tribalmemory 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tribalmemory/__init__.py +3 -0
- tribalmemory/a21/__init__.py +38 -0
- tribalmemory/a21/config/__init__.py +20 -0
- tribalmemory/a21/config/providers.py +104 -0
- tribalmemory/a21/config/system.py +184 -0
- tribalmemory/a21/container/__init__.py +8 -0
- tribalmemory/a21/container/container.py +212 -0
- tribalmemory/a21/providers/__init__.py +32 -0
- tribalmemory/a21/providers/base.py +241 -0
- tribalmemory/a21/providers/deduplication.py +99 -0
- tribalmemory/a21/providers/lancedb.py +232 -0
- tribalmemory/a21/providers/memory.py +128 -0
- tribalmemory/a21/providers/mock.py +54 -0
- tribalmemory/a21/providers/openai.py +151 -0
- tribalmemory/a21/providers/timestamp.py +88 -0
- tribalmemory/a21/system.py +293 -0
- tribalmemory/cli.py +298 -0
- tribalmemory/interfaces.py +306 -0
- tribalmemory/mcp/__init__.py +9 -0
- tribalmemory/mcp/__main__.py +6 -0
- tribalmemory/mcp/server.py +484 -0
- tribalmemory/performance/__init__.py +1 -0
- tribalmemory/performance/benchmarks.py +285 -0
- tribalmemory/performance/corpus_generator.py +171 -0
- tribalmemory/portability/__init__.py +1 -0
- tribalmemory/portability/embedding_metadata.py +320 -0
- tribalmemory/server/__init__.py +9 -0
- tribalmemory/server/__main__.py +6 -0
- tribalmemory/server/app.py +187 -0
- tribalmemory/server/config.py +115 -0
- tribalmemory/server/models.py +206 -0
- tribalmemory/server/routes.py +378 -0
- tribalmemory/services/__init__.py +15 -0
- tribalmemory/services/deduplication.py +115 -0
- tribalmemory/services/embeddings.py +273 -0
- tribalmemory/services/import_export.py +506 -0
- tribalmemory/services/memory.py +275 -0
- tribalmemory/services/vector_store.py +360 -0
- tribalmemory/testing/__init__.py +22 -0
- tribalmemory/testing/embedding_utils.py +110 -0
- tribalmemory/testing/fixtures.py +123 -0
- tribalmemory/testing/metrics.py +256 -0
- tribalmemory/testing/mocks.py +560 -0
- tribalmemory/testing/semantic_expansions.py +91 -0
- tribalmemory/utils.py +23 -0
- tribalmemory-0.1.0.dist-info/METADATA +275 -0
- tribalmemory-0.1.0.dist-info/RECORD +51 -0
- tribalmemory-0.1.0.dist-info/WHEEL +5 -0
- tribalmemory-0.1.0.dist-info/entry_points.txt +3 -0
- tribalmemory-0.1.0.dist-info/licenses/LICENSE +190 -0
- tribalmemory-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""Tribal Memory Service - Main API for agents."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Optional
|
|
6
|
+
import uuid
|
|
7
|
+
|
|
8
|
+
from ..interfaces import (
|
|
9
|
+
IMemoryService,
|
|
10
|
+
IEmbeddingService,
|
|
11
|
+
IVectorStore,
|
|
12
|
+
MemoryEntry,
|
|
13
|
+
MemorySource,
|
|
14
|
+
RecallResult,
|
|
15
|
+
StoreResult,
|
|
16
|
+
)
|
|
17
|
+
from .deduplication import SemanticDeduplicationService
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TribalMemoryService(IMemoryService):
|
|
21
|
+
"""Production tribal memory service.
|
|
22
|
+
|
|
23
|
+
Usage:
|
|
24
|
+
service = TribalMemoryService(
|
|
25
|
+
instance_id="clawdio-1",
|
|
26
|
+
embedding_service=embedding_service,
|
|
27
|
+
vector_store=vector_store
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
await service.remember("Joe prefers TypeScript")
|
|
31
|
+
results = await service.recall("What language for Wally?")
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
instance_id: str,
|
|
37
|
+
embedding_service: IEmbeddingService,
|
|
38
|
+
vector_store: IVectorStore,
|
|
39
|
+
dedup_exact_threshold: float = 0.98,
|
|
40
|
+
dedup_near_threshold: float = 0.90,
|
|
41
|
+
auto_reject_duplicates: bool = True,
|
|
42
|
+
):
|
|
43
|
+
self.instance_id = instance_id
|
|
44
|
+
self.embedding_service = embedding_service
|
|
45
|
+
self.vector_store = vector_store
|
|
46
|
+
self.auto_reject_duplicates = auto_reject_duplicates
|
|
47
|
+
|
|
48
|
+
self.dedup_service = SemanticDeduplicationService(
|
|
49
|
+
vector_store=vector_store,
|
|
50
|
+
embedding_service=embedding_service,
|
|
51
|
+
exact_threshold=dedup_exact_threshold,
|
|
52
|
+
near_threshold=dedup_near_threshold,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
async def remember(
|
|
56
|
+
self,
|
|
57
|
+
content: str,
|
|
58
|
+
source_type: MemorySource = MemorySource.AUTO_CAPTURE,
|
|
59
|
+
context: Optional[str] = None,
|
|
60
|
+
tags: Optional[list[str]] = None,
|
|
61
|
+
skip_dedup: bool = False,
|
|
62
|
+
) -> StoreResult:
|
|
63
|
+
"""Store a new memory."""
|
|
64
|
+
if not content or not content.strip():
|
|
65
|
+
return StoreResult(success=False, error="TribalMemory: Empty content not allowed")
|
|
66
|
+
|
|
67
|
+
content = content.strip()
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
embedding = await self.embedding_service.embed(content)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
return StoreResult(success=False, error=f"Embedding generation failed: {e}")
|
|
73
|
+
|
|
74
|
+
if not skip_dedup and self.auto_reject_duplicates:
|
|
75
|
+
is_dup, dup_id = await self.dedup_service.is_duplicate(content, embedding)
|
|
76
|
+
if is_dup:
|
|
77
|
+
return StoreResult(success=False, duplicate_of=dup_id)
|
|
78
|
+
|
|
79
|
+
entry = MemoryEntry(
|
|
80
|
+
id=str(uuid.uuid4()),
|
|
81
|
+
content=content,
|
|
82
|
+
embedding=embedding,
|
|
83
|
+
source_instance=self.instance_id,
|
|
84
|
+
source_type=source_type,
|
|
85
|
+
created_at=datetime.utcnow(),
|
|
86
|
+
updated_at=datetime.utcnow(),
|
|
87
|
+
tags=tags or [],
|
|
88
|
+
context=context,
|
|
89
|
+
confidence=1.0,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return await self.vector_store.store(entry)
|
|
93
|
+
|
|
94
|
+
async def recall(
|
|
95
|
+
self,
|
|
96
|
+
query: str,
|
|
97
|
+
limit: int = 5,
|
|
98
|
+
min_relevance: float = 0.7,
|
|
99
|
+
tags: Optional[list[str]] = None,
|
|
100
|
+
) -> list[RecallResult]:
|
|
101
|
+
"""Recall relevant memories.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
query: Natural language query
|
|
105
|
+
limit: Maximum results
|
|
106
|
+
min_relevance: Minimum similarity score
|
|
107
|
+
tags: Filter by tags (e.g., ["work", "preferences"])
|
|
108
|
+
"""
|
|
109
|
+
try:
|
|
110
|
+
query_embedding = await self.embedding_service.embed(query)
|
|
111
|
+
except Exception:
|
|
112
|
+
return []
|
|
113
|
+
|
|
114
|
+
filters = {"tags": tags} if tags else None
|
|
115
|
+
|
|
116
|
+
results = await self.vector_store.recall(
|
|
117
|
+
query_embedding,
|
|
118
|
+
limit=limit,
|
|
119
|
+
min_similarity=min_relevance,
|
|
120
|
+
filters=filters,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return self._filter_superseded(results)
|
|
124
|
+
|
|
125
|
+
async def correct(
|
|
126
|
+
self,
|
|
127
|
+
original_id: str,
|
|
128
|
+
corrected_content: str,
|
|
129
|
+
context: Optional[str] = None,
|
|
130
|
+
) -> StoreResult:
|
|
131
|
+
"""Store a correction to an existing memory."""
|
|
132
|
+
original = await self.vector_store.get(original_id)
|
|
133
|
+
if not original:
|
|
134
|
+
return StoreResult(success=False, error=f"Original memory {original_id} not found")
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
embedding = await self.embedding_service.embed(corrected_content)
|
|
138
|
+
except Exception as e:
|
|
139
|
+
return StoreResult(success=False, error=f"Embedding generation failed: {e}")
|
|
140
|
+
|
|
141
|
+
entry = MemoryEntry(
|
|
142
|
+
id=str(uuid.uuid4()),
|
|
143
|
+
content=corrected_content,
|
|
144
|
+
embedding=embedding,
|
|
145
|
+
source_instance=self.instance_id,
|
|
146
|
+
source_type=MemorySource.CORRECTION,
|
|
147
|
+
created_at=datetime.utcnow(),
|
|
148
|
+
updated_at=datetime.utcnow(),
|
|
149
|
+
tags=original.tags,
|
|
150
|
+
context=context or f"Correction of memory {original_id}",
|
|
151
|
+
confidence=1.0,
|
|
152
|
+
supersedes=original_id,
|
|
153
|
+
related_to=[original_id],
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return await self.vector_store.store(entry)
|
|
157
|
+
|
|
158
|
+
async def forget(self, memory_id: str) -> bool:
|
|
159
|
+
"""Forget (soft delete) a memory."""
|
|
160
|
+
return await self.vector_store.delete(memory_id)
|
|
161
|
+
|
|
162
|
+
async def get(self, memory_id: str) -> Optional[MemoryEntry]:
|
|
163
|
+
"""Get a memory by ID with full provenance."""
|
|
164
|
+
return await self.vector_store.get(memory_id)
|
|
165
|
+
|
|
166
|
+
async def get_stats(self) -> dict:
|
|
167
|
+
"""Get memory statistics.
|
|
168
|
+
|
|
169
|
+
Note: Stats are computed over up to 10,000 most recent memories.
|
|
170
|
+
For systems with >10k memories, consider using count() with filters.
|
|
171
|
+
"""
|
|
172
|
+
all_memories = await self.vector_store.list(limit=10000)
|
|
173
|
+
|
|
174
|
+
by_source: dict[str, int] = {}
|
|
175
|
+
by_instance: dict[str, int] = {}
|
|
176
|
+
by_tag: dict[str, int] = {}
|
|
177
|
+
|
|
178
|
+
for m in all_memories:
|
|
179
|
+
source = m.source_type.value
|
|
180
|
+
by_source[source] = by_source.get(source, 0) + 1
|
|
181
|
+
|
|
182
|
+
instance = m.source_instance
|
|
183
|
+
by_instance[instance] = by_instance.get(instance, 0) + 1
|
|
184
|
+
|
|
185
|
+
for tag in m.tags:
|
|
186
|
+
by_tag[tag] = by_tag.get(tag, 0) + 1
|
|
187
|
+
|
|
188
|
+
corrections = sum(1 for m in all_memories if m.supersedes)
|
|
189
|
+
|
|
190
|
+
return {
|
|
191
|
+
"total_memories": len(all_memories),
|
|
192
|
+
"by_source_type": by_source,
|
|
193
|
+
"by_tag": by_tag,
|
|
194
|
+
"by_instance": by_instance,
|
|
195
|
+
"corrections": corrections,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
@staticmethod
|
|
199
|
+
def _filter_superseded(results: list[RecallResult]) -> list[RecallResult]:
|
|
200
|
+
"""Remove memories that are superseded by corrections in the result set."""
|
|
201
|
+
superseded_ids = {
|
|
202
|
+
r.memory.supersedes for r in results if r.memory.supersedes
|
|
203
|
+
}
|
|
204
|
+
if not superseded_ids:
|
|
205
|
+
return results
|
|
206
|
+
return [r for r in results if r.memory.id not in superseded_ids]
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def create_memory_service(
|
|
210
|
+
instance_id: Optional[str] = None,
|
|
211
|
+
db_path: Optional[str] = None,
|
|
212
|
+
openai_api_key: Optional[str] = None,
|
|
213
|
+
api_base: Optional[str] = None,
|
|
214
|
+
embedding_model: Optional[str] = None,
|
|
215
|
+
embedding_dimensions: Optional[int] = None,
|
|
216
|
+
) -> TribalMemoryService:
|
|
217
|
+
"""Factory function to create a memory service with sensible defaults.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
instance_id: Unique identifier for this agent instance.
|
|
221
|
+
db_path: Path for LanceDB persistent storage. If None, uses in-memory.
|
|
222
|
+
openai_api_key: API key. Falls back to OPENAI_API_KEY env var.
|
|
223
|
+
Not required for local models (when api_base is set).
|
|
224
|
+
api_base: Base URL for the embedding API.
|
|
225
|
+
For Ollama: "http://localhost:11434/v1"
|
|
226
|
+
embedding_model: Embedding model name. Default: "text-embedding-3-small".
|
|
227
|
+
embedding_dimensions: Embedding output dimensions. Default: 1536.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Configured TribalMemoryService ready for use.
|
|
231
|
+
|
|
232
|
+
Warning:
|
|
233
|
+
If db_path is provided but LanceDB is not installed, falls back to
|
|
234
|
+
in-memory storage. This means data will NOT persist across restarts.
|
|
235
|
+
"""
|
|
236
|
+
import logging
|
|
237
|
+
|
|
238
|
+
from .embeddings import OpenAIEmbeddingService
|
|
239
|
+
from .vector_store import InMemoryVectorStore, LanceDBVectorStore
|
|
240
|
+
|
|
241
|
+
logger = logging.getLogger(__name__)
|
|
242
|
+
|
|
243
|
+
if not instance_id:
|
|
244
|
+
instance_id = os.environ.get("TRIBAL_MEMORY_INSTANCE_ID", "default")
|
|
245
|
+
|
|
246
|
+
kwargs: dict = {"api_key": openai_api_key}
|
|
247
|
+
if api_base is not None:
|
|
248
|
+
kwargs["api_base"] = api_base
|
|
249
|
+
if embedding_model is not None:
|
|
250
|
+
kwargs["model"] = embedding_model
|
|
251
|
+
if embedding_dimensions is not None:
|
|
252
|
+
kwargs["dimensions"] = embedding_dimensions
|
|
253
|
+
|
|
254
|
+
embedding_service = OpenAIEmbeddingService(**kwargs)
|
|
255
|
+
|
|
256
|
+
if db_path:
|
|
257
|
+
try:
|
|
258
|
+
vector_store = LanceDBVectorStore(
|
|
259
|
+
embedding_service=embedding_service,
|
|
260
|
+
db_path=db_path
|
|
261
|
+
)
|
|
262
|
+
except ImportError:
|
|
263
|
+
logger.warning(
|
|
264
|
+
"LanceDB not installed. Falling back to in-memory storage. "
|
|
265
|
+
"Data will NOT persist across restarts. Install with: pip install lancedb"
|
|
266
|
+
)
|
|
267
|
+
vector_store = InMemoryVectorStore(embedding_service)
|
|
268
|
+
else:
|
|
269
|
+
vector_store = InMemoryVectorStore(embedding_service)
|
|
270
|
+
|
|
271
|
+
return TribalMemoryService(
|
|
272
|
+
instance_id=instance_id,
|
|
273
|
+
embedding_service=embedding_service,
|
|
274
|
+
vector_store=vector_store
|
|
275
|
+
)
|
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
"""Vector Store implementations.
|
|
2
|
+
|
|
3
|
+
Provides both LanceDB (persistent) and in-memory storage options.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Optional, Union
|
|
12
|
+
|
|
13
|
+
from ..interfaces import (
|
|
14
|
+
IVectorStore,
|
|
15
|
+
IEmbeddingService,
|
|
16
|
+
MemoryEntry,
|
|
17
|
+
MemorySource,
|
|
18
|
+
RecallResult,
|
|
19
|
+
StoreResult,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LanceDBVectorStore(IVectorStore):
|
|
24
|
+
"""LanceDB-backed vector store for persistent storage.
|
|
25
|
+
|
|
26
|
+
Supports both local file storage and LanceDB Cloud.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
TABLE_NAME = "memories"
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
embedding_service: IEmbeddingService,
|
|
34
|
+
db_path: Optional[Union[str, Path]] = None,
|
|
35
|
+
db_uri: Optional[str] = None,
|
|
36
|
+
api_key: Optional[str] = None,
|
|
37
|
+
):
|
|
38
|
+
self.embedding_service = embedding_service
|
|
39
|
+
self.db_path = Path(db_path) if db_path else None
|
|
40
|
+
self.db_uri = db_uri
|
|
41
|
+
self.api_key = api_key or os.environ.get("LANCEDB_API_KEY")
|
|
42
|
+
|
|
43
|
+
self._db = None
|
|
44
|
+
self._table = None
|
|
45
|
+
self._initialized = False
|
|
46
|
+
|
|
47
|
+
async def _ensure_initialized(self):
|
|
48
|
+
"""Lazily initialize database connection."""
|
|
49
|
+
if self._initialized:
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
import lancedb
|
|
54
|
+
except ImportError:
|
|
55
|
+
raise ImportError("LanceDB not installed. Run: pip install lancedb")
|
|
56
|
+
|
|
57
|
+
if self.db_uri:
|
|
58
|
+
self._db = lancedb.connect(self.db_uri, api_key=self.api_key)
|
|
59
|
+
elif self.db_path:
|
|
60
|
+
self.db_path.mkdir(parents=True, exist_ok=True)
|
|
61
|
+
self._db = lancedb.connect(str(self.db_path))
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError("Either db_path or db_uri must be provided")
|
|
64
|
+
|
|
65
|
+
if self.TABLE_NAME in self._db.table_names():
|
|
66
|
+
self._table = self._db.open_table(self.TABLE_NAME)
|
|
67
|
+
else:
|
|
68
|
+
self._table = self._create_table()
|
|
69
|
+
|
|
70
|
+
self._initialized = True
|
|
71
|
+
|
|
72
|
+
def _create_table(self) -> "lancedb.table.Table":
|
|
73
|
+
"""Create the memories table with the defined schema."""
|
|
74
|
+
import pyarrow as pa
|
|
75
|
+
|
|
76
|
+
schema = pa.schema([
|
|
77
|
+
pa.field("id", pa.string()),
|
|
78
|
+
pa.field("content", pa.string()),
|
|
79
|
+
pa.field("vector", pa.list_(pa.float32(), self._get_embedding_dim())),
|
|
80
|
+
pa.field("source_instance", pa.string()),
|
|
81
|
+
pa.field("source_type", pa.string()),
|
|
82
|
+
pa.field("created_at", pa.string()),
|
|
83
|
+
pa.field("updated_at", pa.string()),
|
|
84
|
+
pa.field("tags", pa.string()),
|
|
85
|
+
pa.field("context", pa.string()),
|
|
86
|
+
pa.field("confidence", pa.float32()),
|
|
87
|
+
pa.field("supersedes", pa.string()),
|
|
88
|
+
pa.field("related_to", pa.string()),
|
|
89
|
+
pa.field("deleted", pa.bool_()),
|
|
90
|
+
])
|
|
91
|
+
|
|
92
|
+
return self._db.create_table(self.TABLE_NAME, schema=schema)
|
|
93
|
+
|
|
94
|
+
def _get_embedding_dim(self) -> int:
|
|
95
|
+
"""Get the expected embedding dimension from the embedding service."""
|
|
96
|
+
if hasattr(self.embedding_service, 'dimensions'):
|
|
97
|
+
return self.embedding_service.dimensions
|
|
98
|
+
return 1536 # Default for text-embedding-3-small
|
|
99
|
+
|
|
100
|
+
async def store(self, entry: MemoryEntry) -> StoreResult:
|
|
101
|
+
await self._ensure_initialized()
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
if entry.embedding is None:
|
|
105
|
+
entry.embedding = await self.embedding_service.embed(entry.content)
|
|
106
|
+
|
|
107
|
+
# Validate embedding dimensions
|
|
108
|
+
expected_dim = self._get_embedding_dim()
|
|
109
|
+
if len(entry.embedding) != expected_dim:
|
|
110
|
+
return StoreResult(
|
|
111
|
+
success=False,
|
|
112
|
+
error=f"Invalid embedding dimension: got {len(entry.embedding)}, expected {expected_dim}"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
row = {
|
|
116
|
+
"id": entry.id,
|
|
117
|
+
"content": entry.content,
|
|
118
|
+
"vector": entry.embedding,
|
|
119
|
+
"source_instance": entry.source_instance,
|
|
120
|
+
"source_type": entry.source_type.value,
|
|
121
|
+
"created_at": entry.created_at.isoformat(),
|
|
122
|
+
"updated_at": entry.updated_at.isoformat(),
|
|
123
|
+
"tags": json.dumps(entry.tags),
|
|
124
|
+
"context": entry.context or "",
|
|
125
|
+
"confidence": entry.confidence,
|
|
126
|
+
"supersedes": entry.supersedes or "",
|
|
127
|
+
"related_to": json.dumps(entry.related_to),
|
|
128
|
+
"deleted": False,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
self._table.add([row])
|
|
132
|
+
return StoreResult(success=True, memory_id=entry.id)
|
|
133
|
+
|
|
134
|
+
except Exception as e:
|
|
135
|
+
return StoreResult(success=False, error=str(e))
|
|
136
|
+
|
|
137
|
+
async def recall(
|
|
138
|
+
self,
|
|
139
|
+
query_embedding: list[float],
|
|
140
|
+
limit: int = 10,
|
|
141
|
+
min_similarity: float = 0.7,
|
|
142
|
+
filters: Optional[dict] = None,
|
|
143
|
+
) -> list[RecallResult]:
|
|
144
|
+
await self._ensure_initialized()
|
|
145
|
+
|
|
146
|
+
start = time.perf_counter()
|
|
147
|
+
|
|
148
|
+
results = (
|
|
149
|
+
self._table.search(query_embedding)
|
|
150
|
+
.where("deleted = false")
|
|
151
|
+
.limit(limit * 2)
|
|
152
|
+
.to_list()
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
156
|
+
|
|
157
|
+
recall_results = []
|
|
158
|
+
for row in results:
|
|
159
|
+
# LanceDB returns L2 distance. Convert to cosine similarity.
|
|
160
|
+
# For normalized vectors (which OpenAI embeddings are):
|
|
161
|
+
# L2_distance² = 2 * (1 - cosine_similarity)
|
|
162
|
+
# Therefore: cosine_similarity = 1 - (L2_distance² / 2)
|
|
163
|
+
distance = row.get("_distance", 0)
|
|
164
|
+
similarity = max(0, 1 - (distance * distance / 2))
|
|
165
|
+
|
|
166
|
+
if similarity < min_similarity:
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
entry = self._row_to_entry(row)
|
|
170
|
+
|
|
171
|
+
# Apply filters
|
|
172
|
+
if filters:
|
|
173
|
+
if "tags" in filters and filters["tags"]:
|
|
174
|
+
if not any(t in entry.tags for t in filters["tags"]):
|
|
175
|
+
continue
|
|
176
|
+
|
|
177
|
+
recall_results.append(RecallResult(
|
|
178
|
+
memory=entry,
|
|
179
|
+
similarity_score=similarity,
|
|
180
|
+
retrieval_time_ms=elapsed_ms
|
|
181
|
+
))
|
|
182
|
+
|
|
183
|
+
recall_results.sort(key=lambda x: x.similarity_score, reverse=True)
|
|
184
|
+
return recall_results[:limit]
|
|
185
|
+
|
|
186
|
+
async def get(self, memory_id: str) -> Optional[MemoryEntry]:
|
|
187
|
+
await self._ensure_initialized()
|
|
188
|
+
|
|
189
|
+
# Sanitize memory_id to prevent SQL injection
|
|
190
|
+
safe_id = self._sanitize_id(memory_id)
|
|
191
|
+
|
|
192
|
+
results = (
|
|
193
|
+
self._table.search()
|
|
194
|
+
.where(f"id = '{safe_id}' AND deleted = false")
|
|
195
|
+
.limit(1)
|
|
196
|
+
.to_list()
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if not results:
|
|
200
|
+
return None
|
|
201
|
+
return self._row_to_entry(results[0])
|
|
202
|
+
|
|
203
|
+
async def delete(self, memory_id: str) -> bool:
|
|
204
|
+
await self._ensure_initialized()
|
|
205
|
+
|
|
206
|
+
# Sanitize memory_id to prevent SQL injection
|
|
207
|
+
safe_id = self._sanitize_id(memory_id)
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
self._table.update(
|
|
211
|
+
where=f"id = '{safe_id}'",
|
|
212
|
+
values={"deleted": True, "updated_at": datetime.utcnow().isoformat()}
|
|
213
|
+
)
|
|
214
|
+
return True
|
|
215
|
+
except Exception:
|
|
216
|
+
return False
|
|
217
|
+
|
|
218
|
+
def _sanitize_id(self, memory_id: str) -> str:
|
|
219
|
+
"""Sanitize memory_id to prevent SQL injection.
|
|
220
|
+
|
|
221
|
+
UUIDs should only contain alphanumeric characters and hyphens.
|
|
222
|
+
Rejects any input containing quotes, semicolons, or other SQL metacharacters.
|
|
223
|
+
"""
|
|
224
|
+
import re
|
|
225
|
+
# UUID pattern: only allow alphanumeric and hyphens
|
|
226
|
+
if not re.match(r'^[a-zA-Z0-9\-]+$', memory_id):
|
|
227
|
+
raise ValueError(f"Invalid memory_id format: {memory_id[:20]}...")
|
|
228
|
+
return memory_id
|
|
229
|
+
|
|
230
|
+
async def list(
|
|
231
|
+
self,
|
|
232
|
+
limit: int = 1000,
|
|
233
|
+
offset: int = 0,
|
|
234
|
+
filters: Optional[dict] = None,
|
|
235
|
+
) -> list[MemoryEntry]:
|
|
236
|
+
await self._ensure_initialized()
|
|
237
|
+
|
|
238
|
+
results = (
|
|
239
|
+
self._table.search()
|
|
240
|
+
.where("deleted = false")
|
|
241
|
+
.limit(limit + offset)
|
|
242
|
+
.to_list()
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
entries = [self._row_to_entry(row) for row in results[offset:offset + limit]]
|
|
246
|
+
|
|
247
|
+
if filters and "tags" in filters and filters["tags"]:
|
|
248
|
+
entries = [e for e in entries if any(t in e.tags for t in filters["tags"])]
|
|
249
|
+
|
|
250
|
+
return entries
|
|
251
|
+
|
|
252
|
+
async def count(self, filters: Optional[dict] = None) -> int:
|
|
253
|
+
entries = await self.list(limit=100000, filters=filters)
|
|
254
|
+
return len(entries)
|
|
255
|
+
|
|
256
|
+
def _row_to_entry(self, row: dict) -> MemoryEntry:
|
|
257
|
+
return MemoryEntry(
|
|
258
|
+
id=row["id"],
|
|
259
|
+
content=row["content"],
|
|
260
|
+
embedding=row.get("vector"),
|
|
261
|
+
source_instance=row.get("source_instance", "unknown"),
|
|
262
|
+
source_type=MemorySource(row.get("source_type", "unknown")),
|
|
263
|
+
created_at=datetime.fromisoformat(row["created_at"]) if row.get("created_at") else datetime.utcnow(),
|
|
264
|
+
updated_at=datetime.fromisoformat(row["updated_at"]) if row.get("updated_at") else datetime.utcnow(),
|
|
265
|
+
tags=json.loads(row.get("tags", "[]")),
|
|
266
|
+
context=row.get("context") or None,
|
|
267
|
+
confidence=row.get("confidence", 1.0),
|
|
268
|
+
supersedes=row.get("supersedes") or None,
|
|
269
|
+
related_to=json.loads(row.get("related_to", "[]")),
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class InMemoryVectorStore(IVectorStore):
|
|
274
|
+
"""Simple in-memory vector store for testing."""
|
|
275
|
+
|
|
276
|
+
def __init__(self, embedding_service: IEmbeddingService):
|
|
277
|
+
self.embedding_service = embedding_service
|
|
278
|
+
self._store: dict[str, MemoryEntry] = {}
|
|
279
|
+
self._deleted: set[str] = set()
|
|
280
|
+
|
|
281
|
+
async def store(self, entry: MemoryEntry) -> StoreResult:
|
|
282
|
+
if entry.embedding is None:
|
|
283
|
+
entry.embedding = await self.embedding_service.embed(entry.content)
|
|
284
|
+
|
|
285
|
+
self._store[entry.id] = entry
|
|
286
|
+
return StoreResult(success=True, memory_id=entry.id)
|
|
287
|
+
|
|
288
|
+
async def recall(
|
|
289
|
+
self,
|
|
290
|
+
query_embedding: list[float],
|
|
291
|
+
limit: int = 10,
|
|
292
|
+
min_similarity: float = 0.7,
|
|
293
|
+
filters: Optional[dict] = None,
|
|
294
|
+
) -> list[RecallResult]:
|
|
295
|
+
start = time.perf_counter()
|
|
296
|
+
|
|
297
|
+
results = []
|
|
298
|
+
for entry in self._store.values():
|
|
299
|
+
if entry.id in self._deleted:
|
|
300
|
+
continue
|
|
301
|
+
if entry.embedding is None:
|
|
302
|
+
continue
|
|
303
|
+
|
|
304
|
+
# Apply filters
|
|
305
|
+
if filters and "tags" in filters and filters["tags"]:
|
|
306
|
+
if not any(t in entry.tags for t in filters["tags"]):
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
sim = self.embedding_service.similarity(query_embedding, entry.embedding)
|
|
310
|
+
if sim >= min_similarity:
|
|
311
|
+
results.append((entry, sim))
|
|
312
|
+
|
|
313
|
+
results.sort(key=lambda x: x[1], reverse=True)
|
|
314
|
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
315
|
+
|
|
316
|
+
return [
|
|
317
|
+
RecallResult(memory=e, similarity_score=s, retrieval_time_ms=elapsed_ms)
|
|
318
|
+
for e, s in results[:limit]
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
async def get(self, memory_id: str) -> Optional[MemoryEntry]:
|
|
322
|
+
if memory_id in self._deleted:
|
|
323
|
+
return None
|
|
324
|
+
return self._store.get(memory_id)
|
|
325
|
+
|
|
326
|
+
async def delete(self, memory_id: str) -> bool:
|
|
327
|
+
if memory_id in self._store:
|
|
328
|
+
self._deleted.add(memory_id)
|
|
329
|
+
return True
|
|
330
|
+
return False
|
|
331
|
+
|
|
332
|
+
async def list(
|
|
333
|
+
self,
|
|
334
|
+
limit: int = 1000,
|
|
335
|
+
offset: int = 0,
|
|
336
|
+
filters: Optional[dict] = None,
|
|
337
|
+
) -> list[MemoryEntry]:
|
|
338
|
+
entries = [
|
|
339
|
+
e for e in list(self._store.values())
|
|
340
|
+
if e.id not in self._deleted
|
|
341
|
+
]
|
|
342
|
+
|
|
343
|
+
if filters and "tags" in filters and filters["tags"]:
|
|
344
|
+
entries = [e for e in entries if any(t in e.tags for t in filters["tags"])]
|
|
345
|
+
|
|
346
|
+
return entries[offset:offset + limit]
|
|
347
|
+
|
|
348
|
+
async def upsert(self, entry: MemoryEntry) -> StoreResult:
|
|
349
|
+
"""Insert or replace, clearing any soft-delete tombstone."""
|
|
350
|
+
self._deleted.discard(entry.id)
|
|
351
|
+
if entry.embedding is None:
|
|
352
|
+
entry.embedding = (
|
|
353
|
+
await self.embedding_service.embed(entry.content)
|
|
354
|
+
)
|
|
355
|
+
self._store[entry.id] = entry
|
|
356
|
+
return StoreResult(success=True, memory_id=entry.id)
|
|
357
|
+
|
|
358
|
+
async def count(self, filters: Optional[dict] = None) -> int:
|
|
359
|
+
entries = await self.list(limit=100000, filters=filters)
|
|
360
|
+
return len(entries)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Testing utilities for Tribal Memory."""
|
|
2
|
+
|
|
3
|
+
from .mocks import (
|
|
4
|
+
MockEmbeddingService,
|
|
5
|
+
MockVectorStore,
|
|
6
|
+
MockMemoryService,
|
|
7
|
+
MockTimestampService,
|
|
8
|
+
)
|
|
9
|
+
from .metrics import LatencyTracker, SimilarityCalculator, TestResultLogger
|
|
10
|
+
from .fixtures import load_test_data, TestDataSet
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"MockEmbeddingService",
|
|
14
|
+
"MockVectorStore",
|
|
15
|
+
"MockMemoryService",
|
|
16
|
+
"MockTimestampService",
|
|
17
|
+
"LatencyTracker",
|
|
18
|
+
"SimilarityCalculator",
|
|
19
|
+
"TestResultLogger",
|
|
20
|
+
"load_test_data",
|
|
21
|
+
"TestDataSet",
|
|
22
|
+
]
|