tensorflow增加新的层后重载模型部分参数

it2022-05-05  196

踩坑实录: 在做迁移学习的时候经常会碰到 增加了新的层却需要调取已有模型的部分参数的情况 因为已有的checkpoint里并没有新加入层的variables,报错为:

Key xxx not found in checkpoint

可以通过get_collection/看看该层的所有variables var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope=‘新加入层的scope’) 或者 var=slim.get_variables('新加入层的scope')

在restore的时候,知道哪些variables是与新加入的层相关的之后,exclude这些variables就好了,(但我在解决实际问题的时候 发现tf无法正则匹配scope的关键字 比如新增加的scope为 scope1=model/inference/dense 那么这个scope属于绝对路径 比如scope2=model/optimizer/model/inference/dense 这个就得重新写入exclude list 我们只能比较前后有哪些key是不同的 然后剔除这些不同的key):

var_to_restore=tf.contrib.framework.get_variables_to_restore(exclude=['scope1','scope2'...]) saver=tf.train.Saver(var_to_restore,max_to_keep=xx)

最新回复(0)