import gym
import numpy as np
import pandas as pd
import gymnasium as gym
from gymnasium import spaces
import warnings
import matplotlib.pyplot as plt
import time

# Disable warnings
warnings.filterwarnings("ignore")

class StockEnvTrainMyself(gym.Env):

  def __init__(self, df_list, start_day, end_day):
    self.truncated = False
    self.seed_value = None
    self.reset_df_list = df_list
    self.df_list = df_list
    # Action space [-1, 0, 1] for sell, hold, buy
    self.action_space = spaces.MultiDiscrete([3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
    # Observation space, composed of the number of stocks and 7 indicators
    self.observation_space = spaces.Box(low=-500, high=np.inf, shape=(10, 7))
    # Define how long to train the stocks based on the start and end dates
    self.start_day = start_day
    # Backup the date
    self.reset_start_day = start_day
    self.end_day = end_day
    self.reset_end_day = end_day

    # Find the row numbers corresponding to the dates
    self.df_list[0]['date'] = pd.to_datetime(self.df_list[0]['date'])
    self.flag_date = self.df_list[0]['date']

    # Find the row number for the corresponding dates
    for i in range(len(self.flag_date)):
      if self.flag_date[i] == pd.to_datetime(self.start_day):
        self.start_day = i
        break
      elif i == len(self.flag_date)-1:
        print("Start date not found")
    for j in range(len(self.flag_date)):
      if self.flag_date[j] == pd.to_datetime(self.end_day):
        self.end_day = j
        break
      elif j == len(self.flag_date)-1:
        print("End date not found")

    # Initial state values
    self.balance = np.ones((10, 1)) # First item in the state, remaining funds
    self.balance *= 1000000

    # Second item in the state, closing price
    self.close = pd.concat([pd.Series(df['close'].iloc[self.start_day]) for df in self.df_list], ignore_index=True, axis=0)
    self.close = self.close.values.reshape(10, 1)

    # Third item in the state, number of shares held
    self.shares = np.ones((10, 1)) # Third item in the state, number of shares held, initially all 0
    self.shares *= 100

    # Fourth item in the state, MACD
    self.macd = pd.concat([pd.Series(df['macd'].iloc[self.start_day]) for df in self.df_list], ignore_index=True, axis=0)
    self.macd = self.macd.values.reshape(10, 1)

    # Fifth item in the state, RSI
    self.rsi = pd.concat([pd.Series(df['rsi_30'].iloc[self.start_day]) for df in self.df_list], ignore_index=True, axis=0)
    self.rsi = self.rsi.values.reshape(10, 1)

    # Sixth item in the state, CCI
    self.cci = pd.concat([pd.Series(df['cci_30'].iloc[self.start_day]) for df in self.df_list], ignore_index=True, axis=0)
    self.cci = self.cci.values.reshape(10, 1)

    # Seventh item in the state, ADX
    self.adx = pd.concat([pd.Series(df['dx_30'].iloc[self.start_day]) for df in self.df_list], ignore_index=True, axis=0)
    self.adx = self.adx.values.reshape(10, 1)
    self.state = np.concatenate([self.balance, self.close, self.shares, self.macd, self.rsi, self.cci, self.adx], axis=1)
    self.asset = [self.state[0, 0]]
    # Initialize reward value
    self.reward = 0
    # Initialize termination signal
    self.terminate = False
    self.flag = True
    self.bought_price = 0

  def step(self, action):

    if self.flag:
      self.bought_price = [df['close'].iloc[self.start_day] for df in self.df_list]
      self.flag = False

    if self.start_day <= self.end_day - 1:

      today = self.start_day
      self.start_day += 1

      next_day = self.start_day

    elif self.start_day == self.end_day:
      self.terminate = True
    last_balance = self.state[0, 0]
    if self.terminate is False:
      today_close = [df['close'].iloc[today] for df in self.df_list]
      next_day_close = [df['close'].iloc[next_day] for df in self.df_list]

      for dimension in range(len(action)):
        if self.state[0, 0] <= 0:
          self.terminate = True
          self.reward -= 1000
          plt.grid(True)
          plt.title("Stock Assets Variation")
          plt.plot(self.asset)
          plt.show()
          break

        if action[dimension] == 0:
          if self.state[dimension, 2] - 100 >= 0 and next_day_close[dimension] - self.bought_price[dimension] >= 20 and self.df_list[dimension]['dx_30'].iloc[next_day] < 20:
            print("Sold")
            self.state[dimension, 2] -= 100
            profit = 100 * next_day_close[dimension]
            self.state[:, 0] += profit
            self.asset.append(self.state[0, 0])
        # Buy
        elif action[dimension] == 2:
          if next_day_close[dimension] - today_close[dimension] >= 10 and self.state[dimension, 2] <= 500 and self.df_list[dimension]['dx_30'].iloc[next_day] > 40:
            print("Bought")
            self.state[dimension, 2] += 100
            profit = 100 * next_day_close[dimension]
            self.bought_price[dimension] = next_day_close[dimension]
            self.state[:, 0] -= profit
            self.asset.append(self.state[0, 0])
        else:
          self.state[:, 0] += self.state[dimension, 2] * (self.df_list[dimension]['close'].iloc[next_day] - self.df_list[dimension]['close'].iloc[today])
          self.asset.append(self.state[0, 0])

        self.state[dimension, 1] = self.df_list[dimension]['close'].iloc[next_day]
        # Update MACD
        self.state[dimension, 3] = self.df_list[dimension]['macd'].iloc[next_day]
        # Update RSI
        self.state[dimension, 4] = self.df_list[dimension]['rsi_30'].iloc[next_day]
        # Update CCI
        self.state[dimension, 5] = self.df_list[dimension]['cci_30'].iloc[next_day]
        # Update ADX
        self.state[dimension, 6] = self.df_list[dimension]['dx_30'].iloc[next_day]
    if self.state[0, 0] - last_balance > 0:
      self.reward += 1
    elif self.state[0, 0] < 0:
      self.reward -= 100

    f = open("drive/MyDrive/AverageBalance/original_with_ADX_03_21_10.txt", "a")
    average_balance = sum(self.asset)/len(self.asset)
    print("Average Balance: ", average_balance)
    f.write(str(average_balance)+"\n")
    f.close()
    return self.state, self.reward, self.terminate, self.truncated, {}

  def reset(self, seed=None):
    self.seed_value = seed
    self.df_list = self.reset_df_list
    # Reset the start date
    self.start_day = self.reset_start_day
    # print("########start_day: ", self.start_day)
    # Reset the end date
    self.end_day = self.reset_end_day
    for i in range(len(self.flag_date)):
      if self.flag_date[i] == pd.to_datetime(self.start_day):
        self.start_day = i
        break
      elif i == len(self.flag_date)-1:
        print("Start date not found")
    for j in range(len(self.flag_date)):
      if self.flag_date[j] == pd.to_datetime(self.end_day):
        self.end_day = j
        break
      elif j == len(self.flag_date)-1:
        print("End date not found")
    self.balance = np.ones((10, 1)) # First item in the state, remaining funds
    self.balance *= 1000000
    # Second item in the state, closing price
    self.close = pd.concat([pd.Series(df['close'].iloc[self.start_day]) for df in self.df_list], ignore_index=True, axis=0)
    self.close = self.close.values.reshape(10, 1)

    self.shares = np.ones((10, 1)) # Third item in the state, number of shares held, initially all 0
    self.shares *= 100

    # Fourth item in the state, MACD
    self.macd = pd.concat([pd.Series(df['macd'].iloc[self.start_day]) for df in self.df_list], ignore_index=True, axis=0)
    self.macd = self.macd.values.reshape(10, 1)

    # Fifth item in the state, RSI
    self.rsi = pd.concat([pd.Series(df['rsi_30'].iloc[self.start_day]) for df in self.df_list], ignore_index=True, axis=0)
    self.rsi = self.rsi.values.reshape(10, 1)

    # Sixth item in the state, CCI
    self.cci = pd.concat([pd.Series(df['cci_30'].iloc[self.start_day]) for df in self.df_list], ignore_index=True, axis=0)
    self.cci = self.cci.values.reshape(10, 1)

    # Seventh item in the state, ADX
    self.adx = pd.concat([pd.Series(df['dx_30'].iloc[self.start_day]) for df in self.df_list], ignore_index=True, axis=0)
    self.adx = self.adx.values.reshape(10, 1)

    self.state = np.concatenate([self.balance, self.close, self.shares, self.macd, self.rsi, self.cci, self.adx], axis=1)
    self.state = self.state.astype(np.float32)
    self.asset = [self.state[0, 0]]
    # Initialize reward value
    self.reward = 0
    # Initialize termination signal
    self.terminate = False
    self.truncated = False
    return self.state, {}

  def render(self):
    pass

df1 = pd.read_csv('drive/MyDrive/HybridModelData/AAPL/aapl1_original_test_data.csv')
df2 = pd.read_csv('drive/MyDrive/HybridModelData/AMGN/amgn1_original_test_data.csv')
df3 = pd.read_csv('drive/MyDrive/HybridModelData/BA/ba1_original_test_data.csv')
df4 = pd.read_csv('drive/MyDrive/HybridModelData/DIS/dis1_original_test_data.csv')
df5 = pd.read_csv('drive/MyDrive/HybridModelData/INTC/intc1_original_test_data.csv')
df6 = pd.read_csv('drive/MyDrive/HybridModelData/KO/ko1_original_test_data.csv')
df7 = pd.read_csv('drive/MyDrive/HybridModelData/MCD/mcd1_original_test_data.csv')
df8 = pd.read_csv('drive/MyDrive/HybridModelData/MSFT/msft1_original_test_data.csv')
df9 = pd.read_csv('drive/MyDrive/HybridModelData/TRV/trv1_original_test_data.csv')
df10 = pd.read_csv('drive/MyDrive/HybridModelData/V/v1_original_test_data.csv')

df_list = [df1, df2, df3, df4, df5, df6, df7, df8, df9, df10]
my_stock = StockEnvTrainMyself(df_list, start_day="2010-07-10", end_day="2013-04-10")

from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.logger import configure

log_path = "drive/MyDrive/ppo_log"
new_logger = configure(log_path, ["stdout", "csv", "tensorboard"])

check_env(my_stock)
model = PPO("MlpPolicy", my_stock, verbose=1)
model.set_logger(new_logger)
model.learn(total_timesteps=60000)
model.save("drive/MyDrive/StockModel/ppo_10w_2024_03_21_with_ADX")
