aury-agent 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/__init__.py +2 -0
- aury/agents/__init__.py +55 -0
- aury/agents/a2a/__init__.py +168 -0
- aury/agents/backends/__init__.py +196 -0
- aury/agents/backends/artifact/__init__.py +9 -0
- aury/agents/backends/artifact/memory.py +130 -0
- aury/agents/backends/artifact/types.py +133 -0
- aury/agents/backends/code/__init__.py +65 -0
- aury/agents/backends/file/__init__.py +11 -0
- aury/agents/backends/file/local.py +66 -0
- aury/agents/backends/file/types.py +40 -0
- aury/agents/backends/invocation/__init__.py +8 -0
- aury/agents/backends/invocation/memory.py +81 -0
- aury/agents/backends/invocation/types.py +110 -0
- aury/agents/backends/memory/__init__.py +8 -0
- aury/agents/backends/memory/memory.py +179 -0
- aury/agents/backends/memory/types.py +136 -0
- aury/agents/backends/message/__init__.py +9 -0
- aury/agents/backends/message/memory.py +122 -0
- aury/agents/backends/message/types.py +124 -0
- aury/agents/backends/sandbox.py +275 -0
- aury/agents/backends/session/__init__.py +8 -0
- aury/agents/backends/session/memory.py +93 -0
- aury/agents/backends/session/types.py +124 -0
- aury/agents/backends/shell/__init__.py +11 -0
- aury/agents/backends/shell/local.py +110 -0
- aury/agents/backends/shell/types.py +55 -0
- aury/agents/backends/shell.py +209 -0
- aury/agents/backends/snapshot/__init__.py +19 -0
- aury/agents/backends/snapshot/git.py +95 -0
- aury/agents/backends/snapshot/hybrid.py +125 -0
- aury/agents/backends/snapshot/memory.py +86 -0
- aury/agents/backends/snapshot/types.py +59 -0
- aury/agents/backends/state/__init__.py +29 -0
- aury/agents/backends/state/composite.py +49 -0
- aury/agents/backends/state/file.py +57 -0
- aury/agents/backends/state/memory.py +52 -0
- aury/agents/backends/state/sqlite.py +262 -0
- aury/agents/backends/state/types.py +178 -0
- aury/agents/backends/subagent/__init__.py +165 -0
- aury/agents/cli/__init__.py +41 -0
- aury/agents/cli/chat.py +239 -0
- aury/agents/cli/config.py +236 -0
- aury/agents/cli/extensions.py +460 -0
- aury/agents/cli/main.py +189 -0
- aury/agents/cli/session.py +337 -0
- aury/agents/cli/workflow.py +276 -0
- aury/agents/context_providers/__init__.py +66 -0
- aury/agents/context_providers/artifact.py +299 -0
- aury/agents/context_providers/base.py +177 -0
- aury/agents/context_providers/memory.py +70 -0
- aury/agents/context_providers/message.py +130 -0
- aury/agents/context_providers/skill.py +50 -0
- aury/agents/context_providers/subagent.py +46 -0
- aury/agents/context_providers/tool.py +68 -0
- aury/agents/core/__init__.py +83 -0
- aury/agents/core/base.py +573 -0
- aury/agents/core/context.py +797 -0
- aury/agents/core/context_builder.py +303 -0
- aury/agents/core/event_bus/__init__.py +15 -0
- aury/agents/core/event_bus/bus.py +203 -0
- aury/agents/core/factory.py +169 -0
- aury/agents/core/isolator.py +97 -0
- aury/agents/core/logging.py +95 -0
- aury/agents/core/parallel.py +194 -0
- aury/agents/core/runner.py +139 -0
- aury/agents/core/services/__init__.py +5 -0
- aury/agents/core/services/file_session.py +144 -0
- aury/agents/core/services/message.py +53 -0
- aury/agents/core/services/session.py +53 -0
- aury/agents/core/signals.py +109 -0
- aury/agents/core/state.py +363 -0
- aury/agents/core/types/__init__.py +107 -0
- aury/agents/core/types/action.py +176 -0
- aury/agents/core/types/artifact.py +135 -0
- aury/agents/core/types/block.py +736 -0
- aury/agents/core/types/message.py +350 -0
- aury/agents/core/types/recall.py +144 -0
- aury/agents/core/types/session.py +257 -0
- aury/agents/core/types/subagent.py +154 -0
- aury/agents/core/types/tool.py +205 -0
- aury/agents/eval/__init__.py +331 -0
- aury/agents/hitl/__init__.py +57 -0
- aury/agents/hitl/ask_user.py +242 -0
- aury/agents/hitl/compaction.py +230 -0
- aury/agents/hitl/exceptions.py +87 -0
- aury/agents/hitl/permission.py +617 -0
- aury/agents/hitl/revert.py +216 -0
- aury/agents/llm/__init__.py +31 -0
- aury/agents/llm/adapter.py +367 -0
- aury/agents/llm/openai.py +294 -0
- aury/agents/llm/provider.py +476 -0
- aury/agents/mcp/__init__.py +153 -0
- aury/agents/memory/__init__.py +46 -0
- aury/agents/memory/compaction.py +394 -0
- aury/agents/memory/manager.py +465 -0
- aury/agents/memory/processor.py +177 -0
- aury/agents/memory/store.py +187 -0
- aury/agents/memory/types.py +137 -0
- aury/agents/messages/__init__.py +40 -0
- aury/agents/messages/config.py +47 -0
- aury/agents/messages/raw_store.py +224 -0
- aury/agents/messages/store.py +118 -0
- aury/agents/messages/types.py +88 -0
- aury/agents/middleware/__init__.py +31 -0
- aury/agents/middleware/base.py +341 -0
- aury/agents/middleware/chain.py +342 -0
- aury/agents/middleware/message.py +129 -0
- aury/agents/middleware/message_container.py +126 -0
- aury/agents/middleware/raw_message.py +153 -0
- aury/agents/middleware/truncation.py +139 -0
- aury/agents/middleware/types.py +81 -0
- aury/agents/plugin.py +162 -0
- aury/agents/react/__init__.py +4 -0
- aury/agents/react/agent.py +1923 -0
- aury/agents/sandbox/__init__.py +23 -0
- aury/agents/sandbox/local.py +239 -0
- aury/agents/sandbox/remote.py +200 -0
- aury/agents/sandbox/types.py +115 -0
- aury/agents/skill/__init__.py +16 -0
- aury/agents/skill/loader.py +180 -0
- aury/agents/skill/types.py +83 -0
- aury/agents/tool/__init__.py +39 -0
- aury/agents/tool/builtin/__init__.py +23 -0
- aury/agents/tool/builtin/ask_user.py +155 -0
- aury/agents/tool/builtin/bash.py +107 -0
- aury/agents/tool/builtin/delegate.py +726 -0
- aury/agents/tool/builtin/edit.py +121 -0
- aury/agents/tool/builtin/plan.py +277 -0
- aury/agents/tool/builtin/read.py +91 -0
- aury/agents/tool/builtin/thinking.py +111 -0
- aury/agents/tool/builtin/yield_result.py +130 -0
- aury/agents/tool/decorator.py +252 -0
- aury/agents/tool/set.py +204 -0
- aury/agents/usage/__init__.py +12 -0
- aury/agents/usage/tracker.py +236 -0
- aury/agents/workflow/__init__.py +85 -0
- aury/agents/workflow/adapter.py +268 -0
- aury/agents/workflow/dag.py +116 -0
- aury/agents/workflow/dsl.py +575 -0
- aury/agents/workflow/executor.py +659 -0
- aury/agents/workflow/expression.py +136 -0
- aury/agents/workflow/parser.py +182 -0
- aury/agents/workflow/state.py +145 -0
- aury/agents/workflow/types.py +86 -0
- aury_agent-0.0.4.dist-info/METADATA +90 -0
- aury_agent-0.0.4.dist-info/RECORD +149 -0
- aury_agent-0.0.4.dist-info/WHEEL +4 -0
- aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""Middleware chain for sequential processing."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from ..core.logging import middleware_logger as logger
|
|
8
|
+
from .types import TriggerMode, HookAction, HookResult
|
|
9
|
+
from .base import Middleware
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from ..core.types.tool import BaseTool, ToolResult
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class MiddlewareEntry:
|
|
17
|
+
"""Entry in middleware chain with inherit override."""
|
|
18
|
+
middleware: Middleware
|
|
19
|
+
inherit: bool # Effective inherit value (config default or overridden)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MiddlewareChain:
|
|
23
|
+
"""Chain of middlewares for sequential processing."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, middlewares: list[Middleware] | None = None) -> None:
|
|
26
|
+
self._entries: list[MiddlewareEntry] = []
|
|
27
|
+
self._token_buffer: str = ""
|
|
28
|
+
self._token_count: int = 0
|
|
29
|
+
|
|
30
|
+
# Add initial middlewares if provided
|
|
31
|
+
if middlewares:
|
|
32
|
+
for mw in middlewares:
|
|
33
|
+
self.use(mw)
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def _middlewares(self) -> list[Middleware]:
|
|
37
|
+
"""Get middleware list."""
|
|
38
|
+
return [e.middleware for e in self._entries]
|
|
39
|
+
|
|
40
|
+
def use(
|
|
41
|
+
self,
|
|
42
|
+
middleware: Middleware,
|
|
43
|
+
*,
|
|
44
|
+
inherit: bool | None = None,
|
|
45
|
+
) -> "MiddlewareChain":
|
|
46
|
+
"""Add middleware to chain.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
middleware: The middleware to add
|
|
50
|
+
inherit: Override inherit setting (None = use middleware's config default)
|
|
51
|
+
|
|
52
|
+
Maintains sorted order by priority.
|
|
53
|
+
"""
|
|
54
|
+
effective_inherit = inherit if inherit is not None else middleware.config.inherit
|
|
55
|
+
entry = MiddlewareEntry(middleware=middleware, inherit=effective_inherit)
|
|
56
|
+
self._entries.append(entry)
|
|
57
|
+
self._entries.sort(key=lambda e: e.middleware.config.priority)
|
|
58
|
+
return self
|
|
59
|
+
|
|
60
|
+
def remove(self, middleware: Middleware) -> "MiddlewareChain":
|
|
61
|
+
"""Remove middleware from chain."""
|
|
62
|
+
self._entries = [e for e in self._entries if e.middleware != middleware]
|
|
63
|
+
return self
|
|
64
|
+
|
|
65
|
+
def clear(self) -> "MiddlewareChain":
|
|
66
|
+
"""Clear all middlewares."""
|
|
67
|
+
self._entries.clear()
|
|
68
|
+
return self
|
|
69
|
+
|
|
70
|
+
def get_inheritable(self) -> list[MiddlewareEntry]:
|
|
71
|
+
"""Get entries that should be inherited by sub-agents."""
|
|
72
|
+
return [e for e in self._entries if e.inherit]
|
|
73
|
+
|
|
74
|
+
def merge(self, other: "MiddlewareChain | None") -> "MiddlewareChain":
|
|
75
|
+
"""Merge this chain's inheritable middlewares with another chain.
|
|
76
|
+
|
|
77
|
+
Creates a new chain with:
|
|
78
|
+
- This chain's inheritable middlewares
|
|
79
|
+
- All of other chain's middlewares
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
other: Chain to merge with (sub-agent's own middlewares)
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
New merged MiddlewareChain
|
|
86
|
+
"""
|
|
87
|
+
merged = MiddlewareChain()
|
|
88
|
+
|
|
89
|
+
# Add inheritable from this chain
|
|
90
|
+
for entry in self.get_inheritable():
|
|
91
|
+
merged._entries.append(MiddlewareEntry(
|
|
92
|
+
middleware=entry.middleware,
|
|
93
|
+
inherit=entry.inherit,
|
|
94
|
+
))
|
|
95
|
+
|
|
96
|
+
# Add all from other chain
|
|
97
|
+
if other:
|
|
98
|
+
for entry in other._entries:
|
|
99
|
+
# Avoid duplicates (same middleware instance)
|
|
100
|
+
if entry.middleware not in [e.middleware for e in merged._entries]:
|
|
101
|
+
merged._entries.append(MiddlewareEntry(
|
|
102
|
+
middleware=entry.middleware,
|
|
103
|
+
inherit=entry.inherit,
|
|
104
|
+
))
|
|
105
|
+
|
|
106
|
+
# Re-sort by priority
|
|
107
|
+
merged._entries.sort(key=lambda e: e.middleware.config.priority)
|
|
108
|
+
return merged
|
|
109
|
+
|
|
110
|
+
async def process_request(
|
|
111
|
+
self,
|
|
112
|
+
request: dict[str, Any],
|
|
113
|
+
context: dict[str, Any],
|
|
114
|
+
) -> dict[str, Any] | None:
|
|
115
|
+
"""Process request through all middlewares."""
|
|
116
|
+
current = request
|
|
117
|
+
|
|
118
|
+
for mw in self._middlewares:
|
|
119
|
+
result = await mw.on_request(current, context)
|
|
120
|
+
if result is None:
|
|
121
|
+
return None
|
|
122
|
+
current = result
|
|
123
|
+
|
|
124
|
+
return current
|
|
125
|
+
|
|
126
|
+
async def process_response(
|
|
127
|
+
self,
|
|
128
|
+
response: dict[str, Any],
|
|
129
|
+
context: dict[str, Any],
|
|
130
|
+
) -> dict[str, Any] | None:
|
|
131
|
+
"""Process response through all middlewares (reverse order)."""
|
|
132
|
+
current = response
|
|
133
|
+
|
|
134
|
+
for mw in reversed(self._middlewares):
|
|
135
|
+
result = await mw.on_response(current, context)
|
|
136
|
+
if result is None:
|
|
137
|
+
return None
|
|
138
|
+
current = result
|
|
139
|
+
|
|
140
|
+
return current
|
|
141
|
+
|
|
142
|
+
async def process_error(
|
|
143
|
+
self,
|
|
144
|
+
error: Exception,
|
|
145
|
+
context: dict[str, Any],
|
|
146
|
+
) -> Exception | None:
|
|
147
|
+
"""Process error through all middlewares."""
|
|
148
|
+
current = error
|
|
149
|
+
|
|
150
|
+
for mw in self._middlewares:
|
|
151
|
+
result = await mw.on_error(current, context)
|
|
152
|
+
if result is None:
|
|
153
|
+
return None
|
|
154
|
+
current = result
|
|
155
|
+
|
|
156
|
+
return current
|
|
157
|
+
|
|
158
|
+
async def process_stream_chunk(
|
|
159
|
+
self,
|
|
160
|
+
chunk: dict[str, Any],
|
|
161
|
+
context: dict[str, Any],
|
|
162
|
+
) -> dict[str, Any] | None:
|
|
163
|
+
"""Process streaming chunk through middlewares based on trigger mode."""
|
|
164
|
+
text = chunk.get("text", chunk.get("delta", ""))
|
|
165
|
+
self._token_buffer += text
|
|
166
|
+
self._token_count += 1
|
|
167
|
+
|
|
168
|
+
current = chunk
|
|
169
|
+
|
|
170
|
+
for mw in self._middlewares:
|
|
171
|
+
should_trigger = self._should_trigger(mw, text)
|
|
172
|
+
|
|
173
|
+
if should_trigger:
|
|
174
|
+
result = await mw.on_model_stream(current, context)
|
|
175
|
+
if result is None:
|
|
176
|
+
return None
|
|
177
|
+
current = result
|
|
178
|
+
|
|
179
|
+
return current
|
|
180
|
+
|
|
181
|
+
def _should_trigger(self, middleware: Middleware, text: str) -> bool:
|
|
182
|
+
"""Check if middleware should be triggered."""
|
|
183
|
+
mode = middleware.config.trigger_mode
|
|
184
|
+
|
|
185
|
+
if mode == TriggerMode.EVERY_TOKEN:
|
|
186
|
+
return True
|
|
187
|
+
elif mode == TriggerMode.EVERY_N_TOKENS:
|
|
188
|
+
return self._token_count % middleware.config.trigger_n == 0
|
|
189
|
+
elif mode == TriggerMode.ON_BOUNDARY:
|
|
190
|
+
return self._is_boundary(text)
|
|
191
|
+
|
|
192
|
+
return True
|
|
193
|
+
|
|
194
|
+
def _is_boundary(self, text: str) -> bool:
|
|
195
|
+
"""Check if text ends with a sentence/paragraph boundary."""
|
|
196
|
+
boundaries = (".", "。", "\n", "!", "?", "!", "?", ";", ";")
|
|
197
|
+
return text.rstrip().endswith(boundaries)
|
|
198
|
+
|
|
199
|
+
def reset_stream_state(self) -> None:
|
|
200
|
+
"""Reset streaming state (call at start of new stream)."""
|
|
201
|
+
self._token_buffer = ""
|
|
202
|
+
self._token_count = 0
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def middlewares(self) -> list[Middleware]:
|
|
206
|
+
"""Get list of middlewares (read-only)."""
|
|
207
|
+
return list(self._middlewares)
|
|
208
|
+
|
|
209
|
+
# ========== Lifecycle Hook Processing ==========
|
|
210
|
+
|
|
211
|
+
async def process_agent_start(
|
|
212
|
+
self,
|
|
213
|
+
agent_id: str,
|
|
214
|
+
input_data: Any,
|
|
215
|
+
context: dict[str, Any],
|
|
216
|
+
) -> HookResult:
|
|
217
|
+
"""Process agent start through all middlewares.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
First non-CONTINUE result, or CONTINUE if all pass
|
|
221
|
+
"""
|
|
222
|
+
for mw in self._middlewares:
|
|
223
|
+
if hasattr(mw, 'on_agent_start'):
|
|
224
|
+
result = await mw.on_agent_start(agent_id, input_data, context)
|
|
225
|
+
if result.action != HookAction.CONTINUE:
|
|
226
|
+
logger.debug(f"Middleware returned {result.action} on agent_start")
|
|
227
|
+
return result
|
|
228
|
+
return HookResult.proceed()
|
|
229
|
+
|
|
230
|
+
async def process_agent_end(
|
|
231
|
+
self,
|
|
232
|
+
agent_id: str,
|
|
233
|
+
result: Any,
|
|
234
|
+
context: dict[str, Any],
|
|
235
|
+
) -> HookResult:
|
|
236
|
+
"""Process agent end through all middlewares (reverse order)."""
|
|
237
|
+
for mw in reversed(self._middlewares):
|
|
238
|
+
if hasattr(mw, 'on_agent_end'):
|
|
239
|
+
hook_result = await mw.on_agent_end(agent_id, result, context)
|
|
240
|
+
if hook_result.action != HookAction.CONTINUE:
|
|
241
|
+
logger.debug(f"Middleware returned {hook_result.action} on agent_end")
|
|
242
|
+
return hook_result
|
|
243
|
+
return HookResult.proceed()
|
|
244
|
+
|
|
245
|
+
async def process_tool_call(
|
|
246
|
+
self,
|
|
247
|
+
tool: "BaseTool",
|
|
248
|
+
params: dict[str, Any],
|
|
249
|
+
context: dict[str, Any],
|
|
250
|
+
) -> HookResult:
|
|
251
|
+
"""Process tool call through all middlewares.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
SKIP to skip tool, RETRY with modified_data to change params
|
|
255
|
+
"""
|
|
256
|
+
for mw in self._middlewares:
|
|
257
|
+
if hasattr(mw, 'on_tool_call'):
|
|
258
|
+
result = await mw.on_tool_call(tool, params, context)
|
|
259
|
+
if result.action != HookAction.CONTINUE:
|
|
260
|
+
logger.debug(f"Middleware returned {result.action} on tool_call")
|
|
261
|
+
return result
|
|
262
|
+
return HookResult.proceed()
|
|
263
|
+
|
|
264
|
+
async def process_tool_end(
|
|
265
|
+
self,
|
|
266
|
+
tool: "BaseTool",
|
|
267
|
+
result: "ToolResult",
|
|
268
|
+
context: dict[str, Any],
|
|
269
|
+
) -> HookResult:
|
|
270
|
+
"""Process tool end through all middlewares (reverse order)."""
|
|
271
|
+
for mw in reversed(self._middlewares):
|
|
272
|
+
if hasattr(mw, 'on_tool_end'):
|
|
273
|
+
hook_result = await mw.on_tool_end(tool, result, context)
|
|
274
|
+
if hook_result.action != HookAction.CONTINUE:
|
|
275
|
+
logger.debug(f"Middleware returned {hook_result.action} on tool_end")
|
|
276
|
+
return hook_result
|
|
277
|
+
return HookResult.proceed()
|
|
278
|
+
|
|
279
|
+
async def process_subagent_start(
|
|
280
|
+
self,
|
|
281
|
+
parent_agent_id: str,
|
|
282
|
+
child_agent_id: str,
|
|
283
|
+
mode: str,
|
|
284
|
+
context: dict[str, Any],
|
|
285
|
+
) -> HookResult:
|
|
286
|
+
"""Process sub-agent start through all middlewares."""
|
|
287
|
+
for mw in self._middlewares:
|
|
288
|
+
if hasattr(mw, 'on_subagent_start'):
|
|
289
|
+
result = await mw.on_subagent_start(
|
|
290
|
+
parent_agent_id, child_agent_id, mode, context
|
|
291
|
+
)
|
|
292
|
+
if result.action != HookAction.CONTINUE:
|
|
293
|
+
logger.debug(f"Middleware returned {result.action} on subagent_start")
|
|
294
|
+
return result
|
|
295
|
+
return HookResult.proceed()
|
|
296
|
+
|
|
297
|
+
async def process_subagent_end(
|
|
298
|
+
self,
|
|
299
|
+
parent_agent_id: str,
|
|
300
|
+
child_agent_id: str,
|
|
301
|
+
result: Any,
|
|
302
|
+
context: dict[str, Any],
|
|
303
|
+
) -> HookResult:
|
|
304
|
+
"""Process sub-agent end through all middlewares (reverse order)."""
|
|
305
|
+
for mw in reversed(self._middlewares):
|
|
306
|
+
if hasattr(mw, 'on_subagent_end'):
|
|
307
|
+
hook_result = await mw.on_subagent_end(
|
|
308
|
+
parent_agent_id, child_agent_id, result, context
|
|
309
|
+
)
|
|
310
|
+
if hook_result.action != HookAction.CONTINUE:
|
|
311
|
+
logger.debug(f"Middleware returned {hook_result.action} on subagent_end")
|
|
312
|
+
return hook_result
|
|
313
|
+
return HookResult.proceed()
|
|
314
|
+
|
|
315
|
+
async def process_message_save(
|
|
316
|
+
self,
|
|
317
|
+
message: dict[str, Any],
|
|
318
|
+
context: dict[str, Any],
|
|
319
|
+
) -> dict[str, Any] | None:
|
|
320
|
+
"""Process message save through all middlewares.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
message: Message to be saved
|
|
324
|
+
context: Execution context
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
Modified message, or None to skip saving
|
|
328
|
+
"""
|
|
329
|
+
current = message
|
|
330
|
+
|
|
331
|
+
for mw in self._middlewares:
|
|
332
|
+
if hasattr(mw, 'on_message_save'):
|
|
333
|
+
result = await mw.on_message_save(current, context)
|
|
334
|
+
if result is None:
|
|
335
|
+
logger.debug("Middleware blocked message save")
|
|
336
|
+
return None
|
|
337
|
+
current = result
|
|
338
|
+
|
|
339
|
+
return current
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
__all__ = ["MiddlewareChain"]
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Message persistence middleware.
|
|
2
|
+
|
|
3
|
+
Uses backends.message directly to persist messages.
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
middleware = MessageBackendMiddleware()
|
|
7
|
+
|
|
8
|
+
agent = ReactAgent.create(
|
|
9
|
+
llm=llm,
|
|
10
|
+
middlewares=[middleware],
|
|
11
|
+
)
|
|
12
|
+
"""
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Any
|
|
16
|
+
|
|
17
|
+
from .base import BaseMiddleware
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from ..backends import MessageBackend
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MessageBackendMiddleware(BaseMiddleware):
|
|
24
|
+
"""Middleware that persists messages directly via MessageBackend.
|
|
25
|
+
|
|
26
|
+
This is the recommended middleware for message persistence.
|
|
27
|
+
Uses backends.message from the context for direct storage.
|
|
28
|
+
|
|
29
|
+
Features:
|
|
30
|
+
- Saves both truncated (for LLM context) and raw (for audit) messages
|
|
31
|
+
- Supports namespace isolation for sub-agents
|
|
32
|
+
- Works with any MessageBackend implementation
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
agent = ReactAgent.create(
|
|
36
|
+
llm=llm,
|
|
37
|
+
middlewares=[MessageBackendMiddleware()],
|
|
38
|
+
)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
*,
|
|
44
|
+
save_raw: bool = False,
|
|
45
|
+
max_history: int = 100,
|
|
46
|
+
):
|
|
47
|
+
"""Initialize MessageBackendMiddleware.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
save_raw: Also save raw (untruncated) messages for audit
|
|
51
|
+
max_history: Max messages to keep in history (for truncation)
|
|
52
|
+
"""
|
|
53
|
+
self.save_raw = save_raw
|
|
54
|
+
self.max_history = max_history
|
|
55
|
+
|
|
56
|
+
async def on_message_save(
|
|
57
|
+
self,
|
|
58
|
+
message: dict[str, Any],
|
|
59
|
+
context: dict[str, Any],
|
|
60
|
+
) -> dict[str, Any] | None:
|
|
61
|
+
"""Save message via backends.message.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
message: Message dict with 'role', 'content', etc.
|
|
65
|
+
context: Execution context with 'session_id', 'agent_id', 'backends'
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
The message (pass through to other middlewares)
|
|
69
|
+
"""
|
|
70
|
+
from ..core.context import get_current_ctx_or_none
|
|
71
|
+
|
|
72
|
+
session_id = context.get("session_id", "")
|
|
73
|
+
if not session_id:
|
|
74
|
+
return message
|
|
75
|
+
|
|
76
|
+
# Get MessageBackend from context
|
|
77
|
+
ctx = get_current_ctx_or_none()
|
|
78
|
+
if ctx is None or ctx.backends is None or ctx.backends.message is None:
|
|
79
|
+
# No backend available, pass through
|
|
80
|
+
return message
|
|
81
|
+
|
|
82
|
+
backend = ctx.backends.message
|
|
83
|
+
|
|
84
|
+
# Extract message fields
|
|
85
|
+
role = message.get("role", "")
|
|
86
|
+
content = message.get("content", "")
|
|
87
|
+
invocation_id = context.get("invocation_id", "")
|
|
88
|
+
agent_id = context.get("agent_id")
|
|
89
|
+
namespace = context.get("namespace")
|
|
90
|
+
tool_call_id = message.get("tool_call_id")
|
|
91
|
+
|
|
92
|
+
# Build message dict for backend
|
|
93
|
+
msg_dict = {
|
|
94
|
+
"role": role,
|
|
95
|
+
"content": content,
|
|
96
|
+
}
|
|
97
|
+
if tool_call_id:
|
|
98
|
+
msg_dict["tool_call_id"] = tool_call_id
|
|
99
|
+
|
|
100
|
+
# Include tool_calls if present (for assistant messages)
|
|
101
|
+
if "tool_calls" in message:
|
|
102
|
+
msg_dict["tool_calls"] = message["tool_calls"]
|
|
103
|
+
|
|
104
|
+
# Save truncated message (for LLM context)
|
|
105
|
+
await backend.add(
|
|
106
|
+
session_id=session_id,
|
|
107
|
+
message=msg_dict,
|
|
108
|
+
type="truncated",
|
|
109
|
+
agent_id=agent_id,
|
|
110
|
+
namespace=namespace,
|
|
111
|
+
invocation_id=invocation_id,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Optionally save raw message (for audit)
|
|
115
|
+
if self.save_raw:
|
|
116
|
+
await backend.add(
|
|
117
|
+
session_id=session_id,
|
|
118
|
+
message=message, # Full original message
|
|
119
|
+
type="raw",
|
|
120
|
+
agent_id=agent_id,
|
|
121
|
+
namespace=namespace,
|
|
122
|
+
invocation_id=invocation_id,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Pass through to other middlewares
|
|
126
|
+
return message
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
__all__ = ["MessageBackendMiddleware"]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""MessageContainerMiddleware - Groups thinking and text blocks under a message container.
|
|
2
|
+
|
|
3
|
+
This middleware creates a "message" container block when LLM request starts,
|
|
4
|
+
and all thinking/text blocks emitted during the LLM call will have parent_id
|
|
5
|
+
pointing to this container.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
from aury.agents.middleware import MessageContainerMiddleware, MiddlewareChain
|
|
9
|
+
|
|
10
|
+
chain = MiddlewareChain()
|
|
11
|
+
chain.use(MessageContainerMiddleware())
|
|
12
|
+
|
|
13
|
+
agent = ReactAgent.create(llm=llm, middleware=chain)
|
|
14
|
+
|
|
15
|
+
Result structure:
|
|
16
|
+
message (block_id: blk_abc)
|
|
17
|
+
├── thinking (parent_id: blk_abc)
|
|
18
|
+
└── text (parent_id: blk_abc)
|
|
19
|
+
|
|
20
|
+
tool_use (parent_id: None) - not grouped
|
|
21
|
+
tool_result (parent_id: None) - not grouped
|
|
22
|
+
"""
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from typing import Any, TYPE_CHECKING
|
|
26
|
+
|
|
27
|
+
from .base import BaseMiddleware
|
|
28
|
+
from .types import HookResult
|
|
29
|
+
from ..core.context import set_parent_id, reset_parent_id, emit, get_parent_id
|
|
30
|
+
from ..core.types.session import generate_id
|
|
31
|
+
from ..core.types.block import BlockEvent, BlockKind, BlockOp
|
|
32
|
+
from ..core.logging import middleware_logger as logger
|
|
33
|
+
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from ..core.types.tool import BaseTool, ToolResult
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MessageContainerMiddleware(BaseMiddleware):
|
|
39
|
+
"""Groups thinking and text blocks under a message container.
|
|
40
|
+
|
|
41
|
+
When an LLM request starts, creates a "message" container block and sets
|
|
42
|
+
the parent_id ContextVar. Only thinking and text blocks will inherit this
|
|
43
|
+
parent_id (tool_use, tool_result, etc. are not grouped).
|
|
44
|
+
|
|
45
|
+
This allows frontend to group thinking + text as a single unit for display.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
apply_to_kinds: Set of block kinds that should inherit the container's
|
|
49
|
+
parent_id. Defaults to {"thinking", "text"}.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
# Key to store the token in middleware context
|
|
53
|
+
_TOKEN_KEY = "_message_container_token"
|
|
54
|
+
_BLOCK_ID_KEY = "_message_container_block_id"
|
|
55
|
+
|
|
56
|
+
# Default kinds that should be grouped under message container
|
|
57
|
+
DEFAULT_KINDS = {"thinking", "text"}
|
|
58
|
+
|
|
59
|
+
def __init__(self, apply_to_kinds: set[str] | None = None):
|
|
60
|
+
"""Initialize with optional custom kinds filter.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
apply_to_kinds: Block kinds to group. Defaults to {"thinking", "text"}.
|
|
64
|
+
"""
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.apply_to_kinds = apply_to_kinds or self.DEFAULT_KINDS
|
|
67
|
+
|
|
68
|
+
async def on_request(
|
|
69
|
+
self,
|
|
70
|
+
request: dict[str, Any],
|
|
71
|
+
context: dict[str, Any],
|
|
72
|
+
) -> dict[str, Any] | None:
|
|
73
|
+
"""Create message container block and set parent_id."""
|
|
74
|
+
# Generate container block ID
|
|
75
|
+
message_block_id = generate_id("blk")
|
|
76
|
+
|
|
77
|
+
logger.debug(
|
|
78
|
+
"[MessageContainerMiddleware] Creating container",
|
|
79
|
+
extra={"block_id": message_block_id, "apply_to_kinds": list(self.apply_to_kinds)}
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Emit the container block
|
|
83
|
+
await emit(BlockEvent(
|
|
84
|
+
block_id=message_block_id,
|
|
85
|
+
kind="message", # Container type
|
|
86
|
+
op=BlockOp.APPLY,
|
|
87
|
+
data={
|
|
88
|
+
"type": "llm_response",
|
|
89
|
+
"step": context.get("step"),
|
|
90
|
+
},
|
|
91
|
+
session_id=context.get("session_id"),
|
|
92
|
+
invocation_id=context.get("invocation_id"),
|
|
93
|
+
))
|
|
94
|
+
|
|
95
|
+
# Set parent_id in ContextVar with apply_to_kinds filter
|
|
96
|
+
# Only blocks matching apply_to_kinds will inherit this parent_id
|
|
97
|
+
token = set_parent_id(message_block_id, apply_to_kinds=self.apply_to_kinds)
|
|
98
|
+
context[self._TOKEN_KEY] = token
|
|
99
|
+
context[self._BLOCK_ID_KEY] = message_block_id
|
|
100
|
+
|
|
101
|
+
return request
|
|
102
|
+
|
|
103
|
+
async def on_response(
|
|
104
|
+
self,
|
|
105
|
+
response: dict[str, Any],
|
|
106
|
+
context: dict[str, Any],
|
|
107
|
+
) -> dict[str, Any] | None:
|
|
108
|
+
"""Reset parent_id to previous value."""
|
|
109
|
+
token = context.get(self._TOKEN_KEY)
|
|
110
|
+
if token is not None:
|
|
111
|
+
reset_parent_id(token)
|
|
112
|
+
return response
|
|
113
|
+
|
|
114
|
+
async def on_error(
|
|
115
|
+
self,
|
|
116
|
+
error: Exception,
|
|
117
|
+
context: dict[str, Any],
|
|
118
|
+
) -> Exception | None:
|
|
119
|
+
"""Reset parent_id on error too."""
|
|
120
|
+
token = context.get(self._TOKEN_KEY)
|
|
121
|
+
if token is not None:
|
|
122
|
+
reset_parent_id(token)
|
|
123
|
+
return error
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
__all__ = ["MessageContainerMiddleware"]
|