graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +70 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +635 -260
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +37 -15
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +92 -48
- graphiti_core/llm_client/openai_client.py +39 -9
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +24 -15
- graphiti_core/prompts/extract_nodes.py +76 -35
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +110 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/content_chunking.py +702 -0
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +306 -156
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +466 -206
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
- graphiti_core-0.25.3.dist-info/RECORD +87 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/search/search.py
CHANGED
|
@@ -21,6 +21,7 @@ from time import time
|
|
|
21
21
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
22
22
|
from graphiti_core.driver.driver import GraphDriver
|
|
23
23
|
from graphiti_core.edges import EntityEdge
|
|
24
|
+
from graphiti_core.embedder.client import EMBEDDING_DIM
|
|
24
25
|
from graphiti_core.errors import SearchRerankerError
|
|
25
26
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
26
27
|
from graphiti_core.helpers import semaphore_gather
|
|
@@ -29,6 +30,7 @@ from graphiti_core.search.search_config import (
|
|
|
29
30
|
DEFAULT_SEARCH_LIMIT,
|
|
30
31
|
CommunityReranker,
|
|
31
32
|
CommunitySearchConfig,
|
|
33
|
+
CommunitySearchMethod,
|
|
32
34
|
EdgeReranker,
|
|
33
35
|
EdgeSearchConfig,
|
|
34
36
|
EdgeSearchMethod,
|
|
@@ -72,34 +74,53 @@ async def search(
|
|
|
72
74
|
center_node_uuid: str | None = None,
|
|
73
75
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
74
76
|
query_vector: list[float] | None = None,
|
|
77
|
+
driver: GraphDriver | None = None,
|
|
75
78
|
) -> SearchResults:
|
|
76
79
|
start = time()
|
|
77
80
|
|
|
78
|
-
driver = clients.driver
|
|
81
|
+
driver = driver or clients.driver
|
|
79
82
|
embedder = clients.embedder
|
|
80
83
|
cross_encoder = clients.cross_encoder
|
|
81
84
|
|
|
82
85
|
if query.strip() == '':
|
|
83
|
-
return SearchResults(
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
86
|
+
return SearchResults()
|
|
87
|
+
|
|
88
|
+
if (
|
|
89
|
+
config.edge_config
|
|
90
|
+
and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
|
|
91
|
+
or config.edge_config
|
|
92
|
+
and EdgeReranker.mmr == config.edge_config.reranker
|
|
93
|
+
or config.node_config
|
|
94
|
+
and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
|
|
95
|
+
or config.node_config
|
|
96
|
+
and NodeReranker.mmr == config.node_config.reranker
|
|
97
|
+
or (
|
|
98
|
+
config.community_config
|
|
99
|
+
and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
|
|
88
100
|
)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
101
|
+
or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
|
|
102
|
+
):
|
|
103
|
+
search_vector = (
|
|
104
|
+
query_vector
|
|
105
|
+
if query_vector is not None
|
|
106
|
+
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
search_vector = [0.0] * EMBEDDING_DIM
|
|
94
110
|
|
|
95
111
|
# if group_ids is empty, set it to None
|
|
96
112
|
group_ids = group_ids if group_ids and group_ids != [''] else None
|
|
97
|
-
|
|
113
|
+
(
|
|
114
|
+
(edges, edge_reranker_scores),
|
|
115
|
+
(nodes, node_reranker_scores),
|
|
116
|
+
(episodes, episode_reranker_scores),
|
|
117
|
+
(communities, community_reranker_scores),
|
|
118
|
+
) = await semaphore_gather(
|
|
98
119
|
edge_search(
|
|
99
120
|
driver,
|
|
100
121
|
cross_encoder,
|
|
101
122
|
query,
|
|
102
|
-
|
|
123
|
+
search_vector,
|
|
103
124
|
group_ids,
|
|
104
125
|
config.edge_config,
|
|
105
126
|
search_filter,
|
|
@@ -112,7 +133,7 @@ async def search(
|
|
|
112
133
|
driver,
|
|
113
134
|
cross_encoder,
|
|
114
135
|
query,
|
|
115
|
-
|
|
136
|
+
search_vector,
|
|
116
137
|
group_ids,
|
|
117
138
|
config.node_config,
|
|
118
139
|
search_filter,
|
|
@@ -125,7 +146,7 @@ async def search(
|
|
|
125
146
|
driver,
|
|
126
147
|
cross_encoder,
|
|
127
148
|
query,
|
|
128
|
-
|
|
149
|
+
search_vector,
|
|
129
150
|
group_ids,
|
|
130
151
|
config.episode_config,
|
|
131
152
|
search_filter,
|
|
@@ -136,7 +157,7 @@ async def search(
|
|
|
136
157
|
driver,
|
|
137
158
|
cross_encoder,
|
|
138
159
|
query,
|
|
139
|
-
|
|
160
|
+
search_vector,
|
|
140
161
|
group_ids,
|
|
141
162
|
config.community_config,
|
|
142
163
|
config.limit,
|
|
@@ -146,9 +167,13 @@ async def search(
|
|
|
146
167
|
|
|
147
168
|
results = SearchResults(
|
|
148
169
|
edges=edges,
|
|
170
|
+
edge_reranker_scores=edge_reranker_scores,
|
|
149
171
|
nodes=nodes,
|
|
172
|
+
node_reranker_scores=node_reranker_scores,
|
|
150
173
|
episodes=episodes,
|
|
174
|
+
episode_reranker_scores=episode_reranker_scores,
|
|
151
175
|
communities=communities,
|
|
176
|
+
community_reranker_scores=community_reranker_scores,
|
|
152
177
|
)
|
|
153
178
|
|
|
154
179
|
latency = (time() - start) * 1000
|
|
@@ -170,50 +195,72 @@ async def edge_search(
|
|
|
170
195
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
171
196
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
172
197
|
reranker_min_score: float = 0,
|
|
173
|
-
) -> list[EntityEdge]:
|
|
198
|
+
) -> tuple[list[EntityEdge], list[float]]:
|
|
174
199
|
if config is None:
|
|
175
|
-
return []
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
200
|
+
return [], []
|
|
201
|
+
|
|
202
|
+
# Build search tasks based on configured search methods
|
|
203
|
+
search_tasks = []
|
|
204
|
+
if EdgeSearchMethod.bm25 in config.search_methods:
|
|
205
|
+
search_tasks.append(
|
|
206
|
+
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
|
207
|
+
)
|
|
208
|
+
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
|
209
|
+
search_tasks.append(
|
|
210
|
+
edge_similarity_search(
|
|
211
|
+
driver,
|
|
212
|
+
query_vector,
|
|
213
|
+
None,
|
|
214
|
+
None,
|
|
215
|
+
search_filter,
|
|
216
|
+
group_ids,
|
|
217
|
+
2 * limit,
|
|
218
|
+
config.sim_min_score,
|
|
219
|
+
)
|
|
220
|
+
)
|
|
221
|
+
if EdgeSearchMethod.bfs in config.search_methods:
|
|
222
|
+
search_tasks.append(
|
|
223
|
+
edge_bfs_search(
|
|
224
|
+
driver,
|
|
225
|
+
bfs_origin_node_uuids,
|
|
226
|
+
config.bfs_max_depth,
|
|
227
|
+
search_filter,
|
|
228
|
+
group_ids,
|
|
229
|
+
2 * limit,
|
|
230
|
+
)
|
|
194
231
|
)
|
|
195
|
-
|
|
232
|
+
|
|
233
|
+
# Execute only the configured search methods
|
|
234
|
+
search_results: list[list[EntityEdge]] = []
|
|
235
|
+
if search_tasks:
|
|
236
|
+
search_results = list(await semaphore_gather(*search_tasks))
|
|
196
237
|
|
|
197
238
|
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
|
198
239
|
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
|
|
199
240
|
search_results.append(
|
|
200
241
|
await edge_bfs_search(
|
|
201
|
-
driver,
|
|
242
|
+
driver,
|
|
243
|
+
source_node_uuids,
|
|
244
|
+
config.bfs_max_depth,
|
|
245
|
+
search_filter,
|
|
246
|
+
group_ids,
|
|
247
|
+
2 * limit,
|
|
202
248
|
)
|
|
203
249
|
)
|
|
204
250
|
|
|
205
251
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
|
206
252
|
|
|
207
253
|
reranked_uuids: list[str] = []
|
|
254
|
+
edge_scores: list[float] = []
|
|
208
255
|
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
|
209
256
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
210
257
|
|
|
211
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
258
|
+
reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
212
259
|
elif config.reranker == EdgeReranker.mmr:
|
|
213
260
|
search_result_uuids_and_vectors = await get_embeddings_for_edges(
|
|
214
261
|
driver, list(edge_uuid_map.values())
|
|
215
262
|
)
|
|
216
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
263
|
+
reranked_uuids, edge_scores = maximal_marginal_relevance(
|
|
217
264
|
query_vector,
|
|
218
265
|
search_result_uuids_and_vectors,
|
|
219
266
|
config.mmr_lambda,
|
|
@@ -225,12 +272,13 @@ async def edge_search(
|
|
|
225
272
|
reranked_uuids = [
|
|
226
273
|
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
|
227
274
|
]
|
|
275
|
+
edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
|
|
228
276
|
elif config.reranker == EdgeReranker.node_distance:
|
|
229
277
|
if center_node_uuid is None:
|
|
230
278
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
231
279
|
|
|
232
280
|
# use rrf as a preliminary sort
|
|
233
|
-
sorted_result_uuids = rrf(
|
|
281
|
+
sorted_result_uuids, node_scores = rrf(
|
|
234
282
|
[[edge.uuid for edge in result] for result in search_results],
|
|
235
283
|
min_score=reranker_min_score,
|
|
236
284
|
)
|
|
@@ -243,7 +291,7 @@ async def edge_search(
|
|
|
243
291
|
|
|
244
292
|
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
|
|
245
293
|
|
|
246
|
-
reranked_node_uuids = await node_distance_reranker(
|
|
294
|
+
reranked_node_uuids, edge_scores = await node_distance_reranker(
|
|
247
295
|
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
|
|
248
296
|
)
|
|
249
297
|
|
|
@@ -255,7 +303,7 @@ async def edge_search(
|
|
|
255
303
|
if config.reranker == EdgeReranker.episode_mentions:
|
|
256
304
|
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
|
|
257
305
|
|
|
258
|
-
return reranked_edges[:limit]
|
|
306
|
+
return reranked_edges[:limit], edge_scores[:limit]
|
|
259
307
|
|
|
260
308
|
|
|
261
309
|
async def node_search(
|
|
@@ -270,28 +318,54 @@ async def node_search(
|
|
|
270
318
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
271
319
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
272
320
|
reranker_min_score: float = 0,
|
|
273
|
-
) -> list[EntityNode]:
|
|
321
|
+
) -> tuple[list[EntityNode], list[float]]:
|
|
274
322
|
if config is None:
|
|
275
|
-
return []
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
),
|
|
283
|
-
node_bfs_search(
|
|
284
|
-
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
|
|
285
|
-
),
|
|
286
|
-
]
|
|
323
|
+
return [], []
|
|
324
|
+
|
|
325
|
+
# Build search tasks based on configured search methods
|
|
326
|
+
search_tasks = []
|
|
327
|
+
if NodeSearchMethod.bm25 in config.search_methods:
|
|
328
|
+
search_tasks.append(
|
|
329
|
+
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
|
287
330
|
)
|
|
288
|
-
|
|
331
|
+
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
|
332
|
+
search_tasks.append(
|
|
333
|
+
node_similarity_search(
|
|
334
|
+
driver,
|
|
335
|
+
query_vector,
|
|
336
|
+
search_filter,
|
|
337
|
+
group_ids,
|
|
338
|
+
2 * limit,
|
|
339
|
+
config.sim_min_score,
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
if NodeSearchMethod.bfs in config.search_methods:
|
|
343
|
+
search_tasks.append(
|
|
344
|
+
node_bfs_search(
|
|
345
|
+
driver,
|
|
346
|
+
bfs_origin_node_uuids,
|
|
347
|
+
search_filter,
|
|
348
|
+
config.bfs_max_depth,
|
|
349
|
+
group_ids,
|
|
350
|
+
2 * limit,
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Execute only the configured search methods
|
|
355
|
+
search_results: list[list[EntityNode]] = []
|
|
356
|
+
if search_tasks:
|
|
357
|
+
search_results = list(await semaphore_gather(*search_tasks))
|
|
289
358
|
|
|
290
359
|
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
|
291
360
|
origin_node_uuids = [node.uuid for result in search_results for node in result]
|
|
292
361
|
search_results.append(
|
|
293
362
|
await node_bfs_search(
|
|
294
|
-
driver,
|
|
363
|
+
driver,
|
|
364
|
+
origin_node_uuids,
|
|
365
|
+
search_filter,
|
|
366
|
+
config.bfs_max_depth,
|
|
367
|
+
group_ids,
|
|
368
|
+
2 * limit,
|
|
295
369
|
)
|
|
296
370
|
)
|
|
297
371
|
|
|
@@ -299,14 +373,15 @@ async def node_search(
|
|
|
299
373
|
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
|
300
374
|
|
|
301
375
|
reranked_uuids: list[str] = []
|
|
376
|
+
node_scores: list[float] = []
|
|
302
377
|
if config.reranker == NodeReranker.rrf:
|
|
303
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
378
|
+
reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
304
379
|
elif config.reranker == NodeReranker.mmr:
|
|
305
380
|
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
|
|
306
381
|
driver, list(node_uuid_map.values())
|
|
307
382
|
)
|
|
308
383
|
|
|
309
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
384
|
+
reranked_uuids, node_scores = maximal_marginal_relevance(
|
|
310
385
|
query_vector,
|
|
311
386
|
search_result_uuids_and_vectors,
|
|
312
387
|
config.mmr_lambda,
|
|
@@ -321,23 +396,24 @@ async def node_search(
|
|
|
321
396
|
for name, score in reranked_node_names
|
|
322
397
|
if score >= reranker_min_score
|
|
323
398
|
]
|
|
399
|
+
node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
|
|
324
400
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
325
|
-
reranked_uuids = await episode_mentions_reranker(
|
|
401
|
+
reranked_uuids, node_scores = await episode_mentions_reranker(
|
|
326
402
|
driver, search_result_uuids, min_score=reranker_min_score
|
|
327
403
|
)
|
|
328
404
|
elif config.reranker == NodeReranker.node_distance:
|
|
329
405
|
if center_node_uuid is None:
|
|
330
406
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
331
|
-
reranked_uuids = await node_distance_reranker(
|
|
407
|
+
reranked_uuids, node_scores = await node_distance_reranker(
|
|
332
408
|
driver,
|
|
333
|
-
rrf(search_result_uuids, min_score=reranker_min_score),
|
|
409
|
+
rrf(search_result_uuids, min_score=reranker_min_score)[0],
|
|
334
410
|
center_node_uuid,
|
|
335
411
|
min_score=reranker_min_score,
|
|
336
412
|
)
|
|
337
413
|
|
|
338
414
|
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
|
339
415
|
|
|
340
|
-
return reranked_nodes[:limit]
|
|
416
|
+
return reranked_nodes[:limit], node_scores[:limit]
|
|
341
417
|
|
|
342
418
|
|
|
343
419
|
async def episode_search(
|
|
@@ -350,9 +426,9 @@ async def episode_search(
|
|
|
350
426
|
search_filter: SearchFilters,
|
|
351
427
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
352
428
|
reranker_min_score: float = 0,
|
|
353
|
-
) -> list[EpisodicNode]:
|
|
429
|
+
) -> tuple[list[EpisodicNode], list[float]]:
|
|
354
430
|
if config is None:
|
|
355
|
-
return []
|
|
431
|
+
return [], []
|
|
356
432
|
search_results: list[list[EpisodicNode]] = list(
|
|
357
433
|
await semaphore_gather(
|
|
358
434
|
*[
|
|
@@ -365,12 +441,13 @@ async def episode_search(
|
|
|
365
441
|
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
|
|
366
442
|
|
|
367
443
|
reranked_uuids: list[str] = []
|
|
444
|
+
episode_scores: list[float] = []
|
|
368
445
|
if config.reranker == EpisodeReranker.rrf:
|
|
369
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
446
|
+
reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
370
447
|
|
|
371
448
|
elif config.reranker == EpisodeReranker.cross_encoder:
|
|
372
449
|
# use rrf as a preliminary reranker
|
|
373
|
-
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
450
|
+
rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
374
451
|
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
375
452
|
|
|
376
453
|
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
|
|
@@ -381,10 +458,11 @@ async def episode_search(
|
|
|
381
458
|
for content, score in reranked_contents
|
|
382
459
|
if score >= reranker_min_score
|
|
383
460
|
]
|
|
461
|
+
episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
|
|
384
462
|
|
|
385
463
|
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
|
|
386
464
|
|
|
387
|
-
return reranked_episodes[:limit]
|
|
465
|
+
return reranked_episodes[:limit], episode_scores[:limit]
|
|
388
466
|
|
|
389
467
|
|
|
390
468
|
async def community_search(
|
|
@@ -396,9 +474,9 @@ async def community_search(
|
|
|
396
474
|
config: CommunitySearchConfig | None,
|
|
397
475
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
398
476
|
reranker_min_score: float = 0,
|
|
399
|
-
) -> list[CommunityNode]:
|
|
477
|
+
) -> tuple[list[CommunityNode], list[float]]:
|
|
400
478
|
if config is None:
|
|
401
|
-
return []
|
|
479
|
+
return [], []
|
|
402
480
|
|
|
403
481
|
search_results: list[list[CommunityNode]] = list(
|
|
404
482
|
await semaphore_gather(
|
|
@@ -417,14 +495,15 @@ async def community_search(
|
|
|
417
495
|
}
|
|
418
496
|
|
|
419
497
|
reranked_uuids: list[str] = []
|
|
498
|
+
community_scores: list[float] = []
|
|
420
499
|
if config.reranker == CommunityReranker.rrf:
|
|
421
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
500
|
+
reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
422
501
|
elif config.reranker == CommunityReranker.mmr:
|
|
423
502
|
search_result_uuids_and_vectors = await get_embeddings_for_communities(
|
|
424
503
|
driver, list(community_uuid_map.values())
|
|
425
504
|
)
|
|
426
505
|
|
|
427
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
506
|
+
reranked_uuids, community_scores = maximal_marginal_relevance(
|
|
428
507
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
|
|
429
508
|
)
|
|
430
509
|
elif config.reranker == CommunityReranker.cross_encoder:
|
|
@@ -433,7 +512,8 @@ async def community_search(
|
|
|
433
512
|
reranked_uuids = [
|
|
434
513
|
name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
|
|
435
514
|
]
|
|
515
|
+
community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
|
|
436
516
|
|
|
437
517
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
438
518
|
|
|
439
|
-
return reranked_communities[:limit]
|
|
519
|
+
return reranked_communities[:limit], community_scores[:limit]
|
|
@@ -119,7 +119,42 @@ class SearchConfig(BaseModel):
|
|
|
119
119
|
|
|
120
120
|
|
|
121
121
|
class SearchResults(BaseModel):
|
|
122
|
-
edges: list[EntityEdge]
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
122
|
+
edges: list[EntityEdge] = Field(default_factory=list)
|
|
123
|
+
edge_reranker_scores: list[float] = Field(default_factory=list)
|
|
124
|
+
nodes: list[EntityNode] = Field(default_factory=list)
|
|
125
|
+
node_reranker_scores: list[float] = Field(default_factory=list)
|
|
126
|
+
episodes: list[EpisodicNode] = Field(default_factory=list)
|
|
127
|
+
episode_reranker_scores: list[float] = Field(default_factory=list)
|
|
128
|
+
communities: list[CommunityNode] = Field(default_factory=list)
|
|
129
|
+
community_reranker_scores: list[float] = Field(default_factory=list)
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
|
|
133
|
+
"""
|
|
134
|
+
Merge multiple SearchResults objects into a single SearchResults object.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
results_list : list[SearchResults]
|
|
139
|
+
List of SearchResults objects to merge
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
SearchResults
|
|
144
|
+
A single SearchResults object containing all results
|
|
145
|
+
"""
|
|
146
|
+
if not results_list:
|
|
147
|
+
return cls()
|
|
148
|
+
|
|
149
|
+
merged = cls()
|
|
150
|
+
for result in results_list:
|
|
151
|
+
merged.edges.extend(result.edges)
|
|
152
|
+
merged.edge_reranker_scores.extend(result.edge_reranker_scores)
|
|
153
|
+
merged.nodes.extend(result.nodes)
|
|
154
|
+
merged.node_reranker_scores.extend(result.node_reranker_scores)
|
|
155
|
+
merged.episodes.extend(result.episodes)
|
|
156
|
+
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
|
|
157
|
+
merged.communities.extend(result.communities)
|
|
158
|
+
merged.community_reranker_scores.extend(result.community_reranker_scores)
|
|
159
|
+
|
|
160
|
+
return merged
|