CS336 1/?
BPE 算法实现笔记
从零实现 Byte Pair Encoding,记录踩坑与核心洞察。
什么是 BPE?
Byte Pair Encoding(字节对编码)是构建 LLM 词表的核心算法。
核心思想只有一句话:
反复找出语料中最频繁的相邻 token 对,将其合并为新 token,直到词表大小达到目标。
它解决了一个平衡问题:
| 粒度 | 优点 | 缺点 |
|---|---|---|
| 纯字符级 | 词表小 | 序列太长,语义弱 |
| 纯词级 | 语义强 | 词表爆炸,OOV 问题 |
| BPE | 两者平衡 | — |
第一步:预分词(Pretokenization)
BPE 训练之前,需要先把文本切成「词」单元。
为什么需要这一步?
如果不预分词,BPE 会跨越词边界合并,产生没有语言意义的 token:
"end\nthe"里的d\n会被合并 → 无意义" the"(带空格)和"the"语义不同,却可能被当作同一单元处理
GPT-2 的预分词 Pattern
⚠️ 必须用
tiktoken实际使用的 pattern,而不是网上流传的版本——两者对换行符处理不同,会导致 merge 顺序偏差。
import regex
GPT2_PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s"""pat = regex.compile(GPT2_PAT)必须用
regex库(非标准库re),因为需要\p{L}、\p{N}等 Unicode 属性支持。
关键规则:空格属于后面那个词," the" 是一个完整单元。
处理 Special Tokens
Special tokens(如 <|endoftext|>)不参与 BPE 训练,需要先切分文本:
# 长的先匹配,防止短的截断长的special_tokens = sorted(special_tokens, key=len, reverse=True)
# regex.escape 防止 <| 等字符被解释为正则元字符split_pat = "|".join(f"(?:{regex.escape(st)})" for st in special_tokens)chunks = regex.split(split_pat, text)顺序:文本 → 按 special tokens 切分 → 对每个 chunk 预分词 → word_freqs
第二步:数据结构
# 每个预分词单元(字节 tuple)的出现频次word_freqs: dict[tuple[bytes, ...], int]
# 所有相邻 pair 的总频次from collections import Counterpair_counts: Counter # (bytes, bytes) -> int初始化 vocab
BPE 从最小单位出发——所有可能的单字节(256 个):
vocab = {i: bytes([i]) for i in range(256)}for st in special_tokens: vocab[len(vocab)] = st.encode("utf-8")初始化 pair_counts
for word, freq in word_freqs.items(): for i in range(len(word) - 1): pair_counts[(word[i], word[i + 1])] += freq第三步:主循环
重复 (vocab_size - 初始vocab大小) 次: 1. 找最频繁的 pair 2. 记录到 merges 3. 加入 vocab 4. 更新 word_freqs 和 pair_countsTiebreak 规则
频次相同时,选字典序更大的 pair:
best_pair = max( (p for p in pair_counts if pair_counts[p] > 0), key=lambda p: (pair_counts[p], p))核心优化:增量更新 pair_counts
朴素做法(慢)
每次 merge 后重新扫描全部词,重建 Counter → O(n) per merge → 太慢。
高效做法
每次 merge 之后,只有被合并 token 的邻居 pair 会变化,其余不变。
当把 (a, b) 合并成 ab 时,对词 ... x a b y ...:
合并前:... x a b y ...合并后:... x ab y ...| 操作 | pair | 原因 |
|---|---|---|
| 减少 | (x, a) | x 的右邻居从 a 变成了 ab |
| 减少 | (b, y) | y 的左邻居从 b 变成了 ab |
| 增加 | (x, ab) | 新搭档出现 |
| 增加 | (ab, y) | 新搭档出现 |
边界条件:x 不存在(a 在词首)或 y 不存在(b 在词尾)时跳过对应更新。
if i > 0: pair_counts[(new_word[-1], a)] -= freq # 用 new_word[-1] 取左邻居! pair_counts[(new_word[-1], merged)] += freq
new_word.append(merged)
if i + 2 < len(word): pair_counts[(b, word[i + 2])] -= freq pair_counts[(merged, word[i + 2])] += freq⚠️ 踩坑:左邻居索引
取左邻居时必须用 new_word[-1],不能用 word[i-1]。
原因:词 (a, b, a, b) 中,第二个 (a, b) 的左邻居已经是 merged,而不是原词里的 b。用 word[i-1] 会指向错误的 token。
⚠️ 踩坑:****(a, b) 自身计数
merge 后要立即清除 (a, b) 的计数,否则下次循环可能重复选同一个 pair:
a, b = best_pairdel pair_counts[best_pair] # 在更新 word_freqs 之前⚠️ 踩坑:迭代时修改字典
不能在遍历 word_freqs.items() 的同时修改它,用 to_update 收集变化,循环后统一更新:
to_update = {}for word, freq in word_freqs.items(): # ... 构建 new_word ... to_update[word] = (tuple(new_word), freq)
for old_word, (new_word, freq) in to_update.items(): del word_freqs[old_word] word_freqs[new_word] = word_freqs.get(new_word, 0) + freq # 用 .get(..., 0) 防止 new_word 原先就存在完整流程图
输入文本 │ ▼按 special tokens 切分 → chunks │ ▼GPT-2 regex 预分词 → word_freqs │ ▼初始化 vocab(256字节 + special tokens)初始化 pair_counts │ ▼┌─────────────────────────────┐│ 找最频繁 pair (tiebreak:大) ││ → 加入 merges, vocab ││ → 增量更新 pair_counts ││ → 更新 word_freqs │└──────────────┬──────────────┘ │ 重复直到 vocab_size ▼ (vocab, merges)测试要求
corpus.en,vocab_size=500:< 1.5 秒完成- merges 顺序必须与 GPT-2 参考实现完全一致
- special tokens 不出现在任何 BPE merge 结果中
文章分享
如果这篇文章对你有帮助,欢迎分享给更多人!
部分内容可能已过时
printsdf's Blog