本文来源于微信公众号“ 自然语言处理算法与实践”,作者:烛之文
1、前言
现在主流的大语言模型(large language models,LLM)基本都是采用decoder-only,自回归的框架模式。今天就分享一种新的语言模型框架——检索式语言模型(COPY-GENERATOR,CoG),出自论文<COPY IS ALL YOU NEED>,代码地址:https://github.com/gmftbyGMFTBY/Copyisallyouneed,该篇论文被2023 ICLR 会议收录。
其主要思路是:从语料库中自回归式地复制文本片段(text segments)来进行文本生成,改变原来从词表中逐个token的生成方式。带来的好处有:1)解码效率提高,从原始token-level的解码变成segment-level;2)因为是片段式的复制,生成的文本整体流畅性更好;3)因为是检索式,训练出来的模型可以独立领域语料,迁移性强。下面详细介绍下。
2、语言模型
模型包括三个部分:1)prefix encoder,实现对输入的序列进行编码;2)phrase encoder,对语料库中文本片段(phrase)进行编码;3)phrase table,存储文本片段向量库,实现检索式生成。
2.1 Prefix Encoder
采用标准的transformer框架,实现对the prefix 序列的编码,来进行next-phrase生成。具体来说,将 转变成序列向量 ,回归方式可生成the prefix 的序列向量
然后取 中最后一个token的向量来代表the prefix 的表征向量,记为 。
2.2 Phrase Encoder
对于包含有n篇文档的语料库 ,假设某篇文档的序列长度为m,即为 ,然后同样使用transfomer编码器,得到文档的编码向量为 ,接着对接两个MLP层,将其转成一个start token表征 和一个end token 表征,即:
这样对于文档中的任意片段 的表征向量都可以用开头和结尾两个token的向量合并来表示,即:
上述表征思路带来的好处有:1)对文档进行一次编码,即可得到文档中所有片段的编码;2)只需存储文档中token的向量,不用存储片段(phrase)的向量,大大减少存储量。
2.3 Phrase Table
在Phrase Table中存储有上步编码的文档向量(文档中token的向量),除此外,也存储传统的词表向量(独立文档的token,视为长度为1的片段)。例如上图中,输入“he Dune film was released”,检索三个来自文档中的片段(3种不同颜色的文本),和三个token(Before、that、,),生成最终的序列文本。加入词表向量的好处:当从文档中检索不到合适的片段,就可以按目前主流的方式从词表中生成。
3 训练
在训练CoG模型中,首选需要对语料库的文档进行片段化处理,类似中文分词。其原理是按最大原则(事先设定一个最大长度),将某篇文档的序列从左到右进行切片,如果某个片段在其他文档出现过且没超过最大长度,就视为合理的片段。把所有文档切分后,会剩下一些长度为1的token,这类就可以加入词表中。
在切分片段后,就可以生成训练数据。对比传统的LM模型训练数据集,其输入一段文本(token-level),其学习思路是用前面的token序列预测下一个token;而CoG只是变成phrase-level,其学习思路变成利用前面的phrase序列,预测下一个phrase,只是粒度的差异。论文在训练中使用两个优化函数:对比损失和交叉熵。先看第一个优化目标:
其含义是:假设输入一个batch-size包含有n个连续片段 , 视为第 片段表征向量。由于片段可以来自文档(长度>1),也可以来自词表(长度=1),前者表征向量通过PhraseEncoder的首尾索引获取到,二者直接从token embedding中检索到。
是第 片段前面所有片段组成的prefix表征向量,将< >视为一对正样本,然后 来自的所有集合 , 构成对比集合,二者分别表示 包含 文档中所有的片段集合和词表。
此外,论文也加另一个token-level级别的学习,对应的优化目标就是交叉熵,形式如下
其中 表示文档D中第i个token向量, 是文档中前i个token组成的prefix的向量,m为文档长度,其形式跟传统LM学习方式一致。
这样,CoG的整体训练目标为:
4 推理
在推理阶段,对于输入的文本 ,通过PrefixEncoder得到其表征向量 ;然后利用事先构建片段向量集合Phrase Table找与 相似度最大的片段 作为输出,即为:
由于phrase table中 包含上亿级别的量,为了提高检索计算效率,文中是使用了FAISS向量数据库,利用召回top-k的方式进行生成。
5 实验
论文对比Transformer(GPT2)、kNN-LM、RETRO三种框架,后两者是检索式的语言模型,在数据集WIKITEXT-103测评的结果如下:
可以看出:
1)在MAUVE指标上,论文提出CoG框架在两种解码方式上都达到最好结果,提升约有3个点;
2)在CoG框架下,greedy解码的重复生成度最小(Rep2,Rep3、Rep4),比其他模型下降约有10个点;
3)在多样性上(Diversity),CoG用greedy解码效果好些,但用nucleus采样方式效果差些,表现最好是kNN-LM框架;
4)解码速度上(Latency),其跟正常的语言模型(transformer下的gpt2)相当,虽有提升,但并不明显,kNN-LM解码速度是最慢的。
总的来说,论文提出的CoG语言框架,相对其他模型,贪婪解码方式更适合它,这样依然存在多样性差的问题;此外,其解码效率相比正常的语言模型并没有提升多少。
上图显示CoG模型在生成时,通过检索复制过来的token长度占比情况。可以看出,模型还是偏向单个token生成,占比达到约50%,超过6-gram的片段占比已经很低了。这说明,虽然论文的思路是想让模型生成时去检索复制语料库中片段,可训练的模型还是偏向原始语言模型的解码(token-level生成)。这应该也是优化目标决定的,文中phrase-level和token-level对应的两个优化损失权重是1:1,比较吻合这个实验结果。
6 结语
本次分享的还是关于检索式语言模型,其跟之前的不太一样:论文提出的CoG语言框架是让检索生成的片段也参与了训练,更准确的说是融合phrase-level和token-level这两种方式生成。个人觉得phrase-level级别的生成方式还是值得借鉴。本篇论文工作应是在大模型还没火之前完成的,再想想当前比较热的研究点——LLM+检索式增强,觉得本篇论文提出的思路还有很多值得进一步优化的地方:如只训练phrase-level生成,token-level交给LLM来完成,二者融合是不是也实现了检索式增强;