mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-04-22 13:47:37 +00:00
added doc tests
This commit is contained in:
parent
e48555dfbc
commit
1608382d42
@ -28,6 +28,15 @@ class LSTM:
|
|||||||
:param epochs: The number of training epochs.
|
:param epochs: The number of training epochs.
|
||||||
:param lr: The learning rate.
|
:param lr: The learning rate.
|
||||||
"""
|
"""
|
||||||
|
"""
|
||||||
|
Test the LSTM model.
|
||||||
|
|
||||||
|
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
|
||||||
|
>>> lstm.train()
|
||||||
|
>>> predictions = lstm.test()
|
||||||
|
>>> len(predictions) > 0
|
||||||
|
True
|
||||||
|
"""
|
||||||
self.data: str = data.lower()
|
self.data: str = data.lower()
|
||||||
self.hidden_dim: int = hidden_dim
|
self.hidden_dim: int = hidden_dim
|
||||||
self.epochs: int = epochs
|
self.epochs: int = epochs
|
||||||
@ -157,6 +166,15 @@ class LSTM:
|
|||||||
:param inputs: The input data as a list of one-hot encoded vectors.
|
:param inputs: The input data as a list of one-hot encoded vectors.
|
||||||
:return: The outputs of the network.
|
:return: The outputs of the network.
|
||||||
"""
|
"""
|
||||||
|
"""
|
||||||
|
Forward pass through the LSTM network.
|
||||||
|
|
||||||
|
>>> lstm = LSTM(data="abcde", hidden_dim=10, epochs=1, lr=0.01)
|
||||||
|
>>> inputs = [lstm.one_hot_encode(char) for char in lstm.train_X]
|
||||||
|
>>> outputs = lstm.forward(inputs)
|
||||||
|
>>> len(outputs) == len(inputs)
|
||||||
|
True
|
||||||
|
"""
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
@ -282,6 +300,14 @@ class LSTM:
|
|||||||
"""
|
"""
|
||||||
Train the LSTM network on the input data.
|
Train the LSTM network on the input data.
|
||||||
"""
|
"""
|
||||||
|
"""
|
||||||
|
Train the LSTM network on the input data.
|
||||||
|
|
||||||
|
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
|
||||||
|
>>> lstm.train()
|
||||||
|
>>> lstm.losses[-1] < lstm.losses[0]
|
||||||
|
True
|
||||||
|
"""
|
||||||
inputs = [self.one_hot_encode(char) for char in self.train_X]
|
inputs = [self.one_hot_encode(char) for char in self.train_X]
|
||||||
|
|
||||||
for _ in range(self.epochs):
|
for _ in range(self.epochs):
|
||||||
@ -298,6 +324,15 @@ class LSTM:
|
|||||||
"""
|
"""
|
||||||
Test the trained LSTM network on the input data and print the accuracy.
|
Test the trained LSTM network on the input data and print the accuracy.
|
||||||
"""
|
"""
|
||||||
|
"""
|
||||||
|
Test the LSTM model.
|
||||||
|
|
||||||
|
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
|
||||||
|
>>> lstm.train()
|
||||||
|
>>> predictions = lstm.test()
|
||||||
|
>>> len(predictions) > 0
|
||||||
|
True
|
||||||
|
"""
|
||||||
accuracy = 0
|
accuracy = 0
|
||||||
probabilities = self.forward(
|
probabilities = self.forward(
|
||||||
[self.one_hot_encode(char) for char in self.train_X]
|
[self.one_hot_encode(char) for char in self.train_X]
|
||||||
@ -328,6 +363,8 @@ if __name__ == "__main__":
|
|||||||
"machine translation, speech recognition, and more.
|
"machine translation, speech recognition, and more.
|
||||||
iter and Schmidhuber in 1997, and were refined and "
|
iter and Schmidhuber in 1997, and were refined and "
|
||||||
"popularized by many people in following work."""
|
"popularized by many people in following work."""
|
||||||
|
import doctest
|
||||||
|
doctest.testmod()
|
||||||
|
|
||||||
# lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05)
|
# lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user