https://github.com/karpathy/minbpe
BPE(Byte Pair Encoding)算法原本是一种数据压缩技术,后来被广泛应用于自然语言处理(NLP)中,特别是在文本预处理和子词(subword)单位的生成中。在NLP领域,BPE算法用于解决词汇量大导致的数据稀疏问题,以及处理未知或罕见单词的问题。BPE算法的核心思想是将频繁出现的字节对(在文本处理中可以理解为字符对)合并成一个新的单元。
核心就是统计pair出现的次数(get_stats)和替换pair为指定的idx(merge)
def get_stats(ids):
"""
Given a list of integers, return a dictionary of counts of consecutive pairs
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
"""
counts = {}
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
train就是把输入转化为bytes序列,然后进行多次merge,每次merge都需要把出现次数最多的pair,替换成一个新的idx,反复执行多次,最终形成一个byte序列。
class BasicTokenizer:
def __init__(self):
# by default, we have a vocab size of 256 (all bytes) and no merges
self.merges = {}
self.vocab = {idx: bytes([idx]) for idx in range(256)}
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
# input text preprocessing
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
for i in range(num_merges):
# count up the number of times every consecutive pair appears
stats = get_stats(ids)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurences of pair in ids with idx
ids = merge(ids, pair, idx)
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
把merge后的
def decode(self, ids):
# given ids (list of integers), return Python string
text_bytes = b"".join(self.vocab[idx] for idx in ids)
text = text_bytes.decode("utf-8", errors="replace")
return text
def encode(self, text):
# given a string text, return the token ids
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
if __name__ == "__main__":
"""
Quick unit test, following along the Wikipedia example:
<https://en.wikipedia.org/wiki/Byte_pair_encoding>
According to Wikipedia, running bpe on the the input string:
"aaabdaaabac"
for 3 merges will result in string:
"XdXac"
where:
X=ZY
Y=ab
Z=aa
Keep in mind that for us a=97, b=98, c=99, d=100 (ASCII values)
so Z will be 256, Y will be 257, X will be 258.
So we expect the output list of ids to be [258, 100, 258, 97, 99]
"""
text = "aaabdaaabac"
tokenizer = BasicTokenizer()
# we do 3 merges
tokenizer.train(text, 256 + 3)
# verify the correct expected result
ids = tokenizer.encode(text)
print("OK" if ids == [258, 100, 258, 97, 99] else "FAIL")
# verify that decode(encode(x)) == x
print("OK" if tokenizer.decode(tokenizer.encode(text)) == text else "FAIL")