Advancing Large Language Models with Multi-Token Prediction

In an article recently submitted to the arxiv* server, researchers proposed training large language models (LLM) to simultaneously predict multiple future tokens using independent output heads. This approach improved sample efficiency and downstream capabilities, especially for larger models and generative tasks like coding.

Study: Enhancing Large Language Models with Multi-Token Prediction. Image Credit: Alexander Supertramp/Shutterstock
Study: Enhancing Large Language Models with Multi-Token Prediction. Image Credit: Alexander Supertramp/Shutterstock

*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.

The approach showed significant performance gains, solving more problems on benchmarks and improving algorithmic reasoning and induction heads. Models trained with four-token prediction achieved up to three times faster inference.

Background

LLMs have achieved impressive feats in capturing world knowledge and basic reasoning through the next-token prediction task. However, this method is inefficient, as it focuses on local patterns and requires significantly more data than human learning to reach similar fluency levels. Previous studies have explored multi-token prediction, which involves predicting multiple future tokens simultaneously. Despite its promise, this approach has not been extensively applied at scale.

The present research addressed this gap by proposing a straightforward multi-token prediction framework that did not increase training time or memory overhead. The authors provided experimental evidence demonstrating the benefits of this method, particularly for large models with up to 13 billion parameters, which solve approximately 15% more coding problems on average.

Additionally, multi-token prediction facilitated self-speculative decoding, enhancing inference speed by up to three times across various batch sizes. This work highlighted the potential of multi-token prediction to enhance LLM performance, coherence, and reasoning abilities beyond traditional next-token prediction.

Multi-Token Prediction Architecture and Efficient Training Methods

The proposed method generalized standard language modeling by implementing a multi-token prediction task. Instead of predicting the next token in sequence, the model predicted n future tokens simultaneously, minimizing the multi-token cross-entropy loss. The architecture comprised a shared transformer trunk that produced a latent representation of the observed context and n independent output heads that predicted each of the n future tokens in parallel. This factorized the multi-token prediction cross-entropy loss, enhancing the model's predictive capabilities.

To address the challenge of graphic processing unit (GPU) memory utilization in training multi-token predictors, the authors adapted the sequence of forward and backward operations. Instead of materializing all logits and their corresponding gradients, the method sequentially computed forward and backward passes for each independent output head, accumulating gradients at the trunk and freeing memory before moving to the next head. This reduced peak GPU memory usage without increasing runtime.

During inference, the architecture could perform vanilla next-token autoregressive prediction using the next-token prediction head. It leveraged other output heads to speed up decoding through self-speculative decoding methods like blockwise parallel decoding and Medusa-like tree attention, enhancing inference efficiency.

Experimental Results on Multi-Token Prediction Models

The researchers conducted seven large-scale experiments to evaluate the efficacy of multi-token prediction models. The findings indicated that multi-token prediction became increasingly beneficial as model size grew, significantly enhancing performance on code and natural language tasks.

  • Model Size Scaling: Multi-token prediction models outperformed next-token models at larger scales, demonstrating better results on code benchmarks like mostly basic Python programming (MBPP) and HumanEval. Faster Inference: Using speculative decoding, multi-token prediction models achieved up to three times faster inference speeds on code and text.
  • Global Pattern Learning: Multi-token prediction models excelled at learning long-term patterns, particularly with byte-level tokenization, showing a 67% improvement on MBPP pass@1.
  • Optimal Token Prediction: Training with four future tokens consistently outperformed other configurations across various benchmarks.
  • Multi-Epoch Training: The advantages of multi-token prediction persisted across multiple training epochs, maintaining an edge over next-token models.
  • Finetuning: Pretrained multi-token models exhibited superior performance when finetuned on challenging tasks like CodeContests.
  • Natural Language: While multi-token prediction models showed modest improvements on some natural language tasks, larger datasets might be necessary for significant gains.

Overall, multi-token prediction enhanced model capabilities sped up inference, and offered robust performance across diverse tasks.

Speculation on Why Multi-Token Prediction Works

Multi-token prediction improved performance by reducing the gap between training-time teacher forcing and inference-time autoregressive generation. It assigned higher implicit weights to critical "choice point" tokens that influenced subsequent text, ensuring better decision-making at these junctures.

An information-theoretic perspective revealed that multi-token prediction emphasized the mutual information between successive tokens, enhancing the model's ability to predict tokens that were crucial for the continuation of coherent and relevant text. This approach led to more accurate and effective language models, particularly for tasks requiring longer-term dependencies.

Conclusion

In conclusion, multi-token prediction presented a substantial advancement over next-token methods for training large language models, particularly enhancing performance in generative and reasoning tasks. By minimizing the gap between the training and inference phases, it optimized decision-making at critical points in text generation.

This approach, supported by efficient speculative decoding, significantly accelerated inference speeds. Future research aims to automate optimal token selection and refine vocabulary sizes, potentially further improving model efficiency and effectiveness across diverse applications.

*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.

Journal reference:
  • Preliminary scientific report. Gloeckle, F., Idrissi, B. Y., Rozière, B., Lopez-Paz, D., & Synnaeve, G. (2024, April 30). Better & Faster Large Language Models via Multi-token Prediction. ArXiv.org. DOI: 10.48550/arXiv.2404.19737, https://arxiv.org/abs/2404.19737
Soham Nandi

Written by

Soham Nandi

Soham Nandi is a technical writer based in Memari, India. His academic background is in Computer Science Engineering, specializing in Artificial Intelligence and Machine learning. He has extensive experience in Data Analytics, Machine Learning, and Python. He has worked on group projects that required the implementation of Computer Vision, Image Classification, and App Development.

Citations

Please use one of the following formats to cite this article in your essay, paper or report:

  • APA

    Nandi, Soham. (2024, July 02). Advancing Large Language Models with Multi-Token Prediction. AZoAi. Retrieved on December 10, 2024 from https://www.azoai.com/news/20240702/Advancing-Large-Language-Models-with-Multi-Token-Prediction.aspx.

  • MLA

    Nandi, Soham. "Advancing Large Language Models with Multi-Token Prediction". AZoAi. 10 December 2024. <https://www.azoai.com/news/20240702/Advancing-Large-Language-Models-with-Multi-Token-Prediction.aspx>.

  • Chicago

    Nandi, Soham. "Advancing Large Language Models with Multi-Token Prediction". AZoAi. https://www.azoai.com/news/20240702/Advancing-Large-Language-Models-with-Multi-Token-Prediction.aspx. (accessed December 10, 2024).

  • Harvard

    Nandi, Soham. 2024. Advancing Large Language Models with Multi-Token Prediction. AZoAi, viewed 10 December 2024, https://www.azoai.com/news/20240702/Advancing-Large-Language-Models-with-Multi-Token-Prediction.aspx.

Comments

The opinions expressed here are the views of the writer and do not necessarily reflect the views and opinions of AZoAi.
Post a new comment
Post

While we only use edited and approved content for Azthena answers, it may on occasions provide incorrect responses. Please confirm any data provided with the related suppliers or authors. We do not provide medical advice, if you search for medical information you must always consult a medical professional before acting on any information provided.

Your questions, but not your email details will be shared with OpenAI and retained for 30 days in accordance with their privacy principles.

Please do not ask questions that use sensitive or confidential information.

Read the full Terms & Conditions.

You might also like...
Logic Training Transforms AI Into Smarter Problem-Solver