mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-04-18 19:57:35 +00:00
Refactor LSTM class: Improve test method and add comments
This commit is contained in:
parent
b1e7e72524
commit
98332393b2
@ -346,18 +346,12 @@ class LongShortTermMemory:
|
|||||||
|
|
||||||
self.backward_pass(errors, inputs)
|
self.backward_pass(errors, inputs)
|
||||||
|
|
||||||
def test(self) -> None:
|
def test(self):
|
||||||
"""
|
"""
|
||||||
Test the trained LSTM network on the input data and print the accuracy.
|
Test the LSTM model.
|
||||||
|
|
||||||
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10,
|
Returns:
|
||||||
... training_epochs=5, learning_rate=0.01)
|
str: The output predictions.
|
||||||
>>> lstm is not None
|
|
||||||
True
|
|
||||||
>>> lstm.train()
|
|
||||||
>>> output = lstm.test()
|
|
||||||
>>> output is not None
|
|
||||||
True
|
|
||||||
"""
|
"""
|
||||||
accuracy = 0
|
accuracy = 0
|
||||||
probabilities = self.forward_pass(
|
probabilities = self.forward_pass(
|
||||||
@ -366,6 +360,7 @@ class LongShortTermMemory:
|
|||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
for t in range(len(self.target_sequence)):
|
for t in range(len(self.target_sequence)):
|
||||||
|
# Apply softmax to get probabilities for predictions
|
||||||
probs = self.softmax(probabilities[t].reshape(-1))
|
probs = self.softmax(probabilities[t].reshape(-1))
|
||||||
prediction_index = self.random_generator.choice(
|
prediction_index = self.random_generator.choice(
|
||||||
self.vocabulary_size, p=probs
|
self.vocabulary_size, p=probs
|
||||||
@ -374,17 +369,18 @@ class LongShortTermMemory:
|
|||||||
|
|
||||||
output += prediction
|
output += prediction
|
||||||
|
|
||||||
|
# Calculate accuracy
|
||||||
if prediction == self.target_sequence[t]:
|
if prediction == self.target_sequence[t]:
|
||||||
accuracy += 1
|
accuracy += 1
|
||||||
|
|
||||||
# print(f"Ground Truth:\n{self.target_sequence}\n")
|
print(f"Ground Truth:\n{self.target_sequence}\n")
|
||||||
# print(f"Predictions:\n{output}\n")
|
print(f"Predictions:\n{output}\n")
|
||||||
|
print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
|
||||||
# print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sample_data = """Long Short-Term Memory (LSTM) networks are a type
|
sample_data = """Long Short-Term Memory (LSTM) networks are a type
|
||||||
of recurrent neural network (RNN) capable of learning "
|
of recurrent neural network (RNN) capable of learning "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user