added doct tests for each function

This commit is contained in:
“Shashank 2024-10-15 23:39:50 +05:30
parent f058116f95
commit 750c9f6fc8

View File

@ -80,6 +80,18 @@ class LongShortTermMemory:
:param char: The character to encode. :param char: The character to encode.
:return: A one-hot encoded vector. :return: A one-hot encoded vector.
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
>>> output = lstm.one_hot_encode('a')
>>> isinstance(output, np.ndarray)
True
>>> output.shape
(5, 1)
>>> output = lstm.one_hot_encode('c')
>>> isinstance(output, np.ndarray)
True
>>> output.shape
(5, 1)
""" """
vector = np.zeros((self.vocabulary_size, 1)) vector = np.zeros((self.vocabulary_size, 1))
vector[self.char_to_index[char]] = 1 vector[self.char_to_index[char]] = 1
@ -88,8 +100,48 @@ class LongShortTermMemory:
def initialize_weights(self) -> None: def initialize_weights(self) -> None:
""" """
Initialize the weights and biases for the LSTM network. Initialize the weights and biases for the LSTM network.
"""
This method initializes the forget gate, input gate,
cell candidate, and output gate weights
and biases, as well as the output layer weights and biases.
It ensures that the weights
and biases have the correct shapes.
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
# Check the shapes of the weights and biases after initialization
>>> lstm.initialize_weights()
# Forget gate weights and bias
>>> lstm.forget_gate_weights.shape
(10, 15)
>>> lstm.forget_gate_bias.shape
(10, 1)
# Input gate weights and bias
>>> lstm.input_gate_weights.shape
(10, 15)
>>> lstm.input_gate_bias.shape
(10, 1)
# Cell candidate weights and bias
>>> lstm.cell_candidate_weights.shape
(10, 15)
>>> lstm.cell_candidate_bias.shape
(10, 1)
# Output gate weights and bias
>>> lstm.output_gate_weights.shape
(10, 15)
>>> lstm.output_gate_bias.shape
(10, 1)
# Output layer weights and bias
>>> lstm.output_layer_weights.shape
(5, 10)
>>> lstm.output_layer_bias.shape
(5, 1)
"""
self.forget_gate_weights = self.init_weights( self.forget_gate_weights = self.init_weights(
self.vocabulary_size + self.hidden_layer_size, self.hidden_layer_size self.vocabulary_size + self.hidden_layer_size, self.hidden_layer_size
) )
@ -110,10 +162,10 @@ class LongShortTermMemory:
) )
self.output_gate_bias = np.zeros((self.hidden_layer_size, 1)) self.output_gate_bias = np.zeros((self.hidden_layer_size, 1))
self.output_layer_weights: np.ndarray = self.init_weights( self.output_layer_weights = self.init_weights(
self.hidden_layer_size, self.vocabulary_size self.hidden_layer_size, self.vocabulary_size
) )
self.output_layer_bias: np.ndarray = np.zeros((self.vocabulary_size, 1)) self.output_layer_bias = np.zeros((self.vocabulary_size, 1))
def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray: def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
""" """
@ -134,6 +186,16 @@ class LongShortTermMemory:
:param x: The input array. :param x: The input array.
:param derivative: Whether to compute the derivative. :param derivative: Whether to compute the derivative.
:return: The sigmoid activation or its derivative. :return: The sigmoid activation or its derivative.
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
>>> output = lstm.sigmoid(np.array([[1, 2, 3]]))
>>> isinstance(output, np.ndarray)
True
>>> np.round(output, 3)
array([[0.731, 0.881, 0.953]])
>>> derivative_output = lstm.sigmoid(output, derivative=True)
>>> np.round(derivative_output, 3)
array([[0.197, 0.105, 0.045]])
""" """
if derivative: if derivative:
return x * (1 - x) return x * (1 - x)
@ -146,6 +208,16 @@ class LongShortTermMemory:
:param x: The input array. :param x: The input array.
:param derivative: Whether to compute the derivative. :param derivative: Whether to compute the derivative.
:return: The tanh activation or its derivative. :return: The tanh activation or its derivative.
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
>>> output = lstm.tanh(np.array([[1, 2, 3]]))
>>> isinstance(output, np.ndarray)
True
>>> np.round(output, 3)
array([[0.762, 0.964, 0.995]])
>>> derivative_output = lstm.tanh(output, derivative=True)
>>> np.round(derivative_output, 3)
array([[0.42 , 0.071, 0.01 ]])
""" """
if derivative: if derivative:
return 1 - x**2 return 1 - x**2
@ -157,6 +229,13 @@ class LongShortTermMemory:
:param x: The input array. :param x: The input array.
:return: The softmax activation. :return: The softmax activation.
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
>>> output = lstm.softmax(np.array([1, 2, 3]))
>>> isinstance(output, np.ndarray)
True
>>> np.round(output, 3)
array([0.09 , 0.245, 0.665])
""" """
exp_x = np.exp(x - np.max(x)) exp_x = np.exp(x - np.max(x))
return exp_x / exp_x.sum(axis=0) return exp_x / exp_x.sum(axis=0)
@ -164,6 +243,20 @@ class LongShortTermMemory:
def reset_network_state(self) -> None: def reset_network_state(self) -> None:
""" """
Reset the LSTM network states. Reset the LSTM network states.
Resets the internal states of the LSTM network, including the combined inputs,
hidden states, cell states, gate activations, and network outputs.
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
>>> lstm.reset_network_state()
>>> lstm.hidden_states[-1].shape == (10, 1)
True
>>> lstm.cell_states[-1].shape == (10, 1)
True
>>> lstm.combined_inputs == {}
True
>>> lstm.network_outputs == {}
True
""" """
self.combined_inputs = {} self.combined_inputs = {}
self.hidden_states = {-1: np.zeros((self.hidden_layer_size, 1))} self.hidden_states = {-1: np.zeros((self.hidden_layer_size, 1))}
@ -232,12 +325,6 @@ class LongShortTermMemory:
return outputs return outputs
def backward_pass(self, errors: list[np.ndarray], inputs: list[np.ndarray]) -> None: def backward_pass(self, errors: list[np.ndarray], inputs: list[np.ndarray]) -> None:
"""
Perform backpropagation through time to compute gradients and update weights.
:param errors: The errors at each time step.
:param inputs: The input data as a list of one-hot encoded vectors.
"""
d_forget_gate_weights, d_forget_gate_bias = 0, 0 d_forget_gate_weights, d_forget_gate_bias = 0, 0
d_input_gate_weights, d_input_gate_bias = 0, 0 d_input_gate_weights, d_input_gate_bias = 0, 0
d_cell_candidate_weights, d_cell_candidate_bias = 0, 0 d_cell_candidate_weights, d_cell_candidate_bias = 0, 0
@ -400,8 +487,8 @@ if __name__ == "__main__":
# learning_rate=0.05, # learning_rate=0.05,
# ) # )
##### Training ##### # #### Training #####
# lstm_model.train() # lstm_model.train()
##### Testing ##### # #### Testing #####
# lstm_model.test() # lstm_model.test()