hindsight-api 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +2 -0
- hindsight_api/alembic/env.py +24 -1
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +14 -4
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +54 -13
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +18 -7
- hindsight_api/api/http.py +234 -228
- hindsight_api/api/mcp.py +14 -3
- hindsight_api/engine/__init__.py +12 -1
- hindsight_api/engine/entity_resolver.py +38 -37
- hindsight_api/engine/interface.py +592 -0
- hindsight_api/engine/llm_wrapper.py +176 -6
- hindsight_api/engine/memory_engine.py +993 -217
- hindsight_api/engine/retain/bank_utils.py +13 -12
- hindsight_api/engine/retain/chunk_storage.py +3 -2
- hindsight_api/engine/retain/fact_storage.py +10 -7
- hindsight_api/engine/retain/link_utils.py +17 -16
- hindsight_api/engine/retain/observation_regeneration.py +17 -16
- hindsight_api/engine/retain/orchestrator.py +2 -3
- hindsight_api/engine/retain/types.py +25 -8
- hindsight_api/engine/search/graph_retrieval.py +6 -5
- hindsight_api/engine/search/mpfp_retrieval.py +8 -7
- hindsight_api/engine/search/retrieval.py +12 -11
- hindsight_api/engine/search/think_utils.py +1 -1
- hindsight_api/engine/search/tracer.py +1 -1
- hindsight_api/engine/task_backend.py +32 -0
- hindsight_api/extensions/__init__.py +66 -0
- hindsight_api/extensions/base.py +81 -0
- hindsight_api/extensions/builtin/__init__.py +18 -0
- hindsight_api/extensions/builtin/tenant.py +33 -0
- hindsight_api/extensions/context.py +110 -0
- hindsight_api/extensions/http.py +89 -0
- hindsight_api/extensions/loader.py +125 -0
- hindsight_api/extensions/operation_validator.py +325 -0
- hindsight_api/extensions/tenant.py +63 -0
- hindsight_api/main.py +1 -1
- hindsight_api/mcp_local.py +7 -1
- hindsight_api/migrations.py +54 -10
- hindsight_api/models.py +15 -0
- hindsight_api/pg0.py +1 -1
- {hindsight_api-0.1.11.dist-info → hindsight_api-0.1.12.dist-info}/METADATA +1 -1
- hindsight_api-0.1.12.dist-info/RECORD +74 -0
- hindsight_api-0.1.11.dist-info/RECORD +0 -64
- {hindsight_api-0.1.11.dist-info → hindsight_api-0.1.12.dist-info}/WHEEL +0 -0
- {hindsight_api-0.1.11.dist-info → hindsight_api-0.1.12.dist-info}/entry_points.txt +0 -0
|
@@ -10,39 +10,122 @@ This implements a sophisticated memory architecture that combines:
|
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
12
|
import asyncio
|
|
13
|
+
import contextvars
|
|
13
14
|
import logging
|
|
14
15
|
import time
|
|
15
16
|
import uuid
|
|
16
17
|
from datetime import UTC, datetime, timedelta
|
|
17
|
-
from typing import TYPE_CHECKING, Any
|
|
18
|
+
from typing import TYPE_CHECKING, Any
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
from pydantic import BaseModel, Field
|
|
20
|
+
# Context variable for current schema (async-safe, per-task isolation)
|
|
21
|
+
_current_schema: contextvars.ContextVar[str] = contextvars.ContextVar("current_schema", default="public")
|
|
22
22
|
|
|
23
|
-
from .cross_encoder import CrossEncoderModel
|
|
24
|
-
from .embeddings import Embeddings, create_embeddings_from_env
|
|
25
23
|
|
|
26
|
-
|
|
24
|
+
def get_current_schema() -> str:
|
|
25
|
+
"""Get the current schema from context (default: 'public')."""
|
|
26
|
+
return _current_schema.get()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def fq_table(table_name: str) -> str:
|
|
30
|
+
"""
|
|
31
|
+
Get fully-qualified table name with current schema.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
fq_table("memory_units") -> "public.memory_units"
|
|
35
|
+
fq_table("memory_units") -> "tenant_xyz.memory_units" (if schema is set)
|
|
36
|
+
"""
|
|
37
|
+
return f"{get_current_schema()}.{table_name}"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Tables that must be schema-qualified (for runtime validation)
|
|
41
|
+
_PROTECTED_TABLES = frozenset(
|
|
42
|
+
[
|
|
43
|
+
"memory_units",
|
|
44
|
+
"memory_links",
|
|
45
|
+
"unit_entities",
|
|
46
|
+
"entities",
|
|
47
|
+
"entity_cooccurrences",
|
|
48
|
+
"banks",
|
|
49
|
+
"documents",
|
|
50
|
+
"chunks",
|
|
51
|
+
"async_operations",
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Enable runtime SQL validation (can be disabled in production for performance)
|
|
56
|
+
_VALIDATE_SQL_SCHEMAS = True
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class UnqualifiedTableError(Exception):
|
|
60
|
+
"""Raised when SQL contains unqualified table references."""
|
|
61
|
+
|
|
27
62
|
pass
|
|
28
63
|
|
|
29
64
|
|
|
30
|
-
|
|
31
|
-
"""
|
|
65
|
+
def validate_sql_schema(sql: str) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Validate that SQL doesn't contain unqualified table references.
|
|
68
|
+
|
|
69
|
+
This is a runtime safety check to prevent cross-tenant data access.
|
|
70
|
+
Raises UnqualifiedTableError if any protected table is referenced
|
|
71
|
+
without a schema prefix.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
sql: The SQL query to validate
|
|
32
75
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
context: Context about the content (optional)
|
|
36
|
-
event_date: When the content occurred (optional, defaults to now)
|
|
37
|
-
metadata: Custom key-value metadata (optional)
|
|
38
|
-
document_id: Document ID for this content item (optional)
|
|
76
|
+
Raises:
|
|
77
|
+
UnqualifiedTableError: If unqualified table reference found
|
|
39
78
|
"""
|
|
79
|
+
if not _VALIDATE_SQL_SCHEMAS:
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
import re
|
|
83
|
+
|
|
84
|
+
sql_upper = sql.upper()
|
|
85
|
+
|
|
86
|
+
for table in _PROTECTED_TABLES:
|
|
87
|
+
table_upper = table.upper()
|
|
88
|
+
|
|
89
|
+
# Pattern: SQL keyword followed by unqualified table name
|
|
90
|
+
# Matches: FROM memory_units, JOIN memory_units, INTO memory_units, UPDATE memory_units
|
|
91
|
+
patterns = [
|
|
92
|
+
rf"FROM\s+{table_upper}(?:\s|$|,|\)|;)",
|
|
93
|
+
rf"JOIN\s+{table_upper}(?:\s|$|,|\)|;)",
|
|
94
|
+
rf"INTO\s+{table_upper}(?:\s|$|\()",
|
|
95
|
+
rf"UPDATE\s+{table_upper}(?:\s|$)",
|
|
96
|
+
rf"DELETE\s+FROM\s+{table_upper}(?:\s|$|;)",
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
for pattern in patterns:
|
|
100
|
+
match = re.search(pattern, sql_upper)
|
|
101
|
+
if match:
|
|
102
|
+
# Check if it's actually qualified (preceded by schema.)
|
|
103
|
+
# Look backwards from match to see if there's a dot
|
|
104
|
+
start = match.start()
|
|
105
|
+
# Find the table name position in the match
|
|
106
|
+
table_pos = sql_upper.find(table_upper, start)
|
|
107
|
+
if table_pos > 0:
|
|
108
|
+
# Check character before table name (skip whitespace)
|
|
109
|
+
prefix = sql[:table_pos].rstrip()
|
|
110
|
+
if not prefix.endswith("."):
|
|
111
|
+
raise UnqualifiedTableError(
|
|
112
|
+
f"Unqualified table reference '{table}' in SQL. "
|
|
113
|
+
f"Use fq_table('{table}') for schema safety. "
|
|
114
|
+
f"SQL snippet: ...{sql[max(0, start - 10) : start + 50]}..."
|
|
115
|
+
)
|
|
40
116
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
117
|
+
|
|
118
|
+
import asyncpg
|
|
119
|
+
import numpy as np
|
|
120
|
+
from pydantic import BaseModel, Field
|
|
121
|
+
|
|
122
|
+
from .cross_encoder import CrossEncoderModel
|
|
123
|
+
from .embeddings import Embeddings, create_embeddings_from_env
|
|
124
|
+
from .interface import MemoryEngineInterface
|
|
125
|
+
|
|
126
|
+
if TYPE_CHECKING:
|
|
127
|
+
from hindsight_api.extensions import OperationValidatorExtension, TenantExtension
|
|
128
|
+
from hindsight_api.models import RequestContext
|
|
46
129
|
|
|
47
130
|
|
|
48
131
|
from enum import Enum
|
|
@@ -54,6 +137,7 @@ from .query_analyzer import QueryAnalyzer
|
|
|
54
137
|
from .response_models import VALID_RECALL_FACT_TYPES, EntityObservation, EntityState, MemoryFact, ReflectResult
|
|
55
138
|
from .response_models import RecallResult as RecallResultModel
|
|
56
139
|
from .retain import bank_utils, embedding_utils
|
|
140
|
+
from .retain.types import RetainContentDict
|
|
57
141
|
from .search import observation_utils, think_utils
|
|
58
142
|
from .search.reranking import CrossEncoderReranker
|
|
59
143
|
from .task_backend import AsyncIOQueueBackend, TaskBackend
|
|
@@ -91,7 +175,7 @@ def _get_tiktoken_encoding():
|
|
|
91
175
|
return _TIKTOKEN_ENCODING
|
|
92
176
|
|
|
93
177
|
|
|
94
|
-
class MemoryEngine:
|
|
178
|
+
class MemoryEngine(MemoryEngineInterface):
|
|
95
179
|
"""
|
|
96
180
|
Advanced memory system using temporal and semantic linking with PostgreSQL.
|
|
97
181
|
|
|
@@ -116,6 +200,8 @@ class MemoryEngine:
|
|
|
116
200
|
pool_max_size: int = 100,
|
|
117
201
|
task_backend: TaskBackend | None = None,
|
|
118
202
|
run_migrations: bool = True,
|
|
203
|
+
operation_validator: "OperationValidatorExtension | None" = None,
|
|
204
|
+
tenant_extension: "TenantExtension | None" = None,
|
|
119
205
|
):
|
|
120
206
|
"""
|
|
121
207
|
Initialize the temporal + semantic memory system.
|
|
@@ -137,6 +223,10 @@ class MemoryEngine:
|
|
|
137
223
|
pool_max_size: Maximum number of connections in the pool (default: 100)
|
|
138
224
|
task_backend: Custom task backend. If not provided, uses AsyncIOQueueBackend.
|
|
139
225
|
run_migrations: Whether to run database migrations during initialize(). Default: True
|
|
226
|
+
operation_validator: Optional extension to validate operations before execution.
|
|
227
|
+
If provided, retain/recall/reflect operations will be validated.
|
|
228
|
+
tenant_extension: Optional extension for multi-tenancy and API key authentication.
|
|
229
|
+
If provided, operations require a RequestContext for authentication.
|
|
140
230
|
"""
|
|
141
231
|
# Load config from environment for any missing parameters
|
|
142
232
|
from ..config import get_config
|
|
@@ -147,6 +237,9 @@ class MemoryEngine:
|
|
|
147
237
|
db_url = db_url or config.database_url
|
|
148
238
|
memory_llm_provider = memory_llm_provider or config.llm_provider
|
|
149
239
|
memory_llm_api_key = memory_llm_api_key or config.llm_api_key
|
|
240
|
+
# Ollama doesn't require an API key
|
|
241
|
+
if not memory_llm_api_key and memory_llm_provider != "ollama":
|
|
242
|
+
raise ValueError("LLM API key is required. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
|
|
150
243
|
memory_llm_model = memory_llm_model or config.llm_model
|
|
151
244
|
memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
|
|
152
245
|
# Track pg0 instance (if used)
|
|
@@ -243,6 +336,60 @@ class MemoryEngine:
|
|
|
243
336
|
# initialize encoding eagerly to avoid delaying the first time
|
|
244
337
|
_get_tiktoken_encoding()
|
|
245
338
|
|
|
339
|
+
# Store operation validator extension (optional)
|
|
340
|
+
self._operation_validator = operation_validator
|
|
341
|
+
|
|
342
|
+
# Store tenant extension (optional)
|
|
343
|
+
self._tenant_extension = tenant_extension
|
|
344
|
+
|
|
345
|
+
async def _validate_operation(self, validation_coro) -> None:
|
|
346
|
+
"""
|
|
347
|
+
Run validation if an operation validator is configured.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
validation_coro: Coroutine that returns a ValidationResult
|
|
351
|
+
|
|
352
|
+
Raises:
|
|
353
|
+
OperationValidationError: If validation fails
|
|
354
|
+
"""
|
|
355
|
+
if self._operation_validator is None:
|
|
356
|
+
return
|
|
357
|
+
|
|
358
|
+
from hindsight_api.extensions import OperationValidationError
|
|
359
|
+
|
|
360
|
+
result = await validation_coro
|
|
361
|
+
if not result.allowed:
|
|
362
|
+
raise OperationValidationError(result.reason or "Operation not allowed")
|
|
363
|
+
|
|
364
|
+
async def _authenticate_tenant(self, request_context: "RequestContext | None") -> str:
|
|
365
|
+
"""
|
|
366
|
+
Authenticate tenant and set schema in context variable.
|
|
367
|
+
|
|
368
|
+
The schema is stored in a contextvar for async-safe, per-task isolation.
|
|
369
|
+
Use fq_table(table_name) to get fully-qualified table names.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
request_context: The request context with API key. Required if tenant_extension is configured.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
Schema name that was set in the context.
|
|
376
|
+
|
|
377
|
+
Raises:
|
|
378
|
+
AuthenticationError: If authentication fails or request_context is missing when required.
|
|
379
|
+
"""
|
|
380
|
+
if self._tenant_extension is None:
|
|
381
|
+
_current_schema.set("public")
|
|
382
|
+
return "public"
|
|
383
|
+
|
|
384
|
+
from hindsight_api.extensions import AuthenticationError
|
|
385
|
+
|
|
386
|
+
if request_context is None:
|
|
387
|
+
raise AuthenticationError("RequestContext is required when tenant extension is configured")
|
|
388
|
+
|
|
389
|
+
tenant_context = await self._tenant_extension.authenticate(request_context)
|
|
390
|
+
_current_schema.set(tenant_context.schema_name)
|
|
391
|
+
return tenant_context.schema_name
|
|
392
|
+
|
|
246
393
|
async def _handle_access_count_update(self, task_dict: dict[str, Any]):
|
|
247
394
|
"""
|
|
248
395
|
Handler for access count update tasks.
|
|
@@ -260,7 +407,8 @@ class MemoryEngine:
|
|
|
260
407
|
uuid_list = [uuid.UUID(nid) for nid in node_ids]
|
|
261
408
|
async with acquire_with_retry(pool) as conn:
|
|
262
409
|
await conn.execute(
|
|
263
|
-
"UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
|
|
410
|
+
f"UPDATE {fq_table('memory_units')} SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
|
|
411
|
+
uuid_list,
|
|
264
412
|
)
|
|
265
413
|
except Exception as e:
|
|
266
414
|
logger.error(f"Access count handler: Error updating access counts: {e}")
|
|
@@ -274,13 +422,19 @@ class MemoryEngine:
|
|
|
274
422
|
"""
|
|
275
423
|
try:
|
|
276
424
|
bank_id = task_dict.get("bank_id")
|
|
425
|
+
if not bank_id:
|
|
426
|
+
raise ValueError("bank_id is required for batch retain task")
|
|
277
427
|
contents = task_dict.get("contents", [])
|
|
278
428
|
|
|
279
429
|
logger.info(
|
|
280
430
|
f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
|
|
281
431
|
)
|
|
282
432
|
|
|
283
|
-
|
|
433
|
+
# Use internal request context for background tasks
|
|
434
|
+
from hindsight_api.models import RequestContext
|
|
435
|
+
|
|
436
|
+
internal_context = RequestContext()
|
|
437
|
+
await self.retain_batch_async(bank_id=bank_id, contents=contents, request_context=internal_context)
|
|
284
438
|
|
|
285
439
|
logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
|
|
286
440
|
except Exception as e:
|
|
@@ -311,7 +465,8 @@ class MemoryEngine:
|
|
|
311
465
|
pool = await self._get_pool()
|
|
312
466
|
async with acquire_with_retry(pool) as conn:
|
|
313
467
|
result = await conn.fetchrow(
|
|
314
|
-
"SELECT operation_id FROM async_operations WHERE operation_id = $1",
|
|
468
|
+
f"SELECT operation_id FROM {fq_table('async_operations')} WHERE operation_id = $1",
|
|
469
|
+
uuid.UUID(operation_id),
|
|
315
470
|
)
|
|
316
471
|
if not result:
|
|
317
472
|
# Operation was cancelled, skip processing
|
|
@@ -369,7 +524,9 @@ class MemoryEngine:
|
|
|
369
524
|
try:
|
|
370
525
|
pool = await self._get_pool()
|
|
371
526
|
async with acquire_with_retry(pool) as conn:
|
|
372
|
-
await conn.execute(
|
|
527
|
+
await conn.execute(
|
|
528
|
+
f"DELETE FROM {fq_table('async_operations')} WHERE operation_id = $1", uuid.UUID(operation_id)
|
|
529
|
+
)
|
|
373
530
|
except Exception as e:
|
|
374
531
|
logger.error(f"Failed to delete async operation record {operation_id}: {e}")
|
|
375
532
|
|
|
@@ -383,8 +540,8 @@ class MemoryEngine:
|
|
|
383
540
|
|
|
384
541
|
async with acquire_with_retry(pool) as conn:
|
|
385
542
|
await conn.execute(
|
|
386
|
-
"""
|
|
387
|
-
UPDATE async_operations
|
|
543
|
+
f"""
|
|
544
|
+
UPDATE {fq_table("async_operations")}
|
|
388
545
|
SET status = 'failed', error_message = $2
|
|
389
546
|
WHERE operation_id = $1
|
|
390
547
|
""",
|
|
@@ -413,7 +570,7 @@ class MemoryEngine:
|
|
|
413
570
|
kwargs = {"name": self._pg0_instance_name}
|
|
414
571
|
if self._pg0_port is not None:
|
|
415
572
|
kwargs["port"] = self._pg0_port
|
|
416
|
-
pg0 = EmbeddedPostgres(**kwargs)
|
|
573
|
+
pg0 = EmbeddedPostgres(**kwargs) # type: ignore[invalid-argument-type] - dict kwargs
|
|
417
574
|
# Check if pg0 is already running before we start it
|
|
418
575
|
was_already_running = await pg0.is_running()
|
|
419
576
|
self.db_url = await pg0.ensure_running()
|
|
@@ -460,6 +617,8 @@ class MemoryEngine:
|
|
|
460
617
|
if self._run_migrations:
|
|
461
618
|
from ..migrations import run_migrations
|
|
462
619
|
|
|
620
|
+
if not self.db_url:
|
|
621
|
+
raise ValueError("Database URL is required for migrations")
|
|
463
622
|
logger.info("Running database migrations...")
|
|
464
623
|
run_migrations(self.db_url)
|
|
465
624
|
|
|
@@ -628,9 +787,9 @@ class MemoryEngine:
|
|
|
628
787
|
|
|
629
788
|
fetch_start = time_mod.time()
|
|
630
789
|
existing_facts = await conn.fetch(
|
|
631
|
-
"""
|
|
790
|
+
f"""
|
|
632
791
|
SELECT id, text, embedding
|
|
633
|
-
FROM memory_units
|
|
792
|
+
FROM {fq_table("memory_units")}
|
|
634
793
|
WHERE bank_id = $1
|
|
635
794
|
AND event_date BETWEEN $2 AND $3
|
|
636
795
|
""",
|
|
@@ -692,6 +851,7 @@ class MemoryEngine:
|
|
|
692
851
|
content: str,
|
|
693
852
|
context: str = "",
|
|
694
853
|
event_date: datetime | None = None,
|
|
854
|
+
request_context: "RequestContext | None" = None,
|
|
695
855
|
) -> list[str]:
|
|
696
856
|
"""
|
|
697
857
|
Store content as memory units (synchronous wrapper).
|
|
@@ -704,12 +864,16 @@ class MemoryEngine:
|
|
|
704
864
|
content: Text content to store
|
|
705
865
|
context: Context about when/why this memory was formed
|
|
706
866
|
event_date: When the event occurred (defaults to now)
|
|
867
|
+
request_context: Request context for authentication (optional, uses internal context if not provided)
|
|
707
868
|
|
|
708
869
|
Returns:
|
|
709
870
|
List of created unit IDs
|
|
710
871
|
"""
|
|
711
872
|
# Run async version synchronously
|
|
712
|
-
|
|
873
|
+
from hindsight_api.models import RequestContext as RC
|
|
874
|
+
|
|
875
|
+
ctx = request_context if request_context is not None else RC()
|
|
876
|
+
return asyncio.run(self.retain_async(bank_id, content, context, event_date, request_context=ctx))
|
|
713
877
|
|
|
714
878
|
async def retain_async(
|
|
715
879
|
self,
|
|
@@ -720,6 +884,8 @@ class MemoryEngine:
|
|
|
720
884
|
document_id: str | None = None,
|
|
721
885
|
fact_type_override: str | None = None,
|
|
722
886
|
confidence_score: float | None = None,
|
|
887
|
+
*,
|
|
888
|
+
request_context: "RequestContext",
|
|
723
889
|
) -> list[str]:
|
|
724
890
|
"""
|
|
725
891
|
Store content as memory units with temporal and semantic links (ASYNC version).
|
|
@@ -734,12 +900,15 @@ class MemoryEngine:
|
|
|
734
900
|
document_id: Optional document ID for tracking (always upserts if document already exists)
|
|
735
901
|
fact_type_override: Override fact type ('world', 'experience', 'opinion')
|
|
736
902
|
confidence_score: Confidence score for opinions (0.0 to 1.0)
|
|
903
|
+
request_context: Request context for authentication.
|
|
737
904
|
|
|
738
905
|
Returns:
|
|
739
906
|
List of created unit IDs
|
|
740
907
|
"""
|
|
741
908
|
# Build content dict
|
|
742
|
-
content_dict: RetainContentDict = {"content": content, "context": context
|
|
909
|
+
content_dict: RetainContentDict = {"content": content, "context": context} # type: ignore[typeddict-item] - building incrementally
|
|
910
|
+
if event_date:
|
|
911
|
+
content_dict["event_date"] = event_date
|
|
743
912
|
if document_id:
|
|
744
913
|
content_dict["document_id"] = document_id
|
|
745
914
|
|
|
@@ -747,6 +916,7 @@ class MemoryEngine:
|
|
|
747
916
|
result = await self.retain_batch_async(
|
|
748
917
|
bank_id=bank_id,
|
|
749
918
|
contents=[content_dict],
|
|
919
|
+
request_context=request_context,
|
|
750
920
|
fact_type_override=fact_type_override,
|
|
751
921
|
confidence_score=confidence_score,
|
|
752
922
|
)
|
|
@@ -758,6 +928,8 @@ class MemoryEngine:
|
|
|
758
928
|
self,
|
|
759
929
|
bank_id: str,
|
|
760
930
|
contents: list[RetainContentDict],
|
|
931
|
+
*,
|
|
932
|
+
request_context: "RequestContext",
|
|
761
933
|
document_id: str | None = None,
|
|
762
934
|
fact_type_override: str | None = None,
|
|
763
935
|
confidence_score: float | None = None,
|
|
@@ -813,6 +985,24 @@ class MemoryEngine:
|
|
|
813
985
|
if not contents:
|
|
814
986
|
return []
|
|
815
987
|
|
|
988
|
+
# Authenticate tenant and set schema in context (for fq_table())
|
|
989
|
+
await self._authenticate_tenant(request_context)
|
|
990
|
+
|
|
991
|
+
# Validate operation if validator is configured
|
|
992
|
+
contents_copy = [dict(c) for c in contents] # Convert TypedDict to regular dict for extension
|
|
993
|
+
if self._operation_validator:
|
|
994
|
+
from hindsight_api.extensions import RetainContext
|
|
995
|
+
|
|
996
|
+
ctx = RetainContext(
|
|
997
|
+
bank_id=bank_id,
|
|
998
|
+
contents=contents_copy,
|
|
999
|
+
request_context=request_context,
|
|
1000
|
+
document_id=document_id,
|
|
1001
|
+
fact_type_override=fact_type_override,
|
|
1002
|
+
confidence_score=confidence_score,
|
|
1003
|
+
)
|
|
1004
|
+
await self._validate_operation(self._operation_validator.validate_retain(ctx))
|
|
1005
|
+
|
|
816
1006
|
# Apply batch-level document_id to contents that don't have their own (backwards compatibility)
|
|
817
1007
|
if document_id:
|
|
818
1008
|
for item in contents:
|
|
@@ -876,17 +1066,39 @@ class MemoryEngine:
|
|
|
876
1066
|
logger.info(
|
|
877
1067
|
f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s"
|
|
878
1068
|
)
|
|
879
|
-
|
|
1069
|
+
result = all_results
|
|
1070
|
+
else:
|
|
1071
|
+
# Small batch - use internal method directly
|
|
1072
|
+
result = await self._retain_batch_async_internal(
|
|
1073
|
+
bank_id=bank_id,
|
|
1074
|
+
contents=contents,
|
|
1075
|
+
document_id=document_id,
|
|
1076
|
+
is_first_batch=True,
|
|
1077
|
+
fact_type_override=fact_type_override,
|
|
1078
|
+
confidence_score=confidence_score,
|
|
1079
|
+
)
|
|
880
1080
|
|
|
881
|
-
#
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
1081
|
+
# Call post-operation hook if validator is configured
|
|
1082
|
+
if self._operation_validator:
|
|
1083
|
+
from hindsight_api.extensions import RetainResult
|
|
1084
|
+
|
|
1085
|
+
result_ctx = RetainResult(
|
|
1086
|
+
bank_id=bank_id,
|
|
1087
|
+
contents=contents_copy,
|
|
1088
|
+
request_context=request_context,
|
|
1089
|
+
document_id=document_id,
|
|
1090
|
+
fact_type_override=fact_type_override,
|
|
1091
|
+
confidence_score=confidence_score,
|
|
1092
|
+
unit_ids=result,
|
|
1093
|
+
success=True,
|
|
1094
|
+
error=None,
|
|
1095
|
+
)
|
|
1096
|
+
try:
|
|
1097
|
+
await self._operation_validator.on_retain_complete(result_ctx)
|
|
1098
|
+
except Exception as e:
|
|
1099
|
+
logger.warning(f"Post-retain hook error (non-fatal): {e}")
|
|
1100
|
+
|
|
1101
|
+
return result
|
|
890
1102
|
|
|
891
1103
|
async def _retain_batch_async_internal(
|
|
892
1104
|
self,
|
|
@@ -961,22 +1173,36 @@ class MemoryEngine:
|
|
|
961
1173
|
Returns:
|
|
962
1174
|
Tuple of (results, trace)
|
|
963
1175
|
"""
|
|
964
|
-
# Run async version synchronously
|
|
965
|
-
|
|
1176
|
+
# Run async version synchronously - deprecated sync method, passing None for request_context
|
|
1177
|
+
from hindsight_api.models import RequestContext
|
|
1178
|
+
|
|
1179
|
+
return asyncio.run(
|
|
1180
|
+
self.recall_async(
|
|
1181
|
+
bank_id,
|
|
1182
|
+
query,
|
|
1183
|
+
budget=budget,
|
|
1184
|
+
max_tokens=max_tokens,
|
|
1185
|
+
enable_trace=enable_trace,
|
|
1186
|
+
fact_type=[fact_type],
|
|
1187
|
+
request_context=RequestContext(),
|
|
1188
|
+
)
|
|
1189
|
+
)
|
|
966
1190
|
|
|
967
1191
|
async def recall_async(
|
|
968
1192
|
self,
|
|
969
1193
|
bank_id: str,
|
|
970
1194
|
query: str,
|
|
971
|
-
|
|
972
|
-
budget: Budget =
|
|
1195
|
+
*,
|
|
1196
|
+
budget: Budget | None = None,
|
|
973
1197
|
max_tokens: int = 4096,
|
|
974
1198
|
enable_trace: bool = False,
|
|
1199
|
+
fact_type: list[str] | None = None,
|
|
975
1200
|
question_date: datetime | None = None,
|
|
976
1201
|
include_entities: bool = False,
|
|
977
|
-
max_entity_tokens: int =
|
|
1202
|
+
max_entity_tokens: int = 500,
|
|
978
1203
|
include_chunks: bool = False,
|
|
979
1204
|
max_chunk_tokens: int = 8192,
|
|
1205
|
+
request_context: "RequestContext",
|
|
980
1206
|
) -> RecallResultModel:
|
|
981
1207
|
"""
|
|
982
1208
|
Recall memories using N*4-way parallel retrieval (N fact types × 4 retrieval methods).
|
|
@@ -1010,6 +1236,13 @@ class MemoryEngine:
|
|
|
1010
1236
|
- entities: Optional dict of entity states (if include_entities=True)
|
|
1011
1237
|
- chunks: Optional dict of chunks (if include_chunks=True)
|
|
1012
1238
|
"""
|
|
1239
|
+
# Authenticate tenant and set schema in context (for fq_table())
|
|
1240
|
+
await self._authenticate_tenant(request_context)
|
|
1241
|
+
|
|
1242
|
+
# Default to all fact types if not specified
|
|
1243
|
+
if fact_type is None:
|
|
1244
|
+
fact_type = list(VALID_RECALL_FACT_TYPES)
|
|
1245
|
+
|
|
1013
1246
|
# Validate fact types early
|
|
1014
1247
|
invalid_types = set(fact_type) - VALID_RECALL_FACT_TYPES
|
|
1015
1248
|
if invalid_types:
|
|
@@ -1018,17 +1251,40 @@ class MemoryEngine:
|
|
|
1018
1251
|
f"Must be one of: {', '.join(sorted(VALID_RECALL_FACT_TYPES))}"
|
|
1019
1252
|
)
|
|
1020
1253
|
|
|
1021
|
-
#
|
|
1254
|
+
# Validate operation if validator is configured
|
|
1255
|
+
if self._operation_validator:
|
|
1256
|
+
from hindsight_api.extensions import RecallContext
|
|
1257
|
+
|
|
1258
|
+
ctx = RecallContext(
|
|
1259
|
+
bank_id=bank_id,
|
|
1260
|
+
query=query,
|
|
1261
|
+
request_context=request_context,
|
|
1262
|
+
budget=budget,
|
|
1263
|
+
max_tokens=max_tokens,
|
|
1264
|
+
enable_trace=enable_trace,
|
|
1265
|
+
fact_types=list(fact_type),
|
|
1266
|
+
question_date=question_date,
|
|
1267
|
+
include_entities=include_entities,
|
|
1268
|
+
max_entity_tokens=max_entity_tokens,
|
|
1269
|
+
include_chunks=include_chunks,
|
|
1270
|
+
max_chunk_tokens=max_chunk_tokens,
|
|
1271
|
+
)
|
|
1272
|
+
await self._validate_operation(self._operation_validator.validate_recall(ctx))
|
|
1273
|
+
|
|
1274
|
+
# Map budget enum to thinking_budget number (default to MID if None)
|
|
1022
1275
|
budget_mapping = {Budget.LOW: 100, Budget.MID: 300, Budget.HIGH: 1000}
|
|
1023
|
-
|
|
1276
|
+
effective_budget = budget if budget is not None else Budget.MID
|
|
1277
|
+
thinking_budget = budget_mapping[effective_budget]
|
|
1024
1278
|
|
|
1025
1279
|
# Backpressure: limit concurrent recalls to prevent overwhelming the database
|
|
1280
|
+
result = None
|
|
1281
|
+
error_msg = None
|
|
1026
1282
|
async with self._search_semaphore:
|
|
1027
1283
|
# Retry loop for connection errors
|
|
1028
1284
|
max_retries = 3
|
|
1029
1285
|
for attempt in range(max_retries + 1):
|
|
1030
1286
|
try:
|
|
1031
|
-
|
|
1287
|
+
result = await self._search_with_retries(
|
|
1032
1288
|
bank_id,
|
|
1033
1289
|
query,
|
|
1034
1290
|
fact_type,
|
|
@@ -1040,7 +1296,9 @@ class MemoryEngine:
|
|
|
1040
1296
|
max_entity_tokens,
|
|
1041
1297
|
include_chunks,
|
|
1042
1298
|
max_chunk_tokens,
|
|
1299
|
+
request_context,
|
|
1043
1300
|
)
|
|
1301
|
+
break # Success - exit retry loop
|
|
1044
1302
|
except Exception as e:
|
|
1045
1303
|
# Check if it's a connection error
|
|
1046
1304
|
is_connection_error = (
|
|
@@ -1058,9 +1316,89 @@ class MemoryEngine:
|
|
|
1058
1316
|
)
|
|
1059
1317
|
await asyncio.sleep(wait_time)
|
|
1060
1318
|
else:
|
|
1061
|
-
# Not a connection error or out of retries - raise
|
|
1319
|
+
# Not a connection error or out of retries - call post-hook and raise
|
|
1320
|
+
error_msg = str(e)
|
|
1321
|
+
if self._operation_validator:
|
|
1322
|
+
from hindsight_api.extensions.operation_validator import RecallResult
|
|
1323
|
+
|
|
1324
|
+
result_ctx = RecallResult(
|
|
1325
|
+
bank_id=bank_id,
|
|
1326
|
+
query=query,
|
|
1327
|
+
request_context=request_context,
|
|
1328
|
+
budget=budget,
|
|
1329
|
+
max_tokens=max_tokens,
|
|
1330
|
+
enable_trace=enable_trace,
|
|
1331
|
+
fact_types=list(fact_type),
|
|
1332
|
+
question_date=question_date,
|
|
1333
|
+
include_entities=include_entities,
|
|
1334
|
+
max_entity_tokens=max_entity_tokens,
|
|
1335
|
+
include_chunks=include_chunks,
|
|
1336
|
+
max_chunk_tokens=max_chunk_tokens,
|
|
1337
|
+
result=None,
|
|
1338
|
+
success=False,
|
|
1339
|
+
error=error_msg,
|
|
1340
|
+
)
|
|
1341
|
+
try:
|
|
1342
|
+
await self._operation_validator.on_recall_complete(result_ctx)
|
|
1343
|
+
except Exception as hook_err:
|
|
1344
|
+
logger.warning(f"Post-recall hook error (non-fatal): {hook_err}")
|
|
1062
1345
|
raise
|
|
1063
|
-
|
|
1346
|
+
else:
|
|
1347
|
+
# Exceeded max retries
|
|
1348
|
+
error_msg = "Exceeded maximum retries for search due to connection errors."
|
|
1349
|
+
if self._operation_validator:
|
|
1350
|
+
from hindsight_api.extensions.operation_validator import RecallResult
|
|
1351
|
+
|
|
1352
|
+
result_ctx = RecallResult(
|
|
1353
|
+
bank_id=bank_id,
|
|
1354
|
+
query=query,
|
|
1355
|
+
request_context=request_context,
|
|
1356
|
+
budget=budget,
|
|
1357
|
+
max_tokens=max_tokens,
|
|
1358
|
+
enable_trace=enable_trace,
|
|
1359
|
+
fact_types=list(fact_type),
|
|
1360
|
+
question_date=question_date,
|
|
1361
|
+
include_entities=include_entities,
|
|
1362
|
+
max_entity_tokens=max_entity_tokens,
|
|
1363
|
+
include_chunks=include_chunks,
|
|
1364
|
+
max_chunk_tokens=max_chunk_tokens,
|
|
1365
|
+
result=None,
|
|
1366
|
+
success=False,
|
|
1367
|
+
error=error_msg,
|
|
1368
|
+
)
|
|
1369
|
+
try:
|
|
1370
|
+
await self._operation_validator.on_recall_complete(result_ctx)
|
|
1371
|
+
except Exception as hook_err:
|
|
1372
|
+
logger.warning(f"Post-recall hook error (non-fatal): {hook_err}")
|
|
1373
|
+
raise Exception(error_msg)
|
|
1374
|
+
|
|
1375
|
+
# Call post-operation hook for success
|
|
1376
|
+
if self._operation_validator and result is not None:
|
|
1377
|
+
from hindsight_api.extensions.operation_validator import RecallResult
|
|
1378
|
+
|
|
1379
|
+
result_ctx = RecallResult(
|
|
1380
|
+
bank_id=bank_id,
|
|
1381
|
+
query=query,
|
|
1382
|
+
request_context=request_context,
|
|
1383
|
+
budget=budget,
|
|
1384
|
+
max_tokens=max_tokens,
|
|
1385
|
+
enable_trace=enable_trace,
|
|
1386
|
+
fact_types=list(fact_type),
|
|
1387
|
+
question_date=question_date,
|
|
1388
|
+
include_entities=include_entities,
|
|
1389
|
+
max_entity_tokens=max_entity_tokens,
|
|
1390
|
+
include_chunks=include_chunks,
|
|
1391
|
+
max_chunk_tokens=max_chunk_tokens,
|
|
1392
|
+
result=result,
|
|
1393
|
+
success=True,
|
|
1394
|
+
error=None,
|
|
1395
|
+
)
|
|
1396
|
+
try:
|
|
1397
|
+
await self._operation_validator.on_recall_complete(result_ctx)
|
|
1398
|
+
except Exception as e:
|
|
1399
|
+
logger.warning(f"Post-recall hook error (non-fatal): {e}")
|
|
1400
|
+
|
|
1401
|
+
return result
|
|
1064
1402
|
|
|
1065
1403
|
async def _search_with_retries(
|
|
1066
1404
|
self,
|
|
@@ -1075,6 +1413,7 @@ class MemoryEngine:
|
|
|
1075
1413
|
max_entity_tokens: int = 500,
|
|
1076
1414
|
include_chunks: bool = False,
|
|
1077
1415
|
max_chunk_tokens: int = 8192,
|
|
1416
|
+
request_context: "RequestContext" = None,
|
|
1078
1417
|
) -> RecallResultModel:
|
|
1079
1418
|
"""
|
|
1080
1419
|
Search implementation with modular retrieval and reranking.
|
|
@@ -1465,10 +1804,10 @@ class MemoryEngine:
|
|
|
1465
1804
|
if unit_ids:
|
|
1466
1805
|
async with acquire_with_retry(pool) as entity_conn:
|
|
1467
1806
|
entity_rows = await entity_conn.fetch(
|
|
1468
|
-
"""
|
|
1807
|
+
f"""
|
|
1469
1808
|
SELECT ue.unit_id, e.id as entity_id, e.canonical_name
|
|
1470
|
-
FROM unit_entities ue
|
|
1471
|
-
JOIN entities e ON ue.entity_id = e.id
|
|
1809
|
+
FROM {fq_table("unit_entities")} ue
|
|
1810
|
+
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
1472
1811
|
WHERE ue.unit_id = ANY($1::uuid[])
|
|
1473
1812
|
""",
|
|
1474
1813
|
unit_ids,
|
|
@@ -1534,7 +1873,9 @@ class MemoryEngine:
|
|
|
1534
1873
|
if total_entity_tokens >= max_entity_tokens:
|
|
1535
1874
|
break
|
|
1536
1875
|
|
|
1537
|
-
observations = await self.get_entity_observations(
|
|
1876
|
+
observations = await self.get_entity_observations(
|
|
1877
|
+
bank_id, entity_id, limit=5, request_context=request_context
|
|
1878
|
+
)
|
|
1538
1879
|
|
|
1539
1880
|
# Calculate tokens for this entity's observations
|
|
1540
1881
|
entity_tokens = 0
|
|
@@ -1572,9 +1913,9 @@ class MemoryEngine:
|
|
|
1572
1913
|
# Fetch chunk data from database using chunk_ids (no ORDER BY to preserve input order)
|
|
1573
1914
|
async with acquire_with_retry(pool) as conn:
|
|
1574
1915
|
chunks_rows = await conn.fetch(
|
|
1575
|
-
"""
|
|
1916
|
+
f"""
|
|
1576
1917
|
SELECT chunk_id, chunk_text, chunk_index
|
|
1577
|
-
FROM chunks
|
|
1918
|
+
FROM {fq_table("chunks")}
|
|
1578
1919
|
WHERE chunk_id = ANY($1::text[])
|
|
1579
1920
|
""",
|
|
1580
1921
|
chunk_ids_ordered,
|
|
@@ -1671,25 +2012,33 @@ class MemoryEngine:
|
|
|
1671
2012
|
|
|
1672
2013
|
return filtered_results, total_tokens
|
|
1673
2014
|
|
|
1674
|
-
async def get_document(
|
|
2015
|
+
async def get_document(
|
|
2016
|
+
self,
|
|
2017
|
+
document_id: str,
|
|
2018
|
+
bank_id: str,
|
|
2019
|
+
*,
|
|
2020
|
+
request_context: "RequestContext",
|
|
2021
|
+
) -> dict[str, Any] | None:
|
|
1675
2022
|
"""
|
|
1676
2023
|
Retrieve document metadata and statistics.
|
|
1677
2024
|
|
|
1678
2025
|
Args:
|
|
1679
2026
|
document_id: Document ID to retrieve
|
|
1680
2027
|
bank_id: bank ID that owns the document
|
|
2028
|
+
request_context: Request context for authentication.
|
|
1681
2029
|
|
|
1682
2030
|
Returns:
|
|
1683
2031
|
Dictionary with document info or None if not found
|
|
1684
2032
|
"""
|
|
2033
|
+
await self._authenticate_tenant(request_context)
|
|
1685
2034
|
pool = await self._get_pool()
|
|
1686
2035
|
async with acquire_with_retry(pool) as conn:
|
|
1687
2036
|
doc = await conn.fetchrow(
|
|
1688
|
-
"""
|
|
2037
|
+
f"""
|
|
1689
2038
|
SELECT d.id, d.bank_id, d.original_text, d.content_hash,
|
|
1690
2039
|
d.created_at, d.updated_at, COUNT(mu.id) as unit_count
|
|
1691
|
-
FROM documents d
|
|
1692
|
-
LEFT JOIN memory_units mu ON mu.document_id = d.id
|
|
2040
|
+
FROM {fq_table("documents")} d
|
|
2041
|
+
LEFT JOIN {fq_table("memory_units")} mu ON mu.document_id = d.id
|
|
1693
2042
|
WHERE d.id = $1 AND d.bank_id = $2
|
|
1694
2043
|
GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
|
|
1695
2044
|
""",
|
|
@@ -1706,37 +2055,52 @@ class MemoryEngine:
|
|
|
1706
2055
|
"original_text": doc["original_text"],
|
|
1707
2056
|
"content_hash": doc["content_hash"],
|
|
1708
2057
|
"memory_unit_count": doc["unit_count"],
|
|
1709
|
-
"created_at": doc["created_at"],
|
|
1710
|
-
"updated_at": doc["updated_at"],
|
|
2058
|
+
"created_at": doc["created_at"].isoformat() if doc["created_at"] else None,
|
|
2059
|
+
"updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else None,
|
|
1711
2060
|
}
|
|
1712
2061
|
|
|
1713
|
-
async def delete_document(
|
|
2062
|
+
async def delete_document(
|
|
2063
|
+
self,
|
|
2064
|
+
document_id: str,
|
|
2065
|
+
bank_id: str,
|
|
2066
|
+
*,
|
|
2067
|
+
request_context: "RequestContext",
|
|
2068
|
+
) -> dict[str, int]:
|
|
1714
2069
|
"""
|
|
1715
2070
|
Delete a document and all its associated memory units and links.
|
|
1716
2071
|
|
|
1717
2072
|
Args:
|
|
1718
2073
|
document_id: Document ID to delete
|
|
1719
2074
|
bank_id: bank ID that owns the document
|
|
2075
|
+
request_context: Request context for authentication.
|
|
1720
2076
|
|
|
1721
2077
|
Returns:
|
|
1722
2078
|
Dictionary with counts of deleted items
|
|
1723
2079
|
"""
|
|
2080
|
+
await self._authenticate_tenant(request_context)
|
|
1724
2081
|
pool = await self._get_pool()
|
|
1725
2082
|
async with acquire_with_retry(pool) as conn:
|
|
1726
2083
|
async with conn.transaction():
|
|
1727
2084
|
# Count units before deletion
|
|
1728
2085
|
units_count = await conn.fetchval(
|
|
1729
|
-
"SELECT COUNT(*) FROM memory_units WHERE document_id = $1", document_id
|
|
2086
|
+
f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE document_id = $1", document_id
|
|
1730
2087
|
)
|
|
1731
2088
|
|
|
1732
2089
|
# Delete document (cascades to memory_units and all their links)
|
|
1733
2090
|
deleted = await conn.fetchval(
|
|
1734
|
-
"DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id",
|
|
2091
|
+
f"DELETE FROM {fq_table('documents')} WHERE id = $1 AND bank_id = $2 RETURNING id",
|
|
2092
|
+
document_id,
|
|
2093
|
+
bank_id,
|
|
1735
2094
|
)
|
|
1736
2095
|
|
|
1737
2096
|
return {"document_deleted": 1 if deleted else 0, "memory_units_deleted": units_count if deleted else 0}
|
|
1738
2097
|
|
|
1739
|
-
async def delete_memory_unit(
|
|
2098
|
+
async def delete_memory_unit(
|
|
2099
|
+
self,
|
|
2100
|
+
unit_id: str,
|
|
2101
|
+
*,
|
|
2102
|
+
request_context: "RequestContext",
|
|
2103
|
+
) -> dict[str, Any]:
|
|
1740
2104
|
"""
|
|
1741
2105
|
Delete a single memory unit and all its associated links.
|
|
1742
2106
|
|
|
@@ -1747,15 +2111,19 @@ class MemoryEngine:
|
|
|
1747
2111
|
|
|
1748
2112
|
Args:
|
|
1749
2113
|
unit_id: UUID of the memory unit to delete
|
|
2114
|
+
request_context: Request context for authentication.
|
|
1750
2115
|
|
|
1751
2116
|
Returns:
|
|
1752
2117
|
Dictionary with deletion result
|
|
1753
2118
|
"""
|
|
2119
|
+
await self._authenticate_tenant(request_context)
|
|
1754
2120
|
pool = await self._get_pool()
|
|
1755
2121
|
async with acquire_with_retry(pool) as conn:
|
|
1756
2122
|
async with conn.transaction():
|
|
1757
2123
|
# Delete the memory unit (cascades to links and associations)
|
|
1758
|
-
deleted = await conn.fetchval(
|
|
2124
|
+
deleted = await conn.fetchval(
|
|
2125
|
+
f"DELETE FROM {fq_table('memory_units')} WHERE id = $1 RETURNING id", unit_id
|
|
2126
|
+
)
|
|
1759
2127
|
|
|
1760
2128
|
return {
|
|
1761
2129
|
"success": deleted is not None,
|
|
@@ -1765,7 +2133,13 @@ class MemoryEngine:
|
|
|
1765
2133
|
else "Memory unit not found",
|
|
1766
2134
|
}
|
|
1767
2135
|
|
|
1768
|
-
async def delete_bank(
|
|
2136
|
+
async def delete_bank(
|
|
2137
|
+
self,
|
|
2138
|
+
bank_id: str,
|
|
2139
|
+
fact_type: str | None = None,
|
|
2140
|
+
*,
|
|
2141
|
+
request_context: "RequestContext",
|
|
2142
|
+
) -> dict[str, int]:
|
|
1769
2143
|
"""
|
|
1770
2144
|
Delete all data for a specific agent (multi-tenant cleanup).
|
|
1771
2145
|
|
|
@@ -1780,10 +2154,12 @@ class MemoryEngine:
|
|
|
1780
2154
|
Args:
|
|
1781
2155
|
bank_id: bank ID to delete
|
|
1782
2156
|
fact_type: Optional fact type filter (world, experience, opinion). If provided, only deletes memories of that type.
|
|
2157
|
+
request_context: Request context for authentication.
|
|
1783
2158
|
|
|
1784
2159
|
Returns:
|
|
1785
2160
|
Dictionary with counts of deleted items
|
|
1786
2161
|
"""
|
|
2162
|
+
await self._authenticate_tenant(request_context)
|
|
1787
2163
|
pool = await self._get_pool()
|
|
1788
2164
|
async with acquire_with_retry(pool) as conn:
|
|
1789
2165
|
# Ensure connection is not in read-only mode (can happen with connection poolers)
|
|
@@ -1793,12 +2169,14 @@ class MemoryEngine:
|
|
|
1793
2169
|
if fact_type:
|
|
1794
2170
|
# Delete only memories of a specific fact type
|
|
1795
2171
|
units_count = await conn.fetchval(
|
|
1796
|
-
"SELECT COUNT(*) FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
|
|
2172
|
+
f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1 AND fact_type = $2",
|
|
1797
2173
|
bank_id,
|
|
1798
2174
|
fact_type,
|
|
1799
2175
|
)
|
|
1800
2176
|
await conn.execute(
|
|
1801
|
-
"DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
|
|
2177
|
+
f"DELETE FROM {fq_table('memory_units')} WHERE bank_id = $1 AND fact_type = $2",
|
|
2178
|
+
bank_id,
|
|
2179
|
+
fact_type,
|
|
1802
2180
|
)
|
|
1803
2181
|
|
|
1804
2182
|
# Note: We don't delete entities when fact_type is specified,
|
|
@@ -1807,26 +2185,26 @@ class MemoryEngine:
|
|
|
1807
2185
|
else:
|
|
1808
2186
|
# Delete all data for the bank
|
|
1809
2187
|
units_count = await conn.fetchval(
|
|
1810
|
-
"SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id
|
|
2188
|
+
f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1", bank_id
|
|
1811
2189
|
)
|
|
1812
2190
|
entities_count = await conn.fetchval(
|
|
1813
|
-
"SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id
|
|
2191
|
+
f"SELECT COUNT(*) FROM {fq_table('entities')} WHERE bank_id = $1", bank_id
|
|
1814
2192
|
)
|
|
1815
2193
|
documents_count = await conn.fetchval(
|
|
1816
|
-
"SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id
|
|
2194
|
+
f"SELECT COUNT(*) FROM {fq_table('documents')} WHERE bank_id = $1", bank_id
|
|
1817
2195
|
)
|
|
1818
2196
|
|
|
1819
2197
|
# Delete documents (cascades to chunks)
|
|
1820
|
-
await conn.execute("DELETE FROM documents WHERE bank_id = $1", bank_id)
|
|
2198
|
+
await conn.execute(f"DELETE FROM {fq_table('documents')} WHERE bank_id = $1", bank_id)
|
|
1821
2199
|
|
|
1822
2200
|
# Delete memory units (cascades to unit_entities, memory_links)
|
|
1823
|
-
await conn.execute("DELETE FROM memory_units WHERE bank_id = $1", bank_id)
|
|
2201
|
+
await conn.execute(f"DELETE FROM {fq_table('memory_units')} WHERE bank_id = $1", bank_id)
|
|
1824
2202
|
|
|
1825
2203
|
# Delete entities (cascades to unit_entities, entity_cooccurrences, memory_links with entity_id)
|
|
1826
|
-
await conn.execute("DELETE FROM entities WHERE bank_id = $1", bank_id)
|
|
2204
|
+
await conn.execute(f"DELETE FROM {fq_table('entities')} WHERE bank_id = $1", bank_id)
|
|
1827
2205
|
|
|
1828
2206
|
# Delete the bank profile itself
|
|
1829
|
-
await conn.execute("DELETE FROM banks WHERE bank_id = $1", bank_id)
|
|
2207
|
+
await conn.execute(f"DELETE FROM {fq_table('banks')} WHERE bank_id = $1", bank_id)
|
|
1830
2208
|
|
|
1831
2209
|
return {
|
|
1832
2210
|
"memory_units_deleted": units_count,
|
|
@@ -1838,17 +2216,25 @@ class MemoryEngine:
|
|
|
1838
2216
|
except Exception as e:
|
|
1839
2217
|
raise Exception(f"Failed to delete agent data: {str(e)}")
|
|
1840
2218
|
|
|
1841
|
-
async def get_graph_data(
|
|
2219
|
+
async def get_graph_data(
|
|
2220
|
+
self,
|
|
2221
|
+
bank_id: str | None = None,
|
|
2222
|
+
fact_type: str | None = None,
|
|
2223
|
+
*,
|
|
2224
|
+
request_context: "RequestContext",
|
|
2225
|
+
):
|
|
1842
2226
|
"""
|
|
1843
2227
|
Get graph data for visualization.
|
|
1844
2228
|
|
|
1845
2229
|
Args:
|
|
1846
2230
|
bank_id: Filter by bank ID
|
|
1847
2231
|
fact_type: Filter by fact type (world, experience, opinion)
|
|
2232
|
+
request_context: Request context for authentication.
|
|
1848
2233
|
|
|
1849
2234
|
Returns:
|
|
1850
2235
|
Dict with nodes, edges, and table_rows
|
|
1851
2236
|
"""
|
|
2237
|
+
await self._authenticate_tenant(request_context)
|
|
1852
2238
|
pool = await self._get_pool()
|
|
1853
2239
|
async with acquire_with_retry(pool) as conn:
|
|
1854
2240
|
# Get memory units, optionally filtered by bank_id and fact_type
|
|
@@ -1871,7 +2257,7 @@ class MemoryEngine:
|
|
|
1871
2257
|
units = await conn.fetch(
|
|
1872
2258
|
f"""
|
|
1873
2259
|
SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
|
|
1874
|
-
FROM memory_units
|
|
2260
|
+
FROM {fq_table("memory_units")}
|
|
1875
2261
|
{where_clause}
|
|
1876
2262
|
ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
|
|
1877
2263
|
LIMIT 1000
|
|
@@ -1884,15 +2270,15 @@ class MemoryEngine:
|
|
|
1884
2270
|
unit_ids = [row["id"] for row in units]
|
|
1885
2271
|
if unit_ids:
|
|
1886
2272
|
links = await conn.fetch(
|
|
1887
|
-
"""
|
|
2273
|
+
f"""
|
|
1888
2274
|
SELECT DISTINCT ON (LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid))
|
|
1889
2275
|
ml.from_unit_id,
|
|
1890
2276
|
ml.to_unit_id,
|
|
1891
2277
|
ml.link_type,
|
|
1892
2278
|
ml.weight,
|
|
1893
2279
|
e.canonical_name as entity_name
|
|
1894
|
-
FROM memory_links ml
|
|
1895
|
-
LEFT JOIN entities e ON ml.entity_id = e.id
|
|
2280
|
+
FROM {fq_table("memory_links")} ml
|
|
2281
|
+
LEFT JOIN {fq_table("entities")} e ON ml.entity_id = e.id
|
|
1896
2282
|
WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.to_unit_id = ANY($1::uuid[])
|
|
1897
2283
|
ORDER BY LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid), ml.weight DESC
|
|
1898
2284
|
""",
|
|
@@ -1902,10 +2288,10 @@ class MemoryEngine:
|
|
|
1902
2288
|
links = []
|
|
1903
2289
|
|
|
1904
2290
|
# Get entity information
|
|
1905
|
-
unit_entities = await conn.fetch("""
|
|
2291
|
+
unit_entities = await conn.fetch(f"""
|
|
1906
2292
|
SELECT ue.unit_id, e.canonical_name
|
|
1907
|
-
FROM unit_entities ue
|
|
1908
|
-
JOIN entities e ON ue.entity_id = e.id
|
|
2293
|
+
FROM {fq_table("unit_entities")} ue
|
|
2294
|
+
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
1909
2295
|
ORDER BY ue.unit_id
|
|
1910
2296
|
""")
|
|
1911
2297
|
|
|
@@ -2017,11 +2403,13 @@ class MemoryEngine:
|
|
|
2017
2403
|
|
|
2018
2404
|
async def list_memory_units(
|
|
2019
2405
|
self,
|
|
2020
|
-
bank_id: str
|
|
2406
|
+
bank_id: str,
|
|
2407
|
+
*,
|
|
2021
2408
|
fact_type: str | None = None,
|
|
2022
2409
|
search_query: str | None = None,
|
|
2023
2410
|
limit: int = 100,
|
|
2024
2411
|
offset: int = 0,
|
|
2412
|
+
request_context: "RequestContext",
|
|
2025
2413
|
):
|
|
2026
2414
|
"""
|
|
2027
2415
|
List memory units for table view with optional full-text search.
|
|
@@ -2032,10 +2420,12 @@ class MemoryEngine:
|
|
|
2032
2420
|
search_query: Full-text search query (searches text and context fields)
|
|
2033
2421
|
limit: Maximum number of results to return
|
|
2034
2422
|
offset: Offset for pagination
|
|
2423
|
+
request_context: Request context for authentication.
|
|
2035
2424
|
|
|
2036
2425
|
Returns:
|
|
2037
2426
|
Dict with items (list of memory units) and total count
|
|
2038
2427
|
"""
|
|
2428
|
+
await self._authenticate_tenant(request_context)
|
|
2039
2429
|
pool = await self._get_pool()
|
|
2040
2430
|
async with acquire_with_retry(pool) as conn:
|
|
2041
2431
|
# Build query conditions
|
|
@@ -2064,7 +2454,7 @@ class MemoryEngine:
|
|
|
2064
2454
|
# Get total count
|
|
2065
2455
|
count_query = f"""
|
|
2066
2456
|
SELECT COUNT(*) as total
|
|
2067
|
-
FROM memory_units
|
|
2457
|
+
FROM {fq_table("memory_units")}
|
|
2068
2458
|
{where_clause}
|
|
2069
2459
|
"""
|
|
2070
2460
|
count_result = await conn.fetchrow(count_query, *query_params)
|
|
@@ -2082,7 +2472,7 @@ class MemoryEngine:
|
|
|
2082
2472
|
units = await conn.fetch(
|
|
2083
2473
|
f"""
|
|
2084
2474
|
SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id
|
|
2085
|
-
FROM memory_units
|
|
2475
|
+
FROM {fq_table("memory_units")}
|
|
2086
2476
|
{where_clause}
|
|
2087
2477
|
ORDER BY mentioned_at DESC NULLS LAST, created_at DESC
|
|
2088
2478
|
LIMIT {limit_param} OFFSET {offset_param}
|
|
@@ -2094,10 +2484,10 @@ class MemoryEngine:
|
|
|
2094
2484
|
if units:
|
|
2095
2485
|
unit_ids = [row["id"] for row in units]
|
|
2096
2486
|
unit_entities = await conn.fetch(
|
|
2097
|
-
"""
|
|
2487
|
+
f"""
|
|
2098
2488
|
SELECT ue.unit_id, e.canonical_name
|
|
2099
|
-
FROM unit_entities ue
|
|
2100
|
-
JOIN entities e ON ue.entity_id = e.id
|
|
2489
|
+
FROM {fq_table("unit_entities")} ue
|
|
2490
|
+
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
2101
2491
|
WHERE ue.unit_id = ANY($1::uuid[])
|
|
2102
2492
|
ORDER BY ue.unit_id
|
|
2103
2493
|
""",
|
|
@@ -2138,7 +2528,15 @@ class MemoryEngine:
|
|
|
2138
2528
|
|
|
2139
2529
|
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
|
2140
2530
|
|
|
2141
|
-
async def list_documents(
|
|
2531
|
+
async def list_documents(
|
|
2532
|
+
self,
|
|
2533
|
+
bank_id: str,
|
|
2534
|
+
*,
|
|
2535
|
+
search_query: str | None = None,
|
|
2536
|
+
limit: int = 100,
|
|
2537
|
+
offset: int = 0,
|
|
2538
|
+
request_context: "RequestContext",
|
|
2539
|
+
):
|
|
2142
2540
|
"""
|
|
2143
2541
|
List documents with optional search and pagination.
|
|
2144
2542
|
|
|
@@ -2147,10 +2545,12 @@ class MemoryEngine:
|
|
|
2147
2545
|
search_query: Search in document ID
|
|
2148
2546
|
limit: Maximum number of results
|
|
2149
2547
|
offset: Offset for pagination
|
|
2548
|
+
request_context: Request context for authentication.
|
|
2150
2549
|
|
|
2151
2550
|
Returns:
|
|
2152
2551
|
Dict with items (list of documents without original_text) and total count
|
|
2153
2552
|
"""
|
|
2553
|
+
await self._authenticate_tenant(request_context)
|
|
2154
2554
|
pool = await self._get_pool()
|
|
2155
2555
|
async with acquire_with_retry(pool) as conn:
|
|
2156
2556
|
# Build query conditions
|
|
@@ -2173,7 +2573,7 @@ class MemoryEngine:
|
|
|
2173
2573
|
# Get total count
|
|
2174
2574
|
count_query = f"""
|
|
2175
2575
|
SELECT COUNT(*) as total
|
|
2176
|
-
FROM documents
|
|
2576
|
+
FROM {fq_table("documents")}
|
|
2177
2577
|
{where_clause}
|
|
2178
2578
|
"""
|
|
2179
2579
|
count_result = await conn.fetchrow(count_query, *query_params)
|
|
@@ -2198,7 +2598,7 @@ class MemoryEngine:
|
|
|
2198
2598
|
updated_at,
|
|
2199
2599
|
LENGTH(original_text) as text_length,
|
|
2200
2600
|
retain_params
|
|
2201
|
-
FROM documents
|
|
2601
|
+
FROM {fq_table("documents")}
|
|
2202
2602
|
{where_clause}
|
|
2203
2603
|
ORDER BY created_at DESC
|
|
2204
2604
|
LIMIT {limit_param} OFFSET {offset_param}
|
|
@@ -2224,7 +2624,7 @@ class MemoryEngine:
|
|
|
2224
2624
|
unit_counts = await conn.fetch(
|
|
2225
2625
|
f"""
|
|
2226
2626
|
SELECT document_id, bank_id, COUNT(*) as unit_count
|
|
2227
|
-
FROM memory_units
|
|
2627
|
+
FROM {fq_table("memory_units")}
|
|
2228
2628
|
WHERE {where_clause_count}
|
|
2229
2629
|
GROUP BY document_id, bank_id
|
|
2230
2630
|
""",
|
|
@@ -2258,75 +2658,27 @@ class MemoryEngine:
|
|
|
2258
2658
|
|
|
2259
2659
|
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
|
2260
2660
|
|
|
2261
|
-
async def
|
|
2262
|
-
|
|
2263
|
-
|
|
2264
|
-
|
|
2265
|
-
|
|
2266
|
-
|
|
2267
|
-
bank_id: bank ID
|
|
2268
|
-
|
|
2269
|
-
Returns:
|
|
2270
|
-
Dict with document details including original_text, or None if not found
|
|
2271
|
-
"""
|
|
2272
|
-
pool = await self._get_pool()
|
|
2273
|
-
async with acquire_with_retry(pool) as conn:
|
|
2274
|
-
doc = await conn.fetchrow(
|
|
2275
|
-
"""
|
|
2276
|
-
SELECT
|
|
2277
|
-
id,
|
|
2278
|
-
bank_id,
|
|
2279
|
-
original_text,
|
|
2280
|
-
content_hash,
|
|
2281
|
-
created_at,
|
|
2282
|
-
updated_at,
|
|
2283
|
-
retain_params
|
|
2284
|
-
FROM documents
|
|
2285
|
-
WHERE id = $1 AND bank_id = $2
|
|
2286
|
-
""",
|
|
2287
|
-
document_id,
|
|
2288
|
-
bank_id,
|
|
2289
|
-
)
|
|
2290
|
-
|
|
2291
|
-
if not doc:
|
|
2292
|
-
return None
|
|
2293
|
-
|
|
2294
|
-
# Get memory unit count
|
|
2295
|
-
unit_count_row = await conn.fetchrow(
|
|
2296
|
-
"""
|
|
2297
|
-
SELECT COUNT(*) as unit_count
|
|
2298
|
-
FROM memory_units
|
|
2299
|
-
WHERE document_id = $1 AND bank_id = $2
|
|
2300
|
-
""",
|
|
2301
|
-
document_id,
|
|
2302
|
-
bank_id,
|
|
2303
|
-
)
|
|
2304
|
-
|
|
2305
|
-
return {
|
|
2306
|
-
"id": doc["id"],
|
|
2307
|
-
"bank_id": doc["bank_id"],
|
|
2308
|
-
"original_text": doc["original_text"],
|
|
2309
|
-
"content_hash": doc["content_hash"],
|
|
2310
|
-
"created_at": doc["created_at"].isoformat() if doc["created_at"] else "",
|
|
2311
|
-
"updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else "",
|
|
2312
|
-
"memory_unit_count": unit_count_row["unit_count"] if unit_count_row else 0,
|
|
2313
|
-
"retain_params": doc["retain_params"] if doc["retain_params"] else None,
|
|
2314
|
-
}
|
|
2315
|
-
|
|
2316
|
-
async def get_chunk(self, chunk_id: str):
|
|
2661
|
+
async def get_chunk(
|
|
2662
|
+
self,
|
|
2663
|
+
chunk_id: str,
|
|
2664
|
+
*,
|
|
2665
|
+
request_context: "RequestContext",
|
|
2666
|
+
):
|
|
2317
2667
|
"""
|
|
2318
2668
|
Get a specific chunk by its ID.
|
|
2319
2669
|
|
|
2320
2670
|
Args:
|
|
2321
2671
|
chunk_id: Chunk ID (format: bank_id_document_id_chunk_index)
|
|
2672
|
+
request_context: Request context for authentication.
|
|
2322
2673
|
|
|
2323
2674
|
Returns:
|
|
2324
2675
|
Dict with chunk details including chunk_text, or None if not found
|
|
2325
2676
|
"""
|
|
2677
|
+
await self._authenticate_tenant(request_context)
|
|
2326
2678
|
pool = await self._get_pool()
|
|
2327
2679
|
async with acquire_with_retry(pool) as conn:
|
|
2328
2680
|
chunk = await conn.fetchrow(
|
|
2329
|
-
"""
|
|
2681
|
+
f"""
|
|
2330
2682
|
SELECT
|
|
2331
2683
|
chunk_id,
|
|
2332
2684
|
document_id,
|
|
@@ -2334,7 +2686,7 @@ class MemoryEngine:
|
|
|
2334
2686
|
chunk_index,
|
|
2335
2687
|
chunk_text,
|
|
2336
2688
|
created_at
|
|
2337
|
-
FROM chunks
|
|
2689
|
+
FROM {fq_table("chunks")}
|
|
2338
2690
|
WHERE chunk_id = $1
|
|
2339
2691
|
""",
|
|
2340
2692
|
chunk_id,
|
|
@@ -2500,11 +2852,11 @@ Guidelines:
|
|
|
2500
2852
|
async with acquire_with_retry(pool) as conn:
|
|
2501
2853
|
# Find all opinions related to these entities
|
|
2502
2854
|
opinions = await conn.fetch(
|
|
2503
|
-
"""
|
|
2855
|
+
f"""
|
|
2504
2856
|
SELECT DISTINCT mu.id, mu.text, mu.confidence_score, e.canonical_name
|
|
2505
|
-
FROM memory_units mu
|
|
2506
|
-
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
2507
|
-
JOIN entities e ON ue.entity_id = e.id
|
|
2857
|
+
FROM {fq_table("memory_units")} mu
|
|
2858
|
+
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
2859
|
+
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
2508
2860
|
WHERE mu.bank_id = $1
|
|
2509
2861
|
AND mu.fact_type = 'opinion'
|
|
2510
2862
|
AND e.canonical_name = ANY($2::text[])
|
|
@@ -2559,8 +2911,8 @@ Guidelines:
|
|
|
2559
2911
|
if evaluation["action"] == "update" and evaluation["new_text"]:
|
|
2560
2912
|
# Update both text and confidence
|
|
2561
2913
|
await conn.execute(
|
|
2562
|
-
"""
|
|
2563
|
-
UPDATE memory_units
|
|
2914
|
+
f"""
|
|
2915
|
+
UPDATE {fq_table("memory_units")}
|
|
2564
2916
|
SET text = $1, confidence_score = $2, updated_at = NOW()
|
|
2565
2917
|
WHERE id = $3
|
|
2566
2918
|
""",
|
|
@@ -2571,8 +2923,8 @@ Guidelines:
|
|
|
2571
2923
|
else:
|
|
2572
2924
|
# Only update confidence
|
|
2573
2925
|
await conn.execute(
|
|
2574
|
-
"""
|
|
2575
|
-
UPDATE memory_units
|
|
2926
|
+
f"""
|
|
2927
|
+
UPDATE {fq_table("memory_units")}
|
|
2576
2928
|
SET confidence_score = $1, updated_at = NOW()
|
|
2577
2929
|
WHERE id = $2
|
|
2578
2930
|
""",
|
|
@@ -2591,32 +2943,61 @@ Guidelines:
|
|
|
2591
2943
|
|
|
2592
2944
|
# ==================== bank profile Methods ====================
|
|
2593
2945
|
|
|
2594
|
-
async def get_bank_profile(
|
|
2946
|
+
async def get_bank_profile(
|
|
2947
|
+
self,
|
|
2948
|
+
bank_id: str,
|
|
2949
|
+
*,
|
|
2950
|
+
request_context: "RequestContext",
|
|
2951
|
+
) -> dict[str, Any]:
|
|
2595
2952
|
"""
|
|
2596
2953
|
Get bank profile (name, disposition + background).
|
|
2597
2954
|
Auto-creates agent with default values if not exists.
|
|
2598
2955
|
|
|
2599
2956
|
Args:
|
|
2600
2957
|
bank_id: bank IDentifier
|
|
2958
|
+
request_context: Request context for authentication.
|
|
2601
2959
|
|
|
2602
2960
|
Returns:
|
|
2603
|
-
|
|
2961
|
+
Dict with name, disposition traits, and background
|
|
2604
2962
|
"""
|
|
2963
|
+
await self._authenticate_tenant(request_context)
|
|
2605
2964
|
pool = await self._get_pool()
|
|
2606
|
-
|
|
2607
|
-
|
|
2608
|
-
|
|
2965
|
+
profile = await bank_utils.get_bank_profile(pool, bank_id)
|
|
2966
|
+
disposition = profile["disposition"]
|
|
2967
|
+
return {
|
|
2968
|
+
"bank_id": bank_id,
|
|
2969
|
+
"name": profile["name"],
|
|
2970
|
+
"disposition": disposition,
|
|
2971
|
+
"background": profile["background"],
|
|
2972
|
+
}
|
|
2973
|
+
|
|
2974
|
+
async def update_bank_disposition(
|
|
2975
|
+
self,
|
|
2976
|
+
bank_id: str,
|
|
2977
|
+
disposition: dict[str, int],
|
|
2978
|
+
*,
|
|
2979
|
+
request_context: "RequestContext",
|
|
2980
|
+
) -> None:
|
|
2609
2981
|
"""
|
|
2610
2982
|
Update bank disposition traits.
|
|
2611
2983
|
|
|
2612
2984
|
Args:
|
|
2613
2985
|
bank_id: bank IDentifier
|
|
2614
2986
|
disposition: Dict with skepticism, literalism, empathy (all 1-5)
|
|
2987
|
+
request_context: Request context for authentication.
|
|
2615
2988
|
"""
|
|
2989
|
+
await self._authenticate_tenant(request_context)
|
|
2616
2990
|
pool = await self._get_pool()
|
|
2617
2991
|
await bank_utils.update_bank_disposition(pool, bank_id, disposition)
|
|
2618
2992
|
|
|
2619
|
-
async def merge_bank_background(
|
|
2993
|
+
async def merge_bank_background(
|
|
2994
|
+
self,
|
|
2995
|
+
bank_id: str,
|
|
2996
|
+
new_info: str,
|
|
2997
|
+
*,
|
|
2998
|
+
update_disposition: bool = True,
|
|
2999
|
+
request_context: "RequestContext",
|
|
3000
|
+
) -> dict[str, Any]:
|
|
2620
3001
|
"""
|
|
2621
3002
|
Merge new background information with existing background using LLM.
|
|
2622
3003
|
Normalizes to first person ("I") and resolves conflicts.
|
|
@@ -2626,20 +3007,30 @@ Guidelines:
|
|
|
2626
3007
|
bank_id: bank IDentifier
|
|
2627
3008
|
new_info: New background information to add/merge
|
|
2628
3009
|
update_disposition: If True, infer Big Five traits from background (default: True)
|
|
3010
|
+
request_context: Request context for authentication.
|
|
2629
3011
|
|
|
2630
3012
|
Returns:
|
|
2631
3013
|
Dict with 'background' (str) and optionally 'disposition' (dict) keys
|
|
2632
3014
|
"""
|
|
3015
|
+
await self._authenticate_tenant(request_context)
|
|
2633
3016
|
pool = await self._get_pool()
|
|
2634
3017
|
return await bank_utils.merge_bank_background(pool, self._llm_config, bank_id, new_info, update_disposition)
|
|
2635
3018
|
|
|
2636
|
-
async def list_banks(
|
|
3019
|
+
async def list_banks(
|
|
3020
|
+
self,
|
|
3021
|
+
*,
|
|
3022
|
+
request_context: "RequestContext",
|
|
3023
|
+
) -> list[dict[str, Any]]:
|
|
2637
3024
|
"""
|
|
2638
3025
|
List all agents in the system.
|
|
2639
3026
|
|
|
3027
|
+
Args:
|
|
3028
|
+
request_context: Request context for authentication.
|
|
3029
|
+
|
|
2640
3030
|
Returns:
|
|
2641
3031
|
List of dicts with bank_id, name, disposition, background, created_at, updated_at
|
|
2642
3032
|
"""
|
|
3033
|
+
await self._authenticate_tenant(request_context)
|
|
2643
3034
|
pool = await self._get_pool()
|
|
2644
3035
|
return await bank_utils.list_banks(pool)
|
|
2645
3036
|
|
|
@@ -2649,8 +3040,10 @@ Guidelines:
|
|
|
2649
3040
|
self,
|
|
2650
3041
|
bank_id: str,
|
|
2651
3042
|
query: str,
|
|
2652
|
-
|
|
2653
|
-
|
|
3043
|
+
*,
|
|
3044
|
+
budget: Budget | None = None,
|
|
3045
|
+
context: str | None = None,
|
|
3046
|
+
request_context: "RequestContext",
|
|
2654
3047
|
) -> ReflectResult:
|
|
2655
3048
|
"""
|
|
2656
3049
|
Reflect and formulate an answer using bank identity, world facts, and opinions.
|
|
@@ -2679,6 +3072,22 @@ Guidelines:
|
|
|
2679
3072
|
if self._llm_config is None:
|
|
2680
3073
|
raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
|
|
2681
3074
|
|
|
3075
|
+
# Authenticate tenant and set schema in context (for fq_table())
|
|
3076
|
+
await self._authenticate_tenant(request_context)
|
|
3077
|
+
|
|
3078
|
+
# Validate operation if validator is configured
|
|
3079
|
+
if self._operation_validator:
|
|
3080
|
+
from hindsight_api.extensions import ReflectContext
|
|
3081
|
+
|
|
3082
|
+
ctx = ReflectContext(
|
|
3083
|
+
bank_id=bank_id,
|
|
3084
|
+
query=query,
|
|
3085
|
+
request_context=request_context,
|
|
3086
|
+
budget=budget,
|
|
3087
|
+
context=context,
|
|
3088
|
+
)
|
|
3089
|
+
await self._validate_operation(self._operation_validator.validate_reflect(ctx))
|
|
3090
|
+
|
|
2682
3091
|
reflect_start = time.time()
|
|
2683
3092
|
reflect_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
|
|
2684
3093
|
log_buffer = []
|
|
@@ -2694,6 +3103,7 @@ Guidelines:
|
|
|
2694
3103
|
enable_trace=False,
|
|
2695
3104
|
fact_type=["experience", "world", "opinion"],
|
|
2696
3105
|
include_entities=True,
|
|
3106
|
+
request_context=request_context,
|
|
2697
3107
|
)
|
|
2698
3108
|
recall_time = time.time() - recall_start
|
|
2699
3109
|
|
|
@@ -2714,7 +3124,7 @@ Guidelines:
|
|
|
2714
3124
|
opinion_facts_text = think_utils.format_facts_for_prompt(opinion_results)
|
|
2715
3125
|
|
|
2716
3126
|
# Get bank profile (name, disposition + background)
|
|
2717
|
-
profile = await self.get_bank_profile(bank_id)
|
|
3127
|
+
profile = await self.get_bank_profile(bank_id, request_context=request_context)
|
|
2718
3128
|
name = profile["name"]
|
|
2719
3129
|
disposition = profile["disposition"] # Typed as DispositionTraits
|
|
2720
3130
|
background = profile["background"]
|
|
@@ -2758,12 +3168,33 @@ Guidelines:
|
|
|
2758
3168
|
logger.info("\n" + "\n".join(log_buffer))
|
|
2759
3169
|
|
|
2760
3170
|
# Return response with facts split by type
|
|
2761
|
-
|
|
3171
|
+
result = ReflectResult(
|
|
2762
3172
|
text=answer_text,
|
|
2763
3173
|
based_on={"world": world_results, "experience": agent_results, "opinion": opinion_results},
|
|
2764
3174
|
new_opinions=[], # Opinions are being extracted asynchronously
|
|
2765
3175
|
)
|
|
2766
3176
|
|
|
3177
|
+
# Call post-operation hook if validator is configured
|
|
3178
|
+
if self._operation_validator:
|
|
3179
|
+
from hindsight_api.extensions.operation_validator import ReflectResultContext
|
|
3180
|
+
|
|
3181
|
+
result_ctx = ReflectResultContext(
|
|
3182
|
+
bank_id=bank_id,
|
|
3183
|
+
query=query,
|
|
3184
|
+
request_context=request_context,
|
|
3185
|
+
budget=budget,
|
|
3186
|
+
context=context,
|
|
3187
|
+
result=result,
|
|
3188
|
+
success=True,
|
|
3189
|
+
error=None,
|
|
3190
|
+
)
|
|
3191
|
+
try:
|
|
3192
|
+
await self._operation_validator.on_reflect_complete(result_ctx)
|
|
3193
|
+
except Exception as e:
|
|
3194
|
+
logger.warning(f"Post-reflect hook error (non-fatal): {e}")
|
|
3195
|
+
|
|
3196
|
+
return result
|
|
3197
|
+
|
|
2767
3198
|
async def _extract_and_store_opinions_async(self, bank_id: str, answer_text: str, query: str):
|
|
2768
3199
|
"""
|
|
2769
3200
|
Background task to extract and store opinions from think response.
|
|
@@ -2784,6 +3215,10 @@ Guidelines:
|
|
|
2784
3215
|
from datetime import datetime
|
|
2785
3216
|
|
|
2786
3217
|
current_time = datetime.now(UTC)
|
|
3218
|
+
# Use internal request context for background tasks
|
|
3219
|
+
from hindsight_api.models import RequestContext
|
|
3220
|
+
|
|
3221
|
+
internal_context = RequestContext()
|
|
2787
3222
|
for opinion in new_opinions:
|
|
2788
3223
|
await self.retain_async(
|
|
2789
3224
|
bank_id=bank_id,
|
|
@@ -2792,12 +3227,20 @@ Guidelines:
|
|
|
2792
3227
|
event_date=current_time,
|
|
2793
3228
|
fact_type_override="opinion",
|
|
2794
3229
|
confidence_score=opinion.confidence,
|
|
3230
|
+
request_context=internal_context,
|
|
2795
3231
|
)
|
|
2796
3232
|
|
|
2797
3233
|
except Exception as e:
|
|
2798
3234
|
logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
|
|
2799
3235
|
|
|
2800
|
-
async def get_entity_observations(
|
|
3236
|
+
async def get_entity_observations(
|
|
3237
|
+
self,
|
|
3238
|
+
bank_id: str,
|
|
3239
|
+
entity_id: str,
|
|
3240
|
+
*,
|
|
3241
|
+
limit: int = 10,
|
|
3242
|
+
request_context: "RequestContext",
|
|
3243
|
+
) -> list[Any]:
|
|
2801
3244
|
"""
|
|
2802
3245
|
Get observations linked to an entity.
|
|
2803
3246
|
|
|
@@ -2805,17 +3248,19 @@ Guidelines:
|
|
|
2805
3248
|
bank_id: bank IDentifier
|
|
2806
3249
|
entity_id: Entity UUID to get observations for
|
|
2807
3250
|
limit: Maximum number of observations to return
|
|
3251
|
+
request_context: Request context for authentication.
|
|
2808
3252
|
|
|
2809
3253
|
Returns:
|
|
2810
3254
|
List of EntityObservation objects
|
|
2811
3255
|
"""
|
|
3256
|
+
await self._authenticate_tenant(request_context)
|
|
2812
3257
|
pool = await self._get_pool()
|
|
2813
3258
|
async with acquire_with_retry(pool) as conn:
|
|
2814
3259
|
rows = await conn.fetch(
|
|
2815
|
-
"""
|
|
3260
|
+
f"""
|
|
2816
3261
|
SELECT mu.text, mu.mentioned_at
|
|
2817
|
-
FROM memory_units mu
|
|
2818
|
-
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
3262
|
+
FROM {fq_table("memory_units")} mu
|
|
3263
|
+
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
2819
3264
|
WHERE mu.bank_id = $1
|
|
2820
3265
|
AND mu.fact_type = 'observation'
|
|
2821
3266
|
AND ue.entity_id = $2
|
|
@@ -2833,23 +3278,31 @@ Guidelines:
|
|
|
2833
3278
|
observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
|
|
2834
3279
|
return observations
|
|
2835
3280
|
|
|
2836
|
-
async def list_entities(
|
|
3281
|
+
async def list_entities(
|
|
3282
|
+
self,
|
|
3283
|
+
bank_id: str,
|
|
3284
|
+
*,
|
|
3285
|
+
limit: int = 100,
|
|
3286
|
+
request_context: "RequestContext",
|
|
3287
|
+
) -> list[dict[str, Any]]:
|
|
2837
3288
|
"""
|
|
2838
3289
|
List all entities for a bank.
|
|
2839
3290
|
|
|
2840
3291
|
Args:
|
|
2841
3292
|
bank_id: bank IDentifier
|
|
2842
3293
|
limit: Maximum number of entities to return
|
|
3294
|
+
request_context: Request context for authentication.
|
|
2843
3295
|
|
|
2844
3296
|
Returns:
|
|
2845
3297
|
List of entity dicts with id, canonical_name, mention_count, first_seen, last_seen
|
|
2846
3298
|
"""
|
|
3299
|
+
await self._authenticate_tenant(request_context)
|
|
2847
3300
|
pool = await self._get_pool()
|
|
2848
3301
|
async with acquire_with_retry(pool) as conn:
|
|
2849
3302
|
rows = await conn.fetch(
|
|
2850
|
-
"""
|
|
3303
|
+
f"""
|
|
2851
3304
|
SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
|
|
2852
|
-
FROM entities
|
|
3305
|
+
FROM {fq_table("entities")}
|
|
2853
3306
|
WHERE bank_id = $1
|
|
2854
3307
|
ORDER BY mention_count DESC, last_seen DESC
|
|
2855
3308
|
LIMIT $2
|
|
@@ -2884,7 +3337,15 @@ Guidelines:
|
|
|
2884
3337
|
)
|
|
2885
3338
|
return entities
|
|
2886
3339
|
|
|
2887
|
-
async def get_entity_state(
|
|
3340
|
+
async def get_entity_state(
|
|
3341
|
+
self,
|
|
3342
|
+
bank_id: str,
|
|
3343
|
+
entity_id: str,
|
|
3344
|
+
entity_name: str,
|
|
3345
|
+
*,
|
|
3346
|
+
limit: int = 10,
|
|
3347
|
+
request_context: "RequestContext",
|
|
3348
|
+
) -> EntityState:
|
|
2888
3349
|
"""
|
|
2889
3350
|
Get the current state (mental model) of an entity.
|
|
2890
3351
|
|
|
@@ -2893,16 +3354,26 @@ Guidelines:
|
|
|
2893
3354
|
entity_id: Entity UUID
|
|
2894
3355
|
entity_name: Canonical name of the entity
|
|
2895
3356
|
limit: Maximum number of observations to include
|
|
3357
|
+
request_context: Request context for authentication.
|
|
2896
3358
|
|
|
2897
3359
|
Returns:
|
|
2898
3360
|
EntityState with observations
|
|
2899
3361
|
"""
|
|
2900
|
-
observations = await self.get_entity_observations(
|
|
3362
|
+
observations = await self.get_entity_observations(
|
|
3363
|
+
bank_id, entity_id, limit=limit, request_context=request_context
|
|
3364
|
+
)
|
|
2901
3365
|
return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=observations)
|
|
2902
3366
|
|
|
2903
3367
|
async def regenerate_entity_observations(
|
|
2904
|
-
self,
|
|
2905
|
-
|
|
3368
|
+
self,
|
|
3369
|
+
bank_id: str,
|
|
3370
|
+
entity_id: str,
|
|
3371
|
+
entity_name: str,
|
|
3372
|
+
*,
|
|
3373
|
+
version: str | None = None,
|
|
3374
|
+
conn=None,
|
|
3375
|
+
request_context: "RequestContext",
|
|
3376
|
+
) -> None:
|
|
2906
3377
|
"""
|
|
2907
3378
|
Regenerate observations for an entity by:
|
|
2908
3379
|
1. Checking version for deduplication (if provided)
|
|
@@ -2917,10 +3388,9 @@ Guidelines:
|
|
|
2917
3388
|
entity_name: Canonical name of the entity
|
|
2918
3389
|
version: Entity's last_seen timestamp when task was created (for deduplication)
|
|
2919
3390
|
conn: Optional database connection (for transactional atomicity with caller)
|
|
2920
|
-
|
|
2921
|
-
Returns:
|
|
2922
|
-
List of created observation IDs
|
|
3391
|
+
request_context: Request context for authentication.
|
|
2923
3392
|
"""
|
|
3393
|
+
await self._authenticate_tenant(request_context)
|
|
2924
3394
|
pool = await self._get_pool()
|
|
2925
3395
|
entity_uuid = uuid.UUID(entity_id)
|
|
2926
3396
|
|
|
@@ -2942,9 +3412,9 @@ Guidelines:
|
|
|
2942
3412
|
# Step 1: Check version for deduplication
|
|
2943
3413
|
if version:
|
|
2944
3414
|
current_last_seen = await fetchval_with_conn(
|
|
2945
|
-
"""
|
|
3415
|
+
f"""
|
|
2946
3416
|
SELECT last_seen
|
|
2947
|
-
FROM entities
|
|
3417
|
+
FROM {fq_table("entities")}
|
|
2948
3418
|
WHERE id = $1 AND bank_id = $2
|
|
2949
3419
|
""",
|
|
2950
3420
|
entity_uuid,
|
|
@@ -2956,10 +3426,10 @@ Guidelines:
|
|
|
2956
3426
|
|
|
2957
3427
|
# Step 2: Get all facts mentioning this entity (exclude observations themselves)
|
|
2958
3428
|
rows = await fetch_with_conn(
|
|
2959
|
-
"""
|
|
3429
|
+
f"""
|
|
2960
3430
|
SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
|
|
2961
|
-
FROM memory_units mu
|
|
2962
|
-
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
3431
|
+
FROM {fq_table("memory_units")} mu
|
|
3432
|
+
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
2963
3433
|
WHERE mu.bank_id = $1
|
|
2964
3434
|
AND ue.entity_id = $2
|
|
2965
3435
|
AND mu.fact_type IN ('world', 'experience')
|
|
@@ -2999,12 +3469,12 @@ Guidelines:
|
|
|
2999
3469
|
async def do_db_operations(db_conn):
|
|
3000
3470
|
# Delete old observations for this entity
|
|
3001
3471
|
await db_conn.execute(
|
|
3002
|
-
"""
|
|
3003
|
-
DELETE FROM memory_units
|
|
3472
|
+
f"""
|
|
3473
|
+
DELETE FROM {fq_table("memory_units")}
|
|
3004
3474
|
WHERE id IN (
|
|
3005
3475
|
SELECT mu.id
|
|
3006
|
-
FROM memory_units mu
|
|
3007
|
-
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
3476
|
+
FROM {fq_table("memory_units")} mu
|
|
3477
|
+
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
3008
3478
|
WHERE mu.bank_id = $1
|
|
3009
3479
|
AND mu.fact_type = 'observation'
|
|
3010
3480
|
AND ue.entity_id = $2
|
|
@@ -3023,8 +3493,8 @@ Guidelines:
|
|
|
3023
3493
|
|
|
3024
3494
|
for obs_text, embedding in zip(observations, embeddings):
|
|
3025
3495
|
result = await db_conn.fetchrow(
|
|
3026
|
-
"""
|
|
3027
|
-
INSERT INTO memory_units (
|
|
3496
|
+
f"""
|
|
3497
|
+
INSERT INTO {fq_table("memory_units")} (
|
|
3028
3498
|
bank_id, text, embedding, context, event_date,
|
|
3029
3499
|
occurred_start, occurred_end, mentioned_at,
|
|
3030
3500
|
fact_type, access_count
|
|
@@ -3046,8 +3516,8 @@ Guidelines:
|
|
|
3046
3516
|
|
|
3047
3517
|
# Link observation to entity
|
|
3048
3518
|
await db_conn.execute(
|
|
3049
|
-
"""
|
|
3050
|
-
INSERT INTO unit_entities (unit_id, entity_id)
|
|
3519
|
+
f"""
|
|
3520
|
+
INSERT INTO {fq_table("unit_entities")} (unit_id, entity_id)
|
|
3051
3521
|
VALUES ($1, $2)
|
|
3052
3522
|
""",
|
|
3053
3523
|
uuid.UUID(obs_id),
|
|
@@ -3066,7 +3536,12 @@ Guidelines:
|
|
|
3066
3536
|
return await do_db_operations(acquired_conn)
|
|
3067
3537
|
|
|
3068
3538
|
async def _regenerate_observations_sync(
|
|
3069
|
-
self,
|
|
3539
|
+
self,
|
|
3540
|
+
bank_id: str,
|
|
3541
|
+
entity_ids: list[str],
|
|
3542
|
+
min_facts: int = 5,
|
|
3543
|
+
conn=None,
|
|
3544
|
+
request_context: "RequestContext | None" = None,
|
|
3070
3545
|
) -> None:
|
|
3071
3546
|
"""
|
|
3072
3547
|
Regenerate observations for entities synchronously (called during retain).
|
|
@@ -3089,8 +3564,8 @@ Guidelines:
|
|
|
3089
3564
|
if conn is not None:
|
|
3090
3565
|
# Use the provided connection (transactional with caller)
|
|
3091
3566
|
entity_rows = await conn.fetch(
|
|
3092
|
-
"""
|
|
3093
|
-
SELECT id, canonical_name FROM entities
|
|
3567
|
+
f"""
|
|
3568
|
+
SELECT id, canonical_name FROM {fq_table("entities")}
|
|
3094
3569
|
WHERE id = ANY($1) AND bank_id = $2
|
|
3095
3570
|
""",
|
|
3096
3571
|
entity_uuids,
|
|
@@ -3099,10 +3574,10 @@ Guidelines:
|
|
|
3099
3574
|
entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
|
|
3100
3575
|
|
|
3101
3576
|
fact_counts = await conn.fetch(
|
|
3102
|
-
"""
|
|
3577
|
+
f"""
|
|
3103
3578
|
SELECT ue.entity_id, COUNT(*) as cnt
|
|
3104
|
-
FROM unit_entities ue
|
|
3105
|
-
JOIN memory_units mu ON ue.unit_id = mu.id
|
|
3579
|
+
FROM {fq_table("unit_entities")} ue
|
|
3580
|
+
JOIN {fq_table("memory_units")} mu ON ue.unit_id = mu.id
|
|
3106
3581
|
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
3107
3582
|
GROUP BY ue.entity_id
|
|
3108
3583
|
""",
|
|
@@ -3115,8 +3590,8 @@ Guidelines:
|
|
|
3115
3590
|
pool = await self._get_pool()
|
|
3116
3591
|
async with pool.acquire() as acquired_conn:
|
|
3117
3592
|
entity_rows = await acquired_conn.fetch(
|
|
3118
|
-
"""
|
|
3119
|
-
SELECT id, canonical_name FROM entities
|
|
3593
|
+
f"""
|
|
3594
|
+
SELECT id, canonical_name FROM {fq_table("entities")}
|
|
3120
3595
|
WHERE id = ANY($1) AND bank_id = $2
|
|
3121
3596
|
""",
|
|
3122
3597
|
entity_uuids,
|
|
@@ -3125,10 +3600,10 @@ Guidelines:
|
|
|
3125
3600
|
entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
|
|
3126
3601
|
|
|
3127
3602
|
fact_counts = await acquired_conn.fetch(
|
|
3128
|
-
"""
|
|
3603
|
+
f"""
|
|
3129
3604
|
SELECT ue.entity_id, COUNT(*) as cnt
|
|
3130
|
-
FROM unit_entities ue
|
|
3131
|
-
JOIN memory_units mu ON ue.unit_id = mu.id
|
|
3605
|
+
FROM {fq_table("unit_entities")} ue
|
|
3606
|
+
JOIN {fq_table("memory_units")} mu ON ue.unit_id = mu.id
|
|
3132
3607
|
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
3133
3608
|
GROUP BY ue.entity_id
|
|
3134
3609
|
""",
|
|
@@ -3150,10 +3625,17 @@ Guidelines:
|
|
|
3150
3625
|
if not entities_to_process:
|
|
3151
3626
|
return
|
|
3152
3627
|
|
|
3628
|
+
# Use internal context if not provided (for internal/background calls)
|
|
3629
|
+
from hindsight_api.models import RequestContext as RC
|
|
3630
|
+
|
|
3631
|
+
ctx = request_context if request_context is not None else RC()
|
|
3632
|
+
|
|
3153
3633
|
# Process all entities in PARALLEL (LLM calls are the bottleneck)
|
|
3154
3634
|
async def process_entity(entity_id: str, entity_name: str):
|
|
3155
3635
|
try:
|
|
3156
|
-
await self.regenerate_entity_observations(
|
|
3636
|
+
await self.regenerate_entity_observations(
|
|
3637
|
+
bank_id, entity_id, entity_name, version=None, conn=conn, request_context=ctx
|
|
3638
|
+
)
|
|
3157
3639
|
except Exception as e:
|
|
3158
3640
|
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
3159
3641
|
|
|
@@ -3170,6 +3652,10 @@ Guidelines:
|
|
|
3170
3652
|
"""
|
|
3171
3653
|
try:
|
|
3172
3654
|
bank_id = task_dict.get("bank_id")
|
|
3655
|
+
# Use internal request context for background tasks
|
|
3656
|
+
from hindsight_api.models import RequestContext
|
|
3657
|
+
|
|
3658
|
+
internal_context = RequestContext()
|
|
3173
3659
|
|
|
3174
3660
|
# New format: multiple entity_ids
|
|
3175
3661
|
if "entity_ids" in task_dict:
|
|
@@ -3192,7 +3678,7 @@ Guidelines:
|
|
|
3192
3678
|
|
|
3193
3679
|
# First check if entity exists
|
|
3194
3680
|
entity_exists = await conn.fetchrow(
|
|
3195
|
-
"SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
|
|
3681
|
+
f"SELECT canonical_name FROM {fq_table('entities')} WHERE id = $1 AND bank_id = $2",
|
|
3196
3682
|
entity_uuid,
|
|
3197
3683
|
bank_id,
|
|
3198
3684
|
)
|
|
@@ -3206,14 +3692,17 @@ Guidelines:
|
|
|
3206
3692
|
# Count facts linked to this entity
|
|
3207
3693
|
fact_count = (
|
|
3208
3694
|
await conn.fetchval(
|
|
3209
|
-
"SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1",
|
|
3695
|
+
f"SELECT COUNT(*) FROM {fq_table('unit_entities')} WHERE entity_id = $1",
|
|
3696
|
+
entity_uuid,
|
|
3210
3697
|
)
|
|
3211
3698
|
or 0
|
|
3212
3699
|
)
|
|
3213
3700
|
|
|
3214
3701
|
# Only regenerate if entity has enough facts
|
|
3215
3702
|
if fact_count >= min_facts:
|
|
3216
|
-
await self.regenerate_entity_observations(
|
|
3703
|
+
await self.regenerate_entity_observations(
|
|
3704
|
+
bank_id, entity_id, entity_name, version=None, request_context=internal_context
|
|
3705
|
+
)
|
|
3217
3706
|
else:
|
|
3218
3707
|
logger.debug(
|
|
3219
3708
|
f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
|
|
@@ -3233,10 +3722,297 @@ Guidelines:
|
|
|
3233
3722
|
logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
|
|
3234
3723
|
return
|
|
3235
3724
|
|
|
3236
|
-
|
|
3725
|
+
# Type assertions after validation
|
|
3726
|
+
assert isinstance(bank_id, str) and isinstance(entity_id, str) and isinstance(entity_name, str)
|
|
3727
|
+
await self.regenerate_entity_observations(
|
|
3728
|
+
bank_id, entity_id, entity_name, version=version, request_context=internal_context
|
|
3729
|
+
)
|
|
3237
3730
|
|
|
3238
3731
|
except Exception as e:
|
|
3239
3732
|
logger.error(f"[OBSERVATIONS] Error regenerating observations: {e}")
|
|
3240
3733
|
import traceback
|
|
3241
3734
|
|
|
3242
3735
|
traceback.print_exc()
|
|
3736
|
+
|
|
3737
|
+
# =========================================================================
|
|
3738
|
+
# Statistics & Operations (for HTTP API layer)
|
|
3739
|
+
# =========================================================================
|
|
3740
|
+
|
|
3741
|
+
async def get_bank_stats(
|
|
3742
|
+
self,
|
|
3743
|
+
bank_id: str,
|
|
3744
|
+
*,
|
|
3745
|
+
request_context: "RequestContext",
|
|
3746
|
+
) -> dict[str, Any]:
|
|
3747
|
+
"""Get statistics about memory nodes and links for a bank."""
|
|
3748
|
+
await self._authenticate_tenant(request_context)
|
|
3749
|
+
pool = await self._get_pool()
|
|
3750
|
+
|
|
3751
|
+
async with acquire_with_retry(pool) as conn:
|
|
3752
|
+
# Get node counts by fact_type
|
|
3753
|
+
node_stats = await conn.fetch(
|
|
3754
|
+
f"""
|
|
3755
|
+
SELECT fact_type, COUNT(*) as count
|
|
3756
|
+
FROM {fq_table("memory_units")}
|
|
3757
|
+
WHERE bank_id = $1
|
|
3758
|
+
GROUP BY fact_type
|
|
3759
|
+
""",
|
|
3760
|
+
bank_id,
|
|
3761
|
+
)
|
|
3762
|
+
|
|
3763
|
+
# Get link counts by link_type
|
|
3764
|
+
link_stats = await conn.fetch(
|
|
3765
|
+
f"""
|
|
3766
|
+
SELECT ml.link_type, COUNT(*) as count
|
|
3767
|
+
FROM {fq_table("memory_links")} ml
|
|
3768
|
+
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
3769
|
+
WHERE mu.bank_id = $1
|
|
3770
|
+
GROUP BY ml.link_type
|
|
3771
|
+
""",
|
|
3772
|
+
bank_id,
|
|
3773
|
+
)
|
|
3774
|
+
|
|
3775
|
+
# Get link counts by fact_type (from nodes)
|
|
3776
|
+
link_fact_type_stats = await conn.fetch(
|
|
3777
|
+
f"""
|
|
3778
|
+
SELECT mu.fact_type, COUNT(*) as count
|
|
3779
|
+
FROM {fq_table("memory_links")} ml
|
|
3780
|
+
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
3781
|
+
WHERE mu.bank_id = $1
|
|
3782
|
+
GROUP BY mu.fact_type
|
|
3783
|
+
""",
|
|
3784
|
+
bank_id,
|
|
3785
|
+
)
|
|
3786
|
+
|
|
3787
|
+
# Get link counts by fact_type AND link_type
|
|
3788
|
+
link_breakdown_stats = await conn.fetch(
|
|
3789
|
+
f"""
|
|
3790
|
+
SELECT mu.fact_type, ml.link_type, COUNT(*) as count
|
|
3791
|
+
FROM {fq_table("memory_links")} ml
|
|
3792
|
+
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
3793
|
+
WHERE mu.bank_id = $1
|
|
3794
|
+
GROUP BY mu.fact_type, ml.link_type
|
|
3795
|
+
""",
|
|
3796
|
+
bank_id,
|
|
3797
|
+
)
|
|
3798
|
+
|
|
3799
|
+
# Get pending and failed operations counts
|
|
3800
|
+
ops_stats = await conn.fetch(
|
|
3801
|
+
f"""
|
|
3802
|
+
SELECT status, COUNT(*) as count
|
|
3803
|
+
FROM {fq_table("async_operations")}
|
|
3804
|
+
WHERE bank_id = $1
|
|
3805
|
+
GROUP BY status
|
|
3806
|
+
""",
|
|
3807
|
+
bank_id,
|
|
3808
|
+
)
|
|
3809
|
+
|
|
3810
|
+
return {
|
|
3811
|
+
"bank_id": bank_id,
|
|
3812
|
+
"node_counts": {row["fact_type"]: row["count"] for row in node_stats},
|
|
3813
|
+
"link_counts": {row["link_type"]: row["count"] for row in link_stats},
|
|
3814
|
+
"link_counts_by_fact_type": {row["fact_type"]: row["count"] for row in link_fact_type_stats},
|
|
3815
|
+
"link_breakdown": [
|
|
3816
|
+
{"fact_type": row["fact_type"], "link_type": row["link_type"], "count": row["count"]}
|
|
3817
|
+
for row in link_breakdown_stats
|
|
3818
|
+
],
|
|
3819
|
+
"operations": {row["status"]: row["count"] for row in ops_stats},
|
|
3820
|
+
}
|
|
3821
|
+
|
|
3822
|
+
async def get_entity(
|
|
3823
|
+
self,
|
|
3824
|
+
bank_id: str,
|
|
3825
|
+
entity_id: str,
|
|
3826
|
+
*,
|
|
3827
|
+
request_context: "RequestContext",
|
|
3828
|
+
) -> dict[str, Any] | None:
|
|
3829
|
+
"""Get entity details including metadata and observations."""
|
|
3830
|
+
await self._authenticate_tenant(request_context)
|
|
3831
|
+
pool = await self._get_pool()
|
|
3832
|
+
|
|
3833
|
+
async with acquire_with_retry(pool) as conn:
|
|
3834
|
+
entity_row = await conn.fetchrow(
|
|
3835
|
+
f"""
|
|
3836
|
+
SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
|
|
3837
|
+
FROM {fq_table("entities")}
|
|
3838
|
+
WHERE bank_id = $1 AND id = $2
|
|
3839
|
+
""",
|
|
3840
|
+
bank_id,
|
|
3841
|
+
uuid.UUID(entity_id),
|
|
3842
|
+
)
|
|
3843
|
+
|
|
3844
|
+
if not entity_row:
|
|
3845
|
+
return None
|
|
3846
|
+
|
|
3847
|
+
# Get observations for the entity
|
|
3848
|
+
observations = await self.get_entity_observations(bank_id, entity_id, limit=20, request_context=request_context)
|
|
3849
|
+
|
|
3850
|
+
return {
|
|
3851
|
+
"id": str(entity_row["id"]),
|
|
3852
|
+
"canonical_name": entity_row["canonical_name"],
|
|
3853
|
+
"mention_count": entity_row["mention_count"],
|
|
3854
|
+
"first_seen": entity_row["first_seen"].isoformat() if entity_row["first_seen"] else None,
|
|
3855
|
+
"last_seen": entity_row["last_seen"].isoformat() if entity_row["last_seen"] else None,
|
|
3856
|
+
"metadata": entity_row["metadata"] or {},
|
|
3857
|
+
"observations": observations,
|
|
3858
|
+
}
|
|
3859
|
+
|
|
3860
|
+
async def list_operations(
|
|
3861
|
+
self,
|
|
3862
|
+
bank_id: str,
|
|
3863
|
+
*,
|
|
3864
|
+
request_context: "RequestContext",
|
|
3865
|
+
) -> list[dict[str, Any]]:
|
|
3866
|
+
"""List async operations for a bank."""
|
|
3867
|
+
await self._authenticate_tenant(request_context)
|
|
3868
|
+
pool = await self._get_pool()
|
|
3869
|
+
|
|
3870
|
+
async with acquire_with_retry(pool) as conn:
|
|
3871
|
+
operations = await conn.fetch(
|
|
3872
|
+
f"""
|
|
3873
|
+
SELECT operation_id, bank_id, operation_type, created_at, status, error_message, result_metadata
|
|
3874
|
+
FROM {fq_table("async_operations")}
|
|
3875
|
+
WHERE bank_id = $1
|
|
3876
|
+
ORDER BY created_at DESC
|
|
3877
|
+
""",
|
|
3878
|
+
bank_id,
|
|
3879
|
+
)
|
|
3880
|
+
|
|
3881
|
+
def parse_metadata(metadata):
|
|
3882
|
+
if metadata is None:
|
|
3883
|
+
return {}
|
|
3884
|
+
if isinstance(metadata, str):
|
|
3885
|
+
import json
|
|
3886
|
+
|
|
3887
|
+
return json.loads(metadata)
|
|
3888
|
+
return metadata
|
|
3889
|
+
|
|
3890
|
+
return [
|
|
3891
|
+
{
|
|
3892
|
+
"id": str(row["operation_id"]),
|
|
3893
|
+
"task_type": row["operation_type"],
|
|
3894
|
+
"items_count": parse_metadata(row["result_metadata"]).get("items_count", 0),
|
|
3895
|
+
"document_id": parse_metadata(row["result_metadata"]).get("document_id"),
|
|
3896
|
+
"created_at": row["created_at"].isoformat(),
|
|
3897
|
+
"status": row["status"],
|
|
3898
|
+
"error_message": row["error_message"],
|
|
3899
|
+
}
|
|
3900
|
+
for row in operations
|
|
3901
|
+
]
|
|
3902
|
+
|
|
3903
|
+
async def cancel_operation(
|
|
3904
|
+
self,
|
|
3905
|
+
bank_id: str,
|
|
3906
|
+
operation_id: str,
|
|
3907
|
+
*,
|
|
3908
|
+
request_context: "RequestContext",
|
|
3909
|
+
) -> dict[str, Any]:
|
|
3910
|
+
"""Cancel a pending async operation."""
|
|
3911
|
+
await self._authenticate_tenant(request_context)
|
|
3912
|
+
pool = await self._get_pool()
|
|
3913
|
+
|
|
3914
|
+
op_uuid = uuid.UUID(operation_id)
|
|
3915
|
+
|
|
3916
|
+
async with acquire_with_retry(pool) as conn:
|
|
3917
|
+
# Check if operation exists and belongs to this memory bank
|
|
3918
|
+
result = await conn.fetchrow(
|
|
3919
|
+
f"SELECT bank_id FROM {fq_table('async_operations')} WHERE operation_id = $1 AND bank_id = $2",
|
|
3920
|
+
op_uuid,
|
|
3921
|
+
bank_id,
|
|
3922
|
+
)
|
|
3923
|
+
|
|
3924
|
+
if not result:
|
|
3925
|
+
raise ValueError(f"Operation {operation_id} not found for bank {bank_id}")
|
|
3926
|
+
|
|
3927
|
+
# Delete the operation
|
|
3928
|
+
await conn.execute(f"DELETE FROM {fq_table('async_operations')} WHERE operation_id = $1", op_uuid)
|
|
3929
|
+
|
|
3930
|
+
return {
|
|
3931
|
+
"success": True,
|
|
3932
|
+
"message": f"Operation {operation_id} cancelled",
|
|
3933
|
+
"operation_id": operation_id,
|
|
3934
|
+
"bank_id": bank_id,
|
|
3935
|
+
}
|
|
3936
|
+
|
|
3937
|
+
async def update_bank(
|
|
3938
|
+
self,
|
|
3939
|
+
bank_id: str,
|
|
3940
|
+
*,
|
|
3941
|
+
name: str | None = None,
|
|
3942
|
+
background: str | None = None,
|
|
3943
|
+
request_context: "RequestContext",
|
|
3944
|
+
) -> dict[str, Any]:
|
|
3945
|
+
"""Update bank name and/or background."""
|
|
3946
|
+
await self._authenticate_tenant(request_context)
|
|
3947
|
+
pool = await self._get_pool()
|
|
3948
|
+
|
|
3949
|
+
async with acquire_with_retry(pool) as conn:
|
|
3950
|
+
if name is not None:
|
|
3951
|
+
await conn.execute(
|
|
3952
|
+
f"""
|
|
3953
|
+
UPDATE {fq_table("banks")}
|
|
3954
|
+
SET name = $2, updated_at = NOW()
|
|
3955
|
+
WHERE bank_id = $1
|
|
3956
|
+
""",
|
|
3957
|
+
bank_id,
|
|
3958
|
+
name,
|
|
3959
|
+
)
|
|
3960
|
+
|
|
3961
|
+
if background is not None:
|
|
3962
|
+
await conn.execute(
|
|
3963
|
+
f"""
|
|
3964
|
+
UPDATE {fq_table("banks")}
|
|
3965
|
+
SET background = $2, updated_at = NOW()
|
|
3966
|
+
WHERE bank_id = $1
|
|
3967
|
+
""",
|
|
3968
|
+
bank_id,
|
|
3969
|
+
background,
|
|
3970
|
+
)
|
|
3971
|
+
|
|
3972
|
+
# Return updated profile
|
|
3973
|
+
return await self.get_bank_profile(bank_id, request_context=request_context)
|
|
3974
|
+
|
|
3975
|
+
async def submit_async_retain(
|
|
3976
|
+
self,
|
|
3977
|
+
bank_id: str,
|
|
3978
|
+
contents: list[dict[str, Any]],
|
|
3979
|
+
*,
|
|
3980
|
+
request_context: "RequestContext",
|
|
3981
|
+
) -> dict[str, Any]:
|
|
3982
|
+
"""Submit a batch retain operation to run asynchronously."""
|
|
3983
|
+
await self._authenticate_tenant(request_context)
|
|
3984
|
+
pool = await self._get_pool()
|
|
3985
|
+
|
|
3986
|
+
import json
|
|
3987
|
+
|
|
3988
|
+
operation_id = uuid.uuid4()
|
|
3989
|
+
|
|
3990
|
+
# Insert operation record into database
|
|
3991
|
+
async with acquire_with_retry(pool) as conn:
|
|
3992
|
+
await conn.execute(
|
|
3993
|
+
f"""
|
|
3994
|
+
INSERT INTO {fq_table("async_operations")} (operation_id, bank_id, operation_type, result_metadata)
|
|
3995
|
+
VALUES ($1, $2, $3, $4)
|
|
3996
|
+
""",
|
|
3997
|
+
operation_id,
|
|
3998
|
+
bank_id,
|
|
3999
|
+
"retain",
|
|
4000
|
+
json.dumps({"items_count": len(contents)}),
|
|
4001
|
+
)
|
|
4002
|
+
|
|
4003
|
+
# Submit task to background queue
|
|
4004
|
+
await self._task_backend.submit_task(
|
|
4005
|
+
{
|
|
4006
|
+
"type": "batch_retain",
|
|
4007
|
+
"operation_id": str(operation_id),
|
|
4008
|
+
"bank_id": bank_id,
|
|
4009
|
+
"contents": contents,
|
|
4010
|
+
}
|
|
4011
|
+
)
|
|
4012
|
+
|
|
4013
|
+
logger.info(f"Retain task queued for bank_id={bank_id}, {len(contents)} items, operation_id={operation_id}")
|
|
4014
|
+
|
|
4015
|
+
return {
|
|
4016
|
+
"operation_id": str(operation_id),
|
|
4017
|
+
"items_count": len(contents),
|
|
4018
|
+
}
|