mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-04-15 02:07:36 +00:00
changed code a bit for meet ruff standards
This commit is contained in:
parent
5a00ca63fc
commit
3d9b893ee0
@ -46,12 +46,13 @@ Date: [Current Date]
|
||||
|
||||
##### Imports #####
|
||||
import numpy as np
|
||||
from numpy.random import Generator
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class LSTM:
|
||||
def __init__(
|
||||
self, data: str, hidden_dim: int = 25, epochs: int = 1000, lr: float = 0.05
|
||||
self, data: str, hidden_dim: int = 25, epochs: int = 10, lr: float = 0.05
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the LSTM network with the given data and hyperparameters.
|
||||
@ -75,6 +76,7 @@ class LSTM:
|
||||
self.idx_to_char = dict(enumerate(self.chars))
|
||||
|
||||
self.train_X, self.train_y = self.data[:-1], self.data[1:]
|
||||
self.rng: Generator = np.random.default_rng()
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
@ -94,32 +96,32 @@ class LSTM:
|
||||
"""
|
||||
Initialize the weights and biases for the LSTM network.
|
||||
"""
|
||||
rng = np.random.default_rng()
|
||||
|
||||
self.wf = self.init_weights(
|
||||
self.char_size + self.hidden_dim, self.hidden_dim, rng
|
||||
self.char_size + self.hidden_dim, self.hidden_dim
|
||||
)
|
||||
self.bf = np.zeros((self.hidden_dim, 1))
|
||||
|
||||
self.wi = self.init_weights(
|
||||
self.char_size + self.hidden_dim, self.hidden_dim, rng
|
||||
self.char_size + self.hidden_dim, self.hidden_dim
|
||||
)
|
||||
self.bi = np.zeros((self.hidden_dim, 1))
|
||||
|
||||
self.wc = self.init_weights(
|
||||
self.char_size + self.hidden_dim, self.hidden_dim, rng
|
||||
self.char_size + self.hidden_dim, self.hidden_dim
|
||||
)
|
||||
self.bc = np.zeros((self.hidden_dim, 1))
|
||||
|
||||
self.wo = self.init_weights(
|
||||
self.char_size + self.hidden_dim, self.hidden_dim, rng
|
||||
self.char_size + self.hidden_dim, self.hidden_dim
|
||||
)
|
||||
self.bo = np.zeros((self.hidden_dim, 1))
|
||||
|
||||
self.wy = self.init_weights(self.hidden_dim, self.char_size, rng)
|
||||
self.wy = self.init_weights(self.hidden_dim, self.char_size)
|
||||
self.by = np.zeros((self.char_size, 1))
|
||||
|
||||
def init_weights(
|
||||
self, input_dim: int, output_dim: int, rng: np.random.Generator
|
||||
self, input_dim: int, output_dim: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Initialize weights with random values.
|
||||
@ -129,7 +131,7 @@ class LSTM:
|
||||
:param rng: The random number generator.
|
||||
:return: A matrix of initialized weights.
|
||||
"""
|
||||
return rng.uniform(-1, 1, (output_dim, input_dim)) * np.sqrt(
|
||||
return self.rng.uniform(-1, 1, (output_dim, input_dim)) * np.sqrt(
|
||||
6 / (input_dim + output_dim)
|
||||
)
|
||||
|
||||
@ -343,9 +345,6 @@ class LSTM:
|
||||
self.backward(errors, self.concat_inputs)
|
||||
|
||||
def test(self) -> None:
|
||||
"""
|
||||
Test the trained LSTM network on the input data and print the accuracy.
|
||||
"""
|
||||
accuracy = 0
|
||||
probabilities = self.forward(
|
||||
[self.one_hot_encode(char) for char in self.train_X]
|
||||
@ -353,11 +352,9 @@ class LSTM:
|
||||
|
||||
output = ""
|
||||
for t in range(len(self.train_y)):
|
||||
prediction = self.idx_to_char[
|
||||
np.random.choice(
|
||||
range(self.char_size), p=self.softmax(probabilities[t].reshape(-1))
|
||||
)
|
||||
]
|
||||
probs = self.softmax(probabilities[t].reshape(-1))
|
||||
prediction_index = self.rng.choice(self.char_size, p=probs)
|
||||
prediction = self.idx_to_char[prediction_index]
|
||||
|
||||
output += prediction
|
||||
|
||||
@ -370,6 +367,7 @@ class LSTM:
|
||||
print(f"Accuracy: {round(accuracy * 100 / len(self.train_X), 2)}%")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = """Long Short-Term Memory (LSTM) networks are a type
|
||||
of recurrent neural network (RNN) capable of learning "
|
||||
@ -379,7 +377,7 @@ if __name__ == "__main__":
|
||||
iter and Schmidhuber in 1997, and were refined and "
|
||||
"popularized by many people in following work."""
|
||||
|
||||
lstm = LSTM(data=data, hidden_dim=25, epochs=1000, lr=0.05)
|
||||
lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05)
|
||||
|
||||
##### Training #####
|
||||
lstm.train()
|
||||
|
Loading…
x
Reference in New Issue
Block a user