Fix DeprecationWarning in local_weighted_learning.py (#9165)

Fix DeprecationWarning that occurs during build due to converting an
np.ndarray to a scalar implicitly
This commit is contained in:
Tianyi Zheng 2023-09-30 23:31:35 -04:00 committed by GitHub
parent aaf7195465
commit 5f8d1cb5c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -122,7 +122,7 @@ def local_weight_regression(
""" """
y_pred = np.zeros(len(x_train)) # Initialize array of predictions y_pred = np.zeros(len(x_train)) # Initialize array of predictions
for i, item in enumerate(x_train): for i, item in enumerate(x_train):
y_pred[i] = item @ local_weight(item, x_train, y_train, tau) y_pred[i] = np.dot(item, local_weight(item, x_train, y_train, tau))
return y_pred return y_pred