aury-agent 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/__init__.py +2 -0
- aury/agents/__init__.py +55 -0
- aury/agents/a2a/__init__.py +168 -0
- aury/agents/backends/__init__.py +196 -0
- aury/agents/backends/artifact/__init__.py +9 -0
- aury/agents/backends/artifact/memory.py +130 -0
- aury/agents/backends/artifact/types.py +133 -0
- aury/agents/backends/code/__init__.py +65 -0
- aury/agents/backends/file/__init__.py +11 -0
- aury/agents/backends/file/local.py +66 -0
- aury/agents/backends/file/types.py +40 -0
- aury/agents/backends/invocation/__init__.py +8 -0
- aury/agents/backends/invocation/memory.py +81 -0
- aury/agents/backends/invocation/types.py +110 -0
- aury/agents/backends/memory/__init__.py +8 -0
- aury/agents/backends/memory/memory.py +179 -0
- aury/agents/backends/memory/types.py +136 -0
- aury/agents/backends/message/__init__.py +9 -0
- aury/agents/backends/message/memory.py +122 -0
- aury/agents/backends/message/types.py +124 -0
- aury/agents/backends/sandbox.py +275 -0
- aury/agents/backends/session/__init__.py +8 -0
- aury/agents/backends/session/memory.py +93 -0
- aury/agents/backends/session/types.py +124 -0
- aury/agents/backends/shell/__init__.py +11 -0
- aury/agents/backends/shell/local.py +110 -0
- aury/agents/backends/shell/types.py +55 -0
- aury/agents/backends/shell.py +209 -0
- aury/agents/backends/snapshot/__init__.py +19 -0
- aury/agents/backends/snapshot/git.py +95 -0
- aury/agents/backends/snapshot/hybrid.py +125 -0
- aury/agents/backends/snapshot/memory.py +86 -0
- aury/agents/backends/snapshot/types.py +59 -0
- aury/agents/backends/state/__init__.py +29 -0
- aury/agents/backends/state/composite.py +49 -0
- aury/agents/backends/state/file.py +57 -0
- aury/agents/backends/state/memory.py +52 -0
- aury/agents/backends/state/sqlite.py +262 -0
- aury/agents/backends/state/types.py +178 -0
- aury/agents/backends/subagent/__init__.py +165 -0
- aury/agents/cli/__init__.py +41 -0
- aury/agents/cli/chat.py +239 -0
- aury/agents/cli/config.py +236 -0
- aury/agents/cli/extensions.py +460 -0
- aury/agents/cli/main.py +189 -0
- aury/agents/cli/session.py +337 -0
- aury/agents/cli/workflow.py +276 -0
- aury/agents/context_providers/__init__.py +66 -0
- aury/agents/context_providers/artifact.py +299 -0
- aury/agents/context_providers/base.py +177 -0
- aury/agents/context_providers/memory.py +70 -0
- aury/agents/context_providers/message.py +130 -0
- aury/agents/context_providers/skill.py +50 -0
- aury/agents/context_providers/subagent.py +46 -0
- aury/agents/context_providers/tool.py +68 -0
- aury/agents/core/__init__.py +83 -0
- aury/agents/core/base.py +573 -0
- aury/agents/core/context.py +797 -0
- aury/agents/core/context_builder.py +303 -0
- aury/agents/core/event_bus/__init__.py +15 -0
- aury/agents/core/event_bus/bus.py +203 -0
- aury/agents/core/factory.py +169 -0
- aury/agents/core/isolator.py +97 -0
- aury/agents/core/logging.py +95 -0
- aury/agents/core/parallel.py +194 -0
- aury/agents/core/runner.py +139 -0
- aury/agents/core/services/__init__.py +5 -0
- aury/agents/core/services/file_session.py +144 -0
- aury/agents/core/services/message.py +53 -0
- aury/agents/core/services/session.py +53 -0
- aury/agents/core/signals.py +109 -0
- aury/agents/core/state.py +363 -0
- aury/agents/core/types/__init__.py +107 -0
- aury/agents/core/types/action.py +176 -0
- aury/agents/core/types/artifact.py +135 -0
- aury/agents/core/types/block.py +736 -0
- aury/agents/core/types/message.py +350 -0
- aury/agents/core/types/recall.py +144 -0
- aury/agents/core/types/session.py +257 -0
- aury/agents/core/types/subagent.py +154 -0
- aury/agents/core/types/tool.py +205 -0
- aury/agents/eval/__init__.py +331 -0
- aury/agents/hitl/__init__.py +57 -0
- aury/agents/hitl/ask_user.py +242 -0
- aury/agents/hitl/compaction.py +230 -0
- aury/agents/hitl/exceptions.py +87 -0
- aury/agents/hitl/permission.py +617 -0
- aury/agents/hitl/revert.py +216 -0
- aury/agents/llm/__init__.py +31 -0
- aury/agents/llm/adapter.py +367 -0
- aury/agents/llm/openai.py +294 -0
- aury/agents/llm/provider.py +476 -0
- aury/agents/mcp/__init__.py +153 -0
- aury/agents/memory/__init__.py +46 -0
- aury/agents/memory/compaction.py +394 -0
- aury/agents/memory/manager.py +465 -0
- aury/agents/memory/processor.py +177 -0
- aury/agents/memory/store.py +187 -0
- aury/agents/memory/types.py +137 -0
- aury/agents/messages/__init__.py +40 -0
- aury/agents/messages/config.py +47 -0
- aury/agents/messages/raw_store.py +224 -0
- aury/agents/messages/store.py +118 -0
- aury/agents/messages/types.py +88 -0
- aury/agents/middleware/__init__.py +31 -0
- aury/agents/middleware/base.py +341 -0
- aury/agents/middleware/chain.py +342 -0
- aury/agents/middleware/message.py +129 -0
- aury/agents/middleware/message_container.py +126 -0
- aury/agents/middleware/raw_message.py +153 -0
- aury/agents/middleware/truncation.py +139 -0
- aury/agents/middleware/types.py +81 -0
- aury/agents/plugin.py +162 -0
- aury/agents/react/__init__.py +4 -0
- aury/agents/react/agent.py +1923 -0
- aury/agents/sandbox/__init__.py +23 -0
- aury/agents/sandbox/local.py +239 -0
- aury/agents/sandbox/remote.py +200 -0
- aury/agents/sandbox/types.py +115 -0
- aury/agents/skill/__init__.py +16 -0
- aury/agents/skill/loader.py +180 -0
- aury/agents/skill/types.py +83 -0
- aury/agents/tool/__init__.py +39 -0
- aury/agents/tool/builtin/__init__.py +23 -0
- aury/agents/tool/builtin/ask_user.py +155 -0
- aury/agents/tool/builtin/bash.py +107 -0
- aury/agents/tool/builtin/delegate.py +726 -0
- aury/agents/tool/builtin/edit.py +121 -0
- aury/agents/tool/builtin/plan.py +277 -0
- aury/agents/tool/builtin/read.py +91 -0
- aury/agents/tool/builtin/thinking.py +111 -0
- aury/agents/tool/builtin/yield_result.py +130 -0
- aury/agents/tool/decorator.py +252 -0
- aury/agents/tool/set.py +204 -0
- aury/agents/usage/__init__.py +12 -0
- aury/agents/usage/tracker.py +236 -0
- aury/agents/workflow/__init__.py +85 -0
- aury/agents/workflow/adapter.py +268 -0
- aury/agents/workflow/dag.py +116 -0
- aury/agents/workflow/dsl.py +575 -0
- aury/agents/workflow/executor.py +659 -0
- aury/agents/workflow/expression.py +136 -0
- aury/agents/workflow/parser.py +182 -0
- aury/agents/workflow/state.py +145 -0
- aury/agents/workflow/types.py +86 -0
- aury_agent-0.0.4.dist-info/METADATA +90 -0
- aury_agent-0.0.4.dist-info/RECORD +149 -0
- aury_agent-0.0.4.dist-info/WHEEL +4 -0
- aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Memory store protocol and implementations."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import hashlib
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Protocol, runtime_checkable
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class MemoryEntry:
|
|
12
|
+
"""A memory entry."""
|
|
13
|
+
id: str
|
|
14
|
+
content: str
|
|
15
|
+
session_id: str | None = None
|
|
16
|
+
invocation_id: str | None = None
|
|
17
|
+
created_at: datetime = field(default_factory=datetime.now)
|
|
18
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
19
|
+
|
|
20
|
+
def to_dict(self) -> dict[str, Any]:
|
|
21
|
+
return {
|
|
22
|
+
"id": self.id,
|
|
23
|
+
"content": self.content,
|
|
24
|
+
"session_id": self.session_id,
|
|
25
|
+
"invocation_id": self.invocation_id,
|
|
26
|
+
"created_at": self.created_at.isoformat(),
|
|
27
|
+
"metadata": self.metadata,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def from_dict(cls, data: dict[str, Any]) -> MemoryEntry:
|
|
32
|
+
return cls(
|
|
33
|
+
id=data["id"],
|
|
34
|
+
content=data["content"],
|
|
35
|
+
session_id=data.get("session_id"),
|
|
36
|
+
invocation_id=data.get("invocation_id"),
|
|
37
|
+
created_at=datetime.fromisoformat(data["created_at"])
|
|
38
|
+
if "created_at" in data else datetime.now(),
|
|
39
|
+
metadata=data.get("metadata", {}),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def content_hash(self) -> str:
|
|
44
|
+
"""Get hash of content for deduplication."""
|
|
45
|
+
return hashlib.sha256(self.content.encode()).hexdigest()[:16]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ScoredEntry:
|
|
50
|
+
"""Memory entry with relevance score."""
|
|
51
|
+
entry: MemoryEntry
|
|
52
|
+
score: float
|
|
53
|
+
source: str = "default"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@runtime_checkable
|
|
57
|
+
class MemoryStore(Protocol):
|
|
58
|
+
"""Memory store protocol."""
|
|
59
|
+
|
|
60
|
+
async def add(self, entry: MemoryEntry) -> str:
|
|
61
|
+
"""Add entry, return ID."""
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
async def search(
|
|
65
|
+
self,
|
|
66
|
+
query: str,
|
|
67
|
+
filter: dict[str, Any] | None = None,
|
|
68
|
+
limit: int = 10,
|
|
69
|
+
) -> list[ScoredEntry]:
|
|
70
|
+
"""Search for relevant entries."""
|
|
71
|
+
...
|
|
72
|
+
|
|
73
|
+
async def get(self, entry_id: str) -> MemoryEntry | None:
|
|
74
|
+
"""Get entry by ID."""
|
|
75
|
+
...
|
|
76
|
+
|
|
77
|
+
async def remove(self, entry_id: str) -> None:
|
|
78
|
+
"""Remove entry."""
|
|
79
|
+
...
|
|
80
|
+
|
|
81
|
+
async def revert(
|
|
82
|
+
self,
|
|
83
|
+
session_id: str,
|
|
84
|
+
after_invocation_id: str,
|
|
85
|
+
) -> list[str]:
|
|
86
|
+
"""Remove entries after specified invocation.
|
|
87
|
+
|
|
88
|
+
Returns list of deleted IDs.
|
|
89
|
+
"""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class InMemoryStore:
|
|
94
|
+
"""Simple in-memory store for testing."""
|
|
95
|
+
|
|
96
|
+
def __init__(self) -> None:
|
|
97
|
+
self._entries: dict[str, MemoryEntry] = {}
|
|
98
|
+
self._content_hashes: dict[str, str] = {} # hash -> entry_id
|
|
99
|
+
|
|
100
|
+
async def add(self, entry: MemoryEntry) -> str:
|
|
101
|
+
"""Add entry with deduplication."""
|
|
102
|
+
content_hash = entry.content_hash
|
|
103
|
+
|
|
104
|
+
# Check for duplicate
|
|
105
|
+
if content_hash in self._content_hashes:
|
|
106
|
+
return self._content_hashes[content_hash]
|
|
107
|
+
|
|
108
|
+
self._entries[entry.id] = entry
|
|
109
|
+
self._content_hashes[content_hash] = entry.id
|
|
110
|
+
return entry.id
|
|
111
|
+
|
|
112
|
+
async def search(
|
|
113
|
+
self,
|
|
114
|
+
query: str,
|
|
115
|
+
filter: dict[str, Any] | None = None,
|
|
116
|
+
limit: int = 10,
|
|
117
|
+
) -> list[ScoredEntry]:
|
|
118
|
+
"""Simple keyword search."""
|
|
119
|
+
filter = filter or {}
|
|
120
|
+
results = []
|
|
121
|
+
|
|
122
|
+
query_lower = query.lower()
|
|
123
|
+
query_words = set(query_lower.split())
|
|
124
|
+
|
|
125
|
+
for entry in self._entries.values():
|
|
126
|
+
# Apply filters
|
|
127
|
+
if filter:
|
|
128
|
+
skip = False
|
|
129
|
+
for key, value in filter.items():
|
|
130
|
+
entry_value = getattr(entry, key, None) or entry.metadata.get(key)
|
|
131
|
+
if entry_value != value:
|
|
132
|
+
skip = True
|
|
133
|
+
break
|
|
134
|
+
if skip:
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
# Calculate simple relevance score
|
|
138
|
+
content_lower = entry.content.lower()
|
|
139
|
+
content_words = set(content_lower.split())
|
|
140
|
+
|
|
141
|
+
# Word overlap score
|
|
142
|
+
overlap = len(query_words & content_words)
|
|
143
|
+
if overlap > 0:
|
|
144
|
+
score = overlap / len(query_words)
|
|
145
|
+
results.append(ScoredEntry(entry=entry, score=score, source="memory"))
|
|
146
|
+
elif query_lower in content_lower:
|
|
147
|
+
# Substring match
|
|
148
|
+
results.append(ScoredEntry(entry=entry, score=0.5, source="memory"))
|
|
149
|
+
|
|
150
|
+
# Sort by score and limit
|
|
151
|
+
results.sort(key=lambda x: x.score, reverse=True)
|
|
152
|
+
return results[:limit]
|
|
153
|
+
|
|
154
|
+
async def get(self, entry_id: str) -> MemoryEntry | None:
|
|
155
|
+
return self._entries.get(entry_id)
|
|
156
|
+
|
|
157
|
+
async def remove(self, entry_id: str) -> None:
|
|
158
|
+
if entry_id in self._entries:
|
|
159
|
+
entry = self._entries[entry_id]
|
|
160
|
+
content_hash = entry.content_hash
|
|
161
|
+
self._content_hashes.pop(content_hash, None)
|
|
162
|
+
del self._entries[entry_id]
|
|
163
|
+
|
|
164
|
+
async def revert(
|
|
165
|
+
self,
|
|
166
|
+
session_id: str,
|
|
167
|
+
after_invocation_id: str,
|
|
168
|
+
) -> list[str]:
|
|
169
|
+
"""Remove entries for invocations after specified one."""
|
|
170
|
+
# Simple implementation: compare invocation IDs lexicographically
|
|
171
|
+
# In production, use timestamps or sequence numbers
|
|
172
|
+
to_delete = []
|
|
173
|
+
|
|
174
|
+
for entry_id, entry in self._entries.items():
|
|
175
|
+
if entry.session_id == session_id:
|
|
176
|
+
if entry.invocation_id and entry.invocation_id > after_invocation_id:
|
|
177
|
+
to_delete.append(entry_id)
|
|
178
|
+
|
|
179
|
+
for entry_id in to_delete:
|
|
180
|
+
await self.remove(entry_id)
|
|
181
|
+
|
|
182
|
+
return to_delete
|
|
183
|
+
|
|
184
|
+
def clear(self) -> None:
|
|
185
|
+
"""Clear all entries."""
|
|
186
|
+
self._entries.clear()
|
|
187
|
+
self._content_hashes.clear()
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Memory type definitions.
|
|
2
|
+
|
|
3
|
+
Core types for the memory system:
|
|
4
|
+
- MemorySummary: Compressed overview of conversation history
|
|
5
|
+
- MemoryRecall: Key points extracted from invocations
|
|
6
|
+
- MemoryContext: Combined context for LLM (summary + recalls)
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class MemorySummary:
|
|
17
|
+
"""Compressed overview of conversation history.
|
|
18
|
+
|
|
19
|
+
Represents the "big picture" of a session's conversation.
|
|
20
|
+
Updated incrementally as invocations complete.
|
|
21
|
+
"""
|
|
22
|
+
session_id: str
|
|
23
|
+
content: str # Summary text
|
|
24
|
+
last_invocation_id: str # Last invocation included in summary
|
|
25
|
+
updated_at: datetime = field(default_factory=datetime.now)
|
|
26
|
+
token_count: int = 0 # Estimated token count
|
|
27
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
28
|
+
|
|
29
|
+
def to_dict(self) -> dict[str, Any]:
|
|
30
|
+
return {
|
|
31
|
+
"session_id": self.session_id,
|
|
32
|
+
"content": self.content,
|
|
33
|
+
"last_invocation_id": self.last_invocation_id,
|
|
34
|
+
"updated_at": self.updated_at.isoformat(),
|
|
35
|
+
"token_count": self.token_count,
|
|
36
|
+
"metadata": self.metadata,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_dict(cls, data: dict[str, Any]) -> "MemorySummary":
|
|
41
|
+
return cls(
|
|
42
|
+
session_id=data["session_id"],
|
|
43
|
+
content=data["content"],
|
|
44
|
+
last_invocation_id=data["last_invocation_id"],
|
|
45
|
+
updated_at=datetime.fromisoformat(data["updated_at"]),
|
|
46
|
+
token_count=data.get("token_count", 0),
|
|
47
|
+
metadata=data.get("metadata", {}),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class MemoryRecall:
|
|
53
|
+
"""Key point extracted from an invocation.
|
|
54
|
+
|
|
55
|
+
Represents an important piece of information that should be
|
|
56
|
+
recalled when building LLM context. Linked to specific invocation
|
|
57
|
+
for isolation and revert support.
|
|
58
|
+
"""
|
|
59
|
+
id: str
|
|
60
|
+
session_id: str
|
|
61
|
+
invocation_id: str # Which invocation this came from
|
|
62
|
+
content: str # Recall content
|
|
63
|
+
importance: float = 0.5 # 0.0 - 1.0, higher = more important
|
|
64
|
+
tags: list[str] = field(default_factory=list)
|
|
65
|
+
created_at: datetime = field(default_factory=datetime.now)
|
|
66
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
67
|
+
|
|
68
|
+
def to_dict(self) -> dict[str, Any]:
|
|
69
|
+
return {
|
|
70
|
+
"id": self.id,
|
|
71
|
+
"session_id": self.session_id,
|
|
72
|
+
"invocation_id": self.invocation_id,
|
|
73
|
+
"content": self.content,
|
|
74
|
+
"importance": self.importance,
|
|
75
|
+
"tags": self.tags,
|
|
76
|
+
"created_at": self.created_at.isoformat(),
|
|
77
|
+
"metadata": self.metadata,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_dict(cls, data: dict[str, Any]) -> "MemoryRecall":
|
|
82
|
+
return cls(
|
|
83
|
+
id=data["id"],
|
|
84
|
+
session_id=data["session_id"],
|
|
85
|
+
invocation_id=data["invocation_id"],
|
|
86
|
+
content=data["content"],
|
|
87
|
+
importance=data.get("importance", 0.5),
|
|
88
|
+
tags=data.get("tags", []),
|
|
89
|
+
created_at=datetime.fromisoformat(data["created_at"]),
|
|
90
|
+
metadata=data.get("metadata", {}),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class MemoryContext:
|
|
96
|
+
"""Combined memory context for LLM.
|
|
97
|
+
|
|
98
|
+
This is what gets injected into the LLM prompt to provide
|
|
99
|
+
historical context. Contains:
|
|
100
|
+
- summary: High-level overview of conversation
|
|
101
|
+
- recalls: Specific important points
|
|
102
|
+
"""
|
|
103
|
+
summary: MemorySummary | None = None
|
|
104
|
+
recalls: list[MemoryRecall] = field(default_factory=list)
|
|
105
|
+
|
|
106
|
+
def to_system_message(self) -> str:
|
|
107
|
+
"""Format as system message content for LLM."""
|
|
108
|
+
parts = []
|
|
109
|
+
|
|
110
|
+
if self.summary and self.summary.content:
|
|
111
|
+
parts.append(f"## Conversation History Overview\n{self.summary.content}")
|
|
112
|
+
|
|
113
|
+
if self.recalls:
|
|
114
|
+
recalls_text = "\n".join([
|
|
115
|
+
f"- [{', '.join(r.tags) if r.tags else 'note'}] {r.content}"
|
|
116
|
+
for r in sorted(self.recalls, key=lambda x: x.importance, reverse=True)
|
|
117
|
+
])
|
|
118
|
+
parts.append(f"## Key Points\n{recalls_text}")
|
|
119
|
+
|
|
120
|
+
return "\n\n".join(parts) if parts else ""
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def is_empty(self) -> bool:
|
|
124
|
+
"""Check if context has any content."""
|
|
125
|
+
return (not self.summary or not self.summary.content) and not self.recalls
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def total_recalls(self) -> int:
|
|
129
|
+
"""Get total number of recalls."""
|
|
130
|
+
return len(self.recalls)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
__all__ = [
|
|
134
|
+
"MemorySummary",
|
|
135
|
+
"MemoryRecall",
|
|
136
|
+
"MemoryContext",
|
|
137
|
+
]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Message types for conversation history.
|
|
2
|
+
|
|
3
|
+
Note: Message persistence is now handled by:
|
|
4
|
+
- MessageBackend (backends/message/): Storage layer
|
|
5
|
+
- MessageBackendMiddleware (middleware/message.py): Save via on_message_save hook
|
|
6
|
+
- MessageContextProvider (context_providers/message.py): Fetch for context
|
|
7
|
+
|
|
8
|
+
This module provides message types used across the system.
|
|
9
|
+
"""
|
|
10
|
+
from .types import (
|
|
11
|
+
MessageRole,
|
|
12
|
+
Message,
|
|
13
|
+
)
|
|
14
|
+
from .store import (
|
|
15
|
+
MessageStore,
|
|
16
|
+
InMemoryMessageStore,
|
|
17
|
+
)
|
|
18
|
+
from .raw_store import (
|
|
19
|
+
RawMessageStore,
|
|
20
|
+
StateBackendRawMessageStore,
|
|
21
|
+
InMemoryRawMessageStore,
|
|
22
|
+
)
|
|
23
|
+
from .config import (
|
|
24
|
+
MessageConfig,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
# Types
|
|
29
|
+
"MessageRole",
|
|
30
|
+
"Message",
|
|
31
|
+
# Store (protocol + in-memory for testing)
|
|
32
|
+
"MessageStore",
|
|
33
|
+
"InMemoryMessageStore",
|
|
34
|
+
# Raw Store
|
|
35
|
+
"RawMessageStore",
|
|
36
|
+
"StateBackendRawMessageStore",
|
|
37
|
+
"InMemoryRawMessageStore",
|
|
38
|
+
# Config
|
|
39
|
+
"MessageConfig",
|
|
40
|
+
]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Message storage configuration.
|
|
2
|
+
|
|
3
|
+
Configures how messages are stored and retrieved.
|
|
4
|
+
"""
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class MessageConfig:
|
|
13
|
+
"""Configuration for message storage.
|
|
14
|
+
|
|
15
|
+
Controls raw message storage and recall behavior.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
enable_raw_store: Whether to store complete messages in RawMessageStore.
|
|
19
|
+
Required for HITL recovery and full-context recall.
|
|
20
|
+
persist_raw: Whether to keep raw messages after invocation completes.
|
|
21
|
+
False = clean up after invocation (default, saves space)
|
|
22
|
+
True = keep forever (for audit/recall)
|
|
23
|
+
recall_mode: How to build context for LLM recall.
|
|
24
|
+
"mixed" = previous invocations truncated + current invocation raw
|
|
25
|
+
"raw" = all raw messages (requires persist_raw=True or current inv only)
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
# Default: raw for recovery, clean up after
|
|
29
|
+
config = MessageConfig()
|
|
30
|
+
|
|
31
|
+
# Keep all raw messages for full recall
|
|
32
|
+
config = MessageConfig(persist_raw=True, recall_mode="raw")
|
|
33
|
+
|
|
34
|
+
# Disable raw storage (no HITL recovery)
|
|
35
|
+
config = MessageConfig(enable_raw_store=False)
|
|
36
|
+
"""
|
|
37
|
+
enable_raw_store: bool = True
|
|
38
|
+
persist_raw: bool = False
|
|
39
|
+
recall_mode: Literal["mixed", "raw"] = "mixed"
|
|
40
|
+
|
|
41
|
+
def __post_init__(self):
|
|
42
|
+
# Validate: raw recall mode requires raw storage
|
|
43
|
+
if self.recall_mode == "raw" and not self.enable_raw_store:
|
|
44
|
+
raise ValueError("recall_mode='raw' requires enable_raw_store=True")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
__all__ = ["MessageConfig"]
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""Raw message store for complete message storage.
|
|
2
|
+
|
|
3
|
+
RawMessageStore stores complete, untruncated messages for:
|
|
4
|
+
- HITL recovery (restore exact state)
|
|
5
|
+
- Full-context recall (when truncated history is insufficient)
|
|
6
|
+
- Audit/debugging
|
|
7
|
+
|
|
8
|
+
Messages are stored per invocation and can be cleaned up after invocation completes.
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any, Protocol, runtime_checkable
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
|
|
15
|
+
from ..core.types.session import generate_id
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@runtime_checkable
|
|
19
|
+
class RawMessageStore(Protocol):
|
|
20
|
+
"""Protocol for raw message storage.
|
|
21
|
+
|
|
22
|
+
Stores complete messages per invocation.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
async def add(
|
|
26
|
+
self,
|
|
27
|
+
invocation_id: str,
|
|
28
|
+
message: dict[str, Any],
|
|
29
|
+
) -> str:
|
|
30
|
+
"""Add a message and return its ID.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
invocation_id: Invocation this message belongs to
|
|
34
|
+
message: Complete message dict
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Generated message ID
|
|
38
|
+
"""
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
async def get(self, msg_id: str) -> dict[str, Any] | None:
|
|
42
|
+
"""Get a single message by ID.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
msg_id: Message ID
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Message dict or None if not found
|
|
49
|
+
"""
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
|
|
53
|
+
"""Get multiple messages by IDs.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
msg_ids: List of message IDs
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of message dicts (in same order, None for missing)
|
|
60
|
+
"""
|
|
61
|
+
...
|
|
62
|
+
|
|
63
|
+
async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
|
|
64
|
+
"""Get all messages for an invocation.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
invocation_id: Invocation ID
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List of messages in chronological order
|
|
71
|
+
"""
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
async def delete_by_invocation(self, invocation_id: str) -> int:
|
|
75
|
+
"""Delete all messages for an invocation.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
invocation_id: Invocation ID
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Number of messages deleted
|
|
82
|
+
"""
|
|
83
|
+
...
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class StateBackendRawMessageStore:
|
|
87
|
+
"""RawMessageStore implementation backed by StateBackend."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, backend: Any): # StateBackend
|
|
90
|
+
"""Initialize with StateBackend.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
backend: StateBackend instance
|
|
94
|
+
"""
|
|
95
|
+
self._backend = backend
|
|
96
|
+
|
|
97
|
+
async def add(
|
|
98
|
+
self,
|
|
99
|
+
invocation_id: str,
|
|
100
|
+
message: dict[str, Any],
|
|
101
|
+
) -> str:
|
|
102
|
+
"""Add a message using StateBackend."""
|
|
103
|
+
msg_id = generate_id("rmsg")
|
|
104
|
+
|
|
105
|
+
# Store message with metadata
|
|
106
|
+
await self._backend.set("raw_messages", msg_id, {
|
|
107
|
+
"id": msg_id,
|
|
108
|
+
"invocation_id": invocation_id,
|
|
109
|
+
"message": message,
|
|
110
|
+
"created_at": datetime.now().isoformat(),
|
|
111
|
+
})
|
|
112
|
+
|
|
113
|
+
# Add to invocation's message list
|
|
114
|
+
inv_key = f"raw_msg_ids:{invocation_id}"
|
|
115
|
+
msg_ids = await self._backend.get("raw_messages", inv_key) or []
|
|
116
|
+
msg_ids.append(msg_id)
|
|
117
|
+
await self._backend.set("raw_messages", inv_key, msg_ids)
|
|
118
|
+
|
|
119
|
+
return msg_id
|
|
120
|
+
|
|
121
|
+
async def get(self, msg_id: str) -> dict[str, Any] | None:
|
|
122
|
+
"""Get a single message."""
|
|
123
|
+
data = await self._backend.get("raw_messages", msg_id)
|
|
124
|
+
if data:
|
|
125
|
+
return data.get("message")
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
|
|
129
|
+
"""Get multiple messages."""
|
|
130
|
+
results = []
|
|
131
|
+
for msg_id in msg_ids:
|
|
132
|
+
msg = await self.get(msg_id)
|
|
133
|
+
if msg:
|
|
134
|
+
results.append(msg)
|
|
135
|
+
return results
|
|
136
|
+
|
|
137
|
+
async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
|
|
138
|
+
"""Get all messages for an invocation."""
|
|
139
|
+
inv_key = f"raw_msg_ids:{invocation_id}"
|
|
140
|
+
msg_ids = await self._backend.get("raw_messages", inv_key) or []
|
|
141
|
+
return await self.get_many(msg_ids)
|
|
142
|
+
|
|
143
|
+
async def delete_by_invocation(self, invocation_id: str) -> int:
|
|
144
|
+
"""Delete all messages for an invocation."""
|
|
145
|
+
inv_key = f"raw_msg_ids:{invocation_id}"
|
|
146
|
+
msg_ids = await self._backend.get("raw_messages", inv_key) or []
|
|
147
|
+
|
|
148
|
+
# Delete each message
|
|
149
|
+
for msg_id in msg_ids:
|
|
150
|
+
await self._backend.remove("raw_messages", msg_id)
|
|
151
|
+
|
|
152
|
+
# Delete the index
|
|
153
|
+
await self._backend.remove("raw_messages", inv_key)
|
|
154
|
+
|
|
155
|
+
return len(msg_ids)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class InMemoryRawMessageStore:
|
|
159
|
+
"""In-memory raw message store for testing."""
|
|
160
|
+
|
|
161
|
+
def __init__(self) -> None:
|
|
162
|
+
self._messages: dict[str, dict[str, Any]] = {} # msg_id -> message data
|
|
163
|
+
self._invocation_index: dict[str, list[str]] = {} # inv_id -> [msg_ids]
|
|
164
|
+
|
|
165
|
+
async def add(
|
|
166
|
+
self,
|
|
167
|
+
invocation_id: str,
|
|
168
|
+
message: dict[str, Any],
|
|
169
|
+
) -> str:
|
|
170
|
+
"""Add a message."""
|
|
171
|
+
msg_id = generate_id("rmsg")
|
|
172
|
+
|
|
173
|
+
self._messages[msg_id] = {
|
|
174
|
+
"id": msg_id,
|
|
175
|
+
"invocation_id": invocation_id,
|
|
176
|
+
"message": message,
|
|
177
|
+
"created_at": datetime.now().isoformat(),
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
if invocation_id not in self._invocation_index:
|
|
181
|
+
self._invocation_index[invocation_id] = []
|
|
182
|
+
self._invocation_index[invocation_id].append(msg_id)
|
|
183
|
+
|
|
184
|
+
return msg_id
|
|
185
|
+
|
|
186
|
+
async def get(self, msg_id: str) -> dict[str, Any] | None:
|
|
187
|
+
"""Get a single message."""
|
|
188
|
+
data = self._messages.get(msg_id)
|
|
189
|
+
if data:
|
|
190
|
+
return data.get("message")
|
|
191
|
+
return None
|
|
192
|
+
|
|
193
|
+
async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
|
|
194
|
+
"""Get multiple messages."""
|
|
195
|
+
results = []
|
|
196
|
+
for msg_id in msg_ids:
|
|
197
|
+
msg = await self.get(msg_id)
|
|
198
|
+
if msg:
|
|
199
|
+
results.append(msg)
|
|
200
|
+
return results
|
|
201
|
+
|
|
202
|
+
async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
|
|
203
|
+
"""Get all messages for an invocation."""
|
|
204
|
+
msg_ids = self._invocation_index.get(invocation_id, [])
|
|
205
|
+
return await self.get_many(msg_ids)
|
|
206
|
+
|
|
207
|
+
async def delete_by_invocation(self, invocation_id: str) -> int:
|
|
208
|
+
"""Delete all messages for an invocation."""
|
|
209
|
+
msg_ids = self._invocation_index.pop(invocation_id, [])
|
|
210
|
+
for msg_id in msg_ids:
|
|
211
|
+
self._messages.pop(msg_id, None)
|
|
212
|
+
return len(msg_ids)
|
|
213
|
+
|
|
214
|
+
def clear(self) -> None:
|
|
215
|
+
"""Clear all messages (for testing)."""
|
|
216
|
+
self._messages.clear()
|
|
217
|
+
self._invocation_index.clear()
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
__all__ = [
|
|
221
|
+
"RawMessageStore",
|
|
222
|
+
"StateBackendRawMessageStore",
|
|
223
|
+
"InMemoryRawMessageStore",
|
|
224
|
+
]
|