diff --git a/machine_learning/k_means_clust.py b/machine_learning/k_means_clust.py index a926362fc..41750b58a 100644 --- a/machine_learning/k_means_clust.py +++ b/machine_learning/k_means_clust.py @@ -37,7 +37,13 @@ Usage: heterogeneity, k ) - 5. Transfers Dataframe into excel format it must have feature called + 5. 3D Plot of the labeled data points with centroids. + plot_kmeans( + X, + centroids, + cluster_assignment + ) + 6. Transfers Dataframe into excel format it must have feature called 'Clust' with k means clustering numbers in it. """ @@ -126,6 +132,19 @@ def plot_heterogeneity(heterogeneity, k): plt.show() +def plot_kmeans(data, centroids, cluster_assignment): + ax = plt.axes(projection="3d") + ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=cluster_assignment, cmap="viridis") + ax.scatter( + centroids[:, 0], centroids[:, 1], centroids[:, 2], c="red", s=100, marker="x" + ) + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + ax.set_title("3D K-Means Clustering Visualization") + plt.show() + + def kmeans( data, k, initial_centroids, maxiter=500, record_heterogeneity=None, verbose=False ): @@ -193,6 +212,7 @@ if False: # change to true to run this test case. verbose=True, ) plot_heterogeneity(heterogeneity, k) + plot_kmeans(dataset["data"], centroids, cluster_assignment) def report_generator(