训练的模型Model保存,读取

来自软件实验室
跳转至: 导航搜索

当我们训练完后需要保存一些数据以便下次使用,就得把数据保存起来,用的时候在读取。

tensorflow通过tf.Save(sess,save_path,...)来保存,tf.Restore(sess,save_path,...)来读取数据 核心参数是sess和save_path。如下:

保存权重和偏执

#save to file
W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32,name="weights")
b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases")
#init
init = tf.global_variables_initializer();
#
saver = tf.train.Saver()
#session
with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess,"file/save_net.ckpt")
    print ("Save to path:",save_path)
 

打印消息:

('Save to path:', 'file/save_net.ckpt')

上面是保存,下面是读取---------------------------------

读取权重和偏执

W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights")
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases")
#
saver = tf.train.Saver()
# init
init = tf.global_variables_initializer()
# sess
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess,"file/save_net.ckpt")
    print ("weights:",sess.run(W))
    print ("biases:",sess.run(b))

注意点就是,数据定义时候要和原数据类型相同,不然会报错。