LoRA遇上MoE,大模型再也不会健忘了

最近组里同学在尝试实现LoRAMoE,意在解决大模型微调后遗忘世界知识的问题。参考的是复旦23年年底的这篇论文:"LoRAMoE: Revolutionizing Mixture of Experts for Maintaining World Knowledge in Language Model Alignment"[1],将LoRA和MoE方法做了组合,有效解决了上述问题,本文详细讲解一下这篇论文。

  1. 背景

大模型经过大量语料的无监督预训练后,得到所谓的基座模型,这时候通常还不能很好地完成下游任务,需要经过有监督的微调(SFT)后才能和人类指令对齐,释放其全部潜力。

一般来说,SFT的训练数据不需要太多,但当下游任务增多或者需要强化特定任务的性能时,增加SFT训练数据还是有必要的。如下图的左侧部分,当SFT数据从100K提升到3M时,大部分任务的性能显著增强。

但随着SFT数据的大规模增加,新的问题出现了:如下图的右侧部分所示,在某些评测数据集上性能显著下降,与之相伴的是大模型的参数变化量剧增(见红色线段)。这些数据集属于闭卷问答任务(Closed-Book Question Answering,简称CBQA),即只给大模型输入问题,大模型主要依靠在预训练过程中习得的世界知识来给出答案。

这里补充一下:像TriviaQA、Natural Questions这类数据其实是包含问题相关上下文的,也就是说如果用作开卷问答任务,则输入不仅包括问题还包括上下文,大模型可以从上下文中总结出答案;但如果用作闭卷问答任务,则输入中不提供上下文。论文[2]中6.1节有提到。

我们有理由怀疑CBQA的性能下降与大模型世界知识的崩溃有关,下面将通过实验证明这一点。首先验证CBQA的推理依赖大模型的世界知识,其次证明CBQA数据集上性能的大幅下降归因于大规模微调会显着改变模型参数,导致世界知识的破坏,即发生知识遗忘。

  1. 大规模微调导致世界知识破坏

下面详细介绍一下实验过程,即做实验证明大规模SFT导致大模型的世界知识严重受损,引起知识遗忘。

2.1 实验设计


数据集

准备了7种任务的数据集,分别是CBQA(闭卷问答)、coreference resolution(指代消解)、NLI(自然语言推理)、summarization(文本摘要)、multi-lingual translation(多语言翻译)、reading comprehension(阅读理解)、text classification(文本分类)。具体数据集见下图:

基座模型

采用LLaMA-2-7B作为基座模型,属于在学术界非常流行的LLM之一。

评估

将任务分为两类:CBQA数据集用于评估模型的世界知识,前人工作发现CBQA数据集中有train-test重叠,因此做了过滤,只用未重叠的test集,命名为Filtered TriviaQA和Filtered NQ这种;其他的下游任务用opencompass[3]框架来评测。

2.2 实验结果


用前面说的7种任务的混合数据集微调大模型,数据规模逐渐增加,然后看不同下游任务的性能表现。如下图所示,像左侧的摘要、NLI、机器翻译这类任务,随着SFT训练数据的增加,性能显著提升;但是右侧的CBQA任务,却出现断崖式下跌。

我们已经高度怀疑CBQA的性能下降是由于大模型的世界知识崩坏引起的,为了更加确信,接下来我们仔细实验一下CBQA和大模型的世界知识到底有什么关系,具体做法是单独拿CBQA的25万条样本训练大模型,然后看大模型在未重叠的测试集上的表现。

如下图所示,在训练一开始大约1000样本的时候,性能已经快速提升到了很高的点,后续再增加更多的训练样本其实提升很有限。说明少量样本微调就帮助大模型完成了人类指令的对齐,大模型完成CBQA指标评测的能力主要依靠的是内在的世界知识,而不是微调过程中训练样本灌输的。因此我们更加确性CBQA指标高度依赖大模型在预训练过程中学到的世界知识,上图中CBQA的性能下降的原因就是世界知识的破坏。

再进一步实验,证明是大规模的微调导致了世界知识受损。具体做法如下表:第三列仅用CBQA训练数据微调,是可以CBQA测试集上打败Baseline的;而第四列是分两阶段,先用300万不包换CBQA的数据微调,然后再用和第三列同样的CBQA数据继续微调,结果在CBQA测试集上的表现比Baseline都差很远。

对比一下第三列和第四列,差别只在后者多了一个第一阶段300万数据的微调,说明它正是大模型世界知识崩塌的罪魁祸首,第二阶段即便加上CBQA训练数据,也无法弥补回来。同时发现大模型的参数发生了巨大变化,正好和前面结论相互佐证。

  1. LoRAMoE方法

前面的实验表明,有些下游任务需要SFT的训练数据越多越好,即LLM的参数改变越大越好,而有些下游任务需要尽可能保留世界知识,即参数变化越小越好。这种冲突对于一般的全参微调或者LoRA微调都是搞不定的。论文[1]引入了MoE的思想来解决,实现LoRA微调的自适应。题外话,MoE在搜广推领域早就烂大街,这里又用到了LLM微调领域,说明技术都是相通的。

下面先分别介绍一下MoE和LoRA,然后看如何结合。

3.1 MoE简介


MoE全称Mixture of Experts,意味着有多个专家网络投票共同输出结果,只不过每个专家根据输入不同,分配不同的权重。我们可以想象成不同的专家具备不同领域的能力,然后根据输入的特征,给更匹配的专家分配更高的权重,从而动态组合专家输出。

MoE本身是一种思想,需要结合具体模型设计。对于transformer形式的网络,MoE可以将每个block中的前馈网改造成N个结构相同的前馈网 作为专家,然后再配合上门控函数 作为路由,即根据输入给不同的专家分配不同的权重。具体为:

其中, 为block中attention层的输出,同时作为MoE层的输入, 为MoE层的输出, 为第 个专家的输出, 为第 个专家输出对应的权重。门控函数 具体为:

为可训练的参数矩阵。

3.2 LoRA简介


如果大家对推荐系统中的SVD算法了解的话,LoRA的原理就非常简单了,无非用两个低秩矩阵相乘来拟合一个高秩矩阵,只不过这里拟合的不是模型的参数矩阵 本身,而是参数矩阵的增量 ,即更新后的参数矩阵变为:

其中 ,且 ,这样微调过程中我们只需存储两个低秩的 和 矩阵即可,大幅减少存储空间。

按照LoRA的原始论文[4], 用高斯初始化, 用全零初始化,且会增加一个缩放系数 , 为超参数,输入为 ,输出为 ,具体为:

训练过程中,固定 不变, 用全零初始化可以保证在初始化阶段 ,调整 相当于调整学习率,且我们在实验中可能经常会调整参数 ,这里缩放系数中除以 能减少超参数的调整。

论文[4]中对缩放系数的作用只是很简略的说了一段,令人摸不着头脑。对此我自己做了推导和调研,发现原生的缩放因子并未最佳选择,因和本文主旨无关,后续单独作文述之。

3.3 LoRA+MoE


我们结合LoRA和MoE的优点,将二者组合起来,便是LoRAMoE,看下图所示:

具体做法就是冻结已经预训练好的LLM,然后在FFN层中的每个线性层增加了一组LoRA适配器作为专家网络组,并通过门控网络(路由器机制)分配权重,公式为:

其中, 为线性层的输出, 为线性层的输入, 为线性层的预训练好的参数(被冻结), 为门控函数,即 ,共 个专家网络。其余变量含义和前面介绍的LoRA完全类似,不再重复。修改后的线性层称之为LoRAMoE层。

对照公式我们再回看上面的图4,只能说图中的LoRAMoE部分是个示意图,意会即可。

3.4 专家平衡约束


如果不加任何约束微调MoE,经常会出现门控函数收敛到一种状态,即少数专家掌握了话语权,其他专家权重非常小,失去了平衡。

LoRAMoE人为地将专家分为两组,一组专注于学习下游任务,另一组专注于将世界知识和人类指令对齐。

形式上,我们先给每个LoRAMoE层定义一个重要性矩阵 ,每个元素 表示一个batch中第 个样本在第 个专家上的重要度,论文[1]中给出一个公式(13)来计算 ,但公式明显有误,我只能按我的猜测重新给出一个:

其中, 表示第 个训练样本的token数, 表示该样本第 个token的输入向量, 为该样本所有token在第 个专家上的路由器权重之和,然后在所有专家上做个softmax归一化得到 , 为温度超参数。

然后再给 配上一个相同尺寸的重要性系数矩阵 ,定义为:

其中, 控制专家之间的平衡度。前面所讲,我们将专家人为分了两组,样本也可分为同样的两组, 表示专家 的分组, 表示样本 的分组。

我们真正使用的是加权版的重要性矩阵 ,即如果专家和样本分组一致,就在重要性分数 上乘上一个大于1的权重,否则乘上一个小于1的权重,目的是放大重要性分数之间的差距。然后构建一个损失函数如下,作为整体损失函数的一部分:

分子分母分别为方差和均值。通过这个损失函数就能够抑制专家强者恒强的现象。

整体损失函数为LLM的损失 加上所有层的 之和,后者有个 权重:

3.5 实验


实验参数:

LoRAMoE层只替换LLM中FFN的线性层,且每个LoRAMoE层的专家数为6,其中3个用于下游任务,另外3个用于对齐世界知识。

超参 ,,,dropout为0.05,学习率为2e-4,batch size为64。
300万的训练样本,在32张A100上训练。

实验结果:

可以看到,相比于全量微调或者传统的LoRA,本文的方法都取得明显提升,世界知识的遗忘问题也不再发生。详细结论不再细表,总之,LoRA结合上MoE的路由功能,让LoRA的参数增量不再是静态的死板一块,而是可以根据不同任务的输入来动态生成,有效解决了SFT大量训练数据和世界知识遗忘的冲突。

  1. 参考文献

[1] LoRAMoE: Revolutionizing Mixture of Experts for Maintaining World Knowledge in Language Model Alignment, 2023

[2] Palm: Scaling language modeling with pathways, 2022

[3] https://opencompass.org.cn/

[4] Lora: Low-rank adaptation of large language models, 2021

转载自:https://mp.weixin.qq.com/s/fBrjK49Qhtc-rtCT3n6G3g

1