diff --git a/machine_learning/linear_discriminant_analysis.py b/machine_learning/linear_discriminant_analysis.py index 22ee63a5a..0d19e970e 100644 --- a/machine_learning/linear_discriminant_analysis.py +++ b/machine_learning/linear_discriminant_analysis.py @@ -2,6 +2,7 @@ Linear Discriminant Analysis + Assumptions About Data : 1. The input variables has a gaussian distribution. 2. The variance calculated for each input variables by class grouping is the @@ -44,6 +45,7 @@ from math import log from os import name, system from random import gauss, seed +from typing import Callable, TypeVar # Make a training dataset drawn from a gaussian distribution @@ -245,6 +247,40 @@ def accuracy(actual_y: list, predicted_y: list) -> float: return (correct / len(actual_y)) * 100 +num = TypeVar("num") + + +def valid_input( + input_type: Callable[[object], num], # Usually float or int + input_msg: str, + err_msg: str, + condition: Callable[[num], bool] = lambda x: True, + default: str = None, +) -> num: + """ + Ask for user value and validate that it fulfill a condition. + + :input_type: user input expected type of value + :input_msg: message to show user in the screen + :err_msg: message to show in the screen in case of error + :condition: function that represents the condition that user input is valid. + :default: Default value in case the user does not type anything + :return: user's input + """ + while True: + try: + user_input = input_type(input(input_msg).strip() or default) + if condition(user_input): + return user_input + else: + print(f"{user_input}: {err_msg}") + continue + except ValueError: + print( + f"{user_input}: Incorrect input type, expected {input_type.__name__!r}" + ) + + # Main Function def main(): """ This function starts execution phase """ @@ -254,48 +290,26 @@ def main(): print("First of all we should specify the number of classes that") print("we want to generate as training dataset") # Trying to get number of classes - n_classes = 0 - while True: - try: - user_input = int( - input("Enter the number of classes (Data Groupings): ").strip() - ) - if user_input > 0: - n_classes = user_input - break - else: - print( - f"Your entered value is {user_input} , Number of classes " - f"should be positive!" - ) - continue - except ValueError: - print("Your entered value is not numerical!") + n_classes = valid_input( + input_type=int, + condition=lambda x: x > 0, + input_msg="Enter the number of classes (Data Groupings): ", + err_msg="Number of classes should be positive!", + ) print("-" * 100) - std_dev = 1.0 # Default value for standard deviation of dataset # Trying to get the value of standard deviation - while True: - try: - user_sd = float( - input( - "Enter the value of standard deviation" - "(Default value is 1.0 for all classes): " - ).strip() - or "1.0" - ) - if user_sd >= 0.0: - std_dev = user_sd - break - else: - print( - f"Your entered value is {user_sd}, Standard deviation should " - f"not be negative!" - ) - continue - except ValueError: - print("Your entered value is not numerical!") + std_dev = valid_input( + input_type=float, + condition=lambda x: x >= 0, + input_msg=( + "Enter the value of standard deviation" + "(Default value is 1.0 for all classes): " + ), + err_msg="Standard deviation should not be negative!", + default="1.0", + ) print("-" * 100) @@ -303,38 +317,24 @@ def main(): # dataset counts = [] # An empty list to store instance counts of classes in dataset for i in range(n_classes): - while True: - try: - user_count = int( - input(f"Enter The number of instances for class_{i+1}: ") - ) - if user_count > 0: - counts.append(user_count) - break - else: - print( - f"Your entered value is {user_count}, Number of " - "instances should be positive!" - ) - continue - except ValueError: - print("Your entered value is not numerical!") + user_count = valid_input( + input_type=int, + condition=lambda x: x > 0, + input_msg=(f"Enter The number of instances for class_{i+1}: "), + err_msg="Number of instances should be positive!", + ) + counts.append(user_count) print("-" * 100) # An empty list to store values of user-entered means of classes user_means = [] for a in range(n_classes): - while True: - try: - user_mean = float( - input(f"Enter the value of mean for class_{a+1}: ") - ) - if isinstance(user_mean, float): - user_means.append(user_mean) - break - print(f"You entered an invalid value: {user_mean}") - except ValueError: - print("Your entered value is not numerical!") + user_mean = valid_input( + input_type=float, + input_msg=(f"Enter the value of mean for class_{a+1}: "), + err_msg="This is an invalid value.", + ) + user_means.append(user_mean) print("-" * 100) print("Standard deviation: ", std_dev)