using System;
using System.Collections.Generic;
using QuantConnect.Data.Consolidators;
using QuantConnect.Indicators;
using QuantConnect.Data.Market;
using Accord.MachineLearning.VectorMachines;
using Accord.MachineLearning.VectorMachines.Learning;
namespace QuantConnect.Algorithm
{
/*
* Linear Support Vector Machine Prediction using historical data:
* Example of using the Accord ML library to perform prediction using a RollingWindow of returns.
* Also included is an example of a Stop Loss.
* - Gene Wildhart
*/
public class ExampleMachineLearning : QCAlgorithm
{
static int trainSize = 70;
static int inputSize = 4;
static decimal stopPct = 0.6m;
static decimal tradePct = 1.0m;
string _symbol = "SPY";
decimal stopPrice;
RollingWindow<TradeBar> _daywindow = new RollingWindow<TradeBar>(inputSize*trainSize+2);
AverageTrueRange _atr;
Queue<double[]> _samples = new Queue<double[]>(trainSize);
SupportVectorMachine svm;
/// <summary>
/// Initialize QuantConnect Strategy.
/// </summary>
public override void Initialize()
{
SetStartDate(2010, 1, 1);
SetEndDate(DateTime.Now.Date.AddDays(-1));
SetCash(10000);
AddSecurity(SecurityType.Equity, _symbol, Resolution.Minute);
_atr = ATR(_symbol, 20, MovingAverageType.Simple, Resolution.Daily);
// Construct Day Consolidator
var dayConsolidator = new TradeBarConsolidator(TimeSpan.FromDays(1));
dayConsolidator.DataConsolidated += OnDataDay;
SubscriptionManager.AddConsolidator(_symbol, dayConsolidator);
}
/// <summary>
/// OnData Callback (called every minute) - Prediction is made once per day, 1m after market opens
/// StopLoss is handled on a minutely basis.
/// </summary>
public void OnData(TradeBars data)
{
// Wait for data to become ready
if (!_atr.IsReady) return;
if (!_daywindow.IsReady) return;
// At Market Open use historical close data to make a prediction
if ( (data.Time.Hour == 9) && (data.Time.Minute == 31)) {
double[] returns = new double[inputSize];
int[] _targets = new int[trainSize];
// Build training set from historical data
_samples.Clear();
for (int i=0;i<trainSize;i++) {
for (int j=0;j<inputSize;j++) {
// Inputs are 1 for an up day, -1 for a down day
returns[j] = (((_daywindow [i+j+1].Close - _daywindow[i+j+2].Close) / _daywindow[i+j+2].Close) > 0) ? 1 : -1;
}
// targets are one day in the future as compared to inputs
_targets[i] = (((_daywindow [i].Close - _daywindow[i+1].Close) / _daywindow[i+1].Close) > 0) ? 1 : -1;
_samples.Enqueue(returns);
}
double[][] inputs = _samples.ToArray();
Debug("Training...");
// instantiate new Support Vector Machine
svm = new SupportVectorMachine(inputs: inputSize);
// Run the learning algorithm
var teacher = new LinearCoordinateDescent(svm, inputs, _targets);
teacher.UseComplexityHeuristic = true;
teacher.Run();
// Get the most current inputs and make a prediction for tomorrow
double[] _curInputs = new double[inputSize];
for (int i=0;i<inputSize;i++) {
_curInputs[i] = (((_daywindow [i].Close - _daywindow[i+1].Close) / _daywindow[i+1].Close) > 0) ? 1 : -1;
}
double output = svm.Compute(_curInputs);
Debug ("Prediction: " + output);
if (output > 0) {
// update stoploss
stopPrice = data[_symbol].Close - stopPct * _atr;
// Go Long
if (Securities [_symbol].Holdings.Quantity == 0) {
Order(_symbol, Portfolio.TotalPortfolioValue/data[_symbol].Close*tradePct);
Debug ("Buy Shares: " + Securities [_symbol].Holdings.Quantity);
} else if (Securities [_symbol].Holdings.Quantity < 0){
// if we are in a short position: calculate how many shares needed to go long
Order(_symbol, ((decimal)-Securities[_symbol].Holdings.Quantity)+(Portfolio.TotalPortfolioValue/data[_symbol].Close*tradePct));
Debug ("Buy Shares: " + Securities [_symbol].Holdings.Quantity);
}
} else {
// Exit Market
Liquidate(_symbol);
Debug ("Exiting Market");
}
}
// Handle Stop Loss
if (Securities [_symbol].Holdings.Quantity != 0) {
if (Securities[_symbol].Holdings.IsLong) {
if (data[_symbol].Low <= stopPrice) {
//Liquidate(_symbol);
MarketOrder(_symbol,-Securities[_symbol].Holdings.Quantity);
Debug ("Hit StopLoss: " + data[_symbol].Low);
}
}
if (Securities [_symbol].Holdings.IsShort) {
if (data[_symbol].High >= stopPrice) {
//Liquidate(_symbol);
MarketOrder(_symbol,-Securities[_symbol].Holdings.Quantity);
Debug ("Hit StopLoss: " + data[_symbol].High);
}
}
}
}
/// <summary>
/// OnDataDay Callback (called every day) - Used to build the daily history (rolling window) of close prices
/// </summary>
private void OnDataDay(object sender, TradeBar consolidated)
{
//Inject data into the rolling window.
_daywindow.Add(consolidated);
}
}
}