編輯:桃子 喬楊
【新智元導讀】ChatGPT能耗驚人,該怎麼解?谷歌DeepMind新算法JEST問世,讓LLM訓練的疊代次數降低13倍,計算量減少10倍,或将重塑AI未來。
ChatGPT早已成為世界耗能大戶:一天用掉超50萬度電,相當于1.7萬個美國家庭的用電量!
然而,大模型對能源的吞噬,遠不僅如此。
國際能源署(IEA)預測,從2022年到2026年,資料中心的用電量将翻一番。
随着AI計算需求的膨脹,還需要用水來冷卻計算系統。研究稱,微軟用水量從2021年到22年飙升了34%,ChatGPT每處理5-50個提示就會消耗接近半升水。
針對這種現狀,我們有更好的解決政策嗎?
最近,谷歌DeepMind研究團隊提出了一種加快AI訓練的新方法——多模态對比學習與聯合示例選擇(JEST),大大減少了所需的計算資源和時間。
JEST以13倍更少的疊代次數,以及10倍更少的計算量,超越了最先進的模型!
論文位址:https://arxiv.org/pdf/2406.17711
預訓練的參考模型,已經學習了什麼樣的資料是有「優質的」或「有用的」。然後通過模型,來引導資料選擇那些精心篩選過的小型資料集。
這一發現揭示了,資料篩選水準可以作為評判Scaling Law的一個新次元。
網友激動表示,「我沒想到這麼快就會發生。模型能夠自主選擇訓練資料的能力是巨大的,因為它使訓練變得顯著更容易,你不再需要猜測什麼是高品質的訓練資料,你有一個能夠『了解』什麼樣的資料對自身學習最有價值的模型」。
前谷歌、蘋果軟體工程師稱贊道,這項研究非常令人印象深刻。
從「超級batch」中篩選資料
無論是語言、視覺還是多模态模型,資料品質是預訓練性能的重要驅動因素。比如Phi-3、Gemma 2等模型的成功讓我們看到了,更少、更高品質的資料有可能實作更強大的性能。
要篩選出高品質的資料,資料管道的建立就成為重要的工作。現有的方法大體可以分為兩種:1)手動管理 2)基于模型的資料管理,用正在訓練模型的特征選擇高品質資料。
前者成本高昂且難以擴充,後者則有望為多模态LLM實作Scaling Law。
然而,現有方法忽略了一個事實。
如果僅在單個資料點的層面進行篩選,就沒有考慮到資料集以及batch的總體組成。畢竟,訓練資料是以batch為機關,資料點之間的依賴性不可忽視。
許多計算機視覺的研究都曾表明,hard negatives(表達空間中相近但标簽不同的樣本)相比可被平凡解的資料簇,能提供更有效的學習信号。
那麼如何讓模型以batch為機關篩選資料呢?
論文提出的JEST算法正是要解決這個問題,原理很好了解:就是直接從「超級batch」中篩選出「子batch」。
技術介紹
用數學語言來描述這個問題,就是從大小為B的「超級batch」中提取出與學習最相關的子batch ℬ={,∈[1,…,]}⊂,過濾比率可以寫作=1−/。
之前的優先采樣(prioritized sampling)會使用基于模型的評分函數對每個資料點打分,再按比例采樣。JEST則直接對整個子batch評分,再按照batch級别的分數采樣。
一種最直覺的啟發式方法就是在現有模型參數 : hard(ℬ|)=ℓ(ℬ|) 中,直接選擇損失值最高的batch,這種方法可被稱之為「硬學習」(hard learner)。
這種方法具有丢棄瑣碎資料的理想屬性,已被證明适用于小型、幹淨的資料集;然而對于較大、較少管理的資料集往往弊大于利,因為它依舊會采樣到噪聲資料。
另一種方法常用于多模态,使用具有參數 ∗:^easy(ℬ|∗)=−ℓ(ℬ|∗) 的參考模型為預訓練模型采樣資料。但作者依舊否定了這個方案,因為它無法直接反映模型目前的狀态,可能過度依賴參考模型的選擇,而且不易于擴充。
最後,論文選擇借鑒ICML 2022年的一篇論文中提到的方法,将上述兩方面的評分結合起來:^learn(ℬ|,∗)=hard(ℬ|)+^easy(ℬ|∗)=ℓ(ℬ|)−ℓ(ℬ|∗),并将這種啟發式方法稱為「可學習性評分」(learnability score)。
其中,batch上的損失值ℓ(ℬ|)是各資料點之和,使用sigmoid對比損失函數計算(sigmoid-contrastive loss),因為相比softmax對比損失而言,它的擴充性更強。
由于batch上的對比損失可以分解為每個樣本的條件損失之和,是以可學習性評分可被分解為單個樣本可學習性評分(|,∗,ℬ)之和,寫作:
使用的順序采樣方法則受到了block Gibbs采樣的啟發。在第n次疊代、對第B_n個batch進行采樣時,依據如下機率公式對塊{X_k}進行無替換采樣:
将X_k塊添加到B_n中來更新目前采樣的batch,直至疊代數n=N時終止。算法的總體流程如下圖所示:
實驗中發現,使用疊代數N=16且每次疊代時獨立采樣b/N=2048個樣本時,就足以恢複出學習性非常高的batch。
可學習性評分中涉及到使用參考模型為資料點打分,之前的方法慣常使用額外的小型模型,但這會增加每次疊代的計算成本,降低總體FLOP效率增益。
是以論文使用了線上模型近似的方法以及效率較高的FlexiViT架構,隻使用降低分辨率的32×32的patch來評估「超級batch」,與全分辨率、patch大小為16×16的方法相比減少了72%的FLOP,以及67%的挂鐘時間(wall-clock time)。
此外,論文還提出了進行多分辨率訓練的技巧。将每個batch随機分成兩半,使用不同分辨率編碼後再拼接起來,提升了評分過程和訓練的效率。
下圖較長的描述了全分辨率JEST和多分辨率Flexi-JEST方法的僞代碼實作。
所有JEST實驗都在WebLI資料集上運作,包含經過寬松過濾的十億規模的英語圖像-文本對,參考模型的訓練則使用其中經過高品質過濾100M大小的子集(被稱為WebLI-curated)。
在WebLI的基礎上,作者還額外從網絡上抓取了6億個文本-圖像對并經過同樣強度的過濾,組成WebLI-curated++資料集訓練參考模型,拓展出JEST++/FlexiJEST++方法,來探索對資料管理的擴充。
論文所報告的平均性能包括4個多模态規範基準:ImageNet 0-Shot和10-Shot 分類以及COCO圖像到文本和文本到圖像的top-1檢索。
實驗結果
圖1中可以看到,使用JEST或FlexiJEST方法的最明顯優勢就是效率提升。
左圖中,相比原有的SigLIP基線模型,JEST++可以在訓練資料量減少13.1×的情況下達到相同準确率。即使考慮到額外引入的打分成本,也有近10×的FLOP效率提升(中圖)。
右圖展現了JEST++/FlexiJEST++(綠色)與先前方法(灰色)的比較,相比CLIP、EVA-CLIP經典模型實作了計算成本和性能的雙重提升。
左圖和中圖的平均準确率由8個下遊任務得出,右圖性能由ImageNet和COCO基準測試得出
産生可學習batch
研究人員首先評估了JEST在選擇可學習batch方面的效果。
為了直覺地了解這一方法,作者們先将可學習性矩陣進行可視化,即學習模型和參考模型之間,對batch中所有示例對的損失差異。
JEST就是按照示例子矩陣的可學習性總和比例進行采樣。
由于矩陣明顯非對角關系(圖2,左),獨立選擇顯然是次優的。
經過少量疊代(對應于用N=16個塊填充batch),作者發現子batch的可學習性快速增加,達到了需要數千次疊代的暴力吉布斯采樣(Gibbs sampling )所提取batch的可學習性(圖2,中)。
對于0.5、0.8和0.9的過濾比例,他們從大小分别為65,536、163,840和327,680的超級batch中選擇32,768個示例的子batch。
在圖2右側,研究者還發現子batch的可學習性随着更大的過濾比例而增加。
總之,JEST算法是在訓練過程中選擇高度可學習batch的有效,且高效的方法。
加速多模态學習
接下來,研究人員使用JEST算法選擇的可學習batch,檢驗訓練模型的效果。
所有實驗都使用在WebLI-curated上訓練的參考模型,這是一個ViT-B/16和Bert-B圖像-文本雙編碼器,30億訓練樣本,采用sigmoid對比損失函數。
圖3(左)顯示了在訓練過程中多個下遊任務(ImageNet 0-Shot/10-Shot準确率和COCO圖像到文本/文本到圖像檢索)的平均性能。
結果還發現,JEST顯著加速了學習過程。
在使用50%、80%和90%的過濾比例時,分别隻需20億、10億和6.7億訓練樣本就達到了30億均勻基準的最終性能。
在更大的過濾比例下,坐着觀察到類似于更大batch size時的訓練不穩定性,需要修改Adam優化器(β2 = 0.95)以穩定訓練,這表明JEST的資料篩選可以被視為增加了有效batch size。
在最終性能方面,當過濾90%的資料時,JEST也帶來了高達6%的顯著提升(圖3,中間,藍色曲線)。
值得注意的是,這種scaling行為這種性能提升在獨立樣本選擇方法中,并沒有觀察到。(圖3,中間,橙色曲線)。
最後,研究者還評估JEST是否也改善了,除可學習性之外的其他優先标準。
圖3右側顯示了使用easy-reference優先選擇的模型在不同過濾比例下的性能。
與基于可學習性的優先選擇一緻,JEST仍優于獨立樣本選擇,特别是在高過濾比例下(在這種情況下,獨立樣本選擇導緻性能下降)。
優先選擇具有最高損失的資料産生了較小的收益,并且随着過濾更多資料而更快地退化(圖10)。
由于基于可學習性的JEST産生了最佳的scaling行為,研究人員在後續實驗中保留了這一标準。
多分辨率訓練和線上batch選擇之間的協同效應
随着資料batch中被過濾的比例增加,基于可學習性評分的JEST變得更加高效。
然而,評分的成本會帶來顯著的提升:過濾超級batch 80%的資料會導緻每次疊代的浮點運算量是IID訓練的4倍,或者在緩存參考模型得分時是2.3倍。
盡管JEST在訓練疊代次數方面(以下簡稱「訓練效率」)顯著提高了效率,但額外的評分浮點運算降低了其相對于IID基準的計算效率(圖1,左vs右)。
是以,作者還研究了一種計算效率更高的變體,稱為Flexi-JEST,它使用多分辨率訓練和低分辨率評分,将總開銷降低到僅比基準高10%(圖4,左)。
這些近似方法對性能有什麼影響?
正如預期的那樣,Flexi-JEST的每次疊代性能相對于JEST有所下降,但仍然比IID有顯著的加速(圖1,左;圖4,中)。
然而,考慮到總浮點運算量的減少,每次疊代性能的下降是非常有利的:最好的Flexi-JEST模型與40B Siglip運作産生相同的平均性能,但浮點運算量減少了9.9倍,比全分辨率JEST少2倍(圖1,右;圖4,中)。
這些實驗表明了多分辨率訓練和聯合示例選擇之間的協同效應,前者為加速後者提供了高效和準确的評分能力。
實驗結果,還指出了資料策劃政策的帕累托前沿(pareto front)。
如果以計算為代價來最大化訓練速度或訓練效率,全分辨率JEST方法相對于可比的IID訓練運作,可以産生高達13倍的加速。
實作強大資料品質引導
可學習性評分的核心是,一個在人類選擇的小型、精心篩選的資料集上,訓練的參考模型。
JEST的性能如何随不同的篩選政策(在品質和數量之間權衡)而變化?
此外,JEST訓練的改進是否與參考模型的性能相關,還是這些名額是分離的?
了解品質與數量的權衡
研究人員探索了三種規模的資料篩選,每種都是原始WebLI資料集的一個子集:
- 弱篩選(十億級規模):使用圖像-文本對齊(ITA)過濾器。
- 中度篩選(3億級規模):使用ITA過濾器或文本品質(TQ)過濾器。
- 強篩選(1億級規模):結合使用TQ、ITA和額外的圖像品質(aesthetic)過濾器。
在整個過程中,作者将這個強篩選子集稱為「WebLI-curated」。
然後,他們在這四個WebLI子集上,各訓練10個epoch的标準SigLIP編碼器,并将它們用作在全WebLI資料集上進行JEST訓練的參考模型。
在不同的資料篩選方法中,參考模型的性能和JEST的性能似乎是解耦的(甚至可能是反相關的;圖5,左)。
雖然增加篩選(和減少資料集大小)會産生較弱的模型,但當它們被用作JEST預訓練的參考模型時,卻産生了相反的效果:
使用強篩選參考模型的JEST獲得了2.7%的改進,中度篩選獲得了1.5%的改進,弱篩選獲得了0.3%的改進。
擴充資料篩選
假設參考模型性能與JEST性能之間的普遍解耦,可能僅僅是由資料篩選所施加的資料集大小限制造成的。
為了了解這種效果,研究人員在WebLI-curated上訓練了5個參考模型,同時改變所見的總樣本數(從2.5億到30億)。
在這種情況下,圖5(右)顯示了改進的參考模型與更好的JEST預訓練之間存在着顯著的相關性。
這表明「解耦」現象主要可以歸因于參考模型因篩選後資料集大小減少而導緻的飽和。
此外,研究人員還注意到,當資料集達到飽和時,圖5(右)中的相關性開始崩解,即在10個epoch或者看到10億個樣本之後。
這些結果表明,JEST可能會從進一步擴大參考資料集的資料篩選中獲益。
鑒于使用WebLI-curated++對資料進行擴充整理能顯著提高參考模型的性能,作者提出了是否有必要在原始WebLI資料集上進行預訓練的問題。
然而,在評估參考模型在不同資料集上的性能時,卻發現:雖然它在2個下遊任務上的性能優于WebLI預訓練,但在其他6個任務上的性能,以及平均性能都明顯低于WebLI預訓練(表 5)。
與現有資料比較
最後,論文應用JEST++在公開的LAION-2B資料集上進行預訓練,删除了其中不安全的圖像-文本對,但沒有進行其他的預先過濾。
這個資料規模相比的SOTA方法DBP減少了4×,但JEST++依舊遠遠超過了所有之前的離線資料管理方法。
簡化資料管理
之前提到過,用于預訓練的WebLI-curated是原始資料集WebLI過濾後得到的,以求篩選出高品質的圖像-文本對齊的資料。
如表3所示,這種離線資料管理流程對IID(獨立同分布)訓練方法的性能至關重要,但JEST++則表現出了對預過濾流程的魯棒性。即使沒有過濾,JEST++的性能也沒有出現明顯下滑,降低了模型對基礎資料集的要求。
結論和局限性
總體來說,JEST方法展現出了「資料品質引導」(data quality bootstrapping)方法的巨大潛力,即使用小規模精選資料集來指導對更大的、未經管理的資料集的學習。
最近的研究表明,在下遊任務未知時,靜态資料集的過濾會限制模型性能。這篇論文的結果則表明,相比單獨選擇樣本的方法,線上建構batch能提高預訓練的效率。
無論是使用JEST參考模型對資料集進行預評分,還是通過可學習性評分來根據模型需求進行動态調整,都可以成為通用基礎資料集的更有效率的替代方案。
論文的最後,作者也提出了該方法的局限性。雖然JEST同時實作了性能增益和訓練成本降低,但依舊依賴于小型、精心管理的參考資料集,它指定了未經管理的更大資料集中優先考慮的分布。
是以,未來的工作可以探索一種方法,從指定的下遊任務中如何推斷出參考資料集的組成和分布。
參考資料:
https://www.reddit.com/r/singularity/comments/1dw7xnf/google_deepminds_jest_method_can_reduce_ai/
https://decrypt.co/238730/new-ai-training-technique-is-drastically-faster-says-google