mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-04-11 08:15:55 +00:00
Refactor LSTM class to improve code readability and maintainability
This commit is contained in:
parent
5c186b16e8
commit
94ad70c234
@ -10,41 +10,8 @@ Github: LEVII007
|
|||||||
Date: [Current Date]
|
Date: [Current Date]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#### Explanation #####
|
# from typing import dict, list
|
||||||
# This script implements a Long Short-Term Memory (LSTM)
|
|
||||||
# network to learn and predict sequences of characters.
|
|
||||||
# It uses numpy for numerical operations and tqdm for progress visualization.
|
|
||||||
|
|
||||||
# The data is a paragraph about LSTM, converted to
|
|
||||||
# lowercase and split into characters.
|
|
||||||
# Each character is one-hot encoded for training.
|
|
||||||
|
|
||||||
# The LSTM class initializes weights and biases for the
|
|
||||||
# forget, input, candidate, and output gates.
|
|
||||||
# It also initializes weights and biases for the final output layer.
|
|
||||||
|
|
||||||
# The forward method performs forward propagation
|
|
||||||
# through the LSTM network, computing hidden and cell states.
|
|
||||||
# It uses sigmoid and tanh activation functions for the gates and cell states.
|
|
||||||
|
|
||||||
# The backward method performs backpropagation
|
|
||||||
# through time, computing gradients for the weights and biases.
|
|
||||||
# It updates the weights and biases using the
|
|
||||||
# computed gradients and the learning rate.
|
|
||||||
|
|
||||||
# The train method trains the LSTM network on
|
|
||||||
# the input data for a specified number of epochs.
|
|
||||||
# It uses one-hot encoded inputs and computes
|
|
||||||
# errors using the softmax function.
|
|
||||||
|
|
||||||
# The test method evaluates the trained LSTM
|
|
||||||
# network on the input data, computing accuracy based on predictions.
|
|
||||||
|
|
||||||
# The script initializes the LSTM network with
|
|
||||||
# specified hyperparameters and trains it on the input data.
|
|
||||||
# Finally, it tests the trained network and prints the accuracy of the predictions.
|
|
||||||
|
|
||||||
##### Imports #####
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.random import Generator
|
from numpy.random import Generator
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -62,25 +29,37 @@ class LSTM:
|
|||||||
:param epochs: The number of training epochs.
|
:param epochs: The number of training epochs.
|
||||||
:param lr: The learning rate.
|
:param lr: The learning rate.
|
||||||
"""
|
"""
|
||||||
self.data = data.lower()
|
self.data: str = data.lower()
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim: int = hidden_dim
|
||||||
self.epochs = epochs
|
self.epochs: int = epochs
|
||||||
self.lr = lr
|
self.lr: float = lr
|
||||||
|
|
||||||
self.chars = set(self.data)
|
self.chars: set = set(self.data)
|
||||||
self.data_size, self.char_size = len(self.data), len(self.chars)
|
self.data_size: int = len(self.data)
|
||||||
|
self.char_size: int = len(self.chars)
|
||||||
|
|
||||||
print(f"Data size: {self.data_size}, Char Size: {self.char_size}")
|
print(f"Data size: {self.data_size}, Char Size: {self.char_size}")
|
||||||
|
|
||||||
self.char_to_idx = {c: i for i, c in enumerate(self.chars)}
|
self.char_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.chars)}
|
||||||
self.idx_to_char = dict(enumerate(self.chars))
|
self.idx_to_char: dict[int, str] = dict(enumerate(self.chars))
|
||||||
|
|
||||||
self.train_X, self.train_y = self.data[:-1], self.data[1:]
|
self.train_X: str = self.data[:-1]
|
||||||
|
self.train_y: str = self.data[1:]
|
||||||
self.rng: Generator = np.random.default_rng()
|
self.rng: Generator = np.random.default_rng()
|
||||||
|
|
||||||
|
# Initialize attributes used in reset method
|
||||||
|
self.concat_inputs: dict[int, np.ndarray] = {}
|
||||||
|
self.hidden_states: dict[int, np.ndarray] = {-1: np.zeros((self.hidden_dim, 1))}
|
||||||
|
self.cell_states: dict[int, np.ndarray] = {-1: np.zeros((self.hidden_dim, 1))}
|
||||||
|
self.activation_outputs: dict[int, np.ndarray] = {}
|
||||||
|
self.candidate_gates: dict[int, np.ndarray] = {}
|
||||||
|
self.output_gates: dict[int, np.ndarray] = {}
|
||||||
|
self.forget_gates: dict[int, np.ndarray] = {}
|
||||||
|
self.input_gates: dict[int, np.ndarray] = {}
|
||||||
|
self.outputs: dict[int, np.ndarray] = {}
|
||||||
|
|
||||||
self.initialize_weights()
|
self.initialize_weights()
|
||||||
|
|
||||||
##### Helper Functions #####
|
|
||||||
def one_hot_encode(self, char: str) -> np.ndarray:
|
def one_hot_encode(self, char: str) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
One-hot encode a character.
|
One-hot encode a character.
|
||||||
@ -109,8 +88,8 @@ class LSTM:
|
|||||||
self.wo = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
|
self.wo = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
|
||||||
self.bo = np.zeros((self.hidden_dim, 1))
|
self.bo = np.zeros((self.hidden_dim, 1))
|
||||||
|
|
||||||
self.wy = self.init_weights(self.hidden_dim, self.char_size)
|
self.wy: np.ndarray = self.init_weights(self.hidden_dim, self.char_size)
|
||||||
self.by = np.zeros((self.char_size, 1))
|
self.by: np.ndarray = np.zeros((self.char_size, 1))
|
||||||
|
|
||||||
def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
|
def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@ -118,14 +97,12 @@ class LSTM:
|
|||||||
|
|
||||||
:param input_dim: The input dimension.
|
:param input_dim: The input dimension.
|
||||||
:param output_dim: The output dimension.
|
:param output_dim: The output dimension.
|
||||||
:param rng: The random number generator.
|
|
||||||
:return: A matrix of initialized weights.
|
:return: A matrix of initialized weights.
|
||||||
"""
|
"""
|
||||||
return self.rng.uniform(-1, 1, (output_dim, input_dim)) * np.sqrt(
|
return self.rng.uniform(-1, 1, (output_dim, input_dim)) * np.sqrt(
|
||||||
6 / (input_dim + output_dim)
|
6 / (input_dim + output_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
##### Activation Functions #####
|
|
||||||
def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
|
def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Sigmoid activation function.
|
Sigmoid activation function.
|
||||||
@ -160,16 +137,13 @@ class LSTM:
|
|||||||
exp_x = np.exp(x - np.max(x))
|
exp_x = np.exp(x - np.max(x))
|
||||||
return exp_x / exp_x.sum(axis=0)
|
return exp_x / exp_x.sum(axis=0)
|
||||||
|
|
||||||
##### LSTM Network Methods #####
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""
|
"""
|
||||||
Reset the LSTM network states.
|
Reset the LSTM network states.
|
||||||
"""
|
"""
|
||||||
self.concat_inputs = {}
|
self.concat_inputs = {}
|
||||||
|
|
||||||
self.hidden_states = {-1: np.zeros((self.hidden_dim, 1))}
|
self.hidden_states = {-1: np.zeros((self.hidden_dim, 1))}
|
||||||
self.cell_states = {-1: np.zeros((self.hidden_dim, 1))}
|
self.cell_states = {-1: np.zeros((self.hidden_dim, 1))}
|
||||||
|
|
||||||
self.activation_outputs = {}
|
self.activation_outputs = {}
|
||||||
self.candidate_gates = {}
|
self.candidate_gates = {}
|
||||||
self.output_gates = {}
|
self.output_gates = {}
|
||||||
@ -177,7 +151,7 @@ class LSTM:
|
|||||||
self.input_gates = {}
|
self.input_gates = {}
|
||||||
self.outputs = {}
|
self.outputs = {}
|
||||||
|
|
||||||
def forward(self, inputs: list) -> list:
|
def forward(self, inputs: list[np.ndarray]) -> list[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Perform forward propagation through the LSTM network.
|
Perform forward propagation through the LSTM network.
|
||||||
|
|
||||||
@ -217,7 +191,7 @@ class LSTM:
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def backward(self, errors: list, inputs: list) -> None:
|
def backward(self, errors: list[np.ndarray], inputs: list[np.ndarray]) -> None:
|
||||||
"""
|
"""
|
||||||
Perform backpropagation through time to compute gradients and update weights.
|
Perform backpropagation through time to compute gradients and update weights.
|
||||||
|
|
||||||
@ -237,23 +211,19 @@ class LSTM:
|
|||||||
for t in reversed(range(len(inputs))):
|
for t in reversed(range(len(inputs))):
|
||||||
error = errors[t]
|
error = errors[t]
|
||||||
|
|
||||||
# Final Gate Weights and Biases Errors
|
|
||||||
d_wy += np.dot(error, self.hidden_states[t].T)
|
d_wy += np.dot(error, self.hidden_states[t].T)
|
||||||
d_by += error
|
d_by += error
|
||||||
|
|
||||||
# Hidden State Error
|
|
||||||
d_hs = np.dot(self.wy.T, error) + dh_next
|
d_hs = np.dot(self.wy.T, error) + dh_next
|
||||||
|
|
||||||
# Output Gate Weights and Biases Errors
|
|
||||||
d_o = (
|
d_o = (
|
||||||
self.tanh(self.cell_states[t])
|
self.tanh(self.cell_states[t])
|
||||||
* d_hs
|
* d_hs
|
||||||
* self.sigmoid(self.output_gates[t], derivative=True)
|
* self.sigmoid(self.output_gates[t], derivative=True)
|
||||||
)
|
)
|
||||||
d_wo += np.dot(d_o, inputs[t].T)
|
d_wo += np.dot(d_o, self.concat_inputs[t].T)
|
||||||
d_bo += d_o
|
d_bo += d_o
|
||||||
|
|
||||||
# Cell State Error
|
|
||||||
d_cs = (
|
d_cs = (
|
||||||
self.tanh(self.tanh(self.cell_states[t]), derivative=True)
|
self.tanh(self.tanh(self.cell_states[t]), derivative=True)
|
||||||
* self.output_gates[t]
|
* self.output_gates[t]
|
||||||
@ -261,34 +231,30 @@ class LSTM:
|
|||||||
+ dc_next
|
+ dc_next
|
||||||
)
|
)
|
||||||
|
|
||||||
# Forget Gate Weights and Biases Errors
|
|
||||||
d_f = (
|
d_f = (
|
||||||
d_cs
|
d_cs
|
||||||
* self.cell_states[t - 1]
|
* self.cell_states[t - 1]
|
||||||
* self.sigmoid(self.forget_gates[t], derivative=True)
|
* self.sigmoid(self.forget_gates[t], derivative=True)
|
||||||
)
|
)
|
||||||
d_wf += np.dot(d_f, inputs[t].T)
|
d_wf += np.dot(d_f, self.concat_inputs[t].T)
|
||||||
d_bf += d_f
|
d_bf += d_f
|
||||||
|
|
||||||
# Input Gate Weights and Biases Errors
|
|
||||||
d_i = (
|
d_i = (
|
||||||
d_cs
|
d_cs
|
||||||
* self.candidate_gates[t]
|
* self.candidate_gates[t]
|
||||||
* self.sigmoid(self.input_gates[t], derivative=True)
|
* self.sigmoid(self.input_gates[t], derivative=True)
|
||||||
)
|
)
|
||||||
d_wi += np.dot(d_i, inputs[t].T)
|
d_wi += np.dot(d_i, self.concat_inputs[t].T)
|
||||||
d_bi += d_i
|
d_bi += d_i
|
||||||
|
|
||||||
# Candidate Gate Weights and Biases Errors
|
|
||||||
d_c = (
|
d_c = (
|
||||||
d_cs
|
d_cs
|
||||||
* self.input_gates[t]
|
* self.input_gates[t]
|
||||||
* self.tanh(self.candidate_gates[t], derivative=True)
|
* self.tanh(self.candidate_gates[t], derivative=True)
|
||||||
)
|
)
|
||||||
d_wc += np.dot(d_c, inputs[t].T)
|
d_wc += np.dot(d_c, self.concat_inputs[t].T)
|
||||||
d_bc += d_c
|
d_bc += d_c
|
||||||
|
|
||||||
# Concatenated Input Error (Sum of Error at Each Gate!)
|
|
||||||
d_z = (
|
d_z = (
|
||||||
np.dot(self.wf.T, d_f)
|
np.dot(self.wf.T, d_f)
|
||||||
+ np.dot(self.wi.T, d_i)
|
+ np.dot(self.wi.T, d_i)
|
||||||
@ -296,25 +262,20 @@ class LSTM:
|
|||||||
+ np.dot(self.wo.T, d_o)
|
+ np.dot(self.wo.T, d_o)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Error of Hidden State and Cell State at Next Time Step
|
|
||||||
dh_next = d_z[: self.hidden_dim, :]
|
dh_next = d_z[: self.hidden_dim, :]
|
||||||
dc_next = self.forget_gates[t] * d_cs
|
dc_next = self.forget_gates[t] * d_cs
|
||||||
|
|
||||||
for d_ in (d_wf, d_bf, d_wi, d_bi, d_wc, d_bc, d_wo, d_bo, d_wy, d_by):
|
for d in (d_wf, d_bf, d_wi, d_bi, d_wc, d_bc, d_wo, d_bo, d_wy, d_by):
|
||||||
np.clip(d_, -1, 1, out=d_)
|
np.clip(d, -1, 1, out=d)
|
||||||
|
|
||||||
self.wf += d_wf * self.lr
|
self.wf += d_wf * self.lr
|
||||||
self.bf += d_bf * self.lr
|
self.bf += d_bf * self.lr
|
||||||
|
|
||||||
self.wi += d_wi * self.lr
|
self.wi += d_wi * self.lr
|
||||||
self.bi += d_bi * self.lr
|
self.bi += d_bi * self.lr
|
||||||
|
|
||||||
self.wc += d_wc * self.lr
|
self.wc += d_wc * self.lr
|
||||||
self.bc += d_bc * self.lr
|
self.bc += d_bc * self.lr
|
||||||
|
|
||||||
self.wo += d_wo * self.lr
|
self.wo += d_wo * self.lr
|
||||||
self.bo += d_bo * self.lr
|
self.bo += d_bo * self.lr
|
||||||
|
|
||||||
self.wy += d_wy * self.lr
|
self.wy += d_wy * self.lr
|
||||||
self.by += d_by * self.lr
|
self.by += d_by * self.lr
|
||||||
|
|
||||||
@ -332,9 +293,12 @@ class LSTM:
|
|||||||
errors.append(-self.softmax(predictions[t]))
|
errors.append(-self.softmax(predictions[t]))
|
||||||
errors[-1][self.char_to_idx[self.train_y[t]]] += 1
|
errors[-1][self.char_to_idx[self.train_y[t]]] += 1
|
||||||
|
|
||||||
self.backward(errors, self.concat_inputs)
|
self.backward(errors, inputs)
|
||||||
|
|
||||||
def test(self) -> None:
|
def test(self) -> None:
|
||||||
|
"""
|
||||||
|
Test the trained LSTM network on the input data and print the accuracy.
|
||||||
|
"""
|
||||||
accuracy = 0
|
accuracy = 0
|
||||||
probabilities = self.forward(
|
probabilities = self.forward(
|
||||||
[self.one_hot_encode(char) for char in self.train_X]
|
[self.one_hot_encode(char) for char in self.train_X]
|
||||||
@ -366,12 +330,10 @@ if __name__ == "__main__":
|
|||||||
iter and Schmidhuber in 1997, and were refined and "
|
iter and Schmidhuber in 1997, and were refined and "
|
||||||
"popularized by many people in following work."""
|
"popularized by many people in following work."""
|
||||||
|
|
||||||
lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05)
|
# lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05)
|
||||||
|
|
||||||
##### Training #####
|
##### Training #####
|
||||||
lstm.train()
|
# lstm.train()
|
||||||
|
|
||||||
##### Testing #####
|
##### Testing #####
|
||||||
lstm.test()
|
# lstm.test()
|
||||||
|
|
||||||
# testing can be done by uncommenting the above lines of code.
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user