1138 字
6 分钟

CS336 1/?

2026-03-22
浏览量 加载中...

BPE 算法实现笔记#

从零实现 Byte Pair Encoding,记录踩坑与核心洞察。


什么是 BPE?#

Byte Pair Encoding(字节对编码)是构建 LLM 词表的核心算法。

核心思想只有一句话:

反复找出语料中最频繁的相邻 token 对,将其合并为新 token,直到词表大小达到目标。

它解决了一个平衡问题:

粒度优点缺点
纯字符级词表小序列太长,语义弱
纯词级语义强词表爆炸,OOV 问题
BPE两者平衡

第一步:预分词(Pretokenization)#

BPE 训练之前,需要先把文本切成「词」单元。

为什么需要这一步?#

如果不预分词,BPE 会跨越词边界合并,产生没有语言意义的 token:

  • "end\nthe" 里的 d\n 会被合并 → 无意义
  • " the"(带空格)和 "the" 语义不同,却可能被当作同一单元处理

GPT-2 的预分词 Pattern#

⚠️ 必须用 tiktoken 实际使用的 pattern,而不是网上流传的版本——两者对换行符处理不同,会导致 merge 顺序偏差。

import regex
GPT2_PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s"""
pat = regex.compile(GPT2_PAT)

必须用 regex 库(非标准库 re),因为需要 \p{L}\p{N} 等 Unicode 属性支持。

关键规则:空格属于后面那个词" the" 是一个完整单元。

处理 Special Tokens#

Special tokens(如 <|endoftext|>)不参与 BPE 训练,需要先切分文本:

# 长的先匹配,防止短的截断长的
special_tokens = sorted(special_tokens, key=len, reverse=True)
# regex.escape 防止 <| 等字符被解释为正则元字符
split_pat = "|".join(f"(?:{regex.escape(st)})" for st in special_tokens)
chunks = regex.split(split_pat, text)

顺序:文本按 special tokens 切分对每个 chunk 预分词word_freqs


第二步:数据结构#

# 每个预分词单元(字节 tuple)的出现频次
word_freqs: dict[tuple[bytes, ...], int]
# 所有相邻 pair 的总频次
from collections import Counter
pair_counts: Counter  # (bytes, bytes) -> int

初始化 vocab#

BPE 从最小单位出发——所有可能的单字节(256 个):

vocab = {i: bytes([i]) for i in range(256)}
for st in special_tokens:
   vocab[len(vocab)] = st.encode("utf-8")

初始化 pair_counts#

for word, freq in word_freqs.items():
   for i in range(len(word) - 1):
       pair_counts[(word[i], word[i + 1])] += freq

第三步:主循环#

重复 (vocab_size - 初始vocab大小) 次:
1. 找最频繁的 pair
2. 记录到 merges
3. 加入 vocab
4. 更新 word_freqs 和 pair_counts

Tiebreak 规则#

频次相同时,选字典序更大的 pair:

best_pair = max(
  (p for p in pair_counts if pair_counts[p] > 0),
   key=lambda p: (pair_counts[p], p)
)

核心优化:增量更新 pair_counts#

朴素做法(慢)#

每次 merge 后重新扫描全部词,重建 Counter → O(n) per merge → 太慢。

高效做法#

每次 merge 之后,只有被合并 token 的邻居 pair 会变化,其余不变。

当把 (a, b) 合并成 ab 时,对词 ... x a b y ...

合并前:... x a b y ...
合并后:... x ab   y ...
操作pair原因
减少(x, a)x 的右邻居从 a 变成了 ab
减少(b, y)y 的左邻居从 b 变成了 ab
增加(x, ab)新搭档出现
增加(ab, y)新搭档出现

边界条件:x 不存在(a 在词首)或 y 不存在(b 在词尾)时跳过对应更新。

if i > 0:
   pair_counts[(new_word[-1], a)] -= freq   # 用 new_word[-1] 取左邻居!
   pair_counts[(new_word[-1], merged)] += freq
new_word.append(merged)
if i + 2 < len(word):
   pair_counts[(b, word[i + 2])] -= freq
   pair_counts[(merged, word[i + 2])] += freq

⚠️ 踩坑:左邻居索引#

取左邻居时必须用 new_word[-1]不能用 word[i-1]

原因:词 (a, b, a, b) 中,第二个 (a, b) 的左邻居已经是 merged,而不是原词里的 b。用 word[i-1] 会指向错误的 token。

⚠️ 踩坑:****(a, b) 自身计数#

merge 后要立即清除 (a, b) 的计数,否则下次循环可能重复选同一个 pair:

a, b = best_pair
del pair_counts[best_pair]  # 在更新 word_freqs 之前

⚠️ 踩坑:迭代时修改字典#

不能在遍历 word_freqs.items() 的同时修改它,用 to_update 收集变化,循环后统一更新:

to_update = {}
for word, freq in word_freqs.items():
   # ... 构建 new_word ...
   to_update[word] = (tuple(new_word), freq)
for old_word, (new_word, freq) in to_update.items():
   del word_freqs[old_word]
   word_freqs[new_word] = word_freqs.get(new_word, 0) + freq
   # 用 .get(..., 0) 防止 new_word 原先就存在

完整流程图#

输入文本
按 special tokens 切分 → chunks
GPT-2 regex 预分词 → word_freqs
初始化 vocab(256字节 + special tokens)
初始化 pair_counts
┌─────────────────────────────┐
│ 找最频繁 pair (tiebreak:大) │
│ → 加入 merges, vocab       │
│ → 增量更新 pair_counts     │
│ → 更新 word_freqs         │
└──────────────┬──────────────┘
              │ 重复直到 vocab_size
             
        (vocab, merges)

测试要求#

  • corpus.en,vocab_size=500:< 1.5 秒完成
  • merges 顺序必须与 GPT-2 参考实现完全一致
  • special tokens 不出现在任何 BPE merge 结果中

文章分享

如果这篇文章对你有帮助,欢迎分享给更多人!

CS336 1/?
https://printsdf.dpdns.org/posts/cs336-1/
作者
printsdf
发布于
2026-03-22
许可协议
CC BY-NC-SA 4.0
最后更新于 2026-03-22,距今已过 33 天

部分内容可能已过时

评论区

Profile Image of the Author
printsdf
Hello, I'm printsdf.
公告
欢迎来到我的博客!这是一则示例公告。
音乐
封面

音乐

暂未播放

0:00 0:00
暂无歌词
分类
标签
站点统计
文章
37
分类
12
标签
14
总字数
47,088
运行时长
0
最后活动
0 天前

目录