aury-agent 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/agents/core/base.py +10 -1
- aury/agents/core/context.py +53 -32
- aury/agents/core/event_bus/bus.py +0 -2
- aury/agents/core/factory.py +6 -1
- aury/agents/core/logging.py +24 -1
- aury/agents/core/types/block.py +16 -10
- aury/agents/llm/adapter.py +7 -1
- aury/agents/middleware/base.py +19 -46
- aury/agents/middleware/chain.py +20 -33
- aury/agents/middleware/message.py +6 -11
- aury/agents/middleware/message_container.py +15 -23
- aury/agents/middleware/raw_message.py +6 -5
- aury/agents/middleware/truncation.py +0 -2
- aury/agents/react/agent.py +24 -20
- aury/agents/react/factory.py +13 -1
- aury/agents/react/persistence.py +1 -8
- aury/agents/react/step.py +20 -30
- aury/agents/react/tools.py +44 -11
- aury/agents/tool/builtin/delegate.py +1 -6
- aury/agents/tool/decorator.py +2 -1
- aury/agents/workflow/adapter.py +3 -16
- aury/agents/workflow/executor.py +5 -13
- {aury_agent-0.0.6.dist-info → aury_agent-0.0.7.dist-info}/METADATA +1 -1
- {aury_agent-0.0.6.dist-info → aury_agent-0.0.7.dist-info}/RECORD +26 -26
- {aury_agent-0.0.6.dist-info → aury_agent-0.0.7.dist-info}/WHEEL +0 -0
- {aury_agent-0.0.6.dist-info → aury_agent-0.0.7.dist-info}/entry_points.txt +0 -0
|
@@ -49,10 +49,6 @@ class MessageContainerMiddleware(BaseMiddleware):
|
|
|
49
49
|
parent_id. Defaults to {"thinking", "text"}.
|
|
50
50
|
"""
|
|
51
51
|
|
|
52
|
-
# Key to store the token in middleware context
|
|
53
|
-
_TOKEN_KEY = "_message_container_token"
|
|
54
|
-
_BLOCK_ID_KEY = "_message_container_block_id"
|
|
55
|
-
|
|
56
52
|
# Default kinds that should be grouped under message container
|
|
57
53
|
DEFAULT_KINDS = {"thinking", "text"}
|
|
58
54
|
|
|
@@ -64,13 +60,17 @@ class MessageContainerMiddleware(BaseMiddleware):
|
|
|
64
60
|
"""
|
|
65
61
|
super().__init__()
|
|
66
62
|
self.apply_to_kinds = apply_to_kinds or self.DEFAULT_KINDS
|
|
63
|
+
self._token: Any = None
|
|
64
|
+
self._block_id: str | None = None
|
|
67
65
|
|
|
68
66
|
async def on_request(
|
|
69
67
|
self,
|
|
70
68
|
request: dict[str, Any],
|
|
71
|
-
context: dict[str, Any],
|
|
72
69
|
) -> dict[str, Any] | None:
|
|
73
70
|
"""Create message container block and set parent_id."""
|
|
71
|
+
from ..core.context import get_current_ctx_or_none
|
|
72
|
+
ctx = get_current_ctx_or_none()
|
|
73
|
+
|
|
74
74
|
# Generate container block ID
|
|
75
75
|
message_block_id = generate_id("blk")
|
|
76
76
|
|
|
@@ -86,41 +86,33 @@ class MessageContainerMiddleware(BaseMiddleware):
|
|
|
86
86
|
op=BlockOp.APPLY,
|
|
87
87
|
data={
|
|
88
88
|
"type": "llm_response",
|
|
89
|
-
"step": context.get("step"),
|
|
90
89
|
},
|
|
91
|
-
session_id=
|
|
92
|
-
invocation_id=
|
|
90
|
+
session_id=ctx.session_id if ctx else None,
|
|
91
|
+
invocation_id=ctx.invocation_id if ctx else None,
|
|
93
92
|
))
|
|
94
93
|
|
|
95
94
|
# Set parent_id in ContextVar with apply_to_kinds filter
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
context[self._TOKEN_KEY] = token
|
|
99
|
-
context[self._BLOCK_ID_KEY] = message_block_id
|
|
95
|
+
self._token = set_parent_id(message_block_id, apply_to_kinds=self.apply_to_kinds)
|
|
96
|
+
self._block_id = message_block_id
|
|
100
97
|
|
|
101
98
|
return request
|
|
102
99
|
|
|
103
100
|
async def on_response(
|
|
104
101
|
self,
|
|
105
102
|
response: dict[str, Any],
|
|
106
|
-
context: dict[str, Any],
|
|
107
103
|
) -> dict[str, Any] | None:
|
|
108
104
|
"""Reset parent_id to previous value."""
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
105
|
+
if self._token is not None:
|
|
106
|
+
reset_parent_id(self._token)
|
|
107
|
+
self._token = None
|
|
112
108
|
return response
|
|
113
109
|
|
|
114
110
|
async def on_error(
|
|
115
111
|
self,
|
|
116
112
|
error: Exception,
|
|
117
|
-
context: dict[str, Any],
|
|
118
113
|
) -> Exception | None:
|
|
119
114
|
"""Reset parent_id on error too."""
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
115
|
+
if self._token is not None:
|
|
116
|
+
reset_parent_id(self._token)
|
|
117
|
+
self._token = None
|
|
123
118
|
return error
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
__all__ = ["MessageContainerMiddleware"]
|
|
@@ -70,18 +70,18 @@ class RawMessageMiddleware(BaseMiddleware):
|
|
|
70
70
|
async def on_message_save(
|
|
71
71
|
self,
|
|
72
72
|
message: dict[str, Any],
|
|
73
|
-
context: dict[str, Any],
|
|
74
73
|
) -> dict[str, Any] | None:
|
|
75
74
|
"""Store complete message to RawMessageStore.
|
|
76
75
|
|
|
77
76
|
Args:
|
|
78
77
|
message: Complete message dict with 'role', 'content', etc.
|
|
79
|
-
context: Execution context with 'invocation_id', etc.
|
|
80
78
|
|
|
81
79
|
Returns:
|
|
82
80
|
The message with added 'raw_msg_id' field
|
|
83
81
|
"""
|
|
84
|
-
|
|
82
|
+
from ..core.context import get_current_ctx_or_none
|
|
83
|
+
ctx = get_current_ctx_or_none()
|
|
84
|
+
invocation_id = ctx.invocation_id if ctx else ""
|
|
85
85
|
if not invocation_id:
|
|
86
86
|
return message
|
|
87
87
|
|
|
@@ -108,7 +108,6 @@ class RawMessageMiddleware(BaseMiddleware):
|
|
|
108
108
|
self,
|
|
109
109
|
agent_id: str,
|
|
110
110
|
result: Any,
|
|
111
|
-
context: dict[str, Any],
|
|
112
111
|
) -> HookResult:
|
|
113
112
|
"""Clean up raw messages when invocation completes.
|
|
114
113
|
|
|
@@ -117,7 +116,9 @@ class RawMessageMiddleware(BaseMiddleware):
|
|
|
117
116
|
if self.persist_raw:
|
|
118
117
|
return HookResult.proceed()
|
|
119
118
|
|
|
120
|
-
|
|
119
|
+
from ..core.context import get_current_ctx_or_none
|
|
120
|
+
ctx = get_current_ctx_or_none()
|
|
121
|
+
invocation_id = ctx.invocation_id if ctx else ""
|
|
121
122
|
if invocation_id:
|
|
122
123
|
await self._cleanup_invocation(invocation_id)
|
|
123
124
|
|
|
@@ -48,13 +48,11 @@ class MessageTruncationMiddleware(BaseMiddleware):
|
|
|
48
48
|
async def on_message_save(
|
|
49
49
|
self,
|
|
50
50
|
message: dict[str, Any],
|
|
51
|
-
context: dict[str, Any],
|
|
52
51
|
) -> dict[str, Any] | None:
|
|
53
52
|
"""Truncate message content before saving.
|
|
54
53
|
|
|
55
54
|
Args:
|
|
56
55
|
message: Message dict with 'role', 'content', etc.
|
|
57
|
-
context: Execution context
|
|
58
56
|
|
|
59
57
|
Returns:
|
|
60
58
|
Modified message with truncated content
|
aury/agents/react/agent.py
CHANGED
|
@@ -93,6 +93,7 @@ class ReactAgent(BaseAgent):
|
|
|
93
93
|
enable_history: bool = True,
|
|
94
94
|
history_limit: int = 50,
|
|
95
95
|
delegate_tool_class: "type[BaseTool] | None" = None,
|
|
96
|
+
context_metadata: dict | None = None,
|
|
96
97
|
) -> "ReactAgent":
|
|
97
98
|
"""Create ReactAgent with minimal boilerplate. See factory.create_react_agent for details."""
|
|
98
99
|
return create_react_agent(
|
|
@@ -110,6 +111,7 @@ class ReactAgent(BaseAgent):
|
|
|
110
111
|
enable_history=enable_history,
|
|
111
112
|
history_limit=history_limit,
|
|
112
113
|
delegate_tool_class=delegate_tool_class,
|
|
114
|
+
context_metadata=context_metadata,
|
|
113
115
|
)
|
|
114
116
|
|
|
115
117
|
@classmethod
|
|
@@ -263,41 +265,38 @@ class ReactAgent(BaseAgent):
|
|
|
263
265
|
"Starting ReactAgent run",
|
|
264
266
|
extra={
|
|
265
267
|
"session_id": self.session.id,
|
|
266
|
-
"
|
|
268
|
+
"agent_id": self.ctx.agent_id,
|
|
267
269
|
}
|
|
268
270
|
)
|
|
269
271
|
|
|
270
|
-
# Build middleware context
|
|
271
|
-
from ..core.context import emit as global_emit
|
|
272
|
-
mw_context = {
|
|
273
|
-
"session_id": self.session.id,
|
|
274
|
-
"agent_id": self.name,
|
|
275
|
-
"agent_type": self.agent_type,
|
|
276
|
-
"emit": global_emit,
|
|
277
|
-
"backends": self.ctx.backends,
|
|
278
|
-
}
|
|
279
|
-
|
|
280
272
|
try:
|
|
281
|
-
# Create new invocation
|
|
273
|
+
# Create new invocation using ctx.invocation_id
|
|
282
274
|
self._current_invocation = Invocation(
|
|
283
|
-
id=
|
|
275
|
+
id=self.ctx.invocation_id,
|
|
284
276
|
session_id=self.session.id,
|
|
277
|
+
agent_id=self.ctx.agent_id,
|
|
285
278
|
state=InvocationState.RUNNING,
|
|
286
279
|
started_at=datetime.now(),
|
|
287
280
|
)
|
|
288
|
-
mw_context["invocation_id"] = self._current_invocation.id
|
|
289
281
|
|
|
290
282
|
logger.info("Created invocation", extra={"invocation_id": self._current_invocation.id})
|
|
291
283
|
|
|
284
|
+
# Persist invocation immediately (so we have record even if agent fails)
|
|
285
|
+
if self.ctx.backends and self.ctx.backends.invocation:
|
|
286
|
+
await self.ctx.backends.invocation.create(
|
|
287
|
+
self._current_invocation.id,
|
|
288
|
+
self.session.id,
|
|
289
|
+
self._current_invocation.to_dict(),
|
|
290
|
+
agent_id=self.ctx.agent_id,
|
|
291
|
+
)
|
|
292
|
+
|
|
292
293
|
# === Middleware: on_agent_start ===
|
|
293
294
|
if self.middleware:
|
|
294
295
|
logger.info(
|
|
295
296
|
"Calling middleware: on_agent_start",
|
|
296
297
|
extra={"invocation_id": self._current_invocation.id},
|
|
297
298
|
)
|
|
298
|
-
hook_result = await self.middleware.process_agent_start(
|
|
299
|
-
self.name, input, mw_context
|
|
300
|
-
)
|
|
299
|
+
hook_result = await self.middleware.process_agent_start(input)
|
|
301
300
|
if hook_result.action == HookAction.STOP:
|
|
302
301
|
logger.warning("Agent stopped by middleware on_agent_start", extra={"invocation_id": self._current_invocation.id})
|
|
303
302
|
await self.ctx.emit(BlockEvent(
|
|
@@ -449,6 +448,13 @@ class ReactAgent(BaseAgent):
|
|
|
449
448
|
|
|
450
449
|
# Complete invocation
|
|
451
450
|
if is_aborted:
|
|
451
|
+
# Save current buffer content before marking as aborted
|
|
452
|
+
if self._text_buffer or self._thinking_buffer or self._tool_invocations:
|
|
453
|
+
await persist_helpers.save_assistant_message(self)
|
|
454
|
+
# Save completed tool results
|
|
455
|
+
if self._tool_invocations:
|
|
456
|
+
await persist_helpers.save_tool_messages(self)
|
|
457
|
+
# Mark invocation as aborted
|
|
452
458
|
self._current_invocation.state = InvocationState.ABORTED
|
|
453
459
|
logger.info(
|
|
454
460
|
"Invocation aborted by user",
|
|
@@ -490,9 +496,7 @@ class ReactAgent(BaseAgent):
|
|
|
490
496
|
# === Middleware: on_agent_end ===
|
|
491
497
|
if self.middleware:
|
|
492
498
|
await self.middleware.process_agent_end(
|
|
493
|
-
self.name,
|
|
494
499
|
{"steps": self._current_step, "finish_reason": finish_reason},
|
|
495
|
-
mw_context,
|
|
496
500
|
)
|
|
497
501
|
|
|
498
502
|
await self.bus.publish(
|
|
@@ -551,7 +555,7 @@ class ReactAgent(BaseAgent):
|
|
|
551
555
|
|
|
552
556
|
# === Middleware: on_error ===
|
|
553
557
|
if self.middleware:
|
|
554
|
-
processed_error = await self.middleware.process_error(e
|
|
558
|
+
processed_error = await self.middleware.process_error(e)
|
|
555
559
|
if processed_error is None:
|
|
556
560
|
logger.warning(
|
|
557
561
|
"Error suppressed by middleware",
|
aury/agents/react/factory.py
CHANGED
|
@@ -44,6 +44,8 @@ def create_react_agent(
|
|
|
44
44
|
history_limit: int = 50,
|
|
45
45
|
# Tool customization
|
|
46
46
|
delegate_tool_class: "type[BaseTool] | None" = None,
|
|
47
|
+
# Context metadata
|
|
48
|
+
context_metadata: dict | None = None,
|
|
47
49
|
) -> "ReactAgent":
|
|
48
50
|
"""Create ReactAgent with minimal boilerplate.
|
|
49
51
|
|
|
@@ -159,16 +161,26 @@ def create_react_agent(
|
|
|
159
161
|
all_providers = default_providers + (context_providers or [])
|
|
160
162
|
|
|
161
163
|
# Build context
|
|
164
|
+
# agent_id: use config.id if provided, fallback to config.code, then "react_agent"
|
|
165
|
+
# agent_name: use config.name (display name)
|
|
166
|
+
agent_id = (
|
|
167
|
+
config.id if config and config.id
|
|
168
|
+
else (config.code if config and config.code else "react_agent")
|
|
169
|
+
)
|
|
170
|
+
agent_name = config.name if config else None
|
|
171
|
+
|
|
162
172
|
ctx = InvocationContext(
|
|
163
173
|
session=session,
|
|
164
174
|
invocation_id=generate_id("inv"),
|
|
165
|
-
agent_id=
|
|
175
|
+
agent_id=agent_id,
|
|
176
|
+
agent_name=agent_name,
|
|
166
177
|
backends=backends,
|
|
167
178
|
bus=bus,
|
|
168
179
|
llm=llm,
|
|
169
180
|
middleware=middleware_chain,
|
|
170
181
|
memory=memory,
|
|
171
182
|
snapshot=snapshot,
|
|
183
|
+
metadata=context_metadata or {},
|
|
172
184
|
)
|
|
173
185
|
|
|
174
186
|
agent = ReactAgent(ctx, config)
|
aury/agents/react/persistence.py
CHANGED
|
@@ -81,14 +81,7 @@ async def trigger_message_save(agent: "ReactAgent", message: dict) -> dict | Non
|
|
|
81
81
|
if not agent.middleware:
|
|
82
82
|
return message
|
|
83
83
|
|
|
84
|
-
|
|
85
|
-
mw_context = {
|
|
86
|
-
"session_id": agent.session.id,
|
|
87
|
-
"agent_id": agent.name,
|
|
88
|
-
"namespace": namespace,
|
|
89
|
-
}
|
|
90
|
-
|
|
91
|
-
return await agent.middleware.process_message_save(message, mw_context)
|
|
84
|
+
return await agent.middleware.process_message_save(message)
|
|
92
85
|
|
|
93
86
|
|
|
94
87
|
async def save_user_message(agent: "ReactAgent", input: "PromptInput") -> None:
|
aury/agents/react/step.py
CHANGED
|
@@ -190,6 +190,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
190
190
|
agent._text_buffer = ""
|
|
191
191
|
agent._thinking_buffer = "" # Buffer for non-streaming thinking
|
|
192
192
|
agent._tool_invocations = []
|
|
193
|
+
agent._last_usage = None # Store usage for middleware
|
|
193
194
|
|
|
194
195
|
# Reset block IDs for this step (each step gets fresh block IDs)
|
|
195
196
|
agent._current_text_block_id = None
|
|
@@ -199,20 +200,9 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
199
200
|
agent._call_id_to_tool = {}
|
|
200
201
|
agent._tool_call_blocks = {}
|
|
201
202
|
|
|
202
|
-
# Track accumulated arguments for streaming tool calls
|
|
203
|
+
# Track accumulated arguments for streaming tool calls
|
|
203
204
|
tool_call_accumulated_args: dict[str, dict[str, Any]] = {}
|
|
204
205
|
|
|
205
|
-
# Build middleware context for this step
|
|
206
|
-
mw_context = {
|
|
207
|
-
"session_id": agent.session.id,
|
|
208
|
-
"invocation_id": agent._current_invocation.id if agent._current_invocation else "",
|
|
209
|
-
"step": agent._current_step,
|
|
210
|
-
"agent_id": agent.name,
|
|
211
|
-
"emit": global_emit, # For middleware to emit BlockEvent/ActionEvent
|
|
212
|
-
"backends": agent.ctx.backends,
|
|
213
|
-
"tool_mode": effective_tool_mode.value, # Add tool mode to context
|
|
214
|
-
}
|
|
215
|
-
|
|
216
206
|
# Build LLM call kwargs
|
|
217
207
|
# Note: temperature, max_tokens, timeout, retries are configured on LLMProvider
|
|
218
208
|
llm_kwargs: dict[str, Any] = {
|
|
@@ -244,7 +234,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
244
234
|
"Calling middleware: on_request",
|
|
245
235
|
extra={"invocation_id": agent._current_invocation.id},
|
|
246
236
|
)
|
|
247
|
-
llm_kwargs = await agent.middleware.process_request(llm_kwargs
|
|
237
|
+
llm_kwargs = await agent.middleware.process_request(llm_kwargs)
|
|
248
238
|
if llm_kwargs is None:
|
|
249
239
|
logger.warning(
|
|
250
240
|
"LLM request cancelled by middleware",
|
|
@@ -307,7 +297,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
307
297
|
stream_chunk = {"delta": event.delta, "type": "content"}
|
|
308
298
|
if agent.middleware:
|
|
309
299
|
stream_chunk = await agent.middleware.process_stream_chunk(
|
|
310
|
-
stream_chunk
|
|
300
|
+
stream_chunk
|
|
311
301
|
)
|
|
312
302
|
if stream_chunk is None:
|
|
313
303
|
continue # Skip this chunk
|
|
@@ -342,7 +332,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
342
332
|
stream_chunk = {"delta": event.delta, "type": "thinking"}
|
|
343
333
|
if agent.middleware:
|
|
344
334
|
stream_chunk = await agent.middleware.process_stream_chunk(
|
|
345
|
-
stream_chunk
|
|
335
|
+
stream_chunk
|
|
346
336
|
)
|
|
347
337
|
if stream_chunk is None:
|
|
348
338
|
continue # Skip this chunk
|
|
@@ -450,12 +440,9 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
450
440
|
# === Middleware: on_tool_call_delta ===
|
|
451
441
|
processed_delta = arguments_delta
|
|
452
442
|
if agent.middleware:
|
|
453
|
-
|
|
454
|
-
**mw_context,
|
|
455
|
-
"accumulated_args": tool_call_accumulated_args.get(call_id, {}),
|
|
456
|
-
}
|
|
443
|
+
accumulated_args = tool_call_accumulated_args.get(call_id, {})
|
|
457
444
|
processed_delta = await agent.middleware.process_tool_call_delta(
|
|
458
|
-
call_id, tool_name, arguments_delta,
|
|
445
|
+
call_id, tool_name, arguments_delta, accumulated_args
|
|
459
446
|
)
|
|
460
447
|
if processed_delta is None:
|
|
461
448
|
continue # Skip this delta
|
|
@@ -539,17 +526,19 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
539
526
|
|
|
540
527
|
elif event.type == "usage":
|
|
541
528
|
if event.usage:
|
|
529
|
+
# Store usage for middleware
|
|
530
|
+
agent._last_usage = {
|
|
531
|
+
"provider": agent.llm.provider,
|
|
532
|
+
"model": agent.llm.model,
|
|
533
|
+
"input_tokens": event.usage.input_tokens,
|
|
534
|
+
"output_tokens": event.usage.output_tokens,
|
|
535
|
+
"cache_read_tokens": event.usage.cache_read_tokens,
|
|
536
|
+
"cache_write_tokens": event.usage.cache_write_tokens,
|
|
537
|
+
"reasoning_tokens": event.usage.reasoning_tokens,
|
|
538
|
+
}
|
|
542
539
|
await agent.bus.publish(
|
|
543
540
|
Events.USAGE_RECORDED,
|
|
544
|
-
|
|
545
|
-
"provider": agent.llm.provider,
|
|
546
|
-
"model": agent.llm.model,
|
|
547
|
-
"input_tokens": event.usage.input_tokens,
|
|
548
|
-
"output_tokens": event.usage.output_tokens,
|
|
549
|
-
"cache_read_tokens": event.usage.cache_read_tokens,
|
|
550
|
-
"cache_write_tokens": event.usage.cache_write_tokens,
|
|
551
|
-
"reasoning_tokens": event.usage.reasoning_tokens,
|
|
552
|
-
},
|
|
541
|
+
agent._last_usage,
|
|
553
542
|
)
|
|
554
543
|
|
|
555
544
|
elif event.type == "error":
|
|
@@ -624,6 +613,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
624
613
|
"thinking": agent._thinking_buffer,
|
|
625
614
|
"tool_calls": len(agent._tool_invocations),
|
|
626
615
|
"finish_reason": finish_reason,
|
|
616
|
+
"usage": agent._last_usage, # Include usage for middleware
|
|
627
617
|
}
|
|
628
618
|
if agent.middleware:
|
|
629
619
|
logger.debug(
|
|
@@ -635,7 +625,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
|
|
|
635
625
|
},
|
|
636
626
|
)
|
|
637
627
|
llm_response_data = await agent.middleware.process_response(
|
|
638
|
-
llm_response_data
|
|
628
|
+
llm_response_data
|
|
639
629
|
)
|
|
640
630
|
|
|
641
631
|
await agent.bus.publish(
|
aury/agents/react/tools.py
CHANGED
|
@@ -49,6 +49,19 @@ async def execute_tool(agent: "ReactAgent", invocation: ToolInvocation) -> ToolR
|
|
|
49
49
|
Returns:
|
|
50
50
|
ToolResult from tool execution
|
|
51
51
|
"""
|
|
52
|
+
# Check abort before execution
|
|
53
|
+
if await agent._check_abort():
|
|
54
|
+
error_msg = f"Tool {invocation.tool_name} aborted before execution"
|
|
55
|
+
invocation.mark_result(error_msg, is_error=True)
|
|
56
|
+
logger.info(
|
|
57
|
+
f"Tool aborted before execution: {invocation.tool_name}",
|
|
58
|
+
extra={
|
|
59
|
+
"invocation_id": agent._current_invocation.id if agent._current_invocation else "",
|
|
60
|
+
"call_id": invocation.tool_call_id,
|
|
61
|
+
},
|
|
62
|
+
)
|
|
63
|
+
return ToolResult.error(error_msg)
|
|
64
|
+
|
|
52
65
|
invocation.mark_call_complete()
|
|
53
66
|
|
|
54
67
|
logger.info(
|
|
@@ -60,14 +73,6 @@ async def execute_tool(agent: "ReactAgent", invocation: ToolInvocation) -> ToolR
|
|
|
60
73
|
},
|
|
61
74
|
)
|
|
62
75
|
|
|
63
|
-
# Build middleware context
|
|
64
|
-
mw_context = {
|
|
65
|
-
"session_id": agent.session.id,
|
|
66
|
-
"invocation_id": agent._current_invocation.id if agent._current_invocation else "",
|
|
67
|
-
"tool_call_id": invocation.tool_call_id,
|
|
68
|
-
"agent_id": agent.name,
|
|
69
|
-
}
|
|
70
|
-
|
|
71
76
|
try:
|
|
72
77
|
# Get tool from agent context
|
|
73
78
|
tool = get_tool(agent, invocation.tool_name)
|
|
@@ -87,7 +92,7 @@ async def execute_tool(agent: "ReactAgent", invocation: ToolInvocation) -> ToolR
|
|
|
87
92
|
extra={"invocation_id": agent._current_invocation.id, "call_id": invocation.tool_call_id},
|
|
88
93
|
)
|
|
89
94
|
hook_result = await agent.middleware.process_tool_call(
|
|
90
|
-
tool, invocation.args
|
|
95
|
+
tool, invocation.args
|
|
91
96
|
)
|
|
92
97
|
if hook_result.action == HookAction.SKIP:
|
|
93
98
|
logger.warning(
|
|
@@ -136,7 +141,7 @@ async def execute_tool(agent: "ReactAgent", invocation: ToolInvocation) -> ToolR
|
|
|
136
141
|
f"Calling middleware: on_tool_end ({invocation.tool_name})",
|
|
137
142
|
extra={"invocation_id": agent._current_invocation.id},
|
|
138
143
|
)
|
|
139
|
-
hook_result = await agent.middleware.process_tool_end(tool, result
|
|
144
|
+
hook_result = await agent.middleware.process_tool_end(tool, result)
|
|
140
145
|
if hook_result.action == HookAction.RETRY:
|
|
141
146
|
logger.info(
|
|
142
147
|
f"Tool {invocation.tool_name} retry requested by middleware",
|
|
@@ -212,16 +217,44 @@ async def process_tool_results(agent: "ReactAgent") -> None:
|
|
|
212
217
|
},
|
|
213
218
|
)
|
|
214
219
|
|
|
220
|
+
# Check abort before starting tool execution
|
|
221
|
+
if await agent._check_abort():
|
|
222
|
+
logger.info(
|
|
223
|
+
"Tool execution aborted before starting",
|
|
224
|
+
extra={"invocation_id": agent._current_invocation.id},
|
|
225
|
+
)
|
|
226
|
+
# Return empty results - agent loop will handle abort
|
|
227
|
+
return
|
|
228
|
+
|
|
215
229
|
# Execute tools based on configuration
|
|
216
230
|
if agent.config.parallel_tool_execution:
|
|
217
231
|
# Parallel execution using asyncio.gather with create_task
|
|
218
232
|
# create_task ensures each task gets its own ContextVar copy
|
|
219
233
|
tasks = [asyncio.create_task(execute_tool(agent, inv)) for inv in agent._tool_invocations]
|
|
220
234
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
235
|
+
|
|
236
|
+
# Check abort after parallel execution - cancel remaining if aborted
|
|
237
|
+
if await agent._check_abort():
|
|
238
|
+
logger.info(
|
|
239
|
+
"Tool execution aborted after parallel execution",
|
|
240
|
+
extra={"invocation_id": agent._current_invocation.id},
|
|
241
|
+
)
|
|
221
242
|
else:
|
|
222
|
-
# Sequential execution
|
|
243
|
+
# Sequential execution with abort check between tools
|
|
223
244
|
results = []
|
|
224
245
|
for inv in agent._tool_invocations:
|
|
246
|
+
# Check abort before each tool
|
|
247
|
+
if await agent._check_abort():
|
|
248
|
+
logger.info(
|
|
249
|
+
f"Tool execution aborted before {inv.tool_name}",
|
|
250
|
+
extra={"invocation_id": agent._current_invocation.id},
|
|
251
|
+
)
|
|
252
|
+
# Mark remaining as aborted
|
|
253
|
+
error_result = ToolResult.error(f"Aborted before execution")
|
|
254
|
+
results.append(error_result)
|
|
255
|
+
inv.mark_result(error_result.output, is_error=True)
|
|
256
|
+
continue
|
|
257
|
+
|
|
225
258
|
try:
|
|
226
259
|
result = await execute_tool(agent, inv)
|
|
227
260
|
results.append(result)
|
|
@@ -153,12 +153,7 @@ Specify the agent key and task data."""
|
|
|
153
153
|
|
|
154
154
|
# Get dynamic agents from middleware (progressive disclosure)
|
|
155
155
|
if self.middleware and ctx:
|
|
156
|
-
|
|
157
|
-
"session_id": ctx.session_id,
|
|
158
|
-
"invocation_id": ctx.invocation_id,
|
|
159
|
-
"agent_id": ctx.agent,
|
|
160
|
-
}
|
|
161
|
-
dynamic_agents = await self.middleware.get_dynamic_subagents(mw_context)
|
|
156
|
+
dynamic_agents = await self.middleware.get_dynamic_subagents()
|
|
162
157
|
if dynamic_agents:
|
|
163
158
|
# Store dynamic agents for later lookup
|
|
164
159
|
for config in dynamic_agents:
|
aury/agents/tool/decorator.py
CHANGED
|
@@ -53,7 +53,8 @@ def _type_to_schema(t: type) -> dict[str, Any]:
|
|
|
53
53
|
bool: {"type": "boolean"},
|
|
54
54
|
list: {"type": "array"},
|
|
55
55
|
dict: {"type": "object"},
|
|
56
|
-
Any:
|
|
56
|
+
# Any type: JSON Schema 2020-12 requires valid type, use object as fallback
|
|
57
|
+
Any: {"type": "object"},
|
|
57
58
|
}
|
|
58
59
|
|
|
59
60
|
return type_map.get(t, {"type": "string"})
|
aury/agents/workflow/adapter.py
CHANGED
|
@@ -162,21 +162,10 @@ class WorkflowAgent(BaseAgent):
|
|
|
162
162
|
inputs = input if isinstance(input, dict) else {"input": input}
|
|
163
163
|
logger.info(f"WorkflowAgent executing, workflow={self.workflow.spec.name}, invocation_id={self.ctx.invocation_id}")
|
|
164
164
|
|
|
165
|
-
# Build middleware context
|
|
166
|
-
mw_context = {
|
|
167
|
-
"session_id": self.session.id,
|
|
168
|
-
"invocation_id": self.ctx.invocation_id,
|
|
169
|
-
"agent_id": self.name,
|
|
170
|
-
"agent_type": self.agent_type,
|
|
171
|
-
"workflow_name": self.workflow.spec.name,
|
|
172
|
-
}
|
|
173
|
-
|
|
174
165
|
# === Middleware: on_agent_start ===
|
|
175
166
|
if self.middleware:
|
|
176
167
|
logger.debug(f"WorkflowAgent: processing on_agent_start hooks, invocation_id={self.ctx.invocation_id}")
|
|
177
|
-
hook_result = await self.middleware.process_agent_start(
|
|
178
|
-
self.name, inputs, mw_context
|
|
179
|
-
)
|
|
168
|
+
hook_result = await self.middleware.process_agent_start(inputs)
|
|
180
169
|
if hook_result.action == HookAction.STOP:
|
|
181
170
|
logger.warning(f"Workflow stopped by middleware on_agent_start, invocation_id={self.ctx.invocation_id}")
|
|
182
171
|
await self.ctx.emit(BlockEvent(
|
|
@@ -205,16 +194,14 @@ class WorkflowAgent(BaseAgent):
|
|
|
205
194
|
# === Middleware: on_agent_end ===
|
|
206
195
|
if self.middleware:
|
|
207
196
|
logger.debug(f"WorkflowAgent: processing on_agent_end hooks, invocation_id={self.ctx.invocation_id}")
|
|
208
|
-
await self.middleware.process_agent_end(
|
|
209
|
-
self.name, result, mw_context
|
|
210
|
-
)
|
|
197
|
+
await self.middleware.process_agent_end(result)
|
|
211
198
|
|
|
212
199
|
except Exception as e:
|
|
213
200
|
logger.error(f"WorkflowAgent: execution error, error={type(e).__name__}, workflow={self.workflow.spec.name}, invocation_id={self.ctx.invocation_id}", exc_info=True)
|
|
214
201
|
# === Middleware: on_error ===
|
|
215
202
|
if self.middleware:
|
|
216
203
|
logger.debug(f"WorkflowAgent: processing on_error hooks, invocation_id={self.ctx.invocation_id}")
|
|
217
|
-
processed_error = await self.middleware.process_error(e
|
|
204
|
+
processed_error = await self.middleware.process_error(e)
|
|
218
205
|
if processed_error is None:
|
|
219
206
|
logger.warning(f"WorkflowAgent: error suppressed by middleware, invocation_id={self.ctx.invocation_id}")
|
|
220
207
|
return
|
aury/agents/workflow/executor.py
CHANGED
|
@@ -598,24 +598,12 @@ class WorkflowExecutor:
|
|
|
598
598
|
# Get effective middleware for this node
|
|
599
599
|
effective_middleware = self._get_effective_middleware(node)
|
|
600
600
|
|
|
601
|
-
# Build middleware context
|
|
602
|
-
mw_context = {
|
|
603
|
-
"session_id": self.ctx.session_id,
|
|
604
|
-
"invocation_id": self.ctx.invocation_id,
|
|
605
|
-
"parent_agent_id": self.workflow.spec.name,
|
|
606
|
-
"child_agent_id": node.agent,
|
|
607
|
-
"node_id": node.id,
|
|
608
|
-
"parent_block_id": parent_block_id,
|
|
609
|
-
"has_node_middleware": bool(node.middleware),
|
|
610
|
-
}
|
|
611
|
-
|
|
612
601
|
# === Middleware: on_subagent_start ===
|
|
613
602
|
if effective_middleware:
|
|
614
603
|
hook_result = await effective_middleware.process_subagent_start(
|
|
615
604
|
self.workflow.spec.name,
|
|
616
605
|
node.agent,
|
|
617
606
|
"embedded", # Workflow nodes are embedded execution
|
|
618
|
-
mw_context,
|
|
619
607
|
)
|
|
620
608
|
if hook_result.action == HookAction.SKIP:
|
|
621
609
|
logger.info(f"SubAgent {node.agent} skipped by middleware")
|
|
@@ -628,8 +616,13 @@ class WorkflowExecutor:
|
|
|
628
616
|
try:
|
|
629
617
|
# Create child context for sub-agent with effective middleware
|
|
630
618
|
# Note: parent_block_id is already set via ContextVar above
|
|
619
|
+
# Get agent name from factory if available
|
|
620
|
+
agent_class = self.agent_factory.get_class(node.agent)
|
|
621
|
+
agent_name = getattr(agent_class, 'name', node.agent) if agent_class else node.agent
|
|
622
|
+
|
|
631
623
|
child_ctx = self.ctx.create_child(
|
|
632
624
|
agent_id=node.agent,
|
|
625
|
+
agent_name=agent_name,
|
|
633
626
|
middleware=effective_middleware,
|
|
634
627
|
)
|
|
635
628
|
|
|
@@ -662,7 +655,6 @@ class WorkflowExecutor:
|
|
|
662
655
|
self.workflow.spec.name,
|
|
663
656
|
node.agent,
|
|
664
657
|
result,
|
|
665
|
-
mw_context,
|
|
666
658
|
)
|
|
667
659
|
|
|
668
660
|
return result
|