aury-agent 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/agents/context_providers/message.py +8 -5
- aury/agents/core/base.py +11 -0
- aury/agents/core/factory.py +8 -0
- aury/agents/core/parallel.py +26 -4
- aury/agents/core/state.py +25 -0
- aury/agents/core/types/tool.py +1 -0
- aury/agents/hitl/ask_user.py +44 -0
- aury/agents/llm/adapter.py +55 -26
- aury/agents/llm/openai.py +5 -1
- aury/agents/memory/manager.py +33 -2
- aury/agents/messages/store.py +27 -1
- aury/agents/middleware/base.py +57 -0
- aury/agents/middleware/chain.py +81 -18
- aury/agents/react/agent.py +161 -1484
- aury/agents/react/context.py +309 -0
- aury/agents/react/factory.py +301 -0
- aury/agents/react/pause.py +241 -0
- aury/agents/react/persistence.py +182 -0
- aury/agents/react/step.py +680 -0
- aury/agents/react/tools.py +323 -0
- aury/agents/tool/builtin/bash.py +11 -0
- aury/agents/tool/builtin/delegate.py +38 -3
- aury/agents/tool/builtin/edit.py +16 -0
- aury/agents/tool/builtin/plan.py +19 -0
- aury/agents/tool/builtin/read.py +13 -0
- aury/agents/tool/builtin/thinking.py +10 -4
- aury/agents/tool/builtin/yield_result.py +9 -6
- aury/agents/tool/set.py +23 -0
- aury/agents/workflow/adapter.py +22 -3
- aury/agents/workflow/executor.py +51 -7
- {aury_agent-0.0.4.dist-info → aury_agent-0.0.6.dist-info}/METADATA +1 -1
- {aury_agent-0.0.4.dist-info → aury_agent-0.0.6.dist-info}/RECORD +34 -28
- {aury_agent-0.0.4.dist-info → aury_agent-0.0.6.dist-info}/WHEEL +0 -0
- {aury_agent-0.0.4.dist-info → aury_agent-0.0.6.dist-info}/entry_points.txt +0 -0
aury/agents/messages/store.py
CHANGED
|
@@ -72,10 +72,22 @@ class InMemoryMessageStore:
|
|
|
72
72
|
message: Message,
|
|
73
73
|
namespace: str | None = None,
|
|
74
74
|
) -> None:
|
|
75
|
+
from ..core.logging import storage_logger as logger
|
|
76
|
+
|
|
75
77
|
key = self._make_key(session_id, namespace)
|
|
76
78
|
if key not in self._messages:
|
|
77
79
|
self._messages[key] = []
|
|
78
80
|
self._messages[key].append(message)
|
|
81
|
+
|
|
82
|
+
logger.debug(
|
|
83
|
+
"Message stored",
|
|
84
|
+
extra={
|
|
85
|
+
"session_id": session_id,
|
|
86
|
+
"invocation_id": getattr(message, "invocation_id", None),
|
|
87
|
+
"role": getattr(message, "role", None),
|
|
88
|
+
"namespace": namespace,
|
|
89
|
+
},
|
|
90
|
+
)
|
|
79
91
|
|
|
80
92
|
async def get_all(
|
|
81
93
|
self,
|
|
@@ -101,6 +113,8 @@ class InMemoryMessageStore:
|
|
|
101
113
|
invocation_id: str,
|
|
102
114
|
namespace: str | None = None,
|
|
103
115
|
) -> int:
|
|
116
|
+
from ..core.logging import storage_logger as logger
|
|
117
|
+
|
|
104
118
|
key = self._make_key(session_id, namespace)
|
|
105
119
|
if key not in self._messages:
|
|
106
120
|
return 0
|
|
@@ -109,7 +123,19 @@ class InMemoryMessageStore:
|
|
|
109
123
|
self._messages[key] = [
|
|
110
124
|
m for m in original if m.invocation_id != invocation_id
|
|
111
125
|
]
|
|
112
|
-
|
|
126
|
+
deleted_count = len(original) - len(self._messages[key])
|
|
127
|
+
|
|
128
|
+
if deleted_count > 0:
|
|
129
|
+
logger.debug(
|
|
130
|
+
"Messages deleted by invocation",
|
|
131
|
+
extra={
|
|
132
|
+
"session_id": session_id,
|
|
133
|
+
"invocation_id": invocation_id,
|
|
134
|
+
"count": deleted_count,
|
|
135
|
+
},
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return deleted_count
|
|
113
139
|
|
|
114
140
|
|
|
115
141
|
__all__ = [
|
aury/agents/middleware/base.py
CHANGED
|
@@ -87,6 +87,22 @@ class Middleware(Protocol):
|
|
|
87
87
|
"""
|
|
88
88
|
...
|
|
89
89
|
|
|
90
|
+
async def on_thinking_stream(
|
|
91
|
+
self,
|
|
92
|
+
chunk: dict[str, Any],
|
|
93
|
+
context: dict[str, Any],
|
|
94
|
+
) -> dict[str, Any] | None:
|
|
95
|
+
"""Process thinking stream chunk.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
chunk: The thinking chunk with {"delta": str}
|
|
99
|
+
context: Execution context
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Modified chunk, or None to skip
|
|
103
|
+
"""
|
|
104
|
+
...
|
|
105
|
+
|
|
90
106
|
# ========== Agent Lifecycle Hooks ==========
|
|
91
107
|
|
|
92
108
|
async def on_agent_start(
|
|
@@ -143,6 +159,29 @@ class Middleware(Protocol):
|
|
|
143
159
|
"""
|
|
144
160
|
...
|
|
145
161
|
|
|
162
|
+
async def on_tool_call_delta(
|
|
163
|
+
self,
|
|
164
|
+
call_id: str,
|
|
165
|
+
tool_name: str,
|
|
166
|
+
delta: dict[str, Any],
|
|
167
|
+
context: dict[str, Any],
|
|
168
|
+
) -> dict[str, Any] | None:
|
|
169
|
+
"""Called during streaming tool argument generation.
|
|
170
|
+
|
|
171
|
+
Only triggered for tools with stream_arguments=True.
|
|
172
|
+
Receives incremental updates as LLM generates tool parameters.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
call_id: Tool call identifier
|
|
176
|
+
tool_name: Name of the tool being called
|
|
177
|
+
delta: Incremental parameter update (e.g. {"content": "more text"})
|
|
178
|
+
context: Execution context with 'accumulated_args' containing current state
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Modified delta, or None to skip emitting this delta
|
|
182
|
+
"""
|
|
183
|
+
...
|
|
184
|
+
|
|
146
185
|
async def on_tool_end(
|
|
147
186
|
self,
|
|
148
187
|
tool: "BaseTool",
|
|
@@ -268,6 +307,14 @@ class BaseMiddleware:
|
|
|
268
307
|
"""Default: pass through."""
|
|
269
308
|
return chunk
|
|
270
309
|
|
|
310
|
+
async def on_thinking_stream(
|
|
311
|
+
self,
|
|
312
|
+
chunk: dict[str, Any],
|
|
313
|
+
context: dict[str, Any],
|
|
314
|
+
) -> dict[str, Any] | None:
|
|
315
|
+
"""Default: pass through."""
|
|
316
|
+
return chunk
|
|
317
|
+
|
|
271
318
|
# ========== Agent Lifecycle Hooks ==========
|
|
272
319
|
|
|
273
320
|
async def on_agent_start(
|
|
@@ -297,6 +344,16 @@ class BaseMiddleware:
|
|
|
297
344
|
"""Default: continue."""
|
|
298
345
|
return HookResult.proceed()
|
|
299
346
|
|
|
347
|
+
async def on_tool_call_delta(
|
|
348
|
+
self,
|
|
349
|
+
call_id: str,
|
|
350
|
+
tool_name: str,
|
|
351
|
+
delta: dict[str, Any],
|
|
352
|
+
context: dict[str, Any],
|
|
353
|
+
) -> dict[str, Any] | None:
|
|
354
|
+
"""Default: pass through."""
|
|
355
|
+
return delta
|
|
356
|
+
|
|
300
357
|
async def on_tool_end(
|
|
301
358
|
self,
|
|
302
359
|
tool: "BaseTool",
|
aury/agents/middleware/chain.py
CHANGED
|
@@ -29,6 +29,7 @@ class MiddlewareChain:
|
|
|
29
29
|
|
|
30
30
|
# Add initial middlewares if provided
|
|
31
31
|
if middlewares:
|
|
32
|
+
logger.debug(f"MiddlewareChain init with {len(middlewares)} middlewares")
|
|
32
33
|
for mw in middlewares:
|
|
33
34
|
self.use(mw)
|
|
34
35
|
|
|
@@ -55,6 +56,7 @@ class MiddlewareChain:
|
|
|
55
56
|
entry = MiddlewareEntry(middleware=middleware, inherit=effective_inherit)
|
|
56
57
|
self._entries.append(entry)
|
|
57
58
|
self._entries.sort(key=lambda e: e.middleware.config.priority)
|
|
59
|
+
logger.debug(f"Added middleware to chain, priority={middleware.config.priority}, inherit={effective_inherit}, total={len(self._entries)}")
|
|
58
60
|
return self
|
|
59
61
|
|
|
60
62
|
def remove(self, middleware: Middleware) -> "MiddlewareChain":
|
|
@@ -114,13 +116,16 @@ class MiddlewareChain:
|
|
|
114
116
|
) -> dict[str, Any] | None:
|
|
115
117
|
"""Process request through all middlewares."""
|
|
116
118
|
current = request
|
|
119
|
+
logger.debug(f"Processing request through {len(self._middlewares)} middlewares")
|
|
117
120
|
|
|
118
|
-
for mw in self._middlewares:
|
|
121
|
+
for i, mw in enumerate(self._middlewares):
|
|
119
122
|
result = await mw.on_request(current, context)
|
|
120
123
|
if result is None:
|
|
124
|
+
logger.info(f"Middleware #{i} blocked request")
|
|
121
125
|
return None
|
|
122
126
|
current = result
|
|
123
127
|
|
|
128
|
+
logger.debug("Request processing completed")
|
|
124
129
|
return current
|
|
125
130
|
|
|
126
131
|
async def process_response(
|
|
@@ -130,13 +135,16 @@ class MiddlewareChain:
|
|
|
130
135
|
) -> dict[str, Any] | None:
|
|
131
136
|
"""Process response through all middlewares (reverse order)."""
|
|
132
137
|
current = response
|
|
138
|
+
logger.debug(f"Processing response through {len(self._middlewares)} middlewares (reverse order)")
|
|
133
139
|
|
|
134
|
-
for mw in reversed(self._middlewares):
|
|
140
|
+
for i, mw in enumerate(reversed(self._middlewares)):
|
|
135
141
|
result = await mw.on_response(current, context)
|
|
136
142
|
if result is None:
|
|
143
|
+
logger.info(f"Middleware #{i} blocked response")
|
|
137
144
|
return None
|
|
138
145
|
current = result
|
|
139
146
|
|
|
147
|
+
logger.debug("Response processing completed")
|
|
140
148
|
return current
|
|
141
149
|
|
|
142
150
|
async def process_error(
|
|
@@ -146,13 +154,16 @@ class MiddlewareChain:
|
|
|
146
154
|
) -> Exception | None:
|
|
147
155
|
"""Process error through all middlewares."""
|
|
148
156
|
current = error
|
|
157
|
+
logger.debug(f"Processing error {type(error).__name__} through {len(self._middlewares)} middlewares")
|
|
149
158
|
|
|
150
|
-
for mw in self._middlewares:
|
|
159
|
+
for i, mw in enumerate(self._middlewares):
|
|
151
160
|
result = await mw.on_error(current, context)
|
|
152
161
|
if result is None:
|
|
162
|
+
logger.info(f"Middleware #{i} suppressed error")
|
|
153
163
|
return None
|
|
154
164
|
current = result
|
|
155
165
|
|
|
166
|
+
logger.debug("Error processing completed")
|
|
156
167
|
return current
|
|
157
168
|
|
|
158
169
|
async def process_stream_chunk(
|
|
@@ -164,18 +175,55 @@ class MiddlewareChain:
|
|
|
164
175
|
text = chunk.get("text", chunk.get("delta", ""))
|
|
165
176
|
self._token_buffer += text
|
|
166
177
|
self._token_count += 1
|
|
178
|
+
logger.debug(f"Processing stream chunk, token_count={self._token_count}, triggered_middlewares=?")
|
|
167
179
|
|
|
168
180
|
current = chunk
|
|
181
|
+
triggered_count = 0
|
|
169
182
|
|
|
170
|
-
for mw in self._middlewares:
|
|
183
|
+
for i, mw in enumerate(self._middlewares):
|
|
171
184
|
should_trigger = self._should_trigger(mw, text)
|
|
172
185
|
|
|
173
186
|
if should_trigger:
|
|
187
|
+
triggered_count += 1
|
|
188
|
+
logger.debug(f"Middleware #{i} triggered, mode={mw.config.trigger_mode}")
|
|
174
189
|
result = await mw.on_model_stream(current, context)
|
|
175
190
|
if result is None:
|
|
191
|
+
logger.info(f"Middleware #{i} blocked stream chunk")
|
|
176
192
|
return None
|
|
177
193
|
current = result
|
|
178
194
|
|
|
195
|
+
logger.debug(f"Stream chunk processing completed, {triggered_count} middlewares triggered")
|
|
196
|
+
return current
|
|
197
|
+
|
|
198
|
+
async def process_tool_call_delta(
|
|
199
|
+
self,
|
|
200
|
+
call_id: str,
|
|
201
|
+
tool_name: str,
|
|
202
|
+
delta: dict[str, Any],
|
|
203
|
+
context: dict[str, Any],
|
|
204
|
+
) -> dict[str, Any] | None:
|
|
205
|
+
"""Process tool call delta through all middlewares.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
call_id: Tool call identifier
|
|
209
|
+
tool_name: Name of the tool being called
|
|
210
|
+
delta: Incremental parameter update
|
|
211
|
+
context: Execution context with 'accumulated_args'
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Modified delta, or None to skip emitting
|
|
215
|
+
"""
|
|
216
|
+
current = delta
|
|
217
|
+
logger.debug(f"Processing tool_call_delta for {tool_name} (call_id={call_id}) through {len(self._middlewares)} middlewares")
|
|
218
|
+
|
|
219
|
+
for i, mw in enumerate(self._middlewares):
|
|
220
|
+
result = await mw.on_tool_call_delta(call_id, tool_name, current, context)
|
|
221
|
+
if result is None:
|
|
222
|
+
logger.info(f"Middleware #{i} blocked tool_call_delta")
|
|
223
|
+
return None
|
|
224
|
+
current = result
|
|
225
|
+
|
|
226
|
+
logger.debug("Tool call delta processing completed")
|
|
179
227
|
return current
|
|
180
228
|
|
|
181
229
|
def _should_trigger(self, middleware: Middleware, text: str) -> bool:
|
|
@@ -198,6 +246,7 @@ class MiddlewareChain:
|
|
|
198
246
|
|
|
199
247
|
def reset_stream_state(self) -> None:
|
|
200
248
|
"""Reset streaming state (call at start of new stream)."""
|
|
249
|
+
logger.debug("Resetting stream state")
|
|
201
250
|
self._token_buffer = ""
|
|
202
251
|
self._token_count = 0
|
|
203
252
|
|
|
@@ -219,12 +268,14 @@ class MiddlewareChain:
|
|
|
219
268
|
Returns:
|
|
220
269
|
First non-CONTINUE result, or CONTINUE if all pass
|
|
221
270
|
"""
|
|
222
|
-
for
|
|
271
|
+
logger.debug(f"Processing agent_start for agent_id={agent_id}, {len(self._middlewares)} middlewares")
|
|
272
|
+
for i, mw in enumerate(self._middlewares):
|
|
223
273
|
if hasattr(mw, 'on_agent_start'):
|
|
224
274
|
result = await mw.on_agent_start(agent_id, input_data, context)
|
|
225
275
|
if result.action != HookAction.CONTINUE:
|
|
226
|
-
logger.
|
|
276
|
+
logger.info(f"Middleware #{i} returned {result.action} on agent_start")
|
|
227
277
|
return result
|
|
278
|
+
logger.debug("Agent start processing completed, all middlewares passed")
|
|
228
279
|
return HookResult.proceed()
|
|
229
280
|
|
|
230
281
|
async def process_agent_end(
|
|
@@ -234,12 +285,14 @@ class MiddlewareChain:
|
|
|
234
285
|
context: dict[str, Any],
|
|
235
286
|
) -> HookResult:
|
|
236
287
|
"""Process agent end through all middlewares (reverse order)."""
|
|
237
|
-
for
|
|
288
|
+
logger.debug(f"Processing agent_end for agent_id={agent_id}, {len(self._middlewares)} middlewares (reverse order)")
|
|
289
|
+
for i, mw in enumerate(reversed(self._middlewares)):
|
|
238
290
|
if hasattr(mw, 'on_agent_end'):
|
|
239
291
|
hook_result = await mw.on_agent_end(agent_id, result, context)
|
|
240
292
|
if hook_result.action != HookAction.CONTINUE:
|
|
241
|
-
logger.
|
|
293
|
+
logger.info(f"Middleware #{i} returned {hook_result.action} on agent_end")
|
|
242
294
|
return hook_result
|
|
295
|
+
logger.debug("Agent end processing completed, all middlewares passed")
|
|
243
296
|
return HookResult.proceed()
|
|
244
297
|
|
|
245
298
|
async def process_tool_call(
|
|
@@ -253,12 +306,14 @@ class MiddlewareChain:
|
|
|
253
306
|
Returns:
|
|
254
307
|
SKIP to skip tool, RETRY with modified_data to change params
|
|
255
308
|
"""
|
|
256
|
-
for
|
|
309
|
+
logger.debug(f"Processing tool_call for tool={tool.name}, {len(self._middlewares)} middlewares")
|
|
310
|
+
for i, mw in enumerate(self._middlewares):
|
|
257
311
|
if hasattr(mw, 'on_tool_call'):
|
|
258
312
|
result = await mw.on_tool_call(tool, params, context)
|
|
259
313
|
if result.action != HookAction.CONTINUE:
|
|
260
|
-
logger.
|
|
314
|
+
logger.info(f"Middleware #{i} returned {result.action} on tool_call for tool={tool.name}")
|
|
261
315
|
return result
|
|
316
|
+
logger.debug("Tool call processing completed, all middlewares passed")
|
|
262
317
|
return HookResult.proceed()
|
|
263
318
|
|
|
264
319
|
async def process_tool_end(
|
|
@@ -268,12 +323,14 @@ class MiddlewareChain:
|
|
|
268
323
|
context: dict[str, Any],
|
|
269
324
|
) -> HookResult:
|
|
270
325
|
"""Process tool end through all middlewares (reverse order)."""
|
|
271
|
-
for
|
|
326
|
+
logger.debug(f"Processing tool_end for tool={tool.name}, {len(self._middlewares)} middlewares (reverse order)")
|
|
327
|
+
for i, mw in enumerate(reversed(self._middlewares)):
|
|
272
328
|
if hasattr(mw, 'on_tool_end'):
|
|
273
329
|
hook_result = await mw.on_tool_end(tool, result, context)
|
|
274
330
|
if hook_result.action != HookAction.CONTINUE:
|
|
275
|
-
logger.
|
|
331
|
+
logger.info(f"Middleware #{i} returned {hook_result.action} on tool_end for tool={tool.name}")
|
|
276
332
|
return hook_result
|
|
333
|
+
logger.debug("Tool end processing completed, all middlewares passed")
|
|
277
334
|
return HookResult.proceed()
|
|
278
335
|
|
|
279
336
|
async def process_subagent_start(
|
|
@@ -284,14 +341,16 @@ class MiddlewareChain:
|
|
|
284
341
|
context: dict[str, Any],
|
|
285
342
|
) -> HookResult:
|
|
286
343
|
"""Process sub-agent start through all middlewares."""
|
|
287
|
-
|
|
344
|
+
logger.debug(f"Processing subagent_start, parent={parent_agent_id}, child={child_agent_id}, mode={mode}, {len(self._middlewares)} middlewares")
|
|
345
|
+
for i, mw in enumerate(self._middlewares):
|
|
288
346
|
if hasattr(mw, 'on_subagent_start'):
|
|
289
347
|
result = await mw.on_subagent_start(
|
|
290
348
|
parent_agent_id, child_agent_id, mode, context
|
|
291
349
|
)
|
|
292
350
|
if result.action != HookAction.CONTINUE:
|
|
293
|
-
logger.
|
|
351
|
+
logger.info(f"Middleware #{i} returned {result.action} on subagent_start")
|
|
294
352
|
return result
|
|
353
|
+
logger.debug("Subagent start processing completed, all middlewares passed")
|
|
295
354
|
return HookResult.proceed()
|
|
296
355
|
|
|
297
356
|
async def process_subagent_end(
|
|
@@ -302,14 +361,16 @@ class MiddlewareChain:
|
|
|
302
361
|
context: dict[str, Any],
|
|
303
362
|
) -> HookResult:
|
|
304
363
|
"""Process sub-agent end through all middlewares (reverse order)."""
|
|
305
|
-
|
|
364
|
+
logger.debug(f"Processing subagent_end, parent={parent_agent_id}, child={child_agent_id}, {len(self._middlewares)} middlewares (reverse order)")
|
|
365
|
+
for i, mw in enumerate(reversed(self._middlewares)):
|
|
306
366
|
if hasattr(mw, 'on_subagent_end'):
|
|
307
367
|
hook_result = await mw.on_subagent_end(
|
|
308
368
|
parent_agent_id, child_agent_id, result, context
|
|
309
369
|
)
|
|
310
370
|
if hook_result.action != HookAction.CONTINUE:
|
|
311
|
-
logger.
|
|
371
|
+
logger.info(f"Middleware #{i} returned {hook_result.action} on subagent_end")
|
|
312
372
|
return hook_result
|
|
373
|
+
logger.debug("Subagent end processing completed, all middlewares passed")
|
|
313
374
|
return HookResult.proceed()
|
|
314
375
|
|
|
315
376
|
async def process_message_save(
|
|
@@ -327,15 +388,17 @@ class MiddlewareChain:
|
|
|
327
388
|
Modified message, or None to skip saving
|
|
328
389
|
"""
|
|
329
390
|
current = message
|
|
391
|
+
logger.debug(f"Processing message_save, role={message.get('role')}, {len(self._middlewares)} middlewares")
|
|
330
392
|
|
|
331
|
-
for mw in self._middlewares:
|
|
393
|
+
for i, mw in enumerate(self._middlewares):
|
|
332
394
|
if hasattr(mw, 'on_message_save'):
|
|
333
395
|
result = await mw.on_message_save(current, context)
|
|
334
396
|
if result is None:
|
|
335
|
-
logger.
|
|
397
|
+
logger.info(f"Middleware #{i} blocked message save for role={message.get('role')}")
|
|
336
398
|
return None
|
|
337
399
|
current = result
|
|
338
400
|
|
|
401
|
+
logger.debug("Message save processing completed, all middlewares passed")
|
|
339
402
|
return current
|
|
340
403
|
|
|
341
404
|
|