diff --git a/neural_network/lstm.py b/neural_network/lstm.py index 6d1ea1cd3..f03c578a3 100644 --- a/neural_network/lstm.py +++ b/neural_network/lstm.py @@ -346,18 +346,12 @@ class LongShortTermMemory: self.backward_pass(errors, inputs) - def test(self) -> None: + def test(self): """ - Test the trained LSTM network on the input data and print the accuracy. + Test the LSTM model. - >>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10, - ... training_epochs=5, learning_rate=0.01) - >>> lstm is not None - True - >>> lstm.train() - >>> output = lstm.test() - >>> output is not None - True + Returns: + str: The output predictions. """ accuracy = 0 probabilities = self.forward_pass( @@ -366,6 +360,7 @@ class LongShortTermMemory: output = "" for t in range(len(self.target_sequence)): + # Apply softmax to get probabilities for predictions probs = self.softmax(probabilities[t].reshape(-1)) prediction_index = self.random_generator.choice( self.vocabulary_size, p=probs @@ -374,17 +369,18 @@ class LongShortTermMemory: output += prediction + # Calculate accuracy if prediction == self.target_sequence[t]: accuracy += 1 - # print(f"Ground Truth:\n{self.target_sequence}\n") - # print(f"Predictions:\n{output}\n") - - # print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%") + print(f"Ground Truth:\n{self.target_sequence}\n") + print(f"Predictions:\n{output}\n") + print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%") return output + if __name__ == "__main__": sample_data = """Long Short-Term Memory (LSTM) networks are a type of recurrent neural network (RNN) capable of learning "