hindsight-api 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/api/mcp.py +1 -5
- hindsight_api/config.py +9 -0
- hindsight_api/engine/cross_encoder.py +1 -6
- hindsight_api/engine/llm_wrapper.py +33 -15
- hindsight_api/engine/memory_engine.py +71 -59
- hindsight_api/engine/search/__init__.py +15 -1
- hindsight_api/engine/search/graph_retrieval.py +235 -0
- hindsight_api/engine/search/mpfp_retrieval.py +454 -0
- hindsight_api/engine/search/retrieval.py +337 -163
- hindsight_api/engine/search/trace.py +1 -0
- hindsight_api/engine/search/tracer.py +8 -3
- hindsight_api/engine/search/types.py +4 -1
- hindsight_api/pg0.py +54 -326
- {hindsight_api-0.1.3.dist-info → hindsight_api-0.1.5.dist-info}/METADATA +6 -5
- {hindsight_api-0.1.3.dist-info → hindsight_api-0.1.5.dist-info}/RECORD +17 -15
- {hindsight_api-0.1.3.dist-info → hindsight_api-0.1.5.dist-info}/WHEEL +0 -0
- {hindsight_api-0.1.3.dist-info → hindsight_api-0.1.5.dist-info}/entry_points.txt +0 -0
hindsight_api/api/mcp.py
CHANGED
|
@@ -121,11 +121,7 @@ class MCPMiddleware:
|
|
|
121
121
|
self.app = app
|
|
122
122
|
self.memory = memory
|
|
123
123
|
self.mcp_server = create_mcp_server(memory)
|
|
124
|
-
|
|
125
|
-
import warnings
|
|
126
|
-
with warnings.catch_warnings():
|
|
127
|
-
warnings.simplefilter("ignore", DeprecationWarning)
|
|
128
|
-
self.mcp_app = self.mcp_server.sse_app()
|
|
124
|
+
self.mcp_app = self.mcp_server.http_app()
|
|
129
125
|
|
|
130
126
|
async def __call__(self, scope, receive, send):
|
|
131
127
|
if scope["type"] != "http":
|
hindsight_api/config.py
CHANGED
|
@@ -29,6 +29,7 @@ ENV_HOST = "HINDSIGHT_API_HOST"
|
|
|
29
29
|
ENV_PORT = "HINDSIGHT_API_PORT"
|
|
30
30
|
ENV_LOG_LEVEL = "HINDSIGHT_API_LOG_LEVEL"
|
|
31
31
|
ENV_MCP_ENABLED = "HINDSIGHT_API_MCP_ENABLED"
|
|
32
|
+
ENV_GRAPH_RETRIEVER = "HINDSIGHT_API_GRAPH_RETRIEVER"
|
|
32
33
|
|
|
33
34
|
# Default values
|
|
34
35
|
DEFAULT_DATABASE_URL = "pg0"
|
|
@@ -45,6 +46,7 @@ DEFAULT_HOST = "0.0.0.0"
|
|
|
45
46
|
DEFAULT_PORT = 8888
|
|
46
47
|
DEFAULT_LOG_LEVEL = "info"
|
|
47
48
|
DEFAULT_MCP_ENABLED = True
|
|
49
|
+
DEFAULT_GRAPH_RETRIEVER = "bfs" # Options: "bfs", "mpfp"
|
|
48
50
|
|
|
49
51
|
# Required embedding dimension for database schema
|
|
50
52
|
EMBEDDING_DIMENSION = 384
|
|
@@ -79,6 +81,9 @@ class HindsightConfig:
|
|
|
79
81
|
log_level: str
|
|
80
82
|
mcp_enabled: bool
|
|
81
83
|
|
|
84
|
+
# Recall
|
|
85
|
+
graph_retriever: str
|
|
86
|
+
|
|
82
87
|
@classmethod
|
|
83
88
|
def from_env(cls) -> "HindsightConfig":
|
|
84
89
|
"""Create configuration from environment variables."""
|
|
@@ -107,6 +112,9 @@ class HindsightConfig:
|
|
|
107
112
|
port=int(os.getenv(ENV_PORT, DEFAULT_PORT)),
|
|
108
113
|
log_level=os.getenv(ENV_LOG_LEVEL, DEFAULT_LOG_LEVEL),
|
|
109
114
|
mcp_enabled=os.getenv(ENV_MCP_ENABLED, str(DEFAULT_MCP_ENABLED)).lower() == "true",
|
|
115
|
+
|
|
116
|
+
# Recall
|
|
117
|
+
graph_retriever=os.getenv(ENV_GRAPH_RETRIEVER, DEFAULT_GRAPH_RETRIEVER),
|
|
110
118
|
)
|
|
111
119
|
|
|
112
120
|
def get_llm_base_url(self) -> str:
|
|
@@ -147,6 +155,7 @@ class HindsightConfig:
|
|
|
147
155
|
logger.info(f"LLM: provider={self.llm_provider}, model={self.llm_model}")
|
|
148
156
|
logger.info(f"Embeddings: provider={self.embeddings_provider}")
|
|
149
157
|
logger.info(f"Reranker: provider={self.reranker_provider}")
|
|
158
|
+
logger.info(f"Graph retriever: {self.graph_retriever}")
|
|
150
159
|
|
|
151
160
|
|
|
152
161
|
def get_config() -> HindsightConfig:
|
|
@@ -101,12 +101,7 @@ class LocalSTCrossEncoder(CrossEncoderModel):
|
|
|
101
101
|
)
|
|
102
102
|
|
|
103
103
|
logger.info(f"Reranker: initializing local provider with model {self.model_name}")
|
|
104
|
-
|
|
105
|
-
# Setting low_cpu_mem_usage=False and device_map=None ensures tensors are fully materialized
|
|
106
|
-
self._model = CrossEncoder(
|
|
107
|
-
self.model_name,
|
|
108
|
-
model_kwargs={"low_cpu_mem_usage": False, "device_map": None},
|
|
109
|
-
)
|
|
104
|
+
self._model = CrossEncoder(self.model_name)
|
|
110
105
|
logger.info("Reranker: local provider initialized")
|
|
111
106
|
|
|
112
107
|
def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
|
|
@@ -170,24 +170,42 @@ class LLMProvider:
|
|
|
170
170
|
"messages": messages,
|
|
171
171
|
}
|
|
172
172
|
|
|
173
|
-
if max_completion_tokens is not None:
|
|
174
|
-
call_params["max_completion_tokens"] = max_completion_tokens
|
|
175
173
|
# Check if model supports reasoning parameter (o1, o3, gpt-5 families)
|
|
176
174
|
model_lower = self.model.lower()
|
|
177
175
|
is_reasoning_model = any(x in model_lower for x in ["gpt-5", "o1", "o3"])
|
|
178
176
|
|
|
177
|
+
# For GPT-4 and GPT-4.1 models, cap max_completion_tokens to 32000
|
|
178
|
+
# For GPT-4o models, cap to 16384
|
|
179
|
+
is_gpt4_model = any(x in model_lower for x in ["gpt-4.1", "gpt-4-"])
|
|
180
|
+
is_gpt4o_model = "gpt-4o" in model_lower
|
|
181
|
+
if max_completion_tokens is not None:
|
|
182
|
+
if is_gpt4o_model and max_completion_tokens > 16384:
|
|
183
|
+
max_completion_tokens = 16384
|
|
184
|
+
elif is_gpt4_model and max_completion_tokens > 32000:
|
|
185
|
+
max_completion_tokens = 32000
|
|
186
|
+
# For reasoning models, max_completion_tokens includes reasoning + output tokens
|
|
187
|
+
# Enforce minimum of 16000 to ensure enough space for both
|
|
188
|
+
if is_reasoning_model and max_completion_tokens < 16000:
|
|
189
|
+
max_completion_tokens = 16000
|
|
190
|
+
call_params["max_completion_tokens"] = max_completion_tokens
|
|
191
|
+
|
|
179
192
|
# GPT-5/o1/o3 family doesn't support custom temperature (only default 1)
|
|
180
193
|
if temperature is not None and not is_reasoning_model:
|
|
181
194
|
call_params["temperature"] = temperature
|
|
182
195
|
|
|
196
|
+
# Set reasoning_effort for reasoning models (OpenAI gpt-5, o1, o3)
|
|
197
|
+
if is_reasoning_model and self.provider == "openai":
|
|
198
|
+
call_params["reasoning_effort"] = self.reasoning_effort
|
|
199
|
+
|
|
183
200
|
# Provider-specific parameters
|
|
184
201
|
if self.provider == "groq":
|
|
185
202
|
call_params["seed"] = DEFAULT_LLM_SEED
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
"
|
|
190
|
-
|
|
203
|
+
extra_body = {"service_tier": "auto"}
|
|
204
|
+
# Only add reasoning parameters for reasoning models
|
|
205
|
+
if is_reasoning_model:
|
|
206
|
+
extra_body["reasoning_effort"] = self.reasoning_effort
|
|
207
|
+
extra_body["include_reasoning"] = False
|
|
208
|
+
call_params["extra_body"] = extra_body
|
|
191
209
|
|
|
192
210
|
last_exception = None
|
|
193
211
|
|
|
@@ -254,9 +272,9 @@ class LLMProvider:
|
|
|
254
272
|
raise
|
|
255
273
|
|
|
256
274
|
except APIStatusError as e:
|
|
257
|
-
# Fast fail on
|
|
258
|
-
if
|
|
259
|
-
logger.error(f"
|
|
275
|
+
# Fast fail only on 401 (unauthorized) and 403 (forbidden) - these won't recover with retries
|
|
276
|
+
if e.status_code in (401, 403):
|
|
277
|
+
logger.error(f"Auth error (HTTP {e.status_code}), not retrying: {str(e)}")
|
|
260
278
|
raise
|
|
261
279
|
|
|
262
280
|
last_exception = e
|
|
@@ -394,13 +412,13 @@ class LLMProvider:
|
|
|
394
412
|
raise
|
|
395
413
|
|
|
396
414
|
except genai_errors.APIError as e:
|
|
397
|
-
# Fast fail on
|
|
398
|
-
if e.code
|
|
399
|
-
logger.error(f"Gemini
|
|
415
|
+
# Fast fail only on 401 (unauthorized) and 403 (forbidden) - these won't recover with retries
|
|
416
|
+
if e.code in (401, 403):
|
|
417
|
+
logger.error(f"Gemini auth error (HTTP {e.code}), not retrying: {str(e)}")
|
|
400
418
|
raise
|
|
401
419
|
|
|
402
|
-
# Retry on
|
|
403
|
-
if e.code in (429, 500, 502, 503, 504):
|
|
420
|
+
# Retry on retryable errors (rate limits, server errors, and other client errors like 400)
|
|
421
|
+
if e.code in (400, 429, 500, 502, 503, 504) or (e.code and e.code >= 500):
|
|
404
422
|
last_exception = e
|
|
405
423
|
if attempt < max_retries:
|
|
406
424
|
backoff = min(initial_backoff * (2 ** attempt), max_backoff)
|
|
@@ -1156,22 +1156,22 @@ class MemoryEngine:
|
|
|
1156
1156
|
aggregated_timings = {"semantic": 0.0, "bm25": 0.0, "graph": 0.0, "temporal": 0.0}
|
|
1157
1157
|
|
|
1158
1158
|
detected_temporal_constraint = None
|
|
1159
|
-
for idx,
|
|
1159
|
+
for idx, retrieval_result in enumerate(all_retrievals):
|
|
1160
1160
|
# Log fact types in this retrieval batch
|
|
1161
1161
|
ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
|
|
1162
|
-
logger.debug(f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(
|
|
1162
|
+
logger.debug(f"[RECALL {recall_id}] Fact type '{ft_name}': semantic={len(retrieval_result.semantic)}, bm25={len(retrieval_result.bm25)}, graph={len(retrieval_result.graph)}, temporal={len(retrieval_result.temporal) if retrieval_result.temporal else 0}")
|
|
1163
1163
|
|
|
1164
|
-
semantic_results.extend(
|
|
1165
|
-
bm25_results.extend(
|
|
1166
|
-
graph_results.extend(
|
|
1167
|
-
if
|
|
1168
|
-
temporal_results.extend(
|
|
1164
|
+
semantic_results.extend(retrieval_result.semantic)
|
|
1165
|
+
bm25_results.extend(retrieval_result.bm25)
|
|
1166
|
+
graph_results.extend(retrieval_result.graph)
|
|
1167
|
+
if retrieval_result.temporal:
|
|
1168
|
+
temporal_results.extend(retrieval_result.temporal)
|
|
1169
1169
|
# Track max timing for each method (since they run in parallel across fact types)
|
|
1170
|
-
for method, duration in
|
|
1171
|
-
aggregated_timings[method] = max(aggregated_timings
|
|
1170
|
+
for method, duration in retrieval_result.timings.items():
|
|
1171
|
+
aggregated_timings[method] = max(aggregated_timings.get(method, 0.0), duration)
|
|
1172
1172
|
# Capture temporal constraint (same across all fact types)
|
|
1173
|
-
if
|
|
1174
|
-
detected_temporal_constraint =
|
|
1173
|
+
if retrieval_result.temporal_constraint:
|
|
1174
|
+
detected_temporal_constraint = retrieval_result.temporal_constraint
|
|
1175
1175
|
|
|
1176
1176
|
# If no temporal results from any fact type, set to None
|
|
1177
1177
|
if not temporal_results:
|
|
@@ -1203,49 +1203,57 @@ class MemoryEngine:
|
|
|
1203
1203
|
temporal_info = f" | temporal_range={start_dt.strftime('%Y-%m-%d')} to {end_dt.strftime('%Y-%m-%d')}"
|
|
1204
1204
|
log_buffer.append(f" [2] {total_retrievals}-way retrieval ({len(fact_type)} fact_types): {', '.join(timing_parts)} in {step_duration:.3f}s{temporal_info}")
|
|
1205
1205
|
|
|
1206
|
-
# Record retrieval results for tracer
|
|
1206
|
+
# Record retrieval results for tracer - per fact type
|
|
1207
1207
|
if tracer:
|
|
1208
1208
|
# Convert RetrievalResult to old tuple format for tracer
|
|
1209
1209
|
def to_tuple_format(results):
|
|
1210
1210
|
return [(r.id, r.__dict__) for r in results]
|
|
1211
1211
|
|
|
1212
|
-
# Add
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
results=to_tuple_format(semantic_results),
|
|
1216
|
-
duration_seconds=aggregated_timings["semantic"],
|
|
1217
|
-
score_field="similarity",
|
|
1218
|
-
metadata={"limit": thinking_budget}
|
|
1219
|
-
)
|
|
1212
|
+
# Add retrieval results per fact type (to show parallel execution in UI)
|
|
1213
|
+
for idx, rr in enumerate(all_retrievals):
|
|
1214
|
+
ft_name = fact_type[idx] if idx < len(fact_type) else "unknown"
|
|
1220
1215
|
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1216
|
+
# Add semantic retrieval results for this fact type
|
|
1217
|
+
tracer.add_retrieval_results(
|
|
1218
|
+
method_name="semantic",
|
|
1219
|
+
results=to_tuple_format(rr.semantic),
|
|
1220
|
+
duration_seconds=rr.timings.get("semantic", 0.0),
|
|
1221
|
+
score_field="similarity",
|
|
1222
|
+
metadata={"limit": thinking_budget},
|
|
1223
|
+
fact_type=ft_name
|
|
1224
|
+
)
|
|
1229
1225
|
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1226
|
+
# Add BM25 retrieval results for this fact type
|
|
1227
|
+
tracer.add_retrieval_results(
|
|
1228
|
+
method_name="bm25",
|
|
1229
|
+
results=to_tuple_format(rr.bm25),
|
|
1230
|
+
duration_seconds=rr.timings.get("bm25", 0.0),
|
|
1231
|
+
score_field="bm25_score",
|
|
1232
|
+
metadata={"limit": thinking_budget},
|
|
1233
|
+
fact_type=ft_name
|
|
1234
|
+
)
|
|
1238
1235
|
|
|
1239
|
-
|
|
1240
|
-
if temporal_results:
|
|
1236
|
+
# Add graph retrieval results for this fact type
|
|
1241
1237
|
tracer.add_retrieval_results(
|
|
1242
|
-
method_name="
|
|
1243
|
-
results=to_tuple_format(
|
|
1244
|
-
duration_seconds=
|
|
1245
|
-
score_field="
|
|
1246
|
-
metadata={"budget": thinking_budget}
|
|
1238
|
+
method_name="graph",
|
|
1239
|
+
results=to_tuple_format(rr.graph),
|
|
1240
|
+
duration_seconds=rr.timings.get("graph", 0.0),
|
|
1241
|
+
score_field="activation",
|
|
1242
|
+
metadata={"budget": thinking_budget},
|
|
1243
|
+
fact_type=ft_name
|
|
1247
1244
|
)
|
|
1248
1245
|
|
|
1246
|
+
# Add temporal retrieval results for this fact type (even if empty, to show it ran)
|
|
1247
|
+
if rr.temporal is not None:
|
|
1248
|
+
tracer.add_retrieval_results(
|
|
1249
|
+
method_name="temporal",
|
|
1250
|
+
results=to_tuple_format(rr.temporal),
|
|
1251
|
+
duration_seconds=rr.timings.get("temporal", 0.0),
|
|
1252
|
+
score_field="temporal_score",
|
|
1253
|
+
metadata={"budget": thinking_budget},
|
|
1254
|
+
fact_type=ft_name
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1249
1257
|
# Record entry points (from semantic results) for legacy graph view
|
|
1250
1258
|
for rank, retrieval in enumerate(semantic_results[:10], start=1): # Top 10 as entry points
|
|
1251
1259
|
tracer.add_entry_point(retrieval.id, retrieval.text, retrieval.similarity or 0.0, rank)
|
|
@@ -1287,31 +1295,24 @@ class MemoryEngine:
|
|
|
1287
1295
|
step_duration = time.time() - step_start
|
|
1288
1296
|
log_buffer.append(f" [4] Reranking: {len(scored_results)} candidates scored in {step_duration:.3f}s")
|
|
1289
1297
|
|
|
1290
|
-
if tracer:
|
|
1291
|
-
# Convert to old format for tracer
|
|
1292
|
-
results_dict = [sr.to_dict() for sr in scored_results]
|
|
1293
|
-
tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
|
|
1294
|
-
for mc in merged_candidates]
|
|
1295
|
-
tracer.add_reranked(results_dict, tracer_merged)
|
|
1296
|
-
tracer.add_phase_metric("reranking", step_duration, {
|
|
1297
|
-
"reranker_type": "cross-encoder",
|
|
1298
|
-
"candidates_reranked": len(scored_results)
|
|
1299
|
-
})
|
|
1300
|
-
|
|
1301
1298
|
# Step 4.5: Combine cross-encoder score with retrieval signals
|
|
1302
1299
|
# This preserves retrieval work (RRF, temporal, recency) instead of pure cross-encoder ranking
|
|
1303
1300
|
if scored_results:
|
|
1304
|
-
# Normalize RRF scores to [0, 1] range
|
|
1301
|
+
# Normalize RRF scores to [0, 1] range using min-max normalization
|
|
1305
1302
|
rrf_scores = [sr.candidate.rrf_score for sr in scored_results]
|
|
1306
|
-
max_rrf = max(rrf_scores) if rrf_scores else
|
|
1303
|
+
max_rrf = max(rrf_scores) if rrf_scores else 0.0
|
|
1307
1304
|
min_rrf = min(rrf_scores) if rrf_scores else 0.0
|
|
1308
|
-
rrf_range = max_rrf - min_rrf
|
|
1305
|
+
rrf_range = max_rrf - min_rrf # Don't force to 1.0, let fallback handle it
|
|
1309
1306
|
|
|
1310
1307
|
# Calculate recency based on occurred_start (more recent = higher score)
|
|
1311
1308
|
now = utcnow()
|
|
1312
1309
|
for sr in scored_results:
|
|
1313
|
-
# Normalize RRF score
|
|
1314
|
-
|
|
1310
|
+
# Normalize RRF score (0-1 range, 0.5 if all same)
|
|
1311
|
+
if rrf_range > 0:
|
|
1312
|
+
sr.rrf_normalized = (sr.candidate.rrf_score - min_rrf) / rrf_range
|
|
1313
|
+
else:
|
|
1314
|
+
# All RRF scores are the same, use neutral value
|
|
1315
|
+
sr.rrf_normalized = 0.5
|
|
1315
1316
|
|
|
1316
1317
|
# Calculate recency (decay over 365 days, minimum 0.1)
|
|
1317
1318
|
sr.recency = 0.5 # default for missing dates
|
|
@@ -1343,6 +1344,17 @@ class MemoryEngine:
|
|
|
1343
1344
|
scored_results.sort(key=lambda x: x.weight, reverse=True)
|
|
1344
1345
|
log_buffer.append(f" [4.6] Combined scoring: cross_encoder(0.6) + rrf(0.2) + temporal(0.1) + recency(0.1)")
|
|
1345
1346
|
|
|
1347
|
+
# Add reranked results to tracer AFTER combined scoring (so normalized values are included)
|
|
1348
|
+
if tracer:
|
|
1349
|
+
results_dict = [sr.to_dict() for sr in scored_results]
|
|
1350
|
+
tracer_merged = [(mc.id, mc.retrieval.__dict__, {"rrf_score": mc.rrf_score, **mc.source_ranks})
|
|
1351
|
+
for mc in merged_candidates]
|
|
1352
|
+
tracer.add_reranked(results_dict, tracer_merged)
|
|
1353
|
+
tracer.add_phase_metric("reranking", step_duration, {
|
|
1354
|
+
"reranker_type": "cross-encoder",
|
|
1355
|
+
"candidates_reranked": len(scored_results)
|
|
1356
|
+
})
|
|
1357
|
+
|
|
1346
1358
|
# Step 5: Truncate to thinking_budget * 2 for token filtering
|
|
1347
1359
|
rerank_limit = thinking_budget * 2
|
|
1348
1360
|
top_scored = scored_results[:rerank_limit]
|
|
@@ -3,13 +3,27 @@ Search module for memory retrieval.
|
|
|
3
3
|
|
|
4
4
|
Provides modular search architecture:
|
|
5
5
|
- Retrieval: 4-way parallel (semantic + BM25 + graph + temporal)
|
|
6
|
+
- Graph retrieval: Pluggable strategies (BFS, PPR)
|
|
6
7
|
- Reranking: Pluggable strategies (heuristic, cross-encoder)
|
|
7
8
|
"""
|
|
8
9
|
|
|
9
|
-
from .retrieval import
|
|
10
|
+
from .retrieval import (
|
|
11
|
+
retrieve_parallel,
|
|
12
|
+
get_default_graph_retriever,
|
|
13
|
+
set_default_graph_retriever,
|
|
14
|
+
ParallelRetrievalResult,
|
|
15
|
+
)
|
|
16
|
+
from .graph_retrieval import GraphRetriever, BFSGraphRetriever
|
|
17
|
+
from .mpfp_retrieval import MPFPGraphRetriever
|
|
10
18
|
from .reranking import CrossEncoderReranker
|
|
11
19
|
|
|
12
20
|
__all__ = [
|
|
13
21
|
"retrieve_parallel",
|
|
22
|
+
"get_default_graph_retriever",
|
|
23
|
+
"set_default_graph_retriever",
|
|
24
|
+
"ParallelRetrievalResult",
|
|
25
|
+
"GraphRetriever",
|
|
26
|
+
"BFSGraphRetriever",
|
|
27
|
+
"MPFPGraphRetriever",
|
|
14
28
|
"CrossEncoderReranker",
|
|
15
29
|
]
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Graph retrieval strategies for memory recall.
|
|
3
|
+
|
|
4
|
+
This module provides an abstraction for graph-based memory retrieval,
|
|
5
|
+
allowing different algorithms (BFS spreading activation, PPR, etc.) to be
|
|
6
|
+
swapped without changing the rest of the recall pipeline.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from typing import List, Optional
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
from .types import RetrievalResult
|
|
15
|
+
from ..db_utils import acquire_with_retry
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GraphRetriever(ABC):
|
|
21
|
+
"""
|
|
22
|
+
Abstract base class for graph-based memory retrieval.
|
|
23
|
+
|
|
24
|
+
Implementations traverse the memory graph (entity links, temporal links,
|
|
25
|
+
causal links) to find relevant facts that might not be found by
|
|
26
|
+
semantic or keyword search alone.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def name(self) -> str:
|
|
32
|
+
"""Return identifier for this retrieval strategy (e.g., 'bfs', 'mpfp')."""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
async def retrieve(
|
|
37
|
+
self,
|
|
38
|
+
pool,
|
|
39
|
+
query_embedding_str: str,
|
|
40
|
+
bank_id: str,
|
|
41
|
+
fact_type: str,
|
|
42
|
+
budget: int,
|
|
43
|
+
query_text: Optional[str] = None,
|
|
44
|
+
semantic_seeds: Optional[List[RetrievalResult]] = None,
|
|
45
|
+
temporal_seeds: Optional[List[RetrievalResult]] = None,
|
|
46
|
+
) -> List[RetrievalResult]:
|
|
47
|
+
"""
|
|
48
|
+
Retrieve relevant facts via graph traversal.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
pool: Database connection pool
|
|
52
|
+
query_embedding_str: Query embedding as string (for finding entry points)
|
|
53
|
+
bank_id: Memory bank identifier
|
|
54
|
+
fact_type: Fact type to filter ('world', 'experience', 'opinion', 'observation')
|
|
55
|
+
budget: Maximum number of nodes to explore/return
|
|
56
|
+
query_text: Original query text (optional, for some strategies)
|
|
57
|
+
semantic_seeds: Pre-computed semantic entry points (from semantic retrieval)
|
|
58
|
+
temporal_seeds: Pre-computed temporal entry points (from temporal retrieval)
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List of RetrievalResult objects with activation scores set
|
|
62
|
+
"""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class BFSGraphRetriever(GraphRetriever):
|
|
67
|
+
"""
|
|
68
|
+
Graph retrieval using BFS-style spreading activation.
|
|
69
|
+
|
|
70
|
+
Starting from semantic entry points, spreads activation through
|
|
71
|
+
the memory graph (entity, temporal, causal links) using breadth-first
|
|
72
|
+
traversal with decaying activation.
|
|
73
|
+
|
|
74
|
+
This is the original Hindsight graph retrieval algorithm.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
entry_point_limit: int = 5,
|
|
80
|
+
entry_point_threshold: float = 0.5,
|
|
81
|
+
activation_decay: float = 0.8,
|
|
82
|
+
min_activation: float = 0.1,
|
|
83
|
+
batch_size: int = 20,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Initialize BFS graph retriever.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
entry_point_limit: Maximum number of entry points to start from
|
|
90
|
+
entry_point_threshold: Minimum semantic similarity for entry points
|
|
91
|
+
activation_decay: Decay factor per hop (activation *= decay)
|
|
92
|
+
min_activation: Minimum activation to continue spreading
|
|
93
|
+
batch_size: Number of nodes to process per batch (for neighbor fetching)
|
|
94
|
+
"""
|
|
95
|
+
self.entry_point_limit = entry_point_limit
|
|
96
|
+
self.entry_point_threshold = entry_point_threshold
|
|
97
|
+
self.activation_decay = activation_decay
|
|
98
|
+
self.min_activation = min_activation
|
|
99
|
+
self.batch_size = batch_size
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def name(self) -> str:
|
|
103
|
+
return "bfs"
|
|
104
|
+
|
|
105
|
+
async def retrieve(
|
|
106
|
+
self,
|
|
107
|
+
pool,
|
|
108
|
+
query_embedding_str: str,
|
|
109
|
+
bank_id: str,
|
|
110
|
+
fact_type: str,
|
|
111
|
+
budget: int,
|
|
112
|
+
query_text: Optional[str] = None,
|
|
113
|
+
semantic_seeds: Optional[List[RetrievalResult]] = None,
|
|
114
|
+
temporal_seeds: Optional[List[RetrievalResult]] = None,
|
|
115
|
+
) -> List[RetrievalResult]:
|
|
116
|
+
"""
|
|
117
|
+
Retrieve facts using BFS spreading activation.
|
|
118
|
+
|
|
119
|
+
Algorithm:
|
|
120
|
+
1. Find entry points (top semantic matches above threshold)
|
|
121
|
+
2. BFS traversal: visit neighbors, propagate decaying activation
|
|
122
|
+
3. Boost causal links (causes, enables, prevents)
|
|
123
|
+
4. Return visited nodes up to budget
|
|
124
|
+
|
|
125
|
+
Note: BFS finds its own entry points via embedding search.
|
|
126
|
+
The semantic_seeds and temporal_seeds parameters are accepted
|
|
127
|
+
for interface compatibility but not used.
|
|
128
|
+
"""
|
|
129
|
+
async with acquire_with_retry(pool) as conn:
|
|
130
|
+
return await self._retrieve_with_conn(
|
|
131
|
+
conn, query_embedding_str, bank_id, fact_type, budget
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
async def _retrieve_with_conn(
|
|
135
|
+
self,
|
|
136
|
+
conn,
|
|
137
|
+
query_embedding_str: str,
|
|
138
|
+
bank_id: str,
|
|
139
|
+
fact_type: str,
|
|
140
|
+
budget: int,
|
|
141
|
+
) -> List[RetrievalResult]:
|
|
142
|
+
"""Internal implementation with connection."""
|
|
143
|
+
|
|
144
|
+
# Step 1: Find entry points
|
|
145
|
+
entry_points = await conn.fetch(
|
|
146
|
+
"""
|
|
147
|
+
SELECT id, text, context, event_date, occurred_start, occurred_end,
|
|
148
|
+
mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
|
|
149
|
+
1 - (embedding <=> $1::vector) AS similarity
|
|
150
|
+
FROM memory_units
|
|
151
|
+
WHERE bank_id = $2
|
|
152
|
+
AND embedding IS NOT NULL
|
|
153
|
+
AND fact_type = $3
|
|
154
|
+
AND (1 - (embedding <=> $1::vector)) >= $4
|
|
155
|
+
ORDER BY embedding <=> $1::vector
|
|
156
|
+
LIMIT $5
|
|
157
|
+
""",
|
|
158
|
+
query_embedding_str, bank_id, fact_type,
|
|
159
|
+
self.entry_point_threshold, self.entry_point_limit
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if not entry_points:
|
|
163
|
+
return []
|
|
164
|
+
|
|
165
|
+
# Step 2: BFS spreading activation
|
|
166
|
+
visited = set()
|
|
167
|
+
results = []
|
|
168
|
+
queue = [
|
|
169
|
+
(RetrievalResult.from_db_row(dict(r)), r["similarity"])
|
|
170
|
+
for r in entry_points
|
|
171
|
+
]
|
|
172
|
+
budget_remaining = budget
|
|
173
|
+
|
|
174
|
+
while queue and budget_remaining > 0:
|
|
175
|
+
# Collect a batch of nodes to process
|
|
176
|
+
batch_nodes = []
|
|
177
|
+
batch_activations = {}
|
|
178
|
+
|
|
179
|
+
while queue and len(batch_nodes) < self.batch_size and budget_remaining > 0:
|
|
180
|
+
current, activation = queue.pop(0)
|
|
181
|
+
unit_id = current.id
|
|
182
|
+
|
|
183
|
+
if unit_id not in visited:
|
|
184
|
+
visited.add(unit_id)
|
|
185
|
+
budget_remaining -= 1
|
|
186
|
+
current.activation = activation
|
|
187
|
+
results.append(current)
|
|
188
|
+
batch_nodes.append(current.id)
|
|
189
|
+
batch_activations[unit_id] = activation
|
|
190
|
+
|
|
191
|
+
# Batch fetch neighbors
|
|
192
|
+
if batch_nodes and budget_remaining > 0:
|
|
193
|
+
max_neighbors = len(batch_nodes) * 20
|
|
194
|
+
neighbors = await conn.fetch(
|
|
195
|
+
"""
|
|
196
|
+
SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end,
|
|
197
|
+
mu.mentioned_at, mu.access_count, mu.embedding, mu.fact_type,
|
|
198
|
+
mu.document_id, mu.chunk_id,
|
|
199
|
+
ml.weight, ml.link_type, ml.from_unit_id
|
|
200
|
+
FROM memory_links ml
|
|
201
|
+
JOIN memory_units mu ON ml.to_unit_id = mu.id
|
|
202
|
+
WHERE ml.from_unit_id = ANY($1::uuid[])
|
|
203
|
+
AND ml.weight >= $2
|
|
204
|
+
AND mu.fact_type = $3
|
|
205
|
+
ORDER BY ml.weight DESC
|
|
206
|
+
LIMIT $4
|
|
207
|
+
""",
|
|
208
|
+
batch_nodes, self.min_activation, fact_type, max_neighbors
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
for n in neighbors:
|
|
212
|
+
neighbor_id = str(n["id"])
|
|
213
|
+
if neighbor_id not in visited:
|
|
214
|
+
parent_id = str(n["from_unit_id"])
|
|
215
|
+
parent_activation = batch_activations.get(parent_id, 0.5)
|
|
216
|
+
|
|
217
|
+
# Boost causal links
|
|
218
|
+
link_type = n["link_type"]
|
|
219
|
+
base_weight = n["weight"]
|
|
220
|
+
|
|
221
|
+
if link_type in ("causes", "caused_by"):
|
|
222
|
+
causal_boost = 2.0
|
|
223
|
+
elif link_type in ("enables", "prevents"):
|
|
224
|
+
causal_boost = 1.5
|
|
225
|
+
else:
|
|
226
|
+
causal_boost = 1.0
|
|
227
|
+
|
|
228
|
+
effective_weight = base_weight * causal_boost
|
|
229
|
+
new_activation = parent_activation * effective_weight * self.activation_decay
|
|
230
|
+
|
|
231
|
+
if new_activation > self.min_activation:
|
|
232
|
+
neighbor_result = RetrievalResult.from_db_row(dict(n))
|
|
233
|
+
queue.append((neighbor_result, new_activation))
|
|
234
|
+
|
|
235
|
+
return results
|