决策树模型(固定模型)

来源:https://uqer.io/community/share/568dce2d228e5b18e2ba296e

楼主上学时学的是机器学习,现在在BAT做数据挖掘,一直对将机器学习的知识应用到金融领域比较感兴趣。

最近发现了优矿这个平台之后,有点着迷了,通过看大家的策略,也学到些知识。

因为楼主对金融投资认识不多,所以写的策略比较简单粗暴,希望向大家多多学习~

策略: 1、不预测具体股价,只预测次日收盘价相比今日是涨是跌; 2、如果预测为涨,则全部买入或持有;如果预测为跌,则全部卖出。

方法: 基于某只股票的历史数据,采用机器学习的方法,挖掘其中规律,预测该只股票次日收盘价是涨还是跌

  1. import numpy as np
  2. from CAL.PyCAL import *
  3. from sklearn.cross_validation import train_test_split
  4. from sklearn.externals import joblib
  5. import pandas as pd
  6. cal = Calendar('China.SSE')
  7. # 第一步:设置基本参数
  8. start = '2015-01-01'
  9. end = '2015-11-01'
  10. capital_base = 1000000
  11. refresh_rate = 1
  12. benchmark = 'HS300'
  13. ##HS300
  14. freq = 'd'
  15. #601872.XSHG HS300
  16. # 第二步:选择主题,设置股票池
  17. universe = ['601872.XSHG', ]
  18. ##训练模型
  19. def model_train(begin_date,end_date):
  20. data1=DataAPI.MktEqudGet(secID=u"601872.XSHG",beginDate=begin_date,endDate=end_date,field=['tradeDate','highestPrice','lowestPrice','openPrice','closePrice','turnoverVol','turnoverRate'],pandas="1")
  21. data2=DataAPI.MktStockFactorsDateRangeGet(secID=u"601872.XSHG",beginDate=begin_date,endDate=end_date,field=['tradeDate','DAVOL5','EMA5','EMA10','MA5','MA20','RSI','VOL5','VOL10','MACD'],pandas="1")
  22. df_data=pd.merge(data1,data2,on='tradeDate')
  23. tmp=[]
  24. for i in range(len(df_data.values)):
  25. mark_1=0
  26. for j in range(len(df_data.values[i])):
  27. if str(df_data.values[i][j])=='nan':
  28. mark_1=1
  29. if mark_1==0:
  30. a=list(df_data.values[i])
  31. a.append(df_data.values[i][4]-df_data.values[i][10])
  32. a.append(df_data.values[i][4]-df_data.values[i][11])
  33. tmp.append(a)
  34. data=tmp
  35. print len(data)
  36. x=[]
  37. y=[]
  38. for i in range(len(data)-1):
  39. if data[i][4]<data[i+1][4]:
  40. y.append(1)
  41. else:
  42. y.append(0)
  43. x.append(data[i][1:])
  44. x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.0, random_state=42)
  45. ##训练模型
  46. from sklearn import tree
  47. clf = tree.DecisionTreeClassifier( max_depth =3 )
  48. clf.fit(x_train,y_train)
  49. y_predict=clf.predict(x_train)
  50. n_1=0
  51. for i in range(len(y_predict)):
  52. if y_train[i]==y_predict[i]:
  53. n_1=n_1+1
  54. n_2=0
  55. for i in range(len(y_predict)):
  56. if y_train[i]==y_predict[i] and y_predict[i]==1:
  57. n_2=n_2+1
  58. joblib.dump(clf, 'clf.model')
  59. return clf,float(n_1)/float( len(y_predict) ),float(n_2)/float( int(sum(y_train)) ) ,float(sum(y_train))/float(len(y_train))
  60. def initialize(account):
  61. ##使用2015年2月1日之前800个交易日的数据进行训练
  62. today='20150201'
  63. train_begin_date = cal.advanceDate(today,'-800B',BizDayConvention.Preceding).strftime('%Y%m%d')
  64. train_end_date = cal.advanceDate(today,'-1B',BizDayConvention.Preceding).strftime('%Y%m%d')
  65. model,acc_rate,recall_rate,balance=model_train(train_begin_date,train_end_date)
  66. print acc_rate,recall_rate,balance ##正确率、召回率、正负样本均衡度
  67. def handle_data(account):
  68. # 本策略将使用account的以下属性:
  69. # account.referencePortfolioValue表示根据前收计算的当前持有证券市场价值与现金之和。
  70. # account.universe表示当天,股票池中可以进行交易的证券池,剔除停牌退市等股票。
  71. # account.referencePrice表示股票的参考价,一般使用的是上一日收盘价。
  72. # account.valid_secpos字典,键为证券代码,值为虚拟账户中当前所持有该股票的数量。
  73. c = account.referencePortfolioValue
  74. today = account.current_date.strftime('%Y-%m-%d')
  75. begin_date = cal.advanceDate(today,'-1B',BizDayConvention.Preceding).strftime('%Y%m%d')
  76. end_date = cal.advanceDate(today,'-1B',BizDayConvention.Preceding).strftime('%Y%m%d')
  77. data1=DataAPI.MktEqudGet(secID=u"601872.XSHG",beginDate=begin_date,endDate=end_date,field=['tradeDate','highestPrice','lowestPrice','openPrice','closePrice','turnoverVol','turnoverRate'],pandas="1")
  78. data2=DataAPI.MktStockFactorsDateRangeGet(secID=u"601872.XSHG",beginDate=begin_date,endDate=end_date,field=['tradeDate','DAVOL5','EMA5','EMA10','MA5','MA20','RSI','VOL5','VOL10','MACD'],pandas="1")
  79. df_data=pd.merge(data1,data2,on='tradeDate')
  80. a=list(df_data.values[0])
  81. a.append(df_data.values[0][4]-df_data.values[0][10])
  82. a.append(df_data.values[0][4]-df_data.values[0][11])
  83. x_predict=a[1:]
  84. for i in range(len(x_predict)):
  85. if str(x_predict[i])=='nan':
  86. x_predict[i]=10000000
  87. clf = joblib.load('clf.model')
  88. y_predict=clf.predict(x_predict)
  89. # 计算调仓数量
  90. change = {}
  91. for stock in account.universe:
  92. if y_predict>0 and stock not in account.valid_secpos:
  93. p = account.referencePrice[stock]
  94. order(stock,int(c / p))
  95. if y_predict==0 and stock in account.valid_secpos:
  96. order_to(stock,0)
  97. #print today,x_predict[3],y_predict

决策树模型(固定模型) - 图1

  1. 713
  2. 0.580056179775 0.334384858044 0.445224719101

This is an empty markdown cell