aury-agent 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. aury/__init__.py +2 -0
  2. aury/agents/__init__.py +55 -0
  3. aury/agents/a2a/__init__.py +168 -0
  4. aury/agents/backends/__init__.py +196 -0
  5. aury/agents/backends/artifact/__init__.py +9 -0
  6. aury/agents/backends/artifact/memory.py +130 -0
  7. aury/agents/backends/artifact/types.py +133 -0
  8. aury/agents/backends/code/__init__.py +65 -0
  9. aury/agents/backends/file/__init__.py +11 -0
  10. aury/agents/backends/file/local.py +66 -0
  11. aury/agents/backends/file/types.py +40 -0
  12. aury/agents/backends/invocation/__init__.py +8 -0
  13. aury/agents/backends/invocation/memory.py +81 -0
  14. aury/agents/backends/invocation/types.py +110 -0
  15. aury/agents/backends/memory/__init__.py +8 -0
  16. aury/agents/backends/memory/memory.py +179 -0
  17. aury/agents/backends/memory/types.py +136 -0
  18. aury/agents/backends/message/__init__.py +9 -0
  19. aury/agents/backends/message/memory.py +122 -0
  20. aury/agents/backends/message/types.py +124 -0
  21. aury/agents/backends/sandbox.py +275 -0
  22. aury/agents/backends/session/__init__.py +8 -0
  23. aury/agents/backends/session/memory.py +93 -0
  24. aury/agents/backends/session/types.py +124 -0
  25. aury/agents/backends/shell/__init__.py +11 -0
  26. aury/agents/backends/shell/local.py +110 -0
  27. aury/agents/backends/shell/types.py +55 -0
  28. aury/agents/backends/shell.py +209 -0
  29. aury/agents/backends/snapshot/__init__.py +19 -0
  30. aury/agents/backends/snapshot/git.py +95 -0
  31. aury/agents/backends/snapshot/hybrid.py +125 -0
  32. aury/agents/backends/snapshot/memory.py +86 -0
  33. aury/agents/backends/snapshot/types.py +59 -0
  34. aury/agents/backends/state/__init__.py +29 -0
  35. aury/agents/backends/state/composite.py +49 -0
  36. aury/agents/backends/state/file.py +57 -0
  37. aury/agents/backends/state/memory.py +52 -0
  38. aury/agents/backends/state/sqlite.py +262 -0
  39. aury/agents/backends/state/types.py +178 -0
  40. aury/agents/backends/subagent/__init__.py +165 -0
  41. aury/agents/cli/__init__.py +41 -0
  42. aury/agents/cli/chat.py +239 -0
  43. aury/agents/cli/config.py +236 -0
  44. aury/agents/cli/extensions.py +460 -0
  45. aury/agents/cli/main.py +189 -0
  46. aury/agents/cli/session.py +337 -0
  47. aury/agents/cli/workflow.py +276 -0
  48. aury/agents/context_providers/__init__.py +66 -0
  49. aury/agents/context_providers/artifact.py +299 -0
  50. aury/agents/context_providers/base.py +177 -0
  51. aury/agents/context_providers/memory.py +70 -0
  52. aury/agents/context_providers/message.py +130 -0
  53. aury/agents/context_providers/skill.py +50 -0
  54. aury/agents/context_providers/subagent.py +46 -0
  55. aury/agents/context_providers/tool.py +68 -0
  56. aury/agents/core/__init__.py +83 -0
  57. aury/agents/core/base.py +573 -0
  58. aury/agents/core/context.py +797 -0
  59. aury/agents/core/context_builder.py +303 -0
  60. aury/agents/core/event_bus/__init__.py +15 -0
  61. aury/agents/core/event_bus/bus.py +203 -0
  62. aury/agents/core/factory.py +169 -0
  63. aury/agents/core/isolator.py +97 -0
  64. aury/agents/core/logging.py +95 -0
  65. aury/agents/core/parallel.py +194 -0
  66. aury/agents/core/runner.py +139 -0
  67. aury/agents/core/services/__init__.py +5 -0
  68. aury/agents/core/services/file_session.py +144 -0
  69. aury/agents/core/services/message.py +53 -0
  70. aury/agents/core/services/session.py +53 -0
  71. aury/agents/core/signals.py +109 -0
  72. aury/agents/core/state.py +363 -0
  73. aury/agents/core/types/__init__.py +107 -0
  74. aury/agents/core/types/action.py +176 -0
  75. aury/agents/core/types/artifact.py +135 -0
  76. aury/agents/core/types/block.py +736 -0
  77. aury/agents/core/types/message.py +350 -0
  78. aury/agents/core/types/recall.py +144 -0
  79. aury/agents/core/types/session.py +257 -0
  80. aury/agents/core/types/subagent.py +154 -0
  81. aury/agents/core/types/tool.py +205 -0
  82. aury/agents/eval/__init__.py +331 -0
  83. aury/agents/hitl/__init__.py +57 -0
  84. aury/agents/hitl/ask_user.py +242 -0
  85. aury/agents/hitl/compaction.py +230 -0
  86. aury/agents/hitl/exceptions.py +87 -0
  87. aury/agents/hitl/permission.py +617 -0
  88. aury/agents/hitl/revert.py +216 -0
  89. aury/agents/llm/__init__.py +31 -0
  90. aury/agents/llm/adapter.py +367 -0
  91. aury/agents/llm/openai.py +294 -0
  92. aury/agents/llm/provider.py +476 -0
  93. aury/agents/mcp/__init__.py +153 -0
  94. aury/agents/memory/__init__.py +46 -0
  95. aury/agents/memory/compaction.py +394 -0
  96. aury/agents/memory/manager.py +465 -0
  97. aury/agents/memory/processor.py +177 -0
  98. aury/agents/memory/store.py +187 -0
  99. aury/agents/memory/types.py +137 -0
  100. aury/agents/messages/__init__.py +40 -0
  101. aury/agents/messages/config.py +47 -0
  102. aury/agents/messages/raw_store.py +224 -0
  103. aury/agents/messages/store.py +118 -0
  104. aury/agents/messages/types.py +88 -0
  105. aury/agents/middleware/__init__.py +31 -0
  106. aury/agents/middleware/base.py +341 -0
  107. aury/agents/middleware/chain.py +342 -0
  108. aury/agents/middleware/message.py +129 -0
  109. aury/agents/middleware/message_container.py +126 -0
  110. aury/agents/middleware/raw_message.py +153 -0
  111. aury/agents/middleware/truncation.py +139 -0
  112. aury/agents/middleware/types.py +81 -0
  113. aury/agents/plugin.py +162 -0
  114. aury/agents/react/__init__.py +4 -0
  115. aury/agents/react/agent.py +1923 -0
  116. aury/agents/sandbox/__init__.py +23 -0
  117. aury/agents/sandbox/local.py +239 -0
  118. aury/agents/sandbox/remote.py +200 -0
  119. aury/agents/sandbox/types.py +115 -0
  120. aury/agents/skill/__init__.py +16 -0
  121. aury/agents/skill/loader.py +180 -0
  122. aury/agents/skill/types.py +83 -0
  123. aury/agents/tool/__init__.py +39 -0
  124. aury/agents/tool/builtin/__init__.py +23 -0
  125. aury/agents/tool/builtin/ask_user.py +155 -0
  126. aury/agents/tool/builtin/bash.py +107 -0
  127. aury/agents/tool/builtin/delegate.py +726 -0
  128. aury/agents/tool/builtin/edit.py +121 -0
  129. aury/agents/tool/builtin/plan.py +277 -0
  130. aury/agents/tool/builtin/read.py +91 -0
  131. aury/agents/tool/builtin/thinking.py +111 -0
  132. aury/agents/tool/builtin/yield_result.py +130 -0
  133. aury/agents/tool/decorator.py +252 -0
  134. aury/agents/tool/set.py +204 -0
  135. aury/agents/usage/__init__.py +12 -0
  136. aury/agents/usage/tracker.py +236 -0
  137. aury/agents/workflow/__init__.py +85 -0
  138. aury/agents/workflow/adapter.py +268 -0
  139. aury/agents/workflow/dag.py +116 -0
  140. aury/agents/workflow/dsl.py +575 -0
  141. aury/agents/workflow/executor.py +659 -0
  142. aury/agents/workflow/expression.py +136 -0
  143. aury/agents/workflow/parser.py +182 -0
  144. aury/agents/workflow/state.py +145 -0
  145. aury/agents/workflow/types.py +86 -0
  146. aury_agent-0.0.4.dist-info/METADATA +90 -0
  147. aury_agent-0.0.4.dist-info/RECORD +149 -0
  148. aury_agent-0.0.4.dist-info/WHEEL +4 -0
  149. aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,342 @@
1
+ """Middleware chain for sequential processing."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Any, TYPE_CHECKING
6
+
7
+ from ..core.logging import middleware_logger as logger
8
+ from .types import TriggerMode, HookAction, HookResult
9
+ from .base import Middleware
10
+
11
+ if TYPE_CHECKING:
12
+ from ..core.types.tool import BaseTool, ToolResult
13
+
14
+
15
+ @dataclass
16
+ class MiddlewareEntry:
17
+ """Entry in middleware chain with inherit override."""
18
+ middleware: Middleware
19
+ inherit: bool # Effective inherit value (config default or overridden)
20
+
21
+
22
+ class MiddlewareChain:
23
+ """Chain of middlewares for sequential processing."""
24
+
25
+ def __init__(self, middlewares: list[Middleware] | None = None) -> None:
26
+ self._entries: list[MiddlewareEntry] = []
27
+ self._token_buffer: str = ""
28
+ self._token_count: int = 0
29
+
30
+ # Add initial middlewares if provided
31
+ if middlewares:
32
+ for mw in middlewares:
33
+ self.use(mw)
34
+
35
+ @property
36
+ def _middlewares(self) -> list[Middleware]:
37
+ """Get middleware list."""
38
+ return [e.middleware for e in self._entries]
39
+
40
+ def use(
41
+ self,
42
+ middleware: Middleware,
43
+ *,
44
+ inherit: bool | None = None,
45
+ ) -> "MiddlewareChain":
46
+ """Add middleware to chain.
47
+
48
+ Args:
49
+ middleware: The middleware to add
50
+ inherit: Override inherit setting (None = use middleware's config default)
51
+
52
+ Maintains sorted order by priority.
53
+ """
54
+ effective_inherit = inherit if inherit is not None else middleware.config.inherit
55
+ entry = MiddlewareEntry(middleware=middleware, inherit=effective_inherit)
56
+ self._entries.append(entry)
57
+ self._entries.sort(key=lambda e: e.middleware.config.priority)
58
+ return self
59
+
60
+ def remove(self, middleware: Middleware) -> "MiddlewareChain":
61
+ """Remove middleware from chain."""
62
+ self._entries = [e for e in self._entries if e.middleware != middleware]
63
+ return self
64
+
65
+ def clear(self) -> "MiddlewareChain":
66
+ """Clear all middlewares."""
67
+ self._entries.clear()
68
+ return self
69
+
70
+ def get_inheritable(self) -> list[MiddlewareEntry]:
71
+ """Get entries that should be inherited by sub-agents."""
72
+ return [e for e in self._entries if e.inherit]
73
+
74
+ def merge(self, other: "MiddlewareChain | None") -> "MiddlewareChain":
75
+ """Merge this chain's inheritable middlewares with another chain.
76
+
77
+ Creates a new chain with:
78
+ - This chain's inheritable middlewares
79
+ - All of other chain's middlewares
80
+
81
+ Args:
82
+ other: Chain to merge with (sub-agent's own middlewares)
83
+
84
+ Returns:
85
+ New merged MiddlewareChain
86
+ """
87
+ merged = MiddlewareChain()
88
+
89
+ # Add inheritable from this chain
90
+ for entry in self.get_inheritable():
91
+ merged._entries.append(MiddlewareEntry(
92
+ middleware=entry.middleware,
93
+ inherit=entry.inherit,
94
+ ))
95
+
96
+ # Add all from other chain
97
+ if other:
98
+ for entry in other._entries:
99
+ # Avoid duplicates (same middleware instance)
100
+ if entry.middleware not in [e.middleware for e in merged._entries]:
101
+ merged._entries.append(MiddlewareEntry(
102
+ middleware=entry.middleware,
103
+ inherit=entry.inherit,
104
+ ))
105
+
106
+ # Re-sort by priority
107
+ merged._entries.sort(key=lambda e: e.middleware.config.priority)
108
+ return merged
109
+
110
+ async def process_request(
111
+ self,
112
+ request: dict[str, Any],
113
+ context: dict[str, Any],
114
+ ) -> dict[str, Any] | None:
115
+ """Process request through all middlewares."""
116
+ current = request
117
+
118
+ for mw in self._middlewares:
119
+ result = await mw.on_request(current, context)
120
+ if result is None:
121
+ return None
122
+ current = result
123
+
124
+ return current
125
+
126
+ async def process_response(
127
+ self,
128
+ response: dict[str, Any],
129
+ context: dict[str, Any],
130
+ ) -> dict[str, Any] | None:
131
+ """Process response through all middlewares (reverse order)."""
132
+ current = response
133
+
134
+ for mw in reversed(self._middlewares):
135
+ result = await mw.on_response(current, context)
136
+ if result is None:
137
+ return None
138
+ current = result
139
+
140
+ return current
141
+
142
+ async def process_error(
143
+ self,
144
+ error: Exception,
145
+ context: dict[str, Any],
146
+ ) -> Exception | None:
147
+ """Process error through all middlewares."""
148
+ current = error
149
+
150
+ for mw in self._middlewares:
151
+ result = await mw.on_error(current, context)
152
+ if result is None:
153
+ return None
154
+ current = result
155
+
156
+ return current
157
+
158
+ async def process_stream_chunk(
159
+ self,
160
+ chunk: dict[str, Any],
161
+ context: dict[str, Any],
162
+ ) -> dict[str, Any] | None:
163
+ """Process streaming chunk through middlewares based on trigger mode."""
164
+ text = chunk.get("text", chunk.get("delta", ""))
165
+ self._token_buffer += text
166
+ self._token_count += 1
167
+
168
+ current = chunk
169
+
170
+ for mw in self._middlewares:
171
+ should_trigger = self._should_trigger(mw, text)
172
+
173
+ if should_trigger:
174
+ result = await mw.on_model_stream(current, context)
175
+ if result is None:
176
+ return None
177
+ current = result
178
+
179
+ return current
180
+
181
+ def _should_trigger(self, middleware: Middleware, text: str) -> bool:
182
+ """Check if middleware should be triggered."""
183
+ mode = middleware.config.trigger_mode
184
+
185
+ if mode == TriggerMode.EVERY_TOKEN:
186
+ return True
187
+ elif mode == TriggerMode.EVERY_N_TOKENS:
188
+ return self._token_count % middleware.config.trigger_n == 0
189
+ elif mode == TriggerMode.ON_BOUNDARY:
190
+ return self._is_boundary(text)
191
+
192
+ return True
193
+
194
+ def _is_boundary(self, text: str) -> bool:
195
+ """Check if text ends with a sentence/paragraph boundary."""
196
+ boundaries = (".", "。", "\n", "!", "?", "!", "?", ";", ";")
197
+ return text.rstrip().endswith(boundaries)
198
+
199
+ def reset_stream_state(self) -> None:
200
+ """Reset streaming state (call at start of new stream)."""
201
+ self._token_buffer = ""
202
+ self._token_count = 0
203
+
204
+ @property
205
+ def middlewares(self) -> list[Middleware]:
206
+ """Get list of middlewares (read-only)."""
207
+ return list(self._middlewares)
208
+
209
+ # ========== Lifecycle Hook Processing ==========
210
+
211
+ async def process_agent_start(
212
+ self,
213
+ agent_id: str,
214
+ input_data: Any,
215
+ context: dict[str, Any],
216
+ ) -> HookResult:
217
+ """Process agent start through all middlewares.
218
+
219
+ Returns:
220
+ First non-CONTINUE result, or CONTINUE if all pass
221
+ """
222
+ for mw in self._middlewares:
223
+ if hasattr(mw, 'on_agent_start'):
224
+ result = await mw.on_agent_start(agent_id, input_data, context)
225
+ if result.action != HookAction.CONTINUE:
226
+ logger.debug(f"Middleware returned {result.action} on agent_start")
227
+ return result
228
+ return HookResult.proceed()
229
+
230
+ async def process_agent_end(
231
+ self,
232
+ agent_id: str,
233
+ result: Any,
234
+ context: dict[str, Any],
235
+ ) -> HookResult:
236
+ """Process agent end through all middlewares (reverse order)."""
237
+ for mw in reversed(self._middlewares):
238
+ if hasattr(mw, 'on_agent_end'):
239
+ hook_result = await mw.on_agent_end(agent_id, result, context)
240
+ if hook_result.action != HookAction.CONTINUE:
241
+ logger.debug(f"Middleware returned {hook_result.action} on agent_end")
242
+ return hook_result
243
+ return HookResult.proceed()
244
+
245
+ async def process_tool_call(
246
+ self,
247
+ tool: "BaseTool",
248
+ params: dict[str, Any],
249
+ context: dict[str, Any],
250
+ ) -> HookResult:
251
+ """Process tool call through all middlewares.
252
+
253
+ Returns:
254
+ SKIP to skip tool, RETRY with modified_data to change params
255
+ """
256
+ for mw in self._middlewares:
257
+ if hasattr(mw, 'on_tool_call'):
258
+ result = await mw.on_tool_call(tool, params, context)
259
+ if result.action != HookAction.CONTINUE:
260
+ logger.debug(f"Middleware returned {result.action} on tool_call")
261
+ return result
262
+ return HookResult.proceed()
263
+
264
+ async def process_tool_end(
265
+ self,
266
+ tool: "BaseTool",
267
+ result: "ToolResult",
268
+ context: dict[str, Any],
269
+ ) -> HookResult:
270
+ """Process tool end through all middlewares (reverse order)."""
271
+ for mw in reversed(self._middlewares):
272
+ if hasattr(mw, 'on_tool_end'):
273
+ hook_result = await mw.on_tool_end(tool, result, context)
274
+ if hook_result.action != HookAction.CONTINUE:
275
+ logger.debug(f"Middleware returned {hook_result.action} on tool_end")
276
+ return hook_result
277
+ return HookResult.proceed()
278
+
279
+ async def process_subagent_start(
280
+ self,
281
+ parent_agent_id: str,
282
+ child_agent_id: str,
283
+ mode: str,
284
+ context: dict[str, Any],
285
+ ) -> HookResult:
286
+ """Process sub-agent start through all middlewares."""
287
+ for mw in self._middlewares:
288
+ if hasattr(mw, 'on_subagent_start'):
289
+ result = await mw.on_subagent_start(
290
+ parent_agent_id, child_agent_id, mode, context
291
+ )
292
+ if result.action != HookAction.CONTINUE:
293
+ logger.debug(f"Middleware returned {result.action} on subagent_start")
294
+ return result
295
+ return HookResult.proceed()
296
+
297
+ async def process_subagent_end(
298
+ self,
299
+ parent_agent_id: str,
300
+ child_agent_id: str,
301
+ result: Any,
302
+ context: dict[str, Any],
303
+ ) -> HookResult:
304
+ """Process sub-agent end through all middlewares (reverse order)."""
305
+ for mw in reversed(self._middlewares):
306
+ if hasattr(mw, 'on_subagent_end'):
307
+ hook_result = await mw.on_subagent_end(
308
+ parent_agent_id, child_agent_id, result, context
309
+ )
310
+ if hook_result.action != HookAction.CONTINUE:
311
+ logger.debug(f"Middleware returned {hook_result.action} on subagent_end")
312
+ return hook_result
313
+ return HookResult.proceed()
314
+
315
+ async def process_message_save(
316
+ self,
317
+ message: dict[str, Any],
318
+ context: dict[str, Any],
319
+ ) -> dict[str, Any] | None:
320
+ """Process message save through all middlewares.
321
+
322
+ Args:
323
+ message: Message to be saved
324
+ context: Execution context
325
+
326
+ Returns:
327
+ Modified message, or None to skip saving
328
+ """
329
+ current = message
330
+
331
+ for mw in self._middlewares:
332
+ if hasattr(mw, 'on_message_save'):
333
+ result = await mw.on_message_save(current, context)
334
+ if result is None:
335
+ logger.debug("Middleware blocked message save")
336
+ return None
337
+ current = result
338
+
339
+ return current
340
+
341
+
342
+ __all__ = ["MiddlewareChain"]
@@ -0,0 +1,129 @@
1
+ """Message persistence middleware.
2
+
3
+ Uses backends.message directly to persist messages.
4
+
5
+ Usage:
6
+ middleware = MessageBackendMiddleware()
7
+
8
+ agent = ReactAgent.create(
9
+ llm=llm,
10
+ middlewares=[middleware],
11
+ )
12
+ """
13
+ from __future__ import annotations
14
+
15
+ from typing import TYPE_CHECKING, Any
16
+
17
+ from .base import BaseMiddleware
18
+
19
+ if TYPE_CHECKING:
20
+ from ..backends import MessageBackend
21
+
22
+
23
+ class MessageBackendMiddleware(BaseMiddleware):
24
+ """Middleware that persists messages directly via MessageBackend.
25
+
26
+ This is the recommended middleware for message persistence.
27
+ Uses backends.message from the context for direct storage.
28
+
29
+ Features:
30
+ - Saves both truncated (for LLM context) and raw (for audit) messages
31
+ - Supports namespace isolation for sub-agents
32
+ - Works with any MessageBackend implementation
33
+
34
+ Example:
35
+ agent = ReactAgent.create(
36
+ llm=llm,
37
+ middlewares=[MessageBackendMiddleware()],
38
+ )
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ *,
44
+ save_raw: bool = False,
45
+ max_history: int = 100,
46
+ ):
47
+ """Initialize MessageBackendMiddleware.
48
+
49
+ Args:
50
+ save_raw: Also save raw (untruncated) messages for audit
51
+ max_history: Max messages to keep in history (for truncation)
52
+ """
53
+ self.save_raw = save_raw
54
+ self.max_history = max_history
55
+
56
+ async def on_message_save(
57
+ self,
58
+ message: dict[str, Any],
59
+ context: dict[str, Any],
60
+ ) -> dict[str, Any] | None:
61
+ """Save message via backends.message.
62
+
63
+ Args:
64
+ message: Message dict with 'role', 'content', etc.
65
+ context: Execution context with 'session_id', 'agent_id', 'backends'
66
+
67
+ Returns:
68
+ The message (pass through to other middlewares)
69
+ """
70
+ from ..core.context import get_current_ctx_or_none
71
+
72
+ session_id = context.get("session_id", "")
73
+ if not session_id:
74
+ return message
75
+
76
+ # Get MessageBackend from context
77
+ ctx = get_current_ctx_or_none()
78
+ if ctx is None or ctx.backends is None or ctx.backends.message is None:
79
+ # No backend available, pass through
80
+ return message
81
+
82
+ backend = ctx.backends.message
83
+
84
+ # Extract message fields
85
+ role = message.get("role", "")
86
+ content = message.get("content", "")
87
+ invocation_id = context.get("invocation_id", "")
88
+ agent_id = context.get("agent_id")
89
+ namespace = context.get("namespace")
90
+ tool_call_id = message.get("tool_call_id")
91
+
92
+ # Build message dict for backend
93
+ msg_dict = {
94
+ "role": role,
95
+ "content": content,
96
+ }
97
+ if tool_call_id:
98
+ msg_dict["tool_call_id"] = tool_call_id
99
+
100
+ # Include tool_calls if present (for assistant messages)
101
+ if "tool_calls" in message:
102
+ msg_dict["tool_calls"] = message["tool_calls"]
103
+
104
+ # Save truncated message (for LLM context)
105
+ await backend.add(
106
+ session_id=session_id,
107
+ message=msg_dict,
108
+ type="truncated",
109
+ agent_id=agent_id,
110
+ namespace=namespace,
111
+ invocation_id=invocation_id,
112
+ )
113
+
114
+ # Optionally save raw message (for audit)
115
+ if self.save_raw:
116
+ await backend.add(
117
+ session_id=session_id,
118
+ message=message, # Full original message
119
+ type="raw",
120
+ agent_id=agent_id,
121
+ namespace=namespace,
122
+ invocation_id=invocation_id,
123
+ )
124
+
125
+ # Pass through to other middlewares
126
+ return message
127
+
128
+
129
+ __all__ = ["MessageBackendMiddleware"]
@@ -0,0 +1,126 @@
1
+ """MessageContainerMiddleware - Groups thinking and text blocks under a message container.
2
+
3
+ This middleware creates a "message" container block when LLM request starts,
4
+ and all thinking/text blocks emitted during the LLM call will have parent_id
5
+ pointing to this container.
6
+
7
+ Usage:
8
+ from aury.agents.middleware import MessageContainerMiddleware, MiddlewareChain
9
+
10
+ chain = MiddlewareChain()
11
+ chain.use(MessageContainerMiddleware())
12
+
13
+ agent = ReactAgent.create(llm=llm, middleware=chain)
14
+
15
+ Result structure:
16
+ message (block_id: blk_abc)
17
+ ├── thinking (parent_id: blk_abc)
18
+ └── text (parent_id: blk_abc)
19
+
20
+ tool_use (parent_id: None) - not grouped
21
+ tool_result (parent_id: None) - not grouped
22
+ """
23
+ from __future__ import annotations
24
+
25
+ from typing import Any, TYPE_CHECKING
26
+
27
+ from .base import BaseMiddleware
28
+ from .types import HookResult
29
+ from ..core.context import set_parent_id, reset_parent_id, emit, get_parent_id
30
+ from ..core.types.session import generate_id
31
+ from ..core.types.block import BlockEvent, BlockKind, BlockOp
32
+ from ..core.logging import middleware_logger as logger
33
+
34
+ if TYPE_CHECKING:
35
+ from ..core.types.tool import BaseTool, ToolResult
36
+
37
+
38
+ class MessageContainerMiddleware(BaseMiddleware):
39
+ """Groups thinking and text blocks under a message container.
40
+
41
+ When an LLM request starts, creates a "message" container block and sets
42
+ the parent_id ContextVar. Only thinking and text blocks will inherit this
43
+ parent_id (tool_use, tool_result, etc. are not grouped).
44
+
45
+ This allows frontend to group thinking + text as a single unit for display.
46
+
47
+ Args:
48
+ apply_to_kinds: Set of block kinds that should inherit the container's
49
+ parent_id. Defaults to {"thinking", "text"}.
50
+ """
51
+
52
+ # Key to store the token in middleware context
53
+ _TOKEN_KEY = "_message_container_token"
54
+ _BLOCK_ID_KEY = "_message_container_block_id"
55
+
56
+ # Default kinds that should be grouped under message container
57
+ DEFAULT_KINDS = {"thinking", "text"}
58
+
59
+ def __init__(self, apply_to_kinds: set[str] | None = None):
60
+ """Initialize with optional custom kinds filter.
61
+
62
+ Args:
63
+ apply_to_kinds: Block kinds to group. Defaults to {"thinking", "text"}.
64
+ """
65
+ super().__init__()
66
+ self.apply_to_kinds = apply_to_kinds or self.DEFAULT_KINDS
67
+
68
+ async def on_request(
69
+ self,
70
+ request: dict[str, Any],
71
+ context: dict[str, Any],
72
+ ) -> dict[str, Any] | None:
73
+ """Create message container block and set parent_id."""
74
+ # Generate container block ID
75
+ message_block_id = generate_id("blk")
76
+
77
+ logger.debug(
78
+ "[MessageContainerMiddleware] Creating container",
79
+ extra={"block_id": message_block_id, "apply_to_kinds": list(self.apply_to_kinds)}
80
+ )
81
+
82
+ # Emit the container block
83
+ await emit(BlockEvent(
84
+ block_id=message_block_id,
85
+ kind="message", # Container type
86
+ op=BlockOp.APPLY,
87
+ data={
88
+ "type": "llm_response",
89
+ "step": context.get("step"),
90
+ },
91
+ session_id=context.get("session_id"),
92
+ invocation_id=context.get("invocation_id"),
93
+ ))
94
+
95
+ # Set parent_id in ContextVar with apply_to_kinds filter
96
+ # Only blocks matching apply_to_kinds will inherit this parent_id
97
+ token = set_parent_id(message_block_id, apply_to_kinds=self.apply_to_kinds)
98
+ context[self._TOKEN_KEY] = token
99
+ context[self._BLOCK_ID_KEY] = message_block_id
100
+
101
+ return request
102
+
103
+ async def on_response(
104
+ self,
105
+ response: dict[str, Any],
106
+ context: dict[str, Any],
107
+ ) -> dict[str, Any] | None:
108
+ """Reset parent_id to previous value."""
109
+ token = context.get(self._TOKEN_KEY)
110
+ if token is not None:
111
+ reset_parent_id(token)
112
+ return response
113
+
114
+ async def on_error(
115
+ self,
116
+ error: Exception,
117
+ context: dict[str, Any],
118
+ ) -> Exception | None:
119
+ """Reset parent_id on error too."""
120
+ token = context.get(self._TOKEN_KEY)
121
+ if token is not None:
122
+ reset_parent_id(token)
123
+ return error
124
+
125
+
126
+ __all__ = ["MessageContainerMiddleware"]