mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-01-18 16:27:02 +00:00
DBSCAN algorithm (#1207)
* Added dbscan in two formats. A jupyter notebook file for the storytelling and a .py file for people that just want to look at the code. The code in both is essentially the same. With a few things different in the .py file for plotting the clusters. * fixed LGTM problems * Some requested changes implemented. Still need to do docstring * implememted all changes as requested
This commit is contained in:
parent
a79fc2b92a
commit
4617aa78b2
376
machine_learning/dbscan/dbscan.ipynb
Normal file
376
machine_learning/dbscan/dbscan.ipynb
Normal file
File diff suppressed because one or more lines are too long
271
machine_learning/dbscan/dbscan.py
Normal file
271
machine_learning/dbscan/dbscan.py
Normal file
|
@ -0,0 +1,271 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from sklearn.datasets import make_moons
|
||||
import warnings
|
||||
|
||||
|
||||
def euclidean_distance(q, p):
|
||||
"""
|
||||
Calculates the Euclidean distance
|
||||
between points q and p
|
||||
|
||||
Distance can only be calculated between numeric values
|
||||
>>> euclidean_distance([1,'a'],[1,2])
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Non-numeric input detected
|
||||
|
||||
The dimentions of both the points must be the same
|
||||
>>> euclidean_distance([1,1,1],[1,2])
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: expected dimensions to be 2-d, instead got p:3 and q:2
|
||||
|
||||
Supports only two dimentional points
|
||||
>>> euclidean_distance([1,1,1],[1,2])
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: expected dimensions to be 2-d, instead got p:3 and q:2
|
||||
|
||||
Input should be in the format [x,y] or (x,y)
|
||||
>>> euclidean_distance(1,2)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: inputs must be iterable, either list [x,y] or tuple (x,y)
|
||||
"""
|
||||
if not hasattr(q, "__iter__") or not hasattr(p, "__iter__"):
|
||||
raise TypeError("inputs must be iterable, either list [x,y] or tuple (x,y)")
|
||||
|
||||
if isinstance(q, str) or isinstance(p, str):
|
||||
raise TypeError("inputs cannot be str")
|
||||
|
||||
if len(q) != 2 or len(p) != 2:
|
||||
raise ValueError(
|
||||
"expected dimensions to be 2-d, instead got p:{} and q:{}".format(
|
||||
len(q), len(p)
|
||||
)
|
||||
)
|
||||
|
||||
for num in q + p:
|
||||
try:
|
||||
num = int(num)
|
||||
except:
|
||||
raise ValueError("Non-numeric input detected")
|
||||
|
||||
a = pow((q[0] - p[0]), 2)
|
||||
b = pow((q[1] - p[1]), 2)
|
||||
return pow((a + b), 0.5)
|
||||
|
||||
|
||||
def find_neighbors(db, q, eps):
|
||||
"""
|
||||
Finds all points in the db that
|
||||
are within a distance of eps from Q
|
||||
|
||||
eps value should be a number
|
||||
>>> find_neighbors({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, (2,5),'a')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: eps should be either int or float
|
||||
|
||||
Q must be a 2-d point as list or tuple
|
||||
>>> find_neighbors({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, 2, 0.5)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: Q must a 2-dimentional point in the format (x,y) or [x,y]
|
||||
|
||||
Points must be in correct format
|
||||
>>> find_neighbors([], (2,2) ,0.4)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}
|
||||
"""
|
||||
|
||||
if not isinstance(eps, (int, float)):
|
||||
raise ValueError("eps should be either int or float")
|
||||
|
||||
if not hasattr(q, "__iter__"):
|
||||
raise TypeError("Q must a 2-dimentional point in the format (x,y) or [x,y]")
|
||||
|
||||
if not isinstance(db, dict):
|
||||
raise TypeError(
|
||||
"db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}"
|
||||
)
|
||||
|
||||
return [p for p in db if euclidean_distance(q, p) <= eps]
|
||||
|
||||
|
||||
def plot_cluster(db, clusters, ax):
|
||||
"""
|
||||
Extracts all the points in the db and puts them together
|
||||
as seperate clusters and finally plots them
|
||||
|
||||
db cannot be empty
|
||||
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
|
||||
>>> plot_cluster({},[1,2], axes[1] )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
Exception: db is empty. No points to cluster
|
||||
|
||||
clusters cannot be empty
|
||||
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
|
||||
>>> plot_cluster({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},[],axes[1] )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
Exception: nothing to cluster. Empty clusters
|
||||
|
||||
clusters cannot be empty
|
||||
>>> fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
|
||||
>>> plot_cluster({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},[],axes[1] )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
Exception: nothing to cluster. Empty clusters
|
||||
|
||||
ax must be a plotable
|
||||
>>> plot_cluster({ (1,2):{'label':'1'}, (2,3):{'label':'2'}},[1,2], [] )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: ax must be an slot in a matplotlib figure
|
||||
"""
|
||||
if len(db) == 0:
|
||||
raise Exception("db is empty. No points to cluster")
|
||||
|
||||
if len(clusters) == 0:
|
||||
raise Exception("nothing to cluster. Empty clusters")
|
||||
|
||||
if not hasattr(ax, "plot"):
|
||||
raise TypeError("ax must be an slot in a matplotlib figure")
|
||||
|
||||
temp = []
|
||||
noise = []
|
||||
for i in clusters:
|
||||
stack = []
|
||||
for k, v in db.items():
|
||||
if v["label"] == i:
|
||||
stack.append(k)
|
||||
elif v["label"] == "noise":
|
||||
noise.append(k)
|
||||
temp.append(stack)
|
||||
|
||||
color = iter(plt.cm.rainbow(np.linspace(0, 1, len(clusters))))
|
||||
for i in range(0, len(temp)):
|
||||
c = next(color)
|
||||
x = [l[0] for l in temp[i]]
|
||||
y = [l[1] for l in temp[i]]
|
||||
ax.plot(x, y, "ro", c=c)
|
||||
|
||||
x = [l[0] for l in noise]
|
||||
y = [l[1] for l in noise]
|
||||
ax.plot(x, y, "ro", c="0")
|
||||
|
||||
|
||||
def dbscan(db, eps, min_pts):
|
||||
"""
|
||||
Implementation of the DBSCAN algorithm
|
||||
|
||||
Points must be in correct format
|
||||
>>> dbscan([], (2,2) ,0.4)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}
|
||||
|
||||
eps value should be a number
|
||||
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},'a',20 )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: eps should be either int or float
|
||||
|
||||
min_pts value should be an integer
|
||||
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},0.4,20.0 )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: min_pts should be int
|
||||
|
||||
db cannot be empty
|
||||
>>> dbscan({},0.4,20.0 )
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
Exception: db is empty, nothing to cluster
|
||||
|
||||
min_pts cannot be negative
|
||||
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}}, 0.4, -20)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: min_pts or eps cannot be negative
|
||||
|
||||
eps cannot be negative
|
||||
>>> dbscan({ (1,2):{'label':'undefined'}, (2,3):{'label':'undefined'}},-0.4, 20)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: min_pts or eps cannot be negative
|
||||
|
||||
"""
|
||||
if not isinstance(db, dict):
|
||||
raise TypeError(
|
||||
"db must be a dict of points in the format {(x,y):{'label':'boolean/undefined'}}"
|
||||
)
|
||||
|
||||
if len(db) == 0:
|
||||
raise Exception("db is empty, nothing to cluster")
|
||||
|
||||
if not isinstance(eps, (int, float)):
|
||||
raise ValueError("eps should be either int or float")
|
||||
|
||||
if not isinstance(min_pts, int):
|
||||
raise ValueError("min_pts should be int")
|
||||
|
||||
if min_pts < 0 or eps < 0:
|
||||
raise ValueError("min_pts or eps cannot be negative")
|
||||
|
||||
if min_pts == 0:
|
||||
warnings.warn("min_pts is 0. Are you sure you want this ?")
|
||||
|
||||
if eps == 0:
|
||||
warnings.warn("eps is 0. Are you sure you want this ?")
|
||||
|
||||
clusters = []
|
||||
c = 0
|
||||
for p in db:
|
||||
if db[p]["label"] != "undefined":
|
||||
continue
|
||||
neighbors = find_neighbors(db, p, eps)
|
||||
if len(neighbors) < min_pts:
|
||||
db[p]["label"] = "noise"
|
||||
continue
|
||||
c += 1
|
||||
clusters.append(c)
|
||||
db[p]["label"] = c
|
||||
neighbors.remove(p)
|
||||
seed_set = neighbors.copy()
|
||||
while seed_set != []:
|
||||
q = seed_set.pop(0)
|
||||
if db[q]["label"] == "noise":
|
||||
db[q]["label"] = c
|
||||
if db[q]["label"] != "undefined":
|
||||
continue
|
||||
db[q]["label"] = c
|
||||
neighbors_n = find_neighbors(db, q, eps)
|
||||
if len(neighbors_n) >= min_pts:
|
||||
seed_set = seed_set + neighbors_n
|
||||
return db, clusters
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(7, 5))
|
||||
|
||||
x, label = make_moons(n_samples=200, noise=0.1, random_state=19)
|
||||
|
||||
axes[0].plot(x[:, 0], x[:, 1], "ro")
|
||||
|
||||
points = {(point[0], point[1]): {"label": "undefined"} for point in x}
|
||||
|
||||
eps = 0.25
|
||||
|
||||
min_pts = 12
|
||||
|
||||
db, clusters = dbscan(points, eps, min_pts)
|
||||
|
||||
plot_cluster(db, clusters, axes[1])
|
||||
|
||||
plt.show()
|
Loading…
Reference in New Issue
Block a user