diff --git a/neural_network/lstm.py b/neural_network/lstm.py index 20df37d23..c642e6df1 100644 --- a/neural_network/lstm.py +++ b/neural_network/lstm.py @@ -28,6 +28,15 @@ class LSTM: :param epochs: The number of training epochs. :param lr: The learning rate. """ + """ + Test the LSTM model. + + >>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01) + >>> lstm.train() + >>> predictions = lstm.test() + >>> len(predictions) > 0 + True + """ self.data: str = data.lower() self.hidden_dim: int = hidden_dim self.epochs: int = epochs @@ -157,6 +166,15 @@ class LSTM: :param inputs: The input data as a list of one-hot encoded vectors. :return: The outputs of the network. """ + """ + Forward pass through the LSTM network. + + >>> lstm = LSTM(data="abcde", hidden_dim=10, epochs=1, lr=0.01) + >>> inputs = [lstm.one_hot_encode(char) for char in lstm.train_X] + >>> outputs = lstm.forward(inputs) + >>> len(outputs) == len(inputs) + True + """ self.reset() outputs = [] @@ -282,6 +300,14 @@ class LSTM: """ Train the LSTM network on the input data. """ + """ + Train the LSTM network on the input data. + + >>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01) + >>> lstm.train() + >>> lstm.losses[-1] < lstm.losses[0] + True + """ inputs = [self.one_hot_encode(char) for char in self.train_X] for _ in range(self.epochs): @@ -298,6 +324,15 @@ class LSTM: """ Test the trained LSTM network on the input data and print the accuracy. """ + """ + Test the LSTM model. + + >>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01) + >>> lstm.train() + >>> predictions = lstm.test() + >>> len(predictions) > 0 + True + """ accuracy = 0 probabilities = self.forward( [self.one_hot_encode(char) for char in self.train_X] @@ -328,6 +363,8 @@ if __name__ == "__main__": "machine translation, speech recognition, and more. iter and Schmidhuber in 1997, and were refined and " "popularized by many people in following work.""" + import doctest + doctest.testmod() # lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05)