cjdata 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cjdata/__init__.py +17 -0
- cjdata/__main__.py +7 -0
- cjdata/__pycache__/__main__.cpython-312.pyc +0 -0
- cjdata/baostock_pipeline.py +318 -0
- cjdata/builder.py +94 -0
- cjdata/cli.py +74 -0
- cjdata/db.py +190 -0
- cjdata/local_data.py +482 -0
- cjdata/utils.py +30 -0
- cjdata/xtquant_pipeline.py +315 -0
- cjdata-0.0.2.dist-info/METADATA +19 -0
- cjdata-0.0.2.dist-info/RECORD +15 -0
- cjdata-0.0.2.dist-info/WHEEL +5 -0
- cjdata-0.0.2.dist-info/entry_points.txt +2 -0
- cjdata-0.0.2.dist-info/top_level.txt +1 -0
cjdata/db.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""SQLite schema helpers for the cjdata package."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import sqlite3
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Iterable, Iterator, Sequence, Any, Optional
|
|
8
|
+
|
|
9
|
+
SCHEMA_STATEMENTS: tuple[str, ...] = (
|
|
10
|
+
"PRAGMA foreign_keys = ON;",
|
|
11
|
+
"""
|
|
12
|
+
CREATE TABLE IF NOT EXISTS stock_basic (
|
|
13
|
+
stock_code TEXT PRIMARY KEY,
|
|
14
|
+
stock_name TEXT,
|
|
15
|
+
market TEXT,
|
|
16
|
+
board TEXT,
|
|
17
|
+
listed_date TEXT,
|
|
18
|
+
total_volume REAL,
|
|
19
|
+
float_volume REAL,
|
|
20
|
+
updated_at TEXT DEFAULT (strftime('%Y%m%d%H%M%S', 'now', 'localtime'))
|
|
21
|
+
);
|
|
22
|
+
""",
|
|
23
|
+
"""
|
|
24
|
+
CREATE TABLE IF NOT EXISTS daily_k_data (
|
|
25
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
26
|
+
date TEXT NOT NULL,
|
|
27
|
+
code TEXT NOT NULL,
|
|
28
|
+
open REAL,
|
|
29
|
+
high REAL,
|
|
30
|
+
low REAL,
|
|
31
|
+
close REAL,
|
|
32
|
+
preclose REAL,
|
|
33
|
+
volume REAL,
|
|
34
|
+
amount REAL,
|
|
35
|
+
adjustflag INTEGER NOT NULL,
|
|
36
|
+
turn REAL,
|
|
37
|
+
tradestatus INTEGER,
|
|
38
|
+
pctChg REAL,
|
|
39
|
+
peTTM REAL,
|
|
40
|
+
pbMRQ REAL,
|
|
41
|
+
psTTM REAL,
|
|
42
|
+
pcfNcfTTM REAL,
|
|
43
|
+
isST INTEGER,
|
|
44
|
+
source TEXT,
|
|
45
|
+
created_at TEXT DEFAULT (strftime('%Y%m%d%H%M%S', 'now', 'localtime')),
|
|
46
|
+
UNIQUE(date, code, adjustflag)
|
|
47
|
+
);
|
|
48
|
+
""",
|
|
49
|
+
"""
|
|
50
|
+
CREATE TABLE IF NOT EXISTS sector_stocks (
|
|
51
|
+
sector_name TEXT NOT NULL,
|
|
52
|
+
stock_code TEXT NOT NULL,
|
|
53
|
+
PRIMARY KEY (sector_name, stock_code)
|
|
54
|
+
);
|
|
55
|
+
""",
|
|
56
|
+
"""
|
|
57
|
+
CREATE TABLE IF NOT EXISTS trading_days (
|
|
58
|
+
market TEXT NOT NULL,
|
|
59
|
+
trade_date INTEGER NOT NULL,
|
|
60
|
+
PRIMARY KEY (market, trade_date)
|
|
61
|
+
);
|
|
62
|
+
""",
|
|
63
|
+
"""
|
|
64
|
+
CREATE TABLE IF NOT EXISTS dupont_data (
|
|
65
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
66
|
+
code TEXT NOT NULL,
|
|
67
|
+
pubDate TEXT NOT NULL,
|
|
68
|
+
statDate TEXT NOT NULL,
|
|
69
|
+
dupontROE REAL,
|
|
70
|
+
dupontAssetStoEquity REAL,
|
|
71
|
+
dupontAssetTurn REAL,
|
|
72
|
+
dupontPnitoni REAL,
|
|
73
|
+
dupontNitogr REAL,
|
|
74
|
+
dupontTaxBurden REAL,
|
|
75
|
+
dupontIntburden REAL,
|
|
76
|
+
dupontEbittogr REAL,
|
|
77
|
+
UNIQUE(code, statDate)
|
|
78
|
+
);
|
|
79
|
+
""",
|
|
80
|
+
"""
|
|
81
|
+
CREATE TABLE IF NOT EXISTS minutes (
|
|
82
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
83
|
+
stock_code TEXT NOT NULL,
|
|
84
|
+
freq TEXT NOT NULL,
|
|
85
|
+
trade_time TEXT,
|
|
86
|
+
time INTEGER NOT NULL,
|
|
87
|
+
open REAL,
|
|
88
|
+
high REAL,
|
|
89
|
+
low REAL,
|
|
90
|
+
close REAL,
|
|
91
|
+
volume REAL,
|
|
92
|
+
amount REAL,
|
|
93
|
+
pre_close REAL,
|
|
94
|
+
UNIQUE(stock_code, freq, time)
|
|
95
|
+
);
|
|
96
|
+
""",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def connect(path: str) -> sqlite3.Connection:
|
|
101
|
+
conn = sqlite3.connect(path)
|
|
102
|
+
conn.row_factory = sqlite3.Row
|
|
103
|
+
conn.execute("PRAGMA foreign_keys = ON")
|
|
104
|
+
conn.execute("PRAGMA journal_mode = WAL")
|
|
105
|
+
return conn
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@contextmanager
|
|
109
|
+
def connection(path: str) -> Iterator[sqlite3.Connection]:
|
|
110
|
+
conn = connect(path)
|
|
111
|
+
try:
|
|
112
|
+
yield conn
|
|
113
|
+
conn.commit()
|
|
114
|
+
finally:
|
|
115
|
+
conn.close()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def ensure_schema(conn: sqlite3.Connection) -> None:
|
|
119
|
+
cursor = conn.cursor()
|
|
120
|
+
for statement in SCHEMA_STATEMENTS:
|
|
121
|
+
cursor.executescript(statement)
|
|
122
|
+
cursor.close()
|
|
123
|
+
conn.commit()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass(frozen=True)
|
|
127
|
+
class UpsertSpec:
|
|
128
|
+
table: str
|
|
129
|
+
columns: Sequence[str]
|
|
130
|
+
conflict_columns: Sequence[str]
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def upsert_rows(
|
|
134
|
+
conn: sqlite3.Connection,
|
|
135
|
+
spec: UpsertSpec,
|
|
136
|
+
rows: Iterable[Sequence[Any]],
|
|
137
|
+
update_columns: Optional[Sequence[str]] = None,
|
|
138
|
+
) -> int:
|
|
139
|
+
rows = list(rows)
|
|
140
|
+
if not rows:
|
|
141
|
+
return 0
|
|
142
|
+
|
|
143
|
+
placeholders = ", ".join(["?" for _ in spec.columns])
|
|
144
|
+
column_list = ", ".join(spec.columns)
|
|
145
|
+
|
|
146
|
+
if update_columns is None:
|
|
147
|
+
sql = (
|
|
148
|
+
f"INSERT OR IGNORE INTO {spec.table} ({column_list}) "
|
|
149
|
+
f"VALUES ({placeholders})"
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
update_clause = ", ".join(
|
|
153
|
+
f"{col}=excluded.{col}" for col in update_columns
|
|
154
|
+
)
|
|
155
|
+
conflict_cols = ", ".join(spec.conflict_columns)
|
|
156
|
+
sql = (
|
|
157
|
+
f"INSERT INTO {spec.table} ({column_list}) VALUES ({placeholders}) "
|
|
158
|
+
f"ON CONFLICT({conflict_cols}) DO UPDATE SET {update_clause}"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
conn.executemany(sql, rows)
|
|
162
|
+
return len(rows)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def insert_ignore(
|
|
166
|
+
conn: sqlite3.Connection,
|
|
167
|
+
table: str,
|
|
168
|
+
columns: Sequence[str],
|
|
169
|
+
rows: Iterable[Sequence[Any]],
|
|
170
|
+
) -> int:
|
|
171
|
+
rows = list(rows)
|
|
172
|
+
if not rows:
|
|
173
|
+
return 0
|
|
174
|
+
placeholders = ", ".join(["?" for _ in columns])
|
|
175
|
+
column_list = ", ".join(columns)
|
|
176
|
+
sql = (
|
|
177
|
+
f"INSERT OR IGNORE INTO {table} ({column_list}) VALUES ({placeholders})"
|
|
178
|
+
)
|
|
179
|
+
conn.executemany(sql, rows)
|
|
180
|
+
return len(rows)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def delete_rows(
|
|
184
|
+
conn: sqlite3.Connection,
|
|
185
|
+
table: str,
|
|
186
|
+
where_clause: str,
|
|
187
|
+
params: Sequence[Any],
|
|
188
|
+
) -> None:
|
|
189
|
+
sql = f"DELETE FROM {table} WHERE {where_clause}"
|
|
190
|
+
conn.execute(sql, params)
|
cjdata/local_data.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
"""Read-only data access helpers for cjdata."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import sqlite3
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from datetime import datetime, timedelta
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Optional, Dict
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from .utils import to_yyyymmdd
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import talib # type: ignore
|
|
19
|
+
except ImportError: # pragma: no cover
|
|
20
|
+
talib = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TrendType(Enum):
|
|
24
|
+
STRONG_UPTREND = "strong_uptrend"
|
|
25
|
+
WEAK_UPTREND = "weak_uptrend"
|
|
26
|
+
STRONG_DOWNTREND = "strong_downtrend"
|
|
27
|
+
WEAK_DOWNTREND = "weak_downtrend"
|
|
28
|
+
SIDEWAYS = "sideways"
|
|
29
|
+
UNCLEAR = "unclear"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CodeFormat(Enum):
|
|
33
|
+
MARKET_SUFFIX = "suffix"
|
|
34
|
+
MARKET_PREFIX = "prefix"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class FinData(ABC):
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def get_daily(self, stock_code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
|
40
|
+
raise NotImplementedError
|
|
41
|
+
|
|
42
|
+
def close(self) -> None:
|
|
43
|
+
raise NotImplementedError
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class _AdjConfig:
|
|
48
|
+
flag: int
|
|
49
|
+
table: str
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class LocalData(FinData):
|
|
53
|
+
def __init__(self, path: str) -> None:
|
|
54
|
+
self._path = path
|
|
55
|
+
if os.name == "nt":
|
|
56
|
+
self.conn = sqlite3.connect(path, check_same_thread=False)
|
|
57
|
+
self.conn.execute("PRAGMA query_only = ON")
|
|
58
|
+
else:
|
|
59
|
+
uri = f"file:{path}?mode=ro"
|
|
60
|
+
self.conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
|
|
61
|
+
self.conn.row_factory = sqlite3.Row
|
|
62
|
+
|
|
63
|
+
def _table_exists(self, table: str) -> bool:
|
|
64
|
+
cursor = self.conn.execute(
|
|
65
|
+
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
|
|
66
|
+
)
|
|
67
|
+
return cursor.fetchone() is not None
|
|
68
|
+
|
|
69
|
+
def get_daily(self, stock_code: str, start_date: str, end_date: str, adj: str = "qfq") -> pd.DataFrame:
|
|
70
|
+
adj = adj.lower()
|
|
71
|
+
if adj not in ("qfq", "hfq"):
|
|
72
|
+
raise ValueError("adj must be 'qfq' or 'hfq'")
|
|
73
|
+
|
|
74
|
+
start = to_yyyymmdd(start_date)
|
|
75
|
+
end = to_yyyymmdd(end_date)
|
|
76
|
+
mapping = {"hfq": _AdjConfig(1, "daily_k_data"), "qfq": _AdjConfig(2, "daily_k_data")}
|
|
77
|
+
config = mapping[adj]
|
|
78
|
+
|
|
79
|
+
if self._table_exists("daily_k_data"):
|
|
80
|
+
query = (
|
|
81
|
+
"SELECT date AS trade_date, open, high, low, close, preclose AS pre_close, volume "
|
|
82
|
+
"FROM daily_k_data WHERE code=? AND date BETWEEN ? AND ? AND (code like '000%.SH' OR code like '399%.SZ' OR adjustflag = ?) "
|
|
83
|
+
"ORDER BY date"
|
|
84
|
+
)
|
|
85
|
+
df = pd.read_sql(
|
|
86
|
+
query,
|
|
87
|
+
self.conn,
|
|
88
|
+
params=(stock_code, start, end, config.flag),
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
fallback_table = f"daily_{adj}"
|
|
92
|
+
if not self._table_exists(fallback_table):
|
|
93
|
+
return pd.DataFrame()
|
|
94
|
+
query = (
|
|
95
|
+
"SELECT trade_date, open, high, low, close, pre_close, volume "
|
|
96
|
+
f"FROM {fallback_table} WHERE stock_code=? AND trade_date BETWEEN ? AND ? ORDER BY trade_date"
|
|
97
|
+
)
|
|
98
|
+
df = pd.read_sql(query, self.conn, params=(stock_code, start, end))
|
|
99
|
+
|
|
100
|
+
if df.empty:
|
|
101
|
+
return df
|
|
102
|
+
|
|
103
|
+
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
|
|
104
|
+
df = df.set_index("trade_date").sort_index()
|
|
105
|
+
return df
|
|
106
|
+
|
|
107
|
+
def get_weekly(self, stock_code: str, start_date: str, end_date: str, adj: str = "qfq") -> pd.DataFrame:
|
|
108
|
+
data = self.get_daily(stock_code, start_date, end_date, adj)
|
|
109
|
+
if data.empty:
|
|
110
|
+
return data
|
|
111
|
+
weekly = data.resample("W-FRI").agg(
|
|
112
|
+
{
|
|
113
|
+
"open": "first",
|
|
114
|
+
"high": "max",
|
|
115
|
+
"low": "min",
|
|
116
|
+
"close": "last",
|
|
117
|
+
"volume": "sum",
|
|
118
|
+
}
|
|
119
|
+
)
|
|
120
|
+
return weekly.dropna()
|
|
121
|
+
|
|
122
|
+
def get_minutes(self, stock_code: str, start_date: str, end_date: str, freq: str = "5m") -> pd.DataFrame:
|
|
123
|
+
if not self._table_exists("minutes"):
|
|
124
|
+
return pd.DataFrame()
|
|
125
|
+
|
|
126
|
+
start = to_yyyymmdd(start_date)
|
|
127
|
+
end = to_yyyymmdd(end_date)
|
|
128
|
+
start_ts = int(pd.to_datetime(start).timestamp() * 1000)
|
|
129
|
+
end_ts = int(pd.to_datetime(end).timestamp() * 1000) + 24 * 60 * 60 * 1000
|
|
130
|
+
|
|
131
|
+
query = (
|
|
132
|
+
"SELECT trade_time, time, open, high, low, close, volume, amount, pre_close "
|
|
133
|
+
"FROM minutes WHERE stock_code=? AND freq=? AND time >= ? AND time < ? ORDER BY time"
|
|
134
|
+
)
|
|
135
|
+
df = pd.read_sql(query, self.conn, params=(stock_code, freq, start_ts, end_ts))
|
|
136
|
+
if df.empty:
|
|
137
|
+
return df
|
|
138
|
+
|
|
139
|
+
df["datetime"] = pd.to_datetime(df["time"], unit="ms", utc=True)
|
|
140
|
+
df["datetime"] = df["datetime"].dt.tz_convert("Asia/Shanghai").dt.tz_localize(None)
|
|
141
|
+
df = df.set_index("datetime").drop(columns=["trade_time", "time", "amount", "pre_close"], errors="ignore")
|
|
142
|
+
for column in ("open", "high", "low", "close", "volume"):
|
|
143
|
+
df[column] = pd.to_numeric(df[column], errors="coerce")
|
|
144
|
+
df = df.dropna()
|
|
145
|
+
return df
|
|
146
|
+
|
|
147
|
+
def get_price(self, stock_code: str, date: str, adj: str = "qfq") -> float:
|
|
148
|
+
adj = adj.lower()
|
|
149
|
+
if adj not in ("qfq", "hfq"):
|
|
150
|
+
raise ValueError("adj must be 'qfq' or 'hfq'")
|
|
151
|
+
|
|
152
|
+
target_date = to_yyyymmdd(date)
|
|
153
|
+
if self._table_exists("daily_k_data"):
|
|
154
|
+
flag = 1 if adj == "hfq" else 2
|
|
155
|
+
query = (
|
|
156
|
+
"SELECT close FROM daily_k_data WHERE code=? AND adjustflag=? AND date <= ? ORDER BY date DESC LIMIT 1"
|
|
157
|
+
)
|
|
158
|
+
df = pd.read_sql(query, self.conn, params=(stock_code, flag, target_date))
|
|
159
|
+
else:
|
|
160
|
+
table = f"daily_{adj}"
|
|
161
|
+
if not self._table_exists(table):
|
|
162
|
+
return 0.0
|
|
163
|
+
query = (
|
|
164
|
+
f"SELECT close FROM {table} WHERE stock_code=? AND trade_date <= ? ORDER BY trade_date DESC LIMIT 1"
|
|
165
|
+
)
|
|
166
|
+
df = pd.read_sql(query, self.conn, params=(stock_code, target_date))
|
|
167
|
+
|
|
168
|
+
if df.empty:
|
|
169
|
+
return 0.0
|
|
170
|
+
return float(df.iloc[0]["close"])
|
|
171
|
+
|
|
172
|
+
def get_stock_list_in_sector(self, sector_name: str, format: CodeFormat = CodeFormat.MARKET_SUFFIX) -> list[str]:
|
|
173
|
+
df = pd.read_sql(
|
|
174
|
+
"SELECT stock_code FROM sector_stocks WHERE sector_name = ?",
|
|
175
|
+
self.conn,
|
|
176
|
+
params=(sector_name,),
|
|
177
|
+
)
|
|
178
|
+
if df.empty:
|
|
179
|
+
return []
|
|
180
|
+
codes = df["stock_code"].tolist()
|
|
181
|
+
if format == CodeFormat.MARKET_SUFFIX:
|
|
182
|
+
return codes
|
|
183
|
+
converted: list[str] = []
|
|
184
|
+
for code in codes:
|
|
185
|
+
parts = code.split(".")
|
|
186
|
+
if len(parts) == 2:
|
|
187
|
+
converted.append(f"{parts[1].lower()}.{parts[0]}")
|
|
188
|
+
else:
|
|
189
|
+
converted.append(code)
|
|
190
|
+
return converted
|
|
191
|
+
|
|
192
|
+
def get_stock_data_frame_in_sector(
|
|
193
|
+
self,
|
|
194
|
+
sector_name: str,
|
|
195
|
+
start_date: str,
|
|
196
|
+
end_date: str,
|
|
197
|
+
adj: str = "hfq",
|
|
198
|
+
) -> pd.DataFrame:
|
|
199
|
+
adj = adj.lower()
|
|
200
|
+
if adj not in ("qfq", "hfq"):
|
|
201
|
+
raise ValueError("adj must be 'qfq' or 'hfq'")
|
|
202
|
+
|
|
203
|
+
start = to_yyyymmdd(start_date)
|
|
204
|
+
end = to_yyyymmdd(end_date)
|
|
205
|
+
if self._table_exists("daily_k_data"):
|
|
206
|
+
flag = 1 if adj == "hfq" else 2
|
|
207
|
+
query = (
|
|
208
|
+
"SELECT dk.code AS stock_code, dk.date AS trade_date, dk.open, dk.high, dk.low, dk.close, dk.volume "
|
|
209
|
+
"FROM daily_k_data dk JOIN sector_stocks ss ON dk.code = ss.stock_code "
|
|
210
|
+
"WHERE ss.sector_name=? AND dk.date BETWEEN ? AND ? AND dk.adjustflag=? ORDER BY dk.code, dk.date"
|
|
211
|
+
)
|
|
212
|
+
df = pd.read_sql(query, self.conn, params=(sector_name, start, end, flag))
|
|
213
|
+
else:
|
|
214
|
+
table = f"daily_{adj}"
|
|
215
|
+
if not self._table_exists(table):
|
|
216
|
+
return pd.DataFrame()
|
|
217
|
+
query = (
|
|
218
|
+
f"SELECT dq.stock_code, dq.trade_date, dq.open, dq.high, dq.low, dq.close, dq.volume "
|
|
219
|
+
f"FROM {table} dq JOIN sector_stocks ss ON dq.stock_code = ss.stock_code "
|
|
220
|
+
"WHERE ss.sector_name=? AND dq.trade_date BETWEEN ? AND ? ORDER BY dq.stock_code, dq.trade_date"
|
|
221
|
+
)
|
|
222
|
+
df = pd.read_sql(query, self.conn, params=(sector_name, start, end))
|
|
223
|
+
|
|
224
|
+
if df.empty:
|
|
225
|
+
return df
|
|
226
|
+
|
|
227
|
+
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d")
|
|
228
|
+
stocks = self.get_stock_list_in_sector(sector_name)
|
|
229
|
+
if not stocks:
|
|
230
|
+
return df
|
|
231
|
+
dates = pd.to_datetime(df["trade_date"].drop_duplicates())
|
|
232
|
+
full_index = pd.MultiIndex.from_product([stocks, dates], names=["stock_code", "trade_date"])
|
|
233
|
+
df = df.set_index(["stock_code", "trade_date"]).reindex(full_index).reset_index()
|
|
234
|
+
df = df.sort_values(["stock_code", "trade_date"])
|
|
235
|
+
return df
|
|
236
|
+
|
|
237
|
+
def get_trading_dates(self, market: str, start_date: str, end_date: str) -> pd.DataFrame:
|
|
238
|
+
start_ts = int(pd.to_datetime(start_date).timestamp() * 1000)
|
|
239
|
+
end_ts = int(pd.to_datetime(end_date).timestamp() * 1000)
|
|
240
|
+
df = pd.read_sql(
|
|
241
|
+
"SELECT trade_date FROM trading_days WHERE market = ? AND trade_date BETWEEN ? AND ? ORDER BY trade_date",
|
|
242
|
+
self.conn,
|
|
243
|
+
params=(market, start_ts, end_ts),
|
|
244
|
+
)
|
|
245
|
+
if df.empty:
|
|
246
|
+
return df
|
|
247
|
+
df["trade_date"] = df["trade_date"].apply(lambda value: datetime.fromtimestamp(value / 1000))
|
|
248
|
+
return df
|
|
249
|
+
|
|
250
|
+
def get_etf_sector_list(self) -> list[str]:
|
|
251
|
+
df = pd.read_sql(
|
|
252
|
+
"SELECT DISTINCT sector_name FROM sector_stocks WHERE sector_name LIKE 'ETF%'",
|
|
253
|
+
self.conn,
|
|
254
|
+
)
|
|
255
|
+
return df["sector_name"].tolist()
|
|
256
|
+
|
|
257
|
+
def resample_data(self, df: pd.DataFrame, target_period: str) -> pd.DataFrame:
|
|
258
|
+
if df.empty:
|
|
259
|
+
return df
|
|
260
|
+
period_map = {
|
|
261
|
+
"1m": "1T",
|
|
262
|
+
"5m": "5T",
|
|
263
|
+
"15m": "15T",
|
|
264
|
+
"30m": "30T",
|
|
265
|
+
"45m": "45T",
|
|
266
|
+
"60m": "60T",
|
|
267
|
+
"1h": "1H",
|
|
268
|
+
"4h": "4H",
|
|
269
|
+
"1d": "1D",
|
|
270
|
+
}
|
|
271
|
+
if target_period == "1m":
|
|
272
|
+
return df
|
|
273
|
+
freq = period_map.get(target_period, "15T")
|
|
274
|
+
resampled = df.resample(freq).agg(
|
|
275
|
+
{
|
|
276
|
+
"open": "first",
|
|
277
|
+
"high": "max",
|
|
278
|
+
"low": "min",
|
|
279
|
+
"close": "last",
|
|
280
|
+
"volume": "sum",
|
|
281
|
+
}
|
|
282
|
+
)
|
|
283
|
+
return resampled.dropna()
|
|
284
|
+
|
|
285
|
+
def get_stock_name(self, stock_code: str) -> str:
|
|
286
|
+
df = pd.read_sql(
|
|
287
|
+
"SELECT stock_name FROM stock_basic WHERE stock_code = ?",
|
|
288
|
+
self.conn,
|
|
289
|
+
params=(stock_code,),
|
|
290
|
+
)
|
|
291
|
+
if df.empty:
|
|
292
|
+
return ""
|
|
293
|
+
return str(df.iloc[0]["stock_name"])
|
|
294
|
+
|
|
295
|
+
def get_stock_name_in_sector(self, sector: str) -> pd.DataFrame:
|
|
296
|
+
df = pd.read_sql(
|
|
297
|
+
"SELECT sb.stock_code, sb.stock_name FROM stock_basic sb JOIN sector_stocks ss ON sb.stock_code = ss.stock_code "
|
|
298
|
+
"WHERE ss.sector_name = ?",
|
|
299
|
+
self.conn,
|
|
300
|
+
params=(sector,),
|
|
301
|
+
)
|
|
302
|
+
return df
|
|
303
|
+
|
|
304
|
+
def get_stock_volume(self, stock_code: str) -> tuple[float, float]:
|
|
305
|
+
df = pd.read_sql(
|
|
306
|
+
"SELECT total_volume, float_volume FROM stock_basic WHERE stock_code = ?",
|
|
307
|
+
self.conn,
|
|
308
|
+
params=(stock_code,),
|
|
309
|
+
)
|
|
310
|
+
if df.empty:
|
|
311
|
+
return (0.0, 0.0)
|
|
312
|
+
row = df.iloc[0]
|
|
313
|
+
return (float(row["total_volume"] or 0.0), float(row["float_volume"] or 0.0))
|
|
314
|
+
|
|
315
|
+
def get_stock_basic_by_sector(self, sector: str) -> pd.DataFrame:
|
|
316
|
+
df = pd.read_sql(
|
|
317
|
+
"SELECT sb.stock_code, sb.stock_name, sb.market, sb.total_volume, sb.float_volume "
|
|
318
|
+
"FROM stock_basic sb JOIN sector_stocks ss ON sb.stock_code = ss.stock_code WHERE ss.sector_name = ?",
|
|
319
|
+
self.conn,
|
|
320
|
+
params=(sector,),
|
|
321
|
+
)
|
|
322
|
+
return df
|
|
323
|
+
|
|
324
|
+
def get_dupont_data_by_sector(self, sector: str, year: int, quarter: int) -> pd.DataFrame:
|
|
325
|
+
quarter_end_dates = {1: f"{year}0331", 2: f"{year}0630", 3: f"{year}0930", 4: f"{year}1231"}
|
|
326
|
+
if quarter not in quarter_end_dates:
|
|
327
|
+
raise ValueError("quarter must be 1, 2, 3, or 4")
|
|
328
|
+
stat_date = quarter_end_dates[quarter]
|
|
329
|
+
df = pd.read_sql(
|
|
330
|
+
"SELECT sb.stock_code, sb.stock_name, dd.pubDate, dd.statDate, dd.dupontROE, dd.dupontAssetStoEquity, "
|
|
331
|
+
"dd.dupontAssetTurn, dd.dupontPnitoni, dd.dupontNitogr, dd.dupontTaxBurden, dd.dupontIntburden, dd.dupontEbittogr "
|
|
332
|
+
"FROM stock_basic sb JOIN sector_stocks ss ON sb.stock_code = ss.stock_code "
|
|
333
|
+
"LEFT JOIN dupont_data dd ON sb.stock_code = dd.code AND dd.statDate = ? WHERE ss.sector_name = ?",
|
|
334
|
+
self.conn,
|
|
335
|
+
params=(stat_date, sector),
|
|
336
|
+
)
|
|
337
|
+
return df
|
|
338
|
+
|
|
339
|
+
def search_stocks(self, search_str: str, limit: int = 20) -> list[tuple[str, str]]:
|
|
340
|
+
df = pd.read_sql(
|
|
341
|
+
"SELECT stock_code, stock_name FROM stock_basic WHERE stock_code LIKE ? OR stock_name LIKE ? LIMIT ?",
|
|
342
|
+
self.conn,
|
|
343
|
+
params=(f"%{search_str}%", f"%{search_str}%", limit),
|
|
344
|
+
)
|
|
345
|
+
return list(df.itertuples(index=False, name=None))
|
|
346
|
+
|
|
347
|
+
def get_stock_returns(self, start_date: str, end_date: str, limit: int = 100) -> pd.DataFrame:
|
|
348
|
+
start_ts = int(pd.to_datetime(start_date).timestamp() * 1000)
|
|
349
|
+
end_ts = int(pd.to_datetime(end_date).timestamp() * 1000)
|
|
350
|
+
query = f"""
|
|
351
|
+
WITH
|
|
352
|
+
start_trade_date AS (
|
|
353
|
+
SELECT trade_date FROM trading_days WHERE market = 'SH' AND trade_date <= {start_ts}
|
|
354
|
+
ORDER BY trade_date DESC LIMIT 1
|
|
355
|
+
),
|
|
356
|
+
end_trade_date AS (
|
|
357
|
+
SELECT trade_date FROM trading_days WHERE market = 'SH' AND trade_date <= {end_ts}
|
|
358
|
+
ORDER BY trade_date DESC LIMIT 1
|
|
359
|
+
)
|
|
360
|
+
SELECT t1.stock_code, b.stock_name, t1.close AS start_price, t2.close AS end_price,
|
|
361
|
+
(t2.close - t1.close) / t1.close AS return_rate,
|
|
362
|
+
datetime(s.trade_date/1000, 'unixepoch') AS start_date,
|
|
363
|
+
datetime(e.trade_date/1000, 'unixepoch') AS end_date
|
|
364
|
+
FROM daily_k_data t1
|
|
365
|
+
JOIN daily_k_data t2 ON t1.code = t2.code AND t1.adjustflag = 1 AND t2.adjustflag = 1
|
|
366
|
+
JOIN stock_basic b ON t1.code = b.stock_code
|
|
367
|
+
JOIN start_trade_date s
|
|
368
|
+
JOIN end_trade_date e
|
|
369
|
+
WHERE t1.date = (SELECT strftime('%Y%m%d', datetime(s.trade_date/1000, 'unixepoch')))
|
|
370
|
+
AND t2.date = (SELECT strftime('%Y%m%d', datetime(e.trade_date/1000, 'unixepoch')))
|
|
371
|
+
ORDER BY return_rate DESC
|
|
372
|
+
LIMIT {limit}
|
|
373
|
+
"""
|
|
374
|
+
df = pd.read_sql(query, self.conn)
|
|
375
|
+
return df
|
|
376
|
+
|
|
377
|
+
def get_stock_code_with_suffix(self, raw_code: str) -> str:
|
|
378
|
+
df = pd.read_sql(
|
|
379
|
+
"SELECT stock_code FROM stock_basic WHERE stock_code LIKE ?",
|
|
380
|
+
self.conn,
|
|
381
|
+
params=(f"{raw_code}.%",),
|
|
382
|
+
)
|
|
383
|
+
if df.empty:
|
|
384
|
+
return ""
|
|
385
|
+
return str(df.iloc[0]["stock_code"])
|
|
386
|
+
|
|
387
|
+
def calculate_trend(self, df: pd.DataFrame) -> Optional[TrendType]:
|
|
388
|
+
if talib is None or df.empty or len(df) < 30:
|
|
389
|
+
return None
|
|
390
|
+
try:
|
|
391
|
+
close_prices = np.asarray(df["close"].astype(float).values, dtype=float)
|
|
392
|
+
high_prices = np.asarray(df["high"].astype(float).values, dtype=float)
|
|
393
|
+
low_prices = np.asarray(df["low"].astype(float).values, dtype=float)
|
|
394
|
+
ma5 = talib.SMA(close_prices, 5)
|
|
395
|
+
ma20 = talib.SMA(close_prices, 20)
|
|
396
|
+
adx = talib.ADX(high_prices, low_prices, close_prices, 14)
|
|
397
|
+
atr = talib.ATR(high_prices, low_prices, close_prices, 14)
|
|
398
|
+
except Exception:
|
|
399
|
+
return None
|
|
400
|
+
if any(np.isnan(val) for val in (ma5[-1], ma20[-1], adx[-1], atr[-1])):
|
|
401
|
+
return None
|
|
402
|
+
if ma5[-1] > ma20[-1] and np.nanmean(np.diff(ma5[-3:])) > 0:
|
|
403
|
+
return TrendType.STRONG_UPTREND if adx[-1] > 25 else TrendType.WEAK_UPTREND
|
|
404
|
+
if ma5[-1] < ma20[-1] and np.nanmean(np.diff(ma20[-5:])) < 0:
|
|
405
|
+
return TrendType.STRONG_DOWNTREND if adx[-1] > 25 else TrendType.WEAK_DOWNTREND
|
|
406
|
+
if (
|
|
407
|
+
abs(ma5[-1] - ma20[-1]) < 0.02 * close_prices[-1]
|
|
408
|
+
and atr[-1] < 0.03 * close_prices[-1]
|
|
409
|
+
and adx[-1] < 20
|
|
410
|
+
):
|
|
411
|
+
return TrendType.SIDEWAYS
|
|
412
|
+
return TrendType.UNCLEAR
|
|
413
|
+
|
|
414
|
+
def determine_trend(
|
|
415
|
+
self,
|
|
416
|
+
stock_code: str,
|
|
417
|
+
date: str,
|
|
418
|
+
period: int = 60,
|
|
419
|
+
adjust: str = "hfq",
|
|
420
|
+
) -> Optional[TrendType]:
|
|
421
|
+
end_date = to_yyyymmdd(date)
|
|
422
|
+
end_dt = datetime.strptime(end_date, "%Y%m%d")
|
|
423
|
+
start_dt = end_dt - timedelta(days=int(period * 1.5))
|
|
424
|
+
start_date = start_dt.strftime("%Y%m%d")
|
|
425
|
+
df = self.get_daily(stock_code, start_date, end_date, adjust)
|
|
426
|
+
return self.calculate_trend(df)
|
|
427
|
+
|
|
428
|
+
def get_trend_analysis(
|
|
429
|
+
self,
|
|
430
|
+
stock_code: str,
|
|
431
|
+
date: str,
|
|
432
|
+
period: int = 60,
|
|
433
|
+
adjust: str = "qfq",
|
|
434
|
+
) -> Optional[Dict[str, object]]:
|
|
435
|
+
if talib is None:
|
|
436
|
+
return None
|
|
437
|
+
end_date = to_yyyymmdd(date)
|
|
438
|
+
end_dt = datetime.strptime(end_date, "%Y%m%d")
|
|
439
|
+
start_dt = end_dt - timedelta(days=int((period + 20) * 1.5))
|
|
440
|
+
start_date = start_dt.strftime("%Y%m%d")
|
|
441
|
+
df = self.get_daily(stock_code, start_date, end_date, adjust)
|
|
442
|
+
if df.empty or len(df) < 30:
|
|
443
|
+
return None
|
|
444
|
+
try:
|
|
445
|
+
close_prices = np.asarray(df["close"].astype(float).values, dtype=float)
|
|
446
|
+
high_prices = np.asarray(df["high"].astype(float).values, dtype=float)
|
|
447
|
+
low_prices = np.asarray(df["low"].astype(float).values, dtype=float)
|
|
448
|
+
ma5 = talib.SMA(close_prices, 5)
|
|
449
|
+
ma20 = talib.SMA(close_prices, 20)
|
|
450
|
+
adx = talib.ADX(high_prices, low_prices, close_prices, 14)
|
|
451
|
+
atr = talib.ATR(high_prices, low_prices, close_prices, 14)
|
|
452
|
+
except Exception:
|
|
453
|
+
return None
|
|
454
|
+
if any(np.isnan(val) for val in (ma5[-1], ma20[-1], adx[-1], atr[-1])):
|
|
455
|
+
return None
|
|
456
|
+
ma5_slope = float(np.nanmean(np.diff(ma5[-3:]))) if len(ma5) >= 3 else 0.0
|
|
457
|
+
ma20_slope = float(np.nanmean(np.diff(ma20[-5:]))) if len(ma20) >= 5 else 0.0
|
|
458
|
+
trend = self.determine_trend(stock_code, end_date, period, adjust)
|
|
459
|
+
return {
|
|
460
|
+
"trend": trend.value if trend else None,
|
|
461
|
+
"ma5": float(ma5[-1]),
|
|
462
|
+
"ma20": float(ma20[-1]),
|
|
463
|
+
"adx": float(adx[-1]),
|
|
464
|
+
"atr": float(atr[-1]),
|
|
465
|
+
"current_price": float(close_prices[-1]),
|
|
466
|
+
"ma5_slope": ma5_slope,
|
|
467
|
+
"ma20_slope": ma20_slope,
|
|
468
|
+
"analysis_date": end_date,
|
|
469
|
+
"stock_code": stock_code,
|
|
470
|
+
"data_points": len(df),
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
def get_latest_date(self) -> Optional[str]:
|
|
474
|
+
if not self._table_exists("daily_k_data"):
|
|
475
|
+
return None
|
|
476
|
+
df = pd.read_sql("SELECT MAX(date) AS latest_date FROM daily_k_data", self.conn)
|
|
477
|
+
if df.empty or df.iloc[0]["latest_date"] is None:
|
|
478
|
+
return None
|
|
479
|
+
return str(df.iloc[0]["latest_date"])
|
|
480
|
+
|
|
481
|
+
def close(self) -> None:
|
|
482
|
+
self.conn.close()
|
cjdata/utils.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Utility helpers for cjdata."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Iterable, Iterator, Sequence, TypeVar
|
|
6
|
+
|
|
7
|
+
T = TypeVar("T")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def to_yyyymmdd(date_str: str) -> str:
|
|
11
|
+
if len(date_str) == 8 and date_str.isdigit():
|
|
12
|
+
return date_str
|
|
13
|
+
return datetime.strptime(date_str, "%Y-%m-%d").strftime("%Y%m%d")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def to_iso_date(date_str: str) -> str:
|
|
17
|
+
if len(date_str) == 10 and date_str[4] == "-":
|
|
18
|
+
return date_str
|
|
19
|
+
return datetime.strptime(date_str, "%Y%m%d").strftime("%Y-%m-%d")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def chunked(seq: Sequence[T], size: int) -> Iterator[Sequence[T]]:
|
|
23
|
+
if size <= 0:
|
|
24
|
+
raise ValueError("size must be > 0")
|
|
25
|
+
for idx in range(0, len(seq), size):
|
|
26
|
+
yield seq[idx : idx + size]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def ensure_list(iterable: Iterable[T]) -> list[T]:
|
|
30
|
+
return list(iterable)
|