Continual Learning of Foundation Models:
CL-FoMo Suite of 9.6B and 410M LLMs
CL-FoMo Suite of 9.6B and 410M LLMs
This is a progress update on our CL-FoMo (Continually-Learned Foundation Models) suite of open-source LLMs, containing four 410M models, and four 9.6B models, trained on Pile, SlimPajama (SP), Mix Pile+SP, and Continual (Pile, SP). On most benchmarks, our 9.6B continual model is competitive with the performance of the baseline joint model trained on mixed Pile+SP data (we call it IID Pile+SP to underscore that samples in the mix are identically distributed), and attains similar performance to RedPajama 7B and Pythia 12B on featured benchmarks.
See also:
Presentations given by the team:
The 6th Scaling workshop: Continual Pretraining of Foundation Models
AI @ Scale: Continual Pretraining of Foundation Models
Paper: "Simple and Scalable Strategies to Continually Pre-train Large Language Models" (submitted).
Earlier version: Continual Pre-Training of Large Language Models: How to (re)warm your model, in Efficient Natural Language and Speech Processing (ENLSP-III) workshop at NeurIPS 2023.
Main points
Our focus is on developing best practices re: learning rate schedule and replay, for continual pretraining of foundation models without the need to start training from scratch each time a new dataset becomes available.
While this is a work in progress, we already observe that:
Also, we observe that:
the amount of data used for warming up the learning rate does not significantly influence the perplexity on the downstream task (learning) or the upstream task (forgetting);
altering the maximum learning rate can be an effective way to trade-off between adaptation and forgetting;
continual learning is not necessarily worse than retraining on all the data available
However, even despite the much larger datasets used to train recent models such as LLaMA2 (see below), their performance is far from saturation (see the figure below, on the right). Clearly, the LLaMAs (and other LLM species) are still quite data-hungry, and we can continue to feed them more data to improve their performance (surely, the quality of the data is important - "we are what we eat", LLMs not being an exception). However, there is a clear inefficiency in the way the current models are being trained: countless (open-source) models rapidly expanding the LLM ecosystem are all being pre-trained on billions of tokens, only to restart the process over again once new data becomes available. A much cheaper and efficient solution would be to enable the continual pre-training of these models, i.e. updating pre-trained models with new data instead of re-training them from scratch. However, the distribution shift induced by novel data typically results in degraded performance on past data; moreover, learning from a sequence of dataset may be inferior to learning simultaneously from a mixed dataset with shuffled data.
This project aims to develop best practices that allow the practitioners to learn from (potentially infinite) sequences of continually arriving datasets, while maintaining the model performance close to (or even better than) the performance of the "ideal" model one could obtain from training on all data at once. We summarize the main points and recent results below.
Number of parameters vs tokens seen during training for various models. The number of parameters of each model is indicated in parentheses next to each model. Source.
Annual updates to existing data
Introduction of newer sources and domains of data
Improved data compilations and preprocessing techniques, as seen in models such as RedPajama, SlimPajama, HLPT, RedPajamaV2, and RefinedWeb
Instruction tuning and alignment data for enhancing Reinforcement Learning with Human Feedback (RLHF)
Data from new modalities to build generalist agents
https://commoncrawl.github.io/cc-crawl-statistics/plots/crawlsize
When a new dataset becomes available, it is merged with existing datasets to create a weighted composition, and models are trained from scratch on this new dataset. This method, although common, proves to be quite expensive. Do we really need to start over each time new data is introduced?
A viable solution might be Continual Learning! Rather than starting afresh, we can keep training on new data using existing checkpoints, continuously updating the model’s knowledge and broadening its capabilities. This approach is especially beneficial in scenarios where we accumulate new data continuously, such as generalist agents and RLHF.
How can we learn new data efficiently?
How can we prevent Catastrophic Forgetting, i.e. train on new data without erasing previously acquired knowledge, even when the volume of new data is commensurate with the scale of pretraining data?"
Composition of SlimPajama by Category
SlimPajama - 300B tokens
Lower bound?
Pile - 300B tokens -> SlimPajama - 300B tokens
Ours
Pile + SlimPajama - 600B tokens
Upper bound?
What is the effect of warm-up duration on forgetting?
How does varying maximum learning rate affect upstream and downstream performance?
How do non-converged checkpoints perform when continually pre-trained?
Ablating linear warmup for 0%, 0.5%, 1%, and 2% of training data.
We see that the amount of data used for warming up the learning rate does not significantly influence the perplexity on the downstream task (learning) or the upstream task (forgetting).
Ablating maximum learning rate to be 6e-4, 3e-4, and 1.5e-4.
Altering the maximum learning rate can be an effective way to trade-off between adaptation and forgetting, with higher LR increasing adaptation and forgetting.
Using a constant learning rate leads to a model that fails to adapt well to the new dataset.
We notice that using an earlier checkpoint of Pile pretraining does not lead to learning faster on SlimPajama.
We observe that the model that is continually pretrained on 300B tokens of SP outperforms the models that was only trained (and hence, specific) on 300B tokens of the new SP dataset, as seen on the right. This shows that leveraging a pretrained model is beneficial if the new dataset captures well any downstream performance. On the other hand, on Pile (left), the continually pretrained model does not outperform the model only trained on 300B tokens of Pile.
The continually pretrained model also outperforms the model retrained from scratch on 600B tokens (300B+300B) of the mix of Pile and SP on SP (right), in only half as many tokens worth of compute. On the other hand, again, on Pile, the continually pretrained model fails to outperform the model trained on Pile+SP.
We see that our observations for the 410M model hold for the 9.6B-parameter model.
Our continually pre-trained 9.6B parameter model assumes that an existing pre-trained model is available (e.g., one of the LLMs on HuggingFace), therefore, its total flops are measured during the continual training phase. The plots compare different models trained with different compute budgets. While the models and the amount of data used is not inherently comparable, we observe that continuing to pre-train from a pre-trained checkpoint achieves similar or stronger performance on the featured benchmarks, despite using fewer flops.
In a continual pre-training setting,
We investigated details of the linear warmup and cosine decay schedule.
We established the performance of continually pre-trained models compared to practically relevant baselines at two different scales.
Our results demonstrate that continual pre-training greatly reduces the cost of producing competitive language models.