aury-agent 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/agents/backends/__init__.py +2 -3
- aury/agents/backends/message/__init__.py +1 -2
- aury/agents/backends/message/memory.py +15 -37
- aury/agents/backends/message/types.py +5 -32
- aury/agents/context_providers/message.py +0 -1
- aury/agents/core/context.py +3 -3
- aury/agents/core/types/tool.py +47 -7
- aury/agents/llm/adapter.py +3 -0
- aury/agents/llm/provider.py +43 -4
- aury/agents/messages/__init__.py +0 -9
- aury/agents/middleware/__init__.py +0 -2
- aury/agents/middleware/base.py +35 -5
- aury/agents/middleware/chain.py +64 -8
- aury/agents/middleware/message.py +5 -53
- aury/agents/react/factory.py +11 -6
- aury/agents/react/persistence.py +12 -2
- aury/agents/react/step.py +72 -6
- aury/agents/react/tools.py +20 -4
- aury/agents/tool/decorator.py +5 -20
- {aury_agent-0.0.9.dist-info → aury_agent-0.0.11.dist-info}/METADATA +1 -1
- {aury_agent-0.0.9.dist-info → aury_agent-0.0.11.dist-info}/RECORD +23 -25
- aury/agents/messages/raw_store.py +0 -224
- aury/agents/middleware/raw_message.py +0 -154
- {aury_agent-0.0.9.dist-info → aury_agent-0.0.11.dist-info}/WHEEL +0 -0
- {aury_agent-0.0.9.dist-info → aury_agent-0.0.11.dist-info}/entry_points.txt +0 -0
|
@@ -1,224 +0,0 @@
|
|
|
1
|
-
"""Raw message store for complete message storage.
|
|
2
|
-
|
|
3
|
-
RawMessageStore stores complete, untruncated messages for:
|
|
4
|
-
- HITL recovery (restore exact state)
|
|
5
|
-
- Full-context recall (when truncated history is insufficient)
|
|
6
|
-
- Audit/debugging
|
|
7
|
-
|
|
8
|
-
Messages are stored per invocation and can be cleaned up after invocation completes.
|
|
9
|
-
"""
|
|
10
|
-
from __future__ import annotations
|
|
11
|
-
|
|
12
|
-
from typing import Any, Protocol, runtime_checkable
|
|
13
|
-
from datetime import datetime
|
|
14
|
-
|
|
15
|
-
from ..core.types.session import generate_id
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
@runtime_checkable
|
|
19
|
-
class RawMessageStore(Protocol):
|
|
20
|
-
"""Protocol for raw message storage.
|
|
21
|
-
|
|
22
|
-
Stores complete messages per invocation.
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
async def add(
|
|
26
|
-
self,
|
|
27
|
-
invocation_id: str,
|
|
28
|
-
message: dict[str, Any],
|
|
29
|
-
) -> str:
|
|
30
|
-
"""Add a message and return its ID.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
invocation_id: Invocation this message belongs to
|
|
34
|
-
message: Complete message dict
|
|
35
|
-
|
|
36
|
-
Returns:
|
|
37
|
-
Generated message ID
|
|
38
|
-
"""
|
|
39
|
-
...
|
|
40
|
-
|
|
41
|
-
async def get(self, msg_id: str) -> dict[str, Any] | None:
|
|
42
|
-
"""Get a single message by ID.
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
msg_id: Message ID
|
|
46
|
-
|
|
47
|
-
Returns:
|
|
48
|
-
Message dict or None if not found
|
|
49
|
-
"""
|
|
50
|
-
...
|
|
51
|
-
|
|
52
|
-
async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
|
|
53
|
-
"""Get multiple messages by IDs.
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
msg_ids: List of message IDs
|
|
57
|
-
|
|
58
|
-
Returns:
|
|
59
|
-
List of message dicts (in same order, None for missing)
|
|
60
|
-
"""
|
|
61
|
-
...
|
|
62
|
-
|
|
63
|
-
async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
|
|
64
|
-
"""Get all messages for an invocation.
|
|
65
|
-
|
|
66
|
-
Args:
|
|
67
|
-
invocation_id: Invocation ID
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
List of messages in chronological order
|
|
71
|
-
"""
|
|
72
|
-
...
|
|
73
|
-
|
|
74
|
-
async def delete_by_invocation(self, invocation_id: str) -> int:
|
|
75
|
-
"""Delete all messages for an invocation.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
invocation_id: Invocation ID
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
Number of messages deleted
|
|
82
|
-
"""
|
|
83
|
-
...
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
class StateBackendRawMessageStore:
|
|
87
|
-
"""RawMessageStore implementation backed by StateBackend."""
|
|
88
|
-
|
|
89
|
-
def __init__(self, backend: Any): # StateBackend
|
|
90
|
-
"""Initialize with StateBackend.
|
|
91
|
-
|
|
92
|
-
Args:
|
|
93
|
-
backend: StateBackend instance
|
|
94
|
-
"""
|
|
95
|
-
self._backend = backend
|
|
96
|
-
|
|
97
|
-
async def add(
|
|
98
|
-
self,
|
|
99
|
-
invocation_id: str,
|
|
100
|
-
message: dict[str, Any],
|
|
101
|
-
) -> str:
|
|
102
|
-
"""Add a message using StateBackend."""
|
|
103
|
-
msg_id = generate_id("rmsg")
|
|
104
|
-
|
|
105
|
-
# Store message with metadata
|
|
106
|
-
await self._backend.set("raw_messages", msg_id, {
|
|
107
|
-
"id": msg_id,
|
|
108
|
-
"invocation_id": invocation_id,
|
|
109
|
-
"message": message,
|
|
110
|
-
"created_at": datetime.now().isoformat(),
|
|
111
|
-
})
|
|
112
|
-
|
|
113
|
-
# Add to invocation's message list
|
|
114
|
-
inv_key = f"raw_msg_ids:{invocation_id}"
|
|
115
|
-
msg_ids = await self._backend.get("raw_messages", inv_key) or []
|
|
116
|
-
msg_ids.append(msg_id)
|
|
117
|
-
await self._backend.set("raw_messages", inv_key, msg_ids)
|
|
118
|
-
|
|
119
|
-
return msg_id
|
|
120
|
-
|
|
121
|
-
async def get(self, msg_id: str) -> dict[str, Any] | None:
|
|
122
|
-
"""Get a single message."""
|
|
123
|
-
data = await self._backend.get("raw_messages", msg_id)
|
|
124
|
-
if data:
|
|
125
|
-
return data.get("message")
|
|
126
|
-
return None
|
|
127
|
-
|
|
128
|
-
async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
|
|
129
|
-
"""Get multiple messages."""
|
|
130
|
-
results = []
|
|
131
|
-
for msg_id in msg_ids:
|
|
132
|
-
msg = await self.get(msg_id)
|
|
133
|
-
if msg:
|
|
134
|
-
results.append(msg)
|
|
135
|
-
return results
|
|
136
|
-
|
|
137
|
-
async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
|
|
138
|
-
"""Get all messages for an invocation."""
|
|
139
|
-
inv_key = f"raw_msg_ids:{invocation_id}"
|
|
140
|
-
msg_ids = await self._backend.get("raw_messages", inv_key) or []
|
|
141
|
-
return await self.get_many(msg_ids)
|
|
142
|
-
|
|
143
|
-
async def delete_by_invocation(self, invocation_id: str) -> int:
|
|
144
|
-
"""Delete all messages for an invocation."""
|
|
145
|
-
inv_key = f"raw_msg_ids:{invocation_id}"
|
|
146
|
-
msg_ids = await self._backend.get("raw_messages", inv_key) or []
|
|
147
|
-
|
|
148
|
-
# Delete each message
|
|
149
|
-
for msg_id in msg_ids:
|
|
150
|
-
await self._backend.remove("raw_messages", msg_id)
|
|
151
|
-
|
|
152
|
-
# Delete the index
|
|
153
|
-
await self._backend.remove("raw_messages", inv_key)
|
|
154
|
-
|
|
155
|
-
return len(msg_ids)
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
class InMemoryRawMessageStore:
|
|
159
|
-
"""In-memory raw message store for testing."""
|
|
160
|
-
|
|
161
|
-
def __init__(self) -> None:
|
|
162
|
-
self._messages: dict[str, dict[str, Any]] = {} # msg_id -> message data
|
|
163
|
-
self._invocation_index: dict[str, list[str]] = {} # inv_id -> [msg_ids]
|
|
164
|
-
|
|
165
|
-
async def add(
|
|
166
|
-
self,
|
|
167
|
-
invocation_id: str,
|
|
168
|
-
message: dict[str, Any],
|
|
169
|
-
) -> str:
|
|
170
|
-
"""Add a message."""
|
|
171
|
-
msg_id = generate_id("rmsg")
|
|
172
|
-
|
|
173
|
-
self._messages[msg_id] = {
|
|
174
|
-
"id": msg_id,
|
|
175
|
-
"invocation_id": invocation_id,
|
|
176
|
-
"message": message,
|
|
177
|
-
"created_at": datetime.now().isoformat(),
|
|
178
|
-
}
|
|
179
|
-
|
|
180
|
-
if invocation_id not in self._invocation_index:
|
|
181
|
-
self._invocation_index[invocation_id] = []
|
|
182
|
-
self._invocation_index[invocation_id].append(msg_id)
|
|
183
|
-
|
|
184
|
-
return msg_id
|
|
185
|
-
|
|
186
|
-
async def get(self, msg_id: str) -> dict[str, Any] | None:
|
|
187
|
-
"""Get a single message."""
|
|
188
|
-
data = self._messages.get(msg_id)
|
|
189
|
-
if data:
|
|
190
|
-
return data.get("message")
|
|
191
|
-
return None
|
|
192
|
-
|
|
193
|
-
async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
|
|
194
|
-
"""Get multiple messages."""
|
|
195
|
-
results = []
|
|
196
|
-
for msg_id in msg_ids:
|
|
197
|
-
msg = await self.get(msg_id)
|
|
198
|
-
if msg:
|
|
199
|
-
results.append(msg)
|
|
200
|
-
return results
|
|
201
|
-
|
|
202
|
-
async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
|
|
203
|
-
"""Get all messages for an invocation."""
|
|
204
|
-
msg_ids = self._invocation_index.get(invocation_id, [])
|
|
205
|
-
return await self.get_many(msg_ids)
|
|
206
|
-
|
|
207
|
-
async def delete_by_invocation(self, invocation_id: str) -> int:
|
|
208
|
-
"""Delete all messages for an invocation."""
|
|
209
|
-
msg_ids = self._invocation_index.pop(invocation_id, [])
|
|
210
|
-
for msg_id in msg_ids:
|
|
211
|
-
self._messages.pop(msg_id, None)
|
|
212
|
-
return len(msg_ids)
|
|
213
|
-
|
|
214
|
-
def clear(self) -> None:
|
|
215
|
-
"""Clear all messages (for testing)."""
|
|
216
|
-
self._messages.clear()
|
|
217
|
-
self._invocation_index.clear()
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
__all__ = [
|
|
221
|
-
"RawMessageStore",
|
|
222
|
-
"StateBackendRawMessageStore",
|
|
223
|
-
"InMemoryRawMessageStore",
|
|
224
|
-
]
|
|
@@ -1,154 +0,0 @@
|
|
|
1
|
-
"""RawMessageMiddleware - stores complete messages for HITL recovery.
|
|
2
|
-
|
|
3
|
-
This middleware stores complete, untruncated messages to RawMessageStore.
|
|
4
|
-
Works alongside MessageBackendMiddleware which stores truncated messages.
|
|
5
|
-
|
|
6
|
-
Usage:
|
|
7
|
-
raw_store = InMemoryRawMessageStore()
|
|
8
|
-
raw_middleware = RawMessageMiddleware(raw_store, persist_raw=False)
|
|
9
|
-
|
|
10
|
-
agent = ReactAgent.create(
|
|
11
|
-
llm=llm,
|
|
12
|
-
middlewares=[raw_middleware, MessageBackendMiddleware()],
|
|
13
|
-
)
|
|
14
|
-
"""
|
|
15
|
-
from __future__ import annotations
|
|
16
|
-
|
|
17
|
-
from typing import TYPE_CHECKING, Any
|
|
18
|
-
|
|
19
|
-
from .base import BaseMiddleware
|
|
20
|
-
from .types import HookResult
|
|
21
|
-
|
|
22
|
-
if TYPE_CHECKING:
|
|
23
|
-
from ..messages.raw_store import RawMessageStore
|
|
24
|
-
from ..core.state import State
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class RawMessageMiddleware(BaseMiddleware):
|
|
28
|
-
"""Middleware that stores complete messages to RawMessageStore.
|
|
29
|
-
|
|
30
|
-
Stores untruncated messages for:
|
|
31
|
-
- HITL recovery (restore exact conversation state)
|
|
32
|
-
- Full-context recall (when truncated history is insufficient)
|
|
33
|
-
|
|
34
|
-
Messages are stored per invocation and can be cleaned up when
|
|
35
|
-
the invocation completes (controlled by persist_raw).
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
def __init__(
|
|
39
|
-
self,
|
|
40
|
-
raw_store: "RawMessageStore",
|
|
41
|
-
persist_raw: bool = False,
|
|
42
|
-
state: "State | None" = None,
|
|
43
|
-
):
|
|
44
|
-
"""Initialize with RawMessageStore.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
raw_store: RawMessageStore for storing complete messages
|
|
48
|
-
persist_raw: Whether to keep messages after invocation completes.
|
|
49
|
-
False = clean up after invocation (default)
|
|
50
|
-
True = keep forever (for audit/recall)
|
|
51
|
-
state: State instance for storing message_ids in execution namespace.
|
|
52
|
-
If provided, message IDs are automatically added to
|
|
53
|
-
state.execution["message_ids"].
|
|
54
|
-
"""
|
|
55
|
-
self.raw_store = raw_store
|
|
56
|
-
self.persist_raw = persist_raw
|
|
57
|
-
self.state = state
|
|
58
|
-
|
|
59
|
-
# Track message IDs per invocation (for cleanup)
|
|
60
|
-
self._invocation_msg_ids: dict[str, list[str]] = {}
|
|
61
|
-
|
|
62
|
-
def set_state(self, state: "State") -> None:
|
|
63
|
-
"""Set state instance (can be set after construction).
|
|
64
|
-
|
|
65
|
-
Args:
|
|
66
|
-
state: State instance for storing message_ids
|
|
67
|
-
"""
|
|
68
|
-
self.state = state
|
|
69
|
-
|
|
70
|
-
async def on_message_save(
|
|
71
|
-
self,
|
|
72
|
-
message: dict[str, Any],
|
|
73
|
-
) -> dict[str, Any] | None:
|
|
74
|
-
"""Store complete message to RawMessageStore.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
message: Complete message dict with 'role', 'content', etc.
|
|
78
|
-
|
|
79
|
-
Returns:
|
|
80
|
-
The message with added 'raw_msg_id' field
|
|
81
|
-
"""
|
|
82
|
-
from ..core.context import get_current_ctx_or_none
|
|
83
|
-
ctx = get_current_ctx_or_none()
|
|
84
|
-
invocation_id = ctx.invocation_id if ctx else ""
|
|
85
|
-
if not invocation_id:
|
|
86
|
-
return message
|
|
87
|
-
|
|
88
|
-
# Store to raw store
|
|
89
|
-
msg_id = await self.raw_store.add(invocation_id, message)
|
|
90
|
-
|
|
91
|
-
# Track for cleanup
|
|
92
|
-
if invocation_id not in self._invocation_msg_ids:
|
|
93
|
-
self._invocation_msg_ids[invocation_id] = []
|
|
94
|
-
self._invocation_msg_ids[invocation_id].append(msg_id)
|
|
95
|
-
|
|
96
|
-
# Add to state.execution["message_ids"] if state is available
|
|
97
|
-
if self.state:
|
|
98
|
-
message_ids = self.state.execution.get("message_ids", [])
|
|
99
|
-
message_ids.append(msg_id)
|
|
100
|
-
self.state.execution["message_ids"] = message_ids
|
|
101
|
-
|
|
102
|
-
# Add msg_id to message for downstream middlewares
|
|
103
|
-
message["raw_msg_id"] = msg_id
|
|
104
|
-
|
|
105
|
-
return message
|
|
106
|
-
|
|
107
|
-
async def on_agent_end(
|
|
108
|
-
self,
|
|
109
|
-
agent_id: str,
|
|
110
|
-
result: Any,
|
|
111
|
-
) -> HookResult:
|
|
112
|
-
"""Clean up raw messages when invocation completes.
|
|
113
|
-
|
|
114
|
-
Only cleans up if persist_raw is False.
|
|
115
|
-
"""
|
|
116
|
-
if self.persist_raw:
|
|
117
|
-
return HookResult.proceed()
|
|
118
|
-
|
|
119
|
-
from ..core.context import get_current_ctx_or_none
|
|
120
|
-
ctx = get_current_ctx_or_none()
|
|
121
|
-
invocation_id = ctx.invocation_id if ctx else ""
|
|
122
|
-
if invocation_id:
|
|
123
|
-
await self._cleanup_invocation(invocation_id)
|
|
124
|
-
|
|
125
|
-
return HookResult.proceed()
|
|
126
|
-
|
|
127
|
-
async def _cleanup_invocation(self, invocation_id: str) -> int:
|
|
128
|
-
"""Clean up raw messages for an invocation.
|
|
129
|
-
|
|
130
|
-
Args:
|
|
131
|
-
invocation_id: Invocation ID to clean up
|
|
132
|
-
|
|
133
|
-
Returns:
|
|
134
|
-
Number of messages deleted
|
|
135
|
-
"""
|
|
136
|
-
# Remove from tracking
|
|
137
|
-
self._invocation_msg_ids.pop(invocation_id, None)
|
|
138
|
-
|
|
139
|
-
# Delete from store
|
|
140
|
-
return await self.raw_store.delete_by_invocation(invocation_id)
|
|
141
|
-
|
|
142
|
-
def get_message_ids(self, invocation_id: str) -> list[str]:
|
|
143
|
-
"""Get tracked message IDs for an invocation.
|
|
144
|
-
|
|
145
|
-
Args:
|
|
146
|
-
invocation_id: Invocation ID
|
|
147
|
-
|
|
148
|
-
Returns:
|
|
149
|
-
List of message IDs
|
|
150
|
-
"""
|
|
151
|
-
return self._invocation_msg_ids.get(invocation_id, []).copy()
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
__all__ = ["RawMessageMiddleware"]
|
|
File without changes
|
|
File without changes
|