mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-04-16 10:47:37 +00:00
Refactor LSTM class: Improve test method and add comments
This commit is contained in:
parent
b1e7e72524
commit
98332393b2
@ -346,18 +346,12 @@ class LongShortTermMemory:
|
||||
|
||||
self.backward_pass(errors, inputs)
|
||||
|
||||
def test(self) -> None:
|
||||
def test(self):
|
||||
"""
|
||||
Test the trained LSTM network on the input data and print the accuracy.
|
||||
Test the LSTM model.
|
||||
|
||||
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10,
|
||||
... training_epochs=5, learning_rate=0.01)
|
||||
>>> lstm is not None
|
||||
True
|
||||
>>> lstm.train()
|
||||
>>> output = lstm.test()
|
||||
>>> output is not None
|
||||
True
|
||||
Returns:
|
||||
str: The output predictions.
|
||||
"""
|
||||
accuracy = 0
|
||||
probabilities = self.forward_pass(
|
||||
@ -366,6 +360,7 @@ class LongShortTermMemory:
|
||||
|
||||
output = ""
|
||||
for t in range(len(self.target_sequence)):
|
||||
# Apply softmax to get probabilities for predictions
|
||||
probs = self.softmax(probabilities[t].reshape(-1))
|
||||
prediction_index = self.random_generator.choice(
|
||||
self.vocabulary_size, p=probs
|
||||
@ -374,17 +369,18 @@ class LongShortTermMemory:
|
||||
|
||||
output += prediction
|
||||
|
||||
# Calculate accuracy
|
||||
if prediction == self.target_sequence[t]:
|
||||
accuracy += 1
|
||||
|
||||
# print(f"Ground Truth:\n{self.target_sequence}\n")
|
||||
# print(f"Predictions:\n{output}\n")
|
||||
|
||||
# print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
|
||||
print(f"Ground Truth:\n{self.target_sequence}\n")
|
||||
print(f"Predictions:\n{output}\n")
|
||||
print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sample_data = """Long Short-Term Memory (LSTM) networks are a type
|
||||
of recurrent neural network (RNN) capable of learning "
|
||||
|
Loading…
x
Reference in New Issue
Block a user