hindsight-api 0.0.21__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,7 +48,7 @@ from .entity_resolver import EntityResolver
48
48
  from .retain import embedding_utils, bank_utils
49
49
  from .search import think_utils, observation_utils
50
50
  from .llm_wrapper import LLMConfig
51
- from .response_models import RecallResult as RecallResultModel, ReflectResult, MemoryFact, EntityState, EntityObservation
51
+ from .response_models import RecallResult as RecallResultModel, ReflectResult, MemoryFact, EntityState, EntityObservation, VALID_RECALL_FACT_TYPES
52
52
  from .task_backend import TaskBackend, AsyncIOQueueBackend
53
53
  from .search.reranking import CrossEncoderReranker
54
54
  from ..pg0 import EmbeddedPostgres
@@ -110,12 +110,14 @@ class MemoryEngine:
110
110
  pool_min_size: int = 5,
111
111
  pool_max_size: int = 100,
112
112
  task_backend: Optional[TaskBackend] = None,
113
+ run_migrations: bool = True,
113
114
  ):
114
115
  """
115
116
  Initialize the temporal + semantic memory system.
116
117
 
117
118
  Args:
118
119
  db_url: PostgreSQL connection URL (postgresql://user:pass@host:port/dbname). Required.
120
+ Also supports pg0 URLs: "pg0" or "pg0://instance-name" or "pg0://instance-name:port"
119
121
  memory_llm_provider: LLM provider for memory operations: "openai", "groq", or "ollama". Required.
120
122
  memory_llm_api_key: API key for the LLM provider. Required.
121
123
  memory_llm_model: Model name to use for all memory operations (put/think/opinions). Required.
@@ -129,16 +131,38 @@ class MemoryEngine:
129
131
  pool_max_size: Maximum number of connections in the pool (default: 100)
130
132
  Increase for parallel think/search operations (e.g., 200-300 for 100+ parallel thinks)
131
133
  task_backend: Custom task backend for async task execution. If not provided, uses AsyncIOQueueBackend
134
+ run_migrations: Whether to run database migrations during initialize(). Default: True
132
135
  """
133
136
  if not db_url:
134
137
  raise ValueError("Database url is required")
135
138
  # Track pg0 instance (if used)
136
139
  self._pg0: Optional[EmbeddedPostgres] = None
140
+ self._pg0_instance_name: Optional[str] = None
137
141
 
138
142
  # Initialize PostgreSQL connection URL
139
143
  # The actual URL will be set during initialize() after starting the server
140
- self._use_pg0 = db_url == "pg0"
141
- self.db_url = db_url if not self._use_pg0 else None
144
+ # Supports: "pg0" (default instance), "pg0://instance-name" (named instance), or regular postgresql:// URL
145
+ if db_url == "pg0":
146
+ self._use_pg0 = True
147
+ self._pg0_instance_name = "hindsight"
148
+ self._pg0_port = None # Use default port
149
+ self.db_url = None
150
+ elif db_url.startswith("pg0://"):
151
+ self._use_pg0 = True
152
+ # Parse instance name and optional port: pg0://instance-name or pg0://instance-name:port
153
+ url_part = db_url[6:] # Remove "pg0://"
154
+ if ":" in url_part:
155
+ self._pg0_instance_name, port_str = url_part.rsplit(":", 1)
156
+ self._pg0_port = int(port_str)
157
+ else:
158
+ self._pg0_instance_name = url_part or "hindsight"
159
+ self._pg0_port = None # Use default port
160
+ self.db_url = None
161
+ else:
162
+ self._use_pg0 = False
163
+ self._pg0_instance_name = None
164
+ self._pg0_port = None
165
+ self.db_url = db_url
142
166
 
143
167
 
144
168
  # Set default base URL if not provided
@@ -155,6 +179,7 @@ class MemoryEngine:
155
179
  self._initialized = False
156
180
  self._pool_min_size = pool_min_size
157
181
  self._pool_max_size = pool_max_size
182
+ self._run_migrations = run_migrations
158
183
 
159
184
  # Initialize entity resolver (will be created in initialize())
160
185
  self.entity_resolver = None
@@ -378,8 +403,16 @@ class MemoryEngine:
378
403
  async def start_pg0():
379
404
  """Start pg0 if configured."""
380
405
  if self._use_pg0:
381
- self._pg0 = EmbeddedPostgres()
382
- self.db_url = await self._pg0.ensure_running()
406
+ kwargs = {"name": self._pg0_instance_name}
407
+ if self._pg0_port is not None:
408
+ kwargs["port"] = self._pg0_port
409
+ pg0 = EmbeddedPostgres(**kwargs)
410
+ # Check if pg0 is already running before we start it
411
+ was_already_running = await pg0.is_running()
412
+ self.db_url = await pg0.ensure_running()
413
+ # Only track pg0 (to stop later) if WE started it
414
+ if not was_already_running:
415
+ self._pg0 = pg0
383
416
 
384
417
  def load_embeddings():
385
418
  """Load embedding model (CPU-bound)."""
@@ -408,6 +441,12 @@ class MemoryEngine:
408
441
  pg0_task, embeddings_future, cross_encoder_future, query_analyzer_future
409
442
  )
410
443
 
444
+ # Run database migrations if enabled
445
+ if self._run_migrations:
446
+ from ..migrations import run_migrations
447
+ logger.info("Running database migrations...")
448
+ run_migrations(self.db_url)
449
+
411
450
  logger.info(f"Connecting to PostgreSQL at {self.db_url}")
412
451
 
413
452
  # Create connection pool
@@ -869,7 +908,6 @@ class MemoryEngine:
869
908
  task_backend=self._task_backend,
870
909
  format_date_fn=self._format_readable_date,
871
910
  duplicate_checker_fn=self._find_duplicate_facts_batch,
872
- regenerate_observations_fn=self._regenerate_observations_sync,
873
911
  bank_id=bank_id,
874
912
  contents_dicts=contents,
875
913
  document_id=document_id,
@@ -955,11 +993,19 @@ class MemoryEngine:
955
993
  - entities: Optional dict of entity states (if include_entities=True)
956
994
  - chunks: Optional dict of chunks (if include_chunks=True)
957
995
  """
996
+ # Validate fact types early
997
+ invalid_types = set(fact_type) - VALID_RECALL_FACT_TYPES
998
+ if invalid_types:
999
+ raise ValueError(
1000
+ f"Invalid fact type(s): {', '.join(sorted(invalid_types))}. "
1001
+ f"Must be one of: {', '.join(sorted(VALID_RECALL_FACT_TYPES))}"
1002
+ )
1003
+
958
1004
  # Map budget enum to thinking_budget number
959
1005
  budget_mapping = {
960
1006
  Budget.LOW: 100,
961
1007
  Budget.MID: 300,
962
- Budget.HIGH: 600
1008
+ Budget.HIGH: 1000
963
1009
  }
964
1010
  thinking_budget = budget_mapping[budget]
965
1011
 
@@ -1040,12 +1086,12 @@ class MemoryEngine:
1040
1086
  tracer.start()
1041
1087
 
1042
1088
  pool = await self._get_pool()
1043
- search_start = time.time()
1089
+ recall_start = time.time()
1044
1090
 
1045
1091
  # Buffer logs for clean output in concurrent scenarios
1046
- search_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
1092
+ recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
1047
1093
  log_buffer = []
1048
- log_buffer.append(f"[SEARCH {search_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})")
1094
+ log_buffer.append(f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})")
1049
1095
 
1050
1096
  try:
1051
1097
  # Step 1: Generate query embedding (for semantic search)
@@ -1088,7 +1134,7 @@ class MemoryEngine:
1088
1134
  for idx, (ft_semantic, ft_bm25, ft_graph, ft_temporal, ft_timings, ft_temporal_constraint) in enumerate(all_retrievals):
1089
1135
  # Log fact types in this retrieval batch
1090
1136
  ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1091
- logger.debug(f"[SEARCH {search_id}] Fact type '{ft_name}': semantic={len(ft_semantic)}, bm25={len(ft_bm25)}, graph={len(ft_graph)}, temporal={len(ft_temporal) if ft_temporal else 0}")
1137
+ logger.debug(f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(ft_semantic)}, bm25={len(ft_bm25)}, graph={len(ft_graph)}, temporal={len(ft_temporal) if ft_temporal else 0}")
1092
1138
 
1093
1139
  semantic_results.extend(ft_semantic)
1094
1140
  bm25_results.extend(ft_bm25)
@@ -1209,7 +1255,6 @@ class MemoryEngine:
1209
1255
  # Step 4: Rerank using cross-encoder (MergedCandidate -> ScoredResult)
1210
1256
  step_start = time.time()
1211
1257
  reranker_instance = self._cross_encoder_reranker
1212
- log_buffer.append(f" [4] Using cross-encoder reranker")
1213
1258
 
1214
1259
  # Rerank using cross-encoder
1215
1260
  scored_results = reranker_instance.rerank(query, merged_candidates)
@@ -1334,12 +1379,7 @@ class MemoryEngine:
1334
1379
  ft = sr.retrieval.fact_type
1335
1380
  fact_type_counts[ft] = fact_type_counts.get(ft, 0) + 1
1336
1381
 
1337
- total_time = time.time() - search_start
1338
1382
  fact_type_summary = ", ".join([f"{ft}={count}" for ft, count in sorted(fact_type_counts.items())])
1339
- log_buffer.append(f"[SEARCH {search_id}] Complete: {len(top_scored)} results ({fact_type_summary}) ({total_tokens} tokens) in {total_time:.3f}s")
1340
-
1341
- # Log all buffered logs at once
1342
- logger.info("\n" + "\n".join(log_buffer))
1343
1383
 
1344
1384
  # Convert ScoredResult to dicts with ISO datetime strings
1345
1385
  top_results_dicts = []
@@ -1401,11 +1441,12 @@ class MemoryEngine:
1401
1441
  mentioned_at=result_dict.get("mentioned_at"),
1402
1442
  document_id=result_dict.get("document_id"),
1403
1443
  chunk_id=result_dict.get("chunk_id"),
1404
- activation=result_dict.get("weight") # Use final weight as activation
1405
1444
  ))
1406
1445
 
1407
1446
  # Fetch entity observations if requested
1408
1447
  entities_dict = None
1448
+ total_entity_tokens = 0
1449
+ total_chunk_tokens = 0
1409
1450
  if include_entities and fact_entity_map:
1410
1451
  # Collect unique entities in order of fact relevance (preserving order from top_scored)
1411
1452
  # Use a list to maintain order, but track seen entities to avoid duplicates
@@ -1425,7 +1466,6 @@ class MemoryEngine:
1425
1466
 
1426
1467
  # Fetch observations for each entity (respect token budget, in order)
1427
1468
  entities_dict = {}
1428
- total_entity_tokens = 0
1429
1469
  encoding = _get_tiktoken_encoding()
1430
1470
 
1431
1471
  for entity_id, entity_name in entities_ordered:
@@ -1485,7 +1525,6 @@ class MemoryEngine:
1485
1525
 
1486
1526
  # Apply token limit and build chunks_dict in the order of chunk_ids_ordered
1487
1527
  chunks_dict = {}
1488
- total_chunk_tokens = 0
1489
1528
  encoding = _get_tiktoken_encoding()
1490
1529
 
1491
1530
  for chunk_id in chunk_ids_ordered:
@@ -1525,10 +1564,17 @@ class MemoryEngine:
1525
1564
  trace = tracer.finalize(top_results_dicts)
1526
1565
  trace_dict = trace.to_dict() if trace else None
1527
1566
 
1567
+ # Log final recall stats
1568
+ total_time = time.time() - recall_start
1569
+ num_chunks = len(chunks_dict) if chunks_dict else 0
1570
+ num_entities = len(entities_dict) if entities_dict else 0
1571
+ log_buffer.append(f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s")
1572
+ logger.info("\n" + "\n".join(log_buffer))
1573
+
1528
1574
  return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
1529
1575
 
1530
1576
  except Exception as e:
1531
- log_buffer.append(f"[SEARCH {search_id}] ERROR after {time.time() - search_start:.3f}s: {str(e)}")
1577
+ log_buffer.append(f"[RECALL {recall_id}] ERROR after {time.time() - recall_start:.3f}s: {str(e)}")
1532
1578
  logger.error("\n" + "\n".join(log_buffer))
1533
1579
  raise Exception(f"Failed to search memories: {str(e)}")
1534
1580
 
@@ -2502,14 +2548,14 @@ Guidelines:
2502
2548
  async def update_bank_disposition(
2503
2549
  self,
2504
2550
  bank_id: str,
2505
- disposition: Dict[str, float]
2551
+ disposition: Dict[str, int]
2506
2552
  ) -> None:
2507
2553
  """
2508
2554
  Update bank disposition traits.
2509
2555
 
2510
2556
  Args:
2511
2557
  bank_id: bank IDentifier
2512
- disposition: Dict with Big Five traits + bias_strength (all 0-1)
2558
+ disposition: Dict with skepticism, literalism, empathy (all 1-5)
2513
2559
  """
2514
2560
  pool = await self._get_pool()
2515
2561
  await bank_utils.update_bank_disposition(pool, bank_id, disposition)
@@ -2584,7 +2630,13 @@ Guidelines:
2584
2630
  if self._llm_config is None:
2585
2631
  raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
2586
2632
 
2633
+ reflect_start = time.time()
2634
+ reflect_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
2635
+ log_buffer = []
2636
+ log_buffer.append(f"[REFLECT {reflect_id}] Query: '{query[:50]}...'")
2637
+
2587
2638
  # Steps 1-3: Run multi-fact-type search (12-way retrieval: 4 methods × 3 fact types)
2639
+ recall_start = time.time()
2588
2640
  search_result = await self.recall_async(
2589
2641
  bank_id=bank_id,
2590
2642
  query=query,
@@ -2594,24 +2646,22 @@ Guidelines:
2594
2646
  fact_type=['experience', 'world', 'opinion'],
2595
2647
  include_entities=True
2596
2648
  )
2649
+ recall_time = time.time() - recall_start
2597
2650
 
2598
2651
  all_results = search_result.results
2599
- logger.info(f"[THINK] Search returned {len(all_results)} results")
2600
2652
 
2601
2653
  # Split results by fact type for structured response
2602
2654
  agent_results = [r for r in all_results if r.fact_type == 'experience']
2603
2655
  world_results = [r for r in all_results if r.fact_type == 'world']
2604
2656
  opinion_results = [r for r in all_results if r.fact_type == 'opinion']
2605
2657
 
2606
- logger.info(f"[THINK] Split results - agent: {len(agent_results)}, world: {len(world_results)}, opinion: {len(opinion_results)}")
2658
+ log_buffer.append(f"[REFLECT {reflect_id}] Recall: {len(all_results)} facts (experience={len(agent_results)}, world={len(world_results)}, opinion={len(opinion_results)}) in {recall_time:.3f}s")
2607
2659
 
2608
2660
  # Format facts for LLM
2609
2661
  agent_facts_text = think_utils.format_facts_for_prompt(agent_results)
2610
2662
  world_facts_text = think_utils.format_facts_for_prompt(world_results)
2611
2663
  opinion_facts_text = think_utils.format_facts_for_prompt(opinion_results)
2612
2664
 
2613
- logger.info(f"[THINK] Formatted facts - agent: {len(agent_facts_text)} chars, world: {len(world_facts_text)} chars, opinion: {len(opinion_facts_text)} chars")
2614
-
2615
2665
  # Get bank profile (name, disposition + background)
2616
2666
  profile = await self.get_bank_profile(bank_id)
2617
2667
  name = profile["name"]
@@ -2630,10 +2680,11 @@ Guidelines:
2630
2680
  context=context,
2631
2681
  )
2632
2682
 
2633
- logger.info(f"[THINK] Full prompt length: {len(prompt)} chars")
2683
+ log_buffer.append(f"[REFLECT {reflect_id}] Prompt: {len(prompt)} chars")
2634
2684
 
2635
2685
  system_message = think_utils.get_system_message(disposition)
2636
2686
 
2687
+ llm_start = time.time()
2637
2688
  answer_text = await self._llm_config.call(
2638
2689
  messages=[
2639
2690
  {"role": "system", "content": system_message},
@@ -2643,6 +2694,7 @@ Guidelines:
2643
2694
  temperature=0.9,
2644
2695
  max_tokens=1000
2645
2696
  )
2697
+ llm_time = time.time() - llm_start
2646
2698
 
2647
2699
  answer_text = answer_text.strip()
2648
2700
 
@@ -2654,6 +2706,10 @@ Guidelines:
2654
2706
  'query': query
2655
2707
  })
2656
2708
 
2709
+ total_time = time.time() - reflect_start
2710
+ log_buffer.append(f"[REFLECT {reflect_id}] Complete: {len(answer_text)} chars response, LLM {llm_time:.3f}s, total {total_time:.3f}s")
2711
+ logger.info("\n" + "\n".join(log_buffer))
2712
+
2657
2713
  # Return response with facts split by type
2658
2714
  return ReflectResult(
2659
2715
  text=answer_text,
@@ -2702,7 +2758,7 @@ Guidelines:
2702
2758
  )
2703
2759
 
2704
2760
  except Exception as e:
2705
- logger.warning(f"[THINK] Failed to extract/store opinions: {str(e)}")
2761
+ logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
2706
2762
 
2707
2763
  async def get_entity_observations(
2708
2764
  self,
@@ -2828,7 +2884,8 @@ Guidelines:
2828
2884
  bank_id: str,
2829
2885
  entity_id: str,
2830
2886
  entity_name: str,
2831
- version: str | None = None
2887
+ version: str | None = None,
2888
+ conn=None
2832
2889
  ) -> List[str]:
2833
2890
  """
2834
2891
  Regenerate observations for an entity by:
@@ -2843,43 +2900,58 @@ Guidelines:
2843
2900
  entity_id: Entity UUID
2844
2901
  entity_name: Canonical name of the entity
2845
2902
  version: Entity's last_seen timestamp when task was created (for deduplication)
2903
+ conn: Optional database connection (for transactional atomicity with caller)
2846
2904
 
2847
2905
  Returns:
2848
2906
  List of created observation IDs
2849
2907
  """
2850
2908
  pool = await self._get_pool()
2909
+ entity_uuid = uuid.UUID(entity_id)
2851
2910
 
2852
- # Step 1: Check version for deduplication
2853
- if version:
2854
- async with acquire_with_retry(pool) as conn:
2855
- current_last_seen = await conn.fetchval(
2856
- """
2857
- SELECT last_seen
2858
- FROM entities
2859
- WHERE id = $1 AND bank_id = $2
2860
- """,
2861
- uuid.UUID(entity_id), bank_id
2862
- )
2911
+ # Helper to run a query with provided conn or acquire one
2912
+ async def fetch_with_conn(query, *args):
2913
+ if conn is not None:
2914
+ return await conn.fetch(query, *args)
2915
+ else:
2916
+ async with acquire_with_retry(pool) as acquired_conn:
2917
+ return await acquired_conn.fetch(query, *args)
2863
2918
 
2864
- if current_last_seen and current_last_seen.isoformat() != version:
2865
- return []
2919
+ async def fetchval_with_conn(query, *args):
2920
+ if conn is not None:
2921
+ return await conn.fetchval(query, *args)
2922
+ else:
2923
+ async with acquire_with_retry(pool) as acquired_conn:
2924
+ return await acquired_conn.fetchval(query, *args)
2866
2925
 
2867
- # Step 2: Get all facts mentioning this entity (exclude observations themselves)
2868
- async with acquire_with_retry(pool) as conn:
2869
- rows = await conn.fetch(
2926
+ # Step 1: Check version for deduplication
2927
+ if version:
2928
+ current_last_seen = await fetchval_with_conn(
2870
2929
  """
2871
- SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
2872
- FROM memory_units mu
2873
- JOIN unit_entities ue ON mu.id = ue.unit_id
2874
- WHERE mu.bank_id = $1
2875
- AND ue.entity_id = $2
2876
- AND mu.fact_type IN ('world', 'experience')
2877
- ORDER BY mu.occurred_start DESC
2878
- LIMIT 50
2930
+ SELECT last_seen
2931
+ FROM entities
2932
+ WHERE id = $1 AND bank_id = $2
2879
2933
  """,
2880
- bank_id, uuid.UUID(entity_id)
2934
+ entity_uuid, bank_id
2881
2935
  )
2882
2936
 
2937
+ if current_last_seen and current_last_seen.isoformat() != version:
2938
+ return []
2939
+
2940
+ # Step 2: Get all facts mentioning this entity (exclude observations themselves)
2941
+ rows = await fetch_with_conn(
2942
+ """
2943
+ SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
2944
+ FROM memory_units mu
2945
+ JOIN unit_entities ue ON mu.id = ue.unit_id
2946
+ WHERE mu.bank_id = $1
2947
+ AND ue.entity_id = $2
2948
+ AND mu.fact_type IN ('world', 'experience')
2949
+ ORDER BY mu.occurred_start DESC
2950
+ LIMIT 50
2951
+ """,
2952
+ bank_id, entity_uuid
2953
+ )
2954
+
2883
2955
  if not rows:
2884
2956
  return []
2885
2957
 
@@ -2905,120 +2977,173 @@ Guidelines:
2905
2977
  if not observations:
2906
2978
  return []
2907
2979
 
2908
- # Step 4: Delete old observations and insert new ones in a transaction
2909
- async with acquire_with_retry(pool) as conn:
2910
- async with conn.transaction():
2911
- # Delete old observations for this entity
2912
- await conn.execute(
2980
+ # Step 4: Delete old observations and insert new ones
2981
+ # If conn provided, we're already in a transaction - don't start another
2982
+ # If conn is None, acquire one and start a transaction
2983
+ async def do_db_operations(db_conn):
2984
+ # Delete old observations for this entity
2985
+ await db_conn.execute(
2986
+ """
2987
+ DELETE FROM memory_units
2988
+ WHERE id IN (
2989
+ SELECT mu.id
2990
+ FROM memory_units mu
2991
+ JOIN unit_entities ue ON mu.id = ue.unit_id
2992
+ WHERE mu.bank_id = $1
2993
+ AND mu.fact_type = 'observation'
2994
+ AND ue.entity_id = $2
2995
+ )
2996
+ """,
2997
+ bank_id, entity_uuid
2998
+ )
2999
+
3000
+ # Generate embeddings for new observations
3001
+ embeddings = await embedding_utils.generate_embeddings_batch(
3002
+ self.embeddings, observations
3003
+ )
3004
+
3005
+ # Insert new observations
3006
+ current_time = utcnow()
3007
+ created_ids = []
3008
+
3009
+ for obs_text, embedding in zip(observations, embeddings):
3010
+ result = await db_conn.fetchrow(
2913
3011
  """
2914
- DELETE FROM memory_units
2915
- WHERE id IN (
2916
- SELECT mu.id
2917
- FROM memory_units mu
2918
- JOIN unit_entities ue ON mu.id = ue.unit_id
2919
- WHERE mu.bank_id = $1
2920
- AND mu.fact_type = 'observation'
2921
- AND ue.entity_id = $2
3012
+ INSERT INTO memory_units (
3013
+ bank_id, text, embedding, context, event_date,
3014
+ occurred_start, occurred_end, mentioned_at,
3015
+ fact_type, access_count
2922
3016
  )
3017
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
3018
+ RETURNING id
2923
3019
  """,
2924
- bank_id, uuid.UUID(entity_id)
3020
+ bank_id,
3021
+ obs_text,
3022
+ str(embedding),
3023
+ f"observation about {entity_name}",
3024
+ current_time,
3025
+ current_time,
3026
+ current_time,
3027
+ current_time
2925
3028
  )
3029
+ obs_id = str(result['id'])
3030
+ created_ids.append(obs_id)
2926
3031
 
2927
- # Generate embeddings for new observations
2928
- embeddings = await embedding_utils.generate_embeddings_batch(
2929
- self.embeddings, observations
3032
+ # Link observation to entity
3033
+ await db_conn.execute(
3034
+ """
3035
+ INSERT INTO unit_entities (unit_id, entity_id)
3036
+ VALUES ($1, $2)
3037
+ """,
3038
+ uuid.UUID(obs_id), entity_uuid
2930
3039
  )
2931
3040
 
2932
- # Insert new observations
2933
- current_time = utcnow()
2934
- created_ids = []
3041
+ return created_ids
2935
3042
 
2936
- for obs_text, embedding in zip(observations, embeddings):
2937
- result = await conn.fetchrow(
2938
- """
2939
- INSERT INTO memory_units (
2940
- bank_id, text, embedding, context, event_date,
2941
- occurred_start, occurred_end, mentioned_at,
2942
- fact_type, access_count
2943
- )
2944
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
2945
- RETURNING id
2946
- """,
2947
- bank_id,
2948
- obs_text,
2949
- str(embedding),
2950
- f"observation about {entity_name}",
2951
- current_time,
2952
- current_time,
2953
- current_time,
2954
- current_time
2955
- )
2956
- obs_id = str(result['id'])
2957
- created_ids.append(obs_id)
2958
-
2959
- # Link observation to entity
2960
- await conn.execute(
2961
- """
2962
- INSERT INTO unit_entities (unit_id, entity_id)
2963
- VALUES ($1, $2)
2964
- """,
2965
- uuid.UUID(obs_id), uuid.UUID(entity_id)
2966
- )
2967
-
2968
- # Single consolidated log line
2969
- logger.info(f"[OBSERVATIONS] {entity_name}: {len(facts)} facts -> {len(created_ids)} observations")
2970
- return created_ids
3043
+ if conn is not None:
3044
+ # Use provided connection (already in a transaction)
3045
+ return await do_db_operations(conn)
3046
+ else:
3047
+ # Acquire connection and start our own transaction
3048
+ async with acquire_with_retry(pool) as acquired_conn:
3049
+ async with acquired_conn.transaction():
3050
+ return await do_db_operations(acquired_conn)
2971
3051
 
2972
3052
  async def _regenerate_observations_sync(
2973
3053
  self,
2974
3054
  bank_id: str,
2975
3055
  entity_ids: List[str],
2976
- min_facts: int = 5
3056
+ min_facts: int = 5,
3057
+ conn=None
2977
3058
  ) -> None:
2978
3059
  """
2979
3060
  Regenerate observations for entities synchronously (called during retain).
2980
3061
 
3062
+ Processes entities in PARALLEL for faster execution.
3063
+
2981
3064
  Args:
2982
3065
  bank_id: Bank identifier
2983
3066
  entity_ids: List of entity IDs to process
2984
3067
  min_facts: Minimum facts required to regenerate observations
3068
+ conn: Optional database connection (for transactional atomicity)
2985
3069
  """
2986
3070
  if not bank_id or not entity_ids:
2987
3071
  return
2988
3072
 
2989
- pool = await self._get_pool()
2990
- async with pool.acquire() as conn:
2991
- for entity_id in entity_ids:
2992
- try:
2993
- entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
2994
-
2995
- # Check if entity exists
2996
- entity_exists = await conn.fetchrow(
2997
- "SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
2998
- entity_uuid, bank_id
2999
- )
3073
+ # Convert to UUIDs
3074
+ entity_uuids = [uuid.UUID(eid) if isinstance(eid, str) else eid for eid in entity_ids]
3000
3075
 
3001
- if not entity_exists:
3002
- logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
3003
- continue
3076
+ # Use provided connection or acquire a new one
3077
+ if conn is not None:
3078
+ # Use the provided connection (transactional with caller)
3079
+ entity_rows = await conn.fetch(
3080
+ """
3081
+ SELECT id, canonical_name FROM entities
3082
+ WHERE id = ANY($1) AND bank_id = $2
3083
+ """,
3084
+ entity_uuids, bank_id
3085
+ )
3086
+ entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
3004
3087
 
3005
- entity_name = entity_exists['canonical_name']
3088
+ fact_counts = await conn.fetch(
3089
+ """
3090
+ SELECT ue.entity_id, COUNT(*) as cnt
3091
+ FROM unit_entities ue
3092
+ JOIN memory_units mu ON ue.unit_id = mu.id
3093
+ WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3094
+ GROUP BY ue.entity_id
3095
+ """,
3096
+ entity_uuids, bank_id
3097
+ )
3098
+ entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
3099
+ else:
3100
+ # Acquire a new connection (standalone call)
3101
+ pool = await self._get_pool()
3102
+ async with pool.acquire() as acquired_conn:
3103
+ entity_rows = await acquired_conn.fetch(
3104
+ """
3105
+ SELECT id, canonical_name FROM entities
3106
+ WHERE id = ANY($1) AND bank_id = $2
3107
+ """,
3108
+ entity_uuids, bank_id
3109
+ )
3110
+ entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
3006
3111
 
3007
- # Count facts linked to this entity
3008
- fact_count = await conn.fetchval(
3009
- "SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1",
3010
- entity_uuid
3011
- ) or 0
3112
+ fact_counts = await acquired_conn.fetch(
3113
+ """
3114
+ SELECT ue.entity_id, COUNT(*) as cnt
3115
+ FROM unit_entities ue
3116
+ JOIN memory_units mu ON ue.unit_id = mu.id
3117
+ WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3118
+ GROUP BY ue.entity_id
3119
+ """,
3120
+ entity_uuids, bank_id
3121
+ )
3122
+ entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
3123
+
3124
+ # Filter entities that meet the threshold
3125
+ entities_to_process = []
3126
+ for entity_id in entity_ids:
3127
+ entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
3128
+ if entity_uuid not in entity_names:
3129
+ continue
3130
+ fact_count = entity_fact_counts.get(entity_uuid, 0)
3131
+ if fact_count >= min_facts:
3132
+ entities_to_process.append((entity_id, entity_names[entity_uuid]))
3133
+
3134
+ if not entities_to_process:
3135
+ return
3012
3136
 
3013
- # Only regenerate if entity has enough facts
3014
- if fact_count >= min_facts:
3015
- await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
3016
- else:
3017
- logger.debug(f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)")
3137
+ # Process all entities in PARALLEL (LLM calls are the bottleneck)
3138
+ async def process_entity(entity_id: str, entity_name: str):
3139
+ try:
3140
+ await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None, conn=conn)
3141
+ except Exception as e:
3142
+ logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3018
3143
 
3019
- except Exception as e:
3020
- logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3021
- continue
3144
+ await asyncio.gather(*[
3145
+ process_entity(eid, name) for eid, name in entities_to_process
3146
+ ])
3022
3147
 
3023
3148
  async def _handle_regenerate_observations(self, task_dict: Dict[str, Any]):
3024
3149
  """