import io
import json
import math
from pathlib import Path
import time

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps
from torchvision import transforms


def no_change(image: Image) -> Image:
    return image


def jpeg_compress(image: Image) -> Image:
    buffer = io.BytesIO()
    image.save(buffer, format="JPEG", quality=10)
    buffer.seek(0)
    return Image.open(buffer).convert("L")


def compress_and_invert(image: Image) -> Image:
    return ImageOps.invert(jpeg_compress(image))


class OrientationDataset(Dataset):
    def __init__(self, answersheet_path: Path, augment: bool):
        self.sample_path = answersheet_path.parent

        augmentations = [no_change]
        if augment:
            augmentations.extend(
                [
                    jpeg_compress,
                    ImageOps.invert,
                    compress_and_invert,
                ]
            )

        self.samples = []
        for sample in json.loads(answersheet_path.read_text()):
            for augmentation in augmentations:
                self.samples.append(
                    {
                        **sample,
                        "augmentation": augmentation,
                    }
                )
        self.to_tensor = transforms.ToTensor()

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        sample = self.samples[index]
        filename = self.sample_path / sample["filename"]
        image = Image.open(filename).convert("L")
        image = sample["augmentation"](image)
        radians = sample["degrees"] * math.pi / 180
        answer = torch.tensor([math.sin(radians), math.cos(radians)], dtype=torch.float32)
        return (self.to_tensor(image), answer)

    def __len__(self) -> int:
        return len(self.samples)


class OrientationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.flatten = nn.Flatten()

        self.fc = nn.Sequential(
            nn.LazyLinear(out_features=64),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(in_features=64, out_features=2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x


train_dataset = OrientationDataset(Path("data-train/answersheet.json"), augment=True)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
)

test_dataset = OrientationDataset(Path("data-test/answersheet.json"), augment=False)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
)

num_epochs = 20
learning_rate = 0.001

model = OrientationModel()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()


def sincos_to_angles(sines: torch.Tensor, cosines: torch.Tensor) -> torch.Tensor:
    return torch.atan2(sines, cosines) * 180 / math.pi


for epoch in range(num_epochs):
    start_time = time.monotonic()
    running_loss = 0.0
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)

    model.eval()
    within_5 = 0
    within_10 = 0
    within_20 = 0
    total = 0
    train_within_5 = 0
    train_within_10 = 0
    train_within_20 = 0
    train_total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)

            predicted_angles = sincos_to_angles(outputs[:, 0], outputs[:, 1])
            true_angles = sincos_to_angles(labels[:, 0], labels[:, 1])
            angle_diff = torch.abs(predicted_angles - true_angles)
            angle_diff = torch.min(angle_diff, 360 - angle_diff)

            within_5 += (angle_diff < 5).sum().item()
            within_10 += (angle_diff < 10).sum().item()
            within_20 += (angle_diff < 20).sum().item()
            total += labels.size(0)

        for images, labels in train_loader:
            outputs = model(images)

            predicted_angles = sincos_to_angles(outputs[:, 0], outputs[:, 1])
            true_angles = sincos_to_angles(labels[:, 0], labels[:, 1])
            angle_diff = torch.abs(predicted_angles - true_angles)
            angle_diff = torch.min(angle_diff, 360 - angle_diff)

            train_within_5 += (angle_diff < 5).sum().item()
            train_within_10 += (angle_diff < 10).sum().item()
            train_within_20 += (angle_diff < 20).sum().item()
            train_total += labels.size(0)

    seconds_taken = time.monotonic() - start_time
    print(f"[Epoch {epoch + 1}/{num_epochs}] after {seconds_taken:.1f} seconds:")
    print(f"              loss: {avg_loss:.4f}")
    print("        -- TEST --")
    print(f"  within 5 degrees: {within_5 / total * 100:.2f}%")
    print(f" within 10 degrees: {within_10 / total * 100:.2f}%")
    print(f" within 20 degrees: {within_20 / total * 100:.2f}%")
    print("        -- TRAIN --")
    print(f"  within 5 degrees: {train_within_5 / train_total * 100:.2f}%")
    print(f" within 10 degrees: {train_within_10 / train_total * 100:.2f}%")
    print(f" within 20 degrees: {train_within_20 / train_total * 100:.2f}%")
    print()

torch.save(model.state_dict(), "orientation-model.pth")
