aury-agent 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/agents/core/base.py +10 -1
- aury/agents/core/context.py +53 -32
- aury/agents/core/event_bus/bus.py +0 -2
- aury/agents/core/factory.py +6 -1
- aury/agents/core/logging.py +24 -1
- aury/agents/core/types/block.py +16 -10
- aury/agents/llm/adapter.py +7 -1
- aury/agents/middleware/base.py +19 -46
- aury/agents/middleware/chain.py +20 -33
- aury/agents/middleware/message.py +6 -11
- aury/agents/middleware/message_container.py +15 -23
- aury/agents/middleware/raw_message.py +6 -5
- aury/agents/middleware/truncation.py +0 -2
- aury/agents/react/agent.py +24 -20
- aury/agents/react/factory.py +13 -1
- aury/agents/react/persistence.py +1 -8
- aury/agents/react/step.py +20 -30
- aury/agents/react/tools.py +53 -15
- aury/agents/tool/builtin/delegate.py +1 -6
- aury/agents/tool/decorator.py +2 -1
- aury/agents/workflow/adapter.py +3 -16
- aury/agents/workflow/executor.py +5 -13
- {aury_agent-0.0.5.dist-info → aury_agent-0.0.7.dist-info}/METADATA +1 -1
- {aury_agent-0.0.5.dist-info → aury_agent-0.0.7.dist-info}/RECORD +26 -26
- {aury_agent-0.0.5.dist-info → aury_agent-0.0.7.dist-info}/WHEEL +0 -0
- {aury_agent-0.0.5.dist-info → aury_agent-0.0.7.dist-info}/entry_points.txt +0 -0
aury/agents/middleware/base.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
|
1
|
-
"""Middleware protocol and base implementation.
|
|
1
|
+
"""Middleware protocol and base implementation.
|
|
2
|
+
|
|
3
|
+
Middleware can access InvocationContext via get_current_ctx_or_none() for:
|
|
4
|
+
- session_id, invocation_id, agent_id, agent_name
|
|
5
|
+
- backends, metadata, etc.
|
|
6
|
+
|
|
7
|
+
Middleware should use self._xxx for internal state between hooks.
|
|
8
|
+
"""
|
|
2
9
|
from __future__ import annotations
|
|
3
10
|
|
|
4
11
|
from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING
|
|
@@ -14,6 +21,7 @@ class Middleware(Protocol):
|
|
|
14
21
|
"""Middleware protocol for request/response processing.
|
|
15
22
|
|
|
16
23
|
Includes both LLM request/response hooks and agent lifecycle hooks.
|
|
24
|
+
Use get_current_ctx_or_none() to access InvocationContext.
|
|
17
25
|
"""
|
|
18
26
|
|
|
19
27
|
@property
|
|
@@ -26,13 +34,11 @@ class Middleware(Protocol):
|
|
|
26
34
|
async def on_request(
|
|
27
35
|
self,
|
|
28
36
|
request: dict[str, Any],
|
|
29
|
-
context: dict[str, Any],
|
|
30
37
|
) -> dict[str, Any] | None:
|
|
31
38
|
"""Process request before LLM call.
|
|
32
39
|
|
|
33
40
|
Args:
|
|
34
41
|
request: The request to process
|
|
35
|
-
context: Execution context
|
|
36
42
|
|
|
37
43
|
Returns:
|
|
38
44
|
Modified request, or None to skip further processing
|
|
@@ -42,13 +48,11 @@ class Middleware(Protocol):
|
|
|
42
48
|
async def on_response(
|
|
43
49
|
self,
|
|
44
50
|
response: dict[str, Any],
|
|
45
|
-
context: dict[str, Any],
|
|
46
51
|
) -> dict[str, Any] | None:
|
|
47
52
|
"""Process response after LLM call.
|
|
48
53
|
|
|
49
54
|
Args:
|
|
50
55
|
response: The response to process
|
|
51
|
-
context: Execution context
|
|
52
56
|
|
|
53
57
|
Returns:
|
|
54
58
|
Modified response, or None to skip further processing
|
|
@@ -58,13 +62,11 @@ class Middleware(Protocol):
|
|
|
58
62
|
async def on_error(
|
|
59
63
|
self,
|
|
60
64
|
error: Exception,
|
|
61
|
-
context: dict[str, Any],
|
|
62
65
|
) -> Exception | None:
|
|
63
66
|
"""Handle errors.
|
|
64
67
|
|
|
65
68
|
Args:
|
|
66
69
|
error: The exception that occurred
|
|
67
|
-
context: Execution context
|
|
68
70
|
|
|
69
71
|
Returns:
|
|
70
72
|
Modified exception, or None to suppress
|
|
@@ -74,13 +76,11 @@ class Middleware(Protocol):
|
|
|
74
76
|
async def on_model_stream(
|
|
75
77
|
self,
|
|
76
78
|
chunk: dict[str, Any],
|
|
77
|
-
context: dict[str, Any],
|
|
78
79
|
) -> dict[str, Any] | None:
|
|
79
80
|
"""Process streaming chunk (triggered by trigger_mode).
|
|
80
81
|
|
|
81
82
|
Args:
|
|
82
83
|
chunk: The streaming chunk
|
|
83
|
-
context: Execution context
|
|
84
84
|
|
|
85
85
|
Returns:
|
|
86
86
|
Modified chunk, or None to skip further processing
|
|
@@ -90,13 +90,11 @@ class Middleware(Protocol):
|
|
|
90
90
|
async def on_thinking_stream(
|
|
91
91
|
self,
|
|
92
92
|
chunk: dict[str, Any],
|
|
93
|
-
context: dict[str, Any],
|
|
94
93
|
) -> dict[str, Any] | None:
|
|
95
94
|
"""Process thinking stream chunk.
|
|
96
95
|
|
|
97
96
|
Args:
|
|
98
97
|
chunk: The thinking chunk with {"delta": str}
|
|
99
|
-
context: Execution context
|
|
100
98
|
|
|
101
99
|
Returns:
|
|
102
100
|
Modified chunk, or None to skip
|
|
@@ -107,16 +105,14 @@ class Middleware(Protocol):
|
|
|
107
105
|
|
|
108
106
|
async def on_agent_start(
|
|
109
107
|
self,
|
|
110
|
-
agent_id: str,
|
|
111
108
|
input_data: Any,
|
|
112
|
-
context: dict[str, Any],
|
|
113
109
|
) -> HookResult:
|
|
114
110
|
"""Called when agent starts processing.
|
|
115
111
|
|
|
112
|
+
Use get_current_ctx_or_none() to access agent_id, session_id, etc.
|
|
113
|
+
|
|
116
114
|
Args:
|
|
117
|
-
agent_id: The agent identifier
|
|
118
115
|
input_data: Input to the agent
|
|
119
|
-
context: Execution context
|
|
120
116
|
|
|
121
117
|
Returns:
|
|
122
118
|
HookResult controlling execution flow
|
|
@@ -125,16 +121,14 @@ class Middleware(Protocol):
|
|
|
125
121
|
|
|
126
122
|
async def on_agent_end(
|
|
127
123
|
self,
|
|
128
|
-
agent_id: str,
|
|
129
124
|
result: Any,
|
|
130
|
-
context: dict[str, Any],
|
|
131
125
|
) -> HookResult:
|
|
132
126
|
"""Called when agent completes processing.
|
|
133
127
|
|
|
128
|
+
Use get_current_ctx_or_none() to access agent_id, session_id, etc.
|
|
129
|
+
|
|
134
130
|
Args:
|
|
135
|
-
agent_id: The agent identifier
|
|
136
131
|
result: Agent's result
|
|
137
|
-
context: Execution context
|
|
138
132
|
|
|
139
133
|
Returns:
|
|
140
134
|
HookResult (only CONTINUE/STOP meaningful here)
|
|
@@ -145,14 +139,12 @@ class Middleware(Protocol):
|
|
|
145
139
|
self,
|
|
146
140
|
tool: "BaseTool",
|
|
147
141
|
params: dict[str, Any],
|
|
148
|
-
context: dict[str, Any],
|
|
149
142
|
) -> HookResult:
|
|
150
143
|
"""Called before tool execution.
|
|
151
144
|
|
|
152
145
|
Args:
|
|
153
146
|
tool: The tool to be called
|
|
154
147
|
params: Tool parameters
|
|
155
|
-
context: Execution context
|
|
156
148
|
|
|
157
149
|
Returns:
|
|
158
150
|
HookResult - SKIP to skip tool, RETRY to modify params
|
|
@@ -164,7 +156,7 @@ class Middleware(Protocol):
|
|
|
164
156
|
call_id: str,
|
|
165
157
|
tool_name: str,
|
|
166
158
|
delta: dict[str, Any],
|
|
167
|
-
|
|
159
|
+
accumulated_args: dict[str, Any],
|
|
168
160
|
) -> dict[str, Any] | None:
|
|
169
161
|
"""Called during streaming tool argument generation.
|
|
170
162
|
|
|
@@ -175,7 +167,7 @@ class Middleware(Protocol):
|
|
|
175
167
|
call_id: Tool call identifier
|
|
176
168
|
tool_name: Name of the tool being called
|
|
177
169
|
delta: Incremental parameter update (e.g. {"content": "more text"})
|
|
178
|
-
|
|
170
|
+
accumulated_args: Current accumulated arguments state
|
|
179
171
|
|
|
180
172
|
Returns:
|
|
181
173
|
Modified delta, or None to skip emitting this delta
|
|
@@ -186,14 +178,12 @@ class Middleware(Protocol):
|
|
|
186
178
|
self,
|
|
187
179
|
tool: "BaseTool",
|
|
188
180
|
result: "ToolResult",
|
|
189
|
-
context: dict[str, Any],
|
|
190
181
|
) -> HookResult:
|
|
191
182
|
"""Called after tool execution.
|
|
192
183
|
|
|
193
184
|
Args:
|
|
194
185
|
tool: The tool that was called
|
|
195
186
|
result: Tool execution result
|
|
196
|
-
context: Execution context
|
|
197
187
|
|
|
198
188
|
Returns:
|
|
199
189
|
HookResult - RETRY to re-execute tool
|
|
@@ -205,7 +195,6 @@ class Middleware(Protocol):
|
|
|
205
195
|
parent_agent_id: str,
|
|
206
196
|
child_agent_id: str,
|
|
207
197
|
mode: str, # "embedded" or "delegated"
|
|
208
|
-
context: dict[str, Any],
|
|
209
198
|
) -> HookResult:
|
|
210
199
|
"""Called when delegating to a sub-agent.
|
|
211
200
|
|
|
@@ -213,7 +202,6 @@ class Middleware(Protocol):
|
|
|
213
202
|
parent_agent_id: Parent agent identifier
|
|
214
203
|
child_agent_id: Child agent identifier
|
|
215
204
|
mode: Delegation mode
|
|
216
|
-
context: Execution context
|
|
217
205
|
|
|
218
206
|
Returns:
|
|
219
207
|
HookResult - SKIP to skip delegation
|
|
@@ -225,7 +213,6 @@ class Middleware(Protocol):
|
|
|
225
213
|
parent_agent_id: str,
|
|
226
214
|
child_agent_id: str,
|
|
227
215
|
result: Any,
|
|
228
|
-
context: dict[str, Any],
|
|
229
216
|
) -> HookResult:
|
|
230
217
|
"""Called when sub-agent completes.
|
|
231
218
|
|
|
@@ -233,7 +220,6 @@ class Middleware(Protocol):
|
|
|
233
220
|
parent_agent_id: Parent agent identifier
|
|
234
221
|
child_agent_id: Child agent identifier
|
|
235
222
|
result: Sub-agent's result
|
|
236
|
-
context: Execution context
|
|
237
223
|
|
|
238
224
|
Returns:
|
|
239
225
|
HookResult (for post-processing)
|
|
@@ -243,7 +229,6 @@ class Middleware(Protocol):
|
|
|
243
229
|
async def on_message_save(
|
|
244
230
|
self,
|
|
245
231
|
message: dict[str, Any],
|
|
246
|
-
context: dict[str, Any],
|
|
247
232
|
) -> dict[str, Any] | None:
|
|
248
233
|
"""Called before saving a message to history.
|
|
249
234
|
|
|
@@ -252,7 +237,6 @@ class Middleware(Protocol):
|
|
|
252
237
|
|
|
253
238
|
Args:
|
|
254
239
|
message: Message dict with 'role', 'content', etc.
|
|
255
|
-
context: Execution context
|
|
256
240
|
|
|
257
241
|
Returns:
|
|
258
242
|
Modified message, or None to skip saving
|
|
@@ -265,6 +249,9 @@ class BaseMiddleware:
|
|
|
265
249
|
|
|
266
250
|
Subclass and override specific hooks as needed.
|
|
267
251
|
All hooks have sensible pass-through defaults.
|
|
252
|
+
|
|
253
|
+
Use get_current_ctx_or_none() to access InvocationContext.
|
|
254
|
+
Use self._xxx for internal state between hooks.
|
|
268
255
|
"""
|
|
269
256
|
|
|
270
257
|
_config: MiddlewareConfig = MiddlewareConfig()
|
|
@@ -278,7 +265,6 @@ class BaseMiddleware:
|
|
|
278
265
|
async def on_request(
|
|
279
266
|
self,
|
|
280
267
|
request: dict[str, Any],
|
|
281
|
-
context: dict[str, Any],
|
|
282
268
|
) -> dict[str, Any] | None:
|
|
283
269
|
"""Default: pass through."""
|
|
284
270
|
return request
|
|
@@ -286,7 +272,6 @@ class BaseMiddleware:
|
|
|
286
272
|
async def on_response(
|
|
287
273
|
self,
|
|
288
274
|
response: dict[str, Any],
|
|
289
|
-
context: dict[str, Any],
|
|
290
275
|
) -> dict[str, Any] | None:
|
|
291
276
|
"""Default: pass through."""
|
|
292
277
|
return response
|
|
@@ -294,7 +279,6 @@ class BaseMiddleware:
|
|
|
294
279
|
async def on_error(
|
|
295
280
|
self,
|
|
296
281
|
error: Exception,
|
|
297
|
-
context: dict[str, Any],
|
|
298
282
|
) -> Exception | None:
|
|
299
283
|
"""Default: re-raise error."""
|
|
300
284
|
return error
|
|
@@ -302,7 +286,6 @@ class BaseMiddleware:
|
|
|
302
286
|
async def on_model_stream(
|
|
303
287
|
self,
|
|
304
288
|
chunk: dict[str, Any],
|
|
305
|
-
context: dict[str, Any],
|
|
306
289
|
) -> dict[str, Any] | None:
|
|
307
290
|
"""Default: pass through."""
|
|
308
291
|
return chunk
|
|
@@ -310,7 +293,6 @@ class BaseMiddleware:
|
|
|
310
293
|
async def on_thinking_stream(
|
|
311
294
|
self,
|
|
312
295
|
chunk: dict[str, Any],
|
|
313
|
-
context: dict[str, Any],
|
|
314
296
|
) -> dict[str, Any] | None:
|
|
315
297
|
"""Default: pass through."""
|
|
316
298
|
return chunk
|
|
@@ -319,18 +301,14 @@ class BaseMiddleware:
|
|
|
319
301
|
|
|
320
302
|
async def on_agent_start(
|
|
321
303
|
self,
|
|
322
|
-
agent_id: str,
|
|
323
304
|
input_data: Any,
|
|
324
|
-
context: dict[str, Any],
|
|
325
305
|
) -> HookResult:
|
|
326
306
|
"""Default: continue."""
|
|
327
307
|
return HookResult.proceed()
|
|
328
308
|
|
|
329
309
|
async def on_agent_end(
|
|
330
310
|
self,
|
|
331
|
-
agent_id: str,
|
|
332
311
|
result: Any,
|
|
333
|
-
context: dict[str, Any],
|
|
334
312
|
) -> HookResult:
|
|
335
313
|
"""Default: continue."""
|
|
336
314
|
return HookResult.proceed()
|
|
@@ -339,7 +317,6 @@ class BaseMiddleware:
|
|
|
339
317
|
self,
|
|
340
318
|
tool: "BaseTool",
|
|
341
319
|
params: dict[str, Any],
|
|
342
|
-
context: dict[str, Any],
|
|
343
320
|
) -> HookResult:
|
|
344
321
|
"""Default: continue."""
|
|
345
322
|
return HookResult.proceed()
|
|
@@ -349,7 +326,7 @@ class BaseMiddleware:
|
|
|
349
326
|
call_id: str,
|
|
350
327
|
tool_name: str,
|
|
351
328
|
delta: dict[str, Any],
|
|
352
|
-
|
|
329
|
+
accumulated_args: dict[str, Any],
|
|
353
330
|
) -> dict[str, Any] | None:
|
|
354
331
|
"""Default: pass through."""
|
|
355
332
|
return delta
|
|
@@ -358,7 +335,6 @@ class BaseMiddleware:
|
|
|
358
335
|
self,
|
|
359
336
|
tool: "BaseTool",
|
|
360
337
|
result: "ToolResult",
|
|
361
|
-
context: dict[str, Any],
|
|
362
338
|
) -> HookResult:
|
|
363
339
|
"""Default: continue."""
|
|
364
340
|
return HookResult.proceed()
|
|
@@ -368,7 +344,6 @@ class BaseMiddleware:
|
|
|
368
344
|
parent_agent_id: str,
|
|
369
345
|
child_agent_id: str,
|
|
370
346
|
mode: str,
|
|
371
|
-
context: dict[str, Any],
|
|
372
347
|
) -> HookResult:
|
|
373
348
|
"""Default: continue."""
|
|
374
349
|
return HookResult.proceed()
|
|
@@ -378,7 +353,6 @@ class BaseMiddleware:
|
|
|
378
353
|
parent_agent_id: str,
|
|
379
354
|
child_agent_id: str,
|
|
380
355
|
result: Any,
|
|
381
|
-
context: dict[str, Any],
|
|
382
356
|
) -> HookResult:
|
|
383
357
|
"""Default: continue."""
|
|
384
358
|
return HookResult.proceed()
|
|
@@ -386,7 +360,6 @@ class BaseMiddleware:
|
|
|
386
360
|
async def on_message_save(
|
|
387
361
|
self,
|
|
388
362
|
message: dict[str, Any],
|
|
389
|
-
context: dict[str, Any],
|
|
390
363
|
) -> dict[str, Any] | None:
|
|
391
364
|
"""Default: pass through."""
|
|
392
365
|
return message
|
aury/agents/middleware/chain.py
CHANGED
|
@@ -112,14 +112,13 @@ class MiddlewareChain:
|
|
|
112
112
|
async def process_request(
|
|
113
113
|
self,
|
|
114
114
|
request: dict[str, Any],
|
|
115
|
-
context: dict[str, Any],
|
|
116
115
|
) -> dict[str, Any] | None:
|
|
117
116
|
"""Process request through all middlewares."""
|
|
118
117
|
current = request
|
|
119
118
|
logger.debug(f"Processing request through {len(self._middlewares)} middlewares")
|
|
120
119
|
|
|
121
120
|
for i, mw in enumerate(self._middlewares):
|
|
122
|
-
result = await mw.on_request(current
|
|
121
|
+
result = await mw.on_request(current)
|
|
123
122
|
if result is None:
|
|
124
123
|
logger.info(f"Middleware #{i} blocked request")
|
|
125
124
|
return None
|
|
@@ -131,14 +130,13 @@ class MiddlewareChain:
|
|
|
131
130
|
async def process_response(
|
|
132
131
|
self,
|
|
133
132
|
response: dict[str, Any],
|
|
134
|
-
context: dict[str, Any],
|
|
135
133
|
) -> dict[str, Any] | None:
|
|
136
134
|
"""Process response through all middlewares (reverse order)."""
|
|
137
135
|
current = response
|
|
138
136
|
logger.debug(f"Processing response through {len(self._middlewares)} middlewares (reverse order)")
|
|
139
137
|
|
|
140
138
|
for i, mw in enumerate(reversed(self._middlewares)):
|
|
141
|
-
result = await mw.on_response(current
|
|
139
|
+
result = await mw.on_response(current)
|
|
142
140
|
if result is None:
|
|
143
141
|
logger.info(f"Middleware #{i} blocked response")
|
|
144
142
|
return None
|
|
@@ -150,14 +148,13 @@ class MiddlewareChain:
|
|
|
150
148
|
async def process_error(
|
|
151
149
|
self,
|
|
152
150
|
error: Exception,
|
|
153
|
-
context: dict[str, Any],
|
|
154
151
|
) -> Exception | None:
|
|
155
152
|
"""Process error through all middlewares."""
|
|
156
153
|
current = error
|
|
157
154
|
logger.debug(f"Processing error {type(error).__name__} through {len(self._middlewares)} middlewares")
|
|
158
155
|
|
|
159
156
|
for i, mw in enumerate(self._middlewares):
|
|
160
|
-
result = await mw.on_error(current
|
|
157
|
+
result = await mw.on_error(current)
|
|
161
158
|
if result is None:
|
|
162
159
|
logger.info(f"Middleware #{i} suppressed error")
|
|
163
160
|
return None
|
|
@@ -169,13 +166,11 @@ class MiddlewareChain:
|
|
|
169
166
|
async def process_stream_chunk(
|
|
170
167
|
self,
|
|
171
168
|
chunk: dict[str, Any],
|
|
172
|
-
context: dict[str, Any],
|
|
173
169
|
) -> dict[str, Any] | None:
|
|
174
170
|
"""Process streaming chunk through middlewares based on trigger mode."""
|
|
175
171
|
text = chunk.get("text", chunk.get("delta", ""))
|
|
176
172
|
self._token_buffer += text
|
|
177
173
|
self._token_count += 1
|
|
178
|
-
logger.debug(f"Processing stream chunk, token_count={self._token_count}, triggered_middlewares=?")
|
|
179
174
|
|
|
180
175
|
current = chunk
|
|
181
176
|
triggered_count = 0
|
|
@@ -185,14 +180,16 @@ class MiddlewareChain:
|
|
|
185
180
|
|
|
186
181
|
if should_trigger:
|
|
187
182
|
triggered_count += 1
|
|
188
|
-
|
|
189
|
-
result = await mw.on_model_stream(current, context)
|
|
183
|
+
result = await mw.on_model_stream(current)
|
|
190
184
|
if result is None:
|
|
191
185
|
logger.info(f"Middleware #{i} blocked stream chunk")
|
|
192
186
|
return None
|
|
193
187
|
current = result
|
|
194
188
|
|
|
195
|
-
|
|
189
|
+
# Log only every 50 tokens to reduce noise
|
|
190
|
+
if self._token_count % 50 == 0:
|
|
191
|
+
logger.debug(f"Stream progress: token_count={self._token_count}, middlewares={len(self._middlewares)}")
|
|
192
|
+
|
|
196
193
|
return current
|
|
197
194
|
|
|
198
195
|
async def process_tool_call_delta(
|
|
@@ -200,7 +197,7 @@ class MiddlewareChain:
|
|
|
200
197
|
call_id: str,
|
|
201
198
|
tool_name: str,
|
|
202
199
|
delta: dict[str, Any],
|
|
203
|
-
|
|
200
|
+
accumulated_args: dict[str, Any],
|
|
204
201
|
) -> dict[str, Any] | None:
|
|
205
202
|
"""Process tool call delta through all middlewares.
|
|
206
203
|
|
|
@@ -208,7 +205,7 @@ class MiddlewareChain:
|
|
|
208
205
|
call_id: Tool call identifier
|
|
209
206
|
tool_name: Name of the tool being called
|
|
210
207
|
delta: Incremental parameter update
|
|
211
|
-
|
|
208
|
+
accumulated_args: Current accumulated arguments state
|
|
212
209
|
|
|
213
210
|
Returns:
|
|
214
211
|
Modified delta, or None to skip emitting
|
|
@@ -217,7 +214,7 @@ class MiddlewareChain:
|
|
|
217
214
|
logger.debug(f"Processing tool_call_delta for {tool_name} (call_id={call_id}) through {len(self._middlewares)} middlewares")
|
|
218
215
|
|
|
219
216
|
for i, mw in enumerate(self._middlewares):
|
|
220
|
-
result = await mw.on_tool_call_delta(call_id, tool_name, current,
|
|
217
|
+
result = await mw.on_tool_call_delta(call_id, tool_name, current, accumulated_args)
|
|
221
218
|
if result is None:
|
|
222
219
|
logger.info(f"Middleware #{i} blocked tool_call_delta")
|
|
223
220
|
return None
|
|
@@ -259,19 +256,17 @@ class MiddlewareChain:
|
|
|
259
256
|
|
|
260
257
|
async def process_agent_start(
|
|
261
258
|
self,
|
|
262
|
-
agent_id: str,
|
|
263
259
|
input_data: Any,
|
|
264
|
-
context: dict[str, Any],
|
|
265
260
|
) -> HookResult:
|
|
266
261
|
"""Process agent start through all middlewares.
|
|
267
262
|
|
|
268
263
|
Returns:
|
|
269
264
|
First non-CONTINUE result, or CONTINUE if all pass
|
|
270
265
|
"""
|
|
271
|
-
logger.debug(f"Processing agent_start
|
|
266
|
+
logger.debug(f"Processing agent_start, {len(self._middlewares)} middlewares")
|
|
272
267
|
for i, mw in enumerate(self._middlewares):
|
|
273
268
|
if hasattr(mw, 'on_agent_start'):
|
|
274
|
-
result = await mw.on_agent_start(
|
|
269
|
+
result = await mw.on_agent_start(input_data)
|
|
275
270
|
if result.action != HookAction.CONTINUE:
|
|
276
271
|
logger.info(f"Middleware #{i} returned {result.action} on agent_start")
|
|
277
272
|
return result
|
|
@@ -280,15 +275,13 @@ class MiddlewareChain:
|
|
|
280
275
|
|
|
281
276
|
async def process_agent_end(
|
|
282
277
|
self,
|
|
283
|
-
agent_id: str,
|
|
284
278
|
result: Any,
|
|
285
|
-
context: dict[str, Any],
|
|
286
279
|
) -> HookResult:
|
|
287
280
|
"""Process agent end through all middlewares (reverse order)."""
|
|
288
|
-
logger.debug(f"Processing agent_end
|
|
281
|
+
logger.debug(f"Processing agent_end, {len(self._middlewares)} middlewares (reverse order)")
|
|
289
282
|
for i, mw in enumerate(reversed(self._middlewares)):
|
|
290
283
|
if hasattr(mw, 'on_agent_end'):
|
|
291
|
-
hook_result = await mw.on_agent_end(
|
|
284
|
+
hook_result = await mw.on_agent_end(result)
|
|
292
285
|
if hook_result.action != HookAction.CONTINUE:
|
|
293
286
|
logger.info(f"Middleware #{i} returned {hook_result.action} on agent_end")
|
|
294
287
|
return hook_result
|
|
@@ -299,7 +292,6 @@ class MiddlewareChain:
|
|
|
299
292
|
self,
|
|
300
293
|
tool: "BaseTool",
|
|
301
294
|
params: dict[str, Any],
|
|
302
|
-
context: dict[str, Any],
|
|
303
295
|
) -> HookResult:
|
|
304
296
|
"""Process tool call through all middlewares.
|
|
305
297
|
|
|
@@ -309,7 +301,7 @@ class MiddlewareChain:
|
|
|
309
301
|
logger.debug(f"Processing tool_call for tool={tool.name}, {len(self._middlewares)} middlewares")
|
|
310
302
|
for i, mw in enumerate(self._middlewares):
|
|
311
303
|
if hasattr(mw, 'on_tool_call'):
|
|
312
|
-
result = await mw.on_tool_call(tool, params
|
|
304
|
+
result = await mw.on_tool_call(tool, params)
|
|
313
305
|
if result.action != HookAction.CONTINUE:
|
|
314
306
|
logger.info(f"Middleware #{i} returned {result.action} on tool_call for tool={tool.name}")
|
|
315
307
|
return result
|
|
@@ -320,13 +312,12 @@ class MiddlewareChain:
|
|
|
320
312
|
self,
|
|
321
313
|
tool: "BaseTool",
|
|
322
314
|
result: "ToolResult",
|
|
323
|
-
context: dict[str, Any],
|
|
324
315
|
) -> HookResult:
|
|
325
316
|
"""Process tool end through all middlewares (reverse order)."""
|
|
326
317
|
logger.debug(f"Processing tool_end for tool={tool.name}, {len(self._middlewares)} middlewares (reverse order)")
|
|
327
318
|
for i, mw in enumerate(reversed(self._middlewares)):
|
|
328
319
|
if hasattr(mw, 'on_tool_end'):
|
|
329
|
-
hook_result = await mw.on_tool_end(tool, result
|
|
320
|
+
hook_result = await mw.on_tool_end(tool, result)
|
|
330
321
|
if hook_result.action != HookAction.CONTINUE:
|
|
331
322
|
logger.info(f"Middleware #{i} returned {hook_result.action} on tool_end for tool={tool.name}")
|
|
332
323
|
return hook_result
|
|
@@ -338,14 +329,13 @@ class MiddlewareChain:
|
|
|
338
329
|
parent_agent_id: str,
|
|
339
330
|
child_agent_id: str,
|
|
340
331
|
mode: str,
|
|
341
|
-
context: dict[str, Any],
|
|
342
332
|
) -> HookResult:
|
|
343
333
|
"""Process sub-agent start through all middlewares."""
|
|
344
334
|
logger.debug(f"Processing subagent_start, parent={parent_agent_id}, child={child_agent_id}, mode={mode}, {len(self._middlewares)} middlewares")
|
|
345
335
|
for i, mw in enumerate(self._middlewares):
|
|
346
336
|
if hasattr(mw, 'on_subagent_start'):
|
|
347
337
|
result = await mw.on_subagent_start(
|
|
348
|
-
parent_agent_id, child_agent_id, mode
|
|
338
|
+
parent_agent_id, child_agent_id, mode
|
|
349
339
|
)
|
|
350
340
|
if result.action != HookAction.CONTINUE:
|
|
351
341
|
logger.info(f"Middleware #{i} returned {result.action} on subagent_start")
|
|
@@ -358,14 +348,13 @@ class MiddlewareChain:
|
|
|
358
348
|
parent_agent_id: str,
|
|
359
349
|
child_agent_id: str,
|
|
360
350
|
result: Any,
|
|
361
|
-
context: dict[str, Any],
|
|
362
351
|
) -> HookResult:
|
|
363
352
|
"""Process sub-agent end through all middlewares (reverse order)."""
|
|
364
353
|
logger.debug(f"Processing subagent_end, parent={parent_agent_id}, child={child_agent_id}, {len(self._middlewares)} middlewares (reverse order)")
|
|
365
354
|
for i, mw in enumerate(reversed(self._middlewares)):
|
|
366
355
|
if hasattr(mw, 'on_subagent_end'):
|
|
367
356
|
hook_result = await mw.on_subagent_end(
|
|
368
|
-
parent_agent_id, child_agent_id, result
|
|
357
|
+
parent_agent_id, child_agent_id, result
|
|
369
358
|
)
|
|
370
359
|
if hook_result.action != HookAction.CONTINUE:
|
|
371
360
|
logger.info(f"Middleware #{i} returned {hook_result.action} on subagent_end")
|
|
@@ -376,13 +365,11 @@ class MiddlewareChain:
|
|
|
376
365
|
async def process_message_save(
|
|
377
366
|
self,
|
|
378
367
|
message: dict[str, Any],
|
|
379
|
-
context: dict[str, Any],
|
|
380
368
|
) -> dict[str, Any] | None:
|
|
381
369
|
"""Process message save through all middlewares.
|
|
382
370
|
|
|
383
371
|
Args:
|
|
384
372
|
message: Message to be saved
|
|
385
|
-
context: Execution context
|
|
386
373
|
|
|
387
374
|
Returns:
|
|
388
375
|
Modified message, or None to skip saving
|
|
@@ -392,7 +379,7 @@ class MiddlewareChain:
|
|
|
392
379
|
|
|
393
380
|
for i, mw in enumerate(self._middlewares):
|
|
394
381
|
if hasattr(mw, 'on_message_save'):
|
|
395
|
-
result = await mw.on_message_save(current
|
|
382
|
+
result = await mw.on_message_save(current)
|
|
396
383
|
if result is None:
|
|
397
384
|
logger.info(f"Middleware #{i} blocked message save for role={message.get('role')}")
|
|
398
385
|
return None
|
|
@@ -56,37 +56,34 @@ class MessageBackendMiddleware(BaseMiddleware):
|
|
|
56
56
|
async def on_message_save(
|
|
57
57
|
self,
|
|
58
58
|
message: dict[str, Any],
|
|
59
|
-
context: dict[str, Any],
|
|
60
59
|
) -> dict[str, Any] | None:
|
|
61
60
|
"""Save message via backends.message.
|
|
62
61
|
|
|
63
62
|
Args:
|
|
64
63
|
message: Message dict with 'role', 'content', etc.
|
|
65
|
-
context: Execution context with 'session_id', 'agent_id', 'backends'
|
|
66
64
|
|
|
67
65
|
Returns:
|
|
68
66
|
The message (pass through to other middlewares)
|
|
69
67
|
"""
|
|
70
68
|
from ..core.context import get_current_ctx_or_none
|
|
71
69
|
|
|
72
|
-
session_id = context.get("session_id", "")
|
|
73
|
-
if not session_id:
|
|
74
|
-
return message
|
|
75
|
-
|
|
76
70
|
# Get MessageBackend from context
|
|
77
71
|
ctx = get_current_ctx_or_none()
|
|
78
72
|
if ctx is None or ctx.backends is None or ctx.backends.message is None:
|
|
79
73
|
# No backend available, pass through
|
|
80
74
|
return message
|
|
81
75
|
|
|
76
|
+
session_id = ctx.session_id or ""
|
|
77
|
+
if not session_id:
|
|
78
|
+
return message
|
|
79
|
+
|
|
82
80
|
backend = ctx.backends.message
|
|
83
81
|
|
|
84
82
|
# Extract message fields
|
|
85
83
|
role = message.get("role", "")
|
|
86
84
|
content = message.get("content", "")
|
|
87
|
-
invocation_id =
|
|
88
|
-
agent_id =
|
|
89
|
-
namespace = context.get("namespace")
|
|
85
|
+
invocation_id = ctx.invocation_id or ""
|
|
86
|
+
agent_id = ctx.agent_id
|
|
90
87
|
tool_call_id = message.get("tool_call_id")
|
|
91
88
|
|
|
92
89
|
# Build message dict for backend
|
|
@@ -107,7 +104,6 @@ class MessageBackendMiddleware(BaseMiddleware):
|
|
|
107
104
|
message=msg_dict,
|
|
108
105
|
type="truncated",
|
|
109
106
|
agent_id=agent_id,
|
|
110
|
-
namespace=namespace,
|
|
111
107
|
invocation_id=invocation_id,
|
|
112
108
|
)
|
|
113
109
|
|
|
@@ -118,7 +114,6 @@ class MessageBackendMiddleware(BaseMiddleware):
|
|
|
118
114
|
message=message, # Full original message
|
|
119
115
|
type="raw",
|
|
120
116
|
agent_id=agent_id,
|
|
121
|
-
namespace=namespace,
|
|
122
117
|
invocation_id=invocation_id,
|
|
123
118
|
)
|
|
124
119
|
|