aury-agent 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/__init__.py +2 -0
- aury/agents/__init__.py +55 -0
- aury/agents/a2a/__init__.py +168 -0
- aury/agents/backends/__init__.py +196 -0
- aury/agents/backends/artifact/__init__.py +9 -0
- aury/agents/backends/artifact/memory.py +130 -0
- aury/agents/backends/artifact/types.py +133 -0
- aury/agents/backends/code/__init__.py +65 -0
- aury/agents/backends/file/__init__.py +11 -0
- aury/agents/backends/file/local.py +66 -0
- aury/agents/backends/file/types.py +40 -0
- aury/agents/backends/invocation/__init__.py +8 -0
- aury/agents/backends/invocation/memory.py +81 -0
- aury/agents/backends/invocation/types.py +110 -0
- aury/agents/backends/memory/__init__.py +8 -0
- aury/agents/backends/memory/memory.py +179 -0
- aury/agents/backends/memory/types.py +136 -0
- aury/agents/backends/message/__init__.py +9 -0
- aury/agents/backends/message/memory.py +122 -0
- aury/agents/backends/message/types.py +124 -0
- aury/agents/backends/sandbox.py +275 -0
- aury/agents/backends/session/__init__.py +8 -0
- aury/agents/backends/session/memory.py +93 -0
- aury/agents/backends/session/types.py +124 -0
- aury/agents/backends/shell/__init__.py +11 -0
- aury/agents/backends/shell/local.py +110 -0
- aury/agents/backends/shell/types.py +55 -0
- aury/agents/backends/shell.py +209 -0
- aury/agents/backends/snapshot/__init__.py +19 -0
- aury/agents/backends/snapshot/git.py +95 -0
- aury/agents/backends/snapshot/hybrid.py +125 -0
- aury/agents/backends/snapshot/memory.py +86 -0
- aury/agents/backends/snapshot/types.py +59 -0
- aury/agents/backends/state/__init__.py +29 -0
- aury/agents/backends/state/composite.py +49 -0
- aury/agents/backends/state/file.py +57 -0
- aury/agents/backends/state/memory.py +52 -0
- aury/agents/backends/state/sqlite.py +262 -0
- aury/agents/backends/state/types.py +178 -0
- aury/agents/backends/subagent/__init__.py +165 -0
- aury/agents/cli/__init__.py +41 -0
- aury/agents/cli/chat.py +239 -0
- aury/agents/cli/config.py +236 -0
- aury/agents/cli/extensions.py +460 -0
- aury/agents/cli/main.py +189 -0
- aury/agents/cli/session.py +337 -0
- aury/agents/cli/workflow.py +276 -0
- aury/agents/context_providers/__init__.py +66 -0
- aury/agents/context_providers/artifact.py +299 -0
- aury/agents/context_providers/base.py +177 -0
- aury/agents/context_providers/memory.py +70 -0
- aury/agents/context_providers/message.py +130 -0
- aury/agents/context_providers/skill.py +50 -0
- aury/agents/context_providers/subagent.py +46 -0
- aury/agents/context_providers/tool.py +68 -0
- aury/agents/core/__init__.py +83 -0
- aury/agents/core/base.py +573 -0
- aury/agents/core/context.py +797 -0
- aury/agents/core/context_builder.py +303 -0
- aury/agents/core/event_bus/__init__.py +15 -0
- aury/agents/core/event_bus/bus.py +203 -0
- aury/agents/core/factory.py +169 -0
- aury/agents/core/isolator.py +97 -0
- aury/agents/core/logging.py +95 -0
- aury/agents/core/parallel.py +194 -0
- aury/agents/core/runner.py +139 -0
- aury/agents/core/services/__init__.py +5 -0
- aury/agents/core/services/file_session.py +144 -0
- aury/agents/core/services/message.py +53 -0
- aury/agents/core/services/session.py +53 -0
- aury/agents/core/signals.py +109 -0
- aury/agents/core/state.py +363 -0
- aury/agents/core/types/__init__.py +107 -0
- aury/agents/core/types/action.py +176 -0
- aury/agents/core/types/artifact.py +135 -0
- aury/agents/core/types/block.py +736 -0
- aury/agents/core/types/message.py +350 -0
- aury/agents/core/types/recall.py +144 -0
- aury/agents/core/types/session.py +257 -0
- aury/agents/core/types/subagent.py +154 -0
- aury/agents/core/types/tool.py +205 -0
- aury/agents/eval/__init__.py +331 -0
- aury/agents/hitl/__init__.py +57 -0
- aury/agents/hitl/ask_user.py +242 -0
- aury/agents/hitl/compaction.py +230 -0
- aury/agents/hitl/exceptions.py +87 -0
- aury/agents/hitl/permission.py +617 -0
- aury/agents/hitl/revert.py +216 -0
- aury/agents/llm/__init__.py +31 -0
- aury/agents/llm/adapter.py +367 -0
- aury/agents/llm/openai.py +294 -0
- aury/agents/llm/provider.py +476 -0
- aury/agents/mcp/__init__.py +153 -0
- aury/agents/memory/__init__.py +46 -0
- aury/agents/memory/compaction.py +394 -0
- aury/agents/memory/manager.py +465 -0
- aury/agents/memory/processor.py +177 -0
- aury/agents/memory/store.py +187 -0
- aury/agents/memory/types.py +137 -0
- aury/agents/messages/__init__.py +40 -0
- aury/agents/messages/config.py +47 -0
- aury/agents/messages/raw_store.py +224 -0
- aury/agents/messages/store.py +118 -0
- aury/agents/messages/types.py +88 -0
- aury/agents/middleware/__init__.py +31 -0
- aury/agents/middleware/base.py +341 -0
- aury/agents/middleware/chain.py +342 -0
- aury/agents/middleware/message.py +129 -0
- aury/agents/middleware/message_container.py +126 -0
- aury/agents/middleware/raw_message.py +153 -0
- aury/agents/middleware/truncation.py +139 -0
- aury/agents/middleware/types.py +81 -0
- aury/agents/plugin.py +162 -0
- aury/agents/react/__init__.py +4 -0
- aury/agents/react/agent.py +1923 -0
- aury/agents/sandbox/__init__.py +23 -0
- aury/agents/sandbox/local.py +239 -0
- aury/agents/sandbox/remote.py +200 -0
- aury/agents/sandbox/types.py +115 -0
- aury/agents/skill/__init__.py +16 -0
- aury/agents/skill/loader.py +180 -0
- aury/agents/skill/types.py +83 -0
- aury/agents/tool/__init__.py +39 -0
- aury/agents/tool/builtin/__init__.py +23 -0
- aury/agents/tool/builtin/ask_user.py +155 -0
- aury/agents/tool/builtin/bash.py +107 -0
- aury/agents/tool/builtin/delegate.py +726 -0
- aury/agents/tool/builtin/edit.py +121 -0
- aury/agents/tool/builtin/plan.py +277 -0
- aury/agents/tool/builtin/read.py +91 -0
- aury/agents/tool/builtin/thinking.py +111 -0
- aury/agents/tool/builtin/yield_result.py +130 -0
- aury/agents/tool/decorator.py +252 -0
- aury/agents/tool/set.py +204 -0
- aury/agents/usage/__init__.py +12 -0
- aury/agents/usage/tracker.py +236 -0
- aury/agents/workflow/__init__.py +85 -0
- aury/agents/workflow/adapter.py +268 -0
- aury/agents/workflow/dag.py +116 -0
- aury/agents/workflow/dsl.py +575 -0
- aury/agents/workflow/executor.py +659 -0
- aury/agents/workflow/expression.py +136 -0
- aury/agents/workflow/parser.py +182 -0
- aury/agents/workflow/state.py +145 -0
- aury/agents/workflow/types.py +86 -0
- aury_agent-0.0.4.dist-info/METADATA +90 -0
- aury_agent-0.0.4.dist-info/RECORD +149 -0
- aury_agent-0.0.4.dist-info/WHEEL +4 -0
- aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
"""Memory manager for unified memory operations."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any
|
|
7
|
+
from uuid import uuid4
|
|
8
|
+
|
|
9
|
+
from ..core.event_bus import Bus, Events
|
|
10
|
+
from ..core.logging import memory_logger as logger
|
|
11
|
+
from ..core.types.session import generate_id
|
|
12
|
+
from .types import MemorySummary, MemoryRecall, MemoryContext
|
|
13
|
+
from .store import MemoryEntry, ScoredEntry, MemoryStore
|
|
14
|
+
from .processor import WriteFilter, WriteDecision, WriteResult, MemoryProcessor, ProcessContext, WriteContext, ReadContext
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WriteTrigger(Enum):
|
|
18
|
+
"""When memory is written."""
|
|
19
|
+
MANUAL = "manual"
|
|
20
|
+
INVOCATION_END = "invocation_end"
|
|
21
|
+
COMPRESS = "compress"
|
|
22
|
+
EVENT = "event"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class RetrievalSource:
|
|
27
|
+
"""Configuration for a retrieval source."""
|
|
28
|
+
store_name: str
|
|
29
|
+
weight: float = 1.0
|
|
30
|
+
filter: dict[str, Any] | None = None
|
|
31
|
+
limit: int = 10
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class MemoryManager:
|
|
35
|
+
"""Unified memory manager.
|
|
36
|
+
|
|
37
|
+
Handles:
|
|
38
|
+
- Multiple memory stores
|
|
39
|
+
- Write pipeline (filter, process, store)
|
|
40
|
+
- Read pipeline (search, merge, post-process)
|
|
41
|
+
- Auto-triggers from bus events
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
stores: dict[str, MemoryStore],
|
|
47
|
+
retrieval_config: list[RetrievalSource] | None = None,
|
|
48
|
+
write_filters: list[WriteFilter] | None = None,
|
|
49
|
+
write_processors: list[MemoryProcessor] | None = None,
|
|
50
|
+
read_processors: list[Any] | None = None,
|
|
51
|
+
auto_triggers: set[WriteTrigger] | None = None,
|
|
52
|
+
bus: Bus | None = None,
|
|
53
|
+
):
|
|
54
|
+
self.stores = stores
|
|
55
|
+
self.retrieval_config = retrieval_config or [
|
|
56
|
+
RetrievalSource(store_name=name, limit=10)
|
|
57
|
+
for name in stores
|
|
58
|
+
]
|
|
59
|
+
self.write_filters = write_filters or []
|
|
60
|
+
self.write_processors = write_processors or []
|
|
61
|
+
self.read_processors = read_processors or []
|
|
62
|
+
self.auto_triggers = auto_triggers or {WriteTrigger.INVOCATION_END}
|
|
63
|
+
self.bus = bus
|
|
64
|
+
|
|
65
|
+
# Register bus handlers
|
|
66
|
+
if bus:
|
|
67
|
+
self._register_triggers()
|
|
68
|
+
|
|
69
|
+
def _register_triggers(self) -> None:
|
|
70
|
+
"""Register auto-trigger handlers."""
|
|
71
|
+
if WriteTrigger.INVOCATION_END in self.auto_triggers:
|
|
72
|
+
self.bus.subscribe(Events.INVOCATION_END, self._on_invocation_end)
|
|
73
|
+
|
|
74
|
+
async def _on_invocation_end(self, event_type: str, payload: dict[str, Any]) -> None:
|
|
75
|
+
"""Handle invocation end event."""
|
|
76
|
+
messages = payload.get("messages", [])
|
|
77
|
+
if not messages:
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
content = self._format_messages(messages)
|
|
81
|
+
|
|
82
|
+
await self.add(
|
|
83
|
+
content=content,
|
|
84
|
+
session_id=payload.get("session_id"),
|
|
85
|
+
invocation_id=payload.get("invocation_id"),
|
|
86
|
+
metadata={"type": "conversation"},
|
|
87
|
+
trigger=WriteTrigger.INVOCATION_END,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _format_messages(self, messages: list[dict[str, Any]]) -> str:
|
|
91
|
+
"""Format messages for storage."""
|
|
92
|
+
parts = []
|
|
93
|
+
for msg in messages:
|
|
94
|
+
role = msg.get("role", "unknown")
|
|
95
|
+
content = msg.get("content", "")
|
|
96
|
+
if isinstance(content, list):
|
|
97
|
+
# Handle multi-part content
|
|
98
|
+
text_parts = [
|
|
99
|
+
p.get("text", "") for p in content
|
|
100
|
+
if isinstance(p, dict) and p.get("type") == "text"
|
|
101
|
+
]
|
|
102
|
+
content = " ".join(text_parts)
|
|
103
|
+
parts.append(f"[{role}]: {content}")
|
|
104
|
+
return "\n\n".join(parts)
|
|
105
|
+
|
|
106
|
+
async def add(
|
|
107
|
+
self,
|
|
108
|
+
content: str,
|
|
109
|
+
session_id: str | None = None,
|
|
110
|
+
invocation_id: str | None = None,
|
|
111
|
+
metadata: dict[str, Any] | None = None,
|
|
112
|
+
trigger: WriteTrigger = WriteTrigger.MANUAL,
|
|
113
|
+
) -> str | None:
|
|
114
|
+
"""Add content to memory.
|
|
115
|
+
|
|
116
|
+
Runs through write pipeline:
|
|
117
|
+
1. Filters (can skip/transform)
|
|
118
|
+
2. Processors (transform)
|
|
119
|
+
3. Store in all stores
|
|
120
|
+
|
|
121
|
+
Returns entry ID or None if filtered out.
|
|
122
|
+
"""
|
|
123
|
+
logger.debug(
|
|
124
|
+
"Adding to memory",
|
|
125
|
+
extra={"trigger": trigger.value, "session_id": session_id}
|
|
126
|
+
)
|
|
127
|
+
entry = MemoryEntry(
|
|
128
|
+
id=str(uuid4()),
|
|
129
|
+
content=content,
|
|
130
|
+
session_id=session_id,
|
|
131
|
+
invocation_id=invocation_id,
|
|
132
|
+
metadata={**(metadata or {}), "trigger": trigger.value},
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
entries = [entry]
|
|
136
|
+
write_context = WriteContext(
|
|
137
|
+
trigger=trigger,
|
|
138
|
+
session_id=session_id,
|
|
139
|
+
invocation_id=invocation_id,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# 1. Apply filters
|
|
143
|
+
for filter in self.write_filters:
|
|
144
|
+
result = await filter.filter(entries, write_context)
|
|
145
|
+
|
|
146
|
+
if result.decision == WriteDecision.SKIP:
|
|
147
|
+
return None
|
|
148
|
+
elif result.decision == WriteDecision.TRANSFORM:
|
|
149
|
+
entries = result.entries or []
|
|
150
|
+
|
|
151
|
+
if not entries:
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
# 2. Apply processors
|
|
155
|
+
process_context = ProcessContext(session_id=session_id)
|
|
156
|
+
for processor in self.write_processors:
|
|
157
|
+
entries = await processor.process(entries, process_context)
|
|
158
|
+
|
|
159
|
+
# 3. Store in all stores
|
|
160
|
+
for entry in entries:
|
|
161
|
+
for store in self.stores.values():
|
|
162
|
+
await store.add(entry)
|
|
163
|
+
|
|
164
|
+
if self.bus:
|
|
165
|
+
await self.bus.publish(Events.MEMORY_ADD, {
|
|
166
|
+
"entry_id": entries[0].id if entries else None,
|
|
167
|
+
"count": len(entries),
|
|
168
|
+
})
|
|
169
|
+
|
|
170
|
+
return entries[0].id if entries else None
|
|
171
|
+
|
|
172
|
+
async def search(
|
|
173
|
+
self,
|
|
174
|
+
query: str,
|
|
175
|
+
filter: dict[str, Any] | None = None,
|
|
176
|
+
limit: int = 10,
|
|
177
|
+
) -> list[ScoredEntry]:
|
|
178
|
+
"""Search memory stores.
|
|
179
|
+
|
|
180
|
+
Searches all configured sources and merges results.
|
|
181
|
+
"""
|
|
182
|
+
# 1. Search all sources
|
|
183
|
+
all_results: dict[str, list[ScoredEntry]] = {}
|
|
184
|
+
|
|
185
|
+
for source in self.retrieval_config:
|
|
186
|
+
if source.store_name not in self.stores:
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
store = self.stores[source.store_name]
|
|
190
|
+
merged_filter = {**(filter or {}), **(source.filter or {})}
|
|
191
|
+
|
|
192
|
+
results = await store.search(
|
|
193
|
+
query=query,
|
|
194
|
+
filter=merged_filter,
|
|
195
|
+
limit=source.limit,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Apply source weight
|
|
199
|
+
for r in results:
|
|
200
|
+
r.score *= source.weight
|
|
201
|
+
|
|
202
|
+
all_results[source.store_name] = results
|
|
203
|
+
|
|
204
|
+
# 2. Merge results (simple dedup by ID)
|
|
205
|
+
seen_ids: set[str] = set()
|
|
206
|
+
merged: list[ScoredEntry] = []
|
|
207
|
+
|
|
208
|
+
# Flatten and sort by score
|
|
209
|
+
flat_results = []
|
|
210
|
+
for results in all_results.values():
|
|
211
|
+
flat_results.extend(results)
|
|
212
|
+
flat_results.sort(key=lambda x: x.score, reverse=True)
|
|
213
|
+
|
|
214
|
+
for result in flat_results:
|
|
215
|
+
if result.entry.id not in seen_ids:
|
|
216
|
+
seen_ids.add(result.entry.id)
|
|
217
|
+
merged.append(result)
|
|
218
|
+
|
|
219
|
+
# 3. Apply read processors
|
|
220
|
+
read_context = ReadContext(limit=limit)
|
|
221
|
+
for processor in self.read_processors:
|
|
222
|
+
merged = await processor.process(merged, query, read_context)
|
|
223
|
+
|
|
224
|
+
if self.bus:
|
|
225
|
+
await self.bus.publish(Events.MEMORY_SEARCH, {
|
|
226
|
+
"query": query[:100],
|
|
227
|
+
"result_count": len(merged[:limit]),
|
|
228
|
+
})
|
|
229
|
+
|
|
230
|
+
return merged[:limit]
|
|
231
|
+
|
|
232
|
+
async def revert(
|
|
233
|
+
self,
|
|
234
|
+
session_id: str,
|
|
235
|
+
after_invocation_id: str,
|
|
236
|
+
) -> list[str]:
|
|
237
|
+
"""Revert memory entries after specified invocation."""
|
|
238
|
+
deleted = []
|
|
239
|
+
|
|
240
|
+
for store in self.stores.values():
|
|
241
|
+
ids = await store.revert(session_id, after_invocation_id)
|
|
242
|
+
deleted.extend(ids)
|
|
243
|
+
|
|
244
|
+
return deleted
|
|
245
|
+
|
|
246
|
+
async def on_compress(
|
|
247
|
+
self,
|
|
248
|
+
session_id: str,
|
|
249
|
+
invocation_id: str,
|
|
250
|
+
ejected_messages: list[dict[str, Any]],
|
|
251
|
+
) -> str | None:
|
|
252
|
+
"""Handle compression - save ejected messages to memory."""
|
|
253
|
+
if WriteTrigger.COMPRESS not in self.auto_triggers:
|
|
254
|
+
return None
|
|
255
|
+
|
|
256
|
+
content = self._format_messages(ejected_messages)
|
|
257
|
+
|
|
258
|
+
return await self.add(
|
|
259
|
+
content=content,
|
|
260
|
+
session_id=session_id,
|
|
261
|
+
invocation_id=invocation_id,
|
|
262
|
+
metadata={"type": "compressed"},
|
|
263
|
+
trigger=WriteTrigger.COMPRESS,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# ========== Summary & Recall API ==========
|
|
267
|
+
|
|
268
|
+
async def get_context(
|
|
269
|
+
self,
|
|
270
|
+
session_id: str,
|
|
271
|
+
invocation_ids: list[str] | None = None,
|
|
272
|
+
recall_limit: int = 10,
|
|
273
|
+
) -> MemoryContext:
|
|
274
|
+
"""Get memory context for LLM.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
session_id: Session to get context for
|
|
278
|
+
invocation_ids: Filter recalls to these invocations (for isolation)
|
|
279
|
+
recall_limit: Max number of recalls to return
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
MemoryContext with summary and recalls
|
|
283
|
+
"""
|
|
284
|
+
# Get summary
|
|
285
|
+
summary = await self.get_summary(session_id)
|
|
286
|
+
|
|
287
|
+
# Get recalls, filtered by invocation chain if provided
|
|
288
|
+
recalls = await self.get_recalls(
|
|
289
|
+
session_id=session_id,
|
|
290
|
+
invocation_ids=invocation_ids,
|
|
291
|
+
limit=recall_limit,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return MemoryContext(summary=summary, recalls=recalls)
|
|
295
|
+
|
|
296
|
+
async def get_summary(self, session_id: str) -> MemorySummary | None:
|
|
297
|
+
"""Get session summary."""
|
|
298
|
+
# Look in first store that has summaries
|
|
299
|
+
for store in self.stores.values():
|
|
300
|
+
if hasattr(store, 'get_summary'):
|
|
301
|
+
return await store.get_summary(session_id)
|
|
302
|
+
|
|
303
|
+
# Fallback: search for summary entry
|
|
304
|
+
results = await self.search(
|
|
305
|
+
query="conversation summary",
|
|
306
|
+
filter={"session_id": session_id, "type": "summary"},
|
|
307
|
+
limit=1,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if results:
|
|
311
|
+
entry = results[0].entry
|
|
312
|
+
return MemorySummary(
|
|
313
|
+
session_id=session_id,
|
|
314
|
+
content=entry.content,
|
|
315
|
+
last_invocation_id=entry.invocation_id or "",
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return None
|
|
319
|
+
|
|
320
|
+
async def get_recalls(
|
|
321
|
+
self,
|
|
322
|
+
session_id: str,
|
|
323
|
+
invocation_ids: list[str] | None = None,
|
|
324
|
+
limit: int = 10,
|
|
325
|
+
) -> list[MemoryRecall]:
|
|
326
|
+
"""Get recalls for session, optionally filtered by invocations."""
|
|
327
|
+
filter_dict: dict[str, Any] = {"session_id": session_id, "type": "recall"}
|
|
328
|
+
if invocation_ids:
|
|
329
|
+
filter_dict["invocation_id"] = invocation_ids
|
|
330
|
+
|
|
331
|
+
# Search for recall entries
|
|
332
|
+
results = await self.search(
|
|
333
|
+
query="key points recalls",
|
|
334
|
+
filter=filter_dict,
|
|
335
|
+
limit=limit,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
recalls = []
|
|
339
|
+
for r in results:
|
|
340
|
+
entry = r.entry
|
|
341
|
+
recalls.append(MemoryRecall(
|
|
342
|
+
id=entry.id,
|
|
343
|
+
session_id=session_id,
|
|
344
|
+
invocation_id=entry.invocation_id or "",
|
|
345
|
+
content=entry.content,
|
|
346
|
+
importance=entry.metadata.get("importance", 0.5),
|
|
347
|
+
tags=entry.metadata.get("tags", []),
|
|
348
|
+
))
|
|
349
|
+
|
|
350
|
+
return recalls
|
|
351
|
+
|
|
352
|
+
async def add_recall(
|
|
353
|
+
self,
|
|
354
|
+
session_id: str,
|
|
355
|
+
invocation_id: str,
|
|
356
|
+
content: str,
|
|
357
|
+
importance: float = 0.5,
|
|
358
|
+
tags: list[str] | None = None,
|
|
359
|
+
) -> str:
|
|
360
|
+
"""Add a recall entry.
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Recall ID
|
|
364
|
+
"""
|
|
365
|
+
recall_id = generate_id("recall")
|
|
366
|
+
|
|
367
|
+
await self.add(
|
|
368
|
+
content=content,
|
|
369
|
+
session_id=session_id,
|
|
370
|
+
invocation_id=invocation_id,
|
|
371
|
+
metadata={
|
|
372
|
+
"type": "recall",
|
|
373
|
+
"recall_id": recall_id,
|
|
374
|
+
"importance": importance,
|
|
375
|
+
"tags": tags or [],
|
|
376
|
+
},
|
|
377
|
+
trigger=WriteTrigger.MANUAL,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
return recall_id
|
|
381
|
+
|
|
382
|
+
async def update_summary(
|
|
383
|
+
self,
|
|
384
|
+
session_id: str,
|
|
385
|
+
content: str,
|
|
386
|
+
last_invocation_id: str,
|
|
387
|
+
) -> None:
|
|
388
|
+
"""Update session summary."""
|
|
389
|
+
# Delete old summary
|
|
390
|
+
for store in self.stores.values():
|
|
391
|
+
if hasattr(store, 'delete_by_filter'):
|
|
392
|
+
await store.delete_by_filter({
|
|
393
|
+
"session_id": session_id,
|
|
394
|
+
"type": "summary",
|
|
395
|
+
})
|
|
396
|
+
|
|
397
|
+
# Add new summary
|
|
398
|
+
await self.add(
|
|
399
|
+
content=content,
|
|
400
|
+
session_id=session_id,
|
|
401
|
+
invocation_id=last_invocation_id,
|
|
402
|
+
metadata={"type": "summary"},
|
|
403
|
+
trigger=WriteTrigger.MANUAL,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
async def delete_by_invocation(self, invocation_id: str) -> int:
|
|
407
|
+
"""Delete all memory entries for an invocation (for revert).
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
Number of entries deleted
|
|
411
|
+
"""
|
|
412
|
+
count = 0
|
|
413
|
+
for store in self.stores.values():
|
|
414
|
+
if hasattr(store, 'delete_by_filter'):
|
|
415
|
+
deleted = await store.delete_by_filter({"invocation_id": invocation_id})
|
|
416
|
+
count += deleted if isinstance(deleted, int) else 0
|
|
417
|
+
return count
|
|
418
|
+
|
|
419
|
+
async def on_subagent_complete(
|
|
420
|
+
self,
|
|
421
|
+
sub_inv_id: str,
|
|
422
|
+
parent_inv_id: str,
|
|
423
|
+
merge_mode: str,
|
|
424
|
+
) -> None:
|
|
425
|
+
"""Handle SubAgent completion - merge memory based on mode.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
sub_inv_id: SubAgent's invocation ID
|
|
429
|
+
parent_inv_id: Parent's invocation ID
|
|
430
|
+
merge_mode: "merge", "summarize", or "discard"
|
|
431
|
+
"""
|
|
432
|
+
if merge_mode == "merge":
|
|
433
|
+
# Move all recalls from sub to parent
|
|
434
|
+
sub_recalls = await self.get_recalls(
|
|
435
|
+
session_id="", # Will be filtered by invocation
|
|
436
|
+
invocation_ids=[sub_inv_id],
|
|
437
|
+
limit=100,
|
|
438
|
+
)
|
|
439
|
+
for recall in sub_recalls:
|
|
440
|
+
await self.add_recall(
|
|
441
|
+
session_id=recall.session_id,
|
|
442
|
+
invocation_id=parent_inv_id,
|
|
443
|
+
content=recall.content,
|
|
444
|
+
importance=recall.importance,
|
|
445
|
+
tags=recall.tags,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
elif merge_mode == "summarize":
|
|
449
|
+
# Create a summary recall in parent
|
|
450
|
+
sub_recalls = await self.get_recalls(
|
|
451
|
+
session_id="",
|
|
452
|
+
invocation_ids=[sub_inv_id],
|
|
453
|
+
limit=100,
|
|
454
|
+
)
|
|
455
|
+
if sub_recalls:
|
|
456
|
+
combined = "\n".join([r.content for r in sub_recalls])
|
|
457
|
+
await self.add_recall(
|
|
458
|
+
session_id=sub_recalls[0].session_id,
|
|
459
|
+
invocation_id=parent_inv_id,
|
|
460
|
+
content=f"[SubAgent result] {combined[:500]}...",
|
|
461
|
+
importance=0.7,
|
|
462
|
+
tags=["subagent_result"],
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
# "discard" mode: do nothing, sub's memory stays isolated
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""Memory processors for filtering and transformation."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any, Protocol
|
|
7
|
+
|
|
8
|
+
from .store import MemoryEntry, ScoredEntry
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class WriteDecision(Enum):
|
|
12
|
+
"""Decision from write filter."""
|
|
13
|
+
SKIP = "skip"
|
|
14
|
+
PASS = "pass"
|
|
15
|
+
TRANSFORM = "transform"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class WriteResult:
|
|
20
|
+
"""Result from write filter."""
|
|
21
|
+
decision: WriteDecision
|
|
22
|
+
entries: list[MemoryEntry] | None = None
|
|
23
|
+
reason: str | None = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class WriteContext:
|
|
28
|
+
"""Context for write operations."""
|
|
29
|
+
trigger: Any # WriteTrigger
|
|
30
|
+
session_id: str | None = None
|
|
31
|
+
invocation_id: str | None = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class ProcessContext:
|
|
36
|
+
"""Context for processing operations."""
|
|
37
|
+
session_id: str | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class ReadContext:
|
|
42
|
+
"""Context for read operations."""
|
|
43
|
+
session_id: str | None = None
|
|
44
|
+
limit: int = 10
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class WriteFilter(Protocol):
|
|
48
|
+
"""Write filter protocol."""
|
|
49
|
+
|
|
50
|
+
async def filter(
|
|
51
|
+
self,
|
|
52
|
+
entries: list[MemoryEntry],
|
|
53
|
+
context: WriteContext,
|
|
54
|
+
) -> WriteResult:
|
|
55
|
+
"""Filter entries before writing.
|
|
56
|
+
|
|
57
|
+
Returns WriteResult with decision.
|
|
58
|
+
"""
|
|
59
|
+
...
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class MemoryProcessor(Protocol):
|
|
63
|
+
"""Memory processor protocol."""
|
|
64
|
+
|
|
65
|
+
async def process(
|
|
66
|
+
self,
|
|
67
|
+
entries: list[MemoryEntry],
|
|
68
|
+
context: ProcessContext,
|
|
69
|
+
) -> list[MemoryEntry]:
|
|
70
|
+
"""Process entries, return transformed list."""
|
|
71
|
+
...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ReadPostProcessor(Protocol):
|
|
75
|
+
"""Read post-processor protocol."""
|
|
76
|
+
|
|
77
|
+
async def process(
|
|
78
|
+
self,
|
|
79
|
+
results: list[ScoredEntry],
|
|
80
|
+
query: str,
|
|
81
|
+
context: ReadContext,
|
|
82
|
+
) -> list[ScoredEntry]:
|
|
83
|
+
"""Post-process search results."""
|
|
84
|
+
...
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class DeduplicationFilter:
|
|
88
|
+
"""Filter duplicate content."""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
store: Any, # MemoryStore
|
|
93
|
+
similarity_threshold: float = 0.9,
|
|
94
|
+
):
|
|
95
|
+
self.store = store
|
|
96
|
+
self.similarity_threshold = similarity_threshold
|
|
97
|
+
|
|
98
|
+
async def filter(
|
|
99
|
+
self,
|
|
100
|
+
entries: list[MemoryEntry],
|
|
101
|
+
context: WriteContext,
|
|
102
|
+
) -> WriteResult:
|
|
103
|
+
"""Check for duplicate content."""
|
|
104
|
+
for entry in entries:
|
|
105
|
+
# Search for similar content
|
|
106
|
+
results = await self.store.search(
|
|
107
|
+
query=entry.content,
|
|
108
|
+
filter={"session_id": entry.session_id} if entry.session_id else None,
|
|
109
|
+
limit=1,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if results and results[0].score > self.similarity_threshold:
|
|
113
|
+
return WriteResult(
|
|
114
|
+
decision=WriteDecision.SKIP,
|
|
115
|
+
reason=f"Duplicate found: {results[0].entry.id}",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return WriteResult(decision=WriteDecision.PASS)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class LengthFilter:
|
|
122
|
+
"""Filter entries by content length."""
|
|
123
|
+
|
|
124
|
+
def __init__(self, min_length: int = 10, max_length: int = 10000):
|
|
125
|
+
self.min_length = min_length
|
|
126
|
+
self.max_length = max_length
|
|
127
|
+
|
|
128
|
+
async def filter(
|
|
129
|
+
self,
|
|
130
|
+
entries: list[MemoryEntry],
|
|
131
|
+
context: WriteContext,
|
|
132
|
+
) -> WriteResult:
|
|
133
|
+
"""Filter by content length."""
|
|
134
|
+
for entry in entries:
|
|
135
|
+
if len(entry.content) < self.min_length:
|
|
136
|
+
return WriteResult(
|
|
137
|
+
decision=WriteDecision.SKIP,
|
|
138
|
+
reason=f"Content too short: {len(entry.content)} < {self.min_length}",
|
|
139
|
+
)
|
|
140
|
+
if len(entry.content) > self.max_length:
|
|
141
|
+
return WriteResult(
|
|
142
|
+
decision=WriteDecision.SKIP,
|
|
143
|
+
reason=f"Content too long: {len(entry.content)} > {self.max_length}",
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return WriteResult(decision=WriteDecision.PASS)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class TruncationProcessor:
|
|
150
|
+
"""Truncate long content."""
|
|
151
|
+
|
|
152
|
+
def __init__(self, max_length: int = 5000):
|
|
153
|
+
self.max_length = max_length
|
|
154
|
+
|
|
155
|
+
async def process(
|
|
156
|
+
self,
|
|
157
|
+
entries: list[MemoryEntry],
|
|
158
|
+
context: ProcessContext,
|
|
159
|
+
) -> list[MemoryEntry]:
|
|
160
|
+
"""Truncate content if too long."""
|
|
161
|
+
result = []
|
|
162
|
+
|
|
163
|
+
for entry in entries:
|
|
164
|
+
if len(entry.content) > self.max_length:
|
|
165
|
+
truncated = MemoryEntry(
|
|
166
|
+
id=entry.id,
|
|
167
|
+
content=entry.content[:self.max_length] + "... (truncated)",
|
|
168
|
+
session_id=entry.session_id,
|
|
169
|
+
invocation_id=entry.invocation_id,
|
|
170
|
+
created_at=entry.created_at,
|
|
171
|
+
metadata={**entry.metadata, "truncated": True},
|
|
172
|
+
)
|
|
173
|
+
result.append(truncated)
|
|
174
|
+
else:
|
|
175
|
+
result.append(entry)
|
|
176
|
+
|
|
177
|
+
return result
|