“训练的模型Model保存,读取”的版本间的差异

来自软件实验室
跳转至: 导航搜索
(创建页面,内容为“保存模型 save(sess,save_path,...) 从文件中恢复模型 restore(sess,save_path,...)”)
 
 
第1行: 第1行:
保存模型
+
当我们训练完后需要保存一些数据以便下次使用,就得把数据保存起来,用的时候在读取。
  
    save(sess,save_path,...)
+
tensorflow通过tf.Save(sess,save_path,...)来保存,tf.Restore(sess,save_path,...)来读取数据
 +
核心参数是sess和save_path。如下:
  
    从文件中恢复模型
+
保存权重和偏执
  
     restore(sess,save_path,...)
+
<nowiki>
 +
#save to file
 +
W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32,name="weights")
 +
b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases")
 +
#init
 +
init = tf.global_variables_initializer();
 +
#
 +
saver = tf.train.Saver()
 +
#session
 +
with tf.Session() as sess:
 +
     sess.run(init)
 +
    save_path = saver.save(sess,"file/save_net.ckpt")
 +
    print ("Save to path:",save_path)
 +
</nowiki>
 +
 
 +
打印消息:
 +
<nowiki>('Save to path:', 'file/save_net.ckpt')</nowiki>
 +
 
 +
-----------上面是保存,下面是读取---------------------------------
 +
 
 +
读取权重和偏执
 +
 
 +
<nowiki>W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights")
 +
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases")
 +
#
 +
saver = tf.train.Saver()
 +
# init
 +
init = tf.global_variables_initializer()
 +
# sess
 +
with tf.Session() as sess:
 +
    sess.run(init)
 +
    saver.restore(sess,"file/save_net.ckpt")
 +
    print ("weights:",sess.run(W))
 +
    print ("biases:",sess.run(b))
 +
</nowiki>
 +
 
 +
'''注意点就是,数据定义时候要和原数据类型相同,不然会报错。'''

2017年1月9日 (一) 10:44的最新版本

当我们训练完后需要保存一些数据以便下次使用,就得把数据保存起来,用的时候在读取。

tensorflow通过tf.Save(sess,save_path,...)来保存,tf.Restore(sess,save_path,...)来读取数据 核心参数是sess和save_path。如下:

保存权重和偏执

#save to file
W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32,name="weights")
b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases")
#init
init = tf.global_variables_initializer();
#
saver = tf.train.Saver()
#session
with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess,"file/save_net.ckpt")
    print ("Save to path:",save_path)
 

打印消息:

('Save to path:', 'file/save_net.ckpt')

上面是保存,下面是读取---------------------------------

读取权重和偏执

W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights")
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases")
#
saver = tf.train.Saver()
# init
init = tf.global_variables_initializer()
# sess
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess,"file/save_net.ckpt")
    print ("weights:",sess.run(W))
    print ("biases:",sess.run(b))

注意点就是,数据定义时候要和原数据类型相同,不然会报错。