hindsight-api 0.0.21__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -2
- hindsight_api/alembic/README +1 -0
- hindsight_api/alembic/env.py +146 -0
- hindsight_api/alembic/script.py.mako +28 -0
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +274 -0
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +70 -0
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +39 -0
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +48 -0
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +62 -0
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +65 -0
- hindsight_api/api/__init__.py +2 -4
- hindsight_api/api/http.py +112 -164
- hindsight_api/api/mcp.py +2 -1
- hindsight_api/config.py +154 -0
- hindsight_api/engine/__init__.py +7 -2
- hindsight_api/engine/cross_encoder.py +225 -16
- hindsight_api/engine/embeddings.py +198 -19
- hindsight_api/engine/entity_resolver.py +56 -29
- hindsight_api/engine/llm_wrapper.py +147 -106
- hindsight_api/engine/memory_engine.py +337 -192
- hindsight_api/engine/response_models.py +15 -17
- hindsight_api/engine/retain/bank_utils.py +25 -35
- hindsight_api/engine/retain/entity_processing.py +5 -5
- hindsight_api/engine/retain/fact_extraction.py +86 -24
- hindsight_api/engine/retain/fact_storage.py +1 -1
- hindsight_api/engine/retain/link_creation.py +12 -6
- hindsight_api/engine/retain/link_utils.py +50 -56
- hindsight_api/engine/retain/observation_regeneration.py +264 -0
- hindsight_api/engine/retain/orchestrator.py +31 -44
- hindsight_api/engine/retain/types.py +14 -0
- hindsight_api/engine/search/reranking.py +6 -10
- hindsight_api/engine/search/retrieval.py +2 -2
- hindsight_api/engine/search/think_utils.py +59 -30
- hindsight_api/engine/search/tracer.py +1 -1
- hindsight_api/main.py +201 -0
- hindsight_api/migrations.py +61 -39
- hindsight_api/models.py +1 -2
- hindsight_api/pg0.py +17 -36
- hindsight_api/server.py +43 -0
- {hindsight_api-0.0.21.dist-info → hindsight_api-0.1.1.dist-info}/METADATA +2 -3
- hindsight_api-0.1.1.dist-info/RECORD +60 -0
- hindsight_api-0.1.1.dist-info/entry_points.txt +2 -0
- hindsight_api/cli.py +0 -128
- hindsight_api/web/__init__.py +0 -12
- hindsight_api/web/server.py +0 -109
- hindsight_api-0.0.21.dist-info/RECORD +0 -50
- hindsight_api-0.0.21.dist-info/entry_points.txt +0 -2
- {hindsight_api-0.0.21.dist-info → hindsight_api-0.1.1.dist-info}/WHEEL +0 -0
|
@@ -6,6 +6,9 @@ import time
|
|
|
6
6
|
import logging
|
|
7
7
|
from typing import List
|
|
8
8
|
from datetime import timedelta, datetime, timezone
|
|
9
|
+
from uuid import UUID
|
|
10
|
+
|
|
11
|
+
from .types import EntityLink
|
|
9
12
|
|
|
10
13
|
logger = logging.getLogger(__name__)
|
|
11
14
|
|
|
@@ -202,47 +205,24 @@ async def extract_entities_batch_optimized(
|
|
|
202
205
|
|
|
203
206
|
# Resolve ALL entities in one batch call
|
|
204
207
|
if all_entities_flat:
|
|
205
|
-
# [6.2.2] Batch resolve entities
|
|
208
|
+
# [6.2.2] Batch resolve entities - single call with per-entity dates
|
|
206
209
|
substep_6_2_2_start = time.time()
|
|
207
|
-
# Group by date for batch resolution (round to hour to reduce buckets)
|
|
208
|
-
entities_by_date = {}
|
|
209
|
-
for idx, (unit_id, local_idx, fact_date) in enumerate(entity_to_unit):
|
|
210
|
-
# Round to hour to group facts from same time period
|
|
211
|
-
date_key = fact_date.replace(minute=0, second=0, microsecond=0)
|
|
212
|
-
if date_key not in entities_by_date:
|
|
213
|
-
entities_by_date[date_key] = []
|
|
214
|
-
entities_by_date[date_key].append((idx, all_entities_flat[idx]))
|
|
215
|
-
|
|
216
|
-
_log(log_buffer, f" [6.2.2] Grouped into {len(entities_by_date)} date buckets, resolving sequentially...", level='debug')
|
|
217
|
-
|
|
218
|
-
# Resolve all date groups SEQUENTIALLY using main transaction connection
|
|
219
|
-
# This prevents race conditions where parallel tasks create duplicate entities
|
|
220
|
-
resolved_entity_ids = [None] * len(all_entities_flat)
|
|
221
|
-
|
|
222
|
-
for date_idx, (date_key, entities_group) in enumerate(entities_by_date.items(), 1):
|
|
223
|
-
date_bucket_start = time.time()
|
|
224
|
-
indices = [idx for idx, _ in entities_group]
|
|
225
|
-
entities_data = [entity_data for _, entity_data in entities_group]
|
|
226
|
-
# Use the first fact's date for this bucket (all should be in same hour)
|
|
227
|
-
fact_date = entity_to_unit[indices[0]][2]
|
|
228
|
-
|
|
229
|
-
# Use main transaction connection to ensure consistency
|
|
230
|
-
batch_resolved = await entity_resolver.resolve_entities_batch(
|
|
231
|
-
bank_id=bank_id,
|
|
232
|
-
entities_data=entities_data,
|
|
233
|
-
context=context,
|
|
234
|
-
unit_event_date=fact_date,
|
|
235
|
-
conn=conn # Use main transaction connection
|
|
236
|
-
)
|
|
237
210
|
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
211
|
+
# Add per-entity dates to entity data for batch resolution
|
|
212
|
+
for idx, (unit_id, local_idx, fact_date) in enumerate(entity_to_unit):
|
|
213
|
+
all_entities_flat[idx]['event_date'] = fact_date
|
|
214
|
+
|
|
215
|
+
# Resolve ALL entities in ONE batch call (much faster than sequential buckets)
|
|
216
|
+
# INSERT ... ON CONFLICT handles any race conditions at the DB level
|
|
217
|
+
resolved_entity_ids = await entity_resolver.resolve_entities_batch(
|
|
218
|
+
bank_id=bank_id,
|
|
219
|
+
entities_data=all_entities_flat,
|
|
220
|
+
context=context,
|
|
221
|
+
unit_event_date=None, # Not used when per-entity dates provided
|
|
222
|
+
conn=conn # Use main transaction connection
|
|
223
|
+
)
|
|
244
224
|
|
|
245
|
-
_log(log_buffer, f" [6.2.2] Resolve entities: {len(all_entities_flat)} entities
|
|
225
|
+
_log(log_buffer, f" [6.2.2] Resolve entities: {len(all_entities_flat)} entities in single batch in {time.time() - substep_6_2_2_start:.3f}s", level='debug')
|
|
246
226
|
|
|
247
227
|
# [6.2.3] Create unit-entity links in BATCH
|
|
248
228
|
substep_6_2_3_start = time.time()
|
|
@@ -305,10 +285,14 @@ async def extract_entities_batch_optimized(
|
|
|
305
285
|
# Only link each new unit to the most recent MAX_LINKS_PER_ENTITY units
|
|
306
286
|
MAX_LINKS_PER_ENTITY = 50 # Limit to prevent explosion when entity appears in many facts
|
|
307
287
|
link_gen_start = time.time()
|
|
308
|
-
links = []
|
|
288
|
+
links: List[EntityLink] = []
|
|
309
289
|
new_unit_set = set(unit_ids) # Units from this batch
|
|
310
290
|
|
|
291
|
+
def to_uuid(val) -> UUID:
|
|
292
|
+
return UUID(val) if isinstance(val, str) else val
|
|
293
|
+
|
|
311
294
|
for entity_id, units_with_entity in entity_to_units.items():
|
|
295
|
+
entity_uuid = to_uuid(entity_id)
|
|
312
296
|
# Separate new units (from this batch) and existing units
|
|
313
297
|
new_units = [u for u in units_with_entity if str(u) in new_unit_set or u in new_unit_set]
|
|
314
298
|
existing_units = [u for u in units_with_entity if str(u) not in new_unit_set and u not in new_unit_set]
|
|
@@ -318,15 +302,15 @@ async def extract_entities_batch_optimized(
|
|
|
318
302
|
new_units_to_link = new_units[-MAX_LINKS_PER_ENTITY:] if len(new_units) > MAX_LINKS_PER_ENTITY else new_units
|
|
319
303
|
for i, unit_id_1 in enumerate(new_units_to_link):
|
|
320
304
|
for unit_id_2 in new_units_to_link[i+1:]:
|
|
321
|
-
links.append((unit_id_1, unit_id_2,
|
|
322
|
-
links.append((unit_id_2, unit_id_1,
|
|
305
|
+
links.append(EntityLink(from_unit_id=to_uuid(unit_id_1), to_unit_id=to_uuid(unit_id_2), entity_id=entity_uuid))
|
|
306
|
+
links.append(EntityLink(from_unit_id=to_uuid(unit_id_2), to_unit_id=to_uuid(unit_id_1), entity_id=entity_uuid))
|
|
323
307
|
|
|
324
308
|
# Link new units to LIMITED existing units (most recent)
|
|
325
309
|
existing_to_link = existing_units[-MAX_LINKS_PER_ENTITY:] # Take most recent
|
|
326
310
|
for new_unit in new_units:
|
|
327
311
|
for existing_unit in existing_to_link:
|
|
328
|
-
links.append((new_unit, existing_unit,
|
|
329
|
-
links.append((existing_unit, new_unit,
|
|
312
|
+
links.append(EntityLink(from_unit_id=to_uuid(new_unit), to_unit_id=to_uuid(existing_unit), entity_id=entity_uuid))
|
|
313
|
+
links.append(EntityLink(from_unit_id=to_uuid(existing_unit), to_unit_id=to_uuid(new_unit), entity_id=entity_uuid))
|
|
330
314
|
|
|
331
315
|
_log(log_buffer, f" [6.3.3] Generate {len(links)} links: {time.time() - link_gen_start:.3f}s", level='debug')
|
|
332
316
|
_log(log_buffer, f" [6.3] Entity link creation: {len(links)} links for {len(all_entity_ids)} unique entities in {time.time() - substep_start:.3f}s", level='debug')
|
|
@@ -346,7 +330,7 @@ async def create_temporal_links_batch_per_fact(
|
|
|
346
330
|
unit_ids: List[str],
|
|
347
331
|
time_window_hours: int = 24,
|
|
348
332
|
log_buffer: List[str] = None,
|
|
349
|
-
):
|
|
333
|
+
) -> int:
|
|
350
334
|
"""
|
|
351
335
|
Create temporal links for multiple units, each with their own event_date.
|
|
352
336
|
|
|
@@ -359,9 +343,12 @@ async def create_temporal_links_batch_per_fact(
|
|
|
359
343
|
unit_ids: List of unit IDs
|
|
360
344
|
time_window_hours: Time window in hours for temporal links
|
|
361
345
|
log_buffer: Optional buffer for logging
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Number of temporal links created
|
|
362
349
|
"""
|
|
363
350
|
if not unit_ids:
|
|
364
|
-
return
|
|
351
|
+
return 0
|
|
365
352
|
|
|
366
353
|
try:
|
|
367
354
|
import time as time_mod
|
|
@@ -417,6 +404,8 @@ async def create_temporal_links_batch_per_fact(
|
|
|
417
404
|
)
|
|
418
405
|
_log(log_buffer, f" [7.4] Insert {len(links)} temporal links: {time_mod.time() - insert_start:.3f}s")
|
|
419
406
|
|
|
407
|
+
return len(links)
|
|
408
|
+
|
|
420
409
|
except Exception as e:
|
|
421
410
|
logger.error(f"Failed to create temporal links: {str(e)}")
|
|
422
411
|
import traceback
|
|
@@ -432,7 +421,7 @@ async def create_semantic_links_batch(
|
|
|
432
421
|
top_k: int = 5,
|
|
433
422
|
threshold: float = 0.7,
|
|
434
423
|
log_buffer: List[str] = None,
|
|
435
|
-
):
|
|
424
|
+
) -> int:
|
|
436
425
|
"""
|
|
437
426
|
Create semantic links for multiple units efficiently.
|
|
438
427
|
|
|
@@ -446,9 +435,12 @@ async def create_semantic_links_batch(
|
|
|
446
435
|
top_k: Number of top similar units to link
|
|
447
436
|
threshold: Minimum similarity threshold
|
|
448
437
|
log_buffer: Optional buffer for logging
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
Number of semantic links created
|
|
449
441
|
"""
|
|
450
442
|
if not unit_ids or not embeddings:
|
|
451
|
-
return
|
|
443
|
+
return 0
|
|
452
444
|
|
|
453
445
|
try:
|
|
454
446
|
import time as time_mod
|
|
@@ -539,6 +531,8 @@ async def create_semantic_links_batch(
|
|
|
539
531
|
)
|
|
540
532
|
_log(log_buffer, f" [8.3] Insert {len(all_links)} semantic links: {time_mod.time() - insert_start:.3f}s")
|
|
541
533
|
|
|
534
|
+
return len(all_links)
|
|
535
|
+
|
|
542
536
|
except Exception as e:
|
|
543
537
|
logger.error(f"Failed to create semantic links: {str(e)}")
|
|
544
538
|
import traceback
|
|
@@ -546,7 +540,7 @@ async def create_semantic_links_batch(
|
|
|
546
540
|
raise
|
|
547
541
|
|
|
548
542
|
|
|
549
|
-
async def insert_entity_links_batch(conn, links: List[
|
|
543
|
+
async def insert_entity_links_batch(conn, links: List[EntityLink], chunk_size: int = 50000):
|
|
550
544
|
"""
|
|
551
545
|
Insert all entity links using COPY to temp table + INSERT for maximum speed.
|
|
552
546
|
|
|
@@ -556,7 +550,7 @@ async def insert_entity_links_batch(conn, links: List[tuple], chunk_size: int =
|
|
|
556
550
|
|
|
557
551
|
Args:
|
|
558
552
|
conn: Database connection
|
|
559
|
-
links: List of
|
|
553
|
+
links: List of EntityLink objects
|
|
560
554
|
chunk_size: Number of rows per batch (default 50000)
|
|
561
555
|
"""
|
|
562
556
|
if not links:
|
|
@@ -585,16 +579,16 @@ async def insert_entity_links_batch(conn, links: List[tuple], chunk_size: int =
|
|
|
585
579
|
await conn.execute("TRUNCATE _temp_entity_links")
|
|
586
580
|
logger.debug(f" [9.2] Truncate temp table: {time_mod.time() - truncate_start:.3f}s")
|
|
587
581
|
|
|
588
|
-
# Convert
|
|
582
|
+
# Convert EntityLink objects to tuples for COPY
|
|
589
583
|
convert_start = time_mod.time()
|
|
590
584
|
records = []
|
|
591
|
-
for
|
|
585
|
+
for link in links:
|
|
592
586
|
records.append((
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
link_type,
|
|
596
|
-
weight,
|
|
597
|
-
|
|
587
|
+
link.from_unit_id,
|
|
588
|
+
link.to_unit_id,
|
|
589
|
+
link.link_type,
|
|
590
|
+
link.weight,
|
|
591
|
+
link.entity_id
|
|
598
592
|
))
|
|
599
593
|
logger.debug(f" [9.3] Convert {len(records)} records: {time_mod.time() - convert_start:.3f}s")
|
|
600
594
|
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Observation regeneration for retain pipeline.
|
|
3
|
+
|
|
4
|
+
Regenerates entity observations as part of the retain transaction.
|
|
5
|
+
"""
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
import uuid
|
|
9
|
+
from datetime import datetime, timezone
|
|
10
|
+
from typing import List, Dict, Optional
|
|
11
|
+
|
|
12
|
+
from ..search import observation_utils
|
|
13
|
+
from . import embedding_utils
|
|
14
|
+
from ..db_utils import acquire_with_retry
|
|
15
|
+
from .types import EntityLink
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def utcnow():
|
|
21
|
+
"""Get current UTC time."""
|
|
22
|
+
return datetime.now(timezone.utc)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Simple dataclass-like container for facts (avoid importing from memory_engine)
|
|
26
|
+
class MemoryFactForObservation:
|
|
27
|
+
def __init__(self, id: str, text: str, fact_type: str, context: str, occurred_start: Optional[str]):
|
|
28
|
+
self.id = id
|
|
29
|
+
self.text = text
|
|
30
|
+
self.fact_type = fact_type
|
|
31
|
+
self.context = context
|
|
32
|
+
self.occurred_start = occurred_start
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
async def regenerate_observations_batch(
|
|
36
|
+
conn,
|
|
37
|
+
embeddings_model,
|
|
38
|
+
llm_config,
|
|
39
|
+
bank_id: str,
|
|
40
|
+
entity_links: List[EntityLink],
|
|
41
|
+
log_buffer: List[str] = None
|
|
42
|
+
) -> None:
|
|
43
|
+
"""
|
|
44
|
+
Regenerate observations for top entities in this batch.
|
|
45
|
+
|
|
46
|
+
Called INSIDE the retain transaction for atomicity - if observations
|
|
47
|
+
fail, the entire retain batch is rolled back.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
conn: Database connection (from the retain transaction)
|
|
51
|
+
embeddings_model: Embeddings model for generating observation embeddings
|
|
52
|
+
llm_config: LLM configuration for observation extraction
|
|
53
|
+
bank_id: Bank identifier
|
|
54
|
+
entity_links: Entity links from this batch
|
|
55
|
+
log_buffer: Optional log buffer for timing
|
|
56
|
+
"""
|
|
57
|
+
TOP_N_ENTITIES = 5
|
|
58
|
+
MIN_FACTS_THRESHOLD = 5
|
|
59
|
+
|
|
60
|
+
if not entity_links:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
# Count mentions per entity in this batch
|
|
64
|
+
entity_mention_counts: Dict[str, int] = {}
|
|
65
|
+
for link in entity_links:
|
|
66
|
+
if link.entity_id:
|
|
67
|
+
entity_id = str(link.entity_id)
|
|
68
|
+
entity_mention_counts[entity_id] = entity_mention_counts.get(entity_id, 0) + 1
|
|
69
|
+
|
|
70
|
+
if not entity_mention_counts:
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
# Sort by mention count descending and take top N
|
|
74
|
+
sorted_entities = sorted(
|
|
75
|
+
entity_mention_counts.items(),
|
|
76
|
+
key=lambda x: x[1],
|
|
77
|
+
reverse=True
|
|
78
|
+
)
|
|
79
|
+
entities_to_process = [e[0] for e in sorted_entities[:TOP_N_ENTITIES]]
|
|
80
|
+
|
|
81
|
+
obs_start = time.time()
|
|
82
|
+
|
|
83
|
+
# Convert to UUIDs
|
|
84
|
+
entity_uuids = [uuid.UUID(eid) if isinstance(eid, str) else eid for eid in entities_to_process]
|
|
85
|
+
|
|
86
|
+
# Batch query for entity names
|
|
87
|
+
entity_rows = await conn.fetch(
|
|
88
|
+
"""
|
|
89
|
+
SELECT id, canonical_name FROM entities
|
|
90
|
+
WHERE id = ANY($1) AND bank_id = $2
|
|
91
|
+
""",
|
|
92
|
+
entity_uuids, bank_id
|
|
93
|
+
)
|
|
94
|
+
entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
|
|
95
|
+
|
|
96
|
+
# Batch query for fact counts
|
|
97
|
+
fact_counts = await conn.fetch(
|
|
98
|
+
"""
|
|
99
|
+
SELECT ue.entity_id, COUNT(*) as cnt
|
|
100
|
+
FROM unit_entities ue
|
|
101
|
+
JOIN memory_units mu ON ue.unit_id = mu.id
|
|
102
|
+
WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
|
|
103
|
+
GROUP BY ue.entity_id
|
|
104
|
+
""",
|
|
105
|
+
entity_uuids, bank_id
|
|
106
|
+
)
|
|
107
|
+
entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
|
|
108
|
+
|
|
109
|
+
# Filter entities that meet the threshold
|
|
110
|
+
entities_with_names = []
|
|
111
|
+
for entity_id in entities_to_process:
|
|
112
|
+
entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
|
|
113
|
+
if entity_uuid not in entity_names:
|
|
114
|
+
continue
|
|
115
|
+
fact_count = entity_fact_counts.get(entity_uuid, 0)
|
|
116
|
+
if fact_count >= MIN_FACTS_THRESHOLD:
|
|
117
|
+
entities_with_names.append((entity_id, entity_names[entity_uuid]))
|
|
118
|
+
|
|
119
|
+
if not entities_with_names:
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
# Process entities SEQUENTIALLY (asyncpg doesn't allow concurrent queries on same connection)
|
|
123
|
+
# We must use the same connection to stay in the retain transaction
|
|
124
|
+
total_observations = 0
|
|
125
|
+
|
|
126
|
+
for entity_id, entity_name in entities_with_names:
|
|
127
|
+
try:
|
|
128
|
+
obs_ids = await _regenerate_entity_observations(
|
|
129
|
+
conn, embeddings_model, llm_config,
|
|
130
|
+
bank_id, entity_id, entity_name
|
|
131
|
+
)
|
|
132
|
+
total_observations += len(obs_ids)
|
|
133
|
+
except Exception as e:
|
|
134
|
+
logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
|
|
135
|
+
|
|
136
|
+
obs_time = time.time() - obs_start
|
|
137
|
+
if log_buffer is not None:
|
|
138
|
+
log_buffer.append(f"[11] Observations: {total_observations} observations for {len(entities_with_names)} entities in {obs_time:.3f}s")
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
async def _regenerate_entity_observations(
|
|
142
|
+
conn,
|
|
143
|
+
embeddings_model,
|
|
144
|
+
llm_config,
|
|
145
|
+
bank_id: str,
|
|
146
|
+
entity_id: str,
|
|
147
|
+
entity_name: str
|
|
148
|
+
) -> List[str]:
|
|
149
|
+
"""
|
|
150
|
+
Regenerate observations for a single entity.
|
|
151
|
+
|
|
152
|
+
Uses the provided connection (part of retain transaction).
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
conn: Database connection (from the retain transaction)
|
|
156
|
+
embeddings_model: Embeddings model
|
|
157
|
+
llm_config: LLM configuration
|
|
158
|
+
bank_id: Bank identifier
|
|
159
|
+
entity_id: Entity UUID
|
|
160
|
+
entity_name: Canonical name of the entity
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
List of created observation IDs
|
|
164
|
+
"""
|
|
165
|
+
entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
|
|
166
|
+
|
|
167
|
+
# Get all facts mentioning this entity (exclude observations themselves)
|
|
168
|
+
rows = await conn.fetch(
|
|
169
|
+
"""
|
|
170
|
+
SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
|
|
171
|
+
FROM memory_units mu
|
|
172
|
+
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
173
|
+
WHERE mu.bank_id = $1
|
|
174
|
+
AND ue.entity_id = $2
|
|
175
|
+
AND mu.fact_type IN ('world', 'experience')
|
|
176
|
+
ORDER BY mu.occurred_start DESC
|
|
177
|
+
LIMIT 50
|
|
178
|
+
""",
|
|
179
|
+
bank_id, entity_uuid
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if not rows:
|
|
183
|
+
return []
|
|
184
|
+
|
|
185
|
+
# Convert to fact objects for observation extraction
|
|
186
|
+
facts = []
|
|
187
|
+
for row in rows:
|
|
188
|
+
occurred_start = row['occurred_start'].isoformat() if row['occurred_start'] else None
|
|
189
|
+
facts.append(MemoryFactForObservation(
|
|
190
|
+
id=str(row['id']),
|
|
191
|
+
text=row['text'],
|
|
192
|
+
fact_type=row['fact_type'],
|
|
193
|
+
context=row['context'],
|
|
194
|
+
occurred_start=occurred_start
|
|
195
|
+
))
|
|
196
|
+
|
|
197
|
+
# Extract observations using LLM
|
|
198
|
+
observations = await observation_utils.extract_observations_from_facts(
|
|
199
|
+
llm_config,
|
|
200
|
+
entity_name,
|
|
201
|
+
facts
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
if not observations:
|
|
205
|
+
return []
|
|
206
|
+
|
|
207
|
+
# Delete old observations for this entity
|
|
208
|
+
await conn.execute(
|
|
209
|
+
"""
|
|
210
|
+
DELETE FROM memory_units
|
|
211
|
+
WHERE id IN (
|
|
212
|
+
SELECT mu.id
|
|
213
|
+
FROM memory_units mu
|
|
214
|
+
JOIN unit_entities ue ON mu.id = ue.unit_id
|
|
215
|
+
WHERE mu.bank_id = $1
|
|
216
|
+
AND mu.fact_type = 'observation'
|
|
217
|
+
AND ue.entity_id = $2
|
|
218
|
+
)
|
|
219
|
+
""",
|
|
220
|
+
bank_id, entity_uuid
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Generate embeddings for new observations
|
|
224
|
+
embeddings = await embedding_utils.generate_embeddings_batch(
|
|
225
|
+
embeddings_model, observations
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Insert new observations
|
|
229
|
+
current_time = utcnow()
|
|
230
|
+
created_ids = []
|
|
231
|
+
|
|
232
|
+
for obs_text, embedding in zip(observations, embeddings):
|
|
233
|
+
result = await conn.fetchrow(
|
|
234
|
+
"""
|
|
235
|
+
INSERT INTO memory_units (
|
|
236
|
+
bank_id, text, embedding, context, event_date,
|
|
237
|
+
occurred_start, occurred_end, mentioned_at,
|
|
238
|
+
fact_type, access_count
|
|
239
|
+
)
|
|
240
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
|
|
241
|
+
RETURNING id
|
|
242
|
+
""",
|
|
243
|
+
bank_id,
|
|
244
|
+
obs_text,
|
|
245
|
+
str(embedding),
|
|
246
|
+
f"observation about {entity_name}",
|
|
247
|
+
current_time,
|
|
248
|
+
current_time,
|
|
249
|
+
current_time,
|
|
250
|
+
current_time
|
|
251
|
+
)
|
|
252
|
+
obs_id = str(result['id'])
|
|
253
|
+
created_ids.append(obs_id)
|
|
254
|
+
|
|
255
|
+
# Link observation to entity
|
|
256
|
+
await conn.execute(
|
|
257
|
+
"""
|
|
258
|
+
INSERT INTO unit_entities (unit_id, entity_id)
|
|
259
|
+
VALUES ($1, $2)
|
|
260
|
+
""",
|
|
261
|
+
uuid.UUID(obs_id), entity_uuid
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
return created_ids
|
|
@@ -17,7 +17,7 @@ def utcnow():
|
|
|
17
17
|
"""Get current UTC time."""
|
|
18
18
|
return datetime.now(timezone.utc)
|
|
19
19
|
|
|
20
|
-
from .types import RetainContent, ExtractedFact, ProcessedFact
|
|
20
|
+
from .types import RetainContent, ExtractedFact, ProcessedFact, EntityLink
|
|
21
21
|
from . import (
|
|
22
22
|
fact_extraction,
|
|
23
23
|
embedding_processing,
|
|
@@ -25,7 +25,8 @@ from . import (
|
|
|
25
25
|
chunk_storage,
|
|
26
26
|
fact_storage,
|
|
27
27
|
entity_processing,
|
|
28
|
-
link_creation
|
|
28
|
+
link_creation,
|
|
29
|
+
observation_regeneration
|
|
29
30
|
)
|
|
30
31
|
|
|
31
32
|
logger = logging.getLogger(__name__)
|
|
@@ -39,7 +40,6 @@ async def retain_batch(
|
|
|
39
40
|
task_backend,
|
|
40
41
|
format_date_fn,
|
|
41
42
|
duplicate_checker_fn,
|
|
42
|
-
regenerate_observations_fn,
|
|
43
43
|
bank_id: str,
|
|
44
44
|
contents_dicts: List[Dict[str, Any]],
|
|
45
45
|
document_id: Optional[str] = None,
|
|
@@ -58,7 +58,6 @@ async def retain_batch(
|
|
|
58
58
|
task_backend: Task backend for background jobs
|
|
59
59
|
format_date_fn: Function to format datetime to readable string
|
|
60
60
|
duplicate_checker_fn: Function to check for duplicate facts
|
|
61
|
-
regenerate_observations_fn: Async function to regenerate observations for entities
|
|
62
61
|
bank_id: Bank identifier
|
|
63
62
|
contents_dicts: List of content dictionaries
|
|
64
63
|
document_id: Optional document ID
|
|
@@ -288,50 +287,59 @@ async def retain_batch(
|
|
|
288
287
|
|
|
289
288
|
# Create temporal links
|
|
290
289
|
step_start = time.time()
|
|
291
|
-
await link_creation.create_temporal_links_batch(conn, bank_id, unit_ids)
|
|
292
|
-
log_buffer.append(f"[7] Temporal links: {time.time() - step_start:.3f}s")
|
|
290
|
+
temporal_link_count = await link_creation.create_temporal_links_batch(conn, bank_id, unit_ids)
|
|
291
|
+
log_buffer.append(f"[7] Temporal links: {temporal_link_count} links in {time.time() - step_start:.3f}s")
|
|
293
292
|
|
|
294
293
|
# Create semantic links
|
|
295
294
|
step_start = time.time()
|
|
296
295
|
embeddings_for_links = [fact.embedding for fact in non_duplicate_facts]
|
|
297
|
-
await link_creation.create_semantic_links_batch(conn, bank_id, unit_ids, embeddings_for_links)
|
|
298
|
-
log_buffer.append(f"[8] Semantic links: {time.time() - step_start:.3f}s")
|
|
296
|
+
semantic_link_count = await link_creation.create_semantic_links_batch(conn, bank_id, unit_ids, embeddings_for_links)
|
|
297
|
+
log_buffer.append(f"[8] Semantic links: {semantic_link_count} links in {time.time() - step_start:.3f}s")
|
|
299
298
|
|
|
300
299
|
# Insert entity links
|
|
301
300
|
step_start = time.time()
|
|
302
301
|
if entity_links:
|
|
303
302
|
await entity_processing.insert_entity_links_batch(conn, entity_links)
|
|
304
|
-
log_buffer.append(f"[9] Entity links: {time.time() - step_start:.3f}s")
|
|
303
|
+
log_buffer.append(f"[9] Entity links: {len(entity_links) if entity_links else 0} links in {time.time() - step_start:.3f}s")
|
|
305
304
|
|
|
306
305
|
# Create causal links
|
|
307
306
|
step_start = time.time()
|
|
308
307
|
causal_link_count = await link_creation.create_causal_links_batch(conn, unit_ids, non_duplicate_facts)
|
|
309
308
|
log_buffer.append(f"[10] Causal links: {causal_link_count} links in {time.time() - step_start:.3f}s")
|
|
310
309
|
|
|
310
|
+
# Regenerate observations INSIDE transaction for atomicity
|
|
311
|
+
await observation_regeneration.regenerate_observations_batch(
|
|
312
|
+
conn,
|
|
313
|
+
embeddings_model,
|
|
314
|
+
llm_config,
|
|
315
|
+
bank_id,
|
|
316
|
+
entity_links,
|
|
317
|
+
log_buffer
|
|
318
|
+
)
|
|
319
|
+
|
|
311
320
|
# Map results back to original content items
|
|
312
321
|
result_unit_ids = _map_results_to_contents(
|
|
313
322
|
contents, extracted_facts, is_duplicate_flags, unit_ids
|
|
314
323
|
)
|
|
315
324
|
|
|
316
|
-
|
|
317
|
-
log_buffer.append(f"{'='*60}")
|
|
318
|
-
log_buffer.append(f"RETAIN_BATCH COMPLETE: {len(unit_ids)} units in {total_time:.3f}s")
|
|
319
|
-
if document_ids_added:
|
|
320
|
-
log_buffer.append(f"Documents: {', '.join(document_ids_added)}")
|
|
321
|
-
log_buffer.append(f"{'='*60}")
|
|
322
|
-
|
|
323
|
-
logger.info("\n" + "\n".join(log_buffer) + "\n")
|
|
324
|
-
|
|
325
|
-
# Trigger background tasks AFTER transaction commits
|
|
325
|
+
# Trigger background tasks AFTER transaction commits (opinion reinforcement only)
|
|
326
326
|
await _trigger_background_tasks(
|
|
327
327
|
task_backend,
|
|
328
|
-
regenerate_observations_fn,
|
|
329
328
|
bank_id,
|
|
330
329
|
unit_ids,
|
|
331
|
-
non_duplicate_facts
|
|
332
|
-
entity_links
|
|
330
|
+
non_duplicate_facts
|
|
333
331
|
)
|
|
334
332
|
|
|
333
|
+
# Log final summary
|
|
334
|
+
total_time = time.time() - start_time
|
|
335
|
+
log_buffer.append(f"{'='*60}")
|
|
336
|
+
log_buffer.append(f"RETAIN_BATCH COMPLETE: {len(unit_ids)} units in {total_time:.3f}s")
|
|
337
|
+
if document_ids_added:
|
|
338
|
+
log_buffer.append(f"Documents: {', '.join(document_ids_added)}")
|
|
339
|
+
log_buffer.append(f"{'='*60}")
|
|
340
|
+
|
|
341
|
+
logger.info("\n" + "\n".join(log_buffer) + "\n")
|
|
342
|
+
|
|
335
343
|
return result_unit_ids
|
|
336
344
|
|
|
337
345
|
|
|
@@ -367,13 +375,11 @@ def _map_results_to_contents(
|
|
|
367
375
|
|
|
368
376
|
async def _trigger_background_tasks(
|
|
369
377
|
task_backend,
|
|
370
|
-
regenerate_observations_fn,
|
|
371
378
|
bank_id: str,
|
|
372
379
|
unit_ids: List[str],
|
|
373
380
|
facts: List[ProcessedFact],
|
|
374
|
-
entity_links: List
|
|
375
381
|
) -> None:
|
|
376
|
-
"""Trigger opinion reinforcement
|
|
382
|
+
"""Trigger opinion reinforcement as background task (after transaction commits)."""
|
|
377
383
|
# Trigger opinion reinforcement if there are entities
|
|
378
384
|
fact_entities = [[e.name for e in fact.entities] for fact in facts]
|
|
379
385
|
if any(fact_entities):
|
|
@@ -384,22 +390,3 @@ async def _trigger_background_tasks(
|
|
|
384
390
|
'unit_texts': [fact.fact_text for fact in facts],
|
|
385
391
|
'unit_entities': fact_entities
|
|
386
392
|
})
|
|
387
|
-
|
|
388
|
-
# Regenerate observations synchronously for top entities
|
|
389
|
-
TOP_N_ENTITIES = 5
|
|
390
|
-
MIN_FACTS_THRESHOLD = 5
|
|
391
|
-
|
|
392
|
-
if entity_links and regenerate_observations_fn:
|
|
393
|
-
unique_entity_ids = set()
|
|
394
|
-
for link in entity_links:
|
|
395
|
-
# links are tuples: (unit_id, entity_id, confidence)
|
|
396
|
-
if len(link) >= 2 and link[1]:
|
|
397
|
-
unique_entity_ids.add(str(link[1]))
|
|
398
|
-
|
|
399
|
-
if unique_entity_ids:
|
|
400
|
-
# Run observation regeneration synchronously
|
|
401
|
-
await regenerate_observations_fn(
|
|
402
|
-
bank_id=bank_id,
|
|
403
|
-
entity_ids=list(unique_entity_ids)[:TOP_N_ENTITIES],
|
|
404
|
-
min_facts=MIN_FACTS_THRESHOLD
|
|
405
|
-
)
|
|
@@ -176,6 +176,20 @@ class ProcessedFact:
|
|
|
176
176
|
)
|
|
177
177
|
|
|
178
178
|
|
|
179
|
+
@dataclass
|
|
180
|
+
class EntityLink:
|
|
181
|
+
"""
|
|
182
|
+
Link between two memory units through a shared entity.
|
|
183
|
+
|
|
184
|
+
Used for entity-based graph connections in the memory graph.
|
|
185
|
+
"""
|
|
186
|
+
from_unit_id: UUID
|
|
187
|
+
to_unit_id: UUID
|
|
188
|
+
entity_id: UUID
|
|
189
|
+
link_type: str = 'entity'
|
|
190
|
+
weight: float = 1.0
|
|
191
|
+
|
|
192
|
+
|
|
179
193
|
@dataclass
|
|
180
194
|
class RetainBatch:
|
|
181
195
|
"""
|
|
@@ -10,10 +10,8 @@ class CrossEncoderReranker:
|
|
|
10
10
|
"""
|
|
11
11
|
Neural reranking using a cross-encoder model.
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
- Small model (80MB)
|
|
16
|
-
- Trained for passage re-ranking
|
|
13
|
+
Configured via environment variables (see cross_encoder.py).
|
|
14
|
+
Default local model is cross-encoder/ms-marco-MiniLM-L-6-v2.
|
|
17
15
|
"""
|
|
18
16
|
|
|
19
17
|
def __init__(self, cross_encoder=None):
|
|
@@ -21,14 +19,12 @@ class CrossEncoderReranker:
|
|
|
21
19
|
Initialize cross-encoder reranker.
|
|
22
20
|
|
|
23
21
|
Args:
|
|
24
|
-
cross_encoder:
|
|
25
|
-
|
|
26
|
-
(loaded lazily for faster startup)
|
|
22
|
+
cross_encoder: CrossEncoderModel instance. If None, creates one from
|
|
23
|
+
environment variables (defaults to local provider)
|
|
27
24
|
"""
|
|
28
25
|
if cross_encoder is None:
|
|
29
|
-
from hindsight_api.engine.cross_encoder import
|
|
30
|
-
|
|
31
|
-
cross_encoder = SentenceTransformersCrossEncoder()
|
|
26
|
+
from hindsight_api.engine.cross_encoder import create_cross_encoder_from_env
|
|
27
|
+
cross_encoder = create_cross_encoder_from_env()
|
|
32
28
|
self.cross_encoder = cross_encoder
|
|
33
29
|
|
|
34
30
|
def rerank(
|