如何視覺化深度學習網路中Attention層

2022-08-01 10:00:18 字數 2153 閱讀 1853

在訓練深度學習模型時,常想一窺網路結構中的attention層權重分布,觀察序列輸入的哪些詞或者詞組合是網路比較care的。在小**中主要研究了關於詞性pos對輸入序列的注意力機制。同時對比實驗採取的是words的self-attention機制。

下圖主要包含兩列:word_attention是self-attention機制的模型訓練結果,pos_attention是詞性模型的訓練結果。

可以看出,相對於word_attention,pos的注意力機制不僅能夠捕捉到評價的aspect,也能根據aspect關聯的詞借助情感語義表達的詞性分布,care到相關詞性的情感詞。

seqs = [["這", "是", "乙個", "測試", "樣例", "而已"]]

attns = [[0.01, 0.19, 0.12, 0.7, 0.2, 0.1]]

for i in range(batch_size):

text = mk_html(seqs[i], attns[i])

display(html(text))

需要在model的返回列表中,新增attention_weight的輸出,理論上維度應該和輸入序列的長度是一致的。

# load model

import torch

# if you train on gpu, you need to move onto cpu

model = torch.load("../docs/model_chk/2018-11-07-02:45:37", map_location=lambda storage, location: storage)

from torch.autograd import variable

for batch_idx, samples in enumerate(test_loader, 0):

v_word = variable(samples['word_vec'])

v_final_label = samples['top_label']

model.eval()

final_probs, att_weight = model(v_word, v_pos)

batch_words = towords(samples["word_vec"].numpy(), idx_word) # id轉化為word

batch_att = getatten(batch_words, att_weight.data.numpy()) # 去除padding詞,根據words的長度擷取attention

labels = tolabel(samples['top_label'].numpy()) # 真實標籤

pre_labels = tolabel(final_probs.data.numpy() >= 0.5) # **標籤

for i in range(len(batch_words)):

text = mk_html(batch_words[i], batch_att[i])

print(labels[i], pre_labels[i])

display(html(text))

深度學習 TensorBoard視覺化

1 概述 tensorboard是tensorflow的視覺化工具 通過tensorflow程式執行過程中輸出的日誌檔案視覺化tensorflow程式的執行狀態 tensorflow和tensorboard程式跑在不同的程序中 清除default graph和不斷增加的節點 tf.reset def...

深度學習 網路正則化

in 1n i 1n yi f xi 2 r d min 1n i 1n yi f xi 2 r d 2 12 22l2 12 222 1 i i l2 1 i i 1 1 2 22 1 1 2 22 2 c 2 原理 對於某層神經元,在訓練階段均以概率p隨機將該神經元權重設定為0,在測試階段所有神...

深度學習 Keras MNIST 資料視覺化

學習 機器學習 machine learning 或 深度學習 deep learning 的乙個重要關鍵是 視覺化 visualization 視覺化 可以用在資料上 或是在模型上 或是在結果上 或是在過程上。本文是 keras mnist 的資料視覺化。注 以上程式在 enthought can...