aury-agent 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/agents/backends/__init__.py +2 -3
- aury/agents/backends/message/__init__.py +1 -2
- aury/agents/backends/message/memory.py +15 -37
- aury/agents/backends/message/types.py +5 -32
- aury/agents/context_providers/message.py +0 -1
- aury/agents/core/context.py +3 -3
- aury/agents/core/types/tool.py +47 -7
- aury/agents/llm/adapter.py +3 -0
- aury/agents/llm/provider.py +43 -4
- aury/agents/messages/__init__.py +0 -9
- aury/agents/middleware/__init__.py +0 -2
- aury/agents/middleware/base.py +35 -5
- aury/agents/middleware/chain.py +64 -8
- aury/agents/middleware/message.py +5 -53
- aury/agents/react/factory.py +11 -6
- aury/agents/react/persistence.py +12 -2
- aury/agents/react/step.py +72 -6
- aury/agents/react/tools.py +20 -4
- aury/agents/tool/decorator.py +5 -20
- {aury_agent-0.0.9.dist-info → aury_agent-0.0.11.dist-info}/METADATA +1 -1
- {aury_agent-0.0.9.dist-info → aury_agent-0.0.11.dist-info}/RECORD +23 -25
- aury/agents/messages/raw_store.py +0 -224
- aury/agents/middleware/raw_message.py +0 -154
- {aury_agent-0.0.9.dist-info → aury_agent-0.0.11.dist-info}/WHEEL +0 -0
- {aury_agent-0.0.9.dist-info → aury_agent-0.0.11.dist-info}/entry_points.txt +0 -0
aury/agents/backends/__init__.py
CHANGED
|
@@ -5,7 +5,7 @@ Backends provide abstracted interfaces for various capabilities:
|
|
|
5
5
|
Data Backends (storage):
|
|
6
6
|
- SessionBackend: Session management
|
|
7
7
|
- InvocationBackend: Invocation management
|
|
8
|
-
- MessageBackend: Message storage
|
|
8
|
+
- MessageBackend: Message storage
|
|
9
9
|
- MemoryBackend: Long-term memory with search
|
|
10
10
|
- ArtifactBackend: File/artifact storage
|
|
11
11
|
- StateBackend: Generic key-value state
|
|
@@ -26,7 +26,7 @@ from typing import TYPE_CHECKING
|
|
|
26
26
|
# Data backends - new architecture
|
|
27
27
|
from .session import SessionBackend, InMemorySessionBackend
|
|
28
28
|
from .invocation import InvocationBackend, InMemoryInvocationBackend
|
|
29
|
-
from .message import MessageBackend,
|
|
29
|
+
from .message import MessageBackend, InMemoryMessageBackend
|
|
30
30
|
from .memory import MemoryBackend, InMemoryMemoryBackend
|
|
31
31
|
from .artifact import ArtifactBackend, ArtifactSource, InMemoryArtifactBackend
|
|
32
32
|
|
|
@@ -144,7 +144,6 @@ __all__ = [
|
|
|
144
144
|
|
|
145
145
|
# Message backend
|
|
146
146
|
"MessageBackend",
|
|
147
|
-
"MessageType",
|
|
148
147
|
"InMemoryMessageBackend",
|
|
149
148
|
|
|
150
149
|
# Memory backend
|
|
@@ -4,45 +4,36 @@ from __future__ import annotations
|
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
|
-
from .types import MessageType
|
|
8
|
-
|
|
9
7
|
|
|
10
8
|
class InMemoryMessageBackend:
|
|
11
9
|
"""In-memory implementation of MessageBackend.
|
|
12
10
|
|
|
13
|
-
|
|
14
|
-
Suitable for testing and simple single-process use cases.
|
|
11
|
+
Simple in-memory storage for testing and single-process use cases.
|
|
15
12
|
"""
|
|
16
13
|
|
|
17
14
|
def __init__(self) -> None:
|
|
18
15
|
# Key format: "{session_id}" or "{session_id}:{namespace}"
|
|
19
16
|
# Value: list of message dicts
|
|
20
|
-
self.
|
|
21
|
-
self._raw: dict[str, list[dict[str, Any]]] = {}
|
|
17
|
+
self._messages: dict[str, list[dict[str, Any]]] = {}
|
|
22
18
|
|
|
23
19
|
def _make_key(self, session_id: str, namespace: str | None) -> str:
|
|
24
20
|
if namespace:
|
|
25
21
|
return f"{session_id}:{namespace}"
|
|
26
22
|
return session_id
|
|
27
23
|
|
|
28
|
-
def _get_store(self, type: MessageType) -> dict[str, list[dict[str, Any]]]:
|
|
29
|
-
return self._truncated if type == "truncated" else self._raw
|
|
30
|
-
|
|
31
24
|
async def add(
|
|
32
25
|
self,
|
|
33
26
|
session_id: str,
|
|
34
27
|
message: dict[str, Any],
|
|
35
|
-
type: MessageType = "truncated",
|
|
36
28
|
agent_id: str | None = None,
|
|
37
29
|
namespace: str | None = None,
|
|
38
30
|
invocation_id: str | None = None,
|
|
39
31
|
) -> None:
|
|
40
32
|
"""Add a message."""
|
|
41
33
|
key = self._make_key(session_id, namespace)
|
|
42
|
-
store = self._get_store(type)
|
|
43
34
|
|
|
44
|
-
if key not in
|
|
45
|
-
|
|
35
|
+
if key not in self._messages:
|
|
36
|
+
self._messages[key] = []
|
|
46
37
|
|
|
47
38
|
# Add metadata
|
|
48
39
|
msg = {
|
|
@@ -51,20 +42,18 @@ class InMemoryMessageBackend:
|
|
|
51
42
|
"invocation_id": invocation_id,
|
|
52
43
|
"created_at": datetime.now().isoformat(),
|
|
53
44
|
}
|
|
54
|
-
|
|
45
|
+
self._messages[key].append(msg)
|
|
55
46
|
|
|
56
47
|
async def get(
|
|
57
48
|
self,
|
|
58
49
|
session_id: str,
|
|
59
|
-
type: MessageType = "truncated",
|
|
60
50
|
agent_id: str | None = None,
|
|
61
51
|
namespace: str | None = None,
|
|
62
52
|
limit: int | None = None,
|
|
63
53
|
) -> list[dict[str, Any]]:
|
|
64
54
|
"""Get messages."""
|
|
65
55
|
key = self._make_key(session_id, namespace)
|
|
66
|
-
|
|
67
|
-
messages = store.get(key, [])
|
|
56
|
+
messages = self._messages.get(key, [])
|
|
68
57
|
|
|
69
58
|
# Filter by agent_id if specified
|
|
70
59
|
if agent_id:
|
|
@@ -80,42 +69,31 @@ class InMemoryMessageBackend:
|
|
|
80
69
|
self,
|
|
81
70
|
session_id: str,
|
|
82
71
|
invocation_id: str,
|
|
83
|
-
type: MessageType | None = None,
|
|
84
72
|
namespace: str | None = None,
|
|
85
73
|
) -> int:
|
|
86
74
|
"""Delete messages by invocation."""
|
|
87
75
|
key = self._make_key(session_id, namespace)
|
|
88
|
-
deleted = 0
|
|
89
|
-
|
|
90
|
-
types_to_delete = [type] if type else ["truncated", "raw"]
|
|
91
76
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
if key in store:
|
|
95
|
-
original = store[key]
|
|
96
|
-
store[key] = [m for m in original if m.get("invocation_id") != invocation_id]
|
|
97
|
-
deleted += len(original) - len(store[key])
|
|
77
|
+
if key not in self._messages:
|
|
78
|
+
return 0
|
|
98
79
|
|
|
99
|
-
|
|
80
|
+
original = self._messages[key]
|
|
81
|
+
self._messages[key] = [m for m in original if m.get("invocation_id") != invocation_id]
|
|
82
|
+
return len(original) - len(self._messages[key])
|
|
100
83
|
|
|
101
84
|
async def clear(
|
|
102
85
|
self,
|
|
103
86
|
session_id: str,
|
|
104
|
-
type: MessageType | None = None,
|
|
105
87
|
namespace: str | None = None,
|
|
106
88
|
) -> int:
|
|
107
89
|
"""Clear all messages for a session."""
|
|
108
90
|
key = self._make_key(session_id, namespace)
|
|
109
|
-
deleted = 0
|
|
110
|
-
|
|
111
|
-
types_to_clear = [type] if type else ["truncated", "raw"]
|
|
112
91
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
if key in store:
|
|
116
|
-
deleted += len(store[key])
|
|
117
|
-
del store[key]
|
|
92
|
+
if key not in self._messages:
|
|
93
|
+
return 0
|
|
118
94
|
|
|
95
|
+
deleted = len(self._messages[key])
|
|
96
|
+
del self._messages[key]
|
|
119
97
|
return deleted
|
|
120
98
|
|
|
121
99
|
|
|
@@ -1,49 +1,29 @@
|
|
|
1
1
|
"""Message backend types and protocols."""
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
|
-
from typing import Any,
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
MessageType = Literal["truncated", "raw"]
|
|
4
|
+
from typing import Any, Protocol, runtime_checkable
|
|
8
5
|
|
|
9
6
|
|
|
10
7
|
@runtime_checkable
|
|
11
8
|
class MessageBackend(Protocol):
|
|
12
9
|
"""Protocol for message storage.
|
|
13
10
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
- truncated: Messages kept in context window, may be summarized/trimmed
|
|
18
|
-
- raw: Full original messages for audit/replay
|
|
11
|
+
Simple interface for message persistence.
|
|
12
|
+
Storage details (raw/truncated handling) are left to the application layer.
|
|
19
13
|
|
|
20
14
|
Example usage:
|
|
21
|
-
# Add truncated message (for LLM context)
|
|
22
15
|
await backend.add(
|
|
23
16
|
session_id="sess_123",
|
|
24
17
|
message={"role": "user", "content": "Hello"},
|
|
25
|
-
type="truncated",
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
# Add raw message (for audit)
|
|
29
|
-
await backend.add(
|
|
30
|
-
session_id="sess_123",
|
|
31
|
-
message={"role": "user", "content": "Hello", "attachments": [...]},
|
|
32
|
-
type="raw",
|
|
33
18
|
)
|
|
34
19
|
|
|
35
|
-
|
|
36
|
-
messages = await backend.get("sess_123", type="truncated", limit=50)
|
|
37
|
-
|
|
38
|
-
# Get raw history
|
|
39
|
-
raw_messages = await backend.get("sess_123", type="raw")
|
|
20
|
+
messages = await backend.get("sess_123", limit=50)
|
|
40
21
|
"""
|
|
41
22
|
|
|
42
23
|
async def add(
|
|
43
24
|
self,
|
|
44
25
|
session_id: str,
|
|
45
26
|
message: dict[str, Any],
|
|
46
|
-
type: MessageType = "truncated",
|
|
47
27
|
agent_id: str | None = None,
|
|
48
28
|
namespace: str | None = None,
|
|
49
29
|
invocation_id: str | None = None,
|
|
@@ -53,7 +33,6 @@ class MessageBackend(Protocol):
|
|
|
53
33
|
Args:
|
|
54
34
|
session_id: Session ID
|
|
55
35
|
message: Message dict (role, content, tool_call_id, etc.)
|
|
56
|
-
type: Message type - "truncated" or "raw"
|
|
57
36
|
agent_id: Optional agent ID
|
|
58
37
|
namespace: Optional namespace for sub-agent isolation
|
|
59
38
|
invocation_id: Optional invocation ID for grouping
|
|
@@ -63,7 +42,6 @@ class MessageBackend(Protocol):
|
|
|
63
42
|
async def get(
|
|
64
43
|
self,
|
|
65
44
|
session_id: str,
|
|
66
|
-
type: MessageType = "truncated",
|
|
67
45
|
agent_id: str | None = None,
|
|
68
46
|
namespace: str | None = None,
|
|
69
47
|
limit: int | None = None,
|
|
@@ -72,7 +50,6 @@ class MessageBackend(Protocol):
|
|
|
72
50
|
|
|
73
51
|
Args:
|
|
74
52
|
session_id: Session ID
|
|
75
|
-
type: Message type - "truncated" or "raw"
|
|
76
53
|
agent_id: Optional filter by agent
|
|
77
54
|
namespace: Optional namespace filter
|
|
78
55
|
limit: Max messages to return (None = all)
|
|
@@ -86,7 +63,6 @@ class MessageBackend(Protocol):
|
|
|
86
63
|
self,
|
|
87
64
|
session_id: str,
|
|
88
65
|
invocation_id: str,
|
|
89
|
-
type: MessageType | None = None,
|
|
90
66
|
namespace: str | None = None,
|
|
91
67
|
) -> int:
|
|
92
68
|
"""Delete messages by invocation (for revert).
|
|
@@ -94,7 +70,6 @@ class MessageBackend(Protocol):
|
|
|
94
70
|
Args:
|
|
95
71
|
session_id: Session ID
|
|
96
72
|
invocation_id: Invocation ID to delete
|
|
97
|
-
type: Message type to delete, None = both types
|
|
98
73
|
namespace: Optional namespace filter
|
|
99
74
|
|
|
100
75
|
Returns:
|
|
@@ -105,14 +80,12 @@ class MessageBackend(Protocol):
|
|
|
105
80
|
async def clear(
|
|
106
81
|
self,
|
|
107
82
|
session_id: str,
|
|
108
|
-
type: MessageType | None = None,
|
|
109
83
|
namespace: str | None = None,
|
|
110
84
|
) -> int:
|
|
111
85
|
"""Clear all messages for a session.
|
|
112
86
|
|
|
113
87
|
Args:
|
|
114
88
|
session_id: Session ID
|
|
115
|
-
type: Message type to clear, None = both types
|
|
116
89
|
namespace: Optional namespace filter
|
|
117
90
|
|
|
118
91
|
Returns:
|
|
@@ -121,4 +94,4 @@ class MessageBackend(Protocol):
|
|
|
121
94
|
...
|
|
122
95
|
|
|
123
96
|
|
|
124
|
-
__all__ = ["MessageBackend"
|
|
97
|
+
__all__ = ["MessageBackend"]
|
|
@@ -79,7 +79,6 @@ class MessageContextProvider(BaseContextProvider):
|
|
|
79
79
|
if ctx.backends is not None and ctx.backends.message is not None:
|
|
80
80
|
messages = await ctx.backends.message.get(
|
|
81
81
|
session_id=ctx.session.id,
|
|
82
|
-
type="truncated",
|
|
83
82
|
limit=self.max_messages,
|
|
84
83
|
)
|
|
85
84
|
# Convert to LLM format (include tool_call_id for tool messages)
|
aury/agents/core/context.py
CHANGED
|
@@ -728,11 +728,11 @@ class InvocationContext:
|
|
|
728
728
|
**{k: v for k, v in request.items() if k not in ("messages", "stream")}
|
|
729
729
|
):
|
|
730
730
|
if self.middleware:
|
|
731
|
-
chunk_dict = {"
|
|
732
|
-
processed = await self.middleware.
|
|
731
|
+
chunk_dict = {"delta": chunk}
|
|
732
|
+
processed = await self.middleware.process_text_stream(chunk_dict)
|
|
733
733
|
if processed is None:
|
|
734
734
|
continue
|
|
735
|
-
chunk = processed.get("
|
|
735
|
+
chunk = processed.get("delta", chunk)
|
|
736
736
|
yield chunk
|
|
737
737
|
|
|
738
738
|
except Exception as e:
|
aury/agents/core/types/tool.py
CHANGED
|
@@ -90,14 +90,41 @@ class ToolContext:
|
|
|
90
90
|
|
|
91
91
|
@dataclass
|
|
92
92
|
class ToolResult:
|
|
93
|
-
"""Tool execution result for LLM.
|
|
94
|
-
|
|
93
|
+
"""Tool execution result for LLM.
|
|
94
|
+
|
|
95
|
+
Supports dual output for context management:
|
|
96
|
+
- output: Complete output (raw), for storage and recall
|
|
97
|
+
- truncated_output: Shortened output for context window
|
|
98
|
+
|
|
99
|
+
If truncated_output is not provided, it defaults to output.
|
|
100
|
+
"""
|
|
101
|
+
output: str # Complete output (raw)
|
|
95
102
|
is_error: bool = False
|
|
103
|
+
truncated_output: str | None = None # Shortened output (defaults to output)
|
|
104
|
+
|
|
105
|
+
def __post_init__(self):
|
|
106
|
+
# Default truncated to output if not provided
|
|
107
|
+
if self.truncated_output is None:
|
|
108
|
+
self.truncated_output = self.output
|
|
96
109
|
|
|
97
110
|
@classmethod
|
|
98
|
-
def success(
|
|
99
|
-
|
|
100
|
-
|
|
111
|
+
def success(
|
|
112
|
+
cls,
|
|
113
|
+
output: str,
|
|
114
|
+
*,
|
|
115
|
+
truncated_output: str | None = None,
|
|
116
|
+
) -> ToolResult:
|
|
117
|
+
"""Create a successful result.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
output: Complete output (raw)
|
|
121
|
+
truncated_output: Shortened output for context (defaults to output)
|
|
122
|
+
"""
|
|
123
|
+
return cls(
|
|
124
|
+
output=output,
|
|
125
|
+
is_error=False,
|
|
126
|
+
truncated_output=truncated_output,
|
|
127
|
+
)
|
|
101
128
|
|
|
102
129
|
@classmethod
|
|
103
130
|
def error(cls, message: str) -> ToolResult:
|
|
@@ -121,6 +148,7 @@ class ToolInvocation:
|
|
|
121
148
|
args: dict[str, Any] = field(default_factory=dict)
|
|
122
149
|
args_raw: str = "" # Raw JSON string for streaming
|
|
123
150
|
result: str | None = None
|
|
151
|
+
truncated_result: str | None = None # Shortened result for context window
|
|
124
152
|
is_error: bool = False
|
|
125
153
|
|
|
126
154
|
# Timing
|
|
@@ -134,10 +162,22 @@ class ToolInvocation:
|
|
|
134
162
|
self.state = ToolInvocationState.CALL
|
|
135
163
|
self.time["start"] = datetime.now()
|
|
136
164
|
|
|
137
|
-
def mark_result(
|
|
138
|
-
|
|
165
|
+
def mark_result(
|
|
166
|
+
self,
|
|
167
|
+
result: str,
|
|
168
|
+
is_error: bool = False,
|
|
169
|
+
truncated_result: str | None = None,
|
|
170
|
+
) -> None:
|
|
171
|
+
"""Mark execution complete.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
result: Complete result (raw)
|
|
175
|
+
is_error: Whether this is an error result
|
|
176
|
+
truncated_result: Shortened result for context window (defaults to result)
|
|
177
|
+
"""
|
|
139
178
|
self.state = ToolInvocationState.RESULT
|
|
140
179
|
self.result = result
|
|
180
|
+
self.truncated_result = truncated_result if truncated_result is not None else result
|
|
141
181
|
self.is_error = is_error
|
|
142
182
|
self.time["end"] = datetime.now()
|
|
143
183
|
|
aury/agents/llm/adapter.py
CHANGED
|
@@ -261,6 +261,9 @@ class ModelClientProvider:
|
|
|
261
261
|
case Evt.thinking:
|
|
262
262
|
return LLMEvent(type="thinking", delta=event.delta)
|
|
263
263
|
|
|
264
|
+
case Evt.thinking_completed:
|
|
265
|
+
return LLMEvent(type="thinking_completed")
|
|
266
|
+
|
|
264
267
|
case Evt.tool_call_start:
|
|
265
268
|
if event.tool_call:
|
|
266
269
|
return LLMEvent(
|
aury/agents/llm/provider.py
CHANGED
|
@@ -133,18 +133,38 @@ class LLMMessage:
|
|
|
133
133
|
- system: System prompt
|
|
134
134
|
- user: User message (can include images)
|
|
135
135
|
- assistant: Assistant response (can include tool_calls)
|
|
136
|
-
- tool: Tool result (requires tool_call_id)
|
|
136
|
+
- tool: Tool result (requires tool_call_id and name)
|
|
137
|
+
|
|
138
|
+
Supports dual content for context management:
|
|
139
|
+
- content: Complete content (raw), for storage and recall
|
|
140
|
+
- truncated_content: Shortened content for context window (defaults to content)
|
|
137
141
|
"""
|
|
138
142
|
role: Literal["system", "user", "assistant", "tool"]
|
|
139
143
|
content: str | list[dict[str, Any]]
|
|
140
144
|
tool_call_id: str | None = None # Required for tool role
|
|
145
|
+
name: str | None = None # Tool name, required for Gemini compatibility
|
|
146
|
+
truncated_content: str | list[dict[str, Any]] | None = None # Shortened content (defaults to content)
|
|
141
147
|
|
|
142
148
|
def to_dict(self) -> dict[str, Any]:
|
|
143
149
|
d = {"role": self.role, "content": self.content}
|
|
144
150
|
if self.tool_call_id:
|
|
145
151
|
d["tool_call_id"] = self.tool_call_id
|
|
152
|
+
if self.name:
|
|
153
|
+
d["name"] = self.name
|
|
154
|
+
if self.truncated_content is not None:
|
|
155
|
+
d["truncated_content"] = self.truncated_content
|
|
146
156
|
return d
|
|
147
157
|
|
|
158
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
159
|
+
"""Dict-like access for middleware compatibility."""
|
|
160
|
+
return getattr(self, key, default)
|
|
161
|
+
|
|
162
|
+
def __getitem__(self, key: str) -> Any:
|
|
163
|
+
"""Dict-like access via []."""
|
|
164
|
+
if hasattr(self, key):
|
|
165
|
+
return getattr(self, key)
|
|
166
|
+
raise KeyError(key)
|
|
167
|
+
|
|
148
168
|
@classmethod
|
|
149
169
|
def system(cls, content: str) -> "LLMMessage":
|
|
150
170
|
"""Create system message."""
|
|
@@ -161,9 +181,28 @@ class LLMMessage:
|
|
|
161
181
|
return cls(role="assistant", content=content)
|
|
162
182
|
|
|
163
183
|
@classmethod
|
|
164
|
-
def tool(
|
|
165
|
-
|
|
166
|
-
|
|
184
|
+
def tool(
|
|
185
|
+
cls,
|
|
186
|
+
content: str,
|
|
187
|
+
tool_call_id: str,
|
|
188
|
+
name: str | None = None,
|
|
189
|
+
truncated_content: str | None = None,
|
|
190
|
+
) -> "LLMMessage":
|
|
191
|
+
"""Create tool result message.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
content: Tool result content (complete/raw)
|
|
195
|
+
tool_call_id: ID of the tool call this result is for
|
|
196
|
+
name: Tool name (required for Gemini compatibility)
|
|
197
|
+
truncated_content: Shortened content for context window (defaults to content)
|
|
198
|
+
"""
|
|
199
|
+
return cls(
|
|
200
|
+
role="tool",
|
|
201
|
+
content=content,
|
|
202
|
+
tool_call_id=tool_call_id,
|
|
203
|
+
name=name,
|
|
204
|
+
truncated_content=truncated_content,
|
|
205
|
+
)
|
|
167
206
|
|
|
168
207
|
|
|
169
208
|
@runtime_checkable
|
aury/agents/messages/__init__.py
CHANGED
|
@@ -15,11 +15,6 @@ from .store import (
|
|
|
15
15
|
MessageStore,
|
|
16
16
|
InMemoryMessageStore,
|
|
17
17
|
)
|
|
18
|
-
from .raw_store import (
|
|
19
|
-
RawMessageStore,
|
|
20
|
-
StateBackendRawMessageStore,
|
|
21
|
-
InMemoryRawMessageStore,
|
|
22
|
-
)
|
|
23
18
|
from .config import (
|
|
24
19
|
MessageConfig,
|
|
25
20
|
)
|
|
@@ -31,10 +26,6 @@ __all__ = [
|
|
|
31
26
|
# Store (protocol + in-memory for testing)
|
|
32
27
|
"MessageStore",
|
|
33
28
|
"InMemoryMessageStore",
|
|
34
|
-
# Raw Store
|
|
35
|
-
"RawMessageStore",
|
|
36
|
-
"StateBackendRawMessageStore",
|
|
37
|
-
"InMemoryRawMessageStore",
|
|
38
29
|
# Config
|
|
39
30
|
"MessageConfig",
|
|
40
31
|
]
|
|
@@ -13,7 +13,6 @@ from .chain import MiddlewareChain
|
|
|
13
13
|
from .message_container import MessageContainerMiddleware
|
|
14
14
|
from .message import MessageBackendMiddleware
|
|
15
15
|
from .truncation import MessageTruncationMiddleware
|
|
16
|
-
from .raw_message import RawMessageMiddleware
|
|
17
16
|
|
|
18
17
|
__all__ = [
|
|
19
18
|
"TriggerMode",
|
|
@@ -27,5 +26,4 @@ __all__ = [
|
|
|
27
26
|
"MessageContainerMiddleware",
|
|
28
27
|
"MessageBackendMiddleware",
|
|
29
28
|
"MessageTruncationMiddleware",
|
|
30
|
-
"RawMessageMiddleware",
|
|
31
29
|
]
|
aury/agents/middleware/base.py
CHANGED
|
@@ -73,17 +73,28 @@ class Middleware(Protocol):
|
|
|
73
73
|
"""
|
|
74
74
|
...
|
|
75
75
|
|
|
76
|
-
async def
|
|
76
|
+
async def on_text_stream(
|
|
77
77
|
self,
|
|
78
78
|
chunk: dict[str, Any],
|
|
79
79
|
) -> dict[str, Any] | None:
|
|
80
|
-
"""Process streaming chunk
|
|
80
|
+
"""Process text streaming chunk.
|
|
81
81
|
|
|
82
82
|
Args:
|
|
83
|
-
chunk: The
|
|
83
|
+
chunk: The text chunk with {"delta": str}
|
|
84
84
|
|
|
85
85
|
Returns:
|
|
86
|
-
Modified chunk, or None to skip
|
|
86
|
+
Modified chunk, or None to skip
|
|
87
|
+
"""
|
|
88
|
+
...
|
|
89
|
+
|
|
90
|
+
async def on_text_stream_end(self) -> dict[str, Any] | None:
|
|
91
|
+
"""Called when text stream ends.
|
|
92
|
+
|
|
93
|
+
Use this to flush any buffered text content.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Optional dict with {"delta": str} to emit final content,
|
|
97
|
+
or None if no additional content.
|
|
87
98
|
"""
|
|
88
99
|
...
|
|
89
100
|
|
|
@@ -101,6 +112,17 @@ class Middleware(Protocol):
|
|
|
101
112
|
"""
|
|
102
113
|
...
|
|
103
114
|
|
|
115
|
+
async def on_thinking_stream_end(self) -> dict[str, Any] | None:
|
|
116
|
+
"""Called when thinking stream ends.
|
|
117
|
+
|
|
118
|
+
Use this to flush any buffered thinking content.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Optional dict with {"delta": str} to emit final thinking content,
|
|
122
|
+
or None if no additional content.
|
|
123
|
+
"""
|
|
124
|
+
...
|
|
125
|
+
|
|
104
126
|
# ========== Agent Lifecycle Hooks ==========
|
|
105
127
|
|
|
106
128
|
async def on_agent_start(
|
|
@@ -283,13 +305,17 @@ class BaseMiddleware:
|
|
|
283
305
|
"""Default: re-raise error."""
|
|
284
306
|
return error
|
|
285
307
|
|
|
286
|
-
async def
|
|
308
|
+
async def on_text_stream(
|
|
287
309
|
self,
|
|
288
310
|
chunk: dict[str, Any],
|
|
289
311
|
) -> dict[str, Any] | None:
|
|
290
312
|
"""Default: pass through."""
|
|
291
313
|
return chunk
|
|
292
314
|
|
|
315
|
+
async def on_text_stream_end(self) -> dict[str, Any] | None:
|
|
316
|
+
"""Default: no additional content."""
|
|
317
|
+
return None
|
|
318
|
+
|
|
293
319
|
async def on_thinking_stream(
|
|
294
320
|
self,
|
|
295
321
|
chunk: dict[str, Any],
|
|
@@ -297,6 +323,10 @@ class BaseMiddleware:
|
|
|
297
323
|
"""Default: pass through."""
|
|
298
324
|
return chunk
|
|
299
325
|
|
|
326
|
+
async def on_thinking_stream_end(self) -> dict[str, Any] | None:
|
|
327
|
+
"""Default: no additional content."""
|
|
328
|
+
return None
|
|
329
|
+
|
|
300
330
|
# ========== Agent Lifecycle Hooks ==========
|
|
301
331
|
|
|
302
332
|
async def on_agent_start(
|