多分類任務的混淆矩陣處理

2021-08-28 11:27:54 字數 1394 閱讀 4177

在多分類任務中,不適合使用pr曲線和roc曲線來進行指標評價,但我們仍可以通過混淆矩陣來進行處理。可以通過matplotlib的matshow()函式,直觀地展示分類結果的好壞。

先使用cross_val_predict得出各個分類值的分數

y_train_pred = cross_val_predict(sgd_clf, x_train_scaled, y_train, cv= 3 )
再使用confusion_matrix()得出最終的混淆矩陣

conf_mx = confusion_matrix(y_train, y_train_pred)
然後使用 matplotlib 的 matshow() 函式,將混淆矩陣以影象的方式呈現

plt.matshow(conf_mx, cmap=plt.cm.gray)
如下圖所示,行代表了實際的類別,列代表了**的結果,從圖中可看出大致都在正對角線上,說明分類結果還不錯。

但是我們應該關注僅包含誤差資料的影象呈現,所以將混淆矩陣的每乙個值除以相應類別的的總數目。這樣子,你可以比較錯誤率,而不是絕對的錯誤數(這對大的類別不公平)

row_sums = conf_mx.sum(axis= 1 , keepdims= true )

norm_conf_mx = conf_mx / row_sums

然後用 0 來填充對角線(使正確的分類不可見),這樣子就只保留了被錯誤分類的資料。

np.fill_diagonal(norm_conf_mx,  0 )

plt.matshow(norm_conf_mx, cmap=plt.cm.gray)

如下圖所示,8,9列比較亮,說明有很多都被錯誤地分到了8,9類中去。相似的,第 8、9 行也相當亮,也就是說8,9類也經常被誤以為是其他類別。

所以通過這個混淆矩陣影象,分析混淆矩陣通常可以給你提供深刻的見解去改善你的分類器。回顧這幅圖,看樣子你應該努力改善分類器在類別8 和類別 9 上的表現,和糾正 3/5 的混淆。

舉例子,你可以嘗試去收集更多的資料,或者你可以構造新的、有助於分類器的特徵。舉例子,寫乙個演算法去數閉合的環(比如,數字 8 有兩個環,數字 6 有乙個, 5 沒有)。又或者你可以預處理(比如,使用 scikit-learn,pillow, opencv)去構造乙個模式,比如閉合的環。

多分類任務的混淆矩陣

今天我將討論如何在多分類中使用混淆矩陣評估模型的效能。什麼是混淆矩陣?它顯示了實際值和 值之間的差異。它告訴我們有多少資料點被正確 哪些資料點沒有被正確 對於多分類來說,它是乙個 n n 矩陣,其中 n 是編號。輸出列中的類別,也稱為目標屬性。一二分類任務中包含了 2 個類也就是乙個 22 矩陣,一...

多分類任務的混淆矩陣和評價指標

之前一直不明白多分類任務的混淆矩陣,今天研究了一下。拿乙個三分類任務來說 cat dog bird 有8個 結果 值 dog,dog,cat cat,cat,dog,bird,cat 真實值 dog,cat,cat,cat,bird,bird,cat,cat 要對每乙個類別做混淆矩陣。拿cat類來說...

Matlab畫混淆矩陣(多分類)

在神經網路和機器學習的結果分析中,常常會用混淆矩陣和roc曲線來分析識別 分類結果的好壞,而且 中也經常出現這種圖。對於卷積神經網路來說畫混淆矩陣很簡單,要用到函式plotconfusion,格式為plotconfusion 實際標籤,標籤 畫出來是這樣的 實際標籤是我們提前就知道的,標籤在神經網路...