LiBai Model Library to Train Large Models More Easily and Efficiently

OneFlow
16 min readAug 9, 2022

--

Translated by Hu Yanjun, Shen Jiali, Dong Wenwen, Jia Chuan

Starting with BERT in 2018, large models sprung up one after another, including GPT-3 and ViT, whose parameters are counted in billions. Explosive growths in model size happen so frequently that they can hardly impress AI developers. What really troubles the engineers is how to accelerate the training of such large models.

Larger models come with much higher training costs and pose greater challenges to computation and memory resources. For instance, training GPT-3, a model containing over 100 billion parameters, with a state-of-the-art NVIDIA A100 GPU will take more than 100 years.

Larger models are demanding larger GPU memory, but the current GPUs are not growing fast enough in memory size to meet the needs. According to a report by OpenAI, AI model size is doubling every 3.5 months, greatly outpacing the 18 months of GPU memory. This means that one single GPU can no longer accommodate the numerous parameters of large models.

Therefore, developers have to split the computation across multiple GPU devices. Distributed training becomes an inevitable choice.

However, distributed training has a high technical threshold. It takes more than abundant computing resources. Programming for distributed parallel training requires expertise in computer systems and architecture and rich hands-on experience, it increases the difficulty of exploring cutting-edge algorithms and new models. Thus, the development and training of large models become exclusive to tech giants. It’s of top priority to accelerate model training and make large models accessible to more engineers.

But with all the model libraries available for distributed training, which one should we choose?

Luckily, OneFlow has released its LiBai model library recently, making it easier to answer the above question. OneFlow is an open-source DL framework known for its excellent performance. The LiBai model library gathers the merits of mainstream Transformers libraries spanning Hugging Face, Megatron-LM, DeepSpeed, and FairSeq, and outperforms many of its competitors in distributed training. More importantly, its Global View Programming has lowered the bar for distributed training, thus allowing more developers to train large-scale models.

Find out more about the LiBai model library: https://github.com/Oneflow-Inc/libai .

So, how does LiBai make its way towards excellence? In the rest of this article, we will compare LiBai with other distributed training tools in terms of training performance, ease of use, and so forth, to provide you with reference next time you make a choice.

1. One-click auto distributed training with better performance than Megatron-LM and DeepSpeed

Specifically, as a simple and efficient distributed model training toolkit, the LiBai library boasts the following six features:

  • Easy scaling from single-GPU training to multi-GPU training. The models in LiBai are aligned with those in PyTorch, which saves you the trouble of learning and getting used to new operating styles. With LiBai, scaling to parallel training only requires simple configuration. This means if users want to add a new feature to their models and put it into distributed training, all they need is to add it and debug it in the single-GPU training code, and LiBai will take care of the rest. What’s more, if users want to save effort by skipping the configuration step for distributed training, they can simply install the Auto_Parallel package (https://libai.readthedocs.io/en/latest/tutorials/basics/Auto_Parallel.html) and configure a line of code: graph.auto_parallel=True in LiBai. In this way, they can concentrate on the models themselves without worrying about the implementation details of distributed training while benefiting from quick training speeds.
  • Compatibility with Hugging Face. OneFlow is highly compatible with PyTorch in the API layer. Users can import Hugging Face models by a simple modification of the code. They can easily train a large model via import oneflow as torch utilizing mechanisms in LiBai such as Data Parallelism, Automatic Mixed Precision, Activation Checkpoint, Zero Redundancy Optimizer(ZeRO). To train large models with 3D parallelism, users only need to replace a fews layers of the model with the layers in LiBai.
  • Modular design. For the implementation of LiBai, we offer not only replicable basic computation modules for model construction but also abstraction and modularization for data loading, training logic, indicator computing, and so on. The modular design allows users to override codes and integrate them as plug-ins into LiBai’s training system to cater to their own needs.
  • Out-of-the-box. Training a large-scale model usually requires a common series of techniques, and LiBai supports features spanning Mixed Precision Training, Gradient Re-computation, Gradient Accumulation, and ZeRO, which can be easily used in combination with data parallelism, model parallelism, and pipeline parallelism.
  • Rapid reproduction of experiments. The OneFlow team has learned from Detectron2 LazyConfig (https://github.com/facebookresearch/detectron2/blob/main/docs/tutorials/lazyconfigs.md) in constructing LiBai’s configuration system. That’s why LiBai has a more flexible configuring system than the traditional argparse and yacs-based configuring methods. The system is constructed in Python grammar, so it’s convenient to add new parameters and modules in it. For example, adding a new module only requires importing the module. Besides, the training configuration can also be serialized to a yaml file for storage, so users can conveniently search for configuration items in the file by inputting keywords. In addition, if users want to reproduce the result of a previous experiment, they can directly import the config.yaml as the training configuration. In this way, LiBai avoids the need to preserve multiple script files, which makes it inconvenient to check the valid modifications and increases risks of confusion between different experiment configurations.
  • High efficiency. By strict kernel alignment with Megatron-LM, LiBai has implemented various kinds of kernel fusion operations. Besides, benefiting from OneFlow’s static graph design, LiBai surpassed NVIDIA’s Megatron-LM and Microsoft’s DeepSpeed in terms of single-GPU performance and the efficiency of different mixed parallelism methods.

Thanks to OneFlow SBP’s native support for various parallel technologies, LiBai is able to decouple algorithmic description from the parallel system. It has managed to realize features with much less code. It takes NVIDIA Megatron-LM and Microsoft DeepSpeed 100,000 lines of code in total to do what LiBai can do with only around 30,000 lines of code.

Data speaks for itself. The following shows how LiBai and Megatron-LM perform under various models in the same hardware environments, third-party dependencies (CUDA, cuDNN, etc.), parameters, and network structures. (All performance results are public and reproducible, https://libai.readthedocs.io/en/latest/tutorials/get_started/Benchmark.html). In the future, OneFlow will release the performance of LiBai on a larger cluster of devices.

Data parallelism

Note: Here are the meanings of the parameters involved:

DP: Data Parallelism

MP: Model Parallelism

PP: Pipeline Parallelism

2D: 2D Parallelism

3D: 3D Parallelism

fp16: enable automatic mixed precision (amp) training

nl: num layers (When pipeline parallel size = 8, in order to have a relative number of layers per stage for computation, we adjust the num layers from 24 to 48.)

ac: enable activation checkpointing

mb: micro-batch size per gpu

gb: global batch size total

dxmxp:

d = data-parallel-size

m = tensor-model-parallel-size

p = pipeline-model-parallel-size

1n1g: 1 node, 1 GPU

1n8g: 1 node, 8 GPUs

2n8g: 2 nodes, 8 GPUs per node (16 GPUs in total)

4n8g: 4 nodes, 8 GPUs per node (32 GPUs in total)

The result of grad_acc_num_step = global_batch_size / (micro_batch_size * data_parallel_size) is throughout.

(Note: In Group 1, num layers = 24, amp enabled, 1n1g micro-batch size = 24; in Group 2~5, micro-batch size = 16.)
(Note: In Group 1, num layers = 24, amp enabled, 1n1g micro-batch size = 6; in Group 2~5, micro-batch size = 4.)

Model Parallelism

(Note: num layers = 24, amp enabled, activation checkpointing enabled, micro-batch size = 128, global batch size = 1024 and grad acc step = 8.)
(Note: num layers = 24, amp enabled.)

Pipeline Parallelism

(Note: In Group 1&2, num layers = 24, grad acc step = 8; in Group 3, num layers = 48, grad acc step = 16. Both amp and activation checkpointing are enabled in all 3 groups.)
(Note: In Group 1&2, num layers = 24, grad acc step = 8; in Group 3, num layers = 48, grad acc step = 16. Both amp and activation checkpointing are enabled in all 3 groups.)

2D Parallelism

Data & Model Parallelism

(Note: num layers = 24, amp enabled, activation checkpointing enabled, micro-batch size = 128, grad acc step = 8.)
(Note: num layers = 24, amp enabled, activation checkpointing enabled, micro-batch size = 32, grad acc step = 8.)

Data & Pipeline Parallelism

(Note: num layers = 24, amp enabled, activation checkpointing enabled, micro-batch size = 128, grad acc step = 8.)
(Note: num layers = 24, amp enabled, activation checkpointing enabled, micro-batch size = 32, grad acc step = 8.)

3D Parallelism

(Note: num layers = 24, amp enabled, activation checkpointing enabled, grad acc step = 8.)
(Note: num layers = 24, amp enabled, activation checkpointing enabled, grad acc step = 8.)

As is shown above, the training speeds of LiBai exceed those of Megatron-LM on both Bert and GPT-2 models in every experiment on the basis of strictly aligned experimental environments.

2. LiBai: more and better

As we mentioned, currently there are plenty of large model training solutions, such as Hugging Face, DeepSpeed, Megatron-LM, and FairSeq. Do we really need another model library?

To answer this question, let’s see what LiBai has to offer.

HuggingFace: It provides all kinds of SOTA Transformer models, which are pretrained and only require some fine-tuning before put into use. It also has a well-developed community and ecosystem to support developers. However, it only supports data parallelism, which makes it less handy when the model size exceeds the memory capacity of a single GPU. Plus, training models from scratch with Hugging Face is low-speed.

FairSeq: It is targeted at sequence models and lacks support for CV models under the current merging trend of NLP and CV.

Megatron-LM: Based on PyTorch, it is able to implement data parallelism, model parallelism, and pipeline parallelism and deliver high performance. It can handle the training of super large-scale models.

However, it requires too much customization, making it unfriendly to those algorithm engineers who are less of a distributed training expert. In addition, it provides far fewer models than Hugging Face do, so if engineers want to reproduce a large model in PyTorch, they can only wait until that model is implemented based on Megatron-LM by someone who is more adept in distributed training.

DeepSpeed: It is a deeply-customized library related to model memory optimization based on PyTorch. It supports technologies including distributed training, mixed precision training, and ZeRO, so it can largely reduce memory overhead and allow effective training of large models under data parallelism. However, DeepSpeed does not support model parallelism. Model parallelism (tensor parallelism, pipeline parallelism) is a better choice when a single GPU can not accommodate the parameters of certain layers of the model, or the communication efficiency is dragged down by the sharding of DeepSpeed. Thus, to meet their own needs, users can only use DeepSpeed in combination with Megetron-LM and change the original code.

Megatron-LM and DeepSpeed are the earliest libraries for large model training in the PyTorch ecosystem. Some renowned organizations around the world joined the arena later and launched libraries such as FairSeq. But it’s noteworthy that the core distributed functions of the latecoming libraries are all implemented based on Megatron-LM and DeepSpeed.

LiBai, instead of being a slightly upgraded version of any of the above-mentioned libraries, is a useful kit for large pretrained model development that is built on the outstanding distributed training and graph compiler performance of OneFlow. That’s why it boasts incomparable performance and ease of use in distributed training.

  • Compatibility. LiBai is compatible with the existing PyTorch-based SOTA models so users can transfer models from PyTorch conveniently.
  • High Efficiency. LiBai delivers high efficiency in both single-GPU and multi-GPU training.
  • Ease of Use. With good extensibility, LiBai allows users to easily modify the models based on their own needs or add new features to models to speed up their prototype development work. It lowers the bar for distributed deep learning training so greatly that users don’t need to go through painful studies to get on board. When developing new models and new features, all you need to do is to program for single-GPU training, and LiBai will help scale it to large GPU clusters for distributed training so you don’t have to override any codes. What a time saver!

We believe all these traits make LiBai a wise choice for distributed training.

3. LiBai supports all regular parallel training methods

Distributed large model training entails multiple parallel methods, including data parallelism, tensor/model parallelism, and pipeline parallelism. LiBai supports all these methods and every arbitrary combination of them. (For more information on the parallel methods, please refer to: https://docs.oneflow.org/en/master/parallelism/01_introduction.html)

It’s always a headache to learn to implement new parallel methods by yourself. For example, people had to go through all the trouble of configuring Apex to enable AMP training, DALI to support data loading pipelines, and DeepSpeed for the use of ZeRO to reduce memory usage. However, with LiBai, you have no such worries since it is already packed with various parallel methods and great extensibility.

The followings show you how to implement parallelism with LiBai by concrete examples.

A general way to implement parallelism

Via the SBP interface of OneFlow, users can easily shard the input data or weights in the neural network based on their needs and GPU arrangements to implement data parallelism or tensor parallelism.

libai-layers has incorporated a series of network layers, including the frequently-used Linear, MLP, and Transformer modules, which automatically adapt to various parallelism strategies. Therefore, when constructing a neural network via libai-layers, all you need is to adjust the distributed training hyperparameters in the configuration files. Then you can easily implement training strategies including data parallelism, tensor parallelism, and data & tensor mixed parallelism.

The format of distributed configuration is as follows:

# configs/common/train.py
# Distributed arguments
dist=dict(
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
)

data_parallel_size and tensor_parallel_size are used to determine how the input data and model weights will be sharded on different GPU groups. In the above snippet, all values are set to 1, which means to train on a single GPU. Now imagine that a user has 8 GPUs, in what follows we will show you how to modify the configuration files to implement data parallelism, tensor parallelism, and pipeline parallelism on 8 GPUs.

This document provides detailed instructions about distributed configuration in LiBai:

https://libai.readthedocs.io/en/latest/tutorials/basics/Distributed_Configuration.html

Data parallelism & model parallelism

To implement data (or model) parallel training on 8 GPUs, users only need to override the hyperparameters of distributed training in the configuration files.

  • Data parallelism
# your config.py
from libai.config import get_config
train = get_config("common/train.py").train
train.dist.data_parallel_size = 8

To implement data parallel training, multiple ranks copy the same model with each rank dedicated to processing a part of the input data.

  • Model parallelism
# your config.py
from libai.config import get_config
train = get_config("common/train.py").train
train.dist.tensor_parallel_size = 8

To implement model parallel training, the model is partitioned across 8 GPUs and each GPU is only placed with a part of the model.

Data & model mixed parallelism

To implement data & model mixed parallelism on 8 GPUs, users only need to make simple modifications to the distributed training parameters in the configuration files.

# your config.py
from libai.config import get_config
train = get_config("common/train.py").train
train.dist.data_parallel_size = 2
train.dist.tensor_parallel_size = 4

In this case, LiBai will automatically divide the GPUs into groups. We number the 8 GPUs from “0” to “7”. When data_parallel_size is set to "2" and tensor_parallel_size is set to "4", the system will divide the 8 GPUs into two groups: [[0, 1, 2, 3], [4, 5, 6, 7]], with [0, 1, 2, 3] being one group and [4, 5, 6, 7] the other. Data parallelism will be implemented across the two groups and model parallelism will be implemented within each group.

Configuration of pipeline parallelism

In essence, pipeline parallelism can be explained as follows: 1) the neural network is divided into stages; 2) each stage is distributed to one GPU; 3) the computation result of one stage will be passed to the next stage for further computation, which works like an assembly line. For more information about pipeline parallelism, please check: https://docs.oneflow.org/en/master/parallelism/01_introduction.html

  • Configuration of naive pipeline parallelism

In LiBai, you can assign different layers of the network to different GPUs by setting the placement parameters. You can easily set values for the placement parameters via the get_layer_placement() interface in libai.utils.distributed. LiBai can automatically partition stages and assign placements to stages according to the distributed configuration in the configuration file (config). Therefore, for configuration of pipeline parallelism, you only need to configure the placement for each layer of the network.

In most networks, a Linear layer is often used as the head of the network to produce the final results for classification or other tasks. Therefore, here we take the Linear layer as an example to introduces the simplest pipeline parallelism configuration method in LiBai:

from libai.layers import Linear
self.head = Linear(hidden_size, num_classes)
  • Configure the placement of network modules

There are two ways to assign a layer of network to the corresponding placement in LiBai:

  1. Manually specify the placement via the to_global interface and get_layer_placement (). In the following snippet, get_layer_placement(-1) means that the head layer is assigned to the last placement.
from libai.layers import Linear
import libai.utils.distributed as dist
self.head = Linear(hidden_size, num_classes).to_global(placement=dist.get_layer_placement(-1))
  1. (Recommended) Modules implemented in libai.layers come with the layer_idx parameter, so we can specify the placement of this layer by directly setting the layer_idx parameter.
from libai.layers import Linear
self.head = Linear(hidden_size, num_classes, layer_idx=-1)

Configure the placement of input data

After configuring the placement of modules in the network, users need to specify the placement of input data, because the calculation can only be carried out when the input and network are in the same stage. The most intuitive way for this is to configure the same placement for the input and network, which can be done via to_global with get_layer_placement():

class MyModule(nn.Module):
def __init__(self, ... *, layer_idx):
...
self.layer_idx = layer_idx
...
def forward(self, input_data):
input_data = input_data.to_global(placement=dist.get_layer_placement(self.layer_idx))
...

Implement naive pipeline parallelism easily with configuration files

After configuring the placement of network layers and the input data, users only need to adjust the configuration file (config) before they can implement pipeline parallelism. Users need to know the number of network layers beforehand, and adjust the pipeline_num_layers in the configuration file:

# set the number of pipeline stages to be 2
train.dist.pipeline_parallel_size = 2
# set model layers for pipeline
train.dist.pipeline_num_layers = hidden_layers

1F1B is a new pipeline parallel training method introduced in the PipeDream paper (https://arxiv.org/pdf/1806.03377.pdf), which can save GPU memory and utilize resources more efficiently. LiBai support the 1F1B strategy in an easy way (https://github.com/Oneflow-Inc/libai/blob/main/docs/source/tutorials/advanced_tutorials/customize_dataloader.md).

The realization of 3D parallelism

After mastering data & model mixed parallelism and pipeline parallelism, you only need to synthesize the above-mentioned parallelism changes to realize the configuration of data + model + pipeline parallelism.

# your config.py
from libai.config import get_config
train = get_config("common/train.py").train
train.dist.data_parallel_size = 2
train.dist.tensor_parallel_size = 2
train.dist.pipeline_parallel_size = 2
hidden_layers = 8 # Layers of the network
train.dist.pipeline_num_layers = hidden_layers

Again, let’s take 8 GPUs as an example, after setting data_parallel_size, tensor_parallel_size and pipeline_parallel_size to "2", the model will be automatically divided across 8 GPUs according to the pipeline_num_layers set by users.

With the above-mentioned configuration, the model will be partitioned into two stages that are implemented by GPU [0, 1, 2, 3] and [4, 5, 6, 7], respectively. In Stage 0, GPU [0, 2] and [1, 3] will implement data parallelism, and GPU [0, 1] and [2, 3] will implement model parallelism. In Stage 1, GPU [4, 6] and [5, 7] will implement data parallelism, and GPU [4, 5] and [6, 7] will implement model parallelism.

Custom parallel training

As described above, LiBai provides encapsulated modules for users to call in libai/layers/. Using these modules as building blocks, users can construct their own parallel networks.

When the modules in LiBai are not enough to meet their needs, users can customize the parallel strategy conveniently. In PyTorch, users need to insert a complex series of communication operations such as scatter-> forward-> reduce, but in LiBai, users only need to define the sbp and placement when initializing tensor. This makes implementing parallelism as easy as running code on a stand-alone device. (For details of sbp and placement, please refer to https://docs.oneflow.org/en/master/parallelism/04_2d-sbp.html).

For example, when a user performs 4-GPU training, the intermediate result of the network contains a 2D parallel tensor in the shape of (16, 8), which is divided across the 4 GPUs as is shown below. In LiBai, the placement distribution of that tensor is ranks = [[0, 1], [2, 3]], and the SBP is (S [0], S [1]) or (S [1], S [0]).

[            |   
X00 gpu0 | X01 gpu1
--------------------------
X10 gpu2 | X11 gpu3
| ]

Among them, the shapes of Xij are all (8, 4), which means the tensor is evenly distributed across the GPUs. If you want to add some random noise to this tensor, you can easily add the following code in LiBai:

dist.get_nd_sbp() is encapsulated in LiBai to be compatible with the requirements of 1D parallel, and dist.get_layer_placement() is to facilitate the configuration of pipeline parallel. In most cases, users can directly refer to the following code:

# test.py
import oneflow as flow
from omegaconf import DictConfig
from oneflow import nn
from libai.utils import distributed as dist
cfg = DictConfig(
dict(data_parallel_size=2, tensor_parallel_size=2, pipeline_parallel_size=1))
dist.setup_dist_util(cfg)
class Noise(nn.Module):
def __init__(self):
super().__init__()
self.noise_tensor = flow.randn(
16, 8,
sbp=dist.get_nd_sbp([flow.sbp.split(0), flow.sbp.split(1)]),
placement=dist.get_layer_placement(layer_idx=0)
)
# Or the following instead
# self.noise_tensor = flow.randn(
# 16, 8,
# sbp=(flow.sbp.split(0), flow.sbp.split(1)),
# placement=flow.placement("cuda", ranks=[[0, 1],[2, 3]])
# )
def forward(self, x):
return x + self.noise_tensor
Noise = Noise()
x = flow.zeros(
16, 8,
sbp=(flow.sbp.split(0), flow.sbp.split(1)),
placement=flow.placement("cuda", ranks=[[0, 1],[2, 3]])
)
y = Noise(x)
print(f"rank: {flow.env.get_rank()}, global tensor: shape {y.shape} sbp {y.sbp} placement {y.placement}, local tensor shape: {y.to_local().shape}")

Run command:

python3 -m oneflow.distributed.launch --nproc_per_node 4 test.py

The output is shown below. From the shape, you can see the distribution of tensors across the ranks and the information of the tensor from the global perspective.

rank: 2, global tensor: shape oneflow.Size([16, 8]) sbp (oneflow.sbp.split(axis=0), oneflow.sbp.split(axis=1)) placement oneflow.placement(type="cuda", ranks=[[0, 1], [2, 3]]), local tensor shape: oneflow.Size([8, 4])
rank: 3, global tensor: shape oneflow.Size([16, 8]) sbp (oneflow.sbp.split(axis=0), oneflow.sbp.split(axis=1)) placement oneflow.placement(type="cuda", ranks=[[0, 1], [2, 3]]), local tensor shape: oneflow.Size([8, 4])
rank: 1, global tensor: shape oneflow.Size([16, 8]) sbp (oneflow.sbp.split(axis=0), oneflow.sbp.split(axis=1)) placement oneflow.placement(type="cuda", ranks=[[0, 1], [2, 3]]), local tensor shape: oneflow.Size([8, 4])
rank: 0, global tensor: shape oneflow.Size([16, 8]) sbp (oneflow.sbp.split(axis=0), oneflow.sbp.split(axis=1)) placement oneflow.placement(type="cuda", ranks=[[0, 1], [2, 3]]), local tensor shape: oneflow.Size([8, 4])

4. Future Plan

So far, LiBai supports common models such as BERT, GPT, ViT, Swin-Transformer and T5, as well as the latest technologies like MoCoV3 and MAE. In LiBai, these are out of the box and can be easily fine-tuned for downstream tasks.

OneFlow will improve compatibility with the Hugging Face models and increase connectivity to the Hugging Face ecosystem. Meanwhile, via OneFlow’s automatic parallel function, users will enjoy the convenience of automatic scaling from single-GPU training to distributed training.

In the future, OneFlow will not only support more models but also improve its features related to inference and serving. From training to deployment, OneFlow aims to be a one-stop development platform for AI engineers.

Welcome to visit OneFlow on GitHub and follow us on Twitter and LinkedIn.

Also, welcome to join our Discord group to discuss and ask OneFlow related questions, and connect with OneFlow contributors and users all around the world.

--

--

OneFlow
OneFlow

Written by OneFlow

OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient. https://github.com/Oneflow-Inc/oneflow

No responses yet