Tuesday, May 9, 2017

Saving a model that does not import the backend as "K" breaks model loading. #5088

 https://github.com/fchollet/keras/issues/5088

airalcorn2 commented on Jan 19 • edited
The following code produces this error: NameError: name 'backend' is not defined.
from keras import backend
from keras.layers import Dense, Input, Lambda
from keras.models import load_model, Model

x = Input((100, ), dtype = "float32")
x_i = Lambda(lambda x: x + backend.epsilon())(x)
o = Dense(10, activation = "softmax")(x_i)
model = Model(input = [x], output = o)
model.compile("sgd", loss = "categorical_crossentropy")
model.save("toy_model.h5")
model = load_model("toy_model.h5")
@airalcorn2 airalcorn2 changed the title from Saving a model that does not import the Keras backend as "K" breaks model loading. to Saving a model that does not import the backend as "K" breaks model loading. on Jan 19
@joelthchao
Contributor
joelthchao commented on Jan 19
Quick solution: change backend import name, keras doesn't know your variable in lambda layer
from keras import backend as K
from keras.layers import Dense, Input, Lambda
from keras.models import load_model, Model

x = Input((100, ), dtype = "float32")
x_i = Lambda(lambda x: x + K.epsilon())(x)
o = Dense(10, activation = "softmax")(x_i)
model = Model(input = [x], output = o)
model.compile("sgd", loss = "categorical_crossentropy")
model.save("toy_model.h5")
model = load_model("toy_model.h5")
@airalcorn2
airalcorn2 commented on Jan 19 • edited
@joelthchao - thanks for trying to help, but I had already realized that changing the import name for backend to K would make the error go away (hence, my title). The problem is that Keras allows you to save a model using the actual name of the backend module but is unable to load such a model.


@bstriner
Contributor
bstriner commented on Jan 19
No quick fix but worth looking into. There are similar issues if you try to use some custom imports in your lambda. Another workaround is to do the import within the Lambda.
def mylambda(x):
  from keras import backend
  import mymodule
  ... #use backend and mymodule
y = Lambda(mylambda)(h)