MRL:训练时套娃,推理时按需截维
MRL(Matryoshka Representation Learning,2022)的核心思路:训练时让同一个 d 维向量的前 m 个维度($m$ 取 ${8, 16, 32, \dots, d}$ 这一组对数刻度上的值)独立承担分类损失,得到一个由粗到细嵌套的表示。推理时按算力预算只取前 $m$ 维,效果与单独训练一个 m 维模型相当。

问题背景
向量表示是现代 ML 系统的基础设施,一次 forward 算出 $d$ 维 embedding,下游分类、检索都消费这同一个向量。但下游任务的算力和精度需求差别很大,把所有任务都按 $d = 2048$ 部署,对召回类的廉价任务是浪费,对长尾任务又可能不够。
现有解法各有短板。独立训练多个低维模型(FF,fixed feature)要存多份模型权重和多份索引,且每加一档就要重训。事后压缩(PCA/SVD)和随机选维在维度变低时掉点严重。slimmable networks 之类子网方法切换粒度时仍然需要不同的 forward pass,对检索场景意味着重新编码整个数据库。论文要解决的就是这个矛盾:用同一个 forward pass 拿到一组对应不同算力档位的 embedding,且每一档都不输给独立训练的同尺寸模型。
方法
设嵌套维度集合 $M \subseteq [d]$,论文用 $M = {8, 16, 32, 64, 128, 256, 512, 1024, 2048}$,即对数刻度上 $|M| \approx \log d$ 个尺寸。模型 $F$ 把输入 $x$ 映射到 $d$ 维向量 $z = F(x)$,对每个 $m \in M$,截取前 $m$ 维 $z_{1:m}$,配一个独立的线性分类头 $W^{(m)} \in \mathbb{R}^{L \times m}$,套上 softmax 交叉熵 $\mathcal{L}_{\text{CE}}$。把 $|M|$ 个损失加权求和,对 $F$ 与所有 $W^{(m)}$ 联合优化:
$$\mathcal{L}_{\text{MRL}} = \frac{1}{N} \sum_{i=1}^{N} \sum_{m \in M} c_m \cdot \mathcal{L}_{\text{CE}}\bigl(W^{(m)} z_{i,1:m},, y_i\bigr)$$
其中 $z_i = F(x_i)$,权重 $c_m$ 默认全部为 1。论文在消融里讨论了非均匀加权的潜力但主实验没调。
一个直接的省参变体叫 MRL-E(Efficient MRL):让所有分类头共享同一组权重 $W \in \mathbb{R}^{L \times d}$,第 $m$ 档头就取 $W^{(m)} = W_{1:m}$。这把分类头的参数从 $\sum_{m \in M} L \cdot m \approx 2L \cdot d$ 砍到 $L \cdot d$,对类别数 $L$ 极大的场景(比如 ImageNet-21K,JFT 那种百万级标签空间)是必要的。MLM 训练里 BERT 的输出投影本就和 embedding 矩阵共享权重,所以 BERT 上 MRL 自动就是 MRL-E 形式。
对比学习里的适配也是直接的:对要被对比的两端都套 MRL,每个嵌套维度独立做 $L_2$ 归一化(论文强调归一化必须在每档维度上分别做,否则效果会差)。
为什么这个嵌套能 work
直觉上这个设定有点反常:低维度的损失是高维度损失的 strict subset,按理高维度应该任意学,把信息分散到 2048 维上即可,没动力把"前 8 维"专门做成可用的表示。MRL 通过对前 $m$ 维显式施加分类损失,强迫模型把粗粒度的判别信息压到向量的开头。$z_{1:8}$ 必须能直接喂给一个线性分类器把 1000 类分开,这就把它逼成了一个低维瓶颈表示。
更出乎意料的是 MRL 只优化 $O(\log d)$ 个嵌套尺寸,但中间那些没被显式约束的维度(比如 24,48,96)的精度竟然在两端之间平滑插值。论文 Figure 5 给出了这一点:ResNet50-MRL 在 12,24,48 这些没训练过的尺寸上取 1-NN 精度,曲线和训练过的尺寸基本连续。这意味着信息确实是逐维度堆叠的,而不是跳跃式地塞进 ${8, 16, 32, \dots}$ 这几个刻度里。
实验结果
ImageNet-1K 分类(ResNet50)。Figure 2 给出 linear probe top-1,MRL 在每个尺寸上都不输独立训练的 FF,MRL-E 从 16 维起在 1% 之内。1-NN 精度 Figure 3 的差距更大,MRL 在低维度(8,16,32)上比 FF 高 1 到 2 个点。FF 自己把每档当独立任务训了一遍 9 个模型,MRL 只训一次就追平甚至反超。
Web-scale。ViT-B/16 在 JFT-300M 上,以及 ALIGN 的视觉编码器(ViT-B/16 + BERT 在 ALIGN 数据上对比训练),都能直接套 MRL。Table 4 的 1-NN top-1 上,ALIGN 在 12 维只有 11.90,ALIGN-MRL 是 43.57,差距巨大;到 192 维收敛到 66.71 vs 67.00。JFT-ViT 12 维 27.07 → JFT-ViT-MRL 53.61。低维差距大是 ALIGN/JFT 没有专门训过低维 head 的缘故,对照的 baseline 是从高维 ALIGN/JFT 模型上随机选维。
BERT MLM。Table 7 给出 BERT-FF 和 BERT-MRL 在每档维度上的验证准确率,差距都在 0.5% 以内(768 维 65.54 vs 65.00)。BERT 上 MRL 的优势不像视觉那么明显,主要价值在于一个模型出多档维度。
ImageNet-1K 检索 mAP@10(Table 8)。Ds=8 时 FF 53.42,MRL 56.74;Ds=16 时 FF 61.63,MRL 62.94;Ds=2048 时 FF 62.90,MRL 65.20。MRL 在所有维度上都不输 FF,低维优势更明显。一个值得注意的现象:FF 的 mAP@10 在 Ds 从 64 升到 1024 反而下降(63.26 → 61.13),MRL 则单调上升到 65.20。论文未对此专门解释,但这反映出 FF 每档独立训练时只优化分类损失,多档之间没有约束,得到的高维表示对检索任务并非线性更优。
Adaptive Classification 与 Adaptive Retrieval
有了从 8 到 2048 共 9 档可用的表示,部署时不必所有样本都用 2048 维。
AC。在 holdout 验证集上对每档分类头学一个 max softmax probability 阈值 $t_m$:从 $m = 8$ 开始预测,置信度大于 $t_8$ 就停下用这一档结果,否则升到 $m = 16$,依次到 2048。论文用 ImageNet-1K 验证集 10K 条样本网格搜索 $t_m$(0 到 1 取 100 个点)。期望维度 $\approx 37$ 时 ImageNet-1K top-1 达到 76.30%,与 FF-512 同精度,等于 14× 压缩。论文同时给出了"按累积维度算"的更悲观估计 62 维(因为多档头的累计计算量),仍是 8.2× 压缩。

AR。检索按两阶段做:用 $D_s$ 维(短)embedding 取 K=200 候选,再用 $D_r$ 维(长)embedding 重排。$D_r$ 重排只在 200 个候选上跑,只占 400 KFLOPs,主要成本在 $D_s$ 那一步对全库的近邻搜索。ImageNet-1K 上 $D_s = 16,, D_r = 2048$ 与 single-shot $D_s = 2048$ 同精度,FLOPs 上 128× 加速,HNSW 实测 wall-clock 14× 加速。ImageNet-4K 因为类多更难,要 $D_s = 64$ 才能保持精度,对应 32× 理论加速和 6× 实测加速。

Funnel Retrieval。AR 要手选 $D_s$ 和 $D_r$,论文又提出 funnel:每一步把候选数减半,把维度翻倍。比如 200 → 100 → 50 → 25 → 10 候选,配 16 → 32 → 64 → 128 → 256 → 2048 维。ImageNet-1K 上 funnel 与 single-shot 2048 维相比 top-1 在 0.1% 以内,128× FLOPs 节省。这给出一种不依赖手调超参的级联策略。
几个易被忽视的细节
信息插值。MRL 只对 9 个尺寸算损失,但中间维度(12,24,48 等)也能用,且精度连续。Table 8 里 MRL-Interpolated 行 Ds=12 时 mAP@10 是 60.84,介于 8 维 56.74 和 16 维 62.94 之间。这意味着可以按任意整数维度切,不必拘泥于 2 的幂次。
对比学习里的归一化。每个嵌套尺寸都要独立做 $L_2$ 归一化。直接对 $z_{1:m}$ 做 normalize 等价于把每档当成独立单位向量,而不是从 $z / |z|$ 截前 $m$ 维(后者长度小于 1)。论文 Appendix C 把这一点放进了实现说明里。
c_m 全部设为 1 是简化选择。Appendix K Table 27 的消融显示,给低维更高的 $c_m$ 能小幅改善低维精度而不伤高维,但论文主实验全用 1,因为想避免引入需要调的超参。后续工作如果在意低维档的极致精度,是一个可调的钩子。
granularity 起点不能太低。Appendix K Table 28 给出,从 4 维起步会拖累整体;Table 29 验证了对数间距比线性间距好。这两点反映 MRL 的归纳偏置:要避开太瘦的瓶颈维度,因为分类任务在那个维度本身就不可解,强加损失会污染上层。
鲁棒性与超类。MRL 在 ImageNet-V2/R/A/Sketch 上的鲁棒性不输 FF,ImageNet-A 上还高 0.6%。31 类超类分类(Figure 10)上 MRL 在 8 维 top-1 已经达到 85.57%(同维度 FF 略低),远高于 1000 类细粒度的 1-NN 精度。说明低维瓶颈天然适合做粗粒度路由:先用 8 维判超类,再用更高维做细分。
FLUID 长尾。在 FLUID 长尾顺序学习里,MRL 在尾部新类上比 FF 高 2%(Table 16),head 类上不输。论文未给出这一收益的具体解释。
Disagreement 现象。同一张图,不同维度可能给出不同预测。理想路由(已知 oracle)下,MRL 模型 top-1 还能再涨 4.6%。这反映出 MRL 学到的不是单调的"维度越高越准",而是在不同维度上对不同样本有不同偏好。论文把它列为未来方向,没给路由方案。
小结
MRL 的设计思想是用 $O(\log d)$ 个嵌套损失逼模型把信息按粗到细排进向量。一次训练得到 9 档可用维度,每档不输独立训的 FF。配 AC/AR 之后,分类有 14× 维度压缩,检索有 128× FLOPs 加速和 14× wall-clock 加速。代价主要在训练阶段:每个 batch 要算 9 个分类头的损失,但在 ResNet50/ViT/BERT/ALIGN 这些主流架构上论文都没专门调超参就跑通了。MRL 没有解决的问题包括:嵌套损失权重默认全 1 而非 Pareto 最优,不同维度给不同损失(高维 robust 损失,低维高 recall 损失)的设想没在主实验里做,路由策略只用了 max softmax 阈值这种简单 heuristic。但作为一种几乎零成本的多粒度训练改造,MRL 的工程价值已经被后来的 OpenAI text-embedding-3 和 Nomic Embed 等模型证明,是当前可变维度 embedding 的标准做法。