import itertools import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import confusion_matrix def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False): # Create the confustion matrix cm = confusion_matrix(y_true, y_pred) cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it n_classes = cm.shape[0] # find the number of classes we're dealing with # Plot the figure and make it pretty fig, ax = plt.subplots(figsize=figsize) cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better fig.colorbar(cax) # Are there a list of classes? if classes: labels = classes else: labels = np.arange(cm.shape[0]) # Label the axes ax.set(title="Confusion Matrix", xlabel="Predicted label", ylabel="True label", xticks=np.arange(n_classes), # create enough axis slots for each class yticks=np.arange(n_classes), xticklabels=labels, # axes will labeled with class names (if they exist) or ints yticklabels=labels) # Make x-axis labels appear on bottom ax.xaxis.set_label_position("bottom") ax.xaxis.tick_bottom() # Set the threshold for different colors threshold = (cm.max() + cm.min()) / 2. # Plot the text on each cell for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): if norm: plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)", horizontalalignment="center", color="white" if cm[i, j] > threshold else "black", size=text_size) else: plt.text(j, i, f"{cm[i, j]}", horizontalalignment="center", color="white" if cm[i, j] > threshold else "black", size=text_size) # Save the figure to the current working directory if savefig: fig.savefig("confusion_matrix.png")