hindsight-api 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -9
- hindsight_api/alembic/env.py +5 -8
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
- hindsight_api/api/__init__.py +10 -10
- hindsight_api/api/http.py +575 -593
- hindsight_api/api/mcp.py +31 -33
- hindsight_api/banner.py +13 -6
- hindsight_api/config.py +17 -12
- hindsight_api/engine/__init__.py +9 -9
- hindsight_api/engine/cross_encoder.py +23 -27
- hindsight_api/engine/db_utils.py +5 -4
- hindsight_api/engine/embeddings.py +22 -21
- hindsight_api/engine/entity_resolver.py +81 -75
- hindsight_api/engine/llm_wrapper.py +74 -88
- hindsight_api/engine/memory_engine.py +663 -673
- hindsight_api/engine/query_analyzer.py +100 -97
- hindsight_api/engine/response_models.py +105 -106
- hindsight_api/engine/retain/__init__.py +9 -16
- hindsight_api/engine/retain/bank_utils.py +34 -58
- hindsight_api/engine/retain/chunk_storage.py +4 -12
- hindsight_api/engine/retain/deduplication.py +9 -28
- hindsight_api/engine/retain/embedding_processing.py +4 -11
- hindsight_api/engine/retain/embedding_utils.py +3 -4
- hindsight_api/engine/retain/entity_processing.py +7 -17
- hindsight_api/engine/retain/fact_extraction.py +155 -165
- hindsight_api/engine/retain/fact_storage.py +11 -23
- hindsight_api/engine/retain/link_creation.py +11 -39
- hindsight_api/engine/retain/link_utils.py +166 -95
- hindsight_api/engine/retain/observation_regeneration.py +39 -52
- hindsight_api/engine/retain/orchestrator.py +72 -62
- hindsight_api/engine/retain/types.py +49 -43
- hindsight_api/engine/search/__init__.py +15 -1
- hindsight_api/engine/search/fusion.py +6 -15
- hindsight_api/engine/search/graph_retrieval.py +234 -0
- hindsight_api/engine/search/mpfp_retrieval.py +438 -0
- hindsight_api/engine/search/observation_utils.py +9 -16
- hindsight_api/engine/search/reranking.py +4 -7
- hindsight_api/engine/search/retrieval.py +388 -193
- hindsight_api/engine/search/scoring.py +5 -7
- hindsight_api/engine/search/temporal_extraction.py +8 -11
- hindsight_api/engine/search/think_utils.py +115 -39
- hindsight_api/engine/search/trace.py +68 -38
- hindsight_api/engine/search/tracer.py +49 -35
- hindsight_api/engine/search/types.py +22 -16
- hindsight_api/engine/task_backend.py +21 -26
- hindsight_api/engine/utils.py +25 -10
- hindsight_api/main.py +21 -40
- hindsight_api/mcp_local.py +190 -0
- hindsight_api/metrics.py +44 -30
- hindsight_api/migrations.py +10 -8
- hindsight_api/models.py +60 -72
- hindsight_api/pg0.py +64 -337
- hindsight_api/server.py +3 -6
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +6 -5
- hindsight_api-0.1.6.dist-info/RECORD +64 -0
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.1.4.dist-info/RECORD +0 -61
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
|
@@ -8,22 +8,23 @@ This implements a sophisticated memory architecture that combines:
|
|
|
8
8
|
4. Spreading activation: Search through the graph with activation decay
|
|
9
9
|
5. Dynamic weighting: Recency and frequency-based importance
|
|
10
10
|
"""
|
|
11
|
-
|
|
12
|
-
import os
|
|
13
|
-
from datetime import datetime, timedelta, timezone
|
|
14
|
-
from typing import Any, Dict, List, Optional, Tuple, Union, TypedDict, TYPE_CHECKING
|
|
15
|
-
import asyncpg
|
|
11
|
+
|
|
16
12
|
import asyncio
|
|
17
|
-
|
|
18
|
-
from .cross_encoder import CrossEncoderModel, create_cross_encoder_from_env
|
|
13
|
+
import logging
|
|
19
14
|
import time
|
|
20
|
-
import numpy as np
|
|
21
15
|
import uuid
|
|
22
|
-
import
|
|
16
|
+
from datetime import UTC, datetime, timedelta
|
|
17
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
18
|
+
|
|
19
|
+
import asyncpg
|
|
20
|
+
import numpy as np
|
|
23
21
|
from pydantic import BaseModel, Field
|
|
24
22
|
|
|
23
|
+
from .cross_encoder import CrossEncoderModel
|
|
24
|
+
from .embeddings import Embeddings, create_embeddings_from_env
|
|
25
|
+
|
|
25
26
|
if TYPE_CHECKING:
|
|
26
|
-
|
|
27
|
+
pass
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class RetainContentDict(TypedDict, total=False):
|
|
@@ -36,30 +37,31 @@ class RetainContentDict(TypedDict, total=False):
|
|
|
36
37
|
metadata: Custom key-value metadata (optional)
|
|
37
38
|
document_id: Document ID for this content item (optional)
|
|
38
39
|
"""
|
|
40
|
+
|
|
39
41
|
content: str # Required
|
|
40
42
|
context: str
|
|
41
43
|
event_date: datetime
|
|
42
|
-
metadata:
|
|
44
|
+
metadata: dict[str, str]
|
|
43
45
|
document_id: str
|
|
44
46
|
|
|
45
|
-
|
|
46
|
-
from
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
)
|
|
47
|
+
|
|
48
|
+
from enum import Enum
|
|
49
|
+
|
|
50
|
+
from ..pg0 import EmbeddedPostgres
|
|
50
51
|
from .entity_resolver import EntityResolver
|
|
51
|
-
from .retain import embedding_utils, bank_utils
|
|
52
|
-
from .search import think_utils, observation_utils
|
|
53
52
|
from .llm_wrapper import LLMConfig
|
|
54
|
-
from .
|
|
55
|
-
from .
|
|
53
|
+
from .query_analyzer import QueryAnalyzer
|
|
54
|
+
from .response_models import VALID_RECALL_FACT_TYPES, EntityObservation, EntityState, MemoryFact, ReflectResult
|
|
55
|
+
from .response_models import RecallResult as RecallResultModel
|
|
56
|
+
from .retain import bank_utils, embedding_utils
|
|
57
|
+
from .search import observation_utils, think_utils
|
|
56
58
|
from .search.reranking import CrossEncoderReranker
|
|
57
|
-
from
|
|
58
|
-
from enum import Enum
|
|
59
|
+
from .task_backend import AsyncIOQueueBackend, TaskBackend
|
|
59
60
|
|
|
60
61
|
|
|
61
62
|
class Budget(str, Enum):
|
|
62
63
|
"""Budget levels for recall/reflect operations."""
|
|
64
|
+
|
|
63
65
|
LOW = "low"
|
|
64
66
|
MID = "mid"
|
|
65
67
|
HIGH = "high"
|
|
@@ -67,20 +69,20 @@ class Budget(str, Enum):
|
|
|
67
69
|
|
|
68
70
|
def utcnow():
|
|
69
71
|
"""Get current UTC time with timezone info."""
|
|
70
|
-
return datetime.now(
|
|
72
|
+
return datetime.now(UTC)
|
|
71
73
|
|
|
72
74
|
|
|
73
75
|
# Logger for memory system
|
|
74
76
|
logger = logging.getLogger(__name__)
|
|
75
77
|
|
|
76
|
-
from .db_utils import acquire_with_retry, retry_with_backoff
|
|
77
|
-
|
|
78
78
|
import tiktoken
|
|
79
|
-
|
|
79
|
+
|
|
80
|
+
from .db_utils import acquire_with_retry
|
|
80
81
|
|
|
81
82
|
# Cache tiktoken encoding for token budget filtering (module-level singleton)
|
|
82
83
|
_TIKTOKEN_ENCODING = None
|
|
83
84
|
|
|
85
|
+
|
|
84
86
|
def _get_tiktoken_encoding():
|
|
85
87
|
"""Get cached tiktoken encoding (cl100k_base for GPT-4/3.5)."""
|
|
86
88
|
global _TIKTOKEN_ENCODING
|
|
@@ -102,17 +104,17 @@ class MemoryEngine:
|
|
|
102
104
|
|
|
103
105
|
def __init__(
|
|
104
106
|
self,
|
|
105
|
-
db_url:
|
|
106
|
-
memory_llm_provider:
|
|
107
|
-
memory_llm_api_key:
|
|
108
|
-
memory_llm_model:
|
|
109
|
-
memory_llm_base_url:
|
|
110
|
-
embeddings:
|
|
111
|
-
cross_encoder:
|
|
112
|
-
query_analyzer:
|
|
107
|
+
db_url: str | None = None,
|
|
108
|
+
memory_llm_provider: str | None = None,
|
|
109
|
+
memory_llm_api_key: str | None = None,
|
|
110
|
+
memory_llm_model: str | None = None,
|
|
111
|
+
memory_llm_base_url: str | None = None,
|
|
112
|
+
embeddings: Embeddings | None = None,
|
|
113
|
+
cross_encoder: CrossEncoderModel | None = None,
|
|
114
|
+
query_analyzer: QueryAnalyzer | None = None,
|
|
113
115
|
pool_min_size: int = 5,
|
|
114
116
|
pool_max_size: int = 100,
|
|
115
|
-
task_backend:
|
|
117
|
+
task_backend: TaskBackend | None = None,
|
|
116
118
|
run_migrations: bool = True,
|
|
117
119
|
):
|
|
118
120
|
"""
|
|
@@ -138,6 +140,7 @@ class MemoryEngine:
|
|
|
138
140
|
"""
|
|
139
141
|
# Load config from environment for any missing parameters
|
|
140
142
|
from ..config import get_config
|
|
143
|
+
|
|
141
144
|
config = get_config()
|
|
142
145
|
|
|
143
146
|
# Apply defaults from config
|
|
@@ -147,8 +150,8 @@ class MemoryEngine:
|
|
|
147
150
|
memory_llm_model = memory_llm_model or config.llm_model
|
|
148
151
|
memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
|
|
149
152
|
# Track pg0 instance (if used)
|
|
150
|
-
self._pg0:
|
|
151
|
-
self._pg0_instance_name:
|
|
153
|
+
self._pg0: EmbeddedPostgres | None = None
|
|
154
|
+
self._pg0_instance_name: str | None = None
|
|
152
155
|
|
|
153
156
|
# Initialize PostgreSQL connection URL
|
|
154
157
|
# The actual URL will be set during initialize() after starting the server
|
|
@@ -175,7 +178,6 @@ class MemoryEngine:
|
|
|
175
178
|
self._pg0_port = None
|
|
176
179
|
self.db_url = db_url
|
|
177
180
|
|
|
178
|
-
|
|
179
181
|
# Set default base URL if not provided
|
|
180
182
|
if memory_llm_base_url is None:
|
|
181
183
|
if memory_llm_provider.lower() == "groq":
|
|
@@ -206,6 +208,7 @@ class MemoryEngine:
|
|
|
206
208
|
self.query_analyzer = query_analyzer
|
|
207
209
|
else:
|
|
208
210
|
from .query_analyzer import DateparserQueryAnalyzer
|
|
211
|
+
|
|
209
212
|
self.query_analyzer = DateparserQueryAnalyzer()
|
|
210
213
|
|
|
211
214
|
# Initialize LLM configuration
|
|
@@ -224,10 +227,7 @@ class MemoryEngine:
|
|
|
224
227
|
self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
|
|
225
228
|
|
|
226
229
|
# Initialize task backend
|
|
227
|
-
self._task_backend = task_backend or AsyncIOQueueBackend(
|
|
228
|
-
batch_size=100,
|
|
229
|
-
batch_interval=1.0
|
|
230
|
-
)
|
|
230
|
+
self._task_backend = task_backend or AsyncIOQueueBackend(batch_size=100, batch_interval=1.0)
|
|
231
231
|
|
|
232
232
|
# Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
|
|
233
233
|
# Limit concurrent searches to prevent connection pool exhaustion
|
|
@@ -243,14 +243,14 @@ class MemoryEngine:
|
|
|
243
243
|
# initialize encoding eagerly to avoid delaying the first time
|
|
244
244
|
_get_tiktoken_encoding()
|
|
245
245
|
|
|
246
|
-
async def _handle_access_count_update(self, task_dict:
|
|
246
|
+
async def _handle_access_count_update(self, task_dict: dict[str, Any]):
|
|
247
247
|
"""
|
|
248
248
|
Handler for access count update tasks.
|
|
249
249
|
|
|
250
250
|
Args:
|
|
251
251
|
task_dict: Dict with 'node_ids' key containing list of node IDs to update
|
|
252
252
|
"""
|
|
253
|
-
node_ids = task_dict.get(
|
|
253
|
+
node_ids = task_dict.get("node_ids", [])
|
|
254
254
|
if not node_ids:
|
|
255
255
|
return
|
|
256
256
|
|
|
@@ -260,13 +260,12 @@ class MemoryEngine:
|
|
|
260
260
|
uuid_list = [uuid.UUID(nid) for nid in node_ids]
|
|
261
261
|
async with acquire_with_retry(pool) as conn:
|
|
262
262
|
await conn.execute(
|
|
263
|
-
"UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
|
|
264
|
-
uuid_list
|
|
263
|
+
"UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])", uuid_list
|
|
265
264
|
)
|
|
266
265
|
except Exception as e:
|
|
267
266
|
logger.error(f"Access count handler: Error updating access counts: {e}")
|
|
268
267
|
|
|
269
|
-
async def _handle_batch_retain(self, task_dict:
|
|
268
|
+
async def _handle_batch_retain(self, task_dict: dict[str, Any]):
|
|
270
269
|
"""
|
|
271
270
|
Handler for batch retain tasks.
|
|
272
271
|
|
|
@@ -274,23 +273,23 @@ class MemoryEngine:
|
|
|
274
273
|
task_dict: Dict with 'bank_id', 'contents'
|
|
275
274
|
"""
|
|
276
275
|
try:
|
|
277
|
-
bank_id = task_dict.get(
|
|
278
|
-
contents = task_dict.get(
|
|
279
|
-
|
|
280
|
-
logger.info(f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items")
|
|
276
|
+
bank_id = task_dict.get("bank_id")
|
|
277
|
+
contents = task_dict.get("contents", [])
|
|
281
278
|
|
|
282
|
-
|
|
283
|
-
bank_id=bank_id,
|
|
284
|
-
contents=contents
|
|
279
|
+
logger.info(
|
|
280
|
+
f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
|
|
285
281
|
)
|
|
286
282
|
|
|
283
|
+
await self.retain_batch_async(bank_id=bank_id, contents=contents)
|
|
284
|
+
|
|
287
285
|
logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
|
|
288
286
|
except Exception as e:
|
|
289
287
|
logger.error(f"Batch retain handler: Error processing batch retain: {e}")
|
|
290
288
|
import traceback
|
|
289
|
+
|
|
291
290
|
traceback.print_exc()
|
|
292
291
|
|
|
293
|
-
async def execute_task(self, task_dict:
|
|
292
|
+
async def execute_task(self, task_dict: dict[str, Any]):
|
|
294
293
|
"""
|
|
295
294
|
Execute a task by routing it to the appropriate handler.
|
|
296
295
|
|
|
@@ -301,9 +300,9 @@ class MemoryEngine:
|
|
|
301
300
|
task_dict: Task dictionary with 'type' key and other payload data
|
|
302
301
|
Example: {'type': 'access_count_update', 'node_ids': [...]}
|
|
303
302
|
"""
|
|
304
|
-
task_type = task_dict.get(
|
|
305
|
-
operation_id = task_dict.get(
|
|
306
|
-
retry_count = task_dict.get(
|
|
303
|
+
task_type = task_dict.get("type")
|
|
304
|
+
operation_id = task_dict.get("operation_id")
|
|
305
|
+
retry_count = task_dict.get("retry_count", 0)
|
|
307
306
|
max_retries = 3
|
|
308
307
|
|
|
309
308
|
# Check if operation was cancelled (only for tasks with operation_id)
|
|
@@ -312,8 +311,7 @@ class MemoryEngine:
|
|
|
312
311
|
pool = await self._get_pool()
|
|
313
312
|
async with acquire_with_retry(pool) as conn:
|
|
314
313
|
result = await conn.fetchrow(
|
|
315
|
-
"SELECT id FROM async_operations WHERE id = $1",
|
|
316
|
-
uuid.UUID(operation_id)
|
|
314
|
+
"SELECT id FROM async_operations WHERE id = $1", uuid.UUID(operation_id)
|
|
317
315
|
)
|
|
318
316
|
if not result:
|
|
319
317
|
# Operation was cancelled, skip processing
|
|
@@ -324,15 +322,15 @@ class MemoryEngine:
|
|
|
324
322
|
# Continue with processing if we can't check status
|
|
325
323
|
|
|
326
324
|
try:
|
|
327
|
-
if task_type ==
|
|
325
|
+
if task_type == "access_count_update":
|
|
328
326
|
await self._handle_access_count_update(task_dict)
|
|
329
|
-
elif task_type ==
|
|
327
|
+
elif task_type == "reinforce_opinion":
|
|
330
328
|
await self._handle_reinforce_opinion(task_dict)
|
|
331
|
-
elif task_type ==
|
|
329
|
+
elif task_type == "form_opinion":
|
|
332
330
|
await self._handle_form_opinion(task_dict)
|
|
333
|
-
elif task_type ==
|
|
331
|
+
elif task_type == "batch_retain":
|
|
334
332
|
await self._handle_batch_retain(task_dict)
|
|
335
|
-
elif task_type ==
|
|
333
|
+
elif task_type == "regenerate_observations":
|
|
336
334
|
await self._handle_regenerate_observations(task_dict)
|
|
337
335
|
else:
|
|
338
336
|
logger.error(f"Unknown task type: {task_type}")
|
|
@@ -347,14 +345,17 @@ class MemoryEngine:
|
|
|
347
345
|
|
|
348
346
|
except Exception as e:
|
|
349
347
|
# Task failed - check if we should retry
|
|
350
|
-
logger.error(
|
|
348
|
+
logger.error(
|
|
349
|
+
f"Task execution failed (attempt {retry_count + 1}/{max_retries + 1}): {task_type}, error: {e}"
|
|
350
|
+
)
|
|
351
351
|
import traceback
|
|
352
|
+
|
|
352
353
|
error_traceback = traceback.format_exc()
|
|
353
354
|
traceback.print_exc()
|
|
354
355
|
|
|
355
356
|
if retry_count < max_retries:
|
|
356
357
|
# Reschedule with incremented retry count
|
|
357
|
-
task_dict[
|
|
358
|
+
task_dict["retry_count"] = retry_count + 1
|
|
358
359
|
logger.info(f"Rescheduling task {task_type} (retry {retry_count + 1}/{max_retries})")
|
|
359
360
|
await self._task_backend.submit_task(task_dict)
|
|
360
361
|
else:
|
|
@@ -368,10 +369,7 @@ class MemoryEngine:
|
|
|
368
369
|
try:
|
|
369
370
|
pool = await self._get_pool()
|
|
370
371
|
async with acquire_with_retry(pool) as conn:
|
|
371
|
-
await conn.execute(
|
|
372
|
-
"DELETE FROM async_operations WHERE id = $1",
|
|
373
|
-
uuid.UUID(operation_id)
|
|
374
|
-
)
|
|
372
|
+
await conn.execute("DELETE FROM async_operations WHERE id = $1", uuid.UUID(operation_id))
|
|
375
373
|
except Exception as e:
|
|
376
374
|
logger.error(f"Failed to delete async operation record {operation_id}: {e}")
|
|
377
375
|
|
|
@@ -391,7 +389,7 @@ class MemoryEngine:
|
|
|
391
389
|
WHERE id = $1
|
|
392
390
|
""",
|
|
393
391
|
uuid.UUID(operation_id),
|
|
394
|
-
truncated_error
|
|
392
|
+
truncated_error,
|
|
395
393
|
)
|
|
396
394
|
logger.info(f"Marked async operation as failed: {operation_id}")
|
|
397
395
|
except Exception as e:
|
|
@@ -406,8 +404,6 @@ class MemoryEngine:
|
|
|
406
404
|
if self._initialized:
|
|
407
405
|
return
|
|
408
406
|
|
|
409
|
-
import concurrent.futures
|
|
410
|
-
|
|
411
407
|
# Run model loading in thread pool (CPU-bound) in parallel with pg0 startup
|
|
412
408
|
loop = asyncio.get_event_loop()
|
|
413
409
|
|
|
@@ -429,10 +425,7 @@ class MemoryEngine:
|
|
|
429
425
|
"""Initialize embedding model."""
|
|
430
426
|
# For local providers, run in thread pool to avoid blocking event loop
|
|
431
427
|
if self.embeddings.provider_name == "local":
|
|
432
|
-
await loop.run_in_executor(
|
|
433
|
-
None,
|
|
434
|
-
lambda: asyncio.run(self.embeddings.initialize())
|
|
435
|
-
)
|
|
428
|
+
await loop.run_in_executor(None, lambda: asyncio.run(self.embeddings.initialize()))
|
|
436
429
|
else:
|
|
437
430
|
await self.embeddings.initialize()
|
|
438
431
|
|
|
@@ -441,10 +434,7 @@ class MemoryEngine:
|
|
|
441
434
|
cross_encoder = self._cross_encoder_reranker.cross_encoder
|
|
442
435
|
# For local providers, run in thread pool to avoid blocking event loop
|
|
443
436
|
if cross_encoder.provider_name == "local":
|
|
444
|
-
await loop.run_in_executor(
|
|
445
|
-
None,
|
|
446
|
-
lambda: asyncio.run(cross_encoder.initialize())
|
|
447
|
-
)
|
|
437
|
+
await loop.run_in_executor(None, lambda: asyncio.run(cross_encoder.initialize()))
|
|
448
438
|
else:
|
|
449
439
|
await cross_encoder.initialize()
|
|
450
440
|
|
|
@@ -469,6 +459,7 @@ class MemoryEngine:
|
|
|
469
459
|
# Run database migrations if enabled
|
|
470
460
|
if self._run_migrations:
|
|
471
461
|
from ..migrations import run_migrations
|
|
462
|
+
|
|
472
463
|
logger.info("Running database migrations...")
|
|
473
464
|
run_migrations(self.db_url)
|
|
474
465
|
|
|
@@ -555,7 +546,6 @@ class MemoryEngine:
|
|
|
555
546
|
self._pg0 = None
|
|
556
547
|
logger.info("pg0 stopped")
|
|
557
548
|
|
|
558
|
-
|
|
559
549
|
async def wait_for_background_tasks(self):
|
|
560
550
|
"""
|
|
561
551
|
Wait for all pending background tasks to complete.
|
|
@@ -563,7 +553,7 @@ class MemoryEngine:
|
|
|
563
553
|
This is useful in tests to ensure background tasks (like opinion reinforcement)
|
|
564
554
|
complete before making assertions.
|
|
565
555
|
"""
|
|
566
|
-
if hasattr(self._task_backend,
|
|
556
|
+
if hasattr(self._task_backend, "wait_for_pending_tasks"):
|
|
567
557
|
await self._task_backend.wait_for_pending_tasks()
|
|
568
558
|
|
|
569
559
|
def _format_readable_date(self, dt: datetime) -> str:
|
|
@@ -596,12 +586,12 @@ class MemoryEngine:
|
|
|
596
586
|
self,
|
|
597
587
|
conn,
|
|
598
588
|
bank_id: str,
|
|
599
|
-
texts:
|
|
600
|
-
embeddings:
|
|
589
|
+
texts: list[str],
|
|
590
|
+
embeddings: list[list[float]],
|
|
601
591
|
event_date: datetime,
|
|
602
592
|
time_window_hours: int = 24,
|
|
603
|
-
similarity_threshold: float = 0.95
|
|
604
|
-
) ->
|
|
593
|
+
similarity_threshold: float = 0.95,
|
|
594
|
+
) -> list[bool]:
|
|
605
595
|
"""
|
|
606
596
|
Check which facts are duplicates using semantic similarity + temporal window.
|
|
607
597
|
|
|
@@ -635,6 +625,7 @@ class MemoryEngine:
|
|
|
635
625
|
|
|
636
626
|
# Fetch ALL existing facts in time window ONCE (much faster than N queries)
|
|
637
627
|
import time as time_mod
|
|
628
|
+
|
|
638
629
|
fetch_start = time_mod.time()
|
|
639
630
|
existing_facts = await conn.fetch(
|
|
640
631
|
"""
|
|
@@ -643,7 +634,9 @@ class MemoryEngine:
|
|
|
643
634
|
WHERE bank_id = $1
|
|
644
635
|
AND event_date BETWEEN $2 AND $3
|
|
645
636
|
""",
|
|
646
|
-
bank_id,
|
|
637
|
+
bank_id,
|
|
638
|
+
time_lower,
|
|
639
|
+
time_upper,
|
|
647
640
|
)
|
|
648
641
|
|
|
649
642
|
# If no existing facts, nothing is duplicate
|
|
@@ -651,17 +644,17 @@ class MemoryEngine:
|
|
|
651
644
|
return [False] * len(texts)
|
|
652
645
|
|
|
653
646
|
# Compute similarities in Python (vectorized with numpy)
|
|
654
|
-
import numpy as np
|
|
655
647
|
is_duplicate = []
|
|
656
648
|
|
|
657
649
|
# Convert existing embeddings to numpy for faster computation
|
|
658
650
|
embedding_arrays = []
|
|
659
651
|
for row in existing_facts:
|
|
660
|
-
raw_emb = row[
|
|
652
|
+
raw_emb = row["embedding"]
|
|
661
653
|
# Handle different pgvector formats
|
|
662
654
|
if isinstance(raw_emb, str):
|
|
663
655
|
# Parse string format: "[1.0, 2.0, ...]"
|
|
664
656
|
import json
|
|
657
|
+
|
|
665
658
|
emb = np.array(json.loads(raw_emb), dtype=np.float32)
|
|
666
659
|
elif isinstance(raw_emb, (list, tuple)):
|
|
667
660
|
emb = np.array(raw_emb, dtype=np.float32)
|
|
@@ -691,7 +684,6 @@ class MemoryEngine:
|
|
|
691
684
|
max_similarity = np.max(similarities) if len(similarities) > 0 else 0
|
|
692
685
|
is_duplicate.append(max_similarity > similarity_threshold)
|
|
693
686
|
|
|
694
|
-
|
|
695
687
|
return is_duplicate
|
|
696
688
|
|
|
697
689
|
def retain(
|
|
@@ -699,8 +691,8 @@ class MemoryEngine:
|
|
|
699
691
|
bank_id: str,
|
|
700
692
|
content: str,
|
|
701
693
|
context: str = "",
|
|
702
|
-
event_date:
|
|
703
|
-
) ->
|
|
694
|
+
event_date: datetime | None = None,
|
|
695
|
+
) -> list[str]:
|
|
704
696
|
"""
|
|
705
697
|
Store content as memory units (synchronous wrapper).
|
|
706
698
|
|
|
@@ -724,11 +716,11 @@ class MemoryEngine:
|
|
|
724
716
|
bank_id: str,
|
|
725
717
|
content: str,
|
|
726
718
|
context: str = "",
|
|
727
|
-
event_date:
|
|
728
|
-
document_id:
|
|
729
|
-
fact_type_override:
|
|
730
|
-
confidence_score:
|
|
731
|
-
) ->
|
|
719
|
+
event_date: datetime | None = None,
|
|
720
|
+
document_id: str | None = None,
|
|
721
|
+
fact_type_override: str | None = None,
|
|
722
|
+
confidence_score: float | None = None,
|
|
723
|
+
) -> list[str]:
|
|
732
724
|
"""
|
|
733
725
|
Store content as memory units with temporal and semantic links (ASYNC version).
|
|
734
726
|
|
|
@@ -747,11 +739,7 @@ class MemoryEngine:
|
|
|
747
739
|
List of created unit IDs
|
|
748
740
|
"""
|
|
749
741
|
# Build content dict
|
|
750
|
-
content_dict: RetainContentDict = {
|
|
751
|
-
"content": content,
|
|
752
|
-
"context": context,
|
|
753
|
-
"event_date": event_date
|
|
754
|
-
}
|
|
742
|
+
content_dict: RetainContentDict = {"content": content, "context": context, "event_date": event_date}
|
|
755
743
|
if document_id:
|
|
756
744
|
content_dict["document_id"] = document_id
|
|
757
745
|
|
|
@@ -760,7 +748,7 @@ class MemoryEngine:
|
|
|
760
748
|
bank_id=bank_id,
|
|
761
749
|
contents=[content_dict],
|
|
762
750
|
fact_type_override=fact_type_override,
|
|
763
|
-
confidence_score=confidence_score
|
|
751
|
+
confidence_score=confidence_score,
|
|
764
752
|
)
|
|
765
753
|
|
|
766
754
|
# Return the first (and only) list of unit IDs
|
|
@@ -769,11 +757,11 @@ class MemoryEngine:
|
|
|
769
757
|
async def retain_batch_async(
|
|
770
758
|
self,
|
|
771
759
|
bank_id: str,
|
|
772
|
-
contents:
|
|
773
|
-
document_id:
|
|
774
|
-
fact_type_override:
|
|
775
|
-
confidence_score:
|
|
776
|
-
) ->
|
|
760
|
+
contents: list[RetainContentDict],
|
|
761
|
+
document_id: str | None = None,
|
|
762
|
+
fact_type_override: str | None = None,
|
|
763
|
+
confidence_score: float | None = None,
|
|
764
|
+
) -> list[list[str]]:
|
|
777
765
|
"""
|
|
778
766
|
Store multiple content items as memory units in ONE batch operation.
|
|
779
767
|
|
|
@@ -839,7 +827,9 @@ class MemoryEngine:
|
|
|
839
827
|
|
|
840
828
|
if total_chars > CHARS_PER_BATCH:
|
|
841
829
|
# Split into smaller batches based on character count
|
|
842
|
-
logger.info(
|
|
830
|
+
logger.info(
|
|
831
|
+
f"Large batch detected ({total_chars:,} chars from {len(contents)} items). Splitting into sub-batches of ~{CHARS_PER_BATCH:,} chars each..."
|
|
832
|
+
)
|
|
843
833
|
|
|
844
834
|
sub_batches = []
|
|
845
835
|
current_batch = []
|
|
@@ -868,7 +858,9 @@ class MemoryEngine:
|
|
|
868
858
|
all_results = []
|
|
869
859
|
for i, sub_batch in enumerate(sub_batches, 1):
|
|
870
860
|
sub_batch_chars = sum(len(item.get("content", "")) for item in sub_batch)
|
|
871
|
-
logger.info(
|
|
861
|
+
logger.info(
|
|
862
|
+
f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars"
|
|
863
|
+
)
|
|
872
864
|
|
|
873
865
|
sub_results = await self._retain_batch_async_internal(
|
|
874
866
|
bank_id=bank_id,
|
|
@@ -876,12 +868,14 @@ class MemoryEngine:
|
|
|
876
868
|
document_id=document_id,
|
|
877
869
|
is_first_batch=i == 1, # Only upsert on first batch
|
|
878
870
|
fact_type_override=fact_type_override,
|
|
879
|
-
confidence_score=confidence_score
|
|
871
|
+
confidence_score=confidence_score,
|
|
880
872
|
)
|
|
881
873
|
all_results.extend(sub_results)
|
|
882
874
|
|
|
883
875
|
total_time = time.time() - start_time
|
|
884
|
-
logger.info(
|
|
876
|
+
logger.info(
|
|
877
|
+
f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s"
|
|
878
|
+
)
|
|
885
879
|
return all_results
|
|
886
880
|
|
|
887
881
|
# Small batch - use internal method directly
|
|
@@ -891,18 +885,18 @@ class MemoryEngine:
|
|
|
891
885
|
document_id=document_id,
|
|
892
886
|
is_first_batch=True,
|
|
893
887
|
fact_type_override=fact_type_override,
|
|
894
|
-
confidence_score=confidence_score
|
|
888
|
+
confidence_score=confidence_score,
|
|
895
889
|
)
|
|
896
890
|
|
|
897
891
|
async def _retain_batch_async_internal(
|
|
898
892
|
self,
|
|
899
893
|
bank_id: str,
|
|
900
|
-
contents:
|
|
901
|
-
document_id:
|
|
894
|
+
contents: list[RetainContentDict],
|
|
895
|
+
document_id: str | None = None,
|
|
902
896
|
is_first_batch: bool = True,
|
|
903
|
-
fact_type_override:
|
|
904
|
-
confidence_score:
|
|
905
|
-
) ->
|
|
897
|
+
fact_type_override: str | None = None,
|
|
898
|
+
confidence_score: float | None = None,
|
|
899
|
+
) -> list[list[str]]:
|
|
906
900
|
"""
|
|
907
901
|
Internal method for batch processing without chunking logic.
|
|
908
902
|
|
|
@@ -938,7 +932,7 @@ class MemoryEngine:
|
|
|
938
932
|
document_id=document_id,
|
|
939
933
|
is_first_batch=is_first_batch,
|
|
940
934
|
fact_type_override=fact_type_override,
|
|
941
|
-
confidence_score=confidence_score
|
|
935
|
+
confidence_score=confidence_score,
|
|
942
936
|
)
|
|
943
937
|
|
|
944
938
|
def recall(
|
|
@@ -949,7 +943,7 @@ class MemoryEngine:
|
|
|
949
943
|
budget: Budget = Budget.MID,
|
|
950
944
|
max_tokens: int = 4096,
|
|
951
945
|
enable_trace: bool = False,
|
|
952
|
-
) -> tuple[
|
|
946
|
+
) -> tuple[list[dict[str, Any]], Any | None]:
|
|
953
947
|
"""
|
|
954
948
|
Recall memories using 4-way parallel retrieval (synchronous wrapper).
|
|
955
949
|
|
|
@@ -968,19 +962,17 @@ class MemoryEngine:
|
|
|
968
962
|
Tuple of (results, trace)
|
|
969
963
|
"""
|
|
970
964
|
# Run async version synchronously
|
|
971
|
-
return asyncio.run(self.recall_async(
|
|
972
|
-
bank_id, query, [fact_type], budget, max_tokens, enable_trace
|
|
973
|
-
))
|
|
965
|
+
return asyncio.run(self.recall_async(bank_id, query, [fact_type], budget, max_tokens, enable_trace))
|
|
974
966
|
|
|
975
967
|
async def recall_async(
|
|
976
968
|
self,
|
|
977
969
|
bank_id: str,
|
|
978
970
|
query: str,
|
|
979
|
-
fact_type:
|
|
971
|
+
fact_type: list[str],
|
|
980
972
|
budget: Budget = Budget.MID,
|
|
981
973
|
max_tokens: int = 4096,
|
|
982
974
|
enable_trace: bool = False,
|
|
983
|
-
question_date:
|
|
975
|
+
question_date: datetime | None = None,
|
|
984
976
|
include_entities: bool = False,
|
|
985
977
|
max_entity_tokens: int = 1024,
|
|
986
978
|
include_chunks: bool = False,
|
|
@@ -1027,11 +1019,7 @@ class MemoryEngine:
|
|
|
1027
1019
|
)
|
|
1028
1020
|
|
|
1029
1021
|
# Map budget enum to thinking_budget number
|
|
1030
|
-
budget_mapping = {
|
|
1031
|
-
Budget.LOW: 100,
|
|
1032
|
-
Budget.MID: 300,
|
|
1033
|
-
Budget.HIGH: 1000
|
|
1034
|
-
}
|
|
1022
|
+
budget_mapping = {Budget.LOW: 100, Budget.MID: 300, Budget.HIGH: 1000}
|
|
1035
1023
|
thinking_budget = budget_mapping[budget]
|
|
1036
1024
|
|
|
1037
1025
|
# Backpressure: limit concurrent recalls to prevent overwhelming the database
|
|
@@ -1041,20 +1029,29 @@ class MemoryEngine:
|
|
|
1041
1029
|
for attempt in range(max_retries + 1):
|
|
1042
1030
|
try:
|
|
1043
1031
|
return await self._search_with_retries(
|
|
1044
|
-
bank_id,
|
|
1045
|
-
|
|
1032
|
+
bank_id,
|
|
1033
|
+
query,
|
|
1034
|
+
fact_type,
|
|
1035
|
+
thinking_budget,
|
|
1036
|
+
max_tokens,
|
|
1037
|
+
enable_trace,
|
|
1038
|
+
question_date,
|
|
1039
|
+
include_entities,
|
|
1040
|
+
max_entity_tokens,
|
|
1041
|
+
include_chunks,
|
|
1042
|
+
max_chunk_tokens,
|
|
1046
1043
|
)
|
|
1047
1044
|
except Exception as e:
|
|
1048
1045
|
# Check if it's a connection error
|
|
1049
1046
|
is_connection_error = (
|
|
1050
|
-
isinstance(e, asyncpg.TooManyConnectionsError)
|
|
1051
|
-
isinstance(e, asyncpg.CannotConnectNowError)
|
|
1052
|
-
(isinstance(e, asyncpg.PostgresError) and
|
|
1047
|
+
isinstance(e, asyncpg.TooManyConnectionsError)
|
|
1048
|
+
or isinstance(e, asyncpg.CannotConnectNowError)
|
|
1049
|
+
or (isinstance(e, asyncpg.PostgresError) and "connection" in str(e).lower())
|
|
1053
1050
|
)
|
|
1054
1051
|
|
|
1055
1052
|
if is_connection_error and attempt < max_retries:
|
|
1056
1053
|
# Wait with exponential backoff before retry
|
|
1057
|
-
wait_time = 0.5 * (2
|
|
1054
|
+
wait_time = 0.5 * (2**attempt) # 0.5s, 1s, 2s
|
|
1058
1055
|
logger.warning(
|
|
1059
1056
|
f"Connection error on search attempt {attempt + 1}/{max_retries + 1}: {str(e)}. "
|
|
1060
1057
|
f"Retrying in {wait_time:.1f}s..."
|
|
@@ -1069,11 +1066,11 @@ class MemoryEngine:
|
|
|
1069
1066
|
self,
|
|
1070
1067
|
bank_id: str,
|
|
1071
1068
|
query: str,
|
|
1072
|
-
fact_type:
|
|
1069
|
+
fact_type: list[str],
|
|
1073
1070
|
thinking_budget: int,
|
|
1074
1071
|
max_tokens: int,
|
|
1075
1072
|
enable_trace: bool,
|
|
1076
|
-
question_date:
|
|
1073
|
+
question_date: datetime | None = None,
|
|
1077
1074
|
include_entities: bool = False,
|
|
1078
1075
|
max_entity_tokens: int = 500,
|
|
1079
1076
|
include_chunks: bool = False,
|
|
@@ -1106,6 +1103,7 @@ class MemoryEngine:
|
|
|
1106
1103
|
"""
|
|
1107
1104
|
# Initialize tracer if requested
|
|
1108
1105
|
from .search.tracer import SearchTracer
|
|
1106
|
+
|
|
1109
1107
|
tracer = SearchTracer(query, thinking_budget, max_tokens) if enable_trace else None
|
|
1110
1108
|
if tracer:
|
|
1111
1109
|
tracer.start()
|
|
@@ -1116,7 +1114,9 @@ class MemoryEngine:
|
|
|
1116
1114
|
# Buffer logs for clean output in concurrent scenarios
|
|
1117
1115
|
recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
|
|
1118
1116
|
log_buffer = []
|
|
1119
|
-
log_buffer.append(
|
|
1117
|
+
log_buffer.append(
|
|
1118
|
+
f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})"
|
|
1119
|
+
)
|
|
1120
1120
|
|
|
1121
1121
|
try:
|
|
1122
1122
|
# Step 1: Generate query embedding (for semantic search)
|
|
@@ -1141,8 +1141,7 @@ class MemoryEngine:
|
|
|
1141
1141
|
# Run retrieval for each fact type in parallel
|
|
1142
1142
|
retrieval_tasks = [
|
|
1143
1143
|
retrieve_parallel(
|
|
1144
|
-
pool, query, query_embedding_str, bank_id, ft, thinking_budget,
|
|
1145
|
-
question_date, self.query_analyzer
|
|
1144
|
+
pool, query, query_embedding_str, bank_id, ft, thinking_budget, question_date, self.query_analyzer
|
|
1146
1145
|
)
|
|
1147
1146
|
for ft in fact_type
|
|
1148
1147
|
]
|
|
@@ -1156,22 +1155,24 @@ class MemoryEngine:
|
|
|
1156
1155
|
aggregated_timings = {"semantic": 0.0, "bm25": 0.0, "graph": 0.0, "temporal": 0.0}
|
|
1157
1156
|
|
|
1158
1157
|
detected_temporal_constraint = None
|
|
1159
|
-
for idx,
|
|
1158
|
+
for idx, retrieval_result in enumerate(all_retrievals):
|
|
1160
1159
|
# Log fact types in this retrieval batch
|
|
1161
1160
|
ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
|
|
1162
|
-
logger.debug(
|
|
1161
|
+
logger.debug(
|
|
1162
|
+
f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}"
|
|
1163
|
+
)
|
|
1163
1164
|
|
|
1164
|
-
semantic_results.extend(
|
|
1165
|
-
bm25_results.extend(
|
|
1166
|
-
graph_results.extend(
|
|
1167
|
-
if
|
|
1168
|
-
temporal_results.extend(
|
|
1165
|
+
semantic_results.extend(retrieval_result.semantic)
|
|
1166
|
+
bm25_results.extend(retrieval_result.bm25)
|
|
1167
|
+
graph_results.extend(retrieval_result.graph)
|
|
1168
|
+
if retrieval_result.temporal:
|
|
1169
|
+
temporal_results.extend(retrieval_result.temporal)
|
|
1169
1170
|
# Track max timing for each method (since they run in parallel across fact types)
|
|
1170
|
-
for method, duration in
|
|
1171
|
-
aggregated_timings[method] = max(aggregated_timings
|
|
1171
|
+
for method, duration in retrieval_result.timings.items():
|
|
1172
|
+
aggregated_timings[method] = max(aggregated_timings.get(method, 0.0), duration)
|
|
1172
1173
|
# Capture temporal constraint (same across all fact types)
|
|
1173
|
-
if
|
|
1174
|
-
detected_temporal_constraint =
|
|
1174
|
+
if retrieval_result.temporal_constraint:
|
|
1175
|
+
detected_temporal_constraint = retrieval_result.temporal_constraint
|
|
1175
1176
|
|
|
1176
1177
|
# If no temporal results from any fact type, set to None
|
|
1177
1178
|
if not temporal_results:
|
|
@@ -1179,11 +1180,13 @@ class MemoryEngine:
|
|
|
1179
1180
|
|
|
1180
1181
|
# Sort combined results by score (descending) so higher-scored results
|
|
1181
1182
|
# get better ranks in the trace, regardless of fact type
|
|
1182
|
-
semantic_results.sort(key=lambda r: r.similarity if hasattr(r,
|
|
1183
|
-
bm25_results.sort(key=lambda r: r.bm25_score if hasattr(r,
|
|
1184
|
-
graph_results.sort(key=lambda r: r.activation if hasattr(r,
|
|
1183
|
+
semantic_results.sort(key=lambda r: r.similarity if hasattr(r, "similarity") else 0, reverse=True)
|
|
1184
|
+
bm25_results.sort(key=lambda r: r.bm25_score if hasattr(r, "bm25_score") else 0, reverse=True)
|
|
1185
|
+
graph_results.sort(key=lambda r: r.activation if hasattr(r, "activation") else 0, reverse=True)
|
|
1185
1186
|
if temporal_results:
|
|
1186
|
-
temporal_results.sort(
|
|
1187
|
+
temporal_results.sort(
|
|
1188
|
+
key=lambda r: r.combined_score if hasattr(r, "combined_score") else 0, reverse=True
|
|
1189
|
+
)
|
|
1187
1190
|
|
|
1188
1191
|
retrieval_duration = time.time() - retrieval_start
|
|
1189
1192
|
|
|
@@ -1193,7 +1196,7 @@ class MemoryEngine:
|
|
|
1193
1196
|
timing_parts = [
|
|
1194
1197
|
f"semantic={len(semantic_results)}({aggregated_timings['semantic']:.3f}s)",
|
|
1195
1198
|
f"bm25={len(bm25_results)}({aggregated_timings['bm25']:.3f}s)",
|
|
1196
|
-
f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)"
|
|
1199
|
+
f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)",
|
|
1197
1200
|
]
|
|
1198
1201
|
temporal_info = ""
|
|
1199
1202
|
if detected_temporal_constraint:
|
|
@@ -1201,61 +1204,75 @@ class MemoryEngine:
|
|
|
1201
1204
|
temporal_count = len(temporal_results) if temporal_results else 0
|
|
1202
1205
|
timing_parts.append(f"temporal={temporal_count}({aggregated_timings['temporal']:.3f}s)")
|
|
1203
1206
|
temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
|
|
1204
|
-
log_buffer.append(
|
|
1207
|
+
log_buffer.append(
|
|
1208
|
+
f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}"
|
|
1209
|
+
)
|
|
1205
1210
|
|
|
1206
|
-
# Record retrieval results for tracer
|
|
1211
|
+
# Record retrieval results for tracer - per fact type
|
|
1207
1212
|
if tracer:
|
|
1208
1213
|
# Convert RetrievalResult to old tuple format for tracer
|
|
1209
1214
|
def to_tuple_format(results):
|
|
1210
1215
|
return [(r.id, r.__dict__) for r in results]
|
|
1211
1216
|
|
|
1212
|
-
# Add
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
results=to_tuple_format(semantic_results),
|
|
1216
|
-
duration_seconds=aggregated_timings["semantic"],
|
|
1217
|
-
score_field="similarity",
|
|
1218
|
-
metadata={"limit": thinking_budget}
|
|
1219
|
-
)
|
|
1217
|
+
# Add retrieval results per fact type (to show parallel execution in UI)
|
|
1218
|
+
for idx, rr in enumerate(all_retrievals):
|
|
1219
|
+
ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
|
|
1220
1220
|
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1221
|
+
# Add semantic retrieval results for this fact type
|
|
1222
|
+
tracer.add_retrieval_results(
|
|
1223
|
+
method_name="semantic",
|
|
1224
|
+
results=to_tuple_format(rr.semantic),
|
|
1225
|
+
duration_seconds=rr.timings.get("semantic", 0.0),
|
|
1226
|
+
score_field="similarity",
|
|
1227
|
+
metadata={"limit": thinking_budget},
|
|
1228
|
+
fact_type=ft_name,
|
|
1229
|
+
)
|
|
1229
1230
|
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1231
|
+
# Add BM25 retrieval results for this fact type
|
|
1232
|
+
tracer.add_retrieval_results(
|
|
1233
|
+
method_name="bm25",
|
|
1234
|
+
results=to_tuple_format(rr.bm25),
|
|
1235
|
+
duration_seconds=rr.timings.get("bm25", 0.0),
|
|
1236
|
+
score_field="bm25_score",
|
|
1237
|
+
metadata={"limit": thinking_budget},
|
|
1238
|
+
fact_type=ft_name,
|
|
1239
|
+
)
|
|
1238
1240
|
|
|
1239
|
-
|
|
1240
|
-
if temporal_results:
|
|
1241
|
+
# Add graph retrieval results for this fact type
|
|
1241
1242
|
tracer.add_retrieval_results(
|
|
1242
|
-
method_name="
|
|
1243
|
-
results=to_tuple_format(
|
|
1244
|
-
duration_seconds=
|
|
1245
|
-
score_field="
|
|
1246
|
-
metadata={"budget": thinking_budget}
|
|
1243
|
+
method_name="graph",
|
|
1244
|
+
results=to_tuple_format(rr.graph),
|
|
1245
|
+
duration_seconds=rr.timings.get("graph", 0.0),
|
|
1246
|
+
score_field="activation",
|
|
1247
|
+
metadata={"budget": thinking_budget},
|
|
1248
|
+
fact_type=ft_name,
|
|
1247
1249
|
)
|
|
1248
1250
|
|
|
1251
|
+
# Add temporal retrieval results for this fact type (even if empty, to show it ran)
|
|
1252
|
+
if rr.temporal is not None:
|
|
1253
|
+
tracer.add_retrieval_results(
|
|
1254
|
+
method_name="temporal",
|
|
1255
|
+
results=to_tuple_format(rr.temporal),
|
|
1256
|
+
duration_seconds=rr.timings.get("temporal", 0.0),
|
|
1257
|
+
score_field="temporal_score",
|
|
1258
|
+
metadata={"budget": thinking_budget},
|
|
1259
|
+
fact_type=ft_name,
|
|
1260
|
+
)
|
|
1261
|
+
|
|
1249
1262
|
# Record entry points (from semantic results) for legacy graph view
|
|
1250
1263
|
for rank, retrieval in enumerate(semantic_results[:10], start=1): # Top 10 as entry points
|
|
1251
1264
|
tracer.add_entry_point(retrieval.id, retrieval.text, retrieval.similarity or 0.0, rank)
|
|
1252
1265
|
|
|
1253
|
-
tracer.add_phase_metric(
|
|
1254
|
-
"
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1266
|
+
tracer.add_phase_metric(
|
|
1267
|
+
"parallel_retrieval",
|
|
1268
|
+
step_duration,
|
|
1269
|
+
{
|
|
1270
|
+
"semantic_count": len(semantic_results),
|
|
1271
|
+
"bm25_count": len(bm25_results),
|
|
1272
|
+
"graph_count": len(graph_results),
|
|
1273
|
+
"temporal_count": len(temporal_results) if temporal_results else 0,
|
|
1274
|
+
},
|
|
1275
|
+
)
|
|
1259
1276
|
|
|
1260
1277
|
# Step 3: Merge with RRF
|
|
1261
1278
|
step_start = time.time()
|
|
@@ -1263,7 +1280,9 @@ class MemoryEngine:
|
|
|
1263
1280
|
|
|
1264
1281
|
# Merge 3 or 4 result lists depending on temporal constraint
|
|
1265
1282
|
if temporal_results:
|
|
1266
|
-
merged_candidates = reciprocal_rank_fusion(
|
|
1283
|
+
merged_candidates = reciprocal_rank_fusion(
|
|
1284
|
+
[semantic_results, bm25_results, graph_results, temporal_results]
|
|
1285
|
+
)
|
|
1267
1286
|
else:
|
|
1268
1287
|
merged_candidates = reciprocal_rank_fusion([semantic_results, bm25_results, graph_results])
|
|
1269
1288
|
|
|
@@ -1272,8 +1291,10 @@ class MemoryEngine:
|
|
|
1272
1291
|
|
|
1273
1292
|
if tracer:
|
|
1274
1293
|
# Convert MergedCandidate to old tuple format for tracer
|
|
1275
|
-
tracer_merged = [
|
|
1276
|
-
|
|
1294
|
+
tracer_merged = [
|
|
1295
|
+
(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
|
|
1296
|
+
for mc in merged_candidates
|
|
1297
|
+
]
|
|
1277
1298
|
tracer.add_rrf_merged(tracer_merged)
|
|
1278
1299
|
tracer.add_phase_metric("rrf_merge", step_duration, {"candidates_merged": len(merged_candidates)})
|
|
1279
1300
|
|
|
@@ -1287,44 +1308,38 @@ class MemoryEngine:
|
|
|
1287
1308
|
step_duration = time.time() - step_start
|
|
1288
1309
|
log_buffer.append(f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s")
|
|
1289
1310
|
|
|
1290
|
-
if tracer:
|
|
1291
|
-
# Convert to old format for tracer
|
|
1292
|
-
results_dict = [sr.to_dict() for sr in scored_results]
|
|
1293
|
-
tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
|
|
1294
|
-
for mc in merged_candidates]
|
|
1295
|
-
tracer.add_reranked(results_dict, tracer_merged)
|
|
1296
|
-
tracer.add_phase_metric("reranking", step_duration, {
|
|
1297
|
-
"reranker_type": "cross-encoder",
|
|
1298
|
-
"candidates_reranked": len(scored_results)
|
|
1299
|
-
})
|
|
1300
|
-
|
|
1301
1311
|
# Step 4.5: Combine cross-encoder score with retrieval signals
|
|
1302
1312
|
# This preserves retrieval work (RRF, temporal, recency) instead of pure cross-encoder ranking
|
|
1303
1313
|
if scored_results:
|
|
1304
|
-
# Normalize RRF scores to [0, 1] range
|
|
1314
|
+
# Normalize RRF scores to [0, 1] range using min-max normalization
|
|
1305
1315
|
rrf_scores = [sr.candidate.rrf_score for sr in scored_results]
|
|
1306
|
-
max_rrf = max(rrf_scores) if rrf_scores else
|
|
1316
|
+
max_rrf = max(rrf_scores) if rrf_scores else 0.0
|
|
1307
1317
|
min_rrf = min(rrf_scores) if rrf_scores else 0.0
|
|
1308
|
-
rrf_range = max_rrf - min_rrf
|
|
1318
|
+
rrf_range = max_rrf - min_rrf # Don't force to 1.0, let fallback handle it
|
|
1309
1319
|
|
|
1310
1320
|
# Calculate recency based on occurred_start (more recent = higher score)
|
|
1311
1321
|
now = utcnow()
|
|
1312
1322
|
for sr in scored_results:
|
|
1313
|
-
# Normalize RRF score
|
|
1314
|
-
|
|
1323
|
+
# Normalize RRF score (0-1 range, 0.5 if all same)
|
|
1324
|
+
if rrf_range > 0:
|
|
1325
|
+
sr.rrf_normalized = (sr.candidate.rrf_score - min_rrf) / rrf_range
|
|
1326
|
+
else:
|
|
1327
|
+
# All RRF scores are the same, use neutral value
|
|
1328
|
+
sr.rrf_normalized = 0.5
|
|
1315
1329
|
|
|
1316
1330
|
# Calculate recency (decay over 365 days, minimum 0.1)
|
|
1317
1331
|
sr.recency = 0.5 # default for missing dates
|
|
1318
1332
|
if sr.retrieval.occurred_start:
|
|
1319
1333
|
occurred = sr.retrieval.occurred_start
|
|
1320
|
-
if hasattr(occurred,
|
|
1321
|
-
|
|
1322
|
-
occurred = occurred.replace(tzinfo=timezone.utc)
|
|
1334
|
+
if hasattr(occurred, "tzinfo") and occurred.tzinfo is None:
|
|
1335
|
+
occurred = occurred.replace(tzinfo=UTC)
|
|
1323
1336
|
days_ago = (now - occurred).total_seconds() / 86400
|
|
1324
1337
|
sr.recency = max(0.1, 1.0 - (days_ago / 365)) # Linear decay over 1 year
|
|
1325
1338
|
|
|
1326
1339
|
# Get temporal proximity if available (already 0-1)
|
|
1327
|
-
sr.temporal =
|
|
1340
|
+
sr.temporal = (
|
|
1341
|
+
sr.retrieval.temporal_proximity if sr.retrieval.temporal_proximity is not None else 0.5
|
|
1342
|
+
)
|
|
1328
1343
|
|
|
1329
1344
|
# Weighted combination
|
|
1330
1345
|
# Cross-encoder: 60% (semantic relevance)
|
|
@@ -1332,16 +1347,32 @@ class MemoryEngine:
|
|
|
1332
1347
|
# Temporal proximity: 10% (time relevance for temporal queries)
|
|
1333
1348
|
# Recency: 10% (prefer recent facts)
|
|
1334
1349
|
sr.combined_score = (
|
|
1335
|
-
0.6 * sr.cross_encoder_score_normalized
|
|
1336
|
-
0.2 * sr.rrf_normalized
|
|
1337
|
-
0.1 * sr.temporal
|
|
1338
|
-
0.1 * sr.recency
|
|
1350
|
+
0.6 * sr.cross_encoder_score_normalized
|
|
1351
|
+
+ 0.2 * sr.rrf_normalized
|
|
1352
|
+
+ 0.1 * sr.temporal
|
|
1353
|
+
+ 0.1 * sr.recency
|
|
1339
1354
|
)
|
|
1340
1355
|
sr.weight = sr.combined_score # Update weight for final ranking
|
|
1341
1356
|
|
|
1342
1357
|
# Re-sort by combined score
|
|
1343
1358
|
scored_results.sort(key=lambda x: x.weight, reverse=True)
|
|
1344
|
-
log_buffer.append(
|
|
1359
|
+
log_buffer.append(
|
|
1360
|
+
" [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)"
|
|
1361
|
+
)
|
|
1362
|
+
|
|
1363
|
+
# Add reranked results to tracer AFTER combined scoring (so normalized values are included)
|
|
1364
|
+
if tracer:
|
|
1365
|
+
results_dict = [sr.to_dict() for sr in scored_results]
|
|
1366
|
+
tracer_merged = [
|
|
1367
|
+
(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
|
|
1368
|
+
for mc in merged_candidates
|
|
1369
|
+
]
|
|
1370
|
+
tracer.add_reranked(results_dict, tracer_merged)
|
|
1371
|
+
tracer.add_phase_metric(
|
|
1372
|
+
"reranking",
|
|
1373
|
+
step_duration,
|
|
1374
|
+
{"reranker_type": "cross-encoder", "candidates_reranked": len(scored_results)},
|
|
1375
|
+
)
|
|
1345
1376
|
|
|
1346
1377
|
# Step 5: Truncate to thinking_budget * 2 for token filtering
|
|
1347
1378
|
rerank_limit = thinking_budget * 2
|
|
@@ -1360,14 +1391,16 @@ class MemoryEngine:
|
|
|
1360
1391
|
top_scored = [sr for sr in top_scored if sr.id in filtered_ids]
|
|
1361
1392
|
|
|
1362
1393
|
step_duration = time.time() - step_start
|
|
1363
|
-
log_buffer.append(
|
|
1394
|
+
log_buffer.append(
|
|
1395
|
+
f" [6] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s"
|
|
1396
|
+
)
|
|
1364
1397
|
|
|
1365
1398
|
if tracer:
|
|
1366
|
-
tracer.add_phase_metric(
|
|
1367
|
-
"
|
|
1368
|
-
|
|
1369
|
-
"max_tokens": max_tokens
|
|
1370
|
-
|
|
1399
|
+
tracer.add_phase_metric(
|
|
1400
|
+
"token_filtering",
|
|
1401
|
+
step_duration,
|
|
1402
|
+
{"results_selected": len(top_scored), "tokens_used": total_tokens, "max_tokens": max_tokens},
|
|
1403
|
+
)
|
|
1371
1404
|
|
|
1372
1405
|
# Record visits for all retrieved nodes
|
|
1373
1406
|
if tracer:
|
|
@@ -1386,16 +1419,13 @@ class MemoryEngine:
|
|
|
1386
1419
|
semantic_similarity=sr.retrieval.similarity or 0.0,
|
|
1387
1420
|
recency=sr.recency,
|
|
1388
1421
|
frequency=0.0,
|
|
1389
|
-
final_weight=sr.weight
|
|
1422
|
+
final_weight=sr.weight,
|
|
1390
1423
|
)
|
|
1391
1424
|
|
|
1392
1425
|
# Step 8: Queue access count updates for visited nodes
|
|
1393
1426
|
visited_ids = list(set([sr.id for sr in scored_results[:50]])) # Top 50
|
|
1394
1427
|
if visited_ids:
|
|
1395
|
-
await self._task_backend.submit_task({
|
|
1396
|
-
'type': 'access_count_update',
|
|
1397
|
-
'node_ids': visited_ids
|
|
1398
|
-
})
|
|
1428
|
+
await self._task_backend.submit_task({"type": "access_count_update", "node_ids": visited_ids})
|
|
1399
1429
|
log_buffer.append(f" [7] Queued access count updates for {len(visited_ids)} nodes")
|
|
1400
1430
|
|
|
1401
1431
|
# Log fact_type distribution in results
|
|
@@ -1413,13 +1443,19 @@ class MemoryEngine:
|
|
|
1413
1443
|
# Convert datetime objects to ISO strings for JSON serialization
|
|
1414
1444
|
if result_dict.get("occurred_start"):
|
|
1415
1445
|
occurred_start = result_dict["occurred_start"]
|
|
1416
|
-
result_dict["occurred_start"] =
|
|
1446
|
+
result_dict["occurred_start"] = (
|
|
1447
|
+
occurred_start.isoformat() if hasattr(occurred_start, "isoformat") else occurred_start
|
|
1448
|
+
)
|
|
1417
1449
|
if result_dict.get("occurred_end"):
|
|
1418
1450
|
occurred_end = result_dict["occurred_end"]
|
|
1419
|
-
result_dict["occurred_end"] =
|
|
1451
|
+
result_dict["occurred_end"] = (
|
|
1452
|
+
occurred_end.isoformat() if hasattr(occurred_end, "isoformat") else occurred_end
|
|
1453
|
+
)
|
|
1420
1454
|
if result_dict.get("mentioned_at"):
|
|
1421
1455
|
mentioned_at = result_dict["mentioned_at"]
|
|
1422
|
-
result_dict["mentioned_at"] =
|
|
1456
|
+
result_dict["mentioned_at"] = (
|
|
1457
|
+
mentioned_at.isoformat() if hasattr(mentioned_at, "isoformat") else mentioned_at
|
|
1458
|
+
)
|
|
1423
1459
|
top_results_dicts.append(result_dict)
|
|
1424
1460
|
|
|
1425
1461
|
# Get entities for each fact if include_entities is requested
|
|
@@ -1435,16 +1471,15 @@ class MemoryEngine:
|
|
|
1435
1471
|
JOIN entities e ON ue.entity_id = e.id
|
|
1436
1472
|
WHERE ue.unit_id = ANY($1::uuid[])
|
|
1437
1473
|
""",
|
|
1438
|
-
unit_ids
|
|
1474
|
+
unit_ids,
|
|
1439
1475
|
)
|
|
1440
1476
|
for row in entity_rows:
|
|
1441
|
-
unit_id = str(row[
|
|
1477
|
+
unit_id = str(row["unit_id"])
|
|
1442
1478
|
if unit_id not in fact_entity_map:
|
|
1443
1479
|
fact_entity_map[unit_id] = []
|
|
1444
|
-
fact_entity_map[unit_id].append(
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
})
|
|
1480
|
+
fact_entity_map[unit_id].append(
|
|
1481
|
+
{"entity_id": str(row["entity_id"]), "canonical_name": row["canonical_name"]}
|
|
1482
|
+
)
|
|
1448
1483
|
|
|
1449
1484
|
# Convert results to MemoryFact objects
|
|
1450
1485
|
memory_facts = []
|
|
@@ -1453,20 +1488,22 @@ class MemoryEngine:
|
|
|
1453
1488
|
# Get entity names for this fact
|
|
1454
1489
|
entity_names = None
|
|
1455
1490
|
if include_entities and result_id in fact_entity_map:
|
|
1456
|
-
entity_names = [e[
|
|
1457
|
-
|
|
1458
|
-
memory_facts.append(
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1491
|
+
entity_names = [e["canonical_name"] for e in fact_entity_map[result_id]]
|
|
1492
|
+
|
|
1493
|
+
memory_facts.append(
|
|
1494
|
+
MemoryFact(
|
|
1495
|
+
id=result_id,
|
|
1496
|
+
text=result_dict.get("text"),
|
|
1497
|
+
fact_type=result_dict.get("fact_type", "world"),
|
|
1498
|
+
entities=entity_names,
|
|
1499
|
+
context=result_dict.get("context"),
|
|
1500
|
+
occurred_start=result_dict.get("occurred_start"),
|
|
1501
|
+
occurred_end=result_dict.get("occurred_end"),
|
|
1502
|
+
mentioned_at=result_dict.get("mentioned_at"),
|
|
1503
|
+
document_id=result_dict.get("document_id"),
|
|
1504
|
+
chunk_id=result_dict.get("chunk_id"),
|
|
1505
|
+
)
|
|
1506
|
+
)
|
|
1470
1507
|
|
|
1471
1508
|
# Fetch entity observations if requested
|
|
1472
1509
|
entities_dict = None
|
|
@@ -1483,8 +1520,8 @@ class MemoryEngine:
|
|
|
1483
1520
|
unit_id = sr.id
|
|
1484
1521
|
if unit_id in fact_entity_map:
|
|
1485
1522
|
for entity in fact_entity_map[unit_id]:
|
|
1486
|
-
entity_id = entity[
|
|
1487
|
-
entity_name = entity[
|
|
1523
|
+
entity_id = entity["entity_id"]
|
|
1524
|
+
entity_name = entity["canonical_name"]
|
|
1488
1525
|
if entity_id not in seen_entity_ids:
|
|
1489
1526
|
entities_ordered.append((entity_id, entity_name))
|
|
1490
1527
|
seen_entity_ids.add(entity_id)
|
|
@@ -1512,9 +1549,7 @@ class MemoryEngine:
|
|
|
1512
1549
|
|
|
1513
1550
|
if included_observations:
|
|
1514
1551
|
entities_dict[entity_name] = EntityState(
|
|
1515
|
-
entity_id=entity_id,
|
|
1516
|
-
canonical_name=entity_name,
|
|
1517
|
-
observations=included_observations
|
|
1552
|
+
entity_id=entity_id, canonical_name=entity_name, observations=included_observations
|
|
1518
1553
|
)
|
|
1519
1554
|
total_entity_tokens += entity_tokens
|
|
1520
1555
|
|
|
@@ -1542,11 +1577,11 @@ class MemoryEngine:
|
|
|
1542
1577
|
FROM chunks
|
|
1543
1578
|
WHERE chunk_id = ANY($1::text[])
|
|
1544
1579
|
""",
|
|
1545
|
-
chunk_ids_ordered
|
|
1580
|
+
chunk_ids_ordered,
|
|
1546
1581
|
)
|
|
1547
1582
|
|
|
1548
1583
|
# Create a lookup dict for fast access
|
|
1549
|
-
chunks_lookup = {row[
|
|
1584
|
+
chunks_lookup = {row["chunk_id"]: row for row in chunks_rows}
|
|
1550
1585
|
|
|
1551
1586
|
# Apply token limit and build chunks_dict in the order of chunk_ids_ordered
|
|
1552
1587
|
chunks_dict = {}
|
|
@@ -1557,7 +1592,7 @@ class MemoryEngine:
|
|
|
1557
1592
|
continue
|
|
1558
1593
|
|
|
1559
1594
|
row = chunks_lookup[chunk_id]
|
|
1560
|
-
chunk_text = row[
|
|
1595
|
+
chunk_text = row["chunk_text"]
|
|
1561
1596
|
chunk_tokens = len(encoding.encode(chunk_text))
|
|
1562
1597
|
|
|
1563
1598
|
# Check if adding this chunk would exceed the limit
|
|
@@ -1568,18 +1603,14 @@ class MemoryEngine:
|
|
|
1568
1603
|
# Truncate to remaining tokens
|
|
1569
1604
|
truncated_text = encoding.decode(encoding.encode(chunk_text)[:remaining_tokens])
|
|
1570
1605
|
chunks_dict[chunk_id] = ChunkInfo(
|
|
1571
|
-
chunk_text=truncated_text,
|
|
1572
|
-
chunk_index=row['chunk_index'],
|
|
1573
|
-
truncated=True
|
|
1606
|
+
chunk_text=truncated_text, chunk_index=row["chunk_index"], truncated=True
|
|
1574
1607
|
)
|
|
1575
1608
|
total_chunk_tokens = max_chunk_tokens
|
|
1576
1609
|
# Stop adding more chunks once we hit the limit
|
|
1577
1610
|
break
|
|
1578
1611
|
else:
|
|
1579
1612
|
chunks_dict[chunk_id] = ChunkInfo(
|
|
1580
|
-
chunk_text=chunk_text,
|
|
1581
|
-
chunk_index=row['chunk_index'],
|
|
1582
|
-
truncated=False
|
|
1613
|
+
chunk_text=chunk_text, chunk_index=row["chunk_index"], truncated=False
|
|
1583
1614
|
)
|
|
1584
1615
|
total_chunk_tokens += chunk_tokens
|
|
1585
1616
|
|
|
@@ -1593,7 +1624,9 @@ class MemoryEngine:
|
|
|
1593
1624
|
total_time = time.time() - recall_start
|
|
1594
1625
|
num_chunks = len(chunks_dict) if chunks_dict else 0
|
|
1595
1626
|
num_entities = len(entities_dict) if entities_dict else 0
|
|
1596
|
-
log_buffer.append(
|
|
1627
|
+
log_buffer.append(
|
|
1628
|
+
f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s"
|
|
1629
|
+
)
|
|
1597
1630
|
logger.info("\n" + "\n".join(log_buffer))
|
|
1598
1631
|
|
|
1599
1632
|
return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
|
|
@@ -1604,10 +1637,8 @@ class MemoryEngine:
|
|
|
1604
1637
|
raise Exception(f"Failed to search memories: {str(e)}")
|
|
1605
1638
|
|
|
1606
1639
|
def _filter_by_token_budget(
|
|
1607
|
-
self,
|
|
1608
|
-
|
|
1609
|
-
max_tokens: int
|
|
1610
|
-
) -> Tuple[List[Dict[str, Any]], int]:
|
|
1640
|
+
self, results: list[dict[str, Any]], max_tokens: int
|
|
1641
|
+
) -> tuple[list[dict[str, Any]], int]:
|
|
1611
1642
|
"""
|
|
1612
1643
|
Filter results to fit within token budget.
|
|
1613
1644
|
|
|
@@ -1640,7 +1671,7 @@ class MemoryEngine:
|
|
|
1640
1671
|
|
|
1641
1672
|
return filtered_results, total_tokens
|
|
1642
1673
|
|
|
1643
|
-
async def get_document(self, document_id: str, bank_id: str) ->
|
|
1674
|
+
async def get_document(self, document_id: str, bank_id: str) -> dict[str, Any] | None:
|
|
1644
1675
|
"""
|
|
1645
1676
|
Retrieve document metadata and statistics.
|
|
1646
1677
|
|
|
@@ -1662,7 +1693,8 @@ class MemoryEngine:
|
|
|
1662
1693
|
WHERE d.id = $1 AND d.bank_id = $2
|
|
1663
1694
|
GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
|
|
1664
1695
|
""",
|
|
1665
|
-
document_id,
|
|
1696
|
+
document_id,
|
|
1697
|
+
bank_id,
|
|
1666
1698
|
)
|
|
1667
1699
|
|
|
1668
1700
|
if not doc:
|
|
@@ -1675,10 +1707,10 @@ class MemoryEngine:
|
|
|
1675
1707
|
"content_hash": doc["content_hash"],
|
|
1676
1708
|
"memory_unit_count": doc["unit_count"],
|
|
1677
1709
|
"created_at": doc["created_at"],
|
|
1678
|
-
"updated_at": doc["updated_at"]
|
|
1710
|
+
"updated_at": doc["updated_at"],
|
|
1679
1711
|
}
|
|
1680
1712
|
|
|
1681
|
-
async def delete_document(self, document_id: str, bank_id: str) ->
|
|
1713
|
+
async def delete_document(self, document_id: str, bank_id: str) -> dict[str, int]:
|
|
1682
1714
|
"""
|
|
1683
1715
|
Delete a document and all its associated memory units and links.
|
|
1684
1716
|
|
|
@@ -1694,22 +1726,17 @@ class MemoryEngine:
|
|
|
1694
1726
|
async with conn.transaction():
|
|
1695
1727
|
# Count units before deletion
|
|
1696
1728
|
units_count = await conn.fetchval(
|
|
1697
|
-
"SELECT COUNT(*) FROM memory_units WHERE document_id = $1",
|
|
1698
|
-
document_id
|
|
1729
|
+
"SELECT COUNT(*) FROM memory_units WHERE document_id = $1", document_id
|
|
1699
1730
|
)
|
|
1700
1731
|
|
|
1701
1732
|
# Delete document (cascades to memory_units and all their links)
|
|
1702
1733
|
deleted = await conn.fetchval(
|
|
1703
|
-
"DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id",
|
|
1704
|
-
document_id, bank_id
|
|
1734
|
+
"DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id", document_id, bank_id
|
|
1705
1735
|
)
|
|
1706
1736
|
|
|
1707
|
-
return {
|
|
1708
|
-
"document_deleted": 1 if deleted else 0,
|
|
1709
|
-
"memory_units_deleted": units_count if deleted else 0
|
|
1710
|
-
}
|
|
1737
|
+
return {"document_deleted": 1 if deleted else 0, "memory_units_deleted": units_count if deleted else 0}
|
|
1711
1738
|
|
|
1712
|
-
async def delete_memory_unit(self, unit_id: str) ->
|
|
1739
|
+
async def delete_memory_unit(self, unit_id: str) -> dict[str, Any]:
|
|
1713
1740
|
"""
|
|
1714
1741
|
Delete a single memory unit and all its associated links.
|
|
1715
1742
|
|
|
@@ -1728,18 +1755,17 @@ class MemoryEngine:
|
|
|
1728
1755
|
async with acquire_with_retry(pool) as conn:
|
|
1729
1756
|
async with conn.transaction():
|
|
1730
1757
|
# Delete the memory unit (cascades to links and associations)
|
|
1731
|
-
deleted = await conn.fetchval(
|
|
1732
|
-
"DELETE FROM memory_units WHERE id = $1 RETURNING id",
|
|
1733
|
-
unit_id
|
|
1734
|
-
)
|
|
1758
|
+
deleted = await conn.fetchval("DELETE FROM memory_units WHERE id = $1 RETURNING id", unit_id)
|
|
1735
1759
|
|
|
1736
1760
|
return {
|
|
1737
1761
|
"success": deleted is not None,
|
|
1738
1762
|
"unit_id": str(deleted) if deleted else None,
|
|
1739
|
-
"message": "Memory unit and all its links deleted successfully"
|
|
1763
|
+
"message": "Memory unit and all its links deleted successfully"
|
|
1764
|
+
if deleted
|
|
1765
|
+
else "Memory unit not found",
|
|
1740
1766
|
}
|
|
1741
1767
|
|
|
1742
|
-
async def delete_bank(self, bank_id: str, fact_type:
|
|
1768
|
+
async def delete_bank(self, bank_id: str, fact_type: str | None = None) -> dict[str, int]:
|
|
1743
1769
|
"""
|
|
1744
1770
|
Delete all data for a specific agent (multi-tenant cleanup).
|
|
1745
1771
|
|
|
@@ -1768,24 +1794,27 @@ class MemoryEngine:
|
|
|
1768
1794
|
# Delete only memories of a specific fact type
|
|
1769
1795
|
units_count = await conn.fetchval(
|
|
1770
1796
|
"SELECT COUNT(*) FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
|
|
1771
|
-
bank_id,
|
|
1797
|
+
bank_id,
|
|
1798
|
+
fact_type,
|
|
1772
1799
|
)
|
|
1773
1800
|
await conn.execute(
|
|
1774
|
-
"DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
|
|
1775
|
-
bank_id, fact_type
|
|
1801
|
+
"DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2", bank_id, fact_type
|
|
1776
1802
|
)
|
|
1777
1803
|
|
|
1778
1804
|
# Note: We don't delete entities when fact_type is specified,
|
|
1779
1805
|
# as they may be referenced by other memory units
|
|
1780
|
-
return {
|
|
1781
|
-
"memory_units_deleted": units_count,
|
|
1782
|
-
"entities_deleted": 0
|
|
1783
|
-
}
|
|
1806
|
+
return {"memory_units_deleted": units_count, "entities_deleted": 0}
|
|
1784
1807
|
else:
|
|
1785
1808
|
# Delete all data for the bank
|
|
1786
|
-
units_count = await conn.fetchval(
|
|
1787
|
-
|
|
1788
|
-
|
|
1809
|
+
units_count = await conn.fetchval(
|
|
1810
|
+
"SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id
|
|
1811
|
+
)
|
|
1812
|
+
entities_count = await conn.fetchval(
|
|
1813
|
+
"SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id
|
|
1814
|
+
)
|
|
1815
|
+
documents_count = await conn.fetchval(
|
|
1816
|
+
"SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id
|
|
1817
|
+
)
|
|
1789
1818
|
|
|
1790
1819
|
# Delete documents (cascades to chunks)
|
|
1791
1820
|
await conn.execute("DELETE FROM documents WHERE bank_id = $1", bank_id)
|
|
@@ -1803,13 +1832,13 @@ class MemoryEngine:
|
|
|
1803
1832
|
"memory_units_deleted": units_count,
|
|
1804
1833
|
"entities_deleted": entities_count,
|
|
1805
1834
|
"documents_deleted": documents_count,
|
|
1806
|
-
"bank_deleted": True
|
|
1835
|
+
"bank_deleted": True,
|
|
1807
1836
|
}
|
|
1808
1837
|
|
|
1809
1838
|
except Exception as e:
|
|
1810
1839
|
raise Exception(f"Failed to delete agent data: {str(e)}")
|
|
1811
1840
|
|
|
1812
|
-
async def get_graph_data(self, bank_id:
|
|
1841
|
+
async def get_graph_data(self, bank_id: str | None = None, fact_type: str | None = None):
|
|
1813
1842
|
"""
|
|
1814
1843
|
Get graph data for visualization.
|
|
1815
1844
|
|
|
@@ -1839,19 +1868,23 @@ class MemoryEngine:
|
|
|
1839
1868
|
|
|
1840
1869
|
where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
|
|
1841
1870
|
|
|
1842
|
-
units = await conn.fetch(
|
|
1871
|
+
units = await conn.fetch(
|
|
1872
|
+
f"""
|
|
1843
1873
|
SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
|
|
1844
1874
|
FROM memory_units
|
|
1845
1875
|
{where_clause}
|
|
1846
1876
|
ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
|
|
1847
1877
|
LIMIT 1000
|
|
1848
|
-
""",
|
|
1878
|
+
""",
|
|
1879
|
+
*query_params,
|
|
1880
|
+
)
|
|
1849
1881
|
|
|
1850
1882
|
# Get links, filtering to only include links between units of the selected agent
|
|
1851
1883
|
# Use DISTINCT ON with LEAST/GREATEST to deduplicate bidirectional links
|
|
1852
|
-
unit_ids = [row[
|
|
1884
|
+
unit_ids = [row["id"] for row in units]
|
|
1853
1885
|
if unit_ids:
|
|
1854
|
-
links = await conn.fetch(
|
|
1886
|
+
links = await conn.fetch(
|
|
1887
|
+
"""
|
|
1855
1888
|
SELECT DISTINCT ON (LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid))
|
|
1856
1889
|
ml.from_unit_id,
|
|
1857
1890
|
ml.to_unit_id,
|
|
@@ -1862,7 +1895,9 @@ class MemoryEngine:
|
|
|
1862
1895
|
LEFT JOIN entities e ON ml.entity_id = e.id
|
|
1863
1896
|
WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.to_unit_id = ANY($1::uuid[])
|
|
1864
1897
|
ORDER BY LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid), ml.weight DESC
|
|
1865
|
-
""",
|
|
1898
|
+
""",
|
|
1899
|
+
unit_ids,
|
|
1900
|
+
)
|
|
1866
1901
|
else:
|
|
1867
1902
|
links = []
|
|
1868
1903
|
|
|
@@ -1877,8 +1912,8 @@ class MemoryEngine:
|
|
|
1877
1912
|
# Build entity mapping
|
|
1878
1913
|
entity_map = {}
|
|
1879
1914
|
for row in unit_entities:
|
|
1880
|
-
unit_id = row[
|
|
1881
|
-
entity_name = row[
|
|
1915
|
+
unit_id = row["unit_id"]
|
|
1916
|
+
entity_name = row["canonical_name"]
|
|
1882
1917
|
if unit_id not in entity_map:
|
|
1883
1918
|
entity_map[unit_id] = []
|
|
1884
1919
|
entity_map[unit_id].append(entity_name)
|
|
@@ -1886,10 +1921,10 @@ class MemoryEngine:
|
|
|
1886
1921
|
# Build nodes
|
|
1887
1922
|
nodes = []
|
|
1888
1923
|
for row in units:
|
|
1889
|
-
unit_id = row[
|
|
1890
|
-
text = row[
|
|
1891
|
-
event_date = row[
|
|
1892
|
-
context = row[
|
|
1924
|
+
unit_id = row["id"]
|
|
1925
|
+
text = row["text"]
|
|
1926
|
+
event_date = row["event_date"]
|
|
1927
|
+
context = row["context"]
|
|
1893
1928
|
|
|
1894
1929
|
entities = entity_map.get(unit_id, [])
|
|
1895
1930
|
entity_count = len(entities)
|
|
@@ -1902,88 +1937,91 @@ class MemoryEngine:
|
|
|
1902
1937
|
else:
|
|
1903
1938
|
color = "#42a5f5"
|
|
1904
1939
|
|
|
1905
|
-
nodes.append(
|
|
1906
|
-
|
|
1907
|
-
"
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1940
|
+
nodes.append(
|
|
1941
|
+
{
|
|
1942
|
+
"data": {
|
|
1943
|
+
"id": str(unit_id),
|
|
1944
|
+
"label": f"{text[:30]}..." if len(text) > 30 else text,
|
|
1945
|
+
"text": text,
|
|
1946
|
+
"date": event_date.isoformat() if event_date else "",
|
|
1947
|
+
"context": context if context else "",
|
|
1948
|
+
"entities": ", ".join(entities) if entities else "None",
|
|
1949
|
+
"color": color,
|
|
1950
|
+
}
|
|
1914
1951
|
}
|
|
1915
|
-
|
|
1952
|
+
)
|
|
1916
1953
|
|
|
1917
1954
|
# Build edges
|
|
1918
1955
|
edges = []
|
|
1919
1956
|
for row in links:
|
|
1920
|
-
from_id = str(row[
|
|
1921
|
-
to_id = str(row[
|
|
1922
|
-
link_type = row[
|
|
1923
|
-
weight = row[
|
|
1924
|
-
entity_name = row[
|
|
1957
|
+
from_id = str(row["from_unit_id"])
|
|
1958
|
+
to_id = str(row["to_unit_id"])
|
|
1959
|
+
link_type = row["link_type"]
|
|
1960
|
+
weight = row["weight"]
|
|
1961
|
+
entity_name = row["entity_name"]
|
|
1925
1962
|
|
|
1926
1963
|
# Color by link type
|
|
1927
|
-
if link_type ==
|
|
1964
|
+
if link_type == "temporal":
|
|
1928
1965
|
color = "#00bcd4"
|
|
1929
1966
|
line_style = "dashed"
|
|
1930
|
-
elif link_type ==
|
|
1967
|
+
elif link_type == "semantic":
|
|
1931
1968
|
color = "#ff69b4"
|
|
1932
1969
|
line_style = "solid"
|
|
1933
|
-
elif link_type ==
|
|
1970
|
+
elif link_type == "entity":
|
|
1934
1971
|
color = "#ffd700"
|
|
1935
1972
|
line_style = "solid"
|
|
1936
1973
|
else:
|
|
1937
1974
|
color = "#999999"
|
|
1938
1975
|
line_style = "solid"
|
|
1939
1976
|
|
|
1940
|
-
edges.append(
|
|
1941
|
-
|
|
1942
|
-
"
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1977
|
+
edges.append(
|
|
1978
|
+
{
|
|
1979
|
+
"data": {
|
|
1980
|
+
"id": f"{from_id}-{to_id}-{link_type}",
|
|
1981
|
+
"source": from_id,
|
|
1982
|
+
"target": to_id,
|
|
1983
|
+
"linkType": link_type,
|
|
1984
|
+
"weight": weight,
|
|
1985
|
+
"entityName": entity_name if entity_name else "",
|
|
1986
|
+
"color": color,
|
|
1987
|
+
"lineStyle": line_style,
|
|
1988
|
+
}
|
|
1950
1989
|
}
|
|
1951
|
-
|
|
1990
|
+
)
|
|
1952
1991
|
|
|
1953
1992
|
# Build table rows
|
|
1954
1993
|
table_rows = []
|
|
1955
1994
|
for row in units:
|
|
1956
|
-
unit_id = row[
|
|
1995
|
+
unit_id = row["id"]
|
|
1957
1996
|
entities = entity_map.get(unit_id, [])
|
|
1958
1997
|
|
|
1959
|
-
table_rows.append(
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
}
|
|
1998
|
+
table_rows.append(
|
|
1999
|
+
{
|
|
2000
|
+
"id": str(unit_id),
|
|
2001
|
+
"text": row["text"],
|
|
2002
|
+
"context": row["context"] if row["context"] else "N/A",
|
|
2003
|
+
"occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
|
|
2004
|
+
"occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
|
|
2005
|
+
"mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
|
|
2006
|
+
"date": row["event_date"].strftime("%Y-%m-%d %H:%M")
|
|
2007
|
+
if row["event_date"]
|
|
2008
|
+
else "N/A", # Deprecated, kept for backwards compatibility
|
|
2009
|
+
"entities": ", ".join(entities) if entities else "None",
|
|
2010
|
+
"document_id": row["document_id"],
|
|
2011
|
+
"chunk_id": row["chunk_id"] if row["chunk_id"] else None,
|
|
2012
|
+
"fact_type": row["fact_type"],
|
|
2013
|
+
}
|
|
2014
|
+
)
|
|
2015
|
+
|
|
2016
|
+
return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units": len(units)}
|
|
1979
2017
|
|
|
1980
2018
|
async def list_memory_units(
|
|
1981
2019
|
self,
|
|
1982
|
-
bank_id:
|
|
1983
|
-
fact_type:
|
|
1984
|
-
search_query:
|
|
2020
|
+
bank_id: str | None = None,
|
|
2021
|
+
fact_type: str | None = None,
|
|
2022
|
+
search_query: str | None = None,
|
|
1985
2023
|
limit: int = 100,
|
|
1986
|
-
offset: int = 0
|
|
2024
|
+
offset: int = 0,
|
|
1987
2025
|
):
|
|
1988
2026
|
"""
|
|
1989
2027
|
List memory units for table view with optional full-text search.
|
|
@@ -2030,7 +2068,7 @@ class MemoryEngine:
|
|
|
2030
2068
|
{where_clause}
|
|
2031
2069
|
"""
|
|
2032
2070
|
count_result = await conn.fetchrow(count_query, *query_params)
|
|
2033
|
-
total = count_result[
|
|
2071
|
+
total = count_result["total"]
|
|
2034
2072
|
|
|
2035
2073
|
# Get units with limit and offset
|
|
2036
2074
|
param_count += 1
|
|
@@ -2041,32 +2079,38 @@ class MemoryEngine:
|
|
|
2041
2079
|
offset_param = f"${param_count}"
|
|
2042
2080
|
query_params.append(offset)
|
|
2043
2081
|
|
|
2044
|
-
units = await conn.fetch(
|
|
2082
|
+
units = await conn.fetch(
|
|
2083
|
+
f"""
|
|
2045
2084
|
SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id
|
|
2046
2085
|
FROM memory_units
|
|
2047
2086
|
{where_clause}
|
|
2048
2087
|
ORDER BY mentioned_at DESC NULLS LAST, created_at DESC
|
|
2049
2088
|
LIMIT {limit_param} OFFSET {offset_param}
|
|
2050
|
-
""",
|
|
2089
|
+
""",
|
|
2090
|
+
*query_params,
|
|
2091
|
+
)
|
|
2051
2092
|
|
|
2052
2093
|
# Get entity information for these units
|
|
2053
2094
|
if units:
|
|
2054
|
-
unit_ids = [row[
|
|
2055
|
-
unit_entities = await conn.fetch(
|
|
2095
|
+
unit_ids = [row["id"] for row in units]
|
|
2096
|
+
unit_entities = await conn.fetch(
|
|
2097
|
+
"""
|
|
2056
2098
|
SELECT ue.unit_id, e.canonical_name
|
|
2057
2099
|
FROM unit_entities ue
|
|
2058
2100
|
JOIN entities e ON ue.entity_id = e.id
|
|
2059
2101
|
WHERE ue.unit_id = ANY($1::uuid[])
|
|
2060
2102
|
ORDER BY ue.unit_id
|
|
2061
|
-
""",
|
|
2103
|
+
""",
|
|
2104
|
+
unit_ids,
|
|
2105
|
+
)
|
|
2062
2106
|
else:
|
|
2063
2107
|
unit_entities = []
|
|
2064
2108
|
|
|
2065
2109
|
# Build entity mapping
|
|
2066
2110
|
entity_map = {}
|
|
2067
2111
|
for row in unit_entities:
|
|
2068
|
-
unit_id = row[
|
|
2069
|
-
entity_name = row[
|
|
2112
|
+
unit_id = row["unit_id"]
|
|
2113
|
+
entity_name = row["canonical_name"]
|
|
2070
2114
|
if unit_id not in entity_map:
|
|
2071
2115
|
entity_map[unit_id] = []
|
|
2072
2116
|
entity_map[unit_id].append(entity_name)
|
|
@@ -2074,36 +2118,27 @@ class MemoryEngine:
|
|
|
2074
2118
|
# Build result items
|
|
2075
2119
|
items = []
|
|
2076
2120
|
for row in units:
|
|
2077
|
-
unit_id = row[
|
|
2121
|
+
unit_id = row["id"]
|
|
2078
2122
|
entities = entity_map.get(unit_id, [])
|
|
2079
2123
|
|
|
2080
|
-
items.append(
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
|
|
2090
|
-
|
|
2091
|
-
|
|
2124
|
+
items.append(
|
|
2125
|
+
{
|
|
2126
|
+
"id": str(unit_id),
|
|
2127
|
+
"text": row["text"],
|
|
2128
|
+
"context": row["context"] if row["context"] else "",
|
|
2129
|
+
"date": row["event_date"].isoformat() if row["event_date"] else "",
|
|
2130
|
+
"fact_type": row["fact_type"],
|
|
2131
|
+
"mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
|
|
2132
|
+
"occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
|
|
2133
|
+
"occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
|
|
2134
|
+
"entities": ", ".join(entities) if entities else "",
|
|
2135
|
+
"chunk_id": row["chunk_id"] if row["chunk_id"] else None,
|
|
2136
|
+
}
|
|
2137
|
+
)
|
|
2092
2138
|
|
|
2093
|
-
return {
|
|
2094
|
-
"items": items,
|
|
2095
|
-
"total": total,
|
|
2096
|
-
"limit": limit,
|
|
2097
|
-
"offset": offset
|
|
2098
|
-
}
|
|
2139
|
+
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
|
2099
2140
|
|
|
2100
|
-
async def list_documents(
|
|
2101
|
-
self,
|
|
2102
|
-
bank_id: str,
|
|
2103
|
-
search_query: Optional[str] = None,
|
|
2104
|
-
limit: int = 100,
|
|
2105
|
-
offset: int = 0
|
|
2106
|
-
):
|
|
2141
|
+
async def list_documents(self, bank_id: str, search_query: str | None = None, limit: int = 100, offset: int = 0):
|
|
2107
2142
|
"""
|
|
2108
2143
|
List documents with optional search and pagination.
|
|
2109
2144
|
|
|
@@ -2142,7 +2177,7 @@ class MemoryEngine:
|
|
|
2142
2177
|
{where_clause}
|
|
2143
2178
|
"""
|
|
2144
2179
|
count_result = await conn.fetchrow(count_query, *query_params)
|
|
2145
|
-
total = count_result[
|
|
2180
|
+
total = count_result["total"]
|
|
2146
2181
|
|
|
2147
2182
|
# Get documents with limit and offset (without original_text for performance)
|
|
2148
2183
|
param_count += 1
|
|
@@ -2153,7 +2188,8 @@ class MemoryEngine:
|
|
|
2153
2188
|
offset_param = f"${param_count}"
|
|
2154
2189
|
query_params.append(offset)
|
|
2155
2190
|
|
|
2156
|
-
documents = await conn.fetch(
|
|
2191
|
+
documents = await conn.fetch(
|
|
2192
|
+
f"""
|
|
2157
2193
|
SELECT
|
|
2158
2194
|
id,
|
|
2159
2195
|
bank_id,
|
|
@@ -2166,11 +2202,13 @@ class MemoryEngine:
|
|
|
2166
2202
|
{where_clause}
|
|
2167
2203
|
ORDER BY created_at DESC
|
|
2168
2204
|
LIMIT {limit_param} OFFSET {offset_param}
|
|
2169
|
-
""",
|
|
2205
|
+
""",
|
|
2206
|
+
*query_params,
|
|
2207
|
+
)
|
|
2170
2208
|
|
|
2171
2209
|
# Get memory unit count for each document
|
|
2172
2210
|
if documents:
|
|
2173
|
-
doc_ids = [(row[
|
|
2211
|
+
doc_ids = [(row["id"], row["bank_id"]) for row in documents]
|
|
2174
2212
|
|
|
2175
2213
|
# Create placeholders for the query
|
|
2176
2214
|
placeholders = []
|
|
@@ -2183,48 +2221,44 @@ class MemoryEngine:
|
|
|
2183
2221
|
|
|
2184
2222
|
where_clause_count = " OR ".join(placeholders)
|
|
2185
2223
|
|
|
2186
|
-
unit_counts = await conn.fetch(
|
|
2224
|
+
unit_counts = await conn.fetch(
|
|
2225
|
+
f"""
|
|
2187
2226
|
SELECT document_id, bank_id, COUNT(*) as unit_count
|
|
2188
2227
|
FROM memory_units
|
|
2189
2228
|
WHERE {where_clause_count}
|
|
2190
2229
|
GROUP BY document_id, bank_id
|
|
2191
|
-
""",
|
|
2230
|
+
""",
|
|
2231
|
+
*params_for_count,
|
|
2232
|
+
)
|
|
2192
2233
|
else:
|
|
2193
2234
|
unit_counts = []
|
|
2194
2235
|
|
|
2195
2236
|
# Build count mapping
|
|
2196
|
-
count_map = {(row[
|
|
2237
|
+
count_map = {(row["document_id"], row["bank_id"]): row["unit_count"] for row in unit_counts}
|
|
2197
2238
|
|
|
2198
2239
|
# Build result items
|
|
2199
2240
|
items = []
|
|
2200
2241
|
for row in documents:
|
|
2201
|
-
doc_id = row[
|
|
2202
|
-
bank_id_val = row[
|
|
2242
|
+
doc_id = row["id"]
|
|
2243
|
+
bank_id_val = row["bank_id"]
|
|
2203
2244
|
unit_count = count_map.get((doc_id, bank_id_val), 0)
|
|
2204
2245
|
|
|
2205
|
-
items.append(
|
|
2206
|
-
|
|
2207
|
-
|
|
2208
|
-
|
|
2209
|
-
|
|
2210
|
-
|
|
2211
|
-
|
|
2212
|
-
|
|
2213
|
-
|
|
2214
|
-
|
|
2246
|
+
items.append(
|
|
2247
|
+
{
|
|
2248
|
+
"id": doc_id,
|
|
2249
|
+
"bank_id": bank_id_val,
|
|
2250
|
+
"content_hash": row["content_hash"],
|
|
2251
|
+
"created_at": row["created_at"].isoformat() if row["created_at"] else "",
|
|
2252
|
+
"updated_at": row["updated_at"].isoformat() if row["updated_at"] else "",
|
|
2253
|
+
"text_length": row["text_length"] or 0,
|
|
2254
|
+
"memory_unit_count": unit_count,
|
|
2255
|
+
"retain_params": row["retain_params"] if row["retain_params"] else None,
|
|
2256
|
+
}
|
|
2257
|
+
)
|
|
2215
2258
|
|
|
2216
|
-
return {
|
|
2217
|
-
"items": items,
|
|
2218
|
-
"total": total,
|
|
2219
|
-
"limit": limit,
|
|
2220
|
-
"offset": offset
|
|
2221
|
-
}
|
|
2259
|
+
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
|
2222
2260
|
|
|
2223
|
-
async def get_document(
|
|
2224
|
-
self,
|
|
2225
|
-
document_id: str,
|
|
2226
|
-
bank_id: str
|
|
2227
|
-
):
|
|
2261
|
+
async def get_document(self, document_id: str, bank_id: str):
|
|
2228
2262
|
"""
|
|
2229
2263
|
Get a specific document including its original_text.
|
|
2230
2264
|
|
|
@@ -2237,7 +2271,8 @@ class MemoryEngine:
|
|
|
2237
2271
|
"""
|
|
2238
2272
|
pool = await self._get_pool()
|
|
2239
2273
|
async with acquire_with_retry(pool) as conn:
|
|
2240
|
-
doc = await conn.fetchrow(
|
|
2274
|
+
doc = await conn.fetchrow(
|
|
2275
|
+
"""
|
|
2241
2276
|
SELECT
|
|
2242
2277
|
id,
|
|
2243
2278
|
bank_id,
|
|
@@ -2248,33 +2283,37 @@ class MemoryEngine:
|
|
|
2248
2283
|
retain_params
|
|
2249
2284
|
FROM documents
|
|
2250
2285
|
WHERE id = $1 AND bank_id = $2
|
|
2251
|
-
""",
|
|
2286
|
+
""",
|
|
2287
|
+
document_id,
|
|
2288
|
+
bank_id,
|
|
2289
|
+
)
|
|
2252
2290
|
|
|
2253
2291
|
if not doc:
|
|
2254
2292
|
return None
|
|
2255
2293
|
|
|
2256
2294
|
# Get memory unit count
|
|
2257
|
-
unit_count_row = await conn.fetchrow(
|
|
2295
|
+
unit_count_row = await conn.fetchrow(
|
|
2296
|
+
"""
|
|
2258
2297
|
SELECT COUNT(*) as unit_count
|
|
2259
2298
|
FROM memory_units
|
|
2260
2299
|
WHERE document_id = $1 AND bank_id = $2
|
|
2261
|
-
""",
|
|
2300
|
+
""",
|
|
2301
|
+
document_id,
|
|
2302
|
+
bank_id,
|
|
2303
|
+
)
|
|
2262
2304
|
|
|
2263
2305
|
return {
|
|
2264
|
-
"id": doc[
|
|
2265
|
-
"bank_id": doc[
|
|
2266
|
-
"original_text": doc[
|
|
2267
|
-
"content_hash": doc[
|
|
2268
|
-
"created_at": doc[
|
|
2269
|
-
"updated_at": doc[
|
|
2270
|
-
"memory_unit_count": unit_count_row[
|
|
2271
|
-
"retain_params": doc[
|
|
2306
|
+
"id": doc["id"],
|
|
2307
|
+
"bank_id": doc["bank_id"],
|
|
2308
|
+
"original_text": doc["original_text"],
|
|
2309
|
+
"content_hash": doc["content_hash"],
|
|
2310
|
+
"created_at": doc["created_at"].isoformat() if doc["created_at"] else "",
|
|
2311
|
+
"updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else "",
|
|
2312
|
+
"memory_unit_count": unit_count_row["unit_count"] if unit_count_row else 0,
|
|
2313
|
+
"retain_params": doc["retain_params"] if doc["retain_params"] else None,
|
|
2272
2314
|
}
|
|
2273
2315
|
|
|
2274
|
-
async def get_chunk(
|
|
2275
|
-
self,
|
|
2276
|
-
chunk_id: str
|
|
2277
|
-
):
|
|
2316
|
+
async def get_chunk(self, chunk_id: str):
|
|
2278
2317
|
"""
|
|
2279
2318
|
Get a specific chunk by its ID.
|
|
2280
2319
|
|
|
@@ -2286,7 +2325,8 @@ class MemoryEngine:
|
|
|
2286
2325
|
"""
|
|
2287
2326
|
pool = await self._get_pool()
|
|
2288
2327
|
async with acquire_with_retry(pool) as conn:
|
|
2289
|
-
chunk = await conn.fetchrow(
|
|
2328
|
+
chunk = await conn.fetchrow(
|
|
2329
|
+
"""
|
|
2290
2330
|
SELECT
|
|
2291
2331
|
chunk_id,
|
|
2292
2332
|
document_id,
|
|
@@ -2296,18 +2336,20 @@ class MemoryEngine:
|
|
|
2296
2336
|
created_at
|
|
2297
2337
|
FROM chunks
|
|
2298
2338
|
WHERE chunk_id = $1
|
|
2299
|
-
""",
|
|
2339
|
+
""",
|
|
2340
|
+
chunk_id,
|
|
2341
|
+
)
|
|
2300
2342
|
|
|
2301
2343
|
if not chunk:
|
|
2302
2344
|
return None
|
|
2303
2345
|
|
|
2304
2346
|
return {
|
|
2305
|
-
"chunk_id": chunk[
|
|
2306
|
-
"document_id": chunk[
|
|
2307
|
-
"bank_id": chunk[
|
|
2308
|
-
"chunk_index": chunk[
|
|
2309
|
-
"chunk_text": chunk[
|
|
2310
|
-
"created_at": chunk[
|
|
2347
|
+
"chunk_id": chunk["chunk_id"],
|
|
2348
|
+
"document_id": chunk["document_id"],
|
|
2349
|
+
"bank_id": chunk["bank_id"],
|
|
2350
|
+
"chunk_index": chunk["chunk_index"],
|
|
2351
|
+
"chunk_text": chunk["chunk_text"],
|
|
2352
|
+
"created_at": chunk["created_at"].isoformat() if chunk["created_at"] else "",
|
|
2311
2353
|
}
|
|
2312
2354
|
|
|
2313
2355
|
async def _evaluate_opinion_update_async(
|
|
@@ -2316,7 +2358,7 @@ class MemoryEngine:
|
|
|
2316
2358
|
opinion_confidence: float,
|
|
2317
2359
|
new_event_text: str,
|
|
2318
2360
|
entity_name: str,
|
|
2319
|
-
) ->
|
|
2361
|
+
) -> dict[str, Any] | None:
|
|
2320
2362
|
"""
|
|
2321
2363
|
Evaluate if an opinion should be updated based on a new event.
|
|
2322
2364
|
|
|
@@ -2330,16 +2372,18 @@ class MemoryEngine:
|
|
|
2330
2372
|
Dict with 'action' ('keep'|'update'), 'new_confidence', 'new_text' (if action=='update')
|
|
2331
2373
|
or None if no changes needed
|
|
2332
2374
|
"""
|
|
2333
|
-
from pydantic import BaseModel, Field
|
|
2334
2375
|
|
|
2335
2376
|
class OpinionEvaluation(BaseModel):
|
|
2336
2377
|
"""Evaluation of whether an opinion should be updated."""
|
|
2378
|
+
|
|
2337
2379
|
action: str = Field(description="Action to take: 'keep' (no change) or 'update' (modify opinion)")
|
|
2338
2380
|
reasoning: str = Field(description="Brief explanation of why this action was chosen")
|
|
2339
|
-
new_confidence: float = Field(
|
|
2340
|
-
|
|
2381
|
+
new_confidence: float = Field(
|
|
2382
|
+
description="New confidence score (0.0-1.0). Can be higher, lower, or same as before."
|
|
2383
|
+
)
|
|
2384
|
+
new_opinion_text: str | None = Field(
|
|
2341
2385
|
default=None,
|
|
2342
|
-
description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None."
|
|
2386
|
+
description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None.",
|
|
2343
2387
|
)
|
|
2344
2388
|
|
|
2345
2389
|
evaluation_prompt = f"""You are evaluating whether an existing opinion should be updated based on new information.
|
|
@@ -2369,70 +2413,63 @@ Guidelines:
|
|
|
2369
2413
|
result = await self._llm_config.call(
|
|
2370
2414
|
messages=[
|
|
2371
2415
|
{"role": "system", "content": "You evaluate and update opinions based on new information."},
|
|
2372
|
-
{"role": "user", "content": evaluation_prompt}
|
|
2416
|
+
{"role": "user", "content": evaluation_prompt},
|
|
2373
2417
|
],
|
|
2374
2418
|
response_format=OpinionEvaluation,
|
|
2375
2419
|
scope="memory_evaluate_opinion",
|
|
2376
|
-
temperature=0.3 # Lower temperature for more consistent evaluation
|
|
2420
|
+
temperature=0.3, # Lower temperature for more consistent evaluation
|
|
2377
2421
|
)
|
|
2378
2422
|
|
|
2379
2423
|
# Only return updates if something actually changed
|
|
2380
|
-
if result.action ==
|
|
2424
|
+
if result.action == "keep" and abs(result.new_confidence - opinion_confidence) < 0.01:
|
|
2381
2425
|
return None
|
|
2382
2426
|
|
|
2383
2427
|
return {
|
|
2384
|
-
|
|
2385
|
-
|
|
2386
|
-
|
|
2387
|
-
|
|
2428
|
+
"action": result.action,
|
|
2429
|
+
"reasoning": result.reasoning,
|
|
2430
|
+
"new_confidence": result.new_confidence,
|
|
2431
|
+
"new_text": result.new_opinion_text if result.action == "update" else None,
|
|
2388
2432
|
}
|
|
2389
2433
|
|
|
2390
2434
|
except Exception as e:
|
|
2391
2435
|
logger.warning(f"Failed to evaluate opinion update: {str(e)}")
|
|
2392
2436
|
return None
|
|
2393
2437
|
|
|
2394
|
-
async def _handle_form_opinion(self, task_dict:
|
|
2438
|
+
async def _handle_form_opinion(self, task_dict: dict[str, Any]):
|
|
2395
2439
|
"""
|
|
2396
2440
|
Handler for form opinion tasks.
|
|
2397
2441
|
|
|
2398
2442
|
Args:
|
|
2399
2443
|
task_dict: Dict with keys: 'bank_id', 'answer_text', 'query'
|
|
2400
2444
|
"""
|
|
2401
|
-
bank_id = task_dict[
|
|
2402
|
-
answer_text = task_dict[
|
|
2403
|
-
query = task_dict[
|
|
2445
|
+
bank_id = task_dict["bank_id"]
|
|
2446
|
+
answer_text = task_dict["answer_text"]
|
|
2447
|
+
query = task_dict["query"]
|
|
2404
2448
|
|
|
2405
|
-
await self._extract_and_store_opinions_async(
|
|
2406
|
-
bank_id=bank_id,
|
|
2407
|
-
answer_text=answer_text,
|
|
2408
|
-
query=query
|
|
2409
|
-
)
|
|
2449
|
+
await self._extract_and_store_opinions_async(bank_id=bank_id, answer_text=answer_text, query=query)
|
|
2410
2450
|
|
|
2411
|
-
async def _handle_reinforce_opinion(self, task_dict:
|
|
2451
|
+
async def _handle_reinforce_opinion(self, task_dict: dict[str, Any]):
|
|
2412
2452
|
"""
|
|
2413
2453
|
Handler for reinforce opinion tasks.
|
|
2414
2454
|
|
|
2415
2455
|
Args:
|
|
2416
2456
|
task_dict: Dict with keys: 'bank_id', 'created_unit_ids', 'unit_texts', 'unit_entities'
|
|
2417
2457
|
"""
|
|
2418
|
-
bank_id = task_dict[
|
|
2419
|
-
created_unit_ids = task_dict[
|
|
2420
|
-
unit_texts = task_dict[
|
|
2421
|
-
unit_entities = task_dict[
|
|
2458
|
+
bank_id = task_dict["bank_id"]
|
|
2459
|
+
created_unit_ids = task_dict["created_unit_ids"]
|
|
2460
|
+
unit_texts = task_dict["unit_texts"]
|
|
2461
|
+
unit_entities = task_dict["unit_entities"]
|
|
2422
2462
|
|
|
2423
2463
|
await self._reinforce_opinions_async(
|
|
2424
|
-
bank_id=bank_id,
|
|
2425
|
-
created_unit_ids=created_unit_ids,
|
|
2426
|
-
unit_texts=unit_texts,
|
|
2427
|
-
unit_entities=unit_entities
|
|
2464
|
+
bank_id=bank_id, created_unit_ids=created_unit_ids, unit_texts=unit_texts, unit_entities=unit_entities
|
|
2428
2465
|
)
|
|
2429
2466
|
|
|
2430
2467
|
async def _reinforce_opinions_async(
|
|
2431
2468
|
self,
|
|
2432
2469
|
bank_id: str,
|
|
2433
|
-
created_unit_ids:
|
|
2434
|
-
unit_texts:
|
|
2435
|
-
unit_entities:
|
|
2470
|
+
created_unit_ids: list[str],
|
|
2471
|
+
unit_texts: list[str],
|
|
2472
|
+
unit_entities: list[list[dict[str, str]]],
|
|
2436
2473
|
):
|
|
2437
2474
|
"""
|
|
2438
2475
|
Background task to reinforce opinions based on newly ingested events.
|
|
@@ -2451,15 +2488,14 @@ Guidelines:
|
|
|
2451
2488
|
for entities_list in unit_entities:
|
|
2452
2489
|
for entity in entities_list:
|
|
2453
2490
|
# Handle both Entity objects and dicts
|
|
2454
|
-
if hasattr(entity,
|
|
2491
|
+
if hasattr(entity, "text"):
|
|
2455
2492
|
entity_names.add(entity.text)
|
|
2456
2493
|
elif isinstance(entity, dict):
|
|
2457
|
-
entity_names.add(entity[
|
|
2494
|
+
entity_names.add(entity["text"])
|
|
2458
2495
|
|
|
2459
2496
|
if not entity_names:
|
|
2460
2497
|
return
|
|
2461
2498
|
|
|
2462
|
-
|
|
2463
2499
|
pool = await self._get_pool()
|
|
2464
2500
|
async with acquire_with_retry(pool) as conn:
|
|
2465
2501
|
# Find all opinions related to these entities
|
|
@@ -2474,13 +2510,12 @@ Guidelines:
|
|
|
2474
2510
|
AND e.canonical_name = ANY($2::text[])
|
|
2475
2511
|
""",
|
|
2476
2512
|
bank_id,
|
|
2477
|
-
list(entity_names)
|
|
2513
|
+
list(entity_names),
|
|
2478
2514
|
)
|
|
2479
2515
|
|
|
2480
2516
|
if not opinions:
|
|
2481
2517
|
return
|
|
2482
2518
|
|
|
2483
|
-
|
|
2484
2519
|
# Use cached LLM config
|
|
2485
2520
|
if self._llm_config is None:
|
|
2486
2521
|
logger.error("[REINFORCE] LLM config not available, skipping opinion reinforcement")
|
|
@@ -2489,15 +2524,15 @@ Guidelines:
|
|
|
2489
2524
|
# Evaluate each opinion against the new events
|
|
2490
2525
|
updates_to_apply = []
|
|
2491
2526
|
for opinion in opinions:
|
|
2492
|
-
opinion_id = str(opinion[
|
|
2493
|
-
opinion_text = opinion[
|
|
2494
|
-
opinion_confidence = opinion[
|
|
2495
|
-
entity_name = opinion[
|
|
2527
|
+
opinion_id = str(opinion["id"])
|
|
2528
|
+
opinion_text = opinion["text"]
|
|
2529
|
+
opinion_confidence = opinion["confidence_score"]
|
|
2530
|
+
entity_name = opinion["canonical_name"]
|
|
2496
2531
|
|
|
2497
2532
|
# Find all new events mentioning this entity
|
|
2498
2533
|
relevant_events = []
|
|
2499
2534
|
for unit_text, entities_list in zip(unit_texts, unit_entities):
|
|
2500
|
-
if any(e[
|
|
2535
|
+
if any(e["text"] == entity_name for e in entities_list):
|
|
2501
2536
|
relevant_events.append(unit_text)
|
|
2502
2537
|
|
|
2503
2538
|
if not relevant_events:
|
|
@@ -2508,26 +2543,20 @@ Guidelines:
|
|
|
2508
2543
|
|
|
2509
2544
|
# Evaluate if opinion should be updated
|
|
2510
2545
|
evaluation = await self._evaluate_opinion_update_async(
|
|
2511
|
-
opinion_text,
|
|
2512
|
-
opinion_confidence,
|
|
2513
|
-
combined_events,
|
|
2514
|
-
entity_name
|
|
2546
|
+
opinion_text, opinion_confidence, combined_events, entity_name
|
|
2515
2547
|
)
|
|
2516
2548
|
|
|
2517
2549
|
if evaluation:
|
|
2518
|
-
updates_to_apply.append({
|
|
2519
|
-
'opinion_id': opinion_id,
|
|
2520
|
-
'evaluation': evaluation
|
|
2521
|
-
})
|
|
2550
|
+
updates_to_apply.append({"opinion_id": opinion_id, "evaluation": evaluation})
|
|
2522
2551
|
|
|
2523
2552
|
# Apply all updates in a single transaction
|
|
2524
2553
|
if updates_to_apply:
|
|
2525
2554
|
async with conn.transaction():
|
|
2526
2555
|
for update in updates_to_apply:
|
|
2527
|
-
opinion_id = update[
|
|
2528
|
-
evaluation = update[
|
|
2556
|
+
opinion_id = update["opinion_id"]
|
|
2557
|
+
evaluation = update["evaluation"]
|
|
2529
2558
|
|
|
2530
|
-
if evaluation[
|
|
2559
|
+
if evaluation["action"] == "update" and evaluation["new_text"]:
|
|
2531
2560
|
# Update both text and confidence
|
|
2532
2561
|
await conn.execute(
|
|
2533
2562
|
"""
|
|
@@ -2535,9 +2564,9 @@ Guidelines:
|
|
|
2535
2564
|
SET text = $1, confidence_score = $2, updated_at = NOW()
|
|
2536
2565
|
WHERE id = $3
|
|
2537
2566
|
""",
|
|
2538
|
-
evaluation[
|
|
2539
|
-
evaluation[
|
|
2540
|
-
uuid.UUID(opinion_id)
|
|
2567
|
+
evaluation["new_text"],
|
|
2568
|
+
evaluation["new_confidence"],
|
|
2569
|
+
uuid.UUID(opinion_id),
|
|
2541
2570
|
)
|
|
2542
2571
|
else:
|
|
2543
2572
|
# Only update confidence
|
|
@@ -2547,8 +2576,8 @@ Guidelines:
|
|
|
2547
2576
|
SET confidence_score = $1, updated_at = NOW()
|
|
2548
2577
|
WHERE id = $2
|
|
2549
2578
|
""",
|
|
2550
|
-
evaluation[
|
|
2551
|
-
uuid.UUID(opinion_id)
|
|
2579
|
+
evaluation["new_confidence"],
|
|
2580
|
+
uuid.UUID(opinion_id),
|
|
2552
2581
|
)
|
|
2553
2582
|
|
|
2554
2583
|
else:
|
|
@@ -2557,6 +2586,7 @@ Guidelines:
|
|
|
2557
2586
|
except Exception as e:
|
|
2558
2587
|
logger.error(f"[REINFORCE] Error during opinion reinforcement: {str(e)}")
|
|
2559
2588
|
import traceback
|
|
2589
|
+
|
|
2560
2590
|
traceback.print_exc()
|
|
2561
2591
|
|
|
2562
2592
|
# ==================== bank profile Methods ====================
|
|
@@ -2575,11 +2605,7 @@ Guidelines:
|
|
|
2575
2605
|
pool = await self._get_pool()
|
|
2576
2606
|
return await bank_utils.get_bank_profile(pool, bank_id)
|
|
2577
2607
|
|
|
2578
|
-
async def update_bank_disposition(
|
|
2579
|
-
self,
|
|
2580
|
-
bank_id: str,
|
|
2581
|
-
disposition: Dict[str, int]
|
|
2582
|
-
) -> None:
|
|
2608
|
+
async def update_bank_disposition(self, bank_id: str, disposition: dict[str, int]) -> None:
|
|
2583
2609
|
"""
|
|
2584
2610
|
Update bank disposition traits.
|
|
2585
2611
|
|
|
@@ -2590,12 +2616,7 @@ Guidelines:
|
|
|
2590
2616
|
pool = await self._get_pool()
|
|
2591
2617
|
await bank_utils.update_bank_disposition(pool, bank_id, disposition)
|
|
2592
2618
|
|
|
2593
|
-
async def merge_bank_background(
|
|
2594
|
-
self,
|
|
2595
|
-
bank_id: str,
|
|
2596
|
-
new_info: str,
|
|
2597
|
-
update_disposition: bool = True
|
|
2598
|
-
) -> dict:
|
|
2619
|
+
async def merge_bank_background(self, bank_id: str, new_info: str, update_disposition: bool = True) -> dict:
|
|
2599
2620
|
"""
|
|
2600
2621
|
Merge new background information with existing background using LLM.
|
|
2601
2622
|
Normalizes to first person ("I") and resolves conflicts.
|
|
@@ -2610,9 +2631,7 @@ Guidelines:
|
|
|
2610
2631
|
Dict with 'background' (str) and optionally 'disposition' (dict) keys
|
|
2611
2632
|
"""
|
|
2612
2633
|
pool = await self._get_pool()
|
|
2613
|
-
return await bank_utils.merge_bank_background(
|
|
2614
|
-
pool, self._llm_config, bank_id, new_info, update_disposition
|
|
2615
|
-
)
|
|
2634
|
+
return await bank_utils.merge_bank_background(pool, self._llm_config, bank_id, new_info, update_disposition)
|
|
2616
2635
|
|
|
2617
2636
|
async def list_banks(self) -> list:
|
|
2618
2637
|
"""
|
|
@@ -2673,19 +2692,21 @@ Guidelines:
|
|
|
2673
2692
|
budget=budget,
|
|
2674
2693
|
max_tokens=4096,
|
|
2675
2694
|
enable_trace=False,
|
|
2676
|
-
fact_type=[
|
|
2677
|
-
include_entities=True
|
|
2695
|
+
fact_type=["experience", "world", "opinion"],
|
|
2696
|
+
include_entities=True,
|
|
2678
2697
|
)
|
|
2679
2698
|
recall_time = time.time() - recall_start
|
|
2680
2699
|
|
|
2681
2700
|
all_results = search_result.results
|
|
2682
2701
|
|
|
2683
2702
|
# Split results by fact type for structured response
|
|
2684
|
-
agent_results = [r for r in all_results if r.fact_type ==
|
|
2685
|
-
world_results = [r for r in all_results if r.fact_type ==
|
|
2686
|
-
opinion_results = [r for r in all_results if r.fact_type ==
|
|
2703
|
+
agent_results = [r for r in all_results if r.fact_type == "experience"]
|
|
2704
|
+
world_results = [r for r in all_results if r.fact_type == "world"]
|
|
2705
|
+
opinion_results = [r for r in all_results if r.fact_type == "opinion"]
|
|
2687
2706
|
|
|
2688
|
-
log_buffer.append(
|
|
2707
|
+
log_buffer.append(
|
|
2708
|
+
f"[REFLECT {reflect_id}] Recall: {len(all_results)} facts (experience={len(agent_results)}, world={len(world_results)}, opinion={len(opinion_results)}) in {recall_time:.3f}s"
|
|
2709
|
+
)
|
|
2689
2710
|
|
|
2690
2711
|
# Format facts for LLM
|
|
2691
2712
|
agent_facts_text = think_utils.format_facts_for_prompt(agent_results)
|
|
@@ -2716,47 +2737,34 @@ Guidelines:
|
|
|
2716
2737
|
|
|
2717
2738
|
llm_start = time.time()
|
|
2718
2739
|
answer_text = await self._llm_config.call(
|
|
2719
|
-
messages=[
|
|
2720
|
-
{"role": "system", "content": system_message},
|
|
2721
|
-
{"role": "user", "content": prompt}
|
|
2722
|
-
],
|
|
2740
|
+
messages=[{"role": "system", "content": system_message}, {"role": "user", "content": prompt}],
|
|
2723
2741
|
scope="memory_think",
|
|
2724
2742
|
temperature=0.9,
|
|
2725
|
-
max_completion_tokens=1000
|
|
2743
|
+
max_completion_tokens=1000,
|
|
2726
2744
|
)
|
|
2727
2745
|
llm_time = time.time() - llm_start
|
|
2728
2746
|
|
|
2729
2747
|
answer_text = answer_text.strip()
|
|
2730
2748
|
|
|
2731
2749
|
# Submit form_opinion task for background processing
|
|
2732
|
-
await self._task_backend.submit_task(
|
|
2733
|
-
|
|
2734
|
-
|
|
2735
|
-
'answer_text': answer_text,
|
|
2736
|
-
'query': query
|
|
2737
|
-
})
|
|
2750
|
+
await self._task_backend.submit_task(
|
|
2751
|
+
{"type": "form_opinion", "bank_id": bank_id, "answer_text": answer_text, "query": query}
|
|
2752
|
+
)
|
|
2738
2753
|
|
|
2739
2754
|
total_time = time.time() - reflect_start
|
|
2740
|
-
log_buffer.append(
|
|
2755
|
+
log_buffer.append(
|
|
2756
|
+
f"[REFLECT {reflect_id}] Complete: {len(answer_text)} chars response, LLM {llm_time:.3f}s, total {total_time:.3f}s"
|
|
2757
|
+
)
|
|
2741
2758
|
logger.info("\n" + "\n".join(log_buffer))
|
|
2742
2759
|
|
|
2743
2760
|
# Return response with facts split by type
|
|
2744
2761
|
return ReflectResult(
|
|
2745
2762
|
text=answer_text,
|
|
2746
|
-
based_on={
|
|
2747
|
-
|
|
2748
|
-
"experience": agent_results,
|
|
2749
|
-
"opinion": opinion_results
|
|
2750
|
-
},
|
|
2751
|
-
new_opinions=[] # Opinions are being extracted asynchronously
|
|
2763
|
+
based_on={"world": world_results, "experience": agent_results, "opinion": opinion_results},
|
|
2764
|
+
new_opinions=[], # Opinions are being extracted asynchronously
|
|
2752
2765
|
)
|
|
2753
2766
|
|
|
2754
|
-
async def _extract_and_store_opinions_async(
|
|
2755
|
-
self,
|
|
2756
|
-
bank_id: str,
|
|
2757
|
-
answer_text: str,
|
|
2758
|
-
query: str
|
|
2759
|
-
):
|
|
2767
|
+
async def _extract_and_store_opinions_async(self, bank_id: str, answer_text: str, query: str):
|
|
2760
2768
|
"""
|
|
2761
2769
|
Background task to extract and store opinions from think response.
|
|
2762
2770
|
|
|
@@ -2769,33 +2777,27 @@ Guidelines:
|
|
|
2769
2777
|
"""
|
|
2770
2778
|
try:
|
|
2771
2779
|
# Extract opinions from the answer
|
|
2772
|
-
new_opinions = await think_utils.extract_opinions_from_text(
|
|
2773
|
-
self._llm_config, text=answer_text, query=query
|
|
2774
|
-
)
|
|
2780
|
+
new_opinions = await think_utils.extract_opinions_from_text(self._llm_config, text=answer_text, query=query)
|
|
2775
2781
|
|
|
2776
2782
|
# Store new opinions
|
|
2777
2783
|
if new_opinions:
|
|
2778
|
-
from datetime import datetime
|
|
2779
|
-
|
|
2784
|
+
from datetime import datetime
|
|
2785
|
+
|
|
2786
|
+
current_time = datetime.now(UTC)
|
|
2780
2787
|
for opinion in new_opinions:
|
|
2781
2788
|
await self.retain_async(
|
|
2782
2789
|
bank_id=bank_id,
|
|
2783
2790
|
content=opinion.opinion,
|
|
2784
2791
|
context=f"formed during thinking about: {query}",
|
|
2785
2792
|
event_date=current_time,
|
|
2786
|
-
fact_type_override=
|
|
2787
|
-
confidence_score=opinion.confidence
|
|
2793
|
+
fact_type_override="opinion",
|
|
2794
|
+
confidence_score=opinion.confidence,
|
|
2788
2795
|
)
|
|
2789
2796
|
|
|
2790
2797
|
except Exception as e:
|
|
2791
2798
|
logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
|
|
2792
2799
|
|
|
2793
|
-
async def get_entity_observations(
|
|
2794
|
-
self,
|
|
2795
|
-
bank_id: str,
|
|
2796
|
-
entity_id: str,
|
|
2797
|
-
limit: int = 10
|
|
2798
|
-
) -> List[EntityObservation]:
|
|
2800
|
+
async def get_entity_observations(self, bank_id: str, entity_id: str, limit: int = 10) -> list[EntityObservation]:
|
|
2799
2801
|
"""
|
|
2800
2802
|
Get observations linked to an entity.
|
|
2801
2803
|
|
|
@@ -2820,23 +2822,18 @@ Guidelines:
|
|
|
2820
2822
|
ORDER BY mu.mentioned_at DESC
|
|
2821
2823
|
LIMIT $3
|
|
2822
2824
|
""",
|
|
2823
|
-
bank_id,
|
|
2825
|
+
bank_id,
|
|
2826
|
+
uuid.UUID(entity_id),
|
|
2827
|
+
limit,
|
|
2824
2828
|
)
|
|
2825
2829
|
|
|
2826
2830
|
observations = []
|
|
2827
2831
|
for row in rows:
|
|
2828
|
-
mentioned_at = row[
|
|
2829
|
-
observations.append(EntityObservation(
|
|
2830
|
-
text=row['text'],
|
|
2831
|
-
mentioned_at=mentioned_at
|
|
2832
|
-
))
|
|
2832
|
+
mentioned_at = row["mentioned_at"].isoformat() if row["mentioned_at"] else None
|
|
2833
|
+
observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
|
|
2833
2834
|
return observations
|
|
2834
2835
|
|
|
2835
|
-
async def list_entities(
|
|
2836
|
-
self,
|
|
2837
|
-
bank_id: str,
|
|
2838
|
-
limit: int = 100
|
|
2839
|
-
) -> List[Dict[str, Any]]:
|
|
2836
|
+
async def list_entities(self, bank_id: str, limit: int = 100) -> list[dict[str, Any]]:
|
|
2840
2837
|
"""
|
|
2841
2838
|
List all entities for a bank.
|
|
2842
2839
|
|
|
@@ -2857,39 +2854,37 @@ Guidelines:
|
|
|
2857
2854
|
ORDER BY mention_count DESC, last_seen DESC
|
|
2858
2855
|
LIMIT $2
|
|
2859
2856
|
""",
|
|
2860
|
-
bank_id,
|
|
2857
|
+
bank_id,
|
|
2858
|
+
limit,
|
|
2861
2859
|
)
|
|
2862
2860
|
|
|
2863
2861
|
entities = []
|
|
2864
2862
|
for row in rows:
|
|
2865
2863
|
# Handle metadata - may be dict, JSON string, or None
|
|
2866
|
-
metadata = row[
|
|
2864
|
+
metadata = row["metadata"]
|
|
2867
2865
|
if metadata is None:
|
|
2868
2866
|
metadata = {}
|
|
2869
2867
|
elif isinstance(metadata, str):
|
|
2870
2868
|
import json
|
|
2869
|
+
|
|
2871
2870
|
try:
|
|
2872
2871
|
metadata = json.loads(metadata)
|
|
2873
2872
|
except json.JSONDecodeError:
|
|
2874
2873
|
metadata = {}
|
|
2875
2874
|
|
|
2876
|
-
entities.append(
|
|
2877
|
-
|
|
2878
|
-
|
|
2879
|
-
|
|
2880
|
-
|
|
2881
|
-
|
|
2882
|
-
|
|
2883
|
-
|
|
2875
|
+
entities.append(
|
|
2876
|
+
{
|
|
2877
|
+
"id": str(row["id"]),
|
|
2878
|
+
"canonical_name": row["canonical_name"],
|
|
2879
|
+
"mention_count": row["mention_count"],
|
|
2880
|
+
"first_seen": row["first_seen"].isoformat() if row["first_seen"] else None,
|
|
2881
|
+
"last_seen": row["last_seen"].isoformat() if row["last_seen"] else None,
|
|
2882
|
+
"metadata": metadata,
|
|
2883
|
+
}
|
|
2884
|
+
)
|
|
2884
2885
|
return entities
|
|
2885
2886
|
|
|
2886
|
-
async def get_entity_state(
|
|
2887
|
-
self,
|
|
2888
|
-
bank_id: str,
|
|
2889
|
-
entity_id: str,
|
|
2890
|
-
entity_name: str,
|
|
2891
|
-
limit: int = 10
|
|
2892
|
-
) -> EntityState:
|
|
2887
|
+
async def get_entity_state(self, bank_id: str, entity_id: str, entity_name: str, limit: int = 10) -> EntityState:
|
|
2893
2888
|
"""
|
|
2894
2889
|
Get the current state (mental model) of an entity.
|
|
2895
2890
|
|
|
@@ -2903,20 +2898,11 @@ Guidelines:
|
|
|
2903
2898
|
EntityState with observations
|
|
2904
2899
|
"""
|
|
2905
2900
|
observations = await self.get_entity_observations(bank_id, entity_id, limit)
|
|
2906
|
-
return EntityState(
|
|
2907
|
-
entity_id=entity_id,
|
|
2908
|
-
canonical_name=entity_name,
|
|
2909
|
-
observations=observations
|
|
2910
|
-
)
|
|
2901
|
+
return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=observations)
|
|
2911
2902
|
|
|
2912
2903
|
async def regenerate_entity_observations(
|
|
2913
|
-
self,
|
|
2914
|
-
|
|
2915
|
-
entity_id: str,
|
|
2916
|
-
entity_name: str,
|
|
2917
|
-
version: str | None = None,
|
|
2918
|
-
conn=None
|
|
2919
|
-
) -> List[str]:
|
|
2904
|
+
self, bank_id: str, entity_id: str, entity_name: str, version: str | None = None, conn=None
|
|
2905
|
+
) -> list[str]:
|
|
2920
2906
|
"""
|
|
2921
2907
|
Regenerate observations for an entity by:
|
|
2922
2908
|
1. Checking version for deduplication (if provided)
|
|
@@ -2961,7 +2947,8 @@ Guidelines:
|
|
|
2961
2947
|
FROM entities
|
|
2962
2948
|
WHERE id = $1 AND bank_id = $2
|
|
2963
2949
|
""",
|
|
2964
|
-
entity_uuid,
|
|
2950
|
+
entity_uuid,
|
|
2951
|
+
bank_id,
|
|
2965
2952
|
)
|
|
2966
2953
|
|
|
2967
2954
|
if current_last_seen and current_last_seen.isoformat() != version:
|
|
@@ -2979,7 +2966,8 @@ Guidelines:
|
|
|
2979
2966
|
ORDER BY mu.occurred_start DESC
|
|
2980
2967
|
LIMIT 50
|
|
2981
2968
|
""",
|
|
2982
|
-
bank_id,
|
|
2969
|
+
bank_id,
|
|
2970
|
+
entity_uuid,
|
|
2983
2971
|
)
|
|
2984
2972
|
|
|
2985
2973
|
if not rows:
|
|
@@ -2988,21 +2976,19 @@ Guidelines:
|
|
|
2988
2976
|
# Convert to MemoryFact objects for the observation extraction
|
|
2989
2977
|
facts = []
|
|
2990
2978
|
for row in rows:
|
|
2991
|
-
occurred_start = row[
|
|
2992
|
-
facts.append(
|
|
2993
|
-
|
|
2994
|
-
|
|
2995
|
-
|
|
2996
|
-
|
|
2997
|
-
|
|
2998
|
-
|
|
2979
|
+
occurred_start = row["occurred_start"].isoformat() if row["occurred_start"] else None
|
|
2980
|
+
facts.append(
|
|
2981
|
+
MemoryFact(
|
|
2982
|
+
id=str(row["id"]),
|
|
2983
|
+
text=row["text"],
|
|
2984
|
+
fact_type=row["fact_type"],
|
|
2985
|
+
context=row["context"],
|
|
2986
|
+
occurred_start=occurred_start,
|
|
2987
|
+
)
|
|
2988
|
+
)
|
|
2999
2989
|
|
|
3000
2990
|
# Step 3: Extract observations using LLM (no personality)
|
|
3001
|
-
observations = await observation_utils.extract_observations_from_facts(
|
|
3002
|
-
self._llm_config,
|
|
3003
|
-
entity_name,
|
|
3004
|
-
facts
|
|
3005
|
-
)
|
|
2991
|
+
observations = await observation_utils.extract_observations_from_facts(self._llm_config, entity_name, facts)
|
|
3006
2992
|
|
|
3007
2993
|
if not observations:
|
|
3008
2994
|
return []
|
|
@@ -3024,13 +3010,12 @@ Guidelines:
|
|
|
3024
3010
|
AND ue.entity_id = $2
|
|
3025
3011
|
)
|
|
3026
3012
|
""",
|
|
3027
|
-
bank_id,
|
|
3013
|
+
bank_id,
|
|
3014
|
+
entity_uuid,
|
|
3028
3015
|
)
|
|
3029
3016
|
|
|
3030
3017
|
# Generate embeddings for new observations
|
|
3031
|
-
embeddings = await embedding_utils.generate_embeddings_batch(
|
|
3032
|
-
self.embeddings, observations
|
|
3033
|
-
)
|
|
3018
|
+
embeddings = await embedding_utils.generate_embeddings_batch(self.embeddings, observations)
|
|
3034
3019
|
|
|
3035
3020
|
# Insert new observations
|
|
3036
3021
|
current_time = utcnow()
|
|
@@ -3054,9 +3039,9 @@ Guidelines:
|
|
|
3054
3039
|
current_time,
|
|
3055
3040
|
current_time,
|
|
3056
3041
|
current_time,
|
|
3057
|
-
current_time
|
|
3042
|
+
current_time,
|
|
3058
3043
|
)
|
|
3059
|
-
obs_id = str(result[
|
|
3044
|
+
obs_id = str(result["id"])
|
|
3060
3045
|
created_ids.append(obs_id)
|
|
3061
3046
|
|
|
3062
3047
|
# Link observation to entity
|
|
@@ -3065,7 +3050,8 @@ Guidelines:
|
|
|
3065
3050
|
INSERT INTO unit_entities (unit_id, entity_id)
|
|
3066
3051
|
VALUES ($1, $2)
|
|
3067
3052
|
""",
|
|
3068
|
-
uuid.UUID(obs_id),
|
|
3053
|
+
uuid.UUID(obs_id),
|
|
3054
|
+
entity_uuid,
|
|
3069
3055
|
)
|
|
3070
3056
|
|
|
3071
3057
|
return created_ids
|
|
@@ -3080,11 +3066,7 @@ Guidelines:
|
|
|
3080
3066
|
return await do_db_operations(acquired_conn)
|
|
3081
3067
|
|
|
3082
3068
|
async def _regenerate_observations_sync(
|
|
3083
|
-
self,
|
|
3084
|
-
bank_id: str,
|
|
3085
|
-
entity_ids: List[str],
|
|
3086
|
-
min_facts: int = 5,
|
|
3087
|
-
conn=None
|
|
3069
|
+
self, bank_id: str, entity_ids: list[str], min_facts: int = 5, conn=None
|
|
3088
3070
|
) -> None:
|
|
3089
3071
|
"""
|
|
3090
3072
|
Regenerate observations for entities synchronously (called during retain).
|
|
@@ -3111,9 +3093,10 @@ Guidelines:
|
|
|
3111
3093
|
SELECT id, canonical_name FROM entities
|
|
3112
3094
|
WHERE id = ANY($1) AND bank_id = $2
|
|
3113
3095
|
""",
|
|
3114
|
-
entity_uuids,
|
|
3096
|
+
entity_uuids,
|
|
3097
|
+
bank_id,
|
|
3115
3098
|
)
|
|
3116
|
-
entity_names = {row[
|
|
3099
|
+
entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
|
|
3117
3100
|
|
|
3118
3101
|
fact_counts = await conn.fetch(
|
|
3119
3102
|
"""
|
|
@@ -3123,9 +3106,10 @@ Guidelines:
|
|
|
3123
3106
|
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
3124
3107
|
GROUP BY ue.entity_id
|
|
3125
3108
|
""",
|
|
3126
|
-
entity_uuids,
|
|
3109
|
+
entity_uuids,
|
|
3110
|
+
bank_id,
|
|
3127
3111
|
)
|
|
3128
|
-
entity_fact_counts = {row[
|
|
3112
|
+
entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
|
|
3129
3113
|
else:
|
|
3130
3114
|
# Acquire a new connection (standalone call)
|
|
3131
3115
|
pool = await self._get_pool()
|
|
@@ -3135,9 +3119,10 @@ Guidelines:
|
|
|
3135
3119
|
SELECT id, canonical_name FROM entities
|
|
3136
3120
|
WHERE id = ANY($1) AND bank_id = $2
|
|
3137
3121
|
""",
|
|
3138
|
-
entity_uuids,
|
|
3122
|
+
entity_uuids,
|
|
3123
|
+
bank_id,
|
|
3139
3124
|
)
|
|
3140
|
-
entity_names = {row[
|
|
3125
|
+
entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
|
|
3141
3126
|
|
|
3142
3127
|
fact_counts = await acquired_conn.fetch(
|
|
3143
3128
|
"""
|
|
@@ -3147,9 +3132,10 @@ Guidelines:
|
|
|
3147
3132
|
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
3148
3133
|
GROUP BY ue.entity_id
|
|
3149
3134
|
""",
|
|
3150
|
-
entity_uuids,
|
|
3135
|
+
entity_uuids,
|
|
3136
|
+
bank_id,
|
|
3151
3137
|
)
|
|
3152
|
-
entity_fact_counts = {row[
|
|
3138
|
+
entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
|
|
3153
3139
|
|
|
3154
3140
|
# Filter entities that meet the threshold
|
|
3155
3141
|
entities_to_process = []
|
|
@@ -3171,11 +3157,9 @@ Guidelines:
|
|
|
3171
3157
|
except Exception as e:
|
|
3172
3158
|
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
3173
3159
|
|
|
3174
|
-
await asyncio.gather(*[
|
|
3175
|
-
process_entity(eid, name) for eid, name in entities_to_process
|
|
3176
|
-
])
|
|
3160
|
+
await asyncio.gather(*[process_entity(eid, name) for eid, name in entities_to_process])
|
|
3177
3161
|
|
|
3178
|
-
async def _handle_regenerate_observations(self, task_dict:
|
|
3162
|
+
async def _handle_regenerate_observations(self, task_dict: dict[str, Any]):
|
|
3179
3163
|
"""
|
|
3180
3164
|
Handler for regenerate_observations tasks.
|
|
3181
3165
|
|
|
@@ -3185,12 +3169,12 @@ Guidelines:
|
|
|
3185
3169
|
- 'entity_id', 'entity_name': Process single entity (legacy)
|
|
3186
3170
|
"""
|
|
3187
3171
|
try:
|
|
3188
|
-
bank_id = task_dict.get(
|
|
3172
|
+
bank_id = task_dict.get("bank_id")
|
|
3189
3173
|
|
|
3190
3174
|
# New format: multiple entity_ids
|
|
3191
|
-
if
|
|
3192
|
-
entity_ids = task_dict.get(
|
|
3193
|
-
min_facts = task_dict.get(
|
|
3175
|
+
if "entity_ids" in task_dict:
|
|
3176
|
+
entity_ids = task_dict.get("entity_ids", [])
|
|
3177
|
+
min_facts = task_dict.get("min_facts", 5)
|
|
3194
3178
|
|
|
3195
3179
|
if not bank_id or not entity_ids:
|
|
3196
3180
|
logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
|
|
@@ -3203,31 +3187,37 @@ Guidelines:
|
|
|
3203
3187
|
try:
|
|
3204
3188
|
# Fetch entity name and check fact count
|
|
3205
3189
|
import uuid as uuid_module
|
|
3190
|
+
|
|
3206
3191
|
entity_uuid = uuid_module.UUID(entity_id) if isinstance(entity_id, str) else entity_id
|
|
3207
3192
|
|
|
3208
3193
|
# First check if entity exists
|
|
3209
3194
|
entity_exists = await conn.fetchrow(
|
|
3210
3195
|
"SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
|
|
3211
|
-
entity_uuid,
|
|
3196
|
+
entity_uuid,
|
|
3197
|
+
bank_id,
|
|
3212
3198
|
)
|
|
3213
3199
|
|
|
3214
3200
|
if not entity_exists:
|
|
3215
3201
|
logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
|
|
3216
3202
|
continue
|
|
3217
3203
|
|
|
3218
|
-
entity_name = entity_exists[
|
|
3204
|
+
entity_name = entity_exists["canonical_name"]
|
|
3219
3205
|
|
|
3220
3206
|
# Count facts linked to this entity
|
|
3221
|
-
fact_count =
|
|
3222
|
-
|
|
3223
|
-
|
|
3224
|
-
|
|
3207
|
+
fact_count = (
|
|
3208
|
+
await conn.fetchval(
|
|
3209
|
+
"SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1", entity_uuid
|
|
3210
|
+
)
|
|
3211
|
+
or 0
|
|
3212
|
+
)
|
|
3225
3213
|
|
|
3226
3214
|
# Only regenerate if entity has enough facts
|
|
3227
3215
|
if fact_count >= min_facts:
|
|
3228
3216
|
await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
|
|
3229
3217
|
else:
|
|
3230
|
-
logger.debug(
|
|
3218
|
+
logger.debug(
|
|
3219
|
+
f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
|
|
3220
|
+
)
|
|
3231
3221
|
|
|
3232
3222
|
except Exception as e:
|
|
3233
3223
|
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
@@ -3235,9 +3225,9 @@ Guidelines:
|
|
|
3235
3225
|
|
|
3236
3226
|
# Legacy format: single entity
|
|
3237
3227
|
else:
|
|
3238
|
-
entity_id = task_dict.get(
|
|
3239
|
-
entity_name = task_dict.get(
|
|
3240
|
-
version = task_dict.get(
|
|
3228
|
+
entity_id = task_dict.get("entity_id")
|
|
3229
|
+
entity_name = task_dict.get("entity_name")
|
|
3230
|
+
version = task_dict.get("version")
|
|
3241
3231
|
|
|
3242
3232
|
if not all([bank_id, entity_id, entity_name]):
|
|
3243
3233
|
logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
|
|
@@ -3248,5 +3238,5 @@ Guidelines:
|
|
|
3248
3238
|
except Exception as e:
|
|
3249
3239
|
logger.error(f"[OBSERVATIONS] Error regenerating observations: {e}")
|
|
3250
3240
|
import traceback
|
|
3251
|
-
traceback.print_exc()
|
|
3252
3241
|
|
|
3242
|
+
traceback.print_exc()
|