點選藍字
關注我們
關注并星标
從此不迷路
計算機視覺研究院
公衆号ID|計算機視覺研究院
學習群|掃碼在首頁擷取加入方式
- 論文位址:https://arxiv.org/pdf/2407.05483
- 項目首頁:https://github.com/HazyResearch/prefix-linear-attention
計算機視覺研究院專欄
Column of Computer Vision Institute
在目前 AI 領域,大語言模型采用的主流架構是 Transformer。不過,随着 RWKV、Mamba 等架構的陸續問世,出現了一個很明顯的趨勢:在語言模組化困惑度方面與 Transformer 較量的循環大語言模型正在快速進入人們的視線。
令人興奮的是,這些架構在推理期間使用了恒定量的記憶體。不過,受制于有限的記憶體,循環語言模型(LM)無法記憶并使用長上下文中的所有資訊,這導緻了上下文學習(in-context learning,ICL)品質的不佳。是以,獲得高效大語言模型的關鍵挑戰在于選擇存儲或者丢棄哪些資訊。
在最近的論文《Just read twice: closing the recall gap for recurrent language models》中,來自斯坦福大學、布法羅大學的研究者通過簡單觀察發現,資料在推理期間湧入循環語言模型的排序極大地影響了在有限記憶體中預測存儲哪些資訊的難度。
我們假設根據文檔 D(比如伽利略・伽利萊的詳細維基百科)來提問:伽利略是什麼時候搬到的佛羅倫薩?這時,如果提示遵循了 [Q, D] 的排序,則模型隻需要記住文檔 D 中的一個事實即可。相反,如果提示遵循了 [D, Q] 的排序,則模型需要記住所有事實。如下圖 1(左)所示。
是以,本文首先從理論上形式化了資料排序如何影響記憶體需求,然後提出兩種方法來減輕對資料排序的依賴,分别是 Just-read-twice(JRT)提示政策和 JRT 循環架構。本文主要分為以下幾個部分展開:
了解資料排序的作用。研究者得出的第一個洞見是:記憶問題的 hardness 要降低到與設定剝離(set disjointness,SD)相同,這是通信複雜度理論中持續數十年的最典型問題。SD 要求一種流算法(比如循環模型)來決定上下文中提供的輸入集是否剝離:
理論分析和實驗結果表明,第一個集 | A | 掌控了求解 SD 所需的記憶體。因果模型需要存儲 A 中的所有元素以與 B 中的元素進行比較。這表明了,使用上下文中的「正确資料排序」(如将最小 min (|A|, |B|) 的集放在首位)将有助于記憶體受限的模型。更進一步,觀察到上下文非因果邏輯的模型可在空間最小的 (|A|, |B|) 中求解 SD,而無需考慮資料排序。
其次是利用「正确的」排序。本文提出了一種非常簡單的 JRT-Prompt 政策,在模型生成答案之前在上下文中将資訊重複多次(如上圖 1 右所示)。在第二以及更多輪次中,語言模型在決定存儲哪些資訊時要以完整的上下文為條件,進而有效避免了将資料排序「歸正」的問題。
結果表明,JRT-Prompt 在 16 個已有循環語言模型和 6 項 ICL 任務上,實作了平均 11.0 ± 1.3 百分點的提升,而吞吐量是 FlashAttention-2(長度 32k、批大小 16)的 11.9 倍。JRT-Prompt 雖然增加了上下文長度,但漸進來看仍然比注意力更加地計算和記憶體高效。
超越因果模型。本文提出了 JRT-RNN,它的靈感來源于簡單的 Prefix-LM 編碼器解碼器架構設計。大多數的上下文學習輸入包含兩部分内容,分别是輸入的提示(上下文、指令)和作為輸出的模型生成文本。在 Prefix-LM 架構中,LM 并沒有遵循因果邏輯地處理提示區域,而對輸出進行了因果解碼,其中在因果區域僅使用了标準的下一個 token 預測損失,以及非因果區域上的損失。
不過遺憾的是,此前 Prefix-LM 模型的訓練方法取得的成功有限,并使用了低效的 Transformer 主幹。是以本文通過一些簡單的改變來提高品質和效率,包括改進訓練損失并使用稱之為「Prefix Linear Attention,PLA」 的線性注意力公式。研究者發現,使用他們的 IO 感覺實作,JRT-RNN 在 360m 和 1.3b 參數設定下,分别可以提供 13.7 和 6.9 百分點的平均品質改進,吞吐量是 FA2 的 19.2 倍。
JRT-Prompt 方法概覽
上下文學習任務以 (C, Q, Y) 作為輸入,其中 C 為一些上下文來源(如文檔或代碼存儲庫),Q 為給定上下文時對模型的一些問題或請求,Y 為答案。對于使用自回歸 LM A 的标準上下文學習,研究者輸入 C 和 Q,并根據正确的完成情況 Y 來評估生成的輸出 Yˆ = A (C, Q)。
JRT-Prompt 是一種極其簡單的方法,在提示模型輸出答案之前會在上下文中重複提示中的資訊(如問題和文檔),例如下圖 1 右的 Yˆ = A (C, Q, C, Q)。是以,在上下文第二次出現時,模型根據完整的上下文來決定存儲哪些資訊。
此外,JRT-Prompt 可以與現成的 LLM 一起使用。研究者在零樣本提示下,在一系列記憶密集型上下文任務上評估了以下 LM:
- Based 預訓練 LM,參數規模為 1.3B,在 Pile 的 10 − 50B 個 token 上進行訓練;
- Mamba 預訓練的 LM,參數規模為 130M、370M、1.4B 和 2.8B,在 Pile 的 300B 個 token 上進行訓練;
- Gated Linear Attention 預訓練的 LM,參數規模為 1.3B 和 2.7B,在 SlimPajama 資料集的 100B 個 token 上進行訓練;
- Mamba-2 預訓練的 LM,參數規模為 130M、370M、1.3B 和 2.7B,在 Pile 的 300B 個 token 上進行訓練。
結果如下表 1 所示,通過增加狀态(state)大小,研究者發現 JRT-Prompt 方法在各個模型和任務上平均帶來了 11.0 ± 1.3 百分點的性能提升,利用該方法的 Based 模型平均優于利用标準提示的 Transformer 模型。
他們還發現,JRT-Prompt 可以使 Transformer 模型受益,并且該方法在一些任務上(附錄 2)比少樣本學習更加有效。值得注意的是,Springer 等人在論文《Repetition improves language model embeddings》中提出使用自回歸 Transformer 模型來重複上下文以實作生成嵌入的目的,本文的研究結果也類似。研究者專注于亞二次架構和上下文學習任務。
JRT-Prompt 雖然由于重複而增加了上下文長度,但是其使用的亞二次循環架構仍比使用二次 Transformer 模型更高效。研究者發現,在序列長度 N = 32768、批大小為 16 時,使用 JRT-Prompt(序列長度 2N)在英偉達 H100 上提供的吞吐量是 FlashAttention-2(序列長度 N)的 11.9 倍。
JRT-RNN:編碼器 - 解碼器循環架構
JRT-RNN 的靈感來自于 Prefix-LMs,但側重于擴充品質 - 效率權衡空間的帕累托邊界(Pareto frontier)。為了提高品質,JRT-RNN 在編碼器端使用了單獨的 k_e 和 v_e 映射,在解碼器端使用了 k_d 和 v_d 映射。雖然 Prefix LM 模型對編碼器和解碼器區域使用了共享映射權重,但研究者發現使用兩組映射可以提高品質。
為了提高效率,JRT-RNN 為編碼器使用了非因果線性注意力,而為解碼器使用标準因果線性注意力。研究者稱為 Prefix Linear Attention(PLA)(圖 1 右),公式如下:
JRT-RNN 訓練目标。Prefix LMs 通常不計算非因果區域的損失,而 JRT-RNN 将下一個 token 預測與掩碼語言模組化(MLM)目标進行了結合。并且對于添加的 MLM 目标,研究者用一個 [MASK] token 替換了來自編碼器區域 {u_1, ..., u_M} 的比例為 P 的 tokens,并在預測原始 token 時測量了交叉熵損失
。
損失如下:
實驗結果
在實驗中,研究者評估了 JRT-RNN 在以下三個名額上的品質和效率:
- 上下文學習品質
- 整體語言模組化
- 生成
上下文學習品質
如下表 2 所示,研究者發現,JRT-RNN 在參數為 360M(30B tokens)時比僅解碼器的基線(Based)平均高出 13.7 個百分點,在參數為 1.3B(50B tokens)時平均高出 6.9 個百分點。
同時,JRT-RNN 在參數為 360M 和 1.3B 時與 Transformer++ 的差距分别縮小到了 0.5 個百分點和 1.9 個百分點之内。
在下表 3 中,研究者比較了當 prefill 長度 l 小于編碼器長度 M 時,JRT-RNN 與同類推理政策的表現。
整體自然語言了解
根據以往研究,研究者進一步将困惑度分為了兩組:聯想記憶「AR slice」包括了被稱為「AR hits」的 tokens,它們需要模型按照順序執行記憶以正确地預測下一個 token;而「Other slice」包含剩餘的 tokens(如記憶的知識)。
對于記憶頻率,JRT-RNN 在「AR slice」表現出色。對于訓練期間不常見的二進制組(即不太可能在模型參數中被記住的),JRT-RNN 的困惑度相對于 Based 和 Mamba 這兩個強大的因果循環基線有所改善。
對于記憶距離,在「AR slice」中,JRT-RNN 與僅解碼器基線之間的差距随着上下文中重複二進制組的增加而擴大。這也進一步證明了 JRT-RNN 可以幫助完成更長的上下文記憶任務。
非記憶頻率。對于訓練期間很少見到的二進制組的非記憶「Other slice」,JRT-RNN 的困惑度比僅解碼器的 LM 更差。這是意料之中的結果,因為 JRT-RNN 計算了僅解碼器 LM 的 65% tokens 的損失。
我們預計這一差距會随着規模和訓練時間的延長而縮小(随着二進制文法頻率的增加而增加)(圖 3,左上角)。
生成吞吐量
生成可以分解為提示「prefill 處理」和解碼「下一個 token 預測」兩步。相較于标準的僅解碼器循環模型,JRT-RNN 不會修改解碼步驟,是以讨論重點在 prefill 階段。
使用 Simran Arora 等人論文《Simple linear attention language models balance the recall-throughput tradeof》中提出的 Based CUDAn 核心,JRT-Prompt 在處理 prefill 時吞吐量分别是 FlashAttention-2 和 FLA Triton 核心的 11.9 和 13.7 倍,如下表 5 所示。
當研究者将批大小增加到 64 時,JRT-Prompt 吞吐量分别是 FlashAttention-2 和 FLA Triton 核心的 6.1 倍和 7.2 倍。
接下來他們擴充了 Based 核心以支援 JRT-RNN,并且證明了當将序列長度增加到 32768 時,吞吐量分别是 FlashAttention-2 和 FLA 的 19.2 倍和 22.0 倍。當将批大小增加到 64 時,JRT-RNN 分别又提供了 9.7 倍和 11.5 倍的吞吐量提升。JRT-RNN 所需的時間是 Based prefill 的 1.24 倍,比 JRT-Prompt 更加高效。
更多技術細節和實驗結果請參閱原論文。
END
轉載請聯系本公衆号獲得授權
計算機視覺研究院學習群等你加入!
ABOUT
計算機視覺研究院
計算機視覺研究院主要涉及深度學習領域,主要緻力于目标檢測、目标跟蹤、圖像分割、OCR、模型量化、模型部署等研究方向。研究院每日分享最新的論文算法新架構,提供論文一鍵下載下傳,并分享實戰項目。研究院主要着重”技術研究“和“實踐落地”。研究院會針對不同領域分享實踐過程,讓大家真正體會擺脫理論的真實場景,培養愛動手程式設計愛動腦思考的習慣!
🔗