hindsight-api 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -9
- hindsight_api/alembic/env.py +5 -8
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
- hindsight_api/api/__init__.py +10 -10
- hindsight_api/api/http.py +575 -593
- hindsight_api/api/mcp.py +31 -33
- hindsight_api/banner.py +13 -6
- hindsight_api/config.py +17 -12
- hindsight_api/engine/__init__.py +9 -9
- hindsight_api/engine/cross_encoder.py +23 -27
- hindsight_api/engine/db_utils.py +5 -4
- hindsight_api/engine/embeddings.py +22 -21
- hindsight_api/engine/entity_resolver.py +81 -75
- hindsight_api/engine/llm_wrapper.py +74 -88
- hindsight_api/engine/memory_engine.py +663 -673
- hindsight_api/engine/query_analyzer.py +100 -97
- hindsight_api/engine/response_models.py +105 -106
- hindsight_api/engine/retain/__init__.py +9 -16
- hindsight_api/engine/retain/bank_utils.py +34 -58
- hindsight_api/engine/retain/chunk_storage.py +4 -12
- hindsight_api/engine/retain/deduplication.py +9 -28
- hindsight_api/engine/retain/embedding_processing.py +4 -11
- hindsight_api/engine/retain/embedding_utils.py +3 -4
- hindsight_api/engine/retain/entity_processing.py +7 -17
- hindsight_api/engine/retain/fact_extraction.py +155 -165
- hindsight_api/engine/retain/fact_storage.py +11 -23
- hindsight_api/engine/retain/link_creation.py +11 -39
- hindsight_api/engine/retain/link_utils.py +166 -95
- hindsight_api/engine/retain/observation_regeneration.py +39 -52
- hindsight_api/engine/retain/orchestrator.py +72 -62
- hindsight_api/engine/retain/types.py +49 -43
- hindsight_api/engine/search/__init__.py +15 -1
- hindsight_api/engine/search/fusion.py +6 -15
- hindsight_api/engine/search/graph_retrieval.py +234 -0
- hindsight_api/engine/search/mpfp_retrieval.py +438 -0
- hindsight_api/engine/search/observation_utils.py +9 -16
- hindsight_api/engine/search/reranking.py +4 -7
- hindsight_api/engine/search/retrieval.py +388 -193
- hindsight_api/engine/search/scoring.py +5 -7
- hindsight_api/engine/search/temporal_extraction.py +8 -11
- hindsight_api/engine/search/think_utils.py +115 -39
- hindsight_api/engine/search/trace.py +68 -38
- hindsight_api/engine/search/tracer.py +49 -35
- hindsight_api/engine/search/types.py +22 -16
- hindsight_api/engine/task_backend.py +21 -26
- hindsight_api/engine/utils.py +25 -10
- hindsight_api/main.py +21 -40
- hindsight_api/mcp_local.py +190 -0
- hindsight_api/metrics.py +44 -30
- hindsight_api/migrations.py +10 -8
- hindsight_api/models.py +60 -72
- hindsight_api/pg0.py +64 -337
- hindsight_api/server.py +3 -6
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +6 -5
- hindsight_api-0.1.6.dist-info/RECORD +64 -0
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.1.4.dist-info/RECORD +0 -61
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
|
@@ -4,24 +4,68 @@ Retrieval module for 4-way parallel search.
|
|
|
4
4
|
Implements:
|
|
5
5
|
1. Semantic retrieval (vector similarity)
|
|
6
6
|
2. BM25 retrieval (keyword/full-text search)
|
|
7
|
-
3. Graph retrieval (
|
|
7
|
+
3. Graph retrieval (via pluggable GraphRetriever interface)
|
|
8
8
|
4. Temporal retrieval (time-aware search with spreading)
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
|
-
from typing import List, Dict, Any, Tuple, Optional
|
|
12
|
-
from datetime import datetime
|
|
13
11
|
import asyncio
|
|
12
|
+
import logging
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from datetime import UTC, datetime
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
from ...config import get_config
|
|
14
18
|
from ..db_utils import acquire_with_retry
|
|
19
|
+
from .graph_retrieval import BFSGraphRetriever, GraphRetriever
|
|
20
|
+
from .mpfp_retrieval import MPFPGraphRetriever
|
|
15
21
|
from .types import RetrievalResult
|
|
16
22
|
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class ParallelRetrievalResult:
|
|
28
|
+
"""Result from parallel retrieval across all methods."""
|
|
29
|
+
|
|
30
|
+
semantic: list[RetrievalResult]
|
|
31
|
+
bm25: list[RetrievalResult]
|
|
32
|
+
graph: list[RetrievalResult]
|
|
33
|
+
temporal: list[RetrievalResult] | None
|
|
34
|
+
timings: dict[str, float] = field(default_factory=dict)
|
|
35
|
+
temporal_constraint: tuple | None = None # (start_date, end_date)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Default graph retriever instance (can be overridden)
|
|
39
|
+
_default_graph_retriever: GraphRetriever | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_default_graph_retriever() -> GraphRetriever:
|
|
43
|
+
"""Get or create the default graph retriever based on config."""
|
|
44
|
+
global _default_graph_retriever
|
|
45
|
+
if _default_graph_retriever is None:
|
|
46
|
+
config = get_config()
|
|
47
|
+
retriever_type = config.graph_retriever.lower()
|
|
48
|
+
if retriever_type == "mpfp":
|
|
49
|
+
_default_graph_retriever = MPFPGraphRetriever()
|
|
50
|
+
logger.info("Using MPFP graph retriever")
|
|
51
|
+
elif retriever_type == "bfs":
|
|
52
|
+
_default_graph_retriever = BFSGraphRetriever()
|
|
53
|
+
logger.info("Using BFS graph retriever")
|
|
54
|
+
else:
|
|
55
|
+
logger.warning(f"Unknown graph retriever '{retriever_type}', falling back to MPFP")
|
|
56
|
+
_default_graph_retriever = MPFPGraphRetriever()
|
|
57
|
+
return _default_graph_retriever
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def set_default_graph_retriever(retriever: GraphRetriever) -> None:
|
|
61
|
+
"""Set the default graph retriever (for configuration/testing)."""
|
|
62
|
+
global _default_graph_retriever
|
|
63
|
+
_default_graph_retriever = retriever
|
|
64
|
+
|
|
17
65
|
|
|
18
66
|
async def retrieve_semantic(
|
|
19
|
-
conn,
|
|
20
|
-
|
|
21
|
-
bank_id: str,
|
|
22
|
-
fact_type: str,
|
|
23
|
-
limit: int
|
|
24
|
-
) -> List[RetrievalResult]:
|
|
67
|
+
conn, query_emb_str: str, bank_id: str, fact_type: str, limit: int
|
|
68
|
+
) -> list[RetrievalResult]:
|
|
25
69
|
"""
|
|
26
70
|
Semantic retrieval via vector similarity.
|
|
27
71
|
|
|
@@ -47,18 +91,15 @@ async def retrieve_semantic(
|
|
|
47
91
|
ORDER BY embedding <=> $1::vector
|
|
48
92
|
LIMIT $4
|
|
49
93
|
""",
|
|
50
|
-
query_emb_str,
|
|
94
|
+
query_emb_str,
|
|
95
|
+
bank_id,
|
|
96
|
+
fact_type,
|
|
97
|
+
limit,
|
|
51
98
|
)
|
|
52
99
|
return [RetrievalResult.from_db_row(dict(r)) for r in results]
|
|
53
100
|
|
|
54
101
|
|
|
55
|
-
async def retrieve_bm25(
|
|
56
|
-
conn,
|
|
57
|
-
query_text: str,
|
|
58
|
-
bank_id: str,
|
|
59
|
-
fact_type: str,
|
|
60
|
-
limit: int
|
|
61
|
-
) -> List[RetrievalResult]:
|
|
102
|
+
async def retrieve_bm25(conn, query_text: str, bank_id: str, fact_type: str, limit: int) -> list[RetrievalResult]:
|
|
62
103
|
"""
|
|
63
104
|
BM25 keyword retrieval via full-text search.
|
|
64
105
|
|
|
@@ -76,7 +117,7 @@ async def retrieve_bm25(
|
|
|
76
117
|
|
|
77
118
|
# Sanitize query text: remove special characters that have meaning in tsquery
|
|
78
119
|
# Keep only alphanumeric characters and spaces
|
|
79
|
-
sanitized_text = re.sub(r
|
|
120
|
+
sanitized_text = re.sub(r"[^\w\s]", " ", query_text.lower())
|
|
80
121
|
|
|
81
122
|
# Split and filter empty strings
|
|
82
123
|
tokens = [token for token in sanitized_text.split() if token]
|
|
@@ -100,126 +141,14 @@ async def retrieve_bm25(
|
|
|
100
141
|
ORDER BY bm25_score DESC
|
|
101
142
|
LIMIT $4
|
|
102
143
|
""",
|
|
103
|
-
query_tsquery,
|
|
144
|
+
query_tsquery,
|
|
145
|
+
bank_id,
|
|
146
|
+
fact_type,
|
|
147
|
+
limit,
|
|
104
148
|
)
|
|
105
149
|
return [RetrievalResult.from_db_row(dict(r)) for r in results]
|
|
106
150
|
|
|
107
151
|
|
|
108
|
-
async def retrieve_graph(
|
|
109
|
-
conn,
|
|
110
|
-
query_emb_str: str,
|
|
111
|
-
bank_id: str,
|
|
112
|
-
fact_type: str,
|
|
113
|
-
budget: int
|
|
114
|
-
) -> List[RetrievalResult]:
|
|
115
|
-
"""
|
|
116
|
-
Graph retrieval via spreading activation.
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
conn: Database connection
|
|
120
|
-
query_emb_str: Query embedding as string
|
|
121
|
-
agent_id: bank ID
|
|
122
|
-
fact_type: Fact type to filter
|
|
123
|
-
budget: Node budget for graph traversal
|
|
124
|
-
|
|
125
|
-
Returns:
|
|
126
|
-
List of RetrievalResult objects
|
|
127
|
-
"""
|
|
128
|
-
# Find entry points
|
|
129
|
-
entry_points = await conn.fetch(
|
|
130
|
-
"""
|
|
131
|
-
SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
|
|
132
|
-
1 - (embedding <=> $1::vector) AS similarity
|
|
133
|
-
FROM memory_units
|
|
134
|
-
WHERE bank_id = $2
|
|
135
|
-
AND embedding IS NOT NULL
|
|
136
|
-
AND fact_type = $3
|
|
137
|
-
AND (1 - (embedding <=> $1::vector)) >= 0.5
|
|
138
|
-
ORDER BY embedding <=> $1::vector
|
|
139
|
-
LIMIT 5
|
|
140
|
-
""",
|
|
141
|
-
query_emb_str, bank_id, fact_type
|
|
142
|
-
)
|
|
143
|
-
|
|
144
|
-
if not entry_points:
|
|
145
|
-
return []
|
|
146
|
-
|
|
147
|
-
# BFS-style spreading activation with batched neighbor fetching
|
|
148
|
-
visited = set()
|
|
149
|
-
results = []
|
|
150
|
-
queue = [(RetrievalResult.from_db_row(dict(r)), r["similarity"]) for r in entry_points]
|
|
151
|
-
budget_remaining = budget
|
|
152
|
-
|
|
153
|
-
# Process nodes in batches to reduce DB roundtrips
|
|
154
|
-
batch_size = 20 # Fetch neighbors for up to 20 nodes at once
|
|
155
|
-
|
|
156
|
-
while queue and budget_remaining > 0:
|
|
157
|
-
# Collect a batch of nodes to process
|
|
158
|
-
batch_nodes = []
|
|
159
|
-
batch_activations = {}
|
|
160
|
-
|
|
161
|
-
while queue and len(batch_nodes) < batch_size and budget_remaining > 0:
|
|
162
|
-
current, activation = queue.pop(0)
|
|
163
|
-
unit_id = current.id
|
|
164
|
-
|
|
165
|
-
if unit_id not in visited:
|
|
166
|
-
visited.add(unit_id)
|
|
167
|
-
budget_remaining -= 1
|
|
168
|
-
results.append(current)
|
|
169
|
-
batch_nodes.append(current.id)
|
|
170
|
-
batch_activations[unit_id] = activation
|
|
171
|
-
|
|
172
|
-
# Batch fetch neighbors for all nodes in this batch
|
|
173
|
-
# Fetch top weighted neighbors (batch_size * 20 = ~400 for good distribution)
|
|
174
|
-
if batch_nodes and budget_remaining > 0:
|
|
175
|
-
max_neighbors = len(batch_nodes) * 20
|
|
176
|
-
neighbors = await conn.fetch(
|
|
177
|
-
"""
|
|
178
|
-
SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end, mu.mentioned_at,
|
|
179
|
-
mu.access_count, mu.embedding, mu.fact_type, mu.document_id, mu.chunk_id,
|
|
180
|
-
ml.weight, ml.link_type, ml.from_unit_id
|
|
181
|
-
FROM memory_links ml
|
|
182
|
-
JOIN memory_units mu ON ml.to_unit_id = mu.id
|
|
183
|
-
WHERE ml.from_unit_id = ANY($1::uuid[])
|
|
184
|
-
AND ml.weight >= 0.1
|
|
185
|
-
AND mu.fact_type = $2
|
|
186
|
-
ORDER BY ml.weight DESC
|
|
187
|
-
LIMIT $3
|
|
188
|
-
""",
|
|
189
|
-
batch_nodes, fact_type, max_neighbors
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
for n in neighbors:
|
|
193
|
-
neighbor_id = str(n["id"])
|
|
194
|
-
if neighbor_id not in visited:
|
|
195
|
-
# Get parent activation
|
|
196
|
-
parent_id = str(n["from_unit_id"])
|
|
197
|
-
activation = batch_activations.get(parent_id, 0.5)
|
|
198
|
-
|
|
199
|
-
# Boost activation for causal links (they're high-value relationships)
|
|
200
|
-
link_type = n["link_type"]
|
|
201
|
-
base_weight = n["weight"]
|
|
202
|
-
|
|
203
|
-
# Causal links get 1.5-2.0x boost depending on type
|
|
204
|
-
if link_type in ("causes", "caused_by"):
|
|
205
|
-
# Direct causation - very strong relationship
|
|
206
|
-
causal_boost = 2.0
|
|
207
|
-
elif link_type in ("enables", "prevents"):
|
|
208
|
-
# Conditional causation - strong but not as direct
|
|
209
|
-
causal_boost = 1.5
|
|
210
|
-
else:
|
|
211
|
-
# Temporal, semantic, entity links - standard weight
|
|
212
|
-
causal_boost = 1.0
|
|
213
|
-
|
|
214
|
-
effective_weight = base_weight * causal_boost
|
|
215
|
-
new_activation = activation * effective_weight * 0.8
|
|
216
|
-
if new_activation > 0.1:
|
|
217
|
-
neighbor_result = RetrievalResult.from_db_row(dict(n))
|
|
218
|
-
queue.append((neighbor_result, new_activation))
|
|
219
|
-
|
|
220
|
-
return results
|
|
221
|
-
|
|
222
|
-
|
|
223
152
|
async def retrieve_temporal(
|
|
224
153
|
conn,
|
|
225
154
|
query_emb_str: str,
|
|
@@ -228,8 +157,8 @@ async def retrieve_temporal(
|
|
|
228
157
|
start_date: datetime,
|
|
229
158
|
end_date: datetime,
|
|
230
159
|
budget: int,
|
|
231
|
-
semantic_threshold: float = 0.1
|
|
232
|
-
) ->
|
|
160
|
+
semantic_threshold: float = 0.1,
|
|
161
|
+
) -> list[RetrievalResult]:
|
|
233
162
|
"""
|
|
234
163
|
Temporal retrieval with spreading activation.
|
|
235
164
|
|
|
@@ -251,13 +180,12 @@ async def retrieve_temporal(
|
|
|
251
180
|
Returns:
|
|
252
181
|
List of RetrievalResult objects with temporal scores
|
|
253
182
|
"""
|
|
254
|
-
from datetime import timezone
|
|
255
183
|
|
|
256
184
|
# Ensure start_date and end_date are timezone-aware (UTC) to match database datetimes
|
|
257
185
|
if start_date.tzinfo is None:
|
|
258
|
-
start_date = start_date.replace(tzinfo=
|
|
186
|
+
start_date = start_date.replace(tzinfo=UTC)
|
|
259
187
|
if end_date.tzinfo is None:
|
|
260
|
-
end_date = end_date.replace(tzinfo=
|
|
188
|
+
end_date = end_date.replace(tzinfo=UTC)
|
|
261
189
|
|
|
262
190
|
entry_points = await conn.fetch(
|
|
263
191
|
"""
|
|
@@ -284,7 +212,12 @@ async def retrieve_temporal(
|
|
|
284
212
|
ORDER BY COALESCE(occurred_start, mentioned_at, occurred_end) DESC, (embedding <=> $1::vector) ASC
|
|
285
213
|
LIMIT 10
|
|
286
214
|
""",
|
|
287
|
-
query_emb_str,
|
|
215
|
+
query_emb_str,
|
|
216
|
+
bank_id,
|
|
217
|
+
fact_type,
|
|
218
|
+
start_date,
|
|
219
|
+
end_date,
|
|
220
|
+
semantic_threshold,
|
|
288
221
|
)
|
|
289
222
|
|
|
290
223
|
if not entry_points:
|
|
@@ -327,7 +260,9 @@ async def retrieve_temporal(
|
|
|
327
260
|
results.append(ep_result)
|
|
328
261
|
|
|
329
262
|
# Spread through temporal links
|
|
330
|
-
queue = [
|
|
263
|
+
queue = [
|
|
264
|
+
(RetrievalResult.from_db_row(dict(ep)), ep["similarity"], 1.0) for ep in entry_points
|
|
265
|
+
] # (unit, semantic_sim, temporal_score)
|
|
331
266
|
budget_remaining = budget - len(entry_points)
|
|
332
267
|
|
|
333
268
|
while queue and budget_remaining > 0:
|
|
@@ -352,7 +287,10 @@ async def retrieve_temporal(
|
|
|
352
287
|
ORDER BY ml.weight DESC
|
|
353
288
|
LIMIT 10
|
|
354
289
|
""",
|
|
355
|
-
query_emb_str,
|
|
290
|
+
query_emb_str,
|
|
291
|
+
current.id,
|
|
292
|
+
fact_type,
|
|
293
|
+
semantic_threshold,
|
|
356
294
|
)
|
|
357
295
|
|
|
358
296
|
for n in neighbors:
|
|
@@ -376,7 +314,9 @@ async def retrieve_temporal(
|
|
|
376
314
|
|
|
377
315
|
if neighbor_best_date:
|
|
378
316
|
days_from_mid = abs((neighbor_best_date - mid_date).total_seconds() / 86400)
|
|
379
|
-
neighbor_temporal_proximity =
|
|
317
|
+
neighbor_temporal_proximity = (
|
|
318
|
+
1.0 - min(days_from_mid / (total_days / 2), 1.0) if total_days > 0 else 1.0
|
|
319
|
+
)
|
|
380
320
|
else:
|
|
381
321
|
neighbor_temporal_proximity = 0.3 # Lower score if no temporal data
|
|
382
322
|
|
|
@@ -418,9 +358,10 @@ async def retrieve_parallel(
|
|
|
418
358
|
bank_id: str,
|
|
419
359
|
fact_type: str,
|
|
420
360
|
thinking_budget: int,
|
|
421
|
-
question_date:
|
|
422
|
-
query_analyzer: Optional["QueryAnalyzer"] = None
|
|
423
|
-
|
|
361
|
+
question_date: datetime | None = None,
|
|
362
|
+
query_analyzer: Optional["QueryAnalyzer"] = None,
|
|
363
|
+
graph_retriever: GraphRetriever | None = None,
|
|
364
|
+
) -> ParallelRetrievalResult:
|
|
424
365
|
"""
|
|
425
366
|
Run 3-way or 4-way parallel retrieval (adds temporal if detected).
|
|
426
367
|
|
|
@@ -428,76 +369,330 @@ async def retrieve_parallel(
|
|
|
428
369
|
pool: Database connection pool
|
|
429
370
|
query_text: Query text
|
|
430
371
|
query_embedding_str: Query embedding as string
|
|
431
|
-
|
|
372
|
+
bank_id: Bank ID
|
|
432
373
|
fact_type: Fact type to filter
|
|
433
374
|
thinking_budget: Budget for graph traversal and retrieval limits
|
|
434
375
|
question_date: Optional date when question was asked (for temporal filtering)
|
|
435
376
|
query_analyzer: Query analyzer to use (defaults to TransformerQueryAnalyzer)
|
|
377
|
+
graph_retriever: Graph retrieval strategy (defaults to configured retriever)
|
|
436
378
|
|
|
437
379
|
Returns:
|
|
438
|
-
|
|
439
|
-
Each results list contains RetrievalResult objects
|
|
440
|
-
temporal_results is None if no temporal constraint detected
|
|
441
|
-
timings is a dict with per-method latencies in seconds
|
|
442
|
-
temporal_constraint is the (start_date, end_date) tuple if detected, else None
|
|
380
|
+
ParallelRetrievalResult with semantic, bm25, graph, temporal results and timings
|
|
443
381
|
"""
|
|
444
|
-
# Detect temporal constraint
|
|
445
382
|
from .temporal_extraction import extract_temporal_constraint
|
|
383
|
+
|
|
384
|
+
temporal_constraint = extract_temporal_constraint(query_text, reference_date=question_date, analyzer=query_analyzer)
|
|
385
|
+
|
|
386
|
+
retriever = graph_retriever or get_default_graph_retriever()
|
|
387
|
+
|
|
388
|
+
if retriever.name == "mpfp":
|
|
389
|
+
return await _retrieve_parallel_mpfp(
|
|
390
|
+
pool, query_text, query_embedding_str, bank_id, fact_type, thinking_budget, temporal_constraint, retriever
|
|
391
|
+
)
|
|
392
|
+
else:
|
|
393
|
+
return await _retrieve_parallel_bfs(
|
|
394
|
+
pool, query_text, query_embedding_str, bank_id, fact_type, thinking_budget, temporal_constraint, retriever
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@dataclass
|
|
399
|
+
class _SemanticGraphResult:
|
|
400
|
+
"""Internal result from semantic→graph chain."""
|
|
401
|
+
|
|
402
|
+
semantic: list[RetrievalResult]
|
|
403
|
+
graph: list[RetrievalResult]
|
|
404
|
+
semantic_time: float
|
|
405
|
+
graph_time: float
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
@dataclass
|
|
409
|
+
class _TimedResult:
|
|
410
|
+
"""Internal result with timing."""
|
|
411
|
+
|
|
412
|
+
results: list[RetrievalResult]
|
|
413
|
+
time: float
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
async def _retrieve_parallel_mpfp(
|
|
417
|
+
pool,
|
|
418
|
+
query_text: str,
|
|
419
|
+
query_embedding_str: str,
|
|
420
|
+
bank_id: str,
|
|
421
|
+
fact_type: str,
|
|
422
|
+
thinking_budget: int,
|
|
423
|
+
temporal_constraint: tuple | None,
|
|
424
|
+
retriever: GraphRetriever,
|
|
425
|
+
) -> ParallelRetrievalResult:
|
|
426
|
+
"""
|
|
427
|
+
MPFP retrieval with optimized parallelization.
|
|
428
|
+
|
|
429
|
+
Runs 2-3 parallel task chains:
|
|
430
|
+
- Task 1: Semantic → Graph (chained, graph uses semantic seeds)
|
|
431
|
+
- Task 2: BM25 (independent)
|
|
432
|
+
- Task 3: Temporal (if constraint detected)
|
|
433
|
+
"""
|
|
446
434
|
import time
|
|
447
435
|
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
436
|
+
async def run_semantic_then_graph() -> _SemanticGraphResult:
|
|
437
|
+
"""Chain: semantic retrieval → graph retrieval (using semantic as seeds)."""
|
|
438
|
+
start = time.time()
|
|
439
|
+
async with acquire_with_retry(pool) as conn:
|
|
440
|
+
semantic = await retrieve_semantic(conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget)
|
|
441
|
+
semantic_time = time.time() - start
|
|
442
|
+
|
|
443
|
+
# Get temporal seeds if needed (quick query, part of this chain)
|
|
444
|
+
temporal_seeds = None
|
|
445
|
+
if temporal_constraint:
|
|
446
|
+
tc_start, tc_end = temporal_constraint
|
|
447
|
+
async with acquire_with_retry(pool) as conn:
|
|
448
|
+
temporal_seeds = await _get_temporal_entry_points(
|
|
449
|
+
conn, query_embedding_str, bank_id, fact_type, tc_start, tc_end, limit=20
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Run graph with seeds
|
|
453
|
+
start = time.time()
|
|
454
|
+
graph = await retriever.retrieve(
|
|
455
|
+
pool=pool,
|
|
456
|
+
query_embedding_str=query_embedding_str,
|
|
457
|
+
bank_id=bank_id,
|
|
458
|
+
fact_type=fact_type,
|
|
459
|
+
budget=thinking_budget,
|
|
460
|
+
query_text=query_text,
|
|
461
|
+
semantic_seeds=semantic,
|
|
462
|
+
temporal_seeds=temporal_seeds,
|
|
463
|
+
)
|
|
464
|
+
graph_time = time.time() - start
|
|
465
|
+
|
|
466
|
+
return _SemanticGraphResult(semantic, graph, semantic_time, graph_time)
|
|
451
467
|
|
|
452
|
-
|
|
453
|
-
|
|
468
|
+
async def run_bm25() -> _TimedResult:
|
|
469
|
+
"""Independent BM25 retrieval."""
|
|
454
470
|
start = time.time()
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
return
|
|
471
|
+
async with acquire_with_retry(pool) as conn:
|
|
472
|
+
results = await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget)
|
|
473
|
+
return _TimedResult(results, time.time() - start)
|
|
458
474
|
|
|
459
|
-
async def
|
|
475
|
+
async def run_temporal(tc_start, tc_end) -> _TimedResult:
|
|
476
|
+
"""Temporal retrieval (uses its own entry point finding)."""
|
|
477
|
+
start = time.time()
|
|
460
478
|
async with acquire_with_retry(pool) as conn:
|
|
461
|
-
|
|
479
|
+
results = await retrieve_temporal(
|
|
480
|
+
conn,
|
|
481
|
+
query_embedding_str,
|
|
482
|
+
bank_id,
|
|
483
|
+
fact_type,
|
|
484
|
+
tc_start,
|
|
485
|
+
tc_end,
|
|
486
|
+
budget=thinking_budget,
|
|
487
|
+
semantic_threshold=0.1,
|
|
488
|
+
)
|
|
489
|
+
return _TimedResult(results, time.time() - start)
|
|
490
|
+
|
|
491
|
+
# Run parallel task chains
|
|
492
|
+
if temporal_constraint:
|
|
493
|
+
tc_start, tc_end = temporal_constraint
|
|
494
|
+
sg_result, bm25_result, temporal_result = await asyncio.gather(
|
|
495
|
+
run_semantic_then_graph(),
|
|
496
|
+
run_bm25(),
|
|
497
|
+
run_temporal(tc_start, tc_end),
|
|
498
|
+
)
|
|
499
|
+
return ParallelRetrievalResult(
|
|
500
|
+
semantic=sg_result.semantic,
|
|
501
|
+
bm25=bm25_result.results,
|
|
502
|
+
graph=sg_result.graph,
|
|
503
|
+
temporal=temporal_result.results,
|
|
504
|
+
timings={
|
|
505
|
+
"semantic": sg_result.semantic_time,
|
|
506
|
+
"graph": sg_result.graph_time,
|
|
507
|
+
"bm25": bm25_result.time,
|
|
508
|
+
"temporal": temporal_result.time,
|
|
509
|
+
},
|
|
510
|
+
temporal_constraint=temporal_constraint,
|
|
511
|
+
)
|
|
512
|
+
else:
|
|
513
|
+
sg_result, bm25_result = await asyncio.gather(
|
|
514
|
+
run_semantic_then_graph(),
|
|
515
|
+
run_bm25(),
|
|
516
|
+
)
|
|
517
|
+
return ParallelRetrievalResult(
|
|
518
|
+
semantic=sg_result.semantic,
|
|
519
|
+
bm25=bm25_result.results,
|
|
520
|
+
graph=sg_result.graph,
|
|
521
|
+
temporal=None,
|
|
522
|
+
timings={
|
|
523
|
+
"semantic": sg_result.semantic_time,
|
|
524
|
+
"graph": sg_result.graph_time,
|
|
525
|
+
"bm25": bm25_result.time,
|
|
526
|
+
},
|
|
527
|
+
temporal_constraint=None,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
async def _get_temporal_entry_points(
|
|
532
|
+
conn,
|
|
533
|
+
query_embedding_str: str,
|
|
534
|
+
bank_id: str,
|
|
535
|
+
fact_type: str,
|
|
536
|
+
start_date: datetime,
|
|
537
|
+
end_date: datetime,
|
|
538
|
+
limit: int = 20,
|
|
539
|
+
semantic_threshold: float = 0.1,
|
|
540
|
+
) -> list[RetrievalResult]:
|
|
541
|
+
"""Get temporal entry points (facts in date range with semantic relevance)."""
|
|
542
|
+
|
|
543
|
+
if start_date.tzinfo is None:
|
|
544
|
+
start_date = start_date.replace(tzinfo=UTC)
|
|
545
|
+
if end_date.tzinfo is None:
|
|
546
|
+
end_date = end_date.replace(tzinfo=UTC)
|
|
547
|
+
|
|
548
|
+
rows = await conn.fetch(
|
|
549
|
+
"""
|
|
550
|
+
SELECT id, text, context, event_date, occurred_start, occurred_end, mentioned_at,
|
|
551
|
+
access_count, embedding, fact_type, document_id, chunk_id,
|
|
552
|
+
1 - (embedding <=> $1::vector) AS similarity
|
|
553
|
+
FROM memory_units
|
|
554
|
+
WHERE bank_id = $2
|
|
555
|
+
AND fact_type = $3
|
|
556
|
+
AND embedding IS NOT NULL
|
|
557
|
+
AND (
|
|
558
|
+
(occurred_start IS NOT NULL AND occurred_end IS NOT NULL
|
|
559
|
+
AND occurred_start <= $5 AND occurred_end >= $4)
|
|
560
|
+
OR (mentioned_at IS NOT NULL AND mentioned_at BETWEEN $4 AND $5)
|
|
561
|
+
OR (occurred_start IS NOT NULL AND occurred_start BETWEEN $4 AND $5)
|
|
562
|
+
OR (occurred_end IS NOT NULL AND occurred_end BETWEEN $4 AND $5)
|
|
563
|
+
)
|
|
564
|
+
AND (1 - (embedding <=> $1::vector)) >= $6
|
|
565
|
+
ORDER BY COALESCE(occurred_start, mentioned_at, occurred_end) DESC,
|
|
566
|
+
(embedding <=> $1::vector) ASC
|
|
567
|
+
LIMIT $7
|
|
568
|
+
""",
|
|
569
|
+
query_embedding_str,
|
|
570
|
+
bank_id,
|
|
571
|
+
fact_type,
|
|
572
|
+
start_date,
|
|
573
|
+
end_date,
|
|
574
|
+
semantic_threshold,
|
|
575
|
+
limit,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
results = []
|
|
579
|
+
total_days = max((end_date - start_date).total_seconds() / 86400, 1)
|
|
580
|
+
mid_date = start_date + (end_date - start_date) / 2
|
|
581
|
+
|
|
582
|
+
for row in rows:
|
|
583
|
+
result = RetrievalResult.from_db_row(dict(row))
|
|
584
|
+
|
|
585
|
+
# Calculate temporal proximity score
|
|
586
|
+
best_date = None
|
|
587
|
+
if row["occurred_start"] and row["occurred_end"]:
|
|
588
|
+
best_date = row["occurred_start"] + (row["occurred_end"] - row["occurred_start"]) / 2
|
|
589
|
+
elif row["occurred_start"]:
|
|
590
|
+
best_date = row["occurred_start"]
|
|
591
|
+
elif row["occurred_end"]:
|
|
592
|
+
best_date = row["occurred_end"]
|
|
593
|
+
elif row["mentioned_at"]:
|
|
594
|
+
best_date = row["mentioned_at"]
|
|
595
|
+
|
|
596
|
+
if best_date:
|
|
597
|
+
days_from_mid = abs((best_date - mid_date).total_seconds() / 86400)
|
|
598
|
+
result.temporal_proximity = 1.0 - min(days_from_mid / (total_days / 2), 1.0)
|
|
599
|
+
else:
|
|
600
|
+
result.temporal_proximity = 0.5
|
|
462
601
|
|
|
463
|
-
|
|
602
|
+
result.temporal_score = result.temporal_proximity
|
|
603
|
+
results.append(result)
|
|
604
|
+
|
|
605
|
+
return results
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
async def _retrieve_parallel_bfs(
|
|
609
|
+
pool,
|
|
610
|
+
query_text: str,
|
|
611
|
+
query_embedding_str: str,
|
|
612
|
+
bank_id: str,
|
|
613
|
+
fact_type: str,
|
|
614
|
+
thinking_budget: int,
|
|
615
|
+
temporal_constraint: tuple | None,
|
|
616
|
+
retriever: GraphRetriever,
|
|
617
|
+
) -> ParallelRetrievalResult:
|
|
618
|
+
"""BFS retrieval: all methods run in parallel (original behavior)."""
|
|
619
|
+
import time
|
|
620
|
+
|
|
621
|
+
async def run_semantic() -> _TimedResult:
|
|
622
|
+
start = time.time()
|
|
464
623
|
async with acquire_with_retry(pool) as conn:
|
|
465
|
-
|
|
624
|
+
results = await retrieve_semantic(conn, query_embedding_str, bank_id, fact_type, limit=thinking_budget)
|
|
625
|
+
return _TimedResult(results, time.time() - start)
|
|
466
626
|
|
|
467
|
-
async def
|
|
627
|
+
async def run_bm25() -> _TimedResult:
|
|
628
|
+
start = time.time()
|
|
468
629
|
async with acquire_with_retry(pool) as conn:
|
|
469
|
-
|
|
630
|
+
results = await retrieve_bm25(conn, query_text, bank_id, fact_type, limit=thinking_budget)
|
|
631
|
+
return _TimedResult(results, time.time() - start)
|
|
632
|
+
|
|
633
|
+
async def run_graph() -> _TimedResult:
|
|
634
|
+
start = time.time()
|
|
635
|
+
results = await retriever.retrieve(
|
|
636
|
+
pool=pool,
|
|
637
|
+
query_embedding_str=query_embedding_str,
|
|
638
|
+
bank_id=bank_id,
|
|
639
|
+
fact_type=fact_type,
|
|
640
|
+
budget=thinking_budget,
|
|
641
|
+
query_text=query_text,
|
|
642
|
+
)
|
|
643
|
+
return _TimedResult(results, time.time() - start)
|
|
470
644
|
|
|
471
|
-
async def run_temporal(
|
|
645
|
+
async def run_temporal(tc_start, tc_end) -> _TimedResult:
|
|
646
|
+
start = time.time()
|
|
472
647
|
async with acquire_with_retry(pool) as conn:
|
|
473
|
-
|
|
474
|
-
conn,
|
|
475
|
-
|
|
648
|
+
results = await retrieve_temporal(
|
|
649
|
+
conn,
|
|
650
|
+
query_embedding_str,
|
|
651
|
+
bank_id,
|
|
652
|
+
fact_type,
|
|
653
|
+
tc_start,
|
|
654
|
+
tc_end,
|
|
655
|
+
budget=thinking_budget,
|
|
656
|
+
semantic_threshold=0.1,
|
|
476
657
|
)
|
|
658
|
+
return _TimedResult(results, time.time() - start)
|
|
477
659
|
|
|
478
|
-
# Run retrievals in parallel with timing
|
|
479
|
-
timings = {}
|
|
480
660
|
if temporal_constraint:
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
661
|
+
tc_start, tc_end = temporal_constraint
|
|
662
|
+
semantic_r, bm25_r, graph_r, temporal_r = await asyncio.gather(
|
|
663
|
+
run_semantic(),
|
|
664
|
+
run_bm25(),
|
|
665
|
+
run_graph(),
|
|
666
|
+
run_temporal(tc_start, tc_end),
|
|
667
|
+
)
|
|
668
|
+
return ParallelRetrievalResult(
|
|
669
|
+
semantic=semantic_r.results,
|
|
670
|
+
bm25=bm25_r.results,
|
|
671
|
+
graph=graph_r.results,
|
|
672
|
+
temporal=temporal_r.results,
|
|
673
|
+
timings={
|
|
674
|
+
"semantic": semantic_r.time,
|
|
675
|
+
"bm25": bm25_r.time,
|
|
676
|
+
"graph": graph_r.time,
|
|
677
|
+
"temporal": temporal_r.time,
|
|
678
|
+
},
|
|
679
|
+
temporal_constraint=temporal_constraint,
|
|
487
680
|
)
|
|
488
|
-
semantic_results, _, timings["semantic"] = results[0]
|
|
489
|
-
bm25_results, _, timings["bm25"] = results[1]
|
|
490
|
-
graph_results, _, timings["graph"] = results[2]
|
|
491
|
-
temporal_results, _, timings["temporal"] = results[3]
|
|
492
681
|
else:
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
682
|
+
semantic_r, bm25_r, graph_r = await asyncio.gather(
|
|
683
|
+
run_semantic(),
|
|
684
|
+
run_bm25(),
|
|
685
|
+
run_graph(),
|
|
686
|
+
)
|
|
687
|
+
return ParallelRetrievalResult(
|
|
688
|
+
semantic=semantic_r.results,
|
|
689
|
+
bm25=bm25_r.results,
|
|
690
|
+
graph=graph_r.results,
|
|
691
|
+
temporal=None,
|
|
692
|
+
timings={
|
|
693
|
+
"semantic": semantic_r.time,
|
|
694
|
+
"bm25": bm25_r.time,
|
|
695
|
+
"graph": graph_r.time,
|
|
696
|
+
},
|
|
697
|
+
temporal_constraint=None,
|
|
497
698
|
)
|
|
498
|
-
semantic_results, _, timings["semantic"] = results[0]
|
|
499
|
-
bm25_results, _, timings["bm25"] = results[1]
|
|
500
|
-
graph_results, _, timings["graph"] = results[2]
|
|
501
|
-
temporal_results = None
|
|
502
|
-
|
|
503
|
-
return semantic_results, bm25_results, graph_results, temporal_results, timings, temporal_constraint
|