import csv
import statistics
import math
import sys


class Analysis:
    def __init__(self):
        # scroll must have this many points to be considered
        self.scroll_point_threshold = 2

        # scroll must be significantly tall to be considered
        self.y_diff_threshold = 10

        # number of scrolls for output file
        self.num_scrolls_limit = int(sys.argv[1])

        # csv
        self.delim = ','
        self.csvUsers = 'allUsers.csv'
        # self.csvScrolls = 'scrolls' + str(self.num_scrolls_limit) + '.csv'
        self.csvScrolls = 'allScrolls.csv'
        self.csvPoints = 'allPoints.csv'
        self.csvClicks = 'allClicks.csv'

        self.users = []
        self.scrolls = []
        self.points = []
        self.clicks = []

        self.subscrolls = []
        self.stats = []

        self.last_scroll_id = -1

    def create_user(self, uid, gender, hand, age, usage, lang,
                    device, screen_width):
        """ This function creates a user dict and adds it to
        the users list"""
        self.users.append({'id': int(uid),
                           'gender': gender, 'hand': hand,
                           'age': int(age), 'usage': int(usage),
                           'lang': lang, 'device': device,
                           'screen_width': int(screen_width)})

    def create_scroll(self, sid, uid, time):
        """ This function creates a scroll dict and adds it
            to the scrolls list """
        self.scrolls.append({'id': int(sid), 'uid': int(uid),
                             'time': float(time)})

    def create_point(self, pid, scroll_id, x, y):
        """ This function creates a point dict and adds it
            to the points list """
        self.points.append({'id': int(pid), 'scroll_id': int(scroll_id),
                            'x': float(x), 'y': float(y)})

    def read_csvs(self):
        """ This function reads all the csvs and stores the information
        in lists inside the Analysis class """

        # Read users csv and call create user function
        with open(self.csvUsers) as csv_read_users:
            reader = csv.DictReader(csv_read_users, delimiter=self.delim)
            for row in reader:
                if row['id'] != "" and row['hand'] != "bothThumbs" and \
                        row['gender'] != "" and row['age'] != "" and \
                        row['usage'] != "" and row['screen_width'] != "" and \
                        row['hand'] != "":
                    self.create_user(uid=row['id'],
                                     gender=row['gender'], hand=row['hand'],
                                     age=row['age'], usage=row['usage'],
                                     lang=row['lang'], device=row['device'],
                                     screen_width=row['screen_width'])

        # Read scrolls
        with open(self.csvScrolls) as csv_read_scrolls:
            reader = csv.DictReader(csv_read_scrolls, delimiter=self.delim)
            for row in reader:
                if row['uid'] != "" and row['id'] != "":
                    sid = int(row['id'])
                    self.create_scroll(sid=sid, uid=row['uid'],
                                       time=row['time'])
                    if sid > int(self.last_scroll_id):
                        self.last_scroll_id = sid

        # Read points
        with open(self.csvPoints) as csv_read_points:
            reader = csv.DictReader(csv_read_points, delimiter=self.delim)
            for row in reader:
                if row['id'] != "" and \
                        row['srcoll_id'] != "" and \
                        row['x'] != "" and row['y'] != "":
                    self.create_point(pid=row['id'],
                                      scroll_id=row['srcoll_id'],
                                      x=row['x'], y=row['y'])

    def generate_subscrolls(self):
        """ This function breaks scrolls into subscrolls containing upwards-only
        or downwards-only information """

        # Get scrolls for user
        for user in self.users:
            user_scrolls = [scroll for scroll in self.scrolls
                            if scroll['uid'] == user['id']]

            # Get points for scrolls
            for scroll in user_scrolls:
                scroll_points = [point for point in self.points
                                 if point['scroll_id'] == scroll['id']]

                # Break scrolls into upward only and downward only
                if len(scroll_points) > self.scroll_point_threshold:
                    # sort the points by id
                    scroll_points.sort(key=lambda x: x['id'])
                    prev_dir = -1

                    for index in range(len(scroll_points) - 1):
                        prev_point = scroll_points[index - 1]['y'] if index != 0 else scroll_points[0]['y']
                        point = scroll_points[index]

                        # if point['y'] has not changed, dir = 0
                        # if it is bigger, 1; smaller, -1
                        if point['y'] - prev_point != 0:
                            current_dir = (point['y'] - prev_point) / (abs(point['y'] - prev_point))
                        else:
                            current_dir = 0

                        # if there is a change in direction (upward to downward or downward to upward)
                        # create a new scroll and record the new direction
                        if (prev_dir <= 0 < current_dir) or (prev_dir >= 0 > current_dir):
                            self.last_scroll_id += 1
                            prev_dir = current_dir

                        point['scroll_id'] = self.last_scroll_id

                        new_scroll = {'sid': self.last_scroll_id,
                                      'uid': scroll['uid'], 'time': 0, 'points': []}

                        # if this scroll is new
                        if index == 0 or \
                                new_scroll['sid'] != self.subscrolls[len(self.subscrolls) - 1]['sid']:
                            self.subscrolls.append(new_scroll)

                        # append the current point to the last scroll created
                        self.subscrolls[len(self.subscrolls) - 1]['points'].append(point)

    def generate_stats(self):
        """ This function generates the stats that will be analyzed :
        mean x, max x, min x, start x, slope, y displacement, median x scrolls, hand """

        for subscroll in self.subscrolls:
            if len(subscroll['points']) > self.scroll_point_threshold:
	            x = [point['x'] for point in subscroll['points']]
	            y = [point['y'] for point in subscroll['points']]
	            mean_x = statistics.mean(x)
	            median_x = statistics.median(x)
	            min_x = min(x)
	            max_x = max(x)
	            start_x = x[0]
	            slope = y[len(y) - 1] - y[0] / x[len(x) - 1] - x[0] \
	                if x[len(x) - 1] - x[0] != 0 else sys.maxsize
	            y_displacement = y[len(y) - 1] - y[0]

	            self.stats.append({'scrolls': [], 'median_scrolls': 0, 'uid': subscroll['uid'], 'hand': ''})
	            self.stats[len(self.stats) - 1]['scrolls'].append(
	                {
	                    'sid': subscroll['sid'],
	                    'mean_x': mean_x,
	                    'median_x': median_x,
	                    'min_x': min_x,
	                    'max_x': max_x,
	                    'start_x': start_x,
	                    'slope': slope,
	                    'y_displacement': y_displacement
	                }
	            )

        # add median of scrolls and hand
        for user in self.users:
            user_stats = [stat for stat in self.stats
                          if stat['uid'] == user['id']]
            median_xs = [stat['scrolls'][0]['median_x'] for stat in user_stats]

            for stat in user_stats:
                stat['hand'] = user['hand']
                stat['median_scrolls'] = statistics.median(median_xs)

    def print_file(self, filename):
        """ This function prints the arff file for Weka using info from
         the stats list """

        f = open(filename + '.arff', 'w')

        f.write('@RELATION data{0}'.format(self.num_scrolls_limit))

        # We'll write information only for the number of scrolls required
        for num_scroll in range(self.num_scrolls_limit):
            f.write('''
    @ATTRIBUTE s{0}_mean_x REAL                
    @ATTRIBUTE s{0}_max_x REAL                
    @ATTRIBUTE s{0}_min_x REAL                
    @ATTRIBUTE s{0}_start_x REAL                
    @ATTRIBUTE s{0}_slope REAL                
    @ATTRIBUTE s{0}_y_displ REAL
    '''.format(num_scroll + 1))

        f.write('''@ATTRIBUTE all_median_x REAL
    @ATTRIBUTE user_id REAL
    @ATTRIBUTE hand {right, left}
    @DATA\n
    ''')

        # for every user, we print the stats for every scroll
        # up to the scroll limit
        for user in self.users:
            user_stats = [stat for stat in self.stats
                          if stat['uid'] == user['id']]

            # Only write info for this user if they have all
            # the scrolls required, otherwise Weka will not
            # recognise that user's line in the arff file
            if len(user_stats) >= self.num_scrolls_limit > 0:
                for index in range(self.num_scrolls_limit):
                    f.write('{0},{1},{2},{3},{4},{5},'.format(
                        user_stats[index]['scrolls'][0]['mean_x'],
                        user_stats[index]['scrolls'][0]['max_x'],
                        user_stats[index]['scrolls'][0]['min_x'],
                        user_stats[index]['scrolls'][0]['start_x'],
                        user_stats[index]['scrolls'][0]['slope'],
                        user_stats[index]['scrolls'][0]['y_displacement']
                    ))

                # Convert indices to just left or right
                hand = user_stats[0]['hand']
                if hand != 'left' and hand != 'right':
                    if hand == 'leftIndex':
                        hand = 'left'
                    elif hand == 'rightIndex':
                        hand = 'right'

                f.write('{0},{1},{2}\n'.format(
                    user_stats[0]['median_scrolls'],
                    user_stats[0]['uid'],
                    hand,
                ))

        f.close()


################ PROGRAM ###################
analysis = Analysis()
analysis.read_csvs()
analysis.generate_subscrolls()
analysis.generate_stats()

analysis.print_file(sys.argv[2])
