天天看點

解密萬億參數M6模型預訓練背後的分布式架構Whale

解密萬億參數M6模型預訓練背後的分布式架構Whale

作者 | 王林

來源 | 阿裡技術公衆号

最近,阿裡雲PAI團隊和達摩院智能計算實驗室一起釋出“低碳版”巨模型M6,大幅降低萬億參數超大模型訓練能耗。借助我們自研的Whale架構僅使用480卡GPU,即訓練出了規模達人類神經元10倍的萬億參數多模态大模型M6,與傳統海外公司實作萬億參數規模相比,能耗降低超八成、效率提升近11倍。

M6是國内首個實作商業化落地的多模态大模型。M6擁有超越傳統AI的認知和創造能力,擅長繪畫、寫作、問答,在電商、制造業、文學藝術等諸多領域擁有廣泛應用前景。

這裡來為大家介紹支援萬億參數模型訓練的Whale架構設計。

一 模型發展趨勢和挑戰

1 模型發展趨勢

随着深度學習的火爆,模型的參數規模也增長迅速,OpenAI資料顯示:

  • 2012年以前,模型計算耗時每2年增長一倍,和摩爾定律保持一緻;
  • 2012年後,模型計算耗時每3.4個月翻一倍,遠超硬體發展速度;
解密萬億參數M6模型預訓練背後的分布式架構Whale

近一年模型參數規模飛速增長,谷歌、英偉達、阿裡、智源研究院都釋出了萬億參數模型,有大廠也釋出了百億、千億參數模型。同時,随着模型參數規模增大,模型效果也在逐漸提高,Nvidia測試Bert模型不同參數規模,發現模型困惑度随模型參數規模增加而降低。

解密萬億參數M6模型預訓練背後的分布式架構Whale

Google在GShard paper中也發現MoETransformer 模型參數規模越大,翻譯品質越高。

解密萬億參數M6模型預訓練背後的分布式架構Whale

2 大模型訓練的挑戰

大模型帶來模型效果提升的同時,也為訓練架構帶來更大的挑戰,例如當我們要訓練一個萬億規模的模型時會面臨如下挑戰:

  • 訓練難:
    • GPU顯存已經不夠存放模型副本,資料并行已經不能滿足需求;
    • 需要架構提供新的并行政策,協同多GPU能力來存放和訓練模型;
    • 如何給使用者提供簡潔、易用的接口,讓使用者能很容易實作分布式版模型;
    • 超大規模模型對計算效率、通信效率都帶來很大挑戰,如何提高計算和通信效率;
    • 下遊任務如何對接,如何支援批量預測和線上推理需求;
  • 成本高:
    • 以萬億模型為例,模型參數有4TB大小、梯度也有4TB,加上optimizer states和active tensor,顯存需求巨大;
    • 業界訓練同等規模模型需要的資源:英偉達 3072 A100、谷歌 2048 TPU v3,成本太高很難落地;
    • 如何降本增效,使用更少的資源,更快的訓練收斂;

目前已經有一些分布式訓練架構,例如:Horovod、Tensorflow Estimator、PyTorch DDP等支援資料并行,Gpipe、PipeDream、PipeMare等支援流水并行,Mesh Tensorflow、FlexFlow、OneFlow、MindSpore等支援算子拆分,但這些架構還有一些不足:

  • 模式單一:很多架構隻支援部分并行政策,不能完全支援各種混合并行;
  • 接入門檻高:使用者實作模型分布式版本難度大、成本高,需要有領域專家經驗才能實作高效的分布式并行政策;
  • 遷移代價大:不同分布式架構并行化實作割裂,不同架構有各自定義的DSL,當使用者要切換并行政策時,需要學習各種接口,重新改寫模型;
  • 性能不理想:部分架構實作未考慮叢集實體環境;

為了應對目前分布式訓練的挑戰,我們研發了分布式訓練架構Whale,主要目标是:

  • 統一多種并行政策:在一個架構中支援各種并行政策以及這些政策的各種組合;
  • 簡潔易用的接口:使用者隻需添加幾行annotation即可完成并行政策的配置,模型代碼不需要改動;
  • 高效的訓練架構:結合硬體資源、網絡拓撲和模型進行協同優化,打造高效分布式訓練架構;

二 PAI自研Whale架構

1 Whale架構

我們推出統一多種并行政策的高性能分布式訓練架構Whale,從如下角度來應對分布式訓練的挑戰:

  • 将不同并行化政策進行統一抽象、封裝,在一套分布式訓練架構中支援多種并行政策;
  • 基于Tensorflow設計一套分布式并行接口,完全相容Tensorflow,使用者僅僅隻需添加幾行annotation就可以實作豐富的分布式并行政策;
  • 結合模型結構和網絡拓撲進行排程和通信優化,提供高效的分布式訓練能力。

Whale架構如下圖所示,主要分4個子產品:

  • API:提供簡潔易用接口,讓使用者組合使用各種混合并行政策;
  • Whale IR:将并行政策轉成内部表達,通過TaskGraph、Multi-Dimension、VirtualDevices抽象來表達各種并行政策;
  • Whale Engine:基于WhaleIR,通過圖編輯工具來建構分布式執行圖;
  • Runtime:将分布式執行圖轉成TFGraph,再調用TF 的Runtime來執行;
解密萬億參數M6模型預訓練背後的分布式架構Whale

2 Whale簡介易用接口

Whale提供簡潔易用的接口來描述各種并行政策,主要的原語:

  • cluster:配置Virtual Device的劃分方法
  • replica:資料并行
  • stage:劃分TaskGraph
  • pipeline:流水并行
  • split:算子拆分

用這些接口可以組合各種并行政策,例如:

  • 資料并行:
解密萬億參數M6模型預訓練背後的分布式架構Whale
  • 流水并行:
解密萬億參數M6模型預訓練背後的分布式架構Whale
  • 流水并行+資料并行:
解密萬億參數M6模型預訓練背後的分布式架構Whale
  • 更多并行政策示例:
解密萬億參數M6模型預訓練背後的分布式架構Whale

3 Whale訓練流程

使用Whale進行分布式訓練流程:

  • 并行政策配置:
    • 使用Whale API來為模型配置并行政策,隻需添加幾行annotation,無需修改模型代碼,方法如 2.2節 所示;
    • 可以将模型劃分為多個TaskGraph,TaskGraph支援配置多個并行政策,每個TaskGraph可以配置不同的并行政策;
  • 虛拟資源劃分:
    • 按并行政策來劃分Virtual Device,每個TaskGraph對應一個Virtual Device;
    • 按GPU資源和網絡topo來為Virtual Device選擇Physical Device;
  • 分布式執行圖:
    • 基于并行政策和資源配置設定資訊,使用圖編輯工具來編輯執行圖(圖拷貝、拆分、插入通信節點等),生成最終的分布式執行圖;
    • 調用TF的runtime來執行分布式Graph;
解密萬億參數M6模型預訓練背後的分布式架構Whale

三 萬億M6模型預訓練

萬億模型的算力需求非常大,為了降低算力需求,Whale中實作了MoE(Mixture-of-Experts)結構,MoE的主要特點是稀疏激活,使用Gating(Router)來為輸入選擇Top k的expert進行計算(k常用取值1、2),進而大大減少算力需求。

解密萬億參數M6模型預訓練背後的分布式架構Whale

Whale中實作了MoE(Mixture-of-Experts) layer,并支援專家并行,将experts拆分到多個Devices上,降低單個Device的顯存和算力需求。同時資料并行有利于提升訓練的并發度,是以采用資料并行+專家并行組合的混合并行政策來訓練M6模型:MoElayer采用專家并行,其他layer采用資料并行。

解密萬億參數M6模型預訓練背後的分布式架構Whale

Whale中提供簡潔易用的接口來進行模型的混合并行訓練,隻需要增加幾行annotation來配置并行政策,模型本身不需要任何修改。M6模型采用資料并行+專家并行的政策,隻需要增加如下圖的annotation:

解密萬億參數M6模型預訓練背後的分布式架構Whale

同時為了節約訓練資源,提高訓練效率,Whale中提供各種優化技術:

顯存優化:

  • Auto Gradient Checkpoint,自動選擇最優checkpoint節點,節約activation的顯存;
  • Group-wise Apply,優化Optimizer Apply階段的顯存;
  • CPU Offload技術,優化Optimizer status和Weight的顯存;
  • 通信池化,控制通信的資料塊大小和并發,節約通信的顯存;

計算、通信加速:

  • 采用DP+EP混合并行政策,降低算力需求;
  • 采用分組融合通信、半精度通信、拓撲感覺的All2All通信算子等技術來提高通信效率;
  • 結合混合精度、編譯優化等技術提高訓練效率;

借助Whale架構,首次在480 V100 上,3天内完成萬億M6模型的預訓練。相比此前英偉達使用3072 A100 GPU實作萬億參數、谷歌使用2048 TPU實作1.6萬億參數大模型,此次達摩院僅使用480卡V100 32G GPU就實作了萬億模型M6,節省算力資源超80%,且訓練效率提升近11倍。

四 結語

模型參數規模已越來越大,大模型已成為發展趨勢,為解決超大模型訓練的挑戰,我們自研Whale架構,将不同并行化政策進行統一抽象、封裝,在一套分布式訓練架構中支援多種并行政策。Whale提供簡潔易用的接口,使用者隻需添加幾行annotation即可實作各種并行政策,不需要對模型本身進行修改。同時我們結合硬體資源、網絡topo、模型進行軟硬體協同優化,提供高效分布式訓練架構。

通過Whale架構,我們用480 V100 GPU卡訓練萬億規模模型,并在3天内完成模型訓練收斂,為超大規模模型訓練落地提供了可能,後續我們會進一步完善Whale架構,從更大規模、更快速度、更高成本效益3個次元去擴充Whale架構的能力。同時也會推動Whale能力在更多業務場景落地,讓技術能力到産品能力的轉變。

《PostgreSQL實戰教程》

從實戰角度出發,帶你全面掌握PostgreSQL核心技術。

點選這裡

,下載下傳教程~

繼續閱讀