Overall Statistics
from AlgorithmImports import *

# Goal:
#   - Option chain selection
#   - Build option strategies 
#   - Build monitor and scanner system 

class BullCallSpreadStrategy(QCAlgorithm): 
    def Initialize(self):
        self.SetStartDate(2022, 12, 1)
        self.SetEndDate(2023, 1, 1)
        self.SetCash(500000)

        equity = self.AddEquity("TSLA", Resolution.Minute)
        equity.SetDataNormalizationMode(DataNormalizationMode.Raw)

        option = self.AddOption("TSLA", Resolution.Minute)
        self.symbol = option.Symbol

        # Set filter
        #option.SetFilter(-15, 15, timedelta(20), timedelta(30))
        option.SetFilter(self.UniverseFunc)

        self.tradeTime = None

        # Set Warmup period of at leasr 30 days
        # Use ONLY when you need old data like any lagging indicator
        # or anything before starting the real algo
        # self.SetWarmup(30, Resolution.Daily)

        # Schedule to Open the position After market open
        # TODO: - self.CurrentSlice["TSLA"] doesn't work. Find it
        # self.Schedule.On(
        #     self.DateRules.EveryDay("TSLA"), 
        #     self.TimeRules.AfterMarketOpen("TSLA", 10), 
        #     self.OpenPosition    
        # )

        # Schedule to close the position before market close 
        self.Schedule.On(
            self.DateRules.EveryDay("TSLA"), 
            self.TimeRules.BeforeMarketClose("TSLA", 10), 
            self.ClosePosition    
        )

    def ClosePosition(self):
        if self.IsWarmingUp:
            return 

        self.Log("Closing all position EOD!")
        self.Liquidate()


    def UniverseFunc(self, universe):
        return universe.Strikes(-15, 15).Expiration(timedelta(20), timedelta(30)).IncludeWeeklys().OnlyApplyFilterAtMarketOpen()

    def OnData(self, slice: Slice) -> None:
        if self.IsWarmingUp:
            return 

        if self.tradeTime is not None and self.Time.day == self.tradeTime.day:
            return

        # Get the OptionChain
        chain = slice.OptionChains.get(self.symbol, None)
        if not chain: return

        # Build the option strategy 
        option_strategy = self.BullCallSpreadBuilder(chain)
        self.Buy(option_strategy, 1)
        self.tradeTime = self.Time

    def BullCallSpreadBuilder(self, chain):

        # Get the furthest expiration date of the contracts
        expiry = sorted(chain, key = lambda x: x.Expiry)[0].Expiry
        
        # Select the call Option contracts with the furthest expiry
        calls = [i for i in chain if i.Expiry == expiry and i.Right == OptionRight.Call]
        if len(calls) == 0: return

        # Select the ITM and OTM contract strike prices from the remaining contracts
        call_strikes = sorted([x.Strike for x in calls])


        # Here I expect all strike between -minStrike +maxStrike 
        # Not only +5 diff strike. Just not sure why it's providing 
        # 5 difference strike. There are lot more strike in between 
        self.Log(str(call_strikes))

        itm_strike = call_strikes[0]
        otm_strike = call_strikes[-1]

        option_strategy = OptionStrategies.BullCallSpread(self.symbol, itm_strike, otm_strike, expiry)
        return option_strategy

    def OnOrderEvent(self, orderEvent):
        if not orderEvent.Status == OrderStatus.Filled:
            return 
        message =  f"Symbol {orderEvent.Symbol} |" + \
                        f"Fill Price {orderEvent.FillPrice} "+ \
                        f"Action {orderEvent. Direction}"
        self.Log(f"{self.Time} {message}")


    # TODO: - Try using OptionChainProvider
    def InitialFilter(self, underlyingsymbol, symbol_list, min_strike_rank, max_strike_rank, min_expiry, max_expiry):
        if len(symbol_list) == 0 : 
            return None
        
        # fitler the contracts based on the expiry range
        contract_list = [i for i in symbol_list if min_expiry <= (i.ID.Date.date() - self.Time.date()).days <= max_expiry]
        
        if not contract_list: 
            return None

        # find the strike price of ATM option
        atm_strike = sorted(contract_list,
                            key = lambda x: abs(x.ID.StrikePrice - self.Securities[underlyingsymbol].Price))[0].ID.StrikePrice
        strike_list = sorted(set([i.ID.StrikePrice for i in contract_list]))
        
        # find the index of ATM strike in the sorted strike list
        atm_strike_rank = strike_list.index(atm_strike)

        try: 
            min_strike = strike_list[atm_strike_rank + min_strike_rank + 1]
            max_strike = strike_list[atm_strike_rank + max_strike_rank - 1]

        except:
            min_strike = strike_list[0]
            max_strike = strike_list[-1]
        return filtered_contracts