aury-agent 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -49,10 +49,6 @@ class MessageContainerMiddleware(BaseMiddleware):
49
49
  parent_id. Defaults to {"thinking", "text"}.
50
50
  """
51
51
 
52
- # Key to store the token in middleware context
53
- _TOKEN_KEY = "_message_container_token"
54
- _BLOCK_ID_KEY = "_message_container_block_id"
55
-
56
52
  # Default kinds that should be grouped under message container
57
53
  DEFAULT_KINDS = {"thinking", "text"}
58
54
 
@@ -64,13 +60,17 @@ class MessageContainerMiddleware(BaseMiddleware):
64
60
  """
65
61
  super().__init__()
66
62
  self.apply_to_kinds = apply_to_kinds or self.DEFAULT_KINDS
63
+ self._token: Any = None
64
+ self._block_id: str | None = None
67
65
 
68
66
  async def on_request(
69
67
  self,
70
68
  request: dict[str, Any],
71
- context: dict[str, Any],
72
69
  ) -> dict[str, Any] | None:
73
70
  """Create message container block and set parent_id."""
71
+ from ..core.context import get_current_ctx_or_none
72
+ ctx = get_current_ctx_or_none()
73
+
74
74
  # Generate container block ID
75
75
  message_block_id = generate_id("blk")
76
76
 
@@ -86,41 +86,33 @@ class MessageContainerMiddleware(BaseMiddleware):
86
86
  op=BlockOp.APPLY,
87
87
  data={
88
88
  "type": "llm_response",
89
- "step": context.get("step"),
90
89
  },
91
- session_id=context.get("session_id"),
92
- invocation_id=context.get("invocation_id"),
90
+ session_id=ctx.session_id if ctx else None,
91
+ invocation_id=ctx.invocation_id if ctx else None,
93
92
  ))
94
93
 
95
94
  # Set parent_id in ContextVar with apply_to_kinds filter
96
- # Only blocks matching apply_to_kinds will inherit this parent_id
97
- token = set_parent_id(message_block_id, apply_to_kinds=self.apply_to_kinds)
98
- context[self._TOKEN_KEY] = token
99
- context[self._BLOCK_ID_KEY] = message_block_id
95
+ self._token = set_parent_id(message_block_id, apply_to_kinds=self.apply_to_kinds)
96
+ self._block_id = message_block_id
100
97
 
101
98
  return request
102
99
 
103
100
  async def on_response(
104
101
  self,
105
102
  response: dict[str, Any],
106
- context: dict[str, Any],
107
103
  ) -> dict[str, Any] | None:
108
104
  """Reset parent_id to previous value."""
109
- token = context.get(self._TOKEN_KEY)
110
- if token is not None:
111
- reset_parent_id(token)
105
+ if self._token is not None:
106
+ reset_parent_id(self._token)
107
+ self._token = None
112
108
  return response
113
109
 
114
110
  async def on_error(
115
111
  self,
116
112
  error: Exception,
117
- context: dict[str, Any],
118
113
  ) -> Exception | None:
119
114
  """Reset parent_id on error too."""
120
- token = context.get(self._TOKEN_KEY)
121
- if token is not None:
122
- reset_parent_id(token)
115
+ if self._token is not None:
116
+ reset_parent_id(self._token)
117
+ self._token = None
123
118
  return error
124
-
125
-
126
- __all__ = ["MessageContainerMiddleware"]
@@ -70,18 +70,18 @@ class RawMessageMiddleware(BaseMiddleware):
70
70
  async def on_message_save(
71
71
  self,
72
72
  message: dict[str, Any],
73
- context: dict[str, Any],
74
73
  ) -> dict[str, Any] | None:
75
74
  """Store complete message to RawMessageStore.
76
75
 
77
76
  Args:
78
77
  message: Complete message dict with 'role', 'content', etc.
79
- context: Execution context with 'invocation_id', etc.
80
78
 
81
79
  Returns:
82
80
  The message with added 'raw_msg_id' field
83
81
  """
84
- invocation_id = context.get("invocation_id", "")
82
+ from ..core.context import get_current_ctx_or_none
83
+ ctx = get_current_ctx_or_none()
84
+ invocation_id = ctx.invocation_id if ctx else ""
85
85
  if not invocation_id:
86
86
  return message
87
87
 
@@ -108,7 +108,6 @@ class RawMessageMiddleware(BaseMiddleware):
108
108
  self,
109
109
  agent_id: str,
110
110
  result: Any,
111
- context: dict[str, Any],
112
111
  ) -> HookResult:
113
112
  """Clean up raw messages when invocation completes.
114
113
 
@@ -117,7 +116,9 @@ class RawMessageMiddleware(BaseMiddleware):
117
116
  if self.persist_raw:
118
117
  return HookResult.proceed()
119
118
 
120
- invocation_id = context.get("invocation_id", "")
119
+ from ..core.context import get_current_ctx_or_none
120
+ ctx = get_current_ctx_or_none()
121
+ invocation_id = ctx.invocation_id if ctx else ""
121
122
  if invocation_id:
122
123
  await self._cleanup_invocation(invocation_id)
123
124
 
@@ -48,13 +48,11 @@ class MessageTruncationMiddleware(BaseMiddleware):
48
48
  async def on_message_save(
49
49
  self,
50
50
  message: dict[str, Any],
51
- context: dict[str, Any],
52
51
  ) -> dict[str, Any] | None:
53
52
  """Truncate message content before saving.
54
53
 
55
54
  Args:
56
55
  message: Message dict with 'role', 'content', etc.
57
- context: Execution context
58
56
 
59
57
  Returns:
60
58
  Modified message with truncated content
@@ -93,6 +93,7 @@ class ReactAgent(BaseAgent):
93
93
  enable_history: bool = True,
94
94
  history_limit: int = 50,
95
95
  delegate_tool_class: "type[BaseTool] | None" = None,
96
+ context_metadata: dict | None = None,
96
97
  ) -> "ReactAgent":
97
98
  """Create ReactAgent with minimal boilerplate. See factory.create_react_agent for details."""
98
99
  return create_react_agent(
@@ -110,6 +111,7 @@ class ReactAgent(BaseAgent):
110
111
  enable_history=enable_history,
111
112
  history_limit=history_limit,
112
113
  delegate_tool_class=delegate_tool_class,
114
+ context_metadata=context_metadata,
113
115
  )
114
116
 
115
117
  @classmethod
@@ -263,41 +265,38 @@ class ReactAgent(BaseAgent):
263
265
  "Starting ReactAgent run",
264
266
  extra={
265
267
  "session_id": self.session.id,
266
- "agent": self.name,
268
+ "agent_id": self.ctx.agent_id,
267
269
  }
268
270
  )
269
271
 
270
- # Build middleware context
271
- from ..core.context import emit as global_emit
272
- mw_context = {
273
- "session_id": self.session.id,
274
- "agent_id": self.name,
275
- "agent_type": self.agent_type,
276
- "emit": global_emit,
277
- "backends": self.ctx.backends,
278
- }
279
-
280
272
  try:
281
- # Create new invocation
273
+ # Create new invocation using ctx.invocation_id
282
274
  self._current_invocation = Invocation(
283
- id=generate_id("inv"),
275
+ id=self.ctx.invocation_id,
284
276
  session_id=self.session.id,
277
+ agent_id=self.ctx.agent_id,
285
278
  state=InvocationState.RUNNING,
286
279
  started_at=datetime.now(),
287
280
  )
288
- mw_context["invocation_id"] = self._current_invocation.id
289
281
 
290
282
  logger.info("Created invocation", extra={"invocation_id": self._current_invocation.id})
291
283
 
284
+ # Persist invocation immediately (so we have record even if agent fails)
285
+ if self.ctx.backends and self.ctx.backends.invocation:
286
+ await self.ctx.backends.invocation.create(
287
+ self._current_invocation.id,
288
+ self.session.id,
289
+ self._current_invocation.to_dict(),
290
+ agent_id=self.ctx.agent_id,
291
+ )
292
+
292
293
  # === Middleware: on_agent_start ===
293
294
  if self.middleware:
294
295
  logger.info(
295
296
  "Calling middleware: on_agent_start",
296
297
  extra={"invocation_id": self._current_invocation.id},
297
298
  )
298
- hook_result = await self.middleware.process_agent_start(
299
- self.name, input, mw_context
300
- )
299
+ hook_result = await self.middleware.process_agent_start(input)
301
300
  if hook_result.action == HookAction.STOP:
302
301
  logger.warning("Agent stopped by middleware on_agent_start", extra={"invocation_id": self._current_invocation.id})
303
302
  await self.ctx.emit(BlockEvent(
@@ -449,6 +448,13 @@ class ReactAgent(BaseAgent):
449
448
 
450
449
  # Complete invocation
451
450
  if is_aborted:
451
+ # Save current buffer content before marking as aborted
452
+ if self._text_buffer or self._thinking_buffer or self._tool_invocations:
453
+ await persist_helpers.save_assistant_message(self)
454
+ # Save completed tool results
455
+ if self._tool_invocations:
456
+ await persist_helpers.save_tool_messages(self)
457
+ # Mark invocation as aborted
452
458
  self._current_invocation.state = InvocationState.ABORTED
453
459
  logger.info(
454
460
  "Invocation aborted by user",
@@ -490,9 +496,7 @@ class ReactAgent(BaseAgent):
490
496
  # === Middleware: on_agent_end ===
491
497
  if self.middleware:
492
498
  await self.middleware.process_agent_end(
493
- self.name,
494
499
  {"steps": self._current_step, "finish_reason": finish_reason},
495
- mw_context,
496
500
  )
497
501
 
498
502
  await self.bus.publish(
@@ -551,7 +555,7 @@ class ReactAgent(BaseAgent):
551
555
 
552
556
  # === Middleware: on_error ===
553
557
  if self.middleware:
554
- processed_error = await self.middleware.process_error(e, mw_context)
558
+ processed_error = await self.middleware.process_error(e)
555
559
  if processed_error is None:
556
560
  logger.warning(
557
561
  "Error suppressed by middleware",
@@ -44,6 +44,8 @@ def create_react_agent(
44
44
  history_limit: int = 50,
45
45
  # Tool customization
46
46
  delegate_tool_class: "type[BaseTool] | None" = None,
47
+ # Context metadata
48
+ context_metadata: dict | None = None,
47
49
  ) -> "ReactAgent":
48
50
  """Create ReactAgent with minimal boilerplate.
49
51
 
@@ -159,16 +161,26 @@ def create_react_agent(
159
161
  all_providers = default_providers + (context_providers or [])
160
162
 
161
163
  # Build context
164
+ # agent_id: use config.id if provided, fallback to config.code, then "react_agent"
165
+ # agent_name: use config.name (display name)
166
+ agent_id = (
167
+ config.id if config and config.id
168
+ else (config.code if config and config.code else "react_agent")
169
+ )
170
+ agent_name = config.name if config else None
171
+
162
172
  ctx = InvocationContext(
163
173
  session=session,
164
174
  invocation_id=generate_id("inv"),
165
- agent_id=config.name if config else "react_agent",
175
+ agent_id=agent_id,
176
+ agent_name=agent_name,
166
177
  backends=backends,
167
178
  bus=bus,
168
179
  llm=llm,
169
180
  middleware=middleware_chain,
170
181
  memory=memory,
171
182
  snapshot=snapshot,
183
+ metadata=context_metadata or {},
172
184
  )
173
185
 
174
186
  agent = ReactAgent(ctx, config)
@@ -81,14 +81,7 @@ async def trigger_message_save(agent: "ReactAgent", message: dict) -> dict | Non
81
81
  if not agent.middleware:
82
82
  return message
83
83
 
84
- namespace = getattr(agent, '_message_namespace', None)
85
- mw_context = {
86
- "session_id": agent.session.id,
87
- "agent_id": agent.name,
88
- "namespace": namespace,
89
- }
90
-
91
- return await agent.middleware.process_message_save(message, mw_context)
84
+ return await agent.middleware.process_message_save(message)
92
85
 
93
86
 
94
87
  async def save_user_message(agent: "ReactAgent", input: "PromptInput") -> None:
aury/agents/react/step.py CHANGED
@@ -190,6 +190,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
190
190
  agent._text_buffer = ""
191
191
  agent._thinking_buffer = "" # Buffer for non-streaming thinking
192
192
  agent._tool_invocations = []
193
+ agent._last_usage = None # Store usage for middleware
193
194
 
194
195
  # Reset block IDs for this step (each step gets fresh block IDs)
195
196
  agent._current_text_block_id = None
@@ -199,20 +200,9 @@ async def execute_step(agent: "ReactAgent") -> str | None:
199
200
  agent._call_id_to_tool = {}
200
201
  agent._tool_call_blocks = {}
201
202
 
202
- # Track accumulated arguments for streaming tool calls (for middleware context)
203
+ # Track accumulated arguments for streaming tool calls
203
204
  tool_call_accumulated_args: dict[str, dict[str, Any]] = {}
204
205
 
205
- # Build middleware context for this step
206
- mw_context = {
207
- "session_id": agent.session.id,
208
- "invocation_id": agent._current_invocation.id if agent._current_invocation else "",
209
- "step": agent._current_step,
210
- "agent_id": agent.name,
211
- "emit": global_emit, # For middleware to emit BlockEvent/ActionEvent
212
- "backends": agent.ctx.backends,
213
- "tool_mode": effective_tool_mode.value, # Add tool mode to context
214
- }
215
-
216
206
  # Build LLM call kwargs
217
207
  # Note: temperature, max_tokens, timeout, retries are configured on LLMProvider
218
208
  llm_kwargs: dict[str, Any] = {
@@ -244,7 +234,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
244
234
  "Calling middleware: on_request",
245
235
  extra={"invocation_id": agent._current_invocation.id},
246
236
  )
247
- llm_kwargs = await agent.middleware.process_request(llm_kwargs, mw_context)
237
+ llm_kwargs = await agent.middleware.process_request(llm_kwargs)
248
238
  if llm_kwargs is None:
249
239
  logger.warning(
250
240
  "LLM request cancelled by middleware",
@@ -307,7 +297,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
307
297
  stream_chunk = {"delta": event.delta, "type": "content"}
308
298
  if agent.middleware:
309
299
  stream_chunk = await agent.middleware.process_stream_chunk(
310
- stream_chunk, mw_context
300
+ stream_chunk
311
301
  )
312
302
  if stream_chunk is None:
313
303
  continue # Skip this chunk
@@ -342,7 +332,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
342
332
  stream_chunk = {"delta": event.delta, "type": "thinking"}
343
333
  if agent.middleware:
344
334
  stream_chunk = await agent.middleware.process_stream_chunk(
345
- stream_chunk, mw_context
335
+ stream_chunk
346
336
  )
347
337
  if stream_chunk is None:
348
338
  continue # Skip this chunk
@@ -450,12 +440,9 @@ async def execute_step(agent: "ReactAgent") -> str | None:
450
440
  # === Middleware: on_tool_call_delta ===
451
441
  processed_delta = arguments_delta
452
442
  if agent.middleware:
453
- delta_context = {
454
- **mw_context,
455
- "accumulated_args": tool_call_accumulated_args.get(call_id, {}),
456
- }
443
+ accumulated_args = tool_call_accumulated_args.get(call_id, {})
457
444
  processed_delta = await agent.middleware.process_tool_call_delta(
458
- call_id, tool_name, arguments_delta, delta_context
445
+ call_id, tool_name, arguments_delta, accumulated_args
459
446
  )
460
447
  if processed_delta is None:
461
448
  continue # Skip this delta
@@ -539,17 +526,19 @@ async def execute_step(agent: "ReactAgent") -> str | None:
539
526
 
540
527
  elif event.type == "usage":
541
528
  if event.usage:
529
+ # Store usage for middleware
530
+ agent._last_usage = {
531
+ "provider": agent.llm.provider,
532
+ "model": agent.llm.model,
533
+ "input_tokens": event.usage.input_tokens,
534
+ "output_tokens": event.usage.output_tokens,
535
+ "cache_read_tokens": event.usage.cache_read_tokens,
536
+ "cache_write_tokens": event.usage.cache_write_tokens,
537
+ "reasoning_tokens": event.usage.reasoning_tokens,
538
+ }
542
539
  await agent.bus.publish(
543
540
  Events.USAGE_RECORDED,
544
- {
545
- "provider": agent.llm.provider,
546
- "model": agent.llm.model,
547
- "input_tokens": event.usage.input_tokens,
548
- "output_tokens": event.usage.output_tokens,
549
- "cache_read_tokens": event.usage.cache_read_tokens,
550
- "cache_write_tokens": event.usage.cache_write_tokens,
551
- "reasoning_tokens": event.usage.reasoning_tokens,
552
- },
541
+ agent._last_usage,
553
542
  )
554
543
 
555
544
  elif event.type == "error":
@@ -624,6 +613,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
624
613
  "thinking": agent._thinking_buffer,
625
614
  "tool_calls": len(agent._tool_invocations),
626
615
  "finish_reason": finish_reason,
616
+ "usage": agent._last_usage, # Include usage for middleware
627
617
  }
628
618
  if agent.middleware:
629
619
  logger.debug(
@@ -635,7 +625,7 @@ async def execute_step(agent: "ReactAgent") -> str | None:
635
625
  },
636
626
  )
637
627
  llm_response_data = await agent.middleware.process_response(
638
- llm_response_data, mw_context
628
+ llm_response_data
639
629
  )
640
630
 
641
631
  await agent.bus.publish(
@@ -49,6 +49,19 @@ async def execute_tool(agent: "ReactAgent", invocation: ToolInvocation) -> ToolR
49
49
  Returns:
50
50
  ToolResult from tool execution
51
51
  """
52
+ # Check abort before execution
53
+ if await agent._check_abort():
54
+ error_msg = f"Tool {invocation.tool_name} aborted before execution"
55
+ invocation.mark_result(error_msg, is_error=True)
56
+ logger.info(
57
+ f"Tool aborted before execution: {invocation.tool_name}",
58
+ extra={
59
+ "invocation_id": agent._current_invocation.id if agent._current_invocation else "",
60
+ "call_id": invocation.tool_call_id,
61
+ },
62
+ )
63
+ return ToolResult.error(error_msg)
64
+
52
65
  invocation.mark_call_complete()
53
66
 
54
67
  logger.info(
@@ -60,14 +73,6 @@ async def execute_tool(agent: "ReactAgent", invocation: ToolInvocation) -> ToolR
60
73
  },
61
74
  )
62
75
 
63
- # Build middleware context
64
- mw_context = {
65
- "session_id": agent.session.id,
66
- "invocation_id": agent._current_invocation.id if agent._current_invocation else "",
67
- "tool_call_id": invocation.tool_call_id,
68
- "agent_id": agent.name,
69
- }
70
-
71
76
  try:
72
77
  # Get tool from agent context
73
78
  tool = get_tool(agent, invocation.tool_name)
@@ -87,7 +92,7 @@ async def execute_tool(agent: "ReactAgent", invocation: ToolInvocation) -> ToolR
87
92
  extra={"invocation_id": agent._current_invocation.id, "call_id": invocation.tool_call_id},
88
93
  )
89
94
  hook_result = await agent.middleware.process_tool_call(
90
- tool, invocation.args, mw_context
95
+ tool, invocation.args
91
96
  )
92
97
  if hook_result.action == HookAction.SKIP:
93
98
  logger.warning(
@@ -136,7 +141,7 @@ async def execute_tool(agent: "ReactAgent", invocation: ToolInvocation) -> ToolR
136
141
  f"Calling middleware: on_tool_end ({invocation.tool_name})",
137
142
  extra={"invocation_id": agent._current_invocation.id},
138
143
  )
139
- hook_result = await agent.middleware.process_tool_end(tool, result, mw_context)
144
+ hook_result = await agent.middleware.process_tool_end(tool, result)
140
145
  if hook_result.action == HookAction.RETRY:
141
146
  logger.info(
142
147
  f"Tool {invocation.tool_name} retry requested by middleware",
@@ -212,16 +217,44 @@ async def process_tool_results(agent: "ReactAgent") -> None:
212
217
  },
213
218
  )
214
219
 
220
+ # Check abort before starting tool execution
221
+ if await agent._check_abort():
222
+ logger.info(
223
+ "Tool execution aborted before starting",
224
+ extra={"invocation_id": agent._current_invocation.id},
225
+ )
226
+ # Return empty results - agent loop will handle abort
227
+ return
228
+
215
229
  # Execute tools based on configuration
216
230
  if agent.config.parallel_tool_execution:
217
231
  # Parallel execution using asyncio.gather with create_task
218
232
  # create_task ensures each task gets its own ContextVar copy
219
233
  tasks = [asyncio.create_task(execute_tool(agent, inv)) for inv in agent._tool_invocations]
220
234
  results = await asyncio.gather(*tasks, return_exceptions=True)
235
+
236
+ # Check abort after parallel execution - cancel remaining if aborted
237
+ if await agent._check_abort():
238
+ logger.info(
239
+ "Tool execution aborted after parallel execution",
240
+ extra={"invocation_id": agent._current_invocation.id},
241
+ )
221
242
  else:
222
- # Sequential execution
243
+ # Sequential execution with abort check between tools
223
244
  results = []
224
245
  for inv in agent._tool_invocations:
246
+ # Check abort before each tool
247
+ if await agent._check_abort():
248
+ logger.info(
249
+ f"Tool execution aborted before {inv.tool_name}",
250
+ extra={"invocation_id": agent._current_invocation.id},
251
+ )
252
+ # Mark remaining as aborted
253
+ error_result = ToolResult.error(f"Aborted before execution")
254
+ results.append(error_result)
255
+ inv.mark_result(error_result.output, is_error=True)
256
+ continue
257
+
225
258
  try:
226
259
  result = await execute_tool(agent, inv)
227
260
  results.append(result)
@@ -153,12 +153,7 @@ Specify the agent key and task data."""
153
153
 
154
154
  # Get dynamic agents from middleware (progressive disclosure)
155
155
  if self.middleware and ctx:
156
- mw_context = {
157
- "session_id": ctx.session_id,
158
- "invocation_id": ctx.invocation_id,
159
- "agent_id": ctx.agent,
160
- }
161
- dynamic_agents = await self.middleware.get_dynamic_subagents(mw_context)
156
+ dynamic_agents = await self.middleware.get_dynamic_subagents()
162
157
  if dynamic_agents:
163
158
  # Store dynamic agents for later lookup
164
159
  for config in dynamic_agents:
@@ -53,7 +53,8 @@ def _type_to_schema(t: type) -> dict[str, Any]:
53
53
  bool: {"type": "boolean"},
54
54
  list: {"type": "array"},
55
55
  dict: {"type": "object"},
56
- Any: {},
56
+ # Any type: JSON Schema 2020-12 requires valid type, use object as fallback
57
+ Any: {"type": "object"},
57
58
  }
58
59
 
59
60
  return type_map.get(t, {"type": "string"})
@@ -162,21 +162,10 @@ class WorkflowAgent(BaseAgent):
162
162
  inputs = input if isinstance(input, dict) else {"input": input}
163
163
  logger.info(f"WorkflowAgent executing, workflow={self.workflow.spec.name}, invocation_id={self.ctx.invocation_id}")
164
164
 
165
- # Build middleware context
166
- mw_context = {
167
- "session_id": self.session.id,
168
- "invocation_id": self.ctx.invocation_id,
169
- "agent_id": self.name,
170
- "agent_type": self.agent_type,
171
- "workflow_name": self.workflow.spec.name,
172
- }
173
-
174
165
  # === Middleware: on_agent_start ===
175
166
  if self.middleware:
176
167
  logger.debug(f"WorkflowAgent: processing on_agent_start hooks, invocation_id={self.ctx.invocation_id}")
177
- hook_result = await self.middleware.process_agent_start(
178
- self.name, inputs, mw_context
179
- )
168
+ hook_result = await self.middleware.process_agent_start(inputs)
180
169
  if hook_result.action == HookAction.STOP:
181
170
  logger.warning(f"Workflow stopped by middleware on_agent_start, invocation_id={self.ctx.invocation_id}")
182
171
  await self.ctx.emit(BlockEvent(
@@ -205,16 +194,14 @@ class WorkflowAgent(BaseAgent):
205
194
  # === Middleware: on_agent_end ===
206
195
  if self.middleware:
207
196
  logger.debug(f"WorkflowAgent: processing on_agent_end hooks, invocation_id={self.ctx.invocation_id}")
208
- await self.middleware.process_agent_end(
209
- self.name, result, mw_context
210
- )
197
+ await self.middleware.process_agent_end(result)
211
198
 
212
199
  except Exception as e:
213
200
  logger.error(f"WorkflowAgent: execution error, error={type(e).__name__}, workflow={self.workflow.spec.name}, invocation_id={self.ctx.invocation_id}", exc_info=True)
214
201
  # === Middleware: on_error ===
215
202
  if self.middleware:
216
203
  logger.debug(f"WorkflowAgent: processing on_error hooks, invocation_id={self.ctx.invocation_id}")
217
- processed_error = await self.middleware.process_error(e, mw_context)
204
+ processed_error = await self.middleware.process_error(e)
218
205
  if processed_error is None:
219
206
  logger.warning(f"WorkflowAgent: error suppressed by middleware, invocation_id={self.ctx.invocation_id}")
220
207
  return
@@ -598,24 +598,12 @@ class WorkflowExecutor:
598
598
  # Get effective middleware for this node
599
599
  effective_middleware = self._get_effective_middleware(node)
600
600
 
601
- # Build middleware context
602
- mw_context = {
603
- "session_id": self.ctx.session_id,
604
- "invocation_id": self.ctx.invocation_id,
605
- "parent_agent_id": self.workflow.spec.name,
606
- "child_agent_id": node.agent,
607
- "node_id": node.id,
608
- "parent_block_id": parent_block_id,
609
- "has_node_middleware": bool(node.middleware),
610
- }
611
-
612
601
  # === Middleware: on_subagent_start ===
613
602
  if effective_middleware:
614
603
  hook_result = await effective_middleware.process_subagent_start(
615
604
  self.workflow.spec.name,
616
605
  node.agent,
617
606
  "embedded", # Workflow nodes are embedded execution
618
- mw_context,
619
607
  )
620
608
  if hook_result.action == HookAction.SKIP:
621
609
  logger.info(f"SubAgent {node.agent} skipped by middleware")
@@ -628,8 +616,13 @@ class WorkflowExecutor:
628
616
  try:
629
617
  # Create child context for sub-agent with effective middleware
630
618
  # Note: parent_block_id is already set via ContextVar above
619
+ # Get agent name from factory if available
620
+ agent_class = self.agent_factory.get_class(node.agent)
621
+ agent_name = getattr(agent_class, 'name', node.agent) if agent_class else node.agent
622
+
631
623
  child_ctx = self.ctx.create_child(
632
624
  agent_id=node.agent,
625
+ agent_name=agent_name,
633
626
  middleware=effective_middleware,
634
627
  )
635
628
 
@@ -662,7 +655,6 @@ class WorkflowExecutor:
662
655
  self.workflow.spec.name,
663
656
  node.agent,
664
657
  result,
665
- mw_context,
666
658
  )
667
659
 
668
660
  return result
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aury-agent
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: Aury Agent Framework - React Agent and Workflow orchestration
5
5
  Author: Aury Team
6
6
  License: MIT