feat: replace dict get with dunder method

This commit is contained in:
Hashir Ahmad 2024-10-05 19:52:56 +02:00
parent 5b58203a0f
commit 790475a622

View File

@ -7,7 +7,7 @@ import itertools
from collections import OrderedDict from collections import OrderedDict
def get_byte_pair_counts(ids: list[int]): def get_byte_pair_counts(ids: list[int]) -> dict:
"""Count consecutive byte-pairs of an encoded string. """Count consecutive byte-pairs of an encoded string.
>>> ids = [73, 32, 97, 109, 32, 74, 111, 110, 83, 110, 111, 119, 46] >>> ids = [73, 32, 97, 109, 32, 74, 111, 110, 83, 110, 111, 119, 46]
@ -23,7 +23,7 @@ def get_byte_pair_counts(ids: list[int]):
return counts return counts
def merge(ids: list[int], pair: tuple, idx: int): def merge(ids: list[int], pair: tuple, idx: int) -> list[int]:
"""Replace most occurring byte pair with new byte that is not used """Replace most occurring byte pair with new byte that is not used
in the data. For utf-8 encoding, we start with 256 as the new byte in the data. For utf-8 encoding, we start with 256 as the new byte
@ -48,12 +48,12 @@ def merge(ids: list[int], pair: tuple, idx: int):
class Tokenizer: class Tokenizer:
"""Tokenize a string using the byte-pair encoding algorithm""" """Tokenize a string using the byte-pair encoding algorithm"""
def __init__(self, num_merges: int = 20, verbose: bool = False): def __init__(self, num_merges: int = 20, verbose: bool = False) -> None:
self.num_merges = num_merges self.num_merges = num_merges
self.merges: dict = {} self.merges: dict = {}
self.verbose = verbose self.verbose = verbose
def encode(self, text: str): def encode(self, text: str) -> list[int]:
"""Convert a string to tokens (bytes) """Convert a string to tokens (bytes)
>>> t = Tokenizer() >>> t = Tokenizer()
@ -80,7 +80,7 @@ class Tokenizer:
# start merging most frequently occurring byte pairs # start merging most frequently occurring byte pairs
for i in range(num_merges): for i in range(num_merges):
counts = get_byte_pair_counts(ids) counts = get_byte_pair_counts(ids)
pair = max(counts, key=counts.get) pair = max(counts, key=counts.__getitem__)
if counts[pair] == 1: if counts[pair] == 1:
continue continue
@ -93,7 +93,7 @@ class Tokenizer:
return ids return ids
def decode(self, ids: list[int]): def decode(self, ids: list[int]) -> str:
"""Convert a list of tokens to the original string """Convert a list of tokens to the original string
>>> t = Tokenizer() >>> t = Tokenizer()