From 790475a62298fe5b845e822dc0b1c5266a564a55 Mon Sep 17 00:00:00 2001 From: Hashir Ahmad Date: Sat, 5 Oct 2024 19:52:56 +0200 Subject: [PATCH] feat: replace dict get with dunder method --- strings/bpe_tokenizer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/strings/bpe_tokenizer.py b/strings/bpe_tokenizer.py index 2f4366619..d32bda79a 100644 --- a/strings/bpe_tokenizer.py +++ b/strings/bpe_tokenizer.py @@ -7,7 +7,7 @@ import itertools from collections import OrderedDict -def get_byte_pair_counts(ids: list[int]): +def get_byte_pair_counts(ids: list[int]) -> dict: """Count consecutive byte-pairs of an encoded string. >>> ids = [73, 32, 97, 109, 32, 74, 111, 110, 83, 110, 111, 119, 46] @@ -23,7 +23,7 @@ def get_byte_pair_counts(ids: list[int]): return counts -def merge(ids: list[int], pair: tuple, idx: int): +def merge(ids: list[int], pair: tuple, idx: int) -> list[int]: """Replace most occurring byte pair with new byte that is not used in the data. For utf-8 encoding, we start with 256 as the new byte @@ -48,12 +48,12 @@ def merge(ids: list[int], pair: tuple, idx: int): class Tokenizer: """Tokenize a string using the byte-pair encoding algorithm""" - def __init__(self, num_merges: int = 20, verbose: bool = False): + def __init__(self, num_merges: int = 20, verbose: bool = False) -> None: self.num_merges = num_merges self.merges: dict = {} self.verbose = verbose - def encode(self, text: str): + def encode(self, text: str) -> list[int]: """Convert a string to tokens (bytes) >>> t = Tokenizer() @@ -80,7 +80,7 @@ class Tokenizer: # start merging most frequently occurring byte pairs for i in range(num_merges): counts = get_byte_pair_counts(ids) - pair = max(counts, key=counts.get) + pair = max(counts, key=counts.__getitem__) if counts[pair] == 1: continue @@ -93,7 +93,7 @@ class Tokenizer: return ids - def decode(self, ids: list[int]): + def decode(self, ids: list[int]) -> str: """Convert a list of tokens to the original string >>> t = Tokenizer()