Overall Statistics
from AlgorithmImports import *

class SpectralStatArb(QCAlgorithm):
    def initialize(self):
        self.set_start_date(2021, 1, 1)
        self.set_end_date(2022, 1, 1)
        self.set_cash(100000)
        self.settings.seed_initial_prices = True
        # 1. Morningstar Universe Settings
        self.universe_settings.resolution = Resolution.DAILY
        # Filter for the top liquid equities using fundamental data
        self.add_universe(self.fundamental_selection_function)
        
        # Strategy Parameters
        self.laplacian_lookback = 252 # Window for correlation matrix
        self.arb_lookback = 30        # Window for z-score spread
        self.entry_z_score = 2.0      # Statistically significant entry
        self.exit_z_score = 0.5       # Reversion exit threshold
        
        # Dictionaries to manage dynamic data
        self.symbol_data = {}
        self.active_pair = None       # Will hold a tuple (Asset_A, Asset_B)

        # Schedule the Laplacian computation to run at the start of each month
        self.schedule.on(self.date_rules.month_start(),
                         self.time_rules.at(10, 0),
                         self.compute_laplacian_and_pairs)

        self.schedule.on(self.date_rules.every_day(),
                 self.time_rules.at(10, 30),
                 self.execute_pairs_trade)

        self.set_warm_up(timedelta(350))

# TODO: move your schedulers in here: 
    # def on_warmup_finished(self):
    #     # Configure training and prediction schedule.
    #     trading_time_rule = self.time_rules.at(8, 0)
    #     training_time_rule = self.time_rules.at(6, 0)
    #     self.train(self.date_rules.month_end(), training_time_rule, self._train)
    #     self.schedule.on(self.date_rules.every_day('SPY'), trading_time_rule, self._rebalance)
    #     # Rebalance today too.
    #     if self.live_mode:
    #         self._train()
    #         self._rebalance()
    #     else:
    #         self.train(self.date_rules.today, training_time_rule, self._train)
    #         self.schedule.on(self.date_rules.today, trading_time_rule, self._rebalance)

    def fundamental_selection_function(self, fundamental):
        # Filter for the top 20 most liquid stocks with fundamental data
        # Keeping the universe relatively small (20) ensures the correlation matrix runs fast
        filtered = [f for f in fundamental if f.has_fundamental_data]
        sorted_by_volume = sorted(filtered, key=lambda f: f.dollar_volume, reverse=True)
        return [f.symbol for f in sorted_by_volume[:20]]

    def on_securities_changed(self, changes):
        # Handle symbols entering the universe
        for security in changes.added_securities:
            symbol = security.symbol
            if symbol not in self.symbol_data:
                # Create the SymbolData tracker for the new symbol
                self.symbol_data[symbol] = SymbolData(self, symbol, self.laplacian_lookback)
                
                # Warm up the rolling window with historical prices
                history = self.history[TradeBar](symbol, self.laplacian_lookback + 1, Resolution.DAILY)
                for bar in history:
                    self.symbol_data[symbol].logr.update(bar.end_time, bar.close)

        # Handle symbols leaving the universe
        for security in changes.removed_securities:
            symbol = security.symbol
            if symbol in self.symbol_data:
                # Clean up event handlers and remove from tracking
                self.symbol_data[symbol].dispose()
                self.symbol_data.pop(symbol)
                
                # If an active pair gets removed from the universe, liquidate it
                if self.active_pair and symbol in self.active_pair:
                    self.liquidate(self.active_pair[0])
                    self.liquidate(self.active_pair[1])
                    self.active_pair = None

    def compute_laplacian_and_pairs(self):
        # 1. Filter for symbols that have enough data to form the matrix
        ready_symbols = [sym for sym, data in self.symbol_data.items() if data.is_ready()]
        
        if len(ready_symbols) < 2:
            return

        # 2. Extract Log Returns into a DataFrame
        # RollingWindow is ordered newest to oldest, so we reverse it with [::-1]
        returns_dict = {sym.value: list(self.symbol_data[sym].window)[::-1] for sym in ready_symbols}
        df_returns = pd.DataFrame(returns_dict)
        
        # --- SPECTRAL GRAPH THEORY LOGIC ---
        
        # Adjacency Matrix (W)
        corr_matrix = df_returns.corr().fillna(0).values
        W = np.abs(corr_matrix)
        np.fill_diagonal(W, 0)
        
        # Degree Matrix (D) & Graph Laplacian (L)
        D = np.diag(np.sum(W, axis=1))
        L = D - W
        
        # Eigendecomposition
        eigenvalues, eigenvectors = np.linalg.eigh(L)
        if len(eigenvalues) < 2: 
            return
            
        fiedler_vector = eigenvectors[:, 1]
        
        # Map back to symbols
        fiedler_mapping = {ready_symbols[i]: fiedler_vector[i] for i in range(len(ready_symbols))}
        
        # Group into clusters by sign
        cluster_pos = [sym for sym, val in fiedler_mapping.items() if val > 0]
        
        # Liquidate previous pairs if we are picking a new one
        if self.active_pair:
            self.liquidate(self.active_pair[0])
            self.liquidate(self.active_pair[1])
            self.active_pair = None
            
        # Dynamically pair the first two assets in the positive cluster
        if len(cluster_pos) >= 2:
            self.active_pair = (cluster_pos[0], cluster_pos[1])
            self.log(f"New Pair Selected: {self.active_pair[0].value} & {self.active_pair[1].value}")

    def execute_pairs_trade(self):
        if not self.active_pair:
            self.log('no active pair') 
            return
        
        asset_a, asset_b = self.active_pair
        
        # Check if both securities are in our universe and tradeable
        if asset_a not in self.symbol_data or asset_b not in self.symbol_data:
            self.log('one or more asset not in universe')
            self.log(self.symbol_data.keys())
            return
        
        if not self.securities[asset_a].has_data or not self.securities[asset_b].has_data:
            self.log('one or more asset has no data')
            return
            
        # 1. Fetch 30-day history for the Z-score spread
        history = self.history([asset_a, asset_b], self.arb_lookback, Resolution.DAILY)
        if history.empty or len(history.index.levels[0]) < 2:
            return
            
        prices = history['close'].unstack(level=0)
        if asset_a.value not in prices.columns or asset_b.value not in prices.columns:
            return
            
        # 2. Calculate Spread and Z-Score
        log_a = np.log(prices[asset_a.value])
        log_b = np.log(prices[asset_b.value])
        spread = log_a - log_b
        
        mean_spread = np.mean(spread)
        std_spread = np.std(spread)
        
        if std_spread == 0: 
            return 
            
        current_spread = spread.iloc[-1]
        z_score = (current_spread - mean_spread) / std_spread
        
        self.log(f"Z-Score for {asset_a.value}/{asset_b.value}: {z_score:.2f}")
        
        # 3. Execution Logic
        is_invested_a = self.portfolio[asset_a].invested
        is_invested_b = self.portfolio[asset_b].invested
        
        if not is_invested_a and not is_invested_b:
            if z_score > self.entry_z_score:
                self.set_holdings(asset_a, -0.5)
                self.set_holdings(asset_b, 0.5)
                self.log(f"ENTRY: Short {asset_a.value}, Long {asset_b.value}")
            elif z_score < -self.entry_z_score:
                self.set_holdings(asset_a, 0.5)
                self.set_holdings(asset_b, -0.5)
                self.log(f"ENTRY: Long {asset_a.value}, Short {asset_b.value}")
                
        else:
            if abs(z_score) < self.exit_z_score:
                self.liquidate(asset_a)
                self.liquidate(asset_b)
                self.log(f"EXIT: Pair converged")
    # def on_data(self, data):
    #     if not self.active_pair: 
    #         self.log('no active pair')
    #         return
        
    #     asset_a, asset_b = self.active_pair
        
    #     if not data.contains_key(asset_a) or not data.contains_key(asset_b):
    #         self.log('an asset is not in the data')
    #         return
            
    #     # 1. Fetch 30-day history for the Z-score spread
    #     history = self.history([asset_a, asset_b], self.arb_lookback, Resolution.DAILY)
    #     if history.empty or len(history.index.levels[0]) < 2:
    #         return
            
    #     prices = history['close'].unstack(level=0)
    #     if asset_a.value not in prices.columns or asset_b.value not in prices.columns:
    #         return
            
    #     # 2. Calculate Spread and Z-Score
    #     log_a = np.log(prices[asset_a.value])
    #     log_b = np.log(prices[asset_b.value])
    #     spread = log_a - log_b
        
    #     mean_spread = np.mean(spread)
    #     std_spread = np.std(spread)
        
    #     if std_spread == 0: return 
            
    #     current_spread = spread.iloc[-1]
    #     z_score = (current_spread - mean_spread) / std_spread
        
    #     # 3. Execution Logic
    #     is_invested_a = self.portfolio[asset_a].invested
    #     is_invested_b = self.portfolio[asset_b].invested
        
    #     if not is_invested_a and not is_invested_b:
    #         if z_score > self.entry_z_score:
    #             self.set_holdings(asset_a, -0.5)
    #             self.set_holdings(asset_b, 0.5)
    #         elif z_score < -self.entry_z_score:
    #             self.set_holdings(asset_a, 0.5)
    #             self.set_holdings(asset_b, -0.5)
                
    #     else:
    #         if abs(z_score) < self.exit_z_score:
    #             self.liquidate(asset_a)
    #             self.liquidate(asset_b)


class SymbolData:
    """Helper class to track LogReturn indicators and history for dynamic universes"""
    def __init__(self, algorithm, symbol, lookback):
        self.symbol = symbol
        # Initialize the LogReturn indicator
        self.logr = algorithm.logr(symbol, 1, Resolution.DAILY)
        # Create a rolling window to store the indicator's output
        self.window = RollingWindow[float](lookback)
        
        # Attach the event handler so the window updates automatically
        self.logr.updated += self.logr_updated
        
    def logr_updated(self, sender, updated):
        self.window.add(updated.value)
        
    def is_ready(self):
        return self.window.is_ready
        
    def dispose(self):
        # Detach the event handler to prevent memory leaks when symbol leaves universe
        self.logr.updated -= self.logr_updated