Update decision_tree.py

This commit is contained in:
thor-harsh 2023-08-17 12:34:23 +05:30 committed by GitHub
parent f6b12420ce
commit 506fca12ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ Input data set: The input data set must be 1-dimensional with continuous labels.
Output: The decision tree maps a real number input to a real number output. Output: The decision tree maps a real number input to a real number output.
""" """
import numpy as np import numpy as np
import doctest
class DecisionTree: class DecisionTree:
def __init__(self, depth=5, min_leaf_size=5): def __init__(self, depth=5, min_leaf_size=5):
@ -151,13 +151,17 @@ class TestDecisionTree:
return float(squared_error_sum / labels.size) return float(squared_error_sum / labels.size)
def main(): def main():
""" """
In this demonstration we're generating a sample data set from the sin function in In this demonstration first we are generating x which is a numpy array containing values starting
numpy. We then train a decision tree on the data set and use the decision tree to from -1 to 1 with an interval of 0.005 i.e [-1,-0.995,....,0.995,1] this is what we are
predict the label of 10 different test values. Then the mean squared error over getting by applying arange function of numpy.Then the we are generating y by applying sin function
this test is displayed. on x which is an array containing values from -1 to 1 with difference of 0.005 i.e we are getting
an array y which contains sin of each value of x. We then train a decision tree on the data set
and use the decision tree to predict the label of 10 different test values. Here we should prefer
calculating Root Mean Squared Error over Mean Sqaured error beacause RMSE should be used
when you need to communicate your results in an understandable way to end users or when
penalising outliers is less of a priority.Interpretation will be easy in this case.
""" """
x = np.arange(-1.0, 1.0, 0.005) x = np.arange(-1.0, 1.0, 0.005)
y = np.sin(x) y = np.sin(x)
@ -167,7 +171,12 @@ def main():
test_cases = (np.random.rand(10) * 2) - 1 test_cases = (np.random.rand(10) * 2) - 1
predictions = np.array([tree.predict(x) for x in test_cases]) predictions = np.array([tree.predict(x) for x in test_cases])
avg_error = np.mean((predictions - test_cases) ** 2) mse = mean_squared_error(y_true, y_pred)
mse_error = np.mean((predictions - test_cases) ** 2)
"""RMSE error"""
avg_error = np.sqrt(avg_error)
print("Test values: " + str(test_cases)) print("Test values: " + str(test_cases))
print("Predictions: " + str(predictions)) print("Predictions: " + str(predictions))
@ -176,6 +185,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
import doctest
doctest.testmod(name="mean_squarred_error", verbose=True) doctest.testmod(name="mean_squarred_error", verbose=True)