Overall Statistics
from AlgorithmImports import *
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from io import StringIO
import json

class IBMVolatilityCollector(QCAlgorithm):
    def Initialize(self) -> None:
        # Set date range to collect extended historical data
        self.SetStartDate(2005, 1, 1)  # Starting from 2005
        self.SetEndDate(2024, 2, 29)   # Up to current date
        self.SetCash(100000)
        
        # Add IBM equity
        self.ibm = self.AddEquity("IBM", Resolution.Daily)
        self.ibm.SetLeverage(1)
        self.ibm_symbol = self.ibm.Symbol
        
        # Add IBM options
        self.option = self.AddOption("IBM", Resolution.Daily)
        self.option.SetFilter(lambda universe: universe.IncludeWeeklys()
                                                     .Strikes(-20, 20)
                                                     .Expiration(0, 90))
                                                     
        # Initialize data containers
        self.volatility_data = []
        self.earnings_dates = self.LoadEarningsData()
        self.historical_prices = {}
        self.missing_dates = []
        
        # Track the last time we saved data
        self.last_save = datetime.min
        
        # Schedule daily data collection
        self.Schedule.On(self.DateRules.EveryDay("IBM"), 
                         self.TimeRules.AfterMarketOpen("IBM", 30), 
                         self.CollectVolatilityData)
                         
        # Log initialization
        self.Log("IBM Volatility Collector initialized")
        
    def LoadEarningsData(self):
        """Load earnings dates from ObjectStore or use defaults"""
        earnings_dates = []
        
        # Try to load from ObjectStore
        if self.ObjectStore.ContainsKey("team6/ibm_earnings_history_final.csv"):
            try:
                csv_data = self.ObjectStore.Read("team6/ibm_earnings_history_final.csv")
                lines = csv_data.strip().split('\n')
                headers = lines[0].split(',')
                report_date_index = headers.index('report_date')
                
                for i in range(1, len(lines)):
                    values = lines[i].split(',')
                    if len(values) > report_date_index:
                        date_str = values[report_date_index]
                        try:
                            # Remove quotes if present
                            date_str = date_str.strip('"')
                            earnings_date = datetime.strptime(date_str, '%Y-%m-%d')
                            earnings_dates.append(earnings_date)
                        except ValueError:
                            self.Log(f"Error parsing date: {date_str}")
                
                self.Log(f"Loaded {len(earnings_dates)} earnings dates from ObjectStore")
                return earnings_dates
            except Exception as e:
                self.Log(f"Error loading earnings dates: {str(e)}")
        
        # If loading fails, use default quarterly dates
        self.Log("Using default quarterly earnings dates")
        start_date = self.StartDate
        end_date = self.EndDate
        current_date = start_date
        
        while current_date < end_date:
            for month in [1, 4, 7, 10]:
                earnings_date = datetime(current_date.year, month, 15)
                if start_date <= earnings_date <= end_date:
                    earnings_dates.append(earnings_date)
            current_date = datetime(current_date.year + 1, 1, 1)
        
        return earnings_dates
        
    def OnData(self, slice: Slice) -> None:
        """Process new data as it arrives"""
        # Store daily IBM closing prices for realized volatility calculation
        if self.ibm_symbol in slice.Bars:
            bar = slice.Bars[self.ibm_symbol]
            date_key = bar.Time.date().strftime('%Y-%m-%d')
            self.historical_prices[date_key] = bar.Close
        
        # Process option chains
        if slice.OptionChains.ContainsKey(self.ibm_symbol):
            chain = slice.OptionChains[self.ibm_symbol]
            
            # Skip if the chain is empty
            if not chain or len(chain) == 0:
                return
            
            # Get the ATM options
            atm_strike = self.FindAtmStrike(chain)
            
            # Get ATM call and put
            atm_call = None
            atm_put = None
            
            for contract in chain:
                if abs(contract.Strike - atm_strike) < 0.01:
                    if contract.Right == OptionRight.Call:
                        atm_call = contract
                    elif contract.Right == OptionRight.Put:
                        atm_put = contract
            
            # If we have both ATM call and put, calculate and store volatility data
            if atm_call and atm_put:
                # Use the midpoint of call and put IV for ATM IV (more accurate)
                atm_iv = (atm_call.ImpliedVolatility + atm_put.ImpliedVolatility) / 2
                
                # Calculate realized volatility (20-day historical)
                realized_vol = self.CalculateRealizedVolatility(20)
                
                # Find nearest earnings date and days to/from it
                nearest_earnings, days_to_earnings = self.FindNearestEarnings(self.Time)
                
                # Calculate earnings surprise if it's a post-earnings date (simple implementation)
                earnings_surprise = self.CalculateEarningsSurprise(days_to_earnings)
                
                # Store the data
                self.volatility_data.append({
                    'date': self.Time.strftime('%Y-%m-%d'),
                    'symbol': 'IBM',
                    'implied_volatility': atm_iv,
                    'realized_volatility': realized_vol,
                    'days_to_earnings': days_to_earnings,
                    'earnings_date': nearest_earnings.strftime('%Y-%m-%d') if nearest_earnings else None,
                    'earnings_surprise': earnings_surprise,
                    'atm_strike': atm_strike,
                    'underlying_price': self.Securities["IBM"].Price
                })
                
                # Log data
                if len(self.volatility_data) % 20 == 0:
                    self.Log(f"Collected {len(self.volatility_data)} volatility data points")
    
    def FindAtmStrike(self, chain):
        """Find the ATM strike price"""
        current_price = self.Securities["IBM"].Price
        
        if current_price == 0:
            return None
            
        # Get all strikes
        strikes = [contract.Strike for contract in chain]
        
        # Find closest strike to current price
        return min(strikes, key=lambda strike: abs(strike - current_price))
    
    def CalculateRealizedVolatility(self, days=20):
        """Calculate historical realized volatility using closing prices"""
        try:
            # Ensure we have enough historical data
            if len(self.historical_prices) < days + 1:
                return None
                
            # Get the most recent 'days' number of prices
            sorted_dates = sorted(self.historical_prices.keys(), reverse=True)
            recent_dates = sorted_dates[:days+1]
            
            if len(recent_dates) < days + 1:
                return None
                
            # Extract prices in chronological order
            prices = [self.historical_prices[date] for date in sorted(recent_dates)]
            
            # Calculate log returns
            log_returns = np.diff(np.log(prices))
            
            # Calculate annualized volatility (standard deviation of returns)
            if len(log_returns) > 0:
                return np.std(log_returns) * np.sqrt(252)
            else:
                return None
                
        except Exception as e:
            self.Error(f"Error calculating realized volatility: {str(e)}")
            return None
    
    def FindNearestEarnings(self, current_date):
        """Find the nearest earnings date and calculate days difference"""
        if not self.earnings_dates:
            return None, None
            
        # Find closest earnings date
        nearest_date = min(self.earnings_dates, key=lambda d: abs((d - current_date).days))
        days_diff = (nearest_date - current_date).days
        
        return nearest_date, days_diff
    
    def CalculateEarningsSurprise(self, days_to_earnings):
        """Calculate earnings surprise if after an earnings announcement"""
        # Only calculate surprise if this is a post-earnings date (within 10 days)
        if days_to_earnings is None or days_to_earnings >= 0 or days_to_earnings < -10:
            return None
        
        # For this example, we'll calculate a simple earnings surprise
        # based on the volatility change right after earnings
        try:
            # Find the IV from earnings day and day after
            earnings_iv = None
            post_earnings_iv = None
            
            # Look for data points around the earnings date
            for i in range(len(self.volatility_data) - 1, -1, -1):
                data_point = self.volatility_data[i]
                if data_point['days_to_earnings'] == 0:
                    earnings_iv = data_point['implied_volatility']
                elif data_point['days_to_earnings'] == -1:
                    post_earnings_iv = data_point['implied_volatility']
                    
                # If we found both, we can calculate the surprise
                if earnings_iv is not None and post_earnings_iv is not None:
                    # Calculate surprise as percentage change in IV
                    return (post_earnings_iv - earnings_iv) / earnings_iv
                    
            return None
            
        except Exception as e:
            self.Error(f"Error calculating earnings surprise: {str(e)}")
            return None
    
    def CollectVolatilityData(self):
        """Scheduled event to collect volatility data"""
        # No specific actions needed here as data is collected in OnData
        # This is just to ensure regular collection
        
        # Save data periodically (every month)
        time_since_last_save = (self.Time - self.last_save).days
        if time_since_last_save > 30:
            self.SaveVolatilityData("interim")
            self.last_save = self.Time
    
    def OnEndOfAlgorithm(self):
        """Save final data at the end of the algorithm"""
        self.SaveVolatilityData("final")
        self.LogSummaryStatistics()
    
    def SaveVolatilityData(self, version="final"):
        """Save collected volatility data to CSV in ObjectStore"""
        if not self.volatility_data:
            self.Log("No volatility data to save")
            return
            
        try:
            # Convert to string format for CSV storage
            output = StringIO()
            
            # Write headers
            if self.volatility_data:
                headers = list(self.volatility_data[0].keys())
                output.write(','.join([f'"{h}"' for h in headers]) + '\n')
                
                # Write data rows
                for item in self.volatility_data:
                    row = []
                    for key in headers:
                        # Handle special characters and quotes in CSV
                        value = str(item.get(key, ''))
                        value = value.replace('"', '""')  # Escape quotes
                        row.append(f'"{value}"')
                    output.write(','.join(row) + '\n')
            
            # Save to ObjectStore
            filename = f"team6/ibm_volatility_data_{version}.csv"
            self.ObjectStore.Save(filename, output.getvalue())
            
            self.Log(f"Saved {len(self.volatility_data)} volatility data points to {filename}")
            
        except Exception as e:
            self.Error(f"Error saving volatility data: {str(e)}")
    
    def LogSummaryStatistics(self):
        """Log summary statistics about the collected volatility data"""
        if not self.volatility_data:
            return
            
        self.Log(f"\nVolatility Collection Complete - {len(self.volatility_data)} data points")
        
        # Calculate date range
        dates = [datetime.strptime(item['date'], '%Y-%m-%d') for item in self.volatility_data]
        min_date = min(dates) if dates else None
        max_date = max(dates) if dates else None
        
        if min_date and max_date:
            self.Log(f"Date Range: {min_date.strftime('%Y-%m-%d')} to {max_date.strftime('%Y-%m-%d')}")
        
        # Calculate IV statistics
        iv_values = [item['implied_volatility'] for item in self.volatility_data if item['implied_volatility'] is not None]
        if iv_values:
            avg_iv = sum(iv_values) / len(iv_values)
            min_iv = min(iv_values)
            max_iv = max(iv_values)
            self.Log(f"Implied Volatility - Avg: {avg_iv:.4f}, Min: {min_iv:.4f}, Max: {max_iv:.4f}")
        
        # Calculate RV statistics
        rv_values = [item['realized_volatility'] for item in self.volatility_data if item['realized_volatility'] is not None]
        if rv_values:
            avg_rv = sum(rv_values) / len(rv_values)
            min_rv = min(rv_values)
            max_rv = max(rv_values)
            self.Log(f"Realized Volatility - Avg: {avg_rv:.4f}, Min: {min_rv:.4f}, Max: {max_rv:.4f}")
        
        # Calculate earnings statistics
        earnings_data = [item for item in self.volatility_data if item['days_to_earnings'] is not None]
        self.Log(f"Data Points with Earnings Information: {len(earnings_data)}")
        
        # Statistics by days to earnings
        days_to_earnings = {}
        for item in earnings_data:
            days = item['days_to_earnings']
            if -10 <= days <= 10:  # Focus on window around earnings
                if days not in days_to_earnings:
                    days_to_earnings[days] = {
                        'count': 0,
                        'iv_sum': 0,
                        'rv_sum': 0
                    }
                    
                days_to_earnings[days]['count'] += 1
                if item['implied_volatility'] is not None:
                    days_to_earnings[days]['iv_sum'] += item['implied_volatility']
                if item['realized_volatility'] is not None:
                    days_to_earnings[days]['rv_sum'] += item['realized_volatility']
        
        self.Log("\nVolatility by Days to Earnings:")
        for days in sorted(days_to_earnings.keys()):
            stats = days_to_earnings[days]
            count = stats['count']
            if count > 0:
                avg_iv = stats['iv_sum'] / count if stats['iv_sum'] > 0 else 0
                avg_rv = stats['rv_sum'] / count if stats['rv_sum'] > 0 else 0
                days_label = "Before" if days < 0 else ("After" if days > 0 else "On")
                self.Log(f"{abs(days)} Days {days_label} Earnings: Count={count}, Avg IV={avg_iv:.4f}, Avg RV={avg_rv:.4f}")