cjdata 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cjdata/db.py ADDED
@@ -0,0 +1,190 @@
1
+ """SQLite schema helpers for the cjdata package."""
2
+ from __future__ import annotations
3
+
4
+ import sqlite3
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from typing import Iterable, Iterator, Sequence, Any, Optional
8
+
9
+ SCHEMA_STATEMENTS: tuple[str, ...] = (
10
+ "PRAGMA foreign_keys = ON;",
11
+ """
12
+ CREATE TABLE IF NOT EXISTS stock_basic (
13
+ stock_code TEXT PRIMARY KEY,
14
+ stock_name TEXT,
15
+ market TEXT,
16
+ board TEXT,
17
+ listed_date TEXT,
18
+ total_volume REAL,
19
+ float_volume REAL,
20
+ updated_at TEXT DEFAULT (strftime('%Y%m%d%H%M%S', 'now', 'localtime'))
21
+ );
22
+ """,
23
+ """
24
+ CREATE TABLE IF NOT EXISTS daily_k_data (
25
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
26
+ date TEXT NOT NULL,
27
+ code TEXT NOT NULL,
28
+ open REAL,
29
+ high REAL,
30
+ low REAL,
31
+ close REAL,
32
+ preclose REAL,
33
+ volume REAL,
34
+ amount REAL,
35
+ adjustflag INTEGER NOT NULL,
36
+ turn REAL,
37
+ tradestatus INTEGER,
38
+ pctChg REAL,
39
+ peTTM REAL,
40
+ pbMRQ REAL,
41
+ psTTM REAL,
42
+ pcfNcfTTM REAL,
43
+ isST INTEGER,
44
+ source TEXT,
45
+ created_at TEXT DEFAULT (strftime('%Y%m%d%H%M%S', 'now', 'localtime')),
46
+ UNIQUE(date, code, adjustflag)
47
+ );
48
+ """,
49
+ """
50
+ CREATE TABLE IF NOT EXISTS sector_stocks (
51
+ sector_name TEXT NOT NULL,
52
+ stock_code TEXT NOT NULL,
53
+ PRIMARY KEY (sector_name, stock_code)
54
+ );
55
+ """,
56
+ """
57
+ CREATE TABLE IF NOT EXISTS trading_days (
58
+ market TEXT NOT NULL,
59
+ trade_date INTEGER NOT NULL,
60
+ PRIMARY KEY (market, trade_date)
61
+ );
62
+ """,
63
+ """
64
+ CREATE TABLE IF NOT EXISTS dupont_data (
65
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
66
+ code TEXT NOT NULL,
67
+ pubDate TEXT NOT NULL,
68
+ statDate TEXT NOT NULL,
69
+ dupontROE REAL,
70
+ dupontAssetStoEquity REAL,
71
+ dupontAssetTurn REAL,
72
+ dupontPnitoni REAL,
73
+ dupontNitogr REAL,
74
+ dupontTaxBurden REAL,
75
+ dupontIntburden REAL,
76
+ dupontEbittogr REAL,
77
+ UNIQUE(code, statDate)
78
+ );
79
+ """,
80
+ """
81
+ CREATE TABLE IF NOT EXISTS minutes (
82
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
83
+ stock_code TEXT NOT NULL,
84
+ freq TEXT NOT NULL,
85
+ trade_time TEXT,
86
+ time INTEGER NOT NULL,
87
+ open REAL,
88
+ high REAL,
89
+ low REAL,
90
+ close REAL,
91
+ volume REAL,
92
+ amount REAL,
93
+ pre_close REAL,
94
+ UNIQUE(stock_code, freq, time)
95
+ );
96
+ """,
97
+ )
98
+
99
+
100
+ def connect(path: str) -> sqlite3.Connection:
101
+ conn = sqlite3.connect(path)
102
+ conn.row_factory = sqlite3.Row
103
+ conn.execute("PRAGMA foreign_keys = ON")
104
+ conn.execute("PRAGMA journal_mode = WAL")
105
+ return conn
106
+
107
+
108
+ @contextmanager
109
+ def connection(path: str) -> Iterator[sqlite3.Connection]:
110
+ conn = connect(path)
111
+ try:
112
+ yield conn
113
+ conn.commit()
114
+ finally:
115
+ conn.close()
116
+
117
+
118
+ def ensure_schema(conn: sqlite3.Connection) -> None:
119
+ cursor = conn.cursor()
120
+ for statement in SCHEMA_STATEMENTS:
121
+ cursor.executescript(statement)
122
+ cursor.close()
123
+ conn.commit()
124
+
125
+
126
+ @dataclass(frozen=True)
127
+ class UpsertSpec:
128
+ table: str
129
+ columns: Sequence[str]
130
+ conflict_columns: Sequence[str]
131
+
132
+
133
+ def upsert_rows(
134
+ conn: sqlite3.Connection,
135
+ spec: UpsertSpec,
136
+ rows: Iterable[Sequence[Any]],
137
+ update_columns: Optional[Sequence[str]] = None,
138
+ ) -> int:
139
+ rows = list(rows)
140
+ if not rows:
141
+ return 0
142
+
143
+ placeholders = ", ".join(["?" for _ in spec.columns])
144
+ column_list = ", ".join(spec.columns)
145
+
146
+ if update_columns is None:
147
+ sql = (
148
+ f"INSERT OR IGNORE INTO {spec.table} ({column_list}) "
149
+ f"VALUES ({placeholders})"
150
+ )
151
+ else:
152
+ update_clause = ", ".join(
153
+ f"{col}=excluded.{col}" for col in update_columns
154
+ )
155
+ conflict_cols = ", ".join(spec.conflict_columns)
156
+ sql = (
157
+ f"INSERT INTO {spec.table} ({column_list}) VALUES ({placeholders}) "
158
+ f"ON CONFLICT({conflict_cols}) DO UPDATE SET {update_clause}"
159
+ )
160
+
161
+ conn.executemany(sql, rows)
162
+ return len(rows)
163
+
164
+
165
+ def insert_ignore(
166
+ conn: sqlite3.Connection,
167
+ table: str,
168
+ columns: Sequence[str],
169
+ rows: Iterable[Sequence[Any]],
170
+ ) -> int:
171
+ rows = list(rows)
172
+ if not rows:
173
+ return 0
174
+ placeholders = ", ".join(["?" for _ in columns])
175
+ column_list = ", ".join(columns)
176
+ sql = (
177
+ f"INSERT OR IGNORE INTO {table} ({column_list}) VALUES ({placeholders})"
178
+ )
179
+ conn.executemany(sql, rows)
180
+ return len(rows)
181
+
182
+
183
+ def delete_rows(
184
+ conn: sqlite3.Connection,
185
+ table: str,
186
+ where_clause: str,
187
+ params: Sequence[Any],
188
+ ) -> None:
189
+ sql = f"DELETE FROM {table} WHERE {where_clause}"
190
+ conn.execute(sql, params)
cjdata/local_data.py ADDED
@@ -0,0 +1,482 @@
1
+ """Read-only data access helpers for cjdata."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import sqlite3
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass
8
+ from datetime import datetime, timedelta
9
+ from enum import Enum
10
+ from typing import Optional, Dict
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+ from .utils import to_yyyymmdd
16
+
17
+ try:
18
+ import talib # type: ignore
19
+ except ImportError: # pragma: no cover
20
+ talib = None
21
+
22
+
23
+ class TrendType(Enum):
24
+ STRONG_UPTREND = "strong_uptrend"
25
+ WEAK_UPTREND = "weak_uptrend"
26
+ STRONG_DOWNTREND = "strong_downtrend"
27
+ WEAK_DOWNTREND = "weak_downtrend"
28
+ SIDEWAYS = "sideways"
29
+ UNCLEAR = "unclear"
30
+
31
+
32
+ class CodeFormat(Enum):
33
+ MARKET_SUFFIX = "suffix"
34
+ MARKET_PREFIX = "prefix"
35
+
36
+
37
+ class FinData(ABC):
38
+ @abstractmethod
39
+ def get_daily(self, stock_code: str, start_date: str, end_date: str) -> pd.DataFrame:
40
+ raise NotImplementedError
41
+
42
+ def close(self) -> None:
43
+ raise NotImplementedError
44
+
45
+
46
+ @dataclass
47
+ class _AdjConfig:
48
+ flag: int
49
+ table: str
50
+
51
+
52
+ class LocalData(FinData):
53
+ def __init__(self, path: str) -> None:
54
+ self._path = path
55
+ if os.name == "nt":
56
+ self.conn = sqlite3.connect(path, check_same_thread=False)
57
+ self.conn.execute("PRAGMA query_only = ON")
58
+ else:
59
+ uri = f"file:{path}?mode=ro"
60
+ self.conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
61
+ self.conn.row_factory = sqlite3.Row
62
+
63
+ def _table_exists(self, table: str) -> bool:
64
+ cursor = self.conn.execute(
65
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
66
+ )
67
+ return cursor.fetchone() is not None
68
+
69
+ def get_daily(self, stock_code: str, start_date: str, end_date: str, adj: str = "qfq") -> pd.DataFrame:
70
+ adj = adj.lower()
71
+ if adj not in ("qfq", "hfq"):
72
+ raise ValueError("adj must be 'qfq' or 'hfq'")
73
+
74
+ start = to_yyyymmdd(start_date)
75
+ end = to_yyyymmdd(end_date)
76
+ mapping = {"hfq": _AdjConfig(1, "daily_k_data"), "qfq": _AdjConfig(2, "daily_k_data")}
77
+ config = mapping[adj]
78
+
79
+ if self._table_exists("daily_k_data"):
80
+ query = (
81
+ "SELECT date AS trade_date, open, high, low, close, preclose AS pre_close, volume "
82
+ "FROM daily_k_data WHERE code=? AND date BETWEEN ? AND ? AND (code like '000%.SH' OR code like '399%.SZ' OR adjustflag = ?) "
83
+ "ORDER BY date"
84
+ )
85
+ df = pd.read_sql(
86
+ query,
87
+ self.conn,
88
+ params=(stock_code, start, end, config.flag),
89
+ )
90
+ else:
91
+ fallback_table = f"daily_{adj}"
92
+ if not self._table_exists(fallback_table):
93
+ return pd.DataFrame()
94
+ query = (
95
+ "SELECT trade_date, open, high, low, close, pre_close, volume "
96
+ f"FROM {fallback_table} WHERE stock_code=? AND trade_date BETWEEN ? AND ? ORDER BY trade_date"
97
+ )
98
+ df = pd.read_sql(query, self.conn, params=(stock_code, start, end))
99
+
100
+ if df.empty:
101
+ return df
102
+
103
+ df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
104
+ df = df.set_index("trade_date").sort_index()
105
+ return df
106
+
107
+ def get_weekly(self, stock_code: str, start_date: str, end_date: str, adj: str = "qfq") -> pd.DataFrame:
108
+ data = self.get_daily(stock_code, start_date, end_date, adj)
109
+ if data.empty:
110
+ return data
111
+ weekly = data.resample("W-FRI").agg(
112
+ {
113
+ "open": "first",
114
+ "high": "max",
115
+ "low": "min",
116
+ "close": "last",
117
+ "volume": "sum",
118
+ }
119
+ )
120
+ return weekly.dropna()
121
+
122
+ def get_minutes(self, stock_code: str, start_date: str, end_date: str, freq: str = "5m") -> pd.DataFrame:
123
+ if not self._table_exists("minutes"):
124
+ return pd.DataFrame()
125
+
126
+ start = to_yyyymmdd(start_date)
127
+ end = to_yyyymmdd(end_date)
128
+ start_ts = int(pd.to_datetime(start).timestamp() * 1000)
129
+ end_ts = int(pd.to_datetime(end).timestamp() * 1000) + 24 * 60 * 60 * 1000
130
+
131
+ query = (
132
+ "SELECT trade_time, time, open, high, low, close, volume, amount, pre_close "
133
+ "FROM minutes WHERE stock_code=? AND freq=? AND time >= ? AND time < ? ORDER BY time"
134
+ )
135
+ df = pd.read_sql(query, self.conn, params=(stock_code, freq, start_ts, end_ts))
136
+ if df.empty:
137
+ return df
138
+
139
+ df["datetime"] = pd.to_datetime(df["time"], unit="ms", utc=True)
140
+ df["datetime"] = df["datetime"].dt.tz_convert("Asia/Shanghai").dt.tz_localize(None)
141
+ df = df.set_index("datetime").drop(columns=["trade_time", "time", "amount", "pre_close"], errors="ignore")
142
+ for column in ("open", "high", "low", "close", "volume"):
143
+ df[column] = pd.to_numeric(df[column], errors="coerce")
144
+ df = df.dropna()
145
+ return df
146
+
147
+ def get_price(self, stock_code: str, date: str, adj: str = "qfq") -> float:
148
+ adj = adj.lower()
149
+ if adj not in ("qfq", "hfq"):
150
+ raise ValueError("adj must be 'qfq' or 'hfq'")
151
+
152
+ target_date = to_yyyymmdd(date)
153
+ if self._table_exists("daily_k_data"):
154
+ flag = 1 if adj == "hfq" else 2
155
+ query = (
156
+ "SELECT close FROM daily_k_data WHERE code=? AND adjustflag=? AND date <= ? ORDER BY date DESC LIMIT 1"
157
+ )
158
+ df = pd.read_sql(query, self.conn, params=(stock_code, flag, target_date))
159
+ else:
160
+ table = f"daily_{adj}"
161
+ if not self._table_exists(table):
162
+ return 0.0
163
+ query = (
164
+ f"SELECT close FROM {table} WHERE stock_code=? AND trade_date <= ? ORDER BY trade_date DESC LIMIT 1"
165
+ )
166
+ df = pd.read_sql(query, self.conn, params=(stock_code, target_date))
167
+
168
+ if df.empty:
169
+ return 0.0
170
+ return float(df.iloc[0]["close"])
171
+
172
+ def get_stock_list_in_sector(self, sector_name: str, format: CodeFormat = CodeFormat.MARKET_SUFFIX) -> list[str]:
173
+ df = pd.read_sql(
174
+ "SELECT stock_code FROM sector_stocks WHERE sector_name = ?",
175
+ self.conn,
176
+ params=(sector_name,),
177
+ )
178
+ if df.empty:
179
+ return []
180
+ codes = df["stock_code"].tolist()
181
+ if format == CodeFormat.MARKET_SUFFIX:
182
+ return codes
183
+ converted: list[str] = []
184
+ for code in codes:
185
+ parts = code.split(".")
186
+ if len(parts) == 2:
187
+ converted.append(f"{parts[1].lower()}.{parts[0]}")
188
+ else:
189
+ converted.append(code)
190
+ return converted
191
+
192
+ def get_stock_data_frame_in_sector(
193
+ self,
194
+ sector_name: str,
195
+ start_date: str,
196
+ end_date: str,
197
+ adj: str = "hfq",
198
+ ) -> pd.DataFrame:
199
+ adj = adj.lower()
200
+ if adj not in ("qfq", "hfq"):
201
+ raise ValueError("adj must be 'qfq' or 'hfq'")
202
+
203
+ start = to_yyyymmdd(start_date)
204
+ end = to_yyyymmdd(end_date)
205
+ if self._table_exists("daily_k_data"):
206
+ flag = 1 if adj == "hfq" else 2
207
+ query = (
208
+ "SELECT dk.code AS stock_code, dk.date AS trade_date, dk.open, dk.high, dk.low, dk.close, dk.volume "
209
+ "FROM daily_k_data dk JOIN sector_stocks ss ON dk.code = ss.stock_code "
210
+ "WHERE ss.sector_name=? AND dk.date BETWEEN ? AND ? AND dk.adjustflag=? ORDER BY dk.code, dk.date"
211
+ )
212
+ df = pd.read_sql(query, self.conn, params=(sector_name, start, end, flag))
213
+ else:
214
+ table = f"daily_{adj}"
215
+ if not self._table_exists(table):
216
+ return pd.DataFrame()
217
+ query = (
218
+ f"SELECT dq.stock_code, dq.trade_date, dq.open, dq.high, dq.low, dq.close, dq.volume "
219
+ f"FROM {table} dq JOIN sector_stocks ss ON dq.stock_code = ss.stock_code "
220
+ "WHERE ss.sector_name=? AND dq.trade_date BETWEEN ? AND ? ORDER BY dq.stock_code, dq.trade_date"
221
+ )
222
+ df = pd.read_sql(query, self.conn, params=(sector_name, start, end))
223
+
224
+ if df.empty:
225
+ return df
226
+
227
+ df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
228
+ stocks = self.get_stock_list_in_sector(sector_name)
229
+ if not stocks:
230
+ return df
231
+ dates = pd.to_datetime(df["trade_date"].drop_duplicates())
232
+ full_index = pd.MultiIndex.from_product([stocks, dates], names=["stock_code", "trade_date"])
233
+ df = df.set_index(["stock_code", "trade_date"]).reindex(full_index).reset_index()
234
+ df = df.sort_values(["stock_code", "trade_date"])
235
+ return df
236
+
237
+ def get_trading_dates(self, market: str, start_date: str, end_date: str) -> pd.DataFrame:
238
+ start_ts = int(pd.to_datetime(start_date).timestamp() * 1000)
239
+ end_ts = int(pd.to_datetime(end_date).timestamp() * 1000)
240
+ df = pd.read_sql(
241
+ "SELECT trade_date FROM trading_days WHERE market = ? AND trade_date BETWEEN ? AND ? ORDER BY trade_date",
242
+ self.conn,
243
+ params=(market, start_ts, end_ts),
244
+ )
245
+ if df.empty:
246
+ return df
247
+ df["trade_date"] = df["trade_date"].apply(lambda value: datetime.fromtimestamp(value / 1000))
248
+ return df
249
+
250
+ def get_etf_sector_list(self) -> list[str]:
251
+ df = pd.read_sql(
252
+ "SELECT DISTINCT sector_name FROM sector_stocks WHERE sector_name LIKE 'ETF%'",
253
+ self.conn,
254
+ )
255
+ return df["sector_name"].tolist()
256
+
257
+ def resample_data(self, df: pd.DataFrame, target_period: str) -> pd.DataFrame:
258
+ if df.empty:
259
+ return df
260
+ period_map = {
261
+ "1m": "1T",
262
+ "5m": "5T",
263
+ "15m": "15T",
264
+ "30m": "30T",
265
+ "45m": "45T",
266
+ "60m": "60T",
267
+ "1h": "1H",
268
+ "4h": "4H",
269
+ "1d": "1D",
270
+ }
271
+ if target_period == "1m":
272
+ return df
273
+ freq = period_map.get(target_period, "15T")
274
+ resampled = df.resample(freq).agg(
275
+ {
276
+ "open": "first",
277
+ "high": "max",
278
+ "low": "min",
279
+ "close": "last",
280
+ "volume": "sum",
281
+ }
282
+ )
283
+ return resampled.dropna()
284
+
285
+ def get_stock_name(self, stock_code: str) -> str:
286
+ df = pd.read_sql(
287
+ "SELECT stock_name FROM stock_basic WHERE stock_code = ?",
288
+ self.conn,
289
+ params=(stock_code,),
290
+ )
291
+ if df.empty:
292
+ return ""
293
+ return str(df.iloc[0]["stock_name"])
294
+
295
+ def get_stock_name_in_sector(self, sector: str) -> pd.DataFrame:
296
+ df = pd.read_sql(
297
+ "SELECT sb.stock_code, sb.stock_name FROM stock_basic sb JOIN sector_stocks ss ON sb.stock_code = ss.stock_code "
298
+ "WHERE ss.sector_name = ?",
299
+ self.conn,
300
+ params=(sector,),
301
+ )
302
+ return df
303
+
304
+ def get_stock_volume(self, stock_code: str) -> tuple[float, float]:
305
+ df = pd.read_sql(
306
+ "SELECT total_volume, float_volume FROM stock_basic WHERE stock_code = ?",
307
+ self.conn,
308
+ params=(stock_code,),
309
+ )
310
+ if df.empty:
311
+ return (0.0, 0.0)
312
+ row = df.iloc[0]
313
+ return (float(row["total_volume"] or 0.0), float(row["float_volume"] or 0.0))
314
+
315
+ def get_stock_basic_by_sector(self, sector: str) -> pd.DataFrame:
316
+ df = pd.read_sql(
317
+ "SELECT sb.stock_code, sb.stock_name, sb.market, sb.total_volume, sb.float_volume "
318
+ "FROM stock_basic sb JOIN sector_stocks ss ON sb.stock_code = ss.stock_code WHERE ss.sector_name = ?",
319
+ self.conn,
320
+ params=(sector,),
321
+ )
322
+ return df
323
+
324
+ def get_dupont_data_by_sector(self, sector: str, year: int, quarter: int) -> pd.DataFrame:
325
+ quarter_end_dates = {1: f"{year}0331", 2: f"{year}0630", 3: f"{year}0930", 4: f"{year}1231"}
326
+ if quarter not in quarter_end_dates:
327
+ raise ValueError("quarter must be 1, 2, 3, or 4")
328
+ stat_date = quarter_end_dates[quarter]
329
+ df = pd.read_sql(
330
+ "SELECT sb.stock_code, sb.stock_name, dd.pubDate, dd.statDate, dd.dupontROE, dd.dupontAssetStoEquity, "
331
+ "dd.dupontAssetTurn, dd.dupontPnitoni, dd.dupontNitogr, dd.dupontTaxBurden, dd.dupontIntburden, dd.dupontEbittogr "
332
+ "FROM stock_basic sb JOIN sector_stocks ss ON sb.stock_code = ss.stock_code "
333
+ "LEFT JOIN dupont_data dd ON sb.stock_code = dd.code AND dd.statDate = ? WHERE ss.sector_name = ?",
334
+ self.conn,
335
+ params=(stat_date, sector),
336
+ )
337
+ return df
338
+
339
+ def search_stocks(self, search_str: str, limit: int = 20) -> list[tuple[str, str]]:
340
+ df = pd.read_sql(
341
+ "SELECT stock_code, stock_name FROM stock_basic WHERE stock_code LIKE ? OR stock_name LIKE ? LIMIT ?",
342
+ self.conn,
343
+ params=(f"%{search_str}%", f"%{search_str}%", limit),
344
+ )
345
+ return list(df.itertuples(index=False, name=None))
346
+
347
+ def get_stock_returns(self, start_date: str, end_date: str, limit: int = 100) -> pd.DataFrame:
348
+ start_ts = int(pd.to_datetime(start_date).timestamp() * 1000)
349
+ end_ts = int(pd.to_datetime(end_date).timestamp() * 1000)
350
+ query = f"""
351
+ WITH
352
+ start_trade_date AS (
353
+ SELECT trade_date FROM trading_days WHERE market = 'SH' AND trade_date <= {start_ts}
354
+ ORDER BY trade_date DESC LIMIT 1
355
+ ),
356
+ end_trade_date AS (
357
+ SELECT trade_date FROM trading_days WHERE market = 'SH' AND trade_date <= {end_ts}
358
+ ORDER BY trade_date DESC LIMIT 1
359
+ )
360
+ SELECT t1.stock_code, b.stock_name, t1.close AS start_price, t2.close AS end_price,
361
+ (t2.close - t1.close) / t1.close AS return_rate,
362
+ datetime(s.trade_date/1000, 'unixepoch') AS start_date,
363
+ datetime(e.trade_date/1000, 'unixepoch') AS end_date
364
+ FROM daily_k_data t1
365
+ JOIN daily_k_data t2 ON t1.code = t2.code AND t1.adjustflag = 1 AND t2.adjustflag = 1
366
+ JOIN stock_basic b ON t1.code = b.stock_code
367
+ JOIN start_trade_date s
368
+ JOIN end_trade_date e
369
+ WHERE t1.date = (SELECT strftime('%Y%m%d', datetime(s.trade_date/1000, 'unixepoch')))
370
+ AND t2.date = (SELECT strftime('%Y%m%d', datetime(e.trade_date/1000, 'unixepoch')))
371
+ ORDER BY return_rate DESC
372
+ LIMIT {limit}
373
+ """
374
+ df = pd.read_sql(query, self.conn)
375
+ return df
376
+
377
+ def get_stock_code_with_suffix(self, raw_code: str) -> str:
378
+ df = pd.read_sql(
379
+ "SELECT stock_code FROM stock_basic WHERE stock_code LIKE ?",
380
+ self.conn,
381
+ params=(f"{raw_code}.%",),
382
+ )
383
+ if df.empty:
384
+ return ""
385
+ return str(df.iloc[0]["stock_code"])
386
+
387
+ def calculate_trend(self, df: pd.DataFrame) -> Optional[TrendType]:
388
+ if talib is None or df.empty or len(df) < 30:
389
+ return None
390
+ try:
391
+ close_prices = np.asarray(df["close"].astype(float).values, dtype=float)
392
+ high_prices = np.asarray(df["high"].astype(float).values, dtype=float)
393
+ low_prices = np.asarray(df["low"].astype(float).values, dtype=float)
394
+ ma5 = talib.SMA(close_prices, 5)
395
+ ma20 = talib.SMA(close_prices, 20)
396
+ adx = talib.ADX(high_prices, low_prices, close_prices, 14)
397
+ atr = talib.ATR(high_prices, low_prices, close_prices, 14)
398
+ except Exception:
399
+ return None
400
+ if any(np.isnan(val) for val in (ma5[-1], ma20[-1], adx[-1], atr[-1])):
401
+ return None
402
+ if ma5[-1] > ma20[-1] and np.nanmean(np.diff(ma5[-3:])) > 0:
403
+ return TrendType.STRONG_UPTREND if adx[-1] > 25 else TrendType.WEAK_UPTREND
404
+ if ma5[-1] < ma20[-1] and np.nanmean(np.diff(ma20[-5:])) < 0:
405
+ return TrendType.STRONG_DOWNTREND if adx[-1] > 25 else TrendType.WEAK_DOWNTREND
406
+ if (
407
+ abs(ma5[-1] - ma20[-1]) < 0.02 * close_prices[-1]
408
+ and atr[-1] < 0.03 * close_prices[-1]
409
+ and adx[-1] < 20
410
+ ):
411
+ return TrendType.SIDEWAYS
412
+ return TrendType.UNCLEAR
413
+
414
+ def determine_trend(
415
+ self,
416
+ stock_code: str,
417
+ date: str,
418
+ period: int = 60,
419
+ adjust: str = "hfq",
420
+ ) -> Optional[TrendType]:
421
+ end_date = to_yyyymmdd(date)
422
+ end_dt = datetime.strptime(end_date, "%Y%m%d")
423
+ start_dt = end_dt - timedelta(days=int(period * 1.5))
424
+ start_date = start_dt.strftime("%Y%m%d")
425
+ df = self.get_daily(stock_code, start_date, end_date, adjust)
426
+ return self.calculate_trend(df)
427
+
428
+ def get_trend_analysis(
429
+ self,
430
+ stock_code: str,
431
+ date: str,
432
+ period: int = 60,
433
+ adjust: str = "qfq",
434
+ ) -> Optional[Dict[str, object]]:
435
+ if talib is None:
436
+ return None
437
+ end_date = to_yyyymmdd(date)
438
+ end_dt = datetime.strptime(end_date, "%Y%m%d")
439
+ start_dt = end_dt - timedelta(days=int((period + 20) * 1.5))
440
+ start_date = start_dt.strftime("%Y%m%d")
441
+ df = self.get_daily(stock_code, start_date, end_date, adjust)
442
+ if df.empty or len(df) < 30:
443
+ return None
444
+ try:
445
+ close_prices = np.asarray(df["close"].astype(float).values, dtype=float)
446
+ high_prices = np.asarray(df["high"].astype(float).values, dtype=float)
447
+ low_prices = np.asarray(df["low"].astype(float).values, dtype=float)
448
+ ma5 = talib.SMA(close_prices, 5)
449
+ ma20 = talib.SMA(close_prices, 20)
450
+ adx = talib.ADX(high_prices, low_prices, close_prices, 14)
451
+ atr = talib.ATR(high_prices, low_prices, close_prices, 14)
452
+ except Exception:
453
+ return None
454
+ if any(np.isnan(val) for val in (ma5[-1], ma20[-1], adx[-1], atr[-1])):
455
+ return None
456
+ ma5_slope = float(np.nanmean(np.diff(ma5[-3:]))) if len(ma5) >= 3 else 0.0
457
+ ma20_slope = float(np.nanmean(np.diff(ma20[-5:]))) if len(ma20) >= 5 else 0.0
458
+ trend = self.determine_trend(stock_code, end_date, period, adjust)
459
+ return {
460
+ "trend": trend.value if trend else None,
461
+ "ma5": float(ma5[-1]),
462
+ "ma20": float(ma20[-1]),
463
+ "adx": float(adx[-1]),
464
+ "atr": float(atr[-1]),
465
+ "current_price": float(close_prices[-1]),
466
+ "ma5_slope": ma5_slope,
467
+ "ma20_slope": ma20_slope,
468
+ "analysis_date": end_date,
469
+ "stock_code": stock_code,
470
+ "data_points": len(df),
471
+ }
472
+
473
+ def get_latest_date(self) -> Optional[str]:
474
+ if not self._table_exists("daily_k_data"):
475
+ return None
476
+ df = pd.read_sql("SELECT MAX(date) AS latest_date FROM daily_k_data", self.conn)
477
+ if df.empty or df.iloc[0]["latest_date"] is None:
478
+ return None
479
+ return str(df.iloc[0]["latest_date"])
480
+
481
+ def close(self) -> None:
482
+ self.conn.close()
cjdata/utils.py ADDED
@@ -0,0 +1,30 @@
1
+ """Utility helpers for cjdata."""
2
+ from __future__ import annotations
3
+
4
+ from datetime import datetime
5
+ from typing import Iterable, Iterator, Sequence, TypeVar
6
+
7
+ T = TypeVar("T")
8
+
9
+
10
+ def to_yyyymmdd(date_str: str) -> str:
11
+ if len(date_str) == 8 and date_str.isdigit():
12
+ return date_str
13
+ return datetime.strptime(date_str, "%Y-%m-%d").strftime("%Y%m%d")
14
+
15
+
16
+ def to_iso_date(date_str: str) -> str:
17
+ if len(date_str) == 10 and date_str[4] == "-":
18
+ return date_str
19
+ return datetime.strptime(date_str, "%Y%m%d").strftime("%Y-%m-%d")
20
+
21
+
22
+ def chunked(seq: Sequence[T], size: int) -> Iterator[Sequence[T]]:
23
+ if size <= 0:
24
+ raise ValueError("size must be > 0")
25
+ for idx in range(0, len(seq), size):
26
+ yield seq[idx : idx + size]
27
+
28
+
29
+ def ensure_list(iterable: Iterable[T]) -> list[T]:
30
+ return list(iterable)