hindsight-api 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -9
- hindsight_api/alembic/env.py +5 -8
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
- hindsight_api/api/__init__.py +10 -10
- hindsight_api/api/http.py +575 -593
- hindsight_api/api/mcp.py +30 -28
- hindsight_api/banner.py +13 -6
- hindsight_api/config.py +9 -13
- hindsight_api/engine/__init__.py +9 -9
- hindsight_api/engine/cross_encoder.py +22 -21
- hindsight_api/engine/db_utils.py +5 -4
- hindsight_api/engine/embeddings.py +22 -21
- hindsight_api/engine/entity_resolver.py +81 -75
- hindsight_api/engine/llm_wrapper.py +61 -79
- hindsight_api/engine/memory_engine.py +603 -625
- hindsight_api/engine/query_analyzer.py +100 -97
- hindsight_api/engine/response_models.py +105 -106
- hindsight_api/engine/retain/__init__.py +9 -16
- hindsight_api/engine/retain/bank_utils.py +34 -58
- hindsight_api/engine/retain/chunk_storage.py +4 -12
- hindsight_api/engine/retain/deduplication.py +9 -28
- hindsight_api/engine/retain/embedding_processing.py +4 -11
- hindsight_api/engine/retain/embedding_utils.py +3 -4
- hindsight_api/engine/retain/entity_processing.py +7 -17
- hindsight_api/engine/retain/fact_extraction.py +155 -165
- hindsight_api/engine/retain/fact_storage.py +11 -23
- hindsight_api/engine/retain/link_creation.py +11 -39
- hindsight_api/engine/retain/link_utils.py +166 -95
- hindsight_api/engine/retain/observation_regeneration.py +39 -52
- hindsight_api/engine/retain/orchestrator.py +72 -62
- hindsight_api/engine/retain/types.py +49 -43
- hindsight_api/engine/search/__init__.py +5 -5
- hindsight_api/engine/search/fusion.py +6 -15
- hindsight_api/engine/search/graph_retrieval.py +22 -23
- hindsight_api/engine/search/mpfp_retrieval.py +76 -92
- hindsight_api/engine/search/observation_utils.py +9 -16
- hindsight_api/engine/search/reranking.py +4 -7
- hindsight_api/engine/search/retrieval.py +87 -66
- hindsight_api/engine/search/scoring.py +5 -7
- hindsight_api/engine/search/temporal_extraction.py +8 -11
- hindsight_api/engine/search/think_utils.py +115 -39
- hindsight_api/engine/search/trace.py +68 -39
- hindsight_api/engine/search/tracer.py +44 -35
- hindsight_api/engine/search/types.py +20 -17
- hindsight_api/engine/task_backend.py +21 -26
- hindsight_api/engine/utils.py +25 -10
- hindsight_api/main.py +21 -40
- hindsight_api/mcp_local.py +190 -0
- hindsight_api/metrics.py +44 -30
- hindsight_api/migrations.py +10 -8
- hindsight_api/models.py +60 -72
- hindsight_api/pg0.py +22 -23
- hindsight_api/server.py +3 -6
- hindsight_api-0.1.7.dist-info/METADATA +178 -0
- hindsight_api-0.1.7.dist-info/RECORD +64 -0
- {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.7.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.1.5.dist-info/METADATA +0 -42
- hindsight_api-0.1.5.dist-info/RECORD +0 -63
- {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.7.dist-info}/WHEEL +0 -0
|
@@ -16,13 +16,12 @@ Key properties:
|
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
18
|
import logging
|
|
19
|
-
from dataclasses import dataclass, field
|
|
20
|
-
from typing import List, Dict, Optional, Tuple
|
|
21
19
|
from collections import defaultdict
|
|
20
|
+
from dataclasses import dataclass, field
|
|
22
21
|
|
|
23
|
-
from .types import RetrievalResult
|
|
24
|
-
from .graph_retrieval import GraphRetriever
|
|
25
22
|
from ..db_utils import acquire_with_retry
|
|
23
|
+
from .graph_retrieval import GraphRetriever
|
|
24
|
+
from .types import RetrievalResult
|
|
26
25
|
|
|
27
26
|
logger = logging.getLogger(__name__)
|
|
28
27
|
|
|
@@ -31,9 +30,11 @@ logger = logging.getLogger(__name__)
|
|
|
31
30
|
# Data Classes
|
|
32
31
|
# -----------------------------------------------------------------------------
|
|
33
32
|
|
|
33
|
+
|
|
34
34
|
@dataclass
|
|
35
35
|
class EdgeTarget:
|
|
36
36
|
"""A neighbor node with its edge weight."""
|
|
37
|
+
|
|
37
38
|
node_id: str
|
|
38
39
|
weight: float
|
|
39
40
|
|
|
@@ -41,19 +42,15 @@ class EdgeTarget:
|
|
|
41
42
|
@dataclass
|
|
42
43
|
class TypedAdjacency:
|
|
43
44
|
"""Adjacency lists split by edge type."""
|
|
45
|
+
|
|
44
46
|
# edge_type -> from_node_id -> list of (to_node_id, weight)
|
|
45
|
-
graphs:
|
|
47
|
+
graphs: dict[str, dict[str, list[EdgeTarget]]] = field(default_factory=dict)
|
|
46
48
|
|
|
47
|
-
def get_neighbors(self, edge_type: str, node_id: str) ->
|
|
49
|
+
def get_neighbors(self, edge_type: str, node_id: str) -> list[EdgeTarget]:
|
|
48
50
|
"""Get neighbors for a node via a specific edge type."""
|
|
49
51
|
return self.graphs.get(edge_type, {}).get(node_id, [])
|
|
50
52
|
|
|
51
|
-
def get_normalized_neighbors(
|
|
52
|
-
self,
|
|
53
|
-
edge_type: str,
|
|
54
|
-
node_id: str,
|
|
55
|
-
top_k: int
|
|
56
|
-
) -> List[EdgeTarget]:
|
|
53
|
+
def get_normalized_neighbors(self, edge_type: str, node_id: str, top_k: int) -> list[EdgeTarget]:
|
|
57
54
|
"""Get top-k neighbors with weights normalized to sum to 1."""
|
|
58
55
|
neighbors = self.get_neighbors(edge_type, node_id)[:top_k]
|
|
59
56
|
if not neighbors:
|
|
@@ -63,45 +60,49 @@ class TypedAdjacency:
|
|
|
63
60
|
if total == 0:
|
|
64
61
|
return []
|
|
65
62
|
|
|
66
|
-
return [
|
|
67
|
-
EdgeTarget(node_id=n.node_id, weight=n.weight / total)
|
|
68
|
-
for n in neighbors
|
|
69
|
-
]
|
|
63
|
+
return [EdgeTarget(node_id=n.node_id, weight=n.weight / total) for n in neighbors]
|
|
70
64
|
|
|
71
65
|
|
|
72
66
|
@dataclass
|
|
73
67
|
class PatternResult:
|
|
74
68
|
"""Result from a single pattern traversal."""
|
|
75
|
-
|
|
76
|
-
|
|
69
|
+
|
|
70
|
+
pattern: list[str]
|
|
71
|
+
scores: dict[str, float] # node_id -> accumulated mass
|
|
77
72
|
|
|
78
73
|
|
|
79
74
|
@dataclass
|
|
80
75
|
class MPFPConfig:
|
|
81
76
|
"""Configuration for MPFP algorithm."""
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
77
|
+
|
|
78
|
+
alpha: float = 0.15 # teleport/keep probability
|
|
79
|
+
threshold: float = 1e-6 # mass pruning threshold (lower = explore more)
|
|
80
|
+
top_k_neighbors: int = 20 # fan-out limit per node
|
|
85
81
|
|
|
86
82
|
# Patterns from semantic seeds
|
|
87
|
-
patterns_semantic:
|
|
88
|
-
[
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
83
|
+
patterns_semantic: list[list[str]] = field(
|
|
84
|
+
default_factory=lambda: [
|
|
85
|
+
["semantic", "semantic"], # topic expansion
|
|
86
|
+
["entity", "temporal"], # entity timeline
|
|
87
|
+
["semantic", "causes"], # reasoning chains (forward)
|
|
88
|
+
["semantic", "caused_by"], # reasoning chains (backward)
|
|
89
|
+
["entity", "semantic"], # entity context
|
|
90
|
+
]
|
|
91
|
+
)
|
|
94
92
|
|
|
95
93
|
# Patterns from temporal seeds
|
|
96
|
-
patterns_temporal:
|
|
97
|
-
[
|
|
98
|
-
|
|
99
|
-
|
|
94
|
+
patterns_temporal: list[list[str]] = field(
|
|
95
|
+
default_factory=lambda: [
|
|
96
|
+
["temporal", "semantic"], # what was happening then
|
|
97
|
+
["temporal", "entity"], # who was involved then
|
|
98
|
+
]
|
|
99
|
+
)
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
@dataclass
|
|
103
103
|
class SeedNode:
|
|
104
104
|
"""An entry point node with its initial score."""
|
|
105
|
+
|
|
105
106
|
node_id: str
|
|
106
107
|
score: float # initial mass (e.g., similarity score)
|
|
107
108
|
|
|
@@ -110,9 +111,10 @@ class SeedNode:
|
|
|
110
111
|
# Core Algorithm
|
|
111
112
|
# -----------------------------------------------------------------------------
|
|
112
113
|
|
|
114
|
+
|
|
113
115
|
def mpfp_traverse(
|
|
114
|
-
seeds:
|
|
115
|
-
pattern:
|
|
116
|
+
seeds: list[SeedNode],
|
|
117
|
+
pattern: list[str],
|
|
116
118
|
adjacency: TypedAdjacency,
|
|
117
119
|
config: MPFPConfig,
|
|
118
120
|
) -> PatternResult:
|
|
@@ -131,20 +133,18 @@ def mpfp_traverse(
|
|
|
131
133
|
if not seeds:
|
|
132
134
|
return PatternResult(pattern=pattern, scores={})
|
|
133
135
|
|
|
134
|
-
scores:
|
|
136
|
+
scores: dict[str, float] = {}
|
|
135
137
|
|
|
136
138
|
# Initialize frontier with seed masses (normalized)
|
|
137
139
|
total_seed_score = sum(s.score for s in seeds)
|
|
138
140
|
if total_seed_score == 0:
|
|
139
141
|
total_seed_score = len(seeds) # fallback to uniform
|
|
140
142
|
|
|
141
|
-
frontier:
|
|
142
|
-
s.node_id: s.score / total_seed_score for s in seeds
|
|
143
|
-
}
|
|
143
|
+
frontier: dict[str, float] = {s.node_id: s.score / total_seed_score for s in seeds}
|
|
144
144
|
|
|
145
145
|
# Follow pattern hop by hop
|
|
146
146
|
for edge_type in pattern:
|
|
147
|
-
next_frontier:
|
|
147
|
+
next_frontier: dict[str, float] = {}
|
|
148
148
|
|
|
149
149
|
for node_id, mass in frontier.items():
|
|
150
150
|
if mass < config.threshold:
|
|
@@ -155,15 +155,10 @@ def mpfp_traverse(
|
|
|
155
155
|
|
|
156
156
|
# Push (1-α) to neighbors
|
|
157
157
|
push_mass = (1 - config.alpha) * mass
|
|
158
|
-
neighbors = adjacency.get_normalized_neighbors(
|
|
159
|
-
edge_type, node_id, config.top_k_neighbors
|
|
160
|
-
)
|
|
158
|
+
neighbors = adjacency.get_normalized_neighbors(edge_type, node_id, config.top_k_neighbors)
|
|
161
159
|
|
|
162
160
|
for neighbor in neighbors:
|
|
163
|
-
next_frontier[neighbor.node_id] = (
|
|
164
|
-
next_frontier.get(neighbor.node_id, 0) +
|
|
165
|
-
push_mass * neighbor.weight
|
|
166
|
-
)
|
|
161
|
+
next_frontier[neighbor.node_id] = next_frontier.get(neighbor.node_id, 0) + push_mass * neighbor.weight
|
|
167
162
|
|
|
168
163
|
frontier = next_frontier
|
|
169
164
|
|
|
@@ -176,10 +171,10 @@ def mpfp_traverse(
|
|
|
176
171
|
|
|
177
172
|
|
|
178
173
|
def rrf_fusion(
|
|
179
|
-
results:
|
|
174
|
+
results: list[PatternResult],
|
|
180
175
|
k: int = 60,
|
|
181
176
|
top_k: int = 50,
|
|
182
|
-
) ->
|
|
177
|
+
) -> list[tuple[str, float]]:
|
|
183
178
|
"""
|
|
184
179
|
Reciprocal Rank Fusion to combine pattern results.
|
|
185
180
|
|
|
@@ -191,28 +186,20 @@ def rrf_fusion(
|
|
|
191
186
|
Returns:
|
|
192
187
|
List of (node_id, fused_score) tuples, sorted by score descending
|
|
193
188
|
"""
|
|
194
|
-
fused:
|
|
189
|
+
fused: dict[str, float] = {}
|
|
195
190
|
|
|
196
191
|
for result in results:
|
|
197
192
|
if not result.scores:
|
|
198
193
|
continue
|
|
199
194
|
|
|
200
195
|
# Rank nodes by their score in this pattern
|
|
201
|
-
ranked = sorted(
|
|
202
|
-
result.scores.keys(),
|
|
203
|
-
key=lambda n: result.scores[n],
|
|
204
|
-
reverse=True
|
|
205
|
-
)
|
|
196
|
+
ranked = sorted(result.scores.keys(), key=lambda n: result.scores[n], reverse=True)
|
|
206
197
|
|
|
207
198
|
for rank, node_id in enumerate(ranked):
|
|
208
199
|
fused[node_id] = fused.get(node_id, 0) + 1.0 / (k + rank + 1)
|
|
209
200
|
|
|
210
201
|
# Sort by fused score and return top-k
|
|
211
|
-
sorted_results = sorted(
|
|
212
|
-
fused.items(),
|
|
213
|
-
key=lambda x: x[1],
|
|
214
|
-
reverse=True
|
|
215
|
-
)
|
|
202
|
+
sorted_results = sorted(fused.items(), key=lambda x: x[1], reverse=True)
|
|
216
203
|
|
|
217
204
|
return sorted_results[:top_k]
|
|
218
205
|
|
|
@@ -221,6 +208,7 @@ def rrf_fusion(
|
|
|
221
208
|
# Database Loading
|
|
222
209
|
# -----------------------------------------------------------------------------
|
|
223
210
|
|
|
211
|
+
|
|
224
212
|
async def load_typed_adjacency(pool, bank_id: str) -> TypedAdjacency:
|
|
225
213
|
"""
|
|
226
214
|
Load all edges for a bank, split by edge type.
|
|
@@ -237,31 +225,27 @@ async def load_typed_adjacency(pool, bank_id: str) -> TypedAdjacency:
|
|
|
237
225
|
AND ml.weight >= 0.1
|
|
238
226
|
ORDER BY ml.from_unit_id, ml.weight DESC
|
|
239
227
|
""",
|
|
240
|
-
bank_id
|
|
228
|
+
bank_id,
|
|
241
229
|
)
|
|
242
230
|
|
|
243
|
-
graphs:
|
|
244
|
-
lambda: defaultdict(list)
|
|
245
|
-
)
|
|
231
|
+
graphs: dict[str, dict[str, list[EdgeTarget]]] = defaultdict(lambda: defaultdict(list))
|
|
246
232
|
|
|
247
233
|
for row in rows:
|
|
248
|
-
from_id = str(row[
|
|
249
|
-
to_id = str(row[
|
|
250
|
-
link_type = row[
|
|
251
|
-
weight = row[
|
|
234
|
+
from_id = str(row["from_unit_id"])
|
|
235
|
+
to_id = str(row["to_unit_id"])
|
|
236
|
+
link_type = row["link_type"]
|
|
237
|
+
weight = row["weight"]
|
|
252
238
|
|
|
253
|
-
graphs[link_type][from_id].append(
|
|
254
|
-
EdgeTarget(node_id=to_id, weight=weight)
|
|
255
|
-
)
|
|
239
|
+
graphs[link_type][from_id].append(EdgeTarget(node_id=to_id, weight=weight))
|
|
256
240
|
|
|
257
241
|
return TypedAdjacency(graphs=dict(graphs))
|
|
258
242
|
|
|
259
243
|
|
|
260
244
|
async def fetch_memory_units_by_ids(
|
|
261
245
|
pool,
|
|
262
|
-
node_ids:
|
|
246
|
+
node_ids: list[str],
|
|
263
247
|
fact_type: str,
|
|
264
|
-
) ->
|
|
248
|
+
) -> list[RetrievalResult]:
|
|
265
249
|
"""Fetch full memory unit details for a list of node IDs."""
|
|
266
250
|
if not node_ids:
|
|
267
251
|
return []
|
|
@@ -276,7 +260,7 @@ async def fetch_memory_units_by_ids(
|
|
|
276
260
|
AND fact_type = $2
|
|
277
261
|
""",
|
|
278
262
|
node_ids,
|
|
279
|
-
fact_type
|
|
263
|
+
fact_type,
|
|
280
264
|
)
|
|
281
265
|
|
|
282
266
|
return [RetrievalResult.from_db_row(dict(r)) for r in rows]
|
|
@@ -286,6 +270,7 @@ async def fetch_memory_units_by_ids(
|
|
|
286
270
|
# Graph Retriever Implementation
|
|
287
271
|
# -----------------------------------------------------------------------------
|
|
288
272
|
|
|
273
|
+
|
|
289
274
|
class MPFPGraphRetriever(GraphRetriever):
|
|
290
275
|
"""
|
|
291
276
|
Graph retrieval using Meta-Path Forward Push.
|
|
@@ -294,7 +279,7 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
294
279
|
then fuses results via RRF.
|
|
295
280
|
"""
|
|
296
281
|
|
|
297
|
-
def __init__(self, config:
|
|
282
|
+
def __init__(self, config: MPFPConfig | None = None):
|
|
298
283
|
"""
|
|
299
284
|
Initialize MPFP retriever.
|
|
300
285
|
|
|
@@ -302,7 +287,7 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
302
287
|
config: Algorithm configuration (uses defaults if None)
|
|
303
288
|
"""
|
|
304
289
|
self.config = config or MPFPConfig()
|
|
305
|
-
self._adjacency_cache:
|
|
290
|
+
self._adjacency_cache: dict[str, TypedAdjacency] = {}
|
|
306
291
|
|
|
307
292
|
@property
|
|
308
293
|
def name(self) -> str:
|
|
@@ -315,10 +300,10 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
315
300
|
bank_id: str,
|
|
316
301
|
fact_type: str,
|
|
317
302
|
budget: int,
|
|
318
|
-
query_text:
|
|
319
|
-
semantic_seeds:
|
|
320
|
-
temporal_seeds:
|
|
321
|
-
) ->
|
|
303
|
+
query_text: str | None = None,
|
|
304
|
+
semantic_seeds: list[RetrievalResult] | None = None,
|
|
305
|
+
temporal_seeds: list[RetrievalResult] | None = None,
|
|
306
|
+
) -> list[RetrievalResult]:
|
|
322
307
|
"""
|
|
323
308
|
Retrieve facts using MPFP algorithm.
|
|
324
309
|
|
|
@@ -339,14 +324,12 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
339
324
|
adjacency = await load_typed_adjacency(pool, bank_id)
|
|
340
325
|
|
|
341
326
|
# Convert seeds to SeedNode format
|
|
342
|
-
semantic_seed_nodes = self._convert_seeds(semantic_seeds,
|
|
343
|
-
temporal_seed_nodes = self._convert_seeds(temporal_seeds,
|
|
327
|
+
semantic_seed_nodes = self._convert_seeds(semantic_seeds, "similarity")
|
|
328
|
+
temporal_seed_nodes = self._convert_seeds(temporal_seeds, "temporal_score")
|
|
344
329
|
|
|
345
330
|
# If no semantic seeds provided, fall back to finding our own
|
|
346
331
|
if not semantic_seed_nodes:
|
|
347
|
-
semantic_seed_nodes = await self._find_semantic_seeds(
|
|
348
|
-
pool, query_embedding_str, bank_id, fact_type
|
|
349
|
-
)
|
|
332
|
+
semantic_seed_nodes = await self._find_semantic_seeds(pool, query_embedding_str, bank_id, fact_type)
|
|
350
333
|
|
|
351
334
|
# Run all patterns in parallel
|
|
352
335
|
tasks = []
|
|
@@ -407,9 +390,9 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
407
390
|
|
|
408
391
|
def _convert_seeds(
|
|
409
392
|
self,
|
|
410
|
-
seeds:
|
|
393
|
+
seeds: list[RetrievalResult] | None,
|
|
411
394
|
score_attr: str,
|
|
412
|
-
) ->
|
|
395
|
+
) -> list[SeedNode]:
|
|
413
396
|
"""Convert RetrievalResult seeds to SeedNode format."""
|
|
414
397
|
if not seeds:
|
|
415
398
|
return []
|
|
@@ -431,7 +414,7 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
431
414
|
fact_type: str,
|
|
432
415
|
limit: int = 20,
|
|
433
416
|
threshold: float = 0.3,
|
|
434
|
-
) ->
|
|
417
|
+
) -> list[SeedNode]:
|
|
435
418
|
"""Fallback: find semantic seeds via embedding search."""
|
|
436
419
|
async with acquire_with_retry(pool) as conn:
|
|
437
420
|
rows = await conn.fetch(
|
|
@@ -445,10 +428,11 @@ class MPFPGraphRetriever(GraphRetriever):
|
|
|
445
428
|
ORDER BY embedding <=> $1::vector
|
|
446
429
|
LIMIT $5
|
|
447
430
|
""",
|
|
448
|
-
query_embedding_str,
|
|
431
|
+
query_embedding_str,
|
|
432
|
+
bank_id,
|
|
433
|
+
fact_type,
|
|
434
|
+
threshold,
|
|
435
|
+
limit,
|
|
449
436
|
)
|
|
450
437
|
|
|
451
|
-
return [
|
|
452
|
-
SeedNode(node_id=str(r['id']), score=r['similarity'])
|
|
453
|
-
for r in rows
|
|
454
|
-
]
|
|
438
|
+
return [SeedNode(node_id=str(r["id"]), score=r["similarity"]) for r in rows]
|
|
@@ -6,7 +6,7 @@ about an entity, without personality influence.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
|
-
|
|
9
|
+
|
|
10
10
|
from pydantic import BaseModel, Field
|
|
11
11
|
|
|
12
12
|
from ..response_models import MemoryFact
|
|
@@ -16,18 +16,17 @@ logger = logging.getLogger(__name__)
|
|
|
16
16
|
|
|
17
17
|
class Observation(BaseModel):
|
|
18
18
|
"""An observation about an entity."""
|
|
19
|
+
|
|
19
20
|
observation: str = Field(description="The observation text - a factual statement about the entity")
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class ObservationExtractionResponse(BaseModel):
|
|
23
24
|
"""Response containing extracted observations."""
|
|
24
|
-
observations: List[Observation] = Field(
|
|
25
|
-
default_factory=list,
|
|
26
|
-
description="List of observations about the entity"
|
|
27
|
-
)
|
|
28
25
|
|
|
26
|
+
observations: list[Observation] = Field(default_factory=list, description="List of observations about the entity")
|
|
29
27
|
|
|
30
|
-
|
|
28
|
+
|
|
29
|
+
def format_facts_for_observation_prompt(facts: list[MemoryFact]) -> str:
|
|
31
30
|
"""Format facts as text for observation extraction prompt."""
|
|
32
31
|
import json
|
|
33
32
|
|
|
@@ -35,9 +34,7 @@ def format_facts_for_observation_prompt(facts: List[MemoryFact]) -> str:
|
|
|
35
34
|
return "[]"
|
|
36
35
|
formatted = []
|
|
37
36
|
for fact in facts:
|
|
38
|
-
fact_obj = {
|
|
39
|
-
"text": fact.text
|
|
40
|
-
}
|
|
37
|
+
fact_obj = {"text": fact.text}
|
|
41
38
|
|
|
42
39
|
# Add context if available
|
|
43
40
|
if fact.context:
|
|
@@ -92,11 +89,7 @@ def get_observation_system_message() -> str:
|
|
|
92
89
|
return "You are an objective observer synthesizing facts about an entity. Generate clear, factual observations without opinions or personality influence. Be concise and accurate."
|
|
93
90
|
|
|
94
91
|
|
|
95
|
-
async def extract_observations_from_facts(
|
|
96
|
-
llm_config,
|
|
97
|
-
entity_name: str,
|
|
98
|
-
facts: List[MemoryFact]
|
|
99
|
-
) -> List[str]:
|
|
92
|
+
async def extract_observations_from_facts(llm_config, entity_name: str, facts: list[MemoryFact]) -> list[str]:
|
|
100
93
|
"""
|
|
101
94
|
Extract observations from facts about an entity using LLM.
|
|
102
95
|
|
|
@@ -118,10 +111,10 @@ async def extract_observations_from_facts(
|
|
|
118
111
|
result = await llm_config.call(
|
|
119
112
|
messages=[
|
|
120
113
|
{"role": "system", "content": get_observation_system_message()},
|
|
121
|
-
{"role": "user", "content": prompt}
|
|
114
|
+
{"role": "user", "content": prompt},
|
|
122
115
|
],
|
|
123
116
|
response_format=ObservationExtractionResponse,
|
|
124
|
-
scope="memory_extract_observation"
|
|
117
|
+
scope="memory_extract_observation",
|
|
125
118
|
)
|
|
126
119
|
|
|
127
120
|
observations = [op.observation for op in result.observations]
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
Cross-encoder neural reranking for search results.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import List
|
|
6
5
|
from .types import MergedCandidate, ScoredResult
|
|
7
6
|
|
|
8
7
|
|
|
@@ -24,14 +23,11 @@ class CrossEncoderReranker:
|
|
|
24
23
|
"""
|
|
25
24
|
if cross_encoder is None:
|
|
26
25
|
from hindsight_api.engine.cross_encoder import create_cross_encoder_from_env
|
|
26
|
+
|
|
27
27
|
cross_encoder = create_cross_encoder_from_env()
|
|
28
28
|
self.cross_encoder = cross_encoder
|
|
29
29
|
|
|
30
|
-
def rerank(
|
|
31
|
-
self,
|
|
32
|
-
query: str,
|
|
33
|
-
candidates: List[MergedCandidate]
|
|
34
|
-
) -> List[ScoredResult]:
|
|
30
|
+
def rerank(self, query: str, candidates: list[MergedCandidate]) -> list[ScoredResult]:
|
|
35
31
|
"""
|
|
36
32
|
Rerank candidates using cross-encoder scores.
|
|
37
33
|
|
|
@@ -77,6 +73,7 @@ class CrossEncoderReranker:
|
|
|
77
73
|
# Normalize scores using sigmoid to [0, 1] range
|
|
78
74
|
# Cross-encoder returns logits which can be negative
|
|
79
75
|
import numpy as np
|
|
76
|
+
|
|
80
77
|
def sigmoid(x):
|
|
81
78
|
return 1 / (1 + np.exp(-x))
|
|
82
79
|
|
|
@@ -89,7 +86,7 @@ class CrossEncoderReranker:
|
|
|
89
86
|
candidate=candidate,
|
|
90
87
|
cross_encoder_score=float(raw_score),
|
|
91
88
|
cross_encoder_score_normalized=float(norm_score),
|
|
92
|
-
weight=float(norm_score) # Initial weight is just cross-encoder score
|
|
89
|
+
weight=float(norm_score), # Initial weight is just cross-encoder score
|
|
93
90
|
)
|
|
94
91
|
scored_results.append(scored_result)
|
|
95
92
|
|