cjdata 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,315 @@
1
+ """Data acquisition routines backed by xtquant."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ from datetime import datetime, timedelta
6
+ from typing import Optional, Sequence, Any, TYPE_CHECKING
7
+
8
+ import pandas as pd
9
+ from pandas.api.types import is_datetime64_any_dtype
10
+
11
+ from .db import UpsertSpec, insert_ignore, upsert_rows
12
+ from .utils import to_yyyymmdd
13
+
14
+ try: # pragma: no cover - optional dependency
15
+ from tqdm import tqdm # type: ignore
16
+ except ImportError: # pragma: no cover
17
+ tqdm = None # type: ignore
18
+
19
+ xtdata: Any
20
+ if TYPE_CHECKING: # pragma: no cover
21
+ from xtquant import xtdata as _xtdata
22
+
23
+ try: # pragma: no cover - optional dependency
24
+ from xtquant import xtdata as _xtdata # type: ignore
25
+ xtdata = _xtdata
26
+ except ImportError: # pragma: no cover
27
+ xtdata = None
28
+
29
+ _LOGGER = logging.getLogger(__name__)
30
+
31
+
32
+ def _require_xtquant() -> None:
33
+ if xtdata is None: # pragma: no cover
34
+ raise RuntimeError("xtquant is not installed. Install xtquant to use this feature.")
35
+
36
+
37
+ def _to_epoch_ms(value: Any) -> int:
38
+ if isinstance(value, (int, float)):
39
+ return int(value)
40
+ if isinstance(value, datetime):
41
+ return int(value.timestamp() * 1000)
42
+ return int(pd.Timestamp(value).timestamp() * 1000)
43
+
44
+
45
+ class XtQuantPipeline:
46
+ def __init__(self, conn, logger: Optional[logging.Logger] = None) -> None:
47
+ _require_xtquant()
48
+ self.conn = conn
49
+ self.logger = logger or _LOGGER
50
+
51
+ def download_trading_calendar(self, markets: Sequence[str] = ("SH", "SZ")) -> int:
52
+ total = 0
53
+ for market in markets:
54
+ dates = xtdata.get_trading_dates(market)
55
+ if not dates:
56
+ continue
57
+ rows = [(market, _to_epoch_ms(day)) for day in dates]
58
+ total += insert_ignore(self.conn, "trading_days", ("market", "trade_date"), rows)
59
+ self.conn.commit()
60
+ return total
61
+
62
+ def update_sector_membership(self) -> int:
63
+ sectors = xtdata.get_sector_list()
64
+ total = 0
65
+ for sector in sectors:
66
+ stocks = xtdata.get_stock_list_in_sector(sector)
67
+ rows = [(sector, code) for code in stocks]
68
+ total += insert_ignore(self.conn, "sector_stocks", ("sector_name", "stock_code"), rows)
69
+ self.conn.commit()
70
+ return total
71
+
72
+ def update_stock_basic(self, sectors: Sequence[str] = ("沪深A股", "沪深指数", "沪深基金")) -> int:
73
+ cursor = self.conn.execute(
74
+ "SELECT DISTINCT stock_code FROM sector_stocks WHERE sector_name IN ({})".format(
75
+ ",".join("?" for _ in sectors)
76
+ ),
77
+ tuple(sectors),
78
+ )
79
+ codes = [row[0] for row in cursor.fetchall()]
80
+ rows = []
81
+ for code in codes:
82
+ detail = xtdata.get_instrument_detail(code)
83
+ if not detail:
84
+ continue
85
+ rows.append(
86
+ (
87
+ code,
88
+ detail.get("InstrumentName"),
89
+ detail.get("ExchangeID"),
90
+ self._determine_board(code, detail),
91
+ detail.get("OpenDate"),
92
+ detail.get("TotalShares"),
93
+ detail.get("CirculatingShares"),
94
+ )
95
+ )
96
+ spec = UpsertSpec(
97
+ table="stock_basic",
98
+ columns=(
99
+ "stock_code",
100
+ "stock_name",
101
+ "market",
102
+ "board",
103
+ "listed_date",
104
+ "total_volume",
105
+ "float_volume",
106
+ ),
107
+ conflict_columns=("stock_code",),
108
+ )
109
+ updated = upsert_rows(self.conn, spec, rows, update_columns=(
110
+ "stock_name",
111
+ "market",
112
+ "board",
113
+ "listed_date",
114
+ "total_volume",
115
+ "float_volume",
116
+ ))
117
+ self.conn.commit()
118
+ return updated
119
+
120
+ def download_daily_for_codes(
121
+ self,
122
+ codes: Sequence[str],
123
+ start_date: Optional[str] = None,
124
+ end_date: Optional[str] = None,
125
+ dividend_type: str = "back_ratio",
126
+ ) -> int:
127
+ total = 0
128
+ start = to_yyyymmdd(start_date or "20080101")
129
+ end = to_yyyymmdd(end_date) if end_date else ""
130
+ iterable = (
131
+ tqdm(
132
+ codes,
133
+ desc="xtquant daily",
134
+ total=len(codes) if hasattr(codes, "__len__") else None,
135
+ leave=False,
136
+ )
137
+ if tqdm
138
+ else codes
139
+ )
140
+ for code in iterable:
141
+ rows = self._download_single(code, start, end, dividend_type)
142
+ if not rows:
143
+ continue
144
+ spec = UpsertSpec(
145
+ table="daily_k_data",
146
+ columns=(
147
+ "date",
148
+ "code",
149
+ "open",
150
+ "high",
151
+ "low",
152
+ "close",
153
+ "preclose",
154
+ "volume",
155
+ "amount",
156
+ "adjustflag",
157
+ "turn",
158
+ "tradestatus",
159
+ "pctChg",
160
+ "peTTM",
161
+ "pbMRQ",
162
+ "psTTM",
163
+ "pcfNcfTTM",
164
+ "isST",
165
+ "source",
166
+ ),
167
+ conflict_columns=("date", "code", "adjustflag"),
168
+ )
169
+ total += upsert_rows(
170
+ self.conn,
171
+ spec,
172
+ rows,
173
+ update_columns=(
174
+ "open",
175
+ "high",
176
+ "low",
177
+ "close",
178
+ "preclose",
179
+ "volume",
180
+ "amount",
181
+ "turn",
182
+ "tradestatus",
183
+ "pctChg",
184
+ "peTTM",
185
+ "pbMRQ",
186
+ "psTTM",
187
+ "pcfNcfTTM",
188
+ "isST",
189
+ "source",
190
+ ),
191
+ )
192
+ self.conn.commit()
193
+ return total
194
+
195
+ def default_etf_codes(self) -> list[str]:
196
+ cursor = self.conn.execute(
197
+ "SELECT DISTINCT stock_code FROM sector_stocks WHERE sector_name = ?",
198
+ ("沪深基金",),
199
+ )
200
+ return [row[0] for row in cursor.fetchall()]
201
+
202
+ def _download_single(
203
+ self,
204
+ stock_code: str,
205
+ start_date: str,
206
+ end_date: str,
207
+ dividend_type: str,
208
+ ):
209
+ start_dt = self._next_download_date(stock_code, start_date)
210
+ try:
211
+ xtdata.download_history_data(
212
+ stock_code=stock_code,
213
+ period="1d",
214
+ start_time=start_dt,
215
+ end_time=end_date,
216
+ incrementally=True,
217
+ )
218
+ data = xtdata.get_market_data_ex(
219
+ field_list=[
220
+ "time",
221
+ "open",
222
+ "high",
223
+ "low",
224
+ "close",
225
+ "volume",
226
+ "amount",
227
+ "preClose",
228
+ "turn",
229
+ "tradeStatus",
230
+ "pctChg",
231
+ "peTTM",
232
+ "pbMRQ",
233
+ "psTTM",
234
+ "pcfNcfTTM",
235
+ "isST",
236
+ ],
237
+ stock_list=[stock_code],
238
+ period="1d",
239
+ start_time=start_dt,
240
+ end_time=end_date,
241
+ dividend_type=dividend_type,
242
+ fill_data=True,
243
+ )
244
+ except Exception as exc: # pragma: no cover - network error path
245
+ self.logger.warning("xtquant download failed for %s: %s", stock_code, exc)
246
+ return []
247
+ frame = data.get(stock_code) if data else None
248
+ if frame is None or frame.empty:
249
+ return []
250
+ frame = frame.copy()
251
+ index_series = frame.index.to_series()
252
+ if not is_datetime64_any_dtype(index_series):
253
+ index_series = pd.to_datetime(index_series, errors="coerce")
254
+ if index_series.isna().all() and "time" in frame.columns:
255
+ time_series = pd.to_datetime(frame["time"], errors="coerce")
256
+ if time_series.isna().all():
257
+ time_series = pd.to_datetime(frame["time"], unit="ms", errors="coerce")
258
+ index_series = time_series
259
+ mask = ~index_series.isna()
260
+ frame = frame[mask].copy()
261
+ index_series = index_series[mask]
262
+ if frame.empty:
263
+ return []
264
+ frame["date"] = index_series.dt.strftime("%Y%m%d")
265
+ frame["code"] = stock_code
266
+ frame = frame.rename(columns={"preClose": "preclose", "tradeStatus": "tradestatus"})
267
+ frame["adjustflag"] = 1
268
+ frame["source"] = "xtquant"
269
+ columns = [
270
+ "date",
271
+ "code",
272
+ "open",
273
+ "high",
274
+ "low",
275
+ "close",
276
+ "preclose",
277
+ "volume",
278
+ "amount",
279
+ "adjustflag",
280
+ "turn",
281
+ "tradestatus",
282
+ "pctChg",
283
+ "peTTM",
284
+ "pbMRQ",
285
+ "psTTM",
286
+ "pcfNcfTTM",
287
+ "isST",
288
+ "source",
289
+ ]
290
+ for column in columns:
291
+ if column not in frame.columns:
292
+ frame[column] = None
293
+ frame = frame[columns]
294
+ return [tuple(row) for row in frame.itertuples(index=False, name=None)]
295
+
296
+ def _determine_board(self, code: str, detail: dict) -> Optional[str]:
297
+ if code.startswith("688"):
298
+ return "科创板"
299
+ if code.startswith("300"):
300
+ return "创业板"
301
+ if code.startswith("8"):
302
+ return "北交所"
303
+ return detail.get("InstrumentStatus")
304
+
305
+ def _next_download_date(self, code: str, start_date: str) -> str:
306
+ cursor = self.conn.execute(
307
+ "SELECT MAX(date) FROM daily_k_data WHERE code=? AND adjustflag=1",
308
+ (code,),
309
+ )
310
+ latest = cursor.fetchone()[0]
311
+ if not latest:
312
+ return start_date
313
+ latest_dt = datetime.strptime(latest, "%Y%m%d") + timedelta(days=1)
314
+ fallback = datetime.strptime(start_date, "%Y%m%d")
315
+ return max(latest_dt, fallback).strftime("%Y%m%d")
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: cjdata
3
+ Version: 0.0.2
4
+ Summary: Utilities for building and querying local stock data from xtquant and baostock sources.
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: pandas>=2.0
7
+ Requires-Dist: numpy>=1.23
8
+ Provides-Extra: baostock
9
+ Requires-Dist: baostock>=0.8.8; extra == "baostock"
10
+ Provides-Extra: progress
11
+ Requires-Dist: tqdm>=4.60; extra == "progress"
12
+ Provides-Extra: talib
13
+ Requires-Dist: TA-Lib>=0.4; extra == "talib"
14
+ Provides-Extra: all
15
+ Requires-Dist: baostock>=0.8.8; extra == "all"
16
+ Requires-Dist: tqdm>=4.60; extra == "all"
17
+ Requires-Dist: TA-Lib>=0.4; extra == "all"
18
+ Provides-Extra: dev
19
+ Requires-Dist: pytest>=7.0; extra == "dev"
@@ -0,0 +1,15 @@
1
+ cjdata/__init__.py,sha256=epPQD9pG7z1JfuQG8P6fegjZB6D1EbJRLqh_FuA-zMM,473
2
+ cjdata/__main__.py,sha256=Pl4ZPieuVFtDhheztm31d7o09XEGwBLpf1nNOVBcP2E,182
3
+ cjdata/baostock_pipeline.py,sha256=SOuBjhP6I6wJ0i01VCui7mIIvivpKrhj8lqBs-rJ7Zs,11430
4
+ cjdata/builder.py,sha256=zcRZqXqm9vMERLVEUHz2m6sC-nptF0E-dyxFLfwn1x0,4438
5
+ cjdata/cli.py,sha256=wYTSpIdgzzYZkAmmFW7zsdJjsCPl6gR3ij0tOJd2xBk,2757
6
+ cjdata/db.py,sha256=iv8fGMdvru9msU3TlgprXAF567K3qdkFfn7e8M2z6QY,5019
7
+ cjdata/local_data.py,sha256=IhopkzTYirM4iAYxZv1_d3Gl00RLEOAggZfKDv5w-ls,19723
8
+ cjdata/utils.py,sha256=L0R1W5h98f2JZeRTkHnqRjQDgrTakLnXbOeVoYxZOKw,864
9
+ cjdata/xtquant_pipeline.py,sha256=-W_AAA9VpNwUjd6F6lFLRC0XDXt5AXNlKP4HH6s_TPM,10662
10
+ cjdata/__pycache__/__main__.cpython-312.pyc,sha256=6NfjvGzKVlMk2ONRLXVRQraIPba104-fdTl3swWwiNY,347
11
+ cjdata-0.0.2.dist-info/METADATA,sha256=8_2DoZjUc2LEYib_D8OC1vlfT5pV423nxoicx4dl7lI,675
12
+ cjdata-0.0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ cjdata-0.0.2.dist-info/entry_points.txt,sha256=87lnCH8cJY_xs2jpPfXF5JA5memlJSdwVc9RFdzTyvs,43
14
+ cjdata-0.0.2.dist-info/top_level.txt,sha256=SMBWvkmEOUn4wQORzUpIJheNRmSgSDCrwzd7112afzc,7
15
+ cjdata-0.0.2.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ cjdata = cjdata.cli:main
@@ -0,0 +1 @@
1
+ cjdata