mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-02-07 01:50:55 +00:00
Add type hints
This commit is contained in:
parent
694ba686e4
commit
040a292eca
|
@ -11,12 +11,13 @@ https://en.wikipedia.org/wiki/Naive_Bayes_classifier
|
||||||
import doctest
|
import doctest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import scipy
|
||||||
from sklearn.datasets import fetch_20newsgroups
|
from sklearn.datasets import fetch_20newsgroups
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.metrics import accuracy_score
|
from sklearn.metrics import accuracy_score
|
||||||
|
|
||||||
|
|
||||||
def group_indices_by_target(targets):
|
def group_indices_by_target(targets: np.ndarray) -> dict[int, list[int]]:
|
||||||
"""
|
"""
|
||||||
Associates to each target label the indices of the examples with that label
|
Associates to each target label the indices of the examples with that label
|
||||||
|
|
||||||
|
@ -37,7 +38,7 @@ def group_indices_by_target(targets):
|
||||||
>>> group_indices_by_target(y)
|
>>> group_indices_by_target(y)
|
||||||
{1: [0, 3], 2: [1, 4], 3: [2], 5: [5]}
|
{1: [0, 3], 2: [1, 4], 3: [2], 5: [5]}
|
||||||
"""
|
"""
|
||||||
grouped_indices = {}
|
grouped_indices: dict[int, list[int]] = {}
|
||||||
for i, y in enumerate(targets):
|
for i, y in enumerate(targets):
|
||||||
if y not in grouped_indices:
|
if y not in grouped_indices:
|
||||||
grouped_indices[y] = []
|
grouped_indices[y] = []
|
||||||
|
@ -46,13 +47,13 @@ def group_indices_by_target(targets):
|
||||||
|
|
||||||
|
|
||||||
class MultinomialNBClassifier:
|
class MultinomialNBClassifier:
|
||||||
def __init__(self, alpha=1):
|
def __init__(self, alpha: int = 1) -> None:
|
||||||
self.classes = None
|
self.classes: list[int] = []
|
||||||
self.features_probs = None
|
self.features_probs: np.ndarray = np.array([])
|
||||||
self.priors = None
|
self.priors: np.ndarray = np.array([])
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
||||||
def fit(self, data, targets):
|
def fit(self, data: scipy.sparse.csr_matrix, targets: np.ndarray) -> None:
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -61,6 +62,17 @@ class MultinomialNBClassifier:
|
||||||
|
|
||||||
targets : array-like of shape (n_samples,)
|
targets : array-like of shape (n_samples,)
|
||||||
Target labels
|
Target labels
|
||||||
|
|
||||||
|
Example
|
||||||
|
----------
|
||||||
|
>>> from scipy import sparse
|
||||||
|
>>> rng = np.random.RandomState(1)
|
||||||
|
>>> data = rng.randint(5, size=(6, 100))
|
||||||
|
>>> data = sparse.csr_matrix(data)
|
||||||
|
>>> y = np.array([1, 2, 3, 4, 5, 6])
|
||||||
|
>>> model = MultinomialNBClassifier()
|
||||||
|
>>> print(model.fit(data, y))
|
||||||
|
None
|
||||||
"""
|
"""
|
||||||
n_examples, n_features = data.shape
|
n_examples, n_features = data.shape
|
||||||
grouped_indices = group_indices_by_target(targets)
|
grouped_indices = group_indices_by_target(targets)
|
||||||
|
@ -79,7 +91,7 @@ class MultinomialNBClassifier:
|
||||||
tot_features_count + self.alpha * n_features
|
tot_features_count + self.alpha * n_features
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict(self, data):
|
def predict(self, data: scipy.sparse.csr_matrix) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -93,8 +105,6 @@ class MultinomialNBClassifier:
|
||||||
|
|
||||||
Example
|
Example
|
||||||
----------
|
----------
|
||||||
Let's test the function following an example taken from the documentation
|
|
||||||
of the MultinomialNB model from sklearn
|
|
||||||
>>> from scipy import sparse
|
>>> from scipy import sparse
|
||||||
>>> rng = np.random.RandomState(1)
|
>>> rng = np.random.RandomState(1)
|
||||||
>>> data = rng.randint(5, size=(6, 100))
|
>>> data = rng.randint(5, size=(6, 100))
|
||||||
|
@ -118,7 +128,7 @@ class MultinomialNBClassifier:
|
||||||
return np.array(y_pred)
|
return np.array(y_pred)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
if __name__ == "__main__":
|
||||||
newsgroups_train = fetch_20newsgroups(subset="train")
|
newsgroups_train = fetch_20newsgroups(subset="train")
|
||||||
newsgroups_test = fetch_20newsgroups(subset="test")
|
newsgroups_test = fetch_20newsgroups(subset="test")
|
||||||
x_train = newsgroups_train["data"]
|
x_train = newsgroups_train["data"]
|
||||||
|
@ -138,8 +148,4 @@ def main():
|
||||||
"Accuracy of naive bayes text classifier: "
|
"Accuracy of naive bayes text classifier: "
|
||||||
+ str(accuracy_score(y_test, y_pred))
|
+ str(accuracy_score(y_test, y_pred))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
doctest.testmod()
|
doctest.testmod()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user