deeptrade-quant 0.3.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/PKG-INFO +1 -1
  2. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/__init__.py +1 -1
  3. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/cli.py +5 -3
  4. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/cli_config.py +8 -10
  5. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/cli_plugin.py +6 -18
  6. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/config.py +12 -3
  7. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/github_fetch.py +9 -25
  8. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/llm_manager.py +2 -4
  9. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/plugin_manager.py +1 -3
  10. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/plugin_source.py +3 -8
  11. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/registry.py +8 -22
  12. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/tushare_client.py +173 -17
  13. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/pyproject.toml +1 -1
  14. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_config.py +2 -6
  15. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_db.py +1 -0
  16. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_github_fetch.py +3 -3
  17. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_llm_client.py +6 -18
  18. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_llm_manager.py +1 -3
  19. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_plugin_install.py +4 -12
  20. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_plugin_source.py +3 -1
  21. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_plugin_upgrade.py +3 -3
  22. deeptrade_quant-0.3.1/tests/core/test_tushare_classifier.py +274 -0
  23. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_tushare_client.py +14 -28
  24. deeptrade_quant-0.3.1/tests/core/test_tushare_retry_r1.py +234 -0
  25. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/.gitignore +0 -0
  26. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/LICENSE +0 -0
  27. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/README.md +0 -0
  28. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/cli_data.py +0 -0
  29. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/__init__.py +0 -0
  30. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/config_migrations.py +0 -0
  31. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/db.py +0 -0
  32. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/llm_client.py +0 -0
  33. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/logging_config.py +0 -0
  34. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/migrations/__init__.py +0 -0
  35. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/migrations/core/20260509_001_init.sql +0 -0
  36. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/migrations/core/__init__.py +0 -0
  37. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/paths.py +0 -0
  38. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/run_status.py +0 -0
  39. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/core/secrets.py +0 -0
  40. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/plugins_api/__init__.py +0 -0
  41. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/plugins_api/base.py +0 -0
  42. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/plugins_api/events.py +0 -0
  43. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/plugins_api/llm.py +0 -0
  44. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/plugins_api/metadata.py +0 -0
  45. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/deeptrade/theme.py +0 -0
  46. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/__init__.py +0 -0
  47. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/cli/__init__.py +0 -0
  48. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/cli/test_config_cmd.py +0 -0
  49. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/cli/test_plugin_cmd.py +0 -0
  50. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/cli/test_routing.py +0 -0
  51. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/conftest.py +0 -0
  52. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/__init__.py +0 -0
  53. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_config_migrations.py +0 -0
  54. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_paths.py +0 -0
  55. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_registry.py +0 -0
  56. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/core/test_secrets.py +0 -0
  57. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/plugins_api/__init__.py +0 -0
  58. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/plugins_api/test_protocol.py +0 -0
  59. {deeptrade_quant-0.3.0 → deeptrade_quant-0.3.1}/tests/test_smoke.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: deeptrade-quant
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: LLM-driven A-share (Shanghai/Shenzhen main board) stock screening CLI
5
5
  Project-URL: Homepage, https://github.com/ty19880929/deeptrade
6
6
  Project-URL: Repository, https://github.com/ty19880929/deeptrade
@@ -2,5 +2,5 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- __version__ = "0.3.0"
5
+ __version__ = "0.3.1"
6
6
  __all__ = ["__version__"]
@@ -103,8 +103,7 @@ def _build_plugin_command(plugin_id: str) -> click.Command | None:
103
103
  )
104
104
  def _disabled() -> None:
105
105
  typer.echo(
106
- f"✘ plugin {plugin_id!r} is disabled; "
107
- f"run `deeptrade plugin enable {plugin_id}`"
106
+ f"✘ plugin {plugin_id!r} is disabled; run `deeptrade plugin enable {plugin_id}`"
108
107
  )
109
108
  raise typer.Exit(2)
110
109
 
@@ -204,15 +203,16 @@ def init(
204
203
  cmd_set_llm()
205
204
 
206
205
 
207
-
208
206
  @app.command(name="db", context_settings={"ignore_unknown_options": True, "allow_extra_args": True})
209
207
  def db_cmd(ctx: click.Context) -> None:
210
208
  """Database migration and management commands (legacy stub; use `deeptrade db init` via group if added)."""
211
209
  pass
212
210
 
211
+
213
212
  db_app = typer.Typer(name="db", help="Database migration and management commands.")
214
213
  app.add_typer(db_app, name="db")
215
214
 
215
+
216
216
  @db_app.command("init")
217
217
  def db_init() -> None:
218
218
  """Initialize the core database tables and apply migrations."""
@@ -232,10 +232,12 @@ def db_init() -> None:
232
232
  finally:
233
233
  db.close()
234
234
 
235
+
235
236
  @db_app.command("upgrade")
236
237
  def db_upgrade() -> None:
237
238
  """Apply any pending core migrations."""
238
239
  db_init()
239
240
 
241
+
240
242
  if __name__ == "__main__":
241
243
  app()
@@ -156,10 +156,12 @@ def cmd_set_llm() -> None:
156
156
  existing = sorted(cfg.llm_providers.keys())
157
157
 
158
158
  if existing:
159
- choices = ["[+] Add new provider"] + [f"[~] {n}" for n in existing] + ["[x] Delete a provider"]
160
- picked = questionary.select(
161
- "Pick action:", choices=choices
162
- ).ask()
159
+ choices = (
160
+ ["[+] Add new provider"]
161
+ + [f"[~] {n}" for n in existing]
162
+ + ["[x] Delete a provider"]
163
+ )
164
+ picked = questionary.select("Pick action:", choices=choices).ask()
163
165
  if picked is None:
164
166
  raise typer.Exit(1)
165
167
  if picked.startswith("[+]"):
@@ -177,9 +179,7 @@ def cmd_set_llm() -> None:
177
179
 
178
180
 
179
181
  def _set_llm_new(svc: ConfigService) -> None:
180
- name = questionary.text(
181
- "Provider name (e.g. deepseek, qwen-plus, kimi):"
182
- ).ask()
182
+ name = questionary.text("Provider name (e.g. deepseek, qwen-plus, kimi):").ask()
183
183
  if not name:
184
184
  raise typer.Exit(1)
185
185
  name = name.strip()
@@ -259,9 +259,7 @@ def _prompt_and_save_provider(
259
259
  raise typer.Exit(2) from e
260
260
 
261
261
  api_key_prompt = (
262
- "API key (leave empty to keep existing):"
263
- if defaults is not None
264
- else "API key:"
262
+ "API key (leave empty to keep existing):" if defaults is not None else "API key:"
265
263
  )
266
264
  api_key = questionary.password(api_key_prompt).ask()
267
265
  if api_key is None:
@@ -58,9 +58,7 @@ def _format_origin(resolved: ResolvedSource) -> str:
58
58
  if resolved.origin == "local":
59
59
  return f"本地路径 ({d.get('local_path', resolved.path)})"
60
60
  if resolved.origin == "github_registry":
61
- return (
62
- f"GitHub 注册表 ({d['repo']}@{d['ref']}, subdir={d['subdir']})"
63
- )
61
+ return f"GitHub 注册表 ({d['repo']}@{d['ref']}, subdir={d['subdir']})"
64
62
  if resolved.origin == "github_url":
65
63
  return f"GitHub URL ({d['repo']}@{d['ref']})"
66
64
  return resolved.origin
@@ -68,9 +66,7 @@ def _format_origin(resolved: ResolvedSource) -> str:
68
66
 
69
67
  @app.command("install")
70
68
  def cmd_install(
71
- source: str = typer.Argument(
72
- ..., help="短名(注册表)/ 本地路径 / GitHub URL"
73
- ),
69
+ source: str = typer.Argument(..., help="短名(注册表)/ 本地路径 / GitHub URL"),
74
70
  ref: str | None = typer.Option(
75
71
  None, "--ref", help="Tag / branch / sha (默认 = 该插件最新 release)"
76
72
  ),
@@ -150,9 +146,7 @@ def cmd_info(plugin_id: str = typer.Argument(...)) -> None:
150
146
  try:
151
147
  try:
152
148
  rec = mgr.info(plugin_id)
153
- typer.echo(
154
- yaml.safe_dump(rec.metadata.model_dump(mode="json"), allow_unicode=True)
155
- )
149
+ typer.echo(yaml.safe_dump(rec.metadata.model_dump(mode="json"), allow_unicode=True))
156
150
  return
157
151
  except PluginNotFoundError:
158
152
  pass # fall through to registry lookup
@@ -237,9 +231,7 @@ def cmd_uninstall(
237
231
 
238
232
  @app.command("upgrade")
239
233
  def cmd_upgrade(
240
- source: str = typer.Argument(
241
- ..., help="短名(注册表)/ 本地路径 / GitHub URL"
242
- ),
234
+ source: str = typer.Argument(..., help="短名(注册表)/ 本地路径 / GitHub URL"),
243
235
  ref: str | None = typer.Option(
244
236
  None, "--ref", help="Tag / branch / sha (默认 = 该插件最新 release)"
245
237
  ),
@@ -269,9 +261,7 @@ def cmd_upgrade(
269
261
  pid = meta.plugin_id
270
262
  except PluginInstallError:
271
263
  pid = source
272
- typer.echo(
273
- f'✘ 插件 "{pid}" 未安装,请先执行 deeptrade plugin install'
274
- )
264
+ typer.echo(f'✘ 插件 "{pid}" 未安装,请先执行 deeptrade plugin install')
275
265
  raise typer.Exit(2) from e
276
266
  except PluginInstallError as e:
277
267
  typer.echo(f"✘ Upgrade failed: {e}")
@@ -293,9 +283,7 @@ def cmd_search(
293
283
  keyword: str | None = typer.Argument(
294
284
  None, help="可选过滤关键词(匹配 plugin_id / name / description)"
295
285
  ),
296
- no_cache: bool = typer.Option(
297
- False, "--no-cache", help="强制刷新注册表(旁路 ETag 缓存)"
298
- ),
286
+ no_cache: bool = typer.Option(False, "--no-cache", help="强制刷新注册表(旁路 ETag 缓存)"),
299
287
  ) -> None:
300
288
  """List plugins available in the official registry."""
301
289
  try:
@@ -87,6 +87,11 @@ class AppConfig(BaseModel):
87
87
  # tushare.* (token lives in secret_store)
88
88
  tushare_rps: float = Field(default=6.0, gt=0)
89
89
  tushare_timeout: int = Field(default=30, ge=1)
90
+ # Tenacity stop_after_attempt for transient errors (rate limit / server /
91
+ # transport). Default 7 keeps worst-case wait around one minute of
92
+ # jittered exponential backoff. Each attempt re-enters the token bucket,
93
+ # so retries never bypass rate limiting.
94
+ tushare_max_retries: int = Field(default=7, ge=1, le=20)
90
95
 
91
96
  # Global preset name. v0.7 — renamed from ``deepseek.profile``; semantics
92
97
  # are vendor-agnostic. Per-stage tuning is resolved by each plugin's
@@ -134,6 +139,7 @@ _DOT_TO_FIELD: dict[str, str] = {
134
139
  "app.close_after": "app_close_after",
135
140
  "tushare.rps": "tushare_rps",
136
141
  "tushare.timeout": "tushare_timeout",
142
+ "tushare.max_retries": "tushare_max_retries",
137
143
  "app.profile": "app_profile",
138
144
  "llm.providers": "llm_providers",
139
145
  "llm.audit_full_payload": "llm_audit_full_payload",
@@ -365,8 +371,7 @@ class ConfigService:
365
371
  # already default; otherwise we'd leave the dict with no
366
372
  # default at all. Preserve prior_default in that case.
367
373
  other_has_default = any(
368
- k != name and bool((v or {}).get("is_default"))
369
- for k, v in current.items()
374
+ k != name and bool((v or {}).get("is_default")) for k, v in current.items()
370
375
  )
371
376
  new_default = prior_default if not other_has_default else False
372
377
 
@@ -374,7 +379,11 @@ class ConfigService:
374
379
  # the invariant "at most one default" holds.
375
380
  if new_default:
376
381
  for other_name, other_cfg in current.items():
377
- if other_name != name and isinstance(other_cfg, dict) and other_cfg.get("is_default"):
382
+ if (
383
+ other_name != name
384
+ and isinstance(other_cfg, dict)
385
+ and other_cfg.get("is_default")
386
+ ):
378
387
  current[other_name] = {**other_cfg, "is_default": False}
379
388
 
380
389
  current[name] = {
@@ -95,20 +95,14 @@ def latest_release_tag(repo: str, tag_prefix: str = "", *, timeout: float = 15.0
95
95
  payload = resp.read()
96
96
  link_header = resp.headers.get("Link")
97
97
  except HTTPError as e:
98
- raise GitHubFetchError(
99
- f"HTTP {e.code} listing releases for {repo}: {e}"
100
- ) from e
98
+ raise GitHubFetchError(f"HTTP {e.code} listing releases for {repo}: {e}") from e
101
99
  except URLError as e:
102
- raise GitHubFetchError(
103
- f"network error listing releases for {repo}: {e}"
104
- ) from e
100
+ raise GitHubFetchError(f"network error listing releases for {repo}: {e}") from e
105
101
 
106
102
  try:
107
103
  data = json.loads(payload.decode("utf-8"))
108
104
  except (UnicodeDecodeError, json.JSONDecodeError) as e:
109
- raise GitHubFetchError(
110
- f"invalid JSON in releases response for {repo}: {e}"
111
- ) from e
105
+ raise GitHubFetchError(f"invalid JSON in releases response for {repo}: {e}") from e
112
106
 
113
107
  if not isinstance(data, list):
114
108
  raise GitHubFetchError(
@@ -125,7 +119,7 @@ def latest_release_tag(repo: str, tag_prefix: str = "", *, timeout: float = 15.0
125
119
  continue
126
120
  if tag_prefix and not tag.startswith(tag_prefix):
127
121
  continue
128
- ver_str = tag[len(tag_prefix):] if tag_prefix else tag
122
+ ver_str = tag[len(tag_prefix) :] if tag_prefix else tag
129
123
  ver_str = ver_str.lstrip("v")
130
124
  try:
131
125
  candidates.append((Version(ver_str), tag))
@@ -142,9 +136,7 @@ def latest_release_tag(repo: str, tag_prefix: str = "", *, timeout: float = 15.0
142
136
  return candidates[0][1]
143
137
 
144
138
 
145
- def fetch_tarball(
146
- repo: str, ref: str, dest_dir: Path, *, timeout: float = 60.0
147
- ) -> Path:
139
+ def fetch_tarball(repo: str, ref: str, dest_dir: Path, *, timeout: float = 60.0) -> Path:
148
140
  """Download ``repo`` at ``ref`` from the GitHub tarball API and extract.
149
141
 
150
142
  Returns the unique top-level directory created inside ``dest_dir``
@@ -157,9 +149,7 @@ def fetch_tarball(
157
149
 
158
150
  tmp_path: Path | None = None
159
151
  try:
160
- with tempfile.NamedTemporaryFile(
161
- suffix=".tar.gz", delete=False
162
- ) as tmp:
152
+ with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
163
153
  tmp_path = Path(tmp.name)
164
154
 
165
155
  req = _build_request(url, accept="application/vnd.github+json")
@@ -167,13 +157,9 @@ def fetch_tarball(
167
157
  with urlopen(req, timeout=timeout) as resp, tmp_path.open("wb") as fout:
168
158
  shutil.copyfileobj(resp, fout)
169
159
  except HTTPError as e:
170
- raise TarballFetchError(
171
- f"HTTP {e.code} downloading tarball {repo}@{ref}: {e}"
172
- ) from e
160
+ raise TarballFetchError(f"HTTP {e.code} downloading tarball {repo}@{ref}: {e}") from e
173
161
  except URLError as e:
174
- raise TarballFetchError(
175
- f"network error downloading tarball {repo}@{ref}: {e}"
176
- ) from e
162
+ raise TarballFetchError(f"network error downloading tarball {repo}@{ref}: {e}") from e
177
163
 
178
164
  try:
179
165
  with tarfile.open(tmp_path, mode="r:gz") as tf:
@@ -208,9 +194,7 @@ def _safe_extract(tf: tarfile.TarFile, dest: Path) -> None:
208
194
  try:
209
195
  member_path.relative_to(dest_resolved)
210
196
  except ValueError as e:
211
- raise tarfile.TarError(
212
- f"unsafe path in tarball (would escape dest): {m.name!r}"
213
- ) from e
197
+ raise tarfile.TarError(f"unsafe path in tarball (would escape dest): {m.name!r}") from e
214
198
 
215
199
  try:
216
200
  tf.extractall(dest, filter="data")
@@ -111,8 +111,7 @@ class LLMManager:
111
111
  provider = cfg.llm_providers.get(name)
112
112
  if provider is None:
113
113
  raise LLMNotConfiguredError(
114
- f"LLM provider {name!r} is not configured; "
115
- "run `deeptrade config set-llm` to add it"
114
+ f"LLM provider {name!r} is not configured; run `deeptrade config set-llm` to add it"
116
115
  )
117
116
  return LLMProviderInfo(name=name, model=provider.model, base_url=provider.base_url)
118
117
 
@@ -160,8 +159,7 @@ class LLMManager:
160
159
  provider = cfg.llm_providers.get(name)
161
160
  if provider is None:
162
161
  raise LLMNotConfiguredError(
163
- f"LLM provider {name!r} is not configured; "
164
- "run `deeptrade config set-llm` to add it"
162
+ f"LLM provider {name!r} is not configured; run `deeptrade config set-llm` to add it"
165
163
  )
166
164
  api_key = self._config.get(f"llm.{name}.api_key")
167
165
  if not api_key:
@@ -373,9 +373,7 @@ class PluginManager:
373
373
  new_ver = Version(meta.version)
374
374
  cur_ver = Version(existing.version)
375
375
  except InvalidVersion as e:
376
- raise PluginInstallError(
377
- f"invalid version on {meta.plugin_id}: {e}"
378
- ) from e
376
+ raise PluginInstallError(f"invalid version on {meta.plugin_id}: {e}") from e
379
377
 
380
378
  if new_ver == cur_ver:
381
379
  return UpgradeNoop(plugin_id=meta.plugin_id, version=existing.version)
@@ -68,9 +68,7 @@ def _is_git_url(s: str) -> bool:
68
68
  def _parse_github_url(url: str) -> tuple[str, str]:
69
69
  m = _GITHUB_URL_RE.match(url)
70
70
  if not m:
71
- raise SourceResolveError(
72
- f"unsupported URL form (only github.com is supported): {url!r}"
73
- )
71
+ raise SourceResolveError(f"unsupported URL form (only github.com is supported): {url!r}")
74
72
  return m.group(1), m.group(2)
75
73
 
76
74
 
@@ -121,14 +119,11 @@ class SourceResolver:
121
119
  top = fetch_tarball(entry.repo, ref, Path(tmp.name))
122
120
  plugin_path = top / entry.subdir
123
121
  if not plugin_path.is_dir():
124
- raise SourceResolveError(
125
- f"subdir {entry.subdir!r} not found in {entry.repo}@{ref}"
126
- )
122
+ raise SourceResolveError(f"subdir {entry.subdir!r} not found in {entry.repo}@{ref}")
127
123
  yaml_path = plugin_path / "deeptrade_plugin.yaml"
128
124
  if not yaml_path.is_file():
129
125
  raise SourceResolveError(
130
- f"no deeptrade_plugin.yaml in {entry.repo}@{ref} "
131
- f"under {entry.subdir!r}"
126
+ f"no deeptrade_plugin.yaml in {entry.repo}@{ref} under {entry.subdir!r}"
132
127
  )
133
128
  except Exception:
134
129
  tmp.cleanup()
@@ -22,8 +22,7 @@ from deeptrade.core import paths
22
22
  logger = logging.getLogger(__name__)
23
23
 
24
24
  REGISTRY_URL = (
25
- "https://raw.githubusercontent.com/ty19880929/"
26
- "DeepTradePluginOfficial/main/registry/index.json"
25
+ "https://raw.githubusercontent.com/ty19880929/DeepTradePluginOfficial/main/registry/index.json"
27
26
  )
28
27
 
29
28
  _REQUIRED_FIELDS = frozenset(
@@ -81,9 +80,7 @@ def _parse_registry(data: Any) -> Registry:
81
80
 
82
81
  schema_version = data.get("schema_version")
83
82
  if schema_version != 1:
84
- raise RegistrySchemaError(
85
- f"schema_version must be 1, got {schema_version!r}"
86
- )
83
+ raise RegistrySchemaError(f"schema_version must be 1, got {schema_version!r}")
87
84
 
88
85
  plugins_raw = data.get("plugins")
89
86
  if not isinstance(plugins_raw, dict):
@@ -95,9 +92,7 @@ def _parse_registry(data: Any) -> Registry:
95
92
  raise RegistrySchemaError(f"plugins.{plugin_id} must be an object")
96
93
  missing = _REQUIRED_FIELDS - set(raw)
97
94
  if missing:
98
- raise RegistrySchemaError(
99
- f"plugins.{plugin_id} missing fields: {sorted(missing)}"
100
- )
95
+ raise RegistrySchemaError(f"plugins.{plugin_id} missing fields: {sorted(missing)}")
101
96
  entries[plugin_id] = RegistryEntry(
102
97
  plugin_id=plugin_id,
103
98
  **{k: raw[k] for k in _REQUIRED_FIELDS},
@@ -144,18 +139,12 @@ class RegistryClient:
144
139
  if e.code == 304 and cached is not None:
145
140
  logger.debug("registry: 304 Not Modified, using cache")
146
141
  return _parse_registry(cached["body"])
147
- raise RegistryFetchError(
148
- f"HTTP {e.code} fetching registry: {e}"
149
- ) from e
142
+ raise RegistryFetchError(f"HTTP {e.code} fetching registry: {e}") from e
150
143
  except URLError as e:
151
144
  if cached is not None:
152
- logger.warning(
153
- "registry: network error %s, falling back to cache", e
154
- )
145
+ logger.warning("registry: network error %s, falling back to cache", e)
155
146
  return _parse_registry(cached["body"])
156
- raise RegistryFetchError(
157
- f"network error fetching registry: {e}"
158
- ) from e
147
+ raise RegistryFetchError(f"network error fetching registry: {e}") from e
159
148
 
160
149
  try:
161
150
  body = json.loads(raw_bytes.decode("utf-8"))
@@ -171,8 +160,7 @@ class RegistryClient:
171
160
  registry = self.fetch()
172
161
  if plugin_id not in registry.plugins:
173
162
  raise RegistryNotFoundError(
174
- f"plugin {plugin_id!r} not in registry. "
175
- f"Available: {sorted(registry.plugins)}"
163
+ f"plugin {plugin_id!r} not in registry. Available: {sorted(registry.plugins)}"
176
164
  )
177
165
  return registry.plugins[plugin_id]
178
166
 
@@ -186,6 +174,4 @@ class RegistryClient:
186
174
 
187
175
  def _write_cache(self, payload: dict) -> None:
188
176
  self.cache_path.parent.mkdir(parents=True, exist_ok=True)
189
- self.cache_path.write_text(
190
- json.dumps(payload, ensure_ascii=False), encoding="utf-8"
191
- )
177
+ self.cache_path.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8")
@@ -32,10 +32,10 @@ from typing import Any, Literal
32
32
 
33
33
  import pandas as pd
34
34
  from tenacity import (
35
- retry,
35
+ Retrying,
36
36
  retry_if_exception_type,
37
37
  stop_after_attempt,
38
- wait_exponential,
38
+ wait_exponential_jitter,
39
39
  )
40
40
 
41
41
  from deeptrade.core.db import Database
@@ -133,6 +133,154 @@ class TushareServerError(TushareError):
133
133
  """5xx / transient transport error — eligible for retry."""
134
134
 
135
135
 
136
+ class TushareTransportError(TushareServerError):
137
+ """Transport-layer transient failure — protocol error, connection reset,
138
+ response ended prematurely, read timeout, etc.
139
+
140
+ Subclass of ``TushareServerError`` so the existing retry whitelist
141
+ (`retry_if_exception_type((TushareRateLimitError, TushareServerError))`)
142
+ and the 5xx → cache fallback path (`_fetch_and_store`) both pick it up
143
+ automatically. New callers don't need any code change.
144
+ """
145
+
146
+
147
+ # ---------------------------------------------------------------------------
148
+ # Exception classifier — type-first, status-code-second, string-last,
149
+ # unknown-defaults-to-transient (so we retry rather than terminate).
150
+ # ---------------------------------------------------------------------------
151
+
152
+
153
+ _TRANSIENT_TYPE_KEYWORDS: tuple[str, ...] = (
154
+ # httpx / h11
155
+ "RemoteProtocolError",
156
+ # http.client / requests / stdlib
157
+ "RemoteDisconnected",
158
+ "ConnectionError",
159
+ "ConnectionResetError",
160
+ "ConnectionAbortedError",
161
+ # http.client / urllib3
162
+ "IncompleteRead",
163
+ # requests
164
+ "ChunkedEncodingError",
165
+ # urllib3
166
+ "ProtocolError",
167
+ # httpx / requests / urllib3
168
+ "ReadTimeout",
169
+ "ReadTimeoutError",
170
+ "ConnectTimeout",
171
+ "ConnectTimeoutError",
172
+ # stdlib / asyncio
173
+ "TimeoutError",
174
+ # SSL handshake interruption — most are transient
175
+ "SSLError",
176
+ )
177
+
178
+ _TRANSIENT_MSG_KEYWORDS: tuple[str, ...] = (
179
+ # the original symptom that motivated this classifier
180
+ "premature",
181
+ "remote protocol",
182
+ "remote disconnect",
183
+ "incomplete read",
184
+ "chunked",
185
+ "connection reset",
186
+ "connection aborted",
187
+ "connection refused",
188
+ "connection error",
189
+ "broken pipe",
190
+ "timeout",
191
+ "timed out",
192
+ "eof occurred",
193
+ )
194
+
195
+
196
+ def _is_transient_transport_error(e: BaseException, type_name: str) -> bool:
197
+ """True if ``e`` is a transport-layer transient failure across httpx /
198
+ requests / urllib3 / stdlib.
199
+
200
+ Type-name match is preferred (fully qualified ``module.QualName`` lets
201
+ one keyword cover multiple stacks). Message-keyword fallback handles
202
+ the case where the tushare SDK swallows the original exception type
203
+ and re-raises a plain ``Exception(str)``.
204
+ """
205
+ if any(k in type_name for k in _TRANSIENT_TYPE_KEYWORDS):
206
+ return True
207
+ msg = str(e).lower()
208
+ return any(k in msg for k in _TRANSIENT_MSG_KEYWORDS)
209
+
210
+
211
+ def _extract_http_status(e: BaseException) -> int | None:
212
+ """Best-effort HTTP status extraction.
213
+
214
+ Looks at ``e.response.status_code`` (httpx / requests) or the leading
215
+ three digits of ``str(e)`` (some SDK wrappers prefix the code).
216
+ """
217
+ response = getattr(e, "response", None)
218
+ if response is not None:
219
+ status = getattr(response, "status_code", None)
220
+ if status is None:
221
+ status = getattr(response, "status", None)
222
+ if isinstance(status, int):
223
+ return status
224
+ msg = str(e)
225
+ if len(msg) >= 3 and msg[:3].isdigit():
226
+ try:
227
+ return int(msg[:3])
228
+ except ValueError:
229
+ return None
230
+ return None
231
+
232
+
233
+ def _classify_tushare_exception(e: BaseException) -> TushareError:
234
+ """Map an arbitrary upstream exception to the right TushareError subclass.
235
+
236
+ Order matters:
237
+ 1. Exception type — the most reliable signal.
238
+ 2. HTTP status code if present on the exception.
239
+ 3. Tushare business-layer keywords (Chinese / English).
240
+ 4. Default to TushareTransportError (retryable). Inverting the
241
+ default from "unknown → fatal" to "unknown → transient" is the
242
+ central design change: a remote network service's unknown
243
+ errors are far more likely transient than permanent. Worst case
244
+ we waste a few retries; best case we ride out an outage that
245
+ used to terminate hours of work.
246
+ """
247
+ type_name = f"{type(e).__module__}.{type(e).__qualname__}"
248
+
249
+ # 1. Type-based — covers RemoteProtocolError / ChunkedEncodingError / ...
250
+ if _is_transient_transport_error(e, type_name):
251
+ return TushareTransportError(str(e))
252
+
253
+ # 2. HTTP status code, if available
254
+ status = _extract_http_status(e)
255
+ if status is not None:
256
+ if status == 429:
257
+ return TushareRateLimitError(str(e))
258
+ if 500 <= status < 600:
259
+ return TushareServerError(str(e))
260
+ if status in (401, 403):
261
+ return TushareUnauthorizedError(str(e))
262
+
263
+ # 3. Tushare business-layer text matching
264
+ msg = str(e)
265
+ low = msg.lower()
266
+ if "权限" in msg or "未开通" in msg or "permission" in low or "no permission" in low:
267
+ return TushareUnauthorizedError(msg)
268
+ # Tushare's actual rate-limit response is the long-form
269
+ # "抱歉,您每分钟最多访问该接口500次" — match the "每分钟" + "次" pair so
270
+ # those, plus the shorter "频率"/"限流" variants, all funnel here.
271
+ if (
272
+ "频率" in msg
273
+ or "限流" in msg
274
+ or ("每分钟" in msg and "次" in msg)
275
+ or "rate" in low
276
+ or "429" in msg
277
+ ):
278
+ return TushareRateLimitError(msg)
279
+
280
+ # 4. Default-inverted: unknown is treated as transient, not fatal.
281
+ return TushareTransportError(f"unclassified: {msg}")
282
+
283
+
136
284
  # ---------------------------------------------------------------------------
137
285
  # Transport abstraction
138
286
  # ---------------------------------------------------------------------------
@@ -166,15 +314,7 @@ class TushareSDKTransport(TushareTransport):
166
314
  try:
167
315
  df = method(**kwargs)
168
316
  except Exception as e: # noqa: BLE001 — translate SDK errors uniformly
169
- msg = str(e)
170
- low = msg.lower()
171
- if "权限" in msg or "permission" in low or "未开通" in msg or "no permission" in low:
172
- raise TushareUnauthorizedError(msg) from e
173
- if "频率" in msg or "rate" in low or "429" in msg:
174
- raise TushareRateLimitError(msg) from e
175
- if "5" in msg[:3] or "timeout" in low or "connection" in low:
176
- raise TushareServerError(msg) from e
177
- raise TushareError(msg) from e
317
+ raise _classify_tushare_exception(e) from e
178
318
  if df is None:
179
319
  return pd.DataFrame()
180
320
  return df
@@ -296,6 +436,9 @@ class TushareClient:
296
436
  intraday: if True, all writes for INTRADAY_SENSITIVE_APIS get
297
437
  data_completeness='intraday'; reads will only accept matching
298
438
  completeness.
439
+ max_retries: tenacity stop_after_attempt for transient errors
440
+ (rate limit + server + transport). Default 7 → worst-case
441
+ wait ≈ (1+2+4+8+16+30) ≈ 60s of jittered backoff.
299
442
  event_cb: optional callback for surfacing operationally-relevant
300
443
  tushare events (5xx fallback, etc.) to the caller. Signature
301
444
  ``event_cb(event_type, message, payload_dict)``. Kept as
@@ -310,6 +453,7 @@ class TushareClient:
310
453
  plugin_id: str,
311
454
  rps: float = 6.0,
312
455
  intraday: bool = False,
456
+ max_retries: int = 7,
313
457
  event_cb: Callable[[str, str, dict[str, Any]], None] | None = None,
314
458
  ) -> None:
315
459
  self._db = db
@@ -318,6 +462,18 @@ class TushareClient:
318
462
  self._bucket = _TokenBucket(rps)
319
463
  self._intraday = intraday
320
464
  self._event_cb = event_cb
465
+ # R1 (HARD CONSTRAINT): the Retrying object wraps `_do_fetch`, whose
466
+ # FIRST line is `self._bucket.acquire()`. Every retry attempt re-enters
467
+ # `_do_fetch`, so the token bucket is honored on every attempt — the
468
+ # tenacity backoff and the bucket throttle compose, never bypass.
469
+ # Don't move bucket.acquire() out of `_do_fetch` without updating
470
+ # `tests/core/test_tushare_retry_r1.py`.
471
+ self._retrying = Retrying(
472
+ retry=retry_if_exception_type((TushareRateLimitError, TushareServerError)),
473
+ stop=stop_after_attempt(max_retries),
474
+ wait=wait_exponential_jitter(initial=1, max=30, jitter=2),
475
+ reraise=True,
476
+ )
321
477
 
322
478
  @property
323
479
  def plugin_id(self) -> str:
@@ -614,13 +770,11 @@ class TushareClient:
614
770
  self._write_cached(api_name, cache_key_date, params, df)
615
771
  return df
616
772
 
617
- @retry(
618
- retry=retry_if_exception_type((TushareRateLimitError, TushareServerError)),
619
- stop=stop_after_attempt(5),
620
- wait=wait_exponential(multiplier=1, min=1, max=15),
621
- reraise=True,
622
- )
623
773
  def _fetch_with_retries(self, api_name: str, params: dict[str, Any]) -> pd.DataFrame:
774
+ """Fetch with tenacity retry around `_do_fetch`. See R1 in __init__."""
775
+ return self._retrying(self._do_fetch, api_name, params)
776
+
777
+ def _do_fetch(self, api_name: str, params: dict[str, Any]) -> pd.DataFrame:
624
778
  """Fetch the widest payload we'd ever want for this API.
625
779
 
626
780
  For most APIs tushare returns every column when ``fields`` is omitted,
@@ -628,6 +782,8 @@ class TushareClient:
628
782
  ``WIDE_FIELDS`` overrides ``fields=`` per-API so the cache row contains
629
783
  every column downstream callers need; ``call()``'s ``_project_fields``
630
784
  narrows it back at READ time.
785
+
786
+ First line is `self._bucket.acquire()` — see R1 in `__init__`.
631
787
  """
632
788
  self._bucket.acquire()
633
789
  t0 = time.monotonic()
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "deeptrade-quant"
7
- version = "0.3.0"
7
+ version = "0.3.1"
8
8
  description = "LLM-driven A-share (Shanghai/Shenzhen main board) stock screening CLI"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.11"