aria-code 4.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agents/__init__.py +32 -0
- agents/base.py +190 -0
- agents/deep/__init__.py +37 -0
- agents/deep/calibration_loop.py +144 -0
- agents/deep/critic.py +125 -0
- agents/deep/deepen.py +193 -0
- agents/deep/models.py +149 -0
- agents/deep/pipeline.py +164 -0
- agents/deep/quant_fusion.py +192 -0
- agents/deep/themes.py +95 -0
- agents/deep/tiers.py +106 -0
- agents/financial/__init__.py +10 -0
- agents/financial/catalyst.py +279 -0
- agents/financial/debate.py +145 -0
- agents/financial/earnings.py +303 -0
- agents/financial/fundamental.py +159 -0
- agents/financial/macro.py +99 -0
- agents/financial/news.py +207 -0
- agents/financial/risk.py +132 -0
- agents/financial/sector.py +279 -0
- agents/financial/synthesis.py +274 -0
- agents/financial/technical.py +258 -0
- agents/portfolio_agent.py +333 -0
- agents/realty/__init__.py +62 -0
- agents/realty/asset_diagnosis.py +150 -0
- agents/realty/business_match.py +165 -0
- agents/realty/cashflow_verify.py +208 -0
- agents/realty/contract_rules.py +209 -0
- agents/realty/energy_anomaly.py +188 -0
- agents/realty/exit_settlement.py +207 -0
- agents/realty/fulfillment_risk.py +205 -0
- agents/realty/ops_optimize.py +159 -0
- agents/realty/revenue_share.py +214 -0
- agents/registry.py +144 -0
- agents/sports/__init__.py +0 -0
- agents/sports/football_agent.py +169 -0
- agents/team.py +289 -0
- aliyun_data_client.py +660 -0
- apps/README.md +12 -0
- apps/__init__.py +2 -0
- apps/channels/README.md +15 -0
- apps/cli/README.md +13 -0
- apps/cli/__init__.py +2 -0
- apps/cli/bootstrap.py +99 -0
- apps/cli/codegen_paths.py +29 -0
- apps/cli/commands/__init__.py +16 -0
- apps/cli/commands/analysis_cmds.py +288 -0
- apps/cli/commands/backtest_cmds.py +1887 -0
- apps/cli/commands/broker_cmds.py +1154 -0
- apps/cli/commands/business_workflow_cmds.py +289 -0
- apps/cli/commands/catalog.py +84 -0
- apps/cli/commands/data_cmds.py +405 -0
- apps/cli/commands/diagnostic_cmds.py +179 -0
- apps/cli/commands/diagnostic_ops_cmds.py +696 -0
- apps/cli/commands/finance_render.py +12 -0
- apps/cli/commands/market.py +399 -0
- apps/cli/commands/market_cmds.py +1276 -0
- apps/cli/commands/market_context.py +425 -0
- apps/cli/commands/market_render.py +7 -0
- apps/cli/commands/model_cmds.py +1579 -0
- apps/cli/commands/ops_cmds.py +668 -0
- apps/cli/commands/portfolio_cmds.py +962 -0
- apps/cli/commands/report.py +377 -0
- apps/cli/commands/scaffold_templates.py +617 -0
- apps/cli/commands/session_cmds.py +179 -0
- apps/cli/commands/session_ux_cmds.py +280 -0
- apps/cli/commands/team.py +588 -0
- apps/cli/commands/team_render.py +8 -0
- apps/cli/commands/ui_cmds.py +358 -0
- apps/cli/commands/workflow_cmds.py +279 -0
- apps/cli/commands/workspace_cmds.py +1414 -0
- apps/cli/config_paths.py +70 -0
- apps/cli/config_store.py +61 -0
- apps/cli/deterministic.py +122 -0
- apps/cli/direct.py +48 -0
- apps/cli/github_app_auth.py +135 -0
- apps/cli/handlers/__init__.py +11 -0
- apps/cli/handlers/broker_handlers.py +122 -0
- apps/cli/handlers/chart_handlers.py +1309 -0
- apps/cli/handlers/market_handlers.py +2509 -0
- apps/cli/handlers/realty_handlers.py +114 -0
- apps/cli/handlers/strategy_advice.py +82 -0
- apps/cli/hooks.py +180 -0
- apps/cli/i18n.py +284 -0
- apps/cli/intent.py +136 -0
- apps/cli/intent_router.py +217 -0
- apps/cli/lifecycle_hooks.py +48 -0
- apps/cli/main.py +29 -0
- apps/cli/market_metadata.py +135 -0
- apps/cli/market_universe.py +265 -0
- apps/cli/message_processing.py +257 -0
- apps/cli/plan_mode.py +139 -0
- apps/cli/plotly_html.py +15 -0
- apps/cli/prediction_feedback.py +202 -0
- apps/cli/preflight.py +497 -0
- apps/cli/project_aria.py +60 -0
- apps/cli/prompts/__init__.py +0 -0
- apps/cli/prompts/coding.py +658 -0
- apps/cli/prompts/system_prompts.py +531 -0
- apps/cli/prompts/ui.py +434 -0
- apps/cli/providers/__init__.py +1 -0
- apps/cli/providers/base.py +271 -0
- apps/cli/providers/chat_routing.py +80 -0
- apps/cli/providers/llm/__init__.py +1 -0
- apps/cli/providers/llm/ollama_stream.py +1170 -0
- apps/cli/providers/llm/sse_stream.py +216 -0
- apps/cli/providers/runtime_bridge.py +185 -0
- apps/cli/runtime_consumer.py +489 -0
- apps/cli/session_export.py +87 -0
- apps/cli/session_jsonl.py +207 -0
- apps/cli/session_store.py +112 -0
- apps/cli/todo_tracker.py +190 -0
- apps/cli/tools/__init__.py +40 -0
- apps/cli/tools/context.py +46 -0
- apps/cli/tools/file_tools.py +112 -0
- apps/cli/tools/market_tools.py +549 -0
- apps/cli/tools/notebook_tools.py +111 -0
- apps/cli/tools/system_tools.py +669 -0
- apps/cli/tools/write_tools.py +715 -0
- apps/cli/tradingview_bridge.py +434 -0
- apps/cli/update_check.py +152 -0
- apps/cli/utils/__init__.py +0 -0
- apps/cli/utils/market_detect.py +1578 -0
- apps/daemon/README.md +14 -0
- apps/vscode/README.md +115 -0
- apps/vscode/package.json +70 -0
- aria_cli.py +11636 -0
- aria_code-4.1.3.dist-info/METADATA +952 -0
- aria_code-4.1.3.dist-info/RECORD +284 -0
- aria_code-4.1.3.dist-info/WHEEL +5 -0
- aria_code-4.1.3.dist-info/entry_points.txt +2 -0
- aria_code-4.1.3.dist-info/licenses/LICENSE +121 -0
- aria_code-4.1.3.dist-info/top_level.txt +50 -0
- aria_daemon.py +1295 -0
- aria_feishu_bot.py +1359 -0
- aria_relay_client.py +182 -0
- aria_relay_server.py +405 -0
- aria_telegram_bot.py +202 -0
- ariarc.py +328 -0
- artifacts.py +491 -0
- backtest_report.py +472 -0
- brokers/__init__.py +72 -0
- brokers/base.py +207 -0
- brokers/capabilities.py +264 -0
- brokers/cn/__init__.py +10 -0
- brokers/cn/easytrader_broker.py +193 -0
- brokers/cn/futu_broker.py +194 -0
- brokers/cn/longbridge_broker.py +190 -0
- brokers/cn/tiger_broker.py +196 -0
- brokers/cn/xtquant_broker.py +175 -0
- brokers/config.py +364 -0
- brokers/intl/__init__.py +5 -0
- brokers/intl/alpaca_broker.py +183 -0
- brokers/intl/ibkr_broker.py +215 -0
- brokers/intl/webull_broker.py +156 -0
- brokers/paper_broker.py +259 -0
- brokers/planning.py +296 -0
- brokers/registry.py +181 -0
- brokers/trading.py +237 -0
- change_store.py +127 -0
- command_safety.py +19 -0
- computer_use_tools.py +504 -0
- dashboard_generator.py +578 -0
- data_analysis_tools.py +808 -0
- data_cleaner.py +483 -0
- data_service.py +481 -0
- datasources/__init__.py +23 -0
- datasources/base.py +166 -0
- datasources/router.py +221 -0
- datasources/sources/__init__.py +15 -0
- datasources/sources/akshare_source.py +269 -0
- datasources/sources/alpha_vantage_source.py +202 -0
- datasources/sources/edgar_source.py +218 -0
- datasources/sources/finnhub_source.py +197 -0
- datasources/sources/fred_source.py +219 -0
- datasources/sources/tushare_source.py +141 -0
- datasources/sources/web_scraper_source.py +278 -0
- datasources/sources/world_bank_source.py +205 -0
- datasources/sources/yfinance_source.py +152 -0
- demo_player.py +204 -0
- doctor.py +508 -0
- file_analysis_tools.py +734 -0
- finance_formulas.py +389 -0
- football_data_client.py +1670 -0
- intent_classifier.py +358 -0
- local_finance_tools.py +3221 -0
- local_llm_provider.py +552 -0
- macro_tools.py +368 -0
- market_data_client.py +1899 -0
- mcp_client.py +506 -0
- memory_manager.py +245 -0
- model_capability.py +416 -0
- notification_tools.py +248 -0
- packages/__init__.py +23 -0
- packages/aria_agents/__init__.py +5 -0
- packages/aria_agents/manifest.py +69 -0
- packages/aria_core/__init__.py +34 -0
- packages/aria_core/architecture.py +192 -0
- packages/aria_core/export.py +124 -0
- packages/aria_core/manifest.py +65 -0
- packages/aria_infra/__init__.py +15 -0
- packages/aria_infra/arthera.py +52 -0
- packages/aria_infra/doctor.py +246 -0
- packages/aria_infra/product.py +37 -0
- packages/aria_mcp/__init__.py +25 -0
- packages/aria_mcp/bridge.py +38 -0
- packages/aria_mcp/config.py +97 -0
- packages/aria_mcp/tools.py +61 -0
- packages/aria_sdk/__init__.py +19 -0
- packages/aria_sdk/client.py +396 -0
- packages/aria_sdk/providers.py +70 -0
- packages/aria_sdk/streaming.py +73 -0
- packages/aria_sdk/types.py +86 -0
- packages/aria_services/__init__.py +55 -0
- packages/aria_services/context.py +258 -0
- packages/aria_services/data.py +11 -0
- packages/aria_services/provider_health.py +189 -0
- packages/aria_services/registry.py +213 -0
- packages/aria_services/usage.py +138 -0
- packages/aria_skills/__init__.py +5 -0
- packages/aria_skills/registry.py +59 -0
- packages/aria_tools/__init__.py +5 -0
- packages/aria_tools/registry.py +128 -0
- packages/quant_engine/__init__.py +6 -0
- packages/quant_engine/sports/__init__.py +72 -0
- packages/quant_engine/sports/calibrator.py +353 -0
- packages/quant_engine/sports/dixon_coles.py +234 -0
- packages/quant_engine/sports/elo.py +299 -0
- packages/quant_engine/sports/form.py +188 -0
- packages/quant_engine/sports/h2h.py +195 -0
- packages/quant_engine/sports/ml_model.py +354 -0
- packages/quant_engine/sports/predictor.py +311 -0
- packages/quant_engine/sports/tracker.py +664 -0
- packages/quant_engine/stochastic/__init__.py +27 -0
- packages/quant_engine/stochastic/gbm_enhanced.py +195 -0
- packages/quant_engine/stochastic/ito_calculus.py +477 -0
- packages/quant_engine/stochastic/kelly_criterion.py +181 -0
- packages/quant_engine/stochastic/monte_carlo_advanced.py +95 -0
- packages/quant_engine/stochastic/options_pricing.py +573 -0
- packages/quant_engine/stochastic/stochastic_processes.py +90 -0
- plan_utils.py +194 -0
- plugin_loader.py +328 -0
- portfolio_ledger.py +262 -0
- privacy/__init__.py +5 -0
- privacy/feedback.py +123 -0
- project_tools.py +525 -0
- providers/__init__.py +30 -0
- providers/llm/__init__.py +19 -0
- providers/llm/anthropic.py +184 -0
- providers/llm/base.py +139 -0
- providers/llm/ollama.py +128 -0
- providers/llm/openai_compat.py +282 -0
- providers/llm/registry.py +358 -0
- realty_data_tools.py +659 -0
- report_generator.py +1314 -0
- runtime/__init__.py +103 -0
- runtime/agent_loop.py +1183 -0
- runtime/approval.py +51 -0
- runtime/events.py +102 -0
- runtime/gateway.py +128 -0
- runtime/lsp.py +346 -0
- runtime/subagent.py +258 -0
- runtime/tool_executor.py +104 -0
- runtime/tool_policy.py +106 -0
- safety/__init__.py +21 -0
- safety/permissions.py +275 -0
- setup_wizard.py +653 -0
- strategy_vault.py +420 -0
- ui/__init__.py +100 -0
- ui/banner.py +310 -0
- ui/completer.py +391 -0
- ui/console.py +271 -0
- ui/image_render.py +243 -0
- ui/input_box.py +376 -0
- ui/picker.py +195 -0
- ui/render/__init__.py +11 -0
- ui/render/finance.py +1480 -0
- ui/render/market.py +225 -0
- ui/render/output.py +681 -0
- ui/render/team.py +346 -0
- ui/robot.py +235 -0
- workspace/__init__.py +6 -0
- workspace/files.py +170 -0
- workspace/verify.py +113 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""
|
|
2
|
+
providers/llm/anthropic.py — Anthropic Claude Provider
|
|
3
|
+
=======================================================
|
|
4
|
+
支持 claude-3-5-sonnet / claude-3-haiku / claude-3-opus 等。
|
|
5
|
+
支持流式 thinking(扩展思考模式)。
|
|
6
|
+
|
|
7
|
+
需要设置: ANTHROPIC_API_KEY=sk-ant-...
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
16
|
+
|
|
17
|
+
from .base import BaseLLMProvider, Message, ProviderConfig
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
_ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages"
|
|
22
|
+
_ANTHROPIC_VERSION = "2023-06-01"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AnthropicProvider(BaseLLMProvider):
|
|
26
|
+
|
|
27
|
+
provider_name = "anthropic"
|
|
28
|
+
supports_tools = True
|
|
29
|
+
supports_thinking = True # claude-3-5-sonnet 支持 extended thinking
|
|
30
|
+
local = False
|
|
31
|
+
|
|
32
|
+
DEFAULT_MODEL = "claude-3-5-haiku-latest"
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: ProviderConfig):
|
|
35
|
+
super().__init__(config)
|
|
36
|
+
self.api_key = config.api_key or os.getenv("ANTHROPIC_API_KEY", "")
|
|
37
|
+
self.model = config.model or self.DEFAULT_MODEL
|
|
38
|
+
|
|
39
|
+
async def is_available(self) -> bool:
|
|
40
|
+
return bool(self.api_key)
|
|
41
|
+
|
|
42
|
+
# Minimum system-prompt length (chars) to bother caching.
|
|
43
|
+
# Prompts shorter than this don't benefit from prompt caching.
|
|
44
|
+
_CACHE_MIN_CHARS = 1024
|
|
45
|
+
|
|
46
|
+
def _build_cached_system(self, system_text: str) -> list:
|
|
47
|
+
"""
|
|
48
|
+
Wrap a long system prompt in a cache_control block so Anthropic
|
|
49
|
+
caches it across calls. Returns the `system` field value (a list
|
|
50
|
+
of content blocks, or a plain string for short prompts).
|
|
51
|
+
|
|
52
|
+
Anthropic prompt caching docs:
|
|
53
|
+
https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
|
54
|
+
"""
|
|
55
|
+
if len(system_text) < self._CACHE_MIN_CHARS:
|
|
56
|
+
return system_text # type: ignore[return-value]
|
|
57
|
+
return [
|
|
58
|
+
{
|
|
59
|
+
"type": "text",
|
|
60
|
+
"text": system_text,
|
|
61
|
+
"cache_control": {"type": "ephemeral"},
|
|
62
|
+
}
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
async def stream(
|
|
66
|
+
self,
|
|
67
|
+
messages: List[Message],
|
|
68
|
+
tools: Optional[List[Dict]] = None,
|
|
69
|
+
temperature: Optional[float] = None,
|
|
70
|
+
max_tokens: Optional[int] = None,
|
|
71
|
+
cancel_event=None,
|
|
72
|
+
) -> AsyncIterator[Dict[str, Any]]:
|
|
73
|
+
import aiohttp
|
|
74
|
+
|
|
75
|
+
temp = temperature if temperature is not None else self.config.temperature
|
|
76
|
+
n_tokens = max_tokens if max_tokens is not None else self.config.max_tokens
|
|
77
|
+
|
|
78
|
+
# Anthropic 格式:system 单独提取
|
|
79
|
+
system_parts = [m.content for m in messages if m.role == "system"]
|
|
80
|
+
system_text = "\n\n".join(system_parts) if system_parts else None
|
|
81
|
+
anthro_msgs = [
|
|
82
|
+
{"role": m.role, "content": m.content}
|
|
83
|
+
for m in messages if m.role != "system"
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
payload: Dict[str, Any] = {
|
|
87
|
+
"model": self.model,
|
|
88
|
+
"max_tokens": n_tokens,
|
|
89
|
+
"messages": anthro_msgs,
|
|
90
|
+
"stream": True,
|
|
91
|
+
}
|
|
92
|
+
if system_text:
|
|
93
|
+
# Use prompt caching for long system prompts to reduce TTFT and cost
|
|
94
|
+
payload["system"] = self._build_cached_system(system_text)
|
|
95
|
+
if temp > 0:
|
|
96
|
+
payload["temperature"] = temp
|
|
97
|
+
if tools:
|
|
98
|
+
payload["tools"] = [
|
|
99
|
+
{
|
|
100
|
+
"name": t.get("name"),
|
|
101
|
+
"description": t.get("description", ""),
|
|
102
|
+
"input_schema": t.get("parameters", {"type": "object", "properties": {}}),
|
|
103
|
+
}
|
|
104
|
+
for t in tools
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
headers = {
|
|
108
|
+
"x-api-key": self.api_key,
|
|
109
|
+
"anthropic-version": _ANTHROPIC_VERSION,
|
|
110
|
+
"anthropic-beta": "prompt-caching-2024-07-31",
|
|
111
|
+
"content-type": "application/json",
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
proxy = (os.getenv("HTTPS_PROXY") or os.getenv("https_proxy")
|
|
115
|
+
or os.getenv("HTTP_PROXY") or os.getenv("http_proxy"))
|
|
116
|
+
try:
|
|
117
|
+
async with aiohttp.ClientSession() as sess:
|
|
118
|
+
async with sess.post(
|
|
119
|
+
_ANTHROPIC_API_URL, json=payload, headers=headers,
|
|
120
|
+
proxy=proxy,
|
|
121
|
+
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
|
122
|
+
) as resp:
|
|
123
|
+
if resp.status != 200:
|
|
124
|
+
body = await resp.text()
|
|
125
|
+
yield {"type": "error",
|
|
126
|
+
"message": f"Anthropic HTTP {resp.status}: {body[:300]}"}
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
_pending_tool_name = ""
|
|
130
|
+
_pending_tool_args = ""
|
|
131
|
+
|
|
132
|
+
async for raw in resp.content:
|
|
133
|
+
if cancel_event and cancel_event.is_set():
|
|
134
|
+
return
|
|
135
|
+
line = raw.decode("utf-8", errors="ignore").strip()
|
|
136
|
+
if not line or not line.startswith("data:"):
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
data = json.loads(line[5:].strip())
|
|
141
|
+
except json.JSONDecodeError:
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
etype = data.get("type", "")
|
|
145
|
+
|
|
146
|
+
# 思考 token(扩展思考模式)
|
|
147
|
+
if etype == "content_block_start":
|
|
148
|
+
cb = data.get("content_block", {})
|
|
149
|
+
if cb.get("type") == "thinking":
|
|
150
|
+
yield {"type": "thinking", "text": cb.get("thinking", "")}
|
|
151
|
+
elif cb.get("type") == "tool_use":
|
|
152
|
+
_pending_tool_name = cb.get("name", "")
|
|
153
|
+
_pending_tool_args = ""
|
|
154
|
+
|
|
155
|
+
elif etype == "content_block_delta":
|
|
156
|
+
delta = data.get("delta", {})
|
|
157
|
+
dt = delta.get("type", "")
|
|
158
|
+
if dt == "text_delta":
|
|
159
|
+
yield {"type": "token", "text": delta.get("text", "")}
|
|
160
|
+
elif dt == "thinking_delta":
|
|
161
|
+
yield {"type": "thinking", "text": delta.get("thinking", "")}
|
|
162
|
+
elif dt == "input_json_delta":
|
|
163
|
+
_pending_tool_args += delta.get("partial_json", "")
|
|
164
|
+
|
|
165
|
+
elif etype == "content_block_stop":
|
|
166
|
+
if _pending_tool_name:
|
|
167
|
+
try:
|
|
168
|
+
args = json.loads(_pending_tool_args)
|
|
169
|
+
except Exception:
|
|
170
|
+
args = {"_raw": _pending_tool_args}
|
|
171
|
+
yield {"type": "tool_call",
|
|
172
|
+
"name": _pending_tool_name,
|
|
173
|
+
"arguments": args}
|
|
174
|
+
_pending_tool_name = ""
|
|
175
|
+
_pending_tool_args = ""
|
|
176
|
+
|
|
177
|
+
elif etype == "message_stop":
|
|
178
|
+
yield {"type": "done"}
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
except aiohttp.ClientConnectorError as e:
|
|
182
|
+
yield {"type": "error", "message": f"Anthropic 连接失败: {e}"}
|
|
183
|
+
except Exception as e:
|
|
184
|
+
yield {"type": "error", "message": f"Anthropic 错误: {e}"}
|
providers/llm/base.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""
|
|
2
|
+
providers/llm/base.py — LLM Provider 统一抽象基类
|
|
3
|
+
==================================================
|
|
4
|
+
所有 provider(本地/云端)实现同一接口,上层代码无需关心具体后端。
|
|
5
|
+
|
|
6
|
+
事件类型 (stream yields):
|
|
7
|
+
{"type": "token", "text": "..."} # 文本增量
|
|
8
|
+
{"type": "thinking", "text": "..."} # 思考过程 (Claude/DeepSeek-R1)
|
|
9
|
+
{"type": "tool_call", "name": "...", "arguments": {...}} # 工具调用
|
|
10
|
+
{"type": "done", "text": "完整响应"} # 流结束
|
|
11
|
+
{"type": "error", "message": "..."} # 错误
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Message:
|
|
24
|
+
role: str # "system" | "user" | "assistant" | "tool"
|
|
25
|
+
content: str
|
|
26
|
+
name: Optional[str] = None # tool name (for role=tool)
|
|
27
|
+
tool_call_id: Optional[str] = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ProviderConfig:
|
|
32
|
+
"""可由 providers.yaml / 环境变量 / CLI 参数覆盖"""
|
|
33
|
+
name: str
|
|
34
|
+
api_key: Optional[str] = None
|
|
35
|
+
base_url: Optional[str] = None
|
|
36
|
+
model: Optional[str] = None
|
|
37
|
+
temperature: float = 0.3
|
|
38
|
+
max_tokens: int = 4096
|
|
39
|
+
timeout: int = 120
|
|
40
|
+
# 扩展字段(各 provider 可自定义)
|
|
41
|
+
extra: Dict[str, Any] = field(default_factory=dict)
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_env(cls, name: str, **defaults) -> "ProviderConfig":
|
|
45
|
+
"""从环境变量自动读取 API Key(DEEPSEEK_API_KEY / OPENAI_API_KEY 等)"""
|
|
46
|
+
key_map = {
|
|
47
|
+
"deepseek": "DEEPSEEK_API_KEY",
|
|
48
|
+
"openai": "OPENAI_API_KEY",
|
|
49
|
+
"anthropic": "ANTHROPIC_API_KEY",
|
|
50
|
+
"groq": "GROQ_API_KEY",
|
|
51
|
+
"together": "TOGETHER_API_KEY",
|
|
52
|
+
"dashscope": "DASHSCOPE_API_KEY",
|
|
53
|
+
"siliconflow": "SILICONFLOW_API_KEY",
|
|
54
|
+
"moonshot": "MOONSHOT_API_KEY",
|
|
55
|
+
"zhipu": "ZHIPUAI_API_KEY",
|
|
56
|
+
}
|
|
57
|
+
env_var = key_map.get(name.lower())
|
|
58
|
+
api_key = os.getenv(env_var, "") if env_var else ""
|
|
59
|
+
return cls(name=name, api_key=api_key or None, **defaults)
|
|
60
|
+
|
|
61
|
+
def is_configured(self) -> bool:
|
|
62
|
+
"""判断 provider 是否可用(本地 provider 无需 api_key)"""
|
|
63
|
+
_local = {"ollama", "lmstudio", "vllm", "llamacpp", "jan"}
|
|
64
|
+
if self.name.lower() in _local:
|
|
65
|
+
return True
|
|
66
|
+
return bool(self.api_key)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class BaseLLMProvider(ABC):
|
|
70
|
+
"""
|
|
71
|
+
所有 LLM provider 的统一基类。
|
|
72
|
+
|
|
73
|
+
子类只需实现 `stream()` 方法;`complete()` 会自动聚合流结果。
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
# 子类声明这些属性
|
|
77
|
+
provider_name: str = "base"
|
|
78
|
+
supports_tools: bool = False
|
|
79
|
+
supports_thinking: bool = False
|
|
80
|
+
local: bool = False # True = 本地运行,不需要 api_key
|
|
81
|
+
|
|
82
|
+
def __init__(self, config: ProviderConfig):
|
|
83
|
+
self.config = config
|
|
84
|
+
|
|
85
|
+
# ── 必须实现 ──────────────────────────────────────────────────────────────
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
async def stream(
|
|
89
|
+
self,
|
|
90
|
+
messages: List[Message],
|
|
91
|
+
tools: Optional[List[Dict]] = None,
|
|
92
|
+
temperature: Optional[float] = None,
|
|
93
|
+
max_tokens: Optional[int] = None,
|
|
94
|
+
cancel_event=None,
|
|
95
|
+
) -> AsyncIterator[Dict[str, Any]]:
|
|
96
|
+
"""
|
|
97
|
+
流式生成 token。
|
|
98
|
+
|
|
99
|
+
每次 yield 一个事件 dict(见模块文档中的事件类型)。
|
|
100
|
+
"""
|
|
101
|
+
...
|
|
102
|
+
yield {} # 让 Python 识别为 async generator
|
|
103
|
+
|
|
104
|
+
# ── 默认实现(子类可覆盖)────────────────────────────────────────────────
|
|
105
|
+
|
|
106
|
+
async def complete(
|
|
107
|
+
self,
|
|
108
|
+
messages: List[Message],
|
|
109
|
+
tools: Optional[List[Dict]] = None,
|
|
110
|
+
temperature: Optional[float] = None,
|
|
111
|
+
max_tokens: Optional[int] = None,
|
|
112
|
+
) -> Dict[str, Any]:
|
|
113
|
+
"""聚合 stream() 事件,返回完整响应 dict。"""
|
|
114
|
+
full_text = ""
|
|
115
|
+
tool_calls = []
|
|
116
|
+
async for event in self.stream(
|
|
117
|
+
messages, tools=tools,
|
|
118
|
+
temperature=temperature, max_tokens=max_tokens,
|
|
119
|
+
):
|
|
120
|
+
t = event.get("type")
|
|
121
|
+
if t == "token":
|
|
122
|
+
full_text += event.get("text", "")
|
|
123
|
+
elif t == "tool_call":
|
|
124
|
+
tool_calls.append({
|
|
125
|
+
"name": event.get("name"),
|
|
126
|
+
"arguments": event.get("arguments", {}),
|
|
127
|
+
})
|
|
128
|
+
elif t == "error":
|
|
129
|
+
return {"success": False, "error": event.get("message"),
|
|
130
|
+
"response": full_text, "tool_calls": tool_calls}
|
|
131
|
+
return {"success": True, "response": full_text, "tool_calls": tool_calls}
|
|
132
|
+
|
|
133
|
+
async def is_available(self) -> bool:
|
|
134
|
+
"""检查 provider 是否在线/可用(子类覆盖以做实际探测)"""
|
|
135
|
+
return self.config.is_configured()
|
|
136
|
+
|
|
137
|
+
def __repr__(self) -> str:
|
|
138
|
+
model = self.config.model or "default"
|
|
139
|
+
return f"{self.__class__.__name__}(model={model})"
|
providers/llm/ollama.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
providers/llm/ollama.py — Ollama 本地 LLM Provider
|
|
3
|
+
====================================================
|
|
4
|
+
连接本地运行的 Ollama 服务,支持工具调用(native + text-based fallback)。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
from .base import BaseLLMProvider, Message, ProviderConfig
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OllamaProvider(BaseLLMProvider):
|
|
19
|
+
|
|
20
|
+
provider_name = "ollama"
|
|
21
|
+
supports_tools = True
|
|
22
|
+
supports_thinking = False
|
|
23
|
+
local = True
|
|
24
|
+
|
|
25
|
+
DEFAULT_URL = "http://localhost:11434"
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: ProviderConfig):
|
|
28
|
+
super().__init__(config)
|
|
29
|
+
self.base_url = config.base_url or self.DEFAULT_URL
|
|
30
|
+
self.model = config.model or "qwen2.5:7b"
|
|
31
|
+
|
|
32
|
+
async def is_available(self) -> bool:
|
|
33
|
+
"""探测 Ollama 服务是否在线"""
|
|
34
|
+
import urllib.request
|
|
35
|
+
for url in [self.base_url,
|
|
36
|
+
self.base_url.replace("localhost", "127.0.0.1")]:
|
|
37
|
+
try:
|
|
38
|
+
urllib.request.urlopen(f"{url}/api/tags", timeout=2).close()
|
|
39
|
+
return True
|
|
40
|
+
except Exception:
|
|
41
|
+
continue
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
async def list_models(self) -> List[str]:
|
|
45
|
+
"""返回已安装的模型列表"""
|
|
46
|
+
import urllib.request, json as _json
|
|
47
|
+
try:
|
|
48
|
+
with urllib.request.urlopen(f"{self.base_url}/api/tags", timeout=3) as r:
|
|
49
|
+
data = _json.loads(r.read())
|
|
50
|
+
return [m["name"] for m in data.get("models", [])]
|
|
51
|
+
except Exception:
|
|
52
|
+
return []
|
|
53
|
+
|
|
54
|
+
async def stream(
|
|
55
|
+
self,
|
|
56
|
+
messages: List[Message],
|
|
57
|
+
tools: Optional[List[Dict]] = None,
|
|
58
|
+
temperature: Optional[float] = None,
|
|
59
|
+
max_tokens: Optional[int] = None,
|
|
60
|
+
cancel_event=None,
|
|
61
|
+
) -> AsyncIterator[Dict[str, Any]]:
|
|
62
|
+
import aiohttp
|
|
63
|
+
|
|
64
|
+
temp = temperature if temperature is not None else self.config.temperature
|
|
65
|
+
n_tokens = max_tokens if max_tokens is not None else self.config.max_tokens
|
|
66
|
+
|
|
67
|
+
payload = {
|
|
68
|
+
"model": self.model,
|
|
69
|
+
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
|
70
|
+
"stream": True,
|
|
71
|
+
"options": {"temperature": temp, "num_predict": n_tokens, "num_ctx": 32768},
|
|
72
|
+
}
|
|
73
|
+
if tools:
|
|
74
|
+
# 只有支持 native tool call 的模型才注入(其他走 text-based 解析)
|
|
75
|
+
payload["tools"] = tools
|
|
76
|
+
|
|
77
|
+
url = f"{self.base_url}/api/chat"
|
|
78
|
+
try:
|
|
79
|
+
async with aiohttp.ClientSession() as sess:
|
|
80
|
+
async with sess.post(
|
|
81
|
+
url, json=payload,
|
|
82
|
+
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
|
83
|
+
) as resp:
|
|
84
|
+
if resp.status != 200:
|
|
85
|
+
body = await resp.text()
|
|
86
|
+
yield {"type": "error",
|
|
87
|
+
"message": f"Ollama HTTP {resp.status}: {body[:200]}"}
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
async for raw in resp.content:
|
|
91
|
+
if cancel_event and cancel_event.is_set():
|
|
92
|
+
return
|
|
93
|
+
line = raw.decode("utf-8", errors="ignore").strip()
|
|
94
|
+
if not line:
|
|
95
|
+
continue
|
|
96
|
+
try:
|
|
97
|
+
data = json.loads(line)
|
|
98
|
+
except json.JSONDecodeError:
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
# Native tool calls
|
|
102
|
+
tc_list = (data.get("message") or {}).get("tool_calls") or []
|
|
103
|
+
for tc in tc_list:
|
|
104
|
+
fn = tc.get("function") or {}
|
|
105
|
+
args = fn.get("arguments", {})
|
|
106
|
+
if isinstance(args, str):
|
|
107
|
+
try:
|
|
108
|
+
args = json.loads(args)
|
|
109
|
+
except json.JSONDecodeError:
|
|
110
|
+
args = {"_raw": args}
|
|
111
|
+
yield {"type": "tool_call",
|
|
112
|
+
"name": fn.get("name", ""),
|
|
113
|
+
"arguments": args}
|
|
114
|
+
|
|
115
|
+
# Text token
|
|
116
|
+
token = (data.get("message") or {}).get("content", "")
|
|
117
|
+
if token:
|
|
118
|
+
yield {"type": "token", "text": token}
|
|
119
|
+
|
|
120
|
+
if data.get("done"):
|
|
121
|
+
yield {"type": "done"}
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
except aiohttp.ClientConnectorError as e:
|
|
125
|
+
yield {"type": "error",
|
|
126
|
+
"message": f"Ollama 连接失败 ({self.base_url}): {e}"}
|
|
127
|
+
except Exception as e:
|
|
128
|
+
yield {"type": "error", "message": f"Ollama 错误: {e}"}
|