hindsight-api 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. hindsight_api/admin/__init__.py +1 -0
  2. hindsight_api/admin/cli.py +311 -0
  3. hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
  4. hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
  5. hindsight_api/alembic/versions/h3c4d5e6f7g8_mental_models_v4.py +112 -0
  6. hindsight_api/alembic/versions/i4d5e6f7g8h9_delete_opinions.py +41 -0
  7. hindsight_api/alembic/versions/j5e6f7g8h9i0_mental_model_versions.py +95 -0
  8. hindsight_api/alembic/versions/k6f7g8h9i0j1_add_directive_subtype.py +58 -0
  9. hindsight_api/alembic/versions/l7g8h9i0j1k2_add_worker_columns.py +109 -0
  10. hindsight_api/alembic/versions/m8h9i0j1k2l3_mental_model_id_to_text.py +41 -0
  11. hindsight_api/alembic/versions/n9i0j1k2l3m4_learnings_and_pinned_reflections.py +134 -0
  12. hindsight_api/alembic/versions/o0j1k2l3m4n5_migrate_mental_models_data.py +113 -0
  13. hindsight_api/alembic/versions/p1k2l3m4n5o6_new_knowledge_architecture.py +194 -0
  14. hindsight_api/alembic/versions/q2l3m4n5o6p7_fix_mental_model_fact_type.py +50 -0
  15. hindsight_api/alembic/versions/r3m4n5o6p7q8_add_reflect_response_to_reflections.py +47 -0
  16. hindsight_api/alembic/versions/s4n5o6p7q8r9_add_consolidated_at_to_memory_units.py +53 -0
  17. hindsight_api/alembic/versions/t5o6p7q8r9s0_rename_mental_models_to_observations.py +134 -0
  18. hindsight_api/alembic/versions/u6p7q8r9s0t1_mental_models_text_id.py +41 -0
  19. hindsight_api/alembic/versions/v7q8r9s0t1u2_add_max_tokens_to_mental_models.py +50 -0
  20. hindsight_api/api/http.py +1406 -118
  21. hindsight_api/api/mcp.py +11 -196
  22. hindsight_api/config.py +359 -27
  23. hindsight_api/engine/consolidation/__init__.py +5 -0
  24. hindsight_api/engine/consolidation/consolidator.py +859 -0
  25. hindsight_api/engine/consolidation/prompts.py +69 -0
  26. hindsight_api/engine/cross_encoder.py +706 -88
  27. hindsight_api/engine/db_budget.py +284 -0
  28. hindsight_api/engine/db_utils.py +11 -0
  29. hindsight_api/engine/directives/__init__.py +5 -0
  30. hindsight_api/engine/directives/models.py +37 -0
  31. hindsight_api/engine/embeddings.py +553 -29
  32. hindsight_api/engine/entity_resolver.py +8 -5
  33. hindsight_api/engine/interface.py +40 -17
  34. hindsight_api/engine/llm_wrapper.py +744 -68
  35. hindsight_api/engine/memory_engine.py +2505 -1017
  36. hindsight_api/engine/mental_models/__init__.py +14 -0
  37. hindsight_api/engine/mental_models/models.py +53 -0
  38. hindsight_api/engine/query_analyzer.py +4 -3
  39. hindsight_api/engine/reflect/__init__.py +18 -0
  40. hindsight_api/engine/reflect/agent.py +933 -0
  41. hindsight_api/engine/reflect/models.py +109 -0
  42. hindsight_api/engine/reflect/observations.py +186 -0
  43. hindsight_api/engine/reflect/prompts.py +483 -0
  44. hindsight_api/engine/reflect/tools.py +437 -0
  45. hindsight_api/engine/reflect/tools_schema.py +250 -0
  46. hindsight_api/engine/response_models.py +168 -4
  47. hindsight_api/engine/retain/bank_utils.py +79 -201
  48. hindsight_api/engine/retain/fact_extraction.py +424 -195
  49. hindsight_api/engine/retain/fact_storage.py +35 -12
  50. hindsight_api/engine/retain/link_utils.py +29 -24
  51. hindsight_api/engine/retain/orchestrator.py +24 -43
  52. hindsight_api/engine/retain/types.py +11 -2
  53. hindsight_api/engine/search/graph_retrieval.py +43 -14
  54. hindsight_api/engine/search/link_expansion_retrieval.py +391 -0
  55. hindsight_api/engine/search/mpfp_retrieval.py +362 -117
  56. hindsight_api/engine/search/reranking.py +2 -2
  57. hindsight_api/engine/search/retrieval.py +848 -201
  58. hindsight_api/engine/search/tags.py +172 -0
  59. hindsight_api/engine/search/think_utils.py +42 -141
  60. hindsight_api/engine/search/trace.py +12 -1
  61. hindsight_api/engine/search/tracer.py +26 -6
  62. hindsight_api/engine/search/types.py +21 -3
  63. hindsight_api/engine/task_backend.py +113 -106
  64. hindsight_api/engine/utils.py +1 -152
  65. hindsight_api/extensions/__init__.py +10 -1
  66. hindsight_api/extensions/builtin/tenant.py +5 -1
  67. hindsight_api/extensions/context.py +10 -1
  68. hindsight_api/extensions/operation_validator.py +81 -4
  69. hindsight_api/extensions/tenant.py +26 -0
  70. hindsight_api/main.py +69 -6
  71. hindsight_api/mcp_local.py +12 -53
  72. hindsight_api/mcp_tools.py +494 -0
  73. hindsight_api/metrics.py +433 -48
  74. hindsight_api/migrations.py +141 -1
  75. hindsight_api/models.py +3 -3
  76. hindsight_api/pg0.py +53 -0
  77. hindsight_api/server.py +39 -2
  78. hindsight_api/worker/__init__.py +11 -0
  79. hindsight_api/worker/main.py +296 -0
  80. hindsight_api/worker/poller.py +486 -0
  81. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/METADATA +16 -6
  82. hindsight_api-0.4.0.dist-info/RECORD +112 -0
  83. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/entry_points.txt +2 -0
  84. hindsight_api/engine/retain/observation_regeneration.py +0 -254
  85. hindsight_api/engine/search/observation_utils.py +0 -125
  86. hindsight_api/engine/search/scoring.py +0 -159
  87. hindsight_api-0.2.1.dist-info/RECORD +0 -75
  88. {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/WHEEL +0 -0
@@ -11,6 +11,7 @@ This implements a sophisticated memory architecture that combines:
11
11
 
12
12
  import asyncio
13
13
  import contextvars
14
+ import json
14
15
  import logging
15
16
  import time
16
17
  import uuid
@@ -18,6 +19,8 @@ from datetime import UTC, datetime, timedelta
18
19
  from typing import TYPE_CHECKING, Any
19
20
 
20
21
  from ..config import get_config
22
+ from ..metrics import get_metrics_collector
23
+ from .db_budget import budgeted_operation
21
24
 
22
25
  # Context variable for current schema (async-safe, per-task isolation)
23
26
  _current_schema: contextvars.ContextVar[str] = contextvars.ContextVar("current_schema", default="public")
@@ -132,17 +135,31 @@ if TYPE_CHECKING:
132
135
 
133
136
  from enum import Enum
134
137
 
135
- from ..pg0 import EmbeddedPostgres
138
+ from ..metrics import get_metrics_collector
139
+ from ..pg0 import EmbeddedPostgres, parse_pg0_url
136
140
  from .entity_resolver import EntityResolver
137
141
  from .llm_wrapper import LLMConfig
138
142
  from .query_analyzer import QueryAnalyzer
139
- from .response_models import VALID_RECALL_FACT_TYPES, EntityObservation, EntityState, MemoryFact, ReflectResult
143
+ from .reflect import run_reflect_agent
144
+ from .reflect.tools import tool_expand, tool_recall, tool_search_mental_models, tool_search_observations
145
+ from .response_models import (
146
+ VALID_RECALL_FACT_TYPES,
147
+ EntityObservation,
148
+ EntityState,
149
+ LLMCallTrace,
150
+ MemoryFact,
151
+ ObservationRef,
152
+ ReflectResult,
153
+ TokenUsage,
154
+ ToolCallTrace,
155
+ )
140
156
  from .response_models import RecallResult as RecallResultModel
141
157
  from .retain import bank_utils, embedding_utils
142
158
  from .retain.types import RetainContentDict
143
- from .search import observation_utils, think_utils
159
+ from .search import think_utils
144
160
  from .search.reranking import CrossEncoderReranker
145
- from .task_backend import AsyncIOQueueBackend, TaskBackend
161
+ from .search.tags import TagsMatch
162
+ from .task_backend import BrokerTaskBackend, SyncTaskBackend, TaskBackend
146
163
 
147
164
 
148
165
  class Budget(str, Enum):
@@ -195,11 +212,26 @@ class MemoryEngine(MemoryEngineInterface):
195
212
  memory_llm_api_key: str | None = None,
196
213
  memory_llm_model: str | None = None,
197
214
  memory_llm_base_url: str | None = None,
215
+ # Per-operation LLM config (optional, falls back to memory_llm_* params)
216
+ retain_llm_provider: str | None = None,
217
+ retain_llm_api_key: str | None = None,
218
+ retain_llm_model: str | None = None,
219
+ retain_llm_base_url: str | None = None,
220
+ reflect_llm_provider: str | None = None,
221
+ reflect_llm_api_key: str | None = None,
222
+ reflect_llm_model: str | None = None,
223
+ reflect_llm_base_url: str | None = None,
224
+ consolidation_llm_provider: str | None = None,
225
+ consolidation_llm_api_key: str | None = None,
226
+ consolidation_llm_model: str | None = None,
227
+ consolidation_llm_base_url: str | None = None,
198
228
  embeddings: Embeddings | None = None,
199
229
  cross_encoder: CrossEncoderModel | None = None,
200
230
  query_analyzer: QueryAnalyzer | None = None,
201
- pool_min_size: int = 5,
202
- pool_max_size: int = 100,
231
+ pool_min_size: int | None = None,
232
+ pool_max_size: int | None = None,
233
+ db_command_timeout: int | None = None,
234
+ db_acquire_timeout: int | None = None,
203
235
  task_backend: TaskBackend | None = None,
204
236
  run_migrations: bool = True,
205
237
  operation_validator: "OperationValidatorExtension | None" = None,
@@ -220,12 +252,26 @@ class MemoryEngine(MemoryEngineInterface):
220
252
  memory_llm_api_key: API key for the LLM provider. Defaults to HINDSIGHT_API_LLM_API_KEY env var.
221
253
  memory_llm_model: Model name. Defaults to HINDSIGHT_API_LLM_MODEL env var.
222
254
  memory_llm_base_url: Base URL for the LLM API. Defaults based on provider.
255
+ retain_llm_provider: LLM provider for retain operations. Falls back to memory_llm_provider.
256
+ retain_llm_api_key: API key for retain LLM. Falls back to memory_llm_api_key.
257
+ retain_llm_model: Model for retain operations. Falls back to memory_llm_model.
258
+ retain_llm_base_url: Base URL for retain LLM. Falls back to memory_llm_base_url.
259
+ reflect_llm_provider: LLM provider for reflect operations. Falls back to memory_llm_provider.
260
+ reflect_llm_api_key: API key for reflect LLM. Falls back to memory_llm_api_key.
261
+ reflect_llm_model: Model for reflect operations. Falls back to memory_llm_model.
262
+ reflect_llm_base_url: Base URL for reflect LLM. Falls back to memory_llm_base_url.
263
+ consolidation_llm_provider: LLM provider for consolidation operations. Falls back to memory_llm_provider.
264
+ consolidation_llm_api_key: API key for consolidation LLM. Falls back to memory_llm_api_key.
265
+ consolidation_llm_model: Model for consolidation operations. Falls back to memory_llm_model.
266
+ consolidation_llm_base_url: Base URL for consolidation LLM. Falls back to memory_llm_base_url.
223
267
  embeddings: Embeddings implementation. If not provided, created from env vars.
224
268
  cross_encoder: Cross-encoder model. If not provided, created from env vars.
225
269
  query_analyzer: Query analyzer implementation. If not provided, uses DateparserQueryAnalyzer.
226
- pool_min_size: Minimum number of connections in the pool (default: 5)
227
- pool_max_size: Maximum number of connections in the pool (default: 100)
228
- task_backend: Custom task backend. If not provided, uses AsyncIOQueueBackend.
270
+ pool_min_size: Minimum number of connections in the pool. Defaults to HINDSIGHT_API_DB_POOL_MIN_SIZE.
271
+ pool_max_size: Maximum number of connections in the pool. Defaults to HINDSIGHT_API_DB_POOL_MAX_SIZE.
272
+ db_command_timeout: PostgreSQL command timeout in seconds. Defaults to HINDSIGHT_API_DB_COMMAND_TIMEOUT.
273
+ db_acquire_timeout: Connection acquisition timeout in seconds. Defaults to HINDSIGHT_API_DB_ACQUIRE_TIMEOUT.
274
+ task_backend: Custom task backend. If not provided, uses BrokerTaskBackend for distributed processing.
229
275
  run_migrations: Whether to run database migrations during initialize(). Default: True
230
276
  operation_validator: Optional extension to validate operations before execution.
231
277
  If provided, retain/recall/reflect operations will be validated.
@@ -252,38 +298,21 @@ class MemoryEngine(MemoryEngineInterface):
252
298
  db_url = db_url or config.database_url
253
299
  memory_llm_provider = memory_llm_provider or config.llm_provider
254
300
  memory_llm_api_key = memory_llm_api_key or config.llm_api_key
255
- # Ollama doesn't require an API key
256
- if not memory_llm_api_key and memory_llm_provider != "ollama":
301
+ # Ollama and mock don't require an API key
302
+ if not memory_llm_api_key and memory_llm_provider not in ("ollama", "mock"):
257
303
  raise ValueError("LLM API key is required. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
258
304
  memory_llm_model = memory_llm_model or config.llm_model
259
305
  memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
260
306
  # Track pg0 instance (if used)
261
307
  self._pg0: EmbeddedPostgres | None = None
262
- self._pg0_instance_name: str | None = None
263
308
 
264
309
  # Initialize PostgreSQL connection URL
265
310
  # The actual URL will be set during initialize() after starting the server
266
311
  # Supports: "pg0" (default instance), "pg0://instance-name" (named instance), or regular postgresql:// URL
267
- if db_url == "pg0":
268
- self._use_pg0 = True
269
- self._pg0_instance_name = "hindsight"
270
- self._pg0_port = None # Use default port
271
- self.db_url = None
272
- elif db_url.startswith("pg0://"):
273
- self._use_pg0 = True
274
- # Parse instance name and optional port: pg0://instance-name or pg0://instance-name:port
275
- url_part = db_url[6:] # Remove "pg0://"
276
- if ":" in url_part:
277
- self._pg0_instance_name, port_str = url_part.rsplit(":", 1)
278
- self._pg0_port = int(port_str)
279
- else:
280
- self._pg0_instance_name = url_part or "hindsight"
281
- self._pg0_port = None # Use default port
312
+ self._use_pg0, self._pg0_instance_name, self._pg0_port = parse_pg0_url(db_url)
313
+ if self._use_pg0:
282
314
  self.db_url = None
283
315
  else:
284
- self._use_pg0 = False
285
- self._pg0_instance_name = None
286
- self._pg0_port = None
287
316
  self.db_url = db_url
288
317
 
289
318
  # Set default base URL if not provided
@@ -298,8 +327,10 @@ class MemoryEngine(MemoryEngineInterface):
298
327
  # Connection pool (will be created in initialize())
299
328
  self._pool = None
300
329
  self._initialized = False
301
- self._pool_min_size = pool_min_size
302
- self._pool_max_size = pool_max_size
330
+ self._pool_min_size = pool_min_size if pool_min_size is not None else config.db_pool_min_size
331
+ self._pool_max_size = pool_max_size if pool_max_size is not None else config.db_pool_max_size
332
+ self._db_command_timeout = db_command_timeout if db_command_timeout is not None else config.db_command_timeout
333
+ self._db_acquire_timeout = db_acquire_timeout if db_acquire_timeout is not None else config.db_acquire_timeout
303
334
  self._run_migrations = run_migrations
304
335
 
305
336
  # Initialize entity resolver (will be created in initialize())
@@ -319,7 +350,7 @@ class MemoryEngine(MemoryEngineInterface):
319
350
 
320
351
  self.query_analyzer = DateparserQueryAnalyzer()
321
352
 
322
- # Initialize LLM configuration
353
+ # Initialize LLM configuration (default, used as fallback)
323
354
  self._llm_config = LLMConfig(
324
355
  provider=memory_llm_provider,
325
356
  api_key=memory_llm_api_key,
@@ -331,17 +362,84 @@ class MemoryEngine(MemoryEngineInterface):
331
362
  self._llm_client = self._llm_config._client
332
363
  self._llm_model = self._llm_config.model
333
364
 
365
+ # Initialize per-operation LLM configs (fall back to default if not specified)
366
+ # Retain LLM config - for fact extraction (benefits from strong structured output)
367
+ retain_provider = retain_llm_provider or config.retain_llm_provider or memory_llm_provider
368
+ retain_api_key = retain_llm_api_key or config.retain_llm_api_key or memory_llm_api_key
369
+ retain_model = retain_llm_model or config.retain_llm_model or memory_llm_model
370
+ retain_base_url = retain_llm_base_url or config.retain_llm_base_url or memory_llm_base_url
371
+ # Apply provider-specific base URL defaults for retain
372
+ if retain_base_url is None:
373
+ if retain_provider.lower() == "groq":
374
+ retain_base_url = "https://api.groq.com/openai/v1"
375
+ elif retain_provider.lower() == "ollama":
376
+ retain_base_url = "http://localhost:11434/v1"
377
+ else:
378
+ retain_base_url = ""
379
+
380
+ self._retain_llm_config = LLMConfig(
381
+ provider=retain_provider,
382
+ api_key=retain_api_key,
383
+ base_url=retain_base_url,
384
+ model=retain_model,
385
+ )
386
+
387
+ # Reflect LLM config - for think/observe operations (can use lighter models)
388
+ reflect_provider = reflect_llm_provider or config.reflect_llm_provider or memory_llm_provider
389
+ reflect_api_key = reflect_llm_api_key or config.reflect_llm_api_key or memory_llm_api_key
390
+ reflect_model = reflect_llm_model or config.reflect_llm_model or memory_llm_model
391
+ reflect_base_url = reflect_llm_base_url or config.reflect_llm_base_url or memory_llm_base_url
392
+ # Apply provider-specific base URL defaults for reflect
393
+ if reflect_base_url is None:
394
+ if reflect_provider.lower() == "groq":
395
+ reflect_base_url = "https://api.groq.com/openai/v1"
396
+ elif reflect_provider.lower() == "ollama":
397
+ reflect_base_url = "http://localhost:11434/v1"
398
+ else:
399
+ reflect_base_url = ""
400
+
401
+ self._reflect_llm_config = LLMConfig(
402
+ provider=reflect_provider,
403
+ api_key=reflect_api_key,
404
+ base_url=reflect_base_url,
405
+ model=reflect_model,
406
+ )
407
+
408
+ # Consolidation LLM config - for mental model consolidation (can use efficient models)
409
+ consolidation_provider = consolidation_llm_provider or config.consolidation_llm_provider or memory_llm_provider
410
+ consolidation_api_key = consolidation_llm_api_key or config.consolidation_llm_api_key or memory_llm_api_key
411
+ consolidation_model = consolidation_llm_model or config.consolidation_llm_model or memory_llm_model
412
+ consolidation_base_url = consolidation_llm_base_url or config.consolidation_llm_base_url or memory_llm_base_url
413
+ # Apply provider-specific base URL defaults for consolidation
414
+ if consolidation_base_url is None:
415
+ if consolidation_provider.lower() == "groq":
416
+ consolidation_base_url = "https://api.groq.com/openai/v1"
417
+ elif consolidation_provider.lower() == "ollama":
418
+ consolidation_base_url = "http://localhost:11434/v1"
419
+ else:
420
+ consolidation_base_url = ""
421
+
422
+ self._consolidation_llm_config = LLMConfig(
423
+ provider=consolidation_provider,
424
+ api_key=consolidation_api_key,
425
+ base_url=consolidation_base_url,
426
+ model=consolidation_model,
427
+ )
428
+
334
429
  # Initialize cross-encoder reranker (cached for performance)
335
430
  self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
336
431
 
337
432
  # Initialize task backend
338
- self._task_backend = task_backend or AsyncIOQueueBackend(batch_size=100, batch_interval=1.0)
433
+ # If no custom backend provided, use BrokerTaskBackend which stores tasks in PostgreSQL
434
+ # The pool_getter lambda will return the pool once it's initialized
435
+ self._task_backend = task_backend or BrokerTaskBackend(
436
+ pool_getter=lambda: self._pool,
437
+ schema_getter=get_current_schema,
438
+ )
339
439
 
340
440
  # Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
341
- # Limit concurrent searches to prevent connection pool exhaustion
342
- # Each search can use 2-4 connections, so with 10 concurrent searches
343
- # we use ~20-40 connections max, staying well within pool limits
344
- self._search_semaphore = asyncio.Semaphore(10)
441
+ # Configurable via HINDSIGHT_API_RECALL_MAX_CONCURRENT (default: 50)
442
+ self._search_semaphore = asyncio.Semaphore(get_config().recall_max_concurrent)
345
443
 
346
444
  # Backpressure for put operations: limit concurrent puts to prevent database contention
347
445
  # Each put_batch holds a connection for the entire transaction, so we limit to 5
@@ -401,35 +499,19 @@ class MemoryEngine(MemoryEngineInterface):
401
499
  if request_context is None:
402
500
  raise AuthenticationError("RequestContext is required when tenant extension is configured")
403
501
 
502
+ # For internal/background operations (e.g., worker tasks), skip extension authentication
503
+ # if the schema has already been set by execute_task via the _schema field.
504
+ if request_context.internal:
505
+ current = _current_schema.get()
506
+ if current and current != "public":
507
+ return current
508
+
404
509
  # Let AuthenticationError propagate - HTTP layer will convert to 401
405
510
  tenant_context = await self._tenant_extension.authenticate(request_context)
406
511
 
407
512
  _current_schema.set(tenant_context.schema_name)
408
513
  return tenant_context.schema_name
409
514
 
410
- async def _handle_access_count_update(self, task_dict: dict[str, Any]):
411
- """
412
- Handler for access count update tasks.
413
-
414
- Args:
415
- task_dict: Dict with 'node_ids' key containing list of node IDs to update
416
-
417
- Raises:
418
- Exception: Any exception from database operations (propagates to execute_task for retry)
419
- """
420
- node_ids = task_dict.get("node_ids", [])
421
- if not node_ids:
422
- return
423
-
424
- pool = await self._get_pool()
425
- # Convert string UUIDs to UUID type for faster matching
426
- uuid_list = [uuid.UUID(nid) for nid in node_ids]
427
- async with acquire_with_retry(pool) as conn:
428
- await conn.execute(
429
- f"UPDATE {fq_table('memory_units')} SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
430
- uuid_list,
431
- )
432
-
433
515
  async def _handle_batch_retain(self, task_dict: dict[str, Any]):
434
516
  """
435
517
  Handler for batch retain tasks.
@@ -450,14 +532,113 @@ class MemoryEngine(MemoryEngineInterface):
450
532
  f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
451
533
  )
452
534
 
453
- # Use internal request context for background tasks
535
+ # Use internal request context for background tasks (skips tenant auth when schema is pre-set)
454
536
  from hindsight_api.models import RequestContext
455
537
 
456
- internal_context = RequestContext()
538
+ internal_context = RequestContext(internal=True)
457
539
  await self.retain_batch_async(bank_id=bank_id, contents=contents, request_context=internal_context)
458
540
 
459
541
  logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
460
542
 
543
+ async def _handle_consolidation(self, task_dict: dict[str, Any]):
544
+ """
545
+ Handler for consolidation tasks.
546
+
547
+ Consolidates new memories into mental models for a bank.
548
+
549
+ Args:
550
+ task_dict: Dict with 'bank_id'
551
+
552
+ Raises:
553
+ ValueError: If bank_id is missing
554
+ Exception: Any exception from consolidation (propagates to execute_task for retry)
555
+ """
556
+ bank_id = task_dict.get("bank_id")
557
+ if not bank_id:
558
+ raise ValueError("bank_id is required for consolidation task")
559
+
560
+ from hindsight_api.models import RequestContext
561
+
562
+ from .consolidation import run_consolidation_job
563
+
564
+ internal_context = RequestContext(internal=True)
565
+ result = await run_consolidation_job(
566
+ memory_engine=self,
567
+ bank_id=bank_id,
568
+ request_context=internal_context,
569
+ )
570
+
571
+ logger.info(f"[CONSOLIDATION] bank={bank_id} completed: {result.get('memories_processed', 0)} processed")
572
+
573
+ async def _handle_refresh_mental_model(self, task_dict: dict[str, Any]):
574
+ """
575
+ Handler for refresh_mental_model tasks.
576
+
577
+ Re-runs the source query through reflect and updates the mental model content.
578
+
579
+ Args:
580
+ task_dict: Dict with 'bank_id', 'mental_model_id', 'operation_id'
581
+
582
+ Raises:
583
+ ValueError: If required fields are missing
584
+ Exception: Any exception from reflect/update (propagates to execute_task for retry)
585
+ """
586
+ bank_id = task_dict.get("bank_id")
587
+ mental_model_id = task_dict.get("mental_model_id")
588
+
589
+ if not bank_id or not mental_model_id:
590
+ raise ValueError("bank_id and mental_model_id are required for refresh_mental_model task")
591
+
592
+ logger.info(f"[REFRESH_MENTAL_MODEL_TASK] Starting for bank_id={bank_id}, mental_model_id={mental_model_id}")
593
+
594
+ from hindsight_api.models import RequestContext
595
+
596
+ internal_context = RequestContext(internal=True)
597
+
598
+ # Get the current mental model to get source_query
599
+ mental_model = await self.get_mental_model(bank_id, mental_model_id, request_context=internal_context)
600
+ if not mental_model:
601
+ raise ValueError(f"Mental model {mental_model_id} not found in bank {bank_id}")
602
+
603
+ source_query = mental_model["source_query"]
604
+
605
+ # Run reflect to generate new content, excluding the mental model being refreshed
606
+ reflect_result = await self.reflect_async(
607
+ bank_id=bank_id,
608
+ query=source_query,
609
+ request_context=internal_context,
610
+ exclude_mental_model_ids=[mental_model_id],
611
+ )
612
+
613
+ generated_content = reflect_result.text or "No content generated"
614
+
615
+ # Build reflect_response payload to store
616
+ reflect_response = {
617
+ "text": reflect_result.text,
618
+ "based_on": {
619
+ fact_type: [
620
+ {
621
+ "id": str(fact.id),
622
+ "text": fact.text,
623
+ "type": fact_type,
624
+ }
625
+ for fact in facts
626
+ ]
627
+ for fact_type, facts in reflect_result.based_on.items()
628
+ },
629
+ }
630
+
631
+ # Update the mental model with the generated content and reflect_response
632
+ await self.update_mental_model(
633
+ bank_id=bank_id,
634
+ mental_model_id=mental_model_id,
635
+ content=generated_content,
636
+ reflect_response=reflect_response,
637
+ request_context=internal_context,
638
+ )
639
+
640
+ logger.info(f"[REFRESH_MENTAL_MODEL_TASK] Completed for bank_id={bank_id}, mental_model_id={mental_model_id}")
641
+
461
642
  async def execute_task(self, task_dict: dict[str, Any]):
462
643
  """
463
644
  Execute a task by routing it to the appropriate handler.
@@ -467,13 +648,18 @@ class MemoryEngine(MemoryEngineInterface):
467
648
 
468
649
  Args:
469
650
  task_dict: Task dictionary with 'type' key and other payload data
470
- Example: {'type': 'access_count_update', 'node_ids': [...]}
651
+ Example: {'type': 'batch_retain', 'bank_id': '...', 'contents': [...]}
471
652
  """
472
653
  task_type = task_dict.get("type")
473
654
  operation_id = task_dict.get("operation_id")
474
655
  retry_count = task_dict.get("retry_count", 0)
475
656
  max_retries = 3
476
657
 
658
+ # Set schema context for multi-tenant task execution
659
+ schema = task_dict.pop("_schema", None)
660
+ if schema:
661
+ _current_schema.set(schema)
662
+
477
663
  # Check if operation was cancelled (only for tasks with operation_id)
478
664
  if operation_id:
479
665
  try:
@@ -492,16 +678,12 @@ class MemoryEngine(MemoryEngineInterface):
492
678
  # Continue with processing if we can't check status
493
679
 
494
680
  try:
495
- if task_type == "access_count_update":
496
- await self._handle_access_count_update(task_dict)
497
- elif task_type == "reinforce_opinion":
498
- await self._handle_reinforce_opinion(task_dict)
499
- elif task_type == "form_opinion":
500
- await self._handle_form_opinion(task_dict)
501
- elif task_type == "batch_retain":
681
+ if task_type == "batch_retain":
502
682
  await self._handle_batch_retain(task_dict)
503
- elif task_type == "regenerate_observations":
504
- await self._handle_regenerate_observations(task_dict)
683
+ elif task_type == "consolidation":
684
+ await self._handle_consolidation(task_dict)
685
+ elif task_type == "refresh_mental_model":
686
+ await self._handle_refresh_mental_model(task_dict)
505
687
  else:
506
688
  logger.error(f"Unknown task type: {task_type}")
507
689
  # Don't retry unknown task types
@@ -509,9 +691,9 @@ class MemoryEngine(MemoryEngineInterface):
509
691
  await self._delete_operation_record(operation_id)
510
692
  return
511
693
 
512
- # Task succeeded - delete operation record
694
+ # Task succeeded - mark operation as completed
513
695
  if operation_id:
514
- await self._delete_operation_record(operation_id)
696
+ await self._mark_operation_completed(operation_id)
515
697
 
516
698
  except Exception as e:
517
699
  # Task failed - check if we should retry
@@ -557,7 +739,7 @@ class MemoryEngine(MemoryEngineInterface):
557
739
  await conn.execute(
558
740
  f"""
559
741
  UPDATE {fq_table("async_operations")}
560
- SET status = 'failed', error_message = $2
742
+ SET status = 'failed', error_message = $2, updated_at = NOW()
561
743
  WHERE operation_id = $1
562
744
  """,
563
745
  uuid.UUID(operation_id),
@@ -567,6 +749,23 @@ class MemoryEngine(MemoryEngineInterface):
567
749
  except Exception as e:
568
750
  logger.error(f"Failed to mark operation as failed {operation_id}: {e}")
569
751
 
752
+ async def _mark_operation_completed(self, operation_id: str):
753
+ """Helper to mark an operation as completed in the database."""
754
+ try:
755
+ pool = await self._get_pool()
756
+ async with acquire_with_retry(pool) as conn:
757
+ await conn.execute(
758
+ f"""
759
+ UPDATE {fq_table("async_operations")}
760
+ SET status = 'completed', updated_at = NOW(), completed_at = NOW()
761
+ WHERE operation_id = $1
762
+ """,
763
+ uuid.UUID(operation_id),
764
+ )
765
+ logger.info(f"Marked async operation as completed: {operation_id}")
766
+ except Exception as e:
767
+ logger.error(f"Failed to mark operation as completed {operation_id}: {e}")
768
+
570
769
  async def initialize(self):
571
770
  """Initialize the connection pool, models, and background workers.
572
771
 
@@ -618,9 +817,44 @@ class MemoryEngine(MemoryEngineInterface):
618
817
  await loop.run_in_executor(None, self.query_analyzer.load)
619
818
 
620
819
  async def verify_llm():
621
- """Verify LLM connection is working."""
820
+ """Verify LLM connections are working for all unique configs."""
622
821
  if not self._skip_llm_verification:
822
+ # Verify default config
623
823
  await self._llm_config.verify_connection()
824
+ # Verify retain config if different from default
825
+ retain_is_different = (
826
+ self._retain_llm_config.provider != self._llm_config.provider
827
+ or self._retain_llm_config.model != self._llm_config.model
828
+ )
829
+ if retain_is_different:
830
+ await self._retain_llm_config.verify_connection()
831
+ # Verify reflect config if different from default and retain
832
+ reflect_is_different = (
833
+ self._reflect_llm_config.provider != self._llm_config.provider
834
+ or self._reflect_llm_config.model != self._llm_config.model
835
+ ) and (
836
+ self._reflect_llm_config.provider != self._retain_llm_config.provider
837
+ or self._reflect_llm_config.model != self._retain_llm_config.model
838
+ )
839
+ if reflect_is_different:
840
+ await self._reflect_llm_config.verify_connection()
841
+ # Verify consolidation config if different from all others
842
+ consolidation_is_different = (
843
+ (
844
+ self._consolidation_llm_config.provider != self._llm_config.provider
845
+ or self._consolidation_llm_config.model != self._llm_config.model
846
+ )
847
+ and (
848
+ self._consolidation_llm_config.provider != self._retain_llm_config.provider
849
+ or self._consolidation_llm_config.model != self._retain_llm_config.model
850
+ )
851
+ and (
852
+ self._consolidation_llm_config.provider != self._reflect_llm_config.provider
853
+ or self._consolidation_llm_config.model != self._reflect_llm_config.model
854
+ )
855
+ )
856
+ if consolidation_is_different:
857
+ await self._consolidation_llm_config.verify_connection()
624
858
 
625
859
  # Build list of initialization tasks
626
860
  init_tasks = [
@@ -642,13 +876,17 @@ class MemoryEngine(MemoryEngineInterface):
642
876
 
643
877
  # Run database migrations if enabled
644
878
  if self._run_migrations:
645
- from ..migrations import run_migrations
879
+ from ..migrations import ensure_embedding_dimension, run_migrations
646
880
 
647
881
  if not self.db_url:
648
882
  raise ValueError("Database URL is required for migrations")
649
883
  logger.info("Running database migrations...")
650
884
  run_migrations(self.db_url)
651
885
 
886
+ # Ensure embedding column dimension matches the model's dimension
887
+ # This is done after migrations and after embeddings.initialize()
888
+ ensure_embedding_dimension(self.db_url, self.embeddings.dimension)
889
+
652
890
  logger.info(f"Connecting to PostgreSQL at {self.db_url}")
653
891
 
654
892
  # Create connection pool
@@ -658,9 +896,9 @@ class MemoryEngine(MemoryEngineInterface):
658
896
  self.db_url,
659
897
  min_size=self._pool_min_size,
660
898
  max_size=self._pool_max_size,
661
- command_timeout=60,
899
+ command_timeout=self._db_command_timeout,
662
900
  statement_cache_size=0, # Disable prepared statement cache
663
- timeout=30, # Connection acquisition timeout (seconds)
901
+ timeout=self._db_acquire_timeout, # Connection acquisition timeout (seconds)
664
902
  )
665
903
 
666
904
  # Initialize entity resolver with pool
@@ -743,8 +981,7 @@ class MemoryEngine(MemoryEngineInterface):
743
981
  """
744
982
  Wait for all pending background tasks to complete.
745
983
 
746
- This is useful in tests to ensure background tasks (like opinion reinforcement)
747
- complete before making assertions.
984
+ This is useful in tests to ensure background tasks complete before making assertions.
748
985
  """
749
986
  if hasattr(self._task_backend, "wait_for_pending_tasks"):
750
987
  await self._task_backend.wait_for_pending_tasks()
@@ -967,7 +1204,9 @@ class MemoryEngine(MemoryEngineInterface):
967
1204
  document_id: str | None = None,
968
1205
  fact_type_override: str | None = None,
969
1206
  confidence_score: float | None = None,
970
- ) -> list[list[str]]:
1207
+ document_tags: list[str] | None = None,
1208
+ return_usage: bool = False,
1209
+ ):
971
1210
  """
972
1211
  Store multiple content items as memory units in ONE batch operation.
973
1212
 
@@ -988,9 +1227,11 @@ class MemoryEngine(MemoryEngineInterface):
988
1227
  Applies the same document_id to ALL content items that don't specify their own.
989
1228
  fact_type_override: Override fact type for all facts ('world', 'experience', 'opinion')
990
1229
  confidence_score: Confidence score for opinions (0.0 to 1.0)
1230
+ return_usage: If True, returns tuple of (unit_ids, TokenUsage). Default False for backward compatibility.
991
1231
 
992
1232
  Returns:
993
- List of lists of unit IDs (one list per content item)
1233
+ If return_usage=False: List of lists of unit IDs (one list per content item)
1234
+ If return_usage=True: Tuple of (unit_ids, TokenUsage)
994
1235
 
995
1236
  Example (new style - per-content document_id):
996
1237
  unit_ids = await memory.retain_batch_async(
@@ -1017,6 +1258,8 @@ class MemoryEngine(MemoryEngineInterface):
1017
1258
  start_time = time.time()
1018
1259
 
1019
1260
  if not contents:
1261
+ if return_usage:
1262
+ return [], TokenUsage()
1020
1263
  return []
1021
1264
 
1022
1265
  # Authenticate tenant and set schema in context (for fq_table())
@@ -1046,6 +1289,7 @@ class MemoryEngine(MemoryEngineInterface):
1046
1289
  # Auto-chunk large batches by character count to avoid timeouts and memory issues
1047
1290
  # Calculate total character count
1048
1291
  total_chars = sum(len(item.get("content", "")) for item in contents)
1292
+ total_usage = TokenUsage()
1049
1293
 
1050
1294
  CHARS_PER_BATCH = 600_000
1051
1295
 
@@ -1078,7 +1322,7 @@ class MemoryEngine(MemoryEngineInterface):
1078
1322
 
1079
1323
  logger.info(f"Split into {len(sub_batches)} sub-batches: {[len(b) for b in sub_batches]} items each")
1080
1324
 
1081
- # Process each sub-batch using internal method (skip chunking check)
1325
+ # Process each sub-batch
1082
1326
  all_results = []
1083
1327
  for i, sub_batch in enumerate(sub_batches, 1):
1084
1328
  sub_batch_chars = sum(len(item.get("content", "")) for item in sub_batch)
@@ -1086,15 +1330,17 @@ class MemoryEngine(MemoryEngineInterface):
1086
1330
  f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars"
1087
1331
  )
1088
1332
 
1089
- sub_results = await self._retain_batch_async_internal(
1333
+ sub_results, sub_usage = await self._retain_batch_async_internal(
1090
1334
  bank_id=bank_id,
1091
1335
  contents=sub_batch,
1092
1336
  document_id=document_id,
1093
1337
  is_first_batch=i == 1, # Only upsert on first batch
1094
1338
  fact_type_override=fact_type_override,
1095
1339
  confidence_score=confidence_score,
1340
+ document_tags=document_tags,
1096
1341
  )
1097
1342
  all_results.extend(sub_results)
1343
+ total_usage = total_usage + sub_usage
1098
1344
 
1099
1345
  total_time = time.time() - start_time
1100
1346
  logger.info(
@@ -1103,13 +1349,14 @@ class MemoryEngine(MemoryEngineInterface):
1103
1349
  result = all_results
1104
1350
  else:
1105
1351
  # Small batch - use internal method directly
1106
- result = await self._retain_batch_async_internal(
1352
+ result, total_usage = await self._retain_batch_async_internal(
1107
1353
  bank_id=bank_id,
1108
1354
  contents=contents,
1109
1355
  document_id=document_id,
1110
1356
  is_first_batch=True,
1111
1357
  fact_type_override=fact_type_override,
1112
1358
  confidence_score=confidence_score,
1359
+ document_tags=document_tags,
1113
1360
  )
1114
1361
 
1115
1362
  # Call post-operation hook if validator is configured
@@ -1132,6 +1379,19 @@ class MemoryEngine(MemoryEngineInterface):
1132
1379
  except Exception as e:
1133
1380
  logger.warning(f"Post-retain hook error (non-fatal): {e}")
1134
1381
 
1382
+ # Trigger consolidation as a tracked async operation if enabled
1383
+ from ..config import get_config
1384
+
1385
+ config = get_config()
1386
+ if config.enable_observations:
1387
+ try:
1388
+ await self.submit_async_consolidation(bank_id=bank_id, request_context=request_context)
1389
+ except Exception as e:
1390
+ # Log but don't fail the retain - consolidation is non-critical
1391
+ logger.warning(f"Failed to submit consolidation task for bank {bank_id}: {e}")
1392
+
1393
+ if return_usage:
1394
+ return result, total_usage
1135
1395
  return result
1136
1396
 
1137
1397
  async def _retain_batch_async_internal(
@@ -1142,7 +1402,8 @@ class MemoryEngine(MemoryEngineInterface):
1142
1402
  is_first_batch: bool = True,
1143
1403
  fact_type_override: str | None = None,
1144
1404
  confidence_score: float | None = None,
1145
- ) -> list[list[str]]:
1405
+ document_tags: list[str] | None = None,
1406
+ ) -> tuple[list[list[str]], "TokenUsage"]:
1146
1407
  """
1147
1408
  Internal method for batch processing without chunking logic.
1148
1409
 
@@ -1158,6 +1419,10 @@ class MemoryEngine(MemoryEngineInterface):
1158
1419
  is_first_batch: Whether this is the first batch (for chunked operations, only delete on first batch)
1159
1420
  fact_type_override: Override fact type for all facts
1160
1421
  confidence_score: Confidence score for opinions
1422
+ document_tags: Tags applied to all items in this batch
1423
+
1424
+ Returns:
1425
+ Tuple of (unit ID lists, token usage for fact extraction)
1161
1426
  """
1162
1427
  # Backpressure: limit concurrent retains to prevent database contention
1163
1428
  async with self._put_semaphore:
@@ -1168,9 +1433,8 @@ class MemoryEngine(MemoryEngineInterface):
1168
1433
  return await orchestrator.retain_batch(
1169
1434
  pool=pool,
1170
1435
  embeddings_model=self.embeddings,
1171
- llm_config=self._llm_config,
1436
+ llm_config=self._retain_llm_config,
1172
1437
  entity_resolver=self.entity_resolver,
1173
- task_backend=self._task_backend,
1174
1438
  format_date_fn=self._format_readable_date,
1175
1439
  duplicate_checker_fn=self._find_duplicate_facts_batch,
1176
1440
  bank_id=bank_id,
@@ -1179,6 +1443,7 @@ class MemoryEngine(MemoryEngineInterface):
1179
1443
  is_first_batch=is_first_batch,
1180
1444
  fact_type_override=fact_type_override,
1181
1445
  confidence_score=confidence_score,
1446
+ document_tags=document_tags,
1182
1447
  )
1183
1448
 
1184
1449
  def recall(
@@ -1237,6 +1502,10 @@ class MemoryEngine(MemoryEngineInterface):
1237
1502
  include_chunks: bool = False,
1238
1503
  max_chunk_tokens: int = 8192,
1239
1504
  request_context: "RequestContext",
1505
+ tags: list[str] | None = None,
1506
+ tags_match: TagsMatch = "any",
1507
+ _connection_budget: int | None = None,
1508
+ _quiet: bool = False,
1240
1509
  ) -> RecallResultModel:
1241
1510
  """
1242
1511
  Recall memories using N*4-way parallel retrieval (N fact types × 4 retrieval methods).
@@ -1262,6 +1531,8 @@ class MemoryEngine(MemoryEngineInterface):
1262
1531
  max_entity_tokens: Maximum tokens for entity observations (default 500)
1263
1532
  include_chunks: Whether to include raw chunks in the response
1264
1533
  max_chunk_tokens: Maximum tokens for chunks (default 8192)
1534
+ tags: Optional list of tags for visibility filtering (OR matching - returns
1535
+ memories that have at least one matching tag)
1265
1536
 
1266
1537
  Returns:
1267
1538
  RecallResultModel containing:
@@ -1285,6 +1556,12 @@ class MemoryEngine(MemoryEngineInterface):
1285
1556
  f"Must be one of: {', '.join(sorted(VALID_RECALL_FACT_TYPES))}"
1286
1557
  )
1287
1558
 
1559
+ # Filter out 'opinion' - opinions are no longer returned from recall
1560
+ fact_type = [ft for ft in fact_type if ft != "opinion"]
1561
+ if not fact_type:
1562
+ # All requested types were opinions - return empty result
1563
+ return RecallResultModel(results=[], entities={}, chunks={})
1564
+
1288
1565
  # Validate operation if validator is configured
1289
1566
  if self._operation_validator:
1290
1567
  from hindsight_api.extensions import RecallContext
@@ -1310,10 +1587,17 @@ class MemoryEngine(MemoryEngineInterface):
1310
1587
  effective_budget = budget if budget is not None else Budget.MID
1311
1588
  thinking_budget = budget_mapping[effective_budget]
1312
1589
 
1590
+ # Log recall start with tags if present (skip if quiet mode for internal operations)
1591
+ if not _quiet:
1592
+ tags_info = f", tags={tags} ({tags_match})" if tags else ""
1593
+ logger.info(f"[RECALL {bank_id[:8]}] Starting recall for query: {query[:50]}...{tags_info}")
1594
+
1313
1595
  # Backpressure: limit concurrent recalls to prevent overwhelming the database
1314
1596
  result = None
1315
1597
  error_msg = None
1598
+ semaphore_wait_start = time.time()
1316
1599
  async with self._search_semaphore:
1600
+ semaphore_wait = time.time() - semaphore_wait_start
1317
1601
  # Retry loop for connection errors
1318
1602
  max_retries = 3
1319
1603
  for attempt in range(max_retries + 1):
@@ -1331,6 +1615,11 @@ class MemoryEngine(MemoryEngineInterface):
1331
1615
  include_chunks,
1332
1616
  max_chunk_tokens,
1333
1617
  request_context,
1618
+ semaphore_wait=semaphore_wait,
1619
+ tags=tags,
1620
+ tags_match=tags_match,
1621
+ connection_budget=_connection_budget,
1622
+ quiet=_quiet,
1334
1623
  )
1335
1624
  break # Success - exit retry loop
1336
1625
  except Exception as e:
@@ -1448,6 +1737,11 @@ class MemoryEngine(MemoryEngineInterface):
1448
1737
  include_chunks: bool = False,
1449
1738
  max_chunk_tokens: int = 8192,
1450
1739
  request_context: "RequestContext" = None,
1740
+ semaphore_wait: float = 0.0,
1741
+ tags: list[str] | None = None,
1742
+ tags_match: TagsMatch = "any",
1743
+ connection_budget: int | None = None,
1744
+ quiet: bool = False,
1451
1745
  ) -> RecallResultModel:
1452
1746
  """
1453
1747
  Search implementation with modular retrieval and reranking.
@@ -1477,7 +1771,9 @@ class MemoryEngine(MemoryEngineInterface):
1477
1771
  # Initialize tracer if requested
1478
1772
  from .search.tracer import SearchTracer
1479
1773
 
1480
- tracer = SearchTracer(query, thinking_budget, max_tokens) if enable_trace else None
1774
+ tracer = (
1775
+ SearchTracer(query, thinking_budget, max_tokens, tags=tags, tags_match=tags_match) if enable_trace else None
1776
+ )
1481
1777
  if tracer:
1482
1778
  tracer.start()
1483
1779
 
@@ -1487,8 +1783,9 @@ class MemoryEngine(MemoryEngineInterface):
1487
1783
  # Buffer logs for clean output in concurrent scenarios
1488
1784
  recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
1489
1785
  log_buffer = []
1786
+ tags_info = f", tags={tags}, tags_match={tags_match}" if tags else ""
1490
1787
  log_buffer.append(
1491
- f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})"
1788
+ f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens}{tags_info})"
1492
1789
  )
1493
1790
 
1494
1791
  try:
@@ -1502,37 +1799,70 @@ class MemoryEngine(MemoryEngineInterface):
1502
1799
  tracer.record_query_embedding(query_embedding)
1503
1800
  tracer.add_phase_metric("generate_query_embedding", step_duration)
1504
1801
 
1505
- # Step 2: N*4-Way Parallel Retrieval (N fact types × 4 retrieval methods)
1802
+ # Step 2: Optimized parallel retrieval using batched queries
1803
+ # - Semantic + BM25 combined in 1 CTE query for ALL fact types
1804
+ # - Graph runs per fact type (complex traversal)
1805
+ # - Temporal runs per fact type (if constraint detected)
1506
1806
  step_start = time.time()
1507
1807
  query_embedding_str = str(query_embedding)
1508
1808
 
1509
- from .search.retrieval import retrieve_parallel
1809
+ from .search.retrieval import (
1810
+ get_default_graph_retriever,
1811
+ retrieve_all_fact_types_parallel,
1812
+ )
1510
1813
 
1511
1814
  # Track each retrieval start time
1512
1815
  retrieval_start = time.time()
1513
1816
 
1514
- # Run retrieval for each fact type in parallel
1515
- retrieval_tasks = [
1516
- retrieve_parallel(
1517
- pool, query, query_embedding_str, bank_id, ft, thinking_budget, question_date, self.query_analyzer
1817
+ # Run optimized retrieval with connection budget
1818
+ config = get_config()
1819
+ effective_connection_budget = (
1820
+ connection_budget if connection_budget is not None else config.recall_connection_budget
1821
+ )
1822
+ async with budgeted_operation(
1823
+ max_connections=effective_connection_budget,
1824
+ operation_id=f"recall-{recall_id}",
1825
+ ) as op:
1826
+ budgeted_pool = op.wrap_pool(pool)
1827
+ parallel_start = time.time()
1828
+ multi_result = await retrieve_all_fact_types_parallel(
1829
+ budgeted_pool,
1830
+ query,
1831
+ query_embedding_str,
1832
+ bank_id,
1833
+ fact_type, # Pass all fact types at once
1834
+ thinking_budget,
1835
+ question_date,
1836
+ self.query_analyzer,
1837
+ tags=tags,
1838
+ tags_match=tags_match,
1518
1839
  )
1519
- for ft in fact_type
1520
- ]
1521
- all_retrievals = await asyncio.gather(*retrieval_tasks)
1840
+ parallel_duration = time.time() - parallel_start
1522
1841
 
1523
1842
  # Combine all results from all fact types and aggregate timings
1524
1843
  semantic_results = []
1525
1844
  bm25_results = []
1526
1845
  graph_results = []
1527
1846
  temporal_results = []
1528
- aggregated_timings = {"semantic": 0.0, "bm25": 0.0, "graph": 0.0, "temporal": 0.0}
1847
+ aggregated_timings = {
1848
+ "semantic": 0.0,
1849
+ "bm25": 0.0,
1850
+ "graph": 0.0,
1851
+ "temporal": 0.0,
1852
+ "temporal_extraction": 0.0,
1853
+ }
1854
+ all_mpfp_timings = []
1529
1855
 
1530
1856
  detected_temporal_constraint = None
1531
- for idx, retrieval_result in enumerate(all_retrievals):
1857
+ max_conn_wait = multi_result.max_conn_wait
1858
+ for ft in fact_type:
1859
+ retrieval_result = multi_result.results_by_fact_type.get(ft)
1860
+ if not retrieval_result:
1861
+ continue
1862
+
1532
1863
  # Log fact types in this retrieval batch
1533
- ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1534
1864
  logger.debug(
1535
- f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}"
1865
+ f"[RECALL {recall_id}] Fact type '{ft}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}"
1536
1866
  )
1537
1867
 
1538
1868
  semantic_results.extend(retrieval_result.semantic)
@@ -1570,6 +1900,7 @@ class MemoryEngine(MemoryEngineInterface):
1570
1900
  f"semantic={len(semantic_results)}({aggregated_timings['semantic']:.3f}s)",
1571
1901
  f"bm25={len(bm25_results)}({aggregated_timings['bm25']:.3f}s)",
1572
1902
  f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)",
1903
+ f"temporal_extraction={aggregated_timings['temporal_extraction']:.3f}s",
1573
1904
  ]
1574
1905
  temporal_info = ""
1575
1906
  if detected_temporal_constraint:
@@ -1578,9 +1909,41 @@ class MemoryEngine(MemoryEngineInterface):
1578
1909
  timing_parts.append(f"temporal={temporal_count}({aggregated_timings['temporal']:.3f}s)")
1579
1910
  temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
1580
1911
  log_buffer.append(
1581
- f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}"
1912
+ f" [2] Parallel retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {parallel_duration:.3f}s{temporal_info}"
1582
1913
  )
1583
1914
 
1915
+ # Log graph retriever timing breakdown if available
1916
+ if all_mpfp_timings:
1917
+ retriever_name = get_default_graph_retriever().name.upper()
1918
+ mpfp_total = all_mpfp_timings[0] # Take first fact type's timing as representative
1919
+ mpfp_parts = [
1920
+ f"db_queries={mpfp_total.db_queries}",
1921
+ f"edge_load={mpfp_total.edge_load_time:.3f}s",
1922
+ f"edges={mpfp_total.edge_count}",
1923
+ f"patterns={mpfp_total.pattern_count}",
1924
+ ]
1925
+ if mpfp_total.seeds_time > 0.01:
1926
+ mpfp_parts.append(f"seeds={mpfp_total.seeds_time:.3f}s")
1927
+ if mpfp_total.fusion > 0.001:
1928
+ mpfp_parts.append(f"fusion={mpfp_total.fusion:.3f}s")
1929
+ if mpfp_total.fetch > 0.001:
1930
+ mpfp_parts.append(f"fetch={mpfp_total.fetch:.3f}s")
1931
+ log_buffer.append(f" [{retriever_name}] {', '.join(mpfp_parts)}")
1932
+ # Log detailed hop timing for debugging slow queries
1933
+ if mpfp_total.hop_details:
1934
+ for hd in mpfp_total.hop_details:
1935
+ log_buffer.append(
1936
+ f" hop{hd['hop']}: exec={hd.get('exec_time', 0) * 1000:.0f}ms, "
1937
+ f"uncached={hd.get('uncached_after_filter', 0)}, "
1938
+ f"load={hd.get('load_time', 0) * 1000:.0f}ms, "
1939
+ f"edges={hd.get('edges_loaded', 0)}"
1940
+ )
1941
+
1942
+ # Record temporal constraint in tracer if detected
1943
+ if tracer and detected_temporal_constraint:
1944
+ start_dt, end_dt = detected_temporal_constraint
1945
+ tracer.record_temporal_constraint(start_dt, end_dt)
1946
+
1584
1947
  # Record retrieval results for tracer - per fact type
1585
1948
  if tracer:
1586
1949
  # Convert RetrievalResult to old tuple format for tracer
@@ -1588,8 +1951,10 @@ class MemoryEngine(MemoryEngineInterface):
1588
1951
  return [(r.id, r.__dict__) for r in results]
1589
1952
 
1590
1953
  # Add retrieval results per fact type (to show parallel execution in UI)
1591
- for idx, rr in enumerate(all_retrievals):
1592
- ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1954
+ for ft_name in fact_type:
1955
+ rr = multi_result.results_by_fact_type.get(ft_name)
1956
+ if not rr:
1957
+ continue
1593
1958
 
1594
1959
  # Add semantic retrieval results for this fact type
1595
1960
  tracer.add_retrieval_results(
@@ -1621,14 +1986,22 @@ class MemoryEngine(MemoryEngineInterface):
1621
1986
  fact_type=ft_name,
1622
1987
  )
1623
1988
 
1624
- # Add temporal retrieval results for this fact type (even if empty, to show it ran)
1625
- if rr.temporal is not None:
1989
+ # Add temporal retrieval results for this fact type
1990
+ # Show temporal even with 0 results if constraint was detected
1991
+ if rr.temporal is not None or rr.temporal_constraint is not None:
1992
+ temporal_metadata = {"budget": thinking_budget}
1993
+ if rr.temporal_constraint:
1994
+ start_dt, end_dt = rr.temporal_constraint
1995
+ temporal_metadata["constraint"] = {
1996
+ "start": start_dt.isoformat() if start_dt else None,
1997
+ "end": end_dt.isoformat() if end_dt else None,
1998
+ }
1626
1999
  tracer.add_retrieval_results(
1627
2000
  method_name="temporal",
1628
- results=to_tuple_format(rr.temporal),
2001
+ results=to_tuple_format(rr.temporal or []),
1629
2002
  duration_seconds=rr.timings.get("temporal", 0.0),
1630
2003
  score_field="temporal_score",
1631
- metadata={"budget": thinking_budget},
2004
+ metadata=temporal_metadata,
1632
2005
  fact_type=ft_name,
1633
2006
  )
1634
2007
 
@@ -1678,11 +2051,24 @@ class MemoryEngine(MemoryEngineInterface):
1678
2051
  # Ensure reranker is initialized (for lazy initialization mode)
1679
2052
  await reranker_instance.ensure_initialized()
1680
2053
 
2054
+ # Pre-filter candidates to reduce reranking cost (RRF already provides good ranking)
2055
+ # This is especially important for remote rerankers with network latency
2056
+ reranker_max_candidates = get_config().reranker_max_candidates
2057
+ pre_filtered_count = 0
2058
+ if len(merged_candidates) > reranker_max_candidates:
2059
+ # Sort by RRF score and take top candidates
2060
+ merged_candidates.sort(key=lambda mc: mc.rrf_score, reverse=True)
2061
+ pre_filtered_count = len(merged_candidates) - reranker_max_candidates
2062
+ merged_candidates = merged_candidates[:reranker_max_candidates]
2063
+
1681
2064
  # Rerank using cross-encoder
1682
- scored_results = reranker_instance.rerank(query, merged_candidates)
2065
+ scored_results = await reranker_instance.rerank(query, merged_candidates)
1683
2066
 
1684
2067
  step_duration = time.time() - step_start
1685
- log_buffer.append(f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s")
2068
+ pre_filter_note = f" (pre-filtered {pre_filtered_count})" if pre_filtered_count > 0 else ""
2069
+ log_buffer.append(
2070
+ f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s{pre_filter_note}"
2071
+ )
1686
2072
 
1687
2073
  # Step 4.5: Combine cross-encoder score with retrieval signals
1688
2074
  # This preserves retrieval work (RRF, temporal, recency) instead of pure cross-encoder ranking
@@ -1786,7 +2172,6 @@ class MemoryEngine(MemoryEngineInterface):
1786
2172
  text=sr.retrieval.text,
1787
2173
  context=sr.retrieval.context or "",
1788
2174
  event_date=sr.retrieval.occurred_start,
1789
- access_count=sr.retrieval.access_count,
1790
2175
  is_entry_point=(sr.id in [ep.node_id for ep in tracer.entry_points]),
1791
2176
  parent_node_id=None, # In parallel retrieval, there's no clear parent
1792
2177
  link_type=None,
@@ -1798,12 +2183,6 @@ class MemoryEngine(MemoryEngineInterface):
1798
2183
  final_weight=sr.weight,
1799
2184
  )
1800
2185
 
1801
- # Step 8: Queue access count updates for visited nodes
1802
- visited_ids = list(set([sr.id for sr in scored_results[:50]])) # Top 50
1803
- if visited_ids:
1804
- await self._task_backend.submit_task({"type": "access_count_update", "node_ids": visited_ids})
1805
- log_buffer.append(f" [7] Queued access count updates for {len(visited_ids)} nodes")
1806
-
1807
2186
  # Log fact_type distribution in results
1808
2187
  fact_type_counts = {}
1809
2188
  for sr in top_scored:
@@ -1878,6 +2257,7 @@ class MemoryEngine(MemoryEngineInterface):
1878
2257
  mentioned_at=result_dict.get("mentioned_at"),
1879
2258
  document_id=result_dict.get("document_id"),
1880
2259
  chunk_id=result_dict.get("chunk_id"),
2260
+ tags=result_dict.get("tags"),
1881
2261
  )
1882
2262
  )
1883
2263
 
@@ -1902,35 +2282,15 @@ class MemoryEngine(MemoryEngineInterface):
1902
2282
  entities_ordered.append((entity_id, entity_name))
1903
2283
  seen_entity_ids.add(entity_id)
1904
2284
 
1905
- # Fetch observations for each entity (respect token budget, in order)
2285
+ # Return entities with empty observations (summaries now live in mental models)
1906
2286
  entities_dict = {}
1907
- encoding = _get_tiktoken_encoding()
1908
-
1909
2287
  for entity_id, entity_name in entities_ordered:
1910
- if total_entity_tokens >= max_entity_tokens:
1911
- break
1912
-
1913
- observations = await self.get_entity_observations(
1914
- bank_id, entity_id, limit=5, request_context=request_context
2288
+ entities_dict[entity_name] = EntityState(
2289
+ entity_id=entity_id,
2290
+ canonical_name=entity_name,
2291
+ observations=[], # Mental models provide this now
1915
2292
  )
1916
2293
 
1917
- # Calculate tokens for this entity's observations
1918
- entity_tokens = 0
1919
- included_observations = []
1920
- for obs in observations:
1921
- obs_tokens = len(encoding.encode(obs.text))
1922
- if total_entity_tokens + entity_tokens + obs_tokens <= max_entity_tokens:
1923
- included_observations.append(obs)
1924
- entity_tokens += obs_tokens
1925
- else:
1926
- break
1927
-
1928
- if included_observations:
1929
- entities_dict[entity_name] = EntityState(
1930
- entity_id=entity_id, canonical_name=entity_name, observations=included_observations
1931
- )
1932
- total_entity_tokens += entity_tokens
1933
-
1934
2294
  # Fetch chunks if requested
1935
2295
  chunks_dict = None
1936
2296
  if include_chunks and top_scored:
@@ -2002,16 +2362,25 @@ class MemoryEngine(MemoryEngineInterface):
2002
2362
  total_time = time.time() - recall_start
2003
2363
  num_chunks = len(chunks_dict) if chunks_dict else 0
2004
2364
  num_entities = len(entities_dict) if entities_dict else 0
2365
+ # Include wait times in log if significant
2366
+ wait_parts = []
2367
+ if semaphore_wait > 0.01:
2368
+ wait_parts.append(f"sem={semaphore_wait:.3f}s")
2369
+ if max_conn_wait > 0.01:
2370
+ wait_parts.append(f"conn={max_conn_wait:.3f}s")
2371
+ wait_info = f" | waits: {', '.join(wait_parts)}" if wait_parts else ""
2005
2372
  log_buffer.append(
2006
- f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s"
2373
+ f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s{wait_info}"
2007
2374
  )
2008
- logger.info("\n" + "\n".join(log_buffer))
2375
+ if not quiet:
2376
+ logger.info("\n" + "\n".join(log_buffer))
2009
2377
 
2010
2378
  return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
2011
2379
 
2012
2380
  except Exception as e:
2013
2381
  log_buffer.append(f"[RECALL {recall_id}] ERROR after {time.time() - recall_start:.3f}s: {str(e)}")
2014
- logger.error("\n" + "\n".join(log_buffer))
2382
+ if not quiet:
2383
+ logger.error("\n" + "\n".join(log_buffer))
2015
2384
  raise Exception(f"Failed to search memories: {str(e)}")
2016
2385
 
2017
2386
  def _filter_by_token_budget(
@@ -2073,11 +2442,11 @@ class MemoryEngine(MemoryEngineInterface):
2073
2442
  doc = await conn.fetchrow(
2074
2443
  f"""
2075
2444
  SELECT d.id, d.bank_id, d.original_text, d.content_hash,
2076
- d.created_at, d.updated_at, COUNT(mu.id) as unit_count
2445
+ d.created_at, d.updated_at, d.tags, COUNT(mu.id) as unit_count
2077
2446
  FROM {fq_table("documents")} d
2078
2447
  LEFT JOIN {fq_table("memory_units")} mu ON mu.document_id = d.id
2079
2448
  WHERE d.id = $1 AND d.bank_id = $2
2080
- GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
2449
+ GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at, d.tags
2081
2450
  """,
2082
2451
  document_id,
2083
2452
  bank_id,
@@ -2094,6 +2463,7 @@ class MemoryEngine(MemoryEngineInterface):
2094
2463
  "memory_unit_count": doc["unit_count"],
2095
2464
  "created_at": doc["created_at"].isoformat() if doc["created_at"] else None,
2096
2465
  "updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else None,
2466
+ "tags": list(doc["tags"]) if doc["tags"] else [],
2097
2467
  }
2098
2468
 
2099
2469
  async def delete_document(
@@ -2118,10 +2488,12 @@ class MemoryEngine(MemoryEngineInterface):
2118
2488
  pool = await self._get_pool()
2119
2489
  async with acquire_with_retry(pool) as conn:
2120
2490
  async with conn.transaction():
2121
- # Count units before deletion
2122
- units_count = await conn.fetchval(
2123
- f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE document_id = $1", document_id
2491
+ # Get memory unit IDs before deletion (for mental model invalidation)
2492
+ unit_rows = await conn.fetch(
2493
+ f"SELECT id FROM {fq_table('memory_units')} WHERE document_id = $1", document_id
2124
2494
  )
2495
+ unit_ids = [str(row["id"]) for row in unit_rows]
2496
+ units_count = len(unit_ids)
2125
2497
 
2126
2498
  # Delete document (cascades to memory_units and all their links)
2127
2499
  deleted = await conn.fetchval(
@@ -2130,6 +2502,10 @@ class MemoryEngine(MemoryEngineInterface):
2130
2502
  bank_id,
2131
2503
  )
2132
2504
 
2505
+ # Invalidate deleted fact IDs from mental models
2506
+ if deleted and unit_ids:
2507
+ await self._invalidate_facts_from_mental_models(conn, bank_id, unit_ids)
2508
+
2133
2509
  return {"document_deleted": 1 if deleted else 0, "memory_units_deleted": units_count if deleted else 0}
2134
2510
 
2135
2511
  async def delete_memory_unit(
@@ -2157,11 +2533,18 @@ class MemoryEngine(MemoryEngineInterface):
2157
2533
  pool = await self._get_pool()
2158
2534
  async with acquire_with_retry(pool) as conn:
2159
2535
  async with conn.transaction():
2536
+ # Get bank_id before deletion (for mental model invalidation)
2537
+ bank_id = await conn.fetchval(f"SELECT bank_id FROM {fq_table('memory_units')} WHERE id = $1", unit_id)
2538
+
2160
2539
  # Delete the memory unit (cascades to links and associations)
2161
2540
  deleted = await conn.fetchval(
2162
2541
  f"DELETE FROM {fq_table('memory_units')} WHERE id = $1 RETURNING id", unit_id
2163
2542
  )
2164
2543
 
2544
+ # Invalidate deleted fact ID from mental models
2545
+ if deleted and bank_id:
2546
+ await self._invalidate_facts_from_mental_models(conn, bank_id, [str(deleted)])
2547
+
2165
2548
  return {
2166
2549
  "success": deleted is not None,
2167
2550
  "unit_id": str(deleted) if deleted else None,
@@ -2253,11 +2636,85 @@ class MemoryEngine(MemoryEngineInterface):
2253
2636
  except Exception as e:
2254
2637
  raise Exception(f"Failed to delete agent data: {str(e)}")
2255
2638
 
2639
+ async def clear_observations(
2640
+ self,
2641
+ bank_id: str,
2642
+ *,
2643
+ request_context: "RequestContext",
2644
+ ) -> dict[str, int]:
2645
+ """
2646
+ Clear all observations for a bank (consolidated knowledge).
2647
+
2648
+ Args:
2649
+ bank_id: Bank ID to clear observations for
2650
+ request_context: Request context for authentication.
2651
+
2652
+ Returns:
2653
+ Dictionary with count of deleted observations
2654
+ """
2655
+ await self._authenticate_tenant(request_context)
2656
+ pool = await self._get_pool()
2657
+ async with acquire_with_retry(pool) as conn:
2658
+ async with conn.transaction():
2659
+ # Count observations before deletion
2660
+ count = await conn.fetchval(
2661
+ f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1 AND fact_type = 'observation'",
2662
+ bank_id,
2663
+ )
2664
+
2665
+ # Delete all observations
2666
+ await conn.execute(
2667
+ f"DELETE FROM {fq_table('memory_units')} WHERE bank_id = $1 AND fact_type = 'observation'",
2668
+ bank_id,
2669
+ )
2670
+
2671
+ # Reset consolidation timestamp
2672
+ await conn.execute(
2673
+ f"UPDATE {fq_table('banks')} SET last_consolidated_at = NULL WHERE bank_id = $1",
2674
+ bank_id,
2675
+ )
2676
+
2677
+ return {"deleted_count": count or 0}
2678
+
2679
+ async def run_consolidation(
2680
+ self,
2681
+ bank_id: str,
2682
+ *,
2683
+ request_context: "RequestContext",
2684
+ ) -> dict[str, int]:
2685
+ """
2686
+ Run memory consolidation to create/update mental models.
2687
+
2688
+ Args:
2689
+ bank_id: Bank ID to run consolidation for
2690
+ request_context: Request context for authentication.
2691
+
2692
+ Returns:
2693
+ Dictionary with consolidation stats
2694
+ """
2695
+ await self._authenticate_tenant(request_context)
2696
+
2697
+ from .consolidation import run_consolidation_job
2698
+
2699
+ result = await run_consolidation_job(
2700
+ memory_engine=self,
2701
+ bank_id=bank_id,
2702
+ request_context=request_context,
2703
+ )
2704
+
2705
+ return {
2706
+ "processed": result.get("processed", 0),
2707
+ "created": result.get("created", 0),
2708
+ "updated": result.get("updated", 0),
2709
+ "skipped": result.get("skipped", 0),
2710
+ }
2711
+
2256
2712
  async def get_graph_data(
2257
2713
  self,
2258
2714
  bank_id: str | None = None,
2259
2715
  fact_type: str | None = None,
2260
2716
  *,
2717
+ limit: int = 1000,
2261
2718
  request_context: "RequestContext",
2262
2719
  ):
2263
2720
  """
@@ -2266,10 +2723,11 @@ class MemoryEngine(MemoryEngineInterface):
2266
2723
  Args:
2267
2724
  bank_id: Filter by bank ID
2268
2725
  fact_type: Filter by fact type (world, experience, opinion)
2726
+ limit: Maximum number of items to return (default: 1000)
2269
2727
  request_context: Request context for authentication.
2270
2728
 
2271
2729
  Returns:
2272
- Dict with nodes, edges, and table_rows
2730
+ Dict with nodes, edges, table_rows, total_units, and limit
2273
2731
  """
2274
2732
  await self._authenticate_tenant(request_context)
2275
2733
  pool = await self._get_pool()
@@ -2291,21 +2749,46 @@ class MemoryEngine(MemoryEngineInterface):
2291
2749
 
2292
2750
  where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
2293
2751
 
2752
+ # Get total count first
2753
+ total_count_result = await conn.fetchrow(
2754
+ f"""
2755
+ SELECT COUNT(*) as total
2756
+ FROM {fq_table("memory_units")}
2757
+ {where_clause}
2758
+ """,
2759
+ *query_params,
2760
+ )
2761
+ total_count = total_count_result["total"] if total_count_result else 0
2762
+
2763
+ # Get units with limit
2764
+ param_count += 1
2294
2765
  units = await conn.fetch(
2295
2766
  f"""
2296
- SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
2767
+ SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type, tags, created_at, proof_count, source_memory_ids
2297
2768
  FROM {fq_table("memory_units")}
2298
2769
  {where_clause}
2299
2770
  ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
2300
- LIMIT 1000
2771
+ LIMIT ${param_count}
2301
2772
  """,
2302
2773
  *query_params,
2774
+ limit,
2303
2775
  )
2304
2776
 
2305
2777
  # Get links, filtering to only include links between units of the selected agent
2306
2778
  # Use DISTINCT ON with LEAST/GREATEST to deduplicate bidirectional links
2307
2779
  unit_ids = [row["id"] for row in units]
2308
- if unit_ids:
2780
+ unit_id_set = set(unit_ids)
2781
+
2782
+ # Collect source memory IDs from observations
2783
+ source_memory_ids = []
2784
+ for unit in units:
2785
+ if unit["source_memory_ids"]:
2786
+ source_memory_ids.extend(unit["source_memory_ids"])
2787
+ source_memory_ids = list(set(source_memory_ids)) # Deduplicate
2788
+
2789
+ # Fetch links involving both visible units AND source memories
2790
+ all_relevant_ids = unit_ids + source_memory_ids
2791
+ if all_relevant_ids:
2309
2792
  links = await conn.fetch(
2310
2793
  f"""
2311
2794
  SELECT DISTINCT ON (LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid))
@@ -2316,14 +2799,69 @@ class MemoryEngine(MemoryEngineInterface):
2316
2799
  e.canonical_name as entity_name
2317
2800
  FROM {fq_table("memory_links")} ml
2318
2801
  LEFT JOIN {fq_table("entities")} e ON ml.entity_id = e.id
2319
- WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.to_unit_id = ANY($1::uuid[])
2802
+ WHERE ml.from_unit_id = ANY($1::uuid[]) OR ml.to_unit_id = ANY($1::uuid[])
2320
2803
  ORDER BY LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid), ml.weight DESC
2321
2804
  """,
2322
- unit_ids,
2805
+ all_relevant_ids,
2323
2806
  )
2324
2807
  else:
2325
2808
  links = []
2326
2809
 
2810
+ # Copy links from source memories to observations
2811
+ # Observations inherit links from their source memories via source_memory_ids
2812
+ # Build a map from source_id to observation_ids
2813
+ source_to_observations = {}
2814
+ for unit in units:
2815
+ if unit["source_memory_ids"]:
2816
+ for source_id in unit["source_memory_ids"]:
2817
+ if source_id not in source_to_observations:
2818
+ source_to_observations[source_id] = []
2819
+ source_to_observations[source_id].append(unit["id"])
2820
+
2821
+ copied_links = []
2822
+ for link in links:
2823
+ from_id = link["from_unit_id"]
2824
+ to_id = link["to_unit_id"]
2825
+
2826
+ # Get observations that should inherit this link
2827
+ from_observations = source_to_observations.get(from_id, [])
2828
+ to_observations = source_to_observations.get(to_id, [])
2829
+
2830
+ # If from_id is a source memory, copy links to its observations
2831
+ if from_observations:
2832
+ for obs_id in from_observations:
2833
+ # Only include if the target is visible
2834
+ if to_id in unit_id_set or to_observations:
2835
+ target = to_observations[0] if to_observations and to_id not in unit_id_set else to_id
2836
+ if target in unit_id_set:
2837
+ copied_links.append(
2838
+ {
2839
+ "from_unit_id": obs_id,
2840
+ "to_unit_id": target,
2841
+ "link_type": link["link_type"],
2842
+ "weight": link["weight"],
2843
+ "entity_name": link["entity_name"],
2844
+ }
2845
+ )
2846
+
2847
+ # If to_id is a source memory, copy links to its observations
2848
+ if to_observations and from_id in unit_id_set:
2849
+ for obs_id in to_observations:
2850
+ copied_links.append(
2851
+ {
2852
+ "from_unit_id": from_id,
2853
+ "to_unit_id": obs_id,
2854
+ "link_type": link["link_type"],
2855
+ "weight": link["weight"],
2856
+ "entity_name": link["entity_name"],
2857
+ }
2858
+ )
2859
+
2860
+ # Keep only direct links between visible nodes
2861
+ direct_links = [
2862
+ link for link in links if link["from_unit_id"] in unit_id_set and link["to_unit_id"] in unit_id_set
2863
+ ]
2864
+
2327
2865
  # Get entity information
2328
2866
  unit_entities = await conn.fetch(f"""
2329
2867
  SELECT ue.unit_id, e.canonical_name
@@ -2341,6 +2879,18 @@ class MemoryEngine(MemoryEngineInterface):
2341
2879
  entity_map[unit_id] = []
2342
2880
  entity_map[unit_id].append(entity_name)
2343
2881
 
2882
+ # For observations, inherit entities from source memories
2883
+ for unit in units:
2884
+ if unit["source_memory_ids"] and unit["id"] not in entity_map:
2885
+ # Collect entities from all source memories
2886
+ source_entities = []
2887
+ for source_id in unit["source_memory_ids"]:
2888
+ if source_id in entity_map:
2889
+ source_entities.extend(entity_map[source_id])
2890
+ if source_entities:
2891
+ # Deduplicate while preserving order
2892
+ entity_map[unit["id"]] = list(dict.fromkeys(source_entities))
2893
+
2344
2894
  # Build nodes
2345
2895
  nodes = []
2346
2896
  for row in units:
@@ -2374,14 +2924,15 @@ class MemoryEngine(MemoryEngineInterface):
2374
2924
  }
2375
2925
  )
2376
2926
 
2377
- # Build edges
2927
+ # Build edges (combine direct links and copied links from sources)
2378
2928
  edges = []
2379
- for row in links:
2929
+ all_links = direct_links + copied_links
2930
+ for row in all_links:
2380
2931
  from_id = str(row["from_unit_id"])
2381
2932
  to_id = str(row["to_unit_id"])
2382
2933
  link_type = row["link_type"]
2383
2934
  weight = row["weight"]
2384
- entity_name = row["entity_name"]
2935
+ entity_name = row.get("entity_name")
2385
2936
 
2386
2937
  # Color by link type
2387
2938
  if link_type == "temporal":
@@ -2433,10 +2984,13 @@ class MemoryEngine(MemoryEngineInterface):
2433
2984
  "document_id": row["document_id"],
2434
2985
  "chunk_id": row["chunk_id"] if row["chunk_id"] else None,
2435
2986
  "fact_type": row["fact_type"],
2987
+ "tags": list(row["tags"]) if row["tags"] else [],
2988
+ "created_at": row["created_at"].isoformat() if row["created_at"] else None,
2989
+ "proof_count": row["proof_count"] if row["proof_count"] else None,
2436
2990
  }
2437
2991
  )
2438
2992
 
2439
- return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units": len(units)}
2993
+ return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units": total_count, "limit": limit}
2440
2994
 
2441
2995
  async def list_memory_units(
2442
2996
  self,
@@ -2565,6 +3119,97 @@ class MemoryEngine(MemoryEngineInterface):
2565
3119
 
2566
3120
  return {"items": items, "total": total, "limit": limit, "offset": offset}
2567
3121
 
3122
+ async def get_memory_unit(
3123
+ self,
3124
+ bank_id: str,
3125
+ memory_id: str,
3126
+ request_context: "RequestContext",
3127
+ ):
3128
+ """
3129
+ Get a single memory unit by ID.
3130
+
3131
+ Args:
3132
+ bank_id: Bank ID
3133
+ memory_id: Memory unit ID
3134
+ request_context: Request context for authentication.
3135
+
3136
+ Returns:
3137
+ Dict with memory unit data or None if not found
3138
+ """
3139
+ await self._authenticate_tenant(request_context)
3140
+ pool = await self._get_pool()
3141
+ async with acquire_with_retry(pool) as conn:
3142
+ # Get the memory unit (include source_memory_ids for mental models)
3143
+ row = await conn.fetchrow(
3144
+ f"""
3145
+ SELECT id, text, context, event_date, occurred_start, occurred_end,
3146
+ mentioned_at, fact_type, document_id, chunk_id, tags, source_memory_ids
3147
+ FROM {fq_table("memory_units")}
3148
+ WHERE id = $1 AND bank_id = $2
3149
+ """,
3150
+ memory_id,
3151
+ bank_id,
3152
+ )
3153
+
3154
+ if not row:
3155
+ return None
3156
+
3157
+ # Get entity information
3158
+ entities_rows = await conn.fetch(
3159
+ f"""
3160
+ SELECT e.canonical_name
3161
+ FROM {fq_table("unit_entities")} ue
3162
+ JOIN {fq_table("entities")} e ON ue.entity_id = e.id
3163
+ WHERE ue.unit_id = $1
3164
+ """,
3165
+ row["id"],
3166
+ )
3167
+ entities = [r["canonical_name"] for r in entities_rows]
3168
+
3169
+ result = {
3170
+ "id": str(row["id"]),
3171
+ "text": row["text"],
3172
+ "context": row["context"] if row["context"] else "",
3173
+ "date": row["event_date"].isoformat() if row["event_date"] else "",
3174
+ "type": row["fact_type"],
3175
+ "mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
3176
+ "occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
3177
+ "occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
3178
+ "entities": entities,
3179
+ "document_id": row["document_id"] if row["document_id"] else None,
3180
+ "chunk_id": str(row["chunk_id"]) if row["chunk_id"] else None,
3181
+ "tags": row["tags"] if row["tags"] else [],
3182
+ }
3183
+
3184
+ # For observations, include source_memory_ids and fetch source_memories
3185
+ if row["fact_type"] == "observation" and row["source_memory_ids"]:
3186
+ source_ids = row["source_memory_ids"]
3187
+ result["source_memory_ids"] = [str(sid) for sid in source_ids]
3188
+
3189
+ # Fetch source memories
3190
+ source_rows = await conn.fetch(
3191
+ f"""
3192
+ SELECT id, text, fact_type, context, occurred_start, mentioned_at
3193
+ FROM {fq_table("memory_units")}
3194
+ WHERE id = ANY($1::uuid[])
3195
+ ORDER BY mentioned_at DESC NULLS LAST
3196
+ """,
3197
+ source_ids,
3198
+ )
3199
+ result["source_memories"] = [
3200
+ {
3201
+ "id": str(r["id"]),
3202
+ "text": r["text"],
3203
+ "type": r["fact_type"],
3204
+ "context": r["context"],
3205
+ "occurred_start": r["occurred_start"].isoformat() if r["occurred_start"] else None,
3206
+ "mentioned_at": r["mentioned_at"].isoformat() if r["mentioned_at"] else None,
3207
+ }
3208
+ for r in source_rows
3209
+ ]
3210
+
3211
+ return result
3212
+
2568
3213
  async def list_documents(
2569
3214
  self,
2570
3215
  bank_id: str,
@@ -2741,264 +3386,24 @@ class MemoryEngine(MemoryEngineInterface):
2741
3386
  "created_at": chunk["created_at"].isoformat() if chunk["created_at"] else "",
2742
3387
  }
2743
3388
 
2744
- async def _evaluate_opinion_update_async(
3389
+ # ==================== bank profile Methods ====================
3390
+
3391
+ async def get_bank_profile(
2745
3392
  self,
2746
- opinion_text: str,
2747
- opinion_confidence: float,
2748
- new_event_text: str,
2749
- entity_name: str,
2750
- ) -> dict[str, Any] | None:
3393
+ bank_id: str,
3394
+ *,
3395
+ request_context: "RequestContext",
3396
+ ) -> dict[str, Any]:
2751
3397
  """
2752
- Evaluate if an opinion should be updated based on a new event.
2753
-
2754
- Args:
2755
- opinion_text: Current opinion text (includes reasons)
2756
- opinion_confidence: Current confidence score (0.0-1.0)
2757
- new_event_text: Text of the new event
2758
- entity_name: Name of the entity this opinion is about
2759
-
2760
- Returns:
2761
- Dict with 'action' ('keep'|'update'), 'new_confidence', 'new_text' (if action=='update')
2762
- or None if no changes needed
2763
- """
2764
-
2765
- class OpinionEvaluation(BaseModel):
2766
- """Evaluation of whether an opinion should be updated."""
2767
-
2768
- action: str = Field(description="Action to take: 'keep' (no change) or 'update' (modify opinion)")
2769
- reasoning: str = Field(description="Brief explanation of why this action was chosen")
2770
- new_confidence: float = Field(
2771
- description="New confidence score (0.0-1.0). Can be higher, lower, or same as before."
2772
- )
2773
- new_opinion_text: str | None = Field(
2774
- default=None,
2775
- description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None.",
2776
- )
2777
-
2778
- evaluation_prompt = f"""You are evaluating whether an existing opinion should be updated based on new information.
2779
-
2780
- ENTITY: {entity_name}
2781
-
2782
- EXISTING OPINION:
2783
- {opinion_text}
2784
- Current confidence: {opinion_confidence:.2f}
2785
-
2786
- NEW EVENT:
2787
- {new_event_text}
2788
-
2789
- Evaluate whether this new event:
2790
- 1. REINFORCES the opinion (increase confidence, keep text)
2791
- 2. WEAKENS the opinion (decrease confidence, keep text)
2792
- 3. CHANGES the opinion (update both text and confidence, noting "Previously I thought X, but now Y...")
2793
- 4. IRRELEVANT (keep everything as is)
2794
-
2795
- Guidelines:
2796
- - Only suggest 'update' action if the new event genuinely contradicts or significantly modifies the opinion
2797
- - If updating the text, acknowledge the previous opinion and explain the change
2798
- - Confidence should reflect accumulated evidence (0.0 = no confidence, 1.0 = very confident)
2799
- - Small changes in confidence are normal; large jumps should be rare"""
2800
-
2801
- try:
2802
- result = await self._llm_config.call(
2803
- messages=[
2804
- {"role": "system", "content": "You evaluate and update opinions based on new information."},
2805
- {"role": "user", "content": evaluation_prompt},
2806
- ],
2807
- response_format=OpinionEvaluation,
2808
- scope="memory_evaluate_opinion",
2809
- temperature=0.3, # Lower temperature for more consistent evaluation
2810
- )
2811
-
2812
- # Only return updates if something actually changed
2813
- if result.action == "keep" and abs(result.new_confidence - opinion_confidence) < 0.01:
2814
- return None
2815
-
2816
- return {
2817
- "action": result.action,
2818
- "reasoning": result.reasoning,
2819
- "new_confidence": result.new_confidence,
2820
- "new_text": result.new_opinion_text if result.action == "update" else None,
2821
- }
2822
-
2823
- except Exception as e:
2824
- logger.warning(f"Failed to evaluate opinion update: {str(e)}")
2825
- return None
2826
-
2827
- async def _handle_form_opinion(self, task_dict: dict[str, Any]):
2828
- """
2829
- Handler for form opinion tasks.
2830
-
2831
- Args:
2832
- task_dict: Dict with keys: 'bank_id', 'answer_text', 'query', 'tenant_id'
2833
- """
2834
- bank_id = task_dict["bank_id"]
2835
- answer_text = task_dict["answer_text"]
2836
- query = task_dict["query"]
2837
- tenant_id = task_dict.get("tenant_id")
2838
-
2839
- await self._extract_and_store_opinions_async(
2840
- bank_id=bank_id, answer_text=answer_text, query=query, tenant_id=tenant_id
2841
- )
2842
-
2843
- async def _handle_reinforce_opinion(self, task_dict: dict[str, Any]):
2844
- """
2845
- Handler for reinforce opinion tasks.
2846
-
2847
- Args:
2848
- task_dict: Dict with keys: 'bank_id', 'created_unit_ids', 'unit_texts', 'unit_entities'
2849
- """
2850
- bank_id = task_dict["bank_id"]
2851
- created_unit_ids = task_dict["created_unit_ids"]
2852
- unit_texts = task_dict["unit_texts"]
2853
- unit_entities = task_dict["unit_entities"]
2854
-
2855
- await self._reinforce_opinions_async(
2856
- bank_id=bank_id, created_unit_ids=created_unit_ids, unit_texts=unit_texts, unit_entities=unit_entities
2857
- )
2858
-
2859
- async def _reinforce_opinions_async(
2860
- self,
2861
- bank_id: str,
2862
- created_unit_ids: list[str],
2863
- unit_texts: list[str],
2864
- unit_entities: list[list[dict[str, str]]],
2865
- ):
2866
- """
2867
- Background task to reinforce opinions based on newly ingested events.
2868
-
2869
- This runs asynchronously and does not block the put operation.
2870
-
2871
- Args:
2872
- bank_id: bank ID
2873
- created_unit_ids: List of newly created memory unit IDs
2874
- unit_texts: Texts of the newly created units
2875
- unit_entities: Entities extracted from each unit
2876
- """
2877
- try:
2878
- # Extract all unique entity names from the new units
2879
- entity_names = set()
2880
- for entities_list in unit_entities:
2881
- for entity in entities_list:
2882
- # Handle both Entity objects and dicts
2883
- if hasattr(entity, "text"):
2884
- entity_names.add(entity.text)
2885
- elif isinstance(entity, dict):
2886
- entity_names.add(entity["text"])
2887
-
2888
- if not entity_names:
2889
- return
2890
-
2891
- pool = await self._get_pool()
2892
- async with acquire_with_retry(pool) as conn:
2893
- # Find all opinions related to these entities
2894
- opinions = await conn.fetch(
2895
- f"""
2896
- SELECT DISTINCT mu.id, mu.text, mu.confidence_score, e.canonical_name
2897
- FROM {fq_table("memory_units")} mu
2898
- JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
2899
- JOIN {fq_table("entities")} e ON ue.entity_id = e.id
2900
- WHERE mu.bank_id = $1
2901
- AND mu.fact_type = 'opinion'
2902
- AND e.canonical_name = ANY($2::text[])
2903
- """,
2904
- bank_id,
2905
- list(entity_names),
2906
- )
2907
-
2908
- if not opinions:
2909
- return
2910
-
2911
- # Use cached LLM config
2912
- if self._llm_config is None:
2913
- logger.error("[REINFORCE] LLM config not available, skipping opinion reinforcement")
2914
- return
2915
-
2916
- # Evaluate each opinion against the new events
2917
- updates_to_apply = []
2918
- for opinion in opinions:
2919
- opinion_id = str(opinion["id"])
2920
- opinion_text = opinion["text"]
2921
- opinion_confidence = opinion["confidence_score"]
2922
- entity_name = opinion["canonical_name"]
2923
-
2924
- # Find all new events mentioning this entity
2925
- relevant_events = []
2926
- for unit_text, entities_list in zip(unit_texts, unit_entities):
2927
- if any(e["text"] == entity_name for e in entities_list):
2928
- relevant_events.append(unit_text)
2929
-
2930
- if not relevant_events:
2931
- continue
2932
-
2933
- # Combine all relevant events
2934
- combined_events = "\n".join(relevant_events)
2935
-
2936
- # Evaluate if opinion should be updated
2937
- evaluation = await self._evaluate_opinion_update_async(
2938
- opinion_text, opinion_confidence, combined_events, entity_name
2939
- )
2940
-
2941
- if evaluation:
2942
- updates_to_apply.append({"opinion_id": opinion_id, "evaluation": evaluation})
2943
-
2944
- # Apply all updates in a single transaction
2945
- if updates_to_apply:
2946
- async with conn.transaction():
2947
- for update in updates_to_apply:
2948
- opinion_id = update["opinion_id"]
2949
- evaluation = update["evaluation"]
2950
-
2951
- if evaluation["action"] == "update" and evaluation["new_text"]:
2952
- # Update both text and confidence
2953
- await conn.execute(
2954
- f"""
2955
- UPDATE {fq_table("memory_units")}
2956
- SET text = $1, confidence_score = $2, updated_at = NOW()
2957
- WHERE id = $3
2958
- """,
2959
- evaluation["new_text"],
2960
- evaluation["new_confidence"],
2961
- uuid.UUID(opinion_id),
2962
- )
2963
- else:
2964
- # Only update confidence
2965
- await conn.execute(
2966
- f"""
2967
- UPDATE {fq_table("memory_units")}
2968
- SET confidence_score = $1, updated_at = NOW()
2969
- WHERE id = $2
2970
- """,
2971
- evaluation["new_confidence"],
2972
- uuid.UUID(opinion_id),
2973
- )
2974
-
2975
- else:
2976
- pass # No opinions to update
2977
-
2978
- except Exception as e:
2979
- logger.error(f"[REINFORCE] Error during opinion reinforcement: {str(e)}")
2980
- import traceback
2981
-
2982
- traceback.print_exc()
2983
-
2984
- # ==================== bank profile Methods ====================
2985
-
2986
- async def get_bank_profile(
2987
- self,
2988
- bank_id: str,
2989
- *,
2990
- request_context: "RequestContext",
2991
- ) -> dict[str, Any]:
2992
- """
2993
- Get bank profile (name, disposition + background).
2994
- Auto-creates agent with default values if not exists.
3398
+ Get bank profile (name, disposition + mission).
3399
+ Auto-creates agent with default values if not exists.
2995
3400
 
2996
3401
  Args:
2997
3402
  bank_id: bank IDentifier
2998
3403
  request_context: Request context for authentication.
2999
3404
 
3000
3405
  Returns:
3001
- Dict with name, disposition traits, and background
3406
+ Dict with name, disposition traits, and mission
3002
3407
  """
3003
3408
  await self._authenticate_tenant(request_context)
3004
3409
  pool = await self._get_pool()
@@ -3008,7 +3413,7 @@ Guidelines:
3008
3413
  "bank_id": bank_id,
3009
3414
  "name": profile["name"],
3010
3415
  "disposition": disposition,
3011
- "background": profile["background"],
3416
+ "mission": profile["mission"],
3012
3417
  }
3013
3418
 
3014
3419
  async def update_bank_disposition(
@@ -3030,31 +3435,51 @@ Guidelines:
3030
3435
  pool = await self._get_pool()
3031
3436
  await bank_utils.update_bank_disposition(pool, bank_id, disposition)
3032
3437
 
3033
- async def merge_bank_background(
3438
+ async def set_bank_mission(
3439
+ self,
3440
+ bank_id: str,
3441
+ mission: str,
3442
+ *,
3443
+ request_context: "RequestContext",
3444
+ ) -> dict[str, Any]:
3445
+ """
3446
+ Set the mission for a bank.
3447
+
3448
+ Args:
3449
+ bank_id: bank IDentifier
3450
+ mission: The mission text
3451
+ request_context: Request context for authentication.
3452
+
3453
+ Returns:
3454
+ Dict with bank_id and mission.
3455
+ """
3456
+ await self._authenticate_tenant(request_context)
3457
+ pool = await self._get_pool()
3458
+ await bank_utils.set_bank_mission(pool, bank_id, mission)
3459
+ return {"bank_id": bank_id, "mission": mission}
3460
+
3461
+ async def merge_bank_mission(
3034
3462
  self,
3035
3463
  bank_id: str,
3036
3464
  new_info: str,
3037
3465
  *,
3038
- update_disposition: bool = True,
3039
3466
  request_context: "RequestContext",
3040
3467
  ) -> dict[str, Any]:
3041
3468
  """
3042
- Merge new background information with existing background using LLM.
3469
+ Merge new mission information with existing mission using LLM.
3043
3470
  Normalizes to first person ("I") and resolves conflicts.
3044
- Optionally infers disposition traits from the merged background.
3045
3471
 
3046
3472
  Args:
3047
3473
  bank_id: bank IDentifier
3048
- new_info: New background information to add/merge
3049
- update_disposition: If True, infer Big Five traits from background (default: True)
3474
+ new_info: New mission information to add/merge
3050
3475
  request_context: Request context for authentication.
3051
3476
 
3052
3477
  Returns:
3053
- Dict with 'background' (str) and optionally 'disposition' (dict) keys
3478
+ Dict with 'mission' (str) key
3054
3479
  """
3055
3480
  await self._authenticate_tenant(request_context)
3056
3481
  pool = await self._get_pool()
3057
- return await bank_utils.merge_bank_background(pool, self._llm_config, bank_id, new_info, update_disposition)
3482
+ return await bank_utils.merge_bank_mission(pool, self._reflect_llm_config, bank_id, new_info)
3058
3483
 
3059
3484
  async def list_banks(
3060
3485
  self,
@@ -3068,7 +3493,7 @@ Guidelines:
3068
3493
  request_context: Request context for authentication.
3069
3494
 
3070
3495
  Returns:
3071
- List of dicts with bank_id, name, disposition, background, created_at, updated_at
3496
+ List of dicts with bank_id, name, disposition, mission, created_at, updated_at
3072
3497
  """
3073
3498
  await self._authenticate_tenant(request_context)
3074
3499
  pool = await self._get_pool()
@@ -3086,35 +3511,44 @@ Guidelines:
3086
3511
  max_tokens: int = 4096,
3087
3512
  response_schema: dict | None = None,
3088
3513
  request_context: "RequestContext",
3514
+ tags: list[str] | None = None,
3515
+ tags_match: TagsMatch = "any",
3516
+ exclude_mental_model_ids: list[str] | None = None,
3089
3517
  ) -> ReflectResult:
3090
3518
  """
3091
- Reflect and formulate an answer using bank identity, world facts, and opinions.
3519
+ Reflect and formulate an answer using an agentic loop with tools.
3092
3520
 
3093
- This method:
3094
- 1. Retrieves experience (conversations and events)
3095
- 2. Retrieves world facts (general knowledge)
3096
- 3. Retrieves existing opinions (bank's formed perspectives)
3097
- 4. Uses LLM to formulate an answer
3098
- 5. Extracts and stores any new opinions formed during reflection
3099
- 6. Optionally generates structured output based on response_schema
3100
- 7. Returns plain text answer and the facts used
3521
+ The reflect agent iteratively uses tools to:
3522
+ 1. lookup: Get mental models (synthesized knowledge)
3523
+ 2. recall: Search facts (semantic + temporal retrieval)
3524
+ 3. learn: Create/update mental models with new insights
3525
+ 4. expand: Get chunk/document context for memories
3526
+
3527
+ The agent starts with empty context and must call tools to gather
3528
+ information. On the last iteration, tools are removed to force a
3529
+ final text response.
3101
3530
 
3102
3531
  Args:
3103
3532
  bank_id: bank identifier
3104
3533
  query: Question to answer
3105
- budget: Budget level for memory exploration (low=100, mid=300, high=600 units)
3106
- context: Additional context string to include in LLM prompt (not used in recall)
3107
- response_schema: Optional JSON Schema for structured output
3534
+ budget: Budget level (currently unused, reserved for future)
3535
+ context: Additional context string to include in agent prompt
3536
+ max_tokens: Max tokens (currently unused, reserved for future)
3537
+ response_schema: Optional JSON Schema for structured output (not yet supported)
3538
+ tags: Optional tags to filter memories
3539
+ tags_match: How to match tags - "any" (OR), "all" (AND)
3540
+ exclude_mental_model_ids: Optional list of mental model IDs to exclude from search
3541
+ (used when refreshing a mental model to avoid circular reference)
3108
3542
 
3109
3543
  Returns:
3110
3544
  ReflectResult containing:
3111
- - text: Plain text answer (no markdown)
3112
- - based_on: Dict with 'world', 'experience', and 'opinion' fact lists (MemoryFact objects)
3113
- - new_opinions: List of newly formed opinions
3114
- - structured_output: Optional dict if response_schema was provided
3545
+ - text: Plain text answer
3546
+ - based_on: Empty dict (agent retrieves facts dynamically)
3547
+ - new_opinions: Empty list
3548
+ - structured_output: None (not yet supported for agentic reflect)
3115
3549
  """
3116
3550
  # Use cached LLM config
3117
- if self._llm_config is None:
3551
+ if self._reflect_llm_config is None:
3118
3552
  raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
3119
3553
 
3120
3554
  # Authenticate tenant and set schema in context (for fq_table())
@@ -3135,121 +3569,312 @@ Guidelines:
3135
3569
 
3136
3570
  reflect_start = time.time()
3137
3571
  reflect_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
3138
- log_buffer = []
3139
- log_buffer.append(f"[REFLECT {reflect_id}] Query: '{query[:50]}...'")
3572
+ tags_info = f", tags={tags} ({tags_match})" if tags else ""
3573
+ logger.info(f"[REFLECT {reflect_id}] Starting agentic reflect for query: {query[:50]}...{tags_info}")
3140
3574
 
3141
- # Steps 1-3: Run multi-fact-type search (12-way retrieval: 4 methods × 3 fact types)
3142
- recall_start = time.time()
3143
- search_result = await self.recall_async(
3144
- bank_id=bank_id,
3145
- query=query,
3146
- budget=budget,
3147
- max_tokens=4096,
3148
- enable_trace=False,
3149
- fact_type=["experience", "world", "opinion"],
3150
- include_entities=True,
3151
- request_context=request_context,
3152
- )
3153
- recall_time = time.time() - recall_start
3575
+ # Get bank profile for agent identity
3576
+ profile = await self.get_bank_profile(bank_id, request_context=request_context)
3577
+
3578
+ # NOTE: Mental models are NOT pre-loaded to keep the initial prompt small.
3579
+ # The agent can call lookup() to list available models if needed.
3580
+ # This is critical for banks with many mental models to avoid huge prompts.
3581
+
3582
+ # Compute max iterations based on budget
3583
+ config = get_config()
3584
+ base_max_iterations = config.reflect_max_iterations
3585
+ # Budget multipliers: low=0.5x, mid=1x, high=2x
3586
+ budget_multipliers = {Budget.LOW: 0.5, Budget.MID: 1.0, Budget.HIGH: 2.0}
3587
+ effective_budget = budget or Budget.LOW
3588
+ max_iterations = max(1, int(base_max_iterations * budget_multipliers.get(effective_budget, 1.0)))
3589
+
3590
+ # Run agentic loop - acquire connections only when needed for DB operations
3591
+ # (not held during LLM calls which can be slow)
3592
+ pool = await self._get_pool()
3154
3593
 
3155
- all_results = search_result.results
3594
+ # Get bank stats for freshness info
3595
+ bank_stats = await self.get_bank_stats(bank_id, request_context=request_context)
3596
+ last_consolidated_at = bank_stats.last_consolidated_at if hasattr(bank_stats, "last_consolidated_at") else None
3597
+ pending_consolidation = bank_stats.pending_consolidation if hasattr(bank_stats, "pending_consolidation") else 0
3156
3598
 
3157
- # Split results by fact type for structured response
3158
- agent_results = [r for r in all_results if r.fact_type == "experience"]
3159
- world_results = [r for r in all_results if r.fact_type == "world"]
3160
- opinion_results = [r for r in all_results if r.fact_type == "opinion"]
3599
+ # Create tool callbacks that acquire connections only when needed
3600
+ from .retain import embedding_utils
3161
3601
 
3162
- log_buffer.append(
3163
- f"[REFLECT {reflect_id}] Recall: {len(all_results)} facts (experience={len(agent_results)}, world={len(world_results)}, opinion={len(opinion_results)}) in {recall_time:.3f}s"
3602
+ async def search_mental_models_fn(q: str, max_results: int = 5) -> dict[str, Any]:
3603
+ # Generate embedding for the query
3604
+ embeddings = await embedding_utils.generate_embeddings_batch(self.embeddings, [q])
3605
+ query_embedding = embeddings[0]
3606
+ async with pool.acquire() as conn:
3607
+ return await tool_search_mental_models(
3608
+ conn,
3609
+ bank_id,
3610
+ q,
3611
+ query_embedding,
3612
+ max_results=max_results,
3613
+ tags=tags,
3614
+ tags_match=tags_match,
3615
+ exclude_ids=exclude_mental_model_ids,
3616
+ )
3617
+
3618
+ async def search_observations_fn(q: str, max_tokens: int = 5000) -> dict[str, Any]:
3619
+ return await tool_search_observations(
3620
+ self,
3621
+ bank_id,
3622
+ q,
3623
+ request_context,
3624
+ max_tokens=max_tokens,
3625
+ tags=tags,
3626
+ tags_match=tags_match,
3627
+ last_consolidated_at=last_consolidated_at,
3628
+ pending_consolidation=pending_consolidation,
3629
+ )
3630
+
3631
+ async def recall_fn(q: str, max_tokens: int = 4096) -> dict[str, Any]:
3632
+ return await tool_recall(
3633
+ self, bank_id, q, request_context, max_tokens=max_tokens, tags=tags, tags_match=tags_match
3634
+ )
3635
+
3636
+ async def expand_fn(memory_ids: list[str], depth: str) -> dict[str, Any]:
3637
+ async with pool.acquire() as conn:
3638
+ return await tool_expand(conn, bank_id, memory_ids, depth)
3639
+
3640
+ # Load directives from the dedicated directives table
3641
+ # Directives are hard rules that must be followed in all responses
3642
+ directives_raw = await self.list_directives(
3643
+ bank_id=bank_id,
3644
+ tags=tags,
3645
+ tags_match=tags_match,
3646
+ active_only=True,
3647
+ request_context=request_context,
3164
3648
  )
3649
+ # Convert directive format to the expected format for reflect agent
3650
+ # The agent expects: name, description (optional), observations (list of {title, content})
3651
+ directives = [
3652
+ {
3653
+ "name": d["name"],
3654
+ "description": d["content"], # Use content as description
3655
+ "observations": [], # Directives use content directly, not observations
3656
+ }
3657
+ for d in directives_raw
3658
+ ]
3659
+ if directives:
3660
+ logger.info(f"[REFLECT {reflect_id}] Loaded {len(directives)} directives")
3165
3661
 
3166
- # Format facts for LLM
3167
- agent_facts_text = think_utils.format_facts_for_prompt(agent_results)
3168
- world_facts_text = think_utils.format_facts_for_prompt(world_results)
3169
- opinion_facts_text = think_utils.format_facts_for_prompt(opinion_results)
3662
+ # Check if the bank has any mental models
3663
+ async with pool.acquire() as conn:
3664
+ mental_model_count = await conn.fetchval(
3665
+ f"SELECT COUNT(*) FROM {fq_table('mental_models')} WHERE bank_id = $1",
3666
+ bank_id,
3667
+ )
3668
+ has_mental_models = mental_model_count > 0
3669
+ if has_mental_models:
3670
+ logger.info(f"[REFLECT {reflect_id}] Bank has {mental_model_count} mental models")
3170
3671
 
3171
- # Get bank profile (name, disposition + background)
3172
- profile = await self.get_bank_profile(bank_id, request_context=request_context)
3173
- name = profile["name"]
3174
- disposition = profile["disposition"] # Typed as DispositionTraits
3175
- background = profile["background"]
3176
-
3177
- # Build the prompt
3178
- prompt = think_utils.build_think_prompt(
3179
- agent_facts_text=agent_facts_text,
3180
- world_facts_text=world_facts_text,
3181
- opinion_facts_text=opinion_facts_text,
3672
+ # Run the agent
3673
+ agent_result = await run_reflect_agent(
3674
+ llm_config=self._reflect_llm_config,
3675
+ bank_id=bank_id,
3182
3676
  query=query,
3183
- name=name,
3184
- disposition=disposition,
3185
- background=background,
3677
+ bank_profile=profile,
3678
+ search_mental_models_fn=search_mental_models_fn,
3679
+ search_observations_fn=search_observations_fn,
3680
+ recall_fn=recall_fn,
3681
+ expand_fn=expand_fn,
3186
3682
  context=context,
3683
+ max_iterations=max_iterations,
3684
+ max_tokens=max_tokens,
3685
+ response_schema=response_schema,
3686
+ directives=directives,
3687
+ has_mental_models=has_mental_models,
3688
+ budget=effective_budget,
3187
3689
  )
3188
3690
 
3189
- log_buffer.append(f"[REFLECT {reflect_id}] Prompt: {len(prompt)} chars")
3190
-
3191
- system_message = think_utils.get_system_message(disposition)
3192
- messages = [{"role": "system", "content": system_message}, {"role": "user", "content": prompt}]
3193
-
3194
- # Prepare response_format if schema provided
3195
- response_format = None
3196
- if response_schema is not None:
3197
- # Wrapper class to provide Pydantic-like interface for raw JSON schemas
3198
- class JsonSchemaWrapper:
3199
- def __init__(self, schema: dict):
3200
- self._schema = schema
3201
-
3202
- def model_json_schema(self):
3203
- return self._schema
3204
-
3205
- response_format = JsonSchemaWrapper(response_schema)
3206
-
3207
- llm_start = time.time()
3208
- result = await self._llm_config.call(
3209
- messages=messages,
3210
- scope="memory_reflect",
3211
- max_completion_tokens=max_tokens,
3212
- response_format=response_format,
3213
- skip_validation=True if response_format else False,
3214
- # Don't enforce strict_schema - not all providers support it and may retry forever
3215
- # Soft enforcement (schema in prompt + json_object mode) is sufficient
3216
- strict_schema=False,
3691
+ total_time = time.time() - reflect_start
3692
+ logger.info(
3693
+ f"[REFLECT {reflect_id}] Complete: {len(agent_result.text)} chars, "
3694
+ f"{agent_result.iterations} iterations, {agent_result.tools_called} tool calls | {total_time:.3f}s"
3217
3695
  )
3218
- llm_time = time.time() - llm_start
3219
3696
 
3220
- # Handle response based on whether structured output was requested
3221
- if response_schema is not None:
3222
- structured_output = result
3223
- answer_text = "" # Empty for backward compatibility
3224
- log_buffer.append(f"[REFLECT {reflect_id}] Structured output generated")
3225
- else:
3226
- structured_output = None
3227
- answer_text = result.strip()
3697
+ # Convert agent tool trace to ToolCallTrace objects
3698
+ tool_trace_result = [
3699
+ ToolCallTrace(
3700
+ tool=tc.tool,
3701
+ reason=tc.reason,
3702
+ input=tc.input,
3703
+ output=tc.output,
3704
+ duration_ms=tc.duration_ms,
3705
+ iteration=tc.iteration,
3706
+ )
3707
+ for tc in agent_result.tool_trace
3708
+ ]
3228
3709
 
3229
- # Submit form_opinion task for background processing
3230
- # Pass tenant_id from request context for internal authentication in background task
3231
- await self._task_backend.submit_task(
3232
- {
3233
- "type": "form_opinion",
3234
- "bank_id": bank_id,
3235
- "answer_text": answer_text,
3236
- "query": query,
3237
- "tenant_id": getattr(request_context, "tenant_id", None) if request_context else None,
3238
- }
3239
- )
3710
+ # Convert agent LLM trace to LLMCallTrace objects
3711
+ llm_trace_result = [LLMCallTrace(scope=lc.scope, duration_ms=lc.duration_ms) for lc in agent_result.llm_trace]
3712
+
3713
+ # Extract memories from recall tool outputs - only include memories the agent actually used
3714
+ # agent_result.used_memory_ids contains validated IDs from the done action
3715
+ used_memory_ids_set = set(agent_result.used_memory_ids) if agent_result.used_memory_ids else set()
3716
+ based_on: dict[str, list[MemoryFact]] = {"world": [], "experience": [], "opinion": [], "observation": []}
3717
+ seen_memory_ids: set[str] = set()
3718
+ for tc in agent_result.tool_trace:
3719
+ if tc.tool == "recall" and "memories" in tc.output:
3720
+ for memory_data in tc.output["memories"]:
3721
+ memory_id = memory_data.get("id")
3722
+ # Only include memories that the agent declared as used (or all if none specified)
3723
+ if memory_id and memory_id not in seen_memory_ids:
3724
+ if used_memory_ids_set and memory_id not in used_memory_ids_set:
3725
+ continue # Skip memories not actually used by the agent
3726
+ seen_memory_ids.add(memory_id)
3727
+ fact_type = memory_data.get("type", "world")
3728
+ if fact_type in based_on:
3729
+ based_on[fact_type].append(
3730
+ MemoryFact(
3731
+ id=memory_id,
3732
+ text=memory_data.get("text", ""),
3733
+ fact_type=fact_type,
3734
+ context=None,
3735
+ occurred_start=memory_data.get("occurred"),
3736
+ occurred_end=memory_data.get("occurred"),
3737
+ )
3738
+ )
3240
3739
 
3241
- total_time = time.time() - reflect_start
3242
- log_buffer.append(
3243
- f"[REFLECT {reflect_id}] Complete: {len(answer_text)} chars response, LLM {llm_time:.3f}s, total {total_time:.3f}s"
3740
+ # Extract mental models from tool outputs - only include models the agent actually used
3741
+ # agent_result.used_mental_model_ids contains validated IDs from the done action
3742
+ used_model_ids_set = set(agent_result.used_mental_model_ids) if agent_result.used_mental_model_ids else set()
3743
+ based_on["mental-models"] = []
3744
+ seen_model_ids: set[str] = set()
3745
+ for tc in agent_result.tool_trace:
3746
+ if tc.tool == "get_mental_model":
3747
+ # Single model lookup (with full details)
3748
+ if tc.output.get("found") and "model" in tc.output:
3749
+ model = tc.output["model"]
3750
+ model_id = model.get("id")
3751
+ if model_id and model_id not in seen_model_ids:
3752
+ # Only include models that the agent declared as used (or all if none specified)
3753
+ if used_model_ids_set and model_id not in used_model_ids_set:
3754
+ continue # Skip models not actually used by the agent
3755
+ seen_model_ids.add(model_id)
3756
+ # Add to based_on as MemoryFact with type "mental-models"
3757
+ model_name = model.get("name", "")
3758
+ model_summary = model.get("summary") or model.get("description", "")
3759
+ based_on["mental-models"].append(
3760
+ MemoryFact(
3761
+ id=model_id,
3762
+ text=f"{model_name}: {model_summary}",
3763
+ fact_type="mental-models",
3764
+ context=f"{model.get('type', 'concept')} ({model.get('subtype', 'structural')})",
3765
+ occurred_start=None,
3766
+ occurred_end=None,
3767
+ )
3768
+ )
3769
+ elif tc.tool == "search_mental_models":
3770
+ # Search mental models - include all returned models (filtered by used_model_ids_set if specified)
3771
+ for model in tc.output.get("mental_models", []):
3772
+ model_id = model.get("id")
3773
+ if model_id and model_id not in seen_model_ids:
3774
+ # Only include models that the agent declared as used (or all if none specified)
3775
+ if used_model_ids_set and model_id not in used_model_ids_set:
3776
+ continue # Skip models not actually used by the agent
3777
+ seen_model_ids.add(model_id)
3778
+ # Add to based_on as MemoryFact with type "mental-models"
3779
+ model_name = model.get("name", "")
3780
+ model_summary = model.get("summary") or model.get("description", "")
3781
+ based_on["mental-models"].append(
3782
+ MemoryFact(
3783
+ id=model_id,
3784
+ text=f"{model_name}: {model_summary}",
3785
+ fact_type="mental-models",
3786
+ context=f"{model.get('type', 'concept')} ({model.get('subtype', 'structural')})",
3787
+ occurred_start=None,
3788
+ occurred_end=None,
3789
+ )
3790
+ )
3791
+ elif tc.tool == "search_mental_models":
3792
+ # Search mental models - include all returned mental models (filtered by used_mental_model_ids_set if specified)
3793
+ used_mental_model_ids_set = (
3794
+ set(agent_result.used_mental_model_ids) if agent_result.used_mental_model_ids else set()
3795
+ )
3796
+ for mental_model in tc.output.get("mental_models", []):
3797
+ mental_model_id = mental_model.get("id")
3798
+ if mental_model_id and mental_model_id not in seen_model_ids:
3799
+ # Only include mental models that the agent declared as used (or all if none specified)
3800
+ if used_mental_model_ids_set and mental_model_id not in used_mental_model_ids_set:
3801
+ continue # Skip mental models not actually used by the agent
3802
+ seen_model_ids.add(mental_model_id)
3803
+ # Add to based_on as MemoryFact with type "mental-models" (mental models are synthesized knowledge)
3804
+ mental_model_name = mental_model.get("name", "")
3805
+ mental_model_content = mental_model.get("content", "")
3806
+ based_on["mental-models"].append(
3807
+ MemoryFact(
3808
+ id=mental_model_id,
3809
+ text=f"{mental_model_name}: {mental_model_content}",
3810
+ fact_type="mental-models",
3811
+ context="mental model (user-curated)",
3812
+ occurred_start=None,
3813
+ occurred_end=None,
3814
+ )
3815
+ )
3816
+ # List all models lookup - don't add to based_on (too verbose, just a listing)
3817
+
3818
+ # Add directives to based_on["mental-models"] (they are mental models with subtype='directive')
3819
+ for directive in directives:
3820
+ # Extract summary from observations
3821
+ summary_parts: list[str] = []
3822
+ for obs in directive.get("observations", []):
3823
+ # Support both Pydantic Observation objects and dicts
3824
+ if hasattr(obs, "content"):
3825
+ content = obs.content
3826
+ title = obs.title
3827
+ else:
3828
+ content = obs.get("content", "")
3829
+ title = obs.get("title", "")
3830
+ if title and content:
3831
+ summary_parts.append(f"{title}: {content}")
3832
+ elif content:
3833
+ summary_parts.append(content)
3834
+
3835
+ # Fallback to description if no observations
3836
+ if not summary_parts and directive.get("description"):
3837
+ summary_parts.append(directive["description"])
3838
+
3839
+ directive_name = directive.get("name", "")
3840
+ directive_summary = "; ".join(summary_parts) if summary_parts else ""
3841
+ based_on["mental-models"].append(
3842
+ MemoryFact(
3843
+ id=directive.get("id", ""),
3844
+ text=f"{directive_name}: {directive_summary}",
3845
+ fact_type="mental-models",
3846
+ context="directive (directive)",
3847
+ occurred_start=None,
3848
+ occurred_end=None,
3849
+ )
3850
+ )
3851
+
3852
+ # Build directives_applied from agent result
3853
+ from hindsight_api.engine.response_models import DirectiveRef
3854
+
3855
+ directives_applied_result = [
3856
+ DirectiveRef(id=d.id, name=d.name, content=d.content) for d in agent_result.directives_applied
3857
+ ]
3858
+
3859
+ # Convert agent usage to TokenUsage format
3860
+ from hindsight_api.engine.response_models import TokenUsage
3861
+
3862
+ usage = TokenUsage(
3863
+ input_tokens=agent_result.usage.input_tokens,
3864
+ output_tokens=agent_result.usage.output_tokens,
3865
+ total_tokens=agent_result.usage.total_tokens,
3244
3866
  )
3245
- logger.info("\n" + "\n".join(log_buffer))
3246
3867
 
3247
- # Return response with facts split by type
3868
+ # Return response (compatible with existing API)
3248
3869
  result = ReflectResult(
3249
- text=answer_text,
3250
- based_on={"world": world_results, "experience": agent_results, "opinion": opinion_results},
3251
- new_opinions=[], # Opinions are being extracted asynchronously
3252
- structured_output=structured_output,
3870
+ text=agent_result.text,
3871
+ based_on=based_on,
3872
+ new_opinions=[], # Learnings stored as mental models
3873
+ structured_output=agent_result.structured_output,
3874
+ usage=usage,
3875
+ tool_trace=tool_trace_result,
3876
+ llm_trace=llm_trace_result,
3877
+ directives_applied=directives_applied_result,
3253
3878
  )
3254
3879
 
3255
3880
  # Call post-operation hook if validator is configured
@@ -3273,48 +3898,6 @@ Guidelines:
3273
3898
 
3274
3899
  return result
3275
3900
 
3276
- async def _extract_and_store_opinions_async(
3277
- self, bank_id: str, answer_text: str, query: str, tenant_id: str | None = None
3278
- ):
3279
- """
3280
- Background task to extract and store opinions from think response.
3281
-
3282
- This runs asynchronously and does not block the think response.
3283
-
3284
- Args:
3285
- bank_id: bank IDentifier
3286
- answer_text: The generated answer text
3287
- query: The original query
3288
- tenant_id: Tenant identifier for internal authentication
3289
- """
3290
- try:
3291
- # Extract opinions from the answer
3292
- new_opinions = await think_utils.extract_opinions_from_text(self._llm_config, text=answer_text, query=query)
3293
-
3294
- # Store new opinions
3295
- if new_opinions:
3296
- from datetime import datetime
3297
-
3298
- current_time = datetime.now(UTC)
3299
- # Use internal context with tenant_id for background authentication
3300
- # Extension can check internal=True to bypass normal auth
3301
- from hindsight_api.models import RequestContext
3302
-
3303
- internal_context = RequestContext(tenant_id=tenant_id, internal=True)
3304
- for opinion in new_opinions:
3305
- await self.retain_async(
3306
- bank_id=bank_id,
3307
- content=opinion.opinion,
3308
- context=f"formed during thinking about: {query}",
3309
- event_date=current_time,
3310
- fact_type_override="opinion",
3311
- confidence_score=opinion.confidence,
3312
- request_context=internal_context,
3313
- )
3314
-
3315
- except Exception as e:
3316
- logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
3317
-
3318
3901
  async def get_entity_observations(
3319
3902
  self,
3320
3903
  bank_id: str,
@@ -3324,73 +3907,69 @@ Guidelines:
3324
3907
  request_context: "RequestContext",
3325
3908
  ) -> list[Any]:
3326
3909
  """
3327
- Get observations linked to an entity.
3910
+ Get observations for an entity.
3911
+
3912
+ NOTE: Entity observations/summaries have been moved to mental models.
3913
+ This method returns an empty list. Use mental models for entity summaries.
3328
3914
 
3329
3915
  Args:
3330
3916
  bank_id: bank IDentifier
3331
3917
  entity_id: Entity UUID to get observations for
3332
- limit: Maximum number of observations to return
3918
+ limit: Ignored (kept for backwards compatibility)
3333
3919
  request_context: Request context for authentication.
3334
3920
 
3335
3921
  Returns:
3336
- List of EntityObservation objects
3922
+ Empty list (observations now in mental models)
3337
3923
  """
3338
3924
  await self._authenticate_tenant(request_context)
3339
- pool = await self._get_pool()
3340
- async with acquire_with_retry(pool) as conn:
3341
- rows = await conn.fetch(
3342
- f"""
3343
- SELECT mu.text, mu.mentioned_at
3344
- FROM {fq_table("memory_units")} mu
3345
- JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
3346
- WHERE mu.bank_id = $1
3347
- AND mu.fact_type = 'observation'
3348
- AND ue.entity_id = $2
3349
- ORDER BY mu.mentioned_at DESC
3350
- LIMIT $3
3351
- """,
3352
- bank_id,
3353
- uuid.UUID(entity_id),
3354
- limit,
3355
- )
3356
-
3357
- observations = []
3358
- for row in rows:
3359
- mentioned_at = row["mentioned_at"].isoformat() if row["mentioned_at"] else None
3360
- observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
3361
- return observations
3925
+ return []
3362
3926
 
3363
3927
  async def list_entities(
3364
3928
  self,
3365
3929
  bank_id: str,
3366
3930
  *,
3367
3931
  limit: int = 100,
3932
+ offset: int = 0,
3368
3933
  request_context: "RequestContext",
3369
- ) -> list[dict[str, Any]]:
3934
+ ) -> dict[str, Any]:
3370
3935
  """
3371
- List all entities for a bank.
3936
+ List all entities for a bank with pagination.
3372
3937
 
3373
3938
  Args:
3374
3939
  bank_id: bank IDentifier
3375
3940
  limit: Maximum number of entities to return
3941
+ offset: Offset for pagination
3376
3942
  request_context: Request context for authentication.
3377
3943
 
3378
3944
  Returns:
3379
- List of entity dicts with id, canonical_name, mention_count, first_seen, last_seen
3945
+ Dict with items, total, limit, offset
3380
3946
  """
3381
3947
  await self._authenticate_tenant(request_context)
3382
3948
  pool = await self._get_pool()
3383
3949
  async with acquire_with_retry(pool) as conn:
3950
+ # Get total count
3951
+ total_row = await conn.fetchrow(
3952
+ f"""
3953
+ SELECT COUNT(*) as total
3954
+ FROM {fq_table("entities")}
3955
+ WHERE bank_id = $1
3956
+ """,
3957
+ bank_id,
3958
+ )
3959
+ total = total_row["total"] if total_row else 0
3960
+
3961
+ # Get paginated entities
3384
3962
  rows = await conn.fetch(
3385
3963
  f"""
3386
3964
  SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
3387
3965
  FROM {fq_table("entities")}
3388
3966
  WHERE bank_id = $1
3389
- ORDER BY mention_count DESC, last_seen DESC
3390
- LIMIT $2
3967
+ ORDER BY mention_count DESC, last_seen DESC, id ASC
3968
+ LIMIT $2 OFFSET $3
3391
3969
  """,
3392
3970
  bank_id,
3393
3971
  limit,
3972
+ offset,
3394
3973
  )
3395
3974
 
3396
3975
  entities = []
@@ -3417,7 +3996,91 @@ Guidelines:
3417
3996
  "metadata": metadata,
3418
3997
  }
3419
3998
  )
3420
- return entities
3999
+ return {
4000
+ "items": entities,
4001
+ "total": total,
4002
+ "limit": limit,
4003
+ "offset": offset,
4004
+ }
4005
+
4006
+ async def list_tags(
4007
+ self,
4008
+ bank_id: str,
4009
+ *,
4010
+ pattern: str | None = None,
4011
+ limit: int = 100,
4012
+ offset: int = 0,
4013
+ request_context: "RequestContext",
4014
+ ) -> dict[str, Any]:
4015
+ """
4016
+ List all unique tags for a bank with usage counts.
4017
+
4018
+ Use this to discover available tags or expand wildcard patterns.
4019
+ Supports '*' as wildcard for flexible matching (case-insensitive):
4020
+ - 'user:*' matches user:alice, user:bob
4021
+ - '*-admin' matches role-admin, super-admin
4022
+ - 'env*-prod' matches env-prod, environment-prod
4023
+
4024
+ Args:
4025
+ bank_id: Bank identifier
4026
+ pattern: Wildcard pattern to filter tags (use '*' as wildcard, case-insensitive)
4027
+ limit: Maximum number of tags to return
4028
+ offset: Offset for pagination
4029
+ request_context: Request context for authentication.
4030
+
4031
+ Returns:
4032
+ Dict with items (list of {tag, count}), total, limit, offset
4033
+ """
4034
+ await self._authenticate_tenant(request_context)
4035
+ pool = await self._get_pool()
4036
+ async with acquire_with_retry(pool) as conn:
4037
+ # Build pattern filter if provided (convert * to % for ILIKE)
4038
+ pattern_clause = ""
4039
+ params: list[Any] = [bank_id]
4040
+ if pattern:
4041
+ # Convert wildcard pattern: * -> % for SQL ILIKE
4042
+ sql_pattern = pattern.replace("*", "%")
4043
+ pattern_clause = "AND tag ILIKE $2"
4044
+ params.append(sql_pattern)
4045
+
4046
+ # Get total count of distinct tags matching pattern
4047
+ total_row = await conn.fetchrow(
4048
+ f"""
4049
+ SELECT COUNT(DISTINCT tag) as total
4050
+ FROM {fq_table("memory_units")}, unnest(tags) AS tag
4051
+ WHERE bank_id = $1 AND tags IS NOT NULL AND tags != '{{}}'
4052
+ {pattern_clause}
4053
+ """,
4054
+ *params,
4055
+ )
4056
+ total = total_row["total"] if total_row else 0
4057
+
4058
+ # Get paginated tags with counts, ordered by frequency
4059
+ limit_param = len(params) + 1
4060
+ offset_param = len(params) + 2
4061
+ params.extend([limit, offset])
4062
+
4063
+ rows = await conn.fetch(
4064
+ f"""
4065
+ SELECT tag, COUNT(*) as count
4066
+ FROM {fq_table("memory_units")}, unnest(tags) AS tag
4067
+ WHERE bank_id = $1 AND tags IS NOT NULL AND tags != '{{}}'
4068
+ {pattern_clause}
4069
+ GROUP BY tag
4070
+ ORDER BY count DESC, tag ASC
4071
+ LIMIT ${limit_param} OFFSET ${offset_param}
4072
+ """,
4073
+ *params,
4074
+ )
4075
+
4076
+ items = [{"tag": row["tag"], "count": row["count"]} for row in rows]
4077
+
4078
+ return {
4079
+ "items": items,
4080
+ "total": total,
4081
+ "limit": limit,
4082
+ "offset": offset,
4083
+ }
3421
4084
 
3422
4085
  async def get_entity_state(
3423
4086
  self,
@@ -3429,22 +4092,23 @@ Guidelines:
3429
4092
  request_context: "RequestContext",
3430
4093
  ) -> EntityState:
3431
4094
  """
3432
- Get the current state (mental model) of an entity.
4095
+ Get the current state of an entity.
4096
+
4097
+ NOTE: Entity observations/summaries have been moved to mental models.
4098
+ This method returns an entity with empty observations.
3433
4099
 
3434
4100
  Args:
3435
4101
  bank_id: bank IDentifier
3436
4102
  entity_id: Entity UUID
3437
4103
  entity_name: Canonical name of the entity
3438
- limit: Maximum number of observations to include
4104
+ limit: Maximum number of observations to include (kept for backwards compat)
3439
4105
  request_context: Request context for authentication.
3440
4106
 
3441
4107
  Returns:
3442
- EntityState with observations
4108
+ EntityState with empty observations (summaries now in mental models)
3443
4109
  """
3444
- observations = await self.get_entity_observations(
3445
- bank_id, entity_id, limit=limit, request_context=request_context
3446
- )
3447
- return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=observations)
4110
+ await self._authenticate_tenant(request_context)
4111
+ return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=[])
3448
4112
 
3449
4113
  async def regenerate_entity_observations(
3450
4114
  self,
@@ -3455,533 +4119,1228 @@ Guidelines:
3455
4119
  version: str | None = None,
3456
4120
  conn=None,
3457
4121
  request_context: "RequestContext",
3458
- ) -> None:
4122
+ ) -> list[str]:
3459
4123
  """
3460
- Regenerate observations for an entity by:
3461
- 1. Checking version for deduplication (if provided)
3462
- 2. Searching all facts mentioning the entity
3463
- 3. Using LLM to synthesize observations (no personality)
3464
- 4. Deleting old observations for this entity
3465
- 5. Storing new observations linked to the entity
4124
+ Regenerate observations for an entity.
4125
+
4126
+ NOTE: Entity observations/summaries have been moved to mental models.
4127
+ This method is now a no-op and returns an empty list.
3466
4128
 
3467
4129
  Args:
3468
4130
  bank_id: bank IDentifier
3469
4131
  entity_id: Entity UUID
3470
4132
  entity_name: Canonical name of the entity
3471
4133
  version: Entity's last_seen timestamp when task was created (for deduplication)
3472
- conn: Optional database connection (for transactional atomicity with caller)
4134
+ conn: Optional database connection (ignored)
3473
4135
  request_context: Request context for authentication.
4136
+
4137
+ Returns:
4138
+ Empty list (observations now in mental models)
3474
4139
  """
3475
4140
  await self._authenticate_tenant(request_context)
3476
- pool = await self._get_pool()
3477
- entity_uuid = uuid.UUID(entity_id)
3478
-
3479
- # Helper to run a query with provided conn or acquire one
3480
- async def fetch_with_conn(query, *args):
3481
- if conn is not None:
3482
- return await conn.fetch(query, *args)
3483
- else:
3484
- async with acquire_with_retry(pool) as acquired_conn:
3485
- return await acquired_conn.fetch(query, *args)
4141
+ return []
3486
4142
 
3487
- async def fetchval_with_conn(query, *args):
3488
- if conn is not None:
3489
- return await conn.fetchval(query, *args)
3490
- else:
3491
- async with acquire_with_retry(pool) as acquired_conn:
3492
- return await acquired_conn.fetchval(query, *args)
4143
+ # =========================================================================
4144
+ # Statistics & Operations (for HTTP API layer)
4145
+ # =========================================================================
4146
+
4147
+ async def get_bank_stats(
4148
+ self,
4149
+ bank_id: str,
4150
+ *,
4151
+ request_context: "RequestContext",
4152
+ ) -> dict[str, Any]:
4153
+ """Get statistics about memory nodes and links for a bank."""
4154
+ await self._authenticate_tenant(request_context)
4155
+ pool = await self._get_pool()
3493
4156
 
3494
- # Step 1: Check version for deduplication
3495
- if version:
3496
- current_last_seen = await fetchval_with_conn(
4157
+ async with acquire_with_retry(pool) as conn:
4158
+ # Get node counts by fact_type
4159
+ node_stats = await conn.fetch(
4160
+ f"""
4161
+ SELECT fact_type, COUNT(*) as count
4162
+ FROM {fq_table("memory_units")}
4163
+ WHERE bank_id = $1
4164
+ GROUP BY fact_type
4165
+ """,
4166
+ bank_id,
4167
+ )
4168
+
4169
+ # Get link counts by link_type
4170
+ link_stats = await conn.fetch(
3497
4171
  f"""
3498
- SELECT last_seen
4172
+ SELECT ml.link_type, COUNT(*) as count
4173
+ FROM {fq_table("memory_links")} ml
4174
+ JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
4175
+ WHERE mu.bank_id = $1
4176
+ GROUP BY ml.link_type
4177
+ """,
4178
+ bank_id,
4179
+ )
4180
+
4181
+ # Get link counts by fact_type (from nodes)
4182
+ link_fact_type_stats = await conn.fetch(
4183
+ f"""
4184
+ SELECT mu.fact_type, COUNT(*) as count
4185
+ FROM {fq_table("memory_links")} ml
4186
+ JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
4187
+ WHERE mu.bank_id = $1
4188
+ GROUP BY mu.fact_type
4189
+ """,
4190
+ bank_id,
4191
+ )
4192
+
4193
+ # Get link counts by fact_type AND link_type
4194
+ link_breakdown_stats = await conn.fetch(
4195
+ f"""
4196
+ SELECT mu.fact_type, ml.link_type, COUNT(*) as count
4197
+ FROM {fq_table("memory_links")} ml
4198
+ JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
4199
+ WHERE mu.bank_id = $1
4200
+ GROUP BY mu.fact_type, ml.link_type
4201
+ """,
4202
+ bank_id,
4203
+ )
4204
+
4205
+ # Get pending and failed operations counts
4206
+ ops_stats = await conn.fetch(
4207
+ f"""
4208
+ SELECT status, COUNT(*) as count
4209
+ FROM {fq_table("async_operations")}
4210
+ WHERE bank_id = $1
4211
+ GROUP BY status
4212
+ """,
4213
+ bank_id,
4214
+ )
4215
+
4216
+ return {
4217
+ "bank_id": bank_id,
4218
+ "node_counts": {row["fact_type"]: row["count"] for row in node_stats},
4219
+ "link_counts": {row["link_type"]: row["count"] for row in link_stats},
4220
+ "link_counts_by_fact_type": {row["fact_type"]: row["count"] for row in link_fact_type_stats},
4221
+ "link_breakdown": [
4222
+ {"fact_type": row["fact_type"], "link_type": row["link_type"], "count": row["count"]}
4223
+ for row in link_breakdown_stats
4224
+ ],
4225
+ "operations": {row["status"]: row["count"] for row in ops_stats},
4226
+ }
4227
+
4228
+ async def get_entity(
4229
+ self,
4230
+ bank_id: str,
4231
+ entity_id: str,
4232
+ *,
4233
+ request_context: "RequestContext",
4234
+ ) -> dict[str, Any] | None:
4235
+ """Get entity details including metadata and observations."""
4236
+ await self._authenticate_tenant(request_context)
4237
+ pool = await self._get_pool()
4238
+
4239
+ async with acquire_with_retry(pool) as conn:
4240
+ entity_row = await conn.fetchrow(
4241
+ f"""
4242
+ SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
3499
4243
  FROM {fq_table("entities")}
3500
- WHERE id = $1 AND bank_id = $2
4244
+ WHERE bank_id = $1 AND id = $2
3501
4245
  """,
3502
- entity_uuid,
3503
4246
  bank_id,
4247
+ uuid.UUID(entity_id),
3504
4248
  )
3505
4249
 
3506
- if current_last_seen and current_last_seen.isoformat() != version:
3507
- return []
4250
+ if not entity_row:
4251
+ return None
3508
4252
 
3509
- # Step 2: Get all facts mentioning this entity (exclude observations themselves)
3510
- rows = await fetch_with_conn(
4253
+ # Get observations for the entity
4254
+ observations = await self.get_entity_observations(bank_id, entity_id, limit=20, request_context=request_context)
4255
+
4256
+ return {
4257
+ "id": str(entity_row["id"]),
4258
+ "canonical_name": entity_row["canonical_name"],
4259
+ "mention_count": entity_row["mention_count"],
4260
+ "first_seen": entity_row["first_seen"].isoformat() if entity_row["first_seen"] else None,
4261
+ "last_seen": entity_row["last_seen"].isoformat() if entity_row["last_seen"] else None,
4262
+ "metadata": entity_row["metadata"] or {},
4263
+ "observations": observations,
4264
+ }
4265
+
4266
+ def _parse_observations(self, observations_raw: list):
4267
+ """Parse raw observation dicts into typed Observation models.
4268
+
4269
+ Returns list of Observation models with computed trend/evidence_span/evidence_count.
4270
+ """
4271
+ from .reflect.observations import Observation, ObservationEvidence
4272
+
4273
+ observations: list[Observation] = []
4274
+ for obs in observations_raw:
4275
+ if not isinstance(obs, dict):
4276
+ continue
4277
+
4278
+ try:
4279
+ parsed = Observation(
4280
+ title=obs.get("title", ""),
4281
+ content=obs.get("content", ""),
4282
+ evidence=[
4283
+ ObservationEvidence(
4284
+ memory_id=ev.get("memory_id", ""),
4285
+ quote=ev.get("quote", ""),
4286
+ relevance=ev.get("relevance", ""),
4287
+ timestamp=ev.get("timestamp"),
4288
+ )
4289
+ for ev in obs.get("evidence", [])
4290
+ if isinstance(ev, dict)
4291
+ ],
4292
+ created_at=obs.get("created_at"),
4293
+ )
4294
+ observations.append(parsed)
4295
+ except Exception as e:
4296
+ logger.warning(f"Failed to parse observation: {e}")
4297
+ continue
4298
+
4299
+ return observations
4300
+
4301
+ async def _count_memories_since(
4302
+ self,
4303
+ bank_id: str,
4304
+ since_timestamp: str | None,
4305
+ pool=None,
4306
+ ) -> int:
4307
+ """
4308
+ Count memories created after a given timestamp.
4309
+
4310
+ Args:
4311
+ bank_id: Bank identifier
4312
+ since_timestamp: ISO timestamp string. If None, returns total count.
4313
+ pool: Optional database pool (uses default if not provided)
4314
+
4315
+ Returns:
4316
+ Number of memories created since the timestamp
4317
+ """
4318
+ if pool is None:
4319
+ pool = await self._get_pool()
4320
+
4321
+ async with acquire_with_retry(pool) as conn:
4322
+ if since_timestamp:
4323
+ # Parse the timestamp
4324
+ from datetime import datetime
4325
+
4326
+ try:
4327
+ ts = datetime.fromisoformat(since_timestamp.replace("Z", "+00:00"))
4328
+ except ValueError:
4329
+ # Invalid timestamp, return total count
4330
+ ts = None
4331
+
4332
+ if ts:
4333
+ count = await conn.fetchval(
4334
+ f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1 AND created_at > $2",
4335
+ bank_id,
4336
+ ts,
4337
+ )
4338
+ return count or 0
4339
+
4340
+ # No timestamp or invalid, return total count
4341
+ count = await conn.fetchval(
4342
+ f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1",
4343
+ bank_id,
4344
+ )
4345
+ return count or 0
4346
+
4347
+ async def _invalidate_facts_from_mental_models(
4348
+ self,
4349
+ conn,
4350
+ bank_id: str,
4351
+ fact_ids: list[str],
4352
+ ) -> int:
4353
+ """
4354
+ Remove fact IDs from observation source_memory_ids when memories are deleted.
4355
+
4356
+ Observations are stored in memory_units with fact_type='observation'
4357
+ and have a source_memory_ids column (UUID[]) tracking their source memories.
4358
+
4359
+ Args:
4360
+ conn: Database connection
4361
+ bank_id: Bank identifier
4362
+ fact_ids: List of fact IDs to remove from observations
4363
+
4364
+ Returns:
4365
+ Number of observations updated
4366
+ """
4367
+ if not fact_ids:
4368
+ return 0
4369
+
4370
+ # Convert string IDs to UUIDs for the array comparison
4371
+ import uuid as uuid_module
4372
+
4373
+ fact_uuids = [uuid_module.UUID(fid) for fid in fact_ids]
4374
+
4375
+ # Update observations (memory_units with fact_type='observation')
4376
+ # by removing the deleted fact IDs from source_memory_ids
4377
+ # Use array subtraction: source_memory_ids - deleted_ids
4378
+ result = await conn.execute(
3511
4379
  f"""
3512
- SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
3513
- FROM {fq_table("memory_units")} mu
3514
- JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
3515
- WHERE mu.bank_id = $1
3516
- AND ue.entity_id = $2
3517
- AND mu.fact_type IN ('world', 'experience')
3518
- ORDER BY mu.occurred_start DESC
3519
- LIMIT 50
4380
+ UPDATE {fq_table("memory_units")}
4381
+ SET source_memory_ids = (
4382
+ SELECT COALESCE(array_agg(elem), ARRAY[]::uuid[])
4383
+ FROM unnest(source_memory_ids) AS elem
4384
+ WHERE elem != ALL($2::uuid[])
4385
+ ),
4386
+ updated_at = NOW()
4387
+ WHERE bank_id = $1
4388
+ AND fact_type = 'observation'
4389
+ AND source_memory_ids && $2::uuid[]
3520
4390
  """,
3521
4391
  bank_id,
3522
- entity_uuid,
4392
+ fact_uuids,
3523
4393
  )
3524
4394
 
3525
- if not rows:
3526
- return []
4395
+ # Parse the result to get number of updated rows
4396
+ updated_count = int(result.split()[-1]) if result and "UPDATE" in result else 0
4397
+ if updated_count > 0:
4398
+ logger.info(
4399
+ f"[OBSERVATIONS] Invalidated {len(fact_ids)} fact IDs from {updated_count} observations in bank {bank_id}"
4400
+ )
4401
+ return updated_count
3527
4402
 
3528
- # Convert to MemoryFact objects for the observation extraction
3529
- facts = []
3530
- for row in rows:
3531
- occurred_start = row["occurred_start"].isoformat() if row["occurred_start"] else None
3532
- facts.append(
3533
- MemoryFact(
3534
- id=str(row["id"]),
3535
- text=row["text"],
3536
- fact_type=row["fact_type"],
3537
- context=row["context"],
3538
- occurred_start=occurred_start,
3539
- )
4403
+ # =========================================================================
4404
+ # MENTAL MODELS (CONSOLIDATED) - Read-only access to auto-consolidated mental models
4405
+ # =========================================================================
4406
+
4407
+ async def list_mental_models_consolidated(
4408
+ self,
4409
+ bank_id: str,
4410
+ *,
4411
+ tags: list[str] | None = None,
4412
+ tags_match: str = "any",
4413
+ limit: int = 100,
4414
+ offset: int = 0,
4415
+ request_context: "RequestContext",
4416
+ ) -> list[dict[str, Any]]:
4417
+ """List auto-consolidated observations for a bank.
4418
+
4419
+ Observations are stored in memory_units with fact_type='observation'.
4420
+ They are automatically created and updated by the consolidation engine.
4421
+
4422
+ Args:
4423
+ bank_id: Bank identifier
4424
+ tags: Optional tags to filter by
4425
+ tags_match: How to match tags - 'any', 'all', or 'exact'
4426
+ limit: Maximum number of results
4427
+ offset: Offset for pagination
4428
+ request_context: Request context for authentication
4429
+
4430
+ Returns:
4431
+ List of observation dicts
4432
+ """
4433
+ await self._authenticate_tenant(request_context)
4434
+ pool = await self._get_pool()
4435
+
4436
+ async with acquire_with_retry(pool) as conn:
4437
+ # Build tag filter
4438
+ tag_filter = ""
4439
+ params: list[Any] = [bank_id, limit, offset]
4440
+ if tags:
4441
+ if tags_match == "all":
4442
+ tag_filter = " AND tags @> $4::varchar[]"
4443
+ elif tags_match == "exact":
4444
+ tag_filter = " AND tags = $4::varchar[]"
4445
+ else: # any
4446
+ tag_filter = " AND tags && $4::varchar[]"
4447
+ params.append(tags)
4448
+
4449
+ rows = await conn.fetch(
4450
+ f"""
4451
+ SELECT id, bank_id, text, proof_count, history, tags, source_memory_ids, created_at, updated_at
4452
+ FROM {fq_table("memory_units")}
4453
+ WHERE bank_id = $1 AND fact_type = 'observation' {tag_filter}
4454
+ ORDER BY updated_at DESC NULLS LAST
4455
+ LIMIT $2 OFFSET $3
4456
+ """,
4457
+ *params,
3540
4458
  )
3541
4459
 
3542
- # Step 3: Extract observations using LLM (no personality)
3543
- observations = await observation_utils.extract_observations_from_facts(self._llm_config, entity_name, facts)
4460
+ return [self._row_to_observation_consolidated(row) for row in rows]
3544
4461
 
3545
- if not observations:
3546
- return []
4462
+ async def get_observation_consolidated(
4463
+ self,
4464
+ bank_id: str,
4465
+ observation_id: str,
4466
+ *,
4467
+ include_source_memories: bool = True,
4468
+ request_context: "RequestContext",
4469
+ ) -> dict[str, Any] | None:
4470
+ """Get a single observation by ID.
4471
+
4472
+ Args:
4473
+ bank_id: Bank identifier
4474
+ observation_id: Observation ID
4475
+ include_source_memories: Whether to include full source memory details
4476
+ request_context: Request context for authentication
4477
+
4478
+ Returns:
4479
+ Observation dict or None if not found
4480
+ """
4481
+ await self._authenticate_tenant(request_context)
4482
+ pool = await self._get_pool()
3547
4483
 
3548
- # Step 4: Delete old observations and insert new ones
3549
- # If conn provided, we're already in a transaction - don't start another
3550
- # If conn is None, acquire one and start a transaction
3551
- async def do_db_operations(db_conn):
3552
- # Delete old observations for this entity
3553
- await db_conn.execute(
4484
+ async with acquire_with_retry(pool) as conn:
4485
+ row = await conn.fetchrow(
3554
4486
  f"""
3555
- DELETE FROM {fq_table("memory_units")}
3556
- WHERE id IN (
3557
- SELECT mu.id
3558
- FROM {fq_table("memory_units")} mu
3559
- JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
3560
- WHERE mu.bank_id = $1
3561
- AND mu.fact_type = 'observation'
3562
- AND ue.entity_id = $2
3563
- )
4487
+ SELECT id, bank_id, text, proof_count, history, tags, source_memory_ids, created_at, updated_at
4488
+ FROM {fq_table("memory_units")}
4489
+ WHERE bank_id = $1 AND id = $2 AND fact_type = 'observation'
3564
4490
  """,
3565
4491
  bank_id,
3566
- entity_uuid,
4492
+ observation_id,
3567
4493
  )
3568
4494
 
3569
- # Generate embeddings for new observations
3570
- embeddings = await embedding_utils.generate_embeddings_batch(self.embeddings, observations)
4495
+ if not row:
4496
+ return None
3571
4497
 
3572
- # Insert new observations
3573
- current_time = utcnow()
3574
- created_ids = []
4498
+ result = self._row_to_observation_consolidated(row)
3575
4499
 
3576
- for obs_text, embedding in zip(observations, embeddings):
3577
- result = await db_conn.fetchrow(
4500
+ # Fetch source memories if requested and source_memory_ids exist
4501
+ if include_source_memories and result.get("source_memory_ids"):
4502
+ source_ids = [uuid.UUID(sid) if isinstance(sid, str) else sid for sid in result["source_memory_ids"]]
4503
+ source_rows = await conn.fetch(
3578
4504
  f"""
3579
- INSERT INTO {fq_table("memory_units")} (
3580
- bank_id, text, embedding, context, event_date,
3581
- occurred_start, occurred_end, mentioned_at,
3582
- fact_type, access_count
3583
- )
3584
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
3585
- RETURNING id
4505
+ SELECT id, text, fact_type, context, occurred_start, mentioned_at
4506
+ FROM {fq_table("memory_units")}
4507
+ WHERE id = ANY($1::uuid[])
4508
+ ORDER BY mentioned_at DESC NULLS LAST
3586
4509
  """,
3587
- bank_id,
3588
- obs_text,
3589
- str(embedding),
3590
- f"observation about {entity_name}",
3591
- current_time,
3592
- current_time,
3593
- current_time,
3594
- current_time,
4510
+ source_ids,
3595
4511
  )
3596
- obs_id = str(result["id"])
3597
- created_ids.append(obs_id)
4512
+ result["source_memories"] = [
4513
+ {
4514
+ "id": str(r["id"]),
4515
+ "text": r["text"],
4516
+ "type": r["fact_type"],
4517
+ "context": r["context"],
4518
+ "occurred_start": r["occurred_start"].isoformat() if r["occurred_start"] else None,
4519
+ "mentioned_at": r["mentioned_at"].isoformat() if r["mentioned_at"] else None,
4520
+ }
4521
+ for r in source_rows
4522
+ ]
3598
4523
 
3599
- # Link observation to entity
3600
- await db_conn.execute(
3601
- f"""
3602
- INSERT INTO {fq_table("unit_entities")} (unit_id, entity_id)
3603
- VALUES ($1, $2)
3604
- """,
3605
- uuid.UUID(obs_id),
3606
- entity_uuid,
3607
- )
4524
+ return result
3608
4525
 
3609
- return created_ids
4526
+ def _row_to_observation_consolidated(self, row: Any) -> dict[str, Any]:
4527
+ """Convert a database row to an observation dict."""
4528
+ import json
3610
4529
 
3611
- if conn is not None:
3612
- # Use provided connection (already in a transaction)
3613
- return await do_db_operations(conn)
3614
- else:
3615
- # Acquire connection and start our own transaction
3616
- async with acquire_with_retry(pool) as acquired_conn:
3617
- async with acquired_conn.transaction():
3618
- return await do_db_operations(acquired_conn)
4530
+ history = row["history"]
4531
+ if isinstance(history, str):
4532
+ history = json.loads(history)
4533
+ elif history is None:
4534
+ history = []
3619
4535
 
3620
- async def _regenerate_observations_sync(
4536
+ # Convert source_memory_ids to strings
4537
+ source_memory_ids = row.get("source_memory_ids") or []
4538
+ source_memory_ids = [str(sid) for sid in source_memory_ids]
4539
+
4540
+ return {
4541
+ "id": str(row["id"]),
4542
+ "bank_id": row["bank_id"],
4543
+ "text": row["text"],
4544
+ "proof_count": row["proof_count"] or 1,
4545
+ "history": history,
4546
+ "tags": row["tags"] or [],
4547
+ "source_memory_ids": source_memory_ids,
4548
+ "source_memories": [], # Populated separately when fetching full details
4549
+ "created_at": row["created_at"].isoformat() if row["created_at"] else None,
4550
+ "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
4551
+ }
4552
+
4553
+ # =========================================================================
4554
+ # MENTAL MODELS CRUD
4555
+ # =========================================================================
4556
+
4557
+ async def list_mental_models(
3621
4558
  self,
3622
4559
  bank_id: str,
3623
- entity_ids: list[str],
3624
- min_facts: int | None = None,
3625
- conn=None,
3626
- request_context: "RequestContext | None" = None,
3627
- ) -> None:
3628
- """
3629
- Regenerate observations for entities synchronously (called during retain).
3630
-
3631
- Processes entities in PARALLEL for faster execution.
4560
+ *,
4561
+ tags: list[str] | None = None,
4562
+ tags_match: str = "any",
4563
+ limit: int = 100,
4564
+ offset: int = 0,
4565
+ request_context: "RequestContext",
4566
+ ) -> list[dict[str, Any]]:
4567
+ """List pinned mental models for a bank.
3632
4568
 
3633
4569
  Args:
3634
4570
  bank_id: Bank identifier
3635
- entity_ids: List of entity IDs to process
3636
- min_facts: Minimum facts required to regenerate observations (uses config default if None)
3637
- conn: Optional database connection (for transactional atomicity)
3638
- """
3639
- if not bank_id or not entity_ids:
3640
- return
4571
+ tags: Optional tags to filter by
4572
+ tags_match: How to match tags - 'any', 'all', or 'exact'
4573
+ limit: Maximum number of results
4574
+ offset: Offset for pagination
4575
+ request_context: Request context for authentication
3641
4576
 
3642
- # Use config default if min_facts not specified
3643
- if min_facts is None:
3644
- min_facts = get_config().observation_min_facts
4577
+ Returns:
4578
+ List of pinned mental model dicts
4579
+ """
4580
+ await self._authenticate_tenant(request_context)
4581
+ pool = await self._get_pool()
3645
4582
 
3646
- # Convert to UUIDs
3647
- entity_uuids = [uuid.UUID(eid) if isinstance(eid, str) else eid for eid in entity_ids]
4583
+ async with acquire_with_retry(pool) as conn:
4584
+ # Build tag filter
4585
+ tag_filter = ""
4586
+ params: list[Any] = [bank_id, limit, offset]
4587
+ if tags:
4588
+ if tags_match == "all":
4589
+ tag_filter = " AND tags @> $4::varchar[]"
4590
+ elif tags_match == "exact":
4591
+ tag_filter = " AND tags = $4::varchar[]"
4592
+ else: # any
4593
+ tag_filter = " AND tags && $4::varchar[]"
4594
+ params.append(tags)
3648
4595
 
3649
- # Use provided connection or acquire a new one
3650
- if conn is not None:
3651
- # Use the provided connection (transactional with caller)
3652
- entity_rows = await conn.fetch(
4596
+ rows = await conn.fetch(
3653
4597
  f"""
3654
- SELECT id, canonical_name FROM {fq_table("entities")}
3655
- WHERE id = ANY($1) AND bank_id = $2
4598
+ SELECT id, bank_id, name, source_query, content, tags,
4599
+ last_refreshed_at, created_at, reflect_response,
4600
+ max_tokens, trigger
4601
+ FROM {fq_table("mental_models")}
4602
+ WHERE bank_id = $1 {tag_filter}
4603
+ ORDER BY last_refreshed_at DESC
4604
+ LIMIT $2 OFFSET $3
3656
4605
  """,
3657
- entity_uuids,
3658
- bank_id,
4606
+ *params,
3659
4607
  )
3660
- entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
3661
4608
 
3662
- fact_counts = await conn.fetch(
4609
+ return [self._row_to_mental_model(row) for row in rows]
4610
+
4611
+ async def get_mental_model(
4612
+ self,
4613
+ bank_id: str,
4614
+ mental_model_id: str,
4615
+ *,
4616
+ request_context: "RequestContext",
4617
+ ) -> dict[str, Any] | None:
4618
+ """Get a single pinned mental model by ID.
4619
+
4620
+ Args:
4621
+ bank_id: Bank identifier
4622
+ mental_model_id: Pinned mental model UUID
4623
+ request_context: Request context for authentication
4624
+
4625
+ Returns:
4626
+ Pinned mental model dict or None if not found
4627
+ """
4628
+ await self._authenticate_tenant(request_context)
4629
+ pool = await self._get_pool()
4630
+
4631
+ async with acquire_with_retry(pool) as conn:
4632
+ row = await conn.fetchrow(
3663
4633
  f"""
3664
- SELECT ue.entity_id, COUNT(*) as cnt
3665
- FROM {fq_table("unit_entities")} ue
3666
- JOIN {fq_table("memory_units")} mu ON ue.unit_id = mu.id
3667
- WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3668
- GROUP BY ue.entity_id
4634
+ SELECT id, bank_id, name, source_query, content, tags,
4635
+ last_refreshed_at, created_at, reflect_response,
4636
+ max_tokens, trigger
4637
+ FROM {fq_table("mental_models")}
4638
+ WHERE bank_id = $1 AND id = $2
3669
4639
  """,
3670
- entity_uuids,
3671
4640
  bank_id,
4641
+ mental_model_id,
3672
4642
  )
3673
- entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
3674
- else:
3675
- # Acquire a new connection (standalone call)
3676
- pool = await self._get_pool()
3677
- async with pool.acquire() as acquired_conn:
3678
- entity_rows = await acquired_conn.fetch(
4643
+
4644
+ return self._row_to_mental_model(row) if row else None
4645
+
4646
+ async def create_mental_model(
4647
+ self,
4648
+ bank_id: str,
4649
+ name: str,
4650
+ source_query: str,
4651
+ content: str,
4652
+ *,
4653
+ mental_model_id: str | None = None,
4654
+ tags: list[str] | None = None,
4655
+ max_tokens: int | None = None,
4656
+ trigger: dict[str, Any] | None = None,
4657
+ request_context: "RequestContext",
4658
+ ) -> dict[str, Any]:
4659
+ """Create a new pinned mental model.
4660
+
4661
+ Args:
4662
+ bank_id: Bank identifier
4663
+ name: Human-readable name for the mental model
4664
+ source_query: The query that generated this mental model
4665
+ content: The synthesized content
4666
+ mental_model_id: Optional UUID for the mental model (auto-generated if not provided)
4667
+ tags: Optional tags for scoped visibility
4668
+ max_tokens: Token limit for content generation during refresh
4669
+ trigger: Trigger settings (e.g., refresh_after_consolidation)
4670
+ request_context: Request context for authentication
4671
+
4672
+ Returns:
4673
+ The created pinned mental model dict
4674
+ """
4675
+ await self._authenticate_tenant(request_context)
4676
+ pool = await self._get_pool()
4677
+
4678
+ # Generate embedding for the content
4679
+ embedding_text = f"{name} {content}"
4680
+ embedding = await embedding_utils.generate_embeddings_batch(self.embeddings, [embedding_text])
4681
+ # Convert embedding to string for asyncpg vector type
4682
+ embedding_str = str(embedding[0]) if embedding else None
4683
+
4684
+ async with acquire_with_retry(pool) as conn:
4685
+ if mental_model_id:
4686
+ row = await conn.fetchrow(
3679
4687
  f"""
3680
- SELECT id, canonical_name FROM {fq_table("entities")}
3681
- WHERE id = ANY($1) AND bank_id = $2
4688
+ INSERT INTO {fq_table("mental_models")}
4689
+ (id, bank_id, name, source_query, content, embedding, tags, max_tokens, trigger)
4690
+ VALUES ($1, $2, $3, $4, $5, $6, $7, COALESCE($8, 2048), COALESCE($9, '{{"refresh_after_consolidation": false}}'::jsonb))
4691
+ RETURNING id, bank_id, name, source_query, content, tags,
4692
+ last_refreshed_at, created_at, reflect_response,
4693
+ max_tokens, trigger
3682
4694
  """,
3683
- entity_uuids,
4695
+ mental_model_id,
3684
4696
  bank_id,
4697
+ name,
4698
+ source_query,
4699
+ content,
4700
+ embedding_str,
4701
+ tags or [],
4702
+ max_tokens,
4703
+ json.dumps(trigger) if trigger else None,
3685
4704
  )
3686
- entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
3687
-
3688
- fact_counts = await acquired_conn.fetch(
4705
+ else:
4706
+ row = await conn.fetchrow(
3689
4707
  f"""
3690
- SELECT ue.entity_id, COUNT(*) as cnt
3691
- FROM {fq_table("unit_entities")} ue
3692
- JOIN {fq_table("memory_units")} mu ON ue.unit_id = mu.id
3693
- WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3694
- GROUP BY ue.entity_id
4708
+ INSERT INTO {fq_table("mental_models")}
4709
+ (bank_id, name, source_query, content, embedding, tags, max_tokens, trigger)
4710
+ VALUES ($1, $2, $3, $4, $5, $6, COALESCE($7, 2048), COALESCE($8, '{{"refresh_after_consolidation": false}}'::jsonb))
4711
+ RETURNING id, bank_id, name, source_query, content, tags,
4712
+ last_refreshed_at, created_at, reflect_response,
4713
+ max_tokens, trigger
3695
4714
  """,
3696
- entity_uuids,
3697
4715
  bank_id,
4716
+ name,
4717
+ source_query,
4718
+ content,
4719
+ embedding_str,
4720
+ tags or [],
4721
+ max_tokens,
4722
+ json.dumps(trigger) if trigger else None,
3698
4723
  )
3699
- entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
3700
4724
 
3701
- # Filter entities that meet the threshold
3702
- entities_to_process = []
3703
- for entity_id in entity_ids:
3704
- entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
3705
- if entity_uuid not in entity_names:
3706
- continue
3707
- fact_count = entity_fact_counts.get(entity_uuid, 0)
3708
- if fact_count >= min_facts:
3709
- entities_to_process.append((entity_id, entity_names[entity_uuid]))
4725
+ logger.info(f"[MENTAL_MODELS] Created pinned mental model '{name}' for bank {bank_id}")
4726
+ return self._row_to_mental_model(row)
3710
4727
 
3711
- if not entities_to_process:
3712
- return
4728
+ async def refresh_mental_model(
4729
+ self,
4730
+ bank_id: str,
4731
+ mental_model_id: str,
4732
+ *,
4733
+ request_context: "RequestContext",
4734
+ ) -> dict[str, Any] | None:
4735
+ """Refresh a pinned mental model by re-running its source query.
3713
4736
 
3714
- # Use internal context if not provided (for internal/background calls)
3715
- from hindsight_api.models import RequestContext as RC
4737
+ This method:
4738
+ 1. Gets the pinned mental model
4739
+ 2. Runs the source_query through reflect
4740
+ 3. Updates the content with the new synthesis
4741
+ 4. Updates last_refreshed_at
3716
4742
 
3717
- ctx = request_context if request_context is not None else RC()
4743
+ Args:
4744
+ bank_id: Bank identifier
4745
+ mental_model_id: Pinned mental model UUID
4746
+ request_context: Request context for authentication
3718
4747
 
3719
- # Process all entities in PARALLEL (LLM calls are the bottleneck)
3720
- async def process_entity(entity_id: str, entity_name: str):
3721
- try:
3722
- await self.regenerate_entity_observations(
3723
- bank_id, entity_id, entity_name, version=None, conn=conn, request_context=ctx
3724
- )
3725
- except Exception as e:
3726
- logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
4748
+ Returns:
4749
+ Updated pinned mental model dict or None if not found
4750
+ """
4751
+ await self._authenticate_tenant(request_context)
4752
+
4753
+ # Get the current mental model
4754
+ mental_model = await self.get_mental_model(bank_id, mental_model_id, request_context=request_context)
4755
+ if not mental_model:
4756
+ return None
4757
+
4758
+ # Run reflect with the source query, excluding the mental model being refreshed
4759
+ reflect_result = await self.reflect_async(
4760
+ bank_id=bank_id,
4761
+ query=mental_model["source_query"],
4762
+ request_context=request_context,
4763
+ exclude_mental_model_ids=[mental_model_id],
4764
+ )
4765
+
4766
+ # Build reflect_response payload to store
4767
+ reflect_response_payload = {
4768
+ "text": reflect_result.text,
4769
+ "based_on": {
4770
+ fact_type: [
4771
+ {
4772
+ "id": str(fact.id),
4773
+ "text": fact.text,
4774
+ "type": fact_type,
4775
+ }
4776
+ for fact in facts
4777
+ ]
4778
+ for fact_type, facts in reflect_result.based_on.items()
4779
+ },
4780
+ "mental_models": [], # Mental models are included in based_on["mental-models"]
4781
+ }
3727
4782
 
3728
- await asyncio.gather(*[process_entity(eid, name) for eid, name in entities_to_process])
4783
+ # Update the mental model with new content and reflect_response
4784
+ return await self.update_mental_model(
4785
+ bank_id,
4786
+ mental_model_id,
4787
+ content=reflect_result.text,
4788
+ reflect_response=reflect_response_payload,
4789
+ request_context=request_context,
4790
+ )
4791
+
4792
+ async def update_mental_model(
4793
+ self,
4794
+ bank_id: str,
4795
+ mental_model_id: str,
4796
+ *,
4797
+ name: str | None = None,
4798
+ content: str | None = None,
4799
+ source_query: str | None = None,
4800
+ max_tokens: int | None = None,
4801
+ tags: list[str] | None = None,
4802
+ trigger: dict[str, Any] | None = None,
4803
+ reflect_response: dict[str, Any] | None = None,
4804
+ request_context: "RequestContext",
4805
+ ) -> dict[str, Any] | None:
4806
+ """Update a pinned mental model.
4807
+
4808
+ Args:
4809
+ bank_id: Bank identifier
4810
+ mental_model_id: Pinned mental model UUID
4811
+ name: New name (if changing)
4812
+ content: New content (if changing)
4813
+ source_query: New source query (if changing)
4814
+ max_tokens: New max tokens (if changing)
4815
+ tags: New tags (if changing)
4816
+ trigger: New trigger settings (if changing)
4817
+ reflect_response: Full reflect API response payload (if changing)
4818
+ request_context: Request context for authentication
3729
4819
 
3730
- async def _handle_regenerate_observations(self, task_dict: dict[str, Any]):
4820
+ Returns:
4821
+ Updated pinned mental model dict or None if not found
3731
4822
  """
3732
- Handler for regenerate_observations tasks.
4823
+ await self._authenticate_tenant(request_context)
4824
+ pool = await self._get_pool()
4825
+
4826
+ async with acquire_with_retry(pool) as conn:
4827
+ # Build dynamic update
4828
+ updates = []
4829
+ params: list[Any] = [bank_id, mental_model_id]
4830
+ param_idx = 3
4831
+
4832
+ if name is not None:
4833
+ updates.append(f"name = ${param_idx}")
4834
+ params.append(name)
4835
+ param_idx += 1
4836
+
4837
+ if content is not None:
4838
+ updates.append(f"content = ${param_idx}")
4839
+ params.append(content)
4840
+ param_idx += 1
4841
+ updates.append("last_refreshed_at = NOW()")
4842
+ # Also update embedding (convert to string for asyncpg vector type)
4843
+ embedding_text = f"{name or ''} {content}"
4844
+ embedding = await embedding_utils.generate_embeddings_batch(self.embeddings, [embedding_text])
4845
+ if embedding:
4846
+ updates.append(f"embedding = ${param_idx}")
4847
+ params.append(str(embedding[0]))
4848
+ param_idx += 1
4849
+
4850
+ if reflect_response is not None:
4851
+ updates.append(f"reflect_response = ${param_idx}")
4852
+ params.append(json.dumps(reflect_response))
4853
+ param_idx += 1
4854
+
4855
+ if source_query is not None:
4856
+ updates.append(f"source_query = ${param_idx}")
4857
+ params.append(source_query)
4858
+ param_idx += 1
4859
+
4860
+ if max_tokens is not None:
4861
+ updates.append(f"max_tokens = ${param_idx}")
4862
+ params.append(max_tokens)
4863
+ param_idx += 1
4864
+
4865
+ if tags is not None:
4866
+ updates.append(f"tags = ${param_idx}")
4867
+ params.append(tags)
4868
+ param_idx += 1
4869
+
4870
+ if trigger is not None:
4871
+ updates.append(f"trigger = ${param_idx}")
4872
+ params.append(json.dumps(trigger))
4873
+ param_idx += 1
4874
+
4875
+ if not updates:
4876
+ return None
4877
+
4878
+ query = f"""
4879
+ UPDATE {fq_table("mental_models")}
4880
+ SET {", ".join(updates)}
4881
+ WHERE bank_id = $1 AND id = $2
4882
+ RETURNING id, bank_id, name, source_query, content, tags,
4883
+ last_refreshed_at, created_at, reflect_response,
4884
+ max_tokens, trigger
4885
+ """
4886
+
4887
+ row = await conn.fetchrow(query, *params)
4888
+
4889
+ return self._row_to_mental_model(row) if row else None
4890
+
4891
+ async def delete_mental_model(
4892
+ self,
4893
+ bank_id: str,
4894
+ mental_model_id: str,
4895
+ *,
4896
+ request_context: "RequestContext",
4897
+ ) -> bool:
4898
+ """Delete a pinned mental model.
3733
4899
 
3734
4900
  Args:
3735
- task_dict: Dict with 'bank_id' and either:
3736
- - 'entity_ids' (list): Process multiple entities
3737
- - 'entity_id', 'entity_name': Process single entity (legacy)
4901
+ bank_id: Bank identifier
4902
+ mental_model_id: Pinned mental model UUID
4903
+ request_context: Request context for authentication
3738
4904
 
3739
- Raises:
3740
- ValueError: If required fields are missing
3741
- Exception: Any exception from regenerate_entity_observations (propagates to execute_task for retry)
4905
+ Returns:
4906
+ True if deleted, False if not found
3742
4907
  """
3743
- bank_id = task_dict.get("bank_id")
3744
- # Use internal request context for background tasks
3745
- from hindsight_api.models import RequestContext
4908
+ await self._authenticate_tenant(request_context)
4909
+ pool = await self._get_pool()
3746
4910
 
3747
- internal_context = RequestContext()
4911
+ async with acquire_with_retry(pool) as conn:
4912
+ result = await conn.execute(
4913
+ f"DELETE FROM {fq_table('mental_models')} WHERE bank_id = $1 AND id = $2",
4914
+ bank_id,
4915
+ mental_model_id,
4916
+ )
3748
4917
 
3749
- # New format: multiple entity_ids
3750
- if "entity_ids" in task_dict:
3751
- entity_ids = task_dict.get("entity_ids", [])
3752
- min_facts = task_dict.get("min_facts", 5)
4918
+ return result == "DELETE 1"
3753
4919
 
3754
- if not bank_id or not entity_ids:
3755
- raise ValueError(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
4920
+ def _row_to_mental_model(self, row) -> dict[str, Any]:
4921
+ """Convert a database row to a mental model dict."""
4922
+ reflect_response = row.get("reflect_response")
4923
+ # Parse JSON string to dict if needed (asyncpg may return JSONB as string)
4924
+ if isinstance(reflect_response, str):
4925
+ try:
4926
+ reflect_response = json.loads(reflect_response)
4927
+ except json.JSONDecodeError:
4928
+ reflect_response = None
4929
+ trigger = row.get("trigger")
4930
+ if isinstance(trigger, str):
4931
+ try:
4932
+ trigger = json.loads(trigger)
4933
+ except json.JSONDecodeError:
4934
+ trigger = None
4935
+ return {
4936
+ "id": str(row["id"]),
4937
+ "bank_id": row["bank_id"],
4938
+ "name": row["name"],
4939
+ "source_query": row["source_query"],
4940
+ "content": row["content"],
4941
+ "tags": row["tags"] or [],
4942
+ "max_tokens": row.get("max_tokens"),
4943
+ "trigger": trigger,
4944
+ "last_refreshed_at": row["last_refreshed_at"].isoformat() if row["last_refreshed_at"] else None,
4945
+ "created_at": row["created_at"].isoformat() if row["created_at"] else None,
4946
+ "reflect_response": reflect_response,
4947
+ }
3756
4948
 
3757
- # Process each entity
3758
- pool = await self._get_pool()
3759
- async with pool.acquire() as conn:
3760
- for entity_id in entity_ids:
3761
- try:
3762
- # Fetch entity name and check fact count
3763
- import uuid as uuid_module
4949
+ # =========================================================================
4950
+ # Directives - Hard rules injected into prompts
4951
+ # =========================================================================
3764
4952
 
3765
- entity_uuid = uuid_module.UUID(entity_id) if isinstance(entity_id, str) else entity_id
4953
+ async def list_directives(
4954
+ self,
4955
+ bank_id: str,
4956
+ *,
4957
+ tags: list[str] | None = None,
4958
+ tags_match: str = "any",
4959
+ active_only: bool = True,
4960
+ limit: int = 100,
4961
+ offset: int = 0,
4962
+ request_context: "RequestContext",
4963
+ ) -> list[dict[str, Any]]:
4964
+ """List directives for a bank.
3766
4965
 
3767
- # First check if entity exists
3768
- entity_exists = await conn.fetchrow(
3769
- f"SELECT canonical_name FROM {fq_table('entities')} WHERE id = $1 AND bank_id = $2",
3770
- entity_uuid,
3771
- bank_id,
3772
- )
4966
+ Args:
4967
+ bank_id: Bank identifier
4968
+ tags: Optional tags to filter by
4969
+ tags_match: How to match tags - 'any', 'all', or 'exact'
4970
+ active_only: Only return active directives (default True)
4971
+ limit: Maximum number of results
4972
+ offset: Offset for pagination
4973
+ request_context: Request context for authentication
3773
4974
 
3774
- if not entity_exists:
3775
- logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
3776
- continue
4975
+ Returns:
4976
+ List of directive dicts
4977
+ """
4978
+ await self._authenticate_tenant(request_context)
4979
+ pool = await self._get_pool()
3777
4980
 
3778
- entity_name = entity_exists["canonical_name"]
4981
+ async with acquire_with_retry(pool) as conn:
4982
+ # Build filters
4983
+ filters = ["bank_id = $1"]
4984
+ params: list[Any] = [bank_id]
4985
+ param_idx = 2
4986
+
4987
+ if active_only:
4988
+ filters.append("is_active = TRUE")
4989
+
4990
+ if tags:
4991
+ if tags_match == "all":
4992
+ filters.append(f"tags @> ${param_idx}::varchar[]")
4993
+ elif tags_match == "exact":
4994
+ filters.append(f"tags = ${param_idx}::varchar[]")
4995
+ else: # any
4996
+ filters.append(f"tags && ${param_idx}::varchar[]")
4997
+ params.append(tags)
4998
+ param_idx += 1
4999
+
5000
+ params.extend([limit, offset])
3779
5001
 
3780
- # Count facts linked to this entity
3781
- fact_count = (
3782
- await conn.fetchval(
3783
- f"SELECT COUNT(*) FROM {fq_table('unit_entities')} WHERE entity_id = $1",
3784
- entity_uuid,
3785
- )
3786
- or 0
3787
- )
5002
+ rows = await conn.fetch(
5003
+ f"""
5004
+ SELECT id, bank_id, name, content, priority, is_active, tags, created_at, updated_at
5005
+ FROM {fq_table("directives")}
5006
+ WHERE {" AND ".join(filters)}
5007
+ ORDER BY priority DESC, created_at DESC
5008
+ LIMIT ${param_idx} OFFSET ${param_idx + 1}
5009
+ """,
5010
+ *params,
5011
+ )
3788
5012
 
3789
- # Only regenerate if entity has enough facts
3790
- if fact_count >= min_facts:
3791
- await self.regenerate_entity_observations(
3792
- bank_id, entity_id, entity_name, version=None, request_context=internal_context
3793
- )
3794
- else:
3795
- logger.debug(
3796
- f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
3797
- )
5013
+ return [self._row_to_directive(row) for row in rows]
3798
5014
 
3799
- except Exception as e:
3800
- # Log but continue processing other entities - individual entity failures
3801
- # shouldn't fail the whole batch
3802
- logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3803
- continue
5015
+ async def get_directive(
5016
+ self,
5017
+ bank_id: str,
5018
+ directive_id: str,
5019
+ *,
5020
+ request_context: "RequestContext",
5021
+ ) -> dict[str, Any] | None:
5022
+ """Get a single directive by ID.
3804
5023
 
3805
- # Legacy format: single entity
3806
- else:
3807
- entity_id = task_dict.get("entity_id")
3808
- entity_name = task_dict.get("entity_name")
3809
- version = task_dict.get("version")
5024
+ Args:
5025
+ bank_id: Bank identifier
5026
+ directive_id: Directive UUID
5027
+ request_context: Request context for authentication
3810
5028
 
3811
- if not all([bank_id, entity_id, entity_name]):
3812
- raise ValueError(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
5029
+ Returns:
5030
+ Directive dict or None if not found
5031
+ """
5032
+ await self._authenticate_tenant(request_context)
5033
+ pool = await self._get_pool()
3813
5034
 
3814
- # Type assertions after validation
3815
- assert isinstance(bank_id, str) and isinstance(entity_id, str) and isinstance(entity_name, str)
3816
- await self.regenerate_entity_observations(
3817
- bank_id, entity_id, entity_name, version=version, request_context=internal_context
5035
+ async with acquire_with_retry(pool) as conn:
5036
+ row = await conn.fetchrow(
5037
+ f"""
5038
+ SELECT id, bank_id, name, content, priority, is_active, tags, created_at, updated_at
5039
+ FROM {fq_table("directives")}
5040
+ WHERE bank_id = $1 AND id = $2
5041
+ """,
5042
+ bank_id,
5043
+ directive_id,
3818
5044
  )
3819
5045
 
3820
- # =========================================================================
3821
- # Statistics & Operations (for HTTP API layer)
3822
- # =========================================================================
5046
+ return self._row_to_directive(row) if row else None
3823
5047
 
3824
- async def get_bank_stats(
5048
+ async def create_directive(
3825
5049
  self,
3826
5050
  bank_id: str,
5051
+ name: str,
5052
+ content: str,
3827
5053
  *,
5054
+ priority: int = 0,
5055
+ is_active: bool = True,
5056
+ tags: list[str] | None = None,
3828
5057
  request_context: "RequestContext",
3829
5058
  ) -> dict[str, Any]:
3830
- """Get statistics about memory nodes and links for a bank."""
5059
+ """Create a new directive.
5060
+
5061
+ Args:
5062
+ bank_id: Bank identifier
5063
+ name: Human-readable name for the directive
5064
+ content: The directive text to inject into prompts
5065
+ priority: Higher priority directives are injected first (default 0)
5066
+ is_active: Whether this directive is active (default True)
5067
+ tags: Optional tags for filtering
5068
+ request_context: Request context for authentication
5069
+
5070
+ Returns:
5071
+ The created directive dict
5072
+ """
3831
5073
  await self._authenticate_tenant(request_context)
3832
5074
  pool = await self._get_pool()
3833
5075
 
3834
5076
  async with acquire_with_retry(pool) as conn:
3835
- # Get node counts by fact_type
3836
- node_stats = await conn.fetch(
5077
+ row = await conn.fetchrow(
3837
5078
  f"""
3838
- SELECT fact_type, COUNT(*) as count
3839
- FROM {fq_table("memory_units")}
3840
- WHERE bank_id = $1
3841
- GROUP BY fact_type
5079
+ INSERT INTO {fq_table("directives")}
5080
+ (bank_id, name, content, priority, is_active, tags)
5081
+ VALUES ($1, $2, $3, $4, $5, $6)
5082
+ RETURNING id, bank_id, name, content, priority, is_active, tags, created_at, updated_at
3842
5083
  """,
3843
5084
  bank_id,
5085
+ name,
5086
+ content,
5087
+ priority,
5088
+ is_active,
5089
+ tags or [],
3844
5090
  )
3845
5091
 
3846
- # Get link counts by link_type
3847
- link_stats = await conn.fetch(
3848
- f"""
3849
- SELECT ml.link_type, COUNT(*) as count
3850
- FROM {fq_table("memory_links")} ml
3851
- JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
3852
- WHERE mu.bank_id = $1
3853
- GROUP BY ml.link_type
3854
- """,
3855
- bank_id,
3856
- )
5092
+ logger.info(f"[DIRECTIVES] Created directive '{name}' for bank {bank_id}")
5093
+ return self._row_to_directive(row)
3857
5094
 
3858
- # Get link counts by fact_type (from nodes)
3859
- link_fact_type_stats = await conn.fetch(
3860
- f"""
3861
- SELECT mu.fact_type, COUNT(*) as count
3862
- FROM {fq_table("memory_links")} ml
3863
- JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
3864
- WHERE mu.bank_id = $1
3865
- GROUP BY mu.fact_type
3866
- """,
3867
- bank_id,
3868
- )
5095
+ async def update_directive(
5096
+ self,
5097
+ bank_id: str,
5098
+ directive_id: str,
5099
+ *,
5100
+ name: str | None = None,
5101
+ content: str | None = None,
5102
+ priority: int | None = None,
5103
+ is_active: bool | None = None,
5104
+ tags: list[str] | None = None,
5105
+ request_context: "RequestContext",
5106
+ ) -> dict[str, Any] | None:
5107
+ """Update a directive.
3869
5108
 
3870
- # Get link counts by fact_type AND link_type
3871
- link_breakdown_stats = await conn.fetch(
3872
- f"""
3873
- SELECT mu.fact_type, ml.link_type, COUNT(*) as count
3874
- FROM {fq_table("memory_links")} ml
3875
- JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
3876
- WHERE mu.bank_id = $1
3877
- GROUP BY mu.fact_type, ml.link_type
3878
- """,
3879
- bank_id,
3880
- )
5109
+ Args:
5110
+ bank_id: Bank identifier
5111
+ directive_id: Directive UUID
5112
+ name: New name (optional)
5113
+ content: New content (optional)
5114
+ priority: New priority (optional)
5115
+ is_active: New active status (optional)
5116
+ tags: New tags (optional)
5117
+ request_context: Request context for authentication
3881
5118
 
3882
- # Get pending and failed operations counts
3883
- ops_stats = await conn.fetch(
5119
+ Returns:
5120
+ Updated directive dict or None if not found
5121
+ """
5122
+ await self._authenticate_tenant(request_context)
5123
+ pool = await self._get_pool()
5124
+
5125
+ # Build update query dynamically
5126
+ updates = ["updated_at = now()"]
5127
+ params: list[Any] = []
5128
+ param_idx = 1
5129
+
5130
+ if name is not None:
5131
+ updates.append(f"name = ${param_idx}")
5132
+ params.append(name)
5133
+ param_idx += 1
5134
+
5135
+ if content is not None:
5136
+ updates.append(f"content = ${param_idx}")
5137
+ params.append(content)
5138
+ param_idx += 1
5139
+
5140
+ if priority is not None:
5141
+ updates.append(f"priority = ${param_idx}")
5142
+ params.append(priority)
5143
+ param_idx += 1
5144
+
5145
+ if is_active is not None:
5146
+ updates.append(f"is_active = ${param_idx}")
5147
+ params.append(is_active)
5148
+ param_idx += 1
5149
+
5150
+ if tags is not None:
5151
+ updates.append(f"tags = ${param_idx}")
5152
+ params.append(tags)
5153
+ param_idx += 1
5154
+
5155
+ params.extend([bank_id, directive_id])
5156
+
5157
+ async with acquire_with_retry(pool) as conn:
5158
+ row = await conn.fetchrow(
3884
5159
  f"""
3885
- SELECT status, COUNT(*) as count
3886
- FROM {fq_table("async_operations")}
3887
- WHERE bank_id = $1
3888
- GROUP BY status
5160
+ UPDATE {fq_table("directives")}
5161
+ SET {", ".join(updates)}
5162
+ WHERE bank_id = ${param_idx} AND id = ${param_idx + 1}
5163
+ RETURNING id, bank_id, name, content, priority, is_active, tags, created_at, updated_at
3889
5164
  """,
3890
- bank_id,
5165
+ *params,
3891
5166
  )
3892
5167
 
3893
- return {
3894
- "bank_id": bank_id,
3895
- "node_counts": {row["fact_type"]: row["count"] for row in node_stats},
3896
- "link_counts": {row["link_type"]: row["count"] for row in link_stats},
3897
- "link_counts_by_fact_type": {row["fact_type"]: row["count"] for row in link_fact_type_stats},
3898
- "link_breakdown": [
3899
- {"fact_type": row["fact_type"], "link_type": row["link_type"], "count": row["count"]}
3900
- for row in link_breakdown_stats
3901
- ],
3902
- "operations": {row["status"]: row["count"] for row in ops_stats},
3903
- }
5168
+ return self._row_to_directive(row) if row else None
3904
5169
 
3905
- async def get_entity(
5170
+ async def delete_directive(
3906
5171
  self,
3907
5172
  bank_id: str,
3908
- entity_id: str,
5173
+ directive_id: str,
3909
5174
  *,
3910
5175
  request_context: "RequestContext",
3911
- ) -> dict[str, Any] | None:
3912
- """Get entity details including metadata and observations."""
5176
+ ) -> bool:
5177
+ """Delete a directive.
5178
+
5179
+ Args:
5180
+ bank_id: Bank identifier
5181
+ directive_id: Directive UUID
5182
+ request_context: Request context for authentication
5183
+
5184
+ Returns:
5185
+ True if deleted, False if not found
5186
+ """
3913
5187
  await self._authenticate_tenant(request_context)
3914
5188
  pool = await self._get_pool()
3915
5189
 
3916
5190
  async with acquire_with_retry(pool) as conn:
3917
- entity_row = await conn.fetchrow(
3918
- f"""
3919
- SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
3920
- FROM {fq_table("entities")}
3921
- WHERE bank_id = $1 AND id = $2
3922
- """,
5191
+ result = await conn.execute(
5192
+ f"DELETE FROM {fq_table('directives')} WHERE bank_id = $1 AND id = $2",
3923
5193
  bank_id,
3924
- uuid.UUID(entity_id),
5194
+ directive_id,
3925
5195
  )
3926
5196
 
3927
- if not entity_row:
3928
- return None
3929
-
3930
- # Get observations for the entity
3931
- observations = await self.get_entity_observations(bank_id, entity_id, limit=20, request_context=request_context)
5197
+ return result == "DELETE 1"
3932
5198
 
5199
+ def _row_to_directive(self, row) -> dict[str, Any]:
5200
+ """Convert a database row to a directive dict."""
3933
5201
  return {
3934
- "id": str(entity_row["id"]),
3935
- "canonical_name": entity_row["canonical_name"],
3936
- "mention_count": entity_row["mention_count"],
3937
- "first_seen": entity_row["first_seen"].isoformat() if entity_row["first_seen"] else None,
3938
- "last_seen": entity_row["last_seen"].isoformat() if entity_row["last_seen"] else None,
3939
- "metadata": entity_row["metadata"] or {},
3940
- "observations": observations,
5202
+ "id": str(row["id"]),
5203
+ "bank_id": row["bank_id"],
5204
+ "name": row["name"],
5205
+ "content": row["content"],
5206
+ "priority": row["priority"],
5207
+ "is_active": row["is_active"],
5208
+ "tags": row["tags"] or [],
5209
+ "created_at": row["created_at"].isoformat() if row["created_at"] else None,
5210
+ "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
3941
5211
  }
3942
5212
 
3943
5213
  async def list_operations(
3944
5214
  self,
3945
5215
  bank_id: str,
3946
5216
  *,
5217
+ status: str | None = None,
5218
+ limit: int = 20,
5219
+ offset: int = 0,
3947
5220
  request_context: "RequestContext",
3948
- ) -> list[dict[str, Any]]:
3949
- """List async operations for a bank."""
5221
+ ) -> dict[str, Any]:
5222
+ """List async operations for a bank with optional filtering and pagination.
5223
+
5224
+ Args:
5225
+ bank_id: Bank identifier
5226
+ status: Optional status filter (pending, completed, failed)
5227
+ limit: Maximum number of operations to return (default 20)
5228
+ offset: Number of operations to skip (default 0)
5229
+ request_context: Request context for authentication
5230
+
5231
+ Returns:
5232
+ Dict with total count and list of operations, sorted by most recent first
5233
+ """
3950
5234
  await self._authenticate_tenant(request_context)
3951
5235
  pool = await self._get_pool()
3952
5236
 
3953
5237
  async with acquire_with_retry(pool) as conn:
5238
+ # Build WHERE clause
5239
+ where_conditions = ["bank_id = $1"]
5240
+ params: list[Any] = [bank_id]
5241
+
5242
+ if status:
5243
+ # Map API status to DB statuses (pending includes processing)
5244
+ if status == "pending":
5245
+ where_conditions.append("status IN ('pending', 'processing')")
5246
+ else:
5247
+ where_conditions.append(f"status = ${len(params) + 1}")
5248
+ params.append(status)
5249
+
5250
+ where_clause = " AND ".join(where_conditions)
5251
+
5252
+ # Get total count (with filter)
5253
+ total_row = await conn.fetchrow(
5254
+ f"SELECT COUNT(*) as total FROM {fq_table('async_operations')} WHERE {where_clause}",
5255
+ *params,
5256
+ )
5257
+ total = total_row["total"] if total_row else 0
5258
+
5259
+ # Get operations with pagination
3954
5260
  operations = await conn.fetch(
3955
5261
  f"""
3956
- SELECT operation_id, bank_id, operation_type, created_at, status, error_message, result_metadata
5262
+ SELECT operation_id, operation_type, created_at, status, error_message
3957
5263
  FROM {fq_table("async_operations")}
3958
- WHERE bank_id = $1
5264
+ WHERE {where_clause}
3959
5265
  ORDER BY created_at DESC
5266
+ LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}
3960
5267
  """,
3961
- bank_id,
5268
+ *params,
5269
+ limit,
5270
+ offset,
3962
5271
  )
3963
5272
 
3964
- def parse_metadata(metadata):
3965
- if metadata is None:
3966
- return {}
3967
- if isinstance(metadata, str):
3968
- import json
5273
+ return {
5274
+ "total": total,
5275
+ "operations": [
5276
+ {
5277
+ "id": str(row["operation_id"]),
5278
+ "task_type": row["operation_type"],
5279
+ "items_count": 0,
5280
+ "document_id": None,
5281
+ "created_at": row["created_at"].isoformat(),
5282
+ # Map DB status to API status (processing -> pending for simplicity)
5283
+ "status": "pending" if row["status"] in ("pending", "processing") else row["status"],
5284
+ "error_message": row["error_message"],
5285
+ }
5286
+ for row in operations
5287
+ ],
5288
+ }
3969
5289
 
3970
- return json.loads(metadata)
3971
- return metadata
5290
+ async def get_operation_status(
5291
+ self,
5292
+ bank_id: str,
5293
+ operation_id: str,
5294
+ *,
5295
+ request_context: "RequestContext",
5296
+ ) -> dict[str, Any]:
5297
+ """Get the status of a specific async operation.
3972
5298
 
3973
- return [
3974
- {
3975
- "id": str(row["operation_id"]),
3976
- "task_type": row["operation_type"],
3977
- "items_count": parse_metadata(row["result_metadata"]).get("items_count", 0),
3978
- "document_id": parse_metadata(row["result_metadata"]).get("document_id"),
3979
- "created_at": row["created_at"].isoformat(),
3980
- "status": row["status"],
5299
+ Returns:
5300
+ - status: "pending", "completed", or "failed"
5301
+ - updated_at: last update timestamp
5302
+ - completed_at: completion timestamp (if completed)
5303
+ """
5304
+ await self._authenticate_tenant(request_context)
5305
+ pool = await self._get_pool()
5306
+
5307
+ op_uuid = uuid.UUID(operation_id)
5308
+
5309
+ async with acquire_with_retry(pool) as conn:
5310
+ row = await conn.fetchrow(
5311
+ f"""
5312
+ SELECT operation_id, operation_type, created_at, updated_at, completed_at, status, error_message
5313
+ FROM {fq_table("async_operations")}
5314
+ WHERE operation_id = $1 AND bank_id = $2
5315
+ """,
5316
+ op_uuid,
5317
+ bank_id,
5318
+ )
5319
+
5320
+ if row:
5321
+ # Map DB status to API status (processing -> pending for simplicity)
5322
+ db_status = row["status"]
5323
+ api_status = "pending" if db_status in ("pending", "processing") else db_status
5324
+ return {
5325
+ "operation_id": operation_id,
5326
+ "status": api_status,
5327
+ "operation_type": row["operation_type"],
5328
+ "created_at": row["created_at"].isoformat() if row["created_at"] else None,
5329
+ "updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
5330
+ "completed_at": row["completed_at"].isoformat() if row["completed_at"] else None,
3981
5331
  "error_message": row["error_message"],
3982
5332
  }
3983
- for row in operations
3984
- ]
5333
+ else:
5334
+ # Operation not found
5335
+ return {
5336
+ "operation_id": operation_id,
5337
+ "status": "not_found",
5338
+ "operation_type": None,
5339
+ "created_at": None,
5340
+ "updated_at": None,
5341
+ "completed_at": None,
5342
+ "error_message": None,
5343
+ }
3985
5344
 
3986
5345
  async def cancel_operation(
3987
5346
  self,
@@ -4022,10 +5381,10 @@ Guidelines:
4022
5381
  bank_id: str,
4023
5382
  *,
4024
5383
  name: str | None = None,
4025
- background: str | None = None,
5384
+ mission: str | None = None,
4026
5385
  request_context: "RequestContext",
4027
5386
  ) -> dict[str, Any]:
4028
- """Update bank name and/or background."""
5387
+ """Update bank name and/or mission."""
4029
5388
  await self._authenticate_tenant(request_context)
4030
5389
  pool = await self._get_pool()
4031
5390
 
@@ -4041,33 +5400,72 @@ Guidelines:
4041
5400
  name,
4042
5401
  )
4043
5402
 
4044
- if background is not None:
5403
+ if mission is not None:
4045
5404
  await conn.execute(
4046
5405
  f"""
4047
5406
  UPDATE {fq_table("banks")}
4048
- SET background = $2, updated_at = NOW()
5407
+ SET mission = $2, updated_at = NOW()
4049
5408
  WHERE bank_id = $1
4050
5409
  """,
4051
5410
  bank_id,
4052
- background,
5411
+ mission,
4053
5412
  )
4054
5413
 
4055
5414
  # Return updated profile
4056
5415
  return await self.get_bank_profile(bank_id, request_context=request_context)
4057
5416
 
4058
- async def submit_async_retain(
5417
+ async def _submit_async_operation(
4059
5418
  self,
4060
5419
  bank_id: str,
4061
- contents: list[dict[str, Any]],
5420
+ operation_type: str,
5421
+ task_type: str,
5422
+ task_payload: dict[str, Any],
4062
5423
  *,
4063
- request_context: "RequestContext",
5424
+ result_metadata: dict[str, Any] | None = None,
5425
+ dedupe_by_bank: bool = False,
4064
5426
  ) -> dict[str, Any]:
4065
- """Submit a batch retain operation to run asynchronously."""
4066
- await self._authenticate_tenant(request_context)
4067
- pool = await self._get_pool()
5427
+ """Generic helper to submit an async operation.
5428
+
5429
+ Args:
5430
+ bank_id: Bank identifier
5431
+ operation_type: Operation type for the async_operations record (e.g., 'consolidation', 'retain')
5432
+ task_type: Task type for the task payload (e.g., 'consolidation', 'batch_retain')
5433
+ task_payload: Additional task payload fields (operation_id and bank_id are added automatically)
5434
+ result_metadata: Optional metadata to store with the operation record
5435
+ dedupe_by_bank: If True, skip creating a new task if one is already pending for this bank+operation_type
4068
5436
 
5437
+ Returns:
5438
+ Dict with operation_id and optionally deduplicated=True if an existing task was found
5439
+ """
4069
5440
  import json
4070
5441
 
5442
+ pool = await self._get_pool()
5443
+
5444
+ # Check for existing pending task if deduplication is enabled
5445
+ # Note: We only check 'pending', not 'processing', because a processing task
5446
+ # uses a watermark from when it started - new memories added after that point
5447
+ # would need another consolidation run to be processed.
5448
+ if dedupe_by_bank:
5449
+ async with acquire_with_retry(pool) as conn:
5450
+ existing = await conn.fetchrow(
5451
+ f"""
5452
+ SELECT operation_id FROM {fq_table("async_operations")}
5453
+ WHERE bank_id = $1 AND operation_type = $2 AND status = 'pending'
5454
+ LIMIT 1
5455
+ """,
5456
+ bank_id,
5457
+ operation_type,
5458
+ )
5459
+ if existing:
5460
+ logger.debug(
5461
+ f"{operation_type} task already pending for bank_id={bank_id}, "
5462
+ f"skipping duplicate (existing operation_id={existing['operation_id']})"
5463
+ )
5464
+ return {
5465
+ "operation_id": str(existing["operation_id"]),
5466
+ "deduplicated": True,
5467
+ }
5468
+
4071
5469
  operation_id = uuid.uuid4()
4072
5470
 
4073
5471
  # Insert operation record into database
@@ -4079,23 +5477,113 @@ Guidelines:
4079
5477
  """,
4080
5478
  operation_id,
4081
5479
  bank_id,
4082
- "retain",
4083
- json.dumps({"items_count": len(contents)}),
5480
+ operation_type,
5481
+ json.dumps(result_metadata or {}),
4084
5482
  )
4085
5483
 
4086
- # Submit task to background queue
4087
- await self._task_backend.submit_task(
4088
- {
4089
- "type": "batch_retain",
4090
- "operation_id": str(operation_id),
4091
- "bank_id": bank_id,
4092
- "contents": contents,
4093
- }
4094
- )
5484
+ # Build and submit task payload
5485
+ full_payload = {
5486
+ "type": task_type,
5487
+ "operation_id": str(operation_id),
5488
+ "bank_id": bank_id,
5489
+ **task_payload,
5490
+ }
5491
+
5492
+ await self._task_backend.submit_task(full_payload)
4095
5493
 
4096
- logger.info(f"Retain task queued for bank_id={bank_id}, {len(contents)} items, operation_id={operation_id}")
5494
+ logger.info(f"{operation_type} task queued for bank_id={bank_id}, operation_id={operation_id}")
4097
5495
 
4098
5496
  return {
4099
5497
  "operation_id": str(operation_id),
4100
- "items_count": len(contents),
4101
5498
  }
5499
+
5500
+ async def submit_async_retain(
5501
+ self,
5502
+ bank_id: str,
5503
+ contents: list[dict[str, Any]],
5504
+ *,
5505
+ request_context: "RequestContext",
5506
+ document_tags: list[str] | None = None,
5507
+ ) -> dict[str, Any]:
5508
+ """Submit a batch retain operation to run asynchronously."""
5509
+ await self._authenticate_tenant(request_context)
5510
+
5511
+ task_payload: dict[str, Any] = {"contents": contents}
5512
+ if document_tags:
5513
+ task_payload["document_tags"] = document_tags
5514
+
5515
+ result = await self._submit_async_operation(
5516
+ bank_id=bank_id,
5517
+ operation_type="retain",
5518
+ task_type="batch_retain",
5519
+ task_payload=task_payload,
5520
+ result_metadata={"items_count": len(contents)},
5521
+ dedupe_by_bank=False,
5522
+ )
5523
+
5524
+ result["items_count"] = len(contents)
5525
+ return result
5526
+
5527
+ async def submit_async_consolidation(
5528
+ self,
5529
+ bank_id: str,
5530
+ *,
5531
+ request_context: "RequestContext",
5532
+ ) -> dict[str, Any]:
5533
+ """Submit a consolidation operation to run asynchronously.
5534
+
5535
+ Deduplicates by bank_id - if there's already a pending consolidation for this bank,
5536
+ returns the existing operation_id instead of creating a new one.
5537
+
5538
+ Args:
5539
+ bank_id: Bank identifier
5540
+ request_context: Request context for authentication
5541
+
5542
+ Returns:
5543
+ Dict with operation_id
5544
+ """
5545
+ await self._authenticate_tenant(request_context)
5546
+ return await self._submit_async_operation(
5547
+ bank_id=bank_id,
5548
+ operation_type="consolidation",
5549
+ task_type="consolidation",
5550
+ task_payload={},
5551
+ dedupe_by_bank=True,
5552
+ )
5553
+
5554
+ async def submit_async_refresh_mental_model(
5555
+ self,
5556
+ bank_id: str,
5557
+ mental_model_id: str,
5558
+ *,
5559
+ request_context: "RequestContext",
5560
+ ) -> dict[str, Any]:
5561
+ """Submit an async mental model refresh operation.
5562
+
5563
+ This schedules a background task to re-run the source query and update the content.
5564
+
5565
+ Args:
5566
+ bank_id: Bank identifier
5567
+ mental_model_id: Mental model UUID to refresh
5568
+ request_context: Request context for authentication
5569
+
5570
+ Returns:
5571
+ Dict with operation_id
5572
+ """
5573
+ await self._authenticate_tenant(request_context)
5574
+
5575
+ # Verify mental model exists
5576
+ mental_model = await self.get_mental_model(bank_id, mental_model_id, request_context=request_context)
5577
+ if not mental_model:
5578
+ raise ValueError(f"Mental model {mental_model_id} not found in bank {bank_id}")
5579
+
5580
+ return await self._submit_async_operation(
5581
+ bank_id=bank_id,
5582
+ operation_type="refresh_mental_model",
5583
+ task_type="refresh_mental_model",
5584
+ task_payload={
5585
+ "mental_model_id": mental_model_id,
5586
+ },
5587
+ result_metadata={"mental_model_id": mental_model_id, "name": mental_model["name"]},
5588
+ dedupe_by_bank=False,
5589
+ )