from AlgorithmImports import *
#from AlgorithmImports.Indicators import *
import joblib

class RandomForestAlgorithm(QCAlgorithm):

    def Initialize(self):
        self.SetStartDate(2000, 1, 1)  # Set the start date
        self.SetEndDate(2000, 1, 31)
        self.SetCash(100000)  # Set the initial cash balance

        # Load the machine learning model
        model_key = 'spy_randomforest_predictor'
        if self.ObjectStore.ContainsKey(model_key):
            file_name = self.ObjectStore.GetFilePath(model_key)
            self.model = joblib.load(file_name)
        else:
            # Handle the case when the model is not found
            self.Error(f"Model '{model_key}' not found.")
            self.Quit()  # You may want to exit the algorithm here
        
        #self.SetBenchmark("SPY")  # Set the benchmark
        self.spy = self.AddEquity("SPY").Symbol  # Add SPY data

        # Add the VIX data using AddData
        self.vix = self.AddData(CBOE, "VIX", Resolution.Daily).Symbol

        # Add USTYCR data using AddData
        self.syield10yr = self.AddData(USTreasuryYieldCurveRate, "USTYCR", Resolution.Daily).Symbol

        # Initialize rolling windows
        self.vix_window = RollingWindow[float](30)
        self.spy_window = RollingWindow[float](30)
        self.syield10yr_window = RollingWindow[float](30)

        # Initialize indicators
        self.mfi = MoneyFlowIndex(14)
        self.macd = MovingAverageConvergenceDivergence(12, 26, 9, MovingAverageType.Wilders)
        self.williams = WilliamsPercentR(14)
        self.stochastic = Stochastic(14, 3, 3)

        # Attach indicators to SPY data
        self.RegisterIndicator(self.spy, self.mfi, Resolution.Daily)
        self.RegisterIndicator(self.spy, self.macd, Resolution.Daily)
        self.RegisterIndicator(self.spy, self.williams, Resolution.Daily)
        self.RegisterIndicator(self.spy, self.stochastic, Resolution.Daily)

        # Initialize a counter to keep track of warm-up period
        #self.warm_up_counter = 0
        self.SetWarmUp(30)  # Set the warm-up period

    def OnData(self, slice):
        if slice.ContainsKey("VIX.CBOE") and slice.ContainsKey(self.spy) and slice.ContainsKey("USTYCR"):
            # Access the data if it exists
            vix_data = slice["VIX.CBOE"]
            spy_data = slice[self.spy]
            syield10yr_data = slice["USTYCR"]

            # Update rolling windows with new data
            self.vix_window.Add(vix_data.Close)
            self.spy_window.Add(spy_data.Close)
            self.syield10yr_window.Add(syield10yr_data.TenYear)

            # Increment the warm-up counter
            #self.warm_up_counter += 1

            if self.IsWarmingUp: 
                return
            else:
                df = self.CalculateFeatures()

                prediction = self.GetPrediction(df)
                
                if prediction == "Up":
                    self.SetHoldings(self.spy, 1)
                else:
                    self.SetHoldings(self.spy, 0)
        #else:
            # Data is not available, handle this case (e.g., log a message)
            #self.Debug("Data not available for one or more symbols")

    def CalculateFeatures(self):
        # Calculate your features from the data and create a DataFrame
        # Replace this with your specific feature calculation logic
        # Ensure that the DataFrame has the required columns ('vix_lagged', 'return_shift', etc.)

        # Example placeholder for feature calculation:
        vix_lagged = self.vix_window[0]  # Replace with actual calculation
        return_shift = self.spy_window[0]  # Replace with actual calculation
        change_yield_rate_shift = self.syield10yr_window[0]  # Replace with actual calculation

        # Calculate MFI, MACD, Williams %R, and Stochastic Oscillator using indicators
        mfi = self.mfi.Current.Value
        macd = self.macd.Current.Value
        williams = self.williams.Current.Value
        stochastic = self.stochastic.Current.Value

        # Create a DataFrame with the calculated features
        df = pd.DataFrame({
            'vix_lagged': vix_lagged,
            'return_shift': return_shift,
            'change_yield_rate_shift': change_yield_rate_shift,
            'mfi': mfi,
            'macd': macd,
            'williams': williams,
            'stochastic': stochastic
        })
        self.Debug(df.tail(1))
        return df.tail(1)  # Return the last row of the DataFrame

    def GetPrediction(self, df):
        # Ensure that you have enough data for prediction
        if len(df) == 0:
            return "NotEnoughData"

        # Make the prediction using your machine learning model
        prediction = self.model.predict(df)

        if prediction == 1:
            return "Up"
        else:
            return "Down"