Greedy Search
贪婪解码,每一步从词汇表中选择具有最高条件概率的 token 作为 next token,直到遇到结束 token 或达到最大的上下文长度。
在撰写贪婪解码的代码前,我对自回归解码中撰写的 Sampler 基类进行了完善,为了提高代码的复用性,在其中提供了 get_next_token 接口。
class Sampler:
def __init__(self, model_name: str="Qwen2.5-0.5B") -> None:
self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
def encode(self, text: str):
return self.tokenizer.encode(text, return_tensors="pt").to(self.device)
def decode(self, ids: torch.Tensor):
return self.tokenizer.decode(ids)
def get_next_token_prob(self, input_ids:torch.Tensor):
# 禁止计算图中梯度的计算
with torch.no_grad():
logits = self.model(input_ids=input_ids).logits
# 在此之前,logits 形状为 torch.Size([1, 1, 151936])
# 获得 Tensor 的最后一维度 torch.Size([151936])
logits = logits[0, -1, :]
probs = torch.softmax(logits, dim=-1)
return probs
def plot_scores(self, scores, title, k):
"""
:param scores: 排序对象
:param title: 图片标题
:param k: 展示的数量
:return: None
"""
top_indices = torch.argsort(scores, descending=True)[:k]
tokens = [self.decode(idx) for idx in top_indices]
if self.device == "cpu":
top_probs = scores[top_indices].numpy()
else:
top_probs = scores[top_indices].cpu().numpy()
colors = ['#E95B68', '#C4C956', '#58BB7B', '#CAC1C5', '#87601F', '#F7311B',
'#C53D39', '#38658F', '#242ABC', '#9DA52F', '#329018', '#D415C5',
'#6DCE59', '#ADF212', '#9CF042']
colors = colors[0: len(top_indices)]
fig = go.Figure(
data=[
go.Bar(x=tokens, y=top_probs, marker_color=colors, textposition="inside")
]
)
fig.update_layout(title=title)
fig.show()
def get_next_token(self, text: str, k: int):
input_ids = self.encode(text)
next_token_probs = self.get_next_token_prob(input_ids)
top_indices = torch.argsort(next_token_probs, descending=True)[:k]
tokens = [self.decode(idx) for idx in top_indices]
return {
"token": tokens,
"next_token_probs": next_token_probs,
"ids": top_indices
}
def get_next_token_plot_pipeline(self, text: str, k: int=10):
input_ids = self.encode(text)
next_token_prob=self.get_next_token_prob(input_ids)
self.plot_scores(next_token_prob, text, k=k)
基于 Sampler 基类,构建 GreedySampler,该类以 Sampler 类为父类,继承其中的方法。
class GreedySampler(Sampler):
def __call__(self, prompt, max_new_tokens=10):
predictions = []
result = prompt
# generate tokens until the max_new_tokens
for i in range(max_new_tokens):
# greedy search => k=1
next_token_dict = self.get_next_token(result, k=1)
next_token = next_token_dict['token'][0]
result += next_token
next_token_probs = next_token_dict['next_token_probs']
ids = next_token_dict['ids']
predictions.append(next_token_probs[ids].item())
# 判断是否生成结束 token
if next_token == self.tokenizer.eos_token_id:
break
return result
基于 Qwen2.5-0.5B-Instruct 进行实验,输入 prompt 为 "the color of sky is",其返回结果为:"The color of sky is always changing. The sky is blue when the sun is up, and it is white when the sun"。
Greedy Search 本质是一种贪心算法,虽然在解码策略上它较为简单和高效,所需的计算资源较其他解码策略来说较为简单,仍存在一些缺点:
- 多样性且无长远考虑:短视、贪心地每次输出概率最大的 token,忽略了多样性层面、长远的考虑。
- 重复性:由于选择最可能的词,导致实验的可重复性高,但是结果容易被预测。
- 错误放大:贪心算法无法纠正错误,一旦前序解码中存在一定的错误,之后的选择都会受到影响。
Beam Search
束搜索相较于 Greedy Search 而言,不是只考虑每一步的最优情况,而是在束宽度 $k$ 参数下同时跟踪多个潜在序列。
在每个阶段,选择 top-$k$ 个序列,不只考虑即时的高概率词,同时关注整体序列的概率。针对束搜索解码中可能包含重复的相同词序列的问题,使用 n-gram 惩罚的概念,即如果一个 n-gram 序列被生成放入序列中,而该 n-gram 已经在序列中存在,则设置其概率为 0。
class Beam:
def __init__(self, device, size, input_ids, socre, output=None):
self.device = device
self.size = size # num_beam => k
self.input_ids = input_ids.to(self.device)
self.socre = socre
self.output = output.to(self.device) if output is not None else None
# get input_ids
def get_current_state(self):
return self.input_ids
# get probs of the sentence
def get_score(self):
return self.socre
# create a new instance of Beam class after the top-k selection
def extend(self, token_id, score):
# the input_ids of new sentence
new_input_ids = torch.cat([self.input_ids, token_id.unsqueeze(0)], dim=-1)
# the probs score of new sentence
new_score = self.socre * score
new_output = torch.cat([self.output, token_id.unsqueeze(0)], dim=-1) if self.output is not None else new_input_ids
# 递归
return Beam(self.device, self.size, new_input_ids, new_score, new_output)
class BeamSampler(Sampler):
def beam_decode(self, ids):
return self.tokenizer.decode(ids.squeeze().tolist())
# Get the top-k id with the greatest probs
# 静态方法不能访问类的属性或实例的属性,它只与输入参数有关
@staticmethod
def get_topk(prob, k=1):
scores, token_ids = torch.topk(prob, k=k, dim=-1)
return scores, token_ids
def __call__(self, prompt, max_new_tokens=10, num_beam=1):
input_ids = self.encode(prompt)
# 初始化 Beam,只有最初的节点 A1
beams = [Beam(self.device, num_beam, input_ids, 1) for _ in range(num_beam)]
for i in range(max_new_tokens):
# Each timestep
all_next_token_prob = []
# 对每一束进行操作,实时计算每一束的整个序列的 prob
for beam in beams:
# 对每一束进行预测后续的内容
# 假设 k=3,A1 -> B2,C2,D2(每一个还是一个概率分布)
next_token_probs = self.get_next_token_prob(input_ids=beam.get_current_state())
all_next_token_prob.append(next_token_probs)
# 对于每一 timestep,不同束的所有概率一起来看
all_topk_scores = []
all_topk_token_ids = []
for prob in all_next_token_prob:
# 对于 B2,C2,D2 概率分布进行处理,分别选择自己的 topk
socres, token_ids = self.get_topk(prob, k=num_beam)
all_topk_scores.append(socres)
all_topk_token_ids.append(token_ids)
all_topk_scores = torch.stack(all_topk_scores)
all_topk_token_ids = torch.stack(all_topk_token_ids)
# 进行新 beam 的选择,超过 num_beam 的就不选了
new_beams = []
for j, beam in enumerate(beams):
for k in range(num_beam):
score = all_topk_scores[j][k].item()
token_id = all_topk_token_ids[j][k].unsqueeze(0)
new_beam = beam.extend(token_id, score)
new_beams.append(new_beam)
# 对已添加新 token 的所有 beam 根据整个序列的 prob 排序
beams = sorted(new_beams, key=lambda b: b.get_score(), reverse=True)[:num_beam]
generated_text = self.beam_decode(beams[0].output[:, len(input_ids[0]):])
# [:,len(input_ids[0]):] 表示排除输入部分
return prompt + generated_text
Beam Search 的代码较为复杂,之后可以提供一个可视化的图来更直观的表述。其中每个序列的 probs score 计算方法仍是基于条件概率
每一个 Beam 增加了新 token 后,会直接计算其代表的整个序列的分数,最后 Beams 列表中分数最大的路径则为应该选择的路径。从全局视角看也是贪婪解码,因为选择了分数最大的 beam,而从局部视野看,增加了可选择路径的宽度。相较于 Greedy Search 而言需要更多计算资源,需要在每一步维护和计算 $k$ 个序列的概率;也无法保证找到最可能的序列,尤其是在 $k \ll |V|$ 时。
Reference
Contact
There may be some errors present. If you find any, please feel free to contact me at wangqiyao@mail.dlut.edu.cn
. I would appreciate it!