hindsight-api 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/api/mcp.py +1 -5
- hindsight_api/config.py +9 -0
- hindsight_api/engine/cross_encoder.py +1 -6
- hindsight_api/engine/llm_wrapper.py +13 -9
- hindsight_api/engine/memory_engine.py +71 -59
- hindsight_api/engine/search/__init__.py +15 -1
- hindsight_api/engine/search/graph_retrieval.py +235 -0
- hindsight_api/engine/search/mpfp_retrieval.py +454 -0
- hindsight_api/engine/search/retrieval.py +337 -163
- hindsight_api/engine/search/trace.py +1 -0
- hindsight_api/engine/search/tracer.py +8 -3
- hindsight_api/engine/search/types.py +4 -1
- hindsight_api/pg0.py +54 -326
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.5.dist-info}/METADATA +6 -5
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.5.dist-info}/RECORD +17 -15
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.5.dist-info}/WHEEL +0 -0
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.5.dist-info}/entry_points.txt +0 -0
|
@@ -4,15 +4,61 @@ Retrieval module for 4-way parallel search.
|
|
|
4
4
|
Implements:
|
|
5
5
|
1. Semantic retrieval (vector similarity)
|
|
6
6
|
2. BM25 retrieval (keyword/full-text search)
|
|
7
|
-
3. Graph retrieval (
|
|
7
|
+
3. Graph retrieval (via pluggable GraphRetriever interface)
|
|
8
8
|
4. Temporal retrieval (time-aware search with spreading)
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
-
from typing import List, Dict,
|
|
11
|
+
from typing import List, Dict, Optional
|
|
12
|
+
from dataclasses import dataclass, field
|
|
12
13
|
from datetime import datetime
|
|
13
14
|
import asyncio
|
|
15
|
+
import logging
|
|
14
16
|
from ..db_utils import acquire_with_retry
|
|
15
17
|
from .types import RetrievalResult
|
|
18
|
+
from .graph_retrieval import GraphRetriever, BFSGraphRetriever
|
|
19
|
+
from .mpfp_retrieval import MPFPGraphRetriever
|
|
20
|
+
from ...config import get_config
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class ParallelRetrievalResult:
|
|
27
|
+
"""Result from parallel retrieval across all methods."""
|
|
28
|
+
semantic: List[RetrievalResult]
|
|
29
|
+
bm25: List[RetrievalResult]
|
|
30
|
+
graph: List[RetrievalResult]
|
|
31
|
+
temporal: Optional[List[RetrievalResult]]
|
|
32
|
+
timings: Dict[str, float] = field(default_factory=dict)
|
|
33
|
+
temporal_constraint: Optional[tuple] = None # (start_date, end_date)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Default graph retriever instance (can be overridden)
|
|
37
|
+
_default_graph_retriever: Optional[GraphRetriever] = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_default_graph_retriever() -> GraphRetriever:
|
|
41
|
+
"""Get or create the default graph retriever based on config."""
|
|
42
|
+
global _default_graph_retriever
|
|
43
|
+
if _default_graph_retriever is None:
|
|
44
|
+
config = get_config()
|
|
45
|
+
retriever_type = config.graph_retriever.lower()
|
|
46
|
+
if retriever_type == "mpfp":
|
|
47
|
+
_default_graph_retriever = MPFPGraphRetriever()
|
|
48
|
+
logger.info("Using MPFP graph retriever")
|
|
49
|
+
elif retriever_type == "bfs":
|
|
50
|
+
_default_graph_retriever = BFSGraphRetriever()
|
|
51
|
+
logger.info("Using BFS graph retriever")
|
|
52
|
+
else:
|
|
53
|
+
logger.warning(f"Unknown graph retriever '{retriever_type}', falling back to MPFP")
|
|
54
|
+
_default_graph_retriever = MPFPGraphRetriever()
|
|
55
|
+
return _default_graph_retriever
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def set_default_graph_retriever(retriever: GraphRetriever) -> None:
|
|
59
|
+
"""Set the default graph retriever (for configuration/testing)."""
|
|
60
|
+
global _default_graph_retriever
|
|
61
|
+
_default_graph_retriever = retriever
|
|
16
62
|
|
|
17
63
|
|
|
18
64
|
async def retrieve_semantic(
|
|
@@ -105,121 +151,6 @@ async def retrieve_bm25(
|
|
|
105
151
|
return [RetrievalResult.from_db_row(dict(r)) for r in results]
|
|
106
152
|
|
|
107
153
|
|
|
108
|
-
async def retrieve_graph(
|
|
109
|
-
conn,
|
|
110
|
-
query_emb_str: str,
|
|
111
|
-
bank_id: str,
|
|
112
|
-
fact_type: str,
|
|
113
|
-
budget: int
|
|
114
|
-
) -> List[RetrievalResult]:
|
|
115
|
-
"""
|
|
116
|
-
Graph retrieval via spreading activation.
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
conn: Database connection
|
|
120
|
-
query_emb_str: Query embedding as string
|
|
121
|
-
agent_id: bank ID
|
|
122
|
-
fact_type: Fact type to filter
|
|
123
|
-
budget: Node budget for graph traversal
|
|
124
|
-
|
|
125
|
-
Returns:
|
|
126
|
-
List of RetrievalResult objects
|
|
127
|
-
"""
|
|
128
|
-
# Find entry points
|
|
129
|
-
entry_points = await conn.fetch(
|
|
130
|
-
"""
|
|
131
|
-
SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
|
|
132
|
-
1 - (embedding <=> $1::vector) AS similarity
|
|
133
|
-
FROM memory_units
|
|
134
|
-
WHERE bank_id = $2
|
|
135
|
-
AND embedding IS NOT NULL
|
|
136
|
-
AND fact_type = $3
|
|
137
|
-
AND (1 - (embedding <=> $1::vector)) >= 0.5
|
|
138
|
-
ORDER BY embedding <=> $1::vector
|
|
139
|
-
LIMIT 5
|
|
140
|
-
""",
|
|
141
|
-
query_emb_str, bank_id, fact_type
|
|
142
|
-
)
|
|
143
|
-
|
|
144
|
-
if not entry_points:
|
|
145
|
-
return []
|
|
146
|
-
|
|
147
|
-
# BFS-style spreading activation with batched neighbor fetching
|
|
148
|
-
visited = set()
|
|
149
|
-
results = []
|
|
150
|
-
queue = [(RetrievalResult.from_db_row(dict(r)), r["similarity"]) for r in entry_points]
|
|
151
|
-
budget_remaining = budget
|
|
152
|
-
|
|
153
|
-
# Process nodes in batches to reduce DB roundtrips
|
|
154
|
-
batch_size = 20 # Fetch neighbors for up to 20 nodes at once
|
|
155
|
-
|
|
156
|
-
while queue and budget_remaining > 0:
|
|
157
|
-
# Collect a batch of nodes to process
|
|
158
|
-
batch_nodes = []
|
|
159
|
-
batch_activations = {}
|
|
160
|
-
|
|
161
|
-
while queue and len(batch_nodes) < batch_size and budget_remaining > 0:
|
|
162
|
-
current, activation = queue.pop(0)
|
|
163
|
-
unit_id = current.id
|
|
164
|
-
|
|
165
|
-
if unit_id not in visited:
|
|
166
|
-
visited.add(unit_id)
|
|
167
|
-
budget_remaining -= 1
|
|
168
|
-
results.append(current)
|
|
169
|
-
batch_nodes.append(current.id)
|
|
170
|
-
batch_activations[unit_id] = activation
|
|
171
|
-
|
|
172
|
-
# Batch fetch neighbors for all nodes in this batch
|
|
173
|
-
# Fetch top weighted neighbors (batch_size * 20 = ~400 for good distribution)
|
|
174
|
-
if batch_nodes and budget_remaining > 0:
|
|
175
|
-
max_neighbors = len(batch_nodes) * 20
|
|
176
|
-
neighbors = await conn.fetch(
|
|
177
|
-
"""
|
|
178
|
-
SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end, mu.mentioned_at,
|
|
179
|
-
mu.access_count, mu.embedding, mu.fact_type, mu.document_id, mu.chunk_id,
|
|
180
|
-
ml.weight, ml.link_type, ml.from_unit_id
|
|
181
|
-
FROM memory_links ml
|
|
182
|
-
JOIN memory_units mu ON ml.to_unit_id = mu.id
|
|
183
|
-
WHERE ml.from_unit_id = ANY($1::uuid[])
|
|
184
|
-
AND ml.weight >= 0.1
|
|
185
|
-
AND mu.fact_type = $2
|
|
186
|
-
ORDER BY ml.weight DESC
|
|
187
|
-
LIMIT $3
|
|
188
|
-
""",
|
|
189
|
-
batch_nodes, fact_type, max_neighbors
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
for n in neighbors:
|
|
193
|
-
neighbor_id = str(n["id"])
|
|
194
|
-
if neighbor_id not in visited:
|
|
195
|
-
# Get parent activation
|
|
196
|
-
parent_id = str(n["from_unit_id"])
|
|
197
|
-
activation = batch_activations.get(parent_id, 0.5)
|
|
198
|
-
|
|
199
|
-
# Boost activation for causal links (they're high-value relationships)
|
|
200
|
-
link_type = n["link_type"]
|
|
201
|
-
base_weight = n["weight"]
|
|
202
|
-
|
|
203
|
-
# Causal links get 1.5-2.0x boost depending on type
|
|
204
|
-
if link_type in ("causes", "caused_by"):
|
|
205
|
-
# Direct causation - very strong relationship
|
|
206
|
-
causal_boost = 2.0
|
|
207
|
-
elif link_type in ("enables", "prevents"):
|
|
208
|
-
# Conditional causation - strong but not as direct
|
|
209
|
-
causal_boost = 1.5
|
|
210
|
-
else:
|
|
211
|
-
# Temporal, semantic, entity links - standard weight
|
|
212
|
-
causal_boost = 1.0
|
|
213
|
-
|
|
214
|
-
effective_weight = base_weight * causal_boost
|
|
215
|
-
new_activation = activation * effective_weight * 0.8
|
|
216
|
-
if new_activation > 0.1:
|
|
217
|
-
neighbor_result = RetrievalResult.from_db_row(dict(n))
|
|
218
|
-
queue.append((neighbor_result, new_activation))
|
|
219
|
-
|
|
220
|
-
return results
|
|
221
|
-
|
|
222
|
-
|
|
223
154
|
async def retrieve_temporal(
|
|
224
155
|
conn,
|
|
225
156
|
query_emb_str: str,
|
|
@@ -419,8 +350,9 @@ async def retrieve_parallel(
|
|
|
419
350
|
fact_type: str,
|
|
420
351
|
thinking_budget: int,
|
|
421
352
|
question_date: Optional[datetime] = None,
|
|
422
|
-
query_analyzer: Optional["QueryAnalyzer"] = None
|
|
423
|
-
|
|
353
|
+
query_analyzer: Optional["QueryAnalyzer"] = None,
|
|
354
|
+
graph_retriever: Optional[GraphRetriever] = None,
|
|
355
|
+
) -> ParallelRetrievalResult:
|
|
424
356
|
"""
|
|
425
357
|
Run 3-way or 4-way parallel retrieval (adds temporal if detected).
|
|
426
358
|
|
|
@@ -428,76 +360,318 @@ async def retrieve_parallel(
|
|
|
428
360
|
pool: Database connection pool
|
|
429
361
|
query_text: Query text
|
|
430
362
|
query_embedding_str: Query embedding as string
|
|
431
|
-
|
|
363
|
+
bank_id: Bank ID
|
|
432
364
|
fact_type: Fact type to filter
|
|
433
365
|
thinking_budget: Budget for graph traversal and retrieval limits
|
|
434
366
|
question_date: Optional date when question was asked (for temporal filtering)
|
|
435
367
|
query_analyzer: Query analyzer to use (defaults to TransformerQueryAnalyzer)
|
|
368
|
+
graph_retriever: Graph retrieval strategy (defaults to configured retriever)
|
|
436
369
|
|
|
437
370
|
Returns:
|
|
438
|
-
|
|
439
|
-
Each results list contains RetrievalResult objects
|
|
440
|
-
temporal_results is None if no temporal constraint detected
|
|
441
|
-
timings is a dict with per-method latencies in seconds
|
|
442
|
-
temporal_constraint is the (start_date, end_date) tuple if detected, else None
|
|
371
|
+
ParallelRetrievalResult with semantic, bm25, graph, temporal results and timings
|
|
443
372
|
"""
|
|
444
|
-
# Detect temporal constraint
|
|
445
373
|
from .temporal_extraction import extract_temporal_constraint
|
|
446
|
-
import time
|
|
447
374
|
|
|
448
375
|
temporal_constraint = extract_temporal_constraint(
|
|
449
376
|
query_text, reference_date=question_date, analyzer=query_analyzer
|
|
450
377
|
)
|
|
451
378
|
|
|
452
|
-
|
|
453
|
-
|
|
379
|
+
retriever = graph_retriever or get_default_graph_retriever()
|
|
380
|
+
|
|
381
|
+
if retriever.name == "mpfp":
|
|
382
|
+
return await _retrieve_parallel_mpfp(
|
|
383
|
+
pool, query_text, query_embedding_str, bank_id, fact_type,
|
|
384
|
+
thinking_budget, temporal_constraint, retriever
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
return await _retrieve_parallel_bfs(
|
|
388
|
+
pool, query_text, query_embedding_str, bank_id, fact_type,
|
|
389
|
+
thinking_budget, temporal_constraint, retriever
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
@dataclass
|
|
394
|
+
class _SemanticGraphResult:
|
|
395
|
+
"""Internal result from semantic→graph chain."""
|
|
396
|
+
semantic: List[RetrievalResult]
|
|
397
|
+
graph: List[RetrievalResult]
|
|
398
|
+
semantic_time: float
|
|
399
|
+
graph_time: float
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
@dataclass
|
|
403
|
+
class _TimedResult:
|
|
404
|
+
"""Internal result with timing."""
|
|
405
|
+
results: List[RetrievalResult]
|
|
406
|
+
time: float
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
async def _retrieve_parallel_mpfp(
|
|
410
|
+
pool,
|
|
411
|
+
query_text: str,
|
|
412
|
+
query_embedding_str: str,
|
|
413
|
+
bank_id: str,
|
|
414
|
+
fact_type: str,
|
|
415
|
+
thinking_budget: int,
|
|
416
|
+
temporal_constraint: Optional[tuple],
|
|
417
|
+
retriever: GraphRetriever,
|
|
418
|
+
) -> ParallelRetrievalResult:
|
|
419
|
+
"""
|
|
420
|
+
MPFP retrieval with optimized parallelization.
|
|
421
|
+
|
|
422
|
+
Runs 2-3 parallel task chains:
|
|
423
|
+
- Task 1: Semantic → Graph (chained, graph uses semantic seeds)
|
|
424
|
+
- Task 2: BM25 (independent)
|
|
425
|
+
- Task 3: Temporal (if constraint detected)
|
|
426
|
+
"""
|
|
427
|
+
import time
|
|
428
|
+
|
|
429
|
+
async def run_semantic_then_graph() -> _SemanticGraphResult:
|
|
430
|
+
"""Chain: semantic retrieval → graph retrieval (using semantic as seeds)."""
|
|
431
|
+
start = time.time()
|
|
432
|
+
async with acquire_with_retry(pool) as conn:
|
|
433
|
+
semantic = await retrieve_semantic(
|
|
434
|
+
conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget
|
|
435
|
+
)
|
|
436
|
+
semantic_time = time.time() - start
|
|
437
|
+
|
|
438
|
+
# Get temporal seeds if needed (quick query, part of this chain)
|
|
439
|
+
temporal_seeds = None
|
|
440
|
+
if temporal_constraint:
|
|
441
|
+
tc_start, tc_end = temporal_constraint
|
|
442
|
+
async with acquire_with_retry(pool) as conn:
|
|
443
|
+
temporal_seeds = await _get_temporal_entry_points(
|
|
444
|
+
conn, query_embedding_str, bank_id, fact_type,
|
|
445
|
+
tc_start, tc_end, limit=20
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# Run graph with seeds
|
|
449
|
+
start = time.time()
|
|
450
|
+
graph = await retriever.retrieve(
|
|
451
|
+
pool=pool,
|
|
452
|
+
query_embedding_str=query_embedding_str,
|
|
453
|
+
bank_id=bank_id,
|
|
454
|
+
fact_type=fact_type,
|
|
455
|
+
budget=thinking_budget,
|
|
456
|
+
query_text=query_text,
|
|
457
|
+
semantic_seeds=semantic,
|
|
458
|
+
temporal_seeds=temporal_seeds,
|
|
459
|
+
)
|
|
460
|
+
graph_time = time.time() - start
|
|
461
|
+
|
|
462
|
+
return _SemanticGraphResult(semantic, graph, semantic_time, graph_time)
|
|
463
|
+
|
|
464
|
+
async def run_bm25() -> _TimedResult:
|
|
465
|
+
"""Independent BM25 retrieval."""
|
|
454
466
|
start = time.time()
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
return
|
|
467
|
+
async with acquire_with_retry(pool) as conn:
|
|
468
|
+
results = await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget)
|
|
469
|
+
return _TimedResult(results, time.time() - start)
|
|
458
470
|
|
|
459
|
-
async def
|
|
471
|
+
async def run_temporal(tc_start, tc_end) -> _TimedResult:
|
|
472
|
+
"""Temporal retrieval (uses its own entry point finding)."""
|
|
473
|
+
start = time.time()
|
|
460
474
|
async with acquire_with_retry(pool) as conn:
|
|
461
|
-
|
|
475
|
+
results = await retrieve_temporal(
|
|
476
|
+
conn, query_embedding_str, bank_id, fact_type,
|
|
477
|
+
tc_start, tc_end, budget=thinking_budget, semantic_threshold=0.1
|
|
478
|
+
)
|
|
479
|
+
return _TimedResult(results, time.time() - start)
|
|
480
|
+
|
|
481
|
+
# Run parallel task chains
|
|
482
|
+
if temporal_constraint:
|
|
483
|
+
tc_start, tc_end = temporal_constraint
|
|
484
|
+
sg_result, bm25_result, temporal_result = await asyncio.gather(
|
|
485
|
+
run_semantic_then_graph(),
|
|
486
|
+
run_bm25(),
|
|
487
|
+
run_temporal(tc_start, tc_end),
|
|
488
|
+
)
|
|
489
|
+
return ParallelRetrievalResult(
|
|
490
|
+
semantic=sg_result.semantic,
|
|
491
|
+
bm25=bm25_result.results,
|
|
492
|
+
graph=sg_result.graph,
|
|
493
|
+
temporal=temporal_result.results,
|
|
494
|
+
timings={
|
|
495
|
+
"semantic": sg_result.semantic_time,
|
|
496
|
+
"graph": sg_result.graph_time,
|
|
497
|
+
"bm25": bm25_result.time,
|
|
498
|
+
"temporal": temporal_result.time,
|
|
499
|
+
},
|
|
500
|
+
temporal_constraint=temporal_constraint,
|
|
501
|
+
)
|
|
502
|
+
else:
|
|
503
|
+
sg_result, bm25_result = await asyncio.gather(
|
|
504
|
+
run_semantic_then_graph(),
|
|
505
|
+
run_bm25(),
|
|
506
|
+
)
|
|
507
|
+
return ParallelRetrievalResult(
|
|
508
|
+
semantic=sg_result.semantic,
|
|
509
|
+
bm25=bm25_result.results,
|
|
510
|
+
graph=sg_result.graph,
|
|
511
|
+
temporal=None,
|
|
512
|
+
timings={
|
|
513
|
+
"semantic": sg_result.semantic_time,
|
|
514
|
+
"graph": sg_result.graph_time,
|
|
515
|
+
"bm25": bm25_result.time,
|
|
516
|
+
},
|
|
517
|
+
temporal_constraint=None,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
async def _get_temporal_entry_points(
|
|
522
|
+
conn,
|
|
523
|
+
query_embedding_str: str,
|
|
524
|
+
bank_id: str,
|
|
525
|
+
fact_type: str,
|
|
526
|
+
start_date: datetime,
|
|
527
|
+
end_date: datetime,
|
|
528
|
+
limit: int = 20,
|
|
529
|
+
semantic_threshold: float = 0.1,
|
|
530
|
+
) -> List[RetrievalResult]:
|
|
531
|
+
"""Get temporal entry points (facts in date range with semantic relevance)."""
|
|
532
|
+
from datetime import timezone
|
|
533
|
+
|
|
534
|
+
if start_date.tzinfo is None:
|
|
535
|
+
start_date = start_date.replace(tzinfo=timezone.utc)
|
|
536
|
+
if end_date.tzinfo is None:
|
|
537
|
+
end_date = end_date.replace(tzinfo=timezone.utc)
|
|
538
|
+
|
|
539
|
+
rows = await conn.fetch(
|
|
540
|
+
"""
|
|
541
|
+
SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at,
|
|
542
|
+
access_count, embedding, fact_type, document_id, chunk_id,
|
|
543
|
+
1 - (embedding <=> $1::vector) AS similarity
|
|
544
|
+
FROM memory_units
|
|
545
|
+
WHERE bank_id = $2
|
|
546
|
+
AND fact_type = $3
|
|
547
|
+
AND embedding IS NOT NULL
|
|
548
|
+
AND (
|
|
549
|
+
(occurred_start IS NOT NULL AND occurred_end IS NOT NULL
|
|
550
|
+
AND occurred_start <= $5 AND occurred_end >= $4)
|
|
551
|
+
OR (mentioned_at IS NOT NULL AND mentioned_at BETWEEN $4 AND $5)
|
|
552
|
+
OR (occurred_start IS NOT NULL AND occurred_start BETWEEN $4 AND $5)
|
|
553
|
+
OR (occurred_end IS NOT NULL AND occurred_end BETWEEN $4 AND $5)
|
|
554
|
+
)
|
|
555
|
+
AND (1 - (embedding <=> $1::vector)) >= $6
|
|
556
|
+
ORDER BY COALESCE(occurred_start, mentioned_at, occurred_end) DESC,
|
|
557
|
+
(embedding <=> $1::vector) ASC
|
|
558
|
+
LIMIT $7
|
|
559
|
+
""",
|
|
560
|
+
query_embedding_str, bank_id, fact_type, start_date, end_date, semantic_threshold, limit
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
results = []
|
|
564
|
+
total_days = max((end_date - start_date).total_seconds() / 86400, 1)
|
|
565
|
+
mid_date = start_date + (end_date - start_date) / 2
|
|
566
|
+
|
|
567
|
+
for row in rows:
|
|
568
|
+
result = RetrievalResult.from_db_row(dict(row))
|
|
569
|
+
|
|
570
|
+
# Calculate temporal proximity score
|
|
571
|
+
best_date = None
|
|
572
|
+
if row["occurred_start"] and row["occurred_end"]:
|
|
573
|
+
best_date = row["occurred_start"] + (row["occurred_end"] - row["occurred_start"]) / 2
|
|
574
|
+
elif row["occurred_start"]:
|
|
575
|
+
best_date = row["occurred_start"]
|
|
576
|
+
elif row["occurred_end"]:
|
|
577
|
+
best_date = row["occurred_end"]
|
|
578
|
+
elif row["mentioned_at"]:
|
|
579
|
+
best_date = row["mentioned_at"]
|
|
580
|
+
|
|
581
|
+
if best_date:
|
|
582
|
+
days_from_mid = abs((best_date - mid_date).total_seconds() / 86400)
|
|
583
|
+
result.temporal_proximity = 1.0 - min(days_from_mid / (total_days / 2), 1.0)
|
|
584
|
+
else:
|
|
585
|
+
result.temporal_proximity = 0.5
|
|
462
586
|
|
|
463
|
-
|
|
587
|
+
result.temporal_score = result.temporal_proximity
|
|
588
|
+
results.append(result)
|
|
589
|
+
|
|
590
|
+
return results
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
async def _retrieve_parallel_bfs(
|
|
594
|
+
pool,
|
|
595
|
+
query_text: str,
|
|
596
|
+
query_embedding_str: str,
|
|
597
|
+
bank_id: str,
|
|
598
|
+
fact_type: str,
|
|
599
|
+
thinking_budget: int,
|
|
600
|
+
temporal_constraint: Optional[tuple],
|
|
601
|
+
retriever: GraphRetriever,
|
|
602
|
+
) -> ParallelRetrievalResult:
|
|
603
|
+
"""BFS retrieval: all methods run in parallel (original behavior)."""
|
|
604
|
+
import time
|
|
605
|
+
|
|
606
|
+
async def run_semantic() -> _TimedResult:
|
|
607
|
+
start = time.time()
|
|
464
608
|
async with acquire_with_retry(pool) as conn:
|
|
465
|
-
|
|
609
|
+
results = await retrieve_semantic(conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget)
|
|
610
|
+
return _TimedResult(results, time.time() - start)
|
|
466
611
|
|
|
467
|
-
async def
|
|
612
|
+
async def run_bm25() -> _TimedResult:
|
|
613
|
+
start = time.time()
|
|
468
614
|
async with acquire_with_retry(pool) as conn:
|
|
469
|
-
|
|
615
|
+
results = await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget)
|
|
616
|
+
return _TimedResult(results, time.time() - start)
|
|
617
|
+
|
|
618
|
+
async def run_graph() -> _TimedResult:
|
|
619
|
+
start = time.time()
|
|
620
|
+
results = await retriever.retrieve(
|
|
621
|
+
pool=pool,
|
|
622
|
+
query_embedding_str=query_embedding_str,
|
|
623
|
+
bank_id=bank_id,
|
|
624
|
+
fact_type=fact_type,
|
|
625
|
+
budget=thinking_budget,
|
|
626
|
+
query_text=query_text,
|
|
627
|
+
)
|
|
628
|
+
return _TimedResult(results, time.time() - start)
|
|
470
629
|
|
|
471
|
-
async def run_temporal(
|
|
630
|
+
async def run_temporal(tc_start, tc_end) -> _TimedResult:
|
|
631
|
+
start = time.time()
|
|
472
632
|
async with acquire_with_retry(pool) as conn:
|
|
473
|
-
|
|
633
|
+
results = await retrieve_temporal(
|
|
474
634
|
conn, query_embedding_str, bank_id, fact_type,
|
|
475
|
-
|
|
635
|
+
tc_start, tc_end, budget=thinking_budget, semantic_threshold=0.1
|
|
476
636
|
)
|
|
637
|
+
return _TimedResult(results, time.time() - start)
|
|
477
638
|
|
|
478
|
-
# Run retrievals in parallel with timing
|
|
479
|
-
timings = {}
|
|
480
639
|
if temporal_constraint:
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
640
|
+
tc_start, tc_end = temporal_constraint
|
|
641
|
+
semantic_r, bm25_r, graph_r, temporal_r = await asyncio.gather(
|
|
642
|
+
run_semantic(),
|
|
643
|
+
run_bm25(),
|
|
644
|
+
run_graph(),
|
|
645
|
+
run_temporal(tc_start, tc_end),
|
|
646
|
+
)
|
|
647
|
+
return ParallelRetrievalResult(
|
|
648
|
+
semantic=semantic_r.results,
|
|
649
|
+
bm25=bm25_r.results,
|
|
650
|
+
graph=graph_r.results,
|
|
651
|
+
temporal=temporal_r.results,
|
|
652
|
+
timings={
|
|
653
|
+
"semantic": semantic_r.time,
|
|
654
|
+
"bm25": bm25_r.time,
|
|
655
|
+
"graph": graph_r.time,
|
|
656
|
+
"temporal": temporal_r.time,
|
|
657
|
+
},
|
|
658
|
+
temporal_constraint=temporal_constraint,
|
|
487
659
|
)
|
|
488
|
-
semantic_results, _, timings["semantic"] = results[0]
|
|
489
|
-
bm25_results, _, timings["bm25"] = results[1]
|
|
490
|
-
graph_results, _, timings["graph"] = results[2]
|
|
491
|
-
temporal_results, _, timings["temporal"] = results[3]
|
|
492
660
|
else:
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
661
|
+
semantic_r, bm25_r, graph_r = await asyncio.gather(
|
|
662
|
+
run_semantic(),
|
|
663
|
+
run_bm25(),
|
|
664
|
+
run_graph(),
|
|
665
|
+
)
|
|
666
|
+
return ParallelRetrievalResult(
|
|
667
|
+
semantic=semantic_r.results,
|
|
668
|
+
bm25=bm25_r.results,
|
|
669
|
+
graph=graph_r.results,
|
|
670
|
+
temporal=None,
|
|
671
|
+
timings={
|
|
672
|
+
"semantic": semantic_r.time,
|
|
673
|
+
"bm25": bm25_r.time,
|
|
674
|
+
"graph": graph_r.time,
|
|
675
|
+
},
|
|
676
|
+
temporal_constraint=None,
|
|
497
677
|
)
|
|
498
|
-
semantic_results, _, timings["semantic"] = results[0]
|
|
499
|
-
bm25_results, _, timings["bm25"] = results[1]
|
|
500
|
-
graph_results, _, timings["graph"] = results[2]
|
|
501
|
-
temporal_results = None
|
|
502
|
-
|
|
503
|
-
return semantic_results, bm25_results, graph_results, temporal_results, timings, temporal_constraint
|
|
@@ -108,6 +108,7 @@ class RetrievalResult(BaseModel):
|
|
|
108
108
|
class RetrievalMethodResults(BaseModel):
|
|
109
109
|
"""Results from a single retrieval method."""
|
|
110
110
|
method_name: Literal["semantic", "bm25", "graph", "temporal"] = Field(description="Name of retrieval method")
|
|
111
|
+
fact_type: Optional[str] = Field(default=None, description="Fact type this retrieval was for (world, experience, opinion)")
|
|
111
112
|
results: List[RetrievalResult] = Field(description="Retrieved results with ranks")
|
|
112
113
|
duration_seconds: float = Field(description="Time taken for this retrieval")
|
|
113
114
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Method-specific metadata")
|
|
@@ -289,7 +289,8 @@ class SearchTracer:
|
|
|
289
289
|
results: List[tuple], # List of (doc_id, data) tuples
|
|
290
290
|
duration_seconds: float,
|
|
291
291
|
score_field: str, # e.g., "similarity", "bm25_score"
|
|
292
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
292
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
293
|
+
fact_type: Optional[str] = None
|
|
293
294
|
):
|
|
294
295
|
"""
|
|
295
296
|
Record results from a single retrieval method.
|
|
@@ -300,6 +301,7 @@ class SearchTracer:
|
|
|
300
301
|
duration_seconds: Time taken for this retrieval
|
|
301
302
|
score_field: Field name containing the score in data dict
|
|
302
303
|
metadata: Optional metadata about this retrieval method
|
|
304
|
+
fact_type: Fact type this retrieval was for (world, experience, opinion)
|
|
303
305
|
"""
|
|
304
306
|
retrieval_results = []
|
|
305
307
|
for rank, (doc_id, data) in enumerate(results, start=1):
|
|
@@ -313,7 +315,7 @@ class SearchTracer:
|
|
|
313
315
|
text=data.get("text", ""),
|
|
314
316
|
context=data.get("context", ""),
|
|
315
317
|
event_date=data.get("event_date"),
|
|
316
|
-
fact_type=data.get("fact_type"),
|
|
318
|
+
fact_type=data.get("fact_type") or fact_type,
|
|
317
319
|
score=score,
|
|
318
320
|
score_name=score_field,
|
|
319
321
|
)
|
|
@@ -322,6 +324,7 @@ class SearchTracer:
|
|
|
322
324
|
self.retrieval_results.append(
|
|
323
325
|
RetrievalMethodResults(
|
|
324
326
|
method_name=method_name,
|
|
327
|
+
fact_type=fact_type,
|
|
325
328
|
results=retrieval_results,
|
|
326
329
|
duration_seconds=duration_seconds,
|
|
327
330
|
metadata=metadata or {},
|
|
@@ -367,8 +370,10 @@ class SearchTracer:
|
|
|
367
370
|
rank_change = rrf_rank - rank # Positive = moved up
|
|
368
371
|
|
|
369
372
|
# Extract score components (only include non-None values)
|
|
373
|
+
# Keys from ScoredResult.to_dict(): cross_encoder_score, cross_encoder_score_normalized,
|
|
374
|
+
# rrf_normalized, temporal, recency, combined_score, weight
|
|
370
375
|
score_components = {}
|
|
371
|
-
for key in ["
|
|
376
|
+
for key in ["cross_encoder_score", "cross_encoder_score_normalized", "rrf_score", "rrf_normalized", "temporal", "recency", "combined_score"]:
|
|
372
377
|
if key in result and result[key] is not None:
|
|
373
378
|
score_components[key] = result[key]
|
|
374
379
|
|
|
@@ -31,8 +31,9 @@ class RetrievalResult:
|
|
|
31
31
|
embedding: Optional[List[float]] = None
|
|
32
32
|
|
|
33
33
|
# Retrieval-specific scores (only one will be set depending on retrieval method)
|
|
34
|
-
similarity: Optional[float] = None # Semantic
|
|
34
|
+
similarity: Optional[float] = None # Semantic retrieval
|
|
35
35
|
bm25_score: Optional[float] = None # BM25 retrieval
|
|
36
|
+
activation: Optional[float] = None # Graph retrieval (spreading activation)
|
|
36
37
|
temporal_score: Optional[float] = None # Temporal retrieval
|
|
37
38
|
temporal_proximity: Optional[float] = None # Temporal retrieval
|
|
38
39
|
|
|
@@ -54,6 +55,7 @@ class RetrievalResult:
|
|
|
54
55
|
embedding=row.get("embedding"),
|
|
55
56
|
similarity=row.get("similarity"),
|
|
56
57
|
bm25_score=row.get("bm25_score"),
|
|
58
|
+
activation=row.get("activation"),
|
|
57
59
|
temporal_score=row.get("temporal_score"),
|
|
58
60
|
temporal_proximity=row.get("temporal_proximity"),
|
|
59
61
|
)
|
|
@@ -152,6 +154,7 @@ class ScoredResult:
|
|
|
152
154
|
result["cross_encoder_score"] = self.cross_encoder_score
|
|
153
155
|
result["cross_encoder_score_normalized"] = self.cross_encoder_score_normalized
|
|
154
156
|
result["rrf_normalized"] = self.rrf_normalized
|
|
157
|
+
result["temporal"] = self.temporal
|
|
155
158
|
result["recency"] = self.recency
|
|
156
159
|
result["combined_score"] = self.combined_score
|
|
157
160
|
result["weight"] = self.weight
|