GritLM:用一个 LLM 同时做 embedding 和生成

GritLM(Generative Representational Instruction Tuning,2024)的核心思路:用同一个 LLM 同时承担 embedding 和生成两类任务,通过 instruction 格式区分输入属于哪条流,分别用对比损失和语言建模损失训练,两个目标相加。 之前 HyDE 展示了"LLM 负责相关性,encoder 负责相似度"是可以解耦的,GritLM 进一步把两者合回一个模型。

问题背景

embedding 模型和生成模型一直是两条独立路线。BERT 类双向 encoder 适合做表示,decoder-only LLM 适合做生成,把 LLM 直接拿来取 hidden state 作为 embedding 一般效果不好。论文给的对照是 Llama 2 70B 用 weighted-mean pooling 在 MTEB 上只拿到 35.6,而 BGE Large 0.34B 是 64.2,参数差两个数量级却被反超。

反方向同样不通。把 embedding 模型 fine-tune 之后再加回 LM head 做生成,论文实测 Emb.-only 7B 在 MMLU 上是 23.5(随机基线 25.0),意味着 embedding 训练把生成能力彻底洗掉了。

实际部署里这导致 RAG 系统要同时持有两个模型:embedding 模型编码 query 和文档做检索,生成模型再读 query + 召回文档生成回答。query 和文档都要在两个模型上各跑一遍,共 4 次 forward pass。GritLM 想消除的就是这种割裂。

GritLM 用 instruction 把同一个模型路由到 embedding 或生成两条流

GRIT 方法

GritLM 架构与输入格式

一个预训练 LLM,两份训练数据,两套 forward 路径,两个损失加权求和。

embedding 流。输入格式 <s><|user|>{instruction}<|embed|>{sample},注意力切成 bidirectional,对最后一层 hidden state 做 mean pooling(仅对 sample 部分平均,instruction 和格式 token 不计入但通过 self-attention 影响表示)。损失是 in-batch negatives 的 InfoNCE:

$$\mathcal{L}_{\text{Rep}} = -\frac{1}{M} \sum_{i=1}^{M} \log \frac{\exp(\sigma(f(q_i), f(d_i)) / \tau)}{\sum_{j=1}^{M} \exp(\sigma(f(q_i), f(d_j)) / \tau)}$$

$\sigma$ 是 cosine similarity,$\tau$ 是温度。

生成流。输入格式 <s><|user|>{instruction}<|assistant|>{response}</s>,注意力保持 causal,过 LM head 做 next-token 预测,损失只在 response 部分计算。

总损失:

$$\mathcal{L}_{\text{GRIT}} = \lambda_{\text{Rep}} \mathcal{L}_{\text{Rep}} + \lambda_{\text{Gen}} \mathcal{L}_{\text{Gen}}$$

<|embed|> 这个特殊 token 是关键开关。E5 数据集的 instruction 没有固定前缀,模型必须靠这个 token 知道当前样本走对比损失而不是 LM 损失。

GritLM 7B 用 Mistral 7B 初始化,embedding batch 2048,生成 batch 256,训练 1253 步(embedding 1.36 epoch,生成 1 epoch)。GritLM 8x7B 因算力受限,embedding batch 降到 256。

为什么两个目标能共存

直觉上对比损失逼模型学统一的 sentence-level 表示,LM 损失逼模型学逐 token 的条件概率,两者目标不一致,混训理应互相伤害。论文的实测结果是 GritLM 7B 在 MTEB 上 66.8,单独训 embedding 的 Emb.-only 也是 66.8;GritLM 7B 生成平均 55.5,单独训生成的 Gen.-only 是 55.2。两个目标几乎不打架。

论文给的解释是这两类任务都要求模型深度理解自然语言,差别只在表达方式。论文进一步推测模型内部可能存在"少量参数充当开关",让最终表示要么适合 mean pooling 后做 embedding,要么适合喂给 LM head 做生成,但论文将这一点明确标注为推测(“Possibly…"),未做定位实验。

值得注意的是 MEDI2 数据集下,加生成目标后 embedding 性能反而比 embedding-only 更高,但这一现象在切换到 E5 数据集后消失,两者持平。

实验结果

MTEB(embedding)。GritLM 7B 在 56 个数据集平均 66.8,超过 E5 Mistral 7B 的 66.6 和 BGE Large 的 64.2,是当时开源模型 SOTA。GritLM 8x7B 是 65.7,比 7B 略低,论文归因于 embedding batch 从 2048 砍到 256(算力受限)。

生成任务。GritLM 7B 在 MMLU/GSM8K/BBH/TyDi QA/HumanEval/AlpacaEval 六项平均 55.5,超过 Tülu 2 7B 的 46.3 和 Mistral 7B Instruct 的 44.1,已经能压过 Llama 2 70B 的 46.4。GritLM 8x7B 平均 65.7,在论文对比的开源生成模型里最高,超过 Mixtral 8x7B Instruct 的 60.3 和 Tülu 2 70B 的 65.1。

reranking。GritLM 既能当 bi-encoder 又能当 cross-encoder。论文用 Sun et al. 的 permutation generation prompt,让生成能力对 top-10 召回结果重排,MTEB 检索平均从 57.4 提到 57.9,论文 Conclusion 写的是 16 个检索数据集中 15 个有提升(Table 3 给出的 16 项里 QuoraRetrieval 89.47→88.67 是唯一下降的)。

RAG 缓存加速。这是论文的卖点之一。传统 RAG 用两个独立模型:embedding 模型编码 query 做检索,生成模型读"query + 召回文档"出回答,query 和 doc 在两个模型上各跑一次,共 4 次 forward。GritLM 检索和生成同一组权重,embedding 阶段算出的 transformer 内部 KV states 可以直接喂给生成阶段,省掉重复 forward。这是不同模型无法做的,因为 KV 是模型内部表示,跨模型不通用。

GritLM 把 embedding 阶段的 KV 状态缓存给生成阶段复用

论文给出三种缓存策略:

  • Query Caching:embedding 阶段算 query 的 KV 时顺手缓存,生成阶段不再重新 forward query
  • Doc Caching:建索引时不仅存 doc embedding,还把每篇 doc 的 KV 一起存进索引,命中后直接喂给生成阶段
  • Query-Doc / Doc-Query Caching:两者都缓存,但因为 query 和 doc 各自缓存时没机会互相 attend,会偏离原始 RAG 的语义

在 Natural Questions 上,sample A(query 1 token,doc 4000 tokens)CPU 上 Doc Caching 5.25s vs 传统 RAG 14.18s,提速 63%;sample B(doc 1 token,query 4000 tokens)CPU 上 Query Caching 6.87s vs RAG 14.88s,提速 54%。GPU 上提速幅度小一些(30% 量级),论文解释是 GPU 本来就并行处理整个序列,缓存收益相对小。

但 Query Caching 改变了 query 的 attention 模式(embedding 时是双向,生成时模型期望 causal),论文实测 match 分数 Query Caching 从 RAG 的 30.50 掉到 25.46。Doc Caching 反而微涨到 33.38,论文的解释是文档不需要被像 query 那样彻底理解,“略微损坏"的 KV 状态对生成质量影响不大。Query-Doc 和 Doc-Query Caching 因有双重 attention 错位,分数掉到 21.63 和 18.39,接近 No RAG 的 21.00。论文的 takeaway:Query-Doc Caching 实用性受限,单边缓存才是性价比高的选择。

关键消融

注意力。把 causal LLM 在 fine-tune 阶段切成 bidirectional 后再做 mean pooling,embedding 涨了 1.8(causal+wmean 60.0 → bidirectional+mean 61.8),论文确认了"causal LLM 拿来做 embedding 应该改成双向"这一结论。改成 PrefixLM(instruction 双向 + response causal)反而掉分。

初始化。Mistral 7B > Llama 2 7B > GPT-J 6B 在 embedding 和生成上都成立。论文一个有意思的发现:预训练后直接测 embedding,GPT-J 比 Mistral 强;但 fine-tune 后 Mistral 反超。结论是 pretrained embedding 能力不能预测 fine-tuned embedding 能力,pretrained 生成能力反而更靠谱。

embedding 数据集。E5 (66.0) > MEDI2 (64.7) > MEDI (64.0)。论文将 E5 的优势归为 GPT-4 生成的 hard negative 质量更高、任务多样性更好。

生成损失粒度。token level 还是 sample level 直接影响生成长度,进而影响 AlpacaEval(已知偏好长回答)。论文最终用 mix:32 个样本内 token level,再 8 个 sub-batch 间 sample level。这个 mix 在 AlpacaEval 上是 74.7,纯 sample-level(“Mix (4 -> 64)")只有 67.6,差 7 分,对应生成中位长度 941 → 865。

in-batch negative 来源。让 negative 全部来自同一数据集 vs 任意数据集,平均分一样(66.0),但 Retrieval 子集涨 1.3。论文将原因归为同数据集内的 negative 区分难度更高,逼模型学更细的差异。

embedding batch size。从 256 涨到 4096,embedding 平均 +1.0,主要来自 15 个 retrieval 数据集,生成性能不变。

精度。整体 BF16 mixed precision 即可,但 pooling 和相似度计算必须 cast 到 FP32,否则 embedding 性能略有下降。论文未对此给出更深的理论解释,只是经验性地建议这样做。

few-shot embedding 不 work。在 instruction 后面加一个示例,整体性能下降。即使在 MEDI2 训练里塞了 5% 的 few-shot 样本,模型也没学会用。论文将这一现象简单归为"模型似乎没学会利用 few-shot 示例”,未做更深分析。

一些易被忽视的实现细节

asymmetric 任务用 one-sided instruction。E5 数据集对 retrieval 类任务只给 query 加 instruction,文档不加。这样文档只需编码一次就能跨任务复用,缓存友好。symmetric 任务训练时也是单边,但评估时按双边格式喂给模型,论文说这是合理的,因为 cosine similarity 的传递性保证了 A↔B↔C 仍然成立。

KV 缓存的存储代价。对 2,681,468 个文档用 7B 模型,KV states 总量约 30TB。论文指出这部分可以完全 offload 到磁盘,按需加载,每个样本约 12.5MB。原始 index 只是 43GB,KV cache 比 index 大三个数量级。

KTO 对齐 trade-off。KTO(Kahneman-Tversky Optimization)是一种偏好对齐方法,相比 DPO 不需要成对偏好数据,只需要每条样本的"好/坏"二元标签。论文在 GRIT 之后追加了一段 KTO 阶段,UltraFeedback 二元化数据,只训生成不训 embedding。结果 MTEB 从 66.8 微跌到 66.7,AlpacaEval 涨超过 10 分。意味着对齐阶段会缓慢侵蚀 embedding 能力,需要继续维持 embedding 训练才能保住。

$\lambda_{\text{Rep}} > \lambda_{\text{Gen}}$ 的设定。论文坚持 $\mathcal{L}_{\text{Rep}} / \mathcal{L}_{\text{Gen}} > 1$,理由是模型已经预训练过 LM 损失,对比损失是新东西需要更多学习。实际 embedding 损失下降很快,到训练后期两个损失都稳定在 1.0 附近,初始的权重差异被自动平滑掉。

embedding head 的取舍。可选加一个 4096→1024 的下投影线性层,存储省 4 倍但 embedding 平均掉 1 分。GritLM 最终未采用,留给下游用 PCA 等后处理压维。

概念层面的影响

GritLM 的论点是 embedding 和生成是同一枚硬币的两面,可以由同一个 LLM 用 instruction 区分。这是 HyDE “把相关性外包给 LLM” 思路的进一步推进:HyDE 在推理时用两个独立模型拼装,GritLM 在训练时把两者合一。代价是 fine-tune 的算力翻倍(要同时跑 embedding 和生成两个目标),但模型部署、缓存策略、reranker 复用都因此简化。

后续 LLM2Vec、Echo Embedding 等工作沿用了"causal LLM 改双向 + mean pooling 做 embedding"这条路径,但没有合并生成目标。GritLM 的双目标合一设计在工程上更激进,是否真值得多付训练成本,要看具体场景对统一模型的需求强度。

小结

GritLM 在 7B 规模上同时拿到 MTEB SOTA 和接近最佳的生成性能,证明 embedding 和生成可以在同一模型里共存而不互相伤害,并由此带来 reranker 复用和 RAG 缓存加速两项工程红利。代价是训练阶段两个目标各跑一遍 forward/backward,成本约为单目标的两倍;7B 的体量在生产 embedding 场景下也偏重,常见的 0.1-1B 模型仍有性价比优势;RAG 缓存方面,Doc Caching 需要 30TB KV 存储(对照 43GB 索引),Query Caching 提速明显但 match 分数有可见下降,单边缓存更实用。论文消融部分整理的训练技巧(双向注意力 + mean pooling、bf16 但 pooling cast FP32、in-batch negative 同源)后续被广泛沿用。