# coding: UTF-8
import os
import torch
import numpy as np
import pickle as pkl
from tqdm import tqdm
import time
from datetime import timedelta
from sklearn.utils import shuffle
from torch_geometric.data import Dataset, Data
import pandas as pd

from sklearn.utils import shuffle, resample
from imblearn.over_sampling import SMOTE
from collections import Counter
import random

VOCAB_MAX_SIZE = 10000
UNK, PAD = '<UNK>', '<PAD>'

def sample_mask(idx, l):
    """Create mask."""
    mask = torch.zeros(l)
    mask[idx] = 1
    return torch.as_tensor(mask, dtype=torch.bool)

def generate_dataset(config):

    tokenizer = lambda x: [y for y in x]
 
    vocab = pkl.load(open(config.vocab_path, 'rb'))
    print(f"Vocab size: {len(vocab)}")

    def enhance_train_set(train_idx, y_train):

        class_counts = Counter(y_train)
        sorted_classes = sorted(class_counts.items(), key=lambda x: x[1])
        class_labels = [c for c, _ in sorted_classes]
        
        counts = np.array([count for _, count in sorted_classes])
        q1 = np.percentile(counts, 25)
        q3 = np.percentile(counts, 75)
        median = np.median(counts)
        
        class_to_idx = {cls: [] for cls in class_labels}
        for idx, label in zip(train_idx, y_train):
            class_to_idx[label].append(idx)
        
        enhanced_idx = []
        
        for cls, count in sorted_classes:
            cls_indices = class_to_idx[cls]
            
            if count < q1: 
                target_size = int(median)
                if count > 0:
                    oversampled = resample(cls_indices, 
                                        replace=True, 
                                        n_samples=target_size,
                                        random_state=42)
                    enhanced_idx.extend(oversampled)
            
            elif count > q3:  
                target_size = int(q3)
                undersampled = resample(cls_indices,
                                    replace=False,
                                    n_samples=target_size,
                                    random_state=42)
                enhanced_idx.extend(undersampled)
            
            else:  
                enhanced_idx.extend(cls_indices)
        random.shuffle(enhanced_idx)
        return enhanced_idx

    def calculate_class_weights(train_texts):
        i =0
        y_enhanced = [_[1] for _ in train_texts]
        class_counts = Counter(y_enhanced)
        total = sum(class_counts.values())
        weights = [total / class_counts[i] for i in range(max(y_enhanced)+1)]
        weights = torch.tensor(weights, dtype=torch.float)
        weights = weights / weights.sum()
        return weights
    def load_dataset(path,pad_size,edgepath):
        texts = []
        idx = []
        types = []
        labels = []
        with open(path, 'r', encoding='UTF-8') as f:
            for lines in tqdm(f):
                line = lines.strip()
                if not line:
                    continue
                id, text, type, label = line.split('\t')
                sequence = []
                sequencetype = []
                token = tokenizer(text)
                tokentype = tokenizer(type)
                seq_len = len(token)
                if seq_len < pad_size:
                    token.extend([vocab.get(PAD)] * (pad_size - seq_len))
                else:
                    token = token[:pad_size]
                    seq_len = pad_size
                for word in token:
                    sequence.append(vocab.get(word, vocab.get(UNK)))
                texts.append((sequence, int(label), seq_len, int(id)))
                idx.append(int(id))
                labels.append(int(label))
                type_len = len(tokentype)
                if type_len < pad_size:
                    tokentype.extend([vocab.get(PAD)] * (pad_size - type_len))
                else:
                    tokentype = tokentype[:pad_size]
                    type_len = pad_size
                for word in tokentype:
                    sequencetype.append(vocab.get(word, vocab.get(UNK)))
                types.append(sequencetype)
        
        types = torch.tensor(labels)
        y = torch.tensor(labels)

        dataEdge = pd.read_csv(edgepath)
        edges = dataEdge[['Vertex1','Vertex2']]
        edges = edges.sort_values(by="Vertex1", ascending=True)
        allEdge = np.array([],dtype=np.int32).reshape((0, 2))
        allEdge = np.vstack([allEdge,edges])
        edge_index = allEdge.transpose()
   
        edge_index = np.concatenate([edge_index, edge_index[::-1]], axis=1)

        edge_index = np.unique(edge_index, axis=1) 
        edge_index = torch.from_numpy(edge_index)

        sample_number = len(texts)
        seed = 10
        shuffled_idx = shuffle(np.array(range(len(texts))), random_state=seed)
        train_idx = shuffled_idx[:int(0.5* sample_number)].tolist()

        y_train=[texts[i][1] for i in train_idx]
        train_idx = enhance_train_set(train_idx, y_train)


        val_idx = shuffled_idx[int(0.5*sample_number): int(0.8*sample_number)].tolist()
        test_idx = shuffled_idx[int(0.8*sample_number):].tolist()
        train_mask = sample_mask(train_idx, sample_number)
        val_mask = sample_mask(val_idx, sample_number)
        test_mask = sample_mask(test_idx, sample_number)
        dataGnn = Data(x=types, edge_index=edge_index,y=y,train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)

        train_texts = [texts[i] for i in train_idx]
        class_weights = calculate_class_weights(train_texts)
        val_texts = [texts[i] for i in val_idx]  
        test_texts = [texts[i] for i in test_idx]


        return train_texts,val_texts,test_texts, train_mask, val_mask, test_mask, dataGnn,class_weights

    train_data,val_data,test_data ,train_mask,dev_mask,test_mask, dataGnn,class_weights = load_dataset(config.data_path,config.pad_size,config.edge_path)

    return vocab, train_data,val_data,test_data, train_mask, dev_mask, test_mask, dataGnn,class_weights

class DatasetIter(object):
    def __init__(self, batches,mask, batch_size, device):
        self.mask = mask
        self.batch_size = batch_size
        self.batches = batches
        self.num_batches = len(batches) // batch_size
        self.residue = False
        if len(batches) % self.batch_size != 0:
            self.residue = True
        self.idx = 0
        self.device = device

    def _to_tensor(self, datas, mask):
        x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
        y = torch.LongTensor([_[1] for _ in datas]).to(self.device)
        mask = torch.BoolTensor([_ for _ in mask]).to(self.device)
        seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
        id = torch.LongTensor([_[3]-1 for _ in datas]).to(self.device)
        return (x, seq_len, id), y, mask

    def __next__(self):
        if self.residue and self.idx == self.num_batches:
            mask = self.mask[self.idx * self.batch_size: len(self.batches)]
            batches = self.batches[self.idx * self.batch_size: len(self.batches)]
            self.idx += 1
            batches = self._to_tensor(batches,mask)
            return batches

        elif self.idx > self.num_batches:
            self.idx = 0
            raise StopIteration
        else:
            mask = self.mask[self.idx * self.batch_size: (self.idx + 1) * self.batch_size]
            batches = self.batches[self.idx * self.batch_size: (self.idx + 1) * self.batch_size]
            self.idx += 1
            batches = self._to_tensor(batches,mask)
            return batches

    def __iter__(self):
        return self

    def __len__(self):
        if self.residue:
            return self.num_batches + 1
        else:
            return self.num_batches


def build_iterator(dataset, mask, config):
    iter = DatasetIter(dataset,mask, config.batch_size, config.device)
    return iter


def get_time_dif(start_time):
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))


