graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +61 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +582 -255
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +21 -14
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +94 -50
  24. graphiti_core/llm_client/openai_client.py +28 -8
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +23 -14
  33. graphiti_core/prompts/extract_nodes.py +73 -32
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +109 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/datetime_utils.py +13 -0
  45. graphiti_core/utils/maintenance/community_operations.py +62 -38
  46. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  47. graphiti_core/utils/maintenance/edge_operations.py +286 -126
  48. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  49. graphiti_core/utils/maintenance/node_operations.py +320 -158
  50. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  51. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  52. graphiti_core/utils/text_utils.py +53 -0
  53. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
  54. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  55. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  56. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  57. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  58. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
@@ -1 +1,40 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ from typing import Any
19
+
1
20
  DO_NOT_ESCAPE_UNICODE = '\nDo not escape unicode characters.\n'
21
+
22
+
23
+ def to_prompt_json(data: Any, ensure_ascii: bool = False, indent: int | None = None) -> str:
24
+ """
25
+ Serialize data to JSON for use in prompts.
26
+
27
+ Args:
28
+ data: The data to serialize
29
+ ensure_ascii: If True, escape non-ASCII characters. If False (default), preserve them.
30
+ indent: Number of spaces for indentation. Defaults to None (minified).
31
+
32
+ Returns:
33
+ JSON string representation of the data
34
+
35
+ Notes:
36
+ By default (ensure_ascii=False), non-ASCII characters (e.g., Korean, Japanese, Chinese)
37
+ are preserved in their original form in the prompt, making them readable
38
+ in LLM logs and improving model understanding.
39
+ """
40
+ return json.dumps(data, ensure_ascii=ensure_ascii, indent=indent)
@@ -0,0 +1,29 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ summary_instructions = """Guidelines:
18
+ 1. Output only factual content. Never explain what you're doing, why, or mention limitations/constraints.
19
+ 2. Only use the provided messages, entity, and entity context to set attribute values.
20
+ 3. Keep the summary concise and to the point. STATE FACTS DIRECTLY IN UNDER 250 CHARACTERS.
21
+
22
+ Example summaries:
23
+ BAD: "This is the only activity in the context. The user listened to this song. No other details were provided to include in this summary."
24
+ GOOD: "User played 'Blue Monday' by New Order (electronic genre) on 2024-12-03 at 14:22 UTC."
25
+ BAD: "Based on the messages provided, the user attended a meeting. This summary focuses on that event as it was the main topic discussed."
26
+ GOOD: "User attended Q3 planning meeting with sales team on March 15."
27
+ BAD: "The context shows John ordered pizza. Due to length constraints, other details are omitted from this summary."
28
+ GOOD: "John ordered pepperoni pizza from Mario's at 7:30 PM, delivered to office."
29
+ """
@@ -14,18 +14,19 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import json
18
17
  from typing import Any, Protocol, TypedDict
19
18
 
20
19
  from pydantic import BaseModel, Field
21
20
 
22
21
  from .models import Message, PromptFunction, PromptVersion
22
+ from .prompt_helpers import to_prompt_json
23
+ from .snippets import summary_instructions
23
24
 
24
25
 
25
26
  class Summary(BaseModel):
26
27
  summary: str = Field(
27
28
  ...,
28
- description='Summary containing the important information about the entity. Under 250 words',
29
+ description='Summary containing the important information about the entity. Under 250 characters',
29
30
  )
30
31
 
31
32
 
@@ -55,11 +56,11 @@ def summarize_pair(context: dict[str, Any]) -> list[Message]:
55
56
  role='user',
56
57
  content=f"""
57
58
  Synthesize the information from the following two summaries into a single succinct summary.
58
-
59
- Summaries must be under 250 words.
59
+
60
+ IMPORTANT: Keep the summary concise and to the point. SUMMARIES MUST BE LESS THAN 250 CHARACTERS.
60
61
 
61
62
  Summaries:
62
- {json.dumps(context['node_summaries'], indent=2)}
63
+ {to_prompt_json(context['node_summaries'])}
63
64
  """,
64
65
  ),
65
66
  ]
@@ -69,38 +70,35 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
69
70
  return [
70
71
  Message(
71
72
  role='system',
72
- content='You are a helpful assistant that extracts entity properties from the provided text.',
73
+ content='You are a helpful assistant that generates a summary and attributes from provided text.',
73
74
  ),
74
75
  Message(
75
76
  role='user',
76
77
  content=f"""
77
-
78
- <MESSAGES>
79
- {json.dumps(context['previous_episodes'], indent=2)}
80
- {json.dumps(context['episode_content'], indent=2)}
81
- </MESSAGES>
82
-
83
- Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
78
+ Given the MESSAGES and the ENTITY name, create a summary for the ENTITY. Your summary must only use
84
79
  information from the provided MESSAGES. Your summary should also only contain information relevant to the
85
- provided ENTITY. Summaries must be under 250 words.
86
-
80
+ provided ENTITY.
81
+
87
82
  In addition, extract any values for the provided entity properties based on their descriptions.
88
83
  If the value of the entity property cannot be found in the current context, set the value of the property to the Python value None.
89
-
90
- Guidelines:
91
- 1. Do not hallucinate entity property values if they cannot be found in the current context.
92
- 2. Only use the provided messages, entity, and entity context to set attribute values.
93
-
84
+
85
+ {summary_instructions}
86
+
87
+ <MESSAGES>
88
+ {to_prompt_json(context['previous_episodes'])}
89
+ {to_prompt_json(context['episode_content'])}
90
+ </MESSAGES>
91
+
94
92
  <ENTITY>
95
93
  {context['node_name']}
96
94
  </ENTITY>
97
-
95
+
98
96
  <ENTITY CONTEXT>
99
97
  {context['node_summary']}
100
98
  </ENTITY CONTEXT>
101
-
99
+
102
100
  <ATTRIBUTES>
103
- {json.dumps(context['attributes'], indent=2)}
101
+ {to_prompt_json(context['attributes'])}
104
102
  </ATTRIBUTES>
105
103
  """,
106
104
  ),
@@ -117,10 +115,10 @@ def summary_description(context: dict[str, Any]) -> list[Message]:
117
115
  role='user',
118
116
  content=f"""
119
117
  Create a short one sentence description of the summary that explains what kind of information is summarized.
120
- Summaries must be under 250 words.
118
+ Summaries must be under 250 characters.
121
119
 
122
120
  Summary:
123
- {json.dumps(context['summary'], indent=2)}
121
+ {to_prompt_json(context['summary'])}
124
122
  """,
125
123
  ),
126
124
  ]
@@ -21,6 +21,7 @@ from time import time
21
21
  from graphiti_core.cross_encoder.client import CrossEncoderClient
22
22
  from graphiti_core.driver.driver import GraphDriver
23
23
  from graphiti_core.edges import EntityEdge
24
+ from graphiti_core.embedder.client import EMBEDDING_DIM
24
25
  from graphiti_core.errors import SearchRerankerError
25
26
  from graphiti_core.graphiti_types import GraphitiClients
26
27
  from graphiti_core.helpers import semaphore_gather
@@ -29,6 +30,7 @@ from graphiti_core.search.search_config import (
29
30
  DEFAULT_SEARCH_LIMIT,
30
31
  CommunityReranker,
31
32
  CommunitySearchConfig,
33
+ CommunitySearchMethod,
32
34
  EdgeReranker,
33
35
  EdgeSearchConfig,
34
36
  EdgeSearchMethod,
@@ -72,34 +74,53 @@ async def search(
72
74
  center_node_uuid: str | None = None,
73
75
  bfs_origin_node_uuids: list[str] | None = None,
74
76
  query_vector: list[float] | None = None,
77
+ driver: GraphDriver | None = None,
75
78
  ) -> SearchResults:
76
79
  start = time()
77
80
 
78
- driver = clients.driver
81
+ driver = driver or clients.driver
79
82
  embedder = clients.embedder
80
83
  cross_encoder = clients.cross_encoder
81
84
 
82
85
  if query.strip() == '':
83
- return SearchResults(
84
- edges=[],
85
- nodes=[],
86
- episodes=[],
87
- communities=[],
86
+ return SearchResults()
87
+
88
+ if (
89
+ config.edge_config
90
+ and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
91
+ or config.edge_config
92
+ and EdgeReranker.mmr == config.edge_config.reranker
93
+ or config.node_config
94
+ and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
95
+ or config.node_config
96
+ and NodeReranker.mmr == config.node_config.reranker
97
+ or (
98
+ config.community_config
99
+ and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
88
100
  )
89
- query_vector = (
90
- query_vector
91
- if query_vector is not None
92
- else await embedder.create(input_data=[query.replace('\n', ' ')])
93
- )
101
+ or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
102
+ ):
103
+ search_vector = (
104
+ query_vector
105
+ if query_vector is not None
106
+ else await embedder.create(input_data=[query.replace('\n', ' ')])
107
+ )
108
+ else:
109
+ search_vector = [0.0] * EMBEDDING_DIM
94
110
 
95
111
  # if group_ids is empty, set it to None
96
112
  group_ids = group_ids if group_ids and group_ids != [''] else None
97
- edges, nodes, episodes, communities = await semaphore_gather(
113
+ (
114
+ (edges, edge_reranker_scores),
115
+ (nodes, node_reranker_scores),
116
+ (episodes, episode_reranker_scores),
117
+ (communities, community_reranker_scores),
118
+ ) = await semaphore_gather(
98
119
  edge_search(
99
120
  driver,
100
121
  cross_encoder,
101
122
  query,
102
- query_vector,
123
+ search_vector,
103
124
  group_ids,
104
125
  config.edge_config,
105
126
  search_filter,
@@ -112,7 +133,7 @@ async def search(
112
133
  driver,
113
134
  cross_encoder,
114
135
  query,
115
- query_vector,
136
+ search_vector,
116
137
  group_ids,
117
138
  config.node_config,
118
139
  search_filter,
@@ -125,7 +146,7 @@ async def search(
125
146
  driver,
126
147
  cross_encoder,
127
148
  query,
128
- query_vector,
149
+ search_vector,
129
150
  group_ids,
130
151
  config.episode_config,
131
152
  search_filter,
@@ -136,7 +157,7 @@ async def search(
136
157
  driver,
137
158
  cross_encoder,
138
159
  query,
139
- query_vector,
160
+ search_vector,
140
161
  group_ids,
141
162
  config.community_config,
142
163
  config.limit,
@@ -146,9 +167,13 @@ async def search(
146
167
 
147
168
  results = SearchResults(
148
169
  edges=edges,
170
+ edge_reranker_scores=edge_reranker_scores,
149
171
  nodes=nodes,
172
+ node_reranker_scores=node_reranker_scores,
150
173
  episodes=episodes,
174
+ episode_reranker_scores=episode_reranker_scores,
151
175
  communities=communities,
176
+ community_reranker_scores=community_reranker_scores,
152
177
  )
153
178
 
154
179
  latency = (time() - start) * 1000
@@ -170,50 +195,72 @@ async def edge_search(
170
195
  bfs_origin_node_uuids: list[str] | None = None,
171
196
  limit=DEFAULT_SEARCH_LIMIT,
172
197
  reranker_min_score: float = 0,
173
- ) -> list[EntityEdge]:
198
+ ) -> tuple[list[EntityEdge], list[float]]:
174
199
  if config is None:
175
- return []
176
- search_results: list[list[EntityEdge]] = list(
177
- await semaphore_gather(
178
- *[
179
- edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
180
- edge_similarity_search(
181
- driver,
182
- query_vector,
183
- None,
184
- None,
185
- search_filter,
186
- group_ids,
187
- 2 * limit,
188
- config.sim_min_score,
189
- ),
190
- edge_bfs_search(
191
- driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
192
- ),
193
- ]
200
+ return [], []
201
+
202
+ # Build search tasks based on configured search methods
203
+ search_tasks = []
204
+ if EdgeSearchMethod.bm25 in config.search_methods:
205
+ search_tasks.append(
206
+ edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
207
+ )
208
+ if EdgeSearchMethod.cosine_similarity in config.search_methods:
209
+ search_tasks.append(
210
+ edge_similarity_search(
211
+ driver,
212
+ query_vector,
213
+ None,
214
+ None,
215
+ search_filter,
216
+ group_ids,
217
+ 2 * limit,
218
+ config.sim_min_score,
219
+ )
220
+ )
221
+ if EdgeSearchMethod.bfs in config.search_methods:
222
+ search_tasks.append(
223
+ edge_bfs_search(
224
+ driver,
225
+ bfs_origin_node_uuids,
226
+ config.bfs_max_depth,
227
+ search_filter,
228
+ group_ids,
229
+ 2 * limit,
230
+ )
194
231
  )
195
- )
232
+
233
+ # Execute only the configured search methods
234
+ search_results: list[list[EntityEdge]] = []
235
+ if search_tasks:
236
+ search_results = list(await semaphore_gather(*search_tasks))
196
237
 
197
238
  if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
198
239
  source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
199
240
  search_results.append(
200
241
  await edge_bfs_search(
201
- driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
242
+ driver,
243
+ source_node_uuids,
244
+ config.bfs_max_depth,
245
+ search_filter,
246
+ group_ids,
247
+ 2 * limit,
202
248
  )
203
249
  )
204
250
 
205
251
  edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
206
252
 
207
253
  reranked_uuids: list[str] = []
254
+ edge_scores: list[float] = []
208
255
  if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
209
256
  search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
210
257
 
211
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
258
+ reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
212
259
  elif config.reranker == EdgeReranker.mmr:
213
260
  search_result_uuids_and_vectors = await get_embeddings_for_edges(
214
261
  driver, list(edge_uuid_map.values())
215
262
  )
216
- reranked_uuids = maximal_marginal_relevance(
263
+ reranked_uuids, edge_scores = maximal_marginal_relevance(
217
264
  query_vector,
218
265
  search_result_uuids_and_vectors,
219
266
  config.mmr_lambda,
@@ -225,12 +272,13 @@ async def edge_search(
225
272
  reranked_uuids = [
226
273
  fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
227
274
  ]
275
+ edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
228
276
  elif config.reranker == EdgeReranker.node_distance:
229
277
  if center_node_uuid is None:
230
278
  raise SearchRerankerError('No center node provided for Node Distance reranker')
231
279
 
232
280
  # use rrf as a preliminary sort
233
- sorted_result_uuids = rrf(
281
+ sorted_result_uuids, node_scores = rrf(
234
282
  [[edge.uuid for edge in result] for result in search_results],
235
283
  min_score=reranker_min_score,
236
284
  )
@@ -243,7 +291,7 @@ async def edge_search(
243
291
 
244
292
  source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
245
293
 
246
- reranked_node_uuids = await node_distance_reranker(
294
+ reranked_node_uuids, edge_scores = await node_distance_reranker(
247
295
  driver, source_uuids, center_node_uuid, min_score=reranker_min_score
248
296
  )
249
297
 
@@ -255,7 +303,7 @@ async def edge_search(
255
303
  if config.reranker == EdgeReranker.episode_mentions:
256
304
  reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
257
305
 
258
- return reranked_edges[:limit]
306
+ return reranked_edges[:limit], edge_scores[:limit]
259
307
 
260
308
 
261
309
  async def node_search(
@@ -270,28 +318,54 @@ async def node_search(
270
318
  bfs_origin_node_uuids: list[str] | None = None,
271
319
  limit=DEFAULT_SEARCH_LIMIT,
272
320
  reranker_min_score: float = 0,
273
- ) -> list[EntityNode]:
321
+ ) -> tuple[list[EntityNode], list[float]]:
274
322
  if config is None:
275
- return []
276
- search_results: list[list[EntityNode]] = list(
277
- await semaphore_gather(
278
- *[
279
- node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
280
- node_similarity_search(
281
- driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
282
- ),
283
- node_bfs_search(
284
- driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
285
- ),
286
- ]
323
+ return [], []
324
+
325
+ # Build search tasks based on configured search methods
326
+ search_tasks = []
327
+ if NodeSearchMethod.bm25 in config.search_methods:
328
+ search_tasks.append(
329
+ node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
287
330
  )
288
- )
331
+ if NodeSearchMethod.cosine_similarity in config.search_methods:
332
+ search_tasks.append(
333
+ node_similarity_search(
334
+ driver,
335
+ query_vector,
336
+ search_filter,
337
+ group_ids,
338
+ 2 * limit,
339
+ config.sim_min_score,
340
+ )
341
+ )
342
+ if NodeSearchMethod.bfs in config.search_methods:
343
+ search_tasks.append(
344
+ node_bfs_search(
345
+ driver,
346
+ bfs_origin_node_uuids,
347
+ search_filter,
348
+ config.bfs_max_depth,
349
+ group_ids,
350
+ 2 * limit,
351
+ )
352
+ )
353
+
354
+ # Execute only the configured search methods
355
+ search_results: list[list[EntityNode]] = []
356
+ if search_tasks:
357
+ search_results = list(await semaphore_gather(*search_tasks))
289
358
 
290
359
  if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
291
360
  origin_node_uuids = [node.uuid for result in search_results for node in result]
292
361
  search_results.append(
293
362
  await node_bfs_search(
294
- driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
363
+ driver,
364
+ origin_node_uuids,
365
+ search_filter,
366
+ config.bfs_max_depth,
367
+ group_ids,
368
+ 2 * limit,
295
369
  )
296
370
  )
297
371
 
@@ -299,14 +373,15 @@ async def node_search(
299
373
  node_uuid_map = {node.uuid: node for result in search_results for node in result}
300
374
 
301
375
  reranked_uuids: list[str] = []
376
+ node_scores: list[float] = []
302
377
  if config.reranker == NodeReranker.rrf:
303
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
378
+ reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
304
379
  elif config.reranker == NodeReranker.mmr:
305
380
  search_result_uuids_and_vectors = await get_embeddings_for_nodes(
306
381
  driver, list(node_uuid_map.values())
307
382
  )
308
383
 
309
- reranked_uuids = maximal_marginal_relevance(
384
+ reranked_uuids, node_scores = maximal_marginal_relevance(
310
385
  query_vector,
311
386
  search_result_uuids_and_vectors,
312
387
  config.mmr_lambda,
@@ -321,23 +396,24 @@ async def node_search(
321
396
  for name, score in reranked_node_names
322
397
  if score >= reranker_min_score
323
398
  ]
399
+ node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
324
400
  elif config.reranker == NodeReranker.episode_mentions:
325
- reranked_uuids = await episode_mentions_reranker(
401
+ reranked_uuids, node_scores = await episode_mentions_reranker(
326
402
  driver, search_result_uuids, min_score=reranker_min_score
327
403
  )
328
404
  elif config.reranker == NodeReranker.node_distance:
329
405
  if center_node_uuid is None:
330
406
  raise SearchRerankerError('No center node provided for Node Distance reranker')
331
- reranked_uuids = await node_distance_reranker(
407
+ reranked_uuids, node_scores = await node_distance_reranker(
332
408
  driver,
333
- rrf(search_result_uuids, min_score=reranker_min_score),
409
+ rrf(search_result_uuids, min_score=reranker_min_score)[0],
334
410
  center_node_uuid,
335
411
  min_score=reranker_min_score,
336
412
  )
337
413
 
338
414
  reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
339
415
 
340
- return reranked_nodes[:limit]
416
+ return reranked_nodes[:limit], node_scores[:limit]
341
417
 
342
418
 
343
419
  async def episode_search(
@@ -350,9 +426,9 @@ async def episode_search(
350
426
  search_filter: SearchFilters,
351
427
  limit=DEFAULT_SEARCH_LIMIT,
352
428
  reranker_min_score: float = 0,
353
- ) -> list[EpisodicNode]:
429
+ ) -> tuple[list[EpisodicNode], list[float]]:
354
430
  if config is None:
355
- return []
431
+ return [], []
356
432
  search_results: list[list[EpisodicNode]] = list(
357
433
  await semaphore_gather(
358
434
  *[
@@ -365,12 +441,13 @@ async def episode_search(
365
441
  episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
366
442
 
367
443
  reranked_uuids: list[str] = []
444
+ episode_scores: list[float] = []
368
445
  if config.reranker == EpisodeReranker.rrf:
369
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
446
+ reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
370
447
 
371
448
  elif config.reranker == EpisodeReranker.cross_encoder:
372
449
  # use rrf as a preliminary reranker
373
- rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
450
+ rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
374
451
  rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
375
452
 
376
453
  content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
@@ -381,10 +458,11 @@ async def episode_search(
381
458
  for content, score in reranked_contents
382
459
  if score >= reranker_min_score
383
460
  ]
461
+ episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
384
462
 
385
463
  reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
386
464
 
387
- return reranked_episodes[:limit]
465
+ return reranked_episodes[:limit], episode_scores[:limit]
388
466
 
389
467
 
390
468
  async def community_search(
@@ -396,9 +474,9 @@ async def community_search(
396
474
  config: CommunitySearchConfig | None,
397
475
  limit=DEFAULT_SEARCH_LIMIT,
398
476
  reranker_min_score: float = 0,
399
- ) -> list[CommunityNode]:
477
+ ) -> tuple[list[CommunityNode], list[float]]:
400
478
  if config is None:
401
- return []
479
+ return [], []
402
480
 
403
481
  search_results: list[list[CommunityNode]] = list(
404
482
  await semaphore_gather(
@@ -417,14 +495,15 @@ async def community_search(
417
495
  }
418
496
 
419
497
  reranked_uuids: list[str] = []
498
+ community_scores: list[float] = []
420
499
  if config.reranker == CommunityReranker.rrf:
421
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
500
+ reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
422
501
  elif config.reranker == CommunityReranker.mmr:
423
502
  search_result_uuids_and_vectors = await get_embeddings_for_communities(
424
503
  driver, list(community_uuid_map.values())
425
504
  )
426
505
 
427
- reranked_uuids = maximal_marginal_relevance(
506
+ reranked_uuids, community_scores = maximal_marginal_relevance(
428
507
  query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
429
508
  )
430
509
  elif config.reranker == CommunityReranker.cross_encoder:
@@ -433,7 +512,8 @@ async def community_search(
433
512
  reranked_uuids = [
434
513
  name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
435
514
  ]
515
+ community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
436
516
 
437
517
  reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
438
518
 
439
- return reranked_communities[:limit]
519
+ return reranked_communities[:limit], community_scores[:limit]
@@ -119,7 +119,42 @@ class SearchConfig(BaseModel):
119
119
 
120
120
 
121
121
  class SearchResults(BaseModel):
122
- edges: list[EntityEdge]
123
- nodes: list[EntityNode]
124
- episodes: list[EpisodicNode]
125
- communities: list[CommunityNode]
122
+ edges: list[EntityEdge] = Field(default_factory=list)
123
+ edge_reranker_scores: list[float] = Field(default_factory=list)
124
+ nodes: list[EntityNode] = Field(default_factory=list)
125
+ node_reranker_scores: list[float] = Field(default_factory=list)
126
+ episodes: list[EpisodicNode] = Field(default_factory=list)
127
+ episode_reranker_scores: list[float] = Field(default_factory=list)
128
+ communities: list[CommunityNode] = Field(default_factory=list)
129
+ community_reranker_scores: list[float] = Field(default_factory=list)
130
+
131
+ @classmethod
132
+ def merge(cls, results_list: list['SearchResults']) -> 'SearchResults':
133
+ """
134
+ Merge multiple SearchResults objects into a single SearchResults object.
135
+
136
+ Parameters
137
+ ----------
138
+ results_list : list[SearchResults]
139
+ List of SearchResults objects to merge
140
+
141
+ Returns
142
+ -------
143
+ SearchResults
144
+ A single SearchResults object containing all results
145
+ """
146
+ if not results_list:
147
+ return cls()
148
+
149
+ merged = cls()
150
+ for result in results_list:
151
+ merged.edges.extend(result.edges)
152
+ merged.edge_reranker_scores.extend(result.edge_reranker_scores)
153
+ merged.nodes.extend(result.nodes)
154
+ merged.node_reranker_scores.extend(result.node_reranker_scores)
155
+ merged.episodes.extend(result.episodes)
156
+ merged.episode_reranker_scores.extend(result.episode_reranker_scores)
157
+ merged.communities.extend(result.communities)
158
+ merged.community_reranker_scores.extend(result.community_reranker_scores)
159
+
160
+ return merged