hindsight-api 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -9
- hindsight_api/alembic/env.py +5 -8
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
- hindsight_api/api/__init__.py +10 -10
- hindsight_api/api/http.py +575 -593
- hindsight_api/api/mcp.py +30 -28
- hindsight_api/banner.py +13 -6
- hindsight_api/config.py +9 -13
- hindsight_api/engine/__init__.py +9 -9
- hindsight_api/engine/cross_encoder.py +22 -21
- hindsight_api/engine/db_utils.py +5 -4
- hindsight_api/engine/embeddings.py +22 -21
- hindsight_api/engine/entity_resolver.py +81 -75
- hindsight_api/engine/llm_wrapper.py +61 -79
- hindsight_api/engine/memory_engine.py +603 -625
- hindsight_api/engine/query_analyzer.py +100 -97
- hindsight_api/engine/response_models.py +105 -106
- hindsight_api/engine/retain/__init__.py +9 -16
- hindsight_api/engine/retain/bank_utils.py +34 -58
- hindsight_api/engine/retain/chunk_storage.py +4 -12
- hindsight_api/engine/retain/deduplication.py +9 -28
- hindsight_api/engine/retain/embedding_processing.py +4 -11
- hindsight_api/engine/retain/embedding_utils.py +3 -4
- hindsight_api/engine/retain/entity_processing.py +7 -17
- hindsight_api/engine/retain/fact_extraction.py +155 -165
- hindsight_api/engine/retain/fact_storage.py +11 -23
- hindsight_api/engine/retain/link_creation.py +11 -39
- hindsight_api/engine/retain/link_utils.py +166 -95
- hindsight_api/engine/retain/observation_regeneration.py +39 -52
- hindsight_api/engine/retain/orchestrator.py +72 -62
- hindsight_api/engine/retain/types.py +49 -43
- hindsight_api/engine/search/__init__.py +5 -5
- hindsight_api/engine/search/fusion.py +6 -15
- hindsight_api/engine/search/graph_retrieval.py +22 -23
- hindsight_api/engine/search/mpfp_retrieval.py +76 -92
- hindsight_api/engine/search/observation_utils.py +9 -16
- hindsight_api/engine/search/reranking.py +4 -7
- hindsight_api/engine/search/retrieval.py +87 -66
- hindsight_api/engine/search/scoring.py +5 -7
- hindsight_api/engine/search/temporal_extraction.py +8 -11
- hindsight_api/engine/search/think_utils.py +115 -39
- hindsight_api/engine/search/trace.py +68 -39
- hindsight_api/engine/search/tracer.py +44 -35
- hindsight_api/engine/search/types.py +20 -17
- hindsight_api/engine/task_backend.py +21 -26
- hindsight_api/engine/utils.py +25 -10
- hindsight_api/main.py +21 -40
- hindsight_api/mcp_local.py +190 -0
- hindsight_api/metrics.py +44 -30
- hindsight_api/migrations.py +10 -8
- hindsight_api/models.py +60 -72
- hindsight_api/pg0.py +22 -23
- hindsight_api/server.py +3 -6
- {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +2 -2
- hindsight_api-0.1.6.dist-info/RECORD +64 -0
- {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.1.5.dist-info/RECORD +0 -63
- {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
|
@@ -8,22 +8,23 @@ This implements a sophisticated memory architecture that combines:
|
|
|
8
8
|
4. Spreading activation: Search through the graph with activation decay
|
|
9
9
|
5. Dynamic weighting: Recency and frequency-based importance
|
|
10
10
|
"""
|
|
11
|
-
|
|
12
|
-
import os
|
|
13
|
-
from datetime import datetime, timedelta, timezone
|
|
14
|
-
from typing import Any, Dict, List, Optional, Tuple, Union, TypedDict, TYPE_CHECKING
|
|
15
|
-
import asyncpg
|
|
11
|
+
|
|
16
12
|
import asyncio
|
|
17
|
-
|
|
18
|
-
from .cross_encoder import CrossEncoderModel, create_cross_encoder_from_env
|
|
13
|
+
import logging
|
|
19
14
|
import time
|
|
20
|
-
import numpy as np
|
|
21
15
|
import uuid
|
|
22
|
-
import
|
|
16
|
+
from datetime import UTC, datetime, timedelta
|
|
17
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
18
|
+
|
|
19
|
+
import asyncpg
|
|
20
|
+
import numpy as np
|
|
23
21
|
from pydantic import BaseModel, Field
|
|
24
22
|
|
|
23
|
+
from .cross_encoder import CrossEncoderModel
|
|
24
|
+
from .embeddings import Embeddings, create_embeddings_from_env
|
|
25
|
+
|
|
25
26
|
if TYPE_CHECKING:
|
|
26
|
-
|
|
27
|
+
pass
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class RetainContentDict(TypedDict, total=False):
|
|
@@ -36,30 +37,31 @@ class RetainContentDict(TypedDict, total=False):
|
|
|
36
37
|
metadata: Custom key-value metadata (optional)
|
|
37
38
|
document_id: Document ID for this content item (optional)
|
|
38
39
|
"""
|
|
40
|
+
|
|
39
41
|
content: str # Required
|
|
40
42
|
context: str
|
|
41
43
|
event_date: datetime
|
|
42
|
-
metadata:
|
|
44
|
+
metadata: dict[str, str]
|
|
43
45
|
document_id: str
|
|
44
46
|
|
|
45
|
-
|
|
46
|
-
from
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
)
|
|
47
|
+
|
|
48
|
+
from enum import Enum
|
|
49
|
+
|
|
50
|
+
from ..pg0 import EmbeddedPostgres
|
|
50
51
|
from .entity_resolver import EntityResolver
|
|
51
|
-
from .retain import embedding_utils, bank_utils
|
|
52
|
-
from .search import think_utils, observation_utils
|
|
53
52
|
from .llm_wrapper import LLMConfig
|
|
54
|
-
from .
|
|
55
|
-
from .
|
|
53
|
+
from .query_analyzer import QueryAnalyzer
|
|
54
|
+
from .response_models import VALID_RECALL_FACT_TYPES, EntityObservation, EntityState, MemoryFact, ReflectResult
|
|
55
|
+
from .response_models import RecallResult as RecallResultModel
|
|
56
|
+
from .retain import bank_utils, embedding_utils
|
|
57
|
+
from .search import observation_utils, think_utils
|
|
56
58
|
from .search.reranking import CrossEncoderReranker
|
|
57
|
-
from
|
|
58
|
-
from enum import Enum
|
|
59
|
+
from .task_backend import AsyncIOQueueBackend, TaskBackend
|
|
59
60
|
|
|
60
61
|
|
|
61
62
|
class Budget(str, Enum):
|
|
62
63
|
"""Budget levels for recall/reflect operations."""
|
|
64
|
+
|
|
63
65
|
LOW = "low"
|
|
64
66
|
MID = "mid"
|
|
65
67
|
HIGH = "high"
|
|
@@ -67,20 +69,20 @@ class Budget(str, Enum):
|
|
|
67
69
|
|
|
68
70
|
def utcnow():
|
|
69
71
|
"""Get current UTC time with timezone info."""
|
|
70
|
-
return datetime.now(
|
|
72
|
+
return datetime.now(UTC)
|
|
71
73
|
|
|
72
74
|
|
|
73
75
|
# Logger for memory system
|
|
74
76
|
logger = logging.getLogger(__name__)
|
|
75
77
|
|
|
76
|
-
from .db_utils import acquire_with_retry, retry_with_backoff
|
|
77
|
-
|
|
78
78
|
import tiktoken
|
|
79
|
-
|
|
79
|
+
|
|
80
|
+
from .db_utils import acquire_with_retry
|
|
80
81
|
|
|
81
82
|
# Cache tiktoken encoding for token budget filtering (module-level singleton)
|
|
82
83
|
_TIKTOKEN_ENCODING = None
|
|
83
84
|
|
|
85
|
+
|
|
84
86
|
def _get_tiktoken_encoding():
|
|
85
87
|
"""Get cached tiktoken encoding (cl100k_base for GPT-4/3.5)."""
|
|
86
88
|
global _TIKTOKEN_ENCODING
|
|
@@ -102,17 +104,17 @@ class MemoryEngine:
|
|
|
102
104
|
|
|
103
105
|
def __init__(
|
|
104
106
|
self,
|
|
105
|
-
db_url:
|
|
106
|
-
memory_llm_provider:
|
|
107
|
-
memory_llm_api_key:
|
|
108
|
-
memory_llm_model:
|
|
109
|
-
memory_llm_base_url:
|
|
110
|
-
embeddings:
|
|
111
|
-
cross_encoder:
|
|
112
|
-
query_analyzer:
|
|
107
|
+
db_url: str | None = None,
|
|
108
|
+
memory_llm_provider: str | None = None,
|
|
109
|
+
memory_llm_api_key: str | None = None,
|
|
110
|
+
memory_llm_model: str | None = None,
|
|
111
|
+
memory_llm_base_url: str | None = None,
|
|
112
|
+
embeddings: Embeddings | None = None,
|
|
113
|
+
cross_encoder: CrossEncoderModel | None = None,
|
|
114
|
+
query_analyzer: QueryAnalyzer | None = None,
|
|
113
115
|
pool_min_size: int = 5,
|
|
114
116
|
pool_max_size: int = 100,
|
|
115
|
-
task_backend:
|
|
117
|
+
task_backend: TaskBackend | None = None,
|
|
116
118
|
run_migrations: bool = True,
|
|
117
119
|
):
|
|
118
120
|
"""
|
|
@@ -138,6 +140,7 @@ class MemoryEngine:
|
|
|
138
140
|
"""
|
|
139
141
|
# Load config from environment for any missing parameters
|
|
140
142
|
from ..config import get_config
|
|
143
|
+
|
|
141
144
|
config = get_config()
|
|
142
145
|
|
|
143
146
|
# Apply defaults from config
|
|
@@ -147,8 +150,8 @@ class MemoryEngine:
|
|
|
147
150
|
memory_llm_model = memory_llm_model or config.llm_model
|
|
148
151
|
memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
|
|
149
152
|
# Track pg0 instance (if used)
|
|
150
|
-
self._pg0:
|
|
151
|
-
self._pg0_instance_name:
|
|
153
|
+
self._pg0: EmbeddedPostgres | None = None
|
|
154
|
+
self._pg0_instance_name: str | None = None
|
|
152
155
|
|
|
153
156
|
# Initialize PostgreSQL connection URL
|
|
154
157
|
# The actual URL will be set during initialize() after starting the server
|
|
@@ -175,7 +178,6 @@ class MemoryEngine:
|
|
|
175
178
|
self._pg0_port = None
|
|
176
179
|
self.db_url = db_url
|
|
177
180
|
|
|
178
|
-
|
|
179
181
|
# Set default base URL if not provided
|
|
180
182
|
if memory_llm_base_url is None:
|
|
181
183
|
if memory_llm_provider.lower() == "groq":
|
|
@@ -206,6 +208,7 @@ class MemoryEngine:
|
|
|
206
208
|
self.query_analyzer = query_analyzer
|
|
207
209
|
else:
|
|
208
210
|
from .query_analyzer import DateparserQueryAnalyzer
|
|
211
|
+
|
|
209
212
|
self.query_analyzer = DateparserQueryAnalyzer()
|
|
210
213
|
|
|
211
214
|
# Initialize LLM configuration
|
|
@@ -224,10 +227,7 @@ class MemoryEngine:
|
|
|
224
227
|
self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
|
|
225
228
|
|
|
226
229
|
# Initialize task backend
|
|
227
|
-
self._task_backend = task_backend or AsyncIOQueueBackend(
|
|
228
|
-
batch_size=100,
|
|
229
|
-
batch_interval=1.0
|
|
230
|
-
)
|
|
230
|
+
self._task_backend = task_backend or AsyncIOQueueBackend(batch_size=100, batch_interval=1.0)
|
|
231
231
|
|
|
232
232
|
# Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
|
|
233
233
|
# Limit concurrent searches to prevent connection pool exhaustion
|
|
@@ -243,14 +243,14 @@ class MemoryEngine:
|
|
|
243
243
|
# initialize encoding eagerly to avoid delaying the first time
|
|
244
244
|
_get_tiktoken_encoding()
|
|
245
245
|
|
|
246
|
-
async def _handle_access_count_update(self, task_dict:
|
|
246
|
+
async def _handle_access_count_update(self, task_dict: dict[str, Any]):
|
|
247
247
|
"""
|
|
248
248
|
Handler for access count update tasks.
|
|
249
249
|
|
|
250
250
|
Args:
|
|
251
251
|
task_dict: Dict with 'node_ids' key containing list of node IDs to update
|
|
252
252
|
"""
|
|
253
|
-
node_ids = task_dict.get(
|
|
253
|
+
node_ids = task_dict.get("node_ids", [])
|
|
254
254
|
if not node_ids:
|
|
255
255
|
return
|
|
256
256
|
|
|
@@ -260,13 +260,12 @@ class MemoryEngine:
|
|
|
260
260
|
uuid_list = [uuid.UUID(nid) for nid in node_ids]
|
|
261
261
|
async with acquire_with_retry(pool) as conn:
|
|
262
262
|
await conn.execute(
|
|
263
|
-
"UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
|
|
264
|
-
uuid_list
|
|
263
|
+
"UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])", uuid_list
|
|
265
264
|
)
|
|
266
265
|
except Exception as e:
|
|
267
266
|
logger.error(f"Access count handler: Error updating access counts: {e}")
|
|
268
267
|
|
|
269
|
-
async def _handle_batch_retain(self, task_dict:
|
|
268
|
+
async def _handle_batch_retain(self, task_dict: dict[str, Any]):
|
|
270
269
|
"""
|
|
271
270
|
Handler for batch retain tasks.
|
|
272
271
|
|
|
@@ -274,23 +273,23 @@ class MemoryEngine:
|
|
|
274
273
|
task_dict: Dict with 'bank_id', 'contents'
|
|
275
274
|
"""
|
|
276
275
|
try:
|
|
277
|
-
bank_id = task_dict.get(
|
|
278
|
-
contents = task_dict.get(
|
|
279
|
-
|
|
280
|
-
logger.info(f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items")
|
|
276
|
+
bank_id = task_dict.get("bank_id")
|
|
277
|
+
contents = task_dict.get("contents", [])
|
|
281
278
|
|
|
282
|
-
|
|
283
|
-
bank_id=bank_id,
|
|
284
|
-
contents=contents
|
|
279
|
+
logger.info(
|
|
280
|
+
f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
|
|
285
281
|
)
|
|
286
282
|
|
|
283
|
+
await self.retain_batch_async(bank_id=bank_id, contents=contents)
|
|
284
|
+
|
|
287
285
|
logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
|
|
288
286
|
except Exception as e:
|
|
289
287
|
logger.error(f"Batch retain handler: Error processing batch retain: {e}")
|
|
290
288
|
import traceback
|
|
289
|
+
|
|
291
290
|
traceback.print_exc()
|
|
292
291
|
|
|
293
|
-
async def execute_task(self, task_dict:
|
|
292
|
+
async def execute_task(self, task_dict: dict[str, Any]):
|
|
294
293
|
"""
|
|
295
294
|
Execute a task by routing it to the appropriate handler.
|
|
296
295
|
|
|
@@ -301,9 +300,9 @@ class MemoryEngine:
|
|
|
301
300
|
task_dict: Task dictionary with 'type' key and other payload data
|
|
302
301
|
Example: {'type': 'access_count_update', 'node_ids': [...]}
|
|
303
302
|
"""
|
|
304
|
-
task_type = task_dict.get(
|
|
305
|
-
operation_id = task_dict.get(
|
|
306
|
-
retry_count = task_dict.get(
|
|
303
|
+
task_type = task_dict.get("type")
|
|
304
|
+
operation_id = task_dict.get("operation_id")
|
|
305
|
+
retry_count = task_dict.get("retry_count", 0)
|
|
307
306
|
max_retries = 3
|
|
308
307
|
|
|
309
308
|
# Check if operation was cancelled (only for tasks with operation_id)
|
|
@@ -312,8 +311,7 @@ class MemoryEngine:
|
|
|
312
311
|
pool = await self._get_pool()
|
|
313
312
|
async with acquire_with_retry(pool) as conn:
|
|
314
313
|
result = await conn.fetchrow(
|
|
315
|
-
"SELECT id FROM async_operations WHERE id = $1",
|
|
316
|
-
uuid.UUID(operation_id)
|
|
314
|
+
"SELECT id FROM async_operations WHERE id = $1", uuid.UUID(operation_id)
|
|
317
315
|
)
|
|
318
316
|
if not result:
|
|
319
317
|
# Operation was cancelled, skip processing
|
|
@@ -324,15 +322,15 @@ class MemoryEngine:
|
|
|
324
322
|
# Continue with processing if we can't check status
|
|
325
323
|
|
|
326
324
|
try:
|
|
327
|
-
if task_type ==
|
|
325
|
+
if task_type == "access_count_update":
|
|
328
326
|
await self._handle_access_count_update(task_dict)
|
|
329
|
-
elif task_type ==
|
|
327
|
+
elif task_type == "reinforce_opinion":
|
|
330
328
|
await self._handle_reinforce_opinion(task_dict)
|
|
331
|
-
elif task_type ==
|
|
329
|
+
elif task_type == "form_opinion":
|
|
332
330
|
await self._handle_form_opinion(task_dict)
|
|
333
|
-
elif task_type ==
|
|
331
|
+
elif task_type == "batch_retain":
|
|
334
332
|
await self._handle_batch_retain(task_dict)
|
|
335
|
-
elif task_type ==
|
|
333
|
+
elif task_type == "regenerate_observations":
|
|
336
334
|
await self._handle_regenerate_observations(task_dict)
|
|
337
335
|
else:
|
|
338
336
|
logger.error(f"Unknown task type: {task_type}")
|
|
@@ -347,14 +345,17 @@ class MemoryEngine:
|
|
|
347
345
|
|
|
348
346
|
except Exception as e:
|
|
349
347
|
# Task failed - check if we should retry
|
|
350
|
-
logger.error(
|
|
348
|
+
logger.error(
|
|
349
|
+
f"Task execution failed (attempt {retry_count + 1}/{max_retries + 1}): {task_type}, error: {e}"
|
|
350
|
+
)
|
|
351
351
|
import traceback
|
|
352
|
+
|
|
352
353
|
error_traceback = traceback.format_exc()
|
|
353
354
|
traceback.print_exc()
|
|
354
355
|
|
|
355
356
|
if retry_count < max_retries:
|
|
356
357
|
# Reschedule with incremented retry count
|
|
357
|
-
task_dict[
|
|
358
|
+
task_dict["retry_count"] = retry_count + 1
|
|
358
359
|
logger.info(f"Rescheduling task {task_type} (retry {retry_count + 1}/{max_retries})")
|
|
359
360
|
await self._task_backend.submit_task(task_dict)
|
|
360
361
|
else:
|
|
@@ -368,10 +369,7 @@ class MemoryEngine:
|
|
|
368
369
|
try:
|
|
369
370
|
pool = await self._get_pool()
|
|
370
371
|
async with acquire_with_retry(pool) as conn:
|
|
371
|
-
await conn.execute(
|
|
372
|
-
"DELETE FROM async_operations WHERE id = $1",
|
|
373
|
-
uuid.UUID(operation_id)
|
|
374
|
-
)
|
|
372
|
+
await conn.execute("DELETE FROM async_operations WHERE id = $1", uuid.UUID(operation_id))
|
|
375
373
|
except Exception as e:
|
|
376
374
|
logger.error(f"Failed to delete async operation record {operation_id}: {e}")
|
|
377
375
|
|
|
@@ -391,7 +389,7 @@ class MemoryEngine:
|
|
|
391
389
|
WHERE id = $1
|
|
392
390
|
""",
|
|
393
391
|
uuid.UUID(operation_id),
|
|
394
|
-
truncated_error
|
|
392
|
+
truncated_error,
|
|
395
393
|
)
|
|
396
394
|
logger.info(f"Marked async operation as failed: {operation_id}")
|
|
397
395
|
except Exception as e:
|
|
@@ -406,8 +404,6 @@ class MemoryEngine:
|
|
|
406
404
|
if self._initialized:
|
|
407
405
|
return
|
|
408
406
|
|
|
409
|
-
import concurrent.futures
|
|
410
|
-
|
|
411
407
|
# Run model loading in thread pool (CPU-bound) in parallel with pg0 startup
|
|
412
408
|
loop = asyncio.get_event_loop()
|
|
413
409
|
|
|
@@ -429,10 +425,7 @@ class MemoryEngine:
|
|
|
429
425
|
"""Initialize embedding model."""
|
|
430
426
|
# For local providers, run in thread pool to avoid blocking event loop
|
|
431
427
|
if self.embeddings.provider_name == "local":
|
|
432
|
-
await loop.run_in_executor(
|
|
433
|
-
None,
|
|
434
|
-
lambda: asyncio.run(self.embeddings.initialize())
|
|
435
|
-
)
|
|
428
|
+
await loop.run_in_executor(None, lambda: asyncio.run(self.embeddings.initialize()))
|
|
436
429
|
else:
|
|
437
430
|
await self.embeddings.initialize()
|
|
438
431
|
|
|
@@ -441,10 +434,7 @@ class MemoryEngine:
|
|
|
441
434
|
cross_encoder = self._cross_encoder_reranker.cross_encoder
|
|
442
435
|
# For local providers, run in thread pool to avoid blocking event loop
|
|
443
436
|
if cross_encoder.provider_name == "local":
|
|
444
|
-
await loop.run_in_executor(
|
|
445
|
-
None,
|
|
446
|
-
lambda: asyncio.run(cross_encoder.initialize())
|
|
447
|
-
)
|
|
437
|
+
await loop.run_in_executor(None, lambda: asyncio.run(cross_encoder.initialize()))
|
|
448
438
|
else:
|
|
449
439
|
await cross_encoder.initialize()
|
|
450
440
|
|
|
@@ -469,6 +459,7 @@ class MemoryEngine:
|
|
|
469
459
|
# Run database migrations if enabled
|
|
470
460
|
if self._run_migrations:
|
|
471
461
|
from ..migrations import run_migrations
|
|
462
|
+
|
|
472
463
|
logger.info("Running database migrations...")
|
|
473
464
|
run_migrations(self.db_url)
|
|
474
465
|
|
|
@@ -555,7 +546,6 @@ class MemoryEngine:
|
|
|
555
546
|
self._pg0 = None
|
|
556
547
|
logger.info("pg0 stopped")
|
|
557
548
|
|
|
558
|
-
|
|
559
549
|
async def wait_for_background_tasks(self):
|
|
560
550
|
"""
|
|
561
551
|
Wait for all pending background tasks to complete.
|
|
@@ -563,7 +553,7 @@ class MemoryEngine:
|
|
|
563
553
|
This is useful in tests to ensure background tasks (like opinion reinforcement)
|
|
564
554
|
complete before making assertions.
|
|
565
555
|
"""
|
|
566
|
-
if hasattr(self._task_backend,
|
|
556
|
+
if hasattr(self._task_backend, "wait_for_pending_tasks"):
|
|
567
557
|
await self._task_backend.wait_for_pending_tasks()
|
|
568
558
|
|
|
569
559
|
def _format_readable_date(self, dt: datetime) -> str:
|
|
@@ -596,12 +586,12 @@ class MemoryEngine:
|
|
|
596
586
|
self,
|
|
597
587
|
conn,
|
|
598
588
|
bank_id: str,
|
|
599
|
-
texts:
|
|
600
|
-
embeddings:
|
|
589
|
+
texts: list[str],
|
|
590
|
+
embeddings: list[list[float]],
|
|
601
591
|
event_date: datetime,
|
|
602
592
|
time_window_hours: int = 24,
|
|
603
|
-
similarity_threshold: float = 0.95
|
|
604
|
-
) ->
|
|
593
|
+
similarity_threshold: float = 0.95,
|
|
594
|
+
) -> list[bool]:
|
|
605
595
|
"""
|
|
606
596
|
Check which facts are duplicates using semantic similarity + temporal window.
|
|
607
597
|
|
|
@@ -635,6 +625,7 @@ class MemoryEngine:
|
|
|
635
625
|
|
|
636
626
|
# Fetch ALL existing facts in time window ONCE (much faster than N queries)
|
|
637
627
|
import time as time_mod
|
|
628
|
+
|
|
638
629
|
fetch_start = time_mod.time()
|
|
639
630
|
existing_facts = await conn.fetch(
|
|
640
631
|
"""
|
|
@@ -643,7 +634,9 @@ class MemoryEngine:
|
|
|
643
634
|
WHERE bank_id = $1
|
|
644
635
|
AND event_date BETWEEN $2 AND $3
|
|
645
636
|
""",
|
|
646
|
-
bank_id,
|
|
637
|
+
bank_id,
|
|
638
|
+
time_lower,
|
|
639
|
+
time_upper,
|
|
647
640
|
)
|
|
648
641
|
|
|
649
642
|
# If no existing facts, nothing is duplicate
|
|
@@ -651,17 +644,17 @@ class MemoryEngine:
|
|
|
651
644
|
return [False] * len(texts)
|
|
652
645
|
|
|
653
646
|
# Compute similarities in Python (vectorized with numpy)
|
|
654
|
-
import numpy as np
|
|
655
647
|
is_duplicate = []
|
|
656
648
|
|
|
657
649
|
# Convert existing embeddings to numpy for faster computation
|
|
658
650
|
embedding_arrays = []
|
|
659
651
|
for row in existing_facts:
|
|
660
|
-
raw_emb = row[
|
|
652
|
+
raw_emb = row["embedding"]
|
|
661
653
|
# Handle different pgvector formats
|
|
662
654
|
if isinstance(raw_emb, str):
|
|
663
655
|
# Parse string format: "[1.0, 2.0, ...]"
|
|
664
656
|
import json
|
|
657
|
+
|
|
665
658
|
emb = np.array(json.loads(raw_emb), dtype=np.float32)
|
|
666
659
|
elif isinstance(raw_emb, (list, tuple)):
|
|
667
660
|
emb = np.array(raw_emb, dtype=np.float32)
|
|
@@ -691,7 +684,6 @@ class MemoryEngine:
|
|
|
691
684
|
max_similarity = np.max(similarities) if len(similarities) > 0 else 0
|
|
692
685
|
is_duplicate.append(max_similarity > similarity_threshold)
|
|
693
686
|
|
|
694
|
-
|
|
695
687
|
return is_duplicate
|
|
696
688
|
|
|
697
689
|
def retain(
|
|
@@ -699,8 +691,8 @@ class MemoryEngine:
|
|
|
699
691
|
bank_id: str,
|
|
700
692
|
content: str,
|
|
701
693
|
context: str = "",
|
|
702
|
-
event_date:
|
|
703
|
-
) ->
|
|
694
|
+
event_date: datetime | None = None,
|
|
695
|
+
) -> list[str]:
|
|
704
696
|
"""
|
|
705
697
|
Store content as memory units (synchronous wrapper).
|
|
706
698
|
|
|
@@ -724,11 +716,11 @@ class MemoryEngine:
|
|
|
724
716
|
bank_id: str,
|
|
725
717
|
content: str,
|
|
726
718
|
context: str = "",
|
|
727
|
-
event_date:
|
|
728
|
-
document_id:
|
|
729
|
-
fact_type_override:
|
|
730
|
-
confidence_score:
|
|
731
|
-
) ->
|
|
719
|
+
event_date: datetime | None = None,
|
|
720
|
+
document_id: str | None = None,
|
|
721
|
+
fact_type_override: str | None = None,
|
|
722
|
+
confidence_score: float | None = None,
|
|
723
|
+
) -> list[str]:
|
|
732
724
|
"""
|
|
733
725
|
Store content as memory units with temporal and semantic links (ASYNC version).
|
|
734
726
|
|
|
@@ -747,11 +739,7 @@ class MemoryEngine:
|
|
|
747
739
|
List of created unit IDs
|
|
748
740
|
"""
|
|
749
741
|
# Build content dict
|
|
750
|
-
content_dict: RetainContentDict = {
|
|
751
|
-
"content": content,
|
|
752
|
-
"context": context,
|
|
753
|
-
"event_date": event_date
|
|
754
|
-
}
|
|
742
|
+
content_dict: RetainContentDict = {"content": content, "context": context, "event_date": event_date}
|
|
755
743
|
if document_id:
|
|
756
744
|
content_dict["document_id"] = document_id
|
|
757
745
|
|
|
@@ -760,7 +748,7 @@ class MemoryEngine:
|
|
|
760
748
|
bank_id=bank_id,
|
|
761
749
|
contents=[content_dict],
|
|
762
750
|
fact_type_override=fact_type_override,
|
|
763
|
-
confidence_score=confidence_score
|
|
751
|
+
confidence_score=confidence_score,
|
|
764
752
|
)
|
|
765
753
|
|
|
766
754
|
# Return the first (and only) list of unit IDs
|
|
@@ -769,11 +757,11 @@ class MemoryEngine:
|
|
|
769
757
|
async def retain_batch_async(
|
|
770
758
|
self,
|
|
771
759
|
bank_id: str,
|
|
772
|
-
contents:
|
|
773
|
-
document_id:
|
|
774
|
-
fact_type_override:
|
|
775
|
-
confidence_score:
|
|
776
|
-
) ->
|
|
760
|
+
contents: list[RetainContentDict],
|
|
761
|
+
document_id: str | None = None,
|
|
762
|
+
fact_type_override: str | None = None,
|
|
763
|
+
confidence_score: float | None = None,
|
|
764
|
+
) -> list[list[str]]:
|
|
777
765
|
"""
|
|
778
766
|
Store multiple content items as memory units in ONE batch operation.
|
|
779
767
|
|
|
@@ -839,7 +827,9 @@ class MemoryEngine:
|
|
|
839
827
|
|
|
840
828
|
if total_chars > CHARS_PER_BATCH:
|
|
841
829
|
# Split into smaller batches based on character count
|
|
842
|
-
logger.info(
|
|
830
|
+
logger.info(
|
|
831
|
+
f"Large batch detected ({total_chars:,} chars from {len(contents)} items). Splitting into sub-batches of ~{CHARS_PER_BATCH:,} chars each..."
|
|
832
|
+
)
|
|
843
833
|
|
|
844
834
|
sub_batches = []
|
|
845
835
|
current_batch = []
|
|
@@ -868,7 +858,9 @@ class MemoryEngine:
|
|
|
868
858
|
all_results = []
|
|
869
859
|
for i, sub_batch in enumerate(sub_batches, 1):
|
|
870
860
|
sub_batch_chars = sum(len(item.get("content", "")) for item in sub_batch)
|
|
871
|
-
logger.info(
|
|
861
|
+
logger.info(
|
|
862
|
+
f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars"
|
|
863
|
+
)
|
|
872
864
|
|
|
873
865
|
sub_results = await self._retain_batch_async_internal(
|
|
874
866
|
bank_id=bank_id,
|
|
@@ -876,12 +868,14 @@ class MemoryEngine:
|
|
|
876
868
|
document_id=document_id,
|
|
877
869
|
is_first_batch=i == 1, # Only upsert on first batch
|
|
878
870
|
fact_type_override=fact_type_override,
|
|
879
|
-
confidence_score=confidence_score
|
|
871
|
+
confidence_score=confidence_score,
|
|
880
872
|
)
|
|
881
873
|
all_results.extend(sub_results)
|
|
882
874
|
|
|
883
875
|
total_time = time.time() - start_time
|
|
884
|
-
logger.info(
|
|
876
|
+
logger.info(
|
|
877
|
+
f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s"
|
|
878
|
+
)
|
|
885
879
|
return all_results
|
|
886
880
|
|
|
887
881
|
# Small batch - use internal method directly
|
|
@@ -891,18 +885,18 @@ class MemoryEngine:
|
|
|
891
885
|
document_id=document_id,
|
|
892
886
|
is_first_batch=True,
|
|
893
887
|
fact_type_override=fact_type_override,
|
|
894
|
-
confidence_score=confidence_score
|
|
888
|
+
confidence_score=confidence_score,
|
|
895
889
|
)
|
|
896
890
|
|
|
897
891
|
async def _retain_batch_async_internal(
|
|
898
892
|
self,
|
|
899
893
|
bank_id: str,
|
|
900
|
-
contents:
|
|
901
|
-
document_id:
|
|
894
|
+
contents: list[RetainContentDict],
|
|
895
|
+
document_id: str | None = None,
|
|
902
896
|
is_first_batch: bool = True,
|
|
903
|
-
fact_type_override:
|
|
904
|
-
confidence_score:
|
|
905
|
-
) ->
|
|
897
|
+
fact_type_override: str | None = None,
|
|
898
|
+
confidence_score: float | None = None,
|
|
899
|
+
) -> list[list[str]]:
|
|
906
900
|
"""
|
|
907
901
|
Internal method for batch processing without chunking logic.
|
|
908
902
|
|
|
@@ -938,7 +932,7 @@ class MemoryEngine:
|
|
|
938
932
|
document_id=document_id,
|
|
939
933
|
is_first_batch=is_first_batch,
|
|
940
934
|
fact_type_override=fact_type_override,
|
|
941
|
-
confidence_score=confidence_score
|
|
935
|
+
confidence_score=confidence_score,
|
|
942
936
|
)
|
|
943
937
|
|
|
944
938
|
def recall(
|
|
@@ -949,7 +943,7 @@ class MemoryEngine:
|
|
|
949
943
|
budget: Budget = Budget.MID,
|
|
950
944
|
max_tokens: int = 4096,
|
|
951
945
|
enable_trace: bool = False,
|
|
952
|
-
) -> tuple[
|
|
946
|
+
) -> tuple[list[dict[str, Any]], Any | None]:
|
|
953
947
|
"""
|
|
954
948
|
Recall memories using 4-way parallel retrieval (synchronous wrapper).
|
|
955
949
|
|
|
@@ -968,19 +962,17 @@ class MemoryEngine:
|
|
|
968
962
|
Tuple of (results, trace)
|
|
969
963
|
"""
|
|
970
964
|
# Run async version synchronously
|
|
971
|
-
return asyncio.run(self.recall_async(
|
|
972
|
-
bank_id, query, [fact_type], budget, max_tokens, enable_trace
|
|
973
|
-
))
|
|
965
|
+
return asyncio.run(self.recall_async(bank_id, query, [fact_type], budget, max_tokens, enable_trace))
|
|
974
966
|
|
|
975
967
|
async def recall_async(
|
|
976
968
|
self,
|
|
977
969
|
bank_id: str,
|
|
978
970
|
query: str,
|
|
979
|
-
fact_type:
|
|
971
|
+
fact_type: list[str],
|
|
980
972
|
budget: Budget = Budget.MID,
|
|
981
973
|
max_tokens: int = 4096,
|
|
982
974
|
enable_trace: bool = False,
|
|
983
|
-
question_date:
|
|
975
|
+
question_date: datetime | None = None,
|
|
984
976
|
include_entities: bool = False,
|
|
985
977
|
max_entity_tokens: int = 1024,
|
|
986
978
|
include_chunks: bool = False,
|
|
@@ -1027,11 +1019,7 @@ class MemoryEngine:
|
|
|
1027
1019
|
)
|
|
1028
1020
|
|
|
1029
1021
|
# Map budget enum to thinking_budget number
|
|
1030
|
-
budget_mapping = {
|
|
1031
|
-
Budget.LOW: 100,
|
|
1032
|
-
Budget.MID: 300,
|
|
1033
|
-
Budget.HIGH: 1000
|
|
1034
|
-
}
|
|
1022
|
+
budget_mapping = {Budget.LOW: 100, Budget.MID: 300, Budget.HIGH: 1000}
|
|
1035
1023
|
thinking_budget = budget_mapping[budget]
|
|
1036
1024
|
|
|
1037
1025
|
# Backpressure: limit concurrent recalls to prevent overwhelming the database
|
|
@@ -1041,20 +1029,29 @@ class MemoryEngine:
|
|
|
1041
1029
|
for attempt in range(max_retries + 1):
|
|
1042
1030
|
try:
|
|
1043
1031
|
return await self._search_with_retries(
|
|
1044
|
-
bank_id,
|
|
1045
|
-
|
|
1032
|
+
bank_id,
|
|
1033
|
+
query,
|
|
1034
|
+
fact_type,
|
|
1035
|
+
thinking_budget,
|
|
1036
|
+
max_tokens,
|
|
1037
|
+
enable_trace,
|
|
1038
|
+
question_date,
|
|
1039
|
+
include_entities,
|
|
1040
|
+
max_entity_tokens,
|
|
1041
|
+
include_chunks,
|
|
1042
|
+
max_chunk_tokens,
|
|
1046
1043
|
)
|
|
1047
1044
|
except Exception as e:
|
|
1048
1045
|
# Check if it's a connection error
|
|
1049
1046
|
is_connection_error = (
|
|
1050
|
-
isinstance(e, asyncpg.TooManyConnectionsError)
|
|
1051
|
-
isinstance(e, asyncpg.CannotConnectNowError)
|
|
1052
|
-
(isinstance(e, asyncpg.PostgresError) and
|
|
1047
|
+
isinstance(e, asyncpg.TooManyConnectionsError)
|
|
1048
|
+
or isinstance(e, asyncpg.CannotConnectNowError)
|
|
1049
|
+
or (isinstance(e, asyncpg.PostgresError) and "connection" in str(e).lower())
|
|
1053
1050
|
)
|
|
1054
1051
|
|
|
1055
1052
|
if is_connection_error and attempt < max_retries:
|
|
1056
1053
|
# Wait with exponential backoff before retry
|
|
1057
|
-
wait_time = 0.5 * (2
|
|
1054
|
+
wait_time = 0.5 * (2**attempt) # 0.5s, 1s, 2s
|
|
1058
1055
|
logger.warning(
|
|
1059
1056
|
f"Connection error on search attempt {attempt + 1}/{max_retries + 1}: {str(e)}. "
|
|
1060
1057
|
f"Retrying in {wait_time:.1f}s..."
|
|
@@ -1069,11 +1066,11 @@ class MemoryEngine:
|
|
|
1069
1066
|
self,
|
|
1070
1067
|
bank_id: str,
|
|
1071
1068
|
query: str,
|
|
1072
|
-
fact_type:
|
|
1069
|
+
fact_type: list[str],
|
|
1073
1070
|
thinking_budget: int,
|
|
1074
1071
|
max_tokens: int,
|
|
1075
1072
|
enable_trace: bool,
|
|
1076
|
-
question_date:
|
|
1073
|
+
question_date: datetime | None = None,
|
|
1077
1074
|
include_entities: bool = False,
|
|
1078
1075
|
max_entity_tokens: int = 500,
|
|
1079
1076
|
include_chunks: bool = False,
|
|
@@ -1106,6 +1103,7 @@ class MemoryEngine:
|
|
|
1106
1103
|
"""
|
|
1107
1104
|
# Initialize tracer if requested
|
|
1108
1105
|
from .search.tracer import SearchTracer
|
|
1106
|
+
|
|
1109
1107
|
tracer = SearchTracer(query, thinking_budget, max_tokens) if enable_trace else None
|
|
1110
1108
|
if tracer:
|
|
1111
1109
|
tracer.start()
|
|
@@ -1116,7 +1114,9 @@ class MemoryEngine:
|
|
|
1116
1114
|
# Buffer logs for clean output in concurrent scenarios
|
|
1117
1115
|
recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
|
|
1118
1116
|
log_buffer = []
|
|
1119
|
-
log_buffer.append(
|
|
1117
|
+
log_buffer.append(
|
|
1118
|
+
f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})"
|
|
1119
|
+
)
|
|
1120
1120
|
|
|
1121
1121
|
try:
|
|
1122
1122
|
# Step 1: Generate query embedding (for semantic search)
|
|
@@ -1141,8 +1141,7 @@ class MemoryEngine:
|
|
|
1141
1141
|
# Run retrieval for each fact type in parallel
|
|
1142
1142
|
retrieval_tasks = [
|
|
1143
1143
|
retrieve_parallel(
|
|
1144
|
-
pool, query, query_embedding_str, bank_id, ft, thinking_budget,
|
|
1145
|
-
question_date, self.query_analyzer
|
|
1144
|
+
pool, query, query_embedding_str, bank_id, ft, thinking_budget, question_date, self.query_analyzer
|
|
1146
1145
|
)
|
|
1147
1146
|
for ft in fact_type
|
|
1148
1147
|
]
|
|
@@ -1159,7 +1158,9 @@ class MemoryEngine:
|
|
|
1159
1158
|
for idx, retrieval_result in enumerate(all_retrievals):
|
|
1160
1159
|
# Log fact types in this retrieval batch
|
|
1161
1160
|
ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
|
|
1162
|
-
logger.debug(
|
|
1161
|
+
logger.debug(
|
|
1162
|
+
f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}"
|
|
1163
|
+
)
|
|
1163
1164
|
|
|
1164
1165
|
semantic_results.extend(retrieval_result.semantic)
|
|
1165
1166
|
bm25_results.extend(retrieval_result.bm25)
|
|
@@ -1179,11 +1180,13 @@ class MemoryEngine:
|
|
|
1179
1180
|
|
|
1180
1181
|
# Sort combined results by score (descending) so higher-scored results
|
|
1181
1182
|
# get better ranks in the trace, regardless of fact type
|
|
1182
|
-
semantic_results.sort(key=lambda r: r.similarity if hasattr(r,
|
|
1183
|
-
bm25_results.sort(key=lambda r: r.bm25_score if hasattr(r,
|
|
1184
|
-
graph_results.sort(key=lambda r: r.activation if hasattr(r,
|
|
1183
|
+
semantic_results.sort(key=lambda r: r.similarity if hasattr(r, "similarity") else 0, reverse=True)
|
|
1184
|
+
bm25_results.sort(key=lambda r: r.bm25_score if hasattr(r, "bm25_score") else 0, reverse=True)
|
|
1185
|
+
graph_results.sort(key=lambda r: r.activation if hasattr(r, "activation") else 0, reverse=True)
|
|
1185
1186
|
if temporal_results:
|
|
1186
|
-
temporal_results.sort(
|
|
1187
|
+
temporal_results.sort(
|
|
1188
|
+
key=lambda r: r.combined_score if hasattr(r, "combined_score") else 0, reverse=True
|
|
1189
|
+
)
|
|
1187
1190
|
|
|
1188
1191
|
retrieval_duration = time.time() - retrieval_start
|
|
1189
1192
|
|
|
@@ -1193,7 +1196,7 @@ class MemoryEngine:
|
|
|
1193
1196
|
timing_parts = [
|
|
1194
1197
|
f"semantic={len(semantic_results)}({aggregated_timings['semantic']:.3f}s)",
|
|
1195
1198
|
f"bm25={len(bm25_results)}({aggregated_timings['bm25']:.3f}s)",
|
|
1196
|
-
f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)"
|
|
1199
|
+
f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)",
|
|
1197
1200
|
]
|
|
1198
1201
|
temporal_info = ""
|
|
1199
1202
|
if detected_temporal_constraint:
|
|
@@ -1201,7 +1204,9 @@ class MemoryEngine:
|
|
|
1201
1204
|
temporal_count = len(temporal_results) if temporal_results else 0
|
|
1202
1205
|
timing_parts.append(f"temporal={temporal_count}({aggregated_timings['temporal']:.3f}s)")
|
|
1203
1206
|
temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
|
|
1204
|
-
log_buffer.append(
|
|
1207
|
+
log_buffer.append(
|
|
1208
|
+
f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}"
|
|
1209
|
+
)
|
|
1205
1210
|
|
|
1206
1211
|
# Record retrieval results for tracer - per fact type
|
|
1207
1212
|
if tracer:
|
|
@@ -1220,7 +1225,7 @@ class MemoryEngine:
|
|
|
1220
1225
|
duration_seconds=rr.timings.get("semantic", 0.0),
|
|
1221
1226
|
score_field="similarity",
|
|
1222
1227
|
metadata={"limit": thinking_budget},
|
|
1223
|
-
fact_type=ft_name
|
|
1228
|
+
fact_type=ft_name,
|
|
1224
1229
|
)
|
|
1225
1230
|
|
|
1226
1231
|
# Add BM25 retrieval results for this fact type
|
|
@@ -1230,7 +1235,7 @@ class MemoryEngine:
|
|
|
1230
1235
|
duration_seconds=rr.timings.get("bm25", 0.0),
|
|
1231
1236
|
score_field="bm25_score",
|
|
1232
1237
|
metadata={"limit": thinking_budget},
|
|
1233
|
-
fact_type=ft_name
|
|
1238
|
+
fact_type=ft_name,
|
|
1234
1239
|
)
|
|
1235
1240
|
|
|
1236
1241
|
# Add graph retrieval results for this fact type
|
|
@@ -1240,7 +1245,7 @@ class MemoryEngine:
|
|
|
1240
1245
|
duration_seconds=rr.timings.get("graph", 0.0),
|
|
1241
1246
|
score_field="activation",
|
|
1242
1247
|
metadata={"budget": thinking_budget},
|
|
1243
|
-
fact_type=ft_name
|
|
1248
|
+
fact_type=ft_name,
|
|
1244
1249
|
)
|
|
1245
1250
|
|
|
1246
1251
|
# Add temporal retrieval results for this fact type (even if empty, to show it ran)
|
|
@@ -1251,19 +1256,23 @@ class MemoryEngine:
|
|
|
1251
1256
|
duration_seconds=rr.timings.get("temporal", 0.0),
|
|
1252
1257
|
score_field="temporal_score",
|
|
1253
1258
|
metadata={"budget": thinking_budget},
|
|
1254
|
-
fact_type=ft_name
|
|
1259
|
+
fact_type=ft_name,
|
|
1255
1260
|
)
|
|
1256
1261
|
|
|
1257
1262
|
# Record entry points (from semantic results) for legacy graph view
|
|
1258
1263
|
for rank, retrieval in enumerate(semantic_results[:10], start=1): # Top 10 as entry points
|
|
1259
1264
|
tracer.add_entry_point(retrieval.id, retrieval.text, retrieval.similarity or 0.0, rank)
|
|
1260
1265
|
|
|
1261
|
-
tracer.add_phase_metric(
|
|
1262
|
-
"
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1266
|
+
tracer.add_phase_metric(
|
|
1267
|
+
"parallel_retrieval",
|
|
1268
|
+
step_duration,
|
|
1269
|
+
{
|
|
1270
|
+
"semantic_count": len(semantic_results),
|
|
1271
|
+
"bm25_count": len(bm25_results),
|
|
1272
|
+
"graph_count": len(graph_results),
|
|
1273
|
+
"temporal_count": len(temporal_results) if temporal_results else 0,
|
|
1274
|
+
},
|
|
1275
|
+
)
|
|
1267
1276
|
|
|
1268
1277
|
# Step 3: Merge with RRF
|
|
1269
1278
|
step_start = time.time()
|
|
@@ -1271,7 +1280,9 @@ class MemoryEngine:
|
|
|
1271
1280
|
|
|
1272
1281
|
# Merge 3 or 4 result lists depending on temporal constraint
|
|
1273
1282
|
if temporal_results:
|
|
1274
|
-
merged_candidates = reciprocal_rank_fusion(
|
|
1283
|
+
merged_candidates = reciprocal_rank_fusion(
|
|
1284
|
+
[semantic_results, bm25_results, graph_results, temporal_results]
|
|
1285
|
+
)
|
|
1275
1286
|
else:
|
|
1276
1287
|
merged_candidates = reciprocal_rank_fusion([semantic_results, bm25_results, graph_results])
|
|
1277
1288
|
|
|
@@ -1280,8 +1291,10 @@ class MemoryEngine:
|
|
|
1280
1291
|
|
|
1281
1292
|
if tracer:
|
|
1282
1293
|
# Convert MergedCandidate to old tuple format for tracer
|
|
1283
|
-
tracer_merged = [
|
|
1284
|
-
|
|
1294
|
+
tracer_merged = [
|
|
1295
|
+
(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
|
|
1296
|
+
for mc in merged_candidates
|
|
1297
|
+
]
|
|
1285
1298
|
tracer.add_rrf_merged(tracer_merged)
|
|
1286
1299
|
tracer.add_phase_metric("rrf_merge", step_duration, {"candidates_merged": len(merged_candidates)})
|
|
1287
1300
|
|
|
@@ -1318,14 +1331,15 @@ class MemoryEngine:
|
|
|
1318
1331
|
sr.recency = 0.5 # default for missing dates
|
|
1319
1332
|
if sr.retrieval.occurred_start:
|
|
1320
1333
|
occurred = sr.retrieval.occurred_start
|
|
1321
|
-
if hasattr(occurred,
|
|
1322
|
-
|
|
1323
|
-
occurred = occurred.replace(tzinfo=timezone.utc)
|
|
1334
|
+
if hasattr(occurred, "tzinfo") and occurred.tzinfo is None:
|
|
1335
|
+
occurred = occurred.replace(tzinfo=UTC)
|
|
1324
1336
|
days_ago = (now - occurred).total_seconds() / 86400
|
|
1325
1337
|
sr.recency = max(0.1, 1.0 - (days_ago / 365)) # Linear decay over 1 year
|
|
1326
1338
|
|
|
1327
1339
|
# Get temporal proximity if available (already 0-1)
|
|
1328
|
-
sr.temporal =
|
|
1340
|
+
sr.temporal = (
|
|
1341
|
+
sr.retrieval.temporal_proximity if sr.retrieval.temporal_proximity is not None else 0.5
|
|
1342
|
+
)
|
|
1329
1343
|
|
|
1330
1344
|
# Weighted combination
|
|
1331
1345
|
# Cross-encoder: 60% (semantic relevance)
|
|
@@ -1333,27 +1347,32 @@ class MemoryEngine:
|
|
|
1333
1347
|
# Temporal proximity: 10% (time relevance for temporal queries)
|
|
1334
1348
|
# Recency: 10% (prefer recent facts)
|
|
1335
1349
|
sr.combined_score = (
|
|
1336
|
-
0.6 * sr.cross_encoder_score_normalized
|
|
1337
|
-
0.2 * sr.rrf_normalized
|
|
1338
|
-
0.1 * sr.temporal
|
|
1339
|
-
0.1 * sr.recency
|
|
1350
|
+
0.6 * sr.cross_encoder_score_normalized
|
|
1351
|
+
+ 0.2 * sr.rrf_normalized
|
|
1352
|
+
+ 0.1 * sr.temporal
|
|
1353
|
+
+ 0.1 * sr.recency
|
|
1340
1354
|
)
|
|
1341
1355
|
sr.weight = sr.combined_score # Update weight for final ranking
|
|
1342
1356
|
|
|
1343
1357
|
# Re-sort by combined score
|
|
1344
1358
|
scored_results.sort(key=lambda x: x.weight, reverse=True)
|
|
1345
|
-
log_buffer.append(
|
|
1359
|
+
log_buffer.append(
|
|
1360
|
+
" [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)"
|
|
1361
|
+
)
|
|
1346
1362
|
|
|
1347
1363
|
# Add reranked results to tracer AFTER combined scoring (so normalized values are included)
|
|
1348
1364
|
if tracer:
|
|
1349
1365
|
results_dict = [sr.to_dict() for sr in scored_results]
|
|
1350
|
-
tracer_merged = [
|
|
1351
|
-
|
|
1366
|
+
tracer_merged = [
|
|
1367
|
+
(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
|
|
1368
|
+
for mc in merged_candidates
|
|
1369
|
+
]
|
|
1352
1370
|
tracer.add_reranked(results_dict, tracer_merged)
|
|
1353
|
-
tracer.add_phase_metric(
|
|
1354
|
-
"
|
|
1355
|
-
|
|
1356
|
-
|
|
1371
|
+
tracer.add_phase_metric(
|
|
1372
|
+
"reranking",
|
|
1373
|
+
step_duration,
|
|
1374
|
+
{"reranker_type": "cross-encoder", "candidates_reranked": len(scored_results)},
|
|
1375
|
+
)
|
|
1357
1376
|
|
|
1358
1377
|
# Step 5: Truncate to thinking_budget * 2 for token filtering
|
|
1359
1378
|
rerank_limit = thinking_budget * 2
|
|
@@ -1372,14 +1391,16 @@ class MemoryEngine:
|
|
|
1372
1391
|
top_scored = [sr for sr in top_scored if sr.id in filtered_ids]
|
|
1373
1392
|
|
|
1374
1393
|
step_duration = time.time() - step_start
|
|
1375
|
-
log_buffer.append(
|
|
1394
|
+
log_buffer.append(
|
|
1395
|
+
f" [6] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s"
|
|
1396
|
+
)
|
|
1376
1397
|
|
|
1377
1398
|
if tracer:
|
|
1378
|
-
tracer.add_phase_metric(
|
|
1379
|
-
"
|
|
1380
|
-
|
|
1381
|
-
"max_tokens": max_tokens
|
|
1382
|
-
|
|
1399
|
+
tracer.add_phase_metric(
|
|
1400
|
+
"token_filtering",
|
|
1401
|
+
step_duration,
|
|
1402
|
+
{"results_selected": len(top_scored), "tokens_used": total_tokens, "max_tokens": max_tokens},
|
|
1403
|
+
)
|
|
1383
1404
|
|
|
1384
1405
|
# Record visits for all retrieved nodes
|
|
1385
1406
|
if tracer:
|
|
@@ -1398,16 +1419,13 @@ class MemoryEngine:
|
|
|
1398
1419
|
semantic_similarity=sr.retrieval.similarity or 0.0,
|
|
1399
1420
|
recency=sr.recency,
|
|
1400
1421
|
frequency=0.0,
|
|
1401
|
-
final_weight=sr.weight
|
|
1422
|
+
final_weight=sr.weight,
|
|
1402
1423
|
)
|
|
1403
1424
|
|
|
1404
1425
|
# Step 8: Queue access count updates for visited nodes
|
|
1405
1426
|
visited_ids = list(set([sr.id for sr in scored_results[:50]])) # Top 50
|
|
1406
1427
|
if visited_ids:
|
|
1407
|
-
await self._task_backend.submit_task({
|
|
1408
|
-
'type': 'access_count_update',
|
|
1409
|
-
'node_ids': visited_ids
|
|
1410
|
-
})
|
|
1428
|
+
await self._task_backend.submit_task({"type": "access_count_update", "node_ids": visited_ids})
|
|
1411
1429
|
log_buffer.append(f" [7] Queued access count updates for {len(visited_ids)} nodes")
|
|
1412
1430
|
|
|
1413
1431
|
# Log fact_type distribution in results
|
|
@@ -1425,13 +1443,19 @@ class MemoryEngine:
|
|
|
1425
1443
|
# Convert datetime objects to ISO strings for JSON serialization
|
|
1426
1444
|
if result_dict.get("occurred_start"):
|
|
1427
1445
|
occurred_start = result_dict["occurred_start"]
|
|
1428
|
-
result_dict["occurred_start"] =
|
|
1446
|
+
result_dict["occurred_start"] = (
|
|
1447
|
+
occurred_start.isoformat() if hasattr(occurred_start, "isoformat") else occurred_start
|
|
1448
|
+
)
|
|
1429
1449
|
if result_dict.get("occurred_end"):
|
|
1430
1450
|
occurred_end = result_dict["occurred_end"]
|
|
1431
|
-
result_dict["occurred_end"] =
|
|
1451
|
+
result_dict["occurred_end"] = (
|
|
1452
|
+
occurred_end.isoformat() if hasattr(occurred_end, "isoformat") else occurred_end
|
|
1453
|
+
)
|
|
1432
1454
|
if result_dict.get("mentioned_at"):
|
|
1433
1455
|
mentioned_at = result_dict["mentioned_at"]
|
|
1434
|
-
result_dict["mentioned_at"] =
|
|
1456
|
+
result_dict["mentioned_at"] = (
|
|
1457
|
+
mentioned_at.isoformat() if hasattr(mentioned_at, "isoformat") else mentioned_at
|
|
1458
|
+
)
|
|
1435
1459
|
top_results_dicts.append(result_dict)
|
|
1436
1460
|
|
|
1437
1461
|
# Get entities for each fact if include_entities is requested
|
|
@@ -1447,16 +1471,15 @@ class MemoryEngine:
|
|
|
1447
1471
|
JOIN entities e ON ue.entity_id = e.id
|
|
1448
1472
|
WHERE ue.unit_id = ANY($1::uuid[])
|
|
1449
1473
|
""",
|
|
1450
|
-
unit_ids
|
|
1474
|
+
unit_ids,
|
|
1451
1475
|
)
|
|
1452
1476
|
for row in entity_rows:
|
|
1453
|
-
unit_id = str(row[
|
|
1477
|
+
unit_id = str(row["unit_id"])
|
|
1454
1478
|
if unit_id not in fact_entity_map:
|
|
1455
1479
|
fact_entity_map[unit_id] = []
|
|
1456
|
-
fact_entity_map[unit_id].append(
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
})
|
|
1480
|
+
fact_entity_map[unit_id].append(
|
|
1481
|
+
{"entity_id": str(row["entity_id"]), "canonical_name": row["canonical_name"]}
|
|
1482
|
+
)
|
|
1460
1483
|
|
|
1461
1484
|
# Convert results to MemoryFact objects
|
|
1462
1485
|
memory_facts = []
|
|
@@ -1465,20 +1488,22 @@ class MemoryEngine:
|
|
|
1465
1488
|
# Get entity names for this fact
|
|
1466
1489
|
entity_names = None
|
|
1467
1490
|
if include_entities and result_id in fact_entity_map:
|
|
1468
|
-
entity_names = [e[
|
|
1469
|
-
|
|
1470
|
-
memory_facts.append(
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1491
|
+
entity_names = [e["canonical_name"] for e in fact_entity_map[result_id]]
|
|
1492
|
+
|
|
1493
|
+
memory_facts.append(
|
|
1494
|
+
MemoryFact(
|
|
1495
|
+
id=result_id,
|
|
1496
|
+
text=result_dict.get("text"),
|
|
1497
|
+
fact_type=result_dict.get("fact_type", "world"),
|
|
1498
|
+
entities=entity_names,
|
|
1499
|
+
context=result_dict.get("context"),
|
|
1500
|
+
occurred_start=result_dict.get("occurred_start"),
|
|
1501
|
+
occurred_end=result_dict.get("occurred_end"),
|
|
1502
|
+
mentioned_at=result_dict.get("mentioned_at"),
|
|
1503
|
+
document_id=result_dict.get("document_id"),
|
|
1504
|
+
chunk_id=result_dict.get("chunk_id"),
|
|
1505
|
+
)
|
|
1506
|
+
)
|
|
1482
1507
|
|
|
1483
1508
|
# Fetch entity observations if requested
|
|
1484
1509
|
entities_dict = None
|
|
@@ -1495,8 +1520,8 @@ class MemoryEngine:
|
|
|
1495
1520
|
unit_id = sr.id
|
|
1496
1521
|
if unit_id in fact_entity_map:
|
|
1497
1522
|
for entity in fact_entity_map[unit_id]:
|
|
1498
|
-
entity_id = entity[
|
|
1499
|
-
entity_name = entity[
|
|
1523
|
+
entity_id = entity["entity_id"]
|
|
1524
|
+
entity_name = entity["canonical_name"]
|
|
1500
1525
|
if entity_id not in seen_entity_ids:
|
|
1501
1526
|
entities_ordered.append((entity_id, entity_name))
|
|
1502
1527
|
seen_entity_ids.add(entity_id)
|
|
@@ -1524,9 +1549,7 @@ class MemoryEngine:
|
|
|
1524
1549
|
|
|
1525
1550
|
if included_observations:
|
|
1526
1551
|
entities_dict[entity_name] = EntityState(
|
|
1527
|
-
entity_id=entity_id,
|
|
1528
|
-
canonical_name=entity_name,
|
|
1529
|
-
observations=included_observations
|
|
1552
|
+
entity_id=entity_id, canonical_name=entity_name, observations=included_observations
|
|
1530
1553
|
)
|
|
1531
1554
|
total_entity_tokens += entity_tokens
|
|
1532
1555
|
|
|
@@ -1554,11 +1577,11 @@ class MemoryEngine:
|
|
|
1554
1577
|
FROM chunks
|
|
1555
1578
|
WHERE chunk_id = ANY($1::text[])
|
|
1556
1579
|
""",
|
|
1557
|
-
chunk_ids_ordered
|
|
1580
|
+
chunk_ids_ordered,
|
|
1558
1581
|
)
|
|
1559
1582
|
|
|
1560
1583
|
# Create a lookup dict for fast access
|
|
1561
|
-
chunks_lookup = {row[
|
|
1584
|
+
chunks_lookup = {row["chunk_id"]: row for row in chunks_rows}
|
|
1562
1585
|
|
|
1563
1586
|
# Apply token limit and build chunks_dict in the order of chunk_ids_ordered
|
|
1564
1587
|
chunks_dict = {}
|
|
@@ -1569,7 +1592,7 @@ class MemoryEngine:
|
|
|
1569
1592
|
continue
|
|
1570
1593
|
|
|
1571
1594
|
row = chunks_lookup[chunk_id]
|
|
1572
|
-
chunk_text = row[
|
|
1595
|
+
chunk_text = row["chunk_text"]
|
|
1573
1596
|
chunk_tokens = len(encoding.encode(chunk_text))
|
|
1574
1597
|
|
|
1575
1598
|
# Check if adding this chunk would exceed the limit
|
|
@@ -1580,18 +1603,14 @@ class MemoryEngine:
|
|
|
1580
1603
|
# Truncate to remaining tokens
|
|
1581
1604
|
truncated_text = encoding.decode(encoding.encode(chunk_text)[:remaining_tokens])
|
|
1582
1605
|
chunks_dict[chunk_id] = ChunkInfo(
|
|
1583
|
-
chunk_text=truncated_text,
|
|
1584
|
-
chunk_index=row['chunk_index'],
|
|
1585
|
-
truncated=True
|
|
1606
|
+
chunk_text=truncated_text, chunk_index=row["chunk_index"], truncated=True
|
|
1586
1607
|
)
|
|
1587
1608
|
total_chunk_tokens = max_chunk_tokens
|
|
1588
1609
|
# Stop adding more chunks once we hit the limit
|
|
1589
1610
|
break
|
|
1590
1611
|
else:
|
|
1591
1612
|
chunks_dict[chunk_id] = ChunkInfo(
|
|
1592
|
-
chunk_text=chunk_text,
|
|
1593
|
-
chunk_index=row['chunk_index'],
|
|
1594
|
-
truncated=False
|
|
1613
|
+
chunk_text=chunk_text, chunk_index=row["chunk_index"], truncated=False
|
|
1595
1614
|
)
|
|
1596
1615
|
total_chunk_tokens += chunk_tokens
|
|
1597
1616
|
|
|
@@ -1605,7 +1624,9 @@ class MemoryEngine:
|
|
|
1605
1624
|
total_time = time.time() - recall_start
|
|
1606
1625
|
num_chunks = len(chunks_dict) if chunks_dict else 0
|
|
1607
1626
|
num_entities = len(entities_dict) if entities_dict else 0
|
|
1608
|
-
log_buffer.append(
|
|
1627
|
+
log_buffer.append(
|
|
1628
|
+
f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s"
|
|
1629
|
+
)
|
|
1609
1630
|
logger.info("\n" + "\n".join(log_buffer))
|
|
1610
1631
|
|
|
1611
1632
|
return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
|
|
@@ -1616,10 +1637,8 @@ class MemoryEngine:
|
|
|
1616
1637
|
raise Exception(f"Failed to search memories: {str(e)}")
|
|
1617
1638
|
|
|
1618
1639
|
def _filter_by_token_budget(
|
|
1619
|
-
self,
|
|
1620
|
-
|
|
1621
|
-
max_tokens: int
|
|
1622
|
-
) -> Tuple[List[Dict[str, Any]], int]:
|
|
1640
|
+
self, results: list[dict[str, Any]], max_tokens: int
|
|
1641
|
+
) -> tuple[list[dict[str, Any]], int]:
|
|
1623
1642
|
"""
|
|
1624
1643
|
Filter results to fit within token budget.
|
|
1625
1644
|
|
|
@@ -1652,7 +1671,7 @@ class MemoryEngine:
|
|
|
1652
1671
|
|
|
1653
1672
|
return filtered_results, total_tokens
|
|
1654
1673
|
|
|
1655
|
-
async def get_document(self, document_id: str, bank_id: str) ->
|
|
1674
|
+
async def get_document(self, document_id: str, bank_id: str) -> dict[str, Any] | None:
|
|
1656
1675
|
"""
|
|
1657
1676
|
Retrieve document metadata and statistics.
|
|
1658
1677
|
|
|
@@ -1674,7 +1693,8 @@ class MemoryEngine:
|
|
|
1674
1693
|
WHERE d.id = $1 AND d.bank_id = $2
|
|
1675
1694
|
GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
|
|
1676
1695
|
""",
|
|
1677
|
-
document_id,
|
|
1696
|
+
document_id,
|
|
1697
|
+
bank_id,
|
|
1678
1698
|
)
|
|
1679
1699
|
|
|
1680
1700
|
if not doc:
|
|
@@ -1687,10 +1707,10 @@ class MemoryEngine:
|
|
|
1687
1707
|
"content_hash": doc["content_hash"],
|
|
1688
1708
|
"memory_unit_count": doc["unit_count"],
|
|
1689
1709
|
"created_at": doc["created_at"],
|
|
1690
|
-
"updated_at": doc["updated_at"]
|
|
1710
|
+
"updated_at": doc["updated_at"],
|
|
1691
1711
|
}
|
|
1692
1712
|
|
|
1693
|
-
async def delete_document(self, document_id: str, bank_id: str) ->
|
|
1713
|
+
async def delete_document(self, document_id: str, bank_id: str) -> dict[str, int]:
|
|
1694
1714
|
"""
|
|
1695
1715
|
Delete a document and all its associated memory units and links.
|
|
1696
1716
|
|
|
@@ -1706,22 +1726,17 @@ class MemoryEngine:
|
|
|
1706
1726
|
async with conn.transaction():
|
|
1707
1727
|
# Count units before deletion
|
|
1708
1728
|
units_count = await conn.fetchval(
|
|
1709
|
-
"SELECT COUNT(*) FROM memory_units WHERE document_id = $1",
|
|
1710
|
-
document_id
|
|
1729
|
+
"SELECT COUNT(*) FROM memory_units WHERE document_id = $1", document_id
|
|
1711
1730
|
)
|
|
1712
1731
|
|
|
1713
1732
|
# Delete document (cascades to memory_units and all their links)
|
|
1714
1733
|
deleted = await conn.fetchval(
|
|
1715
|
-
"DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id",
|
|
1716
|
-
document_id, bank_id
|
|
1734
|
+
"DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id", document_id, bank_id
|
|
1717
1735
|
)
|
|
1718
1736
|
|
|
1719
|
-
return {
|
|
1720
|
-
"document_deleted": 1 if deleted else 0,
|
|
1721
|
-
"memory_units_deleted": units_count if deleted else 0
|
|
1722
|
-
}
|
|
1737
|
+
return {"document_deleted": 1 if deleted else 0, "memory_units_deleted": units_count if deleted else 0}
|
|
1723
1738
|
|
|
1724
|
-
async def delete_memory_unit(self, unit_id: str) ->
|
|
1739
|
+
async def delete_memory_unit(self, unit_id: str) -> dict[str, Any]:
|
|
1725
1740
|
"""
|
|
1726
1741
|
Delete a single memory unit and all its associated links.
|
|
1727
1742
|
|
|
@@ -1740,18 +1755,17 @@ class MemoryEngine:
|
|
|
1740
1755
|
async with acquire_with_retry(pool) as conn:
|
|
1741
1756
|
async with conn.transaction():
|
|
1742
1757
|
# Delete the memory unit (cascades to links and associations)
|
|
1743
|
-
deleted = await conn.fetchval(
|
|
1744
|
-
"DELETE FROM memory_units WHERE id = $1 RETURNING id",
|
|
1745
|
-
unit_id
|
|
1746
|
-
)
|
|
1758
|
+
deleted = await conn.fetchval("DELETE FROM memory_units WHERE id = $1 RETURNING id", unit_id)
|
|
1747
1759
|
|
|
1748
1760
|
return {
|
|
1749
1761
|
"success": deleted is not None,
|
|
1750
1762
|
"unit_id": str(deleted) if deleted else None,
|
|
1751
|
-
"message": "Memory unit and all its links deleted successfully"
|
|
1763
|
+
"message": "Memory unit and all its links deleted successfully"
|
|
1764
|
+
if deleted
|
|
1765
|
+
else "Memory unit not found",
|
|
1752
1766
|
}
|
|
1753
1767
|
|
|
1754
|
-
async def delete_bank(self, bank_id: str, fact_type:
|
|
1768
|
+
async def delete_bank(self, bank_id: str, fact_type: str | None = None) -> dict[str, int]:
|
|
1755
1769
|
"""
|
|
1756
1770
|
Delete all data for a specific agent (multi-tenant cleanup).
|
|
1757
1771
|
|
|
@@ -1780,24 +1794,27 @@ class MemoryEngine:
|
|
|
1780
1794
|
# Delete only memories of a specific fact type
|
|
1781
1795
|
units_count = await conn.fetchval(
|
|
1782
1796
|
"SELECT COUNT(*) FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
|
|
1783
|
-
bank_id,
|
|
1797
|
+
bank_id,
|
|
1798
|
+
fact_type,
|
|
1784
1799
|
)
|
|
1785
1800
|
await conn.execute(
|
|
1786
|
-
"DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
|
|
1787
|
-
bank_id, fact_type
|
|
1801
|
+
"DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2", bank_id, fact_type
|
|
1788
1802
|
)
|
|
1789
1803
|
|
|
1790
1804
|
# Note: We don't delete entities when fact_type is specified,
|
|
1791
1805
|
# as they may be referenced by other memory units
|
|
1792
|
-
return {
|
|
1793
|
-
"memory_units_deleted": units_count,
|
|
1794
|
-
"entities_deleted": 0
|
|
1795
|
-
}
|
|
1806
|
+
return {"memory_units_deleted": units_count, "entities_deleted": 0}
|
|
1796
1807
|
else:
|
|
1797
1808
|
# Delete all data for the bank
|
|
1798
|
-
units_count = await conn.fetchval(
|
|
1799
|
-
|
|
1800
|
-
|
|
1809
|
+
units_count = await conn.fetchval(
|
|
1810
|
+
"SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id
|
|
1811
|
+
)
|
|
1812
|
+
entities_count = await conn.fetchval(
|
|
1813
|
+
"SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id
|
|
1814
|
+
)
|
|
1815
|
+
documents_count = await conn.fetchval(
|
|
1816
|
+
"SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id
|
|
1817
|
+
)
|
|
1801
1818
|
|
|
1802
1819
|
# Delete documents (cascades to chunks)
|
|
1803
1820
|
await conn.execute("DELETE FROM documents WHERE bank_id = $1", bank_id)
|
|
@@ -1815,13 +1832,13 @@ class MemoryEngine:
|
|
|
1815
1832
|
"memory_units_deleted": units_count,
|
|
1816
1833
|
"entities_deleted": entities_count,
|
|
1817
1834
|
"documents_deleted": documents_count,
|
|
1818
|
-
"bank_deleted": True
|
|
1835
|
+
"bank_deleted": True,
|
|
1819
1836
|
}
|
|
1820
1837
|
|
|
1821
1838
|
except Exception as e:
|
|
1822
1839
|
raise Exception(f"Failed to delete agent data: {str(e)}")
|
|
1823
1840
|
|
|
1824
|
-
async def get_graph_data(self, bank_id:
|
|
1841
|
+
async def get_graph_data(self, bank_id: str | None = None, fact_type: str | None = None):
|
|
1825
1842
|
"""
|
|
1826
1843
|
Get graph data for visualization.
|
|
1827
1844
|
|
|
@@ -1851,19 +1868,23 @@ class MemoryEngine:
|
|
|
1851
1868
|
|
|
1852
1869
|
where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
|
|
1853
1870
|
|
|
1854
|
-
units = await conn.fetch(
|
|
1871
|
+
units = await conn.fetch(
|
|
1872
|
+
f"""
|
|
1855
1873
|
SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
|
|
1856
1874
|
FROM memory_units
|
|
1857
1875
|
{where_clause}
|
|
1858
1876
|
ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
|
|
1859
1877
|
LIMIT 1000
|
|
1860
|
-
""",
|
|
1878
|
+
""",
|
|
1879
|
+
*query_params,
|
|
1880
|
+
)
|
|
1861
1881
|
|
|
1862
1882
|
# Get links, filtering to only include links between units of the selected agent
|
|
1863
1883
|
# Use DISTINCT ON with LEAST/GREATEST to deduplicate bidirectional links
|
|
1864
|
-
unit_ids = [row[
|
|
1884
|
+
unit_ids = [row["id"] for row in units]
|
|
1865
1885
|
if unit_ids:
|
|
1866
|
-
links = await conn.fetch(
|
|
1886
|
+
links = await conn.fetch(
|
|
1887
|
+
"""
|
|
1867
1888
|
SELECT DISTINCT ON (LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid))
|
|
1868
1889
|
ml.from_unit_id,
|
|
1869
1890
|
ml.to_unit_id,
|
|
@@ -1874,7 +1895,9 @@ class MemoryEngine:
|
|
|
1874
1895
|
LEFT JOIN entities e ON ml.entity_id = e.id
|
|
1875
1896
|
WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.to_unit_id = ANY($1::uuid[])
|
|
1876
1897
|
ORDER BY LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid), ml.weight DESC
|
|
1877
|
-
""",
|
|
1898
|
+
""",
|
|
1899
|
+
unit_ids,
|
|
1900
|
+
)
|
|
1878
1901
|
else:
|
|
1879
1902
|
links = []
|
|
1880
1903
|
|
|
@@ -1889,8 +1912,8 @@ class MemoryEngine:
|
|
|
1889
1912
|
# Build entity mapping
|
|
1890
1913
|
entity_map = {}
|
|
1891
1914
|
for row in unit_entities:
|
|
1892
|
-
unit_id = row[
|
|
1893
|
-
entity_name = row[
|
|
1915
|
+
unit_id = row["unit_id"]
|
|
1916
|
+
entity_name = row["canonical_name"]
|
|
1894
1917
|
if unit_id not in entity_map:
|
|
1895
1918
|
entity_map[unit_id] = []
|
|
1896
1919
|
entity_map[unit_id].append(entity_name)
|
|
@@ -1898,10 +1921,10 @@ class MemoryEngine:
|
|
|
1898
1921
|
# Build nodes
|
|
1899
1922
|
nodes = []
|
|
1900
1923
|
for row in units:
|
|
1901
|
-
unit_id = row[
|
|
1902
|
-
text = row[
|
|
1903
|
-
event_date = row[
|
|
1904
|
-
context = row[
|
|
1924
|
+
unit_id = row["id"]
|
|
1925
|
+
text = row["text"]
|
|
1926
|
+
event_date = row["event_date"]
|
|
1927
|
+
context = row["context"]
|
|
1905
1928
|
|
|
1906
1929
|
entities = entity_map.get(unit_id, [])
|
|
1907
1930
|
entity_count = len(entities)
|
|
@@ -1914,88 +1937,91 @@ class MemoryEngine:
|
|
|
1914
1937
|
else:
|
|
1915
1938
|
color = "#42a5f5"
|
|
1916
1939
|
|
|
1917
|
-
nodes.append(
|
|
1918
|
-
|
|
1919
|
-
"
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1940
|
+
nodes.append(
|
|
1941
|
+
{
|
|
1942
|
+
"data": {
|
|
1943
|
+
"id": str(unit_id),
|
|
1944
|
+
"label": f"{text[:30]}..." if len(text) > 30 else text,
|
|
1945
|
+
"text": text,
|
|
1946
|
+
"date": event_date.isoformat() if event_date else "",
|
|
1947
|
+
"context": context if context else "",
|
|
1948
|
+
"entities": ", ".join(entities) if entities else "None",
|
|
1949
|
+
"color": color,
|
|
1950
|
+
}
|
|
1926
1951
|
}
|
|
1927
|
-
|
|
1952
|
+
)
|
|
1928
1953
|
|
|
1929
1954
|
# Build edges
|
|
1930
1955
|
edges = []
|
|
1931
1956
|
for row in links:
|
|
1932
|
-
from_id = str(row[
|
|
1933
|
-
to_id = str(row[
|
|
1934
|
-
link_type = row[
|
|
1935
|
-
weight = row[
|
|
1936
|
-
entity_name = row[
|
|
1957
|
+
from_id = str(row["from_unit_id"])
|
|
1958
|
+
to_id = str(row["to_unit_id"])
|
|
1959
|
+
link_type = row["link_type"]
|
|
1960
|
+
weight = row["weight"]
|
|
1961
|
+
entity_name = row["entity_name"]
|
|
1937
1962
|
|
|
1938
1963
|
# Color by link type
|
|
1939
|
-
if link_type ==
|
|
1964
|
+
if link_type == "temporal":
|
|
1940
1965
|
color = "#00bcd4"
|
|
1941
1966
|
line_style = "dashed"
|
|
1942
|
-
elif link_type ==
|
|
1967
|
+
elif link_type == "semantic":
|
|
1943
1968
|
color = "#ff69b4"
|
|
1944
1969
|
line_style = "solid"
|
|
1945
|
-
elif link_type ==
|
|
1970
|
+
elif link_type == "entity":
|
|
1946
1971
|
color = "#ffd700"
|
|
1947
1972
|
line_style = "solid"
|
|
1948
1973
|
else:
|
|
1949
1974
|
color = "#999999"
|
|
1950
1975
|
line_style = "solid"
|
|
1951
1976
|
|
|
1952
|
-
edges.append(
|
|
1953
|
-
|
|
1954
|
-
"
|
|
1955
|
-
|
|
1956
|
-
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1977
|
+
edges.append(
|
|
1978
|
+
{
|
|
1979
|
+
"data": {
|
|
1980
|
+
"id": f"{from_id}-{to_id}-{link_type}",
|
|
1981
|
+
"source": from_id,
|
|
1982
|
+
"target": to_id,
|
|
1983
|
+
"linkType": link_type,
|
|
1984
|
+
"weight": weight,
|
|
1985
|
+
"entityName": entity_name if entity_name else "",
|
|
1986
|
+
"color": color,
|
|
1987
|
+
"lineStyle": line_style,
|
|
1988
|
+
}
|
|
1962
1989
|
}
|
|
1963
|
-
|
|
1990
|
+
)
|
|
1964
1991
|
|
|
1965
1992
|
# Build table rows
|
|
1966
1993
|
table_rows = []
|
|
1967
1994
|
for row in units:
|
|
1968
|
-
unit_id = row[
|
|
1995
|
+
unit_id = row["id"]
|
|
1969
1996
|
entities = entity_map.get(unit_id, [])
|
|
1970
1997
|
|
|
1971
|
-
table_rows.append(
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
|
|
1990
|
-
}
|
|
1998
|
+
table_rows.append(
|
|
1999
|
+
{
|
|
2000
|
+
"id": str(unit_id),
|
|
2001
|
+
"text": row["text"],
|
|
2002
|
+
"context": row["context"] if row["context"] else "N/A",
|
|
2003
|
+
"occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
|
|
2004
|
+
"occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
|
|
2005
|
+
"mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
|
|
2006
|
+
"date": row["event_date"].strftime("%Y-%m-%d %H:%M")
|
|
2007
|
+
if row["event_date"]
|
|
2008
|
+
else "N/A", # Deprecated, kept for backwards compatibility
|
|
2009
|
+
"entities": ", ".join(entities) if entities else "None",
|
|
2010
|
+
"document_id": row["document_id"],
|
|
2011
|
+
"chunk_id": row["chunk_id"] if row["chunk_id"] else None,
|
|
2012
|
+
"fact_type": row["fact_type"],
|
|
2013
|
+
}
|
|
2014
|
+
)
|
|
2015
|
+
|
|
2016
|
+
return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units": len(units)}
|
|
1991
2017
|
|
|
1992
2018
|
async def list_memory_units(
|
|
1993
2019
|
self,
|
|
1994
|
-
bank_id:
|
|
1995
|
-
fact_type:
|
|
1996
|
-
search_query:
|
|
2020
|
+
bank_id: str | None = None,
|
|
2021
|
+
fact_type: str | None = None,
|
|
2022
|
+
search_query: str | None = None,
|
|
1997
2023
|
limit: int = 100,
|
|
1998
|
-
offset: int = 0
|
|
2024
|
+
offset: int = 0,
|
|
1999
2025
|
):
|
|
2000
2026
|
"""
|
|
2001
2027
|
List memory units for table view with optional full-text search.
|
|
@@ -2042,7 +2068,7 @@ class MemoryEngine:
|
|
|
2042
2068
|
{where_clause}
|
|
2043
2069
|
"""
|
|
2044
2070
|
count_result = await conn.fetchrow(count_query, *query_params)
|
|
2045
|
-
total = count_result[
|
|
2071
|
+
total = count_result["total"]
|
|
2046
2072
|
|
|
2047
2073
|
# Get units with limit and offset
|
|
2048
2074
|
param_count += 1
|
|
@@ -2053,32 +2079,38 @@ class MemoryEngine:
|
|
|
2053
2079
|
offset_param = f"${param_count}"
|
|
2054
2080
|
query_params.append(offset)
|
|
2055
2081
|
|
|
2056
|
-
units = await conn.fetch(
|
|
2082
|
+
units = await conn.fetch(
|
|
2083
|
+
f"""
|
|
2057
2084
|
SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id
|
|
2058
2085
|
FROM memory_units
|
|
2059
2086
|
{where_clause}
|
|
2060
2087
|
ORDER BY mentioned_at DESC NULLS LAST, created_at DESC
|
|
2061
2088
|
LIMIT {limit_param} OFFSET {offset_param}
|
|
2062
|
-
""",
|
|
2089
|
+
""",
|
|
2090
|
+
*query_params,
|
|
2091
|
+
)
|
|
2063
2092
|
|
|
2064
2093
|
# Get entity information for these units
|
|
2065
2094
|
if units:
|
|
2066
|
-
unit_ids = [row[
|
|
2067
|
-
unit_entities = await conn.fetch(
|
|
2095
|
+
unit_ids = [row["id"] for row in units]
|
|
2096
|
+
unit_entities = await conn.fetch(
|
|
2097
|
+
"""
|
|
2068
2098
|
SELECT ue.unit_id, e.canonical_name
|
|
2069
2099
|
FROM unit_entities ue
|
|
2070
2100
|
JOIN entities e ON ue.entity_id = e.id
|
|
2071
2101
|
WHERE ue.unit_id = ANY($1::uuid[])
|
|
2072
2102
|
ORDER BY ue.unit_id
|
|
2073
|
-
""",
|
|
2103
|
+
""",
|
|
2104
|
+
unit_ids,
|
|
2105
|
+
)
|
|
2074
2106
|
else:
|
|
2075
2107
|
unit_entities = []
|
|
2076
2108
|
|
|
2077
2109
|
# Build entity mapping
|
|
2078
2110
|
entity_map = {}
|
|
2079
2111
|
for row in unit_entities:
|
|
2080
|
-
unit_id = row[
|
|
2081
|
-
entity_name = row[
|
|
2112
|
+
unit_id = row["unit_id"]
|
|
2113
|
+
entity_name = row["canonical_name"]
|
|
2082
2114
|
if unit_id not in entity_map:
|
|
2083
2115
|
entity_map[unit_id] = []
|
|
2084
2116
|
entity_map[unit_id].append(entity_name)
|
|
@@ -2086,36 +2118,27 @@ class MemoryEngine:
|
|
|
2086
2118
|
# Build result items
|
|
2087
2119
|
items = []
|
|
2088
2120
|
for row in units:
|
|
2089
|
-
unit_id = row[
|
|
2121
|
+
unit_id = row["id"]
|
|
2090
2122
|
entities = entity_map.get(unit_id, [])
|
|
2091
2123
|
|
|
2092
|
-
items.append(
|
|
2093
|
-
|
|
2094
|
-
|
|
2095
|
-
|
|
2096
|
-
|
|
2097
|
-
|
|
2098
|
-
|
|
2099
|
-
|
|
2100
|
-
|
|
2101
|
-
|
|
2102
|
-
|
|
2103
|
-
|
|
2124
|
+
items.append(
|
|
2125
|
+
{
|
|
2126
|
+
"id": str(unit_id),
|
|
2127
|
+
"text": row["text"],
|
|
2128
|
+
"context": row["context"] if row["context"] else "",
|
|
2129
|
+
"date": row["event_date"].isoformat() if row["event_date"] else "",
|
|
2130
|
+
"fact_type": row["fact_type"],
|
|
2131
|
+
"mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
|
|
2132
|
+
"occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
|
|
2133
|
+
"occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
|
|
2134
|
+
"entities": ", ".join(entities) if entities else "",
|
|
2135
|
+
"chunk_id": row["chunk_id"] if row["chunk_id"] else None,
|
|
2136
|
+
}
|
|
2137
|
+
)
|
|
2104
2138
|
|
|
2105
|
-
return {
|
|
2106
|
-
"items": items,
|
|
2107
|
-
"total": total,
|
|
2108
|
-
"limit": limit,
|
|
2109
|
-
"offset": offset
|
|
2110
|
-
}
|
|
2139
|
+
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
|
2111
2140
|
|
|
2112
|
-
async def list_documents(
|
|
2113
|
-
self,
|
|
2114
|
-
bank_id: str,
|
|
2115
|
-
search_query: Optional[str] = None,
|
|
2116
|
-
limit: int = 100,
|
|
2117
|
-
offset: int = 0
|
|
2118
|
-
):
|
|
2141
|
+
async def list_documents(self, bank_id: str, search_query: str | None = None, limit: int = 100, offset: int = 0):
|
|
2119
2142
|
"""
|
|
2120
2143
|
List documents with optional search and pagination.
|
|
2121
2144
|
|
|
@@ -2154,7 +2177,7 @@ class MemoryEngine:
|
|
|
2154
2177
|
{where_clause}
|
|
2155
2178
|
"""
|
|
2156
2179
|
count_result = await conn.fetchrow(count_query, *query_params)
|
|
2157
|
-
total = count_result[
|
|
2180
|
+
total = count_result["total"]
|
|
2158
2181
|
|
|
2159
2182
|
# Get documents with limit and offset (without original_text for performance)
|
|
2160
2183
|
param_count += 1
|
|
@@ -2165,7 +2188,8 @@ class MemoryEngine:
|
|
|
2165
2188
|
offset_param = f"${param_count}"
|
|
2166
2189
|
query_params.append(offset)
|
|
2167
2190
|
|
|
2168
|
-
documents = await conn.fetch(
|
|
2191
|
+
documents = await conn.fetch(
|
|
2192
|
+
f"""
|
|
2169
2193
|
SELECT
|
|
2170
2194
|
id,
|
|
2171
2195
|
bank_id,
|
|
@@ -2178,11 +2202,13 @@ class MemoryEngine:
|
|
|
2178
2202
|
{where_clause}
|
|
2179
2203
|
ORDER BY created_at DESC
|
|
2180
2204
|
LIMIT {limit_param} OFFSET {offset_param}
|
|
2181
|
-
""",
|
|
2205
|
+
""",
|
|
2206
|
+
*query_params,
|
|
2207
|
+
)
|
|
2182
2208
|
|
|
2183
2209
|
# Get memory unit count for each document
|
|
2184
2210
|
if documents:
|
|
2185
|
-
doc_ids = [(row[
|
|
2211
|
+
doc_ids = [(row["id"], row["bank_id"]) for row in documents]
|
|
2186
2212
|
|
|
2187
2213
|
# Create placeholders for the query
|
|
2188
2214
|
placeholders = []
|
|
@@ -2195,48 +2221,44 @@ class MemoryEngine:
|
|
|
2195
2221
|
|
|
2196
2222
|
where_clause_count = " OR ".join(placeholders)
|
|
2197
2223
|
|
|
2198
|
-
unit_counts = await conn.fetch(
|
|
2224
|
+
unit_counts = await conn.fetch(
|
|
2225
|
+
f"""
|
|
2199
2226
|
SELECT document_id, bank_id, COUNT(*) as unit_count
|
|
2200
2227
|
FROM memory_units
|
|
2201
2228
|
WHERE {where_clause_count}
|
|
2202
2229
|
GROUP BY document_id, bank_id
|
|
2203
|
-
""",
|
|
2230
|
+
""",
|
|
2231
|
+
*params_for_count,
|
|
2232
|
+
)
|
|
2204
2233
|
else:
|
|
2205
2234
|
unit_counts = []
|
|
2206
2235
|
|
|
2207
2236
|
# Build count mapping
|
|
2208
|
-
count_map = {(row[
|
|
2237
|
+
count_map = {(row["document_id"], row["bank_id"]): row["unit_count"] for row in unit_counts}
|
|
2209
2238
|
|
|
2210
2239
|
# Build result items
|
|
2211
2240
|
items = []
|
|
2212
2241
|
for row in documents:
|
|
2213
|
-
doc_id = row[
|
|
2214
|
-
bank_id_val = row[
|
|
2242
|
+
doc_id = row["id"]
|
|
2243
|
+
bank_id_val = row["bank_id"]
|
|
2215
2244
|
unit_count = count_map.get((doc_id, bank_id_val), 0)
|
|
2216
2245
|
|
|
2217
|
-
items.append(
|
|
2218
|
-
|
|
2219
|
-
|
|
2220
|
-
|
|
2221
|
-
|
|
2222
|
-
|
|
2223
|
-
|
|
2224
|
-
|
|
2225
|
-
|
|
2226
|
-
|
|
2246
|
+
items.append(
|
|
2247
|
+
{
|
|
2248
|
+
"id": doc_id,
|
|
2249
|
+
"bank_id": bank_id_val,
|
|
2250
|
+
"content_hash": row["content_hash"],
|
|
2251
|
+
"created_at": row["created_at"].isoformat() if row["created_at"] else "",
|
|
2252
|
+
"updated_at": row["updated_at"].isoformat() if row["updated_at"] else "",
|
|
2253
|
+
"text_length": row["text_length"] or 0,
|
|
2254
|
+
"memory_unit_count": unit_count,
|
|
2255
|
+
"retain_params": row["retain_params"] if row["retain_params"] else None,
|
|
2256
|
+
}
|
|
2257
|
+
)
|
|
2227
2258
|
|
|
2228
|
-
return {
|
|
2229
|
-
"items": items,
|
|
2230
|
-
"total": total,
|
|
2231
|
-
"limit": limit,
|
|
2232
|
-
"offset": offset
|
|
2233
|
-
}
|
|
2259
|
+
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
|
2234
2260
|
|
|
2235
|
-
async def get_document(
|
|
2236
|
-
self,
|
|
2237
|
-
document_id: str,
|
|
2238
|
-
bank_id: str
|
|
2239
|
-
):
|
|
2261
|
+
async def get_document(self, document_id: str, bank_id: str):
|
|
2240
2262
|
"""
|
|
2241
2263
|
Get a specific document including its original_text.
|
|
2242
2264
|
|
|
@@ -2249,7 +2271,8 @@ class MemoryEngine:
|
|
|
2249
2271
|
"""
|
|
2250
2272
|
pool = await self._get_pool()
|
|
2251
2273
|
async with acquire_with_retry(pool) as conn:
|
|
2252
|
-
doc = await conn.fetchrow(
|
|
2274
|
+
doc = await conn.fetchrow(
|
|
2275
|
+
"""
|
|
2253
2276
|
SELECT
|
|
2254
2277
|
id,
|
|
2255
2278
|
bank_id,
|
|
@@ -2260,33 +2283,37 @@ class MemoryEngine:
|
|
|
2260
2283
|
retain_params
|
|
2261
2284
|
FROM documents
|
|
2262
2285
|
WHERE id = $1 AND bank_id = $2
|
|
2263
|
-
""",
|
|
2286
|
+
""",
|
|
2287
|
+
document_id,
|
|
2288
|
+
bank_id,
|
|
2289
|
+
)
|
|
2264
2290
|
|
|
2265
2291
|
if not doc:
|
|
2266
2292
|
return None
|
|
2267
2293
|
|
|
2268
2294
|
# Get memory unit count
|
|
2269
|
-
unit_count_row = await conn.fetchrow(
|
|
2295
|
+
unit_count_row = await conn.fetchrow(
|
|
2296
|
+
"""
|
|
2270
2297
|
SELECT COUNT(*) as unit_count
|
|
2271
2298
|
FROM memory_units
|
|
2272
2299
|
WHERE document_id = $1 AND bank_id = $2
|
|
2273
|
-
""",
|
|
2300
|
+
""",
|
|
2301
|
+
document_id,
|
|
2302
|
+
bank_id,
|
|
2303
|
+
)
|
|
2274
2304
|
|
|
2275
2305
|
return {
|
|
2276
|
-
"id": doc[
|
|
2277
|
-
"bank_id": doc[
|
|
2278
|
-
"original_text": doc[
|
|
2279
|
-
"content_hash": doc[
|
|
2280
|
-
"created_at": doc[
|
|
2281
|
-
"updated_at": doc[
|
|
2282
|
-
"memory_unit_count": unit_count_row[
|
|
2283
|
-
"retain_params": doc[
|
|
2306
|
+
"id": doc["id"],
|
|
2307
|
+
"bank_id": doc["bank_id"],
|
|
2308
|
+
"original_text": doc["original_text"],
|
|
2309
|
+
"content_hash": doc["content_hash"],
|
|
2310
|
+
"created_at": doc["created_at"].isoformat() if doc["created_at"] else "",
|
|
2311
|
+
"updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else "",
|
|
2312
|
+
"memory_unit_count": unit_count_row["unit_count"] if unit_count_row else 0,
|
|
2313
|
+
"retain_params": doc["retain_params"] if doc["retain_params"] else None,
|
|
2284
2314
|
}
|
|
2285
2315
|
|
|
2286
|
-
async def get_chunk(
|
|
2287
|
-
self,
|
|
2288
|
-
chunk_id: str
|
|
2289
|
-
):
|
|
2316
|
+
async def get_chunk(self, chunk_id: str):
|
|
2290
2317
|
"""
|
|
2291
2318
|
Get a specific chunk by its ID.
|
|
2292
2319
|
|
|
@@ -2298,7 +2325,8 @@ class MemoryEngine:
|
|
|
2298
2325
|
"""
|
|
2299
2326
|
pool = await self._get_pool()
|
|
2300
2327
|
async with acquire_with_retry(pool) as conn:
|
|
2301
|
-
chunk = await conn.fetchrow(
|
|
2328
|
+
chunk = await conn.fetchrow(
|
|
2329
|
+
"""
|
|
2302
2330
|
SELECT
|
|
2303
2331
|
chunk_id,
|
|
2304
2332
|
document_id,
|
|
@@ -2308,18 +2336,20 @@ class MemoryEngine:
|
|
|
2308
2336
|
created_at
|
|
2309
2337
|
FROM chunks
|
|
2310
2338
|
WHERE chunk_id = $1
|
|
2311
|
-
""",
|
|
2339
|
+
""",
|
|
2340
|
+
chunk_id,
|
|
2341
|
+
)
|
|
2312
2342
|
|
|
2313
2343
|
if not chunk:
|
|
2314
2344
|
return None
|
|
2315
2345
|
|
|
2316
2346
|
return {
|
|
2317
|
-
"chunk_id": chunk[
|
|
2318
|
-
"document_id": chunk[
|
|
2319
|
-
"bank_id": chunk[
|
|
2320
|
-
"chunk_index": chunk[
|
|
2321
|
-
"chunk_text": chunk[
|
|
2322
|
-
"created_at": chunk[
|
|
2347
|
+
"chunk_id": chunk["chunk_id"],
|
|
2348
|
+
"document_id": chunk["document_id"],
|
|
2349
|
+
"bank_id": chunk["bank_id"],
|
|
2350
|
+
"chunk_index": chunk["chunk_index"],
|
|
2351
|
+
"chunk_text": chunk["chunk_text"],
|
|
2352
|
+
"created_at": chunk["created_at"].isoformat() if chunk["created_at"] else "",
|
|
2323
2353
|
}
|
|
2324
2354
|
|
|
2325
2355
|
async def _evaluate_opinion_update_async(
|
|
@@ -2328,7 +2358,7 @@ class MemoryEngine:
|
|
|
2328
2358
|
opinion_confidence: float,
|
|
2329
2359
|
new_event_text: str,
|
|
2330
2360
|
entity_name: str,
|
|
2331
|
-
) ->
|
|
2361
|
+
) -> dict[str, Any] | None:
|
|
2332
2362
|
"""
|
|
2333
2363
|
Evaluate if an opinion should be updated based on a new event.
|
|
2334
2364
|
|
|
@@ -2342,16 +2372,18 @@ class MemoryEngine:
|
|
|
2342
2372
|
Dict with 'action' ('keep'|'update'), 'new_confidence', 'new_text' (if action=='update')
|
|
2343
2373
|
or None if no changes needed
|
|
2344
2374
|
"""
|
|
2345
|
-
from pydantic import BaseModel, Field
|
|
2346
2375
|
|
|
2347
2376
|
class OpinionEvaluation(BaseModel):
|
|
2348
2377
|
"""Evaluation of whether an opinion should be updated."""
|
|
2378
|
+
|
|
2349
2379
|
action: str = Field(description="Action to take: 'keep' (no change) or 'update' (modify opinion)")
|
|
2350
2380
|
reasoning: str = Field(description="Brief explanation of why this action was chosen")
|
|
2351
|
-
new_confidence: float = Field(
|
|
2352
|
-
|
|
2381
|
+
new_confidence: float = Field(
|
|
2382
|
+
description="New confidence score (0.0-1.0). Can be higher, lower, or same as before."
|
|
2383
|
+
)
|
|
2384
|
+
new_opinion_text: str | None = Field(
|
|
2353
2385
|
default=None,
|
|
2354
|
-
description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None."
|
|
2386
|
+
description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None.",
|
|
2355
2387
|
)
|
|
2356
2388
|
|
|
2357
2389
|
evaluation_prompt = f"""You are evaluating whether an existing opinion should be updated based on new information.
|
|
@@ -2381,70 +2413,63 @@ Guidelines:
|
|
|
2381
2413
|
result = await self._llm_config.call(
|
|
2382
2414
|
messages=[
|
|
2383
2415
|
{"role": "system", "content": "You evaluate and update opinions based on new information."},
|
|
2384
|
-
{"role": "user", "content": evaluation_prompt}
|
|
2416
|
+
{"role": "user", "content": evaluation_prompt},
|
|
2385
2417
|
],
|
|
2386
2418
|
response_format=OpinionEvaluation,
|
|
2387
2419
|
scope="memory_evaluate_opinion",
|
|
2388
|
-
temperature=0.3 # Lower temperature for more consistent evaluation
|
|
2420
|
+
temperature=0.3, # Lower temperature for more consistent evaluation
|
|
2389
2421
|
)
|
|
2390
2422
|
|
|
2391
2423
|
# Only return updates if something actually changed
|
|
2392
|
-
if result.action ==
|
|
2424
|
+
if result.action == "keep" and abs(result.new_confidence - opinion_confidence) < 0.01:
|
|
2393
2425
|
return None
|
|
2394
2426
|
|
|
2395
2427
|
return {
|
|
2396
|
-
|
|
2397
|
-
|
|
2398
|
-
|
|
2399
|
-
|
|
2428
|
+
"action": result.action,
|
|
2429
|
+
"reasoning": result.reasoning,
|
|
2430
|
+
"new_confidence": result.new_confidence,
|
|
2431
|
+
"new_text": result.new_opinion_text if result.action == "update" else None,
|
|
2400
2432
|
}
|
|
2401
2433
|
|
|
2402
2434
|
except Exception as e:
|
|
2403
2435
|
logger.warning(f"Failed to evaluate opinion update: {str(e)}")
|
|
2404
2436
|
return None
|
|
2405
2437
|
|
|
2406
|
-
async def _handle_form_opinion(self, task_dict:
|
|
2438
|
+
async def _handle_form_opinion(self, task_dict: dict[str, Any]):
|
|
2407
2439
|
"""
|
|
2408
2440
|
Handler for form opinion tasks.
|
|
2409
2441
|
|
|
2410
2442
|
Args:
|
|
2411
2443
|
task_dict: Dict with keys: 'bank_id', 'answer_text', 'query'
|
|
2412
2444
|
"""
|
|
2413
|
-
bank_id = task_dict[
|
|
2414
|
-
answer_text = task_dict[
|
|
2415
|
-
query = task_dict[
|
|
2445
|
+
bank_id = task_dict["bank_id"]
|
|
2446
|
+
answer_text = task_dict["answer_text"]
|
|
2447
|
+
query = task_dict["query"]
|
|
2416
2448
|
|
|
2417
|
-
await self._extract_and_store_opinions_async(
|
|
2418
|
-
bank_id=bank_id,
|
|
2419
|
-
answer_text=answer_text,
|
|
2420
|
-
query=query
|
|
2421
|
-
)
|
|
2449
|
+
await self._extract_and_store_opinions_async(bank_id=bank_id, answer_text=answer_text, query=query)
|
|
2422
2450
|
|
|
2423
|
-
async def _handle_reinforce_opinion(self, task_dict:
|
|
2451
|
+
async def _handle_reinforce_opinion(self, task_dict: dict[str, Any]):
|
|
2424
2452
|
"""
|
|
2425
2453
|
Handler for reinforce opinion tasks.
|
|
2426
2454
|
|
|
2427
2455
|
Args:
|
|
2428
2456
|
task_dict: Dict with keys: 'bank_id', 'created_unit_ids', 'unit_texts', 'unit_entities'
|
|
2429
2457
|
"""
|
|
2430
|
-
bank_id = task_dict[
|
|
2431
|
-
created_unit_ids = task_dict[
|
|
2432
|
-
unit_texts = task_dict[
|
|
2433
|
-
unit_entities = task_dict[
|
|
2458
|
+
bank_id = task_dict["bank_id"]
|
|
2459
|
+
created_unit_ids = task_dict["created_unit_ids"]
|
|
2460
|
+
unit_texts = task_dict["unit_texts"]
|
|
2461
|
+
unit_entities = task_dict["unit_entities"]
|
|
2434
2462
|
|
|
2435
2463
|
await self._reinforce_opinions_async(
|
|
2436
|
-
bank_id=bank_id,
|
|
2437
|
-
created_unit_ids=created_unit_ids,
|
|
2438
|
-
unit_texts=unit_texts,
|
|
2439
|
-
unit_entities=unit_entities
|
|
2464
|
+
bank_id=bank_id, created_unit_ids=created_unit_ids, unit_texts=unit_texts, unit_entities=unit_entities
|
|
2440
2465
|
)
|
|
2441
2466
|
|
|
2442
2467
|
async def _reinforce_opinions_async(
|
|
2443
2468
|
self,
|
|
2444
2469
|
bank_id: str,
|
|
2445
|
-
created_unit_ids:
|
|
2446
|
-
unit_texts:
|
|
2447
|
-
unit_entities:
|
|
2470
|
+
created_unit_ids: list[str],
|
|
2471
|
+
unit_texts: list[str],
|
|
2472
|
+
unit_entities: list[list[dict[str, str]]],
|
|
2448
2473
|
):
|
|
2449
2474
|
"""
|
|
2450
2475
|
Background task to reinforce opinions based on newly ingested events.
|
|
@@ -2463,15 +2488,14 @@ Guidelines:
|
|
|
2463
2488
|
for entities_list in unit_entities:
|
|
2464
2489
|
for entity in entities_list:
|
|
2465
2490
|
# Handle both Entity objects and dicts
|
|
2466
|
-
if hasattr(entity,
|
|
2491
|
+
if hasattr(entity, "text"):
|
|
2467
2492
|
entity_names.add(entity.text)
|
|
2468
2493
|
elif isinstance(entity, dict):
|
|
2469
|
-
entity_names.add(entity[
|
|
2494
|
+
entity_names.add(entity["text"])
|
|
2470
2495
|
|
|
2471
2496
|
if not entity_names:
|
|
2472
2497
|
return
|
|
2473
2498
|
|
|
2474
|
-
|
|
2475
2499
|
pool = await self._get_pool()
|
|
2476
2500
|
async with acquire_with_retry(pool) as conn:
|
|
2477
2501
|
# Find all opinions related to these entities
|
|
@@ -2486,13 +2510,12 @@ Guidelines:
|
|
|
2486
2510
|
AND e.canonical_name = ANY($2::text[])
|
|
2487
2511
|
""",
|
|
2488
2512
|
bank_id,
|
|
2489
|
-
list(entity_names)
|
|
2513
|
+
list(entity_names),
|
|
2490
2514
|
)
|
|
2491
2515
|
|
|
2492
2516
|
if not opinions:
|
|
2493
2517
|
return
|
|
2494
2518
|
|
|
2495
|
-
|
|
2496
2519
|
# Use cached LLM config
|
|
2497
2520
|
if self._llm_config is None:
|
|
2498
2521
|
logger.error("[REINFORCE] LLM config not available, skipping opinion reinforcement")
|
|
@@ -2501,15 +2524,15 @@ Guidelines:
|
|
|
2501
2524
|
# Evaluate each opinion against the new events
|
|
2502
2525
|
updates_to_apply = []
|
|
2503
2526
|
for opinion in opinions:
|
|
2504
|
-
opinion_id = str(opinion[
|
|
2505
|
-
opinion_text = opinion[
|
|
2506
|
-
opinion_confidence = opinion[
|
|
2507
|
-
entity_name = opinion[
|
|
2527
|
+
opinion_id = str(opinion["id"])
|
|
2528
|
+
opinion_text = opinion["text"]
|
|
2529
|
+
opinion_confidence = opinion["confidence_score"]
|
|
2530
|
+
entity_name = opinion["canonical_name"]
|
|
2508
2531
|
|
|
2509
2532
|
# Find all new events mentioning this entity
|
|
2510
2533
|
relevant_events = []
|
|
2511
2534
|
for unit_text, entities_list in zip(unit_texts, unit_entities):
|
|
2512
|
-
if any(e[
|
|
2535
|
+
if any(e["text"] == entity_name for e in entities_list):
|
|
2513
2536
|
relevant_events.append(unit_text)
|
|
2514
2537
|
|
|
2515
2538
|
if not relevant_events:
|
|
@@ -2520,26 +2543,20 @@ Guidelines:
|
|
|
2520
2543
|
|
|
2521
2544
|
# Evaluate if opinion should be updated
|
|
2522
2545
|
evaluation = await self._evaluate_opinion_update_async(
|
|
2523
|
-
opinion_text,
|
|
2524
|
-
opinion_confidence,
|
|
2525
|
-
combined_events,
|
|
2526
|
-
entity_name
|
|
2546
|
+
opinion_text, opinion_confidence, combined_events, entity_name
|
|
2527
2547
|
)
|
|
2528
2548
|
|
|
2529
2549
|
if evaluation:
|
|
2530
|
-
updates_to_apply.append({
|
|
2531
|
-
'opinion_id': opinion_id,
|
|
2532
|
-
'evaluation': evaluation
|
|
2533
|
-
})
|
|
2550
|
+
updates_to_apply.append({"opinion_id": opinion_id, "evaluation": evaluation})
|
|
2534
2551
|
|
|
2535
2552
|
# Apply all updates in a single transaction
|
|
2536
2553
|
if updates_to_apply:
|
|
2537
2554
|
async with conn.transaction():
|
|
2538
2555
|
for update in updates_to_apply:
|
|
2539
|
-
opinion_id = update[
|
|
2540
|
-
evaluation = update[
|
|
2556
|
+
opinion_id = update["opinion_id"]
|
|
2557
|
+
evaluation = update["evaluation"]
|
|
2541
2558
|
|
|
2542
|
-
if evaluation[
|
|
2559
|
+
if evaluation["action"] == "update" and evaluation["new_text"]:
|
|
2543
2560
|
# Update both text and confidence
|
|
2544
2561
|
await conn.execute(
|
|
2545
2562
|
"""
|
|
@@ -2547,9 +2564,9 @@ Guidelines:
|
|
|
2547
2564
|
SET text = $1, confidence_score = $2, updated_at = NOW()
|
|
2548
2565
|
WHERE id = $3
|
|
2549
2566
|
""",
|
|
2550
|
-
evaluation[
|
|
2551
|
-
evaluation[
|
|
2552
|
-
uuid.UUID(opinion_id)
|
|
2567
|
+
evaluation["new_text"],
|
|
2568
|
+
evaluation["new_confidence"],
|
|
2569
|
+
uuid.UUID(opinion_id),
|
|
2553
2570
|
)
|
|
2554
2571
|
else:
|
|
2555
2572
|
# Only update confidence
|
|
@@ -2559,8 +2576,8 @@ Guidelines:
|
|
|
2559
2576
|
SET confidence_score = $1, updated_at = NOW()
|
|
2560
2577
|
WHERE id = $2
|
|
2561
2578
|
""",
|
|
2562
|
-
evaluation[
|
|
2563
|
-
uuid.UUID(opinion_id)
|
|
2579
|
+
evaluation["new_confidence"],
|
|
2580
|
+
uuid.UUID(opinion_id),
|
|
2564
2581
|
)
|
|
2565
2582
|
|
|
2566
2583
|
else:
|
|
@@ -2569,6 +2586,7 @@ Guidelines:
|
|
|
2569
2586
|
except Exception as e:
|
|
2570
2587
|
logger.error(f"[REINFORCE] Error during opinion reinforcement: {str(e)}")
|
|
2571
2588
|
import traceback
|
|
2589
|
+
|
|
2572
2590
|
traceback.print_exc()
|
|
2573
2591
|
|
|
2574
2592
|
# ==================== bank profile Methods ====================
|
|
@@ -2587,11 +2605,7 @@ Guidelines:
|
|
|
2587
2605
|
pool = await self._get_pool()
|
|
2588
2606
|
return await bank_utils.get_bank_profile(pool, bank_id)
|
|
2589
2607
|
|
|
2590
|
-
async def update_bank_disposition(
|
|
2591
|
-
self,
|
|
2592
|
-
bank_id: str,
|
|
2593
|
-
disposition: Dict[str, int]
|
|
2594
|
-
) -> None:
|
|
2608
|
+
async def update_bank_disposition(self, bank_id: str, disposition: dict[str, int]) -> None:
|
|
2595
2609
|
"""
|
|
2596
2610
|
Update bank disposition traits.
|
|
2597
2611
|
|
|
@@ -2602,12 +2616,7 @@ Guidelines:
|
|
|
2602
2616
|
pool = await self._get_pool()
|
|
2603
2617
|
await bank_utils.update_bank_disposition(pool, bank_id, disposition)
|
|
2604
2618
|
|
|
2605
|
-
async def merge_bank_background(
|
|
2606
|
-
self,
|
|
2607
|
-
bank_id: str,
|
|
2608
|
-
new_info: str,
|
|
2609
|
-
update_disposition: bool = True
|
|
2610
|
-
) -> dict:
|
|
2619
|
+
async def merge_bank_background(self, bank_id: str, new_info: str, update_disposition: bool = True) -> dict:
|
|
2611
2620
|
"""
|
|
2612
2621
|
Merge new background information with existing background using LLM.
|
|
2613
2622
|
Normalizes to first person ("I") and resolves conflicts.
|
|
@@ -2622,9 +2631,7 @@ Guidelines:
|
|
|
2622
2631
|
Dict with 'background' (str) and optionally 'disposition' (dict) keys
|
|
2623
2632
|
"""
|
|
2624
2633
|
pool = await self._get_pool()
|
|
2625
|
-
return await bank_utils.merge_bank_background(
|
|
2626
|
-
pool, self._llm_config, bank_id, new_info, update_disposition
|
|
2627
|
-
)
|
|
2634
|
+
return await bank_utils.merge_bank_background(pool, self._llm_config, bank_id, new_info, update_disposition)
|
|
2628
2635
|
|
|
2629
2636
|
async def list_banks(self) -> list:
|
|
2630
2637
|
"""
|
|
@@ -2685,19 +2692,21 @@ Guidelines:
|
|
|
2685
2692
|
budget=budget,
|
|
2686
2693
|
max_tokens=4096,
|
|
2687
2694
|
enable_trace=False,
|
|
2688
|
-
fact_type=[
|
|
2689
|
-
include_entities=True
|
|
2695
|
+
fact_type=["experience", "world", "opinion"],
|
|
2696
|
+
include_entities=True,
|
|
2690
2697
|
)
|
|
2691
2698
|
recall_time = time.time() - recall_start
|
|
2692
2699
|
|
|
2693
2700
|
all_results = search_result.results
|
|
2694
2701
|
|
|
2695
2702
|
# Split results by fact type for structured response
|
|
2696
|
-
agent_results = [r for r in all_results if r.fact_type ==
|
|
2697
|
-
world_results = [r for r in all_results if r.fact_type ==
|
|
2698
|
-
opinion_results = [r for r in all_results if r.fact_type ==
|
|
2703
|
+
agent_results = [r for r in all_results if r.fact_type == "experience"]
|
|
2704
|
+
world_results = [r for r in all_results if r.fact_type == "world"]
|
|
2705
|
+
opinion_results = [r for r in all_results if r.fact_type == "opinion"]
|
|
2699
2706
|
|
|
2700
|
-
log_buffer.append(
|
|
2707
|
+
log_buffer.append(
|
|
2708
|
+
f"[REFLECT {reflect_id}] Recall: {len(all_results)} facts (experience={len(agent_results)}, world={len(world_results)}, opinion={len(opinion_results)}) in {recall_time:.3f}s"
|
|
2709
|
+
)
|
|
2701
2710
|
|
|
2702
2711
|
# Format facts for LLM
|
|
2703
2712
|
agent_facts_text = think_utils.format_facts_for_prompt(agent_results)
|
|
@@ -2728,47 +2737,34 @@ Guidelines:
|
|
|
2728
2737
|
|
|
2729
2738
|
llm_start = time.time()
|
|
2730
2739
|
answer_text = await self._llm_config.call(
|
|
2731
|
-
messages=[
|
|
2732
|
-
{"role": "system", "content": system_message},
|
|
2733
|
-
{"role": "user", "content": prompt}
|
|
2734
|
-
],
|
|
2740
|
+
messages=[{"role": "system", "content": system_message}, {"role": "user", "content": prompt}],
|
|
2735
2741
|
scope="memory_think",
|
|
2736
2742
|
temperature=0.9,
|
|
2737
|
-
max_completion_tokens=1000
|
|
2743
|
+
max_completion_tokens=1000,
|
|
2738
2744
|
)
|
|
2739
2745
|
llm_time = time.time() - llm_start
|
|
2740
2746
|
|
|
2741
2747
|
answer_text = answer_text.strip()
|
|
2742
2748
|
|
|
2743
2749
|
# Submit form_opinion task for background processing
|
|
2744
|
-
await self._task_backend.submit_task(
|
|
2745
|
-
|
|
2746
|
-
|
|
2747
|
-
'answer_text': answer_text,
|
|
2748
|
-
'query': query
|
|
2749
|
-
})
|
|
2750
|
+
await self._task_backend.submit_task(
|
|
2751
|
+
{"type": "form_opinion", "bank_id": bank_id, "answer_text": answer_text, "query": query}
|
|
2752
|
+
)
|
|
2750
2753
|
|
|
2751
2754
|
total_time = time.time() - reflect_start
|
|
2752
|
-
log_buffer.append(
|
|
2755
|
+
log_buffer.append(
|
|
2756
|
+
f"[REFLECT {reflect_id}] Complete: {len(answer_text)} chars response, LLM {llm_time:.3f}s, total {total_time:.3f}s"
|
|
2757
|
+
)
|
|
2753
2758
|
logger.info("\n" + "\n".join(log_buffer))
|
|
2754
2759
|
|
|
2755
2760
|
# Return response with facts split by type
|
|
2756
2761
|
return ReflectResult(
|
|
2757
2762
|
text=answer_text,
|
|
2758
|
-
based_on={
|
|
2759
|
-
|
|
2760
|
-
"experience": agent_results,
|
|
2761
|
-
"opinion": opinion_results
|
|
2762
|
-
},
|
|
2763
|
-
new_opinions=[] # Opinions are being extracted asynchronously
|
|
2763
|
+
based_on={"world": world_results, "experience": agent_results, "opinion": opinion_results},
|
|
2764
|
+
new_opinions=[], # Opinions are being extracted asynchronously
|
|
2764
2765
|
)
|
|
2765
2766
|
|
|
2766
|
-
async def _extract_and_store_opinions_async(
|
|
2767
|
-
self,
|
|
2768
|
-
bank_id: str,
|
|
2769
|
-
answer_text: str,
|
|
2770
|
-
query: str
|
|
2771
|
-
):
|
|
2767
|
+
async def _extract_and_store_opinions_async(self, bank_id: str, answer_text: str, query: str):
|
|
2772
2768
|
"""
|
|
2773
2769
|
Background task to extract and store opinions from think response.
|
|
2774
2770
|
|
|
@@ -2781,33 +2777,27 @@ Guidelines:
|
|
|
2781
2777
|
"""
|
|
2782
2778
|
try:
|
|
2783
2779
|
# Extract opinions from the answer
|
|
2784
|
-
new_opinions = await think_utils.extract_opinions_from_text(
|
|
2785
|
-
self._llm_config, text=answer_text, query=query
|
|
2786
|
-
)
|
|
2780
|
+
new_opinions = await think_utils.extract_opinions_from_text(self._llm_config, text=answer_text, query=query)
|
|
2787
2781
|
|
|
2788
2782
|
# Store new opinions
|
|
2789
2783
|
if new_opinions:
|
|
2790
|
-
from datetime import datetime
|
|
2791
|
-
|
|
2784
|
+
from datetime import datetime
|
|
2785
|
+
|
|
2786
|
+
current_time = datetime.now(UTC)
|
|
2792
2787
|
for opinion in new_opinions:
|
|
2793
2788
|
await self.retain_async(
|
|
2794
2789
|
bank_id=bank_id,
|
|
2795
2790
|
content=opinion.opinion,
|
|
2796
2791
|
context=f"formed during thinking about: {query}",
|
|
2797
2792
|
event_date=current_time,
|
|
2798
|
-
fact_type_override=
|
|
2799
|
-
confidence_score=opinion.confidence
|
|
2793
|
+
fact_type_override="opinion",
|
|
2794
|
+
confidence_score=opinion.confidence,
|
|
2800
2795
|
)
|
|
2801
2796
|
|
|
2802
2797
|
except Exception as e:
|
|
2803
2798
|
logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
|
|
2804
2799
|
|
|
2805
|
-
async def get_entity_observations(
|
|
2806
|
-
self,
|
|
2807
|
-
bank_id: str,
|
|
2808
|
-
entity_id: str,
|
|
2809
|
-
limit: int = 10
|
|
2810
|
-
) -> List[EntityObservation]:
|
|
2800
|
+
async def get_entity_observations(self, bank_id: str, entity_id: str, limit: int = 10) -> list[EntityObservation]:
|
|
2811
2801
|
"""
|
|
2812
2802
|
Get observations linked to an entity.
|
|
2813
2803
|
|
|
@@ -2832,23 +2822,18 @@ Guidelines:
|
|
|
2832
2822
|
ORDER BY mu.mentioned_at DESC
|
|
2833
2823
|
LIMIT $3
|
|
2834
2824
|
""",
|
|
2835
|
-
bank_id,
|
|
2825
|
+
bank_id,
|
|
2826
|
+
uuid.UUID(entity_id),
|
|
2827
|
+
limit,
|
|
2836
2828
|
)
|
|
2837
2829
|
|
|
2838
2830
|
observations = []
|
|
2839
2831
|
for row in rows:
|
|
2840
|
-
mentioned_at = row[
|
|
2841
|
-
observations.append(EntityObservation(
|
|
2842
|
-
text=row['text'],
|
|
2843
|
-
mentioned_at=mentioned_at
|
|
2844
|
-
))
|
|
2832
|
+
mentioned_at = row["mentioned_at"].isoformat() if row["mentioned_at"] else None
|
|
2833
|
+
observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
|
|
2845
2834
|
return observations
|
|
2846
2835
|
|
|
2847
|
-
async def list_entities(
|
|
2848
|
-
self,
|
|
2849
|
-
bank_id: str,
|
|
2850
|
-
limit: int = 100
|
|
2851
|
-
) -> List[Dict[str, Any]]:
|
|
2836
|
+
async def list_entities(self, bank_id: str, limit: int = 100) -> list[dict[str, Any]]:
|
|
2852
2837
|
"""
|
|
2853
2838
|
List all entities for a bank.
|
|
2854
2839
|
|
|
@@ -2869,39 +2854,37 @@ Guidelines:
|
|
|
2869
2854
|
ORDER BY mention_count DESC, last_seen DESC
|
|
2870
2855
|
LIMIT $2
|
|
2871
2856
|
""",
|
|
2872
|
-
bank_id,
|
|
2857
|
+
bank_id,
|
|
2858
|
+
limit,
|
|
2873
2859
|
)
|
|
2874
2860
|
|
|
2875
2861
|
entities = []
|
|
2876
2862
|
for row in rows:
|
|
2877
2863
|
# Handle metadata - may be dict, JSON string, or None
|
|
2878
|
-
metadata = row[
|
|
2864
|
+
metadata = row["metadata"]
|
|
2879
2865
|
if metadata is None:
|
|
2880
2866
|
metadata = {}
|
|
2881
2867
|
elif isinstance(metadata, str):
|
|
2882
2868
|
import json
|
|
2869
|
+
|
|
2883
2870
|
try:
|
|
2884
2871
|
metadata = json.loads(metadata)
|
|
2885
2872
|
except json.JSONDecodeError:
|
|
2886
2873
|
metadata = {}
|
|
2887
2874
|
|
|
2888
|
-
entities.append(
|
|
2889
|
-
|
|
2890
|
-
|
|
2891
|
-
|
|
2892
|
-
|
|
2893
|
-
|
|
2894
|
-
|
|
2895
|
-
|
|
2875
|
+
entities.append(
|
|
2876
|
+
{
|
|
2877
|
+
"id": str(row["id"]),
|
|
2878
|
+
"canonical_name": row["canonical_name"],
|
|
2879
|
+
"mention_count": row["mention_count"],
|
|
2880
|
+
"first_seen": row["first_seen"].isoformat() if row["first_seen"] else None,
|
|
2881
|
+
"last_seen": row["last_seen"].isoformat() if row["last_seen"] else None,
|
|
2882
|
+
"metadata": metadata,
|
|
2883
|
+
}
|
|
2884
|
+
)
|
|
2896
2885
|
return entities
|
|
2897
2886
|
|
|
2898
|
-
async def get_entity_state(
|
|
2899
|
-
self,
|
|
2900
|
-
bank_id: str,
|
|
2901
|
-
entity_id: str,
|
|
2902
|
-
entity_name: str,
|
|
2903
|
-
limit: int = 10
|
|
2904
|
-
) -> EntityState:
|
|
2887
|
+
async def get_entity_state(self, bank_id: str, entity_id: str, entity_name: str, limit: int = 10) -> EntityState:
|
|
2905
2888
|
"""
|
|
2906
2889
|
Get the current state (mental model) of an entity.
|
|
2907
2890
|
|
|
@@ -2915,20 +2898,11 @@ Guidelines:
|
|
|
2915
2898
|
EntityState with observations
|
|
2916
2899
|
"""
|
|
2917
2900
|
observations = await self.get_entity_observations(bank_id, entity_id, limit)
|
|
2918
|
-
return EntityState(
|
|
2919
|
-
entity_id=entity_id,
|
|
2920
|
-
canonical_name=entity_name,
|
|
2921
|
-
observations=observations
|
|
2922
|
-
)
|
|
2901
|
+
return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=observations)
|
|
2923
2902
|
|
|
2924
2903
|
async def regenerate_entity_observations(
|
|
2925
|
-
self,
|
|
2926
|
-
|
|
2927
|
-
entity_id: str,
|
|
2928
|
-
entity_name: str,
|
|
2929
|
-
version: str | None = None,
|
|
2930
|
-
conn=None
|
|
2931
|
-
) -> List[str]:
|
|
2904
|
+
self, bank_id: str, entity_id: str, entity_name: str, version: str | None = None, conn=None
|
|
2905
|
+
) -> list[str]:
|
|
2932
2906
|
"""
|
|
2933
2907
|
Regenerate observations for an entity by:
|
|
2934
2908
|
1. Checking version for deduplication (if provided)
|
|
@@ -2973,7 +2947,8 @@ Guidelines:
|
|
|
2973
2947
|
FROM entities
|
|
2974
2948
|
WHERE id = $1 AND bank_id = $2
|
|
2975
2949
|
""",
|
|
2976
|
-
entity_uuid,
|
|
2950
|
+
entity_uuid,
|
|
2951
|
+
bank_id,
|
|
2977
2952
|
)
|
|
2978
2953
|
|
|
2979
2954
|
if current_last_seen and current_last_seen.isoformat() != version:
|
|
@@ -2991,7 +2966,8 @@ Guidelines:
|
|
|
2991
2966
|
ORDER BY mu.occurred_start DESC
|
|
2992
2967
|
LIMIT 50
|
|
2993
2968
|
""",
|
|
2994
|
-
bank_id,
|
|
2969
|
+
bank_id,
|
|
2970
|
+
entity_uuid,
|
|
2995
2971
|
)
|
|
2996
2972
|
|
|
2997
2973
|
if not rows:
|
|
@@ -3000,21 +2976,19 @@ Guidelines:
|
|
|
3000
2976
|
# Convert to MemoryFact objects for the observation extraction
|
|
3001
2977
|
facts = []
|
|
3002
2978
|
for row in rows:
|
|
3003
|
-
occurred_start = row[
|
|
3004
|
-
facts.append(
|
|
3005
|
-
|
|
3006
|
-
|
|
3007
|
-
|
|
3008
|
-
|
|
3009
|
-
|
|
3010
|
-
|
|
2979
|
+
occurred_start = row["occurred_start"].isoformat() if row["occurred_start"] else None
|
|
2980
|
+
facts.append(
|
|
2981
|
+
MemoryFact(
|
|
2982
|
+
id=str(row["id"]),
|
|
2983
|
+
text=row["text"],
|
|
2984
|
+
fact_type=row["fact_type"],
|
|
2985
|
+
context=row["context"],
|
|
2986
|
+
occurred_start=occurred_start,
|
|
2987
|
+
)
|
|
2988
|
+
)
|
|
3011
2989
|
|
|
3012
2990
|
# Step 3: Extract observations using LLM (no personality)
|
|
3013
|
-
observations = await observation_utils.extract_observations_from_facts(
|
|
3014
|
-
self._llm_config,
|
|
3015
|
-
entity_name,
|
|
3016
|
-
facts
|
|
3017
|
-
)
|
|
2991
|
+
observations = await observation_utils.extract_observations_from_facts(self._llm_config, entity_name, facts)
|
|
3018
2992
|
|
|
3019
2993
|
if not observations:
|
|
3020
2994
|
return []
|
|
@@ -3036,13 +3010,12 @@ Guidelines:
|
|
|
3036
3010
|
AND ue.entity_id = $2
|
|
3037
3011
|
)
|
|
3038
3012
|
""",
|
|
3039
|
-
bank_id,
|
|
3013
|
+
bank_id,
|
|
3014
|
+
entity_uuid,
|
|
3040
3015
|
)
|
|
3041
3016
|
|
|
3042
3017
|
# Generate embeddings for new observations
|
|
3043
|
-
embeddings = await embedding_utils.generate_embeddings_batch(
|
|
3044
|
-
self.embeddings, observations
|
|
3045
|
-
)
|
|
3018
|
+
embeddings = await embedding_utils.generate_embeddings_batch(self.embeddings, observations)
|
|
3046
3019
|
|
|
3047
3020
|
# Insert new observations
|
|
3048
3021
|
current_time = utcnow()
|
|
@@ -3066,9 +3039,9 @@ Guidelines:
|
|
|
3066
3039
|
current_time,
|
|
3067
3040
|
current_time,
|
|
3068
3041
|
current_time,
|
|
3069
|
-
current_time
|
|
3042
|
+
current_time,
|
|
3070
3043
|
)
|
|
3071
|
-
obs_id = str(result[
|
|
3044
|
+
obs_id = str(result["id"])
|
|
3072
3045
|
created_ids.append(obs_id)
|
|
3073
3046
|
|
|
3074
3047
|
# Link observation to entity
|
|
@@ -3077,7 +3050,8 @@ Guidelines:
|
|
|
3077
3050
|
INSERT INTO unit_entities (unit_id, entity_id)
|
|
3078
3051
|
VALUES ($1, $2)
|
|
3079
3052
|
""",
|
|
3080
|
-
uuid.UUID(obs_id),
|
|
3053
|
+
uuid.UUID(obs_id),
|
|
3054
|
+
entity_uuid,
|
|
3081
3055
|
)
|
|
3082
3056
|
|
|
3083
3057
|
return created_ids
|
|
@@ -3092,11 +3066,7 @@ Guidelines:
|
|
|
3092
3066
|
return await do_db_operations(acquired_conn)
|
|
3093
3067
|
|
|
3094
3068
|
async def _regenerate_observations_sync(
|
|
3095
|
-
self,
|
|
3096
|
-
bank_id: str,
|
|
3097
|
-
entity_ids: List[str],
|
|
3098
|
-
min_facts: int = 5,
|
|
3099
|
-
conn=None
|
|
3069
|
+
self, bank_id: str, entity_ids: list[str], min_facts: int = 5, conn=None
|
|
3100
3070
|
) -> None:
|
|
3101
3071
|
"""
|
|
3102
3072
|
Regenerate observations for entities synchronously (called during retain).
|
|
@@ -3123,9 +3093,10 @@ Guidelines:
|
|
|
3123
3093
|
SELECT id, canonical_name FROM entities
|
|
3124
3094
|
WHERE id = ANY($1) AND bank_id = $2
|
|
3125
3095
|
""",
|
|
3126
|
-
entity_uuids,
|
|
3096
|
+
entity_uuids,
|
|
3097
|
+
bank_id,
|
|
3127
3098
|
)
|
|
3128
|
-
entity_names = {row[
|
|
3099
|
+
entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
|
|
3129
3100
|
|
|
3130
3101
|
fact_counts = await conn.fetch(
|
|
3131
3102
|
"""
|
|
@@ -3135,9 +3106,10 @@ Guidelines:
|
|
|
3135
3106
|
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
3136
3107
|
GROUP BY ue.entity_id
|
|
3137
3108
|
""",
|
|
3138
|
-
entity_uuids,
|
|
3109
|
+
entity_uuids,
|
|
3110
|
+
bank_id,
|
|
3139
3111
|
)
|
|
3140
|
-
entity_fact_counts = {row[
|
|
3112
|
+
entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
|
|
3141
3113
|
else:
|
|
3142
3114
|
# Acquire a new connection (standalone call)
|
|
3143
3115
|
pool = await self._get_pool()
|
|
@@ -3147,9 +3119,10 @@ Guidelines:
|
|
|
3147
3119
|
SELECT id, canonical_name FROM entities
|
|
3148
3120
|
WHERE id = ANY($1) AND bank_id = $2
|
|
3149
3121
|
""",
|
|
3150
|
-
entity_uuids,
|
|
3122
|
+
entity_uuids,
|
|
3123
|
+
bank_id,
|
|
3151
3124
|
)
|
|
3152
|
-
entity_names = {row[
|
|
3125
|
+
entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
|
|
3153
3126
|
|
|
3154
3127
|
fact_counts = await acquired_conn.fetch(
|
|
3155
3128
|
"""
|
|
@@ -3159,9 +3132,10 @@ Guidelines:
|
|
|
3159
3132
|
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
3160
3133
|
GROUP BY ue.entity_id
|
|
3161
3134
|
""",
|
|
3162
|
-
entity_uuids,
|
|
3135
|
+
entity_uuids,
|
|
3136
|
+
bank_id,
|
|
3163
3137
|
)
|
|
3164
|
-
entity_fact_counts = {row[
|
|
3138
|
+
entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
|
|
3165
3139
|
|
|
3166
3140
|
# Filter entities that meet the threshold
|
|
3167
3141
|
entities_to_process = []
|
|
@@ -3183,11 +3157,9 @@ Guidelines:
|
|
|
3183
3157
|
except Exception as e:
|
|
3184
3158
|
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
3185
3159
|
|
|
3186
|
-
await asyncio.gather(*[
|
|
3187
|
-
process_entity(eid, name) for eid, name in entities_to_process
|
|
3188
|
-
])
|
|
3160
|
+
await asyncio.gather(*[process_entity(eid, name) for eid, name in entities_to_process])
|
|
3189
3161
|
|
|
3190
|
-
async def _handle_regenerate_observations(self, task_dict:
|
|
3162
|
+
async def _handle_regenerate_observations(self, task_dict: dict[str, Any]):
|
|
3191
3163
|
"""
|
|
3192
3164
|
Handler for regenerate_observations tasks.
|
|
3193
3165
|
|
|
@@ -3197,12 +3169,12 @@ Guidelines:
|
|
|
3197
3169
|
- 'entity_id', 'entity_name': Process single entity (legacy)
|
|
3198
3170
|
"""
|
|
3199
3171
|
try:
|
|
3200
|
-
bank_id = task_dict.get(
|
|
3172
|
+
bank_id = task_dict.get("bank_id")
|
|
3201
3173
|
|
|
3202
3174
|
# New format: multiple entity_ids
|
|
3203
|
-
if
|
|
3204
|
-
entity_ids = task_dict.get(
|
|
3205
|
-
min_facts = task_dict.get(
|
|
3175
|
+
if "entity_ids" in task_dict:
|
|
3176
|
+
entity_ids = task_dict.get("entity_ids", [])
|
|
3177
|
+
min_facts = task_dict.get("min_facts", 5)
|
|
3206
3178
|
|
|
3207
3179
|
if not bank_id or not entity_ids:
|
|
3208
3180
|
logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
|
|
@@ -3215,31 +3187,37 @@ Guidelines:
|
|
|
3215
3187
|
try:
|
|
3216
3188
|
# Fetch entity name and check fact count
|
|
3217
3189
|
import uuid as uuid_module
|
|
3190
|
+
|
|
3218
3191
|
entity_uuid = uuid_module.UUID(entity_id) if isinstance(entity_id, str) else entity_id
|
|
3219
3192
|
|
|
3220
3193
|
# First check if entity exists
|
|
3221
3194
|
entity_exists = await conn.fetchrow(
|
|
3222
3195
|
"SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
|
|
3223
|
-
entity_uuid,
|
|
3196
|
+
entity_uuid,
|
|
3197
|
+
bank_id,
|
|
3224
3198
|
)
|
|
3225
3199
|
|
|
3226
3200
|
if not entity_exists:
|
|
3227
3201
|
logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
|
|
3228
3202
|
continue
|
|
3229
3203
|
|
|
3230
|
-
entity_name = entity_exists[
|
|
3204
|
+
entity_name = entity_exists["canonical_name"]
|
|
3231
3205
|
|
|
3232
3206
|
# Count facts linked to this entity
|
|
3233
|
-
fact_count =
|
|
3234
|
-
|
|
3235
|
-
|
|
3236
|
-
|
|
3207
|
+
fact_count = (
|
|
3208
|
+
await conn.fetchval(
|
|
3209
|
+
"SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1", entity_uuid
|
|
3210
|
+
)
|
|
3211
|
+
or 0
|
|
3212
|
+
)
|
|
3237
3213
|
|
|
3238
3214
|
# Only regenerate if entity has enough facts
|
|
3239
3215
|
if fact_count >= min_facts:
|
|
3240
3216
|
await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
|
|
3241
3217
|
else:
|
|
3242
|
-
logger.debug(
|
|
3218
|
+
logger.debug(
|
|
3219
|
+
f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
|
|
3220
|
+
)
|
|
3243
3221
|
|
|
3244
3222
|
except Exception as e:
|
|
3245
3223
|
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
@@ -3247,9 +3225,9 @@ Guidelines:
|
|
|
3247
3225
|
|
|
3248
3226
|
# Legacy format: single entity
|
|
3249
3227
|
else:
|
|
3250
|
-
entity_id = task_dict.get(
|
|
3251
|
-
entity_name = task_dict.get(
|
|
3252
|
-
version = task_dict.get(
|
|
3228
|
+
entity_id = task_dict.get("entity_id")
|
|
3229
|
+
entity_name = task_dict.get("entity_name")
|
|
3230
|
+
version = task_dict.get("version")
|
|
3253
3231
|
|
|
3254
3232
|
if not all([bank_id, entity_id, entity_name]):
|
|
3255
3233
|
logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
|
|
@@ -3260,5 +3238,5 @@ Guidelines:
|
|
|
3260
3238
|
except Exception as e:
|
|
3261
3239
|
logger.error(f"[OBSERVATIONS] Error regenerating observations: {e}")
|
|
3262
3240
|
import traceback
|
|
3263
|
-
traceback.print_exc()
|
|
3264
3241
|
|
|
3242
|
+
traceback.print_exc()
|