From 98332393b2aabedbd8b806e2b3f6b561415d65a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CShashank?= Date: Tue, 15 Oct 2024 22:18:25 +0530 Subject: [PATCH] Refactor LSTM class: Improve test method and add comments --- neural_network/lstm.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/neural_network/lstm.py b/neural_network/lstm.py index 6d1ea1cd3..f03c578a3 100644 --- a/neural_network/lstm.py +++ b/neural_network/lstm.py @@ -346,18 +346,12 @@ class LongShortTermMemory: self.backward_pass(errors, inputs) - def test(self) -> None: + def test(self): """ - Test the trained LSTM network on the input data and print the accuracy. + Test the LSTM model. - >>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10, - ... training_epochs=5, learning_rate=0.01) - >>> lstm is not None - True - >>> lstm.train() - >>> output = lstm.test() - >>> output is not None - True + Returns: + str: The output predictions. """ accuracy = 0 probabilities = self.forward_pass( @@ -366,6 +360,7 @@ class LongShortTermMemory: output = "" for t in range(len(self.target_sequence)): + # Apply softmax to get probabilities for predictions probs = self.softmax(probabilities[t].reshape(-1)) prediction_index = self.random_generator.choice( self.vocabulary_size, p=probs @@ -374,17 +369,18 @@ class LongShortTermMemory: output += prediction + # Calculate accuracy if prediction == self.target_sequence[t]: accuracy += 1 - # print(f"Ground Truth:\n{self.target_sequence}\n") - # print(f"Predictions:\n{output}\n") - - # print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%") + print(f"Ground Truth:\n{self.target_sequence}\n") + print(f"Predictions:\n{output}\n") + print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%") return output + if __name__ == "__main__": sample_data = """Long Short-Term Memory (LSTM) networks are a type of recurrent neural network (RNN) capable of learning "