作者丨Fatescript
序
本文算是我工作一年多以來的一些想法和經驗,最早釋出在曠視研究院内部的論壇中,本着開放和分享的精神釋出在我的知乎專欄中,如果想看幹貨的話可以直接跳過動機部分。另外,後續在這個專欄中,我會做一些關于原理和設計方面的一些分享,希望能給領域從業人員提供一些看待問題的不一樣的視角。
動機
前段時間走在路上,一直在思考一個問題:我的時間開銷很多都被拿去給别人解釋一些在我看起來顯而易見的問題了,比如( https://link.zhihu.com/?target=https%3A//github.com/Megvii- BaseDetection/cvpods )裡面的一些code寫法問題(雖然這在某些方面說明了文檔建設的不完善),而這變相導緻了我實際工作時間的減少,如何讓别人少問一些我覺得答案顯而易見的問題?如何讓别人提前規避一些不必要的坑?隻有解決掉這樣的一些問題,我才能從一件件繁瑣的小事中解放出來,把精力放在我真正關心的事情上去。
其實之前同僚有跟我說過類似的話,每次帶一個新人,都要告訴他:你的實作需要注意這裡blabla,還要注意那裡blabla。說實話,我很佩服那些帶intern時候非常細緻和知無不言的人,但我本性上并不喜歡每次花費時間去解釋一些我覺得顯而易見的問題,是以我寫下了這個文章,把我踩過的坑和留下來的經驗分享出去。希望能夠友善别人,同時也節約我的時間。
加入曠視以來,個人一直在做一些關于架構相關的内容,是以内容主要偏向于模型訓練之類的工作。因為 一個擁有知識的人是無法想象知識在别人腦海中的樣子的(the curse of knowledge),是以我隻能選取被問的最多的,和我認為最應該知道的 。
準備好了的話,我們就啟航出發(另,這篇專欄文章會長期進行更新)。
坑/經驗
Data子產品
- python圖像處理用的最多的兩個庫是opencv和Pillow(PIL),但是兩者讀取出來的圖像并不一樣, opencv讀取的圖像格式的三個通道是BGR形式的,但是PIL是RGB格式的 。這個問題看起來很小,但是衍生出來的坑可以有很多,最常見的場景就是資料增強和預訓練模型中。比如有些資料增強的方法是基于channel次元的,比如megengine裡面的HueTransform,這一行代碼 (https://github.com/MegEngine/MegEngine/blob/4d72e7071d6b8f8240edc56c6853384850b7407f/imperative/python/megengine/data/transform/vision/transform.py#L958 ) 顯然是需要確定圖像是BGR的,但是經常會有人隻看有Transform就無腦用了,從來沒有考慮過這些問題。
- 接上條,RGB和BGR的另一個問題就是導緻預訓練模型載入後訓練的方式不對,最常見的場景就是預訓練模型的input channel是RGB的(例如torch官方來的預訓練模型),然後你用cv2做資料處理,最後還忘了convert成RGB的格式,那麼就是會有問題。這個問題應該很多煉丹的同學沒有注意過,我之前寫CenterNet-better(https://github.com/FateScript/CenterNet-better)就發現CenterNet(https://github.com/xingyizhou/CenterNet)存在這麼一個問題,要知道當時這可是一個有着3k多star的倉庫,但是從來沒有人意識到有這個問題。當然,依照我的經驗,如果你訓練的iter足夠多,即使你的channel有問題,對于結果的影響也會非常小。不過,既然能做對,為啥不注意這些問題一次性做對呢?
- torchvision中提供的模型,都是輸入圖像經過了ToTensor操作train出來的。也就是說最後在進入網絡之前會統一除以255進而将網絡的輸入變到0到1之間。torchvision的文檔(https://pytorch.org/vision/stable/models.html)給出了他們使用的mean和std,也是0-1的mean和std。如果你使用torch預訓練的模型,但是輸入還是0-255的,那麼恭喜你,在載入模型上你又會踩一個大坑(要麼你的圖像先除以255,要麼你的code中mean和std的數值都要乘以255)。
- ToTensor之後接資料處理的坑。上一條說了ToTensor之後圖像變成了0到1的,但是一些資料增強對數值做處理的時候,是針對标準圖像,很多人ToTensor之後接了這樣一個資料增強,最後就是練出來的丹是廢的(心疼電費QaQ)。
- 資料集裡面有一個圖特别詭異,隻要train到那一張圖就會炸顯存(CUDA OOM),别的圖訓練起來都沒有問題,應該怎麼處理?通常出現這個問題,首先判斷資料本身是不是有問題。如果資料本身有問題,在一開始生成Dataset對象的時候去掉就行了。如果資料本身沒有問題,隻不過因為一些特殊原因導緻顯存炸了(比如檢測中圖像的GT boxes過多的問題),可以catch一個CUDA OOM的error之後将一些邏輯放在CPU上,最後retry一下,這樣隻是會慢一個iter,但是訓練過程還是可以完整走完的,在我們開源的YOLOX裡有類似的參考code(https://github.com/Megvii-BaseDetection/YOLOX/blob/0.1.0/yolox/models/yolo_head.py#L330-L334)。
- pytorch中dataloader的坑。有時候會遇到pytorch num_workers=0(也就是單程序)沒有問題,但是多程序就會報一些看不懂的錯的現象,這種情況通常是因為torch到了ulimit的上限,更核心的原因是 torch的dataloader不會釋放檔案描述符 (參考issue: https://github.com/pytorch/pytorch/issues/973)。可以ulimit -n 看一下機器的設定。跑程式之前修改一下對應的數值。
- opencv和dataloader的神奇關聯。很多人經常來問為啥要寫cv2.setNumThreads(0),其實是因為cv2在做resize等op的時候會用多線程,當torch的dataloader是多程序的時候,多程序套多線程,很容易就卡死了(具體哪裡死鎖了我沒探究很深)。除了setNumThreads之外,通常還要加一句cv2.ocl.setUseOpenCL(False),原因是cv2使用opencl和cuda一起用的時候通常會拖慢速度,加了萬事大吉,說不定還能加速。感謝評論區 @Yuxin Wu(https://www.zhihu.com/people/ppwwyyxx) 大大的指正
- dataloader會在epoch結束之後進行類似重新加載的操作,複現這個問題的code稍微有些長,放在後面了。這個問題算是可以說是一個進階bug/feature了,可能導緻的問題之一就是煉丹師在本地的code上進行了一些修改,然後訓練過程直接加載進去了。解決方法也很簡單,讓你的sampler源源不斷地産生資料就好,這樣即使本地code有修改也不會加載進去。
Module子產品
- BatchNorm在訓練和推斷的時候的行為是不一緻的。這也是新人最常見的錯誤(類似的算子還有dropout,這裡提一嘴, pytorch的dropout在eval的時候行為是Identity ,之前有遇到過實習生說dropout加了沒效果,直到我看了他的code:x = F.dropout(x, p=0.5)
- BatchNorm疊加分布式訓練的坑。在使用DDP(DistributedDataParallel)進行訓練的時候,每張卡上的BN統計量是可能不一樣的,仔細檢查broadcast_buffer這個參數 。DDP的預設行為是在forward之前将rank0 的 buffer做一次broadcast(broadcast_buffer=True),但是一些常用的開源檢測倉庫是将broadcast_buffer設定成False的(參考:mmdet(https://github.com/facebookresearch/detectron2/blob/f50ec07cf220982e2c4861c5a9a17c4864ab5bfd/tools/plain_train_net.py#L206) 和 detectron2(https://github.com/facebookresearch/detectron2/blob/f50ec07cf220982e2c4861c5a9a17c4864ab5bfd/tools/plain_train_net.py#L206),我猜是在檢測任務中因為batchsize過小,統一用卡0的統計量會掉點) 這個問題在一邊訓練一邊測試的code中更常見 ,比如說你train了5個epoch,然後要分布式測試一下。一般的邏輯是将資料集分到每塊卡上,每塊卡進行inference,最後gather到卡0上進行測點。但是 因為每張卡統計量是不一樣的,是以和那種把卡0的模型broadcast到不同卡上測試出來的結果是不一樣的。這也是為啥通常訓練完測的點和單獨起了一個測試腳本跑出來的點不一樣的原因 (當然你用SyncBN就不會有這個問題)。
- Pytorch的SyncBN在1.5之前一直實作的有bug,是以有一些老倉庫是存在使用SyncBN之後掉點的問題的。
- 用了多卡開多尺度訓練,明明尺度更小了,但是速度好像不是很理想?這個問題涉及到多卡的原理,因為分布式訓練的時候,在得到新的參數之後往往需要進行一次同步。假設有兩張卡,卡0的尺度非常小,卡1的尺度非常大,那麼就會出現卡0始終在等卡1,于是就出現了雖然有的尺度變小了,但是整體的訓練速度并沒有變快的現象(木桶效應)。解決這個問題的思路就是 盡量把負載拉均衡一些 。
- 多卡的小batch模拟大batch(梯度累積)的坑。假設我們在單卡下隻能塞下batchsize = 2,那麼為了模拟一個batchsize = 8的效果,通常的做法是forward / backward 4次,不清理梯度,step一次(當然考慮BN的統計量問題這種做法和單純的batchsize=8肯定還是有一些差别的)。在多卡下,因為調用loss.backward的時候會做grad的同步,是以說前三次調用backward的時候需要加ddp.no_sync(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=no_sync#torch.nn.parallel.DistributedDataParallel.no_sync)的context manager(不加的話,第一次bp之後,各個卡上的grad此時會進行同步),最後一次則不需要加。當然,我看很多倉庫并沒有這麼做,我隻能了解他們就是單純想做梯度累積(BTW,加了ddp.no_sync會使得程式快一些,畢竟加了之後bp過程是無通訊的)。
- 浮點數的加法其實不遵守交換律的 ,這個通常能衍生出來GPU上的運算結果不能嚴格複現的現象。可能一些非計算機軟體專業的同學并不了解這一件事情,直接自己開一個python終端體驗可能會更好:
print(1e100 + 1e-4 + -1e100) # ouptut: 0
print(1e100 + -1e100 + 1e-4) # output: 0.0001
訓練子產品
- FP16訓練/混合精度訓練。使用Apex訓練混合精度模型,在儲存checkpoint用于繼續訓練的時候,除了model和optimizer本身的state_dict之外,還需要儲存一下amp的state_dict,這個在amp的文檔(https://link.zhihu.com/?target=https%3A//nvidia.github.io/apex/amp.html%23checkpointing)中也有提過。(當然,經驗上來說忘了儲存影響不大,會多花幾個iter search一個loss scalar出來)
- 多機分布式訓練卡死的問題。好友 @NoahSYZhang(https://www.zhihu.com/people/syzhangbuaa) 遇到的一個坑。場景是先申請了兩個8卡機,然後機器1和機器2用前4塊卡做通訊(local rank最大都是4,總共是兩機8卡)。可以初始化process group,但是在使用DDP的時候會卡死。原因在于pytorch在做DDP的時候會猜測一個rank,參考code(https://github.com/pytorch/pytorch/blob/0d437fe6d0ef17648072eb586484a4a5a080b094/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1622-L1630)。對于上面的場景,第二個機器上因為存在卡5到卡8,而對應的rank也是5到8,是以DDP就會認為自己需要同步的是卡5到卡8,于是就卡死了。
複現Code
Data部分
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import tqdm
import time
class SimpleDataset(Dataset):
def __init__(self, length=400):
self.length = length
self.data_list = list(range(length))
def __getitem__(self, index):
data = self.data_list[index]
time.sleep(0.1)
return data
def __len__(self):
return self.length
def train(local_rank):
dataset = SimpleDataset()
dataloader = DataLoader(dataset, batch_size=1, num_workers=2)
iter_loader = iter(dataloader)
max_iter = 100000
for _ in tqdm.tqdm(range(max_iter)):
try:
_ = next(iter_loader)
except StopIteration:
print("Refresh here !!!!!!!!")
iter_loader = iter(dataloader)
_ = next(iter_loader)
if __name__ == "__main__":
import torch.multiprocessing as mp
mp.spawn(train, args=(), nprocs=2, daemon=False)
當程式運作起來的時候,可以在Dataset裡面的__getitem__方法裡面加一個print輸出一些内容,在refresh之後,就會print對應的内容哦(看到現象是不是覺得自己以前煉的丹可能有問題了呢hhh)
一些碎碎念
一口氣寫了這麼多條也有點累了,後續有踩到新坑的話我也會繼續更新這篇文章的。畢竟寫這篇文章是希望工作中不再會有人踩類似的坑 & 煉丹的人能夠對深度學習架構有意識(雖然某種程度上來講這算是個心智負擔)。
如果說今年來什麼事情是最大的收獲的話,那就是了解了一個開放的生态是可以迸發出極強的活力的,也希望能看到更多的人來分享自己遇到的問題和解決的思路。畢竟探索的答案隻是一個副産品,過程本身才是最大的财寶。
如果覺得有用,就請分享到朋友圈吧!
關于我
你好,我是對白,清華計算機碩士畢業,現大廠算法工程師,拿過8家大廠算法崗SSP offer(含特殊計劃),薪資40+W-80+W不等。
高中榮獲全國數學和化學競賽二等獎。
大學獨立創業五年,兩家公司創始人,拿過三百多萬元融資(已到賬),項目入選南京321高層次創業人才引進計劃。創業做過無人機、機器人和網際網路教育,保研清華後退居股東。
我每周至少更新三篇原創,分享人工智能前沿算法、創業心得和人生感悟。我正在努力實作人生中的第二個小目标,上方關注後可以加我微信交流。