agent-memory-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ from agent_memory.models import MemoryAction, MemoryDecision, RetrievalResult
4
+ from agent_memory.policy import DecisionPolicy, DefaultPolicy
5
+
6
+
7
+ def build_score_breakdown(
8
+ result: RetrievalResult,
9
+ policy: DecisionPolicy,
10
+ ) -> dict[str, float]:
11
+ if isinstance(policy, DefaultPolicy):
12
+ return policy.score_breakdown(result.entry, result.semantic_score, result.keyword_score)
13
+
14
+ return {
15
+ "semantic_score": result.semantic_score,
16
+ "keyword_score": result.keyword_score,
17
+ "recency_score": 0.0,
18
+ "confidence_score": result.entry.confidence,
19
+ "usage_score": 0.0,
20
+ "policy_score": result.final_score,
21
+ "final_score": result.decision_score,
22
+ }
23
+
24
+
25
+ def build_reasons(result: RetrievalResult, policy: DecisionPolicy) -> list[str]:
26
+ reasons: list[str] = []
27
+ breakdown = build_score_breakdown(result, policy)
28
+
29
+ if result.semantic_score >= 0.85:
30
+ reasons.append("high semantic match")
31
+ elif result.semantic_score >= 0.70:
32
+ reasons.append("moderate semantic match")
33
+
34
+ if result.keyword_score >= 0.50:
35
+ reasons.append("keyword match")
36
+
37
+ if breakdown.get("recency_score", 0) >= 0.80:
38
+ reasons.append("recent memory")
39
+ elif result.entry.access_count > 0:
40
+ reasons.append("prior usage")
41
+
42
+ if result.entry.access_count >= 3:
43
+ reasons.append("frequently used")
44
+
45
+ if result.entry.confidence >= 0.90:
46
+ reasons.append("high confidence memory")
47
+
48
+ if result.entry.requires_verification:
49
+ reasons.append("requires verification")
50
+
51
+ if result.entry.type.value in {"fact", "tool_output", "workflow"}:
52
+ reasons.append("freshness-sensitive type")
53
+
54
+ return reasons
55
+
56
+
57
+ def enrich_decision(
58
+ decision: MemoryDecision,
59
+ policy: DecisionPolicy,
60
+ *,
61
+ replay_threshold: float,
62
+ restore_threshold: float,
63
+ ) -> MemoryDecision:
64
+ """Attach matched IDs, reason tags, and score breakdown to a decision."""
65
+ if not decision.context:
66
+ decision.matched = []
67
+ decision.reasons = ["no matching memories"]
68
+ return decision
69
+
70
+ best = decision.context[0]
71
+ decision.matched = [r.entry.id for r in decision.context]
72
+ decision.reasons = build_reasons(best, policy)
73
+ decision.scores = build_score_breakdown(best, policy)
74
+
75
+ if decision.action == MemoryAction.REPLAY and best.decision_score >= replay_threshold:
76
+ if "high semantic match" not in decision.reasons and best.semantic_score >= 0.85:
77
+ decision.reasons.insert(0, "exact or near-exact query match")
78
+ elif decision.action == MemoryAction.NONE:
79
+ if best.decision_score < restore_threshold:
80
+ decision.reasons.append("below restore threshold")
81
+
82
+ return decision
83
+
84
+
85
+ def format_explanation(decision: MemoryDecision) -> str:
86
+ if not decision.scores:
87
+ return "No score breakdown available (no memory matches)."
88
+
89
+ lines = [
90
+ f"action: {decision.action.value}",
91
+ f"confidence: {decision.confidence:.2f}",
92
+ f"matched: {', '.join(decision.matched) or 'none'}",
93
+ "",
94
+ ]
95
+
96
+ if decision.reasons:
97
+ lines.append("reasons:")
98
+ for reason in decision.reasons:
99
+ lines.append(f" - {reason}")
100
+ lines.append("")
101
+
102
+ lines.append("scores:")
103
+ label_map = {
104
+ "semantic_score": "semantic_score",
105
+ "keyword_score": "keyword_score",
106
+ "recency_score": "recency_score",
107
+ "confidence_score": "confidence_score",
108
+ "usage_score": "usage_score",
109
+ "policy_score": "policy_score",
110
+ "final_score": "final_score",
111
+ }
112
+ for key, label in label_map.items():
113
+ if key in decision.scores:
114
+ lines.append(f" {label}: {decision.scores[key]:.2f}")
115
+
116
+ return "\n".join(lines)
@@ -0,0 +1,348 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from pathlib import Path
5
+
6
+ from agent_memory.decision import DecisionEngine
7
+ from agent_memory.models import MemoryDecision, MemoryEntry, MemoryScope, MemoryState, MemoryType
8
+ from agent_memory.policy import DecisionPolicy, DefaultPolicy
9
+ from agent_memory.retriever import MemoryRetriever
10
+ from agent_memory.sqlite_store import SqliteMemoryStore
11
+ from agent_memory.store import ChromaDBStore, MemoryStore
12
+ from agent_memory.ttl import parse_ttl
13
+
14
+
15
+ class Memory:
16
+ """High-level SDK for persistent agent memory with a decision layer."""
17
+
18
+ def __init__(
19
+ self,
20
+ persist_dir: str | Path = ".agent_memory",
21
+ collection_name: str = "agent_memories",
22
+ policy: DecisionPolicy | None = None,
23
+ replay_threshold: float = 0.85,
24
+ restore_threshold: float = 0.70,
25
+ verify_threshold: float = 0.80,
26
+ backend: str = "sqlite", # "chromadb" or "sqlite"
27
+ ) -> None:
28
+ self.store: MemoryStore
29
+ if backend == "sqlite":
30
+ self.store = SqliteMemoryStore(persist_dir=persist_dir, collection_name=collection_name)
31
+ elif backend == "chromadb":
32
+ self.store = ChromaDBStore(persist_dir=persist_dir, collection_name=collection_name)
33
+ else:
34
+ raise ValueError(f"Unknown backend: {backend}. Use 'sqlite' or 'chromadb'")
35
+ self._policy = policy or DefaultPolicy()
36
+ self.retriever = MemoryRetriever(self.store, policy=self._policy)
37
+ self.decision_engine = DecisionEngine(
38
+ self.retriever,
39
+ policy=self._policy,
40
+ replay_threshold=replay_threshold,
41
+ restore_threshold=restore_threshold,
42
+ verify_threshold=verify_threshold,
43
+ )
44
+
45
+ def remember(
46
+ self,
47
+ query: str,
48
+ response: str,
49
+ *,
50
+ content: str | None = None,
51
+ type: MemoryType | str = MemoryType.CONVERSATION,
52
+ scope: MemoryScope | str = MemoryScope.USER,
53
+ metadata: dict | None = None,
54
+ tags: list[str] | None = None,
55
+ confidence: float = 1.0,
56
+ requires_verification: bool = False,
57
+ ttl: str | int | float | None = None,
58
+ ) -> MemoryEntry:
59
+ memory_type = MemoryType(type) if isinstance(type, str) else type
60
+ memory_scope = MemoryScope(scope) if isinstance(scope, str) else scope
61
+ entry = MemoryEntry(
62
+ query=query,
63
+ response=response,
64
+ content=content or response,
65
+ type=memory_type,
66
+ scope=memory_scope,
67
+ metadata=metadata or {},
68
+ tags=tags or [],
69
+ confidence=confidence,
70
+ requires_verification=requires_verification,
71
+ expires_at=parse_ttl(ttl),
72
+ )
73
+ entry.refresh_state()
74
+ return self.store.store(entry)
75
+
76
+ async def aremember(
77
+ self,
78
+ query: str,
79
+ response: str,
80
+ *,
81
+ content: str | None = None,
82
+ type: MemoryType | str = MemoryType.CONVERSATION,
83
+ scope: MemoryScope | str = MemoryScope.USER,
84
+ metadata: dict | None = None,
85
+ tags: list[str] | None = None,
86
+ confidence: float = 1.0,
87
+ requires_verification: bool = False,
88
+ ttl: str | int | float | None = None,
89
+ ) -> MemoryEntry:
90
+ """Async version of remember()."""
91
+ return await asyncio.to_thread(
92
+ self.remember,
93
+ query,
94
+ response,
95
+ content=content,
96
+ type=type,
97
+ scope=scope,
98
+ metadata=metadata,
99
+ tags=tags,
100
+ confidence=confidence,
101
+ requires_verification=requires_verification,
102
+ ttl=ttl,
103
+ )
104
+
105
+ def resolve(
106
+ self,
107
+ query: str,
108
+ *,
109
+ mode: str = "auto",
110
+ top_k: int = 3,
111
+ scope: list[MemoryScope | str] | None = None,
112
+ enable_verify: bool = True,
113
+ ) -> MemoryDecision:
114
+ scopes = None
115
+ if scope:
116
+ scopes = [MemoryScope(s) if isinstance(s, str) else s for s in scope]
117
+ return self.decision_engine.decide(
118
+ query,
119
+ mode=mode,
120
+ top_k=top_k,
121
+ scopes=scopes,
122
+ enable_verify=enable_verify,
123
+ )
124
+
125
+ async def aresolve(
126
+ self,
127
+ query: str,
128
+ *,
129
+ mode: str = "auto",
130
+ top_k: int = 3,
131
+ scope: list[MemoryScope | str] | None = None,
132
+ enable_verify: bool = True,
133
+ ) -> MemoryDecision:
134
+ """Async version of resolve()."""
135
+ return await asyncio.to_thread(
136
+ self.resolve,
137
+ query,
138
+ mode=mode,
139
+ top_k=top_k,
140
+ scope=scope,
141
+ enable_verify=enable_verify,
142
+ )
143
+
144
+ def list(
145
+ self,
146
+ limit: int = 100,
147
+ offset: int = 0,
148
+ *,
149
+ scope: list[MemoryScope | str] | None = None,
150
+ include_archived: bool = False,
151
+ type: MemoryType | str | None = None,
152
+ ) -> list[MemoryEntry]:
153
+ scopes = [MemoryScope(s) if isinstance(s, str) else s for s in scope] if scope else None
154
+ memory_type = MemoryType(type) if isinstance(type, str) else type
155
+ return self.store.list_all(
156
+ limit=limit,
157
+ offset=offset,
158
+ scopes=scopes,
159
+ include_archived=include_archived,
160
+ memory_type=memory_type,
161
+ )
162
+
163
+ async def alist(
164
+ self,
165
+ limit: int = 100,
166
+ offset: int = 0,
167
+ *,
168
+ scope: list[MemoryScope | str] | None = None,
169
+ include_archived: bool = False,
170
+ type: MemoryType | str | None = None,
171
+ ) -> list[MemoryEntry]:
172
+ """Async version of list()."""
173
+ return await asyncio.to_thread(
174
+ self.list,
175
+ limit=limit,
176
+ offset=offset,
177
+ scope=scope,
178
+ include_archived=include_archived,
179
+ type=type,
180
+ )
181
+
182
+ def get(self, memory_id: str) -> MemoryEntry | None:
183
+ return self.store.get(memory_id)
184
+
185
+ async def aget(self, memory_id: str) -> MemoryEntry | None:
186
+ """Async version of get()."""
187
+ return await asyncio.to_thread(self.get, memory_id)
188
+
189
+ def forget(self, memory_id: str) -> bool:
190
+ return self.store.delete(memory_id)
191
+
192
+ async def aforget(self, memory_id: str) -> bool:
193
+ """Async version of forget()."""
194
+ return await asyncio.to_thread(self.forget, memory_id)
195
+
196
+ def archive(self, memory_id: str) -> MemoryEntry | None:
197
+ entry = self.store.get(memory_id)
198
+ if not entry:
199
+ return None
200
+ entry.archived = True
201
+ entry.refresh_state()
202
+ return self.store.update(entry)
203
+
204
+ async def aarchive(self, memory_id: str) -> MemoryEntry | None:
205
+ """Async version of archive()."""
206
+ return await asyncio.to_thread(self.archive, memory_id)
207
+
208
+ def cleanup(self, *, delete: bool = False) -> dict[str, int]:
209
+ """
210
+ Mark expired memories as expired, optionally deleting them.
211
+
212
+ Returns counts: {"expired": N, "deleted": M}
213
+ """
214
+ entries = self.store.list_all(limit=10_000, include_archived=True, include_expired=True)
215
+ expired_count = 0
216
+ deleted_count = 0
217
+
218
+ for entry in entries:
219
+ entry.refresh_state()
220
+ if not entry.is_expired:
221
+ continue
222
+ if delete:
223
+ if self.store.delete(entry.id):
224
+ deleted_count += 1
225
+ else:
226
+ if entry.state != MemoryState.EXPIRED:
227
+ entry.state = MemoryState.EXPIRED
228
+ self.store.update(entry)
229
+ expired_count += 1
230
+
231
+ return {"expired": expired_count, "deleted": deleted_count}
232
+
233
+ async def acleanup(self, *, delete: bool = False) -> dict[str, int]:
234
+ """Async version of cleanup()."""
235
+ return await asyncio.to_thread(self.cleanup, delete=delete)
236
+
237
+ def stats(self) -> dict:
238
+ """Return aggregate memory and usage statistics."""
239
+ entries = self.store.list_all(limit=10_000, include_archived=True, include_expired=True)
240
+ by_state: dict[str, int] = {}
241
+ by_type: dict[str, int] = {}
242
+ total_access = 0
243
+
244
+ for entry in entries:
245
+ entry.refresh_state()
246
+ by_state[entry.state.value] = by_state.get(entry.state.value, 0) + 1
247
+ by_type[entry.type.value] = by_type.get(entry.type.value, 0) + 1
248
+ total_access += entry.access_count
249
+
250
+ return {
251
+ "total": len(entries),
252
+ "by_state": by_state,
253
+ "by_type": by_type,
254
+ "total_access_count": total_access,
255
+ }
256
+
257
+ async def astats(self) -> dict:
258
+ """Async version of stats()."""
259
+ return await asyncio.to_thread(self.stats)
260
+
261
+ def consolidate(self, similarity_threshold: float = 0.95) -> list[MemoryEntry]:
262
+ """
263
+ Merge near-duplicate active memories into summary entries.
264
+ Returns newly created summary memories.
265
+ """
266
+ entries = self.store.list_all(limit=10_000)
267
+ created: list[MemoryEntry] = []
268
+ seen: set[str] = set()
269
+
270
+ for entry in entries:
271
+ if entry.id in seen or entry.archived or entry.type == MemoryType.SUMMARY:
272
+ continue
273
+
274
+ duplicates = [entry]
275
+ for other in entries:
276
+ if other.id == entry.id or other.id in seen or other.archived:
277
+ continue
278
+ if entry.type != other.type or entry.scope != other.scope:
279
+ continue
280
+ hits = self.store.search(entry.query, top_k=5)
281
+ for hit, score in hits:
282
+ if hit.id == other.id and score >= similarity_threshold:
283
+ duplicates.append(other)
284
+ break
285
+
286
+ if len(duplicates) < 2:
287
+ continue
288
+
289
+ for dup in duplicates:
290
+ seen.add(dup.id)
291
+ dup.archived = True
292
+ self.store.update(dup)
293
+
294
+ merged_query = duplicates[0].query
295
+ merged_content = "\n".join(f"- {d.response}" for d in duplicates)
296
+ summary = self.remember(
297
+ query=merged_query,
298
+ response=merged_content,
299
+ content=merged_content,
300
+ type=MemoryType.SUMMARY,
301
+ scope=duplicates[0].scope,
302
+ tags=list({tag for d in duplicates for tag in d.tags}),
303
+ metadata={"consolidated_from": [d.id for d in duplicates]},
304
+ )
305
+ created.append(summary)
306
+
307
+ return created
308
+
309
+ async def aconsolidate(self, similarity_threshold: float = 0.95) -> list[MemoryEntry]:
310
+ """Async version of consolidate()."""
311
+ return await asyncio.to_thread(self.consolidate, similarity_threshold=similarity_threshold)
312
+
313
+ def format_restore_context(self, decision: MemoryDecision) -> str:
314
+ if not decision.context:
315
+ return ""
316
+
317
+ lines = ["## Retrieved Memory Context", ""]
318
+ for result in decision.context:
319
+ entry = result.entry
320
+ lines.append(
321
+ f"### Memory {result.rank} (score: {result.final_score:.2f}, type: {entry.type.value})"
322
+ )
323
+ lines.append(f"**Prior query:** {entry.query}")
324
+ lines.append(f"**Prior response:** {entry.response}")
325
+ if entry.tags:
326
+ lines.append(f"**Tags:** {', '.join(entry.tags)}")
327
+ lines.append("")
328
+ return "\n".join(lines)
329
+
330
+ def format_verify_context(self, decision: MemoryDecision) -> str:
331
+ if not decision.memory:
332
+ return ""
333
+ entry = decision.memory
334
+ return (
335
+ "## Memory Pending Verification\n\n"
336
+ f"**Query:** {entry.query}\n"
337
+ f"**Stored response:** {entry.response}\n"
338
+ f"**Type:** {entry.type.value}\n"
339
+ f"**Last updated:** {entry.updated_at.isoformat()}\n\n"
340
+ "Validate this memory with available tools before reusing or regenerating."
341
+ )
342
+
343
+ # Backward compatibility
344
+ def list_memories(self, limit: int = 100, offset: int = 0) -> list[MemoryEntry]:
345
+ return self.list(limit=limit, offset=offset)
346
+
347
+
348
+ MemoryManager = Memory
agent_memory/models.py ADDED
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime, timezone
4
+ from enum import Enum
5
+ from typing import Any
6
+ from uuid import uuid4
7
+
8
+ from pydantic import BaseModel, Field, model_validator
9
+
10
+ # Memory types that default to verify-before-reuse when similarity is moderate.
11
+ VERIFY_TYPES: frozenset[str] = frozenset({"fact", "tool_output", "workflow"})
12
+
13
+
14
+ class MemoryAction(str, Enum):
15
+ """How the agent should use a retrieved memory."""
16
+
17
+ REPLAY = "replay"
18
+ RESTORE = "restore"
19
+ VERIFY = "verify"
20
+ NONE = "none"
21
+
22
+
23
+ class MemoryType(str, Enum):
24
+ CONVERSATION = "conversation"
25
+ FACT = "fact"
26
+ WORKFLOW = "workflow"
27
+ TOOL_OUTPUT = "tool_output"
28
+ DOCUMENT = "document"
29
+ SUMMARY = "summary"
30
+ CODE = "code"
31
+ PREFERENCE = "preference"
32
+
33
+
34
+ class MemoryScope(str, Enum):
35
+ SESSION = "session"
36
+ USER = "user"
37
+ PROJECT = "project"
38
+ WORKSPACE = "workspace"
39
+ TEAM = "team"
40
+ GLOBAL = "global"
41
+
42
+
43
+ class MemoryState(str, Enum):
44
+ ACTIVE = "active"
45
+ ARCHIVED = "archived"
46
+ EXPIRED = "expired"
47
+ DELETED = "deleted"
48
+
49
+
50
+ class MemoryEntry(BaseModel):
51
+ """A stored memory with structured metadata."""
52
+
53
+ id: str = Field(default_factory=lambda: str(uuid4()))
54
+ query: str
55
+ response: str
56
+ content: str = ""
57
+ type: MemoryType = MemoryType.CONVERSATION
58
+ scope: MemoryScope = MemoryScope.USER
59
+ metadata: dict[str, Any] = Field(default_factory=dict)
60
+ tags: list[str] = Field(default_factory=list)
61
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0)
62
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
63
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
64
+ access_count: int = 0
65
+ last_accessed_at: datetime | None = None
66
+ archived: bool = False
67
+ requires_verification: bool = False
68
+ expires_at: datetime | None = None
69
+ state: MemoryState = MemoryState.ACTIVE
70
+
71
+ @property
72
+ def is_expired(self) -> bool:
73
+ if self.state == MemoryState.EXPIRED:
74
+ return True
75
+ if self.expires_at is None:
76
+ return False
77
+ now = datetime.now(timezone.utc)
78
+ expires = self.expires_at
79
+ if expires.tzinfo is None:
80
+ expires = expires.replace(tzinfo=timezone.utc)
81
+ return now >= expires
82
+
83
+ def refresh_state(self) -> None:
84
+ if self.archived:
85
+ self.state = MemoryState.ARCHIVED
86
+ elif self.is_expired:
87
+ self.state = MemoryState.EXPIRED
88
+ else:
89
+ self.state = MemoryState.ACTIVE
90
+
91
+ @model_validator(mode="after")
92
+ def default_content(self) -> MemoryEntry:
93
+ if not self.content:
94
+ self.content = self.response
95
+ return self
96
+
97
+ def touch(self) -> None:
98
+ self.access_count += 1
99
+ self.last_accessed_at = datetime.now(timezone.utc)
100
+ self.updated_at = datetime.now(timezone.utc)
101
+
102
+ def to_dict(self) -> dict[str, Any]:
103
+ return {
104
+ "id": self.id,
105
+ "content": self.content,
106
+ "query": self.query,
107
+ "response": self.response,
108
+ "type": self.type.value,
109
+ "scope": self.scope.value,
110
+ "created_at": self.created_at.isoformat(),
111
+ "updated_at": self.updated_at.isoformat(),
112
+ "last_accessed": self.last_accessed_at.isoformat() if self.last_accessed_at else None,
113
+ "access_count": self.access_count,
114
+ "confidence": self.confidence,
115
+ "tags": self.tags,
116
+ "archived": self.archived,
117
+ "state": self.state.value,
118
+ "expires_at": self.expires_at.isoformat() if self.expires_at else None,
119
+ }
120
+
121
+
122
+ class RetrievalResult(BaseModel):
123
+ """A memory match from hybrid retrieval."""
124
+
125
+ entry: MemoryEntry
126
+ semantic_score: float
127
+ keyword_score: float = 0.0
128
+ final_score: float = 0.0
129
+ rank: int = 0
130
+
131
+ @property
132
+ def similarity(self) -> float:
133
+ """Backward-compatible alias for decision score."""
134
+ return self.decision_score
135
+
136
+ @property
137
+ def decision_score(self) -> float:
138
+ """Score used for action thresholds — never below raw retrieval signals."""
139
+ retrieval = max(self.semantic_score, 0.7 * self.semantic_score + 0.3 * self.keyword_score)
140
+ return max(self.final_score, retrieval)
141
+
142
+
143
+ class MemoryDecision(BaseModel):
144
+ """Agent-facing decision on how to answer a query."""
145
+
146
+ action: MemoryAction
147
+ query: str
148
+ confidence: float
149
+ response: str | None = None
150
+ memory: MemoryEntry | None = None
151
+ context: list[RetrievalResult] = Field(default_factory=list)
152
+ reason: str = ""
153
+ matched: list[str] = Field(default_factory=list)
154
+ reasons: list[str] = Field(default_factory=list)
155
+ scores: dict[str, float] = Field(default_factory=dict)
156
+
157
+ def explain(self) -> str:
158
+ from agent_memory.explain import format_explanation
159
+
160
+ return format_explanation(self)
161
+
162
+ def __repr__(self) -> str:
163
+ matched = ", ".join(self.matched[:3])
164
+ if len(self.matched) > 3:
165
+ matched += ", ..."
166
+ return (
167
+ f"Decision(action={self.action.value!r}, confidence={self.confidence:.2f}, "
168
+ f"matched=[{matched}], reasons={self.reasons!r})"
169
+ )
170
+
171
+ def __str__(self) -> str:
172
+ return self.__repr__()