Overall Statistics
Total Orders
21557
Average Win
0.17%
Average Loss
-0.16%
Compounding Annual Return
-53.271%
Drawdown
99.100%
Expectancy
-0.262
Start Equity
30000.00
End Equity
301.48
Net Profit
-98.995%
Sharpe Ratio
-4.392
Sortino Ratio
-4.669
Probabilistic Sharpe Ratio
0%
Loss Rate
65%
Win Rate
35%
Profit-Loss Ratio
1.08
Alpha
-0.432
Beta
-0.005
Annual Standard Deviation
0.099
Annual Variance
0.01
Information Ratio
-1.409
Tracking Error
0.394
Treynor Ratio
79.224
Total Fees
$27418.68
Estimated Strategy Capacity
$600000.00
Lowest Capacity Asset
ETHUSDT 18N
Portfolio Turnover
195.14%
# backtest.py
from AlgorithmImports import *
import numpy as np
import pandas as pd
import pickle
import base64
from collections import defaultdict
import ast

def initialize_backtest(algo):
    """
    Initialize the backtest module.
    
    Args:
        algo: The algorithm instance
    """
    algo.run_training = False
    algo.run_backtest = True

    # Attempt to load the Q-table from ObjectStore
    try:
        qtable_data_b64 = algo.ObjectStore.Read("qtable.pkl")
        if qtable_data_b64 is not None:
            qtable_data = base64.b64decode(qtable_data_b64)
            loaded_data = pickle.loads(qtable_data)
            
            # Handle both formats: direct dict or nested dict with metadata
            if isinstance(loaded_data, dict) and "q_table" in loaded_data:
                # New format with metadata
                q_table_dict = loaded_data["q_table"]
                metadata = loaded_data["metadata"]
                algo.Log(f"Loaded QTable from training at episode {metadata.get('episode', 'unknown')}")
                
                # Log state count and action distribution
                if "state_count" in metadata:
                    algo.Log(f"QTable contains {metadata['state_count']} states")
            else:
                # Old format - direct dictionary
                q_table_dict = loaded_data
            
            # Convert string keys back to tuples if needed
            fixed_dict = {}
            for k, v in q_table_dict.items():
                try:
                    if isinstance(k, str) and k.startswith('(') and k.endswith(')'):
                        # Convert string representation of tuple back to actual tuple
                        tuple_key = ast.literal_eval(k)
                        fixed_dict[tuple_key] = np.array(v)
                    else:
                        # Already a tuple or other key type
                        fixed_dict[k] = np.array(v)
                except Exception as e:
                    algo.Log(f"Error parsing key {k}: {str(e)}")
                    continue
                    
            # Create a defaultdict with the loaded values
            algo.q_table = defaultdict(lambda: np.zeros(len(algo.actions)), fixed_dict)
            
            # Analyze Q-table for biases
            analyze_and_fix_q_table(algo)
            
            # Debug information
            non_zero_states = sum(1 for v in algo.q_table.values() if np.any(v != 0))
            algo.Log(f"QTable loaded with {non_zero_states} non-zero states")
            
            # Sample and log a few Q-values for debugging
            sample_keys = list(fixed_dict.keys())[:3]
            for k in sample_keys:
                algo.Debug(f"Sample state {k}: actions={algo.q_table[k]}")
                
        else:
            algo.Log("QTable not found in ObjectStore; using empty QTable.")
            algo.q_table = defaultdict(lambda: np.zeros(len(algo.actions)))
    except Exception as e:
        algo.Error(f"Error loading QTable: {str(e)}")
        algo.q_table = defaultdict(lambda: np.zeros(len(algo.actions)))
    
    # Initialize backtest performance tracking
    algo.portfolioHistory = {}
    algo.btcHistory = {}
    algo.totalTrades = 0
    algo.actionHistory = {}
    algo.tradeLog = []
    algo.profitLossHistory = []
    
    # Track initial values for comparison
    algo.initial_btc_value = None
    algo.initial_portfolio_value = None
    
    # Add additional plot for position and equity tracking
    chart = Chart("Portfolio Performance")
    chart.AddSeries(Series("Strategy", SeriesType.Line))
    chart.AddSeries(Series("Benchmark", SeriesType.Line))
    chart.AddSeries(Series("Position", SeriesType.Bar, 0, "rgba(0, 200, 0, 0.3)"))
    algo.AddChart(chart)


def analyze_and_fix_q_table(algo):
    """
    Analyze the Q-table for biases and fix them if needed.
    
    Args:
        algo: The algorithm instance
    """
    action_counts = [0, 0, 0]
    total_states = 0
    excessive_states = []
    
    # Count best actions and detect excessive values
    for state, q_values in algo.q_table.items():
        if np.any(q_values != 0):  # Only count non-zero states
            total_states += 1
            best_action = np.argmax(q_values)
            action_counts[best_action] += 1
            
            # Check for extreme values that might dominate
            max_value = np.max(q_values)
            if max_value > 50:  # If any value is too high
                excessive_states.append((state, max_value))
    
    # Log the distribution
    if total_states > 0:
        percentages = [count / total_states * 100 for count in action_counts]
        algo.Log(f"Q-table action biases: Hold={percentages[0]:.1f}%, Long={percentages[1]:.1f}%, Short={percentages[2]:.1f}%")
        
        # Check if we have a skewed distribution
        if percentages[0] > 80 or percentages[1] > 80 or percentages[2] > 80:
            algo.Log("Detected extreme action bias. Applying correction...")
            balance_q_table(algo)
    
    # Fix excessive values
    if excessive_states:
        algo.Log(f"Found {len(excessive_states)} states with excessive Q-values. Normalizing...")
        normalize_q_table(algo)


def balance_q_table(algo):
    """
    Balance the Q-table to ensure a better distribution of actions.
    
    Args:
        algo: The algorithm instance
    """
    # Count best actions
    action_counts = [0, 0, 0]
    valid_states = []
    
    for state, q_values in algo.q_table.items():
        if np.any(q_values != 0):  # Only count non-zero states
            valid_states.append(state)
            best_action = np.argmax(q_values)
            action_counts[best_action] += 1
    
    # Find most common and least common actions
    dominant_action = np.argmax(action_counts)
    underrep_action = np.argmin(action_counts)
    
    if action_counts[underrep_action] < action_counts[dominant_action] * 0.2:
        # Severe imbalance detected - at least 5:1 ratio
        # Boost the underrepresented action for some states
        states_to_modify = min(int(len(valid_states) * 0.2), 50)  # Modify up to 20% of states
        dominant_states = []
        
        for state in valid_states:
            values = algo.q_table[state]
            if np.argmax(values) == dominant_action:
                dominant_states.append(state)
        
        if dominant_states:
            # Take a random sample to modify
            states_to_modify = min(len(dominant_states), states_to_modify)
            import random
            sample_states = random.sample(dominant_states, states_to_modify)
            
            for state in sample_states:
                values = algo.q_table[state]
                # Set the underrepresented action to a value slightly higher than the dominant action
                algo.q_table[state][underrep_action] = values[dominant_action] * 1.05
            
            algo.Log(f"Modified {states_to_modify} states to balance action distribution")
        

def normalize_q_table(algo):
    """
    Normalize Q-values to prevent extreme values from dominating.
    
    Args:
        algo: The algorithm instance
    """
    # Find max absolute value across all states and actions
    max_value = 0
    for state, q_values in algo.q_table.items():
        if np.any(q_values != 0):  # Only consider non-zero states
            abs_max = np.max(np.abs(q_values))
            if abs_max > max_value:
                max_value = abs_max
    
    # Normalize only if max value is excessive
    if max_value > 50:
        scale_factor = max_value / 25  # Normalize to max of ~25
        for state in algo.q_table:
            if np.any(algo.q_table[state] != 0):  # Only normalize non-zero states
                algo.q_table[state] = algo.q_table[state] / scale_factor
        
        algo.Log(f"Normalized Q-values: max value reduced from {max_value:.1f} to {max_value/scale_factor:.1f}")


def on_data_backtest(algo, data):
    """
    Handle new data during backtesting.
    
    Args:
        algo: The algorithm instance
        data: The price data
    """
    # Ensure we have data for both symbols
    if not (data.ContainsKey(algo.symbol) and data.ContainsKey(algo.btcSymbol)):
        return
    
    # Make sure the securities have valid data
    if not algo.Securities[algo.symbol].HasData or not algo.Securities[algo.btcSymbol].HasData:
        return
    
    current_date = algo.Time.date()
    current_value = algo.Portfolio.TotalPortfolioValue
    
    # Track portfolio history
    algo.portfolioHistory[current_date] = current_value
    algo.btcHistory[current_date] = data[algo.btcSymbol].Close
    
    # Initialize benchmark comparison values
    if algo.initial_btc_value is None:
        algo.initial_btc_value = data[algo.btcSymbol].Close
        algo.initial_portfolio_value = current_value
    
    # Calculate and plot benchmark performance (normalized)
    benchmark_value = algo.initial_portfolio_value * (data[algo.btcSymbol].Close / algo.initial_btc_value)
    
    # Plot performance
    algo.Plot("Portfolio Performance", "Strategy", current_value)
    algo.Plot("Portfolio Performance", "Benchmark", benchmark_value)
    
    # Get the current state for the model
    current_state = get_current_state(algo)
    
    # Get Q-values for current state
    q_values = algo.q_table[current_state]
    
    # Debug the state and actions - log once per day
    if current_date.day == 1:  # Log once a month to avoid excessive logging
        max_q = np.max(q_values)
        max_action = np.argmax(q_values)
        action_labels = {0: "Hold", 1: "Long", 2: "Short"}
        algo.Debug(f"State: {current_state}, Best action: {action_labels[max_action]} (Q={max_q:.2f})")
        algo.Debug(f"Q-values: Hold={q_values[0]:.2f}, Long={q_values[1]:.2f}, Short={q_values[2]:.2f}")
    
    # If all Q-values are zero, randomly add small values to encourage exploration
    if np.all(q_values == 0):
        # Add random tiny values (this will be temporary for just this decision)
        exploration_values = np.random.random(len(q_values)) * 0.1
        q_values = q_values + exploration_values
    
    # Select action with highest Q-value
    action = int(np.argmax(q_values))
    
    # Record action for later analysis
    algo.actionHistory[current_date] = action
    
    # Get current position
    prev_position = "LONG" if algo.Portfolio[algo.symbol].IsLong else "SHORT" if algo.Portfolio[algo.symbol].IsShort else "NONE"
    position_value = 0
    if prev_position == "LONG":
        position_value = 1
    elif prev_position == "SHORT":
        position_value = -1
    
    # Plot position
    algo.Plot("Portfolio Performance", "Position", position_value)
    
    # Debug position before trading
    algo.Debug(f"Date: {current_date}, Before trade: Position={prev_position}, Action={action}")
    
    # Execute the trade
    execute_with_confirmation(algo, action)
    
    # Track position changes
    curr_position = "LONG" if algo.Portfolio[algo.symbol].IsLong else "SHORT" if algo.Portfolio[algo.symbol].IsShort else "NONE"
    
    # Debug position after trading
    algo.Debug(f"After trade: Position={curr_position}")
    
    if prev_position != curr_position:
        # Calculate P&L when we exit a position
        price = algo.Securities[algo.symbol].Price
        unrealized_pnl = algo.Portfolio[algo.symbol].UnrealizedProfit
        
        algo.tradeLog.append({
            'date': algo.Time,
            'action': action,
            'price': price,
            'from': prev_position,
            'to': curr_position,
            'pnl': unrealized_pnl
        })
        
        if unrealized_pnl != 0:
            algo.profitLossHistory.append(unrealized_pnl)
            
        # Log the trade
        algo.Log(f"Trade: {prev_position} -> {curr_position} at {price}")
        algo.totalTrades += 1


def execute_with_confirmation(algo, action):
    """
    Execute a trade and confirm it was successful.
    
    Args:
        algo: The algorithm instance
        action: The action to execute
    """
    # Check if we have a valid price
    if not algo.HasValidPrice():
        algo.Log("Skipping trade due to invalid price.")
        return
    
    # Get current position before trade
    before_position = "LONG" if algo.Portfolio[algo.symbol].IsLong else "SHORT" if algo.Portfolio[algo.symbol].IsShort else "NONE"
    
    # Execute the trade
    algo.ExecuteTrade(action)
    
    # Get position after trade
    after_position = "LONG" if algo.Portfolio[algo.symbol].IsLong else "SHORT" if algo.Portfolio[algo.symbol].IsShort else "NONE"
    
    # If position didn't change when it should have, log an error
    expected_change = {
        (0, "NONE"): "NONE",    # Hold + No position = No position
        (0, "LONG"): "LONG",    # Hold + Long = Long
        (0, "SHORT"): "SHORT",  # Hold + Short = Short
        (1, "NONE"): "LONG",    # Long + No position = Long
        (1, "LONG"): "LONG",    # Long + Long = Long
        (1, "SHORT"): "LONG",   # Long + Short = Long
        (2, "NONE"): "SHORT",   # Short + No position = Short
        (2, "LONG"): "SHORT",   # Short + Long = Short
        (2, "SHORT"): "SHORT",  # Short + Short = Short
    }
    
    expected = expected_change.get((action, before_position), None)
    if expected is not None and expected != after_position:
        algo.Error(f"Trade execution failed: Expected {expected}, got {after_position}. Action: {action}, Before: {before_position}")
        # Try to force position to match action
        force_position(algo, action)


def force_position(algo, action):
    """
    Force a position to match an action.
    
    Args:
        algo: The algorithm instance
        action: The action to enforce
    """
    # First, liquidate any existing position
    algo.Liquidate(algo.symbol)
    
    # Then try to establish the desired position
    if action == 1:  # Long
        algo.SetHoldings(algo.symbol, algo.config.allocation)
    elif action == 2:  # Short
        algo.SetHoldings(algo.symbol, -algo.config.allocation)
    

def process_consolidated_data(algo):
    """
    Process consolidated data bars during backtesting.
    
    Args:
        algo: The algorithm instance
    """
    if algo.IsWarmingUp:
        return
        
    if algo.run_backtest:
        # Get the most recent consolidated bars
        if (algo.btcSymbol not in algo.consolidated_data or
            algo.symbol not in algo.consolidated_data or
            len(algo.consolidated_data[algo.btcSymbol]) == 0 or
            len(algo.consolidated_data[algo.symbol]) == 0):
            return
            
        btc_bar = algo.consolidated_data[algo.btcSymbol][-1]
        symbol_bar = algo.consolidated_data[algo.symbol][-1]
        
        current_date = algo.Time.date()
        current_value = algo.Portfolio.TotalPortfolioValue
        
        # Track history
        algo.portfolioHistory[current_date] = current_value
        algo.btcHistory[current_date] = btc_bar.Close
        
        # Initialize benchmark comparison at first data point
        if algo.initial_btc_value is None:
            algo.initial_btc_value = btc_bar.Close
            algo.initial_portfolio_value = current_value
        
        # Calculate and plot benchmark performance (normalized)
        benchmark_value = algo.initial_portfolio_value * (btc_bar.Close / algo.initial_btc_value)
        
        # Track portfolio value in the chart
        algo.Plot("Portfolio Performance", "Strategy", current_value)
        algo.Plot("Portfolio Performance", "Benchmark", benchmark_value)
        
        # Get the current state and determine action from Q-table
        current_state = get_current_state(algo)
        q_values = algo.q_table[current_state]
        
        # If all Q-values are zero, add small random values for exploration
        if np.all(q_values == 0):
            exploration_values = np.random.random(len(q_values)) * 0.1
            q_values = q_values + exploration_values
        
        # Calculate confidence score based on Q-value spread
        q_range = max(q_values) - min(q_values)
        confidence = q_range / (1 + q_range)  # Normalized between 0 and 1
        
        # Choose action with highest Q-value
        action = int(np.argmax(q_values))
        
        # Record action for later analysis
        algo.actionHistory[current_date] = action
        
        # Get current position
        prev_position = "LONG" if algo.Portfolio[algo.symbol].IsLong else "SHORT" if algo.Portfolio[algo.symbol].IsShort else "NONE"
        
        # Track position size for chart
        position_value = 0
        if prev_position == "LONG":
            position_value = 1
        elif prev_position == "SHORT":
            position_value = -1
        algo.Plot("Portfolio Performance", "Position", position_value)
        
        # Debug trading info
        action_labels = {0: "Hold", 1: "Long", 2: "Short"}
        algo.Debug(f"State: {current_state}, Action: {action_labels[action]}, Confidence: {confidence:.2f}")
        
        # Execute the trade
        execute_with_confirmation(algo, action)
        
        # Track position changes
        curr_position = "LONG" if algo.Portfolio[algo.symbol].IsLong else "SHORT" if algo.Portfolio[algo.symbol].IsShort else "NONE"
        
        if prev_position != curr_position:
            # Calculate P&L when we exit a position
            price = algo.Securities[algo.symbol].Price
            unrealized_pnl = algo.Portfolio[algo.symbol].UnrealizedProfit
            
            algo.tradeLog.append({
                'date': algo.Time,
                'action': action,
                'price': price,
                'from': prev_position,
                'to': curr_position,
                'pnl': unrealized_pnl,
                'confidence': confidence
            })
            
            if unrealized_pnl != 0:
                algo.profitLossHistory.append(unrealized_pnl)
                
            # Log the trade
            algo.Log(f"Trade: {prev_position} -> {curr_position} at {price} (conf: {confidence:.2f})")
            algo.totalTrades += 1


def get_current_state(algo):
    """
    Get the current state for the RL model during backtesting.
    
    Args:
        algo: The algorithm instance
        
    Returns:
        tuple: A tuple representing the current state
    """
    # Make sure indicators are ready
    if not (algo.sma20.IsReady and algo.sma50.IsReady and algo.rsi.IsReady and algo.macd.IsReady):
        return (0, 0, 0, 0, 0)

    price = algo.Securities[algo.symbol].Close
    volume = algo.Securities[algo.symbol].Volume
    rsi_val = algo.rsi.Current.Value
    macd_val = algo.macd.Current.Value

    resolution = algo.config.resolution_backtest
    history_length = 252
    
    # Get the appropriate historical data based on resolution
    if resolution == Resolution.Minute:
        history = algo.History(algo.symbol, history_length, Resolution.Minute)
        if history.empty:
            return (0, 0, 0, 0, 0)
        # Handle index structure - could be single or multi-level
        if history.index.nlevels > 1:
            history = history.loc[algo.symbol]
        # Resample to consolidation interval
        consolidation = f"{algo.config.consolidation_interval}T"
        history_bars = history.resample(consolidation).agg({
            'open': 'first', 
            'high': 'max',
            'low': 'min', 
            'close': 'last', 
            'volume': 'sum'
        }).dropna()
    else:
        history = algo.History(algo.symbol, history_length, Resolution.Daily)
        if history.empty:
            return (0, 0, 0, 0, 0)
        # Handle index structure
        if history.index.nlevels > 1:
            history_bars = history.loc[algo.symbol]
        else:
            history_bars = history

    # Discretize the current state values
    try:
        price_bin = algo.DiscretizeValue(price, history_bars['close'])
        volume_bin = algo.DiscretizeValue(volume, history_bars['volume'])
        rsi_bin = algo.DiscretizeValue(rsi_val, pd.Series(range(0, 101)))
        
        # For MACD, use price changes as reference distribution
        macd_reference = history_bars['close'].pct_change().dropna()
        if len(macd_reference) < 5:  # Minimum data check
            macd_bin = 0
        else:
            macd_bin = algo.DiscretizeValue(macd_val, macd_reference)
        
        # SMA cross indicator
        sma_cross = 1 if algo.sma20.Current.Value > algo.sma50.Current.Value else 0
        
    except Exception as e:
        algo.Error(f"Error creating state: {str(e)}")
        return (0, 0, 0, 0, 0)
    
    # Make sure all bins are integers
    state = (int(price_bin), int(volume_bin), int(rsi_bin), int(macd_bin), int(sma_cross))
    return state


def on_end_of_algorithm(algo):
    """
    Generate performance summary at the end of backtesting.
    
    Args:
        algo: The algorithm instance
    """
    if not (hasattr(algo, 'portfolioHistory') and algo.portfolioHistory):
        algo.Log("No portfolio history recorded.")
        return

    # Check if we actually made any trades
    if algo.totalTrades == 0:
        algo.Log("No trades were executed during the backtest.")
        
        # Sample the Q-table to see the imbalance
        algo.Log("Q-Table sample values:")
        sample_count = 0
        for state, values in algo.q_table.items():
            if np.any(values != 0) and sample_count < 10:
                algo.Log(f"State {state}: [Hold={values[0]:.2f}, Long={values[1]:.2f}, Short={values[2]:.2f}]")
                sample_count += 1
                
        # Analyze action distribution in the Q-table
        action_counts = [0, 0, 0]
        for state, q_values in algo.q_table.items():
            if np.any(q_values != 0):  # Only consider non-zero states
                best_action = np.argmax(q_values)
                action_counts[best_action] += 1
            
        total_states = sum(action_counts)
        if total_states > 0:
            action_pcts = [count / total_states * 100 for count in action_counts]
            algo.Log(f"Q-table action distribution: Hold={action_pcts[0]:.1f}%, Long={action_pcts[1]:.1f}%, Short={action_pcts[2]:.1f}%")
        
        # Try to determine if there are specific issues with state representation
        if algo.actionHistory:
            action_counts = {0: 0, 1: 0, 2: 0}
            for a in algo.actionHistory.values():
                action_counts[a] += 1
            total_actions = sum(action_counts.values())
            if total_actions > 0:
                algo.Log("Action selections during backtest:")
                algo.Log(f"Hold:  {action_counts[0]} ({action_counts[0]/total_actions*100:.1f}%)")
                algo.Log(f"Long:  {action_counts[1]} ({action_counts[1]/total_actions*100:.1f}%)")
                algo.Log(f"Short: {action_counts[2]} ({action_counts[2]/total_actions*100:.1f}%)")
            
        return

    # Organize data by year for annual performance analysis
    portfolio_by_year = {}
    btc_by_year = {}
    for date, value in algo.portfolioHistory.items():
        portfolio_by_year.setdefault(date.year, {})[date] = value
    
    if hasattr(algo, 'btcHistory') and algo.btcHistory:
        for date, price in algo.btcHistory.items():
            btc_by_year.setdefault(date.year, {})[date] = price

        algo.Log("=== Annual Performance Summary ===")
        algo.Log("Year  Strategy Return  Benchmark Return  Beat BTC?")
        algo.Log("-----------------------------------------------")
        total_years = 0
        years_beat_btc = 0

        for year in sorted(portfolio_by_year.keys()):
            if year not in btc_by_year:
                continue
            dates = sorted(portfolio_by_year[year].keys())
            if not dates or len(dates) < 2:  # Need at least 2 dates for return calculation
                continue
            
            start_value = portfolio_by_year[year][dates[0]]
            end_value = portfolio_by_year[year][dates[-1]]
            strategy_return = (end_value / start_value - 1) * 100
            
            btc_dates = sorted(btc_by_year[year].keys())
            if not btc_dates or len(btc_dates) < 2:
                continue
            btc_start = btc_by_year[year][btc_dates[0]]
            btc_end = btc_by_year[year][btc_dates[-1]]
            benchmark_return = (btc_end / btc_start - 1) * 100
            
            beat = "✓" if strategy_return > benchmark_return else "✗"
            algo.Log(f"{year}   {strategy_return:7.2f}%         {benchmark_return:7.2f}%       {beat}")
            total_years += 1
            if strategy_return > benchmark_return:
                years_beat_btc += 1

    # Calculate overall performance statistics
    if len(algo.portfolioHistory) >= 2:
        dates = sorted(algo.portfolioHistory.keys())
        start_value = algo.portfolioHistory[dates[0]]
        end_value = algo.portfolioHistory[dates[-1]]
        total_return = (end_value / start_value - 1) * 100
        algo.Log(f"Total Return: {total_return:.2f}%")
    
    # Calculate drawdown if we have enough data
    max_drawdown = 0
    if len(algo.portfolioHistory) > 30:
        values = np.array(list(algo.portfolioHistory.values()))
        peak = np.maximum.accumulate(values)
        drawdown = (values - peak) / peak
        max_drawdown = abs(min(drawdown)) * 100
    
    # Calculate Sharpe ratio if we have enough data
    sharpe_ratio = 0
    if len(algo.portfolioHistory) > 30:
        values = list(algo.portfolioHistory.values())
        returns = [values[i] / values[i-1] - 1 for i in range(1, len(values))]
        avg_return = np.mean(returns)
        std_return = np.std(returns)
        if std_return > 0:
            sharpe_ratio = avg_return / std_return * np.sqrt(252)  # Annualized
    
    # Calculate trade statistics
    if hasattr(algo, 'profitLossHistory') and algo.profitLossHistory:
        winning_trades = sum(1 for pnl in algo.profitLossHistory if pnl > 0)
        losing_trades = sum(1 for pnl in algo.profitLossHistory if pnl < 0)
        trade_win_rate = winning_trades / len(algo.profitLossHistory) * 100 if algo.profitLossHistory else 0
        
        avg_win = np.mean([pnl for pnl in algo.profitLossHistory if pnl > 0]) if winning_trades else 0
        avg_loss = np.mean([pnl for pnl in algo.profitLossHistory if pnl < 0]) if losing_trades else 0
        
        # Calculate profit factor
        profit_factor = 0
        total_profit = sum([pnl for pnl in algo.profitLossHistory if pnl > 0])
        total_loss = abs(sum([pnl for pnl in algo.profitLossHistory if pnl < 0]))
        if total_loss > 0:
            profit_factor = total_profit / total_loss
    else:
        trade_win_rate = 0
        avg_win = 0
        avg_loss = 0
        profit_factor = 0
    
    # Log final performance summary
    algo.Log("")
    algo.Log("=== Final Results ===")
    if 'total_years' in locals():
        algo.Log(f"Total Years: {total_years}")
        algo.Log(f"Years Beat BTC: {years_beat_btc}")
        win_rate = years_beat_btc / total_years * 100 if total_years > 0 else 0
        algo.Log(f"Win Rate vs. BTC: {win_rate:.1f}%")
    algo.Log(f"Total Trades: {algo.totalTrades}")
    algo.Log(f"Max Drawdown: {max_drawdown:.2f}%")
    algo.Log(f"Sharpe Ratio: {sharpe_ratio:.2f}")
    
    algo.Log("")
    algo.Log("=== Trade Statistics ===")
    algo.Log(f"Trade Win Rate: {trade_win_rate:.1f}%")
    algo.Log(f"Average Win: ${avg_win:.2f}")
    algo.Log(f"Average Loss: ${avg_loss:.2f}")
    algo.Log(f"Profit Factor: {profit_factor:.2f}")

    # Action distribution statistics
    if hasattr(algo, 'actionHistory') and algo.actionHistory:
        action_counts = {0: 0, 1: 0, 2: 0}
        for a in algo.actionHistory.values():
            action_counts[a] += 1
        total_actions = sum(action_counts.values())
        if total_actions > 0:
            algo.Log("")
            algo.Log("=== Action Distribution ===")
            algo.Log(f"Hold:  {action_counts[0]} ({action_counts[0]/total_actions*100:.1f}%)")
            algo.Log(f"Long:  {action_counts[1]} ({action_counts[1]/total_actions*100:.1f}%)")
            algo.Log(f"Short: {action_counts[2]} ({action_counts[2]/total_actions*100:.1f}%)")
    
    # Log trade summary if we tracked trades
    if hasattr(algo, 'tradeLog') and len(algo.tradeLog) > 0:
        algo.Log("")
        algo.Log("=== Trade Summary ===")
        algo.Log(f"Total Position Changes: {len(algo.tradeLog)}")
        
        # Count position types
        position_changes = {}
        for trade in algo.tradeLog:
            key = f"{trade['from']} -> {trade['to']}"
            position_changes[key] = position_changes.get(key, 0) + 1
        
        for change, count in position_changes.items():
            algo.Log(f"{change}: {count}")
# main.py
from AlgorithmImports import *
import numpy as np
import pandas as pd
import random
from collections import defaultdict, deque
from System.Drawing import Color
from decimal import Decimal
import math
import pickle
import base64
from datetime import datetime, timedelta
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# Import specialized modules
import training
import backtest


# -------------------------------
# Configuration Section
# -------------------------------
class Config:
    """
    Central configuration class for the algorithm.
    Handles parameter initialization, validation, and provides utility methods.
    """
    def __init__(self, algorithm):
        # Mode configuration (training or backtest)
        # self.mode = algorithm.GetParameter("mode", "training").lower() if algorithm.GetParameter("mode","training") else "training"
        self.mode = algorithm.GetParameter("mode", "backtest").lower() if algorithm.GetParameter("mode","backtest") else "training"
        if self.mode not in ["training", "backtest"]:
            algorithm.Log(f"Invalid mode '{self.mode}', defaulting to 'training'")
            self.mode = "training"
        
        # Time frame parameters
        self.start_date = datetime(2019, 1, 1)
        self.end_date = datetime(2025, 3, 31)
        self.initial_cash = self._validate_positive(algorithm, "initial_cash", 30000)
        
        # Trading parameters
        self.commission_rate = self._validate_range(algorithm, "commission_rate", 0.001, 0, 0.02)
        self.slippage = self._validate_range(algorithm, "slippage", 0.0, 0, 0.01)
        self.allocation = self._validate_range(algorithm, "allocation", 0.2, 0.05, 1.0)
        
        # Symbol parameters
        self.benchmark_symbol = algorithm.GetParameter("benchmark_symbol", "BTCUSDT")
        self.trading_symbol = algorithm.GetParameter("trading_symbol", "ETHUSDT")
        
        # Exchange configuration
        self.exchange = algorithm.GetParameter("exchange", "binance")
        
        # Data resolution
        self.resolution_training = Resolution.Hour
        self.resolution_backtest = Resolution.Hour  # or Minute
        self.consolidation_interval = 15
        
        # Indicator parameters
        self.sma_short_period = self._validate_positive_int(algorithm, "sma_short", 20)
        self.sma_long_period = self._validate_positive_int(algorithm, "sma_long", 60)
        self.rsi_period = 14
        self.macd_fast = 12
        self.macd_slow = 26
        self.macd_signal = 9
        
        # State space configuration
        self.state_config = {
            'price_bins': 10,
            'volume_bins': 5,
            'rsi_bins': 4,
            'macd_bins': 3,
            'sma_cross': True
        }
        
        # RL parameters (for training mode)
        self.learning_rate = self._validate_range(algorithm, "learning_rate", 0.1, 0.001, 0.5)
        self.discount_factor = self._validate_range(algorithm, "discount_factor", 0.95, 0.8, 0.99)
        self.epsilon = self._validate_range(algorithm, "epsilon", 1.0, 0.5, 1.0)
        self.epsilon_min = self._validate_range(algorithm, "epsilon_min", 0.01, 0.001, 0.1)
        self.epsilon_decay = self._validate_range(algorithm, "epsilon_decay", 0.99, 0.9, 0.999)  # Less aggressive decay
        self.episodes = self._validate_positive_int(algorithm, "episodes", 20)  # Increased from 8
        self.min_avg_reward = 0.001
        
        # MBPO specific parameters
        self.model_rollout_length = 3
        self.model_train_freq = 50  # Increased to collect more data before training
        self.synthetic_data_ratio = 1  # Reduced to prevent overfitting
        self.min_real_samples = 500  # Increased for better model stability
        self.replay_buffer_size = 20000  # Increased buffer size
        
        # Early stopping
        self.patience = 10  # Increased patience before early stopping
        self.min_improvement = 0.0005  # More sensitive to small improvements
        
        # Minimum data requirements
        self.min_training_days = 60  # Increased minimum training days
        
        # Warmup period for indicators
        self.warmup_period = max(self.sma_long_period * 2, 100)  # Longer warmup
        
        # Reward scaling
        self.reward_scaling = 50.0  # Reduced to prevent extreme values
        
        # Logging frequency control
        self.log_training_frequency = 5
        self.log_model_error_frequency = 2
        
        # Model reset parameters
        self.max_consecutive_failures = 5  # More failures allowed before reset
        self.max_failed_model_updates = 3  # More failures allowed
        
        # Fixed rewards for early stopping
        self.consecutive_negative_rewards_threshold = 10  # Increased threshold
        
        # Q-table saving
        self.save_q_table_frequency = 5  # Only save Q-table every 5 episodes
        
        # Validate that required parameters are consistent
        if self.sma_short_period >= self.sma_long_period:
            algorithm.Log(f"Warning: sma_short_period ({self.sma_short_period}) should be less than sma_long_period ({self.sma_long_period})")
    
    def _validate_positive(self, algorithm, param_name, default_value):
        """Validates a parameter is positive, logs warning if not."""
        value = float(algorithm.GetParameter(param_name) or default_value)
        if value <= 0:
            algorithm.Log(f"Warning: {param_name} must be positive. Using default {default_value}")
            return default_value
        return value
    
    def _validate_positive_int(self, algorithm, param_name, default_value):
        """Validates a parameter is a positive integer, logs warning if not."""
        try:
            value = int(algorithm.GetParameter(param_name) or default_value)
            if value <= 0:
                algorithm.Log(f"Warning: {param_name} must be positive. Using default {default_value}")
                return default_value
            return value
        except (ValueError, TypeError):
            algorithm.Log(f"Warning: {param_name} must be an integer. Using default {default_value}")
            return default_value
    
    def _validate_range(self, algorithm, param_name, default_value, min_value, max_value):
        """Validates a parameter falls within a specific range, logs warning if not."""
        try:
            value = float(algorithm.GetParameter(param_name) or default_value)
            if value < min_value or value > max_value:
                algorithm.Log(f"Warning: {param_name} must be between {min_value} and {max_value}. Using default {default_value}")
                return default_value
            return value
        except (ValueError, TypeError):
            algorithm.Log(f"Warning: {param_name} must be a number. Using default {default_value}")
            return default_value
        
    def get_market(self):
        """Returns a valid Market enum based on the user's exchange parameter."""
        exchange_key = self.exchange.lower()
        if exchange_key == "binance":
            return Market.Binance
        elif exchange_key in ["coinbase", "gdax"]:
            return Market.GDAX
        elif exchange_key == "bitfinex":
            return Market.Bitfinex
        elif exchange_key == "kraken":
            return Market.Kraken
        elif exchange_key == "bitmex":
            return Market.BitMEX
        return Market.Binance


# -------------------------------
# Custom Fee Model
# -------------------------------
class CustomFeeModel(FeeModel):
    """
    Implements a percentage-based fee model for cryptocurrency trading.
    """
    def __init__(self, commission_rate):
        self.commission_rate = commission_rate

    def GetOrderFee(self, security, order):
        """Calculate the fee for a given order based on the commission rate."""
        trade_value = abs(order.Quantity) * security.Price
        fee = Decimal(str(trade_value)) * Decimal(str(self.commission_rate))
        currency = security.QuoteCurrency.Symbol
        return OrderFee(CashAmount(fee, currency))


# -------------------------------
# Custom Slippage Model
# -------------------------------
class CustomSlippageModel(ConstantSlippageModel):
    """
    Implements a constant percentage slippage model for cryptocurrency trading.
    """
    def __init__(self, slippage):
        super().__init__(0.0)
        self.slippage = Decimal(str(slippage))

    def GetSlippageApproximation(self, security, order):
        """Return the slippage percentage for a given security and order."""
        return self.slippage


# -----------------------------------------------------------
# Improved Transition Model for MBPO (using scikit-learn)
# -----------------------------------------------------------
class TransitionModel:
    """
    Predicts the next state given the current state and action.
    Used in Model-Based Policy Optimization (MBPO) for reinforcement learning.
    """
    def __init__(self, state_dim, action_dim):
        """
        Initialize the transition model.
        
        Args:
            state_dim: Dimension of the state space
            action_dim: Dimension of the action space
        """
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.model = Pipeline([
            ('scaler', RobustScaler()),
            ('regressor', RandomForestRegressor(
                n_estimators=30,  # Increased for better ensemble prediction
                max_depth=6,      # Slightly deeper trees
                min_samples_leaf=3,  # Reduced for more detailed model
                random_state=42,
                bootstrap=True,   # Bootstrap samples for better generalization
                max_features=0.7  # Use 70% of features per tree
            ))
        ])
        self.trained = False
        self.validation_error = float('inf')
        self.train_attempts = 0
        self.consecutive_failures = 0
        self.best_error = float('inf')
        self.error_history = []
        
    def train(self, states, actions, next_states):
        """
        Train the transition model using historical data.
        
        Args:
            states: Array of observed states
            actions: Array of actions taken
            next_states: Array of resulting next states
            
        Returns:
            bool: Whether training was successful
        """
        if len(states) < 200:  # Increased minimum sample size
            return False
        
        self.train_attempts += 1
        X = np.hstack((states, actions.reshape(-1, 1)))
        y = next_states
        
        try:
            # Use stratified sampling for better distribution
            X_train, X_val, y_train, y_val = train_test_split(
                X, y, test_size=0.2, random_state=42 + self.train_attempts
            )
            
            # Fit the model with early stopping
            self.model.fit(X_train, y_train)
            y_pred = self.model.predict(X_val)
            new_validation_error = mean_squared_error(y_val, y_pred)
            self.error_history.append(new_validation_error)
            
            # Only accept the model if it improves significantly or is the first model
            improvement_threshold = 0.90  # 10% improvement requirement
            
            if not self.trained or new_validation_error < self.validation_error * improvement_threshold:
                self.validation_error = new_validation_error
                self.trained = True
                self.consecutive_failures = 0
                
                # Track best model
                if new_validation_error < self.best_error:
                    self.best_error = new_validation_error
                
                return True
            else:
                self.consecutive_failures += 1
                # Still return true even if not improved, just track failures
                return True
                
        except Exception as e:
            print(f"Model training error: {str(e)}")
            self.consecutive_failures += 1
            return False
            
    def predict(self, state, action):
        """
        Predict the next state given the current state and action.
        
        Args:
            state: Current state
            action: Action to take
            
        Returns:
            Predicted next state or None if prediction fails
        """
        if not self.trained:
            return None
        try:
            X = np.hstack((state.reshape(1, -1), np.array([[action]])))
            return self.model.predict(X)[0]
        except Exception:
            return None
    
    def needs_reset(self, max_failures):
        """
        Determine if the model needs to be reset due to repeated failures.
        
        Args:
            max_failures: Maximum allowed consecutive failures
            
        Returns:
            bool: Whether the model should be reset
        """
        return self.consecutive_failures >= max_failures


# -----------------------------------------------------------
# Improved Reward Model for MBPO
# -----------------------------------------------------------
class RewardModel:
    """
    Predicts the reward for a given state-action pair.
    Used in Model-Based Policy Optimization (MBPO) for reinforcement learning.
    """
    def __init__(self, state_dim, action_dim):
        """
        Initialize the reward model.
        
        Args:
            state_dim: Dimension of the state space
            action_dim: Dimension of the action space
        """
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.model = Pipeline([
            ('scaler', RobustScaler()),
            ('regressor', RandomForestRegressor(
                n_estimators=25,  # Increased 
                max_depth=5,      # Slightly deeper
                min_samples_leaf=4,
                random_state=42,
                max_features=0.8  # Use 80% of features
            ))
        ])
        self.trained = False
        self.validation_error = float('inf')
        self.train_attempts = 0
        self.consecutive_failures = 0
        self.best_error = float('inf')
        self.error_history = []
        
    def train(self, states, actions, rewards):
        """
        Train the reward model using historical data.
        
        Args:
            states: Array of observed states
            actions: Array of actions taken
            rewards: Array of observed rewards
            
        Returns:
            bool: Whether training was successful
        """
        if len(states) < 200:  # Increased minimum size
            return False
            
        self.train_attempts += 1
        X = np.hstack((states, actions.reshape(-1, 1)))
        y = rewards
        
        try:
            # Balance the sample with positive and negative rewards
            X_train, X_val, y_train, y_val = train_test_split(
                X, y, test_size=0.2, random_state=42 + self.train_attempts
            )
            
            self.model.fit(X_train, y_train)
            y_pred = self.model.predict(X_val)
            new_validation_error = mean_squared_error(y_val, y_pred)
            self.error_history.append(new_validation_error)
            
            # Only accept significant improvements
            improvement_threshold = 0.90  # 10% improvement requirement
            
            if not self.trained or new_validation_error < self.validation_error * improvement_threshold:
                self.validation_error = new_validation_error
                self.trained = True
                self.consecutive_failures = 0
                
                # Track best model
                if new_validation_error < self.best_error:
                    self.best_error = new_validation_error
                
                return True
            else:
                self.consecutive_failures += 1
                # Still return true even if not improved
                return True
                
        except Exception as e:
            print(f"Reward model training error: {str(e)}")
            self.consecutive_failures += 1
            return False
            
    def predict(self, state, action):
        """
        Predict the reward for a given state-action pair.
        
        Args:
            state: Current state
            action: Action to take
            
        Returns:
            Predicted reward
        """
        if not self.trained:
            return 0.0
        try:
            X = np.hstack((state.reshape(1, -1), np.array([[action]])))
            return self.model.predict(X)[0]
        except Exception:
            return 0.0
    
    def needs_reset(self, max_failures):
        """
        Determine if the model needs to be reset due to repeated failures.
        
        Args:
            max_failures: Maximum allowed consecutive failures
            
        Returns:
            bool: Whether the model should be reset
        """
        return self.consecutive_failures >= max_failures


# -----------------------------------------------------------
# Improved Experience Replay Buffer for MBPO
# -----------------------------------------------------------
class ReplayBuffer:
    """
    Stores and retrieves experience tuples for reinforcement learning.
    Uses prioritized experience replay for better sampling.
    """
    def __init__(self, buffer_size=10000):
        """
        Initialize the replay buffer with a maximum size.
        
        Args:
            buffer_size: Maximum number of experiences to store
        """
        self.buffer = deque(maxlen=buffer_size)
        self.priorities = deque(maxlen=buffer_size)
        self.alpha = 0.6  # Priority exponent - higher means more prioritization
        self.beta = 0.4   # Importance sampling weight - starts low and increases to 1
        self.epsilon = 1e-5  # Small constant to avoid zero priority
        
    def add(self, state, action, reward, next_state, td_error=None):
        """
        Add an experience to the buffer with its priority.
        
        Args:
            state: Current state
            action: Action taken
            reward: Reward received
            next_state: Next state observed
            td_error: Optional TD error for priority calculation
        """
        experience = (state, action, reward, next_state)
        self.buffer.append(experience)
        
        # Calculate priority - if TD error not provided, use max priority
        if td_error is None:
            priority = max(self.priorities) if self.priorities else 1.0
        else:
            priority = abs(td_error) + self.epsilon
            
        self.priorities.append(priority)
        
    def update_priorities(self, indices, td_errors):
        """
        Update priorities for specific experiences.
        
        Args:
            indices: List of indices to update
            td_errors: List of corresponding TD errors
        """
        for idx, error in zip(indices, td_errors):
            if 0 <= idx < len(self.priorities):
                self.priorities[idx] = abs(error) + self.epsilon
        
    def sample(self, batch_size, beta=None):
        """
        Sample a batch of experiences from the buffer using prioritized sampling.
        
        Args:
            batch_size: Number of experiences to sample
            beta: Importance sampling weight (if None, use default)
            
        Returns:
            Tuple of (states, actions, rewards, next_states, indices, weights)
        """
        if len(self.buffer) < batch_size:
            batch_size = len(self.buffer)
            
        beta = beta if beta is not None else self.beta
        
        # Convert priorities to probabilities
        priorities = np.array(self.priorities)
        probabilities = priorities ** self.alpha
        probabilities /= probabilities.sum()
        
        # Sample experiences based on probabilities
        indices = np.random.choice(len(self.buffer), batch_size, p=probabilities, replace=False)
        
        # Calculate importance sampling weights
        weights = (len(self.buffer) * probabilities[indices]) ** (-beta)
        weights /= weights.max()  # Normalize weights
        
        # Extract experiences
        batch = [self.buffer[idx] for idx in indices]
        states = np.array([x[0] for x in batch])
        actions = np.array([x[1] for x in batch])
        rewards = np.array([x[2] for x in batch])
        next_states = np.array([x[3] for x in batch])
        
        return states, actions, rewards, next_states, indices, weights
    
    def __len__(self):
        """Return the current size of the buffer."""
        return len(self.buffer)
    
    def clear(self):
        """Clear all experiences from the buffer."""
        self.buffer.clear()
        self.priorities.clear()


# -----------------------------------------------------------
# Main Combined Q-Learning Algorithm
# -----------------------------------------------------------
class CombinedQLearningAlgorithm(QCAlgorithm):
    """
    Main algorithm class that combines reinforcement learning with trading.
    Supports both training and backtesting modes.
    """
    def Initialize(self):
        """Initialize the algorithm, setting up symbols, indicators, and parameters."""
        # 1) Load config
        self.config = Config(self)
        
        # 2) Set timeframe and capital
        self.SetStartDate(self.config.start_date)
        self.SetEndDate(self.config.end_date)
        self.SetCash(self.config.initial_cash)
        
        # 3) Choose resolution based on mode
        resolution = self.config.resolution_training if self.config.mode == "training" else self.config.resolution_backtest
        
        # 4) Get the correct market
        market = self.config.get_market()
        self.SetBrokerageModel(DefaultBrokerageModel(AccountType.Margin))

        # 5) Add benchmark and trading symbols
        self.btcSymbol = self.AddCrypto(self.config.benchmark_symbol, resolution, market).Symbol
        self.SetBenchmark(self.btcSymbol)
        self.symbol = self.AddCrypto(self.config.trading_symbol, resolution, market).Symbol
        
        # 6) If using Minute resolution, set up consolidators
        if resolution == Resolution.Minute:
            self.consolidator_btc = TradeBarConsolidator(timedelta(minutes=self.config.consolidation_interval))
            self.consolidator_trade = TradeBarConsolidator(timedelta(minutes=self.config.consolidation_interval))
            self.SubscriptionManager.AddConsolidator(self.btcSymbol, self.consolidator_btc)
            self.SubscriptionManager.AddConsolidator(self.symbol, self.consolidator_trade)
            self.consolidator_btc.DataConsolidated += self.OnDataConsolidated
            self.consolidator_trade.DataConsolidated += self.OnDataConsolidated
            self.consolidated_data = {}
        
        # 7) Apply custom fee & slippage to each security
        for security in self.Securities.Values:
            security.SetFeeModel(CustomFeeModel(self.config.commission_rate))
            security.SetSlippageModel(CustomSlippageModel(self.config.slippage))
        
        # 8) Create and warm up indicators
        self.sma20 = self.SMA(self.symbol, self.config.sma_short_period, resolution)
        self.sma50 = self.SMA(self.symbol, self.config.sma_long_period, resolution)
        self.rsi = self.RSI(self.symbol, self.config.rsi_period, MovingAverageType.Simple, resolution)
        self.macd = self.MACD(self.symbol, self.config.macd_fast, self.config.macd_slow, self.config.macd_signal, 
                              MovingAverageType.Exponential, resolution)
        self.SetWarmUp(self.config.warmup_period, resolution)
        
        # 9) Define RL action space and dimensions
        self.actions = [0, 1, 2]  # 0=Hold, 1=Long, 2=Short
        self.state_dim = 5
        self.action_dim = 1
        
        # 10) Set up some shared variables
        self.training_data_loaded = False
        self.training_index = None
        self.training_data_length = None
        self.last_training_action_time = None
        self.training_action_interval = timedelta(days=1)
        
        self.epsilon = self.config.epsilon
        self.epsilon_min = self.config.epsilon_min
        self.epsilon_decay = self.config.epsilon_decay
        
        self.best_avg_reward = -float('inf')
        self.no_improvement_count = 0
        self.model_reset_count = 0
        self.consecutive_negative_rewards = 0
        self.log_count = 0
        
        # IMPORTANT: Provide references for training/backtest to create new model instances
        self.TransitionModel = TransitionModel
        self.RewardModel = RewardModel
        self.ReplayBuffer = ReplayBuffer

        # 11) Depending on mode, delegate initialization
        if self.config.mode == "training":
            training.initialize_training(self)
        else:
            backtest.initialize_backtest(self)
        
        market_name = str(market).split('.')[-1]
        self.Debug(f"Initialized in {self.config.mode.upper()} mode on {market_name}.")
        self.Debug(f"Benchmark: {self.config.benchmark_symbol}, Trading: {self.config.trading_symbol}")

    def OnDataConsolidated(self, sender, consolidated_bar):
        """
        Handle consolidated data bars (for minute resolution with consolidation).
        
        Args:
            sender: The consolidator sending the data
            consolidated_bar: The consolidated price bar
        """
        if consolidated_bar is None:
            return
        symbol = consolidated_bar.Symbol
        if not hasattr(self, 'consolidated_data'):
            self.consolidated_data = {}
        if symbol not in self.consolidated_data:
            self.consolidated_data[symbol] = []
        self.consolidated_data[symbol].append(consolidated_bar)
        
        # Limit data storage to prevent memory issues
        max_bars = 1000
        if len(self.consolidated_data[symbol]) > max_bars:
            self.consolidated_data[symbol] = self.consolidated_data[symbol][-max_bars:]
        
        if (self.btcSymbol in self.consolidated_data and 
            self.symbol in self.consolidated_data and
            len(self.consolidated_data[self.btcSymbol]) > 0 and
            len(self.consolidated_data[self.symbol]) > 0):
            # Delegate to training or backtest process function
            if self.config.mode == "training":
                training.process_consolidated_data(self)
            else:
                backtest.process_consolidated_data(self)

    def OnData(self, data):
        """
        Handle incoming price data.
        
        Args:
            data: The new price data
        """
        if self.IsWarmingUp:
            return
        # Delegate OnData to correct module
        if self.config.mode == "training":
            training.on_data_training(self, data)
        else:
            backtest.on_data_backtest(self, data)

    def OnEndOfAlgorithm(self):
        """Handle end of algorithm execution."""
        if self.config.mode == "training":
            if not getattr(self, 'training_complete', False):
                training.evaluate_training_results(self)
                training.save_q_table(self)
        else:
            backtest.on_end_of_algorithm(self)
    
    # -------------------------------
    # SHARED UTILITY METHODS
    # -------------------------------
    def HasValidPrice(self):
        """
        Check if we have valid price data for the trading symbol.
        
        Returns:
            bool: Whether we have valid price data
        """
        if not self.Securities[self.symbol].HasData:
            return False
        price = self.Securities[self.symbol].Price
        return (price is not None) and (price > 0)

    def ExecuteTrade(self, action):
        """
        Execute a trade based on the given action.
        
        Args:
            action: The action to take (0=Hold, 1=Long, 2=Short)
        """
        if not self.HasValidPrice():
            self.Log("Skipping trade: no valid price data available.")
            return
        # 1=Long, 2=Short, 0=Hold
        if action == 1:  # LONG
            if self.Portfolio[self.symbol].Invested and self.Portfolio[self.symbol].IsShort:
                self.Liquidate(self.symbol)
                self.Log(f"Closing SHORT at {self.Time}")
            if not self.Portfolio[self.symbol].Invested:
                self.SetHoldings(self.symbol, self.config.allocation)
                if hasattr(self, 'totalTrades'):
                    self.totalTrades += 1
                self.Log(f"Opening LONG at {self.Time}")
        elif action == 2:  # SHORT
            if self.Portfolio[self.symbol].Invested and not self.Portfolio[self.symbol].IsShort:
                self.Liquidate(self.symbol)
                self.Log(f"Closing LONG at {self.Time}")
            if not self.Portfolio[self.symbol].Invested:
                self.SetHoldings(self.symbol, -self.config.allocation)
                if hasattr(self, 'totalTrades'):
                    self.totalTrades += 1
                self.Log(f"Opening SHORT at {self.Time}")

    def CalculateReward(self, action, current_price, next_price):
        """
        Calculate the reward for a given action and price change.
        
        Args:
            action: The action taken (0=Hold, 1=Long, 2=Short)
            current_price: The price when the action was taken
            next_price: The price after the action
            
        Returns:
            float: The calculated reward
        """
        if current_price <= 0 or not np.isfinite(current_price) or not np.isfinite(next_price):
            return 0
        price_change = (next_price - current_price) / current_price
        
        # Add penalty for taking action 0 (hold) when there's significant price change
        if action == 0:
            # Small non-zero reward for hold to prevent always choosing long/short
            significant_change_threshold = 0.02  # 2%
            if abs(price_change) > significant_change_threshold:
                # Penalty for holding during big movements
                return -abs(price_change) * 0.2
            else:
                # Small reward for holding during small movements
                return 0.001
        elif action == 1:  # LONG
            return price_change
        elif action == 2:  # SHORT
            return -price_change
        return 0

    def Discretize(self, values, num_bins):
        """
        Discretize continuous values into bins.
        
        Args:
            values: Series of values to discretize
            num_bins: Number of bins to use
            
        Returns:
            numpy.ndarray: Array of discretized values
        """
        if len(values) == 0:
            return np.zeros(0)
        try:
            unique_values = np.unique(values.dropna())
            if len(unique_values) <= num_bins:
                # Direct map
                value_to_bin = {val: i for i, val in enumerate(sorted(unique_values))}
                bins = np.array([value_to_bin.get(val, 0) if not pd.isna(val) else 0 for val in values])
                return bins
            
            # Use quantile-based discretization for more robust binning
            bins = pd.qcut(values, num_bins, labels=False, duplicates='drop')
            return np.nan_to_num(bins, nan=0)
        except Exception as e:
            self.Log(f"Discretize error: {str(e)}")
            return np.zeros(len(values))

    def CalculateRSI(self, prices, window):
        """
        Calculate Relative Strength Index for a series of prices.
        
        Args:
            prices: Series of price values
            window: RSI window period
            
        Returns:
            pandas.Series: Series of RSI values
        """
        delta = prices.diff()
        gain = delta.where(delta > 0, 0)
        loss = -delta.where(delta < 0, 0)
        if (gain == 0).all() or (loss == 0).all():
            return pd.Series(np.zeros(len(prices)) + 50)  # Default to neutral RSI
        
        # Use exponential moving average for smoother RSI
        avg_gain = gain.ewm(com=window-1, min_periods=window).mean()
        avg_loss = loss.ewm(com=window-1, min_periods=window).mean()
        
        rs = avg_gain / avg_loss.replace(0, 1e-10)  # Avoid division by zero
        rsi = 100 - (100 / (1 + rs))
        return rsi.fillna(50)  # Fill missing values with neutral RSI

    def CalculateMACD(self, prices):
        """
        Calculate MACD for a series of prices.
        
        Args:
            prices: Series of price values
            
        Returns:
            pandas.Series: Series of MACD values
        """
        ema12 = prices.ewm(span=self.config.macd_fast, adjust=False).mean()
        ema26 = prices.ewm(span=self.config.macd_slow, adjust=False).mean()
        return ema12 - ema26

    def DiscretizeValue(self, value, values):
        """
        Discretize a single value based on a series of reference values.
        
        Args:
            value: Value to discretize
            values: Reference values for discretization
            
        Returns:
            int: Discretized bin value
        """
        if not isinstance(values, (list, pd.Series, np.ndarray)) or len(values) == 0:
            return 0
        try:
            if isinstance(values, list):
                values = pd.Series(values)
            
            # Handle case where all values are the same
            if values.nunique() <= 1:
                return 0
                
            bin_count = self.config.state_config['price_bins']
            # If we're dealing with volume or RSI, set bin_count accordingly
            if hasattr(values, 'name'):
                if values.name == 'volume':
                    bin_count = self.config.state_config['volume_bins']
                elif values.name == 'close':
                    bin_count = self.config.state_config['price_bins']
            # If data is in [0..100], assume it's RSI
            if values.min() >= 0 and values.max() <= 100:
                bin_count = self.config.state_config['rsi_bins']
            # Otherwise, default to macd_bins if it doesn't look like RSI
            if bin_count == self.config.state_config['price_bins'] and not (values.min() >= 0 and values.max() <= 100):
                bin_count = self.config.state_config['macd_bins']

            # Ensure value is in the range of values
            adjusted_value = max(min(value, values.max()), values.min())
                
            # Alternative approach: use percentile ranking
            percentile_rank = (values < adjusted_value).mean()
            val_bin = int(percentile_rank * (bin_count - 1))
            
            return val_bin if val_bin >= 0 else 0
        except Exception as e:
            self.Log(f"DiscretizeValue error: {str(e)}")
            return 0
# training.py
from AlgorithmImports import *
import numpy as np
import pandas as pd
import random
import pickle
import base64
from collections import defaultdict
from datetime import timedelta
from sklearn.preprocessing import StandardScaler

def initialize_training(algo):
    """
    Initializes training-specific variables and objects on the 
    'algo' instance. Called from main.py when mode='training'.
    """
    algo.run_training = True
    algo.run_backtest = False

    # Create environment model instances
    algo.transition_model = algo.TransitionModel(algo.state_dim, algo.action_dim)
    algo.reward_model = algo.RewardModel(algo.state_dim, algo.action_dim)
    algo.replay_buffer = algo.ReplayBuffer(algo.config.replay_buffer_size)

    # Additional training-specific fields
    algo.days_since_model_update = 0
    algo.state_normalizer = StandardScaler()
    algo.state_normalizer_fitted = False
    algo.model_update_failures = 0
    
    # Q-table for discrete RL
    algo.q_table = defaultdict(lambda: np.zeros(len(algo.actions)))
    
    # Episode metrics
    algo.episode_rewards = []
    algo.current_episode = 0
    algo.training_complete = False
    algo.current_episode_reward = 0
    
    # MBPO model debugging/tracking
    algo.model_predictions = []
    algo.model_prediction_errors = []
    
    # Add a chart to visualize rewards or model errors
    chart = Chart("Training Rewards")
    chart.AddSeries(Series("Reward", SeriesType.Line))
    chart.AddSeries(Series("Moving Average", SeriesType.Line))
    chart.AddSeries(Series("Model Error", SeriesType.Line))
    algo.AddChart(chart)


def on_data_training(algo, data):
    """
    Called from OnData() in main.py when mode='training'.
    We can decide whether to run training steps based on time intervals, etc.
    """
    current_time = algo.Time
    
    # If not already loaded, load historical data for training
    if not algo.training_data_loaded:
        load_training_data(algo)
        algo.training_data_loaded = True
        algo.last_training_action_time = current_time

    # Example: train once per day
    if (algo.last_training_action_time is None or 
        current_time - algo.last_training_action_time >= algo.training_action_interval):
        
        if algo.training_data_length and algo.training_data_length > 1:
            train_step(algo)
        algo.last_training_action_time = current_time


def process_consolidated_data(algo):
    """
    Called from OnDataConsolidated in main.py if using minute bars + consolidators.
    We can run training steps on each consolidated bar if desired.
    """
    if algo.IsWarmingUp:
        return
    
    if not algo.training_data_loaded:
        load_training_data(algo)
        algo.training_data_loaded = True
    
    if algo.training_data_length and algo.training_data_length > 1:
        train_step(algo)


def load_training_data(algo):
    """
    Loads historical data (Daily by default) for the training symbol.
    Then discretizes it into states for offline Q-learning or MBPO.
    """
    try:
        history = algo.History(algo.symbol, 252, Resolution.Daily)
        if history.empty or len(history) < algo.config.min_training_days:
            algo.Log("Not enough daily historical data to train.")
            algo.training_complete = True
            algo.training_data_length = 0
            algo.training_index = 0
            return
        
        algo.Log(f"Loaded {len(history)} daily bars for training.")
        history = history.loc[algo.symbol]
        
        algo.tr_closes = history['close'].values
        algo.tr_volumes = history['volume'].values
        
        closes_series = pd.Series(algo.tr_closes)
        sma_fast = closes_series.rolling(algo.config.sma_short_period).mean()
        sma_slow = closes_series.rolling(algo.config.sma_long_period).mean()
        
        # Standard indicators for offline usage
        algo.tr_rsi = algo.CalculateRSI(closes_series, algo.config.rsi_period)
        algo.tr_macd = algo.CalculateMACD(closes_series)
        
        # Discretize the historical data for state representation
        algo.tr_price_bins = algo.Discretize(closes_series, algo.config.state_config['price_bins'])
        algo.tr_volume_bins = algo.Discretize(pd.Series(algo.tr_volumes), algo.config.state_config['volume_bins'])
        algo.tr_rsi_bins = algo.Discretize(algo.tr_rsi, algo.config.state_config['rsi_bins'])
        algo.tr_macd_bins = algo.Discretize(algo.tr_macd, algo.config.state_config['macd_bins'])
        
        sma_cross_series = (sma_fast > sma_slow).astype(int)
        algo.tr_sma_cross = sma_cross_series.fillna(0).values
        
        # Combine into raw states
        algo.tr_raw_states = np.column_stack((
            algo.tr_price_bins,
            algo.tr_volume_bins,
            algo.tr_rsi_bins,
            algo.tr_macd_bins,
            algo.tr_sma_cross
        ))
        
        algo.training_data_length = len(algo.tr_closes)
        # Start training from index 50 so that indicators have lookback
        algo.training_index = min(50, algo.training_data_length - 2)
        
        algo.Log(f"Training data length: {algo.training_data_length}")
        algo.Log(f"Starting Episode {algo.current_episode+1}.")
        
    except Exception as e:
        algo.Log(f"Error loading training data: {str(e)}")
        algo.training_complete = True
        algo.training_data_length = 0
        algo.training_index = 0


def train_step(algo):
    """
    Executes one Q-learning step over the offline training dataset. 
    Updates the Q-table and environment models (MBPO).
    """
    if algo.training_data_length is None or algo.training_index is None:
        algo.Log("Training data not loaded. Skipping train_step.")
        return
    if algo.training_data_length <= 1:
        algo.Log("Insufficient training data length. Stopping training.")
        algo.training_complete = True
        return

    # Early exit if negative rewards accumulate
    if algo.consecutive_negative_rewards >= 5:
        if algo.model_reset_count < 2:
            reset_models(algo)
        else:
            algo.Log("Early stopping due to consecutive negative rewards.")
            algo.training_complete = True
            evaluate_training_results(algo)
            save_q_table(algo)
        return

    # Check end of dataset (episode boundary)
    if algo.training_index >= algo.training_data_length - 1:
        # Conclude the episode
        algo.episode_rewards.append(algo.current_episode_reward)
        algo.current_episode += 1
        algo.epsilon = max(algo.epsilon_min, algo.epsilon * algo.epsilon_decay)
        
        # Logging
        if (algo.current_episode % algo.config.log_training_frequency == 0 
            or algo.current_episode == 1 
            or algo.current_episode == algo.config.episodes):
            
            algo.Log(f"Episode {algo.current_episode} done. "
                     f"Reward: {algo.current_episode_reward:.2f}, "
                     f"Epsilon: {algo.epsilon:.3f}")
        
        # Log model error
        if algo.model_predictions and algo.model_prediction_errors:
            avg_error = np.median(algo.model_prediction_errors)
            algo.Plot("Training Rewards", "Model Error", avg_error)
            algo.Log(f"Model prediction error: {avg_error:.4f}")
            algo.model_prediction_errors = []
        
        update_learning_curve(algo)
        
        # Check for improvement or resets
        if algo.current_episode > 1:
            recent_window = min(3, len(algo.episode_rewards))
            current_avg_reward = np.mean(algo.episode_rewards[-recent_window:])
            
            if current_avg_reward <= 0:
                algo.consecutive_negative_rewards += 1
            else:
                algo.consecutive_negative_rewards = 0
            
            if current_avg_reward > algo.best_avg_reward + algo.config.min_improvement:
                algo.best_avg_reward = current_avg_reward
                algo.no_improvement_count = 0
            else:
                algo.no_improvement_count += 1
            
            if algo.no_improvement_count >= algo.config.patience:
                if algo.model_reset_count < 2:
                    reset_models(algo)
                    algo.no_improvement_count = 0
                else:
                    algo.Log(f"No improvement for {algo.config.patience} episodes. Stopping early.")
                    algo.training_complete = True
                    evaluate_training_results(algo)
                    save_q_table(algo)
                    return
        
        # Check total episodes limit
        if algo.current_episode >= algo.config.episodes:
            algo.training_complete = True
            evaluate_training_results(algo)
            save_q_table(algo)
            return
        
        # Otherwise, reset for the next training episode
        algo.training_index = min(50, algo.training_data_length - 2)
        algo.current_episode_reward = 0
        algo.Log(f"Starting Episode {algo.current_episode+1}.")
        return

    # Actual Q-learning logic
    i = algo.training_index
    try:
        current_state_raw = algo.tr_raw_states[i].copy()
        current_state_tuple = tuple(current_state_raw)
        
        algo.days_since_model_update += 1

        # Periodic environment model updates
        if (algo.days_since_model_update >= algo.config.model_train_freq 
            and len(algo.replay_buffer) >= algo.config.min_real_samples):
            
            success = update_dynamics_model(algo)
            if not success:
                algo.model_update_failures += 1
                if algo.model_update_failures >= algo.config.max_failed_model_updates:
                    if algo.model_reset_count < 2:
                        reset_models(algo)
                    else:
                        algo.Log("Too many model update failures. Stopping training.")
                        algo.training_complete = True
                        evaluate_training_results(algo)
                        save_q_table(algo)
                    return
            else:
                algo.model_update_failures = 0
            
            algo.days_since_model_update = 0

        # Epsilon-greedy policy
        if random.random() < algo.epsilon:
            action = random.choice(algo.actions)
        else:
            action = int(np.argmax(algo.q_table[current_state_tuple]))
        
        if i + 1 >= len(algo.tr_raw_states):
            algo.Log(f"Index {i+1} out of range in training data.")
            algo.training_complete = True
            return
        
        real_next_state_raw = algo.tr_raw_states[i+1].copy()
        real_next_state_tuple = tuple(real_next_state_raw)
        
        reward = algo.CalculateReward(action, algo.tr_closes[i], algo.tr_closes[i+1])
        reward *= algo.config.reward_scaling
        
        # Store in replay buffer
        algo.replay_buffer.add(current_state_raw, action, reward, real_next_state_raw)
        
        # Optionally measure model prediction error
        if algo.transition_model.trained and algo.reward_model.trained:
            if (not algo.state_normalizer_fitted) and (len(algo.replay_buffer) >= 100):
                # Need to unpack all 6 return values from sample
                states, _, _, _, _, _ = algo.replay_buffer.sample(min(len(algo.replay_buffer), 1000))
                algo.state_normalizer.fit(states)
                algo.state_normalizer_fitted = True
            
            if algo.state_normalizer_fitted:
                try:
                    normalized_state = algo.state_normalizer.transform(
                        current_state_raw.reshape(1, -1)
                    )[0]
                    predicted_next_state = algo.transition_model.predict(normalized_state, action)
                    if predicted_next_state is not None:
                        normalized_real_next = algo.state_normalizer.transform(
                            real_next_state_raw.reshape(1, -1)
                        )[0]
                        prediction_error = np.mean((normalized_real_next - predicted_next_state)**2)
                        # Clip extreme values
                        prediction_error = min(prediction_error, 100.0)
                        algo.model_prediction_errors.append(prediction_error)
                        algo.model_predictions.append((normalized_state, action, 
                                                       predicted_next_state, 
                                                       normalized_real_next))
                except Exception:
                    pass

        # Update episode reward
        algo.current_episode_reward += reward
        
        # For debugging, plot reward occasionally
        if i % 10 == 0:
            algo.Plot("Training Rewards", "Reward", reward)

        # Q-learning update
        best_next_action = int(np.argmax(algo.q_table[real_next_state_tuple]))
        td_target = reward + algo.config.discount_factor * algo.q_table[real_next_state_tuple][best_next_action]
        td_error = td_target - algo.q_table[current_state_tuple][action]
        algo.q_table[current_state_tuple][action] += algo.config.learning_rate * td_error

        # Generate synthetic experience if environment models are ready
        if (algo.transition_model.trained and algo.reward_model.trained 
            and len(algo.replay_buffer) >= algo.config.min_real_samples):
            generate_synthetic_experience(algo, current_state_raw, action)

        # Move to next step
        algo.training_index += 1

    except Exception as e:
        algo.Log(f"Error in train_step: {str(e)}")
        algo.training_index = algo.training_data_length


def reset_models(algo):
    """
    Resets transition and reward models upon repeated failures 
    or negative reward scenarios. Increases epsilon to encourage 
    more exploration afterward.
    """
    algo.Log(f"Resetting models (reset #{algo.model_reset_count+1})")
    
    # Recreate model instances
    algo.transition_model = algo.TransitionModel(algo.state_dim, algo.action_dim)
    algo.reward_model = algo.RewardModel(algo.state_dim, algo.action_dim)
    
    # Clear replay buffer so old samples are discarded
    algo.replay_buffer.clear()

    # Reset normalizer and counters
    algo.state_normalizer = StandardScaler()
    algo.state_normalizer_fitted = False
    algo.model_update_failures = 0
    algo.model_predictions = []
    algo.model_prediction_errors = []
    
    algo.model_reset_count += 1
    algo.consecutive_negative_rewards = 0
    
    # Increase epsilon for more exploration
    algo.epsilon = max(0.5, algo.epsilon * 1.5)


def update_dynamics_model(algo):
    """
    Samples from replay buffer and trains the transition and reward models.
    Returns True if both models trained successfully, else False.
    """
    if len(algo.replay_buffer) < algo.config.min_real_samples:
        return False

    sample_size = min(len(algo.replay_buffer), 2000)
    # Correctly unpack all 6 values returned by sample method
    states, actions, rewards, next_states, indices, weights = algo.replay_buffer.sample(sample_size)
    
    algo.Log(f"Training environment models with {len(states)} samples")

    transition_success = algo.transition_model.train(states, actions, next_states)
    reward_success = algo.reward_model.train(states, actions, rewards)
    
    if transition_success and reward_success:
        algo.Log("Successfully updated environment models")
        return True
    else:
        algo.Log("Failed to update environment models")
        return False


def generate_synthetic_experience(algo, init_state, init_action):
    """
    Creates short synthetic rollouts from the environment models (MBPO),
    expanding the Q-learning updates with imaginary data.
    """
    if not algo.transition_model.trained or not algo.reward_model.trained:
        return
    if not algo.state_normalizer_fitted:
        return

    num_synthetic = min(algo.config.synthetic_data_ratio, 5)
    for _ in range(num_synthetic):
        state = init_state.copy()
        action = init_action
        try:
            normalized_state = algo.state_normalizer.transform(state.reshape(1, -1))[0]
            for _rollout_step in range(min(algo.config.model_rollout_length, 3)):
                next_state_norm = algo.transition_model.predict(normalized_state, action)
                if next_state_norm is None:
                    break
                next_state = algo.state_normalizer.inverse_transform(
                    next_state_norm.reshape(1, -1)
                )[0]
                # Clamp discrete bins
                next_state = np.round(next_state).astype(int)
                next_state[0] = max(0, min(algo.config.state_config['price_bins']-1, next_state[0]))
                next_state[1] = max(0, min(algo.config.state_config['volume_bins']-1, next_state[1]))
                next_state[2] = max(0, min(algo.config.state_config['rsi_bins']-1, next_state[2]))
                next_state[3] = max(0, min(algo.config.state_config['macd_bins']-1, next_state[3]))
                next_state[4] = 1 if next_state[4] > 0 else 0
                
                # Synthetic reward from the reward model
                reward = algo.reward_model.predict(normalized_state, action) * algo.config.reward_scaling
                
                state_tuple = tuple(state.astype(int))
                next_state_tuple = tuple(next_state.astype(int))
                
                best_next_action = int(np.argmax(algo.q_table[next_state_tuple]))
                td_target = reward + algo.config.discount_factor * algo.q_table[next_state_tuple][best_next_action]
                td_error = td_target - algo.q_table[state_tuple][action]
                
                # Use half the normal learning rate for synthetic transitions
                algo.q_table[state_tuple][action] += algo.config.learning_rate * 0.5 * td_error
                
                # Advance
                state = next_state
                normalized_state = next_state_norm
                
                # Epsilon-greedy action for next step
                if random.random() < algo.epsilon:
                    action = random.choice(algo.actions)
                else:
                    action = int(np.argmax(algo.q_table[next_state_tuple]))
        except Exception:
            continue


def update_learning_curve(algo):
    """
    Plots a moving average of recent episode rewards for tracking progress.
    """
    if not algo.episode_rewards:
        return
    window_size = min(3, len(algo.episode_rewards))
    if window_size == 0:
        return
    
    moving_avg = sum(algo.episode_rewards[-window_size:]) / window_size
    last_reward = algo.episode_rewards[-1]
    algo.Plot("Training Rewards", "Moving Average", moving_avg)
    algo.Log(f"Learning curve updated: Last Episode={last_reward:.4f}, MA={moving_avg:.4f}")


def evaluate_training_results(algo):
    """
    Summarizes training performance once training is complete.
    """
    avg_reward = np.mean(algo.episode_rewards) if algo.episode_rewards else 0
    max_reward = max(algo.episode_rewards) if algo.episode_rewards else 0
    min_reward = min(algo.episode_rewards) if algo.episode_rewards else 0
    
    algo.Log(f"Training complete. Avg={avg_reward:.2f}, "
             f"Max={max_reward:.2f}, Min={min_reward:.2f}.")

    if algo.model_predictions and algo.model_prediction_errors:
        avg_model_error = np.median(algo.model_prediction_errors)
        algo.Log(f"Model average prediction error: {avg_model_error:.4f}")
    else:
        algo.Log("No valid model prediction errors recorded.")
    
    if avg_reward >= algo.config.min_avg_reward:
        algo.Log("Training reached min average reward. Ready to trade.")
    else:
        algo.Log("Training did not meet minimum average reward.")


def save_q_table(algo):
    """
    Saves the Q-table to the ObjectStore for retrieval by backtest mode.
    """
    try:
        # First, clean up the Q-table to remove default zero states
        cleaned_q_table = {}
        for state, actions in algo.q_table.items():
            if np.any(actions != 0):  # Only save non-zero action values
                cleaned_q_table[state] = actions.tolist()  # Convert to list for better serialization
        
        # Include some metadata for better debugging
        save_data = {
            "q_table": cleaned_q_table,
            "metadata": {
                "episode": algo.current_episode,
                "avg_reward": np.mean(algo.episode_rewards) if algo.episode_rewards else 0,
                "state_count": len(cleaned_q_table),
                "training_date": str(algo.Time)
            }
        }
        
        # Log some statistics about the Q-table
        action_counts = [0, 0, 0]
        for state, actions in cleaned_q_table.items():
            best_action = np.argmax(actions)
            action_counts[best_action] += 1
        
        total = sum(action_counts)
        if total > 0:
            algo.Log(f"Saving Q-table with {len(cleaned_q_table)} states")
            algo.Log(f"Action distribution: Hold={action_counts[0]/total*100:.1f}%, Long={action_counts[1]/total*100:.1f}%, Short={action_counts[2]/total*100:.1f}%")
        
        qtable_data = pickle.dumps(save_data)
        qtable_data_b64 = base64.b64encode(qtable_data).decode('utf-8')
        
        if algo.ObjectStore is not None:
            algo.ObjectStore.Save("qtable.pkl", qtable_data_b64)
            algo.Log("QTable saved to ObjectStore as 'qtable.pkl'")
        else:
            algo.Log("ObjectStore unavailable. QTable not saved.")
    except Exception as e:
        algo.Log(f"Error saving QTable: {str(e)}")