mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-02-14 13:28:09 +00:00
Implemented input check
This commit is contained in:
parent
2759947a48
commit
37184e21de
|
@ -11,7 +11,7 @@ from sklearn.datasets import fetch_20newsgroups
|
||||||
from sklearn.metrics import accuracy_score
|
from sklearn.metrics import accuracy_score
|
||||||
|
|
||||||
|
|
||||||
def group_data_by_target(targets):
|
def group_indices_by_target(targets):
|
||||||
"""
|
"""
|
||||||
Associates to each target label the indices of the examples with that label
|
Associates to each target label the indices of the examples with that label
|
||||||
|
|
||||||
|
@ -22,21 +22,21 @@ def group_data_by_target(targets):
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
grouped_data : dict of (label : list)
|
grouped_indices : dict of (label : list)
|
||||||
Maps each target label to the list of indices of the examples with that label
|
Maps each target label to the list of indices of the examples with that label
|
||||||
|
|
||||||
Example
|
Example
|
||||||
----------
|
----------
|
||||||
>>> y = np.array([1, 2, 3, 1, 2, 5])
|
>>> y = np.array([1, 2, 3, 1, 2, 5])
|
||||||
>>> group_data_by_target(y)
|
>>> group_indices_by_target(y)
|
||||||
{1: [0, 3], 2: [1, 4], 3: [2], 5: [5]}
|
{1: [0, 3], 2: [1, 4], 3: [2], 5: [5]}
|
||||||
"""
|
"""
|
||||||
grouped_data = {}
|
grouped_indices = {}
|
||||||
for i, y in enumerate(targets):
|
for i, y in enumerate(targets):
|
||||||
if y not in grouped_data:
|
if y not in grouped_indices:
|
||||||
grouped_data[y] = []
|
grouped_indices[y] = []
|
||||||
grouped_data[y].append(i)
|
grouped_indices[y].append(i)
|
||||||
return grouped_data
|
return grouped_indices
|
||||||
|
|
||||||
|
|
||||||
class MultinomialNBClassifier:
|
class MultinomialNBClassifier:
|
||||||
|
@ -46,6 +46,16 @@ class MultinomialNBClassifier:
|
||||||
self.priors = None
|
self.priors = None
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
||||||
|
def _check_X(self, X):
|
||||||
|
if not sparse.issparse(X):
|
||||||
|
raise ValueError("Matrix X must be an instance of scipy.sparse.csr_matrix")
|
||||||
|
|
||||||
|
def _check_X_y(self, X, y):
|
||||||
|
self._check_X(X)
|
||||||
|
if X.shape[0] != len(y):
|
||||||
|
raise ValueError(
|
||||||
|
"The expected dimension for array y is (" + str(X.shape[0]) + ",), but got (" + str(len(y)) + ",)")
|
||||||
|
|
||||||
def fit(self, X, y):
|
def fit(self, X, y):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -56,16 +66,15 @@ class MultinomialNBClassifier:
|
||||||
y : array-like of shape (n_samples,)
|
y : array-like of shape (n_samples,)
|
||||||
Target labels
|
Target labels
|
||||||
"""
|
"""
|
||||||
if not sparse.issparse(X):
|
self._check_X_y(X, y)
|
||||||
raise ValueError("Matrix X must be an instance of scipy.sparse.csr_matrix")
|
|
||||||
n_examples, n_features = X.shape
|
n_examples, n_features = X.shape
|
||||||
grouped_data = group_data_by_target(y)
|
grouped_indices = group_indices_by_target(y)
|
||||||
self.classes = list(grouped_data.keys())
|
self.classes = list(grouped_indices.keys())
|
||||||
self.priors = np.zeros(shape=len(self.classes))
|
self.priors = np.zeros(shape=len(self.classes))
|
||||||
self.features_probs = np.zeros(shape=(len(self.classes), n_features))
|
self.features_probs = np.zeros(shape=(len(self.classes), n_features))
|
||||||
|
|
||||||
for i, class_i in enumerate(self.classes):
|
for i, class_i in enumerate(self.classes):
|
||||||
data_class_i = X[grouped_data[class_i]]
|
data_class_i = X[grouped_indices[class_i]]
|
||||||
prior_class_i = data_class_i.shape[0] / n_examples
|
prior_class_i = data_class_i.shape[0] / n_examples
|
||||||
self.priors[i] = prior_class_i
|
self.priors[i] = prior_class_i
|
||||||
tot_features_count = data_class_i.sum() # count of all features in class_i
|
tot_features_count = data_class_i.sum() # count of all features in class_i
|
||||||
|
@ -98,8 +107,7 @@ class MultinomialNBClassifier:
|
||||||
>>> model.predict(X[2:3])
|
>>> model.predict(X[2:3])
|
||||||
array([3])
|
array([3])
|
||||||
"""
|
"""
|
||||||
if not sparse.issparse(X):
|
self._check_X(X)
|
||||||
raise ValueError("Matrix X must be an instance of scipy.sparse.csr_matrix")
|
|
||||||
y_pred = []
|
y_pred = []
|
||||||
log_features_probs = np.log(self.features_probs)
|
log_features_probs = np.log(self.features_probs)
|
||||||
log_priors = np.log(self.priors)
|
log_priors = np.log(self.priors)
|
||||||
|
@ -126,7 +134,7 @@ def main():
|
||||||
model.fit(X_train, y_train)
|
model.fit(X_train, y_train)
|
||||||
|
|
||||||
y_pred = model.predict(X_test)
|
y_pred = model.predict(X_test)
|
||||||
print("Accuracy of Naive Bayes text classifier: " + str(accuracy_score(y_test, y_pred)))
|
print("Accuracy of naive bayes text classifier: " + str(accuracy_score(y_test, y_pred)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user