Tensorflow之RNN實踐(二)

2021-07-24 18:23:32 字數 2322 閱讀 7348

參考以下這篇文章學習,從上層模組結構到底層**都分析得很清晰

直接從模型構建過程開始,rnn的核心單元是乙個可沿時間序列展開的cell,一般先定義模型中要使用到的cell結構:

lstm_cell = tf.nn

.rnn_cell.basiclstmcell(size, forget_bias=0.0, state_is_tuple=true)

if is_training and config.keep_prob < 1:

lstm_cell = tf.nn

lstm_cell, output_keep_prob=config.keep_prob)

cell = tf.nn

.rnn_cell.multirnncell([lstm_cell] * config.num_layers, state_is_tuple=true)

接下來是word embedding. 預處理的資料是one-hot向量,但是詞與詞之間的關係沒辦法用位置關係來描述,因此需要用單詞嵌入。

with tf.device("/cpu:0"):

embedding = tf.get_variable(

"embedding", [vocab_size, size], dtype=data_type())

inputs = tf.nn

.embedding_lookup(embedding, input_.input_data)

上面是一種單詞嵌入的方法,設定乙個單詞嵌入矩陣,該矩陣的引數也是訓練引數,在網路訓練過程中同時訓練這些引數。當然還有一種方法,使用已經訓練好的單詞向量,通過詞典查詢使用,這種方法可以防止在語料不夠時,不足以訓練出好的單詞向量的情況。

設定cell的輸入輸出

outputs, state = tf.nn

.rnn(cell, inputs, initial_state=self._initial_state)

設定最後的輸出:

output = tf.reshape(tf.concat(1, outputs), [-1, size])

softmax_w = tf.get_variable(

"softmax_w", [size, vocab_size], dtype=data_type())

softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type())

logits = tf.matmul(output, softmax_w) + softmax_b

設定損失函式:

loss = tf.nn

.seq2seq.sequence_loss_by_example(

[logits],

[tf.reshape(input_.targets, [-1])],

[tf.ones([batch_size * num_steps], dtype=data_type())])

self._cost = cost = tf.reduce_sum(loss) / batch_size

這裡的cost是乙個batch裡面label錯誤的個數的平均數,在輸出的時候我發現輸出的評價指標是perplexity,原來以為是語言模型的word perplexity, 在另一篇部落格 中有提到,但是在ptbmodel的run_epoch函式裡面發現,perplexity的定義如下:

np.exp(costs / iters)
也就是說依然是cost的正向函式。

接下來是梯度傳遞函式,與其它的nn一樣:

self._lr = tf.variable(0.0, trainable=false)

tvars = tf.trainable_variables()

grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),

config.max_grad_norm)

optimizer = tf.train

.gradientdescentoptimizer(self._lr)

zip(grads, tvars),

global_step=tf.contrib

.framework

.get_or_create_global_step())

注意這裡的clip_by_global_norm函式是限制梯度過大的時候要削去,即取設定好的最大梯度值,防止訓練到區域性極小值。

官網例子**:

tensorflow實現普通RNN

coding utf 8 author zhangxianke file test.py time 2018 11 09 from tensorflow.examples.tutorials.mnist import input data import tensorflow as tf data i...

TensorFlow2 0之RNN情感分類問題實戰

tensorflow2.0之rnn情感分類問題實戰 import tensorflow as tf from tensorflow import keras from tensorflow.keras import sequential,layers,datasets,optimizers,loss...

用tensorflow構建動態RNN

直接看 def create cell cell rnn.lstmcell num units rnn cell rnn.multirnncell create cell for in range 2 output,states tf.nn.dynamic rnn rnn cell,x,dtype ...