hindsight-api 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +2 -0
- hindsight_api/alembic/env.py +24 -1
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +14 -4
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +54 -13
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +18 -7
- hindsight_api/api/http.py +253 -230
- hindsight_api/api/mcp.py +14 -3
- hindsight_api/config.py +11 -0
- hindsight_api/daemon.py +204 -0
- hindsight_api/engine/__init__.py +12 -1
- hindsight_api/engine/entity_resolver.py +38 -37
- hindsight_api/engine/interface.py +592 -0
- hindsight_api/engine/llm_wrapper.py +176 -6
- hindsight_api/engine/memory_engine.py +1092 -293
- hindsight_api/engine/retain/bank_utils.py +13 -12
- hindsight_api/engine/retain/chunk_storage.py +3 -2
- hindsight_api/engine/retain/fact_storage.py +10 -7
- hindsight_api/engine/retain/link_utils.py +17 -16
- hindsight_api/engine/retain/observation_regeneration.py +17 -16
- hindsight_api/engine/retain/orchestrator.py +2 -3
- hindsight_api/engine/retain/types.py +25 -8
- hindsight_api/engine/search/graph_retrieval.py +6 -5
- hindsight_api/engine/search/mpfp_retrieval.py +8 -7
- hindsight_api/engine/search/reranking.py +17 -0
- hindsight_api/engine/search/retrieval.py +12 -11
- hindsight_api/engine/search/think_utils.py +1 -1
- hindsight_api/engine/search/tracer.py +1 -1
- hindsight_api/engine/task_backend.py +32 -0
- hindsight_api/extensions/__init__.py +66 -0
- hindsight_api/extensions/base.py +81 -0
- hindsight_api/extensions/builtin/__init__.py +18 -0
- hindsight_api/extensions/builtin/tenant.py +33 -0
- hindsight_api/extensions/context.py +110 -0
- hindsight_api/extensions/http.py +89 -0
- hindsight_api/extensions/loader.py +125 -0
- hindsight_api/extensions/operation_validator.py +325 -0
- hindsight_api/extensions/tenant.py +63 -0
- hindsight_api/main.py +97 -17
- hindsight_api/mcp_local.py +7 -1
- hindsight_api/migrations.py +54 -10
- hindsight_api/models.py +15 -0
- hindsight_api/pg0.py +1 -1
- {hindsight_api-0.1.11.dist-info → hindsight_api-0.1.13.dist-info}/METADATA +1 -1
- hindsight_api-0.1.13.dist-info/RECORD +75 -0
- hindsight_api-0.1.11.dist-info/RECORD +0 -64
- {hindsight_api-0.1.11.dist-info → hindsight_api-0.1.13.dist-info}/WHEEL +0 -0
- {hindsight_api-0.1.11.dist-info → hindsight_api-0.1.13.dist-info}/entry_points.txt +0 -0
|
@@ -10,39 +10,122 @@ This implements a sophisticated memory architecture that combines:
|
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
12
|
import asyncio
|
|
13
|
+
import contextvars
|
|
13
14
|
import logging
|
|
14
15
|
import time
|
|
15
16
|
import uuid
|
|
16
17
|
from datetime import UTC, datetime, timedelta
|
|
17
|
-
from typing import TYPE_CHECKING, Any
|
|
18
|
+
from typing import TYPE_CHECKING, Any
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
from pydantic import BaseModel, Field
|
|
20
|
+
# Context variable for current schema (async-safe, per-task isolation)
|
|
21
|
+
_current_schema: contextvars.ContextVar[str] = contextvars.ContextVar("current_schema", default="public")
|
|
22
22
|
|
|
23
|
-
from .cross_encoder import CrossEncoderModel
|
|
24
|
-
from .embeddings import Embeddings, create_embeddings_from_env
|
|
25
23
|
|
|
26
|
-
|
|
24
|
+
def get_current_schema() -> str:
|
|
25
|
+
"""Get the current schema from context (default: 'public')."""
|
|
26
|
+
return _current_schema.get()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def fq_table(table_name: str) -> str:
|
|
30
|
+
"""
|
|
31
|
+
Get fully-qualified table name with current schema.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
fq_table("memory_units") -> "public.memory_units"
|
|
35
|
+
fq_table("memory_units") -> "tenant_xyz.memory_units" (if schema is set)
|
|
36
|
+
"""
|
|
37
|
+
return f"{get_current_schema()}.{table_name}"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Tables that must be schema-qualified (for runtime validation)
|
|
41
|
+
_PROTECTED_TABLES = frozenset(
|
|
42
|
+
[
|
|
43
|
+
"memory_units",
|
|
44
|
+
"memory_links",
|
|
45
|
+
"unit_entities",
|
|
46
|
+
"entities",
|
|
47
|
+
"entity_cooccurrences",
|
|
48
|
+
"banks",
|
|
49
|
+
"documents",
|
|
50
|
+
"chunks",
|
|
51
|
+
"async_operations",
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Enable runtime SQL validation (can be disabled in production for performance)
|
|
56
|
+
_VALIDATE_SQL_SCHEMAS = True
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class UnqualifiedTableError(Exception):
|
|
60
|
+
"""Raised when SQL contains unqualified table references."""
|
|
61
|
+
|
|
27
62
|
pass
|
|
28
63
|
|
|
29
64
|
|
|
30
|
-
|
|
31
|
-
"""
|
|
65
|
+
def validate_sql_schema(sql: str) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Validate that SQL doesn't contain unqualified table references.
|
|
32
68
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
69
|
+
This is a runtime safety check to prevent cross-tenant data access.
|
|
70
|
+
Raises UnqualifiedTableError if any protected table is referenced
|
|
71
|
+
without a schema prefix.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
sql: The SQL query to validate
|
|
75
|
+
|
|
76
|
+
Raises:
|
|
77
|
+
UnqualifiedTableError: If unqualified table reference found
|
|
39
78
|
"""
|
|
79
|
+
if not _VALIDATE_SQL_SCHEMAS:
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
import re
|
|
83
|
+
|
|
84
|
+
sql_upper = sql.upper()
|
|
85
|
+
|
|
86
|
+
for table in _PROTECTED_TABLES:
|
|
87
|
+
table_upper = table.upper()
|
|
88
|
+
|
|
89
|
+
# Pattern: SQL keyword followed by unqualified table name
|
|
90
|
+
# Matches: FROM memory_units, JOIN memory_units, INTO memory_units, UPDATE memory_units
|
|
91
|
+
patterns = [
|
|
92
|
+
rf"FROM\s+{table_upper}(?:\s|$|,|\)|;)",
|
|
93
|
+
rf"JOIN\s+{table_upper}(?:\s|$|,|\)|;)",
|
|
94
|
+
rf"INTO\s+{table_upper}(?:\s|$|\()",
|
|
95
|
+
rf"UPDATE\s+{table_upper}(?:\s|$)",
|
|
96
|
+
rf"DELETE\s+FROM\s+{table_upper}(?:\s|$|;)",
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
for pattern in patterns:
|
|
100
|
+
match = re.search(pattern, sql_upper)
|
|
101
|
+
if match:
|
|
102
|
+
# Check if it's actually qualified (preceded by schema.)
|
|
103
|
+
# Look backwards from match to see if there's a dot
|
|
104
|
+
start = match.start()
|
|
105
|
+
# Find the table name position in the match
|
|
106
|
+
table_pos = sql_upper.find(table_upper, start)
|
|
107
|
+
if table_pos > 0:
|
|
108
|
+
# Check character before table name (skip whitespace)
|
|
109
|
+
prefix = sql[:table_pos].rstrip()
|
|
110
|
+
if not prefix.endswith("."):
|
|
111
|
+
raise UnqualifiedTableError(
|
|
112
|
+
f"Unqualified table reference '{table}' in SQL. "
|
|
113
|
+
f"Use fq_table('{table}') for schema safety. "
|
|
114
|
+
f"SQL snippet: ...{sql[max(0, start - 10) : start + 50]}..."
|
|
115
|
+
)
|
|
40
116
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
117
|
+
|
|
118
|
+
import asyncpg
|
|
119
|
+
import numpy as np
|
|
120
|
+
from pydantic import BaseModel, Field
|
|
121
|
+
|
|
122
|
+
from .cross_encoder import CrossEncoderModel
|
|
123
|
+
from .embeddings import Embeddings, create_embeddings_from_env
|
|
124
|
+
from .interface import MemoryEngineInterface
|
|
125
|
+
|
|
126
|
+
if TYPE_CHECKING:
|
|
127
|
+
from hindsight_api.extensions import OperationValidatorExtension, TenantExtension
|
|
128
|
+
from hindsight_api.models import RequestContext
|
|
46
129
|
|
|
47
130
|
|
|
48
131
|
from enum import Enum
|
|
@@ -54,6 +137,7 @@ from .query_analyzer import QueryAnalyzer
|
|
|
54
137
|
from .response_models import VALID_RECALL_FACT_TYPES, EntityObservation, EntityState, MemoryFact, ReflectResult
|
|
55
138
|
from .response_models import RecallResult as RecallResultModel
|
|
56
139
|
from .retain import bank_utils, embedding_utils
|
|
140
|
+
from .retain.types import RetainContentDict
|
|
57
141
|
from .search import observation_utils, think_utils
|
|
58
142
|
from .search.reranking import CrossEncoderReranker
|
|
59
143
|
from .task_backend import AsyncIOQueueBackend, TaskBackend
|
|
@@ -91,7 +175,7 @@ def _get_tiktoken_encoding():
|
|
|
91
175
|
return _TIKTOKEN_ENCODING
|
|
92
176
|
|
|
93
177
|
|
|
94
|
-
class MemoryEngine:
|
|
178
|
+
class MemoryEngine(MemoryEngineInterface):
|
|
95
179
|
"""
|
|
96
180
|
Advanced memory system using temporal and semantic linking with PostgreSQL.
|
|
97
181
|
|
|
@@ -116,6 +200,10 @@ class MemoryEngine:
|
|
|
116
200
|
pool_max_size: int = 100,
|
|
117
201
|
task_backend: TaskBackend | None = None,
|
|
118
202
|
run_migrations: bool = True,
|
|
203
|
+
operation_validator: "OperationValidatorExtension | None" = None,
|
|
204
|
+
tenant_extension: "TenantExtension | None" = None,
|
|
205
|
+
skip_llm_verification: bool | None = None,
|
|
206
|
+
lazy_reranker: bool | None = None,
|
|
119
207
|
):
|
|
120
208
|
"""
|
|
121
209
|
Initialize the temporal + semantic memory system.
|
|
@@ -137,16 +225,34 @@ class MemoryEngine:
|
|
|
137
225
|
pool_max_size: Maximum number of connections in the pool (default: 100)
|
|
138
226
|
task_backend: Custom task backend. If not provided, uses AsyncIOQueueBackend.
|
|
139
227
|
run_migrations: Whether to run database migrations during initialize(). Default: True
|
|
228
|
+
operation_validator: Optional extension to validate operations before execution.
|
|
229
|
+
If provided, retain/recall/reflect operations will be validated.
|
|
230
|
+
tenant_extension: Optional extension for multi-tenancy and API key authentication.
|
|
231
|
+
If provided, operations require a RequestContext for authentication.
|
|
232
|
+
skip_llm_verification: Skip LLM connection verification during initialization.
|
|
233
|
+
Defaults to HINDSIGHT_API_SKIP_LLM_VERIFICATION env var or False.
|
|
234
|
+
lazy_reranker: Delay reranker initialization until first use. Useful for retain-only
|
|
235
|
+
operations that don't need the cross-encoder. Defaults to
|
|
236
|
+
HINDSIGHT_API_LAZY_RERANKER env var or False.
|
|
140
237
|
"""
|
|
141
238
|
# Load config from environment for any missing parameters
|
|
142
239
|
from ..config import get_config
|
|
143
240
|
|
|
144
241
|
config = get_config()
|
|
145
242
|
|
|
243
|
+
# Apply optimization flags from config if not explicitly provided
|
|
244
|
+
self._skip_llm_verification = (
|
|
245
|
+
skip_llm_verification if skip_llm_verification is not None else config.skip_llm_verification
|
|
246
|
+
)
|
|
247
|
+
self._lazy_reranker = lazy_reranker if lazy_reranker is not None else config.lazy_reranker
|
|
248
|
+
|
|
146
249
|
# Apply defaults from config
|
|
147
250
|
db_url = db_url or config.database_url
|
|
148
251
|
memory_llm_provider = memory_llm_provider or config.llm_provider
|
|
149
252
|
memory_llm_api_key = memory_llm_api_key or config.llm_api_key
|
|
253
|
+
# Ollama doesn't require an API key
|
|
254
|
+
if not memory_llm_api_key and memory_llm_provider != "ollama":
|
|
255
|
+
raise ValueError("LLM API key is required. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
|
|
150
256
|
memory_llm_model = memory_llm_model or config.llm_model
|
|
151
257
|
memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
|
|
152
258
|
# Track pg0 instance (if used)
|
|
@@ -243,27 +349,82 @@ class MemoryEngine:
|
|
|
243
349
|
# initialize encoding eagerly to avoid delaying the first time
|
|
244
350
|
_get_tiktoken_encoding()
|
|
245
351
|
|
|
352
|
+
# Store operation validator extension (optional)
|
|
353
|
+
self._operation_validator = operation_validator
|
|
354
|
+
|
|
355
|
+
# Store tenant extension (optional)
|
|
356
|
+
self._tenant_extension = tenant_extension
|
|
357
|
+
|
|
358
|
+
async def _validate_operation(self, validation_coro) -> None:
|
|
359
|
+
"""
|
|
360
|
+
Run validation if an operation validator is configured.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
validation_coro: Coroutine that returns a ValidationResult
|
|
364
|
+
|
|
365
|
+
Raises:
|
|
366
|
+
OperationValidationError: If validation fails
|
|
367
|
+
"""
|
|
368
|
+
if self._operation_validator is None:
|
|
369
|
+
return
|
|
370
|
+
|
|
371
|
+
from hindsight_api.extensions import OperationValidationError
|
|
372
|
+
|
|
373
|
+
result = await validation_coro
|
|
374
|
+
if not result.allowed:
|
|
375
|
+
raise OperationValidationError(result.reason or "Operation not allowed")
|
|
376
|
+
|
|
377
|
+
async def _authenticate_tenant(self, request_context: "RequestContext | None") -> str:
|
|
378
|
+
"""
|
|
379
|
+
Authenticate tenant and set schema in context variable.
|
|
380
|
+
|
|
381
|
+
The schema is stored in a contextvar for async-safe, per-task isolation.
|
|
382
|
+
Use fq_table(table_name) to get fully-qualified table names.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
request_context: The request context with API key. Required if tenant_extension is configured.
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
Schema name that was set in the context.
|
|
389
|
+
|
|
390
|
+
Raises:
|
|
391
|
+
AuthenticationError: If authentication fails or request_context is missing when required.
|
|
392
|
+
"""
|
|
393
|
+
if self._tenant_extension is None:
|
|
394
|
+
_current_schema.set("public")
|
|
395
|
+
return "public"
|
|
396
|
+
|
|
397
|
+
from hindsight_api.extensions import AuthenticationError
|
|
398
|
+
|
|
399
|
+
if request_context is None:
|
|
400
|
+
raise AuthenticationError("RequestContext is required when tenant extension is configured")
|
|
401
|
+
|
|
402
|
+
tenant_context = await self._tenant_extension.authenticate(request_context)
|
|
403
|
+
_current_schema.set(tenant_context.schema_name)
|
|
404
|
+
return tenant_context.schema_name
|
|
405
|
+
|
|
246
406
|
async def _handle_access_count_update(self, task_dict: dict[str, Any]):
|
|
247
407
|
"""
|
|
248
408
|
Handler for access count update tasks.
|
|
249
409
|
|
|
250
410
|
Args:
|
|
251
411
|
task_dict: Dict with 'node_ids' key containing list of node IDs to update
|
|
412
|
+
|
|
413
|
+
Raises:
|
|
414
|
+
Exception: Any exception from database operations (propagates to execute_task for retry)
|
|
252
415
|
"""
|
|
253
416
|
node_ids = task_dict.get("node_ids", [])
|
|
254
417
|
if not node_ids:
|
|
255
418
|
return
|
|
256
419
|
|
|
257
420
|
pool = await self._get_pool()
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
except Exception as e:
|
|
266
|
-
logger.error(f"Access count handler: Error updating access counts: {e}")
|
|
421
|
+
# Convert string UUIDs to UUID type for faster matching
|
|
422
|
+
uuid_list = [uuid.UUID(nid) for nid in node_ids]
|
|
423
|
+
async with acquire_with_retry(pool) as conn:
|
|
424
|
+
await conn.execute(
|
|
425
|
+
f"UPDATE {fq_table('memory_units')} SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
|
|
426
|
+
uuid_list,
|
|
427
|
+
)
|
|
267
428
|
|
|
268
429
|
async def _handle_batch_retain(self, task_dict: dict[str, Any]):
|
|
269
430
|
"""
|
|
@@ -271,23 +432,27 @@ class MemoryEngine:
|
|
|
271
432
|
|
|
272
433
|
Args:
|
|
273
434
|
task_dict: Dict with 'bank_id', 'contents'
|
|
435
|
+
|
|
436
|
+
Raises:
|
|
437
|
+
ValueError: If bank_id is missing
|
|
438
|
+
Exception: Any exception from retain_batch_async (propagates to execute_task for retry)
|
|
274
439
|
"""
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
440
|
+
bank_id = task_dict.get("bank_id")
|
|
441
|
+
if not bank_id:
|
|
442
|
+
raise ValueError("bank_id is required for batch retain task")
|
|
443
|
+
contents = task_dict.get("contents", [])
|
|
278
444
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
445
|
+
logger.info(
|
|
446
|
+
f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
|
|
447
|
+
)
|
|
282
448
|
|
|
283
|
-
|
|
449
|
+
# Use internal request context for background tasks
|
|
450
|
+
from hindsight_api.models import RequestContext
|
|
284
451
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
logger.error(f"Batch retain handler: Error processing batch retain: {e}")
|
|
288
|
-
import traceback
|
|
452
|
+
internal_context = RequestContext()
|
|
453
|
+
await self.retain_batch_async(bank_id=bank_id, contents=contents, request_context=internal_context)
|
|
289
454
|
|
|
290
|
-
|
|
455
|
+
logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
|
|
291
456
|
|
|
292
457
|
async def execute_task(self, task_dict: dict[str, Any]):
|
|
293
458
|
"""
|
|
@@ -311,7 +476,8 @@ class MemoryEngine:
|
|
|
311
476
|
pool = await self._get_pool()
|
|
312
477
|
async with acquire_with_retry(pool) as conn:
|
|
313
478
|
result = await conn.fetchrow(
|
|
314
|
-
"SELECT operation_id FROM async_operations WHERE operation_id = $1",
|
|
479
|
+
f"SELECT operation_id FROM {fq_table('async_operations')} WHERE operation_id = $1",
|
|
480
|
+
uuid.UUID(operation_id),
|
|
315
481
|
)
|
|
316
482
|
if not result:
|
|
317
483
|
# Operation was cancelled, skip processing
|
|
@@ -369,7 +535,9 @@ class MemoryEngine:
|
|
|
369
535
|
try:
|
|
370
536
|
pool = await self._get_pool()
|
|
371
537
|
async with acquire_with_retry(pool) as conn:
|
|
372
|
-
await conn.execute(
|
|
538
|
+
await conn.execute(
|
|
539
|
+
f"DELETE FROM {fq_table('async_operations')} WHERE operation_id = $1", uuid.UUID(operation_id)
|
|
540
|
+
)
|
|
373
541
|
except Exception as e:
|
|
374
542
|
logger.error(f"Failed to delete async operation record {operation_id}: {e}")
|
|
375
543
|
|
|
@@ -383,8 +551,8 @@ class MemoryEngine:
|
|
|
383
551
|
|
|
384
552
|
async with acquire_with_retry(pool) as conn:
|
|
385
553
|
await conn.execute(
|
|
386
|
-
"""
|
|
387
|
-
UPDATE async_operations
|
|
554
|
+
f"""
|
|
555
|
+
UPDATE {fq_table("async_operations")}
|
|
388
556
|
SET status = 'failed', error_message = $2
|
|
389
557
|
WHERE operation_id = $1
|
|
390
558
|
""",
|
|
@@ -413,7 +581,7 @@ class MemoryEngine:
|
|
|
413
581
|
kwargs = {"name": self._pg0_instance_name}
|
|
414
582
|
if self._pg0_port is not None:
|
|
415
583
|
kwargs["port"] = self._pg0_port
|
|
416
|
-
pg0 = EmbeddedPostgres(**kwargs)
|
|
584
|
+
pg0 = EmbeddedPostgres(**kwargs) # type: ignore[invalid-argument-type] - dict kwargs
|
|
417
585
|
# Check if pg0 is already running before we start it
|
|
418
586
|
was_already_running = await pg0.is_running()
|
|
419
587
|
self.db_url = await pg0.ensure_running()
|
|
@@ -437,6 +605,8 @@ class MemoryEngine:
|
|
|
437
605
|
await loop.run_in_executor(None, lambda: asyncio.run(cross_encoder.initialize()))
|
|
438
606
|
else:
|
|
439
607
|
await cross_encoder.initialize()
|
|
608
|
+
# Mark reranker as initialized
|
|
609
|
+
self._cross_encoder_reranker._initialized = True
|
|
440
610
|
|
|
441
611
|
async def init_query_analyzer():
|
|
442
612
|
"""Initialize query analyzer model."""
|
|
@@ -445,21 +615,33 @@ class MemoryEngine:
|
|
|
445
615
|
|
|
446
616
|
async def verify_llm():
|
|
447
617
|
"""Verify LLM connection is working."""
|
|
448
|
-
|
|
618
|
+
if not self._skip_llm_verification:
|
|
619
|
+
await self._llm_config.verify_connection()
|
|
449
620
|
|
|
450
|
-
#
|
|
451
|
-
|
|
621
|
+
# Build list of initialization tasks
|
|
622
|
+
init_tasks = [
|
|
452
623
|
start_pg0(),
|
|
453
624
|
init_embeddings(),
|
|
454
|
-
init_cross_encoder(),
|
|
455
625
|
init_query_analyzer(),
|
|
456
|
-
|
|
457
|
-
|
|
626
|
+
]
|
|
627
|
+
|
|
628
|
+
# Only init cross-encoder eagerly if not using lazy initialization
|
|
629
|
+
if not self._lazy_reranker:
|
|
630
|
+
init_tasks.append(init_cross_encoder())
|
|
631
|
+
|
|
632
|
+
# Only verify LLM if not skipping
|
|
633
|
+
if not self._skip_llm_verification:
|
|
634
|
+
init_tasks.append(verify_llm())
|
|
635
|
+
|
|
636
|
+
# Run pg0 and selected model initializations in parallel
|
|
637
|
+
await asyncio.gather(*init_tasks)
|
|
458
638
|
|
|
459
639
|
# Run database migrations if enabled
|
|
460
640
|
if self._run_migrations:
|
|
461
641
|
from ..migrations import run_migrations
|
|
462
642
|
|
|
643
|
+
if not self.db_url:
|
|
644
|
+
raise ValueError("Database URL is required for migrations")
|
|
463
645
|
logger.info("Running database migrations...")
|
|
464
646
|
run_migrations(self.db_url)
|
|
465
647
|
|
|
@@ -628,9 +810,9 @@ class MemoryEngine:
|
|
|
628
810
|
|
|
629
811
|
fetch_start = time_mod.time()
|
|
630
812
|
existing_facts = await conn.fetch(
|
|
631
|
-
"""
|
|
813
|
+
f"""
|
|
632
814
|
SELECT id, text, embedding
|
|
633
|
-
FROM memory_units
|
|
815
|
+
FROM {fq_table("memory_units")}
|
|
634
816
|
WHERE bank_id = $1
|
|
635
817
|
AND event_date BETWEEN $2 AND $3
|
|
636
818
|
""",
|
|
@@ -692,6 +874,7 @@ class MemoryEngine:
|
|
|
692
874
|
content: str,
|
|
693
875
|
context: str = "",
|
|
694
876
|
event_date: datetime | None = None,
|
|
877
|
+
request_context: "RequestContext | None" = None,
|
|
695
878
|
) -> list[str]:
|
|
696
879
|
"""
|
|
697
880
|
Store content as memory units (synchronous wrapper).
|
|
@@ -704,12 +887,16 @@ class MemoryEngine:
|
|
|
704
887
|
content: Text content to store
|
|
705
888
|
context: Context about when/why this memory was formed
|
|
706
889
|
event_date: When the event occurred (defaults to now)
|
|
890
|
+
request_context: Request context for authentication (optional, uses internal context if not provided)
|
|
707
891
|
|
|
708
892
|
Returns:
|
|
709
893
|
List of created unit IDs
|
|
710
894
|
"""
|
|
711
895
|
# Run async version synchronously
|
|
712
|
-
|
|
896
|
+
from hindsight_api.models import RequestContext as RC
|
|
897
|
+
|
|
898
|
+
ctx = request_context if request_context is not None else RC()
|
|
899
|
+
return asyncio.run(self.retain_async(bank_id, content, context, event_date, request_context=ctx))
|
|
713
900
|
|
|
714
901
|
async def retain_async(
|
|
715
902
|
self,
|
|
@@ -720,6 +907,8 @@ class MemoryEngine:
|
|
|
720
907
|
document_id: str | None = None,
|
|
721
908
|
fact_type_override: str | None = None,
|
|
722
909
|
confidence_score: float | None = None,
|
|
910
|
+
*,
|
|
911
|
+
request_context: "RequestContext",
|
|
723
912
|
) -> list[str]:
|
|
724
913
|
"""
|
|
725
914
|
Store content as memory units with temporal and semantic links (ASYNC version).
|
|
@@ -734,12 +923,15 @@ class MemoryEngine:
|
|
|
734
923
|
document_id: Optional document ID for tracking (always upserts if document already exists)
|
|
735
924
|
fact_type_override: Override fact type ('world', 'experience', 'opinion')
|
|
736
925
|
confidence_score: Confidence score for opinions (0.0 to 1.0)
|
|
926
|
+
request_context: Request context for authentication.
|
|
737
927
|
|
|
738
928
|
Returns:
|
|
739
929
|
List of created unit IDs
|
|
740
930
|
"""
|
|
741
931
|
# Build content dict
|
|
742
|
-
content_dict: RetainContentDict = {"content": content, "context": context
|
|
932
|
+
content_dict: RetainContentDict = {"content": content, "context": context} # type: ignore[typeddict-item] - building incrementally
|
|
933
|
+
if event_date:
|
|
934
|
+
content_dict["event_date"] = event_date
|
|
743
935
|
if document_id:
|
|
744
936
|
content_dict["document_id"] = document_id
|
|
745
937
|
|
|
@@ -747,6 +939,7 @@ class MemoryEngine:
|
|
|
747
939
|
result = await self.retain_batch_async(
|
|
748
940
|
bank_id=bank_id,
|
|
749
941
|
contents=[content_dict],
|
|
942
|
+
request_context=request_context,
|
|
750
943
|
fact_type_override=fact_type_override,
|
|
751
944
|
confidence_score=confidence_score,
|
|
752
945
|
)
|
|
@@ -758,6 +951,8 @@ class MemoryEngine:
|
|
|
758
951
|
self,
|
|
759
952
|
bank_id: str,
|
|
760
953
|
contents: list[RetainContentDict],
|
|
954
|
+
*,
|
|
955
|
+
request_context: "RequestContext",
|
|
761
956
|
document_id: str | None = None,
|
|
762
957
|
fact_type_override: str | None = None,
|
|
763
958
|
confidence_score: float | None = None,
|
|
@@ -813,6 +1008,24 @@ class MemoryEngine:
|
|
|
813
1008
|
if not contents:
|
|
814
1009
|
return []
|
|
815
1010
|
|
|
1011
|
+
# Authenticate tenant and set schema in context (for fq_table())
|
|
1012
|
+
await self._authenticate_tenant(request_context)
|
|
1013
|
+
|
|
1014
|
+
# Validate operation if validator is configured
|
|
1015
|
+
contents_copy = [dict(c) for c in contents] # Convert TypedDict to regular dict for extension
|
|
1016
|
+
if self._operation_validator:
|
|
1017
|
+
from hindsight_api.extensions import RetainContext
|
|
1018
|
+
|
|
1019
|
+
ctx = RetainContext(
|
|
1020
|
+
bank_id=bank_id,
|
|
1021
|
+
contents=contents_copy,
|
|
1022
|
+
request_context=request_context,
|
|
1023
|
+
document_id=document_id,
|
|
1024
|
+
fact_type_override=fact_type_override,
|
|
1025
|
+
confidence_score=confidence_score,
|
|
1026
|
+
)
|
|
1027
|
+
await self._validate_operation(self._operation_validator.validate_retain(ctx))
|
|
1028
|
+
|
|
816
1029
|
# Apply batch-level document_id to contents that don't have their own (backwards compatibility)
|
|
817
1030
|
if document_id:
|
|
818
1031
|
for item in contents:
|
|
@@ -876,17 +1089,39 @@ class MemoryEngine:
|
|
|
876
1089
|
logger.info(
|
|
877
1090
|
f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s"
|
|
878
1091
|
)
|
|
879
|
-
|
|
1092
|
+
result = all_results
|
|
1093
|
+
else:
|
|
1094
|
+
# Small batch - use internal method directly
|
|
1095
|
+
result = await self._retain_batch_async_internal(
|
|
1096
|
+
bank_id=bank_id,
|
|
1097
|
+
contents=contents,
|
|
1098
|
+
document_id=document_id,
|
|
1099
|
+
is_first_batch=True,
|
|
1100
|
+
fact_type_override=fact_type_override,
|
|
1101
|
+
confidence_score=confidence_score,
|
|
1102
|
+
)
|
|
880
1103
|
|
|
881
|
-
#
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
1104
|
+
# Call post-operation hook if validator is configured
|
|
1105
|
+
if self._operation_validator:
|
|
1106
|
+
from hindsight_api.extensions import RetainResult
|
|
1107
|
+
|
|
1108
|
+
result_ctx = RetainResult(
|
|
1109
|
+
bank_id=bank_id,
|
|
1110
|
+
contents=contents_copy,
|
|
1111
|
+
request_context=request_context,
|
|
1112
|
+
document_id=document_id,
|
|
1113
|
+
fact_type_override=fact_type_override,
|
|
1114
|
+
confidence_score=confidence_score,
|
|
1115
|
+
unit_ids=result,
|
|
1116
|
+
success=True,
|
|
1117
|
+
error=None,
|
|
1118
|
+
)
|
|
1119
|
+
try:
|
|
1120
|
+
await self._operation_validator.on_retain_complete(result_ctx)
|
|
1121
|
+
except Exception as e:
|
|
1122
|
+
logger.warning(f"Post-retain hook error (non-fatal): {e}")
|
|
1123
|
+
|
|
1124
|
+
return result
|
|
890
1125
|
|
|
891
1126
|
async def _retain_batch_async_internal(
|
|
892
1127
|
self,
|
|
@@ -961,22 +1196,36 @@ class MemoryEngine:
|
|
|
961
1196
|
Returns:
|
|
962
1197
|
Tuple of (results, trace)
|
|
963
1198
|
"""
|
|
964
|
-
# Run async version synchronously
|
|
965
|
-
|
|
1199
|
+
# Run async version synchronously - deprecated sync method, passing None for request_context
|
|
1200
|
+
from hindsight_api.models import RequestContext
|
|
1201
|
+
|
|
1202
|
+
return asyncio.run(
|
|
1203
|
+
self.recall_async(
|
|
1204
|
+
bank_id,
|
|
1205
|
+
query,
|
|
1206
|
+
budget=budget,
|
|
1207
|
+
max_tokens=max_tokens,
|
|
1208
|
+
enable_trace=enable_trace,
|
|
1209
|
+
fact_type=[fact_type],
|
|
1210
|
+
request_context=RequestContext(),
|
|
1211
|
+
)
|
|
1212
|
+
)
|
|
966
1213
|
|
|
967
1214
|
async def recall_async(
|
|
968
1215
|
self,
|
|
969
1216
|
bank_id: str,
|
|
970
1217
|
query: str,
|
|
971
|
-
|
|
972
|
-
budget: Budget =
|
|
1218
|
+
*,
|
|
1219
|
+
budget: Budget | None = None,
|
|
973
1220
|
max_tokens: int = 4096,
|
|
974
1221
|
enable_trace: bool = False,
|
|
1222
|
+
fact_type: list[str] | None = None,
|
|
975
1223
|
question_date: datetime | None = None,
|
|
976
1224
|
include_entities: bool = False,
|
|
977
|
-
max_entity_tokens: int =
|
|
1225
|
+
max_entity_tokens: int = 500,
|
|
978
1226
|
include_chunks: bool = False,
|
|
979
1227
|
max_chunk_tokens: int = 8192,
|
|
1228
|
+
request_context: "RequestContext",
|
|
980
1229
|
) -> RecallResultModel:
|
|
981
1230
|
"""
|
|
982
1231
|
Recall memories using N*4-way parallel retrieval (N fact types × 4 retrieval methods).
|
|
@@ -1010,6 +1259,13 @@ class MemoryEngine:
|
|
|
1010
1259
|
- entities: Optional dict of entity states (if include_entities=True)
|
|
1011
1260
|
- chunks: Optional dict of chunks (if include_chunks=True)
|
|
1012
1261
|
"""
|
|
1262
|
+
# Authenticate tenant and set schema in context (for fq_table())
|
|
1263
|
+
await self._authenticate_tenant(request_context)
|
|
1264
|
+
|
|
1265
|
+
# Default to all fact types if not specified
|
|
1266
|
+
if fact_type is None:
|
|
1267
|
+
fact_type = list(VALID_RECALL_FACT_TYPES)
|
|
1268
|
+
|
|
1013
1269
|
# Validate fact types early
|
|
1014
1270
|
invalid_types = set(fact_type) - VALID_RECALL_FACT_TYPES
|
|
1015
1271
|
if invalid_types:
|
|
@@ -1018,17 +1274,40 @@ class MemoryEngine:
|
|
|
1018
1274
|
f"Must be one of: {', '.join(sorted(VALID_RECALL_FACT_TYPES))}"
|
|
1019
1275
|
)
|
|
1020
1276
|
|
|
1021
|
-
#
|
|
1277
|
+
# Validate operation if validator is configured
|
|
1278
|
+
if self._operation_validator:
|
|
1279
|
+
from hindsight_api.extensions import RecallContext
|
|
1280
|
+
|
|
1281
|
+
ctx = RecallContext(
|
|
1282
|
+
bank_id=bank_id,
|
|
1283
|
+
query=query,
|
|
1284
|
+
request_context=request_context,
|
|
1285
|
+
budget=budget,
|
|
1286
|
+
max_tokens=max_tokens,
|
|
1287
|
+
enable_trace=enable_trace,
|
|
1288
|
+
fact_types=list(fact_type),
|
|
1289
|
+
question_date=question_date,
|
|
1290
|
+
include_entities=include_entities,
|
|
1291
|
+
max_entity_tokens=max_entity_tokens,
|
|
1292
|
+
include_chunks=include_chunks,
|
|
1293
|
+
max_chunk_tokens=max_chunk_tokens,
|
|
1294
|
+
)
|
|
1295
|
+
await self._validate_operation(self._operation_validator.validate_recall(ctx))
|
|
1296
|
+
|
|
1297
|
+
# Map budget enum to thinking_budget number (default to MID if None)
|
|
1022
1298
|
budget_mapping = {Budget.LOW: 100, Budget.MID: 300, Budget.HIGH: 1000}
|
|
1023
|
-
|
|
1299
|
+
effective_budget = budget if budget is not None else Budget.MID
|
|
1300
|
+
thinking_budget = budget_mapping[effective_budget]
|
|
1024
1301
|
|
|
1025
1302
|
# Backpressure: limit concurrent recalls to prevent overwhelming the database
|
|
1303
|
+
result = None
|
|
1304
|
+
error_msg = None
|
|
1026
1305
|
async with self._search_semaphore:
|
|
1027
1306
|
# Retry loop for connection errors
|
|
1028
1307
|
max_retries = 3
|
|
1029
1308
|
for attempt in range(max_retries + 1):
|
|
1030
1309
|
try:
|
|
1031
|
-
|
|
1310
|
+
result = await self._search_with_retries(
|
|
1032
1311
|
bank_id,
|
|
1033
1312
|
query,
|
|
1034
1313
|
fact_type,
|
|
@@ -1040,7 +1319,9 @@ class MemoryEngine:
|
|
|
1040
1319
|
max_entity_tokens,
|
|
1041
1320
|
include_chunks,
|
|
1042
1321
|
max_chunk_tokens,
|
|
1322
|
+
request_context,
|
|
1043
1323
|
)
|
|
1324
|
+
break # Success - exit retry loop
|
|
1044
1325
|
except Exception as e:
|
|
1045
1326
|
# Check if it's a connection error
|
|
1046
1327
|
is_connection_error = (
|
|
@@ -1058,9 +1339,89 @@ class MemoryEngine:
|
|
|
1058
1339
|
)
|
|
1059
1340
|
await asyncio.sleep(wait_time)
|
|
1060
1341
|
else:
|
|
1061
|
-
# Not a connection error or out of retries - raise
|
|
1342
|
+
# Not a connection error or out of retries - call post-hook and raise
|
|
1343
|
+
error_msg = str(e)
|
|
1344
|
+
if self._operation_validator:
|
|
1345
|
+
from hindsight_api.extensions.operation_validator import RecallResult
|
|
1346
|
+
|
|
1347
|
+
result_ctx = RecallResult(
|
|
1348
|
+
bank_id=bank_id,
|
|
1349
|
+
query=query,
|
|
1350
|
+
request_context=request_context,
|
|
1351
|
+
budget=budget,
|
|
1352
|
+
max_tokens=max_tokens,
|
|
1353
|
+
enable_trace=enable_trace,
|
|
1354
|
+
fact_types=list(fact_type),
|
|
1355
|
+
question_date=question_date,
|
|
1356
|
+
include_entities=include_entities,
|
|
1357
|
+
max_entity_tokens=max_entity_tokens,
|
|
1358
|
+
include_chunks=include_chunks,
|
|
1359
|
+
max_chunk_tokens=max_chunk_tokens,
|
|
1360
|
+
result=None,
|
|
1361
|
+
success=False,
|
|
1362
|
+
error=error_msg,
|
|
1363
|
+
)
|
|
1364
|
+
try:
|
|
1365
|
+
await self._operation_validator.on_recall_complete(result_ctx)
|
|
1366
|
+
except Exception as hook_err:
|
|
1367
|
+
logger.warning(f"Post-recall hook error (non-fatal): {hook_err}")
|
|
1062
1368
|
raise
|
|
1063
|
-
|
|
1369
|
+
else:
|
|
1370
|
+
# Exceeded max retries
|
|
1371
|
+
error_msg = "Exceeded maximum retries for search due to connection errors."
|
|
1372
|
+
if self._operation_validator:
|
|
1373
|
+
from hindsight_api.extensions.operation_validator import RecallResult
|
|
1374
|
+
|
|
1375
|
+
result_ctx = RecallResult(
|
|
1376
|
+
bank_id=bank_id,
|
|
1377
|
+
query=query,
|
|
1378
|
+
request_context=request_context,
|
|
1379
|
+
budget=budget,
|
|
1380
|
+
max_tokens=max_tokens,
|
|
1381
|
+
enable_trace=enable_trace,
|
|
1382
|
+
fact_types=list(fact_type),
|
|
1383
|
+
question_date=question_date,
|
|
1384
|
+
include_entities=include_entities,
|
|
1385
|
+
max_entity_tokens=max_entity_tokens,
|
|
1386
|
+
include_chunks=include_chunks,
|
|
1387
|
+
max_chunk_tokens=max_chunk_tokens,
|
|
1388
|
+
result=None,
|
|
1389
|
+
success=False,
|
|
1390
|
+
error=error_msg,
|
|
1391
|
+
)
|
|
1392
|
+
try:
|
|
1393
|
+
await self._operation_validator.on_recall_complete(result_ctx)
|
|
1394
|
+
except Exception as hook_err:
|
|
1395
|
+
logger.warning(f"Post-recall hook error (non-fatal): {hook_err}")
|
|
1396
|
+
raise Exception(error_msg)
|
|
1397
|
+
|
|
1398
|
+
# Call post-operation hook for success
|
|
1399
|
+
if self._operation_validator and result is not None:
|
|
1400
|
+
from hindsight_api.extensions.operation_validator import RecallResult
|
|
1401
|
+
|
|
1402
|
+
result_ctx = RecallResult(
|
|
1403
|
+
bank_id=bank_id,
|
|
1404
|
+
query=query,
|
|
1405
|
+
request_context=request_context,
|
|
1406
|
+
budget=budget,
|
|
1407
|
+
max_tokens=max_tokens,
|
|
1408
|
+
enable_trace=enable_trace,
|
|
1409
|
+
fact_types=list(fact_type),
|
|
1410
|
+
question_date=question_date,
|
|
1411
|
+
include_entities=include_entities,
|
|
1412
|
+
max_entity_tokens=max_entity_tokens,
|
|
1413
|
+
include_chunks=include_chunks,
|
|
1414
|
+
max_chunk_tokens=max_chunk_tokens,
|
|
1415
|
+
result=result,
|
|
1416
|
+
success=True,
|
|
1417
|
+
error=None,
|
|
1418
|
+
)
|
|
1419
|
+
try:
|
|
1420
|
+
await self._operation_validator.on_recall_complete(result_ctx)
|
|
1421
|
+
except Exception as e:
|
|
1422
|
+
logger.warning(f"Post-recall hook error (non-fatal): {e}")
|
|
1423
|
+
|
|
1424
|
+
return result
|
|
1064
1425
|
|
|
1065
1426
|
async def _search_with_retries(
|
|
1066
1427
|
self,
|
|
@@ -1075,6 +1436,7 @@ class MemoryEngine:
|
|
|
1075
1436
|
max_entity_tokens: int = 500,
|
|
1076
1437
|
include_chunks: bool = False,
|
|
1077
1438
|
max_chunk_tokens: int = 8192,
|
|
1439
|
+
request_context: "RequestContext" = None,
|
|
1078
1440
|
) -> RecallResultModel:
|
|
1079
1441
|
"""
|
|
1080
1442
|
Search implementation with modular retrieval and reranking.
|
|
@@ -1302,6 +1664,9 @@ class MemoryEngine:
|
|
|
1302
1664
|
step_start = time.time()
|
|
1303
1665
|
reranker_instance = self._cross_encoder_reranker
|
|
1304
1666
|
|
|
1667
|
+
# Ensure reranker is initialized (for lazy initialization mode)
|
|
1668
|
+
await reranker_instance.ensure_initialized()
|
|
1669
|
+
|
|
1305
1670
|
# Rerank using cross-encoder
|
|
1306
1671
|
scored_results = reranker_instance.rerank(query, merged_candidates)
|
|
1307
1672
|
|
|
@@ -1465,10 +1830,10 @@ class MemoryEngine:
|
|
|
1465
1830
|
if unit_ids:
|
|
1466
1831
|
async with acquire_with_retry(pool) as entity_conn:
|
|
1467
1832
|
entity_rows = await entity_conn.fetch(
|
|
1468
|
-
"""
|
|
1833
|
+
f"""
|
|
1469
1834
|
SELECT ue.unit_id, e.id as entity_id, e.canonical_name
|
|
1470
|
-
FROM unit_entities ue
|
|
1471
|
-
JOIN entities e ON ue.entity_id = e.id
|
|
1835
|
+
FROM {fq_table("unit_entities")} ue
|
|
1836
|
+
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
1472
1837
|
WHERE ue.unit_id = ANY($1::uuid[])
|
|
1473
1838
|
""",
|
|
1474
1839
|
unit_ids,
|
|
@@ -1534,7 +1899,9 @@ class MemoryEngine:
|
|
|
1534
1899
|
if total_entity_tokens >= max_entity_tokens:
|
|
1535
1900
|
break
|
|
1536
1901
|
|
|
1537
|
-
observations = await self.get_entity_observations(
|
|
1902
|
+
observations = await self.get_entity_observations(
|
|
1903
|
+
bank_id, entity_id, limit=5, request_context=request_context
|
|
1904
|
+
)
|
|
1538
1905
|
|
|
1539
1906
|
# Calculate tokens for this entity's observations
|
|
1540
1907
|
entity_tokens = 0
|
|
@@ -1572,9 +1939,9 @@ class MemoryEngine:
|
|
|
1572
1939
|
# Fetch chunk data from database using chunk_ids (no ORDER BY to preserve input order)
|
|
1573
1940
|
async with acquire_with_retry(pool) as conn:
|
|
1574
1941
|
chunks_rows = await conn.fetch(
|
|
1575
|
-
"""
|
|
1942
|
+
f"""
|
|
1576
1943
|
SELECT chunk_id, chunk_text, chunk_index
|
|
1577
|
-
FROM chunks
|
|
1944
|
+
FROM {fq_table("chunks")}
|
|
1578
1945
|
WHERE chunk_id = ANY($1::text[])
|
|
1579
1946
|
""",
|
|
1580
1947
|
chunk_ids_ordered,
|
|
@@ -1671,25 +2038,33 @@ class MemoryEngine:
|
|
|
1671
2038
|
|
|
1672
2039
|
return filtered_results, total_tokens
|
|
1673
2040
|
|
|
1674
|
-
async def get_document(
|
|
2041
|
+
async def get_document(
|
|
2042
|
+
self,
|
|
2043
|
+
document_id: str,
|
|
2044
|
+
bank_id: str,
|
|
2045
|
+
*,
|
|
2046
|
+
request_context: "RequestContext",
|
|
2047
|
+
) -> dict[str, Any] | None:
|
|
1675
2048
|
"""
|
|
1676
2049
|
Retrieve document metadata and statistics.
|
|
1677
2050
|
|
|
1678
2051
|
Args:
|
|
1679
2052
|
document_id: Document ID to retrieve
|
|
1680
2053
|
bank_id: bank ID that owns the document
|
|
2054
|
+
request_context: Request context for authentication.
|
|
1681
2055
|
|
|
1682
2056
|
Returns:
|
|
1683
2057
|
Dictionary with document info or None if not found
|
|
1684
2058
|
"""
|
|
2059
|
+
await self._authenticate_tenant(request_context)
|
|
1685
2060
|
pool = await self._get_pool()
|
|
1686
2061
|
async with acquire_with_retry(pool) as conn:
|
|
1687
2062
|
doc = await conn.fetchrow(
|
|
1688
|
-
"""
|
|
2063
|
+
f"""
|
|
1689
2064
|
SELECT d.id, d.bank_id, d.original_text, d.content_hash,
|
|
1690
2065
|
d.created_at, d.updated_at, COUNT(mu.id) as unit_count
|
|
1691
|
-
FROM documents d
|
|
1692
|
-
LEFT JOIN memory_units mu ON mu.document_id = d.id
|
|
2066
|
+
FROM {fq_table("documents")} d
|
|
2067
|
+
LEFT JOIN {fq_table("memory_units")} mu ON mu.document_id = d.id
|
|
1693
2068
|
WHERE d.id = $1 AND d.bank_id = $2
|
|
1694
2069
|
GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
|
|
1695
2070
|
""",
|
|
@@ -1706,37 +2081,52 @@ class MemoryEngine:
|
|
|
1706
2081
|
"original_text": doc["original_text"],
|
|
1707
2082
|
"content_hash": doc["content_hash"],
|
|
1708
2083
|
"memory_unit_count": doc["unit_count"],
|
|
1709
|
-
"created_at": doc["created_at"],
|
|
1710
|
-
"updated_at": doc["updated_at"],
|
|
2084
|
+
"created_at": doc["created_at"].isoformat() if doc["created_at"] else None,
|
|
2085
|
+
"updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else None,
|
|
1711
2086
|
}
|
|
1712
2087
|
|
|
1713
|
-
async def delete_document(
|
|
2088
|
+
async def delete_document(
|
|
2089
|
+
self,
|
|
2090
|
+
document_id: str,
|
|
2091
|
+
bank_id: str,
|
|
2092
|
+
*,
|
|
2093
|
+
request_context: "RequestContext",
|
|
2094
|
+
) -> dict[str, int]:
|
|
1714
2095
|
"""
|
|
1715
2096
|
Delete a document and all its associated memory units and links.
|
|
1716
2097
|
|
|
1717
2098
|
Args:
|
|
1718
2099
|
document_id: Document ID to delete
|
|
1719
2100
|
bank_id: bank ID that owns the document
|
|
2101
|
+
request_context: Request context for authentication.
|
|
1720
2102
|
|
|
1721
2103
|
Returns:
|
|
1722
2104
|
Dictionary with counts of deleted items
|
|
1723
2105
|
"""
|
|
2106
|
+
await self._authenticate_tenant(request_context)
|
|
1724
2107
|
pool = await self._get_pool()
|
|
1725
2108
|
async with acquire_with_retry(pool) as conn:
|
|
1726
2109
|
async with conn.transaction():
|
|
1727
2110
|
# Count units before deletion
|
|
1728
2111
|
units_count = await conn.fetchval(
|
|
1729
|
-
"SELECT COUNT(*) FROM memory_units WHERE document_id = $1", document_id
|
|
2112
|
+
f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE document_id = $1", document_id
|
|
1730
2113
|
)
|
|
1731
2114
|
|
|
1732
2115
|
# Delete document (cascades to memory_units and all their links)
|
|
1733
2116
|
deleted = await conn.fetchval(
|
|
1734
|
-
"DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id",
|
|
2117
|
+
f"DELETE FROM {fq_table('documents')} WHERE id = $1 AND bank_id = $2 RETURNING id",
|
|
2118
|
+
document_id,
|
|
2119
|
+
bank_id,
|
|
1735
2120
|
)
|
|
1736
2121
|
|
|
1737
2122
|
return {"document_deleted": 1 if deleted else 0, "memory_units_deleted": units_count if deleted else 0}
|
|
1738
2123
|
|
|
1739
|
-
async def delete_memory_unit(
|
|
2124
|
+
async def delete_memory_unit(
|
|
2125
|
+
self,
|
|
2126
|
+
unit_id: str,
|
|
2127
|
+
*,
|
|
2128
|
+
request_context: "RequestContext",
|
|
2129
|
+
) -> dict[str, Any]:
|
|
1740
2130
|
"""
|
|
1741
2131
|
Delete a single memory unit and all its associated links.
|
|
1742
2132
|
|
|
@@ -1747,15 +2137,19 @@ class MemoryEngine:
|
|
|
1747
2137
|
|
|
1748
2138
|
Args:
|
|
1749
2139
|
unit_id: UUID of the memory unit to delete
|
|
2140
|
+
request_context: Request context for authentication.
|
|
1750
2141
|
|
|
1751
2142
|
Returns:
|
|
1752
2143
|
Dictionary with deletion result
|
|
1753
2144
|
"""
|
|
2145
|
+
await self._authenticate_tenant(request_context)
|
|
1754
2146
|
pool = await self._get_pool()
|
|
1755
2147
|
async with acquire_with_retry(pool) as conn:
|
|
1756
2148
|
async with conn.transaction():
|
|
1757
2149
|
# Delete the memory unit (cascades to links and associations)
|
|
1758
|
-
deleted = await conn.fetchval(
|
|
2150
|
+
deleted = await conn.fetchval(
|
|
2151
|
+
f"DELETE FROM {fq_table('memory_units')} WHERE id = $1 RETURNING id", unit_id
|
|
2152
|
+
)
|
|
1759
2153
|
|
|
1760
2154
|
return {
|
|
1761
2155
|
"success": deleted is not None,
|
|
@@ -1765,7 +2159,13 @@ class MemoryEngine:
|
|
|
1765
2159
|
else "Memory unit not found",
|
|
1766
2160
|
}
|
|
1767
2161
|
|
|
1768
|
-
async def delete_bank(
|
|
2162
|
+
async def delete_bank(
|
|
2163
|
+
self,
|
|
2164
|
+
bank_id: str,
|
|
2165
|
+
fact_type: str | None = None,
|
|
2166
|
+
*,
|
|
2167
|
+
request_context: "RequestContext",
|
|
2168
|
+
) -> dict[str, int]:
|
|
1769
2169
|
"""
|
|
1770
2170
|
Delete all data for a specific agent (multi-tenant cleanup).
|
|
1771
2171
|
|
|
@@ -1780,10 +2180,12 @@ class MemoryEngine:
|
|
|
1780
2180
|
Args:
|
|
1781
2181
|
bank_id: bank ID to delete
|
|
1782
2182
|
fact_type: Optional fact type filter (world, experience, opinion). If provided, only deletes memories of that type.
|
|
2183
|
+
request_context: Request context for authentication.
|
|
1783
2184
|
|
|
1784
2185
|
Returns:
|
|
1785
2186
|
Dictionary with counts of deleted items
|
|
1786
2187
|
"""
|
|
2188
|
+
await self._authenticate_tenant(request_context)
|
|
1787
2189
|
pool = await self._get_pool()
|
|
1788
2190
|
async with acquire_with_retry(pool) as conn:
|
|
1789
2191
|
# Ensure connection is not in read-only mode (can happen with connection poolers)
|
|
@@ -1793,12 +2195,14 @@ class MemoryEngine:
|
|
|
1793
2195
|
if fact_type:
|
|
1794
2196
|
# Delete only memories of a specific fact type
|
|
1795
2197
|
units_count = await conn.fetchval(
|
|
1796
|
-
"SELECT COUNT(*) FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
|
|
2198
|
+
f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1 AND fact_type = $2",
|
|
1797
2199
|
bank_id,
|
|
1798
2200
|
fact_type,
|
|
1799
2201
|
)
|
|
1800
2202
|
await conn.execute(
|
|
1801
|
-
"DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
|
|
2203
|
+
f"DELETE FROM {fq_table('memory_units')} WHERE bank_id = $1 AND fact_type = $2",
|
|
2204
|
+
bank_id,
|
|
2205
|
+
fact_type,
|
|
1802
2206
|
)
|
|
1803
2207
|
|
|
1804
2208
|
# Note: We don't delete entities when fact_type is specified,
|
|
@@ -1807,26 +2211,26 @@ class MemoryEngine:
|
|
|
1807
2211
|
else:
|
|
1808
2212
|
# Delete all data for the bank
|
|
1809
2213
|
units_count = await conn.fetchval(
|
|
1810
|
-
"SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id
|
|
2214
|
+
f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1", bank_id
|
|
1811
2215
|
)
|
|
1812
2216
|
entities_count = await conn.fetchval(
|
|
1813
|
-
"SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id
|
|
2217
|
+
f"SELECT COUNT(*) FROM {fq_table('entities')} WHERE bank_id = $1", bank_id
|
|
1814
2218
|
)
|
|
1815
2219
|
documents_count = await conn.fetchval(
|
|
1816
|
-
"SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id
|
|
2220
|
+
f"SELECT COUNT(*) FROM {fq_table('documents')} WHERE bank_id = $1", bank_id
|
|
1817
2221
|
)
|
|
1818
2222
|
|
|
1819
2223
|
# Delete documents (cascades to chunks)
|
|
1820
|
-
await conn.execute("DELETE FROM documents WHERE bank_id = $1", bank_id)
|
|
2224
|
+
await conn.execute(f"DELETE FROM {fq_table('documents')} WHERE bank_id = $1", bank_id)
|
|
1821
2225
|
|
|
1822
2226
|
# Delete memory units (cascades to unit_entities, memory_links)
|
|
1823
|
-
await conn.execute("DELETE FROM memory_units WHERE bank_id = $1", bank_id)
|
|
2227
|
+
await conn.execute(f"DELETE FROM {fq_table('memory_units')} WHERE bank_id = $1", bank_id)
|
|
1824
2228
|
|
|
1825
2229
|
# Delete entities (cascades to unit_entities, entity_cooccurrences, memory_links with entity_id)
|
|
1826
|
-
await conn.execute("DELETE FROM entities WHERE bank_id = $1", bank_id)
|
|
2230
|
+
await conn.execute(f"DELETE FROM {fq_table('entities')} WHERE bank_id = $1", bank_id)
|
|
1827
2231
|
|
|
1828
2232
|
# Delete the bank profile itself
|
|
1829
|
-
await conn.execute("DELETE FROM banks WHERE bank_id = $1", bank_id)
|
|
2233
|
+
await conn.execute(f"DELETE FROM {fq_table('banks')} WHERE bank_id = $1", bank_id)
|
|
1830
2234
|
|
|
1831
2235
|
return {
|
|
1832
2236
|
"memory_units_deleted": units_count,
|
|
@@ -1838,17 +2242,25 @@ class MemoryEngine:
|
|
|
1838
2242
|
except Exception as e:
|
|
1839
2243
|
raise Exception(f"Failed to delete agent data: {str(e)}")
|
|
1840
2244
|
|
|
1841
|
-
async def get_graph_data(
|
|
2245
|
+
async def get_graph_data(
|
|
2246
|
+
self,
|
|
2247
|
+
bank_id: str | None = None,
|
|
2248
|
+
fact_type: str | None = None,
|
|
2249
|
+
*,
|
|
2250
|
+
request_context: "RequestContext",
|
|
2251
|
+
):
|
|
1842
2252
|
"""
|
|
1843
2253
|
Get graph data for visualization.
|
|
1844
2254
|
|
|
1845
2255
|
Args:
|
|
1846
2256
|
bank_id: Filter by bank ID
|
|
1847
2257
|
fact_type: Filter by fact type (world, experience, opinion)
|
|
2258
|
+
request_context: Request context for authentication.
|
|
1848
2259
|
|
|
1849
2260
|
Returns:
|
|
1850
2261
|
Dict with nodes, edges, and table_rows
|
|
1851
2262
|
"""
|
|
2263
|
+
await self._authenticate_tenant(request_context)
|
|
1852
2264
|
pool = await self._get_pool()
|
|
1853
2265
|
async with acquire_with_retry(pool) as conn:
|
|
1854
2266
|
# Get memory units, optionally filtered by bank_id and fact_type
|
|
@@ -1871,7 +2283,7 @@ class MemoryEngine:
|
|
|
1871
2283
|
units = await conn.fetch(
|
|
1872
2284
|
f"""
|
|
1873
2285
|
SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
|
|
1874
|
-
FROM memory_units
|
|
2286
|
+
FROM {fq_table("memory_units")}
|
|
1875
2287
|
{where_clause}
|
|
1876
2288
|
ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
|
|
1877
2289
|
LIMIT 1000
|
|
@@ -1884,15 +2296,15 @@ class MemoryEngine:
|
|
|
1884
2296
|
unit_ids = [row["id"] for row in units]
|
|
1885
2297
|
if unit_ids:
|
|
1886
2298
|
links = await conn.fetch(
|
|
1887
|
-
"""
|
|
2299
|
+
f"""
|
|
1888
2300
|
SELECT DISTINCT ON (LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid))
|
|
1889
2301
|
ml.from_unit_id,
|
|
1890
2302
|
ml.to_unit_id,
|
|
1891
2303
|
ml.link_type,
|
|
1892
2304
|
ml.weight,
|
|
1893
2305
|
e.canonical_name as entity_name
|
|
1894
|
-
FROM memory_links ml
|
|
1895
|
-
LEFT JOIN entities e ON ml.entity_id = e.id
|
|
2306
|
+
FROM {fq_table("memory_links")} ml
|
|
2307
|
+
LEFT JOIN {fq_table("entities")} e ON ml.entity_id = e.id
|
|
1896
2308
|
WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.to_unit_id = ANY($1::uuid[])
|
|
1897
2309
|
ORDER BY LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid), ml.weight DESC
|
|
1898
2310
|
""",
|
|
@@ -1902,10 +2314,10 @@ class MemoryEngine:
|
|
|
1902
2314
|
links = []
|
|
1903
2315
|
|
|
1904
2316
|
# Get entity information
|
|
1905
|
-
unit_entities = await conn.fetch("""
|
|
2317
|
+
unit_entities = await conn.fetch(f"""
|
|
1906
2318
|
SELECT ue.unit_id, e.canonical_name
|
|
1907
|
-
FROM unit_entities ue
|
|
1908
|
-
JOIN entities e ON ue.entity_id = e.id
|
|
2319
|
+
FROM {fq_table("unit_entities")} ue
|
|
2320
|
+
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
1909
2321
|
ORDER BY ue.unit_id
|
|
1910
2322
|
""")
|
|
1911
2323
|
|
|
@@ -2017,11 +2429,13 @@ class MemoryEngine:
|
|
|
2017
2429
|
|
|
2018
2430
|
async def list_memory_units(
|
|
2019
2431
|
self,
|
|
2020
|
-
bank_id: str
|
|
2432
|
+
bank_id: str,
|
|
2433
|
+
*,
|
|
2021
2434
|
fact_type: str | None = None,
|
|
2022
2435
|
search_query: str | None = None,
|
|
2023
2436
|
limit: int = 100,
|
|
2024
2437
|
offset: int = 0,
|
|
2438
|
+
request_context: "RequestContext",
|
|
2025
2439
|
):
|
|
2026
2440
|
"""
|
|
2027
2441
|
List memory units for table view with optional full-text search.
|
|
@@ -2032,10 +2446,12 @@ class MemoryEngine:
|
|
|
2032
2446
|
search_query: Full-text search query (searches text and context fields)
|
|
2033
2447
|
limit: Maximum number of results to return
|
|
2034
2448
|
offset: Offset for pagination
|
|
2449
|
+
request_context: Request context for authentication.
|
|
2035
2450
|
|
|
2036
2451
|
Returns:
|
|
2037
2452
|
Dict with items (list of memory units) and total count
|
|
2038
2453
|
"""
|
|
2454
|
+
await self._authenticate_tenant(request_context)
|
|
2039
2455
|
pool = await self._get_pool()
|
|
2040
2456
|
async with acquire_with_retry(pool) as conn:
|
|
2041
2457
|
# Build query conditions
|
|
@@ -2064,7 +2480,7 @@ class MemoryEngine:
|
|
|
2064
2480
|
# Get total count
|
|
2065
2481
|
count_query = f"""
|
|
2066
2482
|
SELECT COUNT(*) as total
|
|
2067
|
-
FROM memory_units
|
|
2483
|
+
FROM {fq_table("memory_units")}
|
|
2068
2484
|
{where_clause}
|
|
2069
2485
|
"""
|
|
2070
2486
|
count_result = await conn.fetchrow(count_query, *query_params)
|
|
@@ -2082,7 +2498,7 @@ class MemoryEngine:
|
|
|
2082
2498
|
units = await conn.fetch(
|
|
2083
2499
|
f"""
|
|
2084
2500
|
SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id
|
|
2085
|
-
FROM memory_units
|
|
2501
|
+
FROM {fq_table("memory_units")}
|
|
2086
2502
|
{where_clause}
|
|
2087
2503
|
ORDER BY mentioned_at DESC NULLS LAST, created_at DESC
|
|
2088
2504
|
LIMIT {limit_param} OFFSET {offset_param}
|
|
@@ -2094,10 +2510,10 @@ class MemoryEngine:
|
|
|
2094
2510
|
if units:
|
|
2095
2511
|
unit_ids = [row["id"] for row in units]
|
|
2096
2512
|
unit_entities = await conn.fetch(
|
|
2097
|
-
"""
|
|
2513
|
+
f"""
|
|
2098
2514
|
SELECT ue.unit_id, e.canonical_name
|
|
2099
|
-
FROM unit_entities ue
|
|
2100
|
-
JOIN entities e ON ue.entity_id = e.id
|
|
2515
|
+
FROM {fq_table("unit_entities")} ue
|
|
2516
|
+
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
2101
2517
|
WHERE ue.unit_id = ANY($1::uuid[])
|
|
2102
2518
|
ORDER BY ue.unit_id
|
|
2103
2519
|
""",
|
|
@@ -2138,7 +2554,15 @@ class MemoryEngine:
|
|
|
2138
2554
|
|
|
2139
2555
|
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
|
2140
2556
|
|
|
2141
|
-
async def list_documents(
|
|
2557
|
+
async def list_documents(
|
|
2558
|
+
self,
|
|
2559
|
+
bank_id: str,
|
|
2560
|
+
*,
|
|
2561
|
+
search_query: str | None = None,
|
|
2562
|
+
limit: int = 100,
|
|
2563
|
+
offset: int = 0,
|
|
2564
|
+
request_context: "RequestContext",
|
|
2565
|
+
):
|
|
2142
2566
|
"""
|
|
2143
2567
|
List documents with optional search and pagination.
|
|
2144
2568
|
|
|
@@ -2147,10 +2571,12 @@ class MemoryEngine:
|
|
|
2147
2571
|
search_query: Search in document ID
|
|
2148
2572
|
limit: Maximum number of results
|
|
2149
2573
|
offset: Offset for pagination
|
|
2574
|
+
request_context: Request context for authentication.
|
|
2150
2575
|
|
|
2151
2576
|
Returns:
|
|
2152
2577
|
Dict with items (list of documents without original_text) and total count
|
|
2153
2578
|
"""
|
|
2579
|
+
await self._authenticate_tenant(request_context)
|
|
2154
2580
|
pool = await self._get_pool()
|
|
2155
2581
|
async with acquire_with_retry(pool) as conn:
|
|
2156
2582
|
# Build query conditions
|
|
@@ -2173,7 +2599,7 @@ class MemoryEngine:
|
|
|
2173
2599
|
# Get total count
|
|
2174
2600
|
count_query = f"""
|
|
2175
2601
|
SELECT COUNT(*) as total
|
|
2176
|
-
FROM documents
|
|
2602
|
+
FROM {fq_table("documents")}
|
|
2177
2603
|
{where_clause}
|
|
2178
2604
|
"""
|
|
2179
2605
|
count_result = await conn.fetchrow(count_query, *query_params)
|
|
@@ -2198,7 +2624,7 @@ class MemoryEngine:
|
|
|
2198
2624
|
updated_at,
|
|
2199
2625
|
LENGTH(original_text) as text_length,
|
|
2200
2626
|
retain_params
|
|
2201
|
-
FROM documents
|
|
2627
|
+
FROM {fq_table("documents")}
|
|
2202
2628
|
{where_clause}
|
|
2203
2629
|
ORDER BY created_at DESC
|
|
2204
2630
|
LIMIT {limit_param} OFFSET {offset_param}
|
|
@@ -2224,7 +2650,7 @@ class MemoryEngine:
|
|
|
2224
2650
|
unit_counts = await conn.fetch(
|
|
2225
2651
|
f"""
|
|
2226
2652
|
SELECT document_id, bank_id, COUNT(*) as unit_count
|
|
2227
|
-
FROM memory_units
|
|
2653
|
+
FROM {fq_table("memory_units")}
|
|
2228
2654
|
WHERE {where_clause_count}
|
|
2229
2655
|
GROUP BY document_id, bank_id
|
|
2230
2656
|
""",
|
|
@@ -2258,75 +2684,27 @@ class MemoryEngine:
|
|
|
2258
2684
|
|
|
2259
2685
|
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
|
2260
2686
|
|
|
2261
|
-
async def
|
|
2262
|
-
|
|
2263
|
-
|
|
2264
|
-
|
|
2265
|
-
|
|
2266
|
-
|
|
2267
|
-
bank_id: bank ID
|
|
2268
|
-
|
|
2269
|
-
Returns:
|
|
2270
|
-
Dict with document details including original_text, or None if not found
|
|
2271
|
-
"""
|
|
2272
|
-
pool = await self._get_pool()
|
|
2273
|
-
async with acquire_with_retry(pool) as conn:
|
|
2274
|
-
doc = await conn.fetchrow(
|
|
2275
|
-
"""
|
|
2276
|
-
SELECT
|
|
2277
|
-
id,
|
|
2278
|
-
bank_id,
|
|
2279
|
-
original_text,
|
|
2280
|
-
content_hash,
|
|
2281
|
-
created_at,
|
|
2282
|
-
updated_at,
|
|
2283
|
-
retain_params
|
|
2284
|
-
FROM documents
|
|
2285
|
-
WHERE id = $1 AND bank_id = $2
|
|
2286
|
-
""",
|
|
2287
|
-
document_id,
|
|
2288
|
-
bank_id,
|
|
2289
|
-
)
|
|
2290
|
-
|
|
2291
|
-
if not doc:
|
|
2292
|
-
return None
|
|
2293
|
-
|
|
2294
|
-
# Get memory unit count
|
|
2295
|
-
unit_count_row = await conn.fetchrow(
|
|
2296
|
-
"""
|
|
2297
|
-
SELECT COUNT(*) as unit_count
|
|
2298
|
-
FROM memory_units
|
|
2299
|
-
WHERE document_id = $1 AND bank_id = $2
|
|
2300
|
-
""",
|
|
2301
|
-
document_id,
|
|
2302
|
-
bank_id,
|
|
2303
|
-
)
|
|
2304
|
-
|
|
2305
|
-
return {
|
|
2306
|
-
"id": doc["id"],
|
|
2307
|
-
"bank_id": doc["bank_id"],
|
|
2308
|
-
"original_text": doc["original_text"],
|
|
2309
|
-
"content_hash": doc["content_hash"],
|
|
2310
|
-
"created_at": doc["created_at"].isoformat() if doc["created_at"] else "",
|
|
2311
|
-
"updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else "",
|
|
2312
|
-
"memory_unit_count": unit_count_row["unit_count"] if unit_count_row else 0,
|
|
2313
|
-
"retain_params": doc["retain_params"] if doc["retain_params"] else None,
|
|
2314
|
-
}
|
|
2315
|
-
|
|
2316
|
-
async def get_chunk(self, chunk_id: str):
|
|
2687
|
+
async def get_chunk(
|
|
2688
|
+
self,
|
|
2689
|
+
chunk_id: str,
|
|
2690
|
+
*,
|
|
2691
|
+
request_context: "RequestContext",
|
|
2692
|
+
):
|
|
2317
2693
|
"""
|
|
2318
2694
|
Get a specific chunk by its ID.
|
|
2319
2695
|
|
|
2320
2696
|
Args:
|
|
2321
2697
|
chunk_id: Chunk ID (format: bank_id_document_id_chunk_index)
|
|
2698
|
+
request_context: Request context for authentication.
|
|
2322
2699
|
|
|
2323
2700
|
Returns:
|
|
2324
2701
|
Dict with chunk details including chunk_text, or None if not found
|
|
2325
2702
|
"""
|
|
2703
|
+
await self._authenticate_tenant(request_context)
|
|
2326
2704
|
pool = await self._get_pool()
|
|
2327
2705
|
async with acquire_with_retry(pool) as conn:
|
|
2328
2706
|
chunk = await conn.fetchrow(
|
|
2329
|
-
"""
|
|
2707
|
+
f"""
|
|
2330
2708
|
SELECT
|
|
2331
2709
|
chunk_id,
|
|
2332
2710
|
document_id,
|
|
@@ -2334,7 +2712,7 @@ class MemoryEngine:
|
|
|
2334
2712
|
chunk_index,
|
|
2335
2713
|
chunk_text,
|
|
2336
2714
|
created_at
|
|
2337
|
-
FROM chunks
|
|
2715
|
+
FROM {fq_table("chunks")}
|
|
2338
2716
|
WHERE chunk_id = $1
|
|
2339
2717
|
""",
|
|
2340
2718
|
chunk_id,
|
|
@@ -2500,11 +2878,11 @@ Guidelines:
|
|
|
2500
2878
|
async with acquire_with_retry(pool) as conn:
|
|
2501
2879
|
# Find all opinions related to these entities
|
|
2502
2880
|
opinions = await conn.fetch(
|
|
2503
|
-
"""
|
|
2881
|
+
f"""
|
|
2504
2882
|
SELECT DISTINCT mu.id, mu.text, mu.confidence_score, e.canonical_name
|
|
2505
|
-
FROM memory_units mu
|
|
2506
|
-
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
2507
|
-
JOIN entities e ON ue.entity_id = e.id
|
|
2883
|
+
FROM {fq_table("memory_units")} mu
|
|
2884
|
+
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
2885
|
+
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
2508
2886
|
WHERE mu.bank_id = $1
|
|
2509
2887
|
AND mu.fact_type = 'opinion'
|
|
2510
2888
|
AND e.canonical_name = ANY($2::text[])
|
|
@@ -2559,8 +2937,8 @@ Guidelines:
|
|
|
2559
2937
|
if evaluation["action"] == "update" and evaluation["new_text"]:
|
|
2560
2938
|
# Update both text and confidence
|
|
2561
2939
|
await conn.execute(
|
|
2562
|
-
"""
|
|
2563
|
-
UPDATE memory_units
|
|
2940
|
+
f"""
|
|
2941
|
+
UPDATE {fq_table("memory_units")}
|
|
2564
2942
|
SET text = $1, confidence_score = $2, updated_at = NOW()
|
|
2565
2943
|
WHERE id = $3
|
|
2566
2944
|
""",
|
|
@@ -2571,8 +2949,8 @@ Guidelines:
|
|
|
2571
2949
|
else:
|
|
2572
2950
|
# Only update confidence
|
|
2573
2951
|
await conn.execute(
|
|
2574
|
-
"""
|
|
2575
|
-
UPDATE memory_units
|
|
2952
|
+
f"""
|
|
2953
|
+
UPDATE {fq_table("memory_units")}
|
|
2576
2954
|
SET confidence_score = $1, updated_at = NOW()
|
|
2577
2955
|
WHERE id = $2
|
|
2578
2956
|
""",
|
|
@@ -2591,32 +2969,61 @@ Guidelines:
|
|
|
2591
2969
|
|
|
2592
2970
|
# ==================== bank profile Methods ====================
|
|
2593
2971
|
|
|
2594
|
-
async def get_bank_profile(
|
|
2972
|
+
async def get_bank_profile(
|
|
2973
|
+
self,
|
|
2974
|
+
bank_id: str,
|
|
2975
|
+
*,
|
|
2976
|
+
request_context: "RequestContext",
|
|
2977
|
+
) -> dict[str, Any]:
|
|
2595
2978
|
"""
|
|
2596
2979
|
Get bank profile (name, disposition + background).
|
|
2597
2980
|
Auto-creates agent with default values if not exists.
|
|
2598
2981
|
|
|
2599
2982
|
Args:
|
|
2600
2983
|
bank_id: bank IDentifier
|
|
2984
|
+
request_context: Request context for authentication.
|
|
2601
2985
|
|
|
2602
2986
|
Returns:
|
|
2603
|
-
|
|
2987
|
+
Dict with name, disposition traits, and background
|
|
2604
2988
|
"""
|
|
2989
|
+
await self._authenticate_tenant(request_context)
|
|
2605
2990
|
pool = await self._get_pool()
|
|
2606
|
-
|
|
2607
|
-
|
|
2608
|
-
|
|
2991
|
+
profile = await bank_utils.get_bank_profile(pool, bank_id)
|
|
2992
|
+
disposition = profile["disposition"]
|
|
2993
|
+
return {
|
|
2994
|
+
"bank_id": bank_id,
|
|
2995
|
+
"name": profile["name"],
|
|
2996
|
+
"disposition": disposition,
|
|
2997
|
+
"background": profile["background"],
|
|
2998
|
+
}
|
|
2999
|
+
|
|
3000
|
+
async def update_bank_disposition(
|
|
3001
|
+
self,
|
|
3002
|
+
bank_id: str,
|
|
3003
|
+
disposition: dict[str, int],
|
|
3004
|
+
*,
|
|
3005
|
+
request_context: "RequestContext",
|
|
3006
|
+
) -> None:
|
|
2609
3007
|
"""
|
|
2610
3008
|
Update bank disposition traits.
|
|
2611
3009
|
|
|
2612
3010
|
Args:
|
|
2613
3011
|
bank_id: bank IDentifier
|
|
2614
3012
|
disposition: Dict with skepticism, literalism, empathy (all 1-5)
|
|
3013
|
+
request_context: Request context for authentication.
|
|
2615
3014
|
"""
|
|
3015
|
+
await self._authenticate_tenant(request_context)
|
|
2616
3016
|
pool = await self._get_pool()
|
|
2617
3017
|
await bank_utils.update_bank_disposition(pool, bank_id, disposition)
|
|
2618
3018
|
|
|
2619
|
-
async def merge_bank_background(
|
|
3019
|
+
async def merge_bank_background(
|
|
3020
|
+
self,
|
|
3021
|
+
bank_id: str,
|
|
3022
|
+
new_info: str,
|
|
3023
|
+
*,
|
|
3024
|
+
update_disposition: bool = True,
|
|
3025
|
+
request_context: "RequestContext",
|
|
3026
|
+
) -> dict[str, Any]:
|
|
2620
3027
|
"""
|
|
2621
3028
|
Merge new background information with existing background using LLM.
|
|
2622
3029
|
Normalizes to first person ("I") and resolves conflicts.
|
|
@@ -2626,20 +3033,30 @@ Guidelines:
|
|
|
2626
3033
|
bank_id: bank IDentifier
|
|
2627
3034
|
new_info: New background information to add/merge
|
|
2628
3035
|
update_disposition: If True, infer Big Five traits from background (default: True)
|
|
3036
|
+
request_context: Request context for authentication.
|
|
2629
3037
|
|
|
2630
3038
|
Returns:
|
|
2631
3039
|
Dict with 'background' (str) and optionally 'disposition' (dict) keys
|
|
2632
3040
|
"""
|
|
3041
|
+
await self._authenticate_tenant(request_context)
|
|
2633
3042
|
pool = await self._get_pool()
|
|
2634
3043
|
return await bank_utils.merge_bank_background(pool, self._llm_config, bank_id, new_info, update_disposition)
|
|
2635
3044
|
|
|
2636
|
-
async def list_banks(
|
|
3045
|
+
async def list_banks(
|
|
3046
|
+
self,
|
|
3047
|
+
*,
|
|
3048
|
+
request_context: "RequestContext",
|
|
3049
|
+
) -> list[dict[str, Any]]:
|
|
2637
3050
|
"""
|
|
2638
3051
|
List all agents in the system.
|
|
2639
3052
|
|
|
3053
|
+
Args:
|
|
3054
|
+
request_context: Request context for authentication.
|
|
3055
|
+
|
|
2640
3056
|
Returns:
|
|
2641
3057
|
List of dicts with bank_id, name, disposition, background, created_at, updated_at
|
|
2642
3058
|
"""
|
|
3059
|
+
await self._authenticate_tenant(request_context)
|
|
2643
3060
|
pool = await self._get_pool()
|
|
2644
3061
|
return await bank_utils.list_banks(pool)
|
|
2645
3062
|
|
|
@@ -2649,8 +3066,10 @@ Guidelines:
|
|
|
2649
3066
|
self,
|
|
2650
3067
|
bank_id: str,
|
|
2651
3068
|
query: str,
|
|
2652
|
-
|
|
2653
|
-
|
|
3069
|
+
*,
|
|
3070
|
+
budget: Budget | None = None,
|
|
3071
|
+
context: str | None = None,
|
|
3072
|
+
request_context: "RequestContext",
|
|
2654
3073
|
) -> ReflectResult:
|
|
2655
3074
|
"""
|
|
2656
3075
|
Reflect and formulate an answer using bank identity, world facts, and opinions.
|
|
@@ -2679,6 +3098,22 @@ Guidelines:
|
|
|
2679
3098
|
if self._llm_config is None:
|
|
2680
3099
|
raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
|
|
2681
3100
|
|
|
3101
|
+
# Authenticate tenant and set schema in context (for fq_table())
|
|
3102
|
+
await self._authenticate_tenant(request_context)
|
|
3103
|
+
|
|
3104
|
+
# Validate operation if validator is configured
|
|
3105
|
+
if self._operation_validator:
|
|
3106
|
+
from hindsight_api.extensions import ReflectContext
|
|
3107
|
+
|
|
3108
|
+
ctx = ReflectContext(
|
|
3109
|
+
bank_id=bank_id,
|
|
3110
|
+
query=query,
|
|
3111
|
+
request_context=request_context,
|
|
3112
|
+
budget=budget,
|
|
3113
|
+
context=context,
|
|
3114
|
+
)
|
|
3115
|
+
await self._validate_operation(self._operation_validator.validate_reflect(ctx))
|
|
3116
|
+
|
|
2682
3117
|
reflect_start = time.time()
|
|
2683
3118
|
reflect_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
|
|
2684
3119
|
log_buffer = []
|
|
@@ -2694,6 +3129,7 @@ Guidelines:
|
|
|
2694
3129
|
enable_trace=False,
|
|
2695
3130
|
fact_type=["experience", "world", "opinion"],
|
|
2696
3131
|
include_entities=True,
|
|
3132
|
+
request_context=request_context,
|
|
2697
3133
|
)
|
|
2698
3134
|
recall_time = time.time() - recall_start
|
|
2699
3135
|
|
|
@@ -2714,7 +3150,7 @@ Guidelines:
|
|
|
2714
3150
|
opinion_facts_text = think_utils.format_facts_for_prompt(opinion_results)
|
|
2715
3151
|
|
|
2716
3152
|
# Get bank profile (name, disposition + background)
|
|
2717
|
-
profile = await self.get_bank_profile(bank_id)
|
|
3153
|
+
profile = await self.get_bank_profile(bank_id, request_context=request_context)
|
|
2718
3154
|
name = profile["name"]
|
|
2719
3155
|
disposition = profile["disposition"] # Typed as DispositionTraits
|
|
2720
3156
|
background = profile["background"]
|
|
@@ -2758,12 +3194,33 @@ Guidelines:
|
|
|
2758
3194
|
logger.info("\n" + "\n".join(log_buffer))
|
|
2759
3195
|
|
|
2760
3196
|
# Return response with facts split by type
|
|
2761
|
-
|
|
3197
|
+
result = ReflectResult(
|
|
2762
3198
|
text=answer_text,
|
|
2763
3199
|
based_on={"world": world_results, "experience": agent_results, "opinion": opinion_results},
|
|
2764
3200
|
new_opinions=[], # Opinions are being extracted asynchronously
|
|
2765
3201
|
)
|
|
2766
3202
|
|
|
3203
|
+
# Call post-operation hook if validator is configured
|
|
3204
|
+
if self._operation_validator:
|
|
3205
|
+
from hindsight_api.extensions.operation_validator import ReflectResultContext
|
|
3206
|
+
|
|
3207
|
+
result_ctx = ReflectResultContext(
|
|
3208
|
+
bank_id=bank_id,
|
|
3209
|
+
query=query,
|
|
3210
|
+
request_context=request_context,
|
|
3211
|
+
budget=budget,
|
|
3212
|
+
context=context,
|
|
3213
|
+
result=result,
|
|
3214
|
+
success=True,
|
|
3215
|
+
error=None,
|
|
3216
|
+
)
|
|
3217
|
+
try:
|
|
3218
|
+
await self._operation_validator.on_reflect_complete(result_ctx)
|
|
3219
|
+
except Exception as e:
|
|
3220
|
+
logger.warning(f"Post-reflect hook error (non-fatal): {e}")
|
|
3221
|
+
|
|
3222
|
+
return result
|
|
3223
|
+
|
|
2767
3224
|
async def _extract_and_store_opinions_async(self, bank_id: str, answer_text: str, query: str):
|
|
2768
3225
|
"""
|
|
2769
3226
|
Background task to extract and store opinions from think response.
|
|
@@ -2784,6 +3241,10 @@ Guidelines:
|
|
|
2784
3241
|
from datetime import datetime
|
|
2785
3242
|
|
|
2786
3243
|
current_time = datetime.now(UTC)
|
|
3244
|
+
# Use internal request context for background tasks
|
|
3245
|
+
from hindsight_api.models import RequestContext
|
|
3246
|
+
|
|
3247
|
+
internal_context = RequestContext()
|
|
2787
3248
|
for opinion in new_opinions:
|
|
2788
3249
|
await self.retain_async(
|
|
2789
3250
|
bank_id=bank_id,
|
|
@@ -2792,12 +3253,20 @@ Guidelines:
|
|
|
2792
3253
|
event_date=current_time,
|
|
2793
3254
|
fact_type_override="opinion",
|
|
2794
3255
|
confidence_score=opinion.confidence,
|
|
3256
|
+
request_context=internal_context,
|
|
2795
3257
|
)
|
|
2796
3258
|
|
|
2797
3259
|
except Exception as e:
|
|
2798
3260
|
logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
|
|
2799
3261
|
|
|
2800
|
-
async def get_entity_observations(
|
|
3262
|
+
async def get_entity_observations(
|
|
3263
|
+
self,
|
|
3264
|
+
bank_id: str,
|
|
3265
|
+
entity_id: str,
|
|
3266
|
+
*,
|
|
3267
|
+
limit: int = 10,
|
|
3268
|
+
request_context: "RequestContext",
|
|
3269
|
+
) -> list[Any]:
|
|
2801
3270
|
"""
|
|
2802
3271
|
Get observations linked to an entity.
|
|
2803
3272
|
|
|
@@ -2805,17 +3274,19 @@ Guidelines:
|
|
|
2805
3274
|
bank_id: bank IDentifier
|
|
2806
3275
|
entity_id: Entity UUID to get observations for
|
|
2807
3276
|
limit: Maximum number of observations to return
|
|
3277
|
+
request_context: Request context for authentication.
|
|
2808
3278
|
|
|
2809
3279
|
Returns:
|
|
2810
3280
|
List of EntityObservation objects
|
|
2811
3281
|
"""
|
|
3282
|
+
await self._authenticate_tenant(request_context)
|
|
2812
3283
|
pool = await self._get_pool()
|
|
2813
3284
|
async with acquire_with_retry(pool) as conn:
|
|
2814
3285
|
rows = await conn.fetch(
|
|
2815
|
-
"""
|
|
3286
|
+
f"""
|
|
2816
3287
|
SELECT mu.text, mu.mentioned_at
|
|
2817
|
-
FROM memory_units mu
|
|
2818
|
-
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
3288
|
+
FROM {fq_table("memory_units")} mu
|
|
3289
|
+
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
2819
3290
|
WHERE mu.bank_id = $1
|
|
2820
3291
|
AND mu.fact_type = 'observation'
|
|
2821
3292
|
AND ue.entity_id = $2
|
|
@@ -2833,23 +3304,31 @@ Guidelines:
|
|
|
2833
3304
|
observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
|
|
2834
3305
|
return observations
|
|
2835
3306
|
|
|
2836
|
-
async def list_entities(
|
|
3307
|
+
async def list_entities(
|
|
3308
|
+
self,
|
|
3309
|
+
bank_id: str,
|
|
3310
|
+
*,
|
|
3311
|
+
limit: int = 100,
|
|
3312
|
+
request_context: "RequestContext",
|
|
3313
|
+
) -> list[dict[str, Any]]:
|
|
2837
3314
|
"""
|
|
2838
3315
|
List all entities for a bank.
|
|
2839
3316
|
|
|
2840
3317
|
Args:
|
|
2841
3318
|
bank_id: bank IDentifier
|
|
2842
3319
|
limit: Maximum number of entities to return
|
|
3320
|
+
request_context: Request context for authentication.
|
|
2843
3321
|
|
|
2844
3322
|
Returns:
|
|
2845
3323
|
List of entity dicts with id, canonical_name, mention_count, first_seen, last_seen
|
|
2846
3324
|
"""
|
|
3325
|
+
await self._authenticate_tenant(request_context)
|
|
2847
3326
|
pool = await self._get_pool()
|
|
2848
3327
|
async with acquire_with_retry(pool) as conn:
|
|
2849
3328
|
rows = await conn.fetch(
|
|
2850
|
-
"""
|
|
3329
|
+
f"""
|
|
2851
3330
|
SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
|
|
2852
|
-
FROM entities
|
|
3331
|
+
FROM {fq_table("entities")}
|
|
2853
3332
|
WHERE bank_id = $1
|
|
2854
3333
|
ORDER BY mention_count DESC, last_seen DESC
|
|
2855
3334
|
LIMIT $2
|
|
@@ -2884,7 +3363,15 @@ Guidelines:
|
|
|
2884
3363
|
)
|
|
2885
3364
|
return entities
|
|
2886
3365
|
|
|
2887
|
-
async def get_entity_state(
|
|
3366
|
+
async def get_entity_state(
|
|
3367
|
+
self,
|
|
3368
|
+
bank_id: str,
|
|
3369
|
+
entity_id: str,
|
|
3370
|
+
entity_name: str,
|
|
3371
|
+
*,
|
|
3372
|
+
limit: int = 10,
|
|
3373
|
+
request_context: "RequestContext",
|
|
3374
|
+
) -> EntityState:
|
|
2888
3375
|
"""
|
|
2889
3376
|
Get the current state (mental model) of an entity.
|
|
2890
3377
|
|
|
@@ -2893,16 +3380,26 @@ Guidelines:
|
|
|
2893
3380
|
entity_id: Entity UUID
|
|
2894
3381
|
entity_name: Canonical name of the entity
|
|
2895
3382
|
limit: Maximum number of observations to include
|
|
3383
|
+
request_context: Request context for authentication.
|
|
2896
3384
|
|
|
2897
3385
|
Returns:
|
|
2898
3386
|
EntityState with observations
|
|
2899
3387
|
"""
|
|
2900
|
-
observations = await self.get_entity_observations(
|
|
3388
|
+
observations = await self.get_entity_observations(
|
|
3389
|
+
bank_id, entity_id, limit=limit, request_context=request_context
|
|
3390
|
+
)
|
|
2901
3391
|
return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=observations)
|
|
2902
3392
|
|
|
2903
3393
|
async def regenerate_entity_observations(
|
|
2904
|
-
self,
|
|
2905
|
-
|
|
3394
|
+
self,
|
|
3395
|
+
bank_id: str,
|
|
3396
|
+
entity_id: str,
|
|
3397
|
+
entity_name: str,
|
|
3398
|
+
*,
|
|
3399
|
+
version: str | None = None,
|
|
3400
|
+
conn=None,
|
|
3401
|
+
request_context: "RequestContext",
|
|
3402
|
+
) -> None:
|
|
2906
3403
|
"""
|
|
2907
3404
|
Regenerate observations for an entity by:
|
|
2908
3405
|
1. Checking version for deduplication (if provided)
|
|
@@ -2917,10 +3414,9 @@ Guidelines:
|
|
|
2917
3414
|
entity_name: Canonical name of the entity
|
|
2918
3415
|
version: Entity's last_seen timestamp when task was created (for deduplication)
|
|
2919
3416
|
conn: Optional database connection (for transactional atomicity with caller)
|
|
2920
|
-
|
|
2921
|
-
Returns:
|
|
2922
|
-
List of created observation IDs
|
|
3417
|
+
request_context: Request context for authentication.
|
|
2923
3418
|
"""
|
|
3419
|
+
await self._authenticate_tenant(request_context)
|
|
2924
3420
|
pool = await self._get_pool()
|
|
2925
3421
|
entity_uuid = uuid.UUID(entity_id)
|
|
2926
3422
|
|
|
@@ -2942,9 +3438,9 @@ Guidelines:
|
|
|
2942
3438
|
# Step 1: Check version for deduplication
|
|
2943
3439
|
if version:
|
|
2944
3440
|
current_last_seen = await fetchval_with_conn(
|
|
2945
|
-
"""
|
|
3441
|
+
f"""
|
|
2946
3442
|
SELECT last_seen
|
|
2947
|
-
FROM entities
|
|
3443
|
+
FROM {fq_table("entities")}
|
|
2948
3444
|
WHERE id = $1 AND bank_id = $2
|
|
2949
3445
|
""",
|
|
2950
3446
|
entity_uuid,
|
|
@@ -2956,10 +3452,10 @@ Guidelines:
|
|
|
2956
3452
|
|
|
2957
3453
|
# Step 2: Get all facts mentioning this entity (exclude observations themselves)
|
|
2958
3454
|
rows = await fetch_with_conn(
|
|
2959
|
-
"""
|
|
3455
|
+
f"""
|
|
2960
3456
|
SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
|
|
2961
|
-
FROM memory_units mu
|
|
2962
|
-
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
3457
|
+
FROM {fq_table("memory_units")} mu
|
|
3458
|
+
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
2963
3459
|
WHERE mu.bank_id = $1
|
|
2964
3460
|
AND ue.entity_id = $2
|
|
2965
3461
|
AND mu.fact_type IN ('world', 'experience')
|
|
@@ -2999,12 +3495,12 @@ Guidelines:
|
|
|
2999
3495
|
async def do_db_operations(db_conn):
|
|
3000
3496
|
# Delete old observations for this entity
|
|
3001
3497
|
await db_conn.execute(
|
|
3002
|
-
"""
|
|
3003
|
-
DELETE FROM memory_units
|
|
3498
|
+
f"""
|
|
3499
|
+
DELETE FROM {fq_table("memory_units")}
|
|
3004
3500
|
WHERE id IN (
|
|
3005
3501
|
SELECT mu.id
|
|
3006
|
-
FROM memory_units mu
|
|
3007
|
-
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
3502
|
+
FROM {fq_table("memory_units")} mu
|
|
3503
|
+
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
3008
3504
|
WHERE mu.bank_id = $1
|
|
3009
3505
|
AND mu.fact_type = 'observation'
|
|
3010
3506
|
AND ue.entity_id = $2
|
|
@@ -3023,8 +3519,8 @@ Guidelines:
|
|
|
3023
3519
|
|
|
3024
3520
|
for obs_text, embedding in zip(observations, embeddings):
|
|
3025
3521
|
result = await db_conn.fetchrow(
|
|
3026
|
-
"""
|
|
3027
|
-
INSERT INTO memory_units (
|
|
3522
|
+
f"""
|
|
3523
|
+
INSERT INTO {fq_table("memory_units")} (
|
|
3028
3524
|
bank_id, text, embedding, context, event_date,
|
|
3029
3525
|
occurred_start, occurred_end, mentioned_at,
|
|
3030
3526
|
fact_type, access_count
|
|
@@ -3046,8 +3542,8 @@ Guidelines:
|
|
|
3046
3542
|
|
|
3047
3543
|
# Link observation to entity
|
|
3048
3544
|
await db_conn.execute(
|
|
3049
|
-
"""
|
|
3050
|
-
INSERT INTO unit_entities (unit_id, entity_id)
|
|
3545
|
+
f"""
|
|
3546
|
+
INSERT INTO {fq_table("unit_entities")} (unit_id, entity_id)
|
|
3051
3547
|
VALUES ($1, $2)
|
|
3052
3548
|
""",
|
|
3053
3549
|
uuid.UUID(obs_id),
|
|
@@ -3066,7 +3562,12 @@ Guidelines:
|
|
|
3066
3562
|
return await do_db_operations(acquired_conn)
|
|
3067
3563
|
|
|
3068
3564
|
async def _regenerate_observations_sync(
|
|
3069
|
-
self,
|
|
3565
|
+
self,
|
|
3566
|
+
bank_id: str,
|
|
3567
|
+
entity_ids: list[str],
|
|
3568
|
+
min_facts: int = 5,
|
|
3569
|
+
conn=None,
|
|
3570
|
+
request_context: "RequestContext | None" = None,
|
|
3070
3571
|
) -> None:
|
|
3071
3572
|
"""
|
|
3072
3573
|
Regenerate observations for entities synchronously (called during retain).
|
|
@@ -3089,8 +3590,8 @@ Guidelines:
|
|
|
3089
3590
|
if conn is not None:
|
|
3090
3591
|
# Use the provided connection (transactional with caller)
|
|
3091
3592
|
entity_rows = await conn.fetch(
|
|
3092
|
-
"""
|
|
3093
|
-
SELECT id, canonical_name FROM entities
|
|
3593
|
+
f"""
|
|
3594
|
+
SELECT id, canonical_name FROM {fq_table("entities")}
|
|
3094
3595
|
WHERE id = ANY($1) AND bank_id = $2
|
|
3095
3596
|
""",
|
|
3096
3597
|
entity_uuids,
|
|
@@ -3099,10 +3600,10 @@ Guidelines:
|
|
|
3099
3600
|
entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
|
|
3100
3601
|
|
|
3101
3602
|
fact_counts = await conn.fetch(
|
|
3102
|
-
"""
|
|
3603
|
+
f"""
|
|
3103
3604
|
SELECT ue.entity_id, COUNT(*) as cnt
|
|
3104
|
-
FROM unit_entities ue
|
|
3105
|
-
JOIN memory_units mu ON ue.unit_id = mu.id
|
|
3605
|
+
FROM {fq_table("unit_entities")} ue
|
|
3606
|
+
JOIN {fq_table("memory_units")} mu ON ue.unit_id = mu.id
|
|
3106
3607
|
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
3107
3608
|
GROUP BY ue.entity_id
|
|
3108
3609
|
""",
|
|
@@ -3115,8 +3616,8 @@ Guidelines:
|
|
|
3115
3616
|
pool = await self._get_pool()
|
|
3116
3617
|
async with pool.acquire() as acquired_conn:
|
|
3117
3618
|
entity_rows = await acquired_conn.fetch(
|
|
3118
|
-
"""
|
|
3119
|
-
SELECT id, canonical_name FROM entities
|
|
3619
|
+
f"""
|
|
3620
|
+
SELECT id, canonical_name FROM {fq_table("entities")}
|
|
3120
3621
|
WHERE id = ANY($1) AND bank_id = $2
|
|
3121
3622
|
""",
|
|
3122
3623
|
entity_uuids,
|
|
@@ -3125,10 +3626,10 @@ Guidelines:
|
|
|
3125
3626
|
entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
|
|
3126
3627
|
|
|
3127
3628
|
fact_counts = await acquired_conn.fetch(
|
|
3128
|
-
"""
|
|
3629
|
+
f"""
|
|
3129
3630
|
SELECT ue.entity_id, COUNT(*) as cnt
|
|
3130
|
-
FROM unit_entities ue
|
|
3131
|
-
JOIN memory_units mu ON ue.unit_id = mu.id
|
|
3631
|
+
FROM {fq_table("unit_entities")} ue
|
|
3632
|
+
JOIN {fq_table("memory_units")} mu ON ue.unit_id = mu.id
|
|
3132
3633
|
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
3133
3634
|
GROUP BY ue.entity_id
|
|
3134
3635
|
""",
|
|
@@ -3150,10 +3651,17 @@ Guidelines:
|
|
|
3150
3651
|
if not entities_to_process:
|
|
3151
3652
|
return
|
|
3152
3653
|
|
|
3654
|
+
# Use internal context if not provided (for internal/background calls)
|
|
3655
|
+
from hindsight_api.models import RequestContext as RC
|
|
3656
|
+
|
|
3657
|
+
ctx = request_context if request_context is not None else RC()
|
|
3658
|
+
|
|
3153
3659
|
# Process all entities in PARALLEL (LLM calls are the bottleneck)
|
|
3154
3660
|
async def process_entity(entity_id: str, entity_name: str):
|
|
3155
3661
|
try:
|
|
3156
|
-
await self.regenerate_entity_observations(
|
|
3662
|
+
await self.regenerate_entity_observations(
|
|
3663
|
+
bank_id, entity_id, entity_name, version=None, conn=conn, request_context=ctx
|
|
3664
|
+
)
|
|
3157
3665
|
except Exception as e:
|
|
3158
3666
|
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
3159
3667
|
|
|
@@ -3167,76 +3675,367 @@ Guidelines:
|
|
|
3167
3675
|
task_dict: Dict with 'bank_id' and either:
|
|
3168
3676
|
- 'entity_ids' (list): Process multiple entities
|
|
3169
3677
|
- 'entity_id', 'entity_name': Process single entity (legacy)
|
|
3678
|
+
|
|
3679
|
+
Raises:
|
|
3680
|
+
ValueError: If required fields are missing
|
|
3681
|
+
Exception: Any exception from regenerate_entity_observations (propagates to execute_task for retry)
|
|
3170
3682
|
"""
|
|
3171
|
-
|
|
3172
|
-
|
|
3683
|
+
bank_id = task_dict.get("bank_id")
|
|
3684
|
+
# Use internal request context for background tasks
|
|
3685
|
+
from hindsight_api.models import RequestContext
|
|
3173
3686
|
|
|
3174
|
-
|
|
3175
|
-
if "entity_ids" in task_dict:
|
|
3176
|
-
entity_ids = task_dict.get("entity_ids", [])
|
|
3177
|
-
min_facts = task_dict.get("min_facts", 5)
|
|
3687
|
+
internal_context = RequestContext()
|
|
3178
3688
|
|
|
3179
|
-
|
|
3180
|
-
|
|
3181
|
-
|
|
3689
|
+
# New format: multiple entity_ids
|
|
3690
|
+
if "entity_ids" in task_dict:
|
|
3691
|
+
entity_ids = task_dict.get("entity_ids", [])
|
|
3692
|
+
min_facts = task_dict.get("min_facts", 5)
|
|
3182
3693
|
|
|
3183
|
-
|
|
3184
|
-
|
|
3185
|
-
|
|
3186
|
-
|
|
3187
|
-
|
|
3188
|
-
|
|
3189
|
-
|
|
3694
|
+
if not bank_id or not entity_ids:
|
|
3695
|
+
raise ValueError(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
|
|
3696
|
+
|
|
3697
|
+
# Process each entity
|
|
3698
|
+
pool = await self._get_pool()
|
|
3699
|
+
async with pool.acquire() as conn:
|
|
3700
|
+
for entity_id in entity_ids:
|
|
3701
|
+
try:
|
|
3702
|
+
# Fetch entity name and check fact count
|
|
3703
|
+
import uuid as uuid_module
|
|
3704
|
+
|
|
3705
|
+
entity_uuid = uuid_module.UUID(entity_id) if isinstance(entity_id, str) else entity_id
|
|
3706
|
+
|
|
3707
|
+
# First check if entity exists
|
|
3708
|
+
entity_exists = await conn.fetchrow(
|
|
3709
|
+
f"SELECT canonical_name FROM {fq_table('entities')} WHERE id = $1 AND bank_id = $2",
|
|
3710
|
+
entity_uuid,
|
|
3711
|
+
bank_id,
|
|
3712
|
+
)
|
|
3713
|
+
|
|
3714
|
+
if not entity_exists:
|
|
3715
|
+
logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
|
|
3716
|
+
continue
|
|
3190
3717
|
|
|
3191
|
-
|
|
3718
|
+
entity_name = entity_exists["canonical_name"]
|
|
3192
3719
|
|
|
3193
|
-
|
|
3194
|
-
|
|
3195
|
-
|
|
3720
|
+
# Count facts linked to this entity
|
|
3721
|
+
fact_count = (
|
|
3722
|
+
await conn.fetchval(
|
|
3723
|
+
f"SELECT COUNT(*) FROM {fq_table('unit_entities')} WHERE entity_id = $1",
|
|
3196
3724
|
entity_uuid,
|
|
3197
|
-
bank_id,
|
|
3198
3725
|
)
|
|
3726
|
+
or 0
|
|
3727
|
+
)
|
|
3199
3728
|
|
|
3200
|
-
|
|
3201
|
-
|
|
3202
|
-
|
|
3729
|
+
# Only regenerate if entity has enough facts
|
|
3730
|
+
if fact_count >= min_facts:
|
|
3731
|
+
await self.regenerate_entity_observations(
|
|
3732
|
+
bank_id, entity_id, entity_name, version=None, request_context=internal_context
|
|
3733
|
+
)
|
|
3734
|
+
else:
|
|
3735
|
+
logger.debug(
|
|
3736
|
+
f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
|
|
3737
|
+
)
|
|
3203
3738
|
|
|
3204
|
-
|
|
3739
|
+
except Exception as e:
|
|
3740
|
+
# Log but continue processing other entities - individual entity failures
|
|
3741
|
+
# shouldn't fail the whole batch
|
|
3742
|
+
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
3743
|
+
continue
|
|
3205
3744
|
|
|
3206
|
-
|
|
3207
|
-
|
|
3208
|
-
|
|
3209
|
-
|
|
3210
|
-
|
|
3211
|
-
or 0
|
|
3212
|
-
)
|
|
3745
|
+
# Legacy format: single entity
|
|
3746
|
+
else:
|
|
3747
|
+
entity_id = task_dict.get("entity_id")
|
|
3748
|
+
entity_name = task_dict.get("entity_name")
|
|
3749
|
+
version = task_dict.get("version")
|
|
3213
3750
|
|
|
3214
|
-
|
|
3215
|
-
|
|
3216
|
-
await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
|
|
3217
|
-
else:
|
|
3218
|
-
logger.debug(
|
|
3219
|
-
f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
|
|
3220
|
-
)
|
|
3751
|
+
if not all([bank_id, entity_id, entity_name]):
|
|
3752
|
+
raise ValueError(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
|
|
3221
3753
|
|
|
3222
|
-
|
|
3223
|
-
|
|
3224
|
-
|
|
3754
|
+
# Type assertions after validation
|
|
3755
|
+
assert isinstance(bank_id, str) and isinstance(entity_id, str) and isinstance(entity_name, str)
|
|
3756
|
+
await self.regenerate_entity_observations(
|
|
3757
|
+
bank_id, entity_id, entity_name, version=version, request_context=internal_context
|
|
3758
|
+
)
|
|
3225
3759
|
|
|
3226
|
-
|
|
3227
|
-
|
|
3228
|
-
|
|
3229
|
-
entity_name = task_dict.get("entity_name")
|
|
3230
|
-
version = task_dict.get("version")
|
|
3760
|
+
# =========================================================================
|
|
3761
|
+
# Statistics & Operations (for HTTP API layer)
|
|
3762
|
+
# =========================================================================
|
|
3231
3763
|
|
|
3232
|
-
|
|
3233
|
-
|
|
3234
|
-
|
|
3764
|
+
async def get_bank_stats(
|
|
3765
|
+
self,
|
|
3766
|
+
bank_id: str,
|
|
3767
|
+
*,
|
|
3768
|
+
request_context: "RequestContext",
|
|
3769
|
+
) -> dict[str, Any]:
|
|
3770
|
+
"""Get statistics about memory nodes and links for a bank."""
|
|
3771
|
+
await self._authenticate_tenant(request_context)
|
|
3772
|
+
pool = await self._get_pool()
|
|
3235
3773
|
|
|
3236
|
-
|
|
3774
|
+
async with acquire_with_retry(pool) as conn:
|
|
3775
|
+
# Get node counts by fact_type
|
|
3776
|
+
node_stats = await conn.fetch(
|
|
3777
|
+
f"""
|
|
3778
|
+
SELECT fact_type, COUNT(*) as count
|
|
3779
|
+
FROM {fq_table("memory_units")}
|
|
3780
|
+
WHERE bank_id = $1
|
|
3781
|
+
GROUP BY fact_type
|
|
3782
|
+
""",
|
|
3783
|
+
bank_id,
|
|
3784
|
+
)
|
|
3237
3785
|
|
|
3238
|
-
|
|
3239
|
-
|
|
3240
|
-
|
|
3786
|
+
# Get link counts by link_type
|
|
3787
|
+
link_stats = await conn.fetch(
|
|
3788
|
+
f"""
|
|
3789
|
+
SELECT ml.link_type, COUNT(*) as count
|
|
3790
|
+
FROM {fq_table("memory_links")} ml
|
|
3791
|
+
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
3792
|
+
WHERE mu.bank_id = $1
|
|
3793
|
+
GROUP BY ml.link_type
|
|
3794
|
+
""",
|
|
3795
|
+
bank_id,
|
|
3796
|
+
)
|
|
3241
3797
|
|
|
3242
|
-
|
|
3798
|
+
# Get link counts by fact_type (from nodes)
|
|
3799
|
+
link_fact_type_stats = await conn.fetch(
|
|
3800
|
+
f"""
|
|
3801
|
+
SELECT mu.fact_type, COUNT(*) as count
|
|
3802
|
+
FROM {fq_table("memory_links")} ml
|
|
3803
|
+
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
3804
|
+
WHERE mu.bank_id = $1
|
|
3805
|
+
GROUP BY mu.fact_type
|
|
3806
|
+
""",
|
|
3807
|
+
bank_id,
|
|
3808
|
+
)
|
|
3809
|
+
|
|
3810
|
+
# Get link counts by fact_type AND link_type
|
|
3811
|
+
link_breakdown_stats = await conn.fetch(
|
|
3812
|
+
f"""
|
|
3813
|
+
SELECT mu.fact_type, ml.link_type, COUNT(*) as count
|
|
3814
|
+
FROM {fq_table("memory_links")} ml
|
|
3815
|
+
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
3816
|
+
WHERE mu.bank_id = $1
|
|
3817
|
+
GROUP BY mu.fact_type, ml.link_type
|
|
3818
|
+
""",
|
|
3819
|
+
bank_id,
|
|
3820
|
+
)
|
|
3821
|
+
|
|
3822
|
+
# Get pending and failed operations counts
|
|
3823
|
+
ops_stats = await conn.fetch(
|
|
3824
|
+
f"""
|
|
3825
|
+
SELECT status, COUNT(*) as count
|
|
3826
|
+
FROM {fq_table("async_operations")}
|
|
3827
|
+
WHERE bank_id = $1
|
|
3828
|
+
GROUP BY status
|
|
3829
|
+
""",
|
|
3830
|
+
bank_id,
|
|
3831
|
+
)
|
|
3832
|
+
|
|
3833
|
+
return {
|
|
3834
|
+
"bank_id": bank_id,
|
|
3835
|
+
"node_counts": {row["fact_type"]: row["count"] for row in node_stats},
|
|
3836
|
+
"link_counts": {row["link_type"]: row["count"] for row in link_stats},
|
|
3837
|
+
"link_counts_by_fact_type": {row["fact_type"]: row["count"] for row in link_fact_type_stats},
|
|
3838
|
+
"link_breakdown": [
|
|
3839
|
+
{"fact_type": row["fact_type"], "link_type": row["link_type"], "count": row["count"]}
|
|
3840
|
+
for row in link_breakdown_stats
|
|
3841
|
+
],
|
|
3842
|
+
"operations": {row["status"]: row["count"] for row in ops_stats},
|
|
3843
|
+
}
|
|
3844
|
+
|
|
3845
|
+
async def get_entity(
|
|
3846
|
+
self,
|
|
3847
|
+
bank_id: str,
|
|
3848
|
+
entity_id: str,
|
|
3849
|
+
*,
|
|
3850
|
+
request_context: "RequestContext",
|
|
3851
|
+
) -> dict[str, Any] | None:
|
|
3852
|
+
"""Get entity details including metadata and observations."""
|
|
3853
|
+
await self._authenticate_tenant(request_context)
|
|
3854
|
+
pool = await self._get_pool()
|
|
3855
|
+
|
|
3856
|
+
async with acquire_with_retry(pool) as conn:
|
|
3857
|
+
entity_row = await conn.fetchrow(
|
|
3858
|
+
f"""
|
|
3859
|
+
SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
|
|
3860
|
+
FROM {fq_table("entities")}
|
|
3861
|
+
WHERE bank_id = $1 AND id = $2
|
|
3862
|
+
""",
|
|
3863
|
+
bank_id,
|
|
3864
|
+
uuid.UUID(entity_id),
|
|
3865
|
+
)
|
|
3866
|
+
|
|
3867
|
+
if not entity_row:
|
|
3868
|
+
return None
|
|
3869
|
+
|
|
3870
|
+
# Get observations for the entity
|
|
3871
|
+
observations = await self.get_entity_observations(bank_id, entity_id, limit=20, request_context=request_context)
|
|
3872
|
+
|
|
3873
|
+
return {
|
|
3874
|
+
"id": str(entity_row["id"]),
|
|
3875
|
+
"canonical_name": entity_row["canonical_name"],
|
|
3876
|
+
"mention_count": entity_row["mention_count"],
|
|
3877
|
+
"first_seen": entity_row["first_seen"].isoformat() if entity_row["first_seen"] else None,
|
|
3878
|
+
"last_seen": entity_row["last_seen"].isoformat() if entity_row["last_seen"] else None,
|
|
3879
|
+
"metadata": entity_row["metadata"] or {},
|
|
3880
|
+
"observations": observations,
|
|
3881
|
+
}
|
|
3882
|
+
|
|
3883
|
+
async def list_operations(
|
|
3884
|
+
self,
|
|
3885
|
+
bank_id: str,
|
|
3886
|
+
*,
|
|
3887
|
+
request_context: "RequestContext",
|
|
3888
|
+
) -> list[dict[str, Any]]:
|
|
3889
|
+
"""List async operations for a bank."""
|
|
3890
|
+
await self._authenticate_tenant(request_context)
|
|
3891
|
+
pool = await self._get_pool()
|
|
3892
|
+
|
|
3893
|
+
async with acquire_with_retry(pool) as conn:
|
|
3894
|
+
operations = await conn.fetch(
|
|
3895
|
+
f"""
|
|
3896
|
+
SELECT operation_id, bank_id, operation_type, created_at, status, error_message, result_metadata
|
|
3897
|
+
FROM {fq_table("async_operations")}
|
|
3898
|
+
WHERE bank_id = $1
|
|
3899
|
+
ORDER BY created_at DESC
|
|
3900
|
+
""",
|
|
3901
|
+
bank_id,
|
|
3902
|
+
)
|
|
3903
|
+
|
|
3904
|
+
def parse_metadata(metadata):
|
|
3905
|
+
if metadata is None:
|
|
3906
|
+
return {}
|
|
3907
|
+
if isinstance(metadata, str):
|
|
3908
|
+
import json
|
|
3909
|
+
|
|
3910
|
+
return json.loads(metadata)
|
|
3911
|
+
return metadata
|
|
3912
|
+
|
|
3913
|
+
return [
|
|
3914
|
+
{
|
|
3915
|
+
"id": str(row["operation_id"]),
|
|
3916
|
+
"task_type": row["operation_type"],
|
|
3917
|
+
"items_count": parse_metadata(row["result_metadata"]).get("items_count", 0),
|
|
3918
|
+
"document_id": parse_metadata(row["result_metadata"]).get("document_id"),
|
|
3919
|
+
"created_at": row["created_at"].isoformat(),
|
|
3920
|
+
"status": row["status"],
|
|
3921
|
+
"error_message": row["error_message"],
|
|
3922
|
+
}
|
|
3923
|
+
for row in operations
|
|
3924
|
+
]
|
|
3925
|
+
|
|
3926
|
+
async def cancel_operation(
|
|
3927
|
+
self,
|
|
3928
|
+
bank_id: str,
|
|
3929
|
+
operation_id: str,
|
|
3930
|
+
*,
|
|
3931
|
+
request_context: "RequestContext",
|
|
3932
|
+
) -> dict[str, Any]:
|
|
3933
|
+
"""Cancel a pending async operation."""
|
|
3934
|
+
await self._authenticate_tenant(request_context)
|
|
3935
|
+
pool = await self._get_pool()
|
|
3936
|
+
|
|
3937
|
+
op_uuid = uuid.UUID(operation_id)
|
|
3938
|
+
|
|
3939
|
+
async with acquire_with_retry(pool) as conn:
|
|
3940
|
+
# Check if operation exists and belongs to this memory bank
|
|
3941
|
+
result = await conn.fetchrow(
|
|
3942
|
+
f"SELECT bank_id FROM {fq_table('async_operations')} WHERE operation_id = $1 AND bank_id = $2",
|
|
3943
|
+
op_uuid,
|
|
3944
|
+
bank_id,
|
|
3945
|
+
)
|
|
3946
|
+
|
|
3947
|
+
if not result:
|
|
3948
|
+
raise ValueError(f"Operation {operation_id} not found for bank {bank_id}")
|
|
3949
|
+
|
|
3950
|
+
# Delete the operation
|
|
3951
|
+
await conn.execute(f"DELETE FROM {fq_table('async_operations')} WHERE operation_id = $1", op_uuid)
|
|
3952
|
+
|
|
3953
|
+
return {
|
|
3954
|
+
"success": True,
|
|
3955
|
+
"message": f"Operation {operation_id} cancelled",
|
|
3956
|
+
"operation_id": operation_id,
|
|
3957
|
+
"bank_id": bank_id,
|
|
3958
|
+
}
|
|
3959
|
+
|
|
3960
|
+
async def update_bank(
|
|
3961
|
+
self,
|
|
3962
|
+
bank_id: str,
|
|
3963
|
+
*,
|
|
3964
|
+
name: str | None = None,
|
|
3965
|
+
background: str | None = None,
|
|
3966
|
+
request_context: "RequestContext",
|
|
3967
|
+
) -> dict[str, Any]:
|
|
3968
|
+
"""Update bank name and/or background."""
|
|
3969
|
+
await self._authenticate_tenant(request_context)
|
|
3970
|
+
pool = await self._get_pool()
|
|
3971
|
+
|
|
3972
|
+
async with acquire_with_retry(pool) as conn:
|
|
3973
|
+
if name is not None:
|
|
3974
|
+
await conn.execute(
|
|
3975
|
+
f"""
|
|
3976
|
+
UPDATE {fq_table("banks")}
|
|
3977
|
+
SET name = $2, updated_at = NOW()
|
|
3978
|
+
WHERE bank_id = $1
|
|
3979
|
+
""",
|
|
3980
|
+
bank_id,
|
|
3981
|
+
name,
|
|
3982
|
+
)
|
|
3983
|
+
|
|
3984
|
+
if background is not None:
|
|
3985
|
+
await conn.execute(
|
|
3986
|
+
f"""
|
|
3987
|
+
UPDATE {fq_table("banks")}
|
|
3988
|
+
SET background = $2, updated_at = NOW()
|
|
3989
|
+
WHERE bank_id = $1
|
|
3990
|
+
""",
|
|
3991
|
+
bank_id,
|
|
3992
|
+
background,
|
|
3993
|
+
)
|
|
3994
|
+
|
|
3995
|
+
# Return updated profile
|
|
3996
|
+
return await self.get_bank_profile(bank_id, request_context=request_context)
|
|
3997
|
+
|
|
3998
|
+
async def submit_async_retain(
|
|
3999
|
+
self,
|
|
4000
|
+
bank_id: str,
|
|
4001
|
+
contents: list[dict[str, Any]],
|
|
4002
|
+
*,
|
|
4003
|
+
request_context: "RequestContext",
|
|
4004
|
+
) -> dict[str, Any]:
|
|
4005
|
+
"""Submit a batch retain operation to run asynchronously."""
|
|
4006
|
+
await self._authenticate_tenant(request_context)
|
|
4007
|
+
pool = await self._get_pool()
|
|
4008
|
+
|
|
4009
|
+
import json
|
|
4010
|
+
|
|
4011
|
+
operation_id = uuid.uuid4()
|
|
4012
|
+
|
|
4013
|
+
# Insert operation record into database
|
|
4014
|
+
async with acquire_with_retry(pool) as conn:
|
|
4015
|
+
await conn.execute(
|
|
4016
|
+
f"""
|
|
4017
|
+
INSERT INTO {fq_table("async_operations")} (operation_id, bank_id, operation_type, result_metadata)
|
|
4018
|
+
VALUES ($1, $2, $3, $4)
|
|
4019
|
+
""",
|
|
4020
|
+
operation_id,
|
|
4021
|
+
bank_id,
|
|
4022
|
+
"retain",
|
|
4023
|
+
json.dumps({"items_count": len(contents)}),
|
|
4024
|
+
)
|
|
4025
|
+
|
|
4026
|
+
# Submit task to background queue
|
|
4027
|
+
await self._task_backend.submit_task(
|
|
4028
|
+
{
|
|
4029
|
+
"type": "batch_retain",
|
|
4030
|
+
"operation_id": str(operation_id),
|
|
4031
|
+
"bank_id": bank_id,
|
|
4032
|
+
"contents": contents,
|
|
4033
|
+
}
|
|
4034
|
+
)
|
|
4035
|
+
|
|
4036
|
+
logger.info(f"Retain task queued for bank_id={bank_id}, {len(contents)} items, operation_id={operation_id}")
|
|
4037
|
+
|
|
4038
|
+
return {
|
|
4039
|
+
"operation_id": str(operation_id),
|
|
4040
|
+
"items_count": len(contents),
|
|
4041
|
+
}
|