Fix DeprecationWarning in local_weighted_learning.py (Attempt 2) (#9170)

* Fix DeprecationWarning in local_weighted_learning.py

Fix DeprecationWarning that occurs during build due to converting an
np.ndarray to a scalar implicitly

* DeprecationWarning fix attempt 2
This commit is contained in:
Tianyi Zheng 2023-10-01 00:07:25 -04:00 committed by GitHub
parent 320d895b86
commit 280dfc1a22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -122,7 +122,7 @@ def local_weight_regression(
"""
y_pred = np.zeros(len(x_train)) # Initialize array of predictions
for i, item in enumerate(x_train):
y_pred[i] = np.dot(item, local_weight(item, x_train, y_train, tau))
y_pred[i] = np.dot(item, local_weight(item, x_train, y_train, tau)).item()
return y_pred