CS336 2/?
从零实现 BPE Tokenizer:CS336 实战总结
在斯坦福 CS336(大模型基础)课程中,有一个经典的动手作业:从头实现一个 BPE(Byte Pair Encoding)Tokenizer。这篇文章记录了完整的实现思路、核心数据结构,以及那些真实踩过的坑。
什么是 BPE Tokenizer?
BPE Tokenizer 的核心职责很简单:
- encode:把一段文本转换成 token ID 序列
- decode:把 token ID 序列还原回文本
看起来简单,但实现细节充满陷阱。
核心数据结构
vocab: dict[int, bytes] # ID → bytesmerges: list[tuple[bytes, bytes]] # 按创建顺序排列的合并规则reverse_vocab: dict[bytes, int] # bytes → ID(初始化时构建)三个结构各司其职:
vocab是主索引,给 decode 用merges记录了 BPE 训练出的所有合并规则,顺序至关重要reverse_vocab是vocab的反向查找表,初始化时一次性构建,encode 时高频使用
encode 流程
encode 分两个阶段处理:special tokens 和普通文本。
第一步:用 special tokens 切分文本
# 构建正则,special tokens 按长度降序排列sorted_specials = sorted(self.special_tokens, key=len, reverse=True)pattern = "(" + "|".join(re.escape(s) for s in sorted_specials) + ")"parts = regex.split(pattern, text)关键点 1:用捕获组 (...) 而不是非捕获组 (?:...)
regex.split 如果用非捕获组,分隔符本身会从结果中消失——special token 就丢了。用捕获组才能让 special token 出现在 split 结果里。
关键点 2:special tokens 必须按长度降序排列
如果有 <|im_start|> 和 <|im|> 两个 special token,短的不能优先匹配,否则长的永远匹配不到。排序后构建正则,优先尝试最长匹配。
关键点 3:用排序后的变量构建正则
这听起来是废话,但实际上很容易写成:
# 错误写法:忘了用排序后的变量pattern = "(" + "|".join(re.escape(s) for s in self.special_tokens) + ")"必须用 sorted_specials,不能用原始的 self.special_tokens。
第二步:分别处理每个片段
for part in parts: if part in self.special_tokens: ids.append(self.reverse_vocab[part.encode('utf-8')]) else: ids.extend(self._encode_chunk(part))- special token 片段:直接查
reverse_vocab - 普通文本片段:走
_encode_chunk
_encode_chunk:BPE 合并的核心
def _encode_chunk(self, text: str) -> list[int]: # 1. 预分词(用 GPT-2 / tiktoken 风格的 regex) words = pretokenize(text)
ids = [] for word in words: # 2. UTF-8 编码,转成 tuple[bytes, ...] tokens = tuple(bytes([b]) for b in word.encode('utf-8'))
# 3. 按顺序应用所有 merges for pair in self.merges: tokens = apply_merge(tokens, pair)
# 4. 查 reverse_vocab 得到 ID ids.extend(self.reverse_vocab[t] for t in tokens)
return ids关键点:字节迭代的陷阱
Python 中,对 bytes 对象直接迭代得到的是整数,不是单字节 bytes:
for b in "hello".encode('utf-8'): print(type(b)) # <class 'int'>,不是 bytes!所以必须用:
tuple(bytes([b]) for b in word.encode('utf-8'))而不是:
tuple(word.encode('utf-8')) # 得到整数 tuple,查 reverse_vocab 会 KeyErrordecode 流程
decode 相对简单:
def decode(self, ids: list[int]) -> str: return b''.join(self.vocab[i] for i in ids).decode('utf-8', errors='replace')先把每个 ID 映射回 bytes,拼接后统一做 UTF-8 解码。用 errors='replace' 处理边界处可能出现的不完整 UTF-8 序列。
踩坑总结
| 坑 | 原因 | 解法 |
|---|---|---|
| special token 从 split 结果消失 | 用了非捕获组 (?:...) | 改成捕获组 (...) |
| 长 special token 匹配失败 | 正则未排序,短 token 优先 | 按长度降序排列后构建正则 |
| 正则排序没生效 | 用了原始 self.special_tokens | 用排序后的 sorted_specials |
reverse_vocab 查找 KeyError | 字节迭代得到整数而非 bytes | 用 bytes([b]) 包装每个字节 |
小结
BPE Tokenizer 的实现不难,但细节密集。最容易出问题的地方集中在两处:
- 正则切分 special tokens — 捕获组 vs 非捕获组,以及排序问题
- Python 的字节类型行为 —
bytes迭代出整数这个反直觉的特性
把这些细节搞清楚之后,整个 tokenizer 的逻辑其实非常清晰。理解了这些,再去读 tiktoken 或 HuggingFace tokenizers 的源码,会有一种豁然开朗的感觉。
文章分享
如果这篇文章对你有帮助,欢迎分享给更多人!
printsdf's Blog