Keras model to tensorflow

it2022-05-05  94

问题

ValueError: Input 0 of node bn/cond/ReadVariableOp/Switch was passed float from bn/moving_mean:0 incompatible with expected resource.

解决办法: https://github.com/keras-team/keras/issues/11032#issuecomment-429989228

code

#! -*- coding: utf-8 -*- from tensorflow.python.framework import graph_util, graph_io from tensorflow.python.platform import gfile from tensorflow import keras as k import tensorflow as tf def freeze_graph(graph, session, save_root, save_name, keep_var_name=None, output_names=None, clear_devices=True): with graph.as_default(): freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_name or [])) output_names = output_names or [] output_names += [v.op.name for v in tf.global_variables()] graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def()) if clear_devices: for node in graphdef_inf.node: node.device = "" graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output_names, freeze_var_names) graph_io.write_graph(graphdef_frozen, save_root, save_name, as_text=False) def convert(model_path): tf.keras.backend.set_learning_phase(0) model = k.models.load_model(model_path) session = tf.keras.backend.get_session() freeze_graph(session.graph, session, output_names=[out.op.name for out in model.outputs], save_root='./models/', save_name='model.pb') def show_graph(model_path): with tf.Session() as sess: with gfile.FastGFile(model_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') writer = tf.summary.FileWriter('./logs/') writer.add_graph(sess.graph) writer.flush() writer.close() if __name__ == '__main__': convert('./models/model.h5') show_graph('./models/model.pb')

服务

#! -*- coding: utf-8 -*- from tensorflow.python.platform import gfile from flask import Flask, request, jsonify from PIL import Image import tensorflow as tf import numpy as np import io app = Flask(__name__) graph_def = tf.GraphDef() f = gfile.FastGFile('../models/model.pb', 'rb') graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') graph = tf.get_default_graph() inputs = graph.get_tensor_by_name('resnet50_input:0') outputs = graph.get_tensor_by_name('dense_1/Softmax:0') pred_op = tf.argmax(outputs, axis=1) score_op = tf.reduce_max(outputs) sess = tf.Session() @app.route('/predict', methods=['GET', 'POST']) def predict(): data = {'errno': 1, 'errmsg': ''} if request.method == 'POST': if request.files.get('image'): image = request.files['image'].read() image = Image.open(io.BytesIO(image)) image = image.resize((224, 224)) image = np.asarray(image, 'float32') image = np.expand_dims(image, axis=0) print('* Inputs shape:', image.shape) pred, score = sess.run([pred_op, score_op], feed_dict={inputs: image}) print(pred[0], score) return jsonify(data) if __name__ == '__main__': app.run(host='10.255.83.88', port=8088)

预测

curl -X POST -F image=@1.jpg 'http://localhost:5000/predict'


最新回复(0)