diff --git a/neural_network/lstm.py b/neural_network/lstm.py index 9abd96053..7eaf6bdac 100644 --- a/neural_network/lstm.py +++ b/neural_network/lstm.py @@ -10,41 +10,8 @@ Github: LEVII007 Date: [Current Date] """ -#### Explanation ##### -# This script implements a Long Short-Term Memory (LSTM) -# network to learn and predict sequences of characters. -# It uses numpy for numerical operations and tqdm for progress visualization. +# from typing import dict, list -# The data is a paragraph about LSTM, converted to -# lowercase and split into characters. -# Each character is one-hot encoded for training. - -# The LSTM class initializes weights and biases for the -# forget, input, candidate, and output gates. -# It also initializes weights and biases for the final output layer. - -# The forward method performs forward propagation -# through the LSTM network, computing hidden and cell states. -# It uses sigmoid and tanh activation functions for the gates and cell states. - -# The backward method performs backpropagation -# through time, computing gradients for the weights and biases. -# It updates the weights and biases using the -# computed gradients and the learning rate. - -# The train method trains the LSTM network on -# the input data for a specified number of epochs. -# It uses one-hot encoded inputs and computes -# errors using the softmax function. - -# The test method evaluates the trained LSTM -# network on the input data, computing accuracy based on predictions. - -# The script initializes the LSTM network with -# specified hyperparameters and trains it on the input data. -# Finally, it tests the trained network and prints the accuracy of the predictions. - -##### Imports ##### import numpy as np from numpy.random import Generator from tqdm import tqdm @@ -62,25 +29,37 @@ class LSTM: :param epochs: The number of training epochs. :param lr: The learning rate. """ - self.data = data.lower() - self.hidden_dim = hidden_dim - self.epochs = epochs - self.lr = lr + self.data: str = data.lower() + self.hidden_dim: int = hidden_dim + self.epochs: int = epochs + self.lr: float = lr - self.chars = set(self.data) - self.data_size, self.char_size = len(self.data), len(self.chars) + self.chars: set = set(self.data) + self.data_size: int = len(self.data) + self.char_size: int = len(self.chars) print(f"Data size: {self.data_size}, Char Size: {self.char_size}") - self.char_to_idx = {c: i for i, c in enumerate(self.chars)} - self.idx_to_char = dict(enumerate(self.chars)) + self.char_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.chars)} + self.idx_to_char: dict[int, str] = dict(enumerate(self.chars)) - self.train_X, self.train_y = self.data[:-1], self.data[1:] + self.train_X: str = self.data[:-1] + self.train_y: str = self.data[1:] self.rng: Generator = np.random.default_rng() + # Initialize attributes used in reset method + self.concat_inputs: dict[int, np.ndarray] = {} + self.hidden_states: dict[int, np.ndarray] = {-1: np.zeros((self.hidden_dim, 1))} + self.cell_states: dict[int, np.ndarray] = {-1: np.zeros((self.hidden_dim, 1))} + self.activation_outputs: dict[int, np.ndarray] = {} + self.candidate_gates: dict[int, np.ndarray] = {} + self.output_gates: dict[int, np.ndarray] = {} + self.forget_gates: dict[int, np.ndarray] = {} + self.input_gates: dict[int, np.ndarray] = {} + self.outputs: dict[int, np.ndarray] = {} + self.initialize_weights() - ##### Helper Functions ##### def one_hot_encode(self, char: str) -> np.ndarray: """ One-hot encode a character. @@ -109,8 +88,8 @@ class LSTM: self.wo = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim) self.bo = np.zeros((self.hidden_dim, 1)) - self.wy = self.init_weights(self.hidden_dim, self.char_size) - self.by = np.zeros((self.char_size, 1)) + self.wy: np.ndarray = self.init_weights(self.hidden_dim, self.char_size) + self.by: np.ndarray = np.zeros((self.char_size, 1)) def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray: """ @@ -118,14 +97,12 @@ class LSTM: :param input_dim: The input dimension. :param output_dim: The output dimension. - :param rng: The random number generator. :return: A matrix of initialized weights. """ return self.rng.uniform(-1, 1, (output_dim, input_dim)) * np.sqrt( 6 / (input_dim + output_dim) ) - ##### Activation Functions ##### def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray: """ Sigmoid activation function. @@ -160,16 +137,13 @@ class LSTM: exp_x = np.exp(x - np.max(x)) return exp_x / exp_x.sum(axis=0) - ##### LSTM Network Methods ##### def reset(self) -> None: """ Reset the LSTM network states. """ self.concat_inputs = {} - self.hidden_states = {-1: np.zeros((self.hidden_dim, 1))} self.cell_states = {-1: np.zeros((self.hidden_dim, 1))} - self.activation_outputs = {} self.candidate_gates = {} self.output_gates = {} @@ -177,7 +151,7 @@ class LSTM: self.input_gates = {} self.outputs = {} - def forward(self, inputs: list) -> list: + def forward(self, inputs: list[np.ndarray]) -> list[np.ndarray]: """ Perform forward propagation through the LSTM network. @@ -217,7 +191,7 @@ class LSTM: return outputs - def backward(self, errors: list, inputs: list) -> None: + def backward(self, errors: list[np.ndarray], inputs: list[np.ndarray]) -> None: """ Perform backpropagation through time to compute gradients and update weights. @@ -237,23 +211,19 @@ class LSTM: for t in reversed(range(len(inputs))): error = errors[t] - # Final Gate Weights and Biases Errors d_wy += np.dot(error, self.hidden_states[t].T) d_by += error - # Hidden State Error d_hs = np.dot(self.wy.T, error) + dh_next - # Output Gate Weights and Biases Errors d_o = ( self.tanh(self.cell_states[t]) * d_hs * self.sigmoid(self.output_gates[t], derivative=True) ) - d_wo += np.dot(d_o, inputs[t].T) + d_wo += np.dot(d_o, self.concat_inputs[t].T) d_bo += d_o - # Cell State Error d_cs = ( self.tanh(self.tanh(self.cell_states[t]), derivative=True) * self.output_gates[t] @@ -261,34 +231,30 @@ class LSTM: + dc_next ) - # Forget Gate Weights and Biases Errors d_f = ( d_cs * self.cell_states[t - 1] * self.sigmoid(self.forget_gates[t], derivative=True) ) - d_wf += np.dot(d_f, inputs[t].T) + d_wf += np.dot(d_f, self.concat_inputs[t].T) d_bf += d_f - # Input Gate Weights and Biases Errors d_i = ( d_cs * self.candidate_gates[t] * self.sigmoid(self.input_gates[t], derivative=True) ) - d_wi += np.dot(d_i, inputs[t].T) + d_wi += np.dot(d_i, self.concat_inputs[t].T) d_bi += d_i - # Candidate Gate Weights and Biases Errors d_c = ( d_cs * self.input_gates[t] * self.tanh(self.candidate_gates[t], derivative=True) ) - d_wc += np.dot(d_c, inputs[t].T) + d_wc += np.dot(d_c, self.concat_inputs[t].T) d_bc += d_c - # Concatenated Input Error (Sum of Error at Each Gate!) d_z = ( np.dot(self.wf.T, d_f) + np.dot(self.wi.T, d_i) @@ -296,25 +262,20 @@ class LSTM: + np.dot(self.wo.T, d_o) ) - # Error of Hidden State and Cell State at Next Time Step dh_next = d_z[: self.hidden_dim, :] dc_next = self.forget_gates[t] * d_cs - for d_ in (d_wf, d_bf, d_wi, d_bi, d_wc, d_bc, d_wo, d_bo, d_wy, d_by): - np.clip(d_, -1, 1, out=d_) + for d in (d_wf, d_bf, d_wi, d_bi, d_wc, d_bc, d_wo, d_bo, d_wy, d_by): + np.clip(d, -1, 1, out=d) self.wf += d_wf * self.lr self.bf += d_bf * self.lr - self.wi += d_wi * self.lr self.bi += d_bi * self.lr - self.wc += d_wc * self.lr self.bc += d_bc * self.lr - self.wo += d_wo * self.lr self.bo += d_bo * self.lr - self.wy += d_wy * self.lr self.by += d_by * self.lr @@ -332,9 +293,12 @@ class LSTM: errors.append(-self.softmax(predictions[t])) errors[-1][self.char_to_idx[self.train_y[t]]] += 1 - self.backward(errors, self.concat_inputs) + self.backward(errors, inputs) def test(self) -> None: + """ + Test the trained LSTM network on the input data and print the accuracy. + """ accuracy = 0 probabilities = self.forward( [self.one_hot_encode(char) for char in self.train_X] @@ -366,12 +330,10 @@ if __name__ == "__main__": iter and Schmidhuber in 1997, and were refined and " "popularized by many people in following work.""" - lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05) + # lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05) ##### Training ##### - lstm.train() + # lstm.train() ##### Testing ##### - lstm.test() - -# testing can be done by uncommenting the above lines of code. + # lstm.test()