siat 3.10.132__py3-none-any.whl → 3.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. siat/__init__.py +0 -0
  2. siat/allin.py +8 -0
  3. siat/assets_liquidity.py +0 -0
  4. siat/beta_adjustment.py +0 -0
  5. siat/beta_adjustment_china.py +0 -0
  6. siat/blockchain.py +0 -0
  7. siat/bond.py +0 -0
  8. siat/bond_base.py +0 -0
  9. siat/bond_china.py +0 -0
  10. siat/bond_zh_sina.py +0 -0
  11. siat/capm_beta.py +0 -0
  12. siat/capm_beta2.py +4 -4
  13. siat/common.py +9 -6
  14. siat/compare_cross.py +0 -0
  15. siat/copyrights.py +0 -0
  16. siat/cryptocurrency.py +0 -0
  17. siat/economy.py +0 -0
  18. siat/economy2.py +0 -0
  19. siat/esg.py +0 -0
  20. siat/event_study.py +0 -0
  21. siat/exchange_bond_china.pickle +0 -0
  22. siat/fama_french.py +0 -0
  23. siat/fin_stmt2_yahoo.py +0 -0
  24. siat/financial_base.py +0 -0
  25. siat/financial_statements.py +0 -0
  26. siat/financials.py +0 -0
  27. siat/financials2.py +0 -0
  28. siat/financials_china.py +0 -0
  29. siat/financials_china2.py +0 -0
  30. siat/fund.py +0 -0
  31. siat/fund_china.pickle +0 -0
  32. siat/fund_china.py +0 -0
  33. siat/future_china.py +0 -0
  34. siat/google_authenticator.py +0 -0
  35. siat/grafix.py +55 -4
  36. siat/holding_risk.py +0 -0
  37. siat/luchy_draw.py +0 -0
  38. siat/market_china.py +0 -0
  39. siat/markowitz.py +0 -0
  40. siat/markowitz2.py +1 -0
  41. siat/markowitz2_20250704.py +0 -0
  42. siat/markowitz2_20250705.py +0 -0
  43. siat/markowitz_simple.py +0 -0
  44. siat/ml_cases.py +0 -0
  45. siat/ml_cases_example.py +0 -0
  46. siat/option_china.py +0 -0
  47. siat/option_pricing.py +0 -0
  48. siat/other_indexes.py +0 -0
  49. siat/risk_adjusted_return.py +0 -0
  50. siat/risk_adjusted_return2.py +8 -4
  51. siat/risk_evaluation.py +0 -0
  52. siat/risk_free_rate.py +0 -0
  53. siat/save2docx.py +345 -0
  54. siat/save2pdf.py +145 -0
  55. siat/sector_china.py +0 -0
  56. siat/security_price2.py +0 -0
  57. siat/security_prices.py +168 -6
  58. siat/security_trend.py +0 -0
  59. siat/security_trend2.py +2 -2
  60. siat/stock.py +11 -1
  61. siat/stock_advice_linear.py +0 -0
  62. siat/stock_base.py +0 -0
  63. siat/stock_china.py +0 -0
  64. siat/stock_info.pickle +0 -0
  65. siat/stock_prices_kneighbors.py +0 -0
  66. siat/stock_prices_linear.py +0 -0
  67. siat/stock_profile.py +0 -0
  68. siat/stock_technical.py +0 -0
  69. siat/stooq.py +0 -0
  70. siat/transaction.py +0 -0
  71. siat/translate.py +0 -0
  72. siat/valuation.py +0 -0
  73. siat/valuation_china.py +0 -0
  74. siat/var_model_validation.py +0 -0
  75. siat/yf_name.py +0 -0
  76. {siat-3.10.132.dist-info/licenses → siat-3.11.1.dist-info}/LICENSE +0 -0
  77. {siat-3.10.132.dist-info → siat-3.11.1.dist-info}/METADATA +234 -235
  78. siat-3.11.1.dist-info/RECORD +80 -0
  79. {siat-3.10.132.dist-info → siat-3.11.1.dist-info}/WHEEL +1 -1
  80. {siat-3.10.132.dist-info → siat-3.11.1.dist-info}/top_level.txt +0 -1
  81. build/lib/build/lib/siat/__init__.py +0 -75
  82. build/lib/build/lib/siat/allin.py +0 -137
  83. build/lib/build/lib/siat/assets_liquidity.py +0 -915
  84. build/lib/build/lib/siat/beta_adjustment.py +0 -1058
  85. build/lib/build/lib/siat/beta_adjustment_china.py +0 -548
  86. build/lib/build/lib/siat/blockchain.py +0 -143
  87. build/lib/build/lib/siat/bond.py +0 -2900
  88. build/lib/build/lib/siat/bond_base.py +0 -992
  89. build/lib/build/lib/siat/bond_china.py +0 -100
  90. build/lib/build/lib/siat/bond_zh_sina.py +0 -143
  91. build/lib/build/lib/siat/capm_beta.py +0 -783
  92. build/lib/build/lib/siat/capm_beta2.py +0 -887
  93. build/lib/build/lib/siat/common.py +0 -5360
  94. build/lib/build/lib/siat/compare_cross.py +0 -642
  95. build/lib/build/lib/siat/copyrights.py +0 -18
  96. build/lib/build/lib/siat/cryptocurrency.py +0 -667
  97. build/lib/build/lib/siat/economy.py +0 -1471
  98. build/lib/build/lib/siat/economy2.py +0 -1853
  99. build/lib/build/lib/siat/esg.py +0 -536
  100. build/lib/build/lib/siat/event_study.py +0 -815
  101. build/lib/build/lib/siat/fama_french.py +0 -1521
  102. build/lib/build/lib/siat/fin_stmt2_yahoo.py +0 -982
  103. build/lib/build/lib/siat/financial_base.py +0 -1160
  104. build/lib/build/lib/siat/financial_statements.py +0 -598
  105. build/lib/build/lib/siat/financials.py +0 -2339
  106. build/lib/build/lib/siat/financials2.py +0 -1278
  107. build/lib/build/lib/siat/financials_china.py +0 -4433
  108. build/lib/build/lib/siat/financials_china2.py +0 -2212
  109. build/lib/build/lib/siat/fund.py +0 -629
  110. build/lib/build/lib/siat/fund_china.py +0 -3307
  111. build/lib/build/lib/siat/future_china.py +0 -551
  112. build/lib/build/lib/siat/google_authenticator.py +0 -47
  113. build/lib/build/lib/siat/grafix.py +0 -3636
  114. build/lib/build/lib/siat/holding_risk.py +0 -867
  115. build/lib/build/lib/siat/luchy_draw.py +0 -638
  116. build/lib/build/lib/siat/market_china.py +0 -1168
  117. build/lib/build/lib/siat/markowitz.py +0 -2363
  118. build/lib/build/lib/siat/markowitz2.py +0 -3150
  119. build/lib/build/lib/siat/markowitz2_20250704.py +0 -2969
  120. build/lib/build/lib/siat/markowitz2_20250705.py +0 -3158
  121. build/lib/build/lib/siat/markowitz_simple.py +0 -373
  122. build/lib/build/lib/siat/ml_cases.py +0 -2291
  123. build/lib/build/lib/siat/ml_cases_example.py +0 -60
  124. build/lib/build/lib/siat/option_china.py +0 -3069
  125. build/lib/build/lib/siat/option_pricing.py +0 -1925
  126. build/lib/build/lib/siat/other_indexes.py +0 -409
  127. build/lib/build/lib/siat/risk_adjusted_return.py +0 -1576
  128. build/lib/build/lib/siat/risk_adjusted_return2.py +0 -1900
  129. build/lib/build/lib/siat/risk_evaluation.py +0 -2218
  130. build/lib/build/lib/siat/risk_free_rate.py +0 -351
  131. build/lib/build/lib/siat/sector_china.py +0 -4140
  132. build/lib/build/lib/siat/security_price2.py +0 -727
  133. build/lib/build/lib/siat/security_prices.py +0 -3408
  134. build/lib/build/lib/siat/security_trend.py +0 -402
  135. build/lib/build/lib/siat/security_trend2.py +0 -646
  136. build/lib/build/lib/siat/stock.py +0 -4284
  137. build/lib/build/lib/siat/stock_advice_linear.py +0 -934
  138. build/lib/build/lib/siat/stock_base.py +0 -26
  139. build/lib/build/lib/siat/stock_china.py +0 -2095
  140. build/lib/build/lib/siat/stock_prices_kneighbors.py +0 -910
  141. build/lib/build/lib/siat/stock_prices_linear.py +0 -386
  142. build/lib/build/lib/siat/stock_profile.py +0 -707
  143. build/lib/build/lib/siat/stock_technical.py +0 -3305
  144. build/lib/build/lib/siat/stooq.py +0 -74
  145. build/lib/build/lib/siat/transaction.py +0 -347
  146. build/lib/build/lib/siat/translate.py +0 -5183
  147. build/lib/build/lib/siat/valuation.py +0 -1378
  148. build/lib/build/lib/siat/valuation_china.py +0 -2076
  149. build/lib/build/lib/siat/var_model_validation.py +0 -444
  150. build/lib/build/lib/siat/yf_name.py +0 -811
  151. build/lib/siat/__init__.py +0 -75
  152. build/lib/siat/allin.py +0 -137
  153. build/lib/siat/assets_liquidity.py +0 -915
  154. build/lib/siat/beta_adjustment.py +0 -1058
  155. build/lib/siat/beta_adjustment_china.py +0 -548
  156. build/lib/siat/blockchain.py +0 -143
  157. build/lib/siat/bond.py +0 -2900
  158. build/lib/siat/bond_base.py +0 -992
  159. build/lib/siat/bond_china.py +0 -100
  160. build/lib/siat/bond_zh_sina.py +0 -143
  161. build/lib/siat/capm_beta.py +0 -783
  162. build/lib/siat/capm_beta2.py +0 -887
  163. build/lib/siat/common.py +0 -5360
  164. build/lib/siat/compare_cross.py +0 -642
  165. build/lib/siat/copyrights.py +0 -18
  166. build/lib/siat/cryptocurrency.py +0 -667
  167. build/lib/siat/economy.py +0 -1471
  168. build/lib/siat/economy2.py +0 -1853
  169. build/lib/siat/esg.py +0 -536
  170. build/lib/siat/event_study.py +0 -815
  171. build/lib/siat/fama_french.py +0 -1521
  172. build/lib/siat/fin_stmt2_yahoo.py +0 -982
  173. build/lib/siat/financial_base.py +0 -1160
  174. build/lib/siat/financial_statements.py +0 -598
  175. build/lib/siat/financials.py +0 -2339
  176. build/lib/siat/financials2.py +0 -1278
  177. build/lib/siat/financials_china.py +0 -4433
  178. build/lib/siat/financials_china2.py +0 -2212
  179. build/lib/siat/fund.py +0 -629
  180. build/lib/siat/fund_china.py +0 -3307
  181. build/lib/siat/future_china.py +0 -551
  182. build/lib/siat/google_authenticator.py +0 -47
  183. build/lib/siat/grafix.py +0 -3636
  184. build/lib/siat/holding_risk.py +0 -867
  185. build/lib/siat/luchy_draw.py +0 -638
  186. build/lib/siat/market_china.py +0 -1168
  187. build/lib/siat/markowitz.py +0 -2363
  188. build/lib/siat/markowitz2.py +0 -3150
  189. build/lib/siat/markowitz2_20250704.py +0 -2969
  190. build/lib/siat/markowitz2_20250705.py +0 -3158
  191. build/lib/siat/markowitz_simple.py +0 -373
  192. build/lib/siat/ml_cases.py +0 -2291
  193. build/lib/siat/ml_cases_example.py +0 -60
  194. build/lib/siat/option_china.py +0 -3069
  195. build/lib/siat/option_pricing.py +0 -1925
  196. build/lib/siat/other_indexes.py +0 -409
  197. build/lib/siat/risk_adjusted_return.py +0 -1576
  198. build/lib/siat/risk_adjusted_return2.py +0 -1900
  199. build/lib/siat/risk_evaluation.py +0 -2218
  200. build/lib/siat/risk_free_rate.py +0 -351
  201. build/lib/siat/sector_china.py +0 -4140
  202. build/lib/siat/security_price2.py +0 -727
  203. build/lib/siat/security_prices.py +0 -3408
  204. build/lib/siat/security_trend.py +0 -402
  205. build/lib/siat/security_trend2.py +0 -646
  206. build/lib/siat/stock.py +0 -4284
  207. build/lib/siat/stock_advice_linear.py +0 -934
  208. build/lib/siat/stock_base.py +0 -26
  209. build/lib/siat/stock_china.py +0 -2095
  210. build/lib/siat/stock_prices_kneighbors.py +0 -910
  211. build/lib/siat/stock_prices_linear.py +0 -386
  212. build/lib/siat/stock_profile.py +0 -707
  213. build/lib/siat/stock_technical.py +0 -3305
  214. build/lib/siat/stooq.py +0 -74
  215. build/lib/siat/transaction.py +0 -347
  216. build/lib/siat/translate.py +0 -5183
  217. build/lib/siat/valuation.py +0 -1378
  218. build/lib/siat/valuation_china.py +0 -2076
  219. build/lib/siat/var_model_validation.py +0 -444
  220. build/lib/siat/yf_name.py +0 -811
  221. siat-3.10.132.dist-info/RECORD +0 -218
@@ -1,386 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- @function: 预测美股股价,教学演示用,其他用途责任自负
4
- @model:线性模型,ols, righe, lasso, elasticnet
5
- @version:v1.0,2019.4.4
6
- @purpose: 仅限机器学习课程案例使用
7
- @author: 王德宏,北京外国语大学国际商学院
8
- """
9
-
10
- #=====================================================================
11
- def get_stock_price(ticker,atdate,fromdate):
12
- """
13
- 功能:抓取美股股价
14
- 输出:指定美股的收盘价格序列,最新日期的股价排列在前
15
- ticker:美股股票代码
16
- atdate:当前日期,既可以是今天日期,也可以是一个历史日期,datetime类型
17
- fromdate:样本开始日期,尽量远的日期,以便取得足够多的原始样本,类型同atdate
18
- """
19
-
20
- #仅为调试用的函数入口参数,正式使用前需要注释掉!
21
- #ticker='MSFT'
22
- #atdate='3/29/2019'
23
- #fromdate='1/1/2015'
24
- #---------------------------------------------
25
-
26
- #抓取美股股票价格
27
- from pandas_datareader import data
28
- price=data.DataReader(ticker,'stooq',fromdate,atdate)
29
-
30
- #去掉比起始日期更早的样本
31
- price2=price[price.index >= fromdate]
32
-
33
-
34
- #按日期降序排序,近期的价格排在前面
35
- sortedprice=price2.sort_index(axis=0,ascending=False)
36
-
37
- #提取日期和星期几
38
- #sortedprice['Date']=sortedprice.index.date
39
- sortedprice['Date']=sortedprice.index.strftime("%Y-%m-%d")
40
- sortedprice['Weekday']=sortedprice.index.weekday+1
41
-
42
- #生成输出数据格式:日期,星期几,收盘价
43
- dfprice=sortedprice[['Date','Weekday','Close']]
44
-
45
- return dfprice
46
-
47
-
48
- if __name__=='__main__':
49
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
50
- dfprice.head(5)
51
- dfprice.tail(3)
52
- dfprice[dfprice.Date == '2019-03-29']
53
- dfprice[(dfprice.Date>='2019-03-20') & (dfprice.Date<='2019-03-29')]
54
-
55
-
56
- #=====================================================================
57
- def make_price_sample(dfprice,n_nextdays=1,n_samples=240,n_features=20):
58
- """
59
- 功能:生成指定股票的价格样本
60
- ticker:美股股票代码
61
- n_nextdays:预测从atdate开始未来第几天的股价,默认为1
62
- n_samples:需要生成的样本个数,默认240个(一年的平均交易天数)
63
- n_features:使用的特征数量,默认20个(一个月的平均交易天数)
64
- """
65
-
66
- #提取收盘价,Series类型
67
- closeprice=dfprice.Close
68
-
69
- #将closeprice转换为机器学习需要的ndarray类型ndprice
70
- import numpy as np
71
- ndprice=np.asmatrix(closeprice,dtype=None)
72
-
73
- #生成第一个标签样本:标签矩阵y(形状:n_samples x 1)
74
- import numpy as np
75
- y=np.asmatrix(ndprice[0,0])
76
- #生成第一个特征样本:特征矩阵X(形状:n_samples x n_features)
77
- X=ndprice[0,n_nextdays:n_features+n_nextdays]
78
-
79
- #生成其余的标签样本和特征样本
80
- for i in range(1,n_samples):
81
- y_row=np.asmatrix(ndprice[0,i])
82
- y=np.append(y,y_row,axis=0)
83
-
84
- X_row=ndprice[0,(n_nextdays+i):(n_features+n_nextdays+i)]
85
- X=np.append(X,X_row,axis=0)
86
-
87
- return X,y,ndprice
88
-
89
- if __name__=='__main__':
90
- fdprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
91
- X,y,ndprice=make_price_sample(fdprice,1,240,20)
92
- y[:5]
93
- y[2:5] #第1行的序号为0
94
- X[:5]
95
- X[:-5]
96
- X[3-1,2-1]
97
-
98
-
99
- #=====================================================================
100
- def bestR1(X,y):
101
- """
102
- 功能:给定特征矩阵和标签,使用岭回归,返回最优的alpha参数和模型
103
- 最优策略:测试集分数最高,不管过拟合问题
104
- """
105
-
106
- import numpy as np
107
- #将整个样本随机分割为训练集和测试集
108
- from sklearn.model_selection import train_test_split
109
- X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)
110
-
111
- #初始化alpha,便于判断上行下行方向
112
- alphalist=[0.001,0.0011,0.00999,0.01,0.01001,0.999,1,1.01, \
113
- 9.99,10,10.01,99,100,101,999,1000,1001,10000]
114
-
115
- from sklearn.linear_model import RidgeCV
116
- reg=RidgeCV(alphas=alphalist,cv=5,fit_intercept=True,normalize=True)
117
-
118
- reg.fit(X_train, y_train)
119
- score_train=reg.score(X_train, y_train)
120
- score_test=reg.score(X_test, y_test)
121
- alpha=reg.alpha_
122
- #print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
123
-
124
- #确定alpha参数的优化范围
125
- if alpha in [0.001,0.01,1,2,10,100,1000,10000]:
126
- #print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
127
- return reg,alpha,score_train,score_test
128
-
129
- if 0.001 < alpha < 0.01:
130
- alphalist1=np.arange(0.001,0.01,0.0005)
131
- if 0.01 < alpha < 1:
132
- alphalist1=np.arange(0.01,1,0.005)
133
- if 1 < alpha < 10:
134
- alphalist1=np.arange(1,10,0.01)
135
- if 10 < alpha < 100:
136
- alphalist1=np.arange(10,100,0.1)
137
- if 100 < alpha < 1000:
138
- alphalist1=np.arange(100,1000,1)
139
- if 1000 < alpha < 10000:
140
- alphalist1=np.arange(1000,10000,10)
141
-
142
- reg1=RidgeCV(alphas=alphalist1,cv=5,fit_intercept=True,normalize=True)
143
- reg1.fit(X_train, y_train)
144
- score1_train=reg1.score(X_train,y_train)
145
- score1_test =reg1.score(X_test, y_test)
146
- alpha1=reg1.alpha_
147
-
148
- #print("%.5f, %.5f, %.5f"%(alpha1,score1_train,score1_test))
149
- return reg1,alpha1,score1_train,score1_test
150
-
151
-
152
- if __name__=='__main__':
153
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
154
- X,y,ndprice=make_price_sample(dfprice,1,240,20)
155
-
156
- model,alpha,score_train,score_test=bestR1(X,y)
157
- print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
158
- #结果:0.045,0.9277,0.8940
159
-
160
- X_new=ndprice[0,0:20]
161
- y_new=model.predict(X_new)
162
- print("%.2f"%y_new)
163
- #结果:119.43
164
- #=====================================================================
165
- def bestL1(X,y):
166
- """
167
- 功能:给定特征矩阵和标签,使用拉索回归,返回最优的alpha参数和模型
168
- 最优策略:测试集分数最高,不管过拟合问题
169
- """
170
- import numpy as np
171
- #将整个样本随机分割为训练集和测试集
172
- from sklearn.utils import column_or_1d
173
- y=column_or_1d(y,warn=False)
174
- from sklearn.model_selection import train_test_split
175
- X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)
176
-
177
-
178
- #初始alpha,便于判断上行下行方向
179
- alphalist=[0.001,0.0011,0.00999,0.01,0.01001,0.999,1,1.01,1.99,2,2.01, \
180
- 9.99,10,10.01,99,100,101,999,1000,1001,10000]
181
-
182
- from sklearn.linear_model import LassoCV
183
- reg=LassoCV(alphas=alphalist,max_iter=10**6, \
184
- cv=5,fit_intercept=True,normalize=True)
185
- reg.fit(X_train, y_train)
186
- score_train=reg.score(X_train,y_train)
187
- score_test =reg.score(X_test, y_test)
188
- alpha=reg.alpha_
189
- #print("Step0: %.4f, %.5f, %.5f"%(alpha,score_train,score_test))
190
-
191
- #确定alpha参数的优化范围
192
- if alpha in [0.001,0.01,1,2,10,100,1000,10000]:
193
- #print("Step01: %.5f, %.5f, %.5f"%(alpha,score_train,score_test))
194
- return reg,alpha,score_train,score_test
195
-
196
- if 0.001 < alpha < 0.01:
197
- alphalist1=np.arange(0.0015,0.01,0.0005)
198
-
199
- if 0.01 < alpha < 1:
200
- alphalist1=np.arange(0.015,1,0.005)
201
-
202
- if 1 < alpha < 10:
203
- alphalist1=np.arange(1.01,10,0.01)
204
-
205
- if 10 < alpha < 100:
206
- alphalist1=np.arange(10.1,100,0.1)
207
-
208
- if 100 < alpha < 1000:
209
- alphalist1=np.arange(101,1000,1)
210
-
211
- if 1000 < alpha < 10000:
212
- alphalist1=np.arange(1010,10000,10)
213
-
214
- reg1=LassoCV(alphas=alphalist1,cv=5,fit_intercept=True,normalize=True)
215
- reg1.fit(X_train, y_train)
216
- score1_train=reg1.score(X_train,y_train)
217
- score1_test =reg1.score(X_test, y_test)
218
- alpha1=reg1.alpha_
219
- #print("Step1: %.4f, %.5f, %.5f"%(alpha1,score1_train,score1_test))
220
- return reg1,alpha1,score1_train,score1_test
221
-
222
- if __name__=='__main__':
223
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
224
- X,y,ndprice=make_price_sample(dfprice,1,240,20)
225
-
226
- model,alpha,score_train,score_test=bestL1(X,y)
227
- print("%.5f, %.5f, %.5f"%(alpha,score_train,score_test))
228
- #结果:0.015,0.9284,0.9043
229
-
230
- X_new=ndprice[0,0:20]
231
- y_new=model.predict(X_new)
232
- print("%.2f"%y_new)
233
- #结果:119.37
234
-
235
- #=====================================================================
236
-
237
- def bestEN2(X,y,maxalpha=2):
238
- """
239
- 功能:给定特征矩阵和标签,使用弹性网络回归,返回最优的alpha参数和模型
240
- 最优策略:利用ElasticNetCV筛选机制,速度慢
241
- """
242
- #将整个样本随机分割为训练集和测试集
243
- from sklearn.utils import column_or_1d
244
- y=column_or_1d(y,warn=False)
245
-
246
- from sklearn.model_selection import train_test_split
247
- X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=66)
248
-
249
- #限定参数范围
250
- import numpy as np
251
- alphalist=np.arange(0.01,maxalpha,0.01)
252
- l1list =np.arange(0.01,1,0.01)
253
-
254
- from sklearn.linear_model import ElasticNetCV
255
- reg=ElasticNetCV(alphas=alphalist,l1_ratio=l1list)
256
-
257
- reg.fit(X_train, y_train)
258
- score_train=reg.score(X_train,y_train)
259
- score_test =reg.score(X_test, y_test)
260
- alpha=reg.alpha_
261
- l1ratio=reg.l1_ratio_
262
-
263
- return reg,alpha,l1ratio,score_train,score_test
264
-
265
- if __name__=='__main__':
266
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
267
- X,y,ndprice=make_price_sample(dfprice,1,240,20)
268
-
269
- model,alpha,l1ratio,score_train,score_test=bestEN2(X,y)
270
- print("%.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
271
- #结果:0.42,0.99,0.9258,0.9174
272
-
273
- X_new=ndprice[0,0:20]
274
- y_new=model.predict(X_new)
275
- print("%.2f"%y_new)
276
- #结果:119.60
277
-
278
- #=======
279
-
280
- #==============================================================================
281
-
282
- def bestEN3(X,y):
283
- """
284
- 功能:给定特征矩阵和标签,使用弹性网络回归,返回最优的alpha参数和模型
285
- 最优策略:利用cv交叉验证,速度快
286
- 算法贡献者:徐乐欣(韩语国商)
287
- """
288
- import numpy as np
289
- #将整个样本随机分割为训练集和测试集
290
- from sklearn.utils import column_or_1d
291
- y=column_or_1d(y,warn=False)
292
- from sklearn.model_selection import train_test_split
293
- X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=66)
294
-
295
- from sklearn.linear_model import ElasticNetCV
296
- #reg=ElasticNetCV(cv=5, random_state=0)
297
- #reg.fit(X,y)
298
-
299
- l1list=np.arange(0.01,1,0.01)
300
- ENet=ElasticNetCV(alphas=None, copy_X=True, cv=5, eps=0.001, \
301
- fit_intercept=True,l1_ratio=l1list, max_iter=8000, \
302
- n_alphas=100, n_jobs=None,normalize=True, \
303
- positive=False, precompute='auto', random_state=0, \
304
- selection='cyclic', tol=0.0001, verbose=0)
305
- ENet.fit(X_train, y_train)
306
- score_train=ENet.score(X_train, y_train)
307
- score_test=ENet.score(X_test, y_test)
308
- alpha=ENet.alpha_
309
- l1ratio=ENet.l1_ratio_
310
- #print("S1: %.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
311
-
312
- return ENet,alpha,l1ratio,score_train,score_test
313
-
314
- if __name__=='__main__':
315
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
316
- X,y,ndprice=make_price_sample(dfprice,1,240,20)
317
-
318
- model,alpha,l1ratio,score_train,score_test=bestEN3(X,y)
319
- print("%.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
320
- #结果:0.005836,0.99,0.925,0.9194
321
-
322
- X_new=ndprice[0,0:20]
323
- y_new=model.predict(X_new)
324
- print("%.2f"%y_new)
325
- #结果:119.48
326
- #==============================================================================
327
-
328
-
329
- def bestEN1(X,y,maxalpha=2):
330
- """
331
- 功能:给定特征矩阵和标签,使用弹性网络回归,返回最优的alpha参数和模型
332
- 最优策略:对alpha和l1_ratio进行暴力枚举,搜索最高测试集分数,速度中等
333
- 算法贡献者:徐乐欣(韩语国商)
334
- """
335
-
336
- #将整个样本随机分割为训练集和测试集
337
- from sklearn.utils import column_or_1d
338
- y=column_or_1d(y,warn=False)
339
- from sklearn.model_selection import train_test_split
340
- X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=66)
341
-
342
- #设立初始测试集分数门槛
343
- king_score=0.6
344
- from sklearn.linear_model import ElasticNet
345
-
346
- #限定参数范围
347
- import numpy as np
348
- alphalist=np.arange(0.01,maxalpha,0.01)
349
- l1list =np.arange(0.01,1,0.01)
350
-
351
- for i in alphalist:
352
- for j in l1list:
353
- reg=ElasticNet(alpha=i,l1_ratio=j)
354
- reg.fit(X_train,y_train)
355
- temp_score=reg.score(X_test,y_test)
356
- if temp_score > king_score:
357
- king_score=temp_score
358
- alpha=i
359
- l1ratio=j
360
- score_train=reg.score(X_train,y_train)
361
- score_test=temp_score
362
- model=reg
363
-
364
- return model,alpha,l1ratio,score_train,score_test
365
-
366
- if __name__=='__main__':
367
- dfprice=get_stock_price('MSFT','4/3/2019','1/1/2015')
368
- X,y,ndprice=make_price_sample(dfprice,1,240,20)
369
-
370
- model,alpha,l1ratio,score_train,score_test=bestEN1(X,y)
371
- print("%.5f, %.5f, %.5f, %.5f"%(alpha,l1ratio,score_train,score_test))
372
- #结果:1.31,0.56,0.9241,0.9196
373
-
374
- X_new=ndprice[0,0:20]
375
- y_new=model.predict(X_new)
376
- print("%.2f"%y_new)
377
- #结果:119.36
378
- #==============================================================================
379
-
380
-
381
-
382
-
383
-
384
-
385
- #==============================================================================
386
-