Python/machine_learning/dbscan/dbscan.py
Kaushik Amar Das 4617aa78b2 DBSCAN algorithm (#1207)
* Added dbscan in two formats. A jupyter notebook file for the
storytelling and a .py file for people that just want to look at the
code. The code in both is essentially the same. With a few things
different in the .py file for plotting the clusters.

* fixed LGTM problems

* Some requested changes implemented.
Still need to do docstring

* implememted all changes as requested
2019-09-29 10:44:41 +02:00

272 lines
8.4 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_moons
import warnings
def euclidean_distance(q, p):
"""
Calculates the Euclidean distance
between points q and p
Distance can only be calculated between numeric values
>>> euclidean_distance([1,'a'],[1,2])
Traceback (most recent call last):
...
ValueError: Non-numeric input detected
The dimentions of both the points must be the same
>>> euclidean_distance([1,1,1],[1,2])
Traceback (most recent call last):
...
ValueError: expected dimensions to be 2-d, instead got p:3 and q:2
Supports only two dimentional points
>>> euclidean_distance([1,1,1],[1,2])
Traceback (most recent call last):
...
ValueError: expected dimensions to be 2-d, instead got p:3 and q:2
Input should be in the format [x,y] or (x,y)
>>> euclidean_distance(1,2)
Traceback (most recent call last):
...
TypeError: inputs must be iterable, either list [x,y] or tuple (x,y)
"""
if not hasattr(q, "__iter__") or not hasattr(p, "__iter__"):
raise TypeError("inputs must be iterable, either list [x,y] or tuple (x,y)")
if isinstance(q, str) or isinstance(p, str):
raise TypeError("inputs cannot be str")
if len(q) != 2 or len(p) != 2:
raise ValueError(
"expected dimensions to be 2-d, instead got p:{} and q:{}".format(
len(q), len(p)
)
)
for num in q + p:
try:
num = int(num)
except:
raise ValueError("Non-numeric input detected")
a = pow((q[0] - p[0]), 2)
b = pow((q[1] - p[1]), 2)
return pow((a + b), 0.5)
def find_neighbors(db, q, eps):
"""
Finds all points in the db that
are within a distance of eps from Q
eps value should be a number
>>> find_neighbors({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, (2,5),'a')
Traceback (most recent call last):
...
ValueError: eps should be either int or float
Q must be a 2-d point as list or tuple
>>> find_neighbors({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, 2, 0.5)
Traceback (most recent call last):
...
TypeError: Q must a 2-dimentional point in the format (x,y) or [x,y]
Points must be in correct format
>>> find_neighbors([], (2,2) ,0.4)
Traceback (most recent call last):
...
TypeError: db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}
"""
if not isinstance(eps, (int, float)):
raise ValueError("eps should be either int or float")
if not hasattr(q, "__iter__"):
raise TypeError("Q must a 2-dimentional point in the format (x,y) or [x,y]")
if not isinstance(db, dict):
raise TypeError(
"db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}"
)
return [p for p in db if euclidean_distance(q, p) <= eps]
def plot_cluster(db, clusters, ax):
"""
Extracts all the points in the db and puts them together
as seperate clusters and finally plots them
db cannot be empty
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
>>> plot_cluster({},[1,2], axes[1] )
Traceback (most recent call last):
...
Exception: db is empty. No points to cluster
clusters cannot be empty
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
>>> plot_cluster({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},[],axes[1] )
Traceback (most recent call last):
...
Exception: nothing to cluster. Empty clusters
clusters cannot be empty
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
>>> plot_cluster({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},[],axes[1] )
Traceback (most recent call last):
...
Exception: nothing to cluster. Empty clusters
ax must be a plotable
>>> plot_cluster({ (1,2):{'label':'1'}, (2,3):{'label':'2'}},[1,2], [] )
Traceback (most recent call last):
...
TypeError: ax must be an slot in a matplotlib figure
"""
if len(db) == 0:
raise Exception("db is empty. No points to cluster")
if len(clusters) == 0:
raise Exception("nothing to cluster. Empty clusters")
if not hasattr(ax, "plot"):
raise TypeError("ax must be an slot in a matplotlib figure")
temp = []
noise = []
for i in clusters:
stack = []
for k, v in db.items():
if v["label"] == i:
stack.append(k)
elif v["label"] == "noise":
noise.append(k)
temp.append(stack)
color = iter(plt.cm.rainbow(np.linspace(0, 1, len(clusters))))
for i in range(0, len(temp)):
c = next(color)
x = [l[0] for l in temp[i]]
y = [l[1] for l in temp[i]]
ax.plot(x, y, "ro", c=c)
x = [l[0] for l in noise]
y = [l[1] for l in noise]
ax.plot(x, y, "ro", c="0")
def dbscan(db, eps, min_pts):
"""
Implementation of the DBSCAN algorithm
Points must be in correct format
>>> dbscan([], (2,2) ,0.4)
Traceback (most recent call last):
...
TypeError: db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}
eps value should be a number
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},'a',20 )
Traceback (most recent call last):
...
ValueError: eps should be either int or float
min_pts value should be an integer
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},0.4,20.0 )
Traceback (most recent call last):
...
ValueError: min_pts should be int
db cannot be empty
>>> dbscan({},0.4,20.0 )
Traceback (most recent call last):
...
Exception: db is empty, nothing to cluster
min_pts cannot be negative
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, 0.4, -20)
Traceback (most recent call last):
...
ValueError: min_pts or eps cannot be negative
eps cannot be negative
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},-0.4, 20)
Traceback (most recent call last):
...
ValueError: min_pts or eps cannot be negative
"""
if not isinstance(db, dict):
raise TypeError(
"db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}"
)
if len(db) == 0:
raise Exception("db is empty, nothing to cluster")
if not isinstance(eps, (int, float)):
raise ValueError("eps should be either int or float")
if not isinstance(min_pts, int):
raise ValueError("min_pts should be int")
if min_pts < 0 or eps < 0:
raise ValueError("min_pts or eps cannot be negative")
if min_pts == 0:
warnings.warn("min_pts is 0. Are you sure you want this ?")
if eps == 0:
warnings.warn("eps is 0. Are you sure you want this ?")
clusters = []
c = 0
for p in db:
if db[p]["label"] != "undefined":
continue
neighbors = find_neighbors(db, p, eps)
if len(neighbors) < min_pts:
db[p]["label"] = "noise"
continue
c += 1
clusters.append(c)
db[p]["label"] = c
neighbors.remove(p)
seed_set = neighbors.copy()
while seed_set != []:
q = seed_set.pop(0)
if db[q]["label"] == "noise":
db[q]["label"] = c
if db[q]["label"] != "undefined":
continue
db[q]["label"] = c
neighbors_n = find_neighbors(db, q, eps)
if len(neighbors_n) >= min_pts:
seed_set = seed_set + neighbors_n
return db, clusters
if __name__ == "__main__":
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
x, label = make_moons(n_samples=200, noise=0.1, random_state=19)
axes[0].plot(x[:, 0], x[:, 1], "ro")
points = {(point[0], point[1]): {"label": "undefined"} for point in x}
eps = 0.25
min_pts = 12
db, clusters = dbscan(points, eps, min_pts)
plot_cluster(db, clusters, axes[1])
plt.show()