import os

import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
from torchvision import transforms
from torch.utils import data
from Unet_Class import Unet_model
BATCH_SIZE = 2
my_model = Unet_model()
my_model.load_state_dict(torch.load('Model_Unet.pkl'))

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
])

# base_dir = r'D:\code\dataset\exp1\4weather'
# test_dir = os.path.join(base_dir, 'test')
# test_ds = torchvision.datasets.ImageFolder(test_dir,transform=transform)
# test_dl = torch.utils.data.DataLoader(
#     test_ds,batch_size=BATCH_SIZE,shuffle=True
# )
#
# num = 3
# image_batch,mask_batch = next(iter(test_dl))
# pred_batch = my_model(image_batch)
#
# plt.figure(figsize = (10,10))
# for i in range(num):
#     plt.subplot(3, 3, i * num + 1)
#     plt.imshow(image_batch[i].permute(1,2,0).cpu().numpy())
#     plt.subplot(3, 3, i * num + 2)
#     plt.imshow(mask_batch[i].cpu().numpy())
#     plt.subplot(3, 3, i * num + 3)
#     # permute(1, 2, 0) -> 将矩阵 C*H*W 转化为 H*W*C = 2*224*224 -> 224*224@=*2
#     # argmax() -> 将dim=2的矩阵转为dim=1 且 元素={0,1} 的矩阵
#     # detach() -> 使用detach()函数来切断一些分支的反向传播
#     pred_img = torch.argmax(pred_batch[i].permute(1, 2, 0),axis = -1).detach().numpy()
#     plt.imshow(pred_img)

# 对现实中的图片进行分割
plt.figure(figsize = (10,10))
img1 = 'D:/code/dataset/testImg/plmm/p1.jpg'
img2 = 'D:/code/dataset/testImg/plmm/p2.jpg'
pil_img1 = Image.open(img1)
img1_tensor = transform(pil_img1)
# print(img_tensor.shape)
# batch * channel * H *W
img_tensor_batch = torch.unsqueeze(img1_tensor,0)
# print(img_tensor_batch.shape)
pred_img1 = my_model(img_tensor_batch)
print("--------------")
pred_img1_torch = torch.argmax(pred_img1[0].permute(1,2,0),axis = -1).numpy()

plt.subplot(1, 2, 1)
plt.imshow(img1_tensor.permute(1,2,0))
plt.subplot(1, 2, 2)
plt.imshow(pred_img1_torch)
plt.show()
