CS336 Assignment 1: BPE Tokenizer's Detailed Implementation

Aug. 01, 2025 - Sep. 29, 2025 · Qiyao Wang

BPE Tokenizer

BPE Training Example (inefficient yet valid logic for understanding)

Assignment 1 给了一个用于理解 BPE 流程的简单的处理过程,我将其形成了相应的代码,但是相对低效,BPE 的训练流程及代码如下:( 所处理的文本见代码 )

  1. Vocabulary Initialization: 初始化一个 256+1 大小的字典,前 256 为 0-255 每个字节的表示 第 257 个 token 为 特殊 token <|endoftext|>
  2. Pre-Tokenization: 将文本进行预分词,避免导致类似 "dog!" 和 "dog." 语义相近,但分词结果完全不同的情况,此处使用直接的 split 函数,一般而言可以使用正则的手段进行(后续);
  3. Merges: 每一轮,计算每个元组中相邻元素组成的新 pair 的频率,选择频率最大的(相同频率时,选择字典序大的)pair 进行合并,作为引入的新 token,经过 $N$ 次迭代(超参数,此处使用 $N=6$)
from typing import Dict
from collections import Counter

text = """low low low low low
lower lower widest widest widest
newest newest newest newest newest newest
"""

def init_vocab() -> Dict[str, int]:
    """256 byte + <|endoftext|>"""
    vocab = dict()
    for i in range(256):
        item = {
            f"{chr(i)}": i
        }
        vocab.update(item)

    eot = {
        "<|endoftext|>": 256
    }
    vocab.update(eot)
    return vocab

def pre_tokenization(text):
    text_split = text.split()
    text_counter = Counter(text_split)
    return dict(text_counter)

def pair_count(cand_list):
    status = Counter()
    for cand in cand_list:
        cand_key = list(cand.keys())[0]
        cand_value = list(cand.values())[0]
        for i in range(len(cand_key) - 1):
            # i: [0,1,2] i+1: [1,2,3] ❌
            # i: [0,1] i+1: [1,2] ✅
            pair = (cand_key[i], cand_key[i + 1])
            status[pair] += cand_value
    return dict(status)

def merge(pair_counted, cand_list):

    # 频率相同时,选择字典序大的
    max_freq = max(pair_counted.values())
    max_pairs = [pair for pair, freq in pair_counted.items() if freq == max_freq]
    max_cnt_pair = max(max_pairs)

    new_cand_list = []
    for cand in cand_list:
        cand_key = list(cand.keys())[0]
        cand_value = list(cand.values())[0]
        # 判断 (a, b) slice 是否在 x1,x2,...xn 中,保持邻近顺序
        is_in = any(cand_key[i:i+len(max_cnt_pair)] == max_cnt_pair for i in range(len(cand_key) - 1))
        if is_in:
            new_key = []
            i = 0
            while i < len(cand_key):
                if cand_key[i:i + len(max_cnt_pair)] == max_cnt_pair:
                    new_key.append("".join(max_cnt_pair))
                    i += len(max_cnt_pair)
                else:
                    new_key.append(cand_key[i])
                    i += 1

            new_key = tuple(new_key)
            new_cand_list.append({new_key: cand_value})
        else:
            new_cand_list.append({cand_key: cand_value})
    return new_cand_list, max_cnt_pair

if __name__ == '__main__':
    num_merges = 6

    vocab = init_vocab()
    pre_tokenized = pre_tokenization(text)

    merges = []

    for word, count in pre_tokenized.items():
        item_key = tuple(list(word))
        merges.append({
            item_key: count
        })

    new_tokens = []

    for _ in range(num_merges):
        pair_cnt = pair_count(merges)
        merges, new_token = merge(pair_cnt, merges)
        new_tokens.append(new_token)
        print(merges)
    print(new_tokens)
    new_tokens = ["".join(item) for item in new_tokens]
    print(new_tokens)
    for token in new_tokens:
        vocab_size = len(vocab)
        vocab[token] = vocab_size
    print(vocab)

如下图 1 所示为相应的运行结果,其中代码中有步骤级的打印结果

bpe-training-example
图1:BPE Training Example 运行结果

BPE Tokenizer Training (More Efficient Manner)

对于"Problem (train_bpe): BPE Tokenizer Training (15 points)",在adapter.pyrun_bpe_tokenizer()中引入bpe_tokenizer.py中的相关函数,具体实现如下

import regex as re
import os
from multiprocessing import Pool # 多线程
from typing import BinaryIO
from collections import defaultdict
import time

"""对大规模文本进行分块处理,寻找块边界,方便后续的并行处理"""
def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

"""对每个chunk在tokenizer之前进行预处理、预分词、统计相应的字节对频率"""
def process_chunk(args: tuple[str, int, int, list[str]]) -> list[list[bytes]]:
    input_path, start, end, special_tokens = args
    """
    Processing a chunk of the input file
    returns byte pair frequency counts

    Args:
        input_path (str): the path of input file
        start (int): the start byte offset of the chunk
        end (int): the end byte offset of the chunk
        special_tokens (list[bytes]): the list of special tokens that should not be merged

    Returns:
        pre_token_bytes (list[list[bytes]])
            list of tokens, each token is a list of bytes
    """

    with open(input_path, "rb") as file:
        file.seek(start)
        chunk = file.read(end - start).decode("utf-8", errors="ignore")

    """1. Remove special tokens
    Construct Regrex Pattern
    | means or
    re.escape(tok) escape each special character
    """
    pattern = "|".join(re.escape(tok) for tok in special_tokens)
    documents = re.split(pattern, chunk)

    """2. Pre-Tokenize and count byte pair frequencies"""
    pre_tokens_bytes: list[list[bytes]] = []
    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    for doc in documents:
        tokens = [match.group(0).encode("utf-8") for match in re.finditer(PAT, doc)]
        for token in tokens:
            token_bytes = [bytes([b]) for b in token]
            pre_tokens_bytes.append(token_bytes)

    return pre_tokens_bytes

def train_bpe(
        input_path: str,
        vocab_size: int,
        special_tokens: list[str],
        num_processes: int = 16
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """
    Training a BPE Tokenizer on the given corpus.
    :param input_path: str, training corpus path
    :param vocab_size: int, size of the final vocabulary
    :param special_tokens: list[str], list of special tokens
        e.g. ["<|endoftext|>", ""] which directly added into vocabulary, not participate in merge process
    :param num_processes: int, optional (default=8)
    :return:
        vocab: a dictionary mapping token IDs to token values
            dict[int, bytes] {0: "0", ...}
        merges: list of merged tokens, where each tuple represents two byte-level tokens
            list[tuple[bytes, bytes]]
    """
    start_time = time.time()
    # print("Tokenizing corpus...")
    # 1. Initialize Vocabulary
    vocab = {
        i: bytes([i]) for i in range(256)
    }
    for token in special_tokens:
        """After add each special token into vocab, its length +1"""
        vocab[len(vocab)] = token.encode("utf-8")

    # 2. Pre-tokenization
    """find_chunk_boundaries(file, desired_num_chunks, split_special_token)"""
    with open(input_path, "rb") as f:
        boundaries = find_chunk_boundaries(f, num_processes, "<|endoftext|>".encode("utf-8"))
    """Construct Multi-thread task"""
    """[0 : len(boundaries) - 2], [1 : len(boundaries) - 1]"""
    task_args = [(input_path, start, end, special_tokens) for start, end in zip(boundaries[:-1], boundaries[1:])]
    with Pool(processes=num_processes) as pool:
        chunk_results = pool.map(process_chunk, task_args)


    # 3. Compute BPE merges
    merges: list[tuple[bytes, bytes]] = []
    pre_tokens_bytes: list[list[bytes]] = [token for chunk in chunk_results for token in chunk]
    counts = defaultdict(int)
    pair_to_indices = defaultdict(set) # 有点类似倒排索引表
    for idx, token in enumerate(pre_tokens_bytes):
        for i in range(len(token) - 1): # Token is a list of bytes
            pair = (token[i], token[i + 1])
            counts[pair] += 1
            pair_to_indices[pair].add(idx)

    idx = len(vocab)
    while idx < vocab_size:
        if not counts:
            break

        max_pair: tuple[bytes, bytes] = None
        max_cnt = -1
        for pair, cnt in counts.items():
            if cnt > max_cnt:
                max_pair = pair
                max_cnt = cnt
            elif cnt == max_cnt:
                if max_pair is None or pair > max_pair:
                    max_pair = pair
                    # pair > max_pair select the one which is lexicographically larger

        merges.append(max_pair)
        a, b = max_pair
        new_token = a + b
        vocab[idx] = new_token
        idx += 1

        # process affected indices: find max_pair in which token
        affected_indices = pair_to_indices[max_pair].copy() # ! Note this copy
        for j in affected_indices:
            token = pre_tokens_bytes[j] # find the token which contains the max_pair
            for i in range(len(token) - 1): # change the consists of current token
                old_pair = (token[i], token[i + 1])
                pair_to_indices[old_pair].discard(j)
                # 去除倒排索引表中的docid
                counts[old_pair] -= 1
                # 合并后,token中其他的pair也要变化
                if counts[old_pair] == 0:
                    counts.pop(old_pair)
                    pair_to_indices.pop(old_pair, None)

            merged = [] # 合并后
            i = 0
            while i < len(token):
                if i < len(token) - 1 and token[i] == a and token[i+1] == b: # Location a+b
                    merged.append(new_token)
                    i += 2
                else:
                    merged.append(token[i])
                    i += 1

            pre_tokens_bytes[j] = merged

            token = pre_tokens_bytes[j]
            for i in range(len(token) - 1):
                pair = (token[i], token[i + 1])
                counts[pair] += 1
                pair_to_indices[pair].add(j)
    end_time = time.time()
    print(f"BPE training finished in {end_time - start_time:.2f} seconds")
    return vocab, merges

上述代码,处理TinyStories数据集耗时3小时 (11480.69s),应进行优化;对于owt数据集则需要更长的时间,如何加速?

TODO List

  • 分析BPE Tokenizer速度慢的原因,各环节耗时
  • 设计方式加速BPE Tokenizer

Reference

[1] CS336 Assignment 1 (basics): Building a Transformer LM. Version 1.0.5

[2] Stanford CS336 | Assignment 1 - BPE Tokenizer Training 实现

Contact

There may be some errors present. If you find any, please feel free to contact me at wangqiyao@mail.dlut.edu.cn. I would appreciate it!