问题
ValueError: Input 0 of node bn/cond/ReadVariableOp/Switch was passed float from bn/moving_mean:0 incompatible with expected resource.
解决办法: https://github.com/keras-team/keras/issues/11032#issuecomment-429989228
code
from tensorflow
.python
.framework
import graph_util
, graph_io
from tensorflow
.python
.platform
import gfile
from tensorflow
import keras
as k
import tensorflow
as tf
def freeze_graph(graph
, session
, save_root
, save_name
, keep_var_name
=None, output_names
=None, clear_devices
=True):
with graph
.as_default
():
freeze_var_names
= list(set(v
.op
.name
for v
in tf
.global_variables
()).difference
(keep_var_name
or []))
output_names
= output_names
or []
output_names
+= [v
.op
.name
for v
in tf
.global_variables
()]
graphdef_inf
= tf
.graph_util
.remove_training_nodes
(graph
.as_graph_def
())
if clear_devices
:
for node
in graphdef_inf
.node
:
node
.device
= ""
graphdef_frozen
= tf
.graph_util
.convert_variables_to_constants
(session
, graphdef_inf
, output_names
, freeze_var_names
)
graph_io
.write_graph
(graphdef_frozen
, save_root
, save_name
, as_text
=False)
def convert(model_path
):
tf
.keras
.backend
.set_learning_phase
(0)
model
= k
.models
.load_model
(model_path
)
session
= tf
.keras
.backend
.get_session
()
freeze_graph
(session
.graph
,
session
,
output_names
=[out
.op
.name
for out
in model
.outputs
],
save_root
='./models/', save_name
='model.pb')
def show_graph(model_path
):
with tf
.Session
() as sess
:
with gfile
.FastGFile
(model_path
, 'rb') as f
:
graph_def
= tf
.GraphDef
()
graph_def
.ParseFromString
(f
.read
())
tf
.import_graph_def
(graph_def
, name
='')
writer
= tf
.summary
.FileWriter
('./logs/')
writer
.add_graph
(sess
.graph
)
writer
.flush
()
writer
.close
()
if __name__
== '__main__':
convert
('./models/model.h5')
show_graph
('./models/model.pb')
服务
from tensorflow
.python
.platform
import gfile
from flask
import Flask
, request
, jsonify
from PIL
import Image
import tensorflow
as tf
import numpy
as np
import io
app
= Flask
(__name__
)
graph_def
= tf
.GraphDef
()
f
= gfile
.FastGFile
('../models/model.pb', 'rb')
graph_def
.ParseFromString
(f
.read
())
tf
.import_graph_def
(graph_def
, name
='')
graph
= tf
.get_default_graph
()
inputs
= graph
.get_tensor_by_name
('resnet50_input:0')
outputs
= graph
.get_tensor_by_name
('dense_1/Softmax:0')
pred_op
= tf
.argmax
(outputs
, axis
=1)
score_op
= tf
.reduce_max
(outputs
)
sess
= tf
.Session
()
@app
.route
('/predict', methods
=['GET', 'POST'])
def predict():
data
= {'errno': 1, 'errmsg': ''}
if request
.method
== 'POST':
if request
.files
.get
('image'):
image
= request
.files
['image'].read
()
image
= Image
.open(io
.BytesIO
(image
))
image
= image
.resize
((224, 224))
image
= np
.asarray
(image
, 'float32')
image
= np
.expand_dims
(image
, axis
=0)
print('* Inputs shape:', image
.shape
)
pred
, score
= sess
.run
([pred_op
, score_op
], feed_dict
={inputs
: image
})
print(pred
[0], score
)
return jsonify
(data
)
if __name__
== '__main__':
app
.run
(host
='10.255.83.88', port
=8088)
预测
curl -X POST -F image=@1.jpg 'http://localhost:5000/predict'