voxagent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- voxagent/__init__.py +143 -0
- voxagent/_version.py +5 -0
- voxagent/agent/__init__.py +32 -0
- voxagent/agent/abort.py +178 -0
- voxagent/agent/core.py +902 -0
- voxagent/code/__init__.py +9 -0
- voxagent/mcp/__init__.py +16 -0
- voxagent/mcp/manager.py +188 -0
- voxagent/mcp/tool.py +152 -0
- voxagent/providers/__init__.py +110 -0
- voxagent/providers/anthropic.py +498 -0
- voxagent/providers/augment.py +293 -0
- voxagent/providers/auth.py +116 -0
- voxagent/providers/base.py +268 -0
- voxagent/providers/chatgpt.py +415 -0
- voxagent/providers/claudecode.py +162 -0
- voxagent/providers/cli_base.py +265 -0
- voxagent/providers/codex.py +183 -0
- voxagent/providers/failover.py +90 -0
- voxagent/providers/google.py +532 -0
- voxagent/providers/groq.py +96 -0
- voxagent/providers/ollama.py +425 -0
- voxagent/providers/openai.py +435 -0
- voxagent/providers/registry.py +175 -0
- voxagent/py.typed +1 -0
- voxagent/security/__init__.py +14 -0
- voxagent/security/events.py +75 -0
- voxagent/security/filter.py +169 -0
- voxagent/security/registry.py +87 -0
- voxagent/session/__init__.py +39 -0
- voxagent/session/compaction.py +237 -0
- voxagent/session/lock.py +103 -0
- voxagent/session/model.py +109 -0
- voxagent/session/storage.py +184 -0
- voxagent/streaming/__init__.py +52 -0
- voxagent/streaming/emitter.py +286 -0
- voxagent/streaming/events.py +255 -0
- voxagent/subagent/__init__.py +20 -0
- voxagent/subagent/context.py +124 -0
- voxagent/subagent/definition.py +172 -0
- voxagent/tools/__init__.py +32 -0
- voxagent/tools/context.py +50 -0
- voxagent/tools/decorator.py +175 -0
- voxagent/tools/definition.py +131 -0
- voxagent/tools/executor.py +109 -0
- voxagent/tools/policy.py +89 -0
- voxagent/tools/registry.py +89 -0
- voxagent/types/__init__.py +46 -0
- voxagent/types/messages.py +134 -0
- voxagent/types/run.py +176 -0
- voxagent-0.1.0.dist-info/METADATA +186 -0
- voxagent-0.1.0.dist-info/RECORD +53 -0
- voxagent-0.1.0.dist-info/WHEEL +4 -0
voxagent/agent/core.py
ADDED
|
@@ -0,0 +1,902 @@
|
|
|
1
|
+
"""Agent core module for voxagent."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
import inspect
|
|
6
|
+
import re
|
|
7
|
+
import time
|
|
8
|
+
import uuid
|
|
9
|
+
from collections.abc import AsyncIterator
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
|
|
11
|
+
|
|
12
|
+
from voxagent.agent.abort import AbortController, TimeoutHandler
|
|
13
|
+
from voxagent.mcp import MCPServerManager
|
|
14
|
+
from voxagent.providers.base import (
|
|
15
|
+
AbortSignal,
|
|
16
|
+
BaseProvider,
|
|
17
|
+
ErrorChunk,
|
|
18
|
+
MessageEndChunk,
|
|
19
|
+
TextDeltaChunk,
|
|
20
|
+
ToolUseChunk,
|
|
21
|
+
)
|
|
22
|
+
from voxagent.streaming.events import (
|
|
23
|
+
RunEndEvent,
|
|
24
|
+
RunErrorEvent,
|
|
25
|
+
RunStartEvent,
|
|
26
|
+
StreamEventData,
|
|
27
|
+
TextDeltaEvent,
|
|
28
|
+
ToolEndEvent,
|
|
29
|
+
ToolStartEvent,
|
|
30
|
+
)
|
|
31
|
+
from voxagent.subagent.context import DEFAULT_MAX_DEPTH
|
|
32
|
+
from voxagent.subagent.definition import SubAgentDefinition
|
|
33
|
+
from voxagent.tools.definition import ToolDefinition
|
|
34
|
+
from voxagent.tools.executor import execute_tool
|
|
35
|
+
from voxagent.tools.registry import ToolRegistry
|
|
36
|
+
from voxagent.types.messages import Message, ToolCall, ToolResult
|
|
37
|
+
from voxagent.types.run import ModelConfig, RunResult, ToolMeta
|
|
38
|
+
|
|
39
|
+
if TYPE_CHECKING:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
DepsT = TypeVar("DepsT")
|
|
43
|
+
OutputT = TypeVar("OutputT")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Agent(Generic[DepsT, OutputT]):
|
|
47
|
+
"""Main agent class for voxagent.
|
|
48
|
+
|
|
49
|
+
The Agent is the primary entry point for voxagent. It combines:
|
|
50
|
+
- Model configuration (provider and model)
|
|
51
|
+
- Dependency injection (deps_type)
|
|
52
|
+
- Output type specification (output_type)
|
|
53
|
+
- Tool registration and management
|
|
54
|
+
- Security configuration (secret patterns, redaction)
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
model: str, # "provider:model" format
|
|
60
|
+
*,
|
|
61
|
+
name: str | None = None,
|
|
62
|
+
deps_type: type[DepsT] | None = None,
|
|
63
|
+
output_type: type[OutputT] | None = None,
|
|
64
|
+
system_prompt: str | None = None,
|
|
65
|
+
tools: list[ToolDefinition] | None = None,
|
|
66
|
+
toolsets: list[Any] | None = None,
|
|
67
|
+
sub_agents: list["Agent[Any, Any]"] | None = None,
|
|
68
|
+
max_sub_agent_depth: int = DEFAULT_MAX_DEPTH,
|
|
69
|
+
retries: int = 1,
|
|
70
|
+
result_retries: int = 1,
|
|
71
|
+
# Security features
|
|
72
|
+
secret_patterns: list[str] | None = None,
|
|
73
|
+
secrets_to_redact: dict[str, str] | None = None,
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Initialize the Agent.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
model: Model string in "provider:model" format (e.g., "openai:gpt-4")
|
|
79
|
+
name: Optional name for this agent (used when registered as sub-agent)
|
|
80
|
+
deps_type: Optional type for dependencies
|
|
81
|
+
output_type: Optional type for structured output
|
|
82
|
+
system_prompt: Optional system prompt for the agent
|
|
83
|
+
tools: Optional list of ToolDefinitions to register
|
|
84
|
+
toolsets: Optional list of toolsets (MCP servers, etc.)
|
|
85
|
+
sub_agents: Optional list of child Agents to register as tools
|
|
86
|
+
max_sub_agent_depth: Maximum nesting depth for sub-agent calls (default: 5)
|
|
87
|
+
retries: Number of retries for failed operations (default: 1)
|
|
88
|
+
result_retries: Number of retries for result validation (default: 1)
|
|
89
|
+
secret_patterns: Regex patterns to detect and mask secrets
|
|
90
|
+
secrets_to_redact: Dictionary of named secrets to redact
|
|
91
|
+
"""
|
|
92
|
+
# Parse and validate model string
|
|
93
|
+
self._model_config = self._parse_model_string(model)
|
|
94
|
+
self._model_string = model
|
|
95
|
+
|
|
96
|
+
# Store configuration
|
|
97
|
+
self._name = name
|
|
98
|
+
self._deps_type = deps_type
|
|
99
|
+
self._output_type = output_type
|
|
100
|
+
self._system_prompt = system_prompt
|
|
101
|
+
self._retries = retries
|
|
102
|
+
self._result_retries = result_retries
|
|
103
|
+
self._max_sub_agent_depth = max_sub_agent_depth
|
|
104
|
+
|
|
105
|
+
# Security configuration
|
|
106
|
+
self._secret_patterns = secret_patterns
|
|
107
|
+
self._secrets_to_redact = secrets_to_redact
|
|
108
|
+
|
|
109
|
+
# Tool registry
|
|
110
|
+
self._tool_registry = ToolRegistry()
|
|
111
|
+
if tools:
|
|
112
|
+
for t in tools:
|
|
113
|
+
self._tool_registry.register(t)
|
|
114
|
+
|
|
115
|
+
# Register sub-agents as tools
|
|
116
|
+
if sub_agents:
|
|
117
|
+
for sub_agent in sub_agents:
|
|
118
|
+
sub_tool = SubAgentDefinition(
|
|
119
|
+
agent=sub_agent,
|
|
120
|
+
name=sub_agent._name,
|
|
121
|
+
max_depth=max_sub_agent_depth,
|
|
122
|
+
)
|
|
123
|
+
self._tool_registry.register(sub_tool)
|
|
124
|
+
|
|
125
|
+
# Toolsets (MCP servers, etc.) - store for later
|
|
126
|
+
self._toolsets = toolsets or []
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def _parse_model_string(model_string: str) -> ModelConfig:
|
|
130
|
+
"""Parse 'provider:model' string into ModelConfig.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
model_string: String in format "provider:model"
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
ModelConfig instance
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
ValueError: If the model string is invalid
|
|
140
|
+
"""
|
|
141
|
+
if ":" not in model_string:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Invalid model string '{model_string}'. "
|
|
144
|
+
"Expected format: 'provider:model' (e.g., 'openai:gpt-4')"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Split on first colon only (model names can contain colons)
|
|
148
|
+
parts = model_string.split(":", 1)
|
|
149
|
+
provider = parts[0].lower()
|
|
150
|
+
model = parts[1]
|
|
151
|
+
|
|
152
|
+
if not provider:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Invalid model string '{model_string}'. "
|
|
155
|
+
"provider must be non-empty."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if not model:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Invalid model string '{model_string}'. "
|
|
161
|
+
"model must be non-empty."
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return ModelConfig(provider=provider, model=model)
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def name(self) -> str | None:
|
|
168
|
+
"""Get the agent name."""
|
|
169
|
+
return self._name
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def model(self) -> str:
|
|
173
|
+
"""Get the model name."""
|
|
174
|
+
return self._model_config.model
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def provider(self) -> str:
|
|
178
|
+
"""Get the provider name."""
|
|
179
|
+
return self._model_config.provider
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def model_string(self) -> str:
|
|
183
|
+
"""Get the full model string."""
|
|
184
|
+
return self._model_string
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def model_config(self) -> ModelConfig:
|
|
188
|
+
"""Get the model configuration."""
|
|
189
|
+
return self._model_config
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def deps_type(self) -> type[DepsT] | None:
|
|
193
|
+
"""Get the deps type."""
|
|
194
|
+
return self._deps_type
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def output_type(self) -> type[OutputT] | None:
|
|
198
|
+
"""Get the output type."""
|
|
199
|
+
return self._output_type
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def system_prompt(self) -> str | None:
|
|
203
|
+
"""Get the system prompt."""
|
|
204
|
+
return self._system_prompt
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def retries(self) -> int:
|
|
208
|
+
"""Get the number of retries."""
|
|
209
|
+
return self._retries
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def result_retries(self) -> int:
|
|
213
|
+
"""Get the number of result retries."""
|
|
214
|
+
return self._result_retries
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def secret_patterns(self) -> list[str] | None:
|
|
218
|
+
"""Get the secret patterns."""
|
|
219
|
+
return self._secret_patterns
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def secrets_to_redact(self) -> dict[str, str] | None:
|
|
223
|
+
"""Get the secrets to redact."""
|
|
224
|
+
return self._secrets_to_redact
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def tools(self) -> list[ToolDefinition]:
|
|
228
|
+
"""Get a copy of the registered tools."""
|
|
229
|
+
return list(self._tool_registry.list())
|
|
230
|
+
|
|
231
|
+
@property
|
|
232
|
+
def has_tools(self) -> bool:
|
|
233
|
+
"""Check if any tools are registered."""
|
|
234
|
+
return len(self._tool_registry.list()) > 0
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def toolsets(self) -> list[Any]:
|
|
238
|
+
"""Get the toolsets (MCP servers, etc.)."""
|
|
239
|
+
return self._toolsets
|
|
240
|
+
|
|
241
|
+
def _get_all_tools(self, mcp_tools: list[ToolDefinition] | None = None) -> list[ToolDefinition]:
|
|
242
|
+
"""Get all tools including MCP tools.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
mcp_tools: Optional list of MCP tools to include.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Combined list of native and MCP tools.
|
|
249
|
+
"""
|
|
250
|
+
all_tools = list(self._tool_registry.list())
|
|
251
|
+
if mcp_tools:
|
|
252
|
+
all_tools.extend(mcp_tools)
|
|
253
|
+
return all_tools
|
|
254
|
+
|
|
255
|
+
def _has_any_tools(self, mcp_tools: list[ToolDefinition] | None = None) -> bool:
|
|
256
|
+
"""Check if any tools are available (native or MCP).
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
mcp_tools: Optional list of MCP tools.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
True if any tools are available.
|
|
263
|
+
"""
|
|
264
|
+
if self.has_tools:
|
|
265
|
+
return True
|
|
266
|
+
return bool(mcp_tools)
|
|
267
|
+
|
|
268
|
+
def tool(self, fn: Callable[..., Any]) -> ToolDefinition:
|
|
269
|
+
"""Decorator to register a tool from a function.
|
|
270
|
+
|
|
271
|
+
This method can be used as a decorator to register a function as a tool:
|
|
272
|
+
|
|
273
|
+
@agent.tool
|
|
274
|
+
def my_function(x: int) -> str:
|
|
275
|
+
'''Description of the tool.'''
|
|
276
|
+
return str(x)
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
fn: The function to convert to a ToolDefinition
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
ToolDefinition: The created tool definition
|
|
283
|
+
"""
|
|
284
|
+
# Build tool definition from function
|
|
285
|
+
tool_name = fn.__name__
|
|
286
|
+
|
|
287
|
+
# Validate name
|
|
288
|
+
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", tool_name):
|
|
289
|
+
raise ValueError(f"Invalid tool name: {tool_name}")
|
|
290
|
+
|
|
291
|
+
# Get description from docstring
|
|
292
|
+
tool_description = ""
|
|
293
|
+
if fn.__doc__:
|
|
294
|
+
tool_description = fn.__doc__.strip().split("\n")[0].strip()
|
|
295
|
+
|
|
296
|
+
# Check if async
|
|
297
|
+
is_async = inspect.iscoroutinefunction(fn)
|
|
298
|
+
|
|
299
|
+
# Build parameters schema
|
|
300
|
+
parameters = self._build_parameters_schema(fn)
|
|
301
|
+
|
|
302
|
+
# Create ToolDefinition
|
|
303
|
+
tool_def = ToolDefinition(
|
|
304
|
+
name=tool_name,
|
|
305
|
+
description=tool_description,
|
|
306
|
+
parameters=parameters,
|
|
307
|
+
execute=fn,
|
|
308
|
+
is_async=is_async,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# Register the tool
|
|
312
|
+
self._tool_registry.register(tool_def)
|
|
313
|
+
|
|
314
|
+
return tool_def
|
|
315
|
+
|
|
316
|
+
def register_tool(self, tool_def: ToolDefinition) -> None:
|
|
317
|
+
"""Register a tool definition.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
tool_def: The ToolDefinition to register
|
|
321
|
+
|
|
322
|
+
Raises:
|
|
323
|
+
ToolAlreadyRegisteredError: If a tool with the same name is already registered
|
|
324
|
+
"""
|
|
325
|
+
self._tool_registry.register(tool_def)
|
|
326
|
+
|
|
327
|
+
@staticmethod
|
|
328
|
+
def _build_parameters_schema(fn: Callable[..., Any]) -> dict[str, Any]:
|
|
329
|
+
"""Build JSON Schema from function type hints.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
fn: The function to extract parameters from
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
JSON Schema dict for the function parameters
|
|
336
|
+
"""
|
|
337
|
+
from typing import Union, get_args, get_origin, get_type_hints
|
|
338
|
+
|
|
339
|
+
sig = inspect.signature(fn)
|
|
340
|
+
hints: dict[str, Any] = {}
|
|
341
|
+
try:
|
|
342
|
+
hints = get_type_hints(fn)
|
|
343
|
+
except Exception:
|
|
344
|
+
pass
|
|
345
|
+
|
|
346
|
+
properties: dict[str, Any] = {}
|
|
347
|
+
required: list[str] = []
|
|
348
|
+
|
|
349
|
+
for param_name, param in sig.parameters.items():
|
|
350
|
+
# Skip 'context' or 'ctx' parameter (ToolContext)
|
|
351
|
+
if param_name in ("context", "ctx"):
|
|
352
|
+
continue
|
|
353
|
+
|
|
354
|
+
# Skip *args and **kwargs
|
|
355
|
+
if param.kind in (
|
|
356
|
+
inspect.Parameter.VAR_POSITIONAL,
|
|
357
|
+
inspect.Parameter.VAR_KEYWORD,
|
|
358
|
+
):
|
|
359
|
+
continue
|
|
360
|
+
|
|
361
|
+
# Get type hint
|
|
362
|
+
type_hint = hints.get(param_name, Any)
|
|
363
|
+
|
|
364
|
+
# Convert type to JSON Schema
|
|
365
|
+
prop_schema = Agent._type_to_json_schema(type_hint)
|
|
366
|
+
properties[param_name] = prop_schema
|
|
367
|
+
|
|
368
|
+
# Check if required (no default value)
|
|
369
|
+
if param.default is inspect.Parameter.empty:
|
|
370
|
+
required.append(param_name)
|
|
371
|
+
else:
|
|
372
|
+
# Add default to schema
|
|
373
|
+
if param.default is not None:
|
|
374
|
+
properties[param_name]["default"] = param.default
|
|
375
|
+
|
|
376
|
+
schema: dict[str, Any] = {
|
|
377
|
+
"type": "object",
|
|
378
|
+
"properties": properties,
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
if required:
|
|
382
|
+
schema["required"] = required
|
|
383
|
+
else:
|
|
384
|
+
schema["required"] = []
|
|
385
|
+
|
|
386
|
+
return schema
|
|
387
|
+
|
|
388
|
+
@staticmethod
|
|
389
|
+
def _type_to_json_schema(type_hint: Any) -> dict[str, Any]:
|
|
390
|
+
"""Convert a Python type hint to JSON Schema."""
|
|
391
|
+
from typing import Union, get_args, get_origin
|
|
392
|
+
|
|
393
|
+
# Handle None/NoneType
|
|
394
|
+
if type_hint is type(None):
|
|
395
|
+
return {"type": "null"}
|
|
396
|
+
|
|
397
|
+
# Handle basic types
|
|
398
|
+
if type_hint is str:
|
|
399
|
+
return {"type": "string"}
|
|
400
|
+
if type_hint is int:
|
|
401
|
+
return {"type": "integer"}
|
|
402
|
+
if type_hint is float:
|
|
403
|
+
return {"type": "number"}
|
|
404
|
+
if type_hint is bool:
|
|
405
|
+
return {"type": "boolean"}
|
|
406
|
+
|
|
407
|
+
# Handle Optional (Union with None)
|
|
408
|
+
origin = get_origin(type_hint)
|
|
409
|
+
args = get_args(type_hint)
|
|
410
|
+
|
|
411
|
+
if origin is Union:
|
|
412
|
+
# Check if it's Optional (Union[X, None])
|
|
413
|
+
non_none_args = [a for a in args if a is not type(None)]
|
|
414
|
+
if len(non_none_args) == 1 and type(None) in args:
|
|
415
|
+
# It's Optional[X]
|
|
416
|
+
inner_schema = Agent._type_to_json_schema(non_none_args[0])
|
|
417
|
+
inner_schema["nullable"] = True
|
|
418
|
+
return inner_schema
|
|
419
|
+
# General Union - use anyOf
|
|
420
|
+
return {"anyOf": [Agent._type_to_json_schema(a) for a in args]}
|
|
421
|
+
|
|
422
|
+
# Handle list
|
|
423
|
+
if origin is list:
|
|
424
|
+
if args:
|
|
425
|
+
return {"type": "array", "items": Agent._type_to_json_schema(args[0])}
|
|
426
|
+
return {"type": "array"}
|
|
427
|
+
|
|
428
|
+
# Handle dict
|
|
429
|
+
if origin is dict:
|
|
430
|
+
return {"type": "object"}
|
|
431
|
+
|
|
432
|
+
# Handle Any
|
|
433
|
+
if type_hint is Any:
|
|
434
|
+
return {}
|
|
435
|
+
|
|
436
|
+
# Default to object for unknown types
|
|
437
|
+
return {"type": "object"}
|
|
438
|
+
|
|
439
|
+
# =========================================================================
|
|
440
|
+
# Provider Methods
|
|
441
|
+
# =========================================================================
|
|
442
|
+
|
|
443
|
+
def _get_provider(self) -> BaseProvider:
|
|
444
|
+
"""Get the provider for this agent's model.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
The provider instance.
|
|
448
|
+
|
|
449
|
+
Raises:
|
|
450
|
+
ValueError: If the provider is not found.
|
|
451
|
+
"""
|
|
452
|
+
from voxagent.providers.registry import get_default_registry
|
|
453
|
+
|
|
454
|
+
registry = get_default_registry()
|
|
455
|
+
return registry.get_provider(self._model_string)
|
|
456
|
+
|
|
457
|
+
async def _save_session(self, session_key: str, messages: list[Message]) -> None:
|
|
458
|
+
"""Save session messages to storage.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
session_key: The session key.
|
|
462
|
+
messages: The messages to save.
|
|
463
|
+
"""
|
|
464
|
+
# For now, this is a placeholder - session persistence will be
|
|
465
|
+
# handled by the session storage layer
|
|
466
|
+
pass
|
|
467
|
+
|
|
468
|
+
# =========================================================================
|
|
469
|
+
# Run Methods
|
|
470
|
+
# =========================================================================
|
|
471
|
+
|
|
472
|
+
async def run(
|
|
473
|
+
self,
|
|
474
|
+
prompt: str,
|
|
475
|
+
*,
|
|
476
|
+
deps: DepsT | None = None,
|
|
477
|
+
session_key: str | None = None,
|
|
478
|
+
message_history: list[Message] | None = None,
|
|
479
|
+
timeout_ms: int | None = None,
|
|
480
|
+
) -> RunResult:
|
|
481
|
+
"""Run the agent with a prompt.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
prompt: The user prompt to process.
|
|
485
|
+
deps: Optional dependencies to inject into tools.
|
|
486
|
+
session_key: Optional session key for persistence.
|
|
487
|
+
message_history: Optional message history to prepend.
|
|
488
|
+
timeout_ms: Optional timeout in milliseconds.
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
RunResult containing messages, outputs, and metadata.
|
|
492
|
+
"""
|
|
493
|
+
run_id = str(uuid.uuid4())
|
|
494
|
+
abort_controller = AbortController()
|
|
495
|
+
timeout_handler: TimeoutHandler | None = None
|
|
496
|
+
timed_out = False
|
|
497
|
+
error_message: str | None = None
|
|
498
|
+
mcp_manager: MCPServerManager | None = None
|
|
499
|
+
|
|
500
|
+
if timeout_ms:
|
|
501
|
+
timeout_handler = TimeoutHandler(timeout_ms)
|
|
502
|
+
await timeout_handler.start(abort_controller)
|
|
503
|
+
|
|
504
|
+
try:
|
|
505
|
+
# Connect to MCP servers if any
|
|
506
|
+
mcp_tools: list[ToolDefinition] = []
|
|
507
|
+
if self._toolsets:
|
|
508
|
+
mcp_manager = MCPServerManager()
|
|
509
|
+
await mcp_manager.add_servers(self._toolsets)
|
|
510
|
+
mcp_tools = await mcp_manager.connect_all()
|
|
511
|
+
|
|
512
|
+
# Get all tools (native + MCP)
|
|
513
|
+
all_tools = self._get_all_tools(mcp_tools)
|
|
514
|
+
has_any_tools = self._has_any_tools(mcp_tools)
|
|
515
|
+
|
|
516
|
+
# Build messages list
|
|
517
|
+
messages: list[Message] = []
|
|
518
|
+
|
|
519
|
+
# Add system prompt if present
|
|
520
|
+
if self._system_prompt:
|
|
521
|
+
messages.append(Message(role="system", content=self._system_prompt))
|
|
522
|
+
|
|
523
|
+
# Add message history if provided
|
|
524
|
+
if message_history:
|
|
525
|
+
messages.extend(message_history)
|
|
526
|
+
|
|
527
|
+
# Add user prompt
|
|
528
|
+
messages.append(Message(role="user", content=prompt))
|
|
529
|
+
|
|
530
|
+
# Get provider
|
|
531
|
+
provider = self._get_provider()
|
|
532
|
+
|
|
533
|
+
# Track assistant texts and tool metas
|
|
534
|
+
assistant_texts: list[str] = []
|
|
535
|
+
tool_metas: list[ToolMeta] = []
|
|
536
|
+
|
|
537
|
+
# Inference loop
|
|
538
|
+
while not abort_controller.signal.aborted:
|
|
539
|
+
# Stream from provider
|
|
540
|
+
response_text = ""
|
|
541
|
+
tool_calls: list[ToolCall] = []
|
|
542
|
+
|
|
543
|
+
try:
|
|
544
|
+
async for chunk in provider.stream(
|
|
545
|
+
messages=messages,
|
|
546
|
+
system=self._system_prompt,
|
|
547
|
+
tools=[t.to_openai_schema() for t in all_tools]
|
|
548
|
+
if has_any_tools
|
|
549
|
+
else None,
|
|
550
|
+
abort_signal=abort_controller.signal,
|
|
551
|
+
):
|
|
552
|
+
if isinstance(chunk, TextDeltaChunk):
|
|
553
|
+
response_text += chunk.delta
|
|
554
|
+
elif isinstance(chunk, ToolUseChunk):
|
|
555
|
+
tool_calls.append(chunk.tool_call)
|
|
556
|
+
elif isinstance(chunk, ErrorChunk):
|
|
557
|
+
error_message = chunk.error
|
|
558
|
+
break
|
|
559
|
+
elif isinstance(chunk, MessageEndChunk):
|
|
560
|
+
break
|
|
561
|
+
except Exception as e:
|
|
562
|
+
error_message = str(e)
|
|
563
|
+
break
|
|
564
|
+
|
|
565
|
+
# Record assistant text
|
|
566
|
+
if response_text:
|
|
567
|
+
assistant_texts.append(response_text)
|
|
568
|
+
|
|
569
|
+
# Add assistant message
|
|
570
|
+
messages.append(
|
|
571
|
+
Message(
|
|
572
|
+
role="assistant",
|
|
573
|
+
content=response_text,
|
|
574
|
+
tool_calls=tool_calls if tool_calls else None,
|
|
575
|
+
)
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
# Execute tool calls if any
|
|
579
|
+
if tool_calls:
|
|
580
|
+
for tc in tool_calls:
|
|
581
|
+
start_time = time.monotonic()
|
|
582
|
+
|
|
583
|
+
result = await execute_tool(
|
|
584
|
+
name=tc.name,
|
|
585
|
+
params=tc.params,
|
|
586
|
+
tools=all_tools,
|
|
587
|
+
abort_signal=abort_controller.signal,
|
|
588
|
+
tool_use_id=tc.id,
|
|
589
|
+
deps=deps,
|
|
590
|
+
session_id=session_key,
|
|
591
|
+
run_id=run_id,
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
execution_time_ms = int(
|
|
595
|
+
(time.monotonic() - start_time) * 1000
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
tool_metas.append(
|
|
599
|
+
ToolMeta(
|
|
600
|
+
tool_name=tc.name,
|
|
601
|
+
tool_call_id=tc.id,
|
|
602
|
+
execution_time_ms=execution_time_ms,
|
|
603
|
+
success=not result.is_error,
|
|
604
|
+
error=result.content if result.is_error else None,
|
|
605
|
+
)
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
# Add tool result as user message with tool_result content
|
|
609
|
+
messages.append(
|
|
610
|
+
Message(
|
|
611
|
+
role="user",
|
|
612
|
+
content=[
|
|
613
|
+
{
|
|
614
|
+
"type": "tool_result",
|
|
615
|
+
"tool_use_id": tc.id,
|
|
616
|
+
"tool_name": tc.name,
|
|
617
|
+
"content": result.content,
|
|
618
|
+
"is_error": result.is_error,
|
|
619
|
+
}
|
|
620
|
+
],
|
|
621
|
+
)
|
|
622
|
+
)
|
|
623
|
+
else:
|
|
624
|
+
# No tool calls, done
|
|
625
|
+
break
|
|
626
|
+
|
|
627
|
+
# Check timeout
|
|
628
|
+
if timeout_handler and timeout_handler.expired:
|
|
629
|
+
timed_out = True
|
|
630
|
+
|
|
631
|
+
# Save session if key provided
|
|
632
|
+
if session_key:
|
|
633
|
+
await self._save_session(session_key, messages)
|
|
634
|
+
|
|
635
|
+
return RunResult(
|
|
636
|
+
messages=messages,
|
|
637
|
+
assistant_texts=assistant_texts,
|
|
638
|
+
tool_metas=tool_metas,
|
|
639
|
+
aborted=abort_controller.signal.aborted and not timed_out,
|
|
640
|
+
timed_out=timed_out,
|
|
641
|
+
error=error_message,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
except Exception as e:
|
|
645
|
+
return RunResult(
|
|
646
|
+
messages=[],
|
|
647
|
+
assistant_texts=[],
|
|
648
|
+
tool_metas=[],
|
|
649
|
+
aborted=False,
|
|
650
|
+
timed_out=False,
|
|
651
|
+
error=str(e),
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
finally:
|
|
655
|
+
# Disconnect MCP servers
|
|
656
|
+
if mcp_manager:
|
|
657
|
+
await mcp_manager.disconnect_all()
|
|
658
|
+
if timeout_handler:
|
|
659
|
+
timeout_handler.cancel()
|
|
660
|
+
abort_controller.cleanup()
|
|
661
|
+
|
|
662
|
+
async def run_stream(
|
|
663
|
+
self,
|
|
664
|
+
prompt: str,
|
|
665
|
+
*,
|
|
666
|
+
deps: DepsT | None = None,
|
|
667
|
+
session_key: str | None = None,
|
|
668
|
+
message_history: list[Message] | None = None,
|
|
669
|
+
timeout_ms: int | None = None,
|
|
670
|
+
) -> AsyncIterator[StreamEventData]:
|
|
671
|
+
"""Run the agent with streaming events.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
prompt: The user prompt to process.
|
|
675
|
+
deps: Optional dependencies to inject into tools.
|
|
676
|
+
session_key: Optional session key for persistence.
|
|
677
|
+
message_history: Optional message history to prepend.
|
|
678
|
+
timeout_ms: Optional timeout in milliseconds.
|
|
679
|
+
|
|
680
|
+
Yields:
|
|
681
|
+
StreamEventData events for the run lifecycle.
|
|
682
|
+
"""
|
|
683
|
+
run_id = str(uuid.uuid4())
|
|
684
|
+
session = session_key or f"ephemeral-{run_id}"
|
|
685
|
+
abort_controller = AbortController()
|
|
686
|
+
timeout_handler: TimeoutHandler | None = None
|
|
687
|
+
timed_out = False
|
|
688
|
+
mcp_manager: MCPServerManager | None = None
|
|
689
|
+
|
|
690
|
+
if timeout_ms:
|
|
691
|
+
timeout_handler = TimeoutHandler(timeout_ms)
|
|
692
|
+
await timeout_handler.start(abort_controller)
|
|
693
|
+
|
|
694
|
+
try:
|
|
695
|
+
# Connect to MCP servers if any
|
|
696
|
+
mcp_tools: list[ToolDefinition] = []
|
|
697
|
+
if self._toolsets:
|
|
698
|
+
mcp_manager = MCPServerManager()
|
|
699
|
+
await mcp_manager.add_servers(self._toolsets)
|
|
700
|
+
mcp_tools = await mcp_manager.connect_all()
|
|
701
|
+
|
|
702
|
+
# Get all tools (native + MCP)
|
|
703
|
+
all_tools = self._get_all_tools(mcp_tools)
|
|
704
|
+
has_any_tools = self._has_any_tools(mcp_tools)
|
|
705
|
+
|
|
706
|
+
# Emit run start
|
|
707
|
+
yield RunStartEvent(run_id=run_id, session_key=session)
|
|
708
|
+
|
|
709
|
+
# Build messages list
|
|
710
|
+
messages: list[Message] = []
|
|
711
|
+
|
|
712
|
+
if self._system_prompt:
|
|
713
|
+
messages.append(Message(role="system", content=self._system_prompt))
|
|
714
|
+
|
|
715
|
+
if message_history:
|
|
716
|
+
messages.extend(message_history)
|
|
717
|
+
|
|
718
|
+
messages.append(Message(role="user", content=prompt))
|
|
719
|
+
|
|
720
|
+
# Get provider
|
|
721
|
+
provider = self._get_provider()
|
|
722
|
+
|
|
723
|
+
# Track tool metas
|
|
724
|
+
tool_metas: list[ToolMeta] = []
|
|
725
|
+
|
|
726
|
+
# Inference loop
|
|
727
|
+
while not abort_controller.signal.aborted:
|
|
728
|
+
response_text = ""
|
|
729
|
+
tool_calls: list[ToolCall] = []
|
|
730
|
+
|
|
731
|
+
try:
|
|
732
|
+
async for chunk in provider.stream(
|
|
733
|
+
messages=messages,
|
|
734
|
+
system=self._system_prompt,
|
|
735
|
+
tools=[t.to_openai_schema() for t in all_tools]
|
|
736
|
+
if has_any_tools
|
|
737
|
+
else None,
|
|
738
|
+
abort_signal=abort_controller.signal,
|
|
739
|
+
):
|
|
740
|
+
if isinstance(chunk, TextDeltaChunk):
|
|
741
|
+
response_text += chunk.delta
|
|
742
|
+
yield TextDeltaEvent(run_id=run_id, delta=chunk.delta)
|
|
743
|
+
elif isinstance(chunk, ToolUseChunk):
|
|
744
|
+
tool_calls.append(chunk.tool_call)
|
|
745
|
+
elif isinstance(chunk, ErrorChunk):
|
|
746
|
+
yield RunErrorEvent(run_id=run_id, error=chunk.error)
|
|
747
|
+
break
|
|
748
|
+
elif isinstance(chunk, MessageEndChunk):
|
|
749
|
+
break
|
|
750
|
+
except Exception as e:
|
|
751
|
+
yield RunErrorEvent(run_id=run_id, error=str(e))
|
|
752
|
+
break
|
|
753
|
+
|
|
754
|
+
# Add assistant message
|
|
755
|
+
messages.append(
|
|
756
|
+
Message(
|
|
757
|
+
role="assistant",
|
|
758
|
+
content=response_text,
|
|
759
|
+
tool_calls=tool_calls if tool_calls else None,
|
|
760
|
+
)
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
# Execute tool calls if any
|
|
764
|
+
if tool_calls:
|
|
765
|
+
for tc in tool_calls:
|
|
766
|
+
yield ToolStartEvent(run_id=run_id, tool_call=tc)
|
|
767
|
+
|
|
768
|
+
start_time = time.monotonic()
|
|
769
|
+
|
|
770
|
+
result = await execute_tool(
|
|
771
|
+
name=tc.name,
|
|
772
|
+
params=tc.params,
|
|
773
|
+
tools=all_tools,
|
|
774
|
+
abort_signal=abort_controller.signal,
|
|
775
|
+
tool_use_id=tc.id,
|
|
776
|
+
deps=deps,
|
|
777
|
+
session_id=session_key,
|
|
778
|
+
run_id=run_id,
|
|
779
|
+
)
|
|
780
|
+
|
|
781
|
+
execution_time_ms = int(
|
|
782
|
+
(time.monotonic() - start_time) * 1000
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
tool_metas.append(
|
|
786
|
+
ToolMeta(
|
|
787
|
+
tool_name=tc.name,
|
|
788
|
+
tool_call_id=tc.id,
|
|
789
|
+
execution_time_ms=execution_time_ms,
|
|
790
|
+
success=not result.is_error,
|
|
791
|
+
error=result.content if result.is_error else None,
|
|
792
|
+
)
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
yield ToolEndEvent(
|
|
796
|
+
run_id=run_id,
|
|
797
|
+
tool_call_id=tc.id,
|
|
798
|
+
result=result,
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
# Add tool result message
|
|
802
|
+
messages.append(
|
|
803
|
+
Message(
|
|
804
|
+
role="user",
|
|
805
|
+
content=[
|
|
806
|
+
{
|
|
807
|
+
"type": "tool_result",
|
|
808
|
+
"tool_use_id": tc.id,
|
|
809
|
+
"tool_name": tc.name,
|
|
810
|
+
"content": result.content,
|
|
811
|
+
"is_error": result.is_error,
|
|
812
|
+
}
|
|
813
|
+
],
|
|
814
|
+
)
|
|
815
|
+
)
|
|
816
|
+
else:
|
|
817
|
+
break
|
|
818
|
+
|
|
819
|
+
# Check timeout
|
|
820
|
+
if timeout_handler and timeout_handler.expired:
|
|
821
|
+
timed_out = True
|
|
822
|
+
|
|
823
|
+
# Emit run end
|
|
824
|
+
yield RunEndEvent(
|
|
825
|
+
run_id=run_id,
|
|
826
|
+
messages=messages,
|
|
827
|
+
aborted=abort_controller.signal.aborted and not timed_out,
|
|
828
|
+
timed_out=timed_out,
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
except Exception as e:
|
|
832
|
+
yield RunErrorEvent(run_id=run_id, error=str(e))
|
|
833
|
+
yield RunEndEvent(
|
|
834
|
+
run_id=run_id,
|
|
835
|
+
messages=[],
|
|
836
|
+
aborted=False,
|
|
837
|
+
timed_out=False,
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
finally:
|
|
841
|
+
# Disconnect MCP servers
|
|
842
|
+
if mcp_manager:
|
|
843
|
+
await mcp_manager.disconnect_all()
|
|
844
|
+
if timeout_handler:
|
|
845
|
+
timeout_handler.cancel()
|
|
846
|
+
abort_controller.cleanup()
|
|
847
|
+
|
|
848
|
+
def run_sync(
|
|
849
|
+
self,
|
|
850
|
+
prompt: str,
|
|
851
|
+
*,
|
|
852
|
+
deps: DepsT | None = None,
|
|
853
|
+
session_key: str | None = None,
|
|
854
|
+
message_history: list[Message] | None = None,
|
|
855
|
+
timeout_ms: int | None = None,
|
|
856
|
+
) -> RunResult:
|
|
857
|
+
"""Synchronous wrapper for run().
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
prompt: The user prompt to process.
|
|
861
|
+
deps: Optional dependencies to inject into tools.
|
|
862
|
+
session_key: Optional session key for persistence.
|
|
863
|
+
message_history: Optional message history to prepend.
|
|
864
|
+
timeout_ms: Optional timeout in milliseconds.
|
|
865
|
+
|
|
866
|
+
Returns:
|
|
867
|
+
RunResult containing messages, outputs, and metadata.
|
|
868
|
+
"""
|
|
869
|
+
try:
|
|
870
|
+
loop = asyncio.get_running_loop()
|
|
871
|
+
except RuntimeError:
|
|
872
|
+
loop = None
|
|
873
|
+
|
|
874
|
+
if loop is None:
|
|
875
|
+
# No event loop running, create one
|
|
876
|
+
return asyncio.run(
|
|
877
|
+
self.run(
|
|
878
|
+
prompt,
|
|
879
|
+
deps=deps,
|
|
880
|
+
session_key=session_key,
|
|
881
|
+
message_history=message_history,
|
|
882
|
+
timeout_ms=timeout_ms,
|
|
883
|
+
)
|
|
884
|
+
)
|
|
885
|
+
else:
|
|
886
|
+
# Already in an async context, use nest_asyncio pattern
|
|
887
|
+
# or create new loop in thread
|
|
888
|
+
import concurrent.futures
|
|
889
|
+
|
|
890
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
891
|
+
future = executor.submit(
|
|
892
|
+
asyncio.run,
|
|
893
|
+
self.run(
|
|
894
|
+
prompt,
|
|
895
|
+
deps=deps,
|
|
896
|
+
session_key=session_key,
|
|
897
|
+
message_history=message_history,
|
|
898
|
+
timeout_ms=timeout_ms,
|
|
899
|
+
),
|
|
900
|
+
)
|
|
901
|
+
return future.result()
|
|
902
|
+
|