Overall Statistics
from clr import AddReference

from System import *
from QuantConnect import *
from QuantConnect.Indicators import *
from QuantConnect.Logging import Log
from QuantConnect.Algorithm import *
from QuantConnect.Algorithm.Framework import *
from QuantConnect.Algorithm.Framework.Alphas import *
from QuantConnect.Orders import OrderStatus
from QuantConnect.Orders.Fees import ConstantFeeModel
from QuantConnect.Algorithm.Framework.Risk import *
from QuantConnect.Algorithm.Framework.Execution import *
from QuantConnect.Algorithm.Framework.Portfolio import *
from QuantConnect.Algorithm.Framework.Selection import *
import pandas as pd
from datetime import timedelta
from enum import Enum

class GasAndCrudeOilEnergyCorrelationAlpha(QCAlgorithm):

    def Initialize(self):
        self.SetStartDate(2018, 1, 1)   #Set Start Date
        self.SetCash(100000)            #Set Strategy Cash

        natural_gas = [Symbol.Create(x, SecurityType.Equity, Market.USA) for x in ['UNG','BOIL','FCG']]
        crude_oil = [Symbol.Create(x, SecurityType.Equity, Market.USA) for x in ['USO','UCO','DBO']]

        ## Set Universe Selection
        self.UniverseSettings.Resolution = Resolution.Minute
        self.SetUniverseSelection( ManualUniverseSelectionModel(natural_gas + crude_oil) )
        self.SetSecurityInitializer(lambda security: security.SetFeeModel(ConstantFeeModel(0)))

        ## Custom Alpha Model

        ## Equal-weight our positions, in this case 100% in USO
        self.SetPortfolioConstruction(EqualWeightingPortfolioConstructionModel(resolution = Resolution.Minute))

        ## Immediate Execution Fill Model

        ## Null Risk-Management Model

    def OnOrderEvent(self, orderEvent):
        if orderEvent.Status == OrderStatus.Filled:
            self.Debug(f'Purchased Stock: {orderEvent.Symbol}')

    def OnEndOfAlgorithm(self):
        for kvp in self.Portfolio:
            if kvp.Value.Invested:
                 self.Log(f'Invested in: {kvp.Key}')

class RsiAlphaModel(AlphaModel):
    '''Uses Wilder's RSI to create insights.
    Using default settings, a cross over below 30 or above 70 will trigger a new insight.'''

    def __init__(self,
                 period = 14,
                 resolution = Resolution.Daily):
        '''Initializes a new instance of the RsiAlphaModel class
            period: The RSI indicator period'''
        self.period = period
        self.resolution = resolution
        self.insightPeriod = Time.Multiply(Extensions.ToTimeSpan(resolution), period)
        self.symbolDataBySymbol ={}

        resolutionString = Extensions.GetEnumString(resolution, Resolution)
        self.Name = '{}({},{})'.format(self.__class__.__name__, period, resolutionString)

    def Update(self, algorithm, data):
        '''Updates this alpha model with the latest data from the algorithm.
        This is called each time the algorithm receives data for subscribed securities
            algorithm: The algorithm instance
            data: The new data available
            The new insights generated'''
        insights = []
        for symbol, symbolData in self.symbolDataBySymbol.items():
            rsi = symbolData.RSI
            #add bb indicator and close price
            bb = symbolData.BB
            close = symbolData.Close
            previous_state = symbolData.State
            state = self.GetState(rsi, previous_state, bb, close)
            #add bb indicator 
            if state != previous_state and rsi.IsReady and bb.IsReady:
                if state == State.TrippedLow:
                    insights.append(Insight.Price(symbol, self.insightPeriod, InsightDirection.Up))
                if state == State.TrippedHigh:
                    insights.append(Insight.Price(symbol, self.insightPeriod, InsightDirection.Down))

            symbolData.State = state

        return insights

    def OnSecuritiesChanged(self, algorithm, changes):
        '''Cleans out old security data and initializes the RSI for any newly added securities.
        Event fired each time the we add/remove securities from the data feed
            algorithm: The algorithm instance that experienced the change in securities
            changes: The security additions and removals from the algorithm'''

        # clean up data for removed securities
        symbols = [ x.Symbol for x in changes.RemovedSecurities ]
        if len(symbols) > 0:
            for subscription in algorithm.SubscriptionManager.Subscriptions:
                if subscription.Symbol in symbols:
                    self.symbolDataBySymbol.pop(subscription.Symbol, None)

        # initialize data for added securities
        addedSymbols = [ x.Symbol for x in changes.AddedSecurities if x.Symbol not in self.symbolDataBySymbol]
        if len(addedSymbols) == 0: return

        history = algorithm.History(addedSymbols, self.period, self.resolution)

        for symbol in addedSymbols:
            rsi = algorithm.RSI(symbol, self.period, MovingAverageType.Wilders, self.resolution)
            #add bb indicator and close price
            bb = algorithm.BB(symbol, self.period, 3, MovingAverageType.Wilders, self.resolution)
            close = algorithm.Securities[symbol].Close
            if not history.empty:
                ticker = SymbolCache.GetTicker(symbol)

                if ticker not in history.index.levels[0]:
                    Log.Trace(f'RsiAlphaModel.OnSecuritiesChanged: {ticker} not found in history data frame.')

                for tuple in history.loc[ticker].itertuples():
                    rsi.Update(tuple.Index, tuple.close)
                    #add bb indicator
                    bb.Update(tuple.Index, tuple.close)
            #add bb indicator and close price
            self.symbolDataBySymbol[symbol] = SymbolData(symbol, rsi, bb, close)

    def GetState(self, rsi, previous, bb, close):
        ''' Determines the new state. This is basically cross-over detection logic that
        includes considerations for bouncing using the configured bounce tolerance.'''
        #update new condition
        if close > 2*bb.UpperBand.Current.Value:
            return State.TrippedHigh
        if rsi.Current.Value < 30:
            return State.TrippedLow
        if previous == State.TrippedLow:
            if rsi.Current.Value > 35:
                return State.Middle
        if previous == State.TrippedHigh:
            if rsi.Current.Value < 65:
                return State.Middle

        return previous

class SymbolData:
    '''Contains data specific to a symbol required by this model'''
    def __init__(self, symbol, rsi, bb, close):
        self.Symbol = symbol
        self.RSI = rsi
        #add bb indicator and close price
        self.BB = bb
        self.Close = close
        self.State = State.Middle

class State(Enum):
    '''Defines the state. This is used to prevent signal spamming and aid in bounce detection.'''
    TrippedLow = 0
    Middle = 1
    TrippedHigh = 2
class CustomExecutionModel(ExecutionModel):
    '''Provides an implementation of IExecutionModel that immediately submits market orders to achieve the desired portfolio targets'''

    def __init__(self):
        '''Initializes a new instance of the ImmediateExecutionModel class'''
        self.targetsCollection = PortfolioTargetCollection()
        self.previous_symbol = None

    def Execute(self, algorithm, targets):
        '''Immediately submits orders for the specified portfolio targets.
            algorithm: The algorithm instance
            targets: The portfolio targets to be ordered'''


        for target in self.targetsCollection.OrderByMarginImpact(algorithm):
            open_quantity = sum([x.Quantity for x in algorithm.Transactions.GetOpenOrders(target.Symbol)])
            existing = algorithm.Securities[target.Symbol].Holdings.Quantity + open_quantity
            quantity = target.Quantity - existing
            ## Liquidate positions in Crude Oil ETF that is no longer part of the highest-correlation pair
            if (str(target.Symbol) != str(self.previous_symbol)) and (self.previous_symbol is not None):
            if quantity != 0:
                algorithm.MarketOrder(target.Symbol, quantity)
                self.previous_symbol = target.Symbol