输入“/”快速插入内容

饮鸩止渴?LLM训练要不要过采样/训多个epoch

2024年7月24日修改
作者:咸鱼王
目前已经有很多paper论证了数据量(尤其在pretrain阶段)对LLM的重要性,然而数据、尤其是高质量数据往往可遇不可求,因此过采样或训练多个epoch就成了缓解data hungry的常见手段。但是这种repeat究竟会对LLM带来什么影响?
论文链接:
01 背景
众所周知,为了缓解LLM的data hungry,pretrainTokens的量级一直是目前LLM迭代、更新的最重要因素之一,然而尤其在垂域适配时,优质的数据往往可遇而不可求,因此repeat就成为了缓解hungry的最直接手段。这篇论文主要探讨了pretrain阶段repeat对LLM带来的影响,笔者尝试以本篇论文为主干,结合相关文章和笔者的个人见解,对pretrain、SFT阶段repeat的影响做进一步分析。
02 实验设置
作者基于T5及C4数据进行pretrain实验,验证不同repeat下模型的效果,为了方便阅读和理解,笔者将文中提及的multi-epoch training和的过采样,统一称为repeat。
03 LLM易受repeat影响:过拟合并性能下降
3.1 repeat会导致score下降
为验证repeat对模型带来的影响,作者在确保trainedTokens相同的前提下,训练、对比了不同repeat程度下LLM的MLM准确率。具体的,假设从C4随机采样、用于pretrain的tokens数为 T T ,pretrain阶段repeat的次数为 R R ,作者实验的3组配置为: (T,R)={(235,1),(229,26),(227,28)} (T, R)=\{(2^{35},1),(2^{29}, 2^{6}),(2^{27}, 2^{8})\} ,以确保每个方案下都训练了 T∗R=235 T*R=2^{35} 个tokens。
实验结果如下图,MLM准确率随着repeat次数的增加呈下降趋势,但以笔者拙见,由于缺乏 (229,26) (2^{29},2^6) 和 (229,20) (2^{29},2^0) 的对比实验,并无法确定token-crisis下,repeat能否在一定程度上缓解token-crisis问题。
作者目的更多在于验证data hungry下,repeat对模型带来的影响。但是很显然,在正常分布下,repeat fake的数据显然不如同量级真实分布的数据,如果将数据视为真实世界分布中的采样点,数据越多、分布越广、越均匀,则采样构成的面越平滑、越能还原真实分布,而repeat本身并无法带来这种增益,显然不如同量级真实分布的训练效果。
为验证repeat是否会持续影响后续的SFT甚至RL,作者对repeated pretrain模型在SQuAD上进行了微调,具体如下图,可以看出,pretrain阶段经过多次repeat的LLM,在SFT后仍显著低于未repeat的模型。但仍如之前所述,该实验 无法区分这种负面影响,是repeat带来的,还是data hungry带来的。
3.2 少量repeat也有过拟合风险
Jeremy Howard在 [Can LLMs learn from a single example?] 中给出了另一个典型的案例:LLM在SFT中呈现阶梯状的train_loss曲线,且每次骤降均发生在epoch末尾,具体如下图,Howard认为最直接的原因就是LLM在repeat后产生了过拟合。这个结论还是比较令人吃惊的,不同于论文中64甚至256次的repeat,这意味着模型仅在1+ repeat后就产生了过拟合。
为验证猜想,Howard调整了学习率衰减策略,进一步观测模型在2个epoch内的train_loss/valid_loss变化,验证了模型在1-2次repeat后就发生了较为明显的过自信/过拟合现象。(对验证流程不感兴趣的读者可以先行跳过)
具体的,借助Cyclical Learning Rates衰减策略,使学习率在单个epoch内完成warm_up及衰减,具体如下图1。不出意外,train_loss/valid_loss曲线走向与Howard猜想基本一致,实验结果如下图2:
在epoch_1的warm_up初期,学习率缓慢上升,train_loss/valid_loss由于学习率较低下降缓慢;
在epoch_1的warm_up末尾,学习率上升至高点,train_loss/valid_loss由于学习率较高开始迅速下降;
在epoch_1的中后期,学习率逐渐衰减至低点,train_loss/valid_loss趋于平缓;
在epoch_2的warm_up初期,由于训练集没有shuffle,batch顺序和epoch_1完全一致,即此时的batch也是epoch_1的warm_up初期的batch,由于初见时学习率较低仍未较好拟合,因此train_loss/valid_loss仍比较正常;
在epoch_2的warm_up末尾,学习率上升至高点,此时遇见的batch已经在epoch_1较好拟合,再次遇见之后趋向于过度拟合、甚至记忆样本,此时对model的泛化性几乎没有提升,因此train_loss再次陡降,valid_loss开始陡升,模型开始过自信/过拟合;
在epoch_2的末尾,此时与warm_up初期的情况较为相似,遇见的batch在epoch_1仍未较好拟合,model原本记忆样本所学习到的规则与此时遭遇的样本产生碰撞,甚至矫正model回归至更合理的置信度水平,因此train_loss逐渐攀升、valid_loss逐渐下降。