From 5b58203a0f0b96b521a383f480dcb7770fdf6c80 Mon Sep 17 00:00:00 2001 From: Hashir Ahmad Date: Sat, 5 Oct 2024 19:31:26 +0200 Subject: [PATCH 1/3] feat: String Tokenization with Byte-Pair Encoding --- strings/bpe_tokenizer.py | 129 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 strings/bpe_tokenizer.py diff --git a/strings/bpe_tokenizer.py b/strings/bpe_tokenizer.py new file mode 100644 index 000000000..2f4366619 --- /dev/null +++ b/strings/bpe_tokenizer.py @@ -0,0 +1,129 @@ +"""Byte-Pair Encoding: Subword-based tokenization algorithm used +by state-of-the-art language models. + +Wikipedia: https://en.wikipedia.org/wiki/Byte_pair_encoding""" + +import itertools +from collections import OrderedDict + + +def get_byte_pair_counts(ids: list[int]): + """Count consecutive byte-pairs of an encoded string. + + >>> ids = [73, 32, 97, 109, 32, 74, 111, 110, 83, 110, 111, 119, 46] + >>> get_byte_pair_counts(ids) + {(73, 32): 1, (32, 97): 1, (97, 109): 1, (109, 32): 1, (32, 74): 1, (74, 111): 1, (111, 110): 1, (110, 83): 1, (83, 110): 1, (110, 111): 1, (111, 119): 1, (119, 46): 1} + >>> ids = [2, 3, 6, 2, 3, 6, 2, 5] + >>> get_byte_pair_counts(ids) + {(2, 3): 2, (3, 6): 2, (6, 2): 2, (2, 5): 1} + """ # noqa: E501 + counts: dict = {} + for pair in itertools.pairwise(ids): + counts[pair] = counts.get(pair, 0) + 1 + return counts + + +def merge(ids: list[int], pair: tuple, idx: int): + """Replace most occurring byte pair with new byte that is not used + in the data. For utf-8 encoding, we start with 256 as the new byte + + >>> ids = [2, 3, 6, 2, 3, 6, 2, 5] + >>> pair = (2, 3) + >>> idx = 256 + >>> merge(ids, pair, idx) + [256, 6, 256, 6, 2, 5] + """ + new_ids = [] + i = 0 + while i < len(ids): + if i < len(ids) - 1 and (ids[i] == pair[0] and ids[i + 1] == pair[1]): + new_ids.append(idx) + i += 2 + else: + new_ids.append(ids[i]) + i += 1 + return new_ids + + +class Tokenizer: + """Tokenize a string using the byte-pair encoding algorithm""" + + def __init__(self, num_merges: int = 20, verbose: bool = False): + self.num_merges = num_merges + self.merges: dict = {} + self.verbose = verbose + + def encode(self, text: str): + """Convert a string to tokens (bytes) + + >>> t = Tokenizer() + >>> text = "I am JonSnow." + >>> t.encode(text) + [73, 32, 97, 109, 32, 74, 111, 110, 83, 110, 111, 119, 46] + + >>> t = Tokenizer() + >>> text = "" + >>> t.encode(text) + [] + """ + text_b = text.encode("utf-8") # raw bytes + tokens = list(map(int, text_b)) # convert to list of integers + + if self.verbose: + print(f"Input text: {text}") + print(f"Tokens: {tokens}") + + ids = list(tokens) # create a copy of tokens + self.merges = OrderedDict() # store a mapping of merges (int, int) -> int + max_merges = len(tokens) - 1 + num_merges = min(self.num_merges, max_merges) + # start merging most frequently occurring byte pairs + for i in range(num_merges): + counts = get_byte_pair_counts(ids) + pair = max(counts, key=counts.get) + + if counts[pair] == 1: + continue + + idx = 256 + i # create new token for every merge step + if self.verbose: + print(f"Merging {pair} into a new token {idx}") + ids = merge(ids, pair, idx) + self.merges[pair] = idx + + return ids + + def decode(self, ids: list[int]): + """Convert a list of tokens to the original string + + >>> t = Tokenizer() + >>> ids = [73, 32, 97, 109, 32, 74, 111, 110, 83, 110, 111, 119, 46] + >>> t.decode(ids) + 'I am JonSnow.' + + >>> t = Tokenizer() + >>> ids = [] + >>> t.decode(ids) + '' + """ + vocab = {idx: bytes([idx]) for idx in range(256)} # original vocabulary + # The iteration of items should be in the order of + # their insertion. This is the default behavior in Python 3 + # but we use an OrderedDict explicitly here + for (p0, p1), idx in self.merges.items(): + vocab[idx] = vocab[p0] + vocab[p1] + + if self.verbose: + print("Vocabulary (after merging): {vocab}") + + tokens = b"".join(vocab[idx] for idx in ids) + # handle UnicodeDecodeError by replacing the invalid + # start byte to conform to utf-8 format + text = tokens.decode("utf-8", errors="replace") + return text + + +if __name__ == "__main__": + import doctest + + doctest.testmod() From 0f6994c67ef0298f3054328cc1182f85a29da3f1 Mon Sep 17 00:00:00 2001 From: hash-ir Date: Sat, 5 Oct 2024 17:34:09 +0000 Subject: [PATCH 2/3] updating DIRECTORY.md --- DIRECTORY.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/DIRECTORY.md b/DIRECTORY.md index cdbbac684..cecd881f5 100644 --- a/DIRECTORY.md +++ b/DIRECTORY.md @@ -22,6 +22,7 @@ * [Rat In Maze](backtracking/rat_in_maze.py) * [Sudoku](backtracking/sudoku.py) * [Sum Of Subsets](backtracking/sum_of_subsets.py) + * [Word Break](backtracking/word_break.py) * [Word Ladder](backtracking/word_ladder.py) * [Word Search](backtracking/word_search.py) @@ -1272,6 +1273,7 @@ * [Barcode Validator](strings/barcode_validator.py) * [Bitap String Match](strings/bitap_string_match.py) * [Boyer Moore Search](strings/boyer_moore_search.py) + * [Bpe Tokenizer](strings/bpe_tokenizer.py) * [Camel Case To Snake Case](strings/camel_case_to_snake_case.py) * [Can String Be Rearranged As Palindrome](strings/can_string_be_rearranged_as_palindrome.py) * [Capitalize](strings/capitalize.py) From 790475a62298fe5b845e822dc0b1c5266a564a55 Mon Sep 17 00:00:00 2001 From: Hashir Ahmad Date: Sat, 5 Oct 2024 19:52:56 +0200 Subject: [PATCH 3/3] feat: replace dict get with dunder method --- strings/bpe_tokenizer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/strings/bpe_tokenizer.py b/strings/bpe_tokenizer.py index 2f4366619..d32bda79a 100644 --- a/strings/bpe_tokenizer.py +++ b/strings/bpe_tokenizer.py @@ -7,7 +7,7 @@ import itertools from collections import OrderedDict -def get_byte_pair_counts(ids: list[int]): +def get_byte_pair_counts(ids: list[int]) -> dict: """Count consecutive byte-pairs of an encoded string. >>> ids = [73, 32, 97, 109, 32, 74, 111, 110, 83, 110, 111, 119, 46] @@ -23,7 +23,7 @@ def get_byte_pair_counts(ids: list[int]): return counts -def merge(ids: list[int], pair: tuple, idx: int): +def merge(ids: list[int], pair: tuple, idx: int) -> list[int]: """Replace most occurring byte pair with new byte that is not used in the data. For utf-8 encoding, we start with 256 as the new byte @@ -48,12 +48,12 @@ def merge(ids: list[int], pair: tuple, idx: int): class Tokenizer: """Tokenize a string using the byte-pair encoding algorithm""" - def __init__(self, num_merges: int = 20, verbose: bool = False): + def __init__(self, num_merges: int = 20, verbose: bool = False) -> None: self.num_merges = num_merges self.merges: dict = {} self.verbose = verbose - def encode(self, text: str): + def encode(self, text: str) -> list[int]: """Convert a string to tokens (bytes) >>> t = Tokenizer() @@ -80,7 +80,7 @@ class Tokenizer: # start merging most frequently occurring byte pairs for i in range(num_merges): counts = get_byte_pair_counts(ids) - pair = max(counts, key=counts.get) + pair = max(counts, key=counts.__getitem__) if counts[pair] == 1: continue @@ -93,7 +93,7 @@ class Tokenizer: return ids - def decode(self, ids: list[int]): + def decode(self, ids: list[int]) -> str: """Convert a list of tokens to the original string >>> t = Tokenizer()