输入“/”快速插入内容

Deploy034 投机解码

2024年10月10日修改
1.1 投机推理\采样 (Speculative Decoding )
关于投机性执行是分布式计算的老话题,在计算机体系结构中也有投机预取,处理器优化中也有推测投机执行指令,也就是分支预测,还有分布式领域 MapReduce/Hadoop 时代(大概在 2010 年前后)就有大量关于投机性执行一些预计性任务的工作等等的相关思想。大语言模型都需要迭代运行顺序生成标记,也就是一个一个 token 往外吐,不好在 token 级别并行进行加速。所以就有了投机推理技术来帮忙优化增加并行度,一般被称为 speculative decode,推测性解码、投机解码。
2.1 Speculative Decoding 背后有两个关键思想
2.1.1 预测 Token 难易不一致
在下图中,预测标记 'of ' 真的很容易,而且它可能很容易被小得多的模型预测,因此使用较小的模型来预测简单的标记,而使用大模型仅用于预测更困难的标记。
预测标记 'of ' 真的很容易,并且可以通过小得多的模型轻松预测,而标记 'Edinburg' 的预测相对来说很困难,而较小的模型可能无法预测。
2.1.2 并行验证 Token
尽管这些自回归模型通常一次迭代生成一个单词,但它们可以一次输入多个Token。在生成下一个 Token 时,他们可以一次检查序列中的所有 Token。它通过计算序列中每个标记的概率来实现此目的。
在上图中,较小的模型预测“Toronto”,但正确的单词是“Edinburgh”,较大的模型可以看到“Toronto”的概率很低,并将其更正为“Edinburgh”。
2.2 推测解码:推理加速的关键因素
推测解码使用更小、更快的模型(称为Draft/Small 模型)在生成多组输出。随后检查此输出,并在必要时由更大、功能更强大的模型(称为Main/Target模型)进行校正。推测解码的本质在于它能够加快推理过程,利用较小模型的速度,同时保持较大模型提供的质量保证。
小型模型提供低延迟,但它们生成的文本质量通常较差。
较大的模型可能能够生成高质量的文本,但响应时间可能会很慢。
Fast Inference from Transformers via Speculative Decoding 从Paper当中的例子我们能看到
每一行代表模型的一次迭代。
绿色标记是由 6M参数Draft模型 提出的建议,
目标模型为 97M参数Target模型 接受了这些建议,而红色和蓝色标记分别是被拒绝的建议及其修正。
例如,在第一行中,Target模型只运行了一次,生成了5个标记。剩余的Token均为小模型生成。
如果 draft 模型的建议始终准确,那么推测解码的总成本将只是 draft 模型的推理成本,加上Target模型的单个验证步骤。与仅使用 Target 模型的传统解码方法 (均使用大模型逐token生成) 相比,这种配置将更快的推理。
Draft模型必须比主模型小得多且速度更快。
Draft模型应生成主模型可以验证通过的大量token建议。
相反,如果Draft模型的建议大多是错误的,那么Target模型将需要纠正每个标记,从而有效地将解码工作加倍。(首先生成预测,然后纠正它们。这种情况会使推测解码比单独使用主模型慢。)
📌
投机解码的加速关键是 “Draft模型的推理速度” 与 “能够被Target模型验证通过的概率”;
2.2.1 Huggingface 推测解码实现
代码块
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
prompt = "Alice and Bob"
checkpoint = "EleutherAI/pythia-1.4b-deduped"
assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint).to(device)
outputs = model.generate(**inputs, assistant_model=assistant_model)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
# ['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']