Add type hints

This commit is contained in:
ricca 2023-11-28 19:08:28 +01:00
parent 694ba686e4
commit 040a292eca

View File

@ -11,12 +11,13 @@ https://en.wikipedia.org/wiki/Naive_Bayes_classifier
import doctest import doctest
import numpy as np import numpy as np
import scipy
from sklearn.datasets import fetch_20newsgroups from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
def group_indices_by_target(targets): def group_indices_by_target(targets: np.ndarray) -> dict[int, list[int]]:
""" """
Associates to each target label the indices of the examples with that label Associates to each target label the indices of the examples with that label
@ -37,7 +38,7 @@ def group_indices_by_target(targets):
>>> group_indices_by_target(y) >>> group_indices_by_target(y)
{1: [0, 3], 2: [1, 4], 3: [2], 5: [5]} {1: [0, 3], 2: [1, 4], 3: [2], 5: [5]}
""" """
grouped_indices = {} grouped_indices: dict[int, list[int]] = {}
for i, y in enumerate(targets): for i, y in enumerate(targets):
if y not in grouped_indices: if y not in grouped_indices:
grouped_indices[y] = [] grouped_indices[y] = []
@ -46,13 +47,13 @@ def group_indices_by_target(targets):
class MultinomialNBClassifier: class MultinomialNBClassifier:
def __init__(self, alpha=1): def __init__(self, alpha: int = 1) -> None:
self.classes = None self.classes: list[int] = []
self.features_probs = None self.features_probs: np.ndarray = np.array([])
self.priors = None self.priors: np.ndarray = np.array([])
self.alpha = alpha self.alpha = alpha
def fit(self, data, targets): def fit(self, data: scipy.sparse.csr_matrix, targets: np.ndarray) -> None:
""" """
Parameters Parameters
---------- ----------
@ -61,6 +62,17 @@ class MultinomialNBClassifier:
targets : array-like of shape (n_samples,) targets : array-like of shape (n_samples,)
Target labels Target labels
Example
----------
>>> from scipy import sparse
>>> rng = np.random.RandomState(1)
>>> data = rng.randint(5, size=(6, 100))
>>> data = sparse.csr_matrix(data)
>>> y = np.array([1, 2, 3, 4, 5, 6])
>>> model = MultinomialNBClassifier()
>>> print(model.fit(data, y))
None
""" """
n_examples, n_features = data.shape n_examples, n_features = data.shape
grouped_indices = group_indices_by_target(targets) grouped_indices = group_indices_by_target(targets)
@ -79,7 +91,7 @@ class MultinomialNBClassifier:
tot_features_count + self.alpha * n_features tot_features_count + self.alpha * n_features
) )
def predict(self, data): def predict(self, data: scipy.sparse.csr_matrix) -> np.ndarray:
""" """
Parameters Parameters
---------- ----------
@ -93,8 +105,6 @@ class MultinomialNBClassifier:
Example Example
---------- ----------
Let's test the function following an example taken from the documentation
of the MultinomialNB model from sklearn
>>> from scipy import sparse >>> from scipy import sparse
>>> rng = np.random.RandomState(1) >>> rng = np.random.RandomState(1)
>>> data = rng.randint(5, size=(6, 100)) >>> data = rng.randint(5, size=(6, 100))
@ -118,7 +128,7 @@ class MultinomialNBClassifier:
return np.array(y_pred) return np.array(y_pred)
def main(): if __name__ == "__main__":
newsgroups_train = fetch_20newsgroups(subset="train") newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test") newsgroups_test = fetch_20newsgroups(subset="test")
x_train = newsgroups_train["data"] x_train = newsgroups_train["data"]
@ -138,8 +148,4 @@ def main():
"Accuracy of naive bayes text classifier: " "Accuracy of naive bayes text classifier: "
+ str(accuracy_score(y_test, y_pred)) + str(accuracy_score(y_test, y_pred))
) )
if __name__ == "__main__":
main()
doctest.testmod() doctest.testmod()