Overall Statistics
Total Orders
19843
Average Win
0.01%
Average Loss
-0.01%
Compounding Annual Return
1.409%
Drawdown
2.400%
Expectancy
0.148
Start Equity
1000000000
End Equity
1013235313.25
Net Profit
1.324%
Sharpe Ratio
-1.683
Sortino Ratio
-2.052
Probabilistic Sharpe Ratio
23.383%
Loss Rate
48%
Win Rate
52%
Profit-Loss Ratio
1.21
Alpha
-0.038
Beta
-0.077
Annual Standard Deviation
0.027
Annual Variance
0.001
Information Ratio
-0.733
Tracking Error
0.179
Treynor Ratio
0.582
Total Fees
$0.00
Estimated Strategy Capacity
$0
Lowest Capacity Asset
CRM 32Z9BP0KRSO5I|CRM SZQUJUA9SVOL
Portfolio Turnover
8.14%
Drawdown Recovery
204
# region imports
from AlgorithmImports import *
import QuantLib as ql
from OptionPricing.option_pricing_model import OptionPricingModel, IndexOptionPricingModel, FallbackIVStrategy
# endregion

class ImpliedVolatilityIndicator(PythonIndicator):
    def __init__(
        self, 
        algo: QCAlgorithm,
        symbol: Symbol, 
        name: str, 
        moneyness: float,
        target_expiry_days: int = 25, 
    ) -> None:

        super().__init__()
        
        self._algo: QCAlgorithm = algo
        self._symbol: Symbol = symbol
        self._moneyness: float = moneyness
        self._target_expiry_days: int = target_expiry_days

        self.name: str = name
        self.value: float = 0.

        if self._symbol.security_type == SecurityType.INDEX:
            self._option_pricing: Optional[OptionPricingModel] = IndexOptionPricingModel(algo)
        else:
            raise TypeError(f"ImpliedVolatilityIndicator.__init__: underlying security type: {self._underlying.security_type} is not supported")

    def update(self, input: TradeBar) -> bool:
        if not isinstance(input, TradeBar):
            raise TypeError('ImpliedVolatilityIndicator.update: input must be a TradeBar')

        self.value = 0.
        found: bool = False

        symbols_to_cleanup: List[Symbol] = []

        sec_type: SecurityType = self._algo.securities[self._symbol].type
        if sec_type == SecurityType.INDEX:
            before_target_option_contract, after_target_option_contract = tuple(
                self._find_atm_options_for_index(self._symbol)
            )
            before_target_future, after_target_future = (None, None)

        elif sec_type == SecurityType.FUTURE:
            # find two future contracts; one of them has sooner expiry than target expiry; other one expires later
            before_target, after_target = self._find_futures_contracts()

            if before_target is None and after_target is None:
                self._algo.log(
                    f'ImpliedVolatilityIndicator.update: No futures contracts with targeted expiry found for: {self._symbol}'
                )
                before_target_option_contract, after_target_option_contract = (None, None)
            else:
                # subscribe futures
                before_target_future, after_target_future = tuple(
                    self._subscribe_futures(before_target, after_target, symbols_to_cleanup)
                )

                # find ATM options for two futures
                before_target_option_contract, after_target_option_contract = tuple(
                    self._find_atm_options_for_futures(
                        before_target_future,
                        after_target_future
                    )
                )
        else:
            raise TypeError(f'ImpliedVolatilityIndicator.update: security type {sec_type} is not supported')
    
        found: bool = before_target_option_contract is not None and after_target_option_contract is not None
        if not found:
            self._algo.log(
                f'ImpliedVolatilityIndicator.update: No option contracts with targeted expiry found for: {self._symbol}'
            )
        else:
            # underlying future has been subscribed
            # NOTE DataNormalizationMode.RAW for underlying future is required
            if all(
                self._algo.subscription_manager.subscription_data_config_service.get_subscription_data_configs(c.underlying_symbol)
                for c in [before_target_option_contract, after_target_option_contract]
            ):
                # subscribe options
                before_target_option, after_target_option = tuple(
                    self._subscribe_options(
                        before_target_option_contract, 
                        after_target_option_contract, 
                        symbols_to_cleanup
                    )
                )

                # calculate implied volatility
                self.value = self._get_implied_volatility(
                    before_target_option, 
                    after_target_option, 
                    before_target_future, 
                    after_target_future
                )

            # remove not needed assets from algorithm
            for asset in symbols_to_cleanup:
                self._algo.remove_security(asset)

        return found

    def _find_futures_contracts(self) -> Tuple[Optional[Symbol]]:
        future_contracts: List[Symbol] = self._algo.future_chain_provider.get_future_contract_list(
            self._symbol, self._algo.time
        )

        if len(future_contracts) == 0:
            self._algo.log(
                f'ImpliedVolatilityIndicator._find_futures_contract: No futures found for: {self._symbol}'
            )
            return None, None

        target_expiry: datetime = self._algo.time + timedelta(days=self._target_expiry_days)

        # select two futures - one of them has sooner expiry than target expiry; other one expires later
        before_target: Symbol = max((f for f in future_contracts if f.ID.date <= target_expiry), key=lambda f: f.ID.date, default=None)
        after_target: Symbol = min((f for f in future_contracts if f.ID.date >= target_expiry), key=lambda f: f.ID.date, default=None)

        return before_target, after_target

    # NOTE this could be the same for equity symbol as well
    def _find_atm_options_for_index(
        self, 
        index_symbol: Symbol
    ) -> Tuple[Optional[OptionContract]]:

        option_contract_symbols: List = list(self._algo.option_chain(index_symbol, flatten=True).contracts.values())

        if len(option_contract_symbols) == 0:
            self._algo.log(
                f'ImpliedVolatilityIndicator._find_atm_options_for_index: No options found for: {index_symbol}'
            )
            for i in range(2):
                yield None
        else:
            underlying_price: float = self._algo.securities[index_symbol].ask_price
            expiries: List[datetime] = list(map(lambda x: x.expiry, option_contract_symbols))
            target_expiry: datetime = self._algo.time + timedelta(days=self._target_expiry_days)
            
            before_target: Symbol = max((e for e in expiries if e <= target_expiry), key=lambda e: e, default=None)
            after_target: Symbol = min((e for e in expiries if e >= target_expiry), key=lambda e: e, default=None)

            for e in [before_target, after_target]:
                contracts: List[OptionContract] = [
                    x for x in option_contract_symbols 
                    if x.strike != 0 
                    and x.expiry == e
                    and (1 - self._moneyness) <= abs(underlying_price / x.strike) <= (1 + self._moneyness)
                ]
            
                if len(contracts) == 0:
                    self._algo.log(
                        f'ImpliedVolatilityIndicator._find_atm_options_for_index: No options filtered for: {index_symbol}'
                    )
                    yield None
                else:
                    strike: float = max(map(lambda x: x.strike, contracts))
                    atm_option: OptionContract = next(
                        filter(
                            lambda x: x.right == OptionRight.CALL and x.expiry == e and x.strike == strike, contracts
                        )
                    )
                    yield atm_option

    def _find_atm_options_for_futures(
        self, 
        before_target_future: Future, 
        after_target_future: Future
    ) -> Tuple[Optional[OptionContract]]:

        for target_future in [before_target_future, after_target_future]:
            option_contract_symbols: List = list(self._algo.option_chain(target_future.symbol, flatten=True).contracts.values())

            if len(option_contract_symbols) == 0:
                self._algo.log(
                    f'ImpliedVolatilityIndicator._find_atm_options_for_futures: No options found for: {target_future.symbol}'
                )
                yield None
            else:
                underlying_price: float = target_future.ask_price
                expiries: List[datetime] = list(map(lambda x: x.expiry, option_contract_symbols))
                expiry: datetime = min(expiries)

                contracts: List[OptionContract] = [
                    x for x in option_contract_symbols 
                    if x.strike != 0 
                    and x.expiry == expiry
                    and (1 - self._moneyness) <= abs(underlying_price / x.strike) <= (1 + self._moneyness)
                ]
            
                if len(contracts) == 0:
                    self._algo.log(
                        f'ImpliedVolatilityIndicator._find_atm_options_for_futures: No options filtered for: {target_future.symbol}'
                    )
                    yield None
                else:
                    strike: float = max(map(lambda x: x.strike, contracts))
                    atm_option: OptionContract = next(
                        filter(
                            lambda x: x.right == OptionRight.CALL and x.expiry == expiry and x.strike == strike, contracts
                        )
                    )
                    yield atm_option
        
    def _get_implied_volatility(
        self, 
        before_target_option: OptionContract, 
        after_target_option: OptionContract, 
        before_target_future: Optional[Future] = None, 
        after_target_future: Optional[Future] = None
    ) -> float:

        # calculate implied volatility for each option using black model
        if before_target_option.underlying.type == SecurityType.FUTURE:
            implied_vols: List[float] = [
                self._option_pricing.bs_iv(
                    option_symbol=o.symbol, 
                    option_price=o.price, 
                    forward_price=f.price, 
                    evaluation_dt=self._algo.time, 
                    fallback_iv_strategy=FallbackIVStrategy.CALL_PUT_PARITY_IV, 
                    discount_factor=1.
                )

                for o, f in zip(
                    [before_target_option, after_target_option],
                    [before_target_future, after_target_future],
                )
            ]
        elif before_target_option.underlying.type == SecurityType.INDEX:
            spot_price: float = self._algo.current_slice.bars.get(self._symbol).close
            pricing_model: IndexOptionPricingModel = IndexOptionPricingModel(self._algo)
            
            implied_vols: List[float] = []
            for o in [before_target_option, after_target_option]:
                rfr: float = pricing_model.get_risk_free_rate()
                dividends: float = pricing_model.get_dividends(o.symbol)
                discount_factor: float = pricing_model.get_discount_factor(rfr, dividends, o.expiry)
                forward_price: float = pricing_model.get_forward_price(spot_price, discount_factor)

                iv: float = self._option_pricing.bs_iv(
                    option_symbol=o.symbol, 
                    option_price=o.price, 
                    forward_price=forward_price, 
                    evaluation_dt=self._algo.time, 
                    fallback_iv_strategy=FallbackIVStrategy.CALL_PUT_PARITY_IV, 
                    discount_factor=discount_factor
                )
                implied_vols.append(iv)

        # interpolate variances
        interpolated_variance: float = self._interpolate(
            (before_target_option.expiry - self._algo.time).days,
            (after_target_option.expiry - self._algo.time).days,
            implied_vols[0] ** 2,   # variance
            implied_vols[1] ** 2,   # variance
            self._target_expiry_days
        )

        # get volatility
        return np.sqrt(interpolated_variance)
    
    def _subscribe_futures(
        self, 
        before_target: Symbol, 
        after_target: Symbol, 
        symbols_to_cleanup: List[Symbol]
    ) -> Tuple[Future]:

        # subscribe futures to QC algorithm
        for s in [before_target, after_target]:
            if not self._algo.securities.contains_key(s):
                target_future: Future = self._algo.add_future_contract(s)
                self._algo.securities[target_future.symbol].set_data_normalization_mode(DataNormalizationMode.RAW)
                
                # add symbol for a later clean up; only those symbols which were not subscribed already (those might be traded)
                symbols_to_cleanup.append(target_future.symbol)
            else:
                target_future: Future = self._algo.securities.get(s)
            
            yield target_future

    def _subscribe_options(
        self, 
        before_target_option_contract: OptionContract, 
        after_target_option_contract: OptionContract, 
        symbols_to_cleanup: List[Symbol]
    ) -> Tuple:

        # subscribe futures to QC algorithm
        for oc in [before_target_option_contract, after_target_option_contract]:
            if not self._algo.securities.contains_key(oc):
                target_option = self._algo.add_future_option_contract(oc)

                # add symbol for a later clean up; only those symbols which were not subscribed already (those might be traded)
                symbols_to_cleanup.append(target_option.symbol)
            else:
                target_option = self._algo.securities.get(oc)
            
            yield target_option

    def _interpolate(self, x1: float, x2: float, y1: float, y2: float, x: float) -> float:
        return ((y2 - y1) * x + x2 * y1 - x1 * y2) / (x2 - x1)
# region imports
from AlgorithmImports import *
from data import *
from enum import Enum
from collections import deque
# endregion

class PortfolioValueType(Enum):
    PORTFOLIO = 1
    OPTION = 2
    INDICATOR = 3

class PortfolioValueIndicator():
    '''
        Stores last n values of portfolio value, (put) option value and indicator value (indicator value = portfolio value - (put) option value) in deques
        Indexing: id -1 -> latest value; id 0 -> oldest value
    '''
        
    def __init__(
        self,
        algo: QCAlgorithm, 
        underlying_symbols: List[Symbol], 
        observations_n: int, 
        option_right: Optional[OptionRight] = None, 
        option_direction: Optional[Direction] = None, 
        update_date_rule: ScheduleRule = None, 
        use_trade_manager: bool = True, 
        trade_manager = None
    ) -> None:

        self._algo: QCAlgorithm = algo
        self._underlying_symbols: List[Symbol] = underlying_symbols
        self._option_right: OptionRight = option_right
        self._option_direction: Direction = option_direction
        self._use_trade_manager: bool = use_trade_manager
        self._trade_manager = trade_manager

        self._value_storage: Dict[PortfolioValueType, deque] = {
            pv_type: deque(maxlen=observations_n) for pv_type in list(PortfolioValueType)
        }

        # register indicator for automatic updates
        if update_date_rule is not None:
            algo.schedule.on(
                update_date_rule.date_rule, 
                update_date_rule.time_rule, 
                self._update
            )

    @property 
    def trade_manager(self):
        return self._trade_manager 

    @trade_manager.setter
    def trade_manager(self, value):
        self._trade_manager = value

    @property
    def is_ready(self) -> bool:
        return all(len(deq) == deq.maxlen for deq in self._value_storage.values())

    def values_by_index(self, index: int) -> Dict[PortfolioValueType, Optional[float]]:
        return {
            portfolio_type : self.value_by_index(portfolio_type, index) for portfolio_type in self._value_storage
        }

    def values_by_type(self, portfolio_type: PortfolioValueType) -> deque:
        return self._value_storage[portfolio_type]

    def value_by_index(self, portfolio_type: PortfolioType, index: int) -> Optional[float]:
        value: float = None

        try:
            value = self._value_storage[portfolio_type][index]
        except IndexError:
            raise IndexError(f"PortfolioValueIndicator.value_by_index: index out of range")
            pass

        return value

    def latest_value(self, portfolio_type: PortfolioValueType) -> Optional[float]:
        return self.value_by_index(portfolio_type, -1)

    # manual update fn
    def update(self) -> None:
        self._update()

    # automatic update fn
    def _update(self) -> None:
        if self._use_trade_manager:
            # filter options from trade manager history
            assert self.trade_manager is not None
            positions: Dict[Symbol, Union[float, int]] = self._trade_manager.trades_hist.positions
        else:
            # filter options from the whole portfolio
            positions: Dict[Symbol, Union[float, int]] = {symbol: holding.quantity for symbol, holding in self._algo.portfolio.items()}
            # positions: Dict[Symbol, Union[float, int]] = list(map(lambda symbol: symbol, self._algo.portfolio.keys()))
        
        underlying_types: List[SecurityType] = list(set(map(lambda x: self._algo.securities[x].type, self._underlying_symbols)))

        invested_options: List[Symbol] = filter_invested_options(
            algo=self._algo, 
            positions=positions, 
            option_types=[SecurityType.OPTION, SecurityType.INDEX_OPTION], 
            underlying_types=underlying_types, 
            underlying_symbols=self._underlying_symbols, 
            only_invested=True, 
            option_right=self._option_right, 
            direction=self._option_direction
        )

        portf_value: float = self._algo.portfolio.total_portfolio_value

        # calculate option portfolio value
        option_value: float = 0.
        for opt_symbol in invested_options:
            if self._use_trade_manager:
                quantity: float = positions[opt_symbol]
                price: float = self._algo.securities[opt_symbol].ask_price if quantity > 0 else self._algo.securities[opt_symbol].bid_price
                multiplier: float = self._algo.securities[opt_symbol].contract_multiplier
                option_value += (price * abs(quantity) * multiplier)
            else:
                option_value += self._algo.portfolio[opt_symbol].absolute_holdings_value
                # option_value += self._algo.portfolio[opt_symbol].holdings_value

        # TEST
        # plot
        # opt value + cash == total port value

        indicator_value: float = portf_value - option_value

        self._value_storage[PortfolioValueType.PORTFOLIO].append(portf_value)
        self._value_storage[PortfolioValueType.OPTION].append(option_value)
        self._value_storage[PortfolioValueType.INDICATOR].append(indicator_value)
# region imports
from AlgorithmImports import *
from sklearn.linear_model import LinearRegression
from data import *
from collections import deque
from OptionsTrading.data import ScheduleRule
# endregion

class AssetPercentilesStorage:

    def __init__(
        self,
        algo: QCAlgorithm, 
        symbol: Symbol,
        trailing_days,
        min_date: datetime, 
        no_of_quantiles: int, 
        min_observations: int, 
        update_schedule_rule: ScheduleRule,
        resolution: Resolution = Resolution.DAILY,
        ):

        self._trailing_days = trailing_days
        self._algo: QCAlgorithm = algo
        self._symbol = symbol
        self._min_date = min_date
        self._no_of_quantiles = no_of_quantiles
        self._resolution = resolution
        self._min_observations = min_observations
        self._schedule_rule = update_schedule_rule
        
        quantiles = [(1 / no_of_quantiles) * (i+1) for i in range(no_of_quantiles - 1)]
        if quantiles[-1] != 1:
            quantiles.append(1)
        
        self._quantiles = quantiles
        self.update()

        self._algo.schedule.on(
            self._schedule_rule.date_rule,
            self._schedule_rule.time_rule, 
            self.update
        )

    def percentile(self, x):
        assert len(x) >= self._trailing_days
        ret = (x[-1] - x[0])/x[0]
        
        for i, percentile in enumerate(self._levels):
            if ret < percentile:
                return i 
    

    @property
    def is_ready(self) -> bool:
        return len(self._data) >= self._min_observations


    def update(self):
        asset_hist = self._algo.history(TradeBar, self._symbol, self._min_date, self._algo.time, self._resolution)
        asset_hist = asset_hist['close']
        
        def total_return(data):
            return(data[-1] - data[0]) / data[0]


        trailing_rets = asset_hist.rolling(window=self._trailing_days).apply(total_return, raw=True)
        trailing_rets = trailing_rets.loc[trailing_rets.first_valid_index():]
        levels = trailing_rets.quantile(self._quantiles).values
        levels[-1] = np.inf
        self._levels = levels
        
        self._data=trailing_rets
        assert len(self._levels) == self._no_of_quantiles 


class QuantilesIndicator(PythonIndicator):
    def __init__(
        self,
        algo: QCAlgorithm, 
        symbol: Symbol,
        trailing_days: int,
        min_history_date: datetime, 
        no_of_quantiles: int, 
        min_observations: int, 
        qunatiles_update_schedule_rule: ScheduleRule,
        auto_updates: bool,
        resolution: Resolution = Resolution.DAILY,
        name: str = 'Quantiles Indicator'
        ):

        super().__init__()
        
        self._algo: QCAlgorithm = algo
        self._symbol = symbol
        self._trailing_days = trailing_days
        self.min_history_date = min_history_date
        self._no_of_quantiles = no_of_quantiles
        self._resolution = resolution
        self._min_observations = min_observations
        self._schedule_rule = qunatiles_update_schedule_rule
        self._auto_updates = auto_updates
        self.value = -1
        self.name = name
        self._quantiles_counter = [0] * self._no_of_quantiles
        
        #self._trailing_data = deque(maxlen=self._trailing_days + 1)

        self._historic_quantiles_indicator = AssetPercentilesStorage(algo, 
                    symbol, 
                    trailing_days, 
                    min_history_date, 
                    no_of_quantiles, 
                    min_observations, 
                    qunatiles_update_schedule_rule, 
                    resolution
                    )
        
         # register indicator for automatic updates
        if auto_updates:
            algo.register_indicator(symbol, self, resolution)

        self._trailing_data = deque(self._algo.history(TradeBar, self._symbol, self._trailing_days, self._resolution).close, 
                            maxlen=self._trailing_days)
        self.value = self._historic_quantiles_indicator.percentile(self._trailing_data)
        self._quantiles_counter[self.value] += 1
        self._last_year = self._algo.time.year

        

    def update(self, input: TradeBar) -> bool:
        if not isinstance(input, TradeBar):
            raise TypeError('RealizedVolatilityIndicator.update: input must be a TradeBar')
        
        assert len(self._trailing_data) == self._trailing_days
        self._trailing_data.popleft()
        self._trailing_data.append(input.close)
        self.value = self._historic_quantiles_indicator.percentile(self._trailing_data)
        self._quantiles_counter[self.value] += 1
        if self._algo.time.year > self._last_year:
            self._algo.log('indicator counter:')
            for i, count in enumerate(self._quantiles_counter):
                self._algo.log(str(i) + ': ' + str(self._quantiles_counter[i]))  
                self._last_year == self._algo.time.year

                
        return 





class RealizedVolatilityIndicator(PythonIndicator):
    def __init__(
        self,
        algo: QCAlgorithm, 
        symbol: Symbol, 
        name: str, 
        daily_period: int,
        auto_updates: bool, 
        resolution: Resolution = Resolution.DAILY
    ) -> None:

        super().__init__()

        self._daily_period: int = daily_period
        self._auto_updates: bool = auto_updates

        self.name: str = name
        self.value: float = 0.

        self._first_bar: Optional[TradeBar] = None
        self._recent_bar: Optional[TradeBar] = None
        self._return_values = deque()
        self._rolling_sum: float = 0.
        self._rolling_sum_of_squares: float = 0.
        self._n: Optional[float] = None
        
        # register indicator for automatic updates
        if auto_updates:
            algo.register_indicator(symbol, self, resolution)

    @property
    def is_auto_updated(self) -> bool:
        return self._auto_updates

    def update(self, input: TradeBar) -> bool:
        if not isinstance(input, TradeBar):
            raise TypeError('RealizedVolatilityIndicator.update: input must be a TradeBar')
        
        if not self._first_bar:
            # store first bar
            self._first_bar = input
        else:
            log_return: float = np.log(input.close / self._recent_bar.close)
            self._return_values.append(log_return)
            
            # update rolling sums
            self._rolling_sum += log_return
            self._rolling_sum_of_squares += log_return ** 2

            is_ready: bool = (input.end_time - self._first_bar.time).days >= self._daily_period
            if is_ready:
                # store number of bars
                if not self._n:
                    self._n = len(self._return_values)
                
                mean_value_1: float = self._rolling_sum / self._n
                mean_value_2: float = self._rolling_sum_of_squares / self._n

                # adjust rolling sums
                removed_return: float = self._return_values.popleft()
                self._rolling_sum -= removed_return
                self._rolling_sum_of_squares -= removed_return ** 2

                self.value = np.sqrt(mean_value_2 - (mean_value_1 ** 2)) * np.sqrt(250) # annualized
        
        self._recent_bar = input
        
        return self._n is not None
# region imports
from AlgorithmImports import *
from collections import deque
from typing import Optional
# endregion

class RealizedVolatilityIndicator(PythonIndicator):
    def __init__(
        self,
        algo: QCAlgorithm, 
        symbol: Symbol, 
        name: str, 
        daily_period: int,
        auto_updates: bool, 
        resolution: Resolution = Resolution.DAILY
    ) -> None:

        super().__init__()

        self._daily_period: int = daily_period
        self._auto_updates: bool = auto_updates

        self.name: str = name
        self.value: float = 0.

        self._first_bar: Optional[TradeBar] = None
        self._recent_bar: Optional[TradeBar] = None
        self._return_values = deque()
        self._rolling_sum: float = 0.
        self._rolling_sum_of_squares: float = 0.
        self._n: Optional[float] = None
        
        # register indicator for automatic updates
        if auto_updates:
            algo.register_indicator(symbol, self, resolution)

    @property
    def is_auto_updated(self) -> bool:
        return self._auto_updates

    def update(self, input: TradeBar) -> bool:
        if not isinstance(input, TradeBar):
            raise TypeError('RealizedVolatilityIndicator.update: input must be a TradeBar')
        
        if not self._first_bar:
            # store first bar
            self._first_bar = input
        else:
            log_return: float = np.log(input.close / self._recent_bar.close)
            self._return_values.append(log_return)
            
            # update rolling sums
            self._rolling_sum += log_return
            self._rolling_sum_of_squares += log_return ** 2

            is_ready: bool = (input.end_time - self._first_bar.time).days >= self._daily_period
            if is_ready:
                # store number of bars
                if not self._n:
                    self._n = len(self._return_values)
                
                mean_value_1: float = self._rolling_sum / self._n
                mean_value_2: float = self._rolling_sum_of_squares / self._n

                # adjust rolling sums
                removed_return: float = self._return_values.popleft()
                self._rolling_sum -= removed_return
                self._rolling_sum_of_squares -= removed_return ** 2

                self.value = np.sqrt(mean_value_2 - (mean_value_1 ** 2)) * np.sqrt(250) # annualized
        
        self._recent_bar = input
        
        return self._n is not None
# region imports
from AlgorithmImports import *
from pandas.core.frame import DataFrame
from pandas.core.series import Series
from collections import deque
from abc import ABC, abstractmethod
# endregion

class VolatilityCone(ABC):
    def __init__(self) -> None:
        self._price_storage: DataFrame = DataFrame(columns=['close'])
    
    @abstractmethod
    def update_price(self, dt: Optional[datetime] = None, price: Optional[float] = None) -> None:
        ...

    @abstractmethod
    def _get_price_df_from_dt(self, dt_from: Optional[datetime], dt_to: Optional[datetime]) -> DataFrame:
        # return empty df
        return self._price_storage
    
    @abstractmethod
    def get_cone_df(
        self, 
        dt_from: datetime, 
        dt_to: datetime, 
        period_list: List[int], 
        percentile_list: List[int], 
        min_historical_period: int = 250,
    ) -> DataFrame:
        # min_historical_period: minimal number of days to calculate percentile distribution for each period in period list
        ...
    
    @abstractmethod
    def get_rolling_cone_values(self, period: int, percentiles: List[int], rolling_period: int = 250) -> DataFrame:
        # rolling_period: rolling number of days to calculate percentile distribution for a given period
        ...

    @abstractmethod
    def get_expanding_cone_values(self, period: int, percentiles: List[int], min_historical_period: int = 250) -> DataFrame:
        # min_historical_period: minimal number of days to calculate percentile distribution for a given period
        ...

class VolatilityConeIndicator(VolatilityCone):
    def __init__(
        self,
        algo: QCAlgorithm, 
        symbol: Symbol
    ) -> None:
        
        self._algo: QCAlgorithm = algo
        self._symbol: Symbol = symbol
        
        algo.schedule.on(
            algo.date_rules.every_day(self._symbol), 
            algo.time_rules.before_market_close(self._symbol, 0), 
            self.update_price
        )
        
        super(VolatilityConeIndicator, self).__init__()

        # price warmup
        self._price_storage = self._get_price_df_from_dt(datetime(1998, 1, 1), self._algo.time)

    def update_price(self, dt: Optional[datetime] = None, price: Optional[float] = None) -> None:
        # price update
        bar: TradeBar = self._algo.current_slice.bars.get(self._symbol)
        close: float = bar.close
        dt: datetime = bar.end_time
        self._price_storage.loc[dt, 'close'] = close

    def _get_price_df_from_dt(self, dt_from: datetime, dt_to: datetime) -> DataFrame:
        history: DataFrame = self._algo.history(TradeBar, self._symbol, dt_from, dt_to, Resolution.DAILY)[['close']]
        return history
    
    def _calculate_volatility(self, price_df: DataFrame, period: int) -> np.ndarray:
        # calculate returns
        log_ret: DataFrame = np.log((price_df / price_df.shift(1)).dropna().astype(float)).values.reshape(1,-1)[0]

        # Standard close-to-close
        # https://macrosynergy.com/research/six-ways-to-estimate-realized-volatility/
        result: np.ndarray = np.array([
            np.sqrt(np.mean((log_ret[i : i+period] - np.mean(log_ret[i : i+period]))**2)) * np.sqrt(250) \
            for i in range(len(log_ret) - period+1)
        ])
        
        return result

    # def _calculate_volatility(self, price_df: DataFrame, period: int) -> np.ndarray:
    #     # calculate returns
    #     log_ret: DataFrame = (price_df / price_df.shift(1) - 1).dropna()
    #     result: np.ndarray = (log_ret.rolling(period).std() * np.sqrt(250)).dropna().values.reshape(1,-1)[0]
    #     return result

    def get_rolling_cone_values(self, period: int, percentiles: List[int], rolling_period: int = 250) -> DataFrame:
        # rolling_period: rolling number of days to calculate percentile distribution for a given period

        result: DataFrame = DataFrame(columns=percentiles)

        if len(self._price_storage) < period:
            self._algo.log(f'QCVolatilityConeIndicator.get_cone_values - not enough data to calculate volatility')
            return result
            
        realized_vol: np.ndarray = self._calculate_volatility(self._price_storage, period)

        if len(realized_vol) < rolling_period:
            self._algo.log(f'QCVolatilityConeIndicator.get_cone_values - not enough data to calculate percentile distribution')
            return result
        
        for percentile in percentiles:
            result.loc[period, percentile] = np.percentile(realized_vol[-rolling_period:], percentile)
        result.loc[period, 'actual'] = realized_vol[-1]

        return result

    def get_expanding_cone_values(self, period: int, percentiles: List[int], min_historical_period: int = 250) -> DataFrame:
        # min_historical_period: minimal number of days to calculate percentile distribution for a given period

        result: DataFrame = DataFrame(columns=percentiles)

        if len(self._price_storage) < period:
            self._algo.log(f'QCVolatilityConeIndicator.get_cone_values - not enough data to calculate volatility')
            return result
            
        realized_vol: np.ndarray = self._calculate_volatility(self._price_storage, period)

        if len(realized_vol) < min_historical_period:
            self._algo.log(f'QCVolatilityConeIndicator.get_cone_values - not enough data to calculate percentile distribution')
            return result

        for percentile in percentiles:
            result.loc[period, percentile] = np.percentile(realized_vol, percentile)
        result.loc[period, 'actual'] = realized_vol[-1]

        return result
        
    def get_cone_df(
        self, 
        dt_from: datetime, 
        dt_to: datetime, 
        period_list: List[int], 
        percentile_list: List[int], 
    
        # minimal number of days to calculate percentile distribution for each period in period list
        min_historical_period: int = 250,

    ) -> DataFrame:

        history: DataFrame = self._get_price_df_from_dt(dt_from, dt_to)
        # history: DataFrame = self._algo.history(TradeBar, self._symbol, self._dt_from, dt_to, Resolution.DAILY)

        percentiles_list: List[Dict] = []

        for period in period_list:
            percentiles = {}

            realized_vol: np.ndarray = self._calculate_volatility(history, period)
            
            if len(realized_vol) < min_historical_period:
                continue
            else:
                for percentile in percentile_list:
                    percentiles[percentile] = np.percentile(realized_vol, percentile)
                percentiles['actual'] = realized_vol[-1]
                
                percentiles_list.append(percentiles)
        
        # create df from list of dicts
        percentile_df: DataFrame = DataFrame(percentiles_list, index=period_list)
        return percentile_df
# region imports
from AlgorithmImports import *
from .option_pricing_model import OptionPricingModel, IndexOptionPricingModel, FallbackIVStrategy
from decimal import Decimal
from pandas.core.frame import DataFrame
# endregion

class IVMonitor():
    def __init__(
        self, 
        algo: QCAlgorithm, 
        underlying: Symbol, 
        include_dividends: bool = True, 
        abs_delta_to_notify: float = 0.
    ) -> None:

        self._algo: QCAlgorithm = algo
        self._underlying: Symbol = underlying
        self._include_dividends: bool = include_dividends
        self._abs_delta_to_notify: float = abs_delta_to_notify
        self._scheduler()

        if self._underlying.security_type == SecurityType.INDEX:
            self._option_pricing: Optional[OptionPricingModel] = IndexOptionPricingModel(algo)
        else:
            raise TypeError(f"IVMonitor.__init__: underlying security type: {self._underlying.security_type} is not supported")

        # QC's implied volatility indicator indexed by option symbol
        self._iv_by_option: Dict[Symbol, ImpliedVolatility] = {}

    def add(self, option_symbol: Symbol) -> None:
        if option_symbol not in self._iv_by_option:
            iv_indicator: ImpliedVolatility = ImpliedVolatility(
                option=option_symbol, 
                risk_free_rate_model=self._algo.risk_free_interest_rate_model, 
                dividend_yield_model=DividendYieldProvider.create_for_option(option_symbol) if self._include_dividends else ConstantDividendYieldModel(0), 
                mirror_option=None, 
                option_model=OptionPricingModelType.BLACK_SCHOLES
            )

            # iv_indicator: ImpliedVolatility = self._algo.iv(
            #     symbol=option_symbol, 
            #     dividend_yield=None if self._include_dividends else 0., 
            #     option_model=OptionPricingModelType.BLACK_SCHOLES
            # )

            self._algo.warm_up_indicator(option_symbol, iv_indicator)

            self._iv_by_option[option_symbol] = iv_indicator

    def _scheduler(self):
        self._algo.schedule.on(
            self._algo.date_rules.every_day(self._underlying), 
            self._algo.time_rules.before_market_close(self._underlying, 0), 
            self._compare_iv
        )

    def _compare_iv(self) -> None:
        # NOTE this can be deleted later - used only for assertion
        def trim_to_a_point(num: float, dec_point: int = 4) -> float:
            factor: int = 10**dec_point
            num = num * factor
            num = int(num)
            num = num / factor
            return num

        # NOTE this can be deleted later - used only for assertion
        def get_precision(num: float) -> int:
            d_num: Decimal = Decimal(str(num))
            sign, digits, exp = d_num.as_tuple()
            precision: int = abs(exp)
            return precision

        option_symbols_to_remove: List[Symbol] = []

        current_slice: Slice = self._algo.current_slice
        underlying_bar: TradeBar = current_slice.bars.get(self._underlying)

        # iterate through the option symbols and measure the difference in our black model IV value and QC's indicator IV value
        for option_symbol, iv_indicator in self._iv_by_option.items():
            # update indicator
            opt_bar: QuoteBar = current_slice.quote_bars.get(option_symbol)
            
            if underlying_bar and opt_bar:
                iv_indicator.update(IndicatorDataPoint(self._underlying, underlying_bar.end_time, underlying_bar.close))
                iv_indicator.update(IndicatorDataPoint(option_symbol, opt_bar.end_time, opt_bar.close))
            
            dte: int = (option_symbol.id.date - self._algo.time).days

            if dte <= 0:
                option_symbols_to_remove.append(option_symbol)
                # self._algo.log(f'IVMonitor._compare_iv - option {option_symbol.value} has already expired; now: {self._algo.time}; expiry: {option_symbol.id.date}; time delta: {option_symbol.id.date - self._algo.time} days')
                continue

            if not self._algo.portfolio[option_symbol].invested:
                continue

            if iv_indicator.is_ready:
                spot_price: float = underlying_bar.close

                if self._include_dividends:
                    dividends: float = self._option_pricing.get_dividends(option_symbol)
                else:
                    dividends: float = 0.

                rfr: float = self._option_pricing.get_risk_free_rate()
                discount_factor: float = self._option_pricing.get_discount_factor(rfr, dividends, option_symbol.id.date)
                forward_price: float = self._option_pricing.get_forward_price(spot_price, discount_factor)

                iv: float = self._option_pricing.bs_iv(
                    option_symbol=option_symbol, 
                    option_price=opt_bar.close, 
                    forward_price=forward_price, 
                    evaluation_dt=self._algo.time, 
                    fallback_iv_strategy=FallbackIVStrategy.CALL_PUT_PARITY_IV, 
                    discount_factor=discount_factor
                )

                opt_price: float = self._option_pricing.black_price(
                    option_symbol=option_symbol, 
                    forward_price=forward_price, 
                    vol=iv, 
                    discount_factor=discount_factor
                )

                # lagging 1 day
                precalculated_iv: float = self._algo.current_slice.option_chains.get(option_symbol.canonical).contracts.get(option_symbol).implied_volatility
                
                qc_indicator_iv: float = iv_indicator.current.value
                
                black_iv: float = iv

                # NOTE these can be removed later - assertion only
                qc_opt_price: float = opt_bar.close
                indicator_opt_price: float = iv_indicator.price.current.value
                black_opt_price: float = round(opt_price, get_precision(qc_opt_price)) # round to get rid of black model floating point precision to a point when price is comparable to QC's price
                indicator_spot_price: float = iv_indicator.underlying_price.current.value

                assert trim_to_a_point(iv_indicator.dividend_yield.current.value) == trim_to_a_point(dividends), f'indicator dividends: {iv_indicator.dividend_yield.current.value}, black dividends: {dividends}'
                assert trim_to_a_point(iv_indicator.risk_free_rate.current.value) == trim_to_a_point(rfr), f'indicator RFR: {iv_indicator.risk_free_rate.current.value}, black RFR: {rfr}'
                assert indicator_opt_price == qc_opt_price, f"indicator option price does not equal to last available QC option price; QC option price {qc_opt_price}; indicator option price: {indicator_opt_price}"
                # assert indicator_opt_price == black_opt_price, f"indicator option price does not equal to black model option price; black option price {black_opt_price}; indicator option price: {indicator_opt_price}"
                assert indicator_spot_price == spot_price, f"indicator spot price does not equal to QC spot price; QC spot price {spot_price}; indicator spot price: {indicator_spot_price}"

                diff: float = black_iv / qc_indicator_iv - 1
                if abs(diff) >= self._abs_delta_to_notify:
                    self._algo.log(
                        f'{option_symbol.value} - Black model IV: ' + '{:.2f}'.format(black_iv) + '; QC IV indicator: ' + '{:.2f}'.format(qc_indicator_iv) + '; diff: ' + '{:.2f}'.format(diff*100) + '%; DTE: ' + f'{dte}'
                    )
        
        # remove expired options
        for symbol in option_symbols_to_remove:
            del self._iv_by_option[symbol]
# region imports
from AlgorithmImports import *
import QuantLib as ql
from abc import ABC, abstractmethod
from dataclasses import dataclass
from math import e
from enum import Enum
# endregion

class FallbackIVStrategy(Enum):
    MIRROR_OPTION_IV = 1
    CALL_PUT_PARITY_IV = 2

class OptionPricingModel(ABC):
    def __init__(self, algo: QCAlgorithm) -> None:
        self._algo: QCAlgorithm = algo

    @abstractmethod
    def get_forward_price(
        self, 
        spot_price: float, 
        discount_factor: float, 
    ) -> float:
        ...

    @abstractmethod
    def get_risk_free_rate(self, calc_dt: Optional[datetime] = None) -> float:
        ...

    @abstractmethod
    def get_discount_factor(
        self, 
        rfr: float, 
        dividends: float, 
        expiry: datetime, 
        calc_dt: Optional[datetime] = None
    ) -> float:
        ...
    
    @abstractmethod
    def get_dividends(self, option_symbol: Symbol, calc_dt: Optional[datetime] = None) -> float:
        ...

    def get_history_price(self, option_symbol: Symbol, period: int = 15) -> Optional[float]:
        option_price: Optional[float] = None

        history: DataFrame = self._algo.history(QuoteBar, option_symbol, period, Resolution.HOUR)

        if not history.empty:
            last_bar = history.iloc[-1]
            if 'bidclose' in last_bar and 'askclose' in last_bar:
                option_price: float = (last_bar.bidclose + last_bar.askclose) / 2
            elif 'bidclose' in last_bar:
                option_price: float = last_bar.bidclose
            elif 'askclose' in last_bar:
                option_price: float = last_bar.askclose
        
        # if option_price is None:
        #     self._algo.log(f'OptionPricingModel.get_history_price - cannot fetch option {option_symbol} price')

        return option_price

    @staticmethod
    def get_t_days(expiration_date: datetime, calc_dt: datetime) -> float:
        dt = (expiration_date - calc_dt)
        days_dt: float = dt.days
        seconds_dt: float = dt.seconds / (24*60*60) # days
        t: float = seconds_dt + days_dt

        return t
        
class IndexOptionPricingModel(OptionPricingModel):
    def __init__(self, algo: QCAlgorithm) -> None:
        super(IndexOptionPricingModel, self).__init__(algo)

    def get_forward_price(
        self, 
        spot_price: float, 
        discount_factor: float
    ) -> float:

        F: float = spot_price / discount_factor
        return F

    def get_risk_free_rate(self, calc_dt: Optional[datetime] = None) -> float:
        if calc_dt is None: calc_dt = self._algo.time

        rfr: float = RiskFreeInterestRateModelExtensions.get_risk_free_rate(
            self._algo.risk_free_interest_rate_model, 
            calc_dt, calc_dt
        )
        return rfr

    def get_discount_factor(
        self, 
        rfr: float, 
        dividends: float, 
        expiry: datetime, 
        calc_dt: Optional[datetime] = None
    ) -> float:
        if calc_dt is None: calc_dt = self._algo.time

        discount_factor: float = e**((-rfr+dividends)*((expiry - calc_dt).days / 360))
        return discount_factor

    def get_dividends(self, option_symbol: Symbol, calc_dt: Optional[datetime] = None) -> float:
        if calc_dt is None: calc_dt = self._algo.time

        # https://www.quantconnect.com/docs/v2/writing-algorithms/reality-modeling/dividend-yield/key-concepts
        # DividendYieldProvider, which provides the continuous yield calculated from all dividend payoffs from the underlying Equity over the previous 350 days.
        # SPX => SPY dividends
        # NDX => QQQ dividends
        # VIX => 0
        return DividendYieldProvider.create_for_option(option_symbol).get_dividend_yield(calc_dt, self._algo.securities[option_symbol.underlying].price)

    def bs_iv(
        self, 
        option_symbol: Symbol, 
        option_price: float, 
        forward_price: float, 
        evaluation_dt: datetime, 
        fallback_iv_strategy: FallbackIVStrategy, 
        discount_factor: float = 1
    ) -> Optional[float]:
        
        iv: Optional[float] = self._black_iv(option_symbol, option_price, forward_price, fallback_iv_strategy, discount_factor)
        t: float = OptionPricingModel.get_t_days(option_symbol.id.date, evaluation_dt) / 360
        # iv *= np.sqrt(t)    # iv in annualized terms
        if iv is None:
            return None

        iv /= np.sqrt(t)      # to match previous results, IV monitor and indicators

        return iv

    def _black_iv(
        self, 
        option_symbol: Symbol, 
        option_price: float, 
        forward_price: float,
        fallback_iv_strategy: FallbackIVStrategy, 
        discount_factor: float = 1,
    ) -> Optional[float]:

        strike_price: float = option_symbol.id.strike_price
        option_type: int = ql.Option.Call if option_symbol.id.option_right == OptionRight.CALL else ql.Option.Put

        try:
            implied_vol: float = ql.blackFormulaImpliedStdDev(
                option_type, strike_price, forward_price, option_price, discount_factor
            )
            
            # implied_vol: float = ql.blackFormulaImpliedStdDev(
            #     option_type, 
            #     strike_price, 
            #     forward_price, 
            #     option_price, 
            #     discount_factor, 
            #     0, 
            #     0.000001, 
            #     1.0e-3,
            #     100
            # )
        except RuntimeError:
            # create mirror option in case there was an exception caught in the first black model solver
            mirror_option_symbol: Symbol = Symbol.create_option(
                option_symbol.underlying, 
                option_symbol.id.market, 
                option_symbol.id.option_style, 
                # self._algo.securities[option_symbol].style, 
                OptionRight.CALL if option_symbol.id.option_right == OptionRight.PUT else OptionRight.PUT,
                strike_price, 
                option_symbol.id.date
            )
            
            # if option_type == ql.Option.Call:
            #     assert (forward_price - strike_price) * discount_factor - option_price > 0
            #     implied_vol = 0
            # else:
            #     assert (strike_price - forward_price) * discount_factor - option_price < 0
            #     implied_vol = 0
            
            mirror_option_price: Optional[float] = self.get_history_price(mirror_option_symbol)

            if mirror_option_price is None:
                self._algo.log(f'IndexOptionPricingModel._black_iv - cannot fetch option {mirror_option_symbol} price')
                return None

            if fallback_iv_strategy == FallbackIVStrategy.MIRROR_OPTION_IV:
                implied_vol: float = self._mirror_option_iv(
                    mirror_option_symbol=mirror_option_symbol, 
                    mirror_option_price=mirror_option_price, 
                    forward_price=forward_price, 
                    discount_factor=discount_factor
                )
            elif fallback_iv_strategy == FallbackIVStrategy.CALL_PUT_PARITY_IV:
                implied_vol: float = self._call_put_parity_iv(
                    option_symbol=option_symbol, 
                    option_price=option_price,
                    mirror_option_symbol=mirror_option_symbol, 
                    mirror_option_price=mirror_option_price, 
                    forward_price=forward_price, 
                    discount_factor=discount_factor
                )

        return implied_vol

    def _mirror_option_iv(
        self,
        mirror_option_symbol: Symbol, 
        mirror_option_price: float, 
        forward_price: float, 
        discount_factor: float = 1
    ) -> float:

        strike_price: float = mirror_option_symbol.id.strike_price
        option_type: int = ql.Option.Call if mirror_option_symbol.id.option_right == OptionRight.CALL else ql.Option.Put
        implied_vol: float = ql.blackFormulaImpliedStdDev(
            option_type, strike_price, forward_price, mirror_option_price, discount_factor
        )

        # implied_vol: float = ql.blackFormulaImpliedStdDev(
        #     option_type, 
        #     strike_price, 
        #     forward_price, 
        #     mirror_option_price, 
        #     discount_factor, 
        #     0, 
        #     0.000001, 
        #     1.0e-3,
        #     100
        # )

        return implied_vol

    def _call_put_parity_iv(
        self, 
        option_symbol: Symbol, 
        option_price: float, 
        mirror_option_symbol: Symbol, 
        mirror_option_price: float, 
        forward_price: float,
        discount_factor: float = 1
    ) -> float:

        strike_price: float = option_symbol.id.strike_price
        spot_price: float = self._algo.securities[option_symbol.underlying].price
        call_price: Optional[float] = option_price if option_symbol.id.option_right == OptionRight.CALL else mirror_option_price
        put_price: Optional[float] = option_price if option_symbol.id.option_right == OptionRight.PUT else mirror_option_price
        
        if call_price is None or put_price is None:
            return None

        fixed_discount_factor: float = (put_price + spot_price - call_price) / strike_price
        fixed_forward_price: float = spot_price / fixed_discount_factor

        option_type: int = ql.Option.Call if option_symbol.id.option_right == OptionRight.CALL else ql.Option.Put

        implied_vol: float = ql.blackFormulaImpliedStdDev(
            option_type, strike_price, fixed_forward_price, option_price, fixed_discount_factor
        )

        # implied_vol: float = ql.blackFormulaImpliedStdDev(
        #     option_type, 
        #     strike_price, 
        #     fixed_forward_price, 
        #     option_price, 
        #     fixed_discount_factor, 
        #     0, 
        #     0.000001, 
        #     1.0e-3,
        #     100
        # )
        return implied_vol

    def black_price(
        self,
        option_symbol: Symbol, 
        forward_price: float,
        vol: float,
        discount_factor=1
    ) -> float:
        """Compute Black-Scholes option price for given volatility."""

        option_type: int = ql.Option.Call if option_symbol.id.option_right == OptionRight.CALL else ql.Option.Put
        price: float = ql.blackFormula(option_type, option_symbol.id.strike_price, forward_price, vol, discount_factor)
        return price
    
    def black_delta(
        self, 
        option_symbol: Symbol, 
        forward_price: float, 
        implied_vol: float, 
        discount_factor: float=1.
    ) -> float:

        strike_price: float = option_symbol.id.strike_price
        option_type: int = ql.Option.Call if option_symbol.id.option_right == OptionRight.CALL else ql.Option.Put

        strikepayoff = ql.PlainVanillaPayoff(option_type, strike_price)
        black = ql.BlackCalculator(strikepayoff, forward_price, implied_vol, discount_factor)
        delta: float = black.delta(discount_factor * forward_price)
        
        return delta
from AlgorithmImports import *


class AssetFilter:

    def eval(self, security: Symbol) -> bool:
        ...


class AssetTypeFilter(AssetFilter):
    def __init__(self, asset_types: Iterable):
        self._types = set(asset_types)

    def eval(self, security: Symbol):
        return security.security_type in self._types


class PutOrCallOptionFilter(AssetFilter): 
    def __init__(self, option_right, asset_type_filter: AssetTypeFilter=AssetTypeFilter({SecurityType.OPTION, SecurityType.INDEX_OPTION})):
        assert option_right in {OptionRight.CALL, OptionRight.PUT}
        self._right = option_right
        self._type_filter = asset_type_filter
    
    def eval(self, security: Symbol):
        return self._type_filter.eval(security) and security.id.OptionRight == self._right


class PutOptionFilter(PutOrCallOptionFilter):
    def __init__(self, asset_type_filter: Optional[AssetTypeFilter] = None):
        if asset_type_filter is not None: 
            super().__init__(option_right=OptionRight.PUT, asset_type_filter=asset_type_filter)
        else: 
            super().__init__(option_right=OptionRight.PUT)
    

class CallOptionFilter(PutOrCallOptionFilter):
      def __init__(self, asset_type_filter: Optional[AssetTypeFilter] = None):
        if asset_type_filter is not None: 
            super().__init__(option_right=OptionRight.CALL, asset_type_filter=asset_type_filter)
        else: 
            super().__init__(option_right=OptionRight.CALL)

class ChainFilters(AssetFilter):
    def __init__(self, security_filters):
        self._filters = security_filters
    
    def eval(self, security):
        return all(f.eval(security) for f in self._filters)

class EmptyFilter(AssetFilter):
    def eval(self, security): 
        return True
# region imports
from AlgorithmImports import *
from sortedcontainers import SortedList
from abc import ABC, abstractmethod
import copy
from enum import Enum
from logs import log_missing_option_in_chain, log_zero_greek
# endregion

class Direction(Enum): 
    LONG=1
    SHORT=-1

class ThresholdType(Enum):
    ABOVE = 1
    BELOW = -1

class ExpiryConstraint(Enum):
    PRIOR = -1
    CLOSEST = 0
    POST = 1

# for charting purpose
class GreeksType(Enum):
    ABSOLUTE = 1
    PERCENTAGE = 2

# for charting purpose
class PortfolioType(Enum):
    LAST_TRADE = 1
    WHOLE_PORTFOLIO = 2


class OptionsData_:
    def __init__(self, df: pd.DataFrame):
        self._options = df.Symbol


class OptionsData:

    def __init__(self, contracts_dict: Optional[Dict[Symbol, Any]]=None):
        
        if contracts_dict: 
            self._options = contracts_dict
        else: 
            self._options: Dict[Symbol, Any] = {}

    @property
    def option_symbols(self):
        return self._options.keys()
    
    def __add__(self, other):
        res = OptionsData()
        #res._option_symbols = self._option_symbols.union(other._option_symbols)
        res._options = self._options | other._options

        return res

    def get_subset(self, symbols: Iterable):
        return OptionsData({symbol: self._options[symbol] for symbol in symbols})

    def __iter__(self):
        return iter(self._options)

    def __len__(self):
        return len(self.option_symbols)
    
    def items(self):
        return self._options.items()

    def __getitem__(self, key):
        return self._options[key]


    def segregate_by_underlying(self):

        underlying_dict = {}
        for key, value in self.items():
            add = OptionsData({key: value})
            if key.underlying not in underlying_dict:
                underlying_dict[key.underlying] = add
            else:
                underlying_dict[key.underlying] += add
        return underlying_dict


class TradeData:
    def __init__(
        self, 
        trade_data: Dict[Symbol, int] = None,
        time_stamp: Optional[datetime] = None
    ):
        self._trade_dict: Dict[Symbol, int] 

        if trade_data is None:
            self._trade_dict = {}
        else:  
            self._trade_dict = trade_data

        self._time_stamp = time_stamp
        self._unwind_dt: Optional[datetime] = None
        self._unwined: bool = False

    @property
    def time_stamp(self):
        return self._time_stamp

    @property
    def unwined(self) -> bool:
        return self._unwined

    @unwined.setter
    def unwined(self, value: bool):
        self._unwined = value

    @property
    def unwind_date(self) -> datetime:
        return self._unwind_dt

    @unwind_date.setter
    def unwind_date(self, value: datetime):
        self._unwind_dt = value
        
    def __contains__(self, item):
        return item in self._trade_dict

    @property
    def trade_data(self):
        return self._trade_dict
    
    def items(self):
        return self._trade_dict.items()
    
    def keys(self):
        return self._trade_dict.keys()
    
    def __iter__(self):
        return iter(self._trade_dict)
    
    def copy(self):
        return TradeData(self._trade_dict.copy(), self._time_stamp)

    def __getitem__(self, key):
        return self._trade_dict[key]

    def __setitem__(self, key, value): 
        self._trade_dict[key] = value

    def __delitem__(self, key):
        del self._trade_dict[key]

    def add(self, other): 
        for symbol in other: 
            if symbol in self._trade_dict:
                self._trade_dict[symbol] += other[symbol]
            else: 
                self._trade_dict[symbol] = other[symbol] 
        
        if self._time_stamp is None: 
            self._time_stamp = other._time_stamp
        elif other._time_stamp is None: 
            self._time_stamp = self._time_stamp
        else:
            self._time_stamp = max(self._time_stamp, other._time_stamp)
        
        return self

    def __add__(self, other):
        res = self.copy()
        res.add(other)
        return res

    def __len__(self):
        return len(self._trade_dict)


class TradeRecord:
    def __init__(self, id: Any, trade_data: TradeData, time_stamp: datetime):

        self._id = id 
        self._trade_data = trade_data
        self.time_stamp = time_stamp
    
    def __hash__(self): 
        return hash(self.id)

    def items(self):
        return self._trade_data.items()


    def __eq__(self, other):
        return self.id == other.id


    @property
    def time_stamp(self):
        return self._time_stamp
    

    @property
    def trade_data(self):
        return self._trade_data
    

    @property
    def id(self):
        return self._id


    @time_stamp.setter
    def time_stamp(self, value: datetime):
        self._time_stamp = value

    def __iter__(self):
        return iter(self.trade_data)



class TradeTimeStamp:
    def __init__(self, id, time_stamp: datetime):
        self.id = id
        self.time_stamp = time_stamp




class TradeCollection:

    def __init__(self):
        self._trades: Dict[Any, Trade] = {}
        self._time_stamps: SortedList[TradeTimeStamp] = SortedList(key=lambda x: x.time_stamp) 
        self._assets_trades: Dict[Symbol, Set[Any]] = {} #saves correpondence asset_id -> set of trade_id including the asset 
        self._positions: Dict[Symbol, Union[float, int]] = {}     #correspondence asset_id -> net position of strategy in asset
    
    @property
    def positions(self):
        return self._positions
    

    def add(self, trade: Trade):
        
        assert trade.id not in self._trades
        self._trades[trade.id] = trade
        self._time_stamps.add(TradeTimeStamp(id=trade.id, time_stamp=trade.time_stamp))

        
        for asset_id in trade.trade_data:
            
            # update positions
            if asset_id in self._positions: 
                self._positions[asset_id] += trade.trade_data[asset_id]
                # ToDo: Sanity check. Delete.                 
                assert asset_id in self._assets_trades
                # update assets trades
                self._assets_trades[asset_id].add(trade.id)
           
            else: 
                self._positions[asset_id] = trade.trade_data[asset_id]
                # ToDo: Sanity check. Delete.                 
                assert asset_id not in self._assets_trades 
                self._assets_trades[asset_id] = {trade.id}

    def __getitem__(self, id):
        return self.trades[id]
        
    @property
    def trades(self):
        return self._trades


    def _del_trade_in_positions_and_assets_trades(self, trade):
        # adjust positions and assets->trades
        for asset_id in trade: 
            self._assets_trades[asset_id].remove(trade.id)
            self._positions[asset_id] -= trade.trade_data[asset_id]
            if len(self._assets_trades[asset_id]) == 0:
                self._assets_trades.pop(asset_id)
                assert self._positions[asset_id] == 0
                self._positions.pop(asset_id)
    @property
    def asset_trades(self):
        return self._assets_trades

    def get_all_trades_with_asset(self, asset_id):
        return self.asset_trades[asset_id]

        
    def delete_trade(self, id):

        assert id in self._trades
        trade = self._trades.pop(id)

        #adjust positions and assets->trades
        self._del_trade_in_positions_and_assets_trades(trade)

        # adjust time_stamps sorted list    
        self._time_stamps.remove(TradeTimeStamp(trade.id, trade.time_stamp))

    def __contains__(self, trade_id):
        return trade_id in self._trades


    def delete_all_trades_before_date(self, date: datetime):

        index = self._time_stamps.bisect_right(TradeTimeStamp(None, date))
        for x in self._time_stamps[:index]:
            trade = self.trades.pop(x.id)
            self._del_trade_in_positions_and_assets_trades(trade)
        
        del self._time_stamps[:index]
    

class ScheduleRule: 
    def __init__(self, date_rule, time_rule):
        self._date_rule = date_rule
        self._time_rule = time_rule
    
    @property
    def date_rule(self):
        return self._date_rule
    
    @property
    def time_rule(self):
        return self._time_rule



class TradeIdGeneratorMeta(type, ABC):
    _instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(TradeIdGeneratorMeta, cls).__call__(*args, **kwargs)
            return cls._instances[cls]
        else:
            return cls._instances[cls]


    @abstractmethod
    def gen_trade_id(self, trade_data: TradeData):
        ...

class BaseTradeIdGenerator(metaclass=TradeIdGeneratorMeta):
    pass

class IncrementIdGenerator(BaseTradeIdGenerator):

    def __init__(self, start: int=0):
        self._current = start
    
    def gen_trade_id(self, trade_data):
        self._current += 1
        return self._current

class TimeMeasure():
    def __init__(self, class_name: str, units: str = 'seconds') -> None:
        self._class_name: str = class_name
        self._min: float = sys.float_info.max
        self._max: float = sys.float_info.min
        self._total: float = 0.
        self._count: int = 0.
        self._units: str = units

        self._times: List[float] = []
        self._percentiles: List[int] = [0, 5, 10, 25, 50, 75, 90, 95, 99, 100]
    
    def is_ready(self) -> bool:
        return self._count != 0

    def update(self, delta_retrieval_time: float) -> None:
        self._times.append(delta_retrieval_time)

        self._count += 1
        self._total += delta_retrieval_time

        if delta_retrieval_time > self._max:
            self._max = delta_retrieval_time
        if delta_retrieval_time < self._min:
            self._min = delta_retrieval_time
    
    def print_percentiles(self, algo: QCAlgorithm) -> None:
        results = np.percentile(self._times, self._percentiles)

        # percentiles
        algo.log(f"{self._class_name} - Percentiles")
        for p, v in zip(self._percentiles, results):
            algo.log(f"{p}% percentile = {v:.4f} {self._units}")

    def print_histogram(self, algo: QCAlgorithm) -> None:
        results = np.percentile(self._times, self._percentiles)

        # histogram
        bounds = results
        counts, _ = np.histogram(self._times, bins=bounds)

        algo.log(f"{self._class_name} - Histogram")
        for i in range(len(counts)):
            algo.log(f"{self._percentiles[i]}–{self._percentiles[i+1]} percentile: {counts[i]} occurrences")

    def print_stats(self, algo: QCAlgorithm) -> None:
        algo.log(
            f"{self._class_name} - MEAN: {(self._total / self._count):.4f} {self._units}, MIN: {self._min:.4f} {self._units}, MAX: {self._max:.4f} {self._units}, # of records: {len(self._times)}; TOTAL: {self._total:.4f} {self._units}"
        )

def mid_quote(security): 
    if security.bid_price and security.ask_price:
        return (security.bid_price + security.ask_price) / 2
    elif security.bid_price:
        return security.bid_price
    elif security.ask_price: 
        return security.ask_price
    else: 
        raise RuntimeError

def portfolio_cash(algo):
    return algo.portfolio.cash


def portfolio_value(algo):
    return algo.portfolio.total_portfolio_value


def subscribe_to_option(algo, symbol, resolution):
    if symbol.underlying.security_type == SecurityType.INDEX:
        return algo.AddIndexOptionContract(symbol=symbol, resolution=resolution)
    
    elif symbol.underlying.security_type == SecurityType.EQUITY:
        return algo.AddOptionContract(symbol=symbol, resolution=resolution)
    
    else:
        raise NotImplementedError

SECTOR_ETF_MAP: Dict[int, str] = {
    MorningstarSectorCode.BASIC_MATERIALS : 'XLB', 
    MorningstarSectorCode.CONSUMER_CYCLICAL : 'XLY', 
    MorningstarSectorCode.FINANCIAL_SERVICES : 'XLF', 
    MorningstarSectorCode.REAL_ESTATE : 'XLRE', 
    MorningstarSectorCode.CONSUMER_DEFENSIVE : 'XLP', 
    MorningstarSectorCode.HEALTHCARE : 'XLV', 
    MorningstarSectorCode.UTILITIES : 'XLU', 
    MorningstarSectorCode.COMMUNICATION_SERVICES : 'XLC', 
    MorningstarSectorCode.ENERGY : 'XLE', 
    MorningstarSectorCode.INDUSTRIALS : 'XLI', 
    MorningstarSectorCode.TECHNOLOGY : 'XLK', 
}

def get_sector_market_caps(algo: QCAlgorithm, universe) -> Dict[Symbol, float]:

    market_cap_by_sector: Dict[Symbol, float] = {algo.symbol(ticker) : 0 for ticker in SECTOR_ETF_MAP.values()}
    for member in universe.members:
        # assert member.value.fundamentals.asset_classification.morningstar_sector_code in SECTOR_ETF_MAP

        if member.value.fundamentals.asset_classification.morningstar_sector_code in SECTOR_ETF_MAP:
            sector_ticker: str = SECTOR_ETF_MAP[member.value.fundamentals.asset_classification.morningstar_sector_code]
            market_cap_by_sector[algo.symbol(sector_ticker)] += member.value.fundamentals.market_cap
            
    return market_cap_by_sector

def get_sector_weights(algo: QCAlgorithm, universe) -> Dict[Symbol, float]:
    market_cap_by_sector: Dict[Symbol, float] = get_sector_market_caps(algo, universe)

    total_market_cap: float = sum(market_cap_by_sector.values())
    weight_by_sector: Dict[Symbol, float] = {symbol : cap / total_market_cap for symbol, cap in market_cap_by_sector.items()}

    return weight_by_sector

def get_previous_trading_day(algo: QCAlgorithm, symbol: Union[str, Symbol], time: datetime) -> datetime:
    # NOTE symbols has to be subscribed to
    return algo.securities[symbol].exchange.hours.get_previous_trading_day(time)

def calc_trade_expenditure(algo, trade_data: TradeData, contract_multiplier: int = 100) -> float:
    total_expenditure = 0
    for option in trade_data:
        quantity: float = trade_data[option]
        price: float = algo.securities[option].ask_price if quantity > 0 else algo.securities[option].bid_price
        total_expenditure += abs(quantity) * price * contract_multiplier

    return total_expenditure

def calc_trade_notional(algo, trade_data: TradeData, contract_multiplier: int = 100) -> float:
    total_notional = 0
    for option in trade_data: 
        quantity: float = trade_data[option]
        total_notional += quantity * contract_multiplier * algo.securities[option.underlying].price

    return abs(total_notional)

# NOTE version with precalculated delta
# TODO unify the way of getting delta with getting delta within delta hedge
def calc_trade_greek(
    algo: QCAlgorithm, 
    trade_data: TradeData, 
    contract_multiplier: int = 100, 
    greek_name: str = 'gamma', 
    percentage_terms: bool = True, 
    is_new_trade: bool = False, # TODO only for logging/debugging - remove later
    sender_str: str = '',  # TODO only for logging/debugging - remove later
) -> float:

    use_precalculated_delta: bool = True
    canonicals = set(list(map(lambda x: x.canonical, trade_data)))
    if use_precalculated_delta:
        chains: Dict[Symbol, OptionChain] = {
            canonical: algo.option_chain(canonical) for canonical in canonicals
        }
    else:
        chains: Dict[Symbol, OptionChain] = {
            canonical: algo.current_slice.option_chains.get(canonical) for canonical in canonicals
        }

    total_greek: float = 0
    for opt_symbol in trade_data:
        opt_chain: OptionChain = chains[opt_symbol.canonical] if opt_symbol.canonical in chains else None

        if opt_chain is None:
            algo.log(f'data.py.calc_trade_greek.{sender_str} - {opt_symbol.canonical} is not in current_slice.option_chains')
            continue

        opt_contract = opt_chain.contracts.get(opt_symbol)
        if opt_contract is None:
            log_missing_option_in_chain(
                algo, 
                f'data.py.calc_trade_greek.{sender_str}', 
                opt_symbol, 
                False
            )
            continue

        greek_val: float = getattr(opt_contract.greeks, greek_name)

        if greek_val == 0:
            log_zero_greek(
                algo, 
                f'data.py.calc_trade_greek.{sender_str}', 
                opt_symbol, 
                opt_contract, 
                "gamma", 
                False
            )

        if greek_name == 'gamma':
            multiplier: float = (0.01 * (algo.securities[opt_symbol.underlying].price ** 2)) if percentage_terms else 1.
        else:
            multiplier: float = (0.01 * algo.securities[opt_symbol.underlying].price) if percentage_terms else 1.

        quantity: float = trade_data[opt_symbol]
        total_greek += quantity * greek_val * contract_multiplier * multiplier

    return abs(total_greek)

# NOTE old version
# TODO if this version is used again, make sure we are evaluating the price model once for each option symbol and then look up each greek from the same model
# def calc_trade_greek(
    # algo: QCAlgorithm, 
#     trade_data: TradeData, 
#     contract_multiplier: int = 100, 
#     greek_name: str = 'gamma', 
#     percentage_terms: bool = True, 
#     is_new_trade: bool = False, # TODO only for logging/debugging - remove later
#     sender_str: str = '',  # TODO only for logging/debugging - remove later
# ) -> float:

#     total_greek: float = 0
#     for opt_symbol in trade_data:
#         opt_contract = algo.option_chain(opt_symbol.canonical).contracts.get(opt_symbol)

#         if opt_contract is None:
#             log_missing_option_in_chain(
#                 algo, 
#                 f'data.py.calc_trade_greek.{sender_str}', 
#                 opt_symbol, 
#                 is_new_trade
#             )
#             continue

#         price_model: OptionPriceModelResult = get_price_model_from_opt_symbol(algo, algo.securities[opt_contract])
#         quantity: float = trade_data[opt_symbol]

#         greek_val: float = getattr(price_model.greeks, greek_name)

#         if greek_val == 0:
#             log_zero_greek(
#                 algo, 
#                 f'data.py.calc_trade_greek.{sender_str}', 
#                 opt_symbol, 
#                 opt_contract, 
#                 greek_name, 
#                 is_new_trade
#             )

#         if greek_name == 'gamma':
#             multiplier: float = (0.01 * (algo.securities[opt_symbol.underlying].price ** 2)) if percentage_terms else 1.
#         else:
#             multiplier: float = (0.01 * algo.securities[opt_symbol.underlying].price) if percentage_terms else 1.
        
#         total_greek += quantity * greek_val * contract_multiplier * multiplier

#     return abs(total_greek)

def target_notional(
    algo: QCAlgorithm, 
    trade_data: TradeData, 
    holding_calculator: Callable, 
    notional_fraction_target: float = 2/52
) -> TradeData:

    target_notional: float = holding_calculator(algo) * notional_fraction_target
    notional: float = calc_trade_notional(algo, trade_data)
    coeff: float = target_notional / notional
    
    res: TradeData = TradeData()
    for option in trade_data: 
        res[option] = trade_data[option] * coeff

    return res

def target_greek(
    algo: QCAlgorithm, 
    trade_data: TradeData, 
    greek_target_cash_terms: float, 
    greek_name: str = 'gamma'
) -> TradeData:

    # sum of all abs quantities of all options
    quantity_sum: float = abs(np.array(list(trade_data.trade_data.values()))).sum()
    
    # adjust individual quantities by the sum of quantities and create proxy portfolio
    proxy_res = {}
    for option in trade_data: 
        proxy_res[option] = trade_data[option] / quantity_sum
    
    # target proxy portfolio
    proxy_total_greek: float = calc_trade_greek(algo, proxy_res, greek_name)
    proxy_greek_target: float = greek_target_cash_terms / quantity_sum
    proxy_coff: float = proxy_greek_target / proxy_total_greek
    for option in proxy_res:
        proxy_res[option] *= proxy_coff
    
    # map proxy trades back to real trades
    res: TradeData = TradeData()
    for option in trade_data:
        res[option] = proxy_res[option] * quantity_sum
    
    # assert that greek target is met
    total_greek: float = calc_trade_greek(algo, res, greek_name)
    assert round(total_greek) == self._greek_target_cash_terms, f'total_greek = {total_greek}; target {greek_target_cash_terms}'

    return res

def filter_invested_options(
    algo: QCAlgorithm, 
    positions: Dict[Symbol, Union[float, int]],     # symbol -> quantity dictionary
    option_types: List[SecurityType], 
    underlying_types: List[SecurityType], 
    underlying_symbols: Optional[Symbol] = None, 
    only_invested: bool = True, 
    option_right: Optional[OptionRight] = None, 
    direction: Optional[Direction] = None
) -> List[Symbol]:
    
    ''' Goes through the positions to filter them by multiple criteria '''

    result: List[Symbol] = []
    for opt_symbol in positions:
        if algo.securities[opt_symbol].type not in option_types:
            continue

        if algo.securities[opt_symbol.underlying].type not in underlying_types:
            continue

        if option_right is not None:
            if opt_symbol.id.option_right != option_right:
                continue
        
        if direction is not None:
            quantity: float|int = positions[opt_symbol]
            opt_direction: Direction = Direction.LONG if quantity > 0 else Direction.SHORT
            if opt_direction != direction or quantity == 0:
                continue

        if underlying_symbols is not None:
            if opt_symbol.underlying not in underlying_symbols:
                continue
        
        if only_invested:
            if not algo.portfolio[opt_symbol].invested:
                continue
        
        # TODO discuss
        if opt_symbol.id.date <= algo.time + timedelta(days=1):
            continue
        
        result.append(opt_symbol)
    
    return result

def get_price_model_from_opt_symbol(algo: QCAlgorithm, option_symbol) -> OptionPriceModelResult:
    contract_data = OptionContract(option_symbol)
    contract_data.time = algo.time
    result: OptionPriceModelResult = option_symbol.evaluate_price_model(None, contract_data)
    return result

# NOTE used in BaseTradeManager_ at the moment for manual indicator update with recent TradeData
class EventSource:
    def __init__(self):
        self.listeners = []

    def __iadd__(self, listener):
        self.listeners.append(listener)
        return self

    def emit(self, *args, **kwargs):
        for listener in self.listeners:
            listener(*args, **kwargs)
# region imports
from AlgorithmImports import *
from abc import ABC, abstractmethod
from data import *
from logs import log_missing_option_in_chain, log_zero_greek
import time as t
import sys
# endregion

class DeltaHedge(ABC):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        trade_manager=None, 
        schedule_rule: ScheduleRule = None,
        multiplier: float = 1., 
        use_precalculated_delta: bool = False
    ) -> None:

        self._algo: QCAlgorithm = algo
        self._trade_manager = trade_manager
        self._schedule_rule = schedule_rule
        self._multiplier: float = multiplier

        self._use_precalculated_delta: bool = use_precalculated_delta

        # TODO remove
        # self._delta_time_measures: TimeMeasure = TimeMeasure("")
        # self._chain_lookup_measures: TimeMeasure = TimeMeasure("")

        self._algo.schedule.on(
            algo.date_rules.on(2025, 11, 26), 
            algo.time_rules.before_market_close('SPX', 1), 
            self._print_measures
        )

    def _print_measures(self) -> None:
        pass
        # if self._delta_time_measures.is_ready():
        #     self._delta_time_measures.print_percentiles(self._algo)
        #     self._delta_time_measures.print_histogram(self._algo)
        #     self._delta_time_measures.print_stats(self._algo)
        
        # if self._chain_lookup_measures.is_ready():
        #     self._chain_lookup_measures.print_percentiles(self._algo)
        #     self._chain_lookup_measures.print_histogram(self._algo)
        #     self._chain_lookup_measures.print_stats(self._algo)

    def schedule(self) -> None:
        self._algo.schedule.on(
            self._schedule_rule.date_rule,
            self._schedule_rule.time_rule, 
            self._delta_hedge
        )

    @property 
    def trade_manager(self):
        return self._trade_manager 

    @trade_manager.setter
    def trade_manager(self, value) -> None:
        self._trade_manager = value

    @abstractmethod
    def _calc_hedge(self) -> TradeData:
        ...

    def _delta_hedge(self) -> None:
        trade_data: TradeData = self._calc_hedge()
        if len(trade_data.trade_data) > 0:
            self._trade_manager.make_hedge_trade(trade_data)
    
class EquityOptionsDeltaHedge(DeltaHedge):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        trade_manager=None, 
        schedule_rule: ScheduleRule = None, 
        multiplier: float = 1., 
        use_precalculated_delta: bool = False  # NOTE lagging version uses self.option_chains whereas non-lagging version uses current slice option chain (which is not precalculated)
    ) -> None:

        super().__init__(algo, trade_manager, schedule_rule, multiplier, use_precalculated_delta)
        
        # self._delta_time_measures: TimeMeasure = TimeMeasure("EquityOptionsDeltaHedge Delta property retrieval time")
        # self._chain_lookup_measures: TimeMeasure = TimeMeasure("EquityOptionsDeltaHedge Option chain lookup")
        
    def _calc_hedge(self) -> TradeData:
        assert self._trade_manager is not None
        
        trade_data: TradeData = TradeData(time_stamp=self._algo.time)
        slice = self._algo.current_slice
        
        delta_by_underlying: Dict[Symbol, float] = {}
        trade_hist: TradeCollection = self._trade_manager.trades_hist

        option_symbols: List[Symbol] = filter_invested_options(
            algo=self._algo, 
            positions=trade_hist.positions, 
            option_types=[SecurityType.OPTION], 
            underlying_types=[SecurityType.EQUITY]
        )

        # find the chain for each canonical
        canonicals = set(list(map(lambda x: x.canonical, option_symbols)))
        if self._use_precalculated_delta:
            chains: Dict[Symbol, OptionChain] = {
                canonical: self._algo.option_chain(canonical) for canonical in canonicals
            }
        else:
            chains: Dict[Symbol, OptionChain] = {
                canonical: self._algo.current_slice.option_chains.get(canonical) for canonical in canonicals
            }

        for opt_symbol in option_symbols:
            # begin_time = t.time()
            opt_chain: OptionChain = chains[opt_symbol.canonical] if opt_symbol.canonical in chains else None
            # elapsed_time = t.time() - begin_time
            # self._chain_lookup_measures.update(elapsed_time)

            if opt_chain is None:
                self._algo.log(f'EquityOptionsDeltaHedge._calc_hedge - {opt_symbol.canonical} is not in option_chain')
                continue
            
            opt_contract: OptionContract = opt_chain.contracts.get(opt_symbol)

            if opt_contract is None:
                log_missing_option_in_chain(
                    self._algo, 
                    'EquityOptionsDeltaHedge._calc_hedge', 
                    opt_symbol, 
                    False
                )
                continue
            
            # begin_time = t.time()
            delta: float = opt_contract.greeks.delta
            # elapsed_time = t.time() - begin_time
            # self._delta_time_measures.update(elapsed_time)

            if delta == 0:
                log_zero_greek(
                    self._algo, 
                    'EquityOptionsDeltaHedge._calc_hedge', 
                    opt_symbol, 
                    opt_contract, 
                    "gamma", 
                    False
                )

            delta: int = delta * self._algo.securities[opt_symbol].symbol_properties.contract_multiplier * self._algo.portfolio[opt_symbol].quantity
            underlying_symbol: Symbol = opt_symbol.underlying

            if underlying_symbol not in delta_by_underlying:
                delta_by_underlying[underlying_symbol] = delta
            else:
                delta_by_underlying[underlying_symbol] += delta
        
        for underlying, delta in delta_by_underlying.items():
            quantity_to_hedge: float = -delta * self._multiplier
            quantity_to_hedge = quantity_to_hedge - self._algo.portfolio[underlying].quantity
    
            # prevent not tradable quantity
            if abs(quantity_to_hedge) > 1.:
                trade_data.add(TradeData({underlying: quantity_to_hedge}, self._algo.time))

        # equity symbols to liquidate
        eq_to_liquidate: List[Symbol] = [
            x for x in self._trade_manager.trades_hist.positions.keys() \
            if self._algo.securities[x].type == SecurityType.EQUITY \
            and self._algo.portfolio[x].invested \
            and x not in trade_data
        ]

        # liquidate stock underlying whose options are not being traded at the moment
        for eq_symbol in eq_to_liquidate:
            if eq_symbol not in trade_data:
                trade_data.add(TradeData({eq_symbol: -self._algo.portfolio[eq_symbol].quantity}, self._algo.time))

        return trade_data

class IndexDeltaHedge(DeltaHedge):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        underlying_index: Symbol, 
        continuous_future: Future, 
        trade_manager=None, 
        schedule_rule: ScheduleRule = None, 
        multiplier: float = 1., 
        use_precalculated_delta: bool = False  # NOTE lagging version uses self.option_chains whereas non-lagging version uses current slice option chain (which is not precalculated)
    ) -> None:

        self._continuous_future: Future = continuous_future
        self._underlying_index: Symbol = underlying_index
        self._current_future: Optional[Symbol] = self._continuous_future.mapped

        super().__init__(algo, trade_manager, schedule_rule, multiplier, use_precalculated_delta)

        # self._delta_time_measures: TimeMeasure = TimeMeasure("EquityOptionsDeltaHedge Delta property retrieval time")
        # self._chain_lookup_measures: TimeMeasure = TimeMeasure("EquityOptionsDeltaHedge Option chain lookup")

    def _calc_hedge(self) -> TradeData:
        assert self._trade_manager is not None

        trade_data: TradeData = TradeData(time_stamp=self._algo.time)
        
        if self._current_future is None:
            self._current_future = self._continuous_future.mapped
            self._algo.log('IndexDeltaHedge._calc_hedge - no mapped future to hedge')
            return trade_data

        if self._current_future != self._continuous_future.mapped:
            self._algo.log('IndexDeltaHedge._calc_hedge - self._current_fut != self._continuous_future.mapped')
            trade_data.add(TradeData({self._current_future: -self._algo.portfolio[self._current_future].quantity}, self._algo.time))
            self._current_future = self._continuous_future.mapped
        
        trade_hist: TradeCollection = self._trade_manager.trades_hist

        option_symbols: List[Symbol] = filter_invested_options(
            algo=self._algo, 
            positions=trade_hist.positions, 
            option_types=[SecurityType.INDEX_OPTION], 
            underlying_types=[SecurityType.INDEX], 
            underlying_symbols=[self._underlying_index]
        )

        # find the chain for each canonical
        canonicals = set(list(map(lambda x: x.canonical, option_symbols)))
        if self._use_precalculated_delta:
            chains: Dict[Symbol, OptionChain] = {
                canonical: self._algo.option_chain(canonical) for canonical in canonicals
            }
        else:
            chains: Dict[Symbol, OptionChain] = {
                canonical: self._algo.current_slice.option_chains.get(canonical) for canonical in canonicals
            }

        total_delta: float = 0.
        for opt_symbol in option_symbols:
            # begin_time = t.time()
            opt_chain: OptionChain = chains[opt_symbol.canonical] if opt_symbol.canonical in chains else None
            # elapsed_time = t.time() - begin_time
            # self._chain_lookup_measures.update(elapsed_time)

            if opt_chain is None:
                self._algo.log(f'IndexDeltaHedge._calc_hedge - {opt_symbol.canonical} is not in current_slice.option_chains')
                continue

            opt_contract = opt_chain.contracts.get(opt_symbol)

            if opt_contract is None:
                log_missing_option_in_chain(
                    self._algo, 
                    'IndexDeltaHedge._calc_hedge', 
                    opt_symbol, 
                    False
                )
                continue

            # begin_time = t.time()
            delta: float = opt_contract.greeks.delta
            # elapsed_time = t.time() - begin_time
            # self._delta_time_measures.update(elapsed_time)

            if delta == 0:
                log_zero_greek(
                    self._algo, 
                    'EquityOptionsDeltaHedge._calc_hedge', 
                    opt_symbol, 
                    opt_contract, 
                    "gamma", 
                    False
                )

            delta: int = delta * self._algo.securities[opt_symbol].symbol_properties.contract_multiplier * self._algo.portfolio[opt_symbol].quantity

            # log IV above .3
            # iv: float = opt_model.implied_volatility
            # iv_threshold: float = 0.3
            # if iv > iv_threshold:
            #     self._algo.log(f'IV is above {iv_threshold}: IV is {iv} for {opt_symbol.value}; underlying price: {self._algo.securities[opt_symbol.underlying].price}; strike price: {opt_symbol.id.strike_price}')

            total_delta += delta
            
        quantity_to_hedge: float = -(total_delta * self._multiplier) // self._algo.securities[self._current_future].symbol_properties.contract_multiplier
        quantity_to_hedge = quantity_to_hedge - self._algo.portfolio[self._current_future].quantity
        
        # prevent not tradable quantity
        if abs(quantity_to_hedge) > 1.:
            trade_data.add(TradeData({self._current_future: quantity_to_hedge}, self._algo.time))
        
        # TODO scale by futures delta, if needed. For now, only one future symbol is held at the specific time

        # NOTE only previous and new (rolled) symbol can be traded at the same time (old one can be liquidated)
        assert len(trade_data) <= 2

        return trade_data
# region imports
from AlgorithmImports import *
from data import *
from option_providers import *
from quantity_calculators import *
from filters import *
import time as t
from option_universe import *
import time
# endregion

class EnterStrategy(ABC):
    @abstractmethod
    def calc_trade(self) -> TradeData:
        ...

class BaseEnterStrategy(EnterStrategy):

    def __init__(self, algo, 
                 option_provider: OptionProvider, 
                 quantity_calculator: QuantityCalculator, 
                 canonical_options: List,
                 ):
        
        self._algo = algo
        self._trade_manager: Optional[TradeManager] = None
        self._option_provider = option_provider
        self._quantity_calculator = quantity_calculator
        self._canonical_options = canonical_options
    
    @property
    def quantity_calculator(self):
        return self._quantity_calculator

    @property
    def option_provider(self):
        return self._option_provider
    
    @property 
    def trade_manager(self):
        return self._trade_manager 

    @trade_manager.setter
    def trade_manager(self, value):
        self._trade_manager = value
    
    def _validate(self):
        pass
    
    def calc_trade(self) -> TradeData:
        options = OptionsData()
        for canonical_option in self._canonical_options:
            options += OptionsData({option.key: option.value for option in self._algo.option_chain(canonical_option).contracts})
        
        options_to_trade = self._option_provider.get_options(options)
        if len(options_to_trade) == 0:
            return TradeData()
        return self._quantity_calculator.calc_quantity(options_to_trade)
        
class TwoLegsBudgetEnterStrategy(EnterStrategy):
    def __init__(self, 
                algo, 
                main_leg_enter_strategy,
                budget_frac,
                budget_leg_option_provider, 
                budget_leg_quantity_calculator,
                canonical_options,
                lower_bound = 0.7, 
                upper_bound = 1.3,
                contract_multiplier: int = 100, 
                budget_leg_filter: Filter = None
                ):
        self._main_leg_enter_strategy = main_leg_enter_strategy
        self._budget_leg_option_provider = budget_leg_option_provider
        self._budget_leg_quantity_calculator = budget_leg_quantity_calculator
        self._algo = algo
        self._contract_multiplier = contract_multiplier
        self._budget_frac = budget_frac
        self._canonical_options = canonical_options
        self._budget_leg_filter = budget_leg_filter
        
        self._secondary_option_provider = DynamicBudgetOptionProvider(
            algo, 
            option_provider = self._budget_leg_option_provider, 
            quantity_calculator = self._budget_leg_quantity_calculator, 
            lower_bound=lower_bound, 
            upper_bound=upper_bound, 
            contract_multiplier=contract_multiplier)

    def _get_cost(self, trade_data: TradeData): 
        cost = 0
        for option in trade_data: 
            data = self._algo.history(option, 15, resolution=Resolution.HOUR, fill_forward=True)
            if not hasattr(data, 'askclose') or not hasattr(data, 'bidclose'):
                self._algo.log('no data')
            cost += (data.askclose.values[-1] + data.bidclose.values[-1]) / 2 * self._contract_multiplier * trade_data[option]

        return cost

    def calc_trade(self) -> TradeData: 
        main_trade_data = self._main_leg_enter_strategy.calc_trade()
        if len(main_trade_data) == 0: 
            return TradeData()
        #assert len(main_trade_data) == 1

        # filter on budget leg
        if self._budget_leg_filter is not None and not self._budget_leg_filter.filter():
            return main_trade_data

        cost = self._get_cost(main_trade_data)

        self._secondary_option_provider.current_budget = cost - self._budget_frac * self._algo.portfolio.cash
        options_pool = OptionsData()
        for canonical_option in self._canonical_options:
            options_pool += OptionsData({option.key: option.value for option in self._algo.option_chain(canonical_option).contracts})
        
        option_data = self._secondary_option_provider.get_options(options_pool)
        assert len(option_data) == 1

        trade_data = self._budget_leg_quantity_calculator.calc_quantity(option_data)
        main_trade_data += trade_data

        return main_trade_data


        

class MultiLeggedEnterStrategy(EnterStrategy):
    def __init__(self, legs: Iterable[EnterStrategy]): 
        self._legs = legs 
    

    def calc_trade(self) -> TradeData:
        trade_data = TradeData()
        for leg in self._legs:
            trade_data.add(leg.calc_trade())
        
        return trade_data

class BaseDispersionEnterStrategy(EnterStrategy):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        index_option_provider: OptionProvider, 
        single_stock_option_provider: OptionProvider, 
        single_stock_quantity_calculator: QuantityCalculator, 
        index_quantity_calculator: QuantityCalculator, 
        etf_symbol: str, 
        canonical_index_option: Symbol, 
        no_of_stocks: Optional[int] = None, 
        contract_multiplier = 100, 
        subscription_resolution: Resolution = Resolution.HOUR
    ) -> None:
    
        self._algo: QCAlgorithm = algo

        # TODO these two are only used in constructor of subclasses, do we keep em here?
        self._no_of_stocks: Optional[int] = no_of_stocks
        self._etf_symbol: str = etf_symbol
        
        self._index_quantity_calculator = index_quantity_calculator
        self._index_option_provider = index_option_provider
        self._single_stock_option_provider: OptionProvider = single_stock_option_provider
        self._single_stock_quantity_calculator: QuantityCalculatoror = single_stock_quantity_calculator
        self._canonical_option: Symbol = canonical_index_option
        self._contract_multiplier: int = contract_multiplier
        self._resolution: Resolution = subscription_resolution
    
    @abstractmethod
    def calc_trade(self) -> TradeData:
        ...

class DispersionEnterStrategy(BaseDispersionEnterStrategy):
    def __init__(
        self, 
        algo, 
        index_option_provider: OptionProvider, 
        single_stock_option_provider: OptionProvider, 
        single_stock_quantity_calculator: QuantityCalculator, 
        index_quantity_calculator: QuantityCalculator, 
        etf_symbol: str, 
        canonical_index_option: Symbol, 
        no_of_stocks: Optional[int] = None, 
        contract_multiplier = 100, 
        subscription_resolution: Resolution = Resolution.HOUR
    ) -> None:

        super().__init__(
            algo, index_option_provider, single_stock_option_provider, single_stock_quantity_calculator, index_quantity_calculator, etf_symbol, canonical_index_option, no_of_stocks, contract_multiplier, subscription_resolution
        )

        # algo.universe_settings.schedule.on(algo.date_rules.week_start())
        self._universe, self._etf = add_etf_universe(
            algo, etf_symbol=etf_symbol, no_of_stocks=no_of_stocks, resolution=subscription_resolution
        )

    # TODO remove timings and logs, or make structure around it to make it more usable; if possible
    def calc_trade(self) -> TradeData:
        begin_time = t.time()
        index_options = OptionsData({option.key: option.value for option in self._algo.option_chain(self._canonical_option).contracts})
        index_options = self._index_option_provider.get_options(index_options)
        elapsed_time = t.time() - begin_time
        # self._algo.Log(f"DispersionEnterStrategy.calc_trade - Index Option Provider get_options: {elapsed_time:.4f} seconds")

        begin_time = t.time()
        index_trade_data = self._index_quantity_calculator.calc_quantity(index_options)
        elapsed_time = t.time() - begin_time
        # self._algo.Log(f"DispersionEnterStrategy.calc_trade - Index Quantity Calculator calc_quantity: {elapsed_time:.4f} seconds")

        total_index_notional = calc_trade_notional(self._algo, index_trade_data, self._contract_multiplier)
        # target_national = self.ComputeNotionalSplit(holding_value)

        begin_time = t.time()
        stock_trade_data = TradeData()
        for member in self._universe.members:
            symbol = member.key
    
            begin_time_ = t.time()
            cur_options_pool = OptionsData({option.key: option.value for option in self._algo.option_chain(symbol).contracts})
            cur_trade_options = self._single_stock_option_provider.get_options(cur_options_pool)
            elapsed_time_ = t.time() - begin_time_
            # self._algo.Log(f"DispersionEnterStrategy.calc_trade - Stock ({symbol}) Option Provider get_options: {elapsed_time_:.4f} seconds")
            
            # ToDo: Remove. Sanity check for now for a ATM straddle strategy. 
            # if symbol.value != 'BRK.B':
            #     if len(cur_trade_options) != 2:
            #         self._algo.log(f'{symbol.value} has only {len(cur_trade_options)} options')
            
            begin_time_ = t.time()
            single_trade_data = self._single_stock_quantity_calculator.calc_quantity(cur_trade_options)
            elapsed_time_ = t.time() - begin_time_
            # self._algo.Log(f"DispersionEnterStrategy.calc_trade - Stock ({symbol}) Quantity Calculator calc_quantity: {elapsed_time_:.4f} seconds")

            stock_trade_data.add(single_trade_data)

        if len(stock_trade_data) == 0: 
            return stock_trade_data

        stock_notional = calc_trade_notional(self._algo, stock_trade_data, contract_multiplier=self._contract_multiplier)
        coeff = total_index_notional / stock_notional

        for option in stock_trade_data: 
            stock_trade_data[option] *= coeff

        trade = stock_trade_data + index_trade_data

        total = t.time() - begin_time
        # self._algo.Log(f"DispersionEnterStrategy.calc_trade - total time stock trade calculating (provider + quantity calculator for each stock + calc_trade_notional) {total:.4f} seconds")   
       
        return trade

class SectorDispersionEnterStrategy(BaseDispersionEnterStrategy):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        index_option_provider: OptionProvider, 
        single_stock_option_provider: OptionProvider, 
        single_stock_quantity_calculator: QuantityCalculator, 
        index_quantity_calculator: QuantityCalculator, 
        etf_symbol: str, 
        canonical_index_option: Symbol, 
        no_of_stocks: Optional[int] = None, 
        sector_etfs: List[str] = ['XLB', 'XLY', 'XLF', 'XLRE', 'XLP', 'XLV', 'XLU', 'XLC', 'XLE', 'XLI', 'XLK'], 
        contract_multiplier = 100, 
        subscription_resolution: Resolution = Resolution.HOUR
    ):

        self._universe, _ = add_etf_universe(
            algo, etf_symbol=etf_symbol, no_of_stocks=no_of_stocks, resolution=subscription_resolution
        )

        single_stock_quantity_calculator.universe = self._universe

        super().__init__(
            algo, index_option_provider, single_stock_option_provider, single_stock_quantity_calculator, index_quantity_calculator, etf_symbol, canonical_index_option, no_of_stocks, contract_multiplier, subscription_resolution
        )
        
        self._sector_etfs: List[str] = sector_etfs

        for sector_etf in sector_etfs:
            eq: Equity = self._algo.add_equity(sector_etf, resolution=subscription_resolution)
            eq.set_data_normalization_mode(DataNormalizationMode.RAW)

    def calc_trade(self) -> TradeData:
        index_options = OptionsData({option.key: option.value for option in self._algo.option_chain(self._canonical_option).contracts})
        index_options = self._index_option_provider.get_options(index_options)
        index_trade_data = self._index_quantity_calculator.calc_quantity(index_options)

        # append OptionsData from option provider
        trade_options = OptionsData()
        for sector_etf_ticker in self._sector_etfs:
            sector_etf: Symbol = self._algo.symbol(sector_etf_ticker)
            cur_options_pool: OptionsData = OptionsData({option.key: option.value for option in self._algo.option_chain(sector_etf).contracts})
            cur_trade_options: OptionsData = self._single_stock_option_provider.get_options(cur_options_pool)
            trade_options += cur_trade_options

        # calculate quantity for all the options positions at once, not separately by individual sector
        sector_etf_trade_data: TradeData = self._single_stock_quantity_calculator.calc_quantity(trade_options)
        
        if len(sector_etf_trade_data) == 0:
            return sector_etf_trade_data

        trade = sector_etf_trade_data + index_trade_data
       
        return trade

# def calc_trade_quantity(algo, trade_data: TradeData, contract_multiplier=100):
#     total_notional = 0
#     wgts = []
#     for option in trade_data: 
#         total_notional += abs(trade_data[option]) 
#         wgts[option] =  contract_multiplier * algo.securities[option.underlying].price

#     return total_notional, wgts


# TODO move somewhere else
def add_etf_universe(algo, etf_symbol: str, resolution: Resolution, no_of_stocks: Optional=None):
    etf = algo.add_equity(etf_symbol, resolution).symbol
    
    def _select_assets(constituents: list[ETFConstituentUniverse]) -> list[Symbol]:
        # Cache the constituent weights in a dictionary for filtering and position sizing.
        for c in constituents: 
            if c.symbol.value != 'GOOG' and c.symbol.value != 'GOOGL' and not c.has_fundamental_data:
                # algo.log(f'no fundamentals for {c.symbol}')
                pass
        etf_weight_by_symbol = {c.symbol: c.market_cap for c in constituents if c.has_fundamental_data}
        # They should have negative excess return.
        res =  [symbol for symbol, _ in sorted(etf_weight_by_symbol.items(), key=lambda x: x[1], reverse=True)]
        if no_of_stocks is not None: 
            res = res[:no_of_stocks]
        # algo.log(f'no fundamentals for')
        return res

    # Add a universe of the SPY constituents.
    settings = UniverseSettings(algo.universe_settings)
    settings.data_normalization_mode = DataNormalizationMode.RAW
    settings.asynchronous = True
    settings.resolution = resolution 
    etf_univ = algo.universe.etf(etf_ticker=etf, universe_settings=settings)
    universe = algo.add_universe(etf_univ, _select_assets)
    
    return universe, etf





class EnterStrategy_(ABC):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        option_provider: Optional[OptionProvider], 
        quantity_calculator: Optional[QuantityCalculator], 
        option_universe_list: List[OptionUniverse]
    ) -> None:

        self._algo: QCAlgorithm = algo
        self._trade_manager: Optional[TradeManager] = None
        self._option_provider: Optional[OptionProvider] = option_provider
        self._quantity_calculator: Optional[QuantityCalculator] = quantity_calculator
        self._option_universe_list: List[OptionUniverse] = option_universe_list

        self._last_trade: Optional[TradeData] = None

    @property
    def quantity_calculator(self):
        return self._quantity_calculator

    @property
    def option_provider(self):
        return self._option_provider
    
    @property 
    def trade_manager(self):
        return self._trade_manager 

    @trade_manager.setter
    def trade_manager(self, value):
        self._trade_manager = value
    
    def _validate(self):
        pass

    @abstractmethod
    def calc_trade(self) -> TradeData:
        ...
    
class BaseEnterStrategy_(EnterStrategy_):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        option_provider: Optional[OptionProvider], 
        quantity_calculator: Optional[QuantityCalculator], 
        option_universe_list: List[OptionUniverse]
    ) -> None:

        super().__init__(algo, option_provider, quantity_calculator, option_universe_list)
    
    def _calc_trade(self, options_data: OptionsData) -> TradeData:
        options_to_trade = self._option_provider.get_options(options_data)
        if len(options_to_trade) == 0:
            return TradeData()
        return self._quantity_calculator.calc_quantity(options_to_trade)

    def calc_trade(self) -> TradeData:
        options_data: OptionsData = self._get_universe_options_data()
        trade_data: TradeData = self._calc_trade(options_data)
        return trade_data

    def _get_universe_options_data(self) -> OptionsData:
        '''
        Creates options data from universe list as one OptionsData object
        '''
        options: OptionsData = OptionsData()

        for option_universe in self._option_universe_list:
            # options data has been updated day prior to today
            if option_universe.update_dt.date() >= option_universe.get_previous_trading_day():
                universe_options: OptionsData = option_universe.universe_options_data
                options += universe_options
        
        return options

class DispersionEnterStrategy_(EnterStrategy_):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        option_provider: Optional[OptionProvider], 
        quantity_calculator: Optional[QuantityCalculator], 
        option_universe_list: List[OptionUniverse],
        entanglement_calc_fn: Optional[Callable], 
        contract_multiplier: float = 100
    ) -> None:

        super().__init__(algo, option_provider, quantity_calculator, option_universe_list)
        self._contract_multiplier: int = contract_multiplier
        self._entanglement_calc_fn: Callable = entanglement_calc_fn

        assert len(option_universe_list) == 2

        # find needed universes
        self._index_universe: OptionUniverse = next(iter([x for x in option_universe_list if type(x) == SymbolOptionUniverse]), None)
        self._stock_universe: OptionUniverse = next(iter([x for x in option_universe_list if type(x) in [ETFConstituentsOptionUniverse, StaticStocksOptionUniverse]]), None)
        assert self._index_universe is not None and self._stock_universe is not None
        
        # each universe has its own quantity calculator
        assert all(opt_universe.quantity_calculator is not None for opt_universe in [self._index_universe, self._stock_universe])

    def calc_trade(self) -> TradeData:
        # universe options have not been updated recently
        
        if not all(universe.update_dt.date() >= universe.get_previous_trading_day() for universe in [self._index_universe, self._stock_universe]):
        # if not all(universe.update_dt.date() >= universe.get_previous_trading_day() for universe in [self._index_universe]):
            return TradeData()

        index_options: OptionsData = self._index_universe.universe_options_data
        stock_options: OptionsData = self._stock_universe.universe_options_data
        
        index_trade_data: TradeData = self._index_universe.quantity_calculator.calc_quantity(index_options)
        stock_trade_data: TradeData = self._stock_universe.quantity_calculator.calc_quantity(stock_options)

        # entangle two legs of the trade
        if self._entanglement_calc_fn is not None and len(stock_trade_data) != 0:
            index_entanglement_value: float = self._entanglement_calc_fn(self._algo, index_trade_data, self._contract_multiplier)
            stock_entanglement_value: float = self._entanglement_calc_fn(self._algo, stock_trade_data, self._contract_multiplier)
            coeff: float = index_entanglement_value / stock_entanglement_value

            for option in stock_trade_data:
                new_quantity: float = stock_trade_data[option] * coeff
                if new_quantity < 1.:
                    self._algo.log(
                        f'DispersionEnterStrategy_.calc_trade - Quantity is lower than 1 for {option} after entanglement with coeff: {coeff}; index_entanglement_value: {index_entanglement_value}; stock_entanglement_value: {stock_entanglement_value}'
                    )
                stock_trade_data[option] *= coeff

        # set unwind_date date for the whole stock trade
        index_option_expirations: Set[datetime] = set(list(map(lambda x: x.id.date, index_trade_data.trade_data.keys())))
        assert len(index_option_expirations) == 1, "not all of the index options have the same expiry"
        unwind_date: datetime = next(iter(index_option_expirations))

        trade = index_trade_data + stock_trade_data
        trade.unwind_date = unwind_date

        return trade
    
class SectorDispersionEnterStrategy_(DispersionEnterStrategy_):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        option_provider: Optional[OptionProvider], 
        quantity_calculator: Optional[QuantityCalculator], 
        option_universe_list: List[OptionUniverse],
        entanglement_calc_fn: Optional[Callable], 
        contract_multiplier: float = 100, 
        subscription_resolution: Resolution = Resolution.HOUR, 
        etf_symbol: str = 'SPY',
    ) -> None:

        super().__init__(algo, option_provider, quantity_calculator, option_universe_list, entanglement_calc_fn, contract_multiplier)

        self._universe, _ = add_etf_universe(
            algo, etf_symbol=etf_symbol, no_of_stocks=None, resolution=subscription_resolution
        )

        # assert that stock universe quantity calculator can take universe
        assert type(self._stock_universe.quantity_calculator) in [SectorETFMarketCapQuantityCalculator]

        # set universe to quantity calculator
        self._stock_universe.quantity_calculator.universe = self._universe
    
    def calc_trade(self) -> TradeData:
        return super().calc_trade()
# region imports
from AlgorithmImports import *
from abc import ABC, abstractmethod
from data import *
from asset_filters import *
from OptionIndicators.portfolio_value import PortfolioValueIndicator, PortfolioValueType
# endregion

class ExitStrategy(ABC):
    def __init__(self, algo, trade_manager):
        self._trade_manager = trade_manager
        self._algo = algo
    
    @property 
    def trade_manager(self):
        return self._trade_manager 

    @trade_manager.setter
    def trade_manager(self, value):
        self._trade_manager = value
 
    @abstractmethod
    def calc_trade(self) -> TradeData:
        ...

class ExitByPortfolioValue(ExitStrategy):
    '''
        Monitors changes in portfolio value and option protection value
    '''

    def __init__(
        self, 
        algo: QCAlgorithm,
        underlying_symbols: List[Symbol], 
        option_right: Optional[OptionRight] = None, 
        option_direction: Optional[Direction] = None
    ) -> None:

        update_date_rule: ScheduleRule = ScheduleRule(
            algo.DateRules.EveryDay('SPX'), 
            algo.TimeRules.BeforeMarketClose('SPX', 1))

        self._portfolio_value_indicator: PortfolioValueIndicator = PortfolioValueIndicator(
            algo=algo, 
            underlying_symbols=underlying_symbols, 
            observations_n=10, 
            option_right=option_right, 
            option_direction=option_direction, 
            update_date_rule=update_date_rule, 
            use_trade_manager=True
        )

        super().__init__(algo, None)

    @property 
    def portfolio_value_indicator(self) -> PortfolioValueIndicator:
        return self._portfolio_value_indicator 

    def calc_trade(self) -> TradeData:
        assert self._trade_manager is not None

        exit_trade_data: TradeData = TradeData(time_stamp=self._algo.time)

        if self._portfolio_value_indicator.is_ready:
            # NOTE functions test
            latest_indicator: Optional[float] = self._portfolio_value_indicator.latest_value(PortfolioValueType.INDICATOR)
            latest_option: Optional[float] = self._portfolio_value_indicator.latest_value(PortfolioValueType.OPTION)
            latest_portfolio: Optional[float] = self._portfolio_value_indicator.latest_value(PortfolioValueType.PORTFOLIO)

            indicator_value: Optional[float] = self._portfolio_value_indicator.value_by_index(PortfolioValueType.INDICATOR, -1)
            option_value: Optional[float] = self._portfolio_value_indicator.value_by_index(PortfolioValueType.OPTION, -1)
            portfolio_value: Optional[float] = self._portfolio_value_indicator.value_by_index(PortfolioValueType.PORTFOLIO, -1)

            indicator_values: deque = self._portfolio_value_indicator.values_by_type(PortfolioValueType.INDICATOR)
            option_values: deque = self._portfolio_value_indicator.values_by_type(PortfolioValueType.OPTION)
            portfolio_values: deque = self._portfolio_value_indicator.values_by_type(PortfolioValueType.PORTFOLIO)
            
            values_by_index: Dict[PortfolioValueType, Optional[float]] = self._portfolio_value_indicator.values_by_index(-1)

        return exit_trade_data

class ExitAtIndexExpiry(ExitStrategy):
    '''
        Check if already executed trades have the unwind datetime assigned. If so, unwind stock option positions if they have not expired yet
    '''

    def __init__(self, algo: QCAlgorithm) -> None:
        super().__init__(algo, None)

    def calc_trade(self) -> TradeData:
        assert self._trade_manager is not None

        exit_trade_data: TradeData = TradeData(time_stamp=self._algo.time)

        trades_hist = self._trade_manager.trades_hist
        trade_data_to_check: List[TradeData] = [
            t.trade_data for t in list(trades_hist.trades.values()) if t.trade_data.unwind_date is not None and t.trade_data.unwined == False
        ]
        
        for trade_data in trade_data_to_check:
            # trade should be unwined
            if trade_data.unwind_date <= self._algo.time:
                positions: Dict[Symbol, Union[float, int]] = trade_data.trade_data

                # find all invested option that have not expired yet
                invested_stock_options: List[Symbol] = filter_invested_options(
                    algo=self._algo, 
                    positions=positions, 
                    option_types=[SecurityType.OPTION], 
                    underlying_types=[SecurityType.EQUITY], 
                    only_invested=True
                )
                
                if len(invested_stock_options) != 0:
                    self._algo.log(f'ExitAtIndexExpiry.calc_trade - unwinding {len(invested_stock_options)} options; total trade position count: {len(positions)}; trade timestamp: {trade_data.time_stamp}; unwind date: {trade_data.unwind_date}; current time: {self._algo.time}')
                
                # unwind positions
                for symbol in invested_stock_options:
                    quantity: float = positions[symbol]
                    exit_trade_data[symbol] = -quantity
                
                # do not unwind again
                trade_data.unwined = True

        return exit_trade_data

class ExitByDeltaThreshold(ExitStrategy):
    '''
    Check if some options delta is above threshold, exits the asset position.
    '''
    def __init__(self, 
        algo, 
        delta, 
        threshold_type: ThresholdType = ThresholdType.ABOVE, 
        asset_filter: AssetFilter = EmptyFilter(), 
        exit_related_trades: bool = False,
        trade_calculator=None):

        super().__init__(algo, trade_calculator)
        self._delta = delta
        self._threshold_type = threshold_type
        self._asset_filter = asset_filter
        self._exit_related = exit_related_trades


    def calc_trade(self) -> TradeData:
        assert self._trade_manager is not None
        slice = self._algo.current_slice
        trades_hist = self._trade_manager.trades_hist
        trade_data = TradeData(time_stamp=self._algo.time)

        for canonical_option in self._trade_manager._canonical_options:
            if canonical_option not in slice.option_chains:   
                continue
        
            option_contracts = slice.option_chains[canonical_option]
            for contract in option_contracts:
                # ToDo: Remove. Sanity check for particular exit. 
                if (
                    contract.symbol in self._algo.portfolio 
                    and self._algo.portfolio[contract.symbol].holdings_value != 0
                ):
                    assert contract.symbol in trades_hist.positions 
                    assert trades_hist.positions[contract.symbol] == self._algo.portfolio[contract.symbol].quantity
                if (
                    contract.symbol in trades_hist.positions
                    and contract.symbol in self._algo.portfolio 
                    and self._algo.portfolio[contract.symbol].holdings_value != 0
                    and self._asset_filter.eval(contract.symbol)
                ):
                
                    if self._threshold_type * abs(contract.greeks.delta) > self._threshold_type * self._delta:
                        assert contract.symbol not in trade_data
                        #trade_dict[contract.symbol] = -trades_hist.positions[contract.symbol]
                        trade_data.add(TradeData({contract.symbol: -trades_hist.positions[contract.symbol]}, self._algo.time))
                        if self._exit_related: 
                            related_trades = trades_hist.get_all_trades_with_asset(contract.symbol)
                            for trade in related_trades: 
                                for asset, quantity in trades_hist[trade].items():
                                    if asset != contract.symbol and self._algo.securities[asset].is_tradable:
                                        trade_data.add(TradeData({asset: -quantity}))
                                        if self._algo.portfolio[asset].quantity > quantity: 
                                            self._algo.log('error')
                                        assert self._algo.portfolio[asset].quantity <= quantity
                                        assert trades_hist.positions[asset] == self._algo.portfolio[asset].quantity 

        return trade_data
        #return TradeData(trade_dict, self._algo.time)
from abc import ABC, abstractmethod
from AlgorithmImports import *

class Filter(ABC):
    @abstractmethod
    def filter(self):
        pass

class VixThresholdFilter(Filter):
    def __init__(self, algo, threshold):
        self._algo = algo
        self._threshold = threshold

        self.vix_symbol = algo.AddIndex("VIX", Resolution.Hour).Symbol
        algo.Securities[self.vix_symbol].SetDataNormalizationMode(DataNormalizationMode.Raw)
        
    def filter(self):
        if self._algo.Securities[self.vix_symbol].price > self._threshold:
            return False
        else: 
            return True

class Discrete_Indicator_Filter(Filter):
    def __init__(self, algo, values: List, indicator):
        self._algo = algo
        self._indicator = indicator
        self._values = values
    
    def filter(self):
        if self._indicator.value in self._values:
            return False
        else: 
            return True

# region imports
from AlgorithmImports import *
# endregion

def log_missing_option_in_chain(algo: QCAlgorithm, sender_str: str, opt_symbol: Symbol, is_new_trade: bool) -> None:
    option_info: Dict = _get_option_info(algo, opt_symbol)

    algo.log(
        f'{sender_str} - {opt_symbol} not in algo.option_chains; ' + 
        # f'Found in slice: {option_info["slice_found"] is not None}; ' + 
        # f'Slice Bid: {option_info["slice_found"].bid_price if option_info["slice_found"] is not None else "not in slice"}; ' + 
        f'DTE: {option_info["dte"]}; ' + 
        f'Moneyness: {option_info["moneyness"]} {option_info["moneyness_str"]}; ' + 
        f'Strike: {option_info["strike"]}; ' + 
        f'Underlying price: {option_info["underlying_price"]}; ' + 
        f'Invested: {option_info["is_invested"]}; ' + 
        f'Is new trade: {is_new_trade}'
    )

def log_zero_greek(algo: QCAlgorithm, sender_str: str, opt_symbol: Symbol, opt_contract, greek_name: str, is_new_trade: bool) -> None:
    option_info: Dict = _get_option_info(algo, opt_symbol)
    bid: float = algo.securities[opt_contract].bid_price

    algo.log(
        f'data.py.calc_trade_greek.{sender_str} - {opt_symbol} {greek_name} = 0; ' +
        # f'Greek from slice: ' + (f'{getattr(option_info["slice_found"].greeks, greek_name)}; ' if option_info["slice_found"] is not None else 'not found in slice; ') + 
        f'Bid: {bid}; ' + 
        f'DTE: {option_info["dte"]}; ' + 
        f'Moneyness: {option_info["moneyness"]} {option_info["moneyness_str"]}; ' + 
        f'Strike: {option_info["strike"]}; ' + 
        f'Underlying price: {option_info["underlying_price"]}; ' + 
        f'Invested: {option_info["is_invested"]}; ' + 
        f'Is new trade: {is_new_trade}'
    )

def _get_option_info(algo: QCAlgorithm, opt_symbol: Symbol) -> Dict:
    dte: int = (opt_symbol.id.date - algo.time).days
    strike: float = opt_symbol.id.strike_price
    underlying_price: float = algo.securities[opt_symbol.underlying].price
    is_invested: bool = algo.portfolio[opt_symbol].invested
    moneyness: float = round(1 + (strike/underlying_price - 1), 2)
    if opt_symbol.id.option_right == OptionRight.CALL:
        moneyness_str: str = 'OTM' if moneyness > 1 else ('ITM' if moneyness < 1 else 'ATM')
    else:
        moneyness_str: str = 'ITM' if moneyness > 1 else ('OTM' if moneyness < 1 else 'ATM')

    # slice_found = None
    # if opt_symbol.canonical in algo.current_slice.option_chains:
    #     slice_found: OptionContract = algo.current_slice.option_chains[opt_symbol.canonical].contracts.get(opt_symbol)

    return {
        'dte': dte, 
        'strike': strike, 
        'underlying_price': underlying_price, 
        'is_invested': is_invested, 
        'moneyness': moneyness, 
        'moneyness_str': moneyness_str, 
        # 'slice_found': slice_found
    }
# region imports
from AlgorithmImports import *
from abc import ABC, abstractmethod
import bisect
import numpy as np
from data import *
from quantity_calculators import * 
from OptionPricing.option_pricing_model import IndexOptionPricingModel, FallbackIVStrategy
from regression_storage import VIXMoneynessRegressionStorage
import time
# endregion

class OptionProvider(ABC):
    def __init__(self, algo):
        self._algo = algo

    @abstractmethod
    def get_options(self, options: OptionsData) -> OptionsData:
        pass


class PassThroughOptionProvider(OptionProvider):

    def get_options(self, options: OptionsData) -> OptionsData:
        return options

class StrikeRangeOptionProvider(OptionProvider):
    
    def __init__(self, algo, lower_bound: float, upper_bound: float):
        super().__init__(algo)
        self._lb = lower_bound
        self._ub = upper_bound

    def get_options(self, options: OptionsData) -> OptionsData:
        return options.get_subset(
            [option for option in options if option.id.strike_price <= self._ub \
                and option.id.strike_price >= self._lb]
        ) 


class StrikeRangeFromUnderlyingOptionProvider(OptionProvider):
    def __init__(self, algo, lower_fraction, upper_fraction):
        super().__init__(algo)
        self._lf = lower_fraction
        self._uf = upper_fraction
    
    def get_options(self, options: OptionsData):
        res: List[Symbol] = []
        for option in options: 
            underlying_price = self._algo.securities[option.underlying].price
            lb = self._lf * underlying_price
            ub = self._uf * underlying_price

            if lb <= option.id.strike_price <= ub:
                res.append(option)

        return options.get_subset(res) 

class VIXMoneynessRegressionOptionProvider(OptionProvider):
    def __init__(self, algo, regression_storage: VIXMoneynessRegressionStorage):
        super().__init__(algo)
        self._closest_strike_provider: ClosestStrikeOptionProvider = ClosestStrikeOptionProvider(algo, 1.)
        self._regression_storage: VIXMoneynessRegressionStorage = regression_storage
    
    def get_options(self, options: OptionsData):
        if len(options) == 0:
            return OptionsData()

        if not self._regression_storage.is_ready:
            # moneyness from regression cannot be calculated yet
            return OptionsData()

        moneyness: float = self._regression_storage.get_moneyness()
        target_strike: float = 1. + moneyness
        
        self._closest_strike_provider.set_fraction(target_strike)

        return self._closest_strike_provider.get_options(options)


class VIXMoneynessOptionProvider(OptionProvider):
    def __init__(self, algo, eval_function: Callable):
        super().__init__(algo)
        self._eval_function: Callable = eval_function
        self._closest_strike_provider: ClosestStrikeOptionProvider = ClosestStrikeOptionProvider(algo, 1.)
    
    def get_options(self, options: OptionsData):
        if len(options) == 0:
            return OptionsData()

        moneyness: float = self._eval_function()
        target_strike: float = 1. + moneyness
        
        self._closest_strike_provider.set_fraction(target_strike)

        return self._closest_strike_provider.get_options(options)

#Todo: Change name
class ClosestStrikeOptionProviderNoUnderlying(OptionProvider):
    '''
    returns options with closest strike price to a specified strike.
    Note: all options are ssumed to have the same underlying and the underlying is chosen from 
    an arbitrary option. If several underlyings are included no guarantees are given on the underlying
    chosen.
    '''
    
    def __init__(self, algo, target_strike):
        super().__init__(algo)
        self._target_strike = target_strike

    @property
    def strike(self):
        return self._target_strike

    @strike.setter
    def strike(self, value):
        self._target_strike = value 


    def get_options(self, options: OptionsData):
        if len(options) == 0:
            return OptionsData()

        #ToDo: Remove. option = next(iter(options))
        sorted_opt_list = sorted(options, key=lambda x: abs(x.id.strike_price - self._target_strike))
        strike = sorted_opt_list[0].id.strike_price
        for i, opt in enumerate(sorted_opt_list):
            if opt.id.strike_price != strike: 
                break 

        return options.get_subset(sorted_opt_list[0:i])


class SegregateByUnderlyingOptionProvider(OptionProvider):
    def __init__(self, algo, option_provider: OptionProvider):
        super().__init__(algo)
        self._inner_provider = option_provider

        
    def get_options(self, options:OptionsData):
        underlying_dict = options.segregate_by_underlying()
        res = OptionsData()
        for underlying in underlying_dict:
            if len(underlying_dict[underlying]) != 2:
                pass
                # self._algo.log('incorrect straddle')
            res += self._inner_provider.get_options(underlying_dict[underlying])
        
        return res


class ClosestStrikeOptionProvider(OptionProvider):
    '''
    returns options with closest strike price to a specified fraction of the underlying price.
    Note: all options are ssumed to have the same underlying and the underlying is chosen from 
    an arbitrary option. If several underlyings are included no guarantees are given on the underlying
    chosen.
    '''
    
    def __init__(self, algo, fraction):
        super().__init__(algo)
        self._fraction = fraction

    def set_fraction(self, fraction: float) -> None:
        self._fraction = fraction

    def get_options(self, options: OptionsData):
        if len(options) == 0:
            return OptionsData()

        option = next(iter(options))
        target_price = self._algo.securities[option.underlying].price * self._fraction

        # TODO ensure the sort order is correct even with target strike is exactly in between two strikes
        # added + 0.001 to strike price in case target price exactly is in between two strike so it leans to one of two nearby strikes and the sort is consistent
        
        # TODO add constant, define it in data for example
        sorted_opt_list = sorted(options, key=lambda x: abs((x.id.strike_price + 0.001) - target_price))

        strike = sorted_opt_list[0].id.strike_price
        for i, opt in enumerate(sorted_opt_list):
            if opt.id.strike_price != strike: 
                break
        
        return options.get_subset(sorted_opt_list[0:i])


class ClosestDeltaOptionsProvider(OptionProvider):
# ToDo: Rewrite class 

    def __init__(self, algo: QCAlgorithm, delta: float, threshold=0, resolution=None):
        super().__init__(algo)
        self._target_delta = delta
        self._resolution = resolution if resolution else Resolution.MINUTE
        self._threshold=threshold


    def _get_option_by_delta(self, options: Dict[Symbol, Option]):
        deltas_dict = {}
        count = 0

        for symbol, option in options.items():
            
            contract_data = OptionContract(option)
            contract_data.time = self._algo.time
            result = option.evaluate_price_model(None, contract_data)
            delta = abs(result.greeks.delta)

            if delta in deltas_dict:
                count += 1
                deltas_dict[delta].append(symbol)
            
            elif delta >= self._threshold:
                count += 1
                deltas_dict[delta] = [symbol]
        
        if len(deltas_dict) == 0:
            return None 

        deltas = sorted(deltas_dict.keys(), key=lambda x: abs(x - self._target_delta))
        self._algo.Log(f'The delta difference that we get is {abs(deltas[0]-self._target_delta)}')

        return deltas_dict[deltas[0]]
    

    def get_options(self, options: OptionsData):
        
        opt_securities = {}
        # ToDo: rewrite not to have legacy code names.
        for option in options: 
            opt_securities[option] = subscribe_to_option(self._algo, option, self._resolution)
        
        symbols = self._get_option_by_delta(opt_securities)
        if symbols is None:
            self._algo.log(f'No deltas') 
            return OptionsData()

        return options.get_subset(symbols)



# NOTE re-write of ClosestDeltaOptionsProvider 
class ClosestDeltaOptionsProvider_(OptionProvider):
    def __init__(self, algo: QCAlgorithm, delta: float, threshold: float=0, resolution: Optional[Resolution]=None):
        super().__init__(algo)
        self._target_delta: float = delta
        # self._resolution: Optional[Resolution] = resolution if resolution else Resolution.MINUTE
        self._threshold: float = threshold
        
        self._pricing_model: IndexOptionPricingModel = IndexOptionPricingModel(self._algo)

    def _get_option_by_delta(self, options: OptionsData):
        deltas_dict = {}
        rfr: float = self._pricing_model.get_risk_free_rate()

        # timers
        all_options_time = time.time()
        option_times = []
        price_times = []
        delta_times = []

        for opt_symbol in options:
            
            option_time = time.time()

            spot_price: float = self._algo.securities[opt_symbol.underlying].price
            
            price_time = time.time()
            opt_price: Optional[float] = self._pricing_model.get_history_price(opt_symbol)
            if opt_price is None:
                self._algo.log(f'ClosestDeltaOptionsProvider_._get_option_by_delta - cannot fetch option {opt_symbol} price')
                continue
            price_time = time.time() - price_time

            delta_time = time.time()

            dividends: float = self._pricing_model.get_dividends(opt_symbol)
            discount_factor: float = self._pricing_model.get_discount_factor(rfr, dividends, opt_symbol.id.date)
            forward_price: float = self._pricing_model.get_forward_price(spot_price, discount_factor)
                            
            iv: Optional[float] = self._pricing_model.bs_iv(
                option_symbol=opt_symbol, 
                option_price=opt_price, 
                forward_price=forward_price, 
                evaluation_dt=self._algo.time, 
                fallback_iv_strategy=FallbackIVStrategy.CALL_PUT_PARITY_IV, 
                discount_factor=discount_factor
            )
            if iv is None or np.isnan(iv):
                self._algo.log(f'ClosestDeltaOptionsProvider_._get_option_by_delta - cannot get option IV for {opt_symbol}')
                continue

            delta: float = abs(self._pricing_model.black_delta(
                option_symbol=opt_symbol, 
                forward_price=forward_price, 
                implied_vol=iv, 
                discount_factor=discount_factor
            ))

            if delta in deltas_dict:
                deltas_dict[delta].append(opt_symbol)
            
            elif delta >= self._threshold:
                deltas_dict[delta] = [opt_symbol]

            delta_time = time.time() - delta_time
            option_time = time.time() - option_time

            option_times.append(option_time)
            price_times.append(price_time)
            delta_times.append(delta_time)

        all_options_time = time.time() - all_options_time
        
        # timer logs
        self._algo.log(f"Total time: {all_options_time:.4f} seconds")
        self._algo.log(f"Option times (count {len(option_times)}) - mean: {np.mean(option_times):.4f} seconds, max: {max(option_times):.4f}, min: {min(option_times):.4f}")
        self._algo.log(f"Price times (count {len(price_times)}) - mean: {np.mean(price_times):.4f} seconds, max: {max(price_times):.4f}, min: {min(price_times):.4f}")
        self._algo.log(f"Delta times (count {len(delta_times)}) - mean: {np.mean(delta_times):.4f} seconds, max: {max(delta_times):.4f}, min: {min(delta_times):.4f}")
        
        if len(deltas_dict) == 0:
            return None 

        deltas = sorted(deltas_dict.keys(), key=lambda x: abs(x - self._target_delta))
        self._algo.Log(f'The delta difference that we get is {abs(deltas[0]-self._target_delta)}')

        return deltas_dict[deltas[0]]
    
    def get_options(self, options: OptionsData):
        symbols = self._get_option_by_delta(options)
        if symbols is None:
            self._algo.log(f'No deltas') 
            return OptionsData()

        return options.get_subset(symbols)



class UnionOptionProvider(OptionProvider):
    
    def __init__(self, algo: QCAlgorithm, option_providers: Iterable[OptionProvider]):
        super().__init__(algo)
        self._option_providers = option_providers

    
    def get_options(self, options: OptionsData) -> OptionsData:
        # ToDo: improve OptionData to include leg -> options data correspondence
        res = OptionsData()
        for provider in self._option_providers:
            res += provider.get_options(options)
            
        return res



class PutsCallsOptionProvider(OptionProvider):
    
    # def __init__(self, algo: QCAlgorithm, right: Union[OptionRight.PUT, OptionRight.CALL]):
    def __init__(self, algo: QCAlgorithm, right):
        super().__init__(algo=algo)
        assert right == OptionRight.PUT or right == OptionRight.CALL
        self._option_right = right
    

    def get_options(self, options: OptionsData) -> OptionsData:
        return options.get_subset([option for option in options if option.id.OptionRight == self._option_right])



class PutsOptionProvider(PutsCallsOptionProvider):

    def __init__(self, algo):
        super().__init__(algo, OptionRight.PUT)



class CallsOptionProvider(PutsCallsOptionProvider):

    def __init__(self, algo):
        super().__init__(algo, OptionRight.CALL)



class ClosestExpiryOptionProvider(OptionProvider):
    
    def __init__(self, 
                 algo: QCAlgorithm, 
                 days_to_expiry: int, 
                 expiry_constraint: ExpiryConstraint = ExpiryConstraint.CLOSEST):

        super().__init__(algo)
        assert isinstance(days_to_expiry, int) and days_to_expiry >= 1
        self._days_to_expiry = days_to_expiry
        self._expiry_constraint = expiry_constraint

    
    def get_options(self, options: OptionsData) -> OptionsData:

        if len(options) == 0: 
            return OptionsData()

        days_to_expiry = self._days_to_expiry
        expiry_symbol_dict = {}

        for symbol in options:
            expiry_date = symbol.id.date
            if expiry_date in expiry_symbol_dict:
                expiry_symbol_dict[expiry_date].append(symbol)
            else:
                if (expiry_date.date() - self._algo.time.date()).days >= 1: 
                    if self._expiry_constraint == ExpiryConstraint.PRIOR: 
                        if (symbol.id.date.date() - self._algo.time.date()).days <= days_to_expiry:
                            expiry_symbol_dict[expiry_date] = [symbol]

                    elif self._expiry_constraint == ExpiryConstraint.POST: 
                        if (symbol.id.date.date() - self._algo.time.date()).days >= days_to_expiry:
                            expiry_symbol_dict[expiry_date] = [symbol]

                    elif self._expiry_constraint == ExpiryConstraint.CLOSEST: 
                        expiry_symbol_dict[expiry_date] = [symbol]
                    
                    # not implemented
                    else: 
                        assert False

        dates = sorted(expiry_symbol_dict.keys(), key=lambda x: abs((x.date() - self._algo.time.date()).days - days_to_expiry))

        # this is not a good assertion as the length can be == 1. Made to comply with previous code (appears later commented) 
        # as there seemed to be some error in data / QC it handled. 

        if len(expiry_symbol_dict[dates[0]]) < 2:
            self._algo.log('only one option to trade')
        
        return options.get_subset(expiry_symbol_dict[dates[0]])
 


class ChainOptionProviders(OptionProvider):

    def __init__(self, algo: QCAlgorithm, option_providers: Iterable[OptionProvider]):
        
        super().__init__(algo)
        self._option_providers = option_providers
    
    def get_options(self, options: OptionsData) -> OptionsData:
        for provider in self._option_providers:
            options = provider.get_options(options)

        return options
        
    
'''
    class SequentialMultileggedOptionProvider(OptionProvider):
        def __init__(self, option_providers: List[OptionProvider]):
            self._options_providers = option_providers
        
        def __len__(self):
            return len(self._option_providers)

        def get_options(self, options):
            return [provider.get_options(options) for provider in self._providers]
'''

class BudgetOptionProvider(OptionProvider): 
    def __init__(self, 
                 algo, 
                 option_provider: OptionProvider,
                 quantity_calculator: QuantityCalculator, 
                 budget_frac: float,
                 lower_bound = 0.7, 
                 upper_bound = 1.3,
                 contract_multiplier: int = 100
                ):
        
        super().__init__(algo)
        self._budget = budget_frac
        self._contract_multiplier = contract_multiplier
        self._lb = lower_bound
        self._ub = upper_bound
        self._option_provider = option_provider
        self._quantity_calculator = quantity_calculator
    
    def _validate_options(self, options):

        if len(options) == 0:
            return
        
        opt = next(iter(options))
        underlying = opt.underlying
        strike_date = opt.id.date

        for option in options: 
            if option.underlying != underlying: 
                raise RuntimeError('all options in pool of options should have same underlying')
            if option.id.date != opt.id.date:
                raise RuntimeError('all option in pool of options should have same expiry')
            if option.id.option_right != opt.id.option_right:
                raise RuntimeError('all option in pool of options should have same right')

    @property
    def current_budget(self): 
        return self._budget * self._algo.portfolio.cash

    def get_options(self, options: Optional[OptionsData] = None) -> OptionsData:
        options_pool = self._option_provider.get_options(options)
        if len(options_pool) <1: 
            return TradeData()
        
        opt = next(iter(options_pool))
        self._validate_options(options_pool)
        underlying_price = self._algo.securities[opt.underlying].price
        trade_data = self._quantity_calculator.calc_quantity(options_pool.get_subset({opt}))
        quantity = abs(trade_data[opt])
        search_symbols = [s for s in options_pool if s.id.strike_price > self._lb * underlying_price and s.id.strike_price < self._ub * underlying_price]


        if search_symbols[0].id.option_right == OptionRight.CALL:
            ordered_symbols = sorted(search_symbols, key=lambda x: x.id.strike_price, reverse=True)
        
        #if dealing with put options
        else:    
            ordered_symbols = sorted(search_symbols, key=lambda x: x.id.strike_price, reverse=False)
        

        target_budget = self.current_budget

        def key_func(symbol): 
            data = self._algo.history(symbol, 15, resolution=Resolution.HOUR, fill_forward=True)
            if not hasattr(data, 'askclose') or not hasattr(data, 'bidclose'):
                self._algo.log('no data')
            cost = (data.askclose.values[-1] + data.bidclose.values[-1]) / 2 * self._contract_multiplier * quantity

            return cost

        index = bisect.bisect_left(ordered_symbols, target_budget, key=key_func)
        
        if index > 0: 
            res_index = index - 1
            lower_cost = key_func(ordered_symbols[index - 1])

        elif index == 0: 
            lower_cost = -np.inf
            res_index = index
        if index < len(ordered_symbols):
            upper_cost = key_func(ordered_symbols[index])
        else: 
            upper_cost = np.inf
            
        assert lower_cost <= target_budget <= upper_cost

        secondary_option = options.get_subset({ordered_symbols[res_index]})

        return secondary_option



class DynamicBudgetOptionProvider(BudgetOptionProvider): 
    def __init__(self, 
                 algo, 
                 option_provider: OptionProvider,
                 quantity_calculator: QuantityCalculator, 
                 lower_bound = 0.7, 
                 upper_bound = 1.3,
                 contract_multiplier: int = 100
                ):
        super().__init__(algo, 
            option_provider=option_provider, 
            quantity_calculator=quantity_calculator, 
            budget_frac=None, 
            lower_bound=lower_bound, 
            upper_bound=upper_bound, 
            contract_multiplier=contract_multiplier)

        self._budget = None
    
    @property
    def current_budget(self):
        return self._budget
    
    @current_budget.setter
    def current_budget(self, value): 
        self._budget = value
# region imports
from AlgorithmImports import *
from abc import ABC, abstractmethod
from data import *
from option_providers import *
import time as t
# endregion

class OptionUniverse(ABC):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        option_provider: OptionProvider, 
        universe_selection_date_rule: ScheduleRule, 
        quantity_calculator: Optional[QuantityCalculator]=None
    ) -> None:

        self._algo: QCAlgorithm = algo
        
        # store filtered options data and its update dt
        self._universe_options_data: OptionsData = OptionsData()
        self._option_universe_update_dt: datetime = datetime(1,1,1,0,0)
        self._last_update_dt: datetime = datetime(1,1,1,0,0)    # for internal universe options data reset with every new universe selection
        self._option_provider: OptionProvider = option_provider
        self._quantity_calculator: QuantityCalculator = quantity_calculator
        self._universe_selection_date_rule: Schedule = universe_selection_date_rule

        self._begin_time = None

    @property
    def quantity_calculator(self):
        return self._quantity_calculator

    @property
    def update_dt(self) -> datetime:
        return self._option_universe_update_dt
    
    @abstractmethod
    def get_previous_trading_day(self) -> datetime:
        '''
        Returns previous trading day for traded universe underlying. 
        '''
        ...
    
    @property
    def universe_options_data(self) -> OptionsData:
        return self._universe_options_data

    def option_filter(self, option_filter_universe: OptionFilterUniverse) -> OptionFilterUniverse:
        if len(list(option_filter_universe)) != 0:
            underlying_ticker: str = next(iter(option_filter_universe)).underlying.symbol.value
            if underlying_ticker == 'XLB':
                self._begin_time = t.time()
        else:
            self._begin_time = None

        res = option_filter_universe.include_weeklys().contracts(contract_selector=self._filter_function)
        return res
        
    def _filter_function(self, contracts: List[Symbol]) -> List[Symbol]:
        fn_name: str = 'OptionUniverse._filter_function'
        if len(contracts) == 0:
            self._algo.log(f'{fn_name} - no options available in the universe')
            return []

        # transform all available universe options to OptionsData
        options_data: OptionsData = OptionsData(
            {contract.symbol: contract for contract in contracts}
        )
        
        underlying: Symbol = next(iter(contracts)).underlying.symbol
        underlying_ticker: str = underlying.value
        underlying_price: float = self._algo.securities[underlying].price

        if underlying_price == 0:
            self._algo.log(f'{fn_name} - underlying price is 0 for stock: {underlying_ticker}')
            return []
        
        # use options provider to find the options
        contracts_to_trade: OptionsData = self._option_provider.get_options(options_data)

        if len(contracts_to_trade) == 0:
            self._algo.log(f'{fn_name} - no options found for {underlying_ticker} after opt. provider passage; n of options available: {len(contracts)}; n of opt. after provider: {len(contracts_to_trade)}')
            return []
        
        # TODO remove - only for straddles
        if len(contracts_to_trade) > 0 and len(contracts_to_trade) != 2:
            self._algo.log(f'{fn_name} - no straddle found for {underlying_ticker} after opt. provider passage; total n of options available: {len(contracts)}; n of opt. after provider: {len(contracts_to_trade)}')
            return []

        # reset universe options data on a new universe selection
        if self._last_update_dt != self._algo.time:
            self._universe_options_data = OptionsData()
            self._last_update_dt = self._algo.time

        # update stored data
        self._option_universe_update_dt = self._algo.time
        self._universe_options_data += contracts_to_trade

        if underlying.value == 'XLB' and self._begin_time is not None:
            elapsed_time = t.time() - self._begin_time
            if elapsed_time > 1:
                self._algo.log(f'OptionUniverse._filter_function - {underlying.value} option selection took: {elapsed_time:.4f} seconds')

        return list(contracts_to_trade._options.keys())

class SymbolOptionUniverse(OptionUniverse):
    def __init__(
        self, 
        algo: QCAlgorithm,
        symbol: Symbol, 
        resolution: Resolution, 
        option_provider: OptionProvider, 
        universe_selection_date_rule: ScheduleRule, 
        quantity_calculator: Optional[QuantityCalculator]=None
    ) -> None:

        # TODO how to schedule filtering function, it's daily by default
        super().__init__(algo, option_provider, universe_selection_date_rule, quantity_calculator)
        self._symbol: Symbol = symbol

        if symbol.security_type == SecurityType.INDEX:
            option = algo.add_index_option(symbol, resolution=resolution)
        elif symbol.security_type == SecurityType.EQUITY:
            option = algo.add_option(symbol, resolution=resolution)
        else:
            raise TypeError(f"SymbolOptionUniverse.__init__ - security type: {symbol.security_type} is not supported")

        option.set_filter(self.option_filter)

    def get_previous_trading_day(self) -> datetime:
        return get_previous_trading_day(self._algo, self._symbol, self._algo.time).date()

class ETFConstituentsOptionUniverse(OptionUniverse):
    def __init__(
        self, 
        algo: QCAlgorithm,
        etf_ticker: str, 
        resolution: Resolution, 
        option_provider: OptionProvider, 
        universe_selection_date_rule: ScheduleRule, 
        quantity_calculator: Optional[QuantityCalculator]=None, 
        no_of_stocks: Optional[int] = None, 
    ) -> None:

        super().__init__(algo, option_provider, universe_selection_date_rule, quantity_calculator)

        self._no_of_stocks: Optional[int] = no_of_stocks
        self._etf_ticker: str = etf_ticker

        self._init_universe(algo, etf_ticker, resolution, no_of_stocks)

    def get_previous_trading_day(self) -> datetime:
        return get_previous_trading_day(self._algo, self._etf_ticker, self._algo.time).date()

    def _init_universe(self, algo: QCAlgorithm, etf_ticker: str, resolution: Resolution, no_of_stocks: Optional[int] = None) -> None:
        def _etf_constituents_filter(constituents: List[Fundamental]) -> List[Symbol]:
            cap_by_symbol = {c.symbol: c.market_cap for c in constituents if c.has_fundamental_data}
            res = [symbol for symbol, _ in sorted(cap_by_symbol.items(), key=lambda x: x[1], reverse=True)]
            
            if self._no_of_stocks is not None:
                res = res[:self._no_of_stocks]
            
            return res

        algo.add_equity(etf_ticker, resolution).symbol
        
        # TODO check this resolution setting, it only works after algo.universe_settings.resolution = resolution is set firstly
        algo.universe_settings.resolution = resolution

        settings: UniverseSettings = UniverseSettings(algo.universe_settings)
        settings.asynchronous = False
        settings.data_normalization_mode = DataNormalizationMode.RAW
        # settings.resolution = resolution
        settings.minimum_time_in_universe = timedelta(0)
        settings.schedule.on(self._universe_selection_date_rule.date_rule)

        etf_universe = algo.add_universe(
            algo.universe.etf(etf_ticker, settings), _etf_constituents_filter
        )
        algo.add_universe_options(etf_universe, self.option_filter)

class StaticStocksOptionUniverse(OptionUniverse):
    def __init__(
        self, 
        algo: QCAlgorithm,
        stock_tickers: List[str], 
        resolution: Resolution, 
        option_provider: OptionProvider, 
        universe_selection_date_rule: ScheduleRule, 
        quantity_calculator: Optional[QuantityCalculator]=None
    ) -> None:

        super().__init__(algo, option_provider, universe_selection_date_rule, quantity_calculator)

        self._stock_tickers: List[str] = stock_tickers
        
        self._init_universe(algo, stock_tickers, resolution)

    def get_previous_trading_day(self) -> datetime:
        prev_trading_days = list(set([
            get_previous_trading_day(self._algo, ticker, self._algo.time).date() for ticker in self._stock_tickers
        ]))

        assert len(prev_trading_days) == 1, 'StaticStocksOptionUniverse.get_previous_trading_day - not all symbols have the same last trading day'

        return prev_trading_days[0]

    def _init_universe(self, algo: QCAlgorithm, stock_tickers: List[str], resolution: Resolution) -> None:
        # TODO check this resolution setting, it only works after algo.universe_settings.resolution = resolution is set firstly
        algo.universe_settings.resolution = resolution

        settings: UniverseSettings = UniverseSettings(algo.universe_settings)
        settings.asynchronous = False
        settings.data_normalization_mode = DataNormalizationMode.RAW
        # settings.resolution = resolution
        settings.minimum_time_in_universe = timedelta(0)
        settings.schedule.on(self._universe_selection_date_rule.date_rule)

        static_universe = algo.add_universe(
            lambda fundamentals: [Symbol.create(ticker, SecurityType.EQUITY, Market.USA) for ticker in stock_tickers],
        )
        algo.add_universe_options(static_universe, self.option_filter)
# region imports
from AlgorithmImports import *
from abc import ABC, abstractmethod
from data import *
# endregion

class QuantityCalculator(ABC):
    def __init__(self, algo):
        self._algo = algo
    
    @abstractmethod
    def calc_quantity(self, option_data: OptionsData) -> TradeData:
        pass

    def _get_amount_as_fraction_of_portfolio(self, option, fraction_of_portfolio_value):
        multiplier = option.ContractMultiplier
        target_notional = self._algo.Portfolio.TotalPortfolioValue * fraction_of_portfolio_value
        notional_of_contract = multiplier * option.Underlying.Price
        amount = target_notional / notional_of_contract

        return amount
    
    def _get_quantity_as_fraction_of_holding(self, holding_value, option, fraction, multiplier):
        target_notional = holding_value * fraction
        notional_of_contract = multiplier * self._algo.securities[option.underlying].price  #  option.id.StrikePrice**2 #
        quantity = target_notional / notional_of_contract

        return quantity



class MinimumQuantityCalculator(QuantityCalculator): 
    def __init__(self, quantity_calculators: Iterable[QuantityCalculator]):
        '''
        Receives an iterable of quantity calculators and outputs quantities equal to minimum 
        of the quantities calculated by calculators in the iterable. 
        '''
        self._quantity_calculators = quantity_calculators
    
    def calc_quantity(self, option_data: OptionsData):
        trade_dict = {option: np.inf for option in option_data}
        res = TradeData(trade_dict)
        for calculator in self._quantity_calculators:
            trade_data = calculator.calc_quantity(option_data)
            
            for option in option_data:
                if trade_data[option] <= res[option]:
                    res[option] = trade_data[option] 
        
        return res

class SectorETFMarketCapQuantityCalculator(QuantityCalculator):
    def __init__(self, algo, direction: Direction, universe=None, target_fn: Optional[Callable] = None):
        super().__init__(algo)
        self._direction = 1 if direction == Direction.LONG else -1
        self._universe = universe
        self._target_fn: Optional[Callable] = target_fn

    @property
    def universe(self):
        return self._universe
    
    @universe.setter
    def universe(self, value):
        self._universe = value

    def calc_quantity(self, option_data: OptionsData) -> TradeData:
        assert self._universe is not None

        # precalculate sectors' market cap, allows multiple underlyings
        sector_market_cap: Dict[Symbol, float] = get_sector_market_caps(self._algo, self._universe)

        # all sectors have the same number of options present in option_data
        opt_underlyings: List[Symbol] = list(map(lambda x: x.underlying, option_data.option_symbols))
        sector_symbols: Set[Symbol] = list(set(opt_underlyings))

        assert all([opt_underlyings.count(sector_symbols[0]) == opt_underlyings.count(sector_sym) for sector_sym in sector_symbols[1:]])

        res = TradeData(time_stamp=self._algo.time)
        for option in option_data:
            assert option.underlying in sector_market_cap

            cap: float = sector_market_cap[option.underlying]

            underlying_price: float = self._algo.securities[option.underlying].price
            if underlying_price != 0:
                res[option] = (self._direction * cap) / underlying_price
            else:
                self._algo.log(f'SectorETFMarketCapQuantityCalculator.calc_quantity: underlying price is 0 for {option.value}')
            
        # target notional or greek
        if self._target_fn is not None:
            res = self._target_fn(self._algo, res)

        return res

class UnderlyingMarketCapQuantityCalculator(QuantityCalculator):
    def __init__(self, algo, direction: Direction, target_fn: Optional[Callable] = None):
        super().__init__(algo)
        self._direction = 1 if direction == Direction.LONG else -1
        self._target_fn: Optional[Callable] = target_fn

    def calc_quantity(self, option_data: OptionsData):
        res = TradeData(time_stamp=self._algo.time)
        for option in option_data:
            if not self._algo.securities[option.underlying].fundamentals.has_fundamental_data:
                self._algo.log('UnderlyingMarketCapQuantityCalculator.calc_quantity: no fundamentals')
            assert self._algo.securities[option.underlying].fundamentals.has_fundamental_data
            cap = self._algo.securities[option.underlying].fundamentals.market_cap
            underlying_price: float = self._algo.securities[option.underlying].price
            if underlying_price != 0:
                res[option] = (self._direction * cap) / underlying_price
            else:
                self._algo.log(f'UnderlyingMarketCapQuantityCalculator.calc_quantity: underlying price is 0 for {option.value}')

        # target notional or greek
        if self._target_fn is not None:
            res = self._target_fn(self._algo, res)

        return res

class FractionOfHoldingQuantityCalculator(QuantityCalculator):
    def __init__(self, algo, notional_frac, direction: Direction, multiplier=100, holding_calculator: Callable[[Any],float]=portfolio_cash):
        super().__init__(algo)
        self._multiplier = multiplier
        self._notional_frac = notional_frac 
        self._direction = 1 if direction == Direction.LONG else -1
        self._holding_calc = holding_calculator
        
    @property
    def fraction(self):
        return self._notional_frac
    
    @fraction.setter
    def fraction(self, value):
        self._notional_frac = value
    
    def calc_quantity(self, option_data: OptionsData):
        res = TradeData(time_stamp=self._algo.time)
        holding_value = self._holding_calc(self._algo)
        target_national = self.ComputeNotionalSplit(holding_value)
        for option in option_data:
            quantity = self._get_quantity_as_fraction_of_holding(target_national, option, self._notional_frac, self._multiplier)
            quantity = round(quantity)
            res[option] = self._direction * quantity
        
        return res

    def ComputeNotionalSplit(self, total_capital):
        ''' Slipt the national into parts '''
        spx_notional = total_capital / 2
        return spx_notional



class FractionOfCashQuantityCalculator(QuantityCalculator):
    # ToDo: remove. FractionOfHoldingQuanityCalculator with the appropriate holding_calc 
    # does the job. 
    def __init__(self, algo, notional_frac, direction, multiplier=100):
        super().__init__(algo)
        self._multiplier = multiplier
        self._notional_frac = notional_frac 
        self._direction = 1 if direction == Direction.LONG else -1

    def calc_quantity(self, option_data: OptionsData):
        res = TradeData()
        holding_value = self._algo.portfolio.cash

        for option in option_data:
            quantity = self._get_quantity_as_fraction_of_holding(holding_value, option, self._notional_frac, self._multiplier)
            quantity = round(quantity)
            res[option] = self._direction * quantity
        
        return res


class BudgetQuantityCalculator(QuantityCalculator):
    def __init__(self, algo, budget_frac: float, direction, multiplier=100, holding_calculator: Callable[[Any],float]=portfolio_cash):
        '''
        Calculate quantity as a fraction of current portfolio value. Distributes equal
        weighted between all options (rounding)
        '''
        super().__init__(algo)
        self._budget = budget_frac
        self._holding_calc = holding_calculator
        self._multiplier = multiplier
        self._direction = 1 if direction == Direction.LONG else -1
        

    def calc_quantity(self, option_data: OptionsData):
        assert len(option_data) > 0
        res = TradeData()
        budget_per_asset = self._budget * self._holding_calc(self._algo) / len(option_data)

        for option in option_data:
            if option in self._algo.securities: 
                price = mid_quote(self._algo.securities[option])
            else: 
                data = self._algo.history(option, 10, resolution=Resolution.HOUR, fill_forward=True)
                if not hasattr(data, 'askclose') and not hassattr(data, 'bidclose'):
                    raise RuntimeError
                price = (data.askclose.values[-1] + data.bidclose.values[-1]) / 2
            quantity = round(budget_per_asset / (price * self._multiplier))
            res[option] = self._direction * quantity

        return res 


class ConstantQuantityCalculator(QuantityCalculator):
    def __init__(self, quantity: int):
        self._quantity = quantity
    
    @property
    def quantity(self):
        return self._quantity
    
    def calc_quantity(self, option_data: OptionsData):
        res = TradeData()
        for option in option_data:
            res[option] = self._quantity

        return res 
              

class BudgetQuantityCalculatorTrial(QuantityCalculator):
    def __init__(self, algo, budget_frac: float, direction_vec: List[Direction], multiplier=100, holding_calculator: Callable[[Any],float]=portfolio_cash):
        '''
        Calculate quantity as a fraction of current portfolio value. Distributes equal
        weighted between all options (rounding)
        '''
        super().__init__(algo)
        self._budget = budget_frac
        self._holding_calc = holding_calculator
        self._multiplier = multiplier
        #self._direction = 1 if direction == Direction.LONG else -1
        self._direction = [1 if direction == Direction.LONG else -1 for direction in direction_vec]

    def calc_quantity(self, option_data: OptionsData):
        assert len(option_data) > 0
        assert len(option_data) == len(self._direction)
        res = TradeData()
        cost_per_unit = 0
        budget = self._budget * self._holding_calc(self._algo)

        for direction, option in zip(self._direction, option_data):
            if option in self._algo.securities: 
                price = mid_quote(self._algo.securities[option])
            else: 
                data = self._algo.history(option, 10, resolution=Resolution.HOUR, fill_forward=True)
                if not hasattr(data, 'askclose') and not hassattr(data, 'bidclose'):
                    raise RuntimeError
                price = (data.askclose.values[-1] + data.bidclose.values[-1]) / 2
            cost_per_unit += price * direction
        quantity = round(budget / (cost_per_unit * self._multiplier))
        
        for direction, option in zip(self._direction, option_data):            
            res[option] = direction * quantity

        return res 
# region imports
from AlgorithmImports import *
from sklearn.linear_model import LinearRegression
from data import *
from collections import deque
# endregion

class VIXMoneynessRegressionStorage():
    '''
    Used for scheduled storing of vix values and moneyness of the options provided by option provider.
    Also handles scheduled regression.

    Use case:
    VIXMoneynessRegressionOptionProvider
    '''
    def __init__(
        self,
        algo: QCAlgorithm, 
        option_provider, 
        vix: Symbol,
        store_date_rule: ScheduleRule, 
        regress_date_rule: ScheduleRule, 
        canonical_options: List[Symbol], 
        eval_function: Callable, 
        min_period: int = 21, 
        expanding: bool = True, 
    ) -> None:
        
        self._algo: QCAlgorithm = algo
        self._option_provider = option_provider
        self._vix: Symbol = vix
        self._store_date_rule: ScheduleRule = store_date_rule
        self._regress_date_rule: ScheduleRule = regress_date_rule
        self._canonical_options: List[Symbol] = canonical_options
        self._eval_function: Callable = eval_function
        self._min_period: int = min_period

        self._regression_storage: deque = deque() if expanding else deque(maxlen=min_period)

        # recalculated params on each _regress run
        self._recent_intercept: Optional[float] = None
        self._recent_beta: Optional[float] = None

        self._algo.schedule.on(
            self._store_date_rule.date_rule,
            self._store_date_rule.time_rule, 
            self._store
        )
        
        self._algo.schedule.on(
            self._regress_date_rule.date_rule,
            self._regress_date_rule.time_rule, 
            self._regress
        )
    
    def _store(self) -> None:
        options = OptionsData()
        for canonical_option in self._canonical_options:
            options += OptionsData({option.key: option.value for option in self._algo.option_chain(canonical_option).contracts})

        options: OptionsData = self._option_provider.get_options(options)
        assert len(options) == 1
        
        option = list(options)[0]

        strike_price: float = option.id.strike_price
        underlying_price: float = self._algo.securities[option.underlying].price

        vix: float = self._algo.current_slice[self._vix].price
        moneyness: float = -(100 - (strike_price / underlying_price) * 100)

        self._regression_storage.append((vix, moneyness))

    def _regress(self) -> None:
        # wait for enough data points
        if len(self._regression_storage) < self._min_period:
            return

        vix: List[float] = np.array(list(map(lambda x: x[0], self._regression_storage))).reshape(-1,1)
        moneyness: List[float] = np.array(list(map(lambda x: x[1], self._regression_storage))).reshape(-1,1)

        lin_reg = LinearRegression()
        lin_reg.fit(vix, moneyness)
        
        # store params anew each time regression is done
        self._recent_intercept = lin_reg.intercept_[0]
        self._recent_beta = lin_reg.coef_[0][0]
    
    @property
    def is_ready(self) -> bool:
        return self._recent_intercept is not None or self._recent_beta is not None
    
    def get_moneyness(self) -> float:
        # calculate moneyness from recent params
        return self._eval_function.eval(intercept=self._recent_intercept, beta=self._recent_beta) / 100.
# region imports
from AlgorithmImports import *
from enum import Enum
from abc import ABC, abstractmethod
from option_providers import *
from data import *
from quantity_calculators import *
from exit_strategies import *
from enter_strategies import *
from filters import *
from delta_hedge import DeltaHedge
import time
# endregion


class TradeManager(ABC):
    ...


class BaseTradeManager(TradeManager):
    '''
    This class calculates trades that are executed by the BuyClass. 
    '''
    def __init__(self, 
                 algo, 
                 enter_strategy: EnterStrategy, 
                 canonical_options: List, 
                 calc_trade_date_rule: ScheduleRule, 
                 trade_date_rule: ScheduleRule,
                 subscription_resolution: Resolution, 
                 trade_id_generator: BaseTradeIdGenerator,
                 filter: Optional[Filter] = None, 
                 trade_history_length: int = -1,
                 exit_strategy: Optional[ExitStrategy] = None, 
                 exit_date_rule: Optional[ScheduleRule] = None,
                 delta_hedge_strategies: Iterable[DeltaHedge] = [], 
                 gc_date_rule: Optional[ScheduleRule] = None,
                 **params):
        
        self._algo = algo

        self._resolution = subscription_resolution
        self._canonical_options = canonical_options

        self._calc_trade_date_rule = calc_trade_date_rule
        self._trade_rule = trade_date_rule
        self._trade_id_generator = trade_id_generator
        self._trade_history_length = trade_history_length
        
        if exit_strategy: 
            assert exit_date_rule is not None
            if exit_strategy.trade_manager is None:
                exit_strategy.trade_manager = self
        
        self._enter_strategy = enter_strategy
        self._enter_strategy.trade_manager = self
        self._exit_strategy = exit_strategy
        self._exit_rule = exit_date_rule
        self._gc_date_rule = gc_date_rule
        
        for delta_hedge_strategy in delta_hedge_strategies:
            if delta_hedge_strategy.trade_manager is None:
                delta_hedge_strategy.trade_manager = self
                delta_hedge_strategy.schedule()
        
        self._filter = filter

        self._params = params

        self._current_trade: TradeData = TradeData()
        self._trade_updated = False

        self._trades = TradeCollection()

        self._algo.schedule.on(
            self._calc_trade_date_rule.date_rule,
            self._calc_trade_date_rule.time_rule, 
            self.calc_trade
        )

        self._algo.schedule.on(
            self._trade_rule.date_rule,
            self._trade_rule.time_rule, 
            self.trade
        )

        if self._exit_strategy:
            self._algo.schedule.on(
                self._exit_rule.date_rule,
                self._exit_rule.time_rule, 
                self.exit_trades
            )

        if self._trade_history_length >= 0: 
            self._algo.schedule.on(
                self._algo.date_rules.every(days = [0]),
                self._algo.time_rules.midnight, 
                self._clear_trade_history
            )

        if self._gc_date_rule is not None:
            self._algo.schedule.on(
                self._gc_date_rule.date_rule,
                self._gc_date_rule.time_rule, 
                self._garbage_collect
            )

    def _clear_trade_history(self):
        clear_date = self._algo.time - timedelta(days=self._trade_history_length)
        self._trades.delete_all_trades_before_date(clear_date)


    @property
    def trades_hist(self):
        return self._trades
    
    def _subscribe_to_option(self, symbol):
        
        if symbol.underlying.SecurityType == SecurityType.Index:
            return self._algo.AddIndexOptionContract(symbol=symbol, resolution=self._resolution)
        
        elif symbol.underlying.SecurityType == SecurityType.Equity:
            return self._algo.AddOptionContract(symbol=symbol, resolution=self._resolution)

        else:
            raise NotImplementedError
    

    def calc_trade(self): 
        if self._filter is not None and not self._filter.filter():
            return

        start_time = time.time()

        options = OptionsData()
        trade_data = self._enter_strategy.calc_trade()
        
        if len(trade_data) == 0: 
            self._algo.log('no enter trade data found')
            return

        elapsed_time = time.time() - start_time

        # self._algo.log(f"BaseTradeManager.calc_trade() - {elapsed_time:.4f} seconds")

        self._current_trade = trade_data
        self._trade_updated = True


    def _make_trade(self, trade_data):
        if len(trade_data) == 0: 
            self._algo.log('no enter trade data found')
            return

        start_time = time.time()

        for option, quantity in trade_data.items():
            if not self._algo.securities.contains_key(option):
                self._subscribe_to_option(option)
                self._algo.log(f'{option} subscription')
                
            self._algo.market_order(symbol=option, quantity=quantity)
        
        trade_id = self._trade_id_generator.gen_trade_id(trade_data)
        current_trade = TradeRecord(trade_id, trade_data, self._algo.time)

        elapsed_time = time.time() - start_time

        # self._algo.log(f"BaseTradeManager._make_trade() which is subroutine of .trade() {elapsed_time:.4f} seconds for trade_data len: {len(trade_data)}")

        return current_trade


    def trade(self) -> dict[Symbol, float]:

        start_time = time.time()

        if not self._trade_updated:
            self._algo.log('trade data not updated')
            return

        current_trade = self._make_trade(self._current_trade)
        if self._trade_history_length != 0:
            self._trades.add(current_trade)
        self._trade_updated = False

        elapsed_time = time.time() - start_time

        # self._algo.log(f"BaseTradeManager.trade() - {elapsed_time:.4f} seconds")
        
        
    def exit_trades(self):
        start_time = time.time()

        trade_data = self._exit_strategy.calc_trade()
        if len(trade_data.trade_data) > 0:
            current_trade = self._make_trade(trade_data)
            self._trades.add(current_trade)
            # Todo: not a valid assertion. Remove. Only for sanity check a specific
            # exit strategy
            for asset in trade_data:
                if self.trades_hist.positions[asset] != 0: 
                    assert self.trades_hist.positions[asset] < 0
                if asset.id.OptionRight == OptionRight.PUT:
                    assert self.trades_hist.positions[asset] == 0
        
        elapsed_time = time.time() - start_time

        # self._algo.log(f"BaseTradeManager.exit_trades() - {elapsed_time:.4f} seconds")


    def make_hedge_trade(self, trade_data: TradeData) -> None:
        start_time = time.time()

        if len(trade_data) == 0: 
            self._algo.log('no enter trade data found')
            return

        for symbol, q in trade_data.items():
            self._algo.market_order(symbol=symbol, quantity=q)
        
        trade_id = self._trade_id_generator.gen_trade_id(trade_data)
        current_trade = TradeRecord(trade_id, trade_data, self._algo.time)
        self._trades.add(current_trade)

        elapsed_time = time.time() - start_time

        # self._algo.log(f"BaseTradeManager.make_hedge_trade() - {elapsed_time:.4f} seconds for trade_data len: {len(trade_data)}")

    def _garbage_collect(self) -> None:
        start_time = time.time()
        for symbol in self._trades.positions.keys():
            if not self._algo.portfolio[symbol].invested:
                if symbol.id.security_type in [SecurityType.OPTION, SecurityType.INDEX_OPTION, SecurityType.FUTURE_OPTION]:
                    if symbol.id.date < self._algo.time:
                        self._algo.remove_option_contract(symbol)
        elapsed_time = time.time() - start_time

        # active_securities_eq = [symbol.value for symbol in self._algo.active_securities.keys if self._algo.securities[symbol].type == SecurityType.EQUITY]
        # active_securities_opt = [symbol.value for symbol in self._algo.active_securities.keys if self._algo.securities[symbol].type == SecurityType.OPTION]
        # active_securities_index_opt = [symbol.value for symbol in self._algo.active_securities.keys if self._algo.securities[symbol].type == SecurityType.INDEX_OPTION]
        
        self._algo.log(f"BaseTradeManager._garbage_collect - iteration of position traceback: {elapsed_time:.4f} seconds")
        self._algo.log(f"BaseTradeManager._garbage_collect - count of securities {self._algo.securities.count}")
        self._algo.log(f"BaseTradeManager._garbage_collect - count of active_securities {self._algo.active_securities.count}")
        self._algo.log(f"BaseTradeManager._garbage_collect - subscription_manager count: {self._algo.subscription_manager.count}")
        # self._algo.log(f"BaseTradeManager._garbage_collect - number of active_securities eq {len(active_securities_eq)}")
        # self._algo.log(f"BaseTradeManager._garbage_collect - number of active_securities opt {len(active_securities_opt)}")
        # self._algo.log(f"BaseTradeManager._garbage_collect - number of active_securities index opt {len(active_securities_index_opt)}")

class BaseTradeManager_(TradeManager):
    '''
    This class calculates trades that are executed by the BuyClass. 
    '''
    def __init__(
        self, 
        algo, 
        enter_strategy: EnterStrategy, 
        calc_trade_date_rule: ScheduleRule, 
        trade_date_rule: ScheduleRule,
        subscription_resolution: Resolution, 
        trade_id_generator: BaseTradeIdGenerator,
        filter: Optional[Filter] = None, 
        trade_history_length: int = -1,
        exit_strategy: Optional[ExitStrategy] = None, 
        exit_date_rule: Optional[ScheduleRule] = None,
        delta_hedge_strategies: Iterable[DeltaHedge] = [], 
        gc_date_rule: Optional[ScheduleRule] = None,
        **params
    ):
        
        self._algo = algo

        self._resolution = subscription_resolution

        self._calc_trade_date_rule = calc_trade_date_rule
        self._trade_rule = trade_date_rule
        self._trade_id_generator = trade_id_generator
        self._trade_history_length = trade_history_length
        
        # set up exit strategy
        if exit_strategy: 
            assert exit_date_rule is not None
            if exit_strategy.trade_manager is None:
                exit_strategy.trade_manager = self
            
            # set trade manager to portfolio value indicator
            if hasattr(exit_strategy, 'portfolio_value_indicator'):
                if exit_strategy.portfolio_value_indicator.trade_manager is None:
                    exit_strategy.portfolio_value_indicator.trade_manager = self
        
        self._enter_strategy = enter_strategy
        self._enter_strategy.trade_manager = self
        self._exit_strategy = exit_strategy
        self._exit_rule = exit_date_rule
        self._gc_date_rule = gc_date_rule
        
        for delta_hedge_strategy in delta_hedge_strategies:
            if delta_hedge_strategy.trade_manager is None:
                delta_hedge_strategy.trade_manager = self
                delta_hedge_strategy.schedule()
        
        self._filter = filter

        self._params = params

        self._current_trade: TradeData = TradeData()
        self._trade_updated = False

        self._trades = TradeCollection()

        self._algo.schedule.on(
            self._calc_trade_date_rule.date_rule,
            self._calc_trade_date_rule.time_rule, 
            self.calc_trade
        )

        self._algo.schedule.on(
            self._trade_rule.date_rule,
            self._trade_rule.time_rule, 
            self.trade
        )

        if self._exit_strategy:
            self._algo.schedule.on(
                self._exit_rule.date_rule,
                self._exit_rule.time_rule, 
                self.exit_trades
            )

        if self._trade_history_length >= 0: 
            self._algo.schedule.on(
                self._algo.date_rules.every(days = [0]),
                self._algo.time_rules.midnight, 
                self._clear_trade_history
            )

        if self._gc_date_rule is not None:
            self._algo.schedule.on(
                self._gc_date_rule.date_rule,
                self._gc_date_rule.time_rule, 
                self._garbage_collect
            )

        # TODO remove
        self._algo.schedule.on(
            algo.date_rules.every_day('SPX'), 
            algo.time_rules.before_market_close('SPX', 70), 
            self._monitor_cash
        )
        
        # self._algo.schedule.on(
        #     algo.date_rules.every_day('SPX'), 
        #     algo.time_rules.before_market_close('SPX', 55), 
        #     self._monitor_delta
        # )
        # self._algo.schedule.on(
        #     algo.date_rules.every_day('SPX'), 
        #     algo.time_rules.before_market_close('SPX', 1), 
        #     self._monitor_delta_
        # )
        
        self._enter_trade_event: EventSource = EventSource()

        # TODO remove
        self._make_hedge_trade_measures: TimeMeasure = TimeMeasure("BaseTradeManager_ make_hedge_trade")
        self._make_exit_trade_measures: TimeMeasure = TimeMeasure("BaseTradeManager_ _make_exit_trade")
        self._make_trade_measures: TimeMeasure = TimeMeasure("BaseTradeManager_ _make_trade")
        self._hedge_count_measure: TimeMeasure = TimeMeasure("BaseTradeManager_ hedge count measure", units='occurrences')

        self._algo.schedule.on(
            algo.date_rules.on(2025, 12, 4), 
            algo.time_rules.before_market_close('SPX', 1), 
            self._print_measures
        )

    # TODO remove
    def _print_measures(self) -> None:
        if self._make_hedge_trade_measures.is_ready():
            self._make_hedge_trade_measures.print_percentiles(self._algo)
            self._make_hedge_trade_measures.print_histogram(self._algo)
            self._make_hedge_trade_measures.print_stats(self._algo)

        if self._make_exit_trade_measures.is_ready():
            self._make_exit_trade_measures.print_percentiles(self._algo)
            self._make_exit_trade_measures.print_histogram(self._algo)
            self._make_exit_trade_measures.print_stats(self._algo)
                    
        if self._make_trade_measures.is_ready():
            self._make_trade_measures.print_percentiles(self._algo)
            self._make_trade_measures.print_histogram(self._algo)
            self._make_trade_measures.print_stats(self._algo)
        
        if self._hedge_count_measure.is_ready():
            # self._hedge_count_measure.print_percentiles(self._algo)
            # self._hedge_count_measure.print_histogram(self._algo)
            self._hedge_count_measure.print_stats(self._algo)

    @property
    def enter_trade_event(self) -> EventSource:
        return self._enter_trade_event

    @enter_trade_event.setter
    def enter_trade_event(self, value: EventSource):
        self._enter_trade_event = value

    def _monitor_cash(self) -> None:
        if self._algo.time.month == 9 and self._algo.time.year == 2025:
            self._algo.log(f"{self._algo.time} - 5 min before calc_trade; cash: {self._algo.portfolio.cash:,.0f}; total_portfolio_value: {self._algo.portfolio.total_portfolio_value:,.0f}; index price: {self._algo.securities['spx'].price}")

    def _monitor_delta_(self) -> None:
        self._monitor_delta()

    def _monitor_delta(self) -> None:
        portfolio_options: List[Symbol] = [
            k for k,v in self._algo.portfolio.items() if k.security_type == SecurityType.INDEX_OPTION and \
            not k.is_canonical() and self._algo.portfolio[k].invested
        ]
        
        portfolio_futures: List[Symbol] = [
            k for k,v in self._algo.portfolio.items() if k.security_type == SecurityType.FUTURE and \
            not k.is_canonical() and self._algo.portfolio[k].invested
        ]

        options_delta: float = 0.
        if len(portfolio_options) != 0:
            slice = self._algo.current_slice

            for opt_symbol in portfolio_options:
                if opt_symbol.canonical not in slice.option_chains:
                    self._algo.log(f'BaseTradeManager_._monitor_delta - {opt_symbol.canonical} not in slice option chains')
                    continue

                opt_contracts: OptionContracts = slice.option_chains.get(opt_symbol.canonical).contracts

                if opt_symbol not in opt_contracts:
                    self._algo.log(f'BaseTradeManager_._monitor_delta - {opt_symbol} not in option contracts')
                    continue

                delta: int = opt_contracts.get(opt_symbol).greeks.delta * self._algo.securities[opt_symbol].symbol_properties.contract_multiplier * self._algo.portfolio[opt_symbol].quantity

                options_delta += delta

        futures_delta: float = 0.
        if len(portfolio_futures) != 0:
            for future in portfolio_futures:
                delta: int = 1 * self._algo.securities[future].symbol_properties.contract_multiplier * self._algo.portfolio[future].quantity
                futures_delta += delta
        
        self._algo.log(f'total options delta: {options_delta}; total futures delta: {futures_delta}')

    def _clear_trade_history(self):
        clear_date = self._algo.time - timedelta(days=self._trade_history_length)
        self._trades.delete_all_trades_before_date(clear_date)


    @property
    def trades_hist(self):
        return self._trades
    
    def _subscribe_to_option(self, symbol):
        
        if symbol.underlying.SecurityType == SecurityType.Index:
            return self._algo.AddIndexOptionContract(symbol=symbol, resolution=self._resolution)
        
        elif symbol.underlying.SecurityType == SecurityType.Equity:
            return self._algo.AddOptionContract(symbol=symbol, resolution=self._resolution)

        else:
            raise NotImplementedError
    

    def calc_trade(self): 
        if self._filter is not None and not self._filter.filter():
            return

        start_time = time.time()

        options = OptionsData()
        trade_data = self._enter_strategy.calc_trade()

        # invoke chart update
        self._enter_trade_event.emit(trade_data)

        if len(trade_data) == 0: 
            self._algo.log('BaseTradeManager_.calc_trade: no enter trade data found')
            return

        elapsed_time = time.time() - start_time

        # self._algo.log(f"BaseTradeManager.calc_trade() - {elapsed_time:.4f} seconds")

        self._current_trade = trade_data
        self._trade_updated = True


    def _make_trade(self, trade_data):
        if len(trade_data) == 0: 
            self._algo.log('BaseTradeManager_._make_trade - no enter trade data found')
            return

        start_time = time.time()
        
        # self._algo.log('no of trades: ' + str(len(trade_data)))
        not_subscribed_count = 0
        trades_taken = 0

        # for position delta logging
        # opt_chains = self._algo.current_slice.option_chains
        # raw_delta_by_underlying: Dict[Symbol, float] = {}

        for option, quantity in trade_data.items():
            if not self._algo.securities.contains_key(option):
                not_subscribed_count += 1
                continue
            '''
            if not self._algo.securities.contains_key(option):
                self._subscribe_to_option(option)
                self._algo.log(f'{option} subscription')
            '''

            # TODO remove
            if self._algo.securities[option].resolution != Resolution.HOUR:
                self._algo.log(f'BaseTradeManager_._make_trade - no HOUR subscription: {option}; res: {self._algo.securities[option].resolution}')
            else:

                start_time = time.time()
                self._algo.market_order(symbol=option, quantity=quantity)
                elapsed_time = time.time() - start_time
                self._make_trade_measures.update(elapsed_time)

                trades_taken += 1

                # opt_chain = opt_chains.get(option.canonical)
                # if opt_chain is not None:
                #     opt = opt_chain.contracts.get(option)
                #     if opt is not None:
                #         delta: float = opt.greeks.delta

                #         if option.underlying not in raw_delta_by_underlying:
                #             raw_delta_by_underlying[option.underlying] = delta
                #         else:
                #             raw_delta_by_underlying[option.underlying] += delta

        # log delta above 0.2, bellow -0.2
        # delta_log: Dict = {
        #     symbol.value: delta for symbol, delta in raw_delta_by_underlying.items() if delta > 0.2 or delta < -0.2
        # }
        # if len(delta_log) != 0:
        #     self._algo.log(f'Position Delta > 0.2 or < -0.2 {delta_log}')

        if not_subscribed_count != 0:
            self._algo.log('BaseTradeManager_._make_trade - no subscriptions for ' + str(not_subscribed_count) + ' options')
        
        # TODO remove
        # if trades_taken != 102:
        #     self._algo.log(f'BaseTradeManager_._make_trade - trades taken count differs from 102; n of trades taken {trades_taken}')

        trade_id = self._trade_id_generator.gen_trade_id(trade_data)
        current_trade = TradeRecord(trade_id, trade_data, self._algo.time)

        elapsed_time = time.time() - start_time

        # self._algo.log(f"BaseTradeManager._make_trade() which is subroutine of .trade() {elapsed_time:.4f} seconds for trade_data len: {len(trade_data)}")

        return current_trade


    def trade(self) -> dict[Symbol, float]:

        start_time = time.time()

        if not self._trade_updated:
            self._algo.log('BaseTradeManager_.trade - trade data not updated')
            return

        current_trade = self._make_trade(self._current_trade)
        if self._trade_history_length != 0:
            self._trades.add(current_trade)
        self._trade_updated = False

        elapsed_time = time.time() - start_time

        # self._algo.log(f"BaseTradeManager.trade() - {elapsed_time:.4f} seconds")
        

    def _make_exit_trade(self, trade_data: TradeData) -> TradeRecord:
        if len(trade_data) == 0: 
            self._algo.log('BaseTradeManager_._make_exit_trade - no enter trade data found')
            return

        for option, quantity in trade_data.items():
            start_time = time.time()
            self._algo.market_order(symbol=option, quantity=quantity)
            elapsed_time = time.time() - start_time
            self._make_exit_trade_measures.update(elapsed_time)

        trade_id = self._trade_id_generator.gen_trade_id(trade_data)
        current_trade: TradeRecord = TradeRecord(trade_id, trade_data, self._algo.time)

        return current_trade


    def exit_trades(self):
        trade_data: TradeData = self._exit_strategy.calc_trade()

        if len(trade_data.trade_data) > 0:
            current_trade: TradeRecord = self._make_exit_trade(trade_data)
            self._trades.add(current_trade)
            # Todo: not a valid assertion. Remove. Only for sanity check a specific
            # exit strategy
            # for asset in trade_data:
            #     if self.trades_hist.positions[asset] != 0: 
            #         assert self.trades_hist.positions[asset] < 0
            #     if asset.id.OptionRight == OptionRight.PUT:
            #         assert self.trades_hist.positions[asset] == 0
        

        # self._algo.log(f"BaseTradeManager.exit_trades() - {elapsed_time:.4f} seconds")


    def make_hedge_trade(self, trade_data: TradeData) -> None:
        if len(trade_data) == 0: 
            self._algo.log('BaseTradeManager_.make_hedge_trade: no enter trade data found')
            return

        # TODO 
        # record how many hedge trades are done

        # TODO
        # mitigate trading:
        # rank by abs change in delta 
        # threshold for each stock
        # threshold for delta delta

        # TODO
        # change universe renewal quarterly - schedule month/quarterly/6m
        # from 2021

        for symbol, q in trade_data.items():
            start_time = time.time()
            self._algo.market_order(symbol=symbol, quantity=q, asynchronous=True)
            elapsed_time = time.time() - start_time
            self._make_hedge_trade_measures.update(elapsed_time)
        
        self._hedge_count_measure.update(len(trade_data))

        trade_id = self._trade_id_generator.gen_trade_id(trade_data)
        current_trade = TradeRecord(trade_id, trade_data, self._algo.time)
        self._trades.add(current_trade)

        # self._algo.log(f"BaseTradeManager.make_hedge_trade() - {elapsed_time:.4f} seconds for trade_data len: {len(trade_data)}")


    def _garbage_collect(self) -> None:
        start_time = time.time()
        for symbol in self._trades.positions.keys():
            if not self._algo.portfolio[symbol].invested:
                if symbol.id.security_type in [SecurityType.OPTION, SecurityType.INDEX_OPTION, SecurityType.FUTURE_OPTION]:
                    if symbol.id.date < self._algo.time:
                        self._algo.remove_option_contract(symbol)
        elapsed_time = time.time() - start_time

        self._algo.log(f"BaseTradeManager_._garbage_collect - iteration of position traceback: {elapsed_time:.4f} seconds")
        self._algo.log(f"BaseTradeManager_._garbage_collect - count of securities {self._algo.securities.count}")
        self._algo.log(f"BaseTradeManager_._garbage_collect - count of active_securities {self._algo.active_securities.count}")
        self._algo.log(f"BaseTradeManager_._garbage_collect - subscription_manager count: {self._algo.subscription_manager.count}")
# region imports
from AlgorithmImports import *
from abc import ABC, abstractmethod
# endregion

class Function(ABC):
    def __init__(self, algo: QCAlgorithm):
        self._algo: QCAlgorithm = algo

    @abstractmethod
    def eval(self) -> float:
        pass

class VIXLinearFunction(Function):
    def __init__(self, algo: QCAlgorithm, vix: Symbol, intercept: float, slope: float):
        super().__init__(algo)
        self._vix: Symbol = vix
        self._intercept = intercept
        self._slope = slope
        
    
    def eval(self) -> float:
        vix: float = self._algo.current_slice[self._vix].price    
        return self._intercept + self._slope * vix
# region imports
from AlgorithmImports import *
from OptionsTrading.option_providers import *
from OptionsTrading.trade_managers import *
from OptionsTrading.exit_strategies import *
from OptionsTrading.enter_strategies import *
from OptionsTrading.data import ThresholdType, Direction
from OptionsTrading.filters import *
from OptionsTrading.vix_functions import VIXLinearFunction
from OptionsTrading.delta_hedge import DeltaHedge, EquityOptionsDeltaHedge, IndexDeltaHedge
from dispersion_option_provider import *
from option_universe import *
from stats import *
# endregion

def sector_dispersion_(algo): 
    spx = algo.add_index('SPX', Resolution.HOUR).symbol
    xlp = algo.add_equity('XLP', Resolution.HOUR).symbol

    day_offset: int = 2
    calc_trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx, day_offset), 
        algo.TimeRules.BeforeMarketClose(spx, 65))
    
    trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx, day_offset), 
        algo.TimeRules.BeforeMarketClose(spx, 60))
    
    exit_date_rule = ScheduleRule(
        algo.DateRules.EveryDay(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 55))

    # date rules only metter for universe selection scheduling
    universe_selection_date_rule = ScheduleRule(
        algo.date_rules.every_day(),
        None 
    )

    trade_id_generator = IncrementIdGenerator(0)

    #trade filter
    filter = None
    # exit_strat = ExitAtIndexExpiry(algo=algo)
    exit_strat = ExitByPortfolioValue(
        algo=algo, 
        underlying_symbols=[spx, xlp], 
        option_right=OptionRight.PUT, 
        option_direction=Direction.SHORT
    )
    # exit_strat = None
    
    providers_chain = [ClosestExpiryOptionProvider(algo, 180), ClosestStrikeOptionProvider(algo, 1.0)]
    universe_option_provider = ChainOptionProviders(algo, providers_chain)
    index_option_provider = universe_option_provider
    strategy_option_provider = PassThroughOptionProvider(algo)

    index_quantity_calculator = FractionOfHoldingQuantityCalculator(
        algo, 
        notional_frac=4/52, 
        direction=Direction.SHORT,
        # holding_calculator=portfolio_cash
        holding_calculator=portfolio_value
    )
    stocks_quantity_calculator = SectorETFMarketCapQuantityCalculator(
        algo, 
        Direction.LONG, 
        # target_fn=lambda algo, trade_data: target_notional(
        #     algo, 
        #     trade_data, 
        #     holding_calculator=portfolio_cash, 
        #     notional_fraction_target=4/52
        # )
    )

    option_universe_list: List[OptionUniverse] = [
        StaticStocksOptionUniverse(
            algo=algo, 
            stock_tickers=['XLB', 'XLY', 'XLF', 'XLRE', 'XLP', 'XLV', 'XLU', 'XLC', 'XLE', 'XLI', 'XLK'],
            resolution=Resolution.HOUR, 
            option_provider=universe_option_provider, 
            universe_selection_date_rule=universe_selection_date_rule, 
            quantity_calculator=stocks_quantity_calculator
        ), 
        SymbolOptionUniverse(
            algo=algo, 
            symbol=spx, 
            resolution=Resolution.HOUR, 
            option_provider=universe_option_provider, 
            universe_selection_date_rule=universe_selection_date_rule, 
            quantity_calculator=index_quantity_calculator
        )
    ]
    
    # enter strategy
    enter_strat = SectorDispersionEnterStrategy_(
        algo=algo, 
        option_provider=strategy_option_provider, 
        quantity_calculator=None, 
        option_universe_list=option_universe_list, 
        # entanglement_calc_fn=lambda algo, trade_data, contract_multiplier: calc_trade_notional(algo, trade_data, contract_multiplier)
        entanglement_calc_fn=lambda algo, trade_data, contract_multiplier: calc_trade_greek(algo, trade_data, contract_multiplier, greek_name='gamma', is_new_trade=True, sender_str='entanglement_calc_fn')
    )

    future: Future = algo.add_future(
        Futures.Indices.SP_500_E_MINI,
        extended_market_hours=True,
        data_mapping_mode=DataMappingMode.LAST_TRADING_DAY, #OPEN_INTEREST
        data_normalization_mode=DataNormalizationMode.BACKWARDS_RATIO,
        contract_depth_offset=0, 
        resolution=Resolution.HOUR
    )
    future.set_filter(0, 60)

    delta_hedges: List[DeltaHedge] = [
        EquityOptionsDeltaHedge(
            algo, 
            trade_manager=None, 
            schedule_rule=ScheduleRule(
                algo.date_rules.every_day(spx), 
                algo.time_rules.before_market_close(spx, 60)
                # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
            ), 
            use_precalculated_delta=True
        ),
        IndexDeltaHedge(
            algo, 
            underlying_index=spx, 
            continuous_future=future, 
            trade_manager=None, 
            schedule_rule=ScheduleRule(
                algo.date_rules.every_day(spx), 
                algo.time_rules.before_market_close(spx, 60)
                # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
            ), 
            use_precalculated_delta=True
        )
    ]
    # delta_hedges = []

    algo.set_name('Sector_Dispersion_')
    
    # trade_manager = BaseTradeManager(
    trade_manager = BaseTradeManager_(
        algo=algo,
        enter_strategy=enter_strat, 
        calc_trade_date_rule=calc_trade_date_rule, 
        trade_date_rule=trade_date_rule,
        subscription_resolution=Resolution.HOUR, 
        trade_id_generator=trade_id_generator,
        trade_history_length=-1, 
        exit_strategy=exit_strat,
        exit_date_rule=exit_date_rule, 
        delta_hedge_strategies=delta_hedges, 
        gc_date_rule=None, 
        filter=filter
    )

    # charts
    # stats_date_rule = ScheduleRule(
    #     algo.DateRules.EveryDay(spx), 
    #     algo.TimeRules.BeforeMarketClose(spx, 30))
        
    # GreeksChart_(
    #     algo=algo, 
    #     trade_manager=trade_manager, 
    #     greeks=['gamma', 'vega', 'theta', 'delta'], 
    #     greeks_type=GreeksType.PERCENTAGE, 
    #     portfolio_type=PortfolioType.LAST_TRADE
    # )

    # GreeksChart_(
    #     algo=algo, 
    #     trade_manager=trade_manager, 
    #     greeks=['gamma', 'vega', 'theta', 'delta'], 
    #     greeks_type=GreeksType.PERCENTAGE, 
    #     portfolio_type=PortfolioType.WHOLE_PORTFOLIO, 
    #     schedule_rule=stats_date_rule
    # )

    # NotionalChart_(
    #     algo=algo, 
    #     trade_manager=trade_manager, 
    #     portfolio_type=PortfolioType.LAST_TRADE
    # )

    # NotionalChart_(
    #     algo=algo, 
    #     trade_manager=trade_manager, 
    #     portfolio_type=PortfolioType.WHOLE_PORTFOLIO, 
    #     schedule_rule=stats_date_rule
    # )
    
    # ExpenditureChart_(
    #     algo=algo, 
    #     trade_manager=trade_manager
    # )
    
    # SectorWeightChart_(
    #     algo=algo, 
    #     trade_manager=trade_manager
    # )

    return trade_manager

def dispersion_(algo): 
    spx = algo.add_index('SPX', Resolution.HOUR).symbol

    day_offset: int = 2
    calc_trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx, day_offset), 
        algo.TimeRules.BeforeMarketClose(spx, 65))
    
    trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx, day_offset), 
        algo.TimeRules.BeforeMarketClose(spx, 60))
    
    exit_date_rule = ScheduleRule(
        algo.DateRules.EveryDay(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 55))

    # date rules only metter for universe selection scheduling
    universe_selection_date_rule = ScheduleRule(
        algo.date_rules.every_day(),
        None 
    )

    trade_id_generator = IncrementIdGenerator(0)

    # trade filter
    filter = None
    # exit_strat = None
    exit_strat = ExitAtIndexExpiry(algo=algo)
    
    providers_chain = [ClosestExpiryOptionProvider(algo, 180), ClosestStrikeOptionProvider(algo, 1.0)]
    universe_option_provider = ChainOptionProviders(algo, providers_chain)
    index_option_provider = universe_option_provider

    index_quantity_calculator = FractionOfHoldingQuantityCalculator(
        algo, 
        notional_frac=4/52, 
        direction=Direction.SHORT, 
        holding_calculator=portfolio_value
    )
    stocks_quantity_calculator = UnderlyingMarketCapQuantityCalculator(algo, Direction.LONG)

    option_universe_list: List[OptionUniverse] = [
        SymbolOptionUniverse(
            algo=algo, 
            symbol=spx, 
            resolution=Resolution.HOUR, 
            option_provider=universe_option_provider, 
            universe_selection_date_rule=universe_selection_date_rule, 
            quantity_calculator=index_quantity_calculator
        ), 
        ETFConstituentsOptionUniverse(
            algo=algo, 
            etf_ticker='SPY', 
            resolution=Resolution.HOUR, 
            option_provider=universe_option_provider, 
            universe_selection_date_rule=universe_selection_date_rule, 
            quantity_calculator=stocks_quantity_calculator, 
            no_of_stocks=50
        ), 
    ]
    
    # enter strategy
    enter_strat = DispersionEnterStrategy_(
        algo=algo, 
        option_provider=None, 
        quantity_calculator=None, 
        option_universe_list=option_universe_list, 
        # entanglement_calc_fn=lambda algo, trade_data, contract_multiplier: calc_trade_notional(algo, trade_data, contract_multiplier)
        entanglement_calc_fn=lambda algo, trade_data, contract_multiplier: calc_trade_greek(algo, trade_data, contract_multiplier, greek_name='gamma', is_new_trade=True, sender_str='entanglement_calc_fn')
    )

    future: Future = algo.add_future(
        Futures.Indices.SP_500_E_MINI,
        extended_market_hours=True,
        data_mapping_mode=DataMappingMode.LAST_TRADING_DAY, #OPEN_INTEREST
        data_normalization_mode=DataNormalizationMode.BACKWARDS_RATIO,
        contract_depth_offset=0, 
        resolution=Resolution.HOUR
    )
    future.set_filter(0, 60)

    delta_hedges: List[DeltaHedge] = [
        EquityOptionsDeltaHedge(
            algo, 
            trade_manager=None, 
            schedule_rule=ScheduleRule(
                algo.date_rules.every_day(spx), 
                algo.time_rules.before_market_close(spx, 60)
                # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
            ), 
            use_precalculated_delta=True

        ),
        IndexDeltaHedge(
            algo, 
            underlying_index=spx, 
            continuous_future=future, 
            trade_manager=None, 
            schedule_rule=ScheduleRule(
                algo.date_rules.every_day(spx), 
                algo.time_rules.before_market_close(spx, 60)
                # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
            ), 
            use_precalculated_delta=True
        )
    ]
    # delta_hedges = []

    algo.set_name('Dispersion_')
    
    # trade_manager = BaseTradeManager(
    trade_manager = BaseTradeManager_(
        algo=algo,
        enter_strategy=enter_strat, 
        calc_trade_date_rule=calc_trade_date_rule, 
        trade_date_rule=trade_date_rule,
        subscription_resolution=Resolution.HOUR, 
        trade_id_generator=trade_id_generator,
        trade_history_length=-1, 
        exit_strategy=exit_strat,
        exit_date_rule=exit_date_rule, 
        delta_hedge_strategies=delta_hedges, 
        gc_date_rule=None, 
        filter=filter
    )

    # charts
    # stats_date_rule = ScheduleRule(
    #     algo.DateRules.EveryDay(spx), 
    #     algo.TimeRules.BeforeMarketClose(spx, 30))
        
    # GreeksChart_(
    #     algo=algo, 
    #     trade_manager=trade_manager, 
    #     greeks=['gamma', 'vega', 'theta', 'delta'], 
    #     greeks_type=GreeksType.PERCENTAGE, 
    #     portfolio_type=PortfolioType.LAST_TRADE
    # )

    # GreeksChart_(
    #     algo=algo, 
    #     trade_manager=trade_manager, 
    #     greeks=['gamma', 'vega', 'theta', 'delta'], 
    #     greeks_type=GreeksType.PERCENTAGE, 
    #     portfolio_type=PortfolioType.WHOLE_PORTFOLIO, 
    #     schedule_rule=stats_date_rule
    # )

    # NotionalChart_(
    #     algo=algo, 
    #     trade_manager=trade_manager, 
    #     portfolio_type=PortfolioType.LAST_TRADE
    # )

    # NotionalChart_(
    #     algo=algo, 
    #     trade_manager=trade_manager, 
    #     portfolio_type=PortfolioType.WHOLE_PORTFOLIO, 
    #     schedule_rule=stats_date_rule
    # )

    return trade_manager

def straddle_(algo): 
    spx = algo.add_index('SPX', Resolution.HOUR).symbol

    calc_trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 65))
    
    trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))
    
    exit_date_rule = ScheduleRule(
        algo.DateRules.EveryDay(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))

    # date rules only metter for universe selection scheduling
    universe_selection_date_rule = ScheduleRule(
        algo.date_rules.every_day(),
        None 
    )
    reset_universe_options_date_rule = ScheduleRule(
        algo.date_rules.every_day(spx), 
        algo.time_rules.at(23, 59, 0)
    )

    trade_id_generator = IncrementIdGenerator(0)

    #trade filter
    filter = None
    exit_strat = None
    
    # single stock legs
    providers_chain = [ClosestExpiryOptionProvider(algo, 180), ClosestStrikeOptionProvider(algo, 1.0)]
    universe_option_provider = ChainOptionProviders(algo, providers_chain)
    option_provider = PassThroughOptionProvider(algo)
    quantity_calculator = ConstantQuantityCalculator(quantity=1)

    option_universe_list: List[OptionUniverse] = [
        ETFConstituentsOptionUniverse(
            algo=algo, 
            etf_ticker='SPY', 
            resolution=Resolution.HOUR, 
            option_provider=universe_option_provider, 
            universe_selection_date_rule=universe_selection_date_rule, 
            reset_universe_options_date_rule=reset_universe_options_date_rule, 
            quantity_calculator=None, 
            no_of_stocks=50
        )
    ]
    
    # enter strategy
    enter_strat = BaseEnterStrategy_(
        algo=algo, 
        option_provider=option_provider, 
        quantity_calculator=quantity_calculator, 
        option_universe_list=option_universe_list
    )

    # future: Future = algo.add_future(
    #     Futures.Indices.SP_500_E_MINI,
    #     extended_market_hours=True,
    #     data_mapping_mode=DataMappingMode.LAST_TRADING_DAY, #OPEN_INTEREST
    #     data_normalization_mode=DataNormalizationMode.BACKWARDS_RATIO,
    #     contract_depth_offset=0, 
    #     resolution=Resolution.MINUTE
    # )
    # future.set_filter(0, 60)

    # delta_hedges: List[DeltaHedge] = [
    #     EquityOptionsDeltaHedge(
    #         algo, 
    #         trade_manager=None, 
    #         schedule_rule=ScheduleRule(
    #             algo.date_rules.every_day(spx), 
    #             algo.time_rules.before_market_close(spx, 60)
    #             # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
    #         )
    #     ),
    #     IndexDeltaHedge(
    #         algo, 
    #         underlying_index=spx, 
    #         continuous_future=future, 
    #         trade_manager=None, 
    #         schedule_rule=ScheduleRule(
    #             algo.date_rules.every_day(spx), 
    #             algo.time_rules.before_market_close(spx, 60)
    #             # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
    #         )
    #     )
    # ]
    delta_hedges = []

    algo.set_name('Straddle_')
    
    # trade_manager = BaseTradeManager(
    trade_manager = BaseTradeManager_(
        algo=algo,
        enter_strategy=enter_strat, 
        calc_trade_date_rule=calc_trade_date_rule, 
        trade_date_rule=trade_date_rule,
        subscription_resolution=Resolution.HOUR, 
        trade_id_generator=trade_id_generator,
        trade_history_length=-1, 
        exit_strategy=exit_strat,
        exit_date_rule=exit_date_rule, 
        delta_hedge_strategies=delta_hedges, 
        gc_date_rule=None, 
        filter=filter
    )

    return trade_manager

def buy_put_by_moneyness_test(algo):
    algo.set_name('buy_put_by_moneyness_test')
    spx = algo.add_index('SPX', Resolution.MINUTE).symbol
    canonical_options = [Symbol.CreateCanonicalOption(spx, 'SPX', Market.USA, '?SPX')]
    
    calc_trade_date_rule = ScheduleRule(
        algo.DateRules.EveryDay(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60)
    )
        
    trade_date_rule = ScheduleRule(
        algo.DateRules.EveryDay(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 30)
    )

    gc_date_rule = ScheduleRule(
        algo.DateRules.week_start(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))

    trade_id_generator = IncrementIdGenerator(0)
    filter = None
    exit_strat = None
    providers_chain = [
        PutsOptionProvider(algo), 
        ClosestExpiryOptionProvider(algo, 5),
        ClosestStrikeOptionProvider(algo, 1.)
    ]

    option_provider = ChainOptionProviders(algo, providers_chain)
    quantity_calculator = FractionOfCashQuantityCalculator(algo, 2/52, direction=Direction.LONG) 

    enter_strat = BaseEnterStrategy(
        algo, 
        option_provider, 
        quantity_calculator, 
        canonical_options
    )

    trade_calculator = BaseTradeManager(
        algo=algo,
        enter_strategy=enter_strat, 
        canonical_options= canonical_options, 
        calc_trade_date_rule=calc_trade_date_rule, 
        trade_date_rule=trade_date_rule,
        subscription_resolution=Resolution.MINUTE, 
        trade_id_generator=trade_id_generator,
        trade_history_length=-1, 
        exit_strategy=exit_strat,
        exit_date_rule=trade_date_rule, 
        gc_date_rule=gc_date_rule, 
        filter=filter
    )
    return trade_calculator

def sector_dispersion(algo): 
    spx = algo.add_index('SPX').symbol
    canonical_option = Symbol.CreateCanonicalOption(spx, 'SPX', Market.USA, '?SPX')

    calc_trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 65))
    
    trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))
    
    exit_date_rule = ScheduleRule(
        algo.DateRules.EveryDay(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))
    trade_id_generator = IncrementIdGenerator(0)

    gc_date_rule = ScheduleRule(
        algo.DateRules.month_start(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))

    #trade filter
    #filter = VixThresholdFilter(algo, threshold=30)
    filter = None

    exit_strat = None
    
    #enter strategy
    providers_chain = [ClosestExpiryOptionProvider(algo, 180), ClosestStrikeOptionProvider(algo, 1.0)]
    option_provider = ChainOptionProviders(algo, providers_chain)

    option_provider = DispersionOptionProvider(algo, 0.1, no_of_options=4, days_to_expiry=180)

    # single stock legs
    single_stock_option_provider = option_provider
    # single_stock_quantity_calculator = SectorETFMarketCapQuantityCalculator(algo, Direction.LONG)
    single_stock_quantity_calculator = SectorETFMarketCapTargetGreekQuantityCalculator(algo, Direction.LONG, greek_name='vega', greek_target_cash_terms=10e6)
    # single_stock_quantity_calculator = SectorETFMarketCapTargetNotionalQuantityCalculator(algo, Direction.LONG, notional_fraction_target=1/52)
    
    # index leg
    index_option_provider = option_provider
    index_quantity_calculator = FractionOfHoldingQuantityCalculator(
        algo, 
        notional_frac=1/(52), 
        direction=Direction.SHORT
    )
    
    # start_time = time.time()
    enter_strat = SectorDispersionEnterStrategy(
        algo, 
        index_option_provider=index_option_provider,
        single_stock_option_provider=single_stock_option_provider,
        single_stock_quantity_calculator=single_stock_quantity_calculator,
        index_quantity_calculator=index_quantity_calculator,
        etf_symbol='SPY', 
        canonical_index_option=canonical_option, 
    )

    future: Future = algo.add_future(
        Futures.Indices.SP_500_E_MINI,
        extended_market_hours=True,
        data_mapping_mode=DataMappingMode.LAST_TRADING_DAY, #OPEN_INTEREST
        data_normalization_mode=DataNormalizationMode.BACKWARDS_RATIO,
        contract_depth_offset=0, 
        resolution=Resolution.MINUTE
    )
    future.set_filter(0, 60)
    
    delta_hedges: List[DeltaHedge] = [
        EquityOptionsDeltaHedge(
            algo, 
            trade_manager=None, 
            schedule_rule=ScheduleRule(
                algo.date_rules.every_day(spx), 
                algo.time_rules.before_market_close(spx, 60)
                # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
            )
        ),
        IndexDeltaHedge(
            algo, 
            underlying_index=spx, 
            continuous_future=future, 
            trade_manager=None, 
            schedule_rule=ScheduleRule(
                algo.date_rules.every_day(spx), 
                algo.time_rules.before_market_close(spx, 60)
                # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
            )
        )
    ]
    delta_hedges = []

    # elapsed_time4 = time.time() - start_time
    # algo.Log(f" Dispersion in {elapsed_time4:.4f} seconds")
    
    algo.set_name('Sector Dispersion')
    
    trade_manager = BaseTradeManager(
        algo=algo,
        enter_strategy=enter_strat, 
        canonical_options=canonical_option, 
        calc_trade_date_rule=calc_trade_date_rule, 
        trade_date_rule=trade_date_rule,
        subscription_resolution=Resolution.HOUR, 
        trade_id_generator=trade_id_generator,
        trade_history_length=-1, 
        exit_strategy=exit_strat,
        exit_date_rule=exit_date_rule, 
        delta_hedge_strategies=delta_hedges, 
        gc_date_rule=gc_date_rule, 
        filter=filter
    )

    return trade_manager

def dispersion(algo): 
    spx = algo.add_index('SPX').symbol
    canonical_option = Symbol.CreateCanonicalOption(spx, 'SPX', Market.USA, '?SPX')

    calc_trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 65))
    
    trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))
    
    exit_date_rule = ScheduleRule(
        algo.DateRules.EveryDay(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))

    gc_date_rule = ScheduleRule(
        algo.DateRules.month_start(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))

    trade_id_generator = IncrementIdGenerator(0)

    #trade filter
    #filter = VixThresholdFilter(algo, threshold=30)
    filter = None

    exit_strat = None
    
    #enter strategy
    # providers_chain = [ClosestExpiryOptionProvider(algo, 180), ClosestStrikeOptionProvider(algo, 1.0)]
    # option_provider = ChainOptionProviders(algo, providers_chain)

    option_provider = DispersionOptionProvider(algo, 0.1, no_of_options=4, days_to_expiry=180)
    # option_provider = DispersionOptionProvider(algo, 0.1, no_of_options=4, days_to_expiry=30)
    # option_provider = DispersionOptionProvider(algo, 0.1, no_of_options=2, days_to_expiry=30)

    # single stock legs
    single_stock_option_provider = option_provider
    single_stock_quantity_calculator = UnderlyingMarketCapQuantityCalculator(algo, Direction.LONG)
    
    # index leg
    index_option_provider = option_provider
    index_quantity_calculator = FractionOfHoldingQuantityCalculator(
        algo, 
        notional_frac=1/(52), 
        direction=Direction.SHORT
    )
    
    # start_time = time.time()
    enter_strat = DispersionEnterStrategy(
        algo, 
        single_stock_option_provider=single_stock_option_provider,
        single_stock_quantity_calculator=single_stock_quantity_calculator,
        index_option_provider=index_option_provider,
        index_quantity_calculator=index_quantity_calculator,
        no_of_stocks=5,
        etf_symbol='SPY', 
        canonical_index_option=canonical_option
    )

    # future: Future = algo.add_future(
    #     Futures.Indices.SP_500_E_MINI,
    #     extended_market_hours=True,
    #     data_mapping_mode=DataMappingMode.LAST_TRADING_DAY, #OPEN_INTEREST
    #     data_normalization_mode=DataNormalizationMode.BACKWARDS_RATIO,
    #     contract_depth_offset=0, 
    #     resolution=Resolution.MINUTE
    # )
    # future.set_filter(0, 60)
    
    # delta_hedges: List[DeltaHedge] = [
    #     EquityOptionsDeltaHedge(
    #         algo, 
    #         trade_manager=None, 
    #         schedule_rule=ScheduleRule(
    #             algo.date_rules.every_day(spx), 
    #             algo.time_rules.before_market_close(spx, 60)
    #             # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
    #         )
    #     ),
    #     IndexDeltaHedge(
    #         algo, 
    #         underlying_index=spx, 
    #         continuous_future=future, 
    #         trade_manager=None, 
    #         schedule_rule=ScheduleRule(
    #             algo.date_rules.every_day(spx), 
    #             algo.time_rules.before_market_close(spx, 60)
    #             # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
    #         )
    #     )
    # ]
    delta_hedges = []

    # elapsed_time4 = time.time() - start_time
    # algo.Log(f" Dispersion in {elapsed_time4:.4f} seconds")
    
    algo.set_name('Dispersion')
    
    trade_manager = BaseTradeManager(
        algo=algo,
        enter_strategy=enter_strat, 
        canonical_options=canonical_option, 
        calc_trade_date_rule=calc_trade_date_rule, 
        trade_date_rule=trade_date_rule,
        subscription_resolution=Resolution.HOUR, 
        # subscription_resolution=Resolution.DAILY, 
        trade_id_generator=trade_id_generator,
        trade_history_length=-1, 
        exit_strategy=exit_strat,
        exit_date_rule=exit_date_rule, 
        delta_hedge_strategies=delta_hedges, 
        gc_date_rule=gc_date_rule, 
        filter=filter
    )

    return trade_manager

def buy_put_by_moneyness(algo):
    algo.set_name('buy_put_85')
    spx = algo.add_index('SPX').symbol
    canonical_options = [Symbol.CreateCanonicalOption(spx, 'SPX', Market.USA, '?SPX')]
    
    calc_trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60)
    )
        
    trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 30)
    )

    gc_date_rule = ScheduleRule(
        algo.DateRules.month_start(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))

    trade_id_generator = IncrementIdGenerator(0)
    filter = None
    exit_strat = None
    providers_chain = [
        PutsOptionProvider(algo), 
        ClosestExpiryOptionProvider(algo, 180),
        ClosestStrikeOptionProvider(algo, 0.85)
    ]
    option_provider = ChainOptionProviders(algo, providers_chain)
    quantity_calculator = FractionOfCashQuantityCalculator(algo, 2/52, direction=Direction.LONG) 

    enter_strat = BaseEnterStrategy(algo, option_provider, quantity_calculator, canonical_options)
        
    trade_calculator = BaseTradeManager(
        algo=algo,
        enter_strategy=enter_strat, 
        canonical_options= canonical_options, 
        calc_trade_date_rule=calc_trade_date_rule, 
        trade_date_rule=trade_date_rule,
        subscription_resolution=Resolution.MINUTE, 
        trade_id_generator=trade_id_generator,
        trade_history_length=-1, 
        exit_strategy=exit_strat,
        exit_date_rule=trade_date_rule, 
        gc_date_rule=gc_date_rule, 
        filter=filter
    )

    return trade_calculator

def sell_call_by_vix_regression(algo):
    spx = algo.add_index('SPX').symbol
    vix = algo.add_index('VIX', Resolution.MINUTE).symbol
    canonical_options = [Symbol.create_canonical_option(spx, 'SPX', Market.USA, '?SPX')]
    
    # regression params
    store_date_rule: ScheduleRule = ScheduleRule(
        algo.date_rules.every_day(spx), 
        algo.time_rules.before_market_close(spx, 60)
    )
    regr_date_rule: ScheduleRule = ScheduleRule(
        algo.date_rules.month_start(spx), 
        algo.time_rules.before_market_close(spx, 30)
    )
    # regr_providers_chain: List[OptionProvider] = [
    #     PutsOptionProvider(algo), 
    #     StrikeRangeFromUnderlyingOptionProvider(algo, 0.3, 1.5),
    #     # ClosestDeltaOptionsProvider(algo, 0.2)
    #     ClosestDeltaOptionsProvider_(algo, 0.2)
    # ]
    
    regr_providers_chain: List[OptionProvider] = [
        PutsOptionProvider(algo), 
        ClosestExpiryOptionProvider(algo, 5), 
        StrikeRangeFromUnderlyingOptionProvider(algo, 0.7, 1.),
        # ClosestDeltaOptionsProvider(algo, 0.2)
        ClosestDeltaOptionsProvider_(algo, 0.2)
    ]

    regr_option_provider: OptionProvider = ChainOptionProviders(algo, regr_providers_chain)
    regr_eval_fn: VIXLinearFunction = VIXLinearFunction(algo, vix)

    regression_storage: VIXMoneynessRegressionStorage = VIXMoneynessRegressionStorage(
        algo=algo, 
        option_provider=regr_option_provider, 
        vix=vix,
        store_date_rule=store_date_rule, 
        regress_date_rule=regr_date_rule, 
        canonical_options=canonical_options, 
        eval_function=regr_eval_fn, 
        min_period=21, 
        expanding=True
    )

    calc_trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60)
    )
        
    trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 30)
    )
        
    trade_id_generator = IncrementIdGenerator(2)
    
    providers_chain = [
        CallsOptionProvider(algo), 
        ClosestExpiryOptionProvider(algo, 30),
        VIXMoneynessRegressionOptionProvider(algo, regression_storage)
    ]

    option_provider = ChainOptionProviders(algo, providers_chain)
    quantity_calculator = FractionOfCashQuantityCalculator(algo, 2*1/52, Direction.SHORT, 100)
    enter_strat = BaseEnterStrategy(algo, option_provider, quantity_calculator, canonical_options)
    exit_strat = None

    future: Future = algo.add_future(
        Futures.Indices.SP_500_E_MINI,
        extended_market_hours=True,
        data_mapping_mode=DataMappingMode.LAST_TRADING_DAY, #OPEN_INTEREST
        data_normalization_mode=DataNormalizationMode.BACKWARDS_RATIO,
        contract_depth_offset=0, 
        resolution=Resolution.MINUTE
    )
    future.set_filter(0, 20)

    delta_hedges: List[DeltaHedge] = [
        IndexDeltaHedge(
            algo, 
            underlying_index=spx, 
            continuous_future=future, 
            trade_manager=None, 
            schedule_rule=ScheduleRule(
                algo.date_rules.every_day(spx), 
                algo.time_rules.before_market_close(spx, 60)
                # algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
            )
        )
    ]

    trade_calculator = BaseTradeManager(
        algo=algo,
        enter_strategy=enter_strat, 
        canonical_options= canonical_options, 
        calc_trade_date_rule=calc_trade_date_rule, 
        trade_date_rule=trade_date_rule,
        subscription_resolution=Resolution.MINUTE, 
        trade_id_generator=trade_id_generator,
        trade_history_length=-1, 
        exit_strategy=exit_strat,
        exit_date_rule=trade_date_rule, 
        delta_hedge_strategies=delta_hedges
    )

    return trade_calculator

def sell_call_by_vix(algo):
    spx = algo.add_index('SPX').symbol
    vix = algo.add_index('VIX', Resolution.MINUTE).symbol
    canonical_options = [Symbol.CreateCanonicalOption(spx, 'SPX', Market.USA, '?SPX')]
    
    calc_trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60)
    )
        
    trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 30)
    )
        
    trade_id_generator = IncrementIdGenerator(2)
    
    vix_fn: VIXLinearFunction = VIXLinearFunction(algo, vix)
    eval_function: Callable = vix_fn.eval

    providers_chain = [
        CallsOptionProvider(algo), 
        ClosestExpiryOptionProvider(algo, 90),
        VIXMoneynessOptionProvider(algo, eval_function)
    ]

    option_provider = ChainOptionProviders(algo, providers_chain)
    quantity_calculator = FractionOfCashQuantityCalculator(algo, 2*1/52, Direction.SHORT, 100)
    enter_strat = BaseEnterStrategy(algo, option_provider, quantity_calculator, canonical_options)
    exit_strat = None

    delta_hedge_strategy: DeltaHedge = EquityOptionsDeltaHedge(algo)
    delta_hedge_date_rule: ScheduleRule = ScheduleRule(
        algo.date_rules.every_day(spx), 
        algo.time_rules.every(TimeSpan.FromHours(2)) # starting at midnight in case of hours/minutes
    )

    trade_calculator = BaseTradeManager(
        algo=algo,
        enter_strategy=enter_strat, 
        canonical_options= canonical_options, 
        calc_trade_date_rule=calc_trade_date_rule, 
        trade_date_rule=trade_date_rule,
        subscription_resolution=Resolution.MINUTE, 
        trade_id_generator=trade_id_generator,
        trade_history_length=-1, 
        exit_strategy=exit_strat,
        exit_date_rule=trade_date_rule,
        delta_hedge_strategy=delta_hedge_strategy, 
        delta_hedge_date_rule=delta_hedge_date_rule
    )

    return trade_calculator

def sell_call_x_delta(algo):
    spx = algo.add_index('SPX').symbol
    canonical_options = [Symbol.CreateCanonicalOption(spx, 'SPX', Market.USA, '?SPX')]
    
    calc_trade_date_rule = ScheduleRule(
            algo.DateRules.WeekStart(spx), 
            algo.TimeRules.BeforeMarketClose(spx, 60))
        
    trade_date_rule = ScheduleRule(
            algo.DateRules.WeekStart(spx), 
            algo.TimeRules.BeforeMarketClose(spx, 30))
        
    trade_id_generator = IncrementIdGenerator(2)
        
    #exit_strat = ExitByDeltaThreshold(algo=self, delta=0.1)
    providers_chain = [CallsOptionProvider(algo), 
                           ClosestExpiryOptionProvider(algo, 180),
                           ClosestDeltaOptionsProvider(algo, 0.1)
                           ]

    option_provider = ChainOptionProviders(algo, providers_chain)
    quantity_calculator = FractionOfCashQuantityCalculator(algo, 2*1/52, Direction.SHORT, 100)
    enter_strat = BaseEnterStrategy(algo, option_provider, quantity_calculator, canonical_options)
    exit_strat = None
   
    trade_calculator = BaseTradeManager(
            algo=algo,
            enter_strategy=enter_strat, 
            canonical_options= canonical_options, 
            calc_trade_date_rule=calc_trade_date_rule, 
            trade_date_rule=trade_date_rule,
            subscription_resolution=Resolution.MINUTE, 
            trade_id_generator=trade_id_generator,
            trade_history_length=-1, 
            exit_strategy=exit_strat,
            exit_date_rule=trade_date_rule
            )
    return trade_calculator

def budget_put_spread(algo): 
    spx = algo.add_index('SPX').symbol
    canonical_options = [Symbol.CreateCanonicalOption(spx, 'SPX', Market.USA, '?SPX')]

    calc_trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 65))
    
    trade_date_rule = ScheduleRule(
        algo.DateRules.WeekStart(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))
    
    exit_date_rule = ScheduleRule(
        algo.DateRules.EveryDay(spx), 
        algo.TimeRules.BeforeMarketClose(spx, 60))
    trade_id_generator = IncrementIdGenerator(0)
    
    #trade filter
    filter = VixThresholdFilter(algo, threshold=30)
    
    # exit strat
    asset_exit_filter = PutOptionFilter()
    exit_strat = ExitByDeltaThreshold(algo=algo, delta=0.5, threshold_type=ThresholdType.ABOVE,
    asset_filter=asset_exit_filter,
    exit_related_trades=True)

    exit_strat = None
    
    #leg1 enter
    providers_chain = [PutsOptionProvider(algo), 
                        ClosestExpiryOptionProvider(algo, 180),
                        ClosestStrikeOptionProvider(algo, 1.0)
                        ]
    
    option_provider1 = ChainOptionProviders(algo, providers_chain)
    
    providers_chain = [PutsOptionProvider(algo), 
                        ClosestExpiryOptionProvider(algo, 180),
                        ClosestStrikeOptionProvider(algo, 0.85)
                        ]
    option_provider2 = ChainOptionProviders(algo, providers_chain)

    option_provider = UnionOptionProvider(algo, [option_provider1, option_provider2])
    quantity_calculator = BudgetQuantityCalculatorTrial(algo, 0.03/52, direction_vec=[Direction.LONG, Direction.SHORT])
    leg1_enter_strat = BaseEnterStrategy(algo, option_provider, quantity_calculator, canonical_options)

    #leg2 enter
    providers_chain = [CallsOptionProvider(algo), 
                        ClosestExpiryOptionProvider(algo, 180),
                        ]
    options_pool_provider = ChainOptionProviders(algo, providers_chain)
    quantity_calculator = FractionOfCashQuantityCalculator(algo, 2/52, Direction.SHORT, 100)

    enter_strat = TwoLegsBudgetEnterStrategy(algo, 
            budget_frac = 0.005/52, 
            main_leg_enter_strategy=leg1_enter_strat,
            budget_leg_option_provider = options_pool_provider, 
            budget_leg_quantity_calculator = quantity_calculator,
            canonical_options=canonical_options,
            lower_bound = 0.8, 
            upper_bound = 1.2,
            contract_multiplier = 100 
            )

    algo.set_name('collar_180_15_2x_not_cap')
    
    trade_manager = BaseTradeManager(
        algo=algo,
        enter_strategy=enter_strat, 
        canonical_options= canonical_options, 
        calc_trade_date_rule=calc_trade_date_rule, 
        trade_date_rule=trade_date_rule,
        subscription_resolution=Resolution.MINUTE, 
        trade_id_generator=trade_id_generator,
        trade_history_length=-1, 
        exit_strategy=exit_strat,
        exit_date_rule=exit_date_rule, 
        filter=filter
        )

    return trade_manager
# region imports
from AlgorithmImports import *
from OptionsTrading.option_providers import *
import time
from datetime import datetime
# endregion

class InterpolatedStrikesOptionProvider(OptionProvider):
    def __init__(self, algo, option_provider1, option_provider2, no_of_options):
        super().__init__(algo)
        self._opt_provider1 = option_provider1
        self._opt_provider2 = option_provider2
        self._no_of_options = no_of_options
        self._strike_opt_provider = ClosestStrikeOptionProviderNoUnderlying(algo, 0)


    def get_options(self, options):
        # if self._algo.Time == datetime(2022, 6, 6, 15, 10):
        #     pass

        opt1 = self._opt_provider1.get_options(options)
        # assert len(opt1) == 1
        opt2 = self._opt_provider2.get_options(options)
        # assert len(opt2) == 1
        strike1 = next(iter(opt1)).id.strike_price
        strike2 = next(iter(opt2)).id.strike_price

        if strike1 < strike2: 
            low = strike1
            high = strike2
        
        else: 
            low = strike2
            high = strike1

        strikes = np.linspace(low, high, self._no_of_options).tolist()[1:-1]
        res = OptionsData()
        for strike in strikes:
            self._strike_opt_provider.strike = strike
            res += self._strike_opt_provider.get_options(options)
        
        res = res + opt1 + opt2
        return res
    
class DispersionOptionProvider(OptionProvider):
    def __init__(self, algo, delta, days_to_expiry, no_of_options):
        super().__init__(algo)
        self._no_of_options = no_of_options
        self._target_delta = delta

        self._puts_opt_provider = ChainOptionProviders(algo, [PutsOptionProvider(algo), 
                        ClosestExpiryOptionProvider(algo, days_to_expiry)]) 
        self._calls_opt_provider = ChainOptionProviders(algo, [CallsOptionProvider(algo), 
                        ClosestExpiryOptionProvider(algo, days_to_expiry)]) 
        self._atm_opt_provider = ClosestStrikeOptionProvider(algo, 1.0)
        self._delta_opt_provider = DeltaOptionProvider(algo, delta=delta)
        self._interpolated_opt_provider = InterpolatedStrikesOptionProvider(algo, self._atm_opt_provider, self._delta_opt_provider, no_of_options)

        #self._interpolated_put_opt_provider = InterpolatedStrikesOptionProvider(algo, self._atm_put_opt_provider, self._delta_opt_provider, no_of_options)


    def get_options(self, options):
        if len(options) == 0: 
            return OptionsData()
        start_time = time.time()
        call_options = self._calls_opt_provider.get_options(options)
        put_options = self._puts_opt_provider.get_options(options)
        
        call_options = self._interpolated_opt_provider.get_options(call_options)


        if list(call_options.option_symbols)[0].underlying.value=='VZ' and self._algo.time.date()== datetime(2023, 1, 3).date():
            self._algo.Log(f'Warning there aless than{self._no_of_options}')

        put_options = self._interpolated_opt_provider.get_options(put_options)


        return call_options + put_options

class DeltaOptionProvider(OptionProvider):
    def __init__(self, algo: QCAlgorithm, delta: float, threshold=0, resolution=None):
        super().__init__(algo)
        self._target_delta = delta
        self._resolution = resolution if resolution else Resolution.MINUTE
        self._threshold=threshold

    def get_options(self, options):
        deltas_dict = {}

        for symbol, contract in options.items(): 
            delta = abs(contract.greeks.delta)

            if delta in deltas_dict:
                deltas_dict[delta].append(symbol)
            
            elif delta >= self._threshold:
                deltas_dict[delta] = [symbol]
        
        if len(deltas_dict) == 0:
            return OptionsData() 

        deltas = sorted(deltas_dict.keys(), key=lambda x: abs(x - self._target_delta))
        
        return options.get_subset(deltas_dict[deltas[0]])
# region imports
from AlgorithmImports import *
from config_new import *
from cProfile import Profile
from pstats import Stats
from io import StringIO
# endregion

# TODO
# indicator; daily monitor
# value = portf value - value of all put options in a portfolio
# n days of values

# monitazatyin function takes/creates that indicator
# params: 1. list of canonicals
#         2. direction

# indicator remembers
# value of put options
# value of portfolio
# total value

profile = Profile()
profile.enable()

class SPX_Options(QCAlgorithm):
    def initialize(self):
        # self.set_start_date(2012, 1, 1)
        # self.set_start_date(2017, 1, 1)
        # self.set_start_date(2021, 1, 1)
        self.set_start_date(2025, 1, 1)
        
        self.set_cash(1_000_000_000)

        self.set_security_initializer(
            MySecurityInitializer(self.BrokerageModel, FuncSecuritySeeder(self.get_last_known_prices))
        )

        # self.settings.performance_sample_period = timedelta(7)

        # trade_calculator = sector_dispersion_(self)
        trade_calculator = dispersion_(self)
        # trade_calculator = straddle_(self)
        
        # trade_calculator = sell_call_by_vix(self)
        # trade_calculator = sell_call_by_vix_regression(self)
        # trade_calculator = dispersion(self)
        # trade_calculator = buy_put_by_moneyness_test(self)
        # trade_calculator = sector_dispersion(self)
        # trade_calculator = buy_put_by_moneyness(self)

        # order counts
        self._market_order_futures_count: int = 0
        self._market_order_equity_count: int = 0
        self._market_order_options_count: int = 0
        self._market_order_innovation_count: int = 0
        self._expired_count: int = 0
        self._MOC_order_count: int = 0
        self._other_order_count: int = 0

    def on_data(self, slice: Slice) -> None:
        pass

    # def on_order_event(self, order_event: OrderEvent) -> None:
    #     return

    #     if order_event.status == OrderStatus.FILLED:
    #         if order_event.symbol.security_type in [SecurityType.FUTURE]:
    #             if order_event.ticket.order_type == OrderType.MARKET:
    #                 self._market_order_futures_count += 1

    #         if order_event.symbol.security_type in [SecurityType.EQUITY]:
    #             if order_event.ticket.order_type == OrderType.MARKET:
    #                 self._market_order_equity_count += 1

    #         if order_event.symbol.security_type in [SecurityType.OPTION, SecurityType.INDEX_OPTION]:
    #             if order_event.ticket.order_type == OrderType.MARKET:
    #                 self._market_order_options_count += 1
                    
    #                 # TODO this should be not in trade history instead
    #                 # if not self.portfolio[order_event.symbol].invested:
    #                 #     self._market_order_innovation_count += 1

    #             elif order_event.ticket.order_type == OrderType.OPTION_EXERCISE:
    #                 self._expired_count += 1
    #             elif order_event.ticket.order_type == OrderType.MARKET_ON_CLOSE:
    #                 self._MOC_order_count += 1
    #             else:
    #                 self._other_order_count += 1

    # def on_end_of_algorithm(self):
    #     # self.log(f'Market Order Innovation Count (only new option positions, with no overlap): {self._market_order_innovation_count}')
    #     self.log(f'Market Order Option Count (our option trades): {self._market_order_options_count}')
    #     self.log(f'Market Order Equity Count (our equity trades): {self._market_order_equity_count}')
    #     self.log(f'Market Order Futures Count (our hedge future trades): {self._market_order_futures_count}')
    #     self.log(f'Expiration Order Count: {self._expired_count}')
    #     self.log(f'Market On Close Count (probably split handling): {self._MOC_order_count}')
    #     self.log(f'Other Order Count (inspect if any): {self._other_order_count}')

    def on_end_of_algorithm(self):
        # Stop collecting profiling data
        profile.disable()
        stream = StringIO()
        # Save the top 20 time-consuming functions to a file in the Object Store.
        Stats(profile, stream=stream).sort_stats('cumulative').print_stats(20)
        # Save the profiling data to the Object Store using the algorithm ID.
        self.object_store.save(f"{self.algorithm_id}_profile", stream.getvalue())

class MySecurityInitializer(BrokerageModelSecurityInitializer):

    def __init__(self, brokerage_model: IBrokerageModel, security_seeder: ISecuritySeeder) -> None:
        super().__init__(brokerage_model, security_seeder)

    def Initialize(self, security: Security) -> None:
        # First, call the superclass definition
        # This method sets the reality models of each security using the default reality models of the brokerage model
        super().initialize(security)

        # Next, overwrite the security buying power        
        security.set_buying_power_model(BuyingPowerModel.NULL)
        
        if security.type in [SecurityType.OPTION, SecurityType.INDEX_OPTION]:
            # security.style = OptionStyle.EUROPEAN
            security.price_model = OptionPriceModels.binomial_cox_ross_rubinstein()
            # security.price_model = OptionPriceModels.crank_nicolson_fd()
            # security.price_model = OptionPriceModels.black_scholes()

        security.set_fee_model(ConstantFeeModel(0))
#region imports
from AlgorithmImports import *
from abc import ABC, abstractmethod
from trade_managers import TradeManager
from data import *
#endregion

class Charts(ABC):
    def __init__(
        self, 
        algo, 
        chart_name, 
        series_list, 
        scheduled_update: bool = True, 
        schedule_rule: ScheduleRule = None
    ):
        self._algo = algo
        self._chart_name = chart_name
        self._series_list = series_list
        self._add_chart()

        self._schedule_rule: ScheduleRule = schedule_rule
        
        if scheduled_update:
            self._scheduler()

    def _add_chart(self):
        chart = Chart(self._chart_name)
        self._algo.AddChart(chart)
        for series in self._series_list:
            chart.AddSeries(Series(series['name'], series['type'], series['unit']))

    def _scheduler(self):
        algo = self._algo
        if self._schedule_rule is not None:
            algo.Schedule.On(
                self._schedule_rule.date_rule, 
                self._schedule_rule.time_rule, 
                self._update_chart
            )
        else:
            algo.Schedule.On(
                algo.DateRules.EveryDay('SPX'), 
                algo.TimeRules.BeforeMarketClose('SPX', 0), 
                self._update_chart
            )
    
    @abstractmethod
    def _update_chart(self):
        pass

# TODO merge GreeksChart_ and NotionalChart_ if possible
class GreeksChart_(Charts):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        trade_manager: TradeManager, 
        greeks: List[str] = ['gamma', 'delta', 'vega', 'theta'], 
        greeks_type: GreeksType = GreeksType.PERCENTAGE, 
        portfolio_type: PortfolioType = PortfolioType.WHOLE_PORTFOLIO, 
        schedule_rule: ScheduleRule = None
    ) -> None:

        scheduled_update: bool = True

        # it does not update chart automatically uisng scheduled rule, it updates it on a new trade being made
        if portfolio_type == PortfolioType.LAST_TRADE:
            scheduled_update: bool = False
            trade_manager._enter_trade_event += self._manual_update

        self._trade_manager: TradeManager = trade_manager

        self._greeks: List[str] = greeks
        self._greeks_type: GreeksType = greeks_type
        self._portfolio_type: PortfolioType = portfolio_type

        index_series_list: List[Dict] = [
            {
                'name': f'index_leg_{greek}', 
                'type': SeriesType.Line, 
                'unit': ''
            } for greek in greeks
        ]
        stock_series_list: List[Dict] = [
            {
                'name': f'stock_leg_{greek}', 
                'type': SeriesType.Line, 
                'unit': ''
            } for greek in greeks
        ]
        series_list: List[Dict] = index_series_list + stock_series_list

        super().__init__(
            algo=algo, 
            chart_name=f'Greeks_{portfolio_type.name}', 
            series_list=series_list, 
            scheduled_update=scheduled_update, 
            schedule_rule=schedule_rule
        )
    
    # automatic - scheduled update
    def _update_chart(self) -> None:
        assert self._portfolio_type == PortfolioType.WHOLE_PORTFOLIO
        
        trade_hist: TradeCollection = self._trade_manager.trades_hist
        
        # NOTE assuming we are trading only one index underlying
        index_options_trade_data, stock_options_trade_data = self._get_legs_trade_data(
            collection=trade_hist.positions, 
            only_invested=True
        )

        self._print_greeks(index_options_trade_data, stock_options_trade_data, is_new_trade=False)
    
    # manual - non scheduled update using last TradeData
    def _manual_update(self, trade_data: TradeData) -> None:
        assert self._portfolio_type == PortfolioType.LAST_TRADE
        
        index_options_trade_data, stock_options_trade_data = self._get_legs_trade_data(
            collection=trade_data._trade_dict, 
            only_invested=False
        )

        self._print_greeks(index_options_trade_data, stock_options_trade_data, is_new_trade=True)

    def _get_legs_trade_data(
        self, 
        collection: Dict[Symbol, Union[float, int]], 
        only_invested: bool
    ) -> Tuple[TradeData]:

        # TODO in case 'collection' comes from TradeData, there's no need to create new TradeData out of it
        # alternative to filter_invested_options that takes TradeData should be implemented

        # create TradeData for both legs
        index_options_trade_data: TradeData = TradeData(
            {
                symbol : collection[symbol] for symbol in filter_invested_options(
                    algo=self._algo, 
                    positions=collection, 
                    option_types=[SecurityType.INDEX_OPTION], 
                    underlying_types=[SecurityType.INDEX], 
                    only_invested=only_invested
                )
            }
        )

        stock_options_trade_data: TradeData = TradeData(
            {
                symbol : collection[symbol] for symbol in filter_invested_options(
                    algo=self._algo, 
                    positions=collection, 
                    option_types=[SecurityType.OPTION], 
                    underlying_types=[SecurityType.EQUITY], 
                    only_invested=only_invested
                )
            }
        )

        return index_options_trade_data, stock_options_trade_data
    
    def _print_greeks(
        self, 
        index_options_trade_data: TradeData, 
        stock_options_trade_data: TradeData,
        is_new_trade: bool
    ) -> None:

        for greek in self._greeks:
            index_leg_greek: float = calc_trade_greek(
                algo=self._algo, 
                trade_data=index_options_trade_data, 
                greek_name=greek, 
                percentage_terms=self._greeks_type == GreeksType.PERCENTAGE, 
                is_new_trade=is_new_trade, 
                sender_str='GreeksChart_', 
            )

            stock_leg_greek: float = calc_trade_greek(
                algo=self._algo, 
                trade_data=stock_options_trade_data, 
                greek_name=greek, 
                percentage_terms=self._greeks_type == GreeksType.PERCENTAGE, 
                is_new_trade=is_new_trade, 
                sender_str='GreeksChart_', 
            )
            
            self._algo.plot(self._chart_name, f'index_leg_{greek}', index_leg_greek)
            self._algo.plot(self._chart_name, f'stock_leg_{greek}', stock_leg_greek)

class NotionalChart_(Charts):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        trade_manager: TradeManager, 
        portfolio_type: PortfolioType = PortfolioType.WHOLE_PORTFOLIO, 
        schedule_rule: ScheduleRule = None
    ) -> None:

        scheduled_update: bool = True

        # it does not update chart automatically uisng scheduled rule, it updates it on a new trade being made
        if portfolio_type == PortfolioType.LAST_TRADE:
            scheduled_update: bool = False
            trade_manager._enter_trade_event += self._manual_update

        self._trade_manager: TradeManager = trade_manager
        self._portfolio_type: PortfolioType = portfolio_type

        series_list: List[Dict] = [
            {
                'name': f'index_leg_notional', 
                'type': SeriesType.Line, 
                'unit': '$'
            },
            {
                'name': f'stock_leg_notional', 
                'type': SeriesType.Line, 
                'unit': '$'
            },
            {
                'name': f'notional_ratio [stock/index]', 
                'type': SeriesType.Line, 
                'unit': ''
            }
        ]

        super().__init__(
            algo=algo, 
            chart_name=f'Notional_{portfolio_type.name}', 
            series_list=series_list, 
            scheduled_update=scheduled_update, 
            schedule_rule=schedule_rule
        )
    
    # automatic - scheduled update
    def _update_chart(self) -> None:
        assert self._portfolio_type == PortfolioType.WHOLE_PORTFOLIO
        
        trade_hist: TradeCollection = self._trade_manager.trades_hist
        
        # NOTE assuming we are trading only one index underlying
        index_options_trade_data, stock_options_trade_data = self._get_legs_trade_data(
            collection=trade_hist.positions, 
            only_invested=True
        )

        self._print_notional(index_options_trade_data, stock_options_trade_data)
    
    # manual - non scheduled update
    def _manual_update(self, trade_data: TradeData) -> None:
        assert self._portfolio_type == PortfolioType.LAST_TRADE
        
        index_options_trade_data, stock_options_trade_data = self._get_legs_trade_data(
            collection=trade_data._trade_dict, 
            only_invested=False
        )

        self._print_notional(index_options_trade_data, stock_options_trade_data)

    def _get_legs_trade_data(
        self, 
        collection: Dict[Symbol, Union[float, int]], 
        only_invested: bool
    ) -> Tuple[TradeData]:

        # TODO in case 'collection' comes from TradeData, there's no need to create new TradeData out of it
        # alternative to filter_invested_options that takes TradeData should be implemented

        # create TradeData for both legs
        index_options_trade_data: TradeData = TradeData(
            {
                symbol : collection[symbol] for symbol in filter_invested_options(
                    algo=self._algo, 
                    positions=collection, 
                    option_types=[SecurityType.INDEX_OPTION], 
                    underlying_types=[SecurityType.INDEX], 
                    only_invested=only_invested
                )
            }
        )

        stock_options_trade_data: TradeData = TradeData(
            {
                symbol : collection[symbol] for symbol in filter_invested_options(
                    algo=self._algo, 
                    positions=collection, 
                    option_types=[SecurityType.OPTION], 
                    underlying_types=[SecurityType.EQUITY], 
                    only_invested=only_invested
                )
            }
        )

        return index_options_trade_data, stock_options_trade_data
    
    def _print_notional(
        self, 
        index_options_trade_data: TradeData, 
        stock_options_trade_data: TradeData
    ) -> None:

        index_leg_notional: float = calc_trade_notional(
            algo=self._algo, 
            trade_data=index_options_trade_data, 
        )

        stock_leg_notional: float = calc_trade_notional(
            algo=self._algo, 
            trade_data=stock_options_trade_data, 
        )
        
        self._algo.plot(self._chart_name, f'index_leg_notional', index_leg_notional)
        self._algo.plot(self._chart_name, f'stock_leg_notional', stock_leg_notional)
        
        if index_leg_notional != 0:
            self._algo.plot(self._chart_name, f'notional_ratio [stock/index]', stock_leg_notional/index_leg_notional)

# udpated on a new trade only
class ExpenditureChart_(Charts):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        trade_manager: TradeManager
    ) -> None:

        scheduled_update: bool = False
        trade_manager._enter_trade_event += self._manual_update

        self._trade_manager: TradeManager = trade_manager

        series_list: List[Dict] = [
            {
                'name': f'index_leg_expenditure', 
                'type': SeriesType.Line, 
                'unit': '$'
            },
            {
                'name': f'stock_leg_expenditure', 
                'type': SeriesType.Line, 
                'unit': '$'
            },
            {
                'name': f'expenditure_ratio [stock/index]', 
                'type': SeriesType.Line, 
                'unit': ''
            }
        ]

        super().__init__(
            algo=algo, 
            chart_name=f'Trade Expenditure', 
            series_list=series_list, 
            scheduled_update=scheduled_update, 
            schedule_rule=None
        )

    # automatic - scheduled update
    def _update_chart(self) -> None:
        pass

    # manual - non scheduled update
    def _manual_update(self, trade_data: TradeData) -> None:
        index_options_trade_data, stock_options_trade_data = self._get_legs_trade_data(
            collection=trade_data._trade_dict, 
            only_invested=False
        )

        self._print_expenditure(index_options_trade_data, stock_options_trade_data)

    def _get_legs_trade_data(
        self, 
        collection: Dict[Symbol, Union[float, int]], 
        only_invested: bool
    ) -> Tuple[TradeData]:

        # TODO in case 'collection' comes from TradeData, there's no need to create new TradeData out of it
        # alternative to filter_invested_options that takes TradeData should be implemented

        # create TradeData for both legs
        index_options_trade_data: TradeData = TradeData(
            {
                symbol : collection[symbol] for symbol in filter_invested_options(
                    algo=self._algo, 
                    positions=collection, 
                    option_types=[SecurityType.INDEX_OPTION], 
                    underlying_types=[SecurityType.INDEX], 
                    only_invested=only_invested
                )
            }
        )

        stock_options_trade_data: TradeData = TradeData(
            {
                symbol : collection[symbol] for symbol in filter_invested_options(
                    algo=self._algo, 
                    positions=collection, 
                    option_types=[SecurityType.OPTION], 
                    underlying_types=[SecurityType.EQUITY], 
                    only_invested=only_invested
                )
            }
        )

        return index_options_trade_data, stock_options_trade_data

    def _print_expenditure(
        self, 
        index_options_trade_data: TradeData, 
        stock_options_trade_data: TradeData
    ) -> None:

        index_leg_notional: float = calc_trade_expenditure(
            algo=self._algo, 
            trade_data=index_options_trade_data, 
        )

        stock_leg_notional: float = calc_trade_expenditure(
            algo=self._algo, 
            trade_data=stock_options_trade_data, 
        )
        
        self._algo.plot(self._chart_name, f'index_leg_expenditure', index_leg_notional)
        self._algo.plot(self._chart_name, f'stock_leg_expenditure', stock_leg_notional)
        
        if index_leg_notional != 0:
            self._algo.plot(self._chart_name, f'expenditure_ratio [stock/index]', stock_leg_notional/index_leg_notional)

# udpated on a new trade only
class SectorWeightChart_(Charts):
    def __init__(
        self, 
        algo: QCAlgorithm, 
        trade_manager: TradeManager
    ) -> None:

        scheduled_update: bool = False
        trade_manager._enter_trade_event += self._manual_update

        self._trade_manager: TradeManager = trade_manager

        series_list: List[Dict] = [
            {
                'name': f'{sector_str}', 
                'type': SeriesType.STACKED_AREA, 
                'unit': '%'
            } for sector_str in SECTOR_ETF_MAP.values()
        ]

        super().__init__(
            algo=algo, 
            chart_name=f'Sector Weight', 
            series_list=series_list, 
            scheduled_update=scheduled_update, 
            schedule_rule=None
        )

    # automatic - scheduled update
    def _update_chart(self) -> None:
        pass

    # manual - non scheduled update
    def _manual_update(self, trade_data: TradeData) -> None:
        index_options_trade_data, stock_options_trade_data = self._get_legs_trade_data(
            collection=trade_data._trade_dict, 
            only_invested=False
        )

        self._print_sector_weight(index_options_trade_data, stock_options_trade_data)

    def _get_legs_trade_data(
        self, 
        collection: Dict[Symbol, Union[float, int]], 
        only_invested: bool
    ) -> Tuple[TradeData]:

        # TODO in case 'collection' comes from TradeData, there's no need to create new TradeData out of it
        # alternative to filter_invested_options that takes TradeData should be implemented

        # create TradeData for both legs
        index_options_trade_data: TradeData = TradeData(
            {
                symbol : collection[symbol] for symbol in filter_invested_options(
                    algo=self._algo, 
                    positions=collection, 
                    option_types=[SecurityType.INDEX_OPTION], 
                    underlying_types=[SecurityType.INDEX], 
                    only_invested=only_invested
                )
            }
        )

        stock_options_trade_data: TradeData = TradeData(
            {
                symbol : collection[symbol] for symbol in filter_invested_options(
                    algo=self._algo, 
                    positions=collection, 
                    option_types=[SecurityType.OPTION], 
                    underlying_types=[SecurityType.EQUITY], 
                    only_invested=only_invested
                )
            }
        )

        return index_options_trade_data, stock_options_trade_data

    def _print_sector_weight(
        self, 
        index_options_trade_data: TradeData, 
        stock_options_trade_data: TradeData
    ) -> None:

        option_positions = stock_options_trade_data.trade_data.items()
        quantities: List[float] = list(set(map(lambda x: x[1], option_positions)))
        total_quantity: float = sum(quantities)
        underlyings: List[str] = list(set(map(lambda x: x[0].underlying.value, option_positions)))

        for underlying, quantity in zip(underlyings, quantities):
            if total_quantity != 0:
                self._algo.plot(self._chart_name, f'{underlying}', (quantity / total_quantity) * 100)
            else:
                # TODO this needs to be inspected, it's probably tied to gamma entanglement result when gamma is not available for none of the options in a trade
                self._algo.plot(self._chart_name, f'{underlying}', 0.)