hindsight-api 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. hindsight_api/__init__.py +10 -9
  2. hindsight_api/alembic/env.py +5 -8
  3. hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
  4. hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
  5. hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
  6. hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
  7. hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
  8. hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
  9. hindsight_api/api/__init__.py +10 -10
  10. hindsight_api/api/http.py +575 -593
  11. hindsight_api/api/mcp.py +30 -28
  12. hindsight_api/banner.py +13 -6
  13. hindsight_api/config.py +9 -13
  14. hindsight_api/engine/__init__.py +9 -9
  15. hindsight_api/engine/cross_encoder.py +22 -21
  16. hindsight_api/engine/db_utils.py +5 -4
  17. hindsight_api/engine/embeddings.py +22 -21
  18. hindsight_api/engine/entity_resolver.py +81 -75
  19. hindsight_api/engine/llm_wrapper.py +61 -79
  20. hindsight_api/engine/memory_engine.py +603 -625
  21. hindsight_api/engine/query_analyzer.py +100 -97
  22. hindsight_api/engine/response_models.py +105 -106
  23. hindsight_api/engine/retain/__init__.py +9 -16
  24. hindsight_api/engine/retain/bank_utils.py +34 -58
  25. hindsight_api/engine/retain/chunk_storage.py +4 -12
  26. hindsight_api/engine/retain/deduplication.py +9 -28
  27. hindsight_api/engine/retain/embedding_processing.py +4 -11
  28. hindsight_api/engine/retain/embedding_utils.py +3 -4
  29. hindsight_api/engine/retain/entity_processing.py +7 -17
  30. hindsight_api/engine/retain/fact_extraction.py +155 -165
  31. hindsight_api/engine/retain/fact_storage.py +11 -23
  32. hindsight_api/engine/retain/link_creation.py +11 -39
  33. hindsight_api/engine/retain/link_utils.py +166 -95
  34. hindsight_api/engine/retain/observation_regeneration.py +39 -52
  35. hindsight_api/engine/retain/orchestrator.py +72 -62
  36. hindsight_api/engine/retain/types.py +49 -43
  37. hindsight_api/engine/search/__init__.py +5 -5
  38. hindsight_api/engine/search/fusion.py +6 -15
  39. hindsight_api/engine/search/graph_retrieval.py +22 -23
  40. hindsight_api/engine/search/mpfp_retrieval.py +76 -92
  41. hindsight_api/engine/search/observation_utils.py +9 -16
  42. hindsight_api/engine/search/reranking.py +4 -7
  43. hindsight_api/engine/search/retrieval.py +87 -66
  44. hindsight_api/engine/search/scoring.py +5 -7
  45. hindsight_api/engine/search/temporal_extraction.py +8 -11
  46. hindsight_api/engine/search/think_utils.py +115 -39
  47. hindsight_api/engine/search/trace.py +68 -39
  48. hindsight_api/engine/search/tracer.py +44 -35
  49. hindsight_api/engine/search/types.py +20 -17
  50. hindsight_api/engine/task_backend.py +21 -26
  51. hindsight_api/engine/utils.py +25 -10
  52. hindsight_api/main.py +21 -40
  53. hindsight_api/mcp_local.py +190 -0
  54. hindsight_api/metrics.py +44 -30
  55. hindsight_api/migrations.py +10 -8
  56. hindsight_api/models.py +60 -72
  57. hindsight_api/pg0.py +22 -23
  58. hindsight_api/server.py +3 -6
  59. {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +2 -2
  60. hindsight_api-0.1.6.dist-info/RECORD +64 -0
  61. {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
  62. hindsight_api-0.1.5.dist-info/RECORD +0 -63
  63. {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
@@ -8,22 +8,23 @@ This implements a sophisticated memory architecture that combines:
8
8
  4. Spreading activation: Search through the graph with activation decay
9
9
  5. Dynamic weighting: Recency and frequency-based importance
10
10
  """
11
- import json
12
- import os
13
- from datetime import datetime, timedelta, timezone
14
- from typing import Any, Dict, List, Optional, Tuple, Union, TypedDict, TYPE_CHECKING
15
- import asyncpg
11
+
16
12
  import asyncio
17
- from .embeddings import Embeddings, create_embeddings_from_env
18
- from .cross_encoder import CrossEncoderModel, create_cross_encoder_from_env
13
+ import logging
19
14
  import time
20
- import numpy as np
21
15
  import uuid
22
- import logging
16
+ from datetime import UTC, datetime, timedelta
17
+ from typing import TYPE_CHECKING, Any, TypedDict
18
+
19
+ import asyncpg
20
+ import numpy as np
23
21
  from pydantic import BaseModel, Field
24
22
 
23
+ from .cross_encoder import CrossEncoderModel
24
+ from .embeddings import Embeddings, create_embeddings_from_env
25
+
25
26
  if TYPE_CHECKING:
26
- from ..config import HindsightConfig
27
+ pass
27
28
 
28
29
 
29
30
  class RetainContentDict(TypedDict, total=False):
@@ -36,30 +37,31 @@ class RetainContentDict(TypedDict, total=False):
36
37
  metadata: Custom key-value metadata (optional)
37
38
  document_id: Document ID for this content item (optional)
38
39
  """
40
+
39
41
  content: str # Required
40
42
  context: str
41
43
  event_date: datetime
42
- metadata: Dict[str, str]
44
+ metadata: dict[str, str]
43
45
  document_id: str
44
46
 
45
- from .query_analyzer import QueryAnalyzer
46
- from .search.scoring import (
47
- calculate_recency_weight,
48
- calculate_frequency_weight,
49
- )
47
+
48
+ from enum import Enum
49
+
50
+ from ..pg0 import EmbeddedPostgres
50
51
  from .entity_resolver import EntityResolver
51
- from .retain import embedding_utils, bank_utils
52
- from .search import think_utils, observation_utils
53
52
  from .llm_wrapper import LLMConfig
54
- from .response_models import RecallResult as RecallResultModel, ReflectResult, MemoryFact, EntityState, EntityObservation, VALID_RECALL_FACT_TYPES
55
- from .task_backend import TaskBackend, AsyncIOQueueBackend
53
+ from .query_analyzer import QueryAnalyzer
54
+ from .response_models import VALID_RECALL_FACT_TYPES, EntityObservation, EntityState, MemoryFact, ReflectResult
55
+ from .response_models import RecallResult as RecallResultModel
56
+ from .retain import bank_utils, embedding_utils
57
+ from .search import observation_utils, think_utils
56
58
  from .search.reranking import CrossEncoderReranker
57
- from ..pg0 import EmbeddedPostgres
58
- from enum import Enum
59
+ from .task_backend import AsyncIOQueueBackend, TaskBackend
59
60
 
60
61
 
61
62
  class Budget(str, Enum):
62
63
  """Budget levels for recall/reflect operations."""
64
+
63
65
  LOW = "low"
64
66
  MID = "mid"
65
67
  HIGH = "high"
@@ -67,20 +69,20 @@ class Budget(str, Enum):
67
69
 
68
70
  def utcnow():
69
71
  """Get current UTC time with timezone info."""
70
- return datetime.now(timezone.utc)
72
+ return datetime.now(UTC)
71
73
 
72
74
 
73
75
  # Logger for memory system
74
76
  logger = logging.getLogger(__name__)
75
77
 
76
- from .db_utils import acquire_with_retry, retry_with_backoff
77
-
78
78
  import tiktoken
79
- from dateutil import parser as date_parser
79
+
80
+ from .db_utils import acquire_with_retry
80
81
 
81
82
  # Cache tiktoken encoding for token budget filtering (module-level singleton)
82
83
  _TIKTOKEN_ENCODING = None
83
84
 
85
+
84
86
  def _get_tiktoken_encoding():
85
87
  """Get cached tiktoken encoding (cl100k_base for GPT-4/3.5)."""
86
88
  global _TIKTOKEN_ENCODING
@@ -102,17 +104,17 @@ class MemoryEngine:
102
104
 
103
105
  def __init__(
104
106
  self,
105
- db_url: Optional[str] = None,
106
- memory_llm_provider: Optional[str] = None,
107
- memory_llm_api_key: Optional[str] = None,
108
- memory_llm_model: Optional[str] = None,
109
- memory_llm_base_url: Optional[str] = None,
110
- embeddings: Optional[Embeddings] = None,
111
- cross_encoder: Optional[CrossEncoderModel] = None,
112
- query_analyzer: Optional[QueryAnalyzer] = None,
107
+ db_url: str | None = None,
108
+ memory_llm_provider: str | None = None,
109
+ memory_llm_api_key: str | None = None,
110
+ memory_llm_model: str | None = None,
111
+ memory_llm_base_url: str | None = None,
112
+ embeddings: Embeddings | None = None,
113
+ cross_encoder: CrossEncoderModel | None = None,
114
+ query_analyzer: QueryAnalyzer | None = None,
113
115
  pool_min_size: int = 5,
114
116
  pool_max_size: int = 100,
115
- task_backend: Optional[TaskBackend] = None,
117
+ task_backend: TaskBackend | None = None,
116
118
  run_migrations: bool = True,
117
119
  ):
118
120
  """
@@ -138,6 +140,7 @@ class MemoryEngine:
138
140
  """
139
141
  # Load config from environment for any missing parameters
140
142
  from ..config import get_config
143
+
141
144
  config = get_config()
142
145
 
143
146
  # Apply defaults from config
@@ -147,8 +150,8 @@ class MemoryEngine:
147
150
  memory_llm_model = memory_llm_model or config.llm_model
148
151
  memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
149
152
  # Track pg0 instance (if used)
150
- self._pg0: Optional[EmbeddedPostgres] = None
151
- self._pg0_instance_name: Optional[str] = None
153
+ self._pg0: EmbeddedPostgres | None = None
154
+ self._pg0_instance_name: str | None = None
152
155
 
153
156
  # Initialize PostgreSQL connection URL
154
157
  # The actual URL will be set during initialize() after starting the server
@@ -175,7 +178,6 @@ class MemoryEngine:
175
178
  self._pg0_port = None
176
179
  self.db_url = db_url
177
180
 
178
-
179
181
  # Set default base URL if not provided
180
182
  if memory_llm_base_url is None:
181
183
  if memory_llm_provider.lower() == "groq":
@@ -206,6 +208,7 @@ class MemoryEngine:
206
208
  self.query_analyzer = query_analyzer
207
209
  else:
208
210
  from .query_analyzer import DateparserQueryAnalyzer
211
+
209
212
  self.query_analyzer = DateparserQueryAnalyzer()
210
213
 
211
214
  # Initialize LLM configuration
@@ -224,10 +227,7 @@ class MemoryEngine:
224
227
  self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
225
228
 
226
229
  # Initialize task backend
227
- self._task_backend = task_backend or AsyncIOQueueBackend(
228
- batch_size=100,
229
- batch_interval=1.0
230
- )
230
+ self._task_backend = task_backend or AsyncIOQueueBackend(batch_size=100, batch_interval=1.0)
231
231
 
232
232
  # Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
233
233
  # Limit concurrent searches to prevent connection pool exhaustion
@@ -243,14 +243,14 @@ class MemoryEngine:
243
243
  # initialize encoding eagerly to avoid delaying the first time
244
244
  _get_tiktoken_encoding()
245
245
 
246
- async def _handle_access_count_update(self, task_dict: Dict[str, Any]):
246
+ async def _handle_access_count_update(self, task_dict: dict[str, Any]):
247
247
  """
248
248
  Handler for access count update tasks.
249
249
 
250
250
  Args:
251
251
  task_dict: Dict with 'node_ids' key containing list of node IDs to update
252
252
  """
253
- node_ids = task_dict.get('node_ids', [])
253
+ node_ids = task_dict.get("node_ids", [])
254
254
  if not node_ids:
255
255
  return
256
256
 
@@ -260,13 +260,12 @@ class MemoryEngine:
260
260
  uuid_list = [uuid.UUID(nid) for nid in node_ids]
261
261
  async with acquire_with_retry(pool) as conn:
262
262
  await conn.execute(
263
- "UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
264
- uuid_list
263
+ "UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])", uuid_list
265
264
  )
266
265
  except Exception as e:
267
266
  logger.error(f"Access count handler: Error updating access counts: {e}")
268
267
 
269
- async def _handle_batch_retain(self, task_dict: Dict[str, Any]):
268
+ async def _handle_batch_retain(self, task_dict: dict[str, Any]):
270
269
  """
271
270
  Handler for batch retain tasks.
272
271
 
@@ -274,23 +273,23 @@ class MemoryEngine:
274
273
  task_dict: Dict with 'bank_id', 'contents'
275
274
  """
276
275
  try:
277
- bank_id = task_dict.get('bank_id')
278
- contents = task_dict.get('contents', [])
279
-
280
- logger.info(f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items")
276
+ bank_id = task_dict.get("bank_id")
277
+ contents = task_dict.get("contents", [])
281
278
 
282
- await self.retain_batch_async(
283
- bank_id=bank_id,
284
- contents=contents
279
+ logger.info(
280
+ f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
285
281
  )
286
282
 
283
+ await self.retain_batch_async(bank_id=bank_id, contents=contents)
284
+
287
285
  logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
288
286
  except Exception as e:
289
287
  logger.error(f"Batch retain handler: Error processing batch retain: {e}")
290
288
  import traceback
289
+
291
290
  traceback.print_exc()
292
291
 
293
- async def execute_task(self, task_dict: Dict[str, Any]):
292
+ async def execute_task(self, task_dict: dict[str, Any]):
294
293
  """
295
294
  Execute a task by routing it to the appropriate handler.
296
295
 
@@ -301,9 +300,9 @@ class MemoryEngine:
301
300
  task_dict: Task dictionary with 'type' key and other payload data
302
301
  Example: {'type': 'access_count_update', 'node_ids': [...]}
303
302
  """
304
- task_type = task_dict.get('type')
305
- operation_id = task_dict.get('operation_id')
306
- retry_count = task_dict.get('retry_count', 0)
303
+ task_type = task_dict.get("type")
304
+ operation_id = task_dict.get("operation_id")
305
+ retry_count = task_dict.get("retry_count", 0)
307
306
  max_retries = 3
308
307
 
309
308
  # Check if operation was cancelled (only for tasks with operation_id)
@@ -312,8 +311,7 @@ class MemoryEngine:
312
311
  pool = await self._get_pool()
313
312
  async with acquire_with_retry(pool) as conn:
314
313
  result = await conn.fetchrow(
315
- "SELECT id FROM async_operations WHERE id = $1",
316
- uuid.UUID(operation_id)
314
+ "SELECT id FROM async_operations WHERE id = $1", uuid.UUID(operation_id)
317
315
  )
318
316
  if not result:
319
317
  # Operation was cancelled, skip processing
@@ -324,15 +322,15 @@ class MemoryEngine:
324
322
  # Continue with processing if we can't check status
325
323
 
326
324
  try:
327
- if task_type == 'access_count_update':
325
+ if task_type == "access_count_update":
328
326
  await self._handle_access_count_update(task_dict)
329
- elif task_type == 'reinforce_opinion':
327
+ elif task_type == "reinforce_opinion":
330
328
  await self._handle_reinforce_opinion(task_dict)
331
- elif task_type == 'form_opinion':
329
+ elif task_type == "form_opinion":
332
330
  await self._handle_form_opinion(task_dict)
333
- elif task_type == 'batch_retain':
331
+ elif task_type == "batch_retain":
334
332
  await self._handle_batch_retain(task_dict)
335
- elif task_type == 'regenerate_observations':
333
+ elif task_type == "regenerate_observations":
336
334
  await self._handle_regenerate_observations(task_dict)
337
335
  else:
338
336
  logger.error(f"Unknown task type: {task_type}")
@@ -347,14 +345,17 @@ class MemoryEngine:
347
345
 
348
346
  except Exception as e:
349
347
  # Task failed - check if we should retry
350
- logger.error(f"Task execution failed (attempt {retry_count + 1}/{max_retries + 1}): {task_type}, error: {e}")
348
+ logger.error(
349
+ f"Task execution failed (attempt {retry_count + 1}/{max_retries + 1}): {task_type}, error: {e}"
350
+ )
351
351
  import traceback
352
+
352
353
  error_traceback = traceback.format_exc()
353
354
  traceback.print_exc()
354
355
 
355
356
  if retry_count < max_retries:
356
357
  # Reschedule with incremented retry count
357
- task_dict['retry_count'] = retry_count + 1
358
+ task_dict["retry_count"] = retry_count + 1
358
359
  logger.info(f"Rescheduling task {task_type} (retry {retry_count + 1}/{max_retries})")
359
360
  await self._task_backend.submit_task(task_dict)
360
361
  else:
@@ -368,10 +369,7 @@ class MemoryEngine:
368
369
  try:
369
370
  pool = await self._get_pool()
370
371
  async with acquire_with_retry(pool) as conn:
371
- await conn.execute(
372
- "DELETE FROM async_operations WHERE id = $1",
373
- uuid.UUID(operation_id)
374
- )
372
+ await conn.execute("DELETE FROM async_operations WHERE id = $1", uuid.UUID(operation_id))
375
373
  except Exception as e:
376
374
  logger.error(f"Failed to delete async operation record {operation_id}: {e}")
377
375
 
@@ -391,7 +389,7 @@ class MemoryEngine:
391
389
  WHERE id = $1
392
390
  """,
393
391
  uuid.UUID(operation_id),
394
- truncated_error
392
+ truncated_error,
395
393
  )
396
394
  logger.info(f"Marked async operation as failed: {operation_id}")
397
395
  except Exception as e:
@@ -406,8 +404,6 @@ class MemoryEngine:
406
404
  if self._initialized:
407
405
  return
408
406
 
409
- import concurrent.futures
410
-
411
407
  # Run model loading in thread pool (CPU-bound) in parallel with pg0 startup
412
408
  loop = asyncio.get_event_loop()
413
409
 
@@ -429,10 +425,7 @@ class MemoryEngine:
429
425
  """Initialize embedding model."""
430
426
  # For local providers, run in thread pool to avoid blocking event loop
431
427
  if self.embeddings.provider_name == "local":
432
- await loop.run_in_executor(
433
- None,
434
- lambda: asyncio.run(self.embeddings.initialize())
435
- )
428
+ await loop.run_in_executor(None, lambda: asyncio.run(self.embeddings.initialize()))
436
429
  else:
437
430
  await self.embeddings.initialize()
438
431
 
@@ -441,10 +434,7 @@ class MemoryEngine:
441
434
  cross_encoder = self._cross_encoder_reranker.cross_encoder
442
435
  # For local providers, run in thread pool to avoid blocking event loop
443
436
  if cross_encoder.provider_name == "local":
444
- await loop.run_in_executor(
445
- None,
446
- lambda: asyncio.run(cross_encoder.initialize())
447
- )
437
+ await loop.run_in_executor(None, lambda: asyncio.run(cross_encoder.initialize()))
448
438
  else:
449
439
  await cross_encoder.initialize()
450
440
 
@@ -469,6 +459,7 @@ class MemoryEngine:
469
459
  # Run database migrations if enabled
470
460
  if self._run_migrations:
471
461
  from ..migrations import run_migrations
462
+
472
463
  logger.info("Running database migrations...")
473
464
  run_migrations(self.db_url)
474
465
 
@@ -555,7 +546,6 @@ class MemoryEngine:
555
546
  self._pg0 = None
556
547
  logger.info("pg0 stopped")
557
548
 
558
-
559
549
  async def wait_for_background_tasks(self):
560
550
  """
561
551
  Wait for all pending background tasks to complete.
@@ -563,7 +553,7 @@ class MemoryEngine:
563
553
  This is useful in tests to ensure background tasks (like opinion reinforcement)
564
554
  complete before making assertions.
565
555
  """
566
- if hasattr(self._task_backend, 'wait_for_pending_tasks'):
556
+ if hasattr(self._task_backend, "wait_for_pending_tasks"):
567
557
  await self._task_backend.wait_for_pending_tasks()
568
558
 
569
559
  def _format_readable_date(self, dt: datetime) -> str:
@@ -596,12 +586,12 @@ class MemoryEngine:
596
586
  self,
597
587
  conn,
598
588
  bank_id: str,
599
- texts: List[str],
600
- embeddings: List[List[float]],
589
+ texts: list[str],
590
+ embeddings: list[list[float]],
601
591
  event_date: datetime,
602
592
  time_window_hours: int = 24,
603
- similarity_threshold: float = 0.95
604
- ) -> List[bool]:
593
+ similarity_threshold: float = 0.95,
594
+ ) -> list[bool]:
605
595
  """
606
596
  Check which facts are duplicates using semantic similarity + temporal window.
607
597
 
@@ -635,6 +625,7 @@ class MemoryEngine:
635
625
 
636
626
  # Fetch ALL existing facts in time window ONCE (much faster than N queries)
637
627
  import time as time_mod
628
+
638
629
  fetch_start = time_mod.time()
639
630
  existing_facts = await conn.fetch(
640
631
  """
@@ -643,7 +634,9 @@ class MemoryEngine:
643
634
  WHERE bank_id = $1
644
635
  AND event_date BETWEEN $2 AND $3
645
636
  """,
646
- bank_id, time_lower, time_upper
637
+ bank_id,
638
+ time_lower,
639
+ time_upper,
647
640
  )
648
641
 
649
642
  # If no existing facts, nothing is duplicate
@@ -651,17 +644,17 @@ class MemoryEngine:
651
644
  return [False] * len(texts)
652
645
 
653
646
  # Compute similarities in Python (vectorized with numpy)
654
- import numpy as np
655
647
  is_duplicate = []
656
648
 
657
649
  # Convert existing embeddings to numpy for faster computation
658
650
  embedding_arrays = []
659
651
  for row in existing_facts:
660
- raw_emb = row['embedding']
652
+ raw_emb = row["embedding"]
661
653
  # Handle different pgvector formats
662
654
  if isinstance(raw_emb, str):
663
655
  # Parse string format: "[1.0, 2.0, ...]"
664
656
  import json
657
+
665
658
  emb = np.array(json.loads(raw_emb), dtype=np.float32)
666
659
  elif isinstance(raw_emb, (list, tuple)):
667
660
  emb = np.array(raw_emb, dtype=np.float32)
@@ -691,7 +684,6 @@ class MemoryEngine:
691
684
  max_similarity = np.max(similarities) if len(similarities) > 0 else 0
692
685
  is_duplicate.append(max_similarity > similarity_threshold)
693
686
 
694
-
695
687
  return is_duplicate
696
688
 
697
689
  def retain(
@@ -699,8 +691,8 @@ class MemoryEngine:
699
691
  bank_id: str,
700
692
  content: str,
701
693
  context: str = "",
702
- event_date: Optional[datetime] = None,
703
- ) -> List[str]:
694
+ event_date: datetime | None = None,
695
+ ) -> list[str]:
704
696
  """
705
697
  Store content as memory units (synchronous wrapper).
706
698
 
@@ -724,11 +716,11 @@ class MemoryEngine:
724
716
  bank_id: str,
725
717
  content: str,
726
718
  context: str = "",
727
- event_date: Optional[datetime] = None,
728
- document_id: Optional[str] = None,
729
- fact_type_override: Optional[str] = None,
730
- confidence_score: Optional[float] = None,
731
- ) -> List[str]:
719
+ event_date: datetime | None = None,
720
+ document_id: str | None = None,
721
+ fact_type_override: str | None = None,
722
+ confidence_score: float | None = None,
723
+ ) -> list[str]:
732
724
  """
733
725
  Store content as memory units with temporal and semantic links (ASYNC version).
734
726
 
@@ -747,11 +739,7 @@ class MemoryEngine:
747
739
  List of created unit IDs
748
740
  """
749
741
  # Build content dict
750
- content_dict: RetainContentDict = {
751
- "content": content,
752
- "context": context,
753
- "event_date": event_date
754
- }
742
+ content_dict: RetainContentDict = {"content": content, "context": context, "event_date": event_date}
755
743
  if document_id:
756
744
  content_dict["document_id"] = document_id
757
745
 
@@ -760,7 +748,7 @@ class MemoryEngine:
760
748
  bank_id=bank_id,
761
749
  contents=[content_dict],
762
750
  fact_type_override=fact_type_override,
763
- confidence_score=confidence_score
751
+ confidence_score=confidence_score,
764
752
  )
765
753
 
766
754
  # Return the first (and only) list of unit IDs
@@ -769,11 +757,11 @@ class MemoryEngine:
769
757
  async def retain_batch_async(
770
758
  self,
771
759
  bank_id: str,
772
- contents: List[RetainContentDict],
773
- document_id: Optional[str] = None,
774
- fact_type_override: Optional[str] = None,
775
- confidence_score: Optional[float] = None,
776
- ) -> List[List[str]]:
760
+ contents: list[RetainContentDict],
761
+ document_id: str | None = None,
762
+ fact_type_override: str | None = None,
763
+ confidence_score: float | None = None,
764
+ ) -> list[list[str]]:
777
765
  """
778
766
  Store multiple content items as memory units in ONE batch operation.
779
767
 
@@ -839,7 +827,9 @@ class MemoryEngine:
839
827
 
840
828
  if total_chars > CHARS_PER_BATCH:
841
829
  # Split into smaller batches based on character count
842
- logger.info(f"Large batch detected ({total_chars:,} chars from {len(contents)} items). Splitting into sub-batches of ~{CHARS_PER_BATCH:,} chars each...")
830
+ logger.info(
831
+ f"Large batch detected ({total_chars:,} chars from {len(contents)} items). Splitting into sub-batches of ~{CHARS_PER_BATCH:,} chars each..."
832
+ )
843
833
 
844
834
  sub_batches = []
845
835
  current_batch = []
@@ -868,7 +858,9 @@ class MemoryEngine:
868
858
  all_results = []
869
859
  for i, sub_batch in enumerate(sub_batches, 1):
870
860
  sub_batch_chars = sum(len(item.get("content", "")) for item in sub_batch)
871
- logger.info(f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars")
861
+ logger.info(
862
+ f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars"
863
+ )
872
864
 
873
865
  sub_results = await self._retain_batch_async_internal(
874
866
  bank_id=bank_id,
@@ -876,12 +868,14 @@ class MemoryEngine:
876
868
  document_id=document_id,
877
869
  is_first_batch=i == 1, # Only upsert on first batch
878
870
  fact_type_override=fact_type_override,
879
- confidence_score=confidence_score
871
+ confidence_score=confidence_score,
880
872
  )
881
873
  all_results.extend(sub_results)
882
874
 
883
875
  total_time = time.time() - start_time
884
- logger.info(f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s")
876
+ logger.info(
877
+ f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s"
878
+ )
885
879
  return all_results
886
880
 
887
881
  # Small batch - use internal method directly
@@ -891,18 +885,18 @@ class MemoryEngine:
891
885
  document_id=document_id,
892
886
  is_first_batch=True,
893
887
  fact_type_override=fact_type_override,
894
- confidence_score=confidence_score
888
+ confidence_score=confidence_score,
895
889
  )
896
890
 
897
891
  async def _retain_batch_async_internal(
898
892
  self,
899
893
  bank_id: str,
900
- contents: List[RetainContentDict],
901
- document_id: Optional[str] = None,
894
+ contents: list[RetainContentDict],
895
+ document_id: str | None = None,
902
896
  is_first_batch: bool = True,
903
- fact_type_override: Optional[str] = None,
904
- confidence_score: Optional[float] = None,
905
- ) -> List[List[str]]:
897
+ fact_type_override: str | None = None,
898
+ confidence_score: float | None = None,
899
+ ) -> list[list[str]]:
906
900
  """
907
901
  Internal method for batch processing without chunking logic.
908
902
 
@@ -938,7 +932,7 @@ class MemoryEngine:
938
932
  document_id=document_id,
939
933
  is_first_batch=is_first_batch,
940
934
  fact_type_override=fact_type_override,
941
- confidence_score=confidence_score
935
+ confidence_score=confidence_score,
942
936
  )
943
937
 
944
938
  def recall(
@@ -949,7 +943,7 @@ class MemoryEngine:
949
943
  budget: Budget = Budget.MID,
950
944
  max_tokens: int = 4096,
951
945
  enable_trace: bool = False,
952
- ) -> tuple[List[Dict[str, Any]], Optional[Any]]:
946
+ ) -> tuple[list[dict[str, Any]], Any | None]:
953
947
  """
954
948
  Recall memories using 4-way parallel retrieval (synchronous wrapper).
955
949
 
@@ -968,19 +962,17 @@ class MemoryEngine:
968
962
  Tuple of (results, trace)
969
963
  """
970
964
  # Run async version synchronously
971
- return asyncio.run(self.recall_async(
972
- bank_id, query, [fact_type], budget, max_tokens, enable_trace
973
- ))
965
+ return asyncio.run(self.recall_async(bank_id, query, [fact_type], budget, max_tokens, enable_trace))
974
966
 
975
967
  async def recall_async(
976
968
  self,
977
969
  bank_id: str,
978
970
  query: str,
979
- fact_type: List[str],
971
+ fact_type: list[str],
980
972
  budget: Budget = Budget.MID,
981
973
  max_tokens: int = 4096,
982
974
  enable_trace: bool = False,
983
- question_date: Optional[datetime] = None,
975
+ question_date: datetime | None = None,
984
976
  include_entities: bool = False,
985
977
  max_entity_tokens: int = 1024,
986
978
  include_chunks: bool = False,
@@ -1027,11 +1019,7 @@ class MemoryEngine:
1027
1019
  )
1028
1020
 
1029
1021
  # Map budget enum to thinking_budget number
1030
- budget_mapping = {
1031
- Budget.LOW: 100,
1032
- Budget.MID: 300,
1033
- Budget.HIGH: 1000
1034
- }
1022
+ budget_mapping = {Budget.LOW: 100, Budget.MID: 300, Budget.HIGH: 1000}
1035
1023
  thinking_budget = budget_mapping[budget]
1036
1024
 
1037
1025
  # Backpressure: limit concurrent recalls to prevent overwhelming the database
@@ -1041,20 +1029,29 @@ class MemoryEngine:
1041
1029
  for attempt in range(max_retries + 1):
1042
1030
  try:
1043
1031
  return await self._search_with_retries(
1044
- bank_id, query, fact_type, thinking_budget, max_tokens, enable_trace, question_date,
1045
- include_entities, max_entity_tokens, include_chunks, max_chunk_tokens
1032
+ bank_id,
1033
+ query,
1034
+ fact_type,
1035
+ thinking_budget,
1036
+ max_tokens,
1037
+ enable_trace,
1038
+ question_date,
1039
+ include_entities,
1040
+ max_entity_tokens,
1041
+ include_chunks,
1042
+ max_chunk_tokens,
1046
1043
  )
1047
1044
  except Exception as e:
1048
1045
  # Check if it's a connection error
1049
1046
  is_connection_error = (
1050
- isinstance(e, asyncpg.TooManyConnectionsError) or
1051
- isinstance(e, asyncpg.CannotConnectNowError) or
1052
- (isinstance(e, asyncpg.PostgresError) and 'connection' in str(e).lower())
1047
+ isinstance(e, asyncpg.TooManyConnectionsError)
1048
+ or isinstance(e, asyncpg.CannotConnectNowError)
1049
+ or (isinstance(e, asyncpg.PostgresError) and "connection" in str(e).lower())
1053
1050
  )
1054
1051
 
1055
1052
  if is_connection_error and attempt < max_retries:
1056
1053
  # Wait with exponential backoff before retry
1057
- wait_time = 0.5 * (2 ** attempt) # 0.5s, 1s, 2s
1054
+ wait_time = 0.5 * (2**attempt) # 0.5s, 1s, 2s
1058
1055
  logger.warning(
1059
1056
  f"Connection error on search attempt {attempt + 1}/{max_retries + 1}: {str(e)}. "
1060
1057
  f"Retrying in {wait_time:.1f}s..."
@@ -1069,11 +1066,11 @@ class MemoryEngine:
1069
1066
  self,
1070
1067
  bank_id: str,
1071
1068
  query: str,
1072
- fact_type: List[str],
1069
+ fact_type: list[str],
1073
1070
  thinking_budget: int,
1074
1071
  max_tokens: int,
1075
1072
  enable_trace: bool,
1076
- question_date: Optional[datetime] = None,
1073
+ question_date: datetime | None = None,
1077
1074
  include_entities: bool = False,
1078
1075
  max_entity_tokens: int = 500,
1079
1076
  include_chunks: bool = False,
@@ -1106,6 +1103,7 @@ class MemoryEngine:
1106
1103
  """
1107
1104
  # Initialize tracer if requested
1108
1105
  from .search.tracer import SearchTracer
1106
+
1109
1107
  tracer = SearchTracer(query, thinking_budget, max_tokens) if enable_trace else None
1110
1108
  if tracer:
1111
1109
  tracer.start()
@@ -1116,7 +1114,9 @@ class MemoryEngine:
1116
1114
  # Buffer logs for clean output in concurrent scenarios
1117
1115
  recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
1118
1116
  log_buffer = []
1119
- log_buffer.append(f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})")
1117
+ log_buffer.append(
1118
+ f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})"
1119
+ )
1120
1120
 
1121
1121
  try:
1122
1122
  # Step 1: Generate query embedding (for semantic search)
@@ -1141,8 +1141,7 @@ class MemoryEngine:
1141
1141
  # Run retrieval for each fact type in parallel
1142
1142
  retrieval_tasks = [
1143
1143
  retrieve_parallel(
1144
- pool, query, query_embedding_str, bank_id, ft, thinking_budget,
1145
- question_date, self.query_analyzer
1144
+ pool, query, query_embedding_str, bank_id, ft, thinking_budget, question_date, self.query_analyzer
1146
1145
  )
1147
1146
  for ft in fact_type
1148
1147
  ]
@@ -1159,7 +1158,9 @@ class MemoryEngine:
1159
1158
  for idx, retrieval_result in enumerate(all_retrievals):
1160
1159
  # Log fact types in this retrieval batch
1161
1160
  ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1162
- logger.debug(f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}")
1161
+ logger.debug(
1162
+ f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}"
1163
+ )
1163
1164
 
1164
1165
  semantic_results.extend(retrieval_result.semantic)
1165
1166
  bm25_results.extend(retrieval_result.bm25)
@@ -1179,11 +1180,13 @@ class MemoryEngine:
1179
1180
 
1180
1181
  # Sort combined results by score (descending) so higher-scored results
1181
1182
  # get better ranks in the trace, regardless of fact type
1182
- semantic_results.sort(key=lambda r: r.similarity if hasattr(r, 'similarity') else 0, reverse=True)
1183
- bm25_results.sort(key=lambda r: r.bm25_score if hasattr(r, 'bm25_score') else 0, reverse=True)
1184
- graph_results.sort(key=lambda r: r.activation if hasattr(r, 'activation') else 0, reverse=True)
1183
+ semantic_results.sort(key=lambda r: r.similarity if hasattr(r, "similarity") else 0, reverse=True)
1184
+ bm25_results.sort(key=lambda r: r.bm25_score if hasattr(r, "bm25_score") else 0, reverse=True)
1185
+ graph_results.sort(key=lambda r: r.activation if hasattr(r, "activation") else 0, reverse=True)
1185
1186
  if temporal_results:
1186
- temporal_results.sort(key=lambda r: r.combined_score if hasattr(r, 'combined_score') else 0, reverse=True)
1187
+ temporal_results.sort(
1188
+ key=lambda r: r.combined_score if hasattr(r, "combined_score") else 0, reverse=True
1189
+ )
1187
1190
 
1188
1191
  retrieval_duration = time.time() - retrieval_start
1189
1192
 
@@ -1193,7 +1196,7 @@ class MemoryEngine:
1193
1196
  timing_parts = [
1194
1197
  f"semantic={len(semantic_results)}({aggregated_timings['semantic']:.3f}s)",
1195
1198
  f"bm25={len(bm25_results)}({aggregated_timings['bm25']:.3f}s)",
1196
- f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)"
1199
+ f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)",
1197
1200
  ]
1198
1201
  temporal_info = ""
1199
1202
  if detected_temporal_constraint:
@@ -1201,7 +1204,9 @@ class MemoryEngine:
1201
1204
  temporal_count = len(temporal_results) if temporal_results else 0
1202
1205
  timing_parts.append(f"temporal={temporal_count}({aggregated_timings['temporal']:.3f}s)")
1203
1206
  temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
1204
- log_buffer.append(f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}")
1207
+ log_buffer.append(
1208
+ f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}"
1209
+ )
1205
1210
 
1206
1211
  # Record retrieval results for tracer - per fact type
1207
1212
  if tracer:
@@ -1220,7 +1225,7 @@ class MemoryEngine:
1220
1225
  duration_seconds=rr.timings.get("semantic", 0.0),
1221
1226
  score_field="similarity",
1222
1227
  metadata={"limit": thinking_budget},
1223
- fact_type=ft_name
1228
+ fact_type=ft_name,
1224
1229
  )
1225
1230
 
1226
1231
  # Add BM25 retrieval results for this fact type
@@ -1230,7 +1235,7 @@ class MemoryEngine:
1230
1235
  duration_seconds=rr.timings.get("bm25", 0.0),
1231
1236
  score_field="bm25_score",
1232
1237
  metadata={"limit": thinking_budget},
1233
- fact_type=ft_name
1238
+ fact_type=ft_name,
1234
1239
  )
1235
1240
 
1236
1241
  # Add graph retrieval results for this fact type
@@ -1240,7 +1245,7 @@ class MemoryEngine:
1240
1245
  duration_seconds=rr.timings.get("graph", 0.0),
1241
1246
  score_field="activation",
1242
1247
  metadata={"budget": thinking_budget},
1243
- fact_type=ft_name
1248
+ fact_type=ft_name,
1244
1249
  )
1245
1250
 
1246
1251
  # Add temporal retrieval results for this fact type (even if empty, to show it ran)
@@ -1251,19 +1256,23 @@ class MemoryEngine:
1251
1256
  duration_seconds=rr.timings.get("temporal", 0.0),
1252
1257
  score_field="temporal_score",
1253
1258
  metadata={"budget": thinking_budget},
1254
- fact_type=ft_name
1259
+ fact_type=ft_name,
1255
1260
  )
1256
1261
 
1257
1262
  # Record entry points (from semantic results) for legacy graph view
1258
1263
  for rank, retrieval in enumerate(semantic_results[:10], start=1): # Top 10 as entry points
1259
1264
  tracer.add_entry_point(retrieval.id, retrieval.text, retrieval.similarity or 0.0, rank)
1260
1265
 
1261
- tracer.add_phase_metric("parallel_retrieval", step_duration, {
1262
- "semantic_count": len(semantic_results),
1263
- "bm25_count": len(bm25_results),
1264
- "graph_count": len(graph_results),
1265
- "temporal_count": len(temporal_results) if temporal_results else 0
1266
- })
1266
+ tracer.add_phase_metric(
1267
+ "parallel_retrieval",
1268
+ step_duration,
1269
+ {
1270
+ "semantic_count": len(semantic_results),
1271
+ "bm25_count": len(bm25_results),
1272
+ "graph_count": len(graph_results),
1273
+ "temporal_count": len(temporal_results) if temporal_results else 0,
1274
+ },
1275
+ )
1267
1276
 
1268
1277
  # Step 3: Merge with RRF
1269
1278
  step_start = time.time()
@@ -1271,7 +1280,9 @@ class MemoryEngine:
1271
1280
 
1272
1281
  # Merge 3 or 4 result lists depending on temporal constraint
1273
1282
  if temporal_results:
1274
- merged_candidates = reciprocal_rank_fusion([semantic_results, bm25_results, graph_results, temporal_results])
1283
+ merged_candidates = reciprocal_rank_fusion(
1284
+ [semantic_results, bm25_results, graph_results, temporal_results]
1285
+ )
1275
1286
  else:
1276
1287
  merged_candidates = reciprocal_rank_fusion([semantic_results, bm25_results, graph_results])
1277
1288
 
@@ -1280,8 +1291,10 @@ class MemoryEngine:
1280
1291
 
1281
1292
  if tracer:
1282
1293
  # Convert MergedCandidate to old tuple format for tracer
1283
- tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1284
- for mc in merged_candidates]
1294
+ tracer_merged = [
1295
+ (mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1296
+ for mc in merged_candidates
1297
+ ]
1285
1298
  tracer.add_rrf_merged(tracer_merged)
1286
1299
  tracer.add_phase_metric("rrf_merge", step_duration, {"candidates_merged": len(merged_candidates)})
1287
1300
 
@@ -1318,14 +1331,15 @@ class MemoryEngine:
1318
1331
  sr.recency = 0.5 # default for missing dates
1319
1332
  if sr.retrieval.occurred_start:
1320
1333
  occurred = sr.retrieval.occurred_start
1321
- if hasattr(occurred, 'tzinfo') and occurred.tzinfo is None:
1322
- from datetime import timezone
1323
- occurred = occurred.replace(tzinfo=timezone.utc)
1334
+ if hasattr(occurred, "tzinfo") and occurred.tzinfo is None:
1335
+ occurred = occurred.replace(tzinfo=UTC)
1324
1336
  days_ago = (now - occurred).total_seconds() / 86400
1325
1337
  sr.recency = max(0.1, 1.0 - (days_ago / 365)) # Linear decay over 1 year
1326
1338
 
1327
1339
  # Get temporal proximity if available (already 0-1)
1328
- sr.temporal = sr.retrieval.temporal_proximity if sr.retrieval.temporal_proximity is not None else 0.5
1340
+ sr.temporal = (
1341
+ sr.retrieval.temporal_proximity if sr.retrieval.temporal_proximity is not None else 0.5
1342
+ )
1329
1343
 
1330
1344
  # Weighted combination
1331
1345
  # Cross-encoder: 60% (semantic relevance)
@@ -1333,27 +1347,32 @@ class MemoryEngine:
1333
1347
  # Temporal proximity: 10% (time relevance for temporal queries)
1334
1348
  # Recency: 10% (prefer recent facts)
1335
1349
  sr.combined_score = (
1336
- 0.6 * sr.cross_encoder_score_normalized +
1337
- 0.2 * sr.rrf_normalized +
1338
- 0.1 * sr.temporal +
1339
- 0.1 * sr.recency
1350
+ 0.6 * sr.cross_encoder_score_normalized
1351
+ + 0.2 * sr.rrf_normalized
1352
+ + 0.1 * sr.temporal
1353
+ + 0.1 * sr.recency
1340
1354
  )
1341
1355
  sr.weight = sr.combined_score # Update weight for final ranking
1342
1356
 
1343
1357
  # Re-sort by combined score
1344
1358
  scored_results.sort(key=lambda x: x.weight, reverse=True)
1345
- log_buffer.append(f" [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)")
1359
+ log_buffer.append(
1360
+ " [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)"
1361
+ )
1346
1362
 
1347
1363
  # Add reranked results to tracer AFTER combined scoring (so normalized values are included)
1348
1364
  if tracer:
1349
1365
  results_dict = [sr.to_dict() for sr in scored_results]
1350
- tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1351
- for mc in merged_candidates]
1366
+ tracer_merged = [
1367
+ (mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1368
+ for mc in merged_candidates
1369
+ ]
1352
1370
  tracer.add_reranked(results_dict, tracer_merged)
1353
- tracer.add_phase_metric("reranking", step_duration, {
1354
- "reranker_type": "cross-encoder",
1355
- "candidates_reranked": len(scored_results)
1356
- })
1371
+ tracer.add_phase_metric(
1372
+ "reranking",
1373
+ step_duration,
1374
+ {"reranker_type": "cross-encoder", "candidates_reranked": len(scored_results)},
1375
+ )
1357
1376
 
1358
1377
  # Step 5: Truncate to thinking_budget * 2 for token filtering
1359
1378
  rerank_limit = thinking_budget * 2
@@ -1372,14 +1391,16 @@ class MemoryEngine:
1372
1391
  top_scored = [sr for sr in top_scored if sr.id in filtered_ids]
1373
1392
 
1374
1393
  step_duration = time.time() - step_start
1375
- log_buffer.append(f" [6] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s")
1394
+ log_buffer.append(
1395
+ f" [6] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s"
1396
+ )
1376
1397
 
1377
1398
  if tracer:
1378
- tracer.add_phase_metric("token_filtering", step_duration, {
1379
- "results_selected": len(top_scored),
1380
- "tokens_used": total_tokens,
1381
- "max_tokens": max_tokens
1382
- })
1399
+ tracer.add_phase_metric(
1400
+ "token_filtering",
1401
+ step_duration,
1402
+ {"results_selected": len(top_scored), "tokens_used": total_tokens, "max_tokens": max_tokens},
1403
+ )
1383
1404
 
1384
1405
  # Record visits for all retrieved nodes
1385
1406
  if tracer:
@@ -1398,16 +1419,13 @@ class MemoryEngine:
1398
1419
  semantic_similarity=sr.retrieval.similarity or 0.0,
1399
1420
  recency=sr.recency,
1400
1421
  frequency=0.0,
1401
- final_weight=sr.weight
1422
+ final_weight=sr.weight,
1402
1423
  )
1403
1424
 
1404
1425
  # Step 8: Queue access count updates for visited nodes
1405
1426
  visited_ids = list(set([sr.id for sr in scored_results[:50]])) # Top 50
1406
1427
  if visited_ids:
1407
- await self._task_backend.submit_task({
1408
- 'type': 'access_count_update',
1409
- 'node_ids': visited_ids
1410
- })
1428
+ await self._task_backend.submit_task({"type": "access_count_update", "node_ids": visited_ids})
1411
1429
  log_buffer.append(f" [7] Queued access count updates for {len(visited_ids)} nodes")
1412
1430
 
1413
1431
  # Log fact_type distribution in results
@@ -1425,13 +1443,19 @@ class MemoryEngine:
1425
1443
  # Convert datetime objects to ISO strings for JSON serialization
1426
1444
  if result_dict.get("occurred_start"):
1427
1445
  occurred_start = result_dict["occurred_start"]
1428
- result_dict["occurred_start"] = occurred_start.isoformat() if hasattr(occurred_start, 'isoformat') else occurred_start
1446
+ result_dict["occurred_start"] = (
1447
+ occurred_start.isoformat() if hasattr(occurred_start, "isoformat") else occurred_start
1448
+ )
1429
1449
  if result_dict.get("occurred_end"):
1430
1450
  occurred_end = result_dict["occurred_end"]
1431
- result_dict["occurred_end"] = occurred_end.isoformat() if hasattr(occurred_end, 'isoformat') else occurred_end
1451
+ result_dict["occurred_end"] = (
1452
+ occurred_end.isoformat() if hasattr(occurred_end, "isoformat") else occurred_end
1453
+ )
1432
1454
  if result_dict.get("mentioned_at"):
1433
1455
  mentioned_at = result_dict["mentioned_at"]
1434
- result_dict["mentioned_at"] = mentioned_at.isoformat() if hasattr(mentioned_at, 'isoformat') else mentioned_at
1456
+ result_dict["mentioned_at"] = (
1457
+ mentioned_at.isoformat() if hasattr(mentioned_at, "isoformat") else mentioned_at
1458
+ )
1435
1459
  top_results_dicts.append(result_dict)
1436
1460
 
1437
1461
  # Get entities for each fact if include_entities is requested
@@ -1447,16 +1471,15 @@ class MemoryEngine:
1447
1471
  JOIN entities e ON ue.entity_id = e.id
1448
1472
  WHERE ue.unit_id = ANY($1::uuid[])
1449
1473
  """,
1450
- unit_ids
1474
+ unit_ids,
1451
1475
  )
1452
1476
  for row in entity_rows:
1453
- unit_id = str(row['unit_id'])
1477
+ unit_id = str(row["unit_id"])
1454
1478
  if unit_id not in fact_entity_map:
1455
1479
  fact_entity_map[unit_id] = []
1456
- fact_entity_map[unit_id].append({
1457
- 'entity_id': str(row['entity_id']),
1458
- 'canonical_name': row['canonical_name']
1459
- })
1480
+ fact_entity_map[unit_id].append(
1481
+ {"entity_id": str(row["entity_id"]), "canonical_name": row["canonical_name"]}
1482
+ )
1460
1483
 
1461
1484
  # Convert results to MemoryFact objects
1462
1485
  memory_facts = []
@@ -1465,20 +1488,22 @@ class MemoryEngine:
1465
1488
  # Get entity names for this fact
1466
1489
  entity_names = None
1467
1490
  if include_entities and result_id in fact_entity_map:
1468
- entity_names = [e['canonical_name'] for e in fact_entity_map[result_id]]
1469
-
1470
- memory_facts.append(MemoryFact(
1471
- id=result_id,
1472
- text=result_dict.get("text"),
1473
- fact_type=result_dict.get("fact_type", "world"),
1474
- entities=entity_names,
1475
- context=result_dict.get("context"),
1476
- occurred_start=result_dict.get("occurred_start"),
1477
- occurred_end=result_dict.get("occurred_end"),
1478
- mentioned_at=result_dict.get("mentioned_at"),
1479
- document_id=result_dict.get("document_id"),
1480
- chunk_id=result_dict.get("chunk_id"),
1481
- ))
1491
+ entity_names = [e["canonical_name"] for e in fact_entity_map[result_id]]
1492
+
1493
+ memory_facts.append(
1494
+ MemoryFact(
1495
+ id=result_id,
1496
+ text=result_dict.get("text"),
1497
+ fact_type=result_dict.get("fact_type", "world"),
1498
+ entities=entity_names,
1499
+ context=result_dict.get("context"),
1500
+ occurred_start=result_dict.get("occurred_start"),
1501
+ occurred_end=result_dict.get("occurred_end"),
1502
+ mentioned_at=result_dict.get("mentioned_at"),
1503
+ document_id=result_dict.get("document_id"),
1504
+ chunk_id=result_dict.get("chunk_id"),
1505
+ )
1506
+ )
1482
1507
 
1483
1508
  # Fetch entity observations if requested
1484
1509
  entities_dict = None
@@ -1495,8 +1520,8 @@ class MemoryEngine:
1495
1520
  unit_id = sr.id
1496
1521
  if unit_id in fact_entity_map:
1497
1522
  for entity in fact_entity_map[unit_id]:
1498
- entity_id = entity['entity_id']
1499
- entity_name = entity['canonical_name']
1523
+ entity_id = entity["entity_id"]
1524
+ entity_name = entity["canonical_name"]
1500
1525
  if entity_id not in seen_entity_ids:
1501
1526
  entities_ordered.append((entity_id, entity_name))
1502
1527
  seen_entity_ids.add(entity_id)
@@ -1524,9 +1549,7 @@ class MemoryEngine:
1524
1549
 
1525
1550
  if included_observations:
1526
1551
  entities_dict[entity_name] = EntityState(
1527
- entity_id=entity_id,
1528
- canonical_name=entity_name,
1529
- observations=included_observations
1552
+ entity_id=entity_id, canonical_name=entity_name, observations=included_observations
1530
1553
  )
1531
1554
  total_entity_tokens += entity_tokens
1532
1555
 
@@ -1554,11 +1577,11 @@ class MemoryEngine:
1554
1577
  FROM chunks
1555
1578
  WHERE chunk_id = ANY($1::text[])
1556
1579
  """,
1557
- chunk_ids_ordered
1580
+ chunk_ids_ordered,
1558
1581
  )
1559
1582
 
1560
1583
  # Create a lookup dict for fast access
1561
- chunks_lookup = {row['chunk_id']: row for row in chunks_rows}
1584
+ chunks_lookup = {row["chunk_id"]: row for row in chunks_rows}
1562
1585
 
1563
1586
  # Apply token limit and build chunks_dict in the order of chunk_ids_ordered
1564
1587
  chunks_dict = {}
@@ -1569,7 +1592,7 @@ class MemoryEngine:
1569
1592
  continue
1570
1593
 
1571
1594
  row = chunks_lookup[chunk_id]
1572
- chunk_text = row['chunk_text']
1595
+ chunk_text = row["chunk_text"]
1573
1596
  chunk_tokens = len(encoding.encode(chunk_text))
1574
1597
 
1575
1598
  # Check if adding this chunk would exceed the limit
@@ -1580,18 +1603,14 @@ class MemoryEngine:
1580
1603
  # Truncate to remaining tokens
1581
1604
  truncated_text = encoding.decode(encoding.encode(chunk_text)[:remaining_tokens])
1582
1605
  chunks_dict[chunk_id] = ChunkInfo(
1583
- chunk_text=truncated_text,
1584
- chunk_index=row['chunk_index'],
1585
- truncated=True
1606
+ chunk_text=truncated_text, chunk_index=row["chunk_index"], truncated=True
1586
1607
  )
1587
1608
  total_chunk_tokens = max_chunk_tokens
1588
1609
  # Stop adding more chunks once we hit the limit
1589
1610
  break
1590
1611
  else:
1591
1612
  chunks_dict[chunk_id] = ChunkInfo(
1592
- chunk_text=chunk_text,
1593
- chunk_index=row['chunk_index'],
1594
- truncated=False
1613
+ chunk_text=chunk_text, chunk_index=row["chunk_index"], truncated=False
1595
1614
  )
1596
1615
  total_chunk_tokens += chunk_tokens
1597
1616
 
@@ -1605,7 +1624,9 @@ class MemoryEngine:
1605
1624
  total_time = time.time() - recall_start
1606
1625
  num_chunks = len(chunks_dict) if chunks_dict else 0
1607
1626
  num_entities = len(entities_dict) if entities_dict else 0
1608
- log_buffer.append(f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s")
1627
+ log_buffer.append(
1628
+ f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s"
1629
+ )
1609
1630
  logger.info("\n" + "\n".join(log_buffer))
1610
1631
 
1611
1632
  return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
@@ -1616,10 +1637,8 @@ class MemoryEngine:
1616
1637
  raise Exception(f"Failed to search memories: {str(e)}")
1617
1638
 
1618
1639
  def _filter_by_token_budget(
1619
- self,
1620
- results: List[Dict[str, Any]],
1621
- max_tokens: int
1622
- ) -> Tuple[List[Dict[str, Any]], int]:
1640
+ self, results: list[dict[str, Any]], max_tokens: int
1641
+ ) -> tuple[list[dict[str, Any]], int]:
1623
1642
  """
1624
1643
  Filter results to fit within token budget.
1625
1644
 
@@ -1652,7 +1671,7 @@ class MemoryEngine:
1652
1671
 
1653
1672
  return filtered_results, total_tokens
1654
1673
 
1655
- async def get_document(self, document_id: str, bank_id: str) -> Optional[Dict[str, Any]]:
1674
+ async def get_document(self, document_id: str, bank_id: str) -> dict[str, Any] | None:
1656
1675
  """
1657
1676
  Retrieve document metadata and statistics.
1658
1677
 
@@ -1674,7 +1693,8 @@ class MemoryEngine:
1674
1693
  WHERE d.id = $1 AND d.bank_id = $2
1675
1694
  GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
1676
1695
  """,
1677
- document_id, bank_id
1696
+ document_id,
1697
+ bank_id,
1678
1698
  )
1679
1699
 
1680
1700
  if not doc:
@@ -1687,10 +1707,10 @@ class MemoryEngine:
1687
1707
  "content_hash": doc["content_hash"],
1688
1708
  "memory_unit_count": doc["unit_count"],
1689
1709
  "created_at": doc["created_at"],
1690
- "updated_at": doc["updated_at"]
1710
+ "updated_at": doc["updated_at"],
1691
1711
  }
1692
1712
 
1693
- async def delete_document(self, document_id: str, bank_id: str) -> Dict[str, int]:
1713
+ async def delete_document(self, document_id: str, bank_id: str) -> dict[str, int]:
1694
1714
  """
1695
1715
  Delete a document and all its associated memory units and links.
1696
1716
 
@@ -1706,22 +1726,17 @@ class MemoryEngine:
1706
1726
  async with conn.transaction():
1707
1727
  # Count units before deletion
1708
1728
  units_count = await conn.fetchval(
1709
- "SELECT COUNT(*) FROM memory_units WHERE document_id = $1",
1710
- document_id
1729
+ "SELECT COUNT(*) FROM memory_units WHERE document_id = $1", document_id
1711
1730
  )
1712
1731
 
1713
1732
  # Delete document (cascades to memory_units and all their links)
1714
1733
  deleted = await conn.fetchval(
1715
- "DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id",
1716
- document_id, bank_id
1734
+ "DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id", document_id, bank_id
1717
1735
  )
1718
1736
 
1719
- return {
1720
- "document_deleted": 1 if deleted else 0,
1721
- "memory_units_deleted": units_count if deleted else 0
1722
- }
1737
+ return {"document_deleted": 1 if deleted else 0, "memory_units_deleted": units_count if deleted else 0}
1723
1738
 
1724
- async def delete_memory_unit(self, unit_id: str) -> Dict[str, Any]:
1739
+ async def delete_memory_unit(self, unit_id: str) -> dict[str, Any]:
1725
1740
  """
1726
1741
  Delete a single memory unit and all its associated links.
1727
1742
 
@@ -1740,18 +1755,17 @@ class MemoryEngine:
1740
1755
  async with acquire_with_retry(pool) as conn:
1741
1756
  async with conn.transaction():
1742
1757
  # Delete the memory unit (cascades to links and associations)
1743
- deleted = await conn.fetchval(
1744
- "DELETE FROM memory_units WHERE id = $1 RETURNING id",
1745
- unit_id
1746
- )
1758
+ deleted = await conn.fetchval("DELETE FROM memory_units WHERE id = $1 RETURNING id", unit_id)
1747
1759
 
1748
1760
  return {
1749
1761
  "success": deleted is not None,
1750
1762
  "unit_id": str(deleted) if deleted else None,
1751
- "message": "Memory unit and all its links deleted successfully" if deleted else "Memory unit not found"
1763
+ "message": "Memory unit and all its links deleted successfully"
1764
+ if deleted
1765
+ else "Memory unit not found",
1752
1766
  }
1753
1767
 
1754
- async def delete_bank(self, bank_id: str, fact_type: Optional[str] = None) -> Dict[str, int]:
1768
+ async def delete_bank(self, bank_id: str, fact_type: str | None = None) -> dict[str, int]:
1755
1769
  """
1756
1770
  Delete all data for a specific agent (multi-tenant cleanup).
1757
1771
 
@@ -1780,24 +1794,27 @@ class MemoryEngine:
1780
1794
  # Delete only memories of a specific fact type
1781
1795
  units_count = await conn.fetchval(
1782
1796
  "SELECT COUNT(*) FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
1783
- bank_id, fact_type
1797
+ bank_id,
1798
+ fact_type,
1784
1799
  )
1785
1800
  await conn.execute(
1786
- "DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
1787
- bank_id, fact_type
1801
+ "DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2", bank_id, fact_type
1788
1802
  )
1789
1803
 
1790
1804
  # Note: We don't delete entities when fact_type is specified,
1791
1805
  # as they may be referenced by other memory units
1792
- return {
1793
- "memory_units_deleted": units_count,
1794
- "entities_deleted": 0
1795
- }
1806
+ return {"memory_units_deleted": units_count, "entities_deleted": 0}
1796
1807
  else:
1797
1808
  # Delete all data for the bank
1798
- units_count = await conn.fetchval("SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id)
1799
- entities_count = await conn.fetchval("SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id)
1800
- documents_count = await conn.fetchval("SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id)
1809
+ units_count = await conn.fetchval(
1810
+ "SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id
1811
+ )
1812
+ entities_count = await conn.fetchval(
1813
+ "SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id
1814
+ )
1815
+ documents_count = await conn.fetchval(
1816
+ "SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id
1817
+ )
1801
1818
 
1802
1819
  # Delete documents (cascades to chunks)
1803
1820
  await conn.execute("DELETE FROM documents WHERE bank_id = $1", bank_id)
@@ -1815,13 +1832,13 @@ class MemoryEngine:
1815
1832
  "memory_units_deleted": units_count,
1816
1833
  "entities_deleted": entities_count,
1817
1834
  "documents_deleted": documents_count,
1818
- "bank_deleted": True
1835
+ "bank_deleted": True,
1819
1836
  }
1820
1837
 
1821
1838
  except Exception as e:
1822
1839
  raise Exception(f"Failed to delete agent data: {str(e)}")
1823
1840
 
1824
- async def get_graph_data(self, bank_id: Optional[str] = None, fact_type: Optional[str] = None):
1841
+ async def get_graph_data(self, bank_id: str | None = None, fact_type: str | None = None):
1825
1842
  """
1826
1843
  Get graph data for visualization.
1827
1844
 
@@ -1851,19 +1868,23 @@ class MemoryEngine:
1851
1868
 
1852
1869
  where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
1853
1870
 
1854
- units = await conn.fetch(f"""
1871
+ units = await conn.fetch(
1872
+ f"""
1855
1873
  SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
1856
1874
  FROM memory_units
1857
1875
  {where_clause}
1858
1876
  ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
1859
1877
  LIMIT 1000
1860
- """, *query_params)
1878
+ """,
1879
+ *query_params,
1880
+ )
1861
1881
 
1862
1882
  # Get links, filtering to only include links between units of the selected agent
1863
1883
  # Use DISTINCT ON with LEAST/GREATEST to deduplicate bidirectional links
1864
- unit_ids = [row['id'] for row in units]
1884
+ unit_ids = [row["id"] for row in units]
1865
1885
  if unit_ids:
1866
- links = await conn.fetch("""
1886
+ links = await conn.fetch(
1887
+ """
1867
1888
  SELECT DISTINCT ON (LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid))
1868
1889
  ml.from_unit_id,
1869
1890
  ml.to_unit_id,
@@ -1874,7 +1895,9 @@ class MemoryEngine:
1874
1895
  LEFT JOIN entities e ON ml.entity_id = e.id
1875
1896
  WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.to_unit_id = ANY($1::uuid[])
1876
1897
  ORDER BY LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid), ml.weight DESC
1877
- """, unit_ids)
1898
+ """,
1899
+ unit_ids,
1900
+ )
1878
1901
  else:
1879
1902
  links = []
1880
1903
 
@@ -1889,8 +1912,8 @@ class MemoryEngine:
1889
1912
  # Build entity mapping
1890
1913
  entity_map = {}
1891
1914
  for row in unit_entities:
1892
- unit_id = row['unit_id']
1893
- entity_name = row['canonical_name']
1915
+ unit_id = row["unit_id"]
1916
+ entity_name = row["canonical_name"]
1894
1917
  if unit_id not in entity_map:
1895
1918
  entity_map[unit_id] = []
1896
1919
  entity_map[unit_id].append(entity_name)
@@ -1898,10 +1921,10 @@ class MemoryEngine:
1898
1921
  # Build nodes
1899
1922
  nodes = []
1900
1923
  for row in units:
1901
- unit_id = row['id']
1902
- text = row['text']
1903
- event_date = row['event_date']
1904
- context = row['context']
1924
+ unit_id = row["id"]
1925
+ text = row["text"]
1926
+ event_date = row["event_date"]
1927
+ context = row["context"]
1905
1928
 
1906
1929
  entities = entity_map.get(unit_id, [])
1907
1930
  entity_count = len(entities)
@@ -1914,88 +1937,91 @@ class MemoryEngine:
1914
1937
  else:
1915
1938
  color = "#42a5f5"
1916
1939
 
1917
- nodes.append({
1918
- "data": {
1919
- "id": str(unit_id),
1920
- "label": f"{text[:30]}..." if len(text) > 30 else text,
1921
- "text": text,
1922
- "date": event_date.isoformat() if event_date else "",
1923
- "context": context if context else "",
1924
- "entities": ", ".join(entities) if entities else "None",
1925
- "color": color
1940
+ nodes.append(
1941
+ {
1942
+ "data": {
1943
+ "id": str(unit_id),
1944
+ "label": f"{text[:30]}..." if len(text) > 30 else text,
1945
+ "text": text,
1946
+ "date": event_date.isoformat() if event_date else "",
1947
+ "context": context if context else "",
1948
+ "entities": ", ".join(entities) if entities else "None",
1949
+ "color": color,
1950
+ }
1926
1951
  }
1927
- })
1952
+ )
1928
1953
 
1929
1954
  # Build edges
1930
1955
  edges = []
1931
1956
  for row in links:
1932
- from_id = str(row['from_unit_id'])
1933
- to_id = str(row['to_unit_id'])
1934
- link_type = row['link_type']
1935
- weight = row['weight']
1936
- entity_name = row['entity_name']
1957
+ from_id = str(row["from_unit_id"])
1958
+ to_id = str(row["to_unit_id"])
1959
+ link_type = row["link_type"]
1960
+ weight = row["weight"]
1961
+ entity_name = row["entity_name"]
1937
1962
 
1938
1963
  # Color by link type
1939
- if link_type == 'temporal':
1964
+ if link_type == "temporal":
1940
1965
  color = "#00bcd4"
1941
1966
  line_style = "dashed"
1942
- elif link_type == 'semantic':
1967
+ elif link_type == "semantic":
1943
1968
  color = "#ff69b4"
1944
1969
  line_style = "solid"
1945
- elif link_type == 'entity':
1970
+ elif link_type == "entity":
1946
1971
  color = "#ffd700"
1947
1972
  line_style = "solid"
1948
1973
  else:
1949
1974
  color = "#999999"
1950
1975
  line_style = "solid"
1951
1976
 
1952
- edges.append({
1953
- "data": {
1954
- "id": f"{from_id}-{to_id}-{link_type}",
1955
- "source": from_id,
1956
- "target": to_id,
1957
- "linkType": link_type,
1958
- "weight": weight,
1959
- "entityName": entity_name if entity_name else "",
1960
- "color": color,
1961
- "lineStyle": line_style
1977
+ edges.append(
1978
+ {
1979
+ "data": {
1980
+ "id": f"{from_id}-{to_id}-{link_type}",
1981
+ "source": from_id,
1982
+ "target": to_id,
1983
+ "linkType": link_type,
1984
+ "weight": weight,
1985
+ "entityName": entity_name if entity_name else "",
1986
+ "color": color,
1987
+ "lineStyle": line_style,
1988
+ }
1962
1989
  }
1963
- })
1990
+ )
1964
1991
 
1965
1992
  # Build table rows
1966
1993
  table_rows = []
1967
1994
  for row in units:
1968
- unit_id = row['id']
1995
+ unit_id = row["id"]
1969
1996
  entities = entity_map.get(unit_id, [])
1970
1997
 
1971
- table_rows.append({
1972
- "id": str(unit_id),
1973
- "text": row['text'],
1974
- "context": row['context'] if row['context'] else "N/A",
1975
- "occurred_start": row['occurred_start'].isoformat() if row['occurred_start'] else None,
1976
- "occurred_end": row['occurred_end'].isoformat() if row['occurred_end'] else None,
1977
- "mentioned_at": row['mentioned_at'].isoformat() if row['mentioned_at'] else None,
1978
- "date": row['event_date'].strftime("%Y-%m-%d %H:%M") if row['event_date'] else "N/A", # Deprecated, kept for backwards compatibility
1979
- "entities": ", ".join(entities) if entities else "None",
1980
- "document_id": row['document_id'],
1981
- "chunk_id": row['chunk_id'] if row['chunk_id'] else None,
1982
- "fact_type": row['fact_type']
1983
- })
1984
-
1985
- return {
1986
- "nodes": nodes,
1987
- "edges": edges,
1988
- "table_rows": table_rows,
1989
- "total_units": len(units)
1990
- }
1998
+ table_rows.append(
1999
+ {
2000
+ "id": str(unit_id),
2001
+ "text": row["text"],
2002
+ "context": row["context"] if row["context"] else "N/A",
2003
+ "occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
2004
+ "occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
2005
+ "mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
2006
+ "date": row["event_date"].strftime("%Y-%m-%d %H:%M")
2007
+ if row["event_date"]
2008
+ else "N/A", # Deprecated, kept for backwards compatibility
2009
+ "entities": ", ".join(entities) if entities else "None",
2010
+ "document_id": row["document_id"],
2011
+ "chunk_id": row["chunk_id"] if row["chunk_id"] else None,
2012
+ "fact_type": row["fact_type"],
2013
+ }
2014
+ )
2015
+
2016
+ return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units": len(units)}
1991
2017
 
1992
2018
  async def list_memory_units(
1993
2019
  self,
1994
- bank_id: Optional[str] = None,
1995
- fact_type: Optional[str] = None,
1996
- search_query: Optional[str] = None,
2020
+ bank_id: str | None = None,
2021
+ fact_type: str | None = None,
2022
+ search_query: str | None = None,
1997
2023
  limit: int = 100,
1998
- offset: int = 0
2024
+ offset: int = 0,
1999
2025
  ):
2000
2026
  """
2001
2027
  List memory units for table view with optional full-text search.
@@ -2042,7 +2068,7 @@ class MemoryEngine:
2042
2068
  {where_clause}
2043
2069
  """
2044
2070
  count_result = await conn.fetchrow(count_query, *query_params)
2045
- total = count_result['total']
2071
+ total = count_result["total"]
2046
2072
 
2047
2073
  # Get units with limit and offset
2048
2074
  param_count += 1
@@ -2053,32 +2079,38 @@ class MemoryEngine:
2053
2079
  offset_param = f"${param_count}"
2054
2080
  query_params.append(offset)
2055
2081
 
2056
- units = await conn.fetch(f"""
2082
+ units = await conn.fetch(
2083
+ f"""
2057
2084
  SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id
2058
2085
  FROM memory_units
2059
2086
  {where_clause}
2060
2087
  ORDER BY mentioned_at DESC NULLS LAST, created_at DESC
2061
2088
  LIMIT {limit_param} OFFSET {offset_param}
2062
- """, *query_params)
2089
+ """,
2090
+ *query_params,
2091
+ )
2063
2092
 
2064
2093
  # Get entity information for these units
2065
2094
  if units:
2066
- unit_ids = [row['id'] for row in units]
2067
- unit_entities = await conn.fetch("""
2095
+ unit_ids = [row["id"] for row in units]
2096
+ unit_entities = await conn.fetch(
2097
+ """
2068
2098
  SELECT ue.unit_id, e.canonical_name
2069
2099
  FROM unit_entities ue
2070
2100
  JOIN entities e ON ue.entity_id = e.id
2071
2101
  WHERE ue.unit_id = ANY($1::uuid[])
2072
2102
  ORDER BY ue.unit_id
2073
- """, unit_ids)
2103
+ """,
2104
+ unit_ids,
2105
+ )
2074
2106
  else:
2075
2107
  unit_entities = []
2076
2108
 
2077
2109
  # Build entity mapping
2078
2110
  entity_map = {}
2079
2111
  for row in unit_entities:
2080
- unit_id = row['unit_id']
2081
- entity_name = row['canonical_name']
2112
+ unit_id = row["unit_id"]
2113
+ entity_name = row["canonical_name"]
2082
2114
  if unit_id not in entity_map:
2083
2115
  entity_map[unit_id] = []
2084
2116
  entity_map[unit_id].append(entity_name)
@@ -2086,36 +2118,27 @@ class MemoryEngine:
2086
2118
  # Build result items
2087
2119
  items = []
2088
2120
  for row in units:
2089
- unit_id = row['id']
2121
+ unit_id = row["id"]
2090
2122
  entities = entity_map.get(unit_id, [])
2091
2123
 
2092
- items.append({
2093
- "id": str(unit_id),
2094
- "text": row['text'],
2095
- "context": row['context'] if row['context'] else "",
2096
- "date": row['event_date'].isoformat() if row['event_date'] else "",
2097
- "fact_type": row['fact_type'],
2098
- "mentioned_at": row['mentioned_at'].isoformat() if row['mentioned_at'] else None,
2099
- "occurred_start": row['occurred_start'].isoformat() if row['occurred_start'] else None,
2100
- "occurred_end": row['occurred_end'].isoformat() if row['occurred_end'] else None,
2101
- "entities": ", ".join(entities) if entities else "",
2102
- "chunk_id": row['chunk_id'] if row['chunk_id'] else None
2103
- })
2124
+ items.append(
2125
+ {
2126
+ "id": str(unit_id),
2127
+ "text": row["text"],
2128
+ "context": row["context"] if row["context"] else "",
2129
+ "date": row["event_date"].isoformat() if row["event_date"] else "",
2130
+ "fact_type": row["fact_type"],
2131
+ "mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
2132
+ "occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
2133
+ "occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
2134
+ "entities": ", ".join(entities) if entities else "",
2135
+ "chunk_id": row["chunk_id"] if row["chunk_id"] else None,
2136
+ }
2137
+ )
2104
2138
 
2105
- return {
2106
- "items": items,
2107
- "total": total,
2108
- "limit": limit,
2109
- "offset": offset
2110
- }
2139
+ return {"items": items, "total": total, "limit": limit, "offset": offset}
2111
2140
 
2112
- async def list_documents(
2113
- self,
2114
- bank_id: str,
2115
- search_query: Optional[str] = None,
2116
- limit: int = 100,
2117
- offset: int = 0
2118
- ):
2141
+ async def list_documents(self, bank_id: str, search_query: str | None = None, limit: int = 100, offset: int = 0):
2119
2142
  """
2120
2143
  List documents with optional search and pagination.
2121
2144
 
@@ -2154,7 +2177,7 @@ class MemoryEngine:
2154
2177
  {where_clause}
2155
2178
  """
2156
2179
  count_result = await conn.fetchrow(count_query, *query_params)
2157
- total = count_result['total']
2180
+ total = count_result["total"]
2158
2181
 
2159
2182
  # Get documents with limit and offset (without original_text for performance)
2160
2183
  param_count += 1
@@ -2165,7 +2188,8 @@ class MemoryEngine:
2165
2188
  offset_param = f"${param_count}"
2166
2189
  query_params.append(offset)
2167
2190
 
2168
- documents = await conn.fetch(f"""
2191
+ documents = await conn.fetch(
2192
+ f"""
2169
2193
  SELECT
2170
2194
  id,
2171
2195
  bank_id,
@@ -2178,11 +2202,13 @@ class MemoryEngine:
2178
2202
  {where_clause}
2179
2203
  ORDER BY created_at DESC
2180
2204
  LIMIT {limit_param} OFFSET {offset_param}
2181
- """, *query_params)
2205
+ """,
2206
+ *query_params,
2207
+ )
2182
2208
 
2183
2209
  # Get memory unit count for each document
2184
2210
  if documents:
2185
- doc_ids = [(row['id'], row['bank_id']) for row in documents]
2211
+ doc_ids = [(row["id"], row["bank_id"]) for row in documents]
2186
2212
 
2187
2213
  # Create placeholders for the query
2188
2214
  placeholders = []
@@ -2195,48 +2221,44 @@ class MemoryEngine:
2195
2221
 
2196
2222
  where_clause_count = " OR ".join(placeholders)
2197
2223
 
2198
- unit_counts = await conn.fetch(f"""
2224
+ unit_counts = await conn.fetch(
2225
+ f"""
2199
2226
  SELECT document_id, bank_id, COUNT(*) as unit_count
2200
2227
  FROM memory_units
2201
2228
  WHERE {where_clause_count}
2202
2229
  GROUP BY document_id, bank_id
2203
- """, *params_for_count)
2230
+ """,
2231
+ *params_for_count,
2232
+ )
2204
2233
  else:
2205
2234
  unit_counts = []
2206
2235
 
2207
2236
  # Build count mapping
2208
- count_map = {(row['document_id'], row['bank_id']): row['unit_count'] for row in unit_counts}
2237
+ count_map = {(row["document_id"], row["bank_id"]): row["unit_count"] for row in unit_counts}
2209
2238
 
2210
2239
  # Build result items
2211
2240
  items = []
2212
2241
  for row in documents:
2213
- doc_id = row['id']
2214
- bank_id_val = row['bank_id']
2242
+ doc_id = row["id"]
2243
+ bank_id_val = row["bank_id"]
2215
2244
  unit_count = count_map.get((doc_id, bank_id_val), 0)
2216
2245
 
2217
- items.append({
2218
- "id": doc_id,
2219
- "bank_id": bank_id_val,
2220
- "content_hash": row['content_hash'],
2221
- "created_at": row['created_at'].isoformat() if row['created_at'] else "",
2222
- "updated_at": row['updated_at'].isoformat() if row['updated_at'] else "",
2223
- "text_length": row['text_length'] or 0,
2224
- "memory_unit_count": unit_count,
2225
- "retain_params": row['retain_params'] if row['retain_params'] else None
2226
- })
2246
+ items.append(
2247
+ {
2248
+ "id": doc_id,
2249
+ "bank_id": bank_id_val,
2250
+ "content_hash": row["content_hash"],
2251
+ "created_at": row["created_at"].isoformat() if row["created_at"] else "",
2252
+ "updated_at": row["updated_at"].isoformat() if row["updated_at"] else "",
2253
+ "text_length": row["text_length"] or 0,
2254
+ "memory_unit_count": unit_count,
2255
+ "retain_params": row["retain_params"] if row["retain_params"] else None,
2256
+ }
2257
+ )
2227
2258
 
2228
- return {
2229
- "items": items,
2230
- "total": total,
2231
- "limit": limit,
2232
- "offset": offset
2233
- }
2259
+ return {"items": items, "total": total, "limit": limit, "offset": offset}
2234
2260
 
2235
- async def get_document(
2236
- self,
2237
- document_id: str,
2238
- bank_id: str
2239
- ):
2261
+ async def get_document(self, document_id: str, bank_id: str):
2240
2262
  """
2241
2263
  Get a specific document including its original_text.
2242
2264
 
@@ -2249,7 +2271,8 @@ class MemoryEngine:
2249
2271
  """
2250
2272
  pool = await self._get_pool()
2251
2273
  async with acquire_with_retry(pool) as conn:
2252
- doc = await conn.fetchrow("""
2274
+ doc = await conn.fetchrow(
2275
+ """
2253
2276
  SELECT
2254
2277
  id,
2255
2278
  bank_id,
@@ -2260,33 +2283,37 @@ class MemoryEngine:
2260
2283
  retain_params
2261
2284
  FROM documents
2262
2285
  WHERE id = $1 AND bank_id = $2
2263
- """, document_id, bank_id)
2286
+ """,
2287
+ document_id,
2288
+ bank_id,
2289
+ )
2264
2290
 
2265
2291
  if not doc:
2266
2292
  return None
2267
2293
 
2268
2294
  # Get memory unit count
2269
- unit_count_row = await conn.fetchrow("""
2295
+ unit_count_row = await conn.fetchrow(
2296
+ """
2270
2297
  SELECT COUNT(*) as unit_count
2271
2298
  FROM memory_units
2272
2299
  WHERE document_id = $1 AND bank_id = $2
2273
- """, document_id, bank_id)
2300
+ """,
2301
+ document_id,
2302
+ bank_id,
2303
+ )
2274
2304
 
2275
2305
  return {
2276
- "id": doc['id'],
2277
- "bank_id": doc['bank_id'],
2278
- "original_text": doc['original_text'],
2279
- "content_hash": doc['content_hash'],
2280
- "created_at": doc['created_at'].isoformat() if doc['created_at'] else "",
2281
- "updated_at": doc['updated_at'].isoformat() if doc['updated_at'] else "",
2282
- "memory_unit_count": unit_count_row['unit_count'] if unit_count_row else 0,
2283
- "retain_params": doc['retain_params'] if doc['retain_params'] else None
2306
+ "id": doc["id"],
2307
+ "bank_id": doc["bank_id"],
2308
+ "original_text": doc["original_text"],
2309
+ "content_hash": doc["content_hash"],
2310
+ "created_at": doc["created_at"].isoformat() if doc["created_at"] else "",
2311
+ "updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else "",
2312
+ "memory_unit_count": unit_count_row["unit_count"] if unit_count_row else 0,
2313
+ "retain_params": doc["retain_params"] if doc["retain_params"] else None,
2284
2314
  }
2285
2315
 
2286
- async def get_chunk(
2287
- self,
2288
- chunk_id: str
2289
- ):
2316
+ async def get_chunk(self, chunk_id: str):
2290
2317
  """
2291
2318
  Get a specific chunk by its ID.
2292
2319
 
@@ -2298,7 +2325,8 @@ class MemoryEngine:
2298
2325
  """
2299
2326
  pool = await self._get_pool()
2300
2327
  async with acquire_with_retry(pool) as conn:
2301
- chunk = await conn.fetchrow("""
2328
+ chunk = await conn.fetchrow(
2329
+ """
2302
2330
  SELECT
2303
2331
  chunk_id,
2304
2332
  document_id,
@@ -2308,18 +2336,20 @@ class MemoryEngine:
2308
2336
  created_at
2309
2337
  FROM chunks
2310
2338
  WHERE chunk_id = $1
2311
- """, chunk_id)
2339
+ """,
2340
+ chunk_id,
2341
+ )
2312
2342
 
2313
2343
  if not chunk:
2314
2344
  return None
2315
2345
 
2316
2346
  return {
2317
- "chunk_id": chunk['chunk_id'],
2318
- "document_id": chunk['document_id'],
2319
- "bank_id": chunk['bank_id'],
2320
- "chunk_index": chunk['chunk_index'],
2321
- "chunk_text": chunk['chunk_text'],
2322
- "created_at": chunk['created_at'].isoformat() if chunk['created_at'] else ""
2347
+ "chunk_id": chunk["chunk_id"],
2348
+ "document_id": chunk["document_id"],
2349
+ "bank_id": chunk["bank_id"],
2350
+ "chunk_index": chunk["chunk_index"],
2351
+ "chunk_text": chunk["chunk_text"],
2352
+ "created_at": chunk["created_at"].isoformat() if chunk["created_at"] else "",
2323
2353
  }
2324
2354
 
2325
2355
  async def _evaluate_opinion_update_async(
@@ -2328,7 +2358,7 @@ class MemoryEngine:
2328
2358
  opinion_confidence: float,
2329
2359
  new_event_text: str,
2330
2360
  entity_name: str,
2331
- ) -> Optional[Dict[str, Any]]:
2361
+ ) -> dict[str, Any] | None:
2332
2362
  """
2333
2363
  Evaluate if an opinion should be updated based on a new event.
2334
2364
 
@@ -2342,16 +2372,18 @@ class MemoryEngine:
2342
2372
  Dict with 'action' ('keep'|'update'), 'new_confidence', 'new_text' (if action=='update')
2343
2373
  or None if no changes needed
2344
2374
  """
2345
- from pydantic import BaseModel, Field
2346
2375
 
2347
2376
  class OpinionEvaluation(BaseModel):
2348
2377
  """Evaluation of whether an opinion should be updated."""
2378
+
2349
2379
  action: str = Field(description="Action to take: 'keep' (no change) or 'update' (modify opinion)")
2350
2380
  reasoning: str = Field(description="Brief explanation of why this action was chosen")
2351
- new_confidence: float = Field(description="New confidence score (0.0-1.0). Can be higher, lower, or same as before.")
2352
- new_opinion_text: Optional[str] = Field(
2381
+ new_confidence: float = Field(
2382
+ description="New confidence score (0.0-1.0). Can be higher, lower, or same as before."
2383
+ )
2384
+ new_opinion_text: str | None = Field(
2353
2385
  default=None,
2354
- description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None."
2386
+ description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None.",
2355
2387
  )
2356
2388
 
2357
2389
  evaluation_prompt = f"""You are evaluating whether an existing opinion should be updated based on new information.
@@ -2381,70 +2413,63 @@ Guidelines:
2381
2413
  result = await self._llm_config.call(
2382
2414
  messages=[
2383
2415
  {"role": "system", "content": "You evaluate and update opinions based on new information."},
2384
- {"role": "user", "content": evaluation_prompt}
2416
+ {"role": "user", "content": evaluation_prompt},
2385
2417
  ],
2386
2418
  response_format=OpinionEvaluation,
2387
2419
  scope="memory_evaluate_opinion",
2388
- temperature=0.3 # Lower temperature for more consistent evaluation
2420
+ temperature=0.3, # Lower temperature for more consistent evaluation
2389
2421
  )
2390
2422
 
2391
2423
  # Only return updates if something actually changed
2392
- if result.action == 'keep' and abs(result.new_confidence - opinion_confidence) < 0.01:
2424
+ if result.action == "keep" and abs(result.new_confidence - opinion_confidence) < 0.01:
2393
2425
  return None
2394
2426
 
2395
2427
  return {
2396
- 'action': result.action,
2397
- 'reasoning': result.reasoning,
2398
- 'new_confidence': result.new_confidence,
2399
- 'new_text': result.new_opinion_text if result.action == 'update' else None
2428
+ "action": result.action,
2429
+ "reasoning": result.reasoning,
2430
+ "new_confidence": result.new_confidence,
2431
+ "new_text": result.new_opinion_text if result.action == "update" else None,
2400
2432
  }
2401
2433
 
2402
2434
  except Exception as e:
2403
2435
  logger.warning(f"Failed to evaluate opinion update: {str(e)}")
2404
2436
  return None
2405
2437
 
2406
- async def _handle_form_opinion(self, task_dict: Dict[str, Any]):
2438
+ async def _handle_form_opinion(self, task_dict: dict[str, Any]):
2407
2439
  """
2408
2440
  Handler for form opinion tasks.
2409
2441
 
2410
2442
  Args:
2411
2443
  task_dict: Dict with keys: 'bank_id', 'answer_text', 'query'
2412
2444
  """
2413
- bank_id = task_dict['bank_id']
2414
- answer_text = task_dict['answer_text']
2415
- query = task_dict['query']
2445
+ bank_id = task_dict["bank_id"]
2446
+ answer_text = task_dict["answer_text"]
2447
+ query = task_dict["query"]
2416
2448
 
2417
- await self._extract_and_store_opinions_async(
2418
- bank_id=bank_id,
2419
- answer_text=answer_text,
2420
- query=query
2421
- )
2449
+ await self._extract_and_store_opinions_async(bank_id=bank_id, answer_text=answer_text, query=query)
2422
2450
 
2423
- async def _handle_reinforce_opinion(self, task_dict: Dict[str, Any]):
2451
+ async def _handle_reinforce_opinion(self, task_dict: dict[str, Any]):
2424
2452
  """
2425
2453
  Handler for reinforce opinion tasks.
2426
2454
 
2427
2455
  Args:
2428
2456
  task_dict: Dict with keys: 'bank_id', 'created_unit_ids', 'unit_texts', 'unit_entities'
2429
2457
  """
2430
- bank_id = task_dict['bank_id']
2431
- created_unit_ids = task_dict['created_unit_ids']
2432
- unit_texts = task_dict['unit_texts']
2433
- unit_entities = task_dict['unit_entities']
2458
+ bank_id = task_dict["bank_id"]
2459
+ created_unit_ids = task_dict["created_unit_ids"]
2460
+ unit_texts = task_dict["unit_texts"]
2461
+ unit_entities = task_dict["unit_entities"]
2434
2462
 
2435
2463
  await self._reinforce_opinions_async(
2436
- bank_id=bank_id,
2437
- created_unit_ids=created_unit_ids,
2438
- unit_texts=unit_texts,
2439
- unit_entities=unit_entities
2464
+ bank_id=bank_id, created_unit_ids=created_unit_ids, unit_texts=unit_texts, unit_entities=unit_entities
2440
2465
  )
2441
2466
 
2442
2467
  async def _reinforce_opinions_async(
2443
2468
  self,
2444
2469
  bank_id: str,
2445
- created_unit_ids: List[str],
2446
- unit_texts: List[str],
2447
- unit_entities: List[List[Dict[str, str]]],
2470
+ created_unit_ids: list[str],
2471
+ unit_texts: list[str],
2472
+ unit_entities: list[list[dict[str, str]]],
2448
2473
  ):
2449
2474
  """
2450
2475
  Background task to reinforce opinions based on newly ingested events.
@@ -2463,15 +2488,14 @@ Guidelines:
2463
2488
  for entities_list in unit_entities:
2464
2489
  for entity in entities_list:
2465
2490
  # Handle both Entity objects and dicts
2466
- if hasattr(entity, 'text'):
2491
+ if hasattr(entity, "text"):
2467
2492
  entity_names.add(entity.text)
2468
2493
  elif isinstance(entity, dict):
2469
- entity_names.add(entity['text'])
2494
+ entity_names.add(entity["text"])
2470
2495
 
2471
2496
  if not entity_names:
2472
2497
  return
2473
2498
 
2474
-
2475
2499
  pool = await self._get_pool()
2476
2500
  async with acquire_with_retry(pool) as conn:
2477
2501
  # Find all opinions related to these entities
@@ -2486,13 +2510,12 @@ Guidelines:
2486
2510
  AND e.canonical_name = ANY($2::text[])
2487
2511
  """,
2488
2512
  bank_id,
2489
- list(entity_names)
2513
+ list(entity_names),
2490
2514
  )
2491
2515
 
2492
2516
  if not opinions:
2493
2517
  return
2494
2518
 
2495
-
2496
2519
  # Use cached LLM config
2497
2520
  if self._llm_config is None:
2498
2521
  logger.error("[REINFORCE] LLM config not available, skipping opinion reinforcement")
@@ -2501,15 +2524,15 @@ Guidelines:
2501
2524
  # Evaluate each opinion against the new events
2502
2525
  updates_to_apply = []
2503
2526
  for opinion in opinions:
2504
- opinion_id = str(opinion['id'])
2505
- opinion_text = opinion['text']
2506
- opinion_confidence = opinion['confidence_score']
2507
- entity_name = opinion['canonical_name']
2527
+ opinion_id = str(opinion["id"])
2528
+ opinion_text = opinion["text"]
2529
+ opinion_confidence = opinion["confidence_score"]
2530
+ entity_name = opinion["canonical_name"]
2508
2531
 
2509
2532
  # Find all new events mentioning this entity
2510
2533
  relevant_events = []
2511
2534
  for unit_text, entities_list in zip(unit_texts, unit_entities):
2512
- if any(e['text'] == entity_name for e in entities_list):
2535
+ if any(e["text"] == entity_name for e in entities_list):
2513
2536
  relevant_events.append(unit_text)
2514
2537
 
2515
2538
  if not relevant_events:
@@ -2520,26 +2543,20 @@ Guidelines:
2520
2543
 
2521
2544
  # Evaluate if opinion should be updated
2522
2545
  evaluation = await self._evaluate_opinion_update_async(
2523
- opinion_text,
2524
- opinion_confidence,
2525
- combined_events,
2526
- entity_name
2546
+ opinion_text, opinion_confidence, combined_events, entity_name
2527
2547
  )
2528
2548
 
2529
2549
  if evaluation:
2530
- updates_to_apply.append({
2531
- 'opinion_id': opinion_id,
2532
- 'evaluation': evaluation
2533
- })
2550
+ updates_to_apply.append({"opinion_id": opinion_id, "evaluation": evaluation})
2534
2551
 
2535
2552
  # Apply all updates in a single transaction
2536
2553
  if updates_to_apply:
2537
2554
  async with conn.transaction():
2538
2555
  for update in updates_to_apply:
2539
- opinion_id = update['opinion_id']
2540
- evaluation = update['evaluation']
2556
+ opinion_id = update["opinion_id"]
2557
+ evaluation = update["evaluation"]
2541
2558
 
2542
- if evaluation['action'] == 'update' and evaluation['new_text']:
2559
+ if evaluation["action"] == "update" and evaluation["new_text"]:
2543
2560
  # Update both text and confidence
2544
2561
  await conn.execute(
2545
2562
  """
@@ -2547,9 +2564,9 @@ Guidelines:
2547
2564
  SET text = $1, confidence_score = $2, updated_at = NOW()
2548
2565
  WHERE id = $3
2549
2566
  """,
2550
- evaluation['new_text'],
2551
- evaluation['new_confidence'],
2552
- uuid.UUID(opinion_id)
2567
+ evaluation["new_text"],
2568
+ evaluation["new_confidence"],
2569
+ uuid.UUID(opinion_id),
2553
2570
  )
2554
2571
  else:
2555
2572
  # Only update confidence
@@ -2559,8 +2576,8 @@ Guidelines:
2559
2576
  SET confidence_score = $1, updated_at = NOW()
2560
2577
  WHERE id = $2
2561
2578
  """,
2562
- evaluation['new_confidence'],
2563
- uuid.UUID(opinion_id)
2579
+ evaluation["new_confidence"],
2580
+ uuid.UUID(opinion_id),
2564
2581
  )
2565
2582
 
2566
2583
  else:
@@ -2569,6 +2586,7 @@ Guidelines:
2569
2586
  except Exception as e:
2570
2587
  logger.error(f"[REINFORCE] Error during opinion reinforcement: {str(e)}")
2571
2588
  import traceback
2589
+
2572
2590
  traceback.print_exc()
2573
2591
 
2574
2592
  # ==================== bank profile Methods ====================
@@ -2587,11 +2605,7 @@ Guidelines:
2587
2605
  pool = await self._get_pool()
2588
2606
  return await bank_utils.get_bank_profile(pool, bank_id)
2589
2607
 
2590
- async def update_bank_disposition(
2591
- self,
2592
- bank_id: str,
2593
- disposition: Dict[str, int]
2594
- ) -> None:
2608
+ async def update_bank_disposition(self, bank_id: str, disposition: dict[str, int]) -> None:
2595
2609
  """
2596
2610
  Update bank disposition traits.
2597
2611
 
@@ -2602,12 +2616,7 @@ Guidelines:
2602
2616
  pool = await self._get_pool()
2603
2617
  await bank_utils.update_bank_disposition(pool, bank_id, disposition)
2604
2618
 
2605
- async def merge_bank_background(
2606
- self,
2607
- bank_id: str,
2608
- new_info: str,
2609
- update_disposition: bool = True
2610
- ) -> dict:
2619
+ async def merge_bank_background(self, bank_id: str, new_info: str, update_disposition: bool = True) -> dict:
2611
2620
  """
2612
2621
  Merge new background information with existing background using LLM.
2613
2622
  Normalizes to first person ("I") and resolves conflicts.
@@ -2622,9 +2631,7 @@ Guidelines:
2622
2631
  Dict with 'background' (str) and optionally 'disposition' (dict) keys
2623
2632
  """
2624
2633
  pool = await self._get_pool()
2625
- return await bank_utils.merge_bank_background(
2626
- pool, self._llm_config, bank_id, new_info, update_disposition
2627
- )
2634
+ return await bank_utils.merge_bank_background(pool, self._llm_config, bank_id, new_info, update_disposition)
2628
2635
 
2629
2636
  async def list_banks(self) -> list:
2630
2637
  """
@@ -2685,19 +2692,21 @@ Guidelines:
2685
2692
  budget=budget,
2686
2693
  max_tokens=4096,
2687
2694
  enable_trace=False,
2688
- fact_type=['experience', 'world', 'opinion'],
2689
- include_entities=True
2695
+ fact_type=["experience", "world", "opinion"],
2696
+ include_entities=True,
2690
2697
  )
2691
2698
  recall_time = time.time() - recall_start
2692
2699
 
2693
2700
  all_results = search_result.results
2694
2701
 
2695
2702
  # Split results by fact type for structured response
2696
- agent_results = [r for r in all_results if r.fact_type == 'experience']
2697
- world_results = [r for r in all_results if r.fact_type == 'world']
2698
- opinion_results = [r for r in all_results if r.fact_type == 'opinion']
2703
+ agent_results = [r for r in all_results if r.fact_type == "experience"]
2704
+ world_results = [r for r in all_results if r.fact_type == "world"]
2705
+ opinion_results = [r for r in all_results if r.fact_type == "opinion"]
2699
2706
 
2700
- log_buffer.append(f"[REFLECT {reflect_id}] Recall: {len(all_results)} facts (experience={len(agent_results)}, world={len(world_results)}, opinion={len(opinion_results)}) in {recall_time:.3f}s")
2707
+ log_buffer.append(
2708
+ f"[REFLECT {reflect_id}] Recall: {len(all_results)} facts (experience={len(agent_results)}, world={len(world_results)}, opinion={len(opinion_results)}) in {recall_time:.3f}s"
2709
+ )
2701
2710
 
2702
2711
  # Format facts for LLM
2703
2712
  agent_facts_text = think_utils.format_facts_for_prompt(agent_results)
@@ -2728,47 +2737,34 @@ Guidelines:
2728
2737
 
2729
2738
  llm_start = time.time()
2730
2739
  answer_text = await self._llm_config.call(
2731
- messages=[
2732
- {"role": "system", "content": system_message},
2733
- {"role": "user", "content": prompt}
2734
- ],
2740
+ messages=[{"role": "system", "content": system_message}, {"role": "user", "content": prompt}],
2735
2741
  scope="memory_think",
2736
2742
  temperature=0.9,
2737
- max_completion_tokens=1000
2743
+ max_completion_tokens=1000,
2738
2744
  )
2739
2745
  llm_time = time.time() - llm_start
2740
2746
 
2741
2747
  answer_text = answer_text.strip()
2742
2748
 
2743
2749
  # Submit form_opinion task for background processing
2744
- await self._task_backend.submit_task({
2745
- 'type': 'form_opinion',
2746
- 'bank_id': bank_id,
2747
- 'answer_text': answer_text,
2748
- 'query': query
2749
- })
2750
+ await self._task_backend.submit_task(
2751
+ {"type": "form_opinion", "bank_id": bank_id, "answer_text": answer_text, "query": query}
2752
+ )
2750
2753
 
2751
2754
  total_time = time.time() - reflect_start
2752
- log_buffer.append(f"[REFLECT {reflect_id}] Complete: {len(answer_text)} chars response, LLM {llm_time:.3f}s, total {total_time:.3f}s")
2755
+ log_buffer.append(
2756
+ f"[REFLECT {reflect_id}] Complete: {len(answer_text)} chars response, LLM {llm_time:.3f}s, total {total_time:.3f}s"
2757
+ )
2753
2758
  logger.info("\n" + "\n".join(log_buffer))
2754
2759
 
2755
2760
  # Return response with facts split by type
2756
2761
  return ReflectResult(
2757
2762
  text=answer_text,
2758
- based_on={
2759
- "world": world_results,
2760
- "experience": agent_results,
2761
- "opinion": opinion_results
2762
- },
2763
- new_opinions=[] # Opinions are being extracted asynchronously
2763
+ based_on={"world": world_results, "experience": agent_results, "opinion": opinion_results},
2764
+ new_opinions=[], # Opinions are being extracted asynchronously
2764
2765
  )
2765
2766
 
2766
- async def _extract_and_store_opinions_async(
2767
- self,
2768
- bank_id: str,
2769
- answer_text: str,
2770
- query: str
2771
- ):
2767
+ async def _extract_and_store_opinions_async(self, bank_id: str, answer_text: str, query: str):
2772
2768
  """
2773
2769
  Background task to extract and store opinions from think response.
2774
2770
 
@@ -2781,33 +2777,27 @@ Guidelines:
2781
2777
  """
2782
2778
  try:
2783
2779
  # Extract opinions from the answer
2784
- new_opinions = await think_utils.extract_opinions_from_text(
2785
- self._llm_config, text=answer_text, query=query
2786
- )
2780
+ new_opinions = await think_utils.extract_opinions_from_text(self._llm_config, text=answer_text, query=query)
2787
2781
 
2788
2782
  # Store new opinions
2789
2783
  if new_opinions:
2790
- from datetime import datetime, timezone
2791
- current_time = datetime.now(timezone.utc)
2784
+ from datetime import datetime
2785
+
2786
+ current_time = datetime.now(UTC)
2792
2787
  for opinion in new_opinions:
2793
2788
  await self.retain_async(
2794
2789
  bank_id=bank_id,
2795
2790
  content=opinion.opinion,
2796
2791
  context=f"formed during thinking about: {query}",
2797
2792
  event_date=current_time,
2798
- fact_type_override='opinion',
2799
- confidence_score=opinion.confidence
2793
+ fact_type_override="opinion",
2794
+ confidence_score=opinion.confidence,
2800
2795
  )
2801
2796
 
2802
2797
  except Exception as e:
2803
2798
  logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
2804
2799
 
2805
- async def get_entity_observations(
2806
- self,
2807
- bank_id: str,
2808
- entity_id: str,
2809
- limit: int = 10
2810
- ) -> List[EntityObservation]:
2800
+ async def get_entity_observations(self, bank_id: str, entity_id: str, limit: int = 10) -> list[EntityObservation]:
2811
2801
  """
2812
2802
  Get observations linked to an entity.
2813
2803
 
@@ -2832,23 +2822,18 @@ Guidelines:
2832
2822
  ORDER BY mu.mentioned_at DESC
2833
2823
  LIMIT $3
2834
2824
  """,
2835
- bank_id, uuid.UUID(entity_id), limit
2825
+ bank_id,
2826
+ uuid.UUID(entity_id),
2827
+ limit,
2836
2828
  )
2837
2829
 
2838
2830
  observations = []
2839
2831
  for row in rows:
2840
- mentioned_at = row['mentioned_at'].isoformat() if row['mentioned_at'] else None
2841
- observations.append(EntityObservation(
2842
- text=row['text'],
2843
- mentioned_at=mentioned_at
2844
- ))
2832
+ mentioned_at = row["mentioned_at"].isoformat() if row["mentioned_at"] else None
2833
+ observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
2845
2834
  return observations
2846
2835
 
2847
- async def list_entities(
2848
- self,
2849
- bank_id: str,
2850
- limit: int = 100
2851
- ) -> List[Dict[str, Any]]:
2836
+ async def list_entities(self, bank_id: str, limit: int = 100) -> list[dict[str, Any]]:
2852
2837
  """
2853
2838
  List all entities for a bank.
2854
2839
 
@@ -2869,39 +2854,37 @@ Guidelines:
2869
2854
  ORDER BY mention_count DESC, last_seen DESC
2870
2855
  LIMIT $2
2871
2856
  """,
2872
- bank_id, limit
2857
+ bank_id,
2858
+ limit,
2873
2859
  )
2874
2860
 
2875
2861
  entities = []
2876
2862
  for row in rows:
2877
2863
  # Handle metadata - may be dict, JSON string, or None
2878
- metadata = row['metadata']
2864
+ metadata = row["metadata"]
2879
2865
  if metadata is None:
2880
2866
  metadata = {}
2881
2867
  elif isinstance(metadata, str):
2882
2868
  import json
2869
+
2883
2870
  try:
2884
2871
  metadata = json.loads(metadata)
2885
2872
  except json.JSONDecodeError:
2886
2873
  metadata = {}
2887
2874
 
2888
- entities.append({
2889
- 'id': str(row['id']),
2890
- 'canonical_name': row['canonical_name'],
2891
- 'mention_count': row['mention_count'],
2892
- 'first_seen': row['first_seen'].isoformat() if row['first_seen'] else None,
2893
- 'last_seen': row['last_seen'].isoformat() if row['last_seen'] else None,
2894
- 'metadata': metadata
2895
- })
2875
+ entities.append(
2876
+ {
2877
+ "id": str(row["id"]),
2878
+ "canonical_name": row["canonical_name"],
2879
+ "mention_count": row["mention_count"],
2880
+ "first_seen": row["first_seen"].isoformat() if row["first_seen"] else None,
2881
+ "last_seen": row["last_seen"].isoformat() if row["last_seen"] else None,
2882
+ "metadata": metadata,
2883
+ }
2884
+ )
2896
2885
  return entities
2897
2886
 
2898
- async def get_entity_state(
2899
- self,
2900
- bank_id: str,
2901
- entity_id: str,
2902
- entity_name: str,
2903
- limit: int = 10
2904
- ) -> EntityState:
2887
+ async def get_entity_state(self, bank_id: str, entity_id: str, entity_name: str, limit: int = 10) -> EntityState:
2905
2888
  """
2906
2889
  Get the current state (mental model) of an entity.
2907
2890
 
@@ -2915,20 +2898,11 @@ Guidelines:
2915
2898
  EntityState with observations
2916
2899
  """
2917
2900
  observations = await self.get_entity_observations(bank_id, entity_id, limit)
2918
- return EntityState(
2919
- entity_id=entity_id,
2920
- canonical_name=entity_name,
2921
- observations=observations
2922
- )
2901
+ return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=observations)
2923
2902
 
2924
2903
  async def regenerate_entity_observations(
2925
- self,
2926
- bank_id: str,
2927
- entity_id: str,
2928
- entity_name: str,
2929
- version: str | None = None,
2930
- conn=None
2931
- ) -> List[str]:
2904
+ self, bank_id: str, entity_id: str, entity_name: str, version: str | None = None, conn=None
2905
+ ) -> list[str]:
2932
2906
  """
2933
2907
  Regenerate observations for an entity by:
2934
2908
  1. Checking version for deduplication (if provided)
@@ -2973,7 +2947,8 @@ Guidelines:
2973
2947
  FROM entities
2974
2948
  WHERE id = $1 AND bank_id = $2
2975
2949
  """,
2976
- entity_uuid, bank_id
2950
+ entity_uuid,
2951
+ bank_id,
2977
2952
  )
2978
2953
 
2979
2954
  if current_last_seen and current_last_seen.isoformat() != version:
@@ -2991,7 +2966,8 @@ Guidelines:
2991
2966
  ORDER BY mu.occurred_start DESC
2992
2967
  LIMIT 50
2993
2968
  """,
2994
- bank_id, entity_uuid
2969
+ bank_id,
2970
+ entity_uuid,
2995
2971
  )
2996
2972
 
2997
2973
  if not rows:
@@ -3000,21 +2976,19 @@ Guidelines:
3000
2976
  # Convert to MemoryFact objects for the observation extraction
3001
2977
  facts = []
3002
2978
  for row in rows:
3003
- occurred_start = row['occurred_start'].isoformat() if row['occurred_start'] else None
3004
- facts.append(MemoryFact(
3005
- id=str(row['id']),
3006
- text=row['text'],
3007
- fact_type=row['fact_type'],
3008
- context=row['context'],
3009
- occurred_start=occurred_start
3010
- ))
2979
+ occurred_start = row["occurred_start"].isoformat() if row["occurred_start"] else None
2980
+ facts.append(
2981
+ MemoryFact(
2982
+ id=str(row["id"]),
2983
+ text=row["text"],
2984
+ fact_type=row["fact_type"],
2985
+ context=row["context"],
2986
+ occurred_start=occurred_start,
2987
+ )
2988
+ )
3011
2989
 
3012
2990
  # Step 3: Extract observations using LLM (no personality)
3013
- observations = await observation_utils.extract_observations_from_facts(
3014
- self._llm_config,
3015
- entity_name,
3016
- facts
3017
- )
2991
+ observations = await observation_utils.extract_observations_from_facts(self._llm_config, entity_name, facts)
3018
2992
 
3019
2993
  if not observations:
3020
2994
  return []
@@ -3036,13 +3010,12 @@ Guidelines:
3036
3010
  AND ue.entity_id = $2
3037
3011
  )
3038
3012
  """,
3039
- bank_id, entity_uuid
3013
+ bank_id,
3014
+ entity_uuid,
3040
3015
  )
3041
3016
 
3042
3017
  # Generate embeddings for new observations
3043
- embeddings = await embedding_utils.generate_embeddings_batch(
3044
- self.embeddings, observations
3045
- )
3018
+ embeddings = await embedding_utils.generate_embeddings_batch(self.embeddings, observations)
3046
3019
 
3047
3020
  # Insert new observations
3048
3021
  current_time = utcnow()
@@ -3066,9 +3039,9 @@ Guidelines:
3066
3039
  current_time,
3067
3040
  current_time,
3068
3041
  current_time,
3069
- current_time
3042
+ current_time,
3070
3043
  )
3071
- obs_id = str(result['id'])
3044
+ obs_id = str(result["id"])
3072
3045
  created_ids.append(obs_id)
3073
3046
 
3074
3047
  # Link observation to entity
@@ -3077,7 +3050,8 @@ Guidelines:
3077
3050
  INSERT INTO unit_entities (unit_id, entity_id)
3078
3051
  VALUES ($1, $2)
3079
3052
  """,
3080
- uuid.UUID(obs_id), entity_uuid
3053
+ uuid.UUID(obs_id),
3054
+ entity_uuid,
3081
3055
  )
3082
3056
 
3083
3057
  return created_ids
@@ -3092,11 +3066,7 @@ Guidelines:
3092
3066
  return await do_db_operations(acquired_conn)
3093
3067
 
3094
3068
  async def _regenerate_observations_sync(
3095
- self,
3096
- bank_id: str,
3097
- entity_ids: List[str],
3098
- min_facts: int = 5,
3099
- conn=None
3069
+ self, bank_id: str, entity_ids: list[str], min_facts: int = 5, conn=None
3100
3070
  ) -> None:
3101
3071
  """
3102
3072
  Regenerate observations for entities synchronously (called during retain).
@@ -3123,9 +3093,10 @@ Guidelines:
3123
3093
  SELECT id, canonical_name FROM entities
3124
3094
  WHERE id = ANY($1) AND bank_id = $2
3125
3095
  """,
3126
- entity_uuids, bank_id
3096
+ entity_uuids,
3097
+ bank_id,
3127
3098
  )
3128
- entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
3099
+ entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
3129
3100
 
3130
3101
  fact_counts = await conn.fetch(
3131
3102
  """
@@ -3135,9 +3106,10 @@ Guidelines:
3135
3106
  WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3136
3107
  GROUP BY ue.entity_id
3137
3108
  """,
3138
- entity_uuids, bank_id
3109
+ entity_uuids,
3110
+ bank_id,
3139
3111
  )
3140
- entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
3112
+ entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
3141
3113
  else:
3142
3114
  # Acquire a new connection (standalone call)
3143
3115
  pool = await self._get_pool()
@@ -3147,9 +3119,10 @@ Guidelines:
3147
3119
  SELECT id, canonical_name FROM entities
3148
3120
  WHERE id = ANY($1) AND bank_id = $2
3149
3121
  """,
3150
- entity_uuids, bank_id
3122
+ entity_uuids,
3123
+ bank_id,
3151
3124
  )
3152
- entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
3125
+ entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
3153
3126
 
3154
3127
  fact_counts = await acquired_conn.fetch(
3155
3128
  """
@@ -3159,9 +3132,10 @@ Guidelines:
3159
3132
  WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3160
3133
  GROUP BY ue.entity_id
3161
3134
  """,
3162
- entity_uuids, bank_id
3135
+ entity_uuids,
3136
+ bank_id,
3163
3137
  )
3164
- entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
3138
+ entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
3165
3139
 
3166
3140
  # Filter entities that meet the threshold
3167
3141
  entities_to_process = []
@@ -3183,11 +3157,9 @@ Guidelines:
3183
3157
  except Exception as e:
3184
3158
  logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3185
3159
 
3186
- await asyncio.gather(*[
3187
- process_entity(eid, name) for eid, name in entities_to_process
3188
- ])
3160
+ await asyncio.gather(*[process_entity(eid, name) for eid, name in entities_to_process])
3189
3161
 
3190
- async def _handle_regenerate_observations(self, task_dict: Dict[str, Any]):
3162
+ async def _handle_regenerate_observations(self, task_dict: dict[str, Any]):
3191
3163
  """
3192
3164
  Handler for regenerate_observations tasks.
3193
3165
 
@@ -3197,12 +3169,12 @@ Guidelines:
3197
3169
  - 'entity_id', 'entity_name': Process single entity (legacy)
3198
3170
  """
3199
3171
  try:
3200
- bank_id = task_dict.get('bank_id')
3172
+ bank_id = task_dict.get("bank_id")
3201
3173
 
3202
3174
  # New format: multiple entity_ids
3203
- if 'entity_ids' in task_dict:
3204
- entity_ids = task_dict.get('entity_ids', [])
3205
- min_facts = task_dict.get('min_facts', 5)
3175
+ if "entity_ids" in task_dict:
3176
+ entity_ids = task_dict.get("entity_ids", [])
3177
+ min_facts = task_dict.get("min_facts", 5)
3206
3178
 
3207
3179
  if not bank_id or not entity_ids:
3208
3180
  logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
@@ -3215,31 +3187,37 @@ Guidelines:
3215
3187
  try:
3216
3188
  # Fetch entity name and check fact count
3217
3189
  import uuid as uuid_module
3190
+
3218
3191
  entity_uuid = uuid_module.UUID(entity_id) if isinstance(entity_id, str) else entity_id
3219
3192
 
3220
3193
  # First check if entity exists
3221
3194
  entity_exists = await conn.fetchrow(
3222
3195
  "SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
3223
- entity_uuid, bank_id
3196
+ entity_uuid,
3197
+ bank_id,
3224
3198
  )
3225
3199
 
3226
3200
  if not entity_exists:
3227
3201
  logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
3228
3202
  continue
3229
3203
 
3230
- entity_name = entity_exists['canonical_name']
3204
+ entity_name = entity_exists["canonical_name"]
3231
3205
 
3232
3206
  # Count facts linked to this entity
3233
- fact_count = await conn.fetchval(
3234
- "SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1",
3235
- entity_uuid
3236
- ) or 0
3207
+ fact_count = (
3208
+ await conn.fetchval(
3209
+ "SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1", entity_uuid
3210
+ )
3211
+ or 0
3212
+ )
3237
3213
 
3238
3214
  # Only regenerate if entity has enough facts
3239
3215
  if fact_count >= min_facts:
3240
3216
  await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
3241
3217
  else:
3242
- logger.debug(f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)")
3218
+ logger.debug(
3219
+ f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
3220
+ )
3243
3221
 
3244
3222
  except Exception as e:
3245
3223
  logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
@@ -3247,9 +3225,9 @@ Guidelines:
3247
3225
 
3248
3226
  # Legacy format: single entity
3249
3227
  else:
3250
- entity_id = task_dict.get('entity_id')
3251
- entity_name = task_dict.get('entity_name')
3252
- version = task_dict.get('version')
3228
+ entity_id = task_dict.get("entity_id")
3229
+ entity_name = task_dict.get("entity_name")
3230
+ version = task_dict.get("version")
3253
3231
 
3254
3232
  if not all([bank_id, entity_id, entity_name]):
3255
3233
  logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
@@ -3260,5 +3238,5 @@ Guidelines:
3260
3238
  except Exception as e:
3261
3239
  logger.error(f"[OBSERVATIONS] Error regenerating observations: {e}")
3262
3240
  import traceback
3263
- traceback.print_exc()
3264
3241
 
3242
+ traceback.print_exc()