Python使用tensorflow中梯度下降算法求解变量最优值

2016-12-21 董付国 Python小屋 Python小屋

TensorFlow是一个用于人工智能的开源神器,是一个采用数据流图(data flow graphs)用于数值计算的开源软件库。数据流图使用节点(nodes)和边线(edges)的有向图来描述数学计算,图中的节点表示数学操作,也可以表示数据输入的起点或者数据输出的终点,而边线表示在节点之间的输入/输出关系,用来运输大小可动态调整的多维数据数组,也就是张量(tensor)。TensorFlow可以在普通计算机、服务器和移动设备的CPU和GPU上展开计算,具有很强的可移植性,并且支持C++、Python等多种语言。


import tensorflow as tf
import numpy as np
import time


#使用 NumPy 生成随机数据, 总共 2行100列个点.
x_data = np.float32(np.random.rand(2, 200))
#矩阵乘法
#这里的W=[0.100, 0.200]和b=0.300是理论数据

通过后面的训练来验证
y_data = np.dot([0.100, 0.200], x_data) + 0.300

#构造一个线性模型,训练求解W和b
#初始值b = [0.0]
b = tf.Variable(tf.zeros([1]))
#初始值W为1x2的矩阵,元素值介于[-1.0, 1.0]区间
W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
#构建训练模型,matmul为矩阵乘法运算
y = tf.matmul(W, x_data) + b

#最小均方差
loss = tf.reduce_mean(tf.square(y - y_data))
#使用梯度下降算法进行优化求解
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

#初始化变量
init = tf.global_variables_initializer()

with tf.Session() as sess:
    #初始化
    sess.run(init)

    #拟合平面,训练次数越多越精确,但是也没有必要训练太多次
    for step in range(0, 201):
        sess.run(train)
        #显示训练过程,这里演示了两种查看变量值的方法
        print(step, sess.run(W), b.eval())



运行结果如下,可以发现求解的结果非常接近理论值,为避免浪费大家流量,这里省略了中间的180个训练结果。

0 [[ 0.15414073  0.32496157]] [ 0.43636867]
1 [[ 0.03701844  0.20617545]] [ 0.21209587]
2 [[ 0.10094656  0.26214167]] [ 0.32871172]
3 [[ 0.07112053  0.22834063]] [ 0.26936483]
4 [[ 0.08938536  0.24124807]] [ 0.30078542]
5 [[ 0.08256587  0.23040327]] [ 0.28532556]
6 [[ 0.08844876  0.23214042]] [ 0.29402208]
7 [[ 0.08755529  0.22768299]] [ 0.29021722]
8 [[ 0.08995744  0.22669716]] [ 0.29283032]
9 [[ 0.09047545  0.2241728 ]] [ 0.29209939]
10 [[ 0.09179883  0.22267541]] [ 0.2930637]
...
191 [[ 0.10000042  0.20000042]] [ 0.29999956]
192 [[ 0.1000004   0.20000039]] [ 0.29999959]
193 [[ 0.10000037  0.20000036]] [ 0.29999962]
194 [[ 0.10000035  0.20000033]] [ 0.29999962]
195 [[ 0.10000034  0.20000032]] [ 0.29999965]
196 [[ 0.10000032  0.2000003 ]] [ 0.29999968]
197 [[ 0.1000003   0.20000029]] [ 0.29999968]
198 [[ 0.10000029  0.20000027]] [ 0.29999971]
199 [[ 0.10000027  0.20000026]] [ 0.29999971]
200 [[ 0.10000026  0.20000026]] [ 0.29999974]