結果視覺化

2022-09-01 18:15:08 字數 2261 閱讀 9540

1

import

tensorflow as tf

2import

numpy as np

3import

matplotlib.pyplot as plt

4def add_layer(inputs, in_size, out_size,activation_function=none):

5 weights =tf.variable(tf.random_normal([in_size, out_size]))

6 biases = tf.variable(tf.zeros([1, out_size]) + 0.1)

7 wx_plus_b = tf.matmul(inputs, weights) +biases

8if activation_function is

none:

9 outputs =wx_plus_b

10else

:11 outputs =activation_function(wx_plus_b)

12return

outputs

1314 x_data=np.linspace(-1,1,300,dtype=np.float32)[:,np.newaxis]

15 noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)

16 y_data=np.square(x_data)-0.5+noise

17 xs=tf.placeholder(tf.float32,[none,1],name='

x_input')

18 ys=tf.placeholder(tf.float32,[none,1],name='

y_input')

1920 l1=add_layer(xs,1,10,activation_function=tf.nn.relu) #

隱藏層21 prediction=add_layer(l1,10,1,activation_function=none) #

輸出層22 loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),

23 reduction_indices=[1]))

24 train_step = tf.train.gradientdescentoptimizer(0.1).minimize(loss)

25 init =tf.global_variables_initializer()

26 sess =tf.session()

27sess.run(init)

28 fig = plt.figure() #

生成框架

29 ax=fig.add_subplot(1,1,1) #

連續性的畫圖

30 ax.scatter(x_data,y_data) #

用點的形式把真實的資料畫出來

31 plt.ion() #

用於連續顯示,不會show一下就停止顯示

32plt.show()

33for i in range(1000):

34 sess.run(train_step,feed_dict=)

35if i%50==0:

36print(sess.run(loss,feed_dict=))

37try

:38 ax.lines.remove(lines[0]) #

在中去除第一條線

39except

exception:

40pass

41 prediction_value = sess.run(prediction,feed_dict=)

42 lines=ax.plot(x_data,prediction_value,'

r-',lw=5) #

紅色,寬度為5的線,x,y軸的資料plot上去

43 plt.pause(0.1) #

暫停0.1s

結果視覺化

import tensorflow as tf import numpy as np import matplotlib.pyplot as plt def add layer inputs,input size,output size,activation function none weight...

Tensorfow 之 結果視覺化

安裝 中輸出結果視覺化模組 matplotlib import matplotlib.pyplot as plt使用上篇部落格的例子,在裡面新增了視覺化的部分,就可以將這先乏味的資料通通影象更直觀的檢視了。首先需要的是構建圖形,用散點圖描述真實資料之間的關係。每隔50次訓練重新整理一次圖形,用紅色 ...

高維聚類結果視覺化

利用sklearn包裡的birch演算法,以iris資料集,聚類結果視覺化 如下 import numpy as np import matplotlib.pyplot as plt from sklearn.datasets.samples generator import make blobs ...