hindsight-api 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. hindsight_api/admin/__init__.py +1 -0
  2. hindsight_api/admin/cli.py +252 -0
  3. hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
  4. hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
  5. hindsight_api/api/http.py +282 -20
  6. hindsight_api/api/mcp.py +47 -52
  7. hindsight_api/config.py +238 -6
  8. hindsight_api/engine/cross_encoder.py +599 -86
  9. hindsight_api/engine/db_budget.py +284 -0
  10. hindsight_api/engine/db_utils.py +11 -0
  11. hindsight_api/engine/embeddings.py +453 -26
  12. hindsight_api/engine/entity_resolver.py +8 -5
  13. hindsight_api/engine/interface.py +8 -4
  14. hindsight_api/engine/llm_wrapper.py +241 -27
  15. hindsight_api/engine/memory_engine.py +609 -122
  16. hindsight_api/engine/query_analyzer.py +4 -3
  17. hindsight_api/engine/response_models.py +38 -0
  18. hindsight_api/engine/retain/fact_extraction.py +388 -192
  19. hindsight_api/engine/retain/fact_storage.py +34 -8
  20. hindsight_api/engine/retain/link_utils.py +24 -16
  21. hindsight_api/engine/retain/orchestrator.py +52 -17
  22. hindsight_api/engine/retain/types.py +9 -0
  23. hindsight_api/engine/search/graph_retrieval.py +42 -13
  24. hindsight_api/engine/search/link_expansion_retrieval.py +256 -0
  25. hindsight_api/engine/search/mpfp_retrieval.py +362 -117
  26. hindsight_api/engine/search/reranking.py +2 -2
  27. hindsight_api/engine/search/retrieval.py +847 -200
  28. hindsight_api/engine/search/tags.py +172 -0
  29. hindsight_api/engine/search/think_utils.py +1 -1
  30. hindsight_api/engine/search/trace.py +12 -0
  31. hindsight_api/engine/search/tracer.py +24 -1
  32. hindsight_api/engine/search/types.py +21 -0
  33. hindsight_api/engine/task_backend.py +109 -18
  34. hindsight_api/engine/utils.py +1 -1
  35. hindsight_api/extensions/context.py +10 -1
  36. hindsight_api/main.py +56 -4
  37. hindsight_api/metrics.py +433 -48
  38. hindsight_api/migrations.py +141 -1
  39. hindsight_api/models.py +3 -1
  40. hindsight_api/pg0.py +53 -0
  41. hindsight_api/server.py +39 -2
  42. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/METADATA +5 -1
  43. hindsight_api-0.3.0.dist-info/RECORD +82 -0
  44. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/entry_points.txt +1 -0
  45. hindsight_api-0.2.0.dist-info/RECORD +0 -75
  46. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/WHEEL +0 -0
@@ -9,6 +9,7 @@ propagation from Approximate PPR.
9
9
 
10
10
  Key properties:
11
11
  - Sublinear in graph size (threshold pruning bounds active nodes)
12
+ - Lazy edge loading: only loads edges for frontier nodes, not entire graph
12
13
  - Predefined patterns capture different retrieval intents
13
14
  - All patterns run in parallel, results fused via RRF
14
15
  - No LLM in the loop during traversal
@@ -22,7 +23,8 @@ from dataclasses import dataclass, field
22
23
  from ..db_utils import acquire_with_retry
23
24
  from ..memory_engine import fq_table
24
25
  from .graph_retrieval import GraphRetriever
25
- from .types import RetrievalResult
26
+ from .tags import TagsMatch
27
+ from .types import MPFPTimings, RetrievalResult
26
28
 
27
29
  logger = logging.getLogger(__name__)
28
30
 
@@ -41,11 +43,27 @@ class EdgeTarget:
41
43
 
42
44
 
43
45
  @dataclass
44
- class TypedAdjacency:
45
- """Adjacency lists split by edge type."""
46
+ class EdgeCache:
47
+ """
48
+ Cache for lazily-loaded edges.
49
+
50
+ Grows per-hop as edges are loaded for frontier nodes.
51
+ Shared across patterns to avoid redundant loads.
52
+ Loads ALL edge types at once to minimize DB queries.
53
+ Thread-safe via asyncio lock to prevent redundant concurrent loads.
54
+ """
46
55
 
47
- # edge_type -> from_node_id -> list of (to_node_id, weight)
56
+ # edge_type -> from_node_id -> list of EdgeTarget
48
57
  graphs: dict[str, dict[str, list[EdgeTarget]]] = field(default_factory=dict)
58
+ # Track which nodes have been fully loaded (all edge types)
59
+ _fully_loaded: set[str] = field(default_factory=set)
60
+ # Timing stats
61
+ db_queries: int = 0
62
+ edge_load_time: float = 0.0
63
+ # Detailed hop timing for debugging
64
+ hop_details: list[dict] = field(default_factory=list)
65
+ # Lock to prevent redundant concurrent loads
66
+ _lock: asyncio.Lock = field(default_factory=asyncio.Lock)
49
67
 
50
68
  def get_neighbors(self, edge_type: str, node_id: str) -> list[EdgeTarget]:
51
69
  """Get neighbors for a node via a specific edge type."""
@@ -63,6 +81,31 @@ class TypedAdjacency:
63
81
 
64
82
  return [EdgeTarget(node_id=n.node_id, weight=n.weight / total) for n in neighbors]
65
83
 
84
+ def is_fully_loaded(self, node_id: str) -> bool:
85
+ """Check if all edges for this node have been loaded."""
86
+ return node_id in self._fully_loaded
87
+
88
+ def get_uncached(self, node_ids: list[str]) -> list[str]:
89
+ """Get node IDs that haven't been fully loaded yet."""
90
+ return [n for n in node_ids if not self.is_fully_loaded(n)]
91
+
92
+ def add_all_edges(self, edges_by_type: dict[str, dict[str, list[EdgeTarget]]], all_queried: list[str]):
93
+ """
94
+ Add loaded edges to the cache (all edge types at once).
95
+
96
+ Args:
97
+ edges_by_type: Dict mapping edge_type -> from_node_id -> list of EdgeTarget
98
+ all_queried: All node IDs that were queried (marks them as fully loaded)
99
+ """
100
+ for edge_type, edges in edges_by_type.items():
101
+ if edge_type not in self.graphs:
102
+ self.graphs[edge_type] = {}
103
+ for node_id, neighbors in edges.items():
104
+ self.graphs[edge_type][node_id] = neighbors
105
+
106
+ # Mark all queried nodes as fully loaded (even if they have no edges)
107
+ self._fully_loaded.update(all_queried)
108
+
66
109
 
67
110
  @dataclass
68
111
  class PatternResult:
@@ -109,66 +152,249 @@ class SeedNode:
109
152
 
110
153
 
111
154
  # -----------------------------------------------------------------------------
112
- # Core Algorithm
155
+ # Lazy Edge Loading
113
156
  # -----------------------------------------------------------------------------
114
157
 
115
158
 
116
- def mpfp_traverse(
117
- seeds: list[SeedNode],
118
- pattern: list[str],
119
- adjacency: TypedAdjacency,
120
- config: MPFPConfig,
121
- ) -> PatternResult:
159
+ async def load_all_edges_for_frontier(
160
+ pool,
161
+ node_ids: list[str],
162
+ top_k_per_type: int = 20,
163
+ ) -> dict[str, dict[str, list[EdgeTarget]]]:
122
164
  """
123
- Forward Push traversal following a meta-path pattern.
165
+ Load top-k edges per (node, edge_type) for frontier nodes.
166
+
167
+ Uses a LATERAL join to efficiently fetch only the top-k edges per type,
168
+ avoiding loading hundreds of entity edges when only 20 are needed.
169
+
170
+ Requires composite index: (from_unit_id, link_type, weight DESC)
124
171
 
125
172
  Args:
126
- seeds: Entry point nodes with initial scores
127
- pattern: Sequence of edge types to follow
128
- adjacency: Typed adjacency structure
129
- config: Algorithm parameters
173
+ pool: Database connection pool
174
+ node_ids: Frontier node IDs to load edges for
175
+ top_k_per_type: Max edges to load per (node, link_type) pair
130
176
 
131
177
  Returns:
132
- PatternResult with accumulated scores per node
178
+ Dict mapping edge_type -> from_node_id -> list of EdgeTarget
133
179
  """
134
- if not seeds:
135
- return PatternResult(pattern=pattern, scores={})
180
+ if not node_ids:
181
+ return {}
182
+
183
+ async with acquire_with_retry(pool) as conn:
184
+ # Use LATERAL join to get top-k per (from_node, link_type)
185
+ # This leverages the composite index for efficient early termination
186
+ rows = await conn.fetch(
187
+ f"""
188
+ WITH frontier(node_id) AS (SELECT unnest($1::uuid[]))
189
+ SELECT f.node_id as from_unit_id, lt.link_type, edges.to_unit_id, edges.weight
190
+ FROM frontier f
191
+ CROSS JOIN (VALUES ('semantic'), ('temporal'), ('entity'), ('causes'), ('caused_by')) AS lt(link_type)
192
+ CROSS JOIN LATERAL (
193
+ SELECT ml.to_unit_id, ml.weight
194
+ FROM {fq_table("memory_links")} ml
195
+ WHERE ml.from_unit_id = f.node_id
196
+ AND ml.link_type = lt.link_type
197
+ AND ml.weight >= 0.1
198
+ ORDER BY ml.weight DESC
199
+ LIMIT $2
200
+ ) edges
201
+ """,
202
+ node_ids,
203
+ top_k_per_type,
204
+ )
205
+
206
+ # Group by edge_type -> from_node -> neighbors
207
+ result: dict[str, dict[str, list[EdgeTarget]]] = defaultdict(lambda: defaultdict(list))
208
+ for row in rows:
209
+ edge_type = row["link_type"]
210
+ from_id = str(row["from_unit_id"])
211
+ to_id = str(row["to_unit_id"])
212
+ weight = row["weight"]
213
+ result[edge_type][from_id].append(EdgeTarget(node_id=to_id, weight=weight))
214
+
215
+ # Convert nested defaultdicts to regular dicts
216
+ return {edge_type: dict(edges) for edge_type, edges in result.items()}
217
+
218
+
219
+ # -----------------------------------------------------------------------------
220
+ # Core Algorithm (Async with Lazy Loading)
221
+ # -----------------------------------------------------------------------------
222
+
223
+
224
+ @dataclass
225
+ class PatternState:
226
+ """State for a pattern traversal between hops."""
227
+
228
+ pattern: list[str]
229
+ hop_index: int
230
+ scores: dict[str, float]
231
+ frontier: dict[str, float]
136
232
 
137
- scores: dict[str, float] = {}
138
233
 
139
- # Initialize frontier with seed masses (normalized)
234
+ def _init_pattern_state(seeds: list[SeedNode], pattern: list[str]) -> PatternState:
235
+ """Initialize pattern state from seeds."""
236
+ if not seeds:
237
+ return PatternState(pattern=pattern, hop_index=0, scores={}, frontier={})
238
+
140
239
  total_seed_score = sum(s.score for s in seeds)
141
240
  if total_seed_score == 0:
142
- total_seed_score = len(seeds) # fallback to uniform
241
+ total_seed_score = len(seeds)
242
+
243
+ frontier = {s.node_id: s.score / total_seed_score for s in seeds}
244
+ return PatternState(pattern=pattern, hop_index=0, scores={}, frontier=frontier)
245
+
143
246
 
144
- frontier: dict[str, float] = {s.node_id: s.score / total_seed_score for s in seeds}
247
+ def _execute_hop(state: PatternState, cache: EdgeCache, config: MPFPConfig) -> set[str]:
248
+ """
249
+ Execute ONE hop of traversal, return frontier nodes for next hop.
250
+
251
+ This is a pure function that uses cached edges (no DB access).
252
+ Returns set of uncached nodes needed for next hop.
253
+ """
254
+ if state.hop_index >= len(state.pattern):
255
+ return set()
145
256
 
146
- # Follow pattern hop by hop
147
- for edge_type in pattern:
148
- next_frontier: dict[str, float] = {}
257
+ edge_type = state.pattern[state.hop_index]
149
258
 
150
- for node_id, mass in frontier.items():
151
- if mass < config.threshold:
152
- continue
259
+ # Collect active nodes above threshold
260
+ active_nodes = [node_id for node_id, mass in state.frontier.items() if mass >= config.threshold]
261
+ if not active_nodes:
262
+ state.frontier = {}
263
+ return set()
153
264
 
154
- # Keep α portion for this node
155
- scores[node_id] = scores.get(node_id, 0) + config.alpha * mass
265
+ # Propagate mass using cached edges
266
+ next_frontier: dict[str, float] = {}
267
+ uncached_for_next: set[str] = set()
156
268
 
157
- # Push (1-α) to neighbors
158
- push_mass = (1 - config.alpha) * mass
159
- neighbors = adjacency.get_normalized_neighbors(edge_type, node_id, config.top_k_neighbors)
269
+ for node_id, mass in state.frontier.items():
270
+ if mass < config.threshold:
271
+ continue
160
272
 
161
- for neighbor in neighbors:
162
- next_frontier[neighbor.node_id] = next_frontier.get(neighbor.node_id, 0) + push_mass * neighbor.weight
273
+ # Keep α portion for this node
274
+ state.scores[node_id] = state.scores.get(node_id, 0) + config.alpha * mass
163
275
 
164
- frontier = next_frontier
276
+ # Push (1-α) to neighbors
277
+ push_mass = (1 - config.alpha) * mass
278
+ neighbors = cache.get_normalized_neighbors(edge_type, node_id, config.top_k_neighbors)
165
279
 
166
- # Final frontier nodes get their remaining mass
167
- for node_id, mass in frontier.items():
280
+ for neighbor in neighbors:
281
+ next_frontier[neighbor.node_id] = next_frontier.get(neighbor.node_id, 0) + push_mass * neighbor.weight
282
+ # Track if we'll need edges for this node in the next hop
283
+ if not cache.is_fully_loaded(neighbor.node_id):
284
+ uncached_for_next.add(neighbor.node_id)
285
+
286
+ state.frontier = next_frontier
287
+ state.hop_index += 1
288
+
289
+ return uncached_for_next
290
+
291
+
292
+ def _finalize_pattern(state: PatternState, config: MPFPConfig) -> PatternResult:
293
+ """Finalize pattern by adding remaining frontier mass to scores."""
294
+ for node_id, mass in state.frontier.items():
168
295
  if mass >= config.threshold:
169
- scores[node_id] = scores.get(node_id, 0) + mass
296
+ state.scores[node_id] = state.scores.get(node_id, 0) + mass
297
+
298
+ return PatternResult(pattern=state.pattern, scores=state.scores)
299
+
300
+
301
+ async def mpfp_traverse_hop_synchronized(
302
+ pool,
303
+ pattern_jobs: list[tuple[list[SeedNode], list[str]]],
304
+ config: MPFPConfig,
305
+ cache: EdgeCache,
306
+ ) -> list[PatternResult]:
307
+ """
308
+ Execute ALL patterns with hop-synchronized edge loading.
309
+
310
+ Instead of running each pattern independently (causing multiple DB queries),
311
+ this function:
312
+ 1. Runs hop 1 for ALL patterns (using pre-warmed seed edges)
313
+ 2. Collects ALL unique hop-2 frontier nodes across patterns
314
+ 3. Pre-warms hop-2 edges in ONE query
315
+ 4. Runs hop 2 for ALL patterns
316
+
317
+ This reduces DB queries from O(patterns * hops) to O(hops).
318
+
319
+ Args:
320
+ pool: Database connection pool
321
+ pattern_jobs: List of (seeds, pattern) tuples
322
+ config: Algorithm parameters
323
+ cache: Shared edge cache (should be pre-warmed with seed edges)
324
+
325
+ Returns:
326
+ List of PatternResult for each pattern
327
+ """
328
+ import time
329
+
330
+ # Initialize all pattern states
331
+ states = [_init_pattern_state(seeds, pattern) for seeds, pattern in pattern_jobs]
332
+
333
+ # Determine max hops (all patterns should be same length, but be safe)
334
+ max_hops = max((len(p) for _, p in pattern_jobs), default=0)
335
+
336
+ # Detailed timing for debugging
337
+ hop_times: list[dict] = []
338
+
339
+ # Execute hop-by-hop across ALL patterns
340
+ for hop in range(max_hops):
341
+ hop_start = time.time()
342
+ hop_timing = {"hop": hop, "patterns_executed": 0, "uncached_count": 0, "load_time": 0.0}
343
+
344
+ # Execute this hop for all patterns, collect uncached nodes for next hop
345
+ all_uncached: set[str] = set()
346
+ exec_start = time.time()
347
+ for state in states:
348
+ if state.hop_index < len(state.pattern):
349
+ uncached = _execute_hop(state, cache, config)
350
+ all_uncached.update(uncached)
351
+ hop_timing["patterns_executed"] += 1
352
+ hop_timing["exec_time"] = time.time() - exec_start
353
+
354
+ # Pre-warm edges for ALL uncached nodes before next hop
355
+ hop_timing["uncached_count"] = len(all_uncached)
356
+ if all_uncached:
357
+ uncached_list = list(all_uncached - cache._fully_loaded)
358
+ hop_timing["uncached_after_filter"] = len(uncached_list)
359
+ if uncached_list:
360
+ load_start = time.time()
361
+ edges_by_type = await load_all_edges_for_frontier(pool, uncached_list, config.top_k_neighbors)
362
+ hop_timing["load_time"] = time.time() - load_start
363
+ cache.edge_load_time += hop_timing["load_time"]
364
+ cache.db_queries += 1
365
+ cache.add_all_edges(edges_by_type, uncached_list)
366
+ hop_timing["edges_loaded"] = sum(
367
+ len(neighbors) for edges in edges_by_type.values() for neighbors in edges.values()
368
+ )
369
+
370
+ hop_timing["total_time"] = time.time() - hop_start
371
+ hop_times.append(hop_timing)
372
+
373
+ # Store hop timing details in cache for logging
374
+ cache.hop_details = hop_times
375
+
376
+ # Finalize all patterns
377
+ return [_finalize_pattern(state, config) for state in states]
170
378
 
171
- return PatternResult(pattern=pattern, scores=scores)
379
+
380
+ async def mpfp_traverse_async(
381
+ pool,
382
+ seeds: list[SeedNode],
383
+ pattern: list[str],
384
+ config: MPFPConfig,
385
+ cache: EdgeCache,
386
+ ) -> PatternResult:
387
+ """
388
+ Async Forward Push traversal with lazy edge loading.
389
+
390
+ NOTE: For better performance with multiple patterns, use mpfp_traverse_hop_synchronized().
391
+ This function is kept for single-pattern use cases.
392
+ """
393
+ if not seeds:
394
+ return PatternResult(pattern=pattern, scores={})
395
+
396
+ results = await mpfp_traverse_hop_synchronized(pool, [(seeds, pattern)], config, cache)
397
+ return results[0] if results else PatternResult(pattern=pattern, scores={})
172
398
 
173
399
 
174
400
  def rrf_fusion(
@@ -210,38 +436,6 @@ def rrf_fusion(
210
436
  # -----------------------------------------------------------------------------
211
437
 
212
438
 
213
- async def load_typed_adjacency(pool, bank_id: str) -> TypedAdjacency:
214
- """
215
- Load all edges for a bank, split by edge type.
216
-
217
- Single query, then organize in-memory for fast traversal.
218
- """
219
- async with acquire_with_retry(pool) as conn:
220
- rows = await conn.fetch(
221
- f"""
222
- SELECT ml.from_unit_id, ml.to_unit_id, ml.link_type, ml.weight
223
- FROM {fq_table("memory_links")} ml
224
- JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
225
- WHERE mu.bank_id = $1
226
- AND ml.weight >= 0.1
227
- ORDER BY ml.from_unit_id, ml.weight DESC
228
- """,
229
- bank_id,
230
- )
231
-
232
- graphs: dict[str, dict[str, list[EdgeTarget]]] = defaultdict(lambda: defaultdict(list))
233
-
234
- for row in rows:
235
- from_id = str(row["from_unit_id"])
236
- to_id = str(row["to_unit_id"])
237
- link_type = row["link_type"]
238
- weight = row["weight"]
239
-
240
- graphs[link_type][from_id].append(EdgeTarget(node_id=to_id, weight=weight))
241
-
242
- return TypedAdjacency(graphs=dict(graphs))
243
-
244
-
245
439
  async def fetch_memory_units_by_ids(
246
440
  pool,
247
441
  node_ids: list[str],
@@ -255,7 +449,7 @@ async def fetch_memory_units_by_ids(
255
449
  rows = await conn.fetch(
256
450
  f"""
257
451
  SELECT id, text, context, event_date, occurred_start, occurred_end,
258
- mentioned_at, access_count, embedding, fact_type, document_id, chunk_id
452
+ mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags
259
453
  FROM {fq_table("memory_units")}
260
454
  WHERE id = ANY($1::uuid[])
261
455
  AND fact_type = $2
@@ -274,10 +468,10 @@ async def fetch_memory_units_by_ids(
274
468
 
275
469
  class MPFPGraphRetriever(GraphRetriever):
276
470
  """
277
- Graph retrieval using Meta-Path Forward Push.
471
+ Graph retrieval using Meta-Path Forward Push with lazy edge loading.
278
472
 
279
473
  Runs predefined patterns in parallel from semantic and temporal seeds,
280
- then fuses results via RRF.
474
+ loading edges on-demand per hop instead of loading entire graph upfront.
281
475
  """
282
476
 
283
477
  def __init__(self, config: MPFPConfig | None = None):
@@ -287,8 +481,13 @@ class MPFPGraphRetriever(GraphRetriever):
287
481
  Args:
288
482
  config: Algorithm configuration (uses defaults if None)
289
483
  """
290
- self.config = config or MPFPConfig()
291
- self._adjacency_cache: dict[str, TypedAdjacency] = {}
484
+ if config is None:
485
+ # Read top_k_neighbors from global config
486
+ from ...config import get_config
487
+
488
+ global_config = get_config()
489
+ config = MPFPConfig(top_k_neighbors=global_config.mpfp_top_k_neighbors)
490
+ self.config = config
292
491
 
293
492
  @property
294
493
  def name(self) -> str:
@@ -304,9 +503,12 @@ class MPFPGraphRetriever(GraphRetriever):
304
503
  query_text: str | None = None,
305
504
  semantic_seeds: list[RetrievalResult] | None = None,
306
505
  temporal_seeds: list[RetrievalResult] | None = None,
307
- ) -> list[RetrievalResult]:
506
+ adjacency=None, # Ignored - kept for interface compatibility
507
+ tags: list[str] | None = None,
508
+ tags_match: TagsMatch = "any",
509
+ ) -> tuple[list[RetrievalResult], MPFPTimings | None]:
308
510
  """
309
- Retrieve facts using MPFP algorithm.
511
+ Retrieve facts using MPFP algorithm with lazy edge loading.
310
512
 
311
513
  Args:
312
514
  pool: Database connection pool
@@ -317,12 +519,15 @@ class MPFPGraphRetriever(GraphRetriever):
317
519
  query_text: Original query text (optional)
318
520
  semantic_seeds: Pre-computed semantic entry points
319
521
  temporal_seeds: Pre-computed temporal entry points
522
+ adjacency: Ignored (kept for interface compatibility)
523
+ tags: Optional list of tags for visibility filtering (OR matching)
320
524
 
321
525
  Returns:
322
- List of RetrievalResult with activation scores
526
+ Tuple of (List of RetrievalResult with activation scores, MPFPTimings)
323
527
  """
324
- # Load typed adjacency (could cache per bank_id with TTL)
325
- adjacency = await load_typed_adjacency(pool, bank_id)
528
+ import time
529
+
530
+ timings = MPFPTimings(fact_type=fact_type)
326
531
 
327
532
  # Convert seeds to SeedNode format
328
533
  semantic_seed_nodes = self._convert_seeds(semantic_seeds, "similarity")
@@ -330,54 +535,88 @@ class MPFPGraphRetriever(GraphRetriever):
330
535
 
331
536
  # If no semantic seeds provided, fall back to finding our own
332
537
  if not semantic_seed_nodes:
333
- semantic_seed_nodes = await self._find_semantic_seeds(pool, query_embedding_str, bank_id, fact_type)
538
+ seeds_start = time.time()
539
+ semantic_seed_nodes = await self._find_semantic_seeds(
540
+ pool, query_embedding_str, bank_id, fact_type, tags=tags, tags_match=tags_match
541
+ )
542
+ timings.seeds_time = time.time() - seeds_start
543
+ logger.debug(
544
+ f"[MPFP] Found {len(semantic_seed_nodes)} semantic seeds for fact_type={fact_type} (tags={tags}, tags_match={tags_match})"
545
+ )
334
546
 
335
- # Run all patterns in parallel
336
- tasks = []
547
+ # Collect all pattern jobs
548
+ pattern_jobs = []
337
549
 
338
550
  # Patterns from semantic seeds
339
551
  for pattern in self.config.patterns_semantic:
340
552
  if semantic_seed_nodes:
341
- tasks.append(
342
- asyncio.to_thread(
343
- mpfp_traverse,
344
- semantic_seed_nodes,
345
- pattern,
346
- adjacency,
347
- self.config,
348
- )
349
- )
553
+ pattern_jobs.append((semantic_seed_nodes, pattern))
350
554
 
351
555
  # Patterns from temporal seeds
352
556
  for pattern in self.config.patterns_temporal:
353
557
  if temporal_seed_nodes:
354
- tasks.append(
355
- asyncio.to_thread(
356
- mpfp_traverse,
357
- temporal_seed_nodes,
358
- pattern,
359
- adjacency,
360
- self.config,
361
- )
362
- )
363
-
364
- if not tasks:
365
- return []
558
+ pattern_jobs.append((temporal_seed_nodes, pattern))
366
559
 
367
- # Gather pattern results
368
- pattern_results = await asyncio.gather(*tasks)
560
+ if not pattern_jobs:
561
+ logger.debug(
562
+ f"[MPFP] No pattern jobs (semantic_seeds={len(semantic_seed_nodes)}, temporal_seeds={len(temporal_seed_nodes)})"
563
+ )
564
+ return [], timings
565
+
566
+ timings.pattern_count = len(pattern_jobs)
567
+
568
+ # Shared edge cache across all patterns
569
+ cache = EdgeCache()
570
+
571
+ # Pre-warm cache with ALL seed node edges BEFORE running patterns
572
+ # This prevents redundant DB queries at hop 1
573
+ all_seed_ids = list({s.node_id for seeds, _ in pattern_jobs for s in seeds})
574
+ if all_seed_ids:
575
+ import time as time_module
576
+
577
+ prewarm_start = time_module.time()
578
+ edges_by_type = await load_all_edges_for_frontier(pool, all_seed_ids, self.config.top_k_neighbors)
579
+ cache.edge_load_time += time_module.time() - prewarm_start
580
+ cache.db_queries += 1
581
+ cache.add_all_edges(edges_by_type, all_seed_ids)
582
+
583
+ # Run all patterns with HOP-SYNCHRONIZED edge loading
584
+ # This batches hop-2 edge loads across ALL patterns into ONE query
585
+ # Reduces DB queries from O(patterns * hops) to O(hops)
586
+ step_start = time.time()
587
+ pattern_results = await mpfp_traverse_hop_synchronized(pool, pattern_jobs, self.config, cache)
588
+ timings.traverse = time.time() - step_start
589
+
590
+ # Record edge loading stats from cache
591
+ timings.edge_count = sum(len(neighbors) for g in cache.graphs.values() for neighbors in g.values())
592
+ timings.db_queries = cache.db_queries
593
+ timings.edge_load_time = cache.edge_load_time
594
+ timings.hop_details = cache.hop_details
369
595
 
370
596
  # Fuse results
597
+ step_start = time.time()
371
598
  fused = rrf_fusion(pattern_results, top_k=budget)
599
+ timings.fusion = time.time() - step_start
372
600
 
373
601
  if not fused:
374
- return []
602
+ logger.debug(f"[MPFP] No fused results after RRF fusion (pattern_count={len(pattern_results)})")
603
+ return [], timings
375
604
 
376
- # Get top result IDs (don't exclude seeds - they may be highly relevant)
605
+ # Get top result IDs
377
606
  result_ids = [node_id for node_id, score in fused][:budget]
378
607
 
379
608
  # Fetch full details
609
+ step_start = time.time()
380
610
  results = await fetch_memory_units_by_ids(pool, result_ids, fact_type)
611
+ timings.fetch = time.time() - step_start
612
+
613
+ # Filter results by tags (graph traversal may have picked up unfiltered memories)
614
+ if tags:
615
+ from .tags import filter_results_by_tags
616
+
617
+ results = filter_results_by_tags(results, tags, match=tags_match)
618
+
619
+ timings.result_count = len(results)
381
620
 
382
621
  # Add activation scores from fusion
383
622
  score_map = {node_id: score for node_id, score in fused}
@@ -387,7 +626,7 @@ class MPFPGraphRetriever(GraphRetriever):
387
626
  # Sort by activation
388
627
  results.sort(key=lambda r: r.activation or 0, reverse=True)
389
628
 
390
- return results
629
+ return results, timings
391
630
 
392
631
  def _convert_seeds(
393
632
  self,
@@ -415,8 +654,17 @@ class MPFPGraphRetriever(GraphRetriever):
415
654
  fact_type: str,
416
655
  limit: int = 20,
417
656
  threshold: float = 0.3,
657
+ tags: list[str] | None = None,
658
+ tags_match: TagsMatch = "any",
418
659
  ) -> list[SeedNode]:
419
660
  """Fallback: find semantic seeds via embedding search."""
661
+ from .tags import build_tags_where_clause_simple
662
+
663
+ tags_clause = build_tags_where_clause_simple(tags, 6, match=tags_match)
664
+ params = [query_embedding_str, bank_id, fact_type, threshold, limit]
665
+ if tags:
666
+ params.append(tags)
667
+
420
668
  async with acquire_with_retry(pool) as conn:
421
669
  rows = await conn.fetch(
422
670
  f"""
@@ -426,14 +674,11 @@ class MPFPGraphRetriever(GraphRetriever):
426
674
  AND embedding IS NOT NULL
427
675
  AND fact_type = $3
428
676
  AND (1 - (embedding <=> $1::vector)) >= $4
677
+ {tags_clause}
429
678
  ORDER BY embedding <=> $1::vector
430
679
  LIMIT $5
431
680
  """,
432
- query_embedding_str,
433
- bank_id,
434
- fact_type,
435
- threshold,
436
- limit,
681
+ *params,
437
682
  )
438
683
 
439
684
  return [SeedNode(node_id=str(r["id"]), score=r["similarity"]) for r in rows]
@@ -44,7 +44,7 @@ class CrossEncoderReranker:
44
44
  await cross_encoder.initialize()
45
45
  self._initialized = True
46
46
 
47
- def rerank(self, query: str, candidates: list[MergedCandidate]) -> list[ScoredResult]:
47
+ async def rerank(self, query: str, candidates: list[MergedCandidate]) -> list[ScoredResult]:
48
48
  """
49
49
  Rerank candidates using cross-encoder scores.
50
50
 
@@ -85,7 +85,7 @@ class CrossEncoderReranker:
85
85
  pairs.append([query, doc_text])
86
86
 
87
87
  # Get cross-encoder scores
88
- scores = self.cross_encoder.predict(pairs)
88
+ scores = await self.cross_encoder.predict(pairs)
89
89
 
90
90
  # Normalize scores using sigmoid to [0, 1] range
91
91
  # Cross-encoder returns logits which can be negative