hindsight-api 0.0.20__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/api/__init__.py +2 -4
- hindsight_api/api/http.py +28 -78
- hindsight_api/api/mcp.py +2 -1
- hindsight_api/cli.py +0 -1
- hindsight_api/engine/cross_encoder.py +6 -1
- hindsight_api/engine/embeddings.py +6 -1
- hindsight_api/engine/entity_resolver.py +56 -29
- hindsight_api/engine/llm_wrapper.py +97 -5
- hindsight_api/engine/memory_engine.py +264 -139
- hindsight_api/engine/response_models.py +15 -17
- hindsight_api/engine/retain/bank_utils.py +23 -33
- hindsight_api/engine/retain/entity_processing.py +5 -5
- hindsight_api/engine/retain/fact_extraction.py +85 -23
- hindsight_api/engine/retain/fact_storage.py +1 -1
- hindsight_api/engine/retain/link_creation.py +12 -6
- hindsight_api/engine/retain/link_utils.py +50 -56
- hindsight_api/engine/retain/observation_regeneration.py +264 -0
- hindsight_api/engine/retain/orchestrator.py +31 -44
- hindsight_api/engine/retain/types.py +14 -0
- hindsight_api/engine/search/retrieval.py +2 -2
- hindsight_api/engine/search/think_utils.py +59 -30
- hindsight_api/migrations.py +54 -32
- hindsight_api/models.py +1 -2
- hindsight_api/pg0.py +17 -36
- {hindsight_api-0.0.20.dist-info → hindsight_api-0.1.0.dist-info}/METADATA +2 -3
- hindsight_api-0.1.0.dist-info/RECORD +51 -0
- hindsight_api-0.0.20.dist-info/RECORD +0 -50
- {hindsight_api-0.0.20.dist-info → hindsight_api-0.1.0.dist-info}/WHEEL +0 -0
- {hindsight_api-0.0.20.dist-info → hindsight_api-0.1.0.dist-info}/entry_points.txt +0 -0
|
@@ -48,7 +48,7 @@ from .entity_resolver import EntityResolver
|
|
|
48
48
|
from .retain import embedding_utils, bank_utils
|
|
49
49
|
from .search import think_utils, observation_utils
|
|
50
50
|
from .llm_wrapper import LLMConfig
|
|
51
|
-
from .response_models import RecallResult as RecallResultModel, ReflectResult, MemoryFact, EntityState, EntityObservation
|
|
51
|
+
from .response_models import RecallResult as RecallResultModel, ReflectResult, MemoryFact, EntityState, EntityObservation, VALID_RECALL_FACT_TYPES
|
|
52
52
|
from .task_backend import TaskBackend, AsyncIOQueueBackend
|
|
53
53
|
from .search.reranking import CrossEncoderReranker
|
|
54
54
|
from ..pg0 import EmbeddedPostgres
|
|
@@ -110,12 +110,14 @@ class MemoryEngine:
|
|
|
110
110
|
pool_min_size: int = 5,
|
|
111
111
|
pool_max_size: int = 100,
|
|
112
112
|
task_backend: Optional[TaskBackend] = None,
|
|
113
|
+
run_migrations: bool = True,
|
|
113
114
|
):
|
|
114
115
|
"""
|
|
115
116
|
Initialize the temporal + semantic memory system.
|
|
116
117
|
|
|
117
118
|
Args:
|
|
118
119
|
db_url: PostgreSQL connection URL (postgresql://user:pass@host:port/dbname). Required.
|
|
120
|
+
Also supports pg0 URLs: "pg0" or "pg0://instance-name" or "pg0://instance-name:port"
|
|
119
121
|
memory_llm_provider: LLM provider for memory operations: "openai", "groq", or "ollama". Required.
|
|
120
122
|
memory_llm_api_key: API key for the LLM provider. Required.
|
|
121
123
|
memory_llm_model: Model name to use for all memory operations (put/think/opinions). Required.
|
|
@@ -129,16 +131,38 @@ class MemoryEngine:
|
|
|
129
131
|
pool_max_size: Maximum number of connections in the pool (default: 100)
|
|
130
132
|
Increase for parallel think/search operations (e.g., 200-300 for 100+ parallel thinks)
|
|
131
133
|
task_backend: Custom task backend for async task execution. If not provided, uses AsyncIOQueueBackend
|
|
134
|
+
run_migrations: Whether to run database migrations during initialize(). Default: True
|
|
132
135
|
"""
|
|
133
136
|
if not db_url:
|
|
134
137
|
raise ValueError("Database url is required")
|
|
135
138
|
# Track pg0 instance (if used)
|
|
136
139
|
self._pg0: Optional[EmbeddedPostgres] = None
|
|
140
|
+
self._pg0_instance_name: Optional[str] = None
|
|
137
141
|
|
|
138
142
|
# Initialize PostgreSQL connection URL
|
|
139
143
|
# The actual URL will be set during initialize() after starting the server
|
|
140
|
-
|
|
141
|
-
|
|
144
|
+
# Supports: "pg0" (default instance), "pg0://instance-name" (named instance), or regular postgresql:// URL
|
|
145
|
+
if db_url == "pg0":
|
|
146
|
+
self._use_pg0 = True
|
|
147
|
+
self._pg0_instance_name = "hindsight"
|
|
148
|
+
self._pg0_port = None # Use default port
|
|
149
|
+
self.db_url = None
|
|
150
|
+
elif db_url.startswith("pg0://"):
|
|
151
|
+
self._use_pg0 = True
|
|
152
|
+
# Parse instance name and optional port: pg0://instance-name or pg0://instance-name:port
|
|
153
|
+
url_part = db_url[6:] # Remove "pg0://"
|
|
154
|
+
if ":" in url_part:
|
|
155
|
+
self._pg0_instance_name, port_str = url_part.rsplit(":", 1)
|
|
156
|
+
self._pg0_port = int(port_str)
|
|
157
|
+
else:
|
|
158
|
+
self._pg0_instance_name = url_part or "hindsight"
|
|
159
|
+
self._pg0_port = None # Use default port
|
|
160
|
+
self.db_url = None
|
|
161
|
+
else:
|
|
162
|
+
self._use_pg0 = False
|
|
163
|
+
self._pg0_instance_name = None
|
|
164
|
+
self._pg0_port = None
|
|
165
|
+
self.db_url = db_url
|
|
142
166
|
|
|
143
167
|
|
|
144
168
|
# Set default base URL if not provided
|
|
@@ -155,6 +179,7 @@ class MemoryEngine:
|
|
|
155
179
|
self._initialized = False
|
|
156
180
|
self._pool_min_size = pool_min_size
|
|
157
181
|
self._pool_max_size = pool_max_size
|
|
182
|
+
self._run_migrations = run_migrations
|
|
158
183
|
|
|
159
184
|
# Initialize entity resolver (will be created in initialize())
|
|
160
185
|
self.entity_resolver = None
|
|
@@ -378,8 +403,16 @@ class MemoryEngine:
|
|
|
378
403
|
async def start_pg0():
|
|
379
404
|
"""Start pg0 if configured."""
|
|
380
405
|
if self._use_pg0:
|
|
381
|
-
|
|
382
|
-
self.
|
|
406
|
+
kwargs = {"name": self._pg0_instance_name}
|
|
407
|
+
if self._pg0_port is not None:
|
|
408
|
+
kwargs["port"] = self._pg0_port
|
|
409
|
+
pg0 = EmbeddedPostgres(**kwargs)
|
|
410
|
+
# Check if pg0 is already running before we start it
|
|
411
|
+
was_already_running = await pg0.is_running()
|
|
412
|
+
self.db_url = await pg0.ensure_running()
|
|
413
|
+
# Only track pg0 (to stop later) if WE started it
|
|
414
|
+
if not was_already_running:
|
|
415
|
+
self._pg0 = pg0
|
|
383
416
|
|
|
384
417
|
def load_embeddings():
|
|
385
418
|
"""Load embedding model (CPU-bound)."""
|
|
@@ -408,6 +441,12 @@ class MemoryEngine:
|
|
|
408
441
|
pg0_task, embeddings_future, cross_encoder_future, query_analyzer_future
|
|
409
442
|
)
|
|
410
443
|
|
|
444
|
+
# Run database migrations if enabled
|
|
445
|
+
if self._run_migrations:
|
|
446
|
+
from ..migrations import run_migrations
|
|
447
|
+
logger.info("Running database migrations...")
|
|
448
|
+
run_migrations(self.db_url)
|
|
449
|
+
|
|
411
450
|
logger.info(f"Connecting to PostgreSQL at {self.db_url}")
|
|
412
451
|
|
|
413
452
|
# Create connection pool
|
|
@@ -869,7 +908,6 @@ class MemoryEngine:
|
|
|
869
908
|
task_backend=self._task_backend,
|
|
870
909
|
format_date_fn=self._format_readable_date,
|
|
871
910
|
duplicate_checker_fn=self._find_duplicate_facts_batch,
|
|
872
|
-
regenerate_observations_fn=self._regenerate_observations_sync,
|
|
873
911
|
bank_id=bank_id,
|
|
874
912
|
contents_dicts=contents,
|
|
875
913
|
document_id=document_id,
|
|
@@ -955,11 +993,19 @@ class MemoryEngine:
|
|
|
955
993
|
- entities: Optional dict of entity states (if include_entities=True)
|
|
956
994
|
- chunks: Optional dict of chunks (if include_chunks=True)
|
|
957
995
|
"""
|
|
996
|
+
# Validate fact types early
|
|
997
|
+
invalid_types = set(fact_type) - VALID_RECALL_FACT_TYPES
|
|
998
|
+
if invalid_types:
|
|
999
|
+
raise ValueError(
|
|
1000
|
+
f"Invalid fact type(s): {', '.join(sorted(invalid_types))}. "
|
|
1001
|
+
f"Must be one of: {', '.join(sorted(VALID_RECALL_FACT_TYPES))}"
|
|
1002
|
+
)
|
|
1003
|
+
|
|
958
1004
|
# Map budget enum to thinking_budget number
|
|
959
1005
|
budget_mapping = {
|
|
960
1006
|
Budget.LOW: 100,
|
|
961
1007
|
Budget.MID: 300,
|
|
962
|
-
Budget.HIGH:
|
|
1008
|
+
Budget.HIGH: 1000
|
|
963
1009
|
}
|
|
964
1010
|
thinking_budget = budget_mapping[budget]
|
|
965
1011
|
|
|
@@ -1040,12 +1086,12 @@ class MemoryEngine:
|
|
|
1040
1086
|
tracer.start()
|
|
1041
1087
|
|
|
1042
1088
|
pool = await self._get_pool()
|
|
1043
|
-
|
|
1089
|
+
recall_start = time.time()
|
|
1044
1090
|
|
|
1045
1091
|
# Buffer logs for clean output in concurrent scenarios
|
|
1046
|
-
|
|
1092
|
+
recall_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
|
|
1047
1093
|
log_buffer = []
|
|
1048
|
-
log_buffer.append(f"[
|
|
1094
|
+
log_buffer.append(f"[RECALL {recall_id}] Query: '{query[:50]}...' (budget={thinking_budget}, max_tokens={max_tokens})")
|
|
1049
1095
|
|
|
1050
1096
|
try:
|
|
1051
1097
|
# Step 1: Generate query embedding (for semantic search)
|
|
@@ -1088,7 +1134,7 @@ class MemoryEngine:
|
|
|
1088
1134
|
for idx, (ft_semantic, ft_bm25, ft_graph, ft_temporal, ft_timings, ft_temporal_constraint) in enumerate(all_retrievals):
|
|
1089
1135
|
# Log fact types in this retrieval batch
|
|
1090
1136
|
ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
|
|
1091
|
-
logger.debug(f"[
|
|
1137
|
+
logger.debug(f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(ft_semantic)}, bm25={len(ft_bm25)}, graph={len(ft_graph)}, temporal={len(ft_temporal) if ft_temporal else 0}")
|
|
1092
1138
|
|
|
1093
1139
|
semantic_results.extend(ft_semantic)
|
|
1094
1140
|
bm25_results.extend(ft_bm25)
|
|
@@ -1209,7 +1255,6 @@ class MemoryEngine:
|
|
|
1209
1255
|
# Step 4: Rerank using cross-encoder (MergedCandidate -> ScoredResult)
|
|
1210
1256
|
step_start = time.time()
|
|
1211
1257
|
reranker_instance = self._cross_encoder_reranker
|
|
1212
|
-
log_buffer.append(f" [4] Using cross-encoder reranker")
|
|
1213
1258
|
|
|
1214
1259
|
# Rerank using cross-encoder
|
|
1215
1260
|
scored_results = reranker_instance.rerank(query, merged_candidates)
|
|
@@ -1334,12 +1379,7 @@ class MemoryEngine:
|
|
|
1334
1379
|
ft = sr.retrieval.fact_type
|
|
1335
1380
|
fact_type_counts[ft] = fact_type_counts.get(ft, 0) + 1
|
|
1336
1381
|
|
|
1337
|
-
total_time = time.time() - search_start
|
|
1338
1382
|
fact_type_summary = ", ".join([f"{ft}={count}" for ft, count in sorted(fact_type_counts.items())])
|
|
1339
|
-
log_buffer.append(f"[SEARCH {search_id}] Complete: {len(top_scored)} results ({fact_type_summary}) ({total_tokens} tokens) in {total_time:.3f}s")
|
|
1340
|
-
|
|
1341
|
-
# Log all buffered logs at once
|
|
1342
|
-
logger.info("\n" + "\n".join(log_buffer))
|
|
1343
1383
|
|
|
1344
1384
|
# Convert ScoredResult to dicts with ISO datetime strings
|
|
1345
1385
|
top_results_dicts = []
|
|
@@ -1401,11 +1441,12 @@ class MemoryEngine:
|
|
|
1401
1441
|
mentioned_at=result_dict.get("mentioned_at"),
|
|
1402
1442
|
document_id=result_dict.get("document_id"),
|
|
1403
1443
|
chunk_id=result_dict.get("chunk_id"),
|
|
1404
|
-
activation=result_dict.get("weight") # Use final weight as activation
|
|
1405
1444
|
))
|
|
1406
1445
|
|
|
1407
1446
|
# Fetch entity observations if requested
|
|
1408
1447
|
entities_dict = None
|
|
1448
|
+
total_entity_tokens = 0
|
|
1449
|
+
total_chunk_tokens = 0
|
|
1409
1450
|
if include_entities and fact_entity_map:
|
|
1410
1451
|
# Collect unique entities in order of fact relevance (preserving order from top_scored)
|
|
1411
1452
|
# Use a list to maintain order, but track seen entities to avoid duplicates
|
|
@@ -1425,7 +1466,6 @@ class MemoryEngine:
|
|
|
1425
1466
|
|
|
1426
1467
|
# Fetch observations for each entity (respect token budget, in order)
|
|
1427
1468
|
entities_dict = {}
|
|
1428
|
-
total_entity_tokens = 0
|
|
1429
1469
|
encoding = _get_tiktoken_encoding()
|
|
1430
1470
|
|
|
1431
1471
|
for entity_id, entity_name in entities_ordered:
|
|
@@ -1485,7 +1525,6 @@ class MemoryEngine:
|
|
|
1485
1525
|
|
|
1486
1526
|
# Apply token limit and build chunks_dict in the order of chunk_ids_ordered
|
|
1487
1527
|
chunks_dict = {}
|
|
1488
|
-
total_chunk_tokens = 0
|
|
1489
1528
|
encoding = _get_tiktoken_encoding()
|
|
1490
1529
|
|
|
1491
1530
|
for chunk_id in chunk_ids_ordered:
|
|
@@ -1525,10 +1564,17 @@ class MemoryEngine:
|
|
|
1525
1564
|
trace = tracer.finalize(top_results_dicts)
|
|
1526
1565
|
trace_dict = trace.to_dict() if trace else None
|
|
1527
1566
|
|
|
1567
|
+
# Log final recall stats
|
|
1568
|
+
total_time = time.time() - recall_start
|
|
1569
|
+
num_chunks = len(chunks_dict) if chunks_dict else 0
|
|
1570
|
+
num_entities = len(entities_dict) if entities_dict else 0
|
|
1571
|
+
log_buffer.append(f"[RECALL {recall_id}] Complete: {len(top_scored)} facts ({total_tokens} tok), {num_chunks} chunks ({total_chunk_tokens} tok), {num_entities} entities ({total_entity_tokens} tok) | {fact_type_summary} | {total_time:.3f}s")
|
|
1572
|
+
logger.info("\n" + "\n".join(log_buffer))
|
|
1573
|
+
|
|
1528
1574
|
return RecallResultModel(results=memory_facts, trace=trace_dict, entities=entities_dict, chunks=chunks_dict)
|
|
1529
1575
|
|
|
1530
1576
|
except Exception as e:
|
|
1531
|
-
log_buffer.append(f"[
|
|
1577
|
+
log_buffer.append(f"[RECALL {recall_id}] ERROR after {time.time() - recall_start:.3f}s: {str(e)}")
|
|
1532
1578
|
logger.error("\n" + "\n".join(log_buffer))
|
|
1533
1579
|
raise Exception(f"Failed to search memories: {str(e)}")
|
|
1534
1580
|
|
|
@@ -2502,14 +2548,14 @@ Guidelines:
|
|
|
2502
2548
|
async def update_bank_disposition(
|
|
2503
2549
|
self,
|
|
2504
2550
|
bank_id: str,
|
|
2505
|
-
disposition: Dict[str,
|
|
2551
|
+
disposition: Dict[str, int]
|
|
2506
2552
|
) -> None:
|
|
2507
2553
|
"""
|
|
2508
2554
|
Update bank disposition traits.
|
|
2509
2555
|
|
|
2510
2556
|
Args:
|
|
2511
2557
|
bank_id: bank IDentifier
|
|
2512
|
-
disposition: Dict with
|
|
2558
|
+
disposition: Dict with skepticism, literalism, empathy (all 1-5)
|
|
2513
2559
|
"""
|
|
2514
2560
|
pool = await self._get_pool()
|
|
2515
2561
|
await bank_utils.update_bank_disposition(pool, bank_id, disposition)
|
|
@@ -2584,7 +2630,13 @@ Guidelines:
|
|
|
2584
2630
|
if self._llm_config is None:
|
|
2585
2631
|
raise ValueError("Memory LLM API key not set. Set HINDSIGHT_API_LLM_API_KEY environment variable.")
|
|
2586
2632
|
|
|
2633
|
+
reflect_start = time.time()
|
|
2634
|
+
reflect_id = f"{bank_id[:8]}-{int(time.time() * 1000) % 100000}"
|
|
2635
|
+
log_buffer = []
|
|
2636
|
+
log_buffer.append(f"[REFLECT {reflect_id}] Query: '{query[:50]}...'")
|
|
2637
|
+
|
|
2587
2638
|
# Steps 1-3: Run multi-fact-type search (12-way retrieval: 4 methods × 3 fact types)
|
|
2639
|
+
recall_start = time.time()
|
|
2588
2640
|
search_result = await self.recall_async(
|
|
2589
2641
|
bank_id=bank_id,
|
|
2590
2642
|
query=query,
|
|
@@ -2594,24 +2646,22 @@ Guidelines:
|
|
|
2594
2646
|
fact_type=['experience', 'world', 'opinion'],
|
|
2595
2647
|
include_entities=True
|
|
2596
2648
|
)
|
|
2649
|
+
recall_time = time.time() - recall_start
|
|
2597
2650
|
|
|
2598
2651
|
all_results = search_result.results
|
|
2599
|
-
logger.info(f"[THINK] Search returned {len(all_results)} results")
|
|
2600
2652
|
|
|
2601
2653
|
# Split results by fact type for structured response
|
|
2602
2654
|
agent_results = [r for r in all_results if r.fact_type == 'experience']
|
|
2603
2655
|
world_results = [r for r in all_results if r.fact_type == 'world']
|
|
2604
2656
|
opinion_results = [r for r in all_results if r.fact_type == 'opinion']
|
|
2605
2657
|
|
|
2606
|
-
|
|
2658
|
+
log_buffer.append(f"[REFLECT {reflect_id}] Recall: {len(all_results)} facts (experience={len(agent_results)}, world={len(world_results)}, opinion={len(opinion_results)}) in {recall_time:.3f}s")
|
|
2607
2659
|
|
|
2608
2660
|
# Format facts for LLM
|
|
2609
2661
|
agent_facts_text = think_utils.format_facts_for_prompt(agent_results)
|
|
2610
2662
|
world_facts_text = think_utils.format_facts_for_prompt(world_results)
|
|
2611
2663
|
opinion_facts_text = think_utils.format_facts_for_prompt(opinion_results)
|
|
2612
2664
|
|
|
2613
|
-
logger.info(f"[THINK] Formatted facts - agent: {len(agent_facts_text)} chars, world: {len(world_facts_text)} chars, opinion: {len(opinion_facts_text)} chars")
|
|
2614
|
-
|
|
2615
2665
|
# Get bank profile (name, disposition + background)
|
|
2616
2666
|
profile = await self.get_bank_profile(bank_id)
|
|
2617
2667
|
name = profile["name"]
|
|
@@ -2630,10 +2680,11 @@ Guidelines:
|
|
|
2630
2680
|
context=context,
|
|
2631
2681
|
)
|
|
2632
2682
|
|
|
2633
|
-
|
|
2683
|
+
log_buffer.append(f"[REFLECT {reflect_id}] Prompt: {len(prompt)} chars")
|
|
2634
2684
|
|
|
2635
2685
|
system_message = think_utils.get_system_message(disposition)
|
|
2636
2686
|
|
|
2687
|
+
llm_start = time.time()
|
|
2637
2688
|
answer_text = await self._llm_config.call(
|
|
2638
2689
|
messages=[
|
|
2639
2690
|
{"role": "system", "content": system_message},
|
|
@@ -2643,6 +2694,7 @@ Guidelines:
|
|
|
2643
2694
|
temperature=0.9,
|
|
2644
2695
|
max_tokens=1000
|
|
2645
2696
|
)
|
|
2697
|
+
llm_time = time.time() - llm_start
|
|
2646
2698
|
|
|
2647
2699
|
answer_text = answer_text.strip()
|
|
2648
2700
|
|
|
@@ -2654,6 +2706,10 @@ Guidelines:
|
|
|
2654
2706
|
'query': query
|
|
2655
2707
|
})
|
|
2656
2708
|
|
|
2709
|
+
total_time = time.time() - reflect_start
|
|
2710
|
+
log_buffer.append(f"[REFLECT {reflect_id}] Complete: {len(answer_text)} chars response, LLM {llm_time:.3f}s, total {total_time:.3f}s")
|
|
2711
|
+
logger.info("\n" + "\n".join(log_buffer))
|
|
2712
|
+
|
|
2657
2713
|
# Return response with facts split by type
|
|
2658
2714
|
return ReflectResult(
|
|
2659
2715
|
text=answer_text,
|
|
@@ -2702,7 +2758,7 @@ Guidelines:
|
|
|
2702
2758
|
)
|
|
2703
2759
|
|
|
2704
2760
|
except Exception as e:
|
|
2705
|
-
logger.warning(f"[
|
|
2761
|
+
logger.warning(f"[REFLECT] Failed to extract/store opinions: {str(e)}")
|
|
2706
2762
|
|
|
2707
2763
|
async def get_entity_observations(
|
|
2708
2764
|
self,
|
|
@@ -2828,7 +2884,8 @@ Guidelines:
|
|
|
2828
2884
|
bank_id: str,
|
|
2829
2885
|
entity_id: str,
|
|
2830
2886
|
entity_name: str,
|
|
2831
|
-
version: str | None = None
|
|
2887
|
+
version: str | None = None,
|
|
2888
|
+
conn=None
|
|
2832
2889
|
) -> List[str]:
|
|
2833
2890
|
"""
|
|
2834
2891
|
Regenerate observations for an entity by:
|
|
@@ -2843,43 +2900,58 @@ Guidelines:
|
|
|
2843
2900
|
entity_id: Entity UUID
|
|
2844
2901
|
entity_name: Canonical name of the entity
|
|
2845
2902
|
version: Entity's last_seen timestamp when task was created (for deduplication)
|
|
2903
|
+
conn: Optional database connection (for transactional atomicity with caller)
|
|
2846
2904
|
|
|
2847
2905
|
Returns:
|
|
2848
2906
|
List of created observation IDs
|
|
2849
2907
|
"""
|
|
2850
2908
|
pool = await self._get_pool()
|
|
2909
|
+
entity_uuid = uuid.UUID(entity_id)
|
|
2851
2910
|
|
|
2852
|
-
#
|
|
2853
|
-
|
|
2854
|
-
|
|
2855
|
-
|
|
2856
|
-
|
|
2857
|
-
|
|
2858
|
-
|
|
2859
|
-
WHERE id = $1 AND bank_id = $2
|
|
2860
|
-
""",
|
|
2861
|
-
uuid.UUID(entity_id), bank_id
|
|
2862
|
-
)
|
|
2911
|
+
# Helper to run a query with provided conn or acquire one
|
|
2912
|
+
async def fetch_with_conn(query, *args):
|
|
2913
|
+
if conn is not None:
|
|
2914
|
+
return await conn.fetch(query, *args)
|
|
2915
|
+
else:
|
|
2916
|
+
async with acquire_with_retry(pool) as acquired_conn:
|
|
2917
|
+
return await acquired_conn.fetch(query, *args)
|
|
2863
2918
|
|
|
2864
|
-
|
|
2865
|
-
|
|
2919
|
+
async def fetchval_with_conn(query, *args):
|
|
2920
|
+
if conn is not None:
|
|
2921
|
+
return await conn.fetchval(query, *args)
|
|
2922
|
+
else:
|
|
2923
|
+
async with acquire_with_retry(pool) as acquired_conn:
|
|
2924
|
+
return await acquired_conn.fetchval(query, *args)
|
|
2866
2925
|
|
|
2867
|
-
# Step
|
|
2868
|
-
|
|
2869
|
-
|
|
2926
|
+
# Step 1: Check version for deduplication
|
|
2927
|
+
if version:
|
|
2928
|
+
current_last_seen = await fetchval_with_conn(
|
|
2870
2929
|
"""
|
|
2871
|
-
SELECT
|
|
2872
|
-
FROM
|
|
2873
|
-
|
|
2874
|
-
WHERE mu.bank_id = $1
|
|
2875
|
-
AND ue.entity_id = $2
|
|
2876
|
-
AND mu.fact_type IN ('world', 'experience')
|
|
2877
|
-
ORDER BY mu.occurred_start DESC
|
|
2878
|
-
LIMIT 50
|
|
2930
|
+
SELECT last_seen
|
|
2931
|
+
FROM entities
|
|
2932
|
+
WHERE id = $1 AND bank_id = $2
|
|
2879
2933
|
""",
|
|
2880
|
-
|
|
2934
|
+
entity_uuid, bank_id
|
|
2881
2935
|
)
|
|
2882
2936
|
|
|
2937
|
+
if current_last_seen and current_last_seen.isoformat() != version:
|
|
2938
|
+
return []
|
|
2939
|
+
|
|
2940
|
+
# Step 2: Get all facts mentioning this entity (exclude observations themselves)
|
|
2941
|
+
rows = await fetch_with_conn(
|
|
2942
|
+
"""
|
|
2943
|
+
SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
|
|
2944
|
+
FROM memory_units mu
|
|
2945
|
+
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
2946
|
+
WHERE mu.bank_id = $1
|
|
2947
|
+
AND ue.entity_id = $2
|
|
2948
|
+
AND mu.fact_type IN ('world', 'experience')
|
|
2949
|
+
ORDER BY mu.occurred_start DESC
|
|
2950
|
+
LIMIT 50
|
|
2951
|
+
""",
|
|
2952
|
+
bank_id, entity_uuid
|
|
2953
|
+
)
|
|
2954
|
+
|
|
2883
2955
|
if not rows:
|
|
2884
2956
|
return []
|
|
2885
2957
|
|
|
@@ -2905,120 +2977,173 @@ Guidelines:
|
|
|
2905
2977
|
if not observations:
|
|
2906
2978
|
return []
|
|
2907
2979
|
|
|
2908
|
-
# Step 4: Delete old observations and insert new ones
|
|
2909
|
-
|
|
2910
|
-
|
|
2911
|
-
|
|
2912
|
-
|
|
2980
|
+
# Step 4: Delete old observations and insert new ones
|
|
2981
|
+
# If conn provided, we're already in a transaction - don't start another
|
|
2982
|
+
# If conn is None, acquire one and start a transaction
|
|
2983
|
+
async def do_db_operations(db_conn):
|
|
2984
|
+
# Delete old observations for this entity
|
|
2985
|
+
await db_conn.execute(
|
|
2986
|
+
"""
|
|
2987
|
+
DELETE FROM memory_units
|
|
2988
|
+
WHERE id IN (
|
|
2989
|
+
SELECT mu.id
|
|
2990
|
+
FROM memory_units mu
|
|
2991
|
+
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
2992
|
+
WHERE mu.bank_id = $1
|
|
2993
|
+
AND mu.fact_type = 'observation'
|
|
2994
|
+
AND ue.entity_id = $2
|
|
2995
|
+
)
|
|
2996
|
+
""",
|
|
2997
|
+
bank_id, entity_uuid
|
|
2998
|
+
)
|
|
2999
|
+
|
|
3000
|
+
# Generate embeddings for new observations
|
|
3001
|
+
embeddings = await embedding_utils.generate_embeddings_batch(
|
|
3002
|
+
self.embeddings, observations
|
|
3003
|
+
)
|
|
3004
|
+
|
|
3005
|
+
# Insert new observations
|
|
3006
|
+
current_time = utcnow()
|
|
3007
|
+
created_ids = []
|
|
3008
|
+
|
|
3009
|
+
for obs_text, embedding in zip(observations, embeddings):
|
|
3010
|
+
result = await db_conn.fetchrow(
|
|
2913
3011
|
"""
|
|
2914
|
-
|
|
2915
|
-
|
|
2916
|
-
|
|
2917
|
-
|
|
2918
|
-
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
2919
|
-
WHERE mu.bank_id = $1
|
|
2920
|
-
AND mu.fact_type = 'observation'
|
|
2921
|
-
AND ue.entity_id = $2
|
|
3012
|
+
INSERT INTO memory_units (
|
|
3013
|
+
bank_id, text, embedding, context, event_date,
|
|
3014
|
+
occurred_start, occurred_end, mentioned_at,
|
|
3015
|
+
fact_type, access_count
|
|
2922
3016
|
)
|
|
3017
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
|
|
3018
|
+
RETURNING id
|
|
2923
3019
|
""",
|
|
2924
|
-
bank_id,
|
|
3020
|
+
bank_id,
|
|
3021
|
+
obs_text,
|
|
3022
|
+
str(embedding),
|
|
3023
|
+
f"observation about {entity_name}",
|
|
3024
|
+
current_time,
|
|
3025
|
+
current_time,
|
|
3026
|
+
current_time,
|
|
3027
|
+
current_time
|
|
2925
3028
|
)
|
|
3029
|
+
obs_id = str(result['id'])
|
|
3030
|
+
created_ids.append(obs_id)
|
|
2926
3031
|
|
|
2927
|
-
#
|
|
2928
|
-
|
|
2929
|
-
|
|
3032
|
+
# Link observation to entity
|
|
3033
|
+
await db_conn.execute(
|
|
3034
|
+
"""
|
|
3035
|
+
INSERT INTO unit_entities (unit_id, entity_id)
|
|
3036
|
+
VALUES ($1, $2)
|
|
3037
|
+
""",
|
|
3038
|
+
uuid.UUID(obs_id), entity_uuid
|
|
2930
3039
|
)
|
|
2931
3040
|
|
|
2932
|
-
|
|
2933
|
-
current_time = utcnow()
|
|
2934
|
-
created_ids = []
|
|
3041
|
+
return created_ids
|
|
2935
3042
|
|
|
2936
|
-
|
|
2937
|
-
|
|
2938
|
-
|
|
2939
|
-
|
|
2940
|
-
|
|
2941
|
-
|
|
2942
|
-
|
|
2943
|
-
|
|
2944
|
-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
|
|
2945
|
-
RETURNING id
|
|
2946
|
-
""",
|
|
2947
|
-
bank_id,
|
|
2948
|
-
obs_text,
|
|
2949
|
-
str(embedding),
|
|
2950
|
-
f"observation about {entity_name}",
|
|
2951
|
-
current_time,
|
|
2952
|
-
current_time,
|
|
2953
|
-
current_time,
|
|
2954
|
-
current_time
|
|
2955
|
-
)
|
|
2956
|
-
obs_id = str(result['id'])
|
|
2957
|
-
created_ids.append(obs_id)
|
|
2958
|
-
|
|
2959
|
-
# Link observation to entity
|
|
2960
|
-
await conn.execute(
|
|
2961
|
-
"""
|
|
2962
|
-
INSERT INTO unit_entities (unit_id, entity_id)
|
|
2963
|
-
VALUES ($1, $2)
|
|
2964
|
-
""",
|
|
2965
|
-
uuid.UUID(obs_id), uuid.UUID(entity_id)
|
|
2966
|
-
)
|
|
2967
|
-
|
|
2968
|
-
# Single consolidated log line
|
|
2969
|
-
logger.info(f"[OBSERVATIONS] {entity_name}: {len(facts)} facts -> {len(created_ids)} observations")
|
|
2970
|
-
return created_ids
|
|
3043
|
+
if conn is not None:
|
|
3044
|
+
# Use provided connection (already in a transaction)
|
|
3045
|
+
return await do_db_operations(conn)
|
|
3046
|
+
else:
|
|
3047
|
+
# Acquire connection and start our own transaction
|
|
3048
|
+
async with acquire_with_retry(pool) as acquired_conn:
|
|
3049
|
+
async with acquired_conn.transaction():
|
|
3050
|
+
return await do_db_operations(acquired_conn)
|
|
2971
3051
|
|
|
2972
3052
|
async def _regenerate_observations_sync(
|
|
2973
3053
|
self,
|
|
2974
3054
|
bank_id: str,
|
|
2975
3055
|
entity_ids: List[str],
|
|
2976
|
-
min_facts: int = 5
|
|
3056
|
+
min_facts: int = 5,
|
|
3057
|
+
conn=None
|
|
2977
3058
|
) -> None:
|
|
2978
3059
|
"""
|
|
2979
3060
|
Regenerate observations for entities synchronously (called during retain).
|
|
2980
3061
|
|
|
3062
|
+
Processes entities in PARALLEL for faster execution.
|
|
3063
|
+
|
|
2981
3064
|
Args:
|
|
2982
3065
|
bank_id: Bank identifier
|
|
2983
3066
|
entity_ids: List of entity IDs to process
|
|
2984
3067
|
min_facts: Minimum facts required to regenerate observations
|
|
3068
|
+
conn: Optional database connection (for transactional atomicity)
|
|
2985
3069
|
"""
|
|
2986
3070
|
if not bank_id or not entity_ids:
|
|
2987
3071
|
return
|
|
2988
3072
|
|
|
2989
|
-
|
|
2990
|
-
|
|
2991
|
-
for entity_id in entity_ids:
|
|
2992
|
-
try:
|
|
2993
|
-
entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
|
|
2994
|
-
|
|
2995
|
-
# Check if entity exists
|
|
2996
|
-
entity_exists = await conn.fetchrow(
|
|
2997
|
-
"SELECT canonical_name FROM entities WHERE id = $1 AND bank_id = $2",
|
|
2998
|
-
entity_uuid, bank_id
|
|
2999
|
-
)
|
|
3073
|
+
# Convert to UUIDs
|
|
3074
|
+
entity_uuids = [uuid.UUID(eid) if isinstance(eid, str) else eid for eid in entity_ids]
|
|
3000
3075
|
|
|
3001
|
-
|
|
3002
|
-
|
|
3003
|
-
|
|
3076
|
+
# Use provided connection or acquire a new one
|
|
3077
|
+
if conn is not None:
|
|
3078
|
+
# Use the provided connection (transactional with caller)
|
|
3079
|
+
entity_rows = await conn.fetch(
|
|
3080
|
+
"""
|
|
3081
|
+
SELECT id, canonical_name FROM entities
|
|
3082
|
+
WHERE id = ANY($1) AND bank_id = $2
|
|
3083
|
+
""",
|
|
3084
|
+
entity_uuids, bank_id
|
|
3085
|
+
)
|
|
3086
|
+
entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
|
|
3004
3087
|
|
|
3005
|
-
|
|
3088
|
+
fact_counts = await conn.fetch(
|
|
3089
|
+
"""
|
|
3090
|
+
SELECT ue.entity_id, COUNT(*) as cnt
|
|
3091
|
+
FROM unit_entities ue
|
|
3092
|
+
JOIN memory_units mu ON ue.unit_id = mu.id
|
|
3093
|
+
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
3094
|
+
GROUP BY ue.entity_id
|
|
3095
|
+
""",
|
|
3096
|
+
entity_uuids, bank_id
|
|
3097
|
+
)
|
|
3098
|
+
entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
|
|
3099
|
+
else:
|
|
3100
|
+
# Acquire a new connection (standalone call)
|
|
3101
|
+
pool = await self._get_pool()
|
|
3102
|
+
async with pool.acquire() as acquired_conn:
|
|
3103
|
+
entity_rows = await acquired_conn.fetch(
|
|
3104
|
+
"""
|
|
3105
|
+
SELECT id, canonical_name FROM entities
|
|
3106
|
+
WHERE id = ANY($1) AND bank_id = $2
|
|
3107
|
+
""",
|
|
3108
|
+
entity_uuids, bank_id
|
|
3109
|
+
)
|
|
3110
|
+
entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
|
|
3006
3111
|
|
|
3007
|
-
|
|
3008
|
-
|
|
3009
|
-
|
|
3010
|
-
|
|
3011
|
-
|
|
3112
|
+
fact_counts = await acquired_conn.fetch(
|
|
3113
|
+
"""
|
|
3114
|
+
SELECT ue.entity_id, COUNT(*) as cnt
|
|
3115
|
+
FROM unit_entities ue
|
|
3116
|
+
JOIN memory_units mu ON ue.unit_id = mu.id
|
|
3117
|
+
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
3118
|
+
GROUP BY ue.entity_id
|
|
3119
|
+
""",
|
|
3120
|
+
entity_uuids, bank_id
|
|
3121
|
+
)
|
|
3122
|
+
entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
|
|
3123
|
+
|
|
3124
|
+
# Filter entities that meet the threshold
|
|
3125
|
+
entities_to_process = []
|
|
3126
|
+
for entity_id in entity_ids:
|
|
3127
|
+
entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
|
|
3128
|
+
if entity_uuid not in entity_names:
|
|
3129
|
+
continue
|
|
3130
|
+
fact_count = entity_fact_counts.get(entity_uuid, 0)
|
|
3131
|
+
if fact_count >= min_facts:
|
|
3132
|
+
entities_to_process.append((entity_id, entity_names[entity_uuid]))
|
|
3133
|
+
|
|
3134
|
+
if not entities_to_process:
|
|
3135
|
+
return
|
|
3012
3136
|
|
|
3013
|
-
|
|
3014
|
-
|
|
3015
|
-
|
|
3016
|
-
|
|
3017
|
-
|
|
3137
|
+
# Process all entities in PARALLEL (LLM calls are the bottleneck)
|
|
3138
|
+
async def process_entity(entity_id: str, entity_name: str):
|
|
3139
|
+
try:
|
|
3140
|
+
await self.regenerate_entity_observations(bank_id, entity_id, entity_name, version=None, conn=conn)
|
|
3141
|
+
except Exception as e:
|
|
3142
|
+
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
3018
3143
|
|
|
3019
|
-
|
|
3020
|
-
|
|
3021
|
-
|
|
3144
|
+
await asyncio.gather(*[
|
|
3145
|
+
process_entity(eid, name) for eid, name in entities_to_process
|
|
3146
|
+
])
|
|
3022
3147
|
|
|
3023
3148
|
async def _handle_regenerate_observations(self, task_dict: Dict[str, Any]):
|
|
3024
3149
|
"""
|