import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import os
from PIL import Image
from torch.optim import lr_scheduler
BATCH_SIZE = 2
import copy
# pil_img  = Image.open(r'D:\code\dataset\exp3\training\00001_matte.png')
# np_img = np.array(pil_img)
# # 对其进行数值转化 语义分割
# np_img[np_img>0] = 1
# plt.imshow(np_img)
# plt.show()
# print(np.unique(np_img))

all_train_pics = glob.glob(r'D:\code\dataset\exp3\training\*.png')
train_images = [p for p in all_train_pics if 'matte' not in p]
train_anno = [p for p in all_train_pics if 'matte' in p]
# print(len(images))
# print(len(anno))

np.random.seed(2022)
index = np.random.permutation(len(train_images))
images = np.array(train_images)[index]
anno = np.array(images)[index]

all_test_pics = glob.glob(r'D:\code\dataset\exp3\testing\*.png')
test_images = [p for p in all_test_pics if 'matte' not in p]
test_anno = [p for p in all_test_pics if 'matte' in p]

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
])

class HK_DataSet(data.Dataset):
    def __init__(self,imgs_path,annos_path):
        self.imgs_path = imgs_path
        self.annos_path = annos_path
    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        anno_path = self.annos_path[index]
        pil_img = Image.open(img_path)
        img_tensor = transform(pil_img)

        anno_img = Image.open(anno_path)
        anno_tensor = transform(anno_img)
        # 人为 0-1
        anno_tensor[anno_tensor>0] = 1
        # 把channel=1去掉 [4, 1, 256, 256] -> [4, 256, 256]
        anno_tensor = torch.squeeze(anno_tensor).type(torch.long)
        return img_tensor,anno_tensor
    def __len__(self):
        return len(self.imgs_path)

train_ds = HK_DataSet(train_images,train_anno)
test_ds = HK_DataSet(test_images,test_anno)

train_dl = data.DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True)
test_dl = data.DataLoader(test_ds,batch_size=BATCH_SIZE)

# img_batch,anno_batch = next(iter(train_dl))
# print(img_batch.shape) # torch.Size([4, 3, 256, 256])
# print(anno_batch.shape) # torch.Size([4, 256, 256])

# img = img_batch[0].permute(1,2,0).numpy()
# anno_temp = anno_batch[0].numpy()
# plt.subplot(1,2,1)
# plt.imshow(img)
# plt.subplot(1,2,2)
# plt.imshow(anno_temp)
# plt.show()

class DownSample(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(DownSample, self).__init__()
        # 希望输出大小不变
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(kernel_size=2)
    def forward(self, input,is_pool=True):
        if(is_pool):
            input = self.pool(input)
        input = self.conv_relu(input)
        return input

class UpSample(nn.Module):
    def __init__(self,channels):
        super(UpSample, self).__init__()
        # 希望输出大小不变
        self.conv_relu = nn.Sequential(
            nn.Conv2d(channels * 2, channels, kernel_size=3, padding=1),
            # inplace=True 直接替换 - 减少内存
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.upConv = nn.Sequential(
            # output_padding - 与卷积的 padding效果一致
            nn.ConvTranspose2d(channels,channels//2,kernel_size=3,stride=2,output_padding=1,padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, input):
        input = self.conv_relu(input)
        input = self.upConv(input)
        return input
class Unet_model(nn.Module):
    def __init__(self):
        super(Unet_model, self).__init__()
        self.down1 = DownSample(3,64)
        self.down2 = DownSample(64, 128)
        self.down3 = DownSample(128, 256)
        self.down4 = DownSample(256, 512)
        self.down5 = DownSample(512, 1024)

        self.up = nn.Sequential(
            nn.ConvTranspose2d(1024,512,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.ReLU()
        )

        self.up1 = UpSample(512)
        self.up2 = UpSample(256)
        self.up3 = UpSample(128)

        self.conv_2 = DownSample(128,64)
        self.last = nn.Conv2d(64,2,kernel_size=1)
    def forward(self,input):
        x1 = self.down1(input,is_pool = False)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)

        x5 = self.up(x5)
        # 对中间层的结果进行重复利用
        # batch * channel * H * W
        # 沿着channel合并
        x5 = torch.cat([x4, x5], dim=1)
        x5 = self.up1(x5)

        x5 = torch.cat([x3, x5], dim=1)
        x5 = self.up2(x5)

        x5 = torch.cat([x2, x5], dim=1)
        x5 = self.up3(x5)
        x5 = torch.cat([x1, x5], dim=1)

        x5 = self.conv_2(x5,is_pool=False)

        x5 = self.last(x5)
        return x5
model = Unet_model().cuda()
# print(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer,step_size=7,gamma = 0.1)


def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0

    model.train()
    for x, y in trainloader:
        if torch.cuda.is_available():
            x, y = x.to('cuda'), y.to('cuda')
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
    exp_lr_scheduler.step()
    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / (total * 256 * 256)

    test_correct = 0
    test_total = 0
    test_running_loss = 0

    model.eval()
    with torch.no_grad():
        for x, y in testloader:
            if torch.cuda.is_available():
                x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()

    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc = test_correct / (test_total * 256 * 256)

    print('epoch: ', epoch,
          'loss： ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'test_loss： ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3)
          )

    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

epochs = 15
epoch_index = []
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_acc = 0.0
modelName = ''
for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                 model,
                                                                 train_dl,
                                                                 test_dl)
    if (epoch_test_acc > best_acc):
        best_model = model
        best_acc = epoch_acc

        temp = 'model01'
        modelName = 'Model_Unet.pkl'
        model_save_path = os.path.join('', modelName)
        torch.save(best_model, modelName)
        best_model_state = copy.deepcopy(model.state_dict())
        torch.save(model.state_dict(), modelName)
    epoch_index.append(epoch)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
print(model)

plt.title(modelName)
plt.xlabel("epoch")
plt.ylabel("test_acc")
plt.plot(epoch_index,test_acc)
plt.show()