graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
graphiti_core/search/search.py
CHANGED
|
@@ -18,10 +18,10 @@ import logging
|
|
|
18
18
|
from collections import defaultdict
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
|
-
from neo4j import AsyncDriver
|
|
22
|
-
|
|
23
21
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
22
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
24
23
|
from graphiti_core.edges import EntityEdge
|
|
24
|
+
from graphiti_core.embedder.client import EMBEDDING_DIM
|
|
25
25
|
from graphiti_core.errors import SearchRerankerError
|
|
26
26
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
27
27
|
from graphiti_core.helpers import semaphore_gather
|
|
@@ -30,6 +30,7 @@ from graphiti_core.search.search_config import (
|
|
|
30
30
|
DEFAULT_SEARCH_LIMIT,
|
|
31
31
|
CommunityReranker,
|
|
32
32
|
CommunitySearchConfig,
|
|
33
|
+
CommunitySearchMethod,
|
|
33
34
|
EdgeReranker,
|
|
34
35
|
EdgeSearchConfig,
|
|
35
36
|
EdgeSearchMethod,
|
|
@@ -73,34 +74,53 @@ async def search(
|
|
|
73
74
|
center_node_uuid: str | None = None,
|
|
74
75
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
75
76
|
query_vector: list[float] | None = None,
|
|
77
|
+
driver: GraphDriver | None = None,
|
|
76
78
|
) -> SearchResults:
|
|
77
79
|
start = time()
|
|
78
80
|
|
|
79
|
-
driver = clients.driver
|
|
81
|
+
driver = driver or clients.driver
|
|
80
82
|
embedder = clients.embedder
|
|
81
83
|
cross_encoder = clients.cross_encoder
|
|
82
84
|
|
|
83
85
|
if query.strip() == '':
|
|
84
|
-
return SearchResults(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
86
|
+
return SearchResults()
|
|
87
|
+
|
|
88
|
+
if (
|
|
89
|
+
config.edge_config
|
|
90
|
+
and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
|
|
91
|
+
or config.edge_config
|
|
92
|
+
and EdgeReranker.mmr == config.edge_config.reranker
|
|
93
|
+
or config.node_config
|
|
94
|
+
and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
|
|
95
|
+
or config.node_config
|
|
96
|
+
and NodeReranker.mmr == config.node_config.reranker
|
|
97
|
+
or (
|
|
98
|
+
config.community_config
|
|
99
|
+
and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
|
|
89
100
|
)
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
101
|
+
or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
|
|
102
|
+
):
|
|
103
|
+
search_vector = (
|
|
104
|
+
query_vector
|
|
105
|
+
if query_vector is not None
|
|
106
|
+
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
search_vector = [0.0] * EMBEDDING_DIM
|
|
95
110
|
|
|
96
111
|
# if group_ids is empty, set it to None
|
|
97
|
-
group_ids = group_ids if group_ids else None
|
|
98
|
-
|
|
112
|
+
group_ids = group_ids if group_ids and group_ids != [''] else None
|
|
113
|
+
(
|
|
114
|
+
(edges, edge_reranker_scores),
|
|
115
|
+
(nodes, node_reranker_scores),
|
|
116
|
+
(episodes, episode_reranker_scores),
|
|
117
|
+
(communities, community_reranker_scores),
|
|
118
|
+
) = await semaphore_gather(
|
|
99
119
|
edge_search(
|
|
100
120
|
driver,
|
|
101
121
|
cross_encoder,
|
|
102
122
|
query,
|
|
103
|
-
|
|
123
|
+
search_vector,
|
|
104
124
|
group_ids,
|
|
105
125
|
config.edge_config,
|
|
106
126
|
search_filter,
|
|
@@ -113,7 +133,7 @@ async def search(
|
|
|
113
133
|
driver,
|
|
114
134
|
cross_encoder,
|
|
115
135
|
query,
|
|
116
|
-
|
|
136
|
+
search_vector,
|
|
117
137
|
group_ids,
|
|
118
138
|
config.node_config,
|
|
119
139
|
search_filter,
|
|
@@ -126,7 +146,7 @@ async def search(
|
|
|
126
146
|
driver,
|
|
127
147
|
cross_encoder,
|
|
128
148
|
query,
|
|
129
|
-
|
|
149
|
+
search_vector,
|
|
130
150
|
group_ids,
|
|
131
151
|
config.episode_config,
|
|
132
152
|
search_filter,
|
|
@@ -137,7 +157,7 @@ async def search(
|
|
|
137
157
|
driver,
|
|
138
158
|
cross_encoder,
|
|
139
159
|
query,
|
|
140
|
-
|
|
160
|
+
search_vector,
|
|
141
161
|
group_ids,
|
|
142
162
|
config.community_config,
|
|
143
163
|
config.limit,
|
|
@@ -147,9 +167,13 @@ async def search(
|
|
|
147
167
|
|
|
148
168
|
results = SearchResults(
|
|
149
169
|
edges=edges,
|
|
170
|
+
edge_reranker_scores=edge_reranker_scores,
|
|
150
171
|
nodes=nodes,
|
|
172
|
+
node_reranker_scores=node_reranker_scores,
|
|
151
173
|
episodes=episodes,
|
|
174
|
+
episode_reranker_scores=episode_reranker_scores,
|
|
152
175
|
communities=communities,
|
|
176
|
+
community_reranker_scores=community_reranker_scores,
|
|
153
177
|
)
|
|
154
178
|
|
|
155
179
|
latency = (time() - start) * 1000
|
|
@@ -160,7 +184,7 @@ async def search(
|
|
|
160
184
|
|
|
161
185
|
|
|
162
186
|
async def edge_search(
|
|
163
|
-
driver:
|
|
187
|
+
driver: GraphDriver,
|
|
164
188
|
cross_encoder: CrossEncoderClient,
|
|
165
189
|
query: str,
|
|
166
190
|
query_vector: list[float],
|
|
@@ -171,51 +195,72 @@ async def edge_search(
|
|
|
171
195
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
172
196
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
173
197
|
reranker_min_score: float = 0,
|
|
174
|
-
) -> list[EntityEdge]:
|
|
198
|
+
) -> tuple[list[EntityEdge], list[float]]:
|
|
175
199
|
if config is None:
|
|
176
|
-
return []
|
|
200
|
+
return [], []
|
|
177
201
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
202
|
+
# Build search tasks based on configured search methods
|
|
203
|
+
search_tasks = []
|
|
204
|
+
if EdgeSearchMethod.bm25 in config.search_methods:
|
|
205
|
+
search_tasks.append(
|
|
206
|
+
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
|
207
|
+
)
|
|
208
|
+
if EdgeSearchMethod.cosine_similarity in config.search_methods:
|
|
209
|
+
search_tasks.append(
|
|
210
|
+
edge_similarity_search(
|
|
211
|
+
driver,
|
|
212
|
+
query_vector,
|
|
213
|
+
None,
|
|
214
|
+
None,
|
|
215
|
+
search_filter,
|
|
216
|
+
group_ids,
|
|
217
|
+
2 * limit,
|
|
218
|
+
config.sim_min_score,
|
|
219
|
+
)
|
|
220
|
+
)
|
|
221
|
+
if EdgeSearchMethod.bfs in config.search_methods:
|
|
222
|
+
search_tasks.append(
|
|
223
|
+
edge_bfs_search(
|
|
224
|
+
driver,
|
|
225
|
+
bfs_origin_node_uuids,
|
|
226
|
+
config.bfs_max_depth,
|
|
227
|
+
search_filter,
|
|
228
|
+
group_ids,
|
|
229
|
+
2 * limit,
|
|
230
|
+
)
|
|
196
231
|
)
|
|
197
|
-
|
|
232
|
+
|
|
233
|
+
# Execute only the configured search methods
|
|
234
|
+
search_results: list[list[EntityEdge]] = []
|
|
235
|
+
if search_tasks:
|
|
236
|
+
search_results = list(await semaphore_gather(*search_tasks))
|
|
198
237
|
|
|
199
238
|
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
|
200
239
|
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
|
|
201
240
|
search_results.append(
|
|
202
241
|
await edge_bfs_search(
|
|
203
|
-
driver,
|
|
242
|
+
driver,
|
|
243
|
+
source_node_uuids,
|
|
244
|
+
config.bfs_max_depth,
|
|
245
|
+
search_filter,
|
|
246
|
+
group_ids,
|
|
247
|
+
2 * limit,
|
|
204
248
|
)
|
|
205
249
|
)
|
|
206
250
|
|
|
207
251
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
|
208
252
|
|
|
209
253
|
reranked_uuids: list[str] = []
|
|
254
|
+
edge_scores: list[float] = []
|
|
210
255
|
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
|
211
256
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
212
257
|
|
|
213
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
258
|
+
reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
214
259
|
elif config.reranker == EdgeReranker.mmr:
|
|
215
260
|
search_result_uuids_and_vectors = await get_embeddings_for_edges(
|
|
216
261
|
driver, list(edge_uuid_map.values())
|
|
217
262
|
)
|
|
218
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
263
|
+
reranked_uuids, edge_scores = maximal_marginal_relevance(
|
|
219
264
|
query_vector,
|
|
220
265
|
search_result_uuids_and_vectors,
|
|
221
266
|
config.mmr_lambda,
|
|
@@ -227,12 +272,13 @@ async def edge_search(
|
|
|
227
272
|
reranked_uuids = [
|
|
228
273
|
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
|
229
274
|
]
|
|
275
|
+
edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
|
|
230
276
|
elif config.reranker == EdgeReranker.node_distance:
|
|
231
277
|
if center_node_uuid is None:
|
|
232
278
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
233
279
|
|
|
234
280
|
# use rrf as a preliminary sort
|
|
235
|
-
sorted_result_uuids = rrf(
|
|
281
|
+
sorted_result_uuids, node_scores = rrf(
|
|
236
282
|
[[edge.uuid for edge in result] for result in search_results],
|
|
237
283
|
min_score=reranker_min_score,
|
|
238
284
|
)
|
|
@@ -245,7 +291,7 @@ async def edge_search(
|
|
|
245
291
|
|
|
246
292
|
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
|
|
247
293
|
|
|
248
|
-
reranked_node_uuids = await node_distance_reranker(
|
|
294
|
+
reranked_node_uuids, edge_scores = await node_distance_reranker(
|
|
249
295
|
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
|
|
250
296
|
)
|
|
251
297
|
|
|
@@ -257,11 +303,11 @@ async def edge_search(
|
|
|
257
303
|
if config.reranker == EdgeReranker.episode_mentions:
|
|
258
304
|
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
|
|
259
305
|
|
|
260
|
-
return reranked_edges[:limit]
|
|
306
|
+
return reranked_edges[:limit], edge_scores[:limit]
|
|
261
307
|
|
|
262
308
|
|
|
263
309
|
async def node_search(
|
|
264
|
-
driver:
|
|
310
|
+
driver: GraphDriver,
|
|
265
311
|
cross_encoder: CrossEncoderClient,
|
|
266
312
|
query: str,
|
|
267
313
|
query_vector: list[float],
|
|
@@ -272,29 +318,54 @@ async def node_search(
|
|
|
272
318
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
273
319
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
274
320
|
reranker_min_score: float = 0,
|
|
275
|
-
) -> list[EntityNode]:
|
|
321
|
+
) -> tuple[list[EntityNode], list[float]]:
|
|
276
322
|
if config is None:
|
|
277
|
-
return []
|
|
323
|
+
return [], []
|
|
278
324
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
|
|
285
|
-
),
|
|
286
|
-
node_bfs_search(
|
|
287
|
-
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
|
|
288
|
-
),
|
|
289
|
-
]
|
|
325
|
+
# Build search tasks based on configured search methods
|
|
326
|
+
search_tasks = []
|
|
327
|
+
if NodeSearchMethod.bm25 in config.search_methods:
|
|
328
|
+
search_tasks.append(
|
|
329
|
+
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
|
|
290
330
|
)
|
|
291
|
-
|
|
331
|
+
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
|
332
|
+
search_tasks.append(
|
|
333
|
+
node_similarity_search(
|
|
334
|
+
driver,
|
|
335
|
+
query_vector,
|
|
336
|
+
search_filter,
|
|
337
|
+
group_ids,
|
|
338
|
+
2 * limit,
|
|
339
|
+
config.sim_min_score,
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
if NodeSearchMethod.bfs in config.search_methods:
|
|
343
|
+
search_tasks.append(
|
|
344
|
+
node_bfs_search(
|
|
345
|
+
driver,
|
|
346
|
+
bfs_origin_node_uuids,
|
|
347
|
+
search_filter,
|
|
348
|
+
config.bfs_max_depth,
|
|
349
|
+
group_ids,
|
|
350
|
+
2 * limit,
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Execute only the configured search methods
|
|
355
|
+
search_results: list[list[EntityNode]] = []
|
|
356
|
+
if search_tasks:
|
|
357
|
+
search_results = list(await semaphore_gather(*search_tasks))
|
|
292
358
|
|
|
293
359
|
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
|
|
294
360
|
origin_node_uuids = [node.uuid for result in search_results for node in result]
|
|
295
361
|
search_results.append(
|
|
296
362
|
await node_bfs_search(
|
|
297
|
-
driver,
|
|
363
|
+
driver,
|
|
364
|
+
origin_node_uuids,
|
|
365
|
+
search_filter,
|
|
366
|
+
config.bfs_max_depth,
|
|
367
|
+
group_ids,
|
|
368
|
+
2 * limit,
|
|
298
369
|
)
|
|
299
370
|
)
|
|
300
371
|
|
|
@@ -302,14 +373,15 @@ async def node_search(
|
|
|
302
373
|
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
|
303
374
|
|
|
304
375
|
reranked_uuids: list[str] = []
|
|
376
|
+
node_scores: list[float] = []
|
|
305
377
|
if config.reranker == NodeReranker.rrf:
|
|
306
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
378
|
+
reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
307
379
|
elif config.reranker == NodeReranker.mmr:
|
|
308
380
|
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
|
|
309
381
|
driver, list(node_uuid_map.values())
|
|
310
382
|
)
|
|
311
383
|
|
|
312
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
384
|
+
reranked_uuids, node_scores = maximal_marginal_relevance(
|
|
313
385
|
query_vector,
|
|
314
386
|
search_result_uuids_and_vectors,
|
|
315
387
|
config.mmr_lambda,
|
|
@@ -324,27 +396,28 @@ async def node_search(
|
|
|
324
396
|
for name, score in reranked_node_names
|
|
325
397
|
if score >= reranker_min_score
|
|
326
398
|
]
|
|
399
|
+
node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
|
|
327
400
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
328
|
-
reranked_uuids = await episode_mentions_reranker(
|
|
401
|
+
reranked_uuids, node_scores = await episode_mentions_reranker(
|
|
329
402
|
driver, search_result_uuids, min_score=reranker_min_score
|
|
330
403
|
)
|
|
331
404
|
elif config.reranker == NodeReranker.node_distance:
|
|
332
405
|
if center_node_uuid is None:
|
|
333
406
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
334
|
-
reranked_uuids = await node_distance_reranker(
|
|
407
|
+
reranked_uuids, node_scores = await node_distance_reranker(
|
|
335
408
|
driver,
|
|
336
|
-
rrf(search_result_uuids, min_score=reranker_min_score),
|
|
409
|
+
rrf(search_result_uuids, min_score=reranker_min_score)[0],
|
|
337
410
|
center_node_uuid,
|
|
338
411
|
min_score=reranker_min_score,
|
|
339
412
|
)
|
|
340
413
|
|
|
341
414
|
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
|
342
415
|
|
|
343
|
-
return reranked_nodes[:limit]
|
|
416
|
+
return reranked_nodes[:limit], node_scores[:limit]
|
|
344
417
|
|
|
345
418
|
|
|
346
419
|
async def episode_search(
|
|
347
|
-
driver:
|
|
420
|
+
driver: GraphDriver,
|
|
348
421
|
cross_encoder: CrossEncoderClient,
|
|
349
422
|
query: str,
|
|
350
423
|
_query_vector: list[float],
|
|
@@ -353,10 +426,9 @@ async def episode_search(
|
|
|
353
426
|
search_filter: SearchFilters,
|
|
354
427
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
355
428
|
reranker_min_score: float = 0,
|
|
356
|
-
) -> list[EpisodicNode]:
|
|
429
|
+
) -> tuple[list[EpisodicNode], list[float]]:
|
|
357
430
|
if config is None:
|
|
358
|
-
return []
|
|
359
|
-
|
|
431
|
+
return [], []
|
|
360
432
|
search_results: list[list[EpisodicNode]] = list(
|
|
361
433
|
await semaphore_gather(
|
|
362
434
|
*[
|
|
@@ -369,12 +441,13 @@ async def episode_search(
|
|
|
369
441
|
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
|
|
370
442
|
|
|
371
443
|
reranked_uuids: list[str] = []
|
|
444
|
+
episode_scores: list[float] = []
|
|
372
445
|
if config.reranker == EpisodeReranker.rrf:
|
|
373
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
446
|
+
reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
374
447
|
|
|
375
448
|
elif config.reranker == EpisodeReranker.cross_encoder:
|
|
376
449
|
# use rrf as a preliminary reranker
|
|
377
|
-
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
450
|
+
rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
378
451
|
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
379
452
|
|
|
380
453
|
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
|
|
@@ -385,14 +458,15 @@ async def episode_search(
|
|
|
385
458
|
for content, score in reranked_contents
|
|
386
459
|
if score >= reranker_min_score
|
|
387
460
|
]
|
|
461
|
+
episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
|
|
388
462
|
|
|
389
463
|
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
|
|
390
464
|
|
|
391
|
-
return reranked_episodes[:limit]
|
|
465
|
+
return reranked_episodes[:limit], episode_scores[:limit]
|
|
392
466
|
|
|
393
467
|
|
|
394
468
|
async def community_search(
|
|
395
|
-
driver:
|
|
469
|
+
driver: GraphDriver,
|
|
396
470
|
cross_encoder: CrossEncoderClient,
|
|
397
471
|
query: str,
|
|
398
472
|
query_vector: list[float],
|
|
@@ -400,9 +474,9 @@ async def community_search(
|
|
|
400
474
|
config: CommunitySearchConfig | None,
|
|
401
475
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
402
476
|
reranker_min_score: float = 0,
|
|
403
|
-
) -> list[CommunityNode]:
|
|
477
|
+
) -> tuple[list[CommunityNode], list[float]]:
|
|
404
478
|
if config is None:
|
|
405
|
-
return []
|
|
479
|
+
return [], []
|
|
406
480
|
|
|
407
481
|
search_results: list[list[CommunityNode]] = list(
|
|
408
482
|
await semaphore_gather(
|
|
@@ -421,14 +495,15 @@ async def community_search(
|
|
|
421
495
|
}
|
|
422
496
|
|
|
423
497
|
reranked_uuids: list[str] = []
|
|
498
|
+
community_scores: list[float] = []
|
|
424
499
|
if config.reranker == CommunityReranker.rrf:
|
|
425
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
500
|
+
reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
426
501
|
elif config.reranker == CommunityReranker.mmr:
|
|
427
502
|
search_result_uuids_and_vectors = await get_embeddings_for_communities(
|
|
428
503
|
driver, list(community_uuid_map.values())
|
|
429
504
|
)
|
|
430
505
|
|
|
431
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
506
|
+
reranked_uuids, community_scores = maximal_marginal_relevance(
|
|
432
507
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
|
|
433
508
|
)
|
|
434
509
|
elif config.reranker == CommunityReranker.cross_encoder:
|
|
@@ -437,7 +512,8 @@ async def community_search(
|
|
|
437
512
|
reranked_uuids = [
|
|
438
513
|
name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
|
|
439
514
|
]
|
|
515
|
+
community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
|
|
440
516
|
|
|
441
517
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
442
518
|
|
|
443
|
-
return reranked_communities[:limit]
|
|
519
|
+
return reranked_communities[:limit], community_scores[:limit]
|
|
@@ -119,7 +119,42 @@ class SearchConfig(BaseModel):
|
|
|
119
119
|
|
|
120
120
|
|
|
121
121
|
class SearchResults(BaseModel):
|
|
122
|
-
edges: list[EntityEdge]
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
122
|
+
edges: list[EntityEdge] = Field(default_factory=list)
|
|
123
|
+
edge_reranker_scores: list[float] = Field(default_factory=list)
|
|
124
|
+
nodes: list[EntityNode] = Field(default_factory=list)
|
|
125
|
+
node_reranker_scores: list[float] = Field(default_factory=list)
|
|
126
|
+
episodes: list[EpisodicNode] = Field(default_factory=list)
|
|
127
|
+
episode_reranker_scores: list[float] = Field(default_factory=list)
|
|
128
|
+
communities: list[CommunityNode] = Field(default_factory=list)
|
|
129
|
+
community_reranker_scores: list[float] = Field(default_factory=list)
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
|
|
133
|
+
"""
|
|
134
|
+
Merge multiple SearchResults objects into a single SearchResults object.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
results_list : list[SearchResults]
|
|
139
|
+
List of SearchResults objects to merge
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
SearchResults
|
|
144
|
+
A single SearchResults object containing all results
|
|
145
|
+
"""
|
|
146
|
+
if not results_list:
|
|
147
|
+
return cls()
|
|
148
|
+
|
|
149
|
+
merged = cls()
|
|
150
|
+
for result in results_list:
|
|
151
|
+
merged.edges.extend(result.edges)
|
|
152
|
+
merged.edge_reranker_scores.extend(result.edge_reranker_scores)
|
|
153
|
+
merged.nodes.extend(result.nodes)
|
|
154
|
+
merged.node_reranker_scores.extend(result.node_reranker_scores)
|
|
155
|
+
merged.episodes.extend(result.episodes)
|
|
156
|
+
merged.episode_reranker_scores.extend(result.episode_reranker_scores)
|
|
157
|
+
merged.communities.extend(result.communities)
|
|
158
|
+
merged.community_reranker_scores.extend(result.community_reranker_scores)
|
|
159
|
+
|
|
160
|
+
return merged
|