mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-04-22 13:47:37 +00:00
added doc tests
This commit is contained in:
parent
e48555dfbc
commit
1608382d42
@ -28,6 +28,15 @@ class LSTM:
|
||||
:param epochs: The number of training epochs.
|
||||
:param lr: The learning rate.
|
||||
"""
|
||||
"""
|
||||
Test the LSTM model.
|
||||
|
||||
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
|
||||
>>> lstm.train()
|
||||
>>> predictions = lstm.test()
|
||||
>>> len(predictions) > 0
|
||||
True
|
||||
"""
|
||||
self.data: str = data.lower()
|
||||
self.hidden_dim: int = hidden_dim
|
||||
self.epochs: int = epochs
|
||||
@ -157,6 +166,15 @@ class LSTM:
|
||||
:param inputs: The input data as a list of one-hot encoded vectors.
|
||||
:return: The outputs of the network.
|
||||
"""
|
||||
"""
|
||||
Forward pass through the LSTM network.
|
||||
|
||||
>>> lstm = LSTM(data="abcde", hidden_dim=10, epochs=1, lr=0.01)
|
||||
>>> inputs = [lstm.one_hot_encode(char) for char in lstm.train_X]
|
||||
>>> outputs = lstm.forward(inputs)
|
||||
>>> len(outputs) == len(inputs)
|
||||
True
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
outputs = []
|
||||
@ -282,6 +300,14 @@ class LSTM:
|
||||
"""
|
||||
Train the LSTM network on the input data.
|
||||
"""
|
||||
"""
|
||||
Train the LSTM network on the input data.
|
||||
|
||||
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
|
||||
>>> lstm.train()
|
||||
>>> lstm.losses[-1] < lstm.losses[0]
|
||||
True
|
||||
"""
|
||||
inputs = [self.one_hot_encode(char) for char in self.train_X]
|
||||
|
||||
for _ in range(self.epochs):
|
||||
@ -298,6 +324,15 @@ class LSTM:
|
||||
"""
|
||||
Test the trained LSTM network on the input data and print the accuracy.
|
||||
"""
|
||||
"""
|
||||
Test the LSTM model.
|
||||
|
||||
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
|
||||
>>> lstm.train()
|
||||
>>> predictions = lstm.test()
|
||||
>>> len(predictions) > 0
|
||||
True
|
||||
"""
|
||||
accuracy = 0
|
||||
probabilities = self.forward(
|
||||
[self.one_hot_encode(char) for char in self.train_X]
|
||||
@ -328,6 +363,8 @@ if __name__ == "__main__":
|
||||
"machine translation, speech recognition, and more.
|
||||
iter and Schmidhuber in 1997, and were refined and "
|
||||
"popularized by many people in following work."""
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
||||
# lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user