https://github.com/karpathy/minbpe

BPE(Byte Pair Encoding)算法原本是一种数据压缩技术,后来被广泛应用于自然语言处理(NLP)中,特别是在文本预处理和子词(subword)单位的生成中。在NLP领域,BPE算法用于解决词汇量大导致的数据稀疏问题,以及处理未知或罕见单词的问题。BPE算法的核心思想是将频繁出现的字节对(在文本处理中可以理解为字符对)合并成一个新的单元。

核心就是统计pair出现的次数(get_stats)和替换pair为指定的idx(merge)

def get_stats(ids):
    """
    Given a list of integers, return a dictionary of counts of consecutive pairs
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    """
    counts = {}
    for pair in zip(ids, ids[1:]): # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    """
    In the list of integers (ids), replace all consecutive occurrences
    of pair with the new integer token idx
    Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    newids = []
    i = 0
    while i < len(ids):
        # if not at the very last position AND the pair matches, replace it
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

train就是把输入转化为bytes序列,然后进行多次merge,每次merge都需要把出现次数最多的pair,替换成一个新的idx,反复执行多次,最终形成一个byte序列。

class BasicTokenizer:

    def __init__(self):
        # by default, we have a vocab size of 256 (all bytes) and no merges
        self.merges = {}
        self.vocab = {idx: bytes([idx]) for idx in range(256)}

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        # input text preprocessing
        text_bytes = text.encode("utf-8") # raw bytes
        ids = list(text_bytes) # list of integers in range 0..255

        # iteratively merge the most common pairs to create new tokens
        merges = {} # (int, int) -> int
        vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
        for i in range(num_merges):
            # count up the number of times every consecutive pair appears
            stats = get_stats(ids)
            # find the pair with the highest count
            pair = max(stats, key=stats.get)
            # mint a new token: assign it the next available id
            idx = 256 + i
            # replace all occurences of pair in ids with idx
            ids = merge(ids, pair, idx)
            # save the merge
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            # prints
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

        # save class variables
        self.merges = merges # used in encode()
        self.vocab = vocab   # used in decode()

把merge后的

def decode(self, ids):
        # given ids (list of integers), return Python string
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
      return text

def encode(self, text):
    # given a string text, return the token ids
    text_bytes = text.encode("utf-8") # raw bytes
    ids = list(text_bytes) # list of integers in range 0..255
    while len(ids) >= 2:
        # find the pair with the lowest merge index
        stats = get_stats(ids)
        pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
        # subtle: if there are no more merges available, the key will
        # result in an inf for every single pair, and the min will be
        # just the first pair in the list, arbitrarily
        # we can detect this terminating case by a membership check
        if pair not in self.merges:
            break # nothing else can be merged anymore
        # otherwise let's merge the best pair (lowest merge index)
        idx = self.merges[pair]
        ids = merge(ids, pair, idx)
    return ids
if __name__ == "__main__":

    """
    Quick unit test, following along the Wikipedia example:
    <https://en.wikipedia.org/wiki/Byte_pair_encoding>

    According to Wikipedia, running bpe on the the input string:
    "aaabdaaabac"

    for 3 merges will result in string:
    "XdXac"

    where:
    X=ZY
    Y=ab
    Z=aa

    Keep in mind that for us a=97, b=98, c=99, d=100 (ASCII values)
    so Z will be 256, Y will be 257, X will be 258.

    So we expect the output list of ids to be [258, 100, 258, 97, 99]
    """

    text = "aaabdaaabac"
    tokenizer = BasicTokenizer()

    # we do 3 merges
    tokenizer.train(text, 256 + 3)

    # verify the correct expected result
    ids = tokenizer.encode(text)
    print("OK" if ids == [258, 100, 258, 97, 99] else "FAIL")

    # verify that decode(encode(x)) == x
    print("OK" if tokenizer.decode(tokenizer.encode(text)) == text else "FAIL")