graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +61 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +582 -255
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +21 -14
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +94 -50
- graphiti_core/llm_client/openai_client.py +28 -8
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +23 -14
- graphiti_core/prompts/extract_nodes.py +73 -32
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +109 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +286 -126
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +320 -158
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -1 +1,40 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
1
20
|
DO_NOT_ESCAPE_UNICODE = '\nDo not escape unicode characters.\n'
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def to_prompt_json(data: Any, ensure_ascii: bool = False, indent: int | None = None) -> str:
|
|
24
|
+
"""
|
|
25
|
+
Serialize data to JSON for use in prompts.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
data: The data to serialize
|
|
29
|
+
ensure_ascii: If True, escape non-ASCII characters. If False (default), preserve them.
|
|
30
|
+
indent: Number of spaces for indentation. Defaults to None (minified).
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
JSON string representation of the data
|
|
34
|
+
|
|
35
|
+
Notes:
|
|
36
|
+
By default (ensure_ascii=False), non-ASCII characters (e.g., Korean, Japanese, Chinese)
|
|
37
|
+
are preserved in their original form in the prompt, making them readable
|
|
38
|
+
in LLM logs and improving model understanding.
|
|
39
|
+
"""
|
|
40
|
+
return json.dumps(data, ensure_ascii=ensure_ascii, indent=indent)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
summary_instructions = """Guidelines:
|
|
18
|
+
1. Output only factual content. Never explain what you're doing, why, or mention limitations/constraints.
|
|
19
|
+
2. Only use the provided messages, entity, and entity context to set attribute values.
|
|
20
|
+
3. Keep the summary concise and to the point. STATE FACTS DIRECTLY IN UNDER 250 CHARACTERS.
|
|
21
|
+
|
|
22
|
+
Example summaries:
|
|
23
|
+
BAD: "This is the only activity in the context. The user listened to this song. No other details were provided to include in this summary."
|
|
24
|
+
GOOD: "User played 'Blue Monday' by New Order (electronic genre) on 2024-12-03 at 14:22 UTC."
|
|
25
|
+
BAD: "Based on the messages provided, the user attended a meeting. This summary focuses on that event as it was the main topic discussed."
|
|
26
|
+
GOOD: "User attended Q3 planning meeting with sales team on March 15."
|
|
27
|
+
BAD: "The context shows John ordered pizza. Due to length constraints, other details are omitted from this summary."
|
|
28
|
+
GOOD: "John ordered pepperoni pizza from Mario's at 7:30 PM, delivered to office."
|
|
29
|
+
"""
|
|
@@ -14,18 +14,19 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import json
|
|
18
17
|
from typing import Any, Protocol, TypedDict
|
|
19
18
|
|
|
20
19
|
from pydantic import BaseModel, Field
|
|
21
20
|
|
|
22
21
|
from .models import Message, PromptFunction, PromptVersion
|
|
22
|
+
from .prompt_helpers import to_prompt_json
|
|
23
|
+
from .snippets import summary_instructions
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class Summary(BaseModel):
|
|
26
27
|
summary: str = Field(
|
|
27
28
|
...,
|
|
28
|
-
description='Summary containing the important information about the entity. Under 250
|
|
29
|
+
description='Summary containing the important information about the entity. Under 250 characters',
|
|
29
30
|
)
|
|
30
31
|
|
|
31
32
|
|
|
@@ -55,11 +56,11 @@ def summarize_pair(context: dict[str, Any]) -> list[Message]:
|
|
|
55
56
|
role='user',
|
|
56
57
|
content=f"""
|
|
57
58
|
Synthesize the information from the following two summaries into a single succinct summary.
|
|
58
|
-
|
|
59
|
-
|
|
59
|
+
|
|
60
|
+
IMPORTANT: Keep the summary concise and to the point. SUMMARIES MUST BE LESS THAN 250 CHARACTERS.
|
|
60
61
|
|
|
61
62
|
Summaries:
|
|
62
|
-
{
|
|
63
|
+
{to_prompt_json(context['node_summaries'])}
|
|
63
64
|
""",
|
|
64
65
|
),
|
|
65
66
|
]
|
|
@@ -69,38 +70,35 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
|
|
69
70
|
return [
|
|
70
71
|
Message(
|
|
71
72
|
role='system',
|
|
72
|
-
content='You are a helpful assistant that
|
|
73
|
+
content='You are a helpful assistant that generates a summary and attributes from provided text.',
|
|
73
74
|
),
|
|
74
75
|
Message(
|
|
75
76
|
role='user',
|
|
76
77
|
content=f"""
|
|
77
|
-
|
|
78
|
-
<MESSAGES>
|
|
79
|
-
{json.dumps(context['previous_episodes'], indent=2)}
|
|
80
|
-
{json.dumps(context['episode_content'], indent=2)}
|
|
81
|
-
</MESSAGES>
|
|
82
|
-
|
|
83
|
-
Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
|
|
78
|
+
Given the MESSAGES and the ENTITY name, create a summary for the ENTITY. Your summary must only use
|
|
84
79
|
information from the provided MESSAGES. Your summary should also only contain information relevant to the
|
|
85
|
-
provided ENTITY.
|
|
86
|
-
|
|
80
|
+
provided ENTITY.
|
|
81
|
+
|
|
87
82
|
In addition, extract any values for the provided entity properties based on their descriptions.
|
|
88
83
|
If the value of the entity property cannot be found in the current context, set the value of the property to the Python value None.
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
84
|
+
|
|
85
|
+
{summary_instructions}
|
|
86
|
+
|
|
87
|
+
<MESSAGES>
|
|
88
|
+
{to_prompt_json(context['previous_episodes'])}
|
|
89
|
+
{to_prompt_json(context['episode_content'])}
|
|
90
|
+
</MESSAGES>
|
|
91
|
+
|
|
94
92
|
<ENTITY>
|
|
95
93
|
{context['node_name']}
|
|
96
94
|
</ENTITY>
|
|
97
|
-
|
|
95
|
+
|
|
98
96
|
<ENTITY CONTEXT>
|
|
99
97
|
{context['node_summary']}
|
|
100
98
|
</ENTITY CONTEXT>
|
|
101
|
-
|
|
99
|
+
|
|
102
100
|
<ATTRIBUTES>
|
|
103
|
-
{
|
|
101
|
+
{to_prompt_json(context['attributes'])}
|
|
104
102
|
</ATTRIBUTES>
|
|
105
103
|
""",
|
|
106
104
|
),
|
|
@@ -117,10 +115,10 @@ def summary_description(context: dict[str, Any]) -> list[Message]:
|
|
|
117
115
|
role='user',
|
|
118
116
|
content=f"""
|
|
119
117
|
Create a short one sentence description of the summary that explains what kind of information is summarized.
|
|
120
|
-
Summaries must be under 250
|
|
118
|
+
Summaries must be under 250 characters.
|
|
121
119
|
|
|
122
120
|
Summary:
|
|
123
|
-
{
|
|
121
|
+
{to_prompt_json(context['summary'])}
|
|
124
122
|
""",
|
|
125
123
|
),
|
|
126
124
|
]
|
graphiti_core/search/search.py
CHANGED
|
@@ -21,6 +21,7 @@ from time import time
|
|
|
21
21
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
22
22
|
from graphiti_core.driver.driver import GraphDriver
|
|
23
23
|
from graphiti_core.edges import EntityEdge
|
|
24
|
+
from graphiti_core.embedder.client import EMBEDDING_DIM
|
|
24
25
|
from graphiti_core.errors import SearchRerankerError
|
|
25
26
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
26
27
|
from graphiti_core.helpers import semaphore_gather
|
|
@@ -29,6 +30,7 @@ from graphiti_core.search.search_config import (
|
|
|
29
30
|
DEFAULT_SEARCH_LIMIT,
|
|
30
31
|
CommunityReranker,
|
|
31
32
|
CommunitySearchConfig,
|
|
33
|
+
CommunitySearchMethod,
|
|
32
34
|
EdgeReranker,
|
|
33
35
|
EdgeSearchConfig,
|
|
34
36
|
EdgeSearchMethod,
|
|
@@ -72,34 +74,53 @@ async def search(
|
|
|
72
74
|
center_node_uuid: str | None = None,
|
|
73
75
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
74
76
|
query_vector: list[float] | None = None,
|
|
77
|
+
driver: GraphDriver | None = None,
|
|
75
78
|
) -> SearchResults:
|
|
76
79
|
start = time()
|
|
77
80
|
|
|
78
|
-
driver = clients.driver
|
|
81
|
+
driver = driver or clients.driver
|
|
79
82
|
embedder = clients.embedder
|
|
80
83
|
cross_encoder = clients.cross_encoder
|
|
81
84
|
|
|
82
85
|
if query.strip() == '':
|
|
83
|
-
return SearchResults(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
86
|
+
return SearchResults()
|
|
87
|
+
|
|
88
|
+
if (
|
|
89
|
+
config.edge_config
|
|
90
|
+
and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
|
|
91
|
+
or config.edge_config
|
|
92
|
+
and EdgeReranker.mmr == config.edge_config.reranker
|
|
93
|
+
or config.node_config
|
|
94
|
+
and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
|
|
95
|
+
or config.node_config
|
|
96
|
+
and NodeReranker.mmr == config.node_config.reranker
|
|
97
|
+
or (
|
|
98
|
+
config.community_config
|
|
99
|
+
and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
|
|
88
100
|
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
101
|
+
or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
|
|
102
|
+
):
|
|
103
|
+
search_vector = (
|
|
104
|
+
query_vector
|
|
105
|
+
if query_vector is not None
|
|
106
|
+
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
search_vector = [0.0] * EMBEDDING_DIM
|
|
94
110
|
|
|
95
111
|
# if group_ids is empty, set it to None
|
|
96
112
|
group_ids = group_ids if group_ids and group_ids != [''] else None
|
|
97
|
-
|
|
113
|
+
(
|
|
114
|
+
(edges, edge_reranker_scores),
|
|
115
|
+
(nodes, node_reranker_scores),
|
|
116
|
+
(episodes, episode_reranker_scores),
|
|
117
|
+
(communities, community_reranker_scores),
|
|
118
|
+
) = await semaphore_gather(
|
|
98
119
|
edge_search(
|
|
99
120
|
driver,
|
|
100
121
|
cross_encoder,
|
|
101
122
|
query,
|
|
102
|
-
|
|
123
|
+
search_vector,
|
|
103
124
|
group_ids,
|
|
104
125
|
config.edge_config,
|
|
105
126
|
search_filter,
|
|
@@ -112,7 +133,7 @@ async def search(
|
|
|
112
133
|
driver,
|
|
113
134
|
cross_encoder,
|
|
114
135
|
query,
|
|
115
|
-
|
|
136
|
+
search_vector,
|
|
116
137
|
group_ids,
|
|
117
138
|
config.node_config,
|
|
118
139
|
search_filter,
|
|
@@ -125,7 +146,7 @@ async def search(
|
|
|
125
146
|
driver,
|
|
126
147
|
cross_encoder,
|
|
127
148
|
query,
|
|
128
|
-
|
|
149
|
+
search_vector,
|
|
129
150
|
group_ids,
|
|
130
151
|
config.episode_config,
|
|
131
152
|
search_filter,
|
|
@@ -136,7 +157,7 @@ async def search(
|
|
|
136
157
|
driver,
|
|
137
158
|
cross_encoder,
|
|
138
159
|
query,
|
|
139
|
-
|
|
160
|
+
search_vector,
|
|
140
161
|
group_ids,
|
|
141
162
|
config.community_config,
|
|
142
163
|
config.limit,
|
|
@@ -146,9 +167,13 @@ async def search(
|
|
|
146
167
|
|
|
147
168
|
results = SearchResults(
|
|
148
169
|
edges=edges,
|
|
170
|
+
edge_reranker_scores=edge_reranker_scores,
|
|
149
171
|
nodes=nodes,
|
|
172
|
+
node_reranker_scores=node_reranker_scores,
|
|
150
173
|
episodes=episodes,
|
|
174
|
+
episode_reranker_scores=episode_reranker_scores,
|
|
151
175
|
communities=communities,
|
|
176
|
+
community_reranker_scores=community_reranker_scores,
|
|
152
177
|
)
|
|
153
178
|
|
|
154
179
|
latency = (time() - start) * 1000
|
|
@@ -170,50 +195,72 @@ async def edge_search(
|
|
|
170
195
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
171
196
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
172
197
|
reranker_min_score: float = 0,
|
|
173
|
-
) -> list[EntityEdge]:
|
|
198
|
+
) -> tuple[list[EntityEdge], list[float]]:
|
|
174
199
|
if config is None:
|
|
175
|
-
return []
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
200
|
+
return [], []
|
|
201
|
+
|
|
202
|
+
# Build search tasks based on configured search methods
|
|
203
|
+
search_tasks = []
|
|
204
|
+
if EdgeSearchMethod.bm25 in config.search_methods:
|
|
205
|
+
search_tasks.append(
|
|
206
|
+
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
|
207
|
+
)
|
|
208
|
+
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
|
209
|
+
search_tasks.append(
|
|
210
|
+
edge_similarity_search(
|
|
211
|
+
driver,
|
|
212
|
+
query_vector,
|
|
213
|
+
None,
|
|
214
|
+
None,
|
|
215
|
+
search_filter,
|
|
216
|
+
group_ids,
|
|
217
|
+
2 * limit,
|
|
218
|
+
config.sim_min_score,
|
|
219
|
+
)
|
|
220
|
+
)
|
|
221
|
+
if EdgeSearchMethod.bfs in config.search_methods:
|
|
222
|
+
search_tasks.append(
|
|
223
|
+
edge_bfs_search(
|
|
224
|
+
driver,
|
|
225
|
+
bfs_origin_node_uuids,
|
|
226
|
+
config.bfs_max_depth,
|
|
227
|
+
search_filter,
|
|
228
|
+
group_ids,
|
|
229
|
+
2 * limit,
|
|
230
|
+
)
|
|
194
231
|
)
|
|
195
|
-
|
|
232
|
+
|
|
233
|
+
# Execute only the configured search methods
|
|
234
|
+
search_results: list[list[EntityEdge]] = []
|
|
235
|
+
if search_tasks:
|
|
236
|
+
search_results = list(await semaphore_gather(*search_tasks))
|
|
196
237
|
|
|
197
238
|
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
|
198
239
|
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
|
|
199
240
|
search_results.append(
|
|
200
241
|
await edge_bfs_search(
|
|
201
|
-
driver,
|
|
242
|
+
driver,
|
|
243
|
+
source_node_uuids,
|
|
244
|
+
config.bfs_max_depth,
|
|
245
|
+
search_filter,
|
|
246
|
+
group_ids,
|
|
247
|
+
2 * limit,
|
|
202
248
|
)
|
|
203
249
|
)
|
|
204
250
|
|
|
205
251
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
|
206
252
|
|
|
207
253
|
reranked_uuids: list[str] = []
|
|
254
|
+
edge_scores: list[float] = []
|
|
208
255
|
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
|
209
256
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
210
257
|
|
|
211
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
258
|
+
reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
212
259
|
elif config.reranker == EdgeReranker.mmr:
|
|
213
260
|
search_result_uuids_and_vectors = await get_embeddings_for_edges(
|
|
214
261
|
driver, list(edge_uuid_map.values())
|
|
215
262
|
)
|
|
216
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
263
|
+
reranked_uuids, edge_scores = maximal_marginal_relevance(
|
|
217
264
|
query_vector,
|
|
218
265
|
search_result_uuids_and_vectors,
|
|
219
266
|
config.mmr_lambda,
|
|
@@ -225,12 +272,13 @@ async def edge_search(
|
|
|
225
272
|
reranked_uuids = [
|
|
226
273
|
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
|
227
274
|
]
|
|
275
|
+
edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
|
|
228
276
|
elif config.reranker == EdgeReranker.node_distance:
|
|
229
277
|
if center_node_uuid is None:
|
|
230
278
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
231
279
|
|
|
232
280
|
# use rrf as a preliminary sort
|
|
233
|
-
sorted_result_uuids = rrf(
|
|
281
|
+
sorted_result_uuids, node_scores = rrf(
|
|
234
282
|
[[edge.uuid for edge in result] for result in search_results],
|
|
235
283
|
min_score=reranker_min_score,
|
|
236
284
|
)
|
|
@@ -243,7 +291,7 @@ async def edge_search(
|
|
|
243
291
|
|
|
244
292
|
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
|
|
245
293
|
|
|
246
|
-
reranked_node_uuids = await node_distance_reranker(
|
|
294
|
+
reranked_node_uuids, edge_scores = await node_distance_reranker(
|
|
247
295
|
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
|
|
248
296
|
)
|
|
249
297
|
|
|
@@ -255,7 +303,7 @@ async def edge_search(
|
|
|
255
303
|
if config.reranker == EdgeReranker.episode_mentions:
|
|
256
304
|
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
|
|
257
305
|
|
|
258
|
-
return reranked_edges[:limit]
|
|
306
|
+
return reranked_edges[:limit], edge_scores[:limit]
|
|
259
307
|
|
|
260
308
|
|
|
261
309
|
async def node_search(
|
|
@@ -270,28 +318,54 @@ async def node_search(
|
|
|
270
318
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
271
319
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
272
320
|
reranker_min_score: float = 0,
|
|
273
|
-
) -> list[EntityNode]:
|
|
321
|
+
) -> tuple[list[EntityNode], list[float]]:
|
|
274
322
|
if config is None:
|
|
275
|
-
return []
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
),
|
|
283
|
-
node_bfs_search(
|
|
284
|
-
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
|
|
285
|
-
),
|
|
286
|
-
]
|
|
323
|
+
return [], []
|
|
324
|
+
|
|
325
|
+
# Build search tasks based on configured search methods
|
|
326
|
+
search_tasks = []
|
|
327
|
+
if NodeSearchMethod.bm25 in config.search_methods:
|
|
328
|
+
search_tasks.append(
|
|
329
|
+
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
|
287
330
|
)
|
|
288
|
-
|
|
331
|
+
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
|
332
|
+
search_tasks.append(
|
|
333
|
+
node_similarity_search(
|
|
334
|
+
driver,
|
|
335
|
+
query_vector,
|
|
336
|
+
search_filter,
|
|
337
|
+
group_ids,
|
|
338
|
+
2 * limit,
|
|
339
|
+
config.sim_min_score,
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
if NodeSearchMethod.bfs in config.search_methods:
|
|
343
|
+
search_tasks.append(
|
|
344
|
+
node_bfs_search(
|
|
345
|
+
driver,
|
|
346
|
+
bfs_origin_node_uuids,
|
|
347
|
+
search_filter,
|
|
348
|
+
config.bfs_max_depth,
|
|
349
|
+
group_ids,
|
|
350
|
+
2 * limit,
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Execute only the configured search methods
|
|
355
|
+
search_results: list[list[EntityNode]] = []
|
|
356
|
+
if search_tasks:
|
|
357
|
+
search_results = list(await semaphore_gather(*search_tasks))
|
|
289
358
|
|
|
290
359
|
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
|
291
360
|
origin_node_uuids = [node.uuid for result in search_results for node in result]
|
|
292
361
|
search_results.append(
|
|
293
362
|
await node_bfs_search(
|
|
294
|
-
driver,
|
|
363
|
+
driver,
|
|
364
|
+
origin_node_uuids,
|
|
365
|
+
search_filter,
|
|
366
|
+
config.bfs_max_depth,
|
|
367
|
+
group_ids,
|
|
368
|
+
2 * limit,
|
|
295
369
|
)
|
|
296
370
|
)
|
|
297
371
|
|
|
@@ -299,14 +373,15 @@ async def node_search(
|
|
|
299
373
|
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
|
300
374
|
|
|
301
375
|
reranked_uuids: list[str] = []
|
|
376
|
+
node_scores: list[float] = []
|
|
302
377
|
if config.reranker == NodeReranker.rrf:
|
|
303
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
378
|
+
reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
304
379
|
elif config.reranker == NodeReranker.mmr:
|
|
305
380
|
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
|
|
306
381
|
driver, list(node_uuid_map.values())
|
|
307
382
|
)
|
|
308
383
|
|
|
309
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
384
|
+
reranked_uuids, node_scores = maximal_marginal_relevance(
|
|
310
385
|
query_vector,
|
|
311
386
|
search_result_uuids_and_vectors,
|
|
312
387
|
config.mmr_lambda,
|
|
@@ -321,23 +396,24 @@ async def node_search(
|
|
|
321
396
|
for name, score in reranked_node_names
|
|
322
397
|
if score >= reranker_min_score
|
|
323
398
|
]
|
|
399
|
+
node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
|
|
324
400
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
325
|
-
reranked_uuids = await episode_mentions_reranker(
|
|
401
|
+
reranked_uuids, node_scores = await episode_mentions_reranker(
|
|
326
402
|
driver, search_result_uuids, min_score=reranker_min_score
|
|
327
403
|
)
|
|
328
404
|
elif config.reranker == NodeReranker.node_distance:
|
|
329
405
|
if center_node_uuid is None:
|
|
330
406
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
331
|
-
reranked_uuids = await node_distance_reranker(
|
|
407
|
+
reranked_uuids, node_scores = await node_distance_reranker(
|
|
332
408
|
driver,
|
|
333
|
-
rrf(search_result_uuids, min_score=reranker_min_score),
|
|
409
|
+
rrf(search_result_uuids, min_score=reranker_min_score)[0],
|
|
334
410
|
center_node_uuid,
|
|
335
411
|
min_score=reranker_min_score,
|
|
336
412
|
)
|
|
337
413
|
|
|
338
414
|
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
|
339
415
|
|
|
340
|
-
return reranked_nodes[:limit]
|
|
416
|
+
return reranked_nodes[:limit], node_scores[:limit]
|
|
341
417
|
|
|
342
418
|
|
|
343
419
|
async def episode_search(
|
|
@@ -350,9 +426,9 @@ async def episode_search(
|
|
|
350
426
|
search_filter: SearchFilters,
|
|
351
427
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
352
428
|
reranker_min_score: float = 0,
|
|
353
|
-
) -> list[EpisodicNode]:
|
|
429
|
+
) -> tuple[list[EpisodicNode], list[float]]:
|
|
354
430
|
if config is None:
|
|
355
|
-
return []
|
|
431
|
+
return [], []
|
|
356
432
|
search_results: list[list[EpisodicNode]] = list(
|
|
357
433
|
await semaphore_gather(
|
|
358
434
|
*[
|
|
@@ -365,12 +441,13 @@ async def episode_search(
|
|
|
365
441
|
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
|
|
366
442
|
|
|
367
443
|
reranked_uuids: list[str] = []
|
|
444
|
+
episode_scores: list[float] = []
|
|
368
445
|
if config.reranker == EpisodeReranker.rrf:
|
|
369
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
446
|
+
reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
370
447
|
|
|
371
448
|
elif config.reranker == EpisodeReranker.cross_encoder:
|
|
372
449
|
# use rrf as a preliminary reranker
|
|
373
|
-
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
450
|
+
rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
374
451
|
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
375
452
|
|
|
376
453
|
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
|
|
@@ -381,10 +458,11 @@ async def episode_search(
|
|
|
381
458
|
for content, score in reranked_contents
|
|
382
459
|
if score >= reranker_min_score
|
|
383
460
|
]
|
|
461
|
+
episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
|
|
384
462
|
|
|
385
463
|
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
|
|
386
464
|
|
|
387
|
-
return reranked_episodes[:limit]
|
|
465
|
+
return reranked_episodes[:limit], episode_scores[:limit]
|
|
388
466
|
|
|
389
467
|
|
|
390
468
|
async def community_search(
|
|
@@ -396,9 +474,9 @@ async def community_search(
|
|
|
396
474
|
config: CommunitySearchConfig | None,
|
|
397
475
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
398
476
|
reranker_min_score: float = 0,
|
|
399
|
-
) -> list[CommunityNode]:
|
|
477
|
+
) -> tuple[list[CommunityNode], list[float]]:
|
|
400
478
|
if config is None:
|
|
401
|
-
return []
|
|
479
|
+
return [], []
|
|
402
480
|
|
|
403
481
|
search_results: list[list[CommunityNode]] = list(
|
|
404
482
|
await semaphore_gather(
|
|
@@ -417,14 +495,15 @@ async def community_search(
|
|
|
417
495
|
}
|
|
418
496
|
|
|
419
497
|
reranked_uuids: list[str] = []
|
|
498
|
+
community_scores: list[float] = []
|
|
420
499
|
if config.reranker == CommunityReranker.rrf:
|
|
421
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
500
|
+
reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
422
501
|
elif config.reranker == CommunityReranker.mmr:
|
|
423
502
|
search_result_uuids_and_vectors = await get_embeddings_for_communities(
|
|
424
503
|
driver, list(community_uuid_map.values())
|
|
425
504
|
)
|
|
426
505
|
|
|
427
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
506
|
+
reranked_uuids, community_scores = maximal_marginal_relevance(
|
|
428
507
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
|
|
429
508
|
)
|
|
430
509
|
elif config.reranker == CommunityReranker.cross_encoder:
|
|
@@ -433,7 +512,8 @@ async def community_search(
|
|
|
433
512
|
reranked_uuids = [
|
|
434
513
|
name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
|
|
435
514
|
]
|
|
515
|
+
community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
|
|
436
516
|
|
|
437
517
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
438
518
|
|
|
439
|
-
return reranked_communities[:limit]
|
|
519
|
+
return reranked_communities[:limit], community_scores[:limit]
|
|
@@ -119,7 +119,42 @@ class SearchConfig(BaseModel):
|
|
|
119
119
|
|
|
120
120
|
|
|
121
121
|
class SearchResults(BaseModel):
|
|
122
|
-
edges: list[EntityEdge]
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
122
|
+
edges: list[EntityEdge] = Field(default_factory=list)
|
|
123
|
+
edge_reranker_scores: list[float] = Field(default_factory=list)
|
|
124
|
+
nodes: list[EntityNode] = Field(default_factory=list)
|
|
125
|
+
node_reranker_scores: list[float] = Field(default_factory=list)
|
|
126
|
+
episodes: list[EpisodicNode] = Field(default_factory=list)
|
|
127
|
+
episode_reranker_scores: list[float] = Field(default_factory=list)
|
|
128
|
+
communities: list[CommunityNode] = Field(default_factory=list)
|
|
129
|
+
community_reranker_scores: list[float] = Field(default_factory=list)
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
|
|
133
|
+
"""
|
|
134
|
+
Merge multiple SearchResults objects into a single SearchResults object.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
results_list : list[SearchResults]
|
|
139
|
+
List of SearchResults objects to merge
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
SearchResults
|
|
144
|
+
A single SearchResults object containing all results
|
|
145
|
+
"""
|
|
146
|
+
if not results_list:
|
|
147
|
+
return cls()
|
|
148
|
+
|
|
149
|
+
merged = cls()
|
|
150
|
+
for result in results_list:
|
|
151
|
+
merged.edges.extend(result.edges)
|
|
152
|
+
merged.edge_reranker_scores.extend(result.edge_reranker_scores)
|
|
153
|
+
merged.nodes.extend(result.nodes)
|
|
154
|
+
merged.node_reranker_scores.extend(result.node_reranker_scores)
|
|
155
|
+
merged.episodes.extend(result.episodes)
|
|
156
|
+
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
|
|
157
|
+
merged.communities.extend(result.communities)
|
|
158
|
+
merged.community_reranker_scores.extend(result.community_reranker_scores)
|
|
159
|
+
|
|
160
|
+
return merged
|