hindsight-api 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +38 -0
- hindsight_api/api/__init__.py +105 -0
- hindsight_api/api/http.py +1872 -0
- hindsight_api/api/mcp.py +157 -0
- hindsight_api/engine/__init__.py +47 -0
- hindsight_api/engine/cross_encoder.py +97 -0
- hindsight_api/engine/db_utils.py +93 -0
- hindsight_api/engine/embeddings.py +113 -0
- hindsight_api/engine/entity_resolver.py +575 -0
- hindsight_api/engine/llm_wrapper.py +269 -0
- hindsight_api/engine/memory_engine.py +3095 -0
- hindsight_api/engine/query_analyzer.py +519 -0
- hindsight_api/engine/response_models.py +222 -0
- hindsight_api/engine/retain/__init__.py +50 -0
- hindsight_api/engine/retain/bank_utils.py +423 -0
- hindsight_api/engine/retain/chunk_storage.py +82 -0
- hindsight_api/engine/retain/deduplication.py +104 -0
- hindsight_api/engine/retain/embedding_processing.py +62 -0
- hindsight_api/engine/retain/embedding_utils.py +54 -0
- hindsight_api/engine/retain/entity_processing.py +90 -0
- hindsight_api/engine/retain/fact_extraction.py +1027 -0
- hindsight_api/engine/retain/fact_storage.py +176 -0
- hindsight_api/engine/retain/link_creation.py +121 -0
- hindsight_api/engine/retain/link_utils.py +651 -0
- hindsight_api/engine/retain/orchestrator.py +405 -0
- hindsight_api/engine/retain/types.py +206 -0
- hindsight_api/engine/search/__init__.py +15 -0
- hindsight_api/engine/search/fusion.py +122 -0
- hindsight_api/engine/search/observation_utils.py +132 -0
- hindsight_api/engine/search/reranking.py +103 -0
- hindsight_api/engine/search/retrieval.py +503 -0
- hindsight_api/engine/search/scoring.py +161 -0
- hindsight_api/engine/search/temporal_extraction.py +64 -0
- hindsight_api/engine/search/think_utils.py +255 -0
- hindsight_api/engine/search/trace.py +215 -0
- hindsight_api/engine/search/tracer.py +447 -0
- hindsight_api/engine/search/types.py +160 -0
- hindsight_api/engine/task_backend.py +223 -0
- hindsight_api/engine/utils.py +203 -0
- hindsight_api/metrics.py +227 -0
- hindsight_api/migrations.py +163 -0
- hindsight_api/models.py +309 -0
- hindsight_api/pg0.py +425 -0
- hindsight_api/web/__init__.py +12 -0
- hindsight_api/web/server.py +143 -0
- hindsight_api-0.0.13.dist-info/METADATA +41 -0
- hindsight_api-0.0.13.dist-info/RECORD +48 -0
- hindsight_api-0.0.13.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main orchestrator for the retain pipeline.
|
|
3
|
+
|
|
4
|
+
Coordinates all retain pipeline modules to store memories efficiently.
|
|
5
|
+
"""
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
import uuid
|
|
9
|
+
from datetime import datetime, timezone
|
|
10
|
+
from typing import List, Dict, Any, Optional
|
|
11
|
+
|
|
12
|
+
from . import bank_utils
|
|
13
|
+
from ..db_utils import acquire_with_retry
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def utcnow():
|
|
17
|
+
"""Get current UTC time."""
|
|
18
|
+
return datetime.now(timezone.utc)
|
|
19
|
+
|
|
20
|
+
from .types import RetainContent, ExtractedFact, ProcessedFact
|
|
21
|
+
from . import (
|
|
22
|
+
fact_extraction,
|
|
23
|
+
embedding_processing,
|
|
24
|
+
deduplication,
|
|
25
|
+
chunk_storage,
|
|
26
|
+
fact_storage,
|
|
27
|
+
entity_processing,
|
|
28
|
+
link_creation
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
async def retain_batch(
|
|
35
|
+
pool,
|
|
36
|
+
embeddings_model,
|
|
37
|
+
llm_config,
|
|
38
|
+
entity_resolver,
|
|
39
|
+
task_backend,
|
|
40
|
+
format_date_fn,
|
|
41
|
+
duplicate_checker_fn,
|
|
42
|
+
regenerate_observations_fn,
|
|
43
|
+
bank_id: str,
|
|
44
|
+
contents_dicts: List[Dict[str, Any]],
|
|
45
|
+
document_id: Optional[str] = None,
|
|
46
|
+
is_first_batch: bool = True,
|
|
47
|
+
fact_type_override: Optional[str] = None,
|
|
48
|
+
confidence_score: Optional[float] = None,
|
|
49
|
+
) -> List[List[str]]:
|
|
50
|
+
"""
|
|
51
|
+
Process a batch of content through the retain pipeline.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
pool: Database connection pool
|
|
55
|
+
embeddings_model: Embeddings model for generating embeddings
|
|
56
|
+
llm_config: LLM configuration for fact extraction
|
|
57
|
+
entity_resolver: Entity resolver for entity processing
|
|
58
|
+
task_backend: Task backend for background jobs
|
|
59
|
+
format_date_fn: Function to format datetime to readable string
|
|
60
|
+
duplicate_checker_fn: Function to check for duplicate facts
|
|
61
|
+
regenerate_observations_fn: Async function to regenerate observations for entities
|
|
62
|
+
bank_id: Bank identifier
|
|
63
|
+
contents_dicts: List of content dictionaries
|
|
64
|
+
document_id: Optional document ID
|
|
65
|
+
is_first_batch: Whether this is the first batch
|
|
66
|
+
fact_type_override: Override fact type for all facts
|
|
67
|
+
confidence_score: Confidence score for opinions
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List of unit ID lists (one list per content item)
|
|
71
|
+
"""
|
|
72
|
+
start_time = time.time()
|
|
73
|
+
total_chars = sum(len(item.get("content", "")) for item in contents_dicts)
|
|
74
|
+
|
|
75
|
+
# Buffer all logs
|
|
76
|
+
log_buffer = []
|
|
77
|
+
log_buffer.append(f"{'='*60}")
|
|
78
|
+
log_buffer.append(f"RETAIN_BATCH START: {bank_id}")
|
|
79
|
+
log_buffer.append(f"Batch size: {len(contents_dicts)} content items, {total_chars:,} chars")
|
|
80
|
+
log_buffer.append(f"{'='*60}")
|
|
81
|
+
|
|
82
|
+
# Get bank profile
|
|
83
|
+
profile = await bank_utils.get_bank_profile(pool, bank_id)
|
|
84
|
+
agent_name = profile["name"]
|
|
85
|
+
|
|
86
|
+
# Convert dicts to RetainContent objects
|
|
87
|
+
contents = []
|
|
88
|
+
for item in contents_dicts:
|
|
89
|
+
content = RetainContent(
|
|
90
|
+
content=item["content"],
|
|
91
|
+
context=item.get("context", ""),
|
|
92
|
+
event_date=item.get("event_date") or utcnow(),
|
|
93
|
+
metadata=item.get("metadata", {})
|
|
94
|
+
)
|
|
95
|
+
contents.append(content)
|
|
96
|
+
|
|
97
|
+
# Step 1: Extract facts from all contents
|
|
98
|
+
step_start = time.time()
|
|
99
|
+
extract_opinions = (fact_type_override == 'opinion')
|
|
100
|
+
|
|
101
|
+
extracted_facts, chunks = await fact_extraction.extract_facts_from_contents(
|
|
102
|
+
contents,
|
|
103
|
+
llm_config,
|
|
104
|
+
agent_name,
|
|
105
|
+
extract_opinions
|
|
106
|
+
)
|
|
107
|
+
log_buffer.append(f"[1] Extract facts: {len(extracted_facts)} facts, {len(chunks)} chunks from {len(contents)} contents in {time.time() - step_start:.3f}s")
|
|
108
|
+
|
|
109
|
+
if not extracted_facts:
|
|
110
|
+
return [[] for _ in contents]
|
|
111
|
+
|
|
112
|
+
# Apply fact_type_override if provided
|
|
113
|
+
if fact_type_override:
|
|
114
|
+
for fact in extracted_facts:
|
|
115
|
+
fact.fact_type = fact_type_override
|
|
116
|
+
|
|
117
|
+
# Step 2: Augment texts and generate embeddings
|
|
118
|
+
step_start = time.time()
|
|
119
|
+
augmented_texts = embedding_processing.augment_texts_with_dates(extracted_facts, format_date_fn)
|
|
120
|
+
embeddings = await embedding_processing.generate_embeddings_batch(embeddings_model, augmented_texts)
|
|
121
|
+
log_buffer.append(f"[2] Generate embeddings: {len(embeddings)} embeddings in {time.time() - step_start:.3f}s")
|
|
122
|
+
|
|
123
|
+
# Step 3: Convert to ProcessedFact objects (without chunk_ids yet)
|
|
124
|
+
processed_facts = [
|
|
125
|
+
ProcessedFact.from_extracted_fact(extracted_fact, embedding)
|
|
126
|
+
for extracted_fact, embedding in zip(extracted_facts, embeddings)
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
# Track document IDs for logging
|
|
130
|
+
document_ids_added = []
|
|
131
|
+
|
|
132
|
+
# Group contents by document_id for document tracking and chunk storage
|
|
133
|
+
from collections import defaultdict
|
|
134
|
+
contents_by_doc = defaultdict(list)
|
|
135
|
+
for idx, content_dict in enumerate(contents_dicts):
|
|
136
|
+
doc_id = content_dict.get("document_id")
|
|
137
|
+
contents_by_doc[doc_id].append((idx, content_dict))
|
|
138
|
+
|
|
139
|
+
# Step 4: Database transaction
|
|
140
|
+
async with acquire_with_retry(pool) as conn:
|
|
141
|
+
async with conn.transaction():
|
|
142
|
+
# Ensure bank exists
|
|
143
|
+
await fact_storage.ensure_bank_exists(conn, bank_id)
|
|
144
|
+
|
|
145
|
+
# Handle document tracking for all documents
|
|
146
|
+
step_start = time.time()
|
|
147
|
+
# Map None document_id to generated UUIDs
|
|
148
|
+
doc_id_mapping = {} # Maps original doc_id (including None) to actual doc_id used
|
|
149
|
+
|
|
150
|
+
if document_id:
|
|
151
|
+
# Legacy: single document_id parameter
|
|
152
|
+
combined_content = "\n".join([c.get("content", "") for c in contents_dicts])
|
|
153
|
+
retain_params = {}
|
|
154
|
+
if contents_dicts:
|
|
155
|
+
first_item = contents_dicts[0]
|
|
156
|
+
if first_item.get("context"):
|
|
157
|
+
retain_params["context"] = first_item["context"]
|
|
158
|
+
if first_item.get("event_date"):
|
|
159
|
+
retain_params["event_date"] = first_item["event_date"].isoformat() if hasattr(first_item["event_date"], "isoformat") else str(first_item["event_date"])
|
|
160
|
+
if first_item.get("metadata"):
|
|
161
|
+
retain_params["metadata"] = first_item["metadata"]
|
|
162
|
+
|
|
163
|
+
await fact_storage.handle_document_tracking(
|
|
164
|
+
conn, bank_id, document_id, combined_content, is_first_batch, retain_params
|
|
165
|
+
)
|
|
166
|
+
document_ids_added.append(document_id)
|
|
167
|
+
doc_id_mapping[None] = document_id # For backwards compatibility
|
|
168
|
+
else:
|
|
169
|
+
# Handle per-item document_ids (create documents if any item has document_id or if chunks exist)
|
|
170
|
+
has_any_doc_ids = any(item.get("document_id") for item in contents_dicts)
|
|
171
|
+
|
|
172
|
+
if has_any_doc_ids or chunks:
|
|
173
|
+
for original_doc_id, doc_contents in contents_by_doc.items():
|
|
174
|
+
actual_doc_id = original_doc_id
|
|
175
|
+
|
|
176
|
+
# Only create document record if:
|
|
177
|
+
# 1. Item has explicit document_id, OR
|
|
178
|
+
# 2. There are chunks (need document for chunk storage)
|
|
179
|
+
should_create_doc = (original_doc_id is not None) or chunks
|
|
180
|
+
|
|
181
|
+
if should_create_doc:
|
|
182
|
+
if actual_doc_id is None:
|
|
183
|
+
# No document_id but have chunks - generate one
|
|
184
|
+
actual_doc_id = str(uuid.uuid4())
|
|
185
|
+
|
|
186
|
+
# Store mapping for later use
|
|
187
|
+
doc_id_mapping[original_doc_id] = actual_doc_id
|
|
188
|
+
|
|
189
|
+
# Combine content for this document
|
|
190
|
+
combined_content = "\n".join([c.get("content", "") for _, c in doc_contents])
|
|
191
|
+
|
|
192
|
+
# Extract retain params from first content item
|
|
193
|
+
retain_params = {}
|
|
194
|
+
if doc_contents:
|
|
195
|
+
first_item = doc_contents[0][1]
|
|
196
|
+
if first_item.get("context"):
|
|
197
|
+
retain_params["context"] = first_item["context"]
|
|
198
|
+
if first_item.get("event_date"):
|
|
199
|
+
retain_params["event_date"] = first_item["event_date"].isoformat() if hasattr(first_item["event_date"], "isoformat") else str(first_item["event_date"])
|
|
200
|
+
if first_item.get("metadata"):
|
|
201
|
+
retain_params["metadata"] = first_item["metadata"]
|
|
202
|
+
|
|
203
|
+
await fact_storage.handle_document_tracking(
|
|
204
|
+
conn, bank_id, actual_doc_id, combined_content, is_first_batch, retain_params
|
|
205
|
+
)
|
|
206
|
+
document_ids_added.append(actual_doc_id)
|
|
207
|
+
|
|
208
|
+
if document_ids_added:
|
|
209
|
+
log_buffer.append(f"[2.5] Document tracking: {len(document_ids_added)} documents in {time.time() - step_start:.3f}s")
|
|
210
|
+
|
|
211
|
+
# Store chunks and map to facts for all documents
|
|
212
|
+
step_start = time.time()
|
|
213
|
+
chunk_id_map_by_doc = {} # Maps (doc_id, chunk_index) -> chunk_id
|
|
214
|
+
|
|
215
|
+
if chunks:
|
|
216
|
+
# Group chunks by their source document
|
|
217
|
+
chunks_by_doc = defaultdict(list)
|
|
218
|
+
for chunk in chunks:
|
|
219
|
+
# chunk.content_index tells us which content this chunk came from
|
|
220
|
+
original_doc_id = contents_dicts[chunk.content_index].get("document_id")
|
|
221
|
+
# Map to actual document_id (handles None -> generated UUID mapping)
|
|
222
|
+
actual_doc_id = doc_id_mapping.get(original_doc_id, original_doc_id)
|
|
223
|
+
if actual_doc_id is None and document_id:
|
|
224
|
+
actual_doc_id = document_id
|
|
225
|
+
chunks_by_doc[actual_doc_id].append(chunk)
|
|
226
|
+
|
|
227
|
+
# Store chunks for each document
|
|
228
|
+
for doc_id, doc_chunks in chunks_by_doc.items():
|
|
229
|
+
chunk_id_map = await chunk_storage.store_chunks_batch(conn, bank_id, doc_id, doc_chunks)
|
|
230
|
+
# Store mapping with document context
|
|
231
|
+
for chunk_idx, chunk_id in chunk_id_map.items():
|
|
232
|
+
chunk_id_map_by_doc[(doc_id, chunk_idx)] = chunk_id
|
|
233
|
+
|
|
234
|
+
log_buffer.append(f"[3] Store chunks: {len(chunks)} chunks for {len(chunks_by_doc)} documents in {time.time() - step_start:.3f}s")
|
|
235
|
+
|
|
236
|
+
# Map chunk_ids and document_ids to facts
|
|
237
|
+
for fact, processed_fact in zip(extracted_facts, processed_facts):
|
|
238
|
+
# Get the original document_id for this fact's source content
|
|
239
|
+
original_doc_id = contents_dicts[fact.content_index].get("document_id")
|
|
240
|
+
# Map to actual document_id (handles None -> generated UUID mapping)
|
|
241
|
+
actual_doc_id = doc_id_mapping.get(original_doc_id, original_doc_id)
|
|
242
|
+
if actual_doc_id is None and document_id:
|
|
243
|
+
actual_doc_id = document_id
|
|
244
|
+
|
|
245
|
+
# Set document_id on the fact
|
|
246
|
+
processed_fact.document_id = actual_doc_id
|
|
247
|
+
|
|
248
|
+
# Map chunk_id if this fact came from a chunk
|
|
249
|
+
if fact.chunk_index is not None:
|
|
250
|
+
# Look up chunk_id using (doc_id, chunk_index)
|
|
251
|
+
chunk_id = chunk_id_map_by_doc.get((actual_doc_id, fact.chunk_index))
|
|
252
|
+
if chunk_id:
|
|
253
|
+
processed_fact.chunk_id = chunk_id
|
|
254
|
+
else:
|
|
255
|
+
# No chunks - still need to set document_id on facts
|
|
256
|
+
for fact, processed_fact in zip(extracted_facts, processed_facts):
|
|
257
|
+
original_doc_id = contents_dicts[fact.content_index].get("document_id")
|
|
258
|
+
# Map to actual document_id (handles None -> generated UUID mapping)
|
|
259
|
+
actual_doc_id = doc_id_mapping.get(original_doc_id, original_doc_id)
|
|
260
|
+
if actual_doc_id is None and document_id:
|
|
261
|
+
actual_doc_id = document_id
|
|
262
|
+
processed_fact.document_id = actual_doc_id
|
|
263
|
+
|
|
264
|
+
# Deduplication
|
|
265
|
+
step_start = time.time()
|
|
266
|
+
is_duplicate_flags = await deduplication.check_duplicates_batch(
|
|
267
|
+
conn, bank_id, processed_facts, duplicate_checker_fn
|
|
268
|
+
)
|
|
269
|
+
log_buffer.append(f"[4] Deduplication: {sum(is_duplicate_flags)} duplicates in {time.time() - step_start:.3f}s")
|
|
270
|
+
|
|
271
|
+
# Filter out duplicates
|
|
272
|
+
non_duplicate_facts = deduplication.filter_duplicates(processed_facts, is_duplicate_flags)
|
|
273
|
+
|
|
274
|
+
if not non_duplicate_facts:
|
|
275
|
+
return [[] for _ in contents]
|
|
276
|
+
|
|
277
|
+
# Insert facts (document_id is now stored per-fact)
|
|
278
|
+
step_start = time.time()
|
|
279
|
+
unit_ids = await fact_storage.insert_facts_batch(conn, bank_id, non_duplicate_facts)
|
|
280
|
+
log_buffer.append(f"[5] Insert facts: {len(unit_ids)} units in {time.time() - step_start:.3f}s")
|
|
281
|
+
|
|
282
|
+
# Process entities
|
|
283
|
+
step_start = time.time()
|
|
284
|
+
entity_links = await entity_processing.process_entities_batch(
|
|
285
|
+
entity_resolver, conn, bank_id, unit_ids, non_duplicate_facts, log_buffer
|
|
286
|
+
)
|
|
287
|
+
log_buffer.append(f"[6] Process entities: {len(entity_links)} links in {time.time() - step_start:.3f}s")
|
|
288
|
+
|
|
289
|
+
# Create temporal links
|
|
290
|
+
step_start = time.time()
|
|
291
|
+
await link_creation.create_temporal_links_batch(conn, bank_id, unit_ids)
|
|
292
|
+
log_buffer.append(f"[7] Temporal links: {time.time() - step_start:.3f}s")
|
|
293
|
+
|
|
294
|
+
# Create semantic links
|
|
295
|
+
step_start = time.time()
|
|
296
|
+
embeddings_for_links = [fact.embedding for fact in non_duplicate_facts]
|
|
297
|
+
await link_creation.create_semantic_links_batch(conn, bank_id, unit_ids, embeddings_for_links)
|
|
298
|
+
log_buffer.append(f"[8] Semantic links: {time.time() - step_start:.3f}s")
|
|
299
|
+
|
|
300
|
+
# Insert entity links
|
|
301
|
+
step_start = time.time()
|
|
302
|
+
if entity_links:
|
|
303
|
+
await entity_processing.insert_entity_links_batch(conn, entity_links)
|
|
304
|
+
log_buffer.append(f"[9] Entity links: {time.time() - step_start:.3f}s")
|
|
305
|
+
|
|
306
|
+
# Create causal links
|
|
307
|
+
step_start = time.time()
|
|
308
|
+
causal_link_count = await link_creation.create_causal_links_batch(conn, unit_ids, non_duplicate_facts)
|
|
309
|
+
log_buffer.append(f"[10] Causal links: {causal_link_count} links in {time.time() - step_start:.3f}s")
|
|
310
|
+
|
|
311
|
+
# Map results back to original content items
|
|
312
|
+
result_unit_ids = _map_results_to_contents(
|
|
313
|
+
contents, extracted_facts, is_duplicate_flags, unit_ids
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
total_time = time.time() - start_time
|
|
317
|
+
log_buffer.append(f"{'='*60}")
|
|
318
|
+
log_buffer.append(f"RETAIN_BATCH COMPLETE: {len(unit_ids)} units in {total_time:.3f}s")
|
|
319
|
+
if document_ids_added:
|
|
320
|
+
log_buffer.append(f"Documents: {', '.join(document_ids_added)}")
|
|
321
|
+
log_buffer.append(f"{'='*60}")
|
|
322
|
+
|
|
323
|
+
logger.info("\n" + "\n".join(log_buffer) + "\n")
|
|
324
|
+
|
|
325
|
+
# Trigger background tasks AFTER transaction commits
|
|
326
|
+
await _trigger_background_tasks(
|
|
327
|
+
task_backend,
|
|
328
|
+
regenerate_observations_fn,
|
|
329
|
+
bank_id,
|
|
330
|
+
unit_ids,
|
|
331
|
+
non_duplicate_facts,
|
|
332
|
+
entity_links
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
return result_unit_ids
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _map_results_to_contents(
|
|
339
|
+
contents: List[RetainContent],
|
|
340
|
+
extracted_facts: List[ExtractedFact],
|
|
341
|
+
is_duplicate_flags: List[bool],
|
|
342
|
+
unit_ids: List[str]
|
|
343
|
+
) -> List[List[str]]:
|
|
344
|
+
"""
|
|
345
|
+
Map created unit IDs back to original content items.
|
|
346
|
+
|
|
347
|
+
Accounts for duplicates when mapping back.
|
|
348
|
+
"""
|
|
349
|
+
result_unit_ids = []
|
|
350
|
+
filtered_idx = 0
|
|
351
|
+
|
|
352
|
+
# Group facts by content_index
|
|
353
|
+
facts_by_content = {i: [] for i in range(len(contents))}
|
|
354
|
+
for i, fact in enumerate(extracted_facts):
|
|
355
|
+
facts_by_content[fact.content_index].append(i)
|
|
356
|
+
|
|
357
|
+
for content_index in range(len(contents)):
|
|
358
|
+
content_unit_ids = []
|
|
359
|
+
for fact_idx in facts_by_content[content_index]:
|
|
360
|
+
if not is_duplicate_flags[fact_idx]:
|
|
361
|
+
content_unit_ids.append(unit_ids[filtered_idx])
|
|
362
|
+
filtered_idx += 1
|
|
363
|
+
result_unit_ids.append(content_unit_ids)
|
|
364
|
+
|
|
365
|
+
return result_unit_ids
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
async def _trigger_background_tasks(
|
|
369
|
+
task_backend,
|
|
370
|
+
regenerate_observations_fn,
|
|
371
|
+
bank_id: str,
|
|
372
|
+
unit_ids: List[str],
|
|
373
|
+
facts: List[ProcessedFact],
|
|
374
|
+
entity_links: List
|
|
375
|
+
) -> None:
|
|
376
|
+
"""Trigger opinion reinforcement and observation regeneration (sync)."""
|
|
377
|
+
# Trigger opinion reinforcement if there are entities
|
|
378
|
+
fact_entities = [[e.name for e in fact.entities] for fact in facts]
|
|
379
|
+
if any(fact_entities):
|
|
380
|
+
await task_backend.submit_task({
|
|
381
|
+
'type': 'reinforce_opinion',
|
|
382
|
+
'bank_id': bank_id,
|
|
383
|
+
'created_unit_ids': unit_ids,
|
|
384
|
+
'unit_texts': [fact.fact_text for fact in facts],
|
|
385
|
+
'unit_entities': fact_entities
|
|
386
|
+
})
|
|
387
|
+
|
|
388
|
+
# Regenerate observations synchronously for top entities
|
|
389
|
+
TOP_N_ENTITIES = 5
|
|
390
|
+
MIN_FACTS_THRESHOLD = 5
|
|
391
|
+
|
|
392
|
+
if entity_links and regenerate_observations_fn:
|
|
393
|
+
unique_entity_ids = set()
|
|
394
|
+
for link in entity_links:
|
|
395
|
+
# links are tuples: (unit_id, entity_id, confidence)
|
|
396
|
+
if len(link) >= 2 and link[1]:
|
|
397
|
+
unique_entity_ids.add(str(link[1]))
|
|
398
|
+
|
|
399
|
+
if unique_entity_ids:
|
|
400
|
+
# Run observation regeneration synchronously
|
|
401
|
+
await regenerate_observations_fn(
|
|
402
|
+
bank_id=bank_id,
|
|
403
|
+
entity_ids=list(unique_entity_ids)[:TOP_N_ENTITIES],
|
|
404
|
+
min_facts=MIN_FACTS_THRESHOLD
|
|
405
|
+
)
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Type definitions for the retain pipeline.
|
|
3
|
+
|
|
4
|
+
These dataclasses provide type safety throughout the retain operation,
|
|
5
|
+
from content input to fact storage.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import List, Optional, Dict, Any
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from uuid import UUID
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class RetainContent:
|
|
16
|
+
"""
|
|
17
|
+
Input content item to be retained as memories.
|
|
18
|
+
|
|
19
|
+
Represents a single piece of content to extract facts from.
|
|
20
|
+
"""
|
|
21
|
+
content: str
|
|
22
|
+
context: str = ""
|
|
23
|
+
event_date: Optional[datetime] = None
|
|
24
|
+
metadata: Dict[str, str] = field(default_factory=dict)
|
|
25
|
+
|
|
26
|
+
def __post_init__(self):
|
|
27
|
+
"""Ensure event_date is set."""
|
|
28
|
+
if self.event_date is None:
|
|
29
|
+
from datetime import datetime, timezone
|
|
30
|
+
self.event_date = datetime.now(timezone.utc)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class ChunkMetadata:
|
|
35
|
+
"""
|
|
36
|
+
Metadata about a text chunk.
|
|
37
|
+
|
|
38
|
+
Used to track which facts were extracted from which chunks.
|
|
39
|
+
"""
|
|
40
|
+
chunk_text: str
|
|
41
|
+
fact_count: int
|
|
42
|
+
content_index: int # Index of the source content
|
|
43
|
+
chunk_index: int # Global chunk index across all contents
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class EntityRef:
|
|
48
|
+
"""
|
|
49
|
+
Reference to an entity mentioned in a fact.
|
|
50
|
+
|
|
51
|
+
Entities are extracted by the LLM during fact extraction.
|
|
52
|
+
"""
|
|
53
|
+
name: str
|
|
54
|
+
canonical_name: Optional[str] = None # Resolved canonical name
|
|
55
|
+
entity_id: Optional[UUID] = None # Resolved entity ID
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class CausalRelation:
|
|
60
|
+
"""
|
|
61
|
+
Causal relationship between facts.
|
|
62
|
+
|
|
63
|
+
Represents how one fact causes, enables, or prevents another.
|
|
64
|
+
"""
|
|
65
|
+
relation_type: str # "causes", "enables", "prevents", "caused_by"
|
|
66
|
+
target_fact_index: int # Index of the target fact in the batch
|
|
67
|
+
strength: float = 1.0 # Strength of the causal relationship
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class ExtractedFact:
|
|
72
|
+
"""
|
|
73
|
+
Fact extracted from content by the LLM.
|
|
74
|
+
|
|
75
|
+
This is the raw output from fact extraction before processing.
|
|
76
|
+
"""
|
|
77
|
+
fact_text: str
|
|
78
|
+
fact_type: str # "world", "bank", "opinion", "observation"
|
|
79
|
+
entities: List[str] = field(default_factory=list)
|
|
80
|
+
occurred_start: Optional[datetime] = None
|
|
81
|
+
occurred_end: Optional[datetime] = None
|
|
82
|
+
where: Optional[str] = None # WHERE the fact occurred or is about
|
|
83
|
+
causal_relations: List[CausalRelation] = field(default_factory=list)
|
|
84
|
+
|
|
85
|
+
# Context from the content item
|
|
86
|
+
content_index: int = 0 # Which content this fact came from
|
|
87
|
+
chunk_index: int = 0 # Which chunk this fact came from
|
|
88
|
+
context: str = ""
|
|
89
|
+
mentioned_at: Optional[datetime] = None
|
|
90
|
+
metadata: Dict[str, str] = field(default_factory=dict)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class ProcessedFact:
|
|
95
|
+
"""
|
|
96
|
+
Fact after processing and ready for storage.
|
|
97
|
+
|
|
98
|
+
Includes resolved entities, embeddings, and all necessary fields.
|
|
99
|
+
"""
|
|
100
|
+
# Core fact data
|
|
101
|
+
fact_text: str
|
|
102
|
+
fact_type: str
|
|
103
|
+
embedding: List[float]
|
|
104
|
+
|
|
105
|
+
# Temporal data
|
|
106
|
+
occurred_start: Optional[datetime]
|
|
107
|
+
occurred_end: Optional[datetime]
|
|
108
|
+
mentioned_at: datetime
|
|
109
|
+
|
|
110
|
+
# Context and metadata
|
|
111
|
+
context: str
|
|
112
|
+
metadata: Dict[str, str]
|
|
113
|
+
|
|
114
|
+
# Location data
|
|
115
|
+
where: Optional[str] = None
|
|
116
|
+
|
|
117
|
+
# Entities
|
|
118
|
+
entities: List[EntityRef] = field(default_factory=list)
|
|
119
|
+
|
|
120
|
+
# Causal relations
|
|
121
|
+
causal_relations: List[CausalRelation] = field(default_factory=list)
|
|
122
|
+
|
|
123
|
+
# Chunk reference
|
|
124
|
+
chunk_id: Optional[str] = None
|
|
125
|
+
|
|
126
|
+
# Document reference (denormalized for query performance)
|
|
127
|
+
document_id: Optional[str] = None
|
|
128
|
+
|
|
129
|
+
# DB fields (set after insertion)
|
|
130
|
+
unit_id: Optional[UUID] = None
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def is_duplicate(self) -> bool:
|
|
134
|
+
"""Check if this fact was marked as a duplicate."""
|
|
135
|
+
return self.unit_id is None
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def from_extracted_fact(
|
|
139
|
+
extracted_fact: 'ExtractedFact',
|
|
140
|
+
embedding: List[float],
|
|
141
|
+
chunk_id: Optional[str] = None
|
|
142
|
+
) -> 'ProcessedFact':
|
|
143
|
+
"""
|
|
144
|
+
Create ProcessedFact from ExtractedFact.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
extracted_fact: Source ExtractedFact
|
|
148
|
+
embedding: Generated embedding vector
|
|
149
|
+
chunk_id: Optional chunk ID
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
ProcessedFact ready for storage
|
|
153
|
+
"""
|
|
154
|
+
from datetime import datetime, timezone
|
|
155
|
+
|
|
156
|
+
# Use occurred dates only if explicitly provided by LLM
|
|
157
|
+
occurred_start = extracted_fact.occurred_start
|
|
158
|
+
occurred_end = extracted_fact.occurred_end
|
|
159
|
+
mentioned_at = extracted_fact.mentioned_at or datetime.now(timezone.utc)
|
|
160
|
+
|
|
161
|
+
# Convert entity strings to EntityRef objects
|
|
162
|
+
entities = [EntityRef(name=name) for name in extracted_fact.entities]
|
|
163
|
+
|
|
164
|
+
return ProcessedFact(
|
|
165
|
+
fact_text=extracted_fact.fact_text,
|
|
166
|
+
fact_type=extracted_fact.fact_type,
|
|
167
|
+
embedding=embedding,
|
|
168
|
+
occurred_start=occurred_start,
|
|
169
|
+
occurred_end=occurred_end,
|
|
170
|
+
mentioned_at=mentioned_at,
|
|
171
|
+
context=extracted_fact.context,
|
|
172
|
+
metadata=extracted_fact.metadata,
|
|
173
|
+
entities=entities,
|
|
174
|
+
causal_relations=extracted_fact.causal_relations,
|
|
175
|
+
chunk_id=chunk_id
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@dataclass
|
|
180
|
+
class RetainBatch:
|
|
181
|
+
"""
|
|
182
|
+
A batch of content to retain.
|
|
183
|
+
|
|
184
|
+
Tracks all facts, chunks, and metadata for a batch operation.
|
|
185
|
+
"""
|
|
186
|
+
bank_id: str
|
|
187
|
+
contents: List[RetainContent]
|
|
188
|
+
document_id: Optional[str] = None
|
|
189
|
+
fact_type_override: Optional[str] = None
|
|
190
|
+
confidence_score: Optional[float] = None
|
|
191
|
+
|
|
192
|
+
# Extracted data (populated during processing)
|
|
193
|
+
extracted_facts: List[ExtractedFact] = field(default_factory=list)
|
|
194
|
+
processed_facts: List[ProcessedFact] = field(default_factory=list)
|
|
195
|
+
chunks: List[ChunkMetadata] = field(default_factory=list)
|
|
196
|
+
|
|
197
|
+
# Results (populated after storage)
|
|
198
|
+
unit_ids_by_content: List[List[str]] = field(default_factory=list)
|
|
199
|
+
|
|
200
|
+
def get_facts_for_content(self, content_index: int) -> List[ExtractedFact]:
|
|
201
|
+
"""Get all extracted facts for a specific content item."""
|
|
202
|
+
return [f for f in self.extracted_facts if f.content_index == content_index]
|
|
203
|
+
|
|
204
|
+
def get_chunks_for_content(self, content_index: int) -> List[ChunkMetadata]:
|
|
205
|
+
"""Get all chunks for a specific content item."""
|
|
206
|
+
return [c for c in self.chunks if c.content_index == content_index]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Search module for memory retrieval.
|
|
3
|
+
|
|
4
|
+
Provides modular search architecture:
|
|
5
|
+
- Retrieval: 4-way parallel (semantic + BM25 + graph + temporal)
|
|
6
|
+
- Reranking: Pluggable strategies (heuristic, cross-encoder)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from .retrieval import retrieve_parallel
|
|
10
|
+
from .reranking import CrossEncoderReranker
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"retrieve_parallel",
|
|
14
|
+
"CrossEncoderReranker",
|
|
15
|
+
]
|