aury-agent 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. aury/__init__.py +2 -0
  2. aury/agents/__init__.py +55 -0
  3. aury/agents/a2a/__init__.py +168 -0
  4. aury/agents/backends/__init__.py +196 -0
  5. aury/agents/backends/artifact/__init__.py +9 -0
  6. aury/agents/backends/artifact/memory.py +130 -0
  7. aury/agents/backends/artifact/types.py +133 -0
  8. aury/agents/backends/code/__init__.py +65 -0
  9. aury/agents/backends/file/__init__.py +11 -0
  10. aury/agents/backends/file/local.py +66 -0
  11. aury/agents/backends/file/types.py +40 -0
  12. aury/agents/backends/invocation/__init__.py +8 -0
  13. aury/agents/backends/invocation/memory.py +81 -0
  14. aury/agents/backends/invocation/types.py +110 -0
  15. aury/agents/backends/memory/__init__.py +8 -0
  16. aury/agents/backends/memory/memory.py +179 -0
  17. aury/agents/backends/memory/types.py +136 -0
  18. aury/agents/backends/message/__init__.py +9 -0
  19. aury/agents/backends/message/memory.py +122 -0
  20. aury/agents/backends/message/types.py +124 -0
  21. aury/agents/backends/sandbox.py +275 -0
  22. aury/agents/backends/session/__init__.py +8 -0
  23. aury/agents/backends/session/memory.py +93 -0
  24. aury/agents/backends/session/types.py +124 -0
  25. aury/agents/backends/shell/__init__.py +11 -0
  26. aury/agents/backends/shell/local.py +110 -0
  27. aury/agents/backends/shell/types.py +55 -0
  28. aury/agents/backends/shell.py +209 -0
  29. aury/agents/backends/snapshot/__init__.py +19 -0
  30. aury/agents/backends/snapshot/git.py +95 -0
  31. aury/agents/backends/snapshot/hybrid.py +125 -0
  32. aury/agents/backends/snapshot/memory.py +86 -0
  33. aury/agents/backends/snapshot/types.py +59 -0
  34. aury/agents/backends/state/__init__.py +29 -0
  35. aury/agents/backends/state/composite.py +49 -0
  36. aury/agents/backends/state/file.py +57 -0
  37. aury/agents/backends/state/memory.py +52 -0
  38. aury/agents/backends/state/sqlite.py +262 -0
  39. aury/agents/backends/state/types.py +178 -0
  40. aury/agents/backends/subagent/__init__.py +165 -0
  41. aury/agents/cli/__init__.py +41 -0
  42. aury/agents/cli/chat.py +239 -0
  43. aury/agents/cli/config.py +236 -0
  44. aury/agents/cli/extensions.py +460 -0
  45. aury/agents/cli/main.py +189 -0
  46. aury/agents/cli/session.py +337 -0
  47. aury/agents/cli/workflow.py +276 -0
  48. aury/agents/context_providers/__init__.py +66 -0
  49. aury/agents/context_providers/artifact.py +299 -0
  50. aury/agents/context_providers/base.py +177 -0
  51. aury/agents/context_providers/memory.py +70 -0
  52. aury/agents/context_providers/message.py +130 -0
  53. aury/agents/context_providers/skill.py +50 -0
  54. aury/agents/context_providers/subagent.py +46 -0
  55. aury/agents/context_providers/tool.py +68 -0
  56. aury/agents/core/__init__.py +83 -0
  57. aury/agents/core/base.py +573 -0
  58. aury/agents/core/context.py +797 -0
  59. aury/agents/core/context_builder.py +303 -0
  60. aury/agents/core/event_bus/__init__.py +15 -0
  61. aury/agents/core/event_bus/bus.py +203 -0
  62. aury/agents/core/factory.py +169 -0
  63. aury/agents/core/isolator.py +97 -0
  64. aury/agents/core/logging.py +95 -0
  65. aury/agents/core/parallel.py +194 -0
  66. aury/agents/core/runner.py +139 -0
  67. aury/agents/core/services/__init__.py +5 -0
  68. aury/agents/core/services/file_session.py +144 -0
  69. aury/agents/core/services/message.py +53 -0
  70. aury/agents/core/services/session.py +53 -0
  71. aury/agents/core/signals.py +109 -0
  72. aury/agents/core/state.py +363 -0
  73. aury/agents/core/types/__init__.py +107 -0
  74. aury/agents/core/types/action.py +176 -0
  75. aury/agents/core/types/artifact.py +135 -0
  76. aury/agents/core/types/block.py +736 -0
  77. aury/agents/core/types/message.py +350 -0
  78. aury/agents/core/types/recall.py +144 -0
  79. aury/agents/core/types/session.py +257 -0
  80. aury/agents/core/types/subagent.py +154 -0
  81. aury/agents/core/types/tool.py +205 -0
  82. aury/agents/eval/__init__.py +331 -0
  83. aury/agents/hitl/__init__.py +57 -0
  84. aury/agents/hitl/ask_user.py +242 -0
  85. aury/agents/hitl/compaction.py +230 -0
  86. aury/agents/hitl/exceptions.py +87 -0
  87. aury/agents/hitl/permission.py +617 -0
  88. aury/agents/hitl/revert.py +216 -0
  89. aury/agents/llm/__init__.py +31 -0
  90. aury/agents/llm/adapter.py +367 -0
  91. aury/agents/llm/openai.py +294 -0
  92. aury/agents/llm/provider.py +476 -0
  93. aury/agents/mcp/__init__.py +153 -0
  94. aury/agents/memory/__init__.py +46 -0
  95. aury/agents/memory/compaction.py +394 -0
  96. aury/agents/memory/manager.py +465 -0
  97. aury/agents/memory/processor.py +177 -0
  98. aury/agents/memory/store.py +187 -0
  99. aury/agents/memory/types.py +137 -0
  100. aury/agents/messages/__init__.py +40 -0
  101. aury/agents/messages/config.py +47 -0
  102. aury/agents/messages/raw_store.py +224 -0
  103. aury/agents/messages/store.py +118 -0
  104. aury/agents/messages/types.py +88 -0
  105. aury/agents/middleware/__init__.py +31 -0
  106. aury/agents/middleware/base.py +341 -0
  107. aury/agents/middleware/chain.py +342 -0
  108. aury/agents/middleware/message.py +129 -0
  109. aury/agents/middleware/message_container.py +126 -0
  110. aury/agents/middleware/raw_message.py +153 -0
  111. aury/agents/middleware/truncation.py +139 -0
  112. aury/agents/middleware/types.py +81 -0
  113. aury/agents/plugin.py +162 -0
  114. aury/agents/react/__init__.py +4 -0
  115. aury/agents/react/agent.py +1923 -0
  116. aury/agents/sandbox/__init__.py +23 -0
  117. aury/agents/sandbox/local.py +239 -0
  118. aury/agents/sandbox/remote.py +200 -0
  119. aury/agents/sandbox/types.py +115 -0
  120. aury/agents/skill/__init__.py +16 -0
  121. aury/agents/skill/loader.py +180 -0
  122. aury/agents/skill/types.py +83 -0
  123. aury/agents/tool/__init__.py +39 -0
  124. aury/agents/tool/builtin/__init__.py +23 -0
  125. aury/agents/tool/builtin/ask_user.py +155 -0
  126. aury/agents/tool/builtin/bash.py +107 -0
  127. aury/agents/tool/builtin/delegate.py +726 -0
  128. aury/agents/tool/builtin/edit.py +121 -0
  129. aury/agents/tool/builtin/plan.py +277 -0
  130. aury/agents/tool/builtin/read.py +91 -0
  131. aury/agents/tool/builtin/thinking.py +111 -0
  132. aury/agents/tool/builtin/yield_result.py +130 -0
  133. aury/agents/tool/decorator.py +252 -0
  134. aury/agents/tool/set.py +204 -0
  135. aury/agents/usage/__init__.py +12 -0
  136. aury/agents/usage/tracker.py +236 -0
  137. aury/agents/workflow/__init__.py +85 -0
  138. aury/agents/workflow/adapter.py +268 -0
  139. aury/agents/workflow/dag.py +116 -0
  140. aury/agents/workflow/dsl.py +575 -0
  141. aury/agents/workflow/executor.py +659 -0
  142. aury/agents/workflow/expression.py +136 -0
  143. aury/agents/workflow/parser.py +182 -0
  144. aury/agents/workflow/state.py +145 -0
  145. aury/agents/workflow/types.py +86 -0
  146. aury_agent-0.0.4.dist-info/METADATA +90 -0
  147. aury_agent-0.0.4.dist-info/RECORD +149 -0
  148. aury_agent-0.0.4.dist-info/WHEEL +4 -0
  149. aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,465 @@
1
+ """Memory manager for unified memory operations."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, field
5
+ from enum import Enum
6
+ from typing import Any
7
+ from uuid import uuid4
8
+
9
+ from ..core.event_bus import Bus, Events
10
+ from ..core.logging import memory_logger as logger
11
+ from ..core.types.session import generate_id
12
+ from .types import MemorySummary, MemoryRecall, MemoryContext
13
+ from .store import MemoryEntry, ScoredEntry, MemoryStore
14
+ from .processor import WriteFilter, WriteDecision, WriteResult, MemoryProcessor, ProcessContext, WriteContext, ReadContext
15
+
16
+
17
+ class WriteTrigger(Enum):
18
+ """When memory is written."""
19
+ MANUAL = "manual"
20
+ INVOCATION_END = "invocation_end"
21
+ COMPRESS = "compress"
22
+ EVENT = "event"
23
+
24
+
25
+ @dataclass
26
+ class RetrievalSource:
27
+ """Configuration for a retrieval source."""
28
+ store_name: str
29
+ weight: float = 1.0
30
+ filter: dict[str, Any] | None = None
31
+ limit: int = 10
32
+
33
+
34
+ class MemoryManager:
35
+ """Unified memory manager.
36
+
37
+ Handles:
38
+ - Multiple memory stores
39
+ - Write pipeline (filter, process, store)
40
+ - Read pipeline (search, merge, post-process)
41
+ - Auto-triggers from bus events
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ stores: dict[str, MemoryStore],
47
+ retrieval_config: list[RetrievalSource] | None = None,
48
+ write_filters: list[WriteFilter] | None = None,
49
+ write_processors: list[MemoryProcessor] | None = None,
50
+ read_processors: list[Any] | None = None,
51
+ auto_triggers: set[WriteTrigger] | None = None,
52
+ bus: Bus | None = None,
53
+ ):
54
+ self.stores = stores
55
+ self.retrieval_config = retrieval_config or [
56
+ RetrievalSource(store_name=name, limit=10)
57
+ for name in stores
58
+ ]
59
+ self.write_filters = write_filters or []
60
+ self.write_processors = write_processors or []
61
+ self.read_processors = read_processors or []
62
+ self.auto_triggers = auto_triggers or {WriteTrigger.INVOCATION_END}
63
+ self.bus = bus
64
+
65
+ # Register bus handlers
66
+ if bus:
67
+ self._register_triggers()
68
+
69
+ def _register_triggers(self) -> None:
70
+ """Register auto-trigger handlers."""
71
+ if WriteTrigger.INVOCATION_END in self.auto_triggers:
72
+ self.bus.subscribe(Events.INVOCATION_END, self._on_invocation_end)
73
+
74
+ async def _on_invocation_end(self, event_type: str, payload: dict[str, Any]) -> None:
75
+ """Handle invocation end event."""
76
+ messages = payload.get("messages", [])
77
+ if not messages:
78
+ return
79
+
80
+ content = self._format_messages(messages)
81
+
82
+ await self.add(
83
+ content=content,
84
+ session_id=payload.get("session_id"),
85
+ invocation_id=payload.get("invocation_id"),
86
+ metadata={"type": "conversation"},
87
+ trigger=WriteTrigger.INVOCATION_END,
88
+ )
89
+
90
+ def _format_messages(self, messages: list[dict[str, Any]]) -> str:
91
+ """Format messages for storage."""
92
+ parts = []
93
+ for msg in messages:
94
+ role = msg.get("role", "unknown")
95
+ content = msg.get("content", "")
96
+ if isinstance(content, list):
97
+ # Handle multi-part content
98
+ text_parts = [
99
+ p.get("text", "") for p in content
100
+ if isinstance(p, dict) and p.get("type") == "text"
101
+ ]
102
+ content = " ".join(text_parts)
103
+ parts.append(f"[{role}]: {content}")
104
+ return "\n\n".join(parts)
105
+
106
+ async def add(
107
+ self,
108
+ content: str,
109
+ session_id: str | None = None,
110
+ invocation_id: str | None = None,
111
+ metadata: dict[str, Any] | None = None,
112
+ trigger: WriteTrigger = WriteTrigger.MANUAL,
113
+ ) -> str | None:
114
+ """Add content to memory.
115
+
116
+ Runs through write pipeline:
117
+ 1. Filters (can skip/transform)
118
+ 2. Processors (transform)
119
+ 3. Store in all stores
120
+
121
+ Returns entry ID or None if filtered out.
122
+ """
123
+ logger.debug(
124
+ "Adding to memory",
125
+ extra={"trigger": trigger.value, "session_id": session_id}
126
+ )
127
+ entry = MemoryEntry(
128
+ id=str(uuid4()),
129
+ content=content,
130
+ session_id=session_id,
131
+ invocation_id=invocation_id,
132
+ metadata={**(metadata or {}), "trigger": trigger.value},
133
+ )
134
+
135
+ entries = [entry]
136
+ write_context = WriteContext(
137
+ trigger=trigger,
138
+ session_id=session_id,
139
+ invocation_id=invocation_id,
140
+ )
141
+
142
+ # 1. Apply filters
143
+ for filter in self.write_filters:
144
+ result = await filter.filter(entries, write_context)
145
+
146
+ if result.decision == WriteDecision.SKIP:
147
+ return None
148
+ elif result.decision == WriteDecision.TRANSFORM:
149
+ entries = result.entries or []
150
+
151
+ if not entries:
152
+ return None
153
+
154
+ # 2. Apply processors
155
+ process_context = ProcessContext(session_id=session_id)
156
+ for processor in self.write_processors:
157
+ entries = await processor.process(entries, process_context)
158
+
159
+ # 3. Store in all stores
160
+ for entry in entries:
161
+ for store in self.stores.values():
162
+ await store.add(entry)
163
+
164
+ if self.bus:
165
+ await self.bus.publish(Events.MEMORY_ADD, {
166
+ "entry_id": entries[0].id if entries else None,
167
+ "count": len(entries),
168
+ })
169
+
170
+ return entries[0].id if entries else None
171
+
172
+ async def search(
173
+ self,
174
+ query: str,
175
+ filter: dict[str, Any] | None = None,
176
+ limit: int = 10,
177
+ ) -> list[ScoredEntry]:
178
+ """Search memory stores.
179
+
180
+ Searches all configured sources and merges results.
181
+ """
182
+ # 1. Search all sources
183
+ all_results: dict[str, list[ScoredEntry]] = {}
184
+
185
+ for source in self.retrieval_config:
186
+ if source.store_name not in self.stores:
187
+ continue
188
+
189
+ store = self.stores[source.store_name]
190
+ merged_filter = {**(filter or {}), **(source.filter or {})}
191
+
192
+ results = await store.search(
193
+ query=query,
194
+ filter=merged_filter,
195
+ limit=source.limit,
196
+ )
197
+
198
+ # Apply source weight
199
+ for r in results:
200
+ r.score *= source.weight
201
+
202
+ all_results[source.store_name] = results
203
+
204
+ # 2. Merge results (simple dedup by ID)
205
+ seen_ids: set[str] = set()
206
+ merged: list[ScoredEntry] = []
207
+
208
+ # Flatten and sort by score
209
+ flat_results = []
210
+ for results in all_results.values():
211
+ flat_results.extend(results)
212
+ flat_results.sort(key=lambda x: x.score, reverse=True)
213
+
214
+ for result in flat_results:
215
+ if result.entry.id not in seen_ids:
216
+ seen_ids.add(result.entry.id)
217
+ merged.append(result)
218
+
219
+ # 3. Apply read processors
220
+ read_context = ReadContext(limit=limit)
221
+ for processor in self.read_processors:
222
+ merged = await processor.process(merged, query, read_context)
223
+
224
+ if self.bus:
225
+ await self.bus.publish(Events.MEMORY_SEARCH, {
226
+ "query": query[:100],
227
+ "result_count": len(merged[:limit]),
228
+ })
229
+
230
+ return merged[:limit]
231
+
232
+ async def revert(
233
+ self,
234
+ session_id: str,
235
+ after_invocation_id: str,
236
+ ) -> list[str]:
237
+ """Revert memory entries after specified invocation."""
238
+ deleted = []
239
+
240
+ for store in self.stores.values():
241
+ ids = await store.revert(session_id, after_invocation_id)
242
+ deleted.extend(ids)
243
+
244
+ return deleted
245
+
246
+ async def on_compress(
247
+ self,
248
+ session_id: str,
249
+ invocation_id: str,
250
+ ejected_messages: list[dict[str, Any]],
251
+ ) -> str | None:
252
+ """Handle compression - save ejected messages to memory."""
253
+ if WriteTrigger.COMPRESS not in self.auto_triggers:
254
+ return None
255
+
256
+ content = self._format_messages(ejected_messages)
257
+
258
+ return await self.add(
259
+ content=content,
260
+ session_id=session_id,
261
+ invocation_id=invocation_id,
262
+ metadata={"type": "compressed"},
263
+ trigger=WriteTrigger.COMPRESS,
264
+ )
265
+
266
+ # ========== Summary & Recall API ==========
267
+
268
+ async def get_context(
269
+ self,
270
+ session_id: str,
271
+ invocation_ids: list[str] | None = None,
272
+ recall_limit: int = 10,
273
+ ) -> MemoryContext:
274
+ """Get memory context for LLM.
275
+
276
+ Args:
277
+ session_id: Session to get context for
278
+ invocation_ids: Filter recalls to these invocations (for isolation)
279
+ recall_limit: Max number of recalls to return
280
+
281
+ Returns:
282
+ MemoryContext with summary and recalls
283
+ """
284
+ # Get summary
285
+ summary = await self.get_summary(session_id)
286
+
287
+ # Get recalls, filtered by invocation chain if provided
288
+ recalls = await self.get_recalls(
289
+ session_id=session_id,
290
+ invocation_ids=invocation_ids,
291
+ limit=recall_limit,
292
+ )
293
+
294
+ return MemoryContext(summary=summary, recalls=recalls)
295
+
296
+ async def get_summary(self, session_id: str) -> MemorySummary | None:
297
+ """Get session summary."""
298
+ # Look in first store that has summaries
299
+ for store in self.stores.values():
300
+ if hasattr(store, 'get_summary'):
301
+ return await store.get_summary(session_id)
302
+
303
+ # Fallback: search for summary entry
304
+ results = await self.search(
305
+ query="conversation summary",
306
+ filter={"session_id": session_id, "type": "summary"},
307
+ limit=1,
308
+ )
309
+
310
+ if results:
311
+ entry = results[0].entry
312
+ return MemorySummary(
313
+ session_id=session_id,
314
+ content=entry.content,
315
+ last_invocation_id=entry.invocation_id or "",
316
+ )
317
+
318
+ return None
319
+
320
+ async def get_recalls(
321
+ self,
322
+ session_id: str,
323
+ invocation_ids: list[str] | None = None,
324
+ limit: int = 10,
325
+ ) -> list[MemoryRecall]:
326
+ """Get recalls for session, optionally filtered by invocations."""
327
+ filter_dict: dict[str, Any] = {"session_id": session_id, "type": "recall"}
328
+ if invocation_ids:
329
+ filter_dict["invocation_id"] = invocation_ids
330
+
331
+ # Search for recall entries
332
+ results = await self.search(
333
+ query="key points recalls",
334
+ filter=filter_dict,
335
+ limit=limit,
336
+ )
337
+
338
+ recalls = []
339
+ for r in results:
340
+ entry = r.entry
341
+ recalls.append(MemoryRecall(
342
+ id=entry.id,
343
+ session_id=session_id,
344
+ invocation_id=entry.invocation_id or "",
345
+ content=entry.content,
346
+ importance=entry.metadata.get("importance", 0.5),
347
+ tags=entry.metadata.get("tags", []),
348
+ ))
349
+
350
+ return recalls
351
+
352
+ async def add_recall(
353
+ self,
354
+ session_id: str,
355
+ invocation_id: str,
356
+ content: str,
357
+ importance: float = 0.5,
358
+ tags: list[str] | None = None,
359
+ ) -> str:
360
+ """Add a recall entry.
361
+
362
+ Returns:
363
+ Recall ID
364
+ """
365
+ recall_id = generate_id("recall")
366
+
367
+ await self.add(
368
+ content=content,
369
+ session_id=session_id,
370
+ invocation_id=invocation_id,
371
+ metadata={
372
+ "type": "recall",
373
+ "recall_id": recall_id,
374
+ "importance": importance,
375
+ "tags": tags or [],
376
+ },
377
+ trigger=WriteTrigger.MANUAL,
378
+ )
379
+
380
+ return recall_id
381
+
382
+ async def update_summary(
383
+ self,
384
+ session_id: str,
385
+ content: str,
386
+ last_invocation_id: str,
387
+ ) -> None:
388
+ """Update session summary."""
389
+ # Delete old summary
390
+ for store in self.stores.values():
391
+ if hasattr(store, 'delete_by_filter'):
392
+ await store.delete_by_filter({
393
+ "session_id": session_id,
394
+ "type": "summary",
395
+ })
396
+
397
+ # Add new summary
398
+ await self.add(
399
+ content=content,
400
+ session_id=session_id,
401
+ invocation_id=last_invocation_id,
402
+ metadata={"type": "summary"},
403
+ trigger=WriteTrigger.MANUAL,
404
+ )
405
+
406
+ async def delete_by_invocation(self, invocation_id: str) -> int:
407
+ """Delete all memory entries for an invocation (for revert).
408
+
409
+ Returns:
410
+ Number of entries deleted
411
+ """
412
+ count = 0
413
+ for store in self.stores.values():
414
+ if hasattr(store, 'delete_by_filter'):
415
+ deleted = await store.delete_by_filter({"invocation_id": invocation_id})
416
+ count += deleted if isinstance(deleted, int) else 0
417
+ return count
418
+
419
+ async def on_subagent_complete(
420
+ self,
421
+ sub_inv_id: str,
422
+ parent_inv_id: str,
423
+ merge_mode: str,
424
+ ) -> None:
425
+ """Handle SubAgent completion - merge memory based on mode.
426
+
427
+ Args:
428
+ sub_inv_id: SubAgent's invocation ID
429
+ parent_inv_id: Parent's invocation ID
430
+ merge_mode: "merge", "summarize", or "discard"
431
+ """
432
+ if merge_mode == "merge":
433
+ # Move all recalls from sub to parent
434
+ sub_recalls = await self.get_recalls(
435
+ session_id="", # Will be filtered by invocation
436
+ invocation_ids=[sub_inv_id],
437
+ limit=100,
438
+ )
439
+ for recall in sub_recalls:
440
+ await self.add_recall(
441
+ session_id=recall.session_id,
442
+ invocation_id=parent_inv_id,
443
+ content=recall.content,
444
+ importance=recall.importance,
445
+ tags=recall.tags,
446
+ )
447
+
448
+ elif merge_mode == "summarize":
449
+ # Create a summary recall in parent
450
+ sub_recalls = await self.get_recalls(
451
+ session_id="",
452
+ invocation_ids=[sub_inv_id],
453
+ limit=100,
454
+ )
455
+ if sub_recalls:
456
+ combined = "\n".join([r.content for r in sub_recalls])
457
+ await self.add_recall(
458
+ session_id=sub_recalls[0].session_id,
459
+ invocation_id=parent_inv_id,
460
+ content=f"[SubAgent result] {combined[:500]}...",
461
+ importance=0.7,
462
+ tags=["subagent_result"],
463
+ )
464
+
465
+ # "discard" mode: do nothing, sub's memory stays isolated
@@ -0,0 +1,177 @@
1
+ """Memory processors for filtering and transformation."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ from typing import Any, Protocol
7
+
8
+ from .store import MemoryEntry, ScoredEntry
9
+
10
+
11
+ class WriteDecision(Enum):
12
+ """Decision from write filter."""
13
+ SKIP = "skip"
14
+ PASS = "pass"
15
+ TRANSFORM = "transform"
16
+
17
+
18
+ @dataclass
19
+ class WriteResult:
20
+ """Result from write filter."""
21
+ decision: WriteDecision
22
+ entries: list[MemoryEntry] | None = None
23
+ reason: str | None = None
24
+
25
+
26
+ @dataclass
27
+ class WriteContext:
28
+ """Context for write operations."""
29
+ trigger: Any # WriteTrigger
30
+ session_id: str | None = None
31
+ invocation_id: str | None = None
32
+
33
+
34
+ @dataclass
35
+ class ProcessContext:
36
+ """Context for processing operations."""
37
+ session_id: str | None = None
38
+
39
+
40
+ @dataclass
41
+ class ReadContext:
42
+ """Context for read operations."""
43
+ session_id: str | None = None
44
+ limit: int = 10
45
+
46
+
47
+ class WriteFilter(Protocol):
48
+ """Write filter protocol."""
49
+
50
+ async def filter(
51
+ self,
52
+ entries: list[MemoryEntry],
53
+ context: WriteContext,
54
+ ) -> WriteResult:
55
+ """Filter entries before writing.
56
+
57
+ Returns WriteResult with decision.
58
+ """
59
+ ...
60
+
61
+
62
+ class MemoryProcessor(Protocol):
63
+ """Memory processor protocol."""
64
+
65
+ async def process(
66
+ self,
67
+ entries: list[MemoryEntry],
68
+ context: ProcessContext,
69
+ ) -> list[MemoryEntry]:
70
+ """Process entries, return transformed list."""
71
+ ...
72
+
73
+
74
+ class ReadPostProcessor(Protocol):
75
+ """Read post-processor protocol."""
76
+
77
+ async def process(
78
+ self,
79
+ results: list[ScoredEntry],
80
+ query: str,
81
+ context: ReadContext,
82
+ ) -> list[ScoredEntry]:
83
+ """Post-process search results."""
84
+ ...
85
+
86
+
87
+ class DeduplicationFilter:
88
+ """Filter duplicate content."""
89
+
90
+ def __init__(
91
+ self,
92
+ store: Any, # MemoryStore
93
+ similarity_threshold: float = 0.9,
94
+ ):
95
+ self.store = store
96
+ self.similarity_threshold = similarity_threshold
97
+
98
+ async def filter(
99
+ self,
100
+ entries: list[MemoryEntry],
101
+ context: WriteContext,
102
+ ) -> WriteResult:
103
+ """Check for duplicate content."""
104
+ for entry in entries:
105
+ # Search for similar content
106
+ results = await self.store.search(
107
+ query=entry.content,
108
+ filter={"session_id": entry.session_id} if entry.session_id else None,
109
+ limit=1,
110
+ )
111
+
112
+ if results and results[0].score > self.similarity_threshold:
113
+ return WriteResult(
114
+ decision=WriteDecision.SKIP,
115
+ reason=f"Duplicate found: {results[0].entry.id}",
116
+ )
117
+
118
+ return WriteResult(decision=WriteDecision.PASS)
119
+
120
+
121
+ class LengthFilter:
122
+ """Filter entries by content length."""
123
+
124
+ def __init__(self, min_length: int = 10, max_length: int = 10000):
125
+ self.min_length = min_length
126
+ self.max_length = max_length
127
+
128
+ async def filter(
129
+ self,
130
+ entries: list[MemoryEntry],
131
+ context: WriteContext,
132
+ ) -> WriteResult:
133
+ """Filter by content length."""
134
+ for entry in entries:
135
+ if len(entry.content) < self.min_length:
136
+ return WriteResult(
137
+ decision=WriteDecision.SKIP,
138
+ reason=f"Content too short: {len(entry.content)} < {self.min_length}",
139
+ )
140
+ if len(entry.content) > self.max_length:
141
+ return WriteResult(
142
+ decision=WriteDecision.SKIP,
143
+ reason=f"Content too long: {len(entry.content)} > {self.max_length}",
144
+ )
145
+
146
+ return WriteResult(decision=WriteDecision.PASS)
147
+
148
+
149
+ class TruncationProcessor:
150
+ """Truncate long content."""
151
+
152
+ def __init__(self, max_length: int = 5000):
153
+ self.max_length = max_length
154
+
155
+ async def process(
156
+ self,
157
+ entries: list[MemoryEntry],
158
+ context: ProcessContext,
159
+ ) -> list[MemoryEntry]:
160
+ """Truncate content if too long."""
161
+ result = []
162
+
163
+ for entry in entries:
164
+ if len(entry.content) > self.max_length:
165
+ truncated = MemoryEntry(
166
+ id=entry.id,
167
+ content=entry.content[:self.max_length] + "... (truncated)",
168
+ session_id=entry.session_id,
169
+ invocation_id=entry.invocation_id,
170
+ created_at=entry.created_at,
171
+ metadata={**entry.metadata, "truncated": True},
172
+ )
173
+ result.append(truncated)
174
+ else:
175
+ result.append(entry)
176
+
177
+ return result