hindsight-api 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hindsight_api/api/mcp.py CHANGED
@@ -121,11 +121,7 @@ class MCPMiddleware:
121
121
  self.app = app
122
122
  self.memory = memory
123
123
  self.mcp_server = create_mcp_server(memory)
124
- # Use sse_app - http_app requires lifespan management that's complex with middleware
125
- import warnings
126
- with warnings.catch_warnings():
127
- warnings.simplefilter("ignore", DeprecationWarning)
128
- self.mcp_app = self.mcp_server.sse_app()
124
+ self.mcp_app = self.mcp_server.http_app()
129
125
 
130
126
  async def __call__(self, scope, receive, send):
131
127
  if scope["type"] != "http":
hindsight_api/config.py CHANGED
@@ -29,6 +29,7 @@ ENV_HOST = "HINDSIGHT_API_HOST"
29
29
  ENV_PORT = "HINDSIGHT_API_PORT"
30
30
  ENV_LOG_LEVEL = "HINDSIGHT_API_LOG_LEVEL"
31
31
  ENV_MCP_ENABLED = "HINDSIGHT_API_MCP_ENABLED"
32
+ ENV_GRAPH_RETRIEVER = "HINDSIGHT_API_GRAPH_RETRIEVER"
32
33
 
33
34
  # Default values
34
35
  DEFAULT_DATABASE_URL = "pg0"
@@ -45,6 +46,7 @@ DEFAULT_HOST = "0.0.0.0"
45
46
  DEFAULT_PORT = 8888
46
47
  DEFAULT_LOG_LEVEL = "info"
47
48
  DEFAULT_MCP_ENABLED = True
49
+ DEFAULT_GRAPH_RETRIEVER = "bfs" # Options: "bfs", "mpfp"
48
50
 
49
51
  # Required embedding dimension for database schema
50
52
  EMBEDDING_DIMENSION = 384
@@ -79,6 +81,9 @@ class HindsightConfig:
79
81
  log_level: str
80
82
  mcp_enabled: bool
81
83
 
84
+ # Recall
85
+ graph_retriever: str
86
+
82
87
  @classmethod
83
88
  def from_env(cls) -> "HindsightConfig":
84
89
  """Create configuration from environment variables."""
@@ -107,6 +112,9 @@ class HindsightConfig:
107
112
  port=int(os.getenv(ENV_PORT, DEFAULT_PORT)),
108
113
  log_level=os.getenv(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL),
109
114
  mcp_enabled=os.getenv(ENV_MCP_ENABLED, str(DEFAULT_MCP_ENABLED)).lower() == "true",
115
+
116
+ # Recall
117
+ graph_retriever=os.getenv(ENV_GRAPH_RETRIEVER, DEFAULT_GRAPH_RETRIEVER),
110
118
  )
111
119
 
112
120
  def get_llm_base_url(self) -> str:
@@ -147,6 +155,7 @@ class HindsightConfig:
147
155
  logger.info(f"LLM: provider={self.llm_provider}, model={self.llm_model}")
148
156
  logger.info(f"Embeddings: provider={self.embeddings_provider}")
149
157
  logger.info(f"Reranker: provider={self.reranker_provider}")
158
+ logger.info(f"Graph retriever: {self.graph_retriever}")
150
159
 
151
160
 
152
161
  def get_config() -> HindsightConfig:
@@ -101,12 +101,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
101
101
  )
102
102
 
103
103
  logger.info(f"Reranker: initializing local provider with model {self.model_name}")
104
- # Disable lazy loading (meta tensors) which causes issues with newer transformers/accelerate
105
- # Setting low_cpu_mem_usage=False and device_map=None ensures tensors are fully materialized
106
- self._model = CrossEncoder(
107
- self.model_name,
108
- model_kwargs={"low_cpu_mem_usage": False, "device_map": None},
109
- )
104
+ self._model = CrossEncoder(self.model_name)
110
105
  logger.info("Reranker: local provider initialized")
111
106
 
112
107
  def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
@@ -175,9 +175,13 @@ class LLMProvider:
175
175
  is_reasoning_model = any(x in model_lower for x in ["gpt-5", "o1", "o3"])
176
176
 
177
177
  # For GPT-4 and GPT-4.1 models, cap max_completion_tokens to 32000
178
+ # For GPT-4o models, cap to 16384
178
179
  is_gpt4_model = any(x in model_lower for x in ["gpt-4.1", "gpt-4-"])
180
+ is_gpt4o_model = "gpt-4o" in model_lower
179
181
  if max_completion_tokens is not None:
180
- if is_gpt4_model and max_completion_tokens > 32000:
182
+ if is_gpt4o_model and max_completion_tokens > 16384:
183
+ max_completion_tokens = 16384
184
+ elif is_gpt4_model and max_completion_tokens > 32000:
181
185
  max_completion_tokens = 32000
182
186
  # For reasoning models, max_completion_tokens includes reasoning + output tokens
183
187
  # Enforce minimum of 16000 to ensure enough space for both
@@ -268,9 +272,9 @@ class LLMProvider:
268
272
  raise
269
273
 
270
274
  except APIStatusError as e:
271
- # Fast fail on 4xx client errors (except 429 rate limit and 498 which is treated as server error)
272
- if 400 <= e.status_code < 500 and e.status_code not in (429, 498):
273
- logger.error(f"Client error (HTTP {e.status_code}), not retrying: {str(e)}")
275
+ # Fast fail only on 401 (unauthorized) and 403 (forbidden) - these won't recover with retries
276
+ if e.status_code in (401, 403):
277
+ logger.error(f"Auth error (HTTP {e.status_code}), not retrying: {str(e)}")
274
278
  raise
275
279
 
276
280
  last_exception = e
@@ -408,13 +412,13 @@ class LLMProvider:
408
412
  raise
409
413
 
410
414
  except genai_errors.APIError as e:
411
- # Fast fail on 4xx client errors (except 429 rate limit)
412
- if e.code and 400 <= e.code < 500 and e.code != 429:
413
- logger.error(f"Gemini client error (HTTP {e.code}), not retrying: {str(e)}")
415
+ # Fast fail only on 401 (unauthorized) and 403 (forbidden) - these won't recover with retries
416
+ if e.code in (401, 403):
417
+ logger.error(f"Gemini auth error (HTTP {e.code}), not retrying: {str(e)}")
414
418
  raise
415
419
 
416
- # Retry on 429 and 5xx
417
- if e.code in (429, 500, 502, 503, 504):
420
+ # Retry on retryable errors (rate limits, server errors, and other client errors like 400)
421
+ if e.code in (400, 429, 500, 502, 503, 504) or (e.code and e.code >= 500):
418
422
  last_exception = e
419
423
  if attempt < max_retries:
420
424
  backoff = min(initial_backoff * (2 ** attempt), max_backoff)
@@ -1156,22 +1156,22 @@ class MemoryEngine:
1156
1156
  aggregated_timings = {"semantic": 0.0, "bm25": 0.0, "graph": 0.0, "temporal": 0.0}
1157
1157
 
1158
1158
  detected_temporal_constraint = None
1159
- for idx, (ft_semantic, ft_bm25, ft_graph, ft_temporal, ft_timings, ft_temporal_constraint) in enumerate(all_retrievals):
1159
+ for idx, retrieval_result in enumerate(all_retrievals):
1160
1160
  # Log fact types in this retrieval batch
1161
1161
  ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1162
- logger.debug(f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(ft_semantic)}, bm25={len(ft_bm25)}, graph={len(ft_graph)}, temporal={len(ft_temporal) if ft_temporal else 0}")
1162
+ logger.debug(f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}")
1163
1163
 
1164
- semantic_results.extend(ft_semantic)
1165
- bm25_results.extend(ft_bm25)
1166
- graph_results.extend(ft_graph)
1167
- if ft_temporal:
1168
- temporal_results.extend(ft_temporal)
1164
+ semantic_results.extend(retrieval_result.semantic)
1165
+ bm25_results.extend(retrieval_result.bm25)
1166
+ graph_results.extend(retrieval_result.graph)
1167
+ if retrieval_result.temporal:
1168
+ temporal_results.extend(retrieval_result.temporal)
1169
1169
  # Track max timing for each method (since they run in parallel across fact types)
1170
- for method, duration in ft_timings.items():
1171
- aggregated_timings[method] = max(aggregated_timings[method], duration)
1170
+ for method, duration in retrieval_result.timings.items():
1171
+ aggregated_timings[method] = max(aggregated_timings.get(method, 0.0), duration)
1172
1172
  # Capture temporal constraint (same across all fact types)
1173
- if ft_temporal_constraint:
1174
- detected_temporal_constraint = ft_temporal_constraint
1173
+ if retrieval_result.temporal_constraint:
1174
+ detected_temporal_constraint = retrieval_result.temporal_constraint
1175
1175
 
1176
1176
  # If no temporal results from any fact type, set to None
1177
1177
  if not temporal_results:
@@ -1203,49 +1203,57 @@ class MemoryEngine:
1203
1203
  temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
1204
1204
  log_buffer.append(f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}")
1205
1205
 
1206
- # Record retrieval results for tracer (convert typed results to old format)
1206
+ # Record retrieval results for tracer - per fact type
1207
1207
  if tracer:
1208
1208
  # Convert RetrievalResult to old tuple format for tracer
1209
1209
  def to_tuple_format(results):
1210
1210
  return [(r.id, r.__dict__) for r in results]
1211
1211
 
1212
- # Add semantic retrieval results
1213
- tracer.add_retrieval_results(
1214
- method_name="semantic",
1215
- results=to_tuple_format(semantic_results),
1216
- duration_seconds=aggregated_timings["semantic"],
1217
- score_field="similarity",
1218
- metadata={"limit": thinking_budget}
1219
- )
1212
+ # Add retrieval results per fact type (to show parallel execution in UI)
1213
+ for idx, rr in enumerate(all_retrievals):
1214
+ ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
1220
1215
 
1221
- # Add BM25 retrieval results
1222
- tracer.add_retrieval_results(
1223
- method_name="bm25",
1224
- results=to_tuple_format(bm25_results),
1225
- duration_seconds=aggregated_timings["bm25"],
1226
- score_field="bm25_score",
1227
- metadata={"limit": thinking_budget}
1228
- )
1216
+ # Add semantic retrieval results for this fact type
1217
+ tracer.add_retrieval_results(
1218
+ method_name="semantic",
1219
+ results=to_tuple_format(rr.semantic),
1220
+ duration_seconds=rr.timings.get("semantic", 0.0),
1221
+ score_field="similarity",
1222
+ metadata={"limit": thinking_budget},
1223
+ fact_type=ft_name
1224
+ )
1229
1225
 
1230
- # Add graph retrieval results
1231
- tracer.add_retrieval_results(
1232
- method_name="graph",
1233
- results=to_tuple_format(graph_results),
1234
- duration_seconds=aggregated_timings["graph"],
1235
- score_field="similarity", # Graph uses similarity for activation
1236
- metadata={"budget": thinking_budget}
1237
- )
1226
+ # Add BM25 retrieval results for this fact type
1227
+ tracer.add_retrieval_results(
1228
+ method_name="bm25",
1229
+ results=to_tuple_format(rr.bm25),
1230
+ duration_seconds=rr.timings.get("bm25", 0.0),
1231
+ score_field="bm25_score",
1232
+ metadata={"limit": thinking_budget},
1233
+ fact_type=ft_name
1234
+ )
1238
1235
 
1239
- # Add temporal retrieval results if present
1240
- if temporal_results:
1236
+ # Add graph retrieval results for this fact type
1241
1237
  tracer.add_retrieval_results(
1242
- method_name="temporal",
1243
- results=to_tuple_format(temporal_results),
1244
- duration_seconds=aggregated_timings["temporal"],
1245
- score_field="temporal_score",
1246
- metadata={"budget": thinking_budget}
1238
+ method_name="graph",
1239
+ results=to_tuple_format(rr.graph),
1240
+ duration_seconds=rr.timings.get("graph", 0.0),
1241
+ score_field="activation",
1242
+ metadata={"budget": thinking_budget},
1243
+ fact_type=ft_name
1247
1244
  )
1248
1245
 
1246
+ # Add temporal retrieval results for this fact type (even if empty, to show it ran)
1247
+ if rr.temporal is not None:
1248
+ tracer.add_retrieval_results(
1249
+ method_name="temporal",
1250
+ results=to_tuple_format(rr.temporal),
1251
+ duration_seconds=rr.timings.get("temporal", 0.0),
1252
+ score_field="temporal_score",
1253
+ metadata={"budget": thinking_budget},
1254
+ fact_type=ft_name
1255
+ )
1256
+
1249
1257
  # Record entry points (from semantic results) for legacy graph view
1250
1258
  for rank, retrieval in enumerate(semantic_results[:10], start=1): # Top 10 as entry points
1251
1259
  tracer.add_entry_point(retrieval.id, retrieval.text, retrieval.similarity or 0.0, rank)
@@ -1287,31 +1295,24 @@ class MemoryEngine:
1287
1295
  step_duration = time.time() - step_start
1288
1296
  log_buffer.append(f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s")
1289
1297
 
1290
- if tracer:
1291
- # Convert to old format for tracer
1292
- results_dict = [sr.to_dict() for sr in scored_results]
1293
- tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1294
- for mc in merged_candidates]
1295
- tracer.add_reranked(results_dict, tracer_merged)
1296
- tracer.add_phase_metric("reranking", step_duration, {
1297
- "reranker_type": "cross-encoder",
1298
- "candidates_reranked": len(scored_results)
1299
- })
1300
-
1301
1298
  # Step 4.5: Combine cross-encoder score with retrieval signals
1302
1299
  # This preserves retrieval work (RRF, temporal, recency) instead of pure cross-encoder ranking
1303
1300
  if scored_results:
1304
- # Normalize RRF scores to [0, 1] range
1301
+ # Normalize RRF scores to [0, 1] range using min-max normalization
1305
1302
  rrf_scores = [sr.candidate.rrf_score for sr in scored_results]
1306
- max_rrf = max(rrf_scores) if rrf_scores else 1.0
1303
+ max_rrf = max(rrf_scores) if rrf_scores else 0.0
1307
1304
  min_rrf = min(rrf_scores) if rrf_scores else 0.0
1308
- rrf_range = max_rrf - min_rrf if max_rrf > min_rrf else 1.0
1305
+ rrf_range = max_rrf - min_rrf # Don't force to 1.0, let fallback handle it
1309
1306
 
1310
1307
  # Calculate recency based on occurred_start (more recent = higher score)
1311
1308
  now = utcnow()
1312
1309
  for sr in scored_results:
1313
- # Normalize RRF score
1314
- sr.rrf_normalized = (sr.candidate.rrf_score - min_rrf) / rrf_range if rrf_range > 0 else 0.5
1310
+ # Normalize RRF score (0-1 range, 0.5 if all same)
1311
+ if rrf_range > 0:
1312
+ sr.rrf_normalized = (sr.candidate.rrf_score - min_rrf) / rrf_range
1313
+ else:
1314
+ # All RRF scores are the same, use neutral value
1315
+ sr.rrf_normalized = 0.5
1315
1316
 
1316
1317
  # Calculate recency (decay over 365 days, minimum 0.1)
1317
1318
  sr.recency = 0.5 # default for missing dates
@@ -1343,6 +1344,17 @@ class MemoryEngine:
1343
1344
  scored_results.sort(key=lambda x: x.weight, reverse=True)
1344
1345
  log_buffer.append(f" [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)")
1345
1346
 
1347
+ # Add reranked results to tracer AFTER combined scoring (so normalized values are included)
1348
+ if tracer:
1349
+ results_dict = [sr.to_dict() for sr in scored_results]
1350
+ tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
1351
+ for mc in merged_candidates]
1352
+ tracer.add_reranked(results_dict, tracer_merged)
1353
+ tracer.add_phase_metric("reranking", step_duration, {
1354
+ "reranker_type": "cross-encoder",
1355
+ "candidates_reranked": len(scored_results)
1356
+ })
1357
+
1346
1358
  # Step 5: Truncate to thinking_budget * 2 for token filtering
1347
1359
  rerank_limit = thinking_budget * 2
1348
1360
  top_scored = scored_results[:rerank_limit]
@@ -3,13 +3,27 @@ Search module for memory retrieval.
3
3
 
4
4
  Provides modular search architecture:
5
5
  - Retrieval: 4-way parallel (semantic + BM25 + graph + temporal)
6
+ - Graph retrieval: Pluggable strategies (BFS, PPR)
6
7
  - Reranking: Pluggable strategies (heuristic, cross-encoder)
7
8
  """
8
9
 
9
- from .retrieval import retrieve_parallel
10
+ from .retrieval import (
11
+ retrieve_parallel,
12
+ get_default_graph_retriever,
13
+ set_default_graph_retriever,
14
+ ParallelRetrievalResult,
15
+ )
16
+ from .graph_retrieval import GraphRetriever, BFSGraphRetriever
17
+ from .mpfp_retrieval import MPFPGraphRetriever
10
18
  from .reranking import CrossEncoderReranker
11
19
 
12
20
  __all__ = [
13
21
  "retrieve_parallel",
22
+ "get_default_graph_retriever",
23
+ "set_default_graph_retriever",
24
+ "ParallelRetrievalResult",
25
+ "GraphRetriever",
26
+ "BFSGraphRetriever",
27
+ "MPFPGraphRetriever",
14
28
  "CrossEncoderReranker",
15
29
  ]
@@ -0,0 +1,235 @@
1
+ """
2
+ Graph retrieval strategies for memory recall.
3
+
4
+ This module provides an abstraction for graph-based memory retrieval,
5
+ allowing different algorithms (BFS spreading activation, PPR, etc.) to be
6
+ swapped without changing the rest of the recall pipeline.
7
+ """
8
+
9
+ from abc import ABC, abstractmethod
10
+ from typing import List, Optional
11
+ from datetime import datetime
12
+ import logging
13
+
14
+ from .types import RetrievalResult
15
+ from ..db_utils import acquire_with_retry
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class GraphRetriever(ABC):
21
+ """
22
+ Abstract base class for graph-based memory retrieval.
23
+
24
+ Implementations traverse the memory graph (entity links, temporal links,
25
+ causal links) to find relevant facts that might not be found by
26
+ semantic or keyword search alone.
27
+ """
28
+
29
+ @property
30
+ @abstractmethod
31
+ def name(self) -> str:
32
+ """Return identifier for this retrieval strategy (e.g., 'bfs', 'mpfp')."""
33
+ pass
34
+
35
+ @abstractmethod
36
+ async def retrieve(
37
+ self,
38
+ pool,
39
+ query_embedding_str: str,
40
+ bank_id: str,
41
+ fact_type: str,
42
+ budget: int,
43
+ query_text: Optional[str] = None,
44
+ semantic_seeds: Optional[List[RetrievalResult]] = None,
45
+ temporal_seeds: Optional[List[RetrievalResult]] = None,
46
+ ) -> List[RetrievalResult]:
47
+ """
48
+ Retrieve relevant facts via graph traversal.
49
+
50
+ Args:
51
+ pool: Database connection pool
52
+ query_embedding_str: Query embedding as string (for finding entry points)
53
+ bank_id: Memory bank identifier
54
+ fact_type: Fact type to filter ('world', 'experience', 'opinion', 'observation')
55
+ budget: Maximum number of nodes to explore/return
56
+ query_text: Original query text (optional, for some strategies)
57
+ semantic_seeds: Pre-computed semantic entry points (from semantic retrieval)
58
+ temporal_seeds: Pre-computed temporal entry points (from temporal retrieval)
59
+
60
+ Returns:
61
+ List of RetrievalResult objects with activation scores set
62
+ """
63
+ pass
64
+
65
+
66
+ class BFSGraphRetriever(GraphRetriever):
67
+ """
68
+ Graph retrieval using BFS-style spreading activation.
69
+
70
+ Starting from semantic entry points, spreads activation through
71
+ the memory graph (entity, temporal, causal links) using breadth-first
72
+ traversal with decaying activation.
73
+
74
+ This is the original Hindsight graph retrieval algorithm.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ entry_point_limit: int = 5,
80
+ entry_point_threshold: float = 0.5,
81
+ activation_decay: float = 0.8,
82
+ min_activation: float = 0.1,
83
+ batch_size: int = 20,
84
+ ):
85
+ """
86
+ Initialize BFS graph retriever.
87
+
88
+ Args:
89
+ entry_point_limit: Maximum number of entry points to start from
90
+ entry_point_threshold: Minimum semantic similarity for entry points
91
+ activation_decay: Decay factor per hop (activation *= decay)
92
+ min_activation: Minimum activation to continue spreading
93
+ batch_size: Number of nodes to process per batch (for neighbor fetching)
94
+ """
95
+ self.entry_point_limit = entry_point_limit
96
+ self.entry_point_threshold = entry_point_threshold
97
+ self.activation_decay = activation_decay
98
+ self.min_activation = min_activation
99
+ self.batch_size = batch_size
100
+
101
+ @property
102
+ def name(self) -> str:
103
+ return "bfs"
104
+
105
+ async def retrieve(
106
+ self,
107
+ pool,
108
+ query_embedding_str: str,
109
+ bank_id: str,
110
+ fact_type: str,
111
+ budget: int,
112
+ query_text: Optional[str] = None,
113
+ semantic_seeds: Optional[List[RetrievalResult]] = None,
114
+ temporal_seeds: Optional[List[RetrievalResult]] = None,
115
+ ) -> List[RetrievalResult]:
116
+ """
117
+ Retrieve facts using BFS spreading activation.
118
+
119
+ Algorithm:
120
+ 1. Find entry points (top semantic matches above threshold)
121
+ 2. BFS traversal: visit neighbors, propagate decaying activation
122
+ 3. Boost causal links (causes, enables, prevents)
123
+ 4. Return visited nodes up to budget
124
+
125
+ Note: BFS finds its own entry points via embedding search.
126
+ The semantic_seeds and temporal_seeds parameters are accepted
127
+ for interface compatibility but not used.
128
+ """
129
+ async with acquire_with_retry(pool) as conn:
130
+ return await self._retrieve_with_conn(
131
+ conn, query_embedding_str, bank_id, fact_type, budget
132
+ )
133
+
134
+ async def _retrieve_with_conn(
135
+ self,
136
+ conn,
137
+ query_embedding_str: str,
138
+ bank_id: str,
139
+ fact_type: str,
140
+ budget: int,
141
+ ) -> List[RetrievalResult]:
142
+ """Internal implementation with connection."""
143
+
144
+ # Step 1: Find entry points
145
+ entry_points = await conn.fetch(
146
+ """
147
+ SELECT id, text, context, event_date, occurred_start, occurred_end,
148
+ mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
149
+ 1 - (embedding <=> $1::vector) AS similarity
150
+ FROM memory_units
151
+ WHERE bank_id = $2
152
+ AND embedding IS NOT NULL
153
+ AND fact_type = $3
154
+ AND (1 - (embedding <=> $1::vector)) >= $4
155
+ ORDER BY embedding <=> $1::vector
156
+ LIMIT $5
157
+ """,
158
+ query_embedding_str, bank_id, fact_type,
159
+ self.entry_point_threshold, self.entry_point_limit
160
+ )
161
+
162
+ if not entry_points:
163
+ return []
164
+
165
+ # Step 2: BFS spreading activation
166
+ visited = set()
167
+ results = []
168
+ queue = [
169
+ (RetrievalResult.from_db_row(dict(r)), r["similarity"])
170
+ for r in entry_points
171
+ ]
172
+ budget_remaining = budget
173
+
174
+ while queue and budget_remaining > 0:
175
+ # Collect a batch of nodes to process
176
+ batch_nodes = []
177
+ batch_activations = {}
178
+
179
+ while queue and len(batch_nodes) < self.batch_size and budget_remaining > 0:
180
+ current, activation = queue.pop(0)
181
+ unit_id = current.id
182
+
183
+ if unit_id not in visited:
184
+ visited.add(unit_id)
185
+ budget_remaining -= 1
186
+ current.activation = activation
187
+ results.append(current)
188
+ batch_nodes.append(current.id)
189
+ batch_activations[unit_id] = activation
190
+
191
+ # Batch fetch neighbors
192
+ if batch_nodes and budget_remaining > 0:
193
+ max_neighbors = len(batch_nodes) * 20
194
+ neighbors = await conn.fetch(
195
+ """
196
+ SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end,
197
+ mu.mentioned_at, mu.access_count, mu.embedding, mu.fact_type,
198
+ mu.document_id, mu.chunk_id,
199
+ ml.weight, ml.link_type, ml.from_unit_id
200
+ FROM memory_links ml
201
+ JOIN memory_units mu ON ml.to_unit_id = mu.id
202
+ WHERE ml.from_unit_id = ANY($1::uuid[])
203
+ AND ml.weight >= $2
204
+ AND mu.fact_type = $3
205
+ ORDER BY ml.weight DESC
206
+ LIMIT $4
207
+ """,
208
+ batch_nodes, self.min_activation, fact_type, max_neighbors
209
+ )
210
+
211
+ for n in neighbors:
212
+ neighbor_id = str(n["id"])
213
+ if neighbor_id not in visited:
214
+ parent_id = str(n["from_unit_id"])
215
+ parent_activation = batch_activations.get(parent_id, 0.5)
216
+
217
+ # Boost causal links
218
+ link_type = n["link_type"]
219
+ base_weight = n["weight"]
220
+
221
+ if link_type in ("causes", "caused_by"):
222
+ causal_boost = 2.0
223
+ elif link_type in ("enables", "prevents"):
224
+ causal_boost = 1.5
225
+ else:
226
+ causal_boost = 1.0
227
+
228
+ effective_weight = base_weight * causal_boost
229
+ new_activation = parent_activation * effective_weight * self.activation_decay
230
+
231
+ if new_activation > self.min_activation:
232
+ neighbor_result = RetrievalResult.from_db_row(dict(n))
233
+ queue.append((neighbor_result, new_activation))
234
+
235
+ return results