hindsight-api 0.0.21__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. hindsight_api/__init__.py +10 -2
  2. hindsight_api/alembic/README +1 -0
  3. hindsight_api/alembic/env.py +146 -0
  4. hindsight_api/alembic/script.py.mako +28 -0
  5. hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +274 -0
  6. hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +70 -0
  7. hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +39 -0
  8. hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +48 -0
  9. hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +62 -0
  10. hindsight_api/alembic/versions/rename_personality_to_disposition.py +65 -0
  11. hindsight_api/api/__init__.py +2 -4
  12. hindsight_api/api/http.py +112 -164
  13. hindsight_api/api/mcp.py +2 -1
  14. hindsight_api/config.py +154 -0
  15. hindsight_api/engine/__init__.py +7 -2
  16. hindsight_api/engine/cross_encoder.py +225 -16
  17. hindsight_api/engine/embeddings.py +198 -19
  18. hindsight_api/engine/entity_resolver.py +56 -29
  19. hindsight_api/engine/llm_wrapper.py +147 -106
  20. hindsight_api/engine/memory_engine.py +337 -192
  21. hindsight_api/engine/response_models.py +15 -17
  22. hindsight_api/engine/retain/bank_utils.py +25 -35
  23. hindsight_api/engine/retain/entity_processing.py +5 -5
  24. hindsight_api/engine/retain/fact_extraction.py +86 -24
  25. hindsight_api/engine/retain/fact_storage.py +1 -1
  26. hindsight_api/engine/retain/link_creation.py +12 -6
  27. hindsight_api/engine/retain/link_utils.py +50 -56
  28. hindsight_api/engine/retain/observation_regeneration.py +264 -0
  29. hindsight_api/engine/retain/orchestrator.py +31 -44
  30. hindsight_api/engine/retain/types.py +14 -0
  31. hindsight_api/engine/search/reranking.py +6 -10
  32. hindsight_api/engine/search/retrieval.py +2 -2
  33. hindsight_api/engine/search/think_utils.py +59 -30
  34. hindsight_api/engine/search/tracer.py +1 -1
  35. hindsight_api/main.py +201 -0
  36. hindsight_api/migrations.py +61 -39
  37. hindsight_api/models.py +1 -2
  38. hindsight_api/pg0.py +17 -36
  39. hindsight_api/server.py +43 -0
  40. {hindsight_api-0.0.21.dist-info → hindsight_api-0.1.1.dist-info}/METADATA +2 -3
  41. hindsight_api-0.1.1.dist-info/RECORD +60 -0
  42. hindsight_api-0.1.1.dist-info/entry_points.txt +2 -0
  43. hindsight_api/cli.py +0 -128
  44. hindsight_api/web/__init__.py +0 -12
  45. hindsight_api/web/server.py +0 -109
  46. hindsight_api-0.0.21.dist-info/RECORD +0 -50
  47. hindsight_api-0.0.21.dist-info/entry_points.txt +0 -2
  48. {hindsight_api-0.0.21.dist-info → hindsight_api-0.1.1.dist-info}/WHEEL +0 -0
@@ -11,17 +11,20 @@ This implements a sophisticated memory architecture that combines:
11
11
  import json
12
12
  import os
13
13
  from datetime import datetime, timedelta, timezone
14
- from typing import Any, Dict, List, Optional, Tuple, Union, TypedDict
14
+ from typing import Any, Dict, List, Optional, Tuple, Union, TypedDict, TYPE_CHECKING
15
15
  import asyncpg
16
16
  import asyncio
17
- from .embeddings import Embeddings, SentenceTransformersEmbeddings
18
- from .cross_encoder import CrossEncoderModel
17
+ from .embeddings import Embeddings, create_embeddings_from_env
18
+ from .cross_encoder import CrossEncoderModel, create_cross_encoder_from_env
19
19
  import time
20
20
  import numpy as np
21
21
  import uuid
22
22
  import logging
23
23
  from pydantic import BaseModel, Field
24
24
 
25
+ if TYPE_CHECKING:
26
+ from ..config import HindsightConfig
27
+
25
28
 
26
29
  class RetainContentDict(TypedDict, total=False):
27
30
  """Type definition for content items in retain_batch_async.
@@ -48,7 +51,7 @@ from .entity_resolver import EntityResolver
48
51
  from .retain import embedding_utils, bank_utils
49
52
  from .search import think_utils, observation_utils
50
53
  from .llm_wrapper import LLMConfig
51
- from .response_models import RecallResult as RecallResultModel, ReflectResult, MemoryFact, EntityState, EntityObservation
54
+ from .response_models import RecallResult as RecallResultModel, ReflectResult, MemoryFact, EntityState, EntityObservation, VALID_RECALL_FACT_TYPES
52
55
  from .task_backend import TaskBackend, AsyncIOQueueBackend
53
56
  from .search.reranking import CrossEncoderReranker
54
57
  from ..pg0 import EmbeddedPostgres
@@ -99,10 +102,10 @@ class MemoryEngine:
99
102
 
100
103
  def __init__(
101
104
  self,
102
- db_url: str,
103
- memory_llm_provider: str,
104
- memory_llm_api_key: str,
105
- memory_llm_model: str,
105
+ db_url: Optional[str] = None,
106
+ memory_llm_provider: Optional[str] = None,
107
+ memory_llm_api_key: Optional[str] = None,
108
+ memory_llm_model: Optional[str] = None,
106
109
  memory_llm_base_url: Optional[str] = None,
107
110
  embeddings: Optional[Embeddings] = None,
108
111
  cross_encoder: Optional[CrossEncoderModel] = None,
@@ -110,35 +113,67 @@ class MemoryEngine:
110
113
  pool_min_size: int = 5,
111
114
  pool_max_size: int = 100,
112
115
  task_backend: Optional[TaskBackend] = None,
116
+ run_migrations: bool = True,
113
117
  ):
114
118
  """
115
119
  Initialize the temporal + semantic memory system.
116
120
 
121
+ All parameters are optional and will be read from environment variables if not provided.
122
+ See hindsight_api.config for environment variable names and defaults.
123
+
117
124
  Args:
118
- db_url: PostgreSQL connection URL (postgresql://user:pass@host:port/dbname). Required.
119
- memory_llm_provider: LLM provider for memory operations: "openai", "groq", or "ollama". Required.
120
- memory_llm_api_key: API key for the LLM provider. Required.
121
- memory_llm_model: Model name to use for all memory operations (put/think/opinions). Required.
122
- memory_llm_base_url: Base URL for the LLM API. Optional. Defaults based on provider:
123
- - groq: https://api.groq.com/openai/v1
124
- - ollama: http://localhost:11434/v1
125
- embeddings: Embeddings implementation to use. If not provided, uses SentenceTransformersEmbeddings
126
- cross_encoder: Cross-encoder model for reranking. If not provided, uses default when cross-encoder reranker is selected
127
- query_analyzer: Query analyzer implementation to use. If not provided, uses TransformerQueryAnalyzer
125
+ db_url: PostgreSQL connection URL. Defaults to HINDSIGHT_API_DATABASE_URL env var or "pg0".
126
+ Also supports pg0 URLs: "pg0" or "pg0://instance-name" or "pg0://instance-name:port"
127
+ memory_llm_provider: LLM provider. Defaults to HINDSIGHT_API_LLM_PROVIDER env var or "groq".
128
+ memory_llm_api_key: API key for the LLM provider. Defaults to HINDSIGHT_API_LLM_API_KEY env var.
129
+ memory_llm_model: Model name. Defaults to HINDSIGHT_API_LLM_MODEL env var.
130
+ memory_llm_base_url: Base URL for the LLM API. Defaults based on provider.
131
+ embeddings: Embeddings implementation. If not provided, created from env vars.
132
+ cross_encoder: Cross-encoder model. If not provided, created from env vars.
133
+ query_analyzer: Query analyzer implementation. If not provided, uses DateparserQueryAnalyzer.
128
134
  pool_min_size: Minimum number of connections in the pool (default: 5)
129
135
  pool_max_size: Maximum number of connections in the pool (default: 100)
130
- Increase for parallel think/search operations (e.g., 200-300 for 100+ parallel thinks)
131
- task_backend: Custom task backend for async task execution. If not provided, uses AsyncIOQueueBackend
132
- """
133
- if not db_url:
134
- raise ValueError("Database url is required")
136
+ task_backend: Custom task backend. If not provided, uses AsyncIOQueueBackend.
137
+ run_migrations: Whether to run database migrations during initialize(). Default: True
138
+ """
139
+ # Load config from environment for any missing parameters
140
+ from ..config import get_config
141
+ config = get_config()
142
+
143
+ # Apply defaults from config
144
+ db_url = db_url or config.database_url
145
+ memory_llm_provider = memory_llm_provider or config.llm_provider
146
+ memory_llm_api_key = memory_llm_api_key or config.llm_api_key
147
+ memory_llm_model = memory_llm_model or config.llm_model
148
+ memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
135
149
  # Track pg0 instance (if used)
136
150
  self._pg0: Optional[EmbeddedPostgres] = None
151
+ self._pg0_instance_name: Optional[str] = None
137
152
 
138
153
  # Initialize PostgreSQL connection URL
139
154
  # The actual URL will be set during initialize() after starting the server
140
- self._use_pg0 = db_url == "pg0"
141
- self.db_url = db_url if not self._use_pg0 else None
155
+ # Supports: "pg0" (default instance), "pg0://instance-name" (named instance), or regular postgresql:// URL
156
+ if db_url == "pg0":
157
+ self._use_pg0 = True
158
+ self._pg0_instance_name = "hindsight"
159
+ self._pg0_port = None # Use default port
160
+ self.db_url = None
161
+ elif db_url.startswith("pg0://"):
162
+ self._use_pg0 = True
163
+ # Parse instance name and optional port: pg0://instance-name or pg0://instance-name:port
164
+ url_part = db_url[6:] # Remove "pg0://"
165
+ if ":" in url_part:
166
+ self._pg0_instance_name, port_str = url_part.rsplit(":", 1)
167
+ self._pg0_port = int(port_str)
168
+ else:
169
+ self._pg0_instance_name = url_part or "hindsight"
170
+ self._pg0_port = None # Use default port
171
+ self.db_url = None
172
+ else:
173
+ self._use_pg0 = False
174
+ self._pg0_instance_name = None
175
+ self._pg0_port = None
176
+ self.db_url = db_url
142
177
 
143
178
 
144
179
  # Set default base URL if not provided
@@ -155,15 +190,16 @@ class MemoryEngine:
155
190
  self._initialized = False
156
191
  self._pool_min_size = pool_min_size
157
192
  self._pool_max_size = pool_max_size
193
+ self._run_migrations = run_migrations
158
194
 
159
195
  # Initialize entity resolver (will be created in initialize())
160
196
  self.entity_resolver = None
161
197
 
162
- # Initialize embeddings
198
+ # Initialize embeddings (from env vars if not provided)
163
199
  if embeddings is not None:
164
200
  self.embeddings = embeddings
165
201
  else:
166
- self.embeddings = SentenceTransformersEmbeddings("BAAI/bge-small-en-v1.5")
202
+ self.embeddings = create_embeddings_from_env()
167
203
 
168
204
  # Initialize query analyzer
169
205
  if query_analyzer is not None:
@@ -294,7 +330,7 @@ class MemoryEngine:
294
330
  await self._handle_reinforce_opinion(task_dict)
295
331
  elif task_type == 'form_opinion':
296
332
  await self._handle_form_opinion(task_dict)
297
- elif task_type == 'batch_put':
333
+ elif task_type == 'batch_retain':
298
334
  await self._handle_batch_retain(task_dict)
299
335
  elif task_type == 'regenerate_observations':
300
336
  await self._handle_regenerate_observations(task_dict)
@@ -378,35 +414,58 @@ class MemoryEngine:
378
414
  async def start_pg0():
379
415
  """Start pg0 if configured."""
380
416
  if self._use_pg0:
381
- self._pg0 = EmbeddedPostgres()
382
- self.db_url = await self._pg0.ensure_running()
383
-
384
- def load_embeddings():
385
- """Load embedding model (CPU-bound)."""
386
- self.embeddings.load()
387
-
388
- def load_cross_encoder():
389
- """Load cross-encoder model (CPU-bound)."""
390
- self._cross_encoder_reranker.cross_encoder.load()
391
-
392
- def load_query_analyzer():
393
- """Load query analyzer model (CPU-bound)."""
394
- self.query_analyzer.load()
395
-
396
- # Run pg0 and all model loads in parallel
397
- # pg0 is async (IO-bound), models are sync (CPU-bound in thread pool)
398
- # Use 3 workers to load all models concurrently
399
- with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
400
- # Start all tasks
401
- pg0_task = asyncio.create_task(start_pg0())
402
- embeddings_future = loop.run_in_executor(executor, load_embeddings)
403
- cross_encoder_future = loop.run_in_executor(executor, load_cross_encoder)
404
- query_analyzer_future = loop.run_in_executor(executor, load_query_analyzer)
405
-
406
- # Wait for all to complete
407
- await asyncio.gather(
408
- pg0_task, embeddings_future, cross_encoder_future, query_analyzer_future
409
- )
417
+ kwargs = {"name": self._pg0_instance_name}
418
+ if self._pg0_port is not None:
419
+ kwargs["port"] = self._pg0_port
420
+ pg0 = EmbeddedPostgres(**kwargs)
421
+ # Check if pg0 is already running before we start it
422
+ was_already_running = await pg0.is_running()
423
+ self.db_url = await pg0.ensure_running()
424
+ # Only track pg0 (to stop later) if WE started it
425
+ if not was_already_running:
426
+ self._pg0 = pg0
427
+
428
+ async def init_embeddings():
429
+ """Initialize embedding model."""
430
+ # For local providers, run in thread pool to avoid blocking event loop
431
+ if self.embeddings.provider_name == "local":
432
+ await loop.run_in_executor(
433
+ None,
434
+ lambda: asyncio.run(self.embeddings.initialize())
435
+ )
436
+ else:
437
+ await self.embeddings.initialize()
438
+
439
+ async def init_cross_encoder():
440
+ """Initialize cross-encoder model."""
441
+ cross_encoder = self._cross_encoder_reranker.cross_encoder
442
+ # For local providers, run in thread pool to avoid blocking event loop
443
+ if cross_encoder.provider_name == "local":
444
+ await loop.run_in_executor(
445
+ None,
446
+ lambda: asyncio.run(cross_encoder.initialize())
447
+ )
448
+ else:
449
+ await cross_encoder.initialize()
450
+
451
+ async def init_query_analyzer():
452
+ """Initialize query analyzer model."""
453
+ # Query analyzer load is sync and CPU-bound
454
+ await loop.run_in_executor(None, self.query_analyzer.load)
455
+
456
+ # Run pg0 and all model initializations in parallel
457
+ await asyncio.gather(
458
+ start_pg0(),
459
+ init_embeddings(),
460
+ init_cross_encoder(),
461
+ init_query_analyzer(),
462
+ )
463
+
464
+ # Run database migrations if enabled
465
+ if self._run_migrations:
466
+ from ..migrations import run_migrations
467
+ logger.info("Running database migrations...")
468
+ run_migrations(self.db_url)
410
469
 
411
470
  logger.info(f"Connecting to PostgreSQL at {self.db_url}")
412
471
 
@@ -869,7 +928,6 @@ class MemoryEngine:
869
928
  task_backend=self._task_backend,
870
929
  format_date_fn=self._format_readable_date,
871
930
  duplicate_checker_fn=self._find_duplicate_facts_batch,
872
- regenerate_observations_fn=self._regenerate_observations_sync,
873
931
  bank_id=bank_id,
874
932
  contents_dicts=contents,
875
933
  document_id=document_id,
@@ -955,11 +1013,19 @@ class MemoryEngine:
955
1013
  - entities: Optional dict of entity states (if include_entities=True)
956
1014
  - chunks: Optional dict of chunks (if include_chunks=True)
957
1015
  """
1016
+ # Validate fact types early
1017
+ invalid_types = set(fact_type) - VALID_RECALL_FACT_TYPES
1018
+ if invalid_types:
1019
+ raise ValueError(
1020
+ f"Invalid fact type(s): {', '.join(sorted(invalid_types))}. "
1021
+ f"Must be one of: {', '.join(sorted(VALID_RECALL_FACT_TYPES))}"
1022
+ )
1023
+
958
1024
  # Map budget enum to thinking_budget number
959
1025
  budget_mapping = {
960
1026
  Budget.LOW: 100,
961
1027
  Budget.MID: 300,
962
- Budget.HIGH: 600
1028
+ Budget.HIGH: 1000
963
1029
  }
964
1030
  thinking_budget = budget_mapping[budget]
965
1031
 
@@ -1040,12 +1106,12 @@ class MemoryEngine:
1040
1106
  tracer.start()
1041
1107
 
1042
1108
  pool = await self._get_pool()
1043
- search_start = time.time()
1109
+ recall_start = time.time()
1044
1110
 
1045
1111
  # Buffer logs for clean output in concurrent scenarios
1046
- search_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
1112
+ recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
1047
1113
  log_buffer = []
1048
- log_buffer.append(f"[SEARCH {search_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})")
1114
+ log_buffer.append(f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})")
1049
1115
 
1050
1116
  try:
1051
1117
  # Step 1: Generate query embedding (for semantic search)
@@ -1088,7 +1154,7 @@ class MemoryEngine:
1088
1154
  for idx, (ft_semantic, ft_bm25, ft_graph, ft_temporal, ft_timings, ft_temporal_constraint) in enumerate(all_retrievals):
1089
1155
  # Log fact types in this retrieval batch
1090
1156
  ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1091
- logger.debug(f"[SEARCH {search_id}] Fact type '{ft_name}': semantic={len(ft_semantic)}, bm25={len(ft_bm25)}, graph={len(ft_graph)}, temporal={len(ft_temporal) if ft_temporal else 0}")
1157
+ logger.debug(f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(ft_semantic)}, bm25={len(ft_bm25)}, graph={len(ft_graph)}, temporal={len(ft_temporal) if ft_temporal else 0}")
1092
1158
 
1093
1159
  semantic_results.extend(ft_semantic)
1094
1160
  bm25_results.extend(ft_bm25)
@@ -1209,7 +1275,6 @@ class MemoryEngine:
1209
1275
  # Step 4: Rerank using cross-encoder (MergedCandidate -> ScoredResult)
1210
1276
  step_start = time.time()
1211
1277
  reranker_instance = self._cross_encoder_reranker
1212
- log_buffer.append(f" [4] Using cross-encoder reranker")
1213
1278
 
1214
1279
  # Rerank using cross-encoder
1215
1280
  scored_results = reranker_instance.rerank(query, merged_candidates)
@@ -1334,12 +1399,7 @@ class MemoryEngine:
1334
1399
  ft = sr.retrieval.fact_type
1335
1400
  fact_type_counts[ft] = fact_type_counts.get(ft, 0) + 1
1336
1401
 
1337
- total_time = time.time() - search_start
1338
1402
  fact_type_summary = ", ".join([f"{ft}={count}" for ft, count in sorted(fact_type_counts.items())])
1339
- log_buffer.append(f"[SEARCH {search_id}] Complete: {len(top_scored)} results ({fact_type_summary}) ({total_tokens} tokens) in {total_time:.3f}s")
1340
-
1341
- # Log all buffered logs at once
1342
- logger.info("\n" + "\n".join(log_buffer))
1343
1403
 
1344
1404
  # Convert ScoredResult to dicts with ISO datetime strings
1345
1405
  top_results_dicts = []
@@ -1401,11 +1461,12 @@ class MemoryEngine:
1401
1461
  mentioned_at=result_dict.get("mentioned_at"),
1402
1462
  document_id=result_dict.get("document_id"),
1403
1463
  chunk_id=result_dict.get("chunk_id"),
1404
- activation=result_dict.get("weight") # Use final weight as activation
1405
1464
  ))
1406
1465
 
1407
1466
  # Fetch entity observations if requested
1408
1467
  entities_dict = None
1468
+ total_entity_tokens = 0
1469
+ total_chunk_tokens = 0
1409
1470
  if include_entities and fact_entity_map:
1410
1471
  # Collect unique entities in order of fact relevance (preserving order from top_scored)
1411
1472
  # Use a list to maintain order, but track seen entities to avoid duplicates
@@ -1425,7 +1486,6 @@ class MemoryEngine:
1425
1486
 
1426
1487
  # Fetch observations for each entity (respect token budget, in order)
1427
1488
  entities_dict = {}
1428
- total_entity_tokens = 0
1429
1489
  encoding = _get_tiktoken_encoding()
1430
1490
 
1431
1491
  for entity_id, entity_name in entities_ordered:
@@ -1485,7 +1545,6 @@ class MemoryEngine:
1485
1545
 
1486
1546
  # Apply token limit and build chunks_dict in the order of chunk_ids_ordered
1487
1547
  chunks_dict = {}
1488
- total_chunk_tokens = 0
1489
1548
  encoding = _get_tiktoken_encoding()
1490
1549
 
1491
1550
  for chunk_id in chunk_ids_ordered:
@@ -1525,10 +1584,17 @@ class MemoryEngine:
1525
1584
  trace = tracer.finalize(top_results_dicts)
1526
1585
  trace_dict = trace.to_dict() if trace else None
1527
1586
 
1587
+ # Log final recall stats
1588
+ total_time = time.time() - recall_start
1589
+ num_chunks = len(chunks_dict) if chunks_dict else 0
1590
+ num_entities = len(entities_dict) if entities_dict else 0
1591
+ log_buffer.append(f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s")
1592
+ logger.info("\n" + "\n".join(log_buffer))
1593
+
1528
1594
  return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
1529
1595
 
1530
1596
  except Exception as e:
1531
- log_buffer.append(f"[SEARCH {search_id}] ERROR after {time.time() - search_start:.3f}s: {str(e)}")
1597
+ log_buffer.append(f"[RECALL {recall_id}] ERROR after {time.time() - recall_start:.3f}s: {str(e)}")
1532
1598
  logger.error("\n" + "\n".join(log_buffer))
1533
1599
  raise Exception(f"Failed to search memories: {str(e)}")
1534
1600
 
@@ -2502,14 +2568,14 @@ Guidelines:
2502
2568
  async def update_bank_disposition(
2503
2569
  self,
2504
2570
  bank_id: str,
2505
- disposition: Dict[str, float]
2571
+ disposition: Dict[str, int]
2506
2572
  ) -> None:
2507
2573
  """
2508
2574
  Update bank disposition traits.
2509
2575
 
2510
2576
  Args:
2511
2577
  bank_id: bank IDentifier
2512
- disposition: Dict with Big Five traits + bias_strength (all 0-1)
2578
+ disposition: Dict with skepticism, literalism, empathy (all 1-5)
2513
2579
  """
2514
2580
  pool = await self._get_pool()
2515
2581
  await bank_utils.update_bank_disposition(pool, bank_id, disposition)
@@ -2584,7 +2650,13 @@ Guidelines:
2584
2650
  if self._llm_config is None:
2585
2651
  raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
2586
2652
 
2653
+ reflect_start = time.time()
2654
+ reflect_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
2655
+ log_buffer = []
2656
+ log_buffer.append(f"[REFLECT {reflect_id}] Query: '{query[:50]}...'")
2657
+
2587
2658
  # Steps 1-3: Run multi-fact-type search (12-way retrieval: 4 methods × 3 fact types)
2659
+ recall_start = time.time()
2588
2660
  search_result = await self.recall_async(
2589
2661
  bank_id=bank_id,
2590
2662
  query=query,
@@ -2594,24 +2666,22 @@ Guidelines:
2594
2666
  fact_type=['experience', 'world', 'opinion'],
2595
2667
  include_entities=True
2596
2668
  )
2669
+ recall_time = time.time() - recall_start
2597
2670
 
2598
2671
  all_results = search_result.results
2599
- logger.info(f"[THINK] Search returned {len(all_results)} results")
2600
2672
 
2601
2673
  # Split results by fact type for structured response
2602
2674
  agent_results = [r for r in all_results if r.fact_type == 'experience']
2603
2675
  world_results = [r for r in all_results if r.fact_type == 'world']
2604
2676
  opinion_results = [r for r in all_results if r.fact_type == 'opinion']
2605
2677
 
2606
- logger.info(f"[THINK] Split results - agent: {len(agent_results)}, world: {len(world_results)}, opinion: {len(opinion_results)}")
2678
+ log_buffer.append(f"[REFLECT {reflect_id}] Recall: {len(all_results)} facts (experience={len(agent_results)}, world={len(world_results)}, opinion={len(opinion_results)}) in {recall_time:.3f}s")
2607
2679
 
2608
2680
  # Format facts for LLM
2609
2681
  agent_facts_text = think_utils.format_facts_for_prompt(agent_results)
2610
2682
  world_facts_text = think_utils.format_facts_for_prompt(world_results)
2611
2683
  opinion_facts_text = think_utils.format_facts_for_prompt(opinion_results)
2612
2684
 
2613
- logger.info(f"[THINK] Formatted facts - agent: {len(agent_facts_text)} chars, world: {len(world_facts_text)} chars, opinion: {len(opinion_facts_text)} chars")
2614
-
2615
2685
  # Get bank profile (name, disposition + background)
2616
2686
  profile = await self.get_bank_profile(bank_id)
2617
2687
  name = profile["name"]
@@ -2630,10 +2700,11 @@ Guidelines:
2630
2700
  context=context,
2631
2701
  )
2632
2702
 
2633
- logger.info(f"[THINK] Full prompt length: {len(prompt)} chars")
2703
+ log_buffer.append(f"[REFLECT {reflect_id}] Prompt: {len(prompt)} chars")
2634
2704
 
2635
2705
  system_message = think_utils.get_system_message(disposition)
2636
2706
 
2707
+ llm_start = time.time()
2637
2708
  answer_text = await self._llm_config.call(
2638
2709
  messages=[
2639
2710
  {"role": "system", "content": system_message},
@@ -2641,8 +2712,9 @@ Guidelines:
2641
2712
  ],
2642
2713
  scope="memory_think",
2643
2714
  temperature=0.9,
2644
- max_tokens=1000
2715
+ max_completion_tokens=1000
2645
2716
  )
2717
+ llm_time = time.time() - llm_start
2646
2718
 
2647
2719
  answer_text = answer_text.strip()
2648
2720
 
@@ -2654,6 +2726,10 @@ Guidelines:
2654
2726
  'query': query
2655
2727
  })
2656
2728
 
2729
+ total_time = time.time() - reflect_start
2730
+ log_buffer.append(f"[REFLECT {reflect_id}] Complete: {len(answer_text)} chars response, LLM {llm_time:.3f}s, total {total_time:.3f}s")
2731
+ logger.info("\n" + "\n".join(log_buffer))
2732
+
2657
2733
  # Return response with facts split by type
2658
2734
  return ReflectResult(
2659
2735
  text=answer_text,
@@ -2702,7 +2778,7 @@ Guidelines:
2702
2778
  )
2703
2779
 
2704
2780
  except Exception as e:
2705
- logger.warning(f"[THINK] Failed to extract/store opinions: {str(e)}")
2781
+ logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
2706
2782
 
2707
2783
  async def get_entity_observations(
2708
2784
  self,
@@ -2828,7 +2904,8 @@ Guidelines:
2828
2904
  bank_id: str,
2829
2905
  entity_id: str,
2830
2906
  entity_name: str,
2831
- version: str | None = None
2907
+ version: str | None = None,
2908
+ conn=None
2832
2909
  ) -> List[str]:
2833
2910
  """
2834
2911
  Regenerate observations for an entity by:
@@ -2843,43 +2920,58 @@ Guidelines:
2843
2920
  entity_id: Entity UUID
2844
2921
  entity_name: Canonical name of the entity
2845
2922
  version: Entity's last_seen timestamp when task was created (for deduplication)
2923
+ conn: Optional database connection (for transactional atomicity with caller)
2846
2924
 
2847
2925
  Returns:
2848
2926
  List of created observation IDs
2849
2927
  """
2850
2928
  pool = await self._get_pool()
2929
+ entity_uuid = uuid.UUID(entity_id)
2851
2930
 
2852
- # Step 1: Check version for deduplication
2853
- if version:
2854
- async with acquire_with_retry(pool) as conn:
2855
- current_last_seen = await conn.fetchval(
2856
- """
2857
- SELECT last_seen
2858
- FROM entities
2859
- WHERE id = $1 AND bank_id = $2
2860
- """,
2861
- uuid.UUID(entity_id), bank_id
2862
- )
2931
+ # Helper to run a query with provided conn or acquire one
2932
+ async def fetch_with_conn(query, *args):
2933
+ if conn is not None:
2934
+ return await conn.fetch(query, *args)
2935
+ else:
2936
+ async with acquire_with_retry(pool) as acquired_conn:
2937
+ return await acquired_conn.fetch(query, *args)
2863
2938
 
2864
- if current_last_seen and current_last_seen.isoformat() != version:
2865
- return []
2939
+ async def fetchval_with_conn(query, *args):
2940
+ if conn is not None:
2941
+ return await conn.fetchval(query, *args)
2942
+ else:
2943
+ async with acquire_with_retry(pool) as acquired_conn:
2944
+ return await acquired_conn.fetchval(query, *args)
2866
2945
 
2867
- # Step 2: Get all facts mentioning this entity (exclude observations themselves)
2868
- async with acquire_with_retry(pool) as conn:
2869
- rows = await conn.fetch(
2946
+ # Step 1: Check version for deduplication
2947
+ if version:
2948
+ current_last_seen = await fetchval_with_conn(
2870
2949
  """
2871
- SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
2872
- FROM memory_units mu
2873
- JOIN unit_entities ue ON mu.id = ue.unit_id
2874
- WHERE mu.bank_id = $1
2875
- AND ue.entity_id = $2
2876
- AND mu.fact_type IN ('world', 'experience')
2877
- ORDER BY mu.occurred_start DESC
2878
- LIMIT 50
2950
+ SELECT last_seen
2951
+ FROM entities
2952
+ WHERE id = $1 AND bank_id = $2
2879
2953
  """,
2880
- bank_id, uuid.UUID(entity_id)
2954
+ entity_uuid, bank_id
2881
2955
  )
2882
2956
 
2957
+ if current_last_seen and current_last_seen.isoformat() != version:
2958
+ return []
2959
+
2960
+ # Step 2: Get all facts mentioning this entity (exclude observations themselves)
2961
+ rows = await fetch_with_conn(
2962
+ """
2963
+ SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
2964
+ FROM memory_units mu
2965
+ JOIN unit_entities ue ON mu.id = ue.unit_id
2966
+ WHERE mu.bank_id = $1
2967
+ AND ue.entity_id = $2
2968
+ AND mu.fact_type IN ('world', 'experience')
2969
+ ORDER BY mu.occurred_start DESC
2970
+ LIMIT 50
2971
+ """,
2972
+ bank_id, entity_uuid
2973
+ )
2974
+
2883
2975
  if not rows:
2884
2976
  return []
2885
2977
 
@@ -2905,120 +2997,173 @@ Guidelines:
2905
2997
  if not observations:
2906
2998
  return []
2907
2999
 
2908
- # Step 4: Delete old observations and insert new ones in a transaction
2909
- async with acquire_with_retry(pool) as conn:
2910
- async with conn.transaction():
2911
- # Delete old observations for this entity
2912
- await conn.execute(
3000
+ # Step 4: Delete old observations and insert new ones
3001
+ # If conn provided, we're already in a transaction - don't start another
3002
+ # If conn is None, acquire one and start a transaction
3003
+ async def do_db_operations(db_conn):
3004
+ # Delete old observations for this entity
3005
+ await db_conn.execute(
3006
+ """
3007
+ DELETE FROM memory_units
3008
+ WHERE id IN (
3009
+ SELECT mu.id
3010
+ FROM memory_units mu
3011
+ JOIN unit_entities ue ON mu.id = ue.unit_id
3012
+ WHERE mu.bank_id = $1
3013
+ AND mu.fact_type = 'observation'
3014
+ AND ue.entity_id = $2
3015
+ )
3016
+ """,
3017
+ bank_id, entity_uuid
3018
+ )
3019
+
3020
+ # Generate embeddings for new observations
3021
+ embeddings = await embedding_utils.generate_embeddings_batch(
3022
+ self.embeddings, observations
3023
+ )
3024
+
3025
+ # Insert new observations
3026
+ current_time = utcnow()
3027
+ created_ids = []
3028
+
3029
+ for obs_text, embedding in zip(observations, embeddings):
3030
+ result = await db_conn.fetchrow(
2913
3031
  """
2914
- DELETE FROM memory_units
2915
- WHERE id IN (
2916
- SELECT mu.id
2917
- FROM memory_units mu
2918
- JOIN unit_entities ue ON mu.id = ue.unit_id
2919
- WHERE mu.bank_id = $1
2920
- AND mu.fact_type = 'observation'
2921
- AND ue.entity_id = $2
3032
+ INSERT INTO memory_units (
3033
+ bank_id, text, embedding, context, event_date,
3034
+ occurred_start, occurred_end, mentioned_at,
3035
+ fact_type, access_count
2922
3036
  )
3037
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
3038
+ RETURNING id
2923
3039
  """,
2924
- bank_id, uuid.UUID(entity_id)
3040
+ bank_id,
3041
+ obs_text,
3042
+ str(embedding),
3043
+ f"observation about {entity_name}",
3044
+ current_time,
3045
+ current_time,
3046
+ current_time,
3047
+ current_time
2925
3048
  )
3049
+ obs_id = str(result['id'])
3050
+ created_ids.append(obs_id)
2926
3051
 
2927
- # Generate embeddings for new observations
2928
- embeddings = await embedding_utils.generate_embeddings_batch(
2929
- self.embeddings, observations
3052
+ # Link observation to entity
3053
+ await db_conn.execute(
3054
+ """
3055
+ INSERT INTO unit_entities (unit_id, entity_id)
3056
+ VALUES ($1, $2)
3057
+ """,
3058
+ uuid.UUID(obs_id), entity_uuid
2930
3059
  )
2931
3060
 
2932
- # Insert new observations
2933
- current_time = utcnow()
2934
- created_ids = []
2935
-
2936
- for obs_text, embedding in zip(observations, embeddings):
2937
- result = await conn.fetchrow(
2938
- """
2939
- INSERT INTO memory_units (
2940
- bank_id, text, embedding, context, event_date,
2941
- occurred_start, occurred_end, mentioned_at,
2942
- fact_type, access_count
2943
- )
2944
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
2945
- RETURNING id
2946
- """,
2947
- bank_id,
2948
- obs_text,
2949
- str(embedding),
2950
- f"observation about {entity_name}",
2951
- current_time,
2952
- current_time,
2953
- current_time,
2954
- current_time
2955
- )
2956
- obs_id = str(result['id'])
2957
- created_ids.append(obs_id)
2958
-
2959
- # Link observation to entity
2960
- await conn.execute(
2961
- """
2962
- INSERT INTO unit_entities (unit_id, entity_id)
2963
- VALUES ($1, $2)
2964
- """,
2965
- uuid.UUID(obs_id), uuid.UUID(entity_id)
2966
- )
3061
+ return created_ids
2967
3062
 
2968
- # Single consolidated log line
2969
- logger.info(f"[OBSERVATIONS] {entity_name}: {len(facts)} facts -> {len(created_ids)} observations")
2970
- return created_ids
3063
+ if conn is not None:
3064
+ # Use provided connection (already in a transaction)
3065
+ return await do_db_operations(conn)
3066
+ else:
3067
+ # Acquire connection and start our own transaction
3068
+ async with acquire_with_retry(pool) as acquired_conn:
3069
+ async with acquired_conn.transaction():
3070
+ return await do_db_operations(acquired_conn)
2971
3071
 
2972
3072
  async def _regenerate_observations_sync(
2973
3073
  self,
2974
3074
  bank_id: str,
2975
3075
  entity_ids: List[str],
2976
- min_facts: int = 5
3076
+ min_facts: int = 5,
3077
+ conn=None
2977
3078
  ) -> None:
2978
3079
  """
2979
3080
  Regenerate observations for entities synchronously (called during retain).
2980
3081
 
3082
+ Processes entities in PARALLEL for faster execution.
3083
+
2981
3084
  Args:
2982
3085
  bank_id: Bank identifier
2983
3086
  entity_ids: List of entity IDs to process
2984
3087
  min_facts: Minimum facts required to regenerate observations
3088
+ conn: Optional database connection (for transactional atomicity)
2985
3089
  """
2986
3090
  if not bank_id or not entity_ids:
2987
3091
  return
2988
3092
 
2989
- pool = await self._get_pool()
2990
- async with pool.acquire() as conn:
2991
- for entity_id in entity_ids:
2992
- try:
2993
- entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
2994
-
2995
- # Check if entity exists
2996
- entity_exists = await conn.fetchrow(
2997
- "SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
2998
- entity_uuid, bank_id
2999
- )
3093
+ # Convert to UUIDs
3094
+ entity_uuids = [uuid.UUID(eid) if isinstance(eid, str) else eid for eid in entity_ids]
3000
3095
 
3001
- if not entity_exists:
3002
- logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
3003
- continue
3096
+ # Use provided connection or acquire a new one
3097
+ if conn is not None:
3098
+ # Use the provided connection (transactional with caller)
3099
+ entity_rows = await conn.fetch(
3100
+ """
3101
+ SELECT id, canonical_name FROM entities
3102
+ WHERE id = ANY($1) AND bank_id = $2
3103
+ """,
3104
+ entity_uuids, bank_id
3105
+ )
3106
+ entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
3004
3107
 
3005
- entity_name = entity_exists['canonical_name']
3108
+ fact_counts = await conn.fetch(
3109
+ """
3110
+ SELECT ue.entity_id, COUNT(*) as cnt
3111
+ FROM unit_entities ue
3112
+ JOIN memory_units mu ON ue.unit_id = mu.id
3113
+ WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3114
+ GROUP BY ue.entity_id
3115
+ """,
3116
+ entity_uuids, bank_id
3117
+ )
3118
+ entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
3119
+ else:
3120
+ # Acquire a new connection (standalone call)
3121
+ pool = await self._get_pool()
3122
+ async with pool.acquire() as acquired_conn:
3123
+ entity_rows = await acquired_conn.fetch(
3124
+ """
3125
+ SELECT id, canonical_name FROM entities
3126
+ WHERE id = ANY($1) AND bank_id = $2
3127
+ """,
3128
+ entity_uuids, bank_id
3129
+ )
3130
+ entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
3006
3131
 
3007
- # Count facts linked to this entity
3008
- fact_count = await conn.fetchval(
3009
- "SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1",
3010
- entity_uuid
3011
- ) or 0
3132
+ fact_counts = await acquired_conn.fetch(
3133
+ """
3134
+ SELECT ue.entity_id, COUNT(*) as cnt
3135
+ FROM unit_entities ue
3136
+ JOIN memory_units mu ON ue.unit_id = mu.id
3137
+ WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3138
+ GROUP BY ue.entity_id
3139
+ """,
3140
+ entity_uuids, bank_id
3141
+ )
3142
+ entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
3143
+
3144
+ # Filter entities that meet the threshold
3145
+ entities_to_process = []
3146
+ for entity_id in entity_ids:
3147
+ entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
3148
+ if entity_uuid not in entity_names:
3149
+ continue
3150
+ fact_count = entity_fact_counts.get(entity_uuid, 0)
3151
+ if fact_count >= min_facts:
3152
+ entities_to_process.append((entity_id, entity_names[entity_uuid]))
3153
+
3154
+ if not entities_to_process:
3155
+ return
3012
3156
 
3013
- # Only regenerate if entity has enough facts
3014
- if fact_count >= min_facts:
3015
- await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
3016
- else:
3017
- logger.debug(f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)")
3157
+ # Process all entities in PARALLEL (LLM calls are the bottleneck)
3158
+ async def process_entity(entity_id: str, entity_name: str):
3159
+ try:
3160
+ await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None, conn=conn)
3161
+ except Exception as e:
3162
+ logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3018
3163
 
3019
- except Exception as e:
3020
- logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3021
- continue
3164
+ await asyncio.gather(*[
3165
+ process_entity(eid, name) for eid, name in entities_to_process
3166
+ ])
3022
3167
 
3023
3168
  async def _handle_regenerate_observations(self, task_dict: Dict[str, Any]):
3024
3169
  """