aury-agent 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,224 +0,0 @@
1
- """Raw message store for complete message storage.
2
-
3
- RawMessageStore stores complete, untruncated messages for:
4
- - HITL recovery (restore exact state)
5
- - Full-context recall (when truncated history is insufficient)
6
- - Audit/debugging
7
-
8
- Messages are stored per invocation and can be cleaned up after invocation completes.
9
- """
10
- from __future__ import annotations
11
-
12
- from typing import Any, Protocol, runtime_checkable
13
- from datetime import datetime
14
-
15
- from ..core.types.session import generate_id
16
-
17
-
18
- @runtime_checkable
19
- class RawMessageStore(Protocol):
20
- """Protocol for raw message storage.
21
-
22
- Stores complete messages per invocation.
23
- """
24
-
25
- async def add(
26
- self,
27
- invocation_id: str,
28
- message: dict[str, Any],
29
- ) -> str:
30
- """Add a message and return its ID.
31
-
32
- Args:
33
- invocation_id: Invocation this message belongs to
34
- message: Complete message dict
35
-
36
- Returns:
37
- Generated message ID
38
- """
39
- ...
40
-
41
- async def get(self, msg_id: str) -> dict[str, Any] | None:
42
- """Get a single message by ID.
43
-
44
- Args:
45
- msg_id: Message ID
46
-
47
- Returns:
48
- Message dict or None if not found
49
- """
50
- ...
51
-
52
- async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
53
- """Get multiple messages by IDs.
54
-
55
- Args:
56
- msg_ids: List of message IDs
57
-
58
- Returns:
59
- List of message dicts (in same order, None for missing)
60
- """
61
- ...
62
-
63
- async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
64
- """Get all messages for an invocation.
65
-
66
- Args:
67
- invocation_id: Invocation ID
68
-
69
- Returns:
70
- List of messages in chronological order
71
- """
72
- ...
73
-
74
- async def delete_by_invocation(self, invocation_id: str) -> int:
75
- """Delete all messages for an invocation.
76
-
77
- Args:
78
- invocation_id: Invocation ID
79
-
80
- Returns:
81
- Number of messages deleted
82
- """
83
- ...
84
-
85
-
86
- class StateBackendRawMessageStore:
87
- """RawMessageStore implementation backed by StateBackend."""
88
-
89
- def __init__(self, backend: Any): # StateBackend
90
- """Initialize with StateBackend.
91
-
92
- Args:
93
- backend: StateBackend instance
94
- """
95
- self._backend = backend
96
-
97
- async def add(
98
- self,
99
- invocation_id: str,
100
- message: dict[str, Any],
101
- ) -> str:
102
- """Add a message using StateBackend."""
103
- msg_id = generate_id("rmsg")
104
-
105
- # Store message with metadata
106
- await self._backend.set("raw_messages", msg_id, {
107
- "id": msg_id,
108
- "invocation_id": invocation_id,
109
- "message": message,
110
- "created_at": datetime.now().isoformat(),
111
- })
112
-
113
- # Add to invocation's message list
114
- inv_key = f"raw_msg_ids:{invocation_id}"
115
- msg_ids = await self._backend.get("raw_messages", inv_key) or []
116
- msg_ids.append(msg_id)
117
- await self._backend.set("raw_messages", inv_key, msg_ids)
118
-
119
- return msg_id
120
-
121
- async def get(self, msg_id: str) -> dict[str, Any] | None:
122
- """Get a single message."""
123
- data = await self._backend.get("raw_messages", msg_id)
124
- if data:
125
- return data.get("message")
126
- return None
127
-
128
- async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
129
- """Get multiple messages."""
130
- results = []
131
- for msg_id in msg_ids:
132
- msg = await self.get(msg_id)
133
- if msg:
134
- results.append(msg)
135
- return results
136
-
137
- async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
138
- """Get all messages for an invocation."""
139
- inv_key = f"raw_msg_ids:{invocation_id}"
140
- msg_ids = await self._backend.get("raw_messages", inv_key) or []
141
- return await self.get_many(msg_ids)
142
-
143
- async def delete_by_invocation(self, invocation_id: str) -> int:
144
- """Delete all messages for an invocation."""
145
- inv_key = f"raw_msg_ids:{invocation_id}"
146
- msg_ids = await self._backend.get("raw_messages", inv_key) or []
147
-
148
- # Delete each message
149
- for msg_id in msg_ids:
150
- await self._backend.remove("raw_messages", msg_id)
151
-
152
- # Delete the index
153
- await self._backend.remove("raw_messages", inv_key)
154
-
155
- return len(msg_ids)
156
-
157
-
158
- class InMemoryRawMessageStore:
159
- """In-memory raw message store for testing."""
160
-
161
- def __init__(self) -> None:
162
- self._messages: dict[str, dict[str, Any]] = {} # msg_id -> message data
163
- self._invocation_index: dict[str, list[str]] = {} # inv_id -> [msg_ids]
164
-
165
- async def add(
166
- self,
167
- invocation_id: str,
168
- message: dict[str, Any],
169
- ) -> str:
170
- """Add a message."""
171
- msg_id = generate_id("rmsg")
172
-
173
- self._messages[msg_id] = {
174
- "id": msg_id,
175
- "invocation_id": invocation_id,
176
- "message": message,
177
- "created_at": datetime.now().isoformat(),
178
- }
179
-
180
- if invocation_id not in self._invocation_index:
181
- self._invocation_index[invocation_id] = []
182
- self._invocation_index[invocation_id].append(msg_id)
183
-
184
- return msg_id
185
-
186
- async def get(self, msg_id: str) -> dict[str, Any] | None:
187
- """Get a single message."""
188
- data = self._messages.get(msg_id)
189
- if data:
190
- return data.get("message")
191
- return None
192
-
193
- async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
194
- """Get multiple messages."""
195
- results = []
196
- for msg_id in msg_ids:
197
- msg = await self.get(msg_id)
198
- if msg:
199
- results.append(msg)
200
- return results
201
-
202
- async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
203
- """Get all messages for an invocation."""
204
- msg_ids = self._invocation_index.get(invocation_id, [])
205
- return await self.get_many(msg_ids)
206
-
207
- async def delete_by_invocation(self, invocation_id: str) -> int:
208
- """Delete all messages for an invocation."""
209
- msg_ids = self._invocation_index.pop(invocation_id, [])
210
- for msg_id in msg_ids:
211
- self._messages.pop(msg_id, None)
212
- return len(msg_ids)
213
-
214
- def clear(self) -> None:
215
- """Clear all messages (for testing)."""
216
- self._messages.clear()
217
- self._invocation_index.clear()
218
-
219
-
220
- __all__ = [
221
- "RawMessageStore",
222
- "StateBackendRawMessageStore",
223
- "InMemoryRawMessageStore",
224
- ]
@@ -1,154 +0,0 @@
1
- """RawMessageMiddleware - stores complete messages for HITL recovery.
2
-
3
- This middleware stores complete, untruncated messages to RawMessageStore.
4
- Works alongside MessageBackendMiddleware which stores truncated messages.
5
-
6
- Usage:
7
- raw_store = InMemoryRawMessageStore()
8
- raw_middleware = RawMessageMiddleware(raw_store, persist_raw=False)
9
-
10
- agent = ReactAgent.create(
11
- llm=llm,
12
- middlewares=[raw_middleware, MessageBackendMiddleware()],
13
- )
14
- """
15
- from __future__ import annotations
16
-
17
- from typing import TYPE_CHECKING, Any
18
-
19
- from .base import BaseMiddleware
20
- from .types import HookResult
21
-
22
- if TYPE_CHECKING:
23
- from ..messages.raw_store import RawMessageStore
24
- from ..core.state import State
25
-
26
-
27
- class RawMessageMiddleware(BaseMiddleware):
28
- """Middleware that stores complete messages to RawMessageStore.
29
-
30
- Stores untruncated messages for:
31
- - HITL recovery (restore exact conversation state)
32
- - Full-context recall (when truncated history is insufficient)
33
-
34
- Messages are stored per invocation and can be cleaned up when
35
- the invocation completes (controlled by persist_raw).
36
- """
37
-
38
- def __init__(
39
- self,
40
- raw_store: "RawMessageStore",
41
- persist_raw: bool = False,
42
- state: "State | None" = None,
43
- ):
44
- """Initialize with RawMessageStore.
45
-
46
- Args:
47
- raw_store: RawMessageStore for storing complete messages
48
- persist_raw: Whether to keep messages after invocation completes.
49
- False = clean up after invocation (default)
50
- True = keep forever (for audit/recall)
51
- state: State instance for storing message_ids in execution namespace.
52
- If provided, message IDs are automatically added to
53
- state.execution["message_ids"].
54
- """
55
- self.raw_store = raw_store
56
- self.persist_raw = persist_raw
57
- self.state = state
58
-
59
- # Track message IDs per invocation (for cleanup)
60
- self._invocation_msg_ids: dict[str, list[str]] = {}
61
-
62
- def set_state(self, state: "State") -> None:
63
- """Set state instance (can be set after construction).
64
-
65
- Args:
66
- state: State instance for storing message_ids
67
- """
68
- self.state = state
69
-
70
- async def on_message_save(
71
- self,
72
- message: dict[str, Any],
73
- ) -> dict[str, Any] | None:
74
- """Store complete message to RawMessageStore.
75
-
76
- Args:
77
- message: Complete message dict with 'role', 'content', etc.
78
-
79
- Returns:
80
- The message with added 'raw_msg_id' field
81
- """
82
- from ..core.context import get_current_ctx_or_none
83
- ctx = get_current_ctx_or_none()
84
- invocation_id = ctx.invocation_id if ctx else ""
85
- if not invocation_id:
86
- return message
87
-
88
- # Store to raw store
89
- msg_id = await self.raw_store.add(invocation_id, message)
90
-
91
- # Track for cleanup
92
- if invocation_id not in self._invocation_msg_ids:
93
- self._invocation_msg_ids[invocation_id] = []
94
- self._invocation_msg_ids[invocation_id].append(msg_id)
95
-
96
- # Add to state.execution["message_ids"] if state is available
97
- if self.state:
98
- message_ids = self.state.execution.get("message_ids", [])
99
- message_ids.append(msg_id)
100
- self.state.execution["message_ids"] = message_ids
101
-
102
- # Add msg_id to message for downstream middlewares
103
- message["raw_msg_id"] = msg_id
104
-
105
- return message
106
-
107
- async def on_agent_end(
108
- self,
109
- agent_id: str,
110
- result: Any,
111
- ) -> HookResult:
112
- """Clean up raw messages when invocation completes.
113
-
114
- Only cleans up if persist_raw is False.
115
- """
116
- if self.persist_raw:
117
- return HookResult.proceed()
118
-
119
- from ..core.context import get_current_ctx_or_none
120
- ctx = get_current_ctx_or_none()
121
- invocation_id = ctx.invocation_id if ctx else ""
122
- if invocation_id:
123
- await self._cleanup_invocation(invocation_id)
124
-
125
- return HookResult.proceed()
126
-
127
- async def _cleanup_invocation(self, invocation_id: str) -> int:
128
- """Clean up raw messages for an invocation.
129
-
130
- Args:
131
- invocation_id: Invocation ID to clean up
132
-
133
- Returns:
134
- Number of messages deleted
135
- """
136
- # Remove from tracking
137
- self._invocation_msg_ids.pop(invocation_id, None)
138
-
139
- # Delete from store
140
- return await self.raw_store.delete_by_invocation(invocation_id)
141
-
142
- def get_message_ids(self, invocation_id: str) -> list[str]:
143
- """Get tracked message IDs for an invocation.
144
-
145
- Args:
146
- invocation_id: Invocation ID
147
-
148
- Returns:
149
- List of message IDs
150
- """
151
- return self._invocation_msg_ids.get(invocation_id, []).copy()
152
-
153
-
154
- __all__ = ["RawMessageMiddleware"]