[mypy] annotate compression (#5570)

This commit is contained in:
Erwin Junge 2021-10-26 12:29:27 +02:00 committed by GitHub
parent de07245c17
commit e49d8e3af4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 28 deletions

View File

@ -12,6 +12,13 @@ of text compression algorithms, costing only some extra computation.
""" """
from __future__ import annotations from __future__ import annotations
from typing import TypedDict
class BWTTransformDict(TypedDict):
bwt_string: str
idx_original_string: int
def all_rotations(s: str) -> list[str]: def all_rotations(s: str) -> list[str]:
""" """
@ -43,7 +50,7 @@ def all_rotations(s: str) -> list[str]:
return [s[i:] + s[:i] for i in range(len(s))] return [s[i:] + s[:i] for i in range(len(s))]
def bwt_transform(s: str) -> dict: def bwt_transform(s: str) -> BWTTransformDict:
""" """
:param s: The string that will be used at bwt algorithm :param s: The string that will be used at bwt algorithm
:return: the string composed of the last char of each row of the ordered :return: the string composed of the last char of each row of the ordered
@ -75,10 +82,11 @@ def bwt_transform(s: str) -> dict:
rotations = all_rotations(s) rotations = all_rotations(s)
rotations.sort() # sort the list of rotations in alphabetically order rotations.sort() # sort the list of rotations in alphabetically order
# make a string composed of the last char of each rotation # make a string composed of the last char of each rotation
return { response: BWTTransformDict = {
"bwt_string": "".join([word[-1] for word in rotations]), "bwt_string": "".join([word[-1] for word in rotations]),
"idx_original_string": rotations.index(s), "idx_original_string": rotations.index(s),
} }
return response
def reverse_bwt(bwt_string: str, idx_original_string: int) -> str: def reverse_bwt(bwt_string: str, idx_original_string: int) -> str:

View File

@ -1,29 +1,31 @@
from __future__ import annotations
import sys import sys
class Letter: class Letter:
def __init__(self, letter, freq): def __init__(self, letter: str, freq: int):
self.letter = letter self.letter: str = letter
self.freq = freq self.freq: int = freq
self.bitstring = {} self.bitstring: dict[str, str] = {}
def __repr__(self): def __repr__(self) -> str:
return f"{self.letter}:{self.freq}" return f"{self.letter}:{self.freq}"
class TreeNode: class TreeNode:
def __init__(self, freq, left, right): def __init__(self, freq: int, left: Letter | TreeNode, right: Letter | TreeNode):
self.freq = freq self.freq: int = freq
self.left = left self.left: Letter | TreeNode = left
self.right = right self.right: Letter | TreeNode = right
def parse_file(file_path): def parse_file(file_path: str) -> list[Letter]:
""" """
Read the file and build a dict of all letters and their Read the file and build a dict of all letters and their
frequencies, then convert the dict into a list of Letters. frequencies, then convert the dict into a list of Letters.
""" """
chars = {} chars: dict[str, int] = {}
with open(file_path) as f: with open(file_path) as f:
while True: while True:
c = f.read(1) c = f.read(1)
@ -33,22 +35,23 @@ def parse_file(file_path):
return sorted((Letter(c, f) for c, f in chars.items()), key=lambda l: l.freq) return sorted((Letter(c, f) for c, f in chars.items()), key=lambda l: l.freq)
def build_tree(letters): def build_tree(letters: list[Letter]) -> Letter | TreeNode:
""" """
Run through the list of Letters and build the min heap Run through the list of Letters and build the min heap
for the Huffman Tree. for the Huffman Tree.
""" """
while len(letters) > 1: response: list[Letter | TreeNode] = letters # type: ignore
left = letters.pop(0) while len(response) > 1:
right = letters.pop(0) left = response.pop(0)
right = response.pop(0)
total_freq = left.freq + right.freq total_freq = left.freq + right.freq
node = TreeNode(total_freq, left, right) node = TreeNode(total_freq, left, right)
letters.append(node) response.append(node)
letters.sort(key=lambda l: l.freq) response.sort(key=lambda l: l.freq)
return letters[0] return response[0]
def traverse_tree(root, bitstring): def traverse_tree(root: Letter | TreeNode, bitstring: str) -> list[Letter]:
""" """
Recursively traverse the Huffman Tree to set each Recursively traverse the Huffman Tree to set each
Letter's bitstring dictionary, and return the list of Letters Letter's bitstring dictionary, and return the list of Letters
@ -56,13 +59,14 @@ def traverse_tree(root, bitstring):
if type(root) is Letter: if type(root) is Letter:
root.bitstring[root.letter] = bitstring root.bitstring[root.letter] = bitstring
return [root] return [root]
treenode: TreeNode = root # type: ignore
letters = [] letters = []
letters += traverse_tree(root.left, bitstring + "0") letters += traverse_tree(treenode.left, bitstring + "0")
letters += traverse_tree(root.right, bitstring + "1") letters += traverse_tree(treenode.right, bitstring + "1")
return letters return letters
def huffman(file_path): def huffman(file_path: str) -> None:
""" """
Parse the file, build the tree, then run through the file Parse the file, build the tree, then run through the file
again, using the letters dictionary to find and print out the again, using the letters dictionary to find and print out the

View File

@ -26,7 +26,7 @@ def read_file_binary(file_path: str) -> str:
def add_key_to_lexicon( def add_key_to_lexicon(
lexicon: dict, curr_string: str, index: int, last_match_id: str lexicon: dict[str, str], curr_string: str, index: int, last_match_id: str
) -> None: ) -> None:
""" """
Adds new strings (curr_string + "0", curr_string + "1") to the lexicon Adds new strings (curr_string + "0", curr_string + "1") to the lexicon
@ -110,7 +110,7 @@ def write_file_binary(file_path: str, to_write: str) -> None:
sys.exit() sys.exit()
def compress(source_path, destination_path: str) -> None: def compress(source_path: str, destination_path: str) -> None:
""" """
Reads source file, compresses it and writes the compressed result in destination Reads source file, compresses it and writes the compressed result in destination
file file

View File

@ -12,7 +12,7 @@ import cv2
import numpy as np import numpy as np
def psnr(original, contrast): def psnr(original: float, contrast: float) -> float:
mse = np.mean((original - contrast) ** 2) mse = np.mean((original - contrast) ** 2)
if mse == 0: if mse == 0:
return 100 return 100
@ -21,7 +21,7 @@ def psnr(original, contrast):
return PSNR return PSNR
def main(): def main() -> None:
dir_path = os.path.dirname(os.path.realpath(__file__)) dir_path = os.path.dirname(os.path.realpath(__file__))
# Loading images (original image and compressed image) # Loading images (original image and compressed image)
original = cv2.imread(os.path.join(dir_path, "image_data/original_image.png")) original = cv2.imread(os.path.join(dir_path, "image_data/original_image.png"))