aury-agent 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -72,10 +72,22 @@ class InMemoryMessageStore:
72
72
  message: Message,
73
73
  namespace: str | None = None,
74
74
  ) -> None:
75
+ from ..core.logging import storage_logger as logger
76
+
75
77
  key = self._make_key(session_id, namespace)
76
78
  if key not in self._messages:
77
79
  self._messages[key] = []
78
80
  self._messages[key].append(message)
81
+
82
+ logger.debug(
83
+ "Message stored",
84
+ extra={
85
+ "session_id": session_id,
86
+ "invocation_id": getattr(message, "invocation_id", None),
87
+ "role": getattr(message, "role", None),
88
+ "namespace": namespace,
89
+ },
90
+ )
79
91
 
80
92
  async def get_all(
81
93
  self,
@@ -101,6 +113,8 @@ class InMemoryMessageStore:
101
113
  invocation_id: str,
102
114
  namespace: str | None = None,
103
115
  ) -> int:
116
+ from ..core.logging import storage_logger as logger
117
+
104
118
  key = self._make_key(session_id, namespace)
105
119
  if key not in self._messages:
106
120
  return 0
@@ -109,7 +123,19 @@ class InMemoryMessageStore:
109
123
  self._messages[key] = [
110
124
  m for m in original if m.invocation_id != invocation_id
111
125
  ]
112
- return len(original) - len(self._messages[key])
126
+ deleted_count = len(original) - len(self._messages[key])
127
+
128
+ if deleted_count > 0:
129
+ logger.debug(
130
+ "Messages deleted by invocation",
131
+ extra={
132
+ "session_id": session_id,
133
+ "invocation_id": invocation_id,
134
+ "count": deleted_count,
135
+ },
136
+ )
137
+
138
+ return deleted_count
113
139
 
114
140
 
115
141
  __all__ = [
@@ -87,6 +87,22 @@ class Middleware(Protocol):
87
87
  """
88
88
  ...
89
89
 
90
+ async def on_thinking_stream(
91
+ self,
92
+ chunk: dict[str, Any],
93
+ context: dict[str, Any],
94
+ ) -> dict[str, Any] | None:
95
+ """Process thinking stream chunk.
96
+
97
+ Args:
98
+ chunk: The thinking chunk with {"delta": str}
99
+ context: Execution context
100
+
101
+ Returns:
102
+ Modified chunk, or None to skip
103
+ """
104
+ ...
105
+
90
106
  # ========== Agent Lifecycle Hooks ==========
91
107
 
92
108
  async def on_agent_start(
@@ -143,6 +159,29 @@ class Middleware(Protocol):
143
159
  """
144
160
  ...
145
161
 
162
+ async def on_tool_call_delta(
163
+ self,
164
+ call_id: str,
165
+ tool_name: str,
166
+ delta: dict[str, Any],
167
+ context: dict[str, Any],
168
+ ) -> dict[str, Any] | None:
169
+ """Called during streaming tool argument generation.
170
+
171
+ Only triggered for tools with stream_arguments=True.
172
+ Receives incremental updates as LLM generates tool parameters.
173
+
174
+ Args:
175
+ call_id: Tool call identifier
176
+ tool_name: Name of the tool being called
177
+ delta: Incremental parameter update (e.g. {"content": "more text"})
178
+ context: Execution context with 'accumulated_args' containing current state
179
+
180
+ Returns:
181
+ Modified delta, or None to skip emitting this delta
182
+ """
183
+ ...
184
+
146
185
  async def on_tool_end(
147
186
  self,
148
187
  tool: "BaseTool",
@@ -268,6 +307,14 @@ class BaseMiddleware:
268
307
  """Default: pass through."""
269
308
  return chunk
270
309
 
310
+ async def on_thinking_stream(
311
+ self,
312
+ chunk: dict[str, Any],
313
+ context: dict[str, Any],
314
+ ) -> dict[str, Any] | None:
315
+ """Default: pass through."""
316
+ return chunk
317
+
271
318
  # ========== Agent Lifecycle Hooks ==========
272
319
 
273
320
  async def on_agent_start(
@@ -297,6 +344,16 @@ class BaseMiddleware:
297
344
  """Default: continue."""
298
345
  return HookResult.proceed()
299
346
 
347
+ async def on_tool_call_delta(
348
+ self,
349
+ call_id: str,
350
+ tool_name: str,
351
+ delta: dict[str, Any],
352
+ context: dict[str, Any],
353
+ ) -> dict[str, Any] | None:
354
+ """Default: pass through."""
355
+ return delta
356
+
300
357
  async def on_tool_end(
301
358
  self,
302
359
  tool: "BaseTool",
@@ -29,6 +29,7 @@ class MiddlewareChain:
29
29
 
30
30
  # Add initial middlewares if provided
31
31
  if middlewares:
32
+ logger.debug(f"MiddlewareChain init with {len(middlewares)} middlewares")
32
33
  for mw in middlewares:
33
34
  self.use(mw)
34
35
 
@@ -55,6 +56,7 @@ class MiddlewareChain:
55
56
  entry = MiddlewareEntry(middleware=middleware, inherit=effective_inherit)
56
57
  self._entries.append(entry)
57
58
  self._entries.sort(key=lambda e: e.middleware.config.priority)
59
+ logger.debug(f"Added middleware to chain, priority={middleware.config.priority}, inherit={effective_inherit}, total={len(self._entries)}")
58
60
  return self
59
61
 
60
62
  def remove(self, middleware: Middleware) -> "MiddlewareChain":
@@ -114,13 +116,16 @@ class MiddlewareChain:
114
116
  ) -> dict[str, Any] | None:
115
117
  """Process request through all middlewares."""
116
118
  current = request
119
+ logger.debug(f"Processing request through {len(self._middlewares)} middlewares")
117
120
 
118
- for mw in self._middlewares:
121
+ for i, mw in enumerate(self._middlewares):
119
122
  result = await mw.on_request(current, context)
120
123
  if result is None:
124
+ logger.info(f"Middleware #{i} blocked request")
121
125
  return None
122
126
  current = result
123
127
 
128
+ logger.debug("Request processing completed")
124
129
  return current
125
130
 
126
131
  async def process_response(
@@ -130,13 +135,16 @@ class MiddlewareChain:
130
135
  ) -> dict[str, Any] | None:
131
136
  """Process response through all middlewares (reverse order)."""
132
137
  current = response
138
+ logger.debug(f"Processing response through {len(self._middlewares)} middlewares (reverse order)")
133
139
 
134
- for mw in reversed(self._middlewares):
140
+ for i, mw in enumerate(reversed(self._middlewares)):
135
141
  result = await mw.on_response(current, context)
136
142
  if result is None:
143
+ logger.info(f"Middleware #{i} blocked response")
137
144
  return None
138
145
  current = result
139
146
 
147
+ logger.debug("Response processing completed")
140
148
  return current
141
149
 
142
150
  async def process_error(
@@ -146,13 +154,16 @@ class MiddlewareChain:
146
154
  ) -> Exception | None:
147
155
  """Process error through all middlewares."""
148
156
  current = error
157
+ logger.debug(f"Processing error {type(error).__name__} through {len(self._middlewares)} middlewares")
149
158
 
150
- for mw in self._middlewares:
159
+ for i, mw in enumerate(self._middlewares):
151
160
  result = await mw.on_error(current, context)
152
161
  if result is None:
162
+ logger.info(f"Middleware #{i} suppressed error")
153
163
  return None
154
164
  current = result
155
165
 
166
+ logger.debug("Error processing completed")
156
167
  return current
157
168
 
158
169
  async def process_stream_chunk(
@@ -164,18 +175,55 @@ class MiddlewareChain:
164
175
  text = chunk.get("text", chunk.get("delta", ""))
165
176
  self._token_buffer += text
166
177
  self._token_count += 1
178
+ logger.debug(f"Processing stream chunk, token_count={self._token_count}, triggered_middlewares=?")
167
179
 
168
180
  current = chunk
181
+ triggered_count = 0
169
182
 
170
- for mw in self._middlewares:
183
+ for i, mw in enumerate(self._middlewares):
171
184
  should_trigger = self._should_trigger(mw, text)
172
185
 
173
186
  if should_trigger:
187
+ triggered_count += 1
188
+ logger.debug(f"Middleware #{i} triggered, mode={mw.config.trigger_mode}")
174
189
  result = await mw.on_model_stream(current, context)
175
190
  if result is None:
191
+ logger.info(f"Middleware #{i} blocked stream chunk")
176
192
  return None
177
193
  current = result
178
194
 
195
+ logger.debug(f"Stream chunk processing completed, {triggered_count} middlewares triggered")
196
+ return current
197
+
198
+ async def process_tool_call_delta(
199
+ self,
200
+ call_id: str,
201
+ tool_name: str,
202
+ delta: dict[str, Any],
203
+ context: dict[str, Any],
204
+ ) -> dict[str, Any] | None:
205
+ """Process tool call delta through all middlewares.
206
+
207
+ Args:
208
+ call_id: Tool call identifier
209
+ tool_name: Name of the tool being called
210
+ delta: Incremental parameter update
211
+ context: Execution context with 'accumulated_args'
212
+
213
+ Returns:
214
+ Modified delta, or None to skip emitting
215
+ """
216
+ current = delta
217
+ logger.debug(f"Processing tool_call_delta for {tool_name} (call_id={call_id}) through {len(self._middlewares)} middlewares")
218
+
219
+ for i, mw in enumerate(self._middlewares):
220
+ result = await mw.on_tool_call_delta(call_id, tool_name, current, context)
221
+ if result is None:
222
+ logger.info(f"Middleware #{i} blocked tool_call_delta")
223
+ return None
224
+ current = result
225
+
226
+ logger.debug("Tool call delta processing completed")
179
227
  return current
180
228
 
181
229
  def _should_trigger(self, middleware: Middleware, text: str) -> bool:
@@ -198,6 +246,7 @@ class MiddlewareChain:
198
246
 
199
247
  def reset_stream_state(self) -> None:
200
248
  """Reset streaming state (call at start of new stream)."""
249
+ logger.debug("Resetting stream state")
201
250
  self._token_buffer = ""
202
251
  self._token_count = 0
203
252
 
@@ -219,12 +268,14 @@ class MiddlewareChain:
219
268
  Returns:
220
269
  First non-CONTINUE result, or CONTINUE if all pass
221
270
  """
222
- for mw in self._middlewares:
271
+ logger.debug(f"Processing agent_start for agent_id={agent_id}, {len(self._middlewares)} middlewares")
272
+ for i, mw in enumerate(self._middlewares):
223
273
  if hasattr(mw, 'on_agent_start'):
224
274
  result = await mw.on_agent_start(agent_id, input_data, context)
225
275
  if result.action != HookAction.CONTINUE:
226
- logger.debug(f"Middleware returned {result.action} on agent_start")
276
+ logger.info(f"Middleware #{i} returned {result.action} on agent_start")
227
277
  return result
278
+ logger.debug("Agent start processing completed, all middlewares passed")
228
279
  return HookResult.proceed()
229
280
 
230
281
  async def process_agent_end(
@@ -234,12 +285,14 @@ class MiddlewareChain:
234
285
  context: dict[str, Any],
235
286
  ) -> HookResult:
236
287
  """Process agent end through all middlewares (reverse order)."""
237
- for mw in reversed(self._middlewares):
288
+ logger.debug(f"Processing agent_end for agent_id={agent_id}, {len(self._middlewares)} middlewares (reverse order)")
289
+ for i, mw in enumerate(reversed(self._middlewares)):
238
290
  if hasattr(mw, 'on_agent_end'):
239
291
  hook_result = await mw.on_agent_end(agent_id, result, context)
240
292
  if hook_result.action != HookAction.CONTINUE:
241
- logger.debug(f"Middleware returned {hook_result.action} on agent_end")
293
+ logger.info(f"Middleware #{i} returned {hook_result.action} on agent_end")
242
294
  return hook_result
295
+ logger.debug("Agent end processing completed, all middlewares passed")
243
296
  return HookResult.proceed()
244
297
 
245
298
  async def process_tool_call(
@@ -253,12 +306,14 @@ class MiddlewareChain:
253
306
  Returns:
254
307
  SKIP to skip tool, RETRY with modified_data to change params
255
308
  """
256
- for mw in self._middlewares:
309
+ logger.debug(f"Processing tool_call for tool={tool.name}, {len(self._middlewares)} middlewares")
310
+ for i, mw in enumerate(self._middlewares):
257
311
  if hasattr(mw, 'on_tool_call'):
258
312
  result = await mw.on_tool_call(tool, params, context)
259
313
  if result.action != HookAction.CONTINUE:
260
- logger.debug(f"Middleware returned {result.action} on tool_call")
314
+ logger.info(f"Middleware #{i} returned {result.action} on tool_call for tool={tool.name}")
261
315
  return result
316
+ logger.debug("Tool call processing completed, all middlewares passed")
262
317
  return HookResult.proceed()
263
318
 
264
319
  async def process_tool_end(
@@ -268,12 +323,14 @@ class MiddlewareChain:
268
323
  context: dict[str, Any],
269
324
  ) -> HookResult:
270
325
  """Process tool end through all middlewares (reverse order)."""
271
- for mw in reversed(self._middlewares):
326
+ logger.debug(f"Processing tool_end for tool={tool.name}, {len(self._middlewares)} middlewares (reverse order)")
327
+ for i, mw in enumerate(reversed(self._middlewares)):
272
328
  if hasattr(mw, 'on_tool_end'):
273
329
  hook_result = await mw.on_tool_end(tool, result, context)
274
330
  if hook_result.action != HookAction.CONTINUE:
275
- logger.debug(f"Middleware returned {hook_result.action} on tool_end")
331
+ logger.info(f"Middleware #{i} returned {hook_result.action} on tool_end for tool={tool.name}")
276
332
  return hook_result
333
+ logger.debug("Tool end processing completed, all middlewares passed")
277
334
  return HookResult.proceed()
278
335
 
279
336
  async def process_subagent_start(
@@ -284,14 +341,16 @@ class MiddlewareChain:
284
341
  context: dict[str, Any],
285
342
  ) -> HookResult:
286
343
  """Process sub-agent start through all middlewares."""
287
- for mw in self._middlewares:
344
+ logger.debug(f"Processing subagent_start, parent={parent_agent_id}, child={child_agent_id}, mode={mode}, {len(self._middlewares)} middlewares")
345
+ for i, mw in enumerate(self._middlewares):
288
346
  if hasattr(mw, 'on_subagent_start'):
289
347
  result = await mw.on_subagent_start(
290
348
  parent_agent_id, child_agent_id, mode, context
291
349
  )
292
350
  if result.action != HookAction.CONTINUE:
293
- logger.debug(f"Middleware returned {result.action} on subagent_start")
351
+ logger.info(f"Middleware #{i} returned {result.action} on subagent_start")
294
352
  return result
353
+ logger.debug("Subagent start processing completed, all middlewares passed")
295
354
  return HookResult.proceed()
296
355
 
297
356
  async def process_subagent_end(
@@ -302,14 +361,16 @@ class MiddlewareChain:
302
361
  context: dict[str, Any],
303
362
  ) -> HookResult:
304
363
  """Process sub-agent end through all middlewares (reverse order)."""
305
- for mw in reversed(self._middlewares):
364
+ logger.debug(f"Processing subagent_end, parent={parent_agent_id}, child={child_agent_id}, {len(self._middlewares)} middlewares (reverse order)")
365
+ for i, mw in enumerate(reversed(self._middlewares)):
306
366
  if hasattr(mw, 'on_subagent_end'):
307
367
  hook_result = await mw.on_subagent_end(
308
368
  parent_agent_id, child_agent_id, result, context
309
369
  )
310
370
  if hook_result.action != HookAction.CONTINUE:
311
- logger.debug(f"Middleware returned {hook_result.action} on subagent_end")
371
+ logger.info(f"Middleware #{i} returned {hook_result.action} on subagent_end")
312
372
  return hook_result
373
+ logger.debug("Subagent end processing completed, all middlewares passed")
313
374
  return HookResult.proceed()
314
375
 
315
376
  async def process_message_save(
@@ -327,15 +388,17 @@ class MiddlewareChain:
327
388
  Modified message, or None to skip saving
328
389
  """
329
390
  current = message
391
+ logger.debug(f"Processing message_save, role={message.get('role')}, {len(self._middlewares)} middlewares")
330
392
 
331
- for mw in self._middlewares:
393
+ for i, mw in enumerate(self._middlewares):
332
394
  if hasattr(mw, 'on_message_save'):
333
395
  result = await mw.on_message_save(current, context)
334
396
  if result is None:
335
- logger.debug("Middleware blocked message save")
397
+ logger.info(f"Middleware #{i} blocked message save for role={message.get('role')}")
336
398
  return None
337
399
  current = result
338
400
 
401
+ logger.debug("Message save processing completed, all middlewares passed")
339
402
  return current
340
403
 
341
404