graphiti-core 0.17.11__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +20 -2
- graphiti_core/driver/falkordb_driver.py +16 -9
- graphiti_core/driver/neo4j_driver.py +8 -6
- graphiti_core/edges.py +73 -99
- graphiti_core/graph_queries.py +51 -97
- graphiti_core/graphiti.py +24 -9
- graphiti_core/helpers.py +3 -2
- graphiti_core/models/edges/edge_db_queries.py +106 -32
- graphiti_core/models/nodes/node_db_queries.py +101 -20
- graphiti_core/nodes.py +113 -128
- graphiti_core/prompts/dedupe_nodes.py +1 -1
- graphiti_core/prompts/extract_edges.py +4 -4
- graphiti_core/prompts/extract_nodes.py +12 -10
- graphiti_core/search/search.py +44 -32
- graphiti_core/search/search_config.py +8 -4
- graphiti_core/search/search_filters.py +5 -5
- graphiti_core/search/search_utils.py +154 -189
- graphiti_core/utils/bulk_utils.py +3 -5
- graphiti_core/utils/maintenance/community_operations.py +11 -7
- graphiti_core/utils/maintenance/edge_operations.py +19 -50
- graphiti_core/utils/maintenance/graph_data_operations.py +14 -29
- graphiti_core/utils/maintenance/node_operations.py +11 -55
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/METADATA +11 -3
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/RECORD +26 -26
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/WHEEL +0 -0
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/licenses/LICENSE +0 -0
graphiti_core/search/search.py
CHANGED
|
@@ -80,12 +80,7 @@ async def search(
|
|
|
80
80
|
cross_encoder = clients.cross_encoder
|
|
81
81
|
|
|
82
82
|
if query.strip() == '':
|
|
83
|
-
return SearchResults(
|
|
84
|
-
edges=[],
|
|
85
|
-
nodes=[],
|
|
86
|
-
episodes=[],
|
|
87
|
-
communities=[],
|
|
88
|
-
)
|
|
83
|
+
return SearchResults()
|
|
89
84
|
query_vector = (
|
|
90
85
|
query_vector
|
|
91
86
|
if query_vector is not None
|
|
@@ -94,7 +89,12 @@ async def search(
|
|
|
94
89
|
|
|
95
90
|
# if group_ids is empty, set it to None
|
|
96
91
|
group_ids = group_ids if group_ids and group_ids != [''] else None
|
|
97
|
-
|
|
92
|
+
(
|
|
93
|
+
(edges, edge_reranker_scores),
|
|
94
|
+
(nodes, node_reranker_scores),
|
|
95
|
+
(episodes, episode_reranker_scores),
|
|
96
|
+
(communities, community_reranker_scores),
|
|
97
|
+
) = await semaphore_gather(
|
|
98
98
|
edge_search(
|
|
99
99
|
driver,
|
|
100
100
|
cross_encoder,
|
|
@@ -146,9 +146,13 @@ async def search(
|
|
|
146
146
|
|
|
147
147
|
results = SearchResults(
|
|
148
148
|
edges=edges,
|
|
149
|
+
edge_reranker_scores=edge_reranker_scores,
|
|
149
150
|
nodes=nodes,
|
|
151
|
+
node_reranker_scores=node_reranker_scores,
|
|
150
152
|
episodes=episodes,
|
|
153
|
+
episode_reranker_scores=episode_reranker_scores,
|
|
151
154
|
communities=communities,
|
|
155
|
+
community_reranker_scores=community_reranker_scores,
|
|
152
156
|
)
|
|
153
157
|
|
|
154
158
|
latency = (time() - start) * 1000
|
|
@@ -170,9 +174,9 @@ async def edge_search(
|
|
|
170
174
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
171
175
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
172
176
|
reranker_min_score: float = 0,
|
|
173
|
-
) -> list[EntityEdge]:
|
|
177
|
+
) -> tuple[list[EntityEdge], list[float]]:
|
|
174
178
|
if config is None:
|
|
175
|
-
return []
|
|
179
|
+
return [], []
|
|
176
180
|
search_results: list[list[EntityEdge]] = list(
|
|
177
181
|
await semaphore_gather(
|
|
178
182
|
*[
|
|
@@ -215,15 +219,16 @@ async def edge_search(
|
|
|
215
219
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
|
216
220
|
|
|
217
221
|
reranked_uuids: list[str] = []
|
|
222
|
+
edge_scores: list[float] = []
|
|
218
223
|
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
|
219
224
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
220
225
|
|
|
221
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
226
|
+
reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
222
227
|
elif config.reranker == EdgeReranker.mmr:
|
|
223
228
|
search_result_uuids_and_vectors = await get_embeddings_for_edges(
|
|
224
229
|
driver, list(edge_uuid_map.values())
|
|
225
230
|
)
|
|
226
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
231
|
+
reranked_uuids, edge_scores = maximal_marginal_relevance(
|
|
227
232
|
query_vector,
|
|
228
233
|
search_result_uuids_and_vectors,
|
|
229
234
|
config.mmr_lambda,
|
|
@@ -235,12 +240,13 @@ async def edge_search(
|
|
|
235
240
|
reranked_uuids = [
|
|
236
241
|
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
|
237
242
|
]
|
|
243
|
+
edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
|
|
238
244
|
elif config.reranker == EdgeReranker.node_distance:
|
|
239
245
|
if center_node_uuid is None:
|
|
240
246
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
241
247
|
|
|
242
248
|
# use rrf as a preliminary sort
|
|
243
|
-
sorted_result_uuids = rrf(
|
|
249
|
+
sorted_result_uuids, node_scores = rrf(
|
|
244
250
|
[[edge.uuid for edge in result] for result in search_results],
|
|
245
251
|
min_score=reranker_min_score,
|
|
246
252
|
)
|
|
@@ -253,7 +259,7 @@ async def edge_search(
|
|
|
253
259
|
|
|
254
260
|
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
|
|
255
261
|
|
|
256
|
-
reranked_node_uuids = await node_distance_reranker(
|
|
262
|
+
reranked_node_uuids, edge_scores = await node_distance_reranker(
|
|
257
263
|
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
|
|
258
264
|
)
|
|
259
265
|
|
|
@@ -265,7 +271,7 @@ async def edge_search(
|
|
|
265
271
|
if config.reranker == EdgeReranker.episode_mentions:
|
|
266
272
|
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
|
|
267
273
|
|
|
268
|
-
return reranked_edges[:limit]
|
|
274
|
+
return reranked_edges[:limit], edge_scores[:limit]
|
|
269
275
|
|
|
270
276
|
|
|
271
277
|
async def node_search(
|
|
@@ -280,9 +286,9 @@ async def node_search(
|
|
|
280
286
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
281
287
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
282
288
|
reranker_min_score: float = 0,
|
|
283
|
-
) -> list[EntityNode]:
|
|
289
|
+
) -> tuple[list[EntityNode], list[float]]:
|
|
284
290
|
if config is None:
|
|
285
|
-
return []
|
|
291
|
+
return [], []
|
|
286
292
|
search_results: list[list[EntityNode]] = list(
|
|
287
293
|
await semaphore_gather(
|
|
288
294
|
*[
|
|
@@ -319,14 +325,15 @@ async def node_search(
|
|
|
319
325
|
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
|
320
326
|
|
|
321
327
|
reranked_uuids: list[str] = []
|
|
328
|
+
node_scores: list[float] = []
|
|
322
329
|
if config.reranker == NodeReranker.rrf:
|
|
323
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
330
|
+
reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
324
331
|
elif config.reranker == NodeReranker.mmr:
|
|
325
332
|
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
|
|
326
333
|
driver, list(node_uuid_map.values())
|
|
327
334
|
)
|
|
328
335
|
|
|
329
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
336
|
+
reranked_uuids, node_scores = maximal_marginal_relevance(
|
|
330
337
|
query_vector,
|
|
331
338
|
search_result_uuids_and_vectors,
|
|
332
339
|
config.mmr_lambda,
|
|
@@ -341,23 +348,24 @@ async def node_search(
|
|
|
341
348
|
for name, score in reranked_node_names
|
|
342
349
|
if score >= reranker_min_score
|
|
343
350
|
]
|
|
351
|
+
node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
|
|
344
352
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
345
|
-
reranked_uuids = await episode_mentions_reranker(
|
|
353
|
+
reranked_uuids, node_scores = await episode_mentions_reranker(
|
|
346
354
|
driver, search_result_uuids, min_score=reranker_min_score
|
|
347
355
|
)
|
|
348
356
|
elif config.reranker == NodeReranker.node_distance:
|
|
349
357
|
if center_node_uuid is None:
|
|
350
358
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
351
|
-
reranked_uuids = await node_distance_reranker(
|
|
359
|
+
reranked_uuids, node_scores = await node_distance_reranker(
|
|
352
360
|
driver,
|
|
353
|
-
rrf(search_result_uuids, min_score=reranker_min_score),
|
|
361
|
+
rrf(search_result_uuids, min_score=reranker_min_score)[0],
|
|
354
362
|
center_node_uuid,
|
|
355
363
|
min_score=reranker_min_score,
|
|
356
364
|
)
|
|
357
365
|
|
|
358
366
|
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
|
359
367
|
|
|
360
|
-
return reranked_nodes[:limit]
|
|
368
|
+
return reranked_nodes[:limit], node_scores[:limit]
|
|
361
369
|
|
|
362
370
|
|
|
363
371
|
async def episode_search(
|
|
@@ -370,9 +378,9 @@ async def episode_search(
|
|
|
370
378
|
search_filter: SearchFilters,
|
|
371
379
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
372
380
|
reranker_min_score: float = 0,
|
|
373
|
-
) -> list[EpisodicNode]:
|
|
381
|
+
) -> tuple[list[EpisodicNode], list[float]]:
|
|
374
382
|
if config is None:
|
|
375
|
-
return []
|
|
383
|
+
return [], []
|
|
376
384
|
search_results: list[list[EpisodicNode]] = list(
|
|
377
385
|
await semaphore_gather(
|
|
378
386
|
*[
|
|
@@ -385,12 +393,13 @@ async def episode_search(
|
|
|
385
393
|
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
|
|
386
394
|
|
|
387
395
|
reranked_uuids: list[str] = []
|
|
396
|
+
episode_scores: list[float] = []
|
|
388
397
|
if config.reranker == EpisodeReranker.rrf:
|
|
389
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
398
|
+
reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
390
399
|
|
|
391
400
|
elif config.reranker == EpisodeReranker.cross_encoder:
|
|
392
401
|
# use rrf as a preliminary reranker
|
|
393
|
-
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
402
|
+
rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
394
403
|
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
395
404
|
|
|
396
405
|
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
|
|
@@ -401,10 +410,11 @@ async def episode_search(
|
|
|
401
410
|
for content, score in reranked_contents
|
|
402
411
|
if score >= reranker_min_score
|
|
403
412
|
]
|
|
413
|
+
episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
|
|
404
414
|
|
|
405
415
|
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
|
|
406
416
|
|
|
407
|
-
return reranked_episodes[:limit]
|
|
417
|
+
return reranked_episodes[:limit], episode_scores[:limit]
|
|
408
418
|
|
|
409
419
|
|
|
410
420
|
async def community_search(
|
|
@@ -416,9 +426,9 @@ async def community_search(
|
|
|
416
426
|
config: CommunitySearchConfig | None,
|
|
417
427
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
418
428
|
reranker_min_score: float = 0,
|
|
419
|
-
) -> list[CommunityNode]:
|
|
429
|
+
) -> tuple[list[CommunityNode], list[float]]:
|
|
420
430
|
if config is None:
|
|
421
|
-
return []
|
|
431
|
+
return [], []
|
|
422
432
|
|
|
423
433
|
search_results: list[list[CommunityNode]] = list(
|
|
424
434
|
await semaphore_gather(
|
|
@@ -437,14 +447,15 @@ async def community_search(
|
|
|
437
447
|
}
|
|
438
448
|
|
|
439
449
|
reranked_uuids: list[str] = []
|
|
450
|
+
community_scores: list[float] = []
|
|
440
451
|
if config.reranker == CommunityReranker.rrf:
|
|
441
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
452
|
+
reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
442
453
|
elif config.reranker == CommunityReranker.mmr:
|
|
443
454
|
search_result_uuids_and_vectors = await get_embeddings_for_communities(
|
|
444
455
|
driver, list(community_uuid_map.values())
|
|
445
456
|
)
|
|
446
457
|
|
|
447
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
458
|
+
reranked_uuids, community_scores = maximal_marginal_relevance(
|
|
448
459
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
|
|
449
460
|
)
|
|
450
461
|
elif config.reranker == CommunityReranker.cross_encoder:
|
|
@@ -453,7 +464,8 @@ async def community_search(
|
|
|
453
464
|
reranked_uuids = [
|
|
454
465
|
name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
|
|
455
466
|
]
|
|
467
|
+
community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
|
|
456
468
|
|
|
457
469
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
458
470
|
|
|
459
|
-
return reranked_communities[:limit]
|
|
471
|
+
return reranked_communities[:limit], community_scores[:limit]
|
|
@@ -119,7 +119,11 @@ class SearchConfig(BaseModel):
|
|
|
119
119
|
|
|
120
120
|
|
|
121
121
|
class SearchResults(BaseModel):
|
|
122
|
-
edges: list[EntityEdge]
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
122
|
+
edges: list[EntityEdge] = Field(default_factory=list)
|
|
123
|
+
edge_reranker_scores: list[float] = Field(default_factory=list)
|
|
124
|
+
nodes: list[EntityNode] = Field(default_factory=list)
|
|
125
|
+
node_reranker_scores: list[float] = Field(default_factory=list)
|
|
126
|
+
episodes: list[EpisodicNode] = Field(default_factory=list)
|
|
127
|
+
episode_reranker_scores: list[float] = Field(default_factory=list)
|
|
128
|
+
communities: list[CommunityNode] = Field(default_factory=list)
|
|
129
|
+
community_reranker_scores: list[float] = Field(default_factory=list)
|
|
@@ -72,7 +72,7 @@ def edge_search_filter_query_constructor(
|
|
|
72
72
|
|
|
73
73
|
if filters.edge_types is not None:
|
|
74
74
|
edge_types = filters.edge_types
|
|
75
|
-
edge_types_filter = '\nAND
|
|
75
|
+
edge_types_filter = '\nAND e.name in $edge_types'
|
|
76
76
|
filter_query += edge_types_filter
|
|
77
77
|
filter_params['edge_types'] = edge_types
|
|
78
78
|
|
|
@@ -88,7 +88,7 @@ def edge_search_filter_query_constructor(
|
|
|
88
88
|
filter_params['valid_at_' + str(j)] = date_filter.date
|
|
89
89
|
|
|
90
90
|
and_filters = [
|
|
91
|
-
'(
|
|
91
|
+
'(e.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
|
|
92
92
|
for j, date_filter in enumerate(or_list)
|
|
93
93
|
]
|
|
94
94
|
and_filter_query = ''
|
|
@@ -113,7 +113,7 @@ def edge_search_filter_query_constructor(
|
|
|
113
113
|
filter_params['invalid_at_' + str(j)] = date_filter.date
|
|
114
114
|
|
|
115
115
|
and_filters = [
|
|
116
|
-
'(
|
|
116
|
+
'(e.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
|
|
117
117
|
for j, date_filter in enumerate(or_list)
|
|
118
118
|
]
|
|
119
119
|
and_filter_query = ''
|
|
@@ -138,7 +138,7 @@ def edge_search_filter_query_constructor(
|
|
|
138
138
|
filter_params['created_at_' + str(j)] = date_filter.date
|
|
139
139
|
|
|
140
140
|
and_filters = [
|
|
141
|
-
'(
|
|
141
|
+
'(e.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
|
|
142
142
|
for j, date_filter in enumerate(or_list)
|
|
143
143
|
]
|
|
144
144
|
and_filter_query = ''
|
|
@@ -163,7 +163,7 @@ def edge_search_filter_query_constructor(
|
|
|
163
163
|
filter_params['expired_at_' + str(j)] = date_filter.date
|
|
164
164
|
|
|
165
165
|
and_filters = [
|
|
166
|
-
'(
|
|
166
|
+
'(e.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
|
|
167
167
|
for j, date_filter in enumerate(or_list)
|
|
168
168
|
]
|
|
169
169
|
and_filter_query = ''
|