MRL: Nest Embeddings During Training, Truncate at Inference

The core idea of MRL (Matryoshka Representation Learning, 2022): during training, force the first $m$ dimensions of the same $d$-dimensional vector (with $m$ chosen on a logarithmic ladder ${8, 16, 32, \dots, d}$) to independently carry the classification loss, producing a coarse-to-fine nested representation. At inference, just take the first $m$ dimensions according to the compute budget; accuracy matches a separately trained $m$-dim model.

MRL applies losses on O(log d) nested dimensions during training; at inference the first m dims are taken from the same vector

Background

Vector representations are infrastructure in modern ML systems. A single forward pass produces a $d$-dim embedding that downstream classification and retrieval both consume. But downstream tasks have very different compute and accuracy needs. Deploying everything at $d = 2048$ wastes compute on cheap recall tasks, and may still be insufficient on long-tail ones.

Existing remedies all have weaknesses. Training multiple low-dim models independently (FF, fixed feature) requires storing several model weights and several indices, and adding a new tier means retraining. Post-hoc compression (PCA/SVD) and random feature selection lose accuracy badly at low dims. Sub-network methods like slimmable networks still need a separate forward pass per granularity, which means re-encoding the entire database for retrieval. The paper sets out to resolve this tension: get a set of embeddings spanning multiple compute tiers from a single forward pass, with each tier matching an independently trained same-size model.

Method

Let $M \subseteq [d]$ be the set of nesting dimensions; the paper uses $M = {8, 16, 32, 64, 128, 256, 512, 1024, 2048}$, i.e. $|M| \approx \log d$ sizes on a log scale. A model $F$ maps input $x$ to a $d$-dim vector $z = F(x)$. For each $m \in M$, take the first $m$ dimensions $z_{1:m}$, attach an independent linear classifier head $W^{(m)} \in \mathbb{R}^{L \times m}$, and apply softmax cross-entropy $\mathcal{L}_{\text{CE}}$. The $|M|$ losses are summed (with weights), and $F$ together with all $W^{(m)}$ is jointly optimized:

$$\mathcal{L}_{\text{MRL}} = \frac{1}{N} \sum_{i=1}^{N} \sum_{m \in M} c_m \cdot \mathcal{L}_{\text{CE}}\bigl(W^{(m)} z_{i,1:m},, y_i\bigr)$$

where $z_i = F(x_i)$ and $c_m$ defaults to 1 for all $m$. The paper discusses non-uniform weighting in ablations but does not tune it in the main experiments.

A direct parameter-saving variant called MRL-E (Efficient MRL) ties all classifier heads to a single weight matrix $W \in \mathbb{R}^{L \times d}$, with $W^{(m)} = W_{1:m}$. This cuts classifier-head parameters from $\sum_{m \in M} L \cdot m \approx 2L \cdot d$ down to $L \cdot d$, which matters when $L$ is huge (think ImageNet-21K or JFT-scale label spaces). In MLM training, BERT’s output projection already shares weights with the embedding matrix, so MRL on BERT is automatically the MRL-E form.

Adapting to contrastive learning is similarly direct: apply MRL to both contrasted sides, and do $L_2$ normalization independently per nesting dimension. Appendix C records this implementation detail.

Why this nesting works

This setup is somewhat counterintuitive: the loss on a low dim is a strict subset of the loss on a higher dim. A priori, the model should be free to spread information across all 2048 dims with no reason to make “the first 8” usable on their own. By imposing classification loss explicitly on the first $m$ dims, MRL forces the model to push coarse-grained discriminative information to the front of the vector. $z_{1:8}$ has to be directly fed to a linear classifier that separates 1000 classes, which makes it function as a low-dim bottleneck.

What is more surprising is that MRL only optimizes $O(\log d)$ nested sizes, yet the accuracy at intermediate dimensions never explicitly constrained (e.g. 24, 48, 96) interpolates smoothly between the two endpoints. Figure 5 shows this: ResNet50-MRL evaluated by 1-NN at 12, 24, 48 (untrained sizes) gives a curve that is essentially continuous with the trained sizes. Information really is stacked dimension by dimension, not packed in jumps at ${8, 16, 32, \dots}$.

Experimental results

ImageNet-1K classification (ResNet50). Figure 2 shows linear-probe top-1: MRL matches independently trained FF at every size, and MRL-E is within 1% from 16 dims onward. Figure 3 (1-NN accuracy) shows a bigger gap, with MRL beating FF by 1–2 points at low dims (8, 16, 32). FF trains 9 separate models, one per tier; MRL trains once and matches or surpasses them.

Web-scale. ViT-B/16 on JFT-300M, and the vision encoder of ALIGN (ViT-B/16 + BERT contrastively trained on ALIGN data), can both be wrapped in MRL directly. Table 4 (1-NN top-1): ALIGN at 12 dims is only 11.90, while ALIGN-MRL is 43.57 — a huge gap that closes by 192 dims (66.71 vs 67.00). JFT-ViT 12 dims 27.07 → JFT-ViT-MRL 53.61. The big low-dim gaps reflect that the ALIGN/JFT baselines were never trained with low-dim heads; the comparison takes random features from the high-dim model.

BERT MLM. Table 7 shows BERT-FF and BERT-MRL validation accuracy at every tier within 0.5% of each other (768 dims: 65.54 vs 65.00). MRL’s advantage on BERT is less dramatic than in vision; the value is mostly in getting multiple tiers from one model.

ImageNet-1K retrieval mAP@10 (Table 8). At Ds=8: FF 53.42, MRL 56.74; at Ds=16: FF 61.63, MRL 62.94; at Ds=2048: FF 62.90, MRL 65.20. MRL matches or beats FF at every size, with a clearer edge at low dims. One notable observation: FF mAP@10 actually drops as Ds rises from 64 to 1024 (63.26 → 61.13), while MRL grows monotonically to 65.20. The paper offers no specific explanation, but this reflects that FF, trained independently per tier, only optimizes classification loss with no constraint across tiers, so its higher-dim representations are not necessarily monotonically better for retrieval.

Adaptive Classification and Adaptive Retrieval

With 9 usable tiers from 8 to 2048, deployment doesn’t need to commit every sample to 2048 dims.

AC. On a holdout validation set, learn a max-softmax-probability threshold $t_m$ per tier: start at $m = 8$, accept the prediction if confidence exceeds $t_8$, otherwise climb to $m = 16$, and so on up to 2048. The paper grid-searches $t_m$ over 100 points in $[0, 1]$ on a 10K subset of ImageNet-1K validation. At expected dim $\approx 37$, ImageNet-1K top-1 reaches 76.30%, matching FF-512 — a 14× compression. The paper also reports a more pessimistic “cumulative dim” estimate of 62 (accounting for compute through multiple heads), which still gives 8.2× compression.

Adaptive Classification reaches FF-512 accuracy (76.30%) at about 37 dims — a 14× compression

AR. Retrieval is two-staged: shortlist K=200 candidates with a short embedding of $D_s$ dims, then rerank with a long embedding of $D_r$ dims. Reranking 200 candidates at $D_r$ costs only 400 KFLOPs; the dominant cost is the shortlist NN search across the full database at $D_s$. On ImageNet-1K, $D_s = 16,, D_r = 2048$ matches single-shot $D_s = 2048$ in accuracy with a 128× FLOPs speedup and a 14× HNSW wall-clock speedup. ImageNet-4K is harder (more classes), so $D_s = 64$ is needed, giving 32× theoretical and 6× wall-clock speedup.

Adaptive Retrieval mAP@10 vs MFLOPs trade-off on ImageNet-1K/4K — every (Ds, Dr) combo lies above the single-shot Pareto frontier

Funnel Retrieval. AR still requires hand-picking $D_s$ and $D_r$. The paper proposes funnel: at each step, halve the shortlist and double the dim. E.g. 200 → 100 → 50 → 25 → 10 candidates with 16 → 32 → 64 → 128 → 256 → 2048 dims. On ImageNet-1K, funnel matches single-shot 2048-dim top-1 within 0.1% with 128× FLOPs savings — a cascade strategy that avoids manual hyperparameter choice.

Easily overlooked details

Information interpolation. MRL only computes losses at 9 sizes, but intermediate dims (12, 24, 48, etc.) are usable too with continuous accuracy. Table 8: MRL-Interpolated at Ds=12 gives mAP@10 = 60.84, between 8-dim 56.74 and 16-dim 62.94. So you can slice at any integer dimension, not just powers of 2.

Normalization in contrastive learning. Each nested size needs its own $L_2$ normalization. Normalizing $z_{1:m}$ directly treats each tier as an independent unit vector, as opposed to taking the first $m$ dims of $z / |z|$ (which would have norm less than 1). Appendix C records this in the implementation notes.

$c_m = 1$ for all $m$ is a simplification. Appendix K Table 27 ablates this and shows that giving low dims a higher $c_m$ slightly improves low-dim accuracy without hurting high dims. The main experiments use 1 throughout to avoid introducing tunable hyperparameters. If you care about extreme low-dim accuracy, this is a tunable knob.

Granularity start cannot be too low. Appendix K Table 28 shows that starting at 4 dims drags down overall performance; Table 29 confirms log spacing beats linear spacing. These reflect MRL’s inductive bias: avoid bottleneck dims that are too narrow, because classification at those dims is itself unsolvable and forcing a loss there pollutes the upper layers.

Robustness and superclass. MRL is at least as robust as FF on ImageNet-V2/R/A/Sketch, and gains 0.6% on ImageNet-A. On 31-way superclass classification (Figure 10), MRL at 8 dims already reaches 85.57% top-1 (slightly above FF at the same dim), far above the 1-NN accuracy on the 1000-way fine-grained task. Low-dim bottlenecks naturally suit coarse-grained routing: judge superclass with 8 dims first, then refine with higher dims.

FLUID long-tail. On the FLUID long-tail sequential learning benchmark, MRL outperforms FF by 2% on novel/tail classes (Table 16) without losing on head classes. The paper does not give a specific explanation for this gain.

Disagreement across dimensions. The same image can produce different predictions at different dims. With ideal routing (oracle), MRL top-1 could go up another 4.6%. This means MRL doesn’t learn a monotonic “higher dim = better” — different samples have different preferences across dims. The paper lists this as future work without proposing a routing scheme.

Summary

MRL’s design idea is to use $O(\log d)$ nested losses to push the model into arranging information from coarse to fine within a single vector. One training run gives 9 usable dimension tiers, each as good as an independently trained FF. Combined with AC/AR, this delivers 14× dim compression for classification and 128× FLOPs / 14× wall-clock speedup for retrieval. The cost lies in training: 9 classifier-head losses per batch, but the paper made this work on ResNet50/ViT/BERT/ALIGN without tuning hyperparameters. Open issues: the default $c_m = 1$ is not Pareto optimal, the idea of using different losses at different fidelities (robust loss at high dim, high-recall loss at low dim) was not in the main experiments, and routing only uses a simple max-softmax threshold heuristic. Still, as a near-zero-cost multi-granularity training change, MRL’s engineering value has been validated by later models like OpenAI text-embedding-3 and Nomic Embed — it is now the standard approach to variable-dimension embeddings.