DoReMi: Guiding Large Model Pretraining Mixtures via Small-Model DRO Trajectories

DoReMi (DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining, NeurIPS 2023, Google DeepMind + Stanford) is one of the earliest works on automatic pretraining data mixture optimization. Subsequent methods like RegMix, DoGE, and Online Data Mixing all respond to or build upon the framework it introduced. The core idea: run Group DRO (Group Distributionally Robust Optimization, an online algorithm from Sagawa et al. 2020 that minimizes worst-group loss) on a 280M proxy model to dynamically adjust domain weights, average the weights across the training trajectory, then use those weights to train an 8B model. The total cost of the mixture search (training both a reference and a proxy model at 280M) is roughly 8% of the 8B main model’s compute, yet the main model’s downstream accuracy improves by about 6.5 points and reaches the baseline’s final accuracy in roughly 1/2.6 the training steps.

Paper link: arxiv.org/abs/2305.10429

Problem and Motivation

Pretraining corpora are mixtures of multiple domains (Wikipedia, books, web text, code, etc.), and the sampling proportion of each domain directly affects final model quality. The Pile’s domain weights were set heuristically; PaLM/GLaM’s weights were tuned on downstream tasks, which requires training many models for grid search and risks overfitting to a particular evaluation suite.

The problem DoReMi addresses: without any downstream task labels, automatically find domain weights that make the large model perform well across all domains. The optimization target is not “lowest average loss” but “lowest worst-case excess loss across domains”—a minimax criterion.

Three-Step Pipeline

DoReMi three-step pipeline: train reference model → train proxy model with Group DRO to extract domain weights → train large model with new weights

The figure above shows DoReMi’s overall pipeline: train a small reference model on initial domain weights, train a same-sized proxy model with Group DRO to extract domain weights, then resample data with the optimized weights to train the large model.

Train the reference model. Train a 280M-parameter model on initial domain weights (e.g., The Pile’s default weights or uniform weights for GLaM). This reference model establishes a baseline difficulty for each example: high reference loss means the example is inherently high-entropy (e.g., random text) and doesn’t need extra attention; low reference loss means the content is “learnable.”

Train the proxy model with Group DRO and extract domain weights. The core step. Group DRO’s basic idea: partition training data into groups (domains), and instead of optimizing average loss across all groups, optimize the worst group’s loss. It maintains a set of domain weights $\alpha$ and alternates between two operations during training: (1) gradient ascent on $\alpha$ to concentrate weight on the domain with the highest current loss; (2) gradient descent on model parameters to reduce the weighted loss. The effect is to force the model to perform reasonably on every domain—any domain that falls behind gets amplified by $\alpha$.

In DoReMi, the proxy model’s training objective is a minimax excess loss:

$$\min_\theta \max_{\alpha \in \Delta^k} \sum_{i=1}^{k} \alpha_i \cdot \frac{1}{\sum_{x \in D_i} |x|} \sum_{x \in D_i} \left[ \ell_\theta(x) - \ell_{\text{ref}}(x) \right]$$

Here $\ell_\theta(x)$ and $\ell_{\text{ref}}(x)$ are the proxy and reference model’s NLL on example $x$, and $|x|$ is the number of tokens. The inner max performs exponentiated gradient ascent on domain weights $\alpha$, concentrating weight on domains with the highest excess loss; the outer min updates the proxy model parameters with a standard optimizer. The two are alternated at each step. After training, the time-averaged $\alpha_t$ across the entire trajectory becomes the final domain weights.

Implementation details: each step samples a minibatch uniformly (not according to current domain weights), per-token excess loss is clipped to non-negative (to satisfy Group DRO’s requirement for non-negative losses), exponentiated gradient step size $\eta = 1$, smoothing parameter $c = 10^{-3}$.

Train the large model with new domain weights. Use the weights from step two as the new sampling distribution, resample data, and train an 8B model in the standard way.

Design Logic of the Excess Loss

Why not directly optimize the proxy model’s raw loss? The paper ablates this: using only proxy loss (which tends to upweight the hardest domains) or only negative reference loss (which tends to upweight the easiest domains for the reference) both underperform the full excess loss.

Excess loss = proxy loss $-$ reference loss, measuring the proxy model’s “room for improvement” relative to the reference. High excess loss typically means the example is learnable (reference loss is not high) but the proxy hasn’t learned it yet (current training is insufficient). Low excess loss can arise from two causes: the example is inherently high-entropy and unlearnable (reference loss is also high), or the example is so easy the proxy has already learned it. Neither case needs extra weighting.

This design prevents DRO from wasting weight on inherently unlearnable noisy domains, and also from wasting it on domains that are already saturated.

Experiments on The Pile

Train an 8B model with domain weights from the 280M proxy, evaluated on The Pile’s 22 domains.

DoReMi 8B model training curve on The Pile: downstream accuracy improves by ~6.5 points, reaches baseline accuracy ~2.6x faster

The figure above shows the 8B model’s downstream one-shot accuracy vs. training steps on The Pile. DoReMi leads the baseline throughout training, with a final accuracy gap of about 6.5 points; the accuracy the baseline needs 200k steps to reach, DoReMi achieves at roughly 75k steps.

  • Perplexity decreases on all domains. Even heavily downweighted domains (ArXiv, PubMed Central, StackExchange) have better perplexity than the baseline. This is the paper’s most counterintuitive finding.
  • Across 5 generative few-shot tasks (TriviaQA, NaturalQuestions, WebQuestions, SQuADv2, LAMBADA), average accuracy improves by about 6.5 points.
  • Reaching the baseline’s final accuracy requires only ~75k steps vs. 200k for the baseline, a ~2.6x speedup.

DoReMi’s domain weights raise Pile-CC (the Common Crawl subset) from 11.2% to 60.6% while sharply reducing ArXiv (10.5% → 0.4%), PubMed Central (10.7% → 0.5%), StackExchange (9.3% → 1.5%), and others.

Per-domain log-perplexity: despite downweighting some domains, DoReMi outperforms the baseline on all 22 domains

The figure above compares per-domain log-perplexity for the 8B models on The Pile. Red is the baseline, blue is DoReMi. DoReMi is lower on every single domain, including heavily downweighted ones like ArXiv and PubMed Central.

Why does perplexity drop even on downweighted domains? The paper’s hypothesis: the lowest-entropy domains (simple, repetitive content) and highest-entropy domains (close to uniform distribution) both don’t need many samples to learn well. Allocating more budget to medium-entropy domains (like diverse web text) produces positive transfer that benefits all domains.

Experiments on the GLaM Dataset

GLaM has only 8 domains with existing downstream-tuned weights available. DoReMi starts from uniform initial weights and runs 3 rounds of iterated DoReMi (each round uses the previous round’s output weights as the new reference’s training weights). The final weights follow a similar pattern to the downstream-tuned weights: Filtered Webpages gets the highest weight (0.51 vs. downstream-tuned 0.42), Conversations and Books come next, Forums and News are suppressed.

In downstream accuracy, the iterated DoReMi round-2 8B model performs comparably to the model trained with downstream-tuned domain weights.

Cross-Scale Transfer and Proxy Model Size

The paper examines how proxy model scale affects the final large model. Fixing the large model at 8B and sweeping proxy size from 70M to 1B:

Proxy sizeWorst-case log-pplAvg log-pplAll 22 domains beat baseline?
Baseline (8B)1.711.640/22
70M → 8B1.631.5322/22
150M → 8B1.561.5222/22
280M → 8B1.461.4022/22
1B → 8B1.581.5422/22

280M is the sweet spot. The 1B proxy actually underperforms 280M; the paper attributes this to Group DRO’s optimizer degrading at larger model scales—the 1B proxy model itself is notably worse than a same-size standard baseline (19/22 domains worse), while the 280M proxy is better than its baseline (19/22 domains better). Nonetheless, even with a poor-quality 1B proxy, the resulting domain weights still give the 8B main model a 2x speedup to baseline accuracy.

A separate set of experiments matches proxy and main model sizes (280M/510M/760M/1B). DoReMi consistently improves downstream accuracy by about 2 points and achieves ~4x speedup to baseline accuracy across all scales.

Limitations

Domain granularity. DoReMi’s effectiveness depends directly on how domains are defined: The Pile with 22 domains shows large gains; GLaM with only 8 domains shows much smaller improvements. If data lacks pre-existing domain labels (e.g., a large Common Crawl dump), DoReMi cannot be applied directly. The paper discusses the possibility of discovering fine-grained domains via clustering but doesn’t experimentally validate this.

DRO degrades for large proxy models. From 280M to 1B proxy, domain weight quality drops. The paper hypothesizes a mismatch between DRO’s loss reweighting during proxy training and the resampling used for the main model: the proxy sees weighted gradients per batch, while the main model simply trains on resampled data. The paper observes this mismatch is more severe at 1B scale than 280M but doesn’t provide a complete theoretical explanation.

Non-trivial compute cost. While the reference + proxy total training is only about 8% of the large model’s compute, the proxy must run for the full number of main-model training steps (200k). The paper observes that domain weights change most in early training and stabilize later, suggesting early stopping with extrapolation is possible, but doesn’t rigorously validate this.

Single DRO run may reach local optima. The 280M and 1B proxies yield different domain weight patterns (the former favors Pile-CC, the latter favors OpenWebText2), suggesting multiple local optima in the domain weight space. Iterated DoReMi works on GLaM (converging in 3 rounds), but multi-round iteration on The Pile is not reported.

Summary

DoReMi frames pretraining mixture optimization as a minimax problem over excess loss, solves it with Group DRO on a small model, and transfers the resulting domain weights to a large model. Subsequent work has largely followed this “small model proxies for large model mixture search” paradigm (RegMix replaces DRO with regression, DoGE uses gradient signals instead of loss, Online Data Mixing does online adjustment). The key empirical finding DoReMi established: domain weights optimized on a 280M model can directly accelerate 8B model training, and in the Pile experiments the improvement covers all 22 domains—every domain’s perplexity beats the baseline.