mirror of
https://github.com/metafy-social/python-scripts.git
synced 2025-01-31 05:33:44 +00:00
confusion matrix python script added
This commit is contained in:
parent
a1dde3ba2d
commit
f9e93bf8ba
24
scripts/Confusion_Matrix/README.md
Normal file
24
scripts/Confusion_Matrix/README.md
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# CODE BY ZeroToMastery TensorFlow course.
|
||||||
|
# The function makes a labelled confusion matrix comparing predictions and ground truth labels.
|
||||||
|
|
||||||
|
# If classes is passed, confusion matrix will be labelled, if not, integer class values
|
||||||
|
# will be used.
|
||||||
|
|
||||||
|
## Args:
|
||||||
|
|
||||||
|
`y_true`: Array of truth labels (must be same shape as y_pred).
|
||||||
|
`y_pred`: Array of predicted labels (must be same shape as y_true).
|
||||||
|
`classes`: Array of class labels (e.g. string form). If `None`, integer labels are used.
|
||||||
|
`figsize`: Size of output figure (default=(10, 10)).
|
||||||
|
`text_size`: Size of output figure text (default=15).
|
||||||
|
`norm`: normalize values or not (default=False).
|
||||||
|
`savefig`: save confusion matrix to file (default=False).
|
||||||
|
|
||||||
|
## Returns: A labelled confusion matrix plot comparing y_true and y_pred.
|
||||||
|
### Example usage:
|
||||||
|
|
||||||
|
"""make_confusion_matrix(y_true=test_labels, # ground truth test labels
|
||||||
|
y_pred=y_preds, # predicted labels
|
||||||
|
classes=class_names, # array of class label names
|
||||||
|
figsize=(15, 15),
|
||||||
|
text_size=10)"""
|
59
scripts/Confusion_Matrix/make_confusion_matrix.py
Normal file
59
scripts/Confusion_Matrix/make_confusion_matrix.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import itertools
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics import confusion_matrix
|
||||||
|
|
||||||
|
def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False):
|
||||||
|
# Create the confustion matrix
|
||||||
|
cm = confusion_matrix(y_true, y_pred)
|
||||||
|
cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it
|
||||||
|
n_classes = cm.shape[0] # find the number of classes we're dealing with
|
||||||
|
|
||||||
|
# Plot the figure and make it pretty
|
||||||
|
fig, ax = plt.subplots(figsize=figsize)
|
||||||
|
cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better
|
||||||
|
fig.colorbar(cax)
|
||||||
|
|
||||||
|
# Are there a list of classes?
|
||||||
|
if classes:
|
||||||
|
labels = classes
|
||||||
|
else:
|
||||||
|
labels = np.arange(cm.shape[0])
|
||||||
|
|
||||||
|
# Label the axes
|
||||||
|
ax.set(title="Confusion Matrix",
|
||||||
|
xlabel="Predicted label",
|
||||||
|
ylabel="True label",
|
||||||
|
xticks=np.arange(n_classes), # create enough axis slots for each class
|
||||||
|
yticks=np.arange(n_classes),
|
||||||
|
xticklabels=labels, # axes will labeled with class names (if they exist) or ints
|
||||||
|
yticklabels=labels)
|
||||||
|
|
||||||
|
# Make x-axis labels appear on bottom
|
||||||
|
ax.xaxis.set_label_position("bottom")
|
||||||
|
ax.xaxis.tick_bottom()
|
||||||
|
|
||||||
|
# Set the threshold for different colors
|
||||||
|
threshold = (cm.max() + cm.min()) / 2.
|
||||||
|
|
||||||
|
# Plot the text on each cell
|
||||||
|
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
|
||||||
|
if norm:
|
||||||
|
plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)",
|
||||||
|
horizontalalignment="center",
|
||||||
|
color="white" if cm[i, j] > threshold else "black",
|
||||||
|
size=text_size)
|
||||||
|
else:
|
||||||
|
plt.text(j, i, f"{cm[i, j]}",
|
||||||
|
horizontalalignment="center",
|
||||||
|
color="white" if cm[i, j] > threshold else "black",
|
||||||
|
size=text_size)
|
||||||
|
|
||||||
|
# Save the figure to the current working directory
|
||||||
|
if savefig:
|
||||||
|
fig.savefig("confusion_matrix.png")
|
||||||
|
|
||||||
|
y_true = [0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0]
|
||||||
|
y_pred = [0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0]
|
||||||
|
|
||||||
|
make_confusion_matrix(y_true, y_pred)
|
Loading…
Reference in New Issue
Block a user