mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-04-04 04:46:50 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
369a6b2d16
commit
91c8173691
@ -42,7 +42,7 @@ class LSTM:
|
|||||||
def __init__(self, data: str, hidden_dim: int = 25, epochs: int = 1000, lr: float = 0.05) -> None:
|
def __init__(self, data: str, hidden_dim: int = 25, epochs: int = 1000, lr: float = 0.05) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the LSTM network with the given data and hyperparameters.
|
Initialize the LSTM network with the given data and hyperparameters.
|
||||||
|
|
||||||
:param data: The input data as a string.
|
:param data: The input data as a string.
|
||||||
:param hidden_dim: The number of hidden units in the LSTM layer.
|
:param hidden_dim: The number of hidden units in the LSTM layer.
|
||||||
:param epochs: The number of training epochs.
|
:param epochs: The number of training epochs.
|
||||||
@ -69,7 +69,7 @@ class LSTM:
|
|||||||
def one_hot_encode(self, char: str) -> np.ndarray:
|
def one_hot_encode(self, char: str) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
One-hot encode a character.
|
One-hot encode a character.
|
||||||
|
|
||||||
:param char: The character to encode.
|
:param char: The character to encode.
|
||||||
:return: A one-hot encoded vector.
|
:return: A one-hot encoded vector.
|
||||||
"""
|
"""
|
||||||
@ -99,7 +99,7 @@ class LSTM:
|
|||||||
def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
|
def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Initialize weights with random values.
|
Initialize weights with random values.
|
||||||
|
|
||||||
:param input_dim: The input dimension.
|
:param input_dim: The input dimension.
|
||||||
:param output_dim: The output dimension.
|
:param output_dim: The output dimension.
|
||||||
:return: A matrix of initialized weights.
|
:return: A matrix of initialized weights.
|
||||||
@ -110,7 +110,7 @@ class LSTM:
|
|||||||
def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
|
def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Sigmoid activation function.
|
Sigmoid activation function.
|
||||||
|
|
||||||
:param x: The input array.
|
:param x: The input array.
|
||||||
:param derivative: Whether to compute the derivative.
|
:param derivative: Whether to compute the derivative.
|
||||||
:return: The sigmoid activation or its derivative.
|
:return: The sigmoid activation or its derivative.
|
||||||
@ -122,7 +122,7 @@ class LSTM:
|
|||||||
def tanh(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
|
def tanh(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Tanh activation function.
|
Tanh activation function.
|
||||||
|
|
||||||
:param x: The input array.
|
:param x: The input array.
|
||||||
:param derivative: Whether to compute the derivative.
|
:param derivative: Whether to compute the derivative.
|
||||||
:return: The tanh activation or its derivative.
|
:return: The tanh activation or its derivative.
|
||||||
@ -134,7 +134,7 @@ class LSTM:
|
|||||||
def softmax(self, x: np.ndarray) -> np.ndarray:
|
def softmax(self, x: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Softmax activation function.
|
Softmax activation function.
|
||||||
|
|
||||||
:param x: The input array.
|
:param x: The input array.
|
||||||
:return: The softmax activation.
|
:return: The softmax activation.
|
||||||
"""
|
"""
|
||||||
@ -161,7 +161,7 @@ class LSTM:
|
|||||||
def forward(self, inputs: list) -> list:
|
def forward(self, inputs: list) -> list:
|
||||||
"""
|
"""
|
||||||
Perform forward propagation through the LSTM network.
|
Perform forward propagation through the LSTM network.
|
||||||
|
|
||||||
:param inputs: The input data as a list of one-hot encoded vectors.
|
:param inputs: The input data as a list of one-hot encoded vectors.
|
||||||
:return: The outputs of the network.
|
:return: The outputs of the network.
|
||||||
"""
|
"""
|
||||||
@ -186,7 +186,7 @@ class LSTM:
|
|||||||
def backward(self, errors: list, inputs: list) -> None:
|
def backward(self, errors: list, inputs: list) -> None:
|
||||||
"""
|
"""
|
||||||
Perform backpropagation through time to compute gradients and update weights.
|
Perform backpropagation through time to compute gradients and update weights.
|
||||||
|
|
||||||
:param errors: The errors at each time step.
|
:param errors: The errors at each time step.
|
||||||
:param inputs: The input data as a list of one-hot encoded vectors.
|
:param inputs: The input data as a list of one-hot encoded vectors.
|
||||||
"""
|
"""
|
||||||
@ -224,7 +224,7 @@ class LSTM:
|
|||||||
d_i = d_cs * self.candidate_gates[t] * self.sigmoid(self.input_gates[t], derivative=True)
|
d_i = d_cs * self.candidate_gates[t] * self.sigmoid(self.input_gates[t], derivative=True)
|
||||||
d_wi += np.dot(d_i, inputs[t].T)
|
d_wi += np.dot(d_i, inputs[t].T)
|
||||||
d_bi += d_i
|
d_bi += d_i
|
||||||
|
|
||||||
# Candidate Gate Weights and Biases Errors
|
# Candidate Gate Weights and Biases Errors
|
||||||
d_c = d_cs * self.input_gates[t] * self.tanh(self.candidate_gates[t], derivative=True)
|
d_c = d_cs * self.input_gates[t] * self.tanh(self.candidate_gates[t], derivative=True)
|
||||||
d_wc += np.dot(d_c, inputs[t].T)
|
d_wc += np.dot(d_c, inputs[t].T)
|
||||||
@ -270,7 +270,7 @@ class LSTM:
|
|||||||
errors[-1][self.char_to_idx[self.train_y[t]]] += 1
|
errors[-1][self.char_to_idx[self.train_y[t]]] += 1
|
||||||
|
|
||||||
self.backward(errors, self.concat_inputs)
|
self.backward(errors, self.concat_inputs)
|
||||||
|
|
||||||
def test(self) -> None:
|
def test(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test the trained LSTM network on the input data and print the accuracy.
|
Test the trained LSTM network on the input data and print the accuracy.
|
||||||
@ -289,7 +289,7 @@ class LSTM:
|
|||||||
|
|
||||||
print(f'Ground Truth:\n{self.train_y}\n')
|
print(f'Ground Truth:\n{self.train_y}\n')
|
||||||
print(f'Predictions:\n{output}\n')
|
print(f'Predictions:\n{output}\n')
|
||||||
|
|
||||||
print(f'Accuracy: {round(accuracy * 100 / len(self.train_X), 2)}%')
|
print(f'Accuracy: {round(accuracy * 100 / len(self.train_X), 2)}%')
|
||||||
|
|
||||||
##### Data #####
|
##### Data #####
|
||||||
@ -314,4 +314,4 @@ if __name__ == "__main__":
|
|||||||
##### Testing #####
|
##### Testing #####
|
||||||
# lstm.test()
|
# lstm.test()
|
||||||
|
|
||||||
# testing can be done by uncommenting the above lines of code.
|
# testing can be done by uncommenting the above lines of code.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user