Overall Statistics
# region imports
from AlgorithmImports import *
# endregion


class DeepOTMOptionSwitchAlgorithm(QCAlgorithm):

    def initialize(self):
        self.set_start_date(self.end_date - timedelta(5*365))
        self.set_cash(100_000)
        self.settings.seed_initial_prices = True
        self.settings.automatic_indicator_warm_up = True
        self._spy = self.add_equity("SPY", data_normalization_mode=DataNormalizationMode.RAW)
        self._spy.call = None
        self._spy.put = None

        self._tlt = self.add_equity("TLT", Resolution.HOUR)
        self._tlt.sma_long = self.sma(self._tlt, 200)
        self._tlt.sma_short = self.sma(self._tlt, 50)

        self.schedule.on(
            self.date_rules.every(DayOfWeek.THURSDAY, DayOfWeek.FRIDAY), 
            self.time_rules.before_market_close(self._spy, 60), 
            self.liquidate
        ) 
        
        self.schedule.on(
            self.date_rules.every(DayOfWeek.MONDAY), 
            self.time_rules.after_market_open(self._spy, 120), 
            self._select_contracts
        )
        
        self.schedule.on(
            self.date_rules.every(DayOfWeek.MONDAY), 
            self.time_rules.after_market_open(self._spy, 122), 
            self._trade
        ) 

    def _select_contracts(self):
        chain = self._initial_filter(-10, 10, 3, 7)
        if not chain: 
            self._spy.call = None
            self._spy.put = None
            return
        # Sorted the OptionChain by expiration date and choose the furthest date.
        expiry = sorted(chain, key=lambda x: x.expiry)[-1].expiry
        # Select the call contracts that expiry on the desired date.
        calls = [c for c in chain if c.expiry == expiry and c.right == OptionRight.CALL]
        # Add the deep OTM call contract.
        self._spy.call = sorted(calls, key=lambda c: c.strike)[-1]
        self.add_option_contract(self._spy.call)

        # Select the put contracts that expiry on the desired date.
        puts = [c for c in chain if c.expiry == expiry and c.right == OptionRight.PUT]
        # Add the deep OTM put contract.
        self._spy.put = sorted(calls, key=lambda c: c.strike)[0]
        self.add_option_contract(self._spy.put)   

    def _initial_filter(self, min_strike_rank, max_strike_rank, min_expiry, max_expiry):
        ''' This method is an initial filter of option contracts 
            according to the range of strike price and the expiration date '''
        chain = self.option_chain(self._spy)
        # Filter the contracts based on the expiry range.
        chain = [c for c in chain if min_expiry < (c.expiry.date() - self.time.date()).days < max_expiry]
        if not chain:
            return []
        # Find the ATM strike.
        atm_strike = sorted(chain, key=lambda c: abs(c.strike - self._spy.price))[0].strike
        strikes = sorted(set([c.strike for c in chain]))
        # Find the stikes that fall within the min/max strike rank.
        atm_strike_rank = strikes.index(atm_strike)
        strikes = strikes[(atm_strike_rank + min_strike_rank):(atm_strike_rank + max_strike_rank)]
        # Return the contracts that meet the criteria.
        return [c for c in chain if c.strike in strikes]

    def _trade(self):
        if not (self._spy.call and self._spy.put):
            return
        self.buy(self._spy.call if self._tlt.sma_long < self._tlt.sma_short else self._spy.put, 1)