aria-code 4.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. agents/__init__.py +32 -0
  2. agents/base.py +190 -0
  3. agents/deep/__init__.py +37 -0
  4. agents/deep/calibration_loop.py +144 -0
  5. agents/deep/critic.py +125 -0
  6. agents/deep/deepen.py +193 -0
  7. agents/deep/models.py +149 -0
  8. agents/deep/pipeline.py +164 -0
  9. agents/deep/quant_fusion.py +192 -0
  10. agents/deep/themes.py +95 -0
  11. agents/deep/tiers.py +106 -0
  12. agents/financial/__init__.py +10 -0
  13. agents/financial/catalyst.py +279 -0
  14. agents/financial/debate.py +145 -0
  15. agents/financial/earnings.py +303 -0
  16. agents/financial/fundamental.py +159 -0
  17. agents/financial/macro.py +99 -0
  18. agents/financial/news.py +207 -0
  19. agents/financial/risk.py +132 -0
  20. agents/financial/sector.py +279 -0
  21. agents/financial/synthesis.py +274 -0
  22. agents/financial/technical.py +258 -0
  23. agents/portfolio_agent.py +333 -0
  24. agents/realty/__init__.py +62 -0
  25. agents/realty/asset_diagnosis.py +150 -0
  26. agents/realty/business_match.py +165 -0
  27. agents/realty/cashflow_verify.py +208 -0
  28. agents/realty/contract_rules.py +209 -0
  29. agents/realty/energy_anomaly.py +188 -0
  30. agents/realty/exit_settlement.py +207 -0
  31. agents/realty/fulfillment_risk.py +205 -0
  32. agents/realty/ops_optimize.py +159 -0
  33. agents/realty/revenue_share.py +214 -0
  34. agents/registry.py +144 -0
  35. agents/sports/__init__.py +0 -0
  36. agents/sports/football_agent.py +169 -0
  37. agents/team.py +289 -0
  38. aliyun_data_client.py +660 -0
  39. apps/README.md +12 -0
  40. apps/__init__.py +2 -0
  41. apps/channels/README.md +15 -0
  42. apps/cli/README.md +13 -0
  43. apps/cli/__init__.py +2 -0
  44. apps/cli/bootstrap.py +99 -0
  45. apps/cli/codegen_paths.py +29 -0
  46. apps/cli/commands/__init__.py +16 -0
  47. apps/cli/commands/analysis_cmds.py +288 -0
  48. apps/cli/commands/backtest_cmds.py +1887 -0
  49. apps/cli/commands/broker_cmds.py +1154 -0
  50. apps/cli/commands/business_workflow_cmds.py +289 -0
  51. apps/cli/commands/catalog.py +84 -0
  52. apps/cli/commands/data_cmds.py +405 -0
  53. apps/cli/commands/diagnostic_cmds.py +179 -0
  54. apps/cli/commands/diagnostic_ops_cmds.py +696 -0
  55. apps/cli/commands/finance_render.py +12 -0
  56. apps/cli/commands/market.py +399 -0
  57. apps/cli/commands/market_cmds.py +1276 -0
  58. apps/cli/commands/market_context.py +425 -0
  59. apps/cli/commands/market_render.py +7 -0
  60. apps/cli/commands/model_cmds.py +1579 -0
  61. apps/cli/commands/ops_cmds.py +668 -0
  62. apps/cli/commands/portfolio_cmds.py +962 -0
  63. apps/cli/commands/report.py +377 -0
  64. apps/cli/commands/scaffold_templates.py +617 -0
  65. apps/cli/commands/session_cmds.py +179 -0
  66. apps/cli/commands/session_ux_cmds.py +280 -0
  67. apps/cli/commands/team.py +588 -0
  68. apps/cli/commands/team_render.py +8 -0
  69. apps/cli/commands/ui_cmds.py +358 -0
  70. apps/cli/commands/workflow_cmds.py +279 -0
  71. apps/cli/commands/workspace_cmds.py +1414 -0
  72. apps/cli/config_paths.py +70 -0
  73. apps/cli/config_store.py +61 -0
  74. apps/cli/deterministic.py +122 -0
  75. apps/cli/direct.py +48 -0
  76. apps/cli/github_app_auth.py +135 -0
  77. apps/cli/handlers/__init__.py +11 -0
  78. apps/cli/handlers/broker_handlers.py +122 -0
  79. apps/cli/handlers/chart_handlers.py +1309 -0
  80. apps/cli/handlers/market_handlers.py +2509 -0
  81. apps/cli/handlers/realty_handlers.py +114 -0
  82. apps/cli/handlers/strategy_advice.py +82 -0
  83. apps/cli/hooks.py +180 -0
  84. apps/cli/i18n.py +284 -0
  85. apps/cli/intent.py +136 -0
  86. apps/cli/intent_router.py +217 -0
  87. apps/cli/lifecycle_hooks.py +48 -0
  88. apps/cli/main.py +29 -0
  89. apps/cli/market_metadata.py +135 -0
  90. apps/cli/market_universe.py +265 -0
  91. apps/cli/message_processing.py +257 -0
  92. apps/cli/plan_mode.py +139 -0
  93. apps/cli/plotly_html.py +15 -0
  94. apps/cli/prediction_feedback.py +202 -0
  95. apps/cli/preflight.py +497 -0
  96. apps/cli/project_aria.py +60 -0
  97. apps/cli/prompts/__init__.py +0 -0
  98. apps/cli/prompts/coding.py +658 -0
  99. apps/cli/prompts/system_prompts.py +531 -0
  100. apps/cli/prompts/ui.py +434 -0
  101. apps/cli/providers/__init__.py +1 -0
  102. apps/cli/providers/base.py +271 -0
  103. apps/cli/providers/chat_routing.py +80 -0
  104. apps/cli/providers/llm/__init__.py +1 -0
  105. apps/cli/providers/llm/ollama_stream.py +1170 -0
  106. apps/cli/providers/llm/sse_stream.py +216 -0
  107. apps/cli/providers/runtime_bridge.py +185 -0
  108. apps/cli/runtime_consumer.py +489 -0
  109. apps/cli/session_export.py +87 -0
  110. apps/cli/session_jsonl.py +207 -0
  111. apps/cli/session_store.py +112 -0
  112. apps/cli/todo_tracker.py +190 -0
  113. apps/cli/tools/__init__.py +40 -0
  114. apps/cli/tools/context.py +46 -0
  115. apps/cli/tools/file_tools.py +112 -0
  116. apps/cli/tools/market_tools.py +549 -0
  117. apps/cli/tools/notebook_tools.py +111 -0
  118. apps/cli/tools/system_tools.py +669 -0
  119. apps/cli/tools/write_tools.py +715 -0
  120. apps/cli/tradingview_bridge.py +434 -0
  121. apps/cli/update_check.py +152 -0
  122. apps/cli/utils/__init__.py +0 -0
  123. apps/cli/utils/market_detect.py +1578 -0
  124. apps/daemon/README.md +14 -0
  125. apps/vscode/README.md +115 -0
  126. apps/vscode/package.json +70 -0
  127. aria_cli.py +11636 -0
  128. aria_code-4.1.3.dist-info/METADATA +952 -0
  129. aria_code-4.1.3.dist-info/RECORD +284 -0
  130. aria_code-4.1.3.dist-info/WHEEL +5 -0
  131. aria_code-4.1.3.dist-info/entry_points.txt +2 -0
  132. aria_code-4.1.3.dist-info/licenses/LICENSE +121 -0
  133. aria_code-4.1.3.dist-info/top_level.txt +50 -0
  134. aria_daemon.py +1295 -0
  135. aria_feishu_bot.py +1359 -0
  136. aria_relay_client.py +182 -0
  137. aria_relay_server.py +405 -0
  138. aria_telegram_bot.py +202 -0
  139. ariarc.py +328 -0
  140. artifacts.py +491 -0
  141. backtest_report.py +472 -0
  142. brokers/__init__.py +72 -0
  143. brokers/base.py +207 -0
  144. brokers/capabilities.py +264 -0
  145. brokers/cn/__init__.py +10 -0
  146. brokers/cn/easytrader_broker.py +193 -0
  147. brokers/cn/futu_broker.py +194 -0
  148. brokers/cn/longbridge_broker.py +190 -0
  149. brokers/cn/tiger_broker.py +196 -0
  150. brokers/cn/xtquant_broker.py +175 -0
  151. brokers/config.py +364 -0
  152. brokers/intl/__init__.py +5 -0
  153. brokers/intl/alpaca_broker.py +183 -0
  154. brokers/intl/ibkr_broker.py +215 -0
  155. brokers/intl/webull_broker.py +156 -0
  156. brokers/paper_broker.py +259 -0
  157. brokers/planning.py +296 -0
  158. brokers/registry.py +181 -0
  159. brokers/trading.py +237 -0
  160. change_store.py +127 -0
  161. command_safety.py +19 -0
  162. computer_use_tools.py +504 -0
  163. dashboard_generator.py +578 -0
  164. data_analysis_tools.py +808 -0
  165. data_cleaner.py +483 -0
  166. data_service.py +481 -0
  167. datasources/__init__.py +23 -0
  168. datasources/base.py +166 -0
  169. datasources/router.py +221 -0
  170. datasources/sources/__init__.py +15 -0
  171. datasources/sources/akshare_source.py +269 -0
  172. datasources/sources/alpha_vantage_source.py +202 -0
  173. datasources/sources/edgar_source.py +218 -0
  174. datasources/sources/finnhub_source.py +197 -0
  175. datasources/sources/fred_source.py +219 -0
  176. datasources/sources/tushare_source.py +141 -0
  177. datasources/sources/web_scraper_source.py +278 -0
  178. datasources/sources/world_bank_source.py +205 -0
  179. datasources/sources/yfinance_source.py +152 -0
  180. demo_player.py +204 -0
  181. doctor.py +508 -0
  182. file_analysis_tools.py +734 -0
  183. finance_formulas.py +389 -0
  184. football_data_client.py +1670 -0
  185. intent_classifier.py +358 -0
  186. local_finance_tools.py +3221 -0
  187. local_llm_provider.py +552 -0
  188. macro_tools.py +368 -0
  189. market_data_client.py +1899 -0
  190. mcp_client.py +506 -0
  191. memory_manager.py +245 -0
  192. model_capability.py +416 -0
  193. notification_tools.py +248 -0
  194. packages/__init__.py +23 -0
  195. packages/aria_agents/__init__.py +5 -0
  196. packages/aria_agents/manifest.py +69 -0
  197. packages/aria_core/__init__.py +34 -0
  198. packages/aria_core/architecture.py +192 -0
  199. packages/aria_core/export.py +124 -0
  200. packages/aria_core/manifest.py +65 -0
  201. packages/aria_infra/__init__.py +15 -0
  202. packages/aria_infra/arthera.py +52 -0
  203. packages/aria_infra/doctor.py +246 -0
  204. packages/aria_infra/product.py +37 -0
  205. packages/aria_mcp/__init__.py +25 -0
  206. packages/aria_mcp/bridge.py +38 -0
  207. packages/aria_mcp/config.py +97 -0
  208. packages/aria_mcp/tools.py +61 -0
  209. packages/aria_sdk/__init__.py +19 -0
  210. packages/aria_sdk/client.py +396 -0
  211. packages/aria_sdk/providers.py +70 -0
  212. packages/aria_sdk/streaming.py +73 -0
  213. packages/aria_sdk/types.py +86 -0
  214. packages/aria_services/__init__.py +55 -0
  215. packages/aria_services/context.py +258 -0
  216. packages/aria_services/data.py +11 -0
  217. packages/aria_services/provider_health.py +189 -0
  218. packages/aria_services/registry.py +213 -0
  219. packages/aria_services/usage.py +138 -0
  220. packages/aria_skills/__init__.py +5 -0
  221. packages/aria_skills/registry.py +59 -0
  222. packages/aria_tools/__init__.py +5 -0
  223. packages/aria_tools/registry.py +128 -0
  224. packages/quant_engine/__init__.py +6 -0
  225. packages/quant_engine/sports/__init__.py +72 -0
  226. packages/quant_engine/sports/calibrator.py +353 -0
  227. packages/quant_engine/sports/dixon_coles.py +234 -0
  228. packages/quant_engine/sports/elo.py +299 -0
  229. packages/quant_engine/sports/form.py +188 -0
  230. packages/quant_engine/sports/h2h.py +195 -0
  231. packages/quant_engine/sports/ml_model.py +354 -0
  232. packages/quant_engine/sports/predictor.py +311 -0
  233. packages/quant_engine/sports/tracker.py +664 -0
  234. packages/quant_engine/stochastic/__init__.py +27 -0
  235. packages/quant_engine/stochastic/gbm_enhanced.py +195 -0
  236. packages/quant_engine/stochastic/ito_calculus.py +477 -0
  237. packages/quant_engine/stochastic/kelly_criterion.py +181 -0
  238. packages/quant_engine/stochastic/monte_carlo_advanced.py +95 -0
  239. packages/quant_engine/stochastic/options_pricing.py +573 -0
  240. packages/quant_engine/stochastic/stochastic_processes.py +90 -0
  241. plan_utils.py +194 -0
  242. plugin_loader.py +328 -0
  243. portfolio_ledger.py +262 -0
  244. privacy/__init__.py +5 -0
  245. privacy/feedback.py +123 -0
  246. project_tools.py +525 -0
  247. providers/__init__.py +30 -0
  248. providers/llm/__init__.py +19 -0
  249. providers/llm/anthropic.py +184 -0
  250. providers/llm/base.py +139 -0
  251. providers/llm/ollama.py +128 -0
  252. providers/llm/openai_compat.py +282 -0
  253. providers/llm/registry.py +358 -0
  254. realty_data_tools.py +659 -0
  255. report_generator.py +1314 -0
  256. runtime/__init__.py +103 -0
  257. runtime/agent_loop.py +1183 -0
  258. runtime/approval.py +51 -0
  259. runtime/events.py +102 -0
  260. runtime/gateway.py +128 -0
  261. runtime/lsp.py +346 -0
  262. runtime/subagent.py +258 -0
  263. runtime/tool_executor.py +104 -0
  264. runtime/tool_policy.py +106 -0
  265. safety/__init__.py +21 -0
  266. safety/permissions.py +275 -0
  267. setup_wizard.py +653 -0
  268. strategy_vault.py +420 -0
  269. ui/__init__.py +100 -0
  270. ui/banner.py +310 -0
  271. ui/completer.py +391 -0
  272. ui/console.py +271 -0
  273. ui/image_render.py +243 -0
  274. ui/input_box.py +376 -0
  275. ui/picker.py +195 -0
  276. ui/render/__init__.py +11 -0
  277. ui/render/finance.py +1480 -0
  278. ui/render/market.py +225 -0
  279. ui/render/output.py +681 -0
  280. ui/render/team.py +346 -0
  281. ui/robot.py +235 -0
  282. workspace/__init__.py +6 -0
  283. workspace/files.py +170 -0
  284. workspace/verify.py +113 -0
@@ -0,0 +1,354 @@
1
+ """
2
+ sports/ml_model.py — 足球 XGBoost 预测模型
3
+ =============================================
4
+ 从 tracker.py 积累的已结算预测记录中学习,
5
+ 与 Dixon-Coles 规则模型进行 A/B Brier Score 对比。
6
+
7
+ 触发逻辑:
8
+ - 首次训练: ≥20 条已结算记录(Elo + 实际结果)
9
+ - 自动重训: 每新增 10 条记录触发一次
10
+ - 预测时: 优先使用 ML 模型,数据不足则 fallback → DC
11
+
12
+ 特征向量 (9维):
13
+ elo_diff, elo_home, elo_away,
14
+ lambda_home, lambda_away, lambda_ratio,
15
+ league_avg, elo_diff_abs_scaled, is_high_gap
16
+
17
+ 标签: 0=away, 1=draw, 2=home(XGBoost 多分类)
18
+
19
+ 持久化:
20
+ ~/.arthera/football_ml_model.pkl
21
+ ~/.arthera/football_ml_report.json
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import json
27
+ import logging
28
+ import pickle
29
+ import time
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Tuple
32
+
33
+ import numpy as np
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ _MODEL_PATH = Path.home() / ".arthera" / "football_ml_model.pkl"
38
+ _REPORT_PATH = Path.home() / ".arthera" / "football_ml_report.json"
39
+ _MIN_TRAIN = 20
40
+ _RETRAIN_EVERY = 10
41
+
42
+ try:
43
+ from xgboost import XGBClassifier
44
+ _HAS_XGB = True
45
+ except ImportError:
46
+ try:
47
+ import lightgbm as lgb
48
+ _HAS_XGB = False
49
+ _HAS_LGB = True
50
+ except ImportError:
51
+ _HAS_XGB = False
52
+ _HAS_LGB = False
53
+
54
+ try:
55
+ from sklearn.preprocessing import StandardScaler
56
+ from sklearn.model_selection import cross_val_score, StratifiedKFold
57
+ _HAS_SK = True
58
+ except ImportError:
59
+ _HAS_SK = False
60
+
61
+
62
+ # ── 特征提取 ──────────────────────────────────────────────────────────────────
63
+
64
+ _FEATURE_NAMES = [
65
+ "elo_diff", "elo_home", "elo_away",
66
+ "lambda_home", "lambda_away", "lambda_ratio",
67
+ "league_avg", "elo_gap_scaled", "is_high_gap",
68
+ ]
69
+
70
+
71
+ def _extract_features(record: Dict) -> Optional[np.ndarray]:
72
+ """从一条预测记录提取特征向量,缺字段返回 None。"""
73
+ elo_h = record.get("home_elo")
74
+ elo_a = record.get("away_elo")
75
+ lh = record.get("lambda_home")
76
+ la = record.get("lambda_away")
77
+ avg = record.get("league_avg", 1.35)
78
+
79
+ if any(v is None for v in [elo_h, elo_a, lh, la]):
80
+ return None
81
+
82
+ elo_h, elo_a, lh, la, avg = float(elo_h), float(elo_a), float(lh), float(la), float(avg)
83
+ diff = elo_h - elo_a
84
+
85
+ return np.array([
86
+ diff, # Elo 差
87
+ elo_h, # 主队 Elo
88
+ elo_a, # 客队 Elo
89
+ lh, # 主队期望进球
90
+ la, # 客队期望进球
91
+ lh / (la + 1e-6), # λ 比值(反映实力差距)
92
+ avg, # 赛事场均进球
93
+ abs(diff) / 400.0, # 标准化 Elo 差(400=1个标准差)
94
+ 1.0 if abs(diff) > 200 else 0.0, # 悬殊场次标志
95
+ ], dtype=np.float32)
96
+
97
+
98
+ def _result_to_label(result: str) -> int:
99
+ """home=2, draw=1, away=0"""
100
+ return {"home": 2, "draw": 1, "away": 0}.get(result, -1)
101
+
102
+
103
+ # ── 训练器 ────────────────────────────────────────────────────────────────────
104
+
105
+ class FootballMLModel:
106
+ """
107
+ 足球 XGBoost/LightGBM 预测器。
108
+
109
+ 用法:
110
+ m = FootballMLModel.load_or_train()
111
+ if m.is_ready:
112
+ p = m.predict(record) # {"home_win": 0.72, "draw": 0.18, "away_win": 0.10}
113
+ """
114
+
115
+ def __init__(self):
116
+ self._model = None
117
+ self._scaler = None
118
+ self._report: Dict = {}
119
+ self._n_trained = 0
120
+
121
+ @property
122
+ def is_ready(self) -> bool:
123
+ return self._model is not None
124
+
125
+ # ── 训练 ──────────────────────────────────────────────────────────────────
126
+
127
+ def train(self, records: Optional[List[Dict]] = None) -> Dict:
128
+ """
129
+ 从 tracker 记录中训练。records 为 None 时自动从磁盘加载。
130
+ 返回训练报告 dict。
131
+ """
132
+ if not (_HAS_XGB or _HAS_LGB):
133
+ return {"error": "pip install xgboost 或 lightgbm 后重试"}
134
+ if not _HAS_SK:
135
+ return {"error": "pip install scikit-learn 后重试"}
136
+
137
+ if records is None:
138
+ records = _load_settled_records()
139
+
140
+ # 过滤出含完整特征的记录
141
+ X_rows, y_rows = [], []
142
+ for r in records:
143
+ label = _result_to_label(r.get("result", ""))
144
+ if label == -1:
145
+ continue
146
+ feat = _extract_features(r)
147
+ if feat is None:
148
+ continue
149
+ X_rows.append(feat)
150
+ y_rows.append(label)
151
+
152
+ n = len(X_rows)
153
+ if n < _MIN_TRAIN:
154
+ return {"status": "waiting", "n": n, "need": _MIN_TRAIN,
155
+ "message": f"需要 {_MIN_TRAIN} 条完整记录,当前 {n} 条"}
156
+
157
+ X = np.array(X_rows)
158
+ y = np.array(y_rows)
159
+
160
+ # 标准化
161
+ scaler = StandardScaler()
162
+ X_s = scaler.fit_transform(X)
163
+
164
+ # 模型
165
+ if _HAS_XGB:
166
+ model = XGBClassifier(
167
+ n_estimators=200, max_depth=4, learning_rate=0.05,
168
+ subsample=0.8, colsample_bytree=0.8,
169
+ reg_alpha=0.1, reg_lambda=0.5,
170
+ objective="multi:softprob", num_class=3,
171
+ eval_metric="mlogloss", use_label_encoder=False,
172
+ random_state=42, verbosity=0,
173
+ )
174
+ else:
175
+ import lightgbm as lgb
176
+ model = lgb.LGBMClassifier(
177
+ n_estimators=200, max_depth=4, learning_rate=0.05,
178
+ num_class=3, objective="multiclass",
179
+ feature_fraction=0.8, bagging_fraction=0.8,
180
+ reg_alpha=0.1, reg_lambda=0.5,
181
+ verbose=-1, random_state=42,
182
+ )
183
+
184
+ # 走步交叉验证(时序感知:按时间顺序分折)
185
+ cv_briers = _walk_forward_cv(model, X_s, y, n_splits=min(5, n // 4))
186
+
187
+ # 全量重训练
188
+ model.fit(X_s, y)
189
+
190
+ self._model = model
191
+ self._scaler = scaler
192
+ self._n_trained = n
193
+
194
+ # CV Brier vs DC Brier(走步验证,公平对比)
195
+ dc_brier = _dc_brier_from_records(records[:n])
196
+ cv_mean = float(np.mean(cv_briers)) if cv_briers else None
197
+ # improvement = DC - CV_ML(正值表示 ML 更准,使用 CV 避免训练集过拟合)
198
+ improvement = round(dc_brier - cv_mean, 4) if cv_mean is not None else None
199
+
200
+ lib = "XGBoost" if _HAS_XGB else "LightGBM"
201
+ self._report = {
202
+ "lib": lib,
203
+ "n_samples": int(n),
204
+ "cv_brier_mean": round(cv_mean, 4) if cv_mean is not None else None,
205
+ "cv_brier_std": round(float(np.std(cv_briers)), 4) if cv_briers else None,
206
+ "dc_brier": round(float(dc_brier), 4),
207
+ "improvement": improvement, # >0 = ML 更准(基于 CV,可信)
208
+ "trained_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
209
+ "feature_names": _FEATURE_NAMES,
210
+ }
211
+
212
+ _save_model(self._model, self._scaler, self._report)
213
+ logger.info(
214
+ f"[FootballML] {lib} 训练完成 n={n} "
215
+ f"CV Brier={self._report.get('cv_brier_mean')} "
216
+ f"DC Brier={dc_brier:.4f} 提升={self._report['improvement']:+.4f}"
217
+ )
218
+ return self._report
219
+
220
+ # ── 预测 ──────────────────────────────────────────────────────────────────
221
+
222
+ def predict(self, record: Dict) -> Optional[Dict[str, float]]:
223
+ """
224
+ 从预测记录(含 elo/lambda)输出 ML 概率。
225
+ 返回 None 表示特征不完整,调用方应 fallback 到 DC。
226
+ """
227
+ if not self.is_ready:
228
+ return None
229
+ feat = _extract_features(record)
230
+ if feat is None:
231
+ return None
232
+
233
+ feat_s = self._scaler.transform(feat.reshape(1, -1))
234
+ proba = self._model.predict_proba(feat_s)[0] # [away, draw, home]
235
+ return {
236
+ "away_win": round(float(proba[0]), 4),
237
+ "draw": round(float(proba[1]), 4),
238
+ "home_win": round(float(proba[2]), 4),
239
+ "model": "XGB+Elo+λ",
240
+ }
241
+
242
+ @property
243
+ def report(self) -> Dict:
244
+ return self._report
245
+
246
+ # ── 加载/保存 ─────────────────────────────────────────────────────────────
247
+
248
+ @classmethod
249
+ def load_or_train(cls, force_train: bool = False) -> "FootballMLModel":
250
+ """加载已存模型,若不存在或需重训则自动训练。"""
251
+ m = cls()
252
+ if _MODEL_PATH.exists() and not force_train:
253
+ try:
254
+ payload = pickle.loads(_MODEL_PATH.read_bytes())
255
+ m._model = payload["model"]
256
+ m._scaler = payload["scaler"]
257
+ m._report = payload.get("report", {})
258
+ m._n_trained = payload.get("n_trained", 0)
259
+
260
+ # 检查是否需要重训
261
+ records = _load_settled_records()
262
+ if len(records) >= m._n_trained + _RETRAIN_EVERY:
263
+ logger.info("[FootballML] 新增 ≥10 条记录,触发重训")
264
+ m.train(records)
265
+ return m
266
+ except Exception as e:
267
+ logger.warning(f"[FootballML] 加载失败: {e},重新训练")
268
+
269
+ m.train()
270
+ return m
271
+
272
+
273
+ # ── 工具函数 ──────────────────────────────────────────────────────────────────
274
+
275
+ def _load_settled_records() -> List[Dict]:
276
+ """从 tracker 加载已结算预测记录。"""
277
+ try:
278
+ from .tracker import _PRED_PATH, _load_json
279
+ records = _load_json(_PRED_PATH, [])
280
+ return [r for r in records if r.get("result") and r.get("brier_score") is not None]
281
+ except Exception:
282
+ return []
283
+
284
+
285
+ def _walk_forward_cv(model, X: np.ndarray, y: np.ndarray, n_splits: int = 5) -> List[float]:
286
+ """时序感知交叉验证,返回每折 Brier Score。"""
287
+ import copy
288
+ n = len(X)
289
+ fold_size = max(4, n // (n_splits + 1))
290
+ briers = []
291
+ for i in range(n_splits):
292
+ tr_end = (i + 1) * fold_size
293
+ te_end = tr_end + fold_size
294
+ if te_end > n:
295
+ break
296
+ y_tr = y[:tr_end]
297
+ # 跳过训练集类别不足的折(XGBoost 要求所有类别都出现)
298
+ if len(np.unique(y_tr)) < 3:
299
+ continue
300
+ try:
301
+ m_copy = copy.deepcopy(model)
302
+ m_copy.fit(X[:tr_end], y_tr)
303
+ proba = m_copy.predict_proba(X[tr_end:te_end])
304
+ if proba.shape[1] == 3:
305
+ briers.append(_brier_mc(proba, y[tr_end:te_end]))
306
+ except Exception:
307
+ continue
308
+ return briers
309
+
310
+
311
+ def _brier_mc(proba: np.ndarray, y: np.ndarray) -> float:
312
+ """多分类 Brier Score。"""
313
+ total = 0.0
314
+ n_classes = proba.shape[1]
315
+ for i, yi in enumerate(y):
316
+ for c in range(n_classes):
317
+ total += (proba[i, c] - (1.0 if yi == c else 0.0)) ** 2
318
+ return total / max(len(y), 1)
319
+
320
+
321
+ def _dc_brier_from_records(records: List[Dict]) -> float:
322
+ """用记录里已存的 brier_score(DC 模型)计算均值。"""
323
+ scores = [r["brier_score"] for r in records if r.get("brier_score") is not None]
324
+ return float(np.mean(scores)) if scores else 0.5
325
+
326
+
327
+ def _save_model(model, scaler, report: Dict) -> None:
328
+ try:
329
+ n = report.get("n_samples", 0)
330
+ payload = {"model": model, "scaler": scaler, "report": report, "n_trained": n}
331
+ _MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
332
+ _MODEL_PATH.write_bytes(pickle.dumps(payload))
333
+ # JSON 序列化:将 numpy 类型转换为 Python 原生类型
334
+ def _to_native(obj):
335
+ if isinstance(obj, (np.floating, np.float32, np.float64)): return float(obj)
336
+ if isinstance(obj, (np.integer,)): return int(obj)
337
+ if isinstance(obj, np.ndarray): return obj.tolist()
338
+ return obj
339
+ safe_report = json.loads(json.dumps(report, default=_to_native))
340
+ _REPORT_PATH.write_text(json.dumps(safe_report, ensure_ascii=False, indent=2), encoding="utf-8")
341
+ except Exception as e:
342
+ logger.warning(f"[FootballML] 保存失败: {e}")
343
+
344
+
345
+ # ── 单例 ─────────────────────────────────────────────────────────────────────
346
+
347
+ _instance: Optional[FootballMLModel] = None
348
+
349
+
350
+ def get_football_ml() -> FootballMLModel:
351
+ global _instance
352
+ if _instance is None:
353
+ _instance = FootballMLModel.load_or_train()
354
+ return _instance
@@ -0,0 +1,311 @@
1
+ """
2
+ sports/predictor.py — 统一足球比赛预测引擎 v2
3
+ ==============================================
4
+ 整合 Elo + Dixon-Coles(NB) + 近期状态 + H2H + 赛事情境 五个模块。
5
+
6
+ v2 改进:
7
+ 1. 负二项分布(大比分悬殊时自动启用,尾部更重)
8
+ 2. 动态 DC×Elo 混合权重(form 数据越充足 DC 权重越高)
9
+ 3. 赛事情境参数(必须赢/已出线保守/淘汰赛)
10
+ 4. 动态 WC 场均进球(从 tracker 实时获取)
11
+ 5. ρ 随赛果积累自动校准
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import math
17
+ from typing import Dict, List, Optional, Tuple
18
+
19
+ from .elo import EloRatingSystem, get_elo
20
+ from .dixon_coles import compute_match_probabilities, estimate_rho_from_results
21
+ from .form import analyze_form, parse_api_results
22
+ from .h2h import analyze_h2h, _neutral_h2h
23
+
24
+
25
+ # ── 联赛场均进球(每队每场,后备默认值)──────────────────────────────────────
26
+ _LEAGUE_AVG_GOALS: Dict[str, float] = {
27
+ "wc": 1.35, "euro": 1.20, "copa": 1.28,
28
+ "pl": 1.51, "bl1": 1.56, "sa": 1.33,
29
+ "pd": 1.34, "fl1": 1.43, "cl": 1.40,
30
+ "friendly": 1.45, "default": 1.35,
31
+ }
32
+
33
+ # ── 赛事情境因子 ──────────────────────────────────────────────────────────────
34
+ _CONTEXT: Dict[str, Dict[str, float]] = {
35
+ "normal": {"lmult_h": 1.00, "lmult_a": 1.00, "draw_boost": 0.00},
36
+ "must_win": {"lmult_h": 1.10, "lmult_a": 0.95, "draw_boost": -0.04},
37
+ "safe": {"lmult_h": 0.88, "lmult_a": 0.88, "draw_boost": 0.06},
38
+ "knockout": {"lmult_h": 1.00, "lmult_a": 1.00, "draw_boost": 0.12},
39
+ "knockout_attack": {"lmult_h": 1.08, "lmult_a": 1.00, "draw_boost": 0.05},
40
+ }
41
+
42
+
43
+ class FootballPredictor:
44
+ """
45
+ 增强型足球比赛预测引擎 v2。
46
+
47
+ 用法:
48
+ pred = FootballPredictor()
49
+ result = pred.predict("germany", "curacao", league="wc",
50
+ tournament_context="normal")
51
+ """
52
+
53
+ def __init__(self, elo_system: Optional[EloRatingSystem] = None):
54
+ self._elo = elo_system or get_elo()
55
+
56
+ def predict(
57
+ self,
58
+ home_team: str,
59
+ away_team: str,
60
+ league: str = "default",
61
+ neutral_venue: bool = True,
62
+ form_home: Optional[List[Dict]] = None,
63
+ form_away: Optional[List[Dict]] = None,
64
+ h2h_matches: Optional[List[Dict]] = None,
65
+ historical_results: Optional[List[Tuple[int, int]]] = None,
66
+ tournament_context: str = "normal",
67
+ league_avg_override: Optional[float] = None,
68
+ home_attack_override: Optional[float] = None,
69
+ away_attack_override: Optional[float] = None,
70
+ home_defense_override: Optional[float] = None,
71
+ away_defense_override: Optional[float] = None,
72
+ ) -> Dict:
73
+ """
74
+ 主预测函数。
75
+
76
+ tournament_context:
77
+ "normal" — 小组赛正常(默认)
78
+ "must_win" — 必须赢(全力进攻)
79
+ "safe" — 已出线、可保守
80
+ "knockout" — 淘汰赛(平局→加时)
81
+ "knockout_attack" — 淘汰赛落后方
82
+ """
83
+ # ── Step 0: 基础参数 ───────────────────────────────────────────────────
84
+ league_key = league.lower().replace("-", "").replace("_", "")
85
+ league_avg = league_avg_override or _LEAGUE_AVG_GOALS.get(
86
+ league_key, _LEAGUE_AVG_GOALS["default"]
87
+ )
88
+ ctx = _CONTEXT.get(tournament_context, _CONTEXT["normal"])
89
+
90
+ # ── Step 1: Elo → 攻防基础参数(二次曲线,更陡)─────────────────────
91
+ home_stats = self._elo.get_attack_defense(home_team, league_avg)
92
+ away_stats = self._elo.get_attack_defense(away_team, league_avg)
93
+
94
+ h_attack = home_attack_override or home_stats["attack"]
95
+ a_attack = away_attack_override or away_stats["attack"]
96
+ h_defense = home_defense_override or home_stats["defense"]
97
+ a_defense = away_defense_override or away_stats["defense"]
98
+
99
+ home_elo = home_stats["elo"]
100
+ away_elo = away_stats["elo"]
101
+ elo_diff = home_elo - away_elo
102
+
103
+ # ── Step 2: 近期状态调整 ───────────────────────────────────────────────
104
+ home_form = _neutral_form_dict()
105
+ away_form = _neutral_form_dict()
106
+ form_matches_h = 0
107
+ form_matches_a = 0
108
+
109
+ if form_home:
110
+ parsed_h = parse_api_results(form_home, home_team)
111
+ if parsed_h:
112
+ home_form = analyze_form(parsed_h)
113
+ form_matches_h = home_form.get("matches_analyzed", 0)
114
+ if form_away:
115
+ parsed_a = parse_api_results(form_away, away_team)
116
+ if parsed_a:
117
+ away_form = analyze_form(parsed_a)
118
+ form_matches_a = away_form.get("matches_analyzed", 0)
119
+
120
+ h_attack *= home_form["form_factor_attack"]
121
+ a_attack *= away_form["form_factor_attack"]
122
+ h_defense *= home_form["form_factor_defense"]
123
+ a_defense *= away_form["form_factor_defense"]
124
+
125
+ # ── Step 3: 主场优势 + 赛事情境 ───────────────────────────────────────
126
+ home_adv_mult = 1.0 if neutral_venue else 1.12
127
+
128
+ # ── Step 4: 期望进球 ───────────────────────────────────────────────────
129
+ lambda_home = h_attack * a_defense * home_adv_mult * league_avg * ctx["lmult_h"]
130
+ lambda_away = a_attack * h_defense * league_avg * ctx["lmult_a"]
131
+
132
+ # H2H 微调(±8% 期望进球)
133
+ h2h_result = _neutral_h2h(home_team, away_team)
134
+ if h2h_matches:
135
+ h2h_result = analyze_h2h(h2h_matches, home_team, away_team)
136
+ h2h_adv = h2h_result.get("h2h_advantage", 0.0)
137
+ lambda_home *= (1.0 + h2h_adv)
138
+ lambda_away *= (1.0 - h2h_adv)
139
+
140
+ # ── Step 4b: 自动校准修正 ──────────────────────────────────────────────
141
+ # 全局 λ 偏差(实际进球 / 预测 λ 的历史 EMA)
142
+ # 队伍专属进球偏差(≥3 场数据才生效)
143
+ try:
144
+ from .calibrator import get_calibrated_params, get_team_goal_bias
145
+ cal = get_calibrated_params()
146
+ lambda_home *= cal.get("lambda_home_bias", 1.0)
147
+ lambda_away *= cal.get("lambda_away_bias", 1.0)
148
+ lambda_home *= get_team_goal_bias(home_team)
149
+ lambda_away *= get_team_goal_bias(away_team)
150
+ except Exception:
151
+ pass
152
+
153
+ lambda_home = max(0.20, min(lambda_home, 8.0))
154
+ lambda_away = max(0.20, min(lambda_away, 8.0))
155
+
156
+ # ── Step 5: 动态 ρ 校准 ────────────────────────────────────────────────
157
+ rho = _load_calibrated_rho()
158
+ if historical_results and len(historical_results) >= 20:
159
+ rho = estimate_rho_from_results(historical_results)
160
+
161
+ # ── Step 6: Dixon-Coles(NB 自动启用于悬殊场次)──────────────────────
162
+ dc_result = compute_match_probabilities(
163
+ lambda_home, lambda_away, rho, elo_diff=elo_diff
164
+ )
165
+
166
+ # ── Step 7: Elo 概率混合(动态权重)──────────────────────────────────
167
+ elo_probs = self._elo.win_probability(home_team, away_team, neutral_venue)
168
+
169
+ # form 数据越充足,DC 权重越高;数据稀少时 Elo 权重更保守
170
+ avg_form_matches = (form_matches_h + form_matches_a) / 2.0
171
+ w_dc = min(0.78, 0.55 + avg_form_matches * 0.04)
172
+ w_elo = 1.0 - w_dc
173
+
174
+ mix_home = dc_result["home_win"] * w_dc + elo_probs["home_win"] * w_elo
175
+ mix_draw = dc_result["draw"] * w_dc + elo_probs["draw"] * w_elo
176
+ mix_away = dc_result["away_win"] * w_dc + elo_probs["away_win"] * w_elo
177
+
178
+ # 淘汰赛平局加成(反映加时/点球场景)
179
+ draw_boost = ctx["draw_boost"]
180
+ if draw_boost != 0:
181
+ mix_draw = max(0.02, mix_draw + draw_boost)
182
+
183
+ total = mix_home + mix_draw + mix_away
184
+ mix_home /= total
185
+ mix_draw /= total
186
+ mix_away /= total
187
+
188
+ # Raw (pre-temperature) probabilities — recorded for calibration so the
189
+ # temperature optimizer never compounds an already-applied shrink.
190
+ raw_home, raw_draw, raw_away = mix_home, mix_draw, mix_away
191
+
192
+ # ── Step 8: 概率温度校准(收敛过度自信的预测)────────────────────────
193
+ try:
194
+ from .calibrator import get_confidence_temp, _apply_temp
195
+ _temp = get_confidence_temp()
196
+ if _temp != 1.0:
197
+ mix_home, mix_draw, mix_away = _apply_temp(mix_home, mix_draw, mix_away, _temp)
198
+ except Exception:
199
+ pass
200
+
201
+ def impl_odds(p: float) -> float:
202
+ return round(1.0 / p, 2) if p > 0.01 else 99.0
203
+
204
+ use_nb = abs(elo_diff) > 150
205
+ model_tag = f"Elo+DC{'(NB)' if use_nb else ''}+Form+H2H"
206
+ if draw_boost:
207
+ model_tag += f"+{tournament_context}"
208
+
209
+ return {
210
+ "home_team": home_team,
211
+ "away_team": away_team,
212
+ "home_win": round(mix_home, 4),
213
+ "draw": round(mix_draw, 4),
214
+ "away_win": round(mix_away, 4),
215
+ "raw_home_win": round(raw_home, 4),
216
+ "raw_draw": round(raw_draw, 4),
217
+ "raw_away_win": round(raw_away, 4),
218
+ "btts": dc_result["btts"],
219
+ "over_2_5": dc_result["over_2_5"],
220
+ "lambda_home": round(lambda_home, 2),
221
+ "lambda_away": round(lambda_away, 2),
222
+ "league_avg_goals": round(league_avg, 2),
223
+ "top_scorelines": dc_result["top_scorelines"],
224
+ "implied_odds": {
225
+ "home": impl_odds(mix_home),
226
+ "draw": impl_odds(mix_draw),
227
+ "away": impl_odds(mix_away),
228
+ },
229
+ "home_elo": home_elo,
230
+ "away_elo": away_elo,
231
+ "elo_diff": round(elo_diff, 0),
232
+ "home_attack": round(h_attack, 3),
233
+ "away_attack": round(a_attack, 3),
234
+ "home_defense": round(h_defense, 3),
235
+ "away_defense": round(a_defense, 3),
236
+ "rho": rho,
237
+ "dc_home_win": dc_result["home_win"],
238
+ "dc_draw": dc_result["draw"],
239
+ "dc_away_win": dc_result["away_win"],
240
+ "elo_home_win": elo_probs["home_win"],
241
+ "elo_draw": elo_probs["draw"],
242
+ "elo_away_win": elo_probs["away_win"],
243
+ "home_form": home_form.get("form_string", "?????"),
244
+ "away_form": away_form.get("form_string", "?????"),
245
+ "home_momentum": home_form.get("momentum", "stable"),
246
+ "away_momentum": away_form.get("momentum", "stable"),
247
+ "h2h_summary": h2h_result.get("summary", ""),
248
+ "h2h_advantage": h2h_adv,
249
+ "w_dc": round(w_dc, 2),
250
+ "w_elo": round(w_elo, 2),
251
+ "use_nb": use_nb,
252
+ "tournament_context": tournament_context,
253
+ "model": model_tag,
254
+ }
255
+
256
+
257
+ def _neutral_form_dict() -> Dict:
258
+ return {
259
+ "form_factor_attack": 1.0,
260
+ "form_factor_defense": 1.0,
261
+ "form_string": "?????",
262
+ "momentum": "stable",
263
+ "matches_analyzed": 0,
264
+ }
265
+
266
+
267
+ def _load_calibrated_rho() -> float:
268
+ """从 tracker 读取已校准的 ρ 值,不可用则返回默认 -0.10。"""
269
+ try:
270
+ from pathlib import Path
271
+ import json
272
+ p = Path.home() / ".arthera" / "wc_rho.json"
273
+ if p.exists():
274
+ d = json.loads(p.read_text())
275
+ return d.get("rho", -0.10)
276
+ except Exception:
277
+ pass
278
+ return -0.10
279
+
280
+
281
+ _predictor_instance: Optional[FootballPredictor] = None
282
+
283
+
284
+ def get_predictor() -> FootballPredictor:
285
+ global _predictor_instance
286
+ if _predictor_instance is None:
287
+ _predictor_instance = FootballPredictor()
288
+ return _predictor_instance
289
+
290
+
291
+ def quick_predict(
292
+ home_team: str,
293
+ away_team: str,
294
+ league: str = "wc",
295
+ neutral_venue: bool = True,
296
+ tournament_context: str = "normal",
297
+ league_avg_override: Optional[float] = None,
298
+ ) -> Dict:
299
+ """
300
+ 一行调用接口。
301
+
302
+ 示例:
303
+ from packages.quant_engine.sports.predictor import quick_predict
304
+ r = quick_predict("germany", "ivory coast", tournament_context="must_win")
305
+ print(f"德国赢: {r['home_win']*100:.1f}%")
306
+ """
307
+ return get_predictor().predict(
308
+ home_team, away_team, league, neutral_venue,
309
+ tournament_context=tournament_context,
310
+ league_avg_override=league_avg_override,
311
+ )