hindsight-api 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/admin/__init__.py +1 -0
- hindsight_api/admin/cli.py +252 -0
- hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
- hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
- hindsight_api/api/http.py +282 -20
- hindsight_api/api/mcp.py +47 -52
- hindsight_api/config.py +238 -6
- hindsight_api/engine/cross_encoder.py +599 -86
- hindsight_api/engine/db_budget.py +284 -0
- hindsight_api/engine/db_utils.py +11 -0
- hindsight_api/engine/embeddings.py +453 -26
- hindsight_api/engine/entity_resolver.py +8 -5
- hindsight_api/engine/interface.py +8 -4
- hindsight_api/engine/llm_wrapper.py +241 -27
- hindsight_api/engine/memory_engine.py +609 -122
- hindsight_api/engine/query_analyzer.py +4 -3
- hindsight_api/engine/response_models.py +38 -0
- hindsight_api/engine/retain/fact_extraction.py +388 -192
- hindsight_api/engine/retain/fact_storage.py +34 -8
- hindsight_api/engine/retain/link_utils.py +24 -16
- hindsight_api/engine/retain/orchestrator.py +52 -17
- hindsight_api/engine/retain/types.py +9 -0
- hindsight_api/engine/search/graph_retrieval.py +42 -13
- hindsight_api/engine/search/link_expansion_retrieval.py +256 -0
- hindsight_api/engine/search/mpfp_retrieval.py +362 -117
- hindsight_api/engine/search/reranking.py +2 -2
- hindsight_api/engine/search/retrieval.py +847 -200
- hindsight_api/engine/search/tags.py +172 -0
- hindsight_api/engine/search/think_utils.py +1 -1
- hindsight_api/engine/search/trace.py +12 -0
- hindsight_api/engine/search/tracer.py +24 -1
- hindsight_api/engine/search/types.py +21 -0
- hindsight_api/engine/task_backend.py +109 -18
- hindsight_api/engine/utils.py +1 -1
- hindsight_api/extensions/context.py +10 -1
- hindsight_api/main.py +56 -4
- hindsight_api/metrics.py +433 -48
- hindsight_api/migrations.py +141 -1
- hindsight_api/models.py +3 -1
- hindsight_api/pg0.py +53 -0
- hindsight_api/server.py +39 -2
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/METADATA +5 -1
- hindsight_api-0.3.0.dist-info/RECORD +82 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.2.1.dist-info/RECORD +0 -75
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.3.0.dist-info}/WHEEL +0 -0
|
@@ -18,6 +18,8 @@ from datetime import UTC, datetime, timedelta
|
|
|
18
18
|
from typing import TYPE_CHECKING, Any
|
|
19
19
|
|
|
20
20
|
from ..config import get_config
|
|
21
|
+
from ..metrics import get_metrics_collector
|
|
22
|
+
from .db_budget import budgeted_operation
|
|
21
23
|
|
|
22
24
|
# Context variable for current schema (async-safe, per-task isolation)
|
|
23
25
|
_current_schema: contextvars.ContextVar[str] = contextvars.ContextVar("current_schema", default="public")
|
|
@@ -132,17 +134,25 @@ if TYPE_CHECKING:
|
|
|
132
134
|
|
|
133
135
|
from enum import Enum
|
|
134
136
|
|
|
135
|
-
from ..pg0 import EmbeddedPostgres
|
|
137
|
+
from ..pg0 import EmbeddedPostgres, parse_pg0_url
|
|
136
138
|
from .entity_resolver import EntityResolver
|
|
137
139
|
from .llm_wrapper import LLMConfig
|
|
138
140
|
from .query_analyzer import QueryAnalyzer
|
|
139
|
-
from .response_models import
|
|
141
|
+
from .response_models import (
|
|
142
|
+
VALID_RECALL_FACT_TYPES,
|
|
143
|
+
EntityObservation,
|
|
144
|
+
EntityState,
|
|
145
|
+
MemoryFact,
|
|
146
|
+
ReflectResult,
|
|
147
|
+
TokenUsage,
|
|
148
|
+
)
|
|
140
149
|
from .response_models import RecallResult as RecallResultModel
|
|
141
150
|
from .retain import bank_utils, embedding_utils
|
|
142
151
|
from .retain.types import RetainContentDict
|
|
143
152
|
from .search import observation_utils, think_utils
|
|
144
153
|
from .search.reranking import CrossEncoderReranker
|
|
145
|
-
from .
|
|
154
|
+
from .search.tags import TagsMatch
|
|
155
|
+
from .task_backend import AsyncIOQueueBackend, NoopTaskBackend, TaskBackend
|
|
146
156
|
|
|
147
157
|
|
|
148
158
|
class Budget(str, Enum):
|
|
@@ -195,12 +205,25 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
195
205
|
memory_llm_api_key: str | None = None,
|
|
196
206
|
memory_llm_model: str | None = None,
|
|
197
207
|
memory_llm_base_url: str | None = None,
|
|
208
|
+
# Per-operation LLM config (optional, falls back to memory_llm_* params)
|
|
209
|
+
retain_llm_provider: str | None = None,
|
|
210
|
+
retain_llm_api_key: str | None = None,
|
|
211
|
+
retain_llm_model: str | None = None,
|
|
212
|
+
retain_llm_base_url: str | None = None,
|
|
213
|
+
reflect_llm_provider: str | None = None,
|
|
214
|
+
reflect_llm_api_key: str | None = None,
|
|
215
|
+
reflect_llm_model: str | None = None,
|
|
216
|
+
reflect_llm_base_url: str | None = None,
|
|
198
217
|
embeddings: Embeddings | None = None,
|
|
199
218
|
cross_encoder: CrossEncoderModel | None = None,
|
|
200
219
|
query_analyzer: QueryAnalyzer | None = None,
|
|
201
|
-
pool_min_size: int =
|
|
202
|
-
pool_max_size: int =
|
|
220
|
+
pool_min_size: int | None = None,
|
|
221
|
+
pool_max_size: int | None = None,
|
|
222
|
+
db_command_timeout: int | None = None,
|
|
223
|
+
db_acquire_timeout: int | None = None,
|
|
203
224
|
task_backend: TaskBackend | None = None,
|
|
225
|
+
task_batch_size: int | None = None,
|
|
226
|
+
task_batch_interval: float | None = None,
|
|
204
227
|
run_migrations: bool = True,
|
|
205
228
|
operation_validator: "OperationValidatorExtension | None" = None,
|
|
206
229
|
tenant_extension: "TenantExtension | None" = None,
|
|
@@ -220,12 +243,24 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
220
243
|
memory_llm_api_key: API key for the LLM provider. Defaults to HINDSIGHT_API_LLM_API_KEY env var.
|
|
221
244
|
memory_llm_model: Model name. Defaults to HINDSIGHT_API_LLM_MODEL env var.
|
|
222
245
|
memory_llm_base_url: Base URL for the LLM API. Defaults based on provider.
|
|
246
|
+
retain_llm_provider: LLM provider for retain operations. Falls back to memory_llm_provider.
|
|
247
|
+
retain_llm_api_key: API key for retain LLM. Falls back to memory_llm_api_key.
|
|
248
|
+
retain_llm_model: Model for retain operations. Falls back to memory_llm_model.
|
|
249
|
+
retain_llm_base_url: Base URL for retain LLM. Falls back to memory_llm_base_url.
|
|
250
|
+
reflect_llm_provider: LLM provider for reflect operations. Falls back to memory_llm_provider.
|
|
251
|
+
reflect_llm_api_key: API key for reflect LLM. Falls back to memory_llm_api_key.
|
|
252
|
+
reflect_llm_model: Model for reflect operations. Falls back to memory_llm_model.
|
|
253
|
+
reflect_llm_base_url: Base URL for reflect LLM. Falls back to memory_llm_base_url.
|
|
223
254
|
embeddings: Embeddings implementation. If not provided, created from env vars.
|
|
224
255
|
cross_encoder: Cross-encoder model. If not provided, created from env vars.
|
|
225
256
|
query_analyzer: Query analyzer implementation. If not provided, uses DateparserQueryAnalyzer.
|
|
226
|
-
pool_min_size: Minimum number of connections in the pool
|
|
227
|
-
pool_max_size: Maximum number of connections in the pool
|
|
257
|
+
pool_min_size: Minimum number of connections in the pool. Defaults to HINDSIGHT_API_DB_POOL_MIN_SIZE.
|
|
258
|
+
pool_max_size: Maximum number of connections in the pool. Defaults to HINDSIGHT_API_DB_POOL_MAX_SIZE.
|
|
259
|
+
db_command_timeout: PostgreSQL command timeout in seconds. Defaults to HINDSIGHT_API_DB_COMMAND_TIMEOUT.
|
|
260
|
+
db_acquire_timeout: Connection acquisition timeout in seconds. Defaults to HINDSIGHT_API_DB_ACQUIRE_TIMEOUT.
|
|
228
261
|
task_backend: Custom task backend. If not provided, uses AsyncIOQueueBackend.
|
|
262
|
+
task_batch_size: Background task batch size. Defaults to HINDSIGHT_API_TASK_BACKEND_MEMORY_BATCH_SIZE.
|
|
263
|
+
task_batch_interval: Background task batch interval in seconds. Defaults to HINDSIGHT_API_TASK_BACKEND_MEMORY_BATCH_INTERVAL.
|
|
229
264
|
run_migrations: Whether to run database migrations during initialize(). Default: True
|
|
230
265
|
operation_validator: Optional extension to validate operations before execution.
|
|
231
266
|
If provided, retain/recall/reflect operations will be validated.
|
|
@@ -252,38 +287,21 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
252
287
|
db_url = db_url or config.database_url
|
|
253
288
|
memory_llm_provider = memory_llm_provider or config.llm_provider
|
|
254
289
|
memory_llm_api_key = memory_llm_api_key or config.llm_api_key
|
|
255
|
-
# Ollama
|
|
256
|
-
if not memory_llm_api_key and memory_llm_provider
|
|
290
|
+
# Ollama and mock don't require an API key
|
|
291
|
+
if not memory_llm_api_key and memory_llm_provider not in ("ollama", "mock"):
|
|
257
292
|
raise ValueError("LLM API key is required. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
|
|
258
293
|
memory_llm_model = memory_llm_model or config.llm_model
|
|
259
294
|
memory_llm_base_url = memory_llm_base_url or config.get_llm_base_url() or None
|
|
260
295
|
# Track pg0 instance (if used)
|
|
261
296
|
self._pg0: EmbeddedPostgres | None = None
|
|
262
|
-
self._pg0_instance_name: str | None = None
|
|
263
297
|
|
|
264
298
|
# Initialize PostgreSQL connection URL
|
|
265
299
|
# The actual URL will be set during initialize() after starting the server
|
|
266
300
|
# Supports: "pg0" (default instance), "pg0://instance-name" (named instance), or regular postgresql:// URL
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
self._pg0_instance_name = "hindsight"
|
|
270
|
-
self._pg0_port = None # Use default port
|
|
271
|
-
self.db_url = None
|
|
272
|
-
elif db_url.startswith("pg0://"):
|
|
273
|
-
self._use_pg0 = True
|
|
274
|
-
# Parse instance name and optional port: pg0://instance-name or pg0://instance-name:port
|
|
275
|
-
url_part = db_url[6:] # Remove "pg0://"
|
|
276
|
-
if ":" in url_part:
|
|
277
|
-
self._pg0_instance_name, port_str = url_part.rsplit(":", 1)
|
|
278
|
-
self._pg0_port = int(port_str)
|
|
279
|
-
else:
|
|
280
|
-
self._pg0_instance_name = url_part or "hindsight"
|
|
281
|
-
self._pg0_port = None # Use default port
|
|
301
|
+
self._use_pg0, self._pg0_instance_name, self._pg0_port = parse_pg0_url(db_url)
|
|
302
|
+
if self._use_pg0:
|
|
282
303
|
self.db_url = None
|
|
283
304
|
else:
|
|
284
|
-
self._use_pg0 = False
|
|
285
|
-
self._pg0_instance_name = None
|
|
286
|
-
self._pg0_port = None
|
|
287
305
|
self.db_url = db_url
|
|
288
306
|
|
|
289
307
|
# Set default base URL if not provided
|
|
@@ -298,8 +316,10 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
298
316
|
# Connection pool (will be created in initialize())
|
|
299
317
|
self._pool = None
|
|
300
318
|
self._initialized = False
|
|
301
|
-
self._pool_min_size = pool_min_size
|
|
302
|
-
self._pool_max_size = pool_max_size
|
|
319
|
+
self._pool_min_size = pool_min_size if pool_min_size is not None else config.db_pool_min_size
|
|
320
|
+
self._pool_max_size = pool_max_size if pool_max_size is not None else config.db_pool_max_size
|
|
321
|
+
self._db_command_timeout = db_command_timeout if db_command_timeout is not None else config.db_command_timeout
|
|
322
|
+
self._db_acquire_timeout = db_acquire_timeout if db_acquire_timeout is not None else config.db_acquire_timeout
|
|
303
323
|
self._run_migrations = run_migrations
|
|
304
324
|
|
|
305
325
|
# Initialize entity resolver (will be created in initialize())
|
|
@@ -319,7 +339,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
319
339
|
|
|
320
340
|
self.query_analyzer = DateparserQueryAnalyzer()
|
|
321
341
|
|
|
322
|
-
# Initialize LLM configuration
|
|
342
|
+
# Initialize LLM configuration (default, used as fallback)
|
|
323
343
|
self._llm_config = LLMConfig(
|
|
324
344
|
provider=memory_llm_provider,
|
|
325
345
|
api_key=memory_llm_api_key,
|
|
@@ -331,17 +351,68 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
331
351
|
self._llm_client = self._llm_config._client
|
|
332
352
|
self._llm_model = self._llm_config.model
|
|
333
353
|
|
|
354
|
+
# Initialize per-operation LLM configs (fall back to default if not specified)
|
|
355
|
+
# Retain LLM config - for fact extraction (benefits from strong structured output)
|
|
356
|
+
retain_provider = retain_llm_provider or config.retain_llm_provider or memory_llm_provider
|
|
357
|
+
retain_api_key = retain_llm_api_key or config.retain_llm_api_key or memory_llm_api_key
|
|
358
|
+
retain_model = retain_llm_model or config.retain_llm_model or memory_llm_model
|
|
359
|
+
retain_base_url = retain_llm_base_url or config.retain_llm_base_url or memory_llm_base_url
|
|
360
|
+
# Apply provider-specific base URL defaults for retain
|
|
361
|
+
if retain_base_url is None:
|
|
362
|
+
if retain_provider.lower() == "groq":
|
|
363
|
+
retain_base_url = "https://api.groq.com/openai/v1"
|
|
364
|
+
elif retain_provider.lower() == "ollama":
|
|
365
|
+
retain_base_url = "http://localhost:11434/v1"
|
|
366
|
+
else:
|
|
367
|
+
retain_base_url = ""
|
|
368
|
+
|
|
369
|
+
self._retain_llm_config = LLMConfig(
|
|
370
|
+
provider=retain_provider,
|
|
371
|
+
api_key=retain_api_key,
|
|
372
|
+
base_url=retain_base_url,
|
|
373
|
+
model=retain_model,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Reflect LLM config - for think/observe operations (can use lighter models)
|
|
377
|
+
reflect_provider = reflect_llm_provider or config.reflect_llm_provider or memory_llm_provider
|
|
378
|
+
reflect_api_key = reflect_llm_api_key or config.reflect_llm_api_key or memory_llm_api_key
|
|
379
|
+
reflect_model = reflect_llm_model or config.reflect_llm_model or memory_llm_model
|
|
380
|
+
reflect_base_url = reflect_llm_base_url or config.reflect_llm_base_url or memory_llm_base_url
|
|
381
|
+
# Apply provider-specific base URL defaults for reflect
|
|
382
|
+
if reflect_base_url is None:
|
|
383
|
+
if reflect_provider.lower() == "groq":
|
|
384
|
+
reflect_base_url = "https://api.groq.com/openai/v1"
|
|
385
|
+
elif reflect_provider.lower() == "ollama":
|
|
386
|
+
reflect_base_url = "http://localhost:11434/v1"
|
|
387
|
+
else:
|
|
388
|
+
reflect_base_url = ""
|
|
389
|
+
|
|
390
|
+
self._reflect_llm_config = LLMConfig(
|
|
391
|
+
provider=reflect_provider,
|
|
392
|
+
api_key=reflect_api_key,
|
|
393
|
+
base_url=reflect_base_url,
|
|
394
|
+
model=reflect_model,
|
|
395
|
+
)
|
|
396
|
+
|
|
334
397
|
# Initialize cross-encoder reranker (cached for performance)
|
|
335
398
|
self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
|
|
336
399
|
|
|
337
400
|
# Initialize task backend
|
|
338
|
-
|
|
401
|
+
if task_backend:
|
|
402
|
+
self._task_backend = task_backend
|
|
403
|
+
elif config.task_backend == "noop":
|
|
404
|
+
self._task_backend = NoopTaskBackend()
|
|
405
|
+
else:
|
|
406
|
+
# Default to memory (AsyncIOQueueBackend)
|
|
407
|
+
_task_batch_size = task_batch_size if task_batch_size is not None else config.task_backend_memory_batch_size
|
|
408
|
+
_task_batch_interval = (
|
|
409
|
+
task_batch_interval if task_batch_interval is not None else config.task_backend_memory_batch_interval
|
|
410
|
+
)
|
|
411
|
+
self._task_backend = AsyncIOQueueBackend(batch_size=_task_batch_size, batch_interval=_task_batch_interval)
|
|
339
412
|
|
|
340
413
|
# Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
|
|
341
|
-
#
|
|
342
|
-
|
|
343
|
-
# we use ~20-40 connections max, staying well within pool limits
|
|
344
|
-
self._search_semaphore = asyncio.Semaphore(10)
|
|
414
|
+
# Configurable via HINDSIGHT_API_RECALL_MAX_CONCURRENT (default: 50)
|
|
415
|
+
self._search_semaphore = asyncio.Semaphore(get_config().recall_max_concurrent)
|
|
345
416
|
|
|
346
417
|
# Backpressure for put operations: limit concurrent puts to prevent database contention
|
|
347
418
|
# Each put_batch holds a connection for the entire transaction, so we limit to 5
|
|
@@ -618,9 +689,27 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
618
689
|
await loop.run_in_executor(None, self.query_analyzer.load)
|
|
619
690
|
|
|
620
691
|
async def verify_llm():
|
|
621
|
-
"""Verify LLM
|
|
692
|
+
"""Verify LLM connections are working for all unique configs."""
|
|
622
693
|
if not self._skip_llm_verification:
|
|
694
|
+
# Verify default config
|
|
623
695
|
await self._llm_config.verify_connection()
|
|
696
|
+
# Verify retain config if different from default
|
|
697
|
+
retain_is_different = (
|
|
698
|
+
self._retain_llm_config.provider != self._llm_config.provider
|
|
699
|
+
or self._retain_llm_config.model != self._llm_config.model
|
|
700
|
+
)
|
|
701
|
+
if retain_is_different:
|
|
702
|
+
await self._retain_llm_config.verify_connection()
|
|
703
|
+
# Verify reflect config if different from default and retain
|
|
704
|
+
reflect_is_different = (
|
|
705
|
+
self._reflect_llm_config.provider != self._llm_config.provider
|
|
706
|
+
or self._reflect_llm_config.model != self._llm_config.model
|
|
707
|
+
) and (
|
|
708
|
+
self._reflect_llm_config.provider != self._retain_llm_config.provider
|
|
709
|
+
or self._reflect_llm_config.model != self._retain_llm_config.model
|
|
710
|
+
)
|
|
711
|
+
if reflect_is_different:
|
|
712
|
+
await self._reflect_llm_config.verify_connection()
|
|
624
713
|
|
|
625
714
|
# Build list of initialization tasks
|
|
626
715
|
init_tasks = [
|
|
@@ -642,13 +731,17 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
642
731
|
|
|
643
732
|
# Run database migrations if enabled
|
|
644
733
|
if self._run_migrations:
|
|
645
|
-
from ..migrations import run_migrations
|
|
734
|
+
from ..migrations import ensure_embedding_dimension, run_migrations
|
|
646
735
|
|
|
647
736
|
if not self.db_url:
|
|
648
737
|
raise ValueError("Database URL is required for migrations")
|
|
649
738
|
logger.info("Running database migrations...")
|
|
650
739
|
run_migrations(self.db_url)
|
|
651
740
|
|
|
741
|
+
# Ensure embedding column dimension matches the model's dimension
|
|
742
|
+
# This is done after migrations and after embeddings.initialize()
|
|
743
|
+
ensure_embedding_dimension(self.db_url, self.embeddings.dimension)
|
|
744
|
+
|
|
652
745
|
logger.info(f"Connecting to PostgreSQL at {self.db_url}")
|
|
653
746
|
|
|
654
747
|
# Create connection pool
|
|
@@ -658,9 +751,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
658
751
|
self.db_url,
|
|
659
752
|
min_size=self._pool_min_size,
|
|
660
753
|
max_size=self._pool_max_size,
|
|
661
|
-
command_timeout=
|
|
754
|
+
command_timeout=self._db_command_timeout,
|
|
662
755
|
statement_cache_size=0, # Disable prepared statement cache
|
|
663
|
-
timeout=
|
|
756
|
+
timeout=self._db_acquire_timeout, # Connection acquisition timeout (seconds)
|
|
664
757
|
)
|
|
665
758
|
|
|
666
759
|
# Initialize entity resolver with pool
|
|
@@ -967,7 +1060,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
967
1060
|
document_id: str | None = None,
|
|
968
1061
|
fact_type_override: str | None = None,
|
|
969
1062
|
confidence_score: float | None = None,
|
|
970
|
-
|
|
1063
|
+
document_tags: list[str] | None = None,
|
|
1064
|
+
return_usage: bool = False,
|
|
1065
|
+
):
|
|
971
1066
|
"""
|
|
972
1067
|
Store multiple content items as memory units in ONE batch operation.
|
|
973
1068
|
|
|
@@ -988,9 +1083,11 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
988
1083
|
Applies the same document_id to ALL content items that don't specify their own.
|
|
989
1084
|
fact_type_override: Override fact type for all facts ('world', 'experience', 'opinion')
|
|
990
1085
|
confidence_score: Confidence score for opinions (0.0 to 1.0)
|
|
1086
|
+
return_usage: If True, returns tuple of (unit_ids, TokenUsage). Default False for backward compatibility.
|
|
991
1087
|
|
|
992
1088
|
Returns:
|
|
993
|
-
List of lists of unit IDs (one list per content item)
|
|
1089
|
+
If return_usage=False: List of lists of unit IDs (one list per content item)
|
|
1090
|
+
If return_usage=True: Tuple of (unit_ids, TokenUsage)
|
|
994
1091
|
|
|
995
1092
|
Example (new style - per-content document_id):
|
|
996
1093
|
unit_ids = await memory.retain_batch_async(
|
|
@@ -1017,6 +1114,8 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1017
1114
|
start_time = time.time()
|
|
1018
1115
|
|
|
1019
1116
|
if not contents:
|
|
1117
|
+
if return_usage:
|
|
1118
|
+
return [], TokenUsage()
|
|
1020
1119
|
return []
|
|
1021
1120
|
|
|
1022
1121
|
# Authenticate tenant and set schema in context (for fq_table())
|
|
@@ -1046,6 +1145,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1046
1145
|
# Auto-chunk large batches by character count to avoid timeouts and memory issues
|
|
1047
1146
|
# Calculate total character count
|
|
1048
1147
|
total_chars = sum(len(item.get("content", "")) for item in contents)
|
|
1148
|
+
total_usage = TokenUsage()
|
|
1049
1149
|
|
|
1050
1150
|
CHARS_PER_BATCH = 600_000
|
|
1051
1151
|
|
|
@@ -1086,15 +1186,17 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1086
1186
|
f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars"
|
|
1087
1187
|
)
|
|
1088
1188
|
|
|
1089
|
-
sub_results = await self._retain_batch_async_internal(
|
|
1189
|
+
sub_results, sub_usage = await self._retain_batch_async_internal(
|
|
1090
1190
|
bank_id=bank_id,
|
|
1091
1191
|
contents=sub_batch,
|
|
1092
1192
|
document_id=document_id,
|
|
1093
1193
|
is_first_batch=i == 1, # Only upsert on first batch
|
|
1094
1194
|
fact_type_override=fact_type_override,
|
|
1095
1195
|
confidence_score=confidence_score,
|
|
1196
|
+
document_tags=document_tags,
|
|
1096
1197
|
)
|
|
1097
1198
|
all_results.extend(sub_results)
|
|
1199
|
+
total_usage = total_usage + sub_usage
|
|
1098
1200
|
|
|
1099
1201
|
total_time = time.time() - start_time
|
|
1100
1202
|
logger.info(
|
|
@@ -1103,13 +1205,14 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1103
1205
|
result = all_results
|
|
1104
1206
|
else:
|
|
1105
1207
|
# Small batch - use internal method directly
|
|
1106
|
-
result = await self._retain_batch_async_internal(
|
|
1208
|
+
result, total_usage = await self._retain_batch_async_internal(
|
|
1107
1209
|
bank_id=bank_id,
|
|
1108
1210
|
contents=contents,
|
|
1109
1211
|
document_id=document_id,
|
|
1110
1212
|
is_first_batch=True,
|
|
1111
1213
|
fact_type_override=fact_type_override,
|
|
1112
1214
|
confidence_score=confidence_score,
|
|
1215
|
+
document_tags=document_tags,
|
|
1113
1216
|
)
|
|
1114
1217
|
|
|
1115
1218
|
# Call post-operation hook if validator is configured
|
|
@@ -1132,6 +1235,8 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1132
1235
|
except Exception as e:
|
|
1133
1236
|
logger.warning(f"Post-retain hook error (non-fatal): {e}")
|
|
1134
1237
|
|
|
1238
|
+
if return_usage:
|
|
1239
|
+
return result, total_usage
|
|
1135
1240
|
return result
|
|
1136
1241
|
|
|
1137
1242
|
async def _retain_batch_async_internal(
|
|
@@ -1142,7 +1247,8 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1142
1247
|
is_first_batch: bool = True,
|
|
1143
1248
|
fact_type_override: str | None = None,
|
|
1144
1249
|
confidence_score: float | None = None,
|
|
1145
|
-
|
|
1250
|
+
document_tags: list[str] | None = None,
|
|
1251
|
+
) -> tuple[list[list[str]], "TokenUsage"]:
|
|
1146
1252
|
"""
|
|
1147
1253
|
Internal method for batch processing without chunking logic.
|
|
1148
1254
|
|
|
@@ -1158,6 +1264,10 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1158
1264
|
is_first_batch: Whether this is the first batch (for chunked operations, only delete on first batch)
|
|
1159
1265
|
fact_type_override: Override fact type for all facts
|
|
1160
1266
|
confidence_score: Confidence score for opinions
|
|
1267
|
+
document_tags: Tags applied to all items in this batch
|
|
1268
|
+
|
|
1269
|
+
Returns:
|
|
1270
|
+
Tuple of (unit ID lists, token usage for fact extraction)
|
|
1161
1271
|
"""
|
|
1162
1272
|
# Backpressure: limit concurrent retains to prevent database contention
|
|
1163
1273
|
async with self._put_semaphore:
|
|
@@ -1168,7 +1278,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1168
1278
|
return await orchestrator.retain_batch(
|
|
1169
1279
|
pool=pool,
|
|
1170
1280
|
embeddings_model=self.embeddings,
|
|
1171
|
-
llm_config=self.
|
|
1281
|
+
llm_config=self._retain_llm_config,
|
|
1172
1282
|
entity_resolver=self.entity_resolver,
|
|
1173
1283
|
task_backend=self._task_backend,
|
|
1174
1284
|
format_date_fn=self._format_readable_date,
|
|
@@ -1179,6 +1289,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1179
1289
|
is_first_batch=is_first_batch,
|
|
1180
1290
|
fact_type_override=fact_type_override,
|
|
1181
1291
|
confidence_score=confidence_score,
|
|
1292
|
+
document_tags=document_tags,
|
|
1182
1293
|
)
|
|
1183
1294
|
|
|
1184
1295
|
def recall(
|
|
@@ -1237,6 +1348,8 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1237
1348
|
include_chunks: bool = False,
|
|
1238
1349
|
max_chunk_tokens: int = 8192,
|
|
1239
1350
|
request_context: "RequestContext",
|
|
1351
|
+
tags: list[str] | None = None,
|
|
1352
|
+
tags_match: TagsMatch = "any",
|
|
1240
1353
|
) -> RecallResultModel:
|
|
1241
1354
|
"""
|
|
1242
1355
|
Recall memories using N*4-way parallel retrieval (N fact types × 4 retrieval methods).
|
|
@@ -1262,6 +1375,8 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1262
1375
|
max_entity_tokens: Maximum tokens for entity observations (default 500)
|
|
1263
1376
|
include_chunks: Whether to include raw chunks in the response
|
|
1264
1377
|
max_chunk_tokens: Maximum tokens for chunks (default 8192)
|
|
1378
|
+
tags: Optional list of tags for visibility filtering (OR matching - returns
|
|
1379
|
+
memories that have at least one matching tag)
|
|
1265
1380
|
|
|
1266
1381
|
Returns:
|
|
1267
1382
|
RecallResultModel containing:
|
|
@@ -1313,7 +1428,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1313
1428
|
# Backpressure: limit concurrent recalls to prevent overwhelming the database
|
|
1314
1429
|
result = None
|
|
1315
1430
|
error_msg = None
|
|
1431
|
+
semaphore_wait_start = time.time()
|
|
1316
1432
|
async with self._search_semaphore:
|
|
1433
|
+
semaphore_wait = time.time() - semaphore_wait_start
|
|
1317
1434
|
# Retry loop for connection errors
|
|
1318
1435
|
max_retries = 3
|
|
1319
1436
|
for attempt in range(max_retries + 1):
|
|
@@ -1331,6 +1448,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1331
1448
|
include_chunks,
|
|
1332
1449
|
max_chunk_tokens,
|
|
1333
1450
|
request_context,
|
|
1451
|
+
semaphore_wait=semaphore_wait,
|
|
1452
|
+
tags=tags,
|
|
1453
|
+
tags_match=tags_match,
|
|
1334
1454
|
)
|
|
1335
1455
|
break # Success - exit retry loop
|
|
1336
1456
|
except Exception as e:
|
|
@@ -1448,6 +1568,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1448
1568
|
include_chunks: bool = False,
|
|
1449
1569
|
max_chunk_tokens: int = 8192,
|
|
1450
1570
|
request_context: "RequestContext" = None,
|
|
1571
|
+
semaphore_wait: float = 0.0,
|
|
1572
|
+
tags: list[str] | None = None,
|
|
1573
|
+
tags_match: TagsMatch = "any",
|
|
1451
1574
|
) -> RecallResultModel:
|
|
1452
1575
|
"""
|
|
1453
1576
|
Search implementation with modular retrieval and reranking.
|
|
@@ -1477,7 +1600,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1477
1600
|
# Initialize tracer if requested
|
|
1478
1601
|
from .search.tracer import SearchTracer
|
|
1479
1602
|
|
|
1480
|
-
tracer =
|
|
1603
|
+
tracer = (
|
|
1604
|
+
SearchTracer(query, thinking_budget, max_tokens, tags=tags, tags_match=tags_match) if enable_trace else None
|
|
1605
|
+
)
|
|
1481
1606
|
if tracer:
|
|
1482
1607
|
tracer.start()
|
|
1483
1608
|
|
|
@@ -1487,8 +1612,9 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1487
1612
|
# Buffer logs for clean output in concurrent scenarios
|
|
1488
1613
|
recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
|
|
1489
1614
|
log_buffer = []
|
|
1615
|
+
tags_info = f", tags={tags}, tags_match={tags_match}" if tags else ""
|
|
1490
1616
|
log_buffer.append(
|
|
1491
|
-
f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})"
|
|
1617
|
+
f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens}{tags_info})"
|
|
1492
1618
|
)
|
|
1493
1619
|
|
|
1494
1620
|
try:
|
|
@@ -1502,37 +1628,67 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1502
1628
|
tracer.record_query_embedding(query_embedding)
|
|
1503
1629
|
tracer.add_phase_metric("generate_query_embedding", step_duration)
|
|
1504
1630
|
|
|
1505
|
-
# Step 2:
|
|
1631
|
+
# Step 2: Optimized parallel retrieval using batched queries
|
|
1632
|
+
# - Semantic + BM25 combined in 1 CTE query for ALL fact types
|
|
1633
|
+
# - Graph runs per fact type (complex traversal)
|
|
1634
|
+
# - Temporal runs per fact type (if constraint detected)
|
|
1506
1635
|
step_start = time.time()
|
|
1507
1636
|
query_embedding_str = str(query_embedding)
|
|
1508
1637
|
|
|
1509
|
-
from .search.retrieval import
|
|
1638
|
+
from .search.retrieval import (
|
|
1639
|
+
get_default_graph_retriever,
|
|
1640
|
+
retrieve_all_fact_types_parallel,
|
|
1641
|
+
)
|
|
1510
1642
|
|
|
1511
1643
|
# Track each retrieval start time
|
|
1512
1644
|
retrieval_start = time.time()
|
|
1513
1645
|
|
|
1514
|
-
# Run retrieval
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1646
|
+
# Run optimized retrieval with connection budget
|
|
1647
|
+
config = get_config()
|
|
1648
|
+
async with budgeted_operation(
|
|
1649
|
+
max_connections=config.recall_connection_budget,
|
|
1650
|
+
operation_id=f"recall-{recall_id}",
|
|
1651
|
+
) as op:
|
|
1652
|
+
budgeted_pool = op.wrap_pool(pool)
|
|
1653
|
+
parallel_start = time.time()
|
|
1654
|
+
multi_result = await retrieve_all_fact_types_parallel(
|
|
1655
|
+
budgeted_pool,
|
|
1656
|
+
query,
|
|
1657
|
+
query_embedding_str,
|
|
1658
|
+
bank_id,
|
|
1659
|
+
fact_type, # Pass all fact types at once
|
|
1660
|
+
thinking_budget,
|
|
1661
|
+
question_date,
|
|
1662
|
+
self.query_analyzer,
|
|
1663
|
+
tags=tags,
|
|
1664
|
+
tags_match=tags_match,
|
|
1518
1665
|
)
|
|
1519
|
-
|
|
1520
|
-
]
|
|
1521
|
-
all_retrievals = await asyncio.gather(*retrieval_tasks)
|
|
1666
|
+
parallel_duration = time.time() - parallel_start
|
|
1522
1667
|
|
|
1523
1668
|
# Combine all results from all fact types and aggregate timings
|
|
1524
1669
|
semantic_results = []
|
|
1525
1670
|
bm25_results = []
|
|
1526
1671
|
graph_results = []
|
|
1527
1672
|
temporal_results = []
|
|
1528
|
-
aggregated_timings = {
|
|
1673
|
+
aggregated_timings = {
|
|
1674
|
+
"semantic": 0.0,
|
|
1675
|
+
"bm25": 0.0,
|
|
1676
|
+
"graph": 0.0,
|
|
1677
|
+
"temporal": 0.0,
|
|
1678
|
+
"temporal_extraction": 0.0,
|
|
1679
|
+
}
|
|
1680
|
+
all_mpfp_timings = []
|
|
1529
1681
|
|
|
1530
1682
|
detected_temporal_constraint = None
|
|
1531
|
-
|
|
1683
|
+
max_conn_wait = multi_result.max_conn_wait
|
|
1684
|
+
for ft in fact_type:
|
|
1685
|
+
retrieval_result = multi_result.results_by_fact_type.get(ft)
|
|
1686
|
+
if not retrieval_result:
|
|
1687
|
+
continue
|
|
1688
|
+
|
|
1532
1689
|
# Log fact types in this retrieval batch
|
|
1533
|
-
ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
|
|
1534
1690
|
logger.debug(
|
|
1535
|
-
f"[RECALL {recall_id}] Fact type '{
|
|
1691
|
+
f"[RECALL {recall_id}] Fact type '{ft}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}"
|
|
1536
1692
|
)
|
|
1537
1693
|
|
|
1538
1694
|
semantic_results.extend(retrieval_result.semantic)
|
|
@@ -1546,6 +1702,8 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1546
1702
|
# Capture temporal constraint (same across all fact types)
|
|
1547
1703
|
if retrieval_result.temporal_constraint:
|
|
1548
1704
|
detected_temporal_constraint = retrieval_result.temporal_constraint
|
|
1705
|
+
# Collect MPFP timings
|
|
1706
|
+
all_mpfp_timings.extend(retrieval_result.mpfp_timings)
|
|
1549
1707
|
|
|
1550
1708
|
# If no temporal results from any fact type, set to None
|
|
1551
1709
|
if not temporal_results:
|
|
@@ -1564,12 +1722,12 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1564
1722
|
retrieval_duration = time.time() - retrieval_start
|
|
1565
1723
|
|
|
1566
1724
|
step_duration = time.time() - step_start
|
|
1567
|
-
|
|
1568
|
-
# Format per-method timings
|
|
1725
|
+
# Format per-method timings (these are the actual parallel retrieval times)
|
|
1569
1726
|
timing_parts = [
|
|
1570
1727
|
f"semantic={len(semantic_results)}({aggregated_timings['semantic']:.3f}s)",
|
|
1571
1728
|
f"bm25={len(bm25_results)}({aggregated_timings['bm25']:.3f}s)",
|
|
1572
1729
|
f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)",
|
|
1730
|
+
f"temporal_extraction={aggregated_timings['temporal_extraction']:.3f}s",
|
|
1573
1731
|
]
|
|
1574
1732
|
temporal_info = ""
|
|
1575
1733
|
if detected_temporal_constraint:
|
|
@@ -1578,9 +1736,41 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1578
1736
|
timing_parts.append(f"temporal={temporal_count}({aggregated_timings['temporal']:.3f}s)")
|
|
1579
1737
|
temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
|
|
1580
1738
|
log_buffer.append(
|
|
1581
|
-
f" [2]
|
|
1739
|
+
f" [2] Parallel retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {parallel_duration:.3f}s{temporal_info}"
|
|
1582
1740
|
)
|
|
1583
1741
|
|
|
1742
|
+
# Log graph retriever timing breakdown if available
|
|
1743
|
+
if all_mpfp_timings:
|
|
1744
|
+
retriever_name = get_default_graph_retriever().name.upper()
|
|
1745
|
+
mpfp_total = all_mpfp_timings[0] # Take first fact type's timing as representative
|
|
1746
|
+
mpfp_parts = [
|
|
1747
|
+
f"db_queries={mpfp_total.db_queries}",
|
|
1748
|
+
f"edge_load={mpfp_total.edge_load_time:.3f}s",
|
|
1749
|
+
f"edges={mpfp_total.edge_count}",
|
|
1750
|
+
f"patterns={mpfp_total.pattern_count}",
|
|
1751
|
+
]
|
|
1752
|
+
if mpfp_total.seeds_time > 0.01:
|
|
1753
|
+
mpfp_parts.append(f"seeds={mpfp_total.seeds_time:.3f}s")
|
|
1754
|
+
if mpfp_total.fusion > 0.001:
|
|
1755
|
+
mpfp_parts.append(f"fusion={mpfp_total.fusion:.3f}s")
|
|
1756
|
+
if mpfp_total.fetch > 0.001:
|
|
1757
|
+
mpfp_parts.append(f"fetch={mpfp_total.fetch:.3f}s")
|
|
1758
|
+
log_buffer.append(f" [{retriever_name}] {', '.join(mpfp_parts)}")
|
|
1759
|
+
# Log detailed hop timing for debugging slow queries
|
|
1760
|
+
if mpfp_total.hop_details:
|
|
1761
|
+
for hd in mpfp_total.hop_details:
|
|
1762
|
+
log_buffer.append(
|
|
1763
|
+
f" hop{hd['hop']}: exec={hd.get('exec_time', 0) * 1000:.0f}ms, "
|
|
1764
|
+
f"uncached={hd.get('uncached_after_filter', 0)}, "
|
|
1765
|
+
f"load={hd.get('load_time', 0) * 1000:.0f}ms, "
|
|
1766
|
+
f"edges={hd.get('edges_loaded', 0)}"
|
|
1767
|
+
)
|
|
1768
|
+
|
|
1769
|
+
# Record temporal constraint in tracer if detected
|
|
1770
|
+
if tracer and detected_temporal_constraint:
|
|
1771
|
+
start_dt, end_dt = detected_temporal_constraint
|
|
1772
|
+
tracer.record_temporal_constraint(start_dt, end_dt)
|
|
1773
|
+
|
|
1584
1774
|
# Record retrieval results for tracer - per fact type
|
|
1585
1775
|
if tracer:
|
|
1586
1776
|
# Convert RetrievalResult to old tuple format for tracer
|
|
@@ -1588,8 +1778,10 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1588
1778
|
return [(r.id, r.__dict__) for r in results]
|
|
1589
1779
|
|
|
1590
1780
|
# Add retrieval results per fact type (to show parallel execution in UI)
|
|
1591
|
-
for
|
|
1592
|
-
|
|
1781
|
+
for ft_name in fact_type:
|
|
1782
|
+
rr = multi_result.results_by_fact_type.get(ft_name)
|
|
1783
|
+
if not rr:
|
|
1784
|
+
continue
|
|
1593
1785
|
|
|
1594
1786
|
# Add semantic retrieval results for this fact type
|
|
1595
1787
|
tracer.add_retrieval_results(
|
|
@@ -1621,14 +1813,22 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1621
1813
|
fact_type=ft_name,
|
|
1622
1814
|
)
|
|
1623
1815
|
|
|
1624
|
-
# Add temporal retrieval results for this fact type
|
|
1625
|
-
|
|
1816
|
+
# Add temporal retrieval results for this fact type
|
|
1817
|
+
# Show temporal even with 0 results if constraint was detected
|
|
1818
|
+
if rr.temporal is not None or rr.temporal_constraint is not None:
|
|
1819
|
+
temporal_metadata = {"budget": thinking_budget}
|
|
1820
|
+
if rr.temporal_constraint:
|
|
1821
|
+
start_dt, end_dt = rr.temporal_constraint
|
|
1822
|
+
temporal_metadata["constraint"] = {
|
|
1823
|
+
"start": start_dt.isoformat() if start_dt else None,
|
|
1824
|
+
"end": end_dt.isoformat() if end_dt else None,
|
|
1825
|
+
}
|
|
1626
1826
|
tracer.add_retrieval_results(
|
|
1627
1827
|
method_name="temporal",
|
|
1628
|
-
results=to_tuple_format(rr.temporal),
|
|
1828
|
+
results=to_tuple_format(rr.temporal or []),
|
|
1629
1829
|
duration_seconds=rr.timings.get("temporal", 0.0),
|
|
1630
1830
|
score_field="temporal_score",
|
|
1631
|
-
metadata=
|
|
1831
|
+
metadata=temporal_metadata,
|
|
1632
1832
|
fact_type=ft_name,
|
|
1633
1833
|
)
|
|
1634
1834
|
|
|
@@ -1678,11 +1878,24 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1678
1878
|
# Ensure reranker is initialized (for lazy initialization mode)
|
|
1679
1879
|
await reranker_instance.ensure_initialized()
|
|
1680
1880
|
|
|
1881
|
+
# Pre-filter candidates to reduce reranking cost (RRF already provides good ranking)
|
|
1882
|
+
# This is especially important for remote rerankers with network latency
|
|
1883
|
+
reranker_max_candidates = get_config().reranker_max_candidates
|
|
1884
|
+
pre_filtered_count = 0
|
|
1885
|
+
if len(merged_candidates) > reranker_max_candidates:
|
|
1886
|
+
# Sort by RRF score and take top candidates
|
|
1887
|
+
merged_candidates.sort(key=lambda mc: mc.rrf_score, reverse=True)
|
|
1888
|
+
pre_filtered_count = len(merged_candidates) - reranker_max_candidates
|
|
1889
|
+
merged_candidates = merged_candidates[:reranker_max_candidates]
|
|
1890
|
+
|
|
1681
1891
|
# Rerank using cross-encoder
|
|
1682
|
-
scored_results = reranker_instance.rerank(query, merged_candidates)
|
|
1892
|
+
scored_results = await reranker_instance.rerank(query, merged_candidates)
|
|
1683
1893
|
|
|
1684
1894
|
step_duration = time.time() - step_start
|
|
1685
|
-
|
|
1895
|
+
pre_filter_note = f" (pre-filtered {pre_filtered_count})" if pre_filtered_count > 0 else ""
|
|
1896
|
+
log_buffer.append(
|
|
1897
|
+
f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s{pre_filter_note}"
|
|
1898
|
+
)
|
|
1686
1899
|
|
|
1687
1900
|
# Step 4.5: Combine cross-encoder score with retrieval signals
|
|
1688
1901
|
# This preserves retrieval work (RRF, temporal, recency) instead of pure cross-encoder ranking
|
|
@@ -1732,9 +1945,6 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1732
1945
|
|
|
1733
1946
|
# Re-sort by combined score
|
|
1734
1947
|
scored_results.sort(key=lambda x: x.weight, reverse=True)
|
|
1735
|
-
log_buffer.append(
|
|
1736
|
-
" [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)"
|
|
1737
|
-
)
|
|
1738
1948
|
|
|
1739
1949
|
# Add reranked results to tracer AFTER combined scoring (so normalized values are included)
|
|
1740
1950
|
if tracer:
|
|
@@ -1753,7 +1963,6 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1753
1963
|
# Step 5: Truncate to thinking_budget * 2 for token filtering
|
|
1754
1964
|
rerank_limit = thinking_budget * 2
|
|
1755
1965
|
top_scored = scored_results[:rerank_limit]
|
|
1756
|
-
log_buffer.append(f" [5] Truncated to top {len(top_scored)} results")
|
|
1757
1966
|
|
|
1758
1967
|
# Step 6: Token budget filtering
|
|
1759
1968
|
step_start = time.time()
|
|
@@ -1768,7 +1977,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1768
1977
|
|
|
1769
1978
|
step_duration = time.time() - step_start
|
|
1770
1979
|
log_buffer.append(
|
|
1771
|
-
f" [
|
|
1980
|
+
f" [5] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s"
|
|
1772
1981
|
)
|
|
1773
1982
|
|
|
1774
1983
|
if tracer:
|
|
@@ -1802,7 +2011,6 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1802
2011
|
visited_ids = list(set([sr.id for sr in scored_results[:50]])) # Top 50
|
|
1803
2012
|
if visited_ids:
|
|
1804
2013
|
await self._task_backend.submit_task({"type": "access_count_update", "node_ids": visited_ids})
|
|
1805
|
-
log_buffer.append(f" [7] Queued access count updates for {len(visited_ids)} nodes")
|
|
1806
2014
|
|
|
1807
2015
|
# Log fact_type distribution in results
|
|
1808
2016
|
fact_type_counts = {}
|
|
@@ -1835,6 +2043,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1835
2043
|
top_results_dicts.append(result_dict)
|
|
1836
2044
|
|
|
1837
2045
|
# Get entities for each fact if include_entities is requested
|
|
2046
|
+
step_start = time.time()
|
|
1838
2047
|
fact_entity_map = {} # unit_id -> list of (entity_id, entity_name)
|
|
1839
2048
|
if include_entities and top_scored:
|
|
1840
2049
|
unit_ids = [uuid.UUID(sr.id) for sr in top_scored]
|
|
@@ -1856,6 +2065,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1856
2065
|
fact_entity_map[unit_id].append(
|
|
1857
2066
|
{"entity_id": str(row["entity_id"]), "canonical_name": row["canonical_name"]}
|
|
1858
2067
|
)
|
|
2068
|
+
entity_map_duration = time.time() - step_start
|
|
1859
2069
|
|
|
1860
2070
|
# Convert results to MemoryFact objects
|
|
1861
2071
|
memory_facts = []
|
|
@@ -1878,10 +2088,12 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1878
2088
|
mentioned_at=result_dict.get("mentioned_at"),
|
|
1879
2089
|
document_id=result_dict.get("document_id"),
|
|
1880
2090
|
chunk_id=result_dict.get("chunk_id"),
|
|
2091
|
+
tags=result_dict.get("tags"),
|
|
1881
2092
|
)
|
|
1882
2093
|
)
|
|
1883
2094
|
|
|
1884
2095
|
# Fetch entity observations if requested
|
|
2096
|
+
step_start = time.time()
|
|
1885
2097
|
entities_dict = None
|
|
1886
2098
|
total_entity_tokens = 0
|
|
1887
2099
|
total_chunk_tokens = 0
|
|
@@ -1902,7 +2114,13 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1902
2114
|
entities_ordered.append((entity_id, entity_name))
|
|
1903
2115
|
seen_entity_ids.add(entity_id)
|
|
1904
2116
|
|
|
1905
|
-
# Fetch observations
|
|
2117
|
+
# Fetch all observations in a single batched query
|
|
2118
|
+
entity_ids = [eid for eid, _ in entities_ordered]
|
|
2119
|
+
all_observations = await self.get_entity_observations_batch(
|
|
2120
|
+
bank_id, entity_ids, limit_per_entity=5, request_context=request_context
|
|
2121
|
+
)
|
|
2122
|
+
|
|
2123
|
+
# Build entities_dict respecting token budget, in relevance order
|
|
1906
2124
|
entities_dict = {}
|
|
1907
2125
|
encoding = _get_tiktoken_encoding()
|
|
1908
2126
|
|
|
@@ -1910,9 +2128,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1910
2128
|
if total_entity_tokens >= max_entity_tokens:
|
|
1911
2129
|
break
|
|
1912
2130
|
|
|
1913
|
-
observations =
|
|
1914
|
-
bank_id, entity_id, limit=5, request_context=request_context
|
|
1915
|
-
)
|
|
2131
|
+
observations = all_observations.get(entity_id, [])
|
|
1916
2132
|
|
|
1917
2133
|
# Calculate tokens for this entity's observations
|
|
1918
2134
|
entity_tokens = 0
|
|
@@ -1930,8 +2146,10 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1930
2146
|
entity_id=entity_id, canonical_name=entity_name, observations=included_observations
|
|
1931
2147
|
)
|
|
1932
2148
|
total_entity_tokens += entity_tokens
|
|
2149
|
+
entity_obs_duration = time.time() - step_start
|
|
1933
2150
|
|
|
1934
2151
|
# Fetch chunks if requested
|
|
2152
|
+
step_start = time.time()
|
|
1935
2153
|
chunks_dict = None
|
|
1936
2154
|
if include_chunks and top_scored:
|
|
1937
2155
|
from .response_models import ChunkInfo
|
|
@@ -1991,6 +2209,12 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
1991
2209
|
chunk_text=chunk_text, chunk_index=row["chunk_index"], truncated=False
|
|
1992
2210
|
)
|
|
1993
2211
|
total_chunk_tokens += chunk_tokens
|
|
2212
|
+
chunks_duration = time.time() - step_start
|
|
2213
|
+
|
|
2214
|
+
# Log entity/chunk fetch timing (only if any enrichment was requested)
|
|
2215
|
+
log_buffer.append(
|
|
2216
|
+
f" [6] Response enrichment: entity_map={entity_map_duration:.3f}s, entity_obs={entity_obs_duration:.3f}s, chunks={chunks_duration:.3f}s"
|
|
2217
|
+
)
|
|
1994
2218
|
|
|
1995
2219
|
# Finalize trace if enabled
|
|
1996
2220
|
trace_dict = None
|
|
@@ -2002,8 +2226,15 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2002
2226
|
total_time = time.time() - recall_start
|
|
2003
2227
|
num_chunks = len(chunks_dict) if chunks_dict else 0
|
|
2004
2228
|
num_entities = len(entities_dict) if entities_dict else 0
|
|
2229
|
+
# Include wait times in log if significant
|
|
2230
|
+
wait_parts = []
|
|
2231
|
+
if semaphore_wait > 0.01:
|
|
2232
|
+
wait_parts.append(f"sem={semaphore_wait:.3f}s")
|
|
2233
|
+
if max_conn_wait > 0.01:
|
|
2234
|
+
wait_parts.append(f"conn={max_conn_wait:.3f}s")
|
|
2235
|
+
wait_info = f" | waits: {', '.join(wait_parts)}" if wait_parts else ""
|
|
2005
2236
|
log_buffer.append(
|
|
2006
|
-
f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s"
|
|
2237
|
+
f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s{wait_info}"
|
|
2007
2238
|
)
|
|
2008
2239
|
logger.info("\n" + "\n".join(log_buffer))
|
|
2009
2240
|
|
|
@@ -2073,11 +2304,11 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2073
2304
|
doc = await conn.fetchrow(
|
|
2074
2305
|
f"""
|
|
2075
2306
|
SELECT d.id, d.bank_id, d.original_text, d.content_hash,
|
|
2076
|
-
d.created_at, d.updated_at, COUNT(mu.id) as unit_count
|
|
2307
|
+
d.created_at, d.updated_at, d.tags, COUNT(mu.id) as unit_count
|
|
2077
2308
|
FROM {fq_table("documents")} d
|
|
2078
2309
|
LEFT JOIN {fq_table("memory_units")} mu ON mu.document_id = d.id
|
|
2079
2310
|
WHERE d.id = $1 AND d.bank_id = $2
|
|
2080
|
-
GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
|
|
2311
|
+
GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at, d.tags
|
|
2081
2312
|
""",
|
|
2082
2313
|
document_id,
|
|
2083
2314
|
bank_id,
|
|
@@ -2094,6 +2325,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2094
2325
|
"memory_unit_count": doc["unit_count"],
|
|
2095
2326
|
"created_at": doc["created_at"].isoformat() if doc["created_at"] else None,
|
|
2096
2327
|
"updated_at": doc["updated_at"].isoformat() if doc["updated_at"] else None,
|
|
2328
|
+
"tags": list(doc["tags"]) if doc["tags"] else [],
|
|
2097
2329
|
}
|
|
2098
2330
|
|
|
2099
2331
|
async def delete_document(
|
|
@@ -2199,9 +2431,10 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2199
2431
|
await self._authenticate_tenant(request_context)
|
|
2200
2432
|
pool = await self._get_pool()
|
|
2201
2433
|
async with acquire_with_retry(pool) as conn:
|
|
2202
|
-
# Ensure connection is not in read-only mode (can happen with connection poolers)
|
|
2203
|
-
await conn.execute("SET SESSION CHARACTERISTICS AS TRANSACTION READ WRITE")
|
|
2204
2434
|
async with conn.transaction():
|
|
2435
|
+
# Ensure transaction is not in read-only mode (can happen with connection poolers)
|
|
2436
|
+
# Using SET LOCAL so it only affects this transaction, not the session
|
|
2437
|
+
await conn.execute("SET LOCAL transaction_read_only TO off")
|
|
2205
2438
|
try:
|
|
2206
2439
|
if fact_type:
|
|
2207
2440
|
# Delete only memories of a specific fact type
|
|
@@ -2258,6 +2491,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2258
2491
|
bank_id: str | None = None,
|
|
2259
2492
|
fact_type: str | None = None,
|
|
2260
2493
|
*,
|
|
2494
|
+
limit: int = 1000,
|
|
2261
2495
|
request_context: "RequestContext",
|
|
2262
2496
|
):
|
|
2263
2497
|
"""
|
|
@@ -2266,10 +2500,11 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2266
2500
|
Args:
|
|
2267
2501
|
bank_id: Filter by bank ID
|
|
2268
2502
|
fact_type: Filter by fact type (world, experience, opinion)
|
|
2503
|
+
limit: Maximum number of items to return (default: 1000)
|
|
2269
2504
|
request_context: Request context for authentication.
|
|
2270
2505
|
|
|
2271
2506
|
Returns:
|
|
2272
|
-
Dict with nodes, edges, and
|
|
2507
|
+
Dict with nodes, edges, table_rows, total_units, and limit
|
|
2273
2508
|
"""
|
|
2274
2509
|
await self._authenticate_tenant(request_context)
|
|
2275
2510
|
pool = await self._get_pool()
|
|
@@ -2291,15 +2526,29 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2291
2526
|
|
|
2292
2527
|
where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
|
|
2293
2528
|
|
|
2529
|
+
# Get total count first
|
|
2530
|
+
total_count_result = await conn.fetchrow(
|
|
2531
|
+
f"""
|
|
2532
|
+
SELECT COUNT(*) as total
|
|
2533
|
+
FROM {fq_table("memory_units")}
|
|
2534
|
+
{where_clause}
|
|
2535
|
+
""",
|
|
2536
|
+
*query_params,
|
|
2537
|
+
)
|
|
2538
|
+
total_count = total_count_result["total"] if total_count_result else 0
|
|
2539
|
+
|
|
2540
|
+
# Get units with limit
|
|
2541
|
+
param_count += 1
|
|
2294
2542
|
units = await conn.fetch(
|
|
2295
2543
|
f"""
|
|
2296
2544
|
SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
|
|
2297
2545
|
FROM {fq_table("memory_units")}
|
|
2298
2546
|
{where_clause}
|
|
2299
2547
|
ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
|
|
2300
|
-
LIMIT
|
|
2548
|
+
LIMIT ${param_count}
|
|
2301
2549
|
""",
|
|
2302
2550
|
*query_params,
|
|
2551
|
+
limit,
|
|
2303
2552
|
)
|
|
2304
2553
|
|
|
2305
2554
|
# Get links, filtering to only include links between units of the selected agent
|
|
@@ -2436,7 +2685,7 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2436
2685
|
}
|
|
2437
2686
|
)
|
|
2438
2687
|
|
|
2439
|
-
return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units":
|
|
2688
|
+
return {"nodes": nodes, "edges": edges, "table_rows": table_rows, "total_units": total_count, "limit": limit}
|
|
2440
2689
|
|
|
2441
2690
|
async def list_memory_units(
|
|
2442
2691
|
self,
|
|
@@ -2565,6 +2814,68 @@ class MemoryEngine(MemoryEngineInterface):
|
|
|
2565
2814
|
|
|
2566
2815
|
return {"items": items, "total": total, "limit": limit, "offset": offset}
|
|
2567
2816
|
|
|
2817
|
+
async def get_memory_unit(
|
|
2818
|
+
self,
|
|
2819
|
+
bank_id: str,
|
|
2820
|
+
memory_id: str,
|
|
2821
|
+
request_context: "RequestContext",
|
|
2822
|
+
):
|
|
2823
|
+
"""
|
|
2824
|
+
Get a single memory unit by ID.
|
|
2825
|
+
|
|
2826
|
+
Args:
|
|
2827
|
+
bank_id: Bank ID
|
|
2828
|
+
memory_id: Memory unit ID
|
|
2829
|
+
request_context: Request context for authentication.
|
|
2830
|
+
|
|
2831
|
+
Returns:
|
|
2832
|
+
Dict with memory unit data or None if not found
|
|
2833
|
+
"""
|
|
2834
|
+
await self._authenticate_tenant(request_context)
|
|
2835
|
+
pool = await self._get_pool()
|
|
2836
|
+
async with acquire_with_retry(pool) as conn:
|
|
2837
|
+
# Get the memory unit
|
|
2838
|
+
row = await conn.fetchrow(
|
|
2839
|
+
f"""
|
|
2840
|
+
SELECT id, text, context, event_date, occurred_start, occurred_end,
|
|
2841
|
+
mentioned_at, fact_type, document_id, chunk_id, tags
|
|
2842
|
+
FROM {fq_table("memory_units")}
|
|
2843
|
+
WHERE id = $1 AND bank_id = $2
|
|
2844
|
+
""",
|
|
2845
|
+
memory_id,
|
|
2846
|
+
bank_id,
|
|
2847
|
+
)
|
|
2848
|
+
|
|
2849
|
+
if not row:
|
|
2850
|
+
return None
|
|
2851
|
+
|
|
2852
|
+
# Get entity information
|
|
2853
|
+
entities_rows = await conn.fetch(
|
|
2854
|
+
f"""
|
|
2855
|
+
SELECT e.canonical_name
|
|
2856
|
+
FROM {fq_table("unit_entities")} ue
|
|
2857
|
+
JOIN {fq_table("entities")} e ON ue.entity_id = e.id
|
|
2858
|
+
WHERE ue.unit_id = $1
|
|
2859
|
+
""",
|
|
2860
|
+
row["id"],
|
|
2861
|
+
)
|
|
2862
|
+
entities = [r["canonical_name"] for r in entities_rows]
|
|
2863
|
+
|
|
2864
|
+
return {
|
|
2865
|
+
"id": str(row["id"]),
|
|
2866
|
+
"text": row["text"],
|
|
2867
|
+
"context": row["context"] if row["context"] else "",
|
|
2868
|
+
"date": row["event_date"].isoformat() if row["event_date"] else "",
|
|
2869
|
+
"type": row["fact_type"],
|
|
2870
|
+
"mentioned_at": row["mentioned_at"].isoformat() if row["mentioned_at"] else None,
|
|
2871
|
+
"occurred_start": row["occurred_start"].isoformat() if row["occurred_start"] else None,
|
|
2872
|
+
"occurred_end": row["occurred_end"].isoformat() if row["occurred_end"] else None,
|
|
2873
|
+
"entities": entities,
|
|
2874
|
+
"document_id": row["document_id"] if row["document_id"] else None,
|
|
2875
|
+
"chunk_id": str(row["chunk_id"]) if row["chunk_id"] else None,
|
|
2876
|
+
"tags": row["tags"] if row["tags"] else [],
|
|
2877
|
+
}
|
|
2878
|
+
|
|
2568
2879
|
async def list_documents(
|
|
2569
2880
|
self,
|
|
2570
2881
|
bank_id: str,
|
|
@@ -2799,7 +3110,7 @@ Guidelines:
|
|
|
2799
3110
|
- Small changes in confidence are normal; large jumps should be rare"""
|
|
2800
3111
|
|
|
2801
3112
|
try:
|
|
2802
|
-
result = await self.
|
|
3113
|
+
result = await self._reflect_llm_config.call(
|
|
2803
3114
|
messages=[
|
|
2804
3115
|
{"role": "system", "content": "You evaluate and update opinions based on new information."},
|
|
2805
3116
|
{"role": "user", "content": evaluation_prompt},
|
|
@@ -2909,7 +3220,7 @@ Guidelines:
|
|
|
2909
3220
|
return
|
|
2910
3221
|
|
|
2911
3222
|
# Use cached LLM config
|
|
2912
|
-
if self.
|
|
3223
|
+
if self._reflect_llm_config is None:
|
|
2913
3224
|
logger.error("[REINFORCE] LLM config not available, skipping opinion reinforcement")
|
|
2914
3225
|
return
|
|
2915
3226
|
|
|
@@ -3054,7 +3365,9 @@ Guidelines:
|
|
|
3054
3365
|
"""
|
|
3055
3366
|
await self._authenticate_tenant(request_context)
|
|
3056
3367
|
pool = await self._get_pool()
|
|
3057
|
-
return await bank_utils.merge_bank_background(
|
|
3368
|
+
return await bank_utils.merge_bank_background(
|
|
3369
|
+
pool, self._reflect_llm_config, bank_id, new_info, update_disposition
|
|
3370
|
+
)
|
|
3058
3371
|
|
|
3059
3372
|
async def list_banks(
|
|
3060
3373
|
self,
|
|
@@ -3086,6 +3399,8 @@ Guidelines:
|
|
|
3086
3399
|
max_tokens: int = 4096,
|
|
3087
3400
|
response_schema: dict | None = None,
|
|
3088
3401
|
request_context: "RequestContext",
|
|
3402
|
+
tags: list[str] | None = None,
|
|
3403
|
+
tags_match: TagsMatch = "any",
|
|
3089
3404
|
) -> ReflectResult:
|
|
3090
3405
|
"""
|
|
3091
3406
|
Reflect and formulate an answer using bank identity, world facts, and opinions.
|
|
@@ -3114,7 +3429,7 @@ Guidelines:
|
|
|
3114
3429
|
- structured_output: Optional dict if response_schema was provided
|
|
3115
3430
|
"""
|
|
3116
3431
|
# Use cached LLM config
|
|
3117
|
-
if self.
|
|
3432
|
+
if self._reflect_llm_config is None:
|
|
3118
3433
|
raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
|
|
3119
3434
|
|
|
3120
3435
|
# Authenticate tenant and set schema in context (for fq_table())
|
|
@@ -3140,16 +3455,22 @@ Guidelines:
|
|
|
3140
3455
|
|
|
3141
3456
|
# Steps 1-3: Run multi-fact-type search (12-way retrieval: 4 methods × 3 fact types)
|
|
3142
3457
|
recall_start = time.time()
|
|
3143
|
-
|
|
3144
|
-
|
|
3145
|
-
|
|
3146
|
-
|
|
3147
|
-
|
|
3148
|
-
|
|
3149
|
-
|
|
3150
|
-
|
|
3151
|
-
|
|
3152
|
-
|
|
3458
|
+
metrics = get_metrics_collector()
|
|
3459
|
+
with metrics.record_operation(
|
|
3460
|
+
"recall", bank_id=bank_id, source="reflect", budget=budget.value if budget else None
|
|
3461
|
+
):
|
|
3462
|
+
search_result = await self.recall_async(
|
|
3463
|
+
bank_id=bank_id,
|
|
3464
|
+
query=query,
|
|
3465
|
+
budget=budget,
|
|
3466
|
+
max_tokens=4096,
|
|
3467
|
+
enable_trace=False,
|
|
3468
|
+
fact_type=["experience", "world", "opinion"],
|
|
3469
|
+
include_entities=True,
|
|
3470
|
+
request_context=request_context,
|
|
3471
|
+
tags=tags,
|
|
3472
|
+
tags_match=tags_match,
|
|
3473
|
+
)
|
|
3153
3474
|
recall_time = time.time() - recall_start
|
|
3154
3475
|
|
|
3155
3476
|
all_results = search_result.results
|
|
@@ -3205,7 +3526,7 @@ Guidelines:
|
|
|
3205
3526
|
response_format = JsonSchemaWrapper(response_schema)
|
|
3206
3527
|
|
|
3207
3528
|
llm_start = time.time()
|
|
3208
|
-
|
|
3529
|
+
llm_result, usage = await self._reflect_llm_config.call(
|
|
3209
3530
|
messages=messages,
|
|
3210
3531
|
scope="memory_reflect",
|
|
3211
3532
|
max_completion_tokens=max_tokens,
|
|
@@ -3214,17 +3535,18 @@ Guidelines:
|
|
|
3214
3535
|
# Don't enforce strict_schema - not all providers support it and may retry forever
|
|
3215
3536
|
# Soft enforcement (schema in prompt + json_object mode) is sufficient
|
|
3216
3537
|
strict_schema=False,
|
|
3538
|
+
return_usage=True,
|
|
3217
3539
|
)
|
|
3218
3540
|
llm_time = time.time() - llm_start
|
|
3219
3541
|
|
|
3220
3542
|
# Handle response based on whether structured output was requested
|
|
3221
3543
|
if response_schema is not None:
|
|
3222
|
-
structured_output =
|
|
3544
|
+
structured_output = llm_result
|
|
3223
3545
|
answer_text = "" # Empty for backward compatibility
|
|
3224
3546
|
log_buffer.append(f"[REFLECT {reflect_id}] Structured output generated")
|
|
3225
3547
|
else:
|
|
3226
3548
|
structured_output = None
|
|
3227
|
-
answer_text =
|
|
3549
|
+
answer_text = llm_result.strip()
|
|
3228
3550
|
|
|
3229
3551
|
# Submit form_opinion task for background processing
|
|
3230
3552
|
# Pass tenant_id from request context for internal authentication in background task
|
|
@@ -3250,6 +3572,7 @@ Guidelines:
|
|
|
3250
3572
|
based_on={"world": world_results, "experience": agent_results, "opinion": opinion_results},
|
|
3251
3573
|
new_opinions=[], # Opinions are being extracted asynchronously
|
|
3252
3574
|
structured_output=structured_output,
|
|
3575
|
+
usage=usage,
|
|
3253
3576
|
)
|
|
3254
3577
|
|
|
3255
3578
|
# Call post-operation hook if validator is configured
|
|
@@ -3289,7 +3612,9 @@ Guidelines:
|
|
|
3289
3612
|
"""
|
|
3290
3613
|
try:
|
|
3291
3614
|
# Extract opinions from the answer
|
|
3292
|
-
new_opinions = await think_utils.extract_opinions_from_text(
|
|
3615
|
+
new_opinions = await think_utils.extract_opinions_from_text(
|
|
3616
|
+
self._reflect_llm_config, text=answer_text, query=query
|
|
3617
|
+
)
|
|
3293
3618
|
|
|
3294
3619
|
# Store new opinions
|
|
3295
3620
|
if new_opinions:
|
|
@@ -3360,37 +3685,110 @@ Guidelines:
|
|
|
3360
3685
|
observations.append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
|
|
3361
3686
|
return observations
|
|
3362
3687
|
|
|
3688
|
+
async def get_entity_observations_batch(
|
|
3689
|
+
self,
|
|
3690
|
+
bank_id: str,
|
|
3691
|
+
entity_ids: list[str],
|
|
3692
|
+
*,
|
|
3693
|
+
limit_per_entity: int = 5,
|
|
3694
|
+
request_context: "RequestContext",
|
|
3695
|
+
) -> dict[str, list[Any]]:
|
|
3696
|
+
"""
|
|
3697
|
+
Get observations for multiple entities in a single query.
|
|
3698
|
+
|
|
3699
|
+
Args:
|
|
3700
|
+
bank_id: bank IDentifier
|
|
3701
|
+
entity_ids: List of entity UUIDs to get observations for
|
|
3702
|
+
limit_per_entity: Maximum observations per entity
|
|
3703
|
+
request_context: Request context for authentication.
|
|
3704
|
+
|
|
3705
|
+
Returns:
|
|
3706
|
+
Dict mapping entity_id -> list of EntityObservation objects
|
|
3707
|
+
"""
|
|
3708
|
+
if not entity_ids:
|
|
3709
|
+
return {}
|
|
3710
|
+
|
|
3711
|
+
await self._authenticate_tenant(request_context)
|
|
3712
|
+
pool = await self._get_pool()
|
|
3713
|
+
async with acquire_with_retry(pool) as conn:
|
|
3714
|
+
# Use window function to limit observations per entity
|
|
3715
|
+
rows = await conn.fetch(
|
|
3716
|
+
f"""
|
|
3717
|
+
WITH ranked AS (
|
|
3718
|
+
SELECT
|
|
3719
|
+
ue.entity_id,
|
|
3720
|
+
mu.text,
|
|
3721
|
+
mu.mentioned_at,
|
|
3722
|
+
ROW_NUMBER() OVER (PARTITION BY ue.entity_id ORDER BY mu.mentioned_at DESC) as rn
|
|
3723
|
+
FROM {fq_table("memory_units")} mu
|
|
3724
|
+
JOIN {fq_table("unit_entities")} ue ON mu.id = ue.unit_id
|
|
3725
|
+
WHERE mu.bank_id = $1
|
|
3726
|
+
AND mu.fact_type = 'observation'
|
|
3727
|
+
AND ue.entity_id = ANY($2::uuid[])
|
|
3728
|
+
)
|
|
3729
|
+
SELECT entity_id, text, mentioned_at
|
|
3730
|
+
FROM ranked
|
|
3731
|
+
WHERE rn <= $3
|
|
3732
|
+
ORDER BY entity_id, rn
|
|
3733
|
+
""",
|
|
3734
|
+
bank_id,
|
|
3735
|
+
[uuid.UUID(eid) for eid in entity_ids],
|
|
3736
|
+
limit_per_entity,
|
|
3737
|
+
)
|
|
3738
|
+
|
|
3739
|
+
result: dict[str, list[Any]] = {eid: [] for eid in entity_ids}
|
|
3740
|
+
for row in rows:
|
|
3741
|
+
entity_id = str(row["entity_id"])
|
|
3742
|
+
mentioned_at = row["mentioned_at"].isoformat() if row["mentioned_at"] else None
|
|
3743
|
+
result[entity_id].append(EntityObservation(text=row["text"], mentioned_at=mentioned_at))
|
|
3744
|
+
return result
|
|
3745
|
+
|
|
3363
3746
|
async def list_entities(
|
|
3364
3747
|
self,
|
|
3365
3748
|
bank_id: str,
|
|
3366
3749
|
*,
|
|
3367
3750
|
limit: int = 100,
|
|
3751
|
+
offset: int = 0,
|
|
3368
3752
|
request_context: "RequestContext",
|
|
3369
|
-
) ->
|
|
3753
|
+
) -> dict[str, Any]:
|
|
3370
3754
|
"""
|
|
3371
|
-
List all entities for a bank.
|
|
3755
|
+
List all entities for a bank with pagination.
|
|
3372
3756
|
|
|
3373
3757
|
Args:
|
|
3374
3758
|
bank_id: bank IDentifier
|
|
3375
3759
|
limit: Maximum number of entities to return
|
|
3760
|
+
offset: Offset for pagination
|
|
3376
3761
|
request_context: Request context for authentication.
|
|
3377
3762
|
|
|
3378
3763
|
Returns:
|
|
3379
|
-
|
|
3764
|
+
Dict with items, total, limit, offset
|
|
3380
3765
|
"""
|
|
3381
3766
|
await self._authenticate_tenant(request_context)
|
|
3382
3767
|
pool = await self._get_pool()
|
|
3383
3768
|
async with acquire_with_retry(pool) as conn:
|
|
3769
|
+
# Get total count
|
|
3770
|
+
total_row = await conn.fetchrow(
|
|
3771
|
+
f"""
|
|
3772
|
+
SELECT COUNT(*) as total
|
|
3773
|
+
FROM {fq_table("entities")}
|
|
3774
|
+
WHERE bank_id = $1
|
|
3775
|
+
""",
|
|
3776
|
+
bank_id,
|
|
3777
|
+
)
|
|
3778
|
+
total = total_row["total"] if total_row else 0
|
|
3779
|
+
|
|
3780
|
+
# Get paginated entities
|
|
3384
3781
|
rows = await conn.fetch(
|
|
3385
3782
|
f"""
|
|
3386
3783
|
SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
|
|
3387
3784
|
FROM {fq_table("entities")}
|
|
3388
3785
|
WHERE bank_id = $1
|
|
3389
3786
|
ORDER BY mention_count DESC, last_seen DESC
|
|
3390
|
-
LIMIT $2
|
|
3787
|
+
LIMIT $2 OFFSET $3
|
|
3391
3788
|
""",
|
|
3392
3789
|
bank_id,
|
|
3393
3790
|
limit,
|
|
3791
|
+
offset,
|
|
3394
3792
|
)
|
|
3395
3793
|
|
|
3396
3794
|
entities = []
|
|
@@ -3417,7 +3815,91 @@ Guidelines:
|
|
|
3417
3815
|
"metadata": metadata,
|
|
3418
3816
|
}
|
|
3419
3817
|
)
|
|
3420
|
-
return
|
|
3818
|
+
return {
|
|
3819
|
+
"items": entities,
|
|
3820
|
+
"total": total,
|
|
3821
|
+
"limit": limit,
|
|
3822
|
+
"offset": offset,
|
|
3823
|
+
}
|
|
3824
|
+
|
|
3825
|
+
async def list_tags(
|
|
3826
|
+
self,
|
|
3827
|
+
bank_id: str,
|
|
3828
|
+
*,
|
|
3829
|
+
pattern: str | None = None,
|
|
3830
|
+
limit: int = 100,
|
|
3831
|
+
offset: int = 0,
|
|
3832
|
+
request_context: "RequestContext",
|
|
3833
|
+
) -> dict[str, Any]:
|
|
3834
|
+
"""
|
|
3835
|
+
List all unique tags for a bank with usage counts.
|
|
3836
|
+
|
|
3837
|
+
Use this to discover available tags or expand wildcard patterns.
|
|
3838
|
+
Supports '*' as wildcard for flexible matching (case-insensitive):
|
|
3839
|
+
- 'user:*' matches user:alice, user:bob
|
|
3840
|
+
- '*-admin' matches role-admin, super-admin
|
|
3841
|
+
- 'env*-prod' matches env-prod, environment-prod
|
|
3842
|
+
|
|
3843
|
+
Args:
|
|
3844
|
+
bank_id: Bank identifier
|
|
3845
|
+
pattern: Wildcard pattern to filter tags (use '*' as wildcard, case-insensitive)
|
|
3846
|
+
limit: Maximum number of tags to return
|
|
3847
|
+
offset: Offset for pagination
|
|
3848
|
+
request_context: Request context for authentication.
|
|
3849
|
+
|
|
3850
|
+
Returns:
|
|
3851
|
+
Dict with items (list of {tag, count}), total, limit, offset
|
|
3852
|
+
"""
|
|
3853
|
+
await self._authenticate_tenant(request_context)
|
|
3854
|
+
pool = await self._get_pool()
|
|
3855
|
+
async with acquire_with_retry(pool) as conn:
|
|
3856
|
+
# Build pattern filter if provided (convert * to % for ILIKE)
|
|
3857
|
+
pattern_clause = ""
|
|
3858
|
+
params: list[Any] = [bank_id]
|
|
3859
|
+
if pattern:
|
|
3860
|
+
# Convert wildcard pattern: * -> % for SQL ILIKE
|
|
3861
|
+
sql_pattern = pattern.replace("*", "%")
|
|
3862
|
+
pattern_clause = "AND tag ILIKE $2"
|
|
3863
|
+
params.append(sql_pattern)
|
|
3864
|
+
|
|
3865
|
+
# Get total count of distinct tags matching pattern
|
|
3866
|
+
total_row = await conn.fetchrow(
|
|
3867
|
+
f"""
|
|
3868
|
+
SELECT COUNT(DISTINCT tag) as total
|
|
3869
|
+
FROM {fq_table("memory_units")}, unnest(tags) AS tag
|
|
3870
|
+
WHERE bank_id = $1 AND tags IS NOT NULL AND tags != '{{}}'
|
|
3871
|
+
{pattern_clause}
|
|
3872
|
+
""",
|
|
3873
|
+
*params,
|
|
3874
|
+
)
|
|
3875
|
+
total = total_row["total"] if total_row else 0
|
|
3876
|
+
|
|
3877
|
+
# Get paginated tags with counts, ordered by frequency
|
|
3878
|
+
limit_param = len(params) + 1
|
|
3879
|
+
offset_param = len(params) + 2
|
|
3880
|
+
params.extend([limit, offset])
|
|
3881
|
+
|
|
3882
|
+
rows = await conn.fetch(
|
|
3883
|
+
f"""
|
|
3884
|
+
SELECT tag, COUNT(*) as count
|
|
3885
|
+
FROM {fq_table("memory_units")}, unnest(tags) AS tag
|
|
3886
|
+
WHERE bank_id = $1 AND tags IS NOT NULL AND tags != '{{}}'
|
|
3887
|
+
{pattern_clause}
|
|
3888
|
+
GROUP BY tag
|
|
3889
|
+
ORDER BY count DESC, tag ASC
|
|
3890
|
+
LIMIT ${limit_param} OFFSET ${offset_param}
|
|
3891
|
+
""",
|
|
3892
|
+
*params,
|
|
3893
|
+
)
|
|
3894
|
+
|
|
3895
|
+
items = [{"tag": row["tag"], "count": row["count"]} for row in rows]
|
|
3896
|
+
|
|
3897
|
+
return {
|
|
3898
|
+
"items": items,
|
|
3899
|
+
"total": total,
|
|
3900
|
+
"limit": limit,
|
|
3901
|
+
"offset": offset,
|
|
3902
|
+
}
|
|
3421
3903
|
|
|
3422
3904
|
async def get_entity_state(
|
|
3423
3905
|
self,
|
|
@@ -3540,7 +4022,9 @@ Guidelines:
|
|
|
3540
4022
|
)
|
|
3541
4023
|
|
|
3542
4024
|
# Step 3: Extract observations using LLM (no personality)
|
|
3543
|
-
observations = await observation_utils.extract_observations_from_facts(
|
|
4025
|
+
observations = await observation_utils.extract_observations_from_facts(
|
|
4026
|
+
self._reflect_llm_config, entity_name, facts
|
|
4027
|
+
)
|
|
3544
4028
|
|
|
3545
4029
|
if not observations:
|
|
3546
4030
|
return []
|
|
@@ -4061,6 +4545,7 @@ Guidelines:
|
|
|
4061
4545
|
contents: list[dict[str, Any]],
|
|
4062
4546
|
*,
|
|
4063
4547
|
request_context: "RequestContext",
|
|
4548
|
+
document_tags: list[str] | None = None,
|
|
4064
4549
|
) -> dict[str, Any]:
|
|
4065
4550
|
"""Submit a batch retain operation to run asynchronously."""
|
|
4066
4551
|
await self._authenticate_tenant(request_context)
|
|
@@ -4084,14 +4569,16 @@ Guidelines:
|
|
|
4084
4569
|
)
|
|
4085
4570
|
|
|
4086
4571
|
# Submit task to background queue
|
|
4087
|
-
|
|
4088
|
-
|
|
4089
|
-
|
|
4090
|
-
|
|
4091
|
-
|
|
4092
|
-
|
|
4093
|
-
|
|
4094
|
-
|
|
4572
|
+
task_payload = {
|
|
4573
|
+
"type": "batch_retain",
|
|
4574
|
+
"operation_id": str(operation_id),
|
|
4575
|
+
"bank_id": bank_id,
|
|
4576
|
+
"contents": contents,
|
|
4577
|
+
}
|
|
4578
|
+
if document_tags:
|
|
4579
|
+
task_payload["document_tags"] = document_tags
|
|
4580
|
+
|
|
4581
|
+
await self._task_backend.submit_task(task_payload)
|
|
4095
4582
|
|
|
4096
4583
|
logger.info(f"Retain task queued for bank_id={bank_id}, {len(contents)} items, operation_id={operation_id}")
|
|
4097
4584
|
|