graphiti-core 0.17.11__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -80,12 +80,7 @@ async def search(
80
80
  cross_encoder = clients.cross_encoder
81
81
 
82
82
  if query.strip() == '':
83
- return SearchResults(
84
- edges=[],
85
- nodes=[],
86
- episodes=[],
87
- communities=[],
88
- )
83
+ return SearchResults()
89
84
  query_vector = (
90
85
  query_vector
91
86
  if query_vector is not None
@@ -94,7 +89,12 @@ async def search(
94
89
 
95
90
  # if group_ids is empty, set it to None
96
91
  group_ids = group_ids if group_ids and group_ids != [''] else None
97
- edges, nodes, episodes, communities = await semaphore_gather(
92
+ (
93
+ (edges, edge_reranker_scores),
94
+ (nodes, node_reranker_scores),
95
+ (episodes, episode_reranker_scores),
96
+ (communities, community_reranker_scores),
97
+ ) = await semaphore_gather(
98
98
  edge_search(
99
99
  driver,
100
100
  cross_encoder,
@@ -146,9 +146,13 @@ async def search(
146
146
 
147
147
  results = SearchResults(
148
148
  edges=edges,
149
+ edge_reranker_scores=edge_reranker_scores,
149
150
  nodes=nodes,
151
+ node_reranker_scores=node_reranker_scores,
150
152
  episodes=episodes,
153
+ episode_reranker_scores=episode_reranker_scores,
151
154
  communities=communities,
155
+ community_reranker_scores=community_reranker_scores,
152
156
  )
153
157
 
154
158
  latency = (time() - start) * 1000
@@ -170,9 +174,9 @@ async def edge_search(
170
174
  bfs_origin_node_uuids: list[str] | None = None,
171
175
  limit=DEFAULT_SEARCH_LIMIT,
172
176
  reranker_min_score: float = 0,
173
- ) -> list[EntityEdge]:
177
+ ) -> tuple[list[EntityEdge], list[float]]:
174
178
  if config is None:
175
- return []
179
+ return [], []
176
180
  search_results: list[list[EntityEdge]] = list(
177
181
  await semaphore_gather(
178
182
  *[
@@ -215,15 +219,16 @@ async def edge_search(
215
219
  edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
216
220
 
217
221
  reranked_uuids: list[str] = []
222
+ edge_scores: list[float] = []
218
223
  if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
219
224
  search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
220
225
 
221
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
226
+ reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
222
227
  elif config.reranker == EdgeReranker.mmr:
223
228
  search_result_uuids_and_vectors = await get_embeddings_for_edges(
224
229
  driver, list(edge_uuid_map.values())
225
230
  )
226
- reranked_uuids = maximal_marginal_relevance(
231
+ reranked_uuids, edge_scores = maximal_marginal_relevance(
227
232
  query_vector,
228
233
  search_result_uuids_and_vectors,
229
234
  config.mmr_lambda,
@@ -235,12 +240,13 @@ async def edge_search(
235
240
  reranked_uuids = [
236
241
  fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
237
242
  ]
243
+ edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
238
244
  elif config.reranker == EdgeReranker.node_distance:
239
245
  if center_node_uuid is None:
240
246
  raise SearchRerankerError('No center node provided for Node Distance reranker')
241
247
 
242
248
  # use rrf as a preliminary sort
243
- sorted_result_uuids = rrf(
249
+ sorted_result_uuids, node_scores = rrf(
244
250
  [[edge.uuid for edge in result] for result in search_results],
245
251
  min_score=reranker_min_score,
246
252
  )
@@ -253,7 +259,7 @@ async def edge_search(
253
259
 
254
260
  source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
255
261
 
256
- reranked_node_uuids = await node_distance_reranker(
262
+ reranked_node_uuids, edge_scores = await node_distance_reranker(
257
263
  driver, source_uuids, center_node_uuid, min_score=reranker_min_score
258
264
  )
259
265
 
@@ -265,7 +271,7 @@ async def edge_search(
265
271
  if config.reranker == EdgeReranker.episode_mentions:
266
272
  reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
267
273
 
268
- return reranked_edges[:limit]
274
+ return reranked_edges[:limit], edge_scores[:limit]
269
275
 
270
276
 
271
277
  async def node_search(
@@ -280,9 +286,9 @@ async def node_search(
280
286
  bfs_origin_node_uuids: list[str] | None = None,
281
287
  limit=DEFAULT_SEARCH_LIMIT,
282
288
  reranker_min_score: float = 0,
283
- ) -> list[EntityNode]:
289
+ ) -> tuple[list[EntityNode], list[float]]:
284
290
  if config is None:
285
- return []
291
+ return [], []
286
292
  search_results: list[list[EntityNode]] = list(
287
293
  await semaphore_gather(
288
294
  *[
@@ -319,14 +325,15 @@ async def node_search(
319
325
  node_uuid_map = {node.uuid: node for result in search_results for node in result}
320
326
 
321
327
  reranked_uuids: list[str] = []
328
+ node_scores: list[float] = []
322
329
  if config.reranker == NodeReranker.rrf:
323
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
330
+ reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
324
331
  elif config.reranker == NodeReranker.mmr:
325
332
  search_result_uuids_and_vectors = await get_embeddings_for_nodes(
326
333
  driver, list(node_uuid_map.values())
327
334
  )
328
335
 
329
- reranked_uuids = maximal_marginal_relevance(
336
+ reranked_uuids, node_scores = maximal_marginal_relevance(
330
337
  query_vector,
331
338
  search_result_uuids_and_vectors,
332
339
  config.mmr_lambda,
@@ -341,23 +348,24 @@ async def node_search(
341
348
  for name, score in reranked_node_names
342
349
  if score >= reranker_min_score
343
350
  ]
351
+ node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
344
352
  elif config.reranker == NodeReranker.episode_mentions:
345
- reranked_uuids = await episode_mentions_reranker(
353
+ reranked_uuids, node_scores = await episode_mentions_reranker(
346
354
  driver, search_result_uuids, min_score=reranker_min_score
347
355
  )
348
356
  elif config.reranker == NodeReranker.node_distance:
349
357
  if center_node_uuid is None:
350
358
  raise SearchRerankerError('No center node provided for Node Distance reranker')
351
- reranked_uuids = await node_distance_reranker(
359
+ reranked_uuids, node_scores = await node_distance_reranker(
352
360
  driver,
353
- rrf(search_result_uuids, min_score=reranker_min_score),
361
+ rrf(search_result_uuids, min_score=reranker_min_score)[0],
354
362
  center_node_uuid,
355
363
  min_score=reranker_min_score,
356
364
  )
357
365
 
358
366
  reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
359
367
 
360
- return reranked_nodes[:limit]
368
+ return reranked_nodes[:limit], node_scores[:limit]
361
369
 
362
370
 
363
371
  async def episode_search(
@@ -370,9 +378,9 @@ async def episode_search(
370
378
  search_filter: SearchFilters,
371
379
  limit=DEFAULT_SEARCH_LIMIT,
372
380
  reranker_min_score: float = 0,
373
- ) -> list[EpisodicNode]:
381
+ ) -> tuple[list[EpisodicNode], list[float]]:
374
382
  if config is None:
375
- return []
383
+ return [], []
376
384
  search_results: list[list[EpisodicNode]] = list(
377
385
  await semaphore_gather(
378
386
  *[
@@ -385,12 +393,13 @@ async def episode_search(
385
393
  episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
386
394
 
387
395
  reranked_uuids: list[str] = []
396
+ episode_scores: list[float] = []
388
397
  if config.reranker == EpisodeReranker.rrf:
389
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
398
+ reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
390
399
 
391
400
  elif config.reranker == EpisodeReranker.cross_encoder:
392
401
  # use rrf as a preliminary reranker
393
- rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
402
+ rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
394
403
  rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
395
404
 
396
405
  content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
@@ -401,10 +410,11 @@ async def episode_search(
401
410
  for content, score in reranked_contents
402
411
  if score >= reranker_min_score
403
412
  ]
413
+ episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
404
414
 
405
415
  reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
406
416
 
407
- return reranked_episodes[:limit]
417
+ return reranked_episodes[:limit], episode_scores[:limit]
408
418
 
409
419
 
410
420
  async def community_search(
@@ -416,9 +426,9 @@ async def community_search(
416
426
  config: CommunitySearchConfig | None,
417
427
  limit=DEFAULT_SEARCH_LIMIT,
418
428
  reranker_min_score: float = 0,
419
- ) -> list[CommunityNode]:
429
+ ) -> tuple[list[CommunityNode], list[float]]:
420
430
  if config is None:
421
- return []
431
+ return [], []
422
432
 
423
433
  search_results: list[list[CommunityNode]] = list(
424
434
  await semaphore_gather(
@@ -437,14 +447,15 @@ async def community_search(
437
447
  }
438
448
 
439
449
  reranked_uuids: list[str] = []
450
+ community_scores: list[float] = []
440
451
  if config.reranker == CommunityReranker.rrf:
441
- reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
452
+ reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
442
453
  elif config.reranker == CommunityReranker.mmr:
443
454
  search_result_uuids_and_vectors = await get_embeddings_for_communities(
444
455
  driver, list(community_uuid_map.values())
445
456
  )
446
457
 
447
- reranked_uuids = maximal_marginal_relevance(
458
+ reranked_uuids, community_scores = maximal_marginal_relevance(
448
459
  query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
449
460
  )
450
461
  elif config.reranker == CommunityReranker.cross_encoder:
@@ -453,7 +464,8 @@ async def community_search(
453
464
  reranked_uuids = [
454
465
  name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
455
466
  ]
467
+ community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
456
468
 
457
469
  reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
458
470
 
459
- return reranked_communities[:limit]
471
+ return reranked_communities[:limit], community_scores[:limit]
@@ -119,7 +119,11 @@ class SearchConfig(BaseModel):
119
119
 
120
120
 
121
121
  class SearchResults(BaseModel):
122
- edges: list[EntityEdge]
123
- nodes: list[EntityNode]
124
- episodes: list[EpisodicNode]
125
- communities: list[CommunityNode]
122
+ edges: list[EntityEdge] = Field(default_factory=list)
123
+ edge_reranker_scores: list[float] = Field(default_factory=list)
124
+ nodes: list[EntityNode] = Field(default_factory=list)
125
+ node_reranker_scores: list[float] = Field(default_factory=list)
126
+ episodes: list[EpisodicNode] = Field(default_factory=list)
127
+ episode_reranker_scores: list[float] = Field(default_factory=list)
128
+ communities: list[CommunityNode] = Field(default_factory=list)
129
+ community_reranker_scores: list[float] = Field(default_factory=list)
@@ -72,7 +72,7 @@ def edge_search_filter_query_constructor(
72
72
 
73
73
  if filters.edge_types is not None:
74
74
  edge_types = filters.edge_types
75
- edge_types_filter = '\nAND r.name in $edge_types'
75
+ edge_types_filter = '\nAND e.name in $edge_types'
76
76
  filter_query += edge_types_filter
77
77
  filter_params['edge_types'] = edge_types
78
78
 
@@ -88,7 +88,7 @@ def edge_search_filter_query_constructor(
88
88
  filter_params['valid_at_' + str(j)] = date_filter.date
89
89
 
90
90
  and_filters = [
91
- '(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
91
+ '(e.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
92
92
  for j, date_filter in enumerate(or_list)
93
93
  ]
94
94
  and_filter_query = ''
@@ -113,7 +113,7 @@ def edge_search_filter_query_constructor(
113
113
  filter_params['invalid_at_' + str(j)] = date_filter.date
114
114
 
115
115
  and_filters = [
116
- '(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
116
+ '(e.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
117
117
  for j, date_filter in enumerate(or_list)
118
118
  ]
119
119
  and_filter_query = ''
@@ -138,7 +138,7 @@ def edge_search_filter_query_constructor(
138
138
  filter_params['created_at_' + str(j)] = date_filter.date
139
139
 
140
140
  and_filters = [
141
- '(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
141
+ '(e.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
142
142
  for j, date_filter in enumerate(or_list)
143
143
  ]
144
144
  and_filter_query = ''
@@ -163,7 +163,7 @@ def edge_search_filter_query_constructor(
163
163
  filter_params['expired_at_' + str(j)] = date_filter.date
164
164
 
165
165
  and_filters = [
166
- '(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
166
+ '(e.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
167
167
  for j, date_filter in enumerate(or_list)
168
168
  ]
169
169
  and_filter_query = ''