输入“/”快速插入内容

101N0101 ngram Python 核心代码解读

2024年11月1日修改
前言
今天将和大家一起学习 LLM101n 课程中 N-gram 部分。本期我们先详解 n-gram 模型的算法原理(包括困惑度的定义、计算方式(与熵的关系)、数据稀疏问题的解决方式等),再来对基于 Python 和 C 的 ngram 代码进行解读。
n-gram 算法原理
n-gram 算法是一种语言模型,本质和 transfromer 语言算法模型一样, 也是用来预测下一个token(词元,可以简单理解为一个单词或词组、词)的算法。但 n-gram 是一种更简单,形式清晰的语言模型。
先看看一句话如何计算分词(token):<s>我爱北京天安门。</s>
这句话通过分词后会是:["<s>", "我", "爱", "北京", "天安门", "。", "</s>"]
如何计算这句话的概率, 当然是联合概率分布:
其中
表示句子序列 w_1w_2...w_n
公式里描述的是最完美情况,但是这样的每个token的预测都依赖所有的历史token这个计算代价非常高为什么?一方面是因为需要计算语料库中任意 N 个 tokens 的所有排列的概率分布(这几乎是不可能实现的),另一方面是因为 N-gram 算法的空间复杂度和时间复杂度是关于 N 的指数函数(即随 N 提升,训练所需投入的资源量也呈指数上升,这是不可取的)
N 取不同值时,N-gram 模型的参数变化。可以发现随 N 的上升,模型参数量呈指数上升(图源:CSDN)
为了解决计算复杂度的问题,我们可以采用马尔可夫假设来优化做个问题,即一个词的出现仅与它之前的若干个词有关比如下一个词的只依赖上一个词概率分布,即:
这就是 n=2 的 bigram 算法(又称 2-gram)
如果假设每个token都是独立的分布的,即:
这就是 n=1 的 unigram 算法(又称 1-gram)
类推,n=3 的 trigram 公式:
n=4的 4-gram 公式:
回到我们前面提到的计算复杂度的问题。当 N 从 1 到 3 时,模型的效果上升显著;而当模型从 3 到 4 时,效果的提升就不是很显著了,而资源的耗费却增加的非常快。因此 N-gram 模型中 N 的取值大多不超过 3。[更多详情请参阅吴军《数学之美》相关章节]
按照上面的例子, 假设是bigram模型,训练语料如下:
代码块
<s> 我 爱 北京 天安门 。 </s>
<s> 我 想 去 北京 。 </s>
<s> 北京 是 首都 </s>
这里为了简单,用空格进行隔开代表分词。
可以得类似以下的条件概率值:
其中< s >表示 start token,是一种特殊的标记,可以作为一个 “虚拟的前序单词” 参与概率计算。
Perplexity 困惑度的定义和理解
困惑度(Perplexity,常用 PPL)计算公式是: