hindsight-api 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. hindsight_api/__init__.py +38 -0
  2. hindsight_api/api/__init__.py +105 -0
  3. hindsight_api/api/http.py +1872 -0
  4. hindsight_api/api/mcp.py +157 -0
  5. hindsight_api/engine/__init__.py +47 -0
  6. hindsight_api/engine/cross_encoder.py +97 -0
  7. hindsight_api/engine/db_utils.py +93 -0
  8. hindsight_api/engine/embeddings.py +113 -0
  9. hindsight_api/engine/entity_resolver.py +575 -0
  10. hindsight_api/engine/llm_wrapper.py +269 -0
  11. hindsight_api/engine/memory_engine.py +3095 -0
  12. hindsight_api/engine/query_analyzer.py +519 -0
  13. hindsight_api/engine/response_models.py +222 -0
  14. hindsight_api/engine/retain/__init__.py +50 -0
  15. hindsight_api/engine/retain/bank_utils.py +423 -0
  16. hindsight_api/engine/retain/chunk_storage.py +82 -0
  17. hindsight_api/engine/retain/deduplication.py +104 -0
  18. hindsight_api/engine/retain/embedding_processing.py +62 -0
  19. hindsight_api/engine/retain/embedding_utils.py +54 -0
  20. hindsight_api/engine/retain/entity_processing.py +90 -0
  21. hindsight_api/engine/retain/fact_extraction.py +1027 -0
  22. hindsight_api/engine/retain/fact_storage.py +176 -0
  23. hindsight_api/engine/retain/link_creation.py +121 -0
  24. hindsight_api/engine/retain/link_utils.py +651 -0
  25. hindsight_api/engine/retain/orchestrator.py +405 -0
  26. hindsight_api/engine/retain/types.py +206 -0
  27. hindsight_api/engine/search/__init__.py +15 -0
  28. hindsight_api/engine/search/fusion.py +122 -0
  29. hindsight_api/engine/search/observation_utils.py +132 -0
  30. hindsight_api/engine/search/reranking.py +103 -0
  31. hindsight_api/engine/search/retrieval.py +503 -0
  32. hindsight_api/engine/search/scoring.py +161 -0
  33. hindsight_api/engine/search/temporal_extraction.py +64 -0
  34. hindsight_api/engine/search/think_utils.py +255 -0
  35. hindsight_api/engine/search/trace.py +215 -0
  36. hindsight_api/engine/search/tracer.py +447 -0
  37. hindsight_api/engine/search/types.py +160 -0
  38. hindsight_api/engine/task_backend.py +223 -0
  39. hindsight_api/engine/utils.py +203 -0
  40. hindsight_api/metrics.py +227 -0
  41. hindsight_api/migrations.py +163 -0
  42. hindsight_api/models.py +309 -0
  43. hindsight_api/pg0.py +425 -0
  44. hindsight_api/web/__init__.py +12 -0
  45. hindsight_api/web/server.py +143 -0
  46. hindsight_api-0.0.13.dist-info/METADATA +41 -0
  47. hindsight_api-0.0.13.dist-info/RECORD +48 -0
  48. hindsight_api-0.0.13.dist-info/WHEEL +4 -0
@@ -0,0 +1,3095 @@
1
+ """
2
+ Memory Engine for Memory Banks.
3
+
4
+ This implements a sophisticated memory architecture that combines:
5
+ 1. Temporal links: Memories connected by time proximity
6
+ 2. Semantic links: Memories connected by meaning/similarity
7
+ 3. Entity links: Memories connected by shared entities (PERSON, ORG, etc.)
8
+ 4. Spreading activation: Search through the graph with activation decay
9
+ 5. Dynamic weighting: Recency and frequency-based importance
10
+ """
11
+ import json
12
+ import os
13
+ from datetime import datetime, timedelta, timezone
14
+ from typing import Any, Dict, List, Optional, Tuple, Union, TypedDict
15
+ import asyncpg
16
+ import asyncio
17
+ from .embeddings import Embeddings, SentenceTransformersEmbeddings
18
+ from .cross_encoder import CrossEncoderModel
19
+ import time
20
+ import numpy as np
21
+ import uuid
22
+ import logging
23
+ from pydantic import BaseModel, Field
24
+
25
+
26
+ class RetainContentDict(TypedDict, total=False):
27
+ """Type definition for content items in retain_batch_async.
28
+
29
+ Fields:
30
+ content: Text content to store (required)
31
+ context: Context about the content (optional)
32
+ event_date: When the content occurred (optional, defaults to now)
33
+ metadata: Custom key-value metadata (optional)
34
+ document_id: Document ID for this content item (optional)
35
+ """
36
+ content: str # Required
37
+ context: str
38
+ event_date: datetime
39
+ metadata: Dict[str, str]
40
+ document_id: str
41
+
42
+ from .query_analyzer import QueryAnalyzer
43
+ from .search.scoring import (
44
+ calculate_recency_weight,
45
+ calculate_frequency_weight,
46
+ )
47
+ from .entity_resolver import EntityResolver
48
+ from .retain import embedding_utils, bank_utils
49
+ from .search import think_utils, observation_utils
50
+ from .llm_wrapper import LLMConfig
51
+ from .response_models import RecallResult as RecallResultModel, ReflectResult, MemoryFact, EntityState, EntityObservation
52
+ from .task_backend import TaskBackend, AsyncIOQueueBackend
53
+ from .search.reranking import CrossEncoderReranker
54
+ from ..pg0 import EmbeddedPostgres
55
+ from enum import Enum
56
+
57
+
58
+ class Budget(str, Enum):
59
+ """Budget levels for recall/reflect operations."""
60
+ LOW = "low"
61
+ MID = "mid"
62
+ HIGH = "high"
63
+
64
+
65
+ def utcnow():
66
+ """Get current UTC time with timezone info."""
67
+ return datetime.now(timezone.utc)
68
+
69
+
70
+ # Logger for memory system
71
+ logger = logging.getLogger(__name__)
72
+
73
+ from .db_utils import acquire_with_retry, retry_with_backoff
74
+
75
+ import tiktoken
76
+ from dateutil import parser as date_parser
77
+
78
+ # Cache tiktoken encoding for token budget filtering (module-level singleton)
79
+ _TIKTOKEN_ENCODING = None
80
+
81
+ def _get_tiktoken_encoding():
82
+ """Get cached tiktoken encoding (cl100k_base for GPT-4/3.5)."""
83
+ global _TIKTOKEN_ENCODING
84
+ if _TIKTOKEN_ENCODING is None:
85
+ _TIKTOKEN_ENCODING = tiktoken.get_encoding("cl100k_base")
86
+ return _TIKTOKEN_ENCODING
87
+
88
+
89
+ class MemoryEngine:
90
+ """
91
+ Advanced memory system using temporal and semantic linking with PostgreSQL.
92
+
93
+ This class provides:
94
+ - Embedding generation for semantic search
95
+ - Entity, temporal, and semantic link creation
96
+ - Think operations for formulating answers with opinions
97
+ - bank profile and personality management
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ db_url: str,
103
+ memory_llm_provider: str,
104
+ memory_llm_api_key: str,
105
+ memory_llm_model: str,
106
+ memory_llm_base_url: Optional[str] = None,
107
+ embeddings: Optional[Embeddings] = None,
108
+ cross_encoder: Optional[CrossEncoderModel] = None,
109
+ query_analyzer: Optional[QueryAnalyzer] = None,
110
+ pool_min_size: int = 5,
111
+ pool_max_size: int = 100,
112
+ task_backend: Optional[TaskBackend] = None,
113
+ ):
114
+ """
115
+ Initialize the temporal + semantic memory system.
116
+
117
+ Args:
118
+ db_url: PostgreSQL connection URL (postgresql://user:pass@host:port/dbname). Required.
119
+ memory_llm_provider: LLM provider for memory operations: "openai", "groq", or "ollama". Required.
120
+ memory_llm_api_key: API key for the LLM provider. Required.
121
+ memory_llm_model: Model name to use for all memory operations (put/think/opinions). Required.
122
+ memory_llm_base_url: Base URL for the LLM API. Optional. Defaults based on provider:
123
+ - groq: https://api.groq.com/openai/v1
124
+ - ollama: http://localhost:11434/v1
125
+ embeddings: Embeddings implementation to use. If not provided, uses SentenceTransformersEmbeddings
126
+ cross_encoder: Cross-encoder model for reranking. If not provided, uses default when cross-encoder reranker is selected
127
+ query_analyzer: Query analyzer implementation to use. If not provided, uses TransformerQueryAnalyzer
128
+ pool_min_size: Minimum number of connections in the pool (default: 5)
129
+ pool_max_size: Maximum number of connections in the pool (default: 100)
130
+ Increase for parallel think/search operations (e.g., 200-300 for 100+ parallel thinks)
131
+ task_backend: Custom task backend for async task execution. If not provided, uses AsyncIOQueueBackend
132
+ """
133
+ if not db_url:
134
+ raise ValueError("Database url is required")
135
+ # Track pg0 instance (if used)
136
+ self._pg0: Optional[EmbeddedPostgres] = None
137
+
138
+ # Initialize PostgreSQL connection URL
139
+ # The actual URL will be set during initialize() after starting the server
140
+ self._use_pg0 = db_url == "pg0"
141
+ self.db_url = db_url if not self._use_pg0 else None
142
+
143
+
144
+ # Set default base URL if not provided
145
+ if memory_llm_base_url is None:
146
+ if memory_llm_provider.lower() == "groq":
147
+ memory_llm_base_url = "https://api.groq.com/openai/v1"
148
+ elif memory_llm_provider.lower() == "ollama":
149
+ memory_llm_base_url = "http://localhost:11434/v1"
150
+ else:
151
+ memory_llm_base_url = ""
152
+
153
+ # Connection pool (will be created in initialize())
154
+ self._pool = None
155
+ self._initialized = False
156
+ self._pool_min_size = pool_min_size
157
+ self._pool_max_size = pool_max_size
158
+
159
+ # Initialize entity resolver (will be created in initialize())
160
+ self.entity_resolver = None
161
+
162
+ # Initialize embeddings
163
+ if embeddings is not None:
164
+ self.embeddings = embeddings
165
+ else:
166
+ self.embeddings = SentenceTransformersEmbeddings("BAAI/bge-small-en-v1.5")
167
+
168
+ # Initialize query analyzer
169
+ if query_analyzer is not None:
170
+ self.query_analyzer = query_analyzer
171
+ else:
172
+ from .query_analyzer import DateparserQueryAnalyzer
173
+ self.query_analyzer = DateparserQueryAnalyzer()
174
+
175
+ # Initialize LLM configuration
176
+ self._llm_config = LLMConfig(
177
+ provider=memory_llm_provider,
178
+ api_key=memory_llm_api_key,
179
+ base_url=memory_llm_base_url,
180
+ model=memory_llm_model,
181
+ )
182
+
183
+ # Store client and model for convenience (deprecated: use _llm_config.call() instead)
184
+ self._llm_client = self._llm_config._client
185
+ self._llm_model = self._llm_config.model
186
+
187
+ # Initialize cross-encoder reranker (cached for performance)
188
+ self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
189
+
190
+ # Initialize task backend
191
+ self._task_backend = task_backend or AsyncIOQueueBackend(
192
+ batch_size=100,
193
+ batch_interval=1.0
194
+ )
195
+
196
+ # Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
197
+ # Limit concurrent searches to prevent connection pool exhaustion
198
+ # Each search can use 2-4 connections, so with 10 concurrent searches
199
+ # we use ~20-40 connections max, staying well within pool limits
200
+ self._search_semaphore = asyncio.Semaphore(10)
201
+
202
+ # Backpressure for put operations: limit concurrent puts to prevent database contention
203
+ # Each put_batch holds a connection for the entire transaction, so we limit to 5
204
+ # concurrent puts to avoid connection pool exhaustion and reduce write contention
205
+ self._put_semaphore = asyncio.Semaphore(5)
206
+
207
+ # initialize encoding eagerly to avoid delaying the first time
208
+ _get_tiktoken_encoding()
209
+
210
+ async def _handle_access_count_update(self, task_dict: Dict[str, Any]):
211
+ """
212
+ Handler for access count update tasks.
213
+
214
+ Args:
215
+ task_dict: Dict with 'node_ids' key containing list of node IDs to update
216
+ """
217
+ node_ids = task_dict.get('node_ids', [])
218
+ if not node_ids:
219
+ return
220
+
221
+ pool = await self._get_pool()
222
+ try:
223
+ # Convert string UUIDs to UUID type for faster matching
224
+ uuid_list = [uuid.UUID(nid) for nid in node_ids]
225
+ async with acquire_with_retry(pool) as conn:
226
+ await conn.execute(
227
+ "UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
228
+ uuid_list
229
+ )
230
+ except Exception as e:
231
+ logger.error(f"Access count handler: Error updating access counts: {e}")
232
+
233
+ async def _handle_batch_retain(self, task_dict: Dict[str, Any]):
234
+ """
235
+ Handler for batch retain tasks.
236
+
237
+ Args:
238
+ task_dict: Dict with 'bank_id', 'contents'
239
+ """
240
+ try:
241
+ bank_id = task_dict.get('bank_id')
242
+ contents = task_dict.get('contents', [])
243
+
244
+ logger.info(f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items")
245
+
246
+ await self.retain_batch_async(
247
+ bank_id=bank_id,
248
+ contents=contents
249
+ )
250
+
251
+ logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
252
+ except Exception as e:
253
+ logger.error(f"Batch retain handler: Error processing batch retain: {e}")
254
+ import traceback
255
+ traceback.print_exc()
256
+
257
+ async def execute_task(self, task_dict: Dict[str, Any]):
258
+ """
259
+ Execute a task by routing it to the appropriate handler.
260
+
261
+ This method is called by the task backend to execute tasks.
262
+ It receives a plain dict that can be serialized and sent over the network.
263
+
264
+ Args:
265
+ task_dict: Task dictionary with 'type' key and other payload data
266
+ Example: {'type': 'access_count_update', 'node_ids': [...]}
267
+ """
268
+ task_type = task_dict.get('type')
269
+ operation_id = task_dict.get('operation_id')
270
+ retry_count = task_dict.get('retry_count', 0)
271
+ max_retries = 3
272
+
273
+ # Check if operation was cancelled (only for tasks with operation_id)
274
+ if operation_id:
275
+ try:
276
+ pool = await self._get_pool()
277
+ async with acquire_with_retry(pool) as conn:
278
+ result = await conn.fetchrow(
279
+ "SELECT id FROM async_operations WHERE id = $1",
280
+ uuid.UUID(operation_id)
281
+ )
282
+ if not result:
283
+ # Operation was cancelled, skip processing
284
+ logger.info(f"Skipping cancelled operation: {operation_id}")
285
+ return
286
+ except Exception as e:
287
+ logger.error(f"Failed to check operation status {operation_id}: {e}")
288
+ # Continue with processing if we can't check status
289
+
290
+ try:
291
+ if task_type == 'access_count_update':
292
+ await self._handle_access_count_update(task_dict)
293
+ elif task_type == 'reinforce_opinion':
294
+ await self._handle_reinforce_opinion(task_dict)
295
+ elif task_type == 'form_opinion':
296
+ await self._handle_form_opinion(task_dict)
297
+ elif task_type == 'batch_put':
298
+ await self._handle_batch_retain(task_dict)
299
+ elif task_type == 'regenerate_observations':
300
+ await self._handle_regenerate_observations(task_dict)
301
+ else:
302
+ logger.error(f"Unknown task type: {task_type}")
303
+ # Don't retry unknown task types
304
+ if operation_id:
305
+ await self._delete_operation_record(operation_id)
306
+ return
307
+
308
+ # Task succeeded - delete operation record
309
+ if operation_id:
310
+ await self._delete_operation_record(operation_id)
311
+
312
+ except Exception as e:
313
+ # Task failed - check if we should retry
314
+ logger.error(f"Task execution failed (attempt {retry_count + 1}/{max_retries + 1}): {task_type}, error: {e}")
315
+ import traceback
316
+ error_traceback = traceback.format_exc()
317
+ traceback.print_exc()
318
+
319
+ if retry_count < max_retries:
320
+ # Reschedule with incremented retry count
321
+ task_dict['retry_count'] = retry_count + 1
322
+ logger.info(f"Rescheduling task {task_type} (retry {retry_count + 1}/{max_retries})")
323
+ await self._task_backend.submit_task(task_dict)
324
+ else:
325
+ # Max retries exceeded - mark operation as failed
326
+ logger.error(f"Max retries exceeded for task {task_type}, marking as failed")
327
+ if operation_id:
328
+ await self._mark_operation_failed(operation_id, str(e), error_traceback)
329
+
330
+ async def _delete_operation_record(self, operation_id: str):
331
+ """Helper to delete an operation record from the database."""
332
+ try:
333
+ pool = await self._get_pool()
334
+ async with acquire_with_retry(pool) as conn:
335
+ await conn.execute(
336
+ "DELETE FROM async_operations WHERE id = $1",
337
+ uuid.UUID(operation_id)
338
+ )
339
+ except Exception as e:
340
+ logger.error(f"Failed to delete async operation record {operation_id}: {e}")
341
+
342
+ async def _mark_operation_failed(self, operation_id: str, error_message: str, error_traceback: str):
343
+ """Helper to mark an operation as failed in the database."""
344
+ try:
345
+ pool = await self._get_pool()
346
+ # Truncate error message to avoid extremely long strings
347
+ full_error = f"{error_message}\n\nTraceback:\n{error_traceback}"
348
+ truncated_error = full_error[:5000] if len(full_error) > 5000 else full_error
349
+
350
+ async with acquire_with_retry(pool) as conn:
351
+ await conn.execute(
352
+ """
353
+ UPDATE async_operations
354
+ SET status = 'failed', error_message = $2
355
+ WHERE id = $1
356
+ """,
357
+ uuid.UUID(operation_id),
358
+ truncated_error
359
+ )
360
+ logger.info(f"Marked async operation as failed: {operation_id}")
361
+ except Exception as e:
362
+ logger.error(f"Failed to mark operation as failed {operation_id}: {e}")
363
+
364
+ async def initialize(self):
365
+ """Initialize the connection pool, models, and background workers.
366
+
367
+ Loads models (embeddings, cross-encoder) in parallel with pg0 startup
368
+ for faster overall initialization.
369
+ """
370
+ if self._initialized:
371
+ return
372
+
373
+ import concurrent.futures
374
+
375
+ # Run model loading in thread pool (CPU-bound) in parallel with pg0 startup
376
+ loop = asyncio.get_event_loop()
377
+
378
+ async def start_pg0():
379
+ """Start pg0 if configured."""
380
+ if self._use_pg0:
381
+ self._pg0 = EmbeddedPostgres()
382
+ self.db_url = await self._pg0.ensure_running()
383
+
384
+ def load_embeddings():
385
+ """Load embedding model (CPU-bound)."""
386
+ self.embeddings.load()
387
+
388
+ def load_cross_encoder():
389
+ """Load cross-encoder model (CPU-bound)."""
390
+ self._cross_encoder_reranker.cross_encoder.load()
391
+
392
+ def load_query_analyzer():
393
+ """Load query analyzer model (CPU-bound)."""
394
+ self.query_analyzer.load()
395
+
396
+ # Run pg0 and all model loads in parallel
397
+ # pg0 is async (IO-bound), models are sync (CPU-bound in thread pool)
398
+ # Use 3 workers to load all models concurrently
399
+ with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
400
+ # Start all tasks
401
+ pg0_task = asyncio.create_task(start_pg0())
402
+ embeddings_future = loop.run_in_executor(executor, load_embeddings)
403
+ cross_encoder_future = loop.run_in_executor(executor, load_cross_encoder)
404
+ query_analyzer_future = loop.run_in_executor(executor, load_query_analyzer)
405
+
406
+ # Wait for all to complete
407
+ await asyncio.gather(
408
+ pg0_task, embeddings_future, cross_encoder_future, query_analyzer_future
409
+ )
410
+
411
+ logger.info(f"Connecting to PostgreSQL at {self.db_url}")
412
+
413
+ # Create connection pool
414
+ # For read-heavy workloads with many parallel think/search operations,
415
+ # we need a larger pool. Read operations don't need strong isolation.
416
+ self._pool = await asyncpg.create_pool(
417
+ self.db_url,
418
+ min_size=self._pool_min_size,
419
+ max_size=self._pool_max_size,
420
+ command_timeout=60,
421
+ statement_cache_size=0, # Disable prepared statement cache
422
+ timeout=30, # Connection acquisition timeout (seconds)
423
+ )
424
+
425
+ # Initialize entity resolver with pool
426
+ self.entity_resolver = EntityResolver(self._pool)
427
+
428
+ # Set executor for task backend and initialize
429
+ self._task_backend.set_executor(self.execute_task)
430
+ await self._task_backend.initialize()
431
+
432
+ self._initialized = True
433
+ logger.info("Memory system initialized (pool and task backend started)")
434
+
435
+ async def _get_pool(self) -> asyncpg.Pool:
436
+ """Get the connection pool (must call initialize() first)."""
437
+ if not self._initialized:
438
+ await self.initialize()
439
+ return self._pool
440
+
441
+ async def _acquire_connection(self):
442
+ """
443
+ Acquire a connection from the pool with retry logic.
444
+
445
+ Returns an async context manager that yields a connection.
446
+ Retries on transient connection errors with exponential backoff.
447
+ """
448
+ pool = await self._get_pool()
449
+
450
+ async def acquire():
451
+ return await pool.acquire()
452
+
453
+ return await _retry_with_backoff(acquire)
454
+
455
+ async def health_check(self) -> dict:
456
+ """
457
+ Perform a health check by querying the database.
458
+
459
+ Returns:
460
+ dict with status and optional error message
461
+ """
462
+ try:
463
+ pool = await self._get_pool()
464
+ async with pool.acquire() as conn:
465
+ result = await conn.fetchval("SELECT 1")
466
+ if result == 1:
467
+ return {"status": "healthy", "database": "connected"}
468
+ else:
469
+ return {"status": "unhealthy", "database": "unexpected response"}
470
+ except Exception as e:
471
+ return {"status": "unhealthy", "database": "error", "error": str(e)}
472
+
473
+ async def close(self):
474
+ """Close the connection pool and shutdown background workers."""
475
+ logger.info("close() started")
476
+
477
+ # Shutdown task backend
478
+ await self._task_backend.shutdown()
479
+
480
+ # Close pool
481
+ if self._pool is not None:
482
+ self._pool.terminate()
483
+ self._pool = None
484
+
485
+ self._initialized = False
486
+
487
+ # Stop pg0 if we started it
488
+ if self._pg0 is not None:
489
+ logger.info("Stopping pg0...")
490
+ await self._pg0.stop()
491
+ self._pg0 = None
492
+ logger.info("pg0 stopped")
493
+
494
+
495
+ async def wait_for_background_tasks(self):
496
+ """
497
+ Wait for all pending background tasks to complete.
498
+
499
+ This is useful in tests to ensure background tasks (like opinion reinforcement)
500
+ complete before making assertions.
501
+ """
502
+ if hasattr(self._task_backend, 'wait_for_pending_tasks'):
503
+ await self._task_backend.wait_for_pending_tasks()
504
+
505
+ def _format_readable_date(self, dt: datetime) -> str:
506
+ """
507
+ Format a datetime into a readable string for temporal matching.
508
+
509
+ Examples:
510
+ - June 2024
511
+ - January 15, 2024
512
+ - December 2023
513
+
514
+ This helps queries like "camping in June" match facts that happened in June.
515
+
516
+ Args:
517
+ dt: datetime object to format
518
+
519
+ Returns:
520
+ Readable date string
521
+ """
522
+ # Format as "Month Year" for most cases
523
+ # Could be extended to include day for very specific dates if needed
524
+ month_name = dt.strftime("%B") # Full month name (e.g., "June")
525
+ year = dt.strftime("%Y") # Year (e.g., "2024")
526
+
527
+ # For now, use "Month Year" format
528
+ # Could check if day is significant (not 1st or 15th) and include it
529
+ return f"{month_name} {year}"
530
+
531
+ async def _find_duplicate_facts_batch(
532
+ self,
533
+ conn,
534
+ bank_id: str,
535
+ texts: List[str],
536
+ embeddings: List[List[float]],
537
+ event_date: datetime,
538
+ time_window_hours: int = 24,
539
+ similarity_threshold: float = 0.95
540
+ ) -> List[bool]:
541
+ """
542
+ Check which facts are duplicates using semantic similarity + temporal window.
543
+
544
+ For each new fact, checks if a semantically similar fact already exists
545
+ within the time window. Uses pgvector cosine similarity for efficiency.
546
+
547
+ Args:
548
+ conn: Database connection
549
+ bank_id: bank IDentifier
550
+ texts: List of fact texts to check
551
+ embeddings: Corresponding embeddings
552
+ event_date: Event date for temporal filtering
553
+ time_window_hours: Hours before/after event_date to search (default: 24)
554
+ similarity_threshold: Minimum cosine similarity to consider duplicate (default: 0.95)
555
+
556
+ Returns:
557
+ List of booleans - True if fact is a duplicate (should skip), False if new
558
+ """
559
+ if not texts:
560
+ return []
561
+
562
+ # Handle edge cases where event_date is at datetime boundaries
563
+ try:
564
+ time_lower = event_date - timedelta(hours=time_window_hours)
565
+ except OverflowError:
566
+ time_lower = datetime.min
567
+ try:
568
+ time_upper = event_date + timedelta(hours=time_window_hours)
569
+ except OverflowError:
570
+ time_upper = datetime.max
571
+
572
+ # Fetch ALL existing facts in time window ONCE (much faster than N queries)
573
+ import time as time_mod
574
+ fetch_start = time_mod.time()
575
+ existing_facts = await conn.fetch(
576
+ """
577
+ SELECT id, text, embedding
578
+ FROM memory_units
579
+ WHERE bank_id = $1
580
+ AND event_date BETWEEN $2 AND $3
581
+ """,
582
+ bank_id, time_lower, time_upper
583
+ )
584
+
585
+ # If no existing facts, nothing is duplicate
586
+ if not existing_facts:
587
+ return [False] * len(texts)
588
+
589
+ # Compute similarities in Python (vectorized with numpy)
590
+ import numpy as np
591
+ is_duplicate = []
592
+
593
+ # Convert existing embeddings to numpy for faster computation
594
+ embedding_arrays = []
595
+ for row in existing_facts:
596
+ raw_emb = row['embedding']
597
+ # Handle different pgvector formats
598
+ if isinstance(raw_emb, str):
599
+ # Parse string format: "[1.0, 2.0, ...]"
600
+ import json
601
+ emb = np.array(json.loads(raw_emb), dtype=np.float32)
602
+ elif isinstance(raw_emb, (list, tuple)):
603
+ emb = np.array(raw_emb, dtype=np.float32)
604
+ else:
605
+ # Try direct conversion
606
+ emb = np.array(raw_emb, dtype=np.float32)
607
+ embedding_arrays.append(emb)
608
+
609
+ if not embedding_arrays:
610
+ existing_embeddings = np.array([])
611
+ elif len(embedding_arrays) == 1:
612
+ # Single embedding: reshape to (1, dim)
613
+ existing_embeddings = embedding_arrays[0].reshape(1, -1)
614
+ else:
615
+ # Multiple embeddings: vstack
616
+ existing_embeddings = np.vstack(embedding_arrays)
617
+
618
+ comp_start = time_mod.time()
619
+ for embedding in embeddings:
620
+ # Compute cosine similarity with all existing facts
621
+ emb_array = np.array(embedding)
622
+ # Cosine similarity = 1 - cosine distance
623
+ # For normalized vectors: cosine_sim = dot product
624
+ similarities = np.dot(existing_embeddings, emb_array)
625
+
626
+ # Check if any existing fact is too similar
627
+ max_similarity = np.max(similarities) if len(similarities) > 0 else 0
628
+ is_duplicate.append(max_similarity > similarity_threshold)
629
+
630
+
631
+ return is_duplicate
632
+
633
+ def retain(
634
+ self,
635
+ bank_id: str,
636
+ content: str,
637
+ context: str = "",
638
+ event_date: Optional[datetime] = None,
639
+ ) -> List[str]:
640
+ """
641
+ Store content as memory units (synchronous wrapper).
642
+
643
+ This is a synchronous wrapper around retain_async() for convenience.
644
+ For best performance, use retain_async() directly.
645
+
646
+ Args:
647
+ bank_id: Unique identifier for the bank
648
+ content: Text content to store
649
+ context: Context about when/why this memory was formed
650
+ event_date: When the event occurred (defaults to now)
651
+
652
+ Returns:
653
+ List of created unit IDs
654
+ """
655
+ # Run async version synchronously
656
+ return asyncio.run(self.retain_async(bank_id, content, context, event_date))
657
+
658
+ async def retain_async(
659
+ self,
660
+ bank_id: str,
661
+ content: str,
662
+ context: str = "",
663
+ event_date: Optional[datetime] = None,
664
+ document_id: Optional[str] = None,
665
+ fact_type_override: Optional[str] = None,
666
+ confidence_score: Optional[float] = None,
667
+ ) -> List[str]:
668
+ """
669
+ Store content as memory units with temporal and semantic links (ASYNC version).
670
+
671
+ This is a convenience wrapper around retain_batch_async for a single content item.
672
+
673
+ Args:
674
+ bank_id: Unique identifier for the bank
675
+ content: Text content to store
676
+ context: Context about when/why this memory was formed
677
+ event_date: When the event occurred (defaults to now)
678
+ document_id: Optional document ID for tracking (always upserts if document already exists)
679
+ fact_type_override: Override fact type ('world', 'bank', 'opinion')
680
+ confidence_score: Confidence score for opinions (0.0 to 1.0)
681
+
682
+ Returns:
683
+ List of created unit IDs
684
+ """
685
+ # Build content dict
686
+ content_dict: RetainContentDict = {
687
+ "content": content,
688
+ "context": context,
689
+ "event_date": event_date
690
+ }
691
+ if document_id:
692
+ content_dict["document_id"] = document_id
693
+
694
+ # Use retain_batch_async with a single item (avoids code duplication)
695
+ result = await self.retain_batch_async(
696
+ bank_id=bank_id,
697
+ contents=[content_dict],
698
+ fact_type_override=fact_type_override,
699
+ confidence_score=confidence_score
700
+ )
701
+
702
+ # Return the first (and only) list of unit IDs
703
+ return result[0] if result else []
704
+
705
+ async def retain_batch_async(
706
+ self,
707
+ bank_id: str,
708
+ contents: List[RetainContentDict],
709
+ document_id: Optional[str] = None,
710
+ fact_type_override: Optional[str] = None,
711
+ confidence_score: Optional[float] = None,
712
+ ) -> List[List[str]]:
713
+ """
714
+ Store multiple content items as memory units in ONE batch operation.
715
+
716
+ This is MUCH more efficient than calling retain_async multiple times:
717
+ - Extracts facts from all contents in parallel
718
+ - Generates ALL embeddings in ONE batch
719
+ - Does ALL database operations in ONE transaction
720
+ - Automatically chunks large batches to prevent timeouts
721
+
722
+ Args:
723
+ bank_id: Unique identifier for the bank
724
+ contents: List of dicts with keys:
725
+ - "content" (required): Text content to store
726
+ - "context" (optional): Context about the memory
727
+ - "event_date" (optional): When the event occurred
728
+ - "document_id" (optional): Document ID for this specific content item
729
+ document_id: **DEPRECATED** - Use "document_id" key in each content dict instead.
730
+ Applies the same document_id to ALL content items that don't specify their own.
731
+ fact_type_override: Override fact type for all facts ('world', 'bank', 'opinion')
732
+ confidence_score: Confidence score for opinions (0.0 to 1.0)
733
+
734
+ Returns:
735
+ List of lists of unit IDs (one list per content item)
736
+
737
+ Example (new style - per-content document_id):
738
+ unit_ids = await memory.retain_batch_async(
739
+ bank_id="user123",
740
+ contents=[
741
+ {"content": "Alice works at Google", "document_id": "doc1"},
742
+ {"content": "Bob loves Python", "document_id": "doc2"},
743
+ {"content": "More about Alice", "document_id": "doc1"},
744
+ ]
745
+ )
746
+ # Returns: [["unit-id-1"], ["unit-id-2"], ["unit-id-3"]]
747
+
748
+ Example (deprecated style - batch-level document_id):
749
+ unit_ids = await memory.retain_batch_async(
750
+ bank_id="user123",
751
+ contents=[
752
+ {"content": "Alice works at Google"},
753
+ {"content": "Bob loves Python"},
754
+ ],
755
+ document_id="meeting-2024-01-15"
756
+ )
757
+ # Returns: [["unit-id-1"], ["unit-id-2"]]
758
+ """
759
+ start_time = time.time()
760
+
761
+ if not contents:
762
+ return []
763
+
764
+ # Apply batch-level document_id to contents that don't have their own (backwards compatibility)
765
+ if document_id:
766
+ for item in contents:
767
+ if "document_id" not in item:
768
+ item["document_id"] = document_id
769
+
770
+ # Auto-chunk large batches by character count to avoid timeouts and memory issues
771
+ # Calculate total character count
772
+ total_chars = sum(len(item.get("content", "")) for item in contents)
773
+
774
+ CHARS_PER_BATCH = 600_000
775
+
776
+ if total_chars > CHARS_PER_BATCH:
777
+ # Split into smaller batches based on character count
778
+ logger.info(f"Large batch detected ({total_chars:,} chars from {len(contents)} items). Splitting into sub-batches of ~{CHARS_PER_BATCH:,} chars each...")
779
+
780
+ sub_batches = []
781
+ current_batch = []
782
+ current_batch_chars = 0
783
+
784
+ for item in contents:
785
+ item_chars = len(item.get("content", ""))
786
+
787
+ # If adding this item would exceed the limit, start a new batch
788
+ # (unless current batch is empty - then we must include it even if it's large)
789
+ if current_batch and current_batch_chars + item_chars > CHARS_PER_BATCH:
790
+ sub_batches.append(current_batch)
791
+ current_batch = [item]
792
+ current_batch_chars = item_chars
793
+ else:
794
+ current_batch.append(item)
795
+ current_batch_chars += item_chars
796
+
797
+ # Add the last batch
798
+ if current_batch:
799
+ sub_batches.append(current_batch)
800
+
801
+ logger.info(f"Split into {len(sub_batches)} sub-batches: {[len(b) for b in sub_batches]} items each")
802
+
803
+ # Process each sub-batch using internal method (skip chunking check)
804
+ all_results = []
805
+ for i, sub_batch in enumerate(sub_batches, 1):
806
+ sub_batch_chars = sum(len(item.get("content", "")) for item in sub_batch)
807
+ logger.info(f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars")
808
+
809
+ sub_results = await self._retain_batch_async_internal(
810
+ bank_id=bank_id,
811
+ contents=sub_batch,
812
+ document_id=document_id,
813
+ is_first_batch=i == 1, # Only upsert on first batch
814
+ fact_type_override=fact_type_override,
815
+ confidence_score=confidence_score
816
+ )
817
+ all_results.extend(sub_results)
818
+
819
+ total_time = time.time() - start_time
820
+ logger.info(f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s")
821
+ return all_results
822
+
823
+ # Small batch - use internal method directly
824
+ return await self._retain_batch_async_internal(
825
+ bank_id=bank_id,
826
+ contents=contents,
827
+ document_id=document_id,
828
+ is_first_batch=True,
829
+ fact_type_override=fact_type_override,
830
+ confidence_score=confidence_score
831
+ )
832
+
833
+ async def _retain_batch_async_internal(
834
+ self,
835
+ bank_id: str,
836
+ contents: List[RetainContentDict],
837
+ document_id: Optional[str] = None,
838
+ is_first_batch: bool = True,
839
+ fact_type_override: Optional[str] = None,
840
+ confidence_score: Optional[float] = None,
841
+ ) -> List[List[str]]:
842
+ """
843
+ Internal method for batch processing without chunking logic.
844
+
845
+ Assumes contents are already appropriately sized (< 50k chars).
846
+ Called by retain_batch_async after chunking large batches.
847
+
848
+ Uses semaphore for backpressure to limit concurrent retains.
849
+
850
+ Args:
851
+ bank_id: Unique identifier for the bank
852
+ contents: List of dicts with content, context, event_date
853
+ document_id: Optional document ID (always upserts if exists)
854
+ is_first_batch: Whether this is the first batch (for chunked operations, only delete on first batch)
855
+ fact_type_override: Override fact type for all facts
856
+ confidence_score: Confidence score for opinions
857
+ """
858
+ # Backpressure: limit concurrent retains to prevent database contention
859
+ async with self._put_semaphore:
860
+ # Use the new modular orchestrator
861
+ from .retain import orchestrator
862
+
863
+ pool = await self._get_pool()
864
+ return await orchestrator.retain_batch(
865
+ pool=pool,
866
+ embeddings_model=self.embeddings,
867
+ llm_config=self._llm_config,
868
+ entity_resolver=self.entity_resolver,
869
+ task_backend=self._task_backend,
870
+ format_date_fn=self._format_readable_date,
871
+ duplicate_checker_fn=self._find_duplicate_facts_batch,
872
+ regenerate_observations_fn=self._regenerate_observations_sync,
873
+ bank_id=bank_id,
874
+ contents_dicts=contents,
875
+ document_id=document_id,
876
+ is_first_batch=is_first_batch,
877
+ fact_type_override=fact_type_override,
878
+ confidence_score=confidence_score
879
+ )
880
+
881
+ def recall(
882
+ self,
883
+ bank_id: str,
884
+ query: str,
885
+ fact_type: str,
886
+ budget: Budget = Budget.MID,
887
+ max_tokens: int = 4096,
888
+ enable_trace: bool = False,
889
+ ) -> tuple[List[Dict[str, Any]], Optional[Any]]:
890
+ """
891
+ Recall memories using 4-way parallel retrieval (synchronous wrapper).
892
+
893
+ This is a synchronous wrapper around recall_async() for convenience.
894
+ For best performance, use recall_async() directly.
895
+
896
+ Args:
897
+ bank_id: bank ID to recall for
898
+ query: Recall query
899
+ fact_type: Required filter for fact type ('world', 'agent', or 'opinion')
900
+ budget: Budget level for graph traversal (low=100, mid=300, high=600 units)
901
+ max_tokens: Maximum tokens to return (counts only 'text' field, default 4096)
902
+ enable_trace: If True, returns detailed trace object
903
+
904
+ Returns:
905
+ Tuple of (results, trace)
906
+ """
907
+ # Run async version synchronously
908
+ return asyncio.run(self.recall_async(
909
+ bank_id, query, [fact_type], budget, max_tokens, enable_trace
910
+ ))
911
+
912
+ async def recall_async(
913
+ self,
914
+ bank_id: str,
915
+ query: str,
916
+ fact_type: List[str],
917
+ budget: Budget = Budget.MID,
918
+ max_tokens: int = 4096,
919
+ enable_trace: bool = False,
920
+ question_date: Optional[datetime] = None,
921
+ include_entities: bool = False,
922
+ max_entity_tokens: int = 1024,
923
+ include_chunks: bool = False,
924
+ max_chunk_tokens: int = 8192,
925
+ ) -> RecallResultModel:
926
+ """
927
+ Recall memories using N*4-way parallel retrieval (N fact types × 4 retrieval methods).
928
+
929
+ This implements the core RECALL operation:
930
+ 1. Retrieval: For each fact type, run 4 parallel retrievals (semantic vector, BM25 keyword, graph activation, temporal graph)
931
+ 2. Merge: Combine using Reciprocal Rank Fusion (RRF)
932
+ 3. Rerank: Score using selected reranker (heuristic or cross-encoder)
933
+ 4. Diversify: Apply MMR for diversity
934
+ 5. Token Filter: Return results up to max_tokens budget
935
+
936
+ Args:
937
+ bank_id: bank ID to recall for
938
+ query: Recall query
939
+ fact_type: List of fact types to recall (e.g., ['world', 'bank'])
940
+ budget: Budget level for graph traversal (low=100, mid=300, high=600 units)
941
+ max_tokens: Maximum tokens to return (counts only 'text' field, default 4096)
942
+ Results are returned until token budget is reached, stopping before
943
+ including a fact that would exceed the limit
944
+ enable_trace: Whether to return trace for debugging (deprecated)
945
+ question_date: Optional date when question was asked (for temporal filtering)
946
+ include_entities: Whether to include entity observations in the response
947
+ max_entity_tokens: Maximum tokens for entity observations (default 500)
948
+ include_chunks: Whether to include raw chunks in the response
949
+ max_chunk_tokens: Maximum tokens for chunks (default 8192)
950
+
951
+ Returns:
952
+ RecallResultModel containing:
953
+ - results: List of MemoryFact objects
954
+ - trace: Optional trace information for debugging
955
+ - entities: Optional dict of entity states (if include_entities=True)
956
+ - chunks: Optional dict of chunks (if include_chunks=True)
957
+ """
958
+ # Map budget enum to thinking_budget number
959
+ budget_mapping = {
960
+ Budget.LOW: 100,
961
+ Budget.MID: 300,
962
+ Budget.HIGH: 600
963
+ }
964
+ thinking_budget = budget_mapping[budget]
965
+
966
+ # Backpressure: limit concurrent recalls to prevent overwhelming the database
967
+ async with self._search_semaphore:
968
+ # Retry loop for connection errors
969
+ max_retries = 3
970
+ for attempt in range(max_retries + 1):
971
+ try:
972
+ return await self._search_with_retries(
973
+ bank_id, query, fact_type, thinking_budget, max_tokens, enable_trace, question_date,
974
+ include_entities, max_entity_tokens, include_chunks, max_chunk_tokens
975
+ )
976
+ except Exception as e:
977
+ # Check if it's a connection error
978
+ is_connection_error = (
979
+ isinstance(e, asyncpg.TooManyConnectionsError) or
980
+ isinstance(e, asyncpg.CannotConnectNowError) or
981
+ (isinstance(e, asyncpg.PostgresError) and 'connection' in str(e).lower())
982
+ )
983
+
984
+ if is_connection_error and attempt < max_retries:
985
+ # Wait with exponential backoff before retry
986
+ wait_time = 0.5 * (2 ** attempt) # 0.5s, 1s, 2s
987
+ logger.warning(
988
+ f"Connection error on search attempt {attempt + 1}/{max_retries + 1}: {str(e)}. "
989
+ f"Retrying in {wait_time:.1f}s..."
990
+ )
991
+ await asyncio.sleep(wait_time)
992
+ else:
993
+ # Not a connection error or out of retries - raise
994
+ raise
995
+ raise Exception("Exceeded maximum retries for search due to connection errors.")
996
+
997
+ async def _search_with_retries(
998
+ self,
999
+ bank_id: str,
1000
+ query: str,
1001
+ fact_type: List[str],
1002
+ thinking_budget: int,
1003
+ max_tokens: int,
1004
+ enable_trace: bool,
1005
+ question_date: Optional[datetime] = None,
1006
+ include_entities: bool = False,
1007
+ max_entity_tokens: int = 500,
1008
+ include_chunks: bool = False,
1009
+ max_chunk_tokens: int = 8192,
1010
+ ) -> RecallResultModel:
1011
+ """
1012
+ Search implementation with modular retrieval and reranking.
1013
+
1014
+ Architecture:
1015
+ 1. Retrieval: 4-way parallel (semantic, keyword, graph, temporal graph)
1016
+ 2. Merge: RRF to combine ranked lists
1017
+ 3. Reranking: Pluggable strategy (heuristic or cross-encoder)
1018
+ 4. Diversity: MMR with λ=0.5
1019
+ 5. Token Filter: Limit results to max_tokens budget
1020
+
1021
+ Args:
1022
+ bank_id: bank IDentifier
1023
+ query: Search query
1024
+ fact_type: Type of facts to search
1025
+ thinking_budget: Nodes to explore in graph traversal
1026
+ max_tokens: Maximum tokens to return (counts only 'text' field)
1027
+ enable_trace: Whether to return search trace (deprecated)
1028
+ include_entities: Whether to include entity observations
1029
+ max_entity_tokens: Maximum tokens for entity observations
1030
+ include_chunks: Whether to include raw chunks
1031
+ max_chunk_tokens: Maximum tokens for chunks
1032
+
1033
+ Returns:
1034
+ RecallResultModel with results, trace, optional entities, and optional chunks
1035
+ """
1036
+ # Initialize tracer if requested
1037
+ from .search.tracer import SearchTracer
1038
+ tracer = SearchTracer(query, thinking_budget, max_tokens) if enable_trace else None
1039
+ if tracer:
1040
+ tracer.start()
1041
+
1042
+ pool = await self._get_pool()
1043
+ search_start = time.time()
1044
+
1045
+ # Buffer logs for clean output in concurrent scenarios
1046
+ search_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
1047
+ log_buffer = []
1048
+ log_buffer.append(f"[SEARCH {search_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})")
1049
+
1050
+ try:
1051
+ # Step 1: Generate query embedding (for semantic search)
1052
+ step_start = time.time()
1053
+ query_embedding = embedding_utils.generate_embedding(self.embeddings, query)
1054
+ step_duration = time.time() - step_start
1055
+ log_buffer.append(f" [1] Generate query embedding: {step_duration:.3f}s")
1056
+
1057
+ if tracer:
1058
+ tracer.record_query_embedding(query_embedding)
1059
+ tracer.add_phase_metric("generate_query_embedding", step_duration)
1060
+
1061
+ # Step 2: N*4-Way Parallel Retrieval (N fact types × 4 retrieval methods)
1062
+ step_start = time.time()
1063
+ query_embedding_str = str(query_embedding)
1064
+
1065
+ from .search.retrieval import retrieve_parallel
1066
+
1067
+ # Track each retrieval start time
1068
+ retrieval_start = time.time()
1069
+
1070
+ # Run retrieval for each fact type in parallel
1071
+ retrieval_tasks = [
1072
+ retrieve_parallel(
1073
+ pool, query, query_embedding_str, bank_id, ft, thinking_budget,
1074
+ question_date, self.query_analyzer
1075
+ )
1076
+ for ft in fact_type
1077
+ ]
1078
+ all_retrievals = await asyncio.gather(*retrieval_tasks)
1079
+
1080
+ # Combine all results from all fact types and aggregate timings
1081
+ semantic_results = []
1082
+ bm25_results = []
1083
+ graph_results = []
1084
+ temporal_results = []
1085
+ aggregated_timings = {"semantic": 0.0, "bm25": 0.0, "graph": 0.0, "temporal": 0.0}
1086
+
1087
+ detected_temporal_constraint = None
1088
+ for idx, (ft_semantic, ft_bm25, ft_graph, ft_temporal, ft_timings, ft_temporal_constraint) in enumerate(all_retrievals):
1089
+ # Log fact types in this retrieval batch
1090
+ ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1091
+ logger.debug(f"[SEARCH {search_id}] Fact type '{ft_name}': semantic={len(ft_semantic)}, bm25={len(ft_bm25)}, graph={len(ft_graph)}, temporal={len(ft_temporal) if ft_temporal else 0}")
1092
+
1093
+ semantic_results.extend(ft_semantic)
1094
+ bm25_results.extend(ft_bm25)
1095
+ graph_results.extend(ft_graph)
1096
+ if ft_temporal:
1097
+ temporal_results.extend(ft_temporal)
1098
+ # Track max timing for each method (since they run in parallel across fact types)
1099
+ for method, duration in ft_timings.items():
1100
+ aggregated_timings[method] = max(aggregated_timings[method], duration)
1101
+ # Capture temporal constraint (same across all fact types)
1102
+ if ft_temporal_constraint:
1103
+ detected_temporal_constraint = ft_temporal_constraint
1104
+
1105
+ # If no temporal results from any fact type, set to None
1106
+ if not temporal_results:
1107
+ temporal_results = None
1108
+
1109
+ # Sort combined results by score (descending) so higher-scored results
1110
+ # get better ranks in the trace, regardless of fact type
1111
+ semantic_results.sort(key=lambda r: r.similarity if hasattr(r, 'similarity') else 0, reverse=True)
1112
+ bm25_results.sort(key=lambda r: r.bm25_score if hasattr(r, 'bm25_score') else 0, reverse=True)
1113
+ graph_results.sort(key=lambda r: r.activation if hasattr(r, 'activation') else 0, reverse=True)
1114
+ if temporal_results:
1115
+ temporal_results.sort(key=lambda r: r.combined_score if hasattr(r, 'combined_score') else 0, reverse=True)
1116
+
1117
+ retrieval_duration = time.time() - retrieval_start
1118
+
1119
+ step_duration = time.time() - step_start
1120
+ total_retrievals = len(fact_type) * (4 if temporal_results else 3)
1121
+ # Format per-method timings
1122
+ timing_parts = [
1123
+ f"semantic={len(semantic_results)}({aggregated_timings['semantic']:.3f}s)",
1124
+ f"bm25={len(bm25_results)}({aggregated_timings['bm25']:.3f}s)",
1125
+ f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)"
1126
+ ]
1127
+ temporal_info = ""
1128
+ if detected_temporal_constraint:
1129
+ start_dt, end_dt = detected_temporal_constraint
1130
+ temporal_count = len(temporal_results) if temporal_results else 0
1131
+ timing_parts.append(f"temporal={temporal_count}({aggregated_timings['temporal']:.3f}s)")
1132
+ temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
1133
+ log_buffer.append(f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}")
1134
+
1135
+ # Record retrieval results for tracer (convert typed results to old format)
1136
+ if tracer:
1137
+ # Convert RetrievalResult to old tuple format for tracer
1138
+ def to_tuple_format(results):
1139
+ return [(r.id, r.__dict__) for r in results]
1140
+
1141
+ # Add semantic retrieval results
1142
+ tracer.add_retrieval_results(
1143
+ method_name="semantic",
1144
+ results=to_tuple_format(semantic_results),
1145
+ duration_seconds=aggregated_timings["semantic"],
1146
+ score_field="similarity",
1147
+ metadata={"limit": thinking_budget}
1148
+ )
1149
+
1150
+ # Add BM25 retrieval results
1151
+ tracer.add_retrieval_results(
1152
+ method_name="bm25",
1153
+ results=to_tuple_format(bm25_results),
1154
+ duration_seconds=aggregated_timings["bm25"],
1155
+ score_field="bm25_score",
1156
+ metadata={"limit": thinking_budget}
1157
+ )
1158
+
1159
+ # Add graph retrieval results
1160
+ tracer.add_retrieval_results(
1161
+ method_name="graph",
1162
+ results=to_tuple_format(graph_results),
1163
+ duration_seconds=aggregated_timings["graph"],
1164
+ score_field="similarity", # Graph uses similarity for activation
1165
+ metadata={"budget": thinking_budget}
1166
+ )
1167
+
1168
+ # Add temporal retrieval results if present
1169
+ if temporal_results:
1170
+ tracer.add_retrieval_results(
1171
+ method_name="temporal",
1172
+ results=to_tuple_format(temporal_results),
1173
+ duration_seconds=aggregated_timings["temporal"],
1174
+ score_field="temporal_score",
1175
+ metadata={"budget": thinking_budget}
1176
+ )
1177
+
1178
+ # Record entry points (from semantic results) for legacy graph view
1179
+ for rank, retrieval in enumerate(semantic_results[:10], start=1): # Top 10 as entry points
1180
+ tracer.add_entry_point(retrieval.id, retrieval.text, retrieval.similarity or 0.0, rank)
1181
+
1182
+ tracer.add_phase_metric("parallel_retrieval", step_duration, {
1183
+ "semantic_count": len(semantic_results),
1184
+ "bm25_count": len(bm25_results),
1185
+ "graph_count": len(graph_results),
1186
+ "temporal_count": len(temporal_results) if temporal_results else 0
1187
+ })
1188
+
1189
+ # Step 3: Merge with RRF
1190
+ step_start = time.time()
1191
+ from .search.fusion import reciprocal_rank_fusion
1192
+
1193
+ # Merge 3 or 4 result lists depending on temporal constraint
1194
+ if temporal_results:
1195
+ merged_candidates = reciprocal_rank_fusion([semantic_results, bm25_results, graph_results, temporal_results])
1196
+ else:
1197
+ merged_candidates = reciprocal_rank_fusion([semantic_results, bm25_results, graph_results])
1198
+
1199
+ step_duration = time.time() - step_start
1200
+ log_buffer.append(f" [3] RRF merge: {len(merged_candidates)} unique candidates in {step_duration:.3f}s")
1201
+
1202
+ if tracer:
1203
+ # Convert MergedCandidate to old tuple format for tracer
1204
+ tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1205
+ for mc in merged_candidates]
1206
+ tracer.add_rrf_merged(tracer_merged)
1207
+ tracer.add_phase_metric("rrf_merge", step_duration, {"candidates_merged": len(merged_candidates)})
1208
+
1209
+ # Step 4: Rerank using cross-encoder (MergedCandidate -> ScoredResult)
1210
+ step_start = time.time()
1211
+ reranker_instance = self._cross_encoder_reranker
1212
+ log_buffer.append(f" [4] Using cross-encoder reranker")
1213
+
1214
+ # Rerank using cross-encoder
1215
+ scored_results = reranker_instance.rerank(query, merged_candidates)
1216
+
1217
+ step_duration = time.time() - step_start
1218
+ log_buffer.append(f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s")
1219
+
1220
+ if tracer:
1221
+ # Convert to old format for tracer
1222
+ results_dict = [sr.to_dict() for sr in scored_results]
1223
+ tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1224
+ for mc in merged_candidates]
1225
+ tracer.add_reranked(results_dict, tracer_merged)
1226
+ tracer.add_phase_metric("reranking", step_duration, {
1227
+ "reranker_type": "cross-encoder",
1228
+ "candidates_reranked": len(scored_results)
1229
+ })
1230
+
1231
+ # Step 4.5: Combine cross-encoder score with retrieval signals
1232
+ # This preserves retrieval work (RRF, temporal, recency) instead of pure cross-encoder ranking
1233
+ if scored_results:
1234
+ # Normalize RRF scores to [0, 1] range
1235
+ rrf_scores = [sr.candidate.rrf_score for sr in scored_results]
1236
+ max_rrf = max(rrf_scores) if rrf_scores else 1.0
1237
+ min_rrf = min(rrf_scores) if rrf_scores else 0.0
1238
+ rrf_range = max_rrf - min_rrf if max_rrf > min_rrf else 1.0
1239
+
1240
+ # Calculate recency based on occurred_start (more recent = higher score)
1241
+ now = utcnow()
1242
+ for sr in scored_results:
1243
+ # Normalize RRF score
1244
+ sr.rrf_normalized = (sr.candidate.rrf_score - min_rrf) / rrf_range if rrf_range > 0 else 0.5
1245
+
1246
+ # Calculate recency (decay over 365 days, minimum 0.1)
1247
+ sr.recency = 0.5 # default for missing dates
1248
+ if sr.retrieval.occurred_start:
1249
+ occurred = sr.retrieval.occurred_start
1250
+ if hasattr(occurred, 'tzinfo') and occurred.tzinfo is None:
1251
+ from datetime import timezone
1252
+ occurred = occurred.replace(tzinfo=timezone.utc)
1253
+ days_ago = (now - occurred).total_seconds() / 86400
1254
+ sr.recency = max(0.1, 1.0 - (days_ago / 365)) # Linear decay over 1 year
1255
+
1256
+ # Get temporal proximity if available (already 0-1)
1257
+ sr.temporal = sr.retrieval.temporal_proximity if sr.retrieval.temporal_proximity is not None else 0.5
1258
+
1259
+ # Weighted combination
1260
+ # Cross-encoder: 60% (semantic relevance)
1261
+ # RRF: 20% (retrieval consensus)
1262
+ # Temporal proximity: 10% (time relevance for temporal queries)
1263
+ # Recency: 10% (prefer recent facts)
1264
+ sr.combined_score = (
1265
+ 0.6 * sr.cross_encoder_score_normalized +
1266
+ 0.2 * sr.rrf_normalized +
1267
+ 0.1 * sr.temporal +
1268
+ 0.1 * sr.recency
1269
+ )
1270
+ sr.weight = sr.combined_score # Update weight for final ranking
1271
+
1272
+ # Re-sort by combined score
1273
+ scored_results.sort(key=lambda x: x.weight, reverse=True)
1274
+ log_buffer.append(f" [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)")
1275
+
1276
+ # Step 5: Truncate to thinking_budget * 2 for token filtering
1277
+ rerank_limit = thinking_budget * 2
1278
+ top_scored = scored_results[:rerank_limit]
1279
+ log_buffer.append(f" [5] Truncated to top {len(top_scored)} results")
1280
+
1281
+ # Step 6: Token budget filtering
1282
+ step_start = time.time()
1283
+
1284
+ # Convert to dict for token filtering (backward compatibility)
1285
+ top_dicts = [sr.to_dict() for sr in top_scored]
1286
+ filtered_dicts, total_tokens = self._filter_by_token_budget(top_dicts, max_tokens)
1287
+
1288
+ # Convert back to list of IDs and filter scored_results
1289
+ filtered_ids = {d["id"] for d in filtered_dicts}
1290
+ top_scored = [sr for sr in top_scored if sr.id in filtered_ids]
1291
+
1292
+ step_duration = time.time() - step_start
1293
+ log_buffer.append(f" [6] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s")
1294
+
1295
+ if tracer:
1296
+ tracer.add_phase_metric("token_filtering", step_duration, {
1297
+ "results_selected": len(top_scored),
1298
+ "tokens_used": total_tokens,
1299
+ "max_tokens": max_tokens
1300
+ })
1301
+
1302
+ # Record visits for all retrieved nodes
1303
+ if tracer:
1304
+ for sr in scored_results:
1305
+ tracer.visit_node(
1306
+ node_id=sr.id,
1307
+ text=sr.retrieval.text,
1308
+ context=sr.retrieval.context or "",
1309
+ event_date=sr.retrieval.occurred_start,
1310
+ access_count=sr.retrieval.access_count,
1311
+ is_entry_point=(sr.id in [ep.node_id for ep in tracer.entry_points]),
1312
+ parent_node_id=None, # In parallel retrieval, there's no clear parent
1313
+ link_type=None,
1314
+ link_weight=None,
1315
+ activation=sr.candidate.rrf_score, # Use RRF score as activation
1316
+ semantic_similarity=sr.retrieval.similarity or 0.0,
1317
+ recency=sr.recency,
1318
+ frequency=0.0,
1319
+ final_weight=sr.weight
1320
+ )
1321
+
1322
+ # Step 8: Queue access count updates for visited nodes
1323
+ visited_ids = list(set([sr.id for sr in scored_results[:50]])) # Top 50
1324
+ if visited_ids:
1325
+ await self._task_backend.submit_task({
1326
+ 'type': 'access_count_update',
1327
+ 'node_ids': visited_ids
1328
+ })
1329
+ log_buffer.append(f" [7] Queued access count updates for {len(visited_ids)} nodes")
1330
+
1331
+ # Log fact_type distribution in results
1332
+ fact_type_counts = {}
1333
+ for sr in top_scored:
1334
+ ft = sr.retrieval.fact_type
1335
+ fact_type_counts[ft] = fact_type_counts.get(ft, 0) + 1
1336
+
1337
+ total_time = time.time() - search_start
1338
+ fact_type_summary = ", ".join([f"{ft}={count}" for ft, count in sorted(fact_type_counts.items())])
1339
+ log_buffer.append(f"[SEARCH {search_id}] Complete: {len(top_scored)} results ({fact_type_summary}) ({total_tokens} tokens) in {total_time:.3f}s")
1340
+
1341
+ # Log all buffered logs at once
1342
+ logger.info("\n" + "\n".join(log_buffer))
1343
+
1344
+ # Convert ScoredResult to dicts with ISO datetime strings
1345
+ top_results_dicts = []
1346
+ for sr in top_scored:
1347
+ result_dict = sr.to_dict()
1348
+ # Convert datetime objects to ISO strings for JSON serialization
1349
+ if result_dict.get("occurred_start"):
1350
+ occurred_start = result_dict["occurred_start"]
1351
+ result_dict["occurred_start"] = occurred_start.isoformat() if hasattr(occurred_start, 'isoformat') else occurred_start
1352
+ if result_dict.get("occurred_end"):
1353
+ occurred_end = result_dict["occurred_end"]
1354
+ result_dict["occurred_end"] = occurred_end.isoformat() if hasattr(occurred_end, 'isoformat') else occurred_end
1355
+ if result_dict.get("mentioned_at"):
1356
+ mentioned_at = result_dict["mentioned_at"]
1357
+ result_dict["mentioned_at"] = mentioned_at.isoformat() if hasattr(mentioned_at, 'isoformat') else mentioned_at
1358
+ top_results_dicts.append(result_dict)
1359
+
1360
+ # Get entities for each fact if include_entities is requested
1361
+ fact_entity_map = {} # unit_id -> list of (entity_id, entity_name)
1362
+ if include_entities and top_scored:
1363
+ unit_ids = [uuid.UUID(sr.id) for sr in top_scored]
1364
+ if unit_ids:
1365
+ async with acquire_with_retry(pool) as entity_conn:
1366
+ entity_rows = await entity_conn.fetch(
1367
+ """
1368
+ SELECT ue.unit_id, e.id as entity_id, e.canonical_name
1369
+ FROM unit_entities ue
1370
+ JOIN entities e ON ue.entity_id = e.id
1371
+ WHERE ue.unit_id = ANY($1::uuid[])
1372
+ """,
1373
+ unit_ids
1374
+ )
1375
+ for row in entity_rows:
1376
+ unit_id = str(row['unit_id'])
1377
+ if unit_id not in fact_entity_map:
1378
+ fact_entity_map[unit_id] = []
1379
+ fact_entity_map[unit_id].append({
1380
+ 'entity_id': str(row['entity_id']),
1381
+ 'canonical_name': row['canonical_name']
1382
+ })
1383
+
1384
+ # Convert results to MemoryFact objects
1385
+ memory_facts = []
1386
+ for result_dict in top_results_dicts:
1387
+ result_id = str(result_dict.get("id"))
1388
+ # Get entity names for this fact
1389
+ entity_names = None
1390
+ if include_entities and result_id in fact_entity_map:
1391
+ entity_names = [e['canonical_name'] for e in fact_entity_map[result_id]]
1392
+
1393
+ memory_facts.append(MemoryFact(
1394
+ id=result_id,
1395
+ text=result_dict.get("text"),
1396
+ fact_type=result_dict.get("fact_type", "world"),
1397
+ entities=entity_names,
1398
+ context=result_dict.get("context"),
1399
+ occurred_start=result_dict.get("occurred_start"),
1400
+ occurred_end=result_dict.get("occurred_end"),
1401
+ mentioned_at=result_dict.get("mentioned_at"),
1402
+ document_id=result_dict.get("document_id"),
1403
+ chunk_id=result_dict.get("chunk_id"),
1404
+ activation=result_dict.get("weight") # Use final weight as activation
1405
+ ))
1406
+
1407
+ # Fetch entity observations if requested
1408
+ entities_dict = None
1409
+ if include_entities and fact_entity_map:
1410
+ # Collect unique entities in order of fact relevance (preserving order from top_scored)
1411
+ # Use a list to maintain order, but track seen entities to avoid duplicates
1412
+ entities_ordered = [] # list of (entity_id, entity_name) tuples
1413
+ seen_entity_ids = set()
1414
+
1415
+ # Iterate through facts in relevance order
1416
+ for sr in top_scored:
1417
+ unit_id = sr.id
1418
+ if unit_id in fact_entity_map:
1419
+ for entity in fact_entity_map[unit_id]:
1420
+ entity_id = entity['entity_id']
1421
+ entity_name = entity['canonical_name']
1422
+ if entity_id not in seen_entity_ids:
1423
+ entities_ordered.append((entity_id, entity_name))
1424
+ seen_entity_ids.add(entity_id)
1425
+
1426
+ # Fetch observations for each entity (respect token budget, in order)
1427
+ entities_dict = {}
1428
+ total_entity_tokens = 0
1429
+ encoding = _get_tiktoken_encoding()
1430
+
1431
+ for entity_id, entity_name in entities_ordered:
1432
+ if total_entity_tokens >= max_entity_tokens:
1433
+ break
1434
+
1435
+ observations = await self.get_entity_observations(bank_id, entity_id, limit=5)
1436
+
1437
+ # Calculate tokens for this entity's observations
1438
+ entity_tokens = 0
1439
+ included_observations = []
1440
+ for obs in observations:
1441
+ obs_tokens = len(encoding.encode(obs.text))
1442
+ if total_entity_tokens + entity_tokens + obs_tokens <= max_entity_tokens:
1443
+ included_observations.append(obs)
1444
+ entity_tokens += obs_tokens
1445
+ else:
1446
+ break
1447
+
1448
+ if included_observations:
1449
+ entities_dict[entity_name] = EntityState(
1450
+ entity_id=entity_id,
1451
+ canonical_name=entity_name,
1452
+ observations=included_observations
1453
+ )
1454
+ total_entity_tokens += entity_tokens
1455
+
1456
+ # Fetch chunks if requested
1457
+ chunks_dict = None
1458
+ if include_chunks and top_scored:
1459
+ from .response_models import ChunkInfo
1460
+
1461
+ # Collect chunk_ids in order of fact relevance (preserving order from top_scored)
1462
+ # Use a list to maintain order, but track seen chunks to avoid duplicates
1463
+ chunk_ids_ordered = []
1464
+ seen_chunk_ids = set()
1465
+ for sr in top_scored:
1466
+ chunk_id = sr.retrieval.chunk_id
1467
+ if chunk_id and chunk_id not in seen_chunk_ids:
1468
+ chunk_ids_ordered.append(chunk_id)
1469
+ seen_chunk_ids.add(chunk_id)
1470
+
1471
+ if chunk_ids_ordered:
1472
+ # Fetch chunk data from database using chunk_ids (no ORDER BY to preserve input order)
1473
+ async with acquire_with_retry(pool) as conn:
1474
+ chunks_rows = await conn.fetch(
1475
+ """
1476
+ SELECT chunk_id, chunk_text, chunk_index
1477
+ FROM chunks
1478
+ WHERE chunk_id = ANY($1::text[])
1479
+ """,
1480
+ chunk_ids_ordered
1481
+ )
1482
+
1483
+ # Create a lookup dict for fast access
1484
+ chunks_lookup = {row['chunk_id']: row for row in chunks_rows}
1485
+
1486
+ # Apply token limit and build chunks_dict in the order of chunk_ids_ordered
1487
+ chunks_dict = {}
1488
+ total_chunk_tokens = 0
1489
+ encoding = _get_tiktoken_encoding()
1490
+
1491
+ for chunk_id in chunk_ids_ordered:
1492
+ if chunk_id not in chunks_lookup:
1493
+ continue
1494
+
1495
+ row = chunks_lookup[chunk_id]
1496
+ chunk_text = row['chunk_text']
1497
+ chunk_tokens = len(encoding.encode(chunk_text))
1498
+
1499
+ # Check if adding this chunk would exceed the limit
1500
+ if total_chunk_tokens + chunk_tokens > max_chunk_tokens:
1501
+ # Truncate the chunk to fit within the remaining budget
1502
+ remaining_tokens = max_chunk_tokens - total_chunk_tokens
1503
+ if remaining_tokens > 0:
1504
+ # Truncate to remaining tokens
1505
+ truncated_text = encoding.decode(encoding.encode(chunk_text)[:remaining_tokens])
1506
+ chunks_dict[chunk_id] = ChunkInfo(
1507
+ chunk_text=truncated_text,
1508
+ chunk_index=row['chunk_index'],
1509
+ truncated=True
1510
+ )
1511
+ total_chunk_tokens = max_chunk_tokens
1512
+ # Stop adding more chunks once we hit the limit
1513
+ break
1514
+ else:
1515
+ chunks_dict[chunk_id] = ChunkInfo(
1516
+ chunk_text=chunk_text,
1517
+ chunk_index=row['chunk_index'],
1518
+ truncated=False
1519
+ )
1520
+ total_chunk_tokens += chunk_tokens
1521
+
1522
+ # Finalize trace if enabled
1523
+ trace_dict = None
1524
+ if tracer:
1525
+ trace = tracer.finalize(top_results_dicts)
1526
+ trace_dict = trace.to_dict() if trace else None
1527
+
1528
+ return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
1529
+
1530
+ except Exception as e:
1531
+ log_buffer.append(f"[SEARCH {search_id}] ERROR after {time.time() - search_start:.3f}s: {str(e)}")
1532
+ logger.error("\n" + "\n".join(log_buffer))
1533
+ raise Exception(f"Failed to search memories: {str(e)}")
1534
+
1535
+ def _filter_by_token_budget(
1536
+ self,
1537
+ results: List[Dict[str, Any]],
1538
+ max_tokens: int
1539
+ ) -> Tuple[List[Dict[str, Any]], int]:
1540
+ """
1541
+ Filter results to fit within token budget.
1542
+
1543
+ Counts tokens only for the 'text' field using tiktoken (cl100k_base encoding).
1544
+ Stops before including a fact that would exceed the budget.
1545
+
1546
+ Args:
1547
+ results: List of search results
1548
+ max_tokens: Maximum tokens allowed
1549
+
1550
+ Returns:
1551
+ Tuple of (filtered_results, total_tokens_used)
1552
+ """
1553
+ encoding = _get_tiktoken_encoding()
1554
+
1555
+ filtered_results = []
1556
+ total_tokens = 0
1557
+
1558
+ for result in results:
1559
+ text = result.get("text", "")
1560
+ text_tokens = len(encoding.encode(text))
1561
+
1562
+ # Check if adding this result would exceed budget
1563
+ if total_tokens + text_tokens <= max_tokens:
1564
+ filtered_results.append(result)
1565
+ total_tokens += text_tokens
1566
+ else:
1567
+ # Stop before including a fact that would exceed limit
1568
+ break
1569
+
1570
+ return filtered_results, total_tokens
1571
+
1572
+ async def get_document(self, document_id: str, bank_id: str) -> Optional[Dict[str, Any]]:
1573
+ """
1574
+ Retrieve document metadata and statistics.
1575
+
1576
+ Args:
1577
+ document_id: Document ID to retrieve
1578
+ bank_id: bank ID that owns the document
1579
+
1580
+ Returns:
1581
+ Dictionary with document info or None if not found
1582
+ """
1583
+ pool = await self._get_pool()
1584
+ async with acquire_with_retry(pool) as conn:
1585
+ doc = await conn.fetchrow(
1586
+ """
1587
+ SELECT d.id, d.bank_id, d.original_text, d.content_hash,
1588
+ d.created_at, d.updated_at, COUNT(mu.id) as unit_count
1589
+ FROM documents d
1590
+ LEFT JOIN memory_units mu ON mu.document_id = d.id
1591
+ WHERE d.id = $1 AND d.bank_id = $2
1592
+ GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
1593
+ """,
1594
+ document_id, bank_id
1595
+ )
1596
+
1597
+ if not doc:
1598
+ return None
1599
+
1600
+ return {
1601
+ "id": doc["id"],
1602
+ "bank_id": doc["bank_id"],
1603
+ "original_text": doc["original_text"],
1604
+ "content_hash": doc["content_hash"],
1605
+ "memory_unit_count": doc["unit_count"],
1606
+ "created_at": doc["created_at"],
1607
+ "updated_at": doc["updated_at"]
1608
+ }
1609
+
1610
+ async def delete_document(self, document_id: str, bank_id: str) -> Dict[str, int]:
1611
+ """
1612
+ Delete a document and all its associated memory units and links.
1613
+
1614
+ Args:
1615
+ document_id: Document ID to delete
1616
+ bank_id: bank ID that owns the document
1617
+
1618
+ Returns:
1619
+ Dictionary with counts of deleted items
1620
+ """
1621
+ pool = await self._get_pool()
1622
+ async with acquire_with_retry(pool) as conn:
1623
+ async with conn.transaction():
1624
+ # Count units before deletion
1625
+ units_count = await conn.fetchval(
1626
+ "SELECT COUNT(*) FROM memory_units WHERE document_id = $1",
1627
+ document_id
1628
+ )
1629
+
1630
+ # Delete document (cascades to memory_units and all their links)
1631
+ deleted = await conn.fetchval(
1632
+ "DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id",
1633
+ document_id, bank_id
1634
+ )
1635
+
1636
+ return {
1637
+ "document_deleted": 1 if deleted else 0,
1638
+ "memory_units_deleted": units_count if deleted else 0
1639
+ }
1640
+
1641
+ async def delete_memory_unit(self, unit_id: str) -> Dict[str, Any]:
1642
+ """
1643
+ Delete a single memory unit and all its associated links.
1644
+
1645
+ Due to CASCADE DELETE constraints, this will automatically delete:
1646
+ - All links from this unit (memory_links where from_unit_id = unit_id)
1647
+ - All links to this unit (memory_links where to_unit_id = unit_id)
1648
+ - All entity associations (unit_entities where unit_id = unit_id)
1649
+
1650
+ Args:
1651
+ unit_id: UUID of the memory unit to delete
1652
+
1653
+ Returns:
1654
+ Dictionary with deletion result
1655
+ """
1656
+ pool = await self._get_pool()
1657
+ async with acquire_with_retry(pool) as conn:
1658
+ async with conn.transaction():
1659
+ # Delete the memory unit (cascades to links and associations)
1660
+ deleted = await conn.fetchval(
1661
+ "DELETE FROM memory_units WHERE id = $1 RETURNING id",
1662
+ unit_id
1663
+ )
1664
+
1665
+ return {
1666
+ "success": deleted is not None,
1667
+ "unit_id": str(deleted) if deleted else None,
1668
+ "message": "Memory unit and all its links deleted successfully" if deleted else "Memory unit not found"
1669
+ }
1670
+
1671
+ async def delete_bank(self, bank_id: str, fact_type: Optional[str] = None) -> Dict[str, int]:
1672
+ """
1673
+ Delete all data for a specific agent (multi-tenant cleanup).
1674
+
1675
+ This is much more efficient than dropping all tables and allows
1676
+ multiple agents to coexist in the same database.
1677
+
1678
+ Deletes (with CASCADE):
1679
+ - All memory units for this bank (optionally filtered by fact_type)
1680
+ - All entities for this bank (if deleting all memory units)
1681
+ - All associated links, unit-entity associations, and co-occurrences
1682
+
1683
+ Args:
1684
+ bank_id: bank ID to delete
1685
+ fact_type: Optional fact type filter (world, bank, opinion). If provided, only deletes memories of that type.
1686
+
1687
+ Returns:
1688
+ Dictionary with counts of deleted items
1689
+ """
1690
+ pool = await self._get_pool()
1691
+ async with acquire_with_retry(pool) as conn:
1692
+ async with conn.transaction():
1693
+ try:
1694
+ if fact_type:
1695
+ # Delete only memories of a specific fact type
1696
+ units_count = await conn.fetchval(
1697
+ "SELECT COUNT(*) FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
1698
+ bank_id, fact_type
1699
+ )
1700
+ await conn.execute(
1701
+ "DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
1702
+ bank_id, fact_type
1703
+ )
1704
+
1705
+ # Note: We don't delete entities when fact_type is specified,
1706
+ # as they may be referenced by other memory units
1707
+ return {
1708
+ "memory_units_deleted": units_count,
1709
+ "entities_deleted": 0
1710
+ }
1711
+ else:
1712
+ # Delete all data for the bank
1713
+ units_count = await conn.fetchval("SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id)
1714
+ entities_count = await conn.fetchval("SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id)
1715
+ documents_count = await conn.fetchval("SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id)
1716
+
1717
+ # Delete documents (cascades to chunks)
1718
+ await conn.execute("DELETE FROM documents WHERE bank_id = $1", bank_id)
1719
+
1720
+ # Delete memory units (cascades to unit_entities, memory_links)
1721
+ await conn.execute("DELETE FROM memory_units WHERE bank_id = $1", bank_id)
1722
+
1723
+ # Delete entities (cascades to unit_entities, entity_cooccurrences, memory_links with entity_id)
1724
+ await conn.execute("DELETE FROM entities WHERE bank_id = $1", bank_id)
1725
+
1726
+ return {
1727
+ "memory_units_deleted": units_count,
1728
+ "entities_deleted": entities_count,
1729
+ "documents_deleted": documents_count
1730
+ }
1731
+
1732
+ except Exception as e:
1733
+ raise Exception(f"Failed to delete agent data: {str(e)}")
1734
+
1735
+ async def get_graph_data(self, bank_id: Optional[str] = None, fact_type: Optional[str] = None):
1736
+ """
1737
+ Get graph data for visualization.
1738
+
1739
+ Args:
1740
+ bank_id: Filter by bank ID
1741
+ fact_type: Filter by fact type (world, bank, opinion)
1742
+
1743
+ Returns:
1744
+ Dict with nodes, edges, and table_rows
1745
+ """
1746
+ pool = await self._get_pool()
1747
+ async with acquire_with_retry(pool) as conn:
1748
+ # Get memory units, optionally filtered by bank_id and fact_type
1749
+ query_conditions = []
1750
+ query_params = []
1751
+ param_count = 0
1752
+
1753
+ if bank_id:
1754
+ param_count += 1
1755
+ query_conditions.append(f"bank_id = ${param_count}")
1756
+ query_params.append(bank_id)
1757
+
1758
+ if fact_type:
1759
+ param_count += 1
1760
+ query_conditions.append(f"fact_type = ${param_count}")
1761
+ query_params.append(fact_type)
1762
+
1763
+ where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
1764
+
1765
+ units = await conn.fetch(f"""
1766
+ SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
1767
+ FROM memory_units
1768
+ {where_clause}
1769
+ ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
1770
+ LIMIT 1000
1771
+ """, *query_params)
1772
+
1773
+ # Get links, filtering to only include links between units of the selected agent
1774
+ unit_ids = [row['id'] for row in units]
1775
+ if unit_ids:
1776
+ links = await conn.fetch("""
1777
+ SELECT
1778
+ ml.from_unit_id,
1779
+ ml.to_unit_id,
1780
+ ml.link_type,
1781
+ ml.weight,
1782
+ e.canonical_name as entity_name
1783
+ FROM memory_links ml
1784
+ LEFT JOIN entities e ON ml.entity_id = e.id
1785
+ WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.to_unit_id = ANY($1::uuid[])
1786
+ ORDER BY ml.link_type, ml.weight DESC
1787
+ """, unit_ids)
1788
+ else:
1789
+ links = []
1790
+
1791
+ # Get entity information
1792
+ unit_entities = await conn.fetch("""
1793
+ SELECT ue.unit_id, e.canonical_name
1794
+ FROM unit_entities ue
1795
+ JOIN entities e ON ue.entity_id = e.id
1796
+ ORDER BY ue.unit_id
1797
+ """)
1798
+
1799
+ # Build entity mapping
1800
+ entity_map = {}
1801
+ for row in unit_entities:
1802
+ unit_id = row['unit_id']
1803
+ entity_name = row['canonical_name']
1804
+ if unit_id not in entity_map:
1805
+ entity_map[unit_id] = []
1806
+ entity_map[unit_id].append(entity_name)
1807
+
1808
+ # Build nodes
1809
+ nodes = []
1810
+ for row in units:
1811
+ unit_id = row['id']
1812
+ text = row['text']
1813
+ event_date = row['event_date']
1814
+ context = row['context']
1815
+
1816
+ entities = entity_map.get(unit_id, [])
1817
+ entity_count = len(entities)
1818
+
1819
+ # Color by entity count
1820
+ if entity_count == 0:
1821
+ color = "#e0e0e0"
1822
+ elif entity_count == 1:
1823
+ color = "#90caf9"
1824
+ else:
1825
+ color = "#42a5f5"
1826
+
1827
+ nodes.append({
1828
+ "data": {
1829
+ "id": str(unit_id),
1830
+ "label": f"{text[:30]}..." if len(text) > 30 else text,
1831
+ "text": text,
1832
+ "date": event_date.isoformat() if event_date else "",
1833
+ "context": context if context else "",
1834
+ "entities": ", ".join(entities) if entities else "None",
1835
+ "color": color
1836
+ }
1837
+ })
1838
+
1839
+ # Build edges
1840
+ edges = []
1841
+ for row in links:
1842
+ from_id = str(row['from_unit_id'])
1843
+ to_id = str(row['to_unit_id'])
1844
+ link_type = row['link_type']
1845
+ weight = row['weight']
1846
+ entity_name = row['entity_name']
1847
+
1848
+ # Color by link type
1849
+ if link_type == 'temporal':
1850
+ color = "#00bcd4"
1851
+ line_style = "dashed"
1852
+ elif link_type == 'semantic':
1853
+ color = "#ff69b4"
1854
+ line_style = "solid"
1855
+ elif link_type == 'entity':
1856
+ color = "#ffd700"
1857
+ line_style = "solid"
1858
+ else:
1859
+ color = "#999999"
1860
+ line_style = "solid"
1861
+
1862
+ edges.append({
1863
+ "data": {
1864
+ "id": f"{from_id}-{to_id}-{link_type}",
1865
+ "source": from_id,
1866
+ "target": to_id,
1867
+ "linkType": link_type,
1868
+ "weight": weight,
1869
+ "entityName": entity_name if entity_name else "",
1870
+ "color": color,
1871
+ "lineStyle": line_style
1872
+ }
1873
+ })
1874
+
1875
+ # Build table rows
1876
+ table_rows = []
1877
+ for row in units:
1878
+ unit_id = row['id']
1879
+ entities = entity_map.get(unit_id, [])
1880
+
1881
+ table_rows.append({
1882
+ "id": str(unit_id),
1883
+ "text": row['text'],
1884
+ "context": row['context'] if row['context'] else "N/A",
1885
+ "occurred_start": row['occurred_start'].isoformat() if row['occurred_start'] else None,
1886
+ "occurred_end": row['occurred_end'].isoformat() if row['occurred_end'] else None,
1887
+ "mentioned_at": row['mentioned_at'].isoformat() if row['mentioned_at'] else None,
1888
+ "date": row['event_date'].strftime("%Y-%m-%d %H:%M") if row['event_date'] else "N/A", # Deprecated, kept for backwards compatibility
1889
+ "entities": ", ".join(entities) if entities else "None",
1890
+ "document_id": row['document_id'],
1891
+ "chunk_id": row['chunk_id'] if row['chunk_id'] else None,
1892
+ "fact_type": row['fact_type']
1893
+ })
1894
+
1895
+ return {
1896
+ "nodes": nodes,
1897
+ "edges": edges,
1898
+ "table_rows": table_rows,
1899
+ "total_units": len(units)
1900
+ }
1901
+
1902
+ async def list_memory_units(
1903
+ self,
1904
+ bank_id: Optional[str] = None,
1905
+ fact_type: Optional[str] = None,
1906
+ search_query: Optional[str] = None,
1907
+ limit: int = 100,
1908
+ offset: int = 0
1909
+ ):
1910
+ """
1911
+ List memory units for table view with optional full-text search.
1912
+
1913
+ Args:
1914
+ bank_id: Filter by bank ID
1915
+ fact_type: Filter by fact type (world, bank, opinion)
1916
+ search_query: Full-text search query (searches text and context fields)
1917
+ limit: Maximum number of results to return
1918
+ offset: Offset for pagination
1919
+
1920
+ Returns:
1921
+ Dict with items (list of memory units) and total count
1922
+ """
1923
+ pool = await self._get_pool()
1924
+ async with acquire_with_retry(pool) as conn:
1925
+ # Build query conditions
1926
+ query_conditions = []
1927
+ query_params = []
1928
+ param_count = 0
1929
+
1930
+ if bank_id:
1931
+ param_count += 1
1932
+ query_conditions.append(f"bank_id = ${param_count}")
1933
+ query_params.append(bank_id)
1934
+
1935
+ if fact_type:
1936
+ param_count += 1
1937
+ query_conditions.append(f"fact_type = ${param_count}")
1938
+ query_params.append(fact_type)
1939
+
1940
+ if search_query:
1941
+ # Full-text search on text and context fields using ILIKE
1942
+ param_count += 1
1943
+ query_conditions.append(f"(text ILIKE ${param_count} OR context ILIKE ${param_count})")
1944
+ query_params.append(f"%{search_query}%")
1945
+
1946
+ where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
1947
+
1948
+ # Get total count
1949
+ count_query = f"""
1950
+ SELECT COUNT(*) as total
1951
+ FROM memory_units
1952
+ {where_clause}
1953
+ """
1954
+ count_result = await conn.fetchrow(count_query, *query_params)
1955
+ total = count_result['total']
1956
+
1957
+ # Get units with limit and offset
1958
+ param_count += 1
1959
+ limit_param = f"${param_count}"
1960
+ query_params.append(limit)
1961
+
1962
+ param_count += 1
1963
+ offset_param = f"${param_count}"
1964
+ query_params.append(offset)
1965
+
1966
+ units = await conn.fetch(f"""
1967
+ SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id
1968
+ FROM memory_units
1969
+ {where_clause}
1970
+ ORDER BY mentioned_at DESC NULLS LAST, created_at DESC
1971
+ LIMIT {limit_param} OFFSET {offset_param}
1972
+ """, *query_params)
1973
+
1974
+ # Get entity information for these units
1975
+ if units:
1976
+ unit_ids = [row['id'] for row in units]
1977
+ unit_entities = await conn.fetch("""
1978
+ SELECT ue.unit_id, e.canonical_name
1979
+ FROM unit_entities ue
1980
+ JOIN entities e ON ue.entity_id = e.id
1981
+ WHERE ue.unit_id = ANY($1::uuid[])
1982
+ ORDER BY ue.unit_id
1983
+ """, unit_ids)
1984
+ else:
1985
+ unit_entities = []
1986
+
1987
+ # Build entity mapping
1988
+ entity_map = {}
1989
+ for row in unit_entities:
1990
+ unit_id = row['unit_id']
1991
+ entity_name = row['canonical_name']
1992
+ if unit_id not in entity_map:
1993
+ entity_map[unit_id] = []
1994
+ entity_map[unit_id].append(entity_name)
1995
+
1996
+ # Build result items
1997
+ items = []
1998
+ for row in units:
1999
+ unit_id = row['id']
2000
+ entities = entity_map.get(unit_id, [])
2001
+
2002
+ items.append({
2003
+ "id": str(unit_id),
2004
+ "text": row['text'],
2005
+ "context": row['context'] if row['context'] else "",
2006
+ "date": row['event_date'].isoformat() if row['event_date'] else "",
2007
+ "fact_type": row['fact_type'],
2008
+ "mentioned_at": row['mentioned_at'].isoformat() if row['mentioned_at'] else None,
2009
+ "occurred_start": row['occurred_start'].isoformat() if row['occurred_start'] else None,
2010
+ "occurred_end": row['occurred_end'].isoformat() if row['occurred_end'] else None,
2011
+ "entities": ", ".join(entities) if entities else "",
2012
+ "chunk_id": row['chunk_id'] if row['chunk_id'] else None
2013
+ })
2014
+
2015
+ return {
2016
+ "items": items,
2017
+ "total": total,
2018
+ "limit": limit,
2019
+ "offset": offset
2020
+ }
2021
+
2022
+ async def list_documents(
2023
+ self,
2024
+ bank_id: str,
2025
+ search_query: Optional[str] = None,
2026
+ limit: int = 100,
2027
+ offset: int = 0
2028
+ ):
2029
+ """
2030
+ List documents with optional search and pagination.
2031
+
2032
+ Args:
2033
+ bank_id: bank ID (required)
2034
+ search_query: Search in document ID
2035
+ limit: Maximum number of results
2036
+ offset: Offset for pagination
2037
+
2038
+ Returns:
2039
+ Dict with items (list of documents without original_text) and total count
2040
+ """
2041
+ pool = await self._get_pool()
2042
+ async with acquire_with_retry(pool) as conn:
2043
+ # Build query conditions
2044
+ query_conditions = []
2045
+ query_params = []
2046
+ param_count = 0
2047
+
2048
+ param_count += 1
2049
+ query_conditions.append(f"bank_id = ${param_count}")
2050
+ query_params.append(bank_id)
2051
+
2052
+ if search_query:
2053
+ # Search in document ID
2054
+ param_count += 1
2055
+ query_conditions.append(f"id ILIKE ${param_count}")
2056
+ query_params.append(f"%{search_query}%")
2057
+
2058
+ where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
2059
+
2060
+ # Get total count
2061
+ count_query = f"""
2062
+ SELECT COUNT(*) as total
2063
+ FROM documents
2064
+ {where_clause}
2065
+ """
2066
+ count_result = await conn.fetchrow(count_query, *query_params)
2067
+ total = count_result['total']
2068
+
2069
+ # Get documents with limit and offset (without original_text for performance)
2070
+ param_count += 1
2071
+ limit_param = f"${param_count}"
2072
+ query_params.append(limit)
2073
+
2074
+ param_count += 1
2075
+ offset_param = f"${param_count}"
2076
+ query_params.append(offset)
2077
+
2078
+ documents = await conn.fetch(f"""
2079
+ SELECT
2080
+ id,
2081
+ bank_id,
2082
+ content_hash,
2083
+ created_at,
2084
+ updated_at,
2085
+ LENGTH(original_text) as text_length,
2086
+ retain_params
2087
+ FROM documents
2088
+ {where_clause}
2089
+ ORDER BY created_at DESC
2090
+ LIMIT {limit_param} OFFSET {offset_param}
2091
+ """, *query_params)
2092
+
2093
+ # Get memory unit count for each document
2094
+ if documents:
2095
+ doc_ids = [(row['id'], row['bank_id']) for row in documents]
2096
+
2097
+ # Create placeholders for the query
2098
+ placeholders = []
2099
+ params_for_count = []
2100
+ for i, (doc_id, bank_id_val) in enumerate(doc_ids):
2101
+ idx_doc = i * 2 + 1
2102
+ idx_agent = i * 2 + 2
2103
+ placeholders.append(f"(document_id = ${idx_doc} AND bank_id = ${idx_agent})")
2104
+ params_for_count.extend([doc_id, bank_id_val])
2105
+
2106
+ where_clause_count = " OR ".join(placeholders)
2107
+
2108
+ unit_counts = await conn.fetch(f"""
2109
+ SELECT document_id, bank_id, COUNT(*) as unit_count
2110
+ FROM memory_units
2111
+ WHERE {where_clause_count}
2112
+ GROUP BY document_id, bank_id
2113
+ """, *params_for_count)
2114
+ else:
2115
+ unit_counts = []
2116
+
2117
+ # Build count mapping
2118
+ count_map = {(row['document_id'], row['bank_id']): row['unit_count'] for row in unit_counts}
2119
+
2120
+ # Build result items
2121
+ items = []
2122
+ for row in documents:
2123
+ doc_id = row['id']
2124
+ bank_id_val = row['bank_id']
2125
+ unit_count = count_map.get((doc_id, bank_id_val), 0)
2126
+
2127
+ items.append({
2128
+ "id": doc_id,
2129
+ "bank_id": bank_id_val,
2130
+ "content_hash": row['content_hash'],
2131
+ "created_at": row['created_at'].isoformat() if row['created_at'] else "",
2132
+ "updated_at": row['updated_at'].isoformat() if row['updated_at'] else "",
2133
+ "text_length": row['text_length'] or 0,
2134
+ "memory_unit_count": unit_count,
2135
+ "retain_params": row['retain_params'] if row['retain_params'] else None
2136
+ })
2137
+
2138
+ return {
2139
+ "items": items,
2140
+ "total": total,
2141
+ "limit": limit,
2142
+ "offset": offset
2143
+ }
2144
+
2145
+ async def get_document(
2146
+ self,
2147
+ document_id: str,
2148
+ bank_id: str
2149
+ ):
2150
+ """
2151
+ Get a specific document including its original_text.
2152
+
2153
+ Args:
2154
+ document_id: Document ID
2155
+ bank_id: bank ID
2156
+
2157
+ Returns:
2158
+ Dict with document details including original_text, or None if not found
2159
+ """
2160
+ pool = await self._get_pool()
2161
+ async with acquire_with_retry(pool) as conn:
2162
+ doc = await conn.fetchrow("""
2163
+ SELECT
2164
+ id,
2165
+ bank_id,
2166
+ original_text,
2167
+ content_hash,
2168
+ created_at,
2169
+ updated_at,
2170
+ retain_params
2171
+ FROM documents
2172
+ WHERE id = $1 AND bank_id = $2
2173
+ """, document_id, bank_id)
2174
+
2175
+ if not doc:
2176
+ return None
2177
+
2178
+ # Get memory unit count
2179
+ unit_count_row = await conn.fetchrow("""
2180
+ SELECT COUNT(*) as unit_count
2181
+ FROM memory_units
2182
+ WHERE document_id = $1 AND bank_id = $2
2183
+ """, document_id, bank_id)
2184
+
2185
+ return {
2186
+ "id": doc['id'],
2187
+ "bank_id": doc['bank_id'],
2188
+ "original_text": doc['original_text'],
2189
+ "content_hash": doc['content_hash'],
2190
+ "created_at": doc['created_at'].isoformat() if doc['created_at'] else "",
2191
+ "updated_at": doc['updated_at'].isoformat() if doc['updated_at'] else "",
2192
+ "memory_unit_count": unit_count_row['unit_count'] if unit_count_row else 0,
2193
+ "retain_params": doc['retain_params'] if doc['retain_params'] else None
2194
+ }
2195
+
2196
+ async def get_chunk(
2197
+ self,
2198
+ chunk_id: str
2199
+ ):
2200
+ """
2201
+ Get a specific chunk by its ID.
2202
+
2203
+ Args:
2204
+ chunk_id: Chunk ID (format: bank_id_document_id_chunk_index)
2205
+
2206
+ Returns:
2207
+ Dict with chunk details including chunk_text, or None if not found
2208
+ """
2209
+ pool = await self._get_pool()
2210
+ async with acquire_with_retry(pool) as conn:
2211
+ chunk = await conn.fetchrow("""
2212
+ SELECT
2213
+ chunk_id,
2214
+ document_id,
2215
+ bank_id,
2216
+ chunk_index,
2217
+ chunk_text,
2218
+ created_at
2219
+ FROM chunks
2220
+ WHERE chunk_id = $1
2221
+ """, chunk_id)
2222
+
2223
+ if not chunk:
2224
+ return None
2225
+
2226
+ return {
2227
+ "chunk_id": chunk['chunk_id'],
2228
+ "document_id": chunk['document_id'],
2229
+ "bank_id": chunk['bank_id'],
2230
+ "chunk_index": chunk['chunk_index'],
2231
+ "chunk_text": chunk['chunk_text'],
2232
+ "created_at": chunk['created_at'].isoformat() if chunk['created_at'] else ""
2233
+ }
2234
+
2235
+ async def _evaluate_opinion_update_async(
2236
+ self,
2237
+ opinion_text: str,
2238
+ opinion_confidence: float,
2239
+ new_event_text: str,
2240
+ entity_name: str,
2241
+ ) -> Optional[Dict[str, Any]]:
2242
+ """
2243
+ Evaluate if an opinion should be updated based on a new event.
2244
+
2245
+ Args:
2246
+ opinion_text: Current opinion text (includes reasons)
2247
+ opinion_confidence: Current confidence score (0.0-1.0)
2248
+ new_event_text: Text of the new event
2249
+ entity_name: Name of the entity this opinion is about
2250
+
2251
+ Returns:
2252
+ Dict with 'action' ('keep'|'update'), 'new_confidence', 'new_text' (if action=='update')
2253
+ or None if no changes needed
2254
+ """
2255
+ from pydantic import BaseModel, Field
2256
+
2257
+ class OpinionEvaluation(BaseModel):
2258
+ """Evaluation of whether an opinion should be updated."""
2259
+ action: str = Field(description="Action to take: 'keep' (no change) or 'update' (modify opinion)")
2260
+ reasoning: str = Field(description="Brief explanation of why this action was chosen")
2261
+ new_confidence: float = Field(description="New confidence score (0.0-1.0). Can be higher, lower, or same as before.")
2262
+ new_opinion_text: Optional[str] = Field(
2263
+ default=None,
2264
+ description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None."
2265
+ )
2266
+
2267
+ evaluation_prompt = f"""You are evaluating whether an existing opinion should be updated based on new information.
2268
+
2269
+ ENTITY: {entity_name}
2270
+
2271
+ EXISTING OPINION:
2272
+ {opinion_text}
2273
+ Current confidence: {opinion_confidence:.2f}
2274
+
2275
+ NEW EVENT:
2276
+ {new_event_text}
2277
+
2278
+ Evaluate whether this new event:
2279
+ 1. REINFORCES the opinion (increase confidence, keep text)
2280
+ 2. WEAKENS the opinion (decrease confidence, keep text)
2281
+ 3. CHANGES the opinion (update both text and confidence, noting "Previously I thought X, but now Y...")
2282
+ 4. IRRELEVANT (keep everything as is)
2283
+
2284
+ Guidelines:
2285
+ - Only suggest 'update' action if the new event genuinely contradicts or significantly modifies the opinion
2286
+ - If updating the text, acknowledge the previous opinion and explain the change
2287
+ - Confidence should reflect accumulated evidence (0.0 = no confidence, 1.0 = very confident)
2288
+ - Small changes in confidence are normal; large jumps should be rare"""
2289
+
2290
+ try:
2291
+ result = await self._llm_config.call(
2292
+ messages=[
2293
+ {"role": "system", "content": "You evaluate and update opinions based on new information."},
2294
+ {"role": "user", "content": evaluation_prompt}
2295
+ ],
2296
+ response_format=OpinionEvaluation,
2297
+ scope="memory_evaluate_opinion",
2298
+ temperature=0.3 # Lower temperature for more consistent evaluation
2299
+ )
2300
+
2301
+ # Only return updates if something actually changed
2302
+ if result.action == 'keep' and abs(result.new_confidence - opinion_confidence) < 0.01:
2303
+ return None
2304
+
2305
+ return {
2306
+ 'action': result.action,
2307
+ 'reasoning': result.reasoning,
2308
+ 'new_confidence': result.new_confidence,
2309
+ 'new_text': result.new_opinion_text if result.action == 'update' else None
2310
+ }
2311
+
2312
+ except Exception as e:
2313
+ logger.warning(f"Failed to evaluate opinion update: {str(e)}")
2314
+ return None
2315
+
2316
+ async def _handle_form_opinion(self, task_dict: Dict[str, Any]):
2317
+ """
2318
+ Handler for form opinion tasks.
2319
+
2320
+ Args:
2321
+ task_dict: Dict with keys: 'bank_id', 'answer_text', 'query'
2322
+ """
2323
+ bank_id = task_dict['bank_id']
2324
+ answer_text = task_dict['answer_text']
2325
+ query = task_dict['query']
2326
+
2327
+ await self._extract_and_store_opinions_async(
2328
+ bank_id=bank_id,
2329
+ answer_text=answer_text,
2330
+ query=query
2331
+ )
2332
+
2333
+ async def _handle_reinforce_opinion(self, task_dict: Dict[str, Any]):
2334
+ """
2335
+ Handler for reinforce opinion tasks.
2336
+
2337
+ Args:
2338
+ task_dict: Dict with keys: 'bank_id', 'created_unit_ids', 'unit_texts', 'unit_entities'
2339
+ """
2340
+ bank_id = task_dict['bank_id']
2341
+ created_unit_ids = task_dict['created_unit_ids']
2342
+ unit_texts = task_dict['unit_texts']
2343
+ unit_entities = task_dict['unit_entities']
2344
+
2345
+ await self._reinforce_opinions_async(
2346
+ bank_id=bank_id,
2347
+ created_unit_ids=created_unit_ids,
2348
+ unit_texts=unit_texts,
2349
+ unit_entities=unit_entities
2350
+ )
2351
+
2352
+ async def _reinforce_opinions_async(
2353
+ self,
2354
+ bank_id: str,
2355
+ created_unit_ids: List[str],
2356
+ unit_texts: List[str],
2357
+ unit_entities: List[List[Dict[str, str]]],
2358
+ ):
2359
+ """
2360
+ Background task to reinforce opinions based on newly ingested events.
2361
+
2362
+ This runs asynchronously and does not block the put operation.
2363
+
2364
+ Args:
2365
+ bank_id: bank ID
2366
+ created_unit_ids: List of newly created memory unit IDs
2367
+ unit_texts: Texts of the newly created units
2368
+ unit_entities: Entities extracted from each unit
2369
+ """
2370
+ try:
2371
+ # Extract all unique entity names from the new units
2372
+ entity_names = set()
2373
+ for entities_list in unit_entities:
2374
+ for entity in entities_list:
2375
+ # Handle both Entity objects and dicts
2376
+ if hasattr(entity, 'text'):
2377
+ entity_names.add(entity.text)
2378
+ elif isinstance(entity, dict):
2379
+ entity_names.add(entity['text'])
2380
+
2381
+ if not entity_names:
2382
+ return
2383
+
2384
+
2385
+ pool = await self._get_pool()
2386
+ async with acquire_with_retry(pool) as conn:
2387
+ # Find all opinions related to these entities
2388
+ opinions = await conn.fetch(
2389
+ """
2390
+ SELECT DISTINCT mu.id, mu.text, mu.confidence_score, e.canonical_name
2391
+ FROM memory_units mu
2392
+ JOIN unit_entities ue ON mu.id = ue.unit_id
2393
+ JOIN entities e ON ue.entity_id = e.id
2394
+ WHERE mu.bank_id = $1
2395
+ AND mu.fact_type = 'opinion'
2396
+ AND e.canonical_name = ANY($2::text[])
2397
+ """,
2398
+ bank_id,
2399
+ list(entity_names)
2400
+ )
2401
+
2402
+ if not opinions:
2403
+ return
2404
+
2405
+
2406
+ # Use cached LLM config
2407
+ if self._llm_config is None:
2408
+ logger.error("[REINFORCE] LLM config not available, skipping opinion reinforcement")
2409
+ return
2410
+
2411
+ # Evaluate each opinion against the new events
2412
+ updates_to_apply = []
2413
+ for opinion in opinions:
2414
+ opinion_id = str(opinion['id'])
2415
+ opinion_text = opinion['text']
2416
+ opinion_confidence = opinion['confidence_score']
2417
+ entity_name = opinion['canonical_name']
2418
+
2419
+ # Find all new events mentioning this entity
2420
+ relevant_events = []
2421
+ for unit_text, entities_list in zip(unit_texts, unit_entities):
2422
+ if any(e['text'] == entity_name for e in entities_list):
2423
+ relevant_events.append(unit_text)
2424
+
2425
+ if not relevant_events:
2426
+ continue
2427
+
2428
+ # Combine all relevant events
2429
+ combined_events = "\n".join(relevant_events)
2430
+
2431
+ # Evaluate if opinion should be updated
2432
+ evaluation = await self._evaluate_opinion_update_async(
2433
+ opinion_text,
2434
+ opinion_confidence,
2435
+ combined_events,
2436
+ entity_name
2437
+ )
2438
+
2439
+ if evaluation:
2440
+ updates_to_apply.append({
2441
+ 'opinion_id': opinion_id,
2442
+ 'evaluation': evaluation
2443
+ })
2444
+
2445
+ # Apply all updates in a single transaction
2446
+ if updates_to_apply:
2447
+ async with conn.transaction():
2448
+ for update in updates_to_apply:
2449
+ opinion_id = update['opinion_id']
2450
+ evaluation = update['evaluation']
2451
+
2452
+ if evaluation['action'] == 'update' and evaluation['new_text']:
2453
+ # Update both text and confidence
2454
+ await conn.execute(
2455
+ """
2456
+ UPDATE memory_units
2457
+ SET text = $1, confidence_score = $2, updated_at = NOW()
2458
+ WHERE id = $3
2459
+ """,
2460
+ evaluation['new_text'],
2461
+ evaluation['new_confidence'],
2462
+ uuid.UUID(opinion_id)
2463
+ )
2464
+ else:
2465
+ # Only update confidence
2466
+ await conn.execute(
2467
+ """
2468
+ UPDATE memory_units
2469
+ SET confidence_score = $1, updated_at = NOW()
2470
+ WHERE id = $2
2471
+ """,
2472
+ evaluation['new_confidence'],
2473
+ uuid.UUID(opinion_id)
2474
+ )
2475
+
2476
+ else:
2477
+ pass # No opinions to update
2478
+
2479
+ except Exception as e:
2480
+ logger.error(f"[REINFORCE] Error during opinion reinforcement: {str(e)}")
2481
+ import traceback
2482
+ traceback.print_exc()
2483
+
2484
+ # ==================== bank profile Methods ====================
2485
+
2486
+ async def get_bank_profile(self, bank_id: str) -> "bank_utils.BankProfile":
2487
+ """
2488
+ Get bank profile (name, personality + background).
2489
+ Auto-creates agent with default values if not exists.
2490
+
2491
+ Args:
2492
+ bank_id: bank IDentifier
2493
+
2494
+ Returns:
2495
+ BankProfile with name, typed PersonalityTraits, and background
2496
+ """
2497
+ pool = await self._get_pool()
2498
+ return await bank_utils.get_bank_profile(pool, bank_id)
2499
+
2500
+ async def update_bank_personality(
2501
+ self,
2502
+ bank_id: str,
2503
+ personality: Dict[str, float]
2504
+ ) -> None:
2505
+ """
2506
+ Update bank personality traits.
2507
+
2508
+ Args:
2509
+ bank_id: bank IDentifier
2510
+ personality: Dict with Big Five traits + bias_strength (all 0-1)
2511
+ """
2512
+ pool = await self._get_pool()
2513
+ await bank_utils.update_bank_personality(pool, bank_id, personality)
2514
+
2515
+ async def merge_bank_background(
2516
+ self,
2517
+ bank_id: str,
2518
+ new_info: str,
2519
+ update_personality: bool = True
2520
+ ) -> dict:
2521
+ """
2522
+ Merge new background information with existing background using LLM.
2523
+ Normalizes to first person ("I") and resolves conflicts.
2524
+ Optionally infers personality traits from the merged background.
2525
+
2526
+ Args:
2527
+ bank_id: bank IDentifier
2528
+ new_info: New background information to add/merge
2529
+ update_personality: If True, infer Big Five traits from background (default: True)
2530
+
2531
+ Returns:
2532
+ Dict with 'background' (str) and optionally 'personality' (dict) keys
2533
+ """
2534
+ pool = await self._get_pool()
2535
+ return await bank_utils.merge_bank_background(
2536
+ pool, self._llm_config, bank_id, new_info, update_personality
2537
+ )
2538
+
2539
+ async def list_banks(self) -> list:
2540
+ """
2541
+ List all agents in the system.
2542
+
2543
+ Returns:
2544
+ List of dicts with bank_id, name, personality, background, created_at, updated_at
2545
+ """
2546
+ pool = await self._get_pool()
2547
+ return await bank_utils.list_banks(pool)
2548
+
2549
+ # ==================== Reflect Methods ====================
2550
+
2551
+ async def reflect_async(
2552
+ self,
2553
+ bank_id: str,
2554
+ query: str,
2555
+ budget: Budget = Budget.LOW,
2556
+ context: str = None,
2557
+ ) -> ReflectResult:
2558
+ """
2559
+ Reflect and formulate an answer using bank identity, world facts, and opinions.
2560
+
2561
+ This method:
2562
+ 1. Retrieves agent facts (bank's identity and past actions)
2563
+ 2. Retrieves world facts (general knowledge)
2564
+ 3. Retrieves existing opinions (bank's formed perspectives)
2565
+ 4. Uses LLM to formulate an answer
2566
+ 5. Extracts and stores any new opinions formed during reflection
2567
+ 6. Returns plain text answer and the facts used
2568
+
2569
+ Args:
2570
+ bank_id: bank identifier
2571
+ query: Question to answer
2572
+ budget: Budget level for memory exploration (low=100, mid=300, high=600 units)
2573
+ context: Additional context string to include in LLM prompt (not used in recall)
2574
+
2575
+ Returns:
2576
+ ReflectResult containing:
2577
+ - text: Plain text answer (no markdown)
2578
+ - based_on: Dict with 'world', 'agent', and 'opinion' fact lists (MemoryFact objects)
2579
+ - new_opinions: List of newly formed opinions
2580
+ """
2581
+ # Use cached LLM config
2582
+ if self._llm_config is None:
2583
+ raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
2584
+
2585
+ # Steps 1-3: Run multi-fact-type search (12-way retrieval: 4 methods × 3 fact types)
2586
+ search_result = await self.recall_async(
2587
+ bank_id=bank_id,
2588
+ query=query,
2589
+ budget=budget,
2590
+ max_tokens=4096,
2591
+ enable_trace=False,
2592
+ fact_type=['agent', 'world', 'opinion'],
2593
+ include_entities=True
2594
+ )
2595
+
2596
+ all_results = search_result.results
2597
+ logger.info(f"[THINK] Search returned {len(all_results)} results")
2598
+
2599
+ # Split results by fact type for structured response
2600
+ agent_results = [r for r in all_results if r.fact_type == 'bank']
2601
+ world_results = [r for r in all_results if r.fact_type == 'world']
2602
+ opinion_results = [r for r in all_results if r.fact_type == 'opinion']
2603
+
2604
+ logger.info(f"[THINK] Split results - agent: {len(agent_results)}, world: {len(world_results)}, opinion: {len(opinion_results)}")
2605
+
2606
+ # Format facts for LLM
2607
+ agent_facts_text = think_utils.format_facts_for_prompt(agent_results)
2608
+ world_facts_text = think_utils.format_facts_for_prompt(world_results)
2609
+ opinion_facts_text = think_utils.format_facts_for_prompt(opinion_results)
2610
+
2611
+ logger.info(f"[THINK] Formatted facts - agent: {len(agent_facts_text)} chars, world: {len(world_facts_text)} chars, opinion: {len(opinion_facts_text)} chars")
2612
+
2613
+ # Get bank profile (name, personality + background)
2614
+ profile = await self.get_bank_profile(bank_id)
2615
+ name = profile["name"]
2616
+ personality = profile["personality"] # Typed as PersonalityTraits
2617
+ background = profile["background"]
2618
+
2619
+ # Build the prompt
2620
+ prompt = think_utils.build_think_prompt(
2621
+ agent_facts_text=agent_facts_text,
2622
+ world_facts_text=world_facts_text,
2623
+ opinion_facts_text=opinion_facts_text,
2624
+ query=query,
2625
+ name=name,
2626
+ personality=personality,
2627
+ background=background,
2628
+ context=context,
2629
+ )
2630
+
2631
+ logger.info(f"[THINK] Full prompt length: {len(prompt)} chars")
2632
+
2633
+ system_message = think_utils.get_system_message(personality)
2634
+
2635
+ answer_text = await self._llm_config.call(
2636
+ messages=[
2637
+ {"role": "system", "content": system_message},
2638
+ {"role": "user", "content": prompt}
2639
+ ],
2640
+ scope="memory_think",
2641
+ temperature=0.9,
2642
+ max_tokens=1000
2643
+ )
2644
+
2645
+ answer_text = answer_text.strip()
2646
+
2647
+ # Submit form_opinion task for background processing
2648
+ await self._task_backend.submit_task({
2649
+ 'type': 'form_opinion',
2650
+ 'bank_id': bank_id,
2651
+ 'answer_text': answer_text,
2652
+ 'query': query
2653
+ })
2654
+
2655
+ # Return response with facts split by type
2656
+ return ReflectResult(
2657
+ text=answer_text,
2658
+ based_on={
2659
+ "world": world_results,
2660
+ "agent": agent_results,
2661
+ "opinion": opinion_results
2662
+ },
2663
+ new_opinions=[] # Opinions are being extracted asynchronously
2664
+ )
2665
+
2666
+ async def _extract_and_store_opinions_async(
2667
+ self,
2668
+ bank_id: str,
2669
+ answer_text: str,
2670
+ query: str
2671
+ ):
2672
+ """
2673
+ Background task to extract and store opinions from think response.
2674
+
2675
+ This runs asynchronously and does not block the think response.
2676
+
2677
+ Args:
2678
+ bank_id: bank IDentifier
2679
+ answer_text: The generated answer text
2680
+ query: The original query
2681
+ """
2682
+ try:
2683
+ # Extract opinions from the answer
2684
+ new_opinions = await think_utils.extract_opinions_from_text(
2685
+ self._llm_config, text=answer_text, query=query
2686
+ )
2687
+
2688
+ # Store new opinions
2689
+ if new_opinions:
2690
+ from datetime import datetime, timezone
2691
+ current_time = datetime.now(timezone.utc)
2692
+ for opinion in new_opinions:
2693
+ await self.retain_async(
2694
+ bank_id=bank_id,
2695
+ content=opinion.opinion,
2696
+ context=f"formed during thinking about: {query}",
2697
+ event_date=current_time,
2698
+ fact_type_override='opinion',
2699
+ confidence_score=opinion.confidence
2700
+ )
2701
+
2702
+ except Exception as e:
2703
+ logger.warning(f"[THINK] Failed to extract/store opinions: {str(e)}")
2704
+
2705
+ async def get_entity_observations(
2706
+ self,
2707
+ bank_id: str,
2708
+ entity_id: str,
2709
+ limit: int = 10
2710
+ ) -> List[EntityObservation]:
2711
+ """
2712
+ Get observations linked to an entity.
2713
+
2714
+ Args:
2715
+ bank_id: bank IDentifier
2716
+ entity_id: Entity UUID to get observations for
2717
+ limit: Maximum number of observations to return
2718
+
2719
+ Returns:
2720
+ List of EntityObservation objects
2721
+ """
2722
+ pool = await self._get_pool()
2723
+ async with acquire_with_retry(pool) as conn:
2724
+ rows = await conn.fetch(
2725
+ """
2726
+ SELECT mu.text, mu.mentioned_at
2727
+ FROM memory_units mu
2728
+ JOIN unit_entities ue ON mu.id = ue.unit_id
2729
+ WHERE mu.bank_id = $1
2730
+ AND mu.fact_type = 'observation'
2731
+ AND ue.entity_id = $2
2732
+ ORDER BY mu.mentioned_at DESC
2733
+ LIMIT $3
2734
+ """,
2735
+ bank_id, uuid.UUID(entity_id), limit
2736
+ )
2737
+
2738
+ observations = []
2739
+ for row in rows:
2740
+ mentioned_at = row['mentioned_at'].isoformat() if row['mentioned_at'] else None
2741
+ observations.append(EntityObservation(
2742
+ text=row['text'],
2743
+ mentioned_at=mentioned_at
2744
+ ))
2745
+ return observations
2746
+
2747
+ async def list_entities(
2748
+ self,
2749
+ bank_id: str,
2750
+ limit: int = 100
2751
+ ) -> List[Dict[str, Any]]:
2752
+ """
2753
+ List all entities for a bank.
2754
+
2755
+ Args:
2756
+ bank_id: bank IDentifier
2757
+ limit: Maximum number of entities to return
2758
+
2759
+ Returns:
2760
+ List of entity dicts with id, canonical_name, mention_count, first_seen, last_seen
2761
+ """
2762
+ pool = await self._get_pool()
2763
+ async with acquire_with_retry(pool) as conn:
2764
+ rows = await conn.fetch(
2765
+ """
2766
+ SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
2767
+ FROM entities
2768
+ WHERE bank_id = $1
2769
+ ORDER BY mention_count DESC, last_seen DESC
2770
+ LIMIT $2
2771
+ """,
2772
+ bank_id, limit
2773
+ )
2774
+
2775
+ entities = []
2776
+ for row in rows:
2777
+ # Handle metadata - may be dict, JSON string, or None
2778
+ metadata = row['metadata']
2779
+ if metadata is None:
2780
+ metadata = {}
2781
+ elif isinstance(metadata, str):
2782
+ import json
2783
+ try:
2784
+ metadata = json.loads(metadata)
2785
+ except json.JSONDecodeError:
2786
+ metadata = {}
2787
+
2788
+ entities.append({
2789
+ 'id': str(row['id']),
2790
+ 'canonical_name': row['canonical_name'],
2791
+ 'mention_count': row['mention_count'],
2792
+ 'first_seen': row['first_seen'].isoformat() if row['first_seen'] else None,
2793
+ 'last_seen': row['last_seen'].isoformat() if row['last_seen'] else None,
2794
+ 'metadata': metadata
2795
+ })
2796
+ return entities
2797
+
2798
+ async def get_entity_state(
2799
+ self,
2800
+ bank_id: str,
2801
+ entity_id: str,
2802
+ entity_name: str,
2803
+ limit: int = 10
2804
+ ) -> EntityState:
2805
+ """
2806
+ Get the current state (mental model) of an entity.
2807
+
2808
+ Args:
2809
+ bank_id: bank IDentifier
2810
+ entity_id: Entity UUID
2811
+ entity_name: Canonical name of the entity
2812
+ limit: Maximum number of observations to include
2813
+
2814
+ Returns:
2815
+ EntityState with observations
2816
+ """
2817
+ observations = await self.get_entity_observations(bank_id, entity_id, limit)
2818
+ return EntityState(
2819
+ entity_id=entity_id,
2820
+ canonical_name=entity_name,
2821
+ observations=observations
2822
+ )
2823
+
2824
+ async def regenerate_entity_observations(
2825
+ self,
2826
+ bank_id: str,
2827
+ entity_id: str,
2828
+ entity_name: str,
2829
+ version: str | None = None
2830
+ ) -> List[str]:
2831
+ """
2832
+ Regenerate observations for an entity by:
2833
+ 1. Checking version for deduplication (if provided)
2834
+ 2. Searching all facts mentioning the entity
2835
+ 3. Using LLM to synthesize observations (no personality)
2836
+ 4. Deleting old observations for this entity
2837
+ 5. Storing new observations linked to the entity
2838
+
2839
+ Args:
2840
+ bank_id: bank IDentifier
2841
+ entity_id: Entity UUID
2842
+ entity_name: Canonical name of the entity
2843
+ version: Entity's last_seen timestamp when task was created (for deduplication)
2844
+
2845
+ Returns:
2846
+ List of created observation IDs
2847
+ """
2848
+ pool = await self._get_pool()
2849
+
2850
+ # Step 1: Check version for deduplication
2851
+ if version:
2852
+ async with acquire_with_retry(pool) as conn:
2853
+ current_last_seen = await conn.fetchval(
2854
+ """
2855
+ SELECT last_seen
2856
+ FROM entities
2857
+ WHERE id = $1 AND bank_id = $2
2858
+ """,
2859
+ uuid.UUID(entity_id), bank_id
2860
+ )
2861
+
2862
+ if current_last_seen and current_last_seen.isoformat() != version:
2863
+ return []
2864
+
2865
+ # Step 2: Get all facts mentioning this entity (exclude observations themselves)
2866
+ async with acquire_with_retry(pool) as conn:
2867
+ rows = await conn.fetch(
2868
+ """
2869
+ SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
2870
+ FROM memory_units mu
2871
+ JOIN unit_entities ue ON mu.id = ue.unit_id
2872
+ WHERE mu.bank_id = $1
2873
+ AND ue.entity_id = $2
2874
+ AND mu.fact_type IN ('world', 'agent')
2875
+ ORDER BY mu.occurred_start DESC
2876
+ LIMIT 50
2877
+ """,
2878
+ bank_id, uuid.UUID(entity_id)
2879
+ )
2880
+
2881
+ if not rows:
2882
+ return []
2883
+
2884
+ # Convert to MemoryFact objects for the observation extraction
2885
+ facts = []
2886
+ for row in rows:
2887
+ occurred_start = row['occurred_start'].isoformat() if row['occurred_start'] else None
2888
+ facts.append(MemoryFact(
2889
+ id=str(row['id']),
2890
+ text=row['text'],
2891
+ fact_type=row['fact_type'],
2892
+ context=row['context'],
2893
+ occurred_start=occurred_start
2894
+ ))
2895
+
2896
+ # Step 3: Extract observations using LLM (no personality)
2897
+ observations = await observation_utils.extract_observations_from_facts(
2898
+ self._llm_config,
2899
+ entity_name,
2900
+ facts
2901
+ )
2902
+
2903
+ if not observations:
2904
+ return []
2905
+
2906
+ # Step 4: Delete old observations and insert new ones in a transaction
2907
+ async with acquire_with_retry(pool) as conn:
2908
+ async with conn.transaction():
2909
+ # Delete old observations for this entity
2910
+ await conn.execute(
2911
+ """
2912
+ DELETE FROM memory_units
2913
+ WHERE id IN (
2914
+ SELECT mu.id
2915
+ FROM memory_units mu
2916
+ JOIN unit_entities ue ON mu.id = ue.unit_id
2917
+ WHERE mu.bank_id = $1
2918
+ AND mu.fact_type = 'observation'
2919
+ AND ue.entity_id = $2
2920
+ )
2921
+ """,
2922
+ bank_id, uuid.UUID(entity_id)
2923
+ )
2924
+
2925
+ # Generate embeddings for new observations
2926
+ embeddings = await embedding_utils.generate_embeddings_batch(
2927
+ self.embeddings, observations
2928
+ )
2929
+
2930
+ # Insert new observations
2931
+ current_time = utcnow()
2932
+ created_ids = []
2933
+
2934
+ for obs_text, embedding in zip(observations, embeddings):
2935
+ result = await conn.fetchrow(
2936
+ """
2937
+ INSERT INTO memory_units (
2938
+ bank_id, text, embedding, context, event_date,
2939
+ occurred_start, occurred_end, mentioned_at,
2940
+ fact_type, access_count
2941
+ )
2942
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
2943
+ RETURNING id
2944
+ """,
2945
+ bank_id,
2946
+ obs_text,
2947
+ str(embedding),
2948
+ f"observation about {entity_name}",
2949
+ current_time,
2950
+ current_time,
2951
+ current_time,
2952
+ current_time
2953
+ )
2954
+ obs_id = str(result['id'])
2955
+ created_ids.append(obs_id)
2956
+
2957
+ # Link observation to entity
2958
+ await conn.execute(
2959
+ """
2960
+ INSERT INTO unit_entities (unit_id, entity_id)
2961
+ VALUES ($1, $2)
2962
+ """,
2963
+ uuid.UUID(obs_id), uuid.UUID(entity_id)
2964
+ )
2965
+
2966
+ # Single consolidated log line
2967
+ logger.info(f"[OBSERVATIONS] {entity_name}: {len(facts)} facts -> {len(created_ids)} observations")
2968
+ return created_ids
2969
+
2970
+ async def _regenerate_observations_sync(
2971
+ self,
2972
+ bank_id: str,
2973
+ entity_ids: List[str],
2974
+ min_facts: int = 5
2975
+ ) -> None:
2976
+ """
2977
+ Regenerate observations for entities synchronously (called during retain).
2978
+
2979
+ Args:
2980
+ bank_id: Bank identifier
2981
+ entity_ids: List of entity IDs to process
2982
+ min_facts: Minimum facts required to regenerate observations
2983
+ """
2984
+ if not bank_id or not entity_ids:
2985
+ return
2986
+
2987
+ pool = await self._get_pool()
2988
+ async with pool.acquire() as conn:
2989
+ for entity_id in entity_ids:
2990
+ try:
2991
+ entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
2992
+
2993
+ # Check if entity exists
2994
+ entity_exists = await conn.fetchrow(
2995
+ "SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
2996
+ entity_uuid, bank_id
2997
+ )
2998
+
2999
+ if not entity_exists:
3000
+ logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
3001
+ continue
3002
+
3003
+ entity_name = entity_exists['canonical_name']
3004
+
3005
+ # Count facts linked to this entity
3006
+ fact_count = await conn.fetchval(
3007
+ "SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1",
3008
+ entity_uuid
3009
+ ) or 0
3010
+
3011
+ # Only regenerate if entity has enough facts
3012
+ if fact_count >= min_facts:
3013
+ await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
3014
+ else:
3015
+ logger.debug(f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)")
3016
+
3017
+ except Exception as e:
3018
+ logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3019
+ continue
3020
+
3021
+ async def _handle_regenerate_observations(self, task_dict: Dict[str, Any]):
3022
+ """
3023
+ Handler for regenerate_observations tasks.
3024
+
3025
+ Args:
3026
+ task_dict: Dict with 'bank_id' and either:
3027
+ - 'entity_ids' (list): Process multiple entities
3028
+ - 'entity_id', 'entity_name': Process single entity (legacy)
3029
+ """
3030
+ try:
3031
+ bank_id = task_dict.get('bank_id')
3032
+
3033
+ # New format: multiple entity_ids
3034
+ if 'entity_ids' in task_dict:
3035
+ entity_ids = task_dict.get('entity_ids', [])
3036
+ min_facts = task_dict.get('min_facts', 5)
3037
+
3038
+ if not bank_id or not entity_ids:
3039
+ logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
3040
+ return
3041
+
3042
+ # Process each entity
3043
+ pool = await self._get_pool()
3044
+ async with pool.acquire() as conn:
3045
+ for entity_id in entity_ids:
3046
+ try:
3047
+ # Fetch entity name and check fact count
3048
+ import uuid as uuid_module
3049
+ entity_uuid = uuid_module.UUID(entity_id) if isinstance(entity_id, str) else entity_id
3050
+
3051
+ # First check if entity exists
3052
+ entity_exists = await conn.fetchrow(
3053
+ "SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
3054
+ entity_uuid, bank_id
3055
+ )
3056
+
3057
+ if not entity_exists:
3058
+ logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
3059
+ continue
3060
+
3061
+ entity_name = entity_exists['canonical_name']
3062
+
3063
+ # Count facts linked to this entity
3064
+ fact_count = await conn.fetchval(
3065
+ "SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1",
3066
+ entity_uuid
3067
+ ) or 0
3068
+
3069
+ # Only regenerate if entity has enough facts
3070
+ if fact_count >= min_facts:
3071
+ await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
3072
+ else:
3073
+ logger.debug(f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)")
3074
+
3075
+ except Exception as e:
3076
+ logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3077
+ continue
3078
+
3079
+ # Legacy format: single entity
3080
+ else:
3081
+ entity_id = task_dict.get('entity_id')
3082
+ entity_name = task_dict.get('entity_name')
3083
+ version = task_dict.get('version')
3084
+
3085
+ if not all([bank_id, entity_id, entity_name]):
3086
+ logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
3087
+ return
3088
+
3089
+ await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version)
3090
+
3091
+ except Exception as e:
3092
+ logger.error(f"[OBSERVATIONS] Error regenerating observations: {e}")
3093
+ import traceback
3094
+ traceback.print_exc()
3095
+