1. tensorflow生成ckpt文件
直接在训练结束后使用下面两个语句即可:
saver = tf.train.Saver()
saver.save(sess, './名称')
如: saver = tf.train.Saver() saver.save(sess, ‘./my_checkpoint’)
此时会生成如下文件:
|--checkpoint_dir
| |--checkpoint
| |--my_checkpoint.meta
| |--my_checkpoint.data-00000-of-00001
| |--my_checkpoint.index
在0.11版本之前TensorFlow模型仅包含三个文件:
|--checkpoint_dir
| |--checkpoint
| |--my_checkpoint.meta
| |--my_checkpoint.ckpt
.meta文件是pb(protocol buffer)格式的文件,它保存的是图结构,包含变量、op、集合等。 .ckpt文件是二进制文件,保存了所有的weights、biases、gradients等变量。0.11版本后,通过两个文件保存,即.index和.data文件。
2. 将ckpt文件转化为npy
from tensorflow
.python
import pywrap_tensorflow
import numpy
as np
autoencoder
= {'enc_conv1':[[],[]],'enc_conv2':[[],[]],'enc_conv3':[[],[]],
'dec_conv1':[[],[]],'dec_conv2':[[],[]],'dec_conv3':[[],[]]}
path
= 'C://Users//Administrator//Desktop//autoencoder//./my_checkpoint'
reader
= pywrap_tensorflow
.NewCheckpointReader
(path
)
var_to_shape_map
= reader
.get_variable_to_shape_map
()
for key
in var_to_shape_map
:
str_name
= key
if str_name
.find
('Adam') > -1:
continue
if str_name
.find
('power') > -1:
continue
if str_name
.find
('/') > -1:
names
= str_name
.split
('/')
layer_name
= names
[0]
layer_info
= names
[1]
else:
layer_name
= str_name
layer_info
= None
if layer_info
== 'kernel':
autoencoder
[layer_name
][0]=reader
.get_tensor
(key
)
elif layer_info
== 'bias':
autoencoder
[layer_name
][1] = reader
.get_tensor
(key
)
else:
autoencoder
[layer_name
] = reader
.get_tensor
(key
)
np
.save
('autoencoder.npy',autoencoder
)
print('save npy over...')
3. 加载npy中存储的权重信息
def load_weights():
weights_dict
= np
.load
("weights/autoencoder.npy").item
()
assign_list
= []
for op_name
in weights_dict
.keys
():
with tf
.variable_scope
(op_name
, reuse
=tf
.AUTO_REUSE
):
for data
in weights_dict
[op_name
]:
if len(data
.shape
) == 1:
var
= tf
.get_variable
('bias')
assign_list
.append
(var
.assign
(data
))
else:
var
= tf
.get_variable
('kernel')
assign_list
.append
(var
.assign
(data
))
return assign_list
在函数外部使用sess.run(load_weights())语句即可。
4. 另一种保存并读取npy比较简单的方式
适用场合:权重是显式定义的,并且数目不多。 保存: 使用eval()函数将Tensor转化为numpy,再使用np.savez()保存为字典。 读取: 如:
data
= np
.load
('weights/autoencoder.npz')
self
.weights
['w1'] = tf
.Variable
(tf
.convert_to_tensor
(data
['w1']))
self
.weights
['b1'] = tf
.Variable
(tf
.convert_to_tensor
(data
['b1']))
self
.weights
['w2'] = tf
.Variable
(tf
.convert_to_tensor
(data
['w2']))
self
.weights
['b2'] = tf
.Variable
(tf
.convert_to_tensor
(data
['b2']))
参考链接:
ckpt文件的概念和使用:https://www.cnblogs.com/adong7639/p/7764769.html将ckpt文件转化为npy:https://blog.csdn.net/raby_gyl/article/details/79075716 https://download.csdn.net/download/weixin_42713739/11317513