"""
LeNet Network

Paper: http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf
"""

import numpy
import torch
import torch.nn as nn


class LeNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.tanh = nn.Tanh()
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=6,
            kernel_size=(5, 5),
            stride=(1, 1),
            padding=(0, 0),
        )
        self.conv2 = nn.Conv2d(
            in_channels=6,
            out_channels=16,
            kernel_size=(5, 5),
            stride=(1, 1),
            padding=(0, 0),
        )
        self.conv3 = nn.Conv2d(
            in_channels=16,
            out_channels=120,
            kernel_size=(5, 5),
            stride=(1, 1),
            padding=(0, 0),
        )

        self.linear1 = nn.Linear(120, 84)
        self.linear2 = nn.Linear(84, 10)

    def forward(self, image_array: numpy.ndarray) -> numpy.ndarray:
        image_array = self.tanh(self.conv1(image_array))
        image_array = self.avgpool(image_array)
        image_array = self.tanh(self.conv2(image_array))
        image_array = self.avgpool(image_array)
        image_array = self.tanh(self.conv3(image_array))

        image_array = image_array.reshape(image_array.shape[0], -1)
        image_array = self.tanh(self.linear1(image_array))
        image_array = self.linear2(image_array)
        return image_array


def test_model(image_tensor: torch.tensor) -> bool:
    """
    Test the model on an input batch of 64 images

    Args:
        image_tensor (torch.tensor): Batch of Images for the model

    >>> test_model(torch.randn(64, 1, 32, 32))
    True

    """
    try:
        model = LeNet()
        output = model(image_tensor)
    except RuntimeError:
        return False

    return output.shape == torch.zeros([64, 10]).shape


if __name__ == "__main__":
    random_image_1 = torch.randn(64, 1, 32, 32)
    random_image_2 = torch.randn(1, 32, 32)

    print(f"random_image_1 Model Passed: {test_model(random_image_1)}")
    print(f"\nrandom_image_2 Model Passed: {test_model(random_image_2)}")