mirror of
https://github.com/metafy-social/python-scripts.git
synced 2024-11-30 15:31:10 +00:00
54 lines
1.9 KiB
Python
54 lines
1.9 KiB
Python
import itertools
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from sklearn.metrics import confusion_matrix
|
|
|
|
def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False):
|
|
# Create the confustion matrix
|
|
cm = confusion_matrix(y_true, y_pred)
|
|
cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it
|
|
n_classes = cm.shape[0] # find the number of classes we're dealing with
|
|
|
|
# Plot the figure and make it pretty
|
|
fig, ax = plt.subplots(figsize=figsize)
|
|
cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better
|
|
fig.colorbar(cax)
|
|
|
|
# Are there a list of classes?
|
|
if classes:
|
|
labels = classes
|
|
else:
|
|
labels = np.arange(cm.shape[0])
|
|
|
|
# Label the axes
|
|
ax.set(title="Confusion Matrix",
|
|
xlabel="Predicted label",
|
|
ylabel="True label",
|
|
xticks=np.arange(n_classes), # create enough axis slots for each class
|
|
yticks=np.arange(n_classes),
|
|
xticklabels=labels, # axes will labeled with class names (if they exist) or ints
|
|
yticklabels=labels)
|
|
|
|
# Make x-axis labels appear on bottom
|
|
ax.xaxis.set_label_position("bottom")
|
|
ax.xaxis.tick_bottom()
|
|
|
|
# Set the threshold for different colors
|
|
threshold = (cm.max() + cm.min()) / 2.
|
|
|
|
# Plot the text on each cell
|
|
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
|
|
if norm:
|
|
plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)",
|
|
horizontalalignment="center",
|
|
color="white" if cm[i, j] > threshold else "black",
|
|
size=text_size)
|
|
else:
|
|
plt.text(j, i, f"{cm[i, j]}",
|
|
horizontalalignment="center",
|
|
color="white" if cm[i, j] > threshold else "black",
|
|
size=text_size)
|
|
|
|
# Save the figure to the current working directory
|
|
if savefig:
|
|
fig.savefig("confusion_matrix.png") |