mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-11-24 13:31:07 +00:00
4617aa78b2
* Added dbscan in two formats. A jupyter notebook file for the storytelling and a .py file for people that just want to look at the code. The code in both is essentially the same. With a few things different in the .py file for plotting the clusters. * fixed LGTM problems * Some requested changes implemented. Still need to do docstring * implememted all changes as requested
272 lines
8.4 KiB
Python
272 lines
8.4 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from sklearn.datasets import make_moons
|
|
import warnings
|
|
|
|
|
|
def euclidean_distance(q, p):
|
|
"""
|
|
Calculates the Euclidean distance
|
|
between points q and p
|
|
|
|
Distance can only be calculated between numeric values
|
|
>>> euclidean_distance([1,'a'],[1,2])
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: Non-numeric input detected
|
|
|
|
The dimentions of both the points must be the same
|
|
>>> euclidean_distance([1,1,1],[1,2])
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: expected dimensions to be 2-d, instead got p:3 and q:2
|
|
|
|
Supports only two dimentional points
|
|
>>> euclidean_distance([1,1,1],[1,2])
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: expected dimensions to be 2-d, instead got p:3 and q:2
|
|
|
|
Input should be in the format [x,y] or (x,y)
|
|
>>> euclidean_distance(1,2)
|
|
Traceback (most recent call last):
|
|
...
|
|
TypeError: inputs must be iterable, either list [x,y] or tuple (x,y)
|
|
"""
|
|
if not hasattr(q, "__iter__") or not hasattr(p, "__iter__"):
|
|
raise TypeError("inputs must be iterable, either list [x,y] or tuple (x,y)")
|
|
|
|
if isinstance(q, str) or isinstance(p, str):
|
|
raise TypeError("inputs cannot be str")
|
|
|
|
if len(q) != 2 or len(p) != 2:
|
|
raise ValueError(
|
|
"expected dimensions to be 2-d, instead got p:{} and q:{}".format(
|
|
len(q), len(p)
|
|
)
|
|
)
|
|
|
|
for num in q + p:
|
|
try:
|
|
num = int(num)
|
|
except:
|
|
raise ValueError("Non-numeric input detected")
|
|
|
|
a = pow((q[0] - p[0]), 2)
|
|
b = pow((q[1] - p[1]), 2)
|
|
return pow((a + b), 0.5)
|
|
|
|
|
|
def find_neighbors(db, q, eps):
|
|
"""
|
|
Finds all points in the db that
|
|
are within a distance of eps from Q
|
|
|
|
eps value should be a number
|
|
>>> find_neighbors({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, (2,5),'a')
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: eps should be either int or float
|
|
|
|
Q must be a 2-d point as list or tuple
|
|
>>> find_neighbors({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, 2, 0.5)
|
|
Traceback (most recent call last):
|
|
...
|
|
TypeError: Q must a 2-dimentional point in the format (x,y) or [x,y]
|
|
|
|
Points must be in correct format
|
|
>>> find_neighbors([], (2,2) ,0.4)
|
|
Traceback (most recent call last):
|
|
...
|
|
TypeError: db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}
|
|
"""
|
|
|
|
if not isinstance(eps, (int, float)):
|
|
raise ValueError("eps should be either int or float")
|
|
|
|
if not hasattr(q, "__iter__"):
|
|
raise TypeError("Q must a 2-dimentional point in the format (x,y) or [x,y]")
|
|
|
|
if not isinstance(db, dict):
|
|
raise TypeError(
|
|
"db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}"
|
|
)
|
|
|
|
return [p for p in db if euclidean_distance(q, p) <= eps]
|
|
|
|
|
|
def plot_cluster(db, clusters, ax):
|
|
"""
|
|
Extracts all the points in the db and puts them together
|
|
as seperate clusters and finally plots them
|
|
|
|
db cannot be empty
|
|
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
|
|
>>> plot_cluster({},[1,2], axes[1] )
|
|
Traceback (most recent call last):
|
|
...
|
|
Exception: db is empty. No points to cluster
|
|
|
|
clusters cannot be empty
|
|
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
|
|
>>> plot_cluster({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},[],axes[1] )
|
|
Traceback (most recent call last):
|
|
...
|
|
Exception: nothing to cluster. Empty clusters
|
|
|
|
clusters cannot be empty
|
|
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
|
|
>>> plot_cluster({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},[],axes[1] )
|
|
Traceback (most recent call last):
|
|
...
|
|
Exception: nothing to cluster. Empty clusters
|
|
|
|
ax must be a plotable
|
|
>>> plot_cluster({ (1,2):{'label':'1'}, (2,3):{'label':'2'}},[1,2], [] )
|
|
Traceback (most recent call last):
|
|
...
|
|
TypeError: ax must be an slot in a matplotlib figure
|
|
"""
|
|
if len(db) == 0:
|
|
raise Exception("db is empty. No points to cluster")
|
|
|
|
if len(clusters) == 0:
|
|
raise Exception("nothing to cluster. Empty clusters")
|
|
|
|
if not hasattr(ax, "plot"):
|
|
raise TypeError("ax must be an slot in a matplotlib figure")
|
|
|
|
temp = []
|
|
noise = []
|
|
for i in clusters:
|
|
stack = []
|
|
for k, v in db.items():
|
|
if v["label"] == i:
|
|
stack.append(k)
|
|
elif v["label"] == "noise":
|
|
noise.append(k)
|
|
temp.append(stack)
|
|
|
|
color = iter(plt.cm.rainbow(np.linspace(0, 1, len(clusters))))
|
|
for i in range(0, len(temp)):
|
|
c = next(color)
|
|
x = [l[0] for l in temp[i]]
|
|
y = [l[1] for l in temp[i]]
|
|
ax.plot(x, y, "ro", c=c)
|
|
|
|
x = [l[0] for l in noise]
|
|
y = [l[1] for l in noise]
|
|
ax.plot(x, y, "ro", c="0")
|
|
|
|
|
|
def dbscan(db, eps, min_pts):
|
|
"""
|
|
Implementation of the DBSCAN algorithm
|
|
|
|
Points must be in correct format
|
|
>>> dbscan([], (2,2) ,0.4)
|
|
Traceback (most recent call last):
|
|
...
|
|
TypeError: db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}
|
|
|
|
eps value should be a number
|
|
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},'a',20 )
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: eps should be either int or float
|
|
|
|
min_pts value should be an integer
|
|
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},0.4,20.0 )
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: min_pts should be int
|
|
|
|
db cannot be empty
|
|
>>> dbscan({},0.4,20.0 )
|
|
Traceback (most recent call last):
|
|
...
|
|
Exception: db is empty, nothing to cluster
|
|
|
|
min_pts cannot be negative
|
|
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, 0.4, -20)
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: min_pts or eps cannot be negative
|
|
|
|
eps cannot be negative
|
|
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},-0.4, 20)
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: min_pts or eps cannot be negative
|
|
|
|
"""
|
|
if not isinstance(db, dict):
|
|
raise TypeError(
|
|
"db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}"
|
|
)
|
|
|
|
if len(db) == 0:
|
|
raise Exception("db is empty, nothing to cluster")
|
|
|
|
if not isinstance(eps, (int, float)):
|
|
raise ValueError("eps should be either int or float")
|
|
|
|
if not isinstance(min_pts, int):
|
|
raise ValueError("min_pts should be int")
|
|
|
|
if min_pts < 0 or eps < 0:
|
|
raise ValueError("min_pts or eps cannot be negative")
|
|
|
|
if min_pts == 0:
|
|
warnings.warn("min_pts is 0. Are you sure you want this ?")
|
|
|
|
if eps == 0:
|
|
warnings.warn("eps is 0. Are you sure you want this ?")
|
|
|
|
clusters = []
|
|
c = 0
|
|
for p in db:
|
|
if db[p]["label"] != "undefined":
|
|
continue
|
|
neighbors = find_neighbors(db, p, eps)
|
|
if len(neighbors) < min_pts:
|
|
db[p]["label"] = "noise"
|
|
continue
|
|
c += 1
|
|
clusters.append(c)
|
|
db[p]["label"] = c
|
|
neighbors.remove(p)
|
|
seed_set = neighbors.copy()
|
|
while seed_set != []:
|
|
q = seed_set.pop(0)
|
|
if db[q]["label"] == "noise":
|
|
db[q]["label"] = c
|
|
if db[q]["label"] != "undefined":
|
|
continue
|
|
db[q]["label"] = c
|
|
neighbors_n = find_neighbors(db, q, eps)
|
|
if len(neighbors_n) >= min_pts:
|
|
seed_set = seed_set + neighbors_n
|
|
return db, clusters
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
|
|
|
|
x, label = make_moons(n_samples=200, noise=0.1, random_state=19)
|
|
|
|
axes[0].plot(x[:, 0], x[:, 1], "ro")
|
|
|
|
points = {(point[0], point[1]): {"label": "undefined"} for point in x}
|
|
|
|
eps = 0.25
|
|
|
|
min_pts = 12
|
|
|
|
db, clusters = dbscan(points, eps, min_pts)
|
|
|
|
plot_cluster(db, clusters, axes[1])
|
|
|
|
plt.show()
|