laitimes

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

Click on the blue word

Follow us

Follow and star

Never get lost

Institute of Computer Vision

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++
Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

Computer Vision Research Institute

Scan the QR code on the homepage to get how to join

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++
  • Address: https://arxiv.org/pdf/2407.05483
  • Project homepage: https://github.com/HazyResearch/prefix-linear-attention

Special column of the Institute of Computer Vision

Column of Computer Vision Institute

In the current AI field, the dominant architecture for large language models is Transformer. However, with the advent of architectures such as RWKV and Mamba, there is a clear trend: cyclic large language models that compete with transformers in terms of language modeling confusion are rapidly entering people's attention.

Excitingly, these architectures use a constant amount of memory during inference. However, due to limited memory, recurrent language models (LMs) are unable to memorize and use all the information in long contexts, which leads to poor in-context learning (ICL) quality. Therefore, the key challenge in obtaining an efficient large language model is to choose what information to store or discard.

在最近的论文《Just read twice: closing the recall gap for recurrent language models》中,来自斯坦福大学、布法罗大学的研究者通过简单观察发现,数据在推理期间涌入循环语言模型的排序极大地影响了在有限内存中预测存储哪些信息的难度。

Let's assume that we ask questions based on document D (e.g., Galileo Galilei's detailed Wikipedia): When did Galileo move to Florence? In this case, if the prompt follows the ordering of [Q, D], the model only needs to remember one fact in document D. Conversely, if the prompt follows the ordering of [D, Q], the model needs to remember all the facts. This is shown in Figure 1 below (left).

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

Therefore, this paper first theoretically formalizes how data sorting affects memory requirements, and then proposes two methods to alleviate the dependence on data sorting, namely the just-read-twice (JRT) prompt strategy and the JRT loop architecture. This article is mainly divided into the following parts:

Understand the role of data sorting. The first insight that the researchers came to was that the hardness of the memory problem should be reduced to the same as set disjointness (SD), which is the most typical problem in communication complexity theory that has lasted for decades. SD requires a streaming algorithm, such as a recurrent model, to decide whether the set of inputs provided in the context is split:

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

Theoretical analysis and experimental results show that the first set | A | Take control of the memory you need to solve for SD. The causal model needs to store all the elements in A for comparison with the elements in B. This indicates that the "correct data sorting" in the context (e.g., the minimum min (|.) A|, |B|) set of first) will help with memory-constrained models. Taking it a step further, models that observe contextual noncausal logic can be found in the smallest space (|A|, |B|) SD is not considered for data sorting.

The second is to use the "correct" sorting. This article proposes a very simple JRT-Prompt strategy to repeat the information multiple times in context before the model generates an answer (as shown on the right in Figure 1 above). In the second and more rounds, the language model made it conditional on the full context when deciding what information to store, effectively avoiding the problem of "correcting" the sorting of the data.

The results show that JRT-Prompt achieves an average improvement of 11.0 ± 1.3 percentage points on 16 existing circular language models and 6 ICL tasks, while the throughput is 11.9 times that of FlashAttention-2 (32k in length and 16 in batch size). JRT-Prompt increases the context length, but is still more computationally and memory efficient than attention in progression.

Go beyond the causal model. In this paper, the JRT-RNN is proposed, which is inspired by the simple Prefix-LM encoder decoder architecture design. Most contextual learning inputs consist of two parts, the input prompt (context, instructions) and the model-generated text as output. In the Prefix-LM architecture, instead of following causal logic to deal with the cue region, the LM decodes the output with only the standard next token predicting the loss in the causal region, as well as the loss on the non-causal region.

Unfortunately, previous methods for training Prefix-LM models have had limited success and use an inefficient Transformer backbone. Therefore, this paper uses some simple changes to improve quality and efficiency, including improving training loss and using a linear attention formula called "Prefix Linear Attention" (PLA). The researchers found that using their IO-aware implementation, JRT-RNN could provide an average quality improvement of 13.7 and 6.9 percentage points, respectively, at 360m and 1.3b parameter settings, with 19.2x higher throughput than FA2.

JRT-Prompt 方法概览

The contextual learning task takes (C, Q, Y) as input, where C is some context source (such as a document or code repository), Q is some question or request to the model given the context, and Y is the answer. For standard contextual learning using autoregressive LM A, the investigator inputs C and Q and evaluates the resulting output Yˆ = A (C, Q) based on the correct completion Y.

JRT-Prompt is an extremely simple way to repeat the information in the prompt (such as questions and documents) in context before prompting the model to output an answer, such as Yˆ = A (C, Q, C, Q) on the right in Figure 1 below. Therefore, on the second occurrence of the context, the model decides what information to store based on the full context.

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

In addition, JRT-Prompt can be used with off-the-shelf LLMs. The investigators assessed the following LMs on a series of memory-intensive contextual tasks with zero-shot prompts:

  • Based pre-trained LM with a parameter size of 1.3B, trained on 10 − 50B tokens of Pile;
  • Mamba pre-trained LM with parameter sizes of 130M, 370M, 1.4B, and 2.8B, trained on 300B tokens of Pile;
  • Gated Linear Attention pre-trained LM with parameter sizes of 1.3B and 2.7B trained on 100B tokens of the SlimPajama dataset;
  • Mamba-2 pre-trained LM with parameter sizes of 130M, 370M, 1.3B, and 2.7B, trained on 300B tokens of Pile.

As shown in Table 1 below, by increasing the state size, the researchers found that the JRT-Prompt method brought an average performance improvement of 11.0 ± 1.3 percentage points on each model and task, and the Based model using this method was on average better than the Transformer model using Standard Prompt.

They also found that JRT-Prompt can benefit Transformer models, and that the method is more effective than few-shot learning on some tasks (Appendix 2). It is worth noting that Springer et al. proposed the use of an autoregressive transformer model to repeat context for the purpose of generating embeddings in their paper "Repetition improves language model embeddings", and the results of this paper are similar. The investigators focused on subquadratic architecture and contextual learning tasks.

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

Although JRT-Prompt increases the context length due to repetition, it still uses a subquadratic loop architecture that is more efficient than using a quadratic transformer model. The researchers found that using JRT-Prompt (sequence length 2N) on the NVIDIA H100 provided 11.9 times the throughput of FlashAttention-2 (sequence length N) at sequence length N = 32768 and batch size 16.

JRT-RNN:编码器 - 解码器循环架构

JRT-RNN is inspired by Prefix-LMs, but focuses on the Pareto frontier of the mass-efficiency trade-off space. To improve quality, JRT-RNN uses separate k_e and v_e mappings on the encoder side and k_d and v_d mappings on the decoder side. While the Prefix LM model uses shared mapping weights for the encoder and decoder regions, the researchers found that using two sets of mappings improved quality.

To improve efficiency, JRT-RNN uses non-causal linear attention for the encoder and standard causal linear attention for the decoder. The investigator called it Prefix Linear Attention (PLA) (Figure 1 right) and the formula is as follows:

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

JRT-RNN Training Target. Prefix LMs typically do not calculate losses in non-causal regions, while JRT-RNNs combine the next token prediction with the Masking Language Modeling (MLM) goal. And for the added MLM target, the investigators replaced the tokens from the encoder region {u_1, ..., u_M} with a [MASK] token, and measured the cross-entropy loss when predicting the original token

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

The losses are as follows:

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

Experimental results

In the experiment, the investigators evaluated the quality and efficiency of the JRT-RNN on the following three metrics:

  • Contextual learning quality
  • Holistic language modeling
  • generate

Contextual learning quality

As shown in Table 2 below, the investigators found that JRT-RNN was on average 13.7 percentage points higher than the decoder-only baseline (Based) at 360M (30B tokens) and 6.9 percentage points higher at 1.3B (50B tokens).

At the same time, the gap between JRT-RNN and Transformer++ at 360M and 1.3B narrowed to within 0.5 and 1.9 percentage points, respectively.

In Table 3 below, the investigators compared the performance of JRT-RNN to similar inference strategies when the prefill length l was less than the encoder length M.

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

Overall natural language understanding

Based on previous studies, researchers have further divided the confusion into two groups: associative memory "AR slice" includes tokens called "AR hits", which require the model to execute the memory sequentially to correctly predict the next token; "Other slice" contains the remaining tokens (e.g., memorized knowledge).

When it comes to memory frequency, JRT-RNN excels in "AR slice". For binads that are uncommon during training (i.e., those that are unlikely to be remembered in the model parameters), the confusion of the JRT-RNN is improved relative to the two strong causal loop baselines of Based and Mamba.

For memory distance, in AR slice, the gap between the JRT-RNN and the decoder-only baseline widens as the number of repeating binaries increases in the context. This is further evidence that JRT-RNN can help with longer contextual memory tasks.

Non-memory frequency. For non-memorized "Other slice" of binary groups that are rarely seen during training, JRT-RNN is less perplexed than decoder-only LM. This is to be expected, as JRT-RNN calculates the loss of 65% of tokens for decoder LM only.

We expect this gap to shrink with scale and training time (increasing with the frequency of binary grammar) (Figure 3, top left).

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

Generate throughput

The generation can be broken down into two steps: "prefill processing" and decoding "next token prediction". In contrast to the standard decoder-only loop model, JRT-RNN does not modify the decoding step, so the discussion focuses on the prefill phase.

使用 Simran Arora 等人论文《Simple linear attention language models balance the recall-throughput tradeof》中提出的 Based CUDAn 内核,JRT-Prompt 在处理 prefill 时吞吐量分别是 FlashAttention-2 和 FLA Triton 内核的 11.9 和 13.7 倍,如下表 5 所示。

When the researchers increased the batch size to 64, the JRT-Prompt throughput was 6.1x and 7.2x higher than the FlashAttention-2 and FLA Triton cores, respectively.

Next, they extended the Based core to support JRT-RNN, and demonstrated that when increasing the sequence length to 32768, the throughput was 19.2x and 22.0x higher than that of FlashAttention-2 and FLA, respectively. When the batch size was increased to 64, JRT-RNN provided an additional 9.7x and 11.5x increase in throughput, respectively. JRT-RNN takes 1.24 times longer than the Based prefill, which is more efficient than JRT-Prompt.

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

For more technical details and experimental results, please refer to the original paper.

END

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

Please contact this official account for authorization for reprinting

Small tricks make a big difference, "only read twice prompts" makes the loop language model surpass Transformer++

The Computer Vision Research Institute Learning Group is waiting for you to join!

ABOUT

Institute of Computer Vision

The Institute of Computer Vision is mainly involved in the field of deep learning, mainly focusing on object detection, object tracking, image segmentation, OCR, model quantization, model deployment and other research directions. The institute shares the latest paper algorithm and new framework every day, provides one-click download of papers, and shares practical projects. The institute mainly focuses on "technical research" and "practical implementation". The institute will share the practice process for different fields, so that everyone can truly experience the real scene of getting rid of theory, and cultivate the habit of loving hands-on programming and thinking with their brains!

🔗

Read on