hindsight-api 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,454 @@
1
+ """
2
+ Meta-Path Forward Push (MPFP) graph retrieval.
3
+
4
+ A sublinear graph traversal algorithm for memory retrieval over heterogeneous
5
+ graphs with multiple edge types (semantic, temporal, causal, entity).
6
+
7
+ Combines meta-path patterns from HIN literature with Forward Push local
8
+ propagation from Approximate PPR.
9
+
10
+ Key properties:
11
+ - Sublinear in graph size (threshold pruning bounds active nodes)
12
+ - Predefined patterns capture different retrieval intents
13
+ - All patterns run in parallel, results fused via RRF
14
+ - No LLM in the loop during traversal
15
+ """
16
+
17
+ import asyncio
18
+ import logging
19
+ from dataclasses import dataclass, field
20
+ from typing import List, Dict, Optional, Tuple
21
+ from collections import defaultdict
22
+
23
+ from .types import RetrievalResult
24
+ from .graph_retrieval import GraphRetriever
25
+ from ..db_utils import acquire_with_retry
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ # -----------------------------------------------------------------------------
31
+ # Data Classes
32
+ # -----------------------------------------------------------------------------
33
+
34
+ @dataclass
35
+ class EdgeTarget:
36
+ """A neighbor node with its edge weight."""
37
+ node_id: str
38
+ weight: float
39
+
40
+
41
+ @dataclass
42
+ class TypedAdjacency:
43
+ """Adjacency lists split by edge type."""
44
+ # edge_type -> from_node_id -> list of (to_node_id, weight)
45
+ graphs: Dict[str, Dict[str, List[EdgeTarget]]] = field(default_factory=dict)
46
+
47
+ def get_neighbors(self, edge_type: str, node_id: str) -> List[EdgeTarget]:
48
+ """Get neighbors for a node via a specific edge type."""
49
+ return self.graphs.get(edge_type, {}).get(node_id, [])
50
+
51
+ def get_normalized_neighbors(
52
+ self,
53
+ edge_type: str,
54
+ node_id: str,
55
+ top_k: int
56
+ ) -> List[EdgeTarget]:
57
+ """Get top-k neighbors with weights normalized to sum to 1."""
58
+ neighbors = self.get_neighbors(edge_type, node_id)[:top_k]
59
+ if not neighbors:
60
+ return []
61
+
62
+ total = sum(n.weight for n in neighbors)
63
+ if total == 0:
64
+ return []
65
+
66
+ return [
67
+ EdgeTarget(node_id=n.node_id, weight=n.weight / total)
68
+ for n in neighbors
69
+ ]
70
+
71
+
72
+ @dataclass
73
+ class PatternResult:
74
+ """Result from a single pattern traversal."""
75
+ pattern: List[str]
76
+ scores: Dict[str, float] # node_id -> accumulated mass
77
+
78
+
79
+ @dataclass
80
+ class MPFPConfig:
81
+ """Configuration for MPFP algorithm."""
82
+ alpha: float = 0.15 # teleport/keep probability
83
+ threshold: float = 1e-6 # mass pruning threshold (lower = explore more)
84
+ top_k_neighbors: int = 20 # fan-out limit per node
85
+
86
+ # Patterns from semantic seeds
87
+ patterns_semantic: List[List[str]] = field(default_factory=lambda: [
88
+ ['semantic', 'semantic'], # topic expansion
89
+ ['entity', 'temporal'], # entity timeline
90
+ ['semantic', 'causes'], # reasoning chains (forward)
91
+ ['semantic', 'caused_by'], # reasoning chains (backward)
92
+ ['entity', 'semantic'], # entity context
93
+ ])
94
+
95
+ # Patterns from temporal seeds
96
+ patterns_temporal: List[List[str]] = field(default_factory=lambda: [
97
+ ['temporal', 'semantic'], # what was happening then
98
+ ['temporal', 'entity'], # who was involved then
99
+ ])
100
+
101
+
102
+ @dataclass
103
+ class SeedNode:
104
+ """An entry point node with its initial score."""
105
+ node_id: str
106
+ score: float # initial mass (e.g., similarity score)
107
+
108
+
109
+ # -----------------------------------------------------------------------------
110
+ # Core Algorithm
111
+ # -----------------------------------------------------------------------------
112
+
113
+ def mpfp_traverse(
114
+ seeds: List[SeedNode],
115
+ pattern: List[str],
116
+ adjacency: TypedAdjacency,
117
+ config: MPFPConfig,
118
+ ) -> PatternResult:
119
+ """
120
+ Forward Push traversal following a meta-path pattern.
121
+
122
+ Args:
123
+ seeds: Entry point nodes with initial scores
124
+ pattern: Sequence of edge types to follow
125
+ adjacency: Typed adjacency structure
126
+ config: Algorithm parameters
127
+
128
+ Returns:
129
+ PatternResult with accumulated scores per node
130
+ """
131
+ if not seeds:
132
+ return PatternResult(pattern=pattern, scores={})
133
+
134
+ scores: Dict[str, float] = {}
135
+
136
+ # Initialize frontier with seed masses (normalized)
137
+ total_seed_score = sum(s.score for s in seeds)
138
+ if total_seed_score == 0:
139
+ total_seed_score = len(seeds) # fallback to uniform
140
+
141
+ frontier: Dict[str, float] = {
142
+ s.node_id: s.score / total_seed_score for s in seeds
143
+ }
144
+
145
+ # Follow pattern hop by hop
146
+ for edge_type in pattern:
147
+ next_frontier: Dict[str, float] = {}
148
+
149
+ for node_id, mass in frontier.items():
150
+ if mass < config.threshold:
151
+ continue
152
+
153
+ # Keep α portion for this node
154
+ scores[node_id] = scores.get(node_id, 0) + config.alpha * mass
155
+
156
+ # Push (1-α) to neighbors
157
+ push_mass = (1 - config.alpha) * mass
158
+ neighbors = adjacency.get_normalized_neighbors(
159
+ edge_type, node_id, config.top_k_neighbors
160
+ )
161
+
162
+ for neighbor in neighbors:
163
+ next_frontier[neighbor.node_id] = (
164
+ next_frontier.get(neighbor.node_id, 0) +
165
+ push_mass * neighbor.weight
166
+ )
167
+
168
+ frontier = next_frontier
169
+
170
+ # Final frontier nodes get their remaining mass
171
+ for node_id, mass in frontier.items():
172
+ if mass >= config.threshold:
173
+ scores[node_id] = scores.get(node_id, 0) + mass
174
+
175
+ return PatternResult(pattern=pattern, scores=scores)
176
+
177
+
178
+ def rrf_fusion(
179
+ results: List[PatternResult],
180
+ k: int = 60,
181
+ top_k: int = 50,
182
+ ) -> List[Tuple[str, float]]:
183
+ """
184
+ Reciprocal Rank Fusion to combine pattern results.
185
+
186
+ Args:
187
+ results: List of pattern results
188
+ k: RRF constant (higher = more uniform weighting)
189
+ top_k: Number of results to return
190
+
191
+ Returns:
192
+ List of (node_id, fused_score) tuples, sorted by score descending
193
+ """
194
+ fused: Dict[str, float] = {}
195
+
196
+ for result in results:
197
+ if not result.scores:
198
+ continue
199
+
200
+ # Rank nodes by their score in this pattern
201
+ ranked = sorted(
202
+ result.scores.keys(),
203
+ key=lambda n: result.scores[n],
204
+ reverse=True
205
+ )
206
+
207
+ for rank, node_id in enumerate(ranked):
208
+ fused[node_id] = fused.get(node_id, 0) + 1.0 / (k + rank + 1)
209
+
210
+ # Sort by fused score and return top-k
211
+ sorted_results = sorted(
212
+ fused.items(),
213
+ key=lambda x: x[1],
214
+ reverse=True
215
+ )
216
+
217
+ return sorted_results[:top_k]
218
+
219
+
220
+ # -----------------------------------------------------------------------------
221
+ # Database Loading
222
+ # -----------------------------------------------------------------------------
223
+
224
+ async def load_typed_adjacency(pool, bank_id: str) -> TypedAdjacency:
225
+ """
226
+ Load all edges for a bank, split by edge type.
227
+
228
+ Single query, then organize in-memory for fast traversal.
229
+ """
230
+ async with acquire_with_retry(pool) as conn:
231
+ rows = await conn.fetch(
232
+ """
233
+ SELECT ml.from_unit_id, ml.to_unit_id, ml.link_type, ml.weight
234
+ FROM memory_links ml
235
+ JOIN memory_units mu ON ml.from_unit_id = mu.id
236
+ WHERE mu.bank_id = $1
237
+ AND ml.weight >= 0.1
238
+ ORDER BY ml.from_unit_id, ml.weight DESC
239
+ """,
240
+ bank_id
241
+ )
242
+
243
+ graphs: Dict[str, Dict[str, List[EdgeTarget]]] = defaultdict(
244
+ lambda: defaultdict(list)
245
+ )
246
+
247
+ for row in rows:
248
+ from_id = str(row['from_unit_id'])
249
+ to_id = str(row['to_unit_id'])
250
+ link_type = row['link_type']
251
+ weight = row['weight']
252
+
253
+ graphs[link_type][from_id].append(
254
+ EdgeTarget(node_id=to_id, weight=weight)
255
+ )
256
+
257
+ return TypedAdjacency(graphs=dict(graphs))
258
+
259
+
260
+ async def fetch_memory_units_by_ids(
261
+ pool,
262
+ node_ids: List[str],
263
+ fact_type: str,
264
+ ) -> List[RetrievalResult]:
265
+ """Fetch full memory unit details for a list of node IDs."""
266
+ if not node_ids:
267
+ return []
268
+
269
+ async with acquire_with_retry(pool) as conn:
270
+ rows = await conn.fetch(
271
+ """
272
+ SELECT id, text, context, event_date, occurred_start, occurred_end,
273
+ mentioned_at, access_count, embedding, fact_type, document_id, chunk_id
274
+ FROM memory_units
275
+ WHERE id = ANY($1::uuid[])
276
+ AND fact_type = $2
277
+ """,
278
+ node_ids,
279
+ fact_type
280
+ )
281
+
282
+ return [RetrievalResult.from_db_row(dict(r)) for r in rows]
283
+
284
+
285
+ # -----------------------------------------------------------------------------
286
+ # Graph Retriever Implementation
287
+ # -----------------------------------------------------------------------------
288
+
289
+ class MPFPGraphRetriever(GraphRetriever):
290
+ """
291
+ Graph retrieval using Meta-Path Forward Push.
292
+
293
+ Runs predefined patterns in parallel from semantic and temporal seeds,
294
+ then fuses results via RRF.
295
+ """
296
+
297
+ def __init__(self, config: Optional[MPFPConfig] = None):
298
+ """
299
+ Initialize MPFP retriever.
300
+
301
+ Args:
302
+ config: Algorithm configuration (uses defaults if None)
303
+ """
304
+ self.config = config or MPFPConfig()
305
+ self._adjacency_cache: Dict[str, TypedAdjacency] = {}
306
+
307
+ @property
308
+ def name(self) -> str:
309
+ return "mpfp"
310
+
311
+ async def retrieve(
312
+ self,
313
+ pool,
314
+ query_embedding_str: str,
315
+ bank_id: str,
316
+ fact_type: str,
317
+ budget: int,
318
+ query_text: Optional[str] = None,
319
+ semantic_seeds: Optional[List[RetrievalResult]] = None,
320
+ temporal_seeds: Optional[List[RetrievalResult]] = None,
321
+ ) -> List[RetrievalResult]:
322
+ """
323
+ Retrieve facts using MPFP algorithm.
324
+
325
+ Args:
326
+ pool: Database connection pool
327
+ query_embedding_str: Query embedding (used for fallback seed finding)
328
+ bank_id: Memory bank ID
329
+ fact_type: Fact type to filter
330
+ budget: Maximum results to return
331
+ query_text: Original query text (optional)
332
+ semantic_seeds: Pre-computed semantic entry points
333
+ temporal_seeds: Pre-computed temporal entry points
334
+
335
+ Returns:
336
+ List of RetrievalResult with activation scores
337
+ """
338
+ # Load typed adjacency (could cache per bank_id with TTL)
339
+ adjacency = await load_typed_adjacency(pool, bank_id)
340
+
341
+ # Convert seeds to SeedNode format
342
+ semantic_seed_nodes = self._convert_seeds(semantic_seeds, 'similarity')
343
+ temporal_seed_nodes = self._convert_seeds(temporal_seeds, 'temporal_score')
344
+
345
+ # If no semantic seeds provided, fall back to finding our own
346
+ if not semantic_seed_nodes:
347
+ semantic_seed_nodes = await self._find_semantic_seeds(
348
+ pool, query_embedding_str, bank_id, fact_type
349
+ )
350
+
351
+ # Run all patterns in parallel
352
+ tasks = []
353
+
354
+ # Patterns from semantic seeds
355
+ for pattern in self.config.patterns_semantic:
356
+ if semantic_seed_nodes:
357
+ tasks.append(
358
+ asyncio.to_thread(
359
+ mpfp_traverse,
360
+ semantic_seed_nodes,
361
+ pattern,
362
+ adjacency,
363
+ self.config,
364
+ )
365
+ )
366
+
367
+ # Patterns from temporal seeds
368
+ for pattern in self.config.patterns_temporal:
369
+ if temporal_seed_nodes:
370
+ tasks.append(
371
+ asyncio.to_thread(
372
+ mpfp_traverse,
373
+ temporal_seed_nodes,
374
+ pattern,
375
+ adjacency,
376
+ self.config,
377
+ )
378
+ )
379
+
380
+ if not tasks:
381
+ return []
382
+
383
+ # Gather pattern results
384
+ pattern_results = await asyncio.gather(*tasks)
385
+
386
+ # Fuse results
387
+ fused = rrf_fusion(pattern_results, top_k=budget)
388
+
389
+ if not fused:
390
+ return []
391
+
392
+ # Get top result IDs (don't exclude seeds - they may be highly relevant)
393
+ result_ids = [node_id for node_id, score in fused][:budget]
394
+
395
+ # Fetch full details
396
+ results = await fetch_memory_units_by_ids(pool, result_ids, fact_type)
397
+
398
+ # Add activation scores from fusion
399
+ score_map = {node_id: score for node_id, score in fused}
400
+ for result in results:
401
+ result.activation = score_map.get(result.id, 0.0)
402
+
403
+ # Sort by activation
404
+ results.sort(key=lambda r: r.activation or 0, reverse=True)
405
+
406
+ return results
407
+
408
+ def _convert_seeds(
409
+ self,
410
+ seeds: Optional[List[RetrievalResult]],
411
+ score_attr: str,
412
+ ) -> List[SeedNode]:
413
+ """Convert RetrievalResult seeds to SeedNode format."""
414
+ if not seeds:
415
+ return []
416
+
417
+ result = []
418
+ for seed in seeds:
419
+ score = getattr(seed, score_attr, None)
420
+ if score is None:
421
+ score = seed.activation or seed.similarity or 1.0
422
+ result.append(SeedNode(node_id=seed.id, score=score))
423
+
424
+ return result
425
+
426
+ async def _find_semantic_seeds(
427
+ self,
428
+ pool,
429
+ query_embedding_str: str,
430
+ bank_id: str,
431
+ fact_type: str,
432
+ limit: int = 20,
433
+ threshold: float = 0.3,
434
+ ) -> List[SeedNode]:
435
+ """Fallback: find semantic seeds via embedding search."""
436
+ async with acquire_with_retry(pool) as conn:
437
+ rows = await conn.fetch(
438
+ """
439
+ SELECT id, 1 - (embedding <=> $1::vector) AS similarity
440
+ FROM memory_units
441
+ WHERE bank_id = $2
442
+ AND embedding IS NOT NULL
443
+ AND fact_type = $3
444
+ AND (1 - (embedding <=> $1::vector)) >= $4
445
+ ORDER BY embedding <=> $1::vector
446
+ LIMIT $5
447
+ """,
448
+ query_embedding_str, bank_id, fact_type, threshold, limit
449
+ )
450
+
451
+ return [
452
+ SeedNode(node_id=str(r['id']), score=r['similarity'])
453
+ for r in rows
454
+ ]