diff --git a/scripts/Confusion_Matrix/README.md b/scripts/Confusion_Matrix/README.md new file mode 100644 index 0000000..69bdc2a --- /dev/null +++ b/scripts/Confusion_Matrix/README.md @@ -0,0 +1,24 @@ +# CODE BY ZeroToMastery TensorFlow course. +# The function makes a labelled confusion matrix comparing predictions and ground truth labels. + +# If classes is passed, confusion matrix will be labelled, if not, integer class values +# will be used. + +## Args: + +`y_true`: Array of truth labels (must be same shape as y_pred). +`y_pred`: Array of predicted labels (must be same shape as y_true). +`classes`: Array of class labels (e.g. string form). If `None`, integer labels are used. +`figsize`: Size of output figure (default=(10, 10)). +`text_size`: Size of output figure text (default=15). +`norm`: normalize values or not (default=False). +`savefig`: save confusion matrix to file (default=False). + +## Returns: A labelled confusion matrix plot comparing y_true and y_pred. +### Example usage: + +"""make_confusion_matrix(y_true=test_labels, # ground truth test labels + y_pred=y_preds, # predicted labels + classes=class_names, # array of class label names + figsize=(15, 15), + text_size=10)""" \ No newline at end of file diff --git a/scripts/Confusion_Matrix/make_confusion_matrix.py b/scripts/Confusion_Matrix/make_confusion_matrix.py new file mode 100644 index 0000000..2200f12 --- /dev/null +++ b/scripts/Confusion_Matrix/make_confusion_matrix.py @@ -0,0 +1,59 @@ +import itertools +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import confusion_matrix + +def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False): + # Create the confustion matrix + cm = confusion_matrix(y_true, y_pred) + cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it + n_classes = cm.shape[0] # find the number of classes we're dealing with + + # Plot the figure and make it pretty + fig, ax = plt.subplots(figsize=figsize) + cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better + fig.colorbar(cax) + + # Are there a list of classes? + if classes: + labels = classes + else: + labels = np.arange(cm.shape[0]) + + # Label the axes + ax.set(title="Confusion Matrix", + xlabel="Predicted label", + ylabel="True label", + xticks=np.arange(n_classes), # create enough axis slots for each class + yticks=np.arange(n_classes), + xticklabels=labels, # axes will labeled with class names (if they exist) or ints + yticklabels=labels) + + # Make x-axis labels appear on bottom + ax.xaxis.set_label_position("bottom") + ax.xaxis.tick_bottom() + + # Set the threshold for different colors + threshold = (cm.max() + cm.min()) / 2. + + # Plot the text on each cell + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + if norm: + plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)", + horizontalalignment="center", + color="white" if cm[i, j] > threshold else "black", + size=text_size) + else: + plt.text(j, i, f"{cm[i, j]}", + horizontalalignment="center", + color="white" if cm[i, j] > threshold else "black", + size=text_size) + + # Save the figure to the current working directory + if savefig: + fig.savefig("confusion_matrix.png") + +y_true = [0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0] +y_pred = [0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0] + +make_confusion_matrix(y_true, y_pred) \ No newline at end of file