本文轉載于 SegmentFault 社群 作者:zack 這個文章是用 pytorch 和 matplotlib 實作一個二進制分類器并且可視化。
思路:
- 自己生成兩團資料
- 定義自己的神經網絡類
- 訓練網絡
- 列印出邊界
先放效果圖:
關于可視化
定義網絡、訓練網絡主要沒什麼好說的啦其實,畢竟有 pytorch 這麼好的架構,已經提供了如此簡單的代碼工作。 主要是可視化的技巧。 主要是 matplotlib 中有個 contourf ,本身是畫等高線用的,就是地理中那個三維圖投射到二維圖的那種圖。 我們可以把這個用到可視化上來 (當然隻是3維的,如果是更高次元就沒法用這個可視化了) 。
具體怎麼可視化的?
首先,先自己生成 200 個訓練資料 (這步對應 getData 函數) ,然後把屬于不同類别的資料染上不同顔色;
然後,進行網絡的訓練 (對應 run 函數) ;
然後,用同樣的資料讓網絡進行預測。因為二進制分類器最後預測的結果要麼是 0,要麼是 1,是以可以利用 matplotlib 中的畫等高線的函數,來近似畫出決策邊界。這一步主要對應 showBoundary 函數。
使用 conturf 函數
這個函數我自己在用的時候有點懵逼,使用這個要先 meshgrid , mesh 合并的意思, grid 網格的意思,要把兩個清單先合成一個網格,這個形式我也不是很喜歡。
勉勉強強參考了一些部落格才寫了出來。具體我也沒辦法一一講述,還請各位原諒。
不過其中, cmap 是畫出來的圖的風格參數,可以是 camp=plt.cm.hot 等等, alpha 是透明度。 用了 conturf 這個函數,就可以有顔色的差別了。
最後放代碼
- END -