hindsight-api 0.0.21__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,9 @@ import time
6
6
  import logging
7
7
  from typing import List
8
8
  from datetime import timedelta, datetime, timezone
9
+ from uuid import UUID
10
+
11
+ from .types import EntityLink
9
12
 
10
13
  logger = logging.getLogger(__name__)
11
14
 
@@ -202,47 +205,24 @@ async def extract_entities_batch_optimized(
202
205
 
203
206
  # Resolve ALL entities in one batch call
204
207
  if all_entities_flat:
205
- # [6.2.2] Batch resolve entities
208
+ # [6.2.2] Batch resolve entities - single call with per-entity dates
206
209
  substep_6_2_2_start = time.time()
207
- # Group by date for batch resolution (round to hour to reduce buckets)
208
- entities_by_date = {}
209
- for idx, (unit_id, local_idx, fact_date) in enumerate(entity_to_unit):
210
- # Round to hour to group facts from same time period
211
- date_key = fact_date.replace(minute=0, second=0, microsecond=0)
212
- if date_key not in entities_by_date:
213
- entities_by_date[date_key] = []
214
- entities_by_date[date_key].append((idx, all_entities_flat[idx]))
215
-
216
- _log(log_buffer, f" [6.2.2] Grouped into {len(entities_by_date)} date buckets, resolving sequentially...", level='debug')
217
-
218
- # Resolve all date groups SEQUENTIALLY using main transaction connection
219
- # This prevents race conditions where parallel tasks create duplicate entities
220
- resolved_entity_ids = [None] * len(all_entities_flat)
221
-
222
- for date_idx, (date_key, entities_group) in enumerate(entities_by_date.items(), 1):
223
- date_bucket_start = time.time()
224
- indices = [idx for idx, _ in entities_group]
225
- entities_data = [entity_data for _, entity_data in entities_group]
226
- # Use the first fact's date for this bucket (all should be in same hour)
227
- fact_date = entity_to_unit[indices[0]][2]
228
-
229
- # Use main transaction connection to ensure consistency
230
- batch_resolved = await entity_resolver.resolve_entities_batch(
231
- bank_id=bank_id,
232
- entities_data=entities_data,
233
- context=context,
234
- unit_event_date=fact_date,
235
- conn=conn # Use main transaction connection
236
- )
237
210
 
238
- if len(entities_by_date) <= 10: # Only log individual buckets if there aren't too many
239
- _log(log_buffer, f" [6.2.2.{date_idx}] Resolved {len(entities_data)} entities in {time.time() - date_bucket_start:.3f}s", level='debug')
240
-
241
- # Map results back to resolved_entity_ids
242
- for idx, entity_id in zip(indices, batch_resolved):
243
- resolved_entity_ids[idx] = entity_id
211
+ # Add per-entity dates to entity data for batch resolution
212
+ for idx, (unit_id, local_idx, fact_date) in enumerate(entity_to_unit):
213
+ all_entities_flat[idx]['event_date'] = fact_date
214
+
215
+ # Resolve ALL entities in ONE batch call (much faster than sequential buckets)
216
+ # INSERT ... ON CONFLICT handles any race conditions at the DB level
217
+ resolved_entity_ids = await entity_resolver.resolve_entities_batch(
218
+ bank_id=bank_id,
219
+ entities_data=all_entities_flat,
220
+ context=context,
221
+ unit_event_date=None, # Not used when per-entity dates provided
222
+ conn=conn # Use main transaction connection
223
+ )
244
224
 
245
- _log(log_buffer, f" [6.2.2] Resolve entities: {len(all_entities_flat)} entities across {len(entities_by_date)} buckets in {time.time() - substep_6_2_2_start:.3f}s", level='debug')
225
+ _log(log_buffer, f" [6.2.2] Resolve entities: {len(all_entities_flat)} entities in single batch in {time.time() - substep_6_2_2_start:.3f}s", level='debug')
246
226
 
247
227
  # [6.2.3] Create unit-entity links in BATCH
248
228
  substep_6_2_3_start = time.time()
@@ -305,10 +285,14 @@ async def extract_entities_batch_optimized(
305
285
  # Only link each new unit to the most recent MAX_LINKS_PER_ENTITY units
306
286
  MAX_LINKS_PER_ENTITY = 50 # Limit to prevent explosion when entity appears in many facts
307
287
  link_gen_start = time.time()
308
- links = []
288
+ links: List[EntityLink] = []
309
289
  new_unit_set = set(unit_ids) # Units from this batch
310
290
 
291
+ def to_uuid(val) -> UUID:
292
+ return UUID(val) if isinstance(val, str) else val
293
+
311
294
  for entity_id, units_with_entity in entity_to_units.items():
295
+ entity_uuid = to_uuid(entity_id)
312
296
  # Separate new units (from this batch) and existing units
313
297
  new_units = [u for u in units_with_entity if str(u) in new_unit_set or u in new_unit_set]
314
298
  existing_units = [u for u in units_with_entity if str(u) not in new_unit_set and u not in new_unit_set]
@@ -318,15 +302,15 @@ async def extract_entities_batch_optimized(
318
302
  new_units_to_link = new_units[-MAX_LINKS_PER_ENTITY:] if len(new_units) > MAX_LINKS_PER_ENTITY else new_units
319
303
  for i, unit_id_1 in enumerate(new_units_to_link):
320
304
  for unit_id_2 in new_units_to_link[i+1:]:
321
- links.append((unit_id_1, unit_id_2, 'entity', 1.0, entity_id))
322
- links.append((unit_id_2, unit_id_1, 'entity', 1.0, entity_id))
305
+ links.append(EntityLink(from_unit_id=to_uuid(unit_id_1), to_unit_id=to_uuid(unit_id_2), entity_id=entity_uuid))
306
+ links.append(EntityLink(from_unit_id=to_uuid(unit_id_2), to_unit_id=to_uuid(unit_id_1), entity_id=entity_uuid))
323
307
 
324
308
  # Link new units to LIMITED existing units (most recent)
325
309
  existing_to_link = existing_units[-MAX_LINKS_PER_ENTITY:] # Take most recent
326
310
  for new_unit in new_units:
327
311
  for existing_unit in existing_to_link:
328
- links.append((new_unit, existing_unit, 'entity', 1.0, entity_id))
329
- links.append((existing_unit, new_unit, 'entity', 1.0, entity_id))
312
+ links.append(EntityLink(from_unit_id=to_uuid(new_unit), to_unit_id=to_uuid(existing_unit), entity_id=entity_uuid))
313
+ links.append(EntityLink(from_unit_id=to_uuid(existing_unit), to_unit_id=to_uuid(new_unit), entity_id=entity_uuid))
330
314
 
331
315
  _log(log_buffer, f" [6.3.3] Generate {len(links)} links: {time.time() - link_gen_start:.3f}s", level='debug')
332
316
  _log(log_buffer, f" [6.3] Entity link creation: {len(links)} links for {len(all_entity_ids)} unique entities in {time.time() - substep_start:.3f}s", level='debug')
@@ -346,7 +330,7 @@ async def create_temporal_links_batch_per_fact(
346
330
  unit_ids: List[str],
347
331
  time_window_hours: int = 24,
348
332
  log_buffer: List[str] = None,
349
- ):
333
+ ) -> int:
350
334
  """
351
335
  Create temporal links for multiple units, each with their own event_date.
352
336
 
@@ -359,9 +343,12 @@ async def create_temporal_links_batch_per_fact(
359
343
  unit_ids: List of unit IDs
360
344
  time_window_hours: Time window in hours for temporal links
361
345
  log_buffer: Optional buffer for logging
346
+
347
+ Returns:
348
+ Number of temporal links created
362
349
  """
363
350
  if not unit_ids:
364
- return
351
+ return 0
365
352
 
366
353
  try:
367
354
  import time as time_mod
@@ -417,6 +404,8 @@ async def create_temporal_links_batch_per_fact(
417
404
  )
418
405
  _log(log_buffer, f" [7.4] Insert {len(links)} temporal links: {time_mod.time() - insert_start:.3f}s")
419
406
 
407
+ return len(links)
408
+
420
409
  except Exception as e:
421
410
  logger.error(f"Failed to create temporal links: {str(e)}")
422
411
  import traceback
@@ -432,7 +421,7 @@ async def create_semantic_links_batch(
432
421
  top_k: int = 5,
433
422
  threshold: float = 0.7,
434
423
  log_buffer: List[str] = None,
435
- ):
424
+ ) -> int:
436
425
  """
437
426
  Create semantic links for multiple units efficiently.
438
427
 
@@ -446,9 +435,12 @@ async def create_semantic_links_batch(
446
435
  top_k: Number of top similar units to link
447
436
  threshold: Minimum similarity threshold
448
437
  log_buffer: Optional buffer for logging
438
+
439
+ Returns:
440
+ Number of semantic links created
449
441
  """
450
442
  if not unit_ids or not embeddings:
451
- return
443
+ return 0
452
444
 
453
445
  try:
454
446
  import time as time_mod
@@ -539,6 +531,8 @@ async def create_semantic_links_batch(
539
531
  )
540
532
  _log(log_buffer, f" [8.3] Insert {len(all_links)} semantic links: {time_mod.time() - insert_start:.3f}s")
541
533
 
534
+ return len(all_links)
535
+
542
536
  except Exception as e:
543
537
  logger.error(f"Failed to create semantic links: {str(e)}")
544
538
  import traceback
@@ -546,7 +540,7 @@ async def create_semantic_links_batch(
546
540
  raise
547
541
 
548
542
 
549
- async def insert_entity_links_batch(conn, links: List[tuple], chunk_size: int = 50000):
543
+ async def insert_entity_links_batch(conn, links: List[EntityLink], chunk_size: int = 50000):
550
544
  """
551
545
  Insert all entity links using COPY to temp table + INSERT for maximum speed.
552
546
 
@@ -556,7 +550,7 @@ async def insert_entity_links_batch(conn, links: List[tuple], chunk_size: int =
556
550
 
557
551
  Args:
558
552
  conn: Database connection
559
- links: List of tuples (from_unit_id, to_unit_id, link_type, weight, entity_id)
553
+ links: List of EntityLink objects
560
554
  chunk_size: Number of rows per batch (default 50000)
561
555
  """
562
556
  if not links:
@@ -585,16 +579,16 @@ async def insert_entity_links_batch(conn, links: List[tuple], chunk_size: int =
585
579
  await conn.execute("TRUNCATE _temp_entity_links")
586
580
  logger.debug(f" [9.2] Truncate temp table: {time_mod.time() - truncate_start:.3f}s")
587
581
 
588
- # Convert links to proper format for COPY
582
+ # Convert EntityLink objects to tuples for COPY
589
583
  convert_start = time_mod.time()
590
584
  records = []
591
- for from_id, to_id, link_type, weight, entity_id in links:
585
+ for link in links:
592
586
  records.append((
593
- uuid_mod.UUID(from_id) if isinstance(from_id, str) else from_id,
594
- uuid_mod.UUID(to_id) if isinstance(to_id, str) else to_id,
595
- link_type,
596
- weight,
597
- uuid_mod.UUID(str(entity_id)) if entity_id and not isinstance(entity_id, uuid_mod.UUID) else entity_id
587
+ link.from_unit_id,
588
+ link.to_unit_id,
589
+ link.link_type,
590
+ link.weight,
591
+ link.entity_id
598
592
  ))
599
593
  logger.debug(f" [9.3] Convert {len(records)} records: {time_mod.time() - convert_start:.3f}s")
600
594
 
@@ -0,0 +1,264 @@
1
+ """
2
+ Observation regeneration for retain pipeline.
3
+
4
+ Regenerates entity observations as part of the retain transaction.
5
+ """
6
+ import logging
7
+ import time
8
+ import uuid
9
+ from datetime import datetime, timezone
10
+ from typing import List, Dict, Optional
11
+
12
+ from ..search import observation_utils
13
+ from . import embedding_utils
14
+ from ..db_utils import acquire_with_retry
15
+ from .types import EntityLink
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def utcnow():
21
+ """Get current UTC time."""
22
+ return datetime.now(timezone.utc)
23
+
24
+
25
+ # Simple dataclass-like container for facts (avoid importing from memory_engine)
26
+ class MemoryFactForObservation:
27
+ def __init__(self, id: str, text: str, fact_type: str, context: str, occurred_start: Optional[str]):
28
+ self.id = id
29
+ self.text = text
30
+ self.fact_type = fact_type
31
+ self.context = context
32
+ self.occurred_start = occurred_start
33
+
34
+
35
+ async def regenerate_observations_batch(
36
+ conn,
37
+ embeddings_model,
38
+ llm_config,
39
+ bank_id: str,
40
+ entity_links: List[EntityLink],
41
+ log_buffer: List[str] = None
42
+ ) -> None:
43
+ """
44
+ Regenerate observations for top entities in this batch.
45
+
46
+ Called INSIDE the retain transaction for atomicity - if observations
47
+ fail, the entire retain batch is rolled back.
48
+
49
+ Args:
50
+ conn: Database connection (from the retain transaction)
51
+ embeddings_model: Embeddings model for generating observation embeddings
52
+ llm_config: LLM configuration for observation extraction
53
+ bank_id: Bank identifier
54
+ entity_links: Entity links from this batch
55
+ log_buffer: Optional log buffer for timing
56
+ """
57
+ TOP_N_ENTITIES = 5
58
+ MIN_FACTS_THRESHOLD = 5
59
+
60
+ if not entity_links:
61
+ return
62
+
63
+ # Count mentions per entity in this batch
64
+ entity_mention_counts: Dict[str, int] = {}
65
+ for link in entity_links:
66
+ if link.entity_id:
67
+ entity_id = str(link.entity_id)
68
+ entity_mention_counts[entity_id] = entity_mention_counts.get(entity_id, 0) + 1
69
+
70
+ if not entity_mention_counts:
71
+ return
72
+
73
+ # Sort by mention count descending and take top N
74
+ sorted_entities = sorted(
75
+ entity_mention_counts.items(),
76
+ key=lambda x: x[1],
77
+ reverse=True
78
+ )
79
+ entities_to_process = [e[0] for e in sorted_entities[:TOP_N_ENTITIES]]
80
+
81
+ obs_start = time.time()
82
+
83
+ # Convert to UUIDs
84
+ entity_uuids = [uuid.UUID(eid) if isinstance(eid, str) else eid for eid in entities_to_process]
85
+
86
+ # Batch query for entity names
87
+ entity_rows = await conn.fetch(
88
+ """
89
+ SELECT id, canonical_name FROM entities
90
+ WHERE id = ANY($1) AND bank_id = $2
91
+ """,
92
+ entity_uuids, bank_id
93
+ )
94
+ entity_names = {row['id']: row['canonical_name'] for row in entity_rows}
95
+
96
+ # Batch query for fact counts
97
+ fact_counts = await conn.fetch(
98
+ """
99
+ SELECT ue.entity_id, COUNT(*) as cnt
100
+ FROM unit_entities ue
101
+ JOIN memory_units mu ON ue.unit_id = mu.id
102
+ WHERE ue.entity_id = ANY($1) AND mu.bank_id = $2
103
+ GROUP BY ue.entity_id
104
+ """,
105
+ entity_uuids, bank_id
106
+ )
107
+ entity_fact_counts = {row['entity_id']: row['cnt'] for row in fact_counts}
108
+
109
+ # Filter entities that meet the threshold
110
+ entities_with_names = []
111
+ for entity_id in entities_to_process:
112
+ entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
113
+ if entity_uuid not in entity_names:
114
+ continue
115
+ fact_count = entity_fact_counts.get(entity_uuid, 0)
116
+ if fact_count >= MIN_FACTS_THRESHOLD:
117
+ entities_with_names.append((entity_id, entity_names[entity_uuid]))
118
+
119
+ if not entities_with_names:
120
+ return
121
+
122
+ # Process entities SEQUENTIALLY (asyncpg doesn't allow concurrent queries on same connection)
123
+ # We must use the same connection to stay in the retain transaction
124
+ total_observations = 0
125
+
126
+ for entity_id, entity_name in entities_with_names:
127
+ try:
128
+ obs_ids = await _regenerate_entity_observations(
129
+ conn, embeddings_model, llm_config,
130
+ bank_id, entity_id, entity_name
131
+ )
132
+ total_observations += len(obs_ids)
133
+ except Exception as e:
134
+ logger.error(f"[OBSERVATIONS] Error processing entity {entity_id}: {e}")
135
+
136
+ obs_time = time.time() - obs_start
137
+ if log_buffer is not None:
138
+ log_buffer.append(f"[11] Observations: {total_observations} observations for {len(entities_with_names)} entities in {obs_time:.3f}s")
139
+
140
+
141
+ async def _regenerate_entity_observations(
142
+ conn,
143
+ embeddings_model,
144
+ llm_config,
145
+ bank_id: str,
146
+ entity_id: str,
147
+ entity_name: str
148
+ ) -> List[str]:
149
+ """
150
+ Regenerate observations for a single entity.
151
+
152
+ Uses the provided connection (part of retain transaction).
153
+
154
+ Args:
155
+ conn: Database connection (from the retain transaction)
156
+ embeddings_model: Embeddings model
157
+ llm_config: LLM configuration
158
+ bank_id: Bank identifier
159
+ entity_id: Entity UUID
160
+ entity_name: Canonical name of the entity
161
+
162
+ Returns:
163
+ List of created observation IDs
164
+ """
165
+ entity_uuid = uuid.UUID(entity_id) if isinstance(entity_id, str) else entity_id
166
+
167
+ # Get all facts mentioning this entity (exclude observations themselves)
168
+ rows = await conn.fetch(
169
+ """
170
+ SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.fact_type
171
+ FROM memory_units mu
172
+ JOIN unit_entities ue ON mu.id = ue.unit_id
173
+ WHERE mu.bank_id = $1
174
+ AND ue.entity_id = $2
175
+ AND mu.fact_type IN ('world', 'experience')
176
+ ORDER BY mu.occurred_start DESC
177
+ LIMIT 50
178
+ """,
179
+ bank_id, entity_uuid
180
+ )
181
+
182
+ if not rows:
183
+ return []
184
+
185
+ # Convert to fact objects for observation extraction
186
+ facts = []
187
+ for row in rows:
188
+ occurred_start = row['occurred_start'].isoformat() if row['occurred_start'] else None
189
+ facts.append(MemoryFactForObservation(
190
+ id=str(row['id']),
191
+ text=row['text'],
192
+ fact_type=row['fact_type'],
193
+ context=row['context'],
194
+ occurred_start=occurred_start
195
+ ))
196
+
197
+ # Extract observations using LLM
198
+ observations = await observation_utils.extract_observations_from_facts(
199
+ llm_config,
200
+ entity_name,
201
+ facts
202
+ )
203
+
204
+ if not observations:
205
+ return []
206
+
207
+ # Delete old observations for this entity
208
+ await conn.execute(
209
+ """
210
+ DELETE FROM memory_units
211
+ WHERE id IN (
212
+ SELECT mu.id
213
+ FROM memory_units mu
214
+ JOIN unit_entities ue ON mu.id = ue.unit_id
215
+ WHERE mu.bank_id = $1
216
+ AND mu.fact_type = 'observation'
217
+ AND ue.entity_id = $2
218
+ )
219
+ """,
220
+ bank_id, entity_uuid
221
+ )
222
+
223
+ # Generate embeddings for new observations
224
+ embeddings = await embedding_utils.generate_embeddings_batch(
225
+ embeddings_model, observations
226
+ )
227
+
228
+ # Insert new observations
229
+ current_time = utcnow()
230
+ created_ids = []
231
+
232
+ for obs_text, embedding in zip(observations, embeddings):
233
+ result = await conn.fetchrow(
234
+ """
235
+ INSERT INTO memory_units (
236
+ bank_id, text, embedding, context, event_date,
237
+ occurred_start, occurred_end, mentioned_at,
238
+ fact_type, access_count
239
+ )
240
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'observation', 0)
241
+ RETURNING id
242
+ """,
243
+ bank_id,
244
+ obs_text,
245
+ str(embedding),
246
+ f"observation about {entity_name}",
247
+ current_time,
248
+ current_time,
249
+ current_time,
250
+ current_time
251
+ )
252
+ obs_id = str(result['id'])
253
+ created_ids.append(obs_id)
254
+
255
+ # Link observation to entity
256
+ await conn.execute(
257
+ """
258
+ INSERT INTO unit_entities (unit_id, entity_id)
259
+ VALUES ($1, $2)
260
+ """,
261
+ uuid.UUID(obs_id), entity_uuid
262
+ )
263
+
264
+ return created_ids
@@ -17,7 +17,7 @@ def utcnow():
17
17
  """Get current UTC time."""
18
18
  return datetime.now(timezone.utc)
19
19
 
20
- from .types import RetainContent, ExtractedFact, ProcessedFact
20
+ from .types import RetainContent, ExtractedFact, ProcessedFact, EntityLink
21
21
  from . import (
22
22
  fact_extraction,
23
23
  embedding_processing,
@@ -25,7 +25,8 @@ from . import (
25
25
  chunk_storage,
26
26
  fact_storage,
27
27
  entity_processing,
28
- link_creation
28
+ link_creation,
29
+ observation_regeneration
29
30
  )
30
31
 
31
32
  logger = logging.getLogger(__name__)
@@ -39,7 +40,6 @@ async def retain_batch(
39
40
  task_backend,
40
41
  format_date_fn,
41
42
  duplicate_checker_fn,
42
- regenerate_observations_fn,
43
43
  bank_id: str,
44
44
  contents_dicts: List[Dict[str, Any]],
45
45
  document_id: Optional[str] = None,
@@ -58,7 +58,6 @@ async def retain_batch(
58
58
  task_backend: Task backend for background jobs
59
59
  format_date_fn: Function to format datetime to readable string
60
60
  duplicate_checker_fn: Function to check for duplicate facts
61
- regenerate_observations_fn: Async function to regenerate observations for entities
62
61
  bank_id: Bank identifier
63
62
  contents_dicts: List of content dictionaries
64
63
  document_id: Optional document ID
@@ -288,50 +287,59 @@ async def retain_batch(
288
287
 
289
288
  # Create temporal links
290
289
  step_start = time.time()
291
- await link_creation.create_temporal_links_batch(conn, bank_id, unit_ids)
292
- log_buffer.append(f"[7] Temporal links: {time.time() - step_start:.3f}s")
290
+ temporal_link_count = await link_creation.create_temporal_links_batch(conn, bank_id, unit_ids)
291
+ log_buffer.append(f"[7] Temporal links: {temporal_link_count} links in {time.time() - step_start:.3f}s")
293
292
 
294
293
  # Create semantic links
295
294
  step_start = time.time()
296
295
  embeddings_for_links = [fact.embedding for fact in non_duplicate_facts]
297
- await link_creation.create_semantic_links_batch(conn, bank_id, unit_ids, embeddings_for_links)
298
- log_buffer.append(f"[8] Semantic links: {time.time() - step_start:.3f}s")
296
+ semantic_link_count = await link_creation.create_semantic_links_batch(conn, bank_id, unit_ids, embeddings_for_links)
297
+ log_buffer.append(f"[8] Semantic links: {semantic_link_count} links in {time.time() - step_start:.3f}s")
299
298
 
300
299
  # Insert entity links
301
300
  step_start = time.time()
302
301
  if entity_links:
303
302
  await entity_processing.insert_entity_links_batch(conn, entity_links)
304
- log_buffer.append(f"[9] Entity links: {time.time() - step_start:.3f}s")
303
+ log_buffer.append(f"[9] Entity links: {len(entity_links) if entity_links else 0} links in {time.time() - step_start:.3f}s")
305
304
 
306
305
  # Create causal links
307
306
  step_start = time.time()
308
307
  causal_link_count = await link_creation.create_causal_links_batch(conn, unit_ids, non_duplicate_facts)
309
308
  log_buffer.append(f"[10] Causal links: {causal_link_count} links in {time.time() - step_start:.3f}s")
310
309
 
310
+ # Regenerate observations INSIDE transaction for atomicity
311
+ await observation_regeneration.regenerate_observations_batch(
312
+ conn,
313
+ embeddings_model,
314
+ llm_config,
315
+ bank_id,
316
+ entity_links,
317
+ log_buffer
318
+ )
319
+
311
320
  # Map results back to original content items
312
321
  result_unit_ids = _map_results_to_contents(
313
322
  contents, extracted_facts, is_duplicate_flags, unit_ids
314
323
  )
315
324
 
316
- total_time = time.time() - start_time
317
- log_buffer.append(f"{'='*60}")
318
- log_buffer.append(f"RETAIN_BATCH COMPLETE: {len(unit_ids)} units in {total_time:.3f}s")
319
- if document_ids_added:
320
- log_buffer.append(f"Documents: {', '.join(document_ids_added)}")
321
- log_buffer.append(f"{'='*60}")
322
-
323
- logger.info("\n" + "\n".join(log_buffer) + "\n")
324
-
325
- # Trigger background tasks AFTER transaction commits
325
+ # Trigger background tasks AFTER transaction commits (opinion reinforcement only)
326
326
  await _trigger_background_tasks(
327
327
  task_backend,
328
- regenerate_observations_fn,
329
328
  bank_id,
330
329
  unit_ids,
331
- non_duplicate_facts,
332
- entity_links
330
+ non_duplicate_facts
333
331
  )
334
332
 
333
+ # Log final summary
334
+ total_time = time.time() - start_time
335
+ log_buffer.append(f"{'='*60}")
336
+ log_buffer.append(f"RETAIN_BATCH COMPLETE: {len(unit_ids)} units in {total_time:.3f}s")
337
+ if document_ids_added:
338
+ log_buffer.append(f"Documents: {', '.join(document_ids_added)}")
339
+ log_buffer.append(f"{'='*60}")
340
+
341
+ logger.info("\n" + "\n".join(log_buffer) + "\n")
342
+
335
343
  return result_unit_ids
336
344
 
337
345
 
@@ -367,13 +375,11 @@ def _map_results_to_contents(
367
375
 
368
376
  async def _trigger_background_tasks(
369
377
  task_backend,
370
- regenerate_observations_fn,
371
378
  bank_id: str,
372
379
  unit_ids: List[str],
373
380
  facts: List[ProcessedFact],
374
- entity_links: List
375
381
  ) -> None:
376
- """Trigger opinion reinforcement and observation regeneration (sync)."""
382
+ """Trigger opinion reinforcement as background task (after transaction commits)."""
377
383
  # Trigger opinion reinforcement if there are entities
378
384
  fact_entities = [[e.name for e in fact.entities] for fact in facts]
379
385
  if any(fact_entities):
@@ -384,22 +390,3 @@ async def _trigger_background_tasks(
384
390
  'unit_texts': [fact.fact_text for fact in facts],
385
391
  'unit_entities': fact_entities
386
392
  })
387
-
388
- # Regenerate observations synchronously for top entities
389
- TOP_N_ENTITIES = 5
390
- MIN_FACTS_THRESHOLD = 5
391
-
392
- if entity_links and regenerate_observations_fn:
393
- unique_entity_ids = set()
394
- for link in entity_links:
395
- # links are tuples: (unit_id, entity_id, confidence)
396
- if len(link) >= 2 and link[1]:
397
- unique_entity_ids.add(str(link[1]))
398
-
399
- if unique_entity_ids:
400
- # Run observation regeneration synchronously
401
- await regenerate_observations_fn(
402
- bank_id=bank_id,
403
- entity_ids=list(unique_entity_ids)[:TOP_N_ENTITIES],
404
- min_facts=MIN_FACTS_THRESHOLD
405
- )
@@ -176,6 +176,20 @@ class ProcessedFact:
176
176
  )
177
177
 
178
178
 
179
+ @dataclass
180
+ class EntityLink:
181
+ """
182
+ Link between two memory units through a shared entity.
183
+
184
+ Used for entity-based graph connections in the memory graph.
185
+ """
186
+ from_unit_id: UUID
187
+ to_unit_id: UUID
188
+ entity_id: UUID
189
+ link_type: str = 'entity'
190
+ weight: float = 1.0
191
+
192
+
179
193
  @dataclass
180
194
  class RetainBatch:
181
195
  """
@@ -170,9 +170,9 @@ async def retrieve_graph(
170
170
  batch_activations[unit_id] = activation
171
171
 
172
172
  # Batch fetch neighbors for all nodes in this batch
173
- # Fetch top weighted neighbors (batch_size * 10 = ~200 for good distribution)
173
+ # Fetch top weighted neighbors (batch_size * 20 = ~400 for good distribution)
174
174
  if batch_nodes and budget_remaining > 0:
175
- max_neighbors = len(batch_nodes) * 10
175
+ max_neighbors = len(batch_nodes) * 20
176
176
  neighbors = await conn.fetch(
177
177
  """
178
178
  SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end, mu.mentioned_at,