ezKit 1.9.12__py3-none-any.whl → 1.10.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
ezKit/utils.py CHANGED
@@ -9,12 +9,11 @@ import time
9
9
  import tomllib
10
10
  from copy import deepcopy
11
11
  from itertools import islice
12
- from multiprocessing import Pool, Process
12
+ from multiprocessing import Pool
13
13
  from multiprocessing.pool import ThreadPool
14
14
  from pathlib import Path
15
15
  from shutil import rmtree
16
- from threading import Thread
17
- from typing import Any, Callable, List, Optional, Union
16
+ from typing import Any, Callable, List, Union
18
17
  from urllib.parse import ParseResult, urlparse
19
18
  from uuid import uuid4
20
19
 
@@ -742,31 +741,23 @@ def parent_dir(
742
741
  # --------------------------------------------------------------------------------------------------
743
742
 
744
743
 
745
- def retry(
746
- times: int,
747
- func: Callable,
748
- **kwargs
749
- ):
744
+ def retry(func: Callable, times: int = 3, **kwargs):
750
745
  """重试"""
746
+
751
747
  # 函数传递参数: https://stackoverflow.com/a/803632
752
748
  # callable() 判断类型是非为函数: https://stackoverflow.com/a/624939
753
- try:
754
- _num = 0
755
- while True:
756
- # 重试次数判断 (0 表示无限次数, 这里条件使用 > 0, 表示有限次数)
757
- if times > 0:
758
- _num += 1
759
- if _num > times:
760
- return
761
- # 执行函数
762
- try:
763
- return func(**kwargs)
764
- except Exception as e:
765
- logger.exception(e)
749
+
750
+ for attempt in range(times):
751
+ try:
752
+ # 执行函数并结果
753
+ return func(**kwargs)
754
+ except Exception as e:
755
+ logger.exception(e)
756
+ if attempt < (times - 1):
766
757
  logger.info('retrying ...')
767
- continue
768
- except Exception as e:
769
- logger.exception(e)
758
+ else:
759
+ logger.error("all retries failed")
760
+ return False
770
761
 
771
762
 
772
763
  # --------------------------------------------------------------------------------------------------
@@ -1199,65 +1190,61 @@ def delete_directory(
1199
1190
  # --------------------------------------------------------------------------------------------------
1200
1191
 
1201
1192
 
1202
- def process_pool(
1193
+ def processor(
1203
1194
  process_func: Callable,
1204
- process_data: Any = None,
1195
+ process_data: List[Any],
1205
1196
  process_num: int = 2,
1206
1197
  thread: bool = False,
1207
1198
  **kwargs
1208
- ) -> list | bool:
1209
- """
1210
- 多线程(MultiThread) | 多进程(MultiProcess)
1211
- """
1212
- # ThreadPool 线程池
1199
+ ) -> Union[List[Any], bool]:
1200
+ """使用多线程或多进程对数据进行并行处理"""
1201
+
1202
+ # :param process_func: 处理函数
1203
+ # :param process_data: 待处理数据列表
1204
+ # :param process_num: 并行数量
1205
+ # :param thread: 是否使用多线程
1206
+ # :param kwargs: 其他可选参数传递给线程池或进程池
1207
+ # :return: 处理后的结果列表或 False(异常情况)
1208
+ #
1209
+ # MultiThread 多线程
1210
+ # MultiProcess 多进程
1211
+ #
1212
+ # ThreadPool 线程池
1213
+ # Pool 进程池
1214
+ #
1213
1215
  # ThreadPool 共享内存, Pool 不共享内存
1214
1216
  # ThreadPool 可以解决 Pool 在某些情况下产生的 Can't pickle local object 的错误
1215
- # https://stackoverflow.com/a/58897266
1216
- try:
1217
+ # https://stackoverflow.com/a/58897266
1218
+ #
1219
+ # 如果要启动一个新的进程或者线程, 将 process_num 设置为 1 即可
1217
1220
 
1218
- # 处理数据
1219
- if len(process_data) <= process_num:
1220
- process_num = len(process_data)
1221
- _data = process_data
1222
- else:
1223
- _data = list_split(process_data, process_num, equally=True)
1221
+ try:
1224
1222
 
1225
- if _data is None:
1223
+ # 检查参数
1224
+ if not check_arguments([(process_data, list, "process_data")]):
1226
1225
  return False
1227
1226
 
1228
- # 执行函数
1229
- if isTrue(thread, bool):
1230
- # 多线程
1231
- logger.info("execute multi thread ......")
1232
- with ThreadPool(process_num, **kwargs) as p:
1233
- return p.map(process_func, _data)
1234
- else:
1235
- # 多进程
1236
- logger.info("execute multi process ......")
1237
- with Pool(process_num, **kwargs) as p:
1238
- return p.map(process_func, _data)
1227
+ # 确保并行数不超过数据量
1228
+ process_num = min(len(process_data), process_num)
1229
+ _data_chunks = (
1230
+ list_split(process_data, process_num, equally=True)
1231
+ if process_num > 1
1232
+ else [process_data]
1233
+ )
1239
1234
 
1240
- except Exception as e:
1241
- logger.exception(e)
1242
- return False
1235
+ if not _data_chunks:
1236
+ logger.error("data chunks error")
1237
+ return False
1243
1238
 
1239
+ logger.info(
1240
+ f"Starting {'multi-threading' if thread else 'multi-processing'} with {process_num} workers..."
1241
+ )
1242
+
1243
+ # 执行多线程或多进程任务
1244
+ pool_cls = ThreadPool if thread else Pool
1245
+ with pool_cls(process_num, **kwargs) as pool:
1246
+ return pool.map(process_func, _data_chunks)
1244
1247
 
1245
- def new_process(
1246
- process_func: Callable,
1247
- process_data: Any = None,
1248
- thread: bool = False,
1249
- daemon: bool = True,
1250
- **kwargs
1251
- ) -> Thread | Process | bool:
1252
- """New Process"""
1253
- try:
1254
- if isTrue(thread, bool):
1255
- process = Thread(target=process_func, args=process_data, **kwargs)
1256
- else:
1257
- process = Process(target=process_func, args=process_data, **kwargs)
1258
- process.daemon = daemon
1259
- process.start()
1260
- return process
1261
1248
  except Exception as e:
1262
1249
  logger.exception(e)
1263
1250
  return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ezKit
3
- Version: 1.9.12
3
+ Version: 1.10.1
4
4
  Summary: Easy Kit
5
5
  Author: septvean
6
6
  Author-email: septvean@gmail.com
@@ -2,19 +2,17 @@ ezKit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  ezKit/bottle.py,sha256=usKK1wVaZw4_D-4VwMYmOIc8jtz4TrpM30nck59HMFw,180178
3
3
  ezKit/bottle_extensions.py,sha256=3reEQVZuHklXTl6r7F8kiBFFPb0RaAGc3mYJJnrMDjQ,1129
4
4
  ezKit/cipher.py,sha256=0T_StbjiNI4zgrjVgcfU-ffKgu1waBA9UDudAnqFcNM,2896
5
- ezKit/cls.py,sha256=e7_72kv0Q_o023xcjKNtrkfKg7frABQvCF_JjoHV94U,10800
6
5
  ezKit/database.py,sha256=Rc4RgjHOOtf5dMLvMkK1beRfbIai5E1x4HTsDwKsA-Q,6822
7
6
  ezKit/http.py,sha256=i3Kn5AMAMicDMcDjxKKZU7zqEKTU88Ec9_LwCuBJy-0,1801
8
7
  ezKit/mongo.py,sha256=dOm_1wXEPp_e8Ml5Qq78M7FDNrQUAZaThzVIiiLJJwk,2393
9
8
  ezKit/qywx.py,sha256=X_H4fzP-iEqeDEbumr7D1bXi6dxczaxfO8iyutzy02s,7171
10
9
  ezKit/redis.py,sha256=g2_V4jvq0djRc20jLZkgeAeF_bYrq-Rbl_kHcCUPZcA,1965
11
10
  ezKit/sendemail.py,sha256=tRXCsJm_RfTJ9xEWe_lTQ5kOs2JxHGPXvq0oWA7prq0,7263
12
- ezKit/stock.py,sha256=4wphZahpiDs0MuPVCUcD22joOQldJhmXjogdroxyR00,12346
13
11
  ezKit/token.py,sha256=HKREyZj_T2S8-aFoFIrBXTaCKExQq4zE66OHXhGHqQg,1750
14
- ezKit/utils.py,sha256=TDsL3PRkQy6NdZgphkgwacbWvHqEmq4LOkfNzmxV4DY,42682
12
+ ezKit/utils.py,sha256=k3hSnOwNSyyRDVwfEzQUXQh_oJJ51KOT-PvLFPYOtOE,42517
15
13
  ezKit/xftp.py,sha256=XyIdr_2rxRVLqPofG6fIYWhAMVsFwTyp46dg5P9FLW4,7774
16
- ezKit-1.9.12.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
17
- ezKit-1.9.12.dist-info/METADATA,sha256=cCgqBuFpmH0zpv8j267E9tTkm6q1WzdfOhUCwS3MTjI,191
18
- ezKit-1.9.12.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
19
- ezKit-1.9.12.dist-info/top_level.txt,sha256=aYLB_1WODsqNTsTFWcKP-BN0KCTKcV-HZJ4zlHkCFw8,6
20
- ezKit-1.9.12.dist-info/RECORD,,
14
+ ezKit-1.10.1.dist-info/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
15
+ ezKit-1.10.1.dist-info/METADATA,sha256=7-BrvzaxysoECMgvFVFeBbVHX2RgQGbdYUrZS3xumUo,191
16
+ ezKit-1.10.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
17
+ ezKit-1.10.1.dist-info/top_level.txt,sha256=aYLB_1WODsqNTsTFWcKP-BN0KCTKcV-HZJ4zlHkCFw8,6
18
+ ezKit-1.10.1.dist-info/RECORD,,
ezKit/cls.py DELETED
@@ -1,313 +0,0 @@
1
- """财联社数据"""
2
- import re
3
-
4
- import pandas as pd
5
- import requests
6
- from loguru import logger
7
-
8
- from . import stock, utils
9
-
10
-
11
- def up_down_analysis(
12
- target: str = "up_pool",
13
- df: bool = False
14
- ) -> list | pd.DataFrame | None:
15
- """涨停跌停数据"""
16
-
17
- # 判断参数是否正确
18
- match True:
19
- case True if not utils.isTrue(target, str):
20
- logger.error("argument error: target")
21
- return None
22
- case _:
23
- pass
24
-
25
- info: str = "获取涨停池股票"
26
- match True:
27
- case True if target == "up_pool":
28
- info = "获取涨停池股票"
29
- case True if target == "continuous_up_pool":
30
- info = "获取连板池股票"
31
- case True if target == "up_open_pool":
32
- info = "获取炸板池股票"
33
- case True if target == "down_pool":
34
- info = "获取跌停池股票"
35
- case _:
36
- pass
37
-
38
- try:
39
- logger.info(f"{info} ......")
40
-
41
- user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
42
- headers = {"User-Agent": user_agent}
43
-
44
- # 涨停池: https://x-quote.cls.cn/quote/index/up_down_analysis?rever=1&way=last_px&type=up_pool
45
- # 连板池: https://x-quote.cls.cn/quote/index/up_down_analysis?rever=1&way=last_px&type=continuous_up_pool
46
- # 炸板池: https://x-quote.cls.cn/quote/index/up_down_analysis?rever=1&way=last_px&type=up_open_pool
47
- # 跌停池: https://x-quote.cls.cn/quote/index/up_down_analysis?rever=1&way=last_px&type=down_pool
48
- api = f"https://x-quote.cls.cn/quote/index/up_down_analysis?rever=1&way=last_px&type={target}"
49
-
50
- response = requests.get(api, headers=headers, timeout=10)
51
-
52
- response_dict: dict = response.json()
53
-
54
- result: list = []
55
-
56
- for i in response_dict["data"]:
57
-
58
- # if re.match(r"^(sz00|sh60)", i["secu_code"]):
59
- # print(i["secu_code"])
60
-
61
- # if re.search(r"ST|银行", i["secu_name"]):
62
- # print(i["secu_name"])
63
-
64
- # 主板, 非ST, 非银行, 非证券
65
- if (not re.match(r"^(sz00|sh60)", i["secu_code"])) or re.search(r"ST|银行|证券", i["secu_name"]):
66
- continue
67
-
68
- if target in ["up_pool", "up_pool"]:
69
- result.append({
70
- "code": stock.coderename(i["secu_code"], restore=True),
71
- "name": i["secu_name"],
72
- "up_days": i["limit_up_days"],
73
- "reason": i["up_reason"]
74
- })
75
-
76
- if target in ["up_open_pool", "down_pool"]:
77
- result.append({
78
- "code": stock.coderename(i["secu_code"], restore=True),
79
- "name": i["secu_name"]
80
- })
81
-
82
- if not utils.isTrue(df, bool):
83
- logger.success(f"{info} [成功]")
84
- return result
85
-
86
- # data: pd.DataFrame = pd.DataFrame(response_dict["data"], columns=["secu_code", "secu_name", "limit_up_days", "up_reason"])
87
- # data = data.rename(columns={"secu_code": "code", "secu_name": "name", "limit_up_days": "up_days", "up_reason": "reason"})
88
-
89
- return pd.DataFrame(data=pd.DataFrame(result))
90
-
91
- except Exception as e:
92
- logger.error(f"{info} [失败]")
93
- logger.exception(e)
94
- return None
95
-
96
-
97
- # --------------------------------------------------------------------------------------------------
98
-
99
-
100
- def latest_data(
101
- payload: str | dict,
102
- data_type: str = "stock",
103
- df: bool = False
104
- ) -> list | pd.DataFrame | None:
105
- """股票或板块的最新数据"""
106
-
107
- # 热门板块
108
- # https://www.cls.cn/hotPlate
109
- # 行业板块
110
- # https://x-quote.cls.cn/web_quote/plate/plate_list?rever=1&way=change&type=industry
111
- # 概念板块
112
- # https://x-quote.cls.cn/web_quote/plate/plate_list?rever=1&way=change&type=concept
113
- # 地域板块
114
- # https://x-quote.cls.cn/web_quote/plate/plate_list?rever=1&way=change&type=area
115
-
116
- # ----------------------------------------------------------------------------------------------
117
-
118
- # 判断参数类型
119
- match True:
120
- case True if not utils.isTrue(payload, (str, dict)):
121
- logger.error("argument error: payload")
122
- return None
123
- case True if not utils.isTrue(data_type, str):
124
- logger.error("argument error: data_type")
125
- return None
126
- case _:
127
- pass
128
-
129
- # ----------------------------------------------------------------------------------------------
130
-
131
- # 判断数据类型. 数据类型: 个股, 板块 (产业链: industry)
132
- if data_type not in ["stock", "plate"]:
133
- logger.error("data_type error")
134
- return None
135
-
136
- # ----------------------------------------------------------------------------------------------
137
-
138
- # 日志信息
139
-
140
- # 个股 (默认)
141
- info: str = "获取股票最新数据"
142
-
143
- # 板块
144
- if data_type == "plate":
145
- info = "获取板块最新数据"
146
-
147
- # match True:
148
- # case True if data_type == "plate":
149
- # info = "获取板块最新数据"
150
- # case True if data_type == "industry":
151
- # info = "获取产业链最新数据"
152
- # case _:
153
- # pass
154
-
155
- # ----------------------------------------------------------------------------------------------
156
-
157
- try:
158
-
159
- logger.info(f"{info} ......")
160
-
161
- # ------------------------------------------------------------------------------------------
162
-
163
- # HTTP User Agent
164
- user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
165
-
166
- # HTTP Headers
167
- headers = {"User-Agent": user_agent}
168
-
169
- # ------------------------------------------------------------------------------------------
170
-
171
- # 请求参数
172
- params: dict = {}
173
-
174
- # 默认请求参数
175
- if isinstance(payload, str) and utils.isTrue(payload, str):
176
- params = {"secu_code": payload}
177
-
178
- # 请求参数
179
- if isinstance(payload, dict) and utils.isTrue(payload, dict):
180
- params = payload
181
-
182
- # ------------------------------------------------------------------------------------------
183
-
184
- # 不直接在API后面跟参数, 使用 params 传递参数
185
-
186
- # API: 股票
187
- # api: str = f"https://x-quote.cls.cn/quote/stock/basic?secu_code={code}"
188
- api: str = "https://x-quote.cls.cn/quote/stock/basic"
189
-
190
- # API: 板块
191
- if data_type == "plate":
192
- # api = f"https://x-quote.cls.cn/web_quote/plate/stocks?secu_code={code}"
193
- api = "https://x-quote.cls.cn/web_quote/plate/stocks"
194
-
195
- # match True:
196
- # case True if data_type == "plate":
197
- # # 板块
198
- # # api = f"https://x-quote.cls.cn/web_quote/plate/stocks?secu_code={code}"
199
- # api = "https://x-quote.cls.cn/web_quote/plate/stocks"
200
- # case True if data_type == "industry":
201
- # # 产业链
202
- # # api = f"https://x-quote.cls.cn/web_quote/plate/industry?secu_code={code}"
203
- # api = "https://x-quote.cls.cn/web_quote/plate/industry"
204
- # case _:
205
- # pass
206
-
207
- # ------------------------------------------------------------------------------------------
208
-
209
- # 获取数据
210
- # response = requests.get(api, headers=headers, timeout=10)
211
- response = requests.get(api, headers=headers, params=params, timeout=10)
212
-
213
- # 转换数据类型
214
- response_dict: dict = response.json()
215
-
216
- # 判断数据是否正确
217
- if True not in [utils.isTrue(response_dict["data"], dict), utils.isTrue(response_dict["data"], list)]:
218
- logger.error(f"{info} [失败]")
219
- return None
220
-
221
- # ------------------------------------------------------------------------------------------
222
-
223
- # 个股
224
-
225
- if data_type == "stock":
226
-
227
- # 停牌, 返回 None
228
- if response_dict["data"]["trade_status"] == "STOPT":
229
- logger.error(f"{info} [停牌]")
230
- return None
231
-
232
- # pd.DataFrame 数据
233
- if utils.isTrue(df, bool):
234
- df_data = {
235
- # "date": [pd.to_datetime(date_today)],
236
- "open": [float(response_dict["data"]["open_px"])],
237
- "close": [float(response_dict["data"]["last_px"])],
238
- "high": [float(response_dict["data"]["high_px"])],
239
- "low": [float(response_dict["data"]["low_px"])],
240
- "volume": [int(response_dict["data"]["business_amount"])],
241
- "turnover": [float(response_dict["data"]["tr"])]
242
- }
243
- logger.success(f"{info} [成功]")
244
- return pd.DataFrame(data=df_data)
245
-
246
- # 默认返回的数据
247
- logger.success(f"{info} [成功]")
248
- return response_dict["data"]
249
-
250
- # ------------------------------------------------------------------------------------------
251
-
252
- # 板块
253
-
254
- # 板块数据不能转换为 pd.DataFrame
255
- if (data_type == "plate") and utils.isTrue(df, bool):
256
- logger.error(f"{info} [错误]")
257
- return None
258
-
259
- # 数据结果
260
- result: list = []
261
-
262
- # 筛选 主板, 非ST, 非银行, 非证券 的股票
263
- for i in response_dict["data"]["stocks"]:
264
- if (re.match(r"^(sz00|sh60)", i["secu_code"])) and (not re.search(r"ST|银行|证券", i["secu_name"])):
265
- result.append(i)
266
-
267
- # 返回数据
268
- logger.success(f"{info} [成功]")
269
- return result
270
-
271
- except Exception as e:
272
- logger.error(f"{info} [失败]")
273
- logger.exception(e)
274
- return None
275
-
276
-
277
- # --------------------------------------------------------------------------------------------------
278
-
279
-
280
- def plate_codes(
281
- plate: str
282
- ) -> list | None:
283
- """获取板块成分股代码"""
284
-
285
- # 判断参数是否正确
286
- match True:
287
- case True if not utils.isTrue(plate, str):
288
- logger.error("argument error: plate")
289
- return None
290
- case _:
291
- pass
292
-
293
- info: str = "获取板块成分股代码"
294
-
295
- try:
296
-
297
- logger.info(f"{info} ......")
298
-
299
- items = latest_data(payload=plate, data_type="plate")
300
-
301
- if isinstance(items, list):
302
- codes: list = [stock.coderename(i["secu_code"], restore=True) for i in items]
303
- codes.sort()
304
- logger.success(f"{info} [成功]")
305
- return codes
306
-
307
- logger.error(f"{info} [失败]")
308
- return None
309
-
310
- except Exception as e:
311
- logger.error(f"{info} [失败]")
312
- logger.exception(e)
313
- return None
ezKit/stock.py DELETED
@@ -1,355 +0,0 @@
1
- """股票"""
2
- import re
3
- from copy import deepcopy
4
-
5
- import akshare as ak
6
- import numpy as np
7
- import talib as ta
8
- from loguru import logger
9
- from pandas import DataFrame
10
- from sqlalchemy.engine import Engine
11
-
12
- from . import utils
13
-
14
-
15
- def coderename(
16
- target: str | dict,
17
- restore: bool = False
18
- ) -> str | dict | None:
19
- """代码重命名"""
20
-
21
- # 正向:
22
- # coderename('000001') => 'sz000001'
23
- # coderename({'code': '000001', 'name': '平安银行'}) => {'code': 'sz000001', 'name': '平安银行'}
24
- # 反向:
25
- # coderename('sz000001', restore=True) => '000001'
26
- # coderename({'code': 'sz000001', 'name': '平安银行'}) => {'code': '000001', 'name': '平安银行'}
27
-
28
- # 判断参数是否正确
29
- match True:
30
- case True if not utils.isTrue(target, (str, dict)):
31
- logger.error("argument error: target")
32
- return None
33
- case _:
34
- pass
35
-
36
- try:
37
-
38
- # 初始化
39
- code_object: dict = {}
40
- code_name: str | dict = ""
41
-
42
- # 判断 target 是 string 还是 dictionary
43
- if isinstance(target, str) and utils.isTrue(target, str):
44
- code_name = target
45
- elif isinstance(target, dict) and utils.isTrue(target, dict):
46
- code_object = deepcopy(target)
47
- code_name = str(deepcopy(target["code"]))
48
- else:
49
- return None
50
-
51
- # 是否还原
52
- if utils.isTrue(restore, bool):
53
- if len(code_name) == 8 and re.match(r"^(sz|sh)", code_name):
54
- code_name = deepcopy(code_name[2:8])
55
- else:
56
- return None
57
- else:
58
- if code_name[0:2] == "00":
59
- code_name = f"sz{code_name}"
60
- elif code_name[0:2] == "60":
61
- code_name = f"sh{code_name}"
62
- else:
63
- return None
64
-
65
- # 返回结果
66
- if utils.isTrue(target, str):
67
- return code_name
68
-
69
- if utils.isTrue(target, dict):
70
- code_object["code"] = code_name
71
- return code_object
72
-
73
- return None
74
-
75
- except Exception as e:
76
- logger.exception(e)
77
- return None
78
-
79
-
80
- # --------------------------------------------------------------------------------------------------
81
-
82
-
83
- def kdj_vector(
84
- df: DataFrame,
85
- kdj_options: tuple[int, int, int] = (9, 3, 3)
86
- ) -> DataFrame | None:
87
- """KDJ计算器"""
88
-
89
- # 计算周期:Calculation Period, 也可使用 Lookback Period 表示回溯周期, 指用于计算指标值的时间周期.
90
- # 移动平均周期: Smoothing Period 或 Moving Average Period, 指对指标进行平滑处理时采用的周期.
91
- # 同花顺默认参数: 9 3 3
92
- # https://www.daimajiaoliu.com/daima/4ed4ffa26100400
93
- # 说明: KDJ 指标的中文名称又叫随机指标, 融合了动量观念、强弱指标和移动平均线的一些优点, 能够比较迅速、快捷、直观地研判行情, 被广泛用于股市的中短期趋势分析.
94
- # 有采用 ewm 使用 com=2 的, 但是如果使用 com=2 在默认值的情况下KDJ值是正确的.
95
- # 但是非默认值, 比如调整参数, 尝试慢速 KDJ 时就不对了, 最终采用 alpha = 1/m 的情况, 对比同花顺数据, 是正确的.
96
-
97
- # 检查参数
98
- if isinstance(df, DataFrame) and df.empty:
99
- logger.error("argument error: df")
100
- return None
101
-
102
- if not utils.check_arguments([(kdj_options, tuple, "kdj_options")]):
103
- return None
104
-
105
- if not all(utils.isTrue(item, int) for item in kdj_options):
106
- logger.error("argument error: kdj_options")
107
- return None
108
-
109
- try:
110
- low_list = df['low'].rolling(kdj_options[0]).min()
111
- high_list = df['high'].rolling(kdj_options[0]).max()
112
- rsv = (df['close'] - low_list) / (high_list - low_list) * 100
113
- df['K'] = rsv.ewm(alpha=1 / kdj_options[1], adjust=False).mean()
114
- df['D'] = df['K'].ewm(alpha=1 / kdj_options[2], adjust=False).mean()
115
- df['J'] = (3 * df['K']) - (2 * df['D'])
116
- return df
117
- except Exception as e:
118
- logger.exception(e)
119
- return None
120
-
121
-
122
- # --------------------------------------------------------------------------------------------------
123
-
124
-
125
- def data_vector(
126
- df: DataFrame,
127
- macd_options: tuple[int, int, int] = (12, 26, 9),
128
- kdj_options: tuple[int, int, int] = (9, 3, 3)
129
- ) -> DataFrame | None:
130
- """数据运算"""
131
-
132
- # 检查参数
133
- if isinstance(df, DataFrame) and df.empty:
134
- logger.error("argument error: df")
135
- return None
136
-
137
- if not utils.check_arguments([(macd_options, tuple, "macd_options"), (kdj_options, tuple, "kdj_options")]):
138
- return None
139
-
140
- if not all(utils.isTrue(item, int) for item in macd_options):
141
- logger.error("argument error: macd_options")
142
- return None
143
-
144
- if not all(utils.isTrue(item, int) for item in kdj_options):
145
- logger.error("argument error: kdj_options")
146
- return None
147
-
148
- try:
149
-
150
- # ------------------------------------------------------------------------------------------
151
-
152
- # 计算均线: 3,7日均线
153
- # pylint: disable=E1101
154
- # df['SMA03'] = ta.SMA(df['close'], timeperiod=3) # type: ignore
155
- # df['SMA07'] = ta.SMA(df['close'], timeperiod=7) # type: ignore
156
-
157
- # 3,7日均线金叉: 0 无, 1 金叉, 2 死叉
158
- # df['SMA37_X'] = 0
159
- # sma37_position = df['SMA03'] > df['SMA07']
160
- # df.loc[sma37_position[(sma37_position is True) & (sma37_position.shift() is False)].index, 'SMA37_X'] = 1 # type: ignore
161
- # df.loc[sma37_position[(sma37_position is False) & (sma37_position.shift() is True)].index, 'SMA37_X'] = 2 # type: ignore
162
-
163
- # 计算均线: 20,25日均线
164
- # df['SMA20'] = ta.SMA(df['close'], timeperiod=20) # type: ignore
165
- # df['SMA25'] = ta.SMA(df['close'], timeperiod=25) # type: ignore
166
-
167
- # 20,25日均线金叉: 0 无, 1 金叉, 2 死叉
168
- # df['SMA225_X'] = 0
169
- # sma225_position = df['SMA20'] > df['SMA25']
170
- # df.loc[sma225_position[(sma225_position is True) & (sma225_position.shift() is False)].index, 'SMA225_X'] = 1 # type: ignore
171
- # df.loc[sma225_position[(sma225_position is False) & (sma225_position.shift() is True)].index, 'SMA225_X'] = 2 # type: ignore
172
-
173
- # ------------------------------------------------------------------------------------------
174
-
175
- # 计算 MACD: 默认参数 12 26 9
176
- macd_dif, macd_dea, macd_bar = ta.MACD( # type: ignore
177
- df['close'].values,
178
- fastperiod=macd_options[0],
179
- slowperiod=macd_options[1],
180
- signalperiod=macd_options[2]
181
- )
182
-
183
- macd_dif[np.isnan(macd_dif)], macd_dea[np.isnan(macd_dea)], macd_bar[np.isnan(macd_bar)] = 0, 0, 0
184
-
185
- # https://www.bilibili.com/read/cv10185856
186
- df['MACD'] = 2 * (macd_dif - macd_dea)
187
- df['MACD_DIF'] = macd_dif
188
- df['MACD_DEA'] = macd_dea
189
-
190
- # 初始化 MACD_X 列(0 无, 1 金叉, 2 死叉)
191
- df['MACD_X'] = 0
192
-
193
- # 计算 MACD 条件
194
- macd_position = df['MACD_DIF'] > df['MACD_DEA']
195
-
196
- # 设置 MACD_X = 1: 从 False 变为 True 的位置
197
- df.loc[macd_position & ~macd_position.shift(fill_value=False), 'MACD_X'] = 1
198
-
199
- # 设置 MACD_X = 2: 从 True 变为 False 的位置
200
- df.loc[~macd_position & macd_position.shift(fill_value=False), 'MACD_X'] = 2
201
-
202
- # 将浮点数限制为小数点后两位
203
- df['MACD'] = df['MACD'].round(2)
204
- df['MACD_DIF'] = df['MACD_DIF'].round(2)
205
- df['MACD_DEA'] = df['MACD_DEA'].round(2)
206
-
207
- # ------------------------------------------------------------------------------------------
208
-
209
- # # 计算 KDJ: : 默认参数 9 3 3
210
- kdj_data = kdj_vector(df, kdj_options)
211
-
212
- if kdj_data is not None:
213
-
214
- # KDJ 数据
215
- df['K'] = kdj_data['K'].values
216
- df['D'] = kdj_data['D'].values
217
- df['J'] = kdj_data['J'].values
218
-
219
- # 初始化 KDJ_X 列(0 无, 1 金叉, 2 死叉)
220
- df['KDJ_X'] = 0
221
-
222
- # 计算 MACD 条件
223
- kdj_position = df['J'] > df['D']
224
-
225
- # 设置 KDJ_X = 1: 从 False 变为 True 的位置
226
- df.loc[kdj_position & ~kdj_position.shift(fill_value=False), 'KDJ_X'] = 1
227
-
228
- # 设置 KDJ_X = 2: 从 True 变为 False 的位置
229
- df.loc[~kdj_position & kdj_position.shift(fill_value=False), 'KDJ_X'] = 2
230
-
231
- # 将浮点数限制为小数点后两位
232
- df['K'] = df['K'].round(2)
233
- df['D'] = df['D'].round(2)
234
- df['J'] = df['J'].round(2)
235
-
236
- # ------------------------------------------------------------------------------------------
237
-
238
- return df
239
-
240
- except Exception as e:
241
- logger.exception(e)
242
- return None
243
-
244
-
245
- # --------------------------------------------------------------------------------------------------
246
-
247
-
248
- def get_code_name_from_akshare() -> DataFrame | None:
249
- """获取股票代码和名称"""
250
- info = "获取股票代码和名称"
251
- try:
252
- logger.info(f"{info} ......")
253
- df: DataFrame = ak.stock_info_a_code_name()
254
- if df.empty:
255
- logger.error(f"{info} [失败]")
256
- return None
257
- # 排除 ST、证券和银行
258
- # https://towardsdatascience.com/8-ways-to-filter-pandas-dataframes-d34ba585c1b8
259
- df = df[df.code.str.contains("^00|^60") & ~df.name.str.contains("ST|证券|银行")]
260
- logger.success(f"{info} [成功]")
261
- return df
262
- except Exception as e:
263
- logger.error(f"{info} [失败]")
264
- logger.exception(e)
265
- return None
266
-
267
-
268
- # --------------------------------------------------------------------------------------------------
269
-
270
-
271
- def get_stock_data_from_akshare(
272
- code: str,
273
- adjust: str = "qfq",
274
- period: str = "daily",
275
- start_date: str = "19700101",
276
- end_date: str = "20500101",
277
- timeout: float = 10
278
- ) -> DataFrame | None:
279
- """从 akshare 获取股票数据"""
280
- info = f"获取股票数据: {code}"
281
- try:
282
- logger.info(f"{info} ......")
283
- # https://akshare.akfamily.xyz/data/stock/stock.html#id22
284
- df: DataFrame = ak.stock_zh_a_hist(symbol=code, adjust=adjust, period=period, start_date=start_date, end_date=end_date, timeout=timeout)
285
- df = df.rename(columns={
286
- "日期": "date",
287
- "开盘": "open",
288
- "收盘": "close",
289
- "最高": "high",
290
- "最低": "low",
291
- "成交量": "volume"
292
- })
293
- logger.success(f"{info} [成功]")
294
- return df[['date', 'open', 'close', 'high', 'low', 'volume']].copy()
295
- except Exception as e:
296
- logger.error(f"{info} [失败]")
297
- logger.exception(e)
298
- return None
299
-
300
-
301
- # --------------------------------------------------------------------------------------------------
302
-
303
-
304
- def save_data_to_database(engine: Engine, code: str, latest: bool = False) -> bool:
305
- """保存股票所有数据到数据库"""
306
-
307
- # 默认将所有数据保存到数据库中的表里
308
- # 如果 latest 为 True, 插入最新的数据到数据库中的表里
309
- # 即: 将最后一条数据插入到数据库中的表里
310
-
311
- info: str = "保存股票所有数据到数据库"
312
-
313
- if utils.isTrue(latest, bool):
314
- info = "保存股票最新数据到数据库"
315
-
316
- try:
317
-
318
- logger.info(f"{info} ......")
319
-
320
- # 代码名称转换
321
- name = coderename(code)
322
-
323
- if not isinstance(name, str):
324
- logger.error(f"{info} [代码名称转换错误]")
325
- return False
326
-
327
- # 获取数据
328
- df: DataFrame | None = get_stock_data_from_akshare(code)
329
-
330
- if df is None:
331
- logger.error(f"{info} [获取数据错误]")
332
- return False
333
-
334
- # 计算数据
335
- df: DataFrame | None = data_vector(df)
336
-
337
- if df is None:
338
- logger.error(f"{info} [计算数据错误]")
339
- return False
340
-
341
- # 保存到数据库
342
- if utils.isTrue(latest, bool):
343
- df = df.tail(1)
344
- df.to_sql(name=name, con=engine, if_exists="append", index=False)
345
- else:
346
- df.to_sql(name=name, con=engine, if_exists="replace", index=False)
347
-
348
- logger.success(f"{info} [成功]")
349
-
350
- return True
351
-
352
- except Exception as e:
353
- logger.success(f"{info} [失败]")
354
- logger.exception(e)
355
- return False
File without changes