"helloworld"例子

来自软件实验室
跳转至: 导航搜索

tensorflow中的helloworld例子就是mnist,识别手写的数字。先运行就下面的代码把数据下载来


import tensorflow.examples.tutorials.mnist.input_data as input_data
#mnist数据
#load data
mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)

print "训练数据量:",mnist.train.images.shape
print "训练标签量:",mnist.train.labels.shape
print "验证数据量:",mnist.validation.images.shape
print "验证标签量:",mnist.validation.labels.shape
print "测试数据量:",mnist.test.images.shape
print "测试标签量:",mnist.test.labels.shape

Mnist数据.png自己核对吧

再把官方的例子放进去跑起来,

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)

sess = tf.InteractiveSession()

# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# Train
tf.initialize_all_variables().run()
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  train_step.run({x: batch_xs, y_: batch_ys})

# Test trained model
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

刚刚开可能要下载数据,等着就好,当控制台打印如下消息表已示下载好

Extracting Mnist_data/train-images-idx3-ubyte.gz
Extracting Mnist_data/train-labels-idx1-ubyte.gz
Extracting Mnist_data/t10k-images-idx3-ubyte.gz
Extracting Mnist_data/t10k-labels-idx1-ubyte.gz

0.9056

其中0.9056表示test时候的正确率,自己可以先试着理一下tensorflow的运行流程,猜一下,先哪一步,再执行哪一步,知道涉及到的方法的作用。再提醒一下,这里面经常出现矩阵,我刚从java跳过来时候,非常不习惯这种,但只要看做矩阵数据就好了。