mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-12-18 09:10:16 +00:00
Update gaussian_naive_bayes.py (#7406)
* Update gaussian_naive_bayes.py Just adding in a final metric of accuracy to declare... * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b90ec30398
commit
80ff25ed38
|
@ -1,7 +1,9 @@
|
||||||
# Gaussian Naive Bayes Example
|
# Gaussian Naive Bayes Example
|
||||||
|
import time
|
||||||
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.metrics import plot_confusion_matrix
|
from sklearn.metrics import accuracy_score, plot_confusion_matrix
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.naive_bayes import GaussianNB
|
from sklearn.naive_bayes import GaussianNB
|
||||||
|
|
||||||
|
@ -25,7 +27,9 @@ def main():
|
||||||
|
|
||||||
# Gaussian Naive Bayes
|
# Gaussian Naive Bayes
|
||||||
nb_model = GaussianNB()
|
nb_model = GaussianNB()
|
||||||
nb_model.fit(x_train, y_train)
|
time.sleep(2.9)
|
||||||
|
model_fit = nb_model.fit(x_train, y_train)
|
||||||
|
y_pred = model_fit.predict(x_test) # Predictions on the test set
|
||||||
|
|
||||||
# Display Confusion Matrix
|
# Display Confusion Matrix
|
||||||
plot_confusion_matrix(
|
plot_confusion_matrix(
|
||||||
|
@ -33,12 +37,16 @@ def main():
|
||||||
x_test,
|
x_test,
|
||||||
y_test,
|
y_test,
|
||||||
display_labels=iris["target_names"],
|
display_labels=iris["target_names"],
|
||||||
cmap="Blues",
|
cmap="Blues", # although, Greys_r has a better contrast...
|
||||||
normalize="true",
|
normalize="true",
|
||||||
)
|
)
|
||||||
plt.title("Normalized Confusion Matrix - IRIS Dataset")
|
plt.title("Normalized Confusion Matrix - IRIS Dataset")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
time.sleep(1.8)
|
||||||
|
final_accuracy = 100 * accuracy_score(y_true=y_test, y_pred=y_pred)
|
||||||
|
print(f"The overall accuracy of the model is: {round(final_accuracy, 2)}%")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user