Efficient Meta Llama 3 Pre-training with torchtitan on Amazon SageMaker
This post is co-written with Less Wright and Wei Feng from Meta.
Pre-training large language models (LLMs) is the first step in developing powerful AI systems that can understand and generate human-like text. By exposing models to vast amounts of diverse data, pre-training lays the groundwork for LLMs to learn general language patterns, world knowledge, and reasoning capabilities. This foundational process enables LLMs to perform a wide range of tasks without task-specific training, making them highly versatile and adaptable. Pre-training is essential for building a strong base of knowledge, which can then be refined and specialized through fine-tuning, transfer learning, or few-shot learning approaches.
In this post, we collaborate with the team working on PyTorch at Meta to showcase how the torchtitan library accelerates and simplifies the pre-training of Meta Llama 3-like model architectures. We showcase the key features and capabilities of torchtitan such as FSDP2, torch.compile integration, and FP8 support that optimize the training efficiency. We pre-train a Meta Llama 3 8B model architecture using torchtitan on Amazon SageMaker on p5.48xlarge instances, each equipped with 8 Nvidia H100 GPUs. We demonstrate a 38.23% performance speedup in the training throughput compared to the baseline without applying the optimizations.
To learn more, you can find our complete code sample on GitHub.
Introduction to torchtitan
torchtitan is a reference architecture for large-scale LLM training using native PyTorch. It aims to showcase PyTorch’s latest distributed training features in a clean, minimal code base. The library is designed to be simple to understand, use, and extend for different training purposes, with minimal changes required to the model code when applying various parallel processing techniques.
torchtitan offers several key features, including FSDP2 with per-parameter sharding, tensor parallel processing, selective layer and operator activation checkpointing, and distributed checkpointing. It supports pre-training of Meta Llama 3-like and Llama 2-like model architectures of various sizes and includes configurations for multiple datasets. The library provides straightforward configuration through TOML files and offers performance monitoring through TensorBoard.
Transitioning from FSDP1 to FSDP2
FSDP1 and FSDP2 are two approaches to fully sharded data parallel training. FSDP1 uses flat-parameter sharding, which flattens all parameters to 1D, concatenates them into a single tensor, pads it, and then chunks it across workers. FSDP2, on the other hand, represents sharded parameters as DTensors sharded on dim-0, handling each parameter individually. This approach enables easier manipulation of parameters, for example per-weight learning rate, communication-free sharded state dicts, and simpler meta-device initialization.
torchtitan support for torch.compile
torchtitan supports torch.compile, which delivers substantial speedups, especially for large models and complex architectures, by using techniques like operator fusion, memory planning, and automatic kernel selection. This is enabled by setting compile = true in the model’s TOML configuration file.
torchtitan support for FP8 linear operations
torchtitan provides support for FP8 (8-bit floating point) computation that significantly reduces memory footprint and enhances performance in LLM training. FP8 has two formats, E4M3 and E5M2, each optimized for different aspects of training. FP8 support on torchtitan is through the torchao library, and we enable FP8 by setting enable_float8_linear = true in the model’s TOML configuration file.
To continue reading, visit the original post here.