added doc tests

This commit is contained in:
“Shashank 2024-10-15 20:50:31 +05:30
parent e48555dfbc
commit 1608382d42

View File

@ -28,6 +28,15 @@ class LSTM:
:param epochs: The number of training epochs.
:param lr: The learning rate.
"""
"""
Test the LSTM model.
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
>>> lstm.train()
>>> predictions = lstm.test()
>>> len(predictions) > 0
True
"""
self.data: str = data.lower()
self.hidden_dim: int = hidden_dim
self.epochs: int = epochs
@ -157,6 +166,15 @@ class LSTM:
:param inputs: The input data as a list of one-hot encoded vectors.
:return: The outputs of the network.
"""
"""
Forward pass through the LSTM network.
>>> lstm = LSTM(data="abcde", hidden_dim=10, epochs=1, lr=0.01)
>>> inputs = [lstm.one_hot_encode(char) for char in lstm.train_X]
>>> outputs = lstm.forward(inputs)
>>> len(outputs) == len(inputs)
True
"""
self.reset()
outputs = []
@ -282,6 +300,14 @@ class LSTM:
"""
Train the LSTM network on the input data.
"""
"""
Train the LSTM network on the input data.
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
>>> lstm.train()
>>> lstm.losses[-1] < lstm.losses[0]
True
"""
inputs = [self.one_hot_encode(char) for char in self.train_X]
for _ in range(self.epochs):
@ -298,6 +324,15 @@ class LSTM:
"""
Test the trained LSTM network on the input data and print the accuracy.
"""
"""
Test the LSTM model.
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
>>> lstm.train()
>>> predictions = lstm.test()
>>> len(predictions) > 0
True
"""
accuracy = 0
probabilities = self.forward(
[self.one_hot_encode(char) for char in self.train_X]
@ -328,6 +363,8 @@ if __name__ == "__main__":
"machine translation, speech recognition, and more.
iter and Schmidhuber in 1997, and were refined and "
"popularized by many people in following work."""
import doctest
doctest.testmod()
# lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05)