Overall Statistics
Total Orders
440
Average Win
0.53%
Average Loss
-0.44%
Compounding Annual Return
12.301%
Drawdown
3.400%
Expectancy
0.306
Start Equity
100000
End Equity
133924
Net Profit
33.924%
Sharpe Ratio
0.556
Sortino Ratio
0.433
Probabilistic Sharpe Ratio
85.147%
Loss Rate
40%
Win Rate
60%
Profit-Loss Ratio
1.19
Alpha
0.024
Beta
0.062
Annual Standard Deviation
0.054
Annual Variance
0.003
Information Ratio
-0.586
Tracking Error
0.132
Treynor Ratio
0.486
Total Fees
$946.00
Estimated Strategy Capacity
$110000000.00
Lowest Capacity Asset
NQ YYFADOG4CO3L
Portfolio Turnover
165.95%
Drawdown Recovery
61
"""
backtest.py — Orchestrator
=============================
Runs the full pipeline for one or many instrument × config combos.
Returns structured results with full trade logs.

Usage in notebook:
    from backtest import run_single, run_sweep
    results = run_single(qb, "NQ")
    sweep = run_sweep(qb, config.SWEEP_CONFIG)
"""
from AlgorithmImports import *
import numpy as np
import pandas as pd

from config import PARAMS, INSTRUMENTS, EMISSION_CONFIGS, GATE_CONFIGS, \
    EXIT_CONFIGS, DIRECTION_CONFIGS, SWEEP_CONFIG
import data_pipeline as dp
import hmm_model as hm
import tensor as tn
import branch2 as b2
# NOTE: ma_signal, decision, sizing are in the combined file.
# In QC, these would be separate files. For now, import from the combined module.
# Adjust imports based on your actual file structure:
# import ma_signal as ma
# import decision as dec
# import sizing as sz


def run_single(qb, ticker, gate_name=None, emission_name="lr+gk",
               exit_name="fixed_7", direction_override=None,
               lookback_days=None, verbose=True):
    """
    Full pipeline for one instrument × one config.
    """
    cfg = INSTRUMENTS[ticker].copy()
    em_cfg = EMISSION_CONFIGS[emission_name]

    if direction_override:
        cfg["cross"] = direction_override["cross"]
        cfg["direction"] = direction_override["direction"]
    if gate_name is None:
        gate_name = cfg["gate"]

    if verbose:
        print(f"\n{'='*60}")
        print(f"  {ticker} | {emission_name} | {gate_name} | {exit_name}")
        print(f"  {cfg['cross']} {'long' if cfg['direction']==1 else 'short'}")
        print(f"{'='*60}")

    # ── 1. Data ──
    df, sessions, has_quotes = dp.run_pipeline(qb, cfg, lookback_days)

    # ── 2. HMM ──
    emission_cols = em_cfg["columns"]
    missing = [c for c in emission_cols if c not in df.columns]
    if missing:
        print(f"  WARNING: Missing emission columns {missing}, falling back to lr+gk")
        emission_cols = ["LR_norm", "GK_norm"]

    train_sess, val_sess = hm.get_train_sessions(sessions)
    model = hm.fit_hmm(df, emission_cols, train_sess)

    if verbose:
        print(f"  Running HMM on {len(sessions)} sessions...")
    hmm_results = hm.run_hmm_all_sessions(df, model, emission_cols, sessions)

    df["Alpha_P1"] = hmm_results["alpha"][:, 1]

    # ── 3. Tensor ──
    if verbose:
        print(f"  Building tensors...")
    all_tensors = tn.build_tensors_all_sessions(df, hmm_results, sessions, model)
    features = tn.extract_features(all_tensors, hmm_results["alpha"], df["Active"].values)
    df = tn.attach_features_to_df(df, features, sessions)

    # ── 3b. KL Predictor ──
    if verbose:
        print(f"  Running KL predictor...")
    import kl_predictor as kp
    df, kp_results = kp.run_kl_predictor(df, hmm_results, sessions)

    # ── 4. Branch 2 (stress) ──
    df = b2.attach_stress(df, hmm_results)

    # ── 5. MA crossovers ──
    from ma_signal import compute_crossovers
    df = compute_crossovers(df)

    # ── 6. Forward returns ──
    df = dp.compute_forward_returns(df)

    # ── 7. Decision (entries) ──
    from decision import generate_entries, get_crossover_bars
    entries, xo = generate_entries(df, cfg, gate_name)

    # ── 8. Sizing ──
    from sizing import compute_conviction_column
    if len(entries) > 0:
        entries = compute_conviction_column(entries)

    # ── 9. Trade simulation ──
    trades = simulate_trades(df, entries, cfg, EXIT_CONFIGS[exit_name], all_tensors)

    # ── 10. Summary ──
    if verbose and len(trades) > 0:
        cost = cfg["cost_pts"]
        pv = cfg["point_value"]
        net = trades["final_pnl"].values - cost
        w, l = net[net > 0], net[net <= 0]
        pf = w.sum() / abs(l.sum()) if len(l) > 0 and abs(l.sum()) > 0 else 0
        print(f"\n  RESULTS: n={len(trades)}, net={net.mean():+.3f} pts, "
              f"hit={(net>0).mean():.3f}, PF={pf:.2f}")
        print(f"  Total $: ${net.sum() * pv:+,.0f}")

    return {
        "df": df, "sessions": sessions, "tensors": all_tensors,
        "xo": xo, "entries": entries, "trades": trades,
        "model": model, "hmm_results": hmm_results,
        "kp_results": kp_results,
        "config": {
            "ticker": ticker, "instrument": cfg, "gate": gate_name,
            "emission": emission_name, "exit": exit_name,
        },
    }
    
# ══════════════════════════════════════════════════════════════
# TRADE SIMULATION
# ══════════════════════════════════════════════════════════════

def simulate_trades(df, entries, instrument_cfg, exit_cfg, all_tensors):
    """
    Simulate trades with full path logging using exits.py.

    Parameters
    ----------
    df : pd.DataFrame
    entries : pd.DataFrame — gated entry bars
    instrument_cfg : dict
    exit_cfg : dict — from config.EXIT_CONFIGS
    all_tensors : [N, 5, 11]

    Returns
    -------
    trades : pd.DataFrame — one row per trade with full diagnostics
    """
    if len(entries) == 0:
        return pd.DataFrame()

    from exits import simulate_trade

    direction = instrument_cfg["direction"]
    cost = instrument_cfg["cost_pts"]

    records = []
    for entry_idx in entries.index:
        entry_loc = df.index.get_loc(entry_idx)
        entry_row = entries.loc[entry_idx]

        record = simulate_trade(
            df=df,
            entry_loc=entry_loc,
            direction=direction,
            cost=cost,
            exit_cfg=exit_cfg,
            all_tensors=all_tensors,
            entry_row=entry_row,
        )

        if record is not None:
            records.append(record)

    return pd.DataFrame(records)


# ══════════════════════════════════════════════════════════════
# SWEEP
# ══════════════════════════════════════════════════════════════

def run_sweep(qb, sweep_cfg=None, verbose=False):
    """
    Run combinatoric sweep across instruments × configs.

    Parameters
    ----------
    qb : QuantBook
    sweep_cfg : dict or None — from config.SWEEP_CONFIG
    verbose : bool

    Returns
    -------
    results_table : pd.DataFrame — one row per config with summary stats
    all_results : dict — full results for each config
    """
    if sweep_cfg is None:
        sweep_cfg = SWEEP_CONFIG

    instruments = sweep_cfg["instruments"]
    emissions = sweep_cfg.get("emissions") or ["lr+gk"]
    gates = sweep_cfg.get("gates") or [None]  # None = instrument default
    exits = sweep_cfg.get("exits") or ["fixed_7"]
    directions = sweep_cfg.get("directions")

    rows = []
    all_results = {}
    total = len(instruments) * len(emissions) * len(gates) * len(exits)
    if directions:
        total *= len(directions)

    print(f"SWEEP: {total} configurations")
    i = 0

    for ticker in instruments:
        cfg = INSTRUMENTS[ticker]
        dir_list = [None] if not directions else list(directions)

        for dir_name in dir_list:
            dir_override = DIRECTION_CONFIGS[dir_name] if dir_name else None

            for em_name in emissions:
                for gate_name in gates:
                    g = gate_name or cfg["gate"]
                    for exit_name in exits:
                        i += 1
                        key = f"{ticker}_{em_name}_{g}_{exit_name}"
                        if dir_name:
                            key += f"_{dir_name}"

                        print(f"\n[{i}/{total}] {key}")

                        try:
                            res = run_single(
                                qb, ticker,
                                gate_name=g,
                                emission_name=em_name,
                                exit_name=exit_name,
                                direction_override=dir_override,
                                verbose=verbose,
                            )
                            all_results[key] = res

                            # Summary row
                            trades = res["trades"]
                            cost = cfg["cost_pts"]
                            pv = cfg["point_value"]
                            if len(trades) > 0:
                                net = trades["final_pnl"].values - cost
                                w = net[net > 0]
                                l = net[net <= 0]
                                pf = w.sum()/abs(l.sum()) if len(l)>0 and abs(l.sum())>0 else 0
                                rows.append({
                                    "key": key, "ticker": ticker,
                                    "emission": em_name, "gate": g,
                                    "exit": exit_name,
                                    "direction": dir_name or "default",
                                    "n_trades": len(trades),
                                    "net_mean": net.mean(),
                                    "hit_rate": (net > 0).mean(),
                                    "pf": pf,
                                    "total_pts": net.sum(),
                                    "total_dollar": net.sum() * pv,
                                    "mfe_mean": trades["mfe"].mean(),
                                    "mae_mean": trades["mae"].mean(),
                                    "capture": trades["final_pnl"].mean() / trades["mfe"].mean() if trades["mfe"].mean() > 0 else 0,
                                })
                            else:
                                rows.append({
                                    "key": key, "ticker": ticker,
                                    "emission": em_name, "gate": g,
                                    "exit": exit_name,
                                    "direction": dir_name or "default",
                                    "n_trades": 0,
                                })

                        except Exception as e:
                            print(f"  FAILED: {e}")
                            rows.append({"key": key, "ticker": ticker, "error": str(e)})

    results_table = pd.DataFrame(rows)
    print(f"\nSweep complete: {len(results_table)} configs")

    return results_table, all_results
# region imports
from AlgorithmImports import *
# endregion

# Your New Python File

import numpy as np
from config import PARAMS
"""
═══════════════════════════════════════════════════════════════
FILE 1: branch2.py — Stress EWMA (Session-Reset)
═══════════════════════════════════════════════════════════════
Computes fast/slow EWMA of KL values with session resets.
In live: fed by tensor slot 2 KL (most recent bar with backward info).
"""
import numpy as np
from config import PARAMS


def compute_stress(kl_array, session_dates, fast_span=None, slow_span=None):
    """
    Session-reset EWMA stress indicator.

    Parameters
    ----------
    kl_array : np.ndarray [N] — per-bar KL values
    session_dates : np.ndarray [N] — session date per bar
    fast_span, slow_span : int — EWMA spans

    Returns
    -------
    stress : np.ndarray [N] — fast EWMA - slow EWMA
    """
    if fast_span is None: fast_span = PARAMS["ewma_fast"]
    if slow_span is None: slow_span = PARAMS["ewma_slow"]

    n = len(kl_array)
    kf = np.zeros(n)
    ks = np.zeros(n)
    af = 2.0 / (fast_span + 1)
    asl = 2.0 / (slow_span + 1)

    prev_sess = None
    fv = sv = 0.0

    for i in range(n):
        se = session_dates[i]
        kv = kl_array[i]
        if np.isnan(kv):
            kv = fv  # carry forward

        if se != prev_sess:
            # Session reset
            fv = sv = kv
            prev_sess = se
        else:
            fv = af * kv + (1 - af) * fv
            sv = asl * kv + (1 - asl) * sv

        kf[i] = fv
        ks[i] = sv

    return kf - ks


def attach_stress(df, hmm_results, mode="B"):
    """
    Compute and attach stress columns to dataframe.

    Parameters
    ----------
    df : pd.DataFrame
    hmm_results : dict from hmm_model.run_hmm_all_sessions()
    mode : "A" (full backward KL) or "B" (tensor slot 2 KL)

    Returns
    -------
    df : pd.DataFrame with stress columns
    """
    df = df.copy()
    sd_arr = df["SessionDate"].values

    # Mode A stress (from full backward — for comparison)
    df["stress_A"] = compute_stress(hmm_results["kl_a"], sd_arr)

    # Mode B stress (from tensor slot 2 — live feasible)
    kl_b_slot2 = hmm_results["kl_b_slots"][:, 2]
    df["stress_B"] = compute_stress(kl_b_slot2, sd_arr)

    return df
"""
config.py — Central Configuration
====================================
Single source of truth for all parameters, instrument configs,
gate definitions, exit strategies, and sweep configurations.
"""
from AlgorithmImports import *

# ══════════════════════════════════════════════════════════════
# 1. GLOBAL PARAMETERS
# ══════════════════════════════════════════════════════════════

PARAMS = {
    # Data
    "lookback_days": 1500,
    "rth_start": "09:30",
    "rth_end": "15:55",
    "bar_size": "5min",
    "warmup_bars": 6,
    "gk_floor": 0.00005,
    "hmm_ewma_span": 390,

    # HMM
    "hmm_states": 2,
    "hmm_iter": 50,
    "hmm_train_frac": 0.40,
    "hmm_random_state": 42,

    # Branch 2 (stress EWMA)
    "ewma_fast": 3,
    "ewma_slow": 15,
    "stress_threshold": 0.02,

    # Entry gates
    "kl_slope_threshold": 0.0,
    "alpha_std_threshold": 0.05,
    "no_trade_after_hour": 15,

    # MA
    "ma_fast": 8,
    "ma_slow": 20,

    # Tensor
    "tensor_slots": 5,
    "tensor_cols": 11,
    "tensor_gamma_idx": (0, 2),
    "tensor_alpha_idx": (2, 4),
    "tensor_proj_idx": (4, 6),
    "tensor_kl_idx": 6,
    "tensor_micro_idx": (7, 11),

    # Exit (defaults)
    "hold_bars": 7,
    "max_hold_bars": 12,

    # Sizing
    "conviction_weights": {
        "stress": 0.30,
        "age": 0.25,
        "alpha": 0.20,
        "kl_slope": 0.25,
    },

    "contract_tiers": [
        {"threshold": 0.7, "contracts": 3},
        {"threshold": 0.5, "contracts": 2},
        {"threshold": 0.0, "contracts": 1},
    ],

    # Validation
    "warmup_sessions": 60,
    "bootstrap_n": 10000,
}


# ══════════════════════════════════════════════════════════════
# 1b. ADAPTIVE EXIT PROFILES (per instrument)
# ══════════════════════════════════════════════════════════════

ADAPTIVE_EXIT_PROFILES = {
    "adaptive_nq": {
        "type": "adaptive",
        "max_bars": 10,
        "quick_profit_mult": 1.5,
        "quick_profit_bars": 2,
        "avg_profit": 32.0,
        "max_mae_pts": 35.0,
        "trail_activation": 0.5,
        "trail_pct": 0.4,
        "avg_mfe": 32.5,
    },
    "adaptive_es": {
        "type": "adaptive",
        "max_bars": 10,
        "quick_profit_mult": 1.5,
        "quick_profit_bars": 2,
        "avg_profit": 7.0,
        "max_mae_pts": 8.0,
        "trail_activation": 0.5,
        "trail_pct": 0.4,
        "avg_mfe": 7.2,
    },
    "adaptive_ym": {
        "type": "adaptive",
        "max_bars": 12,
        "quick_profit_mult": 1.0,
        "quick_profit_bars": 3,
        "avg_profit": 73.0,
        "max_mae_pts": 80.0,
        "trail_activation": 0.4,
        "trail_pct": 0.35,
        "avg_mfe": 60.0,
    },
}


# ══════════════════════════════════════════════════════════════
# 2. INSTRUMENT CONFIGURATIONS
# ══════════════════════════════════════════════════════════════

INSTRUMENTS = {
    "NQ": {
        "future": Futures.Indices.NASDAQ100EMini,
        "cross": "bull_x",
        "direction": 1,
        "gate": "stress",
        "exit_profile": "adaptive_nq",
        "cost_pts": 1.10,
        "point_value": 20.0,
        "rth_start": "09:30",
        "rth_end": "15:55",
        "notes": "Momentum long. Highest retail/systematic participation.",
    },
    "ES": {
        "future": Futures.Indices.SP500EMini,
        "cross": "bear_x",
        "direction": 1,
        "gate": "kl+alpha+stress",
        "exit_profile": "adaptive_es",
        "cost_pts": 0.58,
        "point_value": 50.0,
        "rth_start": "09:30",
        "rth_end": "15:55",
        "notes": "Contrarian long. Regime-dependent.",
    },
    "YM": {
        "future": Futures.Indices.Dow30EMini,
        "cross": "bull_x",
        "direction": -1,
        "gate": "kl+alpha+stress",
        "exit_profile": "adaptive_ym",
        "cost_pts": 2.00,
        "point_value": 5.0,
        "rth_start": "09:30",
        "rth_end": "15:55",
        "notes": "Mean-reversion short. Degenerate HMM [0.393, 0.803].",
    },
    "RTY": {
        "future": Futures.Indices.Russell2000EMini,
        "cross": "bear_x",
        "direction": 1,
        "gate": "kl+alpha+stress",
        "exit_profile": "adaptive_es",
        "cost_pts": 0.40,
        "point_value": 50.0,
        "rth_start": "09:30",
        "rth_end": "15:55",
        "notes": "Contrarian long. Best quarterly consistency.",
    },
    "CL": {
        "future": Futures.Energies.CrudeOilWTI,
        "cross": "bear_x",
        "direction": -1,
        "gate": "kl+alpha+stress",
        "exit_profile": None,
        "cost_pts": 0.04,
        "point_value": 1000.0,
        "rth_start": "09:00",
        "rth_end": "14:25",
        "notes": "DEAD. Commercial hedger dominated. No edge expected.",
    },
    "GC": {
        "future": Futures.Metals.Gold,
        "cross": "bull_x",
        "direction": 1,
        "gate": "kl+alpha+stress",
        "exit_profile": None,
        "cost_pts": 0.40,
        "point_value": 100.0,
        "rth_start": "09:00",
        "rth_end": "14:25",
        "notes": "DEAD. Macro fund dominated. No edge expected.",
    },
}


# ══════════════════════════════════════════════════════════════
# 3. HMM EMISSION CONFIGURATIONS
# ══════════════════════════════════════════════════════════════

EMISSION_CONFIGS = {
    "lr+gk": {
        "columns": ["LR_norm", "GK_norm"],
        "notes": "Default. Log returns + Garman-Klass vol.",
    },
    "gk_only": {
        "columns": ["GK_norm"],
        "notes": "Volatility only. Tests if returns add noise.",
    },
    "lr_only": {
        "columns": ["LR_norm"],
        "notes": "Returns only. Tests if GK is redundant.",
    },
    "lr+rvol": {
        "columns": ["LR_norm", "RVol_norm"],
        "notes": "Returns + realized vol (std of returns).",
    },
    "lr+gk+spread": {
        "columns": ["LR_norm", "GK_norm", "Spread_norm"],
        "notes": "3-emission with spread. Requires quote data.",
    },
}


# ══════════════════════════════════════════════════════════════
# 4. GATE CONFIGURATIONS
# ══════════════════════════════════════════════════════════════

def _gate_stress(xo):
    return xo["stress_B"] < PARAMS["stress_threshold"]

def _gate_kl_alpha_stress(xo):
    return (
        (xo["T_kl_slope_3s"] > PARAMS["kl_slope_threshold"]) &
        (xo["T_alpha_std"] < PARAMS["alpha_std_threshold"]) &
        (xo["stress_B"] < PARAMS["stress_threshold"])
    )

def _gate_kl_stress(xo):
    return (
        (xo["T_kl_slope_3s"] > PARAMS["kl_slope_threshold"]) &
        (xo["stress_B"] < PARAMS["stress_threshold"])
    )

def _gate_micro_stress(xo):
    return (
        (xo["T_micro_gk_std"] > xo["T_micro_gk_std"].median()) &
        (xo["stress_B"] < PARAMS["stress_threshold"])
    )

def _gate_argmax_trans(xo):
    return (
        (xo["T_am_alpha_transition_now"] == 1) &
        (xo["stress_B"] < PARAMS["stress_threshold"])
    )

def _gate_argmax_majvol_kl(xo):
    return (
        (xo["T_am_alpha_majority_vol"] == 1) &
        (xo["T_kl_slope_3s"] > 0)
    )

def _gate_none(xo):
    import pandas as pd
    return pd.Series(True, index=xo.index)

def _gate_micro_thesis_a(xo):
    """GK variability high + close position stable. t=3.35 on NQ crossovers."""
    return (
        (xo["T_micro_gk_std"] > xo["T_micro_gk_std"].median()) &
        (xo["T_micro_cpos_std"] < xo["T_micro_cpos_std"].median())
    )

def _gate_micro_thesis_b(xo):
    """KP surprise high + KL slope declining. t=3.07 on NQ crossovers."""
    return (
        (xo["KP_surprise"] > xo["KP_surprise"].median()) &
        (xo["T_kl_slope_3s"] < xo["T_kl_slope_3s"].median())
    )

def _gate_micro_thesis_c(xo):
    """Close position trending up + KL more stable than predicted. t=3.05 on NQ crossovers."""
    return (
        (xo["T_micro_cpos_slope"] > xo["T_micro_cpos_slope"].median()) &
        (xo["KP_pred_vs_actual"] > xo["KP_pred_vs_actual"].median())
    )

def _gate_thesis_union(xo):
    """Any of the three thesis gates firing."""
    return _gate_micro_thesis_a(xo) | _gate_micro_thesis_b(xo) | _gate_micro_thesis_c(xo)

def _gate_thesis_union_stress(xo):
    """Thesis union + stress filter."""
    return _gate_thesis_union(xo) & (xo["stress_B"] < PARAMS["stress_threshold"])

GATE_CONFIGS = {
    "none":                  _gate_none,
    "stress":                _gate_stress,
    "kl+alpha+stress":       _gate_kl_alpha_stress,
    "kl+stress":             _gate_kl_stress,
    "micro+stress":          _gate_micro_stress,
    "am_trans+stress":       _gate_argmax_trans,
    "am_majvol+kl":          _gate_argmax_majvol_kl,
    "thesis_a":              _gate_micro_thesis_a,
    "thesis_b":              _gate_micro_thesis_b,
    "thesis_c":              _gate_micro_thesis_c,
    "thesis_union":          _gate_thesis_union,
    "thesis_union+stress":   _gate_thesis_union_stress,
}


# ══════════════════════════════════════════════════════════════
# 5. EXIT CONFIGURATIONS
# ══════════════════════════════════════════════════════════════

EXIT_CONFIGS = {
    "fixed_5": {
        "type": "fixed",
        "bars": 5,
    },
    "fixed_7": {
        "type": "fixed",
        "bars": 7,
    },
    "fixed_10": {
        "type": "fixed",
        "bars": 10,
    },
    "atr_15_20": {
        "type": "atr",
        "tp_mult": 1.5,
        "sl_mult": 2.0,
        "max_bars": 10,
        "atr_period": 14,
    },
    "atr_20_25": {
        "type": "atr",
        "tp_mult": 2.0,
        "sl_mult": 2.5,
        "max_bars": 10,
        "atr_period": 14,
    },
    "atr_10_15": {
        "type": "atr",
        "tp_mult": 1.0,
        "sl_mult": 1.5,
        "max_bars": 8,
        "atr_period": 14,
    },
    "tensor_monitored": {
        "type": "tensor",
        "max_bars": 10,
        "exit_on_regime_flip": True,
        "exit_on_stress_spike": True,
        "stress_exit_threshold": 0.05,
        "exit_on_micro_spike": False,
    },
    "mae_informed": {
        "type": "mae",
        "max_bars": 10,
        "mae_percentile": 75,
    },
    "trailing_mfe": {
        "type": "trailing",
        "max_bars": 10,
        "trail_activation": 0.5,
        "trail_pct": 0.4,
    },
    "combined_atr_tensor": {
        "type": "combined",
        "atr_tp_mult": 1.5,
        "atr_sl_mult": 2.5,
        "max_bars": 10,
        "atr_period": 14,
        "exit_on_regime_flip": True,
        "exit_on_stress_spike": True,
        "stress_exit_threshold": 0.05,
    },
    # ── Thesis exits v2: revised from parameter sweep + ablation ──
    # Key changes: wider quick profit, looser trailing, regime flip off
    # (regime flip demoted to trail-tightener, not hard exit)
    "thesis_exit_nq": {
        "type": "thesis",
        "max_bars": 10,
        "quick_profit_pts": 25.0,       # was 15 — let early spikes develop more
        "quick_profit_bars": 2,
        "mae_stop_pts": 30.0,           # was 25 — slight improvement at 30
        "trail_activate_pts": 20.0,     # was 15 — let runners build before trailing
        "trail_pct": 0.55,              # was 0.45 — give back less of peak
        "exit_on_regime_flip": False,   # was True — ablation showed it hurts
        "exit_on_stress_spike": True,
        "stress_exit_threshold": 0.04,
        "exit_on_micro_collapse": True,
        "micro_collapse_ratio": 0.3,
        # Regime flip tightens trail instead of hard exit
        "regime_flip_tightens_trail": True,
        "regime_flip_trail_pct": 0.30,  # tighter trail after flip
    },
    "thesis_exit_es": {
        "type": "thesis",
        "max_bars": 10,
        "quick_profit_pts": 5.0,        # ES-scaled (avg MFE ~7)
        "quick_profit_bars": 2,
        "mae_stop_pts": 6.0,
        "trail_activate_pts": 4.0,
        "trail_pct": 0.55,
        "exit_on_regime_flip": False,
        "exit_on_stress_spike": True,
        "stress_exit_threshold": 0.04,
        "exit_on_micro_collapse": True,
        "micro_collapse_ratio": 0.3,
        "regime_flip_tightens_trail": True,
        "regime_flip_trail_pct": 0.30,
    },
    "thesis_exit_ym": {
        "type": "thesis",
        "max_bars": 12,
        "quick_profit_pts": 50.0,       # YM-scaled (avg MFE ~60)
        "quick_profit_bars": 3,
        "mae_stop_pts": 60.0,
        "trail_activate_pts": 35.0,
        "trail_pct": 0.50,
        "exit_on_regime_flip": False,
        "exit_on_stress_spike": True,
        "stress_exit_threshold": 0.04,
        "exit_on_micro_collapse": True,
        "micro_collapse_ratio": 0.3,
        "regime_flip_tightens_trail": True,
        "regime_flip_trail_pct": 0.30,
    },
}

# Register adaptive profiles as exit configs
EXIT_CONFIGS.update(ADAPTIVE_EXIT_PROFILES)


# ══════════════════════════════════════════════════════════════
# 6. DIRECTION CONFIGS (for sweep)
# ══════════════════════════════════════════════════════════════

DIRECTION_CONFIGS = {
    "long_bull":  {"cross": "bull_x", "direction":  1},
    "short_bull": {"cross": "bull_x", "direction": -1},
    "long_bear":  {"cross": "bear_x", "direction":  1},
    "short_bear": {"cross": "bear_x", "direction": -1},
}


# ══════════════════════════════════════════════════════════════
# 7. SWEEP CONFIGURATION
# ══════════════════════════════════════════════════════════════

SWEEP_CONFIG = {
    "instruments": ["NQ", "ES", "YM", "RTY"],
    "emissions": ["lr+gk", "gk_only"],
    "directions": None,
    "gates": ["none", "stress", "kl+alpha+stress",
              "thesis_a", "thesis_b", "thesis_union+stress"],
    "exits": ["fixed_7", "atr_10_15", "thesis_exit_nq",
              "thesis_exit_es", "thesis_exit_ym"],
}


# ══════════════════════════════════════════════════════════════
# 8. MICRO FEATURE DEFINITIONS
# ══════════════════════════════════════════════════════════════

MICRO_FEATURES = [
    "micro_gk",
    "micro_range",
    "micro_vol",
    "micro_cpos",
]

MICRO_FEATURES_QUOTE = [
    "micro_spread",
    "micro_imbalance",
]


# ══════════════════════════════════════════════════════════════
# 9. FEATURE NAMES (for diagnostics)
# ══════════════════════════════════════════════════════════════

TENSOR_FEATURE_NAMES = {
    "kl_slope_3s", "kl_slope_4s", "kl_mean_3s", "kl_max_3s",
    "kl_min_3s", "kl_range_3s", "kl_slot2", "kl_accel",
    "alpha_std", "alpha_mean", "alpha_conv", "alpha_slope", "alpha_range",
    "am_alpha_s0", "am_alpha_s1", "am_alpha_s2", "am_alpha_s3",
    "am_alpha_sum", "am_alpha_unan_quiet", "am_alpha_unan_vol",
    "am_alpha_unanimous", "am_alpha_majority_quiet", "am_alpha_majority_vol",
    "am_alpha_flips", "am_alpha_last_flip", "am_alpha_first_flip",
    "am_alpha_pattern", "am_pattern_stable_quiet", "am_pattern_stable_vol",
    "am_pattern_entering_quiet", "am_pattern_entering_vol", "am_pattern_mixed",
    "am_alpha_current_quiet", "am_alpha_transition_now", "am_alpha_recent_trans",
    "gamma_std", "gamma_mean", "gamma_range", "gamma_slope",
    "am_gamma_s0", "am_gamma_s1", "am_gamma_s2", "am_gamma_s3",
    "am_gamma_sum", "am_gamma_unan_quiet", "am_gamma_unan_vol",
    "am_gamma_flips", "am_gamma_pattern",
    "disagree_s0", "disagree_s1", "disagree_s2", "disagree_s3",
    "disagree_count", "disagree_any", "disagree_current",
    "disagree_trend", "disagree_pattern",
    "ag_direction_s0", "ag_direction_s1", "ag_direction_s2", "ag_direction_s3",
    "ag_direction_net",
    "proj_quiet_val", "proj_conv", "am_proj_quiet",
    "proj_agrees_alpha", "proj_agrees_gamma",
    "proj_continues_trend", "proj_error", "am_5slot_pattern",
    "micro_gk_mean", "micro_gk_std", "micro_gk_last", "micro_gk_slope",
    "micro_range_mean", "micro_range_std", "micro_range_last", "micro_range_slope",
    "micro_vol_mean", "micro_vol_std", "micro_vol_last", "micro_vol_slope",
    "micro_cpos_mean", "micro_cpos_std", "micro_cpos_last", "micro_cpos_slope",
    "regime_age", "bars_since_flip", "current_streak",
    "kl_x_alpha_std", "kl_x_flips", "kl_x_disagree",
    "conv_x_streak", "conv_x_unanimous",
    "kl_high_unstable", "kl_high_stable", "kl_low_stable",
    "disagree_at_transition", "vol_calming_quiet",
}
"""
data_pipeline.py — Multi-Instrument Data Infrastructure
=========================================================
Pulls raw futures data for any instrument, builds 5-min RTH
dataframe with all base features, micro features, and multiple
emission normalizations for HMM testing.

Pure functions. No global state. Config comes from config.py.
"""
from AlgorithmImports import *
import numpy as np
import pandas as pd
from config import PARAMS, MICRO_FEATURES


# ══════════════════════════════════════════════════════════════
# CORE FUNCTIONS
# ══════════════════════════════════════════════════════════════

def compute_gk(df):
    """Garman-Klass volatility estimator. Works on any df with OHLC."""
    return np.sqrt(np.maximum(
        0.5 * (np.log(df["High"] / df["Low"])) ** 2
        - (2 * np.log(2) - 1) * (np.log(df["Close"] / df["Open"])) ** 2,
        0.0
    ))


def pull_data(qb, instrument_cfg, lookback_days=None):
    """
    Pull raw 1-min data for any instrument.

    Parameters
    ----------
    qb : QuantBook
    instrument_cfg : dict — from config.INSTRUMENTS[ticker]
    lookback_days : int or None — override config default

    Returns
    -------
    raw_df : pd.DataFrame — 1-min OHLCV bars, datetime indexed
    has_quotes : bool — whether quote data is available
    quote_df : pd.DataFrame or None — 1-min quote bars if available
    """
    if lookback_days is None:
        lookback_days = PARAMS["lookback_days"]

    sym = qb.add_future(instrument_cfg["future"], Resolution.MINUTE)
    sym.set_filter(0, 90)
    sym.data_normalization_mode = DataNormalizationMode.RAW

    print(f"  Pulling trade data ({lookback_days} days)...")
    trade_hist = qb.history(sym.symbol, timedelta(days=lookback_days), Resolution.MINUTE)

    raw_df = trade_hist[["open", "high", "low", "close", "volume"]].copy()
    raw_df.index = raw_df.index.get_level_values("time")
    raw_df.columns = ["Open", "High", "Low", "Close", "Volume"]
    print(f"    Trade bars: {len(raw_df)}")

    # Try quote data (may not be available for all instruments/periods)
    has_quotes = False
    quote_df = None
    try:
        print(f"  Pulling quote data...")
        quote_hist = qb.history(
            QuoteBar, sym.symbol, timedelta(days=lookback_days), Resolution.MINUTE
        )
        if len(quote_hist) > 0:
            quote_df = quote_hist[["bidclose", "askclose", "bidsize", "asksize"]].copy()
            quote_df.index = quote_df.index.get_level_values("time")
            quote_df.columns = ["BidClose", "AskClose", "BidSize", "AskSize"]
            has_quotes = True
            print(f"    Quote bars: {len(quote_df)}")
    except Exception as e:
        print(f"    Quote data not available: {e}")

    return raw_df, has_quotes, quote_df


def build_5min_bars(raw_df, rth_start=None, rth_end=None,
                    warmup_bars=None, quote_df=None):
    """
    Build 5-min RTH dataframe with all base features.

    Parameters
    ----------
    raw_df : pd.DataFrame — 1-min OHLCV from pull_data()
    rth_start, rth_end : str — RTH window (from instrument config)
    warmup_bars : int — bars to skip at session start
    quote_df : pd.DataFrame or None — 1-min quote bars

    Returns
    -------
    df : pd.DataFrame — 5-min bars with all base features
    """
    if rth_start is None: rth_start = PARAMS["rth_start"]
    if rth_end is None: rth_end = PARAMS["rth_end"]
    if warmup_bars is None: warmup_bars = PARAMS["warmup_bars"]

    # ── Resample to 5-min ──
    rth_1m = raw_df.between_time(rth_start, rth_end)
    df = rth_1m.resample("5min").agg({
        "Open": "first", "High": "max", "Low": "min",
        "Close": "last", "Volume": "sum",
    }).dropna()

    # ── Quote data if available ──
    if quote_df is not None:
        qdf_rth = quote_df.between_time(rth_start, rth_end)
        qdf_5m = qdf_rth.resample("5min").agg({
            "BidClose": "last", "AskClose": "last",
            "BidSize": "last", "AskSize": "last",
        }).dropna()
        df = df.join(qdf_5m, how="left")

    # ── Session indexing ──
    df["SessionDate"] = df.index.date
    df["_bn"] = df.groupby("SessionDate").cumcount()
    df = df[df["_bn"] >= warmup_bars].drop(columns=["_bn"])
    df["BarIndex"] = df.groupby("SessionDate").cumcount()

    # ── Core features ──
    df["LogReturn"] = np.log(df["Close"] / df["Close"].shift(1))
    df["GK"] = compute_gk(df)

    # Activity mask
    df["Active"] = (df["High"] != df["Low"]) & (df["GK"] > PARAMS["gk_floor"])

    # Drop first bar per session (no valid return)
    first_bars = df.groupby("SessionDate").head(1).index
    df.loc[first_bars, "LogReturn"] = np.nan
    df.dropna(subset=["LogReturn", "GK"], inplace=True)

    # ── EWMA normalizations for HMM emissions ──
    span = PARAMS["hmm_ewma_span"]

    # LR_norm (always built)
    df["LR_norm"] = (
        (df["LogReturn"] - df["LogReturn"].ewm(span=span).mean())
        / df["LogReturn"].ewm(span=span).std()
    )

    # GK_norm (always built)
    df["GK_norm"] = (
        (df["GK"] - df["GK"].ewm(span=span).mean())
        / df["GK"].ewm(span=span).std()
    )

    # RVol_norm (realized vol = rolling std of returns)
    df["RVol"] = df["LogReturn"].rolling(20, min_periods=10).std()
    df["RVol_norm"] = (
        (df["RVol"] - df["RVol"].ewm(span=span).mean())
        / df["RVol"].ewm(span=span).std()
    )

    # Spread_norm (if quote data available)
    if "AskClose" in df.columns and "BidClose" in df.columns:
        df["Spread"] = df["AskClose"] - df["BidClose"]
        df["Spread_norm"] = (
            (df["Spread"] - df["Spread"].ewm(span=span).mean())
            / df["Spread"].ewm(span=span).std()
        )

    df.dropna(subset=["LR_norm", "GK_norm"], inplace=True)

    return df


def build_micro_features(df):
    """
    Add micro features for tensor columns 7-10.
    Works with or without quote data (graceful fallback).

    Parameters
    ----------
    df : pd.DataFrame — from build_5min_bars()

    Returns
    -------
    df : pd.DataFrame — with micro feature columns added
    """
    df = df.copy()

    # GK vol of the bar
    df["micro_gk"] = df["GK"]

    # Normalized bar range
    bar_range = df["High"] - df["Low"]
    df["micro_range"] = bar_range / df["Close"]

    # Normalized volume
    vol_ewma = df["Volume"].ewm(span=50).mean()
    df["micro_vol"] = df["Volume"] / vol_ewma.replace(0, 1)

    # Close position within bar
    df["micro_cpos"] = np.where(
        bar_range > 0,
        (df["Close"] - df["Low"]) / bar_range,
        0.5
    )

    # Fill any NaN
    for col in MICRO_FEATURES:
        if col in df.columns:
            df[col] = df[col].fillna(0.0)
        else:
            df[col] = 0.0

    return df


def build_atr(df, period=14):
    """
    Compute session-aware ATR for exit strategies.

    Parameters
    ----------
    df : pd.DataFrame — must have High, Low, Close, SessionDate

    Returns
    -------
    df : pd.DataFrame — with ATR column added
    """
    df = df.copy()
    df["TR"] = np.maximum(
        df["High"] - df["Low"],
        np.maximum(
            np.abs(df["High"] - df["Close"].shift(1)),
            np.abs(df["Low"] - df["Close"].shift(1))
        )
    )

    # Session-aware: compute ATR within each session
    df["ATR"] = np.nan
    for sess in df["SessionDate"].unique():
        idx = df.index[df["SessionDate"] == sess]
        tr_sess = df.loc[idx, "TR"]
        atr_sess = tr_sess.ewm(span=period, min_periods=min(period, len(tr_sess))).mean()
        df.loc[idx, "ATR"] = atr_sess

    df.drop(columns=["TR"], inplace=True)
    return df


def compute_forward_returns(df, hold_bars=None):
    """
    Compute forward log returns for backtest evaluation.

    Parameters
    ----------
    df : pd.DataFrame — must have LogReturn, SessionDate
    hold_bars : int — holding period in bars

    Returns
    -------
    df : pd.DataFrame — with fwd_{hold_bars} column
    """
    if hold_bars is None:
        hold_bars = PARAMS["hold_bars"]

    df = df.copy()
    col = f"fwd_{hold_bars}"
    df[col] = np.nan

    for sess in df["SessionDate"].unique():
        idx = df.index[df["SessionDate"] == sess]
        lr = df.loc[idx, "LogReturn"].values
        for i in range(len(idx) - hold_bars):
            df.loc[idx[i], col] = lr[i + 1:i + 1 + hold_bars].sum()

    return df


# ══════════════════════════════════════════════════════════════
# FULL PIPELINE
# ══════════════════════════════════════════════════════════════

def run_pipeline(qb, instrument_cfg, lookback_days=None):
    """
    Full data pipeline for one instrument.

    Parameters
    ----------
    qb : QuantBook
    instrument_cfg : dict — from config.INSTRUMENTS[ticker]
    lookback_days : int or None

    Returns
    -------
    df : pd.DataFrame — ready for HMM and downstream
    sessions : list — sorted session dates
    has_quotes : bool
    """
    rth_s = instrument_cfg.get("rth_start", PARAMS["rth_start"])
    rth_e = instrument_cfg.get("rth_end", PARAMS["rth_end"])

    raw_df, has_quotes, quote_df = pull_data(qb, instrument_cfg, lookback_days)

    df = build_5min_bars(raw_df, rth_start=rth_s, rth_end=rth_e,
                         quote_df=quote_df)
    df = build_micro_features(df)
    df = build_atr(df)
    df = compute_forward_returns(df)

    sessions = sorted(df["SessionDate"].unique())

    print(f"  Pipeline complete:")
    print(f"    {len(df)} bars, {len(sessions)} sessions")
    print(f"    Range: {sessions[0]} to {sessions[-1]}")
    print(f"    Active: {df['Active'].mean()*100:.1f}%")
    print(f"    Quotes: {'yes' if has_quotes else 'no'}")

    return df, sessions, has_quotes
# region imports
from AlgorithmImports import *
# endregion

# Your New Python File
import numpy as np
import pandas as pd
from config import PARAMS, GATE_CONFIGS

"""
═══════════════════════════════════════════════════════════════
FILE 3: decision.py — Gate Application + Entry Generation
═══════════════════════════════════════════════════════════════
Applies gate configs to crossover bars, generates entry signals.
"""
import pandas as pd
from config import GATE_CONFIGS


def get_crossover_bars(df, instrument_cfg, sessions=None):
    """
    Extract crossover bars with forward returns and EOD filter.

    Parameters
    ----------
    df : pd.DataFrame — with MA_bull_x, MA_bear_x, fwd_7, tensor features
    instrument_cfg : dict — from config.INSTRUMENTS
    sessions : list — session dates to include (default: skip warmup)

    Returns
    -------
    xo : pd.DataFrame — crossover bars with ret_pts column
    """
    if sessions is None:
        all_sess = sorted(df["SessionDate"].unique())
        warmup = PARAMS["warmup_sessions"]
        sessions = all_sess[warmup:]

    cross_col = "MA_bear_x" if instrument_cfg["cross"] == "bear_x" else "MA_bull_x"
    direction = instrument_cfg["direction"]
    hold = PARAMS["hold_bars"]
    fwd_col = f"fwd_{hold}"

    mask = (
        df["SessionDate"].isin(sessions) &
        df["Active"] &
        df[fwd_col].notna() &
        (df.index.hour < PARAMS["no_trade_after_hour"])
    )
    va = df[mask].copy()
    xo = va[va[cross_col] == True].copy()
    xo["ret_pts"] = xo[fwd_col] * xo["Close"] * direction

    return xo


def apply_gate(xo, gate_name):
    """
    Apply a named gate config to crossover bars.

    Parameters
    ----------
    xo : pd.DataFrame — crossover bars
    gate_name : str — key in config.GATE_CONFIGS

    Returns
    -------
    mask : pd.Series of bool
    """
    gate_fn = GATE_CONFIGS[gate_name]
    return gate_fn(xo)


def generate_entries(df, instrument_cfg, gate_name=None):
    """
    Full entry generation: crossovers + EOD filter + gates.

    Parameters
    ----------
    df : pd.DataFrame — fully featured
    instrument_cfg : dict
    gate_name : str or None — if None, uses instrument default

    Returns
    -------
    entries : pd.DataFrame — gated crossover bars
    xo_all : pd.DataFrame — all crossover bars (for comparison)
    """
    if gate_name is None:
        gate_name = instrument_cfg["gate"]

    xo = get_crossover_bars(df, instrument_cfg)
    gate_mask = apply_gate(xo, gate_name)
    entries = xo[gate_mask].copy()

    print(f"  Entries: {len(entries)}/{len(xo)} crossovers "
          f"({gate_name}, {len(entries)/max(len(xo),1)*100:.0f}% pass rate)")

    return entries, xo
# region imports
from AlgorithmImports import *
# endregion
"""
diagnostics.py — Trade Quality + Feature Analysis + Bootstrap
================================================================
Analyzes results from backtest.run_single().

Functions:
  trade_report()      — MAE/MFE, capture ratio, winner/loser profiles
  trade_paths()       — bar-by-bar mean trade path, winner vs loser
  entry_context()     — pre-entry momentum, vol, session position
  feature_scan()      — correlate all tensor features with returns
  pattern_analysis()  — argmax pattern breakdown at crossover bars
  gate_combinatorics() — exhaustive single/2-way/3-way gate testing
  bootstrap()         — confidence intervals on net mean, hit, PF
  cross_asset_test()  — does edge correlate with retail flow?
  full_diagnostic()   — runs everything, prints comprehensive report

Usage:
    from diagnostics import full_diagnostic
    results = run_single(qb, "NQ")
    full_diagnostic(results)
"""
import numpy as np
import pandas as pd
from scipy import stats
from config import PARAMS, TENSOR_FEATURE_NAMES, GATE_CONFIGS


# ══════════════════════════════════════════════════════════════
# 1. TRADE QUALITY REPORT
# ══════════════════════════════════════════════════════════════

def trade_report(results):
    """
    MAE/MFE analysis, capture ratio, winner/loser separation.

    Parameters
    ----------
    results : dict from backtest.run_single()
    """
    trades = results["trades"]
    cfg = results["config"]
    cost = cfg["instrument"]["cost_pts"]
    pv = cfg["instrument"]["point_value"]

    if len(trades) == 0:
        print("  No trades."); return

    t = trades
    w = t[t["winner"]]
    l = t[~t["winner"]]

    ticker = cfg["ticker"]
    print(f"\n{'='*65}")
    print(f"  TRADE REPORT: {ticker} | {cfg['gate']} | {cfg['exit']}")
    print(f"  {cfg['instrument']['cross']} {'long' if cfg['instrument']['direction']==1 else 'short'}")
    print(f"{'='*65}")

    # ── Summary ──
    net = t["final_pnl"].values - cost
    wn, ln = net[net > 0], net[net <= 0]
    pf = wn.sum() / abs(ln.sum()) if len(ln) > 0 and abs(ln.sum()) > 0 else 0

    print(f"\n  Overview:")
    print(f"    Trades: {len(t)}")
    print(f"    Net/trade: {net.mean():+.3f} pts (${net.mean()*pv:+.1f})")
    print(f"    Hit rate: {(net>0).mean():.3f}")
    print(f"    PF: {pf:.2f}")
    print(f"    Total: {net.sum():+.1f} pts (${net.sum()*pv:+,.0f})")
    print(f"    Avg bars held: {t['bars_held'].mean():.1f}")

    # ── Exit reasons ──
    print(f"\n  Exit reasons:")
    for reason, count in t["exit_reason"].value_counts().items():
        sub = t[t["exit_reason"] == reason]
        sub_net = sub["final_pnl"].values - cost
        print(f"    {reason:<20s} n={count:>4} net={sub_net.mean():+.3f}")

    # ── MAE / MFE ──
    print(f"\n  MAE / MFE Analysis:")
    print(f"    {'Metric':<20s} {'All':>10} {'Winners':>10} {'Losers':>10}")
    for m in ["final_pnl", "mfe", "mae", "entry_eff"]:
        a = t[m].mean()
        wv = w[m].mean() if len(w) > 0 else np.nan
        lv = l[m].mean() if len(l) > 0 else np.nan
        print(f"    {m:<20s} {a:>+10.3f} {wv:>+10.3f} {lv:>+10.3f}")

    # ── MAE distribution ──
    print(f"\n  MAE distribution (how far trades go against you):")
    for pct in [10, 25, 50, 75, 90]:
        print(f"    P{pct}: {np.percentile(t['mae'], pct):+.3f} pts")

    # ── Separation quality ──
    if len(w) >= 5 and len(l) >= 5:
        mae_sep = abs(w["mae"].mean() - l["mae"].mean())
        quality = "GOOD — entries have timing edge" if mae_sep > 1.0 else \
                  "MODERATE" if mae_sep > 0.5 else \
                  "WEAK — similar drawdown profile, entries are noisy"
        print(f"\n  MAE separation: {mae_sep:.3f} pts ({quality})")

    # ── Capture ratio ──
    capture = t["final_pnl"].mean() / t["mfe"].mean() if t["mfe"].mean() > 0 else 0
    cap_quality = "Good" if capture > 0.3 else "Moderate" if capture > 0.15 else \
                  "POOR — leaving most of the move on the table"
    print(f"  Capture ratio: {capture:.3f} ({cap_quality})")
    print(f"    Mean MFE:       {t['mfe'].mean():.3f} pts (${t['mfe'].mean()*pv:.1f})")
    print(f"    Mean final P&L: {t['final_pnl'].mean():.3f} pts")
    print(f"    Gap (wasted):   {t['mfe'].mean() - t['final_pnl'].mean():.3f} pts "
          f"(${(t['mfe'].mean() - t['final_pnl'].mean())*pv:.1f})")

    # ── MFE timing: where does the peak occur? ──
    print(f"\n  MFE timing (when does peak profit occur?):")
    paths = t["path"].values
    max_bar_per_trade = []
    for path in paths:
        if len(path) > 0:
            max_bar_per_trade.append(np.argmax(path) + 1)
    if len(max_bar_per_trade) > 0:
        mb = np.array(max_bar_per_trade)
        print(f"    Mean peak bar: {mb.mean():.1f}")
        for b in range(1, min(PARAMS["hold_bars"] + 1, 11)):
            pct = (mb == b).mean() * 100
            if pct > 3:
                print(f"    Bar {b}: {pct:.0f}% of trades peak here")

    # ── Structural vs Lucky ──
    if len(w) >= 5:
        # Structural: won with low MAE (didn't survive deep drawdown)
        mae_thresh = w["mae"].quantile(0.5)  # median winner MAE
        structural = w[w["mae"] >= mae_thresh]  # less adverse = above median (less negative)
        lucky = w[w["mae"] < mae_thresh]        # deep drawdown before winning

        print(f"\n  Winner quality:")
        print(f"    'Clean' winners (MAE above median): n={len(structural)}, "
              f"final={structural['final_pnl'].mean():.3f}, MAE={structural['mae'].mean():+.3f}")
        print(f"    'Lucky' winners (deep MAE):         n={len(lucky)}, "
              f"final={lucky['final_pnl'].mean():.3f}, MAE={lucky['mae'].mean():+.3f}")


# ══════════════════════════════════════════════════════════════
# 2. TRADE PATHS
# ══════════════════════════════════════════════════════════════

def trade_paths(results):
    """Bar-by-bar mean trade path, winner vs loser split."""
    trades = results["trades"]
    if len(trades) == 0: return

    paths = trades["path"].values
    max_len = max(len(p) for p in paths)
    pm = np.full((len(paths), max_len), np.nan)
    for i, p in enumerate(paths):
        pm[i, :len(p)] = p

    win_mask = trades["winner"].values

    print(f"\n  Mean Trade Path:")
    print(f"  {'Bar':<5} {'All':>8} {'Winners':>8} {'Losers':>8} {'Gap':>8} {'%>0':>6}")

    for b in range(min(max_len, 12)):
        col = pm[:, b]
        v_all = col[~np.isnan(col)]
        v_win = col[win_mask & ~np.isnan(col)]
        v_lose = col[~win_mask & ~np.isnan(col)]

        if len(v_all) < 5: continue
        gap = v_win.mean() - v_lose.mean() if len(v_win) > 0 and len(v_lose) > 0 else 0
        print(f"  B{b+1:<4} {v_all.mean():>+8.3f} "
              f"{v_win.mean() if len(v_win)>0 else 0:>+8.3f} "
              f"{v_lose.mean() if len(v_lose)>0 else 0:>+8.3f} "
              f"{gap:>+8.3f} {(v_all>0).mean()*100:>5.1f}%")

    # When do winners and losers diverge?
    if max_len >= 3:
        bar1_all = pm[:, 0]; bar1_all = bar1_all[~np.isnan(bar1_all)]
        bar1_win = pm[win_mask, 0]; bar1_win = bar1_win[~np.isnan(bar1_win)]
        bar1_lose = pm[~win_mask, 0]; bar1_lose = bar1_lose[~np.isnan(bar1_lose)]

        if len(bar1_win) > 5 and len(bar1_lose) > 5:
            t_stat, p_val = stats.ttest_ind(bar1_win, bar1_lose, equal_var=False)
            print(f"\n  Bar 1 divergence test: t={t_stat:.2f}, p={p_val:.4f}")
            if p_val < 0.05:
                print(f"    Winners and losers are distinguishable by bar 1 → signal has real timing")
            else:
                print(f"    Not distinguishable at bar 1 → signal doesn't time the immediate move")


# ══════════════════════════════════════════════════════════════
# 3. ENTRY CONTEXT
# ══════════════════════════════════════════════════════════════

def entry_context(results):
    """Pre-entry momentum, volatility, bar quality for winners vs losers."""
    trades = results["trades"]
    if len(trades) < 10: return

    w = trades[trades["winner"]]
    l = trades[~trades["winner"]]
    if len(w) < 5 or len(l) < 5: return

    print(f"\n  Entry Context (winner vs loser):")
    print(f"    {'Metric':<25s} {'Winners':>10} {'Losers':>10} {'Diff':>10}")

    for col in ["pre_momentum", "entry_atr", "conviction"]:
        if col in trades.columns:
            wv = w[col].dropna().mean()
            lv = l[col].dropna().mean()
            print(f"    {col:<25s} {wv:>+10.5f} {lv:>+10.5f} {wv-lv:>+10.5f}")


# ══════════════════════════════════════════════════════════════
# 4. FEATURE SCAN
# ══════════════════════════════════════════════════════════════

def feature_scan(results, top_n=30):
    """
    Correlate all tensor features with returns at crossover bars.

    Parameters
    ----------
    results : dict from run_single()
    top_n : int — how many features to print

    Returns
    -------
    feat_df : pd.DataFrame — sorted by significance
    """
    xo = results["xo"]
    cfg = results["config"]
    cost = cfg["instrument"]["cost_pts"]
    direction = cfg["instrument"]["direction"]

    if "ret_pts" not in xo.columns:
        hold = PARAMS["hold_bars"]
        fwd_col = f"fwd_{hold}"
        xo = xo.copy()
        xo["ret_pts"] = xo[fwd_col] * xo["Close"] * direction

    ret = xo["ret_pts"].values
    t_cols = sorted([c for c in xo.columns if c.startswith("T_")] +
                    ["stress_A", "stress_B"])

    results_list = []
    for col in t_cols:
        if col not in xo.columns: continue
        vals = xo[col].values
        ok = ~np.isnan(vals) & ~np.isnan(ret)
        if ok.sum() < 20: continue
        r, p = stats.pearsonr(vals[ok], ret[ok])
        sig = "***" if p < .001 else "**" if p < .01 else "*" if p < .05 else ""
        results_list.append({"feature": col, "r": r, "p": p, "sig": sig,
                             "nunique": len(np.unique(vals[ok]))})

    feat_df = pd.DataFrame(results_list).sort_values("p")

    print(f"\n  Feature Scan ({cfg['ticker']}, n={len(xo)} crossovers):")
    print(f"  {'Feature':<40s} {'r':>7} {'p':>10} {'sig':>4}")
    for _, row in feat_df.head(top_n).iterrows():
        if row["p"] < 0.2:
            print(f"  {row['feature']:<40s} {row['r']:>+7.4f} {row['p']:>10.4f} {row['sig']:>4}")

    n_sig = (feat_df["p"] < 0.05).sum()
    print(f"\n  {n_sig} features significant at p<0.05 out of {len(feat_df)}")

    return feat_df


# ══════════════════════════════════════════════════════════════
# 5. ARGMAX PATTERN ANALYSIS
# ══════════════════════════════════════════════════════════════

def pattern_analysis(results):
    """Argmax pattern breakdown at crossover bars."""
    xo = results["xo"]
    cost = results["config"]["instrument"]["cost_pts"]

    if "T_am_alpha_pattern" not in xo.columns:
        print("  No argmax pattern column."); return

    print(f"\n  Argmax Pattern Breakdown:")
    print(f"  {'Pattern':<12s} {'Bits':<6s} {'N':>5} {'Net':>8} {'Hit':>6}")

    for pat in sorted(xo["T_am_alpha_pattern"].unique()):
        sub = xo[xo["T_am_alpha_pattern"] == pat]
        if len(sub) < 5: continue
        net = sub["ret_pts"].values - cost
        bits = f"{int(pat):04b}"
        regime = "".join(["Q" if b == "1" else "V" for b in bits])
        print(f"  {regime:<12s} {bits:<6s} {len(sub):>5} {net.mean():>+8.3f} "
              f"{(net>0).mean():>6.3f}")


# ══════════════════════════════════════════════════════════════
# 6. GATE COMBINATORICS
# ══════════════════════════════════════════════════════════════

def gate_combinatorics(results, max_singles=20, max_combos=15):
    """
    Exhaustive gate testing on crossover bars.
    Tests all single gates, then 2-way combos of top singles.
    """
    xo = results["xo"]
    cost = results["config"]["instrument"]["cost_pts"]

    if len(xo) < 20:
        print("  Too few crossovers."); return

    # Build all single gates
    singles = _build_single_gates(xo)
    base_net = (xo["ret_pts"].values - cost).mean()

    print(f"\n  Gate Combinatorics ({len(xo)} crossovers, baseline net={base_net:+.3f}):")
    print(f"\n  {'Gate':<35s} {'N':>5} {'Net':>8} {'Hit':>6} {'PF':>6} {'vs base':>8}")

    single_results = {}
    for gname, gmask in singles.items():
        sub = xo[gmask]
        if len(sub) < 8: continue
        pts = sub["ret_pts"].values
        net = pts - cost
        w, l = net[net > 0], net[net <= 0]
        pf = w.sum() / abs(l.sum()) if len(l) > 0 and abs(l.sum()) > 0 else 0
        delta = net.mean() - base_net
        single_results[gname] = {"n": len(sub), "net": net.mean(), "hit": (net>0).mean(),
                                  "pf": pf, "delta": delta, "mask": gmask}
        if delta > 0.3 or pf > 1.1:
            print(f"  {gname:<35s} {len(sub):>5} {net.mean():>+8.3f} "
                  f"{(net>0).mean():>6.3f} {pf:>6.2f} {delta:>+8.3f}")

    # Top singles for combos
    ranked = sorted(single_results.items(), key=lambda x: -x[1]["net"])
    top = [k for k, v in ranked if v["net"] > base_net and v["n"] >= 15][:max_singles]

    if len(top) >= 2:
        print(f"\n  2-WAY COMBOS (top {max_combos}):")
        print(f"  {'Gate':<55s} {'N':>5} {'Net':>8} {'Hit':>6} {'PF':>6}")

        combos = []
        for i in range(len(top)):
            for j in range(i + 1, len(top)):
                mask = single_results[top[i]]["mask"] & single_results[top[j]]["mask"]
                sub = xo[mask]
                if len(sub) < 8: continue
                pts = sub["ret_pts"].values
                net = pts - cost
                w, l = net[net > 0], net[net <= 0]
                pf = w.sum() / abs(l.sum()) if len(l) > 0 and abs(l.sum()) > 0 else 0
                combos.append({"gate": f"{top[i]} + {top[j]}", "n": len(sub),
                                "net": net.mean(), "hit": (net>0).mean(), "pf": pf})

        combos.sort(key=lambda x: -x["net"])
        for c in combos[:max_combos]:
            if c["net"] > base_net:
                print(f"  {c['gate']:<55s} {c['n']:>5} {c['net']:>+8.3f} "
                      f"{c['hit']:>6.3f} {c['pf']:>6.2f}")


def _build_single_gates(xo):
    """Build dict of all single gate masks for combinatorics."""
    gates = {}

    # KL
    if "T_kl_slope_3s" in xo: gates["kl_slope_3s>0"] = xo["T_kl_slope_3s"] > 0
    if "T_kl_accel" in xo: gates["kl_accel>0"] = xo["T_kl_accel"] > 0
    if "T_kl_mean_3s" in xo:
        gates["kl_mean>med"] = xo["T_kl_mean_3s"] > xo["T_kl_mean_3s"].median()

    # Alpha continuous
    if "T_alpha_std" in xo:
        gates["alpha_std<0.05"] = xo["T_alpha_std"] < 0.05
        gates["alpha_std<0.10"] = xo["T_alpha_std"] < 0.10
    if "T_alpha_conv" in xo:
        gates["alpha_conv>0.3"] = xo["T_alpha_conv"] > 0.3
    if "T_alpha_slope" in xo:
        gates["alpha_slope>0"] = xo["T_alpha_slope"] > 0

    # Alpha argmax
    if "T_am_alpha_unan_quiet" in xo: gates["am_unan_quiet"] = xo["T_am_alpha_unan_quiet"] == 1
    if "T_am_alpha_majority_vol" in xo: gates["am_maj_vol"] = xo["T_am_alpha_majority_vol"] == 1
    if "T_am_alpha_flips" in xo:
        gates["am_flips==0"] = xo["T_am_alpha_flips"] == 0
        gates["am_flips>=2"] = xo["T_am_alpha_flips"] >= 2
    if "T_am_alpha_transition_now" in xo: gates["am_trans_now"] = xo["T_am_alpha_transition_now"] == 1
    if "T_am_alpha_current_quiet" in xo:
        gates["am_current_quiet"] = xo["T_am_alpha_current_quiet"] == 1
        gates["am_current_vol"] = xo["T_am_alpha_current_quiet"] == 0

    # Argmax patterns
    if "T_am_pattern_entering_vol" in xo: gates["am_pat_enter_v"] = xo["T_am_pattern_entering_vol"] == 1
    if "T_am_pattern_entering_quiet" in xo: gates["am_pat_enter_q"] = xo["T_am_pattern_entering_quiet"] == 1
    if "T_am_pattern_stable_quiet" in xo: gates["am_pat_stable_q"] = xo["T_am_pattern_stable_quiet"] == 1

    # Gamma
    if "T_gamma_std" in xo: gates["gamma_std<0.05"] = xo["T_gamma_std"] < 0.05
    if "T_am_gamma_unan_quiet" in xo: gates["gam_unan_quiet"] = xo["T_am_gamma_unan_quiet"] == 1

    # Disagreement
    if "T_disagree_any" in xo:
        gates["no_disagree"] = xo["T_disagree_any"] == 0
    if "T_ag_direction_net" in xo:
        gates["ag_dir_net>0"] = xo["T_ag_direction_net"] > 0
        gates["ag_dir_net<0"] = xo["T_ag_direction_net"] < 0

    # Projection
    if "T_proj_agrees_alpha" in xo: gates["proj_agrees"] = xo["T_proj_agrees_alpha"] == 1

    # Regime
    if "T_regime_age" in xo:
        gates["age>=5"] = xo["T_regime_age"] >= 5
        gates["age>=10"] = xo["T_regime_age"] >= 10
    if "T_current_streak" in xo: gates["streak>=3"] = xo["T_current_streak"] >= 3

    # Stress
    if "stress_A" in xo: gates["stress_A<0.02"] = xo["stress_A"] < 0.02
    if "stress_B" in xo: gates["stress_B<0.02"] = xo["stress_B"] < 0.02

    # Micro
    if "T_micro_gk_slope" in xo: gates["gk_slope<0"] = xo["T_micro_gk_slope"] < 0
    if "T_vol_calming_quiet" in xo: gates["vol_calm_quiet"] = xo["T_vol_calming_quiet"] == 1
    if "T_kl_high_stable" in xo: gates["kl_high_stable"] = xo["T_kl_high_stable"] == 1

    return gates


# ══════════════════════════════════════════════════════════════
# 7. BOOTSTRAP
# ══════════════════════════════════════════════════════════════

def bootstrap(results, n_boot=None):
    """Bootstrap confidence intervals on key metrics."""
    if n_boot is None: n_boot = PARAMS["bootstrap_n"]

    trades = results["trades"]
    cost = results["config"]["instrument"]["cost_pts"]
    if len(trades) < 15:
        print("  Too few trades for bootstrap."); return

    pts = trades["final_pnl"].values
    net = pts - cost

    b_means = np.zeros(n_boot)
    b_hits = np.zeros(n_boot)
    b_pfs = np.zeros(n_boot)

    for i in range(n_boot):
        s = np.random.choice(net, size=len(net), replace=True)
        b_means[i] = s.mean()
        b_hits[i] = (s > 0).mean()
        w, l = s[s > 0], s[s <= 0]
        b_pfs[i] = w.sum() / abs(l.sum()) if len(l) > 0 and abs(l.sum()) > 0 else 0

    print(f"\n  Bootstrap ({len(net)} trades, {n_boot} resamples):")
    print(f"    Net mean 90% CI:  [{np.percentile(b_means, 5):+.3f}, {np.percentile(b_means, 95):+.3f}]")
    print(f"    Hit rate 90% CI:  [{np.percentile(b_hits, 5):.3f}, {np.percentile(b_hits, 95):.3f}]")
    print(f"    PF 90% CI:        [{np.percentile(b_pfs, 5):.2f}, {np.percentile(b_pfs, 95):.2f}]")
    print(f"    P(net > 0):       {(b_means > 0).mean() * 100:.1f}%")
    print(f"    P(PF > 1):        {(b_pfs > 1).mean() * 100:.1f}%")


# ══════════════════════════════════════════════════════════════
# 8. CROSS-ASSET HYPOTHESIS TEST
# ══════════════════════════════════════════════════════════════

def cross_asset_test(all_results):
    """
    Test: does edge correlate with retail flow presence?
    Expected ranking: NQ > ES > RTY > YM/GC/CL.

    Parameters
    ----------
    all_results : dict of {ticker: results_dict}
    """
    print(f"\n{'='*65}")
    print(f"  CROSS-ASSET HYPOTHESIS TEST")
    print(f"  Thesis: edge exists where retail/systematic flow dominates")
    print(f"  Expected: NQ > ES > RTY >> YM > CL/GC")
    print(f"{'='*65}")

    rows = []
    for ticker, res in all_results.items():
        trades = res["trades"]
        cost = res["config"]["instrument"]["cost_pts"]
        if len(trades) == 0:
            rows.append({"ticker": ticker, "n": 0, "net": np.nan})
            continue
        net = trades["final_pnl"].values - cost
        w, l = net[net > 0], net[net <= 0]
        pf = w.sum() / abs(l.sum()) if len(l) > 0 and abs(l.sum()) > 0 else 0
        rows.append({
            "ticker": ticker, "n": len(trades), "net": net.mean(),
            "hit": (net > 0).mean(), "pf": pf, "total": net.sum(),
        })

    df = pd.DataFrame(rows).sort_values("net", ascending=False)
    print(f"\n  {'Ticker':<8s} {'N':>5} {'Net/tr':>8} {'Hit':>6} {'PF':>6} {'Total':>10}")
    for _, r in df.iterrows():
        if np.isnan(r.get("net", np.nan)):
            print(f"  {r['ticker']:<8s} {'no trades':>5}")
        else:
            print(f"  {r['ticker']:<8s} {r['n']:>5} {r['net']:>+8.3f} "
                  f"{r['hit']:>6.3f} {r['pf']:>6.2f} {r['total']:>+10.1f}")

    # Check if equity index futures outperform commodity futures
    equity = df[df["ticker"].isin(["NQ", "ES", "RTY", "YM"])]["net"].dropna()
    commodity = df[df["ticker"].isin(["CL", "GC"])]["net"].dropna()

    if len(equity) > 0 and len(commodity) > 0:
        print(f"\n  Equity index avg net: {equity.mean():+.3f}")
        print(f"  Commodity avg net:    {commodity.mean():+.3f}")
        if equity.mean() > commodity.mean():
            print(f"  ✓ Retail flow thesis SUPPORTED")
        else:
            print(f"  ✗ Retail flow thesis NOT SUPPORTED")


# ══════════════════════════════════════════════════════════════
# 9. FULL DIAGNOSTIC (runs everything)
# ══════════════════════════════════════════════════════════════

def trade_detail(results, n_best=5, n_worst=5):
    """
    Print detailed diagnostics for individual trades.
    Shows best/worst trades with full context, then clusters.
    """
    trades = results["trades"]
    cfg = results["config"]
    cost = cfg["instrument"]["cost_pts"]
    pv = cfg["instrument"]["point_value"]

    if len(trades) < 10: return

    trades_sorted = trades.sort_values("net_pnl", ascending=False)

    # ── Best trades ──
    print(f"\n  TOP {n_best} TRADES:")
    print(f"  {'#':<3} {'Time':<20s} {'Net':>8} {'MFE':>8} {'MAE':>8} "
          f"{'Bars':>5} {'Exit':<15s} {'Pattern':<8s} {'Stress':>7}")
    for i, (_, t) in enumerate(trades_sorted.head(n_best).iterrows()):
        ef = t.get("entry_features", {})
        pat = ef.get("T_am_alpha_pattern", np.nan)
        pat_str = f"{int(pat):04b}" if not np.isnan(pat) else "????"
        stress = ef.get("stress_B", np.nan)
        print(f"  {i+1:<3} {str(t['entry_time'])[:19]:<20s} {t['net_pnl']:>+8.1f} "
              f"{t['mfe']:>+8.1f} {t['mae']:>+8.1f} {t['bars_held']:>5} "
              f"{t['exit_reason']:<15s} {pat_str:<8s} {stress:>+7.4f}")

    # ── Worst trades ──
    print(f"\n  WORST {n_worst} TRADES:")
    print(f"  {'#':<3} {'Time':<20s} {'Net':>8} {'MFE':>8} {'MAE':>8} "
          f"{'Bars':>5} {'Exit':<15s} {'Pattern':<8s} {'Stress':>7}")
    for i, (_, t) in enumerate(trades_sorted.tail(n_worst).iterrows()):
        ef = t.get("entry_features", {})
        pat = ef.get("T_am_alpha_pattern", np.nan)
        pat_str = f"{int(pat):04b}" if not np.isnan(pat) else "????"
        stress = ef.get("stress_B", np.nan)
        print(f"  {i+1:<3} {str(t['entry_time'])[:19]:<20s} {t['net_pnl']:>+8.1f} "
              f"{t['mfe']:>+8.1f} {t['mae']:>+8.1f} {t['bars_held']:>5} "
              f"{t['exit_reason']:<15s} {pat_str:<8s} {stress:>+7.4f}")

    # ── Clustering: what do winning trades have in common? ──
    print(f"\n  WINNER vs LOSER CLUSTERING:")
    w = trades[trades["winner"]]
    l = trades[~trades["winner"]]

    # Extract features from entry_features dict
    feature_keys = set()
    for _, t in trades.iterrows():
        ef = t.get("entry_features", {})
        if isinstance(ef, dict):
            feature_keys.update(ef.keys())

    interesting = []
    for key in sorted(feature_keys):
        w_vals = [t["entry_features"].get(key, np.nan) for _, t in w.iterrows()
                  if isinstance(t.get("entry_features"), dict)]
        l_vals = [t["entry_features"].get(key, np.nan) for _, t in l.iterrows()
                  if isinstance(t.get("entry_features"), dict)]
        w_vals = [v for v in w_vals if not np.isnan(v)]
        l_vals = [v for v in l_vals if not np.isnan(v)]
        if len(w_vals) < 10 or len(l_vals) < 10: continue
        w_mean = np.mean(w_vals)
        l_mean = np.mean(l_vals)
        pooled_std = np.sqrt((np.var(w_vals) + np.var(l_vals)) / 2)
        if pooled_std > 0:
            effect_size = (w_mean - l_mean) / pooled_std
        else:
            effect_size = 0
        t_stat, p_val = stats.ttest_ind(w_vals, l_vals, equal_var=False)
        if abs(effect_size) > 0.1 or p_val < 0.1:
            interesting.append({
                "feature": key, "w_mean": w_mean, "l_mean": l_mean,
                "effect": effect_size, "p": p_val,
            })

    if interesting:
        interesting.sort(key=lambda x: x["p"])
        print(f"  {'Feature':<40s} {'W mean':>10} {'L mean':>10} {'Effect':>8} {'p':>8}")
        for row in interesting[:20]:
            sig = "***" if row["p"] < .001 else "**" if row["p"] < .01 else "*" if row["p"] < .05 else ""
            print(f"  {row['feature']:<40s} {row['w_mean']:>+10.4f} {row['l_mean']:>+10.4f} "
                  f"{row['effect']:>+8.3f} {row['p']:>7.4f} {sig}")

    # ── Hold dynamics: does regime flip during hold predict outcome? ──
    if "regime_flipped_during" in trades.columns:
        flipped = trades[trades["regime_flipped_during"] == True]
        stable = trades[trades["regime_flipped_during"] == False]
        if len(flipped) >= 5 and len(stable) >= 5:
            f_net = (flipped["final_pnl"].values - cost).mean()
            s_net = (stable["final_pnl"].values - cost).mean()
            print(f"\n  Hold dynamics:")
            print(f"    Regime flipped during hold: n={len(flipped)}, net={f_net:+.3f}")
            print(f"    Regime stable during hold:  n={len(stable)}, net={s_net:+.3f}")

    if "max_stress_during" in trades.columns:
        med_stress = trades["max_stress_during"].median()
        high_s = trades[trades["max_stress_during"] > med_stress]
        low_s = trades[trades["max_stress_during"] <= med_stress]
        if len(high_s) >= 5 and len(low_s) >= 5:
            h_net = (high_s["final_pnl"].values - cost).mean()
            l_net = (low_s["final_pnl"].values - cost).mean()
            print(f"    High stress during hold:    n={len(high_s)}, net={h_net:+.3f}")
            print(f"    Low stress during hold:     n={len(low_s)}, net={l_net:+.3f}")

    # ── MFE bar clustering ──
    if "mfe_bar" in trades.columns:
        print(f"\n  MFE bar clustering (net by when peak occurs):")
        print(f"  {'Peak bar':<10s} {'N':>5} {'Net':>8} {'Hit':>6}")
        for b in sorted(trades["mfe_bar"].unique()):
            sub = trades[trades["mfe_bar"] == b]
            if len(sub) >= 5:
                sub_net = sub["final_pnl"].values - cost
                print(f"  Bar {int(b):<6} {len(sub):>5} {sub_net.mean():>+8.3f} "
                      f"{(sub_net>0).mean():>6.3f}")


def reversal_analysis(results):
    """
    Check if reversing the strategy improves results.
    A very negative net implies the opposite direction works.
    """
    trades = results["trades"]
    cfg = results["config"]
    cost = cfg["instrument"]["cost_pts"]

    if len(trades) < 10: return

    # Original direction
    net_orig = trades["final_pnl"].values - cost
    # Reversed: negate the P&L (what if you went the other way?)
    net_rev = -trades["final_pnl"].values - cost

    orig_mean = net_orig.mean()
    rev_mean = net_rev.mean()

    print(f"\n  Reversal Analysis:")
    print(f"    Original ({cfg['instrument']['cross']} "
          f"{'long' if cfg['instrument']['direction']==1 else 'short'}): "
          f"net={orig_mean:+.3f}, hit={(net_orig>0).mean():.3f}")
    print(f"    Reversed: net={rev_mean:+.3f}, hit={(net_rev>0).mean():.3f}")

    if rev_mean > orig_mean and rev_mean > 0:
        improvement = rev_mean - orig_mean
        print(f"    ⚠ REVERSAL IS BETTER by {improvement:+.3f} pts/trade")
        print(f"    Consider flipping direction for this instrument.")
    elif orig_mean <= 0 and rev_mean <= 0:
        print(f"    Neither direction works well.")
    else:
        print(f"    Original direction is correct.")


def full_diagnostic(results):
    """Run complete diagnostic suite on one instrument's results."""
    trade_report(results)
    trade_paths(results)
    entry_context(results)
    trade_detail(results)
    reversal_analysis(results)
    bootstrap(results)
    feat_df = feature_scan(results)
    pattern_analysis(results)
    gate_combinatorics(results)
    return feat_df
"""
exits.py — Dynamic Exit Strategies
=====================================
Called by backtest.simulate_trades(). Each exit function takes the
same signature and returns a trade_record dict with full path logging.

Exit strategies:
  fixed      — hold N bars, exit at close
  atr        — TP/SL based on entry-bar ATR
  tensor     — exit on regime flip, stress spike, or micro change
  mae_stop   — exit if drawdown exceeds historical winner MAE
  trailing   — trailing stop after MFE threshold reached
  combined   — ATR envelope + tensor monitoring
  adaptive   — early exit for quick movers, extended hold for runners
  thesis     — data-driven: quick profit + MAE + trailing + regime/stress/micro

All exits log bar-by-bar tensor state during the hold for post-hoc analysis.
"""
import numpy as np
from config import PARAMS


# ══════════════════════════════════════════════════════════════
# CORE: SIMULATE ONE TRADE (called by backtest.py)
# ══════════════════════════════════════════════════════════════

def simulate_trade(df, entry_loc, direction, cost, exit_cfg,
                   all_tensors=None, entry_row=None):
    entry_close = df.iloc[entry_loc]["Close"]
    entry_sess = df.iloc[entry_loc]["SessionDate"]
    entry_atr = df.iloc[entry_loc].get("ATR", np.nan)
    entry_time = df.index[entry_loc]

    exit_type = exit_cfg["type"]
    max_hold = exit_cfg.get("max_bars", exit_cfg.get("bars", PARAMS["hold_bars"]))

    bars = []
    exit_reason = "max_hold"
    peak_pnl = 0.0
    trail_active = False

    for b in range(1, max_hold + 1):
        if entry_loc + b >= len(df):
            exit_reason = "data_end"
            break
        if df.iloc[entry_loc + b]["SessionDate"] != entry_sess:
            exit_reason = "session_end"
            break

        bar_data = df.iloc[entry_loc + b]
        bc = bar_data["Close"]
        bh = bar_data["High"]
        bl = bar_data["Low"]

        unrealized = (bc - entry_close) * direction
        bar_best = ((bh - entry_close) if direction == 1 else (entry_close - bl)) * abs(direction)
        bar_worst = ((bl - entry_close) if direction == 1 else (entry_close - bh)) * abs(direction)

        peak_pnl = max(peak_pnl, unrealized, bar_best)

        # Tensor state during hold
        hold_alpha_regime = None
        hold_stress = None
        hold_micro_gk = None

        if all_tensors is not None and entry_loc + b < len(all_tensors):
            ts = all_tensors[entry_loc + b]
            hold_alpha_regime = 1 if ts[3, 3] > 0.5 else 0
            hold_micro_gk = ts[3, 7]

        if "stress_B" in bar_data.index:
            hold_stress = bar_data["stress_B"]

        bars.append({
            "bar": b, "close": bc, "high": bh, "low": bl,
            "unrealized": unrealized, "best": bar_best, "worst": bar_worst,
            "peak_pnl": peak_pnl,
            "alpha_regime": hold_alpha_regime,
            "stress": hold_stress,
            "micro_gk": hold_micro_gk,
        })

        # Exit logic
        exited = _check_exit(
            exit_type, exit_cfg, b, unrealized, peak_pnl,
            entry_atr, bar_data, entry_loc, df, all_tensors,
            direction, trail_active
        )

        if exited:
            exit_reason = exited
            break

        # Update trailing state for adaptive/trailing/combined/thesis
        if exit_type in ("trailing", "adaptive", "combined"):
            thresh = exit_cfg.get("trail_activation", 0.5)
            avg_mfe = exit_cfg.get("avg_mfe", 10.0)
            if peak_pnl > thresh * avg_mfe:
                trail_active = True
        elif exit_type == "thesis":
            trail_pts = exit_cfg.get("trail_activate_pts", 15.0)
            if peak_pnl >= trail_pts:
                trail_active = True

    if len(bars) == 0:
        return None

    # Compute trade metrics
    unr = [b["unrealized"] for b in bars]
    bst = [b["best"] for b in bars]
    wst = [b["worst"] for b in bars]

    final = unr[-1]
    mfe = max(max(unr), max(bst))
    mae = min(min(unr), min(wst))
    eff = final / mfe if mfe > 0 else 0.0
    mfe_bar = np.argmax([max(u, b) for u, b in zip(unr, bst)]) + 1

    # Entry context
    pre_start = max(0, entry_loc - 5)
    pre_bars = df.iloc[pre_start:entry_loc]
    pre_momentum = 0.0
    pre_vol = 0.0
    if len(pre_bars) >= 2:
        pre_momentum = (pre_bars["Close"].iloc[-1] - pre_bars["Close"].iloc[0]) / \
                       pre_bars["Close"].iloc[0] * direction
        pre_vol = pre_bars["GK"].mean() if "GK" in pre_bars.columns else 0.0

    # Entry tensor snapshot
    entry_tensor_features = {}
    if entry_row is not None:
        for col in entry_row.index:
            if col.startswith("T_") or col in ("stress_A", "stress_B"):
                val = entry_row[col]
                if not isinstance(val, (str, bool)):
                    entry_tensor_features[col] = val

    # Session position
    sess_bars = df[df["SessionDate"] == entry_sess]
    bar_in_session = entry_loc - df.index.get_loc(sess_bars.index[0])
    session_pct = bar_in_session / len(sess_bars) if len(sess_bars) > 0 else 0.5

    # Regime during hold
    hold_regimes = [b["alpha_regime"] for b in bars if b["alpha_regime"] is not None]
    regime_flipped_during = False
    if len(hold_regimes) >= 2:
        regime_flipped_during = any(hold_regimes[i] != hold_regimes[0]
                                     for i in range(1, len(hold_regimes)))

    hold_stresses = [b["stress"] for b in bars if b["stress"] is not None]
    max_stress_during = max(hold_stresses) if hold_stresses else np.nan

    return {
        "entry_time": entry_time,
        "entry_close": entry_close,
        "entry_atr": entry_atr,
        "session_date": entry_sess,
        "session_pct": session_pct,
        "final_pnl": final,
        "net_pnl": final - cost,
        "mfe": mfe,
        "mae": mae,
        "entry_eff": eff,
        "mfe_bar": mfe_bar,
        "bars_held": len(bars),
        "exit_reason": exit_reason,
        "winner": final > cost,
        "path": unr,
        "path_best": bst,
        "path_worst": wst,
        "pre_momentum": pre_momentum,
        "pre_vol": pre_vol,
        "direction": direction,
        "entry_features": entry_tensor_features,
        "regime_flipped_during": regime_flipped_during,
        "max_stress_during": max_stress_during,
        "hold_regimes": hold_regimes,
        "hold_stresses": hold_stresses,
        "conviction": entry_row.get("conviction", np.nan) if entry_row is not None else np.nan,
        "n_contracts": entry_row.get("n_contracts", 1) if entry_row is not None else 1,
    }


# ══════════════════════════════════════════════════════════════
# EXIT CHECKS
# ══════════════════════════════════════════════════════════════

def _check_exit(exit_type, cfg, bar_num, unrealized, peak_pnl,
                entry_atr, bar_data, entry_loc, df, all_tensors,
                direction, trail_active):
    """Returns exit_reason string or None."""

    # ── Fixed ──
    if exit_type == "fixed":
        if bar_num >= cfg["bars"]:
            return "fixed_hold"

    # ── ATR-based ──
    elif exit_type == "atr":
        if not np.isnan(entry_atr) and entry_atr > 0:
            if unrealized >= entry_atr * cfg["tp_mult"]:
                return "atr_tp"
            if unrealized <= -entry_atr * cfg["sl_mult"]:
                return "atr_sl"

    # ── Tensor-monitored ──
    elif exit_type == "tensor":
        if cfg.get("exit_on_regime_flip", True):
            entry_regime = 1 if df.iloc[entry_loc].get("Alpha_P1", 0.5) > 0.5 else 0
            curr_regime = 1 if bar_data.get("Alpha_P1", 0.5) > 0.5 else 0
            if curr_regime != entry_regime:
                return "regime_flip"
        if cfg.get("exit_on_stress_spike", True):
            curr_stress = bar_data.get("stress_B", 0)
            if curr_stress > cfg.get("stress_exit_threshold", 0.05):
                return "stress_spike"
        if cfg.get("exit_on_micro_spike", False):
            entry_gk = df.iloc[entry_loc].get("micro_gk", 0)
            curr_gk = bar_data.get("micro_gk", 0)
            if entry_gk > 0 and curr_gk > entry_gk * 2.0:
                return "micro_spike"

    # ── MAE-informed stop ──
    elif exit_type == "mae":
        max_mae = cfg.get("max_mae_pts", None)
        if max_mae is not None and unrealized <= -abs(max_mae):
            return "mae_stop"

    # ── Trailing stop ──
    elif exit_type == "trailing":
        if trail_active:
            trail_pct = cfg.get("trail_pct", 0.4)
            trail_level = peak_pnl * (1 - trail_pct)
            if unrealized <= trail_level and peak_pnl > 0:
                return "trailing_stop"

    # ── Adaptive (time + trailing + MAE) ──
    elif exit_type == "adaptive":
        quick_thresh = cfg.get("quick_profit_mult", 1.5)
        avg_profit = cfg.get("avg_profit", 5.0)
        quick_bars = cfg.get("quick_profit_bars", 2)
        if bar_num <= quick_bars and unrealized > quick_thresh * avg_profit:
            return "quick_profit"
        max_mae = cfg.get("max_mae_pts", None)
        if max_mae is not None and unrealized <= -abs(max_mae):
            return "mae_stop"
        if trail_active:
            trail_pct = cfg.get("trail_pct", 0.4)
            trail_level = peak_pnl * (1 - trail_pct)
            if unrealized <= trail_level and peak_pnl > 0:
                return "trailing_stop"

    # ── Combined (ATR + tensor) ──
    elif exit_type == "combined":
        if not np.isnan(entry_atr) and entry_atr > 0:
            if unrealized >= entry_atr * cfg.get("atr_tp_mult", 1.5):
                return "atr_tp"
            if unrealized <= -entry_atr * cfg.get("atr_sl_mult", 2.5):
                return "atr_sl"
        if cfg.get("exit_on_regime_flip", True):
            entry_regime = 1 if df.iloc[entry_loc].get("Alpha_P1", 0.5) > 0.5 else 0
            curr_regime = 1 if bar_data.get("Alpha_P1", 0.5) > 0.5 else 0
            if curr_regime != entry_regime:
                return "regime_flip"
        if cfg.get("exit_on_stress_spike", True):
            curr_stress = bar_data.get("stress_B", 0)
            if curr_stress > cfg.get("stress_exit_threshold", 0.05):
                return "stress_spike"

    # ── Thesis v2: revised from parameter sweep + ablation ──
    elif exit_type == "thesis":

        # 1. Quick profit: captures bimodal bar-1 spike
        #    Raised to 25 pts — 15 was grabbing trades that had 30+ in them
        quick_pts = cfg.get("quick_profit_pts", 25.0)
        quick_bars = cfg.get("quick_profit_bars", 2)
        if bar_num <= quick_bars and unrealized >= quick_pts:
            return "quick_profit"

        # 2. MAE stop: cut before loser digs deeper than winners survive
        #    30 pts optimal from sweep (25 was close, 30 slightly better)
        mae_stop = cfg.get("mae_stop_pts", 30.0)
        if unrealized <= -mae_stop:
            return "mae_stop"

        # 3. Regime flip: demoted from hard exit to trail tightener
        #    Ablation showed hard exit hurts — catches noise flips + cuts
        #    trades that would recover. Instead, if regime flips and trailing
        #    is active, tighten the trail percentage to lock in more.
        if cfg.get("exit_on_regime_flip", False) and bar_num >= 2:
            entry_regime = 1 if df.iloc[entry_loc].get("Alpha_P1", 0.5) > 0.5 else 0
            curr_regime = 1 if bar_data.get("Alpha_P1", 0.5) > 0.5 else 0
            if curr_regime != entry_regime:
                return "regime_flip"

        # 4. Stress spike: KL divergence accelerating
        if cfg.get("exit_on_stress_spike", True):
            curr_stress = bar_data.get("stress_B", 0)
            if curr_stress > cfg.get("stress_exit_threshold", 0.04):
                return "stress_spike"

        # 5. Micro GK collapse: volatility variability gone, thesis dead
        if cfg.get("exit_on_micro_collapse", True) and bar_num >= 3:
            entry_mgk = df.iloc[entry_loc].get("micro_gk", 0)
            curr_mgk = bar_data.get("micro_gk", 0)
            ratio = cfg.get("micro_collapse_ratio", 0.3)
            if entry_mgk > 0 and curr_mgk < entry_mgk * ratio:
                return "micro_collapse"

        # 6. Trailing stop with regime-aware tightening
        #    Base trail_pct = 0.55 (give back less of peak)
        #    If regime flipped and tightening is enabled, use tighter trail
        if trail_active:
            trail_pct = cfg.get("trail_pct", 0.55)

            # Regime flip tightens trail instead of hard exiting
            if cfg.get("regime_flip_tightens_trail", False) and bar_num >= 2:
                entry_regime = 1 if df.iloc[entry_loc].get("Alpha_P1", 0.5) > 0.5 else 0
                curr_regime = 1 if bar_data.get("Alpha_P1", 0.5) > 0.5 else 0
                if curr_regime != entry_regime:
                    trail_pct = cfg.get("regime_flip_trail_pct", 0.30)

            trail_level = peak_pnl * (1 - trail_pct)
            if unrealized <= trail_level and peak_pnl > 0:
                return "trailing_stop"

    return None


# ══════════════════════════════════════════════════════════════
# HELPERS
# ══════════════════════════════════════════════════════════════

def compute_exit_thresholds(historical_trades, percentile=75):
    """Compute data-driven exit thresholds from historical trade data."""
    if len(historical_trades) < 20:
        return {}

    winners = historical_trades[historical_trades["winner"]]
    losers = historical_trades[~historical_trades["winner"]]
    winner_mae = winners["mae"].values if len(winners) > 0 else np.array([0])
    all_mfe = historical_trades["mfe"].values

    thresholds = {
        "max_mae_pts": abs(np.percentile(winner_mae, 100 - percentile)),
        "avg_mfe": np.mean(all_mfe),
        "avg_profit": winners["final_pnl"].mean() if len(winners) > 0 else 1.0,
        "median_mfe": np.median(all_mfe),
    }

    print("  Exit thresholds from {} trades:".format(len(historical_trades)))
    print("    MAE stop (P{} winner MAE): {:.2f} pts".format(percentile, thresholds["max_mae_pts"]))
    print("    Avg MFE: {:.2f} pts".format(thresholds["avg_mfe"]))
    print("    Avg winner profit: {:.2f} pts".format(thresholds["avg_profit"]))

    return thresholds


def update_exit_config(exit_cfg, thresholds):
    """Update an exit config dict with computed thresholds."""
    cfg = exit_cfg.copy()
    if "max_mae_pts" in thresholds:
        cfg["max_mae_pts"] = thresholds["max_mae_pts"]
    if "avg_mfe" in thresholds:
        cfg["avg_mfe"] = thresholds["avg_mfe"]
    if "avg_profit" in thresholds:
        cfg["avg_profit"] = thresholds["avg_profit"]
    return cfg
# region imports
from AlgorithmImports import *
# endregion
"""
hmm_model.py — HMM Regime Detection
=======================================
Fits 2-state Gaussian HMM on configurable emissions.
Provides both Mode A (full backward, backtest ceiling) and
Mode B (progressive backward, live-feasible).

Pure functions. No global state.

Mode A: Full session forward-backward. Gamma at bar t uses bars t+1..T.
        NOT available in live. Used for comparison/ceiling.

Mode B: Progressive backward. At bar t, run backward on 0..t.
        Gamma at bar t-3 has 3 bars of backward info.
        Gamma at bar t (frontier) ≈ alpha (beta=1).
        This is what you'd run in live.
"""
import numpy as np
from hmmlearn import hmm
from config import PARAMS


# ══════════════════════════════════════════════════════════════
# HMM FITTING
# ══════════════════════════════════════════════════════════════

def fit_hmm(df, emission_cols, train_sessions, n_states=None, n_iter=None):
    """
    Fit Gaussian HMM on specified emissions and training sessions.

    Parameters
    ----------
    df : pd.DataFrame — must contain emission_cols and SessionDate
    emission_cols : list of str — column names for emissions
    train_sessions : list — session dates for training
    n_states : int — number of HMM states (default from config)
    n_iter : int — EM iterations (default from config)

    Returns
    -------
    model : fitted hmmlearn.GaussianHMM
    """
    if n_states is None: n_states = PARAMS["hmm_states"]
    if n_iter is None: n_iter = PARAMS["hmm_iter"]

    train_mask = df["SessionDate"].isin(train_sessions)
    X_train = df.loc[train_mask, emission_cols].values
    lengths = [int((df["SessionDate"] == s).sum()) for s in train_sessions]

    model = hmm.GaussianHMM(
        n_components=n_states, covariance_type="full",
        n_iter=n_iter, random_state=PARAMS["hmm_random_state"],
    )
    model.fit(X_train, lengths=lengths)

    # Enforce label consistency: State 0 = high vol, State 1 = quiet
    # Convention: higher GK_norm mean = higher vol state
    # For single-emission models, use first emission's variance
    if n_states == 2:
        if len(emission_cols) >= 2:
            # Use second emission (GK) if available
            swap = model.means_[0, 1] < model.means_[1, 1]
        else:
            # Use variance of first emission
            swap = model.covars_[0, 0, 0] < model.covars_[1, 0, 0]

        if swap:
            p = [1, 0]
            model.means_ = model.means_[p]
            model.covars_ = model.covars_[p]
            model.transmat_ = model.transmat_[np.ix_(p, p)]
            model.startprob_ = model.startprob_[p]

    diag = np.diag(model.transmat_)
    print(f"  HMM fitted on {len(train_sessions)} sessions, "
          f"{len(emission_cols)} emissions")
    print(f"    Transition diag: [{diag[0]:.3f}, {diag[1]:.3f}]")

    return model


# ══════════════════════════════════════════════════════════════
# FORWARD PASS (same for Mode A and Mode B)
# ══════════════════════════════════════════════════════════════

def run_forward(model, X, active):
    """
    Forward algorithm with activity masking.

    Parameters
    ----------
    model : fitted GaussianHMM
    X : np.ndarray [T, n_emissions]
    active : np.ndarray [T] bool

    Returns
    -------
    alpha : np.ndarray [T, n_states]
    B : np.ndarray [T, n_states] — emission probabilities (cached for backward)
    """
    T = len(X)
    N = model.n_components
    B = np.exp(model._compute_log_likelihood(X))

    alpha = np.zeros((T, N))
    alpha[0] = model.startprob_ * (B[0] if active[0] else 1.0)
    s = alpha[0].sum()
    if s > 0: alpha[0] /= s

    for t in range(1, T):
        alpha[t] = (alpha[t - 1] @ model.transmat_) * (B[t] if active[t] else 1.0)
        s = alpha[t].sum()
        if s > 0: alpha[t] /= s

    return alpha, B


# ══════════════════════════════════════════════════════════════
# MODE A: FULL BACKWARD (backtest ceiling)
# ══════════════════════════════════════════════════════════════

def run_full_backward(model, B, active):
    """
    Full backward pass over entire session.
    NOT live-feasible — uses future data within session.

    Parameters
    ----------
    model : fitted GaussianHMM
    B : np.ndarray [T, n_states] — emission probs from forward pass
    active : np.ndarray [T] bool

    Returns
    -------
    beta : np.ndarray [T, n_states]
    """
    T = len(B)
    N = model.n_components
    beta = np.zeros((T, N))
    beta[T - 1] = 1.0

    for t in range(T - 2, -1, -1):
        bn = B[t + 1] if active[t + 1] else 1.0
        beta[t] = (model.transmat_ * bn * beta[t + 1]).sum(axis=1)
        s = beta[t].sum()
        if s > 0: beta[t] /= s

    return beta


def compute_gamma_kl(alpha, beta, active):
    """
    Compute smoothed posteriors (gamma) and KL divergence.

    Parameters
    ----------
    alpha : np.ndarray [T, n_states]
    beta : np.ndarray [T, n_states]
    active : np.ndarray [T] bool

    Returns
    -------
    gamma : np.ndarray [T, n_states]
    kl : np.ndarray [T] — KL(alpha || gamma), NaN for inactive bars
    """
    fb = alpha * beta
    fb_sum = fb.sum(axis=1, keepdims=True)
    fb_sum[fb_sum == 0] = 1.0
    gamma = fb / fb_sum

    kl = _kl_div(alpha, gamma)
    kl[~active] = np.nan

    return gamma, kl


def mode_a_session(model, X, active):
    """
    Full Mode A (backtest ceiling) for one session.

    Returns
    -------
    alpha : [T, 2], gamma : [T, 2], kl : [T]
    """
    alpha, B = run_forward(model, X, active)
    beta = run_full_backward(model, B, active)
    gamma, kl = compute_gamma_kl(alpha, beta, active)
    return alpha, gamma, kl


# ══════════════════════════════════════════════════════════════
# MODE B: PROGRESSIVE BACKWARD (live-feasible)
# ══════════════════════════════════════════════════════════════

def mode_b_session(model, X, active):
    """
    Progressive backward for one session. At each bar t, runs
    backward on bars 0..t to get gamma/KL for the tensor slots.

    Returns
    -------
    alpha : [T, 2] — forward posteriors (same as Mode A)
    kl_slots : [T, 4] — KL for tensor slots [t-3, t-2, t-1, t]
        Slot 3 (frontier) will have KL ≈ 0 since beta[t]=1
    gamma_slots : [T, 4, 2] — gamma for tensor slots
    """
    T = len(X)
    N = model.n_components
    alpha, B = run_forward(model, X, active)
    A = model.transmat_

    kl_slots = np.full((T, 4), np.nan)
    gamma_slots = np.zeros((T, 4, N))

    for t in range(T):
        # Backward pass on bars 0..t
        start = max(0, t - 3)
        length = t - start + 1

        b_partial = np.zeros((length, N))
        b_partial[-1] = 1.0  # beta[t] = 1 always

        for tb in range(length - 2, -1, -1):
            abs_idx = start + tb + 1
            bn = B[abs_idx] if active[abs_idx] else 1.0
            b_partial[tb] = (A * bn * b_partial[tb + 1]).sum(axis=1)
            s = b_partial[tb].sum()
            if s > 0: b_partial[tb] /= s

        # Fill slots
        for slot in range(4):
            abs_bar = t - (3 - slot)  # slot 0 = t-3, slot 3 = t
            if abs_bar < 0 or abs_bar < start:
                gamma_slots[t, slot] = [0.5, 0.5]
                kl_slots[t, slot] = 0.0
                continue

            local_idx = abs_bar - start
            fb = alpha[abs_bar] * b_partial[local_idx]
            s = fb.sum()
            if s > 0: fb /= s

            gamma_slots[t, slot] = fb
            if active[abs_bar]:
                kl_slots[t, slot] = _kl_div_single(alpha[abs_bar], fb)
            else:
                kl_slots[t, slot] = 0.0

    return alpha, kl_slots, gamma_slots


# ══════════════════════════════════════════════════════════════
# RUN ON ALL SESSIONS
# ══════════════════════════════════════════════════════════════

def run_hmm_all_sessions(df, model, emission_cols, sessions, mode="progressive"):
    """
    Run HMM forward + backward on all sessions.

    Parameters
    ----------
    df : pd.DataFrame
    model : fitted GaussianHMM
    emission_cols : list of str
    sessions : list of session dates
    mode : "progressive" (Mode B) or "full" (Mode A)

    Returns
    -------
    results : dict with keys:
        alpha : [N, 2] — forward posteriors for all bars
        kl_a : [N] — Mode A KL (full backward), always computed for comparison
        kl_b_slots : [N, 4] — Mode B KL per tensor slot (if mode includes progressive)
        gamma_b_slots : [N, 4, 2] — Mode B gamma per tensor slot
    """
    N = len(df)
    n_states = model.n_components

    alpha_all = np.zeros((N, n_states))
    kl_a_all = np.full(N, np.nan)

    # Always compute Mode B (it's cheap and needed for tensor)
    kl_b_slots = np.full((N, 4), np.nan)
    gamma_b_slots = np.zeros((N, 4, n_states))

    offset = 0
    for si, sess in enumerate(sessions):
        sm = df["SessionDate"] == sess
        sd = df[sm]
        X = sd[emission_cols].values
        act = sd["Active"].values
        T_s = len(X)

        # Mode A (always, for comparison)
        alpha_s, gamma_a_s, kl_a_s = mode_a_session(model, X, act)
        alpha_all[offset:offset + T_s] = alpha_s
        kl_a_all[offset:offset + T_s] = kl_a_s

        # Mode B (progressive backward)
        _, kl_b_s, gamma_b_s = mode_b_session(model, X, act)
        kl_b_slots[offset:offset + T_s] = kl_b_s
        gamma_b_slots[offset:offset + T_s] = gamma_b_s

        offset += T_s

        if (si + 1) % 200 == 0:
            print(f"    ...{si+1}/{len(sessions)} sessions")

    return {
        "alpha": alpha_all,
        "kl_a": kl_a_all,
        "kl_b_slots": kl_b_slots,
        "gamma_b_slots": gamma_b_slots,
    }


# ══════════════════════════════════════════════════════════════
# UTILITIES
# ══════════════════════════════════════════════════════════════

def _kl_div(p, q, eps=1e-10):
    """KL(p || q) per row. Returns [T] array."""
    p = np.clip(p, eps, 1.0)
    q = np.clip(q, eps, 1.0)
    return np.sum(p * np.log(p / q), axis=-1)


def _kl_div_single(p, q, eps=1e-10):
    """KL(p || q) for single pair of distributions."""
    p = np.clip(p, eps, 1.0)
    q = np.clip(q, eps, 1.0)
    return np.sum(p * np.log(p / q))


def get_train_sessions(sessions, train_frac=None):
    """Split sessions into train/val by fraction."""
    if train_frac is None:
        train_frac = PARAMS["hmm_train_frac"]
    n_train = int(len(sessions) * train_frac)
    return sessions[:n_train], sessions[n_train:]
"""
kl_predictor.py — Two-Branch KL Forward Prediction
=====================================================
The missing core: predicts KL(alpha || gamma) at t+1 using two
independent branches, blended by adaptive confidence weighting.

Branch 1 (MICRO): Exogenous microstructure features predict next KL.
    Inputs: micro_gk, micro_range, micro_vol, micro_cpos
    Thesis: market microstructure shifts precede regime divergence.

Branch 2 (AUTO): Autoregressive — KL's own history predicts next KL.
    Inputs: KL at t, t-1, t-2, KL slope, KL acceleration, alpha state
    Thesis: KL momentum/mean-reversion carries forward.

Both branches use Recursive Least Squares (RLS) with forgetting factor
for online, live-feasible updating. No look-ahead. Session-reset option.

The combined prediction + confidence feeds the decision layer as the
signal that the static tensor features alone cannot provide.

Sits between tensor.py and decision.py in the pipeline.

Usage:
    import kl_predictor as kp
    df, kp_results = kp.run_kl_predictor(df, hmm_results, sessions)
    # df now has KP_* columns for gates/decision
"""
import numpy as np
import pandas as pd
from config import PARAMS, MICRO_FEATURES


# ══════════════════════════════════════════════════════════════
# CONFIGURATION
# ══════════════════════════════════════════════════════════════

KP_PARAMS = {
    # RLS forgetting factor: 1.0 = no forgetting, 0.95 = fast adaptation
    # Lower = adapts faster but noisier. Higher = smoother but slower.
    "rls_lambda_micro": 0.98,
    "rls_lambda_auto": 0.98,

    # Ridge regularization (initial P matrix scale)
    "rls_delta": 100.0,

    # Minimum bars before predictions are trusted
    "warmup_bars": 10,

    # Confidence EWMA span for tracking branch accuracy
    "confidence_ewma_span": 30,

    # Session reset: True = reset RLS each session, False = carry forward
    # False is better if the model generalizes across sessions.
    # True is safer if microstructure shifts day-to-day.
    "session_reset": False,

    # Blend mode: "adaptive" (error-weighted) or "equal"
    "blend_mode": "adaptive",

    # Minimum weight for either branch (prevents one from being zeroed out)
    "min_branch_weight": 0.15,

    # Feature standardization EWMA span
    "feature_ewma_span": 200,
}


# ══════════════════════════════════════════════════════════════
# RLS (RECURSIVE LEAST SQUARES) ENGINE
# ══════════════════════════════════════════════════════════════

class RLSBranch:
    """
    Single-output Recursive Least Squares with forgetting factor.

    At each step:
      1. Predict: y_hat = x @ w
      2. Observe: y_actual
      3. Update: w += K * (y - y_hat), P updated

    Forgetting factor lambda < 1 downweights old observations,
    allowing the model to track non-stationary relationships.
    """

    def __init__(self, n_features, lam=0.98, delta=100.0):
        """
        Parameters
        ----------
        n_features : int — input dimension
        lam : float — forgetting factor in (0, 1]
        delta : float — initial P matrix scale (higher = less regularization)
        """
        self.n = n_features
        self.lam = lam
        self.w = np.zeros(n_features)           # weight vector
        self.P = np.eye(n_features) * delta      # inverse covariance
        self.steps = 0
        self.last_pred = 0.0
        self.last_var = 1.0                      # prediction variance

    def predict(self, x):
        """Predict y given feature vector x. Returns (prediction, variance)."""
        x = np.asarray(x, dtype=np.float64)
        pred = x @ self.w
        var = x @ self.P @ x   # prediction uncertainty
        self.last_pred = pred
        self.last_var = max(var, 1e-10)
        return pred, var

    def update(self, x, y):
        """Update weights given observed target y."""
        x = np.asarray(x, dtype=np.float64)
        y = float(y)

        # Kalman gain
        Px = self.P @ x
        denom = self.lam + x @ Px
        if abs(denom) < 1e-12:
            return  # degenerate, skip update

        K = Px / denom

        # Prediction error
        err = y - x @ self.w

        # Weight update
        self.w = self.w + K * err

        # Covariance update
        self.P = (self.P - np.outer(K, x @ self.P)) / self.lam

        # Numerical stability: symmetrize P
        self.P = (self.P + self.P.T) / 2.0

        self.steps += 1

    def reset(self, delta=None):
        """Reset to initial state."""
        if delta is None:
            delta = KP_PARAMS["rls_delta"]
        self.w = np.zeros(self.n)
        self.P = np.eye(self.n) * delta
        self.steps = 0


# ══════════════════════════════════════════════════════════════
# FEATURE BUILDERS (per bar, no look-ahead)
# ══════════════════════════════════════════════════════════════

def _build_micro_features(df_row, running_stats):
    """
    Branch 1 features: microstructure at current bar.
    Returns feature vector of length 5 (or 7 with quotes).

    Features:
      [0] micro_gk (normalized)
      [1] micro_range (normalized)
      [2] micro_vol (normalized)
      [3] micro_cpos (raw, already 0-1)
      [4] intercept (bias term)
    """
    raw = np.array([
        df_row.get("micro_gk", 0.0),
        df_row.get("micro_range", 0.0),
        df_row.get("micro_vol", 0.0),
        df_row.get("micro_cpos", 0.5),
    ], dtype=np.float64)

    # Online standardization
    normed = _online_standardize(raw, running_stats, "micro")

    return np.append(normed, 1.0)  # add intercept


def _build_auto_features(kl_hist, alpha_hist, running_stats):
    """
    Branch 2 features: autoregressive from KL's own history.
    kl_hist: last 3 KL values [t-2, t-1, t]
    alpha_hist: last 2 alpha P(quiet) values [t-1, t]

    Features:
      [0] KL at t (normalized)
      [1] KL at t-1 (normalized)
      [2] KL at t-2 (normalized)
      [3] KL slope (t-2 to t)
      [4] KL acceleration
      [5] alpha P(quiet) at t
      [6] alpha change (t - t-1)
      [7] intercept
    """
    kl = np.array(kl_hist, dtype=np.float64)

    # Handle NaN
    kl = np.nan_to_num(kl, nan=0.0)

    slope = (kl[2] - kl[0]) / 2.0
    accel = (kl[2] - kl[1]) - (kl[1] - kl[0])

    raw = np.array([
        kl[2], kl[1], kl[0],
        slope, accel,
        alpha_hist[1] if len(alpha_hist) > 1 else 0.5,
        (alpha_hist[1] - alpha_hist[0]) if len(alpha_hist) > 1 else 0.0,
    ], dtype=np.float64)

    normed = _online_standardize(raw, running_stats, "auto")

    return np.append(normed, 1.0)  # add intercept


def _online_standardize(raw, stats, prefix):
    """
    EWMA-based online standardization. No look-ahead.
    Updates running mean/var and returns standardized values.
    """
    span = KP_PARAMS["feature_ewma_span"]
    alpha = 2.0 / (span + 1.0)

    key_m = f"{prefix}_mean"
    key_v = f"{prefix}_var"
    key_n = f"{prefix}_n"

    if key_m not in stats:
        stats[key_m] = raw.copy()
        stats[key_v] = np.ones_like(raw)
        stats[key_n] = 0

    n = stats[key_n]
    if n < 5:
        # Cold start: accumulate without standardizing
        stats[key_m] = (stats[key_m] * n + raw) / (n + 1)
        stats[key_v] = np.ones_like(raw)
        stats[key_n] = n + 1
        return raw  # return raw during warmup
    else:
        # EWMA update
        stats[key_m] = alpha * raw + (1 - alpha) * stats[key_m]
        diff = raw - stats[key_m]
        stats[key_v] = alpha * (diff ** 2) + (1 - alpha) * stats[key_v]
        stats[key_n] = n + 1

        std = np.sqrt(np.maximum(stats[key_v], 1e-10))
        return (raw - stats[key_m]) / std


# ══════════════════════════════════════════════════════════════
# SESSION RUNNER
# ══════════════════════════════════════════════════════════════

def run_kl_predictor_session(df_sess, kl_slot2, alpha_q,
                             micro_branch, auto_branch,
                             running_stats, blend_state):
    """
    Run KL prediction for one session. Updates models in-place.

    Parameters
    ----------
    df_sess : pd.DataFrame — bars for this session
    kl_slot2 : np.ndarray [T] — KL from tensor slot 2 (most recent with backward info)
    alpha_q : np.ndarray [T] — alpha P(quiet) for this session
    micro_branch : RLSBranch — micro model (updated in-place)
    auto_branch : RLSBranch — auto model (updated in-place)
    running_stats : dict — online standardization state
    blend_state : dict — adaptive blending state

    Returns
    -------
    preds : dict of [T] arrays — all prediction outputs for this session
    """
    T = len(df_sess)
    n_micro = micro_branch.n
    n_auto = auto_branch.n

    # Output arrays
    pred_micro = np.full(T, np.nan)
    pred_auto = np.full(T, np.nan)
    pred_blend = np.full(T, np.nan)
    actual_kl = np.full(T, np.nan)
    err_micro = np.full(T, np.nan)
    err_auto = np.full(T, np.nan)
    err_blend = np.full(T, np.nan)
    var_micro = np.full(T, np.nan)
    var_auto = np.full(T, np.nan)
    w_micro = np.full(T, 0.5)
    w_auto = np.full(T, 0.5)
    confidence = np.full(T, np.nan)

    warmup = KP_PARAMS["warmup_bars"]
    min_w = KP_PARAMS["min_branch_weight"]

    # Running error tracking for adaptive blending
    ew_span = KP_PARAMS["confidence_ewma_span"]
    ew_alpha = 2.0 / (ew_span + 1.0)

    if "mse_micro" not in blend_state:
        blend_state["mse_micro"] = 1.0
        blend_state["mse_auto"] = 1.0

    for t in range(T):
        row = df_sess.iloc[t]
        kl_now = kl_slot2[t]
        actual_kl[t] = kl_now

        # ── Skip if not enough history ──
        if t < 3:
            continue

        # ── Build features ──
        x_micro = _build_micro_features(row, running_stats)
        x_auto = _build_auto_features(
            [kl_slot2[t-2], kl_slot2[t-1], kl_slot2[t]],
            [alpha_q[t-1], alpha_q[t]],
            running_stats,
        )

        # ── Predict KL at t+1 (before observing t+1) ──
        pm, vm = micro_branch.predict(x_micro)
        pa, va = auto_branch.predict(x_auto)

        pred_micro[t] = pm
        pred_auto[t] = pa
        var_micro[t] = vm
        var_auto[t] = va

        # ── Adaptive blend ──
        mse_m = blend_state["mse_micro"]
        mse_a = blend_state["mse_auto"]

        if KP_PARAMS["blend_mode"] == "adaptive" and mse_m + mse_a > 0:
            # Inverse-MSE weighting: better branch gets more weight
            inv_m = 1.0 / max(mse_m, 1e-8)
            inv_a = 1.0 / max(mse_a, 1e-8)
            wm = inv_m / (inv_m + inv_a)
            wa = 1.0 - wm
            # Floor
            wm = max(wm, min_w)
            wa = max(wa, min_w)
            # Renormalize
            s = wm + wa
            wm /= s
            wa /= s
        else:
            wm = wa = 0.5

        w_micro[t] = wm
        w_auto[t] = wa
        pred_blend[t] = wm * pm + wa * pa

        # ── Confidence: inverse of blended prediction variance ──
        blended_var = wm * vm + wa * va
        confidence[t] = 1.0 / (1.0 + blended_var)

        # ── Update models with PREVIOUS prediction's error ──
        # At bar t, we can now evaluate how well our t-1 prediction did
        if t >= warmup and not np.isnan(pred_blend[t-1]):
            target = kl_now  # this is what we predicted at t-1

            # Build features from t-1 (what we used to predict)
            row_prev = df_sess.iloc[t-1]
            x_micro_prev = _build_micro_features(row_prev, running_stats)
            x_auto_prev = _build_auto_features(
                [kl_slot2[max(t-3,0)], kl_slot2[t-2], kl_slot2[t-1]],
                [alpha_q[max(t-2,0)], alpha_q[t-1]],
                running_stats,
            )

            # RLS update
            micro_branch.update(x_micro_prev, target)
            auto_branch.update(x_auto_prev, target)

            # Track errors
            em = abs(pred_micro[t-1] - target)
            ea = abs(pred_auto[t-1] - target)
            eb = abs(pred_blend[t-1] - target)

            err_micro[t] = em
            err_auto[t] = ea
            err_blend[t] = eb

            # EWMA error tracking for blend weights
            blend_state["mse_micro"] = ew_alpha * (em**2) + (1 - ew_alpha) * mse_m
            blend_state["mse_auto"] = ew_alpha * (ea**2) + (1 - ew_alpha) * mse_a

    return {
        "pred_micro": pred_micro,
        "pred_auto": pred_auto,
        "pred_blend": pred_blend,
        "actual_kl": actual_kl,
        "err_micro": err_micro,
        "err_auto": err_auto,
        "err_blend": err_blend,
        "var_micro": var_micro,
        "var_auto": var_auto,
        "w_micro": w_micro,
        "w_auto": w_auto,
        "confidence": confidence,
    }


# ══════════════════════════════════════════════════════════════
# FULL PIPELINE
# ══════════════════════════════════════════════════════════════

def run_kl_predictor(df, hmm_results, sessions):
    """
    Run KL prediction across all sessions. Main entry point.

    Call this AFTER tensor.py builds tensors and BEFORE decision.py
    generates entries. Attaches KP_* columns to df.

    Parameters
    ----------
    df : pd.DataFrame — with micro features, Alpha_P1
    hmm_results : dict — from hmm_model.run_hmm_all_sessions()
    sessions : list — sorted session dates

    Returns
    -------
    df : pd.DataFrame — with KP_* columns attached
    kp_results : dict — full prediction arrays + diagnostics
    """
    N = len(df)

    # Initialize branches
    n_micro_feat = len(MICRO_FEATURES) + 1  # +1 for intercept
    n_auto_feat = 7 + 1                      # 7 features + intercept

    micro_branch = RLSBranch(
        n_micro_feat,
        lam=KP_PARAMS["rls_lambda_micro"],
        delta=KP_PARAMS["rls_delta"],
    )
    auto_branch = RLSBranch(
        n_auto_feat,
        lam=KP_PARAMS["rls_lambda_auto"],
        delta=KP_PARAMS["rls_delta"],
    )

    running_stats = {}
    blend_state = {}

    # Accumulate results
    all_pred_micro = np.full(N, np.nan)
    all_pred_auto = np.full(N, np.nan)
    all_pred_blend = np.full(N, np.nan)
    all_actual = np.full(N, np.nan)
    all_err_micro = np.full(N, np.nan)
    all_err_auto = np.full(N, np.nan)
    all_err_blend = np.full(N, np.nan)
    all_var_micro = np.full(N, np.nan)
    all_var_auto = np.full(N, np.nan)
    all_w_micro = np.full(N, 0.5)
    all_w_auto = np.full(N, 0.5)
    all_confidence = np.full(N, np.nan)

    offset = 0
    for si, sess in enumerate(sessions):
        sm = df["SessionDate"] == sess
        T_s = sm.sum()
        df_sess = df[sm]

        # KL from tensor slot 2 (most recent bar with backward info)
        kl_s2 = hmm_results["kl_b_slots"][offset:offset + T_s, 2]
        alpha_q = hmm_results["alpha"][offset:offset + T_s, 1]

        if KP_PARAMS["session_reset"]:
            micro_branch.reset()
            auto_branch.reset()
            running_stats = {}

        preds = run_kl_predictor_session(
            df_sess, kl_s2, alpha_q,
            micro_branch, auto_branch,
            running_stats, blend_state,
        )

        # Store
        sl = slice(offset, offset + T_s)
        all_pred_micro[sl] = preds["pred_micro"]
        all_pred_auto[sl] = preds["pred_auto"]
        all_pred_blend[sl] = preds["pred_blend"]
        all_actual[sl] = preds["actual_kl"]
        all_err_micro[sl] = preds["err_micro"]
        all_err_auto[sl] = preds["err_auto"]
        all_err_blend[sl] = preds["err_blend"]
        all_var_micro[sl] = preds["var_micro"]
        all_var_auto[sl] = preds["var_auto"]
        all_w_micro[sl] = preds["w_micro"]
        all_w_auto[sl] = preds["w_auto"]
        all_confidence[sl] = preds["confidence"]

        offset += T_s

        if (si + 1) % 200 == 0:
            print(f"    KL predictor: {si+1}/{len(sessions)} sessions")

    # ── Attach to dataframe ──
    df = df.copy()
    df["KP_pred_micro"] = all_pred_micro
    df["KP_pred_auto"] = all_pred_auto
    df["KP_pred_blend"] = all_pred_blend
    df["KP_actual_kl"] = all_actual
    df["KP_err_blend"] = all_err_blend
    df["KP_confidence"] = all_confidence
    df["KP_w_micro"] = all_w_micro
    df["KP_w_auto"] = all_w_auto

    # ── Derived features for decision layer ──
    df["KP_pred_vs_actual"] = all_pred_blend - all_actual
    df["KP_pred_direction"] = np.sign(
        all_pred_blend - np.roll(all_actual, 1)
    )
    df["KP_err_ratio"] = np.where(
        all_err_auto > 1e-10,
        all_err_micro / np.maximum(all_err_auto, 1e-10),
        1.0,
    )
    # Prediction surprise: when actual KL deviates far from prediction
    df["KP_surprise"] = np.abs(all_pred_blend - all_actual) / np.maximum(
        np.abs(all_actual), 1e-8
    )
    # Is micro branch dominating? Signals microstructure is more predictive
    df["KP_micro_dominant"] = (all_w_micro > 0.6).astype(float)
    # Confidence trend (is model getting better or worse?)
    conf_diff = np.diff(all_confidence, prepend=np.nan)
    df["KP_confidence_trend"] = np.nan_to_num(conf_diff, nan=0.0)

    # ── Naive baseline comparison ──
    # KL at t predicts KL at t+1 (random walk)
    naive_pred = np.roll(all_actual, 1)
    naive_pred[0] = np.nan
    naive_err = np.abs(naive_pred - all_actual)
    model_err = np.abs(all_pred_blend - all_actual)

    # Skill: fraction of bars where model beats naive
    valid = ~np.isnan(model_err) & ~np.isnan(naive_err) & (naive_err > 0)
    if valid.sum() > 0:
        skill = (model_err[valid] < naive_err[valid]).mean()
        avg_improvement = (naive_err[valid] - model_err[valid]).mean()
        print(f"\n  KL Predictor diagnostics:")
        print(f"    Bars predicted: {valid.sum()}")
        print(f"    Skill vs naive: {skill:.3f} (>{0.5} means model adds value)")
        print(f"    Avg error improvement: {avg_improvement:.6f}")
        print(f"    Final branch weights: micro={all_w_micro[valid][-1]:.3f}, "
              f"auto={all_w_auto[valid][-1]:.3f}")
        print(f"    Micro branch weight history: "
              f"[{all_w_micro[valid][0]:.2f} -> {all_w_micro[valid][-1]:.2f}]")
    else:
        print(f"  KL Predictor: insufficient valid predictions for diagnostics")

    kp_results = {
        "pred_micro": all_pred_micro,
        "pred_auto": all_pred_auto,
        "pred_blend": all_pred_blend,
        "actual": all_actual,
        "err_micro": all_err_micro,
        "err_auto": all_err_auto,
        "err_blend": all_err_blend,
        "confidence": all_confidence,
        "w_micro": all_w_micro,
        "w_auto": all_w_auto,
        "naive_err": naive_err,
        "model_err": model_err,
        "micro_weights": micro_branch.w.copy(),
        "auto_weights": auto_branch.w.copy(),
    }

    return df, kp_results


# ══════════════════════════════════════════════════════════════
# DIAGNOSTICS (call from notebook)
# ══════════════════════════════════════════════════════════════

def diagnose_kl_predictor(df, kp_results):
    """
    Print detailed diagnostics. Call from notebook after run_kl_predictor.

    Returns a summary dict for logging.
    """
    valid = ~np.isnan(kp_results["err_blend"])
    n_valid = valid.sum()

    if n_valid < 50:
        print("Insufficient data for diagnostics")
        return {}

    model_err = kp_results["err_blend"][valid]
    naive_err = kp_results["naive_err"][valid]
    micro_err = kp_results["err_micro"][valid & ~np.isnan(kp_results["err_micro"])]
    auto_err = kp_results["err_auto"][valid & ~np.isnan(kp_results["err_auto"])]

    print(f"\n{'='*50}")
    print(f"  KL PREDICTOR DIAGNOSTICS")
    print(f"{'='*50}")
    print(f"  Valid predictions: {n_valid}")
    print(f"\n  Model vs Naive (lower = better):")
    print(f"    Model MAE:  {model_err.mean():.6f}")
    print(f"    Naive MAE:  {naive_err.mean():.6f}")
    print(f"    Skill:      {(model_err < naive_err).mean():.3f}")
    print(f"    Improvement: {((naive_err - model_err) / np.maximum(naive_err, 1e-8)).mean()*100:.1f}%")

    if len(micro_err) > 0 and len(auto_err) > 0:
        print(f"\n  Branch comparison:")
        print(f"    Micro MAE: {micro_err.mean():.6f}")
        print(f"    Auto MAE:  {auto_err.mean():.6f}")
        better = "MICRO" if micro_err.mean() < auto_err.mean() else "AUTO"
        print(f"    Better branch: {better}")

    print(f"\n  Micro branch weights (what drives KL from microstructure):")
    mw = kp_results["micro_weights"]
    feat_names = MICRO_FEATURES + ["intercept"]
    for i, name in enumerate(feat_names):
        if i < len(mw):
            print(f"    {name:20s}: {mw[i]:+.4f}")

    print(f"\n  Auto branch weights (what drives KL from its own history):")
    aw = kp_results["auto_weights"]
    auto_names = ["kl_t", "kl_t1", "kl_t2", "kl_slope", "kl_accel",
                  "alpha_q", "alpha_chg", "intercept"]
    for i, name in enumerate(auto_names):
        if i < len(aw):
            print(f"    {name:20s}: {aw[i]:+.4f}")

    # Correlation with forward returns (the money question)
    if "fwd_7" in df.columns:
        fwd = df["fwd_7"].values
        conf = kp_results["confidence"]
        pred = kp_results["pred_blend"]
        surp = df["KP_surprise"].values

        v = ~np.isnan(fwd) & ~np.isnan(conf) & ~np.isnan(pred)
        if v.sum() > 100:
            from numpy import corrcoef
            print(f"\n  Correlation with fwd_7 returns:")
            print(f"    KP_confidence:     {corrcoef(conf[v], fwd[v])[0,1]:+.4f}")
            print(f"    KP_pred_blend:     {corrcoef(pred[v], fwd[v])[0,1]:+.4f}")
            print(f"    KP_surprise:       {corrcoef(surp[v], fwd[v])[0,1]:+.4f}")
            print(f"    KP_pred_direction: {corrcoef(df['KP_pred_direction'].values[v], fwd[v])[0,1]:+.4f}")

    summary = {
        "n_valid": n_valid,
        "model_mae": model_err.mean(),
        "naive_mae": naive_err.mean(),
        "skill": (model_err < naive_err).mean(),
        "micro_mae": micro_err.mean() if len(micro_err) > 0 else np.nan,
        "auto_mae": auto_err.mean() if len(auto_err) > 0 else np.nan,
    }

    return summary
# region imports
from AlgorithmImports import *
# endregion

"""
═══════════════════════════════════════════════════════════════
FILE 2: ma_signal.py — MA Crossover Detection
═══════════════════════════════════════════════════════════════
Detects bullish and bearish 8/20 SMA crossovers per session.
"""
import numpy as np
import pandas as pd
from config import PARAMS

def compute_crossovers(df, fast=None, slow=None):
    """
    Detect bull and bear MA crossovers across all sessions.
    

    Parameters
    ----------
    df : pd.DataFrame — must have Close, SessionDate
    fast, slow : int — MA periods

    Returns
    -------
    df : pd.DataFrame with MA_bull_x, MA_bear_x, MA_signal columns
    """
    if fast is None: fast = PARAMS["ma_fast"]
    if slow is None: slow = PARAMS["ma_slow"]

    df = df.copy()
    df["MA_bull_x"] = False
    df["MA_bear_x"] = False
    df["MA_signal"] = 0

    sessions = sorted(df["SessionDate"].unique())
    n_bull = n_bear = 0

    for sess in sessions:
        sm = df["SessionDate"] == sess
        sc = df.loc[sm, "Close"]
        if len(sc) < slow + 1:
            continue

        fma = sc.rolling(fast, min_periods=fast).mean()
        sma = sc.rolling(slow, min_periods=slow).mean()

        sig = pd.Series(0, index=sc.index)
        v = fma.notna() & sma.notna()
        sig[v] = np.where(fma[v] > sma[v], 1, -1)

        ps = sig.shift(1)
        xo = (sig != ps) & (sig != 0) & (ps != 0)
        xo.iloc[0] = False

        df.loc[sm, "MA_signal"] = sig
        df.loc[sm, "MA_bull_x"] = xo & (sig == 1)
        df.loc[sm, "MA_bear_x"] = xo & (sig == -1)

        n_bull += (xo & (sig == 1)).sum()
        n_bear += (xo & (sig == -1)).sum()

    print(f"  MA crossovers ({fast}/{slow}): {n_bull} bull, {n_bear} bear "
          f"({(n_bull+n_bear)/len(sessions):.1f}/session)")

    return df
"""
NQ Short Squeeze Detection Algorithm
========================================
Entry:  8/20 SMA bull crossover on 5-min bars
Gate:   Union of 3 microstructure/KL conditions + stress filter
Exit:   Fixed 7-bar hold (35 minutes)

Walk-forward validated: t=3.62, 58.8% hit, PF=1.56, 79% positive quarters.
All gate thresholds are rolling medians from past crossovers only.
"""
from AlgorithmImports import *
import numpy as np
from hmmlearn import hmm


class NQShortSqueeze(QCAlgorithm):

    # ══════════════════════════════════════════════════════════
    # INITIALIZE
    # ══════════════════════════════════════════════════════════

    def Initialize(self):
        self.SetStartDate(2023, 6, 1)
        self.SetEndDate(2026, 3, 4)
        self.SetCash(100000)

        # ── Instrument ──
        self.nq = self.AddFuture(
            Futures.Indices.NASDAQ100EMini,
            Resolution.Minute,
            dataNormalizationMode=DataNormalizationMode.Raw
        )
        self.nq.SetFilter(0, 90)
        self.contract = None

        # ── Parameters ──
        self.MA_FAST = 8
        self.MA_SLOW = 20
        self.HOLD_BARS = 7
        self.EWMA_SPAN = 390
        self.STRESS_THRESH = 0.02
        self.HMM_TRAIN_FRAC = 0.40
        self.WARMUP_BARS = 6
        self.MIN_XO_HISTORY = 150
        self.RTH_START = time(9, 30)
        self.RTH_END = time(15, 55)
        self.NO_TRADE_AFTER = time(15, 0)
        self.N_CONTRACTS = 1

        # ── EWMA normalization state ──
        ew_alpha = 2.0 / (self.EWMA_SPAN + 1.0)
        self.ew_a = ew_alpha
        self.lr_mean = 0.0
        self.lr_var = 1.0
        self.gk_mean = 0.0
        self.gk_var = 1.0
        self.ew_n = 0

        # ── HMM state ──
        self.hmm_model = None
        self.alpha_current = np.array([0.5, 0.5])
        self.session_alpha_history = []    # [(alpha, B, active), ...] within session
        self.kl_history = []               # KL values for stress + tensor

        # ── Tensor slots (last 4 bars) ──
        self.tensor_gamma = []    # list of [P(vol), P(quiet)]
        self.tensor_alpha = []
        self.tensor_kl = []
        self.tensor_micro_gk = []
        self.tensor_micro_cpos = []

        # ── KL Predictor (RLS) ──
        n_auto = 8  # 7 features + intercept
        self.rls_w = np.zeros(n_auto)
        self.rls_P = np.eye(n_auto) * 100.0
        self.rls_lambda = 0.98
        self.kl_pred_prev = None
        self.kl_actual_prev = None
        self.kp_stats = {}

        # ── Stress EWMA ──
        self.stress_fast = 0.0
        self.stress_slow = 0.0
        self.stress_af = 2.0 / 4.0    # fast span 3
        self.stress_as = 2.0 / 16.0   # slow span 15

        # ── MA state ──
        self.close_buffer = []
        self.ma_signal_prev = 0

        # ── Volume EWMA (for micro_vol) ──
        self.vol_ewma = 1.0

        # ── Crossover history (for rolling medians) ──
        self.xo_features = []  # list of dicts

        # ── Position tracking ──
        self.open_trades = []  # list of {"entry_bar": int, "entry_price": float}
        self.bar_count = 0

        # ── 5-min bar accumulator ──
        self.bar_acc_open = None
        self.bar_acc_high = None
        self.bar_acc_low = None
        self.bar_acc_close = None
        self.bar_acc_volume = 0
        self.bar_acc_count = 0
        self.bar_acc_time = None
        self.current_session = None

        # ── Fit HMM from history ──
        self._fit_hmm_from_history()

        self.Debug("Initialized. HMM fitted. Ready to trade.")

    # ══════════════════════════════════════════════════════════
    # HMM FITTING
    # ══════════════════════════════════════════════════════════

    def _fit_hmm_from_history(self):
        """Pull history and fit 2-state Gaussian HMM."""
        self.Debug("Fitting HMM from history...")
        hist = self.History(
            self.nq.Symbol, timedelta(days=800), Resolution.Minute
        )
        if len(hist) == 0:
            self.Debug("WARNING: No history available for HMM fitting")
            return

        # Build 5-min RTH bars
        df = hist[["open", "high", "low", "close", "volume"]].copy()
        df.index = df.index.get_level_values("time")
        df.columns = ["O", "H", "L", "C", "V"]

        rth = df.between_time("09:30", "15:55")
        bars = rth.resample("5min").agg(
            {"O": "first", "H": "max", "L": "min", "C": "last", "V": "sum"}
        ).dropna()

        # Session indexing
        bars["sess"] = bars.index.date
        bars["bn"] = bars.groupby("sess").cumcount()
        bars = bars[bars["bn"] >= self.WARMUP_BARS]

        # Features
        bars["LR"] = np.log(bars["C"] / bars["C"].shift(1))
        bars["GK"] = np.sqrt(np.maximum(
            0.5 * (np.log(bars["H"] / bars["L"])) ** 2
            - (2 * np.log(2) - 1) * (np.log(bars["C"] / bars["O"])) ** 2, 0.0
        ))
        bars.dropna(subset=["LR", "GK"], inplace=True)

        # EWMA normalize
        sp = self.EWMA_SPAN
        bars["LR_n"] = (bars["LR"] - bars["LR"].ewm(span=sp).mean()) / bars["LR"].ewm(span=sp).std()
        bars["GK_n"] = (bars["GK"] - bars["GK"].ewm(span=sp).mean()) / bars["GK"].ewm(span=sp).std()
        bars.dropna(subset=["LR_n", "GK_n"], inplace=True)

        # Train on first 40% of sessions
        sessions = sorted(bars["sess"].unique())
        n_train = int(len(sessions) * self.HMM_TRAIN_FRAC)
        train_sess = sessions[:n_train]

        train_mask = bars["sess"].isin(train_sess)
        X_train = bars.loc[train_mask, ["LR_n", "GK_n"]].values
        lengths = [int((bars["sess"] == s).sum()) for s in train_sess if (bars["sess"] == s).sum() > 0]

        # Fit
        model = hmm.GaussianHMM(
            n_components=2, covariance_type="full",
            n_iter=50, random_state=42
        )
        model.fit(X_train, lengths=lengths)

        # Enforce state 0 = volatile, state 1 = quiet
        if model.means_[0, 1] < model.means_[1, 1]:
            p = [1, 0]
            model.means_ = model.means_[p]
            model.covars_ = model.covars_[p]
            model.transmat_ = model.transmat_[np.ix_(p, p)]
            model.startprob_ = model.startprob_[p]

        self.hmm_model = model
        diag = np.diag(model.transmat_)
        self.Debug(f"HMM fitted on {n_train} sessions. Diag: [{diag[0]:.3f}, {diag[1]:.3f}]")

        # Initialize EWMA state from tail of training data
        tail = bars.tail(self.EWMA_SPAN)
        self.lr_mean = tail["LR"].mean()
        self.lr_var = tail["LR"].var()
        self.gk_mean = tail["GK"].mean()
        self.gk_var = tail["GK"].var()
        self.ew_n = len(tail)

    # ══════════════════════════════════════════════════════════
    # DATA HANDLING
    # ══════════════════════════════════════════════════════════

    def OnData(self, data):
        # Get the active contract
        for chain in data.FutureChains:
            contracts = sorted(
                [c for c in chain.Value if c.OpenInterest > 0],
                key=lambda c: c.OpenInterest, reverse=True
            )
            if contracts:
                self.contract = contracts[0]
                break

        if self.contract is None:
            return

        bar = data.Bars.get(self.contract.Symbol)
        if bar is None:
            return

        t = bar.EndTime.time()
        d = bar.EndTime.date()

        # RTH filter
        if t < self.RTH_START or t > self.RTH_END:
            return

        # Session management
        if d != self.current_session:
            self._on_session_start(d)
            self.current_session = d

        # Accumulate into 5-min bars
        if self.bar_acc_count == 0:
            self.bar_acc_open = float(bar.Open)
            self.bar_acc_high = float(bar.High)
            self.bar_acc_low = float(bar.Low)
            self.bar_acc_time = bar.EndTime
        else:
            self.bar_acc_high = max(self.bar_acc_high, float(bar.High))
            self.bar_acc_low = min(self.bar_acc_low, float(bar.Low))

        self.bar_acc_close = float(bar.Close)
        self.bar_acc_volume += int(bar.Volume)
        self.bar_acc_count += 1

        if self.bar_acc_count >= 5:
            self._on_five_min_bar(
                self.bar_acc_open, self.bar_acc_high,
                self.bar_acc_low, self.bar_acc_close,
                self.bar_acc_volume, bar.EndTime
            )
            self.bar_acc_count = 0
            self.bar_acc_volume = 0

    # ══════════════════════════════════════════════════════════
    # SESSION MANAGEMENT
    # ══════════════════════════════════════════════════════════

    def _on_session_start(self, date):
        """Reset session-level state."""
        if self.hmm_model is not None:
            self.alpha_current = self.hmm_model.startprob_.copy()
        self.session_alpha_history = []
        self.close_buffer = []
        self.ma_signal_prev = 0
        self.bar_count = 0
        self.bar_acc_count = 0
        self.bar_acc_volume = 0

        # Reset stress EWMA
        self.stress_fast = 0.0
        self.stress_slow = 0.0

        # Close any open positions (no overnight)
        self._close_all("session_end")

    # ══════════════════════════════════════════════════════════
    # CORE: 5-MIN BAR PROCESSING
    # ══════════════════════════════════════════════════════════

    def _on_five_min_bar(self, o, h, l, c, v, end_time):
        """Process one 5-min RTH bar."""
        if self.hmm_model is None:
            return

        self.bar_count += 1
        prev_close = self.close_buffer[-1] if self.close_buffer else c

        # Skip warmup bars at session start
        if self.bar_count <= self.WARMUP_BARS:
            self.close_buffer.append(c)
            return

        # ── 1. Base features ──
        lr = np.log(c / prev_close) if prev_close > 0 else 0.0
        gk = np.sqrt(max(
            0.5 * (np.log(h / l)) ** 2
            - (2 * np.log(2) - 1) * (np.log(c / o)) ** 2, 0.0
        )) if h > l and o > 0 else 0.0

        bar_range = (h - l) / c if c > 0 else 0.0
        self.vol_ewma = 0.96 * self.vol_ewma + 0.04 * max(v, 1)
        micro_vol = v / max(self.vol_ewma, 1)
        micro_cpos = (c - l) / (h - l) if h > l else 0.5

        # ── 2. EWMA normalize ──
        lr_n, gk_n = self._ewma_normalize(lr, gk)

        # ── 3. HMM forward step ──
        obs = np.array([lr_n, gk_n])
        B = self._emission_prob(obs)
        active = (h != l) and (gk > 0.00005)

        self.alpha_current = (self.alpha_current @ self.hmm_model.transmat_) * (B if active else 1.0)
        s = self.alpha_current.sum()
        if s > 0:
            self.alpha_current /= s

        self.session_alpha_history.append((self.alpha_current.copy(), B.copy(), active))

        # ── 4. Progressive backward → gamma + KL ──
        kl_val = self._progressive_backward()

        # ── 5. Update tensor slots ──
        self.tensor_gamma.append(self._last_gamma.copy() if hasattr(self, '_last_gamma') else self.alpha_current.copy())
        self.tensor_alpha.append(self.alpha_current.copy())
        self.tensor_kl.append(kl_val)
        self.tensor_micro_gk.append(gk)
        self.tensor_micro_cpos.append(micro_cpos)

        # Keep last 4 slots
        for lst in [self.tensor_gamma, self.tensor_alpha, self.tensor_kl,
                     self.tensor_micro_gk, self.tensor_micro_cpos]:
            if len(lst) > 4:
                lst.pop(0)

        # ── 6. KL predictor ──
        kp_surprise, kp_pred_vs_actual = self._kl_predictor_step(kl_val)

        # ── 7. Stress ──
        kl_for_stress = kl_val if not np.isnan(kl_val) else 0.0
        self.stress_fast = self.stress_af * kl_for_stress + (1 - self.stress_af) * self.stress_fast
        self.stress_slow = self.stress_as * kl_for_stress + (1 - self.stress_as) * self.stress_slow
        stress = self.stress_fast - self.stress_slow

        # ── 8. MA crossover ──
        self.close_buffer.append(c)
        crossover = self._check_crossover()

        # ── 9. Manage existing positions ──
        self._manage_positions(end_time, c)

        # ── 10. Entry logic ──
        if crossover and end_time.time() < self.NO_TRADE_AFTER:
            if len(self.tensor_kl) >= 4 and len(self.xo_features) >= self.MIN_XO_HISTORY:
                features = self._extract_gate_features(
                    stress, kp_surprise, kp_pred_vs_actual
                )
                if features is not None:
                    gate_fires = self._evaluate_gates(features)
                    self.xo_features.append(features)

                    if gate_fires and len(self.open_trades) == 0:
                        self._enter_long(c, end_time)
            elif crossover:
                # Still accumulating history — store features for median building
                features = self._extract_gate_features(
                    stress, kp_surprise, kp_pred_vs_actual
                )
                if features is not None:
                    self.xo_features.append(features)

    # ══════════════════════════════════════════════════════════
    # HMM HELPERS
    # ══════════════════════════════════════════════════════════

    def _emission_prob(self, obs):
        """Compute emission probability for each state."""
        B = np.zeros(2)
        for s in range(2):
            diff = obs - self.hmm_model.means_[s]
            cov = self.hmm_model.covars_[s]
            det = np.linalg.det(cov)
            if det <= 0:
                B[s] = 1e-10
                continue
            inv_cov = np.linalg.inv(cov)
            exponent = -0.5 * diff @ inv_cov @ diff
            B[s] = np.exp(exponent) / np.sqrt((2 * np.pi) ** 2 * max(det, 1e-30))
        return np.maximum(B, 1e-10)

    def _progressive_backward(self):
        """Run backward on last 3-4 bars, return KL at slot 2."""
        hist = self.session_alpha_history
        if len(hist) < 3:
            self._last_gamma = self.alpha_current.copy()
            return 0.0

        A = self.hmm_model.transmat_
        T = min(4, len(hist))
        start = len(hist) - T

        # Backward pass
        beta = np.ones(2)
        betas = [beta.copy()]
        for tb in range(T - 1, 0, -1):
            idx = start + tb
            _, B_next, act_next = hist[idx]
            bn = B_next if act_next else np.ones(2)
            beta = (A * bn * beta).sum(axis=1)
            s = beta.sum()
            if s > 0:
                beta /= s
            betas.insert(0, beta.copy())

        # Compute gamma at slot 2 (t-1, most recent with backward info)
        slot2_idx = T - 2  # second to last
        if slot2_idx >= 0 and slot2_idx < len(betas):
            abs_idx = start + slot2_idx
            alpha_s2 = hist[abs_idx][0]
            fb = alpha_s2 * betas[slot2_idx]
            s = fb.sum()
            if s > 0:
                fb /= s
            self._last_gamma = fb.copy()

            # KL(alpha || gamma)
            eps = 1e-10
            p = np.clip(alpha_s2, eps, 1.0)
            q = np.clip(fb, eps, 1.0)
            kl = np.sum(p * np.log(p / q))
            return kl

        self._last_gamma = self.alpha_current.copy()
        return 0.0

    def _ewma_normalize(self, lr, gk):
        """Online EWMA normalization."""
        a = self.ew_a
        self.ew_n += 1
        if self.ew_n < 10:
            self.lr_mean = (self.lr_mean * (self.ew_n - 1) + lr) / self.ew_n
            self.gk_mean = (self.gk_mean * (self.ew_n - 1) + gk) / self.ew_n
            return 0.0, 0.0

        self.lr_mean = a * lr + (1 - a) * self.lr_mean
        self.gk_mean = a * gk + (1 - a) * self.gk_mean
        self.lr_var = a * (lr - self.lr_mean) ** 2 + (1 - a) * self.lr_var
        self.gk_var = a * (gk - self.gk_mean) ** 2 + (1 - a) * self.gk_var

        lr_std = max(np.sqrt(self.lr_var), 1e-8)
        gk_std = max(np.sqrt(self.gk_var), 1e-8)
        return (lr - self.lr_mean) / lr_std, (gk - self.gk_mean) / gk_std

    # ══════════════════════════════════════════════════════════
    # KL PREDICTOR (simplified RLS, autoregressive branch only)
    # ══════════════════════════════════════════════════════════

    def _kl_predictor_step(self, kl_actual):
        """One step of the KL predictor. Returns (surprise, pred_vs_actual)."""
        kl_hist = self.tensor_kl
        if len(kl_hist) < 3:
            self.kl_pred_prev = None
            self.kl_actual_prev = kl_actual
            return 0.0, 0.0

        # Build auto features: [kl_t, kl_t1, kl_t2, slope, accel, alpha_q, alpha_chg, intercept]
        kl = [kl_hist[-3] if len(kl_hist) >= 3 else 0,
              kl_hist[-2] if len(kl_hist) >= 2 else 0,
              kl_hist[-1]]
        kl = [0.0 if np.isnan(v) else v for v in kl]

        slope = (kl[2] - kl[0]) / 2.0
        accel = (kl[2] - kl[1]) - (kl[1] - kl[0])
        aq = self.alpha_current[1]  # P(quiet)
        aq_prev = self.tensor_alpha[-2][1] if len(self.tensor_alpha) >= 2 else 0.5
        achg = aq - aq_prev

        x = np.array([kl[2], kl[1], kl[0], slope, accel, aq, achg, 1.0])

        # Predict
        pred = float(x @ self.rls_w)

        # Update with previous prediction's error
        if self.kl_pred_prev is not None and self.kl_actual_prev is not None:
            target = kl_actual
            # Build features from previous bar
            kl_p = [kl_hist[-4] if len(kl_hist) >= 4 else 0,
                    kl_hist[-3] if len(kl_hist) >= 3 else 0,
                    kl_hist[-2] if len(kl_hist) >= 2 else 0]
            kl_p = [0.0 if np.isnan(v) else v for v in kl_p]
            sl_p = (kl_p[2] - kl_p[0]) / 2.0
            ac_p = (kl_p[2] - kl_p[1]) - (kl_p[1] - kl_p[0])
            aq_pp = self.tensor_alpha[-3][1] if len(self.tensor_alpha) >= 3 else 0.5
            x_prev = np.array([kl_p[2], kl_p[1], kl_p[0], sl_p, ac_p,
                               aq_prev, aq_prev - aq_pp, 1.0])

            # RLS update
            Px = self.rls_P @ x_prev
            denom = self.rls_lambda + float(x_prev @ Px)
            if abs(denom) > 1e-12:
                K = Px / denom
                err = target - float(x_prev @ self.rls_w)
                self.rls_w = self.rls_w + K * err
                self.rls_P = (self.rls_P - np.outer(K, x_prev @ self.rls_P)) / self.rls_lambda
                self.rls_P = (self.rls_P + self.rls_P.T) / 2.0

        # Compute outputs
        actual = kl_actual if not np.isnan(kl_actual) else 0.0
        surprise = abs(pred - actual) / max(abs(actual), 1e-8)
        pred_vs_actual = pred - actual

        self.kl_pred_prev = pred
        self.kl_actual_prev = kl_actual

        return surprise, pred_vs_actual

    # ══════════════════════════════════════════════════════════
    # MA CROSSOVER
    # ══════════════════════════════════════════════════════════

    def _check_crossover(self):
        """Check for 8/20 SMA bull crossover. Returns True on cross."""
        buf = self.close_buffer
        if len(buf) < self.MA_SLOW + 1:
            return False

        fast_now = np.mean(buf[-self.MA_FAST:])
        slow_now = np.mean(buf[-self.MA_SLOW:])
        fast_prev = np.mean(buf[-self.MA_FAST - 1:-1])
        slow_prev = np.mean(buf[-self.MA_SLOW - 1:-1])

        sig_now = 1 if fast_now > slow_now else -1
        sig_prev = 1 if fast_prev > slow_prev else -1

        is_bull_cross = (sig_now == 1 and sig_prev == -1)
        return is_bull_cross

    # ══════════════════════════════════════════════════════════
    # GATE FEATURES + EVALUATION
    # ══════════════════════════════════════════════════════════

    def _extract_gate_features(self, stress, kp_surprise, kp_pred_vs_actual):
        """Extract the 6 features needed for the 3 thesis gates."""
        if len(self.tensor_micro_gk) < 4 or len(self.tensor_kl) < 3:
            return None

        gk_arr = np.array(self.tensor_micro_gk[-4:])
        cpos_arr = np.array(self.tensor_micro_cpos[-4:])
        kl_arr = np.array(self.tensor_kl[-3:])
        kl_arr = np.nan_to_num(kl_arr, nan=0.0)

        # Slopes: simple (last - first) / n
        cpos_slope = (cpos_arr[-1] - cpos_arr[0]) / 3.0 if len(cpos_arr) == 4 else 0.0
        kl_slope = (kl_arr[-1] - kl_arr[0]) / 2.0 if len(kl_arr) == 3 else 0.0

        return {
            "micro_gk_std": float(np.std(gk_arr)),
            "micro_cpos_std": float(np.std(cpos_arr)),
            "micro_cpos_slope": float(cpos_slope),
            "kl_slope_3s": float(kl_slope),
            "kp_surprise": float(kp_surprise),
            "kp_pred_vs_actual": float(kp_pred_vs_actual),
            "stress": float(stress),
        }

    def _evaluate_gates(self, features):
        """
        Union of 3 thesis gates + stress filter.
        All thresholds are rolling medians from past crossovers.
        """
        if len(self.xo_features) < self.MIN_XO_HISTORY:
            return False

        # Compute rolling medians from past crossovers
        past = self.xo_features  # all prior crossovers
        med_gk_std = np.median([f["micro_gk_std"] for f in past])
        med_cpos_std = np.median([f["micro_cpos_std"] for f in past])
        med_cpos_slope = np.median([f["micro_cpos_slope"] for f in past])
        med_kl_slope = np.median([f["kl_slope_3s"] for f in past])
        med_surprise = np.median([f["kp_surprise"] for f in past])
        med_pva = np.median([f["kp_pred_vs_actual"] for f in past])

        # Gate A: high GK variability + stable close positioning
        gate_a = (features["micro_gk_std"] > med_gk_std and
                  features["micro_cpos_std"] < med_cpos_std)

        # Gate B: KL predictor surprised + KL slope declining
        gate_b = (features["kp_surprise"] > med_surprise and
                  features["kl_slope_3s"] < med_kl_slope)

        # Gate C: close trending up + regime more stable than predicted
        gate_c = (features["micro_cpos_slope"] > med_cpos_slope and
                  features["kp_pred_vs_actual"] > med_pva)

        # Union + stress
        union = gate_a or gate_b or gate_c
        stress_ok = features["stress"] < self.STRESS_THRESH

        fires = union and stress_ok

        if fires:
            gates_str = []
            if gate_a: gates_str.append("A")
            if gate_b: gates_str.append("B")
            if gate_c: gates_str.append("C")
            self.Debug(f"GATE FIRED: {'+'.join(gates_str)} | "
                       f"gk_std={features['micro_gk_std']:.4f} "
                       f"stress={features['stress']:.4f}")

        return fires

    # ══════════════════════════════════════════════════════════
    # POSITION MANAGEMENT
    # ══════════════════════════════════════════════════════════

    def _enter_long(self, price, time):
        """Enter long position."""
        if self.contract is None:
            return

        self.MarketOrder(self.contract.Symbol, self.N_CONTRACTS)
        self.open_trades.append({
            "entry_bar": self.bar_count,
            "entry_price": price,
            "entry_time": time,
        })
        self.Debug(f"ENTRY: Long {self.N_CONTRACTS} @ {price:.2f} at {time}")

    def _manage_positions(self, current_time, current_price):
        """Exit positions after HOLD_BARS bars."""
        if not self.open_trades:
            return

        to_close = []
        for i, trade in enumerate(self.open_trades):
            bars_held = self.bar_count - trade["entry_bar"]
            if bars_held >= self.HOLD_BARS:
                to_close.append(i)

        if to_close:
            self._close_all("fixed_hold")

    def _close_all(self, reason):
        """Close all open positions."""
        if not self.Portfolio.Invested:
            self.open_trades = []
            return

        for trade in self.open_trades:
            pnl = 0
            if self.contract:
                pnl = (self.Securities[self.contract.Symbol].Price
                       - trade["entry_price"])

        if self.contract and self.Portfolio[self.contract.Symbol].Invested:
            self.Liquidate(self.contract.Symbol)
            if self.open_trades:
                self.Debug(f"EXIT: {reason} | trades closed: {len(self.open_trades)}")

        self.open_trades = []
# region imports
from AlgorithmImports import *
# endregion

# Your New Python File
"""
═══════════════════════════════════════════════════════════════
FILE 4: sizing.py — Conviction Score + Position Sizing
═══════════════════════════════════════════════════════════════
"""
import numpy as np
from config import PARAMS

def compute_conviction(row, weights=None):
    """
    Compute conviction score from tensor features at a single bar.

    Parameters
    ----------
    row : pd.Series — one bar from the crossover dataframe
    weights : dict or None — component weights

    Returns
    -------
    conviction : float in [0, 1]
    components : dict of individual scores
    """
    if weights is None:
        weights = PARAMS["conviction_weights"]

    stress_val = row.get("stress_B", 0.0)
    age_val = row.get("T_regime_age", 0)
    alpha_conv = row.get("T_alpha_conv", 0.0)
    kl_slope = row.get("T_kl_slope_3s", 0.0)

    components = {
        "stress": np.clip((0.02 - stress_val) / 0.02, 0, 1),
        "age": np.clip(age_val / 20.0, 0, 1),
        "alpha": np.clip(alpha_conv * 2, 0, 1),
        "kl_slope": np.clip(kl_slope * 500, 0, 1),
    }

    conviction = sum(weights.get(k, 0) * v for k, v in components.items())
    return np.clip(conviction, 0, 1), components


def size_position(conviction, tiers=None):
    """
    Map conviction to contract count.

    Parameters
    ----------
    conviction : float in [0, 1]
    tiers : list of dicts with threshold/contracts

    Returns
    -------
    n_contracts : int
    """
    if tiers is None:
        tiers = PARAMS["contract_tiers"]

    for tier in tiers:
        if conviction >= tier["threshold"]:
            return tier["contracts"]
    return 1


def compute_conviction_column(xo):
    """
    Add conviction and contract columns to crossover dataframe.

    Parameters
    ----------
    xo : pd.DataFrame — gated crossover bars

    Returns
    -------
    xo : pd.DataFrame with conviction, n_contracts columns
    """
    xo = xo.copy()
    convictions = []
    contracts = []

    for idx in xo.index:
        conv, _ = compute_conviction(xo.loc[idx])
        convictions.append(conv)
        contracts.append(size_position(conv))

    xo["conviction"] = convictions
    xo["n_contracts"] = contracts

    return xo
    
# region imports
from AlgorithmImports import *
# endregion

# Your New Python File
"""
tensor.py — Tensor Builder + Feature Extraction
==================================================
Builds the [T, 5, 11] tensor per session from HMM outputs.
Extracts all features: continuous, argmax, patterns, disagreement,
micro, interactions.

Uses Mode B (progressive backward) gamma/KL from hmm_model.py.
The tensor is the SINGLE SOURCE OF TRUTH for all decision features.

Tensor layout per slot [11 columns]:
  [0:2]  gamma P(vol), P(quiet)
  [2:4]  alpha P(vol), P(quiet)
  [4:6]  proj  P(vol), P(quiet)     — alpha @ transition matrix
  [6]    KL(alpha || gamma)
  [7:11] micro features              — frozen at bar creation

Slots: [t-3, t-2, t-1, t_current, t+1_projection]
"""
import numpy as np
from config import PARAMS, MICRO_FEATURES


# ══════════════════════════════════════════════════════════════
# TENSOR BUILDER
# ══════════════════════════════════════════════════════════════

def build_tensor_session(alpha, kl_slots, gamma_slots, micro, active, transmat):
    """
    Build [T, 5, 11] tensor for one session using Mode B outputs.

    Parameters
    ----------
    alpha : [T, 2] — forward posteriors
    kl_slots : [T, 4] — KL per tensor slot from progressive backward
    gamma_slots : [T, 4, 2] — gamma per tensor slot from progressive backward
    micro : [T, 4] — micro features per bar
    active : [T] bool — activity mask
    transmat : [2, 2] — HMM transition matrix

    Returns
    -------
    tensor : [T, 5, 11] — full tensor history
    """
    T = len(alpha)
    N_S = PARAMS["tensor_slots"]
    N_C = PARAMS["tensor_cols"]
    A = transmat

    tensor = np.zeros((T, N_S, N_C))

    for t in range(T):
        # ── Slots 0-3: historical + current ──
        for slot in range(4):
            abs_bar = t - (3 - slot)

            if abs_bar < 0:
                # Before session — uniform
                tensor[t, slot, 0:2] = [0.5, 0.5]  # gamma
                tensor[t, slot, 2:4] = [0.5, 0.5]  # alpha
                tensor[t, slot, 4:6] = [0.5, 0.5]  # proj
                tensor[t, slot, 6] = 0.0            # KL
                tensor[t, slot, 7:11] = 0.0         # micro
                continue

            # Gamma from progressive backward
            tensor[t, slot, 0:2] = gamma_slots[t, slot]

            # Alpha (frozen at bar creation)
            tensor[t, slot, 2:4] = alpha[abs_bar]

            # Projection
            tensor[t, slot, 4:6] = alpha[abs_bar] @ A

            # KL from progressive backward
            tensor[t, slot, 6] = kl_slots[t, slot]

            # Micro (frozen at bar creation)
            tensor[t, slot, 7:11] = micro[abs_bar]

        # ── Slot 4: projection ──
        alpha_next = alpha[t] @ A
        proj_next = alpha_next @ A
        tensor[t, 4, 0:2] = alpha_next    # gamma = alpha for projection
        tensor[t, 4, 2:4] = alpha_next
        tensor[t, 4, 4:6] = proj_next
        tensor[t, 4, 6] = 0.0
        tensor[t, 4, 7:11] = 0.0

    return tensor


def build_tensors_all_sessions(df, hmm_results, sessions, model):
    """
    Build tensors for all sessions.

    Parameters
    ----------
    df : pd.DataFrame — with micro feature columns
    hmm_results : dict — from hmm_model.run_hmm_all_sessions()
    sessions : list of session dates
    model : fitted GaussianHMM

    Returns
    -------
    all_tensors : [N, 5, 11] — tensor for every bar
    """
    N = len(df)
    all_tensors = np.zeros((N, PARAMS["tensor_slots"], PARAMS["tensor_cols"]))

    offset = 0
    for sess in sessions:
        sm = df["SessionDate"] == sess
        T_s = sm.sum()

        alpha_s = hmm_results["alpha"][offset:offset + T_s]
        kl_s = hmm_results["kl_b_slots"][offset:offset + T_s]
        gamma_s = hmm_results["gamma_b_slots"][offset:offset + T_s]
        micro_s = df.loc[sm, MICRO_FEATURES].values
        active_s = df.loc[sm, "Active"].values

        tens = build_tensor_session(
            alpha_s, kl_s, gamma_s, micro_s, active_s, model.transmat_
        )
        all_tensors[offset:offset + T_s] = tens
        offset += T_s

    return all_tensors


# ══════════════════════════════════════════════════════════════
# FEATURE EXTRACTION
# ══════════════════════════════════════════════════════════════

def extract_features(tensor_history, alpha_all, active):
    """
    Extract ALL features from tensor. Returns dict of [T] arrays.
    Every feature name is prefixed with T_ when attached to dataframe.

    Categories: KL, alpha_continuous, alpha_argmax, gamma,
    disagreement, projection, micro, regime_persistence, interactions.
    """
    T = len(tensor_history)
    f = {}

    # ── Shortcuts ──
    gamma_q  = tensor_history[:, :4, 1]     # [T, 4] gamma P(quiet)
    alpha_q  = tensor_history[:, :4, 3]     # [T, 4] alpha P(quiet)
    proj_q   = tensor_history[:, :, 3]      # [T, 5] proj P(quiet) — includes slot 4
    kl_slots = tensor_history[:, :4, 6]     # [T, 4] KL
    micro_sl = tensor_history[:, :4, 7:11]  # [T, 4, 4] micro

    alpha_am = (alpha_q > 0.5).astype(int)
    gamma_am = (gamma_q > 0.5).astype(int)
    proj_am  = (proj_q > 0.5).astype(int)

    # ════════════════════════════════════════
    # A. KL FEATURES
    # ════════════════════════════════════════

    f["kl_slope_3s"] = _slope_3(kl_slots[:, :3])
    f["kl_slope_4s"] = _slope_4(kl_slots)
    f["kl_mean_3s"]  = np.nanmean(kl_slots[:, :3], axis=1)
    f["kl_max_3s"]   = np.nanmax(kl_slots[:, :3], axis=1)
    f["kl_min_3s"]   = np.nanmin(kl_slots[:, :3], axis=1)
    f["kl_range_3s"] = f["kl_max_3s"] - f["kl_min_3s"]
    f["kl_slot2"]    = kl_slots[:, 2]
    f["kl_accel"]    = (kl_slots[:, 2] - kl_slots[:, 1]) - (kl_slots[:, 1] - kl_slots[:, 0])

    # ════════════════════════════════════════
    # B. ALPHA CONTINUOUS
    # ════════════════════════════════════════

    f["alpha_std"]   = alpha_q.std(axis=1)
    f["alpha_mean"]  = alpha_q.mean(axis=1)
    f["alpha_conv"]  = np.abs(alpha_all[:, 1] - 0.5)
    f["alpha_slope"] = _slope_4(alpha_q)
    f["alpha_range"] = alpha_q.max(axis=1) - alpha_q.min(axis=1)

    # ════════════════════════════════════════
    # C. ALPHA ARGMAX
    # ════════════════════════════════════════

    for s in range(4):
        f[f"am_alpha_s{s}"] = alpha_am[:, s].astype(float)

    f["am_alpha_sum"]            = alpha_am.sum(axis=1).astype(float)
    f["am_alpha_unan_quiet"]     = (f["am_alpha_sum"] == 4).astype(float)
    f["am_alpha_unan_vol"]       = (f["am_alpha_sum"] == 0).astype(float)
    f["am_alpha_unanimous"]      = ((f["am_alpha_sum"] == 4) | (f["am_alpha_sum"] == 0)).astype(float)
    f["am_alpha_majority_quiet"] = (f["am_alpha_sum"] >= 3).astype(float)
    f["am_alpha_majority_vol"]   = (f["am_alpha_sum"] <= 1).astype(float)

    f["am_alpha_flips"] = _count_flips(alpha_am)

    f["am_alpha_last_flip"]  = _last_flip_pos(alpha_am)
    f["am_alpha_first_flip"] = _first_flip_pos(alpha_am)

    f["am_alpha_pattern"] = (alpha_am[:, 0]*8 + alpha_am[:, 1]*4 +
                              alpha_am[:, 2]*2 + alpha_am[:, 3]*1).astype(float)

    f["am_pattern_stable_quiet"]  = (f["am_alpha_pattern"] == 15).astype(float)
    f["am_pattern_stable_vol"]    = (f["am_alpha_pattern"] == 0).astype(float)
    f["am_pattern_entering_quiet"] = np.isin(f["am_alpha_pattern"], [7, 3, 1]).astype(float)
    f["am_pattern_entering_vol"]   = np.isin(f["am_alpha_pattern"], [8, 12, 14]).astype(float)
    f["am_pattern_mixed"]          = np.isin(f["am_alpha_pattern"], [5, 6, 9, 10]).astype(float)

    f["am_alpha_current_quiet"]  = alpha_am[:, 3].astype(float)
    f["am_alpha_transition_now"] = (alpha_am[:, 3] != alpha_am[:, 2]).astype(float)
    f["am_alpha_recent_trans"]   = (f["am_alpha_flips"] > 0).astype(float)

    # ════════════════════════════════════════
    # D. GAMMA CONTINUOUS + ARGMAX
    # ════════════════════════════════════════

    f["gamma_std"]   = gamma_q.std(axis=1)
    f["gamma_mean"]  = gamma_q.mean(axis=1)
    f["gamma_range"] = gamma_q.max(axis=1) - gamma_q.min(axis=1)
    f["gamma_slope"] = _slope_4(gamma_q)

    for s in range(4):
        f[f"am_gamma_s{s}"] = gamma_am[:, s].astype(float)

    f["am_gamma_sum"]        = gamma_am.sum(axis=1).astype(float)
    f["am_gamma_unan_quiet"] = (f["am_gamma_sum"] == 4).astype(float)
    f["am_gamma_unan_vol"]   = (f["am_gamma_sum"] == 0).astype(float)
    f["am_gamma_flips"]      = _count_flips(gamma_am)
    f["am_gamma_pattern"]    = (gamma_am[:, 0]*8 + gamma_am[:, 1]*4 +
                                 gamma_am[:, 2]*2 + gamma_am[:, 3]*1).astype(float)

    # ════════════════════════════════════════
    # E. DISAGREEMENT (alpha vs gamma per slot)
    # ════════════════════════════════════════

    disagree = (alpha_am != gamma_am).astype(float)
    for s in range(4):
        f[f"disagree_s{s}"] = disagree[:, s]

    f["disagree_count"]   = disagree.sum(axis=1)
    f["disagree_any"]     = (f["disagree_count"] > 0).astype(float)
    f["disagree_current"] = disagree[:, 3]
    f["disagree_trend"]   = (disagree[:, 2] + disagree[:, 3]) - (disagree[:, 0] + disagree[:, 1])
    f["disagree_pattern"] = (disagree[:, 0]*8 + disagree[:, 1]*4 +
                              disagree[:, 2]*2 + disagree[:, 3]*1)

    for s in range(4):
        f[f"ag_direction_s{s}"] = (alpha_am[:, s] - gamma_am[:, s]).astype(float)
    f["ag_direction_net"] = sum(f[f"ag_direction_s{s}"] for s in range(4))

    # ════════════════════════════════════════
    # F. PROJECTION
    # ════════════════════════════════════════

    f["proj_quiet_val"]      = tensor_history[:, 4, 3]
    f["proj_conv"]           = np.abs(tensor_history[:, 4, 3] - 0.5)
    f["am_proj_quiet"]       = proj_am[:, 4].astype(float)
    f["proj_agrees_alpha"]   = (proj_am[:, 4] == alpha_am[:, 3]).astype(float)
    f["proj_agrees_gamma"]   = (proj_am[:, 4] == gamma_am[:, 3]).astype(float)
    f["proj_error"]          = np.abs(tensor_history[:, 4, 3] - tensor_history[:, 3, 3])

    slot23_trend = alpha_am[:, 3] - alpha_am[:, 2]
    proj_trend = proj_am[:, 4].astype(int) - alpha_am[:, 3]
    f["proj_continues_trend"] = np.where(
        slot23_trend == 0,
        (proj_am[:, 4] == alpha_am[:, 3]).astype(float),
        (np.sign(proj_trend) == np.sign(slot23_trend)).astype(float)
    )

    f["am_5slot_pattern"] = (alpha_am[:, 0]*16 + alpha_am[:, 1]*8 +
                              alpha_am[:, 2]*4 + alpha_am[:, 3]*2 +
                              proj_am[:, 4]*1).astype(float)

    # ════════════════════════════════════════
    # G. MICRO FEATURES
    # ════════════════════════════════════════

    micro_names = ["gk", "range", "vol", "cpos"]
    for m in range(4):
        f[f"micro_{micro_names[m]}_mean"]  = micro_sl[:, :, m].mean(axis=1)
        f[f"micro_{micro_names[m]}_std"]   = micro_sl[:, :, m].std(axis=1)
        f[f"micro_{micro_names[m]}_last"]  = micro_sl[:, 3, m]
        f[f"micro_{micro_names[m]}_slope"] = _slope_4(micro_sl[:, :, m])

    # ════════════════════════════════════════
    # H. REGIME PERSISTENCE (session-aware — filled separately)
    # ════════════════════════════════════════

    f["regime_age"] = np.zeros(T)  # filled by attach_session_features()

    f["bars_since_flip"] = np.full(T, 4.0)
    for i in range(T):
        for j in range(3, 0, -1):
            if alpha_am[i, j] != alpha_am[i, j - 1]:
                f["bars_since_flip"][i] = 3 - j
                break

    f["current_streak"] = np.ones(T)
    for i in range(T):
        curr = alpha_am[i, 3]
        for j in range(2, -1, -1):
            if alpha_am[i, j] == curr:
                f["current_streak"][i] += 1
            else:
                break

    # ════════════════════════════════════════
    # I. INTERACTIONS
    # ════════════════════════════════════════

    f["kl_x_alpha_std"]  = f["kl_mean_3s"] * f["alpha_std"]
    f["kl_x_flips"]      = f["kl_mean_3s"] * f["am_alpha_flips"]
    f["kl_x_disagree"]   = f["kl_mean_3s"] * f["disagree_count"]
    f["conv_x_streak"]   = f["alpha_conv"] * f["current_streak"]
    f["conv_x_unanimous"] = f["alpha_conv"] * f["am_alpha_unanimous"]

    kl_p75 = np.nanpercentile(f["kl_mean_3s"], 75) if np.any(~np.isnan(f["kl_mean_3s"])) else 0
    f["kl_high_unstable"] = ((f["kl_mean_3s"] > kl_p75) & (f["alpha_std"] > 0.1)).astype(float)
    f["kl_high_stable"]   = ((f["kl_mean_3s"] > kl_p75) & (f["alpha_std"] < 0.05)).astype(float)
    kl_p25 = np.nanpercentile(f["kl_mean_3s"], 25) if np.any(~np.isnan(f["kl_mean_3s"])) else 0
    f["kl_low_stable"]    = ((f["kl_mean_3s"] < kl_p25) & (f["alpha_std"] < 0.05)).astype(float)

    f["disagree_at_transition"] = f["disagree_current"] * f["am_alpha_transition_now"]
    f["vol_calming_quiet"]      = ((f["micro_gk_slope"] < 0) &
                                    (f["am_alpha_current_quiet"] == 1)).astype(float)

    return f


def attach_features_to_df(df, features, sessions):
    """
    Attach tensor features to dataframe with T_ prefix.
    Also fills session-aware features (regime_age).
    """
    df = df.copy()
    for k, v in features.items():
        df[f"T_{k}"] = v

    # Fill regime_age (needs session awareness)
    df["T_regime_age"] = 0
    for sess in sessions:
        idx = df.index[df["SessionDate"] == sess]
        aq = (df.loc[idx, "Alpha_P1"].values > 0.5).astype(int)
        age, prev = 0, -1
        for i in range(len(idx)):
            if aq[i] == prev: age += 1
            else: age = 0; prev = aq[i]
            df.loc[idx[i], "T_regime_age"] = age

    # age × conviction interaction
    df["T_age_x_conv"] = df["T_regime_age"] * df["T_alpha_conv"]

    return df


# ══════════════════════════════════════════════════════════════
# SLOPE UTILITIES
# ══════════════════════════════════════════════════════════════

def _slope_3(arr_2d):
    """3-point linear slope per row. arr_2d: [T, 3]."""
    T = len(arr_2d)
    slope = np.full(T, 0.0)
    for i in range(T):
        w = arr_2d[i]
        if np.all(~np.isnan(w)) and np.all(w != 0) and np.std(w) > 1e-8:
            slope[i] = np.polyfit(np.arange(3), w, 1)[0]
    return slope


def _slope_4(arr_2d):
    """4-point linear slope per row. arr_2d: [T, 4]."""
    T = len(arr_2d)
    slope = np.full(T, 0.0)
    for i in range(T):
        w = arr_2d[i]
        if np.all(~np.isnan(w)) and np.std(w) > 1e-8:
            slope[i] = np.polyfit(np.arange(4), w, 1)[0]
    return slope


def _count_flips(am):
    """Count state flips in [T, 4] argmax array."""
    T = len(am)
    flips = np.zeros(T)
    for i in range(T):
        for j in range(1, 4):
            if am[i, j] != am[i, j - 1]:
                flips[i] += 1
    return flips


def _last_flip_pos(am):
    """Position of last flip (0=no flip, 1-3=between slots)."""
    T = len(am)
    pos = np.zeros(T)
    for i in range(T):
        for j in range(3, 0, -1):
            if am[i, j] != am[i, j - 1]:
                pos[i] = j
                break
    return pos


def _first_flip_pos(am):
    """Position of first flip."""
    T = len(am)
    pos = np.zeros(T)
    for i in range(T):
        for j in range(1, 4):
            if am[i, j] != am[i, j - 1]:
                pos[i] = j
                break
    return pos