hindsight-api 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. hindsight_api/admin/__init__.py +1 -0
  2. hindsight_api/admin/cli.py +252 -0
  3. hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
  4. hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
  5. hindsight_api/api/http.py +282 -20
  6. hindsight_api/api/mcp.py +47 -52
  7. hindsight_api/config.py +238 -6
  8. hindsight_api/engine/cross_encoder.py +599 -86
  9. hindsight_api/engine/db_budget.py +284 -0
  10. hindsight_api/engine/db_utils.py +11 -0
  11. hindsight_api/engine/embeddings.py +453 -26
  12. hindsight_api/engine/entity_resolver.py +8 -5
  13. hindsight_api/engine/interface.py +8 -4
  14. hindsight_api/engine/llm_wrapper.py +241 -27
  15. hindsight_api/engine/memory_engine.py +609 -122
  16. hindsight_api/engine/query_analyzer.py +4 -3
  17. hindsight_api/engine/response_models.py +38 -0
  18. hindsight_api/engine/retain/fact_extraction.py +388 -192
  19. hindsight_api/engine/retain/fact_storage.py +34 -8
  20. hindsight_api/engine/retain/link_utils.py +24 -16
  21. hindsight_api/engine/retain/orchestrator.py +52 -17
  22. hindsight_api/engine/retain/types.py +9 -0
  23. hindsight_api/engine/search/graph_retrieval.py +42 -13
  24. hindsight_api/engine/search/link_expansion_retrieval.py +256 -0
  25. hindsight_api/engine/search/mpfp_retrieval.py +362 -117
  26. hindsight_api/engine/search/reranking.py +2 -2
  27. hindsight_api/engine/search/retrieval.py +847 -200
  28. hindsight_api/engine/search/tags.py +172 -0
  29. hindsight_api/engine/search/think_utils.py +1 -1
  30. hindsight_api/engine/search/trace.py +12 -0
  31. hindsight_api/engine/search/tracer.py +24 -1
  32. hindsight_api/engine/search/types.py +21 -0
  33. hindsight_api/engine/task_backend.py +109 -18
  34. hindsight_api/engine/utils.py +1 -1
  35. hindsight_api/extensions/context.py +10 -1
  36. hindsight_api/main.py +56 -4
  37. hindsight_api/metrics.py +433 -48
  38. hindsight_api/migrations.py +141 -1
  39. hindsight_api/models.py +3 -1
  40. hindsight_api/pg0.py +53 -0
  41. hindsight_api/server.py +39 -2
  42. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/METADATA +5 -1
  43. hindsight_api-0.3.0.dist-info/RECORD +82 -0
  44. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/entry_points.txt +1 -0
  45. hindsight_api-0.2.0.dist-info/RECORD +0 -75
  46. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/WHEEL +0 -0
@@ -18,6 +18,8 @@ from datetime import UTC, datetime, timedelta
18
18
  from typing import TYPE_CHECKING, Any
19
19
 
20
20
  from ..config import get_config
21
+ from ..metrics import get_metrics_collector
22
+ from .db_budget import budgeted_operation
21
23
 
22
24
  # Context variable for current schema (async-safe, per-task isolation)
23
25
  _current_schema: contextvars.ContextVar[str] = contextvars.ContextVar("current_schema", default="public")
@@ -132,17 +134,25 @@ if TYPE_CHECKING:
132
134
 
133
135
  from enum import Enum
134
136
 
135
- from ..pg0 import EmbeddedPostgres
137
+ from ..pg0 import EmbeddedPostgres, parse_pg0_url
136
138
  from .entity_resolver import EntityResolver
137
139
  from .llm_wrapper import LLMConfig
138
140
  from .query_analyzer import QueryAnalyzer
139
- from .response_models import VALID_RECALL_FACT_TYPES, EntityObservation, EntityState, MemoryFact, ReflectResult
141
+ from .response_models import (
142
+ VALID_RECALL_FACT_TYPES,
143
+ EntityObservation,
144
+ EntityState,
145
+ MemoryFact,
146
+ ReflectResult,
147
+ TokenUsage,
148
+ )
140
149
  from .response_models import RecallResult as RecallResultModel
141
150
  from .retain import bank_utils, embedding_utils
142
151
  from .retain.types import RetainContentDict
143
152
  from .search import observation_utils, think_utils
144
153
  from .search.reranking import CrossEncoderReranker
145
- from .task_backend import AsyncIOQueueBackend, TaskBackend
154
+ from .search.tags import TagsMatch
155
+ from .task_backend import AsyncIOQueueBackend, NoopTaskBackend, TaskBackend
146
156
 
147
157
 
148
158
  class Budget(str, Enum):
@@ -195,12 +205,25 @@ class MemoryEngine(MemoryEngineInterface):
195
205
  memory_llm_api_key: str | None = None,
196
206
  memory_llm_model: str | None = None,
197
207
  memory_llm_base_url: str | None = None,
208
+ # Per-operation LLM config (optional, falls back to memory_llm_* params)
209
+ retain_llm_provider: str | None = None,
210
+ retain_llm_api_key: str | None = None,
211
+ retain_llm_model: str | None = None,
212
+ retain_llm_base_url: str | None = None,
213
+ reflect_llm_provider: str | None = None,
214
+ reflect_llm_api_key: str | None = None,
215
+ reflect_llm_model: str | None = None,
216
+ reflect_llm_base_url: str | None = None,
198
217
  embeddings: Embeddings | None = None,
199
218
  cross_encoder: CrossEncoderModel | None = None,
200
219
  query_analyzer: QueryAnalyzer | None = None,
201
- pool_min_size: int = 5,
202
- pool_max_size: int = 100,
220
+ pool_min_size: int | None = None,
221
+ pool_max_size: int | None = None,
222
+ db_command_timeout: int | None = None,
223
+ db_acquire_timeout: int | None = None,
203
224
  task_backend: TaskBackend | None = None,
225
+ task_batch_size: int | None = None,
226
+ task_batch_interval: float | None = None,
204
227
  run_migrations: bool = True,
205
228
  operation_validator: "OperationValidatorExtension | None" = None,
206
229
  tenant_extension: "TenantExtension | None" = None,
@@ -220,12 +243,24 @@ class MemoryEngine(MemoryEngineInterface):
220
243
  memory_llm_api_key: API key for the LLM provider. Defaults to HINDSIGHT_API_LLM_API_KEY env var.
221
244
  memory_llm_model: Model name. Defaults to HINDSIGHT_API_LLM_MODEL env var.
222
245
  memory_llm_base_url: Base URL for the LLM API. Defaults based on provider.
246
+ retain_llm_provider: LLM provider for retain operations. Falls back to memory_llm_provider.
247
+ retain_llm_api_key: API key for retain LLM. Falls back to memory_llm_api_key.
248
+ retain_llm_model: Model for retain operations. Falls back to memory_llm_model.
249
+ retain_llm_base_url: Base URL for retain LLM. Falls back to memory_llm_base_url.
250
+ reflect_llm_provider: LLM provider for reflect operations. Falls back to memory_llm_provider.
251
+ reflect_llm_api_key: API key for reflect LLM. Falls back to memory_llm_api_key.
252
+ reflect_llm_model: Model for reflect operations. Falls back to memory_llm_model.
253
+ reflect_llm_base_url: Base URL for reflect LLM. Falls back to memory_llm_base_url.
223
254
  embeddings: Embeddings implementation. If not provided, created from env vars.
224
255
  cross_encoder: Cross-encoder model. If not provided, created from env vars.
225
256
  query_analyzer: Query analyzer implementation. If not provided, uses DateparserQueryAnalyzer.
226
- pool_min_size: Minimum number of connections in the pool (default: 5)
227
- pool_max_size: Maximum number of connections in the pool (default: 100)
257
+ pool_min_size: Minimum number of connections in the pool. Defaults to HINDSIGHT_API_DB_POOL_MIN_SIZE.
258
+ pool_max_size: Maximum number of connections in the pool. Defaults to HINDSIGHT_API_DB_POOL_MAX_SIZE.
259
+ db_command_timeout: PostgreSQL command timeout in seconds. Defaults to HINDSIGHT_API_DB_COMMAND_TIMEOUT.
260
+ db_acquire_timeout: Connection acquisition timeout in seconds. Defaults to HINDSIGHT_API_DB_ACQUIRE_TIMEOUT.
228
261
  task_backend: Custom task backend. If not provided, uses AsyncIOQueueBackend.
262
+ task_batch_size: Background task batch size. Defaults to HINDSIGHT_API_TASK_BACKEND_MEMORY_BATCH_SIZE.
263
+ task_batch_interval: Background task batch interval in seconds. Defaults to HINDSIGHT_API_TASK_BACKEND_MEMORY_BATCH_INTERVAL.
229
264
  run_migrations: Whether to run database migrations during initialize(). Default: True
230
265
  operation_validator: Optional extension to validate operations before execution.
231
266
  If provided, retain/recall/reflect operations will be validated.
@@ -252,38 +287,21 @@ class MemoryEngine(MemoryEngineInterface):
252
287
  db_url = db_url or config.database_url
253
288
  memory_llm_provider = memory_llm_provider or config.llm_provider
254
289
  memory_llm_api_key = memory_llm_api_key or config.llm_api_key
255
- # Ollama doesn't require an API key
256
- if not memory_llm_api_key and memory_llm_provider != "ollama":
290
+ # Ollama and mock don't require an API key
291
+ if not memory_llm_api_key and memory_llm_provider not in ("ollama", "mock"):
257
292
  raise ValueError("LLM API key is required. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
258
293
  memory_llm_model = memory_llm_model or config.llm_model
259
294
  memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
260
295
  # Track pg0 instance (if used)
261
296
  self._pg0: EmbeddedPostgres | None = None
262
- self._pg0_instance_name: str | None = None
263
297
 
264
298
  # Initialize PostgreSQL connection URL
265
299
  # The actual URL will be set during initialize() after starting the server
266
300
  # Supports: "pg0" (default instance), "pg0://instance-name" (named instance), or regular postgresql:// URL
267
- if db_url == "pg0":
268
- self._use_pg0 = True
269
- self._pg0_instance_name = "hindsight"
270
- self._pg0_port = None # Use default port
271
- self.db_url = None
272
- elif db_url.startswith("pg0://"):
273
- self._use_pg0 = True
274
- # Parse instance name and optional port: pg0://instance-name or pg0://instance-name:port
275
- url_part = db_url[6:] # Remove "pg0://"
276
- if ":" in url_part:
277
- self._pg0_instance_name, port_str = url_part.rsplit(":", 1)
278
- self._pg0_port = int(port_str)
279
- else:
280
- self._pg0_instance_name = url_part or "hindsight"
281
- self._pg0_port = None # Use default port
301
+ self._use_pg0, self._pg0_instance_name, self._pg0_port = parse_pg0_url(db_url)
302
+ if self._use_pg0:
282
303
  self.db_url = None
283
304
  else:
284
- self._use_pg0 = False
285
- self._pg0_instance_name = None
286
- self._pg0_port = None
287
305
  self.db_url = db_url
288
306
 
289
307
  # Set default base URL if not provided
@@ -298,8 +316,10 @@ class MemoryEngine(MemoryEngineInterface):
298
316
  # Connection pool (will be created in initialize())
299
317
  self._pool = None
300
318
  self._initialized = False
301
- self._pool_min_size = pool_min_size
302
- self._pool_max_size = pool_max_size
319
+ self._pool_min_size = pool_min_size if pool_min_size is not None else config.db_pool_min_size
320
+ self._pool_max_size = pool_max_size if pool_max_size is not None else config.db_pool_max_size
321
+ self._db_command_timeout = db_command_timeout if db_command_timeout is not None else config.db_command_timeout
322
+ self._db_acquire_timeout = db_acquire_timeout if db_acquire_timeout is not None else config.db_acquire_timeout
303
323
  self._run_migrations = run_migrations
304
324
 
305
325
  # Initialize entity resolver (will be created in initialize())
@@ -319,7 +339,7 @@ class MemoryEngine(MemoryEngineInterface):
319
339
 
320
340
  self.query_analyzer = DateparserQueryAnalyzer()
321
341
 
322
- # Initialize LLM configuration
342
+ # Initialize LLM configuration (default, used as fallback)
323
343
  self._llm_config = LLMConfig(
324
344
  provider=memory_llm_provider,
325
345
  api_key=memory_llm_api_key,
@@ -331,17 +351,68 @@ class MemoryEngine(MemoryEngineInterface):
331
351
  self._llm_client = self._llm_config._client
332
352
  self._llm_model = self._llm_config.model
333
353
 
354
+ # Initialize per-operation LLM configs (fall back to default if not specified)
355
+ # Retain LLM config - for fact extraction (benefits from strong structured output)
356
+ retain_provider = retain_llm_provider or config.retain_llm_provider or memory_llm_provider
357
+ retain_api_key = retain_llm_api_key or config.retain_llm_api_key or memory_llm_api_key
358
+ retain_model = retain_llm_model or config.retain_llm_model or memory_llm_model
359
+ retain_base_url = retain_llm_base_url or config.retain_llm_base_url or memory_llm_base_url
360
+ # Apply provider-specific base URL defaults for retain
361
+ if retain_base_url is None:
362
+ if retain_provider.lower() == "groq":
363
+ retain_base_url = "https://api.groq.com/openai/v1"
364
+ elif retain_provider.lower() == "ollama":
365
+ retain_base_url = "http://localhost:11434/v1"
366
+ else:
367
+ retain_base_url = ""
368
+
369
+ self._retain_llm_config = LLMConfig(
370
+ provider=retain_provider,
371
+ api_key=retain_api_key,
372
+ base_url=retain_base_url,
373
+ model=retain_model,
374
+ )
375
+
376
+ # Reflect LLM config - for think/observe operations (can use lighter models)
377
+ reflect_provider = reflect_llm_provider or config.reflect_llm_provider or memory_llm_provider
378
+ reflect_api_key = reflect_llm_api_key or config.reflect_llm_api_key or memory_llm_api_key
379
+ reflect_model = reflect_llm_model or config.reflect_llm_model or memory_llm_model
380
+ reflect_base_url = reflect_llm_base_url or config.reflect_llm_base_url or memory_llm_base_url
381
+ # Apply provider-specific base URL defaults for reflect
382
+ if reflect_base_url is None:
383
+ if reflect_provider.lower() == "groq":
384
+ reflect_base_url = "https://api.groq.com/openai/v1"
385
+ elif reflect_provider.lower() == "ollama":
386
+ reflect_base_url = "http://localhost:11434/v1"
387
+ else:
388
+ reflect_base_url = ""
389
+
390
+ self._reflect_llm_config = LLMConfig(
391
+ provider=reflect_provider,
392
+ api_key=reflect_api_key,
393
+ base_url=reflect_base_url,
394
+ model=reflect_model,
395
+ )
396
+
334
397
  # Initialize cross-encoder reranker (cached for performance)
335
398
  self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
336
399
 
337
400
  # Initialize task backend
338
- self._task_backend = task_backend or AsyncIOQueueBackend(batch_size=100, batch_interval=1.0)
401
+ if task_backend:
402
+ self._task_backend = task_backend
403
+ elif config.task_backend == "noop":
404
+ self._task_backend = NoopTaskBackend()
405
+ else:
406
+ # Default to memory (AsyncIOQueueBackend)
407
+ _task_batch_size = task_batch_size if task_batch_size is not None else config.task_backend_memory_batch_size
408
+ _task_batch_interval = (
409
+ task_batch_interval if task_batch_interval is not None else config.task_backend_memory_batch_interval
410
+ )
411
+ self._task_backend = AsyncIOQueueBackend(batch_size=_task_batch_size, batch_interval=_task_batch_interval)
339
412
 
340
413
  # Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
341
- # Limit concurrent searches to prevent connection pool exhaustion
342
- # Each search can use 2-4 connections, so with 10 concurrent searches
343
- # we use ~20-40 connections max, staying well within pool limits
344
- self._search_semaphore = asyncio.Semaphore(10)
414
+ # Configurable via HINDSIGHT_API_RECALL_MAX_CONCURRENT (default: 50)
415
+ self._search_semaphore = asyncio.Semaphore(get_config().recall_max_concurrent)
345
416
 
346
417
  # Backpressure for put operations: limit concurrent puts to prevent database contention
347
418
  # Each put_batch holds a connection for the entire transaction, so we limit to 5
@@ -618,9 +689,27 @@ class MemoryEngine(MemoryEngineInterface):
618
689
  await loop.run_in_executor(None, self.query_analyzer.load)
619
690
 
620
691
  async def verify_llm():
621
- """Verify LLM connection is working."""
692
+ """Verify LLM connections are working for all unique configs."""
622
693
  if not self._skip_llm_verification:
694
+ # Verify default config
623
695
  await self._llm_config.verify_connection()
696
+ # Verify retain config if different from default
697
+ retain_is_different = (
698
+ self._retain_llm_config.provider != self._llm_config.provider
699
+ or self._retain_llm_config.model != self._llm_config.model
700
+ )
701
+ if retain_is_different:
702
+ await self._retain_llm_config.verify_connection()
703
+ # Verify reflect config if different from default and retain
704
+ reflect_is_different = (
705
+ self._reflect_llm_config.provider != self._llm_config.provider
706
+ or self._reflect_llm_config.model != self._llm_config.model
707
+ ) and (
708
+ self._reflect_llm_config.provider != self._retain_llm_config.provider
709
+ or self._reflect_llm_config.model != self._retain_llm_config.model
710
+ )
711
+ if reflect_is_different:
712
+ await self._reflect_llm_config.verify_connection()
624
713
 
625
714
  # Build list of initialization tasks
626
715
  init_tasks = [
@@ -642,13 +731,17 @@ class MemoryEngine(MemoryEngineInterface):
642
731
 
643
732
  # Run database migrations if enabled
644
733
  if self._run_migrations:
645
- from ..migrations import run_migrations
734
+ from ..migrations import ensure_embedding_dimension, run_migrations
646
735
 
647
736
  if not self.db_url:
648
737
  raise ValueError("Database URL is required for migrations")
649
738
  logger.info("Running database migrations...")
650
739
  run_migrations(self.db_url)
651
740
 
741
+ # Ensure embedding column dimension matches the model's dimension
742
+ # This is done after migrations and after embeddings.initialize()
743
+ ensure_embedding_dimension(self.db_url, self.embeddings.dimension)
744
+
652
745
  logger.info(f"Connecting to PostgreSQL at {self.db_url}")
653
746
 
654
747
  # Create connection pool
@@ -658,9 +751,9 @@ class MemoryEngine(MemoryEngineInterface):
658
751
  self.db_url,
659
752
  min_size=self._pool_min_size,
660
753
  max_size=self._pool_max_size,
661
- command_timeout=60,
754
+ command_timeout=self._db_command_timeout,
662
755
  statement_cache_size=0, # Disable prepared statement cache
663
- timeout=30, # Connection acquisition timeout (seconds)
756
+ timeout=self._db_acquire_timeout, # Connection acquisition timeout (seconds)
664
757
  )
665
758
 
666
759
  # Initialize entity resolver with pool
@@ -967,7 +1060,9 @@ class MemoryEngine(MemoryEngineInterface):
967
1060
  document_id: str | None = None,
968
1061
  fact_type_override: str | None = None,
969
1062
  confidence_score: float | None = None,
970
- ) -> list[list[str]]:
1063
+ document_tags: list[str] | None = None,
1064
+ return_usage: bool = False,
1065
+ ):
971
1066
  """
972
1067
  Store multiple content items as memory units in ONE batch operation.
973
1068
 
@@ -988,9 +1083,11 @@ class MemoryEngine(MemoryEngineInterface):
988
1083
  Applies the same document_id to ALL content items that don't specify their own.
989
1084
  fact_type_override: Override fact type for all facts ('world', 'experience', 'opinion')
990
1085
  confidence_score: Confidence score for opinions (0.0 to 1.0)
1086
+ return_usage: If True, returns tuple of (unit_ids, TokenUsage). Default False for backward compatibility.
991
1087
 
992
1088
  Returns:
993
- List of lists of unit IDs (one list per content item)
1089
+ If return_usage=False: List of lists of unit IDs (one list per content item)
1090
+ If return_usage=True: Tuple of (unit_ids, TokenUsage)
994
1091
 
995
1092
  Example (new style - per-content document_id):
996
1093
  unit_ids = await memory.retain_batch_async(
@@ -1017,6 +1114,8 @@ class MemoryEngine(MemoryEngineInterface):
1017
1114
  start_time = time.time()
1018
1115
 
1019
1116
  if not contents:
1117
+ if return_usage:
1118
+ return [], TokenUsage()
1020
1119
  return []
1021
1120
 
1022
1121
  # Authenticate tenant and set schema in context (for fq_table())
@@ -1046,6 +1145,7 @@ class MemoryEngine(MemoryEngineInterface):
1046
1145
  # Auto-chunk large batches by character count to avoid timeouts and memory issues
1047
1146
  # Calculate total character count
1048
1147
  total_chars = sum(len(item.get("content", "")) for item in contents)
1148
+ total_usage = TokenUsage()
1049
1149
 
1050
1150
  CHARS_PER_BATCH = 600_000
1051
1151
 
@@ -1086,15 +1186,17 @@ class MemoryEngine(MemoryEngineInterface):
1086
1186
  f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars"
1087
1187
  )
1088
1188
 
1089
- sub_results = await self._retain_batch_async_internal(
1189
+ sub_results, sub_usage = await self._retain_batch_async_internal(
1090
1190
  bank_id=bank_id,
1091
1191
  contents=sub_batch,
1092
1192
  document_id=document_id,
1093
1193
  is_first_batch=i == 1, # Only upsert on first batch
1094
1194
  fact_type_override=fact_type_override,
1095
1195
  confidence_score=confidence_score,
1196
+ document_tags=document_tags,
1096
1197
  )
1097
1198
  all_results.extend(sub_results)
1199
+ total_usage = total_usage + sub_usage
1098
1200
 
1099
1201
  total_time = time.time() - start_time
1100
1202
  logger.info(
@@ -1103,13 +1205,14 @@ class MemoryEngine(MemoryEngineInterface):
1103
1205
  result = all_results
1104
1206
  else:
1105
1207
  # Small batch - use internal method directly
1106
- result = await self._retain_batch_async_internal(
1208
+ result, total_usage = await self._retain_batch_async_internal(
1107
1209
  bank_id=bank_id,
1108
1210
  contents=contents,
1109
1211
  document_id=document_id,
1110
1212
  is_first_batch=True,
1111
1213
  fact_type_override=fact_type_override,
1112
1214
  confidence_score=confidence_score,
1215
+ document_tags=document_tags,
1113
1216
  )
1114
1217
 
1115
1218
  # Call post-operation hook if validator is configured
@@ -1132,6 +1235,8 @@ class MemoryEngine(MemoryEngineInterface):
1132
1235
  except Exception as e:
1133
1236
  logger.warning(f"Post-retain hook error (non-fatal): {e}")
1134
1237
 
1238
+ if return_usage:
1239
+ return result, total_usage
1135
1240
  return result
1136
1241
 
1137
1242
  async def _retain_batch_async_internal(
@@ -1142,7 +1247,8 @@ class MemoryEngine(MemoryEngineInterface):
1142
1247
  is_first_batch: bool = True,
1143
1248
  fact_type_override: str | None = None,
1144
1249
  confidence_score: float | None = None,
1145
- ) -> list[list[str]]:
1250
+ document_tags: list[str] | None = None,
1251
+ ) -> tuple[list[list[str]], "TokenUsage"]:
1146
1252
  """
1147
1253
  Internal method for batch processing without chunking logic.
1148
1254
 
@@ -1158,6 +1264,10 @@ class MemoryEngine(MemoryEngineInterface):
1158
1264
  is_first_batch: Whether this is the first batch (for chunked operations, only delete on first batch)
1159
1265
  fact_type_override: Override fact type for all facts
1160
1266
  confidence_score: Confidence score for opinions
1267
+ document_tags: Tags applied to all items in this batch
1268
+
1269
+ Returns:
1270
+ Tuple of (unit ID lists, token usage for fact extraction)
1161
1271
  """
1162
1272
  # Backpressure: limit concurrent retains to prevent database contention
1163
1273
  async with self._put_semaphore:
@@ -1168,7 +1278,7 @@ class MemoryEngine(MemoryEngineInterface):
1168
1278
  return await orchestrator.retain_batch(
1169
1279
  pool=pool,
1170
1280
  embeddings_model=self.embeddings,
1171
- llm_config=self._llm_config,
1281
+ llm_config=self._retain_llm_config,
1172
1282
  entity_resolver=self.entity_resolver,
1173
1283
  task_backend=self._task_backend,
1174
1284
  format_date_fn=self._format_readable_date,
@@ -1179,6 +1289,7 @@ class MemoryEngine(MemoryEngineInterface):
1179
1289
  is_first_batch=is_first_batch,
1180
1290
  fact_type_override=fact_type_override,
1181
1291
  confidence_score=confidence_score,
1292
+ document_tags=document_tags,
1182
1293
  )
1183
1294
 
1184
1295
  def recall(
@@ -1237,6 +1348,8 @@ class MemoryEngine(MemoryEngineInterface):
1237
1348
  include_chunks: bool = False,
1238
1349
  max_chunk_tokens: int = 8192,
1239
1350
  request_context: "RequestContext",
1351
+ tags: list[str] | None = None,
1352
+ tags_match: TagsMatch = "any",
1240
1353
  ) -> RecallResultModel:
1241
1354
  """
1242
1355
  Recall memories using N*4-way parallel retrieval (N fact types × 4 retrieval methods).
@@ -1262,6 +1375,8 @@ class MemoryEngine(MemoryEngineInterface):
1262
1375
  max_entity_tokens: Maximum tokens for entity observations (default 500)
1263
1376
  include_chunks: Whether to include raw chunks in the response
1264
1377
  max_chunk_tokens: Maximum tokens for chunks (default 8192)
1378
+ tags: Optional list of tags for visibility filtering (OR matching - returns
1379
+ memories that have at least one matching tag)
1265
1380
 
1266
1381
  Returns:
1267
1382
  RecallResultModel containing:
@@ -1313,7 +1428,9 @@ class MemoryEngine(MemoryEngineInterface):
1313
1428
  # Backpressure: limit concurrent recalls to prevent overwhelming the database
1314
1429
  result = None
1315
1430
  error_msg = None
1431
+ semaphore_wait_start = time.time()
1316
1432
  async with self._search_semaphore:
1433
+ semaphore_wait = time.time() - semaphore_wait_start
1317
1434
  # Retry loop for connection errors
1318
1435
  max_retries = 3
1319
1436
  for attempt in range(max_retries + 1):
@@ -1331,6 +1448,9 @@ class MemoryEngine(MemoryEngineInterface):
1331
1448
  include_chunks,
1332
1449
  max_chunk_tokens,
1333
1450
  request_context,
1451
+ semaphore_wait=semaphore_wait,
1452
+ tags=tags,
1453
+ tags_match=tags_match,
1334
1454
  )
1335
1455
  break # Success - exit retry loop
1336
1456
  except Exception as e:
@@ -1448,6 +1568,9 @@ class MemoryEngine(MemoryEngineInterface):
1448
1568
  include_chunks: bool = False,
1449
1569
  max_chunk_tokens: int = 8192,
1450
1570
  request_context: "RequestContext" = None,
1571
+ semaphore_wait: float = 0.0,
1572
+ tags: list[str] | None = None,
1573
+ tags_match: TagsMatch = "any",
1451
1574
  ) -> RecallResultModel:
1452
1575
  """
1453
1576
  Search implementation with modular retrieval and reranking.
@@ -1477,7 +1600,9 @@ class MemoryEngine(MemoryEngineInterface):
1477
1600
  # Initialize tracer if requested
1478
1601
  from .search.tracer import SearchTracer
1479
1602
 
1480
- tracer = SearchTracer(query, thinking_budget, max_tokens) if enable_trace else None
1603
+ tracer = (
1604
+ SearchTracer(query, thinking_budget, max_tokens, tags=tags, tags_match=tags_match) if enable_trace else None
1605
+ )
1481
1606
  if tracer:
1482
1607
  tracer.start()
1483
1608
 
@@ -1487,8 +1612,9 @@ class MemoryEngine(MemoryEngineInterface):
1487
1612
  # Buffer logs for clean output in concurrent scenarios
1488
1613
  recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
1489
1614
  log_buffer = []
1615
+ tags_info = f", tags={tags}, tags_match={tags_match}" if tags else ""
1490
1616
  log_buffer.append(
1491
- f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})"
1617
+ f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens}{tags_info})"
1492
1618
  )
1493
1619
 
1494
1620
  try:
@@ -1502,37 +1628,67 @@ class MemoryEngine(MemoryEngineInterface):
1502
1628
  tracer.record_query_embedding(query_embedding)
1503
1629
  tracer.add_phase_metric("generate_query_embedding", step_duration)
1504
1630
 
1505
- # Step 2: N*4-Way Parallel Retrieval (N fact types × 4 retrieval methods)
1631
+ # Step 2: Optimized parallel retrieval using batched queries
1632
+ # - Semantic + BM25 combined in 1 CTE query for ALL fact types
1633
+ # - Graph runs per fact type (complex traversal)
1634
+ # - Temporal runs per fact type (if constraint detected)
1506
1635
  step_start = time.time()
1507
1636
  query_embedding_str = str(query_embedding)
1508
1637
 
1509
- from .search.retrieval import retrieve_parallel
1638
+ from .search.retrieval import (
1639
+ get_default_graph_retriever,
1640
+ retrieve_all_fact_types_parallel,
1641
+ )
1510
1642
 
1511
1643
  # Track each retrieval start time
1512
1644
  retrieval_start = time.time()
1513
1645
 
1514
- # Run retrieval for each fact type in parallel
1515
- retrieval_tasks = [
1516
- retrieve_parallel(
1517
- pool, query, query_embedding_str, bank_id, ft, thinking_budget, question_date, self.query_analyzer
1646
+ # Run optimized retrieval with connection budget
1647
+ config = get_config()
1648
+ async with budgeted_operation(
1649
+ max_connections=config.recall_connection_budget,
1650
+ operation_id=f"recall-{recall_id}",
1651
+ ) as op:
1652
+ budgeted_pool = op.wrap_pool(pool)
1653
+ parallel_start = time.time()
1654
+ multi_result = await retrieve_all_fact_types_parallel(
1655
+ budgeted_pool,
1656
+ query,
1657
+ query_embedding_str,
1658
+ bank_id,
1659
+ fact_type, # Pass all fact types at once
1660
+ thinking_budget,
1661
+ question_date,
1662
+ self.query_analyzer,
1663
+ tags=tags,
1664
+ tags_match=tags_match,
1518
1665
  )
1519
- for ft in fact_type
1520
- ]
1521
- all_retrievals = await asyncio.gather(*retrieval_tasks)
1666
+ parallel_duration = time.time() - parallel_start
1522
1667
 
1523
1668
  # Combine all results from all fact types and aggregate timings
1524
1669
  semantic_results = []
1525
1670
  bm25_results = []
1526
1671
  graph_results = []
1527
1672
  temporal_results = []
1528
- aggregated_timings = {"semantic": 0.0, "bm25": 0.0, "graph": 0.0, "temporal": 0.0}
1673
+ aggregated_timings = {
1674
+ "semantic": 0.0,
1675
+ "bm25": 0.0,
1676
+ "graph": 0.0,
1677
+ "temporal": 0.0,
1678
+ "temporal_extraction": 0.0,
1679
+ }
1680
+ all_mpfp_timings = []
1529
1681
 
1530
1682
  detected_temporal_constraint = None
1531
- for idx, retrieval_result in enumerate(all_retrievals):
1683
+ max_conn_wait = multi_result.max_conn_wait
1684
+ for ft in fact_type:
1685
+ retrieval_result = multi_result.results_by_fact_type.get(ft)
1686
+ if not retrieval_result:
1687
+ continue
1688
+
1532
1689
  # Log fact types in this retrieval batch
1533
- ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1534
1690
  logger.debug(
1535
- f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}"
1691
+ f"[RECALL {recall_id}] Fact type '{ft}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}"
1536
1692
  )
1537
1693
 
1538
1694
  semantic_results.extend(retrieval_result.semantic)
@@ -1546,6 +1702,8 @@ class MemoryEngine(MemoryEngineInterface):
1546
1702
  # Capture temporal constraint (same across all fact types)
1547
1703
  if retrieval_result.temporal_constraint:
1548
1704
  detected_temporal_constraint = retrieval_result.temporal_constraint
1705
+ # Collect MPFP timings
1706
+ all_mpfp_timings.extend(retrieval_result.mpfp_timings)
1549
1707
 
1550
1708
  # If no temporal results from any fact type, set to None
1551
1709
  if not temporal_results:
@@ -1564,12 +1722,12 @@ class MemoryEngine(MemoryEngineInterface):
1564
1722
  retrieval_duration = time.time() - retrieval_start
1565
1723
 
1566
1724
  step_duration = time.time() - step_start
1567
- total_retrievals = len(fact_type) * (4 if temporal_results else 3)
1568
- # Format per-method timings
1725
+ # Format per-method timings (these are the actual parallel retrieval times)
1569
1726
  timing_parts = [
1570
1727
  f"semantic={len(semantic_results)}({aggregated_timings['semantic']:.3f}s)",
1571
1728
  f"bm25={len(bm25_results)}({aggregated_timings['bm25']:.3f}s)",
1572
1729
  f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)",
1730
+ f"temporal_extraction={aggregated_timings['temporal_extraction']:.3f}s",
1573
1731
  ]
1574
1732
  temporal_info = ""
1575
1733
  if detected_temporal_constraint:
@@ -1578,9 +1736,41 @@ class MemoryEngine(MemoryEngineInterface):
1578
1736
  timing_parts.append(f"temporal={temporal_count}({aggregated_timings['temporal']:.3f}s)")
1579
1737
  temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
1580
1738
  log_buffer.append(
1581
- f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}"
1739
+ f" [2] Parallel retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {parallel_duration:.3f}s{temporal_info}"
1582
1740
  )
1583
1741
 
1742
+ # Log graph retriever timing breakdown if available
1743
+ if all_mpfp_timings:
1744
+ retriever_name = get_default_graph_retriever().name.upper()
1745
+ mpfp_total = all_mpfp_timings[0] # Take first fact type's timing as representative
1746
+ mpfp_parts = [
1747
+ f"db_queries={mpfp_total.db_queries}",
1748
+ f"edge_load={mpfp_total.edge_load_time:.3f}s",
1749
+ f"edges={mpfp_total.edge_count}",
1750
+ f"patterns={mpfp_total.pattern_count}",
1751
+ ]
1752
+ if mpfp_total.seeds_time > 0.01:
1753
+ mpfp_parts.append(f"seeds={mpfp_total.seeds_time:.3f}s")
1754
+ if mpfp_total.fusion > 0.001:
1755
+ mpfp_parts.append(f"fusion={mpfp_total.fusion:.3f}s")
1756
+ if mpfp_total.fetch > 0.001:
1757
+ mpfp_parts.append(f"fetch={mpfp_total.fetch:.3f}s")
1758
+ log_buffer.append(f" [{retriever_name}] {', '.join(mpfp_parts)}")
1759
+ # Log detailed hop timing for debugging slow queries
1760
+ if mpfp_total.hop_details:
1761
+ for hd in mpfp_total.hop_details:
1762
+ log_buffer.append(
1763
+ f" hop{hd['hop']}: exec={hd.get('exec_time', 0) * 1000:.0f}ms, "
1764
+ f"uncached={hd.get('uncached_after_filter', 0)}, "
1765
+ f"load={hd.get('load_time', 0) * 1000:.0f}ms, "
1766
+ f"edges={hd.get('edges_loaded', 0)}"
1767
+ )
1768
+
1769
+ # Record temporal constraint in tracer if detected
1770
+ if tracer and detected_temporal_constraint:
1771
+ start_dt, end_dt = detected_temporal_constraint
1772
+ tracer.record_temporal_constraint(start_dt, end_dt)
1773
+
1584
1774
  # Record retrieval results for tracer - per fact type
1585
1775
  if tracer:
1586
1776
  # Convert RetrievalResult to old tuple format for tracer
@@ -1588,8 +1778,10 @@ class MemoryEngine(MemoryEngineInterface):
1588
1778
  return [(r.id, r.__dict__) for r in results]
1589
1779
 
1590
1780
  # Add retrieval results per fact type (to show parallel execution in UI)
1591
- for idx, rr in enumerate(all_retrievals):
1592
- ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1781
+ for ft_name in fact_type:
1782
+ rr = multi_result.results_by_fact_type.get(ft_name)
1783
+ if not rr:
1784
+ continue
1593
1785
 
1594
1786
  # Add semantic retrieval results for this fact type
1595
1787
  tracer.add_retrieval_results(
@@ -1621,14 +1813,22 @@ class MemoryEngine(MemoryEngineInterface):
1621
1813
  fact_type=ft_name,
1622
1814
  )
1623
1815
 
1624
- # Add temporal retrieval results for this fact type (even if empty, to show it ran)
1625
- if rr.temporal is not None:
1816
+ # Add temporal retrieval results for this fact type
1817
+ # Show temporal even with 0 results if constraint was detected
1818
+ if rr.temporal is not None or rr.temporal_constraint is not None:
1819
+ temporal_metadata = {"budget": thinking_budget}
1820
+ if rr.temporal_constraint:
1821
+ start_dt, end_dt = rr.temporal_constraint
1822
+ temporal_metadata["constraint"] = {
1823
+ "start": start_dt.isoformat() if start_dt else None,
1824
+ "end": end_dt.isoformat() if end_dt else None,
1825
+ }
1626
1826
  tracer.add_retrieval_results(
1627
1827
  method_name="temporal",
1628
- results=to_tuple_format(rr.temporal),
1828
+ results=to_tuple_format(rr.temporal or []),
1629
1829
  duration_seconds=rr.timings.get("temporal", 0.0),
1630
1830
  score_field="temporal_score",
1631
- metadata={"budget": thinking_budget},
1831
+ metadata=temporal_metadata,
1632
1832
  fact_type=ft_name,
1633
1833
  )
1634
1834
 
@@ -1678,11 +1878,24 @@ class MemoryEngine(MemoryEngineInterface):
1678
1878
  # Ensure reranker is initialized (for lazy initialization mode)
1679
1879
  await reranker_instance.ensure_initialized()
1680
1880
 
1881
+ # Pre-filter candidates to reduce reranking cost (RRF already provides good ranking)
1882
+ # This is especially important for remote rerankers with network latency
1883
+ reranker_max_candidates = get_config().reranker_max_candidates
1884
+ pre_filtered_count = 0
1885
+ if len(merged_candidates) > reranker_max_candidates:
1886
+ # Sort by RRF score and take top candidates
1887
+ merged_candidates.sort(key=lambda mc: mc.rrf_score, reverse=True)
1888
+ pre_filtered_count = len(merged_candidates) - reranker_max_candidates
1889
+ merged_candidates = merged_candidates[:reranker_max_candidates]
1890
+
1681
1891
  # Rerank using cross-encoder
1682
- scored_results = reranker_instance.rerank(query, merged_candidates)
1892
+ scored_results = await reranker_instance.rerank(query, merged_candidates)
1683
1893
 
1684
1894
  step_duration = time.time() - step_start
1685
- log_buffer.append(f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s")
1895
+ pre_filter_note = f" (pre-filtered {pre_filtered_count})" if pre_filtered_count > 0 else ""
1896
+ log_buffer.append(
1897
+ f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s{pre_filter_note}"
1898
+ )
1686
1899
 
1687
1900
  # Step 4.5: Combine cross-encoder score with retrieval signals
1688
1901
  # This preserves retrieval work (RRF, temporal, recency) instead of pure cross-encoder ranking
@@ -1732,9 +1945,6 @@ class MemoryEngine(MemoryEngineInterface):
1732
1945
 
1733
1946
  # Re-sort by combined score
1734
1947
  scored_results.sort(key=lambda x: x.weight, reverse=True)
1735
- log_buffer.append(
1736
- " [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)"
1737
- )
1738
1948
 
1739
1949
  # Add reranked results to tracer AFTER combined scoring (so normalized values are included)
1740
1950
  if tracer:
@@ -1753,7 +1963,6 @@ class MemoryEngine(MemoryEngineInterface):
1753
1963
  # Step 5: Truncate to thinking_budget * 2 for token filtering
1754
1964
  rerank_limit = thinking_budget * 2
1755
1965
  top_scored = scored_results[:rerank_limit]
1756
- log_buffer.append(f" [5] Truncated to top {len(top_scored)} results")
1757
1966
 
1758
1967
  # Step 6: Token budget filtering
1759
1968
  step_start = time.time()
@@ -1768,7 +1977,7 @@ class MemoryEngine(MemoryEngineInterface):
1768
1977
 
1769
1978
  step_duration = time.time() - step_start
1770
1979
  log_buffer.append(
1771
- f" [6] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s"
1980
+ f" [5] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s"
1772
1981
  )
1773
1982
 
1774
1983
  if tracer:
@@ -1802,7 +2011,6 @@ class MemoryEngine(MemoryEngineInterface):
1802
2011
  visited_ids = list(set([sr.id for sr in scored_results[:50]])) # Top 50
1803
2012
  if visited_ids:
1804
2013
  await self._task_backend.submit_task({"type": "access_count_update", "node_ids": visited_ids})
1805
- log_buffer.append(f" [7] Queued access count updates for {len(visited_ids)} nodes")
1806
2014
 
1807
2015
  # Log fact_type distribution in results
1808
2016
  fact_type_counts = {}
@@ -1835,6 +2043,7 @@ class MemoryEngine(MemoryEngineInterface):
1835
2043
  top_results_dicts.append(result_dict)
1836
2044
 
1837
2045
  # Get entities for each fact if include_entities is requested
2046
+ step_start = time.time()
1838
2047
  fact_entity_map = {} # unit_id -> list of (entity_id, entity_name)
1839
2048
  if include_entities and top_scored:
1840
2049
  unit_ids = [uuid.UUID(sr.id) for sr in top_scored]
@@ -1856,6 +2065,7 @@ class MemoryEngine(MemoryEngineInterface):
1856
2065
  fact_entity_map[unit_id].append(
1857
2066
  {"entity_id": str(row["entity_id"]), "canonical_name": row["canonical_name"]}
1858
2067
  )
2068
+ entity_map_duration = time.time() - step_start
1859
2069
 
1860
2070
  # Convert results to MemoryFact objects
1861
2071
  memory_facts = []
@@ -1878,10 +2088,12 @@ class MemoryEngine(MemoryEngineInterface):
1878
2088
  mentioned_at=result_dict.get("mentioned_at"),
1879
2089
  document_id=result_dict.get("document_id"),
1880
2090
  chunk_id=result_dict.get("chunk_id"),
2091
+ tags=result_dict.get("tags"),
1881
2092
  )
1882
2093
  )
1883
2094
 
1884
2095
  # Fetch entity observations if requested
2096
+ step_start = time.time()
1885
2097
  entities_dict = None
1886
2098
  total_entity_tokens = 0
1887
2099
  total_chunk_tokens = 0
@@ -1902,7 +2114,13 @@ class MemoryEngine(MemoryEngineInterface):
1902
2114
  entities_ordered.append((entity_id, entity_name))
1903
2115
  seen_entity_ids.add(entity_id)
1904
2116
 
1905
- # Fetch observations for each entity (respect token budget, in order)
2117
+ # Fetch all observations in a single batched query
2118
+ entity_ids = [eid for eid, _ in entities_ordered]
2119
+ all_observations = await self.get_entity_observations_batch(
2120
+ bank_id, entity_ids, limit_per_entity=5, request_context=request_context
2121
+ )
2122
+
2123
+ # Build entities_dict respecting token budget, in relevance order
1906
2124
  entities_dict = {}
1907
2125
  encoding = _get_tiktoken_encoding()
1908
2126
 
@@ -1910,9 +2128,7 @@ class MemoryEngine(MemoryEngineInterface):
1910
2128
  if total_entity_tokens >= max_entity_tokens:
1911
2129
  break
1912
2130
 
1913
- observations = await self.get_entity_observations(
1914
- bank_id, entity_id, limit=5, request_context=request_context
1915
- )
2131
+ observations = all_observations.get(entity_id, [])
1916
2132
 
1917
2133
  # Calculate tokens for this entity's observations
1918
2134
  entity_tokens = 0
@@ -1930,8 +2146,10 @@ class MemoryEngine(MemoryEngineInterface):
1930
2146
  entity_id=entity_id, canonical_name=entity_name, observations=included_observations
1931
2147
  )
1932
2148
  total_entity_tokens += entity_tokens
2149
+ entity_obs_duration = time.time() - step_start
1933
2150
 
1934
2151
  # Fetch chunks if requested
2152
+ step_start = time.time()
1935
2153
  chunks_dict = None
1936
2154
  if include_chunks and top_scored:
1937
2155
  from .response_models import ChunkInfo
@@ -1991,6 +2209,12 @@ class MemoryEngine(MemoryEngineInterface):
1991
2209
  chunk_text=chunk_text, chunk_index=row["chunk_index"], truncated=False
1992
2210
  )
1993
2211
  total_chunk_tokens += chunk_tokens
2212
+ chunks_duration = time.time() - step_start
2213
+
2214
+ # Log entity/chunk fetch timing (only if any enrichment was requested)
2215
+ log_buffer.append(
2216
+ f" [6] Response enrichment: entity_map={entity_map_duration:.3f}s, entity_obs={entity_obs_duration:.3f}s, chunks={chunks_duration:.3f}s"
2217
+ )
1994
2218
 
1995
2219
  # Finalize trace if enabled
1996
2220
  trace_dict = None
@@ -2002,8 +2226,15 @@ class MemoryEngine(MemoryEngineInterface):
2002
2226
  total_time = time.time() - recall_start
2003
2227
  num_chunks = len(chunks_dict) if chunks_dict else 0
2004
2228
  num_entities = len(entities_dict) if entities_dict else 0
2229
+ # Include wait times in log if significant
2230
+ wait_parts = []
2231
+ if semaphore_wait > 0.01:
2232
+ wait_parts.append(f"sem={semaphore_wait:.3f}s")
2233
+ if max_conn_wait > 0.01:
2234
+ wait_parts.append(f"conn={max_conn_wait:.3f}s")
2235
+ wait_info = f" | waits: {', '.join(wait_parts)}" if wait_parts else ""
2005
2236
  log_buffer.append(
2006
- f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s"
2237
+ f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s{wait_info}"
2007
2238
  )
2008
2239
  logger.info("\n" + "\n".join(log_buffer))
2009
2240
 
@@ -2073,11 +2304,11 @@ class MemoryEngine(MemoryEngineInterface):
2073
2304
  doc = await conn.fetchrow(
2074
2305
  f"""
2075
2306
  SELECT d.id, d.bank_id, d.original_text, d.content_hash,
2076
- d.created_at, d.updated_at, COUNT(mu.id) as unit_count
2307
+ d.created_at, d.updated_at, d.tags, COUNT(mu.id) as unit_count
2077
2308
  FROM {fq_table("documents")} d
2078
2309
  LEFT JOIN {fq_table("memory_units")} mu ON mu.document_id = d.id
2079
2310
  WHERE d.id = $1 AND d.bank_id = $2
2080
- GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
2311
+ GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at, d.tags
2081
2312
  """,
2082
2313
  document_id,
2083
2314
  bank_id,
@@ -2094,6 +2325,7 @@ class MemoryEngine(MemoryEngineInterface):
2094
2325
  "memory_unit_count": doc["unit_count"],
2095
2326
  "created_at": doc["created_at"].isoformat() if doc["created_at"] else None,
2096
2327
  "updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else None,
2328
+ "tags": list(doc["tags"]) if doc["tags"] else [],
2097
2329
  }
2098
2330
 
2099
2331
  async def delete_document(
@@ -2199,9 +2431,10 @@ class MemoryEngine(MemoryEngineInterface):
2199
2431
  await self._authenticate_tenant(request_context)
2200
2432
  pool = await self._get_pool()
2201
2433
  async with acquire_with_retry(pool) as conn:
2202
- # Ensure connection is not in read-only mode (can happen with connection poolers)
2203
- await conn.execute("SET SESSION CHARACTERISTICS AS TRANSACTION READ WRITE")
2204
2434
  async with conn.transaction():
2435
+ # Ensure transaction is not in read-only mode (can happen with connection poolers)
2436
+ # Using SET LOCAL so it only affects this transaction, not the session
2437
+ await conn.execute("SET LOCAL transaction_read_only TO off")
2205
2438
  try:
2206
2439
  if fact_type:
2207
2440
  # Delete only memories of a specific fact type
@@ -2258,6 +2491,7 @@ class MemoryEngine(MemoryEngineInterface):
2258
2491
  bank_id: str | None = None,
2259
2492
  fact_type: str | None = None,
2260
2493
  *,
2494
+ limit: int = 1000,
2261
2495
  request_context: "RequestContext",
2262
2496
  ):
2263
2497
  """
@@ -2266,10 +2500,11 @@ class MemoryEngine(MemoryEngineInterface):
2266
2500
  Args:
2267
2501
  bank_id: Filter by bank ID
2268
2502
  fact_type: Filter by fact type (world, experience, opinion)
2503
+ limit: Maximum number of items to return (default: 1000)
2269
2504
  request_context: Request context for authentication.
2270
2505
 
2271
2506
  Returns:
2272
- Dict with nodes, edges, and table_rows
2507
+ Dict with nodes, edges, table_rows, total_units, and limit
2273
2508
  """
2274
2509
  await self._authenticate_tenant(request_context)
2275
2510
  pool = await self._get_pool()
@@ -2291,15 +2526,29 @@ class MemoryEngine(MemoryEngineInterface):
2291
2526
 
2292
2527
  where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
2293
2528
 
2529
+ # Get total count first
2530
+ total_count_result = await conn.fetchrow(
2531
+ f"""
2532
+ SELECT COUNT(*) as total
2533
+ FROM {fq_table("memory_units")}
2534
+ {where_clause}
2535
+ """,
2536
+ *query_params,
2537
+ )
2538
+ total_count = total_count_result["total"] if total_count_result else 0
2539
+
2540
+ # Get units with limit
2541
+ param_count += 1
2294
2542
  units = await conn.fetch(
2295
2543
  f"""
2296
2544
  SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
2297
2545
  FROM {fq_table("memory_units")}
2298
2546
  {where_clause}
2299
2547
  ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
2300
- LIMIT 1000
2548
+ LIMIT ${param_count}
2301
2549
  """,
2302
2550
  *query_params,
2551
+ limit,
2303
2552
  )
2304
2553
 
2305
2554
  # Get links, filtering to only include links between units of the selected agent
@@ -2436,7 +2685,7 @@ class MemoryEngine(MemoryEngineInterface):
2436
2685
  }
2437
2686
  )
2438
2687
 
2439
- return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units": len(units)}
2688
+ return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units": total_count, "limit": limit}
2440
2689
 
2441
2690
  async def list_memory_units(
2442
2691
  self,
@@ -2565,6 +2814,68 @@ class MemoryEngine(MemoryEngineInterface):
2565
2814
 
2566
2815
  return {"items": items, "total": total, "limit": limit, "offset": offset}
2567
2816
 
2817
+ async def get_memory_unit(
2818
+ self,
2819
+ bank_id: str,
2820
+ memory_id: str,
2821
+ request_context: "RequestContext",
2822
+ ):
2823
+ """
2824
+ Get a single memory unit by ID.
2825
+
2826
+ Args:
2827
+ bank_id: Bank ID
2828
+ memory_id: Memory unit ID
2829
+ request_context: Request context for authentication.
2830
+
2831
+ Returns:
2832
+ Dict with memory unit data or None if not found
2833
+ """
2834
+ await self._authenticate_tenant(request_context)
2835
+ pool = await self._get_pool()
2836
+ async with acquire_with_retry(pool) as conn:
2837
+ # Get the memory unit
2838
+ row = await conn.fetchrow(
2839
+ f"""
2840
+ SELECT id, text, context, event_date, occurred_start, occurred_end,
2841
+ mentioned_at, fact_type, document_id, chunk_id, tags
2842
+ FROM {fq_table("memory_units")}
2843
+ WHERE id = $1 AND bank_id = $2
2844
+ """,
2845
+ memory_id,
2846
+ bank_id,
2847
+ )
2848
+
2849
+ if not row:
2850
+ return None
2851
+
2852
+ # Get entity information
2853
+ entities_rows = await conn.fetch(
2854
+ f"""
2855
+ SELECT e.canonical_name
2856
+ FROM {fq_table("unit_entities")} ue
2857
+ JOIN {fq_table("entities")} e ON ue.entity_id = e.id
2858
+ WHERE ue.unit_id = $1
2859
+ """,
2860
+ row["id"],
2861
+ )
2862
+ entities = [r["canonical_name"] for r in entities_rows]
2863
+
2864
+ return {
2865
+ "id": str(row["id"]),
2866
+ "text": row["text"],
2867
+ "context": row["context"] if row["context"] else "",
2868
+ "date": row["event_date"].isoformat() if row["event_date"] else "",
2869
+ "type": row["fact_type"],
2870
+ "mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
2871
+ "occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
2872
+ "occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
2873
+ "entities": entities,
2874
+ "document_id": row["document_id"] if row["document_id"] else None,
2875
+ "chunk_id": str(row["chunk_id"]) if row["chunk_id"] else None,
2876
+ "tags": row["tags"] if row["tags"] else [],
2877
+ }
2878
+
2568
2879
  async def list_documents(
2569
2880
  self,
2570
2881
  bank_id: str,
@@ -2799,7 +3110,7 @@ Guidelines:
2799
3110
  - Small changes in confidence are normal; large jumps should be rare"""
2800
3111
 
2801
3112
  try:
2802
- result = await self._llm_config.call(
3113
+ result = await self._reflect_llm_config.call(
2803
3114
  messages=[
2804
3115
  {"role": "system", "content": "You evaluate and update opinions based on new information."},
2805
3116
  {"role": "user", "content": evaluation_prompt},
@@ -2909,7 +3220,7 @@ Guidelines:
2909
3220
  return
2910
3221
 
2911
3222
  # Use cached LLM config
2912
- if self._llm_config is None:
3223
+ if self._reflect_llm_config is None:
2913
3224
  logger.error("[REINFORCE] LLM config not available, skipping opinion reinforcement")
2914
3225
  return
2915
3226
 
@@ -3054,7 +3365,9 @@ Guidelines:
3054
3365
  """
3055
3366
  await self._authenticate_tenant(request_context)
3056
3367
  pool = await self._get_pool()
3057
- return await bank_utils.merge_bank_background(pool, self._llm_config, bank_id, new_info, update_disposition)
3368
+ return await bank_utils.merge_bank_background(
3369
+ pool, self._reflect_llm_config, bank_id, new_info, update_disposition
3370
+ )
3058
3371
 
3059
3372
  async def list_banks(
3060
3373
  self,
@@ -3086,6 +3399,8 @@ Guidelines:
3086
3399
  max_tokens: int = 4096,
3087
3400
  response_schema: dict | None = None,
3088
3401
  request_context: "RequestContext",
3402
+ tags: list[str] | None = None,
3403
+ tags_match: TagsMatch = "any",
3089
3404
  ) -> ReflectResult:
3090
3405
  """
3091
3406
  Reflect and formulate an answer using bank identity, world facts, and opinions.
@@ -3114,7 +3429,7 @@ Guidelines:
3114
3429
  - structured_output: Optional dict if response_schema was provided
3115
3430
  """
3116
3431
  # Use cached LLM config
3117
- if self._llm_config is None:
3432
+ if self._reflect_llm_config is None:
3118
3433
  raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
3119
3434
 
3120
3435
  # Authenticate tenant and set schema in context (for fq_table())
@@ -3140,16 +3455,22 @@ Guidelines:
3140
3455
 
3141
3456
  # Steps 1-3: Run multi-fact-type search (12-way retrieval: 4 methods × 3 fact types)
3142
3457
  recall_start = time.time()
3143
- search_result = await self.recall_async(
3144
- bank_id=bank_id,
3145
- query=query,
3146
- budget=budget,
3147
- max_tokens=4096,
3148
- enable_trace=False,
3149
- fact_type=["experience", "world", "opinion"],
3150
- include_entities=True,
3151
- request_context=request_context,
3152
- )
3458
+ metrics = get_metrics_collector()
3459
+ with metrics.record_operation(
3460
+ "recall", bank_id=bank_id, source="reflect", budget=budget.value if budget else None
3461
+ ):
3462
+ search_result = await self.recall_async(
3463
+ bank_id=bank_id,
3464
+ query=query,
3465
+ budget=budget,
3466
+ max_tokens=4096,
3467
+ enable_trace=False,
3468
+ fact_type=["experience", "world", "opinion"],
3469
+ include_entities=True,
3470
+ request_context=request_context,
3471
+ tags=tags,
3472
+ tags_match=tags_match,
3473
+ )
3153
3474
  recall_time = time.time() - recall_start
3154
3475
 
3155
3476
  all_results = search_result.results
@@ -3205,7 +3526,7 @@ Guidelines:
3205
3526
  response_format = JsonSchemaWrapper(response_schema)
3206
3527
 
3207
3528
  llm_start = time.time()
3208
- result = await self._llm_config.call(
3529
+ llm_result, usage = await self._reflect_llm_config.call(
3209
3530
  messages=messages,
3210
3531
  scope="memory_reflect",
3211
3532
  max_completion_tokens=max_tokens,
@@ -3214,17 +3535,18 @@ Guidelines:
3214
3535
  # Don't enforce strict_schema - not all providers support it and may retry forever
3215
3536
  # Soft enforcement (schema in prompt + json_object mode) is sufficient
3216
3537
  strict_schema=False,
3538
+ return_usage=True,
3217
3539
  )
3218
3540
  llm_time = time.time() - llm_start
3219
3541
 
3220
3542
  # Handle response based on whether structured output was requested
3221
3543
  if response_schema is not None:
3222
- structured_output = result
3544
+ structured_output = llm_result
3223
3545
  answer_text = "" # Empty for backward compatibility
3224
3546
  log_buffer.append(f"[REFLECT {reflect_id}] Structured output generated")
3225
3547
  else:
3226
3548
  structured_output = None
3227
- answer_text = result.strip()
3549
+ answer_text = llm_result.strip()
3228
3550
 
3229
3551
  # Submit form_opinion task for background processing
3230
3552
  # Pass tenant_id from request context for internal authentication in background task
@@ -3250,6 +3572,7 @@ Guidelines:
3250
3572
  based_on={"world": world_results, "experience": agent_results, "opinion": opinion_results},
3251
3573
  new_opinions=[], # Opinions are being extracted asynchronously
3252
3574
  structured_output=structured_output,
3575
+ usage=usage,
3253
3576
  )
3254
3577
 
3255
3578
  # Call post-operation hook if validator is configured
@@ -3289,7 +3612,9 @@ Guidelines:
3289
3612
  """
3290
3613
  try:
3291
3614
  # Extract opinions from the answer
3292
- new_opinions = await think_utils.extract_opinions_from_text(self._llm_config, text=answer_text, query=query)
3615
+ new_opinions = await think_utils.extract_opinions_from_text(
3616
+ self._reflect_llm_config, text=answer_text, query=query
3617
+ )
3293
3618
 
3294
3619
  # Store new opinions
3295
3620
  if new_opinions:
@@ -3360,37 +3685,110 @@ Guidelines:
3360
3685
  observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
3361
3686
  return observations
3362
3687
 
3688
+ async def get_entity_observations_batch(
3689
+ self,
3690
+ bank_id: str,
3691
+ entity_ids: list[str],
3692
+ *,
3693
+ limit_per_entity: int = 5,
3694
+ request_context: "RequestContext",
3695
+ ) -> dict[str, list[Any]]:
3696
+ """
3697
+ Get observations for multiple entities in a single query.
3698
+
3699
+ Args:
3700
+ bank_id: bank IDentifier
3701
+ entity_ids: List of entity UUIDs to get observations for
3702
+ limit_per_entity: Maximum observations per entity
3703
+ request_context: Request context for authentication.
3704
+
3705
+ Returns:
3706
+ Dict mapping entity_id -> list of EntityObservation objects
3707
+ """
3708
+ if not entity_ids:
3709
+ return {}
3710
+
3711
+ await self._authenticate_tenant(request_context)
3712
+ pool = await self._get_pool()
3713
+ async with acquire_with_retry(pool) as conn:
3714
+ # Use window function to limit observations per entity
3715
+ rows = await conn.fetch(
3716
+ f"""
3717
+ WITH ranked AS (
3718
+ SELECT
3719
+ ue.entity_id,
3720
+ mu.text,
3721
+ mu.mentioned_at,
3722
+ ROW_NUMBER() OVER (PARTITION BY ue.entity_id ORDER BY mu.mentioned_at DESC) as rn
3723
+ FROM {fq_table("memory_units")} mu
3724
+ JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
3725
+ WHERE mu.bank_id = $1
3726
+ AND mu.fact_type = 'observation'
3727
+ AND ue.entity_id = ANY($2::uuid[])
3728
+ )
3729
+ SELECT entity_id, text, mentioned_at
3730
+ FROM ranked
3731
+ WHERE rn <= $3
3732
+ ORDER BY entity_id, rn
3733
+ """,
3734
+ bank_id,
3735
+ [uuid.UUID(eid) for eid in entity_ids],
3736
+ limit_per_entity,
3737
+ )
3738
+
3739
+ result: dict[str, list[Any]] = {eid: [] for eid in entity_ids}
3740
+ for row in rows:
3741
+ entity_id = str(row["entity_id"])
3742
+ mentioned_at = row["mentioned_at"].isoformat() if row["mentioned_at"] else None
3743
+ result[entity_id].append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
3744
+ return result
3745
+
3363
3746
  async def list_entities(
3364
3747
  self,
3365
3748
  bank_id: str,
3366
3749
  *,
3367
3750
  limit: int = 100,
3751
+ offset: int = 0,
3368
3752
  request_context: "RequestContext",
3369
- ) -> list[dict[str, Any]]:
3753
+ ) -> dict[str, Any]:
3370
3754
  """
3371
- List all entities for a bank.
3755
+ List all entities for a bank with pagination.
3372
3756
 
3373
3757
  Args:
3374
3758
  bank_id: bank IDentifier
3375
3759
  limit: Maximum number of entities to return
3760
+ offset: Offset for pagination
3376
3761
  request_context: Request context for authentication.
3377
3762
 
3378
3763
  Returns:
3379
- List of entity dicts with id, canonical_name, mention_count, first_seen, last_seen
3764
+ Dict with items, total, limit, offset
3380
3765
  """
3381
3766
  await self._authenticate_tenant(request_context)
3382
3767
  pool = await self._get_pool()
3383
3768
  async with acquire_with_retry(pool) as conn:
3769
+ # Get total count
3770
+ total_row = await conn.fetchrow(
3771
+ f"""
3772
+ SELECT COUNT(*) as total
3773
+ FROM {fq_table("entities")}
3774
+ WHERE bank_id = $1
3775
+ """,
3776
+ bank_id,
3777
+ )
3778
+ total = total_row["total"] if total_row else 0
3779
+
3780
+ # Get paginated entities
3384
3781
  rows = await conn.fetch(
3385
3782
  f"""
3386
3783
  SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
3387
3784
  FROM {fq_table("entities")}
3388
3785
  WHERE bank_id = $1
3389
3786
  ORDER BY mention_count DESC, last_seen DESC
3390
- LIMIT $2
3787
+ LIMIT $2 OFFSET $3
3391
3788
  """,
3392
3789
  bank_id,
3393
3790
  limit,
3791
+ offset,
3394
3792
  )
3395
3793
 
3396
3794
  entities = []
@@ -3417,7 +3815,91 @@ Guidelines:
3417
3815
  "metadata": metadata,
3418
3816
  }
3419
3817
  )
3420
- return entities
3818
+ return {
3819
+ "items": entities,
3820
+ "total": total,
3821
+ "limit": limit,
3822
+ "offset": offset,
3823
+ }
3824
+
3825
+ async def list_tags(
3826
+ self,
3827
+ bank_id: str,
3828
+ *,
3829
+ pattern: str | None = None,
3830
+ limit: int = 100,
3831
+ offset: int = 0,
3832
+ request_context: "RequestContext",
3833
+ ) -> dict[str, Any]:
3834
+ """
3835
+ List all unique tags for a bank with usage counts.
3836
+
3837
+ Use this to discover available tags or expand wildcard patterns.
3838
+ Supports '*' as wildcard for flexible matching (case-insensitive):
3839
+ - 'user:*' matches user:alice, user:bob
3840
+ - '*-admin' matches role-admin, super-admin
3841
+ - 'env*-prod' matches env-prod, environment-prod
3842
+
3843
+ Args:
3844
+ bank_id: Bank identifier
3845
+ pattern: Wildcard pattern to filter tags (use '*' as wildcard, case-insensitive)
3846
+ limit: Maximum number of tags to return
3847
+ offset: Offset for pagination
3848
+ request_context: Request context for authentication.
3849
+
3850
+ Returns:
3851
+ Dict with items (list of {tag, count}), total, limit, offset
3852
+ """
3853
+ await self._authenticate_tenant(request_context)
3854
+ pool = await self._get_pool()
3855
+ async with acquire_with_retry(pool) as conn:
3856
+ # Build pattern filter if provided (convert * to % for ILIKE)
3857
+ pattern_clause = ""
3858
+ params: list[Any] = [bank_id]
3859
+ if pattern:
3860
+ # Convert wildcard pattern: * -> % for SQL ILIKE
3861
+ sql_pattern = pattern.replace("*", "%")
3862
+ pattern_clause = "AND tag ILIKE $2"
3863
+ params.append(sql_pattern)
3864
+
3865
+ # Get total count of distinct tags matching pattern
3866
+ total_row = await conn.fetchrow(
3867
+ f"""
3868
+ SELECT COUNT(DISTINCT tag) as total
3869
+ FROM {fq_table("memory_units")}, unnest(tags) AS tag
3870
+ WHERE bank_id = $1 AND tags IS NOT NULL AND tags != '{{}}'
3871
+ {pattern_clause}
3872
+ """,
3873
+ *params,
3874
+ )
3875
+ total = total_row["total"] if total_row else 0
3876
+
3877
+ # Get paginated tags with counts, ordered by frequency
3878
+ limit_param = len(params) + 1
3879
+ offset_param = len(params) + 2
3880
+ params.extend([limit, offset])
3881
+
3882
+ rows = await conn.fetch(
3883
+ f"""
3884
+ SELECT tag, COUNT(*) as count
3885
+ FROM {fq_table("memory_units")}, unnest(tags) AS tag
3886
+ WHERE bank_id = $1 AND tags IS NOT NULL AND tags != '{{}}'
3887
+ {pattern_clause}
3888
+ GROUP BY tag
3889
+ ORDER BY count DESC, tag ASC
3890
+ LIMIT ${limit_param} OFFSET ${offset_param}
3891
+ """,
3892
+ *params,
3893
+ )
3894
+
3895
+ items = [{"tag": row["tag"], "count": row["count"]} for row in rows]
3896
+
3897
+ return {
3898
+ "items": items,
3899
+ "total": total,
3900
+ "limit": limit,
3901
+ "offset": offset,
3902
+ }
3421
3903
 
3422
3904
  async def get_entity_state(
3423
3905
  self,
@@ -3540,7 +4022,9 @@ Guidelines:
3540
4022
  )
3541
4023
 
3542
4024
  # Step 3: Extract observations using LLM (no personality)
3543
- observations = await observation_utils.extract_observations_from_facts(self._llm_config, entity_name, facts)
4025
+ observations = await observation_utils.extract_observations_from_facts(
4026
+ self._reflect_llm_config, entity_name, facts
4027
+ )
3544
4028
 
3545
4029
  if not observations:
3546
4030
  return []
@@ -4061,6 +4545,7 @@ Guidelines:
4061
4545
  contents: list[dict[str, Any]],
4062
4546
  *,
4063
4547
  request_context: "RequestContext",
4548
+ document_tags: list[str] | None = None,
4064
4549
  ) -> dict[str, Any]:
4065
4550
  """Submit a batch retain operation to run asynchronously."""
4066
4551
  await self._authenticate_tenant(request_context)
@@ -4084,14 +4569,16 @@ Guidelines:
4084
4569
  )
4085
4570
 
4086
4571
  # Submit task to background queue
4087
- await self._task_backend.submit_task(
4088
- {
4089
- "type": "batch_retain",
4090
- "operation_id": str(operation_id),
4091
- "bank_id": bank_id,
4092
- "contents": contents,
4093
- }
4094
- )
4572
+ task_payload = {
4573
+ "type": "batch_retain",
4574
+ "operation_id": str(operation_id),
4575
+ "bank_id": bank_id,
4576
+ "contents": contents,
4577
+ }
4578
+ if document_tags:
4579
+ task_payload["document_tags"] = document_tags
4580
+
4581
+ await self._task_backend.submit_task(task_payload)
4095
4582
 
4096
4583
  logger.info(f"Retain task queued for bank_id={bank_id}, {len(contents)} items, operation_id={operation_id}")
4097
4584