hindsight-api 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. hindsight_api/__init__.py +10 -9
  2. hindsight_api/alembic/env.py +5 -8
  3. hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
  4. hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
  5. hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
  6. hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
  7. hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
  8. hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
  9. hindsight_api/api/__init__.py +10 -10
  10. hindsight_api/api/http.py +575 -593
  11. hindsight_api/api/mcp.py +31 -33
  12. hindsight_api/banner.py +13 -6
  13. hindsight_api/config.py +17 -12
  14. hindsight_api/engine/__init__.py +9 -9
  15. hindsight_api/engine/cross_encoder.py +23 -27
  16. hindsight_api/engine/db_utils.py +5 -4
  17. hindsight_api/engine/embeddings.py +22 -21
  18. hindsight_api/engine/entity_resolver.py +81 -75
  19. hindsight_api/engine/llm_wrapper.py +74 -88
  20. hindsight_api/engine/memory_engine.py +663 -673
  21. hindsight_api/engine/query_analyzer.py +100 -97
  22. hindsight_api/engine/response_models.py +105 -106
  23. hindsight_api/engine/retain/__init__.py +9 -16
  24. hindsight_api/engine/retain/bank_utils.py +34 -58
  25. hindsight_api/engine/retain/chunk_storage.py +4 -12
  26. hindsight_api/engine/retain/deduplication.py +9 -28
  27. hindsight_api/engine/retain/embedding_processing.py +4 -11
  28. hindsight_api/engine/retain/embedding_utils.py +3 -4
  29. hindsight_api/engine/retain/entity_processing.py +7 -17
  30. hindsight_api/engine/retain/fact_extraction.py +155 -165
  31. hindsight_api/engine/retain/fact_storage.py +11 -23
  32. hindsight_api/engine/retain/link_creation.py +11 -39
  33. hindsight_api/engine/retain/link_utils.py +166 -95
  34. hindsight_api/engine/retain/observation_regeneration.py +39 -52
  35. hindsight_api/engine/retain/orchestrator.py +72 -62
  36. hindsight_api/engine/retain/types.py +49 -43
  37. hindsight_api/engine/search/__init__.py +15 -1
  38. hindsight_api/engine/search/fusion.py +6 -15
  39. hindsight_api/engine/search/graph_retrieval.py +234 -0
  40. hindsight_api/engine/search/mpfp_retrieval.py +438 -0
  41. hindsight_api/engine/search/observation_utils.py +9 -16
  42. hindsight_api/engine/search/reranking.py +4 -7
  43. hindsight_api/engine/search/retrieval.py +388 -193
  44. hindsight_api/engine/search/scoring.py +5 -7
  45. hindsight_api/engine/search/temporal_extraction.py +8 -11
  46. hindsight_api/engine/search/think_utils.py +115 -39
  47. hindsight_api/engine/search/trace.py +68 -38
  48. hindsight_api/engine/search/tracer.py +49 -35
  49. hindsight_api/engine/search/types.py +22 -16
  50. hindsight_api/engine/task_backend.py +21 -26
  51. hindsight_api/engine/utils.py +25 -10
  52. hindsight_api/main.py +21 -40
  53. hindsight_api/mcp_local.py +190 -0
  54. hindsight_api/metrics.py +44 -30
  55. hindsight_api/migrations.py +10 -8
  56. hindsight_api/models.py +60 -72
  57. hindsight_api/pg0.py +64 -337
  58. hindsight_api/server.py +3 -6
  59. {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +6 -5
  60. hindsight_api-0.1.6.dist-info/RECORD +64 -0
  61. {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
  62. hindsight_api-0.1.4.dist-info/RECORD +0 -61
  63. {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
@@ -4,24 +4,68 @@ Retrieval module for 4-way parallel search.
4
4
  Implements:
5
5
  1. Semantic retrieval (vector similarity)
6
6
  2. BM25 retrieval (keyword/full-text search)
7
- 3. Graph retrieval (spreading activation)
7
+ 3. Graph retrieval (via pluggable GraphRetriever interface)
8
8
  4. Temporal retrieval (time-aware search with spreading)
9
9
  """
10
10
 
11
- from typing import List, Dict, Any, Tuple, Optional
12
- from datetime import datetime
13
11
  import asyncio
12
+ import logging
13
+ from dataclasses import dataclass, field
14
+ from datetime import UTC, datetime
15
+ from typing import Optional
16
+
17
+ from ...config import get_config
14
18
  from ..db_utils import acquire_with_retry
19
+ from .graph_retrieval import BFSGraphRetriever, GraphRetriever
20
+ from .mpfp_retrieval import MPFPGraphRetriever
15
21
  from .types import RetrievalResult
16
22
 
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class ParallelRetrievalResult:
28
+ """Result from parallel retrieval across all methods."""
29
+
30
+ semantic: list[RetrievalResult]
31
+ bm25: list[RetrievalResult]
32
+ graph: list[RetrievalResult]
33
+ temporal: list[RetrievalResult] | None
34
+ timings: dict[str, float] = field(default_factory=dict)
35
+ temporal_constraint: tuple | None = None # (start_date, end_date)
36
+
37
+
38
+ # Default graph retriever instance (can be overridden)
39
+ _default_graph_retriever: GraphRetriever | None = None
40
+
41
+
42
+ def get_default_graph_retriever() -> GraphRetriever:
43
+ """Get or create the default graph retriever based on config."""
44
+ global _default_graph_retriever
45
+ if _default_graph_retriever is None:
46
+ config = get_config()
47
+ retriever_type = config.graph_retriever.lower()
48
+ if retriever_type == "mpfp":
49
+ _default_graph_retriever = MPFPGraphRetriever()
50
+ logger.info("Using MPFP graph retriever")
51
+ elif retriever_type == "bfs":
52
+ _default_graph_retriever = BFSGraphRetriever()
53
+ logger.info("Using BFS graph retriever")
54
+ else:
55
+ logger.warning(f"Unknown graph retriever '{retriever_type}', falling back to MPFP")
56
+ _default_graph_retriever = MPFPGraphRetriever()
57
+ return _default_graph_retriever
58
+
59
+
60
+ def set_default_graph_retriever(retriever: GraphRetriever) -> None:
61
+ """Set the default graph retriever (for configuration/testing)."""
62
+ global _default_graph_retriever
63
+ _default_graph_retriever = retriever
64
+
17
65
 
18
66
  async def retrieve_semantic(
19
- conn,
20
- query_emb_str: str,
21
- bank_id: str,
22
- fact_type: str,
23
- limit: int
24
- ) -> List[RetrievalResult]:
67
+ conn, query_emb_str: str, bank_id: str, fact_type: str, limit: int
68
+ ) -> list[RetrievalResult]:
25
69
  """
26
70
  Semantic retrieval via vector similarity.
27
71
 
@@ -47,18 +91,15 @@ async def retrieve_semantic(
47
91
  ORDER BY embedding <=> $1::vector
48
92
  LIMIT $4
49
93
  """,
50
- query_emb_str, bank_id, fact_type, limit
94
+ query_emb_str,
95
+ bank_id,
96
+ fact_type,
97
+ limit,
51
98
  )
52
99
  return [RetrievalResult.from_db_row(dict(r)) for r in results]
53
100
 
54
101
 
55
- async def retrieve_bm25(
56
- conn,
57
- query_text: str,
58
- bank_id: str,
59
- fact_type: str,
60
- limit: int
61
- ) -> List[RetrievalResult]:
102
+ async def retrieve_bm25(conn, query_text: str, bank_id: str, fact_type: str, limit: int) -> list[RetrievalResult]:
62
103
  """
63
104
  BM25 keyword retrieval via full-text search.
64
105
 
@@ -76,7 +117,7 @@ async def retrieve_bm25(
76
117
 
77
118
  # Sanitize query text: remove special characters that have meaning in tsquery
78
119
  # Keep only alphanumeric characters and spaces
79
- sanitized_text = re.sub(r'[^\w\s]', ' ', query_text.lower())
120
+ sanitized_text = re.sub(r"[^\w\s]", " ", query_text.lower())
80
121
 
81
122
  # Split and filter empty strings
82
123
  tokens = [token for token in sanitized_text.split() if token]
@@ -100,126 +141,14 @@ async def retrieve_bm25(
100
141
  ORDER BY bm25_score DESC
101
142
  LIMIT $4
102
143
  """,
103
- query_tsquery, bank_id, fact_type, limit
144
+ query_tsquery,
145
+ bank_id,
146
+ fact_type,
147
+ limit,
104
148
  )
105
149
  return [RetrievalResult.from_db_row(dict(r)) for r in results]
106
150
 
107
151
 
108
- async def retrieve_graph(
109
- conn,
110
- query_emb_str: str,
111
- bank_id: str,
112
- fact_type: str,
113
- budget: int
114
- ) -> List[RetrievalResult]:
115
- """
116
- Graph retrieval via spreading activation.
117
-
118
- Args:
119
- conn: Database connection
120
- query_emb_str: Query embedding as string
121
- agent_id: bank ID
122
- fact_type: Fact type to filter
123
- budget: Node budget for graph traversal
124
-
125
- Returns:
126
- List of RetrievalResult objects
127
- """
128
- # Find entry points
129
- entry_points = await conn.fetch(
130
- """
131
- SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
132
- 1 - (embedding <=> $1::vector) AS similarity
133
- FROM memory_units
134
- WHERE bank_id = $2
135
- AND embedding IS NOT NULL
136
- AND fact_type = $3
137
- AND (1 - (embedding <=> $1::vector)) >= 0.5
138
- ORDER BY embedding <=> $1::vector
139
- LIMIT 5
140
- """,
141
- query_emb_str, bank_id, fact_type
142
- )
143
-
144
- if not entry_points:
145
- return []
146
-
147
- # BFS-style spreading activation with batched neighbor fetching
148
- visited = set()
149
- results = []
150
- queue = [(RetrievalResult.from_db_row(dict(r)), r["similarity"]) for r in entry_points]
151
- budget_remaining = budget
152
-
153
- # Process nodes in batches to reduce DB roundtrips
154
- batch_size = 20 # Fetch neighbors for up to 20 nodes at once
155
-
156
- while queue and budget_remaining > 0:
157
- # Collect a batch of nodes to process
158
- batch_nodes = []
159
- batch_activations = {}
160
-
161
- while queue and len(batch_nodes) < batch_size and budget_remaining > 0:
162
- current, activation = queue.pop(0)
163
- unit_id = current.id
164
-
165
- if unit_id not in visited:
166
- visited.add(unit_id)
167
- budget_remaining -= 1
168
- results.append(current)
169
- batch_nodes.append(current.id)
170
- batch_activations[unit_id] = activation
171
-
172
- # Batch fetch neighbors for all nodes in this batch
173
- # Fetch top weighted neighbors (batch_size * 20 = ~400 for good distribution)
174
- if batch_nodes and budget_remaining > 0:
175
- max_neighbors = len(batch_nodes) * 20
176
- neighbors = await conn.fetch(
177
- """
178
- SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end, mu.mentioned_at,
179
- mu.access_count, mu.embedding, mu.fact_type, mu.document_id, mu.chunk_id,
180
- ml.weight, ml.link_type, ml.from_unit_id
181
- FROM memory_links ml
182
- JOIN memory_units mu ON ml.to_unit_id = mu.id
183
- WHERE ml.from_unit_id = ANY($1::uuid[])
184
- AND ml.weight >= 0.1
185
- AND mu.fact_type = $2
186
- ORDER BY ml.weight DESC
187
- LIMIT $3
188
- """,
189
- batch_nodes, fact_type, max_neighbors
190
- )
191
-
192
- for n in neighbors:
193
- neighbor_id = str(n["id"])
194
- if neighbor_id not in visited:
195
- # Get parent activation
196
- parent_id = str(n["from_unit_id"])
197
- activation = batch_activations.get(parent_id, 0.5)
198
-
199
- # Boost activation for causal links (they're high-value relationships)
200
- link_type = n["link_type"]
201
- base_weight = n["weight"]
202
-
203
- # Causal links get 1.5-2.0x boost depending on type
204
- if link_type in ("causes", "caused_by"):
205
- # Direct causation - very strong relationship
206
- causal_boost = 2.0
207
- elif link_type in ("enables", "prevents"):
208
- # Conditional causation - strong but not as direct
209
- causal_boost = 1.5
210
- else:
211
- # Temporal, semantic, entity links - standard weight
212
- causal_boost = 1.0
213
-
214
- effective_weight = base_weight * causal_boost
215
- new_activation = activation * effective_weight * 0.8
216
- if new_activation > 0.1:
217
- neighbor_result = RetrievalResult.from_db_row(dict(n))
218
- queue.append((neighbor_result, new_activation))
219
-
220
- return results
221
-
222
-
223
152
  async def retrieve_temporal(
224
153
  conn,
225
154
  query_emb_str: str,
@@ -228,8 +157,8 @@ async def retrieve_temporal(
228
157
  start_date: datetime,
229
158
  end_date: datetime,
230
159
  budget: int,
231
- semantic_threshold: float = 0.1
232
- ) -> List[RetrievalResult]:
160
+ semantic_threshold: float = 0.1,
161
+ ) -> list[RetrievalResult]:
233
162
  """
234
163
  Temporal retrieval with spreading activation.
235
164
 
@@ -251,13 +180,12 @@ async def retrieve_temporal(
251
180
  Returns:
252
181
  List of RetrievalResult objects with temporal scores
253
182
  """
254
- from datetime import timezone
255
183
 
256
184
  # Ensure start_date and end_date are timezone-aware (UTC) to match database datetimes
257
185
  if start_date.tzinfo is None:
258
- start_date = start_date.replace(tzinfo=timezone.utc)
186
+ start_date = start_date.replace(tzinfo=UTC)
259
187
  if end_date.tzinfo is None:
260
- end_date = end_date.replace(tzinfo=timezone.utc)
188
+ end_date = end_date.replace(tzinfo=UTC)
261
189
 
262
190
  entry_points = await conn.fetch(
263
191
  """
@@ -284,7 +212,12 @@ async def retrieve_temporal(
284
212
  ORDER BY COALESCE(occurred_start, mentioned_at, occurred_end) DESC, (embedding <=> $1::vector) ASC
285
213
  LIMIT 10
286
214
  """,
287
- query_emb_str, bank_id, fact_type, start_date, end_date, semantic_threshold
215
+ query_emb_str,
216
+ bank_id,
217
+ fact_type,
218
+ start_date,
219
+ end_date,
220
+ semantic_threshold,
288
221
  )
289
222
 
290
223
  if not entry_points:
@@ -327,7 +260,9 @@ async def retrieve_temporal(
327
260
  results.append(ep_result)
328
261
 
329
262
  # Spread through temporal links
330
- queue = [(RetrievalResult.from_db_row(dict(ep)), ep["similarity"], 1.0) for ep in entry_points] # (unit, semantic_sim, temporal_score)
263
+ queue = [
264
+ (RetrievalResult.from_db_row(dict(ep)), ep["similarity"], 1.0) for ep in entry_points
265
+ ] # (unit, semantic_sim, temporal_score)
331
266
  budget_remaining = budget - len(entry_points)
332
267
 
333
268
  while queue and budget_remaining > 0:
@@ -352,7 +287,10 @@ async def retrieve_temporal(
352
287
  ORDER BY ml.weight DESC
353
288
  LIMIT 10
354
289
  """,
355
- query_emb_str, current.id, fact_type, semantic_threshold
290
+ query_emb_str,
291
+ current.id,
292
+ fact_type,
293
+ semantic_threshold,
356
294
  )
357
295
 
358
296
  for n in neighbors:
@@ -376,7 +314,9 @@ async def retrieve_temporal(
376
314
 
377
315
  if neighbor_best_date:
378
316
  days_from_mid = abs((neighbor_best_date - mid_date).total_seconds() / 86400)
379
- neighbor_temporal_proximity = 1.0 - min(days_from_mid / (total_days / 2), 1.0) if total_days > 0 else 1.0
317
+ neighbor_temporal_proximity = (
318
+ 1.0 - min(days_from_mid / (total_days / 2), 1.0) if total_days > 0 else 1.0
319
+ )
380
320
  else:
381
321
  neighbor_temporal_proximity = 0.3 # Lower score if no temporal data
382
322
 
@@ -418,9 +358,10 @@ async def retrieve_parallel(
418
358
  bank_id: str,
419
359
  fact_type: str,
420
360
  thinking_budget: int,
421
- question_date: Optional[datetime] = None,
422
- query_analyzer: Optional["QueryAnalyzer"] = None
423
- ) -> Tuple[List[RetrievalResult], List[RetrievalResult], List[RetrievalResult], Optional[List[RetrievalResult]], Dict[str, float], Optional[Tuple[datetime, datetime]]]:
361
+ question_date: datetime | None = None,
362
+ query_analyzer: Optional["QueryAnalyzer"] = None,
363
+ graph_retriever: GraphRetriever | None = None,
364
+ ) -> ParallelRetrievalResult:
424
365
  """
425
366
  Run 3-way or 4-way parallel retrieval (adds temporal if detected).
426
367
 
@@ -428,76 +369,330 @@ async def retrieve_parallel(
428
369
  pool: Database connection pool
429
370
  query_text: Query text
430
371
  query_embedding_str: Query embedding as string
431
- agent_id: bank ID
372
+ bank_id: Bank ID
432
373
  fact_type: Fact type to filter
433
374
  thinking_budget: Budget for graph traversal and retrieval limits
434
375
  question_date: Optional date when question was asked (for temporal filtering)
435
376
  query_analyzer: Query analyzer to use (defaults to TransformerQueryAnalyzer)
377
+ graph_retriever: Graph retrieval strategy (defaults to configured retriever)
436
378
 
437
379
  Returns:
438
- Tuple of (semantic_results, bm25_results, graph_results, temporal_results, timings, temporal_constraint)
439
- Each results list contains RetrievalResult objects
440
- temporal_results is None if no temporal constraint detected
441
- timings is a dict with per-method latencies in seconds
442
- temporal_constraint is the (start_date, end_date) tuple if detected, else None
380
+ ParallelRetrievalResult with semantic, bm25, graph, temporal results and timings
443
381
  """
444
- # Detect temporal constraint
445
382
  from .temporal_extraction import extract_temporal_constraint
383
+
384
+ temporal_constraint = extract_temporal_constraint(query_text, reference_date=question_date, analyzer=query_analyzer)
385
+
386
+ retriever = graph_retriever or get_default_graph_retriever()
387
+
388
+ if retriever.name == "mpfp":
389
+ return await _retrieve_parallel_mpfp(
390
+ pool, query_text, query_embedding_str, bank_id, fact_type, thinking_budget, temporal_constraint, retriever
391
+ )
392
+ else:
393
+ return await _retrieve_parallel_bfs(
394
+ pool, query_text, query_embedding_str, bank_id, fact_type, thinking_budget, temporal_constraint, retriever
395
+ )
396
+
397
+
398
+ @dataclass
399
+ class _SemanticGraphResult:
400
+ """Internal result from semantic→graph chain."""
401
+
402
+ semantic: list[RetrievalResult]
403
+ graph: list[RetrievalResult]
404
+ semantic_time: float
405
+ graph_time: float
406
+
407
+
408
+ @dataclass
409
+ class _TimedResult:
410
+ """Internal result with timing."""
411
+
412
+ results: list[RetrievalResult]
413
+ time: float
414
+
415
+
416
+ async def _retrieve_parallel_mpfp(
417
+ pool,
418
+ query_text: str,
419
+ query_embedding_str: str,
420
+ bank_id: str,
421
+ fact_type: str,
422
+ thinking_budget: int,
423
+ temporal_constraint: tuple | None,
424
+ retriever: GraphRetriever,
425
+ ) -> ParallelRetrievalResult:
426
+ """
427
+ MPFP retrieval with optimized parallelization.
428
+
429
+ Runs 2-3 parallel task chains:
430
+ - Task 1: Semantic → Graph (chained, graph uses semantic seeds)
431
+ - Task 2: BM25 (independent)
432
+ - Task 3: Temporal (if constraint detected)
433
+ """
446
434
  import time
447
435
 
448
- temporal_constraint = extract_temporal_constraint(
449
- query_text, reference_date=question_date, analyzer=query_analyzer
450
- )
436
+ async def run_semantic_then_graph() -> _SemanticGraphResult:
437
+ """Chain: semantic retrieval → graph retrieval (using semantic as seeds)."""
438
+ start = time.time()
439
+ async with acquire_with_retry(pool) as conn:
440
+ semantic = await retrieve_semantic(conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget)
441
+ semantic_time = time.time() - start
442
+
443
+ # Get temporal seeds if needed (quick query, part of this chain)
444
+ temporal_seeds = None
445
+ if temporal_constraint:
446
+ tc_start, tc_end = temporal_constraint
447
+ async with acquire_with_retry(pool) as conn:
448
+ temporal_seeds = await _get_temporal_entry_points(
449
+ conn, query_embedding_str, bank_id, fact_type, tc_start, tc_end, limit=20
450
+ )
451
+
452
+ # Run graph with seeds
453
+ start = time.time()
454
+ graph = await retriever.retrieve(
455
+ pool=pool,
456
+ query_embedding_str=query_embedding_str,
457
+ bank_id=bank_id,
458
+ fact_type=fact_type,
459
+ budget=thinking_budget,
460
+ query_text=query_text,
461
+ semantic_seeds=semantic,
462
+ temporal_seeds=temporal_seeds,
463
+ )
464
+ graph_time = time.time() - start
465
+
466
+ return _SemanticGraphResult(semantic, graph, semantic_time, graph_time)
451
467
 
452
- # Wrapper to track timing for each retrieval method
453
- async def timed_retrieval(name: str, coro):
468
+ async def run_bm25() -> _TimedResult:
469
+ """Independent BM25 retrieval."""
454
470
  start = time.time()
455
- result = await coro
456
- duration = time.time() - start
457
- return result, name, duration
471
+ async with acquire_with_retry(pool) as conn:
472
+ results = await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget)
473
+ return _TimedResult(results, time.time() - start)
458
474
 
459
- async def run_semantic():
475
+ async def run_temporal(tc_start, tc_end) -> _TimedResult:
476
+ """Temporal retrieval (uses its own entry point finding)."""
477
+ start = time.time()
460
478
  async with acquire_with_retry(pool) as conn:
461
- return await retrieve_semantic(conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget)
479
+ results = await retrieve_temporal(
480
+ conn,
481
+ query_embedding_str,
482
+ bank_id,
483
+ fact_type,
484
+ tc_start,
485
+ tc_end,
486
+ budget=thinking_budget,
487
+ semantic_threshold=0.1,
488
+ )
489
+ return _TimedResult(results, time.time() - start)
490
+
491
+ # Run parallel task chains
492
+ if temporal_constraint:
493
+ tc_start, tc_end = temporal_constraint
494
+ sg_result, bm25_result, temporal_result = await asyncio.gather(
495
+ run_semantic_then_graph(),
496
+ run_bm25(),
497
+ run_temporal(tc_start, tc_end),
498
+ )
499
+ return ParallelRetrievalResult(
500
+ semantic=sg_result.semantic,
501
+ bm25=bm25_result.results,
502
+ graph=sg_result.graph,
503
+ temporal=temporal_result.results,
504
+ timings={
505
+ "semantic": sg_result.semantic_time,
506
+ "graph": sg_result.graph_time,
507
+ "bm25": bm25_result.time,
508
+ "temporal": temporal_result.time,
509
+ },
510
+ temporal_constraint=temporal_constraint,
511
+ )
512
+ else:
513
+ sg_result, bm25_result = await asyncio.gather(
514
+ run_semantic_then_graph(),
515
+ run_bm25(),
516
+ )
517
+ return ParallelRetrievalResult(
518
+ semantic=sg_result.semantic,
519
+ bm25=bm25_result.results,
520
+ graph=sg_result.graph,
521
+ temporal=None,
522
+ timings={
523
+ "semantic": sg_result.semantic_time,
524
+ "graph": sg_result.graph_time,
525
+ "bm25": bm25_result.time,
526
+ },
527
+ temporal_constraint=None,
528
+ )
529
+
530
+
531
+ async def _get_temporal_entry_points(
532
+ conn,
533
+ query_embedding_str: str,
534
+ bank_id: str,
535
+ fact_type: str,
536
+ start_date: datetime,
537
+ end_date: datetime,
538
+ limit: int = 20,
539
+ semantic_threshold: float = 0.1,
540
+ ) -> list[RetrievalResult]:
541
+ """Get temporal entry points (facts in date range with semantic relevance)."""
542
+
543
+ if start_date.tzinfo is None:
544
+ start_date = start_date.replace(tzinfo=UTC)
545
+ if end_date.tzinfo is None:
546
+ end_date = end_date.replace(tzinfo=UTC)
547
+
548
+ rows = await conn.fetch(
549
+ """
550
+ SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at,
551
+ access_count, embedding, fact_type, document_id, chunk_id,
552
+ 1 - (embedding <=> $1::vector) AS similarity
553
+ FROM memory_units
554
+ WHERE bank_id = $2
555
+ AND fact_type = $3
556
+ AND embedding IS NOT NULL
557
+ AND (
558
+ (occurred_start IS NOT NULL AND occurred_end IS NOT NULL
559
+ AND occurred_start <= $5 AND occurred_end >= $4)
560
+ OR (mentioned_at IS NOT NULL AND mentioned_at BETWEEN $4 AND $5)
561
+ OR (occurred_start IS NOT NULL AND occurred_start BETWEEN $4 AND $5)
562
+ OR (occurred_end IS NOT NULL AND occurred_end BETWEEN $4 AND $5)
563
+ )
564
+ AND (1 - (embedding <=> $1::vector)) >= $6
565
+ ORDER BY COALESCE(occurred_start, mentioned_at, occurred_end) DESC,
566
+ (embedding <=> $1::vector) ASC
567
+ LIMIT $7
568
+ """,
569
+ query_embedding_str,
570
+ bank_id,
571
+ fact_type,
572
+ start_date,
573
+ end_date,
574
+ semantic_threshold,
575
+ limit,
576
+ )
577
+
578
+ results = []
579
+ total_days = max((end_date - start_date).total_seconds() / 86400, 1)
580
+ mid_date = start_date + (end_date - start_date) / 2
581
+
582
+ for row in rows:
583
+ result = RetrievalResult.from_db_row(dict(row))
584
+
585
+ # Calculate temporal proximity score
586
+ best_date = None
587
+ if row["occurred_start"] and row["occurred_end"]:
588
+ best_date = row["occurred_start"] + (row["occurred_end"] - row["occurred_start"]) / 2
589
+ elif row["occurred_start"]:
590
+ best_date = row["occurred_start"]
591
+ elif row["occurred_end"]:
592
+ best_date = row["occurred_end"]
593
+ elif row["mentioned_at"]:
594
+ best_date = row["mentioned_at"]
595
+
596
+ if best_date:
597
+ days_from_mid = abs((best_date - mid_date).total_seconds() / 86400)
598
+ result.temporal_proximity = 1.0 - min(days_from_mid / (total_days / 2), 1.0)
599
+ else:
600
+ result.temporal_proximity = 0.5
462
601
 
463
- async def run_bm25():
602
+ result.temporal_score = result.temporal_proximity
603
+ results.append(result)
604
+
605
+ return results
606
+
607
+
608
+ async def _retrieve_parallel_bfs(
609
+ pool,
610
+ query_text: str,
611
+ query_embedding_str: str,
612
+ bank_id: str,
613
+ fact_type: str,
614
+ thinking_budget: int,
615
+ temporal_constraint: tuple | None,
616
+ retriever: GraphRetriever,
617
+ ) -> ParallelRetrievalResult:
618
+ """BFS retrieval: all methods run in parallel (original behavior)."""
619
+ import time
620
+
621
+ async def run_semantic() -> _TimedResult:
622
+ start = time.time()
464
623
  async with acquire_with_retry(pool) as conn:
465
- return await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget)
624
+ results = await retrieve_semantic(conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget)
625
+ return _TimedResult(results, time.time() - start)
466
626
 
467
- async def run_graph():
627
+ async def run_bm25() -> _TimedResult:
628
+ start = time.time()
468
629
  async with acquire_with_retry(pool) as conn:
469
- return await retrieve_graph(conn, query_embedding_str, bank_id, fact_type, budget=thinking_budget)
630
+ results = await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget)
631
+ return _TimedResult(results, time.time() - start)
632
+
633
+ async def run_graph() -> _TimedResult:
634
+ start = time.time()
635
+ results = await retriever.retrieve(
636
+ pool=pool,
637
+ query_embedding_str=query_embedding_str,
638
+ bank_id=bank_id,
639
+ fact_type=fact_type,
640
+ budget=thinking_budget,
641
+ query_text=query_text,
642
+ )
643
+ return _TimedResult(results, time.time() - start)
470
644
 
471
- async def run_temporal(start_date, end_date):
645
+ async def run_temporal(tc_start, tc_end) -> _TimedResult:
646
+ start = time.time()
472
647
  async with acquire_with_retry(pool) as conn:
473
- return await retrieve_temporal(
474
- conn, query_embedding_str, bank_id, fact_type,
475
- start_date, end_date, budget=thinking_budget, semantic_threshold=0.1
648
+ results = await retrieve_temporal(
649
+ conn,
650
+ query_embedding_str,
651
+ bank_id,
652
+ fact_type,
653
+ tc_start,
654
+ tc_end,
655
+ budget=thinking_budget,
656
+ semantic_threshold=0.1,
476
657
  )
658
+ return _TimedResult(results, time.time() - start)
477
659
 
478
- # Run retrievals in parallel with timing
479
- timings = {}
480
660
  if temporal_constraint:
481
- start_date, end_date = temporal_constraint
482
- results = await asyncio.gather(
483
- timed_retrieval("semantic", run_semantic()),
484
- timed_retrieval("bm25", run_bm25()),
485
- timed_retrieval("graph", run_graph()),
486
- timed_retrieval("temporal", run_temporal(start_date, end_date))
661
+ tc_start, tc_end = temporal_constraint
662
+ semantic_r, bm25_r, graph_r, temporal_r = await asyncio.gather(
663
+ run_semantic(),
664
+ run_bm25(),
665
+ run_graph(),
666
+ run_temporal(tc_start, tc_end),
667
+ )
668
+ return ParallelRetrievalResult(
669
+ semantic=semantic_r.results,
670
+ bm25=bm25_r.results,
671
+ graph=graph_r.results,
672
+ temporal=temporal_r.results,
673
+ timings={
674
+ "semantic": semantic_r.time,
675
+ "bm25": bm25_r.time,
676
+ "graph": graph_r.time,
677
+ "temporal": temporal_r.time,
678
+ },
679
+ temporal_constraint=temporal_constraint,
487
680
  )
488
- semantic_results, _, timings["semantic"] = results[0]
489
- bm25_results, _, timings["bm25"] = results[1]
490
- graph_results, _, timings["graph"] = results[2]
491
- temporal_results, _, timings["temporal"] = results[3]
492
681
  else:
493
- results = await asyncio.gather(
494
- timed_retrieval("semantic", run_semantic()),
495
- timed_retrieval("bm25", run_bm25()),
496
- timed_retrieval("graph", run_graph())
682
+ semantic_r, bm25_r, graph_r = await asyncio.gather(
683
+ run_semantic(),
684
+ run_bm25(),
685
+ run_graph(),
686
+ )
687
+ return ParallelRetrievalResult(
688
+ semantic=semantic_r.results,
689
+ bm25=bm25_r.results,
690
+ graph=graph_r.results,
691
+ temporal=None,
692
+ timings={
693
+ "semantic": semantic_r.time,
694
+ "bm25": bm25_r.time,
695
+ "graph": graph_r.time,
696
+ },
697
+ temporal_constraint=None,
497
698
  )
498
- semantic_results, _, timings["semantic"] = results[0]
499
- bm25_results, _, timings["bm25"] = results[1]
500
- graph_results, _, timings["graph"] = results[2]
501
- temporal_results = None
502
-
503
- return semantic_results, bm25_results, graph_results, temporal_results, timings, temporal_constraint