hindsight-api 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -9
- hindsight_api/alembic/env.py +5 -8
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
- hindsight_api/api/__init__.py +10 -10
- hindsight_api/api/http.py +575 -593
- hindsight_api/api/mcp.py +31 -33
- hindsight_api/banner.py +13 -6
- hindsight_api/config.py +17 -12
- hindsight_api/engine/__init__.py +9 -9
- hindsight_api/engine/cross_encoder.py +23 -27
- hindsight_api/engine/db_utils.py +5 -4
- hindsight_api/engine/embeddings.py +22 -21
- hindsight_api/engine/entity_resolver.py +81 -75
- hindsight_api/engine/llm_wrapper.py +74 -88
- hindsight_api/engine/memory_engine.py +663 -673
- hindsight_api/engine/query_analyzer.py +100 -97
- hindsight_api/engine/response_models.py +105 -106
- hindsight_api/engine/retain/__init__.py +9 -16
- hindsight_api/engine/retain/bank_utils.py +34 -58
- hindsight_api/engine/retain/chunk_storage.py +4 -12
- hindsight_api/engine/retain/deduplication.py +9 -28
- hindsight_api/engine/retain/embedding_processing.py +4 -11
- hindsight_api/engine/retain/embedding_utils.py +3 -4
- hindsight_api/engine/retain/entity_processing.py +7 -17
- hindsight_api/engine/retain/fact_extraction.py +155 -165
- hindsight_api/engine/retain/fact_storage.py +11 -23
- hindsight_api/engine/retain/link_creation.py +11 -39
- hindsight_api/engine/retain/link_utils.py +166 -95
- hindsight_api/engine/retain/observation_regeneration.py +39 -52
- hindsight_api/engine/retain/orchestrator.py +72 -62
- hindsight_api/engine/retain/types.py +49 -43
- hindsight_api/engine/search/__init__.py +15 -1
- hindsight_api/engine/search/fusion.py +6 -15
- hindsight_api/engine/search/graph_retrieval.py +234 -0
- hindsight_api/engine/search/mpfp_retrieval.py +438 -0
- hindsight_api/engine/search/observation_utils.py +9 -16
- hindsight_api/engine/search/reranking.py +4 -7
- hindsight_api/engine/search/retrieval.py +388 -193
- hindsight_api/engine/search/scoring.py +5 -7
- hindsight_api/engine/search/temporal_extraction.py +8 -11
- hindsight_api/engine/search/think_utils.py +115 -39
- hindsight_api/engine/search/trace.py +68 -38
- hindsight_api/engine/search/tracer.py +49 -35
- hindsight_api/engine/search/types.py +22 -16
- hindsight_api/engine/task_backend.py +21 -26
- hindsight_api/engine/utils.py +25 -10
- hindsight_api/main.py +21 -40
- hindsight_api/mcp_local.py +190 -0
- hindsight_api/metrics.py +44 -30
- hindsight_api/migrations.py +10 -8
- hindsight_api/models.py +60 -72
- hindsight_api/pg0.py +64 -337
- hindsight_api/server.py +3 -6
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +6 -5
- hindsight_api-0.1.6.dist-info/RECORD +64 -0
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.1.4.dist-info/RECORD +0 -61
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
|
@@ -6,8 +6,7 @@ from content input to fact storage.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
|
-
from
|
|
10
|
-
from datetime import datetime
|
|
9
|
+
from datetime import UTC, datetime
|
|
11
10
|
from uuid import UUID
|
|
12
11
|
|
|
13
12
|
|
|
@@ -18,16 +17,18 @@ class RetainContent:
|
|
|
18
17
|
|
|
19
18
|
Represents a single piece of content to extract facts from.
|
|
20
19
|
"""
|
|
20
|
+
|
|
21
21
|
content: str
|
|
22
22
|
context: str = ""
|
|
23
|
-
event_date:
|
|
24
|
-
metadata:
|
|
23
|
+
event_date: datetime | None = None
|
|
24
|
+
metadata: dict[str, str] = field(default_factory=dict)
|
|
25
25
|
|
|
26
26
|
def __post_init__(self):
|
|
27
27
|
"""Ensure event_date is set."""
|
|
28
28
|
if self.event_date is None:
|
|
29
|
-
from datetime import datetime
|
|
30
|
-
|
|
29
|
+
from datetime import datetime
|
|
30
|
+
|
|
31
|
+
self.event_date = datetime.now(UTC)
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
@dataclass
|
|
@@ -37,6 +38,7 @@ class ChunkMetadata:
|
|
|
37
38
|
|
|
38
39
|
Used to track which facts were extracted from which chunks.
|
|
39
40
|
"""
|
|
41
|
+
|
|
40
42
|
chunk_text: str
|
|
41
43
|
fact_count: int
|
|
42
44
|
content_index: int # Index of the source content
|
|
@@ -50,9 +52,10 @@ class EntityRef:
|
|
|
50
52
|
|
|
51
53
|
Entities are extracted by the LLM during fact extraction.
|
|
52
54
|
"""
|
|
55
|
+
|
|
53
56
|
name: str
|
|
54
|
-
canonical_name:
|
|
55
|
-
entity_id:
|
|
57
|
+
canonical_name: str | None = None # Resolved canonical name
|
|
58
|
+
entity_id: UUID | None = None # Resolved entity ID
|
|
56
59
|
|
|
57
60
|
|
|
58
61
|
@dataclass
|
|
@@ -62,6 +65,7 @@ class CausalRelation:
|
|
|
62
65
|
|
|
63
66
|
Represents how one fact causes, enables, or prevents another.
|
|
64
67
|
"""
|
|
68
|
+
|
|
65
69
|
relation_type: str # "causes", "enables", "prevents", "caused_by"
|
|
66
70
|
target_fact_index: int # Index of the target fact in the batch
|
|
67
71
|
strength: float = 1.0 # Strength of the causal relationship
|
|
@@ -74,20 +78,21 @@ class ExtractedFact:
|
|
|
74
78
|
|
|
75
79
|
This is the raw output from fact extraction before processing.
|
|
76
80
|
"""
|
|
81
|
+
|
|
77
82
|
fact_text: str
|
|
78
83
|
fact_type: str # "world", "experience", "opinion", "observation"
|
|
79
|
-
entities:
|
|
80
|
-
occurred_start:
|
|
81
|
-
occurred_end:
|
|
82
|
-
where:
|
|
83
|
-
causal_relations:
|
|
84
|
+
entities: list[str] = field(default_factory=list)
|
|
85
|
+
occurred_start: datetime | None = None
|
|
86
|
+
occurred_end: datetime | None = None
|
|
87
|
+
where: str | None = None # WHERE the fact occurred or is about
|
|
88
|
+
causal_relations: list[CausalRelation] = field(default_factory=list)
|
|
84
89
|
|
|
85
90
|
# Context from the content item
|
|
86
91
|
content_index: int = 0 # Which content this fact came from
|
|
87
92
|
chunk_index: int = 0 # Which chunk this fact came from
|
|
88
93
|
context: str = ""
|
|
89
|
-
mentioned_at:
|
|
90
|
-
metadata:
|
|
94
|
+
mentioned_at: datetime | None = None
|
|
95
|
+
metadata: dict[str, str] = field(default_factory=dict)
|
|
91
96
|
|
|
92
97
|
|
|
93
98
|
@dataclass
|
|
@@ -97,37 +102,38 @@ class ProcessedFact:
|
|
|
97
102
|
|
|
98
103
|
Includes resolved entities, embeddings, and all necessary fields.
|
|
99
104
|
"""
|
|
105
|
+
|
|
100
106
|
# Core fact data
|
|
101
107
|
fact_text: str
|
|
102
108
|
fact_type: str
|
|
103
|
-
embedding:
|
|
109
|
+
embedding: list[float]
|
|
104
110
|
|
|
105
111
|
# Temporal data
|
|
106
|
-
occurred_start:
|
|
107
|
-
occurred_end:
|
|
112
|
+
occurred_start: datetime | None
|
|
113
|
+
occurred_end: datetime | None
|
|
108
114
|
mentioned_at: datetime
|
|
109
115
|
|
|
110
116
|
# Context and metadata
|
|
111
117
|
context: str
|
|
112
|
-
metadata:
|
|
118
|
+
metadata: dict[str, str]
|
|
113
119
|
|
|
114
120
|
# Location data
|
|
115
|
-
where:
|
|
121
|
+
where: str | None = None
|
|
116
122
|
|
|
117
123
|
# Entities
|
|
118
|
-
entities:
|
|
124
|
+
entities: list[EntityRef] = field(default_factory=list)
|
|
119
125
|
|
|
120
126
|
# Causal relations
|
|
121
|
-
causal_relations:
|
|
127
|
+
causal_relations: list[CausalRelation] = field(default_factory=list)
|
|
122
128
|
|
|
123
129
|
# Chunk reference
|
|
124
|
-
chunk_id:
|
|
130
|
+
chunk_id: str | None = None
|
|
125
131
|
|
|
126
132
|
# Document reference (denormalized for query performance)
|
|
127
|
-
document_id:
|
|
133
|
+
document_id: str | None = None
|
|
128
134
|
|
|
129
135
|
# DB fields (set after insertion)
|
|
130
|
-
unit_id:
|
|
136
|
+
unit_id: UUID | None = None
|
|
131
137
|
|
|
132
138
|
@property
|
|
133
139
|
def is_duplicate(self) -> bool:
|
|
@@ -136,10 +142,8 @@ class ProcessedFact:
|
|
|
136
142
|
|
|
137
143
|
@staticmethod
|
|
138
144
|
def from_extracted_fact(
|
|
139
|
-
extracted_fact:
|
|
140
|
-
|
|
141
|
-
chunk_id: Optional[str] = None
|
|
142
|
-
) -> 'ProcessedFact':
|
|
145
|
+
extracted_fact: "ExtractedFact", embedding: list[float], chunk_id: str | None = None
|
|
146
|
+
) -> "ProcessedFact":
|
|
143
147
|
"""
|
|
144
148
|
Create ProcessedFact from ExtractedFact.
|
|
145
149
|
|
|
@@ -151,12 +155,12 @@ class ProcessedFact:
|
|
|
151
155
|
Returns:
|
|
152
156
|
ProcessedFact ready for storage
|
|
153
157
|
"""
|
|
154
|
-
from datetime import datetime
|
|
158
|
+
from datetime import datetime
|
|
155
159
|
|
|
156
160
|
# Use occurred dates only if explicitly provided by LLM
|
|
157
161
|
occurred_start = extracted_fact.occurred_start
|
|
158
162
|
occurred_end = extracted_fact.occurred_end
|
|
159
|
-
mentioned_at = extracted_fact.mentioned_at or datetime.now(
|
|
163
|
+
mentioned_at = extracted_fact.mentioned_at or datetime.now(UTC)
|
|
160
164
|
|
|
161
165
|
# Convert entity strings to EntityRef objects
|
|
162
166
|
entities = [EntityRef(name=name) for name in extracted_fact.entities]
|
|
@@ -172,7 +176,7 @@ class ProcessedFact:
|
|
|
172
176
|
metadata=extracted_fact.metadata,
|
|
173
177
|
entities=entities,
|
|
174
178
|
causal_relations=extracted_fact.causal_relations,
|
|
175
|
-
chunk_id=chunk_id
|
|
179
|
+
chunk_id=chunk_id,
|
|
176
180
|
)
|
|
177
181
|
|
|
178
182
|
|
|
@@ -183,10 +187,11 @@ class EntityLink:
|
|
|
183
187
|
|
|
184
188
|
Used for entity-based graph connections in the memory graph.
|
|
185
189
|
"""
|
|
190
|
+
|
|
186
191
|
from_unit_id: UUID
|
|
187
192
|
to_unit_id: UUID
|
|
188
193
|
entity_id: UUID
|
|
189
|
-
link_type: str =
|
|
194
|
+
link_type: str = "entity"
|
|
190
195
|
weight: float = 1.0
|
|
191
196
|
|
|
192
197
|
|
|
@@ -197,24 +202,25 @@ class RetainBatch:
|
|
|
197
202
|
|
|
198
203
|
Tracks all facts, chunks, and metadata for a batch operation.
|
|
199
204
|
"""
|
|
205
|
+
|
|
200
206
|
bank_id: str
|
|
201
|
-
contents:
|
|
202
|
-
document_id:
|
|
203
|
-
fact_type_override:
|
|
204
|
-
confidence_score:
|
|
207
|
+
contents: list[RetainContent]
|
|
208
|
+
document_id: str | None = None
|
|
209
|
+
fact_type_override: str | None = None
|
|
210
|
+
confidence_score: float | None = None
|
|
205
211
|
|
|
206
212
|
# Extracted data (populated during processing)
|
|
207
|
-
extracted_facts:
|
|
208
|
-
processed_facts:
|
|
209
|
-
chunks:
|
|
213
|
+
extracted_facts: list[ExtractedFact] = field(default_factory=list)
|
|
214
|
+
processed_facts: list[ProcessedFact] = field(default_factory=list)
|
|
215
|
+
chunks: list[ChunkMetadata] = field(default_factory=list)
|
|
210
216
|
|
|
211
217
|
# Results (populated after storage)
|
|
212
|
-
unit_ids_by_content:
|
|
218
|
+
unit_ids_by_content: list[list[str]] = field(default_factory=list)
|
|
213
219
|
|
|
214
|
-
def get_facts_for_content(self, content_index: int) ->
|
|
220
|
+
def get_facts_for_content(self, content_index: int) -> list[ExtractedFact]:
|
|
215
221
|
"""Get all extracted facts for a specific content item."""
|
|
216
222
|
return [f for f in self.extracted_facts if f.content_index == content_index]
|
|
217
223
|
|
|
218
|
-
def get_chunks_for_content(self, content_index: int) ->
|
|
224
|
+
def get_chunks_for_content(self, content_index: int) -> list[ChunkMetadata]:
|
|
219
225
|
"""Get all chunks for a specific content item."""
|
|
220
226
|
return [c for c in self.chunks if c.content_index == content_index]
|
|
@@ -3,13 +3,27 @@ Search module for memory retrieval.
|
|
|
3
3
|
|
|
4
4
|
Provides modular search architecture:
|
|
5
5
|
- Retrieval: 4-way parallel (semantic + BM25 + graph + temporal)
|
|
6
|
+
- Graph retrieval: Pluggable strategies (BFS, PPR)
|
|
6
7
|
- Reranking: Pluggable strategies (heuristic, cross-encoder)
|
|
7
8
|
"""
|
|
8
9
|
|
|
9
|
-
from .
|
|
10
|
+
from .graph_retrieval import BFSGraphRetriever, GraphRetriever
|
|
11
|
+
from .mpfp_retrieval import MPFPGraphRetriever
|
|
10
12
|
from .reranking import CrossEncoderReranker
|
|
13
|
+
from .retrieval import (
|
|
14
|
+
ParallelRetrievalResult,
|
|
15
|
+
get_default_graph_retriever,
|
|
16
|
+
retrieve_parallel,
|
|
17
|
+
set_default_graph_retriever,
|
|
18
|
+
)
|
|
11
19
|
|
|
12
20
|
__all__ = [
|
|
13
21
|
"retrieve_parallel",
|
|
22
|
+
"get_default_graph_retriever",
|
|
23
|
+
"set_default_graph_retriever",
|
|
24
|
+
"ParallelRetrievalResult",
|
|
25
|
+
"GraphRetriever",
|
|
26
|
+
"BFSGraphRetriever",
|
|
27
|
+
"MPFPGraphRetriever",
|
|
14
28
|
"CrossEncoderReranker",
|
|
15
29
|
]
|
|
@@ -2,15 +2,12 @@
|
|
|
2
2
|
Helper functions for hybrid search (semantic + BM25 + graph).
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import
|
|
6
|
-
import asyncio
|
|
7
|
-
from .types import RetrievalResult, MergedCandidate
|
|
5
|
+
from typing import Any
|
|
8
6
|
|
|
7
|
+
from .types import MergedCandidate, RetrievalResult
|
|
9
8
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
k: int = 60
|
|
13
|
-
) -> List[MergedCandidate]:
|
|
9
|
+
|
|
10
|
+
def reciprocal_rank_fusion(result_lists: list[list[RetrievalResult]], k: int = 60) -> list[MergedCandidate]:
|
|
14
11
|
"""
|
|
15
12
|
Merge multiple ranked result lists using Reciprocal Rank Fusion.
|
|
16
13
|
|
|
@@ -73,20 +70,14 @@ def reciprocal_rank_fusion(
|
|
|
73
70
|
sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True), start=1
|
|
74
71
|
):
|
|
75
72
|
merged_candidate = MergedCandidate(
|
|
76
|
-
retrieval=all_retrievals[doc_id],
|
|
77
|
-
rrf_score=rrf_score,
|
|
78
|
-
rrf_rank=rrf_rank,
|
|
79
|
-
source_ranks=source_ranks[doc_id]
|
|
73
|
+
retrieval=all_retrievals[doc_id], rrf_score=rrf_score, rrf_rank=rrf_rank, source_ranks=source_ranks[doc_id]
|
|
80
74
|
)
|
|
81
75
|
merged_results.append(merged_candidate)
|
|
82
76
|
|
|
83
77
|
return merged_results
|
|
84
78
|
|
|
85
79
|
|
|
86
|
-
def normalize_scores_on_deltas(
|
|
87
|
-
results: List[Dict[str, Any]],
|
|
88
|
-
score_keys: List[str]
|
|
89
|
-
) -> List[Dict[str, Any]]:
|
|
80
|
+
def normalize_scores_on_deltas(results: list[dict[str, Any]], score_keys: list[str]) -> list[dict[str, Any]]:
|
|
90
81
|
"""
|
|
91
82
|
Normalize scores based on deltas (min-max normalization within result set).
|
|
92
83
|
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Graph retrieval strategies for memory recall.
|
|
3
|
+
|
|
4
|
+
This module provides an abstraction for graph-based memory retrieval,
|
|
5
|
+
allowing different algorithms (BFS spreading activation, PPR, etc.) to be
|
|
6
|
+
swapped without changing the rest of the recall pipeline.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
|
|
12
|
+
from ..db_utils import acquire_with_retry
|
|
13
|
+
from .types import RetrievalResult
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GraphRetriever(ABC):
|
|
19
|
+
"""
|
|
20
|
+
Abstract base class for graph-based memory retrieval.
|
|
21
|
+
|
|
22
|
+
Implementations traverse the memory graph (entity links, temporal links,
|
|
23
|
+
causal links) to find relevant facts that might not be found by
|
|
24
|
+
semantic or keyword search alone.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def name(self) -> str:
|
|
30
|
+
"""Return identifier for this retrieval strategy (e.g., 'bfs', 'mpfp')."""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
async def retrieve(
|
|
35
|
+
self,
|
|
36
|
+
pool,
|
|
37
|
+
query_embedding_str: str,
|
|
38
|
+
bank_id: str,
|
|
39
|
+
fact_type: str,
|
|
40
|
+
budget: int,
|
|
41
|
+
query_text: str | None = None,
|
|
42
|
+
semantic_seeds: list[RetrievalResult] | None = None,
|
|
43
|
+
temporal_seeds: list[RetrievalResult] | None = None,
|
|
44
|
+
) -> list[RetrievalResult]:
|
|
45
|
+
"""
|
|
46
|
+
Retrieve relevant facts via graph traversal.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
pool: Database connection pool
|
|
50
|
+
query_embedding_str: Query embedding as string (for finding entry points)
|
|
51
|
+
bank_id: Memory bank identifier
|
|
52
|
+
fact_type: Fact type to filter ('world', 'experience', 'opinion', 'observation')
|
|
53
|
+
budget: Maximum number of nodes to explore/return
|
|
54
|
+
query_text: Original query text (optional, for some strategies)
|
|
55
|
+
semantic_seeds: Pre-computed semantic entry points (from semantic retrieval)
|
|
56
|
+
temporal_seeds: Pre-computed temporal entry points (from temporal retrieval)
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of RetrievalResult objects with activation scores set
|
|
60
|
+
"""
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class BFSGraphRetriever(GraphRetriever):
|
|
65
|
+
"""
|
|
66
|
+
Graph retrieval using BFS-style spreading activation.
|
|
67
|
+
|
|
68
|
+
Starting from semantic entry points, spreads activation through
|
|
69
|
+
the memory graph (entity, temporal, causal links) using breadth-first
|
|
70
|
+
traversal with decaying activation.
|
|
71
|
+
|
|
72
|
+
This is the original Hindsight graph retrieval algorithm.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
entry_point_limit: int = 5,
|
|
78
|
+
entry_point_threshold: float = 0.5,
|
|
79
|
+
activation_decay: float = 0.8,
|
|
80
|
+
min_activation: float = 0.1,
|
|
81
|
+
batch_size: int = 20,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Initialize BFS graph retriever.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
entry_point_limit: Maximum number of entry points to start from
|
|
88
|
+
entry_point_threshold: Minimum semantic similarity for entry points
|
|
89
|
+
activation_decay: Decay factor per hop (activation *= decay)
|
|
90
|
+
min_activation: Minimum activation to continue spreading
|
|
91
|
+
batch_size: Number of nodes to process per batch (for neighbor fetching)
|
|
92
|
+
"""
|
|
93
|
+
self.entry_point_limit = entry_point_limit
|
|
94
|
+
self.entry_point_threshold = entry_point_threshold
|
|
95
|
+
self.activation_decay = activation_decay
|
|
96
|
+
self.min_activation = min_activation
|
|
97
|
+
self.batch_size = batch_size
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def name(self) -> str:
|
|
101
|
+
return "bfs"
|
|
102
|
+
|
|
103
|
+
async def retrieve(
|
|
104
|
+
self,
|
|
105
|
+
pool,
|
|
106
|
+
query_embedding_str: str,
|
|
107
|
+
bank_id: str,
|
|
108
|
+
fact_type: str,
|
|
109
|
+
budget: int,
|
|
110
|
+
query_text: str | None = None,
|
|
111
|
+
semantic_seeds: list[RetrievalResult] | None = None,
|
|
112
|
+
temporal_seeds: list[RetrievalResult] | None = None,
|
|
113
|
+
) -> list[RetrievalResult]:
|
|
114
|
+
"""
|
|
115
|
+
Retrieve facts using BFS spreading activation.
|
|
116
|
+
|
|
117
|
+
Algorithm:
|
|
118
|
+
1. Find entry points (top semantic matches above threshold)
|
|
119
|
+
2. BFS traversal: visit neighbors, propagate decaying activation
|
|
120
|
+
3. Boost causal links (causes, enables, prevents)
|
|
121
|
+
4. Return visited nodes up to budget
|
|
122
|
+
|
|
123
|
+
Note: BFS finds its own entry points via embedding search.
|
|
124
|
+
The semantic_seeds and temporal_seeds parameters are accepted
|
|
125
|
+
for interface compatibility but not used.
|
|
126
|
+
"""
|
|
127
|
+
async with acquire_with_retry(pool) as conn:
|
|
128
|
+
return await self._retrieve_with_conn(conn, query_embedding_str, bank_id, fact_type, budget)
|
|
129
|
+
|
|
130
|
+
async def _retrieve_with_conn(
|
|
131
|
+
self,
|
|
132
|
+
conn,
|
|
133
|
+
query_embedding_str: str,
|
|
134
|
+
bank_id: str,
|
|
135
|
+
fact_type: str,
|
|
136
|
+
budget: int,
|
|
137
|
+
) -> list[RetrievalResult]:
|
|
138
|
+
"""Internal implementation with connection."""
|
|
139
|
+
|
|
140
|
+
# Step 1: Find entry points
|
|
141
|
+
entry_points = await conn.fetch(
|
|
142
|
+
"""
|
|
143
|
+
SELECT id, text, context, event_date, occurred_start, occurred_end,
|
|
144
|
+
mentioned_at, access_count, embedding, fact_type, document_id, chunk_id,
|
|
145
|
+
1 - (embedding <=> $1::vector) AS similarity
|
|
146
|
+
FROM memory_units
|
|
147
|
+
WHERE bank_id = $2
|
|
148
|
+
AND embedding IS NOT NULL
|
|
149
|
+
AND fact_type = $3
|
|
150
|
+
AND (1 - (embedding <=> $1::vector)) >= $4
|
|
151
|
+
ORDER BY embedding <=> $1::vector
|
|
152
|
+
LIMIT $5
|
|
153
|
+
""",
|
|
154
|
+
query_embedding_str,
|
|
155
|
+
bank_id,
|
|
156
|
+
fact_type,
|
|
157
|
+
self.entry_point_threshold,
|
|
158
|
+
self.entry_point_limit,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if not entry_points:
|
|
162
|
+
return []
|
|
163
|
+
|
|
164
|
+
# Step 2: BFS spreading activation
|
|
165
|
+
visited = set()
|
|
166
|
+
results = []
|
|
167
|
+
queue = [(RetrievalResult.from_db_row(dict(r)), r["similarity"]) for r in entry_points]
|
|
168
|
+
budget_remaining = budget
|
|
169
|
+
|
|
170
|
+
while queue and budget_remaining > 0:
|
|
171
|
+
# Collect a batch of nodes to process
|
|
172
|
+
batch_nodes = []
|
|
173
|
+
batch_activations = {}
|
|
174
|
+
|
|
175
|
+
while queue and len(batch_nodes) < self.batch_size and budget_remaining > 0:
|
|
176
|
+
current, activation = queue.pop(0)
|
|
177
|
+
unit_id = current.id
|
|
178
|
+
|
|
179
|
+
if unit_id not in visited:
|
|
180
|
+
visited.add(unit_id)
|
|
181
|
+
budget_remaining -= 1
|
|
182
|
+
current.activation = activation
|
|
183
|
+
results.append(current)
|
|
184
|
+
batch_nodes.append(current.id)
|
|
185
|
+
batch_activations[unit_id] = activation
|
|
186
|
+
|
|
187
|
+
# Batch fetch neighbors
|
|
188
|
+
if batch_nodes and budget_remaining > 0:
|
|
189
|
+
max_neighbors = len(batch_nodes) * 20
|
|
190
|
+
neighbors = await conn.fetch(
|
|
191
|
+
"""
|
|
192
|
+
SELECT mu.id, mu.text, mu.context, mu.occurred_start, mu.occurred_end,
|
|
193
|
+
mu.mentioned_at, mu.access_count, mu.embedding, mu.fact_type,
|
|
194
|
+
mu.document_id, mu.chunk_id,
|
|
195
|
+
ml.weight, ml.link_type, ml.from_unit_id
|
|
196
|
+
FROM memory_links ml
|
|
197
|
+
JOIN memory_units mu ON ml.to_unit_id = mu.id
|
|
198
|
+
WHERE ml.from_unit_id = ANY($1::uuid[])
|
|
199
|
+
AND ml.weight >= $2
|
|
200
|
+
AND mu.fact_type = $3
|
|
201
|
+
ORDER BY ml.weight DESC
|
|
202
|
+
LIMIT $4
|
|
203
|
+
""",
|
|
204
|
+
batch_nodes,
|
|
205
|
+
self.min_activation,
|
|
206
|
+
fact_type,
|
|
207
|
+
max_neighbors,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
for n in neighbors:
|
|
211
|
+
neighbor_id = str(n["id"])
|
|
212
|
+
if neighbor_id not in visited:
|
|
213
|
+
parent_id = str(n["from_unit_id"])
|
|
214
|
+
parent_activation = batch_activations.get(parent_id, 0.5)
|
|
215
|
+
|
|
216
|
+
# Boost causal links
|
|
217
|
+
link_type = n["link_type"]
|
|
218
|
+
base_weight = n["weight"]
|
|
219
|
+
|
|
220
|
+
if link_type in ("causes", "caused_by"):
|
|
221
|
+
causal_boost = 2.0
|
|
222
|
+
elif link_type in ("enables", "prevents"):
|
|
223
|
+
causal_boost = 1.5
|
|
224
|
+
else:
|
|
225
|
+
causal_boost = 1.0
|
|
226
|
+
|
|
227
|
+
effective_weight = base_weight * causal_boost
|
|
228
|
+
new_activation = parent_activation * effective_weight * self.activation_decay
|
|
229
|
+
|
|
230
|
+
if new_activation > self.min_activation:
|
|
231
|
+
neighbor_result = RetrievalResult.from_db_row(dict(n))
|
|
232
|
+
queue.append((neighbor_result, new_activation))
|
|
233
|
+
|
|
234
|
+
return results
|