hindsight-api 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. hindsight_api/__init__.py +38 -0
  2. hindsight_api/api/__init__.py +105 -0
  3. hindsight_api/api/http.py +1872 -0
  4. hindsight_api/api/mcp.py +157 -0
  5. hindsight_api/engine/__init__.py +47 -0
  6. hindsight_api/engine/cross_encoder.py +97 -0
  7. hindsight_api/engine/db_utils.py +93 -0
  8. hindsight_api/engine/embeddings.py +113 -0
  9. hindsight_api/engine/entity_resolver.py +575 -0
  10. hindsight_api/engine/llm_wrapper.py +269 -0
  11. hindsight_api/engine/memory_engine.py +3095 -0
  12. hindsight_api/engine/query_analyzer.py +519 -0
  13. hindsight_api/engine/response_models.py +222 -0
  14. hindsight_api/engine/retain/__init__.py +50 -0
  15. hindsight_api/engine/retain/bank_utils.py +423 -0
  16. hindsight_api/engine/retain/chunk_storage.py +82 -0
  17. hindsight_api/engine/retain/deduplication.py +104 -0
  18. hindsight_api/engine/retain/embedding_processing.py +62 -0
  19. hindsight_api/engine/retain/embedding_utils.py +54 -0
  20. hindsight_api/engine/retain/entity_processing.py +90 -0
  21. hindsight_api/engine/retain/fact_extraction.py +1027 -0
  22. hindsight_api/engine/retain/fact_storage.py +176 -0
  23. hindsight_api/engine/retain/link_creation.py +121 -0
  24. hindsight_api/engine/retain/link_utils.py +651 -0
  25. hindsight_api/engine/retain/orchestrator.py +405 -0
  26. hindsight_api/engine/retain/types.py +206 -0
  27. hindsight_api/engine/search/__init__.py +15 -0
  28. hindsight_api/engine/search/fusion.py +122 -0
  29. hindsight_api/engine/search/observation_utils.py +132 -0
  30. hindsight_api/engine/search/reranking.py +103 -0
  31. hindsight_api/engine/search/retrieval.py +503 -0
  32. hindsight_api/engine/search/scoring.py +161 -0
  33. hindsight_api/engine/search/temporal_extraction.py +64 -0
  34. hindsight_api/engine/search/think_utils.py +255 -0
  35. hindsight_api/engine/search/trace.py +215 -0
  36. hindsight_api/engine/search/tracer.py +447 -0
  37. hindsight_api/engine/search/types.py +160 -0
  38. hindsight_api/engine/task_backend.py +223 -0
  39. hindsight_api/engine/utils.py +203 -0
  40. hindsight_api/metrics.py +227 -0
  41. hindsight_api/migrations.py +163 -0
  42. hindsight_api/models.py +309 -0
  43. hindsight_api/pg0.py +425 -0
  44. hindsight_api/web/__init__.py +12 -0
  45. hindsight_api/web/server.py +143 -0
  46. hindsight_api-0.0.13.dist-info/METADATA +41 -0
  47. hindsight_api-0.0.13.dist-info/RECORD +48 -0
  48. hindsight_api-0.0.13.dist-info/WHEEL +4 -0
@@ -0,0 +1,651 @@
1
+ """
2
+ Link creation utilities for temporal, semantic, and entity links.
3
+ """
4
+
5
+ import time
6
+ import logging
7
+ from typing import List
8
+ from datetime import timedelta, datetime, timezone
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def _normalize_datetime(dt):
14
+ """Normalize datetime to be timezone-aware (UTC) for consistent comparison."""
15
+ if dt is None:
16
+ return None
17
+ if dt.tzinfo is None:
18
+ # Naive datetime - assume UTC
19
+ return dt.replace(tzinfo=timezone.utc)
20
+ return dt
21
+
22
+
23
+ def compute_temporal_links(
24
+ new_units: dict,
25
+ candidates: list,
26
+ time_window_hours: int = 24,
27
+ ) -> list:
28
+ """
29
+ Compute temporal links between new units and candidate neighbors.
30
+
31
+ This is a pure function that takes query results and returns link tuples,
32
+ making it easy to test without database access.
33
+
34
+ Args:
35
+ new_units: Dict mapping unit_id (str) to event_date (datetime)
36
+ candidates: List of dicts with 'id' and 'event_date' keys (candidate neighbors)
37
+ time_window_hours: Time window in hours for temporal links
38
+
39
+ Returns:
40
+ List of tuples: (from_unit_id, to_unit_id, 'temporal', weight, None)
41
+ """
42
+ if not new_units:
43
+ return []
44
+
45
+ links = []
46
+ for unit_id, unit_event_date in new_units.items():
47
+ # Normalize unit_event_date for consistent comparison
48
+ unit_event_date_norm = _normalize_datetime(unit_event_date)
49
+
50
+ # Calculate time window bounds with overflow protection
51
+ try:
52
+ time_lower = unit_event_date_norm - timedelta(hours=time_window_hours)
53
+ except OverflowError:
54
+ time_lower = datetime.min.replace(tzinfo=timezone.utc)
55
+ try:
56
+ time_upper = unit_event_date_norm + timedelta(hours=time_window_hours)
57
+ except OverflowError:
58
+ time_upper = datetime.max.replace(tzinfo=timezone.utc)
59
+
60
+ # Filter candidates within this unit's time window
61
+ matching_neighbors = [
62
+ (row['id'], row['event_date'])
63
+ for row in candidates
64
+ if time_lower <= _normalize_datetime(row['event_date']) <= time_upper
65
+ ][:10] # Limit to top 10
66
+
67
+ for recent_id, recent_event_date in matching_neighbors:
68
+ # Calculate temporal proximity weight
69
+ time_diff_hours = abs((unit_event_date_norm - _normalize_datetime(recent_event_date)).total_seconds() / 3600)
70
+ weight = max(0.3, 1.0 - (time_diff_hours / time_window_hours))
71
+ links.append((unit_id, str(recent_id), 'temporal', weight, None))
72
+
73
+ return links
74
+
75
+
76
+ def compute_temporal_query_bounds(
77
+ new_units: dict,
78
+ time_window_hours: int = 24,
79
+ ) -> tuple:
80
+ """
81
+ Compute the min/max date bounds for querying temporal neighbors.
82
+
83
+ Args:
84
+ new_units: Dict mapping unit_id (str) to event_date (datetime)
85
+ time_window_hours: Time window in hours
86
+
87
+ Returns:
88
+ Tuple of (min_date, max_date) with overflow protection
89
+ """
90
+ if not new_units:
91
+ return None, None
92
+
93
+ # Normalize all dates to be timezone-aware to avoid comparison issues
94
+ all_dates = [_normalize_datetime(d) for d in new_units.values()]
95
+
96
+ try:
97
+ min_date = min(all_dates) - timedelta(hours=time_window_hours)
98
+ except OverflowError:
99
+ min_date = datetime.min.replace(tzinfo=timezone.utc)
100
+
101
+ try:
102
+ max_date = max(all_dates) + timedelta(hours=time_window_hours)
103
+ except OverflowError:
104
+ max_date = datetime.max.replace(tzinfo=timezone.utc)
105
+
106
+ return min_date, max_date
107
+
108
+
109
+ def _log(log_buffer, message, level='info'):
110
+ """Helper to log to buffer if available, otherwise use logger."""
111
+ if log_buffer is not None:
112
+ log_buffer.append(message)
113
+ else:
114
+ if level == 'info':
115
+ logger.info(message)
116
+ else:
117
+ logger.log(logging.WARNING if level == 'warning' else logging.ERROR, message)
118
+
119
+
120
+ async def extract_entities_batch_optimized(
121
+ entity_resolver,
122
+ conn,
123
+ bank_id: str,
124
+ unit_ids: List[str],
125
+ sentences: List[str],
126
+ context: str,
127
+ fact_dates: List,
128
+ llm_entities: List[List[dict]],
129
+ log_buffer: List[str] = None,
130
+ ) -> List[tuple]:
131
+ """
132
+ Process LLM-extracted entities for ALL facts in batch.
133
+
134
+ Uses entities provided by the LLM (no spaCy needed), then resolves
135
+ and links them in bulk.
136
+
137
+ Args:
138
+ entity_resolver: EntityResolver instance for entity resolution
139
+ conn: Database connection
140
+ agent_id: bank IDentifier
141
+ unit_ids: List of unit IDs
142
+ sentences: List of fact sentences
143
+ context: Context string
144
+ fact_dates: List of fact dates
145
+ llm_entities: List of entity lists from LLM extraction
146
+ log_buffer: Optional buffer for logging
147
+
148
+ Returns:
149
+ List of tuples for batch insertion: (from_unit_id, to_unit_id, link_type, weight, entity_id)
150
+ """
151
+ try:
152
+ # Step 1: Convert LLM entities to the format expected by entity resolver
153
+ substep_start = time.time()
154
+ all_entities = []
155
+ for entity_list in llm_entities:
156
+ # Convert List[Entity] or List[dict] to List[Dict] format
157
+ formatted_entities = []
158
+ for ent in entity_list:
159
+ # Handle both Entity objects and dicts
160
+ if hasattr(ent, 'text'):
161
+ # Entity objects only have 'text', default type to 'CONCEPT'
162
+ formatted_entities.append({'text': ent.text, 'type': 'CONCEPT'})
163
+ elif isinstance(ent, dict):
164
+ formatted_entities.append({'text': ent.get('text', ''), 'type': ent.get('type', 'CONCEPT')})
165
+ all_entities.append(formatted_entities)
166
+
167
+ total_entities = sum(len(ents) for ents in all_entities)
168
+ _log(log_buffer, f" [6.1] Process LLM entities: {total_entities} entities from {len(sentences)} facts in {time.time() - substep_start:.3f}s")
169
+
170
+ # Step 2: Resolve entities in BATCH (much faster!)
171
+ substep_start = time.time()
172
+ step_6_2_start = time.time()
173
+
174
+ # [6.2.1] Prepare all entities for batch resolution
175
+ substep_6_2_1_start = time.time()
176
+ all_entities_flat = []
177
+ entity_to_unit = [] # Maps flat index to (unit_id, local_index)
178
+
179
+ for unit_id, entities, fact_date in zip(unit_ids, all_entities, fact_dates):
180
+ if not entities:
181
+ continue
182
+
183
+ for local_idx, entity in enumerate(entities):
184
+ all_entities_flat.append({
185
+ 'text': entity['text'],
186
+ 'type': entity['type'],
187
+ 'nearby_entities': entities,
188
+ })
189
+ entity_to_unit.append((unit_id, local_idx, fact_date))
190
+ _log(log_buffer, f" [6.2.1] Prepare entities: {len(all_entities_flat)} entities in {time.time() - substep_6_2_1_start:.3f}s")
191
+
192
+ # Resolve ALL entities in one batch call
193
+ if all_entities_flat:
194
+ # [6.2.2] Batch resolve entities
195
+ substep_6_2_2_start = time.time()
196
+ # Group by date for batch resolution (round to hour to reduce buckets)
197
+ entities_by_date = {}
198
+ for idx, (unit_id, local_idx, fact_date) in enumerate(entity_to_unit):
199
+ # Round to hour to group facts from same time period
200
+ date_key = fact_date.replace(minute=0, second=0, microsecond=0)
201
+ if date_key not in entities_by_date:
202
+ entities_by_date[date_key] = []
203
+ entities_by_date[date_key].append((idx, all_entities_flat[idx]))
204
+
205
+ _log(log_buffer, f" [6.2.2] Grouped into {len(entities_by_date)} date buckets, resolving in parallel...")
206
+
207
+ # Resolve all date groups in PARALLEL using asyncio.gather
208
+ resolved_entity_ids = [None] * len(all_entities_flat)
209
+
210
+ # Prepare all resolution tasks
211
+ async def resolve_date_bucket(date_idx, date_key, entities_group):
212
+ date_bucket_start = time.time()
213
+ indices = [idx for idx, _ in entities_group]
214
+ entities_data = [entity_data for _, entity_data in entities_group]
215
+ # Use the first fact's date for this bucket (all should be in same hour)
216
+ fact_date = entity_to_unit[indices[0]][2]
217
+
218
+ # Pass conn=None to let each parallel task acquire its own connection
219
+ batch_resolved = await entity_resolver.resolve_entities_batch(
220
+ bank_id=bank_id,
221
+ entities_data=entities_data,
222
+ context=context,
223
+ unit_event_date=fact_date,
224
+ conn=None # Each task gets its own connection from pool
225
+ )
226
+
227
+ if len(entities_by_date) <= 10: # Only log individual buckets if there aren't too many
228
+ _log(log_buffer, f" [6.2.2.{date_idx}] Resolved {len(entities_data)} entities in {time.time() - date_bucket_start:.3f}s")
229
+
230
+ return indices, batch_resolved
231
+
232
+ # Execute all resolution tasks in parallel
233
+ import asyncio
234
+ tasks = [
235
+ resolve_date_bucket(date_idx, date_key, entities_group)
236
+ for date_idx, (date_key, entities_group) in enumerate(entities_by_date.items(), 1)
237
+ ]
238
+ results = await asyncio.gather(*tasks)
239
+
240
+ # Map results back to resolved_entity_ids
241
+ for indices, batch_resolved in results:
242
+ for idx, entity_id in zip(indices, batch_resolved):
243
+ resolved_entity_ids[idx] = entity_id
244
+
245
+ _log(log_buffer, f" [6.2.2] Resolve entities: {len(all_entities_flat)} entities across {len(entities_by_date)} buckets in {time.time() - substep_6_2_2_start:.3f}s")
246
+
247
+ # [6.2.3] Create unit-entity links in BATCH
248
+ substep_6_2_3_start = time.time()
249
+ # Map resolved entities back to units and collect all (unit, entity) pairs
250
+ unit_to_entity_ids = {}
251
+ unit_entity_pairs = []
252
+ for idx, (unit_id, local_idx, fact_date) in enumerate(entity_to_unit):
253
+ if unit_id not in unit_to_entity_ids:
254
+ unit_to_entity_ids[unit_id] = []
255
+
256
+ entity_id = resolved_entity_ids[idx]
257
+ unit_to_entity_ids[unit_id].append(entity_id)
258
+ unit_entity_pairs.append((unit_id, entity_id))
259
+
260
+ # Batch insert all unit-entity links (MUCH faster!)
261
+ await entity_resolver.link_units_to_entities_batch(unit_entity_pairs, conn=conn)
262
+ _log(log_buffer, f" [6.2.3] Create unit-entity links (batched): {len(unit_entity_pairs)} links in {time.time() - substep_6_2_3_start:.3f}s")
263
+
264
+ _log(log_buffer, f" [6.2] Entity resolution (batched): {len(all_entities_flat)} entities resolved in {time.time() - step_6_2_start:.3f}s")
265
+ else:
266
+ unit_to_entity_ids = {}
267
+ _log(log_buffer, f" [6.2] Entity resolution (batched): 0 entities in {time.time() - step_6_2_start:.3f}s")
268
+
269
+ # Step 3: Create entity links between units that share entities
270
+ substep_start = time.time()
271
+ # Collect all unique entity IDs
272
+ all_entity_ids = set()
273
+ for entity_ids in unit_to_entity_ids.values():
274
+ all_entity_ids.update(entity_ids)
275
+
276
+ _log(log_buffer, f" [6.3] Creating entity links for {len(all_entity_ids)} unique entities...")
277
+
278
+ # Find all units that reference these entities (ONE batched query)
279
+ entity_to_units = {}
280
+ if all_entity_ids:
281
+ query_start = time.time()
282
+ import uuid
283
+ entity_id_list = [uuid.UUID(eid) if isinstance(eid, str) else eid for eid in all_entity_ids]
284
+ rows = await conn.fetch(
285
+ """
286
+ SELECT entity_id, unit_id
287
+ FROM unit_entities
288
+ WHERE entity_id = ANY($1::uuid[])
289
+ """,
290
+ entity_id_list
291
+ )
292
+ _log(log_buffer, f" [6.3.1] Query unit_entities: {len(rows)} rows in {time.time() - query_start:.3f}s")
293
+
294
+ # Group by entity_id
295
+ group_start = time.time()
296
+ for row in rows:
297
+ entity_id = row['entity_id']
298
+ if entity_id not in entity_to_units:
299
+ entity_to_units[entity_id] = []
300
+ entity_to_units[entity_id].append(row['unit_id'])
301
+ _log(log_buffer, f" [6.3.2] Group by entity_id: {time.time() - group_start:.3f}s")
302
+
303
+ # Create bidirectional links between units that share entities
304
+ link_gen_start = time.time()
305
+ links = []
306
+ for entity_id, units_with_entity in entity_to_units.items():
307
+ # For each pair of units with this entity, create bidirectional links
308
+ for i, unit_id_1 in enumerate(units_with_entity):
309
+ for unit_id_2 in units_with_entity[i+1:]:
310
+ # Bidirectional links
311
+ links.append((unit_id_1, unit_id_2, 'entity', 1.0, entity_id))
312
+ links.append((unit_id_2, unit_id_1, 'entity', 1.0, entity_id))
313
+
314
+ _log(log_buffer, f" [6.3.3] Generate {len(links)} links: {time.time() - link_gen_start:.3f}s")
315
+ _log(log_buffer, f" [6.3] Entity link creation: {len(links)} links for {len(all_entity_ids)} unique entities in {time.time() - substep_start:.3f}s")
316
+
317
+ return links
318
+
319
+ except Exception as e:
320
+ logger.error(f"Failed to extract entities in batch: {str(e)}")
321
+ import traceback
322
+ traceback.print_exc()
323
+ raise
324
+
325
+
326
+ async def create_temporal_links_batch_per_fact(
327
+ conn,
328
+ bank_id: str,
329
+ unit_ids: List[str],
330
+ time_window_hours: int = 24,
331
+ log_buffer: List[str] = None,
332
+ ):
333
+ """
334
+ Create temporal links for multiple units, each with their own event_date.
335
+
336
+ Queries the event_date for each unit from the database and creates temporal
337
+ links based on individual dates (supports per-fact dating).
338
+
339
+ Args:
340
+ conn: Database connection
341
+ agent_id: bank IDentifier
342
+ unit_ids: List of unit IDs
343
+ time_window_hours: Time window in hours for temporal links
344
+ log_buffer: Optional buffer for logging
345
+ """
346
+ if not unit_ids:
347
+ return
348
+
349
+ try:
350
+ import time as time_mod
351
+
352
+ # Get the event_date for each new unit
353
+ fetch_dates_start = time_mod.time()
354
+ rows = await conn.fetch(
355
+ """
356
+ SELECT id, event_date
357
+ FROM memory_units
358
+ WHERE id::text = ANY($1)
359
+ """,
360
+ unit_ids
361
+ )
362
+ new_units = {str(row['id']): row['event_date'] for row in rows}
363
+ _log(log_buffer, f" [7.1] Fetch event_dates for {len(unit_ids)} units: {time_mod.time() - fetch_dates_start:.3f}s")
364
+
365
+ # Fetch ALL potential temporal neighbors in ONE query (much faster!)
366
+ # Get time range across all units with overflow protection
367
+ min_date, max_date = compute_temporal_query_bounds(new_units, time_window_hours)
368
+
369
+ fetch_neighbors_start = time_mod.time()
370
+ all_candidates = await conn.fetch(
371
+ """
372
+ SELECT id, event_date
373
+ FROM memory_units
374
+ WHERE bank_id = $1
375
+ AND event_date BETWEEN $2 AND $3
376
+ AND id::text != ALL($4)
377
+ ORDER BY event_date DESC
378
+ """,
379
+ bank_id,
380
+ min_date,
381
+ max_date,
382
+ unit_ids
383
+ )
384
+ _log(log_buffer, f" [7.2] Fetch {len(all_candidates)} candidate neighbors (1 query): {time_mod.time() - fetch_neighbors_start:.3f}s")
385
+
386
+ # Filter and create links in memory (much faster than N queries)
387
+ link_gen_start = time_mod.time()
388
+ links = compute_temporal_links(new_units, all_candidates, time_window_hours)
389
+ _log(log_buffer, f" [7.3] Generate {len(links)} temporal links: {time_mod.time() - link_gen_start:.3f}s")
390
+
391
+ if links:
392
+ insert_start = time_mod.time()
393
+ await conn.executemany(
394
+ """
395
+ INSERT INTO memory_links (from_unit_id, to_unit_id, link_type, weight, entity_id)
396
+ VALUES ($1, $2, $3, $4, $5)
397
+ ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
398
+ """,
399
+ links
400
+ )
401
+ _log(log_buffer, f" [7.4] Insert {len(links)} temporal links: {time_mod.time() - insert_start:.3f}s")
402
+
403
+ except Exception as e:
404
+ logger.error(f"Failed to create temporal links: {str(e)}")
405
+ import traceback
406
+ traceback.print_exc()
407
+ raise
408
+
409
+
410
+ async def create_semantic_links_batch(
411
+ conn,
412
+ bank_id: str,
413
+ unit_ids: List[str],
414
+ embeddings: List[List[float]],
415
+ top_k: int = 5,
416
+ threshold: float = 0.7,
417
+ log_buffer: List[str] = None,
418
+ ):
419
+ """
420
+ Create semantic links for multiple units efficiently.
421
+
422
+ For each unit, finds similar units and creates links.
423
+
424
+ Args:
425
+ conn: Database connection
426
+ agent_id: bank IDentifier
427
+ unit_ids: List of unit IDs
428
+ embeddings: List of embedding vectors
429
+ top_k: Number of top similar units to link
430
+ threshold: Minimum similarity threshold
431
+ log_buffer: Optional buffer for logging
432
+ """
433
+ if not unit_ids or not embeddings:
434
+ return
435
+
436
+ try:
437
+ import time as time_mod
438
+ import numpy as np
439
+
440
+ # Fetch ALL existing units with embeddings in ONE query
441
+ fetch_start = time_mod.time()
442
+ all_existing = await conn.fetch(
443
+ """
444
+ SELECT id, embedding
445
+ FROM memory_units
446
+ WHERE bank_id = $1
447
+ AND embedding IS NOT NULL
448
+ AND id::text != ALL($2)
449
+ """,
450
+ bank_id,
451
+ unit_ids
452
+ )
453
+ _log(log_buffer, f" [8.1] Fetch {len(all_existing)} existing embeddings (1 query): {time_mod.time() - fetch_start:.3f}s")
454
+
455
+ # Convert to numpy for vectorized similarity computation
456
+ compute_start = time_mod.time()
457
+ all_links = []
458
+
459
+ if all_existing:
460
+ # Convert existing embeddings to numpy array
461
+ existing_ids = [str(row['id']) for row in all_existing]
462
+ # Stack embeddings as 2D array: (num_embeddings, embedding_dim)
463
+ embedding_arrays = []
464
+ for row in all_existing:
465
+ raw_emb = row['embedding']
466
+ # Handle different pgvector formats
467
+ if isinstance(raw_emb, str):
468
+ # Parse string format: "[1.0, 2.0, ...]"
469
+ import json
470
+ emb = np.array(json.loads(raw_emb), dtype=np.float32)
471
+ elif isinstance(raw_emb, (list, tuple)):
472
+ emb = np.array(raw_emb, dtype=np.float32)
473
+ else:
474
+ # Try direct conversion (works for numpy arrays, pgvector objects, etc.)
475
+ emb = np.array(raw_emb, dtype=np.float32)
476
+
477
+ # Ensure it's 1D
478
+ if emb.ndim != 1:
479
+ raise ValueError(f"Expected 1D embedding, got shape {emb.shape}")
480
+ embedding_arrays.append(emb)
481
+
482
+ if not embedding_arrays:
483
+ existing_embeddings = np.array([])
484
+ elif len(embedding_arrays) == 1:
485
+ # Single embedding: reshape to (1, dim)
486
+ existing_embeddings = embedding_arrays[0].reshape(1, -1)
487
+ else:
488
+ # Multiple embeddings: vstack
489
+ existing_embeddings = np.vstack(embedding_arrays)
490
+
491
+ # For each new unit, compute similarities with ALL existing units
492
+ for unit_id, new_embedding in zip(unit_ids, embeddings):
493
+ new_emb_array = np.array(new_embedding)
494
+
495
+ # Compute cosine similarities (dot product for normalized vectors)
496
+ similarities = np.dot(existing_embeddings, new_emb_array)
497
+
498
+ # Find top-k above threshold
499
+ # Get indices of similarities above threshold
500
+ above_threshold = np.where(similarities >= threshold)[0]
501
+
502
+ if len(above_threshold) > 0:
503
+ # Sort by similarity (descending) and take top-k
504
+ sorted_indices = above_threshold[np.argsort(-similarities[above_threshold])][:top_k]
505
+
506
+ for idx in sorted_indices:
507
+ similar_id = existing_ids[idx]
508
+ similarity = float(similarities[idx])
509
+ all_links.append((unit_id, similar_id, 'semantic', similarity, None))
510
+
511
+ _log(log_buffer, f" [8.2] Compute similarities & generate {len(all_links)} semantic links: {time_mod.time() - compute_start:.3f}s")
512
+
513
+ if all_links:
514
+ insert_start = time_mod.time()
515
+ await conn.executemany(
516
+ """
517
+ INSERT INTO memory_links (from_unit_id, to_unit_id, link_type, weight, entity_id)
518
+ VALUES ($1, $2, $3, $4, $5)
519
+ ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
520
+ """,
521
+ all_links
522
+ )
523
+ _log(log_buffer, f" [8.3] Insert {len(all_links)} semantic links: {time_mod.time() - insert_start:.3f}s")
524
+
525
+ except Exception as e:
526
+ logger.error(f"Failed to create semantic links: {str(e)}")
527
+ import traceback
528
+ traceback.print_exc()
529
+ raise
530
+
531
+
532
+ async def insert_entity_links_batch(conn, links: List[tuple]):
533
+ """
534
+ Insert all entity links in a single batch.
535
+
536
+ Args:
537
+ conn: Database connection
538
+ links: List of tuples (from_unit_id, to_unit_id, link_type, weight, entity_id)
539
+ """
540
+ if not links:
541
+ return
542
+
543
+ await conn.executemany(
544
+ """
545
+ INSERT INTO memory_links (from_unit_id, to_unit_id, link_type, weight, entity_id)
546
+ VALUES ($1, $2, $3, $4, $5)
547
+ ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
548
+ """,
549
+ links
550
+ )
551
+
552
+
553
+ async def create_causal_links_batch(
554
+ conn,
555
+ unit_ids: List[str],
556
+ causal_relations_per_fact: List[List[dict]],
557
+ ) -> int:
558
+ """
559
+ Create causal links between facts based on LLM-extracted causal relationships.
560
+
561
+ Args:
562
+ conn: Database connection
563
+ unit_ids: List of unit IDs (in same order as causal_relations_per_fact)
564
+ causal_relations_per_fact: List of causal relations for each fact.
565
+ Each element is a list of dicts with:
566
+ - target_fact_index: Index into unit_ids for the target fact
567
+ - relation_type: "causes", "caused_by", "enables", or "prevents"
568
+ - strength: Float in [0.0, 1.0] representing relationship strength
569
+
570
+ Returns:
571
+ Number of causal links created
572
+
573
+ Causal link types:
574
+ - "causes": This fact directly causes the target fact (forward causation)
575
+ - "caused_by": This fact was caused by the target fact (backward causation)
576
+ - "enables": This fact enables/allows the target fact (enablement)
577
+ - "prevents": This fact prevents/blocks the target fact (prevention)
578
+ """
579
+ if not unit_ids or not causal_relations_per_fact:
580
+ return 0
581
+
582
+ try:
583
+ import time as time_mod
584
+ create_start = time_mod.time()
585
+
586
+ # Build links list
587
+ links = []
588
+ for fact_idx, causal_relations in enumerate(causal_relations_per_fact):
589
+ if not causal_relations:
590
+ continue
591
+
592
+ from_unit_id = unit_ids[fact_idx]
593
+
594
+ for relation in causal_relations:
595
+ target_idx = relation['target_fact_index']
596
+ relation_type = relation['relation_type']
597
+ strength = relation.get('strength', 1.0)
598
+
599
+ # Validate relation_type - must match database constraint
600
+ valid_types = {'causes', 'caused_by', 'enables', 'prevents'}
601
+ if relation_type not in valid_types:
602
+ logger.error(
603
+ f"Invalid relation_type '{relation_type}' (type: {type(relation_type).__name__}) "
604
+ f"from fact {fact_idx}. Must be one of: {valid_types}. "
605
+ f"Relation data: {relation}"
606
+ )
607
+ continue
608
+
609
+ # Validate target index
610
+ if target_idx < 0 or target_idx >= len(unit_ids):
611
+ logger.warning(f"Invalid target_fact_index {target_idx} in causal relation from fact {fact_idx}")
612
+ continue
613
+
614
+ to_unit_id = unit_ids[target_idx]
615
+
616
+ # Don't create self-links
617
+ if from_unit_id == to_unit_id:
618
+ continue
619
+
620
+ # Add the causal link
621
+ # link_type is the relation_type (e.g., "causes", "caused_by")
622
+ # weight is the strength of the relationship
623
+ links.append((from_unit_id, to_unit_id, relation_type, strength, None))
624
+
625
+
626
+ if links:
627
+ insert_start = time_mod.time()
628
+ try:
629
+ await conn.executemany(
630
+ """
631
+ INSERT INTO memory_links (from_unit_id, to_unit_id, link_type, weight, entity_id)
632
+ VALUES ($1, $2, $3, $4, $5)
633
+ ON CONFLICT (from_unit_id, to_unit_id, link_type, COALESCE(entity_id, '00000000-0000-0000-0000-000000000000'::uuid)) DO NOTHING
634
+ """,
635
+ links
636
+ )
637
+ except Exception as db_error:
638
+ # Log the actual data being inserted for debugging
639
+ logger.error(f"Database insert failed for causal links. Error: {db_error}")
640
+ logger.error(f"Attempted to insert {len(links)} links. First few:")
641
+ for i, link in enumerate(links[:3]):
642
+ logger.error(f" Link {i}: from={link[0]}, to={link[1]}, type='{link[2]}' (repr={repr(link[2])}), weight={link[3]}, entity={link[4]}")
643
+ raise
644
+
645
+ return len(links)
646
+
647
+ except Exception as e:
648
+ logger.error(f"Failed to create causal links: {str(e)}")
649
+ import traceback
650
+ traceback.print_exc()
651
+ raise