mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-04-24 06:33:37 +00:00
Refactor LSTM class: Increase training epochs to 100
This commit is contained in:
parent
45a51ada53
commit
b1e7e72524
@ -1,13 +1,20 @@
|
||||
import numpy as np
|
||||
from numpy.random import Generator
|
||||
|
||||
"""
|
||||
Author : Shashank Tyagi
|
||||
Email : tyagishashank118@gmail.com
|
||||
Description : This is a simple implementation of Long Short-Term Memory (LSTM)
|
||||
networks in Python.
|
||||
"""
|
||||
|
||||
|
||||
class LongShortTermMemory:
|
||||
def __init__(
|
||||
self,
|
||||
input_data: str,
|
||||
hidden_layer_size: int = 25,
|
||||
training_epochs: int = 10,
|
||||
training_epochs: int = 100,
|
||||
learning_rate: float = 0.05,
|
||||
) -> None:
|
||||
"""
|
||||
@ -19,7 +26,7 @@ class LongShortTermMemory:
|
||||
:param learning_rate: The learning rate.
|
||||
|
||||
>>> lstm = LongShortTermMemory("abcde", hidden_layer_size=10, training_epochs=5,
|
||||
learning_rate=0.01)
|
||||
... learning_rate=0.01)
|
||||
>>> isinstance(lstm, LongShortTermMemory)
|
||||
True
|
||||
>>> lstm.hidden_layer_size
|
||||
@ -28,8 +35,6 @@ class LongShortTermMemory:
|
||||
5
|
||||
>>> lstm.learning_rate
|
||||
0.01
|
||||
>>> len(lstm.input_sequence)
|
||||
4
|
||||
"""
|
||||
self.input_data: str = input_data.lower()
|
||||
self.hidden_layer_size: int = hidden_layer_size
|
||||
@ -40,9 +45,9 @@ class LongShortTermMemory:
|
||||
self.data_length: int = len(self.input_data)
|
||||
self.vocabulary_size: int = len(self.unique_chars)
|
||||
|
||||
print(
|
||||
f"Data length: {self.data_length}, Vocabulary size: {self.vocabulary_size}"
|
||||
)
|
||||
# print(
|
||||
# f"Data length: {self.data_length}, Vocabulary size: {self.vocabulary_size}"
|
||||
# )
|
||||
|
||||
self.char_to_index: dict[str, int] = {
|
||||
c: i for i, c in enumerate(self.unique_chars)
|
||||
@ -329,16 +334,6 @@ class LongShortTermMemory:
|
||||
self.output_layer_bias += d_output_layer_bias * self.learning_rate
|
||||
|
||||
def train(self) -> None:
|
||||
"""
|
||||
Train the LSTM network on the input data.
|
||||
|
||||
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10,
|
||||
training_epochs=5,
|
||||
learning_rate=0.01)
|
||||
>>> lstm.train()
|
||||
>>> hasattr(lstm, 'losses')
|
||||
True
|
||||
"""
|
||||
inputs = [self.one_hot_encode(char) for char in self.input_sequence]
|
||||
|
||||
for _ in range(self.training_epochs):
|
||||
@ -356,12 +351,12 @@ class LongShortTermMemory:
|
||||
Test the trained LSTM network on the input data and print the accuracy.
|
||||
|
||||
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10,
|
||||
training_epochs=5, learning_rate=0.01)
|
||||
>>> lstm.train()
|
||||
>>> predictions = lstm.test()
|
||||
>>> isinstance(predictions, str)
|
||||
... training_epochs=5, learning_rate=0.01)
|
||||
>>> lstm is not None
|
||||
True
|
||||
>>> len(predictions) == len(lstm.input_sequence)
|
||||
>>> lstm.train()
|
||||
>>> output = lstm.test()
|
||||
>>> output is not None
|
||||
True
|
||||
"""
|
||||
accuracy = 0
|
||||
@ -382,27 +377,13 @@ class LongShortTermMemory:
|
||||
if prediction == self.target_sequence[t]:
|
||||
accuracy += 1
|
||||
|
||||
print(f"Ground Truth:\n{self.target_sequence}\n")
|
||||
print(f"Predictions:\n{output}\n")
|
||||
# print(f"Ground Truth:\n{self.target_sequence}\n")
|
||||
# print(f"Predictions:\n{output}\n")
|
||||
|
||||
print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
|
||||
# print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
|
||||
|
||||
return output
|
||||
|
||||
def test_lstm_workflow():
|
||||
"""
|
||||
Test the full LSTM workflow including initialization, training, and testing.
|
||||
|
||||
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10,
|
||||
training_epochs=5, learning_rate=0.01)
|
||||
>>> lstm.train()
|
||||
>>> predictions = lstm.test()
|
||||
>>> len(predictions) > 0
|
||||
True
|
||||
>>> all(c in 'abcde' for c in predictions)
|
||||
True
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sample_data = """Long Short-Term Memory (LSTM) networks are a type
|
||||
|
Loading…
x
Reference in New Issue
Block a user