Class 2 BPE Tokenizer:Encoder & Decoder

在第一节课中我们训练了一个输入文本,输出为token词表和BPE合并项列表的BPE Tokenizer。在这第二节课中,我们将为它增添更多的功能,使其能够接收上面的词表与合并项列表,并使用它们实现文本和token ID的编解码。

编码

对文本编码的过程其实和我们之前训练BPE分词器的过程是类似的。主要分为以下两个步骤:

  1. 预分词

就像我们之前做的一样,首先对文本序列进行预分词,并用UTF-8编码表示分词序列。接下来我们将各个预分词的“token碎片”合并为词表中有的单词。

注意:在合并时,不会在两个预分词之间进行合并,不然到时候解码会造成语义的丢失。

  1. 应用合并

词表中的每个元素都有一个编号,这使得我们可以仅保存编号而不需要保存实际的文本。当我们对输入文本进行预分词后,需要去词表中找到对应的单词编号,从而实现文本的编号化。

此外还有几点需要注意,Tokenizer应该能够正确处理特殊token;并且由于内存限制的原因,程序也需要分段读写,在这期间不能跨段分开处理token,不然会导致编码上的问题。

解码

解码的过程其实就是编码的逆过程。可以通过将编号对应的词表单词重新组合,并转换这些字节为Unicode字符来实现。

但是,要特别注意的是用户可能会随意输入任意的整数ID序列,这部分序列并不能转换为一个合法的Unicode字符。在这种情况下,需要将这种序列替换为Unicode官方规定的替换字符U+FFFD。同时,bytes.decode有一个error字段,用于描述Unicode解码操作时出错应该如何处理,当我们使用error= 'replace'时,程序就会自动用替换字符替换错误的序列。

实验:BPE分词器解码&编码实现

这部分的实验代码需要在adapters.get_tokenizer()中实现。为了方便还是像之前一样,新建一个文件BPE_Tokenizer.py,然后在里面声明类和待实现的方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from typing import Iterable, Iterator

class BPE_Tokenizer:
def __init__(self,vocab: dict[int, bytes],
merges: list[tuple[bytes, bytes]],
special_tokens: list[str] | None = None):
*"""*
* Construct a tokenizer from a given vocabulary, list of merges, and (optionally) a list of special tokens.*
* Args:*
* vocab: dict[int, bytes]*
* merges: list[tuple[bytes, bytes]]*
* special_tokens: list[str] | None = None*
* """*
* *pass

@classmethod
def from_files(cls, vocab_filepath: str, merges_filepath: str, special_tokens=None):
*"""*
* Class method that constructs and return a Tokenizer from a serialized vocabulary and list of merges (in the same format that your BPE training code output) and (optionally) a list of special tokens.*
* Args:*
* vocab_filepath: str*
* merges_filepath: str*
* special_tokens: list[str] | None = None*
* """*
* *pass

@staticmethod
def encode(self, text: str) -> list[int]:
*"""*
* Encode an input text into a sequence of token IDs.*
* """*
* *pass

def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
*"""*
* Given an iterable of strings (e.g., a Python file handle), return a generator that lazily yields token IDs. This is required for memory-efficient tokenization of large files that we cannot directly load into memory.*
* """*
* *pass

def decode(self, ids: list[int]) -> str:
*"""*
* Decode a sequence of token IDs into text.*
* """*

from_files()

由于这个函数只涉及基本的外部文件读取操作,可以首先实现。注意这里的读取格式是pkl,如果没有其他意外这个课程的数据都将用pickle进行外部保存和读取。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@classmethod
def from_files(cls, vocab_filepath: str, merges_filepath: str, special_tokens=None):
*"""*
* Class method that constructs and return a Tokenizer from a serialized vocabulary and list of merges (in the same format that your BPE training code output) and (optionally) a list of special tokens.*
* Args:*
* vocab_filepath: str*
* merges_filepath: str*
* special_tokens: list[str] | None = None*
* """*
* *with open(vocab_filepath,"rb") as vf, open(merges_filepath,"rb") as mf:
vocab: dict[int, bytes] = pickle.load(vf)
merges: list[tuple[bytes, bytes]] = pickle.load(mf)
size = len(vocab)
for token in special_tokens:
if token.encode("utf-8") not in vocab.values():
vocab[size] = bytes(token.encode("utf-8"))
size += 1
return cls(vocab, merges, special_tokens=special_tokens)
return cls(vocab, merges, special_tokens=special_tokens)

encode_iterable()

为了防止内存爆炸,这个函数实现了流式编码的效果,可以在还没有实现encode()的情况下先实现这个函数。

1
2
3
4
5
6
def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
"""
Given an iterable of strings (e.g., a Python file handle), return a generator that lazily yields token IDs. This is required for memory-efficient tokenization of large files that we cannot directly load into memory.
"""
for chunk in iterable:
yield from self.encode(chunk)

init()

先来考虑一下编码的全过程:假设现在输入文本是the cat ate,词表vocab{0: b' ', 1: b'a', 2: b'c', 3: b'e', 4: b'h', 5: b't', 6: b'th', 7: b' c', 8: b' a', 9: b'the', 10: b' at'},合并项列表merges[(b't', b'h'), (b' ', b'c'), (b' ', b'a'), (b'th', b'e'), (b' a', b't')]

首先,预分词器会把这段文本划分为['the', ' cat', ' ate']的token列表,接下来会对每个token进行合并操作。以the为例,首先表示为[b't', b'h', b'e'],然后查阅merges,得知应当使用(b't', b'h')进行合并。这时候表示更新为[b'th', b'e']。再次查阅merges,可以找到对应的合并,表示更新为[b'the']。这时候已经无法再进行更大的合并了,说明我们的合并已经完成。

接下来去词表vocab查找到对应项的编号为[9],这样就完成一次token的编号化。以此类推,输入文本the cat ate的编号化就是[9, 7, 1, 5, 10, 3]

这里有几个问题需要注意:

  1. 查阅merges的时候,如何确定用什么进行合并?

  2. 普通文本和special_token之间的合并顺序谁优先?

其实这两个问题的回答可以统一考虑。为了能够减小合并粒度,达到更加精确的编码效果,要实现下面三个需求:

  1. 要按字典序查找merges内的合并对,并匹配最先找到的能够合并的合并对。

  2. 同样的,special_token和普通文本应该分开处理

  3. 并且在处理special_token时,应当优先匹配较短的special_token(其实这个用字典序也可以解决)。

这样我们就可以编写__init__函数了,同时可以直接实现第3个需求。注意我们还创建了一个倒排索引bytes_to_id用于快速找到词表id。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def __init__(self, vocab: dict[int, bytes],
merges: list[tuple[bytes, bytes]],
special_tokens: list[str] | None = None):
"""
Construct a tokenizer from a given vocabulary, list of merges, and (optionally) a list of special tokens.
Args:
vocab: dict[int, bytes]
merges: list[tuple[bytes, bytes]]
special_tokens: list[str] | None = None
"""
self.vocab = vocab
self.merges = merges
self.special_tokens = special_tokens or []
# 对special_token进行排序,确保较短的先匹配上
self.special_tokens = sorted(self.special_tokens, key=lambda x: len(x), reverse=True)
# 构建vocab的倒排索引
self.bytes_to_id = {v: k for k, v in self.vocab.items()}
# 构建merge的pairs合并优先级倒排索引,使查找效率达到O(1)
self.merge_ranks = {pair: rank for rank, pair in enumerate(self.merges)}

split_by_special()

可以发现现在special_token和一般的token实际上已经是同等地位了,只是需要分别对待。因此,在分词时我们希望能够首先区分出special_token和一般的token,这就是这个函数的作用。

1
2
3
4
5
6
7
8
9
10
11
@staticmethod
def split_by_special(text, special_tokens, drop_special=True):
if not special_tokens:
return [text]

pattern = "|".join(re.escape(token) for token in special_tokens)
if not drop_special: pattern = f"({pattern})"

pattern = re.compile(pattern)
chunks = pattern.split(text)
return [c for c in chunks if c] # 丢掉空字符串

这个函数关键的一点就是会根据传入的drop_special来选择是否最后会丢掉special_token,举个例子:

1
2
3
4
5
text = "Hello<BOS>world<EOS>!"
special_tokens = ["<BOS>", "<EOS>"]

drop_special = True 结果: ['Hello', 'world', '!']
drop_special = False 结果: ['Hello', '<BOS>', 'world', '<EOS>', '!']

encode()

可以在这个函数中实现第2个需求。其实这个过程是比较容易理解的,见下面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def encode(self, text: str) -> list[int]:
"""
Encode an input text into a sequence of token IDs.
"""
if not text:
return []
pattern = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
token_ids = []

# 首先进行token分词
blocks = BPE_Tokenizer.split_by_special(text, special_tokens=self.special_tokens, drop_special=False)
for block in blocks:
if self.special_tokens and block in self.special_tokens:
# 是special_token,直接找词表转换为id
block_bytes = bytes(block.encode("utf-8"))
token_ids.append(self.bytes_to_id[block_bytes])
else:
# 是一般token,进行分词
tokens: list[str] = []
for match in re.finditer(pattern, block):
tokens.append(match.group(0))
for token in tokens:
token_bytes = [bytes([b]) for b in list(token.encode("utf-8"))]
bytes_list = self.get_best_merge(token_bytes)
for b in bytes_list:
token_ids.append(self.bytes_to_id[b])

return token_ids

还是以过程讲方法:假设现在输入的textHello<|endoftext|>How are you

通过调用split_by_special可以将special_tokentext中单独区分出来。现在的blocks就是["Hello","<|endoftext|>","How are you"]

然后对这个列表里面的每一项进行处理。如果是special_token那么就直接从词表中获取id。如果是普通文本,再进行普通文本的分词,结果放在tokens中,此时内容是["Hello","How"," are"," you"]

接下来需要对tokens的每一项token分别处理。首先是将其编码为UTF-8后转为bytes类型,得到一个列表bytes_list。再调用get_best_merge()函数对这个列表进行合并,才能从词表中找到对应的id。

最后将整段text对应的token_ids返回,整个编码过程就结束啦~

get_best_merge()

这个函数就负责上面encode()函数的合并列表功能,可以算作是最关键的函数了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def get_best_merge(self, token_bytes: list[bytes]) -> list[bytes]:
while len(token_bytes) >= 2:
best_rank = float('inf')
best_pair_id = -1
for i in range(len(token_bytes) - 1):
pair = (token_bytes[i], token_bytes[i + 1])
rank = self.merge_ranks.get(pair)
if rank is not None and rank < best_rank:
best_rank = rank
best_pair_id = i
# 如果这次循环没有找到可以合并的,说明已经无法再合并下去
if best_pair_id == -1:
break
# 合并best_pair_id对应的两个token
best_merge = token_bytes[best_pair_id] + token_bytes[best_pair_id + 1]
# 别忘了更新token_bytes
token_bytes = (
token_bytes[:best_pair_id] +
[best_merge] +
token_bytes[best_pair_id + 2:]
)
return token_bytes

其实主要的功能看了代码就明白了,是个反复穷举查找merges的过程。我们在这个函数里面实现了第1个需求,就是不要忘记最后更新token_bytes

decode()

解码函数其实就很容易实现了,不说了直接上代码吧

1
2
3
4
5
def decode(self, ids: list[int]) -> str:
"""
Decode a sequence of token IDs into text.
"""
return b''.join([self.vocab[t] for t in ids]).decode('utf-8',errors='replace')

可以看到还完成了文档的replace要求,其实通过传参就可以实现了。

test_tokenizer

在终端启动测试

1
uv run pytest tests/test_tokenizer.py

结果如下:

imagepng

test_encode_iterable_memory_usage这个测试耗费的时间有点长(大概8s),测试的大量时间耗在上面。另外test_encode_memory_usage结果为XFAIL 是正常的,因为要求里就写着“Tokenizer.encode is expected to take more memory than allotted (1MB).”,说明你没有作弊(

总结

在这次的课程中,我们实现了一个能够对文本进行编码和解码的BPE Tokenizer,并且成功通过了所有测试。这仅仅只是第一步,我们还没有涉及到机器学习的皮毛。下节课开始我们将会对transformer的结构进行了解并初步实现transformer的linear module和embedding module。