hindsight-api 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/admin/__init__.py +1 -0
- hindsight_api/admin/cli.py +311 -0
- hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
- hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
- hindsight_api/alembic/versions/h3c4d5e6f7g8_mental_models_v4.py +112 -0
- hindsight_api/alembic/versions/i4d5e6f7g8h9_delete_opinions.py +41 -0
- hindsight_api/alembic/versions/j5e6f7g8h9i0_mental_model_versions.py +95 -0
- hindsight_api/alembic/versions/k6f7g8h9i0j1_add_directive_subtype.py +58 -0
- hindsight_api/alembic/versions/l7g8h9i0j1k2_add_worker_columns.py +109 -0
- hindsight_api/alembic/versions/m8h9i0j1k2l3_mental_model_id_to_text.py +41 -0
- hindsight_api/alembic/versions/n9i0j1k2l3m4_learnings_and_pinned_reflections.py +134 -0
- hindsight_api/alembic/versions/o0j1k2l3m4n5_migrate_mental_models_data.py +113 -0
- hindsight_api/alembic/versions/p1k2l3m4n5o6_new_knowledge_architecture.py +194 -0
- hindsight_api/alembic/versions/q2l3m4n5o6p7_fix_mental_model_fact_type.py +50 -0
- hindsight_api/alembic/versions/r3m4n5o6p7q8_add_reflect_response_to_reflections.py +47 -0
- hindsight_api/alembic/versions/s4n5o6p7q8r9_add_consolidated_at_to_memory_units.py +53 -0
- hindsight_api/alembic/versions/t5o6p7q8r9s0_rename_mental_models_to_observations.py +134 -0
- hindsight_api/alembic/versions/u6p7q8r9s0t1_mental_models_text_id.py +41 -0
- hindsight_api/alembic/versions/v7q8r9s0t1u2_add_max_tokens_to_mental_models.py +50 -0
- hindsight_api/api/http.py +1406 -118
- hindsight_api/api/mcp.py +11 -196
- hindsight_api/config.py +359 -27
- hindsight_api/engine/consolidation/__init__.py +5 -0
- hindsight_api/engine/consolidation/consolidator.py +859 -0
- hindsight_api/engine/consolidation/prompts.py +69 -0
- hindsight_api/engine/cross_encoder.py +706 -88
- hindsight_api/engine/db_budget.py +284 -0
- hindsight_api/engine/db_utils.py +11 -0
- hindsight_api/engine/directives/__init__.py +5 -0
- hindsight_api/engine/directives/models.py +37 -0
- hindsight_api/engine/embeddings.py +553 -29
- hindsight_api/engine/entity_resolver.py +8 -5
- hindsight_api/engine/interface.py +40 -17
- hindsight_api/engine/llm_wrapper.py +744 -68
- hindsight_api/engine/memory_engine.py +2505 -1017
- hindsight_api/engine/mental_models/__init__.py +14 -0
- hindsight_api/engine/mental_models/models.py +53 -0
- hindsight_api/engine/query_analyzer.py +4 -3
- hindsight_api/engine/reflect/__init__.py +18 -0
- hindsight_api/engine/reflect/agent.py +933 -0
- hindsight_api/engine/reflect/models.py +109 -0
- hindsight_api/engine/reflect/observations.py +186 -0
- hindsight_api/engine/reflect/prompts.py +483 -0
- hindsight_api/engine/reflect/tools.py +437 -0
- hindsight_api/engine/reflect/tools_schema.py +250 -0
- hindsight_api/engine/response_models.py +168 -4
- hindsight_api/engine/retain/bank_utils.py +79 -201
- hindsight_api/engine/retain/fact_extraction.py +424 -195
- hindsight_api/engine/retain/fact_storage.py +35 -12
- hindsight_api/engine/retain/link_utils.py +29 -24
- hindsight_api/engine/retain/orchestrator.py +24 -43
- hindsight_api/engine/retain/types.py +11 -2
- hindsight_api/engine/search/graph_retrieval.py +43 -14
- hindsight_api/engine/search/link_expansion_retrieval.py +391 -0
- hindsight_api/engine/search/mpfp_retrieval.py +362 -117
- hindsight_api/engine/search/reranking.py +2 -2
- hindsight_api/engine/search/retrieval.py +848 -201
- hindsight_api/engine/search/tags.py +172 -0
- hindsight_api/engine/search/think_utils.py +42 -141
- hindsight_api/engine/search/trace.py +12 -1
- hindsight_api/engine/search/tracer.py +26 -6
- hindsight_api/engine/search/types.py +21 -3
- hindsight_api/engine/task_backend.py +113 -106
- hindsight_api/engine/utils.py +1 -152
- hindsight_api/extensions/__init__.py +10 -1
- hindsight_api/extensions/builtin/tenant.py +5 -1
- hindsight_api/extensions/context.py +10 -1
- hindsight_api/extensions/operation_validator.py +81 -4
- hindsight_api/extensions/tenant.py +26 -0
- hindsight_api/main.py +69 -6
- hindsight_api/mcp_local.py +12 -53
- hindsight_api/mcp_tools.py +494 -0
- hindsight_api/metrics.py +433 -48
- hindsight_api/migrations.py +141 -1
- hindsight_api/models.py +3 -3
- hindsight_api/pg0.py +53 -0
- hindsight_api/server.py +39 -2
- hindsight_api/worker/__init__.py +11 -0
- hindsight_api/worker/main.py +296 -0
- hindsight_api/worker/poller.py +486 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/METADATA +16 -6
- hindsight_api-0.4.0.dist-info/RECORD +112 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/entry_points.txt +2 -0
- hindsight_api/engine/retain/observation_regeneration.py +0 -254
- hindsight_api/engine/search/observation_utils.py +0 -125
- hindsight_api/engine/search/scoring.py +0 -159
- hindsight_api-0.2.1.dist-info/RECORD +0 -75
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/WHEEL +0 -0
|
@@ -11,6 +11,7 @@ This implements a sophisticated memory architecture that combines:
|
|
|
11
11
|
|
|
12
12
|
import asyncio
|
|
13
13
|
import contextvars
|
|
14
|
+
import json
|
|
14
15
|
import logging
|
|
15
16
|
import time
|
|
16
17
|
import uuid
|
|
@@ -18,6 +19,8 @@ from datetime import UTC, datetime, timedelta
|
|
|
18
19
|
from typing import TYPE_CHECKING, Any
|
|
19
20
|
|
|
20
21
|
from ..config import get_config
|
|
22
|
+
from ..metrics import get_metrics_collector
|
|
23
|
+
from .db_budget import budgeted_operation
|
|
21
24
|
|
|
22
25
|
# Context variable for current schema (async-safe, per-task isolation)
|
|
23
26
|
_current_schema: contextvars.ContextVar[str] = contextvars.ContextVar("current_schema", default="public")
|
|
@@ -132,17 +135,31 @@ if TYPE_CHECKING:
|
|
|
132
135
|
|
|
133
136
|
from enum import Enum
|
|
134
137
|
|
|
135
|
-
from ..
|
|
138
|
+
from ..metrics import get_metrics_collector
|
|
139
|
+
from ..pg0 import EmbeddedPostgres, parse_pg0_url
|
|
136
140
|
from .entity_resolver import EntityResolver
|
|
137
141
|
from .llm_wrapper import LLMConfig
|
|
138
142
|
from .query_analyzer import QueryAnalyzer
|
|
139
|
-
from .
|
|
143
|
+
from .reflect import run_reflect_agent
|
|
144
|
+
from .reflect.tools import tool_expand, tool_recall, tool_search_mental_models, tool_search_observations
|
|
145
|
+
from .response_models import (
|
|
146
|
+
VALID_RECALL_FACT_TYPES,
|
|
147
|
+
EntityObservation,
|
|
148
|
+
EntityState,
|
|
149
|
+
LLMCallTrace,
|
|
150
|
+
MemoryFact,
|
|
151
|
+
ObservationRef,
|
|
152
|
+
ReflectResult,
|
|
153
|
+
TokenUsage,
|
|
154
|
+
ToolCallTrace,
|
|
155
|
+
)
|
|
140
156
|
from .response_models import RecallResult as RecallResultModel
|
|
141
157
|
from .retain import bank_utils, embedding_utils
|
|
142
158
|
from .retain.types import RetainContentDict
|
|
143
|
-
from .search import
|
|
159
|
+
from .search import think_utils
|
|
144
160
|
from .search.reranking import CrossEncoderReranker
|
|
145
|
-
from .
|
|
161
|
+
from .search.tags import TagsMatch
|
|
162
|
+
from .task_backend import BrokerTaskBackend, SyncTaskBackend, TaskBackend
|
|
146
163
|
|
|
147
164
|
|
|
148
165
|
class Budget(str, Enum):
|
|
@@ -195,11 +212,26 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
195
212
|
memory_llm_api_key: str | None = None,
|
|
196
213
|
memory_llm_model: str | None = None,
|
|
197
214
|
memory_llm_base_url: str | None = None,
|
|
215
|
+
# Per-operation LLM config (optional, falls back to memory_llm_* params)
|
|
216
|
+
retain_llm_provider: str | None = None,
|
|
217
|
+
retain_llm_api_key: str | None = None,
|
|
218
|
+
retain_llm_model: str | None = None,
|
|
219
|
+
retain_llm_base_url: str | None = None,
|
|
220
|
+
reflect_llm_provider: str | None = None,
|
|
221
|
+
reflect_llm_api_key: str | None = None,
|
|
222
|
+
reflect_llm_model: str | None = None,
|
|
223
|
+
reflect_llm_base_url: str | None = None,
|
|
224
|
+
consolidation_llm_provider: str | None = None,
|
|
225
|
+
consolidation_llm_api_key: str | None = None,
|
|
226
|
+
consolidation_llm_model: str | None = None,
|
|
227
|
+
consolidation_llm_base_url: str | None = None,
|
|
198
228
|
embeddings: Embeddings | None = None,
|
|
199
229
|
cross_encoder: CrossEncoderModel | None = None,
|
|
200
230
|
query_analyzer: QueryAnalyzer | None = None,
|
|
201
|
-
pool_min_size: int =
|
|
202
|
-
pool_max_size: int =
|
|
231
|
+
pool_min_size: int | None = None,
|
|
232
|
+
pool_max_size: int | None = None,
|
|
233
|
+
db_command_timeout: int | None = None,
|
|
234
|
+
db_acquire_timeout: int | None = None,
|
|
203
235
|
task_backend: TaskBackend | None = None,
|
|
204
236
|
run_migrations: bool = True,
|
|
205
237
|
operation_validator: "OperationValidatorExtension | None" = None,
|
|
@@ -220,12 +252,26 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
220
252
|
memory_llm_api_key: API key for the LLM provider. Defaults to HINDSIGHT_API_LLM_API_KEY env var.
|
|
221
253
|
memory_llm_model: Model name. Defaults to HINDSIGHT_API_LLM_MODEL env var.
|
|
222
254
|
memory_llm_base_url: Base URL for the LLM API. Defaults based on provider.
|
|
255
|
+
retain_llm_provider: LLM provider for retain operations. Falls back to memory_llm_provider.
|
|
256
|
+
retain_llm_api_key: API key for retain LLM. Falls back to memory_llm_api_key.
|
|
257
|
+
retain_llm_model: Model for retain operations. Falls back to memory_llm_model.
|
|
258
|
+
retain_llm_base_url: Base URL for retain LLM. Falls back to memory_llm_base_url.
|
|
259
|
+
reflect_llm_provider: LLM provider for reflect operations. Falls back to memory_llm_provider.
|
|
260
|
+
reflect_llm_api_key: API key for reflect LLM. Falls back to memory_llm_api_key.
|
|
261
|
+
reflect_llm_model: Model for reflect operations. Falls back to memory_llm_model.
|
|
262
|
+
reflect_llm_base_url: Base URL for reflect LLM. Falls back to memory_llm_base_url.
|
|
263
|
+
consolidation_llm_provider: LLM provider for consolidation operations. Falls back to memory_llm_provider.
|
|
264
|
+
consolidation_llm_api_key: API key for consolidation LLM. Falls back to memory_llm_api_key.
|
|
265
|
+
consolidation_llm_model: Model for consolidation operations. Falls back to memory_llm_model.
|
|
266
|
+
consolidation_llm_base_url: Base URL for consolidation LLM. Falls back to memory_llm_base_url.
|
|
223
267
|
embeddings: Embeddings implementation. If not provided, created from env vars.
|
|
224
268
|
cross_encoder: Cross-encoder model. If not provided, created from env vars.
|
|
225
269
|
query_analyzer: Query analyzer implementation. If not provided, uses DateparserQueryAnalyzer.
|
|
226
|
-
pool_min_size: Minimum number of connections in the pool
|
|
227
|
-
pool_max_size: Maximum number of connections in the pool
|
|
228
|
-
|
|
270
|
+
pool_min_size: Minimum number of connections in the pool. Defaults to HINDSIGHT_API_DB_POOL_MIN_SIZE.
|
|
271
|
+
pool_max_size: Maximum number of connections in the pool. Defaults to HINDSIGHT_API_DB_POOL_MAX_SIZE.
|
|
272
|
+
db_command_timeout: PostgreSQL command timeout in seconds. Defaults to HINDSIGHT_API_DB_COMMAND_TIMEOUT.
|
|
273
|
+
db_acquire_timeout: Connection acquisition timeout in seconds. Defaults to HINDSIGHT_API_DB_ACQUIRE_TIMEOUT.
|
|
274
|
+
task_backend: Custom task backend. If not provided, uses BrokerTaskBackend for distributed processing.
|
|
229
275
|
run_migrations: Whether to run database migrations during initialize(). Default: True
|
|
230
276
|
operation_validator: Optional extension to validate operations before execution.
|
|
231
277
|
If provided, retain/recall/reflect operations will be validated.
|
|
@@ -252,38 +298,21 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
252
298
|
db_url = db_url or config.database_url
|
|
253
299
|
memory_llm_provider = memory_llm_provider or config.llm_provider
|
|
254
300
|
memory_llm_api_key = memory_llm_api_key or config.llm_api_key
|
|
255
|
-
# Ollama
|
|
256
|
-
if not memory_llm_api_key and memory_llm_provider
|
|
301
|
+
# Ollama and mock don't require an API key
|
|
302
|
+
if not memory_llm_api_key and memory_llm_provider not in ("ollama", "mock"):
|
|
257
303
|
raise ValueError("LLM API key is required. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
|
|
258
304
|
memory_llm_model = memory_llm_model or config.llm_model
|
|
259
305
|
memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
|
|
260
306
|
# Track pg0 instance (if used)
|
|
261
307
|
self._pg0: EmbeddedPostgres | None = None
|
|
262
|
-
self._pg0_instance_name: str | None = None
|
|
263
308
|
|
|
264
309
|
# Initialize PostgreSQL connection URL
|
|
265
310
|
# The actual URL will be set during initialize() after starting the server
|
|
266
311
|
# Supports: "pg0" (default instance), "pg0://instance-name" (named instance), or regular postgresql:// URL
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
self._pg0_instance_name = "hindsight"
|
|
270
|
-
self._pg0_port = None # Use default port
|
|
271
|
-
self.db_url = None
|
|
272
|
-
elif db_url.startswith("pg0://"):
|
|
273
|
-
self._use_pg0 = True
|
|
274
|
-
# Parse instance name and optional port: pg0://instance-name or pg0://instance-name:port
|
|
275
|
-
url_part = db_url[6:] # Remove "pg0://"
|
|
276
|
-
if ":" in url_part:
|
|
277
|
-
self._pg0_instance_name, port_str = url_part.rsplit(":", 1)
|
|
278
|
-
self._pg0_port = int(port_str)
|
|
279
|
-
else:
|
|
280
|
-
self._pg0_instance_name = url_part or "hindsight"
|
|
281
|
-
self._pg0_port = None # Use default port
|
|
312
|
+
self._use_pg0, self._pg0_instance_name, self._pg0_port = parse_pg0_url(db_url)
|
|
313
|
+
if self._use_pg0:
|
|
282
314
|
self.db_url = None
|
|
283
315
|
else:
|
|
284
|
-
self._use_pg0 = False
|
|
285
|
-
self._pg0_instance_name = None
|
|
286
|
-
self._pg0_port = None
|
|
287
316
|
self.db_url = db_url
|
|
288
317
|
|
|
289
318
|
# Set default base URL if not provided
|
|
@@ -298,8 +327,10 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
298
327
|
# Connection pool (will be created in initialize())
|
|
299
328
|
self._pool = None
|
|
300
329
|
self._initialized = False
|
|
301
|
-
self._pool_min_size = pool_min_size
|
|
302
|
-
self._pool_max_size = pool_max_size
|
|
330
|
+
self._pool_min_size = pool_min_size if pool_min_size is not None else config.db_pool_min_size
|
|
331
|
+
self._pool_max_size = pool_max_size if pool_max_size is not None else config.db_pool_max_size
|
|
332
|
+
self._db_command_timeout = db_command_timeout if db_command_timeout is not None else config.db_command_timeout
|
|
333
|
+
self._db_acquire_timeout = db_acquire_timeout if db_acquire_timeout is not None else config.db_acquire_timeout
|
|
303
334
|
self._run_migrations = run_migrations
|
|
304
335
|
|
|
305
336
|
# Initialize entity resolver (will be created in initialize())
|
|
@@ -319,7 +350,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
319
350
|
|
|
320
351
|
self.query_analyzer = DateparserQueryAnalyzer()
|
|
321
352
|
|
|
322
|
-
# Initialize LLM configuration
|
|
353
|
+
# Initialize LLM configuration (default, used as fallback)
|
|
323
354
|
self._llm_config = LLMConfig(
|
|
324
355
|
provider=memory_llm_provider,
|
|
325
356
|
api_key=memory_llm_api_key,
|
|
@@ -331,17 +362,84 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
331
362
|
self._llm_client = self._llm_config._client
|
|
332
363
|
self._llm_model = self._llm_config.model
|
|
333
364
|
|
|
365
|
+
# Initialize per-operation LLM configs (fall back to default if not specified)
|
|
366
|
+
# Retain LLM config - for fact extraction (benefits from strong structured output)
|
|
367
|
+
retain_provider = retain_llm_provider or config.retain_llm_provider or memory_llm_provider
|
|
368
|
+
retain_api_key = retain_llm_api_key or config.retain_llm_api_key or memory_llm_api_key
|
|
369
|
+
retain_model = retain_llm_model or config.retain_llm_model or memory_llm_model
|
|
370
|
+
retain_base_url = retain_llm_base_url or config.retain_llm_base_url or memory_llm_base_url
|
|
371
|
+
# Apply provider-specific base URL defaults for retain
|
|
372
|
+
if retain_base_url is None:
|
|
373
|
+
if retain_provider.lower() == "groq":
|
|
374
|
+
retain_base_url = "https://api.groq.com/openai/v1"
|
|
375
|
+
elif retain_provider.lower() == "ollama":
|
|
376
|
+
retain_base_url = "http://localhost:11434/v1"
|
|
377
|
+
else:
|
|
378
|
+
retain_base_url = ""
|
|
379
|
+
|
|
380
|
+
self._retain_llm_config = LLMConfig(
|
|
381
|
+
provider=retain_provider,
|
|
382
|
+
api_key=retain_api_key,
|
|
383
|
+
base_url=retain_base_url,
|
|
384
|
+
model=retain_model,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Reflect LLM config - for think/observe operations (can use lighter models)
|
|
388
|
+
reflect_provider = reflect_llm_provider or config.reflect_llm_provider or memory_llm_provider
|
|
389
|
+
reflect_api_key = reflect_llm_api_key or config.reflect_llm_api_key or memory_llm_api_key
|
|
390
|
+
reflect_model = reflect_llm_model or config.reflect_llm_model or memory_llm_model
|
|
391
|
+
reflect_base_url = reflect_llm_base_url or config.reflect_llm_base_url or memory_llm_base_url
|
|
392
|
+
# Apply provider-specific base URL defaults for reflect
|
|
393
|
+
if reflect_base_url is None:
|
|
394
|
+
if reflect_provider.lower() == "groq":
|
|
395
|
+
reflect_base_url = "https://api.groq.com/openai/v1"
|
|
396
|
+
elif reflect_provider.lower() == "ollama":
|
|
397
|
+
reflect_base_url = "http://localhost:11434/v1"
|
|
398
|
+
else:
|
|
399
|
+
reflect_base_url = ""
|
|
400
|
+
|
|
401
|
+
self._reflect_llm_config = LLMConfig(
|
|
402
|
+
provider=reflect_provider,
|
|
403
|
+
api_key=reflect_api_key,
|
|
404
|
+
base_url=reflect_base_url,
|
|
405
|
+
model=reflect_model,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Consolidation LLM config - for mental model consolidation (can use efficient models)
|
|
409
|
+
consolidation_provider = consolidation_llm_provider or config.consolidation_llm_provider or memory_llm_provider
|
|
410
|
+
consolidation_api_key = consolidation_llm_api_key or config.consolidation_llm_api_key or memory_llm_api_key
|
|
411
|
+
consolidation_model = consolidation_llm_model or config.consolidation_llm_model or memory_llm_model
|
|
412
|
+
consolidation_base_url = consolidation_llm_base_url or config.consolidation_llm_base_url or memory_llm_base_url
|
|
413
|
+
# Apply provider-specific base URL defaults for consolidation
|
|
414
|
+
if consolidation_base_url is None:
|
|
415
|
+
if consolidation_provider.lower() == "groq":
|
|
416
|
+
consolidation_base_url = "https://api.groq.com/openai/v1"
|
|
417
|
+
elif consolidation_provider.lower() == "ollama":
|
|
418
|
+
consolidation_base_url = "http://localhost:11434/v1"
|
|
419
|
+
else:
|
|
420
|
+
consolidation_base_url = ""
|
|
421
|
+
|
|
422
|
+
self._consolidation_llm_config = LLMConfig(
|
|
423
|
+
provider=consolidation_provider,
|
|
424
|
+
api_key=consolidation_api_key,
|
|
425
|
+
base_url=consolidation_base_url,
|
|
426
|
+
model=consolidation_model,
|
|
427
|
+
)
|
|
428
|
+
|
|
334
429
|
# Initialize cross-encoder reranker (cached for performance)
|
|
335
430
|
self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
|
|
336
431
|
|
|
337
432
|
# Initialize task backend
|
|
338
|
-
|
|
433
|
+
# If no custom backend provided, use BrokerTaskBackend which stores tasks in PostgreSQL
|
|
434
|
+
# The pool_getter lambda will return the pool once it's initialized
|
|
435
|
+
self._task_backend = task_backend or BrokerTaskBackend(
|
|
436
|
+
pool_getter=lambda: self._pool,
|
|
437
|
+
schema_getter=get_current_schema,
|
|
438
|
+
)
|
|
339
439
|
|
|
340
440
|
# Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
|
|
341
|
-
#
|
|
342
|
-
|
|
343
|
-
# we use ~20-40 connections max, staying well within pool limits
|
|
344
|
-
self._search_semaphore = asyncio.Semaphore(10)
|
|
441
|
+
# Configurable via HINDSIGHT_API_RECALL_MAX_CONCURRENT (default: 50)
|
|
442
|
+
self._search_semaphore = asyncio.Semaphore(get_config().recall_max_concurrent)
|
|
345
443
|
|
|
346
444
|
# Backpressure for put operations: limit concurrent puts to prevent database contention
|
|
347
445
|
# Each put_batch holds a connection for the entire transaction, so we limit to 5
|
|
@@ -401,35 +499,19 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
401
499
|
if request_context is None:
|
|
402
500
|
raise AuthenticationError("RequestContext is required when tenant extension is configured")
|
|
403
501
|
|
|
502
|
+
# For internal/background operations (e.g., worker tasks), skip extension authentication
|
|
503
|
+
# if the schema has already been set by execute_task via the _schema field.
|
|
504
|
+
if request_context.internal:
|
|
505
|
+
current = _current_schema.get()
|
|
506
|
+
if current and current != "public":
|
|
507
|
+
return current
|
|
508
|
+
|
|
404
509
|
# Let AuthenticationError propagate - HTTP layer will convert to 401
|
|
405
510
|
tenant_context = await self._tenant_extension.authenticate(request_context)
|
|
406
511
|
|
|
407
512
|
_current_schema.set(tenant_context.schema_name)
|
|
408
513
|
return tenant_context.schema_name
|
|
409
514
|
|
|
410
|
-
async def _handle_access_count_update(self, task_dict: dict[str, Any]):
|
|
411
|
-
"""
|
|
412
|
-
Handler for access count update tasks.
|
|
413
|
-
|
|
414
|
-
Args:
|
|
415
|
-
task_dict: Dict with 'node_ids' key containing list of node IDs to update
|
|
416
|
-
|
|
417
|
-
Raises:
|
|
418
|
-
Exception: Any exception from database operations (propagates to execute_task for retry)
|
|
419
|
-
"""
|
|
420
|
-
node_ids = task_dict.get("node_ids", [])
|
|
421
|
-
if not node_ids:
|
|
422
|
-
return
|
|
423
|
-
|
|
424
|
-
pool = await self._get_pool()
|
|
425
|
-
# Convert string UUIDs to UUID type for faster matching
|
|
426
|
-
uuid_list = [uuid.UUID(nid) for nid in node_ids]
|
|
427
|
-
async with acquire_with_retry(pool) as conn:
|
|
428
|
-
await conn.execute(
|
|
429
|
-
f"UPDATE {fq_table('memory_units')} SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
|
|
430
|
-
uuid_list,
|
|
431
|
-
)
|
|
432
|
-
|
|
433
515
|
async def _handle_batch_retain(self, task_dict: dict[str, Any]):
|
|
434
516
|
"""
|
|
435
517
|
Handler for batch retain tasks.
|
|
@@ -450,14 +532,113 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
450
532
|
f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items"
|
|
451
533
|
)
|
|
452
534
|
|
|
453
|
-
# Use internal request context for background tasks
|
|
535
|
+
# Use internal request context for background tasks (skips tenant auth when schema is pre-set)
|
|
454
536
|
from hindsight_api.models import RequestContext
|
|
455
537
|
|
|
456
|
-
internal_context = RequestContext()
|
|
538
|
+
internal_context = RequestContext(internal=True)
|
|
457
539
|
await self.retain_batch_async(bank_id=bank_id, contents=contents, request_context=internal_context)
|
|
458
540
|
|
|
459
541
|
logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
|
|
460
542
|
|
|
543
|
+
async def _handle_consolidation(self, task_dict: dict[str, Any]):
|
|
544
|
+
"""
|
|
545
|
+
Handler for consolidation tasks.
|
|
546
|
+
|
|
547
|
+
Consolidates new memories into mental models for a bank.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
task_dict: Dict with 'bank_id'
|
|
551
|
+
|
|
552
|
+
Raises:
|
|
553
|
+
ValueError: If bank_id is missing
|
|
554
|
+
Exception: Any exception from consolidation (propagates to execute_task for retry)
|
|
555
|
+
"""
|
|
556
|
+
bank_id = task_dict.get("bank_id")
|
|
557
|
+
if not bank_id:
|
|
558
|
+
raise ValueError("bank_id is required for consolidation task")
|
|
559
|
+
|
|
560
|
+
from hindsight_api.models import RequestContext
|
|
561
|
+
|
|
562
|
+
from .consolidation import run_consolidation_job
|
|
563
|
+
|
|
564
|
+
internal_context = RequestContext(internal=True)
|
|
565
|
+
result = await run_consolidation_job(
|
|
566
|
+
memory_engine=self,
|
|
567
|
+
bank_id=bank_id,
|
|
568
|
+
request_context=internal_context,
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
logger.info(f"[CONSOLIDATION] bank={bank_id} completed: {result.get('memories_processed', 0)} processed")
|
|
572
|
+
|
|
573
|
+
async def _handle_refresh_mental_model(self, task_dict: dict[str, Any]):
|
|
574
|
+
"""
|
|
575
|
+
Handler for refresh_mental_model tasks.
|
|
576
|
+
|
|
577
|
+
Re-runs the source query through reflect and updates the mental model content.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
task_dict: Dict with 'bank_id', 'mental_model_id', 'operation_id'
|
|
581
|
+
|
|
582
|
+
Raises:
|
|
583
|
+
ValueError: If required fields are missing
|
|
584
|
+
Exception: Any exception from reflect/update (propagates to execute_task for retry)
|
|
585
|
+
"""
|
|
586
|
+
bank_id = task_dict.get("bank_id")
|
|
587
|
+
mental_model_id = task_dict.get("mental_model_id")
|
|
588
|
+
|
|
589
|
+
if not bank_id or not mental_model_id:
|
|
590
|
+
raise ValueError("bank_id and mental_model_id are required for refresh_mental_model task")
|
|
591
|
+
|
|
592
|
+
logger.info(f"[REFRESH_MENTAL_MODEL_TASK] Starting for bank_id={bank_id}, mental_model_id={mental_model_id}")
|
|
593
|
+
|
|
594
|
+
from hindsight_api.models import RequestContext
|
|
595
|
+
|
|
596
|
+
internal_context = RequestContext(internal=True)
|
|
597
|
+
|
|
598
|
+
# Get the current mental model to get source_query
|
|
599
|
+
mental_model = await self.get_mental_model(bank_id, mental_model_id, request_context=internal_context)
|
|
600
|
+
if not mental_model:
|
|
601
|
+
raise ValueError(f"Mental model {mental_model_id} not found in bank {bank_id}")
|
|
602
|
+
|
|
603
|
+
source_query = mental_model["source_query"]
|
|
604
|
+
|
|
605
|
+
# Run reflect to generate new content, excluding the mental model being refreshed
|
|
606
|
+
reflect_result = await self.reflect_async(
|
|
607
|
+
bank_id=bank_id,
|
|
608
|
+
query=source_query,
|
|
609
|
+
request_context=internal_context,
|
|
610
|
+
exclude_mental_model_ids=[mental_model_id],
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
generated_content = reflect_result.text or "No content generated"
|
|
614
|
+
|
|
615
|
+
# Build reflect_response payload to store
|
|
616
|
+
reflect_response = {
|
|
617
|
+
"text": reflect_result.text,
|
|
618
|
+
"based_on": {
|
|
619
|
+
fact_type: [
|
|
620
|
+
{
|
|
621
|
+
"id": str(fact.id),
|
|
622
|
+
"text": fact.text,
|
|
623
|
+
"type": fact_type,
|
|
624
|
+
}
|
|
625
|
+
for fact in facts
|
|
626
|
+
]
|
|
627
|
+
for fact_type, facts in reflect_result.based_on.items()
|
|
628
|
+
},
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
# Update the mental model with the generated content and reflect_response
|
|
632
|
+
await self.update_mental_model(
|
|
633
|
+
bank_id=bank_id,
|
|
634
|
+
mental_model_id=mental_model_id,
|
|
635
|
+
content=generated_content,
|
|
636
|
+
reflect_response=reflect_response,
|
|
637
|
+
request_context=internal_context,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
logger.info(f"[REFRESH_MENTAL_MODEL_TASK] Completed for bank_id={bank_id}, mental_model_id={mental_model_id}")
|
|
641
|
+
|
|
461
642
|
async def execute_task(self, task_dict: dict[str, Any]):
|
|
462
643
|
"""
|
|
463
644
|
Execute a task by routing it to the appropriate handler.
|
|
@@ -467,13 +648,18 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
467
648
|
|
|
468
649
|
Args:
|
|
469
650
|
task_dict: Task dictionary with 'type' key and other payload data
|
|
470
|
-
Example: {'type': '
|
|
651
|
+
Example: {'type': 'batch_retain', 'bank_id': '...', 'contents': [...]}
|
|
471
652
|
"""
|
|
472
653
|
task_type = task_dict.get("type")
|
|
473
654
|
operation_id = task_dict.get("operation_id")
|
|
474
655
|
retry_count = task_dict.get("retry_count", 0)
|
|
475
656
|
max_retries = 3
|
|
476
657
|
|
|
658
|
+
# Set schema context for multi-tenant task execution
|
|
659
|
+
schema = task_dict.pop("_schema", None)
|
|
660
|
+
if schema:
|
|
661
|
+
_current_schema.set(schema)
|
|
662
|
+
|
|
477
663
|
# Check if operation was cancelled (only for tasks with operation_id)
|
|
478
664
|
if operation_id:
|
|
479
665
|
try:
|
|
@@ -492,16 +678,12 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
492
678
|
# Continue with processing if we can't check status
|
|
493
679
|
|
|
494
680
|
try:
|
|
495
|
-
if task_type == "
|
|
496
|
-
await self._handle_access_count_update(task_dict)
|
|
497
|
-
elif task_type == "reinforce_opinion":
|
|
498
|
-
await self._handle_reinforce_opinion(task_dict)
|
|
499
|
-
elif task_type == "form_opinion":
|
|
500
|
-
await self._handle_form_opinion(task_dict)
|
|
501
|
-
elif task_type == "batch_retain":
|
|
681
|
+
if task_type == "batch_retain":
|
|
502
682
|
await self._handle_batch_retain(task_dict)
|
|
503
|
-
elif task_type == "
|
|
504
|
-
await self.
|
|
683
|
+
elif task_type == "consolidation":
|
|
684
|
+
await self._handle_consolidation(task_dict)
|
|
685
|
+
elif task_type == "refresh_mental_model":
|
|
686
|
+
await self._handle_refresh_mental_model(task_dict)
|
|
505
687
|
else:
|
|
506
688
|
logger.error(f"Unknown task type: {task_type}")
|
|
507
689
|
# Don't retry unknown task types
|
|
@@ -509,9 +691,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
509
691
|
await self._delete_operation_record(operation_id)
|
|
510
692
|
return
|
|
511
693
|
|
|
512
|
-
# Task succeeded -
|
|
694
|
+
# Task succeeded - mark operation as completed
|
|
513
695
|
if operation_id:
|
|
514
|
-
await self.
|
|
696
|
+
await self._mark_operation_completed(operation_id)
|
|
515
697
|
|
|
516
698
|
except Exception as e:
|
|
517
699
|
# Task failed - check if we should retry
|
|
@@ -557,7 +739,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
557
739
|
await conn.execute(
|
|
558
740
|
f"""
|
|
559
741
|
UPDATE {fq_table("async_operations")}
|
|
560
|
-
SET status = 'failed', error_message = $2
|
|
742
|
+
SET status = 'failed', error_message = $2, updated_at = NOW()
|
|
561
743
|
WHERE operation_id = $1
|
|
562
744
|
""",
|
|
563
745
|
uuid.UUID(operation_id),
|
|
@@ -567,6 +749,23 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
567
749
|
except Exception as e:
|
|
568
750
|
logger.error(f"Failed to mark operation as failed {operation_id}: {e}")
|
|
569
751
|
|
|
752
|
+
async def _mark_operation_completed(self, operation_id: str):
|
|
753
|
+
"""Helper to mark an operation as completed in the database."""
|
|
754
|
+
try:
|
|
755
|
+
pool = await self._get_pool()
|
|
756
|
+
async with acquire_with_retry(pool) as conn:
|
|
757
|
+
await conn.execute(
|
|
758
|
+
f"""
|
|
759
|
+
UPDATE {fq_table("async_operations")}
|
|
760
|
+
SET status = 'completed', updated_at = NOW(), completed_at = NOW()
|
|
761
|
+
WHERE operation_id = $1
|
|
762
|
+
""",
|
|
763
|
+
uuid.UUID(operation_id),
|
|
764
|
+
)
|
|
765
|
+
logger.info(f"Marked async operation as completed: {operation_id}")
|
|
766
|
+
except Exception as e:
|
|
767
|
+
logger.error(f"Failed to mark operation as completed {operation_id}: {e}")
|
|
768
|
+
|
|
570
769
|
async def initialize(self):
|
|
571
770
|
"""Initialize the connection pool, models, and background workers.
|
|
572
771
|
|
|
@@ -618,9 +817,44 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
618
817
|
await loop.run_in_executor(None, self.query_analyzer.load)
|
|
619
818
|
|
|
620
819
|
async def verify_llm():
|
|
621
|
-
"""Verify LLM
|
|
820
|
+
"""Verify LLM connections are working for all unique configs."""
|
|
622
821
|
if not self._skip_llm_verification:
|
|
822
|
+
# Verify default config
|
|
623
823
|
await self._llm_config.verify_connection()
|
|
824
|
+
# Verify retain config if different from default
|
|
825
|
+
retain_is_different = (
|
|
826
|
+
self._retain_llm_config.provider != self._llm_config.provider
|
|
827
|
+
or self._retain_llm_config.model != self._llm_config.model
|
|
828
|
+
)
|
|
829
|
+
if retain_is_different:
|
|
830
|
+
await self._retain_llm_config.verify_connection()
|
|
831
|
+
# Verify reflect config if different from default and retain
|
|
832
|
+
reflect_is_different = (
|
|
833
|
+
self._reflect_llm_config.provider != self._llm_config.provider
|
|
834
|
+
or self._reflect_llm_config.model != self._llm_config.model
|
|
835
|
+
) and (
|
|
836
|
+
self._reflect_llm_config.provider != self._retain_llm_config.provider
|
|
837
|
+
or self._reflect_llm_config.model != self._retain_llm_config.model
|
|
838
|
+
)
|
|
839
|
+
if reflect_is_different:
|
|
840
|
+
await self._reflect_llm_config.verify_connection()
|
|
841
|
+
# Verify consolidation config if different from all others
|
|
842
|
+
consolidation_is_different = (
|
|
843
|
+
(
|
|
844
|
+
self._consolidation_llm_config.provider != self._llm_config.provider
|
|
845
|
+
or self._consolidation_llm_config.model != self._llm_config.model
|
|
846
|
+
)
|
|
847
|
+
and (
|
|
848
|
+
self._consolidation_llm_config.provider != self._retain_llm_config.provider
|
|
849
|
+
or self._consolidation_llm_config.model != self._retain_llm_config.model
|
|
850
|
+
)
|
|
851
|
+
and (
|
|
852
|
+
self._consolidation_llm_config.provider != self._reflect_llm_config.provider
|
|
853
|
+
or self._consolidation_llm_config.model != self._reflect_llm_config.model
|
|
854
|
+
)
|
|
855
|
+
)
|
|
856
|
+
if consolidation_is_different:
|
|
857
|
+
await self._consolidation_llm_config.verify_connection()
|
|
624
858
|
|
|
625
859
|
# Build list of initialization tasks
|
|
626
860
|
init_tasks = [
|
|
@@ -642,13 +876,17 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
642
876
|
|
|
643
877
|
# Run database migrations if enabled
|
|
644
878
|
if self._run_migrations:
|
|
645
|
-
from ..migrations import run_migrations
|
|
879
|
+
from ..migrations import ensure_embedding_dimension, run_migrations
|
|
646
880
|
|
|
647
881
|
if not self.db_url:
|
|
648
882
|
raise ValueError("Database URL is required for migrations")
|
|
649
883
|
logger.info("Running database migrations...")
|
|
650
884
|
run_migrations(self.db_url)
|
|
651
885
|
|
|
886
|
+
# Ensure embedding column dimension matches the model's dimension
|
|
887
|
+
# This is done after migrations and after embeddings.initialize()
|
|
888
|
+
ensure_embedding_dimension(self.db_url, self.embeddings.dimension)
|
|
889
|
+
|
|
652
890
|
logger.info(f"Connecting to PostgreSQL at {self.db_url}")
|
|
653
891
|
|
|
654
892
|
# Create connection pool
|
|
@@ -658,9 +896,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
658
896
|
self.db_url,
|
|
659
897
|
min_size=self._pool_min_size,
|
|
660
898
|
max_size=self._pool_max_size,
|
|
661
|
-
command_timeout=
|
|
899
|
+
command_timeout=self._db_command_timeout,
|
|
662
900
|
statement_cache_size=0, # Disable prepared statement cache
|
|
663
|
-
timeout=
|
|
901
|
+
timeout=self._db_acquire_timeout, # Connection acquisition timeout (seconds)
|
|
664
902
|
)
|
|
665
903
|
|
|
666
904
|
# Initialize entity resolver with pool
|
|
@@ -743,8 +981,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
743
981
|
"""
|
|
744
982
|
Wait for all pending background tasks to complete.
|
|
745
983
|
|
|
746
|
-
This is useful in tests to ensure background tasks
|
|
747
|
-
complete before making assertions.
|
|
984
|
+
This is useful in tests to ensure background tasks complete before making assertions.
|
|
748
985
|
"""
|
|
749
986
|
if hasattr(self._task_backend, "wait_for_pending_tasks"):
|
|
750
987
|
await self._task_backend.wait_for_pending_tasks()
|
|
@@ -967,7 +1204,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
967
1204
|
document_id: str | None = None,
|
|
968
1205
|
fact_type_override: str | None = None,
|
|
969
1206
|
confidence_score: float | None = None,
|
|
970
|
-
|
|
1207
|
+
document_tags: list[str] | None = None,
|
|
1208
|
+
return_usage: bool = False,
|
|
1209
|
+
):
|
|
971
1210
|
"""
|
|
972
1211
|
Store multiple content items as memory units in ONE batch operation.
|
|
973
1212
|
|
|
@@ -988,9 +1227,11 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
988
1227
|
Applies the same document_id to ALL content items that don't specify their own.
|
|
989
1228
|
fact_type_override: Override fact type for all facts ('world', 'experience', 'opinion')
|
|
990
1229
|
confidence_score: Confidence score for opinions (0.0 to 1.0)
|
|
1230
|
+
return_usage: If True, returns tuple of (unit_ids, TokenUsage). Default False for backward compatibility.
|
|
991
1231
|
|
|
992
1232
|
Returns:
|
|
993
|
-
List of lists of unit IDs (one list per content item)
|
|
1233
|
+
If return_usage=False: List of lists of unit IDs (one list per content item)
|
|
1234
|
+
If return_usage=True: Tuple of (unit_ids, TokenUsage)
|
|
994
1235
|
|
|
995
1236
|
Example (new style - per-content document_id):
|
|
996
1237
|
unit_ids = await memory.retain_batch_async(
|
|
@@ -1017,6 +1258,8 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1017
1258
|
start_time = time.time()
|
|
1018
1259
|
|
|
1019
1260
|
if not contents:
|
|
1261
|
+
if return_usage:
|
|
1262
|
+
return [], TokenUsage()
|
|
1020
1263
|
return []
|
|
1021
1264
|
|
|
1022
1265
|
# Authenticate tenant and set schema in context (for fq_table())
|
|
@@ -1046,6 +1289,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1046
1289
|
# Auto-chunk large batches by character count to avoid timeouts and memory issues
|
|
1047
1290
|
# Calculate total character count
|
|
1048
1291
|
total_chars = sum(len(item.get("content", "")) for item in contents)
|
|
1292
|
+
total_usage = TokenUsage()
|
|
1049
1293
|
|
|
1050
1294
|
CHARS_PER_BATCH = 600_000
|
|
1051
1295
|
|
|
@@ -1078,7 +1322,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1078
1322
|
|
|
1079
1323
|
logger.info(f"Split into {len(sub_batches)} sub-batches: {[len(b) for b in sub_batches]} items each")
|
|
1080
1324
|
|
|
1081
|
-
# Process each sub-batch
|
|
1325
|
+
# Process each sub-batch
|
|
1082
1326
|
all_results = []
|
|
1083
1327
|
for i, sub_batch in enumerate(sub_batches, 1):
|
|
1084
1328
|
sub_batch_chars = sum(len(item.get("content", "")) for item in sub_batch)
|
|
@@ -1086,15 +1330,17 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1086
1330
|
f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars"
|
|
1087
1331
|
)
|
|
1088
1332
|
|
|
1089
|
-
sub_results = await self._retain_batch_async_internal(
|
|
1333
|
+
sub_results, sub_usage = await self._retain_batch_async_internal(
|
|
1090
1334
|
bank_id=bank_id,
|
|
1091
1335
|
contents=sub_batch,
|
|
1092
1336
|
document_id=document_id,
|
|
1093
1337
|
is_first_batch=i == 1, # Only upsert on first batch
|
|
1094
1338
|
fact_type_override=fact_type_override,
|
|
1095
1339
|
confidence_score=confidence_score,
|
|
1340
|
+
document_tags=document_tags,
|
|
1096
1341
|
)
|
|
1097
1342
|
all_results.extend(sub_results)
|
|
1343
|
+
total_usage = total_usage + sub_usage
|
|
1098
1344
|
|
|
1099
1345
|
total_time = time.time() - start_time
|
|
1100
1346
|
logger.info(
|
|
@@ -1103,13 +1349,14 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1103
1349
|
result = all_results
|
|
1104
1350
|
else:
|
|
1105
1351
|
# Small batch - use internal method directly
|
|
1106
|
-
result = await self._retain_batch_async_internal(
|
|
1352
|
+
result, total_usage = await self._retain_batch_async_internal(
|
|
1107
1353
|
bank_id=bank_id,
|
|
1108
1354
|
contents=contents,
|
|
1109
1355
|
document_id=document_id,
|
|
1110
1356
|
is_first_batch=True,
|
|
1111
1357
|
fact_type_override=fact_type_override,
|
|
1112
1358
|
confidence_score=confidence_score,
|
|
1359
|
+
document_tags=document_tags,
|
|
1113
1360
|
)
|
|
1114
1361
|
|
|
1115
1362
|
# Call post-operation hook if validator is configured
|
|
@@ -1132,6 +1379,19 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1132
1379
|
except Exception as e:
|
|
1133
1380
|
logger.warning(f"Post-retain hook error (non-fatal): {e}")
|
|
1134
1381
|
|
|
1382
|
+
# Trigger consolidation as a tracked async operation if enabled
|
|
1383
|
+
from ..config import get_config
|
|
1384
|
+
|
|
1385
|
+
config = get_config()
|
|
1386
|
+
if config.enable_observations:
|
|
1387
|
+
try:
|
|
1388
|
+
await self.submit_async_consolidation(bank_id=bank_id, request_context=request_context)
|
|
1389
|
+
except Exception as e:
|
|
1390
|
+
# Log but don't fail the retain - consolidation is non-critical
|
|
1391
|
+
logger.warning(f"Failed to submit consolidation task for bank {bank_id}: {e}")
|
|
1392
|
+
|
|
1393
|
+
if return_usage:
|
|
1394
|
+
return result, total_usage
|
|
1135
1395
|
return result
|
|
1136
1396
|
|
|
1137
1397
|
async def _retain_batch_async_internal(
|
|
@@ -1142,7 +1402,8 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1142
1402
|
is_first_batch: bool = True,
|
|
1143
1403
|
fact_type_override: str | None = None,
|
|
1144
1404
|
confidence_score: float | None = None,
|
|
1145
|
-
|
|
1405
|
+
document_tags: list[str] | None = None,
|
|
1406
|
+
) -> tuple[list[list[str]], "TokenUsage"]:
|
|
1146
1407
|
"""
|
|
1147
1408
|
Internal method for batch processing without chunking logic.
|
|
1148
1409
|
|
|
@@ -1158,6 +1419,10 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1158
1419
|
is_first_batch: Whether this is the first batch (for chunked operations, only delete on first batch)
|
|
1159
1420
|
fact_type_override: Override fact type for all facts
|
|
1160
1421
|
confidence_score: Confidence score for opinions
|
|
1422
|
+
document_tags: Tags applied to all items in this batch
|
|
1423
|
+
|
|
1424
|
+
Returns:
|
|
1425
|
+
Tuple of (unit ID lists, token usage for fact extraction)
|
|
1161
1426
|
"""
|
|
1162
1427
|
# Backpressure: limit concurrent retains to prevent database contention
|
|
1163
1428
|
async with self._put_semaphore:
|
|
@@ -1168,9 +1433,8 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1168
1433
|
return await orchestrator.retain_batch(
|
|
1169
1434
|
pool=pool,
|
|
1170
1435
|
embeddings_model=self.embeddings,
|
|
1171
|
-
llm_config=self.
|
|
1436
|
+
llm_config=self._retain_llm_config,
|
|
1172
1437
|
entity_resolver=self.entity_resolver,
|
|
1173
|
-
task_backend=self._task_backend,
|
|
1174
1438
|
format_date_fn=self._format_readable_date,
|
|
1175
1439
|
duplicate_checker_fn=self._find_duplicate_facts_batch,
|
|
1176
1440
|
bank_id=bank_id,
|
|
@@ -1179,6 +1443,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1179
1443
|
is_first_batch=is_first_batch,
|
|
1180
1444
|
fact_type_override=fact_type_override,
|
|
1181
1445
|
confidence_score=confidence_score,
|
|
1446
|
+
document_tags=document_tags,
|
|
1182
1447
|
)
|
|
1183
1448
|
|
|
1184
1449
|
def recall(
|
|
@@ -1237,6 +1502,10 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1237
1502
|
include_chunks: bool = False,
|
|
1238
1503
|
max_chunk_tokens: int = 8192,
|
|
1239
1504
|
request_context: "RequestContext",
|
|
1505
|
+
tags: list[str] | None = None,
|
|
1506
|
+
tags_match: TagsMatch = "any",
|
|
1507
|
+
_connection_budget: int | None = None,
|
|
1508
|
+
_quiet: bool = False,
|
|
1240
1509
|
) -> RecallResultModel:
|
|
1241
1510
|
"""
|
|
1242
1511
|
Recall memories using N*4-way parallel retrieval (N fact types × 4 retrieval methods).
|
|
@@ -1262,6 +1531,8 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1262
1531
|
max_entity_tokens: Maximum tokens for entity observations (default 500)
|
|
1263
1532
|
include_chunks: Whether to include raw chunks in the response
|
|
1264
1533
|
max_chunk_tokens: Maximum tokens for chunks (default 8192)
|
|
1534
|
+
tags: Optional list of tags for visibility filtering (OR matching - returns
|
|
1535
|
+
memories that have at least one matching tag)
|
|
1265
1536
|
|
|
1266
1537
|
Returns:
|
|
1267
1538
|
RecallResultModel containing:
|
|
@@ -1285,6 +1556,12 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1285
1556
|
f"Must be one of: {', '.join(sorted(VALID_RECALL_FACT_TYPES))}"
|
|
1286
1557
|
)
|
|
1287
1558
|
|
|
1559
|
+
# Filter out 'opinion' - opinions are no longer returned from recall
|
|
1560
|
+
fact_type = [ft for ft in fact_type if ft != "opinion"]
|
|
1561
|
+
if not fact_type:
|
|
1562
|
+
# All requested types were opinions - return empty result
|
|
1563
|
+
return RecallResultModel(results=[], entities={}, chunks={})
|
|
1564
|
+
|
|
1288
1565
|
# Validate operation if validator is configured
|
|
1289
1566
|
if self._operation_validator:
|
|
1290
1567
|
from hindsight_api.extensions import RecallContext
|
|
@@ -1310,10 +1587,17 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1310
1587
|
effective_budget = budget if budget is not None else Budget.MID
|
|
1311
1588
|
thinking_budget = budget_mapping[effective_budget]
|
|
1312
1589
|
|
|
1590
|
+
# Log recall start with tags if present (skip if quiet mode for internal operations)
|
|
1591
|
+
if not _quiet:
|
|
1592
|
+
tags_info = f", tags={tags} ({tags_match})" if tags else ""
|
|
1593
|
+
logger.info(f"[RECALL {bank_id[:8]}] Starting recall for query: {query[:50]}...{tags_info}")
|
|
1594
|
+
|
|
1313
1595
|
# Backpressure: limit concurrent recalls to prevent overwhelming the database
|
|
1314
1596
|
result = None
|
|
1315
1597
|
error_msg = None
|
|
1598
|
+
semaphore_wait_start = time.time()
|
|
1316
1599
|
async with self._search_semaphore:
|
|
1600
|
+
semaphore_wait = time.time() - semaphore_wait_start
|
|
1317
1601
|
# Retry loop for connection errors
|
|
1318
1602
|
max_retries = 3
|
|
1319
1603
|
for attempt in range(max_retries + 1):
|
|
@@ -1331,6 +1615,11 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1331
1615
|
include_chunks,
|
|
1332
1616
|
max_chunk_tokens,
|
|
1333
1617
|
request_context,
|
|
1618
|
+
semaphore_wait=semaphore_wait,
|
|
1619
|
+
tags=tags,
|
|
1620
|
+
tags_match=tags_match,
|
|
1621
|
+
connection_budget=_connection_budget,
|
|
1622
|
+
quiet=_quiet,
|
|
1334
1623
|
)
|
|
1335
1624
|
break # Success - exit retry loop
|
|
1336
1625
|
except Exception as e:
|
|
@@ -1448,6 +1737,11 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1448
1737
|
include_chunks: bool = False,
|
|
1449
1738
|
max_chunk_tokens: int = 8192,
|
|
1450
1739
|
request_context: "RequestContext" = None,
|
|
1740
|
+
semaphore_wait: float = 0.0,
|
|
1741
|
+
tags: list[str] | None = None,
|
|
1742
|
+
tags_match: TagsMatch = "any",
|
|
1743
|
+
connection_budget: int | None = None,
|
|
1744
|
+
quiet: bool = False,
|
|
1451
1745
|
) -> RecallResultModel:
|
|
1452
1746
|
"""
|
|
1453
1747
|
Search implementation with modular retrieval and reranking.
|
|
@@ -1477,7 +1771,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1477
1771
|
# Initialize tracer if requested
|
|
1478
1772
|
from .search.tracer import SearchTracer
|
|
1479
1773
|
|
|
1480
|
-
tracer =
|
|
1774
|
+
tracer = (
|
|
1775
|
+
SearchTracer(query, thinking_budget, max_tokens, tags=tags, tags_match=tags_match) if enable_trace else None
|
|
1776
|
+
)
|
|
1481
1777
|
if tracer:
|
|
1482
1778
|
tracer.start()
|
|
1483
1779
|
|
|
@@ -1487,8 +1783,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1487
1783
|
# Buffer logs for clean output in concurrent scenarios
|
|
1488
1784
|
recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
|
|
1489
1785
|
log_buffer = []
|
|
1786
|
+
tags_info = f", tags={tags}, tags_match={tags_match}" if tags else ""
|
|
1490
1787
|
log_buffer.append(
|
|
1491
|
-
f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})"
|
|
1788
|
+
f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens}{tags_info})"
|
|
1492
1789
|
)
|
|
1493
1790
|
|
|
1494
1791
|
try:
|
|
@@ -1502,37 +1799,70 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1502
1799
|
tracer.record_query_embedding(query_embedding)
|
|
1503
1800
|
tracer.add_phase_metric("generate_query_embedding", step_duration)
|
|
1504
1801
|
|
|
1505
|
-
# Step 2:
|
|
1802
|
+
# Step 2: Optimized parallel retrieval using batched queries
|
|
1803
|
+
# - Semantic + BM25 combined in 1 CTE query for ALL fact types
|
|
1804
|
+
# - Graph runs per fact type (complex traversal)
|
|
1805
|
+
# - Temporal runs per fact type (if constraint detected)
|
|
1506
1806
|
step_start = time.time()
|
|
1507
1807
|
query_embedding_str = str(query_embedding)
|
|
1508
1808
|
|
|
1509
|
-
from .search.retrieval import
|
|
1809
|
+
from .search.retrieval import (
|
|
1810
|
+
get_default_graph_retriever,
|
|
1811
|
+
retrieve_all_fact_types_parallel,
|
|
1812
|
+
)
|
|
1510
1813
|
|
|
1511
1814
|
# Track each retrieval start time
|
|
1512
1815
|
retrieval_start = time.time()
|
|
1513
1816
|
|
|
1514
|
-
# Run retrieval
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1817
|
+
# Run optimized retrieval with connection budget
|
|
1818
|
+
config = get_config()
|
|
1819
|
+
effective_connection_budget = (
|
|
1820
|
+
connection_budget if connection_budget is not None else config.recall_connection_budget
|
|
1821
|
+
)
|
|
1822
|
+
async with budgeted_operation(
|
|
1823
|
+
max_connections=effective_connection_budget,
|
|
1824
|
+
operation_id=f"recall-{recall_id}",
|
|
1825
|
+
) as op:
|
|
1826
|
+
budgeted_pool = op.wrap_pool(pool)
|
|
1827
|
+
parallel_start = time.time()
|
|
1828
|
+
multi_result = await retrieve_all_fact_types_parallel(
|
|
1829
|
+
budgeted_pool,
|
|
1830
|
+
query,
|
|
1831
|
+
query_embedding_str,
|
|
1832
|
+
bank_id,
|
|
1833
|
+
fact_type, # Pass all fact types at once
|
|
1834
|
+
thinking_budget,
|
|
1835
|
+
question_date,
|
|
1836
|
+
self.query_analyzer,
|
|
1837
|
+
tags=tags,
|
|
1838
|
+
tags_match=tags_match,
|
|
1518
1839
|
)
|
|
1519
|
-
|
|
1520
|
-
]
|
|
1521
|
-
all_retrievals = await asyncio.gather(*retrieval_tasks)
|
|
1840
|
+
parallel_duration = time.time() - parallel_start
|
|
1522
1841
|
|
|
1523
1842
|
# Combine all results from all fact types and aggregate timings
|
|
1524
1843
|
semantic_results = []
|
|
1525
1844
|
bm25_results = []
|
|
1526
1845
|
graph_results = []
|
|
1527
1846
|
temporal_results = []
|
|
1528
|
-
aggregated_timings = {
|
|
1847
|
+
aggregated_timings = {
|
|
1848
|
+
"semantic": 0.0,
|
|
1849
|
+
"bm25": 0.0,
|
|
1850
|
+
"graph": 0.0,
|
|
1851
|
+
"temporal": 0.0,
|
|
1852
|
+
"temporal_extraction": 0.0,
|
|
1853
|
+
}
|
|
1854
|
+
all_mpfp_timings = []
|
|
1529
1855
|
|
|
1530
1856
|
detected_temporal_constraint = None
|
|
1531
|
-
|
|
1857
|
+
max_conn_wait = multi_result.max_conn_wait
|
|
1858
|
+
for ft in fact_type:
|
|
1859
|
+
retrieval_result = multi_result.results_by_fact_type.get(ft)
|
|
1860
|
+
if not retrieval_result:
|
|
1861
|
+
continue
|
|
1862
|
+
|
|
1532
1863
|
# Log fact types in this retrieval batch
|
|
1533
|
-
ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
|
|
1534
1864
|
logger.debug(
|
|
1535
|
-
f"[RECALL {recall_id}] Fact type '{
|
|
1865
|
+
f"[RECALL {recall_id}] Fact type '{ft}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}"
|
|
1536
1866
|
)
|
|
1537
1867
|
|
|
1538
1868
|
semantic_results.extend(retrieval_result.semantic)
|
|
@@ -1570,6 +1900,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1570
1900
|
f"semantic={len(semantic_results)}({aggregated_timings['semantic']:.3f}s)",
|
|
1571
1901
|
f"bm25={len(bm25_results)}({aggregated_timings['bm25']:.3f}s)",
|
|
1572
1902
|
f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)",
|
|
1903
|
+
f"temporal_extraction={aggregated_timings['temporal_extraction']:.3f}s",
|
|
1573
1904
|
]
|
|
1574
1905
|
temporal_info = ""
|
|
1575
1906
|
if detected_temporal_constraint:
|
|
@@ -1578,9 +1909,41 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1578
1909
|
timing_parts.append(f"temporal={temporal_count}({aggregated_timings['temporal']:.3f}s)")
|
|
1579
1910
|
temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
|
|
1580
1911
|
log_buffer.append(
|
|
1581
|
-
f" [2]
|
|
1912
|
+
f" [2] Parallel retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {parallel_duration:.3f}s{temporal_info}"
|
|
1582
1913
|
)
|
|
1583
1914
|
|
|
1915
|
+
# Log graph retriever timing breakdown if available
|
|
1916
|
+
if all_mpfp_timings:
|
|
1917
|
+
retriever_name = get_default_graph_retriever().name.upper()
|
|
1918
|
+
mpfp_total = all_mpfp_timings[0] # Take first fact type's timing as representative
|
|
1919
|
+
mpfp_parts = [
|
|
1920
|
+
f"db_queries={mpfp_total.db_queries}",
|
|
1921
|
+
f"edge_load={mpfp_total.edge_load_time:.3f}s",
|
|
1922
|
+
f"edges={mpfp_total.edge_count}",
|
|
1923
|
+
f"patterns={mpfp_total.pattern_count}",
|
|
1924
|
+
]
|
|
1925
|
+
if mpfp_total.seeds_time > 0.01:
|
|
1926
|
+
mpfp_parts.append(f"seeds={mpfp_total.seeds_time:.3f}s")
|
|
1927
|
+
if mpfp_total.fusion > 0.001:
|
|
1928
|
+
mpfp_parts.append(f"fusion={mpfp_total.fusion:.3f}s")
|
|
1929
|
+
if mpfp_total.fetch > 0.001:
|
|
1930
|
+
mpfp_parts.append(f"fetch={mpfp_total.fetch:.3f}s")
|
|
1931
|
+
log_buffer.append(f" [{retriever_name}] {', '.join(mpfp_parts)}")
|
|
1932
|
+
# Log detailed hop timing for debugging slow queries
|
|
1933
|
+
if mpfp_total.hop_details:
|
|
1934
|
+
for hd in mpfp_total.hop_details:
|
|
1935
|
+
log_buffer.append(
|
|
1936
|
+
f" hop{hd['hop']}: exec={hd.get('exec_time', 0) * 1000:.0f}ms, "
|
|
1937
|
+
f"uncached={hd.get('uncached_after_filter', 0)}, "
|
|
1938
|
+
f"load={hd.get('load_time', 0) * 1000:.0f}ms, "
|
|
1939
|
+
f"edges={hd.get('edges_loaded', 0)}"
|
|
1940
|
+
)
|
|
1941
|
+
|
|
1942
|
+
# Record temporal constraint in tracer if detected
|
|
1943
|
+
if tracer and detected_temporal_constraint:
|
|
1944
|
+
start_dt, end_dt = detected_temporal_constraint
|
|
1945
|
+
tracer.record_temporal_constraint(start_dt, end_dt)
|
|
1946
|
+
|
|
1584
1947
|
# Record retrieval results for tracer - per fact type
|
|
1585
1948
|
if tracer:
|
|
1586
1949
|
# Convert RetrievalResult to old tuple format for tracer
|
|
@@ -1588,8 +1951,10 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1588
1951
|
return [(r.id, r.__dict__) for r in results]
|
|
1589
1952
|
|
|
1590
1953
|
# Add retrieval results per fact type (to show parallel execution in UI)
|
|
1591
|
-
for
|
|
1592
|
-
|
|
1954
|
+
for ft_name in fact_type:
|
|
1955
|
+
rr = multi_result.results_by_fact_type.get(ft_name)
|
|
1956
|
+
if not rr:
|
|
1957
|
+
continue
|
|
1593
1958
|
|
|
1594
1959
|
# Add semantic retrieval results for this fact type
|
|
1595
1960
|
tracer.add_retrieval_results(
|
|
@@ -1621,14 +1986,22 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1621
1986
|
fact_type=ft_name,
|
|
1622
1987
|
)
|
|
1623
1988
|
|
|
1624
|
-
# Add temporal retrieval results for this fact type
|
|
1625
|
-
|
|
1989
|
+
# Add temporal retrieval results for this fact type
|
|
1990
|
+
# Show temporal even with 0 results if constraint was detected
|
|
1991
|
+
if rr.temporal is not None or rr.temporal_constraint is not None:
|
|
1992
|
+
temporal_metadata = {"budget": thinking_budget}
|
|
1993
|
+
if rr.temporal_constraint:
|
|
1994
|
+
start_dt, end_dt = rr.temporal_constraint
|
|
1995
|
+
temporal_metadata["constraint"] = {
|
|
1996
|
+
"start": start_dt.isoformat() if start_dt else None,
|
|
1997
|
+
"end": end_dt.isoformat() if end_dt else None,
|
|
1998
|
+
}
|
|
1626
1999
|
tracer.add_retrieval_results(
|
|
1627
2000
|
method_name="temporal",
|
|
1628
|
-
results=to_tuple_format(rr.temporal),
|
|
2001
|
+
results=to_tuple_format(rr.temporal or []),
|
|
1629
2002
|
duration_seconds=rr.timings.get("temporal", 0.0),
|
|
1630
2003
|
score_field="temporal_score",
|
|
1631
|
-
metadata=
|
|
2004
|
+
metadata=temporal_metadata,
|
|
1632
2005
|
fact_type=ft_name,
|
|
1633
2006
|
)
|
|
1634
2007
|
|
|
@@ -1678,11 +2051,24 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1678
2051
|
# Ensure reranker is initialized (for lazy initialization mode)
|
|
1679
2052
|
await reranker_instance.ensure_initialized()
|
|
1680
2053
|
|
|
2054
|
+
# Pre-filter candidates to reduce reranking cost (RRF already provides good ranking)
|
|
2055
|
+
# This is especially important for remote rerankers with network latency
|
|
2056
|
+
reranker_max_candidates = get_config().reranker_max_candidates
|
|
2057
|
+
pre_filtered_count = 0
|
|
2058
|
+
if len(merged_candidates) > reranker_max_candidates:
|
|
2059
|
+
# Sort by RRF score and take top candidates
|
|
2060
|
+
merged_candidates.sort(key=lambda mc: mc.rrf_score, reverse=True)
|
|
2061
|
+
pre_filtered_count = len(merged_candidates) - reranker_max_candidates
|
|
2062
|
+
merged_candidates = merged_candidates[:reranker_max_candidates]
|
|
2063
|
+
|
|
1681
2064
|
# Rerank using cross-encoder
|
|
1682
|
-
scored_results = reranker_instance.rerank(query, merged_candidates)
|
|
2065
|
+
scored_results = await reranker_instance.rerank(query, merged_candidates)
|
|
1683
2066
|
|
|
1684
2067
|
step_duration = time.time() - step_start
|
|
1685
|
-
|
|
2068
|
+
pre_filter_note = f" (pre-filtered {pre_filtered_count})" if pre_filtered_count > 0 else ""
|
|
2069
|
+
log_buffer.append(
|
|
2070
|
+
f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s{pre_filter_note}"
|
|
2071
|
+
)
|
|
1686
2072
|
|
|
1687
2073
|
# Step 4.5: Combine cross-encoder score with retrieval signals
|
|
1688
2074
|
# This preserves retrieval work (RRF, temporal, recency) instead of pure cross-encoder ranking
|
|
@@ -1786,7 +2172,6 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1786
2172
|
text=sr.retrieval.text,
|
|
1787
2173
|
context=sr.retrieval.context or "",
|
|
1788
2174
|
event_date=sr.retrieval.occurred_start,
|
|
1789
|
-
access_count=sr.retrieval.access_count,
|
|
1790
2175
|
is_entry_point=(sr.id in [ep.node_id for ep in tracer.entry_points]),
|
|
1791
2176
|
parent_node_id=None, # In parallel retrieval, there's no clear parent
|
|
1792
2177
|
link_type=None,
|
|
@@ -1798,12 +2183,6 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1798
2183
|
final_weight=sr.weight,
|
|
1799
2184
|
)
|
|
1800
2185
|
|
|
1801
|
-
# Step 8: Queue access count updates for visited nodes
|
|
1802
|
-
visited_ids = list(set([sr.id for sr in scored_results[:50]])) # Top 50
|
|
1803
|
-
if visited_ids:
|
|
1804
|
-
await self._task_backend.submit_task({"type": "access_count_update", "node_ids": visited_ids})
|
|
1805
|
-
log_buffer.append(f" [7] Queued access count updates for {len(visited_ids)} nodes")
|
|
1806
|
-
|
|
1807
2186
|
# Log fact_type distribution in results
|
|
1808
2187
|
fact_type_counts = {}
|
|
1809
2188
|
for sr in top_scored:
|
|
@@ -1878,6 +2257,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1878
2257
|
mentioned_at=result_dict.get("mentioned_at"),
|
|
1879
2258
|
document_id=result_dict.get("document_id"),
|
|
1880
2259
|
chunk_id=result_dict.get("chunk_id"),
|
|
2260
|
+
tags=result_dict.get("tags"),
|
|
1881
2261
|
)
|
|
1882
2262
|
)
|
|
1883
2263
|
|
|
@@ -1902,35 +2282,15 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1902
2282
|
entities_ordered.append((entity_id, entity_name))
|
|
1903
2283
|
seen_entity_ids.add(entity_id)
|
|
1904
2284
|
|
|
1905
|
-
#
|
|
2285
|
+
# Return entities with empty observations (summaries now live in mental models)
|
|
1906
2286
|
entities_dict = {}
|
|
1907
|
-
encoding = _get_tiktoken_encoding()
|
|
1908
|
-
|
|
1909
2287
|
for entity_id, entity_name in entities_ordered:
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
bank_id, entity_id, limit=5, request_context=request_context
|
|
2288
|
+
entities_dict[entity_name] = EntityState(
|
|
2289
|
+
entity_id=entity_id,
|
|
2290
|
+
canonical_name=entity_name,
|
|
2291
|
+
observations=[], # Mental models provide this now
|
|
1915
2292
|
)
|
|
1916
2293
|
|
|
1917
|
-
# Calculate tokens for this entity's observations
|
|
1918
|
-
entity_tokens = 0
|
|
1919
|
-
included_observations = []
|
|
1920
|
-
for obs in observations:
|
|
1921
|
-
obs_tokens = len(encoding.encode(obs.text))
|
|
1922
|
-
if total_entity_tokens + entity_tokens + obs_tokens <= max_entity_tokens:
|
|
1923
|
-
included_observations.append(obs)
|
|
1924
|
-
entity_tokens += obs_tokens
|
|
1925
|
-
else:
|
|
1926
|
-
break
|
|
1927
|
-
|
|
1928
|
-
if included_observations:
|
|
1929
|
-
entities_dict[entity_name] = EntityState(
|
|
1930
|
-
entity_id=entity_id, canonical_name=entity_name, observations=included_observations
|
|
1931
|
-
)
|
|
1932
|
-
total_entity_tokens += entity_tokens
|
|
1933
|
-
|
|
1934
2294
|
# Fetch chunks if requested
|
|
1935
2295
|
chunks_dict = None
|
|
1936
2296
|
if include_chunks and top_scored:
|
|
@@ -2002,16 +2362,25 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2002
2362
|
total_time = time.time() - recall_start
|
|
2003
2363
|
num_chunks = len(chunks_dict) if chunks_dict else 0
|
|
2004
2364
|
num_entities = len(entities_dict) if entities_dict else 0
|
|
2365
|
+
# Include wait times in log if significant
|
|
2366
|
+
wait_parts = []
|
|
2367
|
+
if semaphore_wait > 0.01:
|
|
2368
|
+
wait_parts.append(f"sem={semaphore_wait:.3f}s")
|
|
2369
|
+
if max_conn_wait > 0.01:
|
|
2370
|
+
wait_parts.append(f"conn={max_conn_wait:.3f}s")
|
|
2371
|
+
wait_info = f" | waits: {', '.join(wait_parts)}" if wait_parts else ""
|
|
2005
2372
|
log_buffer.append(
|
|
2006
|
-
f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s"
|
|
2373
|
+
f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s{wait_info}"
|
|
2007
2374
|
)
|
|
2008
|
-
|
|
2375
|
+
if not quiet:
|
|
2376
|
+
logger.info("\n" + "\n".join(log_buffer))
|
|
2009
2377
|
|
|
2010
2378
|
return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
|
|
2011
2379
|
|
|
2012
2380
|
except Exception as e:
|
|
2013
2381
|
log_buffer.append(f"[RECALL {recall_id}] ERROR after {time.time() - recall_start:.3f}s: {str(e)}")
|
|
2014
|
-
|
|
2382
|
+
if not quiet:
|
|
2383
|
+
logger.error("\n" + "\n".join(log_buffer))
|
|
2015
2384
|
raise Exception(f"Failed to search memories: {str(e)}")
|
|
2016
2385
|
|
|
2017
2386
|
def _filter_by_token_budget(
|
|
@@ -2073,11 +2442,11 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2073
2442
|
doc = await conn.fetchrow(
|
|
2074
2443
|
f"""
|
|
2075
2444
|
SELECT d.id, d.bank_id, d.original_text, d.content_hash,
|
|
2076
|
-
d.created_at, d.updated_at, COUNT(mu.id) as unit_count
|
|
2445
|
+
d.created_at, d.updated_at, d.tags, COUNT(mu.id) as unit_count
|
|
2077
2446
|
FROM {fq_table("documents")} d
|
|
2078
2447
|
LEFT JOIN {fq_table("memory_units")} mu ON mu.document_id = d.id
|
|
2079
2448
|
WHERE d.id = $1 AND d.bank_id = $2
|
|
2080
|
-
GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
|
|
2449
|
+
GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at, d.tags
|
|
2081
2450
|
""",
|
|
2082
2451
|
document_id,
|
|
2083
2452
|
bank_id,
|
|
@@ -2094,6 +2463,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2094
2463
|
"memory_unit_count": doc["unit_count"],
|
|
2095
2464
|
"created_at": doc["created_at"].isoformat() if doc["created_at"] else None,
|
|
2096
2465
|
"updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else None,
|
|
2466
|
+
"tags": list(doc["tags"]) if doc["tags"] else [],
|
|
2097
2467
|
}
|
|
2098
2468
|
|
|
2099
2469
|
async def delete_document(
|
|
@@ -2118,10 +2488,12 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2118
2488
|
pool = await self._get_pool()
|
|
2119
2489
|
async with acquire_with_retry(pool) as conn:
|
|
2120
2490
|
async with conn.transaction():
|
|
2121
|
-
#
|
|
2122
|
-
|
|
2123
|
-
f"SELECT
|
|
2491
|
+
# Get memory unit IDs before deletion (for mental model invalidation)
|
|
2492
|
+
unit_rows = await conn.fetch(
|
|
2493
|
+
f"SELECT id FROM {fq_table('memory_units')} WHERE document_id = $1", document_id
|
|
2124
2494
|
)
|
|
2495
|
+
unit_ids = [str(row["id"]) for row in unit_rows]
|
|
2496
|
+
units_count = len(unit_ids)
|
|
2125
2497
|
|
|
2126
2498
|
# Delete document (cascades to memory_units and all their links)
|
|
2127
2499
|
deleted = await conn.fetchval(
|
|
@@ -2130,6 +2502,10 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2130
2502
|
bank_id,
|
|
2131
2503
|
)
|
|
2132
2504
|
|
|
2505
|
+
# Invalidate deleted fact IDs from mental models
|
|
2506
|
+
if deleted and unit_ids:
|
|
2507
|
+
await self._invalidate_facts_from_mental_models(conn, bank_id, unit_ids)
|
|
2508
|
+
|
|
2133
2509
|
return {"document_deleted": 1 if deleted else 0, "memory_units_deleted": units_count if deleted else 0}
|
|
2134
2510
|
|
|
2135
2511
|
async def delete_memory_unit(
|
|
@@ -2157,11 +2533,18 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2157
2533
|
pool = await self._get_pool()
|
|
2158
2534
|
async with acquire_with_retry(pool) as conn:
|
|
2159
2535
|
async with conn.transaction():
|
|
2536
|
+
# Get bank_id before deletion (for mental model invalidation)
|
|
2537
|
+
bank_id = await conn.fetchval(f"SELECT bank_id FROM {fq_table('memory_units')} WHERE id = $1", unit_id)
|
|
2538
|
+
|
|
2160
2539
|
# Delete the memory unit (cascades to links and associations)
|
|
2161
2540
|
deleted = await conn.fetchval(
|
|
2162
2541
|
f"DELETE FROM {fq_table('memory_units')} WHERE id = $1 RETURNING id", unit_id
|
|
2163
2542
|
)
|
|
2164
2543
|
|
|
2544
|
+
# Invalidate deleted fact ID from mental models
|
|
2545
|
+
if deleted and bank_id:
|
|
2546
|
+
await self._invalidate_facts_from_mental_models(conn, bank_id, [str(deleted)])
|
|
2547
|
+
|
|
2165
2548
|
return {
|
|
2166
2549
|
"success": deleted is not None,
|
|
2167
2550
|
"unit_id": str(deleted) if deleted else None,
|
|
@@ -2253,11 +2636,85 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2253
2636
|
except Exception as e:
|
|
2254
2637
|
raise Exception(f"Failed to delete agent data: {str(e)}")
|
|
2255
2638
|
|
|
2639
|
+
async def clear_observations(
|
|
2640
|
+
self,
|
|
2641
|
+
bank_id: str,
|
|
2642
|
+
*,
|
|
2643
|
+
request_context: "RequestContext",
|
|
2644
|
+
) -> dict[str, int]:
|
|
2645
|
+
"""
|
|
2646
|
+
Clear all observations for a bank (consolidated knowledge).
|
|
2647
|
+
|
|
2648
|
+
Args:
|
|
2649
|
+
bank_id: Bank ID to clear observations for
|
|
2650
|
+
request_context: Request context for authentication.
|
|
2651
|
+
|
|
2652
|
+
Returns:
|
|
2653
|
+
Dictionary with count of deleted observations
|
|
2654
|
+
"""
|
|
2655
|
+
await self._authenticate_tenant(request_context)
|
|
2656
|
+
pool = await self._get_pool()
|
|
2657
|
+
async with acquire_with_retry(pool) as conn:
|
|
2658
|
+
async with conn.transaction():
|
|
2659
|
+
# Count observations before deletion
|
|
2660
|
+
count = await conn.fetchval(
|
|
2661
|
+
f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1 AND fact_type = 'observation'",
|
|
2662
|
+
bank_id,
|
|
2663
|
+
)
|
|
2664
|
+
|
|
2665
|
+
# Delete all observations
|
|
2666
|
+
await conn.execute(
|
|
2667
|
+
f"DELETE FROM {fq_table('memory_units')} WHERE bank_id = $1 AND fact_type = 'observation'",
|
|
2668
|
+
bank_id,
|
|
2669
|
+
)
|
|
2670
|
+
|
|
2671
|
+
# Reset consolidation timestamp
|
|
2672
|
+
await conn.execute(
|
|
2673
|
+
f"UPDATE {fq_table('banks')} SET last_consolidated_at = NULL WHERE bank_id = $1",
|
|
2674
|
+
bank_id,
|
|
2675
|
+
)
|
|
2676
|
+
|
|
2677
|
+
return {"deleted_count": count or 0}
|
|
2678
|
+
|
|
2679
|
+
async def run_consolidation(
|
|
2680
|
+
self,
|
|
2681
|
+
bank_id: str,
|
|
2682
|
+
*,
|
|
2683
|
+
request_context: "RequestContext",
|
|
2684
|
+
) -> dict[str, int]:
|
|
2685
|
+
"""
|
|
2686
|
+
Run memory consolidation to create/update mental models.
|
|
2687
|
+
|
|
2688
|
+
Args:
|
|
2689
|
+
bank_id: Bank ID to run consolidation for
|
|
2690
|
+
request_context: Request context for authentication.
|
|
2691
|
+
|
|
2692
|
+
Returns:
|
|
2693
|
+
Dictionary with consolidation stats
|
|
2694
|
+
"""
|
|
2695
|
+
await self._authenticate_tenant(request_context)
|
|
2696
|
+
|
|
2697
|
+
from .consolidation import run_consolidation_job
|
|
2698
|
+
|
|
2699
|
+
result = await run_consolidation_job(
|
|
2700
|
+
memory_engine=self,
|
|
2701
|
+
bank_id=bank_id,
|
|
2702
|
+
request_context=request_context,
|
|
2703
|
+
)
|
|
2704
|
+
|
|
2705
|
+
return {
|
|
2706
|
+
"processed": result.get("processed", 0),
|
|
2707
|
+
"created": result.get("created", 0),
|
|
2708
|
+
"updated": result.get("updated", 0),
|
|
2709
|
+
"skipped": result.get("skipped", 0),
|
|
2710
|
+
}
|
|
2711
|
+
|
|
2256
2712
|
async def get_graph_data(
|
|
2257
2713
|
self,
|
|
2258
2714
|
bank_id: str | None = None,
|
|
2259
2715
|
fact_type: str | None = None,
|
|
2260
2716
|
*,
|
|
2717
|
+
limit: int = 1000,
|
|
2261
2718
|
request_context: "RequestContext",
|
|
2262
2719
|
):
|
|
2263
2720
|
"""
|
|
@@ -2266,10 +2723,11 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2266
2723
|
Args:
|
|
2267
2724
|
bank_id: Filter by bank ID
|
|
2268
2725
|
fact_type: Filter by fact type (world, experience, opinion)
|
|
2726
|
+
limit: Maximum number of items to return (default: 1000)
|
|
2269
2727
|
request_context: Request context for authentication.
|
|
2270
2728
|
|
|
2271
2729
|
Returns:
|
|
2272
|
-
Dict with nodes, edges, and
|
|
2730
|
+
Dict with nodes, edges, table_rows, total_units, and limit
|
|
2273
2731
|
"""
|
|
2274
2732
|
await self._authenticate_tenant(request_context)
|
|
2275
2733
|
pool = await self._get_pool()
|
|
@@ -2291,21 +2749,46 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2291
2749
|
|
|
2292
2750
|
where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
|
|
2293
2751
|
|
|
2752
|
+
# Get total count first
|
|
2753
|
+
total_count_result = await conn.fetchrow(
|
|
2754
|
+
f"""
|
|
2755
|
+
SELECT COUNT(*) as total
|
|
2756
|
+
FROM {fq_table("memory_units")}
|
|
2757
|
+
{where_clause}
|
|
2758
|
+
""",
|
|
2759
|
+
*query_params,
|
|
2760
|
+
)
|
|
2761
|
+
total_count = total_count_result["total"] if total_count_result else 0
|
|
2762
|
+
|
|
2763
|
+
# Get units with limit
|
|
2764
|
+
param_count += 1
|
|
2294
2765
|
units = await conn.fetch(
|
|
2295
2766
|
f"""
|
|
2296
|
-
SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
|
|
2767
|
+
SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type, tags, created_at, proof_count, source_memory_ids
|
|
2297
2768
|
FROM {fq_table("memory_units")}
|
|
2298
2769
|
{where_clause}
|
|
2299
2770
|
ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
|
|
2300
|
-
LIMIT
|
|
2771
|
+
LIMIT ${param_count}
|
|
2301
2772
|
""",
|
|
2302
2773
|
*query_params,
|
|
2774
|
+
limit,
|
|
2303
2775
|
)
|
|
2304
2776
|
|
|
2305
2777
|
# Get links, filtering to only include links between units of the selected agent
|
|
2306
2778
|
# Use DISTINCT ON with LEAST/GREATEST to deduplicate bidirectional links
|
|
2307
2779
|
unit_ids = [row["id"] for row in units]
|
|
2308
|
-
|
|
2780
|
+
unit_id_set = set(unit_ids)
|
|
2781
|
+
|
|
2782
|
+
# Collect source memory IDs from observations
|
|
2783
|
+
source_memory_ids = []
|
|
2784
|
+
for unit in units:
|
|
2785
|
+
if unit["source_memory_ids"]:
|
|
2786
|
+
source_memory_ids.extend(unit["source_memory_ids"])
|
|
2787
|
+
source_memory_ids = list(set(source_memory_ids)) # Deduplicate
|
|
2788
|
+
|
|
2789
|
+
# Fetch links involving both visible units AND source memories
|
|
2790
|
+
all_relevant_ids = unit_ids + source_memory_ids
|
|
2791
|
+
if all_relevant_ids:
|
|
2309
2792
|
links = await conn.fetch(
|
|
2310
2793
|
f"""
|
|
2311
2794
|
SELECT DISTINCT ON (LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid))
|
|
@@ -2316,14 +2799,69 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2316
2799
|
e.canonical_name as entity_name
|
|
2317
2800
|
FROM {fq_table("memory_links")} ml
|
|
2318
2801
|
LEFT JOIN {fq_table("entities")} e ON ml.entity_id = e.id
|
|
2319
|
-
WHERE ml.from_unit_id = ANY($1::uuid[])
|
|
2802
|
+
WHERE ml.from_unit_id = ANY($1::uuid[]) OR ml.to_unit_id = ANY($1::uuid[])
|
|
2320
2803
|
ORDER BY LEAST(ml.from_unit_id, ml.to_unit_id), GREATEST(ml.from_unit_id, ml.to_unit_id), ml.link_type, COALESCE(ml.entity_id, '00000000-0000-0000-0000-000000000000'::uuid), ml.weight DESC
|
|
2321
2804
|
""",
|
|
2322
|
-
|
|
2805
|
+
all_relevant_ids,
|
|
2323
2806
|
)
|
|
2324
2807
|
else:
|
|
2325
2808
|
links = []
|
|
2326
2809
|
|
|
2810
|
+
# Copy links from source memories to observations
|
|
2811
|
+
# Observations inherit links from their source memories via source_memory_ids
|
|
2812
|
+
# Build a map from source_id to observation_ids
|
|
2813
|
+
source_to_observations = {}
|
|
2814
|
+
for unit in units:
|
|
2815
|
+
if unit["source_memory_ids"]:
|
|
2816
|
+
for source_id in unit["source_memory_ids"]:
|
|
2817
|
+
if source_id not in source_to_observations:
|
|
2818
|
+
source_to_observations[source_id] = []
|
|
2819
|
+
source_to_observations[source_id].append(unit["id"])
|
|
2820
|
+
|
|
2821
|
+
copied_links = []
|
|
2822
|
+
for link in links:
|
|
2823
|
+
from_id = link["from_unit_id"]
|
|
2824
|
+
to_id = link["to_unit_id"]
|
|
2825
|
+
|
|
2826
|
+
# Get observations that should inherit this link
|
|
2827
|
+
from_observations = source_to_observations.get(from_id, [])
|
|
2828
|
+
to_observations = source_to_observations.get(to_id, [])
|
|
2829
|
+
|
|
2830
|
+
# If from_id is a source memory, copy links to its observations
|
|
2831
|
+
if from_observations:
|
|
2832
|
+
for obs_id in from_observations:
|
|
2833
|
+
# Only include if the target is visible
|
|
2834
|
+
if to_id in unit_id_set or to_observations:
|
|
2835
|
+
target = to_observations[0] if to_observations and to_id not in unit_id_set else to_id
|
|
2836
|
+
if target in unit_id_set:
|
|
2837
|
+
copied_links.append(
|
|
2838
|
+
{
|
|
2839
|
+
"from_unit_id": obs_id,
|
|
2840
|
+
"to_unit_id": target,
|
|
2841
|
+
"link_type": link["link_type"],
|
|
2842
|
+
"weight": link["weight"],
|
|
2843
|
+
"entity_name": link["entity_name"],
|
|
2844
|
+
}
|
|
2845
|
+
)
|
|
2846
|
+
|
|
2847
|
+
# If to_id is a source memory, copy links to its observations
|
|
2848
|
+
if to_observations and from_id in unit_id_set:
|
|
2849
|
+
for obs_id in to_observations:
|
|
2850
|
+
copied_links.append(
|
|
2851
|
+
{
|
|
2852
|
+
"from_unit_id": from_id,
|
|
2853
|
+
"to_unit_id": obs_id,
|
|
2854
|
+
"link_type": link["link_type"],
|
|
2855
|
+
"weight": link["weight"],
|
|
2856
|
+
"entity_name": link["entity_name"],
|
|
2857
|
+
}
|
|
2858
|
+
)
|
|
2859
|
+
|
|
2860
|
+
# Keep only direct links between visible nodes
|
|
2861
|
+
direct_links = [
|
|
2862
|
+
link for link in links if link["from_unit_id"] in unit_id_set and link["to_unit_id"] in unit_id_set
|
|
2863
|
+
]
|
|
2864
|
+
|
|
2327
2865
|
# Get entity information
|
|
2328
2866
|
unit_entities = await conn.fetch(f"""
|
|
2329
2867
|
SELECT ue.unit_id, e.canonical_name
|
|
@@ -2341,6 +2879,18 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2341
2879
|
entity_map[unit_id] = []
|
|
2342
2880
|
entity_map[unit_id].append(entity_name)
|
|
2343
2881
|
|
|
2882
|
+
# For observations, inherit entities from source memories
|
|
2883
|
+
for unit in units:
|
|
2884
|
+
if unit["source_memory_ids"] and unit["id"] not in entity_map:
|
|
2885
|
+
# Collect entities from all source memories
|
|
2886
|
+
source_entities = []
|
|
2887
|
+
for source_id in unit["source_memory_ids"]:
|
|
2888
|
+
if source_id in entity_map:
|
|
2889
|
+
source_entities.extend(entity_map[source_id])
|
|
2890
|
+
if source_entities:
|
|
2891
|
+
# Deduplicate while preserving order
|
|
2892
|
+
entity_map[unit["id"]] = list(dict.fromkeys(source_entities))
|
|
2893
|
+
|
|
2344
2894
|
# Build nodes
|
|
2345
2895
|
nodes = []
|
|
2346
2896
|
for row in units:
|
|
@@ -2374,14 +2924,15 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2374
2924
|
}
|
|
2375
2925
|
)
|
|
2376
2926
|
|
|
2377
|
-
# Build edges
|
|
2927
|
+
# Build edges (combine direct links and copied links from sources)
|
|
2378
2928
|
edges = []
|
|
2379
|
-
|
|
2929
|
+
all_links = direct_links + copied_links
|
|
2930
|
+
for row in all_links:
|
|
2380
2931
|
from_id = str(row["from_unit_id"])
|
|
2381
2932
|
to_id = str(row["to_unit_id"])
|
|
2382
2933
|
link_type = row["link_type"]
|
|
2383
2934
|
weight = row["weight"]
|
|
2384
|
-
entity_name = row
|
|
2935
|
+
entity_name = row.get("entity_name")
|
|
2385
2936
|
|
|
2386
2937
|
# Color by link type
|
|
2387
2938
|
if link_type == "temporal":
|
|
@@ -2433,10 +2984,13 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2433
2984
|
"document_id": row["document_id"],
|
|
2434
2985
|
"chunk_id": row["chunk_id"] if row["chunk_id"] else None,
|
|
2435
2986
|
"fact_type": row["fact_type"],
|
|
2987
|
+
"tags": list(row["tags"]) if row["tags"] else [],
|
|
2988
|
+
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
|
2989
|
+
"proof_count": row["proof_count"] if row["proof_count"] else None,
|
|
2436
2990
|
}
|
|
2437
2991
|
)
|
|
2438
2992
|
|
|
2439
|
-
return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units":
|
|
2993
|
+
return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units": total_count, "limit": limit}
|
|
2440
2994
|
|
|
2441
2995
|
async def list_memory_units(
|
|
2442
2996
|
self,
|
|
@@ -2565,6 +3119,97 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2565
3119
|
|
|
2566
3120
|
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
|
2567
3121
|
|
|
3122
|
+
async def get_memory_unit(
|
|
3123
|
+
self,
|
|
3124
|
+
bank_id: str,
|
|
3125
|
+
memory_id: str,
|
|
3126
|
+
request_context: "RequestContext",
|
|
3127
|
+
):
|
|
3128
|
+
"""
|
|
3129
|
+
Get a single memory unit by ID.
|
|
3130
|
+
|
|
3131
|
+
Args:
|
|
3132
|
+
bank_id: Bank ID
|
|
3133
|
+
memory_id: Memory unit ID
|
|
3134
|
+
request_context: Request context for authentication.
|
|
3135
|
+
|
|
3136
|
+
Returns:
|
|
3137
|
+
Dict with memory unit data or None if not found
|
|
3138
|
+
"""
|
|
3139
|
+
await self._authenticate_tenant(request_context)
|
|
3140
|
+
pool = await self._get_pool()
|
|
3141
|
+
async with acquire_with_retry(pool) as conn:
|
|
3142
|
+
# Get the memory unit (include source_memory_ids for mental models)
|
|
3143
|
+
row = await conn.fetchrow(
|
|
3144
|
+
f"""
|
|
3145
|
+
SELECT id, text, context, event_date, occurred_start, occurred_end,
|
|
3146
|
+
mentioned_at, fact_type, document_id, chunk_id, tags, source_memory_ids
|
|
3147
|
+
FROM {fq_table("memory_units")}
|
|
3148
|
+
WHERE id = $1 AND bank_id = $2
|
|
3149
|
+
""",
|
|
3150
|
+
memory_id,
|
|
3151
|
+
bank_id,
|
|
3152
|
+
)
|
|
3153
|
+
|
|
3154
|
+
if not row:
|
|
3155
|
+
return None
|
|
3156
|
+
|
|
3157
|
+
# Get entity information
|
|
3158
|
+
entities_rows = await conn.fetch(
|
|
3159
|
+
f"""
|
|
3160
|
+
SELECT e.canonical_name
|
|
3161
|
+
FROM {fq_table("unit_entities")} ue
|
|
3162
|
+
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
3163
|
+
WHERE ue.unit_id = $1
|
|
3164
|
+
""",
|
|
3165
|
+
row["id"],
|
|
3166
|
+
)
|
|
3167
|
+
entities = [r["canonical_name"] for r in entities_rows]
|
|
3168
|
+
|
|
3169
|
+
result = {
|
|
3170
|
+
"id": str(row["id"]),
|
|
3171
|
+
"text": row["text"],
|
|
3172
|
+
"context": row["context"] if row["context"] else "",
|
|
3173
|
+
"date": row["event_date"].isoformat() if row["event_date"] else "",
|
|
3174
|
+
"type": row["fact_type"],
|
|
3175
|
+
"mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
|
|
3176
|
+
"occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
|
|
3177
|
+
"occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
|
|
3178
|
+
"entities": entities,
|
|
3179
|
+
"document_id": row["document_id"] if row["document_id"] else None,
|
|
3180
|
+
"chunk_id": str(row["chunk_id"]) if row["chunk_id"] else None,
|
|
3181
|
+
"tags": row["tags"] if row["tags"] else [],
|
|
3182
|
+
}
|
|
3183
|
+
|
|
3184
|
+
# For observations, include source_memory_ids and fetch source_memories
|
|
3185
|
+
if row["fact_type"] == "observation" and row["source_memory_ids"]:
|
|
3186
|
+
source_ids = row["source_memory_ids"]
|
|
3187
|
+
result["source_memory_ids"] = [str(sid) for sid in source_ids]
|
|
3188
|
+
|
|
3189
|
+
# Fetch source memories
|
|
3190
|
+
source_rows = await conn.fetch(
|
|
3191
|
+
f"""
|
|
3192
|
+
SELECT id, text, fact_type, context, occurred_start, mentioned_at
|
|
3193
|
+
FROM {fq_table("memory_units")}
|
|
3194
|
+
WHERE id = ANY($1::uuid[])
|
|
3195
|
+
ORDER BY mentioned_at DESC NULLS LAST
|
|
3196
|
+
""",
|
|
3197
|
+
source_ids,
|
|
3198
|
+
)
|
|
3199
|
+
result["source_memories"] = [
|
|
3200
|
+
{
|
|
3201
|
+
"id": str(r["id"]),
|
|
3202
|
+
"text": r["text"],
|
|
3203
|
+
"type": r["fact_type"],
|
|
3204
|
+
"context": r["context"],
|
|
3205
|
+
"occurred_start": r["occurred_start"].isoformat() if r["occurred_start"] else None,
|
|
3206
|
+
"mentioned_at": r["mentioned_at"].isoformat() if r["mentioned_at"] else None,
|
|
3207
|
+
}
|
|
3208
|
+
for r in source_rows
|
|
3209
|
+
]
|
|
3210
|
+
|
|
3211
|
+
return result
|
|
3212
|
+
|
|
2568
3213
|
async def list_documents(
|
|
2569
3214
|
self,
|
|
2570
3215
|
bank_id: str,
|
|
@@ -2741,264 +3386,24 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2741
3386
|
"created_at": chunk["created_at"].isoformat() if chunk["created_at"] else "",
|
|
2742
3387
|
}
|
|
2743
3388
|
|
|
2744
|
-
|
|
3389
|
+
# ==================== bank profile Methods ====================
|
|
3390
|
+
|
|
3391
|
+
async def get_bank_profile(
|
|
2745
3392
|
self,
|
|
2746
|
-
|
|
2747
|
-
|
|
2748
|
-
|
|
2749
|
-
|
|
2750
|
-
) -> dict[str, Any] | None:
|
|
3393
|
+
bank_id: str,
|
|
3394
|
+
*,
|
|
3395
|
+
request_context: "RequestContext",
|
|
3396
|
+
) -> dict[str, Any]:
|
|
2751
3397
|
"""
|
|
2752
|
-
|
|
2753
|
-
|
|
2754
|
-
Args:
|
|
2755
|
-
opinion_text: Current opinion text (includes reasons)
|
|
2756
|
-
opinion_confidence: Current confidence score (0.0-1.0)
|
|
2757
|
-
new_event_text: Text of the new event
|
|
2758
|
-
entity_name: Name of the entity this opinion is about
|
|
2759
|
-
|
|
2760
|
-
Returns:
|
|
2761
|
-
Dict with 'action' ('keep'|'update'), 'new_confidence', 'new_text' (if action=='update')
|
|
2762
|
-
or None if no changes needed
|
|
2763
|
-
"""
|
|
2764
|
-
|
|
2765
|
-
class OpinionEvaluation(BaseModel):
|
|
2766
|
-
"""Evaluation of whether an opinion should be updated."""
|
|
2767
|
-
|
|
2768
|
-
action: str = Field(description="Action to take: 'keep' (no change) or 'update' (modify opinion)")
|
|
2769
|
-
reasoning: str = Field(description="Brief explanation of why this action was chosen")
|
|
2770
|
-
new_confidence: float = Field(
|
|
2771
|
-
description="New confidence score (0.0-1.0). Can be higher, lower, or same as before."
|
|
2772
|
-
)
|
|
2773
|
-
new_opinion_text: str | None = Field(
|
|
2774
|
-
default=None,
|
|
2775
|
-
description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None.",
|
|
2776
|
-
)
|
|
2777
|
-
|
|
2778
|
-
evaluation_prompt = f"""You are evaluating whether an existing opinion should be updated based on new information.
|
|
2779
|
-
|
|
2780
|
-
ENTITY: {entity_name}
|
|
2781
|
-
|
|
2782
|
-
EXISTING OPINION:
|
|
2783
|
-
{opinion_text}
|
|
2784
|
-
Current confidence: {opinion_confidence:.2f}
|
|
2785
|
-
|
|
2786
|
-
NEW EVENT:
|
|
2787
|
-
{new_event_text}
|
|
2788
|
-
|
|
2789
|
-
Evaluate whether this new event:
|
|
2790
|
-
1. REINFORCES the opinion (increase confidence, keep text)
|
|
2791
|
-
2. WEAKENS the opinion (decrease confidence, keep text)
|
|
2792
|
-
3. CHANGES the opinion (update both text and confidence, noting "Previously I thought X, but now Y...")
|
|
2793
|
-
4. IRRELEVANT (keep everything as is)
|
|
2794
|
-
|
|
2795
|
-
Guidelines:
|
|
2796
|
-
- Only suggest 'update' action if the new event genuinely contradicts or significantly modifies the opinion
|
|
2797
|
-
- If updating the text, acknowledge the previous opinion and explain the change
|
|
2798
|
-
- Confidence should reflect accumulated evidence (0.0 = no confidence, 1.0 = very confident)
|
|
2799
|
-
- Small changes in confidence are normal; large jumps should be rare"""
|
|
2800
|
-
|
|
2801
|
-
try:
|
|
2802
|
-
result = await self._llm_config.call(
|
|
2803
|
-
messages=[
|
|
2804
|
-
{"role": "system", "content": "You evaluate and update opinions based on new information."},
|
|
2805
|
-
{"role": "user", "content": evaluation_prompt},
|
|
2806
|
-
],
|
|
2807
|
-
response_format=OpinionEvaluation,
|
|
2808
|
-
scope="memory_evaluate_opinion",
|
|
2809
|
-
temperature=0.3, # Lower temperature for more consistent evaluation
|
|
2810
|
-
)
|
|
2811
|
-
|
|
2812
|
-
# Only return updates if something actually changed
|
|
2813
|
-
if result.action == "keep" and abs(result.new_confidence - opinion_confidence) < 0.01:
|
|
2814
|
-
return None
|
|
2815
|
-
|
|
2816
|
-
return {
|
|
2817
|
-
"action": result.action,
|
|
2818
|
-
"reasoning": result.reasoning,
|
|
2819
|
-
"new_confidence": result.new_confidence,
|
|
2820
|
-
"new_text": result.new_opinion_text if result.action == "update" else None,
|
|
2821
|
-
}
|
|
2822
|
-
|
|
2823
|
-
except Exception as e:
|
|
2824
|
-
logger.warning(f"Failed to evaluate opinion update: {str(e)}")
|
|
2825
|
-
return None
|
|
2826
|
-
|
|
2827
|
-
async def _handle_form_opinion(self, task_dict: dict[str, Any]):
|
|
2828
|
-
"""
|
|
2829
|
-
Handler for form opinion tasks.
|
|
2830
|
-
|
|
2831
|
-
Args:
|
|
2832
|
-
task_dict: Dict with keys: 'bank_id', 'answer_text', 'query', 'tenant_id'
|
|
2833
|
-
"""
|
|
2834
|
-
bank_id = task_dict["bank_id"]
|
|
2835
|
-
answer_text = task_dict["answer_text"]
|
|
2836
|
-
query = task_dict["query"]
|
|
2837
|
-
tenant_id = task_dict.get("tenant_id")
|
|
2838
|
-
|
|
2839
|
-
await self._extract_and_store_opinions_async(
|
|
2840
|
-
bank_id=bank_id, answer_text=answer_text, query=query, tenant_id=tenant_id
|
|
2841
|
-
)
|
|
2842
|
-
|
|
2843
|
-
async def _handle_reinforce_opinion(self, task_dict: dict[str, Any]):
|
|
2844
|
-
"""
|
|
2845
|
-
Handler for reinforce opinion tasks.
|
|
2846
|
-
|
|
2847
|
-
Args:
|
|
2848
|
-
task_dict: Dict with keys: 'bank_id', 'created_unit_ids', 'unit_texts', 'unit_entities'
|
|
2849
|
-
"""
|
|
2850
|
-
bank_id = task_dict["bank_id"]
|
|
2851
|
-
created_unit_ids = task_dict["created_unit_ids"]
|
|
2852
|
-
unit_texts = task_dict["unit_texts"]
|
|
2853
|
-
unit_entities = task_dict["unit_entities"]
|
|
2854
|
-
|
|
2855
|
-
await self._reinforce_opinions_async(
|
|
2856
|
-
bank_id=bank_id, created_unit_ids=created_unit_ids, unit_texts=unit_texts, unit_entities=unit_entities
|
|
2857
|
-
)
|
|
2858
|
-
|
|
2859
|
-
async def _reinforce_opinions_async(
|
|
2860
|
-
self,
|
|
2861
|
-
bank_id: str,
|
|
2862
|
-
created_unit_ids: list[str],
|
|
2863
|
-
unit_texts: list[str],
|
|
2864
|
-
unit_entities: list[list[dict[str, str]]],
|
|
2865
|
-
):
|
|
2866
|
-
"""
|
|
2867
|
-
Background task to reinforce opinions based on newly ingested events.
|
|
2868
|
-
|
|
2869
|
-
This runs asynchronously and does not block the put operation.
|
|
2870
|
-
|
|
2871
|
-
Args:
|
|
2872
|
-
bank_id: bank ID
|
|
2873
|
-
created_unit_ids: List of newly created memory unit IDs
|
|
2874
|
-
unit_texts: Texts of the newly created units
|
|
2875
|
-
unit_entities: Entities extracted from each unit
|
|
2876
|
-
"""
|
|
2877
|
-
try:
|
|
2878
|
-
# Extract all unique entity names from the new units
|
|
2879
|
-
entity_names = set()
|
|
2880
|
-
for entities_list in unit_entities:
|
|
2881
|
-
for entity in entities_list:
|
|
2882
|
-
# Handle both Entity objects and dicts
|
|
2883
|
-
if hasattr(entity, "text"):
|
|
2884
|
-
entity_names.add(entity.text)
|
|
2885
|
-
elif isinstance(entity, dict):
|
|
2886
|
-
entity_names.add(entity["text"])
|
|
2887
|
-
|
|
2888
|
-
if not entity_names:
|
|
2889
|
-
return
|
|
2890
|
-
|
|
2891
|
-
pool = await self._get_pool()
|
|
2892
|
-
async with acquire_with_retry(pool) as conn:
|
|
2893
|
-
# Find all opinions related to these entities
|
|
2894
|
-
opinions = await conn.fetch(
|
|
2895
|
-
f"""
|
|
2896
|
-
SELECT DISTINCT mu.id, mu.text, mu.confidence_score, e.canonical_name
|
|
2897
|
-
FROM {fq_table("memory_units")} mu
|
|
2898
|
-
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
2899
|
-
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
2900
|
-
WHERE mu.bank_id = $1
|
|
2901
|
-
AND mu.fact_type = 'opinion'
|
|
2902
|
-
AND e.canonical_name = ANY($2::text[])
|
|
2903
|
-
""",
|
|
2904
|
-
bank_id,
|
|
2905
|
-
list(entity_names),
|
|
2906
|
-
)
|
|
2907
|
-
|
|
2908
|
-
if not opinions:
|
|
2909
|
-
return
|
|
2910
|
-
|
|
2911
|
-
# Use cached LLM config
|
|
2912
|
-
if self._llm_config is None:
|
|
2913
|
-
logger.error("[REINFORCE] LLM config not available, skipping opinion reinforcement")
|
|
2914
|
-
return
|
|
2915
|
-
|
|
2916
|
-
# Evaluate each opinion against the new events
|
|
2917
|
-
updates_to_apply = []
|
|
2918
|
-
for opinion in opinions:
|
|
2919
|
-
opinion_id = str(opinion["id"])
|
|
2920
|
-
opinion_text = opinion["text"]
|
|
2921
|
-
opinion_confidence = opinion["confidence_score"]
|
|
2922
|
-
entity_name = opinion["canonical_name"]
|
|
2923
|
-
|
|
2924
|
-
# Find all new events mentioning this entity
|
|
2925
|
-
relevant_events = []
|
|
2926
|
-
for unit_text, entities_list in zip(unit_texts, unit_entities):
|
|
2927
|
-
if any(e["text"] == entity_name for e in entities_list):
|
|
2928
|
-
relevant_events.append(unit_text)
|
|
2929
|
-
|
|
2930
|
-
if not relevant_events:
|
|
2931
|
-
continue
|
|
2932
|
-
|
|
2933
|
-
# Combine all relevant events
|
|
2934
|
-
combined_events = "\n".join(relevant_events)
|
|
2935
|
-
|
|
2936
|
-
# Evaluate if opinion should be updated
|
|
2937
|
-
evaluation = await self._evaluate_opinion_update_async(
|
|
2938
|
-
opinion_text, opinion_confidence, combined_events, entity_name
|
|
2939
|
-
)
|
|
2940
|
-
|
|
2941
|
-
if evaluation:
|
|
2942
|
-
updates_to_apply.append({"opinion_id": opinion_id, "evaluation": evaluation})
|
|
2943
|
-
|
|
2944
|
-
# Apply all updates in a single transaction
|
|
2945
|
-
if updates_to_apply:
|
|
2946
|
-
async with conn.transaction():
|
|
2947
|
-
for update in updates_to_apply:
|
|
2948
|
-
opinion_id = update["opinion_id"]
|
|
2949
|
-
evaluation = update["evaluation"]
|
|
2950
|
-
|
|
2951
|
-
if evaluation["action"] == "update" and evaluation["new_text"]:
|
|
2952
|
-
# Update both text and confidence
|
|
2953
|
-
await conn.execute(
|
|
2954
|
-
f"""
|
|
2955
|
-
UPDATE {fq_table("memory_units")}
|
|
2956
|
-
SET text = $1, confidence_score = $2, updated_at = NOW()
|
|
2957
|
-
WHERE id = $3
|
|
2958
|
-
""",
|
|
2959
|
-
evaluation["new_text"],
|
|
2960
|
-
evaluation["new_confidence"],
|
|
2961
|
-
uuid.UUID(opinion_id),
|
|
2962
|
-
)
|
|
2963
|
-
else:
|
|
2964
|
-
# Only update confidence
|
|
2965
|
-
await conn.execute(
|
|
2966
|
-
f"""
|
|
2967
|
-
UPDATE {fq_table("memory_units")}
|
|
2968
|
-
SET confidence_score = $1, updated_at = NOW()
|
|
2969
|
-
WHERE id = $2
|
|
2970
|
-
""",
|
|
2971
|
-
evaluation["new_confidence"],
|
|
2972
|
-
uuid.UUID(opinion_id),
|
|
2973
|
-
)
|
|
2974
|
-
|
|
2975
|
-
else:
|
|
2976
|
-
pass # No opinions to update
|
|
2977
|
-
|
|
2978
|
-
except Exception as e:
|
|
2979
|
-
logger.error(f"[REINFORCE] Error during opinion reinforcement: {str(e)}")
|
|
2980
|
-
import traceback
|
|
2981
|
-
|
|
2982
|
-
traceback.print_exc()
|
|
2983
|
-
|
|
2984
|
-
# ==================== bank profile Methods ====================
|
|
2985
|
-
|
|
2986
|
-
async def get_bank_profile(
|
|
2987
|
-
self,
|
|
2988
|
-
bank_id: str,
|
|
2989
|
-
*,
|
|
2990
|
-
request_context: "RequestContext",
|
|
2991
|
-
) -> dict[str, Any]:
|
|
2992
|
-
"""
|
|
2993
|
-
Get bank profile (name, disposition + background).
|
|
2994
|
-
Auto-creates agent with default values if not exists.
|
|
3398
|
+
Get bank profile (name, disposition + mission).
|
|
3399
|
+
Auto-creates agent with default values if not exists.
|
|
2995
3400
|
|
|
2996
3401
|
Args:
|
|
2997
3402
|
bank_id: bank IDentifier
|
|
2998
3403
|
request_context: Request context for authentication.
|
|
2999
3404
|
|
|
3000
3405
|
Returns:
|
|
3001
|
-
Dict with name, disposition traits, and
|
|
3406
|
+
Dict with name, disposition traits, and mission
|
|
3002
3407
|
"""
|
|
3003
3408
|
await self._authenticate_tenant(request_context)
|
|
3004
3409
|
pool = await self._get_pool()
|
|
@@ -3008,7 +3413,7 @@ Guidelines:
|
|
|
3008
3413
|
"bank_id": bank_id,
|
|
3009
3414
|
"name": profile["name"],
|
|
3010
3415
|
"disposition": disposition,
|
|
3011
|
-
"
|
|
3416
|
+
"mission": profile["mission"],
|
|
3012
3417
|
}
|
|
3013
3418
|
|
|
3014
3419
|
async def update_bank_disposition(
|
|
@@ -3030,31 +3435,51 @@ Guidelines:
|
|
|
3030
3435
|
pool = await self._get_pool()
|
|
3031
3436
|
await bank_utils.update_bank_disposition(pool, bank_id, disposition)
|
|
3032
3437
|
|
|
3033
|
-
async def
|
|
3438
|
+
async def set_bank_mission(
|
|
3439
|
+
self,
|
|
3440
|
+
bank_id: str,
|
|
3441
|
+
mission: str,
|
|
3442
|
+
*,
|
|
3443
|
+
request_context: "RequestContext",
|
|
3444
|
+
) -> dict[str, Any]:
|
|
3445
|
+
"""
|
|
3446
|
+
Set the mission for a bank.
|
|
3447
|
+
|
|
3448
|
+
Args:
|
|
3449
|
+
bank_id: bank IDentifier
|
|
3450
|
+
mission: The mission text
|
|
3451
|
+
request_context: Request context for authentication.
|
|
3452
|
+
|
|
3453
|
+
Returns:
|
|
3454
|
+
Dict with bank_id and mission.
|
|
3455
|
+
"""
|
|
3456
|
+
await self._authenticate_tenant(request_context)
|
|
3457
|
+
pool = await self._get_pool()
|
|
3458
|
+
await bank_utils.set_bank_mission(pool, bank_id, mission)
|
|
3459
|
+
return {"bank_id": bank_id, "mission": mission}
|
|
3460
|
+
|
|
3461
|
+
async def merge_bank_mission(
|
|
3034
3462
|
self,
|
|
3035
3463
|
bank_id: str,
|
|
3036
3464
|
new_info: str,
|
|
3037
3465
|
*,
|
|
3038
|
-
update_disposition: bool = True,
|
|
3039
3466
|
request_context: "RequestContext",
|
|
3040
3467
|
) -> dict[str, Any]:
|
|
3041
3468
|
"""
|
|
3042
|
-
Merge new
|
|
3469
|
+
Merge new mission information with existing mission using LLM.
|
|
3043
3470
|
Normalizes to first person ("I") and resolves conflicts.
|
|
3044
|
-
Optionally infers disposition traits from the merged background.
|
|
3045
3471
|
|
|
3046
3472
|
Args:
|
|
3047
3473
|
bank_id: bank IDentifier
|
|
3048
|
-
new_info: New
|
|
3049
|
-
update_disposition: If True, infer Big Five traits from background (default: True)
|
|
3474
|
+
new_info: New mission information to add/merge
|
|
3050
3475
|
request_context: Request context for authentication.
|
|
3051
3476
|
|
|
3052
3477
|
Returns:
|
|
3053
|
-
Dict with '
|
|
3478
|
+
Dict with 'mission' (str) key
|
|
3054
3479
|
"""
|
|
3055
3480
|
await self._authenticate_tenant(request_context)
|
|
3056
3481
|
pool = await self._get_pool()
|
|
3057
|
-
return await bank_utils.
|
|
3482
|
+
return await bank_utils.merge_bank_mission(pool, self._reflect_llm_config, bank_id, new_info)
|
|
3058
3483
|
|
|
3059
3484
|
async def list_banks(
|
|
3060
3485
|
self,
|
|
@@ -3068,7 +3493,7 @@ Guidelines:
|
|
|
3068
3493
|
request_context: Request context for authentication.
|
|
3069
3494
|
|
|
3070
3495
|
Returns:
|
|
3071
|
-
List of dicts with bank_id, name, disposition,
|
|
3496
|
+
List of dicts with bank_id, name, disposition, mission, created_at, updated_at
|
|
3072
3497
|
"""
|
|
3073
3498
|
await self._authenticate_tenant(request_context)
|
|
3074
3499
|
pool = await self._get_pool()
|
|
@@ -3086,35 +3511,44 @@ Guidelines:
|
|
|
3086
3511
|
max_tokens: int = 4096,
|
|
3087
3512
|
response_schema: dict | None = None,
|
|
3088
3513
|
request_context: "RequestContext",
|
|
3514
|
+
tags: list[str] | None = None,
|
|
3515
|
+
tags_match: TagsMatch = "any",
|
|
3516
|
+
exclude_mental_model_ids: list[str] | None = None,
|
|
3089
3517
|
) -> ReflectResult:
|
|
3090
3518
|
"""
|
|
3091
|
-
Reflect and formulate an answer using
|
|
3519
|
+
Reflect and formulate an answer using an agentic loop with tools.
|
|
3092
3520
|
|
|
3093
|
-
|
|
3094
|
-
1.
|
|
3095
|
-
2.
|
|
3096
|
-
3.
|
|
3097
|
-
4.
|
|
3098
|
-
|
|
3099
|
-
|
|
3100
|
-
|
|
3521
|
+
The reflect agent iteratively uses tools to:
|
|
3522
|
+
1. lookup: Get mental models (synthesized knowledge)
|
|
3523
|
+
2. recall: Search facts (semantic + temporal retrieval)
|
|
3524
|
+
3. learn: Create/update mental models with new insights
|
|
3525
|
+
4. expand: Get chunk/document context for memories
|
|
3526
|
+
|
|
3527
|
+
The agent starts with empty context and must call tools to gather
|
|
3528
|
+
information. On the last iteration, tools are removed to force a
|
|
3529
|
+
final text response.
|
|
3101
3530
|
|
|
3102
3531
|
Args:
|
|
3103
3532
|
bank_id: bank identifier
|
|
3104
3533
|
query: Question to answer
|
|
3105
|
-
budget: Budget level
|
|
3106
|
-
context: Additional context string to include in
|
|
3107
|
-
|
|
3534
|
+
budget: Budget level (currently unused, reserved for future)
|
|
3535
|
+
context: Additional context string to include in agent prompt
|
|
3536
|
+
max_tokens: Max tokens (currently unused, reserved for future)
|
|
3537
|
+
response_schema: Optional JSON Schema for structured output (not yet supported)
|
|
3538
|
+
tags: Optional tags to filter memories
|
|
3539
|
+
tags_match: How to match tags - "any" (OR), "all" (AND)
|
|
3540
|
+
exclude_mental_model_ids: Optional list of mental model IDs to exclude from search
|
|
3541
|
+
(used when refreshing a mental model to avoid circular reference)
|
|
3108
3542
|
|
|
3109
3543
|
Returns:
|
|
3110
3544
|
ReflectResult containing:
|
|
3111
|
-
- text: Plain text answer
|
|
3112
|
-
- based_on:
|
|
3113
|
-
- new_opinions:
|
|
3114
|
-
- structured_output:
|
|
3545
|
+
- text: Plain text answer
|
|
3546
|
+
- based_on: Empty dict (agent retrieves facts dynamically)
|
|
3547
|
+
- new_opinions: Empty list
|
|
3548
|
+
- structured_output: None (not yet supported for agentic reflect)
|
|
3115
3549
|
"""
|
|
3116
3550
|
# Use cached LLM config
|
|
3117
|
-
if self.
|
|
3551
|
+
if self._reflect_llm_config is None:
|
|
3118
3552
|
raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
|
|
3119
3553
|
|
|
3120
3554
|
# Authenticate tenant and set schema in context (for fq_table())
|
|
@@ -3135,121 +3569,312 @@ Guidelines:
|
|
|
3135
3569
|
|
|
3136
3570
|
reflect_start = time.time()
|
|
3137
3571
|
reflect_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
|
|
3138
|
-
|
|
3139
|
-
|
|
3572
|
+
tags_info = f", tags={tags} ({tags_match})" if tags else ""
|
|
3573
|
+
logger.info(f"[REFLECT {reflect_id}] Starting agentic reflect for query: {query[:50]}...{tags_info}")
|
|
3140
3574
|
|
|
3141
|
-
#
|
|
3142
|
-
|
|
3143
|
-
|
|
3144
|
-
|
|
3145
|
-
|
|
3146
|
-
|
|
3147
|
-
|
|
3148
|
-
|
|
3149
|
-
|
|
3150
|
-
|
|
3151
|
-
|
|
3152
|
-
|
|
3153
|
-
|
|
3575
|
+
# Get bank profile for agent identity
|
|
3576
|
+
profile = await self.get_bank_profile(bank_id, request_context=request_context)
|
|
3577
|
+
|
|
3578
|
+
# NOTE: Mental models are NOT pre-loaded to keep the initial prompt small.
|
|
3579
|
+
# The agent can call lookup() to list available models if needed.
|
|
3580
|
+
# This is critical for banks with many mental models to avoid huge prompts.
|
|
3581
|
+
|
|
3582
|
+
# Compute max iterations based on budget
|
|
3583
|
+
config = get_config()
|
|
3584
|
+
base_max_iterations = config.reflect_max_iterations
|
|
3585
|
+
# Budget multipliers: low=0.5x, mid=1x, high=2x
|
|
3586
|
+
budget_multipliers = {Budget.LOW: 0.5, Budget.MID: 1.0, Budget.HIGH: 2.0}
|
|
3587
|
+
effective_budget = budget or Budget.LOW
|
|
3588
|
+
max_iterations = max(1, int(base_max_iterations * budget_multipliers.get(effective_budget, 1.0)))
|
|
3589
|
+
|
|
3590
|
+
# Run agentic loop - acquire connections only when needed for DB operations
|
|
3591
|
+
# (not held during LLM calls which can be slow)
|
|
3592
|
+
pool = await self._get_pool()
|
|
3154
3593
|
|
|
3155
|
-
|
|
3594
|
+
# Get bank stats for freshness info
|
|
3595
|
+
bank_stats = await self.get_bank_stats(bank_id, request_context=request_context)
|
|
3596
|
+
last_consolidated_at = bank_stats.last_consolidated_at if hasattr(bank_stats, "last_consolidated_at") else None
|
|
3597
|
+
pending_consolidation = bank_stats.pending_consolidation if hasattr(bank_stats, "pending_consolidation") else 0
|
|
3156
3598
|
|
|
3157
|
-
#
|
|
3158
|
-
|
|
3159
|
-
world_results = [r for r in all_results if r.fact_type == "world"]
|
|
3160
|
-
opinion_results = [r for r in all_results if r.fact_type == "opinion"]
|
|
3599
|
+
# Create tool callbacks that acquire connections only when needed
|
|
3600
|
+
from .retain import embedding_utils
|
|
3161
3601
|
|
|
3162
|
-
|
|
3163
|
-
|
|
3602
|
+
async def search_mental_models_fn(q: str, max_results: int = 5) -> dict[str, Any]:
|
|
3603
|
+
# Generate embedding for the query
|
|
3604
|
+
embeddings = await embedding_utils.generate_embeddings_batch(self.embeddings, [q])
|
|
3605
|
+
query_embedding = embeddings[0]
|
|
3606
|
+
async with pool.acquire() as conn:
|
|
3607
|
+
return await tool_search_mental_models(
|
|
3608
|
+
conn,
|
|
3609
|
+
bank_id,
|
|
3610
|
+
q,
|
|
3611
|
+
query_embedding,
|
|
3612
|
+
max_results=max_results,
|
|
3613
|
+
tags=tags,
|
|
3614
|
+
tags_match=tags_match,
|
|
3615
|
+
exclude_ids=exclude_mental_model_ids,
|
|
3616
|
+
)
|
|
3617
|
+
|
|
3618
|
+
async def search_observations_fn(q: str, max_tokens: int = 5000) -> dict[str, Any]:
|
|
3619
|
+
return await tool_search_observations(
|
|
3620
|
+
self,
|
|
3621
|
+
bank_id,
|
|
3622
|
+
q,
|
|
3623
|
+
request_context,
|
|
3624
|
+
max_tokens=max_tokens,
|
|
3625
|
+
tags=tags,
|
|
3626
|
+
tags_match=tags_match,
|
|
3627
|
+
last_consolidated_at=last_consolidated_at,
|
|
3628
|
+
pending_consolidation=pending_consolidation,
|
|
3629
|
+
)
|
|
3630
|
+
|
|
3631
|
+
async def recall_fn(q: str, max_tokens: int = 4096) -> dict[str, Any]:
|
|
3632
|
+
return await tool_recall(
|
|
3633
|
+
self, bank_id, q, request_context, max_tokens=max_tokens, tags=tags, tags_match=tags_match
|
|
3634
|
+
)
|
|
3635
|
+
|
|
3636
|
+
async def expand_fn(memory_ids: list[str], depth: str) -> dict[str, Any]:
|
|
3637
|
+
async with pool.acquire() as conn:
|
|
3638
|
+
return await tool_expand(conn, bank_id, memory_ids, depth)
|
|
3639
|
+
|
|
3640
|
+
# Load directives from the dedicated directives table
|
|
3641
|
+
# Directives are hard rules that must be followed in all responses
|
|
3642
|
+
directives_raw = await self.list_directives(
|
|
3643
|
+
bank_id=bank_id,
|
|
3644
|
+
tags=tags,
|
|
3645
|
+
tags_match=tags_match,
|
|
3646
|
+
active_only=True,
|
|
3647
|
+
request_context=request_context,
|
|
3164
3648
|
)
|
|
3649
|
+
# Convert directive format to the expected format for reflect agent
|
|
3650
|
+
# The agent expects: name, description (optional), observations (list of {title, content})
|
|
3651
|
+
directives = [
|
|
3652
|
+
{
|
|
3653
|
+
"name": d["name"],
|
|
3654
|
+
"description": d["content"], # Use content as description
|
|
3655
|
+
"observations": [], # Directives use content directly, not observations
|
|
3656
|
+
}
|
|
3657
|
+
for d in directives_raw
|
|
3658
|
+
]
|
|
3659
|
+
if directives:
|
|
3660
|
+
logger.info(f"[REFLECT {reflect_id}] Loaded {len(directives)} directives")
|
|
3165
3661
|
|
|
3166
|
-
#
|
|
3167
|
-
|
|
3168
|
-
|
|
3169
|
-
|
|
3662
|
+
# Check if the bank has any mental models
|
|
3663
|
+
async with pool.acquire() as conn:
|
|
3664
|
+
mental_model_count = await conn.fetchval(
|
|
3665
|
+
f"SELECT COUNT(*) FROM {fq_table('mental_models')} WHERE bank_id = $1",
|
|
3666
|
+
bank_id,
|
|
3667
|
+
)
|
|
3668
|
+
has_mental_models = mental_model_count > 0
|
|
3669
|
+
if has_mental_models:
|
|
3670
|
+
logger.info(f"[REFLECT {reflect_id}] Bank has {mental_model_count} mental models")
|
|
3170
3671
|
|
|
3171
|
-
#
|
|
3172
|
-
|
|
3173
|
-
|
|
3174
|
-
|
|
3175
|
-
background = profile["background"]
|
|
3176
|
-
|
|
3177
|
-
# Build the prompt
|
|
3178
|
-
prompt = think_utils.build_think_prompt(
|
|
3179
|
-
agent_facts_text=agent_facts_text,
|
|
3180
|
-
world_facts_text=world_facts_text,
|
|
3181
|
-
opinion_facts_text=opinion_facts_text,
|
|
3672
|
+
# Run the agent
|
|
3673
|
+
agent_result = await run_reflect_agent(
|
|
3674
|
+
llm_config=self._reflect_llm_config,
|
|
3675
|
+
bank_id=bank_id,
|
|
3182
3676
|
query=query,
|
|
3183
|
-
|
|
3184
|
-
|
|
3185
|
-
|
|
3677
|
+
bank_profile=profile,
|
|
3678
|
+
search_mental_models_fn=search_mental_models_fn,
|
|
3679
|
+
search_observations_fn=search_observations_fn,
|
|
3680
|
+
recall_fn=recall_fn,
|
|
3681
|
+
expand_fn=expand_fn,
|
|
3186
3682
|
context=context,
|
|
3683
|
+
max_iterations=max_iterations,
|
|
3684
|
+
max_tokens=max_tokens,
|
|
3685
|
+
response_schema=response_schema,
|
|
3686
|
+
directives=directives,
|
|
3687
|
+
has_mental_models=has_mental_models,
|
|
3688
|
+
budget=effective_budget,
|
|
3187
3689
|
)
|
|
3188
3690
|
|
|
3189
|
-
|
|
3190
|
-
|
|
3191
|
-
|
|
3192
|
-
|
|
3193
|
-
|
|
3194
|
-
# Prepare response_format if schema provided
|
|
3195
|
-
response_format = None
|
|
3196
|
-
if response_schema is not None:
|
|
3197
|
-
# Wrapper class to provide Pydantic-like interface for raw JSON schemas
|
|
3198
|
-
class JsonSchemaWrapper:
|
|
3199
|
-
def __init__(self, schema: dict):
|
|
3200
|
-
self._schema = schema
|
|
3201
|
-
|
|
3202
|
-
def model_json_schema(self):
|
|
3203
|
-
return self._schema
|
|
3204
|
-
|
|
3205
|
-
response_format = JsonSchemaWrapper(response_schema)
|
|
3206
|
-
|
|
3207
|
-
llm_start = time.time()
|
|
3208
|
-
result = await self._llm_config.call(
|
|
3209
|
-
messages=messages,
|
|
3210
|
-
scope="memory_reflect",
|
|
3211
|
-
max_completion_tokens=max_tokens,
|
|
3212
|
-
response_format=response_format,
|
|
3213
|
-
skip_validation=True if response_format else False,
|
|
3214
|
-
# Don't enforce strict_schema - not all providers support it and may retry forever
|
|
3215
|
-
# Soft enforcement (schema in prompt + json_object mode) is sufficient
|
|
3216
|
-
strict_schema=False,
|
|
3691
|
+
total_time = time.time() - reflect_start
|
|
3692
|
+
logger.info(
|
|
3693
|
+
f"[REFLECT {reflect_id}] Complete: {len(agent_result.text)} chars, "
|
|
3694
|
+
f"{agent_result.iterations} iterations, {agent_result.tools_called} tool calls | {total_time:.3f}s"
|
|
3217
3695
|
)
|
|
3218
|
-
llm_time = time.time() - llm_start
|
|
3219
3696
|
|
|
3220
|
-
#
|
|
3221
|
-
|
|
3222
|
-
|
|
3223
|
-
|
|
3224
|
-
|
|
3225
|
-
|
|
3226
|
-
|
|
3227
|
-
|
|
3697
|
+
# Convert agent tool trace to ToolCallTrace objects
|
|
3698
|
+
tool_trace_result = [
|
|
3699
|
+
ToolCallTrace(
|
|
3700
|
+
tool=tc.tool,
|
|
3701
|
+
reason=tc.reason,
|
|
3702
|
+
input=tc.input,
|
|
3703
|
+
output=tc.output,
|
|
3704
|
+
duration_ms=tc.duration_ms,
|
|
3705
|
+
iteration=tc.iteration,
|
|
3706
|
+
)
|
|
3707
|
+
for tc in agent_result.tool_trace
|
|
3708
|
+
]
|
|
3228
3709
|
|
|
3229
|
-
#
|
|
3230
|
-
|
|
3231
|
-
|
|
3232
|
-
|
|
3233
|
-
|
|
3234
|
-
|
|
3235
|
-
|
|
3236
|
-
|
|
3237
|
-
|
|
3238
|
-
|
|
3239
|
-
|
|
3710
|
+
# Convert agent LLM trace to LLMCallTrace objects
|
|
3711
|
+
llm_trace_result = [LLMCallTrace(scope=lc.scope, duration_ms=lc.duration_ms) for lc in agent_result.llm_trace]
|
|
3712
|
+
|
|
3713
|
+
# Extract memories from recall tool outputs - only include memories the agent actually used
|
|
3714
|
+
# agent_result.used_memory_ids contains validated IDs from the done action
|
|
3715
|
+
used_memory_ids_set = set(agent_result.used_memory_ids) if agent_result.used_memory_ids else set()
|
|
3716
|
+
based_on: dict[str, list[MemoryFact]] = {"world": [], "experience": [], "opinion": [], "observation": []}
|
|
3717
|
+
seen_memory_ids: set[str] = set()
|
|
3718
|
+
for tc in agent_result.tool_trace:
|
|
3719
|
+
if tc.tool == "recall" and "memories" in tc.output:
|
|
3720
|
+
for memory_data in tc.output["memories"]:
|
|
3721
|
+
memory_id = memory_data.get("id")
|
|
3722
|
+
# Only include memories that the agent declared as used (or all if none specified)
|
|
3723
|
+
if memory_id and memory_id not in seen_memory_ids:
|
|
3724
|
+
if used_memory_ids_set and memory_id not in used_memory_ids_set:
|
|
3725
|
+
continue # Skip memories not actually used by the agent
|
|
3726
|
+
seen_memory_ids.add(memory_id)
|
|
3727
|
+
fact_type = memory_data.get("type", "world")
|
|
3728
|
+
if fact_type in based_on:
|
|
3729
|
+
based_on[fact_type].append(
|
|
3730
|
+
MemoryFact(
|
|
3731
|
+
id=memory_id,
|
|
3732
|
+
text=memory_data.get("text", ""),
|
|
3733
|
+
fact_type=fact_type,
|
|
3734
|
+
context=None,
|
|
3735
|
+
occurred_start=memory_data.get("occurred"),
|
|
3736
|
+
occurred_end=memory_data.get("occurred"),
|
|
3737
|
+
)
|
|
3738
|
+
)
|
|
3240
3739
|
|
|
3241
|
-
|
|
3242
|
-
|
|
3243
|
-
|
|
3740
|
+
# Extract mental models from tool outputs - only include models the agent actually used
|
|
3741
|
+
# agent_result.used_mental_model_ids contains validated IDs from the done action
|
|
3742
|
+
used_model_ids_set = set(agent_result.used_mental_model_ids) if agent_result.used_mental_model_ids else set()
|
|
3743
|
+
based_on["mental-models"] = []
|
|
3744
|
+
seen_model_ids: set[str] = set()
|
|
3745
|
+
for tc in agent_result.tool_trace:
|
|
3746
|
+
if tc.tool == "get_mental_model":
|
|
3747
|
+
# Single model lookup (with full details)
|
|
3748
|
+
if tc.output.get("found") and "model" in tc.output:
|
|
3749
|
+
model = tc.output["model"]
|
|
3750
|
+
model_id = model.get("id")
|
|
3751
|
+
if model_id and model_id not in seen_model_ids:
|
|
3752
|
+
# Only include models that the agent declared as used (or all if none specified)
|
|
3753
|
+
if used_model_ids_set and model_id not in used_model_ids_set:
|
|
3754
|
+
continue # Skip models not actually used by the agent
|
|
3755
|
+
seen_model_ids.add(model_id)
|
|
3756
|
+
# Add to based_on as MemoryFact with type "mental-models"
|
|
3757
|
+
model_name = model.get("name", "")
|
|
3758
|
+
model_summary = model.get("summary") or model.get("description", "")
|
|
3759
|
+
based_on["mental-models"].append(
|
|
3760
|
+
MemoryFact(
|
|
3761
|
+
id=model_id,
|
|
3762
|
+
text=f"{model_name}: {model_summary}",
|
|
3763
|
+
fact_type="mental-models",
|
|
3764
|
+
context=f"{model.get('type', 'concept')} ({model.get('subtype', 'structural')})",
|
|
3765
|
+
occurred_start=None,
|
|
3766
|
+
occurred_end=None,
|
|
3767
|
+
)
|
|
3768
|
+
)
|
|
3769
|
+
elif tc.tool == "search_mental_models":
|
|
3770
|
+
# Search mental models - include all returned models (filtered by used_model_ids_set if specified)
|
|
3771
|
+
for model in tc.output.get("mental_models", []):
|
|
3772
|
+
model_id = model.get("id")
|
|
3773
|
+
if model_id and model_id not in seen_model_ids:
|
|
3774
|
+
# Only include models that the agent declared as used (or all if none specified)
|
|
3775
|
+
if used_model_ids_set and model_id not in used_model_ids_set:
|
|
3776
|
+
continue # Skip models not actually used by the agent
|
|
3777
|
+
seen_model_ids.add(model_id)
|
|
3778
|
+
# Add to based_on as MemoryFact with type "mental-models"
|
|
3779
|
+
model_name = model.get("name", "")
|
|
3780
|
+
model_summary = model.get("summary") or model.get("description", "")
|
|
3781
|
+
based_on["mental-models"].append(
|
|
3782
|
+
MemoryFact(
|
|
3783
|
+
id=model_id,
|
|
3784
|
+
text=f"{model_name}: {model_summary}",
|
|
3785
|
+
fact_type="mental-models",
|
|
3786
|
+
context=f"{model.get('type', 'concept')} ({model.get('subtype', 'structural')})",
|
|
3787
|
+
occurred_start=None,
|
|
3788
|
+
occurred_end=None,
|
|
3789
|
+
)
|
|
3790
|
+
)
|
|
3791
|
+
elif tc.tool == "search_mental_models":
|
|
3792
|
+
# Search mental models - include all returned mental models (filtered by used_mental_model_ids_set if specified)
|
|
3793
|
+
used_mental_model_ids_set = (
|
|
3794
|
+
set(agent_result.used_mental_model_ids) if agent_result.used_mental_model_ids else set()
|
|
3795
|
+
)
|
|
3796
|
+
for mental_model in tc.output.get("mental_models", []):
|
|
3797
|
+
mental_model_id = mental_model.get("id")
|
|
3798
|
+
if mental_model_id and mental_model_id not in seen_model_ids:
|
|
3799
|
+
# Only include mental models that the agent declared as used (or all if none specified)
|
|
3800
|
+
if used_mental_model_ids_set and mental_model_id not in used_mental_model_ids_set:
|
|
3801
|
+
continue # Skip mental models not actually used by the agent
|
|
3802
|
+
seen_model_ids.add(mental_model_id)
|
|
3803
|
+
# Add to based_on as MemoryFact with type "mental-models" (mental models are synthesized knowledge)
|
|
3804
|
+
mental_model_name = mental_model.get("name", "")
|
|
3805
|
+
mental_model_content = mental_model.get("content", "")
|
|
3806
|
+
based_on["mental-models"].append(
|
|
3807
|
+
MemoryFact(
|
|
3808
|
+
id=mental_model_id,
|
|
3809
|
+
text=f"{mental_model_name}: {mental_model_content}",
|
|
3810
|
+
fact_type="mental-models",
|
|
3811
|
+
context="mental model (user-curated)",
|
|
3812
|
+
occurred_start=None,
|
|
3813
|
+
occurred_end=None,
|
|
3814
|
+
)
|
|
3815
|
+
)
|
|
3816
|
+
# List all models lookup - don't add to based_on (too verbose, just a listing)
|
|
3817
|
+
|
|
3818
|
+
# Add directives to based_on["mental-models"] (they are mental models with subtype='directive')
|
|
3819
|
+
for directive in directives:
|
|
3820
|
+
# Extract summary from observations
|
|
3821
|
+
summary_parts: list[str] = []
|
|
3822
|
+
for obs in directive.get("observations", []):
|
|
3823
|
+
# Support both Pydantic Observation objects and dicts
|
|
3824
|
+
if hasattr(obs, "content"):
|
|
3825
|
+
content = obs.content
|
|
3826
|
+
title = obs.title
|
|
3827
|
+
else:
|
|
3828
|
+
content = obs.get("content", "")
|
|
3829
|
+
title = obs.get("title", "")
|
|
3830
|
+
if title and content:
|
|
3831
|
+
summary_parts.append(f"{title}: {content}")
|
|
3832
|
+
elif content:
|
|
3833
|
+
summary_parts.append(content)
|
|
3834
|
+
|
|
3835
|
+
# Fallback to description if no observations
|
|
3836
|
+
if not summary_parts and directive.get("description"):
|
|
3837
|
+
summary_parts.append(directive["description"])
|
|
3838
|
+
|
|
3839
|
+
directive_name = directive.get("name", "")
|
|
3840
|
+
directive_summary = "; ".join(summary_parts) if summary_parts else ""
|
|
3841
|
+
based_on["mental-models"].append(
|
|
3842
|
+
MemoryFact(
|
|
3843
|
+
id=directive.get("id", ""),
|
|
3844
|
+
text=f"{directive_name}: {directive_summary}",
|
|
3845
|
+
fact_type="mental-models",
|
|
3846
|
+
context="directive (directive)",
|
|
3847
|
+
occurred_start=None,
|
|
3848
|
+
occurred_end=None,
|
|
3849
|
+
)
|
|
3850
|
+
)
|
|
3851
|
+
|
|
3852
|
+
# Build directives_applied from agent result
|
|
3853
|
+
from hindsight_api.engine.response_models import DirectiveRef
|
|
3854
|
+
|
|
3855
|
+
directives_applied_result = [
|
|
3856
|
+
DirectiveRef(id=d.id, name=d.name, content=d.content) for d in agent_result.directives_applied
|
|
3857
|
+
]
|
|
3858
|
+
|
|
3859
|
+
# Convert agent usage to TokenUsage format
|
|
3860
|
+
from hindsight_api.engine.response_models import TokenUsage
|
|
3861
|
+
|
|
3862
|
+
usage = TokenUsage(
|
|
3863
|
+
input_tokens=agent_result.usage.input_tokens,
|
|
3864
|
+
output_tokens=agent_result.usage.output_tokens,
|
|
3865
|
+
total_tokens=agent_result.usage.total_tokens,
|
|
3244
3866
|
)
|
|
3245
|
-
logger.info("\n" + "\n".join(log_buffer))
|
|
3246
3867
|
|
|
3247
|
-
# Return response with
|
|
3868
|
+
# Return response (compatible with existing API)
|
|
3248
3869
|
result = ReflectResult(
|
|
3249
|
-
text=
|
|
3250
|
-
based_on=
|
|
3251
|
-
new_opinions=[], #
|
|
3252
|
-
structured_output=structured_output,
|
|
3870
|
+
text=agent_result.text,
|
|
3871
|
+
based_on=based_on,
|
|
3872
|
+
new_opinions=[], # Learnings stored as mental models
|
|
3873
|
+
structured_output=agent_result.structured_output,
|
|
3874
|
+
usage=usage,
|
|
3875
|
+
tool_trace=tool_trace_result,
|
|
3876
|
+
llm_trace=llm_trace_result,
|
|
3877
|
+
directives_applied=directives_applied_result,
|
|
3253
3878
|
)
|
|
3254
3879
|
|
|
3255
3880
|
# Call post-operation hook if validator is configured
|
|
@@ -3273,48 +3898,6 @@ Guidelines:
|
|
|
3273
3898
|
|
|
3274
3899
|
return result
|
|
3275
3900
|
|
|
3276
|
-
async def _extract_and_store_opinions_async(
|
|
3277
|
-
self, bank_id: str, answer_text: str, query: str, tenant_id: str | None = None
|
|
3278
|
-
):
|
|
3279
|
-
"""
|
|
3280
|
-
Background task to extract and store opinions from think response.
|
|
3281
|
-
|
|
3282
|
-
This runs asynchronously and does not block the think response.
|
|
3283
|
-
|
|
3284
|
-
Args:
|
|
3285
|
-
bank_id: bank IDentifier
|
|
3286
|
-
answer_text: The generated answer text
|
|
3287
|
-
query: The original query
|
|
3288
|
-
tenant_id: Tenant identifier for internal authentication
|
|
3289
|
-
"""
|
|
3290
|
-
try:
|
|
3291
|
-
# Extract opinions from the answer
|
|
3292
|
-
new_opinions = await think_utils.extract_opinions_from_text(self._llm_config, text=answer_text, query=query)
|
|
3293
|
-
|
|
3294
|
-
# Store new opinions
|
|
3295
|
-
if new_opinions:
|
|
3296
|
-
from datetime import datetime
|
|
3297
|
-
|
|
3298
|
-
current_time = datetime.now(UTC)
|
|
3299
|
-
# Use internal context with tenant_id for background authentication
|
|
3300
|
-
# Extension can check internal=True to bypass normal auth
|
|
3301
|
-
from hindsight_api.models import RequestContext
|
|
3302
|
-
|
|
3303
|
-
internal_context = RequestContext(tenant_id=tenant_id, internal=True)
|
|
3304
|
-
for opinion in new_opinions:
|
|
3305
|
-
await self.retain_async(
|
|
3306
|
-
bank_id=bank_id,
|
|
3307
|
-
content=opinion.opinion,
|
|
3308
|
-
context=f"formed during thinking about: {query}",
|
|
3309
|
-
event_date=current_time,
|
|
3310
|
-
fact_type_override="opinion",
|
|
3311
|
-
confidence_score=opinion.confidence,
|
|
3312
|
-
request_context=internal_context,
|
|
3313
|
-
)
|
|
3314
|
-
|
|
3315
|
-
except Exception as e:
|
|
3316
|
-
logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
|
|
3317
|
-
|
|
3318
3901
|
async def get_entity_observations(
|
|
3319
3902
|
self,
|
|
3320
3903
|
bank_id: str,
|
|
@@ -3324,73 +3907,69 @@ Guidelines:
|
|
|
3324
3907
|
request_context: "RequestContext",
|
|
3325
3908
|
) -> list[Any]:
|
|
3326
3909
|
"""
|
|
3327
|
-
Get observations
|
|
3910
|
+
Get observations for an entity.
|
|
3911
|
+
|
|
3912
|
+
NOTE: Entity observations/summaries have been moved to mental models.
|
|
3913
|
+
This method returns an empty list. Use mental models for entity summaries.
|
|
3328
3914
|
|
|
3329
3915
|
Args:
|
|
3330
3916
|
bank_id: bank IDentifier
|
|
3331
3917
|
entity_id: Entity UUID to get observations for
|
|
3332
|
-
limit:
|
|
3918
|
+
limit: Ignored (kept for backwards compatibility)
|
|
3333
3919
|
request_context: Request context for authentication.
|
|
3334
3920
|
|
|
3335
3921
|
Returns:
|
|
3336
|
-
|
|
3922
|
+
Empty list (observations now in mental models)
|
|
3337
3923
|
"""
|
|
3338
3924
|
await self._authenticate_tenant(request_context)
|
|
3339
|
-
|
|
3340
|
-
async with acquire_with_retry(pool) as conn:
|
|
3341
|
-
rows = await conn.fetch(
|
|
3342
|
-
f"""
|
|
3343
|
-
SELECT mu.text, mu.mentioned_at
|
|
3344
|
-
FROM {fq_table("memory_units")} mu
|
|
3345
|
-
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
3346
|
-
WHERE mu.bank_id = $1
|
|
3347
|
-
AND mu.fact_type = 'observation'
|
|
3348
|
-
AND ue.entity_id = $2
|
|
3349
|
-
ORDER BY mu.mentioned_at DESC
|
|
3350
|
-
LIMIT $3
|
|
3351
|
-
""",
|
|
3352
|
-
bank_id,
|
|
3353
|
-
uuid.UUID(entity_id),
|
|
3354
|
-
limit,
|
|
3355
|
-
)
|
|
3356
|
-
|
|
3357
|
-
observations = []
|
|
3358
|
-
for row in rows:
|
|
3359
|
-
mentioned_at = row["mentioned_at"].isoformat() if row["mentioned_at"] else None
|
|
3360
|
-
observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
|
|
3361
|
-
return observations
|
|
3925
|
+
return []
|
|
3362
3926
|
|
|
3363
3927
|
async def list_entities(
|
|
3364
3928
|
self,
|
|
3365
3929
|
bank_id: str,
|
|
3366
3930
|
*,
|
|
3367
3931
|
limit: int = 100,
|
|
3932
|
+
offset: int = 0,
|
|
3368
3933
|
request_context: "RequestContext",
|
|
3369
|
-
) ->
|
|
3934
|
+
) -> dict[str, Any]:
|
|
3370
3935
|
"""
|
|
3371
|
-
List all entities for a bank.
|
|
3936
|
+
List all entities for a bank with pagination.
|
|
3372
3937
|
|
|
3373
3938
|
Args:
|
|
3374
3939
|
bank_id: bank IDentifier
|
|
3375
3940
|
limit: Maximum number of entities to return
|
|
3941
|
+
offset: Offset for pagination
|
|
3376
3942
|
request_context: Request context for authentication.
|
|
3377
3943
|
|
|
3378
3944
|
Returns:
|
|
3379
|
-
|
|
3945
|
+
Dict with items, total, limit, offset
|
|
3380
3946
|
"""
|
|
3381
3947
|
await self._authenticate_tenant(request_context)
|
|
3382
3948
|
pool = await self._get_pool()
|
|
3383
3949
|
async with acquire_with_retry(pool) as conn:
|
|
3950
|
+
# Get total count
|
|
3951
|
+
total_row = await conn.fetchrow(
|
|
3952
|
+
f"""
|
|
3953
|
+
SELECT COUNT(*) as total
|
|
3954
|
+
FROM {fq_table("entities")}
|
|
3955
|
+
WHERE bank_id = $1
|
|
3956
|
+
""",
|
|
3957
|
+
bank_id,
|
|
3958
|
+
)
|
|
3959
|
+
total = total_row["total"] if total_row else 0
|
|
3960
|
+
|
|
3961
|
+
# Get paginated entities
|
|
3384
3962
|
rows = await conn.fetch(
|
|
3385
3963
|
f"""
|
|
3386
3964
|
SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
|
|
3387
3965
|
FROM {fq_table("entities")}
|
|
3388
3966
|
WHERE bank_id = $1
|
|
3389
|
-
ORDER BY mention_count DESC, last_seen DESC
|
|
3390
|
-
LIMIT $2
|
|
3967
|
+
ORDER BY mention_count DESC, last_seen DESC, id ASC
|
|
3968
|
+
LIMIT $2 OFFSET $3
|
|
3391
3969
|
""",
|
|
3392
3970
|
bank_id,
|
|
3393
3971
|
limit,
|
|
3972
|
+
offset,
|
|
3394
3973
|
)
|
|
3395
3974
|
|
|
3396
3975
|
entities = []
|
|
@@ -3417,7 +3996,91 @@ Guidelines:
|
|
|
3417
3996
|
"metadata": metadata,
|
|
3418
3997
|
}
|
|
3419
3998
|
)
|
|
3420
|
-
return
|
|
3999
|
+
return {
|
|
4000
|
+
"items": entities,
|
|
4001
|
+
"total": total,
|
|
4002
|
+
"limit": limit,
|
|
4003
|
+
"offset": offset,
|
|
4004
|
+
}
|
|
4005
|
+
|
|
4006
|
+
async def list_tags(
|
|
4007
|
+
self,
|
|
4008
|
+
bank_id: str,
|
|
4009
|
+
*,
|
|
4010
|
+
pattern: str | None = None,
|
|
4011
|
+
limit: int = 100,
|
|
4012
|
+
offset: int = 0,
|
|
4013
|
+
request_context: "RequestContext",
|
|
4014
|
+
) -> dict[str, Any]:
|
|
4015
|
+
"""
|
|
4016
|
+
List all unique tags for a bank with usage counts.
|
|
4017
|
+
|
|
4018
|
+
Use this to discover available tags or expand wildcard patterns.
|
|
4019
|
+
Supports '*' as wildcard for flexible matching (case-insensitive):
|
|
4020
|
+
- 'user:*' matches user:alice, user:bob
|
|
4021
|
+
- '*-admin' matches role-admin, super-admin
|
|
4022
|
+
- 'env*-prod' matches env-prod, environment-prod
|
|
4023
|
+
|
|
4024
|
+
Args:
|
|
4025
|
+
bank_id: Bank identifier
|
|
4026
|
+
pattern: Wildcard pattern to filter tags (use '*' as wildcard, case-insensitive)
|
|
4027
|
+
limit: Maximum number of tags to return
|
|
4028
|
+
offset: Offset for pagination
|
|
4029
|
+
request_context: Request context for authentication.
|
|
4030
|
+
|
|
4031
|
+
Returns:
|
|
4032
|
+
Dict with items (list of {tag, count}), total, limit, offset
|
|
4033
|
+
"""
|
|
4034
|
+
await self._authenticate_tenant(request_context)
|
|
4035
|
+
pool = await self._get_pool()
|
|
4036
|
+
async with acquire_with_retry(pool) as conn:
|
|
4037
|
+
# Build pattern filter if provided (convert * to % for ILIKE)
|
|
4038
|
+
pattern_clause = ""
|
|
4039
|
+
params: list[Any] = [bank_id]
|
|
4040
|
+
if pattern:
|
|
4041
|
+
# Convert wildcard pattern: * -> % for SQL ILIKE
|
|
4042
|
+
sql_pattern = pattern.replace("*", "%")
|
|
4043
|
+
pattern_clause = "AND tag ILIKE $2"
|
|
4044
|
+
params.append(sql_pattern)
|
|
4045
|
+
|
|
4046
|
+
# Get total count of distinct tags matching pattern
|
|
4047
|
+
total_row = await conn.fetchrow(
|
|
4048
|
+
f"""
|
|
4049
|
+
SELECT COUNT(DISTINCT tag) as total
|
|
4050
|
+
FROM {fq_table("memory_units")}, unnest(tags) AS tag
|
|
4051
|
+
WHERE bank_id = $1 AND tags IS NOT NULL AND tags != '{{}}'
|
|
4052
|
+
{pattern_clause}
|
|
4053
|
+
""",
|
|
4054
|
+
*params,
|
|
4055
|
+
)
|
|
4056
|
+
total = total_row["total"] if total_row else 0
|
|
4057
|
+
|
|
4058
|
+
# Get paginated tags with counts, ordered by frequency
|
|
4059
|
+
limit_param = len(params) + 1
|
|
4060
|
+
offset_param = len(params) + 2
|
|
4061
|
+
params.extend([limit, offset])
|
|
4062
|
+
|
|
4063
|
+
rows = await conn.fetch(
|
|
4064
|
+
f"""
|
|
4065
|
+
SELECT tag, COUNT(*) as count
|
|
4066
|
+
FROM {fq_table("memory_units")}, unnest(tags) AS tag
|
|
4067
|
+
WHERE bank_id = $1 AND tags IS NOT NULL AND tags != '{{}}'
|
|
4068
|
+
{pattern_clause}
|
|
4069
|
+
GROUP BY tag
|
|
4070
|
+
ORDER BY count DESC, tag ASC
|
|
4071
|
+
LIMIT ${limit_param} OFFSET ${offset_param}
|
|
4072
|
+
""",
|
|
4073
|
+
*params,
|
|
4074
|
+
)
|
|
4075
|
+
|
|
4076
|
+
items = [{"tag": row["tag"], "count": row["count"]} for row in rows]
|
|
4077
|
+
|
|
4078
|
+
return {
|
|
4079
|
+
"items": items,
|
|
4080
|
+
"total": total,
|
|
4081
|
+
"limit": limit,
|
|
4082
|
+
"offset": offset,
|
|
4083
|
+
}
|
|
3421
4084
|
|
|
3422
4085
|
async def get_entity_state(
|
|
3423
4086
|
self,
|
|
@@ -3429,22 +4092,23 @@ Guidelines:
|
|
|
3429
4092
|
request_context: "RequestContext",
|
|
3430
4093
|
) -> EntityState:
|
|
3431
4094
|
"""
|
|
3432
|
-
Get the current state
|
|
4095
|
+
Get the current state of an entity.
|
|
4096
|
+
|
|
4097
|
+
NOTE: Entity observations/summaries have been moved to mental models.
|
|
4098
|
+
This method returns an entity with empty observations.
|
|
3433
4099
|
|
|
3434
4100
|
Args:
|
|
3435
4101
|
bank_id: bank IDentifier
|
|
3436
4102
|
entity_id: Entity UUID
|
|
3437
4103
|
entity_name: Canonical name of the entity
|
|
3438
|
-
limit: Maximum number of observations to include
|
|
4104
|
+
limit: Maximum number of observations to include (kept for backwards compat)
|
|
3439
4105
|
request_context: Request context for authentication.
|
|
3440
4106
|
|
|
3441
4107
|
Returns:
|
|
3442
|
-
EntityState with observations
|
|
4108
|
+
EntityState with empty observations (summaries now in mental models)
|
|
3443
4109
|
"""
|
|
3444
|
-
|
|
3445
|
-
|
|
3446
|
-
)
|
|
3447
|
-
return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=observations)
|
|
4110
|
+
await self._authenticate_tenant(request_context)
|
|
4111
|
+
return EntityState(entity_id=entity_id, canonical_name=entity_name, observations=[])
|
|
3448
4112
|
|
|
3449
4113
|
async def regenerate_entity_observations(
|
|
3450
4114
|
self,
|
|
@@ -3455,533 +4119,1228 @@ Guidelines:
|
|
|
3455
4119
|
version: str | None = None,
|
|
3456
4120
|
conn=None,
|
|
3457
4121
|
request_context: "RequestContext",
|
|
3458
|
-
) ->
|
|
4122
|
+
) -> list[str]:
|
|
3459
4123
|
"""
|
|
3460
|
-
Regenerate observations for an entity
|
|
3461
|
-
|
|
3462
|
-
|
|
3463
|
-
|
|
3464
|
-
4. Deleting old observations for this entity
|
|
3465
|
-
5. Storing new observations linked to the entity
|
|
4124
|
+
Regenerate observations for an entity.
|
|
4125
|
+
|
|
4126
|
+
NOTE: Entity observations/summaries have been moved to mental models.
|
|
4127
|
+
This method is now a no-op and returns an empty list.
|
|
3466
4128
|
|
|
3467
4129
|
Args:
|
|
3468
4130
|
bank_id: bank IDentifier
|
|
3469
4131
|
entity_id: Entity UUID
|
|
3470
4132
|
entity_name: Canonical name of the entity
|
|
3471
4133
|
version: Entity's last_seen timestamp when task was created (for deduplication)
|
|
3472
|
-
conn: Optional database connection (
|
|
4134
|
+
conn: Optional database connection (ignored)
|
|
3473
4135
|
request_context: Request context for authentication.
|
|
4136
|
+
|
|
4137
|
+
Returns:
|
|
4138
|
+
Empty list (observations now in mental models)
|
|
3474
4139
|
"""
|
|
3475
4140
|
await self._authenticate_tenant(request_context)
|
|
3476
|
-
|
|
3477
|
-
entity_uuid = uuid.UUID(entity_id)
|
|
3478
|
-
|
|
3479
|
-
# Helper to run a query with provided conn or acquire one
|
|
3480
|
-
async def fetch_with_conn(query, *args):
|
|
3481
|
-
if conn is not None:
|
|
3482
|
-
return await conn.fetch(query, *args)
|
|
3483
|
-
else:
|
|
3484
|
-
async with acquire_with_retry(pool) as acquired_conn:
|
|
3485
|
-
return await acquired_conn.fetch(query, *args)
|
|
4141
|
+
return []
|
|
3486
4142
|
|
|
3487
|
-
|
|
3488
|
-
|
|
3489
|
-
|
|
3490
|
-
|
|
3491
|
-
|
|
3492
|
-
|
|
4143
|
+
# =========================================================================
|
|
4144
|
+
# Statistics & Operations (for HTTP API layer)
|
|
4145
|
+
# =========================================================================
|
|
4146
|
+
|
|
4147
|
+
async def get_bank_stats(
|
|
4148
|
+
self,
|
|
4149
|
+
bank_id: str,
|
|
4150
|
+
*,
|
|
4151
|
+
request_context: "RequestContext",
|
|
4152
|
+
) -> dict[str, Any]:
|
|
4153
|
+
"""Get statistics about memory nodes and links for a bank."""
|
|
4154
|
+
await self._authenticate_tenant(request_context)
|
|
4155
|
+
pool = await self._get_pool()
|
|
3493
4156
|
|
|
3494
|
-
|
|
3495
|
-
|
|
3496
|
-
|
|
4157
|
+
async with acquire_with_retry(pool) as conn:
|
|
4158
|
+
# Get node counts by fact_type
|
|
4159
|
+
node_stats = await conn.fetch(
|
|
4160
|
+
f"""
|
|
4161
|
+
SELECT fact_type, COUNT(*) as count
|
|
4162
|
+
FROM {fq_table("memory_units")}
|
|
4163
|
+
WHERE bank_id = $1
|
|
4164
|
+
GROUP BY fact_type
|
|
4165
|
+
""",
|
|
4166
|
+
bank_id,
|
|
4167
|
+
)
|
|
4168
|
+
|
|
4169
|
+
# Get link counts by link_type
|
|
4170
|
+
link_stats = await conn.fetch(
|
|
3497
4171
|
f"""
|
|
3498
|
-
SELECT
|
|
4172
|
+
SELECT ml.link_type, COUNT(*) as count
|
|
4173
|
+
FROM {fq_table("memory_links")} ml
|
|
4174
|
+
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
4175
|
+
WHERE mu.bank_id = $1
|
|
4176
|
+
GROUP BY ml.link_type
|
|
4177
|
+
""",
|
|
4178
|
+
bank_id,
|
|
4179
|
+
)
|
|
4180
|
+
|
|
4181
|
+
# Get link counts by fact_type (from nodes)
|
|
4182
|
+
link_fact_type_stats = await conn.fetch(
|
|
4183
|
+
f"""
|
|
4184
|
+
SELECT mu.fact_type, COUNT(*) as count
|
|
4185
|
+
FROM {fq_table("memory_links")} ml
|
|
4186
|
+
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
4187
|
+
WHERE mu.bank_id = $1
|
|
4188
|
+
GROUP BY mu.fact_type
|
|
4189
|
+
""",
|
|
4190
|
+
bank_id,
|
|
4191
|
+
)
|
|
4192
|
+
|
|
4193
|
+
# Get link counts by fact_type AND link_type
|
|
4194
|
+
link_breakdown_stats = await conn.fetch(
|
|
4195
|
+
f"""
|
|
4196
|
+
SELECT mu.fact_type, ml.link_type, COUNT(*) as count
|
|
4197
|
+
FROM {fq_table("memory_links")} ml
|
|
4198
|
+
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
4199
|
+
WHERE mu.bank_id = $1
|
|
4200
|
+
GROUP BY mu.fact_type, ml.link_type
|
|
4201
|
+
""",
|
|
4202
|
+
bank_id,
|
|
4203
|
+
)
|
|
4204
|
+
|
|
4205
|
+
# Get pending and failed operations counts
|
|
4206
|
+
ops_stats = await conn.fetch(
|
|
4207
|
+
f"""
|
|
4208
|
+
SELECT status, COUNT(*) as count
|
|
4209
|
+
FROM {fq_table("async_operations")}
|
|
4210
|
+
WHERE bank_id = $1
|
|
4211
|
+
GROUP BY status
|
|
4212
|
+
""",
|
|
4213
|
+
bank_id,
|
|
4214
|
+
)
|
|
4215
|
+
|
|
4216
|
+
return {
|
|
4217
|
+
"bank_id": bank_id,
|
|
4218
|
+
"node_counts": {row["fact_type"]: row["count"] for row in node_stats},
|
|
4219
|
+
"link_counts": {row["link_type"]: row["count"] for row in link_stats},
|
|
4220
|
+
"link_counts_by_fact_type": {row["fact_type"]: row["count"] for row in link_fact_type_stats},
|
|
4221
|
+
"link_breakdown": [
|
|
4222
|
+
{"fact_type": row["fact_type"], "link_type": row["link_type"], "count": row["count"]}
|
|
4223
|
+
for row in link_breakdown_stats
|
|
4224
|
+
],
|
|
4225
|
+
"operations": {row["status"]: row["count"] for row in ops_stats},
|
|
4226
|
+
}
|
|
4227
|
+
|
|
4228
|
+
async def get_entity(
|
|
4229
|
+
self,
|
|
4230
|
+
bank_id: str,
|
|
4231
|
+
entity_id: str,
|
|
4232
|
+
*,
|
|
4233
|
+
request_context: "RequestContext",
|
|
4234
|
+
) -> dict[str, Any] | None:
|
|
4235
|
+
"""Get entity details including metadata and observations."""
|
|
4236
|
+
await self._authenticate_tenant(request_context)
|
|
4237
|
+
pool = await self._get_pool()
|
|
4238
|
+
|
|
4239
|
+
async with acquire_with_retry(pool) as conn:
|
|
4240
|
+
entity_row = await conn.fetchrow(
|
|
4241
|
+
f"""
|
|
4242
|
+
SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
|
|
3499
4243
|
FROM {fq_table("entities")}
|
|
3500
|
-
WHERE
|
|
4244
|
+
WHERE bank_id = $1 AND id = $2
|
|
3501
4245
|
""",
|
|
3502
|
-
entity_uuid,
|
|
3503
4246
|
bank_id,
|
|
4247
|
+
uuid.UUID(entity_id),
|
|
3504
4248
|
)
|
|
3505
4249
|
|
|
3506
|
-
|
|
3507
|
-
|
|
4250
|
+
if not entity_row:
|
|
4251
|
+
return None
|
|
3508
4252
|
|
|
3509
|
-
#
|
|
3510
|
-
|
|
4253
|
+
# Get observations for the entity
|
|
4254
|
+
observations = await self.get_entity_observations(bank_id, entity_id, limit=20, request_context=request_context)
|
|
4255
|
+
|
|
4256
|
+
return {
|
|
4257
|
+
"id": str(entity_row["id"]),
|
|
4258
|
+
"canonical_name": entity_row["canonical_name"],
|
|
4259
|
+
"mention_count": entity_row["mention_count"],
|
|
4260
|
+
"first_seen": entity_row["first_seen"].isoformat() if entity_row["first_seen"] else None,
|
|
4261
|
+
"last_seen": entity_row["last_seen"].isoformat() if entity_row["last_seen"] else None,
|
|
4262
|
+
"metadata": entity_row["metadata"] or {},
|
|
4263
|
+
"observations": observations,
|
|
4264
|
+
}
|
|
4265
|
+
|
|
4266
|
+
def _parse_observations(self, observations_raw: list):
|
|
4267
|
+
"""Parse raw observation dicts into typed Observation models.
|
|
4268
|
+
|
|
4269
|
+
Returns list of Observation models with computed trend/evidence_span/evidence_count.
|
|
4270
|
+
"""
|
|
4271
|
+
from .reflect.observations import Observation, ObservationEvidence
|
|
4272
|
+
|
|
4273
|
+
observations: list[Observation] = []
|
|
4274
|
+
for obs in observations_raw:
|
|
4275
|
+
if not isinstance(obs, dict):
|
|
4276
|
+
continue
|
|
4277
|
+
|
|
4278
|
+
try:
|
|
4279
|
+
parsed = Observation(
|
|
4280
|
+
title=obs.get("title", ""),
|
|
4281
|
+
content=obs.get("content", ""),
|
|
4282
|
+
evidence=[
|
|
4283
|
+
ObservationEvidence(
|
|
4284
|
+
memory_id=ev.get("memory_id", ""),
|
|
4285
|
+
quote=ev.get("quote", ""),
|
|
4286
|
+
relevance=ev.get("relevance", ""),
|
|
4287
|
+
timestamp=ev.get("timestamp"),
|
|
4288
|
+
)
|
|
4289
|
+
for ev in obs.get("evidence", [])
|
|
4290
|
+
if isinstance(ev, dict)
|
|
4291
|
+
],
|
|
4292
|
+
created_at=obs.get("created_at"),
|
|
4293
|
+
)
|
|
4294
|
+
observations.append(parsed)
|
|
4295
|
+
except Exception as e:
|
|
4296
|
+
logger.warning(f"Failed to parse observation: {e}")
|
|
4297
|
+
continue
|
|
4298
|
+
|
|
4299
|
+
return observations
|
|
4300
|
+
|
|
4301
|
+
async def _count_memories_since(
|
|
4302
|
+
self,
|
|
4303
|
+
bank_id: str,
|
|
4304
|
+
since_timestamp: str | None,
|
|
4305
|
+
pool=None,
|
|
4306
|
+
) -> int:
|
|
4307
|
+
"""
|
|
4308
|
+
Count memories created after a given timestamp.
|
|
4309
|
+
|
|
4310
|
+
Args:
|
|
4311
|
+
bank_id: Bank identifier
|
|
4312
|
+
since_timestamp: ISO timestamp string. If None, returns total count.
|
|
4313
|
+
pool: Optional database pool (uses default if not provided)
|
|
4314
|
+
|
|
4315
|
+
Returns:
|
|
4316
|
+
Number of memories created since the timestamp
|
|
4317
|
+
"""
|
|
4318
|
+
if pool is None:
|
|
4319
|
+
pool = await self._get_pool()
|
|
4320
|
+
|
|
4321
|
+
async with acquire_with_retry(pool) as conn:
|
|
4322
|
+
if since_timestamp:
|
|
4323
|
+
# Parse the timestamp
|
|
4324
|
+
from datetime import datetime
|
|
4325
|
+
|
|
4326
|
+
try:
|
|
4327
|
+
ts = datetime.fromisoformat(since_timestamp.replace("Z", "+00:00"))
|
|
4328
|
+
except ValueError:
|
|
4329
|
+
# Invalid timestamp, return total count
|
|
4330
|
+
ts = None
|
|
4331
|
+
|
|
4332
|
+
if ts:
|
|
4333
|
+
count = await conn.fetchval(
|
|
4334
|
+
f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1 AND created_at > $2",
|
|
4335
|
+
bank_id,
|
|
4336
|
+
ts,
|
|
4337
|
+
)
|
|
4338
|
+
return count or 0
|
|
4339
|
+
|
|
4340
|
+
# No timestamp or invalid, return total count
|
|
4341
|
+
count = await conn.fetchval(
|
|
4342
|
+
f"SELECT COUNT(*) FROM {fq_table('memory_units')} WHERE bank_id = $1",
|
|
4343
|
+
bank_id,
|
|
4344
|
+
)
|
|
4345
|
+
return count or 0
|
|
4346
|
+
|
|
4347
|
+
async def _invalidate_facts_from_mental_models(
|
|
4348
|
+
self,
|
|
4349
|
+
conn,
|
|
4350
|
+
bank_id: str,
|
|
4351
|
+
fact_ids: list[str],
|
|
4352
|
+
) -> int:
|
|
4353
|
+
"""
|
|
4354
|
+
Remove fact IDs from observation source_memory_ids when memories are deleted.
|
|
4355
|
+
|
|
4356
|
+
Observations are stored in memory_units with fact_type='observation'
|
|
4357
|
+
and have a source_memory_ids column (UUID[]) tracking their source memories.
|
|
4358
|
+
|
|
4359
|
+
Args:
|
|
4360
|
+
conn: Database connection
|
|
4361
|
+
bank_id: Bank identifier
|
|
4362
|
+
fact_ids: List of fact IDs to remove from observations
|
|
4363
|
+
|
|
4364
|
+
Returns:
|
|
4365
|
+
Number of observations updated
|
|
4366
|
+
"""
|
|
4367
|
+
if not fact_ids:
|
|
4368
|
+
return 0
|
|
4369
|
+
|
|
4370
|
+
# Convert string IDs to UUIDs for the array comparison
|
|
4371
|
+
import uuid as uuid_module
|
|
4372
|
+
|
|
4373
|
+
fact_uuids = [uuid_module.UUID(fid) for fid in fact_ids]
|
|
4374
|
+
|
|
4375
|
+
# Update observations (memory_units with fact_type='observation')
|
|
4376
|
+
# by removing the deleted fact IDs from source_memory_ids
|
|
4377
|
+
# Use array subtraction: source_memory_ids - deleted_ids
|
|
4378
|
+
result = await conn.execute(
|
|
3511
4379
|
f"""
|
|
3512
|
-
|
|
3513
|
-
|
|
3514
|
-
|
|
3515
|
-
|
|
3516
|
-
|
|
3517
|
-
|
|
3518
|
-
|
|
3519
|
-
|
|
4380
|
+
UPDATE {fq_table("memory_units")}
|
|
4381
|
+
SET source_memory_ids = (
|
|
4382
|
+
SELECT COALESCE(array_agg(elem), ARRAY[]::uuid[])
|
|
4383
|
+
FROM unnest(source_memory_ids) AS elem
|
|
4384
|
+
WHERE elem != ALL($2::uuid[])
|
|
4385
|
+
),
|
|
4386
|
+
updated_at = NOW()
|
|
4387
|
+
WHERE bank_id = $1
|
|
4388
|
+
AND fact_type = 'observation'
|
|
4389
|
+
AND source_memory_ids && $2::uuid[]
|
|
3520
4390
|
""",
|
|
3521
4391
|
bank_id,
|
|
3522
|
-
|
|
4392
|
+
fact_uuids,
|
|
3523
4393
|
)
|
|
3524
4394
|
|
|
3525
|
-
|
|
3526
|
-
|
|
4395
|
+
# Parse the result to get number of updated rows
|
|
4396
|
+
updated_count = int(result.split()[-1]) if result and "UPDATE" in result else 0
|
|
4397
|
+
if updated_count > 0:
|
|
4398
|
+
logger.info(
|
|
4399
|
+
f"[OBSERVATIONS] Invalidated {len(fact_ids)} fact IDs from {updated_count} observations in bank {bank_id}"
|
|
4400
|
+
)
|
|
4401
|
+
return updated_count
|
|
3527
4402
|
|
|
3528
|
-
|
|
3529
|
-
|
|
3530
|
-
|
|
3531
|
-
|
|
3532
|
-
|
|
3533
|
-
|
|
3534
|
-
|
|
3535
|
-
|
|
3536
|
-
|
|
3537
|
-
|
|
3538
|
-
|
|
3539
|
-
|
|
4403
|
+
# =========================================================================
|
|
4404
|
+
# MENTAL MODELS (CONSOLIDATED) - Read-only access to auto-consolidated mental models
|
|
4405
|
+
# =========================================================================
|
|
4406
|
+
|
|
4407
|
+
async def list_mental_models_consolidated(
|
|
4408
|
+
self,
|
|
4409
|
+
bank_id: str,
|
|
4410
|
+
*,
|
|
4411
|
+
tags: list[str] | None = None,
|
|
4412
|
+
tags_match: str = "any",
|
|
4413
|
+
limit: int = 100,
|
|
4414
|
+
offset: int = 0,
|
|
4415
|
+
request_context: "RequestContext",
|
|
4416
|
+
) -> list[dict[str, Any]]:
|
|
4417
|
+
"""List auto-consolidated observations for a bank.
|
|
4418
|
+
|
|
4419
|
+
Observations are stored in memory_units with fact_type='observation'.
|
|
4420
|
+
They are automatically created and updated by the consolidation engine.
|
|
4421
|
+
|
|
4422
|
+
Args:
|
|
4423
|
+
bank_id: Bank identifier
|
|
4424
|
+
tags: Optional tags to filter by
|
|
4425
|
+
tags_match: How to match tags - 'any', 'all', or 'exact'
|
|
4426
|
+
limit: Maximum number of results
|
|
4427
|
+
offset: Offset for pagination
|
|
4428
|
+
request_context: Request context for authentication
|
|
4429
|
+
|
|
4430
|
+
Returns:
|
|
4431
|
+
List of observation dicts
|
|
4432
|
+
"""
|
|
4433
|
+
await self._authenticate_tenant(request_context)
|
|
4434
|
+
pool = await self._get_pool()
|
|
4435
|
+
|
|
4436
|
+
async with acquire_with_retry(pool) as conn:
|
|
4437
|
+
# Build tag filter
|
|
4438
|
+
tag_filter = ""
|
|
4439
|
+
params: list[Any] = [bank_id, limit, offset]
|
|
4440
|
+
if tags:
|
|
4441
|
+
if tags_match == "all":
|
|
4442
|
+
tag_filter = " AND tags @> $4::varchar[]"
|
|
4443
|
+
elif tags_match == "exact":
|
|
4444
|
+
tag_filter = " AND tags = $4::varchar[]"
|
|
4445
|
+
else: # any
|
|
4446
|
+
tag_filter = " AND tags && $4::varchar[]"
|
|
4447
|
+
params.append(tags)
|
|
4448
|
+
|
|
4449
|
+
rows = await conn.fetch(
|
|
4450
|
+
f"""
|
|
4451
|
+
SELECT id, bank_id, text, proof_count, history, tags, source_memory_ids, created_at, updated_at
|
|
4452
|
+
FROM {fq_table("memory_units")}
|
|
4453
|
+
WHERE bank_id = $1 AND fact_type = 'observation' {tag_filter}
|
|
4454
|
+
ORDER BY updated_at DESC NULLS LAST
|
|
4455
|
+
LIMIT $2 OFFSET $3
|
|
4456
|
+
""",
|
|
4457
|
+
*params,
|
|
3540
4458
|
)
|
|
3541
4459
|
|
|
3542
|
-
|
|
3543
|
-
observations = await observation_utils.extract_observations_from_facts(self._llm_config, entity_name, facts)
|
|
4460
|
+
return [self._row_to_observation_consolidated(row) for row in rows]
|
|
3544
4461
|
|
|
3545
|
-
|
|
3546
|
-
|
|
4462
|
+
async def get_observation_consolidated(
|
|
4463
|
+
self,
|
|
4464
|
+
bank_id: str,
|
|
4465
|
+
observation_id: str,
|
|
4466
|
+
*,
|
|
4467
|
+
include_source_memories: bool = True,
|
|
4468
|
+
request_context: "RequestContext",
|
|
4469
|
+
) -> dict[str, Any] | None:
|
|
4470
|
+
"""Get a single observation by ID.
|
|
4471
|
+
|
|
4472
|
+
Args:
|
|
4473
|
+
bank_id: Bank identifier
|
|
4474
|
+
observation_id: Observation ID
|
|
4475
|
+
include_source_memories: Whether to include full source memory details
|
|
4476
|
+
request_context: Request context for authentication
|
|
4477
|
+
|
|
4478
|
+
Returns:
|
|
4479
|
+
Observation dict or None if not found
|
|
4480
|
+
"""
|
|
4481
|
+
await self._authenticate_tenant(request_context)
|
|
4482
|
+
pool = await self._get_pool()
|
|
3547
4483
|
|
|
3548
|
-
|
|
3549
|
-
|
|
3550
|
-
# If conn is None, acquire one and start a transaction
|
|
3551
|
-
async def do_db_operations(db_conn):
|
|
3552
|
-
# Delete old observations for this entity
|
|
3553
|
-
await db_conn.execute(
|
|
4484
|
+
async with acquire_with_retry(pool) as conn:
|
|
4485
|
+
row = await conn.fetchrow(
|
|
3554
4486
|
f"""
|
|
3555
|
-
|
|
3556
|
-
|
|
3557
|
-
|
|
3558
|
-
FROM {fq_table("memory_units")} mu
|
|
3559
|
-
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
3560
|
-
WHERE mu.bank_id = $1
|
|
3561
|
-
AND mu.fact_type = 'observation'
|
|
3562
|
-
AND ue.entity_id = $2
|
|
3563
|
-
)
|
|
4487
|
+
SELECT id, bank_id, text, proof_count, history, tags, source_memory_ids, created_at, updated_at
|
|
4488
|
+
FROM {fq_table("memory_units")}
|
|
4489
|
+
WHERE bank_id = $1 AND id = $2 AND fact_type = 'observation'
|
|
3564
4490
|
""",
|
|
3565
4491
|
bank_id,
|
|
3566
|
-
|
|
4492
|
+
observation_id,
|
|
3567
4493
|
)
|
|
3568
4494
|
|
|
3569
|
-
|
|
3570
|
-
|
|
4495
|
+
if not row:
|
|
4496
|
+
return None
|
|
3571
4497
|
|
|
3572
|
-
|
|
3573
|
-
current_time = utcnow()
|
|
3574
|
-
created_ids = []
|
|
4498
|
+
result = self._row_to_observation_consolidated(row)
|
|
3575
4499
|
|
|
3576
|
-
|
|
3577
|
-
|
|
4500
|
+
# Fetch source memories if requested and source_memory_ids exist
|
|
4501
|
+
if include_source_memories and result.get("source_memory_ids"):
|
|
4502
|
+
source_ids = [uuid.UUID(sid) if isinstance(sid, str) else sid for sid in result["source_memory_ids"]]
|
|
4503
|
+
source_rows = await conn.fetch(
|
|
3578
4504
|
f"""
|
|
3579
|
-
|
|
3580
|
-
|
|
3581
|
-
|
|
3582
|
-
|
|
3583
|
-
)
|
|
3584
|
-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
|
|
3585
|
-
RETURNING id
|
|
4505
|
+
SELECT id, text, fact_type, context, occurred_start, mentioned_at
|
|
4506
|
+
FROM {fq_table("memory_units")}
|
|
4507
|
+
WHERE id = ANY($1::uuid[])
|
|
4508
|
+
ORDER BY mentioned_at DESC NULLS LAST
|
|
3586
4509
|
""",
|
|
3587
|
-
|
|
3588
|
-
obs_text,
|
|
3589
|
-
str(embedding),
|
|
3590
|
-
f"observation about {entity_name}",
|
|
3591
|
-
current_time,
|
|
3592
|
-
current_time,
|
|
3593
|
-
current_time,
|
|
3594
|
-
current_time,
|
|
4510
|
+
source_ids,
|
|
3595
4511
|
)
|
|
3596
|
-
|
|
3597
|
-
|
|
4512
|
+
result["source_memories"] = [
|
|
4513
|
+
{
|
|
4514
|
+
"id": str(r["id"]),
|
|
4515
|
+
"text": r["text"],
|
|
4516
|
+
"type": r["fact_type"],
|
|
4517
|
+
"context": r["context"],
|
|
4518
|
+
"occurred_start": r["occurred_start"].isoformat() if r["occurred_start"] else None,
|
|
4519
|
+
"mentioned_at": r["mentioned_at"].isoformat() if r["mentioned_at"] else None,
|
|
4520
|
+
}
|
|
4521
|
+
for r in source_rows
|
|
4522
|
+
]
|
|
3598
4523
|
|
|
3599
|
-
|
|
3600
|
-
await db_conn.execute(
|
|
3601
|
-
f"""
|
|
3602
|
-
INSERT INTO {fq_table("unit_entities")} (unit_id, entity_id)
|
|
3603
|
-
VALUES ($1, $2)
|
|
3604
|
-
""",
|
|
3605
|
-
uuid.UUID(obs_id),
|
|
3606
|
-
entity_uuid,
|
|
3607
|
-
)
|
|
4524
|
+
return result
|
|
3608
4525
|
|
|
3609
|
-
|
|
4526
|
+
def _row_to_observation_consolidated(self, row: Any) -> dict[str, Any]:
|
|
4527
|
+
"""Convert a database row to an observation dict."""
|
|
4528
|
+
import json
|
|
3610
4529
|
|
|
3611
|
-
|
|
3612
|
-
|
|
3613
|
-
|
|
3614
|
-
|
|
3615
|
-
|
|
3616
|
-
async with acquire_with_retry(pool) as acquired_conn:
|
|
3617
|
-
async with acquired_conn.transaction():
|
|
3618
|
-
return await do_db_operations(acquired_conn)
|
|
4530
|
+
history = row["history"]
|
|
4531
|
+
if isinstance(history, str):
|
|
4532
|
+
history = json.loads(history)
|
|
4533
|
+
elif history is None:
|
|
4534
|
+
history = []
|
|
3619
4535
|
|
|
3620
|
-
|
|
4536
|
+
# Convert source_memory_ids to strings
|
|
4537
|
+
source_memory_ids = row.get("source_memory_ids") or []
|
|
4538
|
+
source_memory_ids = [str(sid) for sid in source_memory_ids]
|
|
4539
|
+
|
|
4540
|
+
return {
|
|
4541
|
+
"id": str(row["id"]),
|
|
4542
|
+
"bank_id": row["bank_id"],
|
|
4543
|
+
"text": row["text"],
|
|
4544
|
+
"proof_count": row["proof_count"] or 1,
|
|
4545
|
+
"history": history,
|
|
4546
|
+
"tags": row["tags"] or [],
|
|
4547
|
+
"source_memory_ids": source_memory_ids,
|
|
4548
|
+
"source_memories": [], # Populated separately when fetching full details
|
|
4549
|
+
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
|
4550
|
+
"updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
|
|
4551
|
+
}
|
|
4552
|
+
|
|
4553
|
+
# =========================================================================
|
|
4554
|
+
# MENTAL MODELS CRUD
|
|
4555
|
+
# =========================================================================
|
|
4556
|
+
|
|
4557
|
+
async def list_mental_models(
|
|
3621
4558
|
self,
|
|
3622
4559
|
bank_id: str,
|
|
3623
|
-
|
|
3624
|
-
|
|
3625
|
-
|
|
3626
|
-
|
|
3627
|
-
|
|
3628
|
-
""
|
|
3629
|
-
|
|
3630
|
-
|
|
3631
|
-
Processes entities in PARALLEL for faster execution.
|
|
4560
|
+
*,
|
|
4561
|
+
tags: list[str] | None = None,
|
|
4562
|
+
tags_match: str = "any",
|
|
4563
|
+
limit: int = 100,
|
|
4564
|
+
offset: int = 0,
|
|
4565
|
+
request_context: "RequestContext",
|
|
4566
|
+
) -> list[dict[str, Any]]:
|
|
4567
|
+
"""List pinned mental models for a bank.
|
|
3632
4568
|
|
|
3633
4569
|
Args:
|
|
3634
4570
|
bank_id: Bank identifier
|
|
3635
|
-
|
|
3636
|
-
|
|
3637
|
-
|
|
3638
|
-
|
|
3639
|
-
|
|
3640
|
-
return
|
|
4571
|
+
tags: Optional tags to filter by
|
|
4572
|
+
tags_match: How to match tags - 'any', 'all', or 'exact'
|
|
4573
|
+
limit: Maximum number of results
|
|
4574
|
+
offset: Offset for pagination
|
|
4575
|
+
request_context: Request context for authentication
|
|
3641
4576
|
|
|
3642
|
-
|
|
3643
|
-
|
|
3644
|
-
|
|
4577
|
+
Returns:
|
|
4578
|
+
List of pinned mental model dicts
|
|
4579
|
+
"""
|
|
4580
|
+
await self._authenticate_tenant(request_context)
|
|
4581
|
+
pool = await self._get_pool()
|
|
3645
4582
|
|
|
3646
|
-
|
|
3647
|
-
|
|
4583
|
+
async with acquire_with_retry(pool) as conn:
|
|
4584
|
+
# Build tag filter
|
|
4585
|
+
tag_filter = ""
|
|
4586
|
+
params: list[Any] = [bank_id, limit, offset]
|
|
4587
|
+
if tags:
|
|
4588
|
+
if tags_match == "all":
|
|
4589
|
+
tag_filter = " AND tags @> $4::varchar[]"
|
|
4590
|
+
elif tags_match == "exact":
|
|
4591
|
+
tag_filter = " AND tags = $4::varchar[]"
|
|
4592
|
+
else: # any
|
|
4593
|
+
tag_filter = " AND tags && $4::varchar[]"
|
|
4594
|
+
params.append(tags)
|
|
3648
4595
|
|
|
3649
|
-
|
|
3650
|
-
if conn is not None:
|
|
3651
|
-
# Use the provided connection (transactional with caller)
|
|
3652
|
-
entity_rows = await conn.fetch(
|
|
4596
|
+
rows = await conn.fetch(
|
|
3653
4597
|
f"""
|
|
3654
|
-
SELECT id,
|
|
3655
|
-
|
|
4598
|
+
SELECT id, bank_id, name, source_query, content, tags,
|
|
4599
|
+
last_refreshed_at, created_at, reflect_response,
|
|
4600
|
+
max_tokens, trigger
|
|
4601
|
+
FROM {fq_table("mental_models")}
|
|
4602
|
+
WHERE bank_id = $1 {tag_filter}
|
|
4603
|
+
ORDER BY last_refreshed_at DESC
|
|
4604
|
+
LIMIT $2 OFFSET $3
|
|
3656
4605
|
""",
|
|
3657
|
-
|
|
3658
|
-
bank_id,
|
|
4606
|
+
*params,
|
|
3659
4607
|
)
|
|
3660
|
-
entity_names = {row["id"]: row["canonical_name"] for row in entity_rows}
|
|
3661
4608
|
|
|
3662
|
-
|
|
4609
|
+
return [self._row_to_mental_model(row) for row in rows]
|
|
4610
|
+
|
|
4611
|
+
async def get_mental_model(
|
|
4612
|
+
self,
|
|
4613
|
+
bank_id: str,
|
|
4614
|
+
mental_model_id: str,
|
|
4615
|
+
*,
|
|
4616
|
+
request_context: "RequestContext",
|
|
4617
|
+
) -> dict[str, Any] | None:
|
|
4618
|
+
"""Get a single pinned mental model by ID.
|
|
4619
|
+
|
|
4620
|
+
Args:
|
|
4621
|
+
bank_id: Bank identifier
|
|
4622
|
+
mental_model_id: Pinned mental model UUID
|
|
4623
|
+
request_context: Request context for authentication
|
|
4624
|
+
|
|
4625
|
+
Returns:
|
|
4626
|
+
Pinned mental model dict or None if not found
|
|
4627
|
+
"""
|
|
4628
|
+
await self._authenticate_tenant(request_context)
|
|
4629
|
+
pool = await self._get_pool()
|
|
4630
|
+
|
|
4631
|
+
async with acquire_with_retry(pool) as conn:
|
|
4632
|
+
row = await conn.fetchrow(
|
|
3663
4633
|
f"""
|
|
3664
|
-
SELECT
|
|
3665
|
-
|
|
3666
|
-
|
|
3667
|
-
|
|
3668
|
-
|
|
4634
|
+
SELECT id, bank_id, name, source_query, content, tags,
|
|
4635
|
+
last_refreshed_at, created_at, reflect_response,
|
|
4636
|
+
max_tokens, trigger
|
|
4637
|
+
FROM {fq_table("mental_models")}
|
|
4638
|
+
WHERE bank_id = $1 AND id = $2
|
|
3669
4639
|
""",
|
|
3670
|
-
entity_uuids,
|
|
3671
4640
|
bank_id,
|
|
4641
|
+
mental_model_id,
|
|
3672
4642
|
)
|
|
3673
|
-
|
|
3674
|
-
|
|
3675
|
-
|
|
3676
|
-
|
|
3677
|
-
|
|
3678
|
-
|
|
4643
|
+
|
|
4644
|
+
return self._row_to_mental_model(row) if row else None
|
|
4645
|
+
|
|
4646
|
+
async def create_mental_model(
|
|
4647
|
+
self,
|
|
4648
|
+
bank_id: str,
|
|
4649
|
+
name: str,
|
|
4650
|
+
source_query: str,
|
|
4651
|
+
content: str,
|
|
4652
|
+
*,
|
|
4653
|
+
mental_model_id: str | None = None,
|
|
4654
|
+
tags: list[str] | None = None,
|
|
4655
|
+
max_tokens: int | None = None,
|
|
4656
|
+
trigger: dict[str, Any] | None = None,
|
|
4657
|
+
request_context: "RequestContext",
|
|
4658
|
+
) -> dict[str, Any]:
|
|
4659
|
+
"""Create a new pinned mental model.
|
|
4660
|
+
|
|
4661
|
+
Args:
|
|
4662
|
+
bank_id: Bank identifier
|
|
4663
|
+
name: Human-readable name for the mental model
|
|
4664
|
+
source_query: The query that generated this mental model
|
|
4665
|
+
content: The synthesized content
|
|
4666
|
+
mental_model_id: Optional UUID for the mental model (auto-generated if not provided)
|
|
4667
|
+
tags: Optional tags for scoped visibility
|
|
4668
|
+
max_tokens: Token limit for content generation during refresh
|
|
4669
|
+
trigger: Trigger settings (e.g., refresh_after_consolidation)
|
|
4670
|
+
request_context: Request context for authentication
|
|
4671
|
+
|
|
4672
|
+
Returns:
|
|
4673
|
+
The created pinned mental model dict
|
|
4674
|
+
"""
|
|
4675
|
+
await self._authenticate_tenant(request_context)
|
|
4676
|
+
pool = await self._get_pool()
|
|
4677
|
+
|
|
4678
|
+
# Generate embedding for the content
|
|
4679
|
+
embedding_text = f"{name} {content}"
|
|
4680
|
+
embedding = await embedding_utils.generate_embeddings_batch(self.embeddings, [embedding_text])
|
|
4681
|
+
# Convert embedding to string for asyncpg vector type
|
|
4682
|
+
embedding_str = str(embedding[0]) if embedding else None
|
|
4683
|
+
|
|
4684
|
+
async with acquire_with_retry(pool) as conn:
|
|
4685
|
+
if mental_model_id:
|
|
4686
|
+
row = await conn.fetchrow(
|
|
3679
4687
|
f"""
|
|
3680
|
-
|
|
3681
|
-
|
|
4688
|
+
INSERT INTO {fq_table("mental_models")}
|
|
4689
|
+
(id, bank_id, name, source_query, content, embedding, tags, max_tokens, trigger)
|
|
4690
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, COALESCE($8, 2048), COALESCE($9, '{{"refresh_after_consolidation": false}}'::jsonb))
|
|
4691
|
+
RETURNING id, bank_id, name, source_query, content, tags,
|
|
4692
|
+
last_refreshed_at, created_at, reflect_response,
|
|
4693
|
+
max_tokens, trigger
|
|
3682
4694
|
""",
|
|
3683
|
-
|
|
4695
|
+
mental_model_id,
|
|
3684
4696
|
bank_id,
|
|
4697
|
+
name,
|
|
4698
|
+
source_query,
|
|
4699
|
+
content,
|
|
4700
|
+
embedding_str,
|
|
4701
|
+
tags or [],
|
|
4702
|
+
max_tokens,
|
|
4703
|
+
json.dumps(trigger) if trigger else None,
|
|
3685
4704
|
)
|
|
3686
|
-
|
|
3687
|
-
|
|
3688
|
-
fact_counts = await acquired_conn.fetch(
|
|
4705
|
+
else:
|
|
4706
|
+
row = await conn.fetchrow(
|
|
3689
4707
|
f"""
|
|
3690
|
-
|
|
3691
|
-
|
|
3692
|
-
|
|
3693
|
-
|
|
3694
|
-
|
|
4708
|
+
INSERT INTO {fq_table("mental_models")}
|
|
4709
|
+
(bank_id, name, source_query, content, embedding, tags, max_tokens, trigger)
|
|
4710
|
+
VALUES ($1, $2, $3, $4, $5, $6, COALESCE($7, 2048), COALESCE($8, '{{"refresh_after_consolidation": false}}'::jsonb))
|
|
4711
|
+
RETURNING id, bank_id, name, source_query, content, tags,
|
|
4712
|
+
last_refreshed_at, created_at, reflect_response,
|
|
4713
|
+
max_tokens, trigger
|
|
3695
4714
|
""",
|
|
3696
|
-
entity_uuids,
|
|
3697
4715
|
bank_id,
|
|
4716
|
+
name,
|
|
4717
|
+
source_query,
|
|
4718
|
+
content,
|
|
4719
|
+
embedding_str,
|
|
4720
|
+
tags or [],
|
|
4721
|
+
max_tokens,
|
|
4722
|
+
json.dumps(trigger) if trigger else None,
|
|
3698
4723
|
)
|
|
3699
|
-
entity_fact_counts = {row["entity_id"]: row["cnt"] for row in fact_counts}
|
|
3700
4724
|
|
|
3701
|
-
|
|
3702
|
-
|
|
3703
|
-
for entity_id in entity_ids:
|
|
3704
|
-
entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
|
|
3705
|
-
if entity_uuid not in entity_names:
|
|
3706
|
-
continue
|
|
3707
|
-
fact_count = entity_fact_counts.get(entity_uuid, 0)
|
|
3708
|
-
if fact_count >= min_facts:
|
|
3709
|
-
entities_to_process.append((entity_id, entity_names[entity_uuid]))
|
|
4725
|
+
logger.info(f"[MENTAL_MODELS] Created pinned mental model '{name}' for bank {bank_id}")
|
|
4726
|
+
return self._row_to_mental_model(row)
|
|
3710
4727
|
|
|
3711
|
-
|
|
3712
|
-
|
|
4728
|
+
async def refresh_mental_model(
|
|
4729
|
+
self,
|
|
4730
|
+
bank_id: str,
|
|
4731
|
+
mental_model_id: str,
|
|
4732
|
+
*,
|
|
4733
|
+
request_context: "RequestContext",
|
|
4734
|
+
) -> dict[str, Any] | None:
|
|
4735
|
+
"""Refresh a pinned mental model by re-running its source query.
|
|
3713
4736
|
|
|
3714
|
-
|
|
3715
|
-
|
|
4737
|
+
This method:
|
|
4738
|
+
1. Gets the pinned mental model
|
|
4739
|
+
2. Runs the source_query through reflect
|
|
4740
|
+
3. Updates the content with the new synthesis
|
|
4741
|
+
4. Updates last_refreshed_at
|
|
3716
4742
|
|
|
3717
|
-
|
|
4743
|
+
Args:
|
|
4744
|
+
bank_id: Bank identifier
|
|
4745
|
+
mental_model_id: Pinned mental model UUID
|
|
4746
|
+
request_context: Request context for authentication
|
|
3718
4747
|
|
|
3719
|
-
|
|
3720
|
-
|
|
3721
|
-
|
|
3722
|
-
|
|
3723
|
-
|
|
3724
|
-
|
|
3725
|
-
|
|
3726
|
-
|
|
4748
|
+
Returns:
|
|
4749
|
+
Updated pinned mental model dict or None if not found
|
|
4750
|
+
"""
|
|
4751
|
+
await self._authenticate_tenant(request_context)
|
|
4752
|
+
|
|
4753
|
+
# Get the current mental model
|
|
4754
|
+
mental_model = await self.get_mental_model(bank_id, mental_model_id, request_context=request_context)
|
|
4755
|
+
if not mental_model:
|
|
4756
|
+
return None
|
|
4757
|
+
|
|
4758
|
+
# Run reflect with the source query, excluding the mental model being refreshed
|
|
4759
|
+
reflect_result = await self.reflect_async(
|
|
4760
|
+
bank_id=bank_id,
|
|
4761
|
+
query=mental_model["source_query"],
|
|
4762
|
+
request_context=request_context,
|
|
4763
|
+
exclude_mental_model_ids=[mental_model_id],
|
|
4764
|
+
)
|
|
4765
|
+
|
|
4766
|
+
# Build reflect_response payload to store
|
|
4767
|
+
reflect_response_payload = {
|
|
4768
|
+
"text": reflect_result.text,
|
|
4769
|
+
"based_on": {
|
|
4770
|
+
fact_type: [
|
|
4771
|
+
{
|
|
4772
|
+
"id": str(fact.id),
|
|
4773
|
+
"text": fact.text,
|
|
4774
|
+
"type": fact_type,
|
|
4775
|
+
}
|
|
4776
|
+
for fact in facts
|
|
4777
|
+
]
|
|
4778
|
+
for fact_type, facts in reflect_result.based_on.items()
|
|
4779
|
+
},
|
|
4780
|
+
"mental_models": [], # Mental models are included in based_on["mental-models"]
|
|
4781
|
+
}
|
|
3727
4782
|
|
|
3728
|
-
|
|
4783
|
+
# Update the mental model with new content and reflect_response
|
|
4784
|
+
return await self.update_mental_model(
|
|
4785
|
+
bank_id,
|
|
4786
|
+
mental_model_id,
|
|
4787
|
+
content=reflect_result.text,
|
|
4788
|
+
reflect_response=reflect_response_payload,
|
|
4789
|
+
request_context=request_context,
|
|
4790
|
+
)
|
|
4791
|
+
|
|
4792
|
+
async def update_mental_model(
|
|
4793
|
+
self,
|
|
4794
|
+
bank_id: str,
|
|
4795
|
+
mental_model_id: str,
|
|
4796
|
+
*,
|
|
4797
|
+
name: str | None = None,
|
|
4798
|
+
content: str | None = None,
|
|
4799
|
+
source_query: str | None = None,
|
|
4800
|
+
max_tokens: int | None = None,
|
|
4801
|
+
tags: list[str] | None = None,
|
|
4802
|
+
trigger: dict[str, Any] | None = None,
|
|
4803
|
+
reflect_response: dict[str, Any] | None = None,
|
|
4804
|
+
request_context: "RequestContext",
|
|
4805
|
+
) -> dict[str, Any] | None:
|
|
4806
|
+
"""Update a pinned mental model.
|
|
4807
|
+
|
|
4808
|
+
Args:
|
|
4809
|
+
bank_id: Bank identifier
|
|
4810
|
+
mental_model_id: Pinned mental model UUID
|
|
4811
|
+
name: New name (if changing)
|
|
4812
|
+
content: New content (if changing)
|
|
4813
|
+
source_query: New source query (if changing)
|
|
4814
|
+
max_tokens: New max tokens (if changing)
|
|
4815
|
+
tags: New tags (if changing)
|
|
4816
|
+
trigger: New trigger settings (if changing)
|
|
4817
|
+
reflect_response: Full reflect API response payload (if changing)
|
|
4818
|
+
request_context: Request context for authentication
|
|
3729
4819
|
|
|
3730
|
-
|
|
4820
|
+
Returns:
|
|
4821
|
+
Updated pinned mental model dict or None if not found
|
|
3731
4822
|
"""
|
|
3732
|
-
|
|
4823
|
+
await self._authenticate_tenant(request_context)
|
|
4824
|
+
pool = await self._get_pool()
|
|
4825
|
+
|
|
4826
|
+
async with acquire_with_retry(pool) as conn:
|
|
4827
|
+
# Build dynamic update
|
|
4828
|
+
updates = []
|
|
4829
|
+
params: list[Any] = [bank_id, mental_model_id]
|
|
4830
|
+
param_idx = 3
|
|
4831
|
+
|
|
4832
|
+
if name is not None:
|
|
4833
|
+
updates.append(f"name = ${param_idx}")
|
|
4834
|
+
params.append(name)
|
|
4835
|
+
param_idx += 1
|
|
4836
|
+
|
|
4837
|
+
if content is not None:
|
|
4838
|
+
updates.append(f"content = ${param_idx}")
|
|
4839
|
+
params.append(content)
|
|
4840
|
+
param_idx += 1
|
|
4841
|
+
updates.append("last_refreshed_at = NOW()")
|
|
4842
|
+
# Also update embedding (convert to string for asyncpg vector type)
|
|
4843
|
+
embedding_text = f"{name or ''} {content}"
|
|
4844
|
+
embedding = await embedding_utils.generate_embeddings_batch(self.embeddings, [embedding_text])
|
|
4845
|
+
if embedding:
|
|
4846
|
+
updates.append(f"embedding = ${param_idx}")
|
|
4847
|
+
params.append(str(embedding[0]))
|
|
4848
|
+
param_idx += 1
|
|
4849
|
+
|
|
4850
|
+
if reflect_response is not None:
|
|
4851
|
+
updates.append(f"reflect_response = ${param_idx}")
|
|
4852
|
+
params.append(json.dumps(reflect_response))
|
|
4853
|
+
param_idx += 1
|
|
4854
|
+
|
|
4855
|
+
if source_query is not None:
|
|
4856
|
+
updates.append(f"source_query = ${param_idx}")
|
|
4857
|
+
params.append(source_query)
|
|
4858
|
+
param_idx += 1
|
|
4859
|
+
|
|
4860
|
+
if max_tokens is not None:
|
|
4861
|
+
updates.append(f"max_tokens = ${param_idx}")
|
|
4862
|
+
params.append(max_tokens)
|
|
4863
|
+
param_idx += 1
|
|
4864
|
+
|
|
4865
|
+
if tags is not None:
|
|
4866
|
+
updates.append(f"tags = ${param_idx}")
|
|
4867
|
+
params.append(tags)
|
|
4868
|
+
param_idx += 1
|
|
4869
|
+
|
|
4870
|
+
if trigger is not None:
|
|
4871
|
+
updates.append(f"trigger = ${param_idx}")
|
|
4872
|
+
params.append(json.dumps(trigger))
|
|
4873
|
+
param_idx += 1
|
|
4874
|
+
|
|
4875
|
+
if not updates:
|
|
4876
|
+
return None
|
|
4877
|
+
|
|
4878
|
+
query = f"""
|
|
4879
|
+
UPDATE {fq_table("mental_models")}
|
|
4880
|
+
SET {", ".join(updates)}
|
|
4881
|
+
WHERE bank_id = $1 AND id = $2
|
|
4882
|
+
RETURNING id, bank_id, name, source_query, content, tags,
|
|
4883
|
+
last_refreshed_at, created_at, reflect_response,
|
|
4884
|
+
max_tokens, trigger
|
|
4885
|
+
"""
|
|
4886
|
+
|
|
4887
|
+
row = await conn.fetchrow(query, *params)
|
|
4888
|
+
|
|
4889
|
+
return self._row_to_mental_model(row) if row else None
|
|
4890
|
+
|
|
4891
|
+
async def delete_mental_model(
|
|
4892
|
+
self,
|
|
4893
|
+
bank_id: str,
|
|
4894
|
+
mental_model_id: str,
|
|
4895
|
+
*,
|
|
4896
|
+
request_context: "RequestContext",
|
|
4897
|
+
) -> bool:
|
|
4898
|
+
"""Delete a pinned mental model.
|
|
3733
4899
|
|
|
3734
4900
|
Args:
|
|
3735
|
-
|
|
3736
|
-
|
|
3737
|
-
|
|
4901
|
+
bank_id: Bank identifier
|
|
4902
|
+
mental_model_id: Pinned mental model UUID
|
|
4903
|
+
request_context: Request context for authentication
|
|
3738
4904
|
|
|
3739
|
-
|
|
3740
|
-
|
|
3741
|
-
Exception: Any exception from regenerate_entity_observations (propagates to execute_task for retry)
|
|
4905
|
+
Returns:
|
|
4906
|
+
True if deleted, False if not found
|
|
3742
4907
|
"""
|
|
3743
|
-
|
|
3744
|
-
|
|
3745
|
-
from hindsight_api.models import RequestContext
|
|
4908
|
+
await self._authenticate_tenant(request_context)
|
|
4909
|
+
pool = await self._get_pool()
|
|
3746
4910
|
|
|
3747
|
-
|
|
4911
|
+
async with acquire_with_retry(pool) as conn:
|
|
4912
|
+
result = await conn.execute(
|
|
4913
|
+
f"DELETE FROM {fq_table('mental_models')} WHERE bank_id = $1 AND id = $2",
|
|
4914
|
+
bank_id,
|
|
4915
|
+
mental_model_id,
|
|
4916
|
+
)
|
|
3748
4917
|
|
|
3749
|
-
|
|
3750
|
-
if "entity_ids" in task_dict:
|
|
3751
|
-
entity_ids = task_dict.get("entity_ids", [])
|
|
3752
|
-
min_facts = task_dict.get("min_facts", 5)
|
|
4918
|
+
return result == "DELETE 1"
|
|
3753
4919
|
|
|
3754
|
-
|
|
3755
|
-
|
|
4920
|
+
def _row_to_mental_model(self, row) -> dict[str, Any]:
|
|
4921
|
+
"""Convert a database row to a mental model dict."""
|
|
4922
|
+
reflect_response = row.get("reflect_response")
|
|
4923
|
+
# Parse JSON string to dict if needed (asyncpg may return JSONB as string)
|
|
4924
|
+
if isinstance(reflect_response, str):
|
|
4925
|
+
try:
|
|
4926
|
+
reflect_response = json.loads(reflect_response)
|
|
4927
|
+
except json.JSONDecodeError:
|
|
4928
|
+
reflect_response = None
|
|
4929
|
+
trigger = row.get("trigger")
|
|
4930
|
+
if isinstance(trigger, str):
|
|
4931
|
+
try:
|
|
4932
|
+
trigger = json.loads(trigger)
|
|
4933
|
+
except json.JSONDecodeError:
|
|
4934
|
+
trigger = None
|
|
4935
|
+
return {
|
|
4936
|
+
"id": str(row["id"]),
|
|
4937
|
+
"bank_id": row["bank_id"],
|
|
4938
|
+
"name": row["name"],
|
|
4939
|
+
"source_query": row["source_query"],
|
|
4940
|
+
"content": row["content"],
|
|
4941
|
+
"tags": row["tags"] or [],
|
|
4942
|
+
"max_tokens": row.get("max_tokens"),
|
|
4943
|
+
"trigger": trigger,
|
|
4944
|
+
"last_refreshed_at": row["last_refreshed_at"].isoformat() if row["last_refreshed_at"] else None,
|
|
4945
|
+
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
|
4946
|
+
"reflect_response": reflect_response,
|
|
4947
|
+
}
|
|
3756
4948
|
|
|
3757
|
-
|
|
3758
|
-
|
|
3759
|
-
|
|
3760
|
-
for entity_id in entity_ids:
|
|
3761
|
-
try:
|
|
3762
|
-
# Fetch entity name and check fact count
|
|
3763
|
-
import uuid as uuid_module
|
|
4949
|
+
# =========================================================================
|
|
4950
|
+
# Directives - Hard rules injected into prompts
|
|
4951
|
+
# =========================================================================
|
|
3764
4952
|
|
|
3765
|
-
|
|
4953
|
+
async def list_directives(
|
|
4954
|
+
self,
|
|
4955
|
+
bank_id: str,
|
|
4956
|
+
*,
|
|
4957
|
+
tags: list[str] | None = None,
|
|
4958
|
+
tags_match: str = "any",
|
|
4959
|
+
active_only: bool = True,
|
|
4960
|
+
limit: int = 100,
|
|
4961
|
+
offset: int = 0,
|
|
4962
|
+
request_context: "RequestContext",
|
|
4963
|
+
) -> list[dict[str, Any]]:
|
|
4964
|
+
"""List directives for a bank.
|
|
3766
4965
|
|
|
3767
|
-
|
|
3768
|
-
|
|
3769
|
-
|
|
3770
|
-
|
|
3771
|
-
|
|
3772
|
-
|
|
4966
|
+
Args:
|
|
4967
|
+
bank_id: Bank identifier
|
|
4968
|
+
tags: Optional tags to filter by
|
|
4969
|
+
tags_match: How to match tags - 'any', 'all', or 'exact'
|
|
4970
|
+
active_only: Only return active directives (default True)
|
|
4971
|
+
limit: Maximum number of results
|
|
4972
|
+
offset: Offset for pagination
|
|
4973
|
+
request_context: Request context for authentication
|
|
3773
4974
|
|
|
3774
|
-
|
|
3775
|
-
|
|
3776
|
-
|
|
4975
|
+
Returns:
|
|
4976
|
+
List of directive dicts
|
|
4977
|
+
"""
|
|
4978
|
+
await self._authenticate_tenant(request_context)
|
|
4979
|
+
pool = await self._get_pool()
|
|
3777
4980
|
|
|
3778
|
-
|
|
4981
|
+
async with acquire_with_retry(pool) as conn:
|
|
4982
|
+
# Build filters
|
|
4983
|
+
filters = ["bank_id = $1"]
|
|
4984
|
+
params: list[Any] = [bank_id]
|
|
4985
|
+
param_idx = 2
|
|
4986
|
+
|
|
4987
|
+
if active_only:
|
|
4988
|
+
filters.append("is_active = TRUE")
|
|
4989
|
+
|
|
4990
|
+
if tags:
|
|
4991
|
+
if tags_match == "all":
|
|
4992
|
+
filters.append(f"tags @> ${param_idx}::varchar[]")
|
|
4993
|
+
elif tags_match == "exact":
|
|
4994
|
+
filters.append(f"tags = ${param_idx}::varchar[]")
|
|
4995
|
+
else: # any
|
|
4996
|
+
filters.append(f"tags && ${param_idx}::varchar[]")
|
|
4997
|
+
params.append(tags)
|
|
4998
|
+
param_idx += 1
|
|
4999
|
+
|
|
5000
|
+
params.extend([limit, offset])
|
|
3779
5001
|
|
|
3780
|
-
|
|
3781
|
-
|
|
3782
|
-
|
|
3783
|
-
|
|
3784
|
-
|
|
3785
|
-
|
|
3786
|
-
|
|
3787
|
-
|
|
5002
|
+
rows = await conn.fetch(
|
|
5003
|
+
f"""
|
|
5004
|
+
SELECT id, bank_id, name, content, priority, is_active, tags, created_at, updated_at
|
|
5005
|
+
FROM {fq_table("directives")}
|
|
5006
|
+
WHERE {" AND ".join(filters)}
|
|
5007
|
+
ORDER BY priority DESC, created_at DESC
|
|
5008
|
+
LIMIT ${param_idx} OFFSET ${param_idx + 1}
|
|
5009
|
+
""",
|
|
5010
|
+
*params,
|
|
5011
|
+
)
|
|
3788
5012
|
|
|
3789
|
-
|
|
3790
|
-
if fact_count >= min_facts:
|
|
3791
|
-
await self.regenerate_entity_observations(
|
|
3792
|
-
bank_id, entity_id, entity_name, version=None, request_context=internal_context
|
|
3793
|
-
)
|
|
3794
|
-
else:
|
|
3795
|
-
logger.debug(
|
|
3796
|
-
f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)"
|
|
3797
|
-
)
|
|
5013
|
+
return [self._row_to_directive(row) for row in rows]
|
|
3798
5014
|
|
|
3799
|
-
|
|
3800
|
-
|
|
3801
|
-
|
|
3802
|
-
|
|
3803
|
-
|
|
5015
|
+
async def get_directive(
|
|
5016
|
+
self,
|
|
5017
|
+
bank_id: str,
|
|
5018
|
+
directive_id: str,
|
|
5019
|
+
*,
|
|
5020
|
+
request_context: "RequestContext",
|
|
5021
|
+
) -> dict[str, Any] | None:
|
|
5022
|
+
"""Get a single directive by ID.
|
|
3804
5023
|
|
|
3805
|
-
|
|
3806
|
-
|
|
3807
|
-
|
|
3808
|
-
|
|
3809
|
-
version = task_dict.get("version")
|
|
5024
|
+
Args:
|
|
5025
|
+
bank_id: Bank identifier
|
|
5026
|
+
directive_id: Directive UUID
|
|
5027
|
+
request_context: Request context for authentication
|
|
3810
5028
|
|
|
3811
|
-
|
|
3812
|
-
|
|
5029
|
+
Returns:
|
|
5030
|
+
Directive dict or None if not found
|
|
5031
|
+
"""
|
|
5032
|
+
await self._authenticate_tenant(request_context)
|
|
5033
|
+
pool = await self._get_pool()
|
|
3813
5034
|
|
|
3814
|
-
|
|
3815
|
-
|
|
3816
|
-
|
|
3817
|
-
bank_id,
|
|
5035
|
+
async with acquire_with_retry(pool) as conn:
|
|
5036
|
+
row = await conn.fetchrow(
|
|
5037
|
+
f"""
|
|
5038
|
+
SELECT id, bank_id, name, content, priority, is_active, tags, created_at, updated_at
|
|
5039
|
+
FROM {fq_table("directives")}
|
|
5040
|
+
WHERE bank_id = $1 AND id = $2
|
|
5041
|
+
""",
|
|
5042
|
+
bank_id,
|
|
5043
|
+
directive_id,
|
|
3818
5044
|
)
|
|
3819
5045
|
|
|
3820
|
-
|
|
3821
|
-
# Statistics & Operations (for HTTP API layer)
|
|
3822
|
-
# =========================================================================
|
|
5046
|
+
return self._row_to_directive(row) if row else None
|
|
3823
5047
|
|
|
3824
|
-
async def
|
|
5048
|
+
async def create_directive(
|
|
3825
5049
|
self,
|
|
3826
5050
|
bank_id: str,
|
|
5051
|
+
name: str,
|
|
5052
|
+
content: str,
|
|
3827
5053
|
*,
|
|
5054
|
+
priority: int = 0,
|
|
5055
|
+
is_active: bool = True,
|
|
5056
|
+
tags: list[str] | None = None,
|
|
3828
5057
|
request_context: "RequestContext",
|
|
3829
5058
|
) -> dict[str, Any]:
|
|
3830
|
-
"""
|
|
5059
|
+
"""Create a new directive.
|
|
5060
|
+
|
|
5061
|
+
Args:
|
|
5062
|
+
bank_id: Bank identifier
|
|
5063
|
+
name: Human-readable name for the directive
|
|
5064
|
+
content: The directive text to inject into prompts
|
|
5065
|
+
priority: Higher priority directives are injected first (default 0)
|
|
5066
|
+
is_active: Whether this directive is active (default True)
|
|
5067
|
+
tags: Optional tags for filtering
|
|
5068
|
+
request_context: Request context for authentication
|
|
5069
|
+
|
|
5070
|
+
Returns:
|
|
5071
|
+
The created directive dict
|
|
5072
|
+
"""
|
|
3831
5073
|
await self._authenticate_tenant(request_context)
|
|
3832
5074
|
pool = await self._get_pool()
|
|
3833
5075
|
|
|
3834
5076
|
async with acquire_with_retry(pool) as conn:
|
|
3835
|
-
|
|
3836
|
-
node_stats = await conn.fetch(
|
|
5077
|
+
row = await conn.fetchrow(
|
|
3837
5078
|
f"""
|
|
3838
|
-
|
|
3839
|
-
|
|
3840
|
-
|
|
3841
|
-
|
|
5079
|
+
INSERT INTO {fq_table("directives")}
|
|
5080
|
+
(bank_id, name, content, priority, is_active, tags)
|
|
5081
|
+
VALUES ($1, $2, $3, $4, $5, $6)
|
|
5082
|
+
RETURNING id, bank_id, name, content, priority, is_active, tags, created_at, updated_at
|
|
3842
5083
|
""",
|
|
3843
5084
|
bank_id,
|
|
5085
|
+
name,
|
|
5086
|
+
content,
|
|
5087
|
+
priority,
|
|
5088
|
+
is_active,
|
|
5089
|
+
tags or [],
|
|
3844
5090
|
)
|
|
3845
5091
|
|
|
3846
|
-
|
|
3847
|
-
|
|
3848
|
-
f"""
|
|
3849
|
-
SELECT ml.link_type, COUNT(*) as count
|
|
3850
|
-
FROM {fq_table("memory_links")} ml
|
|
3851
|
-
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
3852
|
-
WHERE mu.bank_id = $1
|
|
3853
|
-
GROUP BY ml.link_type
|
|
3854
|
-
""",
|
|
3855
|
-
bank_id,
|
|
3856
|
-
)
|
|
5092
|
+
logger.info(f"[DIRECTIVES] Created directive '{name}' for bank {bank_id}")
|
|
5093
|
+
return self._row_to_directive(row)
|
|
3857
5094
|
|
|
3858
|
-
|
|
3859
|
-
|
|
3860
|
-
|
|
3861
|
-
|
|
3862
|
-
|
|
3863
|
-
|
|
3864
|
-
|
|
3865
|
-
|
|
3866
|
-
|
|
3867
|
-
|
|
3868
|
-
|
|
5095
|
+
async def update_directive(
|
|
5096
|
+
self,
|
|
5097
|
+
bank_id: str,
|
|
5098
|
+
directive_id: str,
|
|
5099
|
+
*,
|
|
5100
|
+
name: str | None = None,
|
|
5101
|
+
content: str | None = None,
|
|
5102
|
+
priority: int | None = None,
|
|
5103
|
+
is_active: bool | None = None,
|
|
5104
|
+
tags: list[str] | None = None,
|
|
5105
|
+
request_context: "RequestContext",
|
|
5106
|
+
) -> dict[str, Any] | None:
|
|
5107
|
+
"""Update a directive.
|
|
3869
5108
|
|
|
3870
|
-
|
|
3871
|
-
|
|
3872
|
-
|
|
3873
|
-
|
|
3874
|
-
|
|
3875
|
-
|
|
3876
|
-
|
|
3877
|
-
|
|
3878
|
-
|
|
3879
|
-
bank_id,
|
|
3880
|
-
)
|
|
5109
|
+
Args:
|
|
5110
|
+
bank_id: Bank identifier
|
|
5111
|
+
directive_id: Directive UUID
|
|
5112
|
+
name: New name (optional)
|
|
5113
|
+
content: New content (optional)
|
|
5114
|
+
priority: New priority (optional)
|
|
5115
|
+
is_active: New active status (optional)
|
|
5116
|
+
tags: New tags (optional)
|
|
5117
|
+
request_context: Request context for authentication
|
|
3881
5118
|
|
|
3882
|
-
|
|
3883
|
-
|
|
5119
|
+
Returns:
|
|
5120
|
+
Updated directive dict or None if not found
|
|
5121
|
+
"""
|
|
5122
|
+
await self._authenticate_tenant(request_context)
|
|
5123
|
+
pool = await self._get_pool()
|
|
5124
|
+
|
|
5125
|
+
# Build update query dynamically
|
|
5126
|
+
updates = ["updated_at = now()"]
|
|
5127
|
+
params: list[Any] = []
|
|
5128
|
+
param_idx = 1
|
|
5129
|
+
|
|
5130
|
+
if name is not None:
|
|
5131
|
+
updates.append(f"name = ${param_idx}")
|
|
5132
|
+
params.append(name)
|
|
5133
|
+
param_idx += 1
|
|
5134
|
+
|
|
5135
|
+
if content is not None:
|
|
5136
|
+
updates.append(f"content = ${param_idx}")
|
|
5137
|
+
params.append(content)
|
|
5138
|
+
param_idx += 1
|
|
5139
|
+
|
|
5140
|
+
if priority is not None:
|
|
5141
|
+
updates.append(f"priority = ${param_idx}")
|
|
5142
|
+
params.append(priority)
|
|
5143
|
+
param_idx += 1
|
|
5144
|
+
|
|
5145
|
+
if is_active is not None:
|
|
5146
|
+
updates.append(f"is_active = ${param_idx}")
|
|
5147
|
+
params.append(is_active)
|
|
5148
|
+
param_idx += 1
|
|
5149
|
+
|
|
5150
|
+
if tags is not None:
|
|
5151
|
+
updates.append(f"tags = ${param_idx}")
|
|
5152
|
+
params.append(tags)
|
|
5153
|
+
param_idx += 1
|
|
5154
|
+
|
|
5155
|
+
params.extend([bank_id, directive_id])
|
|
5156
|
+
|
|
5157
|
+
async with acquire_with_retry(pool) as conn:
|
|
5158
|
+
row = await conn.fetchrow(
|
|
3884
5159
|
f"""
|
|
3885
|
-
|
|
3886
|
-
|
|
3887
|
-
WHERE bank_id = $1
|
|
3888
|
-
|
|
5160
|
+
UPDATE {fq_table("directives")}
|
|
5161
|
+
SET {", ".join(updates)}
|
|
5162
|
+
WHERE bank_id = ${param_idx} AND id = ${param_idx + 1}
|
|
5163
|
+
RETURNING id, bank_id, name, content, priority, is_active, tags, created_at, updated_at
|
|
3889
5164
|
""",
|
|
3890
|
-
|
|
5165
|
+
*params,
|
|
3891
5166
|
)
|
|
3892
5167
|
|
|
3893
|
-
return
|
|
3894
|
-
"bank_id": bank_id,
|
|
3895
|
-
"node_counts": {row["fact_type"]: row["count"] for row in node_stats},
|
|
3896
|
-
"link_counts": {row["link_type"]: row["count"] for row in link_stats},
|
|
3897
|
-
"link_counts_by_fact_type": {row["fact_type"]: row["count"] for row in link_fact_type_stats},
|
|
3898
|
-
"link_breakdown": [
|
|
3899
|
-
{"fact_type": row["fact_type"], "link_type": row["link_type"], "count": row["count"]}
|
|
3900
|
-
for row in link_breakdown_stats
|
|
3901
|
-
],
|
|
3902
|
-
"operations": {row["status"]: row["count"] for row in ops_stats},
|
|
3903
|
-
}
|
|
5168
|
+
return self._row_to_directive(row) if row else None
|
|
3904
5169
|
|
|
3905
|
-
async def
|
|
5170
|
+
async def delete_directive(
|
|
3906
5171
|
self,
|
|
3907
5172
|
bank_id: str,
|
|
3908
|
-
|
|
5173
|
+
directive_id: str,
|
|
3909
5174
|
*,
|
|
3910
5175
|
request_context: "RequestContext",
|
|
3911
|
-
) ->
|
|
3912
|
-
"""
|
|
5176
|
+
) -> bool:
|
|
5177
|
+
"""Delete a directive.
|
|
5178
|
+
|
|
5179
|
+
Args:
|
|
5180
|
+
bank_id: Bank identifier
|
|
5181
|
+
directive_id: Directive UUID
|
|
5182
|
+
request_context: Request context for authentication
|
|
5183
|
+
|
|
5184
|
+
Returns:
|
|
5185
|
+
True if deleted, False if not found
|
|
5186
|
+
"""
|
|
3913
5187
|
await self._authenticate_tenant(request_context)
|
|
3914
5188
|
pool = await self._get_pool()
|
|
3915
5189
|
|
|
3916
5190
|
async with acquire_with_retry(pool) as conn:
|
|
3917
|
-
|
|
3918
|
-
f""
|
|
3919
|
-
SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
|
|
3920
|
-
FROM {fq_table("entities")}
|
|
3921
|
-
WHERE bank_id = $1 AND id = $2
|
|
3922
|
-
""",
|
|
5191
|
+
result = await conn.execute(
|
|
5192
|
+
f"DELETE FROM {fq_table('directives')} WHERE bank_id = $1 AND id = $2",
|
|
3923
5193
|
bank_id,
|
|
3924
|
-
|
|
5194
|
+
directive_id,
|
|
3925
5195
|
)
|
|
3926
5196
|
|
|
3927
|
-
|
|
3928
|
-
return None
|
|
3929
|
-
|
|
3930
|
-
# Get observations for the entity
|
|
3931
|
-
observations = await self.get_entity_observations(bank_id, entity_id, limit=20, request_context=request_context)
|
|
5197
|
+
return result == "DELETE 1"
|
|
3932
5198
|
|
|
5199
|
+
def _row_to_directive(self, row) -> dict[str, Any]:
|
|
5200
|
+
"""Convert a database row to a directive dict."""
|
|
3933
5201
|
return {
|
|
3934
|
-
"id": str(
|
|
3935
|
-
"
|
|
3936
|
-
"
|
|
3937
|
-
"
|
|
3938
|
-
"
|
|
3939
|
-
"
|
|
3940
|
-
"
|
|
5202
|
+
"id": str(row["id"]),
|
|
5203
|
+
"bank_id": row["bank_id"],
|
|
5204
|
+
"name": row["name"],
|
|
5205
|
+
"content": row["content"],
|
|
5206
|
+
"priority": row["priority"],
|
|
5207
|
+
"is_active": row["is_active"],
|
|
5208
|
+
"tags": row["tags"] or [],
|
|
5209
|
+
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
|
5210
|
+
"updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
|
|
3941
5211
|
}
|
|
3942
5212
|
|
|
3943
5213
|
async def list_operations(
|
|
3944
5214
|
self,
|
|
3945
5215
|
bank_id: str,
|
|
3946
5216
|
*,
|
|
5217
|
+
status: str | None = None,
|
|
5218
|
+
limit: int = 20,
|
|
5219
|
+
offset: int = 0,
|
|
3947
5220
|
request_context: "RequestContext",
|
|
3948
|
-
) ->
|
|
3949
|
-
"""List async operations for a bank.
|
|
5221
|
+
) -> dict[str, Any]:
|
|
5222
|
+
"""List async operations for a bank with optional filtering and pagination.
|
|
5223
|
+
|
|
5224
|
+
Args:
|
|
5225
|
+
bank_id: Bank identifier
|
|
5226
|
+
status: Optional status filter (pending, completed, failed)
|
|
5227
|
+
limit: Maximum number of operations to return (default 20)
|
|
5228
|
+
offset: Number of operations to skip (default 0)
|
|
5229
|
+
request_context: Request context for authentication
|
|
5230
|
+
|
|
5231
|
+
Returns:
|
|
5232
|
+
Dict with total count and list of operations, sorted by most recent first
|
|
5233
|
+
"""
|
|
3950
5234
|
await self._authenticate_tenant(request_context)
|
|
3951
5235
|
pool = await self._get_pool()
|
|
3952
5236
|
|
|
3953
5237
|
async with acquire_with_retry(pool) as conn:
|
|
5238
|
+
# Build WHERE clause
|
|
5239
|
+
where_conditions = ["bank_id = $1"]
|
|
5240
|
+
params: list[Any] = [bank_id]
|
|
5241
|
+
|
|
5242
|
+
if status:
|
|
5243
|
+
# Map API status to DB statuses (pending includes processing)
|
|
5244
|
+
if status == "pending":
|
|
5245
|
+
where_conditions.append("status IN ('pending', 'processing')")
|
|
5246
|
+
else:
|
|
5247
|
+
where_conditions.append(f"status = ${len(params) + 1}")
|
|
5248
|
+
params.append(status)
|
|
5249
|
+
|
|
5250
|
+
where_clause = " AND ".join(where_conditions)
|
|
5251
|
+
|
|
5252
|
+
# Get total count (with filter)
|
|
5253
|
+
total_row = await conn.fetchrow(
|
|
5254
|
+
f"SELECT COUNT(*) as total FROM {fq_table('async_operations')} WHERE {where_clause}",
|
|
5255
|
+
*params,
|
|
5256
|
+
)
|
|
5257
|
+
total = total_row["total"] if total_row else 0
|
|
5258
|
+
|
|
5259
|
+
# Get operations with pagination
|
|
3954
5260
|
operations = await conn.fetch(
|
|
3955
5261
|
f"""
|
|
3956
|
-
SELECT operation_id,
|
|
5262
|
+
SELECT operation_id, operation_type, created_at, status, error_message
|
|
3957
5263
|
FROM {fq_table("async_operations")}
|
|
3958
|
-
WHERE
|
|
5264
|
+
WHERE {where_clause}
|
|
3959
5265
|
ORDER BY created_at DESC
|
|
5266
|
+
LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}
|
|
3960
5267
|
""",
|
|
3961
|
-
|
|
5268
|
+
*params,
|
|
5269
|
+
limit,
|
|
5270
|
+
offset,
|
|
3962
5271
|
)
|
|
3963
5272
|
|
|
3964
|
-
|
|
3965
|
-
|
|
3966
|
-
|
|
3967
|
-
|
|
3968
|
-
|
|
5273
|
+
return {
|
|
5274
|
+
"total": total,
|
|
5275
|
+
"operations": [
|
|
5276
|
+
{
|
|
5277
|
+
"id": str(row["operation_id"]),
|
|
5278
|
+
"task_type": row["operation_type"],
|
|
5279
|
+
"items_count": 0,
|
|
5280
|
+
"document_id": None,
|
|
5281
|
+
"created_at": row["created_at"].isoformat(),
|
|
5282
|
+
# Map DB status to API status (processing -> pending for simplicity)
|
|
5283
|
+
"status": "pending" if row["status"] in ("pending", "processing") else row["status"],
|
|
5284
|
+
"error_message": row["error_message"],
|
|
5285
|
+
}
|
|
5286
|
+
for row in operations
|
|
5287
|
+
],
|
|
5288
|
+
}
|
|
3969
5289
|
|
|
3970
|
-
|
|
3971
|
-
|
|
5290
|
+
async def get_operation_status(
|
|
5291
|
+
self,
|
|
5292
|
+
bank_id: str,
|
|
5293
|
+
operation_id: str,
|
|
5294
|
+
*,
|
|
5295
|
+
request_context: "RequestContext",
|
|
5296
|
+
) -> dict[str, Any]:
|
|
5297
|
+
"""Get the status of a specific async operation.
|
|
3972
5298
|
|
|
3973
|
-
|
|
3974
|
-
|
|
3975
|
-
|
|
3976
|
-
|
|
3977
|
-
|
|
3978
|
-
|
|
3979
|
-
|
|
3980
|
-
|
|
5299
|
+
Returns:
|
|
5300
|
+
- status: "pending", "completed", or "failed"
|
|
5301
|
+
- updated_at: last update timestamp
|
|
5302
|
+
- completed_at: completion timestamp (if completed)
|
|
5303
|
+
"""
|
|
5304
|
+
await self._authenticate_tenant(request_context)
|
|
5305
|
+
pool = await self._get_pool()
|
|
5306
|
+
|
|
5307
|
+
op_uuid = uuid.UUID(operation_id)
|
|
5308
|
+
|
|
5309
|
+
async with acquire_with_retry(pool) as conn:
|
|
5310
|
+
row = await conn.fetchrow(
|
|
5311
|
+
f"""
|
|
5312
|
+
SELECT operation_id, operation_type, created_at, updated_at, completed_at, status, error_message
|
|
5313
|
+
FROM {fq_table("async_operations")}
|
|
5314
|
+
WHERE operation_id = $1 AND bank_id = $2
|
|
5315
|
+
""",
|
|
5316
|
+
op_uuid,
|
|
5317
|
+
bank_id,
|
|
5318
|
+
)
|
|
5319
|
+
|
|
5320
|
+
if row:
|
|
5321
|
+
# Map DB status to API status (processing -> pending for simplicity)
|
|
5322
|
+
db_status = row["status"]
|
|
5323
|
+
api_status = "pending" if db_status in ("pending", "processing") else db_status
|
|
5324
|
+
return {
|
|
5325
|
+
"operation_id": operation_id,
|
|
5326
|
+
"status": api_status,
|
|
5327
|
+
"operation_type": row["operation_type"],
|
|
5328
|
+
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
|
5329
|
+
"updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
|
|
5330
|
+
"completed_at": row["completed_at"].isoformat() if row["completed_at"] else None,
|
|
3981
5331
|
"error_message": row["error_message"],
|
|
3982
5332
|
}
|
|
3983
|
-
|
|
3984
|
-
|
|
5333
|
+
else:
|
|
5334
|
+
# Operation not found
|
|
5335
|
+
return {
|
|
5336
|
+
"operation_id": operation_id,
|
|
5337
|
+
"status": "not_found",
|
|
5338
|
+
"operation_type": None,
|
|
5339
|
+
"created_at": None,
|
|
5340
|
+
"updated_at": None,
|
|
5341
|
+
"completed_at": None,
|
|
5342
|
+
"error_message": None,
|
|
5343
|
+
}
|
|
3985
5344
|
|
|
3986
5345
|
async def cancel_operation(
|
|
3987
5346
|
self,
|
|
@@ -4022,10 +5381,10 @@ Guidelines:
|
|
|
4022
5381
|
bank_id: str,
|
|
4023
5382
|
*,
|
|
4024
5383
|
name: str | None = None,
|
|
4025
|
-
|
|
5384
|
+
mission: str | None = None,
|
|
4026
5385
|
request_context: "RequestContext",
|
|
4027
5386
|
) -> dict[str, Any]:
|
|
4028
|
-
"""Update bank name and/or
|
|
5387
|
+
"""Update bank name and/or mission."""
|
|
4029
5388
|
await self._authenticate_tenant(request_context)
|
|
4030
5389
|
pool = await self._get_pool()
|
|
4031
5390
|
|
|
@@ -4041,33 +5400,72 @@ Guidelines:
|
|
|
4041
5400
|
name,
|
|
4042
5401
|
)
|
|
4043
5402
|
|
|
4044
|
-
if
|
|
5403
|
+
if mission is not None:
|
|
4045
5404
|
await conn.execute(
|
|
4046
5405
|
f"""
|
|
4047
5406
|
UPDATE {fq_table("banks")}
|
|
4048
|
-
SET
|
|
5407
|
+
SET mission = $2, updated_at = NOW()
|
|
4049
5408
|
WHERE bank_id = $1
|
|
4050
5409
|
""",
|
|
4051
5410
|
bank_id,
|
|
4052
|
-
|
|
5411
|
+
mission,
|
|
4053
5412
|
)
|
|
4054
5413
|
|
|
4055
5414
|
# Return updated profile
|
|
4056
5415
|
return await self.get_bank_profile(bank_id, request_context=request_context)
|
|
4057
5416
|
|
|
4058
|
-
async def
|
|
5417
|
+
async def _submit_async_operation(
|
|
4059
5418
|
self,
|
|
4060
5419
|
bank_id: str,
|
|
4061
|
-
|
|
5420
|
+
operation_type: str,
|
|
5421
|
+
task_type: str,
|
|
5422
|
+
task_payload: dict[str, Any],
|
|
4062
5423
|
*,
|
|
4063
|
-
|
|
5424
|
+
result_metadata: dict[str, Any] | None = None,
|
|
5425
|
+
dedupe_by_bank: bool = False,
|
|
4064
5426
|
) -> dict[str, Any]:
|
|
4065
|
-
"""
|
|
4066
|
-
|
|
4067
|
-
|
|
5427
|
+
"""Generic helper to submit an async operation.
|
|
5428
|
+
|
|
5429
|
+
Args:
|
|
5430
|
+
bank_id: Bank identifier
|
|
5431
|
+
operation_type: Operation type for the async_operations record (e.g., 'consolidation', 'retain')
|
|
5432
|
+
task_type: Task type for the task payload (e.g., 'consolidation', 'batch_retain')
|
|
5433
|
+
task_payload: Additional task payload fields (operation_id and bank_id are added automatically)
|
|
5434
|
+
result_metadata: Optional metadata to store with the operation record
|
|
5435
|
+
dedupe_by_bank: If True, skip creating a new task if one is already pending for this bank+operation_type
|
|
4068
5436
|
|
|
5437
|
+
Returns:
|
|
5438
|
+
Dict with operation_id and optionally deduplicated=True if an existing task was found
|
|
5439
|
+
"""
|
|
4069
5440
|
import json
|
|
4070
5441
|
|
|
5442
|
+
pool = await self._get_pool()
|
|
5443
|
+
|
|
5444
|
+
# Check for existing pending task if deduplication is enabled
|
|
5445
|
+
# Note: We only check 'pending', not 'processing', because a processing task
|
|
5446
|
+
# uses a watermark from when it started - new memories added after that point
|
|
5447
|
+
# would need another consolidation run to be processed.
|
|
5448
|
+
if dedupe_by_bank:
|
|
5449
|
+
async with acquire_with_retry(pool) as conn:
|
|
5450
|
+
existing = await conn.fetchrow(
|
|
5451
|
+
f"""
|
|
5452
|
+
SELECT operation_id FROM {fq_table("async_operations")}
|
|
5453
|
+
WHERE bank_id = $1 AND operation_type = $2 AND status = 'pending'
|
|
5454
|
+
LIMIT 1
|
|
5455
|
+
""",
|
|
5456
|
+
bank_id,
|
|
5457
|
+
operation_type,
|
|
5458
|
+
)
|
|
5459
|
+
if existing:
|
|
5460
|
+
logger.debug(
|
|
5461
|
+
f"{operation_type} task already pending for bank_id={bank_id}, "
|
|
5462
|
+
f"skipping duplicate (existing operation_id={existing['operation_id']})"
|
|
5463
|
+
)
|
|
5464
|
+
return {
|
|
5465
|
+
"operation_id": str(existing["operation_id"]),
|
|
5466
|
+
"deduplicated": True,
|
|
5467
|
+
}
|
|
5468
|
+
|
|
4071
5469
|
operation_id = uuid.uuid4()
|
|
4072
5470
|
|
|
4073
5471
|
# Insert operation record into database
|
|
@@ -4079,23 +5477,113 @@ Guidelines:
|
|
|
4079
5477
|
""",
|
|
4080
5478
|
operation_id,
|
|
4081
5479
|
bank_id,
|
|
4082
|
-
|
|
4083
|
-
json.dumps({
|
|
5480
|
+
operation_type,
|
|
5481
|
+
json.dumps(result_metadata or {}),
|
|
4084
5482
|
)
|
|
4085
5483
|
|
|
4086
|
-
#
|
|
4087
|
-
|
|
4088
|
-
|
|
4089
|
-
|
|
4090
|
-
|
|
4091
|
-
|
|
4092
|
-
|
|
4093
|
-
|
|
4094
|
-
)
|
|
5484
|
+
# Build and submit task payload
|
|
5485
|
+
full_payload = {
|
|
5486
|
+
"type": task_type,
|
|
5487
|
+
"operation_id": str(operation_id),
|
|
5488
|
+
"bank_id": bank_id,
|
|
5489
|
+
**task_payload,
|
|
5490
|
+
}
|
|
5491
|
+
|
|
5492
|
+
await self._task_backend.submit_task(full_payload)
|
|
4095
5493
|
|
|
4096
|
-
logger.info(f"
|
|
5494
|
+
logger.info(f"{operation_type} task queued for bank_id={bank_id}, operation_id={operation_id}")
|
|
4097
5495
|
|
|
4098
5496
|
return {
|
|
4099
5497
|
"operation_id": str(operation_id),
|
|
4100
|
-
"items_count": len(contents),
|
|
4101
5498
|
}
|
|
5499
|
+
|
|
5500
|
+
async def submit_async_retain(
|
|
5501
|
+
self,
|
|
5502
|
+
bank_id: str,
|
|
5503
|
+
contents: list[dict[str, Any]],
|
|
5504
|
+
*,
|
|
5505
|
+
request_context: "RequestContext",
|
|
5506
|
+
document_tags: list[str] | None = None,
|
|
5507
|
+
) -> dict[str, Any]:
|
|
5508
|
+
"""Submit a batch retain operation to run asynchronously."""
|
|
5509
|
+
await self._authenticate_tenant(request_context)
|
|
5510
|
+
|
|
5511
|
+
task_payload: dict[str, Any] = {"contents": contents}
|
|
5512
|
+
if document_tags:
|
|
5513
|
+
task_payload["document_tags"] = document_tags
|
|
5514
|
+
|
|
5515
|
+
result = await self._submit_async_operation(
|
|
5516
|
+
bank_id=bank_id,
|
|
5517
|
+
operation_type="retain",
|
|
5518
|
+
task_type="batch_retain",
|
|
5519
|
+
task_payload=task_payload,
|
|
5520
|
+
result_metadata={"items_count": len(contents)},
|
|
5521
|
+
dedupe_by_bank=False,
|
|
5522
|
+
)
|
|
5523
|
+
|
|
5524
|
+
result["items_count"] = len(contents)
|
|
5525
|
+
return result
|
|
5526
|
+
|
|
5527
|
+
async def submit_async_consolidation(
|
|
5528
|
+
self,
|
|
5529
|
+
bank_id: str,
|
|
5530
|
+
*,
|
|
5531
|
+
request_context: "RequestContext",
|
|
5532
|
+
) -> dict[str, Any]:
|
|
5533
|
+
"""Submit a consolidation operation to run asynchronously.
|
|
5534
|
+
|
|
5535
|
+
Deduplicates by bank_id - if there's already a pending consolidation for this bank,
|
|
5536
|
+
returns the existing operation_id instead of creating a new one.
|
|
5537
|
+
|
|
5538
|
+
Args:
|
|
5539
|
+
bank_id: Bank identifier
|
|
5540
|
+
request_context: Request context for authentication
|
|
5541
|
+
|
|
5542
|
+
Returns:
|
|
5543
|
+
Dict with operation_id
|
|
5544
|
+
"""
|
|
5545
|
+
await self._authenticate_tenant(request_context)
|
|
5546
|
+
return await self._submit_async_operation(
|
|
5547
|
+
bank_id=bank_id,
|
|
5548
|
+
operation_type="consolidation",
|
|
5549
|
+
task_type="consolidation",
|
|
5550
|
+
task_payload={},
|
|
5551
|
+
dedupe_by_bank=True,
|
|
5552
|
+
)
|
|
5553
|
+
|
|
5554
|
+
async def submit_async_refresh_mental_model(
|
|
5555
|
+
self,
|
|
5556
|
+
bank_id: str,
|
|
5557
|
+
mental_model_id: str,
|
|
5558
|
+
*,
|
|
5559
|
+
request_context: "RequestContext",
|
|
5560
|
+
) -> dict[str, Any]:
|
|
5561
|
+
"""Submit an async mental model refresh operation.
|
|
5562
|
+
|
|
5563
|
+
This schedules a background task to re-run the source query and update the content.
|
|
5564
|
+
|
|
5565
|
+
Args:
|
|
5566
|
+
bank_id: Bank identifier
|
|
5567
|
+
mental_model_id: Mental model UUID to refresh
|
|
5568
|
+
request_context: Request context for authentication
|
|
5569
|
+
|
|
5570
|
+
Returns:
|
|
5571
|
+
Dict with operation_id
|
|
5572
|
+
"""
|
|
5573
|
+
await self._authenticate_tenant(request_context)
|
|
5574
|
+
|
|
5575
|
+
# Verify mental model exists
|
|
5576
|
+
mental_model = await self.get_mental_model(bank_id, mental_model_id, request_context=request_context)
|
|
5577
|
+
if not mental_model:
|
|
5578
|
+
raise ValueError(f"Mental model {mental_model_id} not found in bank {bank_id}")
|
|
5579
|
+
|
|
5580
|
+
return await self._submit_async_operation(
|
|
5581
|
+
bank_id=bank_id,
|
|
5582
|
+
operation_type="refresh_mental_model",
|
|
5583
|
+
task_type="refresh_mental_model",
|
|
5584
|
+
task_payload={
|
|
5585
|
+
"mental_model_id": mental_model_id,
|
|
5586
|
+
},
|
|
5587
|
+
result_metadata={"mental_model_id": mental_model_id, "name": mental_model["name"]},
|
|
5588
|
+
dedupe_by_bank=False,
|
|
5589
|
+
)
|