Continual Learning of Foundation Models:
CL-FoMo Suite of 9.6B and 410M LLMs
CL-FoMo Suite of 9.6B and 410M LLMs
This is a progress update on our CL-FoMo (Continually-Learned Foundation Models) suite of open-source LLMs, containing four 410M models, and four 9.6B models, trained on Pile, SlimPajama (SP), Mix Pile+SP, and Continual (Pile, SP). On most benchmarks, our 9.6B continual model is competitive with the performance of the baseline joint model trained on mixed Pile+SP data (we call it IID Pile+SP to underscore that samples in the mix are identically distributed), and attains similar performance to RedPajama 7B and Pythia 12B on featured benchmarks.
See also:
Presentations given by the team:
The 6th Scaling workshop: Continual Pretraining of Foundation Models
AI @ Scale: Continual Pretraining of Foundation Models
Paper: "Simple and Scalable Strategies to Continually Pre-train Large Language Models" (submitted).
Earlier version: Continual Pre-Training of Large Language Models: How to (re)warm your model, in Efficient Natural Language and Speech Processing (ENLSP-III) workshop at NeurIPS 2023.
Main points
Our focus is on developing best practices re: learning rate schedule and replay, for continual pretraining of foundation models without the need to start training from scratch each time a new dataset becomes available.
While this is a work in progress, we already observe that:
- Our 9.6B continual model (Pile, SP) is competitive with the performance of the baseline joint model (Pile+SP)
- Our 9.6B model attains similar performance to RedPajama 7B and Pythia 12B on featured benchmarks
Also, we observe that:
the amount of data used for warming up the learning rate does not significantly influence the perplexity on the downstream task (learning) or the upstream task (forgetting);
altering the maximum learning rate can be an effective way to trade-off between adaptation and forgetting;
continual learning is not necessarily worse than retraining on all the data available
Introduction: why continual pretraining?
Large-scale unsupervised pre-trained models, a.k.a. foundation models, are taking the AI field by storm, as they are beating the prior state-of-art performance by far and achieve impressive few-shot generalization abilities on a variety of tasks in multiple domains. These successes become possible due to the fact that performance of foundation models scales consistently with increasing model, data and compute size. Neural scaling laws predicting such capability advances are being observed and characterized in various recent works.
While the initial trend (2020-2021) was primarily on increasing the model size, followed the original Kaplan’s scaling laws, the more recent trend is towards smaller models and (significantly) larger dataset size (as well as the quality of the training data). The shift started with the discovery of better (compute-optimal) Chinchilla scaling laws by the authors of the Chinchilla paper: Training Compute-Optimal Large Language Models. Furthermore, recent LLaMA models released by Meta further improved the state-of-art by training relatively small model sizes, with even larger than Chinchilla-optimal datasets, and also emphasized the importance of taking into account not only training, but inference efficiency. As a result, for example, LLaMA-13B performed better than GPT-3 (175B) in most tests or evaluations.
While the initial trend (2020-2021) was primarily on increasing the model size, followed the original Kaplan’s scaling laws, the more recent trend is towards smaller models and (significantly) larger dataset size (as well as the quality of the training data). The shift started with the discovery of better (compute-optimal) Chinchilla scaling laws by the authors of the Chinchilla paper: Training Compute-Optimal Large Language Models. Furthermore, recent LLaMA models released by Meta further improved the state-of-art by training relatively small model sizes, with even larger than Chinchilla-optimal datasets, and also emphasized the importance of taking into account not only training, but inference efficiency. As a result, for example, LLaMA-13B performed better than GPT-3 (175B) in most tests or evaluations.
However, even despite the much larger datasets used to train recent models such as LLaMA2 (see below), their performance is far from saturation (see the figure below, on the right). Clearly, the LLaMAs (and other LLM species) are still quite data-hungry, and we can continue to feed them more data to improve their performance (surely, the quality of the data is important - "we are what we eat", LLMs not being an exception). However, there is a clear inefficiency in the way the current models are being trained: countless (open-source) models rapidly expanding the LLM ecosystem are all being pre-trained on billions of tokens, only to restart the process over again once new data becomes available. A much cheaper and efficient solution would be to enable the continual pre-training of these models, i.e. updating pre-trained models with new data instead of re-training them from scratch. However, the distribution shift induced by novel data typically results in degraded performance on past data; moreover, learning from a sequence of dataset may be inferior to learning simultaneously from a mixed dataset with shuffled data.
This project aims to develop best practices that allow the practitioners to learn from (potentially infinite) sequences of continually arriving datasets, while maintaining the model performance close to (or even better than) the performance of the "ideal" model one could obtain from training on all data at once. We summarize the main points and recent results below.
Number of parameters vs tokens seen during training for various models. The number of parameters of each model is indicated in parentheses next to each model. Source.
Diverse types of new data
Annual updates to existing data
Introduction of newer sources and domains of data
Improved data compilations and preprocessing techniques, as seen in models such as RedPajama, SlimPajama, HLPT, RedPajamaV2, and RefinedWeb
Instruction tuning and alignment data for enhancing Reinforcement Learning with Human Feedback (RLHF)
Data from new modalities to build generalist agents
https://commoncrawl.github.io/cc-crawl-statistics/plots/crawlsize
Current practice
When a new dataset becomes available, it is merged with existing datasets to create a weighted composition, and models are trained from scratch on this new dataset. This method, although common, proves to be quite expensive. Do we really need to start over each time new data is introduced?
Is there an alternative? Can we instead build "lifelong" learning foundation models?
A viable solution might be Continual Learning! Rather than starting afresh, we can keep training on new data using existing checkpoints, continuously updating the model’s knowledge and broadening its capabilities. This approach is especially beneficial in scenarios where we accumulate new data continuously, such as generalist agents and RLHF.
So, what's the challenge? There are two main questions to address:
How can we learn new data efficiently?
How can we prevent Catastrophic Forgetting, i.e. train on new data without erasing previously acquired knowledge, even when the volume of new data is commensurate with the scale of pretraining data?"
Experimental Setup
Pre-training datasets: (1) Pile, (2) SlimPajama (SP) and the (3) union: Pile+SP
Composition of SlimPajama by Category
Settings and baselines
SlimPajama - 300B tokens
Lower bound?
Pile - 300B tokens -> SlimPajama - 300B tokens
Ours
Pile + SlimPajama - 600B tokens
Upper bound?
Model suite: 410M and 9.6B parameter models
Research questions
What is the effect of warm-up duration on forgetting?
How does varying maximum learning rate affect upstream and downstream performance?
How do non-converged checkpoints perform when continually pre-trained?
Effect of warm-up duration on validation loss (410M)
Ablating linear warmup for 0%, 0.5%, 1%, and 2% of training data.
We see that the amount of data used for warming up the learning rate does not significantly influence the perplexity on the downstream task (learning) or the upstream task (forgetting).
Effect of varying the maximum learning rate
Ablating maximum learning rate to be 6e-4, 3e-4, and 1.5e-4.
Altering the maximum learning rate can be an effective way to trade-off between adaptation and forgetting, with higher LR increasing adaptation and forgetting.
Using a constant learning rate leads to a model that fails to adapt well to the new dataset.
Continual pre-training from non-converged checkpoints
We notice that using an earlier checkpoint of Pile pretraining does not lead to learning faster on SlimPajama.
Forgetting and adaptation: 410M models
We observe that the model that is continually pretrained on 300B tokens of SP outperforms the models that was only trained (and hence, specific) on 300B tokens of the new SP dataset, as seen on the right. This shows that leveraging a pretrained model is beneficial if the new dataset captures well any downstream performance. On the other hand, on Pile (left), the continually pretrained model does not outperform the model only trained on 300B tokens of Pile.
The continually pretrained model also outperforms the model retrained from scratch on 600B tokens (300B+300B) of the mix of Pile and SP on SP (right), in only half as many tokens worth of compute. On the other hand, again, on Pile, the continually pretrained model fails to outperform the model trained on Pile+SP.
Forgetting and adaptation: 9.6B models
We see that our observations for the 410M model hold for the 9.6B-parameter model.
Continual Model (Pile, SP) matches the performance of the baseline joint model (Pile+SP)
Our 9.6B parameter model attains similar performance to RedPajama 7B and Pythia 12B on featured benchmarks
Our continually pre-trained 9.6B parameter model assumes that an existing pre-trained model is available (e.g., one of the LLMs on HuggingFace), therefore, its total flops are measured during the continual training phase. The plots compare different models trained with different compute budgets. While the models and the amount of data used is not inherently comparable, we observe that continuing to pre-train from a pre-trained checkpoint achieves similar or stronger performance on the featured benchmarks, despite using fewer flops.
Comparison with other popular open-source models
Conclusions
In a continual pre-training setting,
We investigated details of the linear warmup and cosine decay schedule.
We established the performance of continually pre-trained models compared to practically relevant baselines at two different scales.
Our results demonstrate that continual pre-training greatly reduces the cost of producing competitive language models.