hindsight-api 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/admin/__init__.py +1 -0
- hindsight_api/admin/cli.py +311 -0
- hindsight_api/alembic/versions/f1a2b3c4d5e6_add_memory_links_composite_index.py +44 -0
- hindsight_api/alembic/versions/g2a3b4c5d6e7_add_tags_column.py +48 -0
- hindsight_api/alembic/versions/h3c4d5e6f7g8_mental_models_v4.py +112 -0
- hindsight_api/alembic/versions/i4d5e6f7g8h9_delete_opinions.py +41 -0
- hindsight_api/alembic/versions/j5e6f7g8h9i0_mental_model_versions.py +95 -0
- hindsight_api/alembic/versions/k6f7g8h9i0j1_add_directive_subtype.py +58 -0
- hindsight_api/alembic/versions/l7g8h9i0j1k2_add_worker_columns.py +109 -0
- hindsight_api/alembic/versions/m8h9i0j1k2l3_mental_model_id_to_text.py +41 -0
- hindsight_api/alembic/versions/n9i0j1k2l3m4_learnings_and_pinned_reflections.py +134 -0
- hindsight_api/alembic/versions/o0j1k2l3m4n5_migrate_mental_models_data.py +113 -0
- hindsight_api/alembic/versions/p1k2l3m4n5o6_new_knowledge_architecture.py +194 -0
- hindsight_api/alembic/versions/q2l3m4n5o6p7_fix_mental_model_fact_type.py +50 -0
- hindsight_api/alembic/versions/r3m4n5o6p7q8_add_reflect_response_to_reflections.py +47 -0
- hindsight_api/alembic/versions/s4n5o6p7q8r9_add_consolidated_at_to_memory_units.py +53 -0
- hindsight_api/alembic/versions/t5o6p7q8r9s0_rename_mental_models_to_observations.py +134 -0
- hindsight_api/alembic/versions/u6p7q8r9s0t1_mental_models_text_id.py +41 -0
- hindsight_api/alembic/versions/v7q8r9s0t1u2_add_max_tokens_to_mental_models.py +50 -0
- hindsight_api/api/http.py +1406 -118
- hindsight_api/api/mcp.py +11 -196
- hindsight_api/config.py +359 -27
- hindsight_api/engine/consolidation/__init__.py +5 -0
- hindsight_api/engine/consolidation/consolidator.py +859 -0
- hindsight_api/engine/consolidation/prompts.py +69 -0
- hindsight_api/engine/cross_encoder.py +706 -88
- hindsight_api/engine/db_budget.py +284 -0
- hindsight_api/engine/db_utils.py +11 -0
- hindsight_api/engine/directives/__init__.py +5 -0
- hindsight_api/engine/directives/models.py +37 -0
- hindsight_api/engine/embeddings.py +553 -29
- hindsight_api/engine/entity_resolver.py +8 -5
- hindsight_api/engine/interface.py +40 -17
- hindsight_api/engine/llm_wrapper.py +744 -68
- hindsight_api/engine/memory_engine.py +2505 -1017
- hindsight_api/engine/mental_models/__init__.py +14 -0
- hindsight_api/engine/mental_models/models.py +53 -0
- hindsight_api/engine/query_analyzer.py +4 -3
- hindsight_api/engine/reflect/__init__.py +18 -0
- hindsight_api/engine/reflect/agent.py +933 -0
- hindsight_api/engine/reflect/models.py +109 -0
- hindsight_api/engine/reflect/observations.py +186 -0
- hindsight_api/engine/reflect/prompts.py +483 -0
- hindsight_api/engine/reflect/tools.py +437 -0
- hindsight_api/engine/reflect/tools_schema.py +250 -0
- hindsight_api/engine/response_models.py +168 -4
- hindsight_api/engine/retain/bank_utils.py +79 -201
- hindsight_api/engine/retain/fact_extraction.py +424 -195
- hindsight_api/engine/retain/fact_storage.py +35 -12
- hindsight_api/engine/retain/link_utils.py +29 -24
- hindsight_api/engine/retain/orchestrator.py +24 -43
- hindsight_api/engine/retain/types.py +11 -2
- hindsight_api/engine/search/graph_retrieval.py +43 -14
- hindsight_api/engine/search/link_expansion_retrieval.py +391 -0
- hindsight_api/engine/search/mpfp_retrieval.py +362 -117
- hindsight_api/engine/search/reranking.py +2 -2
- hindsight_api/engine/search/retrieval.py +848 -201
- hindsight_api/engine/search/tags.py +172 -0
- hindsight_api/engine/search/think_utils.py +42 -141
- hindsight_api/engine/search/trace.py +12 -1
- hindsight_api/engine/search/tracer.py +26 -6
- hindsight_api/engine/search/types.py +21 -3
- hindsight_api/engine/task_backend.py +113 -106
- hindsight_api/engine/utils.py +1 -152
- hindsight_api/extensions/__init__.py +10 -1
- hindsight_api/extensions/builtin/tenant.py +5 -1
- hindsight_api/extensions/context.py +10 -1
- hindsight_api/extensions/operation_validator.py +81 -4
- hindsight_api/extensions/tenant.py +26 -0
- hindsight_api/main.py +69 -6
- hindsight_api/mcp_local.py +12 -53
- hindsight_api/mcp_tools.py +494 -0
- hindsight_api/metrics.py +433 -48
- hindsight_api/migrations.py +141 -1
- hindsight_api/models.py +3 -3
- hindsight_api/pg0.py +53 -0
- hindsight_api/server.py +39 -2
- hindsight_api/worker/__init__.py +11 -0
- hindsight_api/worker/main.py +296 -0
- hindsight_api/worker/poller.py +486 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/METADATA +16 -6
- hindsight_api-0.4.0.dist-info/RECORD +112 -0
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/entry_points.txt +2 -0
- hindsight_api/engine/retain/observation_regeneration.py +0 -254
- hindsight_api/engine/search/observation_utils.py +0 -125
- hindsight_api/engine/search/scoring.py +0 -159
- hindsight_api-0.2.1.dist-info/RECORD +0 -75
- {hindsight_api-0.2.1.dist-info → hindsight_api-0.4.0.dist-info}/WHEEL +0 -0
|
@@ -9,6 +9,7 @@ propagation from Approximate PPR.
|
|
|
9
9
|
|
|
10
10
|
Key properties:
|
|
11
11
|
- Sublinear in graph size (threshold pruning bounds active nodes)
|
|
12
|
+
- Lazy edge loading: only loads edges for frontier nodes, not entire graph
|
|
12
13
|
- Predefined patterns capture different retrieval intents
|
|
13
14
|
- All patterns run in parallel, results fused via RRF
|
|
14
15
|
- No LLM in the loop during traversal
|
|
@@ -22,7 +23,8 @@ from dataclasses import dataclass, field
|
|
|
22
23
|
from ..db_utils import acquire_with_retry
|
|
23
24
|
from ..memory_engine import fq_table
|
|
24
25
|
from .graph_retrieval import GraphRetriever
|
|
25
|
-
from .
|
|
26
|
+
from .tags import TagsMatch
|
|
27
|
+
from .types import MPFPTimings, RetrievalResult
|
|
26
28
|
|
|
27
29
|
logger = logging.getLogger(__name__)
|
|
28
30
|
|
|
@@ -41,11 +43,27 @@ class EdgeTarget:
|
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
@dataclass
|
|
44
|
-
class
|
|
45
|
-
"""
|
|
46
|
+
class EdgeCache:
|
|
47
|
+
"""
|
|
48
|
+
Cache for lazily-loaded edges.
|
|
49
|
+
|
|
50
|
+
Grows per-hop as edges are loaded for frontier nodes.
|
|
51
|
+
Shared across patterns to avoid redundant loads.
|
|
52
|
+
Loads ALL edge types at once to minimize DB queries.
|
|
53
|
+
Thread-safe via asyncio lock to prevent redundant concurrent loads.
|
|
54
|
+
"""
|
|
46
55
|
|
|
47
|
-
# edge_type -> from_node_id -> list of
|
|
56
|
+
# edge_type -> from_node_id -> list of EdgeTarget
|
|
48
57
|
graphs: dict[str, dict[str, list[EdgeTarget]]] = field(default_factory=dict)
|
|
58
|
+
# Track which nodes have been fully loaded (all edge types)
|
|
59
|
+
_fully_loaded: set[str] = field(default_factory=set)
|
|
60
|
+
# Timing stats
|
|
61
|
+
db_queries: int = 0
|
|
62
|
+
edge_load_time: float = 0.0
|
|
63
|
+
# Detailed hop timing for debugging
|
|
64
|
+
hop_details: list[dict] = field(default_factory=list)
|
|
65
|
+
# Lock to prevent redundant concurrent loads
|
|
66
|
+
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
|
49
67
|
|
|
50
68
|
def get_neighbors(self, edge_type: str, node_id: str) -> list[EdgeTarget]:
|
|
51
69
|
"""Get neighbors for a node via a specific edge type."""
|
|
@@ -63,6 +81,31 @@ class TypedAdjacency:
|
|
|
63
81
|
|
|
64
82
|
return [EdgeTarget(node_id=n.node_id, weight=n.weight / total) for n in neighbors]
|
|
65
83
|
|
|
84
|
+
def is_fully_loaded(self, node_id: str) -> bool:
|
|
85
|
+
"""Check if all edges for this node have been loaded."""
|
|
86
|
+
return node_id in self._fully_loaded
|
|
87
|
+
|
|
88
|
+
def get_uncached(self, node_ids: list[str]) -> list[str]:
|
|
89
|
+
"""Get node IDs that haven't been fully loaded yet."""
|
|
90
|
+
return [n for n in node_ids if not self.is_fully_loaded(n)]
|
|
91
|
+
|
|
92
|
+
def add_all_edges(self, edges_by_type: dict[str, dict[str, list[EdgeTarget]]], all_queried: list[str]):
|
|
93
|
+
"""
|
|
94
|
+
Add loaded edges to the cache (all edge types at once).
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
edges_by_type: Dict mapping edge_type -> from_node_id -> list of EdgeTarget
|
|
98
|
+
all_queried: All node IDs that were queried (marks them as fully loaded)
|
|
99
|
+
"""
|
|
100
|
+
for edge_type, edges in edges_by_type.items():
|
|
101
|
+
if edge_type not in self.graphs:
|
|
102
|
+
self.graphs[edge_type] = {}
|
|
103
|
+
for node_id, neighbors in edges.items():
|
|
104
|
+
self.graphs[edge_type][node_id] = neighbors
|
|
105
|
+
|
|
106
|
+
# Mark all queried nodes as fully loaded (even if they have no edges)
|
|
107
|
+
self._fully_loaded.update(all_queried)
|
|
108
|
+
|
|
66
109
|
|
|
67
110
|
@dataclass
|
|
68
111
|
class PatternResult:
|
|
@@ -109,66 +152,249 @@ class SeedNode:
|
|
|
109
152
|
|
|
110
153
|
|
|
111
154
|
# -----------------------------------------------------------------------------
|
|
112
|
-
#
|
|
155
|
+
# Lazy Edge Loading
|
|
113
156
|
# -----------------------------------------------------------------------------
|
|
114
157
|
|
|
115
158
|
|
|
116
|
-
def
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
) -> PatternResult:
|
|
159
|
+
async def load_all_edges_for_frontier(
|
|
160
|
+
pool,
|
|
161
|
+
node_ids: list[str],
|
|
162
|
+
top_k_per_type: int = 20,
|
|
163
|
+
) -> dict[str, dict[str, list[EdgeTarget]]]:
|
|
122
164
|
"""
|
|
123
|
-
|
|
165
|
+
Load top-k edges per (node, edge_type) for frontier nodes.
|
|
166
|
+
|
|
167
|
+
Uses a LATERAL join to efficiently fetch only the top-k edges per type,
|
|
168
|
+
avoiding loading hundreds of entity edges when only 20 are needed.
|
|
169
|
+
|
|
170
|
+
Requires composite index: (from_unit_id, link_type, weight DESC)
|
|
124
171
|
|
|
125
172
|
Args:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
config: Algorithm parameters
|
|
173
|
+
pool: Database connection pool
|
|
174
|
+
node_ids: Frontier node IDs to load edges for
|
|
175
|
+
top_k_per_type: Max edges to load per (node, link_type) pair
|
|
130
176
|
|
|
131
177
|
Returns:
|
|
132
|
-
|
|
178
|
+
Dict mapping edge_type -> from_node_id -> list of EdgeTarget
|
|
133
179
|
"""
|
|
134
|
-
if not
|
|
135
|
-
return
|
|
180
|
+
if not node_ids:
|
|
181
|
+
return {}
|
|
182
|
+
|
|
183
|
+
async with acquire_with_retry(pool) as conn:
|
|
184
|
+
# Use LATERAL join to get top-k per (from_node, link_type)
|
|
185
|
+
# This leverages the composite index for efficient early termination
|
|
186
|
+
rows = await conn.fetch(
|
|
187
|
+
f"""
|
|
188
|
+
WITH frontier(node_id) AS (SELECT unnest($1::uuid[]))
|
|
189
|
+
SELECT f.node_id as from_unit_id, lt.link_type, edges.to_unit_id, edges.weight
|
|
190
|
+
FROM frontier f
|
|
191
|
+
CROSS JOIN (VALUES ('semantic'), ('temporal'), ('entity'), ('causes'), ('caused_by')) AS lt(link_type)
|
|
192
|
+
CROSS JOIN LATERAL (
|
|
193
|
+
SELECT ml.to_unit_id, ml.weight
|
|
194
|
+
FROM {fq_table("memory_links")} ml
|
|
195
|
+
WHERE ml.from_unit_id = f.node_id
|
|
196
|
+
AND ml.link_type = lt.link_type
|
|
197
|
+
AND ml.weight >= 0.1
|
|
198
|
+
ORDER BY ml.weight DESC
|
|
199
|
+
LIMIT $2
|
|
200
|
+
) edges
|
|
201
|
+
""",
|
|
202
|
+
node_ids,
|
|
203
|
+
top_k_per_type,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Group by edge_type -> from_node -> neighbors
|
|
207
|
+
result: dict[str, dict[str, list[EdgeTarget]]] = defaultdict(lambda: defaultdict(list))
|
|
208
|
+
for row in rows:
|
|
209
|
+
edge_type = row["link_type"]
|
|
210
|
+
from_id = str(row["from_unit_id"])
|
|
211
|
+
to_id = str(row["to_unit_id"])
|
|
212
|
+
weight = row["weight"]
|
|
213
|
+
result[edge_type][from_id].append(EdgeTarget(node_id=to_id, weight=weight))
|
|
214
|
+
|
|
215
|
+
# Convert nested defaultdicts to regular dicts
|
|
216
|
+
return {edge_type: dict(edges) for edge_type, edges in result.items()}
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# -----------------------------------------------------------------------------
|
|
220
|
+
# Core Algorithm (Async with Lazy Loading)
|
|
221
|
+
# -----------------------------------------------------------------------------
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
@dataclass
|
|
225
|
+
class PatternState:
|
|
226
|
+
"""State for a pattern traversal between hops."""
|
|
227
|
+
|
|
228
|
+
pattern: list[str]
|
|
229
|
+
hop_index: int
|
|
230
|
+
scores: dict[str, float]
|
|
231
|
+
frontier: dict[str, float]
|
|
136
232
|
|
|
137
|
-
scores: dict[str, float] = {}
|
|
138
233
|
|
|
139
|
-
|
|
234
|
+
def _init_pattern_state(seeds: list[SeedNode], pattern: list[str]) -> PatternState:
|
|
235
|
+
"""Initialize pattern state from seeds."""
|
|
236
|
+
if not seeds:
|
|
237
|
+
return PatternState(pattern=pattern, hop_index=0, scores={}, frontier={})
|
|
238
|
+
|
|
140
239
|
total_seed_score = sum(s.score for s in seeds)
|
|
141
240
|
if total_seed_score == 0:
|
|
142
|
-
total_seed_score = len(seeds)
|
|
241
|
+
total_seed_score = len(seeds)
|
|
242
|
+
|
|
243
|
+
frontier = {s.node_id: s.score / total_seed_score for s in seeds}
|
|
244
|
+
return PatternState(pattern=pattern, hop_index=0, scores={}, frontier=frontier)
|
|
245
|
+
|
|
143
246
|
|
|
144
|
-
|
|
247
|
+
def _execute_hop(state: PatternState, cache: EdgeCache, config: MPFPConfig) -> set[str]:
|
|
248
|
+
"""
|
|
249
|
+
Execute ONE hop of traversal, return frontier nodes for next hop.
|
|
250
|
+
|
|
251
|
+
This is a pure function that uses cached edges (no DB access).
|
|
252
|
+
Returns set of uncached nodes needed for next hop.
|
|
253
|
+
"""
|
|
254
|
+
if state.hop_index >= len(state.pattern):
|
|
255
|
+
return set()
|
|
145
256
|
|
|
146
|
-
|
|
147
|
-
for edge_type in pattern:
|
|
148
|
-
next_frontier: dict[str, float] = {}
|
|
257
|
+
edge_type = state.pattern[state.hop_index]
|
|
149
258
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
259
|
+
# Collect active nodes above threshold
|
|
260
|
+
active_nodes = [node_id for node_id, mass in state.frontier.items() if mass >= config.threshold]
|
|
261
|
+
if not active_nodes:
|
|
262
|
+
state.frontier = {}
|
|
263
|
+
return set()
|
|
153
264
|
|
|
154
|
-
|
|
155
|
-
|
|
265
|
+
# Propagate mass using cached edges
|
|
266
|
+
next_frontier: dict[str, float] = {}
|
|
267
|
+
uncached_for_next: set[str] = set()
|
|
156
268
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
269
|
+
for node_id, mass in state.frontier.items():
|
|
270
|
+
if mass < config.threshold:
|
|
271
|
+
continue
|
|
160
272
|
|
|
161
|
-
|
|
162
|
-
|
|
273
|
+
# Keep α portion for this node
|
|
274
|
+
state.scores[node_id] = state.scores.get(node_id, 0) + config.alpha * mass
|
|
163
275
|
|
|
164
|
-
|
|
276
|
+
# Push (1-α) to neighbors
|
|
277
|
+
push_mass = (1 - config.alpha) * mass
|
|
278
|
+
neighbors = cache.get_normalized_neighbors(edge_type, node_id, config.top_k_neighbors)
|
|
165
279
|
|
|
166
|
-
|
|
167
|
-
|
|
280
|
+
for neighbor in neighbors:
|
|
281
|
+
next_frontier[neighbor.node_id] = next_frontier.get(neighbor.node_id, 0) + push_mass * neighbor.weight
|
|
282
|
+
# Track if we'll need edges for this node in the next hop
|
|
283
|
+
if not cache.is_fully_loaded(neighbor.node_id):
|
|
284
|
+
uncached_for_next.add(neighbor.node_id)
|
|
285
|
+
|
|
286
|
+
state.frontier = next_frontier
|
|
287
|
+
state.hop_index += 1
|
|
288
|
+
|
|
289
|
+
return uncached_for_next
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _finalize_pattern(state: PatternState, config: MPFPConfig) -> PatternResult:
|
|
293
|
+
"""Finalize pattern by adding remaining frontier mass to scores."""
|
|
294
|
+
for node_id, mass in state.frontier.items():
|
|
168
295
|
if mass >= config.threshold:
|
|
169
|
-
scores[node_id] = scores.get(node_id, 0) + mass
|
|
296
|
+
state.scores[node_id] = state.scores.get(node_id, 0) + mass
|
|
297
|
+
|
|
298
|
+
return PatternResult(pattern=state.pattern, scores=state.scores)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
async def mpfp_traverse_hop_synchronized(
|
|
302
|
+
pool,
|
|
303
|
+
pattern_jobs: list[tuple[list[SeedNode], list[str]]],
|
|
304
|
+
config: MPFPConfig,
|
|
305
|
+
cache: EdgeCache,
|
|
306
|
+
) -> list[PatternResult]:
|
|
307
|
+
"""
|
|
308
|
+
Execute ALL patterns with hop-synchronized edge loading.
|
|
309
|
+
|
|
310
|
+
Instead of running each pattern independently (causing multiple DB queries),
|
|
311
|
+
this function:
|
|
312
|
+
1. Runs hop 1 for ALL patterns (using pre-warmed seed edges)
|
|
313
|
+
2. Collects ALL unique hop-2 frontier nodes across patterns
|
|
314
|
+
3. Pre-warms hop-2 edges in ONE query
|
|
315
|
+
4. Runs hop 2 for ALL patterns
|
|
316
|
+
|
|
317
|
+
This reduces DB queries from O(patterns * hops) to O(hops).
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
pool: Database connection pool
|
|
321
|
+
pattern_jobs: List of (seeds, pattern) tuples
|
|
322
|
+
config: Algorithm parameters
|
|
323
|
+
cache: Shared edge cache (should be pre-warmed with seed edges)
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
List of PatternResult for each pattern
|
|
327
|
+
"""
|
|
328
|
+
import time
|
|
329
|
+
|
|
330
|
+
# Initialize all pattern states
|
|
331
|
+
states = [_init_pattern_state(seeds, pattern) for seeds, pattern in pattern_jobs]
|
|
332
|
+
|
|
333
|
+
# Determine max hops (all patterns should be same length, but be safe)
|
|
334
|
+
max_hops = max((len(p) for _, p in pattern_jobs), default=0)
|
|
335
|
+
|
|
336
|
+
# Detailed timing for debugging
|
|
337
|
+
hop_times: list[dict] = []
|
|
338
|
+
|
|
339
|
+
# Execute hop-by-hop across ALL patterns
|
|
340
|
+
for hop in range(max_hops):
|
|
341
|
+
hop_start = time.time()
|
|
342
|
+
hop_timing = {"hop": hop, "patterns_executed": 0, "uncached_count": 0, "load_time": 0.0}
|
|
343
|
+
|
|
344
|
+
# Execute this hop for all patterns, collect uncached nodes for next hop
|
|
345
|
+
all_uncached: set[str] = set()
|
|
346
|
+
exec_start = time.time()
|
|
347
|
+
for state in states:
|
|
348
|
+
if state.hop_index < len(state.pattern):
|
|
349
|
+
uncached = _execute_hop(state, cache, config)
|
|
350
|
+
all_uncached.update(uncached)
|
|
351
|
+
hop_timing["patterns_executed"] += 1
|
|
352
|
+
hop_timing["exec_time"] = time.time() - exec_start
|
|
353
|
+
|
|
354
|
+
# Pre-warm edges for ALL uncached nodes before next hop
|
|
355
|
+
hop_timing["uncached_count"] = len(all_uncached)
|
|
356
|
+
if all_uncached:
|
|
357
|
+
uncached_list = list(all_uncached - cache._fully_loaded)
|
|
358
|
+
hop_timing["uncached_after_filter"] = len(uncached_list)
|
|
359
|
+
if uncached_list:
|
|
360
|
+
load_start = time.time()
|
|
361
|
+
edges_by_type = await load_all_edges_for_frontier(pool, uncached_list, config.top_k_neighbors)
|
|
362
|
+
hop_timing["load_time"] = time.time() - load_start
|
|
363
|
+
cache.edge_load_time += hop_timing["load_time"]
|
|
364
|
+
cache.db_queries += 1
|
|
365
|
+
cache.add_all_edges(edges_by_type, uncached_list)
|
|
366
|
+
hop_timing["edges_loaded"] = sum(
|
|
367
|
+
len(neighbors) for edges in edges_by_type.values() for neighbors in edges.values()
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
hop_timing["total_time"] = time.time() - hop_start
|
|
371
|
+
hop_times.append(hop_timing)
|
|
372
|
+
|
|
373
|
+
# Store hop timing details in cache for logging
|
|
374
|
+
cache.hop_details = hop_times
|
|
375
|
+
|
|
376
|
+
# Finalize all patterns
|
|
377
|
+
return [_finalize_pattern(state, config) for state in states]
|
|
170
378
|
|
|
171
|
-
|
|
379
|
+
|
|
380
|
+
async def mpfp_traverse_async(
|
|
381
|
+
pool,
|
|
382
|
+
seeds: list[SeedNode],
|
|
383
|
+
pattern: list[str],
|
|
384
|
+
config: MPFPConfig,
|
|
385
|
+
cache: EdgeCache,
|
|
386
|
+
) -> PatternResult:
|
|
387
|
+
"""
|
|
388
|
+
Async Forward Push traversal with lazy edge loading.
|
|
389
|
+
|
|
390
|
+
NOTE: For better performance with multiple patterns, use mpfp_traverse_hop_synchronized().
|
|
391
|
+
This function is kept for single-pattern use cases.
|
|
392
|
+
"""
|
|
393
|
+
if not seeds:
|
|
394
|
+
return PatternResult(pattern=pattern, scores={})
|
|
395
|
+
|
|
396
|
+
results = await mpfp_traverse_hop_synchronized(pool, [(seeds, pattern)], config, cache)
|
|
397
|
+
return results[0] if results else PatternResult(pattern=pattern, scores={})
|
|
172
398
|
|
|
173
399
|
|
|
174
400
|
def rrf_fusion(
|
|
@@ -210,38 +436,6 @@ def rrf_fusion(
|
|
|
210
436
|
# -----------------------------------------------------------------------------
|
|
211
437
|
|
|
212
438
|
|
|
213
|
-
async def load_typed_adjacency(pool, bank_id: str) -> TypedAdjacency:
|
|
214
|
-
"""
|
|
215
|
-
Load all edges for a bank, split by edge type.
|
|
216
|
-
|
|
217
|
-
Single query, then organize in-memory for fast traversal.
|
|
218
|
-
"""
|
|
219
|
-
async with acquire_with_retry(pool) as conn:
|
|
220
|
-
rows = await conn.fetch(
|
|
221
|
-
f"""
|
|
222
|
-
SELECT ml.from_unit_id, ml.to_unit_id, ml.link_type, ml.weight
|
|
223
|
-
FROM {fq_table("memory_links")} ml
|
|
224
|
-
JOIN {fq_table("memory_units")} mu ON ml.from_unit_id = mu.id
|
|
225
|
-
WHERE mu.bank_id = $1
|
|
226
|
-
AND ml.weight >= 0.1
|
|
227
|
-
ORDER BY ml.from_unit_id, ml.weight DESC
|
|
228
|
-
""",
|
|
229
|
-
bank_id,
|
|
230
|
-
)
|
|
231
|
-
|
|
232
|
-
graphs: dict[str, dict[str, list[EdgeTarget]]] = defaultdict(lambda: defaultdict(list))
|
|
233
|
-
|
|
234
|
-
for row in rows:
|
|
235
|
-
from_id = str(row["from_unit_id"])
|
|
236
|
-
to_id = str(row["to_unit_id"])
|
|
237
|
-
link_type = row["link_type"]
|
|
238
|
-
weight = row["weight"]
|
|
239
|
-
|
|
240
|
-
graphs[link_type][from_id].append(EdgeTarget(node_id=to_id, weight=weight))
|
|
241
|
-
|
|
242
|
-
return TypedAdjacency(graphs=dict(graphs))
|
|
243
|
-
|
|
244
|
-
|
|
245
439
|
async def fetch_memory_units_by_ids(
|
|
246
440
|
pool,
|
|
247
441
|
node_ids: list[str],
|
|
@@ -255,7 +449,7 @@ async def fetch_memory_units_by_ids(
|
|
|
255
449
|
rows = await conn.fetch(
|
|
256
450
|
f"""
|
|
257
451
|
SELECT id, text, context, event_date, occurred_start, occurred_end,
|
|
258
|
-
mentioned_at,
|
|
452
|
+
mentioned_at, embedding, fact_type, document_id, chunk_id, tags
|
|
259
453
|
FROM {fq_table("memory_units")}
|
|
260
454
|
WHERE id = ANY($1::uuid[])
|
|
261
455
|
AND fact_type = $2
|
|
@@ -274,10 +468,10 @@ async def fetch_memory_units_by_ids(
|
|
|
274
468
|
|
|
275
469
|
class MPFPGraphRetriever(GraphRetriever):
|
|
276
470
|
"""
|
|
277
|
-
Graph retrieval using Meta-Path Forward Push.
|
|
471
|
+
Graph retrieval using Meta-Path Forward Push with lazy edge loading.
|
|
278
472
|
|
|
279
473
|
Runs predefined patterns in parallel from semantic and temporal seeds,
|
|
280
|
-
|
|
474
|
+
loading edges on-demand per hop instead of loading entire graph upfront.
|
|
281
475
|
"""
|
|
282
476
|
|
|
283
477
|
def __init__(self, config: MPFPConfig | None = None):
|
|
@@ -287,8 +481,13 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
287
481
|
Args:
|
|
288
482
|
config: Algorithm configuration (uses defaults if None)
|
|
289
483
|
"""
|
|
290
|
-
|
|
291
|
-
|
|
484
|
+
if config is None:
|
|
485
|
+
# Read top_k_neighbors from global config
|
|
486
|
+
from ...config import get_config
|
|
487
|
+
|
|
488
|
+
global_config = get_config()
|
|
489
|
+
config = MPFPConfig(top_k_neighbors=global_config.mpfp_top_k_neighbors)
|
|
490
|
+
self.config = config
|
|
292
491
|
|
|
293
492
|
@property
|
|
294
493
|
def name(self) -> str:
|
|
@@ -304,9 +503,12 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
304
503
|
query_text: str | None = None,
|
|
305
504
|
semantic_seeds: list[RetrievalResult] | None = None,
|
|
306
505
|
temporal_seeds: list[RetrievalResult] | None = None,
|
|
307
|
-
|
|
506
|
+
adjacency=None, # Ignored - kept for interface compatibility
|
|
507
|
+
tags: list[str] | None = None,
|
|
508
|
+
tags_match: TagsMatch = "any",
|
|
509
|
+
) -> tuple[list[RetrievalResult], MPFPTimings | None]:
|
|
308
510
|
"""
|
|
309
|
-
Retrieve facts using MPFP algorithm.
|
|
511
|
+
Retrieve facts using MPFP algorithm with lazy edge loading.
|
|
310
512
|
|
|
311
513
|
Args:
|
|
312
514
|
pool: Database connection pool
|
|
@@ -317,12 +519,15 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
317
519
|
query_text: Original query text (optional)
|
|
318
520
|
semantic_seeds: Pre-computed semantic entry points
|
|
319
521
|
temporal_seeds: Pre-computed temporal entry points
|
|
522
|
+
adjacency: Ignored (kept for interface compatibility)
|
|
523
|
+
tags: Optional list of tags for visibility filtering (OR matching)
|
|
320
524
|
|
|
321
525
|
Returns:
|
|
322
|
-
List of RetrievalResult with activation scores
|
|
526
|
+
Tuple of (List of RetrievalResult with activation scores, MPFPTimings)
|
|
323
527
|
"""
|
|
324
|
-
|
|
325
|
-
|
|
528
|
+
import time
|
|
529
|
+
|
|
530
|
+
timings = MPFPTimings(fact_type=fact_type)
|
|
326
531
|
|
|
327
532
|
# Convert seeds to SeedNode format
|
|
328
533
|
semantic_seed_nodes = self._convert_seeds(semantic_seeds, "similarity")
|
|
@@ -330,54 +535,88 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
330
535
|
|
|
331
536
|
# If no semantic seeds provided, fall back to finding our own
|
|
332
537
|
if not semantic_seed_nodes:
|
|
333
|
-
|
|
538
|
+
seeds_start = time.time()
|
|
539
|
+
semantic_seed_nodes = await self._find_semantic_seeds(
|
|
540
|
+
pool, query_embedding_str, bank_id, fact_type, tags=tags, tags_match=tags_match
|
|
541
|
+
)
|
|
542
|
+
timings.seeds_time = time.time() - seeds_start
|
|
543
|
+
logger.debug(
|
|
544
|
+
f"[MPFP] Found {len(semantic_seed_nodes)} semantic seeds for fact_type={fact_type} (tags={tags}, tags_match={tags_match})"
|
|
545
|
+
)
|
|
334
546
|
|
|
335
|
-
#
|
|
336
|
-
|
|
547
|
+
# Collect all pattern jobs
|
|
548
|
+
pattern_jobs = []
|
|
337
549
|
|
|
338
550
|
# Patterns from semantic seeds
|
|
339
551
|
for pattern in self.config.patterns_semantic:
|
|
340
552
|
if semantic_seed_nodes:
|
|
341
|
-
|
|
342
|
-
asyncio.to_thread(
|
|
343
|
-
mpfp_traverse,
|
|
344
|
-
semantic_seed_nodes,
|
|
345
|
-
pattern,
|
|
346
|
-
adjacency,
|
|
347
|
-
self.config,
|
|
348
|
-
)
|
|
349
|
-
)
|
|
553
|
+
pattern_jobs.append((semantic_seed_nodes, pattern))
|
|
350
554
|
|
|
351
555
|
# Patterns from temporal seeds
|
|
352
556
|
for pattern in self.config.patterns_temporal:
|
|
353
557
|
if temporal_seed_nodes:
|
|
354
|
-
|
|
355
|
-
asyncio.to_thread(
|
|
356
|
-
mpfp_traverse,
|
|
357
|
-
temporal_seed_nodes,
|
|
358
|
-
pattern,
|
|
359
|
-
adjacency,
|
|
360
|
-
self.config,
|
|
361
|
-
)
|
|
362
|
-
)
|
|
363
|
-
|
|
364
|
-
if not tasks:
|
|
365
|
-
return []
|
|
558
|
+
pattern_jobs.append((temporal_seed_nodes, pattern))
|
|
366
559
|
|
|
367
|
-
|
|
368
|
-
|
|
560
|
+
if not pattern_jobs:
|
|
561
|
+
logger.debug(
|
|
562
|
+
f"[MPFP] No pattern jobs (semantic_seeds={len(semantic_seed_nodes)}, temporal_seeds={len(temporal_seed_nodes)})"
|
|
563
|
+
)
|
|
564
|
+
return [], timings
|
|
565
|
+
|
|
566
|
+
timings.pattern_count = len(pattern_jobs)
|
|
567
|
+
|
|
568
|
+
# Shared edge cache across all patterns
|
|
569
|
+
cache = EdgeCache()
|
|
570
|
+
|
|
571
|
+
# Pre-warm cache with ALL seed node edges BEFORE running patterns
|
|
572
|
+
# This prevents redundant DB queries at hop 1
|
|
573
|
+
all_seed_ids = list({s.node_id for seeds, _ in pattern_jobs for s in seeds})
|
|
574
|
+
if all_seed_ids:
|
|
575
|
+
import time as time_module
|
|
576
|
+
|
|
577
|
+
prewarm_start = time_module.time()
|
|
578
|
+
edges_by_type = await load_all_edges_for_frontier(pool, all_seed_ids, self.config.top_k_neighbors)
|
|
579
|
+
cache.edge_load_time += time_module.time() - prewarm_start
|
|
580
|
+
cache.db_queries += 1
|
|
581
|
+
cache.add_all_edges(edges_by_type, all_seed_ids)
|
|
582
|
+
|
|
583
|
+
# Run all patterns with HOP-SYNCHRONIZED edge loading
|
|
584
|
+
# This batches hop-2 edge loads across ALL patterns into ONE query
|
|
585
|
+
# Reduces DB queries from O(patterns * hops) to O(hops)
|
|
586
|
+
step_start = time.time()
|
|
587
|
+
pattern_results = await mpfp_traverse_hop_synchronized(pool, pattern_jobs, self.config, cache)
|
|
588
|
+
timings.traverse = time.time() - step_start
|
|
589
|
+
|
|
590
|
+
# Record edge loading stats from cache
|
|
591
|
+
timings.edge_count = sum(len(neighbors) for g in cache.graphs.values() for neighbors in g.values())
|
|
592
|
+
timings.db_queries = cache.db_queries
|
|
593
|
+
timings.edge_load_time = cache.edge_load_time
|
|
594
|
+
timings.hop_details = cache.hop_details
|
|
369
595
|
|
|
370
596
|
# Fuse results
|
|
597
|
+
step_start = time.time()
|
|
371
598
|
fused = rrf_fusion(pattern_results, top_k=budget)
|
|
599
|
+
timings.fusion = time.time() - step_start
|
|
372
600
|
|
|
373
601
|
if not fused:
|
|
374
|
-
|
|
602
|
+
logger.debug(f"[MPFP] No fused results after RRF fusion (pattern_count={len(pattern_results)})")
|
|
603
|
+
return [], timings
|
|
375
604
|
|
|
376
|
-
# Get top result IDs
|
|
605
|
+
# Get top result IDs
|
|
377
606
|
result_ids = [node_id for node_id, score in fused][:budget]
|
|
378
607
|
|
|
379
608
|
# Fetch full details
|
|
609
|
+
step_start = time.time()
|
|
380
610
|
results = await fetch_memory_units_by_ids(pool, result_ids, fact_type)
|
|
611
|
+
timings.fetch = time.time() - step_start
|
|
612
|
+
|
|
613
|
+
# Filter results by tags (graph traversal may have picked up unfiltered memories)
|
|
614
|
+
if tags:
|
|
615
|
+
from .tags import filter_results_by_tags
|
|
616
|
+
|
|
617
|
+
results = filter_results_by_tags(results, tags, match=tags_match)
|
|
618
|
+
|
|
619
|
+
timings.result_count = len(results)
|
|
381
620
|
|
|
382
621
|
# Add activation scores from fusion
|
|
383
622
|
score_map = {node_id: score for node_id, score in fused}
|
|
@@ -387,7 +626,7 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
387
626
|
# Sort by activation
|
|
388
627
|
results.sort(key=lambda r: r.activation or 0, reverse=True)
|
|
389
628
|
|
|
390
|
-
return results
|
|
629
|
+
return results, timings
|
|
391
630
|
|
|
392
631
|
def _convert_seeds(
|
|
393
632
|
self,
|
|
@@ -415,8 +654,17 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
415
654
|
fact_type: str,
|
|
416
655
|
limit: int = 20,
|
|
417
656
|
threshold: float = 0.3,
|
|
657
|
+
tags: list[str] | None = None,
|
|
658
|
+
tags_match: TagsMatch = "any",
|
|
418
659
|
) -> list[SeedNode]:
|
|
419
660
|
"""Fallback: find semantic seeds via embedding search."""
|
|
661
|
+
from .tags import build_tags_where_clause_simple
|
|
662
|
+
|
|
663
|
+
tags_clause = build_tags_where_clause_simple(tags, 6, match=tags_match)
|
|
664
|
+
params = [query_embedding_str, bank_id, fact_type, threshold, limit]
|
|
665
|
+
if tags:
|
|
666
|
+
params.append(tags)
|
|
667
|
+
|
|
420
668
|
async with acquire_with_retry(pool) as conn:
|
|
421
669
|
rows = await conn.fetch(
|
|
422
670
|
f"""
|
|
@@ -426,14 +674,11 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
426
674
|
AND embedding IS NOT NULL
|
|
427
675
|
AND fact_type = $3
|
|
428
676
|
AND (1 - (embedding <=> $1::vector)) >= $4
|
|
677
|
+
{tags_clause}
|
|
429
678
|
ORDER BY embedding <=> $1::vector
|
|
430
679
|
LIMIT $5
|
|
431
680
|
""",
|
|
432
|
-
|
|
433
|
-
bank_id,
|
|
434
|
-
fact_type,
|
|
435
|
-
threshold,
|
|
436
|
-
limit,
|
|
681
|
+
*params,
|
|
437
682
|
)
|
|
438
683
|
|
|
439
684
|
return [SeedNode(node_id=str(r["id"]), score=r["similarity"]) for r in rows]
|
|
@@ -44,7 +44,7 @@ class CrossEncoderReranker:
|
|
|
44
44
|
await cross_encoder.initialize()
|
|
45
45
|
self._initialized = True
|
|
46
46
|
|
|
47
|
-
def rerank(self, query: str, candidates: list[MergedCandidate]) -> list[ScoredResult]:
|
|
47
|
+
async def rerank(self, query: str, candidates: list[MergedCandidate]) -> list[ScoredResult]:
|
|
48
48
|
"""
|
|
49
49
|
Rerank candidates using cross-encoder scores.
|
|
50
50
|
|
|
@@ -85,7 +85,7 @@ class CrossEncoderReranker:
|
|
|
85
85
|
pairs.append([query, doc_text])
|
|
86
86
|
|
|
87
87
|
# Get cross-encoder scores
|
|
88
|
-
scores = self.cross_encoder.predict(pairs)
|
|
88
|
+
scores = await self.cross_encoder.predict(pairs)
|
|
89
89
|
|
|
90
90
|
# Normalize scores using sigmoid to [0, 1] range
|
|
91
91
|
# Cross-encoder returns logits which can be negative
|