added doc tests

This commit is contained in:
“Shashank 2024-10-15 20:50:31 +05:30
parent e48555dfbc
commit 1608382d42

View File

@ -28,6 +28,15 @@ class LSTM:
:param epochs: The number of training epochs. :param epochs: The number of training epochs.
:param lr: The learning rate. :param lr: The learning rate.
""" """
"""
Test the LSTM model.
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
>>> lstm.train()
>>> predictions = lstm.test()
>>> len(predictions) > 0
True
"""
self.data: str = data.lower() self.data: str = data.lower()
self.hidden_dim: int = hidden_dim self.hidden_dim: int = hidden_dim
self.epochs: int = epochs self.epochs: int = epochs
@ -157,6 +166,15 @@ class LSTM:
:param inputs: The input data as a list of one-hot encoded vectors. :param inputs: The input data as a list of one-hot encoded vectors.
:return: The outputs of the network. :return: The outputs of the network.
""" """
"""
Forward pass through the LSTM network.
>>> lstm = LSTM(data="abcde", hidden_dim=10, epochs=1, lr=0.01)
>>> inputs = [lstm.one_hot_encode(char) for char in lstm.train_X]
>>> outputs = lstm.forward(inputs)
>>> len(outputs) == len(inputs)
True
"""
self.reset() self.reset()
outputs = [] outputs = []
@ -282,6 +300,14 @@ class LSTM:
""" """
Train the LSTM network on the input data. Train the LSTM network on the input data.
""" """
"""
Train the LSTM network on the input data.
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
>>> lstm.train()
>>> lstm.losses[-1] < lstm.losses[0]
True
"""
inputs = [self.one_hot_encode(char) for char in self.train_X] inputs = [self.one_hot_encode(char) for char in self.train_X]
for _ in range(self.epochs): for _ in range(self.epochs):
@ -298,6 +324,15 @@ class LSTM:
""" """
Test the trained LSTM network on the input data and print the accuracy. Test the trained LSTM network on the input data and print the accuracy.
""" """
"""
Test the LSTM model.
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
>>> lstm.train()
>>> predictions = lstm.test()
>>> len(predictions) > 0
True
"""
accuracy = 0 accuracy = 0
probabilities = self.forward( probabilities = self.forward(
[self.one_hot_encode(char) for char in self.train_X] [self.one_hot_encode(char) for char in self.train_X]
@ -328,6 +363,8 @@ if __name__ == "__main__":
"machine translation, speech recognition, and more. "machine translation, speech recognition, and more.
iter and Schmidhuber in 1997, and were refined and " iter and Schmidhuber in 1997, and were refined and "
"popularized by many people in following work.""" "popularized by many people in following work."""
import doctest
doctest.testmod()
# lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05) # lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05)