hindsight-api 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. hindsight_api/__init__.py +2 -0
  2. hindsight_api/alembic/env.py +24 -1
  3. hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +14 -4
  4. hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +54 -13
  5. hindsight_api/alembic/versions/rename_personality_to_disposition.py +18 -7
  6. hindsight_api/api/http.py +253 -230
  7. hindsight_api/api/mcp.py +14 -3
  8. hindsight_api/config.py +11 -0
  9. hindsight_api/daemon.py +204 -0
  10. hindsight_api/engine/__init__.py +12 -1
  11. hindsight_api/engine/entity_resolver.py +38 -37
  12. hindsight_api/engine/interface.py +592 -0
  13. hindsight_api/engine/llm_wrapper.py +176 -6
  14. hindsight_api/engine/memory_engine.py +1092 -293
  15. hindsight_api/engine/retain/bank_utils.py +13 -12
  16. hindsight_api/engine/retain/chunk_storage.py +3 -2
  17. hindsight_api/engine/retain/fact_storage.py +10 -7
  18. hindsight_api/engine/retain/link_utils.py +17 -16
  19. hindsight_api/engine/retain/observation_regeneration.py +17 -16
  20. hindsight_api/engine/retain/orchestrator.py +2 -3
  21. hindsight_api/engine/retain/types.py +25 -8
  22. hindsight_api/engine/search/graph_retrieval.py +6 -5
  23. hindsight_api/engine/search/mpfp_retrieval.py +8 -7
  24. hindsight_api/engine/search/reranking.py +17 -0
  25. hindsight_api/engine/search/retrieval.py +12 -11
  26. hindsight_api/engine/search/think_utils.py +1 -1
  27. hindsight_api/engine/search/tracer.py +1 -1
  28. hindsight_api/engine/task_backend.py +32 -0
  29. hindsight_api/extensions/__init__.py +66 -0
  30. hindsight_api/extensions/base.py +81 -0
  31. hindsight_api/extensions/builtin/__init__.py +18 -0
  32. hindsight_api/extensions/builtin/tenant.py +33 -0
  33. hindsight_api/extensions/context.py +110 -0
  34. hindsight_api/extensions/http.py +89 -0
  35. hindsight_api/extensions/loader.py +125 -0
  36. hindsight_api/extensions/operation_validator.py +325 -0
  37. hindsight_api/extensions/tenant.py +63 -0
  38. hindsight_api/main.py +97 -17
  39. hindsight_api/mcp_local.py +7 -1
  40. hindsight_api/migrations.py +54 -10
  41. hindsight_api/models.py +15 -0
  42. hindsight_api/pg0.py +1 -1
  43. {hindsight_api-0.1.11.dist-info → hindsight_api-0.1.13.dist-info}/METADATA +1 -1
  44. hindsight_api-0.1.13.dist-info/RECORD +75 -0
  45. hindsight_api-0.1.11.dist-info/RECORD +0 -64
  46. {hindsight_api-0.1.11.dist-info → hindsight_api-0.1.13.dist-info}/WHEEL +0 -0
  47. {hindsight_api-0.1.11.dist-info → hindsight_api-0.1.13.dist-info}/entry_points.txt +0 -0
@@ -10,39 +10,122 @@ This implements a sophisticated memory architecture that combines:
10
10
  """
11
11
 
12
12
  import asyncio
13
+ import contextvars
13
14
  import logging
14
15
  import time
15
16
  import uuid
16
17
  from datetime import UTC, datetime, timedelta
17
- from typing import TYPE_CHECKING, Any, TypedDict
18
+ from typing import TYPE_CHECKING, Any
18
19
 
19
- import asyncpg
20
- import numpy as np
21
- from pydantic import BaseModel, Field
20
+ # Context variable for current schema (async-safe, per-task isolation)
21
+ _current_schema: contextvars.ContextVar[str] = contextvars.ContextVar("current_schema", default="public")
22
22
 
23
- from .cross_encoder import CrossEncoderModel
24
- from .embeddings import Embeddings, create_embeddings_from_env
25
23
 
26
- if TYPE_CHECKING:
24
+ def get_current_schema() -> str:
25
+ """Get the current schema from context (default: 'public')."""
26
+ return _current_schema.get()
27
+
28
+
29
+ def fq_table(table_name: str) -> str:
30
+ """
31
+ Get fully-qualified table name with current schema.
32
+
33
+ Example:
34
+ fq_table("memory_units") -> "public.memory_units"
35
+ fq_table("memory_units") -> "tenant_xyz.memory_units" (if schema is set)
36
+ """
37
+ return f"{get_current_schema()}.{table_name}"
38
+
39
+
40
+ # Tables that must be schema-qualified (for runtime validation)
41
+ _PROTECTED_TABLES = frozenset(
42
+ [
43
+ "memory_units",
44
+ "memory_links",
45
+ "unit_entities",
46
+ "entities",
47
+ "entity_cooccurrences",
48
+ "banks",
49
+ "documents",
50
+ "chunks",
51
+ "async_operations",
52
+ ]
53
+ )
54
+
55
+ # Enable runtime SQL validation (can be disabled in production for performance)
56
+ _VALIDATE_SQL_SCHEMAS = True
57
+
58
+
59
+ class UnqualifiedTableError(Exception):
60
+ """Raised when SQL contains unqualified table references."""
61
+
27
62
  pass
28
63
 
29
64
 
30
- class RetainContentDict(TypedDict, total=False):
31
- """Type definition for content items in retain_batch_async.
65
+ def validate_sql_schema(sql: str) -> None:
66
+ """
67
+ Validate that SQL doesn't contain unqualified table references.
32
68
 
33
- Fields:
34
- content: Text content to store (required)
35
- context: Context about the content (optional)
36
- event_date: When the content occurred (optional, defaults to now)
37
- metadata: Custom key-value metadata (optional)
38
- document_id: Document ID for this content item (optional)
69
+ This is a runtime safety check to prevent cross-tenant data access.
70
+ Raises UnqualifiedTableError if any protected table is referenced
71
+ without a schema prefix.
72
+
73
+ Args:
74
+ sql: The SQL query to validate
75
+
76
+ Raises:
77
+ UnqualifiedTableError: If unqualified table reference found
39
78
  """
79
+ if not _VALIDATE_SQL_SCHEMAS:
80
+ return
81
+
82
+ import re
83
+
84
+ sql_upper = sql.upper()
85
+
86
+ for table in _PROTECTED_TABLES:
87
+ table_upper = table.upper()
88
+
89
+ # Pattern: SQL keyword followed by unqualified table name
90
+ # Matches: FROM memory_units, JOIN memory_units, INTO memory_units, UPDATE memory_units
91
+ patterns = [
92
+ rf"FROM\s+{table_upper}(?:\s|$|,|\)|;)",
93
+ rf"JOIN\s+{table_upper}(?:\s|$|,|\)|;)",
94
+ rf"INTO\s+{table_upper}(?:\s|$|\()",
95
+ rf"UPDATE\s+{table_upper}(?:\s|$)",
96
+ rf"DELETE\s+FROM\s+{table_upper}(?:\s|$|;)",
97
+ ]
98
+
99
+ for pattern in patterns:
100
+ match = re.search(pattern, sql_upper)
101
+ if match:
102
+ # Check if it's actually qualified (preceded by schema.)
103
+ # Look backwards from match to see if there's a dot
104
+ start = match.start()
105
+ # Find the table name position in the match
106
+ table_pos = sql_upper.find(table_upper, start)
107
+ if table_pos > 0:
108
+ # Check character before table name (skip whitespace)
109
+ prefix = sql[:table_pos].rstrip()
110
+ if not prefix.endswith("."):
111
+ raise UnqualifiedTableError(
112
+ f"Unqualified table reference '{table}' in SQL. "
113
+ f"Use fq_table('{table}') for schema safety. "
114
+ f"SQL snippet: ...{sql[max(0, start - 10) : start + 50]}..."
115
+ )
40
116
 
41
- content: str # Required
42
- context: str
43
- event_date: datetime
44
- metadata: dict[str, str]
45
- document_id: str
117
+
118
+ import asyncpg
119
+ import numpy as np
120
+ from pydantic import BaseModel, Field
121
+
122
+ from .cross_encoder import CrossEncoderModel
123
+ from .embeddings import Embeddings, create_embeddings_from_env
124
+ from .interface import MemoryEngineInterface
125
+
126
+ if TYPE_CHECKING:
127
+ from hindsight_api.extensions import OperationValidatorExtension, TenantExtension
128
+ from hindsight_api.models import RequestContext
46
129
 
47
130
 
48
131
  from enum import Enum
@@ -54,6 +137,7 @@ from .query_analyzer import QueryAnalyzer
54
137
  from .response_models import VALID_RECALL_FACT_TYPES, EntityObservation, EntityState, MemoryFact, ReflectResult
55
138
  from .response_models import RecallResult as RecallResultModel
56
139
  from .retain import bank_utils, embedding_utils
140
+ from .retain.types import RetainContentDict
57
141
  from .search import observation_utils, think_utils
58
142
  from .search.reranking import CrossEncoderReranker
59
143
  from .task_backend import AsyncIOQueueBackend, TaskBackend
@@ -91,7 +175,7 @@ def _get_tiktoken_encoding():
91
175
  return _TIKTOKEN_ENCODING
92
176
 
93
177
 
94
- class MemoryEngine:
178
+ class MemoryEngine(MemoryEngineInterface):
95
179
  """
96
180
  Advanced memory system using temporal and semantic linking with PostgreSQL.
97
181
 
@@ -116,6 +200,10 @@ class MemoryEngine:
116
200
  pool_max_size: int = 100,
117
201
  task_backend: TaskBackend | None = None,
118
202
  run_migrations: bool = True,
203
+ operation_validator: "OperationValidatorExtension | None" = None,
204
+ tenant_extension: "TenantExtension | None" = None,
205
+ skip_llm_verification: bool | None = None,
206
+ lazy_reranker: bool | None = None,
119
207
  ):
120
208
  """
121
209
  Initialize the temporal + semantic memory system.
@@ -137,16 +225,34 @@ class MemoryEngine:
137
225
  pool_max_size: Maximum number of connections in the pool (default: 100)
138
226
  task_backend: Custom task backend. If not provided, uses AsyncIOQueueBackend.
139
227
  run_migrations: Whether to run database migrations during initialize(). Default: True
228
+ operation_validator: Optional extension to validate operations before execution.
229
+ If provided, retain/recall/reflect operations will be validated.
230
+ tenant_extension: Optional extension for multi-tenancy and API key authentication.
231
+ If provided, operations require a RequestContext for authentication.
232
+ skip_llm_verification: Skip LLM connection verification during initialization.
233
+ Defaults to HINDSIGHT_API_SKIP_LLM_VERIFICATION env var or False.
234
+ lazy_reranker: Delay reranker initialization until first use. Useful for retain-only
235
+ operations that don't need the cross-encoder. Defaults to
236
+ HINDSIGHT_API_LAZY_RERANKER env var or False.
140
237
  """
141
238
  # Load config from environment for any missing parameters
142
239
  from ..config import get_config
143
240
 
144
241
  config = get_config()
145
242
 
243
+ # Apply optimization flags from config if not explicitly provided
244
+ self._skip_llm_verification = (
245
+ skip_llm_verification if skip_llm_verification is not None else config.skip_llm_verification
246
+ )
247
+ self._lazy_reranker = lazy_reranker if lazy_reranker is not None else config.lazy_reranker
248
+
146
249
  # Apply defaults from config
147
250
  db_url = db_url or config.database_url
148
251
  memory_llm_provider = memory_llm_provider or config.llm_provider
149
252
  memory_llm_api_key = memory_llm_api_key or config.llm_api_key
253
+ # Ollama doesn't require an API key
254
+ if not memory_llm_api_key and memory_llm_provider != "ollama":
255
+ raise ValueError("LLM API key is required. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
150
256
  memory_llm_model = memory_llm_model or config.llm_model
151
257
  memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
152
258
  # Track pg0 instance (if used)
@@ -243,27 +349,82 @@ class MemoryEngine:
243
349
  # initialize encoding eagerly to avoid delaying the first time
244
350
  _get_tiktoken_encoding()
245
351
 
352
+ # Store operation validator extension (optional)
353
+ self._operation_validator = operation_validator
354
+
355
+ # Store tenant extension (optional)
356
+ self._tenant_extension = tenant_extension
357
+
358
+ async def _validate_operation(self, validation_coro) -> None:
359
+ """
360
+ Run validation if an operation validator is configured.
361
+
362
+ Args:
363
+ validation_coro: Coroutine that returns a ValidationResult
364
+
365
+ Raises:
366
+ OperationValidationError: If validation fails
367
+ """
368
+ if self._operation_validator is None:
369
+ return
370
+
371
+ from hindsight_api.extensions import OperationValidationError
372
+
373
+ result = await validation_coro
374
+ if not result.allowed:
375
+ raise OperationValidationError(result.reason or "Operation not allowed")
376
+
377
+ async def _authenticate_tenant(self, request_context: "RequestContext | None") -> str:
378
+ """
379
+ Authenticate tenant and set schema in context variable.
380
+
381
+ The schema is stored in a contextvar for async-safe, per-task isolation.
382
+ Use fq_table(table_name) to get fully-qualified table names.
383
+
384
+ Args:
385
+ request_context: The request context with API key. Required if tenant_extension is configured.
386
+
387
+ Returns:
388
+ Schema name that was set in the context.
389
+
390
+ Raises:
391
+ AuthenticationError: If authentication fails or request_context is missing when required.
392
+ """
393
+ if self._tenant_extension is None:
394
+ _current_schema.set("public")
395
+ return "public"
396
+
397
+ from hindsight_api.extensions import AuthenticationError
398
+
399
+ if request_context is None:
400
+ raise AuthenticationError("RequestContext is required when tenant extension is configured")
401
+
402
+ tenant_context = await self._tenant_extension.authenticate(request_context)
403
+ _current_schema.set(tenant_context.schema_name)
404
+ return tenant_context.schema_name
405
+
246
406
  async def _handle_access_count_update(self, task_dict: dict[str, Any]):
247
407
  """
248
408
  Handler for access count update tasks.
249
409
 
250
410
  Args:
251
411
  task_dict: Dict with 'node_ids' key containing list of node IDs to update
412
+
413
+ Raises:
414
+ Exception: Any exception from database operations (propagates to execute_task for retry)
252
415
  """
253
416
  node_ids = task_dict.get("node_ids", [])
254
417
  if not node_ids:
255
418
  return
256
419
 
257
420
  pool = await self._get_pool()
258
- try:
259
- # Convert string UUIDs to UUID type for faster matching
260
- uuid_list = [uuid.UUID(nid) for nid in node_ids]
261
- async with acquire_with_retry(pool) as conn:
262
- await conn.execute(
263
- "UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])", uuid_list
264
- )
265
- except Exception as e:
266
- logger.error(f"Access count handler: Error updating access counts: {e}")
421
+ # Convert string UUIDs to UUID type for faster matching
422
+ uuid_list = [uuid.UUID(nid) for nid in node_ids]
423
+ async with acquire_with_retry(pool) as conn:
424
+ await conn.execute(
425
+ f"UPDATE {fq_table('memory_units')} SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
426
+ uuid_list,
427
+ )
267
428
 
268
429
  async def _handle_batch_retain(self, task_dict: dict[str, Any]):
269
430
  """
@@ -271,23 +432,27 @@ class MemoryEngine:
271
432
 
272
433
  Args:
273
434
  task_dict: Dict with 'bank_id', 'contents'
435
+
436
+ Raises:
437
+ ValueError: If bank_id is missing
438
+ Exception: Any exception from retain_batch_async (propagates to execute_task for retry)
274
439
  """
275
- try:
276
- bank_id = task_dict.get("bank_id")
277
- contents = task_dict.get("contents", [])
440
+ bank_id = task_dict.get("bank_id")
441
+ if not bank_id:
442
+ raise ValueError("bank_id is required for batch retain task")
443
+ contents = task_dict.get("contents", [])
278
444
 
279
- logger.info(
280
- f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
281
- )
445
+ logger.info(
446
+ f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
447
+ )
282
448
 
283
- await self.retain_batch_async(bank_id=bank_id, contents=contents)
449
+ # Use internal request context for background tasks
450
+ from hindsight_api.models import RequestContext
284
451
 
285
- logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
286
- except Exception as e:
287
- logger.error(f"Batch retain handler: Error processing batch retain: {e}")
288
- import traceback
452
+ internal_context = RequestContext()
453
+ await self.retain_batch_async(bank_id=bank_id, contents=contents, request_context=internal_context)
289
454
 
290
- traceback.print_exc()
455
+ logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
291
456
 
292
457
  async def execute_task(self, task_dict: dict[str, Any]):
293
458
  """
@@ -311,7 +476,8 @@ class MemoryEngine:
311
476
  pool = await self._get_pool()
312
477
  async with acquire_with_retry(pool) as conn:
313
478
  result = await conn.fetchrow(
314
- "SELECT operation_id FROM async_operations WHERE operation_id = $1", uuid.UUID(operation_id)
479
+ f"SELECT operation_id FROM {fq_table('async_operations')} WHERE operation_id = $1",
480
+ uuid.UUID(operation_id),
315
481
  )
316
482
  if not result:
317
483
  # Operation was cancelled, skip processing
@@ -369,7 +535,9 @@ class MemoryEngine:
369
535
  try:
370
536
  pool = await self._get_pool()
371
537
  async with acquire_with_retry(pool) as conn:
372
- await conn.execute("DELETE FROM async_operations WHERE operation_id = $1", uuid.UUID(operation_id))
538
+ await conn.execute(
539
+ f"DELETE FROM {fq_table('async_operations')} WHERE operation_id = $1", uuid.UUID(operation_id)
540
+ )
373
541
  except Exception as e:
374
542
  logger.error(f"Failed to delete async operation record {operation_id}: {e}")
375
543
 
@@ -383,8 +551,8 @@ class MemoryEngine:
383
551
 
384
552
  async with acquire_with_retry(pool) as conn:
385
553
  await conn.execute(
386
- """
387
- UPDATE async_operations
554
+ f"""
555
+ UPDATE {fq_table("async_operations")}
388
556
  SET status = 'failed', error_message = $2
389
557
  WHERE operation_id = $1
390
558
  """,
@@ -413,7 +581,7 @@ class MemoryEngine:
413
581
  kwargs = {"name": self._pg0_instance_name}
414
582
  if self._pg0_port is not None:
415
583
  kwargs["port"] = self._pg0_port
416
- pg0 = EmbeddedPostgres(**kwargs)
584
+ pg0 = EmbeddedPostgres(**kwargs) # type: ignore[invalid-argument-type] - dict kwargs
417
585
  # Check if pg0 is already running before we start it
418
586
  was_already_running = await pg0.is_running()
419
587
  self.db_url = await pg0.ensure_running()
@@ -437,6 +605,8 @@ class MemoryEngine:
437
605
  await loop.run_in_executor(None, lambda: asyncio.run(cross_encoder.initialize()))
438
606
  else:
439
607
  await cross_encoder.initialize()
608
+ # Mark reranker as initialized
609
+ self._cross_encoder_reranker._initialized = True
440
610
 
441
611
  async def init_query_analyzer():
442
612
  """Initialize query analyzer model."""
@@ -445,21 +615,33 @@ class MemoryEngine:
445
615
 
446
616
  async def verify_llm():
447
617
  """Verify LLM connection is working."""
448
- await self._llm_config.verify_connection()
618
+ if not self._skip_llm_verification:
619
+ await self._llm_config.verify_connection()
449
620
 
450
- # Run pg0 and all model initializations in parallel
451
- await asyncio.gather(
621
+ # Build list of initialization tasks
622
+ init_tasks = [
452
623
  start_pg0(),
453
624
  init_embeddings(),
454
- init_cross_encoder(),
455
625
  init_query_analyzer(),
456
- verify_llm(),
457
- )
626
+ ]
627
+
628
+ # Only init cross-encoder eagerly if not using lazy initialization
629
+ if not self._lazy_reranker:
630
+ init_tasks.append(init_cross_encoder())
631
+
632
+ # Only verify LLM if not skipping
633
+ if not self._skip_llm_verification:
634
+ init_tasks.append(verify_llm())
635
+
636
+ # Run pg0 and selected model initializations in parallel
637
+ await asyncio.gather(*init_tasks)
458
638
 
459
639
  # Run database migrations if enabled
460
640
  if self._run_migrations:
461
641
  from ..migrations import run_migrations
462
642
 
643
+ if not self.db_url:
644
+ raise ValueError("Database URL is required for migrations")
463
645
  logger.info("Running database migrations...")
464
646
  run_migrations(self.db_url)
465
647
 
@@ -628,9 +810,9 @@ class MemoryEngine:
628
810
 
629
811
  fetch_start = time_mod.time()
630
812
  existing_facts = await conn.fetch(
631
- """
813
+ f"""
632
814
  SELECT id, text, embedding
633
- FROM memory_units
815
+ FROM {fq_table("memory_units")}
634
816
  WHERE bank_id = $1
635
817
  AND event_date BETWEEN $2 AND $3
636
818
  """,
@@ -692,6 +874,7 @@ class MemoryEngine:
692
874
  content: str,
693
875
  context: str = "",
694
876
  event_date: datetime | None = None,
877
+ request_context: "RequestContext | None" = None,
695
878
  ) -> list[str]:
696
879
  """
697
880
  Store content as memory units (synchronous wrapper).
@@ -704,12 +887,16 @@ class MemoryEngine:
704
887
  content: Text content to store
705
888
  context: Context about when/why this memory was formed
706
889
  event_date: When the event occurred (defaults to now)
890
+ request_context: Request context for authentication (optional, uses internal context if not provided)
707
891
 
708
892
  Returns:
709
893
  List of created unit IDs
710
894
  """
711
895
  # Run async version synchronously
712
- return asyncio.run(self.retain_async(bank_id, content, context, event_date))
896
+ from hindsight_api.models import RequestContext as RC
897
+
898
+ ctx = request_context if request_context is not None else RC()
899
+ return asyncio.run(self.retain_async(bank_id, content, context, event_date, request_context=ctx))
713
900
 
714
901
  async def retain_async(
715
902
  self,
@@ -720,6 +907,8 @@ class MemoryEngine:
720
907
  document_id: str | None = None,
721
908
  fact_type_override: str | None = None,
722
909
  confidence_score: float | None = None,
910
+ *,
911
+ request_context: "RequestContext",
723
912
  ) -> list[str]:
724
913
  """
725
914
  Store content as memory units with temporal and semantic links (ASYNC version).
@@ -734,12 +923,15 @@ class MemoryEngine:
734
923
  document_id: Optional document ID for tracking (always upserts if document already exists)
735
924
  fact_type_override: Override fact type ('world', 'experience', 'opinion')
736
925
  confidence_score: Confidence score for opinions (0.0 to 1.0)
926
+ request_context: Request context for authentication.
737
927
 
738
928
  Returns:
739
929
  List of created unit IDs
740
930
  """
741
931
  # Build content dict
742
- content_dict: RetainContentDict = {"content": content, "context": context, "event_date": event_date}
932
+ content_dict: RetainContentDict = {"content": content, "context": context} # type: ignore[typeddict-item] - building incrementally
933
+ if event_date:
934
+ content_dict["event_date"] = event_date
743
935
  if document_id:
744
936
  content_dict["document_id"] = document_id
745
937
 
@@ -747,6 +939,7 @@ class MemoryEngine:
747
939
  result = await self.retain_batch_async(
748
940
  bank_id=bank_id,
749
941
  contents=[content_dict],
942
+ request_context=request_context,
750
943
  fact_type_override=fact_type_override,
751
944
  confidence_score=confidence_score,
752
945
  )
@@ -758,6 +951,8 @@ class MemoryEngine:
758
951
  self,
759
952
  bank_id: str,
760
953
  contents: list[RetainContentDict],
954
+ *,
955
+ request_context: "RequestContext",
761
956
  document_id: str | None = None,
762
957
  fact_type_override: str | None = None,
763
958
  confidence_score: float | None = None,
@@ -813,6 +1008,24 @@ class MemoryEngine:
813
1008
  if not contents:
814
1009
  return []
815
1010
 
1011
+ # Authenticate tenant and set schema in context (for fq_table())
1012
+ await self._authenticate_tenant(request_context)
1013
+
1014
+ # Validate operation if validator is configured
1015
+ contents_copy = [dict(c) for c in contents] # Convert TypedDict to regular dict for extension
1016
+ if self._operation_validator:
1017
+ from hindsight_api.extensions import RetainContext
1018
+
1019
+ ctx = RetainContext(
1020
+ bank_id=bank_id,
1021
+ contents=contents_copy,
1022
+ request_context=request_context,
1023
+ document_id=document_id,
1024
+ fact_type_override=fact_type_override,
1025
+ confidence_score=confidence_score,
1026
+ )
1027
+ await self._validate_operation(self._operation_validator.validate_retain(ctx))
1028
+
816
1029
  # Apply batch-level document_id to contents that don't have their own (backwards compatibility)
817
1030
  if document_id:
818
1031
  for item in contents:
@@ -876,17 +1089,39 @@ class MemoryEngine:
876
1089
  logger.info(
877
1090
  f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s"
878
1091
  )
879
- return all_results
1092
+ result = all_results
1093
+ else:
1094
+ # Small batch - use internal method directly
1095
+ result = await self._retain_batch_async_internal(
1096
+ bank_id=bank_id,
1097
+ contents=contents,
1098
+ document_id=document_id,
1099
+ is_first_batch=True,
1100
+ fact_type_override=fact_type_override,
1101
+ confidence_score=confidence_score,
1102
+ )
880
1103
 
881
- # Small batch - use internal method directly
882
- return await self._retain_batch_async_internal(
883
- bank_id=bank_id,
884
- contents=contents,
885
- document_id=document_id,
886
- is_first_batch=True,
887
- fact_type_override=fact_type_override,
888
- confidence_score=confidence_score,
889
- )
1104
+ # Call post-operation hook if validator is configured
1105
+ if self._operation_validator:
1106
+ from hindsight_api.extensions import RetainResult
1107
+
1108
+ result_ctx = RetainResult(
1109
+ bank_id=bank_id,
1110
+ contents=contents_copy,
1111
+ request_context=request_context,
1112
+ document_id=document_id,
1113
+ fact_type_override=fact_type_override,
1114
+ confidence_score=confidence_score,
1115
+ unit_ids=result,
1116
+ success=True,
1117
+ error=None,
1118
+ )
1119
+ try:
1120
+ await self._operation_validator.on_retain_complete(result_ctx)
1121
+ except Exception as e:
1122
+ logger.warning(f"Post-retain hook error (non-fatal): {e}")
1123
+
1124
+ return result
890
1125
 
891
1126
  async def _retain_batch_async_internal(
892
1127
  self,
@@ -961,22 +1196,36 @@ class MemoryEngine:
961
1196
  Returns:
962
1197
  Tuple of (results, trace)
963
1198
  """
964
- # Run async version synchronously
965
- return asyncio.run(self.recall_async(bank_id, query, [fact_type], budget, max_tokens, enable_trace))
1199
+ # Run async version synchronously - deprecated sync method, passing None for request_context
1200
+ from hindsight_api.models import RequestContext
1201
+
1202
+ return asyncio.run(
1203
+ self.recall_async(
1204
+ bank_id,
1205
+ query,
1206
+ budget=budget,
1207
+ max_tokens=max_tokens,
1208
+ enable_trace=enable_trace,
1209
+ fact_type=[fact_type],
1210
+ request_context=RequestContext(),
1211
+ )
1212
+ )
966
1213
 
967
1214
  async def recall_async(
968
1215
  self,
969
1216
  bank_id: str,
970
1217
  query: str,
971
- fact_type: list[str],
972
- budget: Budget = Budget.MID,
1218
+ *,
1219
+ budget: Budget | None = None,
973
1220
  max_tokens: int = 4096,
974
1221
  enable_trace: bool = False,
1222
+ fact_type: list[str] | None = None,
975
1223
  question_date: datetime | None = None,
976
1224
  include_entities: bool = False,
977
- max_entity_tokens: int = 1024,
1225
+ max_entity_tokens: int = 500,
978
1226
  include_chunks: bool = False,
979
1227
  max_chunk_tokens: int = 8192,
1228
+ request_context: "RequestContext",
980
1229
  ) -> RecallResultModel:
981
1230
  """
982
1231
  Recall memories using N*4-way parallel retrieval (N fact types × 4 retrieval methods).
@@ -1010,6 +1259,13 @@ class MemoryEngine:
1010
1259
  - entities: Optional dict of entity states (if include_entities=True)
1011
1260
  - chunks: Optional dict of chunks (if include_chunks=True)
1012
1261
  """
1262
+ # Authenticate tenant and set schema in context (for fq_table())
1263
+ await self._authenticate_tenant(request_context)
1264
+
1265
+ # Default to all fact types if not specified
1266
+ if fact_type is None:
1267
+ fact_type = list(VALID_RECALL_FACT_TYPES)
1268
+
1013
1269
  # Validate fact types early
1014
1270
  invalid_types = set(fact_type) - VALID_RECALL_FACT_TYPES
1015
1271
  if invalid_types:
@@ -1018,17 +1274,40 @@ class MemoryEngine:
1018
1274
  f"Must be one of: {', '.join(sorted(VALID_RECALL_FACT_TYPES))}"
1019
1275
  )
1020
1276
 
1021
- # Map budget enum to thinking_budget number
1277
+ # Validate operation if validator is configured
1278
+ if self._operation_validator:
1279
+ from hindsight_api.extensions import RecallContext
1280
+
1281
+ ctx = RecallContext(
1282
+ bank_id=bank_id,
1283
+ query=query,
1284
+ request_context=request_context,
1285
+ budget=budget,
1286
+ max_tokens=max_tokens,
1287
+ enable_trace=enable_trace,
1288
+ fact_types=list(fact_type),
1289
+ question_date=question_date,
1290
+ include_entities=include_entities,
1291
+ max_entity_tokens=max_entity_tokens,
1292
+ include_chunks=include_chunks,
1293
+ max_chunk_tokens=max_chunk_tokens,
1294
+ )
1295
+ await self._validate_operation(self._operation_validator.validate_recall(ctx))
1296
+
1297
+ # Map budget enum to thinking_budget number (default to MID if None)
1022
1298
  budget_mapping = {Budget.LOW: 100, Budget.MID: 300, Budget.HIGH: 1000}
1023
- thinking_budget = budget_mapping[budget]
1299
+ effective_budget = budget if budget is not None else Budget.MID
1300
+ thinking_budget = budget_mapping[effective_budget]
1024
1301
 
1025
1302
  # Backpressure: limit concurrent recalls to prevent overwhelming the database
1303
+ result = None
1304
+ error_msg = None
1026
1305
  async with self._search_semaphore:
1027
1306
  # Retry loop for connection errors
1028
1307
  max_retries = 3
1029
1308
  for attempt in range(max_retries + 1):
1030
1309
  try:
1031
- return await self._search_with_retries(
1310
+ result = await self._search_with_retries(
1032
1311
  bank_id,
1033
1312
  query,
1034
1313
  fact_type,
@@ -1040,7 +1319,9 @@ class MemoryEngine:
1040
1319
  max_entity_tokens,
1041
1320
  include_chunks,
1042
1321
  max_chunk_tokens,
1322
+ request_context,
1043
1323
  )
1324
+ break # Success - exit retry loop
1044
1325
  except Exception as e:
1045
1326
  # Check if it's a connection error
1046
1327
  is_connection_error = (
@@ -1058,9 +1339,89 @@ class MemoryEngine:
1058
1339
  )
1059
1340
  await asyncio.sleep(wait_time)
1060
1341
  else:
1061
- # Not a connection error or out of retries - raise
1342
+ # Not a connection error or out of retries - call post-hook and raise
1343
+ error_msg = str(e)
1344
+ if self._operation_validator:
1345
+ from hindsight_api.extensions.operation_validator import RecallResult
1346
+
1347
+ result_ctx = RecallResult(
1348
+ bank_id=bank_id,
1349
+ query=query,
1350
+ request_context=request_context,
1351
+ budget=budget,
1352
+ max_tokens=max_tokens,
1353
+ enable_trace=enable_trace,
1354
+ fact_types=list(fact_type),
1355
+ question_date=question_date,
1356
+ include_entities=include_entities,
1357
+ max_entity_tokens=max_entity_tokens,
1358
+ include_chunks=include_chunks,
1359
+ max_chunk_tokens=max_chunk_tokens,
1360
+ result=None,
1361
+ success=False,
1362
+ error=error_msg,
1363
+ )
1364
+ try:
1365
+ await self._operation_validator.on_recall_complete(result_ctx)
1366
+ except Exception as hook_err:
1367
+ logger.warning(f"Post-recall hook error (non-fatal): {hook_err}")
1062
1368
  raise
1063
- raise Exception("Exceeded maximum retries for search due to connection errors.")
1369
+ else:
1370
+ # Exceeded max retries
1371
+ error_msg = "Exceeded maximum retries for search due to connection errors."
1372
+ if self._operation_validator:
1373
+ from hindsight_api.extensions.operation_validator import RecallResult
1374
+
1375
+ result_ctx = RecallResult(
1376
+ bank_id=bank_id,
1377
+ query=query,
1378
+ request_context=request_context,
1379
+ budget=budget,
1380
+ max_tokens=max_tokens,
1381
+ enable_trace=enable_trace,
1382
+ fact_types=list(fact_type),
1383
+ question_date=question_date,
1384
+ include_entities=include_entities,
1385
+ max_entity_tokens=max_entity_tokens,
1386
+ include_chunks=include_chunks,
1387
+ max_chunk_tokens=max_chunk_tokens,
1388
+ result=None,
1389
+ success=False,
1390
+ error=error_msg,
1391
+ )
1392
+ try:
1393
+ await self._operation_validator.on_recall_complete(result_ctx)
1394
+ except Exception as hook_err:
1395
+ logger.warning(f"Post-recall hook error (non-fatal): {hook_err}")
1396
+ raise Exception(error_msg)
1397
+
1398
+ # Call post-operation hook for success
1399
+ if self._operation_validator and result is not None:
1400
+ from hindsight_api.extensions.operation_validator import RecallResult
1401
+
1402
+ result_ctx = RecallResult(
1403
+ bank_id=bank_id,
1404
+ query=query,
1405
+ request_context=request_context,
1406
+ budget=budget,
1407
+ max_tokens=max_tokens,
1408
+ enable_trace=enable_trace,
1409
+ fact_types=list(fact_type),
1410
+ question_date=question_date,
1411
+ include_entities=include_entities,
1412
+ max_entity_tokens=max_entity_tokens,
1413
+ include_chunks=include_chunks,
1414
+ max_chunk_tokens=max_chunk_tokens,
1415
+ result=result,
1416
+ success=True,
1417
+ error=None,
1418
+ )
1419
+ try:
1420
+ await self._operation_validator.on_recall_complete(result_ctx)
1421
+ except Exception as e:
1422
+ logger.warning(f"Post-recall hook error (non-fatal): {e}")
1423
+
1424
+ return result
1064
1425
 
1065
1426
  async def _search_with_retries(
1066
1427
  self,
@@ -1075,6 +1436,7 @@ class MemoryEngine:
1075
1436
  max_entity_tokens: int = 500,
1076
1437
  include_chunks: bool = False,
1077
1438
  max_chunk_tokens: int = 8192,
1439
+ request_context: "RequestContext" = None,
1078
1440
  ) -> RecallResultModel:
1079
1441
  """
1080
1442
  Search implementation with modular retrieval and reranking.
@@ -1302,6 +1664,9 @@ class MemoryEngine:
1302
1664
  step_start = time.time()
1303
1665
  reranker_instance = self._cross_encoder_reranker
1304
1666
 
1667
+ # Ensure reranker is initialized (for lazy initialization mode)
1668
+ await reranker_instance.ensure_initialized()
1669
+
1305
1670
  # Rerank using cross-encoder
1306
1671
  scored_results = reranker_instance.rerank(query, merged_candidates)
1307
1672
 
@@ -1465,10 +1830,10 @@ class MemoryEngine:
1465
1830
  if unit_ids:
1466
1831
  async with acquire_with_retry(pool) as entity_conn:
1467
1832
  entity_rows = await entity_conn.fetch(
1468
- """
1833
+ f"""
1469
1834
  SELECT ue.unit_id, e.id as entity_id, e.canonical_name
1470
- FROM unit_entities ue
1471
- JOIN entities e ON ue.entity_id = e.id
1835
+ FROM {fq_table("unit_entities")} ue
1836
+ JOIN {fq_table("entities")} e ON ue.entity_id = e.id
1472
1837
  WHERE ue.unit_id = ANY($1::uuid[])
1473
1838
  """,
1474
1839
  unit_ids,
@@ -1534,7 +1899,9 @@ class MemoryEngine:
1534
1899
  if total_entity_tokens >= max_entity_tokens:
1535
1900
  break
1536
1901
 
1537
- observations = await self.get_entity_observations(bank_id, entity_id, limit=5)
1902
+ observations = await self.get_entity_observations(
1903
+ bank_id, entity_id, limit=5, request_context=request_context
1904
+ )
1538
1905
 
1539
1906
  # Calculate tokens for this entity's observations
1540
1907
  entity_tokens = 0
@@ -1572,9 +1939,9 @@ class MemoryEngine:
1572
1939
  # Fetch chunk data from database using chunk_ids (no ORDER BY to preserve input order)
1573
1940
  async with acquire_with_retry(pool) as conn:
1574
1941
  chunks_rows = await conn.fetch(
1575
- """
1942
+ f"""
1576
1943
  SELECT chunk_id, chunk_text, chunk_index
1577
- FROM chunks
1944
+ FROM {fq_table("chunks")}
1578
1945
  WHERE chunk_id = ANY($1::text[])
1579
1946
  """,
1580
1947
  chunk_ids_ordered,
@@ -1671,25 +2038,33 @@ class MemoryEngine:
1671
2038
 
1672
2039
  return filtered_results, total_tokens
1673
2040
 
1674
- async def get_document(self, document_id: str, bank_id: str) -> dict[str, Any] | None:
2041
+ async def get_document(
2042
+ self,
2043
+ document_id: str,
2044
+ bank_id: str,
2045
+ *,
2046
+ request_context: "RequestContext",
2047
+ ) -> dict[str, Any] | None:
1675
2048
  """
1676
2049
  Retrieve document metadata and statistics.
1677
2050
 
1678
2051
  Args:
1679
2052
  document_id: Document ID to retrieve
1680
2053
  bank_id: bank ID that owns the document
2054
+ request_context: Request context for authentication.
1681
2055
 
1682
2056
  Returns:
1683
2057
  Dictionary with document info or None if not found
1684
2058
  """
2059
+ await self._authenticate_tenant(request_context)
1685
2060
  pool = await self._get_pool()
1686
2061
  async with acquire_with_retry(pool) as conn:
1687
2062
  doc = await conn.fetchrow(
1688
- """
2063
+ f"""
1689
2064
  SELECT d.id, d.bank_id, d.original_text, d.content_hash,
1690
2065
  d.created_at, d.updated_at, COUNT(mu.id) as unit_count
1691
- FROM documents d
1692
- LEFT JOIN memory_units mu ON mu.document_id = d.id
2066
+ FROM {fq_table("documents")} d
2067
+ LEFT JOIN {fq_table("memory_units")} mu ON mu.document_id = d.id
1693
2068
  WHERE d.id = $1 AND d.bank_id = $2
1694
2069
  GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
1695
2070
  """,
@@ -1706,37 +2081,52 @@ class MemoryEngine:
1706
2081
  "original_text": doc["original_text"],
1707
2082
  "content_hash": doc["content_hash"],
1708
2083
  "memory_unit_count": doc["unit_count"],
1709
- "created_at": doc["created_at"],
1710
- "updated_at": doc["updated_at"],
2084
+ "created_at": doc["created_at"].isoformat() if doc["created_at"] else None,
2085
+ "updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else None,
1711
2086
  }
1712
2087
 
1713
- async def delete_document(self, document_id: str, bank_id: str) -> dict[str, int]:
2088
+ async def delete_document(
2089
+ self,
2090
+ document_id: str,
2091
+ bank_id: str,
2092
+ *,
2093
+ request_context: "RequestContext",
2094
+ ) -> dict[str, int]:
1714
2095
  """
1715
2096
  Delete a document and all its associated memory units and links.
1716
2097
 
1717
2098
  Args:
1718
2099
  document_id: Document ID to delete
1719
2100
  bank_id: bank ID that owns the document
2101
+ request_context: Request context for authentication.
1720
2102
 
1721
2103
  Returns:
1722
2104
  Dictionary with counts of deleted items
1723
2105
  """
2106
+ await self._authenticate_tenant(request_context)
1724
2107
  pool = await self._get_pool()
1725
2108
  async with acquire_with_retry(pool) as conn:
1726
2109
  async with conn.transaction():
1727
2110
  # Count units before deletion
1728
2111
  units_count = await conn.fetchval(
1729
- "SELECT COUNT(*) FROM memory_units WHERE document_id = $1", document_id
2112
+ f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE document_id = $1", document_id
1730
2113
  )
1731
2114
 
1732
2115
  # Delete document (cascades to memory_units and all their links)
1733
2116
  deleted = await conn.fetchval(
1734
- "DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id", document_id, bank_id
2117
+ f"DELETE FROM {fq_table('documents')} WHERE id = $1 AND bank_id = $2 RETURNING id",
2118
+ document_id,
2119
+ bank_id,
1735
2120
  )
1736
2121
 
1737
2122
  return {"document_deleted": 1 if deleted else 0, "memory_units_deleted": units_count if deleted else 0}
1738
2123
 
1739
- async def delete_memory_unit(self, unit_id: str) -> dict[str, Any]:
2124
+ async def delete_memory_unit(
2125
+ self,
2126
+ unit_id: str,
2127
+ *,
2128
+ request_context: "RequestContext",
2129
+ ) -> dict[str, Any]:
1740
2130
  """
1741
2131
  Delete a single memory unit and all its associated links.
1742
2132
 
@@ -1747,15 +2137,19 @@ class MemoryEngine:
1747
2137
 
1748
2138
  Args:
1749
2139
  unit_id: UUID of the memory unit to delete
2140
+ request_context: Request context for authentication.
1750
2141
 
1751
2142
  Returns:
1752
2143
  Dictionary with deletion result
1753
2144
  """
2145
+ await self._authenticate_tenant(request_context)
1754
2146
  pool = await self._get_pool()
1755
2147
  async with acquire_with_retry(pool) as conn:
1756
2148
  async with conn.transaction():
1757
2149
  # Delete the memory unit (cascades to links and associations)
1758
- deleted = await conn.fetchval("DELETE FROM memory_units WHERE id = $1 RETURNING id", unit_id)
2150
+ deleted = await conn.fetchval(
2151
+ f"DELETE FROM {fq_table('memory_units')} WHERE id = $1 RETURNING id", unit_id
2152
+ )
1759
2153
 
1760
2154
  return {
1761
2155
  "success": deleted is not None,
@@ -1765,7 +2159,13 @@ class MemoryEngine:
1765
2159
  else "Memory unit not found",
1766
2160
  }
1767
2161
 
1768
- async def delete_bank(self, bank_id: str, fact_type: str | None = None) -> dict[str, int]:
2162
+ async def delete_bank(
2163
+ self,
2164
+ bank_id: str,
2165
+ fact_type: str | None = None,
2166
+ *,
2167
+ request_context: "RequestContext",
2168
+ ) -> dict[str, int]:
1769
2169
  """
1770
2170
  Delete all data for a specific agent (multi-tenant cleanup).
1771
2171
 
@@ -1780,10 +2180,12 @@ class MemoryEngine:
1780
2180
  Args:
1781
2181
  bank_id: bank ID to delete
1782
2182
  fact_type: Optional fact type filter (world, experience, opinion). If provided, only deletes memories of that type.
2183
+ request_context: Request context for authentication.
1783
2184
 
1784
2185
  Returns:
1785
2186
  Dictionary with counts of deleted items
1786
2187
  """
2188
+ await self._authenticate_tenant(request_context)
1787
2189
  pool = await self._get_pool()
1788
2190
  async with acquire_with_retry(pool) as conn:
1789
2191
  # Ensure connection is not in read-only mode (can happen with connection poolers)
@@ -1793,12 +2195,14 @@ class MemoryEngine:
1793
2195
  if fact_type:
1794
2196
  # Delete only memories of a specific fact type
1795
2197
  units_count = await conn.fetchval(
1796
- "SELECT COUNT(*) FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
2198
+ f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1 AND fact_type = $2",
1797
2199
  bank_id,
1798
2200
  fact_type,
1799
2201
  )
1800
2202
  await conn.execute(
1801
- "DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2", bank_id, fact_type
2203
+ f"DELETE FROM {fq_table('memory_units')} WHERE bank_id = $1 AND fact_type = $2",
2204
+ bank_id,
2205
+ fact_type,
1802
2206
  )
1803
2207
 
1804
2208
  # Note: We don't delete entities when fact_type is specified,
@@ -1807,26 +2211,26 @@ class MemoryEngine:
1807
2211
  else:
1808
2212
  # Delete all data for the bank
1809
2213
  units_count = await conn.fetchval(
1810
- "SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id
2214
+ f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1", bank_id
1811
2215
  )
1812
2216
  entities_count = await conn.fetchval(
1813
- "SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id
2217
+ f"SELECT COUNT(*) FROM {fq_table('entities')} WHERE bank_id = $1", bank_id
1814
2218
  )
1815
2219
  documents_count = await conn.fetchval(
1816
- "SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id
2220
+ f"SELECT COUNT(*) FROM {fq_table('documents')} WHERE bank_id = $1", bank_id
1817
2221
  )
1818
2222
 
1819
2223
  # Delete documents (cascades to chunks)
1820
- await conn.execute("DELETE FROM documents WHERE bank_id = $1", bank_id)
2224
+ await conn.execute(f"DELETE FROM {fq_table('documents')} WHERE bank_id = $1", bank_id)
1821
2225
 
1822
2226
  # Delete memory units (cascades to unit_entities, memory_links)
1823
- await conn.execute("DELETE FROM memory_units WHERE bank_id = $1", bank_id)
2227
+ await conn.execute(f"DELETE FROM {fq_table('memory_units')} WHERE bank_id = $1", bank_id)
1824
2228
 
1825
2229
  # Delete entities (cascades to unit_entities, entity_cooccurrences, memory_links with entity_id)
1826
- await conn.execute("DELETE FROM entities WHERE bank_id = $1", bank_id)
2230
+ await conn.execute(f"DELETE FROM {fq_table('entities')} WHERE bank_id = $1", bank_id)
1827
2231
 
1828
2232
  # Delete the bank profile itself
1829
- await conn.execute("DELETE FROM banks WHERE bank_id = $1", bank_id)
2233
+ await conn.execute(f"DELETE FROM {fq_table('banks')} WHERE bank_id = $1", bank_id)
1830
2234
 
1831
2235
  return {
1832
2236
  "memory_units_deleted": units_count,
@@ -1838,17 +2242,25 @@ class MemoryEngine:
1838
2242
  except Exception as e:
1839
2243
  raise Exception(f"Failed to delete agent data: {str(e)}")
1840
2244
 
1841
- async def get_graph_data(self, bank_id: str | None = None, fact_type: str | None = None):
2245
+ async def get_graph_data(
2246
+ self,
2247
+ bank_id: str | None = None,
2248
+ fact_type: str | None = None,
2249
+ *,
2250
+ request_context: "RequestContext",
2251
+ ):
1842
2252
  """
1843
2253
  Get graph data for visualization.
1844
2254
 
1845
2255
  Args:
1846
2256
  bank_id: Filter by bank ID
1847
2257
  fact_type: Filter by fact type (world, experience, opinion)
2258
+ request_context: Request context for authentication.
1848
2259
 
1849
2260
  Returns:
1850
2261
  Dict with nodes, edges, and table_rows
1851
2262
  """
2263
+ await self._authenticate_tenant(request_context)
1852
2264
  pool = await self._get_pool()
1853
2265
  async with acquire_with_retry(pool) as conn:
1854
2266
  # Get memory units, optionally filtered by bank_id and fact_type
@@ -1871,7 +2283,7 @@ class MemoryEngine:
1871
2283
  units = await conn.fetch(
1872
2284
  f"""
1873
2285
  SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
1874
- FROM memory_units
2286
+ FROM {fq_table("memory_units")}
1875
2287
  {where_clause}
1876
2288
  ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
1877
2289
  LIMIT 1000
@@ -1884,15 +2296,15 @@ class MemoryEngine:
1884
2296
  unit_ids = [row["id"] for row in units]
1885
2297
  if unit_ids:
1886
2298
  links = await conn.fetch(
1887
- """
2299
+ f"""
1888
2300
  SELECT DISTINCT ON (LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid))
1889
2301
  ml.from_unit_id,
1890
2302
  ml.to_unit_id,
1891
2303
  ml.link_type,
1892
2304
  ml.weight,
1893
2305
  e.canonical_name as entity_name
1894
- FROM memory_links ml
1895
- LEFT JOIN entities e ON ml.entity_id = e.id
2306
+ FROM {fq_table("memory_links")} ml
2307
+ LEFT JOIN {fq_table("entities")} e ON ml.entity_id = e.id
1896
2308
  WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.to_unit_id = ANY($1::uuid[])
1897
2309
  ORDER BY LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid), ml.weight DESC
1898
2310
  """,
@@ -1902,10 +2314,10 @@ class MemoryEngine:
1902
2314
  links = []
1903
2315
 
1904
2316
  # Get entity information
1905
- unit_entities = await conn.fetch("""
2317
+ unit_entities = await conn.fetch(f"""
1906
2318
  SELECT ue.unit_id, e.canonical_name
1907
- FROM unit_entities ue
1908
- JOIN entities e ON ue.entity_id = e.id
2319
+ FROM {fq_table("unit_entities")} ue
2320
+ JOIN {fq_table("entities")} e ON ue.entity_id = e.id
1909
2321
  ORDER BY ue.unit_id
1910
2322
  """)
1911
2323
 
@@ -2017,11 +2429,13 @@ class MemoryEngine:
2017
2429
 
2018
2430
  async def list_memory_units(
2019
2431
  self,
2020
- bank_id: str | None = None,
2432
+ bank_id: str,
2433
+ *,
2021
2434
  fact_type: str | None = None,
2022
2435
  search_query: str | None = None,
2023
2436
  limit: int = 100,
2024
2437
  offset: int = 0,
2438
+ request_context: "RequestContext",
2025
2439
  ):
2026
2440
  """
2027
2441
  List memory units for table view with optional full-text search.
@@ -2032,10 +2446,12 @@ class MemoryEngine:
2032
2446
  search_query: Full-text search query (searches text and context fields)
2033
2447
  limit: Maximum number of results to return
2034
2448
  offset: Offset for pagination
2449
+ request_context: Request context for authentication.
2035
2450
 
2036
2451
  Returns:
2037
2452
  Dict with items (list of memory units) and total count
2038
2453
  """
2454
+ await self._authenticate_tenant(request_context)
2039
2455
  pool = await self._get_pool()
2040
2456
  async with acquire_with_retry(pool) as conn:
2041
2457
  # Build query conditions
@@ -2064,7 +2480,7 @@ class MemoryEngine:
2064
2480
  # Get total count
2065
2481
  count_query = f"""
2066
2482
  SELECT COUNT(*) as total
2067
- FROM memory_units
2483
+ FROM {fq_table("memory_units")}
2068
2484
  {where_clause}
2069
2485
  """
2070
2486
  count_result = await conn.fetchrow(count_query, *query_params)
@@ -2082,7 +2498,7 @@ class MemoryEngine:
2082
2498
  units = await conn.fetch(
2083
2499
  f"""
2084
2500
  SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id
2085
- FROM memory_units
2501
+ FROM {fq_table("memory_units")}
2086
2502
  {where_clause}
2087
2503
  ORDER BY mentioned_at DESC NULLS LAST, created_at DESC
2088
2504
  LIMIT {limit_param} OFFSET {offset_param}
@@ -2094,10 +2510,10 @@ class MemoryEngine:
2094
2510
  if units:
2095
2511
  unit_ids = [row["id"] for row in units]
2096
2512
  unit_entities = await conn.fetch(
2097
- """
2513
+ f"""
2098
2514
  SELECT ue.unit_id, e.canonical_name
2099
- FROM unit_entities ue
2100
- JOIN entities e ON ue.entity_id = e.id
2515
+ FROM {fq_table("unit_entities")} ue
2516
+ JOIN {fq_table("entities")} e ON ue.entity_id = e.id
2101
2517
  WHERE ue.unit_id = ANY($1::uuid[])
2102
2518
  ORDER BY ue.unit_id
2103
2519
  """,
@@ -2138,7 +2554,15 @@ class MemoryEngine:
2138
2554
 
2139
2555
  return {"items": items, "total": total, "limit": limit, "offset": offset}
2140
2556
 
2141
- async def list_documents(self, bank_id: str, search_query: str | None = None, limit: int = 100, offset: int = 0):
2557
+ async def list_documents(
2558
+ self,
2559
+ bank_id: str,
2560
+ *,
2561
+ search_query: str | None = None,
2562
+ limit: int = 100,
2563
+ offset: int = 0,
2564
+ request_context: "RequestContext",
2565
+ ):
2142
2566
  """
2143
2567
  List documents with optional search and pagination.
2144
2568
 
@@ -2147,10 +2571,12 @@ class MemoryEngine:
2147
2571
  search_query: Search in document ID
2148
2572
  limit: Maximum number of results
2149
2573
  offset: Offset for pagination
2574
+ request_context: Request context for authentication.
2150
2575
 
2151
2576
  Returns:
2152
2577
  Dict with items (list of documents without original_text) and total count
2153
2578
  """
2579
+ await self._authenticate_tenant(request_context)
2154
2580
  pool = await self._get_pool()
2155
2581
  async with acquire_with_retry(pool) as conn:
2156
2582
  # Build query conditions
@@ -2173,7 +2599,7 @@ class MemoryEngine:
2173
2599
  # Get total count
2174
2600
  count_query = f"""
2175
2601
  SELECT COUNT(*) as total
2176
- FROM documents
2602
+ FROM {fq_table("documents")}
2177
2603
  {where_clause}
2178
2604
  """
2179
2605
  count_result = await conn.fetchrow(count_query, *query_params)
@@ -2198,7 +2624,7 @@ class MemoryEngine:
2198
2624
  updated_at,
2199
2625
  LENGTH(original_text) as text_length,
2200
2626
  retain_params
2201
- FROM documents
2627
+ FROM {fq_table("documents")}
2202
2628
  {where_clause}
2203
2629
  ORDER BY created_at DESC
2204
2630
  LIMIT {limit_param} OFFSET {offset_param}
@@ -2224,7 +2650,7 @@ class MemoryEngine:
2224
2650
  unit_counts = await conn.fetch(
2225
2651
  f"""
2226
2652
  SELECT document_id, bank_id, COUNT(*) as unit_count
2227
- FROM memory_units
2653
+ FROM {fq_table("memory_units")}
2228
2654
  WHERE {where_clause_count}
2229
2655
  GROUP BY document_id, bank_id
2230
2656
  """,
@@ -2258,75 +2684,27 @@ class MemoryEngine:
2258
2684
 
2259
2685
  return {"items": items, "total": total, "limit": limit, "offset": offset}
2260
2686
 
2261
- async def get_document(self, document_id: str, bank_id: str):
2262
- """
2263
- Get a specific document including its original_text.
2264
-
2265
- Args:
2266
- document_id: Document ID
2267
- bank_id: bank ID
2268
-
2269
- Returns:
2270
- Dict with document details including original_text, or None if not found
2271
- """
2272
- pool = await self._get_pool()
2273
- async with acquire_with_retry(pool) as conn:
2274
- doc = await conn.fetchrow(
2275
- """
2276
- SELECT
2277
- id,
2278
- bank_id,
2279
- original_text,
2280
- content_hash,
2281
- created_at,
2282
- updated_at,
2283
- retain_params
2284
- FROM documents
2285
- WHERE id = $1 AND bank_id = $2
2286
- """,
2287
- document_id,
2288
- bank_id,
2289
- )
2290
-
2291
- if not doc:
2292
- return None
2293
-
2294
- # Get memory unit count
2295
- unit_count_row = await conn.fetchrow(
2296
- """
2297
- SELECT COUNT(*) as unit_count
2298
- FROM memory_units
2299
- WHERE document_id = $1 AND bank_id = $2
2300
- """,
2301
- document_id,
2302
- bank_id,
2303
- )
2304
-
2305
- return {
2306
- "id": doc["id"],
2307
- "bank_id": doc["bank_id"],
2308
- "original_text": doc["original_text"],
2309
- "content_hash": doc["content_hash"],
2310
- "created_at": doc["created_at"].isoformat() if doc["created_at"] else "",
2311
- "updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else "",
2312
- "memory_unit_count": unit_count_row["unit_count"] if unit_count_row else 0,
2313
- "retain_params": doc["retain_params"] if doc["retain_params"] else None,
2314
- }
2315
-
2316
- async def get_chunk(self, chunk_id: str):
2687
+ async def get_chunk(
2688
+ self,
2689
+ chunk_id: str,
2690
+ *,
2691
+ request_context: "RequestContext",
2692
+ ):
2317
2693
  """
2318
2694
  Get a specific chunk by its ID.
2319
2695
 
2320
2696
  Args:
2321
2697
  chunk_id: Chunk ID (format: bank_id_document_id_chunk_index)
2698
+ request_context: Request context for authentication.
2322
2699
 
2323
2700
  Returns:
2324
2701
  Dict with chunk details including chunk_text, or None if not found
2325
2702
  """
2703
+ await self._authenticate_tenant(request_context)
2326
2704
  pool = await self._get_pool()
2327
2705
  async with acquire_with_retry(pool) as conn:
2328
2706
  chunk = await conn.fetchrow(
2329
- """
2707
+ f"""
2330
2708
  SELECT
2331
2709
  chunk_id,
2332
2710
  document_id,
@@ -2334,7 +2712,7 @@ class MemoryEngine:
2334
2712
  chunk_index,
2335
2713
  chunk_text,
2336
2714
  created_at
2337
- FROM chunks
2715
+ FROM {fq_table("chunks")}
2338
2716
  WHERE chunk_id = $1
2339
2717
  """,
2340
2718
  chunk_id,
@@ -2500,11 +2878,11 @@ Guidelines:
2500
2878
  async with acquire_with_retry(pool) as conn:
2501
2879
  # Find all opinions related to these entities
2502
2880
  opinions = await conn.fetch(
2503
- """
2881
+ f"""
2504
2882
  SELECT DISTINCT mu.id, mu.text, mu.confidence_score, e.canonical_name
2505
- FROM memory_units mu
2506
- JOIN unit_entities ue ON mu.id = ue.unit_id
2507
- JOIN entities e ON ue.entity_id = e.id
2883
+ FROM {fq_table("memory_units")} mu
2884
+ JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
2885
+ JOIN {fq_table("entities")} e ON ue.entity_id = e.id
2508
2886
  WHERE mu.bank_id = $1
2509
2887
  AND mu.fact_type = 'opinion'
2510
2888
  AND e.canonical_name = ANY($2::text[])
@@ -2559,8 +2937,8 @@ Guidelines:
2559
2937
  if evaluation["action"] == "update" and evaluation["new_text"]:
2560
2938
  # Update both text and confidence
2561
2939
  await conn.execute(
2562
- """
2563
- UPDATE memory_units
2940
+ f"""
2941
+ UPDATE {fq_table("memory_units")}
2564
2942
  SET text = $1, confidence_score = $2, updated_at = NOW()
2565
2943
  WHERE id = $3
2566
2944
  """,
@@ -2571,8 +2949,8 @@ Guidelines:
2571
2949
  else:
2572
2950
  # Only update confidence
2573
2951
  await conn.execute(
2574
- """
2575
- UPDATE memory_units
2952
+ f"""
2953
+ UPDATE {fq_table("memory_units")}
2576
2954
  SET confidence_score = $1, updated_at = NOW()
2577
2955
  WHERE id = $2
2578
2956
  """,
@@ -2591,32 +2969,61 @@ Guidelines:
2591
2969
 
2592
2970
  # ==================== bank profile Methods ====================
2593
2971
 
2594
- async def get_bank_profile(self, bank_id: str) -> "bank_utils.BankProfile":
2972
+ async def get_bank_profile(
2973
+ self,
2974
+ bank_id: str,
2975
+ *,
2976
+ request_context: "RequestContext",
2977
+ ) -> dict[str, Any]:
2595
2978
  """
2596
2979
  Get bank profile (name, disposition + background).
2597
2980
  Auto-creates agent with default values if not exists.
2598
2981
 
2599
2982
  Args:
2600
2983
  bank_id: bank IDentifier
2984
+ request_context: Request context for authentication.
2601
2985
 
2602
2986
  Returns:
2603
- BankProfile with name, typed DispositionTraits, and background
2987
+ Dict with name, disposition traits, and background
2604
2988
  """
2989
+ await self._authenticate_tenant(request_context)
2605
2990
  pool = await self._get_pool()
2606
- return await bank_utils.get_bank_profile(pool, bank_id)
2607
-
2608
- async def update_bank_disposition(self, bank_id: str, disposition: dict[str, int]) -> None:
2991
+ profile = await bank_utils.get_bank_profile(pool, bank_id)
2992
+ disposition = profile["disposition"]
2993
+ return {
2994
+ "bank_id": bank_id,
2995
+ "name": profile["name"],
2996
+ "disposition": disposition,
2997
+ "background": profile["background"],
2998
+ }
2999
+
3000
+ async def update_bank_disposition(
3001
+ self,
3002
+ bank_id: str,
3003
+ disposition: dict[str, int],
3004
+ *,
3005
+ request_context: "RequestContext",
3006
+ ) -> None:
2609
3007
  """
2610
3008
  Update bank disposition traits.
2611
3009
 
2612
3010
  Args:
2613
3011
  bank_id: bank IDentifier
2614
3012
  disposition: Dict with skepticism, literalism, empathy (all 1-5)
3013
+ request_context: Request context for authentication.
2615
3014
  """
3015
+ await self._authenticate_tenant(request_context)
2616
3016
  pool = await self._get_pool()
2617
3017
  await bank_utils.update_bank_disposition(pool, bank_id, disposition)
2618
3018
 
2619
- async def merge_bank_background(self, bank_id: str, new_info: str, update_disposition: bool = True) -> dict:
3019
+ async def merge_bank_background(
3020
+ self,
3021
+ bank_id: str,
3022
+ new_info: str,
3023
+ *,
3024
+ update_disposition: bool = True,
3025
+ request_context: "RequestContext",
3026
+ ) -> dict[str, Any]:
2620
3027
  """
2621
3028
  Merge new background information with existing background using LLM.
2622
3029
  Normalizes to first person ("I") and resolves conflicts.
@@ -2626,20 +3033,30 @@ Guidelines:
2626
3033
  bank_id: bank IDentifier
2627
3034
  new_info: New background information to add/merge
2628
3035
  update_disposition: If True, infer Big Five traits from background (default: True)
3036
+ request_context: Request context for authentication.
2629
3037
 
2630
3038
  Returns:
2631
3039
  Dict with 'background' (str) and optionally 'disposition' (dict) keys
2632
3040
  """
3041
+ await self._authenticate_tenant(request_context)
2633
3042
  pool = await self._get_pool()
2634
3043
  return await bank_utils.merge_bank_background(pool, self._llm_config, bank_id, new_info, update_disposition)
2635
3044
 
2636
- async def list_banks(self) -> list:
3045
+ async def list_banks(
3046
+ self,
3047
+ *,
3048
+ request_context: "RequestContext",
3049
+ ) -> list[dict[str, Any]]:
2637
3050
  """
2638
3051
  List all agents in the system.
2639
3052
 
3053
+ Args:
3054
+ request_context: Request context for authentication.
3055
+
2640
3056
  Returns:
2641
3057
  List of dicts with bank_id, name, disposition, background, created_at, updated_at
2642
3058
  """
3059
+ await self._authenticate_tenant(request_context)
2643
3060
  pool = await self._get_pool()
2644
3061
  return await bank_utils.list_banks(pool)
2645
3062
 
@@ -2649,8 +3066,10 @@ Guidelines:
2649
3066
  self,
2650
3067
  bank_id: str,
2651
3068
  query: str,
2652
- budget: Budget = Budget.LOW,
2653
- context: str = None,
3069
+ *,
3070
+ budget: Budget | None = None,
3071
+ context: str | None = None,
3072
+ request_context: "RequestContext",
2654
3073
  ) -> ReflectResult:
2655
3074
  """
2656
3075
  Reflect and formulate an answer using bank identity, world facts, and opinions.
@@ -2679,6 +3098,22 @@ Guidelines:
2679
3098
  if self._llm_config is None:
2680
3099
  raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
2681
3100
 
3101
+ # Authenticate tenant and set schema in context (for fq_table())
3102
+ await self._authenticate_tenant(request_context)
3103
+
3104
+ # Validate operation if validator is configured
3105
+ if self._operation_validator:
3106
+ from hindsight_api.extensions import ReflectContext
3107
+
3108
+ ctx = ReflectContext(
3109
+ bank_id=bank_id,
3110
+ query=query,
3111
+ request_context=request_context,
3112
+ budget=budget,
3113
+ context=context,
3114
+ )
3115
+ await self._validate_operation(self._operation_validator.validate_reflect(ctx))
3116
+
2682
3117
  reflect_start = time.time()
2683
3118
  reflect_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
2684
3119
  log_buffer = []
@@ -2694,6 +3129,7 @@ Guidelines:
2694
3129
  enable_trace=False,
2695
3130
  fact_type=["experience", "world", "opinion"],
2696
3131
  include_entities=True,
3132
+ request_context=request_context,
2697
3133
  )
2698
3134
  recall_time = time.time() - recall_start
2699
3135
 
@@ -2714,7 +3150,7 @@ Guidelines:
2714
3150
  opinion_facts_text = think_utils.format_facts_for_prompt(opinion_results)
2715
3151
 
2716
3152
  # Get bank profile (name, disposition + background)
2717
- profile = await self.get_bank_profile(bank_id)
3153
+ profile = await self.get_bank_profile(bank_id, request_context=request_context)
2718
3154
  name = profile["name"]
2719
3155
  disposition = profile["disposition"] # Typed as DispositionTraits
2720
3156
  background = profile["background"]
@@ -2758,12 +3194,33 @@ Guidelines:
2758
3194
  logger.info("\n" + "\n".join(log_buffer))
2759
3195
 
2760
3196
  # Return response with facts split by type
2761
- return ReflectResult(
3197
+ result = ReflectResult(
2762
3198
  text=answer_text,
2763
3199
  based_on={"world": world_results, "experience": agent_results, "opinion": opinion_results},
2764
3200
  new_opinions=[], # Opinions are being extracted asynchronously
2765
3201
  )
2766
3202
 
3203
+ # Call post-operation hook if validator is configured
3204
+ if self._operation_validator:
3205
+ from hindsight_api.extensions.operation_validator import ReflectResultContext
3206
+
3207
+ result_ctx = ReflectResultContext(
3208
+ bank_id=bank_id,
3209
+ query=query,
3210
+ request_context=request_context,
3211
+ budget=budget,
3212
+ context=context,
3213
+ result=result,
3214
+ success=True,
3215
+ error=None,
3216
+ )
3217
+ try:
3218
+ await self._operation_validator.on_reflect_complete(result_ctx)
3219
+ except Exception as e:
3220
+ logger.warning(f"Post-reflect hook error (non-fatal): {e}")
3221
+
3222
+ return result
3223
+
2767
3224
  async def _extract_and_store_opinions_async(self, bank_id: str, answer_text: str, query: str):
2768
3225
  """
2769
3226
  Background task to extract and store opinions from think response.
@@ -2784,6 +3241,10 @@ Guidelines:
2784
3241
  from datetime import datetime
2785
3242
 
2786
3243
  current_time = datetime.now(UTC)
3244
+ # Use internal request context for background tasks
3245
+ from hindsight_api.models import RequestContext
3246
+
3247
+ internal_context = RequestContext()
2787
3248
  for opinion in new_opinions:
2788
3249
  await self.retain_async(
2789
3250
  bank_id=bank_id,
@@ -2792,12 +3253,20 @@ Guidelines:
2792
3253
  event_date=current_time,
2793
3254
  fact_type_override="opinion",
2794
3255
  confidence_score=opinion.confidence,
3256
+ request_context=internal_context,
2795
3257
  )
2796
3258
 
2797
3259
  except Exception as e:
2798
3260
  logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
2799
3261
 
2800
- async def get_entity_observations(self, bank_id: str, entity_id: str, limit: int = 10) -> list[EntityObservation]:
3262
+ async def get_entity_observations(
3263
+ self,
3264
+ bank_id: str,
3265
+ entity_id: str,
3266
+ *,
3267
+ limit: int = 10,
3268
+ request_context: "RequestContext",
3269
+ ) -> list[Any]:
2801
3270
  """
2802
3271
  Get observations linked to an entity.
2803
3272
 
@@ -2805,17 +3274,19 @@ Guidelines:
2805
3274
  bank_id: bank IDentifier
2806
3275
  entity_id: Entity UUID to get observations for
2807
3276
  limit: Maximum number of observations to return
3277
+ request_context: Request context for authentication.
2808
3278
 
2809
3279
  Returns:
2810
3280
  List of EntityObservation objects
2811
3281
  """
3282
+ await self._authenticate_tenant(request_context)
2812
3283
  pool = await self._get_pool()
2813
3284
  async with acquire_with_retry(pool) as conn:
2814
3285
  rows = await conn.fetch(
2815
- """
3286
+ f"""
2816
3287
  SELECT mu.text, mu.mentioned_at
2817
- FROM memory_units mu
2818
- JOIN unit_entities ue ON mu.id = ue.unit_id
3288
+ FROM {fq_table("memory_units")} mu
3289
+ JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
2819
3290
  WHERE mu.bank_id = $1
2820
3291
  AND mu.fact_type = 'observation'
2821
3292
  AND ue.entity_id = $2
@@ -2833,23 +3304,31 @@ Guidelines:
2833
3304
  observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
2834
3305
  return observations
2835
3306
 
2836
- async def list_entities(self, bank_id: str, limit: int = 100) -> list[dict[str, Any]]:
3307
+ async def list_entities(
3308
+ self,
3309
+ bank_id: str,
3310
+ *,
3311
+ limit: int = 100,
3312
+ request_context: "RequestContext",
3313
+ ) -> list[dict[str, Any]]:
2837
3314
  """
2838
3315
  List all entities for a bank.
2839
3316
 
2840
3317
  Args:
2841
3318
  bank_id: bank IDentifier
2842
3319
  limit: Maximum number of entities to return
3320
+ request_context: Request context for authentication.
2843
3321
 
2844
3322
  Returns:
2845
3323
  List of entity dicts with id, canonical_name, mention_count, first_seen, last_seen
2846
3324
  """
3325
+ await self._authenticate_tenant(request_context)
2847
3326
  pool = await self._get_pool()
2848
3327
  async with acquire_with_retry(pool) as conn:
2849
3328
  rows = await conn.fetch(
2850
- """
3329
+ f"""
2851
3330
  SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
2852
- FROM entities
3331
+ FROM {fq_table("entities")}
2853
3332
  WHERE bank_id = $1
2854
3333
  ORDER BY mention_count DESC, last_seen DESC
2855
3334
  LIMIT $2
@@ -2884,7 +3363,15 @@ Guidelines:
2884
3363
  )
2885
3364
  return entities
2886
3365
 
2887
- async def get_entity_state(self, bank_id: str, entity_id: str, entity_name: str, limit: int = 10) -> EntityState:
3366
+ async def get_entity_state(
3367
+ self,
3368
+ bank_id: str,
3369
+ entity_id: str,
3370
+ entity_name: str,
3371
+ *,
3372
+ limit: int = 10,
3373
+ request_context: "RequestContext",
3374
+ ) -> EntityState:
2888
3375
  """
2889
3376
  Get the current state (mental model) of an entity.
2890
3377
 
@@ -2893,16 +3380,26 @@ Guidelines:
2893
3380
  entity_id: Entity UUID
2894
3381
  entity_name: Canonical name of the entity
2895
3382
  limit: Maximum number of observations to include
3383
+ request_context: Request context for authentication.
2896
3384
 
2897
3385
  Returns:
2898
3386
  EntityState with observations
2899
3387
  """
2900
- observations = await self.get_entity_observations(bank_id, entity_id, limit)
3388
+ observations = await self.get_entity_observations(
3389
+ bank_id, entity_id, limit=limit, request_context=request_context
3390
+ )
2901
3391
  return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=observations)
2902
3392
 
2903
3393
  async def regenerate_entity_observations(
2904
- self, bank_id: str, entity_id: str, entity_name: str, version: str | None = None, conn=None
2905
- ) -> list[str]:
3394
+ self,
3395
+ bank_id: str,
3396
+ entity_id: str,
3397
+ entity_name: str,
3398
+ *,
3399
+ version: str | None = None,
3400
+ conn=None,
3401
+ request_context: "RequestContext",
3402
+ ) -> None:
2906
3403
  """
2907
3404
  Regenerate observations for an entity by:
2908
3405
  1. Checking version for deduplication (if provided)
@@ -2917,10 +3414,9 @@ Guidelines:
2917
3414
  entity_name: Canonical name of the entity
2918
3415
  version: Entity's last_seen timestamp when task was created (for deduplication)
2919
3416
  conn: Optional database connection (for transactional atomicity with caller)
2920
-
2921
- Returns:
2922
- List of created observation IDs
3417
+ request_context: Request context for authentication.
2923
3418
  """
3419
+ await self._authenticate_tenant(request_context)
2924
3420
  pool = await self._get_pool()
2925
3421
  entity_uuid = uuid.UUID(entity_id)
2926
3422
 
@@ -2942,9 +3438,9 @@ Guidelines:
2942
3438
  # Step 1: Check version for deduplication
2943
3439
  if version:
2944
3440
  current_last_seen = await fetchval_with_conn(
2945
- """
3441
+ f"""
2946
3442
  SELECT last_seen
2947
- FROM entities
3443
+ FROM {fq_table("entities")}
2948
3444
  WHERE id = $1 AND bank_id = $2
2949
3445
  """,
2950
3446
  entity_uuid,
@@ -2956,10 +3452,10 @@ Guidelines:
2956
3452
 
2957
3453
  # Step 2: Get all facts mentioning this entity (exclude observations themselves)
2958
3454
  rows = await fetch_with_conn(
2959
- """
3455
+ f"""
2960
3456
  SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
2961
- FROM memory_units mu
2962
- JOIN unit_entities ue ON mu.id = ue.unit_id
3457
+ FROM {fq_table("memory_units")} mu
3458
+ JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
2963
3459
  WHERE mu.bank_id = $1
2964
3460
  AND ue.entity_id = $2
2965
3461
  AND mu.fact_type IN ('world', 'experience')
@@ -2999,12 +3495,12 @@ Guidelines:
2999
3495
  async def do_db_operations(db_conn):
3000
3496
  # Delete old observations for this entity
3001
3497
  await db_conn.execute(
3002
- """
3003
- DELETE FROM memory_units
3498
+ f"""
3499
+ DELETE FROM {fq_table("memory_units")}
3004
3500
  WHERE id IN (
3005
3501
  SELECT mu.id
3006
- FROM memory_units mu
3007
- JOIN unit_entities ue ON mu.id = ue.unit_id
3502
+ FROM {fq_table("memory_units")} mu
3503
+ JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
3008
3504
  WHERE mu.bank_id = $1
3009
3505
  AND mu.fact_type = 'observation'
3010
3506
  AND ue.entity_id = $2
@@ -3023,8 +3519,8 @@ Guidelines:
3023
3519
 
3024
3520
  for obs_text, embedding in zip(observations, embeddings):
3025
3521
  result = await db_conn.fetchrow(
3026
- """
3027
- INSERT INTO memory_units (
3522
+ f"""
3523
+ INSERT INTO {fq_table("memory_units")} (
3028
3524
  bank_id, text, embedding, context, event_date,
3029
3525
  occurred_start, occurred_end, mentioned_at,
3030
3526
  fact_type, access_count
@@ -3046,8 +3542,8 @@ Guidelines:
3046
3542
 
3047
3543
  # Link observation to entity
3048
3544
  await db_conn.execute(
3049
- """
3050
- INSERT INTO unit_entities (unit_id, entity_id)
3545
+ f"""
3546
+ INSERT INTO {fq_table("unit_entities")} (unit_id, entity_id)
3051
3547
  VALUES ($1, $2)
3052
3548
  """,
3053
3549
  uuid.UUID(obs_id),
@@ -3066,7 +3562,12 @@ Guidelines:
3066
3562
  return await do_db_operations(acquired_conn)
3067
3563
 
3068
3564
  async def _regenerate_observations_sync(
3069
- self, bank_id: str, entity_ids: list[str], min_facts: int = 5, conn=None
3565
+ self,
3566
+ bank_id: str,
3567
+ entity_ids: list[str],
3568
+ min_facts: int = 5,
3569
+ conn=None,
3570
+ request_context: "RequestContext | None" = None,
3070
3571
  ) -> None:
3071
3572
  """
3072
3573
  Regenerate observations for entities synchronously (called during retain).
@@ -3089,8 +3590,8 @@ Guidelines:
3089
3590
  if conn is not None:
3090
3591
  # Use the provided connection (transactional with caller)
3091
3592
  entity_rows = await conn.fetch(
3092
- """
3093
- SELECT id, canonical_name FROM entities
3593
+ f"""
3594
+ SELECT id, canonical_name FROM {fq_table("entities")}
3094
3595
  WHERE id = ANY($1) AND bank_id = $2
3095
3596
  """,
3096
3597
  entity_uuids,
@@ -3099,10 +3600,10 @@ Guidelines:
3099
3600
  entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
3100
3601
 
3101
3602
  fact_counts = await conn.fetch(
3102
- """
3603
+ f"""
3103
3604
  SELECT ue.entity_id, COUNT(*) as cnt
3104
- FROM unit_entities ue
3105
- JOIN memory_units mu ON ue.unit_id = mu.id
3605
+ FROM {fq_table("unit_entities")} ue
3606
+ JOIN {fq_table("memory_units")} mu ON ue.unit_id = mu.id
3106
3607
  WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3107
3608
  GROUP BY ue.entity_id
3108
3609
  """,
@@ -3115,8 +3616,8 @@ Guidelines:
3115
3616
  pool = await self._get_pool()
3116
3617
  async with pool.acquire() as acquired_conn:
3117
3618
  entity_rows = await acquired_conn.fetch(
3118
- """
3119
- SELECT id, canonical_name FROM entities
3619
+ f"""
3620
+ SELECT id, canonical_name FROM {fq_table("entities")}
3120
3621
  WHERE id = ANY($1) AND bank_id = $2
3121
3622
  """,
3122
3623
  entity_uuids,
@@ -3125,10 +3626,10 @@ Guidelines:
3125
3626
  entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
3126
3627
 
3127
3628
  fact_counts = await acquired_conn.fetch(
3128
- """
3629
+ f"""
3129
3630
  SELECT ue.entity_id, COUNT(*) as cnt
3130
- FROM unit_entities ue
3131
- JOIN memory_units mu ON ue.unit_id = mu.id
3631
+ FROM {fq_table("unit_entities")} ue
3632
+ JOIN {fq_table("memory_units")} mu ON ue.unit_id = mu.id
3132
3633
  WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
3133
3634
  GROUP BY ue.entity_id
3134
3635
  """,
@@ -3150,10 +3651,17 @@ Guidelines:
3150
3651
  if not entities_to_process:
3151
3652
  return
3152
3653
 
3654
+ # Use internal context if not provided (for internal/background calls)
3655
+ from hindsight_api.models import RequestContext as RC
3656
+
3657
+ ctx = request_context if request_context is not None else RC()
3658
+
3153
3659
  # Process all entities in PARALLEL (LLM calls are the bottleneck)
3154
3660
  async def process_entity(entity_id: str, entity_name: str):
3155
3661
  try:
3156
- await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None, conn=conn)
3662
+ await self.regenerate_entity_observations(
3663
+ bank_id, entity_id, entity_name, version=None, conn=conn, request_context=ctx
3664
+ )
3157
3665
  except Exception as e:
3158
3666
  logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3159
3667
 
@@ -3167,76 +3675,367 @@ Guidelines:
3167
3675
  task_dict: Dict with 'bank_id' and either:
3168
3676
  - 'entity_ids' (list): Process multiple entities
3169
3677
  - 'entity_id', 'entity_name': Process single entity (legacy)
3678
+
3679
+ Raises:
3680
+ ValueError: If required fields are missing
3681
+ Exception: Any exception from regenerate_entity_observations (propagates to execute_task for retry)
3170
3682
  """
3171
- try:
3172
- bank_id = task_dict.get("bank_id")
3683
+ bank_id = task_dict.get("bank_id")
3684
+ # Use internal request context for background tasks
3685
+ from hindsight_api.models import RequestContext
3173
3686
 
3174
- # New format: multiple entity_ids
3175
- if "entity_ids" in task_dict:
3176
- entity_ids = task_dict.get("entity_ids", [])
3177
- min_facts = task_dict.get("min_facts", 5)
3687
+ internal_context = RequestContext()
3178
3688
 
3179
- if not bank_id or not entity_ids:
3180
- logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
3181
- return
3689
+ # New format: multiple entity_ids
3690
+ if "entity_ids" in task_dict:
3691
+ entity_ids = task_dict.get("entity_ids", [])
3692
+ min_facts = task_dict.get("min_facts", 5)
3182
3693
 
3183
- # Process each entity
3184
- pool = await self._get_pool()
3185
- async with pool.acquire() as conn:
3186
- for entity_id in entity_ids:
3187
- try:
3188
- # Fetch entity name and check fact count
3189
- import uuid as uuid_module
3694
+ if not bank_id or not entity_ids:
3695
+ raise ValueError(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
3696
+
3697
+ # Process each entity
3698
+ pool = await self._get_pool()
3699
+ async with pool.acquire() as conn:
3700
+ for entity_id in entity_ids:
3701
+ try:
3702
+ # Fetch entity name and check fact count
3703
+ import uuid as uuid_module
3704
+
3705
+ entity_uuid = uuid_module.UUID(entity_id) if isinstance(entity_id, str) else entity_id
3706
+
3707
+ # First check if entity exists
3708
+ entity_exists = await conn.fetchrow(
3709
+ f"SELECT canonical_name FROM {fq_table('entities')} WHERE id = $1 AND bank_id = $2",
3710
+ entity_uuid,
3711
+ bank_id,
3712
+ )
3713
+
3714
+ if not entity_exists:
3715
+ logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
3716
+ continue
3190
3717
 
3191
- entity_uuid = uuid_module.UUID(entity_id) if isinstance(entity_id, str) else entity_id
3718
+ entity_name = entity_exists["canonical_name"]
3192
3719
 
3193
- # First check if entity exists
3194
- entity_exists = await conn.fetchrow(
3195
- "SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
3720
+ # Count facts linked to this entity
3721
+ fact_count = (
3722
+ await conn.fetchval(
3723
+ f"SELECT COUNT(*) FROM {fq_table('unit_entities')} WHERE entity_id = $1",
3196
3724
  entity_uuid,
3197
- bank_id,
3198
3725
  )
3726
+ or 0
3727
+ )
3199
3728
 
3200
- if not entity_exists:
3201
- logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
3202
- continue
3729
+ # Only regenerate if entity has enough facts
3730
+ if fact_count >= min_facts:
3731
+ await self.regenerate_entity_observations(
3732
+ bank_id, entity_id, entity_name, version=None, request_context=internal_context
3733
+ )
3734
+ else:
3735
+ logger.debug(
3736
+ f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
3737
+ )
3203
3738
 
3204
- entity_name = entity_exists["canonical_name"]
3739
+ except Exception as e:
3740
+ # Log but continue processing other entities - individual entity failures
3741
+ # shouldn't fail the whole batch
3742
+ logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3743
+ continue
3205
3744
 
3206
- # Count facts linked to this entity
3207
- fact_count = (
3208
- await conn.fetchval(
3209
- "SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1", entity_uuid
3210
- )
3211
- or 0
3212
- )
3745
+ # Legacy format: single entity
3746
+ else:
3747
+ entity_id = task_dict.get("entity_id")
3748
+ entity_name = task_dict.get("entity_name")
3749
+ version = task_dict.get("version")
3213
3750
 
3214
- # Only regenerate if entity has enough facts
3215
- if fact_count >= min_facts:
3216
- await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
3217
- else:
3218
- logger.debug(
3219
- f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
3220
- )
3751
+ if not all([bank_id, entity_id, entity_name]):
3752
+ raise ValueError(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
3221
3753
 
3222
- except Exception as e:
3223
- logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
3224
- continue
3754
+ # Type assertions after validation
3755
+ assert isinstance(bank_id, str) and isinstance(entity_id, str) and isinstance(entity_name, str)
3756
+ await self.regenerate_entity_observations(
3757
+ bank_id, entity_id, entity_name, version=version, request_context=internal_context
3758
+ )
3225
3759
 
3226
- # Legacy format: single entity
3227
- else:
3228
- entity_id = task_dict.get("entity_id")
3229
- entity_name = task_dict.get("entity_name")
3230
- version = task_dict.get("version")
3760
+ # =========================================================================
3761
+ # Statistics & Operations (for HTTP API layer)
3762
+ # =========================================================================
3231
3763
 
3232
- if not all([bank_id, entity_id, entity_name]):
3233
- logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
3234
- return
3764
+ async def get_bank_stats(
3765
+ self,
3766
+ bank_id: str,
3767
+ *,
3768
+ request_context: "RequestContext",
3769
+ ) -> dict[str, Any]:
3770
+ """Get statistics about memory nodes and links for a bank."""
3771
+ await self._authenticate_tenant(request_context)
3772
+ pool = await self._get_pool()
3235
3773
 
3236
- await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version)
3774
+ async with acquire_with_retry(pool) as conn:
3775
+ # Get node counts by fact_type
3776
+ node_stats = await conn.fetch(
3777
+ f"""
3778
+ SELECT fact_type, COUNT(*) as count
3779
+ FROM {fq_table("memory_units")}
3780
+ WHERE bank_id = $1
3781
+ GROUP BY fact_type
3782
+ """,
3783
+ bank_id,
3784
+ )
3237
3785
 
3238
- except Exception as e:
3239
- logger.error(f"[OBSERVATIONS] Error regenerating observations: {e}")
3240
- import traceback
3786
+ # Get link counts by link_type
3787
+ link_stats = await conn.fetch(
3788
+ f"""
3789
+ SELECT ml.link_type, COUNT(*) as count
3790
+ FROM {fq_table("memory_links")} ml
3791
+ JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
3792
+ WHERE mu.bank_id = $1
3793
+ GROUP BY ml.link_type
3794
+ """,
3795
+ bank_id,
3796
+ )
3241
3797
 
3242
- traceback.print_exc()
3798
+ # Get link counts by fact_type (from nodes)
3799
+ link_fact_type_stats = await conn.fetch(
3800
+ f"""
3801
+ SELECT mu.fact_type, COUNT(*) as count
3802
+ FROM {fq_table("memory_links")} ml
3803
+ JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
3804
+ WHERE mu.bank_id = $1
3805
+ GROUP BY mu.fact_type
3806
+ """,
3807
+ bank_id,
3808
+ )
3809
+
3810
+ # Get link counts by fact_type AND link_type
3811
+ link_breakdown_stats = await conn.fetch(
3812
+ f"""
3813
+ SELECT mu.fact_type, ml.link_type, COUNT(*) as count
3814
+ FROM {fq_table("memory_links")} ml
3815
+ JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
3816
+ WHERE mu.bank_id = $1
3817
+ GROUP BY mu.fact_type, ml.link_type
3818
+ """,
3819
+ bank_id,
3820
+ )
3821
+
3822
+ # Get pending and failed operations counts
3823
+ ops_stats = await conn.fetch(
3824
+ f"""
3825
+ SELECT status, COUNT(*) as count
3826
+ FROM {fq_table("async_operations")}
3827
+ WHERE bank_id = $1
3828
+ GROUP BY status
3829
+ """,
3830
+ bank_id,
3831
+ )
3832
+
3833
+ return {
3834
+ "bank_id": bank_id,
3835
+ "node_counts": {row["fact_type"]: row["count"] for row in node_stats},
3836
+ "link_counts": {row["link_type"]: row["count"] for row in link_stats},
3837
+ "link_counts_by_fact_type": {row["fact_type"]: row["count"] for row in link_fact_type_stats},
3838
+ "link_breakdown": [
3839
+ {"fact_type": row["fact_type"], "link_type": row["link_type"], "count": row["count"]}
3840
+ for row in link_breakdown_stats
3841
+ ],
3842
+ "operations": {row["status"]: row["count"] for row in ops_stats},
3843
+ }
3844
+
3845
+ async def get_entity(
3846
+ self,
3847
+ bank_id: str,
3848
+ entity_id: str,
3849
+ *,
3850
+ request_context: "RequestContext",
3851
+ ) -> dict[str, Any] | None:
3852
+ """Get entity details including metadata and observations."""
3853
+ await self._authenticate_tenant(request_context)
3854
+ pool = await self._get_pool()
3855
+
3856
+ async with acquire_with_retry(pool) as conn:
3857
+ entity_row = await conn.fetchrow(
3858
+ f"""
3859
+ SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
3860
+ FROM {fq_table("entities")}
3861
+ WHERE bank_id = $1 AND id = $2
3862
+ """,
3863
+ bank_id,
3864
+ uuid.UUID(entity_id),
3865
+ )
3866
+
3867
+ if not entity_row:
3868
+ return None
3869
+
3870
+ # Get observations for the entity
3871
+ observations = await self.get_entity_observations(bank_id, entity_id, limit=20, request_context=request_context)
3872
+
3873
+ return {
3874
+ "id": str(entity_row["id"]),
3875
+ "canonical_name": entity_row["canonical_name"],
3876
+ "mention_count": entity_row["mention_count"],
3877
+ "first_seen": entity_row["first_seen"].isoformat() if entity_row["first_seen"] else None,
3878
+ "last_seen": entity_row["last_seen"].isoformat() if entity_row["last_seen"] else None,
3879
+ "metadata": entity_row["metadata"] or {},
3880
+ "observations": observations,
3881
+ }
3882
+
3883
+ async def list_operations(
3884
+ self,
3885
+ bank_id: str,
3886
+ *,
3887
+ request_context: "RequestContext",
3888
+ ) -> list[dict[str, Any]]:
3889
+ """List async operations for a bank."""
3890
+ await self._authenticate_tenant(request_context)
3891
+ pool = await self._get_pool()
3892
+
3893
+ async with acquire_with_retry(pool) as conn:
3894
+ operations = await conn.fetch(
3895
+ f"""
3896
+ SELECT operation_id, bank_id, operation_type, created_at, status, error_message, result_metadata
3897
+ FROM {fq_table("async_operations")}
3898
+ WHERE bank_id = $1
3899
+ ORDER BY created_at DESC
3900
+ """,
3901
+ bank_id,
3902
+ )
3903
+
3904
+ def parse_metadata(metadata):
3905
+ if metadata is None:
3906
+ return {}
3907
+ if isinstance(metadata, str):
3908
+ import json
3909
+
3910
+ return json.loads(metadata)
3911
+ return metadata
3912
+
3913
+ return [
3914
+ {
3915
+ "id": str(row["operation_id"]),
3916
+ "task_type": row["operation_type"],
3917
+ "items_count": parse_metadata(row["result_metadata"]).get("items_count", 0),
3918
+ "document_id": parse_metadata(row["result_metadata"]).get("document_id"),
3919
+ "created_at": row["created_at"].isoformat(),
3920
+ "status": row["status"],
3921
+ "error_message": row["error_message"],
3922
+ }
3923
+ for row in operations
3924
+ ]
3925
+
3926
+ async def cancel_operation(
3927
+ self,
3928
+ bank_id: str,
3929
+ operation_id: str,
3930
+ *,
3931
+ request_context: "RequestContext",
3932
+ ) -> dict[str, Any]:
3933
+ """Cancel a pending async operation."""
3934
+ await self._authenticate_tenant(request_context)
3935
+ pool = await self._get_pool()
3936
+
3937
+ op_uuid = uuid.UUID(operation_id)
3938
+
3939
+ async with acquire_with_retry(pool) as conn:
3940
+ # Check if operation exists and belongs to this memory bank
3941
+ result = await conn.fetchrow(
3942
+ f"SELECT bank_id FROM {fq_table('async_operations')} WHERE operation_id = $1 AND bank_id = $2",
3943
+ op_uuid,
3944
+ bank_id,
3945
+ )
3946
+
3947
+ if not result:
3948
+ raise ValueError(f"Operation {operation_id} not found for bank {bank_id}")
3949
+
3950
+ # Delete the operation
3951
+ await conn.execute(f"DELETE FROM {fq_table('async_operations')} WHERE operation_id = $1", op_uuid)
3952
+
3953
+ return {
3954
+ "success": True,
3955
+ "message": f"Operation {operation_id} cancelled",
3956
+ "operation_id": operation_id,
3957
+ "bank_id": bank_id,
3958
+ }
3959
+
3960
+ async def update_bank(
3961
+ self,
3962
+ bank_id: str,
3963
+ *,
3964
+ name: str | None = None,
3965
+ background: str | None = None,
3966
+ request_context: "RequestContext",
3967
+ ) -> dict[str, Any]:
3968
+ """Update bank name and/or background."""
3969
+ await self._authenticate_tenant(request_context)
3970
+ pool = await self._get_pool()
3971
+
3972
+ async with acquire_with_retry(pool) as conn:
3973
+ if name is not None:
3974
+ await conn.execute(
3975
+ f"""
3976
+ UPDATE {fq_table("banks")}
3977
+ SET name = $2, updated_at = NOW()
3978
+ WHERE bank_id = $1
3979
+ """,
3980
+ bank_id,
3981
+ name,
3982
+ )
3983
+
3984
+ if background is not None:
3985
+ await conn.execute(
3986
+ f"""
3987
+ UPDATE {fq_table("banks")}
3988
+ SET background = $2, updated_at = NOW()
3989
+ WHERE bank_id = $1
3990
+ """,
3991
+ bank_id,
3992
+ background,
3993
+ )
3994
+
3995
+ # Return updated profile
3996
+ return await self.get_bank_profile(bank_id, request_context=request_context)
3997
+
3998
+ async def submit_async_retain(
3999
+ self,
4000
+ bank_id: str,
4001
+ contents: list[dict[str, Any]],
4002
+ *,
4003
+ request_context: "RequestContext",
4004
+ ) -> dict[str, Any]:
4005
+ """Submit a batch retain operation to run asynchronously."""
4006
+ await self._authenticate_tenant(request_context)
4007
+ pool = await self._get_pool()
4008
+
4009
+ import json
4010
+
4011
+ operation_id = uuid.uuid4()
4012
+
4013
+ # Insert operation record into database
4014
+ async with acquire_with_retry(pool) as conn:
4015
+ await conn.execute(
4016
+ f"""
4017
+ INSERT INTO {fq_table("async_operations")} (operation_id, bank_id, operation_type, result_metadata)
4018
+ VALUES ($1, $2, $3, $4)
4019
+ """,
4020
+ operation_id,
4021
+ bank_id,
4022
+ "retain",
4023
+ json.dumps({"items_count": len(contents)}),
4024
+ )
4025
+
4026
+ # Submit task to background queue
4027
+ await self._task_backend.submit_task(
4028
+ {
4029
+ "type": "batch_retain",
4030
+ "operation_id": str(operation_id),
4031
+ "bank_id": bank_id,
4032
+ "contents": contents,
4033
+ }
4034
+ )
4035
+
4036
+ logger.info(f"Retain task queued for bank_id={bank_id}, {len(contents)} items, operation_id={operation_id}")
4037
+
4038
+ return {
4039
+ "operation_id": str(operation_id),
4040
+ "items_count": len(contents),
4041
+ }