分享
101N0101 ngram Python 核心代码解读
输入“/”快速插入内容
🍇
101N0101 ngram Python 核心代码解读
用户5190
用户9737
用户2659
2024年11月1日修改
前言
代码仓库地址:
https://github.com/EurekaLabsAI/ngram
今天将和大家一起学习 LLM101n 课程中 N-gram 部分。本期我们先详解 n-gram 模型的算法原理(包括困惑度的定义、计算方式(与熵的关系)、数据稀疏问题的解决方式等),再来对基于 Python 和 C 的 ngram 代码进行解读。
n-gram 算法原理
n-gram 算法是一种语言模型,本质和 transfromer 语言算法模型一样, 也是用来预测下一个token
(词元,可以简单理解为一个单词或词组、词)
的算法。但
n-gram
是一种更简单,形式清晰的语言模型。
先看看一句话如何计算分词(token):
<s>我爱北京天安门。</s>
这句话通过分词后会是:
["<s>", "我", "爱", "北京", "天安门", "。", "</s>"]
如何计算这句话的概率, 当然是联合概率分布:
其中
表示句子序列 w_1w_2...w_n
公式里描述的是最完美情况,但是这样的每个token的预测都依赖
所有的
历史token
,
这个计算代价非常高
,
为什么?
一方面是因为需要计算语料库中任意 N 个 tokens 的所有排列的概率分布(这几乎是不可能实现的),另一方面是因为 N-gram 算法的空间复杂度和时间复杂度是关于 N 的指数函数(即随 N 提升,训练所需投入的资源量也呈指数上升,这是不可取的)
N 取不同值时,N-gram 模型的参数变化。可以发现随 N 的上升,模型参数量呈指数上升(图源:CSDN)
为了解决计算复杂度的问题,我们可以采用
马尔可夫假设
来优化做个问题,即
一个词的出现仅与它之前的若干个词有关
。
比如下一个词的只依赖上一个词概率分布,即:
这就是 n
=
2 的 bigram 算法
(又称 2-gram)
。
如果假设每个token都是独立的分布的,即:
这就是
n=1 的
unigram
算法(又称 1-gram)
。
类推,n=3
的 trigram 公式:
n=4的 4-gram 公式:
回到我们前面提到的计算复杂度的问题。当 N 从 1 到 3 时,模型的效果上升显著;而当模型从 3 到 4 时,效果的提升就不是很显著了,而资源的耗费却增加的非常快。因此 N-gram 模型中 N 的取值大多不超过 3。[更多详情请参阅吴军《数学之美》相关章节]
按照上面的例子, 假设是bigram模型,训练语料如下:
代码块
Plain Text
<s> 我 爱 北京 天安门 。 </s>
<s> 我 想 去 北京 。 </s>
<s> 北京 是 首都 </s>
这里为了简单,用空格进行隔开代表分词。
可以得类似以下的条件概率值:
其中< s >表示 start token,是一种特殊的标记,可以作为一个 “虚拟的前序单词” 参与概率计算。
Perplexity 困惑度的定义和理解
困惑度
(Perplexity,常用 PPL)
计算公式是: