hindsight-api 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +38 -0
- hindsight_api/api/__init__.py +105 -0
- hindsight_api/api/http.py +1872 -0
- hindsight_api/api/mcp.py +157 -0
- hindsight_api/engine/__init__.py +47 -0
- hindsight_api/engine/cross_encoder.py +97 -0
- hindsight_api/engine/db_utils.py +93 -0
- hindsight_api/engine/embeddings.py +113 -0
- hindsight_api/engine/entity_resolver.py +575 -0
- hindsight_api/engine/llm_wrapper.py +269 -0
- hindsight_api/engine/memory_engine.py +3095 -0
- hindsight_api/engine/query_analyzer.py +519 -0
- hindsight_api/engine/response_models.py +222 -0
- hindsight_api/engine/retain/__init__.py +50 -0
- hindsight_api/engine/retain/bank_utils.py +423 -0
- hindsight_api/engine/retain/chunk_storage.py +82 -0
- hindsight_api/engine/retain/deduplication.py +104 -0
- hindsight_api/engine/retain/embedding_processing.py +62 -0
- hindsight_api/engine/retain/embedding_utils.py +54 -0
- hindsight_api/engine/retain/entity_processing.py +90 -0
- hindsight_api/engine/retain/fact_extraction.py +1027 -0
- hindsight_api/engine/retain/fact_storage.py +176 -0
- hindsight_api/engine/retain/link_creation.py +121 -0
- hindsight_api/engine/retain/link_utils.py +651 -0
- hindsight_api/engine/retain/orchestrator.py +405 -0
- hindsight_api/engine/retain/types.py +206 -0
- hindsight_api/engine/search/__init__.py +15 -0
- hindsight_api/engine/search/fusion.py +122 -0
- hindsight_api/engine/search/observation_utils.py +132 -0
- hindsight_api/engine/search/reranking.py +103 -0
- hindsight_api/engine/search/retrieval.py +503 -0
- hindsight_api/engine/search/scoring.py +161 -0
- hindsight_api/engine/search/temporal_extraction.py +64 -0
- hindsight_api/engine/search/think_utils.py +255 -0
- hindsight_api/engine/search/trace.py +215 -0
- hindsight_api/engine/search/tracer.py +447 -0
- hindsight_api/engine/search/types.py +160 -0
- hindsight_api/engine/task_backend.py +223 -0
- hindsight_api/engine/utils.py +203 -0
- hindsight_api/metrics.py +227 -0
- hindsight_api/migrations.py +163 -0
- hindsight_api/models.py +309 -0
- hindsight_api/pg0.py +425 -0
- hindsight_api/web/__init__.py +12 -0
- hindsight_api/web/server.py +143 -0
- hindsight_api-0.0.13.dist-info/METADATA +41 -0
- hindsight_api-0.0.13.dist-info/RECORD +48 -0
- hindsight_api-0.0.13.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,3095 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory Engine for Memory Banks.
|
|
3
|
+
|
|
4
|
+
This implements a sophisticated memory architecture that combines:
|
|
5
|
+
1. Temporal links: Memories connected by time proximity
|
|
6
|
+
2. Semantic links: Memories connected by meaning/similarity
|
|
7
|
+
3. Entity links: Memories connected by shared entities (PERSON, ORG, etc.)
|
|
8
|
+
4. Spreading activation: Search through the graph with activation decay
|
|
9
|
+
5. Dynamic weighting: Recency and frequency-based importance
|
|
10
|
+
"""
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
from datetime import datetime, timedelta, timezone
|
|
14
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, TypedDict
|
|
15
|
+
import asyncpg
|
|
16
|
+
import asyncio
|
|
17
|
+
from .embeddings import Embeddings, SentenceTransformersEmbeddings
|
|
18
|
+
from .cross_encoder import CrossEncoderModel
|
|
19
|
+
import time
|
|
20
|
+
import numpy as np
|
|
21
|
+
import uuid
|
|
22
|
+
import logging
|
|
23
|
+
from pydantic import BaseModel, Field
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RetainContentDict(TypedDict, total=False):
|
|
27
|
+
"""Type definition for content items in retain_batch_async.
|
|
28
|
+
|
|
29
|
+
Fields:
|
|
30
|
+
content: Text content to store (required)
|
|
31
|
+
context: Context about the content (optional)
|
|
32
|
+
event_date: When the content occurred (optional, defaults to now)
|
|
33
|
+
metadata: Custom key-value metadata (optional)
|
|
34
|
+
document_id: Document ID for this content item (optional)
|
|
35
|
+
"""
|
|
36
|
+
content: str # Required
|
|
37
|
+
context: str
|
|
38
|
+
event_date: datetime
|
|
39
|
+
metadata: Dict[str, str]
|
|
40
|
+
document_id: str
|
|
41
|
+
|
|
42
|
+
from .query_analyzer import QueryAnalyzer
|
|
43
|
+
from .search.scoring import (
|
|
44
|
+
calculate_recency_weight,
|
|
45
|
+
calculate_frequency_weight,
|
|
46
|
+
)
|
|
47
|
+
from .entity_resolver import EntityResolver
|
|
48
|
+
from .retain import embedding_utils, bank_utils
|
|
49
|
+
from .search import think_utils, observation_utils
|
|
50
|
+
from .llm_wrapper import LLMConfig
|
|
51
|
+
from .response_models import RecallResult as RecallResultModel, ReflectResult, MemoryFact, EntityState, EntityObservation
|
|
52
|
+
from .task_backend import TaskBackend, AsyncIOQueueBackend
|
|
53
|
+
from .search.reranking import CrossEncoderReranker
|
|
54
|
+
from ..pg0 import EmbeddedPostgres
|
|
55
|
+
from enum import Enum
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Budget(str, Enum):
|
|
59
|
+
"""Budget levels for recall/reflect operations."""
|
|
60
|
+
LOW = "low"
|
|
61
|
+
MID = "mid"
|
|
62
|
+
HIGH = "high"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def utcnow():
|
|
66
|
+
"""Get current UTC time with timezone info."""
|
|
67
|
+
return datetime.now(timezone.utc)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# Logger for memory system
|
|
71
|
+
logger = logging.getLogger(__name__)
|
|
72
|
+
|
|
73
|
+
from .db_utils import acquire_with_retry, retry_with_backoff
|
|
74
|
+
|
|
75
|
+
import tiktoken
|
|
76
|
+
from dateutil import parser as date_parser
|
|
77
|
+
|
|
78
|
+
# Cache tiktoken encoding for token budget filtering (module-level singleton)
|
|
79
|
+
_TIKTOKEN_ENCODING = None
|
|
80
|
+
|
|
81
|
+
def _get_tiktoken_encoding():
|
|
82
|
+
"""Get cached tiktoken encoding (cl100k_base for GPT-4/3.5)."""
|
|
83
|
+
global _TIKTOKEN_ENCODING
|
|
84
|
+
if _TIKTOKEN_ENCODING is None:
|
|
85
|
+
_TIKTOKEN_ENCODING = tiktoken.get_encoding("cl100k_base")
|
|
86
|
+
return _TIKTOKEN_ENCODING
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class MemoryEngine:
|
|
90
|
+
"""
|
|
91
|
+
Advanced memory system using temporal and semantic linking with PostgreSQL.
|
|
92
|
+
|
|
93
|
+
This class provides:
|
|
94
|
+
- Embedding generation for semantic search
|
|
95
|
+
- Entity, temporal, and semantic link creation
|
|
96
|
+
- Think operations for formulating answers with opinions
|
|
97
|
+
- bank profile and personality management
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
db_url: str,
|
|
103
|
+
memory_llm_provider: str,
|
|
104
|
+
memory_llm_api_key: str,
|
|
105
|
+
memory_llm_model: str,
|
|
106
|
+
memory_llm_base_url: Optional[str] = None,
|
|
107
|
+
embeddings: Optional[Embeddings] = None,
|
|
108
|
+
cross_encoder: Optional[CrossEncoderModel] = None,
|
|
109
|
+
query_analyzer: Optional[QueryAnalyzer] = None,
|
|
110
|
+
pool_min_size: int = 5,
|
|
111
|
+
pool_max_size: int = 100,
|
|
112
|
+
task_backend: Optional[TaskBackend] = None,
|
|
113
|
+
):
|
|
114
|
+
"""
|
|
115
|
+
Initialize the temporal + semantic memory system.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
db_url: PostgreSQL connection URL (postgresql://user:pass@host:port/dbname). Required.
|
|
119
|
+
memory_llm_provider: LLM provider for memory operations: "openai", "groq", or "ollama". Required.
|
|
120
|
+
memory_llm_api_key: API key for the LLM provider. Required.
|
|
121
|
+
memory_llm_model: Model name to use for all memory operations (put/think/opinions). Required.
|
|
122
|
+
memory_llm_base_url: Base URL for the LLM API. Optional. Defaults based on provider:
|
|
123
|
+
- groq: https://api.groq.com/openai/v1
|
|
124
|
+
- ollama: http://localhost:11434/v1
|
|
125
|
+
embeddings: Embeddings implementation to use. If not provided, uses SentenceTransformersEmbeddings
|
|
126
|
+
cross_encoder: Cross-encoder model for reranking. If not provided, uses default when cross-encoder reranker is selected
|
|
127
|
+
query_analyzer: Query analyzer implementation to use. If not provided, uses TransformerQueryAnalyzer
|
|
128
|
+
pool_min_size: Minimum number of connections in the pool (default: 5)
|
|
129
|
+
pool_max_size: Maximum number of connections in the pool (default: 100)
|
|
130
|
+
Increase for parallel think/search operations (e.g., 200-300 for 100+ parallel thinks)
|
|
131
|
+
task_backend: Custom task backend for async task execution. If not provided, uses AsyncIOQueueBackend
|
|
132
|
+
"""
|
|
133
|
+
if not db_url:
|
|
134
|
+
raise ValueError("Database url is required")
|
|
135
|
+
# Track pg0 instance (if used)
|
|
136
|
+
self._pg0: Optional[EmbeddedPostgres] = None
|
|
137
|
+
|
|
138
|
+
# Initialize PostgreSQL connection URL
|
|
139
|
+
# The actual URL will be set during initialize() after starting the server
|
|
140
|
+
self._use_pg0 = db_url == "pg0"
|
|
141
|
+
self.db_url = db_url if not self._use_pg0 else None
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# Set default base URL if not provided
|
|
145
|
+
if memory_llm_base_url is None:
|
|
146
|
+
if memory_llm_provider.lower() == "groq":
|
|
147
|
+
memory_llm_base_url = "https://api.groq.com/openai/v1"
|
|
148
|
+
elif memory_llm_provider.lower() == "ollama":
|
|
149
|
+
memory_llm_base_url = "http://localhost:11434/v1"
|
|
150
|
+
else:
|
|
151
|
+
memory_llm_base_url = ""
|
|
152
|
+
|
|
153
|
+
# Connection pool (will be created in initialize())
|
|
154
|
+
self._pool = None
|
|
155
|
+
self._initialized = False
|
|
156
|
+
self._pool_min_size = pool_min_size
|
|
157
|
+
self._pool_max_size = pool_max_size
|
|
158
|
+
|
|
159
|
+
# Initialize entity resolver (will be created in initialize())
|
|
160
|
+
self.entity_resolver = None
|
|
161
|
+
|
|
162
|
+
# Initialize embeddings
|
|
163
|
+
if embeddings is not None:
|
|
164
|
+
self.embeddings = embeddings
|
|
165
|
+
else:
|
|
166
|
+
self.embeddings = SentenceTransformersEmbeddings("BAAI/bge-small-en-v1.5")
|
|
167
|
+
|
|
168
|
+
# Initialize query analyzer
|
|
169
|
+
if query_analyzer is not None:
|
|
170
|
+
self.query_analyzer = query_analyzer
|
|
171
|
+
else:
|
|
172
|
+
from .query_analyzer import DateparserQueryAnalyzer
|
|
173
|
+
self.query_analyzer = DateparserQueryAnalyzer()
|
|
174
|
+
|
|
175
|
+
# Initialize LLM configuration
|
|
176
|
+
self._llm_config = LLMConfig(
|
|
177
|
+
provider=memory_llm_provider,
|
|
178
|
+
api_key=memory_llm_api_key,
|
|
179
|
+
base_url=memory_llm_base_url,
|
|
180
|
+
model=memory_llm_model,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Store client and model for convenience (deprecated: use _llm_config.call() instead)
|
|
184
|
+
self._llm_client = self._llm_config._client
|
|
185
|
+
self._llm_model = self._llm_config.model
|
|
186
|
+
|
|
187
|
+
# Initialize cross-encoder reranker (cached for performance)
|
|
188
|
+
self._cross_encoder_reranker = CrossEncoderReranker(cross_encoder=cross_encoder)
|
|
189
|
+
|
|
190
|
+
# Initialize task backend
|
|
191
|
+
self._task_backend = task_backend or AsyncIOQueueBackend(
|
|
192
|
+
batch_size=100,
|
|
193
|
+
batch_interval=1.0
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Backpressure mechanism: limit concurrent searches to prevent overwhelming the database
|
|
197
|
+
# Limit concurrent searches to prevent connection pool exhaustion
|
|
198
|
+
# Each search can use 2-4 connections, so with 10 concurrent searches
|
|
199
|
+
# we use ~20-40 connections max, staying well within pool limits
|
|
200
|
+
self._search_semaphore = asyncio.Semaphore(10)
|
|
201
|
+
|
|
202
|
+
# Backpressure for put operations: limit concurrent puts to prevent database contention
|
|
203
|
+
# Each put_batch holds a connection for the entire transaction, so we limit to 5
|
|
204
|
+
# concurrent puts to avoid connection pool exhaustion and reduce write contention
|
|
205
|
+
self._put_semaphore = asyncio.Semaphore(5)
|
|
206
|
+
|
|
207
|
+
# initialize encoding eagerly to avoid delaying the first time
|
|
208
|
+
_get_tiktoken_encoding()
|
|
209
|
+
|
|
210
|
+
async def _handle_access_count_update(self, task_dict: Dict[str, Any]):
|
|
211
|
+
"""
|
|
212
|
+
Handler for access count update tasks.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
task_dict: Dict with 'node_ids' key containing list of node IDs to update
|
|
216
|
+
"""
|
|
217
|
+
node_ids = task_dict.get('node_ids', [])
|
|
218
|
+
if not node_ids:
|
|
219
|
+
return
|
|
220
|
+
|
|
221
|
+
pool = await self._get_pool()
|
|
222
|
+
try:
|
|
223
|
+
# Convert string UUIDs to UUID type for faster matching
|
|
224
|
+
uuid_list = [uuid.UUID(nid) for nid in node_ids]
|
|
225
|
+
async with acquire_with_retry(pool) as conn:
|
|
226
|
+
await conn.execute(
|
|
227
|
+
"UPDATE memory_units SET access_count = access_count + 1 WHERE id = ANY($1::uuid[])",
|
|
228
|
+
uuid_list
|
|
229
|
+
)
|
|
230
|
+
except Exception as e:
|
|
231
|
+
logger.error(f"Access count handler: Error updating access counts: {e}")
|
|
232
|
+
|
|
233
|
+
async def _handle_batch_retain(self, task_dict: Dict[str, Any]):
|
|
234
|
+
"""
|
|
235
|
+
Handler for batch retain tasks.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
task_dict: Dict with 'bank_id', 'contents'
|
|
239
|
+
"""
|
|
240
|
+
try:
|
|
241
|
+
bank_id = task_dict.get('bank_id')
|
|
242
|
+
contents = task_dict.get('contents', [])
|
|
243
|
+
|
|
244
|
+
logger.info(f"[BATCH_RETAIN_TASK] Starting background batch retain for bank_id={bank_id}, {len(contents)} items")
|
|
245
|
+
|
|
246
|
+
await self.retain_batch_async(
|
|
247
|
+
bank_id=bank_id,
|
|
248
|
+
contents=contents
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
logger.info(f"[BATCH_RETAIN_TASK] Completed background batch retain for bank_id={bank_id}")
|
|
252
|
+
except Exception as e:
|
|
253
|
+
logger.error(f"Batch retain handler: Error processing batch retain: {e}")
|
|
254
|
+
import traceback
|
|
255
|
+
traceback.print_exc()
|
|
256
|
+
|
|
257
|
+
async def execute_task(self, task_dict: Dict[str, Any]):
|
|
258
|
+
"""
|
|
259
|
+
Execute a task by routing it to the appropriate handler.
|
|
260
|
+
|
|
261
|
+
This method is called by the task backend to execute tasks.
|
|
262
|
+
It receives a plain dict that can be serialized and sent over the network.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
task_dict: Task dictionary with 'type' key and other payload data
|
|
266
|
+
Example: {'type': 'access_count_update', 'node_ids': [...]}
|
|
267
|
+
"""
|
|
268
|
+
task_type = task_dict.get('type')
|
|
269
|
+
operation_id = task_dict.get('operation_id')
|
|
270
|
+
retry_count = task_dict.get('retry_count', 0)
|
|
271
|
+
max_retries = 3
|
|
272
|
+
|
|
273
|
+
# Check if operation was cancelled (only for tasks with operation_id)
|
|
274
|
+
if operation_id:
|
|
275
|
+
try:
|
|
276
|
+
pool = await self._get_pool()
|
|
277
|
+
async with acquire_with_retry(pool) as conn:
|
|
278
|
+
result = await conn.fetchrow(
|
|
279
|
+
"SELECT id FROM async_operations WHERE id = $1",
|
|
280
|
+
uuid.UUID(operation_id)
|
|
281
|
+
)
|
|
282
|
+
if not result:
|
|
283
|
+
# Operation was cancelled, skip processing
|
|
284
|
+
logger.info(f"Skipping cancelled operation: {operation_id}")
|
|
285
|
+
return
|
|
286
|
+
except Exception as e:
|
|
287
|
+
logger.error(f"Failed to check operation status {operation_id}: {e}")
|
|
288
|
+
# Continue with processing if we can't check status
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
if task_type == 'access_count_update':
|
|
292
|
+
await self._handle_access_count_update(task_dict)
|
|
293
|
+
elif task_type == 'reinforce_opinion':
|
|
294
|
+
await self._handle_reinforce_opinion(task_dict)
|
|
295
|
+
elif task_type == 'form_opinion':
|
|
296
|
+
await self._handle_form_opinion(task_dict)
|
|
297
|
+
elif task_type == 'batch_put':
|
|
298
|
+
await self._handle_batch_retain(task_dict)
|
|
299
|
+
elif task_type == 'regenerate_observations':
|
|
300
|
+
await self._handle_regenerate_observations(task_dict)
|
|
301
|
+
else:
|
|
302
|
+
logger.error(f"Unknown task type: {task_type}")
|
|
303
|
+
# Don't retry unknown task types
|
|
304
|
+
if operation_id:
|
|
305
|
+
await self._delete_operation_record(operation_id)
|
|
306
|
+
return
|
|
307
|
+
|
|
308
|
+
# Task succeeded - delete operation record
|
|
309
|
+
if operation_id:
|
|
310
|
+
await self._delete_operation_record(operation_id)
|
|
311
|
+
|
|
312
|
+
except Exception as e:
|
|
313
|
+
# Task failed - check if we should retry
|
|
314
|
+
logger.error(f"Task execution failed (attempt {retry_count + 1}/{max_retries + 1}): {task_type}, error: {e}")
|
|
315
|
+
import traceback
|
|
316
|
+
error_traceback = traceback.format_exc()
|
|
317
|
+
traceback.print_exc()
|
|
318
|
+
|
|
319
|
+
if retry_count < max_retries:
|
|
320
|
+
# Reschedule with incremented retry count
|
|
321
|
+
task_dict['retry_count'] = retry_count + 1
|
|
322
|
+
logger.info(f"Rescheduling task {task_type} (retry {retry_count + 1}/{max_retries})")
|
|
323
|
+
await self._task_backend.submit_task(task_dict)
|
|
324
|
+
else:
|
|
325
|
+
# Max retries exceeded - mark operation as failed
|
|
326
|
+
logger.error(f"Max retries exceeded for task {task_type}, marking as failed")
|
|
327
|
+
if operation_id:
|
|
328
|
+
await self._mark_operation_failed(operation_id, str(e), error_traceback)
|
|
329
|
+
|
|
330
|
+
async def _delete_operation_record(self, operation_id: str):
|
|
331
|
+
"""Helper to delete an operation record from the database."""
|
|
332
|
+
try:
|
|
333
|
+
pool = await self._get_pool()
|
|
334
|
+
async with acquire_with_retry(pool) as conn:
|
|
335
|
+
await conn.execute(
|
|
336
|
+
"DELETE FROM async_operations WHERE id = $1",
|
|
337
|
+
uuid.UUID(operation_id)
|
|
338
|
+
)
|
|
339
|
+
except Exception as e:
|
|
340
|
+
logger.error(f"Failed to delete async operation record {operation_id}: {e}")
|
|
341
|
+
|
|
342
|
+
async def _mark_operation_failed(self, operation_id: str, error_message: str, error_traceback: str):
|
|
343
|
+
"""Helper to mark an operation as failed in the database."""
|
|
344
|
+
try:
|
|
345
|
+
pool = await self._get_pool()
|
|
346
|
+
# Truncate error message to avoid extremely long strings
|
|
347
|
+
full_error = f"{error_message}\n\nTraceback:\n{error_traceback}"
|
|
348
|
+
truncated_error = full_error[:5000] if len(full_error) > 5000 else full_error
|
|
349
|
+
|
|
350
|
+
async with acquire_with_retry(pool) as conn:
|
|
351
|
+
await conn.execute(
|
|
352
|
+
"""
|
|
353
|
+
UPDATE async_operations
|
|
354
|
+
SET status = 'failed', error_message = $2
|
|
355
|
+
WHERE id = $1
|
|
356
|
+
""",
|
|
357
|
+
uuid.UUID(operation_id),
|
|
358
|
+
truncated_error
|
|
359
|
+
)
|
|
360
|
+
logger.info(f"Marked async operation as failed: {operation_id}")
|
|
361
|
+
except Exception as e:
|
|
362
|
+
logger.error(f"Failed to mark operation as failed {operation_id}: {e}")
|
|
363
|
+
|
|
364
|
+
async def initialize(self):
|
|
365
|
+
"""Initialize the connection pool, models, and background workers.
|
|
366
|
+
|
|
367
|
+
Loads models (embeddings, cross-encoder) in parallel with pg0 startup
|
|
368
|
+
for faster overall initialization.
|
|
369
|
+
"""
|
|
370
|
+
if self._initialized:
|
|
371
|
+
return
|
|
372
|
+
|
|
373
|
+
import concurrent.futures
|
|
374
|
+
|
|
375
|
+
# Run model loading in thread pool (CPU-bound) in parallel with pg0 startup
|
|
376
|
+
loop = asyncio.get_event_loop()
|
|
377
|
+
|
|
378
|
+
async def start_pg0():
|
|
379
|
+
"""Start pg0 if configured."""
|
|
380
|
+
if self._use_pg0:
|
|
381
|
+
self._pg0 = EmbeddedPostgres()
|
|
382
|
+
self.db_url = await self._pg0.ensure_running()
|
|
383
|
+
|
|
384
|
+
def load_embeddings():
|
|
385
|
+
"""Load embedding model (CPU-bound)."""
|
|
386
|
+
self.embeddings.load()
|
|
387
|
+
|
|
388
|
+
def load_cross_encoder():
|
|
389
|
+
"""Load cross-encoder model (CPU-bound)."""
|
|
390
|
+
self._cross_encoder_reranker.cross_encoder.load()
|
|
391
|
+
|
|
392
|
+
def load_query_analyzer():
|
|
393
|
+
"""Load query analyzer model (CPU-bound)."""
|
|
394
|
+
self.query_analyzer.load()
|
|
395
|
+
|
|
396
|
+
# Run pg0 and all model loads in parallel
|
|
397
|
+
# pg0 is async (IO-bound), models are sync (CPU-bound in thread pool)
|
|
398
|
+
# Use 3 workers to load all models concurrently
|
|
399
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
|
|
400
|
+
# Start all tasks
|
|
401
|
+
pg0_task = asyncio.create_task(start_pg0())
|
|
402
|
+
embeddings_future = loop.run_in_executor(executor, load_embeddings)
|
|
403
|
+
cross_encoder_future = loop.run_in_executor(executor, load_cross_encoder)
|
|
404
|
+
query_analyzer_future = loop.run_in_executor(executor, load_query_analyzer)
|
|
405
|
+
|
|
406
|
+
# Wait for all to complete
|
|
407
|
+
await asyncio.gather(
|
|
408
|
+
pg0_task, embeddings_future, cross_encoder_future, query_analyzer_future
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
logger.info(f"Connecting to PostgreSQL at {self.db_url}")
|
|
412
|
+
|
|
413
|
+
# Create connection pool
|
|
414
|
+
# For read-heavy workloads with many parallel think/search operations,
|
|
415
|
+
# we need a larger pool. Read operations don't need strong isolation.
|
|
416
|
+
self._pool = await asyncpg.create_pool(
|
|
417
|
+
self.db_url,
|
|
418
|
+
min_size=self._pool_min_size,
|
|
419
|
+
max_size=self._pool_max_size,
|
|
420
|
+
command_timeout=60,
|
|
421
|
+
statement_cache_size=0, # Disable prepared statement cache
|
|
422
|
+
timeout=30, # Connection acquisition timeout (seconds)
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# Initialize entity resolver with pool
|
|
426
|
+
self.entity_resolver = EntityResolver(self._pool)
|
|
427
|
+
|
|
428
|
+
# Set executor for task backend and initialize
|
|
429
|
+
self._task_backend.set_executor(self.execute_task)
|
|
430
|
+
await self._task_backend.initialize()
|
|
431
|
+
|
|
432
|
+
self._initialized = True
|
|
433
|
+
logger.info("Memory system initialized (pool and task backend started)")
|
|
434
|
+
|
|
435
|
+
async def _get_pool(self) -> asyncpg.Pool:
|
|
436
|
+
"""Get the connection pool (must call initialize() first)."""
|
|
437
|
+
if not self._initialized:
|
|
438
|
+
await self.initialize()
|
|
439
|
+
return self._pool
|
|
440
|
+
|
|
441
|
+
async def _acquire_connection(self):
|
|
442
|
+
"""
|
|
443
|
+
Acquire a connection from the pool with retry logic.
|
|
444
|
+
|
|
445
|
+
Returns an async context manager that yields a connection.
|
|
446
|
+
Retries on transient connection errors with exponential backoff.
|
|
447
|
+
"""
|
|
448
|
+
pool = await self._get_pool()
|
|
449
|
+
|
|
450
|
+
async def acquire():
|
|
451
|
+
return await pool.acquire()
|
|
452
|
+
|
|
453
|
+
return await _retry_with_backoff(acquire)
|
|
454
|
+
|
|
455
|
+
async def health_check(self) -> dict:
|
|
456
|
+
"""
|
|
457
|
+
Perform a health check by querying the database.
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
dict with status and optional error message
|
|
461
|
+
"""
|
|
462
|
+
try:
|
|
463
|
+
pool = await self._get_pool()
|
|
464
|
+
async with pool.acquire() as conn:
|
|
465
|
+
result = await conn.fetchval("SELECT 1")
|
|
466
|
+
if result == 1:
|
|
467
|
+
return {"status": "healthy", "database": "connected"}
|
|
468
|
+
else:
|
|
469
|
+
return {"status": "unhealthy", "database": "unexpected response"}
|
|
470
|
+
except Exception as e:
|
|
471
|
+
return {"status": "unhealthy", "database": "error", "error": str(e)}
|
|
472
|
+
|
|
473
|
+
async def close(self):
|
|
474
|
+
"""Close the connection pool and shutdown background workers."""
|
|
475
|
+
logger.info("close() started")
|
|
476
|
+
|
|
477
|
+
# Shutdown task backend
|
|
478
|
+
await self._task_backend.shutdown()
|
|
479
|
+
|
|
480
|
+
# Close pool
|
|
481
|
+
if self._pool is not None:
|
|
482
|
+
self._pool.terminate()
|
|
483
|
+
self._pool = None
|
|
484
|
+
|
|
485
|
+
self._initialized = False
|
|
486
|
+
|
|
487
|
+
# Stop pg0 if we started it
|
|
488
|
+
if self._pg0 is not None:
|
|
489
|
+
logger.info("Stopping pg0...")
|
|
490
|
+
await self._pg0.stop()
|
|
491
|
+
self._pg0 = None
|
|
492
|
+
logger.info("pg0 stopped")
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
async def wait_for_background_tasks(self):
|
|
496
|
+
"""
|
|
497
|
+
Wait for all pending background tasks to complete.
|
|
498
|
+
|
|
499
|
+
This is useful in tests to ensure background tasks (like opinion reinforcement)
|
|
500
|
+
complete before making assertions.
|
|
501
|
+
"""
|
|
502
|
+
if hasattr(self._task_backend, 'wait_for_pending_tasks'):
|
|
503
|
+
await self._task_backend.wait_for_pending_tasks()
|
|
504
|
+
|
|
505
|
+
def _format_readable_date(self, dt: datetime) -> str:
|
|
506
|
+
"""
|
|
507
|
+
Format a datetime into a readable string for temporal matching.
|
|
508
|
+
|
|
509
|
+
Examples:
|
|
510
|
+
- June 2024
|
|
511
|
+
- January 15, 2024
|
|
512
|
+
- December 2023
|
|
513
|
+
|
|
514
|
+
This helps queries like "camping in June" match facts that happened in June.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
dt: datetime object to format
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
Readable date string
|
|
521
|
+
"""
|
|
522
|
+
# Format as "Month Year" for most cases
|
|
523
|
+
# Could be extended to include day for very specific dates if needed
|
|
524
|
+
month_name = dt.strftime("%B") # Full month name (e.g., "June")
|
|
525
|
+
year = dt.strftime("%Y") # Year (e.g., "2024")
|
|
526
|
+
|
|
527
|
+
# For now, use "Month Year" format
|
|
528
|
+
# Could check if day is significant (not 1st or 15th) and include it
|
|
529
|
+
return f"{month_name} {year}"
|
|
530
|
+
|
|
531
|
+
async def _find_duplicate_facts_batch(
|
|
532
|
+
self,
|
|
533
|
+
conn,
|
|
534
|
+
bank_id: str,
|
|
535
|
+
texts: List[str],
|
|
536
|
+
embeddings: List[List[float]],
|
|
537
|
+
event_date: datetime,
|
|
538
|
+
time_window_hours: int = 24,
|
|
539
|
+
similarity_threshold: float = 0.95
|
|
540
|
+
) -> List[bool]:
|
|
541
|
+
"""
|
|
542
|
+
Check which facts are duplicates using semantic similarity + temporal window.
|
|
543
|
+
|
|
544
|
+
For each new fact, checks if a semantically similar fact already exists
|
|
545
|
+
within the time window. Uses pgvector cosine similarity for efficiency.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
conn: Database connection
|
|
549
|
+
bank_id: bank IDentifier
|
|
550
|
+
texts: List of fact texts to check
|
|
551
|
+
embeddings: Corresponding embeddings
|
|
552
|
+
event_date: Event date for temporal filtering
|
|
553
|
+
time_window_hours: Hours before/after event_date to search (default: 24)
|
|
554
|
+
similarity_threshold: Minimum cosine similarity to consider duplicate (default: 0.95)
|
|
555
|
+
|
|
556
|
+
Returns:
|
|
557
|
+
List of booleans - True if fact is a duplicate (should skip), False if new
|
|
558
|
+
"""
|
|
559
|
+
if not texts:
|
|
560
|
+
return []
|
|
561
|
+
|
|
562
|
+
# Handle edge cases where event_date is at datetime boundaries
|
|
563
|
+
try:
|
|
564
|
+
time_lower = event_date - timedelta(hours=time_window_hours)
|
|
565
|
+
except OverflowError:
|
|
566
|
+
time_lower = datetime.min
|
|
567
|
+
try:
|
|
568
|
+
time_upper = event_date + timedelta(hours=time_window_hours)
|
|
569
|
+
except OverflowError:
|
|
570
|
+
time_upper = datetime.max
|
|
571
|
+
|
|
572
|
+
# Fetch ALL existing facts in time window ONCE (much faster than N queries)
|
|
573
|
+
import time as time_mod
|
|
574
|
+
fetch_start = time_mod.time()
|
|
575
|
+
existing_facts = await conn.fetch(
|
|
576
|
+
"""
|
|
577
|
+
SELECT id, text, embedding
|
|
578
|
+
FROM memory_units
|
|
579
|
+
WHERE bank_id = $1
|
|
580
|
+
AND event_date BETWEEN $2 AND $3
|
|
581
|
+
""",
|
|
582
|
+
bank_id, time_lower, time_upper
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
# If no existing facts, nothing is duplicate
|
|
586
|
+
if not existing_facts:
|
|
587
|
+
return [False] * len(texts)
|
|
588
|
+
|
|
589
|
+
# Compute similarities in Python (vectorized with numpy)
|
|
590
|
+
import numpy as np
|
|
591
|
+
is_duplicate = []
|
|
592
|
+
|
|
593
|
+
# Convert existing embeddings to numpy for faster computation
|
|
594
|
+
embedding_arrays = []
|
|
595
|
+
for row in existing_facts:
|
|
596
|
+
raw_emb = row['embedding']
|
|
597
|
+
# Handle different pgvector formats
|
|
598
|
+
if isinstance(raw_emb, str):
|
|
599
|
+
# Parse string format: "[1.0, 2.0, ...]"
|
|
600
|
+
import json
|
|
601
|
+
emb = np.array(json.loads(raw_emb), dtype=np.float32)
|
|
602
|
+
elif isinstance(raw_emb, (list, tuple)):
|
|
603
|
+
emb = np.array(raw_emb, dtype=np.float32)
|
|
604
|
+
else:
|
|
605
|
+
# Try direct conversion
|
|
606
|
+
emb = np.array(raw_emb, dtype=np.float32)
|
|
607
|
+
embedding_arrays.append(emb)
|
|
608
|
+
|
|
609
|
+
if not embedding_arrays:
|
|
610
|
+
existing_embeddings = np.array([])
|
|
611
|
+
elif len(embedding_arrays) == 1:
|
|
612
|
+
# Single embedding: reshape to (1, dim)
|
|
613
|
+
existing_embeddings = embedding_arrays[0].reshape(1, -1)
|
|
614
|
+
else:
|
|
615
|
+
# Multiple embeddings: vstack
|
|
616
|
+
existing_embeddings = np.vstack(embedding_arrays)
|
|
617
|
+
|
|
618
|
+
comp_start = time_mod.time()
|
|
619
|
+
for embedding in embeddings:
|
|
620
|
+
# Compute cosine similarity with all existing facts
|
|
621
|
+
emb_array = np.array(embedding)
|
|
622
|
+
# Cosine similarity = 1 - cosine distance
|
|
623
|
+
# For normalized vectors: cosine_sim = dot product
|
|
624
|
+
similarities = np.dot(existing_embeddings, emb_array)
|
|
625
|
+
|
|
626
|
+
# Check if any existing fact is too similar
|
|
627
|
+
max_similarity = np.max(similarities) if len(similarities) > 0 else 0
|
|
628
|
+
is_duplicate.append(max_similarity > similarity_threshold)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
return is_duplicate
|
|
632
|
+
|
|
633
|
+
def retain(
|
|
634
|
+
self,
|
|
635
|
+
bank_id: str,
|
|
636
|
+
content: str,
|
|
637
|
+
context: str = "",
|
|
638
|
+
event_date: Optional[datetime] = None,
|
|
639
|
+
) -> List[str]:
|
|
640
|
+
"""
|
|
641
|
+
Store content as memory units (synchronous wrapper).
|
|
642
|
+
|
|
643
|
+
This is a synchronous wrapper around retain_async() for convenience.
|
|
644
|
+
For best performance, use retain_async() directly.
|
|
645
|
+
|
|
646
|
+
Args:
|
|
647
|
+
bank_id: Unique identifier for the bank
|
|
648
|
+
content: Text content to store
|
|
649
|
+
context: Context about when/why this memory was formed
|
|
650
|
+
event_date: When the event occurred (defaults to now)
|
|
651
|
+
|
|
652
|
+
Returns:
|
|
653
|
+
List of created unit IDs
|
|
654
|
+
"""
|
|
655
|
+
# Run async version synchronously
|
|
656
|
+
return asyncio.run(self.retain_async(bank_id, content, context, event_date))
|
|
657
|
+
|
|
658
|
+
async def retain_async(
|
|
659
|
+
self,
|
|
660
|
+
bank_id: str,
|
|
661
|
+
content: str,
|
|
662
|
+
context: str = "",
|
|
663
|
+
event_date: Optional[datetime] = None,
|
|
664
|
+
document_id: Optional[str] = None,
|
|
665
|
+
fact_type_override: Optional[str] = None,
|
|
666
|
+
confidence_score: Optional[float] = None,
|
|
667
|
+
) -> List[str]:
|
|
668
|
+
"""
|
|
669
|
+
Store content as memory units with temporal and semantic links (ASYNC version).
|
|
670
|
+
|
|
671
|
+
This is a convenience wrapper around retain_batch_async for a single content item.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
bank_id: Unique identifier for the bank
|
|
675
|
+
content: Text content to store
|
|
676
|
+
context: Context about when/why this memory was formed
|
|
677
|
+
event_date: When the event occurred (defaults to now)
|
|
678
|
+
document_id: Optional document ID for tracking (always upserts if document already exists)
|
|
679
|
+
fact_type_override: Override fact type ('world', 'bank', 'opinion')
|
|
680
|
+
confidence_score: Confidence score for opinions (0.0 to 1.0)
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
List of created unit IDs
|
|
684
|
+
"""
|
|
685
|
+
# Build content dict
|
|
686
|
+
content_dict: RetainContentDict = {
|
|
687
|
+
"content": content,
|
|
688
|
+
"context": context,
|
|
689
|
+
"event_date": event_date
|
|
690
|
+
}
|
|
691
|
+
if document_id:
|
|
692
|
+
content_dict["document_id"] = document_id
|
|
693
|
+
|
|
694
|
+
# Use retain_batch_async with a single item (avoids code duplication)
|
|
695
|
+
result = await self.retain_batch_async(
|
|
696
|
+
bank_id=bank_id,
|
|
697
|
+
contents=[content_dict],
|
|
698
|
+
fact_type_override=fact_type_override,
|
|
699
|
+
confidence_score=confidence_score
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
# Return the first (and only) list of unit IDs
|
|
703
|
+
return result[0] if result else []
|
|
704
|
+
|
|
705
|
+
async def retain_batch_async(
|
|
706
|
+
self,
|
|
707
|
+
bank_id: str,
|
|
708
|
+
contents: List[RetainContentDict],
|
|
709
|
+
document_id: Optional[str] = None,
|
|
710
|
+
fact_type_override: Optional[str] = None,
|
|
711
|
+
confidence_score: Optional[float] = None,
|
|
712
|
+
) -> List[List[str]]:
|
|
713
|
+
"""
|
|
714
|
+
Store multiple content items as memory units in ONE batch operation.
|
|
715
|
+
|
|
716
|
+
This is MUCH more efficient than calling retain_async multiple times:
|
|
717
|
+
- Extracts facts from all contents in parallel
|
|
718
|
+
- Generates ALL embeddings in ONE batch
|
|
719
|
+
- Does ALL database operations in ONE transaction
|
|
720
|
+
- Automatically chunks large batches to prevent timeouts
|
|
721
|
+
|
|
722
|
+
Args:
|
|
723
|
+
bank_id: Unique identifier for the bank
|
|
724
|
+
contents: List of dicts with keys:
|
|
725
|
+
- "content" (required): Text content to store
|
|
726
|
+
- "context" (optional): Context about the memory
|
|
727
|
+
- "event_date" (optional): When the event occurred
|
|
728
|
+
- "document_id" (optional): Document ID for this specific content item
|
|
729
|
+
document_id: **DEPRECATED** - Use "document_id" key in each content dict instead.
|
|
730
|
+
Applies the same document_id to ALL content items that don't specify their own.
|
|
731
|
+
fact_type_override: Override fact type for all facts ('world', 'bank', 'opinion')
|
|
732
|
+
confidence_score: Confidence score for opinions (0.0 to 1.0)
|
|
733
|
+
|
|
734
|
+
Returns:
|
|
735
|
+
List of lists of unit IDs (one list per content item)
|
|
736
|
+
|
|
737
|
+
Example (new style - per-content document_id):
|
|
738
|
+
unit_ids = await memory.retain_batch_async(
|
|
739
|
+
bank_id="user123",
|
|
740
|
+
contents=[
|
|
741
|
+
{"content": "Alice works at Google", "document_id": "doc1"},
|
|
742
|
+
{"content": "Bob loves Python", "document_id": "doc2"},
|
|
743
|
+
{"content": "More about Alice", "document_id": "doc1"},
|
|
744
|
+
]
|
|
745
|
+
)
|
|
746
|
+
# Returns: [["unit-id-1"], ["unit-id-2"], ["unit-id-3"]]
|
|
747
|
+
|
|
748
|
+
Example (deprecated style - batch-level document_id):
|
|
749
|
+
unit_ids = await memory.retain_batch_async(
|
|
750
|
+
bank_id="user123",
|
|
751
|
+
contents=[
|
|
752
|
+
{"content": "Alice works at Google"},
|
|
753
|
+
{"content": "Bob loves Python"},
|
|
754
|
+
],
|
|
755
|
+
document_id="meeting-2024-01-15"
|
|
756
|
+
)
|
|
757
|
+
# Returns: [["unit-id-1"], ["unit-id-2"]]
|
|
758
|
+
"""
|
|
759
|
+
start_time = time.time()
|
|
760
|
+
|
|
761
|
+
if not contents:
|
|
762
|
+
return []
|
|
763
|
+
|
|
764
|
+
# Apply batch-level document_id to contents that don't have their own (backwards compatibility)
|
|
765
|
+
if document_id:
|
|
766
|
+
for item in contents:
|
|
767
|
+
if "document_id" not in item:
|
|
768
|
+
item["document_id"] = document_id
|
|
769
|
+
|
|
770
|
+
# Auto-chunk large batches by character count to avoid timeouts and memory issues
|
|
771
|
+
# Calculate total character count
|
|
772
|
+
total_chars = sum(len(item.get("content", "")) for item in contents)
|
|
773
|
+
|
|
774
|
+
CHARS_PER_BATCH = 600_000
|
|
775
|
+
|
|
776
|
+
if total_chars > CHARS_PER_BATCH:
|
|
777
|
+
# Split into smaller batches based on character count
|
|
778
|
+
logger.info(f"Large batch detected ({total_chars:,} chars from {len(contents)} items). Splitting into sub-batches of ~{CHARS_PER_BATCH:,} chars each...")
|
|
779
|
+
|
|
780
|
+
sub_batches = []
|
|
781
|
+
current_batch = []
|
|
782
|
+
current_batch_chars = 0
|
|
783
|
+
|
|
784
|
+
for item in contents:
|
|
785
|
+
item_chars = len(item.get("content", ""))
|
|
786
|
+
|
|
787
|
+
# If adding this item would exceed the limit, start a new batch
|
|
788
|
+
# (unless current batch is empty - then we must include it even if it's large)
|
|
789
|
+
if current_batch and current_batch_chars + item_chars > CHARS_PER_BATCH:
|
|
790
|
+
sub_batches.append(current_batch)
|
|
791
|
+
current_batch = [item]
|
|
792
|
+
current_batch_chars = item_chars
|
|
793
|
+
else:
|
|
794
|
+
current_batch.append(item)
|
|
795
|
+
current_batch_chars += item_chars
|
|
796
|
+
|
|
797
|
+
# Add the last batch
|
|
798
|
+
if current_batch:
|
|
799
|
+
sub_batches.append(current_batch)
|
|
800
|
+
|
|
801
|
+
logger.info(f"Split into {len(sub_batches)} sub-batches: {[len(b) for b in sub_batches]} items each")
|
|
802
|
+
|
|
803
|
+
# Process each sub-batch using internal method (skip chunking check)
|
|
804
|
+
all_results = []
|
|
805
|
+
for i, sub_batch in enumerate(sub_batches, 1):
|
|
806
|
+
sub_batch_chars = sum(len(item.get("content", "")) for item in sub_batch)
|
|
807
|
+
logger.info(f"Processing sub-batch {i}/{len(sub_batches)}: {len(sub_batch)} items, {sub_batch_chars:,} chars")
|
|
808
|
+
|
|
809
|
+
sub_results = await self._retain_batch_async_internal(
|
|
810
|
+
bank_id=bank_id,
|
|
811
|
+
contents=sub_batch,
|
|
812
|
+
document_id=document_id,
|
|
813
|
+
is_first_batch=i == 1, # Only upsert on first batch
|
|
814
|
+
fact_type_override=fact_type_override,
|
|
815
|
+
confidence_score=confidence_score
|
|
816
|
+
)
|
|
817
|
+
all_results.extend(sub_results)
|
|
818
|
+
|
|
819
|
+
total_time = time.time() - start_time
|
|
820
|
+
logger.info(f"RETAIN_BATCH_ASYNC (chunked) COMPLETE: {len(all_results)} results from {len(contents)} contents in {total_time:.3f}s")
|
|
821
|
+
return all_results
|
|
822
|
+
|
|
823
|
+
# Small batch - use internal method directly
|
|
824
|
+
return await self._retain_batch_async_internal(
|
|
825
|
+
bank_id=bank_id,
|
|
826
|
+
contents=contents,
|
|
827
|
+
document_id=document_id,
|
|
828
|
+
is_first_batch=True,
|
|
829
|
+
fact_type_override=fact_type_override,
|
|
830
|
+
confidence_score=confidence_score
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
async def _retain_batch_async_internal(
|
|
834
|
+
self,
|
|
835
|
+
bank_id: str,
|
|
836
|
+
contents: List[RetainContentDict],
|
|
837
|
+
document_id: Optional[str] = None,
|
|
838
|
+
is_first_batch: bool = True,
|
|
839
|
+
fact_type_override: Optional[str] = None,
|
|
840
|
+
confidence_score: Optional[float] = None,
|
|
841
|
+
) -> List[List[str]]:
|
|
842
|
+
"""
|
|
843
|
+
Internal method for batch processing without chunking logic.
|
|
844
|
+
|
|
845
|
+
Assumes contents are already appropriately sized (< 50k chars).
|
|
846
|
+
Called by retain_batch_async after chunking large batches.
|
|
847
|
+
|
|
848
|
+
Uses semaphore for backpressure to limit concurrent retains.
|
|
849
|
+
|
|
850
|
+
Args:
|
|
851
|
+
bank_id: Unique identifier for the bank
|
|
852
|
+
contents: List of dicts with content, context, event_date
|
|
853
|
+
document_id: Optional document ID (always upserts if exists)
|
|
854
|
+
is_first_batch: Whether this is the first batch (for chunked operations, only delete on first batch)
|
|
855
|
+
fact_type_override: Override fact type for all facts
|
|
856
|
+
confidence_score: Confidence score for opinions
|
|
857
|
+
"""
|
|
858
|
+
# Backpressure: limit concurrent retains to prevent database contention
|
|
859
|
+
async with self._put_semaphore:
|
|
860
|
+
# Use the new modular orchestrator
|
|
861
|
+
from .retain import orchestrator
|
|
862
|
+
|
|
863
|
+
pool = await self._get_pool()
|
|
864
|
+
return await orchestrator.retain_batch(
|
|
865
|
+
pool=pool,
|
|
866
|
+
embeddings_model=self.embeddings,
|
|
867
|
+
llm_config=self._llm_config,
|
|
868
|
+
entity_resolver=self.entity_resolver,
|
|
869
|
+
task_backend=self._task_backend,
|
|
870
|
+
format_date_fn=self._format_readable_date,
|
|
871
|
+
duplicate_checker_fn=self._find_duplicate_facts_batch,
|
|
872
|
+
regenerate_observations_fn=self._regenerate_observations_sync,
|
|
873
|
+
bank_id=bank_id,
|
|
874
|
+
contents_dicts=contents,
|
|
875
|
+
document_id=document_id,
|
|
876
|
+
is_first_batch=is_first_batch,
|
|
877
|
+
fact_type_override=fact_type_override,
|
|
878
|
+
confidence_score=confidence_score
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
def recall(
|
|
882
|
+
self,
|
|
883
|
+
bank_id: str,
|
|
884
|
+
query: str,
|
|
885
|
+
fact_type: str,
|
|
886
|
+
budget: Budget = Budget.MID,
|
|
887
|
+
max_tokens: int = 4096,
|
|
888
|
+
enable_trace: bool = False,
|
|
889
|
+
) -> tuple[List[Dict[str, Any]], Optional[Any]]:
|
|
890
|
+
"""
|
|
891
|
+
Recall memories using 4-way parallel retrieval (synchronous wrapper).
|
|
892
|
+
|
|
893
|
+
This is a synchronous wrapper around recall_async() for convenience.
|
|
894
|
+
For best performance, use recall_async() directly.
|
|
895
|
+
|
|
896
|
+
Args:
|
|
897
|
+
bank_id: bank ID to recall for
|
|
898
|
+
query: Recall query
|
|
899
|
+
fact_type: Required filter for fact type ('world', 'agent', or 'opinion')
|
|
900
|
+
budget: Budget level for graph traversal (low=100, mid=300, high=600 units)
|
|
901
|
+
max_tokens: Maximum tokens to return (counts only 'text' field, default 4096)
|
|
902
|
+
enable_trace: If True, returns detailed trace object
|
|
903
|
+
|
|
904
|
+
Returns:
|
|
905
|
+
Tuple of (results, trace)
|
|
906
|
+
"""
|
|
907
|
+
# Run async version synchronously
|
|
908
|
+
return asyncio.run(self.recall_async(
|
|
909
|
+
bank_id, query, [fact_type], budget, max_tokens, enable_trace
|
|
910
|
+
))
|
|
911
|
+
|
|
912
|
+
async def recall_async(
|
|
913
|
+
self,
|
|
914
|
+
bank_id: str,
|
|
915
|
+
query: str,
|
|
916
|
+
fact_type: List[str],
|
|
917
|
+
budget: Budget = Budget.MID,
|
|
918
|
+
max_tokens: int = 4096,
|
|
919
|
+
enable_trace: bool = False,
|
|
920
|
+
question_date: Optional[datetime] = None,
|
|
921
|
+
include_entities: bool = False,
|
|
922
|
+
max_entity_tokens: int = 1024,
|
|
923
|
+
include_chunks: bool = False,
|
|
924
|
+
max_chunk_tokens: int = 8192,
|
|
925
|
+
) -> RecallResultModel:
|
|
926
|
+
"""
|
|
927
|
+
Recall memories using N*4-way parallel retrieval (N fact types × 4 retrieval methods).
|
|
928
|
+
|
|
929
|
+
This implements the core RECALL operation:
|
|
930
|
+
1. Retrieval: For each fact type, run 4 parallel retrievals (semantic vector, BM25 keyword, graph activation, temporal graph)
|
|
931
|
+
2. Merge: Combine using Reciprocal Rank Fusion (RRF)
|
|
932
|
+
3. Rerank: Score using selected reranker (heuristic or cross-encoder)
|
|
933
|
+
4. Diversify: Apply MMR for diversity
|
|
934
|
+
5. Token Filter: Return results up to max_tokens budget
|
|
935
|
+
|
|
936
|
+
Args:
|
|
937
|
+
bank_id: bank ID to recall for
|
|
938
|
+
query: Recall query
|
|
939
|
+
fact_type: List of fact types to recall (e.g., ['world', 'bank'])
|
|
940
|
+
budget: Budget level for graph traversal (low=100, mid=300, high=600 units)
|
|
941
|
+
max_tokens: Maximum tokens to return (counts only 'text' field, default 4096)
|
|
942
|
+
Results are returned until token budget is reached, stopping before
|
|
943
|
+
including a fact that would exceed the limit
|
|
944
|
+
enable_trace: Whether to return trace for debugging (deprecated)
|
|
945
|
+
question_date: Optional date when question was asked (for temporal filtering)
|
|
946
|
+
include_entities: Whether to include entity observations in the response
|
|
947
|
+
max_entity_tokens: Maximum tokens for entity observations (default 500)
|
|
948
|
+
include_chunks: Whether to include raw chunks in the response
|
|
949
|
+
max_chunk_tokens: Maximum tokens for chunks (default 8192)
|
|
950
|
+
|
|
951
|
+
Returns:
|
|
952
|
+
RecallResultModel containing:
|
|
953
|
+
- results: List of MemoryFact objects
|
|
954
|
+
- trace: Optional trace information for debugging
|
|
955
|
+
- entities: Optional dict of entity states (if include_entities=True)
|
|
956
|
+
- chunks: Optional dict of chunks (if include_chunks=True)
|
|
957
|
+
"""
|
|
958
|
+
# Map budget enum to thinking_budget number
|
|
959
|
+
budget_mapping = {
|
|
960
|
+
Budget.LOW: 100,
|
|
961
|
+
Budget.MID: 300,
|
|
962
|
+
Budget.HIGH: 600
|
|
963
|
+
}
|
|
964
|
+
thinking_budget = budget_mapping[budget]
|
|
965
|
+
|
|
966
|
+
# Backpressure: limit concurrent recalls to prevent overwhelming the database
|
|
967
|
+
async with self._search_semaphore:
|
|
968
|
+
# Retry loop for connection errors
|
|
969
|
+
max_retries = 3
|
|
970
|
+
for attempt in range(max_retries + 1):
|
|
971
|
+
try:
|
|
972
|
+
return await self._search_with_retries(
|
|
973
|
+
bank_id, query, fact_type, thinking_budget, max_tokens, enable_trace, question_date,
|
|
974
|
+
include_entities, max_entity_tokens, include_chunks, max_chunk_tokens
|
|
975
|
+
)
|
|
976
|
+
except Exception as e:
|
|
977
|
+
# Check if it's a connection error
|
|
978
|
+
is_connection_error = (
|
|
979
|
+
isinstance(e, asyncpg.TooManyConnectionsError) or
|
|
980
|
+
isinstance(e, asyncpg.CannotConnectNowError) or
|
|
981
|
+
(isinstance(e, asyncpg.PostgresError) and 'connection' in str(e).lower())
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
if is_connection_error and attempt < max_retries:
|
|
985
|
+
# Wait with exponential backoff before retry
|
|
986
|
+
wait_time = 0.5 * (2 ** attempt) # 0.5s, 1s, 2s
|
|
987
|
+
logger.warning(
|
|
988
|
+
f"Connection error on search attempt {attempt + 1}/{max_retries + 1}: {str(e)}. "
|
|
989
|
+
f"Retrying in {wait_time:.1f}s..."
|
|
990
|
+
)
|
|
991
|
+
await asyncio.sleep(wait_time)
|
|
992
|
+
else:
|
|
993
|
+
# Not a connection error or out of retries - raise
|
|
994
|
+
raise
|
|
995
|
+
raise Exception("Exceeded maximum retries for search due to connection errors.")
|
|
996
|
+
|
|
997
|
+
async def _search_with_retries(
|
|
998
|
+
self,
|
|
999
|
+
bank_id: str,
|
|
1000
|
+
query: str,
|
|
1001
|
+
fact_type: List[str],
|
|
1002
|
+
thinking_budget: int,
|
|
1003
|
+
max_tokens: int,
|
|
1004
|
+
enable_trace: bool,
|
|
1005
|
+
question_date: Optional[datetime] = None,
|
|
1006
|
+
include_entities: bool = False,
|
|
1007
|
+
max_entity_tokens: int = 500,
|
|
1008
|
+
include_chunks: bool = False,
|
|
1009
|
+
max_chunk_tokens: int = 8192,
|
|
1010
|
+
) -> RecallResultModel:
|
|
1011
|
+
"""
|
|
1012
|
+
Search implementation with modular retrieval and reranking.
|
|
1013
|
+
|
|
1014
|
+
Architecture:
|
|
1015
|
+
1. Retrieval: 4-way parallel (semantic, keyword, graph, temporal graph)
|
|
1016
|
+
2. Merge: RRF to combine ranked lists
|
|
1017
|
+
3. Reranking: Pluggable strategy (heuristic or cross-encoder)
|
|
1018
|
+
4. Diversity: MMR with λ=0.5
|
|
1019
|
+
5. Token Filter: Limit results to max_tokens budget
|
|
1020
|
+
|
|
1021
|
+
Args:
|
|
1022
|
+
bank_id: bank IDentifier
|
|
1023
|
+
query: Search query
|
|
1024
|
+
fact_type: Type of facts to search
|
|
1025
|
+
thinking_budget: Nodes to explore in graph traversal
|
|
1026
|
+
max_tokens: Maximum tokens to return (counts only 'text' field)
|
|
1027
|
+
enable_trace: Whether to return search trace (deprecated)
|
|
1028
|
+
include_entities: Whether to include entity observations
|
|
1029
|
+
max_entity_tokens: Maximum tokens for entity observations
|
|
1030
|
+
include_chunks: Whether to include raw chunks
|
|
1031
|
+
max_chunk_tokens: Maximum tokens for chunks
|
|
1032
|
+
|
|
1033
|
+
Returns:
|
|
1034
|
+
RecallResultModel with results, trace, optional entities, and optional chunks
|
|
1035
|
+
"""
|
|
1036
|
+
# Initialize tracer if requested
|
|
1037
|
+
from .search.tracer import SearchTracer
|
|
1038
|
+
tracer = SearchTracer(query, thinking_budget, max_tokens) if enable_trace else None
|
|
1039
|
+
if tracer:
|
|
1040
|
+
tracer.start()
|
|
1041
|
+
|
|
1042
|
+
pool = await self._get_pool()
|
|
1043
|
+
search_start = time.time()
|
|
1044
|
+
|
|
1045
|
+
# Buffer logs for clean output in concurrent scenarios
|
|
1046
|
+
search_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
|
|
1047
|
+
log_buffer = []
|
|
1048
|
+
log_buffer.append(f"[SEARCH {search_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})")
|
|
1049
|
+
|
|
1050
|
+
try:
|
|
1051
|
+
# Step 1: Generate query embedding (for semantic search)
|
|
1052
|
+
step_start = time.time()
|
|
1053
|
+
query_embedding = embedding_utils.generate_embedding(self.embeddings, query)
|
|
1054
|
+
step_duration = time.time() - step_start
|
|
1055
|
+
log_buffer.append(f" [1] Generate query embedding: {step_duration:.3f}s")
|
|
1056
|
+
|
|
1057
|
+
if tracer:
|
|
1058
|
+
tracer.record_query_embedding(query_embedding)
|
|
1059
|
+
tracer.add_phase_metric("generate_query_embedding", step_duration)
|
|
1060
|
+
|
|
1061
|
+
# Step 2: N*4-Way Parallel Retrieval (N fact types × 4 retrieval methods)
|
|
1062
|
+
step_start = time.time()
|
|
1063
|
+
query_embedding_str = str(query_embedding)
|
|
1064
|
+
|
|
1065
|
+
from .search.retrieval import retrieve_parallel
|
|
1066
|
+
|
|
1067
|
+
# Track each retrieval start time
|
|
1068
|
+
retrieval_start = time.time()
|
|
1069
|
+
|
|
1070
|
+
# Run retrieval for each fact type in parallel
|
|
1071
|
+
retrieval_tasks = [
|
|
1072
|
+
retrieve_parallel(
|
|
1073
|
+
pool, query, query_embedding_str, bank_id, ft, thinking_budget,
|
|
1074
|
+
question_date, self.query_analyzer
|
|
1075
|
+
)
|
|
1076
|
+
for ft in fact_type
|
|
1077
|
+
]
|
|
1078
|
+
all_retrievals = await asyncio.gather(*retrieval_tasks)
|
|
1079
|
+
|
|
1080
|
+
# Combine all results from all fact types and aggregate timings
|
|
1081
|
+
semantic_results = []
|
|
1082
|
+
bm25_results = []
|
|
1083
|
+
graph_results = []
|
|
1084
|
+
temporal_results = []
|
|
1085
|
+
aggregated_timings = {"semantic": 0.0, "bm25": 0.0, "graph": 0.0, "temporal": 0.0}
|
|
1086
|
+
|
|
1087
|
+
detected_temporal_constraint = None
|
|
1088
|
+
for idx, (ft_semantic, ft_bm25, ft_graph, ft_temporal, ft_timings, ft_temporal_constraint) in enumerate(all_retrievals):
|
|
1089
|
+
# Log fact types in this retrieval batch
|
|
1090
|
+
ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
|
|
1091
|
+
logger.debug(f"[SEARCH {search_id}] Fact type '{ft_name}': semantic={len(ft_semantic)}, bm25={len(ft_bm25)}, graph={len(ft_graph)}, temporal={len(ft_temporal) if ft_temporal else 0}")
|
|
1092
|
+
|
|
1093
|
+
semantic_results.extend(ft_semantic)
|
|
1094
|
+
bm25_results.extend(ft_bm25)
|
|
1095
|
+
graph_results.extend(ft_graph)
|
|
1096
|
+
if ft_temporal:
|
|
1097
|
+
temporal_results.extend(ft_temporal)
|
|
1098
|
+
# Track max timing for each method (since they run in parallel across fact types)
|
|
1099
|
+
for method, duration in ft_timings.items():
|
|
1100
|
+
aggregated_timings[method] = max(aggregated_timings[method], duration)
|
|
1101
|
+
# Capture temporal constraint (same across all fact types)
|
|
1102
|
+
if ft_temporal_constraint:
|
|
1103
|
+
detected_temporal_constraint = ft_temporal_constraint
|
|
1104
|
+
|
|
1105
|
+
# If no temporal results from any fact type, set to None
|
|
1106
|
+
if not temporal_results:
|
|
1107
|
+
temporal_results = None
|
|
1108
|
+
|
|
1109
|
+
# Sort combined results by score (descending) so higher-scored results
|
|
1110
|
+
# get better ranks in the trace, regardless of fact type
|
|
1111
|
+
semantic_results.sort(key=lambda r: r.similarity if hasattr(r, 'similarity') else 0, reverse=True)
|
|
1112
|
+
bm25_results.sort(key=lambda r: r.bm25_score if hasattr(r, 'bm25_score') else 0, reverse=True)
|
|
1113
|
+
graph_results.sort(key=lambda r: r.activation if hasattr(r, 'activation') else 0, reverse=True)
|
|
1114
|
+
if temporal_results:
|
|
1115
|
+
temporal_results.sort(key=lambda r: r.combined_score if hasattr(r, 'combined_score') else 0, reverse=True)
|
|
1116
|
+
|
|
1117
|
+
retrieval_duration = time.time() - retrieval_start
|
|
1118
|
+
|
|
1119
|
+
step_duration = time.time() - step_start
|
|
1120
|
+
total_retrievals = len(fact_type) * (4 if temporal_results else 3)
|
|
1121
|
+
# Format per-method timings
|
|
1122
|
+
timing_parts = [
|
|
1123
|
+
f"semantic={len(semantic_results)}({aggregated_timings['semantic']:.3f}s)",
|
|
1124
|
+
f"bm25={len(bm25_results)}({aggregated_timings['bm25']:.3f}s)",
|
|
1125
|
+
f"graph={len(graph_results)}({aggregated_timings['graph']:.3f}s)"
|
|
1126
|
+
]
|
|
1127
|
+
temporal_info = ""
|
|
1128
|
+
if detected_temporal_constraint:
|
|
1129
|
+
start_dt, end_dt = detected_temporal_constraint
|
|
1130
|
+
temporal_count = len(temporal_results) if temporal_results else 0
|
|
1131
|
+
timing_parts.append(f"temporal={temporal_count}({aggregated_timings['temporal']:.3f}s)")
|
|
1132
|
+
temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
|
|
1133
|
+
log_buffer.append(f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}")
|
|
1134
|
+
|
|
1135
|
+
# Record retrieval results for tracer (convert typed results to old format)
|
|
1136
|
+
if tracer:
|
|
1137
|
+
# Convert RetrievalResult to old tuple format for tracer
|
|
1138
|
+
def to_tuple_format(results):
|
|
1139
|
+
return [(r.id, r.__dict__) for r in results]
|
|
1140
|
+
|
|
1141
|
+
# Add semantic retrieval results
|
|
1142
|
+
tracer.add_retrieval_results(
|
|
1143
|
+
method_name="semantic",
|
|
1144
|
+
results=to_tuple_format(semantic_results),
|
|
1145
|
+
duration_seconds=aggregated_timings["semantic"],
|
|
1146
|
+
score_field="similarity",
|
|
1147
|
+
metadata={"limit": thinking_budget}
|
|
1148
|
+
)
|
|
1149
|
+
|
|
1150
|
+
# Add BM25 retrieval results
|
|
1151
|
+
tracer.add_retrieval_results(
|
|
1152
|
+
method_name="bm25",
|
|
1153
|
+
results=to_tuple_format(bm25_results),
|
|
1154
|
+
duration_seconds=aggregated_timings["bm25"],
|
|
1155
|
+
score_field="bm25_score",
|
|
1156
|
+
metadata={"limit": thinking_budget}
|
|
1157
|
+
)
|
|
1158
|
+
|
|
1159
|
+
# Add graph retrieval results
|
|
1160
|
+
tracer.add_retrieval_results(
|
|
1161
|
+
method_name="graph",
|
|
1162
|
+
results=to_tuple_format(graph_results),
|
|
1163
|
+
duration_seconds=aggregated_timings["graph"],
|
|
1164
|
+
score_field="similarity", # Graph uses similarity for activation
|
|
1165
|
+
metadata={"budget": thinking_budget}
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
# Add temporal retrieval results if present
|
|
1169
|
+
if temporal_results:
|
|
1170
|
+
tracer.add_retrieval_results(
|
|
1171
|
+
method_name="temporal",
|
|
1172
|
+
results=to_tuple_format(temporal_results),
|
|
1173
|
+
duration_seconds=aggregated_timings["temporal"],
|
|
1174
|
+
score_field="temporal_score",
|
|
1175
|
+
metadata={"budget": thinking_budget}
|
|
1176
|
+
)
|
|
1177
|
+
|
|
1178
|
+
# Record entry points (from semantic results) for legacy graph view
|
|
1179
|
+
for rank, retrieval in enumerate(semantic_results[:10], start=1): # Top 10 as entry points
|
|
1180
|
+
tracer.add_entry_point(retrieval.id, retrieval.text, retrieval.similarity or 0.0, rank)
|
|
1181
|
+
|
|
1182
|
+
tracer.add_phase_metric("parallel_retrieval", step_duration, {
|
|
1183
|
+
"semantic_count": len(semantic_results),
|
|
1184
|
+
"bm25_count": len(bm25_results),
|
|
1185
|
+
"graph_count": len(graph_results),
|
|
1186
|
+
"temporal_count": len(temporal_results) if temporal_results else 0
|
|
1187
|
+
})
|
|
1188
|
+
|
|
1189
|
+
# Step 3: Merge with RRF
|
|
1190
|
+
step_start = time.time()
|
|
1191
|
+
from .search.fusion import reciprocal_rank_fusion
|
|
1192
|
+
|
|
1193
|
+
# Merge 3 or 4 result lists depending on temporal constraint
|
|
1194
|
+
if temporal_results:
|
|
1195
|
+
merged_candidates = reciprocal_rank_fusion([semantic_results, bm25_results, graph_results, temporal_results])
|
|
1196
|
+
else:
|
|
1197
|
+
merged_candidates = reciprocal_rank_fusion([semantic_results, bm25_results, graph_results])
|
|
1198
|
+
|
|
1199
|
+
step_duration = time.time() - step_start
|
|
1200
|
+
log_buffer.append(f" [3] RRF merge: {len(merged_candidates)} unique candidates in {step_duration:.3f}s")
|
|
1201
|
+
|
|
1202
|
+
if tracer:
|
|
1203
|
+
# Convert MergedCandidate to old tuple format for tracer
|
|
1204
|
+
tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
|
|
1205
|
+
for mc in merged_candidates]
|
|
1206
|
+
tracer.add_rrf_merged(tracer_merged)
|
|
1207
|
+
tracer.add_phase_metric("rrf_merge", step_duration, {"candidates_merged": len(merged_candidates)})
|
|
1208
|
+
|
|
1209
|
+
# Step 4: Rerank using cross-encoder (MergedCandidate -> ScoredResult)
|
|
1210
|
+
step_start = time.time()
|
|
1211
|
+
reranker_instance = self._cross_encoder_reranker
|
|
1212
|
+
log_buffer.append(f" [4] Using cross-encoder reranker")
|
|
1213
|
+
|
|
1214
|
+
# Rerank using cross-encoder
|
|
1215
|
+
scored_results = reranker_instance.rerank(query, merged_candidates)
|
|
1216
|
+
|
|
1217
|
+
step_duration = time.time() - step_start
|
|
1218
|
+
log_buffer.append(f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s")
|
|
1219
|
+
|
|
1220
|
+
if tracer:
|
|
1221
|
+
# Convert to old format for tracer
|
|
1222
|
+
results_dict = [sr.to_dict() for sr in scored_results]
|
|
1223
|
+
tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
|
|
1224
|
+
for mc in merged_candidates]
|
|
1225
|
+
tracer.add_reranked(results_dict, tracer_merged)
|
|
1226
|
+
tracer.add_phase_metric("reranking", step_duration, {
|
|
1227
|
+
"reranker_type": "cross-encoder",
|
|
1228
|
+
"candidates_reranked": len(scored_results)
|
|
1229
|
+
})
|
|
1230
|
+
|
|
1231
|
+
# Step 4.5: Combine cross-encoder score with retrieval signals
|
|
1232
|
+
# This preserves retrieval work (RRF, temporal, recency) instead of pure cross-encoder ranking
|
|
1233
|
+
if scored_results:
|
|
1234
|
+
# Normalize RRF scores to [0, 1] range
|
|
1235
|
+
rrf_scores = [sr.candidate.rrf_score for sr in scored_results]
|
|
1236
|
+
max_rrf = max(rrf_scores) if rrf_scores else 1.0
|
|
1237
|
+
min_rrf = min(rrf_scores) if rrf_scores else 0.0
|
|
1238
|
+
rrf_range = max_rrf - min_rrf if max_rrf > min_rrf else 1.0
|
|
1239
|
+
|
|
1240
|
+
# Calculate recency based on occurred_start (more recent = higher score)
|
|
1241
|
+
now = utcnow()
|
|
1242
|
+
for sr in scored_results:
|
|
1243
|
+
# Normalize RRF score
|
|
1244
|
+
sr.rrf_normalized = (sr.candidate.rrf_score - min_rrf) / rrf_range if rrf_range > 0 else 0.5
|
|
1245
|
+
|
|
1246
|
+
# Calculate recency (decay over 365 days, minimum 0.1)
|
|
1247
|
+
sr.recency = 0.5 # default for missing dates
|
|
1248
|
+
if sr.retrieval.occurred_start:
|
|
1249
|
+
occurred = sr.retrieval.occurred_start
|
|
1250
|
+
if hasattr(occurred, 'tzinfo') and occurred.tzinfo is None:
|
|
1251
|
+
from datetime import timezone
|
|
1252
|
+
occurred = occurred.replace(tzinfo=timezone.utc)
|
|
1253
|
+
days_ago = (now - occurred).total_seconds() / 86400
|
|
1254
|
+
sr.recency = max(0.1, 1.0 - (days_ago / 365)) # Linear decay over 1 year
|
|
1255
|
+
|
|
1256
|
+
# Get temporal proximity if available (already 0-1)
|
|
1257
|
+
sr.temporal = sr.retrieval.temporal_proximity if sr.retrieval.temporal_proximity is not None else 0.5
|
|
1258
|
+
|
|
1259
|
+
# Weighted combination
|
|
1260
|
+
# Cross-encoder: 60% (semantic relevance)
|
|
1261
|
+
# RRF: 20% (retrieval consensus)
|
|
1262
|
+
# Temporal proximity: 10% (time relevance for temporal queries)
|
|
1263
|
+
# Recency: 10% (prefer recent facts)
|
|
1264
|
+
sr.combined_score = (
|
|
1265
|
+
0.6 * sr.cross_encoder_score_normalized +
|
|
1266
|
+
0.2 * sr.rrf_normalized +
|
|
1267
|
+
0.1 * sr.temporal +
|
|
1268
|
+
0.1 * sr.recency
|
|
1269
|
+
)
|
|
1270
|
+
sr.weight = sr.combined_score # Update weight for final ranking
|
|
1271
|
+
|
|
1272
|
+
# Re-sort by combined score
|
|
1273
|
+
scored_results.sort(key=lambda x: x.weight, reverse=True)
|
|
1274
|
+
log_buffer.append(f" [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)")
|
|
1275
|
+
|
|
1276
|
+
# Step 5: Truncate to thinking_budget * 2 for token filtering
|
|
1277
|
+
rerank_limit = thinking_budget * 2
|
|
1278
|
+
top_scored = scored_results[:rerank_limit]
|
|
1279
|
+
log_buffer.append(f" [5] Truncated to top {len(top_scored)} results")
|
|
1280
|
+
|
|
1281
|
+
# Step 6: Token budget filtering
|
|
1282
|
+
step_start = time.time()
|
|
1283
|
+
|
|
1284
|
+
# Convert to dict for token filtering (backward compatibility)
|
|
1285
|
+
top_dicts = [sr.to_dict() for sr in top_scored]
|
|
1286
|
+
filtered_dicts, total_tokens = self._filter_by_token_budget(top_dicts, max_tokens)
|
|
1287
|
+
|
|
1288
|
+
# Convert back to list of IDs and filter scored_results
|
|
1289
|
+
filtered_ids = {d["id"] for d in filtered_dicts}
|
|
1290
|
+
top_scored = [sr for sr in top_scored if sr.id in filtered_ids]
|
|
1291
|
+
|
|
1292
|
+
step_duration = time.time() - step_start
|
|
1293
|
+
log_buffer.append(f" [6] Token filtering: {len(top_scored)} results, {total_tokens}/{max_tokens} tokens in {step_duration:.3f}s")
|
|
1294
|
+
|
|
1295
|
+
if tracer:
|
|
1296
|
+
tracer.add_phase_metric("token_filtering", step_duration, {
|
|
1297
|
+
"results_selected": len(top_scored),
|
|
1298
|
+
"tokens_used": total_tokens,
|
|
1299
|
+
"max_tokens": max_tokens
|
|
1300
|
+
})
|
|
1301
|
+
|
|
1302
|
+
# Record visits for all retrieved nodes
|
|
1303
|
+
if tracer:
|
|
1304
|
+
for sr in scored_results:
|
|
1305
|
+
tracer.visit_node(
|
|
1306
|
+
node_id=sr.id,
|
|
1307
|
+
text=sr.retrieval.text,
|
|
1308
|
+
context=sr.retrieval.context or "",
|
|
1309
|
+
event_date=sr.retrieval.occurred_start,
|
|
1310
|
+
access_count=sr.retrieval.access_count,
|
|
1311
|
+
is_entry_point=(sr.id in [ep.node_id for ep in tracer.entry_points]),
|
|
1312
|
+
parent_node_id=None, # In parallel retrieval, there's no clear parent
|
|
1313
|
+
link_type=None,
|
|
1314
|
+
link_weight=None,
|
|
1315
|
+
activation=sr.candidate.rrf_score, # Use RRF score as activation
|
|
1316
|
+
semantic_similarity=sr.retrieval.similarity or 0.0,
|
|
1317
|
+
recency=sr.recency,
|
|
1318
|
+
frequency=0.0,
|
|
1319
|
+
final_weight=sr.weight
|
|
1320
|
+
)
|
|
1321
|
+
|
|
1322
|
+
# Step 8: Queue access count updates for visited nodes
|
|
1323
|
+
visited_ids = list(set([sr.id for sr in scored_results[:50]])) # Top 50
|
|
1324
|
+
if visited_ids:
|
|
1325
|
+
await self._task_backend.submit_task({
|
|
1326
|
+
'type': 'access_count_update',
|
|
1327
|
+
'node_ids': visited_ids
|
|
1328
|
+
})
|
|
1329
|
+
log_buffer.append(f" [7] Queued access count updates for {len(visited_ids)} nodes")
|
|
1330
|
+
|
|
1331
|
+
# Log fact_type distribution in results
|
|
1332
|
+
fact_type_counts = {}
|
|
1333
|
+
for sr in top_scored:
|
|
1334
|
+
ft = sr.retrieval.fact_type
|
|
1335
|
+
fact_type_counts[ft] = fact_type_counts.get(ft, 0) + 1
|
|
1336
|
+
|
|
1337
|
+
total_time = time.time() - search_start
|
|
1338
|
+
fact_type_summary = ", ".join([f"{ft}={count}" for ft, count in sorted(fact_type_counts.items())])
|
|
1339
|
+
log_buffer.append(f"[SEARCH {search_id}] Complete: {len(top_scored)} results ({fact_type_summary}) ({total_tokens} tokens) in {total_time:.3f}s")
|
|
1340
|
+
|
|
1341
|
+
# Log all buffered logs at once
|
|
1342
|
+
logger.info("\n" + "\n".join(log_buffer))
|
|
1343
|
+
|
|
1344
|
+
# Convert ScoredResult to dicts with ISO datetime strings
|
|
1345
|
+
top_results_dicts = []
|
|
1346
|
+
for sr in top_scored:
|
|
1347
|
+
result_dict = sr.to_dict()
|
|
1348
|
+
# Convert datetime objects to ISO strings for JSON serialization
|
|
1349
|
+
if result_dict.get("occurred_start"):
|
|
1350
|
+
occurred_start = result_dict["occurred_start"]
|
|
1351
|
+
result_dict["occurred_start"] = occurred_start.isoformat() if hasattr(occurred_start, 'isoformat') else occurred_start
|
|
1352
|
+
if result_dict.get("occurred_end"):
|
|
1353
|
+
occurred_end = result_dict["occurred_end"]
|
|
1354
|
+
result_dict["occurred_end"] = occurred_end.isoformat() if hasattr(occurred_end, 'isoformat') else occurred_end
|
|
1355
|
+
if result_dict.get("mentioned_at"):
|
|
1356
|
+
mentioned_at = result_dict["mentioned_at"]
|
|
1357
|
+
result_dict["mentioned_at"] = mentioned_at.isoformat() if hasattr(mentioned_at, 'isoformat') else mentioned_at
|
|
1358
|
+
top_results_dicts.append(result_dict)
|
|
1359
|
+
|
|
1360
|
+
# Get entities for each fact if include_entities is requested
|
|
1361
|
+
fact_entity_map = {} # unit_id -> list of (entity_id, entity_name)
|
|
1362
|
+
if include_entities and top_scored:
|
|
1363
|
+
unit_ids = [uuid.UUID(sr.id) for sr in top_scored]
|
|
1364
|
+
if unit_ids:
|
|
1365
|
+
async with acquire_with_retry(pool) as entity_conn:
|
|
1366
|
+
entity_rows = await entity_conn.fetch(
|
|
1367
|
+
"""
|
|
1368
|
+
SELECT ue.unit_id, e.id as entity_id, e.canonical_name
|
|
1369
|
+
FROM unit_entities ue
|
|
1370
|
+
JOIN entities e ON ue.entity_id = e.id
|
|
1371
|
+
WHERE ue.unit_id = ANY($1::uuid[])
|
|
1372
|
+
""",
|
|
1373
|
+
unit_ids
|
|
1374
|
+
)
|
|
1375
|
+
for row in entity_rows:
|
|
1376
|
+
unit_id = str(row['unit_id'])
|
|
1377
|
+
if unit_id not in fact_entity_map:
|
|
1378
|
+
fact_entity_map[unit_id] = []
|
|
1379
|
+
fact_entity_map[unit_id].append({
|
|
1380
|
+
'entity_id': str(row['entity_id']),
|
|
1381
|
+
'canonical_name': row['canonical_name']
|
|
1382
|
+
})
|
|
1383
|
+
|
|
1384
|
+
# Convert results to MemoryFact objects
|
|
1385
|
+
memory_facts = []
|
|
1386
|
+
for result_dict in top_results_dicts:
|
|
1387
|
+
result_id = str(result_dict.get("id"))
|
|
1388
|
+
# Get entity names for this fact
|
|
1389
|
+
entity_names = None
|
|
1390
|
+
if include_entities and result_id in fact_entity_map:
|
|
1391
|
+
entity_names = [e['canonical_name'] for e in fact_entity_map[result_id]]
|
|
1392
|
+
|
|
1393
|
+
memory_facts.append(MemoryFact(
|
|
1394
|
+
id=result_id,
|
|
1395
|
+
text=result_dict.get("text"),
|
|
1396
|
+
fact_type=result_dict.get("fact_type", "world"),
|
|
1397
|
+
entities=entity_names,
|
|
1398
|
+
context=result_dict.get("context"),
|
|
1399
|
+
occurred_start=result_dict.get("occurred_start"),
|
|
1400
|
+
occurred_end=result_dict.get("occurred_end"),
|
|
1401
|
+
mentioned_at=result_dict.get("mentioned_at"),
|
|
1402
|
+
document_id=result_dict.get("document_id"),
|
|
1403
|
+
chunk_id=result_dict.get("chunk_id"),
|
|
1404
|
+
activation=result_dict.get("weight") # Use final weight as activation
|
|
1405
|
+
))
|
|
1406
|
+
|
|
1407
|
+
# Fetch entity observations if requested
|
|
1408
|
+
entities_dict = None
|
|
1409
|
+
if include_entities and fact_entity_map:
|
|
1410
|
+
# Collect unique entities in order of fact relevance (preserving order from top_scored)
|
|
1411
|
+
# Use a list to maintain order, but track seen entities to avoid duplicates
|
|
1412
|
+
entities_ordered = [] # list of (entity_id, entity_name) tuples
|
|
1413
|
+
seen_entity_ids = set()
|
|
1414
|
+
|
|
1415
|
+
# Iterate through facts in relevance order
|
|
1416
|
+
for sr in top_scored:
|
|
1417
|
+
unit_id = sr.id
|
|
1418
|
+
if unit_id in fact_entity_map:
|
|
1419
|
+
for entity in fact_entity_map[unit_id]:
|
|
1420
|
+
entity_id = entity['entity_id']
|
|
1421
|
+
entity_name = entity['canonical_name']
|
|
1422
|
+
if entity_id not in seen_entity_ids:
|
|
1423
|
+
entities_ordered.append((entity_id, entity_name))
|
|
1424
|
+
seen_entity_ids.add(entity_id)
|
|
1425
|
+
|
|
1426
|
+
# Fetch observations for each entity (respect token budget, in order)
|
|
1427
|
+
entities_dict = {}
|
|
1428
|
+
total_entity_tokens = 0
|
|
1429
|
+
encoding = _get_tiktoken_encoding()
|
|
1430
|
+
|
|
1431
|
+
for entity_id, entity_name in entities_ordered:
|
|
1432
|
+
if total_entity_tokens >= max_entity_tokens:
|
|
1433
|
+
break
|
|
1434
|
+
|
|
1435
|
+
observations = await self.get_entity_observations(bank_id, entity_id, limit=5)
|
|
1436
|
+
|
|
1437
|
+
# Calculate tokens for this entity's observations
|
|
1438
|
+
entity_tokens = 0
|
|
1439
|
+
included_observations = []
|
|
1440
|
+
for obs in observations:
|
|
1441
|
+
obs_tokens = len(encoding.encode(obs.text))
|
|
1442
|
+
if total_entity_tokens + entity_tokens + obs_tokens <= max_entity_tokens:
|
|
1443
|
+
included_observations.append(obs)
|
|
1444
|
+
entity_tokens += obs_tokens
|
|
1445
|
+
else:
|
|
1446
|
+
break
|
|
1447
|
+
|
|
1448
|
+
if included_observations:
|
|
1449
|
+
entities_dict[entity_name] = EntityState(
|
|
1450
|
+
entity_id=entity_id,
|
|
1451
|
+
canonical_name=entity_name,
|
|
1452
|
+
observations=included_observations
|
|
1453
|
+
)
|
|
1454
|
+
total_entity_tokens += entity_tokens
|
|
1455
|
+
|
|
1456
|
+
# Fetch chunks if requested
|
|
1457
|
+
chunks_dict = None
|
|
1458
|
+
if include_chunks and top_scored:
|
|
1459
|
+
from .response_models import ChunkInfo
|
|
1460
|
+
|
|
1461
|
+
# Collect chunk_ids in order of fact relevance (preserving order from top_scored)
|
|
1462
|
+
# Use a list to maintain order, but track seen chunks to avoid duplicates
|
|
1463
|
+
chunk_ids_ordered = []
|
|
1464
|
+
seen_chunk_ids = set()
|
|
1465
|
+
for sr in top_scored:
|
|
1466
|
+
chunk_id = sr.retrieval.chunk_id
|
|
1467
|
+
if chunk_id and chunk_id not in seen_chunk_ids:
|
|
1468
|
+
chunk_ids_ordered.append(chunk_id)
|
|
1469
|
+
seen_chunk_ids.add(chunk_id)
|
|
1470
|
+
|
|
1471
|
+
if chunk_ids_ordered:
|
|
1472
|
+
# Fetch chunk data from database using chunk_ids (no ORDER BY to preserve input order)
|
|
1473
|
+
async with acquire_with_retry(pool) as conn:
|
|
1474
|
+
chunks_rows = await conn.fetch(
|
|
1475
|
+
"""
|
|
1476
|
+
SELECT chunk_id, chunk_text, chunk_index
|
|
1477
|
+
FROM chunks
|
|
1478
|
+
WHERE chunk_id = ANY($1::text[])
|
|
1479
|
+
""",
|
|
1480
|
+
chunk_ids_ordered
|
|
1481
|
+
)
|
|
1482
|
+
|
|
1483
|
+
# Create a lookup dict for fast access
|
|
1484
|
+
chunks_lookup = {row['chunk_id']: row for row in chunks_rows}
|
|
1485
|
+
|
|
1486
|
+
# Apply token limit and build chunks_dict in the order of chunk_ids_ordered
|
|
1487
|
+
chunks_dict = {}
|
|
1488
|
+
total_chunk_tokens = 0
|
|
1489
|
+
encoding = _get_tiktoken_encoding()
|
|
1490
|
+
|
|
1491
|
+
for chunk_id in chunk_ids_ordered:
|
|
1492
|
+
if chunk_id not in chunks_lookup:
|
|
1493
|
+
continue
|
|
1494
|
+
|
|
1495
|
+
row = chunks_lookup[chunk_id]
|
|
1496
|
+
chunk_text = row['chunk_text']
|
|
1497
|
+
chunk_tokens = len(encoding.encode(chunk_text))
|
|
1498
|
+
|
|
1499
|
+
# Check if adding this chunk would exceed the limit
|
|
1500
|
+
if total_chunk_tokens + chunk_tokens > max_chunk_tokens:
|
|
1501
|
+
# Truncate the chunk to fit within the remaining budget
|
|
1502
|
+
remaining_tokens = max_chunk_tokens - total_chunk_tokens
|
|
1503
|
+
if remaining_tokens > 0:
|
|
1504
|
+
# Truncate to remaining tokens
|
|
1505
|
+
truncated_text = encoding.decode(encoding.encode(chunk_text)[:remaining_tokens])
|
|
1506
|
+
chunks_dict[chunk_id] = ChunkInfo(
|
|
1507
|
+
chunk_text=truncated_text,
|
|
1508
|
+
chunk_index=row['chunk_index'],
|
|
1509
|
+
truncated=True
|
|
1510
|
+
)
|
|
1511
|
+
total_chunk_tokens = max_chunk_tokens
|
|
1512
|
+
# Stop adding more chunks once we hit the limit
|
|
1513
|
+
break
|
|
1514
|
+
else:
|
|
1515
|
+
chunks_dict[chunk_id] = ChunkInfo(
|
|
1516
|
+
chunk_text=chunk_text,
|
|
1517
|
+
chunk_index=row['chunk_index'],
|
|
1518
|
+
truncated=False
|
|
1519
|
+
)
|
|
1520
|
+
total_chunk_tokens += chunk_tokens
|
|
1521
|
+
|
|
1522
|
+
# Finalize trace if enabled
|
|
1523
|
+
trace_dict = None
|
|
1524
|
+
if tracer:
|
|
1525
|
+
trace = tracer.finalize(top_results_dicts)
|
|
1526
|
+
trace_dict = trace.to_dict() if trace else None
|
|
1527
|
+
|
|
1528
|
+
return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
|
|
1529
|
+
|
|
1530
|
+
except Exception as e:
|
|
1531
|
+
log_buffer.append(f"[SEARCH {search_id}] ERROR after {time.time() - search_start:.3f}s: {str(e)}")
|
|
1532
|
+
logger.error("\n" + "\n".join(log_buffer))
|
|
1533
|
+
raise Exception(f"Failed to search memories: {str(e)}")
|
|
1534
|
+
|
|
1535
|
+
def _filter_by_token_budget(
|
|
1536
|
+
self,
|
|
1537
|
+
results: List[Dict[str, Any]],
|
|
1538
|
+
max_tokens: int
|
|
1539
|
+
) -> Tuple[List[Dict[str, Any]], int]:
|
|
1540
|
+
"""
|
|
1541
|
+
Filter results to fit within token budget.
|
|
1542
|
+
|
|
1543
|
+
Counts tokens only for the 'text' field using tiktoken (cl100k_base encoding).
|
|
1544
|
+
Stops before including a fact that would exceed the budget.
|
|
1545
|
+
|
|
1546
|
+
Args:
|
|
1547
|
+
results: List of search results
|
|
1548
|
+
max_tokens: Maximum tokens allowed
|
|
1549
|
+
|
|
1550
|
+
Returns:
|
|
1551
|
+
Tuple of (filtered_results, total_tokens_used)
|
|
1552
|
+
"""
|
|
1553
|
+
encoding = _get_tiktoken_encoding()
|
|
1554
|
+
|
|
1555
|
+
filtered_results = []
|
|
1556
|
+
total_tokens = 0
|
|
1557
|
+
|
|
1558
|
+
for result in results:
|
|
1559
|
+
text = result.get("text", "")
|
|
1560
|
+
text_tokens = len(encoding.encode(text))
|
|
1561
|
+
|
|
1562
|
+
# Check if adding this result would exceed budget
|
|
1563
|
+
if total_tokens + text_tokens <= max_tokens:
|
|
1564
|
+
filtered_results.append(result)
|
|
1565
|
+
total_tokens += text_tokens
|
|
1566
|
+
else:
|
|
1567
|
+
# Stop before including a fact that would exceed limit
|
|
1568
|
+
break
|
|
1569
|
+
|
|
1570
|
+
return filtered_results, total_tokens
|
|
1571
|
+
|
|
1572
|
+
async def get_document(self, document_id: str, bank_id: str) -> Optional[Dict[str, Any]]:
|
|
1573
|
+
"""
|
|
1574
|
+
Retrieve document metadata and statistics.
|
|
1575
|
+
|
|
1576
|
+
Args:
|
|
1577
|
+
document_id: Document ID to retrieve
|
|
1578
|
+
bank_id: bank ID that owns the document
|
|
1579
|
+
|
|
1580
|
+
Returns:
|
|
1581
|
+
Dictionary with document info or None if not found
|
|
1582
|
+
"""
|
|
1583
|
+
pool = await self._get_pool()
|
|
1584
|
+
async with acquire_with_retry(pool) as conn:
|
|
1585
|
+
doc = await conn.fetchrow(
|
|
1586
|
+
"""
|
|
1587
|
+
SELECT d.id, d.bank_id, d.original_text, d.content_hash,
|
|
1588
|
+
d.created_at, d.updated_at, COUNT(mu.id) as unit_count
|
|
1589
|
+
FROM documents d
|
|
1590
|
+
LEFT JOIN memory_units mu ON mu.document_id = d.id
|
|
1591
|
+
WHERE d.id = $1 AND d.bank_id = $2
|
|
1592
|
+
GROUP BY d.id, d.bank_id, d.original_text, d.content_hash, d.created_at, d.updated_at
|
|
1593
|
+
""",
|
|
1594
|
+
document_id, bank_id
|
|
1595
|
+
)
|
|
1596
|
+
|
|
1597
|
+
if not doc:
|
|
1598
|
+
return None
|
|
1599
|
+
|
|
1600
|
+
return {
|
|
1601
|
+
"id": doc["id"],
|
|
1602
|
+
"bank_id": doc["bank_id"],
|
|
1603
|
+
"original_text": doc["original_text"],
|
|
1604
|
+
"content_hash": doc["content_hash"],
|
|
1605
|
+
"memory_unit_count": doc["unit_count"],
|
|
1606
|
+
"created_at": doc["created_at"],
|
|
1607
|
+
"updated_at": doc["updated_at"]
|
|
1608
|
+
}
|
|
1609
|
+
|
|
1610
|
+
async def delete_document(self, document_id: str, bank_id: str) -> Dict[str, int]:
|
|
1611
|
+
"""
|
|
1612
|
+
Delete a document and all its associated memory units and links.
|
|
1613
|
+
|
|
1614
|
+
Args:
|
|
1615
|
+
document_id: Document ID to delete
|
|
1616
|
+
bank_id: bank ID that owns the document
|
|
1617
|
+
|
|
1618
|
+
Returns:
|
|
1619
|
+
Dictionary with counts of deleted items
|
|
1620
|
+
"""
|
|
1621
|
+
pool = await self._get_pool()
|
|
1622
|
+
async with acquire_with_retry(pool) as conn:
|
|
1623
|
+
async with conn.transaction():
|
|
1624
|
+
# Count units before deletion
|
|
1625
|
+
units_count = await conn.fetchval(
|
|
1626
|
+
"SELECT COUNT(*) FROM memory_units WHERE document_id = $1",
|
|
1627
|
+
document_id
|
|
1628
|
+
)
|
|
1629
|
+
|
|
1630
|
+
# Delete document (cascades to memory_units and all their links)
|
|
1631
|
+
deleted = await conn.fetchval(
|
|
1632
|
+
"DELETE FROM documents WHERE id = $1 AND bank_id = $2 RETURNING id",
|
|
1633
|
+
document_id, bank_id
|
|
1634
|
+
)
|
|
1635
|
+
|
|
1636
|
+
return {
|
|
1637
|
+
"document_deleted": 1 if deleted else 0,
|
|
1638
|
+
"memory_units_deleted": units_count if deleted else 0
|
|
1639
|
+
}
|
|
1640
|
+
|
|
1641
|
+
async def delete_memory_unit(self, unit_id: str) -> Dict[str, Any]:
|
|
1642
|
+
"""
|
|
1643
|
+
Delete a single memory unit and all its associated links.
|
|
1644
|
+
|
|
1645
|
+
Due to CASCADE DELETE constraints, this will automatically delete:
|
|
1646
|
+
- All links from this unit (memory_links where from_unit_id = unit_id)
|
|
1647
|
+
- All links to this unit (memory_links where to_unit_id = unit_id)
|
|
1648
|
+
- All entity associations (unit_entities where unit_id = unit_id)
|
|
1649
|
+
|
|
1650
|
+
Args:
|
|
1651
|
+
unit_id: UUID of the memory unit to delete
|
|
1652
|
+
|
|
1653
|
+
Returns:
|
|
1654
|
+
Dictionary with deletion result
|
|
1655
|
+
"""
|
|
1656
|
+
pool = await self._get_pool()
|
|
1657
|
+
async with acquire_with_retry(pool) as conn:
|
|
1658
|
+
async with conn.transaction():
|
|
1659
|
+
# Delete the memory unit (cascades to links and associations)
|
|
1660
|
+
deleted = await conn.fetchval(
|
|
1661
|
+
"DELETE FROM memory_units WHERE id = $1 RETURNING id",
|
|
1662
|
+
unit_id
|
|
1663
|
+
)
|
|
1664
|
+
|
|
1665
|
+
return {
|
|
1666
|
+
"success": deleted is not None,
|
|
1667
|
+
"unit_id": str(deleted) if deleted else None,
|
|
1668
|
+
"message": "Memory unit and all its links deleted successfully" if deleted else "Memory unit not found"
|
|
1669
|
+
}
|
|
1670
|
+
|
|
1671
|
+
async def delete_bank(self, bank_id: str, fact_type: Optional[str] = None) -> Dict[str, int]:
|
|
1672
|
+
"""
|
|
1673
|
+
Delete all data for a specific agent (multi-tenant cleanup).
|
|
1674
|
+
|
|
1675
|
+
This is much more efficient than dropping all tables and allows
|
|
1676
|
+
multiple agents to coexist in the same database.
|
|
1677
|
+
|
|
1678
|
+
Deletes (with CASCADE):
|
|
1679
|
+
- All memory units for this bank (optionally filtered by fact_type)
|
|
1680
|
+
- All entities for this bank (if deleting all memory units)
|
|
1681
|
+
- All associated links, unit-entity associations, and co-occurrences
|
|
1682
|
+
|
|
1683
|
+
Args:
|
|
1684
|
+
bank_id: bank ID to delete
|
|
1685
|
+
fact_type: Optional fact type filter (world, bank, opinion). If provided, only deletes memories of that type.
|
|
1686
|
+
|
|
1687
|
+
Returns:
|
|
1688
|
+
Dictionary with counts of deleted items
|
|
1689
|
+
"""
|
|
1690
|
+
pool = await self._get_pool()
|
|
1691
|
+
async with acquire_with_retry(pool) as conn:
|
|
1692
|
+
async with conn.transaction():
|
|
1693
|
+
try:
|
|
1694
|
+
if fact_type:
|
|
1695
|
+
# Delete only memories of a specific fact type
|
|
1696
|
+
units_count = await conn.fetchval(
|
|
1697
|
+
"SELECT COUNT(*) FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
|
|
1698
|
+
bank_id, fact_type
|
|
1699
|
+
)
|
|
1700
|
+
await conn.execute(
|
|
1701
|
+
"DELETE FROM memory_units WHERE bank_id = $1 AND fact_type = $2",
|
|
1702
|
+
bank_id, fact_type
|
|
1703
|
+
)
|
|
1704
|
+
|
|
1705
|
+
# Note: We don't delete entities when fact_type is specified,
|
|
1706
|
+
# as they may be referenced by other memory units
|
|
1707
|
+
return {
|
|
1708
|
+
"memory_units_deleted": units_count,
|
|
1709
|
+
"entities_deleted": 0
|
|
1710
|
+
}
|
|
1711
|
+
else:
|
|
1712
|
+
# Delete all data for the bank
|
|
1713
|
+
units_count = await conn.fetchval("SELECT COUNT(*) FROM memory_units WHERE bank_id = $1", bank_id)
|
|
1714
|
+
entities_count = await conn.fetchval("SELECT COUNT(*) FROM entities WHERE bank_id = $1", bank_id)
|
|
1715
|
+
documents_count = await conn.fetchval("SELECT COUNT(*) FROM documents WHERE bank_id = $1", bank_id)
|
|
1716
|
+
|
|
1717
|
+
# Delete documents (cascades to chunks)
|
|
1718
|
+
await conn.execute("DELETE FROM documents WHERE bank_id = $1", bank_id)
|
|
1719
|
+
|
|
1720
|
+
# Delete memory units (cascades to unit_entities, memory_links)
|
|
1721
|
+
await conn.execute("DELETE FROM memory_units WHERE bank_id = $1", bank_id)
|
|
1722
|
+
|
|
1723
|
+
# Delete entities (cascades to unit_entities, entity_cooccurrences, memory_links with entity_id)
|
|
1724
|
+
await conn.execute("DELETE FROM entities WHERE bank_id = $1", bank_id)
|
|
1725
|
+
|
|
1726
|
+
return {
|
|
1727
|
+
"memory_units_deleted": units_count,
|
|
1728
|
+
"entities_deleted": entities_count,
|
|
1729
|
+
"documents_deleted": documents_count
|
|
1730
|
+
}
|
|
1731
|
+
|
|
1732
|
+
except Exception as e:
|
|
1733
|
+
raise Exception(f"Failed to delete agent data: {str(e)}")
|
|
1734
|
+
|
|
1735
|
+
async def get_graph_data(self, bank_id: Optional[str] = None, fact_type: Optional[str] = None):
|
|
1736
|
+
"""
|
|
1737
|
+
Get graph data for visualization.
|
|
1738
|
+
|
|
1739
|
+
Args:
|
|
1740
|
+
bank_id: Filter by bank ID
|
|
1741
|
+
fact_type: Filter by fact type (world, bank, opinion)
|
|
1742
|
+
|
|
1743
|
+
Returns:
|
|
1744
|
+
Dict with nodes, edges, and table_rows
|
|
1745
|
+
"""
|
|
1746
|
+
pool = await self._get_pool()
|
|
1747
|
+
async with acquire_with_retry(pool) as conn:
|
|
1748
|
+
# Get memory units, optionally filtered by bank_id and fact_type
|
|
1749
|
+
query_conditions = []
|
|
1750
|
+
query_params = []
|
|
1751
|
+
param_count = 0
|
|
1752
|
+
|
|
1753
|
+
if bank_id:
|
|
1754
|
+
param_count += 1
|
|
1755
|
+
query_conditions.append(f"bank_id = ${param_count}")
|
|
1756
|
+
query_params.append(bank_id)
|
|
1757
|
+
|
|
1758
|
+
if fact_type:
|
|
1759
|
+
param_count += 1
|
|
1760
|
+
query_conditions.append(f"fact_type = ${param_count}")
|
|
1761
|
+
query_params.append(fact_type)
|
|
1762
|
+
|
|
1763
|
+
where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
|
|
1764
|
+
|
|
1765
|
+
units = await conn.fetch(f"""
|
|
1766
|
+
SELECT id, text, event_date, context, occurred_start, occurred_end, mentioned_at, document_id, chunk_id, fact_type
|
|
1767
|
+
FROM memory_units
|
|
1768
|
+
{where_clause}
|
|
1769
|
+
ORDER BY mentioned_at DESC NULLS LAST, event_date DESC
|
|
1770
|
+
LIMIT 1000
|
|
1771
|
+
""", *query_params)
|
|
1772
|
+
|
|
1773
|
+
# Get links, filtering to only include links between units of the selected agent
|
|
1774
|
+
unit_ids = [row['id'] for row in units]
|
|
1775
|
+
if unit_ids:
|
|
1776
|
+
links = await conn.fetch("""
|
|
1777
|
+
SELECT
|
|
1778
|
+
ml.from_unit_id,
|
|
1779
|
+
ml.to_unit_id,
|
|
1780
|
+
ml.link_type,
|
|
1781
|
+
ml.weight,
|
|
1782
|
+
e.canonical_name as entity_name
|
|
1783
|
+
FROM memory_links ml
|
|
1784
|
+
LEFT JOIN entities e ON ml.entity_id = e.id
|
|
1785
|
+
WHERE ml.from_unit_id = ANY($1::uuid[]) AND ml.to_unit_id = ANY($1::uuid[])
|
|
1786
|
+
ORDER BY ml.link_type, ml.weight DESC
|
|
1787
|
+
""", unit_ids)
|
|
1788
|
+
else:
|
|
1789
|
+
links = []
|
|
1790
|
+
|
|
1791
|
+
# Get entity information
|
|
1792
|
+
unit_entities = await conn.fetch("""
|
|
1793
|
+
SELECT ue.unit_id, e.canonical_name
|
|
1794
|
+
FROM unit_entities ue
|
|
1795
|
+
JOIN entities e ON ue.entity_id = e.id
|
|
1796
|
+
ORDER BY ue.unit_id
|
|
1797
|
+
""")
|
|
1798
|
+
|
|
1799
|
+
# Build entity mapping
|
|
1800
|
+
entity_map = {}
|
|
1801
|
+
for row in unit_entities:
|
|
1802
|
+
unit_id = row['unit_id']
|
|
1803
|
+
entity_name = row['canonical_name']
|
|
1804
|
+
if unit_id not in entity_map:
|
|
1805
|
+
entity_map[unit_id] = []
|
|
1806
|
+
entity_map[unit_id].append(entity_name)
|
|
1807
|
+
|
|
1808
|
+
# Build nodes
|
|
1809
|
+
nodes = []
|
|
1810
|
+
for row in units:
|
|
1811
|
+
unit_id = row['id']
|
|
1812
|
+
text = row['text']
|
|
1813
|
+
event_date = row['event_date']
|
|
1814
|
+
context = row['context']
|
|
1815
|
+
|
|
1816
|
+
entities = entity_map.get(unit_id, [])
|
|
1817
|
+
entity_count = len(entities)
|
|
1818
|
+
|
|
1819
|
+
# Color by entity count
|
|
1820
|
+
if entity_count == 0:
|
|
1821
|
+
color = "#e0e0e0"
|
|
1822
|
+
elif entity_count == 1:
|
|
1823
|
+
color = "#90caf9"
|
|
1824
|
+
else:
|
|
1825
|
+
color = "#42a5f5"
|
|
1826
|
+
|
|
1827
|
+
nodes.append({
|
|
1828
|
+
"data": {
|
|
1829
|
+
"id": str(unit_id),
|
|
1830
|
+
"label": f"{text[:30]}..." if len(text) > 30 else text,
|
|
1831
|
+
"text": text,
|
|
1832
|
+
"date": event_date.isoformat() if event_date else "",
|
|
1833
|
+
"context": context if context else "",
|
|
1834
|
+
"entities": ", ".join(entities) if entities else "None",
|
|
1835
|
+
"color": color
|
|
1836
|
+
}
|
|
1837
|
+
})
|
|
1838
|
+
|
|
1839
|
+
# Build edges
|
|
1840
|
+
edges = []
|
|
1841
|
+
for row in links:
|
|
1842
|
+
from_id = str(row['from_unit_id'])
|
|
1843
|
+
to_id = str(row['to_unit_id'])
|
|
1844
|
+
link_type = row['link_type']
|
|
1845
|
+
weight = row['weight']
|
|
1846
|
+
entity_name = row['entity_name']
|
|
1847
|
+
|
|
1848
|
+
# Color by link type
|
|
1849
|
+
if link_type == 'temporal':
|
|
1850
|
+
color = "#00bcd4"
|
|
1851
|
+
line_style = "dashed"
|
|
1852
|
+
elif link_type == 'semantic':
|
|
1853
|
+
color = "#ff69b4"
|
|
1854
|
+
line_style = "solid"
|
|
1855
|
+
elif link_type == 'entity':
|
|
1856
|
+
color = "#ffd700"
|
|
1857
|
+
line_style = "solid"
|
|
1858
|
+
else:
|
|
1859
|
+
color = "#999999"
|
|
1860
|
+
line_style = "solid"
|
|
1861
|
+
|
|
1862
|
+
edges.append({
|
|
1863
|
+
"data": {
|
|
1864
|
+
"id": f"{from_id}-{to_id}-{link_type}",
|
|
1865
|
+
"source": from_id,
|
|
1866
|
+
"target": to_id,
|
|
1867
|
+
"linkType": link_type,
|
|
1868
|
+
"weight": weight,
|
|
1869
|
+
"entityName": entity_name if entity_name else "",
|
|
1870
|
+
"color": color,
|
|
1871
|
+
"lineStyle": line_style
|
|
1872
|
+
}
|
|
1873
|
+
})
|
|
1874
|
+
|
|
1875
|
+
# Build table rows
|
|
1876
|
+
table_rows = []
|
|
1877
|
+
for row in units:
|
|
1878
|
+
unit_id = row['id']
|
|
1879
|
+
entities = entity_map.get(unit_id, [])
|
|
1880
|
+
|
|
1881
|
+
table_rows.append({
|
|
1882
|
+
"id": str(unit_id),
|
|
1883
|
+
"text": row['text'],
|
|
1884
|
+
"context": row['context'] if row['context'] else "N/A",
|
|
1885
|
+
"occurred_start": row['occurred_start'].isoformat() if row['occurred_start'] else None,
|
|
1886
|
+
"occurred_end": row['occurred_end'].isoformat() if row['occurred_end'] else None,
|
|
1887
|
+
"mentioned_at": row['mentioned_at'].isoformat() if row['mentioned_at'] else None,
|
|
1888
|
+
"date": row['event_date'].strftime("%Y-%m-%d %H:%M") if row['event_date'] else "N/A", # Deprecated, kept for backwards compatibility
|
|
1889
|
+
"entities": ", ".join(entities) if entities else "None",
|
|
1890
|
+
"document_id": row['document_id'],
|
|
1891
|
+
"chunk_id": row['chunk_id'] if row['chunk_id'] else None,
|
|
1892
|
+
"fact_type": row['fact_type']
|
|
1893
|
+
})
|
|
1894
|
+
|
|
1895
|
+
return {
|
|
1896
|
+
"nodes": nodes,
|
|
1897
|
+
"edges": edges,
|
|
1898
|
+
"table_rows": table_rows,
|
|
1899
|
+
"total_units": len(units)
|
|
1900
|
+
}
|
|
1901
|
+
|
|
1902
|
+
async def list_memory_units(
|
|
1903
|
+
self,
|
|
1904
|
+
bank_id: Optional[str] = None,
|
|
1905
|
+
fact_type: Optional[str] = None,
|
|
1906
|
+
search_query: Optional[str] = None,
|
|
1907
|
+
limit: int = 100,
|
|
1908
|
+
offset: int = 0
|
|
1909
|
+
):
|
|
1910
|
+
"""
|
|
1911
|
+
List memory units for table view with optional full-text search.
|
|
1912
|
+
|
|
1913
|
+
Args:
|
|
1914
|
+
bank_id: Filter by bank ID
|
|
1915
|
+
fact_type: Filter by fact type (world, bank, opinion)
|
|
1916
|
+
search_query: Full-text search query (searches text and context fields)
|
|
1917
|
+
limit: Maximum number of results to return
|
|
1918
|
+
offset: Offset for pagination
|
|
1919
|
+
|
|
1920
|
+
Returns:
|
|
1921
|
+
Dict with items (list of memory units) and total count
|
|
1922
|
+
"""
|
|
1923
|
+
pool = await self._get_pool()
|
|
1924
|
+
async with acquire_with_retry(pool) as conn:
|
|
1925
|
+
# Build query conditions
|
|
1926
|
+
query_conditions = []
|
|
1927
|
+
query_params = []
|
|
1928
|
+
param_count = 0
|
|
1929
|
+
|
|
1930
|
+
if bank_id:
|
|
1931
|
+
param_count += 1
|
|
1932
|
+
query_conditions.append(f"bank_id = ${param_count}")
|
|
1933
|
+
query_params.append(bank_id)
|
|
1934
|
+
|
|
1935
|
+
if fact_type:
|
|
1936
|
+
param_count += 1
|
|
1937
|
+
query_conditions.append(f"fact_type = ${param_count}")
|
|
1938
|
+
query_params.append(fact_type)
|
|
1939
|
+
|
|
1940
|
+
if search_query:
|
|
1941
|
+
# Full-text search on text and context fields using ILIKE
|
|
1942
|
+
param_count += 1
|
|
1943
|
+
query_conditions.append(f"(text ILIKE ${param_count} OR context ILIKE ${param_count})")
|
|
1944
|
+
query_params.append(f"%{search_query}%")
|
|
1945
|
+
|
|
1946
|
+
where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
|
|
1947
|
+
|
|
1948
|
+
# Get total count
|
|
1949
|
+
count_query = f"""
|
|
1950
|
+
SELECT COUNT(*) as total
|
|
1951
|
+
FROM memory_units
|
|
1952
|
+
{where_clause}
|
|
1953
|
+
"""
|
|
1954
|
+
count_result = await conn.fetchrow(count_query, *query_params)
|
|
1955
|
+
total = count_result['total']
|
|
1956
|
+
|
|
1957
|
+
# Get units with limit and offset
|
|
1958
|
+
param_count += 1
|
|
1959
|
+
limit_param = f"${param_count}"
|
|
1960
|
+
query_params.append(limit)
|
|
1961
|
+
|
|
1962
|
+
param_count += 1
|
|
1963
|
+
offset_param = f"${param_count}"
|
|
1964
|
+
query_params.append(offset)
|
|
1965
|
+
|
|
1966
|
+
units = await conn.fetch(f"""
|
|
1967
|
+
SELECT id, text, event_date, context, fact_type, mentioned_at, occurred_start, occurred_end, chunk_id
|
|
1968
|
+
FROM memory_units
|
|
1969
|
+
{where_clause}
|
|
1970
|
+
ORDER BY mentioned_at DESC NULLS LAST, created_at DESC
|
|
1971
|
+
LIMIT {limit_param} OFFSET {offset_param}
|
|
1972
|
+
""", *query_params)
|
|
1973
|
+
|
|
1974
|
+
# Get entity information for these units
|
|
1975
|
+
if units:
|
|
1976
|
+
unit_ids = [row['id'] for row in units]
|
|
1977
|
+
unit_entities = await conn.fetch("""
|
|
1978
|
+
SELECT ue.unit_id, e.canonical_name
|
|
1979
|
+
FROM unit_entities ue
|
|
1980
|
+
JOIN entities e ON ue.entity_id = e.id
|
|
1981
|
+
WHERE ue.unit_id = ANY($1::uuid[])
|
|
1982
|
+
ORDER BY ue.unit_id
|
|
1983
|
+
""", unit_ids)
|
|
1984
|
+
else:
|
|
1985
|
+
unit_entities = []
|
|
1986
|
+
|
|
1987
|
+
# Build entity mapping
|
|
1988
|
+
entity_map = {}
|
|
1989
|
+
for row in unit_entities:
|
|
1990
|
+
unit_id = row['unit_id']
|
|
1991
|
+
entity_name = row['canonical_name']
|
|
1992
|
+
if unit_id not in entity_map:
|
|
1993
|
+
entity_map[unit_id] = []
|
|
1994
|
+
entity_map[unit_id].append(entity_name)
|
|
1995
|
+
|
|
1996
|
+
# Build result items
|
|
1997
|
+
items = []
|
|
1998
|
+
for row in units:
|
|
1999
|
+
unit_id = row['id']
|
|
2000
|
+
entities = entity_map.get(unit_id, [])
|
|
2001
|
+
|
|
2002
|
+
items.append({
|
|
2003
|
+
"id": str(unit_id),
|
|
2004
|
+
"text": row['text'],
|
|
2005
|
+
"context": row['context'] if row['context'] else "",
|
|
2006
|
+
"date": row['event_date'].isoformat() if row['event_date'] else "",
|
|
2007
|
+
"fact_type": row['fact_type'],
|
|
2008
|
+
"mentioned_at": row['mentioned_at'].isoformat() if row['mentioned_at'] else None,
|
|
2009
|
+
"occurred_start": row['occurred_start'].isoformat() if row['occurred_start'] else None,
|
|
2010
|
+
"occurred_end": row['occurred_end'].isoformat() if row['occurred_end'] else None,
|
|
2011
|
+
"entities": ", ".join(entities) if entities else "",
|
|
2012
|
+
"chunk_id": row['chunk_id'] if row['chunk_id'] else None
|
|
2013
|
+
})
|
|
2014
|
+
|
|
2015
|
+
return {
|
|
2016
|
+
"items": items,
|
|
2017
|
+
"total": total,
|
|
2018
|
+
"limit": limit,
|
|
2019
|
+
"offset": offset
|
|
2020
|
+
}
|
|
2021
|
+
|
|
2022
|
+
async def list_documents(
|
|
2023
|
+
self,
|
|
2024
|
+
bank_id: str,
|
|
2025
|
+
search_query: Optional[str] = None,
|
|
2026
|
+
limit: int = 100,
|
|
2027
|
+
offset: int = 0
|
|
2028
|
+
):
|
|
2029
|
+
"""
|
|
2030
|
+
List documents with optional search and pagination.
|
|
2031
|
+
|
|
2032
|
+
Args:
|
|
2033
|
+
bank_id: bank ID (required)
|
|
2034
|
+
search_query: Search in document ID
|
|
2035
|
+
limit: Maximum number of results
|
|
2036
|
+
offset: Offset for pagination
|
|
2037
|
+
|
|
2038
|
+
Returns:
|
|
2039
|
+
Dict with items (list of documents without original_text) and total count
|
|
2040
|
+
"""
|
|
2041
|
+
pool = await self._get_pool()
|
|
2042
|
+
async with acquire_with_retry(pool) as conn:
|
|
2043
|
+
# Build query conditions
|
|
2044
|
+
query_conditions = []
|
|
2045
|
+
query_params = []
|
|
2046
|
+
param_count = 0
|
|
2047
|
+
|
|
2048
|
+
param_count += 1
|
|
2049
|
+
query_conditions.append(f"bank_id = ${param_count}")
|
|
2050
|
+
query_params.append(bank_id)
|
|
2051
|
+
|
|
2052
|
+
if search_query:
|
|
2053
|
+
# Search in document ID
|
|
2054
|
+
param_count += 1
|
|
2055
|
+
query_conditions.append(f"id ILIKE ${param_count}")
|
|
2056
|
+
query_params.append(f"%{search_query}%")
|
|
2057
|
+
|
|
2058
|
+
where_clause = "WHERE " + " AND ".join(query_conditions) if query_conditions else ""
|
|
2059
|
+
|
|
2060
|
+
# Get total count
|
|
2061
|
+
count_query = f"""
|
|
2062
|
+
SELECT COUNT(*) as total
|
|
2063
|
+
FROM documents
|
|
2064
|
+
{where_clause}
|
|
2065
|
+
"""
|
|
2066
|
+
count_result = await conn.fetchrow(count_query, *query_params)
|
|
2067
|
+
total = count_result['total']
|
|
2068
|
+
|
|
2069
|
+
# Get documents with limit and offset (without original_text for performance)
|
|
2070
|
+
param_count += 1
|
|
2071
|
+
limit_param = f"${param_count}"
|
|
2072
|
+
query_params.append(limit)
|
|
2073
|
+
|
|
2074
|
+
param_count += 1
|
|
2075
|
+
offset_param = f"${param_count}"
|
|
2076
|
+
query_params.append(offset)
|
|
2077
|
+
|
|
2078
|
+
documents = await conn.fetch(f"""
|
|
2079
|
+
SELECT
|
|
2080
|
+
id,
|
|
2081
|
+
bank_id,
|
|
2082
|
+
content_hash,
|
|
2083
|
+
created_at,
|
|
2084
|
+
updated_at,
|
|
2085
|
+
LENGTH(original_text) as text_length,
|
|
2086
|
+
retain_params
|
|
2087
|
+
FROM documents
|
|
2088
|
+
{where_clause}
|
|
2089
|
+
ORDER BY created_at DESC
|
|
2090
|
+
LIMIT {limit_param} OFFSET {offset_param}
|
|
2091
|
+
""", *query_params)
|
|
2092
|
+
|
|
2093
|
+
# Get memory unit count for each document
|
|
2094
|
+
if documents:
|
|
2095
|
+
doc_ids = [(row['id'], row['bank_id']) for row in documents]
|
|
2096
|
+
|
|
2097
|
+
# Create placeholders for the query
|
|
2098
|
+
placeholders = []
|
|
2099
|
+
params_for_count = []
|
|
2100
|
+
for i, (doc_id, bank_id_val) in enumerate(doc_ids):
|
|
2101
|
+
idx_doc = i * 2 + 1
|
|
2102
|
+
idx_agent = i * 2 + 2
|
|
2103
|
+
placeholders.append(f"(document_id = ${idx_doc} AND bank_id = ${idx_agent})")
|
|
2104
|
+
params_for_count.extend([doc_id, bank_id_val])
|
|
2105
|
+
|
|
2106
|
+
where_clause_count = " OR ".join(placeholders)
|
|
2107
|
+
|
|
2108
|
+
unit_counts = await conn.fetch(f"""
|
|
2109
|
+
SELECT document_id, bank_id, COUNT(*) as unit_count
|
|
2110
|
+
FROM memory_units
|
|
2111
|
+
WHERE {where_clause_count}
|
|
2112
|
+
GROUP BY document_id, bank_id
|
|
2113
|
+
""", *params_for_count)
|
|
2114
|
+
else:
|
|
2115
|
+
unit_counts = []
|
|
2116
|
+
|
|
2117
|
+
# Build count mapping
|
|
2118
|
+
count_map = {(row['document_id'], row['bank_id']): row['unit_count'] for row in unit_counts}
|
|
2119
|
+
|
|
2120
|
+
# Build result items
|
|
2121
|
+
items = []
|
|
2122
|
+
for row in documents:
|
|
2123
|
+
doc_id = row['id']
|
|
2124
|
+
bank_id_val = row['bank_id']
|
|
2125
|
+
unit_count = count_map.get((doc_id, bank_id_val), 0)
|
|
2126
|
+
|
|
2127
|
+
items.append({
|
|
2128
|
+
"id": doc_id,
|
|
2129
|
+
"bank_id": bank_id_val,
|
|
2130
|
+
"content_hash": row['content_hash'],
|
|
2131
|
+
"created_at": row['created_at'].isoformat() if row['created_at'] else "",
|
|
2132
|
+
"updated_at": row['updated_at'].isoformat() if row['updated_at'] else "",
|
|
2133
|
+
"text_length": row['text_length'] or 0,
|
|
2134
|
+
"memory_unit_count": unit_count,
|
|
2135
|
+
"retain_params": row['retain_params'] if row['retain_params'] else None
|
|
2136
|
+
})
|
|
2137
|
+
|
|
2138
|
+
return {
|
|
2139
|
+
"items": items,
|
|
2140
|
+
"total": total,
|
|
2141
|
+
"limit": limit,
|
|
2142
|
+
"offset": offset
|
|
2143
|
+
}
|
|
2144
|
+
|
|
2145
|
+
async def get_document(
|
|
2146
|
+
self,
|
|
2147
|
+
document_id: str,
|
|
2148
|
+
bank_id: str
|
|
2149
|
+
):
|
|
2150
|
+
"""
|
|
2151
|
+
Get a specific document including its original_text.
|
|
2152
|
+
|
|
2153
|
+
Args:
|
|
2154
|
+
document_id: Document ID
|
|
2155
|
+
bank_id: bank ID
|
|
2156
|
+
|
|
2157
|
+
Returns:
|
|
2158
|
+
Dict with document details including original_text, or None if not found
|
|
2159
|
+
"""
|
|
2160
|
+
pool = await self._get_pool()
|
|
2161
|
+
async with acquire_with_retry(pool) as conn:
|
|
2162
|
+
doc = await conn.fetchrow("""
|
|
2163
|
+
SELECT
|
|
2164
|
+
id,
|
|
2165
|
+
bank_id,
|
|
2166
|
+
original_text,
|
|
2167
|
+
content_hash,
|
|
2168
|
+
created_at,
|
|
2169
|
+
updated_at,
|
|
2170
|
+
retain_params
|
|
2171
|
+
FROM documents
|
|
2172
|
+
WHERE id = $1 AND bank_id = $2
|
|
2173
|
+
""", document_id, bank_id)
|
|
2174
|
+
|
|
2175
|
+
if not doc:
|
|
2176
|
+
return None
|
|
2177
|
+
|
|
2178
|
+
# Get memory unit count
|
|
2179
|
+
unit_count_row = await conn.fetchrow("""
|
|
2180
|
+
SELECT COUNT(*) as unit_count
|
|
2181
|
+
FROM memory_units
|
|
2182
|
+
WHERE document_id = $1 AND bank_id = $2
|
|
2183
|
+
""", document_id, bank_id)
|
|
2184
|
+
|
|
2185
|
+
return {
|
|
2186
|
+
"id": doc['id'],
|
|
2187
|
+
"bank_id": doc['bank_id'],
|
|
2188
|
+
"original_text": doc['original_text'],
|
|
2189
|
+
"content_hash": doc['content_hash'],
|
|
2190
|
+
"created_at": doc['created_at'].isoformat() if doc['created_at'] else "",
|
|
2191
|
+
"updated_at": doc['updated_at'].isoformat() if doc['updated_at'] else "",
|
|
2192
|
+
"memory_unit_count": unit_count_row['unit_count'] if unit_count_row else 0,
|
|
2193
|
+
"retain_params": doc['retain_params'] if doc['retain_params'] else None
|
|
2194
|
+
}
|
|
2195
|
+
|
|
2196
|
+
async def get_chunk(
|
|
2197
|
+
self,
|
|
2198
|
+
chunk_id: str
|
|
2199
|
+
):
|
|
2200
|
+
"""
|
|
2201
|
+
Get a specific chunk by its ID.
|
|
2202
|
+
|
|
2203
|
+
Args:
|
|
2204
|
+
chunk_id: Chunk ID (format: bank_id_document_id_chunk_index)
|
|
2205
|
+
|
|
2206
|
+
Returns:
|
|
2207
|
+
Dict with chunk details including chunk_text, or None if not found
|
|
2208
|
+
"""
|
|
2209
|
+
pool = await self._get_pool()
|
|
2210
|
+
async with acquire_with_retry(pool) as conn:
|
|
2211
|
+
chunk = await conn.fetchrow("""
|
|
2212
|
+
SELECT
|
|
2213
|
+
chunk_id,
|
|
2214
|
+
document_id,
|
|
2215
|
+
bank_id,
|
|
2216
|
+
chunk_index,
|
|
2217
|
+
chunk_text,
|
|
2218
|
+
created_at
|
|
2219
|
+
FROM chunks
|
|
2220
|
+
WHERE chunk_id = $1
|
|
2221
|
+
""", chunk_id)
|
|
2222
|
+
|
|
2223
|
+
if not chunk:
|
|
2224
|
+
return None
|
|
2225
|
+
|
|
2226
|
+
return {
|
|
2227
|
+
"chunk_id": chunk['chunk_id'],
|
|
2228
|
+
"document_id": chunk['document_id'],
|
|
2229
|
+
"bank_id": chunk['bank_id'],
|
|
2230
|
+
"chunk_index": chunk['chunk_index'],
|
|
2231
|
+
"chunk_text": chunk['chunk_text'],
|
|
2232
|
+
"created_at": chunk['created_at'].isoformat() if chunk['created_at'] else ""
|
|
2233
|
+
}
|
|
2234
|
+
|
|
2235
|
+
async def _evaluate_opinion_update_async(
|
|
2236
|
+
self,
|
|
2237
|
+
opinion_text: str,
|
|
2238
|
+
opinion_confidence: float,
|
|
2239
|
+
new_event_text: str,
|
|
2240
|
+
entity_name: str,
|
|
2241
|
+
) -> Optional[Dict[str, Any]]:
|
|
2242
|
+
"""
|
|
2243
|
+
Evaluate if an opinion should be updated based on a new event.
|
|
2244
|
+
|
|
2245
|
+
Args:
|
|
2246
|
+
opinion_text: Current opinion text (includes reasons)
|
|
2247
|
+
opinion_confidence: Current confidence score (0.0-1.0)
|
|
2248
|
+
new_event_text: Text of the new event
|
|
2249
|
+
entity_name: Name of the entity this opinion is about
|
|
2250
|
+
|
|
2251
|
+
Returns:
|
|
2252
|
+
Dict with 'action' ('keep'|'update'), 'new_confidence', 'new_text' (if action=='update')
|
|
2253
|
+
or None if no changes needed
|
|
2254
|
+
"""
|
|
2255
|
+
from pydantic import BaseModel, Field
|
|
2256
|
+
|
|
2257
|
+
class OpinionEvaluation(BaseModel):
|
|
2258
|
+
"""Evaluation of whether an opinion should be updated."""
|
|
2259
|
+
action: str = Field(description="Action to take: 'keep' (no change) or 'update' (modify opinion)")
|
|
2260
|
+
reasoning: str = Field(description="Brief explanation of why this action was chosen")
|
|
2261
|
+
new_confidence: float = Field(description="New confidence score (0.0-1.0). Can be higher, lower, or same as before.")
|
|
2262
|
+
new_opinion_text: Optional[str] = Field(
|
|
2263
|
+
default=None,
|
|
2264
|
+
description="If action is 'update', the revised opinion text that acknowledges the previous view. Otherwise None."
|
|
2265
|
+
)
|
|
2266
|
+
|
|
2267
|
+
evaluation_prompt = f"""You are evaluating whether an existing opinion should be updated based on new information.
|
|
2268
|
+
|
|
2269
|
+
ENTITY: {entity_name}
|
|
2270
|
+
|
|
2271
|
+
EXISTING OPINION:
|
|
2272
|
+
{opinion_text}
|
|
2273
|
+
Current confidence: {opinion_confidence:.2f}
|
|
2274
|
+
|
|
2275
|
+
NEW EVENT:
|
|
2276
|
+
{new_event_text}
|
|
2277
|
+
|
|
2278
|
+
Evaluate whether this new event:
|
|
2279
|
+
1. REINFORCES the opinion (increase confidence, keep text)
|
|
2280
|
+
2. WEAKENS the opinion (decrease confidence, keep text)
|
|
2281
|
+
3. CHANGES the opinion (update both text and confidence, noting "Previously I thought X, but now Y...")
|
|
2282
|
+
4. IRRELEVANT (keep everything as is)
|
|
2283
|
+
|
|
2284
|
+
Guidelines:
|
|
2285
|
+
- Only suggest 'update' action if the new event genuinely contradicts or significantly modifies the opinion
|
|
2286
|
+
- If updating the text, acknowledge the previous opinion and explain the change
|
|
2287
|
+
- Confidence should reflect accumulated evidence (0.0 = no confidence, 1.0 = very confident)
|
|
2288
|
+
- Small changes in confidence are normal; large jumps should be rare"""
|
|
2289
|
+
|
|
2290
|
+
try:
|
|
2291
|
+
result = await self._llm_config.call(
|
|
2292
|
+
messages=[
|
|
2293
|
+
{"role": "system", "content": "You evaluate and update opinions based on new information."},
|
|
2294
|
+
{"role": "user", "content": evaluation_prompt}
|
|
2295
|
+
],
|
|
2296
|
+
response_format=OpinionEvaluation,
|
|
2297
|
+
scope="memory_evaluate_opinion",
|
|
2298
|
+
temperature=0.3 # Lower temperature for more consistent evaluation
|
|
2299
|
+
)
|
|
2300
|
+
|
|
2301
|
+
# Only return updates if something actually changed
|
|
2302
|
+
if result.action == 'keep' and abs(result.new_confidence - opinion_confidence) < 0.01:
|
|
2303
|
+
return None
|
|
2304
|
+
|
|
2305
|
+
return {
|
|
2306
|
+
'action': result.action,
|
|
2307
|
+
'reasoning': result.reasoning,
|
|
2308
|
+
'new_confidence': result.new_confidence,
|
|
2309
|
+
'new_text': result.new_opinion_text if result.action == 'update' else None
|
|
2310
|
+
}
|
|
2311
|
+
|
|
2312
|
+
except Exception as e:
|
|
2313
|
+
logger.warning(f"Failed to evaluate opinion update: {str(e)}")
|
|
2314
|
+
return None
|
|
2315
|
+
|
|
2316
|
+
async def _handle_form_opinion(self, task_dict: Dict[str, Any]):
|
|
2317
|
+
"""
|
|
2318
|
+
Handler for form opinion tasks.
|
|
2319
|
+
|
|
2320
|
+
Args:
|
|
2321
|
+
task_dict: Dict with keys: 'bank_id', 'answer_text', 'query'
|
|
2322
|
+
"""
|
|
2323
|
+
bank_id = task_dict['bank_id']
|
|
2324
|
+
answer_text = task_dict['answer_text']
|
|
2325
|
+
query = task_dict['query']
|
|
2326
|
+
|
|
2327
|
+
await self._extract_and_store_opinions_async(
|
|
2328
|
+
bank_id=bank_id,
|
|
2329
|
+
answer_text=answer_text,
|
|
2330
|
+
query=query
|
|
2331
|
+
)
|
|
2332
|
+
|
|
2333
|
+
async def _handle_reinforce_opinion(self, task_dict: Dict[str, Any]):
|
|
2334
|
+
"""
|
|
2335
|
+
Handler for reinforce opinion tasks.
|
|
2336
|
+
|
|
2337
|
+
Args:
|
|
2338
|
+
task_dict: Dict with keys: 'bank_id', 'created_unit_ids', 'unit_texts', 'unit_entities'
|
|
2339
|
+
"""
|
|
2340
|
+
bank_id = task_dict['bank_id']
|
|
2341
|
+
created_unit_ids = task_dict['created_unit_ids']
|
|
2342
|
+
unit_texts = task_dict['unit_texts']
|
|
2343
|
+
unit_entities = task_dict['unit_entities']
|
|
2344
|
+
|
|
2345
|
+
await self._reinforce_opinions_async(
|
|
2346
|
+
bank_id=bank_id,
|
|
2347
|
+
created_unit_ids=created_unit_ids,
|
|
2348
|
+
unit_texts=unit_texts,
|
|
2349
|
+
unit_entities=unit_entities
|
|
2350
|
+
)
|
|
2351
|
+
|
|
2352
|
+
async def _reinforce_opinions_async(
|
|
2353
|
+
self,
|
|
2354
|
+
bank_id: str,
|
|
2355
|
+
created_unit_ids: List[str],
|
|
2356
|
+
unit_texts: List[str],
|
|
2357
|
+
unit_entities: List[List[Dict[str, str]]],
|
|
2358
|
+
):
|
|
2359
|
+
"""
|
|
2360
|
+
Background task to reinforce opinions based on newly ingested events.
|
|
2361
|
+
|
|
2362
|
+
This runs asynchronously and does not block the put operation.
|
|
2363
|
+
|
|
2364
|
+
Args:
|
|
2365
|
+
bank_id: bank ID
|
|
2366
|
+
created_unit_ids: List of newly created memory unit IDs
|
|
2367
|
+
unit_texts: Texts of the newly created units
|
|
2368
|
+
unit_entities: Entities extracted from each unit
|
|
2369
|
+
"""
|
|
2370
|
+
try:
|
|
2371
|
+
# Extract all unique entity names from the new units
|
|
2372
|
+
entity_names = set()
|
|
2373
|
+
for entities_list in unit_entities:
|
|
2374
|
+
for entity in entities_list:
|
|
2375
|
+
# Handle both Entity objects and dicts
|
|
2376
|
+
if hasattr(entity, 'text'):
|
|
2377
|
+
entity_names.add(entity.text)
|
|
2378
|
+
elif isinstance(entity, dict):
|
|
2379
|
+
entity_names.add(entity['text'])
|
|
2380
|
+
|
|
2381
|
+
if not entity_names:
|
|
2382
|
+
return
|
|
2383
|
+
|
|
2384
|
+
|
|
2385
|
+
pool = await self._get_pool()
|
|
2386
|
+
async with acquire_with_retry(pool) as conn:
|
|
2387
|
+
# Find all opinions related to these entities
|
|
2388
|
+
opinions = await conn.fetch(
|
|
2389
|
+
"""
|
|
2390
|
+
SELECT DISTINCT mu.id, mu.text, mu.confidence_score, e.canonical_name
|
|
2391
|
+
FROM memory_units mu
|
|
2392
|
+
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
2393
|
+
JOIN entities e ON ue.entity_id = e.id
|
|
2394
|
+
WHERE mu.bank_id = $1
|
|
2395
|
+
AND mu.fact_type = 'opinion'
|
|
2396
|
+
AND e.canonical_name = ANY($2::text[])
|
|
2397
|
+
""",
|
|
2398
|
+
bank_id,
|
|
2399
|
+
list(entity_names)
|
|
2400
|
+
)
|
|
2401
|
+
|
|
2402
|
+
if not opinions:
|
|
2403
|
+
return
|
|
2404
|
+
|
|
2405
|
+
|
|
2406
|
+
# Use cached LLM config
|
|
2407
|
+
if self._llm_config is None:
|
|
2408
|
+
logger.error("[REINFORCE] LLM config not available, skipping opinion reinforcement")
|
|
2409
|
+
return
|
|
2410
|
+
|
|
2411
|
+
# Evaluate each opinion against the new events
|
|
2412
|
+
updates_to_apply = []
|
|
2413
|
+
for opinion in opinions:
|
|
2414
|
+
opinion_id = str(opinion['id'])
|
|
2415
|
+
opinion_text = opinion['text']
|
|
2416
|
+
opinion_confidence = opinion['confidence_score']
|
|
2417
|
+
entity_name = opinion['canonical_name']
|
|
2418
|
+
|
|
2419
|
+
# Find all new events mentioning this entity
|
|
2420
|
+
relevant_events = []
|
|
2421
|
+
for unit_text, entities_list in zip(unit_texts, unit_entities):
|
|
2422
|
+
if any(e['text'] == entity_name for e in entities_list):
|
|
2423
|
+
relevant_events.append(unit_text)
|
|
2424
|
+
|
|
2425
|
+
if not relevant_events:
|
|
2426
|
+
continue
|
|
2427
|
+
|
|
2428
|
+
# Combine all relevant events
|
|
2429
|
+
combined_events = "\n".join(relevant_events)
|
|
2430
|
+
|
|
2431
|
+
# Evaluate if opinion should be updated
|
|
2432
|
+
evaluation = await self._evaluate_opinion_update_async(
|
|
2433
|
+
opinion_text,
|
|
2434
|
+
opinion_confidence,
|
|
2435
|
+
combined_events,
|
|
2436
|
+
entity_name
|
|
2437
|
+
)
|
|
2438
|
+
|
|
2439
|
+
if evaluation:
|
|
2440
|
+
updates_to_apply.append({
|
|
2441
|
+
'opinion_id': opinion_id,
|
|
2442
|
+
'evaluation': evaluation
|
|
2443
|
+
})
|
|
2444
|
+
|
|
2445
|
+
# Apply all updates in a single transaction
|
|
2446
|
+
if updates_to_apply:
|
|
2447
|
+
async with conn.transaction():
|
|
2448
|
+
for update in updates_to_apply:
|
|
2449
|
+
opinion_id = update['opinion_id']
|
|
2450
|
+
evaluation = update['evaluation']
|
|
2451
|
+
|
|
2452
|
+
if evaluation['action'] == 'update' and evaluation['new_text']:
|
|
2453
|
+
# Update both text and confidence
|
|
2454
|
+
await conn.execute(
|
|
2455
|
+
"""
|
|
2456
|
+
UPDATE memory_units
|
|
2457
|
+
SET text = $1, confidence_score = $2, updated_at = NOW()
|
|
2458
|
+
WHERE id = $3
|
|
2459
|
+
""",
|
|
2460
|
+
evaluation['new_text'],
|
|
2461
|
+
evaluation['new_confidence'],
|
|
2462
|
+
uuid.UUID(opinion_id)
|
|
2463
|
+
)
|
|
2464
|
+
else:
|
|
2465
|
+
# Only update confidence
|
|
2466
|
+
await conn.execute(
|
|
2467
|
+
"""
|
|
2468
|
+
UPDATE memory_units
|
|
2469
|
+
SET confidence_score = $1, updated_at = NOW()
|
|
2470
|
+
WHERE id = $2
|
|
2471
|
+
""",
|
|
2472
|
+
evaluation['new_confidence'],
|
|
2473
|
+
uuid.UUID(opinion_id)
|
|
2474
|
+
)
|
|
2475
|
+
|
|
2476
|
+
else:
|
|
2477
|
+
pass # No opinions to update
|
|
2478
|
+
|
|
2479
|
+
except Exception as e:
|
|
2480
|
+
logger.error(f"[REINFORCE] Error during opinion reinforcement: {str(e)}")
|
|
2481
|
+
import traceback
|
|
2482
|
+
traceback.print_exc()
|
|
2483
|
+
|
|
2484
|
+
# ==================== bank profile Methods ====================
|
|
2485
|
+
|
|
2486
|
+
async def get_bank_profile(self, bank_id: str) -> "bank_utils.BankProfile":
|
|
2487
|
+
"""
|
|
2488
|
+
Get bank profile (name, personality + background).
|
|
2489
|
+
Auto-creates agent with default values if not exists.
|
|
2490
|
+
|
|
2491
|
+
Args:
|
|
2492
|
+
bank_id: bank IDentifier
|
|
2493
|
+
|
|
2494
|
+
Returns:
|
|
2495
|
+
BankProfile with name, typed PersonalityTraits, and background
|
|
2496
|
+
"""
|
|
2497
|
+
pool = await self._get_pool()
|
|
2498
|
+
return await bank_utils.get_bank_profile(pool, bank_id)
|
|
2499
|
+
|
|
2500
|
+
async def update_bank_personality(
|
|
2501
|
+
self,
|
|
2502
|
+
bank_id: str,
|
|
2503
|
+
personality: Dict[str, float]
|
|
2504
|
+
) -> None:
|
|
2505
|
+
"""
|
|
2506
|
+
Update bank personality traits.
|
|
2507
|
+
|
|
2508
|
+
Args:
|
|
2509
|
+
bank_id: bank IDentifier
|
|
2510
|
+
personality: Dict with Big Five traits + bias_strength (all 0-1)
|
|
2511
|
+
"""
|
|
2512
|
+
pool = await self._get_pool()
|
|
2513
|
+
await bank_utils.update_bank_personality(pool, bank_id, personality)
|
|
2514
|
+
|
|
2515
|
+
async def merge_bank_background(
|
|
2516
|
+
self,
|
|
2517
|
+
bank_id: str,
|
|
2518
|
+
new_info: str,
|
|
2519
|
+
update_personality: bool = True
|
|
2520
|
+
) -> dict:
|
|
2521
|
+
"""
|
|
2522
|
+
Merge new background information with existing background using LLM.
|
|
2523
|
+
Normalizes to first person ("I") and resolves conflicts.
|
|
2524
|
+
Optionally infers personality traits from the merged background.
|
|
2525
|
+
|
|
2526
|
+
Args:
|
|
2527
|
+
bank_id: bank IDentifier
|
|
2528
|
+
new_info: New background information to add/merge
|
|
2529
|
+
update_personality: If True, infer Big Five traits from background (default: True)
|
|
2530
|
+
|
|
2531
|
+
Returns:
|
|
2532
|
+
Dict with 'background' (str) and optionally 'personality' (dict) keys
|
|
2533
|
+
"""
|
|
2534
|
+
pool = await self._get_pool()
|
|
2535
|
+
return await bank_utils.merge_bank_background(
|
|
2536
|
+
pool, self._llm_config, bank_id, new_info, update_personality
|
|
2537
|
+
)
|
|
2538
|
+
|
|
2539
|
+
async def list_banks(self) -> list:
|
|
2540
|
+
"""
|
|
2541
|
+
List all agents in the system.
|
|
2542
|
+
|
|
2543
|
+
Returns:
|
|
2544
|
+
List of dicts with bank_id, name, personality, background, created_at, updated_at
|
|
2545
|
+
"""
|
|
2546
|
+
pool = await self._get_pool()
|
|
2547
|
+
return await bank_utils.list_banks(pool)
|
|
2548
|
+
|
|
2549
|
+
# ==================== Reflect Methods ====================
|
|
2550
|
+
|
|
2551
|
+
async def reflect_async(
|
|
2552
|
+
self,
|
|
2553
|
+
bank_id: str,
|
|
2554
|
+
query: str,
|
|
2555
|
+
budget: Budget = Budget.LOW,
|
|
2556
|
+
context: str = None,
|
|
2557
|
+
) -> ReflectResult:
|
|
2558
|
+
"""
|
|
2559
|
+
Reflect and formulate an answer using bank identity, world facts, and opinions.
|
|
2560
|
+
|
|
2561
|
+
This method:
|
|
2562
|
+
1. Retrieves agent facts (bank's identity and past actions)
|
|
2563
|
+
2. Retrieves world facts (general knowledge)
|
|
2564
|
+
3. Retrieves existing opinions (bank's formed perspectives)
|
|
2565
|
+
4. Uses LLM to formulate an answer
|
|
2566
|
+
5. Extracts and stores any new opinions formed during reflection
|
|
2567
|
+
6. Returns plain text answer and the facts used
|
|
2568
|
+
|
|
2569
|
+
Args:
|
|
2570
|
+
bank_id: bank identifier
|
|
2571
|
+
query: Question to answer
|
|
2572
|
+
budget: Budget level for memory exploration (low=100, mid=300, high=600 units)
|
|
2573
|
+
context: Additional context string to include in LLM prompt (not used in recall)
|
|
2574
|
+
|
|
2575
|
+
Returns:
|
|
2576
|
+
ReflectResult containing:
|
|
2577
|
+
- text: Plain text answer (no markdown)
|
|
2578
|
+
- based_on: Dict with 'world', 'agent', and 'opinion' fact lists (MemoryFact objects)
|
|
2579
|
+
- new_opinions: List of newly formed opinions
|
|
2580
|
+
"""
|
|
2581
|
+
# Use cached LLM config
|
|
2582
|
+
if self._llm_config is None:
|
|
2583
|
+
raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
|
|
2584
|
+
|
|
2585
|
+
# Steps 1-3: Run multi-fact-type search (12-way retrieval: 4 methods × 3 fact types)
|
|
2586
|
+
search_result = await self.recall_async(
|
|
2587
|
+
bank_id=bank_id,
|
|
2588
|
+
query=query,
|
|
2589
|
+
budget=budget,
|
|
2590
|
+
max_tokens=4096,
|
|
2591
|
+
enable_trace=False,
|
|
2592
|
+
fact_type=['agent', 'world', 'opinion'],
|
|
2593
|
+
include_entities=True
|
|
2594
|
+
)
|
|
2595
|
+
|
|
2596
|
+
all_results = search_result.results
|
|
2597
|
+
logger.info(f"[THINK] Search returned {len(all_results)} results")
|
|
2598
|
+
|
|
2599
|
+
# Split results by fact type for structured response
|
|
2600
|
+
agent_results = [r for r in all_results if r.fact_type == 'bank']
|
|
2601
|
+
world_results = [r for r in all_results if r.fact_type == 'world']
|
|
2602
|
+
opinion_results = [r for r in all_results if r.fact_type == 'opinion']
|
|
2603
|
+
|
|
2604
|
+
logger.info(f"[THINK] Split results - agent: {len(agent_results)}, world: {len(world_results)}, opinion: {len(opinion_results)}")
|
|
2605
|
+
|
|
2606
|
+
# Format facts for LLM
|
|
2607
|
+
agent_facts_text = think_utils.format_facts_for_prompt(agent_results)
|
|
2608
|
+
world_facts_text = think_utils.format_facts_for_prompt(world_results)
|
|
2609
|
+
opinion_facts_text = think_utils.format_facts_for_prompt(opinion_results)
|
|
2610
|
+
|
|
2611
|
+
logger.info(f"[THINK] Formatted facts - agent: {len(agent_facts_text)} chars, world: {len(world_facts_text)} chars, opinion: {len(opinion_facts_text)} chars")
|
|
2612
|
+
|
|
2613
|
+
# Get bank profile (name, personality + background)
|
|
2614
|
+
profile = await self.get_bank_profile(bank_id)
|
|
2615
|
+
name = profile["name"]
|
|
2616
|
+
personality = profile["personality"] # Typed as PersonalityTraits
|
|
2617
|
+
background = profile["background"]
|
|
2618
|
+
|
|
2619
|
+
# Build the prompt
|
|
2620
|
+
prompt = think_utils.build_think_prompt(
|
|
2621
|
+
agent_facts_text=agent_facts_text,
|
|
2622
|
+
world_facts_text=world_facts_text,
|
|
2623
|
+
opinion_facts_text=opinion_facts_text,
|
|
2624
|
+
query=query,
|
|
2625
|
+
name=name,
|
|
2626
|
+
personality=personality,
|
|
2627
|
+
background=background,
|
|
2628
|
+
context=context,
|
|
2629
|
+
)
|
|
2630
|
+
|
|
2631
|
+
logger.info(f"[THINK] Full prompt length: {len(prompt)} chars")
|
|
2632
|
+
|
|
2633
|
+
system_message = think_utils.get_system_message(personality)
|
|
2634
|
+
|
|
2635
|
+
answer_text = await self._llm_config.call(
|
|
2636
|
+
messages=[
|
|
2637
|
+
{"role": "system", "content": system_message},
|
|
2638
|
+
{"role": "user", "content": prompt}
|
|
2639
|
+
],
|
|
2640
|
+
scope="memory_think",
|
|
2641
|
+
temperature=0.9,
|
|
2642
|
+
max_tokens=1000
|
|
2643
|
+
)
|
|
2644
|
+
|
|
2645
|
+
answer_text = answer_text.strip()
|
|
2646
|
+
|
|
2647
|
+
# Submit form_opinion task for background processing
|
|
2648
|
+
await self._task_backend.submit_task({
|
|
2649
|
+
'type': 'form_opinion',
|
|
2650
|
+
'bank_id': bank_id,
|
|
2651
|
+
'answer_text': answer_text,
|
|
2652
|
+
'query': query
|
|
2653
|
+
})
|
|
2654
|
+
|
|
2655
|
+
# Return response with facts split by type
|
|
2656
|
+
return ReflectResult(
|
|
2657
|
+
text=answer_text,
|
|
2658
|
+
based_on={
|
|
2659
|
+
"world": world_results,
|
|
2660
|
+
"agent": agent_results,
|
|
2661
|
+
"opinion": opinion_results
|
|
2662
|
+
},
|
|
2663
|
+
new_opinions=[] # Opinions are being extracted asynchronously
|
|
2664
|
+
)
|
|
2665
|
+
|
|
2666
|
+
async def _extract_and_store_opinions_async(
|
|
2667
|
+
self,
|
|
2668
|
+
bank_id: str,
|
|
2669
|
+
answer_text: str,
|
|
2670
|
+
query: str
|
|
2671
|
+
):
|
|
2672
|
+
"""
|
|
2673
|
+
Background task to extract and store opinions from think response.
|
|
2674
|
+
|
|
2675
|
+
This runs asynchronously and does not block the think response.
|
|
2676
|
+
|
|
2677
|
+
Args:
|
|
2678
|
+
bank_id: bank IDentifier
|
|
2679
|
+
answer_text: The generated answer text
|
|
2680
|
+
query: The original query
|
|
2681
|
+
"""
|
|
2682
|
+
try:
|
|
2683
|
+
# Extract opinions from the answer
|
|
2684
|
+
new_opinions = await think_utils.extract_opinions_from_text(
|
|
2685
|
+
self._llm_config, text=answer_text, query=query
|
|
2686
|
+
)
|
|
2687
|
+
|
|
2688
|
+
# Store new opinions
|
|
2689
|
+
if new_opinions:
|
|
2690
|
+
from datetime import datetime, timezone
|
|
2691
|
+
current_time = datetime.now(timezone.utc)
|
|
2692
|
+
for opinion in new_opinions:
|
|
2693
|
+
await self.retain_async(
|
|
2694
|
+
bank_id=bank_id,
|
|
2695
|
+
content=opinion.opinion,
|
|
2696
|
+
context=f"formed during thinking about: {query}",
|
|
2697
|
+
event_date=current_time,
|
|
2698
|
+
fact_type_override='opinion',
|
|
2699
|
+
confidence_score=opinion.confidence
|
|
2700
|
+
)
|
|
2701
|
+
|
|
2702
|
+
except Exception as e:
|
|
2703
|
+
logger.warning(f"[THINK] Failed to extract/store opinions: {str(e)}")
|
|
2704
|
+
|
|
2705
|
+
async def get_entity_observations(
|
|
2706
|
+
self,
|
|
2707
|
+
bank_id: str,
|
|
2708
|
+
entity_id: str,
|
|
2709
|
+
limit: int = 10
|
|
2710
|
+
) -> List[EntityObservation]:
|
|
2711
|
+
"""
|
|
2712
|
+
Get observations linked to an entity.
|
|
2713
|
+
|
|
2714
|
+
Args:
|
|
2715
|
+
bank_id: bank IDentifier
|
|
2716
|
+
entity_id: Entity UUID to get observations for
|
|
2717
|
+
limit: Maximum number of observations to return
|
|
2718
|
+
|
|
2719
|
+
Returns:
|
|
2720
|
+
List of EntityObservation objects
|
|
2721
|
+
"""
|
|
2722
|
+
pool = await self._get_pool()
|
|
2723
|
+
async with acquire_with_retry(pool) as conn:
|
|
2724
|
+
rows = await conn.fetch(
|
|
2725
|
+
"""
|
|
2726
|
+
SELECT mu.text, mu.mentioned_at
|
|
2727
|
+
FROM memory_units mu
|
|
2728
|
+
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
2729
|
+
WHERE mu.bank_id = $1
|
|
2730
|
+
AND mu.fact_type = 'observation'
|
|
2731
|
+
AND ue.entity_id = $2
|
|
2732
|
+
ORDER BY mu.mentioned_at DESC
|
|
2733
|
+
LIMIT $3
|
|
2734
|
+
""",
|
|
2735
|
+
bank_id, uuid.UUID(entity_id), limit
|
|
2736
|
+
)
|
|
2737
|
+
|
|
2738
|
+
observations = []
|
|
2739
|
+
for row in rows:
|
|
2740
|
+
mentioned_at = row['mentioned_at'].isoformat() if row['mentioned_at'] else None
|
|
2741
|
+
observations.append(EntityObservation(
|
|
2742
|
+
text=row['text'],
|
|
2743
|
+
mentioned_at=mentioned_at
|
|
2744
|
+
))
|
|
2745
|
+
return observations
|
|
2746
|
+
|
|
2747
|
+
async def list_entities(
|
|
2748
|
+
self,
|
|
2749
|
+
bank_id: str,
|
|
2750
|
+
limit: int = 100
|
|
2751
|
+
) -> List[Dict[str, Any]]:
|
|
2752
|
+
"""
|
|
2753
|
+
List all entities for a bank.
|
|
2754
|
+
|
|
2755
|
+
Args:
|
|
2756
|
+
bank_id: bank IDentifier
|
|
2757
|
+
limit: Maximum number of entities to return
|
|
2758
|
+
|
|
2759
|
+
Returns:
|
|
2760
|
+
List of entity dicts with id, canonical_name, mention_count, first_seen, last_seen
|
|
2761
|
+
"""
|
|
2762
|
+
pool = await self._get_pool()
|
|
2763
|
+
async with acquire_with_retry(pool) as conn:
|
|
2764
|
+
rows = await conn.fetch(
|
|
2765
|
+
"""
|
|
2766
|
+
SELECT id, canonical_name, mention_count, first_seen, last_seen, metadata
|
|
2767
|
+
FROM entities
|
|
2768
|
+
WHERE bank_id = $1
|
|
2769
|
+
ORDER BY mention_count DESC, last_seen DESC
|
|
2770
|
+
LIMIT $2
|
|
2771
|
+
""",
|
|
2772
|
+
bank_id, limit
|
|
2773
|
+
)
|
|
2774
|
+
|
|
2775
|
+
entities = []
|
|
2776
|
+
for row in rows:
|
|
2777
|
+
# Handle metadata - may be dict, JSON string, or None
|
|
2778
|
+
metadata = row['metadata']
|
|
2779
|
+
if metadata is None:
|
|
2780
|
+
metadata = {}
|
|
2781
|
+
elif isinstance(metadata, str):
|
|
2782
|
+
import json
|
|
2783
|
+
try:
|
|
2784
|
+
metadata = json.loads(metadata)
|
|
2785
|
+
except json.JSONDecodeError:
|
|
2786
|
+
metadata = {}
|
|
2787
|
+
|
|
2788
|
+
entities.append({
|
|
2789
|
+
'id': str(row['id']),
|
|
2790
|
+
'canonical_name': row['canonical_name'],
|
|
2791
|
+
'mention_count': row['mention_count'],
|
|
2792
|
+
'first_seen': row['first_seen'].isoformat() if row['first_seen'] else None,
|
|
2793
|
+
'last_seen': row['last_seen'].isoformat() if row['last_seen'] else None,
|
|
2794
|
+
'metadata': metadata
|
|
2795
|
+
})
|
|
2796
|
+
return entities
|
|
2797
|
+
|
|
2798
|
+
async def get_entity_state(
|
|
2799
|
+
self,
|
|
2800
|
+
bank_id: str,
|
|
2801
|
+
entity_id: str,
|
|
2802
|
+
entity_name: str,
|
|
2803
|
+
limit: int = 10
|
|
2804
|
+
) -> EntityState:
|
|
2805
|
+
"""
|
|
2806
|
+
Get the current state (mental model) of an entity.
|
|
2807
|
+
|
|
2808
|
+
Args:
|
|
2809
|
+
bank_id: bank IDentifier
|
|
2810
|
+
entity_id: Entity UUID
|
|
2811
|
+
entity_name: Canonical name of the entity
|
|
2812
|
+
limit: Maximum number of observations to include
|
|
2813
|
+
|
|
2814
|
+
Returns:
|
|
2815
|
+
EntityState with observations
|
|
2816
|
+
"""
|
|
2817
|
+
observations = await self.get_entity_observations(bank_id, entity_id, limit)
|
|
2818
|
+
return EntityState(
|
|
2819
|
+
entity_id=entity_id,
|
|
2820
|
+
canonical_name=entity_name,
|
|
2821
|
+
observations=observations
|
|
2822
|
+
)
|
|
2823
|
+
|
|
2824
|
+
async def regenerate_entity_observations(
|
|
2825
|
+
self,
|
|
2826
|
+
bank_id: str,
|
|
2827
|
+
entity_id: str,
|
|
2828
|
+
entity_name: str,
|
|
2829
|
+
version: str | None = None
|
|
2830
|
+
) -> List[str]:
|
|
2831
|
+
"""
|
|
2832
|
+
Regenerate observations for an entity by:
|
|
2833
|
+
1. Checking version for deduplication (if provided)
|
|
2834
|
+
2. Searching all facts mentioning the entity
|
|
2835
|
+
3. Using LLM to synthesize observations (no personality)
|
|
2836
|
+
4. Deleting old observations for this entity
|
|
2837
|
+
5. Storing new observations linked to the entity
|
|
2838
|
+
|
|
2839
|
+
Args:
|
|
2840
|
+
bank_id: bank IDentifier
|
|
2841
|
+
entity_id: Entity UUID
|
|
2842
|
+
entity_name: Canonical name of the entity
|
|
2843
|
+
version: Entity's last_seen timestamp when task was created (for deduplication)
|
|
2844
|
+
|
|
2845
|
+
Returns:
|
|
2846
|
+
List of created observation IDs
|
|
2847
|
+
"""
|
|
2848
|
+
pool = await self._get_pool()
|
|
2849
|
+
|
|
2850
|
+
# Step 1: Check version for deduplication
|
|
2851
|
+
if version:
|
|
2852
|
+
async with acquire_with_retry(pool) as conn:
|
|
2853
|
+
current_last_seen = await conn.fetchval(
|
|
2854
|
+
"""
|
|
2855
|
+
SELECT last_seen
|
|
2856
|
+
FROM entities
|
|
2857
|
+
WHERE id = $1 AND bank_id = $2
|
|
2858
|
+
""",
|
|
2859
|
+
uuid.UUID(entity_id), bank_id
|
|
2860
|
+
)
|
|
2861
|
+
|
|
2862
|
+
if current_last_seen and current_last_seen.isoformat() != version:
|
|
2863
|
+
return []
|
|
2864
|
+
|
|
2865
|
+
# Step 2: Get all facts mentioning this entity (exclude observations themselves)
|
|
2866
|
+
async with acquire_with_retry(pool) as conn:
|
|
2867
|
+
rows = await conn.fetch(
|
|
2868
|
+
"""
|
|
2869
|
+
SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
|
|
2870
|
+
FROM memory_units mu
|
|
2871
|
+
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
2872
|
+
WHERE mu.bank_id = $1
|
|
2873
|
+
AND ue.entity_id = $2
|
|
2874
|
+
AND mu.fact_type IN ('world', 'agent')
|
|
2875
|
+
ORDER BY mu.occurred_start DESC
|
|
2876
|
+
LIMIT 50
|
|
2877
|
+
""",
|
|
2878
|
+
bank_id, uuid.UUID(entity_id)
|
|
2879
|
+
)
|
|
2880
|
+
|
|
2881
|
+
if not rows:
|
|
2882
|
+
return []
|
|
2883
|
+
|
|
2884
|
+
# Convert to MemoryFact objects for the observation extraction
|
|
2885
|
+
facts = []
|
|
2886
|
+
for row in rows:
|
|
2887
|
+
occurred_start = row['occurred_start'].isoformat() if row['occurred_start'] else None
|
|
2888
|
+
facts.append(MemoryFact(
|
|
2889
|
+
id=str(row['id']),
|
|
2890
|
+
text=row['text'],
|
|
2891
|
+
fact_type=row['fact_type'],
|
|
2892
|
+
context=row['context'],
|
|
2893
|
+
occurred_start=occurred_start
|
|
2894
|
+
))
|
|
2895
|
+
|
|
2896
|
+
# Step 3: Extract observations using LLM (no personality)
|
|
2897
|
+
observations = await observation_utils.extract_observations_from_facts(
|
|
2898
|
+
self._llm_config,
|
|
2899
|
+
entity_name,
|
|
2900
|
+
facts
|
|
2901
|
+
)
|
|
2902
|
+
|
|
2903
|
+
if not observations:
|
|
2904
|
+
return []
|
|
2905
|
+
|
|
2906
|
+
# Step 4: Delete old observations and insert new ones in a transaction
|
|
2907
|
+
async with acquire_with_retry(pool) as conn:
|
|
2908
|
+
async with conn.transaction():
|
|
2909
|
+
# Delete old observations for this entity
|
|
2910
|
+
await conn.execute(
|
|
2911
|
+
"""
|
|
2912
|
+
DELETE FROM memory_units
|
|
2913
|
+
WHERE id IN (
|
|
2914
|
+
SELECT mu.id
|
|
2915
|
+
FROM memory_units mu
|
|
2916
|
+
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
2917
|
+
WHERE mu.bank_id = $1
|
|
2918
|
+
AND mu.fact_type = 'observation'
|
|
2919
|
+
AND ue.entity_id = $2
|
|
2920
|
+
)
|
|
2921
|
+
""",
|
|
2922
|
+
bank_id, uuid.UUID(entity_id)
|
|
2923
|
+
)
|
|
2924
|
+
|
|
2925
|
+
# Generate embeddings for new observations
|
|
2926
|
+
embeddings = await embedding_utils.generate_embeddings_batch(
|
|
2927
|
+
self.embeddings, observations
|
|
2928
|
+
)
|
|
2929
|
+
|
|
2930
|
+
# Insert new observations
|
|
2931
|
+
current_time = utcnow()
|
|
2932
|
+
created_ids = []
|
|
2933
|
+
|
|
2934
|
+
for obs_text, embedding in zip(observations, embeddings):
|
|
2935
|
+
result = await conn.fetchrow(
|
|
2936
|
+
"""
|
|
2937
|
+
INSERT INTO memory_units (
|
|
2938
|
+
bank_id, text, embedding, context, event_date,
|
|
2939
|
+
occurred_start, occurred_end, mentioned_at,
|
|
2940
|
+
fact_type, access_count
|
|
2941
|
+
)
|
|
2942
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
|
|
2943
|
+
RETURNING id
|
|
2944
|
+
""",
|
|
2945
|
+
bank_id,
|
|
2946
|
+
obs_text,
|
|
2947
|
+
str(embedding),
|
|
2948
|
+
f"observation about {entity_name}",
|
|
2949
|
+
current_time,
|
|
2950
|
+
current_time,
|
|
2951
|
+
current_time,
|
|
2952
|
+
current_time
|
|
2953
|
+
)
|
|
2954
|
+
obs_id = str(result['id'])
|
|
2955
|
+
created_ids.append(obs_id)
|
|
2956
|
+
|
|
2957
|
+
# Link observation to entity
|
|
2958
|
+
await conn.execute(
|
|
2959
|
+
"""
|
|
2960
|
+
INSERT INTO unit_entities (unit_id, entity_id)
|
|
2961
|
+
VALUES ($1, $2)
|
|
2962
|
+
""",
|
|
2963
|
+
uuid.UUID(obs_id), uuid.UUID(entity_id)
|
|
2964
|
+
)
|
|
2965
|
+
|
|
2966
|
+
# Single consolidated log line
|
|
2967
|
+
logger.info(f"[OBSERVATIONS] {entity_name}: {len(facts)} facts -> {len(created_ids)} observations")
|
|
2968
|
+
return created_ids
|
|
2969
|
+
|
|
2970
|
+
async def _regenerate_observations_sync(
|
|
2971
|
+
self,
|
|
2972
|
+
bank_id: str,
|
|
2973
|
+
entity_ids: List[str],
|
|
2974
|
+
min_facts: int = 5
|
|
2975
|
+
) -> None:
|
|
2976
|
+
"""
|
|
2977
|
+
Regenerate observations for entities synchronously (called during retain).
|
|
2978
|
+
|
|
2979
|
+
Args:
|
|
2980
|
+
bank_id: Bank identifier
|
|
2981
|
+
entity_ids: List of entity IDs to process
|
|
2982
|
+
min_facts: Minimum facts required to regenerate observations
|
|
2983
|
+
"""
|
|
2984
|
+
if not bank_id or not entity_ids:
|
|
2985
|
+
return
|
|
2986
|
+
|
|
2987
|
+
pool = await self._get_pool()
|
|
2988
|
+
async with pool.acquire() as conn:
|
|
2989
|
+
for entity_id in entity_ids:
|
|
2990
|
+
try:
|
|
2991
|
+
entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
|
|
2992
|
+
|
|
2993
|
+
# Check if entity exists
|
|
2994
|
+
entity_exists = await conn.fetchrow(
|
|
2995
|
+
"SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
|
|
2996
|
+
entity_uuid, bank_id
|
|
2997
|
+
)
|
|
2998
|
+
|
|
2999
|
+
if not entity_exists:
|
|
3000
|
+
logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
|
|
3001
|
+
continue
|
|
3002
|
+
|
|
3003
|
+
entity_name = entity_exists['canonical_name']
|
|
3004
|
+
|
|
3005
|
+
# Count facts linked to this entity
|
|
3006
|
+
fact_count = await conn.fetchval(
|
|
3007
|
+
"SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1",
|
|
3008
|
+
entity_uuid
|
|
3009
|
+
) or 0
|
|
3010
|
+
|
|
3011
|
+
# Only regenerate if entity has enough facts
|
|
3012
|
+
if fact_count >= min_facts:
|
|
3013
|
+
await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
|
|
3014
|
+
else:
|
|
3015
|
+
logger.debug(f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)")
|
|
3016
|
+
|
|
3017
|
+
except Exception as e:
|
|
3018
|
+
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
3019
|
+
continue
|
|
3020
|
+
|
|
3021
|
+
async def _handle_regenerate_observations(self, task_dict: Dict[str, Any]):
|
|
3022
|
+
"""
|
|
3023
|
+
Handler for regenerate_observations tasks.
|
|
3024
|
+
|
|
3025
|
+
Args:
|
|
3026
|
+
task_dict: Dict with 'bank_id' and either:
|
|
3027
|
+
- 'entity_ids' (list): Process multiple entities
|
|
3028
|
+
- 'entity_id', 'entity_name': Process single entity (legacy)
|
|
3029
|
+
"""
|
|
3030
|
+
try:
|
|
3031
|
+
bank_id = task_dict.get('bank_id')
|
|
3032
|
+
|
|
3033
|
+
# New format: multiple entity_ids
|
|
3034
|
+
if 'entity_ids' in task_dict:
|
|
3035
|
+
entity_ids = task_dict.get('entity_ids', [])
|
|
3036
|
+
min_facts = task_dict.get('min_facts', 5)
|
|
3037
|
+
|
|
3038
|
+
if not bank_id or not entity_ids:
|
|
3039
|
+
logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
|
|
3040
|
+
return
|
|
3041
|
+
|
|
3042
|
+
# Process each entity
|
|
3043
|
+
pool = await self._get_pool()
|
|
3044
|
+
async with pool.acquire() as conn:
|
|
3045
|
+
for entity_id in entity_ids:
|
|
3046
|
+
try:
|
|
3047
|
+
# Fetch entity name and check fact count
|
|
3048
|
+
import uuid as uuid_module
|
|
3049
|
+
entity_uuid = uuid_module.UUID(entity_id) if isinstance(entity_id, str) else entity_id
|
|
3050
|
+
|
|
3051
|
+
# First check if entity exists
|
|
3052
|
+
entity_exists = await conn.fetchrow(
|
|
3053
|
+
"SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
|
|
3054
|
+
entity_uuid, bank_id
|
|
3055
|
+
)
|
|
3056
|
+
|
|
3057
|
+
if not entity_exists:
|
|
3058
|
+
logger.debug(f"[OBSERVATIONS] Entity {entity_id} not yet in bank {bank_id}, skipping")
|
|
3059
|
+
continue
|
|
3060
|
+
|
|
3061
|
+
entity_name = entity_exists['canonical_name']
|
|
3062
|
+
|
|
3063
|
+
# Count facts linked to this entity
|
|
3064
|
+
fact_count = await conn.fetchval(
|
|
3065
|
+
"SELECT COUNT(*) FROM unit_entities WHERE entity_id = $1",
|
|
3066
|
+
entity_uuid
|
|
3067
|
+
) or 0
|
|
3068
|
+
|
|
3069
|
+
# Only regenerate if entity has enough facts
|
|
3070
|
+
if fact_count >= min_facts:
|
|
3071
|
+
await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None)
|
|
3072
|
+
else:
|
|
3073
|
+
logger.debug(f"[OBSERVATIONS] Skipping {entity_name} ({fact_count} facts < {min_facts} threshold)")
|
|
3074
|
+
|
|
3075
|
+
except Exception as e:
|
|
3076
|
+
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
3077
|
+
continue
|
|
3078
|
+
|
|
3079
|
+
# Legacy format: single entity
|
|
3080
|
+
else:
|
|
3081
|
+
entity_id = task_dict.get('entity_id')
|
|
3082
|
+
entity_name = task_dict.get('entity_name')
|
|
3083
|
+
version = task_dict.get('version')
|
|
3084
|
+
|
|
3085
|
+
if not all([bank_id, entity_id, entity_name]):
|
|
3086
|
+
logger.error(f"[OBSERVATIONS] Missing required fields in task: {task_dict}")
|
|
3087
|
+
return
|
|
3088
|
+
|
|
3089
|
+
await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version)
|
|
3090
|
+
|
|
3091
|
+
except Exception as e:
|
|
3092
|
+
logger.error(f"[OBSERVATIONS] Error regenerating observations: {e}")
|
|
3093
|
+
import traceback
|
|
3094
|
+
traceback.print_exc()
|
|
3095
|
+
|