# version 11 denotes ,auto keras argumentation, no  validation
# version 12 denotes ,auto keras argumentation,  with validation

#version 13 denotes ,auto keras argumentation,  with validation, and flatten and fully connected
# version 2 denotes ,myself keras argumentation,  with validation
# version 3 denotes ,myself  argumentation using rotation,  with validation

from keras.applications.inception_v3 import InceptionV3
from keras.applications import InceptionResNetV2
from keras.applications import ResNet50

from keras.preprocessing import image
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.layers import Flatten
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Input
import warnings
from keras.optimizers import SGD
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import numpy as np
import os

from keras_applications.resnet_common import ResNet152
from numpy.random import seed
#seed(1)
from tensorflow import set_random_seed
#set_random_seed(2)

warnings.filterwarnings('ignore')

##########################################################
#define models, igsize, running times

trainPath = '/home/amax/mydata/NRtexture3/data/train1000'
validpath='/home/amax/mydata/NRtexture3/data/valid'
resultfolder='/home/amax/mydata/NRtexture3/results_v13'

#trainPath = '/home/amax/mydata/NRtexture/train-moredots0'
#resultfolder='/home/amax/mydata/NRtexture/results-moredots0'


evalPath = '/home/amax/mydata/NRtexture3/data/test170'
visualPath = evalPath
#visualPath = '/home/amax/mydata/NRtexture/test/test2_visual'
model_num=3#0 for InceptionV3,
            # 1 for resnert 50,
            # 2 for restnet 152,
            # 3 for InceptionResNetV2
##########################################################  modify runing times
runTimes=8




##########################################################
igSize=0

if model_num==0:
    igSize=299
    resultfolder=resultfolder+'/InceptionV3'
elif model_num==1:
    igSize = 224
    resultfolder = resultfolder + '/Res50'
elif model_num==2:
    igSize = 224
    resultfolder = resultfolder + '/Res152'
elif model_num==3:
    igSize = 299
    resultfolder = resultfolder + '/InceptionResV2'
else:
    print 'undefine model'


filename=resultfolder+'/'+str(runTimes)+'_00'
if not os.path.exists(filename):
    os.makedirs(filename)
else:
    print(filename +' existed, no clear before!')

filename = resultfolder +'/'+str(runTimes)+'_01'
if not os.path.exists(filename):
    os.makedirs(filename)
else:
    print(filename +' existed, no clear before!')

filename = resultfolder +'/'+str(runTimes)+'_10'
if not os.path.exists(filename):
    os.makedirs(filename)
else:
    print(filename +' existed, no clear before!')

filename = resultfolder +'/'+str(runTimes)+'_11'
if not os.path.exists(filename):
    os.makedirs(filename)
else:
    print(filename +' existed, no clear before!')


trainGen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rescale=1./255,
    horizontal_flip=True,
    vertical_flip=True )
#rotation_range=90,
  #  width_shift_range=0.2,
  #  height_shift_range=0.2,
validGen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rescale=1./255,
    )
#validation_split=0.1
evalGen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rescale=1./255)



input_tensor = Input(shape=(igSize, igSize, 3))

if model_num==0:
    base_model = InceptionV3(input_tensor=input_tensor,
                             input_shape=(igSize, igSize, 3),
                             weights='imagenet',
                             include_top=False)
elif model_num==1:
    base_model = ResNet50(input_tensor=input_tensor,
                             input_shape=(igSize, igSize, 3),
                             weights='imagenet',
                             include_top=False)
elif model_num==2:
    base_model = ResNet152(input_tensor=input_tensor,
                             input_shape=(igSize, igSize, 3),
                             weights='imagenet',
                             include_top=False)
elif model_num==3:
    base_model = InceptionResNetV2(input_tensor=input_tensor,
                             input_shape=(igSize, igSize, 3),
                             weights='imagenet',
                             include_top=False)
else:
    print 'undefine model'

#####
#testmodel=InceptionResNetV2()
#testmodel.summary()
####



#x = base_model.output
#x = GlobalAveragePooling2D()(x)
#y=base_model.layers[-1].output

x = Flatten()(base_model.output)

x = Dense(1024, activation='relu')(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)


model.compile(optimizer=SGD(lr=0.01, momentum=0.9, decay = 0.94),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

#model.summary()

train_generator=trainGen.flow_from_directory(trainPath,
                                                 target_size=(igSize, igSize),
                                                 batch_size=32, shuffle=True)
valid_generator=validGen.flow_from_directory(validpath,
                                                 target_size=(igSize, igSize),
                                                 batch_size=139)

validation_steps=valid_generator.samples/valid_generator.batch_size
print(validation_steps)
#validation_data=validation_generator,
#steps=valid_generator.samples

history=model.fit(train_generator,
                    steps_per_epoch=train_generator.samples/train_generator.batch_size,
                    #use_multiprocessing=Tru
validation_data=valid_generator, validation_steps=valid_generator.samples/valid_generator.batch_size, shuffle=True, verbose=2, epochs=500)############################################epoch number

#shuffle=False,
test_generator=evalGen.flow_from_directory(evalPath,
                                                 target_size=(igSize, igSize),
                                                shuffle=False,
                                                 batch_size=1)


# list all data in history
print(history.history.keys())
# summarize history for accuracy
plt.plot(history.history['val_acc'])
#plt.plot(history.history['val_accuracy'])
#plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
#plt.legend(['train'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
#plt.plot(history.history['val_loss'])
#plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train'], loc='upper left')
plt.show()

filenames=test_generator.filenames
nb_samples=len(filenames)
pred=model.predict_generator(test_generator,steps=nb_samples)
predicted_class_indices=np.argmax(pred,axis=1)
#print(predicted_class_indices)
ignum0=os.listdir( evalPath + '/0')
ignum1=os.listdir( evalPath + '/1')

num_error_0=0
num_error_1=0
for i in range(len(ignum0)):
    print(i)
    igname = filenames[i]
    print(igname)
    igname = igname[2:]
    print(igname)
    pathig = visualPath + '/0/' + igname  # folder 0
    img = mpimg.imread(pathig)
    plt.ion()
    plt.figure(i)
    plt.imshow(img)
    plt.text(20, 20, str(round(pred[i, 0], 2)), color='r', fontsize=20)
    plt.show()
    plt.axis('off')
    if predicted_class_indices[i] == 0:
        plt.savefig(resultfolder+'/'+str(runTimes)+'_00/'+igname) #result folder 00
    else:
        plt.savefig(resultfolder +'/'+str(runTimes)+'_01/' + igname) #result folder 00
        num_error_0=num_error_0+1

    plt.close()

for i in range(len(ignum0), len(ignum0)+len(ignum1)):
    print(i)
    igname = filenames[i]
    print(igname)
    igname = igname[2:]
    print(igname)
    pathig = visualPath + '/1/' + igname  # folder 1
    img = mpimg.imread(pathig)
    plt.ion()
    plt.figure(i)
    plt.imshow(img)
    plt.text(20, 20, str(round(pred[i, 1], 2)), color='r', fontsize=20)
    plt.show()
    plt.axis('off')
    if predicted_class_indices[i] == 1:
        plt.savefig(resultfolder +'/'+str(runTimes)+'_11/'+igname)  #result folder 11
    else:
        plt.savefig(resultfolder +'/'+str(runTimes)+ '_10/' + igname)  #result folder 10
        num_error_1 = num_error_1 + 1
    plt.close()

accuracy=1-round(float(num_error_0 + num_error_1)/(len(ignum0)+len(ignum1)),2)

markfile=resultfolder+'/mark.txt'
fileh= open(markfile,mode='r+')
fileh.read()
fileh.write( 'Model_'+ str(model_num)+ '_Runtime_'+str(runTimes)+'      '+ str(num_error_0)+'  '+ str(num_error_1) +'   '+str(accuracy) )
fileh.write('\n')
fileh.close()

#save pred  results to tex file
datafile=resultfolder +'/'+ 'submark_'+str(runTimes)
np.savetxt(datafile, (pred))
#model.save('v3_1000_NR3_train1000test170epoch500times8.h5')
model.save('v2_1000_NR3_train1000test170epoch500times8.h5')

#input('Press any key......')


