agnt5 0.3.0a8__cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agnt5 might be problematic. Click here for more details.
- agnt5/__init__.py +119 -0
- agnt5/_compat.py +16 -0
- agnt5/_core.abi3.so +0 -0
- agnt5/_retry_utils.py +172 -0
- agnt5/_schema_utils.py +312 -0
- agnt5/_sentry.py +515 -0
- agnt5/_telemetry.py +191 -0
- agnt5/agent/__init__.py +48 -0
- agnt5/agent/context.py +458 -0
- agnt5/agent/core.py +1793 -0
- agnt5/agent/decorator.py +112 -0
- agnt5/agent/handoff.py +105 -0
- agnt5/agent/registry.py +68 -0
- agnt5/agent/result.py +39 -0
- agnt5/checkpoint.py +246 -0
- agnt5/client.py +1478 -0
- agnt5/context.py +210 -0
- agnt5/entity.py +1230 -0
- agnt5/events.py +566 -0
- agnt5/exceptions.py +102 -0
- agnt5/function.py +325 -0
- agnt5/lm.py +1033 -0
- agnt5/memory.py +521 -0
- agnt5/tool.py +657 -0
- agnt5/tracing.py +196 -0
- agnt5/types.py +110 -0
- agnt5/version.py +19 -0
- agnt5/worker.py +1982 -0
- agnt5/workflow.py +1584 -0
- agnt5-0.3.0a8.dist-info/METADATA +26 -0
- agnt5-0.3.0a8.dist-info/RECORD +32 -0
- agnt5-0.3.0a8.dist-info/WHEEL +5 -0
agnt5/agent/core.py
ADDED
|
@@ -0,0 +1,1793 @@
|
|
|
1
|
+
"""Agent class - core LLM-driven agent with tool orchestration."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
from ..context import Context, get_current_context, set_current_context
|
|
9
|
+
from .. import lm
|
|
10
|
+
from ..lm import GenerateRequest, GenerateResponse, LanguageModel, Message, ModelConfig, ToolDefinition
|
|
11
|
+
from ..tool import Tool, ToolRegistry
|
|
12
|
+
from .._telemetry import setup_module_logger
|
|
13
|
+
from ..exceptions import WaitingForUserInputException
|
|
14
|
+
from ..events import Event, EventType
|
|
15
|
+
|
|
16
|
+
from .context import AgentContext
|
|
17
|
+
from .result import AgentResult
|
|
18
|
+
from .handoff import Handoff
|
|
19
|
+
from .registry import AgentRegistry
|
|
20
|
+
|
|
21
|
+
logger = setup_module_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _serialize_tool_result(result: Any) -> str:
|
|
25
|
+
"""Serialize a tool result to JSON string, handling Pydantic models and other complex types.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
result: The tool execution result (may be Pydantic model, dataclass, dict, etc.)
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
JSON string representation of the result
|
|
32
|
+
"""
|
|
33
|
+
if result is None:
|
|
34
|
+
return "null"
|
|
35
|
+
|
|
36
|
+
# Handle Pydantic models (v2 API)
|
|
37
|
+
if hasattr(result, 'model_dump'):
|
|
38
|
+
return json.dumps(result.model_dump())
|
|
39
|
+
|
|
40
|
+
# Handle Pydantic models (v1 API)
|
|
41
|
+
if hasattr(result, 'dict') and hasattr(result, '__fields__'):
|
|
42
|
+
return json.dumps(result.dict())
|
|
43
|
+
|
|
44
|
+
# Handle dataclasses
|
|
45
|
+
import dataclasses as dc
|
|
46
|
+
if dc.is_dataclass(result) and not isinstance(result, type):
|
|
47
|
+
return json.dumps(dc.asdict(result))
|
|
48
|
+
|
|
49
|
+
# Default JSON serialization
|
|
50
|
+
return json.dumps(result)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class _StreamedLMResponse:
|
|
55
|
+
"""Result from streaming LLM call - contains collected text and any tool calls."""
|
|
56
|
+
text: str
|
|
57
|
+
tool_calls: List[Dict[str, Any]]
|
|
58
|
+
usage: Optional[Dict[str, int]] = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Agent:
|
|
62
|
+
"""Autonomous LLM-driven agent with tool orchestration.
|
|
63
|
+
|
|
64
|
+
Current features:
|
|
65
|
+
- LLM integration (OpenAI, Anthropic, etc.)
|
|
66
|
+
- Tool selection and execution
|
|
67
|
+
- Multi-turn reasoning
|
|
68
|
+
- Context and state management
|
|
69
|
+
|
|
70
|
+
Future enhancements:
|
|
71
|
+
- Durable execution with checkpointing
|
|
72
|
+
- Multi-agent coordination
|
|
73
|
+
- Platform-backed tool execution
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
```python
|
|
77
|
+
from agnt5 import Agent, tool
|
|
78
|
+
|
|
79
|
+
@tool
|
|
80
|
+
async def search_web(query: str) -> str:
|
|
81
|
+
'''Search the web for information.'''
|
|
82
|
+
return f"Results for: {query}"
|
|
83
|
+
|
|
84
|
+
agent = Agent(
|
|
85
|
+
name="researcher",
|
|
86
|
+
model="openai/gpt-4o-mini",
|
|
87
|
+
instructions="You are a research assistant.",
|
|
88
|
+
tools=[search_web],
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
result = await agent.run_sync("Find recent AI developments")
|
|
92
|
+
print(result.output)
|
|
93
|
+
```
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
name: str,
|
|
99
|
+
model: Union[str, LanguageModel],
|
|
100
|
+
instructions: str,
|
|
101
|
+
tools: Optional[List[Any]] = None,
|
|
102
|
+
model_config: Optional[ModelConfig] = None,
|
|
103
|
+
handoffs: Optional[List[Union["Agent", Handoff]]] = None,
|
|
104
|
+
# Legacy parameters (kept for backward compatibility)
|
|
105
|
+
model_name: Optional[str] = None,
|
|
106
|
+
temperature: float = 0.7,
|
|
107
|
+
max_tokens: Optional[int] = None,
|
|
108
|
+
top_p: Optional[float] = None,
|
|
109
|
+
max_iterations: int = 10,
|
|
110
|
+
):
|
|
111
|
+
"""Initialize agent.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
name: Agent identifier
|
|
115
|
+
model: Model specification. Either:
|
|
116
|
+
- String like "openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"
|
|
117
|
+
- LanguageModel instance (legacy, for backward compatibility)
|
|
118
|
+
instructions: System prompt for the agent
|
|
119
|
+
tools: List of tools, Tool instances, or Agents (used as tools)
|
|
120
|
+
model_config: Model configuration (temperature, max_tokens, etc.)
|
|
121
|
+
handoffs: List of agents to hand off to (creates transfer_to_* tools)
|
|
122
|
+
model_name: Deprecated - use `model` parameter instead
|
|
123
|
+
temperature: LLM temperature (0-1). Legacy parameter - prefer model_config.
|
|
124
|
+
max_tokens: Maximum tokens in response. Legacy parameter - prefer model_config.
|
|
125
|
+
top_p: Top-p sampling. Legacy parameter - prefer model_config.
|
|
126
|
+
max_iterations: Maximum reasoning iterations
|
|
127
|
+
"""
|
|
128
|
+
self.name = name
|
|
129
|
+
self.instructions = instructions
|
|
130
|
+
self.max_iterations = max_iterations
|
|
131
|
+
self.logger = logging.getLogger(f"agnt5.agent.{name}")
|
|
132
|
+
|
|
133
|
+
# Handle model parameter: string or LanguageModel
|
|
134
|
+
if isinstance(model, str):
|
|
135
|
+
# New API: model is a string like "openai/gpt-4o-mini"
|
|
136
|
+
self.model = model
|
|
137
|
+
self.model_name = model # For compatibility
|
|
138
|
+
self._language_model = None
|
|
139
|
+
elif isinstance(model, LanguageModel):
|
|
140
|
+
# Legacy API: model is a LanguageModel instance
|
|
141
|
+
self._language_model = model
|
|
142
|
+
self.model = model_name or "mock-model"
|
|
143
|
+
self.model_name = model_name or "mock-model"
|
|
144
|
+
else:
|
|
145
|
+
raise ValueError(f"model must be a string (e.g., 'openai/gpt-4o-mini') or LanguageModel instance")
|
|
146
|
+
|
|
147
|
+
# Model configuration (legacy params take precedence for backward compat)
|
|
148
|
+
self.model_config = model_config
|
|
149
|
+
self.temperature = temperature
|
|
150
|
+
self.max_tokens = max_tokens
|
|
151
|
+
self.top_p = top_p
|
|
152
|
+
|
|
153
|
+
# Cost tracking
|
|
154
|
+
self._cumulative_cost_usd: float = 0.0
|
|
155
|
+
|
|
156
|
+
# Initialize tools registry
|
|
157
|
+
self.tools: Dict[str, Tool] = {}
|
|
158
|
+
|
|
159
|
+
if tools:
|
|
160
|
+
for item in tools:
|
|
161
|
+
if isinstance(item, Tool):
|
|
162
|
+
self.tools[item.name] = item
|
|
163
|
+
elif isinstance(item, Agent):
|
|
164
|
+
# Agent as tool - wrap it
|
|
165
|
+
agent_tool = item.to_tool()
|
|
166
|
+
self.tools[agent_tool.name] = agent_tool
|
|
167
|
+
self.logger.debug(f"Wrapped agent '{item.name}' as tool")
|
|
168
|
+
elif callable(item):
|
|
169
|
+
# Function decorated with @tool
|
|
170
|
+
tool_instance = ToolRegistry.get(item.__name__)
|
|
171
|
+
if tool_instance:
|
|
172
|
+
self.tools[tool_instance.name] = tool_instance
|
|
173
|
+
else:
|
|
174
|
+
self.logger.warning(f"Tool '{item.__name__}' not found in registry")
|
|
175
|
+
else:
|
|
176
|
+
self.logger.warning(f"Skipping unknown tool type: {type(item)}")
|
|
177
|
+
|
|
178
|
+
# Store handoffs for introspection
|
|
179
|
+
self.handoffs: List[Handoff] = []
|
|
180
|
+
|
|
181
|
+
# Process handoffs: create transfer_to_* tools for each target agent
|
|
182
|
+
if handoffs:
|
|
183
|
+
for item in handoffs:
|
|
184
|
+
if isinstance(item, Agent):
|
|
185
|
+
# Auto-wrap Agent in Handoff with defaults
|
|
186
|
+
handoff_config = Handoff(agent=item)
|
|
187
|
+
elif isinstance(item, Handoff):
|
|
188
|
+
handoff_config = item
|
|
189
|
+
else:
|
|
190
|
+
self.logger.warning(f"Skipping unknown handoff type: {type(item)}")
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
# Store the handoff configuration
|
|
194
|
+
self.handoffs.append(handoff_config)
|
|
195
|
+
|
|
196
|
+
# Create handoff tool
|
|
197
|
+
handoff_tool = self._create_handoff_tool(handoff_config)
|
|
198
|
+
self.tools[handoff_tool.name] = handoff_tool
|
|
199
|
+
self.logger.debug(f"Added handoff tool '{handoff_tool.name}'")
|
|
200
|
+
|
|
201
|
+
# Auto-register agent in registry (similar to Entity auto-registration)
|
|
202
|
+
AgentRegistry.register(self)
|
|
203
|
+
self.logger.debug(f"Auto-registered agent '{self.name}'")
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def cumulative_cost_usd(self) -> float:
|
|
207
|
+
"""Get cumulative cost of all LLM calls for this agent.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Total cost in USD
|
|
211
|
+
"""
|
|
212
|
+
return self._cumulative_cost_usd
|
|
213
|
+
|
|
214
|
+
def _track_llm_cost(self, response: GenerateResponse, workflow_ctx: Optional[Any] = None) -> None:
|
|
215
|
+
"""Track LLM call cost.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
response: LLM response containing usage/cost info
|
|
219
|
+
workflow_ctx: Optional workflow context for emitting cost events
|
|
220
|
+
"""
|
|
221
|
+
cost_usd = getattr(response, 'cost_usd', None)
|
|
222
|
+
if cost_usd:
|
|
223
|
+
self._cumulative_cost_usd += cost_usd
|
|
224
|
+
self.logger.debug(
|
|
225
|
+
f"LLM call cost: ${cost_usd:.6f}, "
|
|
226
|
+
f"cumulative: ${self._cumulative_cost_usd:.6f}"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Emit cost event for observability
|
|
230
|
+
if workflow_ctx:
|
|
231
|
+
usage = getattr(response, 'usage', None)
|
|
232
|
+
workflow_ctx._send_checkpoint("agent.llm_cost", {
|
|
233
|
+
"agent.name": self.name,
|
|
234
|
+
"call_cost_usd": cost_usd,
|
|
235
|
+
"cumulative_cost_usd": self._cumulative_cost_usd,
|
|
236
|
+
"input_tokens": usage.get("input_tokens") if usage else None,
|
|
237
|
+
"output_tokens": usage.get("output_tokens") if usage else None,
|
|
238
|
+
})
|
|
239
|
+
|
|
240
|
+
def to_tool(self) -> Tool:
|
|
241
|
+
"""Convert this agent to a tool that can be used by other agents.
|
|
242
|
+
|
|
243
|
+
The tool will run this agent and return its output.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Tool instance that wraps this agent
|
|
247
|
+
|
|
248
|
+
Example:
|
|
249
|
+
```python
|
|
250
|
+
# Create specialist agents
|
|
251
|
+
researcher = Agent(name="researcher", ...)
|
|
252
|
+
analyst = Agent(name="analyst", ...)
|
|
253
|
+
|
|
254
|
+
# Use them as tools
|
|
255
|
+
coordinator = Agent(
|
|
256
|
+
name="coordinator",
|
|
257
|
+
tools=[researcher.to_tool(), analyst.to_tool()]
|
|
258
|
+
)
|
|
259
|
+
```
|
|
260
|
+
"""
|
|
261
|
+
from ..tool import tool as tool_decorator
|
|
262
|
+
|
|
263
|
+
# Capture agent reference
|
|
264
|
+
agent = self
|
|
265
|
+
|
|
266
|
+
@tool_decorator(
|
|
267
|
+
name=f"ask_{agent.name}",
|
|
268
|
+
description=agent.instructions or f"Ask the {agent.name} agent for help"
|
|
269
|
+
)
|
|
270
|
+
async def agent_as_tool(ctx: Context, message: str) -> str:
|
|
271
|
+
"""Invoke the agent with a message and return its response."""
|
|
272
|
+
result = await agent.run_sync(message, context=ctx)
|
|
273
|
+
return result.output
|
|
274
|
+
|
|
275
|
+
# Get the tool from registry
|
|
276
|
+
return ToolRegistry.get(f"ask_{agent.name}")
|
|
277
|
+
|
|
278
|
+
def _create_handoff_tool(self, handoff: Handoff) -> Tool:
|
|
279
|
+
"""Create a handoff tool for transferring control to another agent.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
handoff: Handoff configuration
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Tool that performs the handoff
|
|
286
|
+
"""
|
|
287
|
+
from ..tool import tool as tool_decorator
|
|
288
|
+
|
|
289
|
+
target_agent = handoff.agent
|
|
290
|
+
pass_history = handoff.pass_full_history
|
|
291
|
+
|
|
292
|
+
@tool_decorator(
|
|
293
|
+
name=handoff.tool_name,
|
|
294
|
+
description=handoff.description
|
|
295
|
+
)
|
|
296
|
+
async def transfer_tool(ctx: Context, message: str) -> Dict[str, Any]:
|
|
297
|
+
"""Transfer control to another agent.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
ctx: Execution context (auto-injected)
|
|
301
|
+
message: Message to pass to the target agent
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Dict with handoff marker and target agent's result
|
|
305
|
+
"""
|
|
306
|
+
# Get conversation history if available and requested
|
|
307
|
+
history = None
|
|
308
|
+
if pass_history and ctx:
|
|
309
|
+
if hasattr(ctx, '_agent_data') and "_current_conversation" in ctx._agent_data:
|
|
310
|
+
history = ctx._agent_data["_current_conversation"]
|
|
311
|
+
|
|
312
|
+
# Run target agent (using run_sync for non-streaming invocation)
|
|
313
|
+
result = await target_agent.run_sync(
|
|
314
|
+
message,
|
|
315
|
+
context=ctx,
|
|
316
|
+
history=history
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Return with handoff marker
|
|
320
|
+
return {
|
|
321
|
+
"_handoff": True,
|
|
322
|
+
"to_agent": target_agent.name,
|
|
323
|
+
"output": result.output,
|
|
324
|
+
"tool_calls": result.tool_calls,
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
return ToolRegistry.get(handoff.tool_name)
|
|
328
|
+
|
|
329
|
+
def _render_prompt(
|
|
330
|
+
self,
|
|
331
|
+
template: str,
|
|
332
|
+
context_vars: Optional[Dict[str, Any]] = None
|
|
333
|
+
) -> str:
|
|
334
|
+
"""Render system prompt template with context variables.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
template: System prompt with {{variable_name}} placeholders
|
|
338
|
+
context_vars: Variables to substitute
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
Rendered prompt string
|
|
342
|
+
"""
|
|
343
|
+
if not context_vars:
|
|
344
|
+
return template
|
|
345
|
+
|
|
346
|
+
rendered = template
|
|
347
|
+
for key, value in context_vars.items():
|
|
348
|
+
placeholder = "{{" + key + "}}"
|
|
349
|
+
if placeholder in rendered:
|
|
350
|
+
rendered = rendered.replace(placeholder, str(value))
|
|
351
|
+
|
|
352
|
+
return rendered
|
|
353
|
+
|
|
354
|
+
def _detect_memory_scope(
|
|
355
|
+
self,
|
|
356
|
+
context: Optional[Context] = None
|
|
357
|
+
) -> tuple[str, str]:
|
|
358
|
+
"""Detect memory scope from context.
|
|
359
|
+
|
|
360
|
+
Priority: user_id > session_id > run_id
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Tuple of (entity_key, scope) where:
|
|
364
|
+
- entity_key: e.g., "user:user-456", "session:abc-123", "run:xyz-789"
|
|
365
|
+
- scope: "user", "session", or "run"
|
|
366
|
+
|
|
367
|
+
Example:
|
|
368
|
+
entity_key, scope = agent._detect_memory_scope(ctx)
|
|
369
|
+
# If ctx.user_id="user-123": ("user:user-123", "user")
|
|
370
|
+
# If ctx.session_id="sess-456": ("session:sess-456", "session")
|
|
371
|
+
# Otherwise: ("run:run-789", "run")
|
|
372
|
+
"""
|
|
373
|
+
# Extract identifiers from context
|
|
374
|
+
user_id = getattr(context, 'user_id', None) if context else None
|
|
375
|
+
session_id = getattr(context, 'session_id', None) if context else None
|
|
376
|
+
run_id = getattr(context, 'run_id', None) if context else None
|
|
377
|
+
|
|
378
|
+
# Priority: user_id > session_id > run_id
|
|
379
|
+
if user_id:
|
|
380
|
+
return (f"user:{user_id}", "user")
|
|
381
|
+
elif session_id and session_id != run_id: # Explicit session (not defaulting to run_id)
|
|
382
|
+
return (f"session:{session_id}", "session")
|
|
383
|
+
elif run_id:
|
|
384
|
+
return (f"run:{run_id}", "run")
|
|
385
|
+
else:
|
|
386
|
+
# Fallback: create ephemeral key
|
|
387
|
+
import uuid
|
|
388
|
+
fallback_run_id = f"agent-{self.name}-{uuid.uuid4().hex[:8]}"
|
|
389
|
+
return (f"run:{fallback_run_id}", "run")
|
|
390
|
+
|
|
391
|
+
async def _run_core(
|
|
392
|
+
self,
|
|
393
|
+
user_message: str,
|
|
394
|
+
context: Optional[Context] = None,
|
|
395
|
+
history: Optional[List[Message]] = None,
|
|
396
|
+
prompt_context: Optional[Dict[str, Any]] = None,
|
|
397
|
+
sequence_start: int = 0,
|
|
398
|
+
) -> AsyncGenerator[Union[Event, AgentResult], None]:
|
|
399
|
+
"""Core streaming execution loop.
|
|
400
|
+
|
|
401
|
+
This async generator yields events during execution and returns
|
|
402
|
+
the final AgentResult as the last yielded item.
|
|
403
|
+
|
|
404
|
+
Yields:
|
|
405
|
+
Event objects (LM events, tool events) during execution
|
|
406
|
+
AgentResult as the final item
|
|
407
|
+
|
|
408
|
+
Used by:
|
|
409
|
+
- run(): Wraps with agent.started/completed events
|
|
410
|
+
- run_sync(): Consumes events and extracts final result
|
|
411
|
+
"""
|
|
412
|
+
sequence = sequence_start
|
|
413
|
+
|
|
414
|
+
# Create or adapt context
|
|
415
|
+
if context is None:
|
|
416
|
+
context = get_current_context()
|
|
417
|
+
|
|
418
|
+
# Capture workflow context for checkpoints
|
|
419
|
+
from ..workflow import WorkflowContext
|
|
420
|
+
workflow_ctx = context if isinstance(context, WorkflowContext) else None
|
|
421
|
+
|
|
422
|
+
if context is None:
|
|
423
|
+
import uuid
|
|
424
|
+
run_id = f"agent-{self.name}-{uuid.uuid4().hex[:8]}"
|
|
425
|
+
context = AgentContext(
|
|
426
|
+
run_id=run_id,
|
|
427
|
+
agent_name=self.name,
|
|
428
|
+
)
|
|
429
|
+
elif isinstance(context, AgentContext):
|
|
430
|
+
pass
|
|
431
|
+
elif hasattr(context, '_workflow_entity'):
|
|
432
|
+
entity_key, scope = self._detect_memory_scope(context)
|
|
433
|
+
import uuid
|
|
434
|
+
run_id = f"{context.run_id}:agent:{self.name}"
|
|
435
|
+
detected_session_id = entity_key.split(":", 1)[1] if ":" in entity_key else context.run_id
|
|
436
|
+
context = AgentContext(
|
|
437
|
+
run_id=run_id,
|
|
438
|
+
agent_name=self.name,
|
|
439
|
+
session_id=detected_session_id,
|
|
440
|
+
parent_context=context,
|
|
441
|
+
runtime_context=getattr(context, '_runtime_context', None),
|
|
442
|
+
)
|
|
443
|
+
else:
|
|
444
|
+
import uuid
|
|
445
|
+
run_id = f"{context.run_id}:agent:{self.name}"
|
|
446
|
+
context = AgentContext(
|
|
447
|
+
run_id=run_id,
|
|
448
|
+
agent_name=self.name,
|
|
449
|
+
parent_context=context,
|
|
450
|
+
runtime_context=getattr(context, '_runtime_context', None),
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Emit checkpoint if in workflow context
|
|
454
|
+
if workflow_ctx is not None:
|
|
455
|
+
workflow_ctx._send_checkpoint("agent.started", {
|
|
456
|
+
"agent.name": self.name,
|
|
457
|
+
"agent.model": self.model_name,
|
|
458
|
+
"agent.tools": list(self.tools.keys()),
|
|
459
|
+
"agent.max_iterations": self.max_iterations,
|
|
460
|
+
"user_message": user_message,
|
|
461
|
+
})
|
|
462
|
+
|
|
463
|
+
# Check for HITL resume
|
|
464
|
+
if workflow_ctx and hasattr(workflow_ctx, "_agent_resume_info"):
|
|
465
|
+
resume_info = workflow_ctx._agent_resume_info
|
|
466
|
+
if resume_info["agent_name"] == self.name:
|
|
467
|
+
self.logger.info("Detected HITL resume, calling resume_from_hitl()")
|
|
468
|
+
delattr(workflow_ctx, "_agent_resume_info")
|
|
469
|
+
result = await self.resume_from_hitl(
|
|
470
|
+
context=workflow_ctx,
|
|
471
|
+
agent_context=resume_info["agent_context"],
|
|
472
|
+
user_response=resume_info["user_response"],
|
|
473
|
+
)
|
|
474
|
+
yield result
|
|
475
|
+
return
|
|
476
|
+
|
|
477
|
+
# Set context in task-local storage
|
|
478
|
+
token = set_current_context(context)
|
|
479
|
+
try:
|
|
480
|
+
# Build conversation messages
|
|
481
|
+
messages: List[Message] = []
|
|
482
|
+
|
|
483
|
+
if history:
|
|
484
|
+
# Convert dicts to Message objects if needed (for JSON history from platform)
|
|
485
|
+
for msg in history:
|
|
486
|
+
if isinstance(msg, Message):
|
|
487
|
+
messages.append(msg)
|
|
488
|
+
elif isinstance(msg, dict):
|
|
489
|
+
role_str = msg.get("role", "user")
|
|
490
|
+
content = msg.get("content", "")
|
|
491
|
+
if role_str == "user":
|
|
492
|
+
messages.append(Message.user(content))
|
|
493
|
+
elif role_str == "assistant":
|
|
494
|
+
messages.append(Message.assistant(content))
|
|
495
|
+
elif role_str == "system":
|
|
496
|
+
messages.append(Message.system(content))
|
|
497
|
+
else:
|
|
498
|
+
messages.append(Message.user(content))
|
|
499
|
+
else:
|
|
500
|
+
# Try to use it as a Message anyway
|
|
501
|
+
messages.append(msg)
|
|
502
|
+
self.logger.debug(f"Prepended {len(history)} messages from explicit history")
|
|
503
|
+
|
|
504
|
+
if isinstance(context, AgentContext):
|
|
505
|
+
stored_messages = await context.get_conversation_history()
|
|
506
|
+
messages.extend(stored_messages)
|
|
507
|
+
|
|
508
|
+
messages.append(Message.user(user_message))
|
|
509
|
+
|
|
510
|
+
if isinstance(context, AgentContext):
|
|
511
|
+
messages_to_save = stored_messages + [Message.user(user_message)] if history else messages
|
|
512
|
+
await context.save_conversation_history(messages_to_save)
|
|
513
|
+
|
|
514
|
+
# Create span for tracing
|
|
515
|
+
from .._core import create_span
|
|
516
|
+
|
|
517
|
+
with create_span(
|
|
518
|
+
self.name,
|
|
519
|
+
"agent",
|
|
520
|
+
context._runtime_context if hasattr(context, "_runtime_context") else None,
|
|
521
|
+
{
|
|
522
|
+
"agent.name": self.name,
|
|
523
|
+
"agent.model": self.model_name,
|
|
524
|
+
"agent.max_iterations": str(self.max_iterations),
|
|
525
|
+
},
|
|
526
|
+
) as span:
|
|
527
|
+
all_tool_calls: List[Dict[str, Any]] = []
|
|
528
|
+
import time as _time
|
|
529
|
+
|
|
530
|
+
# Render system prompt
|
|
531
|
+
rendered_instructions = self._render_prompt(self.instructions, prompt_context)
|
|
532
|
+
|
|
533
|
+
# Reasoning loop
|
|
534
|
+
for iteration in range(self.max_iterations):
|
|
535
|
+
iteration_start_time = _time.time()
|
|
536
|
+
|
|
537
|
+
if workflow_ctx:
|
|
538
|
+
workflow_ctx._send_checkpoint("agent.iteration.started", {
|
|
539
|
+
"agent.name": self.name,
|
|
540
|
+
"iteration": iteration + 1,
|
|
541
|
+
"max_iterations": self.max_iterations,
|
|
542
|
+
})
|
|
543
|
+
|
|
544
|
+
# Build tool definitions
|
|
545
|
+
tool_defs = [
|
|
546
|
+
ToolDefinition(
|
|
547
|
+
name=tool.name,
|
|
548
|
+
description=tool.description,
|
|
549
|
+
parameters=tool.input_schema,
|
|
550
|
+
)
|
|
551
|
+
for tool in self.tools.values()
|
|
552
|
+
]
|
|
553
|
+
|
|
554
|
+
# Build request
|
|
555
|
+
request = GenerateRequest(
|
|
556
|
+
model=self.model if not self._language_model else "mock-model",
|
|
557
|
+
system_prompt=rendered_instructions,
|
|
558
|
+
messages=messages,
|
|
559
|
+
tools=tool_defs if tool_defs else [],
|
|
560
|
+
)
|
|
561
|
+
request.config.temperature = self.temperature
|
|
562
|
+
if self.max_tokens:
|
|
563
|
+
request.config.max_tokens = self.max_tokens
|
|
564
|
+
if self.top_p:
|
|
565
|
+
request.config.top_p = self.top_p
|
|
566
|
+
|
|
567
|
+
# Stream LLM call and yield events
|
|
568
|
+
response_text = ""
|
|
569
|
+
response_tool_calls = []
|
|
570
|
+
|
|
571
|
+
async for item, seq in self._stream_lm_call(request, sequence):
|
|
572
|
+
if isinstance(item, _StreamedLMResponse):
|
|
573
|
+
response_text = item.text
|
|
574
|
+
response_tool_calls = item.tool_calls
|
|
575
|
+
sequence = seq
|
|
576
|
+
else:
|
|
577
|
+
# Yield LM event
|
|
578
|
+
yield item
|
|
579
|
+
sequence = seq
|
|
580
|
+
|
|
581
|
+
# Add assistant response to messages
|
|
582
|
+
messages.append(Message.assistant(response_text))
|
|
583
|
+
|
|
584
|
+
# Check if LLM wants to use tools
|
|
585
|
+
if response_tool_calls:
|
|
586
|
+
self.logger.debug(f"Agent calling {len(response_tool_calls)} tool(s)")
|
|
587
|
+
|
|
588
|
+
if not hasattr(context, '_agent_data'):
|
|
589
|
+
context._agent_data = {}
|
|
590
|
+
context._agent_data["_current_conversation"] = messages
|
|
591
|
+
|
|
592
|
+
# Execute tool calls
|
|
593
|
+
tool_results = []
|
|
594
|
+
for tool_idx, tool_call in enumerate(response_tool_calls):
|
|
595
|
+
tool_name = tool_call["name"]
|
|
596
|
+
tool_args_str = tool_call["arguments"]
|
|
597
|
+
tool_call_id = tool_call.get("id") # From LLM response
|
|
598
|
+
|
|
599
|
+
all_tool_calls.append({
|
|
600
|
+
"name": tool_name,
|
|
601
|
+
"arguments": tool_args_str,
|
|
602
|
+
"iteration": iteration + 1,
|
|
603
|
+
"id": tool_call_id,
|
|
604
|
+
})
|
|
605
|
+
|
|
606
|
+
# Yield tool call started event with unique content_index
|
|
607
|
+
yield Event.agent_tool_call_started(
|
|
608
|
+
tool_name=tool_name,
|
|
609
|
+
arguments=tool_args_str,
|
|
610
|
+
tool_call_id=tool_call_id,
|
|
611
|
+
content_index=tool_idx,
|
|
612
|
+
sequence=sequence,
|
|
613
|
+
)
|
|
614
|
+
sequence += 1
|
|
615
|
+
|
|
616
|
+
try:
|
|
617
|
+
tool_args = json.loads(tool_args_str)
|
|
618
|
+
tool = self.tools.get(tool_name)
|
|
619
|
+
|
|
620
|
+
if not tool:
|
|
621
|
+
result_text = f"Error: Tool '{tool_name}' not found"
|
|
622
|
+
else:
|
|
623
|
+
result = await tool.invoke(context, **tool_args)
|
|
624
|
+
|
|
625
|
+
if isinstance(result, dict) and result.get("_handoff"):
|
|
626
|
+
self.logger.info(f"Handoff to '{result['to_agent']}'")
|
|
627
|
+
if isinstance(context, AgentContext):
|
|
628
|
+
await context.save_conversation_history(messages)
|
|
629
|
+
|
|
630
|
+
# Yield tool completed and final result
|
|
631
|
+
yield Event.agent_tool_call_completed(
|
|
632
|
+
tool_name=tool_name,
|
|
633
|
+
result=_serialize_tool_result(result["output"]),
|
|
634
|
+
tool_call_id=tool_call_id,
|
|
635
|
+
content_index=tool_idx,
|
|
636
|
+
sequence=sequence,
|
|
637
|
+
)
|
|
638
|
+
sequence += 1
|
|
639
|
+
|
|
640
|
+
yield AgentResult(
|
|
641
|
+
output=result["output"],
|
|
642
|
+
tool_calls=all_tool_calls + result.get("tool_calls", []),
|
|
643
|
+
context=context,
|
|
644
|
+
handoff_to=result["to_agent"],
|
|
645
|
+
handoff_metadata=result,
|
|
646
|
+
)
|
|
647
|
+
return
|
|
648
|
+
|
|
649
|
+
result_text = _serialize_tool_result(result)
|
|
650
|
+
|
|
651
|
+
tool_results.append({
|
|
652
|
+
"tool": tool_name,
|
|
653
|
+
"result": result_text,
|
|
654
|
+
"error": None,
|
|
655
|
+
})
|
|
656
|
+
|
|
657
|
+
# Yield tool completed event
|
|
658
|
+
yield Event.agent_tool_call_completed(
|
|
659
|
+
tool_name=tool_name,
|
|
660
|
+
result=result_text,
|
|
661
|
+
tool_call_id=tool_call_id,
|
|
662
|
+
content_index=tool_idx,
|
|
663
|
+
sequence=sequence,
|
|
664
|
+
)
|
|
665
|
+
sequence += 1
|
|
666
|
+
|
|
667
|
+
except WaitingForUserInputException as e:
|
|
668
|
+
self.logger.info(f"Agent pausing for user input at iteration {iteration}")
|
|
669
|
+
messages_dict = [
|
|
670
|
+
{"role": msg.role.value, "content": msg.content}
|
|
671
|
+
for msg in messages
|
|
672
|
+
]
|
|
673
|
+
raise WaitingForUserInputException(
|
|
674
|
+
question=e.question,
|
|
675
|
+
input_type=e.input_type,
|
|
676
|
+
options=e.options,
|
|
677
|
+
checkpoint_state=e.checkpoint_state,
|
|
678
|
+
agent_context={
|
|
679
|
+
"agent_name": self.name,
|
|
680
|
+
"iteration": iteration,
|
|
681
|
+
"messages": messages_dict,
|
|
682
|
+
"tool_results": tool_results,
|
|
683
|
+
"pending_tool_call": {
|
|
684
|
+
"name": tool_call["name"],
|
|
685
|
+
"arguments": tool_call["arguments"],
|
|
686
|
+
"tool_call_index": response_tool_calls.index(tool_call),
|
|
687
|
+
},
|
|
688
|
+
"all_tool_calls": all_tool_calls,
|
|
689
|
+
"model_config": {
|
|
690
|
+
"model": self.model,
|
|
691
|
+
"temperature": self.temperature,
|
|
692
|
+
"max_tokens": self.max_tokens,
|
|
693
|
+
"top_p": self.top_p,
|
|
694
|
+
},
|
|
695
|
+
},
|
|
696
|
+
) from e
|
|
697
|
+
|
|
698
|
+
except Exception as e:
|
|
699
|
+
self.logger.error(f"Tool execution error: {e}")
|
|
700
|
+
tool_results.append({
|
|
701
|
+
"tool": tool_name,
|
|
702
|
+
"result": None,
|
|
703
|
+
"error": str(e),
|
|
704
|
+
})
|
|
705
|
+
yield Event.agent_tool_call_completed(
|
|
706
|
+
tool_name=tool_name,
|
|
707
|
+
result=None,
|
|
708
|
+
error=str(e),
|
|
709
|
+
tool_call_id=tool_call_id,
|
|
710
|
+
content_index=tool_idx,
|
|
711
|
+
sequence=sequence,
|
|
712
|
+
)
|
|
713
|
+
sequence += 1
|
|
714
|
+
|
|
715
|
+
# Add tool results to conversation
|
|
716
|
+
results_text = "\n".join([
|
|
717
|
+
f"Tool: {tr['tool']}\nResult: {tr['result']}"
|
|
718
|
+
if tr["error"] is None
|
|
719
|
+
else f"Tool: {tr['tool']}\nError: {tr['error']}"
|
|
720
|
+
for tr in tool_results
|
|
721
|
+
])
|
|
722
|
+
messages.append(Message.user(
|
|
723
|
+
f"Tool results:\n{results_text}\n\nPlease provide your final answer based on these results."
|
|
724
|
+
))
|
|
725
|
+
|
|
726
|
+
iteration_duration_ms = int((_time.time() - iteration_start_time) * 1000)
|
|
727
|
+
if workflow_ctx:
|
|
728
|
+
workflow_ctx._send_checkpoint("agent.iteration.completed", {
|
|
729
|
+
"agent.name": self.name,
|
|
730
|
+
"iteration": iteration + 1,
|
|
731
|
+
"duration_ms": iteration_duration_ms,
|
|
732
|
+
"has_tool_calls": True,
|
|
733
|
+
"tool_calls_count": len(tool_results),
|
|
734
|
+
})
|
|
735
|
+
|
|
736
|
+
else:
|
|
737
|
+
# No tool calls - agent is done
|
|
738
|
+
self.logger.debug(f"Agent completed after {iteration + 1} iterations")
|
|
739
|
+
|
|
740
|
+
iteration_duration_ms = int((_time.time() - iteration_start_time) * 1000)
|
|
741
|
+
if workflow_ctx:
|
|
742
|
+
workflow_ctx._send_checkpoint("agent.iteration.completed", {
|
|
743
|
+
"agent.name": self.name,
|
|
744
|
+
"iteration": iteration + 1,
|
|
745
|
+
"duration_ms": iteration_duration_ms,
|
|
746
|
+
"has_tool_calls": False,
|
|
747
|
+
})
|
|
748
|
+
|
|
749
|
+
if isinstance(context, AgentContext):
|
|
750
|
+
await context.save_conversation_history(messages)
|
|
751
|
+
|
|
752
|
+
if workflow_ctx:
|
|
753
|
+
workflow_ctx._send_checkpoint("agent.completed", {
|
|
754
|
+
"agent.name": self.name,
|
|
755
|
+
"agent.iterations": iteration + 1,
|
|
756
|
+
"agent.tool_calls_count": len(all_tool_calls),
|
|
757
|
+
"output_length": len(response_text),
|
|
758
|
+
})
|
|
759
|
+
|
|
760
|
+
yield AgentResult(
|
|
761
|
+
output=response_text,
|
|
762
|
+
tool_calls=all_tool_calls,
|
|
763
|
+
context=context,
|
|
764
|
+
)
|
|
765
|
+
return
|
|
766
|
+
|
|
767
|
+
# Max iterations reached
|
|
768
|
+
self.logger.warning(f"Agent reached max iterations ({self.max_iterations})")
|
|
769
|
+
final_output = messages[-1].content if messages else "No output generated"
|
|
770
|
+
|
|
771
|
+
if workflow_ctx:
|
|
772
|
+
workflow_ctx._send_checkpoint("agent.max_iterations.reached", {
|
|
773
|
+
"agent.name": self.name,
|
|
774
|
+
"max_iterations": self.max_iterations,
|
|
775
|
+
"tool_calls_count": len(all_tool_calls),
|
|
776
|
+
})
|
|
777
|
+
|
|
778
|
+
if isinstance(context, AgentContext):
|
|
779
|
+
await context.save_conversation_history(messages)
|
|
780
|
+
|
|
781
|
+
if workflow_ctx:
|
|
782
|
+
workflow_ctx._send_checkpoint("agent.completed", {
|
|
783
|
+
"agent.name": self.name,
|
|
784
|
+
"agent.iterations": self.max_iterations,
|
|
785
|
+
"agent.tool_calls_count": len(all_tool_calls),
|
|
786
|
+
"agent.max_iterations_reached": True,
|
|
787
|
+
"output_length": len(final_output),
|
|
788
|
+
})
|
|
789
|
+
|
|
790
|
+
yield AgentResult(
|
|
791
|
+
output=final_output,
|
|
792
|
+
tool_calls=all_tool_calls,
|
|
793
|
+
context=context,
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
except Exception as e:
|
|
797
|
+
if workflow_ctx:
|
|
798
|
+
workflow_ctx._send_checkpoint("agent.failed", {
|
|
799
|
+
"agent.name": self.name,
|
|
800
|
+
"error": str(e),
|
|
801
|
+
"error_type": type(e).__name__,
|
|
802
|
+
})
|
|
803
|
+
raise
|
|
804
|
+
finally:
|
|
805
|
+
from ..context import _current_context
|
|
806
|
+
_current_context.reset(token)
|
|
807
|
+
|
|
808
|
+
async def _stream_lm_call(
|
|
809
|
+
self,
|
|
810
|
+
request: GenerateRequest,
|
|
811
|
+
sequence_start: int = 0,
|
|
812
|
+
) -> AsyncGenerator[Tuple[Event, int], None]:
|
|
813
|
+
"""Stream an LLM call and yield events.
|
|
814
|
+
|
|
815
|
+
This method calls the LLM and yields LM events (start, delta, stop).
|
|
816
|
+
The final response (including tool_calls) is yielded as a special
|
|
817
|
+
_StreamedLMResponse event at the end.
|
|
818
|
+
|
|
819
|
+
When tools are present, uses generate() with synthetic events since
|
|
820
|
+
streaming doesn't yet support tool calls. When no tools, uses real
|
|
821
|
+
streaming which properly exposes thinking blocks for extended thinking.
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
request: The generate request with model, messages, tools, etc.
|
|
825
|
+
sequence_start: Starting sequence number for events
|
|
826
|
+
|
|
827
|
+
Yields:
|
|
828
|
+
Tuple of (Event, next_sequence) or (_StreamedLMResponse, next_sequence)
|
|
829
|
+
"""
|
|
830
|
+
from ..lm import _LanguageModel
|
|
831
|
+
|
|
832
|
+
sequence = sequence_start
|
|
833
|
+
collected_text = ""
|
|
834
|
+
usage_dict = None
|
|
835
|
+
tool_calls = []
|
|
836
|
+
|
|
837
|
+
# When tools are present, use generate() since streaming doesn't support tool calls
|
|
838
|
+
# When no tools, use real streaming for proper thinking block support
|
|
839
|
+
has_tools = bool(request.tools)
|
|
840
|
+
|
|
841
|
+
if has_tools:
|
|
842
|
+
# Use generate() - streaming doesn't support tool calls yet
|
|
843
|
+
if self._language_model is not None:
|
|
844
|
+
response = await self._language_model.generate(request)
|
|
845
|
+
else:
|
|
846
|
+
provider, model_name = self.model.split('/', 1)
|
|
847
|
+
internal_lm = _LanguageModel(provider=provider.lower(), default_model=None)
|
|
848
|
+
response = await internal_lm.generate(request)
|
|
849
|
+
|
|
850
|
+
# Emit synthetic LM events for compatibility
|
|
851
|
+
yield (Event.message_start(index=0, sequence=sequence), sequence + 1)
|
|
852
|
+
sequence += 1
|
|
853
|
+
if response.text:
|
|
854
|
+
yield (Event.message_delta(content=response.text, index=0, sequence=sequence), sequence + 1)
|
|
855
|
+
sequence += 1
|
|
856
|
+
yield (Event.message_stop(index=0, sequence=sequence), sequence + 1)
|
|
857
|
+
sequence += 1
|
|
858
|
+
|
|
859
|
+
collected_text = response.text
|
|
860
|
+
tool_calls = response.tool_calls or []
|
|
861
|
+
if response.usage:
|
|
862
|
+
usage_dict = {
|
|
863
|
+
"input_tokens": getattr(response.usage, 'input_tokens', getattr(response.usage, 'prompt_tokens', 0)),
|
|
864
|
+
"output_tokens": getattr(response.usage, 'output_tokens', getattr(response.usage, 'completion_tokens', 0)),
|
|
865
|
+
}
|
|
866
|
+
else:
|
|
867
|
+
# Use real streaming - properly exposes thinking blocks
|
|
868
|
+
if self._language_model is not None:
|
|
869
|
+
# Legacy LanguageModel - use stream() method
|
|
870
|
+
async for event in self._language_model.stream(request):
|
|
871
|
+
if event.event_type == EventType.LM_STREAM_COMPLETED:
|
|
872
|
+
# Extract final text and usage from completion event
|
|
873
|
+
collected_text = event.data.get("text", "")
|
|
874
|
+
if "usage" in event.data:
|
|
875
|
+
usage_dict = event.data["usage"]
|
|
876
|
+
else:
|
|
877
|
+
# Forward LM events (thinking/message start/delta/stop)
|
|
878
|
+
event.sequence = sequence
|
|
879
|
+
yield (event, sequence + 1)
|
|
880
|
+
sequence += 1
|
|
881
|
+
# Collect text from message deltas (not thinking)
|
|
882
|
+
if event.event_type == EventType.LM_MESSAGE_DELTA:
|
|
883
|
+
# data is raw content string for delta events
|
|
884
|
+
if event.data:
|
|
885
|
+
collected_text += event.data
|
|
886
|
+
else:
|
|
887
|
+
# New API: model is a string, create internal LM instance
|
|
888
|
+
provider, model_name = self.model.split('/', 1)
|
|
889
|
+
internal_lm = _LanguageModel(provider=provider.lower(), default_model=None)
|
|
890
|
+
async for event in internal_lm.stream(request):
|
|
891
|
+
if event.event_type == EventType.LM_STREAM_COMPLETED:
|
|
892
|
+
# Extract final text and usage from completion event
|
|
893
|
+
collected_text = event.data.get("text", "")
|
|
894
|
+
if "usage" in event.data:
|
|
895
|
+
usage_dict = event.data["usage"]
|
|
896
|
+
else:
|
|
897
|
+
# Forward LM events (thinking/message start/delta/stop)
|
|
898
|
+
event.sequence = sequence
|
|
899
|
+
yield (event, sequence + 1)
|
|
900
|
+
sequence += 1
|
|
901
|
+
# Collect text from message deltas (not thinking)
|
|
902
|
+
if event.event_type == EventType.LM_MESSAGE_DELTA:
|
|
903
|
+
# data is raw content string for delta events
|
|
904
|
+
if event.data:
|
|
905
|
+
collected_text += event.data
|
|
906
|
+
|
|
907
|
+
# Yield the final response
|
|
908
|
+
yield (_StreamedLMResponse(
|
|
909
|
+
text=collected_text,
|
|
910
|
+
tool_calls=tool_calls,
|
|
911
|
+
usage=usage_dict,
|
|
912
|
+
), sequence)
|
|
913
|
+
|
|
914
|
+
async def run(
|
|
915
|
+
self,
|
|
916
|
+
user_message: str,
|
|
917
|
+
context: Optional[Context] = None,
|
|
918
|
+
history: Optional[List[Message]] = None,
|
|
919
|
+
prompt_context: Optional[Dict[str, Any]] = None,
|
|
920
|
+
) -> AsyncGenerator[Event, None]:
|
|
921
|
+
"""Run agent with streaming events.
|
|
922
|
+
|
|
923
|
+
This is an async generator that yields Event objects during execution.
|
|
924
|
+
Use `async for event in agent.run(...)` to process events in real-time.
|
|
925
|
+
|
|
926
|
+
Args:
|
|
927
|
+
user_message: User's input message
|
|
928
|
+
context: Optional execution context (auto-created if not provided)
|
|
929
|
+
history: Optional conversation history to include
|
|
930
|
+
prompt_context: Optional context variables for system prompt template
|
|
931
|
+
|
|
932
|
+
Yields:
|
|
933
|
+
Event objects during execution:
|
|
934
|
+
- agent.started: When agent begins execution
|
|
935
|
+
- lm.message.start/delta/stop: During LLM generation
|
|
936
|
+
- agent.tool_call.started/completed: During tool execution
|
|
937
|
+
- agent.completed: When agent finishes (contains final output)
|
|
938
|
+
|
|
939
|
+
Example:
|
|
940
|
+
```python
|
|
941
|
+
# Streaming execution
|
|
942
|
+
async for event in agent.run("Analyze recent tech news"):
|
|
943
|
+
if event.event_type == EventType.LM_MESSAGE_DELTA:
|
|
944
|
+
print(event.data, end="", flush=True) # data is raw content for deltas
|
|
945
|
+
elif event.event_type == EventType.AGENT_COMPLETED:
|
|
946
|
+
print(f"\\nFinal: {event.data['output']}")
|
|
947
|
+
|
|
948
|
+
# Non-streaming (use run_sync instead)
|
|
949
|
+
result = await agent.run_sync("Analyze recent tech news")
|
|
950
|
+
print(result.output)
|
|
951
|
+
```
|
|
952
|
+
"""
|
|
953
|
+
# Track sequence number for events
|
|
954
|
+
sequence = 0
|
|
955
|
+
|
|
956
|
+
# Yield agent.started event
|
|
957
|
+
yield Event.agent_started(
|
|
958
|
+
agent_name=self.name,
|
|
959
|
+
model=self.model_name,
|
|
960
|
+
tools=list(self.tools.keys()),
|
|
961
|
+
max_iterations=self.max_iterations,
|
|
962
|
+
sequence=sequence,
|
|
963
|
+
)
|
|
964
|
+
sequence += 1
|
|
965
|
+
|
|
966
|
+
try:
|
|
967
|
+
# Run the streaming core loop - yields LM events, tool events, and final result
|
|
968
|
+
result = None
|
|
969
|
+
async for item in self._run_core(
|
|
970
|
+
user_message=user_message,
|
|
971
|
+
context=context,
|
|
972
|
+
history=history,
|
|
973
|
+
prompt_context=prompt_context,
|
|
974
|
+
sequence_start=sequence,
|
|
975
|
+
):
|
|
976
|
+
if isinstance(item, AgentResult):
|
|
977
|
+
# Final result - convert to agent.completed event
|
|
978
|
+
result = item
|
|
979
|
+
sequence = getattr(item, '_last_sequence', sequence)
|
|
980
|
+
elif isinstance(item, Event):
|
|
981
|
+
# Forward LM and tool events
|
|
982
|
+
yield item
|
|
983
|
+
sequence = item.sequence + 1 if hasattr(item, 'sequence') else sequence
|
|
984
|
+
|
|
985
|
+
# Yield agent.completed event with the result
|
|
986
|
+
if result:
|
|
987
|
+
yield Event.agent_completed(
|
|
988
|
+
output=result.output,
|
|
989
|
+
iterations=len(result.tool_calls) // 2 + 1 if result.tool_calls else 1,
|
|
990
|
+
tool_calls=result.tool_calls,
|
|
991
|
+
handoff_to=result.handoff_to,
|
|
992
|
+
max_iterations_reached=False,
|
|
993
|
+
sequence=sequence,
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
except Exception as e:
|
|
997
|
+
# Yield agent.failed event
|
|
998
|
+
yield Event.agent_failed(
|
|
999
|
+
error=str(e),
|
|
1000
|
+
error_type=type(e).__name__,
|
|
1001
|
+
agent_name=self.name,
|
|
1002
|
+
sequence=sequence,
|
|
1003
|
+
)
|
|
1004
|
+
raise
|
|
1005
|
+
|
|
1006
|
+
async def run_sync(
|
|
1007
|
+
self,
|
|
1008
|
+
user_message: str,
|
|
1009
|
+
context: Optional[Context] = None,
|
|
1010
|
+
history: Optional[List[Message]] = None,
|
|
1011
|
+
prompt_context: Optional[Dict[str, Any]] = None,
|
|
1012
|
+
) -> AgentResult:
|
|
1013
|
+
"""Run agent to completion (non-streaming).
|
|
1014
|
+
|
|
1015
|
+
This is the synchronous version that returns an AgentResult directly.
|
|
1016
|
+
Use this when you don't need streaming events.
|
|
1017
|
+
|
|
1018
|
+
Args:
|
|
1019
|
+
user_message: User's input message
|
|
1020
|
+
context: Optional execution context
|
|
1021
|
+
history: Optional conversation history
|
|
1022
|
+
prompt_context: Optional context variables
|
|
1023
|
+
|
|
1024
|
+
Returns:
|
|
1025
|
+
AgentResult with output and execution details
|
|
1026
|
+
|
|
1027
|
+
Example:
|
|
1028
|
+
```python
|
|
1029
|
+
result = await agent.run_sync("Analyze recent tech news")
|
|
1030
|
+
print(result.output)
|
|
1031
|
+
```
|
|
1032
|
+
"""
|
|
1033
|
+
result = None
|
|
1034
|
+
async for event in self.run(user_message, context, history, prompt_context):
|
|
1035
|
+
if event.event_type == EventType.AGENT_COMPLETED:
|
|
1036
|
+
# Extract result from the completed event
|
|
1037
|
+
result = AgentResult(
|
|
1038
|
+
output=event.data["output"],
|
|
1039
|
+
tool_calls=event.data.get("tool_calls", []),
|
|
1040
|
+
context=context,
|
|
1041
|
+
handoff_to=event.data.get("handoff_to"),
|
|
1042
|
+
)
|
|
1043
|
+
elif event.event_type == EventType.AGENT_FAILED:
|
|
1044
|
+
# Re-raise the error (it was already raised in run())
|
|
1045
|
+
pass
|
|
1046
|
+
|
|
1047
|
+
if result is None:
|
|
1048
|
+
# This shouldn't happen, but handle gracefully
|
|
1049
|
+
raise RuntimeError("Agent completed without producing a result")
|
|
1050
|
+
|
|
1051
|
+
return result
|
|
1052
|
+
|
|
1053
|
+
async def _run_impl(
|
|
1054
|
+
self,
|
|
1055
|
+
user_message: str,
|
|
1056
|
+
context: Optional[Context] = None,
|
|
1057
|
+
history: Optional[List[Message]] = None,
|
|
1058
|
+
prompt_context: Optional[Dict[str, Any]] = None,
|
|
1059
|
+
) -> AgentResult:
|
|
1060
|
+
"""Internal implementation of agent execution.
|
|
1061
|
+
|
|
1062
|
+
This contains the core agent loop logic. Called by both run() and run_sync().
|
|
1063
|
+
"""
|
|
1064
|
+
# Create or adapt context
|
|
1065
|
+
if context is None:
|
|
1066
|
+
# Try to get context from task-local storage (set by workflow/function decorator)
|
|
1067
|
+
context = get_current_context()
|
|
1068
|
+
|
|
1069
|
+
# IMPORTANT: Capture workflow context NOW before we replace it with AgentContext
|
|
1070
|
+
# This allows LM calls inside the agent to emit workflow checkpoints
|
|
1071
|
+
from ..workflow import WorkflowContext
|
|
1072
|
+
workflow_ctx = context if isinstance(context, WorkflowContext) else None
|
|
1073
|
+
|
|
1074
|
+
if context is None:
|
|
1075
|
+
# Standalone execution - create AgentContext
|
|
1076
|
+
import uuid
|
|
1077
|
+
run_id = f"agent-{self.name}-{uuid.uuid4().hex[:8]}"
|
|
1078
|
+
context = AgentContext(
|
|
1079
|
+
run_id=run_id,
|
|
1080
|
+
agent_name=self.name,
|
|
1081
|
+
)
|
|
1082
|
+
elif isinstance(context, AgentContext):
|
|
1083
|
+
# Already AgentContext - use as-is
|
|
1084
|
+
pass
|
|
1085
|
+
elif hasattr(context, '_workflow_entity'):
|
|
1086
|
+
# WorkflowContext - create AgentContext that inherits state
|
|
1087
|
+
# Auto-detect memory scope based on user_id/session_id/run_id priority
|
|
1088
|
+
entity_key, scope = self._detect_memory_scope(context)
|
|
1089
|
+
|
|
1090
|
+
import uuid
|
|
1091
|
+
run_id = f"{context.run_id}:agent:{self.name}"
|
|
1092
|
+
# Extract the ID from entity_key (e.g., "session:abc-123" → "abc-123")
|
|
1093
|
+
detected_session_id = entity_key.split(":", 1)[1] if ":" in entity_key else context.run_id
|
|
1094
|
+
|
|
1095
|
+
context = AgentContext(
|
|
1096
|
+
run_id=run_id,
|
|
1097
|
+
agent_name=self.name,
|
|
1098
|
+
session_id=detected_session_id, # Use auto-detected scope
|
|
1099
|
+
parent_context=context,
|
|
1100
|
+
runtime_context=getattr(context, '_runtime_context', None), # Inherit trace context
|
|
1101
|
+
)
|
|
1102
|
+
else:
|
|
1103
|
+
# FunctionContext or other - create new AgentContext
|
|
1104
|
+
import uuid
|
|
1105
|
+
run_id = f"{context.run_id}:agent:{self.name}"
|
|
1106
|
+
context = AgentContext(
|
|
1107
|
+
run_id=run_id,
|
|
1108
|
+
agent_name=self.name,
|
|
1109
|
+
parent_context=context, # Inherit streaming context
|
|
1110
|
+
runtime_context=getattr(context, '_runtime_context', None), # Inherit trace context
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1113
|
+
# Emit checkpoint if called within a workflow context
|
|
1114
|
+
if workflow_ctx is not None:
|
|
1115
|
+
workflow_ctx._send_checkpoint("agent.started", {
|
|
1116
|
+
"agent.name": self.name,
|
|
1117
|
+
"agent.model": self.model_name,
|
|
1118
|
+
"agent.tools": list(self.tools.keys()),
|
|
1119
|
+
"agent.max_iterations": self.max_iterations,
|
|
1120
|
+
"user_message": user_message,
|
|
1121
|
+
})
|
|
1122
|
+
|
|
1123
|
+
# NEW: Check if this is a resume from HITL
|
|
1124
|
+
if workflow_ctx and hasattr(workflow_ctx, "_agent_resume_info"):
|
|
1125
|
+
resume_info = workflow_ctx._agent_resume_info
|
|
1126
|
+
if resume_info["agent_name"] == self.name:
|
|
1127
|
+
self.logger.info("Detected HITL resume, calling resume_from_hitl()")
|
|
1128
|
+
|
|
1129
|
+
# Clear resume info to avoid re-entry
|
|
1130
|
+
delattr(workflow_ctx, "_agent_resume_info")
|
|
1131
|
+
|
|
1132
|
+
# Resume from checkpoint (context setup happens inside resume_from_hitl)
|
|
1133
|
+
return await self.resume_from_hitl(
|
|
1134
|
+
context=workflow_ctx,
|
|
1135
|
+
agent_context=resume_info["agent_context"],
|
|
1136
|
+
user_response=resume_info["user_response"],
|
|
1137
|
+
)
|
|
1138
|
+
|
|
1139
|
+
# Set context in task-local storage for automatic propagation to tools and LM calls
|
|
1140
|
+
token = set_current_context(context)
|
|
1141
|
+
try:
|
|
1142
|
+
try:
|
|
1143
|
+
# Build conversation messages
|
|
1144
|
+
messages: List[Message] = []
|
|
1145
|
+
|
|
1146
|
+
# 1. Start with explicitly provided history (if any)
|
|
1147
|
+
if history:
|
|
1148
|
+
messages.extend(history)
|
|
1149
|
+
self.logger.debug(f"Prepended {len(history)} messages from explicit history")
|
|
1150
|
+
|
|
1151
|
+
# 2. Load conversation history from state (if AgentContext)
|
|
1152
|
+
if isinstance(context, AgentContext):
|
|
1153
|
+
stored_messages = await context.get_conversation_history()
|
|
1154
|
+
messages.extend(stored_messages)
|
|
1155
|
+
|
|
1156
|
+
# 3. Add new user message
|
|
1157
|
+
messages.append(Message.user(user_message))
|
|
1158
|
+
|
|
1159
|
+
# 4. Save updated conversation to context storage
|
|
1160
|
+
if isinstance(context, AgentContext):
|
|
1161
|
+
# Only save the stored + new message (not the explicit history)
|
|
1162
|
+
messages_to_save = stored_messages + [Message.user(user_message)] if history else messages
|
|
1163
|
+
await context.save_conversation_history(messages_to_save)
|
|
1164
|
+
|
|
1165
|
+
# Create span for agent execution with trace linking
|
|
1166
|
+
from .._core import create_span
|
|
1167
|
+
|
|
1168
|
+
with create_span(
|
|
1169
|
+
self.name,
|
|
1170
|
+
"agent",
|
|
1171
|
+
context._runtime_context if hasattr(context, "_runtime_context") else None,
|
|
1172
|
+
{
|
|
1173
|
+
"agent.name": self.name,
|
|
1174
|
+
"agent.model": self.model_name, # Use model_name (always a string)
|
|
1175
|
+
"agent.max_iterations": str(self.max_iterations),
|
|
1176
|
+
},
|
|
1177
|
+
) as span:
|
|
1178
|
+
all_tool_calls: List[Dict[str, Any]] = []
|
|
1179
|
+
import time as _time
|
|
1180
|
+
|
|
1181
|
+
# Emit agent started checkpoint
|
|
1182
|
+
if workflow_ctx:
|
|
1183
|
+
workflow_ctx._send_checkpoint("agent.started", {
|
|
1184
|
+
"agent.name": self.name,
|
|
1185
|
+
"agent.model": self.model_name,
|
|
1186
|
+
"agent.max_iterations": self.max_iterations,
|
|
1187
|
+
"agent.tools_count": len(self.tools),
|
|
1188
|
+
})
|
|
1189
|
+
|
|
1190
|
+
# Render system prompt with context variables
|
|
1191
|
+
rendered_instructions = self._render_prompt(self.instructions, prompt_context)
|
|
1192
|
+
if prompt_context:
|
|
1193
|
+
self.logger.debug(f"Rendered system prompt with {len(prompt_context)} context variables")
|
|
1194
|
+
|
|
1195
|
+
# Reasoning loop
|
|
1196
|
+
for iteration in range(self.max_iterations):
|
|
1197
|
+
iteration_start_time = _time.time()
|
|
1198
|
+
|
|
1199
|
+
# Emit iteration started checkpoint
|
|
1200
|
+
if workflow_ctx:
|
|
1201
|
+
workflow_ctx._send_checkpoint("agent.iteration.started", {
|
|
1202
|
+
"agent.name": self.name,
|
|
1203
|
+
"iteration": iteration + 1,
|
|
1204
|
+
"max_iterations": self.max_iterations,
|
|
1205
|
+
})
|
|
1206
|
+
|
|
1207
|
+
# Build tool definitions for LLM
|
|
1208
|
+
tool_defs = [
|
|
1209
|
+
ToolDefinition(
|
|
1210
|
+
name=tool.name,
|
|
1211
|
+
description=tool.description,
|
|
1212
|
+
parameters=tool.input_schema,
|
|
1213
|
+
)
|
|
1214
|
+
for tool in self.tools.values()
|
|
1215
|
+
]
|
|
1216
|
+
|
|
1217
|
+
# Convert messages to dict format for lm.generate()
|
|
1218
|
+
messages_dict = []
|
|
1219
|
+
for msg in messages:
|
|
1220
|
+
messages_dict.append({
|
|
1221
|
+
"role": msg.role.value,
|
|
1222
|
+
"content": msg.content
|
|
1223
|
+
})
|
|
1224
|
+
|
|
1225
|
+
# Call LLM
|
|
1226
|
+
# Check if we have a legacy LanguageModel instance or need to create one
|
|
1227
|
+
if self._language_model is not None:
|
|
1228
|
+
# Legacy API: use provided LanguageModel instance
|
|
1229
|
+
request = GenerateRequest(
|
|
1230
|
+
model="mock-model", # Not used by MockLanguageModel
|
|
1231
|
+
system_prompt=rendered_instructions,
|
|
1232
|
+
messages=messages,
|
|
1233
|
+
tools=tool_defs if tool_defs else [],
|
|
1234
|
+
)
|
|
1235
|
+
request.config.temperature = self.temperature
|
|
1236
|
+
if self.max_tokens:
|
|
1237
|
+
request.config.max_tokens = self.max_tokens
|
|
1238
|
+
if self.top_p:
|
|
1239
|
+
request.config.top_p = self.top_p
|
|
1240
|
+
response = await self._language_model.generate(request)
|
|
1241
|
+
|
|
1242
|
+
# Track cost for this LLM call
|
|
1243
|
+
self._track_llm_cost(response, workflow_ctx)
|
|
1244
|
+
else:
|
|
1245
|
+
# New API: model is a string, create internal LM instance
|
|
1246
|
+
request = GenerateRequest(
|
|
1247
|
+
model=self.model,
|
|
1248
|
+
system_prompt=rendered_instructions,
|
|
1249
|
+
messages=messages,
|
|
1250
|
+
tools=tool_defs if tool_defs else [],
|
|
1251
|
+
)
|
|
1252
|
+
request.config.temperature = self.temperature
|
|
1253
|
+
if self.max_tokens:
|
|
1254
|
+
request.config.max_tokens = self.max_tokens
|
|
1255
|
+
if self.top_p:
|
|
1256
|
+
request.config.top_p = self.top_p
|
|
1257
|
+
|
|
1258
|
+
# Create internal LM instance for generation
|
|
1259
|
+
# TODO: Use model_config when provided
|
|
1260
|
+
from ..lm import _LanguageModel
|
|
1261
|
+
provider, model_name = self.model.split('/', 1)
|
|
1262
|
+
internal_lm = _LanguageModel(provider=provider.lower(), default_model=None)
|
|
1263
|
+
response = await internal_lm.generate(request)
|
|
1264
|
+
|
|
1265
|
+
# Track cost for this LLM call
|
|
1266
|
+
self._track_llm_cost(response, workflow_ctx)
|
|
1267
|
+
|
|
1268
|
+
# Add assistant response to messages
|
|
1269
|
+
messages.append(Message.assistant(response.text))
|
|
1270
|
+
|
|
1271
|
+
# Check if LLM wants to use tools
|
|
1272
|
+
if response.tool_calls:
|
|
1273
|
+
self.logger.debug(f"Agent calling {len(response.tool_calls)} tool(s)")
|
|
1274
|
+
|
|
1275
|
+
# Store current conversation in context for potential handoffs
|
|
1276
|
+
# Use a simple dict attribute since we don't need full state persistence for this
|
|
1277
|
+
if not hasattr(context, '_agent_data'):
|
|
1278
|
+
context._agent_data = {}
|
|
1279
|
+
context._agent_data["_current_conversation"] = messages
|
|
1280
|
+
|
|
1281
|
+
# Execute tool calls
|
|
1282
|
+
tool_results = []
|
|
1283
|
+
for tool_call in response.tool_calls:
|
|
1284
|
+
tool_name = tool_call["name"]
|
|
1285
|
+
tool_args_str = tool_call["arguments"]
|
|
1286
|
+
|
|
1287
|
+
# Track tool call
|
|
1288
|
+
all_tool_calls.append(
|
|
1289
|
+
{
|
|
1290
|
+
"name": tool_name,
|
|
1291
|
+
"arguments": tool_args_str,
|
|
1292
|
+
"iteration": iteration + 1,
|
|
1293
|
+
}
|
|
1294
|
+
)
|
|
1295
|
+
|
|
1296
|
+
# Execute tool
|
|
1297
|
+
try:
|
|
1298
|
+
# Parse arguments
|
|
1299
|
+
tool_args = json.loads(tool_args_str)
|
|
1300
|
+
|
|
1301
|
+
# Get tool
|
|
1302
|
+
tool = self.tools.get(tool_name)
|
|
1303
|
+
if not tool:
|
|
1304
|
+
result_text = f"Error: Tool '{tool_name}' not found"
|
|
1305
|
+
else:
|
|
1306
|
+
# Execute tool
|
|
1307
|
+
result = await tool.invoke(context, **tool_args)
|
|
1308
|
+
|
|
1309
|
+
# Check if this was a handoff
|
|
1310
|
+
if isinstance(result, dict) and result.get("_handoff"):
|
|
1311
|
+
self.logger.info(
|
|
1312
|
+
f"Handoff detected to '{result['to_agent']}', "
|
|
1313
|
+
f"terminating current agent"
|
|
1314
|
+
)
|
|
1315
|
+
# Save conversation before returning
|
|
1316
|
+
if isinstance(context, AgentContext):
|
|
1317
|
+
await context.save_conversation_history(messages)
|
|
1318
|
+
# Return immediately with handoff result
|
|
1319
|
+
return AgentResult(
|
|
1320
|
+
output=result["output"],
|
|
1321
|
+
tool_calls=all_tool_calls + result.get("tool_calls", []),
|
|
1322
|
+
context=context,
|
|
1323
|
+
handoff_to=result["to_agent"],
|
|
1324
|
+
handoff_metadata=result,
|
|
1325
|
+
)
|
|
1326
|
+
|
|
1327
|
+
result_text = _serialize_tool_result(result)
|
|
1328
|
+
|
|
1329
|
+
tool_results.append(
|
|
1330
|
+
{"tool": tool_name, "result": result_text, "error": None}
|
|
1331
|
+
)
|
|
1332
|
+
|
|
1333
|
+
except WaitingForUserInputException as e:
|
|
1334
|
+
# HITL PAUSE: Capture agent state and propagate exception
|
|
1335
|
+
self.logger.info(f"Agent pausing for user input at iteration {iteration}")
|
|
1336
|
+
|
|
1337
|
+
# Serialize messages to dict format
|
|
1338
|
+
messages_dict = [
|
|
1339
|
+
{"role": msg.role.value, "content": msg.content}
|
|
1340
|
+
for msg in messages
|
|
1341
|
+
]
|
|
1342
|
+
|
|
1343
|
+
# Enhance exception with agent execution context
|
|
1344
|
+
raise WaitingForUserInputException(
|
|
1345
|
+
question=e.question,
|
|
1346
|
+
input_type=e.input_type,
|
|
1347
|
+
options=e.options,
|
|
1348
|
+
checkpoint_state=e.checkpoint_state,
|
|
1349
|
+
agent_context={
|
|
1350
|
+
"agent_name": self.name,
|
|
1351
|
+
"iteration": iteration,
|
|
1352
|
+
"messages": messages_dict,
|
|
1353
|
+
"tool_results": tool_results,
|
|
1354
|
+
"pending_tool_call": {
|
|
1355
|
+
"name": tool_call["name"],
|
|
1356
|
+
"arguments": tool_call["arguments"],
|
|
1357
|
+
"tool_call_index": response.tool_calls.index(tool_call),
|
|
1358
|
+
},
|
|
1359
|
+
"all_tool_calls": all_tool_calls,
|
|
1360
|
+
"model_config": {
|
|
1361
|
+
"model": self.model,
|
|
1362
|
+
"temperature": self.temperature,
|
|
1363
|
+
"max_tokens": self.max_tokens,
|
|
1364
|
+
"top_p": self.top_p,
|
|
1365
|
+
},
|
|
1366
|
+
},
|
|
1367
|
+
) from e
|
|
1368
|
+
|
|
1369
|
+
except Exception as e:
|
|
1370
|
+
# Regular tool errors - log and continue
|
|
1371
|
+
self.logger.error(f"Tool execution error: {e}")
|
|
1372
|
+
tool_results.append(
|
|
1373
|
+
{"tool": tool_name, "result": None, "error": str(e)}
|
|
1374
|
+
)
|
|
1375
|
+
|
|
1376
|
+
# Add tool results to conversation
|
|
1377
|
+
results_text = "\n".join(
|
|
1378
|
+
[
|
|
1379
|
+
f"Tool: {tr['tool']}\nResult: {tr['result']}"
|
|
1380
|
+
if tr["error"] is None
|
|
1381
|
+
else f"Tool: {tr['tool']}\nError: {tr['error']}"
|
|
1382
|
+
for tr in tool_results
|
|
1383
|
+
]
|
|
1384
|
+
)
|
|
1385
|
+
messages.append(Message.user(f"Tool results:\n{results_text}\n\nPlease provide your final answer based on these results."))
|
|
1386
|
+
|
|
1387
|
+
# Emit iteration completed checkpoint (with tool calls)
|
|
1388
|
+
iteration_duration_ms = int((_time.time() - iteration_start_time) * 1000)
|
|
1389
|
+
if workflow_ctx:
|
|
1390
|
+
workflow_ctx._send_checkpoint("agent.iteration.completed", {
|
|
1391
|
+
"agent.name": self.name,
|
|
1392
|
+
"iteration": iteration + 1,
|
|
1393
|
+
"duration_ms": iteration_duration_ms,
|
|
1394
|
+
"has_tool_calls": True,
|
|
1395
|
+
"tool_calls_count": len(tool_results),
|
|
1396
|
+
})
|
|
1397
|
+
|
|
1398
|
+
# Continue loop for agent to process results
|
|
1399
|
+
|
|
1400
|
+
else:
|
|
1401
|
+
# No tool calls - agent is done
|
|
1402
|
+
self.logger.debug(f"Agent completed after {iteration + 1} iterations")
|
|
1403
|
+
|
|
1404
|
+
# Emit iteration completed checkpoint
|
|
1405
|
+
iteration_duration_ms = int((_time.time() - iteration_start_time) * 1000)
|
|
1406
|
+
if workflow_ctx:
|
|
1407
|
+
workflow_ctx._send_checkpoint("agent.iteration.completed", {
|
|
1408
|
+
"agent.name": self.name,
|
|
1409
|
+
"iteration": iteration + 1,
|
|
1410
|
+
"duration_ms": iteration_duration_ms,
|
|
1411
|
+
"has_tool_calls": False,
|
|
1412
|
+
})
|
|
1413
|
+
|
|
1414
|
+
# Save conversation before returning
|
|
1415
|
+
if isinstance(context, AgentContext):
|
|
1416
|
+
await context.save_conversation_history(messages)
|
|
1417
|
+
|
|
1418
|
+
# Emit completion checkpoint
|
|
1419
|
+
if workflow_ctx:
|
|
1420
|
+
workflow_ctx._send_checkpoint("agent.completed", {
|
|
1421
|
+
"agent.name": self.name,
|
|
1422
|
+
"agent.iterations": iteration + 1,
|
|
1423
|
+
"agent.tool_calls_count": len(all_tool_calls),
|
|
1424
|
+
"output_length": len(response.text),
|
|
1425
|
+
})
|
|
1426
|
+
|
|
1427
|
+
return AgentResult(
|
|
1428
|
+
output=response.text,
|
|
1429
|
+
tool_calls=all_tool_calls,
|
|
1430
|
+
context=context,
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
# Max iterations reached
|
|
1434
|
+
self.logger.warning(f"Agent reached max iterations ({self.max_iterations})")
|
|
1435
|
+
final_output = messages[-1].content if messages else "No output generated"
|
|
1436
|
+
|
|
1437
|
+
# Emit max iterations reached checkpoint (separate event for metrics)
|
|
1438
|
+
if workflow_ctx:
|
|
1439
|
+
workflow_ctx._send_checkpoint("agent.max_iterations.reached", {
|
|
1440
|
+
"agent.name": self.name,
|
|
1441
|
+
"max_iterations": self.max_iterations,
|
|
1442
|
+
"tool_calls_count": len(all_tool_calls),
|
|
1443
|
+
})
|
|
1444
|
+
|
|
1445
|
+
# Save conversation before returning
|
|
1446
|
+
if isinstance(context, AgentContext):
|
|
1447
|
+
await context.save_conversation_history(messages)
|
|
1448
|
+
|
|
1449
|
+
# Emit completion checkpoint with max iterations flag
|
|
1450
|
+
if workflow_ctx:
|
|
1451
|
+
workflow_ctx._send_checkpoint("agent.completed", {
|
|
1452
|
+
"agent.name": self.name,
|
|
1453
|
+
"agent.iterations": self.max_iterations,
|
|
1454
|
+
"agent.tool_calls_count": len(all_tool_calls),
|
|
1455
|
+
"agent.max_iterations_reached": True,
|
|
1456
|
+
"output_length": len(final_output),
|
|
1457
|
+
})
|
|
1458
|
+
|
|
1459
|
+
return AgentResult(
|
|
1460
|
+
output=final_output,
|
|
1461
|
+
tool_calls=all_tool_calls,
|
|
1462
|
+
context=context,
|
|
1463
|
+
)
|
|
1464
|
+
except Exception as e:
|
|
1465
|
+
# Emit error checkpoint for observability
|
|
1466
|
+
if workflow_ctx:
|
|
1467
|
+
workflow_ctx._send_checkpoint("agent.failed", {
|
|
1468
|
+
"agent.name": self.name,
|
|
1469
|
+
"error": str(e),
|
|
1470
|
+
"error_type": type(e).__name__,
|
|
1471
|
+
})
|
|
1472
|
+
raise
|
|
1473
|
+
finally:
|
|
1474
|
+
# Always reset context to prevent leakage between agent executions
|
|
1475
|
+
from ..context import _current_context
|
|
1476
|
+
_current_context.reset(token)
|
|
1477
|
+
|
|
1478
|
+
async def resume_from_hitl(
|
|
1479
|
+
self,
|
|
1480
|
+
context: Context,
|
|
1481
|
+
agent_context: Dict,
|
|
1482
|
+
user_response: str,
|
|
1483
|
+
) -> AgentResult:
|
|
1484
|
+
"""
|
|
1485
|
+
Resume agent execution after HITL pause.
|
|
1486
|
+
|
|
1487
|
+
This method reconstructs agent state from the checkpoint and injects
|
|
1488
|
+
the user's response as the successful tool result, then continues
|
|
1489
|
+
the conversation loop.
|
|
1490
|
+
|
|
1491
|
+
Args:
|
|
1492
|
+
context: Current execution context (workflow or agent)
|
|
1493
|
+
agent_context: Agent state from WaitingForUserInputException.agent_context
|
|
1494
|
+
user_response: User's answer to the HITL question
|
|
1495
|
+
|
|
1496
|
+
Returns:
|
|
1497
|
+
AgentResult with final output and tool calls
|
|
1498
|
+
"""
|
|
1499
|
+
self.logger.info(f"Resuming agent '{self.name}' from HITL pause")
|
|
1500
|
+
|
|
1501
|
+
# 1. Restore conversation state
|
|
1502
|
+
messages = [
|
|
1503
|
+
Message(role=lm.MessageRole(msg["role"]), content=msg["content"])
|
|
1504
|
+
for msg in agent_context["messages"]
|
|
1505
|
+
]
|
|
1506
|
+
iteration = agent_context["iteration"]
|
|
1507
|
+
all_tool_calls = agent_context["all_tool_calls"]
|
|
1508
|
+
|
|
1509
|
+
# 2. Restore partial tool results for current iteration
|
|
1510
|
+
tool_results = agent_context["tool_results"]
|
|
1511
|
+
|
|
1512
|
+
# 3. Inject user response as successful tool result
|
|
1513
|
+
pending_tool = agent_context["pending_tool_call"]
|
|
1514
|
+
tool_results.append({
|
|
1515
|
+
"tool": pending_tool["name"],
|
|
1516
|
+
"result": json.dumps(user_response),
|
|
1517
|
+
"error": None,
|
|
1518
|
+
})
|
|
1519
|
+
|
|
1520
|
+
self.logger.debug(
|
|
1521
|
+
f"Injected user response for tool '{pending_tool['name']}': {user_response}"
|
|
1522
|
+
)
|
|
1523
|
+
|
|
1524
|
+
# 4. Add tool results to conversation
|
|
1525
|
+
results_text = "\n".join([
|
|
1526
|
+
f"Tool: {tr['tool']}\nResult: {tr['result']}"
|
|
1527
|
+
if tr["error"] is None
|
|
1528
|
+
else f"Tool: {tr['tool']}\nError: {tr['error']}"
|
|
1529
|
+
for tr in tool_results
|
|
1530
|
+
])
|
|
1531
|
+
messages.append(Message.user(
|
|
1532
|
+
f"Tool results:\n{results_text}\n\n"
|
|
1533
|
+
f"Please provide your final answer based on these results."
|
|
1534
|
+
))
|
|
1535
|
+
|
|
1536
|
+
# 5. Continue agent execution loop from next iteration
|
|
1537
|
+
return await self._continue_execution_from_iteration(
|
|
1538
|
+
context=context,
|
|
1539
|
+
messages=messages,
|
|
1540
|
+
iteration=iteration + 1, # Next iteration
|
|
1541
|
+
all_tool_calls=all_tool_calls,
|
|
1542
|
+
)
|
|
1543
|
+
|
|
1544
|
+
async def _continue_execution_from_iteration(
|
|
1545
|
+
self,
|
|
1546
|
+
context: Context,
|
|
1547
|
+
messages: List[Message],
|
|
1548
|
+
iteration: int,
|
|
1549
|
+
all_tool_calls: List[Dict],
|
|
1550
|
+
) -> AgentResult:
|
|
1551
|
+
"""
|
|
1552
|
+
Continue agent execution from a specific iteration.
|
|
1553
|
+
|
|
1554
|
+
This is the core execution loop extracted to support both:
|
|
1555
|
+
1. Normal execution (starting from iteration 0)
|
|
1556
|
+
2. Resume after HITL (starting from iteration N)
|
|
1557
|
+
|
|
1558
|
+
Args:
|
|
1559
|
+
context: Execution context
|
|
1560
|
+
messages: Conversation history
|
|
1561
|
+
iteration: Starting iteration number
|
|
1562
|
+
all_tool_calls: Accumulated tool calls
|
|
1563
|
+
|
|
1564
|
+
Returns:
|
|
1565
|
+
AgentResult with output and tool calls
|
|
1566
|
+
"""
|
|
1567
|
+
# Extract workflow context for checkpointing
|
|
1568
|
+
workflow_ctx = None
|
|
1569
|
+
if hasattr(context, "_workflow_entity"):
|
|
1570
|
+
workflow_ctx = context
|
|
1571
|
+
elif hasattr(context, "_agent_data") and "_workflow_ctx" in context._agent_data:
|
|
1572
|
+
workflow_ctx = context._agent_data["_workflow_ctx"]
|
|
1573
|
+
|
|
1574
|
+
# Prepare tool definitions
|
|
1575
|
+
tool_defs = [
|
|
1576
|
+
ToolDefinition(
|
|
1577
|
+
name=name,
|
|
1578
|
+
description=tool.description or f"Tool: {name}",
|
|
1579
|
+
parameters=tool.input_schema if hasattr(tool, "input_schema") else {},
|
|
1580
|
+
)
|
|
1581
|
+
for name, tool in self.tools.items()
|
|
1582
|
+
]
|
|
1583
|
+
|
|
1584
|
+
# Main iteration loop (continue from specified iteration)
|
|
1585
|
+
while iteration < self.max_iterations:
|
|
1586
|
+
self.logger.debug(f"Agent iteration {iteration + 1}/{self.max_iterations}")
|
|
1587
|
+
|
|
1588
|
+
# Call LLM for next response
|
|
1589
|
+
if self._language_model:
|
|
1590
|
+
# Legacy API: model is a LanguageModel instance
|
|
1591
|
+
request = GenerateRequest(
|
|
1592
|
+
system_prompt=self.instructions,
|
|
1593
|
+
messages=messages,
|
|
1594
|
+
tools=tool_defs if tool_defs else [],
|
|
1595
|
+
)
|
|
1596
|
+
request.config.temperature = self.temperature
|
|
1597
|
+
if self.max_tokens:
|
|
1598
|
+
request.config.max_tokens = self.max_tokens
|
|
1599
|
+
if self.top_p:
|
|
1600
|
+
request.config.top_p = self.top_p
|
|
1601
|
+
response = await self._language_model.generate(request)
|
|
1602
|
+
|
|
1603
|
+
# Track cost for this LLM call
|
|
1604
|
+
self._track_llm_cost(response, workflow_ctx)
|
|
1605
|
+
else:
|
|
1606
|
+
# New API: model is a string, create internal LM instance
|
|
1607
|
+
request = GenerateRequest(
|
|
1608
|
+
model=self.model,
|
|
1609
|
+
system_prompt=self.instructions,
|
|
1610
|
+
messages=messages,
|
|
1611
|
+
tools=tool_defs if tool_defs else [],
|
|
1612
|
+
)
|
|
1613
|
+
request.config.temperature = self.temperature
|
|
1614
|
+
if self.max_tokens:
|
|
1615
|
+
request.config.max_tokens = self.max_tokens
|
|
1616
|
+
if self.top_p:
|
|
1617
|
+
request.config.top_p = self.top_p
|
|
1618
|
+
|
|
1619
|
+
# Create internal LM instance for generation
|
|
1620
|
+
from ..lm import _LanguageModel
|
|
1621
|
+
provider, model_name = self.model.split('/', 1)
|
|
1622
|
+
internal_lm = _LanguageModel(provider=provider.lower(), default_model=None)
|
|
1623
|
+
response = await internal_lm.generate(request)
|
|
1624
|
+
|
|
1625
|
+
# Track cost for this LLM call
|
|
1626
|
+
self._track_llm_cost(response, workflow_ctx)
|
|
1627
|
+
|
|
1628
|
+
# Add assistant response to messages
|
|
1629
|
+
messages.append(Message.assistant(response.text))
|
|
1630
|
+
|
|
1631
|
+
# Check if LLM wants to use tools
|
|
1632
|
+
if response.tool_calls:
|
|
1633
|
+
self.logger.debug(f"Agent calling {len(response.tool_calls)} tool(s)")
|
|
1634
|
+
|
|
1635
|
+
# Store current conversation in context for potential handoffs
|
|
1636
|
+
if not hasattr(context, '_agent_data'):
|
|
1637
|
+
context._agent_data = {}
|
|
1638
|
+
context._agent_data["_current_conversation"] = messages
|
|
1639
|
+
|
|
1640
|
+
# Execute tool calls
|
|
1641
|
+
tool_results = []
|
|
1642
|
+
for tool_call in response.tool_calls:
|
|
1643
|
+
tool_name = tool_call["name"]
|
|
1644
|
+
tool_args_str = tool_call["arguments"]
|
|
1645
|
+
|
|
1646
|
+
# Track tool call
|
|
1647
|
+
all_tool_calls.append({
|
|
1648
|
+
"name": tool_name,
|
|
1649
|
+
"arguments": tool_args_str,
|
|
1650
|
+
"iteration": iteration + 1,
|
|
1651
|
+
})
|
|
1652
|
+
|
|
1653
|
+
# Execute tool
|
|
1654
|
+
try:
|
|
1655
|
+
# Parse arguments
|
|
1656
|
+
tool_args = json.loads(tool_args_str)
|
|
1657
|
+
|
|
1658
|
+
# Get tool
|
|
1659
|
+
tool = self.tools.get(tool_name)
|
|
1660
|
+
if not tool:
|
|
1661
|
+
result_text = f"Error: Tool '{tool_name}' not found"
|
|
1662
|
+
else:
|
|
1663
|
+
# Execute tool
|
|
1664
|
+
result = await tool.invoke(context, **tool_args)
|
|
1665
|
+
|
|
1666
|
+
# Check if this was a handoff
|
|
1667
|
+
if isinstance(result, dict) and result.get("_handoff"):
|
|
1668
|
+
self.logger.info(
|
|
1669
|
+
f"Handoff detected to '{result['to_agent']}', "
|
|
1670
|
+
f"terminating current agent"
|
|
1671
|
+
)
|
|
1672
|
+
# Save conversation before returning
|
|
1673
|
+
if isinstance(context, AgentContext):
|
|
1674
|
+
await context.save_conversation_history(messages)
|
|
1675
|
+
# Return immediately with handoff result
|
|
1676
|
+
return AgentResult(
|
|
1677
|
+
output=result["output"],
|
|
1678
|
+
tool_calls=all_tool_calls + result.get("tool_calls", []),
|
|
1679
|
+
context=context,
|
|
1680
|
+
handoff_to=result["to_agent"],
|
|
1681
|
+
handoff_metadata=result,
|
|
1682
|
+
)
|
|
1683
|
+
|
|
1684
|
+
result_text = _serialize_tool_result(result)
|
|
1685
|
+
|
|
1686
|
+
tool_results.append(
|
|
1687
|
+
{"tool": tool_name, "result": result_text, "error": None}
|
|
1688
|
+
)
|
|
1689
|
+
|
|
1690
|
+
except WaitingForUserInputException as e:
|
|
1691
|
+
# HITL PAUSE: Capture agent state and propagate exception
|
|
1692
|
+
self.logger.info(f"Agent pausing for user input at iteration {iteration}")
|
|
1693
|
+
|
|
1694
|
+
# Serialize messages to dict format
|
|
1695
|
+
messages_dict = [
|
|
1696
|
+
{"role": msg.role.value, "content": msg.content}
|
|
1697
|
+
for msg in messages
|
|
1698
|
+
]
|
|
1699
|
+
|
|
1700
|
+
# Enhance exception with agent execution context
|
|
1701
|
+
from ..exceptions import WaitingForUserInputException
|
|
1702
|
+
raise WaitingForUserInputException(
|
|
1703
|
+
question=e.question,
|
|
1704
|
+
input_type=e.input_type,
|
|
1705
|
+
options=e.options,
|
|
1706
|
+
checkpoint_state=e.checkpoint_state,
|
|
1707
|
+
agent_context={
|
|
1708
|
+
"agent_name": self.name,
|
|
1709
|
+
"iteration": iteration,
|
|
1710
|
+
"messages": messages_dict,
|
|
1711
|
+
"tool_results": tool_results,
|
|
1712
|
+
"pending_tool_call": {
|
|
1713
|
+
"name": tool_call["name"],
|
|
1714
|
+
"arguments": tool_call["arguments"],
|
|
1715
|
+
"tool_call_index": response.tool_calls.index(tool_call),
|
|
1716
|
+
},
|
|
1717
|
+
"all_tool_calls": all_tool_calls,
|
|
1718
|
+
"model_config": {
|
|
1719
|
+
"model": self.model,
|
|
1720
|
+
"temperature": self.temperature,
|
|
1721
|
+
"max_tokens": self.max_tokens,
|
|
1722
|
+
"top_p": self.top_p,
|
|
1723
|
+
},
|
|
1724
|
+
},
|
|
1725
|
+
) from e
|
|
1726
|
+
|
|
1727
|
+
except Exception as e:
|
|
1728
|
+
# Regular tool errors - log and continue
|
|
1729
|
+
self.logger.error(f"Tool execution error: {e}")
|
|
1730
|
+
tool_results.append(
|
|
1731
|
+
{"tool": tool_name, "result": None, "error": str(e)}
|
|
1732
|
+
)
|
|
1733
|
+
|
|
1734
|
+
# Add tool results to conversation
|
|
1735
|
+
results_text = "\n".join([
|
|
1736
|
+
f"Tool: {tr['tool']}\nResult: {tr['result']}"
|
|
1737
|
+
if tr["error"] is None
|
|
1738
|
+
else f"Tool: {tr['tool']}\nError: {tr['error']}"
|
|
1739
|
+
for tr in tool_results
|
|
1740
|
+
])
|
|
1741
|
+
messages.append(Message.user(
|
|
1742
|
+
f"Tool results:\n{results_text}\n\n"
|
|
1743
|
+
f"Please provide your final answer based on these results."
|
|
1744
|
+
))
|
|
1745
|
+
|
|
1746
|
+
# Continue loop for agent to process results
|
|
1747
|
+
|
|
1748
|
+
else:
|
|
1749
|
+
# No tool calls - agent is done
|
|
1750
|
+
self.logger.debug(f"Agent completed after {iteration + 1} iterations")
|
|
1751
|
+
# Save conversation before returning
|
|
1752
|
+
if isinstance(context, AgentContext):
|
|
1753
|
+
await context.save_conversation_history(messages)
|
|
1754
|
+
|
|
1755
|
+
# Emit completion checkpoint
|
|
1756
|
+
if workflow_ctx:
|
|
1757
|
+
workflow_ctx._send_checkpoint("agent.completed", {
|
|
1758
|
+
"agent.name": self.name,
|
|
1759
|
+
"agent.iterations": iteration + 1,
|
|
1760
|
+
"agent.tool_calls_count": len(all_tool_calls),
|
|
1761
|
+
"output_length": len(response.text),
|
|
1762
|
+
})
|
|
1763
|
+
|
|
1764
|
+
return AgentResult(
|
|
1765
|
+
output=response.text,
|
|
1766
|
+
tool_calls=all_tool_calls,
|
|
1767
|
+
context=context,
|
|
1768
|
+
)
|
|
1769
|
+
|
|
1770
|
+
iteration += 1
|
|
1771
|
+
|
|
1772
|
+
# Max iterations reached
|
|
1773
|
+
self.logger.warning(f"Agent reached max iterations ({self.max_iterations})")
|
|
1774
|
+
final_output = messages[-1].content if messages else "No output generated"
|
|
1775
|
+
# Save conversation before returning
|
|
1776
|
+
if isinstance(context, AgentContext):
|
|
1777
|
+
await context.save_conversation_history(messages)
|
|
1778
|
+
|
|
1779
|
+
# Emit completion checkpoint with max iterations flag
|
|
1780
|
+
if workflow_ctx:
|
|
1781
|
+
workflow_ctx._send_checkpoint("agent.completed", {
|
|
1782
|
+
"agent.name": self.name,
|
|
1783
|
+
"agent.iterations": self.max_iterations,
|
|
1784
|
+
"agent.tool_calls_count": len(all_tool_calls),
|
|
1785
|
+
"agent.max_iterations_reached": True,
|
|
1786
|
+
"output_length": len(final_output),
|
|
1787
|
+
})
|
|
1788
|
+
|
|
1789
|
+
return AgentResult(
|
|
1790
|
+
output=final_output,
|
|
1791
|
+
tool_calls=all_tool_calls,
|
|
1792
|
+
context=context,
|
|
1793
|
+
)
|