Overall Statistics
from Alphas.RsiAlphaModel import RsiAlphaModel
from Execution.ImmediateExecutionModel import ImmediateExecutionModel
from Portfolio.EqualWeightingPortfolioConstructionModel import EqualWeightingPortfolioConstructionModel
from Risk.MaximumDrawdownPercentPerSecurity import MaximumDrawdownPercentPerSecurity
from Selection.QC500UniverseSelectionModel import QC500UniverseSelectionModel
from System.Drawing import Color
from enum import Enum

class MyAlgorithm(QCAlgorithm):

    def Initialize(self):
        self.symbol = "SPY"
        self.res2use = Resolution.Daily
        
        self.SetStartDate(2018, 6, 17)  # Set Start Date
        self.SetCash(100000)  # Set Strategy Cash

        # request the daily equity data
        self.AddEquity(self.symbol, self.res2use)
        self.AddPlots(self.symbol, self.res2use)
        
        # Six module plug and play algorithm development model
        self.AddAlpha(RsiAlphaModel(14, self.res2use))
        self.SetExecution(ImmediateExecutionModel())
        self.SetPortfolioConstruction(EqualWeightingPortfolioConstructionModel())
        self.SetRiskManagement(MaximumDrawdownPercentPerSecurity(0.01))
        # self.SetUniverseSelection(QC500UniverseSelectionModel())
        symbols = [ Symbol.Create(self.symbol, SecurityType.Equity, Market.USA) ]
        self.SetUniverseSelection( ManualUniverseSelectionModel(symbols) )
        
        # Create new chart and series
        plotExample = Chart("Portfolio and Triggers")
        plotExample.AddSeries(Series("Portfolio Value", SeriesType.Line, '$', Color.Green))
        plotExample.AddSeries(Series('Buy', SeriesType.Scatter, '$', Color.Red, ScatterMarkerSymbol.Triangle))
        plotExample.AddSeries(Series('Sell', SeriesType.Scatter, '$', Color.Blue, ScatterMarkerSymbol.TriangleDown))
        self.AddChart(plotExample)

    def OnData(self, data):
        '''OnData event is the primary entry point for your algorithm. Each new data point will be pumped in here.
            Arguments:
                data: Slice object keyed by symbol containing the stock data
        '''
        if not self.Portfolio.Invested:
            self.SetHoldings(self.symbol, 1.0)
            
    def AddPlots(self, symbol, res2use):
        # Calcualte and plot various technical indicators
        self.sym_price = self.Identity(symbol)
        
        # Process: 1. Create Indictor 
        #          2. Register the daily data of "SPY" to automatically update the indicator 
        #          3. Plot indicator
        
        # SMA - Simple moving average
        self.sma50 = self.SMA(symbol, 50, res2use)
        self.sma200 = self.SMA(symbol, 200, res2use)
        self.RegisterIndicator(symbol, self.sma50)
        self.RegisterIndicator(symbol, self.sma200)
        self.PlotIndicator("SMA50-SMA200", self.sym_price, self.sma50, self.sma200)

        # BB - Bolling Bands 
        self.bb = self.BB(symbol, 200, res2use)
        self.RegisterIndicator(symbol, self.bb)
        self.PlotIndicator("BB", self.sym_price, self.bb.UpperBand, self.bb.LowerBand)

        # RSI - Relative Strength Index
        self.rsi = self.RSI(symbol, 10, MovingAverageType.Simple, res2use)
        self.RegisterIndicator(symbol, self.rsi)
        self.PlotIndicator("RSI", self.rsi)
    
class RsiAlphaModel(AlphaModel):
    '''Uses Wilder's RSI to create insights.
    Using default settings, a cross over below 30 or above 70 will trigger a new insight.'''

    def __init__(self,
                 period = 14,
                 resolution = Resolution.Daily):
        '''Initializes a new instance of the RsiAlphaModel class
        Args:
            period: The RSI indicator period'''
        self.period = period
        self.resolution = resolution
        self.insightPeriod = Time.Multiply(Extensions.ToTimeSpan(resolution), period)
        self.symbolDataBySymbol ={}

        resolutionString = Extensions.GetEnumString(resolution, Resolution)
        self.Name = '{}({},{})'.format(self.__class__.__name__, period, resolutionString)

    def Update(self, algorithm, data):
        '''Updates this alpha model with the latest data from the algorithm.
        This is called each time the algorithm receives data for subscribed securities
        Args:
            algorithm: The algorithm instance
            data: The new data available
        Returns:
            The new insights generated'''
        insights = []
        for symbol, symbolData in self.symbolDataBySymbol.items():
            rsi = symbolData.RSI
            
            previous_state = symbolData.State
            state = self.GetState(rsi, previous_state)

            if state != previous_state and rsi.IsReady:
                if state == State.TrippedLow:
                    insights.append(Insight.Price(symbol, self.insightPeriod, InsightDirection.Up))
                    
                if state == State.TrippedHigh:
                    insights.append(Insight.Price(symbol, self.insightPeriod, InsightDirection.Down))
                    
            symbolData.State = state
            
        algorithm.Plot("Portfolio and Triggers", "Portfolio Value", algorithm.Portfolio.TotalPortfolioValue)
        for insight in insights:
            if insight.Direction > 0:
                algorithm.Plot("Portfolio and Triggers", "Buy", algorithm.Portfolio.TotalPortfolioValue)
            else:
                algorithm.Plot("Portfolio and Triggers", "Sell", algorithm.Portfolio.TotalPortfolioValue)
        return insights
        

    def OnSecuritiesChanged(self, algorithm, changes):
        '''Cleans out old security data and initializes the RSI for any newly added securities.
        Event fired each time the we add/remove securities from the data feed
        Args:
            algorithm: The algorithm instance that experienced the change in securities
            changes: The security additions and removals from the algorithm'''

        # clean up data for removed securities
        symbols = [ x.Symbol for x in changes.RemovedSecurities ]
        if len(symbols) > 0:
            for subscription in algorithm.SubscriptionManager.Subscriptions:
                if subscription.Symbol in symbols:
                    self.symbolDataBySymbol.pop(subscription.Symbol, None)
                    subscription.Consolidators.Clear()

        # initialize data for added securities

        addedSymbols = [ x.Symbol for x in changes.AddedSecurities if x.Symbol not in self.symbolDataBySymbol]
        if len(addedSymbols) == 0: return

        history = algorithm.History(addedSymbols, self.period, self.resolution)

        for symbol in addedSymbols:
            rsi = algorithm.RSI(symbol, self.period, MovingAverageType.Wilders, self.resolution)

            if not history.empty:
                ticker = SymbolCache.GetTicker(symbol)

                if ticker not in history.index.levels[0]:
                    Log.Trace(f'RsiAlphaModel.OnSecuritiesChanged: {ticker} not found in history data frame.')
                    continue

                for tuple in history.loc[ticker].itertuples():
                    rsi.Update(tuple.Index, tuple.close)

            self.symbolDataBySymbol[symbol] = SymbolData(symbol, rsi)


    def GetState(self, rsi, previous):
        ''' Determines the new state. This is basically cross-over detection logic that
        includes considerations for bouncing using the configured bounce tolerance.'''
        if rsi.Current.Value > 70:
            return State.TrippedHigh
        if rsi.Current.Value < 30:
            return State.TrippedLow
        if previous == State.TrippedLow:
            if rsi.Current.Value > 35:
                return State.Middle
        if previous == State.TrippedHigh:
            if rsi.Current.Value < 65:
                return State.Middle

        return previous


class SymbolData:
    '''Contains data specific to a symbol required by this model'''
    def __init__(self, symbol, rsi):
        self.Symbol = symbol
        self.RSI = rsi
        self.State = State.Middle


class State(Enum):
    '''Defines the state. This is used to prevent signal spamming and aid in bounce detection.'''
    TrippedLow = 0
    Middle = 1
    TrippedHigh = 2