import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import make_moons import warnings def euclidean_distance(q, p): """ Calculates the Euclidean distance between points q and p Distance can only be calculated between numeric values >>> euclidean_distance([1,'a'],[1,2]) Traceback (most recent call last): ... ValueError: Non-numeric input detected The dimentions of both the points must be the same >>> euclidean_distance([1,1,1],[1,2]) Traceback (most recent call last): ... ValueError: expected dimensions to be 2-d, instead got p:3 and q:2 Supports only two dimentional points >>> euclidean_distance([1,1,1],[1,2]) Traceback (most recent call last): ... ValueError: expected dimensions to be 2-d, instead got p:3 and q:2 Input should be in the format [x,y] or (x,y) >>> euclidean_distance(1,2) Traceback (most recent call last): ... TypeError: inputs must be iterable, either list [x,y] or tuple (x,y) """ if not hasattr(q, "__iter__") or not hasattr(p, "__iter__"): raise TypeError("inputs must be iterable, either list [x,y] or tuple (x,y)") if isinstance(q, str) or isinstance(p, str): raise TypeError("inputs cannot be str") if len(q) != 2 or len(p) != 2: raise ValueError( "expected dimensions to be 2-d, instead got p:{} and q:{}".format( len(q), len(p) ) ) for num in q + p: try: num = int(num) except: raise ValueError("Non-numeric input detected") a = pow((q[0] - p[0]), 2) b = pow((q[1] - p[1]), 2) return pow((a + b), 0.5) def find_neighbors(db, q, eps): """ Finds all points in the db that are within a distance of eps from Q eps value should be a number >>> find_neighbors({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, (2,5),'a') Traceback (most recent call last): ... ValueError: eps should be either int or float Q must be a 2-d point as list or tuple >>> find_neighbors({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, 2, 0.5) Traceback (most recent call last): ... TypeError: Q must a 2-dimentional point in the format (x,y) or [x,y] Points must be in correct format >>> find_neighbors([], (2,2) ,0.4) Traceback (most recent call last): ... TypeError: db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}} """ if not isinstance(eps, (int, float)): raise ValueError("eps should be either int or float") if not hasattr(q, "__iter__"): raise TypeError("Q must a 2-dimentional point in the format (x,y) or [x,y]") if not isinstance(db, dict): raise TypeError( "db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}" ) return [p for p in db if euclidean_distance(q, p) <= eps] def plot_cluster(db, clusters, ax): """ Extracts all the points in the db and puts them together as seperate clusters and finally plots them db cannot be empty >>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5)) >>> plot_cluster({},[1,2], axes[1] ) Traceback (most recent call last): ... Exception: db is empty. No points to cluster clusters cannot be empty >>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5)) >>> plot_cluster({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},[],axes[1] ) Traceback (most recent call last): ... Exception: nothing to cluster. Empty clusters clusters cannot be empty >>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5)) >>> plot_cluster({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},[],axes[1] ) Traceback (most recent call last): ... Exception: nothing to cluster. Empty clusters ax must be a plotable >>> plot_cluster({ (1,2):{'label':'1'}, (2,3):{'label':'2'}},[1,2], [] ) Traceback (most recent call last): ... TypeError: ax must be an slot in a matplotlib figure """ if len(db) == 0: raise Exception("db is empty. No points to cluster") if len(clusters) == 0: raise Exception("nothing to cluster. Empty clusters") if not hasattr(ax, "plot"): raise TypeError("ax must be an slot in a matplotlib figure") temp = [] noise = [] for i in clusters: stack = [] for k, v in db.items(): if v["label"] == i: stack.append(k) elif v["label"] == "noise": noise.append(k) temp.append(stack) color = iter(plt.cm.rainbow(np.linspace(0, 1, len(clusters)))) for i in range(0, len(temp)): c = next(color) x = [l[0] for l in temp[i]] y = [l[1] for l in temp[i]] ax.plot(x, y, "ro", c=c) x = [l[0] for l in noise] y = [l[1] for l in noise] ax.plot(x, y, "ro", c="0") def dbscan(db, eps, min_pts): """ Implementation of the DBSCAN algorithm Points must be in correct format >>> dbscan([], (2,2) ,0.4) Traceback (most recent call last): ... TypeError: db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}} eps value should be a number >>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},'a',20 ) Traceback (most recent call last): ... ValueError: eps should be either int or float min_pts value should be an integer >>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},0.4,20.0 ) Traceback (most recent call last): ... ValueError: min_pts should be int db cannot be empty >>> dbscan({},0.4,20.0 ) Traceback (most recent call last): ... Exception: db is empty, nothing to cluster min_pts cannot be negative >>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, 0.4, -20) Traceback (most recent call last): ... ValueError: min_pts or eps cannot be negative eps cannot be negative >>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},-0.4, 20) Traceback (most recent call last): ... ValueError: min_pts or eps cannot be negative """ if not isinstance(db, dict): raise TypeError( "db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}" ) if len(db) == 0: raise Exception("db is empty, nothing to cluster") if not isinstance(eps, (int, float)): raise ValueError("eps should be either int or float") if not isinstance(min_pts, int): raise ValueError("min_pts should be int") if min_pts < 0 or eps < 0: raise ValueError("min_pts or eps cannot be negative") if min_pts == 0: warnings.warn("min_pts is 0. Are you sure you want this ?") if eps == 0: warnings.warn("eps is 0. Are you sure you want this ?") clusters = [] c = 0 for p in db: if db[p]["label"] != "undefined": continue neighbors = find_neighbors(db, p, eps) if len(neighbors) < min_pts: db[p]["label"] = "noise" continue c += 1 clusters.append(c) db[p]["label"] = c neighbors.remove(p) seed_set = neighbors.copy() while seed_set != []: q = seed_set.pop(0) if db[q]["label"] == "noise": db[q]["label"] = c if db[q]["label"] != "undefined": continue db[q]["label"] = c neighbors_n = find_neighbors(db, q, eps) if len(neighbors_n) >= min_pts: seed_set = seed_set + neighbors_n return db, clusters if __name__ == "__main__": fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5)) x, label = make_moons(n_samples=200, noise=0.1, random_state=19) axes[0].plot(x[:, 0], x[:, 1], "ro") points = {(point[0], point[1]): {"label": "undefined"} for point in x} eps = 0.25 min_pts = 12 db, clusters = dbscan(points, eps, min_pts) plot_cluster(db, clusters, axes[1]) plt.show()