siat 3.10.132__py3-none-any.whl → 3.10.133__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. siat/__init__.py +0 -0
  2. siat/allin.py +0 -0
  3. siat/assets_liquidity.py +0 -0
  4. siat/beta_adjustment.py +0 -0
  5. siat/beta_adjustment_china.py +0 -0
  6. siat/blockchain.py +0 -0
  7. siat/bond.py +0 -0
  8. siat/bond_base.py +0 -0
  9. siat/bond_china.py +0 -0
  10. siat/bond_zh_sina.py +0 -0
  11. siat/capm_beta.py +0 -0
  12. siat/capm_beta2.py +0 -0
  13. siat/compare_cross.py +0 -0
  14. siat/copyrights.py +0 -0
  15. siat/cryptocurrency.py +0 -0
  16. siat/economy.py +0 -0
  17. siat/economy2.py +0 -0
  18. siat/esg.py +0 -0
  19. siat/event_study.py +0 -0
  20. siat/exchange_bond_china.pickle +0 -0
  21. siat/fama_french.py +0 -0
  22. siat/fin_stmt2_yahoo.py +0 -0
  23. siat/financial_base.py +0 -0
  24. siat/financial_statements.py +0 -0
  25. siat/financials.py +0 -0
  26. siat/financials2.py +0 -0
  27. siat/financials_china.py +0 -0
  28. siat/financials_china2.py +0 -0
  29. siat/fund.py +0 -0
  30. siat/fund_china.pickle +0 -0
  31. siat/fund_china.py +0 -0
  32. siat/future_china.py +0 -0
  33. siat/google_authenticator.py +0 -0
  34. siat/grafix.py +0 -0
  35. siat/holding_risk.py +0 -0
  36. siat/luchy_draw.py +0 -0
  37. siat/market_china.py +0 -0
  38. siat/markowitz.py +0 -0
  39. siat/markowitz2.py +0 -0
  40. siat/markowitz2_20250704.py +0 -0
  41. siat/markowitz2_20250705.py +0 -0
  42. siat/markowitz_simple.py +0 -0
  43. siat/ml_cases.py +0 -0
  44. siat/ml_cases_example.py +0 -0
  45. siat/option_china.py +0 -0
  46. siat/option_pricing.py +0 -0
  47. siat/other_indexes.py +0 -0
  48. siat/risk_adjusted_return.py +0 -0
  49. siat/risk_adjusted_return2.py +0 -0
  50. siat/risk_evaluation.py +0 -0
  51. siat/risk_free_rate.py +0 -0
  52. siat/sector_china.py +0 -0
  53. siat/security_price2.py +0 -0
  54. siat/security_prices.py +40 -2
  55. siat/security_trend.py +0 -0
  56. siat/security_trend2.py +0 -0
  57. siat/stock.py +0 -0
  58. siat/stock_advice_linear.py +0 -0
  59. siat/stock_base.py +0 -0
  60. siat/stock_china.py +0 -0
  61. siat/stock_info.pickle +0 -0
  62. siat/stock_prices_kneighbors.py +0 -0
  63. siat/stock_prices_linear.py +0 -0
  64. siat/stock_profile.py +0 -0
  65. siat/stock_technical.py +0 -0
  66. siat/stooq.py +0 -0
  67. siat/transaction.py +0 -0
  68. siat/translate.py +0 -0
  69. siat/valuation.py +0 -0
  70. siat/valuation_china.py +0 -0
  71. siat/var_model_validation.py +0 -0
  72. siat/yf_name.py +0 -0
  73. {siat-3.10.132.dist-info/licenses → siat-3.10.133.dist-info}/LICENSE +0 -0
  74. {siat-3.10.132.dist-info → siat-3.10.133.dist-info}/METADATA +232 -235
  75. siat-3.10.133.dist-info/RECORD +78 -0
  76. {siat-3.10.132.dist-info → siat-3.10.133.dist-info}/WHEEL +1 -1
  77. {siat-3.10.132.dist-info → siat-3.10.133.dist-info}/top_level.txt +0 -1
  78. build/lib/build/lib/siat/__init__.py +0 -75
  79. build/lib/build/lib/siat/allin.py +0 -137
  80. build/lib/build/lib/siat/assets_liquidity.py +0 -915
  81. build/lib/build/lib/siat/beta_adjustment.py +0 -1058
  82. build/lib/build/lib/siat/beta_adjustment_china.py +0 -548
  83. build/lib/build/lib/siat/blockchain.py +0 -143
  84. build/lib/build/lib/siat/bond.py +0 -2900
  85. build/lib/build/lib/siat/bond_base.py +0 -992
  86. build/lib/build/lib/siat/bond_china.py +0 -100
  87. build/lib/build/lib/siat/bond_zh_sina.py +0 -143
  88. build/lib/build/lib/siat/capm_beta.py +0 -783
  89. build/lib/build/lib/siat/capm_beta2.py +0 -887
  90. build/lib/build/lib/siat/common.py +0 -5360
  91. build/lib/build/lib/siat/compare_cross.py +0 -642
  92. build/lib/build/lib/siat/copyrights.py +0 -18
  93. build/lib/build/lib/siat/cryptocurrency.py +0 -667
  94. build/lib/build/lib/siat/economy.py +0 -1471
  95. build/lib/build/lib/siat/economy2.py +0 -1853
  96. build/lib/build/lib/siat/esg.py +0 -536
  97. build/lib/build/lib/siat/event_study.py +0 -815
  98. build/lib/build/lib/siat/fama_french.py +0 -1521
  99. build/lib/build/lib/siat/fin_stmt2_yahoo.py +0 -982
  100. build/lib/build/lib/siat/financial_base.py +0 -1160
  101. build/lib/build/lib/siat/financial_statements.py +0 -598
  102. build/lib/build/lib/siat/financials.py +0 -2339
  103. build/lib/build/lib/siat/financials2.py +0 -1278
  104. build/lib/build/lib/siat/financials_china.py +0 -4433
  105. build/lib/build/lib/siat/financials_china2.py +0 -2212
  106. build/lib/build/lib/siat/fund.py +0 -629
  107. build/lib/build/lib/siat/fund_china.py +0 -3307
  108. build/lib/build/lib/siat/future_china.py +0 -551
  109. build/lib/build/lib/siat/google_authenticator.py +0 -47
  110. build/lib/build/lib/siat/grafix.py +0 -3636
  111. build/lib/build/lib/siat/holding_risk.py +0 -867
  112. build/lib/build/lib/siat/luchy_draw.py +0 -638
  113. build/lib/build/lib/siat/market_china.py +0 -1168
  114. build/lib/build/lib/siat/markowitz.py +0 -2363
  115. build/lib/build/lib/siat/markowitz2.py +0 -3150
  116. build/lib/build/lib/siat/markowitz2_20250704.py +0 -2969
  117. build/lib/build/lib/siat/markowitz2_20250705.py +0 -3158
  118. build/lib/build/lib/siat/markowitz_simple.py +0 -373
  119. build/lib/build/lib/siat/ml_cases.py +0 -2291
  120. build/lib/build/lib/siat/ml_cases_example.py +0 -60
  121. build/lib/build/lib/siat/option_china.py +0 -3069
  122. build/lib/build/lib/siat/option_pricing.py +0 -1925
  123. build/lib/build/lib/siat/other_indexes.py +0 -409
  124. build/lib/build/lib/siat/risk_adjusted_return.py +0 -1576
  125. build/lib/build/lib/siat/risk_adjusted_return2.py +0 -1900
  126. build/lib/build/lib/siat/risk_evaluation.py +0 -2218
  127. build/lib/build/lib/siat/risk_free_rate.py +0 -351
  128. build/lib/build/lib/siat/sector_china.py +0 -4140
  129. build/lib/build/lib/siat/security_price2.py +0 -727
  130. build/lib/build/lib/siat/security_prices.py +0 -3408
  131. build/lib/build/lib/siat/security_trend.py +0 -402
  132. build/lib/build/lib/siat/security_trend2.py +0 -646
  133. build/lib/build/lib/siat/stock.py +0 -4284
  134. build/lib/build/lib/siat/stock_advice_linear.py +0 -934
  135. build/lib/build/lib/siat/stock_base.py +0 -26
  136. build/lib/build/lib/siat/stock_china.py +0 -2095
  137. build/lib/build/lib/siat/stock_prices_kneighbors.py +0 -910
  138. build/lib/build/lib/siat/stock_prices_linear.py +0 -386
  139. build/lib/build/lib/siat/stock_profile.py +0 -707
  140. build/lib/build/lib/siat/stock_technical.py +0 -3305
  141. build/lib/build/lib/siat/stooq.py +0 -74
  142. build/lib/build/lib/siat/transaction.py +0 -347
  143. build/lib/build/lib/siat/translate.py +0 -5183
  144. build/lib/build/lib/siat/valuation.py +0 -1378
  145. build/lib/build/lib/siat/valuation_china.py +0 -2076
  146. build/lib/build/lib/siat/var_model_validation.py +0 -444
  147. build/lib/build/lib/siat/yf_name.py +0 -811
  148. build/lib/siat/__init__.py +0 -75
  149. build/lib/siat/allin.py +0 -137
  150. build/lib/siat/assets_liquidity.py +0 -915
  151. build/lib/siat/beta_adjustment.py +0 -1058
  152. build/lib/siat/beta_adjustment_china.py +0 -548
  153. build/lib/siat/blockchain.py +0 -143
  154. build/lib/siat/bond.py +0 -2900
  155. build/lib/siat/bond_base.py +0 -992
  156. build/lib/siat/bond_china.py +0 -100
  157. build/lib/siat/bond_zh_sina.py +0 -143
  158. build/lib/siat/capm_beta.py +0 -783
  159. build/lib/siat/capm_beta2.py +0 -887
  160. build/lib/siat/common.py +0 -5360
  161. build/lib/siat/compare_cross.py +0 -642
  162. build/lib/siat/copyrights.py +0 -18
  163. build/lib/siat/cryptocurrency.py +0 -667
  164. build/lib/siat/economy.py +0 -1471
  165. build/lib/siat/economy2.py +0 -1853
  166. build/lib/siat/esg.py +0 -536
  167. build/lib/siat/event_study.py +0 -815
  168. build/lib/siat/fama_french.py +0 -1521
  169. build/lib/siat/fin_stmt2_yahoo.py +0 -982
  170. build/lib/siat/financial_base.py +0 -1160
  171. build/lib/siat/financial_statements.py +0 -598
  172. build/lib/siat/financials.py +0 -2339
  173. build/lib/siat/financials2.py +0 -1278
  174. build/lib/siat/financials_china.py +0 -4433
  175. build/lib/siat/financials_china2.py +0 -2212
  176. build/lib/siat/fund.py +0 -629
  177. build/lib/siat/fund_china.py +0 -3307
  178. build/lib/siat/future_china.py +0 -551
  179. build/lib/siat/google_authenticator.py +0 -47
  180. build/lib/siat/grafix.py +0 -3636
  181. build/lib/siat/holding_risk.py +0 -867
  182. build/lib/siat/luchy_draw.py +0 -638
  183. build/lib/siat/market_china.py +0 -1168
  184. build/lib/siat/markowitz.py +0 -2363
  185. build/lib/siat/markowitz2.py +0 -3150
  186. build/lib/siat/markowitz2_20250704.py +0 -2969
  187. build/lib/siat/markowitz2_20250705.py +0 -3158
  188. build/lib/siat/markowitz_simple.py +0 -373
  189. build/lib/siat/ml_cases.py +0 -2291
  190. build/lib/siat/ml_cases_example.py +0 -60
  191. build/lib/siat/option_china.py +0 -3069
  192. build/lib/siat/option_pricing.py +0 -1925
  193. build/lib/siat/other_indexes.py +0 -409
  194. build/lib/siat/risk_adjusted_return.py +0 -1576
  195. build/lib/siat/risk_adjusted_return2.py +0 -1900
  196. build/lib/siat/risk_evaluation.py +0 -2218
  197. build/lib/siat/risk_free_rate.py +0 -351
  198. build/lib/siat/sector_china.py +0 -4140
  199. build/lib/siat/security_price2.py +0 -727
  200. build/lib/siat/security_prices.py +0 -3408
  201. build/lib/siat/security_trend.py +0 -402
  202. build/lib/siat/security_trend2.py +0 -646
  203. build/lib/siat/stock.py +0 -4284
  204. build/lib/siat/stock_advice_linear.py +0 -934
  205. build/lib/siat/stock_base.py +0 -26
  206. build/lib/siat/stock_china.py +0 -2095
  207. build/lib/siat/stock_prices_kneighbors.py +0 -910
  208. build/lib/siat/stock_prices_linear.py +0 -386
  209. build/lib/siat/stock_profile.py +0 -707
  210. build/lib/siat/stock_technical.py +0 -3305
  211. build/lib/siat/stooq.py +0 -74
  212. build/lib/siat/transaction.py +0 -347
  213. build/lib/siat/translate.py +0 -5183
  214. build/lib/siat/valuation.py +0 -1378
  215. build/lib/siat/valuation_china.py +0 -2076
  216. build/lib/siat/var_model_validation.py +0 -444
  217. build/lib/siat/yf_name.py +0 -811
  218. siat-3.10.132.dist-info/RECORD +0 -218
@@ -1,386 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- @function: 预测美股股价,教学演示用,其他用途责任自负
4
- @model:线性模型,ols, righe, lasso, elasticnet
5
- @version:v1.0,2019.4.4
6
- @purpose: 仅限机器学习课程案例使用
7
- @author: 王德宏,北京外国语大学国际商学院
8
- """
9
-
10
- #=====================================================================
11
- def get_stock_price(ticker,atdate,fromdate):
12
- """
13
- 功能:抓取美股股价
14
- 输出:指定美股的收盘价格序列,最新日期的股价排列在前
15
- ticker:美股股票代码
16
- atdate:当前日期,既可以是今天日期,也可以是一个历史日期,datetime类型
17
- fromdate:样本开始日期,尽量远的日期,以便取得足够多的原始样本,类型同atdate
18
- """
19
-
20
- #仅为调试用的函数入口参数,正式使用前需要注释掉!
21
- #ticker='MSFT'
22
- #atdate='3/29/2019'
23
- #fromdate='1/1/2015'
24
- #---------------------------------------------
25
-
26
- #抓取美股股票价格
27
- from pandas_datareader import data
28
- price=data.DataReader(ticker,'stooq',fromdate,atdate)
29
-
30
- #去掉比起始日期更早的样本
31
- price2=price[price.index >= fromdate]
32
-
33
-
34
- #按日期降序排序,近期的价格排在前面
35
- sortedprice=price2.sort_index(axis=0,ascending=False)
36
-
37
- #提取日期和星期几
38
- #sortedprice['Date']=sortedprice.index.date
39
- sortedprice['Date']=sortedprice.index.strftime("%Y-%m-%d")
40
- sortedprice['Weekday']=sortedprice.index.weekday+1
41
-
42
- #生成输出数据格式:日期,星期几,收盘价
43
- dfprice=sortedprice[['Date','Weekday','Close']]
44
-
45
- return dfprice
46
-
47
-
48
- if __name__=='__main__':
49
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
50
- dfprice.head(5)
51
- dfprice.tail(3)
52
- dfprice[dfprice.Date == '2019-03-29']
53
- dfprice[(dfprice.Date>='2019-03-20') & (dfprice.Date<='2019-03-29')]
54
-
55
-
56
- #=====================================================================
57
- def make_price_sample(dfprice,n_nextdays=1,n_samples=240,n_features=20):
58
- """
59
- 功能:生成指定股票的价格样本
60
- ticker:美股股票代码
61
- n_nextdays:预测从atdate开始未来第几天的股价,默认为1
62
- n_samples:需要生成的样本个数,默认240个(一年的平均交易天数)
63
- n_features:使用的特征数量,默认20个(一个月的平均交易天数)
64
- """
65
-
66
- #提取收盘价,Series类型
67
- closeprice=dfprice.Close
68
-
69
- #将closeprice转换为机器学习需要的ndarray类型ndprice
70
- import numpy as np
71
- ndprice=np.asmatrix(closeprice,dtype=None)
72
-
73
- #生成第一个标签样本:标签矩阵y(形状:n_samples x 1)
74
- import numpy as np
75
- y=np.asmatrix(ndprice[0,0])
76
- #生成第一个特征样本:特征矩阵X(形状:n_samples x n_features)
77
- X=ndprice[0,n_nextdays:n_features+n_nextdays]
78
-
79
- #生成其余的标签样本和特征样本
80
- for i in range(1,n_samples):
81
- y_row=np.asmatrix(ndprice[0,i])
82
- y=np.append(y,y_row,axis=0)
83
-
84
- X_row=ndprice[0,(n_nextdays+i):(n_features+n_nextdays+i)]
85
- X=np.append(X,X_row,axis=0)
86
-
87
- return X,y,ndprice
88
-
89
- if __name__=='__main__':
90
- fdprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
91
- X,y,ndprice=make_price_sample(fdprice,1,240,20)
92
- y[:5]
93
- y[2:5] #第1行的序号为0
94
- X[:5]
95
- X[:-5]
96
- X[3-1,2-1]
97
-
98
-
99
- #=====================================================================
100
- def bestR1(X,y):
101
- """
102
- 功能:给定特征矩阵和标签,使用岭回归,返回最优的alpha参数和模型
103
- 最优策略:测试集分数最高,不管过拟合问题
104
- """
105
-
106
- import numpy as np
107
- #将整个样本随机分割为训练集和测试集
108
- from sklearn.model_selection import train_test_split
109
- X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)
110
-
111
- #初始化alpha,便于判断上行下行方向
112
- alphalist=[0.001,0.0011,0.00999,0.01,0.01001,0.999,1,1.01, \
113
- 9.99,10,10.01,99,100,101,999,1000,1001,10000]
114
-
115
- from sklearn.linear_model import RidgeCV
116
- reg=RidgeCV(alphas=alphalist,cv=5,fit_intercept=True,normalize=True)
117
-
118
- reg.fit(X_train, y_train)
119
- score_train=reg.score(X_train, y_train)
120
- score_test=reg.score(X_test, y_test)
121
- alpha=reg.alpha_
122
- #print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
123
-
124
- #确定alpha参数的优化范围
125
- if alpha in [0.001,0.01,1,2,10,100,1000,10000]:
126
- #print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
127
- return reg,alpha,score_train,score_test
128
-
129
- if 0.001 < alpha < 0.01:
130
- alphalist1=np.arange(0.001,0.01,0.0005)
131
- if 0.01 < alpha < 1:
132
- alphalist1=np.arange(0.01,1,0.005)
133
- if 1 < alpha < 10:
134
- alphalist1=np.arange(1,10,0.01)
135
- if 10 < alpha < 100:
136
- alphalist1=np.arange(10,100,0.1)
137
- if 100 < alpha < 1000:
138
- alphalist1=np.arange(100,1000,1)
139
- if 1000 < alpha < 10000:
140
- alphalist1=np.arange(1000,10000,10)
141
-
142
- reg1=RidgeCV(alphas=alphalist1,cv=5,fit_intercept=True,normalize=True)
143
- reg1.fit(X_train, y_train)
144
- score1_train=reg1.score(X_train,y_train)
145
- score1_test =reg1.score(X_test, y_test)
146
- alpha1=reg1.alpha_
147
-
148
- #print("%.5f, %.5f, %.5f"%(alpha1,score1_train,score1_test))
149
- return reg1,alpha1,score1_train,score1_test
150
-
151
-
152
- if __name__=='__main__':
153
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
154
- X,y,ndprice=make_price_sample(dfprice,1,240,20)
155
-
156
- model,alpha,score_train,score_test=bestR1(X,y)
157
- print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
158
- #结果:0.045,0.9277,0.8940
159
-
160
- X_new=ndprice[0,0:20]
161
- y_new=model.predict(X_new)
162
- print("%.2f"%y_new)
163
- #结果:119.43
164
- #=====================================================================
165
- def bestL1(X,y):
166
- """
167
- 功能:给定特征矩阵和标签,使用拉索回归,返回最优的alpha参数和模型
168
- 最优策略:测试集分数最高,不管过拟合问题
169
- """
170
- import numpy as np
171
- #将整个样本随机分割为训练集和测试集
172
- from sklearn.utils import column_or_1d
173
- y=column_or_1d(y,warn=False)
174
- from sklearn.model_selection import train_test_split
175
- X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)
176
-
177
-
178
- #初始alpha,便于判断上行下行方向
179
- alphalist=[0.001,0.0011,0.00999,0.01,0.01001,0.999,1,1.01,1.99,2,2.01, \
180
- 9.99,10,10.01,99,100,101,999,1000,1001,10000]
181
-
182
- from sklearn.linear_model import LassoCV
183
- reg=LassoCV(alphas=alphalist,max_iter=10**6, \
184
- cv=5,fit_intercept=True,normalize=True)
185
- reg.fit(X_train, y_train)
186
- score_train=reg.score(X_train,y_train)
187
- score_test =reg.score(X_test, y_test)
188
- alpha=reg.alpha_
189
- #print("Step0: %.4f, %.5f, %.5f"%(alpha,score_train,score_test))
190
-
191
- #确定alpha参数的优化范围
192
- if alpha in [0.001,0.01,1,2,10,100,1000,10000]:
193
- #print("Step01: %.5f, %.5f, %.5f"%(alpha,score_train,score_test))
194
- return reg,alpha,score_train,score_test
195
-
196
- if 0.001 < alpha < 0.01:
197
- alphalist1=np.arange(0.0015,0.01,0.0005)
198
-
199
- if 0.01 < alpha < 1:
200
- alphalist1=np.arange(0.015,1,0.005)
201
-
202
- if 1 < alpha < 10:
203
- alphalist1=np.arange(1.01,10,0.01)
204
-
205
- if 10 < alpha < 100:
206
- alphalist1=np.arange(10.1,100,0.1)
207
-
208
- if 100 < alpha < 1000:
209
- alphalist1=np.arange(101,1000,1)
210
-
211
- if 1000 < alpha < 10000:
212
- alphalist1=np.arange(1010,10000,10)
213
-
214
- reg1=LassoCV(alphas=alphalist1,cv=5,fit_intercept=True,normalize=True)
215
- reg1.fit(X_train, y_train)
216
- score1_train=reg1.score(X_train,y_train)
217
- score1_test =reg1.score(X_test, y_test)
218
- alpha1=reg1.alpha_
219
- #print("Step1: %.4f, %.5f, %.5f"%(alpha1,score1_train,score1_test))
220
- return reg1,alpha1,score1_train,score1_test
221
-
222
- if __name__=='__main__':
223
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
224
- X,y,ndprice=make_price_sample(dfprice,1,240,20)
225
-
226
- model,alpha,score_train,score_test=bestL1(X,y)
227
- print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
228
- #结果:0.015,0.9284,0.9043
229
-
230
- X_new=ndprice[0,0:20]
231
- y_new=model.predict(X_new)
232
- print("%.2f"%y_new)
233
- #结果:119.37
234
-
235
- #=====================================================================
236
-
237
- def bestEN2(X,y,maxalpha=2):
238
- """
239
- 功能:给定特征矩阵和标签,使用弹性网络回归,返回最优的alpha参数和模型
240
- 最优策略:利用ElasticNetCV筛选机制,速度慢
241
- """
242
- #将整个样本随机分割为训练集和测试集
243
- from sklearn.utils import column_or_1d
244
- y=column_or_1d(y,warn=False)
245
-
246
- from sklearn.model_selection import train_test_split
247
- X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=66)
248
-
249
- #限定参数范围
250
- import numpy as np
251
- alphalist=np.arange(0.01,maxalpha,0.01)
252
- l1list =np.arange(0.01,1,0.01)
253
-
254
- from sklearn.linear_model import ElasticNetCV
255
- reg=ElasticNetCV(alphas=alphalist,l1_ratio=l1list)
256
-
257
- reg.fit(X_train, y_train)
258
- score_train=reg.score(X_train,y_train)
259
- score_test =reg.score(X_test, y_test)
260
- alpha=reg.alpha_
261
- l1ratio=reg.l1_ratio_
262
-
263
- return reg,alpha,l1ratio,score_train,score_test
264
-
265
- if __name__=='__main__':
266
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
267
- X,y,ndprice=make_price_sample(dfprice,1,240,20)
268
-
269
- model,alpha,l1ratio,score_train,score_test=bestEN2(X,y)
270
- print("%.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
271
- #结果:0.42,0.99,0.9258,0.9174
272
-
273
- X_new=ndprice[0,0:20]
274
- y_new=model.predict(X_new)
275
- print("%.2f"%y_new)
276
- #结果:119.60
277
-
278
- #=======
279
-
280
- #==============================================================================
281
-
282
- def bestEN3(X,y):
283
- """
284
- 功能:给定特征矩阵和标签,使用弹性网络回归,返回最优的alpha参数和模型
285
- 最优策略:利用cv交叉验证,速度快
286
- 算法贡献者:徐乐欣(韩语国商)
287
- """
288
- import numpy as np
289
- #将整个样本随机分割为训练集和测试集
290
- from sklearn.utils import column_or_1d
291
- y=column_or_1d(y,warn=False)
292
- from sklearn.model_selection import train_test_split
293
- X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=66)
294
-
295
- from sklearn.linear_model import ElasticNetCV
296
- #reg=ElasticNetCV(cv=5, random_state=0)
297
- #reg.fit(X,y)
298
-
299
- l1list=np.arange(0.01,1,0.01)
300
- ENet=ElasticNetCV(alphas=None, copy_X=True, cv=5, eps=0.001, \
301
- fit_intercept=True,l1_ratio=l1list, max_iter=8000, \
302
- n_alphas=100, n_jobs=None,normalize=True, \
303
- positive=False, precompute='auto', random_state=0, \
304
- selection='cyclic', tol=0.0001, verbose=0)
305
- ENet.fit(X_train, y_train)
306
- score_train=ENet.score(X_train, y_train)
307
- score_test=ENet.score(X_test, y_test)
308
- alpha=ENet.alpha_
309
- l1ratio=ENet.l1_ratio_
310
- #print("S1: %.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
311
-
312
- return ENet,alpha,l1ratio,score_train,score_test
313
-
314
- if __name__=='__main__':
315
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
316
- X,y,ndprice=make_price_sample(dfprice,1,240,20)
317
-
318
- model,alpha,l1ratio,score_train,score_test=bestEN3(X,y)
319
- print("%.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
320
- #结果:0.005836,0.99,0.925,0.9194
321
-
322
- X_new=ndprice[0,0:20]
323
- y_new=model.predict(X_new)
324
- print("%.2f"%y_new)
325
- #结果:119.48
326
- #==============================================================================
327
-
328
-
329
- def bestEN1(X,y,maxalpha=2):
330
- """
331
- 功能:给定特征矩阵和标签,使用弹性网络回归,返回最优的alpha参数和模型
332
- 最优策略:对alpha和l1_ratio进行暴力枚举,搜索最高测试集分数,速度中等
333
- 算法贡献者:徐乐欣(韩语国商)
334
- """
335
-
336
- #将整个样本随机分割为训练集和测试集
337
- from sklearn.utils import column_or_1d
338
- y=column_or_1d(y,warn=False)
339
- from sklearn.model_selection import train_test_split
340
- X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=66)
341
-
342
- #设立初始测试集分数门槛
343
- king_score=0.6
344
- from sklearn.linear_model import ElasticNet
345
-
346
- #限定参数范围
347
- import numpy as np
348
- alphalist=np.arange(0.01,maxalpha,0.01)
349
- l1list =np.arange(0.01,1,0.01)
350
-
351
- for i in alphalist:
352
- for j in l1list:
353
- reg=ElasticNet(alpha=i,l1_ratio=j)
354
- reg.fit(X_train,y_train)
355
- temp_score=reg.score(X_test,y_test)
356
- if temp_score > king_score:
357
- king_score=temp_score
358
- alpha=i
359
- l1ratio=j
360
- score_train=reg.score(X_train,y_train)
361
- score_test=temp_score
362
- model=reg
363
-
364
- return model,alpha,l1ratio,score_train,score_test
365
-
366
- if __name__=='__main__':
367
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
368
- X,y,ndprice=make_price_sample(dfprice,1,240,20)
369
-
370
- model,alpha,l1ratio,score_train,score_test=bestEN1(X,y)
371
- print("%.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
372
- #结果:1.31,0.56,0.9241,0.9196
373
-
374
- X_new=ndprice[0,0:20]
375
- y_new=model.predict(X_new)
376
- print("%.2f"%y_new)
377
- #结果:119.36
378
- #==============================================================================
379
-
380
-
381
-
382
-
383
-
384
-
385
- #==============================================================================
386
-