diff --git a/maths/weddles_rule.py b/maths/weddles_rule.py index 97ae868a4..03deee89a 100644 --- a/maths/weddles_rule.py +++ b/maths/weddles_rule.py @@ -1,4 +1,5 @@ -from math import * +import numpy as np +from sympy import lambdify, symbols, sympify def get_inputs(): @@ -6,7 +7,8 @@ def get_inputs(): Get user input for the function, lower limit, and upper limit. Returns: - tuple: A tuple containing the function as a string, the lower limit (a), and the upper limit (b) as floats. + tuple: A tuple containing the function as a string, the lower limit (a), + and the upper limit (b) as floats. Example: >>> from unittest.mock import patch @@ -21,6 +23,24 @@ def get_inputs(): return func, a, b +def safe_function_eval(func_str): + """ + Safely evaluates the function by substituting x value using sympy. + + Args: + func_str (str): Function expression as a string. + + Returns: + float: The evaluated function result. + """ + x = symbols('x') + func_expr = sympify(func_str) + + # Convert the function to a callable lambda function + lambda_func = lambdify(x, func_expr, modules=["numpy"]) + return lambda_func + + def compute_table(func, a, b, acc): """ Compute the table of function values based on the limits and accuracy. @@ -35,14 +55,19 @@ def compute_table(func, a, b, acc): tuple: A tuple containing the table of values and the step size (h). Example: - >>> compute_table('1/(1+x**2)', 1, -1, 1) - ([0.5, 0.4235294117647058, 0.36, 0.3076923076923077, 0.26470588235294124, 0.22929936305732482, 0.2], -0.3333333333333333) + >>> compute_table( + ... safe_function_eval('1/(1+x**2)'), 1, -1, 1 + ... ) + (array([0.5 , 0.69230769, 0.9 , 1. , 0.9 , + 0.69230769, 0.5 ]), -0.3333333333333333) """ - h = (b - a) / (acc * 6) - table = [0 for _ in range(acc * 6 + 1)] - for j in range(acc * 6 + 1): - x = a + j / (acc * 6) - table[j] = eval(func) + # Weddle's rule requires number of intervals as a multiple of 6 for accuracy + n_points = acc * 6 + 1 + h = (b - a) / (n_points - 1) + x_vals = np.linspace(a, b, n_points) + + # Evaluate function values at all points + table = func(x_vals) return table, h @@ -86,7 +111,8 @@ def compute_solution(add, table, h): float: The final computed integral solution. Example: - >>> compute_solution([4.33, 6.0, 0.0, -4.33], [0.0, 0.866, 1.0, 0.866, 0.0, -0.866, -1.0], 0.5235983333333333) + >>> compute_solution([4.33, 6.0, 0.0, -4.33], [0.0, 0.866, 1.0, 0.866, 0.0, + ... -0.866, -1.0], 0.5235983333333333) 0.7853975 """ return 0.3 * h * (sum(add) + table[0] + table[-1]) @@ -94,17 +120,16 @@ def compute_solution(add, table, h): if __name__ == "__main__": from doctest import testmod - testmod() - + func, a, b = get_inputs() acc = 1 solution = None - while acc <= 100000: + while acc <= 100_000: table, h = compute_table(func, a, b, acc) add = apply_weights(table) solution = compute_solution(add, table, h) acc *= 10 - print(f"Solution: {solution}") + print(f'Solution: {solution}') \ No newline at end of file