mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-02-25 18:38:39 +00:00
Update decision_tree.py
This commit is contained in:
parent
f6b12420ce
commit
506fca12ec
@ -4,7 +4,7 @@ Input data set: The input data set must be 1-dimensional with continuous labels.
|
||||
Output: The decision tree maps a real number input to a real number output.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import doctest
|
||||
|
||||
class DecisionTree:
|
||||
def __init__(self, depth=5, min_leaf_size=5):
|
||||
@ -151,23 +151,32 @@ class TestDecisionTree:
|
||||
|
||||
return float(squared_error_sum / labels.size)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
In this demonstration we're generating a sample data set from the sin function in
|
||||
numpy. We then train a decision tree on the data set and use the decision tree to
|
||||
predict the label of 10 different test values. Then the mean squared error over
|
||||
this test is displayed.
|
||||
In this demonstration first we are generating x which is a numpy array containing values starting
|
||||
from -1 to 1 with an interval of 0.005 i.e [-1,-0.995,....,0.995,1] this is what we are
|
||||
getting by applying arange function of numpy.Then the we are generating y by applying sin function
|
||||
on x which is an array containing values from -1 to 1 with difference of 0.005 i.e we are getting
|
||||
an array y which contains sin of each value of x. We then train a decision tree on the data set
|
||||
and use the decision tree to predict the label of 10 different test values. Here we should prefer
|
||||
calculating Root Mean Squared Error over Mean Sqaured error beacause RMSE should be used
|
||||
when you need to communicate your results in an understandable way to end users or when
|
||||
penalising outliers is less of a priority.Interpretation will be easy in this case.
|
||||
"""
|
||||
x = np.arange(-1.0, 1.0, 0.005)
|
||||
y = np.sin(x)
|
||||
|
||||
tree = DecisionTree(depth=10, min_leaf_size=10)
|
||||
tree.train(x, y)
|
||||
|
||||
|
||||
test_cases = (np.random.rand(10) * 2) - 1
|
||||
predictions = np.array([tree.predict(x) for x in test_cases])
|
||||
avg_error = np.mean((predictions - test_cases) ** 2)
|
||||
mse = mean_squared_error(y_true, y_pred)
|
||||
|
||||
mse_error = np.mean((predictions - test_cases) ** 2)
|
||||
|
||||
"""RMSE error"""
|
||||
avg_error = np.sqrt(avg_error)
|
||||
|
||||
print("Test values: " + str(test_cases))
|
||||
print("Predictions: " + str(predictions))
|
||||
@ -176,6 +185,5 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
import doctest
|
||||
|
||||
|
||||
doctest.testmod(name="mean_squarred_error", verbose=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user