hindsight-api 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. hindsight_api/__init__.py +10 -9
  2. hindsight_api/alembic/env.py +5 -8
  3. hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
  4. hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
  5. hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
  6. hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
  7. hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
  8. hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
  9. hindsight_api/api/__init__.py +10 -10
  10. hindsight_api/api/http.py +575 -593
  11. hindsight_api/api/mcp.py +31 -33
  12. hindsight_api/banner.py +13 -6
  13. hindsight_api/config.py +17 -12
  14. hindsight_api/engine/__init__.py +9 -9
  15. hindsight_api/engine/cross_encoder.py +23 -27
  16. hindsight_api/engine/db_utils.py +5 -4
  17. hindsight_api/engine/embeddings.py +22 -21
  18. hindsight_api/engine/entity_resolver.py +81 -75
  19. hindsight_api/engine/llm_wrapper.py +74 -88
  20. hindsight_api/engine/memory_engine.py +663 -673
  21. hindsight_api/engine/query_analyzer.py +100 -97
  22. hindsight_api/engine/response_models.py +105 -106
  23. hindsight_api/engine/retain/__init__.py +9 -16
  24. hindsight_api/engine/retain/bank_utils.py +34 -58
  25. hindsight_api/engine/retain/chunk_storage.py +4 -12
  26. hindsight_api/engine/retain/deduplication.py +9 -28
  27. hindsight_api/engine/retain/embedding_processing.py +4 -11
  28. hindsight_api/engine/retain/embedding_utils.py +3 -4
  29. hindsight_api/engine/retain/entity_processing.py +7 -17
  30. hindsight_api/engine/retain/fact_extraction.py +155 -165
  31. hindsight_api/engine/retain/fact_storage.py +11 -23
  32. hindsight_api/engine/retain/link_creation.py +11 -39
  33. hindsight_api/engine/retain/link_utils.py +166 -95
  34. hindsight_api/engine/retain/observation_regeneration.py +39 -52
  35. hindsight_api/engine/retain/orchestrator.py +72 -62
  36. hindsight_api/engine/retain/types.py +49 -43
  37. hindsight_api/engine/search/__init__.py +15 -1
  38. hindsight_api/engine/search/fusion.py +6 -15
  39. hindsight_api/engine/search/graph_retrieval.py +234 -0
  40. hindsight_api/engine/search/mpfp_retrieval.py +438 -0
  41. hindsight_api/engine/search/observation_utils.py +9 -16
  42. hindsight_api/engine/search/reranking.py +4 -7
  43. hindsight_api/engine/search/retrieval.py +388 -193
  44. hindsight_api/engine/search/scoring.py +5 -7
  45. hindsight_api/engine/search/temporal_extraction.py +8 -11
  46. hindsight_api/engine/search/think_utils.py +115 -39
  47. hindsight_api/engine/search/trace.py +68 -38
  48. hindsight_api/engine/search/tracer.py +49 -35
  49. hindsight_api/engine/search/types.py +22 -16
  50. hindsight_api/engine/task_backend.py +21 -26
  51. hindsight_api/engine/utils.py +25 -10
  52. hindsight_api/main.py +21 -40
  53. hindsight_api/mcp_local.py +190 -0
  54. hindsight_api/metrics.py +44 -30
  55. hindsight_api/migrations.py +10 -8
  56. hindsight_api/models.py +60 -72
  57. hindsight_api/pg0.py +64 -337
  58. hindsight_api/server.py +3 -6
  59. {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +6 -5
  60. hindsight_api-0.1.6.dist-info/RECORD +64 -0
  61. {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
  62. hindsight_api-0.1.4.dist-info/RECORD +0 -61
  63. {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
@@ -8,22 +8,23 @@ This implements a sophisticated memory architecture that combines:
8
8
  4. Spreading activation: Search through the graph with activation decay
9
9
  5. Dynamic weighting: Recency and frequency-based importance
10
10
  """
11
- import json
12
- import os
13
- from datetime import datetime, timedelta, timezone
14
- from typing import Any, Dict, List, Optional, Tuple, Union, TypedDict, TYPE_CHECKING
15
- import asyncpg
11
+
16
12
  import asyncio
17
- from .embeddings import Embeddings, create_embeddings_from_env
18
- from .cross_encoder import CrossEncoderModel, create_cross_encoder_from_env
13
+ import logging
19
14
  import time
20
- import numpy as np
21
15
  import uuid
22
- import logging
16
+ from datetime import UTC, datetime, timedelta
17
+ from typing import TYPE_CHECKING, Any, TypedDict
18
+
19
+ import asyncpg
20
+ import numpy as np
23
21
  from pydantic import BaseModel, Field
24
22
 
23
+ from .cross_encoder import CrossEncoderModel
24
+ from .embeddings import Embeddings, create_embeddings_from_env
25
+
25
26
  if TYPE_CHECKING:
26
- from ..config import HindsightConfig
27
+ pass
27
28
 
28
29
 
29
30
  class RetainContentDict(TypedDict, total=False):
@@ -36,30 +37,31 @@ class RetainContentDict(TypedDict, total=False):
36
37
  metadata: Custom key-value metadata (optional)
37
38
  document_id: Document ID for this content item (optional)
38
39
  """
40
+
39
41
  content: str # Required
40
42
  context: str
41
43
  event_date: datetime
42
- metadata: Dict[str, str]
44
+ metadata: dict[str, str]
43
45
  document_id: str
44
46
 
45
- from .query_analyzer import QueryAnalyzer
46
- from .search.scoring import (
47
- calculate_recency_weight,
48
- calculate_frequency_weight,
49
- )
47
+
48
+ from enum import Enum
49
+
50
+ from ..pg0 import EmbeddedPostgres
50
51
  from .entity_resolver import EntityResolver
51
- from .retain import embedding_utils, bank_utils
52
- from .search import think_utils, observation_utils
53
52
  from .llm_wrapper import LLMConfig
54
- from .response_models import RecallResult as RecallResultModel, ReflectResult, MemoryFact, EntityState, EntityObservation, VALID_RECALL_FACT_TYPES
55
- from .task_backend import TaskBackend, AsyncIOQueueBackend
53
+ from .query_analyzer import QueryAnalyzer
54
+ from .response_models import VALID_RECALL_FACT_TYPES, EntityObservation, EntityState, MemoryFact, ReflectResult
55
+ from .response_models import RecallResult as RecallResultModel
56
+ from .retain import bank_utils, embedding_utils
57
+ from .search import observation_utils, think_utils
56
58
  from .search.reranking import CrossEncoderReranker
57
- from ..pg0 import EmbeddedPostgres
58
- from enum import Enum
59
+ from .task_backend import AsyncIOQueueBackend, TaskBackend
59
60
 
60
61
 
61
62
  class Budget(str, Enum):
62
63
  """Budget levels for recall/reflect operations."""
64
+
63
65
  LOW = "low"
64
66
  MID = "mid"
65
67
  HIGH = "high"
@@ -67,20 +69,20 @@ class Budget(str, Enum):
67
69
 
68
70
  def utcnow():
69
71
  """Get current UTC time with timezone info."""
70
- return datetime.now(timezone.utc)
72
+ return datetime.now(UTC)
71
73
 
72
74
 
73
75
  # Logger for memory system
74
76
  logger = logging.getLogger(__name__)
75
77
 
76
- from .db_utils import acquire_with_retry, retry_with_backoff
77
-
78
78
  import tiktoken
79
- from dateutil import parser as date_parser
79
+
80
+ from .db_utils import acquire_with_retry
80
81
 
81
82
  # Cache tiktoken encoding for token budget filtering (module-level singleton)
82
83
  _TIKTOKEN_ENCODING = None
83
84
 
85
+
84
86
  def _get_tiktoken_encoding():
85
87
  """Get cached tiktoken encoding (cl100k_base for GPT-4/3.5)."""
86
88
  global _TIKTOKEN_ENCODING
@@ -102,17 +104,17 @@ class MemoryEngine:
102
104
 
103
105
  def __init__(
104
106
  self,
105
- db_url: Optional[str] = None,
106
- memory_llm_provider: Optional[str] = None,
107
- memory_llm_api_key: Optional[str] = None,
108
- memory_llm_model: Optional[str] = None,
109
- memory_llm_base_url: Optional[str] = None,
110
- embeddings: Optional[Embeddings] = None,
111
- cross_encoder: Optional[CrossEncoderModel] = None,
112
- query_analyzer: Optional[QueryAnalyzer] = None,
107
+ db_url: str | None = None,
108
+ memory_llm_provider: str | None = None,
109
+ memory_llm_api_key: str | None = None,
110
+ memory_llm_model: str | None = None,
111
+ memory_llm_base_url: str | None = None,
112
+ embeddings: Embeddings | None = None,
113
+ cross_encoder: CrossEncoderModel | None = None,
114
+ query_analyzer: QueryAnalyzer | None = None,
113
115
  pool_min_size: int = 5,
114
116
  pool_max_size: int = 100,
115
- task_backend: Optional[TaskBackend] = None,
117
+ task_backend: TaskBackend | None = None,
116
118
  run_migrations: bool = True,
117
119
  ):
118
120
  """
@@ -138,6 +140,7 @@ class MemoryEngine:
138
140
  """
139
141
  # Load config from environment for any missing parameters
140
142
  from ..config import get_config
143
+
141
144
  config = get_config()
142
145
 
143
146
  # Apply defaults from config
@@ -147,8 +150,8 @@ class MemoryEngine:
147
150
  memory_llm_model = memory_llm_model or config.llm_model
148
151
  memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
149
152
  # Track pg0 instance (if used)
150
- self._pg0: Optional[EmbeddedPostgres] = None
151
- self._pg0_instance_name: Optional[str] = None
153
+ self._pg0: EmbeddedPostgres | None = None
154
+ self._pg0_instance_name: str | None = None
152
155
 
153
156
  # Initialize PostgreSQL connection URL
154
157
  # The actual URL will be set during initialize() after starting the server
@@ -175,7 +178,6 @@ class MemoryEngine:
175
178
  self._pg0_port = None
176
179
  self.db_url = db_url
177
180
 
178
-
179
181
  # Set default base URL if not provided
180
182
  if memory_llm_base_url is None:
181
183
  if memory_llm_provider.lower() == "groq":
@@ -206,6 +208,7 @@ class MemoryEngine:
206
208
  self.query_analyzer = query_analyzer
207
209
  else:
208
210
  from .query_analyzer import DateparserQueryAnalyzer
211
+
209
212
  self.query_analyzer = DateparserQueryAnalyzer()
210
213
 
211
214
  # Initialize LLM configuration
@@ -224,10 +227,7 @@ class MemoryEngine:
224
227
  self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
225
228
 
226
229
  # Initialize task backend
227
- self._task_backend = task_backend or AsyncIOQueueBackend(
228
- batch_size=100,
229
- batch_interval=1.0
230
- )
230
+ self._task_backend = task_backend or AsyncIOQueueBackend(batch_size=100, batch_interval=1.0)
231
231
 
232
232
  # Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
233
233
  # Limit concurrent searches to prevent connection pool exhaustion
@@ -243,14 +243,14 @@ class MemoryEngine:
243
243
  # initialize encoding eagerly to avoid delaying the first time
244
244
  _get_tiktoken_encoding()
245
245
 
246
- async def _handle_access_count_update(self, task_dict: Dict[str, Any]):
246
+ async def _handle_access_count_update(self, task_dict: dict[str, Any]):
247
247
  """
248
248
  Handler for access count update tasks.
249
249
 
250
250
  Args:
251
251
  task_dict: Dict with 'node_ids' key containing list of node IDs to update
252
252
  """
253
- node_ids = task_dict.get('node_ids', [])
253
+ node_ids = task_dict.get("node_ids", [])
254
254
  if not node_ids:
255
255
  return
256
256
 
@@ -260,13 +260,12 @@ class MemoryEngine:
260
260
  uuid_list = [uuid.UUID(nid) for nid in node_ids]
261
261
  async with acquire_with_retry(pool) as conn:
262
262
  await conn.execute(
263
- "UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
264
- uuid_list
263
+ "UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])", uuid_list
265
264
  )
266
265
  except Exception as e:
267
266
  logger.error(f"Access count handler: Error updating access counts: {e}")
268
267
 
269
- async def _handle_batch_retain(self, task_dict: Dict[str, Any]):
268
+ async def _handle_batch_retain(self, task_dict: dict[str, Any]):
270
269
  """
271
270
  Handler for batch retain tasks.
272
271
 
@@ -274,23 +273,23 @@ class MemoryEngine:
274
273
  task_dict: Dict with 'bank_id', 'contents'
275
274
  """
276
275
  try:
277
- bank_id = task_dict.get('bank_id')
278
- contents = task_dict.get('contents', [])
279
-
280
- logger.info(f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items")
276
+ bank_id = task_dict.get("bank_id")
277
+ contents = task_dict.get("contents", [])
281
278
 
282
- await self.retain_batch_async(
283
- bank_id=bank_id,
284
- contents=contents
279
+ logger.info(
280
+ f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
285
281
  )
286
282
 
283
+ await self.retain_batch_async(bank_id=bank_id, contents=contents)
284
+
287
285
  logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
288
286
  except Exception as e:
289
287
  logger.error(f"Batch retain handler: Error processing batch retain: {e}")
290
288
  import traceback
289
+
291
290
  traceback.print_exc()
292
291
 
293
- async def execute_task(self, task_dict: Dict[str, Any]):
292
+ async def execute_task(self, task_dict: dict[str, Any]):
294
293
  """
295
294
  Execute a task by routing it to the appropriate handler.
296
295
 
@@ -301,9 +300,9 @@ class MemoryEngine:
301
300
  task_dict: Task dictionary with 'type' key and other payload data
302
301
  Example: {'type': 'access_count_update', 'node_ids': [...]}
303
302
  """
304
- task_type = task_dict.get('type')
305
- operation_id = task_dict.get('operation_id')
306
- retry_count = task_dict.get('retry_count', 0)
303
+ task_type = task_dict.get("type")
304
+ operation_id = task_dict.get("operation_id")
305
+ retry_count = task_dict.get("retry_count", 0)
307
306
  max_retries = 3
308
307
 
309
308
  # Check if operation was cancelled (only for tasks with operation_id)
@@ -312,8 +311,7 @@ class MemoryEngine:
312
311
  pool = await self._get_pool()
313
312
  async with acquire_with_retry(pool) as conn:
314
313
  result = await conn.fetchrow(
315
- "SELECT id FROM async_operations WHERE id = $1",
316
- uuid.UUID(operation_id)
314
+ "SELECT id FROM async_operations WHERE id = $1", uuid.UUID(operation_id)
317
315
  )
318
316
  if not result:
319
317
  # Operation was cancelled, skip processing
@@ -324,15 +322,15 @@ class MemoryEngine:
324
322
  # Continue with processing if we can't check status
325
323
 
326
324
  try:
327
- if task_type == 'access_count_update':
325
+ if task_type == "access_count_update":
328
326
  await self._handle_access_count_update(task_dict)
329
- elif task_type == 'reinforce_opinion':
327
+ elif task_type == "reinforce_opinion":
330
328
  await self._handle_reinforce_opinion(task_dict)
331
- elif task_type == 'form_opinion':
329
+ elif task_type == "form_opinion":
332
330
  await self._handle_form_opinion(task_dict)
333
- elif task_type == 'batch_retain':
331
+ elif task_type == "batch_retain":
334
332
  await self._handle_batch_retain(task_dict)
335
- elif task_type == 'regenerate_observations':
333
+ elif task_type == "regenerate_observations":
336
334
  await self._handle_regenerate_observations(task_dict)
337
335
  else:
338
336
  logger.error(f"Unknown task type: {task_type}")
@@ -347,14 +345,17 @@ class MemoryEngine:
347
345
 
348
346
  except Exception as e:
349
347
  # Task failed - check if we should retry
350
- logger.error(f"Task execution failed (attempt {retry_count + 1}/{max_retries + 1}): {task_type}, error: {e}")
348
+ logger.error(
349
+ f"Task execution failed (attempt {retry_count + 1}/{max_retries + 1}): {task_type}, error: {e}"
350
+ )
351
351
  import traceback
352
+
352
353
  error_traceback = traceback.format_exc()
353
354
  traceback.print_exc()
354
355
 
355
356
  if retry_count < max_retries:
356
357
  # Reschedule with incremented retry count
357
- task_dict['retry_count'] = retry_count + 1
358
+ task_dict["retry_count"] = retry_count + 1
358
359
  logger.info(f"Rescheduling task {task_type} (retry {retry_count + 1}/{max_retries})")
359
360
  await self._task_backend.submit_task(task_dict)
360
361
  else:
@@ -368,10 +369,7 @@ class MemoryEngine:
368
369
  try:
369
370
  pool = await self._get_pool()
370
371
  async with acquire_with_retry(pool) as conn:
371
- await conn.execute(
372
- "DELETE FROM async_operations WHERE id = $1",
373
- uuid.UUID(operation_id)
374
- )
372
+ await conn.execute("DELETE FROM async_operations WHERE id = $1", uuid.UUID(operation_id))
375
373
  except Exception as e:
376
374
  logger.error(f"Failed to delete async operation record {operation_id}: {e}")
377
375
 
@@ -391,7 +389,7 @@ class MemoryEngine:
391
389
  WHERE id = $1
392
390
  """,
393
391
  uuid.UUID(operation_id),
394
- truncated_error
392
+ truncated_error,
395
393
  )
396
394
  logger.info(f"Marked async operation as failed: {operation_id}")
397
395
  except Exception as e:
@@ -406,8 +404,6 @@ class MemoryEngine:
406
404
  if self._initialized:
407
405
  return
408
406
 
409
- import concurrent.futures
410
-
411
407
  # Run model loading in thread pool (CPU-bound) in parallel with pg0 startup
412
408
  loop = asyncio.get_event_loop()
413
409
 
@@ -429,10 +425,7 @@ class MemoryEngine:
429
425
  """Initialize embedding model."""
430
426
  # For local providers, run in thread pool to avoid blocking event loop
431
427
  if self.embeddings.provider_name == "local":
432
- await loop.run_in_executor(
433
- None,
434
- lambda: asyncio.run(self.embeddings.initialize())
435
- )
428
+ await loop.run_in_executor(None, lambda: asyncio.run(self.embeddings.initialize()))
436
429
  else:
437
430
  await self.embeddings.initialize()
438
431
 
@@ -441,10 +434,7 @@ class MemoryEngine:
441
434
  cross_encoder = self._cross_encoder_reranker.cross_encoder
442
435
  # For local providers, run in thread pool to avoid blocking event loop
443
436
  if cross_encoder.provider_name == "local":
444
- await loop.run_in_executor(
445
- None,
446
- lambda: asyncio.run(cross_encoder.initialize())
447
- )
437
+ await loop.run_in_executor(None, lambda: asyncio.run(cross_encoder.initialize()))
448
438
  else:
449
439
  await cross_encoder.initialize()
450
440
 
@@ -469,6 +459,7 @@ class MemoryEngine:
469
459
  # Run database migrations if enabled
470
460
  if self._run_migrations:
471
461
  from ..migrations import run_migrations
462
+
472
463
  logger.info("Running database migrations...")
473
464
  run_migrations(self.db_url)
474
465
 
@@ -555,7 +546,6 @@ class MemoryEngine:
555
546
  self._pg0 = None
556
547
  logger.info("pg0 stopped")
557
548
 
558
-
559
549
  async def wait_for_background_tasks(self):
560
550
  """
561
551
  Wait for all pending background tasks to complete.
@@ -563,7 +553,7 @@ class MemoryEngine:
563
553
  This is useful in tests to ensure background tasks (like opinion reinforcement)
564
554
  complete before making assertions.
565
555
  """
566
- if hasattr(self._task_backend, 'wait_for_pending_tasks'):
556
+ if hasattr(self._task_backend, "wait_for_pending_tasks"):
567
557
  await self._task_backend.wait_for_pending_tasks()
568
558
 
569
559
  def _format_readable_date(self, dt: datetime) -> str:
@@ -596,12 +586,12 @@ class MemoryEngine:
596
586
  self,
597
587
  conn,
598
588
  bank_id: str,
599
- texts: List[str],
600
- embeddings: List[List[float]],
589
+ texts: list[str],
590
+ embeddings: list[list[float]],
601
591
  event_date: datetime,
602
592
  time_window_hours: int = 24,
603
- similarity_threshold: float = 0.95
604
- ) -> List[bool]:
593
+ similarity_threshold: float = 0.95,
594
+ ) -> list[bool]:
605
595
  """
606
596
  Check which facts are duplicates using semantic similarity + temporal window.
607
597
 
@@ -635,6 +625,7 @@ class MemoryEngine:
635
625
 
636
626
  # Fetch ALL existing facts in time window ONCE (much faster than N queries)
637
627
  import time as time_mod
628
+
638
629
  fetch_start = time_mod.time()
639
630
  existing_facts = await conn.fetch(
640
631
  """
@@ -643,7 +634,9 @@ class MemoryEngine:
643
634
  WHERE bank_id = $1
644
635
  AND event_date BETWEEN $2 AND $3
645
636
  """,
646
- bank_id, time_lower, time_upper
637
+ bank_id,
638
+ time_lower,
639
+ time_upper,
647
640
  )
648
641
 
649
642
  # If no existing facts, nothing is duplicate
@@ -651,17 +644,17 @@ class MemoryEngine:
651
644
  return [False] * len(texts)
652
645
 
653
646
  # Compute similarities in Python (vectorized with numpy)
654
- import numpy as np
655
647
  is_duplicate = []
656
648
 
657
649
  # Convert existing embeddings to numpy for faster computation
658
650
  embedding_arrays = []
659
651
  for row in existing_facts:
660
- raw_emb = row['embedding']
652
+ raw_emb = row["embedding"]
661
653
  # Handle different pgvector formats
662
654
  if isinstance(raw_emb, str):
663
655
  # Parse string format: "[1.0, 2.0, ...]"
664
656
  import json
657
+
665
658
  emb = np.array(json.loads(raw_emb), dtype=np.float32)
666
659
  elif isinstance(raw_emb, (list, tuple)):
667
660
  emb = np.array(raw_emb, dtype=np.float32)
@@ -691,7 +684,6 @@ class MemoryEngine:
691
684
  max_similarity = np.max(similarities) if len(similarities) > 0 else 0
692
685
  is_duplicate.append(max_similarity > similarity_threshold)
693
686
 
694
-
695
687
  return is_duplicate
696
688
 
697
689
  def retain(
@@ -699,8 +691,8 @@ class MemoryEngine:
699
691
  bank_id: str,
700
692
  content: str,
701
693
  context: str = "",
702
- event_date: Optional[datetime] = None,
703
- ) -> List[str]:
694
+ event_date: datetime | None = None,
695
+ ) -> list[str]:
704
696
  """
705
697
  Store content as memory units (synchronous wrapper).
706
698
 
@@ -724,11 +716,11 @@ class MemoryEngine:
724
716
  bank_id: str,
725
717
  content: str,
726
718
  context: str = "",
727
- event_date: Optional[datetime] = None,
728
- document_id: Optional[str] = None,
729
- fact_type_override: Optional[str] = None,
730
- confidence_score: Optional[float] = None,
731
- ) -> List[str]:
719
+ event_date: datetime | None = None,
720
+ document_id: str | None = None,
721
+ fact_type_override: str | None = None,
722
+ confidence_score: float | None = None,
723
+ ) -> list[str]:
732
724
  """
733
725
  Store content as memory units with temporal and semantic links (ASYNC version).
734
726
 
@@ -747,11 +739,7 @@ class MemoryEngine:
747
739
  List of created unit IDs
748
740
  """
749
741
  # Build content dict
750
- content_dict: RetainContentDict = {
751
- "content": content,
752
- "context": context,
753
- "event_date": event_date
754
- }
742
+ content_dict: RetainContentDict = {"content": content, "context": context, "event_date": event_date}
755
743
  if document_id:
756
744
  content_dict["document_id"] = document_id
757
745
 
@@ -760,7 +748,7 @@ class MemoryEngine:
760
748
  bank_id=bank_id,
761
749
  contents=[content_dict],
762
750
  fact_type_override=fact_type_override,
763
- confidence_score=confidence_score
751
+ confidence_score=confidence_score,
764
752
  )
765
753
 
766
754
  # Return the first (and only) list of unit IDs
@@ -769,11 +757,11 @@ class MemoryEngine:
769
757
  async def retain_batch_async(
770
758
  self,
771
759
  bank_id: str,
772
- contents: List[RetainContentDict],
773
- document_id: Optional[str] = None,
774
- fact_type_override: Optional[str] = None,
775
- confidence_score: Optional[float] = None,
776
- ) -> List[List[str]]:
760
+ contents: list[RetainContentDict],
761
+ document_id: str | None = None,
762
+ fact_type_override: str | None = None,
763
+ confidence_score: float | None = None,
764
+ ) -> list[list[str]]:
777
765
  """
778
766
  Store multiple content items as memory units in ONE batch operation.
779
767
 
@@ -839,7 +827,9 @@ class MemoryEngine:
839
827
 
840
828
  if total_chars > CHARS_PER_BATCH:
841
829
  # Split into smaller batches based on character count
842
- logger.info(f"Large batch detected ({total_chars:,} chars from {len(contents)} items). Splitting into sub-batches of ~{CHARS_PER_BATCH:,} chars each...")
830
+ logger.info(
831
+ f"Large batch detected ({total_chars:,} chars from {len(contents)} items). Splitting into sub-batches of ~{CHARS_PER_BATCH:,} chars each..."
832
+ )
843
833
 
844
834
  sub_batches = []
845
835
  current_batch = []
@@ -868,7 +858,9 @@ class MemoryEngine:
868
858
  all_results = []
869
859
  for i, sub_batch in enumerate(sub_batches, 1):
870
860
  sub_batch_chars = sum(len(item.get("content", "")) for item in sub_batch)
871
- logger.info(f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars")
861
+ logger.info(
862
+ f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars"
863
+ )
872
864
 
873
865
  sub_results = await self._retain_batch_async_internal(
874
866
  bank_id=bank_id,
@@ -876,12 +868,14 @@ class MemoryEngine:
876
868
  document_id=document_id,
877
869
  is_first_batch=i == 1, # Only upsert on first batch
878
870
  fact_type_override=fact_type_override,
879
- confidence_score=confidence_score
871
+ confidence_score=confidence_score,
880
872
  )
881
873
  all_results.extend(sub_results)
882
874
 
883
875
  total_time = time.time() - start_time
884
- logger.info(f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s")
876
+ logger.info(
877
+ f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s"
878
+ )
885
879
  return all_results
886
880
 
887
881
  # Small batch - use internal method directly
@@ -891,18 +885,18 @@ class MemoryEngine:
891
885
  document_id=document_id,
892
886
  is_first_batch=True,
893
887
  fact_type_override=fact_type_override,
894
- confidence_score=confidence_score
888
+ confidence_score=confidence_score,
895
889
  )
896
890
 
897
891
  async def _retain_batch_async_internal(
898
892
  self,
899
893
  bank_id: str,
900
- contents: List[RetainContentDict],
901
- document_id: Optional[str] = None,
894
+ contents: list[RetainContentDict],
895
+ document_id: str | None = None,
902
896
  is_first_batch: bool = True,
903
- fact_type_override: Optional[str] = None,
904
- confidence_score: Optional[float] = None,
905
- ) -> List[List[str]]:
897
+ fact_type_override: str | None = None,
898
+ confidence_score: float | None = None,
899
+ ) -> list[list[str]]:
906
900
  """
907
901
  Internal method for batch processing without chunking logic.
908
902
 
@@ -938,7 +932,7 @@ class MemoryEngine:
938
932
  document_id=document_id,
939
933
  is_first_batch=is_first_batch,
940
934
  fact_type_override=fact_type_override,
941
- confidence_score=confidence_score
935
+ confidence_score=confidence_score,
942
936
  )
943
937
 
944
938
  def recall(
@@ -949,7 +943,7 @@ class MemoryEngine:
949
943
  budget: Budget = Budget.MID,
950
944
  max_tokens: int = 4096,
951
945
  enable_trace: bool = False,
952
- ) -> tuple[List[Dict[str, Any]], Optional[Any]]:
946
+ ) -> tuple[list[dict[str, Any]], Any | None]:
953
947
  """
954
948
  Recall memories using 4-way parallel retrieval (synchronous wrapper).
955
949
 
@@ -968,19 +962,17 @@ class MemoryEngine:
968
962
  Tuple of (results, trace)
969
963
  """
970
964
  # Run async version synchronously
971
- return asyncio.run(self.recall_async(
972
- bank_id, query, [fact_type], budget, max_tokens, enable_trace
973
- ))
965
+ return asyncio.run(self.recall_async(bank_id, query, [fact_type], budget, max_tokens, enable_trace))
974
966
 
975
967
  async def recall_async(
976
968
  self,
977
969
  bank_id: str,
978
970
  query: str,
979
- fact_type: List[str],
971
+ fact_type: list[str],
980
972
  budget: Budget = Budget.MID,
981
973
  max_tokens: int = 4096,
982
974
  enable_trace: bool = False,
983
- question_date: Optional[datetime] = None,
975
+ question_date: datetime | None = None,
984
976
  include_entities: bool = False,
985
977
  max_entity_tokens: int = 1024,
986
978
  include_chunks: bool = False,
@@ -1027,11 +1019,7 @@ class MemoryEngine:
1027
1019
  )
1028
1020
 
1029
1021
  # Map budget enum to thinking_budget number
1030
- budget_mapping = {
1031
- Budget.LOW: 100,
1032
- Budget.MID: 300,
1033
- Budget.HIGH: 1000
1034
- }
1022
+ budget_mapping = {Budget.LOW: 100, Budget.MID: 300, Budget.HIGH: 1000}
1035
1023
  thinking_budget = budget_mapping[budget]
1036
1024
 
1037
1025
  # Backpressure: limit concurrent recalls to prevent overwhelming the database
@@ -1041,20 +1029,29 @@ class MemoryEngine:
1041
1029
  for attempt in range(max_retries + 1):
1042
1030
  try:
1043
1031
  return await self._search_with_retries(
1044
- bank_id, query, fact_type, thinking_budget, max_tokens, enable_trace, question_date,
1045
- include_entities, max_entity_tokens, include_chunks, max_chunk_tokens
1032
+ bank_id,
1033
+ query,
1034
+ fact_type,
1035
+ thinking_budget,
1036
+ max_tokens,
1037
+ enable_trace,
1038
+ question_date,
1039
+ include_entities,
1040
+ max_entity_tokens,
1041
+ include_chunks,
1042
+ max_chunk_tokens,
1046
1043
  )
1047
1044
  except Exception as e:
1048
1045
  # Check if it's a connection error
1049
1046
  is_connection_error = (
1050
- isinstance(e, asyncpg.TooManyConnectionsError) or
1051
- isinstance(e, asyncpg.CannotConnectNowError) or
1052
- (isinstance(e, asyncpg.PostgresError) and 'connection' in str(e).lower())
1047
+ isinstance(e, asyncpg.TooManyConnectionsError)
1048
+ or isinstance(e, asyncpg.CannotConnectNowError)
1049
+ or (isinstance(e, asyncpg.PostgresError) and "connection" in str(e).lower())
1053
1050
  )
1054
1051
 
1055
1052
  if is_connection_error and attempt < max_retries:
1056
1053
  # Wait with exponential backoff before retry
1057
- wait_time = 0.5 * (2 ** attempt) # 0.5s, 1s, 2s
1054
+ wait_time = 0.5 * (2**attempt) # 0.5s, 1s, 2s
1058
1055
  logger.warning(
1059
1056
  f"Connection error on search attempt {attempt + 1}/{max_retries + 1}: {str(e)}. "
1060
1057
  f"Retrying in {wait_time:.1f}s..."
@@ -1069,11 +1066,11 @@ class MemoryEngine:
1069
1066
  self,
1070
1067
  bank_id: str,
1071
1068
  query: str,
1072
- fact_type: List[str],
1069
+ fact_type: list[str],
1073
1070
  thinking_budget: int,
1074
1071
  max_tokens: int,
1075
1072
  enable_trace: bool,
1076
- question_date: Optional[datetime] = None,
1073
+ question_date: datetime | None = None,
1077
1074
  include_entities: bool = False,
1078
1075
  max_entity_tokens: int = 500,
1079
1076
  include_chunks: bool = False,
@@ -1106,6 +1103,7 @@ class MemoryEngine:
1106
1103
  """
1107
1104
  # Initialize tracer if requested
1108
1105
  from .search.tracer import SearchTracer
1106
+
1109
1107
  tracer = SearchTracer(query, thinking_budget, max_tokens) if enable_trace else None
1110
1108
  if tracer:
1111
1109
  tracer.start()
@@ -1116,7 +1114,9 @@ class MemoryEngine:
1116
1114
  # Buffer logs for clean output in concurrent scenarios
1117
1115
  recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
1118
1116
  log_buffer = []
1119
- log_buffer.append(f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})")
1117
+ log_buffer.append(
1118
+ f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})"
1119
+ )
1120
1120
 
1121
1121
  try:
1122
1122
  # Step 1: Generate query embedding (for semantic search)
@@ -1141,8 +1141,7 @@ class MemoryEngine:
1141
1141
  # Run retrieval for each fact type in parallel
1142
1142
  retrieval_tasks = [
1143
1143
  retrieve_parallel(
1144
- pool, query, query_embedding_str, bank_id, ft, thinking_budget,
1145
- question_date, self.query_analyzer
1144
+ pool, query, query_embedding_str, bank_id, ft, thinking_budget, question_date, self.query_analyzer
1146
1145
  )
1147
1146
  for ft in fact_type
1148
1147
  ]
@@ -1156,22 +1155,24 @@ class MemoryEngine:
1156
1155
  aggregated_timings = {"semantic": 0.0, "bm25": 0.0, "graph": 0.0, "temporal": 0.0}
1157
1156
 
1158
1157
  detected_temporal_constraint = None
1159
- for idx, (ft_semantic, ft_bm25, ft_graph, ft_temporal, ft_timings, ft_temporal_constraint) in enumerate(all_retrievals):
1158
+ for idx, retrieval_result in enumerate(all_retrievals):
1160
1159
  # Log fact types in this retrieval batch
1161
1160
  ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1162
- logger.debug(f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(ft_semantic)}, bm25={len(ft_bm25)}, graph={len(ft_graph)}, temporal={len(ft_temporal) if ft_temporal else 0}")
1161
+ logger.debug(
1162
+ f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}"
1163
+ )
1163
1164
 
1164
- semantic_results.extend(ft_semantic)
1165
- bm25_results.extend(ft_bm25)
1166
- graph_results.extend(ft_graph)
1167
- if ft_temporal:
1168
- temporal_results.extend(ft_temporal)
1165
+ semantic_results.extend(retrieval_result.semantic)
1166
+ bm25_results.extend(retrieval_result.bm25)
1167
+ graph_results.extend(retrieval_result.graph)
1168
+ if retrieval_result.temporal:
1169
+ temporal_results.extend(retrieval_result.temporal)
1169
1170
  # Track max timing for each method (since they run in parallel across fact types)
1170
- for method, duration in ft_timings.items():
1171
- aggregated_timings[method] = max(aggregated_timings[method], duration)
1171
+ for method, duration in retrieval_result.timings.items():
1172
+ aggregated_timings[method] = max(aggregated_timings.get(method, 0.0), duration)
1172
1173
  # Capture temporal constraint (same across all fact types)
1173
- if ft_temporal_constraint:
1174
- detected_temporal_constraint = ft_temporal_constraint
1174
+ if retrieval_result.temporal_constraint:
1175
+ detected_temporal_constraint = retrieval_result.temporal_constraint
1175
1176
 
1176
1177
  # If no temporal results from any fact type, set to None
1177
1178
  if not temporal_results:
@@ -1179,11 +1180,13 @@ class MemoryEngine:
1179
1180
 
1180
1181
  # Sort combined results by score (descending) so higher-scored results
1181
1182
  # get better ranks in the trace, regardless of fact type
1182
- semantic_results.sort(key=lambda r: r.similarity if hasattr(r, 'similarity') else 0, reverse=True)
1183
- bm25_results.sort(key=lambda r: r.bm25_score if hasattr(r, 'bm25_score') else 0, reverse=True)
1184
- graph_results.sort(key=lambda r: r.activation if hasattr(r, 'activation') else 0, reverse=True)
1183
+ semantic_results.sort(key=lambda r: r.similarity if hasattr(r, "similarity") else 0, reverse=True)
1184
+ bm25_results.sort(key=lambda r: r.bm25_score if hasattr(r, "bm25_score") else 0, reverse=True)
1185
+ graph_results.sort(key=lambda r: r.activation if hasattr(r, "activation") else 0, reverse=True)
1185
1186
  if temporal_results:
1186
- temporal_results.sort(key=lambda r: r.combined_score if hasattr(r, 'combined_score') else 0, reverse=True)
1187
+ temporal_results.sort(
1188
+ key=lambda r: r.combined_score if hasattr(r, "combined_score") else 0, reverse=True
1189
+ )
1187
1190
 
1188
1191
  retrieval_duration = time.time() - retrieval_start
1189
1192
 
@@ -1193,7 +1196,7 @@ class MemoryEngine:
1193
1196
  timing_parts = [
1194
1197
  f"semantic={len(semantic_results)}({aggregated_timings['semantic']:.3f}s)",
1195
1198
  f"bm25={len(bm25_results)}({aggregated_timings['bm25']:.3f}s)",
1196
- f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)"
1199
+ f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)",
1197
1200
  ]
1198
1201
  temporal_info = ""
1199
1202
  if detected_temporal_constraint:
@@ -1201,61 +1204,75 @@ class MemoryEngine:
1201
1204
  temporal_count = len(temporal_results) if temporal_results else 0
1202
1205
  timing_parts.append(f"temporal={temporal_count}({aggregated_timings['temporal']:.3f}s)")
1203
1206
  temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
1204
- log_buffer.append(f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}")
1207
+ log_buffer.append(
1208
+ f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}"
1209
+ )
1205
1210
 
1206
- # Record retrieval results for tracer (convert typed results to old format)
1211
+ # Record retrieval results for tracer - per fact type
1207
1212
  if tracer:
1208
1213
  # Convert RetrievalResult to old tuple format for tracer
1209
1214
  def to_tuple_format(results):
1210
1215
  return [(r.id, r.__dict__) for r in results]
1211
1216
 
1212
- # Add semantic retrieval results
1213
- tracer.add_retrieval_results(
1214
- method_name="semantic",
1215
- results=to_tuple_format(semantic_results),
1216
- duration_seconds=aggregated_timings["semantic"],
1217
- score_field="similarity",
1218
- metadata={"limit": thinking_budget}
1219
- )
1217
+ # Add retrieval results per fact type (to show parallel execution in UI)
1218
+ for idx, rr in enumerate(all_retrievals):
1219
+ ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1220
1220
 
1221
- # Add BM25 retrieval results
1222
- tracer.add_retrieval_results(
1223
- method_name="bm25",
1224
- results=to_tuple_format(bm25_results),
1225
- duration_seconds=aggregated_timings["bm25"],
1226
- score_field="bm25_score",
1227
- metadata={"limit": thinking_budget}
1228
- )
1221
+ # Add semantic retrieval results for this fact type
1222
+ tracer.add_retrieval_results(
1223
+ method_name="semantic",
1224
+ results=to_tuple_format(rr.semantic),
1225
+ duration_seconds=rr.timings.get("semantic", 0.0),
1226
+ score_field="similarity",
1227
+ metadata={"limit": thinking_budget},
1228
+ fact_type=ft_name,
1229
+ )
1229
1230
 
1230
- # Add graph retrieval results
1231
- tracer.add_retrieval_results(
1232
- method_name="graph",
1233
- results=to_tuple_format(graph_results),
1234
- duration_seconds=aggregated_timings["graph"],
1235
- score_field="similarity", # Graph uses similarity for activation
1236
- metadata={"budget": thinking_budget}
1237
- )
1231
+ # Add BM25 retrieval results for this fact type
1232
+ tracer.add_retrieval_results(
1233
+ method_name="bm25",
1234
+ results=to_tuple_format(rr.bm25),
1235
+ duration_seconds=rr.timings.get("bm25", 0.0),
1236
+ score_field="bm25_score",
1237
+ metadata={"limit": thinking_budget},
1238
+ fact_type=ft_name,
1239
+ )
1238
1240
 
1239
- # Add temporal retrieval results if present
1240
- if temporal_results:
1241
+ # Add graph retrieval results for this fact type
1241
1242
  tracer.add_retrieval_results(
1242
- method_name="temporal",
1243
- results=to_tuple_format(temporal_results),
1244
- duration_seconds=aggregated_timings["temporal"],
1245
- score_field="temporal_score",
1246
- metadata={"budget": thinking_budget}
1243
+ method_name="graph",
1244
+ results=to_tuple_format(rr.graph),
1245
+ duration_seconds=rr.timings.get("graph", 0.0),
1246
+ score_field="activation",
1247
+ metadata={"budget": thinking_budget},
1248
+ fact_type=ft_name,
1247
1249
  )
1248
1250
 
1251
+ # Add temporal retrieval results for this fact type (even if empty, to show it ran)
1252
+ if rr.temporal is not None:
1253
+ tracer.add_retrieval_results(
1254
+ method_name="temporal",
1255
+ results=to_tuple_format(rr.temporal),
1256
+ duration_seconds=rr.timings.get("temporal", 0.0),
1257
+ score_field="temporal_score",
1258
+ metadata={"budget": thinking_budget},
1259
+ fact_type=ft_name,
1260
+ )
1261
+
1249
1262
  # Record entry points (from semantic results) for legacy graph view
1250
1263
  for rank, retrieval in enumerate(semantic_results[:10], start=1): # Top 10 as entry points
1251
1264
  tracer.add_entry_point(retrieval.id, retrieval.text, retrieval.similarity or 0.0, rank)
1252
1265
 
1253
- tracer.add_phase_metric("parallel_retrieval", step_duration, {
1254
- "semantic_count": len(semantic_results),
1255
- "bm25_count": len(bm25_results),
1256
- "graph_count": len(graph_results),
1257
- "temporal_count": len(temporal_results) if temporal_results else 0
1258
- })
1266
+ tracer.add_phase_metric(
1267
+ "parallel_retrieval",
1268
+ step_duration,
1269
+ {
1270
+ "semantic_count": len(semantic_results),
1271
+ "bm25_count": len(bm25_results),
1272
+ "graph_count": len(graph_results),
1273
+ "temporal_count": len(temporal_results) if temporal_results else 0,
1274
+ },
1275
+ )
1259
1276
 
1260
1277
  # Step 3: Merge with RRF
1261
1278
  step_start = time.time()
@@ -1263,7 +1280,9 @@ class MemoryEngine:
1263
1280
 
1264
1281
  # Merge 3 or 4 result lists depending on temporal constraint
1265
1282
  if temporal_results:
1266
- merged_candidates = reciprocal_rank_fusion([semantic_results, bm25_results, graph_results, temporal_results])
1283
+ merged_candidates = reciprocal_rank_fusion(
1284
+ [semantic_results, bm25_results, graph_results, temporal_results]
1285
+ )
1267
1286
  else:
1268
1287
  merged_candidates = reciprocal_rank_fusion([semantic_results, bm25_results, graph_results])
1269
1288
 
@@ -1272,8 +1291,10 @@ class MemoryEngine:
1272
1291
 
1273
1292
  if tracer:
1274
1293
  # Convert MergedCandidate to old tuple format for tracer
1275
- tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1276
- for mc in merged_candidates]
1294
+ tracer_merged = [
1295
+ (mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1296
+ for mc in merged_candidates
1297
+ ]
1277
1298
  tracer.add_rrf_merged(tracer_merged)
1278
1299
  tracer.add_phase_metric("rrf_merge", step_duration, {"candidates_merged": len(merged_candidates)})
1279
1300
 
@@ -1287,44 +1308,38 @@ class MemoryEngine:
1287
1308
  step_duration = time.time() - step_start
1288
1309
  log_buffer.append(f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s")
1289
1310
 
1290
- if tracer:
1291
- # Convert to old format for tracer
1292
- results_dict = [sr.to_dict() for sr in scored_results]
1293
- tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1294
- for mc in merged_candidates]
1295
- tracer.add_reranked(results_dict, tracer_merged)
1296
- tracer.add_phase_metric("reranking", step_duration, {
1297
- "reranker_type": "cross-encoder",
1298
- "candidates_reranked": len(scored_results)
1299
- })
1300
-
1301
1311
  # Step 4.5: Combine cross-encoder score with retrieval signals
1302
1312
  # This preserves retrieval work (RRF, temporal, recency) instead of pure cross-encoder ranking
1303
1313
  if scored_results:
1304
- # Normalize RRF scores to [0, 1] range
1314
+ # Normalize RRF scores to [0, 1] range using min-max normalization
1305
1315
  rrf_scores = [sr.candidate.rrf_score for sr in scored_results]
1306
- max_rrf = max(rrf_scores) if rrf_scores else 1.0
1316
+ max_rrf = max(rrf_scores) if rrf_scores else 0.0
1307
1317
  min_rrf = min(rrf_scores) if rrf_scores else 0.0
1308
- rrf_range = max_rrf - min_rrf if max_rrf > min_rrf else 1.0
1318
+ rrf_range = max_rrf - min_rrf # Don't force to 1.0, let fallback handle it
1309
1319
 
1310
1320
  # Calculate recency based on occurred_start (more recent = higher score)
1311
1321
  now = utcnow()
1312
1322
  for sr in scored_results:
1313
- # Normalize RRF score
1314
- sr.rrf_normalized = (sr.candidate.rrf_score - min_rrf) / rrf_range if rrf_range > 0 else 0.5
1323
+ # Normalize RRF score (0-1 range, 0.5 if all same)
1324
+ if rrf_range > 0:
1325
+ sr.rrf_normalized = (sr.candidate.rrf_score - min_rrf) / rrf_range
1326
+ else:
1327
+ # All RRF scores are the same, use neutral value
1328
+ sr.rrf_normalized = 0.5
1315
1329
 
1316
1330
  # Calculate recency (decay over 365 days, minimum 0.1)
1317
1331
  sr.recency = 0.5 # default for missing dates
1318
1332
  if sr.retrieval.occurred_start:
1319
1333
  occurred = sr.retrieval.occurred_start
1320
- if hasattr(occurred, 'tzinfo') and occurred.tzinfo is None:
1321
- from datetime import timezone
1322
- occurred = occurred.replace(tzinfo=timezone.utc)
1334
+ if hasattr(occurred, "tzinfo") and occurred.tzinfo is None:
1335
+ occurred = occurred.replace(tzinfo=UTC)
1323
1336
  days_ago = (now - occurred).total_seconds() / 86400
1324
1337
  sr.recency = max(0.1, 1.0 - (days_ago / 365)) # Linear decay over 1 year
1325
1338
 
1326
1339
  # Get temporal proximity if available (already 0-1)
1327
- sr.temporal = sr.retrieval.temporal_proximity if sr.retrieval.temporal_proximity is not None else 0.5
1340
+ sr.temporal = (
1341
+ sr.retrieval.temporal_proximity if sr.retrieval.temporal_proximity is not None else 0.5
1342
+ )
1328
1343
 
1329
1344
  # Weighted combination
1330
1345
  # Cross-encoder: 60% (semantic relevance)
@@ -1332,16 +1347,32 @@ class MemoryEngine:
1332
1347
  # Temporal proximity: 10% (time relevance for temporal queries)
1333
1348
  # Recency: 10% (prefer recent facts)
1334
1349
  sr.combined_score = (
1335
- 0.6 * sr.cross_encoder_score_normalized +
1336
- 0.2 * sr.rrf_normalized +
1337
- 0.1 * sr.temporal +
1338
- 0.1 * sr.recency
1350
+ 0.6 * sr.cross_encoder_score_normalized
1351
+ + 0.2 * sr.rrf_normalized
1352
+ + 0.1 * sr.temporal
1353
+ + 0.1 * sr.recency
1339
1354
  )
1340
1355
  sr.weight = sr.combined_score # Update weight for final ranking
1341
1356
 
1342
1357
  # Re-sort by combined score
1343
1358
  scored_results.sort(key=lambda x: x.weight, reverse=True)
1344
- log_buffer.append(f" [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)")
1359
+ log_buffer.append(
1360
+ " [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)"
1361
+ )
1362
+
1363
+ # Add reranked results to tracer AFTER combined scoring (so normalized values are included)
1364
+ if tracer:
1365
+ results_dict = [sr.to_dict() for sr in scored_results]
1366
+ tracer_merged = [
1367
+ (mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1368
+ for mc in merged_candidates
1369
+ ]
1370
+ tracer.add_reranked(results_dict, tracer_merged)
1371
+ tracer.add_phase_metric(
1372
+ "reranking",
1373
+ step_duration,
1374
+ {"reranker_type": "cross-encoder", "candidates_reranked": len(scored_results)},
1375
+ )
1345
1376
 
1346
1377
  # Step 5: Truncate to thinking_budget * 2 for token filtering
1347
1378
  rerank_limit = thinking_budget * 2
@@ -1360,14 +1391,16 @@ class MemoryEngine:
1360
1391
  top_scored = [sr for sr in top_scored if sr.id in filtered_ids]
1361
1392
 
1362
1393
  step_duration = time.time() - step_start
1363
- log_buffer.append(f" [6] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s")
1394
+ log_buffer.append(
1395
+ f" [6] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s"
1396
+ )
1364
1397
 
1365
1398
  if tracer:
1366
- tracer.add_phase_metric("token_filtering", step_duration, {
1367
- "results_selected": len(top_scored),
1368
- "tokens_used": total_tokens,
1369
- "max_tokens": max_tokens
1370
- })
1399
+ tracer.add_phase_metric(
1400
+ "token_filtering",
1401
+ step_duration,
1402
+ {"results_selected": len(top_scored), "tokens_used": total_tokens, "max_tokens": max_tokens},
1403
+ )
1371
1404
 
1372
1405
  # Record visits for all retrieved nodes
1373
1406
  if tracer:
@@ -1386,16 +1419,13 @@ class MemoryEngine:
1386
1419
  semantic_similarity=sr.retrieval.similarity or 0.0,
1387
1420
  recency=sr.recency,
1388
1421
  frequency=0.0,
1389
- final_weight=sr.weight
1422
+ final_weight=sr.weight,
1390
1423
  )
1391
1424
 
1392
1425
  # Step 8: Queue access count updates for visited nodes
1393
1426
  visited_ids = list(set([sr.id for sr in scored_results[:50]])) # Top 50
1394
1427
  if visited_ids:
1395
- await self._task_backend.submit_task({
1396
- 'type': 'access_count_update',
1397
- 'node_ids': visited_ids
1398
- })
1428
+ await self._task_backend.submit_task({"type": "access_count_update", "node_ids": visited_ids})
1399
1429
  log_buffer.append(f" [7] Queued access count updates for {len(visited_ids)} nodes")
1400
1430
 
1401
1431
  # Log fact_type distribution in results
@@ -1413,13 +1443,19 @@ class MemoryEngine:
1413
1443
  # Convert datetime objects to ISO strings for JSON serialization
1414
1444
  if result_dict.get("occurred_start"):
1415
1445
  occurred_start = result_dict["occurred_start"]
1416
- result_dict["occurred_start"] = occurred_start.isoformat() if hasattr(occurred_start, 'isoformat') else occurred_start
1446
+ result_dict["occurred_start"] = (
1447
+ occurred_start.isoformat() if hasattr(occurred_start, "isoformat") else occurred_start
1448
+ )
1417
1449
  if result_dict.get("occurred_end"):
1418
1450
  occurred_end = result_dict["occurred_end"]
1419
- result_dict["occurred_end"] = occurred_end.isoformat() if hasattr(occurred_end, 'isoformat') else occurred_end
1451
+ result_dict["occurred_end"] = (
1452
+ occurred_end.isoformat() if hasattr(occurred_end, "isoformat") else occurred_end
1453
+ )
1420
1454
  if result_dict.get("mentioned_at"):
1421
1455
  mentioned_at = result_dict["mentioned_at"]
1422
- result_dict["mentioned_at"] = mentioned_at.isoformat() if hasattr(mentioned_at, 'isoformat') else mentioned_at
1456
+ result_dict["mentioned_at"] = (
1457
+ mentioned_at.isoformat() if hasattr(mentioned_at, "isoformat") else mentioned_at
1458
+ )
1423
1459
  top_results_dicts.append(result_dict)
1424
1460
 
1425
1461
  # Get entities for each fact if include_entities is requested
@@ -1435,16 +1471,15 @@ class MemoryEngine:
1435
1471
  JOIN entities e ON ue.entity_id = e.id
1436
1472
  WHERE ue.unit_id = ANY($1::uuid[])
1437
1473
  """,
1438
- unit_ids
1474
+ unit_ids,
1439
1475
  )
1440
1476
  for row in entity_rows:
1441
- unit_id = str(row['unit_id'])
1477
+ unit_id = str(row["unit_id"])
1442
1478
  if unit_id not in fact_entity_map:
1443
1479
  fact_entity_map[unit_id] = []
1444
- fact_entity_map[unit_id].append({
1445
- 'entity_id': str(row['entity_id']),
1446
- 'canonical_name': row['canonical_name']
1447
- })
1480
+ fact_entity_map[unit_id].append(
1481
+ {"entity_id": str(row["entity_id"]), "canonical_name": row["canonical_name"]}
1482
+ )
1448
1483
 
1449
1484
  # Convert results to MemoryFact objects
1450
1485
  memory_facts = []
@@ -1453,20 +1488,22 @@ class MemoryEngine:
1453
1488
  # Get entity names for this fact
1454
1489
  entity_names = None
1455
1490
  if include_entities and result_id in fact_entity_map:
1456
- entity_names = [e['canonical_name'] for e in fact_entity_map[result_id]]
1457
-
1458
- memory_facts.append(MemoryFact(
1459
- id=result_id,
1460
- text=result_dict.get("text"),
1461
- fact_type=result_dict.get("fact_type", "world"),
1462
- entities=entity_names,
1463
- context=result_dict.get("context"),
1464
- occurred_start=result_dict.get("occurred_start"),
1465
- occurred_end=result_dict.get("occurred_end"),
1466
- mentioned_at=result_dict.get("mentioned_at"),
1467
- document_id=result_dict.get("document_id"),
1468
- chunk_id=result_dict.get("chunk_id"),
1469
- ))
1491
+ entity_names = [e["canonical_name"] for e in fact_entity_map[result_id]]
1492
+
1493
+ memory_facts.append(
1494
+ MemoryFact(
1495
+ id=result_id,
1496
+ text=result_dict.get("text"),
1497
+ fact_type=result_dict.get("fact_type", "world"),
1498
+ entities=entity_names,
1499
+ context=result_dict.get("context"),
1500
+ occurred_start=result_dict.get("occurred_start"),
1501
+ occurred_end=result_dict.get("occurred_end"),
1502
+ mentioned_at=result_dict.get("mentioned_at"),
1503
+ document_id=result_dict.get("document_id"),
1504
+ chunk_id=result_dict.get("chunk_id"),
1505
+ )
1506
+ )
1470
1507
 
1471
1508
  # Fetch entity observations if requested
1472
1509
  entities_dict = None
@@ -1483,8 +1520,8 @@ class MemoryEngine:
1483
1520
  unit_id = sr.id
1484
1521
  if unit_id in fact_entity_map:
1485
1522
  for entity in fact_entity_map[unit_id]:
1486
- entity_id = entity['entity_id']
1487
- entity_name = entity['canonical_name']
1523
+ entity_id = entity["entity_id"]
1524
+ entity_name = entity["canonical_name"]
1488
1525
  if entity_id not in seen_entity_ids:
1489
1526
  entities_ordered.append((entity_id, entity_name))
1490
1527
  seen_entity_ids.add(entity_id)
@@ -1512,9 +1549,7 @@ class MemoryEngine:
1512
1549
 
1513
1550
  if included_observations:
1514
1551
  entities_dict[entity_name] = EntityState(
1515
- entity_id=entity_id,
1516
- canonical_name=entity_name,
1517
- observations=included_observations
1552
+ entity_id=entity_id, canonical_name=entity_name, observations=included_observations
1518
1553
  )
1519
1554
  total_entity_tokens += entity_tokens
1520
1555
 
@@ -1542,11 +1577,11 @@ class MemoryEngine:
1542
1577
  FROM chunks
1543
1578
  WHERE chunk_id = ANY($1::text[])
1544
1579
  """,
1545
- chunk_ids_ordered
1580
+ chunk_ids_ordered,
1546
1581
  )
1547
1582
 
1548
1583
  # Create a lookup dict for fast access
1549
- chunks_lookup = {row['chunk_id']: row for row in chunks_rows}
1584
+ chunks_lookup = {row["chunk_id"]: row for row in chunks_rows}
1550
1585
 
1551
1586
  # Apply token limit and build chunks_dict in the order of chunk_ids_ordered
1552
1587
  chunks_dict = {}
@@ -1557,7 +1592,7 @@ class MemoryEngine:
1557
1592
  continue
1558
1593
 
1559
1594
  row = chunks_lookup[chunk_id]
1560
- chunk_text = row['chunk_text']
1595
+ chunk_text = row["chunk_text"]
1561
1596
  chunk_tokens = len(encoding.encode(chunk_text))
1562
1597
 
1563
1598
  # Check if adding this chunk would exceed the limit
@@ -1568,18 +1603,14 @@ class MemoryEngine:
1568
1603
  # Truncate to remaining tokens
1569
1604
  truncated_text = encoding.decode(encoding.encode(chunk_text)[:remaining_tokens])
1570
1605
  chunks_dict[chunk_id] = ChunkInfo(
1571
- chunk_text=truncated_text,
1572
- chunk_index=row['chunk_index'],
1573
- truncated=True
1606
+ chunk_text=truncated_text, chunk_index=row["chunk_index"], truncated=True
1574
1607
  )
1575
1608
  total_chunk_tokens = max_chunk_tokens
1576
1609
  # Stop adding more chunks once we hit the limit
1577
1610
  break
1578
1611
  else:
1579
1612
  chunks_dict[chunk_id] = ChunkInfo(
1580
- chunk_text=chunk_text,
1581
- chunk_index=row['chunk_index'],
1582
- truncated=False
1613
+ chunk_text=chunk_text, chunk_index=row["chunk_index"], truncated=False
1583
1614
  )
1584
1615
  total_chunk_tokens += chunk_tokens
1585
1616
 
@@ -1593,7 +1624,9 @@ class MemoryEngine:
1593
1624
  total_time = time.time() - recall_start
1594
1625
  num_chunks = len(chunks_dict) if chunks_dict else 0
1595
1626
  num_entities = len(entities_dict) if entities_dict else 0
1596
- log_buffer.append(f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s")
1627
+ log_buffer.append(
1628
+ f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s"
1629
+ )
1597
1630
  logger.info("\n" + "\n".join(log_buffer))
1598
1631
 
1599
1632
  return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
@@ -1604,10 +1637,8 @@ class MemoryEngine:
1604
1637
  raise Exception(f"Failed to search memories: {str(e)}")
1605
1638
 
1606
1639
  def _filter_by_token_budget(
1607
- self,
1608
- results: List[Dict[str, Any]],
1609
- max_tokens: int
1610
- ) -> Tuple[List[Dict[str, Any]], int]:
1640
+ self, results: list[dict[str, Any]], max_tokens: int
1641
+ ) -> tuple[list[dict[str, Any]], int]:
1611
1642
  """
1612
1643
  Filter results to fit within token budget.
1613
1644
 
@@ -1640,7 +1671,7 @@ class MemoryEngine:
1640
1671
 
1641
1672
  return filtered_results, total_tokens
1642
1673
 
1643
- async def get_document(self, document_id: str, bank_id: str) -> Optional[Dict[str, Any]]:
1674
+ async def get_document(self, document_id: str, bank_id: str) -> dict[str, Any] | None:
1644
1675
  """
1645
1676
  Retrieve document metadata and statistics.
1646
1677
 
@@ -1662,7 +1693,8 @@ class MemoryEngine:
1662
1693
  WHERE d.id = $1 AND d.bank_id = $2
1663
1694
  GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
1664
1695
  """,
1665
- document_id, bank_id
1696
+ document_id,
1697
+ bank_id,
1666
1698
  )
1667
1699
 
1668
1700
  if not doc:
@@ -1675,10 +1707,10 @@ class MemoryEngine:
1675
1707
  "content_hash": doc["content_hash"],
1676
1708
  "memory_unit_count": doc["unit_count"],
1677
1709
  "created_at": doc["created_at"],
1678
- "updated_at": doc["updated_at"]
1710
+ "updated_at": doc["updated_at"],
1679
1711
  }
1680
1712
 
1681
- async def delete_document(self, document_id: str, bank_id: str) -> Dict[str, int]:
1713
+ async def delete_document(self, document_id: str, bank_id: str) -> dict[str, int]:
1682
1714
  """
1683
1715
  Delete a document and all its associated memory units and links.
1684
1716
 
@@ -1694,22 +1726,17 @@ class MemoryEngine:
1694
1726
  async with conn.transaction():
1695
1727
  # Count units before deletion
1696
1728
  units_count = await conn.fetchval(
1697
- "SELECT COUNT(*) FROM memory_units WHERE document_id = $1",
1698
- document_id
1729
+ "SELECT COUNT(*) FROM memory_units WHERE document_id = $1", document_id
1699
1730
  )
1700
1731
 
1701
1732
  # Delete document (cascades to memory_units and all their links)
1702
1733
  deleted = await conn.fetchval(
1703
- "DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id",
1704
- document_id, bank_id
1734
+ "DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id", document_id, bank_id
1705
1735
  )
1706
1736
 
1707
- return {
1708
- "document_deleted": 1 if deleted else 0,
1709
- "memory_units_deleted": units_count if deleted else 0
1710
- }
1737
+ return {"document_deleted": 1 if deleted else 0, "memory_units_deleted": units_count if deleted else 0}
1711
1738
 
1712
- async def delete_memory_unit(self, unit_id: str) -> Dict[str, Any]:
1739
+ async def delete_memory_unit(self, unit_id: str) -> dict[str, Any]:
1713
1740
  """
1714
1741
  Delete a single memory unit and all its associated links.
1715
1742
 
@@ -1728,18 +1755,17 @@ class MemoryEngine:
1728
1755
  async with acquire_with_retry(pool) as conn:
1729
1756
  async with conn.transaction():
1730
1757
  # Delete the memory unit (cascades to links and associations)
1731
- deleted = await conn.fetchval(
1732
- "DELETE FROM memory_units WHERE id = $1 RETURNING id",
1733
- unit_id
1734
- )
1758
+ deleted = await conn.fetchval("DELETE FROM memory_units WHERE id = $1 RETURNING id", unit_id)
1735
1759
 
1736
1760
  return {
1737
1761
  "success": deleted is not None,
1738
1762
  "unit_id": str(deleted) if deleted else None,
1739
- "message": "Memory unit and all its links deleted successfully" if deleted else "Memory unit not found"
1763
+ "message": "Memory unit and all its links deleted successfully"
1764
+ if deleted
1765
+ else "Memory unit not found",
1740
1766
  }
1741
1767
 
1742
- async def delete_bank(self, bank_id: str, fact_type: Optional[str] = None) -> Dict[str, int]:
1768
+ async def delete_bank(self, bank_id: str, fact_type: str | None = None) -> dict[str, int]:
1743
1769
  """
1744
1770
  Delete all data for a specific agent (multi-tenant cleanup).
1745
1771
 
@@ -1768,24 +1794,27 @@ class MemoryEngine:
1768
1794
  # Delete only memories of a specific fact type
1769
1795
  units_count = await conn.fetchval(
1770
1796
  "SELECT COUNT(*) FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
1771
- bank_id, fact_type
1797
+ bank_id,
1798
+ fact_type,
1772
1799
  )
1773
1800
  await conn.execute(
1774
- "DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
1775
- bank_id, fact_type
1801
+ "DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2", bank_id, fact_type
1776
1802
  )
1777
1803
 
1778
1804
  # Note: We don't delete entities when fact_type is specified,
1779
1805
  # as they may be referenced by other memory units
1780
- return {
1781
- "memory_units_deleted": units_count,
1782
- "entities_deleted": 0
1783
- }
1806
+ return {"memory_units_deleted": units_count, "entities_deleted": 0}
1784
1807
  else:
1785
1808
  # Delete all data for the bank
1786
- units_count = await conn.fetchval("SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id)
1787
- entities_count = await conn.fetchval("SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id)
1788
- documents_count = await conn.fetchval("SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id)
1809
+ units_count = await conn.fetchval(
1810
+ "SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id
1811
+ )
1812
+ entities_count = await conn.fetchval(
1813
+ "SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id
1814
+ )
1815
+ documents_count = await conn.fetchval(
1816
+ "SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id
1817
+ )
1789
1818
 
1790
1819
  # Delete documents (cascades to chunks)
1791
1820
  await conn.execute("DELETE FROM documents WHERE bank_id = $1", bank_id)
@@ -1803,13 +1832,13 @@ class MemoryEngine:
1803
1832
  "memory_units_deleted": units_count,
1804
1833
  "entities_deleted": entities_count,
1805
1834
  "documents_deleted": documents_count,
1806
- "bank_deleted": True
1835
+ "bank_deleted": True,
1807
1836
  }
1808
1837
 
1809
1838
  except Exception as e:
1810
1839
  raise Exception(f"Failed to delete agent data: {str(e)}")
1811
1840
 
1812
- async def get_graph_data(self, bank_id: Optional[str] = None, fact_type: Optional[str] = None):
1841
+ async def get_graph_data(self, bank_id: str | None = None, fact_type: str | None = None):
1813
1842
  """
1814
1843
  Get graph data for visualization.
1815
1844
 
@@ -1839,19 +1868,23 @@ class MemoryEngine:
1839
1868
 
1840
1869
  where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
1841
1870
 
1842
- units = await conn.fetch(f"""
1871
+ units = await conn.fetch(
1872
+ f"""
1843
1873
  SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
1844
1874
  FROM memory_units
1845
1875
  {where_clause}
1846
1876
  ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
1847
1877
  LIMIT 1000
1848
- """, *query_params)
1878
+ """,
1879
+ *query_params,
1880
+ )
1849
1881
 
1850
1882
  # Get links, filtering to only include links between units of the selected agent
1851
1883
  # Use DISTINCT ON with LEAST/GREATEST to deduplicate bidirectional links
1852
- unit_ids = [row['id'] for row in units]
1884
+ unit_ids = [row["id"] for row in units]
1853
1885
  if unit_ids:
1854
- links = await conn.fetch("""
1886
+ links = await conn.fetch(
1887
+ """
1855
1888
  SELECT DISTINCT ON (LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid))
1856
1889
  ml.from_unit_id,
1857
1890
  ml.to_unit_id,
@@ -1862,7 +1895,9 @@ class MemoryEngine:
1862
1895
  LEFT JOIN entities e ON ml.entity_id = e.id
1863
1896
  WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.to_unit_id = ANY($1::uuid[])
1864
1897
  ORDER BY LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid), ml.weight DESC
1865
- """, unit_ids)
1898
+ """,
1899
+ unit_ids,
1900
+ )
1866
1901
  else:
1867
1902
  links = []
1868
1903
 
@@ -1877,8 +1912,8 @@ class MemoryEngine:
1877
1912
  # Build entity mapping
1878
1913
  entity_map = {}
1879
1914
  for row in unit_entities:
1880
- unit_id = row['unit_id']
1881
- entity_name = row['canonical_name']
1915
+ unit_id = row["unit_id"]
1916
+ entity_name = row["canonical_name"]
1882
1917
  if unit_id not in entity_map:
1883
1918
  entity_map[unit_id] = []
1884
1919
  entity_map[unit_id].append(entity_name)
@@ -1886,10 +1921,10 @@ class MemoryEngine:
1886
1921
  # Build nodes
1887
1922
  nodes = []
1888
1923
  for row in units:
1889
- unit_id = row['id']
1890
- text = row['text']
1891
- event_date = row['event_date']
1892
- context = row['context']
1924
+ unit_id = row["id"]
1925
+ text = row["text"]
1926
+ event_date = row["event_date"]
1927
+ context = row["context"]
1893
1928
 
1894
1929
  entities = entity_map.get(unit_id, [])
1895
1930
  entity_count = len(entities)
@@ -1902,88 +1937,91 @@ class MemoryEngine:
1902
1937
  else:
1903
1938
  color = "#42a5f5"
1904
1939
 
1905
- nodes.append({
1906
- "data": {
1907
- "id": str(unit_id),
1908
- "label": f"{text[:30]}..." if len(text) > 30 else text,
1909
- "text": text,
1910
- "date": event_date.isoformat() if event_date else "",
1911
- "context": context if context else "",
1912
- "entities": ", ".join(entities) if entities else "None",
1913
- "color": color
1940
+ nodes.append(
1941
+ {
1942
+ "data": {
1943
+ "id": str(unit_id),
1944
+ "label": f"{text[:30]}..." if len(text) > 30 else text,
1945
+ "text": text,
1946
+ "date": event_date.isoformat() if event_date else "",
1947
+ "context": context if context else "",
1948
+ "entities": ", ".join(entities) if entities else "None",
1949
+ "color": color,
1950
+ }
1914
1951
  }
1915
- })
1952
+ )
1916
1953
 
1917
1954
  # Build edges
1918
1955
  edges = []
1919
1956
  for row in links:
1920
- from_id = str(row['from_unit_id'])
1921
- to_id = str(row['to_unit_id'])
1922
- link_type = row['link_type']
1923
- weight = row['weight']
1924
- entity_name = row['entity_name']
1957
+ from_id = str(row["from_unit_id"])
1958
+ to_id = str(row["to_unit_id"])
1959
+ link_type = row["link_type"]
1960
+ weight = row["weight"]
1961
+ entity_name = row["entity_name"]
1925
1962
 
1926
1963
  # Color by link type
1927
- if link_type == 'temporal':
1964
+ if link_type == "temporal":
1928
1965
  color = "#00bcd4"
1929
1966
  line_style = "dashed"
1930
- elif link_type == 'semantic':
1967
+ elif link_type == "semantic":
1931
1968
  color = "#ff69b4"
1932
1969
  line_style = "solid"
1933
- elif link_type == 'entity':
1970
+ elif link_type == "entity":
1934
1971
  color = "#ffd700"
1935
1972
  line_style = "solid"
1936
1973
  else:
1937
1974
  color = "#999999"
1938
1975
  line_style = "solid"
1939
1976
 
1940
- edges.append({
1941
- "data": {
1942
- "id": f"{from_id}-{to_id}-{link_type}",
1943
- "source": from_id,
1944
- "target": to_id,
1945
- "linkType": link_type,
1946
- "weight": weight,
1947
- "entityName": entity_name if entity_name else "",
1948
- "color": color,
1949
- "lineStyle": line_style
1977
+ edges.append(
1978
+ {
1979
+ "data": {
1980
+ "id": f"{from_id}-{to_id}-{link_type}",
1981
+ "source": from_id,
1982
+ "target": to_id,
1983
+ "linkType": link_type,
1984
+ "weight": weight,
1985
+ "entityName": entity_name if entity_name else "",
1986
+ "color": color,
1987
+ "lineStyle": line_style,
1988
+ }
1950
1989
  }
1951
- })
1990
+ )
1952
1991
 
1953
1992
  # Build table rows
1954
1993
  table_rows = []
1955
1994
  for row in units:
1956
- unit_id = row['id']
1995
+ unit_id = row["id"]
1957
1996
  entities = entity_map.get(unit_id, [])
1958
1997
 
1959
- table_rows.append({
1960
- "id": str(unit_id),
1961
- "text": row['text'],
1962
- "context": row['context'] if row['context'] else "N/A",
1963
- "occurred_start": row['occurred_start'].isoformat() if row['occurred_start'] else None,
1964
- "occurred_end": row['occurred_end'].isoformat() if row['occurred_end'] else None,
1965
- "mentioned_at": row['mentioned_at'].isoformat() if row['mentioned_at'] else None,
1966
- "date": row['event_date'].strftime("%Y-%m-%d %H:%M") if row['event_date'] else "N/A", # Deprecated, kept for backwards compatibility
1967
- "entities": ", ".join(entities) if entities else "None",
1968
- "document_id": row['document_id'],
1969
- "chunk_id": row['chunk_id'] if row['chunk_id'] else None,
1970
- "fact_type": row['fact_type']
1971
- })
1972
-
1973
- return {
1974
- "nodes": nodes,
1975
- "edges": edges,
1976
- "table_rows": table_rows,
1977
- "total_units": len(units)
1978
- }
1998
+ table_rows.append(
1999
+ {
2000
+ "id": str(unit_id),
2001
+ "text": row["text"],
2002
+ "context": row["context"] if row["context"] else "N/A",
2003
+ "occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
2004
+ "occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
2005
+ "mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
2006
+ "date": row["event_date"].strftime("%Y-%m-%d %H:%M")
2007
+ if row["event_date"]
2008
+ else "N/A", # Deprecated, kept for backwards compatibility
2009
+ "entities": ", ".join(entities) if entities else "None",
2010
+ "document_id": row["document_id"],
2011
+ "chunk_id": row["chunk_id"] if row["chunk_id"] else None,
2012
+ "fact_type": row["fact_type"],
2013
+ }
2014
+ )
2015
+
2016
+ return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units": len(units)}
1979
2017
 
1980
2018
  async def list_memory_units(
1981
2019
  self,
1982
- bank_id: Optional[str] = None,
1983
- fact_type: Optional[str] = None,
1984
- search_query: Optional[str] = None,
2020
+ bank_id: str | None = None,
2021
+ fact_type: str | None = None,
2022
+ search_query: str | None = None,
1985
2023
  limit: int = 100,
1986
- offset: int = 0
2024
+ offset: int = 0,
1987
2025
  ):
1988
2026
  """
1989
2027
  List memory units for table view with optional full-text search.
@@ -2030,7 +2068,7 @@ class MemoryEngine:
2030
2068
  {where_clause}
2031
2069
  """
2032
2070
  count_result = await conn.fetchrow(count_query, *query_params)
2033
- total = count_result['total']
2071
+ total = count_result["total"]
2034
2072
 
2035
2073
  # Get units with limit and offset
2036
2074
  param_count += 1
@@ -2041,32 +2079,38 @@ class MemoryEngine:
2041
2079
  offset_param = f"${param_count}"
2042
2080
  query_params.append(offset)
2043
2081
 
2044
- units = await conn.fetch(f"""
2082
+ units = await conn.fetch(
2083
+ f"""
2045
2084
  SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id
2046
2085
  FROM memory_units
2047
2086
  {where_clause}
2048
2087
  ORDER BY mentioned_at DESC NULLS LAST, created_at DESC
2049
2088
  LIMIT {limit_param} OFFSET {offset_param}
2050
- """, *query_params)
2089
+ """,
2090
+ *query_params,
2091
+ )
2051
2092
 
2052
2093
  # Get entity information for these units
2053
2094
  if units:
2054
- unit_ids = [row['id'] for row in units]
2055
- unit_entities = await conn.fetch("""
2095
+ unit_ids = [row["id"] for row in units]
2096
+ unit_entities = await conn.fetch(
2097
+ """
2056
2098
  SELECT ue.unit_id, e.canonical_name
2057
2099
  FROM unit_entities ue
2058
2100
  JOIN entities e ON ue.entity_id = e.id
2059
2101
  WHERE ue.unit_id = ANY($1::uuid[])
2060
2102
  ORDER BY ue.unit_id
2061
- """, unit_ids)
2103
+ """,
2104
+ unit_ids,
2105
+ )
2062
2106
  else:
2063
2107
  unit_entities = []
2064
2108
 
2065
2109
  # Build entity mapping
2066
2110
  entity_map = {}
2067
2111
  for row in unit_entities:
2068
- unit_id = row['unit_id']
2069
- entity_name = row['canonical_name']
2112
+ unit_id = row["unit_id"]
2113
+ entity_name = row["canonical_name"]
2070
2114
  if unit_id not in entity_map:
2071
2115
  entity_map[unit_id] = []
2072
2116
  entity_map[unit_id].append(entity_name)
@@ -2074,36 +2118,27 @@ class MemoryEngine:
2074
2118
  # Build result items
2075
2119
  items = []
2076
2120
  for row in units:
2077
- unit_id = row['id']
2121
+ unit_id = row["id"]
2078
2122
  entities = entity_map.get(unit_id, [])
2079
2123
 
2080
- items.append({
2081
- "id": str(unit_id),
2082
- "text": row['text'],
2083
- "context": row['context'] if row['context'] else "",
2084
- "date": row['event_date'].isoformat() if row['event_date'] else "",
2085
- "fact_type": row['fact_type'],
2086
- "mentioned_at": row['mentioned_at'].isoformat() if row['mentioned_at'] else None,
2087
- "occurred_start": row['occurred_start'].isoformat() if row['occurred_start'] else None,
2088
- "occurred_end": row['occurred_end'].isoformat() if row['occurred_end'] else None,
2089
- "entities": ", ".join(entities) if entities else "",
2090
- "chunk_id": row['chunk_id'] if row['chunk_id'] else None
2091
- })
2124
+ items.append(
2125
+ {
2126
+ "id": str(unit_id),
2127
+ "text": row["text"],
2128
+ "context": row["context"] if row["context"] else "",
2129
+ "date": row["event_date"].isoformat() if row["event_date"] else "",
2130
+ "fact_type": row["fact_type"],
2131
+ "mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
2132
+ "occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
2133
+ "occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
2134
+ "entities": ", ".join(entities) if entities else "",
2135
+ "chunk_id": row["chunk_id"] if row["chunk_id"] else None,
2136
+ }
2137
+ )
2092
2138
 
2093
- return {
2094
- "items": items,
2095
- "total": total,
2096
- "limit": limit,
2097
- "offset": offset
2098
- }
2139
+ return {"items": items, "total": total, "limit": limit, "offset": offset}
2099
2140
 
2100
- async def list_documents(
2101
- self,
2102
- bank_id: str,
2103
- search_query: Optional[str] = None,
2104
- limit: int = 100,
2105
- offset: int = 0
2106
- ):
2141
+ async def list_documents(self, bank_id: str, search_query: str | None = None, limit: int = 100, offset: int = 0):
2107
2142
  """
2108
2143
  List documents with optional search and pagination.
2109
2144
 
@@ -2142,7 +2177,7 @@ class MemoryEngine:
2142
2177
  {where_clause}
2143
2178
  """
2144
2179
  count_result = await conn.fetchrow(count_query, *query_params)
2145
- total = count_result['total']
2180
+ total = count_result["total"]
2146
2181
 
2147
2182
  # Get documents with limit and offset (without original_text for performance)
2148
2183
  param_count += 1
@@ -2153,7 +2188,8 @@ class MemoryEngine:
2153
2188
  offset_param = f"${param_count}"
2154
2189
  query_params.append(offset)
2155
2190
 
2156
- documents = await conn.fetch(f"""
2191
+ documents = await conn.fetch(
2192
+ f"""
2157
2193
  SELECT
2158
2194
  id,
2159
2195
  bank_id,
@@ -2166,11 +2202,13 @@ class MemoryEngine:
2166
2202
  {where_clause}
2167
2203
  ORDER BY created_at DESC
2168
2204
  LIMIT {limit_param} OFFSET {offset_param}
2169
- """, *query_params)
2205
+ """,
2206
+ *query_params,
2207
+ )
2170
2208
 
2171
2209
  # Get memory unit count for each document
2172
2210
  if documents:
2173
- doc_ids = [(row['id'], row['bank_id']) for row in documents]
2211
+ doc_ids = [(row["id"], row["bank_id"]) for row in documents]
2174
2212
 
2175
2213
  # Create placeholders for the query
2176
2214
  placeholders = []
@@ -2183,48 +2221,44 @@ class MemoryEngine:
2183
2221
 
2184
2222
  where_clause_count = " OR ".join(placeholders)
2185
2223
 
2186
- unit_counts = await conn.fetch(f"""
2224
+ unit_counts = await conn.fetch(
2225
+ f"""
2187
2226
  SELECT document_id, bank_id, COUNT(*) as unit_count
2188
2227
  FROM memory_units
2189
2228
  WHERE {where_clause_count}
2190
2229
  GROUP BY document_id, bank_id
2191
- """, *params_for_count)
2230
+ """,
2231
+ *params_for_count,
2232
+ )
2192
2233
  else:
2193
2234
  unit_counts = []
2194
2235
 
2195
2236
  # Build count mapping
2196
- count_map = {(row['document_id'], row['bank_id']): row['unit_count'] for row in unit_counts}
2237
+ count_map = {(row["document_id"], row["bank_id"]): row["unit_count"] for row in unit_counts}
2197
2238
 
2198
2239
  # Build result items
2199
2240
  items = []
2200
2241
  for row in documents:
2201
- doc_id = row['id']
2202
- bank_id_val = row['bank_id']
2242
+ doc_id = row["id"]
2243
+ bank_id_val = row["bank_id"]
2203
2244
  unit_count = count_map.get((doc_id, bank_id_val), 0)
2204
2245
 
2205
- items.append({
2206
- "id": doc_id,
2207
- "bank_id": bank_id_val,
2208
- "content_hash": row['content_hash'],
2209
- "created_at": row['created_at'].isoformat() if row['created_at'] else "",
2210
- "updated_at": row['updated_at'].isoformat() if row['updated_at'] else "",
2211
- "text_length": row['text_length'] or 0,
2212
- "memory_unit_count": unit_count,
2213
- "retain_params": row['retain_params'] if row['retain_params'] else None
2214
- })
2246
+ items.append(
2247
+ {
2248
+ "id": doc_id,
2249
+ "bank_id": bank_id_val,
2250
+ "content_hash": row["content_hash"],
2251
+ "created_at": row["created_at"].isoformat() if row["created_at"] else "",
2252
+ "updated_at": row["updated_at"].isoformat() if row["updated_at"] else "",
2253
+ "text_length": row["text_length"] or 0,
2254
+ "memory_unit_count": unit_count,
2255
+ "retain_params": row["retain_params"] if row["retain_params"] else None,
2256
+ }
2257
+ )
2215
2258
 
2216
- return {
2217
- "items": items,
2218
- "total": total,
2219
- "limit": limit,
2220
- "offset": offset
2221
- }
2259
+ return {"items": items, "total": total, "limit": limit, "offset": offset}
2222
2260
 
2223
- async def get_document(
2224
- self,
2225
- document_id: str,
2226
- bank_id: str
2227
- ):
2261
+ async def get_document(self, document_id: str, bank_id: str):
2228
2262
  """
2229
2263
  Get a specific document including its original_text.
2230
2264
 
@@ -2237,7 +2271,8 @@ class MemoryEngine:
2237
2271
  """
2238
2272
  pool = await self._get_pool()
2239
2273
  async with acquire_with_retry(pool) as conn:
2240
- doc = await conn.fetchrow("""
2274
+ doc = await conn.fetchrow(
2275
+ """
2241
2276
  SELECT
2242
2277
  id,
2243
2278
  bank_id,
@@ -2248,33 +2283,37 @@ class MemoryEngine:
2248
2283
  retain_params
2249
2284
  FROM documents
2250
2285
  WHERE id = $1 AND bank_id = $2
2251
- """, document_id, bank_id)
2286
+ """,
2287
+ document_id,
2288
+ bank_id,
2289
+ )
2252
2290
 
2253
2291
  if not doc:
2254
2292
  return None
2255
2293
 
2256
2294
  # Get memory unit count
2257
- unit_count_row = await conn.fetchrow("""
2295
+ unit_count_row = await conn.fetchrow(
2296
+ """
2258
2297
  SELECT COUNT(*) as unit_count
2259
2298
  FROM memory_units
2260
2299
  WHERE document_id = $1 AND bank_id = $2
2261
- """, document_id, bank_id)
2300
+ """,
2301
+ document_id,
2302
+ bank_id,
2303
+ )
2262
2304
 
2263
2305
  return {
2264
- "id": doc['id'],
2265
- "bank_id": doc['bank_id'],
2266
- "original_text": doc['original_text'],
2267
- "content_hash": doc['content_hash'],
2268
- "created_at": doc['created_at'].isoformat() if doc['created_at'] else "",
2269
- "updated_at": doc['updated_at'].isoformat() if doc['updated_at'] else "",
2270
- "memory_unit_count": unit_count_row['unit_count'] if unit_count_row else 0,
2271
- "retain_params": doc['retain_params'] if doc['retain_params'] else None
2306
+ "id": doc["id"],
2307
+ "bank_id": doc["bank_id"],
2308
+ "original_text": doc["original_text"],
2309
+ "content_hash": doc["content_hash"],
2310
+ "created_at": doc["created_at"].isoformat() if doc["created_at"] else "",
2311
+ "updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else "",
2312
+ "memory_unit_count": unit_count_row["unit_count"] if unit_count_row else 0,
2313
+ "retain_params": doc["retain_params"] if doc["retain_params"] else None,
2272
2314
  }
2273
2315
 
2274
- async def get_chunk(
2275
- self,
2276
- chunk_id: str
2277
- ):
2316
+ async def get_chunk(self, chunk_id: str):
2278
2317
  """
2279
2318
  Get a specific chunk by its ID.
2280
2319
 
@@ -2286,7 +2325,8 @@ class MemoryEngine:
2286
2325
  """
2287
2326
  pool = await self._get_pool()
2288
2327
  async with acquire_with_retry(pool) as conn:
2289
- chunk = await conn.fetchrow("""
2328
+ chunk = await conn.fetchrow(
2329
+ """
2290
2330
  SELECT
2291
2331
  chunk_id,
2292
2332
  document_id,
@@ -2296,18 +2336,20 @@ class MemoryEngine:
2296
2336
  created_at
2297
2337
  FROM chunks
2298
2338
  WHERE chunk_id = $1
2299
- """, chunk_id)
2339
+ """,
2340
+ chunk_id,
2341
+ )
2300
2342
 
2301
2343
  if not chunk:
2302
2344
  return None
2303
2345
 
2304
2346
  return {
2305
- "chunk_id": chunk['chunk_id'],
2306
- "document_id": chunk['document_id'],
2307
- "bank_id": chunk['bank_id'],
2308
- "chunk_index": chunk['chunk_index'],
2309
- "chunk_text": chunk['chunk_text'],
2310
- "created_at": chunk['created_at'].isoformat() if chunk['created_at'] else ""
2347
+ "chunk_id": chunk["chunk_id"],
2348
+ "document_id": chunk["document_id"],
2349
+ "bank_id": chunk["bank_id"],
2350
+ "chunk_index": chunk["chunk_index"],
2351
+ "chunk_text": chunk["chunk_text"],
2352
+ "created_at": chunk["created_at"].isoformat() if chunk["created_at"] else "",
2311
2353
  }
2312
2354
 
2313
2355
  async def _evaluate_opinion_update_async(
@@ -2316,7 +2358,7 @@ class MemoryEngine:
2316
2358
  opinion_confidence: float,
2317
2359
  new_event_text: str,
2318
2360
  entity_name: str,
2319
- ) -> Optional[Dict[str, Any]]:
2361
+ ) -> dict[str, Any] | None:
2320
2362
  """
2321
2363
  Evaluate if an opinion should be updated based on a new event.
2322
2364
 
@@ -2330,16 +2372,18 @@ class MemoryEngine:
2330
2372
  Dict with 'action' ('keep'|'update'), 'new_confidence', 'new_text' (if action=='update')
2331
2373
  or None if no changes needed
2332
2374
  """
2333
- from pydantic import BaseModel, Field
2334
2375
 
2335
2376
  class OpinionEvaluation(BaseModel):
2336
2377
  """Evaluation of whether an opinion should be updated."""
2378
+
2337
2379
  action: str = Field(description="Action to take: 'keep' (no change) or 'update' (modify opinion)")
2338
2380
  reasoning: str = Field(description="Brief explanation of why this action was chosen")
2339
- new_confidence: float = Field(description="New confidence score (0.0-1.0). Can be higher, lower, or same as before.")
2340
- new_opinion_text: Optional[str] = Field(
2381
+ new_confidence: float = Field(
2382
+ description="New confidence score (0.0-1.0). Can be higher, lower, or same as before."
2383
+ )
2384
+ new_opinion_text: str | None = Field(
2341
2385
  default=None,
2342
- description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None."
2386
+ description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None.",
2343
2387
  )
2344
2388
 
2345
2389
  evaluation_prompt = f"""You are evaluating whether an existing opinion should be updated based on new information.
@@ -2369,70 +2413,63 @@ Guidelines:
2369
2413
  result = await self._llm_config.call(
2370
2414
  messages=[
2371
2415
  {"role": "system", "content": "You evaluate and update opinions based on new information."},
2372
- {"role": "user", "content": evaluation_prompt}
2416
+ {"role": "user", "content": evaluation_prompt},
2373
2417
  ],
2374
2418
  response_format=OpinionEvaluation,
2375
2419
  scope="memory_evaluate_opinion",
2376
- temperature=0.3 # Lower temperature for more consistent evaluation
2420
+ temperature=0.3, # Lower temperature for more consistent evaluation
2377
2421
  )
2378
2422
 
2379
2423
  # Only return updates if something actually changed
2380
- if result.action == 'keep' and abs(result.new_confidence - opinion_confidence) < 0.01:
2424
+ if result.action == "keep" and abs(result.new_confidence - opinion_confidence) < 0.01:
2381
2425
  return None
2382
2426
 
2383
2427
  return {
2384
- 'action': result.action,
2385
- 'reasoning': result.reasoning,
2386
- 'new_confidence': result.new_confidence,
2387
- 'new_text': result.new_opinion_text if result.action == 'update' else None
2428
+ "action": result.action,
2429
+ "reasoning": result.reasoning,
2430
+ "new_confidence": result.new_confidence,
2431
+ "new_text": result.new_opinion_text if result.action == "update" else None,
2388
2432
  }
2389
2433
 
2390
2434
  except Exception as e:
2391
2435
  logger.warning(f"Failed to evaluate opinion update: {str(e)}")
2392
2436
  return None
2393
2437
 
2394
- async def _handle_form_opinion(self, task_dict: Dict[str, Any]):
2438
+ async def _handle_form_opinion(self, task_dict: dict[str, Any]):
2395
2439
  """
2396
2440
  Handler for form opinion tasks.
2397
2441
 
2398
2442
  Args:
2399
2443
  task_dict: Dict with keys: 'bank_id', 'answer_text', 'query'
2400
2444
  """
2401
- bank_id = task_dict['bank_id']
2402
- answer_text = task_dict['answer_text']
2403
- query = task_dict['query']
2445
+ bank_id = task_dict["bank_id"]
2446
+ answer_text = task_dict["answer_text"]
2447
+ query = task_dict["query"]
2404
2448
 
2405
- await self._extract_and_store_opinions_async(
2406
- bank_id=bank_id,
2407
- answer_text=answer_text,
2408
- query=query
2409
- )
2449
+ await self._extract_and_store_opinions_async(bank_id=bank_id, answer_text=answer_text, query=query)
2410
2450
 
2411
- async def _handle_reinforce_opinion(self, task_dict: Dict[str, Any]):
2451
+ async def _handle_reinforce_opinion(self, task_dict: dict[str, Any]):
2412
2452
  """
2413
2453
  Handler for reinforce opinion tasks.
2414
2454
 
2415
2455
  Args:
2416
2456
  task_dict: Dict with keys: 'bank_id', 'created_unit_ids', 'unit_texts', 'unit_entities'
2417
2457
  """
2418
- bank_id = task_dict['bank_id']
2419
- created_unit_ids = task_dict['created_unit_ids']
2420
- unit_texts = task_dict['unit_texts']
2421
- unit_entities = task_dict['unit_entities']
2458
+ bank_id = task_dict["bank_id"]
2459
+ created_unit_ids = task_dict["created_unit_ids"]
2460
+ unit_texts = task_dict["unit_texts"]
2461
+ unit_entities = task_dict["unit_entities"]
2422
2462
 
2423
2463
  await self._reinforce_opinions_async(
2424
- bank_id=bank_id,
2425
- created_unit_ids=created_unit_ids,
2426
- unit_texts=unit_texts,
2427
- unit_entities=unit_entities
2464
+ bank_id=bank_id, created_unit_ids=created_unit_ids, unit_texts=unit_texts, unit_entities=unit_entities
2428
2465
  )
2429
2466
 
2430
2467
  async def _reinforce_opinions_async(
2431
2468
  self,
2432
2469
  bank_id: str,
2433
- created_unit_ids: List[str],
2434
- unit_texts: List[str],
2435
- unit_entities: List[List[Dict[str, str]]],
2470
+ created_unit_ids: list[str],
2471
+ unit_texts: list[str],
2472
+ unit_entities: list[list[dict[str, str]]],
2436
2473
  ):
2437
2474
  """
2438
2475
  Background task to reinforce opinions based on newly ingested events.
@@ -2451,15 +2488,14 @@ Guidelines:
2451
2488
  for entities_list in unit_entities:
2452
2489
  for entity in entities_list:
2453
2490
  # Handle both Entity objects and dicts
2454
- if hasattr(entity, 'text'):
2491
+ if hasattr(entity, "text"):
2455
2492
  entity_names.add(entity.text)
2456
2493
  elif isinstance(entity, dict):
2457
- entity_names.add(entity['text'])
2494
+ entity_names.add(entity["text"])
2458
2495
 
2459
2496
  if not entity_names:
2460
2497
  return
2461
2498
 
2462
-
2463
2499
  pool = await self._get_pool()
2464
2500
  async with acquire_with_retry(pool) as conn:
2465
2501
  # Find all opinions related to these entities
@@ -2474,13 +2510,12 @@ Guidelines:
2474
2510
  AND e.canonical_name = ANY($2::text[])
2475
2511
  """,
2476
2512
  bank_id,
2477
- list(entity_names)
2513
+ list(entity_names),
2478
2514
  )
2479
2515
 
2480
2516
  if not opinions:
2481
2517
  return
2482
2518
 
2483
-
2484
2519
  # Use cached LLM config
2485
2520
  if self._llm_config is None:
2486
2521
  logger.error("[REINFORCE] LLM config not available, skipping opinion reinforcement")
@@ -2489,15 +2524,15 @@ Guidelines:
2489
2524
  # Evaluate each opinion against the new events
2490
2525
  updates_to_apply = []
2491
2526
  for opinion in opinions:
2492
- opinion_id = str(opinion['id'])
2493
- opinion_text = opinion['text']
2494
- opinion_confidence = opinion['confidence_score']
2495
- entity_name = opinion['canonical_name']
2527
+ opinion_id = str(opinion["id"])
2528
+ opinion_text = opinion["text"]
2529
+ opinion_confidence = opinion["confidence_score"]
2530
+ entity_name = opinion["canonical_name"]
2496
2531
 
2497
2532
  # Find all new events mentioning this entity
2498
2533
  relevant_events = []
2499
2534
  for unit_text, entities_list in zip(unit_texts, unit_entities):
2500
- if any(e['text'] == entity_name for e in entities_list):
2535
+ if any(e["text"] == entity_name for e in entities_list):
2501
2536
  relevant_events.append(unit_text)
2502
2537
 
2503
2538
  if not relevant_events:
@@ -2508,26 +2543,20 @@ Guidelines:
2508
2543
 
2509
2544
  # Evaluate if opinion should be updated
2510
2545
  evaluation = await self._evaluate_opinion_update_async(
2511
- opinion_text,
2512
- opinion_confidence,
2513
- combined_events,
2514
- entity_name
2546
+ opinion_text, opinion_confidence, combined_events, entity_name
2515
2547
  )
2516
2548
 
2517
2549
  if evaluation:
2518
- updates_to_apply.append({
2519
- 'opinion_id': opinion_id,
2520
- 'evaluation': evaluation
2521
- })
2550
+ updates_to_apply.append({"opinion_id": opinion_id, "evaluation": evaluation})
2522
2551
 
2523
2552
  # Apply all updates in a single transaction
2524
2553
  if updates_to_apply:
2525
2554
  async with conn.transaction():
2526
2555
  for update in updates_to_apply:
2527
- opinion_id = update['opinion_id']
2528
- evaluation = update['evaluation']
2556
+ opinion_id = update["opinion_id"]
2557
+ evaluation = update["evaluation"]
2529
2558
 
2530
- if evaluation['action'] == 'update' and evaluation['new_text']:
2559
+ if evaluation["action"] == "update" and evaluation["new_text"]:
2531
2560
  # Update both text and confidence
2532
2561
  await conn.execute(
2533
2562
  """
@@ -2535,9 +2564,9 @@ Guidelines:
2535
2564
  SET text = $1, confidence_score = $2, updated_at = NOW()
2536
2565
  WHERE id = $3
2537
2566
  """,
2538
- evaluation['new_text'],
2539
- evaluation['new_confidence'],
2540
- uuid.UUID(opinion_id)
2567
+ evaluation["new_text"],
2568
+ evaluation["new_confidence"],
2569
+ uuid.UUID(opinion_id),
2541
2570
  )
2542
2571
  else:
2543
2572
  # Only update confidence
@@ -2547,8 +2576,8 @@ Guidelines:
2547
2576
  SET confidence_score = $1, updated_at = NOW()
2548
2577
  WHERE id = $2
2549
2578
  """,
2550
- evaluation['new_confidence'],
2551
- uuid.UUID(opinion_id)
2579
+ evaluation["new_confidence"],
2580
+ uuid.UUID(opinion_id),
2552
2581
  )
2553
2582
 
2554
2583
  else:
@@ -2557,6 +2586,7 @@ Guidelines:
2557
2586
  except Exception as e:
2558
2587
  logger.error(f"[REINFORCE] Error during opinion reinforcement: {str(e)}")
2559
2588
  import traceback
2589
+
2560
2590
  traceback.print_exc()
2561
2591
 
2562
2592
  # ==================== bank profile Methods ====================
@@ -2575,11 +2605,7 @@ Guidelines:
2575
2605
  pool = await self._get_pool()
2576
2606
  return await bank_utils.get_bank_profile(pool, bank_id)
2577
2607
 
2578
- async def update_bank_disposition(
2579
- self,
2580
- bank_id: str,
2581
- disposition: Dict[str, int]
2582
- ) -> None:
2608
+ async def update_bank_disposition(self, bank_id: str, disposition: dict[str, int]) -> None:
2583
2609
  """
2584
2610
  Update bank disposition traits.
2585
2611
 
@@ -2590,12 +2616,7 @@ Guidelines:
2590
2616
  pool = await self._get_pool()
2591
2617
  await bank_utils.update_bank_disposition(pool, bank_id, disposition)
2592
2618
 
2593
- async def merge_bank_background(
2594
- self,
2595
- bank_id: str,
2596
- new_info: str,
2597
- update_disposition: bool = True
2598
- ) -> dict:
2619
+ async def merge_bank_background(self, bank_id: str, new_info: str, update_disposition: bool = True) -> dict:
2599
2620
  """
2600
2621
  Merge new background information with existing background using LLM.
2601
2622
  Normalizes to first person ("I") and resolves conflicts.
@@ -2610,9 +2631,7 @@ Guidelines:
2610
2631
  Dict with 'background' (str) and optionally 'disposition' (dict) keys
2611
2632
  """
2612
2633
  pool = await self._get_pool()
2613
- return await bank_utils.merge_bank_background(
2614
- pool, self._llm_config, bank_id, new_info, update_disposition
2615
- )
2634
+ return await bank_utils.merge_bank_background(pool, self._llm_config, bank_id, new_info, update_disposition)
2616
2635
 
2617
2636
  async def list_banks(self) -> list:
2618
2637
  """
@@ -2673,19 +2692,21 @@ Guidelines:
2673
2692
  budget=budget,
2674
2693
  max_tokens=4096,
2675
2694
  enable_trace=False,
2676
- fact_type=['experience', 'world', 'opinion'],
2677
- include_entities=True
2695
+ fact_type=["experience", "world", "opinion"],
2696
+ include_entities=True,
2678
2697
  )
2679
2698
  recall_time = time.time() - recall_start
2680
2699
 
2681
2700
  all_results = search_result.results
2682
2701
 
2683
2702
  # Split results by fact type for structured response
2684
- agent_results = [r for r in all_results if r.fact_type == 'experience']
2685
- world_results = [r for r in all_results if r.fact_type == 'world']
2686
- opinion_results = [r for r in all_results if r.fact_type == 'opinion']
2703
+ agent_results = [r for r in all_results if r.fact_type == "experience"]
2704
+ world_results = [r for r in all_results if r.fact_type == "world"]
2705
+ opinion_results = [r for r in all_results if r.fact_type == "opinion"]
2687
2706
 
2688
- log_buffer.append(f"[REFLECT {reflect_id}] Recall: {len(all_results)} facts (experience={len(agent_results)}, world={len(world_results)}, opinion={len(opinion_results)}) in {recall_time:.3f}s")
2707
+ log_buffer.append(
2708
+ f"[REFLECT {reflect_id}] Recall: {len(all_results)} facts (experience={len(agent_results)}, world={len(world_results)}, opinion={len(opinion_results)}) in {recall_time:.3f}s"
2709
+ )
2689
2710
 
2690
2711
  # Format facts for LLM
2691
2712
  agent_facts_text = think_utils.format_facts_for_prompt(agent_results)
@@ -2716,47 +2737,34 @@ Guidelines:
2716
2737
 
2717
2738
  llm_start = time.time()
2718
2739
  answer_text = await self._llm_config.call(
2719
- messages=[
2720
- {"role": "system", "content": system_message},
2721
- {"role": "user", "content": prompt}
2722
- ],
2740
+ messages=[{"role": "system", "content": system_message}, {"role": "user", "content": prompt}],
2723
2741
  scope="memory_think",
2724
2742
  temperature=0.9,
2725
- max_completion_tokens=1000
2743
+ max_completion_tokens=1000,
2726
2744
  )
2727
2745
  llm_time = time.time() - llm_start
2728
2746
 
2729
2747
  answer_text = answer_text.strip()
2730
2748
 
2731
2749
  # Submit form_opinion task for background processing
2732
- await self._task_backend.submit_task({
2733
- 'type': 'form_opinion',
2734
- 'bank_id': bank_id,
2735
- 'answer_text': answer_text,
2736
- 'query': query
2737
- })
2750
+ await self._task_backend.submit_task(
2751
+ {"type": "form_opinion", "bank_id": bank_id, "answer_text": answer_text, "query": query}
2752
+ )
2738
2753
 
2739
2754
  total_time = time.time() - reflect_start
2740
- log_buffer.append(f"[REFLECT {reflect_id}] Complete: {len(answer_text)} chars response, LLM {llm_time:.3f}s, total {total_time:.3f}s")
2755
+ log_buffer.append(
2756
+ f"[REFLECT {reflect_id}] Complete: {len(answer_text)} chars response, LLM {llm_time:.3f}s, total {total_time:.3f}s"
2757
+ )
2741
2758
  logger.info("\n" + "\n".join(log_buffer))
2742
2759
 
2743
2760
  # Return response with facts split by type
2744
2761
  return ReflectResult(
2745
2762
  text=answer_text,
2746
- based_on={
2747
- "world": world_results,
2748
- "experience": agent_results,
2749
- "opinion": opinion_results
2750
- },
2751
- new_opinions=[] # Opinions are being extracted asynchronously
2763
+ based_on={"world": world_results, "experience": agent_results, "opinion": opinion_results},
2764
+ new_opinions=[], # Opinions are being extracted asynchronously
2752
2765
  )
2753
2766
 
2754
- async def _extract_and_store_opinions_async(
2755
- self,
2756
- bank_id: str,
2757
- answer_text: str,
2758
- query: str
2759
- ):
2767
+ async def _extract_and_store_opinions_async(self, bank_id: str, answer_text: str, query: str):
2760
2768
  """
2761
2769
  Background task to extract and store opinions from think response.
2762
2770
 
@@ -2769,33 +2777,27 @@ Guidelines:
2769
2777
  """
2770
2778
  try:
2771
2779
  # Extract opinions from the answer
2772
- new_opinions = await think_utils.extract_opinions_from_text(
2773
- self._llm_config, text=answer_text, query=query
2774
- )
2780
+ new_opinions = await think_utils.extract_opinions_from_text(self._llm_config, text=answer_text, query=query)
2775
2781
 
2776
2782
  # Store new opinions
2777
2783
  if new_opinions:
2778
- from datetime import datetime, timezone
2779
- current_time = datetime.now(timezone.utc)
2784
+ from datetime import datetime
2785
+
2786
+ current_time = datetime.now(UTC)
2780
2787
  for opinion in new_opinions:
2781
2788
  await self.retain_async(
2782
2789
  bank_id=bank_id,
2783
2790
  content=opinion.opinion,
2784
2791
  context=f"formed during thinking about: {query}",
2785
2792
  event_date=current_time,
2786
- fact_type_override='opinion',
2787
- confidence_score=opinion.confidence
2793
+ fact_type_override="opinion",
2794
+ confidence_score=opinion.confidence,
2788
2795
  )
2789
2796
 
2790
2797
  except Exception as e:
2791
2798
  logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
2792
2799
 
2793
- async def get_entity_observations(
2794
- self,
2795
- bank_id: str,
2796
- entity_id: str,
2797
- limit: int = 10
2798
- ) -> List[EntityObservation]:
2800
+ async def get_entity_observations(self, bank_id: str, entity_id: str, limit: int = 10) -> list[EntityObservation]:
2799
2801
  """
2800
2802
  Get observations linked to an entity.
2801
2803
 
@@ -2820,23 +2822,18 @@ Guidelines:
2820
2822
  ORDER BY mu.mentioned_at DESC
2821
2823
  LIMIT $3
2822
2824
  """,
2823
- bank_id, uuid.UUID(entity_id), limit
2825
+ bank_id,
2826
+ uuid.UUID(entity_id),
2827
+ limit,
2824
2828
  )
2825
2829
 
2826
2830
  observations = []
2827
2831
  for row in rows:
2828
- mentioned_at = row['mentioned_at'].isoformat() if row['mentioned_at'] else None
2829
- observations.append(EntityObservation(
2830
- text=row['text'],
2831
- mentioned_at=mentioned_at
2832
- ))
2832
+ mentioned_at = row["mentioned_at"].isoformat() if row["mentioned_at"] else None
2833
+ observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
2833
2834
  return observations
2834
2835
 
2835
- async def list_entities(
2836
- self,
2837
- bank_id: str,
2838
- limit: int = 100
2839
- ) -> List[Dict[str, Any]]:
2836
+ async def list_entities(self, bank_id: str, limit: int = 100) -> list[dict[str, Any]]:
2840
2837
  """
2841
2838
  List all entities for a bank.
2842
2839
 
@@ -2857,39 +2854,37 @@ Guidelines:
2857
2854
  ORDER BY mention_count DESC, last_seen DESC
2858
2855
  LIMIT $2
2859
2856
  """,
2860
- bank_id, limit
2857
+ bank_id,
2858
+ limit,
2861
2859
  )
2862
2860
 
2863
2861
  entities = []
2864
2862
  for row in rows:
2865
2863
  # Handle metadata - may be dict, JSON string, or None
2866
- metadata = row['metadata']
2864
+ metadata = row["metadata"]
2867
2865
  if metadata is None:
2868
2866
  metadata = {}
2869
2867
  elif isinstance(metadata, str):
2870
2868
  import json
2869
+
2871
2870
  try:
2872
2871
  metadata = json.loads(metadata)
2873
2872
  except json.JSONDecodeError:
2874
2873
  metadata = {}
2875
2874
 
2876
- entities.append({
2877
- 'id': str(row['id']),
2878
- 'canonical_name': row['canonical_name'],
2879
- 'mention_count': row['mention_count'],
2880
- 'first_seen': row['first_seen'].isoformat() if row['first_seen'] else None,
2881
- 'last_seen': row['last_seen'].isoformat() if row['last_seen'] else None,
2882
- 'metadata': metadata
2883
- })
2875
+ entities.append(
2876
+ {
2877
+ "id": str(row["id"]),
2878
+ "canonical_name": row["canonical_name"],
2879
+ "mention_count": row["mention_count"],
2880
+ "first_seen": row["first_seen"].isoformat() if row["first_seen"] else None,
2881
+ "last_seen": row["last_seen"].isoformat() if row["last_seen"] else None,
2882
+ "metadata": metadata,
2883
+ }
2884
+ )
2884
2885
  return entities
2885
2886
 
2886
- async def get_entity_state(
2887
- self,
2888
- bank_id: str,
2889
- entity_id: str,
2890
- entity_name: str,
2891
- limit: int = 10
2892
- ) -> EntityState:
2887
+ async def get_entity_state(self, bank_id: str, entity_id: str, entity_name: str, limit: int = 10) -> EntityState:
2893
2888
  """
2894
2889
  Get the current state (mental model) of an entity.
2895
2890
 
@@ -2903,20 +2898,11 @@ Guidelines:
2903
2898
  EntityState with observations
2904
2899
  """
2905
2900
  observations = await self.get_entity_observations(bank_id, entity_id, limit)
2906
- return EntityState(
2907
- entity_id=entity_id,
2908
- canonical_name=entity_name,
2909
- observations=observations
2910
- )
2901
+ return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=observations)
2911
2902
 
2912
2903
  async def regenerate_entity_observations(
2913
- self,
2914
- bank_id: str,
2915
- entity_id: str,
2916
- entity_name: str,
2917
- version: str | None = None,
2918
- conn=None
2919
- ) -> List[str]:
2904
+ self, bank_id: str, entity_id: str, entity_name: str, version: str | None = None, conn=None
2905
+ ) -> list[str]:
2920
2906
  """
2921
2907
  Regenerate observations for an entity by:
2922
2908
  1. Checking version for deduplication (if provided)
@@ -2961,7 +2947,8 @@ Guidelines:
2961
2947
  FROM entities
2962
2948
  WHERE id = $1 AND bank_id = $2
2963
2949
  """,
2964
- entity_uuid, bank_id
2950
+ entity_uuid,
2951
+ bank_id,
2965
2952
  )
2966
2953
 
2967
2954
  if current_last_seen and current_last_seen.isoformat() != version:
@@ -2979,7 +2966,8 @@ Guidelines:
2979
2966
  ORDER BY mu.occurred_start DESC
2980
2967
  LIMIT 50
2981
2968
  """,
2982
- bank_id, entity_uuid
2969
+ bank_id,
2970
+ entity_uuid,
2983
2971
  )
2984
2972
 
2985
2973
  if not rows:
@@ -2988,21 +2976,19 @@ Guidelines:
2988
2976
  # Convert to MemoryFact objects for the observation extraction
2989
2977
  facts = []
2990
2978
  for row in rows:
2991
- occurred_start = row['occurred_start'].isoformat() if row['occurred_start'] else None
2992
- facts.append(MemoryFact(
2993
- id=str(row['id']),
2994
- text=row['text'],
2995
- fact_type=row['fact_type'],
2996
- context=row['context'],
2997
- occurred_start=occurred_start
2998
- ))
2979
+ occurred_start = row["occurred_start"].isoformat() if row["occurred_start"] else None
2980
+ facts.append(
2981
+ MemoryFact(
2982
+ id=str(row["id"]),
2983
+ text=row["text"],
2984
+ fact_type=row["fact_type"],
2985
+ context=row["context"],
2986
+ occurred_start=occurred_start,
2987
+ )
2988
+ )
2999
2989
 
3000
2990
  # Step 3: Extract observations using LLM (no personality)
3001
- observations = await observation_utils.extract_observations_from_facts(
3002
- self._llm_config,
3003
- entity_name,
3004
- facts
3005
- )
2991
+ observations = await observation_utils.extract_observations_from_facts(self._llm_config, entity_name, facts)
3006
2992
 
3007
2993
  if not observations:
3008
2994
  return []
@@ -3024,13 +3010,12 @@ Guidelines:
3024
3010
  AND ue.entity_id = $2
3025
3011
  )
3026
3012
  """,
3027
- bank_id, entity_uuid
3013
+ bank_id,
3014
+ entity_uuid,
3028
3015
  )
3029
3016
 
3030
3017
  # Generate embeddings for new observations
3031
- embeddings = await embedding_utils.generate_embeddings_batch(
3032
- self.embeddings, observations
3033
- )
3018
+ embeddings = await embedding_utils.generate_embeddings_batch(self.embeddings, observations)
3034
3019
 
3035
3020
  # Insert new observations
3036
3021
  current_time = utcnow()
@@ -3054,9 +3039,9 @@ Guidelines:
3054
3039
  current_time,
3055
3040
  current_time,
3056
3041
  current_time,
3057
- current_time
3042
+ current_time,
3058
3043
  )
3059
- obs_id = str(result['id'])
3044
+ obs_id = str(result["id"])
3060
3045
  created_ids.append(obs_id)
3061
3046
 
3062
3047
  # Link observation to entity
@@ -3065,7 +3050,8 @@ Guidelines:
3065
3050
  INSERT INTO unit_entities (unit_id, entity_id)
3066
3051
  VALUES ($1, $2)
3067
3052
  """,
3068
- uuid.UUID(obs_id), entity_uuid
3053
+ uuid.UUID(obs_id),
3054
+ entity_uuid,
3069
3055
  )
3070
3056
 
3071
3057
  return created_ids
@@ -3080,11 +3066,7 @@ Guidelines:
3080
3066
  return await do_db_operations(acquired_conn)
3081
3067
 
3082
3068
  async def _regenerate_observations_sync(
3083
- self,
3084
- bank_id: str,
3085
- entity_ids: List[str],
3086
- min_facts: int = 5,
3087
- conn=None
3069
+ self, bank_id: str, entity_ids: list[str], min_facts: int = 5, conn=None
3088
3070
  ) -> None:
3089
3071
  """
3090
3072
  Regenerate observations for entities synchronously (called during retain).
@@ -3111,9 +3093,10 @@ Guidelines:
3111
3093
  SELECT id, canonical_name FROM entities
3112
3094
  WHERE id = ANY($1) AND bank_id = $2
3113
3095
  """,
3114
- entity_uuids, bank_id
3096
+ entity_uuids,
3097
+ bank_id,
3115
3098
  )
3116
- entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
3099
+ entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
3117
3100
 
3118
3101
  fact_counts = await conn.fetch(
3119
3102
  """
@@ -3123,9 +3106,10 @@ Guidelines:
3123
3106
  WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3124
3107
  GROUP BY ue.entity_id
3125
3108
  """,
3126
- entity_uuids, bank_id
3109
+ entity_uuids,
3110
+ bank_id,
3127
3111
  )
3128
- entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
3112
+ entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
3129
3113
  else:
3130
3114
  # Acquire a new connection (standalone call)
3131
3115
  pool = await self._get_pool()
@@ -3135,9 +3119,10 @@ Guidelines:
3135
3119
  SELECT id, canonical_name FROM entities
3136
3120
  WHERE id = ANY($1) AND bank_id = $2
3137
3121
  """,
3138
- entity_uuids, bank_id
3122
+ entity_uuids,
3123
+ bank_id,
3139
3124
  )
3140
- entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
3125
+ entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
3141
3126
 
3142
3127
  fact_counts = await acquired_conn.fetch(
3143
3128
  """
@@ -3147,9 +3132,10 @@ Guidelines:
3147
3132
  WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3148
3133
  GROUP BY ue.entity_id
3149
3134
  """,
3150
- entity_uuids, bank_id
3135
+ entity_uuids,
3136
+ bank_id,
3151
3137
  )
3152
- entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
3138
+ entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
3153
3139
 
3154
3140
  # Filter entities that meet the threshold
3155
3141
  entities_to_process = []
@@ -3171,11 +3157,9 @@ Guidelines:
3171
3157
  except Exception as e:
3172
3158
  logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3173
3159
 
3174
- await asyncio.gather(*[
3175
- process_entity(eid, name) for eid, name in entities_to_process
3176
- ])
3160
+ await asyncio.gather(*[process_entity(eid, name) for eid, name in entities_to_process])
3177
3161
 
3178
- async def _handle_regenerate_observations(self, task_dict: Dict[str, Any]):
3162
+ async def _handle_regenerate_observations(self, task_dict: dict[str, Any]):
3179
3163
  """
3180
3164
  Handler for regenerate_observations tasks.
3181
3165
 
@@ -3185,12 +3169,12 @@ Guidelines:
3185
3169
  - 'entity_id', 'entity_name': Process single entity (legacy)
3186
3170
  """
3187
3171
  try:
3188
- bank_id = task_dict.get('bank_id')
3172
+ bank_id = task_dict.get("bank_id")
3189
3173
 
3190
3174
  # New format: multiple entity_ids
3191
- if 'entity_ids' in task_dict:
3192
- entity_ids = task_dict.get('entity_ids', [])
3193
- min_facts = task_dict.get('min_facts', 5)
3175
+ if "entity_ids" in task_dict:
3176
+ entity_ids = task_dict.get("entity_ids", [])
3177
+ min_facts = task_dict.get("min_facts", 5)
3194
3178
 
3195
3179
  if not bank_id or not entity_ids:
3196
3180
  logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
@@ -3203,31 +3187,37 @@ Guidelines:
3203
3187
  try:
3204
3188
  # Fetch entity name and check fact count
3205
3189
  import uuid as uuid_module
3190
+
3206
3191
  entity_uuid = uuid_module.UUID(entity_id) if isinstance(entity_id, str) else entity_id
3207
3192
 
3208
3193
  # First check if entity exists
3209
3194
  entity_exists = await conn.fetchrow(
3210
3195
  "SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
3211
- entity_uuid, bank_id
3196
+ entity_uuid,
3197
+ bank_id,
3212
3198
  )
3213
3199
 
3214
3200
  if not entity_exists:
3215
3201
  logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
3216
3202
  continue
3217
3203
 
3218
- entity_name = entity_exists['canonical_name']
3204
+ entity_name = entity_exists["canonical_name"]
3219
3205
 
3220
3206
  # Count facts linked to this entity
3221
- fact_count = await conn.fetchval(
3222
- "SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1",
3223
- entity_uuid
3224
- ) or 0
3207
+ fact_count = (
3208
+ await conn.fetchval(
3209
+ "SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1", entity_uuid
3210
+ )
3211
+ or 0
3212
+ )
3225
3213
 
3226
3214
  # Only regenerate if entity has enough facts
3227
3215
  if fact_count >= min_facts:
3228
3216
  await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
3229
3217
  else:
3230
- logger.debug(f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)")
3218
+ logger.debug(
3219
+ f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
3220
+ )
3231
3221
 
3232
3222
  except Exception as e:
3233
3223
  logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
@@ -3235,9 +3225,9 @@ Guidelines:
3235
3225
 
3236
3226
  # Legacy format: single entity
3237
3227
  else:
3238
- entity_id = task_dict.get('entity_id')
3239
- entity_name = task_dict.get('entity_name')
3240
- version = task_dict.get('version')
3228
+ entity_id = task_dict.get("entity_id")
3229
+ entity_name = task_dict.get("entity_name")
3230
+ version = task_dict.get("version")
3241
3231
 
3242
3232
  if not all([bank_id, entity_id, entity_name]):
3243
3233
  logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
@@ -3248,5 +3238,5 @@ Guidelines:
3248
3238
  except Exception as e:
3249
3239
  logger.error(f"[OBSERVATIONS] Error regenerating observations: {e}")
3250
3240
  import traceback
3251
- traceback.print_exc()
3252
3241
 
3242
+ traceback.print_exc()