天天看點

寫給程式員的機器學習入門 (八 補充) - 使用 GPU 訓練模型

在之前的文章中我訓練模型都是使用的 CPU,因為家中黃臉婆不允許我浪費錢買電腦😭。終于的,附近一個廢品資源回收筒的朋友轉讓給我一台破爛舊電腦,是以我現在可以體驗使用 GPU 訓練模型了🥳。

pytorch, tensorflow 等主流的架構的 GPU 支援都基于 CUDA 架構,而目前提供 CUDA 支援的顯示卡隻有 nvidia,這次我撿到的破爛是 GTX 1650 4GB 是以滿足最低要求了。簡單描述下目前各種顯示卡的支援程度:

Intel 核顯:死心叭

APU:沒法用

Nvidia Geforce

2GB 可以用來跑一些入門例子

4GB 可以跑一些簡單模型

6GB 可以跑一些中級模型

8GB 可以跑一些進階模型

10GB以上 可以跑最前沿的模型

Radeon:要折騰,試試 ROCm

如果真的要玩機器學習推薦購買 RTX 系列,因為有 tensor 核心和 16 位浮點數支援,訓練速度會快很多,并且使用 16 位浮點數可以讓顯存占用少一半。雖然在過幾個星期就可以看到 3000 系列的顯示卡了,可惜沒錢買🤒。此外,明年如果出支援機器學習的民用國産顯示卡必定會大力支援😡。

Windows 的話會通過 Windows Update 自動安裝, pytorch 會自動檢測出顯示卡,不需要做任何工作。Linux 需要安裝 Nvidia 官方的閉源驅動 (開源的 Nouveau 驅動不支援 CUDA),如果是 Ubuntu 那麼在安裝系統的時候打個勾就可以自動安裝,如果沒打可以參考這篇文章,其他 Linux 系統如果源沒有提供可以去 Nvidia 官方下載下傳驅動。

安裝以後可以執行以下代碼看看 pytorch 是否可以檢測出顯示卡:

如果輸出類似以上的結果,那麼就代表沒有問題了。

pytorch 預設會把 tensor 對象的資料儲存在記憶體上,計算會由 CPU 執行,如果我們想使用 GPU,可以調用 tensor 對象的 <code>cuda</code> 方法把對象的資料複制到顯存上,複制以後的 tensor 對象運算會使用 GPU。注意在記憶體上的 tensor 對象和在顯存上的 tensor 對象之間無法進行運算。

如果你想編寫同時相容 GPU 和 CPU 的代碼可以使用以下寫法,如果有支援的 GPU 則會使用 GPU,如果沒有則會使用 CPU:

如果你插了多張顯示卡,以上的寫法隻會使用第一張,你可以通過 "cuda:序号" 來指定不同的顯示卡來實作分布式計算。

這裡我拿前一篇文章的代碼來展示怎樣實際使用 GPU 訓練識别驗證碼的模型,以下是修改後完整的代碼:

如何生成訓練資料和如何使用這份代碼的說明請參考前一篇文章。

使用 diff 生成相差的部分如下:

可以看到隻改動了五個部分,在頭部添加了 device 的定義,然後在加載模型和 tensor 對象的時候使用 <code>.to(device)</code> 即可。

簡單吧☺️。

那麼訓練速度相差如何呢?隻訓練一個 batch 使用 CPU 和 GPU 消耗的時間分别如下 (機關秒):

差了整整 7 倍😱,,如果是高端的顯示卡估計可以看到數十倍的差距。

如果你想檢視訓練過程中的顯存占用情況,可以使用 <code>nvidia-smi</code> 指令,指令會輸出以下的資訊:

如果訓練過程中出現顯存不足,你會看到以下的異常資訊:

如果你遇到顯存不足的問題,那麼可以嘗試以下的辦法解決,按實用程度排序:

出錢買新顯示卡🤒

減少訓練批次大小 (例如每個批次 100 條資料,減為每個批次 50 條資料)

不使用的對象早回收,例如 <code>predicted = None</code>,pytorch 會在對象聲明周期結束後自動釋放顯存

計算單值的時候使用 <code>item()</code>,例如 <code>acc_total += acc.item()</code>,但配合 <code>backward</code> 生成運算路徑的計算不能用

如果你使用桌面 Linux,試試開機的時候添加 <code>rw init=/bin/bash</code> 進入指令行界面再訓練,這樣可以節省個幾百 MB 顯存

你可能會好奇為什了 pytorch 可以及時釋放顯存,這是因為 python 的對象使用了引用計數 (Reference Counted),GC 基本上隻負責回收循環引用的對象,對象的引用計數歸 0 的時候 python 會自動調用析構函數,不需要等待 GC。而 NET 和 Java 等語言則無法做到及時回收,除非你每個 tensor 對象都及時的去調用 Dispose 方法,或者使用 tensorflow 來編譯靜态運算路徑然後把生命周期管理工作全部交給架構。這也是使用 Python 的一大好處🥳。

這篇本來應該放在最開始,可惜等到現在才有條件寫。下一篇文章預計會介紹對象識别模型,包括 RCNN,FasterRCNN 和 YOLO,看看什麼時候能出來吧。