graphiti-core 0.11.6rc7__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/driver/__init__.py +17 -0
- graphiti_core/driver/driver.py +66 -0
- graphiti_core/driver/falkordb_driver.py +132 -0
- graphiti_core/driver/neo4j_driver.py +61 -0
- graphiti_core/edges.py +66 -40
- graphiti_core/embedder/azure_openai.py +64 -0
- graphiti_core/embedder/gemini.py +14 -3
- graphiti_core/graph_queries.py +149 -0
- graphiti_core/graphiti.py +41 -14
- graphiti_core/graphiti_types.py +2 -2
- graphiti_core/helpers.py +17 -30
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/azure_openai_client.py +73 -0
- graphiti_core/llm_client/gemini_client.py +4 -1
- graphiti_core/models/edges/edge_db_queries.py +2 -4
- graphiti_core/nodes.py +31 -31
- graphiti_core/prompts/dedupe_edges.py +52 -1
- graphiti_core/prompts/dedupe_nodes.py +79 -4
- graphiti_core/prompts/extract_edges.py +50 -5
- graphiti_core/prompts/invalidate_edges.py +1 -1
- graphiti_core/search/search.py +25 -55
- graphiti_core/search/search_filters.py +23 -9
- graphiti_core/search/search_utils.py +360 -195
- graphiti_core/utils/bulk_utils.py +38 -11
- graphiti_core/utils/maintenance/community_operations.py +6 -7
- graphiti_core/utils/maintenance/edge_operations.py +149 -19
- graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
- graphiti_core/utils/maintenance/node_operations.py +52 -71
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
graphiti_core/search/search.py
CHANGED
|
@@ -18,9 +18,8 @@ import logging
|
|
|
18
18
|
from collections import defaultdict
|
|
19
19
|
from time import time
|
|
20
20
|
|
|
21
|
-
from neo4j import AsyncDriver
|
|
22
|
-
|
|
23
21
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
22
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
24
23
|
from graphiti_core.edges import EntityEdge
|
|
25
24
|
from graphiti_core.errors import SearchRerankerError
|
|
26
25
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
@@ -50,6 +49,9 @@ from graphiti_core.search.search_utils import (
|
|
|
50
49
|
edge_similarity_search,
|
|
51
50
|
episode_fulltext_search,
|
|
52
51
|
episode_mentions_reranker,
|
|
52
|
+
get_embeddings_for_communities,
|
|
53
|
+
get_embeddings_for_edges,
|
|
54
|
+
get_embeddings_for_nodes,
|
|
53
55
|
maximal_marginal_relevance,
|
|
54
56
|
node_bfs_search,
|
|
55
57
|
node_distance_reranker,
|
|
@@ -91,7 +93,7 @@ async def search(
|
|
|
91
93
|
)
|
|
92
94
|
|
|
93
95
|
# if group_ids is empty, set it to None
|
|
94
|
-
group_ids = group_ids if group_ids else None
|
|
96
|
+
group_ids = group_ids if group_ids and group_ids != [''] else None
|
|
95
97
|
edges, nodes, episodes, communities = await semaphore_gather(
|
|
96
98
|
edge_search(
|
|
97
99
|
driver,
|
|
@@ -157,7 +159,7 @@ async def search(
|
|
|
157
159
|
|
|
158
160
|
|
|
159
161
|
async def edge_search(
|
|
160
|
-
driver:
|
|
162
|
+
driver: GraphDriver,
|
|
161
163
|
cross_encoder: CrossEncoderClient,
|
|
162
164
|
query: str,
|
|
163
165
|
query_vector: list[float],
|
|
@@ -171,7 +173,6 @@ async def edge_search(
|
|
|
171
173
|
) -> list[EntityEdge]:
|
|
172
174
|
if config is None:
|
|
173
175
|
return []
|
|
174
|
-
|
|
175
176
|
search_results: list[list[EntityEdge]] = list(
|
|
176
177
|
await semaphore_gather(
|
|
177
178
|
*[
|
|
@@ -209,26 +210,17 @@ async def edge_search(
|
|
|
209
210
|
|
|
210
211
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
211
212
|
elif config.reranker == EdgeReranker.mmr:
|
|
212
|
-
await
|
|
213
|
-
|
|
213
|
+
search_result_uuids_and_vectors = await get_embeddings_for_edges(
|
|
214
|
+
driver, list(edge_uuid_map.values())
|
|
214
215
|
)
|
|
215
|
-
search_result_uuids_and_vectors = [
|
|
216
|
-
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
|
217
|
-
for result in search_results
|
|
218
|
-
for edge in result
|
|
219
|
-
]
|
|
220
216
|
reranked_uuids = maximal_marginal_relevance(
|
|
221
217
|
query_vector,
|
|
222
218
|
search_result_uuids_and_vectors,
|
|
223
219
|
config.mmr_lambda,
|
|
220
|
+
reranker_min_score,
|
|
224
221
|
)
|
|
225
222
|
elif config.reranker == EdgeReranker.cross_encoder:
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
229
|
-
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
230
|
-
|
|
231
|
-
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
|
|
223
|
+
fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
|
|
232
224
|
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
|
233
225
|
reranked_uuids = [
|
|
234
226
|
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
|
@@ -267,7 +259,7 @@ async def edge_search(
|
|
|
267
259
|
|
|
268
260
|
|
|
269
261
|
async def node_search(
|
|
270
|
-
driver:
|
|
262
|
+
driver: GraphDriver,
|
|
271
263
|
cross_encoder: CrossEncoderClient,
|
|
272
264
|
query: str,
|
|
273
265
|
query_vector: list[float],
|
|
@@ -281,7 +273,6 @@ async def node_search(
|
|
|
281
273
|
) -> list[EntityNode]:
|
|
282
274
|
if config is None:
|
|
283
275
|
return []
|
|
284
|
-
|
|
285
276
|
search_results: list[list[EntityNode]] = list(
|
|
286
277
|
await semaphore_gather(
|
|
287
278
|
*[
|
|
@@ -311,30 +302,23 @@ async def node_search(
|
|
|
311
302
|
if config.reranker == NodeReranker.rrf:
|
|
312
303
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
313
304
|
elif config.reranker == NodeReranker.mmr:
|
|
314
|
-
await
|
|
315
|
-
|
|
305
|
+
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
|
|
306
|
+
driver, list(node_uuid_map.values())
|
|
316
307
|
)
|
|
317
|
-
|
|
318
|
-
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
|
319
|
-
for result in search_results
|
|
320
|
-
for node in result
|
|
321
|
-
]
|
|
308
|
+
|
|
322
309
|
reranked_uuids = maximal_marginal_relevance(
|
|
323
310
|
query_vector,
|
|
324
311
|
search_result_uuids_and_vectors,
|
|
325
312
|
config.mmr_lambda,
|
|
313
|
+
reranker_min_score,
|
|
326
314
|
)
|
|
327
315
|
elif config.reranker == NodeReranker.cross_encoder:
|
|
328
|
-
|
|
329
|
-
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
330
|
-
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
331
|
-
|
|
332
|
-
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
|
|
316
|
+
name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
|
|
333
317
|
|
|
334
|
-
|
|
318
|
+
reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
|
|
335
319
|
reranked_uuids = [
|
|
336
|
-
|
|
337
|
-
for
|
|
320
|
+
name_to_uuid_map[name]
|
|
321
|
+
for name, score in reranked_node_names
|
|
338
322
|
if score >= reranker_min_score
|
|
339
323
|
]
|
|
340
324
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
@@ -357,7 +341,7 @@ async def node_search(
|
|
|
357
341
|
|
|
358
342
|
|
|
359
343
|
async def episode_search(
|
|
360
|
-
driver:
|
|
344
|
+
driver: GraphDriver,
|
|
361
345
|
cross_encoder: CrossEncoderClient,
|
|
362
346
|
query: str,
|
|
363
347
|
_query_vector: list[float],
|
|
@@ -369,7 +353,6 @@ async def episode_search(
|
|
|
369
353
|
) -> list[EpisodicNode]:
|
|
370
354
|
if config is None:
|
|
371
355
|
return []
|
|
372
|
-
|
|
373
356
|
search_results: list[list[EpisodicNode]] = list(
|
|
374
357
|
await semaphore_gather(
|
|
375
358
|
*[
|
|
@@ -405,7 +388,7 @@ async def episode_search(
|
|
|
405
388
|
|
|
406
389
|
|
|
407
390
|
async def community_search(
|
|
408
|
-
driver:
|
|
391
|
+
driver: GraphDriver,
|
|
409
392
|
cross_encoder: CrossEncoderClient,
|
|
410
393
|
query: str,
|
|
411
394
|
query_vector: list[float],
|
|
@@ -437,25 +420,12 @@ async def community_search(
|
|
|
437
420
|
if config.reranker == CommunityReranker.rrf:
|
|
438
421
|
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
439
422
|
elif config.reranker == CommunityReranker.mmr:
|
|
440
|
-
await
|
|
441
|
-
|
|
442
|
-
community.load_name_embedding(driver)
|
|
443
|
-
for result in search_results
|
|
444
|
-
for community in result
|
|
445
|
-
]
|
|
423
|
+
search_result_uuids_and_vectors = await get_embeddings_for_communities(
|
|
424
|
+
driver, list(community_uuid_map.values())
|
|
446
425
|
)
|
|
447
|
-
|
|
448
|
-
(
|
|
449
|
-
community.uuid,
|
|
450
|
-
community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
|
|
451
|
-
)
|
|
452
|
-
for result in search_results
|
|
453
|
-
for community in result
|
|
454
|
-
]
|
|
426
|
+
|
|
455
427
|
reranked_uuids = maximal_marginal_relevance(
|
|
456
|
-
query_vector,
|
|
457
|
-
search_result_uuids_and_vectors,
|
|
458
|
-
config.mmr_lambda,
|
|
428
|
+
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
|
|
459
429
|
)
|
|
460
430
|
elif config.reranker == CommunityReranker.cross_encoder:
|
|
461
431
|
name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
|
|
@@ -42,6 +42,9 @@ class SearchFilters(BaseModel):
|
|
|
42
42
|
node_labels: list[str] | None = Field(
|
|
43
43
|
default=None, description='List of node labels to filter on'
|
|
44
44
|
)
|
|
45
|
+
edge_types: list[str] | None = Field(
|
|
46
|
+
default=None, description='List of edge types to filter on'
|
|
47
|
+
)
|
|
45
48
|
valid_at: list[list[DateFilter]] | None = Field(default=None)
|
|
46
49
|
invalid_at: list[list[DateFilter]] | None = Field(default=None)
|
|
47
50
|
created_at: list[list[DateFilter]] | None = Field(default=None)
|
|
@@ -68,8 +71,19 @@ def edge_search_filter_query_constructor(
|
|
|
68
71
|
filter_query: LiteralString = ''
|
|
69
72
|
filter_params: dict[str, Any] = {}
|
|
70
73
|
|
|
74
|
+
if filters.edge_types is not None:
|
|
75
|
+
edge_types = filters.edge_types
|
|
76
|
+
edge_types_filter = '\nAND r.name in $edge_types'
|
|
77
|
+
filter_query += edge_types_filter
|
|
78
|
+
filter_params['edge_types'] = edge_types
|
|
79
|
+
|
|
80
|
+
if filters.node_labels is not None:
|
|
81
|
+
node_labels = '|'.join(filters.node_labels)
|
|
82
|
+
node_label_filter = '\nAND n:' + node_labels + ' AND m:' + node_labels
|
|
83
|
+
filter_query += node_label_filter
|
|
84
|
+
|
|
71
85
|
if filters.valid_at is not None:
|
|
72
|
-
valid_at_filter = '
|
|
86
|
+
valid_at_filter = '\nAND ('
|
|
73
87
|
for i, or_list in enumerate(filters.valid_at):
|
|
74
88
|
for j, date_filter in enumerate(or_list):
|
|
75
89
|
filter_params['valid_at_' + str(j)] = date_filter.date
|
|
@@ -81,12 +95,12 @@ def edge_search_filter_query_constructor(
|
|
|
81
95
|
and_filter_query = ''
|
|
82
96
|
for j, and_filter in enumerate(and_filters):
|
|
83
97
|
and_filter_query += and_filter
|
|
84
|
-
if j != len(
|
|
98
|
+
if j != len(and_filters) - 1:
|
|
85
99
|
and_filter_query += ' AND '
|
|
86
100
|
|
|
87
101
|
valid_at_filter += and_filter_query
|
|
88
102
|
|
|
89
|
-
if i == len(
|
|
103
|
+
if i == len(filters.valid_at) - 1:
|
|
90
104
|
valid_at_filter += ')'
|
|
91
105
|
else:
|
|
92
106
|
valid_at_filter += ' OR '
|
|
@@ -106,12 +120,12 @@ def edge_search_filter_query_constructor(
|
|
|
106
120
|
and_filter_query = ''
|
|
107
121
|
for j, and_filter in enumerate(and_filters):
|
|
108
122
|
and_filter_query += and_filter
|
|
109
|
-
if j != len(
|
|
123
|
+
if j != len(and_filters) - 1:
|
|
110
124
|
and_filter_query += ' AND '
|
|
111
125
|
|
|
112
126
|
invalid_at_filter += and_filter_query
|
|
113
127
|
|
|
114
|
-
if i == len(
|
|
128
|
+
if i == len(filters.invalid_at) - 1:
|
|
115
129
|
invalid_at_filter += ')'
|
|
116
130
|
else:
|
|
117
131
|
invalid_at_filter += ' OR '
|
|
@@ -131,12 +145,12 @@ def edge_search_filter_query_constructor(
|
|
|
131
145
|
and_filter_query = ''
|
|
132
146
|
for j, and_filter in enumerate(and_filters):
|
|
133
147
|
and_filter_query += and_filter
|
|
134
|
-
if j != len(
|
|
148
|
+
if j != len(and_filters) - 1:
|
|
135
149
|
and_filter_query += ' AND '
|
|
136
150
|
|
|
137
151
|
created_at_filter += and_filter_query
|
|
138
152
|
|
|
139
|
-
if i == len(
|
|
153
|
+
if i == len(filters.created_at) - 1:
|
|
140
154
|
created_at_filter += ')'
|
|
141
155
|
else:
|
|
142
156
|
created_at_filter += ' OR '
|
|
@@ -156,12 +170,12 @@ def edge_search_filter_query_constructor(
|
|
|
156
170
|
and_filter_query = ''
|
|
157
171
|
for j, and_filter in enumerate(and_filters):
|
|
158
172
|
and_filter_query += and_filter
|
|
159
|
-
if j != len(
|
|
173
|
+
if j != len(and_filters) - 1:
|
|
160
174
|
and_filter_query += ' AND '
|
|
161
175
|
|
|
162
176
|
expired_at_filter += and_filter_query
|
|
163
177
|
|
|
164
|
-
if i == len(
|
|
178
|
+
if i == len(filters.expired_at) - 1:
|
|
165
179
|
expired_at_filter += ')'
|
|
166
180
|
else:
|
|
167
181
|
expired_at_filter += ' OR '
|