aria-code 4.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agents/__init__.py +32 -0
- agents/base.py +190 -0
- agents/deep/__init__.py +37 -0
- agents/deep/calibration_loop.py +144 -0
- agents/deep/critic.py +125 -0
- agents/deep/deepen.py +193 -0
- agents/deep/models.py +149 -0
- agents/deep/pipeline.py +164 -0
- agents/deep/quant_fusion.py +192 -0
- agents/deep/themes.py +95 -0
- agents/deep/tiers.py +106 -0
- agents/financial/__init__.py +10 -0
- agents/financial/catalyst.py +279 -0
- agents/financial/debate.py +145 -0
- agents/financial/earnings.py +303 -0
- agents/financial/fundamental.py +159 -0
- agents/financial/macro.py +99 -0
- agents/financial/news.py +207 -0
- agents/financial/risk.py +132 -0
- agents/financial/sector.py +279 -0
- agents/financial/synthesis.py +274 -0
- agents/financial/technical.py +258 -0
- agents/portfolio_agent.py +333 -0
- agents/realty/__init__.py +62 -0
- agents/realty/asset_diagnosis.py +150 -0
- agents/realty/business_match.py +165 -0
- agents/realty/cashflow_verify.py +208 -0
- agents/realty/contract_rules.py +209 -0
- agents/realty/energy_anomaly.py +188 -0
- agents/realty/exit_settlement.py +207 -0
- agents/realty/fulfillment_risk.py +205 -0
- agents/realty/ops_optimize.py +159 -0
- agents/realty/revenue_share.py +214 -0
- agents/registry.py +144 -0
- agents/sports/__init__.py +0 -0
- agents/sports/football_agent.py +169 -0
- agents/team.py +289 -0
- aliyun_data_client.py +660 -0
- apps/README.md +12 -0
- apps/__init__.py +2 -0
- apps/channels/README.md +15 -0
- apps/cli/README.md +13 -0
- apps/cli/__init__.py +2 -0
- apps/cli/bootstrap.py +99 -0
- apps/cli/codegen_paths.py +29 -0
- apps/cli/commands/__init__.py +16 -0
- apps/cli/commands/analysis_cmds.py +288 -0
- apps/cli/commands/backtest_cmds.py +1887 -0
- apps/cli/commands/broker_cmds.py +1154 -0
- apps/cli/commands/business_workflow_cmds.py +289 -0
- apps/cli/commands/catalog.py +84 -0
- apps/cli/commands/data_cmds.py +405 -0
- apps/cli/commands/diagnostic_cmds.py +179 -0
- apps/cli/commands/diagnostic_ops_cmds.py +696 -0
- apps/cli/commands/finance_render.py +12 -0
- apps/cli/commands/market.py +399 -0
- apps/cli/commands/market_cmds.py +1276 -0
- apps/cli/commands/market_context.py +425 -0
- apps/cli/commands/market_render.py +7 -0
- apps/cli/commands/model_cmds.py +1579 -0
- apps/cli/commands/ops_cmds.py +668 -0
- apps/cli/commands/portfolio_cmds.py +962 -0
- apps/cli/commands/report.py +377 -0
- apps/cli/commands/scaffold_templates.py +617 -0
- apps/cli/commands/session_cmds.py +179 -0
- apps/cli/commands/session_ux_cmds.py +280 -0
- apps/cli/commands/team.py +588 -0
- apps/cli/commands/team_render.py +8 -0
- apps/cli/commands/ui_cmds.py +358 -0
- apps/cli/commands/workflow_cmds.py +279 -0
- apps/cli/commands/workspace_cmds.py +1414 -0
- apps/cli/config_paths.py +70 -0
- apps/cli/config_store.py +61 -0
- apps/cli/deterministic.py +122 -0
- apps/cli/direct.py +48 -0
- apps/cli/github_app_auth.py +135 -0
- apps/cli/handlers/__init__.py +11 -0
- apps/cli/handlers/broker_handlers.py +122 -0
- apps/cli/handlers/chart_handlers.py +1309 -0
- apps/cli/handlers/market_handlers.py +2509 -0
- apps/cli/handlers/realty_handlers.py +114 -0
- apps/cli/handlers/strategy_advice.py +82 -0
- apps/cli/hooks.py +180 -0
- apps/cli/i18n.py +284 -0
- apps/cli/intent.py +136 -0
- apps/cli/intent_router.py +217 -0
- apps/cli/lifecycle_hooks.py +48 -0
- apps/cli/main.py +29 -0
- apps/cli/market_metadata.py +135 -0
- apps/cli/market_universe.py +265 -0
- apps/cli/message_processing.py +257 -0
- apps/cli/plan_mode.py +139 -0
- apps/cli/plotly_html.py +15 -0
- apps/cli/prediction_feedback.py +202 -0
- apps/cli/preflight.py +497 -0
- apps/cli/project_aria.py +60 -0
- apps/cli/prompts/__init__.py +0 -0
- apps/cli/prompts/coding.py +658 -0
- apps/cli/prompts/system_prompts.py +531 -0
- apps/cli/prompts/ui.py +434 -0
- apps/cli/providers/__init__.py +1 -0
- apps/cli/providers/base.py +271 -0
- apps/cli/providers/chat_routing.py +80 -0
- apps/cli/providers/llm/__init__.py +1 -0
- apps/cli/providers/llm/ollama_stream.py +1170 -0
- apps/cli/providers/llm/sse_stream.py +216 -0
- apps/cli/providers/runtime_bridge.py +185 -0
- apps/cli/runtime_consumer.py +489 -0
- apps/cli/session_export.py +87 -0
- apps/cli/session_jsonl.py +207 -0
- apps/cli/session_store.py +112 -0
- apps/cli/todo_tracker.py +190 -0
- apps/cli/tools/__init__.py +40 -0
- apps/cli/tools/context.py +46 -0
- apps/cli/tools/file_tools.py +112 -0
- apps/cli/tools/market_tools.py +549 -0
- apps/cli/tools/notebook_tools.py +111 -0
- apps/cli/tools/system_tools.py +669 -0
- apps/cli/tools/write_tools.py +715 -0
- apps/cli/tradingview_bridge.py +434 -0
- apps/cli/update_check.py +152 -0
- apps/cli/utils/__init__.py +0 -0
- apps/cli/utils/market_detect.py +1578 -0
- apps/daemon/README.md +14 -0
- apps/vscode/README.md +115 -0
- apps/vscode/package.json +70 -0
- aria_cli.py +11636 -0
- aria_code-4.1.3.dist-info/METADATA +952 -0
- aria_code-4.1.3.dist-info/RECORD +284 -0
- aria_code-4.1.3.dist-info/WHEEL +5 -0
- aria_code-4.1.3.dist-info/entry_points.txt +2 -0
- aria_code-4.1.3.dist-info/licenses/LICENSE +121 -0
- aria_code-4.1.3.dist-info/top_level.txt +50 -0
- aria_daemon.py +1295 -0
- aria_feishu_bot.py +1359 -0
- aria_relay_client.py +182 -0
- aria_relay_server.py +405 -0
- aria_telegram_bot.py +202 -0
- ariarc.py +328 -0
- artifacts.py +491 -0
- backtest_report.py +472 -0
- brokers/__init__.py +72 -0
- brokers/base.py +207 -0
- brokers/capabilities.py +264 -0
- brokers/cn/__init__.py +10 -0
- brokers/cn/easytrader_broker.py +193 -0
- brokers/cn/futu_broker.py +194 -0
- brokers/cn/longbridge_broker.py +190 -0
- brokers/cn/tiger_broker.py +196 -0
- brokers/cn/xtquant_broker.py +175 -0
- brokers/config.py +364 -0
- brokers/intl/__init__.py +5 -0
- brokers/intl/alpaca_broker.py +183 -0
- brokers/intl/ibkr_broker.py +215 -0
- brokers/intl/webull_broker.py +156 -0
- brokers/paper_broker.py +259 -0
- brokers/planning.py +296 -0
- brokers/registry.py +181 -0
- brokers/trading.py +237 -0
- change_store.py +127 -0
- command_safety.py +19 -0
- computer_use_tools.py +504 -0
- dashboard_generator.py +578 -0
- data_analysis_tools.py +808 -0
- data_cleaner.py +483 -0
- data_service.py +481 -0
- datasources/__init__.py +23 -0
- datasources/base.py +166 -0
- datasources/router.py +221 -0
- datasources/sources/__init__.py +15 -0
- datasources/sources/akshare_source.py +269 -0
- datasources/sources/alpha_vantage_source.py +202 -0
- datasources/sources/edgar_source.py +218 -0
- datasources/sources/finnhub_source.py +197 -0
- datasources/sources/fred_source.py +219 -0
- datasources/sources/tushare_source.py +141 -0
- datasources/sources/web_scraper_source.py +278 -0
- datasources/sources/world_bank_source.py +205 -0
- datasources/sources/yfinance_source.py +152 -0
- demo_player.py +204 -0
- doctor.py +508 -0
- file_analysis_tools.py +734 -0
- finance_formulas.py +389 -0
- football_data_client.py +1670 -0
- intent_classifier.py +358 -0
- local_finance_tools.py +3221 -0
- local_llm_provider.py +552 -0
- macro_tools.py +368 -0
- market_data_client.py +1899 -0
- mcp_client.py +506 -0
- memory_manager.py +245 -0
- model_capability.py +416 -0
- notification_tools.py +248 -0
- packages/__init__.py +23 -0
- packages/aria_agents/__init__.py +5 -0
- packages/aria_agents/manifest.py +69 -0
- packages/aria_core/__init__.py +34 -0
- packages/aria_core/architecture.py +192 -0
- packages/aria_core/export.py +124 -0
- packages/aria_core/manifest.py +65 -0
- packages/aria_infra/__init__.py +15 -0
- packages/aria_infra/arthera.py +52 -0
- packages/aria_infra/doctor.py +246 -0
- packages/aria_infra/product.py +37 -0
- packages/aria_mcp/__init__.py +25 -0
- packages/aria_mcp/bridge.py +38 -0
- packages/aria_mcp/config.py +97 -0
- packages/aria_mcp/tools.py +61 -0
- packages/aria_sdk/__init__.py +19 -0
- packages/aria_sdk/client.py +396 -0
- packages/aria_sdk/providers.py +70 -0
- packages/aria_sdk/streaming.py +73 -0
- packages/aria_sdk/types.py +86 -0
- packages/aria_services/__init__.py +55 -0
- packages/aria_services/context.py +258 -0
- packages/aria_services/data.py +11 -0
- packages/aria_services/provider_health.py +189 -0
- packages/aria_services/registry.py +213 -0
- packages/aria_services/usage.py +138 -0
- packages/aria_skills/__init__.py +5 -0
- packages/aria_skills/registry.py +59 -0
- packages/aria_tools/__init__.py +5 -0
- packages/aria_tools/registry.py +128 -0
- packages/quant_engine/__init__.py +6 -0
- packages/quant_engine/sports/__init__.py +72 -0
- packages/quant_engine/sports/calibrator.py +353 -0
- packages/quant_engine/sports/dixon_coles.py +234 -0
- packages/quant_engine/sports/elo.py +299 -0
- packages/quant_engine/sports/form.py +188 -0
- packages/quant_engine/sports/h2h.py +195 -0
- packages/quant_engine/sports/ml_model.py +354 -0
- packages/quant_engine/sports/predictor.py +311 -0
- packages/quant_engine/sports/tracker.py +664 -0
- packages/quant_engine/stochastic/__init__.py +27 -0
- packages/quant_engine/stochastic/gbm_enhanced.py +195 -0
- packages/quant_engine/stochastic/ito_calculus.py +477 -0
- packages/quant_engine/stochastic/kelly_criterion.py +181 -0
- packages/quant_engine/stochastic/monte_carlo_advanced.py +95 -0
- packages/quant_engine/stochastic/options_pricing.py +573 -0
- packages/quant_engine/stochastic/stochastic_processes.py +90 -0
- plan_utils.py +194 -0
- plugin_loader.py +328 -0
- portfolio_ledger.py +262 -0
- privacy/__init__.py +5 -0
- privacy/feedback.py +123 -0
- project_tools.py +525 -0
- providers/__init__.py +30 -0
- providers/llm/__init__.py +19 -0
- providers/llm/anthropic.py +184 -0
- providers/llm/base.py +139 -0
- providers/llm/ollama.py +128 -0
- providers/llm/openai_compat.py +282 -0
- providers/llm/registry.py +358 -0
- realty_data_tools.py +659 -0
- report_generator.py +1314 -0
- runtime/__init__.py +103 -0
- runtime/agent_loop.py +1183 -0
- runtime/approval.py +51 -0
- runtime/events.py +102 -0
- runtime/gateway.py +128 -0
- runtime/lsp.py +346 -0
- runtime/subagent.py +258 -0
- runtime/tool_executor.py +104 -0
- runtime/tool_policy.py +106 -0
- safety/__init__.py +21 -0
- safety/permissions.py +275 -0
- setup_wizard.py +653 -0
- strategy_vault.py +420 -0
- ui/__init__.py +100 -0
- ui/banner.py +310 -0
- ui/completer.py +391 -0
- ui/console.py +271 -0
- ui/image_render.py +243 -0
- ui/input_box.py +376 -0
- ui/picker.py +195 -0
- ui/render/__init__.py +11 -0
- ui/render/finance.py +1480 -0
- ui/render/market.py +225 -0
- ui/render/output.py +681 -0
- ui/render/team.py +346 -0
- ui/robot.py +235 -0
- workspace/__init__.py +6 -0
- workspace/files.py +170 -0
- workspace/verify.py +113 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""MCP configuration helpers for Aria and Arthera package bridges."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def arthera_quant_engine_server_config(
|
|
11
|
+
arthera_root: Optional[Path] = None,
|
|
12
|
+
) -> Dict[str, object]:
|
|
13
|
+
"""Return a ready-to-write MCP server entry for Arthera QuantEngine."""
|
|
14
|
+
|
|
15
|
+
root = (arthera_root or Path.home() / "Desktop" / "Arthera").expanduser()
|
|
16
|
+
server = root / "packages" / "quant_engine" / "mcp_server.py"
|
|
17
|
+
return {
|
|
18
|
+
"name": "arthera_quant_engine",
|
|
19
|
+
"command": "python3",
|
|
20
|
+
"args": [str(server)],
|
|
21
|
+
"env": {"PYTHONPATH": str(root)},
|
|
22
|
+
"description": "Arthera QuantEngine tools exposed through MCP",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def merge_server_config(existing: Dict[str, object], server: Dict[str, object]) -> Dict[str, object]:
|
|
27
|
+
"""Return config with server upserted by name."""
|
|
28
|
+
|
|
29
|
+
servers = list(existing.get("servers") or [])
|
|
30
|
+
name = server.get("name")
|
|
31
|
+
out = []
|
|
32
|
+
replaced = False
|
|
33
|
+
for item in servers:
|
|
34
|
+
if isinstance(item, dict) and item.get("name") == name:
|
|
35
|
+
out.append(server)
|
|
36
|
+
replaced = True
|
|
37
|
+
else:
|
|
38
|
+
out.append(item)
|
|
39
|
+
if not replaced:
|
|
40
|
+
out.append(server)
|
|
41
|
+
return {**existing, "servers": out}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def load_mcp_config(path: Path) -> Dict[str, object]:
|
|
45
|
+
if not path.exists():
|
|
46
|
+
return {"servers": []}
|
|
47
|
+
try:
|
|
48
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
49
|
+
except Exception:
|
|
50
|
+
return {"servers": []}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def write_mcp_config(path: Path, config: Dict[str, object]) -> Path:
|
|
54
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
55
|
+
path.write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
56
|
+
return path
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def find_server_config(config: Dict[str, object], name: str) -> Optional[Dict[str, object]]:
|
|
60
|
+
for item in config.get("servers") or []:
|
|
61
|
+
if isinstance(item, dict) and item.get("name") == name:
|
|
62
|
+
return item
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def mcp_server_status(
|
|
67
|
+
config_path: Path,
|
|
68
|
+
server_name: str = "arthera_quant_engine",
|
|
69
|
+
runtime_status: Optional[List[Dict[str, Any]]] = None,
|
|
70
|
+
) -> Dict[str, Any]:
|
|
71
|
+
"""Return config/file/runtime status for a configured MCP server."""
|
|
72
|
+
|
|
73
|
+
config = load_mcp_config(config_path)
|
|
74
|
+
server = find_server_config(config, server_name)
|
|
75
|
+
runtime = None
|
|
76
|
+
for item in runtime_status or []:
|
|
77
|
+
if item.get("name") == server_name:
|
|
78
|
+
runtime = item
|
|
79
|
+
break
|
|
80
|
+
|
|
81
|
+
server_path = None
|
|
82
|
+
if server:
|
|
83
|
+
args = server.get("args") or []
|
|
84
|
+
if args:
|
|
85
|
+
server_path = Path(str(args[0])).expanduser()
|
|
86
|
+
|
|
87
|
+
return {
|
|
88
|
+
"name": server_name,
|
|
89
|
+
"config_path": str(config_path),
|
|
90
|
+
"configured": server is not None,
|
|
91
|
+
"server_path": str(server_path) if server_path else "",
|
|
92
|
+
"server_file_exists": bool(server_path and server_path.exists()),
|
|
93
|
+
"running": bool(runtime and runtime.get("running")),
|
|
94
|
+
"tool_count": int(runtime.get("tool_count", 0)) if runtime else 0,
|
|
95
|
+
"tools": list(runtime.get("tools", [])) if runtime else [],
|
|
96
|
+
"description": (server or {}).get("description", ""),
|
|
97
|
+
}
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Adapters from MCP tool descriptors to Aria tool manifests."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Iterable, List
|
|
6
|
+
|
|
7
|
+
from packages.aria_core import PermissionLevel
|
|
8
|
+
from packages.aria_tools import ToolSpec
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _capabilities_for_mcp_tool(name: str, description: str = "") -> List[str]:
|
|
12
|
+
text = f"{name} {description}".lower()
|
|
13
|
+
capabilities: List[str] = []
|
|
14
|
+
if any(word in text for word in ("quote", "market", "ohlc", "price")):
|
|
15
|
+
capabilities.append("market.data")
|
|
16
|
+
if any(word in text for word in ("backtest", "strategy", "simulation")):
|
|
17
|
+
capabilities.append("strategy.backtest")
|
|
18
|
+
if any(word in text for word in ("risk", "var", "cvar", "drawdown")):
|
|
19
|
+
capabilities.append("risk")
|
|
20
|
+
if any(word in text for word in ("factor", "alpha", "feature")):
|
|
21
|
+
capabilities.append("factors")
|
|
22
|
+
if any(word in text for word in ("signal", "predict", "regime")):
|
|
23
|
+
capabilities.append("signals")
|
|
24
|
+
if any(word in text for word in ("news", "filing", "web", "search")):
|
|
25
|
+
capabilities.append("research")
|
|
26
|
+
if any(word in text for word in ("trade", "order", "execution")):
|
|
27
|
+
capabilities.append("broker")
|
|
28
|
+
return capabilities or ["mcp.tool"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _permissions_for_mcp_tool(name: str, description: str = "") -> List[PermissionLevel]:
|
|
32
|
+
text = f"{name} {description}".lower()
|
|
33
|
+
if any(word in text for word in ("trade", "order", "execution")):
|
|
34
|
+
return [PermissionLevel.NETWORK, PermissionLevel.BROKER_TRADE]
|
|
35
|
+
if any(word in text for word in ("write", "export", "report", "backtest")):
|
|
36
|
+
return [PermissionLevel.NETWORK, PermissionLevel.WORKSPACE_WRITE]
|
|
37
|
+
return [PermissionLevel.NETWORK]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def mcp_tool_to_spec(tool: Dict[str, Any], server_name: str) -> ToolSpec:
|
|
41
|
+
"""Convert one MCP tool descriptor into an Aria ToolSpec."""
|
|
42
|
+
|
|
43
|
+
short_name = str(tool.get("name") or "unknown")
|
|
44
|
+
qualified_name = str(tool.get("qualified_name") or f"{server_name}/{short_name}")
|
|
45
|
+
description = str(tool.get("description") or f"MCP tool from {server_name}")
|
|
46
|
+
schema = tool.get("inputSchema") or tool.get("parameters") or {}
|
|
47
|
+
return ToolSpec(
|
|
48
|
+
name=qualified_name,
|
|
49
|
+
handler=None,
|
|
50
|
+
description=description,
|
|
51
|
+
schema=schema if isinstance(schema, dict) else {},
|
|
52
|
+
permissions=_permissions_for_mcp_tool(short_name, description),
|
|
53
|
+
capabilities=_capabilities_for_mcp_tool(short_name, description),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def mcp_tools_to_specs(tools: Iterable[Dict[str, Any]], server_name: str) -> List[ToolSpec]:
|
|
58
|
+
return sorted(
|
|
59
|
+
[mcp_tool_to_spec(tool, server_name) for tool in tools],
|
|
60
|
+
key=lambda item: item.name,
|
|
61
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Public Aria Agent SDK facade."""
|
|
2
|
+
|
|
3
|
+
from .client import AriaSDKClient, query, run
|
|
4
|
+
from .providers import ProviderSelection, build_llm_provider, normalize_provider_name
|
|
5
|
+
from .streaming import stream_provider_result
|
|
6
|
+
from .types import AriaAgentOptions, AriaMessage, AriaResult
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"AriaAgentOptions",
|
|
10
|
+
"AriaMessage",
|
|
11
|
+
"AriaResult",
|
|
12
|
+
"AriaSDKClient",
|
|
13
|
+
"ProviderSelection",
|
|
14
|
+
"build_llm_provider",
|
|
15
|
+
"normalize_provider_name",
|
|
16
|
+
"query",
|
|
17
|
+
"run",
|
|
18
|
+
"stream_provider_result",
|
|
19
|
+
]
|
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
"""Embeddable Aria Agent SDK client.
|
|
2
|
+
|
|
3
|
+
The SDK owns the agent-facing event stream. Terminal UI, Rich panels, and
|
|
4
|
+
interactive prompts remain CLI adapters layered above this package.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import os
|
|
11
|
+
import uuid
|
|
12
|
+
from dataclasses import replace
|
|
13
|
+
from typing import AsyncGenerator
|
|
14
|
+
|
|
15
|
+
from apps.cli.deterministic import run_deterministic_chain
|
|
16
|
+
from apps.cli.providers.base import (
|
|
17
|
+
LLMDone,
|
|
18
|
+
LLMStatus,
|
|
19
|
+
LLMThinking,
|
|
20
|
+
LLMToken,
|
|
21
|
+
LLMToolCall,
|
|
22
|
+
LLMToolResult,
|
|
23
|
+
)
|
|
24
|
+
from runtime import (
|
|
25
|
+
AgentEventCancelled,
|
|
26
|
+
AgentEventComplete,
|
|
27
|
+
AgentEventError,
|
|
28
|
+
AgentEventStatus,
|
|
29
|
+
AgentEventThinking,
|
|
30
|
+
AgentEventToken,
|
|
31
|
+
AgentEventToolCall,
|
|
32
|
+
AgentEventToolResult,
|
|
33
|
+
AgentOptions,
|
|
34
|
+
ToolExecutor,
|
|
35
|
+
run_agent,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from .providers import build_llm_provider
|
|
39
|
+
from .streaming import stream_provider_result
|
|
40
|
+
from .types import AriaAgentOptions, AriaMessage, AriaResult
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AriaSDKClient:
|
|
44
|
+
"""A reusable agent client that can be embedded outside the terminal CLI."""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
options: AriaAgentOptions | None = None,
|
|
49
|
+
*,
|
|
50
|
+
session_id: str | None = None,
|
|
51
|
+
tool_executor: ToolExecutor | None = None,
|
|
52
|
+
remote_runner=None,
|
|
53
|
+
tool_result_formatter=None,
|
|
54
|
+
) -> None:
|
|
55
|
+
self.options = options or AriaAgentOptions()
|
|
56
|
+
self.session_id = session_id or uuid.uuid4().hex
|
|
57
|
+
self.messages: list[dict[str, str]] = []
|
|
58
|
+
self.tool_executor = tool_executor
|
|
59
|
+
self.remote_runner = remote_runner
|
|
60
|
+
self.tool_result_formatter = tool_result_formatter
|
|
61
|
+
|
|
62
|
+
async def query(
|
|
63
|
+
self,
|
|
64
|
+
prompt: str,
|
|
65
|
+
*,
|
|
66
|
+
history: list | None = None,
|
|
67
|
+
cancel_event: asyncio.Event | None = None,
|
|
68
|
+
) -> AsyncGenerator[AriaMessage, None]:
|
|
69
|
+
"""Yield SDK events for one agent turn."""
|
|
70
|
+
|
|
71
|
+
active_history = list(history) if history is not None else list(self.messages)
|
|
72
|
+
user_record = {"role": "user", "content": prompt}
|
|
73
|
+
|
|
74
|
+
yield AriaMessage(
|
|
75
|
+
kind="system",
|
|
76
|
+
role="system",
|
|
77
|
+
content="aria_sdk.turn_started",
|
|
78
|
+
data={
|
|
79
|
+
"session_id": self.session_id,
|
|
80
|
+
"model": self.options.model,
|
|
81
|
+
"provider": self.options.provider,
|
|
82
|
+
"local_mode": self.options.local_mode,
|
|
83
|
+
"cwd": self.options.cwd or os.getcwd(),
|
|
84
|
+
"permission_mode": self.options.permission_mode,
|
|
85
|
+
},
|
|
86
|
+
)
|
|
87
|
+
yield AriaMessage(kind="user", role="user", content=prompt)
|
|
88
|
+
|
|
89
|
+
if self.options.deterministic:
|
|
90
|
+
deterministic = run_deterministic_chain(
|
|
91
|
+
prompt,
|
|
92
|
+
model_has_tools=self.options.model_has_tools,
|
|
93
|
+
history=active_history,
|
|
94
|
+
has_brokers=self.options.has_brokers,
|
|
95
|
+
get_broker_registry=self.options.get_broker_registry,
|
|
96
|
+
)
|
|
97
|
+
if deterministic.get("success"):
|
|
98
|
+
content = str(deterministic.get("response", ""))
|
|
99
|
+
self.messages.extend([user_record, {"role": "assistant", "content": content}])
|
|
100
|
+
yield AriaMessage(
|
|
101
|
+
kind="assistant",
|
|
102
|
+
role="assistant",
|
|
103
|
+
content=content,
|
|
104
|
+
data={
|
|
105
|
+
"provider": "deterministic",
|
|
106
|
+
"tools_used": list(deterministic.get("tools_used") or []),
|
|
107
|
+
"raw": deterministic,
|
|
108
|
+
},
|
|
109
|
+
)
|
|
110
|
+
yield AriaMessage(
|
|
111
|
+
kind="result",
|
|
112
|
+
role="assistant",
|
|
113
|
+
content=content,
|
|
114
|
+
data={
|
|
115
|
+
"success": True,
|
|
116
|
+
"provider": "deterministic",
|
|
117
|
+
"session_id": self.session_id,
|
|
118
|
+
"tools_used": list(deterministic.get("tools_used") or []),
|
|
119
|
+
},
|
|
120
|
+
)
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
if self.tool_executor is not None:
|
|
124
|
+
async for event in self._run_agent(prompt, history=active_history, cancel_event=cancel_event):
|
|
125
|
+
yield event
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
async for event in self._run_llm(prompt, history=active_history, cancel_event=cancel_event):
|
|
129
|
+
yield event
|
|
130
|
+
|
|
131
|
+
async def _provider_result(
|
|
132
|
+
self,
|
|
133
|
+
prompt: str,
|
|
134
|
+
history: list,
|
|
135
|
+
*,
|
|
136
|
+
on_token=None,
|
|
137
|
+
on_thinking=None,
|
|
138
|
+
on_tool_call=None,
|
|
139
|
+
on_tool_result=None,
|
|
140
|
+
on_status=None,
|
|
141
|
+
cancel_event: asyncio.Event | None = None,
|
|
142
|
+
) -> dict:
|
|
143
|
+
selection = build_llm_provider(self.options)
|
|
144
|
+
result = await stream_provider_result(
|
|
145
|
+
selection.provider,
|
|
146
|
+
prompt,
|
|
147
|
+
history,
|
|
148
|
+
tools=list(self.options.tool_schemas),
|
|
149
|
+
cancel_event=cancel_event,
|
|
150
|
+
on_token=on_token,
|
|
151
|
+
on_thinking=on_thinking,
|
|
152
|
+
on_tool_call=on_tool_call,
|
|
153
|
+
on_tool_result=on_tool_result,
|
|
154
|
+
on_status=on_status,
|
|
155
|
+
)
|
|
156
|
+
if not result.get("provider") or result.get("provider") == "unknown":
|
|
157
|
+
result["provider"] = selection.name
|
|
158
|
+
return result
|
|
159
|
+
|
|
160
|
+
async def _run_agent(
|
|
161
|
+
self,
|
|
162
|
+
prompt: str,
|
|
163
|
+
*,
|
|
164
|
+
history: list,
|
|
165
|
+
cancel_event: asyncio.Event | None = None,
|
|
166
|
+
) -> AsyncGenerator[AriaMessage, None]:
|
|
167
|
+
"""Run the provider through the shared runtime tool loop."""
|
|
168
|
+
|
|
169
|
+
if self.tool_executor is None:
|
|
170
|
+
return
|
|
171
|
+
|
|
172
|
+
final_text = ""
|
|
173
|
+
async for event in run_agent(
|
|
174
|
+
prompt,
|
|
175
|
+
history,
|
|
176
|
+
provider_fn=self._provider_result,
|
|
177
|
+
tool_executor=self.tool_executor,
|
|
178
|
+
options=AgentOptions(
|
|
179
|
+
max_rounds=max(1, int(self.options.max_turns or 1)),
|
|
180
|
+
tool_schemas=list(self.options.tool_schemas),
|
|
181
|
+
),
|
|
182
|
+
remote_runner=self.remote_runner,
|
|
183
|
+
cancel_event=cancel_event,
|
|
184
|
+
tool_result_formatter=self.tool_result_formatter,
|
|
185
|
+
):
|
|
186
|
+
if isinstance(event, AgentEventToken):
|
|
187
|
+
yield AriaMessage(kind="token", role="assistant", content=event.text)
|
|
188
|
+
elif isinstance(event, AgentEventThinking):
|
|
189
|
+
yield AriaMessage(kind="thinking", role="assistant", content=event.content)
|
|
190
|
+
elif isinstance(event, AgentEventToolCall):
|
|
191
|
+
yield AriaMessage(
|
|
192
|
+
kind="tool_use",
|
|
193
|
+
role="assistant",
|
|
194
|
+
content=event.tool,
|
|
195
|
+
data={"tool": event.tool, "params": dict(event.params)},
|
|
196
|
+
)
|
|
197
|
+
elif isinstance(event, AgentEventToolResult):
|
|
198
|
+
yield AriaMessage(
|
|
199
|
+
kind="tool_result",
|
|
200
|
+
role="tool",
|
|
201
|
+
content=event.tool,
|
|
202
|
+
data={
|
|
203
|
+
"tool": event.tool,
|
|
204
|
+
"result": dict(event.result),
|
|
205
|
+
"elapsed": event.elapsed,
|
|
206
|
+
},
|
|
207
|
+
)
|
|
208
|
+
elif isinstance(event, AgentEventStatus):
|
|
209
|
+
yield AriaMessage(
|
|
210
|
+
kind="status",
|
|
211
|
+
role="system",
|
|
212
|
+
content=event.message,
|
|
213
|
+
data={"state": event.state},
|
|
214
|
+
)
|
|
215
|
+
elif isinstance(event, AgentEventCancelled):
|
|
216
|
+
final_text = event.partial_text
|
|
217
|
+
yield AriaMessage(
|
|
218
|
+
kind="result",
|
|
219
|
+
role="assistant",
|
|
220
|
+
content=final_text,
|
|
221
|
+
data={
|
|
222
|
+
"success": True,
|
|
223
|
+
"cancelled": True,
|
|
224
|
+
"provider": "runtime",
|
|
225
|
+
"session_id": self.session_id,
|
|
226
|
+
},
|
|
227
|
+
)
|
|
228
|
+
elif isinstance(event, AgentEventError):
|
|
229
|
+
yield AriaMessage(
|
|
230
|
+
kind="result",
|
|
231
|
+
role="assistant",
|
|
232
|
+
content="",
|
|
233
|
+
data={
|
|
234
|
+
"success": False,
|
|
235
|
+
"provider": "runtime",
|
|
236
|
+
"session_id": self.session_id,
|
|
237
|
+
"error": event.error,
|
|
238
|
+
},
|
|
239
|
+
)
|
|
240
|
+
elif isinstance(event, AgentEventComplete):
|
|
241
|
+
final_text = event.result.final_text
|
|
242
|
+
if event.result.success:
|
|
243
|
+
self.messages.extend([
|
|
244
|
+
{"role": "user", "content": prompt},
|
|
245
|
+
{"role": "assistant", "content": final_text},
|
|
246
|
+
])
|
|
247
|
+
yield AriaMessage(
|
|
248
|
+
kind="assistant",
|
|
249
|
+
role="assistant",
|
|
250
|
+
content=final_text,
|
|
251
|
+
data=event.result.to_dict(),
|
|
252
|
+
)
|
|
253
|
+
yield AriaMessage(
|
|
254
|
+
kind="result",
|
|
255
|
+
role="assistant",
|
|
256
|
+
content=final_text,
|
|
257
|
+
data={
|
|
258
|
+
"success": event.result.success,
|
|
259
|
+
"cancelled": event.result.cancelled,
|
|
260
|
+
"provider": event.result.provider,
|
|
261
|
+
"session_id": self.session_id,
|
|
262
|
+
"tools": list(event.result.tools),
|
|
263
|
+
"error": event.result.error,
|
|
264
|
+
},
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
async def _run_llm(
|
|
268
|
+
self,
|
|
269
|
+
prompt: str,
|
|
270
|
+
*,
|
|
271
|
+
history: list,
|
|
272
|
+
cancel_event: asyncio.Event | None = None,
|
|
273
|
+
) -> AsyncGenerator[AriaMessage, None]:
|
|
274
|
+
"""Fallback to the configured LLM provider and normalize its events."""
|
|
275
|
+
|
|
276
|
+
selection = build_llm_provider(self.options)
|
|
277
|
+
provider = selection.provider
|
|
278
|
+
messages = list(history) + [{"role": "user", "content": prompt}]
|
|
279
|
+
token_parts: list[str] = []
|
|
280
|
+
|
|
281
|
+
try:
|
|
282
|
+
async for llm_event in provider.stream(messages, [], cancel_event=cancel_event):
|
|
283
|
+
if isinstance(llm_event, LLMToken):
|
|
284
|
+
token_parts.append(llm_event.text)
|
|
285
|
+
yield AriaMessage(kind="token", role="assistant", content=llm_event.text)
|
|
286
|
+
elif isinstance(llm_event, LLMThinking):
|
|
287
|
+
yield AriaMessage(kind="thinking", role="assistant", content=llm_event.content)
|
|
288
|
+
elif isinstance(llm_event, LLMToolCall):
|
|
289
|
+
yield AriaMessage(
|
|
290
|
+
kind="tool_use",
|
|
291
|
+
role="assistant",
|
|
292
|
+
content=llm_event.tool,
|
|
293
|
+
data={"tool": llm_event.tool, "params": dict(llm_event.params)},
|
|
294
|
+
)
|
|
295
|
+
elif isinstance(llm_event, LLMToolResult):
|
|
296
|
+
yield AriaMessage(
|
|
297
|
+
kind="tool_result",
|
|
298
|
+
role="tool",
|
|
299
|
+
content=llm_event.summary,
|
|
300
|
+
data={"tool": llm_event.tool, "summary": llm_event.summary},
|
|
301
|
+
)
|
|
302
|
+
elif isinstance(llm_event, LLMStatus):
|
|
303
|
+
yield AriaMessage(
|
|
304
|
+
kind="status",
|
|
305
|
+
role="system",
|
|
306
|
+
content=llm_event.message,
|
|
307
|
+
data={"state": llm_event.state},
|
|
308
|
+
)
|
|
309
|
+
elif isinstance(llm_event, LLMDone):
|
|
310
|
+
content = llm_event.response or "".join(token_parts)
|
|
311
|
+
if llm_event.success:
|
|
312
|
+
self.messages.extend([
|
|
313
|
+
{"role": "user", "content": prompt},
|
|
314
|
+
{"role": "assistant", "content": content},
|
|
315
|
+
])
|
|
316
|
+
yield AriaMessage(
|
|
317
|
+
kind="assistant",
|
|
318
|
+
role="assistant",
|
|
319
|
+
content=content,
|
|
320
|
+
data={
|
|
321
|
+
"provider": llm_event.provider,
|
|
322
|
+
"usage": dict(llm_event.usage),
|
|
323
|
+
"success": llm_event.success,
|
|
324
|
+
"error": llm_event.error,
|
|
325
|
+
},
|
|
326
|
+
)
|
|
327
|
+
yield AriaMessage(
|
|
328
|
+
kind="result",
|
|
329
|
+
role="assistant",
|
|
330
|
+
content=content,
|
|
331
|
+
data={
|
|
332
|
+
"success": llm_event.success,
|
|
333
|
+
"provider": llm_event.provider,
|
|
334
|
+
"session_id": self.session_id,
|
|
335
|
+
"error": llm_event.error,
|
|
336
|
+
},
|
|
337
|
+
)
|
|
338
|
+
except Exception as exc:
|
|
339
|
+
yield AriaMessage(
|
|
340
|
+
kind="result",
|
|
341
|
+
role="assistant",
|
|
342
|
+
content="",
|
|
343
|
+
data={
|
|
344
|
+
"success": False,
|
|
345
|
+
"provider": selection.name,
|
|
346
|
+
"session_id": self.session_id,
|
|
347
|
+
"error": str(exc),
|
|
348
|
+
},
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
async def query(
|
|
353
|
+
prompt: str,
|
|
354
|
+
*,
|
|
355
|
+
options: AriaAgentOptions | None = None,
|
|
356
|
+
**overrides,
|
|
357
|
+
) -> AsyncGenerator[AriaMessage, None]:
|
|
358
|
+
"""Convenience async generator for one-off SDK calls."""
|
|
359
|
+
|
|
360
|
+
resolved_options = options or AriaAgentOptions()
|
|
361
|
+
if overrides:
|
|
362
|
+
resolved_options = replace(resolved_options, **overrides)
|
|
363
|
+
client = AriaSDKClient(resolved_options)
|
|
364
|
+
async for event in client.query(prompt):
|
|
365
|
+
yield event
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
async def run(
|
|
369
|
+
prompt: str,
|
|
370
|
+
*,
|
|
371
|
+
options: AriaAgentOptions | None = None,
|
|
372
|
+
**overrides,
|
|
373
|
+
) -> AriaResult:
|
|
374
|
+
"""Collect a one-off SDK query and return the final result."""
|
|
375
|
+
|
|
376
|
+
final: AriaMessage | None = None
|
|
377
|
+
async for event in query(prompt, options=options, **overrides):
|
|
378
|
+
if event.kind == "result":
|
|
379
|
+
final = event
|
|
380
|
+
if final is None:
|
|
381
|
+
return AriaResult(success=False, error="no_result")
|
|
382
|
+
return AriaResult(
|
|
383
|
+
success=bool(final.data.get("success")),
|
|
384
|
+
content=final.content,
|
|
385
|
+
provider=str(final.data.get("provider") or ""),
|
|
386
|
+
session_id=str(final.data.get("session_id") or ""),
|
|
387
|
+
error=str(final.data.get("error") or ""),
|
|
388
|
+
data=dict(final.data),
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
__all__ = [
|
|
393
|
+
"AriaSDKClient",
|
|
394
|
+
"query",
|
|
395
|
+
"run",
|
|
396
|
+
]
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Provider selection for the public Aria SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
from apps.cli.providers.base import AriaSSEProvider, LLMProvider, OllamaProvider
|
|
8
|
+
|
|
9
|
+
from .types import AriaAgentOptions
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class ProviderSelection:
|
|
14
|
+
"""Resolved provider instance and normalized provider name."""
|
|
15
|
+
|
|
16
|
+
name: str
|
|
17
|
+
provider: LLMProvider
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def normalize_provider_name(name: str) -> str:
|
|
21
|
+
key = (name or "auto").strip().lower().replace("-", "_")
|
|
22
|
+
aliases = {
|
|
23
|
+
"local": "ollama",
|
|
24
|
+
"ollama_local": "ollama",
|
|
25
|
+
"cloud": "aria_sse",
|
|
26
|
+
"remote": "aria_sse",
|
|
27
|
+
"sse": "aria_sse",
|
|
28
|
+
"aria": "aria_sse",
|
|
29
|
+
}
|
|
30
|
+
return aliases.get(key, key)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def build_llm_provider(options: AriaAgentOptions) -> ProviderSelection:
|
|
34
|
+
"""Build the configured LLM provider for SDK turns."""
|
|
35
|
+
|
|
36
|
+
requested = normalize_provider_name(options.provider)
|
|
37
|
+
if requested == "auto":
|
|
38
|
+
requested = "ollama" if options.local_mode else "aria_sse"
|
|
39
|
+
|
|
40
|
+
if requested == "ollama":
|
|
41
|
+
return ProviderSelection(
|
|
42
|
+
name="ollama",
|
|
43
|
+
provider=OllamaProvider(
|
|
44
|
+
options.ollama_url,
|
|
45
|
+
options.model,
|
|
46
|
+
system_override=options.system_prompt or None,
|
|
47
|
+
),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if requested == "aria_sse":
|
|
51
|
+
return ProviderSelection(
|
|
52
|
+
name="aria_sse",
|
|
53
|
+
provider=AriaSSEProvider(
|
|
54
|
+
options.api_url,
|
|
55
|
+
options.model,
|
|
56
|
+
auth_token=options.auth_token or None,
|
|
57
|
+
thinking_mode=options.thinking_mode,
|
|
58
|
+
user_context=dict(options.user_context),
|
|
59
|
+
system_override=options.system_prompt or None,
|
|
60
|
+
),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
raise ValueError(f"Unsupported SDK provider: {options.provider}")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
__all__ = [
|
|
67
|
+
"ProviderSelection",
|
|
68
|
+
"build_llm_provider",
|
|
69
|
+
"normalize_provider_name",
|
|
70
|
+
]
|