hindsight-api 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +38 -0
- hindsight_api/api/__init__.py +105 -0
- hindsight_api/api/http.py +1872 -0
- hindsight_api/api/mcp.py +157 -0
- hindsight_api/engine/__init__.py +47 -0
- hindsight_api/engine/cross_encoder.py +97 -0
- hindsight_api/engine/db_utils.py +93 -0
- hindsight_api/engine/embeddings.py +113 -0
- hindsight_api/engine/entity_resolver.py +575 -0
- hindsight_api/engine/llm_wrapper.py +269 -0
- hindsight_api/engine/memory_engine.py +3095 -0
- hindsight_api/engine/query_analyzer.py +519 -0
- hindsight_api/engine/response_models.py +222 -0
- hindsight_api/engine/retain/__init__.py +50 -0
- hindsight_api/engine/retain/bank_utils.py +423 -0
- hindsight_api/engine/retain/chunk_storage.py +82 -0
- hindsight_api/engine/retain/deduplication.py +104 -0
- hindsight_api/engine/retain/embedding_processing.py +62 -0
- hindsight_api/engine/retain/embedding_utils.py +54 -0
- hindsight_api/engine/retain/entity_processing.py +90 -0
- hindsight_api/engine/retain/fact_extraction.py +1027 -0
- hindsight_api/engine/retain/fact_storage.py +176 -0
- hindsight_api/engine/retain/link_creation.py +121 -0
- hindsight_api/engine/retain/link_utils.py +651 -0
- hindsight_api/engine/retain/orchestrator.py +405 -0
- hindsight_api/engine/retain/types.py +206 -0
- hindsight_api/engine/search/__init__.py +15 -0
- hindsight_api/engine/search/fusion.py +122 -0
- hindsight_api/engine/search/observation_utils.py +132 -0
- hindsight_api/engine/search/reranking.py +103 -0
- hindsight_api/engine/search/retrieval.py +503 -0
- hindsight_api/engine/search/scoring.py +161 -0
- hindsight_api/engine/search/temporal_extraction.py +64 -0
- hindsight_api/engine/search/think_utils.py +255 -0
- hindsight_api/engine/search/trace.py +215 -0
- hindsight_api/engine/search/tracer.py +447 -0
- hindsight_api/engine/search/types.py +160 -0
- hindsight_api/engine/task_backend.py +223 -0
- hindsight_api/engine/utils.py +203 -0
- hindsight_api/metrics.py +227 -0
- hindsight_api/migrations.py +163 -0
- hindsight_api/models.py +309 -0
- hindsight_api/pg0.py +425 -0
- hindsight_api/web/__init__.py +12 -0
- hindsight_api/web/server.py +143 -0
- hindsight_api-0.0.13.dist-info/METADATA +41 -0
- hindsight_api-0.0.13.dist-info/RECORD +48 -0
- hindsight_api-0.0.13.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Search tracer for collecting detailed search execution traces.
|
|
3
|
+
|
|
4
|
+
The SearchTracer collects comprehensive information about each step
|
|
5
|
+
of the spreading activation search process for debugging and visualization.
|
|
6
|
+
"""
|
|
7
|
+
import time
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from typing import List, Optional, Dict, Any, Literal
|
|
10
|
+
|
|
11
|
+
from .trace import (
|
|
12
|
+
SearchTrace,
|
|
13
|
+
QueryInfo,
|
|
14
|
+
EntryPoint,
|
|
15
|
+
NodeVisit,
|
|
16
|
+
WeightComponents,
|
|
17
|
+
LinkInfo,
|
|
18
|
+
PruningDecision,
|
|
19
|
+
SearchSummary,
|
|
20
|
+
SearchPhaseMetrics,
|
|
21
|
+
RetrievalResult,
|
|
22
|
+
RetrievalMethodResults,
|
|
23
|
+
RRFMergeResult,
|
|
24
|
+
RerankedResult,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SearchTracer:
|
|
29
|
+
"""
|
|
30
|
+
Tracer for collecting detailed search execution information.
|
|
31
|
+
|
|
32
|
+
Usage:
|
|
33
|
+
tracer = SearchTracer(query="Who is Alice?", budget=50, max_tokens=4096)
|
|
34
|
+
tracer.start()
|
|
35
|
+
|
|
36
|
+
# During search...
|
|
37
|
+
tracer.record_query_embedding(embedding)
|
|
38
|
+
tracer.add_entry_point(node_id, text, similarity, rank)
|
|
39
|
+
tracer.visit_node(...)
|
|
40
|
+
tracer.prune_node(...)
|
|
41
|
+
|
|
42
|
+
# After search...
|
|
43
|
+
trace = tracer.finalize(final_results)
|
|
44
|
+
json_output = trace.to_json()
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, query: str, budget: int, max_tokens: int):
|
|
48
|
+
"""
|
|
49
|
+
Initialize tracer.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
query: Search query text
|
|
53
|
+
budget: Maximum nodes to explore
|
|
54
|
+
max_tokens: Maximum tokens to return in results
|
|
55
|
+
"""
|
|
56
|
+
self.query_text = query
|
|
57
|
+
self.budget = budget
|
|
58
|
+
self.max_tokens = max_tokens
|
|
59
|
+
|
|
60
|
+
# Trace data
|
|
61
|
+
self.query_embedding: Optional[List[float]] = None
|
|
62
|
+
self.start_time: Optional[float] = None
|
|
63
|
+
self.entry_points: List[EntryPoint] = []
|
|
64
|
+
self.visits: List[NodeVisit] = []
|
|
65
|
+
self.pruned: List[PruningDecision] = []
|
|
66
|
+
self.phase_metrics: List[SearchPhaseMetrics] = []
|
|
67
|
+
|
|
68
|
+
# New 4-way retrieval tracking
|
|
69
|
+
self.retrieval_results: List[RetrievalMethodResults] = []
|
|
70
|
+
self.rrf_merged: List[RRFMergeResult] = []
|
|
71
|
+
self.reranked: List[RerankedResult] = []
|
|
72
|
+
|
|
73
|
+
# Tracking state
|
|
74
|
+
self.current_step = 0
|
|
75
|
+
self.nodes_visited_set = set() # For quick lookups
|
|
76
|
+
|
|
77
|
+
# Link statistics
|
|
78
|
+
self.temporal_links_followed = 0
|
|
79
|
+
self.semantic_links_followed = 0
|
|
80
|
+
self.entity_links_followed = 0
|
|
81
|
+
|
|
82
|
+
def start(self):
|
|
83
|
+
"""Start timing the search."""
|
|
84
|
+
self.start_time = time.time()
|
|
85
|
+
|
|
86
|
+
def record_query_embedding(self, embedding: List[float]):
|
|
87
|
+
"""Record the query embedding."""
|
|
88
|
+
self.query_embedding = embedding
|
|
89
|
+
|
|
90
|
+
def add_entry_point(self, node_id: str, text: str, similarity: float, rank: int):
|
|
91
|
+
"""
|
|
92
|
+
Record an entry point.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
node_id: Memory unit ID
|
|
96
|
+
text: Memory unit text
|
|
97
|
+
similarity: Cosine similarity to query
|
|
98
|
+
rank: Rank among entry points (1-based)
|
|
99
|
+
"""
|
|
100
|
+
# Clamp similarity to [0.0, 1.0] to handle floating-point precision
|
|
101
|
+
similarity = min(1.0, max(0.0, similarity))
|
|
102
|
+
|
|
103
|
+
self.entry_points.append(
|
|
104
|
+
EntryPoint(
|
|
105
|
+
node_id=node_id,
|
|
106
|
+
text=text,
|
|
107
|
+
similarity_score=similarity,
|
|
108
|
+
rank=rank,
|
|
109
|
+
)
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def visit_node(
|
|
113
|
+
self,
|
|
114
|
+
node_id: str,
|
|
115
|
+
text: str,
|
|
116
|
+
context: str,
|
|
117
|
+
event_date: datetime,
|
|
118
|
+
access_count: int,
|
|
119
|
+
is_entry_point: bool,
|
|
120
|
+
parent_node_id: Optional[str],
|
|
121
|
+
link_type: Optional[Literal["temporal", "semantic", "entity"]],
|
|
122
|
+
link_weight: Optional[float],
|
|
123
|
+
activation: float,
|
|
124
|
+
semantic_similarity: float,
|
|
125
|
+
recency: float,
|
|
126
|
+
frequency: float,
|
|
127
|
+
final_weight: float,
|
|
128
|
+
):
|
|
129
|
+
"""
|
|
130
|
+
Record visiting a node.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
node_id: Memory unit ID
|
|
134
|
+
text: Memory unit text
|
|
135
|
+
context: Memory unit context
|
|
136
|
+
event_date: When the memory occurred
|
|
137
|
+
access_count: Access count before this search
|
|
138
|
+
is_entry_point: Whether this is an entry point
|
|
139
|
+
parent_node_id: Node that led here (None for entry points)
|
|
140
|
+
link_type: Type of link from parent
|
|
141
|
+
link_weight: Weight of link from parent
|
|
142
|
+
activation: Activation score
|
|
143
|
+
semantic_similarity: Semantic similarity to query
|
|
144
|
+
recency: Recency weight
|
|
145
|
+
frequency: Frequency weight
|
|
146
|
+
final_weight: Combined final weight
|
|
147
|
+
"""
|
|
148
|
+
self.current_step += 1
|
|
149
|
+
self.nodes_visited_set.add(node_id)
|
|
150
|
+
|
|
151
|
+
# Clamp values to handle floating-point precision issues
|
|
152
|
+
# (sometimes normalization produces values like 1.0000005 instead of 1.0)
|
|
153
|
+
semantic_similarity = min(1.0, max(0.0, semantic_similarity))
|
|
154
|
+
recency = min(1.0, max(0.0, recency))
|
|
155
|
+
frequency = min(1.0, max(0.0, frequency))
|
|
156
|
+
|
|
157
|
+
# Calculate weight contributions for transparency
|
|
158
|
+
weights = WeightComponents(
|
|
159
|
+
activation=activation,
|
|
160
|
+
semantic_similarity=semantic_similarity,
|
|
161
|
+
recency=recency,
|
|
162
|
+
frequency=frequency,
|
|
163
|
+
final_weight=final_weight,
|
|
164
|
+
activation_contribution=0.3 * activation,
|
|
165
|
+
semantic_contribution=0.3 * semantic_similarity,
|
|
166
|
+
recency_contribution=0.25 * recency,
|
|
167
|
+
frequency_contribution=0.15 * frequency,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
visit = NodeVisit(
|
|
171
|
+
step=self.current_step,
|
|
172
|
+
node_id=node_id,
|
|
173
|
+
text=text,
|
|
174
|
+
context=context,
|
|
175
|
+
event_date=event_date,
|
|
176
|
+
access_count=access_count,
|
|
177
|
+
is_entry_point=is_entry_point,
|
|
178
|
+
parent_node_id=parent_node_id,
|
|
179
|
+
link_type=link_type,
|
|
180
|
+
link_weight=link_weight,
|
|
181
|
+
weights=weights,
|
|
182
|
+
neighbors_explored=[],
|
|
183
|
+
final_rank=None, # Will be set later
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
self.visits.append(visit)
|
|
187
|
+
|
|
188
|
+
# Track link statistics
|
|
189
|
+
if link_type == "temporal":
|
|
190
|
+
self.temporal_links_followed += 1
|
|
191
|
+
elif link_type == "semantic":
|
|
192
|
+
self.semantic_links_followed += 1
|
|
193
|
+
elif link_type == "entity":
|
|
194
|
+
self.entity_links_followed += 1
|
|
195
|
+
|
|
196
|
+
def add_neighbor_link(
|
|
197
|
+
self,
|
|
198
|
+
from_node_id: str,
|
|
199
|
+
to_node_id: str,
|
|
200
|
+
link_type: Literal["temporal", "semantic", "entity"],
|
|
201
|
+
link_weight: float,
|
|
202
|
+
entity_id: Optional[str],
|
|
203
|
+
new_activation: Optional[float],
|
|
204
|
+
followed: bool,
|
|
205
|
+
prune_reason: Optional[str] = None,
|
|
206
|
+
is_supplementary: bool = False,
|
|
207
|
+
):
|
|
208
|
+
"""
|
|
209
|
+
Record a link to a neighbor (whether followed or not).
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
from_node_id: Source node
|
|
213
|
+
to_node_id: Target node
|
|
214
|
+
link_type: Type of link
|
|
215
|
+
link_weight: Weight of link
|
|
216
|
+
entity_id: Entity ID if link is entity-based
|
|
217
|
+
new_activation: Activation passed to neighbor (None for supplementary links)
|
|
218
|
+
followed: Whether link was followed
|
|
219
|
+
prune_reason: Why link was not followed (if not followed)
|
|
220
|
+
is_supplementary: Whether this is a supplementary link (multiple connections)
|
|
221
|
+
"""
|
|
222
|
+
# Find the visit for the source node
|
|
223
|
+
visit = None
|
|
224
|
+
for v in self.visits:
|
|
225
|
+
if v.node_id == from_node_id:
|
|
226
|
+
visit = v
|
|
227
|
+
break
|
|
228
|
+
|
|
229
|
+
if visit is None:
|
|
230
|
+
# Node not found, skip
|
|
231
|
+
return
|
|
232
|
+
|
|
233
|
+
link_info = LinkInfo(
|
|
234
|
+
to_node_id=to_node_id,
|
|
235
|
+
link_type=link_type,
|
|
236
|
+
link_weight=link_weight,
|
|
237
|
+
entity_id=entity_id,
|
|
238
|
+
new_activation=new_activation,
|
|
239
|
+
followed=followed,
|
|
240
|
+
prune_reason=prune_reason,
|
|
241
|
+
is_supplementary=is_supplementary,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
visit.neighbors_explored.append(link_info)
|
|
245
|
+
|
|
246
|
+
def prune_node(
|
|
247
|
+
self,
|
|
248
|
+
node_id: str,
|
|
249
|
+
reason: Literal["already_visited", "activation_too_low", "budget_exhausted"],
|
|
250
|
+
activation: float,
|
|
251
|
+
):
|
|
252
|
+
"""
|
|
253
|
+
Record a node being pruned (not visited).
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
node_id: Node that was pruned
|
|
257
|
+
reason: Why it was pruned
|
|
258
|
+
activation: Activation value when pruned
|
|
259
|
+
"""
|
|
260
|
+
self.pruned.append(
|
|
261
|
+
PruningDecision(
|
|
262
|
+
node_id=node_id,
|
|
263
|
+
reason=reason,
|
|
264
|
+
activation=activation,
|
|
265
|
+
would_have_been_step=self.current_step + 1,
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
def add_phase_metric(self, phase_name: str, duration_seconds: float, details: Optional[Dict[str, Any]] = None):
|
|
270
|
+
"""
|
|
271
|
+
Record metrics for a search phase.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
phase_name: Name of the phase
|
|
275
|
+
duration_seconds: Time taken
|
|
276
|
+
details: Additional phase-specific details
|
|
277
|
+
"""
|
|
278
|
+
self.phase_metrics.append(
|
|
279
|
+
SearchPhaseMetrics(
|
|
280
|
+
phase_name=phase_name,
|
|
281
|
+
duration_seconds=duration_seconds,
|
|
282
|
+
details=details or {},
|
|
283
|
+
)
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def add_retrieval_results(
|
|
287
|
+
self,
|
|
288
|
+
method_name: Literal["semantic", "bm25", "graph", "temporal"],
|
|
289
|
+
results: List[tuple], # List of (doc_id, data) tuples
|
|
290
|
+
duration_seconds: float,
|
|
291
|
+
score_field: str, # e.g., "similarity", "bm25_score"
|
|
292
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
293
|
+
):
|
|
294
|
+
"""
|
|
295
|
+
Record results from a single retrieval method.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
method_name: Name of the retrieval method
|
|
299
|
+
results: List of (doc_id, data) tuples from retrieval
|
|
300
|
+
duration_seconds: Time taken for this retrieval
|
|
301
|
+
score_field: Field name containing the score in data dict
|
|
302
|
+
metadata: Optional metadata about this retrieval method
|
|
303
|
+
"""
|
|
304
|
+
retrieval_results = []
|
|
305
|
+
for rank, (doc_id, data) in enumerate(results, start=1):
|
|
306
|
+
score = data.get(score_field)
|
|
307
|
+
if score is None:
|
|
308
|
+
score = 0.0
|
|
309
|
+
retrieval_results.append(
|
|
310
|
+
RetrievalResult(
|
|
311
|
+
rank=rank,
|
|
312
|
+
node_id=doc_id,
|
|
313
|
+
text=data.get("text", ""),
|
|
314
|
+
context=data.get("context", ""),
|
|
315
|
+
event_date=data.get("event_date"),
|
|
316
|
+
fact_type=data.get("fact_type"),
|
|
317
|
+
score=score,
|
|
318
|
+
score_name=score_field,
|
|
319
|
+
)
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
self.retrieval_results.append(
|
|
323
|
+
RetrievalMethodResults(
|
|
324
|
+
method_name=method_name,
|
|
325
|
+
results=retrieval_results,
|
|
326
|
+
duration_seconds=duration_seconds,
|
|
327
|
+
metadata=metadata or {},
|
|
328
|
+
)
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
def add_rrf_merged(self, merged_results: List[tuple]):
|
|
332
|
+
"""
|
|
333
|
+
Record RRF merged results.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
merged_results: List of (doc_id, data, rrf_meta) tuples from RRF merge
|
|
337
|
+
"""
|
|
338
|
+
self.rrf_merged = []
|
|
339
|
+
for rank, (doc_id, data, rrf_meta) in enumerate(merged_results, start=1):
|
|
340
|
+
self.rrf_merged.append(
|
|
341
|
+
RRFMergeResult(
|
|
342
|
+
node_id=doc_id,
|
|
343
|
+
text=data.get("text", ""),
|
|
344
|
+
rrf_score=rrf_meta.get("rrf_score", 0.0),
|
|
345
|
+
source_ranks=rrf_meta.get("source_ranks", {}),
|
|
346
|
+
final_rrf_rank=rank,
|
|
347
|
+
)
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
def add_reranked(self, reranked_results: List[Dict[str, Any]], rrf_merged: List):
|
|
351
|
+
"""
|
|
352
|
+
Record reranked results.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
reranked_results: List of result dicts after reranking
|
|
356
|
+
rrf_merged: Original RRF merged results for comparison
|
|
357
|
+
"""
|
|
358
|
+
# Build map of node_id -> rrf_rank
|
|
359
|
+
rrf_rank_map = {}
|
|
360
|
+
for item in self.rrf_merged:
|
|
361
|
+
rrf_rank_map[item.node_id] = item.final_rrf_rank
|
|
362
|
+
|
|
363
|
+
self.reranked = []
|
|
364
|
+
for rank, result in enumerate(reranked_results, start=1):
|
|
365
|
+
node_id = result["id"]
|
|
366
|
+
rrf_rank = rrf_rank_map.get(node_id, len(rrf_merged) + 1)
|
|
367
|
+
rank_change = rrf_rank - rank # Positive = moved up
|
|
368
|
+
|
|
369
|
+
# Extract score components (only include non-None values)
|
|
370
|
+
score_components = {}
|
|
371
|
+
for key in ["semantic_similarity", "bm25_score", "rrf_score", "recency_normalized", "frequency_normalized"]:
|
|
372
|
+
if key in result and result[key] is not None:
|
|
373
|
+
score_components[key] = result[key]
|
|
374
|
+
|
|
375
|
+
self.reranked.append(
|
|
376
|
+
RerankedResult(
|
|
377
|
+
node_id=node_id,
|
|
378
|
+
text=result.get("text", ""),
|
|
379
|
+
rerank_score=result.get("weight", 0.0),
|
|
380
|
+
rerank_rank=rank,
|
|
381
|
+
rrf_rank=rrf_rank,
|
|
382
|
+
rank_change=rank_change,
|
|
383
|
+
score_components=score_components,
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def finalize(self, final_results: List[Dict[str, Any]]) -> SearchTrace:
|
|
388
|
+
"""
|
|
389
|
+
Finalize the trace and return the complete SearchTrace object.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
final_results: Final ranked results returned to user
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
Complete SearchTrace object
|
|
396
|
+
"""
|
|
397
|
+
if self.start_time is None:
|
|
398
|
+
raise ValueError("Tracer not started - call start() first")
|
|
399
|
+
|
|
400
|
+
total_duration = time.time() - self.start_time
|
|
401
|
+
|
|
402
|
+
# Set final ranks on visits based on results
|
|
403
|
+
for rank, result in enumerate(final_results, 1):
|
|
404
|
+
result_node_id = result["id"]
|
|
405
|
+
for visit in self.visits:
|
|
406
|
+
if visit.node_id == result_node_id:
|
|
407
|
+
visit.final_rank = rank
|
|
408
|
+
break
|
|
409
|
+
|
|
410
|
+
# Create query info
|
|
411
|
+
query_info = QueryInfo(
|
|
412
|
+
query_text=self.query_text,
|
|
413
|
+
query_embedding=self.query_embedding or [],
|
|
414
|
+
timestamp=datetime.now(timezone.utc),
|
|
415
|
+
budget=self.budget,
|
|
416
|
+
max_tokens=self.max_tokens,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Create summary
|
|
420
|
+
summary = SearchSummary(
|
|
421
|
+
total_nodes_visited=len(self.visits),
|
|
422
|
+
total_nodes_pruned=len(self.pruned),
|
|
423
|
+
entry_points_found=len(self.entry_points),
|
|
424
|
+
budget_used=len(self.visits),
|
|
425
|
+
budget_remaining=self.budget - len(self.visits),
|
|
426
|
+
total_duration_seconds=total_duration,
|
|
427
|
+
results_returned=len(final_results),
|
|
428
|
+
temporal_links_followed=self.temporal_links_followed,
|
|
429
|
+
semantic_links_followed=self.semantic_links_followed,
|
|
430
|
+
entity_links_followed=self.entity_links_followed,
|
|
431
|
+
phase_metrics=self.phase_metrics,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Create complete trace
|
|
435
|
+
trace = SearchTrace(
|
|
436
|
+
query=query_info,
|
|
437
|
+
retrieval_results=self.retrieval_results,
|
|
438
|
+
rrf_merged=self.rrf_merged,
|
|
439
|
+
reranked=self.reranked,
|
|
440
|
+
entry_points=self.entry_points,
|
|
441
|
+
visits=self.visits,
|
|
442
|
+
pruned=self.pruned,
|
|
443
|
+
summary=summary,
|
|
444
|
+
final_results=final_results,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
return trace
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Type definitions for the recall pipeline.
|
|
3
|
+
|
|
4
|
+
These dataclasses replace Dict[str, Any] types throughout the recall pipeline,
|
|
5
|
+
providing type safety and making data flow explicit.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Optional, List, Dict, Any
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class RetrievalResult:
|
|
15
|
+
"""
|
|
16
|
+
Result from a single retrieval method (semantic, BM25, graph, or temporal).
|
|
17
|
+
|
|
18
|
+
This represents a raw result from the database query, before merging or reranking.
|
|
19
|
+
"""
|
|
20
|
+
id: str
|
|
21
|
+
text: str
|
|
22
|
+
fact_type: str
|
|
23
|
+
context: Optional[str] = None
|
|
24
|
+
event_date: Optional[datetime] = None
|
|
25
|
+
occurred_start: Optional[datetime] = None
|
|
26
|
+
occurred_end: Optional[datetime] = None
|
|
27
|
+
mentioned_at: Optional[datetime] = None
|
|
28
|
+
document_id: Optional[str] = None
|
|
29
|
+
chunk_id: Optional[str] = None
|
|
30
|
+
access_count: int = 0
|
|
31
|
+
embedding: Optional[List[float]] = None
|
|
32
|
+
|
|
33
|
+
# Retrieval-specific scores (only one will be set depending on retrieval method)
|
|
34
|
+
similarity: Optional[float] = None # Semantic/graph retrieval
|
|
35
|
+
bm25_score: Optional[float] = None # BM25 retrieval
|
|
36
|
+
temporal_score: Optional[float] = None # Temporal retrieval
|
|
37
|
+
temporal_proximity: Optional[float] = None # Temporal retrieval
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_db_row(cls, row: Dict[str, Any]) -> "RetrievalResult":
|
|
41
|
+
"""Create from a database row (asyncpg Record converted to dict)."""
|
|
42
|
+
return cls(
|
|
43
|
+
id=str(row["id"]),
|
|
44
|
+
text=row["text"],
|
|
45
|
+
fact_type=row["fact_type"],
|
|
46
|
+
context=row.get("context"),
|
|
47
|
+
event_date=row.get("event_date"),
|
|
48
|
+
occurred_start=row.get("occurred_start"),
|
|
49
|
+
occurred_end=row.get("occurred_end"),
|
|
50
|
+
mentioned_at=row.get("mentioned_at"),
|
|
51
|
+
document_id=row.get("document_id"),
|
|
52
|
+
chunk_id=row.get("chunk_id"),
|
|
53
|
+
access_count=row.get("access_count", 0),
|
|
54
|
+
embedding=row.get("embedding"),
|
|
55
|
+
similarity=row.get("similarity"),
|
|
56
|
+
bm25_score=row.get("bm25_score"),
|
|
57
|
+
temporal_score=row.get("temporal_score"),
|
|
58
|
+
temporal_proximity=row.get("temporal_proximity"),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class MergedCandidate:
|
|
64
|
+
"""
|
|
65
|
+
Candidate after RRF merge of multiple retrieval results.
|
|
66
|
+
|
|
67
|
+
Contains the original retrieval data plus RRF metadata.
|
|
68
|
+
"""
|
|
69
|
+
# Original retrieval data
|
|
70
|
+
retrieval: RetrievalResult
|
|
71
|
+
|
|
72
|
+
# RRF metadata
|
|
73
|
+
rrf_score: float
|
|
74
|
+
rrf_rank: int = 0
|
|
75
|
+
source_ranks: Dict[str, int] = field(default_factory=dict) # method_name -> rank
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def id(self) -> str:
|
|
79
|
+
"""Convenience property to access ID."""
|
|
80
|
+
return self.retrieval.id
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class ScoredResult:
|
|
85
|
+
"""
|
|
86
|
+
Result after reranking and scoring.
|
|
87
|
+
|
|
88
|
+
Contains all retrieval/merge data plus reranking scores and combined score.
|
|
89
|
+
"""
|
|
90
|
+
# Original merged candidate
|
|
91
|
+
candidate: MergedCandidate
|
|
92
|
+
|
|
93
|
+
# Reranking scores
|
|
94
|
+
cross_encoder_score: float = 0.0
|
|
95
|
+
cross_encoder_score_normalized: float = 0.0
|
|
96
|
+
|
|
97
|
+
# Normalized component scores
|
|
98
|
+
rrf_normalized: float = 0.0
|
|
99
|
+
recency: float = 0.5
|
|
100
|
+
temporal: float = 0.5
|
|
101
|
+
|
|
102
|
+
# Final combined score
|
|
103
|
+
combined_score: float = 0.0
|
|
104
|
+
weight: float = 0.0 # Final weight used for ranking
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def id(self) -> str:
|
|
108
|
+
"""Convenience property to access ID."""
|
|
109
|
+
return self.candidate.id
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def retrieval(self) -> RetrievalResult:
|
|
113
|
+
"""Convenience property to access retrieval data."""
|
|
114
|
+
return self.candidate.retrieval
|
|
115
|
+
|
|
116
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
117
|
+
"""
|
|
118
|
+
Convert to dict for backwards compatibility.
|
|
119
|
+
|
|
120
|
+
This is used during the transition period and for serialization.
|
|
121
|
+
"""
|
|
122
|
+
# Start with retrieval data
|
|
123
|
+
result = {
|
|
124
|
+
"id": self.retrieval.id,
|
|
125
|
+
"text": self.retrieval.text,
|
|
126
|
+
"fact_type": self.retrieval.fact_type,
|
|
127
|
+
"context": self.retrieval.context,
|
|
128
|
+
"event_date": self.retrieval.event_date,
|
|
129
|
+
"occurred_start": self.retrieval.occurred_start,
|
|
130
|
+
"occurred_end": self.retrieval.occurred_end,
|
|
131
|
+
"mentioned_at": self.retrieval.mentioned_at,
|
|
132
|
+
"document_id": self.retrieval.document_id,
|
|
133
|
+
"chunk_id": self.retrieval.chunk_id,
|
|
134
|
+
"access_count": self.retrieval.access_count,
|
|
135
|
+
"embedding": self.retrieval.embedding,
|
|
136
|
+
"semantic_similarity": self.retrieval.similarity,
|
|
137
|
+
"bm25_score": self.retrieval.bm25_score,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
# Add temporal scores if present
|
|
141
|
+
if self.retrieval.temporal_score is not None:
|
|
142
|
+
result["temporal_score"] = self.retrieval.temporal_score
|
|
143
|
+
if self.retrieval.temporal_proximity is not None:
|
|
144
|
+
result["temporal_proximity"] = self.retrieval.temporal_proximity
|
|
145
|
+
|
|
146
|
+
# Add RRF metadata
|
|
147
|
+
result["rrf_score"] = self.candidate.rrf_score
|
|
148
|
+
result["rrf_rank"] = self.candidate.rrf_rank
|
|
149
|
+
result.update(self.candidate.source_ranks)
|
|
150
|
+
|
|
151
|
+
# Add reranking scores
|
|
152
|
+
result["cross_encoder_score"] = self.cross_encoder_score
|
|
153
|
+
result["cross_encoder_score_normalized"] = self.cross_encoder_score_normalized
|
|
154
|
+
result["rrf_normalized"] = self.rrf_normalized
|
|
155
|
+
result["recency"] = self.recency
|
|
156
|
+
result["combined_score"] = self.combined_score
|
|
157
|
+
result["weight"] = self.weight
|
|
158
|
+
result["activation"] = self.weight # Legacy field
|
|
159
|
+
|
|
160
|
+
return result
|