hindsight-api 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -9
- hindsight_api/alembic/env.py +5 -8
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
- hindsight_api/api/__init__.py +10 -10
- hindsight_api/api/http.py +575 -593
- hindsight_api/api/mcp.py +31 -33
- hindsight_api/banner.py +13 -6
- hindsight_api/config.py +17 -12
- hindsight_api/engine/__init__.py +9 -9
- hindsight_api/engine/cross_encoder.py +23 -27
- hindsight_api/engine/db_utils.py +5 -4
- hindsight_api/engine/embeddings.py +22 -21
- hindsight_api/engine/entity_resolver.py +81 -75
- hindsight_api/engine/llm_wrapper.py +74 -88
- hindsight_api/engine/memory_engine.py +663 -673
- hindsight_api/engine/query_analyzer.py +100 -97
- hindsight_api/engine/response_models.py +105 -106
- hindsight_api/engine/retain/__init__.py +9 -16
- hindsight_api/engine/retain/bank_utils.py +34 -58
- hindsight_api/engine/retain/chunk_storage.py +4 -12
- hindsight_api/engine/retain/deduplication.py +9 -28
- hindsight_api/engine/retain/embedding_processing.py +4 -11
- hindsight_api/engine/retain/embedding_utils.py +3 -4
- hindsight_api/engine/retain/entity_processing.py +7 -17
- hindsight_api/engine/retain/fact_extraction.py +155 -165
- hindsight_api/engine/retain/fact_storage.py +11 -23
- hindsight_api/engine/retain/link_creation.py +11 -39
- hindsight_api/engine/retain/link_utils.py +166 -95
- hindsight_api/engine/retain/observation_regeneration.py +39 -52
- hindsight_api/engine/retain/orchestrator.py +72 -62
- hindsight_api/engine/retain/types.py +49 -43
- hindsight_api/engine/search/__init__.py +15 -1
- hindsight_api/engine/search/fusion.py +6 -15
- hindsight_api/engine/search/graph_retrieval.py +234 -0
- hindsight_api/engine/search/mpfp_retrieval.py +438 -0
- hindsight_api/engine/search/observation_utils.py +9 -16
- hindsight_api/engine/search/reranking.py +4 -7
- hindsight_api/engine/search/retrieval.py +388 -193
- hindsight_api/engine/search/scoring.py +5 -7
- hindsight_api/engine/search/temporal_extraction.py +8 -11
- hindsight_api/engine/search/think_utils.py +115 -39
- hindsight_api/engine/search/trace.py +68 -38
- hindsight_api/engine/search/tracer.py +49 -35
- hindsight_api/engine/search/types.py +22 -16
- hindsight_api/engine/task_backend.py +21 -26
- hindsight_api/engine/utils.py +25 -10
- hindsight_api/main.py +21 -40
- hindsight_api/mcp_local.py +190 -0
- hindsight_api/metrics.py +44 -30
- hindsight_api/migrations.py +10 -8
- hindsight_api/models.py +60 -72
- hindsight_api/pg0.py +64 -337
- hindsight_api/server.py +3 -6
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +6 -5
- hindsight_api-0.1.6.dist-info/RECORD +64 -0
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.1.4.dist-info/RECORD +0 -61
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Meta-Path Forward Push (MPFP) graph retrieval.
|
|
3
|
+
|
|
4
|
+
A sublinear graph traversal algorithm for memory retrieval over heterogeneous
|
|
5
|
+
graphs with multiple edge types (semantic, temporal, causal, entity).
|
|
6
|
+
|
|
7
|
+
Combines meta-path patterns from HIN literature with Forward Push local
|
|
8
|
+
propagation from Approximate PPR.
|
|
9
|
+
|
|
10
|
+
Key properties:
|
|
11
|
+
- Sublinear in graph size (threshold pruning bounds active nodes)
|
|
12
|
+
- Predefined patterns capture different retrieval intents
|
|
13
|
+
- All patterns run in parallel, results fused via RRF
|
|
14
|
+
- No LLM in the loop during traversal
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import logging
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
from dataclasses import dataclass, field
|
|
21
|
+
|
|
22
|
+
from ..db_utils import acquire_with_retry
|
|
23
|
+
from .graph_retrieval import GraphRetriever
|
|
24
|
+
from .types import RetrievalResult
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# -----------------------------------------------------------------------------
|
|
30
|
+
# Data Classes
|
|
31
|
+
# -----------------------------------------------------------------------------
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class EdgeTarget:
|
|
36
|
+
"""A neighbor node with its edge weight."""
|
|
37
|
+
|
|
38
|
+
node_id: str
|
|
39
|
+
weight: float
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class TypedAdjacency:
|
|
44
|
+
"""Adjacency lists split by edge type."""
|
|
45
|
+
|
|
46
|
+
# edge_type -> from_node_id -> list of (to_node_id, weight)
|
|
47
|
+
graphs: dict[str, dict[str, list[EdgeTarget]]] = field(default_factory=dict)
|
|
48
|
+
|
|
49
|
+
def get_neighbors(self, edge_type: str, node_id: str) -> list[EdgeTarget]:
|
|
50
|
+
"""Get neighbors for a node via a specific edge type."""
|
|
51
|
+
return self.graphs.get(edge_type, {}).get(node_id, [])
|
|
52
|
+
|
|
53
|
+
def get_normalized_neighbors(self, edge_type: str, node_id: str, top_k: int) -> list[EdgeTarget]:
|
|
54
|
+
"""Get top-k neighbors with weights normalized to sum to 1."""
|
|
55
|
+
neighbors = self.get_neighbors(edge_type, node_id)[:top_k]
|
|
56
|
+
if not neighbors:
|
|
57
|
+
return []
|
|
58
|
+
|
|
59
|
+
total = sum(n.weight for n in neighbors)
|
|
60
|
+
if total == 0:
|
|
61
|
+
return []
|
|
62
|
+
|
|
63
|
+
return [EdgeTarget(node_id=n.node_id, weight=n.weight / total) for n in neighbors]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class PatternResult:
|
|
68
|
+
"""Result from a single pattern traversal."""
|
|
69
|
+
|
|
70
|
+
pattern: list[str]
|
|
71
|
+
scores: dict[str, float] # node_id -> accumulated mass
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class MPFPConfig:
|
|
76
|
+
"""Configuration for MPFP algorithm."""
|
|
77
|
+
|
|
78
|
+
alpha: float = 0.15 # teleport/keep probability
|
|
79
|
+
threshold: float = 1e-6 # mass pruning threshold (lower = explore more)
|
|
80
|
+
top_k_neighbors: int = 20 # fan-out limit per node
|
|
81
|
+
|
|
82
|
+
# Patterns from semantic seeds
|
|
83
|
+
patterns_semantic: list[list[str]] = field(
|
|
84
|
+
default_factory=lambda: [
|
|
85
|
+
["semantic", "semantic"], # topic expansion
|
|
86
|
+
["entity", "temporal"], # entity timeline
|
|
87
|
+
["semantic", "causes"], # reasoning chains (forward)
|
|
88
|
+
["semantic", "caused_by"], # reasoning chains (backward)
|
|
89
|
+
["entity", "semantic"], # entity context
|
|
90
|
+
]
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Patterns from temporal seeds
|
|
94
|
+
patterns_temporal: list[list[str]] = field(
|
|
95
|
+
default_factory=lambda: [
|
|
96
|
+
["temporal", "semantic"], # what was happening then
|
|
97
|
+
["temporal", "entity"], # who was involved then
|
|
98
|
+
]
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclass
|
|
103
|
+
class SeedNode:
|
|
104
|
+
"""An entry point node with its initial score."""
|
|
105
|
+
|
|
106
|
+
node_id: str
|
|
107
|
+
score: float # initial mass (e.g., similarity score)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# -----------------------------------------------------------------------------
|
|
111
|
+
# Core Algorithm
|
|
112
|
+
# -----------------------------------------------------------------------------
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def mpfp_traverse(
|
|
116
|
+
seeds: list[SeedNode],
|
|
117
|
+
pattern: list[str],
|
|
118
|
+
adjacency: TypedAdjacency,
|
|
119
|
+
config: MPFPConfig,
|
|
120
|
+
) -> PatternResult:
|
|
121
|
+
"""
|
|
122
|
+
Forward Push traversal following a meta-path pattern.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
seeds: Entry point nodes with initial scores
|
|
126
|
+
pattern: Sequence of edge types to follow
|
|
127
|
+
adjacency: Typed adjacency structure
|
|
128
|
+
config: Algorithm parameters
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
PatternResult with accumulated scores per node
|
|
132
|
+
"""
|
|
133
|
+
if not seeds:
|
|
134
|
+
return PatternResult(pattern=pattern, scores={})
|
|
135
|
+
|
|
136
|
+
scores: dict[str, float] = {}
|
|
137
|
+
|
|
138
|
+
# Initialize frontier with seed masses (normalized)
|
|
139
|
+
total_seed_score = sum(s.score for s in seeds)
|
|
140
|
+
if total_seed_score == 0:
|
|
141
|
+
total_seed_score = len(seeds) # fallback to uniform
|
|
142
|
+
|
|
143
|
+
frontier: dict[str, float] = {s.node_id: s.score / total_seed_score for s in seeds}
|
|
144
|
+
|
|
145
|
+
# Follow pattern hop by hop
|
|
146
|
+
for edge_type in pattern:
|
|
147
|
+
next_frontier: dict[str, float] = {}
|
|
148
|
+
|
|
149
|
+
for node_id, mass in frontier.items():
|
|
150
|
+
if mass < config.threshold:
|
|
151
|
+
continue
|
|
152
|
+
|
|
153
|
+
# Keep α portion for this node
|
|
154
|
+
scores[node_id] = scores.get(node_id, 0) + config.alpha * mass
|
|
155
|
+
|
|
156
|
+
# Push (1-α) to neighbors
|
|
157
|
+
push_mass = (1 - config.alpha) * mass
|
|
158
|
+
neighbors = adjacency.get_normalized_neighbors(edge_type, node_id, config.top_k_neighbors)
|
|
159
|
+
|
|
160
|
+
for neighbor in neighbors:
|
|
161
|
+
next_frontier[neighbor.node_id] = next_frontier.get(neighbor.node_id, 0) + push_mass * neighbor.weight
|
|
162
|
+
|
|
163
|
+
frontier = next_frontier
|
|
164
|
+
|
|
165
|
+
# Final frontier nodes get their remaining mass
|
|
166
|
+
for node_id, mass in frontier.items():
|
|
167
|
+
if mass >= config.threshold:
|
|
168
|
+
scores[node_id] = scores.get(node_id, 0) + mass
|
|
169
|
+
|
|
170
|
+
return PatternResult(pattern=pattern, scores=scores)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def rrf_fusion(
|
|
174
|
+
results: list[PatternResult],
|
|
175
|
+
k: int = 60,
|
|
176
|
+
top_k: int = 50,
|
|
177
|
+
) -> list[tuple[str, float]]:
|
|
178
|
+
"""
|
|
179
|
+
Reciprocal Rank Fusion to combine pattern results.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
results: List of pattern results
|
|
183
|
+
k: RRF constant (higher = more uniform weighting)
|
|
184
|
+
top_k: Number of results to return
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
List of (node_id, fused_score) tuples, sorted by score descending
|
|
188
|
+
"""
|
|
189
|
+
fused: dict[str, float] = {}
|
|
190
|
+
|
|
191
|
+
for result in results:
|
|
192
|
+
if not result.scores:
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
# Rank nodes by their score in this pattern
|
|
196
|
+
ranked = sorted(result.scores.keys(), key=lambda n: result.scores[n], reverse=True)
|
|
197
|
+
|
|
198
|
+
for rank, node_id in enumerate(ranked):
|
|
199
|
+
fused[node_id] = fused.get(node_id, 0) + 1.0 / (k + rank + 1)
|
|
200
|
+
|
|
201
|
+
# Sort by fused score and return top-k
|
|
202
|
+
sorted_results = sorted(fused.items(), key=lambda x: x[1], reverse=True)
|
|
203
|
+
|
|
204
|
+
return sorted_results[:top_k]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# -----------------------------------------------------------------------------
|
|
208
|
+
# Database Loading
|
|
209
|
+
# -----------------------------------------------------------------------------
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
async def load_typed_adjacency(pool, bank_id: str) -> TypedAdjacency:
|
|
213
|
+
"""
|
|
214
|
+
Load all edges for a bank, split by edge type.
|
|
215
|
+
|
|
216
|
+
Single query, then organize in-memory for fast traversal.
|
|
217
|
+
"""
|
|
218
|
+
async with acquire_with_retry(pool) as conn:
|
|
219
|
+
rows = await conn.fetch(
|
|
220
|
+
"""
|
|
221
|
+
SELECT ml.from_unit_id, ml.to_unit_id, ml.link_type, ml.weight
|
|
222
|
+
FROM memory_links ml
|
|
223
|
+
JOIN memory_units mu ON ml.from_unit_id = mu.id
|
|
224
|
+
WHERE mu.bank_id = $1
|
|
225
|
+
AND ml.weight >= 0.1
|
|
226
|
+
ORDER BY ml.from_unit_id, ml.weight DESC
|
|
227
|
+
""",
|
|
228
|
+
bank_id,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
graphs: dict[str, dict[str, list[EdgeTarget]]] = defaultdict(lambda: defaultdict(list))
|
|
232
|
+
|
|
233
|
+
for row in rows:
|
|
234
|
+
from_id = str(row["from_unit_id"])
|
|
235
|
+
to_id = str(row["to_unit_id"])
|
|
236
|
+
link_type = row["link_type"]
|
|
237
|
+
weight = row["weight"]
|
|
238
|
+
|
|
239
|
+
graphs[link_type][from_id].append(EdgeTarget(node_id=to_id, weight=weight))
|
|
240
|
+
|
|
241
|
+
return TypedAdjacency(graphs=dict(graphs))
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
async def fetch_memory_units_by_ids(
|
|
245
|
+
pool,
|
|
246
|
+
node_ids: list[str],
|
|
247
|
+
fact_type: str,
|
|
248
|
+
) -> list[RetrievalResult]:
|
|
249
|
+
"""Fetch full memory unit details for a list of node IDs."""
|
|
250
|
+
if not node_ids:
|
|
251
|
+
return []
|
|
252
|
+
|
|
253
|
+
async with acquire_with_retry(pool) as conn:
|
|
254
|
+
rows = await conn.fetch(
|
|
255
|
+
"""
|
|
256
|
+
SELECT id, text, context, event_date, occurred_start, occurred_end,
|
|
257
|
+
mentioned_at, access_count, embedding, fact_type, document_id, chunk_id
|
|
258
|
+
FROM memory_units
|
|
259
|
+
WHERE id = ANY($1::uuid[])
|
|
260
|
+
AND fact_type = $2
|
|
261
|
+
""",
|
|
262
|
+
node_ids,
|
|
263
|
+
fact_type,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return [RetrievalResult.from_db_row(dict(r)) for r in rows]
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
# -----------------------------------------------------------------------------
|
|
270
|
+
# Graph Retriever Implementation
|
|
271
|
+
# -----------------------------------------------------------------------------
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class MPFPGraphRetriever(GraphRetriever):
|
|
275
|
+
"""
|
|
276
|
+
Graph retrieval using Meta-Path Forward Push.
|
|
277
|
+
|
|
278
|
+
Runs predefined patterns in parallel from semantic and temporal seeds,
|
|
279
|
+
then fuses results via RRF.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def __init__(self, config: MPFPConfig | None = None):
|
|
283
|
+
"""
|
|
284
|
+
Initialize MPFP retriever.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
config: Algorithm configuration (uses defaults if None)
|
|
288
|
+
"""
|
|
289
|
+
self.config = config or MPFPConfig()
|
|
290
|
+
self._adjacency_cache: dict[str, TypedAdjacency] = {}
|
|
291
|
+
|
|
292
|
+
@property
|
|
293
|
+
def name(self) -> str:
|
|
294
|
+
return "mpfp"
|
|
295
|
+
|
|
296
|
+
async def retrieve(
|
|
297
|
+
self,
|
|
298
|
+
pool,
|
|
299
|
+
query_embedding_str: str,
|
|
300
|
+
bank_id: str,
|
|
301
|
+
fact_type: str,
|
|
302
|
+
budget: int,
|
|
303
|
+
query_text: str | None = None,
|
|
304
|
+
semantic_seeds: list[RetrievalResult] | None = None,
|
|
305
|
+
temporal_seeds: list[RetrievalResult] | None = None,
|
|
306
|
+
) -> list[RetrievalResult]:
|
|
307
|
+
"""
|
|
308
|
+
Retrieve facts using MPFP algorithm.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
pool: Database connection pool
|
|
312
|
+
query_embedding_str: Query embedding (used for fallback seed finding)
|
|
313
|
+
bank_id: Memory bank ID
|
|
314
|
+
fact_type: Fact type to filter
|
|
315
|
+
budget: Maximum results to return
|
|
316
|
+
query_text: Original query text (optional)
|
|
317
|
+
semantic_seeds: Pre-computed semantic entry points
|
|
318
|
+
temporal_seeds: Pre-computed temporal entry points
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
List of RetrievalResult with activation scores
|
|
322
|
+
"""
|
|
323
|
+
# Load typed adjacency (could cache per bank_id with TTL)
|
|
324
|
+
adjacency = await load_typed_adjacency(pool, bank_id)
|
|
325
|
+
|
|
326
|
+
# Convert seeds to SeedNode format
|
|
327
|
+
semantic_seed_nodes = self._convert_seeds(semantic_seeds, "similarity")
|
|
328
|
+
temporal_seed_nodes = self._convert_seeds(temporal_seeds, "temporal_score")
|
|
329
|
+
|
|
330
|
+
# If no semantic seeds provided, fall back to finding our own
|
|
331
|
+
if not semantic_seed_nodes:
|
|
332
|
+
semantic_seed_nodes = await self._find_semantic_seeds(pool, query_embedding_str, bank_id, fact_type)
|
|
333
|
+
|
|
334
|
+
# Run all patterns in parallel
|
|
335
|
+
tasks = []
|
|
336
|
+
|
|
337
|
+
# Patterns from semantic seeds
|
|
338
|
+
for pattern in self.config.patterns_semantic:
|
|
339
|
+
if semantic_seed_nodes:
|
|
340
|
+
tasks.append(
|
|
341
|
+
asyncio.to_thread(
|
|
342
|
+
mpfp_traverse,
|
|
343
|
+
semantic_seed_nodes,
|
|
344
|
+
pattern,
|
|
345
|
+
adjacency,
|
|
346
|
+
self.config,
|
|
347
|
+
)
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# Patterns from temporal seeds
|
|
351
|
+
for pattern in self.config.patterns_temporal:
|
|
352
|
+
if temporal_seed_nodes:
|
|
353
|
+
tasks.append(
|
|
354
|
+
asyncio.to_thread(
|
|
355
|
+
mpfp_traverse,
|
|
356
|
+
temporal_seed_nodes,
|
|
357
|
+
pattern,
|
|
358
|
+
adjacency,
|
|
359
|
+
self.config,
|
|
360
|
+
)
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
if not tasks:
|
|
364
|
+
return []
|
|
365
|
+
|
|
366
|
+
# Gather pattern results
|
|
367
|
+
pattern_results = await asyncio.gather(*tasks)
|
|
368
|
+
|
|
369
|
+
# Fuse results
|
|
370
|
+
fused = rrf_fusion(pattern_results, top_k=budget)
|
|
371
|
+
|
|
372
|
+
if not fused:
|
|
373
|
+
return []
|
|
374
|
+
|
|
375
|
+
# Get top result IDs (don't exclude seeds - they may be highly relevant)
|
|
376
|
+
result_ids = [node_id for node_id, score in fused][:budget]
|
|
377
|
+
|
|
378
|
+
# Fetch full details
|
|
379
|
+
results = await fetch_memory_units_by_ids(pool, result_ids, fact_type)
|
|
380
|
+
|
|
381
|
+
# Add activation scores from fusion
|
|
382
|
+
score_map = {node_id: score for node_id, score in fused}
|
|
383
|
+
for result in results:
|
|
384
|
+
result.activation = score_map.get(result.id, 0.0)
|
|
385
|
+
|
|
386
|
+
# Sort by activation
|
|
387
|
+
results.sort(key=lambda r: r.activation or 0, reverse=True)
|
|
388
|
+
|
|
389
|
+
return results
|
|
390
|
+
|
|
391
|
+
def _convert_seeds(
|
|
392
|
+
self,
|
|
393
|
+
seeds: list[RetrievalResult] | None,
|
|
394
|
+
score_attr: str,
|
|
395
|
+
) -> list[SeedNode]:
|
|
396
|
+
"""Convert RetrievalResult seeds to SeedNode format."""
|
|
397
|
+
if not seeds:
|
|
398
|
+
return []
|
|
399
|
+
|
|
400
|
+
result = []
|
|
401
|
+
for seed in seeds:
|
|
402
|
+
score = getattr(seed, score_attr, None)
|
|
403
|
+
if score is None:
|
|
404
|
+
score = seed.activation or seed.similarity or 1.0
|
|
405
|
+
result.append(SeedNode(node_id=seed.id, score=score))
|
|
406
|
+
|
|
407
|
+
return result
|
|
408
|
+
|
|
409
|
+
async def _find_semantic_seeds(
|
|
410
|
+
self,
|
|
411
|
+
pool,
|
|
412
|
+
query_embedding_str: str,
|
|
413
|
+
bank_id: str,
|
|
414
|
+
fact_type: str,
|
|
415
|
+
limit: int = 20,
|
|
416
|
+
threshold: float = 0.3,
|
|
417
|
+
) -> list[SeedNode]:
|
|
418
|
+
"""Fallback: find semantic seeds via embedding search."""
|
|
419
|
+
async with acquire_with_retry(pool) as conn:
|
|
420
|
+
rows = await conn.fetch(
|
|
421
|
+
"""
|
|
422
|
+
SELECT id, 1 - (embedding <=> $1::vector) AS similarity
|
|
423
|
+
FROM memory_units
|
|
424
|
+
WHERE bank_id = $2
|
|
425
|
+
AND embedding IS NOT NULL
|
|
426
|
+
AND fact_type = $3
|
|
427
|
+
AND (1 - (embedding <=> $1::vector)) >= $4
|
|
428
|
+
ORDER BY embedding <=> $1::vector
|
|
429
|
+
LIMIT $5
|
|
430
|
+
""",
|
|
431
|
+
query_embedding_str,
|
|
432
|
+
bank_id,
|
|
433
|
+
fact_type,
|
|
434
|
+
threshold,
|
|
435
|
+
limit,
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
return [SeedNode(node_id=str(r["id"]), score=r["similarity"]) for r in rows]
|
|
@@ -6,7 +6,7 @@ about an entity, without personality influence.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
|
-
|
|
9
|
+
|
|
10
10
|
from pydantic import BaseModel, Field
|
|
11
11
|
|
|
12
12
|
from ..response_models import MemoryFact
|
|
@@ -16,18 +16,17 @@ logger = logging.getLogger(__name__)
|
|
|
16
16
|
|
|
17
17
|
class Observation(BaseModel):
|
|
18
18
|
"""An observation about an entity."""
|
|
19
|
+
|
|
19
20
|
observation: str = Field(description="The observation text - a factual statement about the entity")
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class ObservationExtractionResponse(BaseModel):
|
|
23
24
|
"""Response containing extracted observations."""
|
|
24
|
-
observations: List[Observation] = Field(
|
|
25
|
-
default_factory=list,
|
|
26
|
-
description="List of observations about the entity"
|
|
27
|
-
)
|
|
28
25
|
|
|
26
|
+
observations: list[Observation] = Field(default_factory=list, description="List of observations about the entity")
|
|
29
27
|
|
|
30
|
-
|
|
28
|
+
|
|
29
|
+
def format_facts_for_observation_prompt(facts: list[MemoryFact]) -> str:
|
|
31
30
|
"""Format facts as text for observation extraction prompt."""
|
|
32
31
|
import json
|
|
33
32
|
|
|
@@ -35,9 +34,7 @@ def format_facts_for_observation_prompt(facts: List[MemoryFact]) -> str:
|
|
|
35
34
|
return "[]"
|
|
36
35
|
formatted = []
|
|
37
36
|
for fact in facts:
|
|
38
|
-
fact_obj = {
|
|
39
|
-
"text": fact.text
|
|
40
|
-
}
|
|
37
|
+
fact_obj = {"text": fact.text}
|
|
41
38
|
|
|
42
39
|
# Add context if available
|
|
43
40
|
if fact.context:
|
|
@@ -92,11 +89,7 @@ def get_observation_system_message() -> str:
|
|
|
92
89
|
return "You are an objective observer synthesizing facts about an entity. Generate clear, factual observations without opinions or personality influence. Be concise and accurate."
|
|
93
90
|
|
|
94
91
|
|
|
95
|
-
async def extract_observations_from_facts(
|
|
96
|
-
llm_config,
|
|
97
|
-
entity_name: str,
|
|
98
|
-
facts: List[MemoryFact]
|
|
99
|
-
) -> List[str]:
|
|
92
|
+
async def extract_observations_from_facts(llm_config, entity_name: str, facts: list[MemoryFact]) -> list[str]:
|
|
100
93
|
"""
|
|
101
94
|
Extract observations from facts about an entity using LLM.
|
|
102
95
|
|
|
@@ -118,10 +111,10 @@ async def extract_observations_from_facts(
|
|
|
118
111
|
result = await llm_config.call(
|
|
119
112
|
messages=[
|
|
120
113
|
{"role": "system", "content": get_observation_system_message()},
|
|
121
|
-
{"role": "user", "content": prompt}
|
|
114
|
+
{"role": "user", "content": prompt},
|
|
122
115
|
],
|
|
123
116
|
response_format=ObservationExtractionResponse,
|
|
124
|
-
scope="memory_extract_observation"
|
|
117
|
+
scope="memory_extract_observation",
|
|
125
118
|
)
|
|
126
119
|
|
|
127
120
|
observations = [op.observation for op in result.observations]
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
Cross-encoder neural reranking for search results.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import List
|
|
6
5
|
from .types import MergedCandidate, ScoredResult
|
|
7
6
|
|
|
8
7
|
|
|
@@ -24,14 +23,11 @@ class CrossEncoderReranker:
|
|
|
24
23
|
"""
|
|
25
24
|
if cross_encoder is None:
|
|
26
25
|
from hindsight_api.engine.cross_encoder import create_cross_encoder_from_env
|
|
26
|
+
|
|
27
27
|
cross_encoder = create_cross_encoder_from_env()
|
|
28
28
|
self.cross_encoder = cross_encoder
|
|
29
29
|
|
|
30
|
-
def rerank(
|
|
31
|
-
self,
|
|
32
|
-
query: str,
|
|
33
|
-
candidates: List[MergedCandidate]
|
|
34
|
-
) -> List[ScoredResult]:
|
|
30
|
+
def rerank(self, query: str, candidates: list[MergedCandidate]) -> list[ScoredResult]:
|
|
35
31
|
"""
|
|
36
32
|
Rerank candidates using cross-encoder scores.
|
|
37
33
|
|
|
@@ -77,6 +73,7 @@ class CrossEncoderReranker:
|
|
|
77
73
|
# Normalize scores using sigmoid to [0, 1] range
|
|
78
74
|
# Cross-encoder returns logits which can be negative
|
|
79
75
|
import numpy as np
|
|
76
|
+
|
|
80
77
|
def sigmoid(x):
|
|
81
78
|
return 1 / (1 + np.exp(-x))
|
|
82
79
|
|
|
@@ -89,7 +86,7 @@ class CrossEncoderReranker:
|
|
|
89
86
|
candidate=candidate,
|
|
90
87
|
cross_encoder_score=float(raw_score),
|
|
91
88
|
cross_encoder_score_normalized=float(norm_score),
|
|
92
|
-
weight=float(norm_score) # Initial weight is just cross-encoder score
|
|
89
|
+
weight=float(norm_score), # Initial weight is just cross-encoder score
|
|
93
90
|
)
|
|
94
91
|
scored_results.append(scored_result)
|
|
95
92
|
|