siat 3.10.130__py3-none-any.whl → 3.10.132__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (217) hide show
  1. build/lib/build/lib/siat/__init__.py +75 -0
  2. build/lib/build/lib/siat/allin.py +137 -0
  3. build/lib/build/lib/siat/assets_liquidity.py +915 -0
  4. build/lib/build/lib/siat/beta_adjustment.py +1058 -0
  5. build/lib/build/lib/siat/beta_adjustment_china.py +548 -0
  6. build/lib/build/lib/siat/blockchain.py +143 -0
  7. build/lib/build/lib/siat/bond.py +2900 -0
  8. build/lib/build/lib/siat/bond_base.py +992 -0
  9. build/lib/build/lib/siat/bond_china.py +100 -0
  10. build/lib/build/lib/siat/bond_zh_sina.py +143 -0
  11. build/lib/build/lib/siat/capm_beta.py +783 -0
  12. build/lib/build/lib/siat/capm_beta2.py +887 -0
  13. build/lib/build/lib/siat/common.py +5360 -0
  14. build/lib/build/lib/siat/compare_cross.py +642 -0
  15. build/lib/build/lib/siat/copyrights.py +18 -0
  16. build/lib/build/lib/siat/cryptocurrency.py +667 -0
  17. build/lib/build/lib/siat/economy.py +1471 -0
  18. build/lib/build/lib/siat/economy2.py +1853 -0
  19. build/lib/build/lib/siat/esg.py +536 -0
  20. build/lib/build/lib/siat/event_study.py +815 -0
  21. build/lib/build/lib/siat/fama_french.py +1521 -0
  22. build/lib/build/lib/siat/fin_stmt2_yahoo.py +982 -0
  23. build/lib/build/lib/siat/financial_base.py +1160 -0
  24. build/lib/build/lib/siat/financial_statements.py +598 -0
  25. build/lib/build/lib/siat/financials.py +2339 -0
  26. build/lib/build/lib/siat/financials2.py +1278 -0
  27. build/lib/build/lib/siat/financials_china.py +4433 -0
  28. build/lib/build/lib/siat/financials_china2.py +2212 -0
  29. build/lib/build/lib/siat/fund.py +629 -0
  30. build/lib/build/lib/siat/fund_china.py +3307 -0
  31. build/lib/build/lib/siat/future_china.py +551 -0
  32. build/lib/build/lib/siat/google_authenticator.py +47 -0
  33. build/lib/build/lib/siat/grafix.py +3636 -0
  34. build/lib/build/lib/siat/holding_risk.py +867 -0
  35. build/lib/build/lib/siat/luchy_draw.py +638 -0
  36. build/lib/build/lib/siat/market_china.py +1168 -0
  37. build/lib/build/lib/siat/markowitz.py +2363 -0
  38. build/lib/build/lib/siat/markowitz2.py +3150 -0
  39. build/lib/build/lib/siat/markowitz2_20250704.py +2969 -0
  40. build/lib/build/lib/siat/markowitz2_20250705.py +3158 -0
  41. build/lib/build/lib/siat/markowitz_simple.py +373 -0
  42. build/lib/build/lib/siat/ml_cases.py +2291 -0
  43. build/lib/build/lib/siat/ml_cases_example.py +60 -0
  44. build/lib/build/lib/siat/option_china.py +3069 -0
  45. build/lib/build/lib/siat/option_pricing.py +1925 -0
  46. build/lib/build/lib/siat/other_indexes.py +409 -0
  47. build/lib/build/lib/siat/risk_adjusted_return.py +1576 -0
  48. build/lib/build/lib/siat/risk_adjusted_return2.py +1900 -0
  49. build/lib/build/lib/siat/risk_evaluation.py +2218 -0
  50. build/lib/build/lib/siat/risk_free_rate.py +351 -0
  51. build/lib/build/lib/siat/sector_china.py +4140 -0
  52. build/lib/build/lib/siat/security_price2.py +727 -0
  53. build/lib/build/lib/siat/security_prices.py +3408 -0
  54. build/lib/build/lib/siat/security_trend.py +402 -0
  55. build/lib/build/lib/siat/security_trend2.py +646 -0
  56. build/lib/build/lib/siat/stock.py +4284 -0
  57. build/lib/build/lib/siat/stock_advice_linear.py +934 -0
  58. build/lib/build/lib/siat/stock_base.py +26 -0
  59. build/lib/build/lib/siat/stock_china.py +2095 -0
  60. build/lib/build/lib/siat/stock_prices_kneighbors.py +910 -0
  61. build/lib/build/lib/siat/stock_prices_linear.py +386 -0
  62. build/lib/build/lib/siat/stock_profile.py +707 -0
  63. build/lib/build/lib/siat/stock_technical.py +3305 -0
  64. build/lib/build/lib/siat/stooq.py +74 -0
  65. build/lib/build/lib/siat/transaction.py +347 -0
  66. build/lib/build/lib/siat/translate.py +5183 -0
  67. build/lib/build/lib/siat/valuation.py +1378 -0
  68. build/lib/build/lib/siat/valuation_china.py +2076 -0
  69. build/lib/build/lib/siat/var_model_validation.py +444 -0
  70. build/lib/build/lib/siat/yf_name.py +811 -0
  71. build/lib/siat/__init__.py +75 -0
  72. build/lib/siat/allin.py +137 -0
  73. build/lib/siat/assets_liquidity.py +915 -0
  74. build/lib/siat/beta_adjustment.py +1058 -0
  75. build/lib/siat/beta_adjustment_china.py +548 -0
  76. build/lib/siat/blockchain.py +143 -0
  77. build/lib/siat/bond.py +2900 -0
  78. build/lib/siat/bond_base.py +992 -0
  79. build/lib/siat/bond_china.py +100 -0
  80. build/lib/siat/bond_zh_sina.py +143 -0
  81. build/lib/siat/capm_beta.py +783 -0
  82. build/lib/siat/capm_beta2.py +887 -0
  83. build/lib/siat/common.py +5360 -0
  84. build/lib/siat/compare_cross.py +642 -0
  85. build/lib/siat/copyrights.py +18 -0
  86. build/lib/siat/cryptocurrency.py +667 -0
  87. build/lib/siat/economy.py +1471 -0
  88. build/lib/siat/economy2.py +1853 -0
  89. build/lib/siat/esg.py +536 -0
  90. build/lib/siat/event_study.py +815 -0
  91. build/lib/siat/fama_french.py +1521 -0
  92. build/lib/siat/fin_stmt2_yahoo.py +982 -0
  93. build/lib/siat/financial_base.py +1160 -0
  94. build/lib/siat/financial_statements.py +598 -0
  95. build/lib/siat/financials.py +2339 -0
  96. build/lib/siat/financials2.py +1278 -0
  97. build/lib/siat/financials_china.py +4433 -0
  98. build/lib/siat/financials_china2.py +2212 -0
  99. build/lib/siat/fund.py +629 -0
  100. build/lib/siat/fund_china.py +3307 -0
  101. build/lib/siat/future_china.py +551 -0
  102. build/lib/siat/google_authenticator.py +47 -0
  103. build/lib/siat/grafix.py +3636 -0
  104. build/lib/siat/holding_risk.py +867 -0
  105. build/lib/siat/luchy_draw.py +638 -0
  106. build/lib/siat/market_china.py +1168 -0
  107. build/lib/siat/markowitz.py +2363 -0
  108. build/lib/siat/markowitz2.py +3150 -0
  109. build/lib/siat/markowitz2_20250704.py +2969 -0
  110. build/lib/siat/markowitz2_20250705.py +3158 -0
  111. build/lib/siat/markowitz_simple.py +373 -0
  112. build/lib/siat/ml_cases.py +2291 -0
  113. build/lib/siat/ml_cases_example.py +60 -0
  114. build/lib/siat/option_china.py +3069 -0
  115. build/lib/siat/option_pricing.py +1925 -0
  116. build/lib/siat/other_indexes.py +409 -0
  117. build/lib/siat/risk_adjusted_return.py +1576 -0
  118. build/lib/siat/risk_adjusted_return2.py +1900 -0
  119. build/lib/siat/risk_evaluation.py +2218 -0
  120. build/lib/siat/risk_free_rate.py +351 -0
  121. build/lib/siat/sector_china.py +4140 -0
  122. build/lib/siat/security_price2.py +727 -0
  123. build/lib/siat/security_prices.py +3408 -0
  124. build/lib/siat/security_trend.py +402 -0
  125. build/lib/siat/security_trend2.py +646 -0
  126. build/lib/siat/stock.py +4284 -0
  127. build/lib/siat/stock_advice_linear.py +934 -0
  128. build/lib/siat/stock_base.py +26 -0
  129. build/lib/siat/stock_china.py +2095 -0
  130. build/lib/siat/stock_prices_kneighbors.py +910 -0
  131. build/lib/siat/stock_prices_linear.py +386 -0
  132. build/lib/siat/stock_profile.py +707 -0
  133. build/lib/siat/stock_technical.py +3305 -0
  134. build/lib/siat/stooq.py +74 -0
  135. build/lib/siat/transaction.py +347 -0
  136. build/lib/siat/translate.py +5183 -0
  137. build/lib/siat/valuation.py +1378 -0
  138. build/lib/siat/valuation_china.py +2076 -0
  139. build/lib/siat/var_model_validation.py +444 -0
  140. build/lib/siat/yf_name.py +811 -0
  141. siat/__init__.py +0 -0
  142. siat/allin.py +0 -0
  143. siat/assets_liquidity.py +0 -0
  144. siat/beta_adjustment.py +0 -0
  145. siat/beta_adjustment_china.py +0 -0
  146. siat/blockchain.py +0 -0
  147. siat/bond.py +0 -0
  148. siat/bond_base.py +0 -0
  149. siat/bond_china.py +0 -0
  150. siat/bond_zh_sina.py +0 -0
  151. siat/capm_beta.py +0 -0
  152. siat/capm_beta2.py +0 -0
  153. siat/common.py +94 -30
  154. siat/compare_cross.py +0 -0
  155. siat/copyrights.py +0 -0
  156. siat/cryptocurrency.py +0 -0
  157. siat/economy.py +0 -0
  158. siat/economy2.py +0 -0
  159. siat/esg.py +0 -0
  160. siat/event_study.py +0 -0
  161. siat/fama_french.py +0 -0
  162. siat/fin_stmt2_yahoo.py +0 -0
  163. siat/financial_base.py +0 -0
  164. siat/financial_statements.py +0 -0
  165. siat/financials.py +0 -0
  166. siat/financials2.py +0 -0
  167. siat/financials_china.py +0 -0
  168. siat/financials_china2.py +0 -0
  169. siat/fund.py +0 -0
  170. siat/fund_china.py +0 -0
  171. siat/future_china.py +0 -0
  172. siat/google_authenticator.py +0 -0
  173. siat/grafix.py +1 -1
  174. siat/holding_risk.py +0 -0
  175. siat/luchy_draw.py +0 -0
  176. siat/market_china.py +7 -1
  177. siat/markowitz.py +0 -0
  178. siat/markowitz2.py +240 -39
  179. siat/markowitz2_20250704.py +2969 -0
  180. siat/markowitz2_20250705.py +3158 -0
  181. siat/markowitz_simple.py +0 -0
  182. siat/ml_cases.py +0 -0
  183. siat/ml_cases_example.py +0 -0
  184. siat/option_china.py +0 -0
  185. siat/option_pricing.py +0 -0
  186. siat/other_indexes.py +0 -0
  187. siat/risk_adjusted_return.py +0 -0
  188. siat/risk_adjusted_return2.py +0 -0
  189. siat/risk_evaluation.py +0 -0
  190. siat/risk_free_rate.py +0 -0
  191. siat/sector_china.py +0 -0
  192. siat/security_price2.py +0 -0
  193. siat/security_prices.py +3 -1
  194. siat/security_trend.py +0 -0
  195. siat/security_trend2.py +1 -1
  196. siat/stock.py +4 -2
  197. siat/stock_advice_linear.py +0 -0
  198. siat/stock_base.py +0 -0
  199. siat/stock_china.py +0 -0
  200. siat/stock_prices_kneighbors.py +0 -0
  201. siat/stock_prices_linear.py +0 -0
  202. siat/stock_profile.py +0 -0
  203. siat/stock_technical.py +0 -0
  204. siat/stooq.py +0 -0
  205. siat/transaction.py +0 -0
  206. siat/translate.py +11 -11
  207. siat/valuation.py +0 -0
  208. siat/valuation_china.py +0 -0
  209. siat/var_model_validation.py +0 -0
  210. siat/yf_name.py +0 -0
  211. {siat-3.10.130.dist-info → siat-3.10.132.dist-info}/METADATA +11 -11
  212. siat-3.10.132.dist-info/RECORD +218 -0
  213. siat-3.10.132.dist-info/top_level.txt +4 -0
  214. siat-3.10.130.dist-info/RECORD +0 -76
  215. siat-3.10.130.dist-info/top_level.txt +0 -1
  216. {siat-3.10.130.dist-info → siat-3.10.132.dist-info}/WHEEL +0 -0
  217. {siat-3.10.130.dist-info → siat-3.10.132.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,386 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @function: 预测美股股价,教学演示用,其他用途责任自负
4
+ @model:线性模型,ols, righe, lasso, elasticnet
5
+ @version:v1.0,2019.4.4
6
+ @purpose: 仅限机器学习课程案例使用
7
+ @author: 王德宏,北京外国语大学国际商学院
8
+ """
9
+
10
+ #=====================================================================
11
+ def get_stock_price(ticker,atdate,fromdate):
12
+ """
13
+ 功能:抓取美股股价
14
+ 输出:指定美股的收盘价格序列,最新日期的股价排列在前
15
+ ticker:美股股票代码
16
+ atdate:当前日期,既可以是今天日期,也可以是一个历史日期,datetime类型
17
+ fromdate:样本开始日期,尽量远的日期,以便取得足够多的原始样本,类型同atdate
18
+ """
19
+
20
+ #仅为调试用的函数入口参数,正式使用前需要注释掉!
21
+ #ticker='MSFT'
22
+ #atdate='3/29/2019'
23
+ #fromdate='1/1/2015'
24
+ #---------------------------------------------
25
+
26
+ #抓取美股股票价格
27
+ from pandas_datareader import data
28
+ price=data.DataReader(ticker,'stooq',fromdate,atdate)
29
+
30
+ #去掉比起始日期更早的样本
31
+ price2=price[price.index >= fromdate]
32
+
33
+
34
+ #按日期降序排序,近期的价格排在前面
35
+ sortedprice=price2.sort_index(axis=0,ascending=False)
36
+
37
+ #提取日期和星期几
38
+ #sortedprice['Date']=sortedprice.index.date
39
+ sortedprice['Date']=sortedprice.index.strftime("%Y-%m-%d")
40
+ sortedprice['Weekday']=sortedprice.index.weekday+1
41
+
42
+ #生成输出数据格式:日期,星期几,收盘价
43
+ dfprice=sortedprice[['Date','Weekday','Close']]
44
+
45
+ return dfprice
46
+
47
+
48
+ if __name__=='__main__':
49
+ dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
50
+ dfprice.head(5)
51
+ dfprice.tail(3)
52
+ dfprice[dfprice.Date == '2019-03-29']
53
+ dfprice[(dfprice.Date>='2019-03-20') & (dfprice.Date<='2019-03-29')]
54
+
55
+
56
+ #=====================================================================
57
+ def make_price_sample(dfprice,n_nextdays=1,n_samples=240,n_features=20):
58
+ """
59
+ 功能:生成指定股票的价格样本
60
+ ticker:美股股票代码
61
+ n_nextdays:预测从atdate开始未来第几天的股价,默认为1
62
+ n_samples:需要生成的样本个数,默认240个(一年的平均交易天数)
63
+ n_features:使用的特征数量,默认20个(一个月的平均交易天数)
64
+ """
65
+
66
+ #提取收盘价,Series类型
67
+ closeprice=dfprice.Close
68
+
69
+ #将closeprice转换为机器学习需要的ndarray类型ndprice
70
+ import numpy as np
71
+ ndprice=np.asmatrix(closeprice,dtype=None)
72
+
73
+ #生成第一个标签样本:标签矩阵y(形状:n_samples x 1)
74
+ import numpy as np
75
+ y=np.asmatrix(ndprice[0,0])
76
+ #生成第一个特征样本:特征矩阵X(形状:n_samples x n_features)
77
+ X=ndprice[0,n_nextdays:n_features+n_nextdays]
78
+
79
+ #生成其余的标签样本和特征样本
80
+ for i in range(1,n_samples):
81
+ y_row=np.asmatrix(ndprice[0,i])
82
+ y=np.append(y,y_row,axis=0)
83
+
84
+ X_row=ndprice[0,(n_nextdays+i):(n_features+n_nextdays+i)]
85
+ X=np.append(X,X_row,axis=0)
86
+
87
+ return X,y,ndprice
88
+
89
+ if __name__=='__main__':
90
+ fdprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
91
+ X,y,ndprice=make_price_sample(fdprice,1,240,20)
92
+ y[:5]
93
+ y[2:5] #第1行的序号为0
94
+ X[:5]
95
+ X[:-5]
96
+ X[3-1,2-1]
97
+
98
+
99
+ #=====================================================================
100
+ def bestR1(X,y):
101
+ """
102
+ 功能:给定特征矩阵和标签,使用岭回归,返回最优的alpha参数和模型
103
+ 最优策略:测试集分数最高,不管过拟合问题
104
+ """
105
+
106
+ import numpy as np
107
+ #将整个样本随机分割为训练集和测试集
108
+ from sklearn.model_selection import train_test_split
109
+ X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)
110
+
111
+ #初始化alpha,便于判断上行下行方向
112
+ alphalist=[0.001,0.0011,0.00999,0.01,0.01001,0.999,1,1.01, \
113
+ 9.99,10,10.01,99,100,101,999,1000,1001,10000]
114
+
115
+ from sklearn.linear_model import RidgeCV
116
+ reg=RidgeCV(alphas=alphalist,cv=5,fit_intercept=True,normalize=True)
117
+
118
+ reg.fit(X_train, y_train)
119
+ score_train=reg.score(X_train, y_train)
120
+ score_test=reg.score(X_test, y_test)
121
+ alpha=reg.alpha_
122
+ #print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
123
+
124
+ #确定alpha参数的优化范围
125
+ if alpha in [0.001,0.01,1,2,10,100,1000,10000]:
126
+ #print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
127
+ return reg,alpha,score_train,score_test
128
+
129
+ if 0.001 < alpha < 0.01:
130
+ alphalist1=np.arange(0.001,0.01,0.0005)
131
+ if 0.01 < alpha < 1:
132
+ alphalist1=np.arange(0.01,1,0.005)
133
+ if 1 < alpha < 10:
134
+ alphalist1=np.arange(1,10,0.01)
135
+ if 10 < alpha < 100:
136
+ alphalist1=np.arange(10,100,0.1)
137
+ if 100 < alpha < 1000:
138
+ alphalist1=np.arange(100,1000,1)
139
+ if 1000 < alpha < 10000:
140
+ alphalist1=np.arange(1000,10000,10)
141
+
142
+ reg1=RidgeCV(alphas=alphalist1,cv=5,fit_intercept=True,normalize=True)
143
+ reg1.fit(X_train, y_train)
144
+ score1_train=reg1.score(X_train,y_train)
145
+ score1_test =reg1.score(X_test, y_test)
146
+ alpha1=reg1.alpha_
147
+
148
+ #print("%.5f, %.5f, %.5f"%(alpha1,score1_train,score1_test))
149
+ return reg1,alpha1,score1_train,score1_test
150
+
151
+
152
+ if __name__=='__main__':
153
+ dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
154
+ X,y,ndprice=make_price_sample(dfprice,1,240,20)
155
+
156
+ model,alpha,score_train,score_test=bestR1(X,y)
157
+ print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
158
+ #结果:0.045,0.9277,0.8940
159
+
160
+ X_new=ndprice[0,0:20]
161
+ y_new=model.predict(X_new)
162
+ print("%.2f"%y_new)
163
+ #结果:119.43
164
+ #=====================================================================
165
+ def bestL1(X,y):
166
+ """
167
+ 功能:给定特征矩阵和标签,使用拉索回归,返回最优的alpha参数和模型
168
+ 最优策略:测试集分数最高,不管过拟合问题
169
+ """
170
+ import numpy as np
171
+ #将整个样本随机分割为训练集和测试集
172
+ from sklearn.utils import column_or_1d
173
+ y=column_or_1d(y,warn=False)
174
+ from sklearn.model_selection import train_test_split
175
+ X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)
176
+
177
+
178
+ #初始alpha,便于判断上行下行方向
179
+ alphalist=[0.001,0.0011,0.00999,0.01,0.01001,0.999,1,1.01,1.99,2,2.01, \
180
+ 9.99,10,10.01,99,100,101,999,1000,1001,10000]
181
+
182
+ from sklearn.linear_model import LassoCV
183
+ reg=LassoCV(alphas=alphalist,max_iter=10**6, \
184
+ cv=5,fit_intercept=True,normalize=True)
185
+ reg.fit(X_train, y_train)
186
+ score_train=reg.score(X_train,y_train)
187
+ score_test =reg.score(X_test, y_test)
188
+ alpha=reg.alpha_
189
+ #print("Step0: %.4f, %.5f, %.5f"%(alpha,score_train,score_test))
190
+
191
+ #确定alpha参数的优化范围
192
+ if alpha in [0.001,0.01,1,2,10,100,1000,10000]:
193
+ #print("Step01: %.5f, %.5f, %.5f"%(alpha,score_train,score_test))
194
+ return reg,alpha,score_train,score_test
195
+
196
+ if 0.001 < alpha < 0.01:
197
+ alphalist1=np.arange(0.0015,0.01,0.0005)
198
+
199
+ if 0.01 < alpha < 1:
200
+ alphalist1=np.arange(0.015,1,0.005)
201
+
202
+ if 1 < alpha < 10:
203
+ alphalist1=np.arange(1.01,10,0.01)
204
+
205
+ if 10 < alpha < 100:
206
+ alphalist1=np.arange(10.1,100,0.1)
207
+
208
+ if 100 < alpha < 1000:
209
+ alphalist1=np.arange(101,1000,1)
210
+
211
+ if 1000 < alpha < 10000:
212
+ alphalist1=np.arange(1010,10000,10)
213
+
214
+ reg1=LassoCV(alphas=alphalist1,cv=5,fit_intercept=True,normalize=True)
215
+ reg1.fit(X_train, y_train)
216
+ score1_train=reg1.score(X_train,y_train)
217
+ score1_test =reg1.score(X_test, y_test)
218
+ alpha1=reg1.alpha_
219
+ #print("Step1: %.4f, %.5f, %.5f"%(alpha1,score1_train,score1_test))
220
+ return reg1,alpha1,score1_train,score1_test
221
+
222
+ if __name__=='__main__':
223
+ dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
224
+ X,y,ndprice=make_price_sample(dfprice,1,240,20)
225
+
226
+ model,alpha,score_train,score_test=bestL1(X,y)
227
+ print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
228
+ #结果:0.015,0.9284,0.9043
229
+
230
+ X_new=ndprice[0,0:20]
231
+ y_new=model.predict(X_new)
232
+ print("%.2f"%y_new)
233
+ #结果:119.37
234
+
235
+ #=====================================================================
236
+
237
+ def bestEN2(X,y,maxalpha=2):
238
+ """
239
+ 功能:给定特征矩阵和标签,使用弹性网络回归,返回最优的alpha参数和模型
240
+ 最优策略:利用ElasticNetCV筛选机制,速度慢
241
+ """
242
+ #将整个样本随机分割为训练集和测试集
243
+ from sklearn.utils import column_or_1d
244
+ y=column_or_1d(y,warn=False)
245
+
246
+ from sklearn.model_selection import train_test_split
247
+ X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=66)
248
+
249
+ #限定参数范围
250
+ import numpy as np
251
+ alphalist=np.arange(0.01,maxalpha,0.01)
252
+ l1list =np.arange(0.01,1,0.01)
253
+
254
+ from sklearn.linear_model import ElasticNetCV
255
+ reg=ElasticNetCV(alphas=alphalist,l1_ratio=l1list)
256
+
257
+ reg.fit(X_train, y_train)
258
+ score_train=reg.score(X_train,y_train)
259
+ score_test =reg.score(X_test, y_test)
260
+ alpha=reg.alpha_
261
+ l1ratio=reg.l1_ratio_
262
+
263
+ return reg,alpha,l1ratio,score_train,score_test
264
+
265
+ if __name__=='__main__':
266
+ dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
267
+ X,y,ndprice=make_price_sample(dfprice,1,240,20)
268
+
269
+ model,alpha,l1ratio,score_train,score_test=bestEN2(X,y)
270
+ print("%.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
271
+ #结果:0.42,0.99,0.9258,0.9174
272
+
273
+ X_new=ndprice[0,0:20]
274
+ y_new=model.predict(X_new)
275
+ print("%.2f"%y_new)
276
+ #结果:119.60
277
+
278
+ #=======
279
+
280
+ #==============================================================================
281
+
282
+ def bestEN3(X,y):
283
+ """
284
+ 功能:给定特征矩阵和标签,使用弹性网络回归,返回最优的alpha参数和模型
285
+ 最优策略:利用cv交叉验证,速度快
286
+ 算法贡献者:徐乐欣(韩语国商)
287
+ """
288
+ import numpy as np
289
+ #将整个样本随机分割为训练集和测试集
290
+ from sklearn.utils import column_or_1d
291
+ y=column_or_1d(y,warn=False)
292
+ from sklearn.model_selection import train_test_split
293
+ X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=66)
294
+
295
+ from sklearn.linear_model import ElasticNetCV
296
+ #reg=ElasticNetCV(cv=5, random_state=0)
297
+ #reg.fit(X,y)
298
+
299
+ l1list=np.arange(0.01,1,0.01)
300
+ ENet=ElasticNetCV(alphas=None, copy_X=True, cv=5, eps=0.001, \
301
+ fit_intercept=True,l1_ratio=l1list, max_iter=8000, \
302
+ n_alphas=100, n_jobs=None,normalize=True, \
303
+ positive=False, precompute='auto', random_state=0, \
304
+ selection='cyclic', tol=0.0001, verbose=0)
305
+ ENet.fit(X_train, y_train)
306
+ score_train=ENet.score(X_train, y_train)
307
+ score_test=ENet.score(X_test, y_test)
308
+ alpha=ENet.alpha_
309
+ l1ratio=ENet.l1_ratio_
310
+ #print("S1: %.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
311
+
312
+ return ENet,alpha,l1ratio,score_train,score_test
313
+
314
+ if __name__=='__main__':
315
+ dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
316
+ X,y,ndprice=make_price_sample(dfprice,1,240,20)
317
+
318
+ model,alpha,l1ratio,score_train,score_test=bestEN3(X,y)
319
+ print("%.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
320
+ #结果:0.005836,0.99,0.925,0.9194
321
+
322
+ X_new=ndprice[0,0:20]
323
+ y_new=model.predict(X_new)
324
+ print("%.2f"%y_new)
325
+ #结果:119.48
326
+ #==============================================================================
327
+
328
+
329
+ def bestEN1(X,y,maxalpha=2):
330
+ """
331
+ 功能:给定特征矩阵和标签,使用弹性网络回归,返回最优的alpha参数和模型
332
+ 最优策略:对alpha和l1_ratio进行暴力枚举,搜索最高测试集分数,速度中等
333
+ 算法贡献者:徐乐欣(韩语国商)
334
+ """
335
+
336
+ #将整个样本随机分割为训练集和测试集
337
+ from sklearn.utils import column_or_1d
338
+ y=column_or_1d(y,warn=False)
339
+ from sklearn.model_selection import train_test_split
340
+ X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=66)
341
+
342
+ #设立初始测试集分数门槛
343
+ king_score=0.6
344
+ from sklearn.linear_model import ElasticNet
345
+
346
+ #限定参数范围
347
+ import numpy as np
348
+ alphalist=np.arange(0.01,maxalpha,0.01)
349
+ l1list =np.arange(0.01,1,0.01)
350
+
351
+ for i in alphalist:
352
+ for j in l1list:
353
+ reg=ElasticNet(alpha=i,l1_ratio=j)
354
+ reg.fit(X_train,y_train)
355
+ temp_score=reg.score(X_test,y_test)
356
+ if temp_score > king_score:
357
+ king_score=temp_score
358
+ alpha=i
359
+ l1ratio=j
360
+ score_train=reg.score(X_train,y_train)
361
+ score_test=temp_score
362
+ model=reg
363
+
364
+ return model,alpha,l1ratio,score_train,score_test
365
+
366
+ if __name__=='__main__':
367
+ dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
368
+ X,y,ndprice=make_price_sample(dfprice,1,240,20)
369
+
370
+ model,alpha,l1ratio,score_train,score_test=bestEN1(X,y)
371
+ print("%.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
372
+ #结果:1.31,0.56,0.9241,0.9196
373
+
374
+ X_new=ndprice[0,0:20]
375
+ y_new=model.predict(X_new)
376
+ print("%.2f"%y_new)
377
+ #结果:119.36
378
+ #==============================================================================
379
+
380
+
381
+
382
+
383
+
384
+
385
+ #==============================================================================
386
+