import argparse
import math
from pathlib import Path
from PIL import Image
import torch
from torch import nn
from torchvision import transforms


class OrientationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.flatten = nn.Flatten()

        self.fc = nn.Sequential(
            nn.LazyLinear(out_features=64),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(in_features=64, out_features=2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x


def predict_angle(model_path: Path, image: Image) -> float:
    to_tensor = transforms.ToTensor()
    image_tensor = to_tensor(image)
    image_tensor = image_tensor.unsqueeze(0)

    model = OrientationModel()
    model.load_state_dict(torch.load(model_path))
    model.eval()

    with torch.no_grad():
        output = model(image_tensor)[0]
        return torch.atan2(output[0], output[1]) * 180 / math.pi


parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=Path, default="orientation-model.pth", help="path to saved model state dict")
parser.add_argument("image_path", type=Path, help="path of PNG image to correct")

if __name__ == "__main__":
    args = parser.parse_args()
    image = Image.open(args.image_path).convert("L")

    predicted_angle = predict_angle(args.model, image)

    print(f"{predicted_angle=}")
    image.show()
    corrected_image = image.rotate(predicted_angle)
    corrected_image.show()
