Handle mypy errors

This commit is contained in:
ricca 2023-10-06 17:32:14 +02:00
parent 9d3da1ae4f
commit 694ba686e4

View File

@ -11,14 +11,12 @@ https://en.wikipedia.org/wiki/Naive_Bayes_classifier
import doctest
import numpy as np
import numpy.typing as npt
from scipy import sparse
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score
def group_indices_by_target(targets: npt.ArrayLike) -> dict:
def group_indices_by_target(targets):
"""
Associates to each target label the indices of the examples with that label
@ -48,13 +46,13 @@ def group_indices_by_target(targets: npt.ArrayLike) -> dict:
class MultinomialNBClassifier:
def __init__(self, alpha: int = 1):
def __init__(self, alpha=1):
self.classes = None
self.features_probs = None
self.priors = None
self.alpha = alpha
def fit(self, data: sparse.csr_matrix, targets: npt.ArrayLike) -> None:
def fit(self, data, targets):
"""
Parameters
----------
@ -81,7 +79,7 @@ class MultinomialNBClassifier:
tot_features_count + self.alpha * n_features
)
def predict(self, data: sparse.csr_matrix) -> np.ndarray:
def predict(self, data):
"""
Parameters
----------
@ -97,6 +95,7 @@ class MultinomialNBClassifier:
----------
Let's test the function following an example taken from the documentation
of the MultinomialNB model from sklearn
>>> from scipy import sparse
>>> rng = np.random.RandomState(1)
>>> data = rng.randint(5, size=(6, 100))
>>> data = sparse.csr_matrix(data)
@ -119,7 +118,7 @@ class MultinomialNBClassifier:
return np.array(y_pred)
def main() -> None:
def main():
newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test")
x_train = newsgroups_train["data"]