“训练的模型Model保存,读取”的版本间的差异
来自软件实验室
(创建页面,内容为“保存模型 save(sess,save_path,...) 从文件中恢复模型 restore(sess,save_path,...)”) |
|||
第1行: | 第1行: | ||
− | + | 当我们训练完后需要保存一些数据以便下次使用,就得把数据保存起来,用的时候在读取。 | |
− | + | tensorflow通过tf.Save(sess,save_path,...)来保存,tf.Restore(sess,save_path,...)来读取数据 | |
+ | 核心参数是sess和save_path。如下: | ||
− | + | 保存权重和偏执 | |
− | + | <nowiki> | |
+ | #save to file | ||
+ | W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32,name="weights") | ||
+ | b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases") | ||
+ | #init | ||
+ | init = tf.global_variables_initializer(); | ||
+ | # | ||
+ | saver = tf.train.Saver() | ||
+ | #session | ||
+ | with tf.Session() as sess: | ||
+ | sess.run(init) | ||
+ | save_path = saver.save(sess,"file/save_net.ckpt") | ||
+ | print ("Save to path:",save_path) | ||
+ | </nowiki> | ||
+ | |||
+ | 打印消息: | ||
+ | <nowiki>('Save to path:', 'file/save_net.ckpt')</nowiki> | ||
+ | |||
+ | -----------上面是保存,下面是读取--------------------------------- | ||
+ | |||
+ | 读取权重和偏执 | ||
+ | |||
+ | <nowiki>W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights") | ||
+ | b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases") | ||
+ | # | ||
+ | saver = tf.train.Saver() | ||
+ | # init | ||
+ | init = tf.global_variables_initializer() | ||
+ | # sess | ||
+ | with tf.Session() as sess: | ||
+ | sess.run(init) | ||
+ | saver.restore(sess,"file/save_net.ckpt") | ||
+ | print ("weights:",sess.run(W)) | ||
+ | print ("biases:",sess.run(b)) | ||
+ | </nowiki> | ||
+ | |||
+ | '''注意点就是,数据定义时候要和原数据类型相同,不然会报错。''' |
2017年1月9日 (一) 10:44的最新版本
当我们训练完后需要保存一些数据以便下次使用,就得把数据保存起来,用的时候在读取。
tensorflow通过tf.Save(sess,save_path,...)来保存,tf.Restore(sess,save_path,...)来读取数据 核心参数是sess和save_path。如下:
保存权重和偏执
#save to file W = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32,name="weights") b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases") #init init = tf.global_variables_initializer(); # saver = tf.train.Saver() #session with tf.Session() as sess: sess.run(init) save_path = saver.save(sess,"file/save_net.ckpt") print ("Save to path:",save_path)
打印消息:
('Save to path:', 'file/save_net.ckpt')
上面是保存,下面是读取---------------------------------
读取权重和偏执
W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights") b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases") # saver = tf.train.Saver() # init init = tf.global_variables_initializer() # sess with tf.Session() as sess: sess.run(init) saver.restore(sess,"file/save_net.ckpt") print ("weights:",sess.run(W)) print ("biases:",sess.run(b))
注意点就是,数据定义时候要和原数据类型相同,不然会报错。