cjdata 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cjdata/__init__.py ADDED
@@ -0,0 +1,17 @@
1
+ """cjdata package - local stock data toolkit."""
2
+ from __future__ import annotations
3
+
4
+ import warnings
5
+
6
+ # Suppress xtquant's pkg_resources deprecation warning before any imports
7
+ warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*", category=UserWarning)
8
+
9
+ from .builder import CJDataBuilder
10
+ from .local_data import LocalData, TrendType, CodeFormat
11
+
12
+ __all__ = [
13
+ "CJDataBuilder",
14
+ "LocalData",
15
+ "TrendType",
16
+ "CodeFormat",
17
+ ]
cjdata/__main__.py ADDED
@@ -0,0 +1,7 @@
1
+ """Entry point for python -m cjdata."""
2
+ from __future__ import annotations
3
+
4
+ from .cli import main
5
+
6
+ if __name__ == "__main__": # pragma: no cover
7
+ raise SystemExit(main())
@@ -0,0 +1,318 @@
1
+ """Data acquisition routines backed by baostock."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ from contextlib import contextmanager
6
+ from datetime import datetime
7
+ from typing import Iterator, Optional, Sequence, Any
8
+
9
+ import pandas as pd
10
+
11
+ from .db import UpsertSpec, upsert_rows
12
+ from .utils import to_iso_date, to_yyyymmdd
13
+
14
+ try: # pragma: no cover - optional dependency
15
+ from tqdm import tqdm # type: ignore
16
+ except ImportError: # pragma: no cover
17
+ tqdm = None # type: ignore
18
+
19
+ bs: Any
20
+ try: # pragma: no cover - optional dependency
21
+ import baostock as bs # type: ignore
22
+ except ImportError: # pragma: no cover
23
+ bs = None # type: ignore
24
+
25
+ _LOGGER = logging.getLogger(__name__)
26
+
27
+
28
+ def _require_baostock() -> None:
29
+ if bs is None: # pragma: no cover
30
+ raise RuntimeError("baostock is not installed. Install baostock to use this feature.")
31
+
32
+
33
+ @contextmanager
34
+ def _baostock_session() -> Iterator[Any]:
35
+ _require_baostock()
36
+ login_result = bs.login()
37
+ if login_result.error_code != "0": # pragma: no cover
38
+ raise RuntimeError(f"baostock login failed: {login_result.error_msg}")
39
+ try:
40
+ yield bs
41
+ finally:
42
+ bs.logout()
43
+
44
+
45
+ def _to_baostock_code(code: str) -> str:
46
+ if "." not in code:
47
+ raise ValueError(f"invalid stock code: {code}")
48
+ number, market = code.split(".")
49
+ return f"{market.lower()}.{number}"
50
+
51
+
52
+ def _from_baostock_code(code: str) -> str:
53
+ market, number = code.split(".")
54
+ return f"{number}.{market.upper()}"
55
+
56
+
57
+ class BaostockPipeline:
58
+ def __init__(self, conn, logger: Optional[logging.Logger] = None) -> None:
59
+ _require_baostock()
60
+ self.conn = conn
61
+ self.logger = logger or _LOGGER
62
+
63
+ def download_daily_for_codes(
64
+ self,
65
+ codes: Sequence[str],
66
+ start_date: Optional[str] = None,
67
+ end_date: Optional[str] = None,
68
+ adjustflag: str = "1",
69
+ ) -> int:
70
+ start = to_yyyymmdd(start_date or "20080101")
71
+ end = to_yyyymmdd(end_date) if end_date else None
72
+ total = 0
73
+ with _baostock_session() as session:
74
+ iterable = (
75
+ tqdm(
76
+ codes,
77
+ desc="baostock daily",
78
+ total=len(codes) if hasattr(codes, "__len__") else None,
79
+ leave=False,
80
+ )
81
+ if tqdm
82
+ else codes
83
+ )
84
+ for code in iterable:
85
+ rows = self._download_single_daily(session, code, start, end, adjustflag)
86
+ if not rows:
87
+ continue
88
+ spec = UpsertSpec(
89
+ table="daily_k_data",
90
+ columns=(
91
+ "date",
92
+ "code",
93
+ "open",
94
+ "high",
95
+ "low",
96
+ "close",
97
+ "preclose",
98
+ "volume",
99
+ "amount",
100
+ "adjustflag",
101
+ "turn",
102
+ "tradestatus",
103
+ "pctChg",
104
+ "peTTM",
105
+ "pbMRQ",
106
+ "psTTM",
107
+ "pcfNcfTTM",
108
+ "isST",
109
+ "source",
110
+ ),
111
+ conflict_columns=("date", "code", "adjustflag"),
112
+ )
113
+ total += upsert_rows(
114
+ self.conn,
115
+ spec,
116
+ rows,
117
+ update_columns=(
118
+ "open",
119
+ "high",
120
+ "low",
121
+ "close",
122
+ "preclose",
123
+ "volume",
124
+ "amount",
125
+ "turn",
126
+ "tradestatus",
127
+ "pctChg",
128
+ "peTTM",
129
+ "pbMRQ",
130
+ "psTTM",
131
+ "pcfNcfTTM",
132
+ "isST",
133
+ "source",
134
+ ),
135
+ )
136
+ self.conn.commit()
137
+ return total
138
+
139
+ def download_dupont_for_codes(
140
+ self,
141
+ codes: Sequence[str],
142
+ start_year: int = 2007,
143
+ start_quarter: int = 1,
144
+ ) -> int:
145
+ total = 0
146
+ with _baostock_session() as session:
147
+ for code in codes:
148
+ total += self._download_single_dupont(session, code, start_year, start_quarter)
149
+ self.conn.commit()
150
+ return total
151
+
152
+ def _download_single_daily(
153
+ self,
154
+ session: Any,
155
+ stock_code: str,
156
+ start_date: str,
157
+ end_date: Optional[str],
158
+ adjustflag: str,
159
+ ) -> list[tuple[Any, ...]]:
160
+ final_start = max(start_date, self._next_download_date(stock_code, adjustflag))
161
+ start_iso = to_iso_date(final_start)
162
+ end_iso = to_iso_date(end_date) if end_date else ""
163
+ bs_code = _to_baostock_code(stock_code)
164
+ try:
165
+ rs = session.query_history_k_data_plus(
166
+ code=bs_code,
167
+ fields="date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,peTTM,pbMRQ,psTTM,pcfNcfTTM,isST",
168
+ start_date=start_iso,
169
+ end_date=end_iso,
170
+ frequency="d",
171
+ adjustflag=adjustflag,
172
+ )
173
+ except Exception as exc: # pragma: no cover - network errors
174
+ self.logger.warning("baostock download failed for %s: %s", stock_code, exc)
175
+ return []
176
+ if rs.error_code != "0": # pragma: no cover
177
+ self.logger.warning("baostock error for %s: %s", stock_code, rs.error_msg)
178
+ return []
179
+ records = []
180
+ while rs.next():
181
+ records.append(rs.get_row_data())
182
+ if not records:
183
+ return []
184
+ frame = pd.DataFrame(records, columns=rs.fields)
185
+ frame["date"] = frame["date"].str.replace("-", "")
186
+ frame["code"] = frame["code"].apply(_from_baostock_code)
187
+ numeric_columns = [
188
+ "open",
189
+ "high",
190
+ "low",
191
+ "close",
192
+ "preclose",
193
+ "volume",
194
+ "amount",
195
+ "turn",
196
+ "tradestatus",
197
+ "pctChg",
198
+ "peTTM",
199
+ "pbMRQ",
200
+ "psTTM",
201
+ "pcfNcfTTM",
202
+ "isST",
203
+ ]
204
+ for column in numeric_columns:
205
+ if column in frame.columns:
206
+ frame[column] = pd.to_numeric(frame[column], errors="coerce")
207
+ frame["adjustflag"] = pd.to_numeric(frame["adjustflag"], errors="coerce").fillna(int(adjustflag)).astype(int)
208
+ frame["source"] = "baostock"
209
+ columns = [
210
+ "date",
211
+ "code",
212
+ "open",
213
+ "high",
214
+ "low",
215
+ "close",
216
+ "preclose",
217
+ "volume",
218
+ "amount",
219
+ "adjustflag",
220
+ "turn",
221
+ "tradestatus",
222
+ "pctChg",
223
+ "peTTM",
224
+ "pbMRQ",
225
+ "psTTM",
226
+ "pcfNcfTTM",
227
+ "isST",
228
+ "source",
229
+ ]
230
+ frame = frame[columns]
231
+ return [tuple(row) for row in frame.itertuples(index=False, name=None)]
232
+
233
+ def _download_single_dupont(
234
+ self,
235
+ session: Any,
236
+ stock_code: str,
237
+ start_year: int,
238
+ start_quarter: int,
239
+ ) -> int:
240
+ bs_code = _to_baostock_code(stock_code)
241
+ last_date = self._latest_dupont_quarter(stock_code)
242
+ if last_date:
243
+ year = int(last_date[:4])
244
+ month = int(last_date[4:6])
245
+ quarter = (month - 1) // 3 + 1
246
+ start_year = year
247
+ start_quarter = quarter + 1 if quarter < 4 else 1
248
+ if quarter == 4:
249
+ start_year = year + 1
250
+ current = pd.Timestamp.now()
251
+ total = 0
252
+ for year in range(start_year, current.year + 1):
253
+ first_quarter = start_quarter if year == start_year else 1
254
+ last_quarter = ((current.month - 1) // 3 + 1) if year == current.year else 4
255
+ for quarter in range(first_quarter, last_quarter + 1):
256
+ try:
257
+ rs = session.query_dupont_data(code=bs_code, year=year, quarter=quarter)
258
+ except Exception as exc: # pragma: no cover
259
+ self.logger.warning("dupont download failed for %s %sQ%s: %s", stock_code, year, quarter, exc)
260
+ continue
261
+ if rs.error_code != "0": # pragma: no cover
262
+ if "no records" not in rs.error_msg.lower():
263
+ self.logger.warning("dupont error for %s %sQ%s: %s", stock_code, year, quarter, rs.error_msg)
264
+ continue
265
+ records: list[dict[str, Any]] = []
266
+ while rs.next():
267
+ records.append(dict(zip(rs.fields, rs.get_row_data())))
268
+ if not records:
269
+ continue
270
+ frame = pd.DataFrame(records)
271
+ frame["code"] = frame["code"].apply(_from_baostock_code)
272
+ frame["pubDate"] = frame["pubDate"].str.replace("-", "")
273
+ frame["statDate"] = frame["statDate"].str.replace("-", "")
274
+ numeric_columns = [
275
+ "dupontROE",
276
+ "dupontAssetStoEquity",
277
+ "dupontAssetTurn",
278
+ "dupontPnitoni",
279
+ "dupontNitogr",
280
+ "dupontTaxBurden",
281
+ "dupontIntburden",
282
+ "dupontEbittogr",
283
+ ]
284
+ for column in numeric_columns:
285
+ if column in frame.columns:
286
+ frame[column] = pd.to_numeric(frame[column], errors="coerce")
287
+ spec = UpsertSpec(
288
+ table="dupont_data",
289
+ columns=tuple(frame.columns),
290
+ conflict_columns=("code", "statDate"),
291
+ )
292
+ total += upsert_rows(
293
+ self.conn,
294
+ spec,
295
+ [tuple(row) for row in frame.itertuples(index=False, name=None)],
296
+ update_columns=tuple(frame.columns),
297
+ )
298
+ return total
299
+
300
+ def _next_download_date(self, code: str, adjustflag: str) -> str:
301
+ cursor = self.conn.execute(
302
+ "SELECT MAX(date) FROM daily_k_data WHERE code=? AND adjustflag=?",
303
+ (code, int(adjustflag)),
304
+ )
305
+ latest = cursor.fetchone()[0]
306
+ if not latest:
307
+ return "20080101"
308
+ latest_dt = datetime.strptime(latest, "%Y%m%d")
309
+ next_dt = latest_dt + pd.Timedelta(days=1)
310
+ return next_dt.strftime("%Y%m%d")
311
+
312
+ def _latest_dupont_quarter(self, code: str) -> Optional[str]:
313
+ cursor = self.conn.execute(
314
+ "SELECT MAX(statDate) FROM dupont_data WHERE code=?",
315
+ (code,),
316
+ )
317
+ value = cursor.fetchone()[0]
318
+ return value if value else None
cjdata/builder.py ADDED
@@ -0,0 +1,94 @@
1
+ """Orchestrates local database construction and updates."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ from typing import Optional, Sequence
6
+
7
+ from . import db
8
+ from .baostock_pipeline import BaostockPipeline
9
+ from .xtquant_pipeline import XtQuantPipeline
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class CJDataBuilder:
15
+ def __init__(self, db_path: str, logger_override: Optional[logging.Logger] = None) -> None:
16
+ self.db_path = db_path
17
+ self.logger = logger_override or logger
18
+
19
+ def bootstrap(
20
+ self,
21
+ start_date: str = "20080101",
22
+ end_date: Optional[str] = None,
23
+ include_dupont: bool = False,
24
+ skip_xtquant: bool = False,
25
+ skip_baostock: bool = False,
26
+ ) -> None:
27
+ with db.connection(self.db_path) as conn:
28
+ db.ensure_schema(conn)
29
+ if not skip_xtquant:
30
+ try:
31
+ xt_logger = self.logger.getChild("xtquant")
32
+ xt_pipeline = XtQuantPipeline(conn, xt_logger)
33
+ xt_pipeline.download_trading_calendar()
34
+ xt_pipeline.update_sector_membership()
35
+ xt_pipeline.update_stock_basic()
36
+ etf_codes = xt_pipeline.default_etf_codes()
37
+ if etf_codes:
38
+ xt_logger.info("Downloading ETF daily data for %s codes", len(etf_codes))
39
+ xt_pipeline.download_daily_for_codes(etf_codes, start_date=start_date, end_date=end_date)
40
+ except RuntimeError as exc:
41
+ self.logger.warning("Skip xtquant stage: %s", exc)
42
+ if not skip_baostock:
43
+ try:
44
+ bs_logger = self.logger.getChild("baostock")
45
+ bs_pipeline = BaostockPipeline(conn, bs_logger)
46
+ codes = self._sector_codes(conn, ("沪深A股", "沪深指数"))
47
+ if codes:
48
+ bs_logger.info("Downloading BA daily data for %s codes", len(codes))
49
+ bs_pipeline.download_daily_for_codes(codes, start_date=start_date, end_date=end_date)
50
+ if include_dupont:
51
+ dupont_codes = self._sector_codes(conn, ("沪深A股",))
52
+ if dupont_codes:
53
+ bs_logger.info("Downloading DuPont data for %s codes", len(dupont_codes))
54
+ bs_pipeline.download_dupont_for_codes(dupont_codes)
55
+ except RuntimeError as exc:
56
+ self.logger.warning("Skip baostock stage: %s", exc)
57
+
58
+ def update(
59
+ self,
60
+ end_date: Optional[str] = None,
61
+ skip_xtquant: bool = False,
62
+ skip_baostock: bool = False,
63
+ ) -> None:
64
+ with db.connection(self.db_path) as conn:
65
+ db.ensure_schema(conn)
66
+ if not skip_xtquant:
67
+ try:
68
+ xt_logger = self.logger.getChild("xtquant")
69
+ xt_pipeline = XtQuantPipeline(conn, xt_logger)
70
+ etf_codes = xt_pipeline.default_etf_codes()
71
+ if etf_codes:
72
+ xt_logger.info("Updating ETF daily data for %s codes", len(etf_codes))
73
+ xt_pipeline.download_daily_for_codes(etf_codes, end_date=end_date)
74
+ except RuntimeError as exc:
75
+ self.logger.warning("Skip xtquant update: %s", exc)
76
+ if not skip_baostock:
77
+ try:
78
+ bs_logger = self.logger.getChild("baostock")
79
+ bs_pipeline = BaostockPipeline(conn, bs_logger)
80
+ codes = self._sector_codes(conn, ("沪深A股", "沪深指数"))
81
+ if codes:
82
+ bs_logger.info("Updating BA daily data for %s codes", len(codes))
83
+ bs_pipeline.download_daily_for_codes(codes, end_date=end_date)
84
+ except RuntimeError as exc:
85
+ self.logger.warning("Skip baostock update: %s", exc)
86
+
87
+ def _sector_codes(self, conn, sectors: Sequence[str]) -> list[str]:
88
+ cursor = conn.execute(
89
+ "SELECT DISTINCT stock_code FROM sector_stocks WHERE sector_name IN ({})".format(
90
+ ",".join("?" for _ in sectors)
91
+ ),
92
+ tuple(sectors),
93
+ )
94
+ return [row[0] for row in cursor.fetchall()]
cjdata/cli.py ADDED
@@ -0,0 +1,74 @@
1
+ """Command line interface for the cjdata package."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import logging
6
+ from typing import Sequence, Optional
7
+
8
+ from .builder import CJDataBuilder
9
+
10
+ DEFAULT_DB = "stock_data_hfq.db"
11
+
12
+
13
+ def _configure_logging(level: str) -> None:
14
+ logging.basicConfig(
15
+ level=getattr(logging, level.upper(), logging.INFO),
16
+ format="%(asctime)s %(levelname)s %(name)s - %(message)s",
17
+ )
18
+
19
+
20
+ def build_parser() -> argparse.ArgumentParser:
21
+ parser = argparse.ArgumentParser(prog="cjdata", description="Local stock data toolkit")
22
+ parser.add_argument("--log-level", default="INFO", help="Logging level (default: INFO)")
23
+
24
+ subparsers = parser.add_subparsers(dest="command")
25
+
26
+ download = subparsers.add_parser("download", help="Perform a full data download")
27
+ download.add_argument("--db", default=DEFAULT_DB, help="SQLite database path")
28
+ download.add_argument("--start-date", default="20080101", help="Start date in YYYYMMDD")
29
+ download.add_argument("--end-date", help="End date in YYYYMMDD")
30
+ download.add_argument("--include-dupont", action="store_true", help="Download DuPont data")
31
+ download.add_argument("--skip-xtquant", action="store_true", help="Skip xtquant stage")
32
+ download.add_argument("--skip-baostock", action="store_true", help="Skip baostock stage")
33
+
34
+ update = subparsers.add_parser("update", help="Incrementally update existing data")
35
+ update.add_argument("--db", default=DEFAULT_DB, help="SQLite database path")
36
+ update.add_argument("--end-date", help="End date in YYYYMMDD")
37
+ update.add_argument("--skip-xtquant", action="store_true", help="Skip xtquant stage")
38
+ update.add_argument("--skip-baostock", action="store_true", help="Skip baostock stage")
39
+
40
+ return parser
41
+
42
+
43
+ def main(argv: Optional[Sequence[str]] = None) -> int:
44
+ parser = build_parser()
45
+ args = parser.parse_args(argv)
46
+
47
+ if not args.command:
48
+ parser.print_help()
49
+ return 1
50
+
51
+ _configure_logging(args.log_level)
52
+ builder = CJDataBuilder(args.db)
53
+
54
+ if args.command == "download":
55
+ builder.bootstrap(
56
+ start_date=args.start_date,
57
+ end_date=args.end_date,
58
+ include_dupont=args.include_dupont,
59
+ skip_xtquant=args.skip_xtquant,
60
+ skip_baostock=args.skip_baostock,
61
+ )
62
+ elif args.command == "update":
63
+ builder.update(
64
+ end_date=args.end_date,
65
+ skip_xtquant=args.skip_xtquant,
66
+ skip_baostock=args.skip_baostock,
67
+ )
68
+ else:
69
+ parser.error(f"Unknown command: {args.command}")
70
+ return 0
71
+
72
+
73
+ if __name__ == "__main__": # pragma: no cover
74
+ raise SystemExit(main())