financechatbotkit 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- financechatbotkit-2.0.0.dist-info/METADATA +11 -0
- financechatbotkit-2.0.0.dist-info/RECORD +39 -0
- financechatbotkit-2.0.0.dist-info/WHEEL +5 -0
- financechatbotkit-2.0.0.dist-info/entry_points.txt +2 -0
- financechatbotkit-2.0.0.dist-info/top_level.txt +2 -0
- orchestrator/__init__.py +29 -0
- orchestrator/bond/__init__.py +8 -0
- orchestrator/bond/base_reader.py +139 -0
- orchestrator/bond/getBondBasiInfo.py +84 -0
- orchestrator/bond/getBondWithOptiCallRede.py +83 -0
- orchestrator/bond/getEarlExerOpti.py +90 -0
- orchestrator/bond/getIssuIssuItemStat.py +85 -0
- orchestrator/bond/getOptiExer.py +83 -0
- orchestrator/bond/getOptiExerPricAdju.py +84 -0
- orchestrator/bond/workflow.py +252 -0
- orchestrator/exceptions.py +17 -0
- orchestrator/fnguide/__init__.py +21 -0
- orchestrator/fnguide/workflow.py +391 -0
- orchestrator/mapping/__init__.py +22 -0
- orchestrator/mapping/data/__init__.py +1 -0
- orchestrator/mapping/data/corp_codes_raw.json +693170 -0
- orchestrator/mapping/update_raw_data.py +96 -0
- orchestrator/mapping/workflow.py +303 -0
- orchestrator/price/__init__.py +15 -0
- orchestrator/price/workflow.py +250 -0
- telebotkit/__init__.py +51 -0
- telebotkit/bot/__init__.py +38 -0
- telebotkit/bot/client.py +217 -0
- telebotkit/bot/reply.py +36 -0
- telebotkit/bot/router.py +125 -0
- telebotkit/bot/safety.py +28 -0
- telebotkit/bot/telegram.py +41 -0
- telebotkit/firestore/__init__.py +45 -0
- telebotkit/firestore/client.py +141 -0
- telebotkit/firestore/documents.py +164 -0
- telebotkit/firestore/fetch.py +228 -0
- telebotkit/firestore/locks.py +74 -0
- telebotkit/firestore/upload.py +75 -0
- telebotkit/sheets.py +219 -0
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""CLI utility for refreshing the bundled OpenDart corp-code JSON."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
import tempfile
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Sequence
|
|
12
|
+
|
|
13
|
+
from ..exceptions import MappingWorkflowError
|
|
14
|
+
from .workflow import download_latest_raw_entries
|
|
15
|
+
|
|
16
|
+
_DEFAULT_OUTPUT_PATH = Path(__file__).resolve().parent / "data" / "corp_codes_raw.json"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_default_output_path() -> Path:
|
|
20
|
+
return _DEFAULT_OUTPUT_PATH
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def update_raw_corp_codes_file(
|
|
24
|
+
*,
|
|
25
|
+
api_key: str,
|
|
26
|
+
output_path: str | Path | None = None,
|
|
27
|
+
) -> tuple[Path, int]:
|
|
28
|
+
target_path = Path(output_path) if output_path is not None else get_default_output_path()
|
|
29
|
+
raw_entries = download_latest_raw_entries(api_key)
|
|
30
|
+
payload = [
|
|
31
|
+
{
|
|
32
|
+
"corp_code": row["corp_code"],
|
|
33
|
+
"corp_name": row["stock_name"],
|
|
34
|
+
"stock_code": row["stock_code"],
|
|
35
|
+
"modify_date": row["modify_date"],
|
|
36
|
+
}
|
|
37
|
+
for row in raw_entries
|
|
38
|
+
]
|
|
39
|
+
serialized = json.dumps(payload, ensure_ascii=False, indent=2) + "\n"
|
|
40
|
+
|
|
41
|
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
temp_path: Path | None = None
|
|
43
|
+
try:
|
|
44
|
+
with tempfile.NamedTemporaryFile(
|
|
45
|
+
"w",
|
|
46
|
+
encoding="utf-8",
|
|
47
|
+
dir=target_path.parent,
|
|
48
|
+
delete=False,
|
|
49
|
+
) as handle:
|
|
50
|
+
handle.write(serialized)
|
|
51
|
+
temp_path = Path(handle.name)
|
|
52
|
+
temp_path.replace(target_path)
|
|
53
|
+
finally:
|
|
54
|
+
if temp_path is not None and temp_path.exists():
|
|
55
|
+
temp_path.unlink()
|
|
56
|
+
return target_path, len(payload)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
60
|
+
parser = argparse.ArgumentParser(
|
|
61
|
+
description="Refresh the bundled OpenDart corp_codes_raw.json file.",
|
|
62
|
+
)
|
|
63
|
+
parser.add_argument(
|
|
64
|
+
"--api-key",
|
|
65
|
+
default=os.environ.get("DART_API_KEY"),
|
|
66
|
+
help="OpenDart API key. Defaults to the DART_API_KEY environment variable.",
|
|
67
|
+
)
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"--output",
|
|
70
|
+
default=None,
|
|
71
|
+
help=f"Output path. Defaults to {get_default_output_path()}",
|
|
72
|
+
)
|
|
73
|
+
return parser
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def main(argv: Sequence[str] | None = None) -> int:
|
|
77
|
+
parser = build_parser()
|
|
78
|
+
args = parser.parse_args(argv)
|
|
79
|
+
if not args.api_key:
|
|
80
|
+
parser.error("--api-key is required unless DART_API_KEY is set.")
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
output_path, entry_count = update_raw_corp_codes_file(
|
|
84
|
+
api_key=args.api_key,
|
|
85
|
+
output_path=args.output,
|
|
86
|
+
)
|
|
87
|
+
except (MappingWorkflowError, OSError) as exc:
|
|
88
|
+
print(f"Failed to refresh corp codes: {exc}", file=sys.stderr)
|
|
89
|
+
return 1
|
|
90
|
+
|
|
91
|
+
print(f"Updated {output_path} with {entry_count} entries.")
|
|
92
|
+
return 0
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
if __name__ == "__main__":
|
|
96
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
"""Mapping workflow for OpenDart corp-code data."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import io
|
|
6
|
+
import json
|
|
7
|
+
import ssl
|
|
8
|
+
import urllib.error
|
|
9
|
+
import urllib.parse
|
|
10
|
+
import urllib.request
|
|
11
|
+
import xml.etree.ElementTree as ET
|
|
12
|
+
import zipfile
|
|
13
|
+
from collections import OrderedDict
|
|
14
|
+
from importlib.resources import files
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from ..exceptions import DownloadError, InvalidInputError, NotFoundError
|
|
18
|
+
|
|
19
|
+
_CORP_CODE_URL = "https://opendart.fss.or.kr/api/corpCode.xml"
|
|
20
|
+
_BUNDLED_DATA_KEY = "__bundled__"
|
|
21
|
+
_MISSING = object()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _normalize_entry(item: Any) -> dict[str, Any]:
|
|
25
|
+
if not isinstance(item, dict):
|
|
26
|
+
raise InvalidInputError("Each corp-code entry must be a JSON object.")
|
|
27
|
+
corp_code = _clean(item.get("corp_code"))
|
|
28
|
+
stock_name = _clean(item.get("corp_name") or item.get("stock_name"))
|
|
29
|
+
stock_code = _clean(item.get("stock_code"))
|
|
30
|
+
modify_date = _clean(item.get("modify_date"))
|
|
31
|
+
if corp_code is None or stock_name is None:
|
|
32
|
+
raise InvalidInputError("Each corp-code entry must include corp_code and corp_name.")
|
|
33
|
+
return {
|
|
34
|
+
"corp_code": corp_code,
|
|
35
|
+
"stock_name": stock_name,
|
|
36
|
+
"stock_code": stock_code,
|
|
37
|
+
"modify_date": modify_date,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _clean(value: Any) -> str | None:
|
|
43
|
+
if value is None:
|
|
44
|
+
return None
|
|
45
|
+
text = str(value).strip()
|
|
46
|
+
return text or None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _normalize_name(value: str) -> str:
|
|
51
|
+
return " ".join(value.split()).casefold()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _build_input_payload(
|
|
56
|
+
*,
|
|
57
|
+
stock_code: str | None,
|
|
58
|
+
stock_name: str | None,
|
|
59
|
+
corp_code: str | None,
|
|
60
|
+
get_latest: bool,
|
|
61
|
+
get_all: bool,
|
|
62
|
+
include_nonlisted: bool,
|
|
63
|
+
) -> dict[str, Any]:
|
|
64
|
+
return {
|
|
65
|
+
"stock_code": _clean(stock_code),
|
|
66
|
+
"stock_name": _clean(stock_name),
|
|
67
|
+
"corp_code": _clean(corp_code),
|
|
68
|
+
"get_latest": get_latest,
|
|
69
|
+
"get_all": get_all,
|
|
70
|
+
"include_nonlisted": include_nonlisted,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _detect_lookup_key(*, stock_code: str | None, stock_name: str | None, corp_code: str | None) -> tuple[str, str]:
|
|
76
|
+
provided = [
|
|
77
|
+
("stock_code", _clean(stock_code)),
|
|
78
|
+
("stock_name", _clean(stock_name)),
|
|
79
|
+
("corp_code", _clean(corp_code)),
|
|
80
|
+
]
|
|
81
|
+
provided = [(key, value) for key, value in provided if value is not None]
|
|
82
|
+
if len(provided) != 1:
|
|
83
|
+
raise InvalidInputError("Exactly one of stock_code, stock_name, or corp_code must be provided.")
|
|
84
|
+
return provided[0]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _filter_entries(entries: tuple[dict[str, Any], ...], lookup_key: str, lookup_value: str) -> list[dict[str, Any]]:
|
|
89
|
+
if lookup_key == "stock_name":
|
|
90
|
+
normalized = _normalize_name(lookup_value)
|
|
91
|
+
return [entry for entry in entries if _normalize_name(entry["stock_name"]) == normalized]
|
|
92
|
+
return [entry for entry in entries if entry[lookup_key] == lookup_value]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def download_latest_raw_entries(api_key: str) -> tuple[dict[str, Any], ...]:
|
|
97
|
+
api_key = _clean(api_key)
|
|
98
|
+
if api_key is None:
|
|
99
|
+
raise InvalidInputError("api_key is required when get_latest=True.")
|
|
100
|
+
|
|
101
|
+
query = urllib.parse.urlencode({"crtfc_key": api_key})
|
|
102
|
+
request_url = f"{_CORP_CODE_URL}?{query}"
|
|
103
|
+
try:
|
|
104
|
+
with urllib.request.urlopen(request_url, timeout=60, context=ssl.create_default_context()) as response:
|
|
105
|
+
content = response.read()
|
|
106
|
+
except (urllib.error.URLError, TimeoutError) as exc:
|
|
107
|
+
raise DownloadError("Failed to download OpenDart corpCode.xml.") from exc
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
with zipfile.ZipFile(io.BytesIO(content)) as zipped:
|
|
111
|
+
xml_name = next(name for name in zipped.namelist() if name.upper().endswith(".XML"))
|
|
112
|
+
xml_bytes = zipped.read(xml_name)
|
|
113
|
+
except (StopIteration, zipfile.BadZipFile, KeyError) as exc:
|
|
114
|
+
raise DownloadError("Downloaded OpenDart corpCode.xml payload is invalid.") from exc
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
root = ET.fromstring(xml_bytes)
|
|
118
|
+
except ET.ParseError as exc:
|
|
119
|
+
raise DownloadError("OpenDart corpCode.xml could not be parsed.") from exc
|
|
120
|
+
|
|
121
|
+
rows: list[dict[str, Any]] = []
|
|
122
|
+
for item in root.iter("list"):
|
|
123
|
+
rows.append(
|
|
124
|
+
{
|
|
125
|
+
"corp_code": _clean(item.findtext("corp_code")),
|
|
126
|
+
"corp_name": _clean(item.findtext("corp_name")),
|
|
127
|
+
"stock_code": _clean(item.findtext("stock_code")),
|
|
128
|
+
"modify_date": _clean(item.findtext("modify_date")),
|
|
129
|
+
}
|
|
130
|
+
)
|
|
131
|
+
return tuple(_normalize_entry(row) for row in rows)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class MappingWorkflow:
|
|
135
|
+
"""Resolve stock_name, stock_code, and corp_code from OpenDart corp-code data."""
|
|
136
|
+
|
|
137
|
+
def __init__(self, *, cache_size: int = 0, data_path: str | None = None) -> None:
|
|
138
|
+
if cache_size < 0:
|
|
139
|
+
raise InvalidInputError("cache_size must be >= 0.")
|
|
140
|
+
self._cache_size = cache_size
|
|
141
|
+
self._data_path = data_path
|
|
142
|
+
self._raw_cache: OrderedDict[str, tuple[dict[str, Any], ...]] = OrderedDict()
|
|
143
|
+
self._processed_cache: OrderedDict[tuple[str, bool], tuple[dict[str, Any], ...]] = OrderedDict()
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def cache_size(self) -> int:
|
|
147
|
+
return self._cache_size
|
|
148
|
+
|
|
149
|
+
def clear_cache(self) -> None:
|
|
150
|
+
self._raw_cache.clear()
|
|
151
|
+
self._processed_cache.clear()
|
|
152
|
+
|
|
153
|
+
def run(
|
|
154
|
+
self,
|
|
155
|
+
*,
|
|
156
|
+
stock_code: str | None = None,
|
|
157
|
+
stock_name: str | None = None,
|
|
158
|
+
corp_code: str | None = None,
|
|
159
|
+
get_latest: bool = False,
|
|
160
|
+
get_all: bool = False,
|
|
161
|
+
include_nonlisted: bool = False,
|
|
162
|
+
api_key: str | None = None,
|
|
163
|
+
data_path: str | None = None,
|
|
164
|
+
) -> dict[str, Any]:
|
|
165
|
+
input_payload = _build_input_payload(
|
|
166
|
+
stock_code=stock_code,
|
|
167
|
+
stock_name=stock_name,
|
|
168
|
+
corp_code=corp_code,
|
|
169
|
+
get_latest=get_latest,
|
|
170
|
+
get_all=get_all,
|
|
171
|
+
include_nonlisted=include_nonlisted,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if get_latest:
|
|
175
|
+
raw_entries = download_latest_raw_entries(api_key or "")
|
|
176
|
+
entries = tuple(
|
|
177
|
+
{
|
|
178
|
+
"stock_name": row["stock_name"],
|
|
179
|
+
"stock_code": row["stock_code"],
|
|
180
|
+
"corp_code": row["corp_code"],
|
|
181
|
+
}
|
|
182
|
+
for row in raw_entries
|
|
183
|
+
if include_nonlisted or row["stock_code"] is not None
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
entries = self._load_processed_entries(data_path, include_nonlisted)
|
|
187
|
+
|
|
188
|
+
if get_all:
|
|
189
|
+
return {
|
|
190
|
+
"input": input_payload,
|
|
191
|
+
"data": {"mappings": list(entries)},
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
lookup_key, lookup_value = _detect_lookup_key(
|
|
195
|
+
stock_code=stock_code,
|
|
196
|
+
stock_name=stock_name,
|
|
197
|
+
corp_code=corp_code,
|
|
198
|
+
)
|
|
199
|
+
matches = _filter_entries(entries, lookup_key, lookup_value)
|
|
200
|
+
if not matches:
|
|
201
|
+
raise NotFoundError(f"No mapping found for {lookup_key}={lookup_value!r}.")
|
|
202
|
+
return {
|
|
203
|
+
"input": input_payload,
|
|
204
|
+
"data": {"mappings": matches},
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
def _load_raw_entries(self, data_path: str | None) -> tuple[dict[str, Any], ...]:
|
|
208
|
+
resolved_data_path = self._resolve_data_path(data_path)
|
|
209
|
+
cache_key = resolved_data_path or _BUNDLED_DATA_KEY
|
|
210
|
+
cached = self._cache_get(self._raw_cache, cache_key)
|
|
211
|
+
if cached is not _MISSING:
|
|
212
|
+
return cached
|
|
213
|
+
|
|
214
|
+
if resolved_data_path is None:
|
|
215
|
+
resource = files("orchestrator.mapping.data").joinpath("corp_codes_raw.json")
|
|
216
|
+
text = resource.read_text(encoding="utf-8")
|
|
217
|
+
else:
|
|
218
|
+
with open(resolved_data_path, "r", encoding="utf-8") as handle:
|
|
219
|
+
text = handle.read()
|
|
220
|
+
|
|
221
|
+
data = json.loads(text)
|
|
222
|
+
if not isinstance(data, list):
|
|
223
|
+
raise InvalidInputError("Bundled corp-code JSON must contain a list of entries.")
|
|
224
|
+
rows = tuple(_normalize_entry(item) for item in data)
|
|
225
|
+
self._cache_set(self._raw_cache, cache_key, rows)
|
|
226
|
+
return rows
|
|
227
|
+
|
|
228
|
+
def _load_processed_entries(
|
|
229
|
+
self,
|
|
230
|
+
data_path: str | None,
|
|
231
|
+
include_nonlisted: bool,
|
|
232
|
+
) -> tuple[dict[str, Any], ...]:
|
|
233
|
+
resolved_data_path = self._resolve_data_path(data_path)
|
|
234
|
+
cache_key = (resolved_data_path or _BUNDLED_DATA_KEY, include_nonlisted)
|
|
235
|
+
cached = self._cache_get(self._processed_cache, cache_key)
|
|
236
|
+
if cached is not _MISSING:
|
|
237
|
+
return cached
|
|
238
|
+
|
|
239
|
+
rows = self._load_raw_entries(resolved_data_path)
|
|
240
|
+
processed: list[dict[str, Any]] = []
|
|
241
|
+
for row in rows:
|
|
242
|
+
stock_code = row["stock_code"]
|
|
243
|
+
if not include_nonlisted and stock_code is None:
|
|
244
|
+
continue
|
|
245
|
+
processed.append(
|
|
246
|
+
{
|
|
247
|
+
"stock_name": row["stock_name"],
|
|
248
|
+
"stock_code": stock_code,
|
|
249
|
+
"corp_code": row["corp_code"],
|
|
250
|
+
}
|
|
251
|
+
)
|
|
252
|
+
result = tuple(processed)
|
|
253
|
+
self._cache_set(self._processed_cache, cache_key, result)
|
|
254
|
+
return result
|
|
255
|
+
|
|
256
|
+
def _resolve_data_path(self, data_path: str | None) -> str | None:
|
|
257
|
+
return data_path if data_path is not None else self._data_path
|
|
258
|
+
|
|
259
|
+
def _cache_get(self, cache: OrderedDict[Any, Any], key: Any) -> Any:
|
|
260
|
+
if self._cache_size == 0:
|
|
261
|
+
return _MISSING
|
|
262
|
+
try:
|
|
263
|
+
value = cache.pop(key)
|
|
264
|
+
except KeyError:
|
|
265
|
+
return _MISSING
|
|
266
|
+
cache[key] = value
|
|
267
|
+
return value
|
|
268
|
+
|
|
269
|
+
def _cache_set(self, cache: OrderedDict[Any, Any], key: Any, value: Any) -> None:
|
|
270
|
+
if self._cache_size == 0:
|
|
271
|
+
return
|
|
272
|
+
if key in cache:
|
|
273
|
+
cache.pop(key)
|
|
274
|
+
cache[key] = value
|
|
275
|
+
while len(cache) > self._cache_size:
|
|
276
|
+
cache.popitem(last=False)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
_DEFAULT_WORKFLOW = MappingWorkflow(cache_size=0)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def run_mapping_workflow(
|
|
284
|
+
*,
|
|
285
|
+
stock_code: str | None = None,
|
|
286
|
+
stock_name: str | None = None,
|
|
287
|
+
corp_code: str | None = None,
|
|
288
|
+
get_latest: bool = False,
|
|
289
|
+
get_all: bool = False,
|
|
290
|
+
include_nonlisted: bool = False,
|
|
291
|
+
api_key: str | None = None,
|
|
292
|
+
data_path: str | None = None,
|
|
293
|
+
) -> dict[str, Any]:
|
|
294
|
+
return _DEFAULT_WORKFLOW.run(
|
|
295
|
+
stock_code=stock_code,
|
|
296
|
+
stock_name=stock_name,
|
|
297
|
+
corp_code=corp_code,
|
|
298
|
+
get_latest=get_latest,
|
|
299
|
+
get_all=get_all,
|
|
300
|
+
include_nonlisted=include_nonlisted,
|
|
301
|
+
api_key=api_key,
|
|
302
|
+
data_path=data_path,
|
|
303
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Price workflows for FinanceChatbot."""
|
|
2
|
+
|
|
3
|
+
from .workflow import (
|
|
4
|
+
PricePeriodWorkflow,
|
|
5
|
+
PriceSnapshotWorkflow,
|
|
6
|
+
run_price_period_workflow,
|
|
7
|
+
run_price_snapshot_workflow,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"PricePeriodWorkflow",
|
|
12
|
+
"PriceSnapshotWorkflow",
|
|
13
|
+
"run_price_period_workflow",
|
|
14
|
+
"run_price_snapshot_workflow",
|
|
15
|
+
]
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Price workflows built on FinanceDataReader."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from datetime import date, timedelta
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from ..exceptions import InvalidInputError, NotFoundError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class PriceBar:
|
|
14
|
+
"""Normalized daily price bar."""
|
|
15
|
+
|
|
16
|
+
trade_date: str
|
|
17
|
+
open: int | None = None
|
|
18
|
+
close: int | None = None
|
|
19
|
+
low: int | None = None
|
|
20
|
+
high: int | None = None
|
|
21
|
+
volume: int | None = None
|
|
22
|
+
|
|
23
|
+
def has_fields(self, field_names: tuple[str, ...]) -> bool:
|
|
24
|
+
return all(getattr(self, field_name) is not None for field_name in field_names)
|
|
25
|
+
|
|
26
|
+
def to_dict(self, *, field_names: tuple[str, ...]) -> dict[str, int | str]:
|
|
27
|
+
payload: dict[str, int | str] = {"date": self.trade_date}
|
|
28
|
+
for field_name in field_names:
|
|
29
|
+
value = getattr(self, field_name)
|
|
30
|
+
if value is None:
|
|
31
|
+
raise InvalidInputError(f"Missing {field_name} value for trade_date={self.trade_date}.")
|
|
32
|
+
payload[field_name] = value
|
|
33
|
+
return payload
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_PRICE_FIELD_PRESETS: dict[str, tuple[str, ...]] = {
|
|
37
|
+
"close": ("close",),
|
|
38
|
+
"oclh": ("open", "close", "low", "high"),
|
|
39
|
+
"oclhv": ("open", "close", "low", "high", "volume"),
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
_PRICE_SOURCE_COLUMNS = {
|
|
43
|
+
"open": "Open",
|
|
44
|
+
"close": "Close",
|
|
45
|
+
"low": "Low",
|
|
46
|
+
"high": "High",
|
|
47
|
+
"volume": "Volume",
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class PriceSnapshotWorkflow:
|
|
52
|
+
"""Return the latest price and the closest previous price."""
|
|
53
|
+
|
|
54
|
+
def __init__(self, *, lookback_days: int = 30) -> None:
|
|
55
|
+
if lookback_days < 2:
|
|
56
|
+
raise InvalidInputError("lookback_days must be >= 2.")
|
|
57
|
+
self._lookback_days = lookback_days
|
|
58
|
+
|
|
59
|
+
def run(
|
|
60
|
+
self,
|
|
61
|
+
*,
|
|
62
|
+
stock_code: str,
|
|
63
|
+
lookback_days: int | None = None,
|
|
64
|
+
price_fields: str = "close",
|
|
65
|
+
) -> dict[str, Any]:
|
|
66
|
+
stock_code = _normalize_stock_code(stock_code)
|
|
67
|
+
window = lookback_days if lookback_days is not None else self._lookback_days
|
|
68
|
+
if window < 2:
|
|
69
|
+
raise InvalidInputError("lookback_days must be >= 2.")
|
|
70
|
+
normalized_price_fields, field_names = _normalize_price_fields(price_fields)
|
|
71
|
+
|
|
72
|
+
bars = _fetch_price_bars(
|
|
73
|
+
stock_code=stock_code,
|
|
74
|
+
start=(date.today() - timedelta(days=window)).isoformat(),
|
|
75
|
+
end=date.today().isoformat(),
|
|
76
|
+
field_names=field_names,
|
|
77
|
+
)
|
|
78
|
+
if len(bars) < 2:
|
|
79
|
+
raise NotFoundError(f"Not enough price history found for stock_code={stock_code!r}.")
|
|
80
|
+
|
|
81
|
+
latest = bars[-1]
|
|
82
|
+
previous = bars[-2]
|
|
83
|
+
response = {
|
|
84
|
+
"input": {
|
|
85
|
+
"stock_code": stock_code,
|
|
86
|
+
"lookback_days": window,
|
|
87
|
+
"price_fields": normalized_price_fields,
|
|
88
|
+
},
|
|
89
|
+
"data": {},
|
|
90
|
+
}
|
|
91
|
+
if normalized_price_fields == "close":
|
|
92
|
+
response["data"] = {
|
|
93
|
+
"lastClose": previous.close,
|
|
94
|
+
"latestClose": latest.close,
|
|
95
|
+
}
|
|
96
|
+
return response
|
|
97
|
+
|
|
98
|
+
response["data"] = {
|
|
99
|
+
"lastPrice": previous.to_dict(field_names=field_names),
|
|
100
|
+
"latestPrice": latest.to_dict(field_names=field_names),
|
|
101
|
+
}
|
|
102
|
+
return response
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class PricePeriodWorkflow:
|
|
106
|
+
"""Return daily prices for a specific period."""
|
|
107
|
+
|
|
108
|
+
def run(
|
|
109
|
+
self,
|
|
110
|
+
*,
|
|
111
|
+
stock_code: str,
|
|
112
|
+
start_date: str,
|
|
113
|
+
end_date: str | None = None,
|
|
114
|
+
price_fields: str = "close",
|
|
115
|
+
) -> dict[str, Any]:
|
|
116
|
+
stock_code = _normalize_stock_code(stock_code)
|
|
117
|
+
start_date = _normalize_date_string(start_date, field_name="start_date")
|
|
118
|
+
end_date = _normalize_date_string(end_date, field_name="end_date") if end_date else date.today().isoformat()
|
|
119
|
+
if start_date > end_date:
|
|
120
|
+
raise InvalidInputError("start_date must be <= end_date.")
|
|
121
|
+
normalized_price_fields, field_names = _normalize_price_fields(price_fields)
|
|
122
|
+
|
|
123
|
+
bars = _fetch_price_bars(
|
|
124
|
+
stock_code=stock_code,
|
|
125
|
+
start=start_date,
|
|
126
|
+
end=end_date,
|
|
127
|
+
field_names=field_names,
|
|
128
|
+
)
|
|
129
|
+
if not bars:
|
|
130
|
+
raise NotFoundError(
|
|
131
|
+
f"No price data found for stock_code={stock_code!r} between {start_date} and {end_date}."
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return {
|
|
135
|
+
"input": {
|
|
136
|
+
"stock_code": stock_code,
|
|
137
|
+
"start_date": start_date,
|
|
138
|
+
"end_date": end_date,
|
|
139
|
+
"price_fields": normalized_price_fields,
|
|
140
|
+
},
|
|
141
|
+
"data": {
|
|
142
|
+
"prices": [
|
|
143
|
+
bar.to_dict(field_names=field_names)
|
|
144
|
+
for bar in bars
|
|
145
|
+
]
|
|
146
|
+
},
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
_DEFAULT_SNAPSHOT_WORKFLOW = PriceSnapshotWorkflow()
|
|
151
|
+
_DEFAULT_PERIOD_WORKFLOW = PricePeriodWorkflow()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def run_price_snapshot_workflow(
|
|
155
|
+
*,
|
|
156
|
+
stock_code: str,
|
|
157
|
+
lookback_days: int = 30,
|
|
158
|
+
price_fields: str = "close",
|
|
159
|
+
) -> dict[str, Any]:
|
|
160
|
+
return _DEFAULT_SNAPSHOT_WORKFLOW.run(
|
|
161
|
+
stock_code=stock_code,
|
|
162
|
+
lookback_days=lookback_days,
|
|
163
|
+
price_fields=price_fields,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def run_price_period_workflow(
|
|
168
|
+
*,
|
|
169
|
+
stock_code: str,
|
|
170
|
+
start_date: str,
|
|
171
|
+
end_date: str | None = None,
|
|
172
|
+
price_fields: str = "close",
|
|
173
|
+
) -> dict[str, Any]:
|
|
174
|
+
return _DEFAULT_PERIOD_WORKFLOW.run(
|
|
175
|
+
stock_code=stock_code,
|
|
176
|
+
start_date=start_date,
|
|
177
|
+
end_date=end_date,
|
|
178
|
+
price_fields=price_fields,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _fetch_price_bars(
|
|
184
|
+
*,
|
|
185
|
+
stock_code: str,
|
|
186
|
+
start: str,
|
|
187
|
+
end: str,
|
|
188
|
+
field_names: tuple[str, ...],
|
|
189
|
+
) -> list[PriceBar]:
|
|
190
|
+
df = _load_price_frame(stock_code=stock_code, start=start, end=end)
|
|
191
|
+
if df is None or df.empty:
|
|
192
|
+
return []
|
|
193
|
+
|
|
194
|
+
bars: list[PriceBar] = []
|
|
195
|
+
for idx, row in df.iterrows():
|
|
196
|
+
bar = _build_price_bar(idx=idx, row=row)
|
|
197
|
+
if bar.has_fields(field_names):
|
|
198
|
+
bars.append(bar)
|
|
199
|
+
return bars
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _load_price_frame(*, stock_code: str, start: str, end: str):
|
|
204
|
+
try:
|
|
205
|
+
import FinanceDataReader as fdr # noqa: PLC0415
|
|
206
|
+
except ImportError as exc:
|
|
207
|
+
raise InvalidInputError("FinanceDataReader is required for price workflows.") from exc
|
|
208
|
+
return fdr.DataReader(stock_code, start, end)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _normalize_stock_code(stock_code: str) -> str:
|
|
213
|
+
normalized = str(stock_code or "").strip()
|
|
214
|
+
if not normalized:
|
|
215
|
+
raise InvalidInputError("stock_code is required.")
|
|
216
|
+
return normalized
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _normalize_date_string(value: str | None, *, field_name: str) -> str:
|
|
221
|
+
normalized = str(value or "").strip()
|
|
222
|
+
if not normalized:
|
|
223
|
+
raise InvalidInputError(f"{field_name} is required.")
|
|
224
|
+
try:
|
|
225
|
+
return date.fromisoformat(normalized).isoformat()
|
|
226
|
+
except ValueError as exc:
|
|
227
|
+
raise InvalidInputError(f"{field_name} must be YYYY-MM-DD.") from exc
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _normalize_price_fields(price_fields: str | None) -> tuple[str, tuple[str, ...]]:
|
|
231
|
+
normalized = str(price_fields or "close").strip().lower()
|
|
232
|
+
field_names = _PRICE_FIELD_PRESETS.get(normalized)
|
|
233
|
+
if field_names is None:
|
|
234
|
+
raise InvalidInputError("price_fields must be one of: close, oclh, oclhv.")
|
|
235
|
+
return normalized, field_names
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _build_price_bar(*, idx: Any, row: Any) -> PriceBar:
|
|
239
|
+
values: dict[str, int] = {}
|
|
240
|
+
for field_name, column_name in _PRICE_SOURCE_COLUMNS.items():
|
|
241
|
+
value = row.get(column_name)
|
|
242
|
+
if _is_missing(value):
|
|
243
|
+
continue
|
|
244
|
+
values[field_name] = int(value)
|
|
245
|
+
trade_date = idx.date().isoformat() if hasattr(idx, "date") else str(idx)[:10]
|
|
246
|
+
return PriceBar(trade_date=trade_date, **values)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _is_missing(value: Any) -> bool:
|
|
250
|
+
return value is None or value != value
|