aury-agent 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. aury/__init__.py +2 -0
  2. aury/agents/__init__.py +55 -0
  3. aury/agents/a2a/__init__.py +168 -0
  4. aury/agents/backends/__init__.py +196 -0
  5. aury/agents/backends/artifact/__init__.py +9 -0
  6. aury/agents/backends/artifact/memory.py +130 -0
  7. aury/agents/backends/artifact/types.py +133 -0
  8. aury/agents/backends/code/__init__.py +65 -0
  9. aury/agents/backends/file/__init__.py +11 -0
  10. aury/agents/backends/file/local.py +66 -0
  11. aury/agents/backends/file/types.py +40 -0
  12. aury/agents/backends/invocation/__init__.py +8 -0
  13. aury/agents/backends/invocation/memory.py +81 -0
  14. aury/agents/backends/invocation/types.py +110 -0
  15. aury/agents/backends/memory/__init__.py +8 -0
  16. aury/agents/backends/memory/memory.py +179 -0
  17. aury/agents/backends/memory/types.py +136 -0
  18. aury/agents/backends/message/__init__.py +9 -0
  19. aury/agents/backends/message/memory.py +122 -0
  20. aury/agents/backends/message/types.py +124 -0
  21. aury/agents/backends/sandbox.py +275 -0
  22. aury/agents/backends/session/__init__.py +8 -0
  23. aury/agents/backends/session/memory.py +93 -0
  24. aury/agents/backends/session/types.py +124 -0
  25. aury/agents/backends/shell/__init__.py +11 -0
  26. aury/agents/backends/shell/local.py +110 -0
  27. aury/agents/backends/shell/types.py +55 -0
  28. aury/agents/backends/shell.py +209 -0
  29. aury/agents/backends/snapshot/__init__.py +19 -0
  30. aury/agents/backends/snapshot/git.py +95 -0
  31. aury/agents/backends/snapshot/hybrid.py +125 -0
  32. aury/agents/backends/snapshot/memory.py +86 -0
  33. aury/agents/backends/snapshot/types.py +59 -0
  34. aury/agents/backends/state/__init__.py +29 -0
  35. aury/agents/backends/state/composite.py +49 -0
  36. aury/agents/backends/state/file.py +57 -0
  37. aury/agents/backends/state/memory.py +52 -0
  38. aury/agents/backends/state/sqlite.py +262 -0
  39. aury/agents/backends/state/types.py +178 -0
  40. aury/agents/backends/subagent/__init__.py +165 -0
  41. aury/agents/cli/__init__.py +41 -0
  42. aury/agents/cli/chat.py +239 -0
  43. aury/agents/cli/config.py +236 -0
  44. aury/agents/cli/extensions.py +460 -0
  45. aury/agents/cli/main.py +189 -0
  46. aury/agents/cli/session.py +337 -0
  47. aury/agents/cli/workflow.py +276 -0
  48. aury/agents/context_providers/__init__.py +66 -0
  49. aury/agents/context_providers/artifact.py +299 -0
  50. aury/agents/context_providers/base.py +177 -0
  51. aury/agents/context_providers/memory.py +70 -0
  52. aury/agents/context_providers/message.py +130 -0
  53. aury/agents/context_providers/skill.py +50 -0
  54. aury/agents/context_providers/subagent.py +46 -0
  55. aury/agents/context_providers/tool.py +68 -0
  56. aury/agents/core/__init__.py +83 -0
  57. aury/agents/core/base.py +573 -0
  58. aury/agents/core/context.py +797 -0
  59. aury/agents/core/context_builder.py +303 -0
  60. aury/agents/core/event_bus/__init__.py +15 -0
  61. aury/agents/core/event_bus/bus.py +203 -0
  62. aury/agents/core/factory.py +169 -0
  63. aury/agents/core/isolator.py +97 -0
  64. aury/agents/core/logging.py +95 -0
  65. aury/agents/core/parallel.py +194 -0
  66. aury/agents/core/runner.py +139 -0
  67. aury/agents/core/services/__init__.py +5 -0
  68. aury/agents/core/services/file_session.py +144 -0
  69. aury/agents/core/services/message.py +53 -0
  70. aury/agents/core/services/session.py +53 -0
  71. aury/agents/core/signals.py +109 -0
  72. aury/agents/core/state.py +363 -0
  73. aury/agents/core/types/__init__.py +107 -0
  74. aury/agents/core/types/action.py +176 -0
  75. aury/agents/core/types/artifact.py +135 -0
  76. aury/agents/core/types/block.py +736 -0
  77. aury/agents/core/types/message.py +350 -0
  78. aury/agents/core/types/recall.py +144 -0
  79. aury/agents/core/types/session.py +257 -0
  80. aury/agents/core/types/subagent.py +154 -0
  81. aury/agents/core/types/tool.py +205 -0
  82. aury/agents/eval/__init__.py +331 -0
  83. aury/agents/hitl/__init__.py +57 -0
  84. aury/agents/hitl/ask_user.py +242 -0
  85. aury/agents/hitl/compaction.py +230 -0
  86. aury/agents/hitl/exceptions.py +87 -0
  87. aury/agents/hitl/permission.py +617 -0
  88. aury/agents/hitl/revert.py +216 -0
  89. aury/agents/llm/__init__.py +31 -0
  90. aury/agents/llm/adapter.py +367 -0
  91. aury/agents/llm/openai.py +294 -0
  92. aury/agents/llm/provider.py +476 -0
  93. aury/agents/mcp/__init__.py +153 -0
  94. aury/agents/memory/__init__.py +46 -0
  95. aury/agents/memory/compaction.py +394 -0
  96. aury/agents/memory/manager.py +465 -0
  97. aury/agents/memory/processor.py +177 -0
  98. aury/agents/memory/store.py +187 -0
  99. aury/agents/memory/types.py +137 -0
  100. aury/agents/messages/__init__.py +40 -0
  101. aury/agents/messages/config.py +47 -0
  102. aury/agents/messages/raw_store.py +224 -0
  103. aury/agents/messages/store.py +118 -0
  104. aury/agents/messages/types.py +88 -0
  105. aury/agents/middleware/__init__.py +31 -0
  106. aury/agents/middleware/base.py +341 -0
  107. aury/agents/middleware/chain.py +342 -0
  108. aury/agents/middleware/message.py +129 -0
  109. aury/agents/middleware/message_container.py +126 -0
  110. aury/agents/middleware/raw_message.py +153 -0
  111. aury/agents/middleware/truncation.py +139 -0
  112. aury/agents/middleware/types.py +81 -0
  113. aury/agents/plugin.py +162 -0
  114. aury/agents/react/__init__.py +4 -0
  115. aury/agents/react/agent.py +1923 -0
  116. aury/agents/sandbox/__init__.py +23 -0
  117. aury/agents/sandbox/local.py +239 -0
  118. aury/agents/sandbox/remote.py +200 -0
  119. aury/agents/sandbox/types.py +115 -0
  120. aury/agents/skill/__init__.py +16 -0
  121. aury/agents/skill/loader.py +180 -0
  122. aury/agents/skill/types.py +83 -0
  123. aury/agents/tool/__init__.py +39 -0
  124. aury/agents/tool/builtin/__init__.py +23 -0
  125. aury/agents/tool/builtin/ask_user.py +155 -0
  126. aury/agents/tool/builtin/bash.py +107 -0
  127. aury/agents/tool/builtin/delegate.py +726 -0
  128. aury/agents/tool/builtin/edit.py +121 -0
  129. aury/agents/tool/builtin/plan.py +277 -0
  130. aury/agents/tool/builtin/read.py +91 -0
  131. aury/agents/tool/builtin/thinking.py +111 -0
  132. aury/agents/tool/builtin/yield_result.py +130 -0
  133. aury/agents/tool/decorator.py +252 -0
  134. aury/agents/tool/set.py +204 -0
  135. aury/agents/usage/__init__.py +12 -0
  136. aury/agents/usage/tracker.py +236 -0
  137. aury/agents/workflow/__init__.py +85 -0
  138. aury/agents/workflow/adapter.py +268 -0
  139. aury/agents/workflow/dag.py +116 -0
  140. aury/agents/workflow/dsl.py +575 -0
  141. aury/agents/workflow/executor.py +659 -0
  142. aury/agents/workflow/expression.py +136 -0
  143. aury/agents/workflow/parser.py +182 -0
  144. aury/agents/workflow/state.py +145 -0
  145. aury/agents/workflow/types.py +86 -0
  146. aury_agent-0.0.4.dist-info/METADATA +90 -0
  147. aury_agent-0.0.4.dist-info/RECORD +149 -0
  148. aury_agent-0.0.4.dist-info/WHEEL +4 -0
  149. aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,187 @@
1
+ """Memory store protocol and implementations."""
2
+ from __future__ import annotations
3
+
4
+ import hashlib
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime
7
+ from typing import Any, Protocol, runtime_checkable
8
+
9
+
10
+ @dataclass
11
+ class MemoryEntry:
12
+ """A memory entry."""
13
+ id: str
14
+ content: str
15
+ session_id: str | None = None
16
+ invocation_id: str | None = None
17
+ created_at: datetime = field(default_factory=datetime.now)
18
+ metadata: dict[str, Any] = field(default_factory=dict)
19
+
20
+ def to_dict(self) -> dict[str, Any]:
21
+ return {
22
+ "id": self.id,
23
+ "content": self.content,
24
+ "session_id": self.session_id,
25
+ "invocation_id": self.invocation_id,
26
+ "created_at": self.created_at.isoformat(),
27
+ "metadata": self.metadata,
28
+ }
29
+
30
+ @classmethod
31
+ def from_dict(cls, data: dict[str, Any]) -> MemoryEntry:
32
+ return cls(
33
+ id=data["id"],
34
+ content=data["content"],
35
+ session_id=data.get("session_id"),
36
+ invocation_id=data.get("invocation_id"),
37
+ created_at=datetime.fromisoformat(data["created_at"])
38
+ if "created_at" in data else datetime.now(),
39
+ metadata=data.get("metadata", {}),
40
+ )
41
+
42
+ @property
43
+ def content_hash(self) -> str:
44
+ """Get hash of content for deduplication."""
45
+ return hashlib.sha256(self.content.encode()).hexdigest()[:16]
46
+
47
+
48
+ @dataclass
49
+ class ScoredEntry:
50
+ """Memory entry with relevance score."""
51
+ entry: MemoryEntry
52
+ score: float
53
+ source: str = "default"
54
+
55
+
56
+ @runtime_checkable
57
+ class MemoryStore(Protocol):
58
+ """Memory store protocol."""
59
+
60
+ async def add(self, entry: MemoryEntry) -> str:
61
+ """Add entry, return ID."""
62
+ ...
63
+
64
+ async def search(
65
+ self,
66
+ query: str,
67
+ filter: dict[str, Any] | None = None,
68
+ limit: int = 10,
69
+ ) -> list[ScoredEntry]:
70
+ """Search for relevant entries."""
71
+ ...
72
+
73
+ async def get(self, entry_id: str) -> MemoryEntry | None:
74
+ """Get entry by ID."""
75
+ ...
76
+
77
+ async def remove(self, entry_id: str) -> None:
78
+ """Remove entry."""
79
+ ...
80
+
81
+ async def revert(
82
+ self,
83
+ session_id: str,
84
+ after_invocation_id: str,
85
+ ) -> list[str]:
86
+ """Remove entries after specified invocation.
87
+
88
+ Returns list of deleted IDs.
89
+ """
90
+ ...
91
+
92
+
93
+ class InMemoryStore:
94
+ """Simple in-memory store for testing."""
95
+
96
+ def __init__(self) -> None:
97
+ self._entries: dict[str, MemoryEntry] = {}
98
+ self._content_hashes: dict[str, str] = {} # hash -> entry_id
99
+
100
+ async def add(self, entry: MemoryEntry) -> str:
101
+ """Add entry with deduplication."""
102
+ content_hash = entry.content_hash
103
+
104
+ # Check for duplicate
105
+ if content_hash in self._content_hashes:
106
+ return self._content_hashes[content_hash]
107
+
108
+ self._entries[entry.id] = entry
109
+ self._content_hashes[content_hash] = entry.id
110
+ return entry.id
111
+
112
+ async def search(
113
+ self,
114
+ query: str,
115
+ filter: dict[str, Any] | None = None,
116
+ limit: int = 10,
117
+ ) -> list[ScoredEntry]:
118
+ """Simple keyword search."""
119
+ filter = filter or {}
120
+ results = []
121
+
122
+ query_lower = query.lower()
123
+ query_words = set(query_lower.split())
124
+
125
+ for entry in self._entries.values():
126
+ # Apply filters
127
+ if filter:
128
+ skip = False
129
+ for key, value in filter.items():
130
+ entry_value = getattr(entry, key, None) or entry.metadata.get(key)
131
+ if entry_value != value:
132
+ skip = True
133
+ break
134
+ if skip:
135
+ continue
136
+
137
+ # Calculate simple relevance score
138
+ content_lower = entry.content.lower()
139
+ content_words = set(content_lower.split())
140
+
141
+ # Word overlap score
142
+ overlap = len(query_words & content_words)
143
+ if overlap > 0:
144
+ score = overlap / len(query_words)
145
+ results.append(ScoredEntry(entry=entry, score=score, source="memory"))
146
+ elif query_lower in content_lower:
147
+ # Substring match
148
+ results.append(ScoredEntry(entry=entry, score=0.5, source="memory"))
149
+
150
+ # Sort by score and limit
151
+ results.sort(key=lambda x: x.score, reverse=True)
152
+ return results[:limit]
153
+
154
+ async def get(self, entry_id: str) -> MemoryEntry | None:
155
+ return self._entries.get(entry_id)
156
+
157
+ async def remove(self, entry_id: str) -> None:
158
+ if entry_id in self._entries:
159
+ entry = self._entries[entry_id]
160
+ content_hash = entry.content_hash
161
+ self._content_hashes.pop(content_hash, None)
162
+ del self._entries[entry_id]
163
+
164
+ async def revert(
165
+ self,
166
+ session_id: str,
167
+ after_invocation_id: str,
168
+ ) -> list[str]:
169
+ """Remove entries for invocations after specified one."""
170
+ # Simple implementation: compare invocation IDs lexicographically
171
+ # In production, use timestamps or sequence numbers
172
+ to_delete = []
173
+
174
+ for entry_id, entry in self._entries.items():
175
+ if entry.session_id == session_id:
176
+ if entry.invocation_id and entry.invocation_id > after_invocation_id:
177
+ to_delete.append(entry_id)
178
+
179
+ for entry_id in to_delete:
180
+ await self.remove(entry_id)
181
+
182
+ return to_delete
183
+
184
+ def clear(self) -> None:
185
+ """Clear all entries."""
186
+ self._entries.clear()
187
+ self._content_hashes.clear()
@@ -0,0 +1,137 @@
1
+ """Memory type definitions.
2
+
3
+ Core types for the memory system:
4
+ - MemorySummary: Compressed overview of conversation history
5
+ - MemoryRecall: Key points extracted from invocations
6
+ - MemoryContext: Combined context for LLM (summary + recalls)
7
+ """
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+ from datetime import datetime
12
+ from typing import Any
13
+
14
+
15
+ @dataclass
16
+ class MemorySummary:
17
+ """Compressed overview of conversation history.
18
+
19
+ Represents the "big picture" of a session's conversation.
20
+ Updated incrementally as invocations complete.
21
+ """
22
+ session_id: str
23
+ content: str # Summary text
24
+ last_invocation_id: str # Last invocation included in summary
25
+ updated_at: datetime = field(default_factory=datetime.now)
26
+ token_count: int = 0 # Estimated token count
27
+ metadata: dict[str, Any] = field(default_factory=dict)
28
+
29
+ def to_dict(self) -> dict[str, Any]:
30
+ return {
31
+ "session_id": self.session_id,
32
+ "content": self.content,
33
+ "last_invocation_id": self.last_invocation_id,
34
+ "updated_at": self.updated_at.isoformat(),
35
+ "token_count": self.token_count,
36
+ "metadata": self.metadata,
37
+ }
38
+
39
+ @classmethod
40
+ def from_dict(cls, data: dict[str, Any]) -> "MemorySummary":
41
+ return cls(
42
+ session_id=data["session_id"],
43
+ content=data["content"],
44
+ last_invocation_id=data["last_invocation_id"],
45
+ updated_at=datetime.fromisoformat(data["updated_at"]),
46
+ token_count=data.get("token_count", 0),
47
+ metadata=data.get("metadata", {}),
48
+ )
49
+
50
+
51
+ @dataclass
52
+ class MemoryRecall:
53
+ """Key point extracted from an invocation.
54
+
55
+ Represents an important piece of information that should be
56
+ recalled when building LLM context. Linked to specific invocation
57
+ for isolation and revert support.
58
+ """
59
+ id: str
60
+ session_id: str
61
+ invocation_id: str # Which invocation this came from
62
+ content: str # Recall content
63
+ importance: float = 0.5 # 0.0 - 1.0, higher = more important
64
+ tags: list[str] = field(default_factory=list)
65
+ created_at: datetime = field(default_factory=datetime.now)
66
+ metadata: dict[str, Any] = field(default_factory=dict)
67
+
68
+ def to_dict(self) -> dict[str, Any]:
69
+ return {
70
+ "id": self.id,
71
+ "session_id": self.session_id,
72
+ "invocation_id": self.invocation_id,
73
+ "content": self.content,
74
+ "importance": self.importance,
75
+ "tags": self.tags,
76
+ "created_at": self.created_at.isoformat(),
77
+ "metadata": self.metadata,
78
+ }
79
+
80
+ @classmethod
81
+ def from_dict(cls, data: dict[str, Any]) -> "MemoryRecall":
82
+ return cls(
83
+ id=data["id"],
84
+ session_id=data["session_id"],
85
+ invocation_id=data["invocation_id"],
86
+ content=data["content"],
87
+ importance=data.get("importance", 0.5),
88
+ tags=data.get("tags", []),
89
+ created_at=datetime.fromisoformat(data["created_at"]),
90
+ metadata=data.get("metadata", {}),
91
+ )
92
+
93
+
94
+ @dataclass
95
+ class MemoryContext:
96
+ """Combined memory context for LLM.
97
+
98
+ This is what gets injected into the LLM prompt to provide
99
+ historical context. Contains:
100
+ - summary: High-level overview of conversation
101
+ - recalls: Specific important points
102
+ """
103
+ summary: MemorySummary | None = None
104
+ recalls: list[MemoryRecall] = field(default_factory=list)
105
+
106
+ def to_system_message(self) -> str:
107
+ """Format as system message content for LLM."""
108
+ parts = []
109
+
110
+ if self.summary and self.summary.content:
111
+ parts.append(f"## Conversation History Overview\n{self.summary.content}")
112
+
113
+ if self.recalls:
114
+ recalls_text = "\n".join([
115
+ f"- [{', '.join(r.tags) if r.tags else 'note'}] {r.content}"
116
+ for r in sorted(self.recalls, key=lambda x: x.importance, reverse=True)
117
+ ])
118
+ parts.append(f"## Key Points\n{recalls_text}")
119
+
120
+ return "\n\n".join(parts) if parts else ""
121
+
122
+ @property
123
+ def is_empty(self) -> bool:
124
+ """Check if context has any content."""
125
+ return (not self.summary or not self.summary.content) and not self.recalls
126
+
127
+ @property
128
+ def total_recalls(self) -> int:
129
+ """Get total number of recalls."""
130
+ return len(self.recalls)
131
+
132
+
133
+ __all__ = [
134
+ "MemorySummary",
135
+ "MemoryRecall",
136
+ "MemoryContext",
137
+ ]
@@ -0,0 +1,40 @@
1
+ """Message types for conversation history.
2
+
3
+ Note: Message persistence is now handled by:
4
+ - MessageBackend (backends/message/): Storage layer
5
+ - MessageBackendMiddleware (middleware/message.py): Save via on_message_save hook
6
+ - MessageContextProvider (context_providers/message.py): Fetch for context
7
+
8
+ This module provides message types used across the system.
9
+ """
10
+ from .types import (
11
+ MessageRole,
12
+ Message,
13
+ )
14
+ from .store import (
15
+ MessageStore,
16
+ InMemoryMessageStore,
17
+ )
18
+ from .raw_store import (
19
+ RawMessageStore,
20
+ StateBackendRawMessageStore,
21
+ InMemoryRawMessageStore,
22
+ )
23
+ from .config import (
24
+ MessageConfig,
25
+ )
26
+
27
+ __all__ = [
28
+ # Types
29
+ "MessageRole",
30
+ "Message",
31
+ # Store (protocol + in-memory for testing)
32
+ "MessageStore",
33
+ "InMemoryMessageStore",
34
+ # Raw Store
35
+ "RawMessageStore",
36
+ "StateBackendRawMessageStore",
37
+ "InMemoryRawMessageStore",
38
+ # Config
39
+ "MessageConfig",
40
+ ]
@@ -0,0 +1,47 @@
1
+ """Message storage configuration.
2
+
3
+ Configures how messages are stored and retrieved.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Literal
9
+
10
+
11
+ @dataclass
12
+ class MessageConfig:
13
+ """Configuration for message storage.
14
+
15
+ Controls raw message storage and recall behavior.
16
+
17
+ Attributes:
18
+ enable_raw_store: Whether to store complete messages in RawMessageStore.
19
+ Required for HITL recovery and full-context recall.
20
+ persist_raw: Whether to keep raw messages after invocation completes.
21
+ False = clean up after invocation (default, saves space)
22
+ True = keep forever (for audit/recall)
23
+ recall_mode: How to build context for LLM recall.
24
+ "mixed" = previous invocations truncated + current invocation raw
25
+ "raw" = all raw messages (requires persist_raw=True or current inv only)
26
+
27
+ Example:
28
+ # Default: raw for recovery, clean up after
29
+ config = MessageConfig()
30
+
31
+ # Keep all raw messages for full recall
32
+ config = MessageConfig(persist_raw=True, recall_mode="raw")
33
+
34
+ # Disable raw storage (no HITL recovery)
35
+ config = MessageConfig(enable_raw_store=False)
36
+ """
37
+ enable_raw_store: bool = True
38
+ persist_raw: bool = False
39
+ recall_mode: Literal["mixed", "raw"] = "mixed"
40
+
41
+ def __post_init__(self):
42
+ # Validate: raw recall mode requires raw storage
43
+ if self.recall_mode == "raw" and not self.enable_raw_store:
44
+ raise ValueError("recall_mode='raw' requires enable_raw_store=True")
45
+
46
+
47
+ __all__ = ["MessageConfig"]
@@ -0,0 +1,224 @@
1
+ """Raw message store for complete message storage.
2
+
3
+ RawMessageStore stores complete, untruncated messages for:
4
+ - HITL recovery (restore exact state)
5
+ - Full-context recall (when truncated history is insufficient)
6
+ - Audit/debugging
7
+
8
+ Messages are stored per invocation and can be cleaned up after invocation completes.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from typing import Any, Protocol, runtime_checkable
13
+ from datetime import datetime
14
+
15
+ from ..core.types.session import generate_id
16
+
17
+
18
+ @runtime_checkable
19
+ class RawMessageStore(Protocol):
20
+ """Protocol for raw message storage.
21
+
22
+ Stores complete messages per invocation.
23
+ """
24
+
25
+ async def add(
26
+ self,
27
+ invocation_id: str,
28
+ message: dict[str, Any],
29
+ ) -> str:
30
+ """Add a message and return its ID.
31
+
32
+ Args:
33
+ invocation_id: Invocation this message belongs to
34
+ message: Complete message dict
35
+
36
+ Returns:
37
+ Generated message ID
38
+ """
39
+ ...
40
+
41
+ async def get(self, msg_id: str) -> dict[str, Any] | None:
42
+ """Get a single message by ID.
43
+
44
+ Args:
45
+ msg_id: Message ID
46
+
47
+ Returns:
48
+ Message dict or None if not found
49
+ """
50
+ ...
51
+
52
+ async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
53
+ """Get multiple messages by IDs.
54
+
55
+ Args:
56
+ msg_ids: List of message IDs
57
+
58
+ Returns:
59
+ List of message dicts (in same order, None for missing)
60
+ """
61
+ ...
62
+
63
+ async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
64
+ """Get all messages for an invocation.
65
+
66
+ Args:
67
+ invocation_id: Invocation ID
68
+
69
+ Returns:
70
+ List of messages in chronological order
71
+ """
72
+ ...
73
+
74
+ async def delete_by_invocation(self, invocation_id: str) -> int:
75
+ """Delete all messages for an invocation.
76
+
77
+ Args:
78
+ invocation_id: Invocation ID
79
+
80
+ Returns:
81
+ Number of messages deleted
82
+ """
83
+ ...
84
+
85
+
86
+ class StateBackendRawMessageStore:
87
+ """RawMessageStore implementation backed by StateBackend."""
88
+
89
+ def __init__(self, backend: Any): # StateBackend
90
+ """Initialize with StateBackend.
91
+
92
+ Args:
93
+ backend: StateBackend instance
94
+ """
95
+ self._backend = backend
96
+
97
+ async def add(
98
+ self,
99
+ invocation_id: str,
100
+ message: dict[str, Any],
101
+ ) -> str:
102
+ """Add a message using StateBackend."""
103
+ msg_id = generate_id("rmsg")
104
+
105
+ # Store message with metadata
106
+ await self._backend.set("raw_messages", msg_id, {
107
+ "id": msg_id,
108
+ "invocation_id": invocation_id,
109
+ "message": message,
110
+ "created_at": datetime.now().isoformat(),
111
+ })
112
+
113
+ # Add to invocation's message list
114
+ inv_key = f"raw_msg_ids:{invocation_id}"
115
+ msg_ids = await self._backend.get("raw_messages", inv_key) or []
116
+ msg_ids.append(msg_id)
117
+ await self._backend.set("raw_messages", inv_key, msg_ids)
118
+
119
+ return msg_id
120
+
121
+ async def get(self, msg_id: str) -> dict[str, Any] | None:
122
+ """Get a single message."""
123
+ data = await self._backend.get("raw_messages", msg_id)
124
+ if data:
125
+ return data.get("message")
126
+ return None
127
+
128
+ async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
129
+ """Get multiple messages."""
130
+ results = []
131
+ for msg_id in msg_ids:
132
+ msg = await self.get(msg_id)
133
+ if msg:
134
+ results.append(msg)
135
+ return results
136
+
137
+ async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
138
+ """Get all messages for an invocation."""
139
+ inv_key = f"raw_msg_ids:{invocation_id}"
140
+ msg_ids = await self._backend.get("raw_messages", inv_key) or []
141
+ return await self.get_many(msg_ids)
142
+
143
+ async def delete_by_invocation(self, invocation_id: str) -> int:
144
+ """Delete all messages for an invocation."""
145
+ inv_key = f"raw_msg_ids:{invocation_id}"
146
+ msg_ids = await self._backend.get("raw_messages", inv_key) or []
147
+
148
+ # Delete each message
149
+ for msg_id in msg_ids:
150
+ await self._backend.remove("raw_messages", msg_id)
151
+
152
+ # Delete the index
153
+ await self._backend.remove("raw_messages", inv_key)
154
+
155
+ return len(msg_ids)
156
+
157
+
158
+ class InMemoryRawMessageStore:
159
+ """In-memory raw message store for testing."""
160
+
161
+ def __init__(self) -> None:
162
+ self._messages: dict[str, dict[str, Any]] = {} # msg_id -> message data
163
+ self._invocation_index: dict[str, list[str]] = {} # inv_id -> [msg_ids]
164
+
165
+ async def add(
166
+ self,
167
+ invocation_id: str,
168
+ message: dict[str, Any],
169
+ ) -> str:
170
+ """Add a message."""
171
+ msg_id = generate_id("rmsg")
172
+
173
+ self._messages[msg_id] = {
174
+ "id": msg_id,
175
+ "invocation_id": invocation_id,
176
+ "message": message,
177
+ "created_at": datetime.now().isoformat(),
178
+ }
179
+
180
+ if invocation_id not in self._invocation_index:
181
+ self._invocation_index[invocation_id] = []
182
+ self._invocation_index[invocation_id].append(msg_id)
183
+
184
+ return msg_id
185
+
186
+ async def get(self, msg_id: str) -> dict[str, Any] | None:
187
+ """Get a single message."""
188
+ data = self._messages.get(msg_id)
189
+ if data:
190
+ return data.get("message")
191
+ return None
192
+
193
+ async def get_many(self, msg_ids: list[str]) -> list[dict[str, Any]]:
194
+ """Get multiple messages."""
195
+ results = []
196
+ for msg_id in msg_ids:
197
+ msg = await self.get(msg_id)
198
+ if msg:
199
+ results.append(msg)
200
+ return results
201
+
202
+ async def get_by_invocation(self, invocation_id: str) -> list[dict[str, Any]]:
203
+ """Get all messages for an invocation."""
204
+ msg_ids = self._invocation_index.get(invocation_id, [])
205
+ return await self.get_many(msg_ids)
206
+
207
+ async def delete_by_invocation(self, invocation_id: str) -> int:
208
+ """Delete all messages for an invocation."""
209
+ msg_ids = self._invocation_index.pop(invocation_id, [])
210
+ for msg_id in msg_ids:
211
+ self._messages.pop(msg_id, None)
212
+ return len(msg_ids)
213
+
214
+ def clear(self) -> None:
215
+ """Clear all messages (for testing)."""
216
+ self._messages.clear()
217
+ self._invocation_index.clear()
218
+
219
+
220
+ __all__ = [
221
+ "RawMessageStore",
222
+ "StateBackendRawMessageStore",
223
+ "InMemoryRawMessageStore",
224
+ ]