Refactor LSTM class: Improve test method and add comments

This commit is contained in:
“Shashank 2024-10-15 22:18:25 +05:30
parent b1e7e72524
commit 98332393b2

View File

@ -346,18 +346,12 @@ class LongShortTermMemory:
self.backward_pass(errors, inputs)
def test(self) -> None:
def test(self):
"""
Test the trained LSTM network on the input data and print the accuracy.
Test the LSTM model.
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10,
... training_epochs=5, learning_rate=0.01)
>>> lstm is not None
True
>>> lstm.train()
>>> output = lstm.test()
>>> output is not None
True
Returns:
str: The output predictions.
"""
accuracy = 0
probabilities = self.forward_pass(
@ -366,6 +360,7 @@ class LongShortTermMemory:
output = ""
for t in range(len(self.target_sequence)):
# Apply softmax to get probabilities for predictions
probs = self.softmax(probabilities[t].reshape(-1))
prediction_index = self.random_generator.choice(
self.vocabulary_size, p=probs
@ -374,17 +369,18 @@ class LongShortTermMemory:
output += prediction
# Calculate accuracy
if prediction == self.target_sequence[t]:
accuracy += 1
# print(f"Ground Truth:\n{self.target_sequence}\n")
# print(f"Predictions:\n{output}\n")
# print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
print(f"Ground Truth:\n{self.target_sequence}\n")
print(f"Predictions:\n{output}\n")
print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
return output
if __name__ == "__main__":
sample_data = """Long Short-Term Memory (LSTM) networks are a type
of recurrent neural network (RNN) capable of learning "