From 36700a13ee46d2f4bb01d00df2d66f72393d96a9 Mon Sep 17 00:00:00 2001 From: tkgowtham Date: Wed, 2 Oct 2024 21:14:43 +0530 Subject: [PATCH] Update dbscan.py with more test cases --- machine_learning/dbscan.py | 73 +++++++++++++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/machine_learning/dbscan.py b/machine_learning/dbscan.py index 90fec1c0a..744ef69b8 100644 --- a/machine_learning/dbscan.py +++ b/machine_learning/dbscan.py @@ -7,6 +7,7 @@ LinkedIn : https://www.linkedin.com/in/gowtham-kamalasekar/ import math +import matplotlib.patches as mpatches import matplotlib.pyplot as plt import pandas as pd @@ -118,6 +119,38 @@ class DbScan: 11 [2, 10, 11, 12] 12 [9, 11, 12] + >>> result = DbScan(3, 2.5).perform_dbscan() + >>> for key in sorted(result): + ... print(key, sorted(result[key])) + 1 [1, 2, 10, 11] + 2 [1, 2, 3, 10, 11] + 3 [2, 3, 4, 11] + 4 [3, 4, 5, 6, 7, 8] + 5 [4, 5, 6, 7, 8] + 6 [4, 5, 6, 7] + 7 [4, 5, 6, 7, 8] + 8 [4, 5, 7, 8] + 9 [9, 11, 12] + 10 [1, 2, 10, 11, 12] + 11 [1, 2, 3, 9, 10, 11, 12] + 12 [9, 10, 11, 12] + + >>> result = DbScan(5, 2.5).perform_dbscan() + >>> for key in sorted(result): + ... print(key, sorted(result[key])) + 1 [1, 2, 10, 11] + 2 [1, 2, 3, 10, 11] + 3 [2, 3, 4, 11] + 4 [3, 4, 5, 6, 7, 8] + 5 [4, 5, 6, 7, 8] + 6 [4, 5, 6, 7] + 7 [4, 5, 6, 7, 8] + 8 [4, 5, 7, 8] + 9 [9, 11, 12] + 10 [1, 2, 10, 11, 12] + 11 [1, 2, 3, 9, 10, 11, 12] + 12 [9, 10, 11, 12] + """ if type(self.file) is str: data = pd.read_csv(self.file) @@ -159,6 +192,35 @@ class DbScan: 10 [1, 10, 11] ---> Noise ---> Border 11 [2, 10, 11, 12] ---> Core 12 [9, 11, 12] ---> Noise ---> Border + + >>> DbScan(5,2.5).print_dbscan() + 1 [1, 2, 10, 11] ---> Noise ---> Border + 2 [1, 2, 3, 10, 11] ---> Core + 3 [2, 3, 4, 11] ---> Noise ---> Border + 4 [3, 4, 5, 6, 7, 8] ---> Core + 5 [4, 5, 6, 7, 8] ---> Core + 6 [4, 5, 6, 7] ---> Noise ---> Border + 7 [4, 5, 6, 7, 8] ---> Core + 8 [4, 5, 7, 8] ---> Noise ---> Border + 9 [9, 11, 12] ---> Noise ---> Border + 10 [1, 2, 10, 11, 12] ---> Core + 11 [1, 2, 3, 9, 10, 11, 12] ---> Core + 12 [9, 10, 11, 12] ---> Noise ---> Border + + >>> DbScan(2,0.5).print_dbscan() + 1 [1] ---> Noise + 2 [2] ---> Noise + 3 [3] ---> Noise + 4 [4] ---> Noise + 5 [5] ---> Noise + 6 [6] ---> Noise + 7 [7] ---> Noise + 8 [8] ---> Noise + 9 [9] ---> Noise + 10 [10] ---> Noise + 11 [11] ---> Noise + 12 [12] ---> Noise + """ for i in self.dict1: print(i, " ", self.dict1[i], end=" ---> ") @@ -185,6 +247,13 @@ class DbScan: >>> DbScan(4,1.9).plot_dbscan() Plotted Successfully + + >>> DbScan(5,2.5).plot_dbscan() + Plotted Successfully + + >>> DbScan(5,2.5).plot_dbscan() + Plotted Successfully + """ if type(self.file) is str: data = pd.read_csv(self.file) @@ -214,10 +283,12 @@ class DbScan: ha="center", va="bottom", ) + core_legend = mpatches.Patch(color="red", label="Core") + noise_legend = mpatches.Patch(color="green", label="Noise") plt.xlabel("X") plt.ylabel("Y") plt.title("DBSCAN Clustering") - plt.legend(["Core", "Noise"]) + plt.legend(handles=[core_legend, noise_legend]) plt.show() print("Plotted Successfully")