Handle mypy errors

This commit is contained in:
ricca 2023-10-06 17:32:14 +02:00
parent 9d3da1ae4f
commit 694ba686e4

View File

@ -11,14 +11,12 @@ https://en.wikipedia.org/wiki/Naive_Bayes_classifier
import doctest import doctest
import numpy as np import numpy as np
import numpy.typing as npt
from scipy import sparse
from sklearn.datasets import fetch_20newsgroups from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
def group_indices_by_target(targets: npt.ArrayLike) -> dict: def group_indices_by_target(targets):
""" """
Associates to each target label the indices of the examples with that label Associates to each target label the indices of the examples with that label
@ -48,13 +46,13 @@ def group_indices_by_target(targets: npt.ArrayLike) -> dict:
class MultinomialNBClassifier: class MultinomialNBClassifier:
def __init__(self, alpha: int = 1): def __init__(self, alpha=1):
self.classes = None self.classes = None
self.features_probs = None self.features_probs = None
self.priors = None self.priors = None
self.alpha = alpha self.alpha = alpha
def fit(self, data: sparse.csr_matrix, targets: npt.ArrayLike) -> None: def fit(self, data, targets):
""" """
Parameters Parameters
---------- ----------
@ -81,7 +79,7 @@ class MultinomialNBClassifier:
tot_features_count + self.alpha * n_features tot_features_count + self.alpha * n_features
) )
def predict(self, data: sparse.csr_matrix) -> np.ndarray: def predict(self, data):
""" """
Parameters Parameters
---------- ----------
@ -97,6 +95,7 @@ class MultinomialNBClassifier:
---------- ----------
Let's test the function following an example taken from the documentation Let's test the function following an example taken from the documentation
of the MultinomialNB model from sklearn of the MultinomialNB model from sklearn
>>> from scipy import sparse
>>> rng = np.random.RandomState(1) >>> rng = np.random.RandomState(1)
>>> data = rng.randint(5, size=(6, 100)) >>> data = rng.randint(5, size=(6, 100))
>>> data = sparse.csr_matrix(data) >>> data = sparse.csr_matrix(data)
@ -119,7 +118,7 @@ class MultinomialNBClassifier:
return np.array(y_pred) return np.array(y_pred)
def main() -> None: def main():
newsgroups_train = fetch_20newsgroups(subset="train") newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test") newsgroups_test = fetch_20newsgroups(subset="test")
x_train = newsgroups_train["data"] x_train = newsgroups_train["data"]