graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
@@ -18,10 +18,10 @@ import logging
18
18
  from collections import defaultdict
19
19
  from time import time
20
20
 
21
- from neo4j import AsyncDriver
22
-
23
21
  from graphiti_core.cross_encoder.client import CrossEncoderClient
22
+ from graphiti_core.driver.driver import GraphDriver
24
23
  from graphiti_core.edges import EntityEdge
24
+ from graphiti_core.embedder.client import EMBEDDING_DIM
25
25
  from graphiti_core.errors import SearchRerankerError
26
26
  from graphiti_core.graphiti_types import GraphitiClients
27
27
  from graphiti_core.helpers import semaphore_gather
@@ -30,6 +30,7 @@ from graphiti_core.search.search_config import (
30
30
  DEFAULT_SEARCH_LIMIT,
31
31
  CommunityReranker,
32
32
  CommunitySearchConfig,
33
+ CommunitySearchMethod,
33
34
  EdgeReranker,
34
35
  EdgeSearchConfig,
35
36
  EdgeSearchMethod,
@@ -73,34 +74,53 @@ async def search(
73
74
  center_node_uuid: str | None = None,
74
75
  bfs_origin_node_uuids: list[str] | None = None,
75
76
  query_vector: list[float] | None = None,
77
+ driver: GraphDriver | None = None,
76
78
  ) -> SearchResults:
77
79
  start = time()
78
80
 
79
- driver = clients.driver
81
+ driver = driver or clients.driver
80
82
  embedder = clients.embedder
81
83
  cross_encoder = clients.cross_encoder
82
84
 
83
85
  if query.strip() == '':
84
- return SearchResults(
85
- edges=[],
86
- nodes=[],
87
- episodes=[],
88
- communities=[],
86
+ return SearchResults()
87
+
88
+ if (
89
+ config.edge_config
90
+ and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
91
+ or config.edge_config
92
+ and EdgeReranker.mmr == config.edge_config.reranker
93
+ or config.node_config
94
+ and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
95
+ or config.node_config
96
+ and NodeReranker.mmr == config.node_config.reranker
97
+ or (
98
+ config.community_config
99
+ and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
89
100
  )
90
- query_vector = (
91
- query_vector
92
- if query_vector is not None
93
- else await embedder.create(input_data=[query.replace('\n', ' ')])
94
- )
101
+ or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
102
+ ):
103
+ search_vector = (
104
+ query_vector
105
+ if query_vector is not None
106
+ else await embedder.create(input_data=[query.replace('\n', ' ')])
107
+ )
108
+ else:
109
+ search_vector = [0.0] * EMBEDDING_DIM
95
110
 
96
111
  # if group_ids is empty, set it to None
97
- group_ids = group_ids if group_ids else None
98
- edges, nodes, episodes, communities = await semaphore_gather(
112
+ group_ids = group_ids if group_ids and group_ids != [''] else None
113
+ (
114
+ (edges, edge_reranker_scores),
115
+ (nodes, node_reranker_scores),
116
+ (episodes, episode_reranker_scores),
117
+ (communities, community_reranker_scores),
118
+ ) = await semaphore_gather(
99
119
  edge_search(
100
120
  driver,
101
121
  cross_encoder,
102
122
  query,
103
- query_vector,
123
+ search_vector,
104
124
  group_ids,
105
125
  config.edge_config,
106
126
  search_filter,
@@ -113,7 +133,7 @@ async def search(
113
133
  driver,
114
134
  cross_encoder,
115
135
  query,
116
- query_vector,
136
+ search_vector,
117
137
  group_ids,
118
138
  config.node_config,
119
139
  search_filter,
@@ -126,7 +146,7 @@ async def search(
126
146
  driver,
127
147
  cross_encoder,
128
148
  query,
129
- query_vector,
149
+ search_vector,
130
150
  group_ids,
131
151
  config.episode_config,
132
152
  search_filter,
@@ -137,7 +157,7 @@ async def search(
137
157
  driver,
138
158
  cross_encoder,
139
159
  query,
140
- query_vector,
160
+ search_vector,
141
161
  group_ids,
142
162
  config.community_config,
143
163
  config.limit,
@@ -147,9 +167,13 @@ async def search(
147
167
 
148
168
  results = SearchResults(
149
169
  edges=edges,
170
+ edge_reranker_scores=edge_reranker_scores,
150
171
  nodes=nodes,
172
+ node_reranker_scores=node_reranker_scores,
151
173
  episodes=episodes,
174
+ episode_reranker_scores=episode_reranker_scores,
152
175
  communities=communities,
176
+ community_reranker_scores=community_reranker_scores,
153
177
  )
154
178
 
155
179
  latency = (time() - start) * 1000
@@ -160,7 +184,7 @@ async def search(
160
184
 
161
185
 
162
186
  async def edge_search(
163
- driver: AsyncDriver,
187
+ driver: GraphDriver,
164
188
  cross_encoder: CrossEncoderClient,
165
189
  query: str,
166
190
  query_vector: list[float],
@@ -171,51 +195,72 @@ async def edge_search(
171
195
  bfs_origin_node_uuids: list[str] | None = None,
172
196
  limit=DEFAULT_SEARCH_LIMIT,
173
197
  reranker_min_score: float = 0,
174
- ) -> list[EntityEdge]:
198
+ ) -> tuple[list[EntityEdge], list[float]]:
175
199
  if config is None:
176
- return []
200
+ return [], []
177
201
 
178
- search_results: list[list[EntityEdge]] = list(
179
- await semaphore_gather(
180
- *[
181
- edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
182
- edge_similarity_search(
183
- driver,
184
- query_vector,
185
- None,
186
- None,
187
- search_filter,
188
- group_ids,
189
- 2 * limit,
190
- config.sim_min_score,
191
- ),
192
- edge_bfs_search(
193
- driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
194
- ),
195
- ]
202
+ # Build search tasks based on configured search methods
203
+ search_tasks = []
204
+ if EdgeSearchMethod.bm25 in config.search_methods:
205
+ search_tasks.append(
206
+ edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
207
+ )
208
+ if EdgeSearchMethod.cosine_similarity in config.search_methods:
209
+ search_tasks.append(
210
+ edge_similarity_search(
211
+ driver,
212
+ query_vector,
213
+ None,
214
+ None,
215
+ search_filter,
216
+ group_ids,
217
+ 2 * limit,
218
+ config.sim_min_score,
219
+ )
220
+ )
221
+ if EdgeSearchMethod.bfs in config.search_methods:
222
+ search_tasks.append(
223
+ edge_bfs_search(
224
+ driver,
225
+ bfs_origin_node_uuids,
226
+ config.bfs_max_depth,
227
+ search_filter,
228
+ group_ids,
229
+ 2 * limit,
230
+ )
196
231
  )
197
- )
232
+
233
+ # Execute only the configured search methods
234
+ search_results: list[list[EntityEdge]] = []
235
+ if search_tasks:
236
+ search_results = list(await semaphore_gather(*search_tasks))
198
237
 
199
238
  if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
200
239
  source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
201
240
  search_results.append(
202
241
  await edge_bfs_search(
203
- driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
242
+ driver,
243
+ source_node_uuids,
244
+ config.bfs_max_depth,
245
+ search_filter,
246
+ group_ids,
247
+ 2 * limit,
204
248
  )
205
249
  )
206
250
 
207
251
  edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
208
252
 
209
253
  reranked_uuids: list[str] = []
254
+ edge_scores: list[float] = []
210
255
  if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
211
256
  search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
212
257
 
213
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
258
+ reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
214
259
  elif config.reranker == EdgeReranker.mmr:
215
260
  search_result_uuids_and_vectors = await get_embeddings_for_edges(
216
261
  driver, list(edge_uuid_map.values())
217
262
  )
218
- reranked_uuids = maximal_marginal_relevance(
263
+ reranked_uuids, edge_scores = maximal_marginal_relevance(
219
264
  query_vector,
220
265
  search_result_uuids_and_vectors,
221
266
  config.mmr_lambda,
@@ -227,12 +272,13 @@ async def edge_search(
227
272
  reranked_uuids = [
228
273
  fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
229
274
  ]
275
+ edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
230
276
  elif config.reranker == EdgeReranker.node_distance:
231
277
  if center_node_uuid is None:
232
278
  raise SearchRerankerError('No center node provided for Node Distance reranker')
233
279
 
234
280
  # use rrf as a preliminary sort
235
- sorted_result_uuids = rrf(
281
+ sorted_result_uuids, node_scores = rrf(
236
282
  [[edge.uuid for edge in result] for result in search_results],
237
283
  min_score=reranker_min_score,
238
284
  )
@@ -245,7 +291,7 @@ async def edge_search(
245
291
 
246
292
  source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
247
293
 
248
- reranked_node_uuids = await node_distance_reranker(
294
+ reranked_node_uuids, edge_scores = await node_distance_reranker(
249
295
  driver, source_uuids, center_node_uuid, min_score=reranker_min_score
250
296
  )
251
297
 
@@ -257,11 +303,11 @@ async def edge_search(
257
303
  if config.reranker == EdgeReranker.episode_mentions:
258
304
  reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
259
305
 
260
- return reranked_edges[:limit]
306
+ return reranked_edges[:limit], edge_scores[:limit]
261
307
 
262
308
 
263
309
  async def node_search(
264
- driver: AsyncDriver,
310
+ driver: GraphDriver,
265
311
  cross_encoder: CrossEncoderClient,
266
312
  query: str,
267
313
  query_vector: list[float],
@@ -272,29 +318,54 @@ async def node_search(
272
318
  bfs_origin_node_uuids: list[str] | None = None,
273
319
  limit=DEFAULT_SEARCH_LIMIT,
274
320
  reranker_min_score: float = 0,
275
- ) -> list[EntityNode]:
321
+ ) -> tuple[list[EntityNode], list[float]]:
276
322
  if config is None:
277
- return []
323
+ return [], []
278
324
 
279
- search_results: list[list[EntityNode]] = list(
280
- await semaphore_gather(
281
- *[
282
- node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
283
- node_similarity_search(
284
- driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
285
- ),
286
- node_bfs_search(
287
- driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
288
- ),
289
- ]
325
+ # Build search tasks based on configured search methods
326
+ search_tasks = []
327
+ if NodeSearchMethod.bm25 in config.search_methods:
328
+ search_tasks.append(
329
+ node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
290
330
  )
291
- )
331
+ if NodeSearchMethod.cosine_similarity in config.search_methods:
332
+ search_tasks.append(
333
+ node_similarity_search(
334
+ driver,
335
+ query_vector,
336
+ search_filter,
337
+ group_ids,
338
+ 2 * limit,
339
+ config.sim_min_score,
340
+ )
341
+ )
342
+ if NodeSearchMethod.bfs in config.search_methods:
343
+ search_tasks.append(
344
+ node_bfs_search(
345
+ driver,
346
+ bfs_origin_node_uuids,
347
+ search_filter,
348
+ config.bfs_max_depth,
349
+ group_ids,
350
+ 2 * limit,
351
+ )
352
+ )
353
+
354
+ # Execute only the configured search methods
355
+ search_results: list[list[EntityNode]] = []
356
+ if search_tasks:
357
+ search_results = list(await semaphore_gather(*search_tasks))
292
358
 
293
359
  if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
294
360
  origin_node_uuids = [node.uuid for result in search_results for node in result]
295
361
  search_results.append(
296
362
  await node_bfs_search(
297
- driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
363
+ driver,
364
+ origin_node_uuids,
365
+ search_filter,
366
+ config.bfs_max_depth,
367
+ group_ids,
368
+ 2 * limit,
298
369
  )
299
370
  )
300
371
 
@@ -302,14 +373,15 @@ async def node_search(
302
373
  node_uuid_map = {node.uuid: node for result in search_results for node in result}
303
374
 
304
375
  reranked_uuids: list[str] = []
376
+ node_scores: list[float] = []
305
377
  if config.reranker == NodeReranker.rrf:
306
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
378
+ reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
307
379
  elif config.reranker == NodeReranker.mmr:
308
380
  search_result_uuids_and_vectors = await get_embeddings_for_nodes(
309
381
  driver, list(node_uuid_map.values())
310
382
  )
311
383
 
312
- reranked_uuids = maximal_marginal_relevance(
384
+ reranked_uuids, node_scores = maximal_marginal_relevance(
313
385
  query_vector,
314
386
  search_result_uuids_and_vectors,
315
387
  config.mmr_lambda,
@@ -324,27 +396,28 @@ async def node_search(
324
396
  for name, score in reranked_node_names
325
397
  if score >= reranker_min_score
326
398
  ]
399
+ node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
327
400
  elif config.reranker == NodeReranker.episode_mentions:
328
- reranked_uuids = await episode_mentions_reranker(
401
+ reranked_uuids, node_scores = await episode_mentions_reranker(
329
402
  driver, search_result_uuids, min_score=reranker_min_score
330
403
  )
331
404
  elif config.reranker == NodeReranker.node_distance:
332
405
  if center_node_uuid is None:
333
406
  raise SearchRerankerError('No center node provided for Node Distance reranker')
334
- reranked_uuids = await node_distance_reranker(
407
+ reranked_uuids, node_scores = await node_distance_reranker(
335
408
  driver,
336
- rrf(search_result_uuids, min_score=reranker_min_score),
409
+ rrf(search_result_uuids, min_score=reranker_min_score)[0],
337
410
  center_node_uuid,
338
411
  min_score=reranker_min_score,
339
412
  )
340
413
 
341
414
  reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
342
415
 
343
- return reranked_nodes[:limit]
416
+ return reranked_nodes[:limit], node_scores[:limit]
344
417
 
345
418
 
346
419
  async def episode_search(
347
- driver: AsyncDriver,
420
+ driver: GraphDriver,
348
421
  cross_encoder: CrossEncoderClient,
349
422
  query: str,
350
423
  _query_vector: list[float],
@@ -353,10 +426,9 @@ async def episode_search(
353
426
  search_filter: SearchFilters,
354
427
  limit=DEFAULT_SEARCH_LIMIT,
355
428
  reranker_min_score: float = 0,
356
- ) -> list[EpisodicNode]:
429
+ ) -> tuple[list[EpisodicNode], list[float]]:
357
430
  if config is None:
358
- return []
359
-
431
+ return [], []
360
432
  search_results: list[list[EpisodicNode]] = list(
361
433
  await semaphore_gather(
362
434
  *[
@@ -369,12 +441,13 @@ async def episode_search(
369
441
  episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
370
442
 
371
443
  reranked_uuids: list[str] = []
444
+ episode_scores: list[float] = []
372
445
  if config.reranker == EpisodeReranker.rrf:
373
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
446
+ reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
374
447
 
375
448
  elif config.reranker == EpisodeReranker.cross_encoder:
376
449
  # use rrf as a preliminary reranker
377
- rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
450
+ rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
378
451
  rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
379
452
 
380
453
  content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
@@ -385,14 +458,15 @@ async def episode_search(
385
458
  for content, score in reranked_contents
386
459
  if score >= reranker_min_score
387
460
  ]
461
+ episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
388
462
 
389
463
  reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
390
464
 
391
- return reranked_episodes[:limit]
465
+ return reranked_episodes[:limit], episode_scores[:limit]
392
466
 
393
467
 
394
468
  async def community_search(
395
- driver: AsyncDriver,
469
+ driver: GraphDriver,
396
470
  cross_encoder: CrossEncoderClient,
397
471
  query: str,
398
472
  query_vector: list[float],
@@ -400,9 +474,9 @@ async def community_search(
400
474
  config: CommunitySearchConfig | None,
401
475
  limit=DEFAULT_SEARCH_LIMIT,
402
476
  reranker_min_score: float = 0,
403
- ) -> list[CommunityNode]:
477
+ ) -> tuple[list[CommunityNode], list[float]]:
404
478
  if config is None:
405
- return []
479
+ return [], []
406
480
 
407
481
  search_results: list[list[CommunityNode]] = list(
408
482
  await semaphore_gather(
@@ -421,14 +495,15 @@ async def community_search(
421
495
  }
422
496
 
423
497
  reranked_uuids: list[str] = []
498
+ community_scores: list[float] = []
424
499
  if config.reranker == CommunityReranker.rrf:
425
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
500
+ reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
426
501
  elif config.reranker == CommunityReranker.mmr:
427
502
  search_result_uuids_and_vectors = await get_embeddings_for_communities(
428
503
  driver, list(community_uuid_map.values())
429
504
  )
430
505
 
431
- reranked_uuids = maximal_marginal_relevance(
506
+ reranked_uuids, community_scores = maximal_marginal_relevance(
432
507
  query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
433
508
  )
434
509
  elif config.reranker == CommunityReranker.cross_encoder:
@@ -437,7 +512,8 @@ async def community_search(
437
512
  reranked_uuids = [
438
513
  name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
439
514
  ]
515
+ community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
440
516
 
441
517
  reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
442
518
 
443
- return reranked_communities[:limit]
519
+ return reranked_communities[:limit], community_scores[:limit]
@@ -119,7 +119,42 @@ class SearchConfig(BaseModel):
119
119
 
120
120
 
121
121
  class SearchResults(BaseModel):
122
- edges: list[EntityEdge]
123
- nodes: list[EntityNode]
124
- episodes: list[EpisodicNode]
125
- communities: list[CommunityNode]
122
+ edges: list[EntityEdge] = Field(default_factory=list)
123
+ edge_reranker_scores: list[float] = Field(default_factory=list)
124
+ nodes: list[EntityNode] = Field(default_factory=list)
125
+ node_reranker_scores: list[float] = Field(default_factory=list)
126
+ episodes: list[EpisodicNode] = Field(default_factory=list)
127
+ episode_reranker_scores: list[float] = Field(default_factory=list)
128
+ communities: list[CommunityNode] = Field(default_factory=list)
129
+ community_reranker_scores: list[float] = Field(default_factory=list)
130
+
131
+ @classmethod
132
+ def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
133
+ """
134
+ Merge multiple SearchResults objects into a single SearchResults object.
135
+
136
+ Parameters
137
+ ----------
138
+ results_list : list[SearchResults]
139
+ List of SearchResults objects to merge
140
+
141
+ Returns
142
+ -------
143
+ SearchResults
144
+ A single SearchResults object containing all results
145
+ """
146
+ if not results_list:
147
+ return cls()
148
+
149
+ merged = cls()
150
+ for result in results_list:
151
+ merged.edges.extend(result.edges)
152
+ merged.edge_reranker_scores.extend(result.edge_reranker_scores)
153
+ merged.nodes.extend(result.nodes)
154
+ merged.node_reranker_scores.extend(result.node_reranker_scores)
155
+ merged.episodes.extend(result.episodes)
156
+ merged.episode_reranker_scores.extend(result.episode_reranker_scores)
157
+ merged.communities.extend(result.communities)
158
+ merged.community_reranker_scores.extend(result.community_reranker_scores)
159
+
160
+ return merged