hindsight-api 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. hindsight_api/admin/__init__.py +1 -0
  2. hindsight_api/admin/cli.py +252 -0
  3. hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
  4. hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
  5. hindsight_api/api/http.py +282 -20
  6. hindsight_api/api/mcp.py +47 -52
  7. hindsight_api/config.py +238 -6
  8. hindsight_api/engine/cross_encoder.py +599 -86
  9. hindsight_api/engine/db_budget.py +284 -0
  10. hindsight_api/engine/db_utils.py +11 -0
  11. hindsight_api/engine/embeddings.py +453 -26
  12. hindsight_api/engine/entity_resolver.py +8 -5
  13. hindsight_api/engine/interface.py +8 -4
  14. hindsight_api/engine/llm_wrapper.py +241 -27
  15. hindsight_api/engine/memory_engine.py +609 -122
  16. hindsight_api/engine/query_analyzer.py +4 -3
  17. hindsight_api/engine/response_models.py +38 -0
  18. hindsight_api/engine/retain/fact_extraction.py +388 -192
  19. hindsight_api/engine/retain/fact_storage.py +34 -8
  20. hindsight_api/engine/retain/link_utils.py +24 -16
  21. hindsight_api/engine/retain/orchestrator.py +52 -17
  22. hindsight_api/engine/retain/types.py +9 -0
  23. hindsight_api/engine/search/graph_retrieval.py +42 -13
  24. hindsight_api/engine/search/link_expansion_retrieval.py +256 -0
  25. hindsight_api/engine/search/mpfp_retrieval.py +362 -117
  26. hindsight_api/engine/search/reranking.py +2 -2
  27. hindsight_api/engine/search/retrieval.py +847 -200
  28. hindsight_api/engine/search/tags.py +172 -0
  29. hindsight_api/engine/search/think_utils.py +1 -1
  30. hindsight_api/engine/search/trace.py +12 -0
  31. hindsight_api/engine/search/tracer.py +24 -1
  32. hindsight_api/engine/search/types.py +21 -0
  33. hindsight_api/engine/task_backend.py +109 -18
  34. hindsight_api/engine/utils.py +1 -1
  35. hindsight_api/extensions/context.py +10 -1
  36. hindsight_api/main.py +56 -4
  37. hindsight_api/metrics.py +433 -48
  38. hindsight_api/migrations.py +141 -1
  39. hindsight_api/models.py +3 -1
  40. hindsight_api/pg0.py +53 -0
  41. hindsight_api/server.py +39 -2
  42. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/METADATA +5 -1
  43. hindsight_api-0.3.0.dist-info/RECORD +82 -0
  44. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/entry_points.txt +1 -0
  45. hindsight_api-0.2.0.dist-info/RECORD +0 -75
  46. {hindsight_api-0.2.0.dist-info → hindsight_api-0.3.0.dist-info}/WHEEL +0 -0
@@ -18,8 +18,10 @@ from ...config import get_config
18
18
  from ..db_utils import acquire_with_retry
19
19
  from ..memory_engine import fq_table
20
20
  from .graph_retrieval import BFSGraphRetriever, GraphRetriever
21
+ from .link_expansion_retrieval import LinkExpansionRetriever
21
22
  from .mpfp_retrieval import MPFPGraphRetriever
22
- from .types import RetrievalResult
23
+ from .tags import TagsMatch, build_tags_where_clause_simple
24
+ from .types import MPFPTimings, RetrievalResult
23
25
 
24
26
  logger = logging.getLogger(__name__)
25
27
 
@@ -34,6 +36,20 @@ class ParallelRetrievalResult:
34
36
  temporal: list[RetrievalResult] | None
35
37
  timings: dict[str, float] = field(default_factory=dict)
36
38
  temporal_constraint: tuple | None = None # (start_date, end_date)
39
+ mpfp_timings: list[MPFPTimings] = field(default_factory=list) # MPFP sub-step timings per fact type
40
+ max_conn_wait: float = 0.0 # Maximum connection acquisition wait time across all methods
41
+
42
+
43
+ @dataclass
44
+ class MultiFactTypeRetrievalResult:
45
+ """Result from retrieval across all fact types."""
46
+
47
+ # Results per fact type
48
+ results_by_fact_type: dict[str, ParallelRetrievalResult]
49
+ # Aggregate timings
50
+ timings: dict[str, float] = field(default_factory=dict)
51
+ # Max connection wait across all operations
52
+ max_conn_wait: float = 0.0
37
53
 
38
54
 
39
55
  # Default graph retriever instance (can be overridden)
@@ -48,13 +64,18 @@ def get_default_graph_retriever() -> GraphRetriever:
48
64
  retriever_type = config.graph_retriever.lower()
49
65
  if retriever_type == "mpfp":
50
66
  _default_graph_retriever = MPFPGraphRetriever()
51
- logger.info("Using MPFP graph retriever")
67
+ logger.info(
68
+ f"Using MPFP graph retriever (top_k_neighbors={_default_graph_retriever.config.top_k_neighbors})"
69
+ )
52
70
  elif retriever_type == "bfs":
53
71
  _default_graph_retriever = BFSGraphRetriever()
54
72
  logger.info("Using BFS graph retriever")
73
+ elif retriever_type == "link_expansion":
74
+ _default_graph_retriever = LinkExpansionRetriever()
75
+ logger.info("Using LinkExpansion graph retriever")
55
76
  else:
56
- logger.warning(f"Unknown graph retriever '{retriever_type}', falling back to MPFP")
57
- _default_graph_retriever = MPFPGraphRetriever()
77
+ logger.warning(f"Unknown graph retriever '{retriever_type}', falling back to link_expansion")
78
+ _default_graph_retriever = LinkExpansionRetriever()
58
79
  return _default_graph_retriever
59
80
 
60
81
 
@@ -65,7 +86,12 @@ def set_default_graph_retriever(retriever: GraphRetriever) -> None:
65
86
 
66
87
 
67
88
  async def retrieve_semantic(
68
- conn, query_emb_str: str, bank_id: str, fact_type: str, limit: int
89
+ conn,
90
+ query_emb_str: str,
91
+ bank_id: str,
92
+ fact_type: str,
93
+ limit: int,
94
+ tags: list[str] | None = None,
69
95
  ) -> list[RetrievalResult]:
70
96
  """
71
97
  Semantic retrieval via vector similarity.
@@ -76,31 +102,44 @@ async def retrieve_semantic(
76
102
  agent_id: bank ID
77
103
  fact_type: Fact type to filter
78
104
  limit: Maximum results to return
105
+ tags: Optional list of tags for visibility filtering (OR matching)
79
106
 
80
107
  Returns:
81
108
  List of RetrievalResult objects
82
109
  """
110
+ from .tags import TagsMatch, build_tags_where_clause_simple
111
+
112
+ tags_clause = build_tags_where_clause_simple(tags, 5)
113
+ params = [query_emb_str, bank_id, fact_type, limit]
114
+ if tags:
115
+ params.append(tags)
116
+
83
117
  results = await conn.fetch(
84
118
  f"""
85
- SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
119
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
86
120
  1 - (embedding <=> $1::vector) AS similarity
87
121
  FROM {fq_table("memory_units")}
88
122
  WHERE bank_id = $2
89
123
  AND embedding IS NOT NULL
90
124
  AND fact_type = $3
91
125
  AND (1 - (embedding <=> $1::vector)) >= 0.3
126
+ {tags_clause}
92
127
  ORDER BY embedding <=> $1::vector
93
128
  LIMIT $4
94
129
  """,
95
- query_emb_str,
96
- bank_id,
97
- fact_type,
98
- limit,
130
+ *params,
99
131
  )
100
132
  return [RetrievalResult.from_db_row(dict(r)) for r in results]
101
133
 
102
134
 
103
- async def retrieve_bm25(conn, query_text: str, bank_id: str, fact_type: str, limit: int) -> list[RetrievalResult]:
135
+ async def retrieve_bm25(
136
+ conn,
137
+ query_text: str,
138
+ bank_id: str,
139
+ fact_type: str,
140
+ limit: int,
141
+ tags: list[str] | None = None,
142
+ ) -> list[RetrievalResult]:
104
143
  """
105
144
  BM25 keyword retrieval via full-text search.
106
145
 
@@ -110,12 +149,15 @@ async def retrieve_bm25(conn, query_text: str, bank_id: str, fact_type: str, lim
110
149
  agent_id: bank ID
111
150
  fact_type: Fact type to filter
112
151
  limit: Maximum results to return
152
+ tags: Optional list of tags for visibility filtering (OR matching)
113
153
 
114
154
  Returns:
115
155
  List of RetrievalResult objects
116
156
  """
117
157
  import re
118
158
 
159
+ from .tags import TagsMatch, build_tags_where_clause_simple
160
+
119
161
  # Sanitize query text: remove special characters that have meaning in tsquery
120
162
  # Keep only alphanumeric characters and spaces
121
163
  sanitized_text = re.sub(r"[^\w\s]", " ", query_text.lower())
@@ -131,25 +173,394 @@ async def retrieve_bm25(conn, query_text: str, bank_id: str, fact_type: str, lim
131
173
  # This prevents empty results when some terms are missing
132
174
  query_tsquery = " | ".join(tokens)
133
175
 
176
+ tags_clause = build_tags_where_clause_simple(tags, 5)
177
+ params = [query_tsquery, bank_id, fact_type, limit]
178
+ if tags:
179
+ params.append(tags)
180
+
134
181
  results = await conn.fetch(
135
182
  f"""
136
- SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
183
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
137
184
  ts_rank_cd(search_vector, to_tsquery('english', $1)) AS bm25_score
138
185
  FROM {fq_table("memory_units")}
139
186
  WHERE bank_id = $2
140
187
  AND fact_type = $3
141
188
  AND search_vector @@ to_tsquery('english', $1)
189
+ {tags_clause}
142
190
  ORDER BY bm25_score DESC
143
191
  LIMIT $4
144
192
  """,
145
- query_tsquery,
146
- bank_id,
147
- fact_type,
148
- limit,
193
+ *params,
149
194
  )
150
195
  return [RetrievalResult.from_db_row(dict(r)) for r in results]
151
196
 
152
197
 
198
+ async def retrieve_semantic_bm25_combined(
199
+ conn,
200
+ query_emb_str: str,
201
+ query_text: str,
202
+ bank_id: str,
203
+ fact_types: list[str],
204
+ limit: int,
205
+ tags: list[str] | None = None,
206
+ tags_match: TagsMatch = "any",
207
+ ) -> dict[str, tuple[list[RetrievalResult], list[RetrievalResult]]]:
208
+ """
209
+ Combined semantic + BM25 retrieval for multiple fact types in a single query.
210
+
211
+ Uses CTEs with window functions to get top-N results per fact type per method,
212
+ all in one database round-trip.
213
+
214
+ Args:
215
+ conn: Database connection
216
+ query_emb_str: Query embedding as string
217
+ query_text: Query text for BM25
218
+ bank_id: Bank ID
219
+ fact_types: List of fact types to retrieve
220
+ limit: Maximum results per method per fact type
221
+
222
+ Returns:
223
+ Dict mapping fact_type -> (semantic_results, bm25_results)
224
+ """
225
+ import re
226
+
227
+ # Sanitize query text for BM25 (same as retrieve_bm25)
228
+ sanitized_text = re.sub(r"[^\w\s]", " ", query_text.lower())
229
+ tokens = [token for token in sanitized_text.split() if token]
230
+
231
+ # If no valid tokens for BM25, just run semantic
232
+ if not tokens:
233
+ tags_clause = build_tags_where_clause_simple(tags, 5, match=tags_match)
234
+ params = [query_emb_str, bank_id, fact_types, limit]
235
+ if tags:
236
+ params.append(tags)
237
+ results = await conn.fetch(
238
+ f"""
239
+ WITH semantic_ranked AS (
240
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
241
+ 1 - (embedding <=> $1::vector) AS similarity,
242
+ NULL::float AS bm25_score,
243
+ 'semantic' AS source,
244
+ ROW_NUMBER() OVER (PARTITION BY fact_type ORDER BY embedding <=> $1::vector) AS rn
245
+ FROM {fq_table("memory_units")}
246
+ WHERE bank_id = $2
247
+ AND embedding IS NOT NULL
248
+ AND fact_type = ANY($3)
249
+ AND (1 - (embedding <=> $1::vector)) >= 0.3
250
+ {tags_clause}
251
+ )
252
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
253
+ similarity, bm25_score, source
254
+ FROM semantic_ranked
255
+ WHERE rn <= $4
256
+ """,
257
+ *params,
258
+ )
259
+ # Group by fact_type
260
+ result_dict: dict[str, tuple[list[RetrievalResult], list[RetrievalResult]]] = {
261
+ ft: ([], []) for ft in fact_types
262
+ }
263
+ for r in results:
264
+ row = dict(r)
265
+ ft = row.get("fact_type")
266
+ row.pop("source", None)
267
+ if ft in result_dict:
268
+ result_dict[ft][0].append(RetrievalResult.from_db_row(row))
269
+ return result_dict
270
+
271
+ query_tsquery = " | ".join(tokens)
272
+
273
+ # Build tags clause - param 6 if tags provided
274
+ tags_clause = build_tags_where_clause_simple(tags, 6, match=tags_match)
275
+ params = [query_emb_str, bank_id, fact_types, limit, query_tsquery]
276
+ if tags:
277
+ params.append(tags)
278
+
279
+ # Combined CTE query for both semantic and BM25 across all fact types
280
+ # Uses window functions to limit per fact_type per method
281
+ results = await conn.fetch(
282
+ f"""
283
+ WITH semantic_ranked AS (
284
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
285
+ 1 - (embedding <=> $1::vector) AS similarity,
286
+ NULL::float AS bm25_score,
287
+ 'semantic' AS source,
288
+ ROW_NUMBER() OVER (PARTITION BY fact_type ORDER BY embedding <=> $1::vector) AS rn
289
+ FROM {fq_table("memory_units")}
290
+ WHERE bank_id = $2
291
+ AND embedding IS NOT NULL
292
+ AND fact_type = ANY($3)
293
+ AND (1 - (embedding <=> $1::vector)) >= 0.3
294
+ {tags_clause}
295
+ ),
296
+ bm25_ranked AS (
297
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
298
+ NULL::float AS similarity,
299
+ ts_rank_cd(search_vector, to_tsquery('english', $5)) AS bm25_score,
300
+ 'bm25' AS source,
301
+ ROW_NUMBER() OVER (PARTITION BY fact_type ORDER BY ts_rank_cd(search_vector, to_tsquery('english', $5)) DESC) AS rn
302
+ FROM {fq_table("memory_units")}
303
+ WHERE bank_id = $2
304
+ AND fact_type = ANY($3)
305
+ AND search_vector @@ to_tsquery('english', $5)
306
+ {tags_clause}
307
+ ),
308
+ semantic AS (
309
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
310
+ similarity, bm25_score, source
311
+ FROM semantic_ranked WHERE rn <= $4
312
+ ),
313
+ bm25 AS (
314
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
315
+ similarity, bm25_score, source
316
+ FROM bm25_ranked WHERE rn <= $4
317
+ )
318
+ SELECT * FROM semantic
319
+ UNION ALL
320
+ SELECT * FROM bm25
321
+ """,
322
+ *params,
323
+ )
324
+
325
+ # Group results by fact_type and source
326
+ result_dict: dict[str, tuple[list[RetrievalResult], list[RetrievalResult]]] = {ft: ([], []) for ft in fact_types}
327
+ for r in results:
328
+ row = dict(r)
329
+ source = row.pop("source", None)
330
+ ft = row.get("fact_type")
331
+ if ft in result_dict:
332
+ if source == "semantic":
333
+ result_dict[ft][0].append(RetrievalResult.from_db_row(row))
334
+ else:
335
+ result_dict[ft][1].append(RetrievalResult.from_db_row(row))
336
+
337
+ return result_dict
338
+
339
+
340
+ async def retrieve_temporal_combined(
341
+ conn,
342
+ query_emb_str: str,
343
+ bank_id: str,
344
+ fact_types: list[str],
345
+ start_date: datetime,
346
+ end_date: datetime,
347
+ budget: int,
348
+ semantic_threshold: float = 0.1,
349
+ tags: list[str] | None = None,
350
+ tags_match: TagsMatch = "any",
351
+ ) -> dict[str, list[RetrievalResult]]:
352
+ """
353
+ Temporal retrieval for multiple fact types in a single query.
354
+
355
+ Batches the entry point query using window functions to get top-N per fact type,
356
+ then runs spreading for each fact type.
357
+
358
+ Args:
359
+ conn: Database connection
360
+ query_emb_str: Query embedding as string
361
+ bank_id: Bank ID
362
+ fact_types: List of fact types to retrieve
363
+ start_date: Start of time range
364
+ end_date: End of time range
365
+ budget: Node budget for spreading per fact type
366
+ semantic_threshold: Minimum semantic similarity to include
367
+
368
+ Returns:
369
+ Dict mapping fact_type -> list of RetrievalResult
370
+ """
371
+ from ..memory_engine import fq_table
372
+
373
+ # Ensure dates are timezone-aware
374
+ if start_date.tzinfo is None:
375
+ start_date = start_date.replace(tzinfo=UTC)
376
+ if end_date.tzinfo is None:
377
+ end_date = end_date.replace(tzinfo=UTC)
378
+
379
+ # Build tags clause
380
+ tags_clause = build_tags_where_clause_simple(tags, 7, match=tags_match)
381
+ params = [query_emb_str, bank_id, fact_types, start_date, end_date, semantic_threshold]
382
+ if tags:
383
+ params.append(tags)
384
+
385
+ # Batch query: Get entry points for ALL fact types at once with window function
386
+ entry_points = await conn.fetch(
387
+ f"""
388
+ WITH ranked_entries AS (
389
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
390
+ 1 - (embedding <=> $1::vector) AS similarity,
391
+ ROW_NUMBER() OVER (PARTITION BY fact_type ORDER BY COALESCE(occurred_start, mentioned_at, occurred_end) DESC, embedding <=> $1::vector) AS rn
392
+ FROM {fq_table("memory_units")}
393
+ WHERE bank_id = $2
394
+ AND fact_type = ANY($3)
395
+ AND embedding IS NOT NULL
396
+ AND (
397
+ (occurred_start IS NOT NULL AND occurred_end IS NOT NULL
398
+ AND occurred_start <= $5 AND occurred_end >= $4)
399
+ OR
400
+ (mentioned_at IS NOT NULL AND mentioned_at BETWEEN $4 AND $5)
401
+ OR
402
+ (occurred_start IS NOT NULL AND occurred_start BETWEEN $4 AND $5)
403
+ OR
404
+ (occurred_end IS NOT NULL AND occurred_end BETWEEN $4 AND $5)
405
+ )
406
+ AND (1 - (embedding <=> $1::vector)) >= $6
407
+ {tags_clause}
408
+ )
409
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags, similarity
410
+ FROM ranked_entries
411
+ WHERE rn <= 10
412
+ """,
413
+ *params,
414
+ )
415
+
416
+ if not entry_points:
417
+ return {ft: [] for ft in fact_types}
418
+
419
+ # Group entry points by fact type
420
+ entries_by_ft: dict[str, list] = {ft: [] for ft in fact_types}
421
+ for ep in entry_points:
422
+ ft = ep["fact_type"]
423
+ if ft in entries_by_ft:
424
+ entries_by_ft[ft].append(ep)
425
+
426
+ # Calculate shared temporal parameters
427
+ total_days = (end_date - start_date).total_seconds() / 86400
428
+ mid_date = start_date + (end_date - start_date) / 2
429
+
430
+ # Process each fact type (spreading needs to stay per fact type due to link filtering)
431
+ results_by_ft: dict[str, list[RetrievalResult]] = {}
432
+
433
+ for ft in fact_types:
434
+ ft_entry_points = entries_by_ft.get(ft, [])
435
+ if not ft_entry_points:
436
+ results_by_ft[ft] = []
437
+ continue
438
+
439
+ results = []
440
+ visited = set()
441
+ node_scores = {}
442
+
443
+ # Process entry points
444
+ for ep in ft_entry_points:
445
+ unit_id = str(ep["id"])
446
+ visited.add(unit_id)
447
+
448
+ # Calculate temporal proximity
449
+ best_date = None
450
+ if ep["occurred_start"] is not None and ep["occurred_end"] is not None:
451
+ best_date = ep["occurred_start"] + (ep["occurred_end"] - ep["occurred_start"]) / 2
452
+ elif ep["occurred_start"] is not None:
453
+ best_date = ep["occurred_start"]
454
+ elif ep["occurred_end"] is not None:
455
+ best_date = ep["occurred_end"]
456
+ elif ep["mentioned_at"] is not None:
457
+ best_date = ep["mentioned_at"]
458
+
459
+ if best_date:
460
+ days_from_mid = abs((best_date - mid_date).total_seconds() / 86400)
461
+ temporal_proximity = 1.0 - min(days_from_mid / (total_days / 2), 1.0) if total_days > 0 else 1.0
462
+ else:
463
+ temporal_proximity = 0.5
464
+
465
+ ep_result = RetrievalResult.from_db_row(dict(ep))
466
+ ep_result.temporal_score = temporal_proximity
467
+ ep_result.temporal_proximity = temporal_proximity
468
+ results.append(ep_result)
469
+ node_scores[unit_id] = (ep["similarity"], 1.0)
470
+
471
+ # Spreading through temporal links (same as single-fact-type version)
472
+ frontier = list(node_scores.keys())
473
+ budget_remaining = budget - len(ft_entry_points)
474
+ batch_size = 20
475
+
476
+ # Build tags clause for spreading (use param 6 since 1-5 are used)
477
+ spreading_tags_clause = build_tags_where_clause_simple(tags, 6, table_alias="mu.", match=tags_match)
478
+
479
+ while frontier and budget_remaining > 0:
480
+ batch_ids = frontier[:batch_size]
481
+ frontier = frontier[batch_size:]
482
+
483
+ spreading_params = [query_emb_str, batch_ids, ft, semantic_threshold, batch_size * 10]
484
+ if tags:
485
+ spreading_params.append(tags)
486
+
487
+ neighbors = await conn.fetch(
488
+ f"""
489
+ SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.access_count, mu.embedding, mu.fact_type, mu.document_id, mu.chunk_id, mu.tags,
490
+ ml.weight, ml.link_type, ml.from_unit_id,
491
+ 1 - (mu.embedding <=> $1::vector) AS similarity
492
+ FROM {fq_table("memory_links")} ml
493
+ JOIN {fq_table("memory_units")} mu ON ml.to_unit_id = mu.id
494
+ WHERE ml.from_unit_id = ANY($2::uuid[])
495
+ AND ml.link_type IN ('temporal', 'causes', 'caused_by', 'enables', 'prevents')
496
+ AND ml.weight >= 0.1
497
+ AND mu.fact_type = $3
498
+ AND mu.embedding IS NOT NULL
499
+ AND (1 - (mu.embedding <=> $1::vector)) >= $4
500
+ {spreading_tags_clause}
501
+ ORDER BY ml.weight DESC
502
+ LIMIT $5
503
+ """,
504
+ *spreading_params,
505
+ )
506
+
507
+ for n in neighbors:
508
+ neighbor_id = str(n["id"])
509
+ if neighbor_id in visited:
510
+ continue
511
+
512
+ visited.add(neighbor_id)
513
+ budget_remaining -= 1
514
+
515
+ parent_id = str(n["from_unit_id"])
516
+ _, parent_temporal_score = node_scores.get(parent_id, (0.5, 0.5))
517
+
518
+ neighbor_best_date = None
519
+ if n["occurred_start"] is not None and n["occurred_end"] is not None:
520
+ neighbor_best_date = n["occurred_start"] + (n["occurred_end"] - n["occurred_start"]) / 2
521
+ elif n["occurred_start"] is not None:
522
+ neighbor_best_date = n["occurred_start"]
523
+ elif n["occurred_end"] is not None:
524
+ neighbor_best_date = n["occurred_end"]
525
+ elif n["mentioned_at"] is not None:
526
+ neighbor_best_date = n["mentioned_at"]
527
+
528
+ if neighbor_best_date:
529
+ days_from_mid = abs((neighbor_best_date - mid_date).total_seconds() / 86400)
530
+ neighbor_temporal_proximity = (
531
+ 1.0 - min(days_from_mid / (total_days / 2), 1.0) if total_days > 0 else 1.0
532
+ )
533
+ else:
534
+ neighbor_temporal_proximity = 0.3
535
+
536
+ link_type = n["link_type"]
537
+ if link_type in ("causes", "caused_by"):
538
+ causal_boost = 2.0
539
+ elif link_type in ("enables", "prevents"):
540
+ causal_boost = 1.5
541
+ else:
542
+ causal_boost = 1.0
543
+
544
+ propagated_temporal = parent_temporal_score * n["weight"] * causal_boost * 0.7
545
+ combined_temporal = max(neighbor_temporal_proximity, propagated_temporal)
546
+
547
+ neighbor_result = RetrievalResult.from_db_row(dict(n))
548
+ neighbor_result.temporal_score = combined_temporal
549
+ neighbor_result.temporal_proximity = neighbor_temporal_proximity
550
+ results.append(neighbor_result)
551
+
552
+ if budget_remaining > 0 and combined_temporal > 0.2:
553
+ node_scores[neighbor_id] = (n["similarity"], combined_temporal)
554
+ frontier.append(neighbor_id)
555
+
556
+ if budget_remaining <= 0:
557
+ break
558
+
559
+ results_by_ft[ft] = results
560
+
561
+ return results_by_ft
562
+
563
+
153
564
  async def retrieve_temporal(
154
565
  conn,
155
566
  query_emb_str: str,
@@ -159,6 +570,7 @@ async def retrieve_temporal(
159
570
  end_date: datetime,
160
571
  budget: int,
161
572
  semantic_threshold: float = 0.1,
573
+ tags: list[str] | None = None,
162
574
  ) -> list[RetrievalResult]:
163
575
  """
164
576
  Temporal retrieval with spreading activation.
@@ -177,6 +589,7 @@ async def retrieve_temporal(
177
589
  end_date: End of time range
178
590
  budget: Node budget for spreading
179
591
  semantic_threshold: Minimum semantic similarity to include
592
+ tags: Optional list of tags for visibility filtering (OR matching)
180
593
 
181
594
  Returns:
182
595
  List of RetrievalResult objects with temporal scores
@@ -188,9 +601,16 @@ async def retrieve_temporal(
188
601
  if end_date.tzinfo is None:
189
602
  end_date = end_date.replace(tzinfo=UTC)
190
603
 
604
+ from .tags import TagsMatch, build_tags_where_clause_simple
605
+
606
+ tags_clause = build_tags_where_clause_simple(tags, 7)
607
+ params = [query_emb_str, bank_id, fact_type, start_date, end_date, semantic_threshold]
608
+ if tags:
609
+ params.append(tags)
610
+
191
611
  entry_points = await conn.fetch(
192
612
  f"""
193
- SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
613
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id, tags,
194
614
  1 - (embedding <=> $1::vector) AS similarity
195
615
  FROM {fq_table("memory_units")}
196
616
  WHERE bank_id = $2
@@ -210,15 +630,11 @@ async def retrieve_temporal(
210
630
  (occurred_end IS NOT NULL AND occurred_end BETWEEN $4 AND $5)
211
631
  )
212
632
  AND (1 - (embedding <=> $1::vector)) >= $6
633
+ {tags_clause}
213
634
  ORDER BY COALESCE(occurred_start, mentioned_at, occurred_end) DESC, (embedding <=> $1::vector) ASC
214
635
  LIMIT 10
215
636
  """,
216
- query_emb_str,
217
- bank_id,
218
- fact_type,
219
- start_date,
220
- end_date,
221
- semantic_threshold,
637
+ *params,
222
638
  )
223
639
 
224
640
  if not entry_points:
@@ -260,94 +676,101 @@ async def retrieve_temporal(
260
676
  ep_result.temporal_proximity = temporal_proximity
261
677
  results.append(ep_result)
262
678
 
263
- # Spread through temporal links
264
- queue = [
265
- (RetrievalResult.from_db_row(dict(ep)), ep["similarity"], 1.0) for ep in entry_points
266
- ] # (unit, semantic_sim, temporal_score)
679
+ # Spread through temporal links using BATCHED neighbor fetching
680
+ # Map node_id -> (semantic_sim, temporal_score) for propagation
681
+ node_scores = {str(ep["id"]): (ep["similarity"], 1.0) for ep in entry_points}
682
+ frontier = list(node_scores.keys()) # Current batch of nodes to expand
267
683
  budget_remaining = budget - len(entry_points)
684
+ batch_size = 20 # Process this many nodes per DB query
685
+
686
+ while frontier and budget_remaining > 0:
687
+ # Take a batch from frontier
688
+ batch_ids = frontier[:batch_size]
689
+ frontier = frontier[batch_size:]
690
+
691
+ # Batch fetch all neighbors for this batch of nodes
692
+ neighbors = await conn.fetch(
693
+ f"""
694
+ SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.access_count, mu.embedding, mu.fact_type, mu.document_id, mu.chunk_id,
695
+ ml.weight, ml.link_type, ml.from_unit_id,
696
+ 1 - (mu.embedding <=> $1::vector) AS similarity
697
+ FROM {fq_table("memory_links")} ml
698
+ JOIN {fq_table("memory_units")} mu ON ml.to_unit_id = mu.id
699
+ WHERE ml.from_unit_id = ANY($2::uuid[])
700
+ AND ml.link_type IN ('temporal', 'causes', 'caused_by', 'enables', 'prevents')
701
+ AND ml.weight >= 0.1
702
+ AND mu.fact_type = $3
703
+ AND mu.embedding IS NOT NULL
704
+ AND (1 - (mu.embedding <=> $1::vector)) >= $4
705
+ ORDER BY ml.weight DESC
706
+ LIMIT $5
707
+ """,
708
+ query_emb_str,
709
+ batch_ids,
710
+ fact_type,
711
+ semantic_threshold,
712
+ batch_size * 10, # Allow up to 10 neighbors per node in batch
713
+ )
268
714
 
269
- while queue and budget_remaining > 0:
270
- current, semantic_sim, temporal_score = queue.pop(0)
271
- current_id = current.id
272
-
273
- # Get neighbors via temporal and causal links
274
- if budget_remaining > 0:
275
- neighbors = await conn.fetch(
276
- f"""
277
- SELECT mu.id, mu.text, mu.context, mu.event_date, mu.occurred_start, mu.occurred_end, mu.mentioned_at, mu.access_count, mu.embedding, mu.fact_type, mu.document_id, mu.chunk_id,
278
- ml.weight, ml.link_type,
279
- 1 - (mu.embedding <=> $1::vector) AS similarity
280
- FROM {fq_table("memory_links")} ml
281
- JOIN {fq_table("memory_units")} mu ON ml.to_unit_id = mu.id
282
- WHERE ml.from_unit_id = $2
283
- AND ml.link_type IN ('temporal', 'causes', 'caused_by', 'enables', 'prevents')
284
- AND ml.weight >= 0.1
285
- AND mu.fact_type = $3
286
- AND mu.embedding IS NOT NULL
287
- AND (1 - (mu.embedding <=> $1::vector)) >= $4
288
- ORDER BY ml.weight DESC
289
- LIMIT 10
290
- """,
291
- query_emb_str,
292
- current.id,
293
- fact_type,
294
- semantic_threshold,
295
- )
296
-
297
- for n in neighbors:
298
- neighbor_id = str(n["id"])
299
- if neighbor_id in visited:
300
- continue
301
-
302
- visited.add(neighbor_id)
303
- budget_remaining -= 1
304
-
305
- # Calculate temporal score for neighbor using best available date
306
- neighbor_best_date = None
307
- if n["occurred_start"] is not None and n["occurred_end"] is not None:
308
- neighbor_best_date = n["occurred_start"] + (n["occurred_end"] - n["occurred_start"]) / 2
309
- elif n["occurred_start"] is not None:
310
- neighbor_best_date = n["occurred_start"]
311
- elif n["occurred_end"] is not None:
312
- neighbor_best_date = n["occurred_end"]
313
- elif n["mentioned_at"] is not None:
314
- neighbor_best_date = n["mentioned_at"]
315
-
316
- if neighbor_best_date:
317
- days_from_mid = abs((neighbor_best_date - mid_date).total_seconds() / 86400)
318
- neighbor_temporal_proximity = (
319
- 1.0 - min(days_from_mid / (total_days / 2), 1.0) if total_days > 0 else 1.0
320
- )
321
- else:
322
- neighbor_temporal_proximity = 0.3 # Lower score if no temporal data
323
-
324
- # Boost causal links (same as graph retrieval)
325
- link_type = n["link_type"]
326
- if link_type in ("causes", "caused_by"):
327
- causal_boost = 2.0
328
- elif link_type in ("enables", "prevents"):
329
- causal_boost = 1.5
330
- else:
331
- causal_boost = 1.0
332
-
333
- # Propagate temporal score through links (decay, with causal boost)
334
- propagated_temporal = temporal_score * n["weight"] * causal_boost * 0.7
335
-
336
- # Combined temporal score
337
- combined_temporal = max(neighbor_temporal_proximity, propagated_temporal)
338
-
339
- # Create RetrievalResult with temporal scores
340
- neighbor_result = RetrievalResult.from_db_row(dict(n))
341
- neighbor_result.temporal_score = combined_temporal
342
- neighbor_result.temporal_proximity = neighbor_temporal_proximity
343
- results.append(neighbor_result)
344
-
345
- # Add to queue for further spreading
346
- if budget_remaining > 0 and combined_temporal > 0.2:
347
- queue.append((neighbor_result, n["similarity"], combined_temporal))
348
-
349
- if budget_remaining <= 0:
350
- break
715
+ for n in neighbors:
716
+ neighbor_id = str(n["id"])
717
+ if neighbor_id in visited:
718
+ continue
719
+
720
+ visited.add(neighbor_id)
721
+ budget_remaining -= 1
722
+
723
+ # Get parent's scores for propagation
724
+ parent_id = str(n["from_unit_id"])
725
+ _, parent_temporal_score = node_scores.get(parent_id, (0.5, 0.5))
726
+
727
+ # Calculate temporal score for neighbor using best available date
728
+ neighbor_best_date = None
729
+ if n["occurred_start"] is not None and n["occurred_end"] is not None:
730
+ neighbor_best_date = n["occurred_start"] + (n["occurred_end"] - n["occurred_start"]) / 2
731
+ elif n["occurred_start"] is not None:
732
+ neighbor_best_date = n["occurred_start"]
733
+ elif n["occurred_end"] is not None:
734
+ neighbor_best_date = n["occurred_end"]
735
+ elif n["mentioned_at"] is not None:
736
+ neighbor_best_date = n["mentioned_at"]
737
+
738
+ if neighbor_best_date:
739
+ days_from_mid = abs((neighbor_best_date - mid_date).total_seconds() / 86400)
740
+ neighbor_temporal_proximity = (
741
+ 1.0 - min(days_from_mid / (total_days / 2), 1.0) if total_days > 0 else 1.0
742
+ )
743
+ else:
744
+ neighbor_temporal_proximity = 0.3 # Lower score if no temporal data
745
+
746
+ # Boost causal links (same as graph retrieval)
747
+ link_type = n["link_type"]
748
+ if link_type in ("causes", "caused_by"):
749
+ causal_boost = 2.0
750
+ elif link_type in ("enables", "prevents"):
751
+ causal_boost = 1.5
752
+ else:
753
+ causal_boost = 1.0
754
+
755
+ # Propagate temporal score through links (decay, with causal boost)
756
+ propagated_temporal = parent_temporal_score * n["weight"] * causal_boost * 0.7
757
+
758
+ # Combined temporal score
759
+ combined_temporal = max(neighbor_temporal_proximity, propagated_temporal)
760
+
761
+ # Create RetrievalResult with temporal scores
762
+ neighbor_result = RetrievalResult.from_db_row(dict(n))
763
+ neighbor_result.temporal_score = combined_temporal
764
+ neighbor_result.temporal_proximity = neighbor_temporal_proximity
765
+ results.append(neighbor_result)
766
+
767
+ # Track scores for propagation and add to frontier
768
+ if budget_remaining > 0 and combined_temporal > 0.2:
769
+ node_scores[neighbor_id] = (n["similarity"], combined_temporal)
770
+ frontier.append(neighbor_id)
771
+
772
+ if budget_remaining <= 0:
773
+ break
351
774
 
352
775
  return results
353
776
 
@@ -362,6 +785,8 @@ async def retrieve_parallel(
362
785
  question_date: datetime | None = None,
363
786
  query_analyzer: Optional["QueryAnalyzer"] = None,
364
787
  graph_retriever: GraphRetriever | None = None,
788
+ temporal_constraint: tuple | None = None, # Pre-extracted temporal constraint
789
+ tags: list[str] | None = None, # Visibility scope tags for filtering
365
790
  ) -> ParallelRetrievalResult:
366
791
  """
367
792
  Run 3-way or 4-way parallel retrieval (adds temporal if detected).
@@ -376,42 +801,58 @@ async def retrieve_parallel(
376
801
  question_date: Optional date when question was asked (for temporal filtering)
377
802
  query_analyzer: Query analyzer to use (defaults to TransformerQueryAnalyzer)
378
803
  graph_retriever: Graph retrieval strategy (defaults to configured retriever)
804
+ temporal_constraint: Pre-extracted temporal constraint (optional)
805
+ tags: Optional list of tags for visibility filtering (OR matching)
379
806
 
380
807
  Returns:
381
808
  ParallelRetrievalResult with semantic, bm25, graph, temporal results and timings
382
809
  """
383
- from .temporal_extraction import extract_temporal_constraint
384
-
385
- temporal_constraint = extract_temporal_constraint(query_text, reference_date=question_date, analyzer=query_analyzer)
386
-
387
810
  retriever = graph_retriever or get_default_graph_retriever()
388
811
 
389
- if retriever.name == "mpfp":
812
+ # Use optimized parallel path for MPFP and LinkExpansion (runs all methods truly in parallel)
813
+ # BFS uses legacy path that extracts temporal constraint upfront
814
+ if retriever.name in ("mpfp", "link_expansion"):
390
815
  return await _retrieve_parallel_mpfp(
391
- pool, query_text, query_embedding_str, bank_id, fact_type, thinking_budget, temporal_constraint, retriever
816
+ pool,
817
+ query_text,
818
+ query_embedding_str,
819
+ bank_id,
820
+ fact_type,
821
+ thinking_budget,
822
+ temporal_constraint,
823
+ retriever,
824
+ question_date,
825
+ query_analyzer,
826
+ tags=tags,
392
827
  )
393
828
  else:
829
+ # For BFS, extract temporal constraint upfront (legacy path)
830
+ if temporal_constraint is None:
831
+ from .temporal_extraction import extract_temporal_constraint
832
+
833
+ temporal_constraint = extract_temporal_constraint(
834
+ query_text, reference_date=question_date, analyzer=query_analyzer
835
+ )
394
836
  return await _retrieve_parallel_bfs(
395
- pool, query_text, query_embedding_str, bank_id, fact_type, thinking_budget, temporal_constraint, retriever
837
+ pool,
838
+ query_text,
839
+ query_embedding_str,
840
+ bank_id,
841
+ fact_type,
842
+ thinking_budget,
843
+ temporal_constraint,
844
+ retriever,
845
+ tags=tags,
396
846
  )
397
847
 
398
848
 
399
- @dataclass
400
- class _SemanticGraphResult:
401
- """Internal result from semantic→graph chain."""
402
-
403
- semantic: list[RetrievalResult]
404
- graph: list[RetrievalResult]
405
- semantic_time: float
406
- graph_time: float
407
-
408
-
409
849
  @dataclass
410
850
  class _TimedResult:
411
851
  """Internal result with timing."""
412
852
 
413
853
  results: list[RetrievalResult]
414
854
  time: float
855
+ conn_wait: float = 0.0 # Connection acquisition wait time
415
856
 
416
857
 
417
858
  async def _retrieve_parallel_mpfp(
@@ -423,60 +864,103 @@ async def _retrieve_parallel_mpfp(
423
864
  thinking_budget: int,
424
865
  temporal_constraint: tuple | None,
425
866
  retriever: GraphRetriever,
867
+ question_date: datetime | None = None,
868
+ query_analyzer=None,
869
+ tags: list[str] | None = None,
426
870
  ) -> ParallelRetrievalResult:
427
871
  """
428
- MPFP retrieval with optimized parallelization.
872
+ MPFP retrieval with true parallelization.
873
+
874
+ All methods run independently in parallel:
875
+ - Semantic: vector similarity search
876
+ - BM25: keyword search
877
+ - Graph: MPFP traversal (does its own semantic seeds internally)
878
+ - Temporal: date extraction (if needed) + date-range search
429
879
 
430
- Runs 2-3 parallel task chains:
431
- - Task 1: Semantic Graph (chained, graph uses semantic seeds)
432
- - Task 2: BM25 (independent)
433
- - Task 3: Temporal (if constraint detected)
880
+ Temporal extraction runs IN PARALLEL with other retrievals, so even if
881
+ dateparser is slow, it doesn't block semantic/BM25/graph.
434
882
  """
435
883
  import time
436
884
 
437
- async def run_semantic_then_graph() -> _SemanticGraphResult:
438
- """Chain: semantic retrieval → graph retrieval (using semantic as seeds)."""
885
+ async def run_semantic() -> _TimedResult:
886
+ """Independent semantic retrieval."""
439
887
  start = time.time()
888
+ acquire_start = time.time()
440
889
  async with acquire_with_retry(pool) as conn:
441
- semantic = await retrieve_semantic(conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget)
442
- semantic_time = time.time() - start
890
+ conn_wait = time.time() - acquire_start
891
+ results = await retrieve_semantic(
892
+ conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget, tags=tags
893
+ )
894
+ return _TimedResult(results, time.time() - start, conn_wait)
443
895
 
444
- # Get temporal seeds if needed (quick query, part of this chain)
445
- temporal_seeds = None
446
- if temporal_constraint:
447
- tc_start, tc_end = temporal_constraint
448
- async with acquire_with_retry(pool) as conn:
449
- temporal_seeds = await _get_temporal_entry_points(
450
- conn, query_embedding_str, bank_id, fact_type, tc_start, tc_end, limit=20
451
- )
896
+ async def run_bm25() -> _TimedResult:
897
+ """Independent BM25 retrieval."""
898
+ start = time.time()
899
+ acquire_start = time.time()
900
+ async with acquire_with_retry(pool) as conn:
901
+ conn_wait = time.time() - acquire_start
902
+ results = await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget, tags=tags)
903
+ return _TimedResult(results, time.time() - start, conn_wait)
452
904
 
453
- # Run graph with seeds
905
+ async def run_graph() -> tuple[list[RetrievalResult], float, MPFPTimings | None]:
906
+ """Independent graph retrieval - does its own semantic seeds."""
454
907
  start = time.time()
455
- graph = await retriever.retrieve(
908
+
909
+ # MPFP does its own semantic seeds via _find_semantic_seeds
910
+ # Note: temporal_seeds not used here to avoid dependency on temporal extraction
911
+ results, mpfp_timing = await retriever.retrieve(
456
912
  pool=pool,
457
913
  query_embedding_str=query_embedding_str,
458
914
  bank_id=bank_id,
459
915
  fact_type=fact_type,
460
916
  budget=thinking_budget,
461
917
  query_text=query_text,
462
- semantic_seeds=semantic,
463
- temporal_seeds=temporal_seeds,
918
+ semantic_seeds=None, # Let MPFP find its own seeds
919
+ temporal_seeds=None, # Don't wait for temporal extraction
920
+ tags=tags,
464
921
  )
465
- graph_time = time.time() - start
922
+ return results, time.time() - start, mpfp_timing
466
923
 
467
- return _SemanticGraphResult(semantic, graph, semantic_time, graph_time)
924
+ @dataclass
925
+ class _TemporalWithConstraint:
926
+ """Temporal results with the extracted constraint."""
468
927
 
469
- async def run_bm25() -> _TimedResult:
470
- """Independent BM25 retrieval."""
471
- start = time.time()
472
- async with acquire_with_retry(pool) as conn:
473
- results = await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget)
474
- return _TimedResult(results, time.time() - start)
928
+ results: list[RetrievalResult]
929
+ time: float
930
+ constraint: tuple | None
931
+ extraction_time: float # Time spent in query analyzer (dateparser)
932
+ conn_wait: float = 0.0 # Connection acquisition wait time
475
933
 
476
- async def run_temporal(tc_start, tc_end) -> _TimedResult:
477
- """Temporal retrieval (uses its own entry point finding)."""
934
+ async def run_temporal_with_extraction() -> _TemporalWithConstraint:
935
+ """
936
+ Extract temporal constraint AND run temporal retrieval.
937
+
938
+ This runs in parallel with semantic/BM25/graph, so dateparser
939
+ latency doesn't block other retrievals.
940
+ """
478
941
  start = time.time()
942
+
943
+ # Use pre-provided constraint if available
944
+ tc = temporal_constraint
945
+ extraction_time = 0.0
946
+
947
+ # Otherwise extract from query (this is the potentially slow dateparser call)
948
+ if tc is None:
949
+ from .temporal_extraction import extract_temporal_constraint
950
+
951
+ extraction_start = time.time()
952
+ tc = extract_temporal_constraint(query_text, reference_date=question_date, analyzer=query_analyzer)
953
+ extraction_time = time.time() - extraction_start
954
+
955
+ # If no temporal constraint found, return empty (but still report extraction time)
956
+ if tc is None:
957
+ return _TemporalWithConstraint([], time.time() - start, None, extraction_time, 0.0)
958
+
959
+ # Run temporal retrieval with the extracted constraint
960
+ tc_start, tc_end = tc
961
+ acquire_start = time.time()
479
962
  async with acquire_with_retry(pool) as conn:
963
+ conn_wait = time.time() - acquire_start
480
964
  results = await retrieve_temporal(
481
965
  conn,
482
966
  query_embedding_str,
@@ -487,46 +971,36 @@ async def _retrieve_parallel_mpfp(
487
971
  budget=thinking_budget,
488
972
  semantic_threshold=0.1,
489
973
  )
490
- return _TimedResult(results, time.time() - start)
491
-
492
- # Run parallel task chains
493
- if temporal_constraint:
494
- tc_start, tc_end = temporal_constraint
495
- sg_result, bm25_result, temporal_result = await asyncio.gather(
496
- run_semantic_then_graph(),
497
- run_bm25(),
498
- run_temporal(tc_start, tc_end),
499
- )
500
- return ParallelRetrievalResult(
501
- semantic=sg_result.semantic,
502
- bm25=bm25_result.results,
503
- graph=sg_result.graph,
504
- temporal=temporal_result.results,
505
- timings={
506
- "semantic": sg_result.semantic_time,
507
- "graph": sg_result.graph_time,
508
- "bm25": bm25_result.time,
509
- "temporal": temporal_result.time,
510
- },
511
- temporal_constraint=temporal_constraint,
512
- )
513
- else:
514
- sg_result, bm25_result = await asyncio.gather(
515
- run_semantic_then_graph(),
516
- run_bm25(),
517
- )
518
- return ParallelRetrievalResult(
519
- semantic=sg_result.semantic,
520
- bm25=bm25_result.results,
521
- graph=sg_result.graph,
522
- temporal=None,
523
- timings={
524
- "semantic": sg_result.semantic_time,
525
- "graph": sg_result.graph_time,
526
- "bm25": bm25_result.time,
527
- },
528
- temporal_constraint=None,
529
- )
974
+ return _TemporalWithConstraint(results, time.time() - start, tc, extraction_time, conn_wait)
975
+
976
+ # Run ALL methods in parallel (including temporal extraction!)
977
+ semantic_result, bm25_result, graph_result, temporal_result = await asyncio.gather(
978
+ run_semantic(),
979
+ run_bm25(),
980
+ run_graph(),
981
+ run_temporal_with_extraction(),
982
+ )
983
+ graph_results, graph_time, mpfp_timing = graph_result
984
+
985
+ # Compute max connection wait across all methods (graph handles its own connections)
986
+ max_conn_wait = max(semantic_result.conn_wait, bm25_result.conn_wait, temporal_result.conn_wait)
987
+
988
+ return ParallelRetrievalResult(
989
+ semantic=semantic_result.results,
990
+ bm25=bm25_result.results,
991
+ graph=graph_results,
992
+ temporal=temporal_result.results if temporal_result.results else None,
993
+ timings={
994
+ "semantic": semantic_result.time,
995
+ "bm25": bm25_result.time,
996
+ "graph": graph_time,
997
+ "temporal": temporal_result.time,
998
+ "temporal_extraction": temporal_result.extraction_time,
999
+ },
1000
+ temporal_constraint=temporal_result.constraint,
1001
+ mpfp_timings=[mpfp_timing] if mpfp_timing else [],
1002
+ max_conn_wait=max_conn_wait,
1003
+ )
530
1004
 
531
1005
 
532
1006
  async def _get_temporal_entry_points(
@@ -615,6 +1089,7 @@ async def _retrieve_parallel_bfs(
615
1089
  thinking_budget: int,
616
1090
  temporal_constraint: tuple | None,
617
1091
  retriever: GraphRetriever,
1092
+ tags: list[str] | None = None,
618
1093
  ) -> ParallelRetrievalResult:
619
1094
  """BFS retrieval: all methods run in parallel (original behavior)."""
620
1095
  import time
@@ -622,24 +1097,27 @@ async def _retrieve_parallel_bfs(
622
1097
  async def run_semantic() -> _TimedResult:
623
1098
  start = time.time()
624
1099
  async with acquire_with_retry(pool) as conn:
625
- results = await retrieve_semantic(conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget)
1100
+ results = await retrieve_semantic(
1101
+ conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget, tags=tags
1102
+ )
626
1103
  return _TimedResult(results, time.time() - start)
627
1104
 
628
1105
  async def run_bm25() -> _TimedResult:
629
1106
  start = time.time()
630
1107
  async with acquire_with_retry(pool) as conn:
631
- results = await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget)
1108
+ results = await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget, tags=tags)
632
1109
  return _TimedResult(results, time.time() - start)
633
1110
 
634
1111
  async def run_graph() -> _TimedResult:
635
1112
  start = time.time()
636
- results = await retriever.retrieve(
1113
+ results, _ = await retriever.retrieve(
637
1114
  pool=pool,
638
1115
  query_embedding_str=query_embedding_str,
639
1116
  bank_id=bank_id,
640
1117
  fact_type=fact_type,
641
1118
  budget=thinking_budget,
642
1119
  query_text=query_text,
1120
+ tags=tags,
643
1121
  )
644
1122
  return _TimedResult(results, time.time() - start)
645
1123
 
@@ -655,6 +1133,7 @@ async def _retrieve_parallel_bfs(
655
1133
  tc_end,
656
1134
  budget=thinking_budget,
657
1135
  semantic_threshold=0.1,
1136
+ tags=tags,
658
1137
  )
659
1138
  return _TimedResult(results, time.time() - start)
660
1139
 
@@ -697,3 +1176,171 @@ async def _retrieve_parallel_bfs(
697
1176
  },
698
1177
  temporal_constraint=None,
699
1178
  )
1179
+
1180
+
1181
+ async def retrieve_all_fact_types_parallel(
1182
+ pool,
1183
+ query_text: str,
1184
+ query_embedding_str: str,
1185
+ bank_id: str,
1186
+ fact_types: list[str],
1187
+ thinking_budget: int,
1188
+ question_date: datetime | None = None,
1189
+ query_analyzer: Optional["QueryAnalyzer"] = None,
1190
+ graph_retriever: GraphRetriever | None = None,
1191
+ tags: list[str] | None = None,
1192
+ tags_match: TagsMatch = "any",
1193
+ ) -> MultiFactTypeRetrievalResult:
1194
+ """
1195
+ Optimized retrieval for multiple fact types using batched queries.
1196
+
1197
+ This reduces database round-trips by:
1198
+ 1. Combining semantic + BM25 into one CTE query for ALL fact types (1 query instead of 2N)
1199
+ 2. Running graph retrieval per fact type in parallel (N parallel tasks)
1200
+ 3. Running temporal retrieval per fact type in parallel (N parallel tasks)
1201
+
1202
+ Args:
1203
+ pool: Database connection pool
1204
+ query_text: Query text
1205
+ query_embedding_str: Query embedding as string
1206
+ bank_id: Bank ID
1207
+ fact_types: List of fact types to retrieve
1208
+ thinking_budget: Budget for graph traversal and retrieval limits
1209
+ question_date: Optional date when question was asked (for temporal filtering)
1210
+ query_analyzer: Query analyzer to use (defaults to TransformerQueryAnalyzer)
1211
+ graph_retriever: Graph retrieval strategy (defaults to configured retriever)
1212
+
1213
+ Returns:
1214
+ MultiFactTypeRetrievalResult with results organized by fact type
1215
+ """
1216
+ import time
1217
+
1218
+ retriever = graph_retriever or get_default_graph_retriever()
1219
+ start_time = time.time()
1220
+ timings: dict[str, float] = {}
1221
+
1222
+ # Step 1: Extract temporal constraint first (CPU work, no DB)
1223
+ # Do this before DB queries so we know if we need temporal retrieval
1224
+ temporal_extraction_start = time.time()
1225
+ from .temporal_extraction import extract_temporal_constraint
1226
+
1227
+ temporal_constraint = extract_temporal_constraint(query_text, reference_date=question_date, analyzer=query_analyzer)
1228
+ temporal_extraction_time = time.time() - temporal_extraction_start
1229
+ timings["temporal_extraction"] = temporal_extraction_time
1230
+
1231
+ # Step 2: Run semantic + BM25 + temporal combined in ONE connection!
1232
+ # This reduces connection usage from 2 to 1 for these operations
1233
+ semantic_bm25_start = time.time()
1234
+ temporal_results_by_ft: dict[str, list[RetrievalResult]] = {}
1235
+ temporal_time = 0.0
1236
+
1237
+ async with acquire_with_retry(pool) as conn:
1238
+ conn_wait = time.time() - semantic_bm25_start
1239
+
1240
+ # Semantic + BM25 combined
1241
+ semantic_bm25_results = await retrieve_semantic_bm25_combined(
1242
+ conn,
1243
+ query_embedding_str,
1244
+ query_text,
1245
+ bank_id,
1246
+ fact_types,
1247
+ thinking_budget,
1248
+ tags=tags,
1249
+ tags_match=tags_match,
1250
+ )
1251
+ semantic_bm25_time = time.time() - semantic_bm25_start
1252
+
1253
+ # Temporal combined (if constraint detected) - same connection!
1254
+ if temporal_constraint:
1255
+ tc_start, tc_end = temporal_constraint
1256
+ temporal_start = time.time()
1257
+ temporal_results_by_ft = await retrieve_temporal_combined(
1258
+ conn,
1259
+ query_embedding_str,
1260
+ bank_id,
1261
+ fact_types,
1262
+ tc_start,
1263
+ tc_end,
1264
+ budget=thinking_budget,
1265
+ semantic_threshold=0.1,
1266
+ tags=tags,
1267
+ tags_match=tags_match,
1268
+ )
1269
+ temporal_time = time.time() - temporal_start
1270
+
1271
+ timings["semantic_bm25_combined"] = semantic_bm25_time
1272
+ timings["temporal_combined"] = temporal_time
1273
+
1274
+ # Step 3: Run graph retrieval for each fact type in parallel
1275
+ async def run_graph_for_fact_type(ft: str) -> tuple[str, list[RetrievalResult], float, MPFPTimings | None]:
1276
+ graph_start = time.time()
1277
+ results, mpfp_timing = await retriever.retrieve(
1278
+ pool=pool,
1279
+ query_embedding_str=query_embedding_str,
1280
+ bank_id=bank_id,
1281
+ fact_type=ft,
1282
+ budget=thinking_budget,
1283
+ query_text=query_text,
1284
+ semantic_seeds=None,
1285
+ temporal_seeds=None,
1286
+ tags=tags,
1287
+ tags_match=tags_match,
1288
+ )
1289
+ return ft, results, time.time() - graph_start, mpfp_timing
1290
+
1291
+ # Run graph for all fact types in parallel
1292
+ graph_tasks = [run_graph_for_fact_type(ft) for ft in fact_types]
1293
+ graph_results_list = await asyncio.gather(*graph_tasks)
1294
+
1295
+ # Organize results by fact type
1296
+ results_by_fact_type: dict[str, ParallelRetrievalResult] = {}
1297
+ max_conn_wait = conn_wait # Single connection for semantic+bm25+temporal
1298
+ all_mpfp_timings: list[MPFPTimings] = []
1299
+
1300
+ for ft in fact_types:
1301
+ # Get semantic + bm25 results for this fact type
1302
+ semantic_results, bm25_results = semantic_bm25_results.get(ft, ([], []))
1303
+
1304
+ # Find graph results for this fact type
1305
+ graph_results = []
1306
+ graph_time = 0.0
1307
+ mpfp_timing = None
1308
+ for gr in graph_results_list:
1309
+ if gr[0] == ft:
1310
+ graph_results = gr[1]
1311
+ graph_time = gr[2]
1312
+ mpfp_timing = gr[3]
1313
+ if mpfp_timing:
1314
+ all_mpfp_timings.append(mpfp_timing)
1315
+ break
1316
+
1317
+ # Get temporal results for this fact type from combined result
1318
+ temporal_results = temporal_results_by_ft.get(ft) if temporal_constraint else None
1319
+ if temporal_results is not None and len(temporal_results) == 0:
1320
+ temporal_results = None
1321
+
1322
+ results_by_fact_type[ft] = ParallelRetrievalResult(
1323
+ semantic=semantic_results,
1324
+ bm25=bm25_results,
1325
+ graph=graph_results,
1326
+ temporal=temporal_results,
1327
+ timings={
1328
+ "semantic": semantic_bm25_time / 2, # Approximate split
1329
+ "bm25": semantic_bm25_time / 2,
1330
+ "graph": graph_time,
1331
+ "temporal": temporal_time, # Same for all fact types (single query)
1332
+ "temporal_extraction": temporal_extraction_time,
1333
+ },
1334
+ temporal_constraint=temporal_constraint,
1335
+ mpfp_timings=[mpfp_timing] if mpfp_timing else [],
1336
+ max_conn_wait=max_conn_wait,
1337
+ )
1338
+
1339
+ total_time = time.time() - start_time
1340
+ timings["total"] = total_time
1341
+
1342
+ return MultiFactTypeRetrievalResult(
1343
+ results_by_fact_type=results_by_fact_type,
1344
+ timings=timings,
1345
+ max_conn_wait=max_conn_wait,
1346
+ )