Greedy Search and Beam Search

Feb. 26, 2025 · Qiyao Wang #Decoding

贪婪解码,每一步从词汇表中选择具有最高条件概率的 token 作为 next token,直到遇到结束 token 或达到最大的上下文长度。

在撰写贪婪解码的代码前,我对自回归解码中撰写的 Sampler 基类进行了完善,为了提高代码的复用性,在其中提供了 get_next_token 接口。

class Sampler:
    def __init__(self, model_name: str="Qwen2.5-0.5B") -> None:
        self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)

    def encode(self, text: str):
        return self.tokenizer.encode(text, return_tensors="pt").to(self.device)

    def decode(self, ids: torch.Tensor):
        return self.tokenizer.decode(ids)

    def get_next_token_prob(self, input_ids:torch.Tensor):
        # 禁止计算图中梯度的计算
        with torch.no_grad():
            logits = self.model(input_ids=input_ids).logits
        # 在此之前,logits 形状为 torch.Size([1, 1, 151936])
        # 获得 Tensor 的最后一维度 torch.Size([151936])
        logits = logits[0, -1, :]
        probs = torch.softmax(logits, dim=-1)
        return probs

    def plot_scores(self, scores, title, k):
        """
        :param scores: 排序对象
        :param title: 图片标题
        :param k: 展示的数量
        :return: None
        """
        top_indices = torch.argsort(scores, descending=True)[:k]
        tokens = [self.decode(idx) for idx in top_indices]

        if self.device == "cpu":
            top_probs = scores[top_indices].numpy()
        else:
            top_probs = scores[top_indices].cpu().numpy()

        colors = ['#E95B68', '#C4C956', '#58BB7B', '#CAC1C5', '#87601F', '#F7311B',
                  '#C53D39', '#38658F', '#242ABC', '#9DA52F', '#329018', '#D415C5',
                  '#6DCE59', '#ADF212', '#9CF042']

        colors = colors[0: len(top_indices)]

        fig = go.Figure(
            data=[
                go.Bar(x=tokens, y=top_probs, marker_color=colors, textposition="inside")
            ]
        )
        fig.update_layout(title=title)
        fig.show()

    def get_next_token(self, text: str, k: int):
        input_ids = self.encode(text)
        next_token_probs = self.get_next_token_prob(input_ids)
        top_indices = torch.argsort(next_token_probs, descending=True)[:k]
        tokens = [self.decode(idx) for idx in top_indices]
        return {
            "token": tokens,
            "next_token_probs": next_token_probs,
            "ids": top_indices
        }

    def get_next_token_plot_pipeline(self, text: str, k: int=10):
        input_ids = self.encode(text)
        next_token_prob=self.get_next_token_prob(input_ids)
        self.plot_scores(next_token_prob, text, k=k)

基于 Sampler 基类,构建 GreedySampler,该类以 Sampler 类为父类,继承其中的方法。

class GreedySampler(Sampler):
    def __call__(self, prompt, max_new_tokens=10):
        predictions = []
        result = prompt
        # generate tokens until the max_new_tokens
        for i in range(max_new_tokens):
            # greedy search => k=1
            next_token_dict =  self.get_next_token(result, k=1)
            next_token = next_token_dict['token'][0]
            result += next_token
            next_token_probs = next_token_dict['next_token_probs']
            ids = next_token_dict['ids']
            predictions.append(next_token_probs[ids].item())

            # 判断是否生成结束 token
            if next_token == self.tokenizer.eos_token_id:
                break

        return result

基于 Qwen2.5-0.5B-Instruct 进行实验,输入 prompt 为 "the color of sky is",其返回结果为:"The color of sky is always changing. The sky is blue when the sun is up, and it is white when the sun"。

Greedy Search 本质是一种贪心算法,虽然在解码策略上它较为简单和高效,所需的计算资源较其他解码策略来说较为简单,仍存在一些缺点:

  • 多样性且无长远考虑:短视、贪心地每次输出概率最大的 token,忽略了多样性层面、长远的考虑。
  • 重复性:由于选择最可能的词,导致实验的可重复性高,但是结果容易被预测。
  • 错误放大:贪心算法无法纠正错误,一旦前序解码中存在一定的错误,之后的选择都会受到影响。

束搜索相较于 Greedy Search 而言,不是只考虑每一步的最优情况,而是在束宽度 $k$ 参数下同时跟踪多个潜在序列。

在每个阶段,选择 top-$k$ 个序列,不只考虑即时的高概率词,同时关注整体序列的概率。针对束搜索解码中可能包含重复的相同词序列的问题,使用 n-gram 惩罚的概念,即如果一个 n-gram 序列被生成放入序列中,而该 n-gram 已经在序列中存在,则设置其概率为 0。

class Beam:
    def __init__(self, device, size, input_ids, socre, output=None):
        self.device = device
        self.size = size # num_beam => k
        self.input_ids = input_ids.to(self.device)
        self.socre = socre
        self.output = output.to(self.device) if output is not None else None

    # get input_ids
    def get_current_state(self):
        return self.input_ids

    # get probs of the sentence
    def get_score(self):
        return self.socre

    # create a new instance of Beam class after the top-k selection
    def extend(self, token_id, score):
        # the input_ids of new sentence
        new_input_ids = torch.cat([self.input_ids, token_id.unsqueeze(0)], dim=-1)
        # the probs score of new sentence
        new_score = self.socre * score
        new_output = torch.cat([self.output, token_id.unsqueeze(0)], dim=-1) if self.output is not None else new_input_ids
        # 递归
        return Beam(self.device, self.size, new_input_ids, new_score, new_output)
class BeamSampler(Sampler):
    def beam_decode(self, ids):
        return self.tokenizer.decode(ids.squeeze().tolist())

    # Get the top-k id with the greatest probs
    # 静态方法不能访问类的属性或实例的属性,它只与输入参数有关
    @staticmethod
    def get_topk(prob, k=1):
        scores, token_ids = torch.topk(prob, k=k, dim=-1)
        return scores, token_ids

    def __call__(self, prompt, max_new_tokens=10, num_beam=1):
        input_ids = self.encode(prompt)

        # 初始化 Beam,只有最初的节点 A1
        beams = [Beam(self.device, num_beam, input_ids, 1) for _ in range(num_beam)]

        for i in range(max_new_tokens):
            # Each timestep
            all_next_token_prob = []
            # 对每一束进行操作,实时计算每一束的整个序列的 prob
            for beam in beams:
                # 对每一束进行预测后续的内容
                # 假设 k=3,A1 -> B2,C2,D2(每一个还是一个概率分布)
                next_token_probs = self.get_next_token_prob(input_ids=beam.get_current_state())
                all_next_token_prob.append(next_token_probs)

            # 对于每一 timestep,不同束的所有概率一起来看
            all_topk_scores = []
            all_topk_token_ids = []

            for prob in all_next_token_prob:
                # 对于 B2,C2,D2 概率分布进行处理,分别选择自己的 topk
                socres, token_ids = self.get_topk(prob, k=num_beam)
                all_topk_scores.append(socres)
                all_topk_token_ids.append(token_ids)

            all_topk_scores = torch.stack(all_topk_scores)
            all_topk_token_ids = torch.stack(all_topk_token_ids)

            # 进行新 beam 的选择,超过 num_beam 的就不选了
            new_beams = []
            for j, beam in enumerate(beams):
                for k in range(num_beam):
                    score = all_topk_scores[j][k].item()
                    token_id = all_topk_token_ids[j][k].unsqueeze(0)
                    new_beam = beam.extend(token_id, score)
                    new_beams.append(new_beam)

            # 对已添加新 token 的所有 beam 根据整个序列的 prob 排序
            beams = sorted(new_beams, key=lambda b: b.get_score(), reverse=True)[:num_beam]

        generated_text = self.beam_decode(beams[0].output[:, len(input_ids[0]):])
        # [:,len(input_ids[0]):] 表示排除输入部分

        return prompt + generated_text

Beam Search 的代码较为复杂,之后可以提供一个可视化的图来更直观的表述。其中每个序列的 probs score 计算方法仍是基于条件概率

$$ p(\mathbf{x})=\prod_{t=1}^Tp(x_1)\cdot p(x_2\mid x_1)\cdots p(x_T\mid x_{T-1},x_{T-2},...,x_{1}) $$

每一个 Beam 增加了新 token 后,会直接计算其代表的整个序列的分数,最后 Beams 列表中分数最大的路径则为应该选择的路径。从全局视角看也是贪婪解码,因为选择了分数最大的 beam,而从局部视野看,增加了可选择路径的宽度。相较于 Greedy Search 而言需要更多计算资源,需要在每一步维护和计算 $k$ 个序列的概率;也无法保证找到最可能的序列,尤其是在 $k \ll |V|$ 时。

Reference

Contact

There may be some errors present. If you find any, please feel free to contact me at wangqiyao@mail.dlut.edu.cn. I would appreciate it!