graphiti-core 0.18.9__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +4 -0
- graphiti_core/driver/falkordb_driver.py +3 -14
- graphiti_core/driver/kuzu_driver.py +175 -0
- graphiti_core/driver/neptune_driver.py +301 -0
- graphiti_core/edges.py +155 -62
- graphiti_core/graph_queries.py +31 -2
- graphiti_core/graphiti.py +6 -1
- graphiti_core/helpers.py +8 -8
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/openai_base_client.py +12 -2
- graphiti_core/llm_client/openai_client.py +10 -2
- graphiti_core/migrations/__init__.py +0 -0
- graphiti_core/migrations/neo4j_node_group_labels.py +114 -0
- graphiti_core/models/edges/edge_db_queries.py +205 -76
- graphiti_core/models/nodes/node_db_queries.py +253 -74
- graphiti_core/nodes.py +271 -98
- graphiti_core/search/search.py +42 -12
- graphiti_core/search/search_config.py +4 -0
- graphiti_core/search/search_filters.py +35 -22
- graphiti_core/search/search_utils.py +1329 -392
- graphiti_core/utils/bulk_utils.py +50 -15
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +39 -32
- graphiti_core/utils/maintenance/edge_operations.py +47 -13
- graphiti_core/utils/maintenance/graph_data_operations.py +100 -15
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/METADATA +87 -13
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/RECORD +29 -25
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/search/search.py
CHANGED
|
@@ -21,6 +21,7 @@ from time import time
|
|
|
21
21
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
22
22
|
from graphiti_core.driver.driver import GraphDriver
|
|
23
23
|
from graphiti_core.edges import EntityEdge
|
|
24
|
+
from graphiti_core.embedder.client import EMBEDDING_DIM
|
|
24
25
|
from graphiti_core.errors import SearchRerankerError
|
|
25
26
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
26
27
|
from graphiti_core.helpers import semaphore_gather
|
|
@@ -29,6 +30,7 @@ from graphiti_core.search.search_config import (
|
|
|
29
30
|
DEFAULT_SEARCH_LIMIT,
|
|
30
31
|
CommunityReranker,
|
|
31
32
|
CommunitySearchConfig,
|
|
33
|
+
CommunitySearchMethod,
|
|
32
34
|
EdgeReranker,
|
|
33
35
|
EdgeSearchConfig,
|
|
34
36
|
EdgeSearchMethod,
|
|
@@ -81,11 +83,29 @@ async def search(
|
|
|
81
83
|
|
|
82
84
|
if query.strip() == '':
|
|
83
85
|
return SearchResults()
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
86
|
+
|
|
87
|
+
if (
|
|
88
|
+
config.edge_config
|
|
89
|
+
and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
|
|
90
|
+
or config.edge_config
|
|
91
|
+
and EdgeReranker.mmr == config.edge_config.reranker
|
|
92
|
+
or config.node_config
|
|
93
|
+
and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
|
|
94
|
+
or config.node_config
|
|
95
|
+
and NodeReranker.mmr == config.node_config.reranker
|
|
96
|
+
or (
|
|
97
|
+
config.community_config
|
|
98
|
+
and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
|
|
99
|
+
)
|
|
100
|
+
or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
|
|
101
|
+
):
|
|
102
|
+
search_vector = (
|
|
103
|
+
query_vector
|
|
104
|
+
if query_vector is not None
|
|
105
|
+
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
search_vector = [0.0] * EMBEDDING_DIM
|
|
89
109
|
|
|
90
110
|
# if group_ids is empty, set it to None
|
|
91
111
|
group_ids = group_ids if group_ids and group_ids != [''] else None
|
|
@@ -99,7 +119,7 @@ async def search(
|
|
|
99
119
|
driver,
|
|
100
120
|
cross_encoder,
|
|
101
121
|
query,
|
|
102
|
-
|
|
122
|
+
search_vector,
|
|
103
123
|
group_ids,
|
|
104
124
|
config.edge_config,
|
|
105
125
|
search_filter,
|
|
@@ -112,7 +132,7 @@ async def search(
|
|
|
112
132
|
driver,
|
|
113
133
|
cross_encoder,
|
|
114
134
|
query,
|
|
115
|
-
|
|
135
|
+
search_vector,
|
|
116
136
|
group_ids,
|
|
117
137
|
config.node_config,
|
|
118
138
|
search_filter,
|
|
@@ -125,7 +145,7 @@ async def search(
|
|
|
125
145
|
driver,
|
|
126
146
|
cross_encoder,
|
|
127
147
|
query,
|
|
128
|
-
|
|
148
|
+
search_vector,
|
|
129
149
|
group_ids,
|
|
130
150
|
config.episode_config,
|
|
131
151
|
search_filter,
|
|
@@ -136,7 +156,7 @@ async def search(
|
|
|
136
156
|
driver,
|
|
137
157
|
cross_encoder,
|
|
138
158
|
query,
|
|
139
|
-
|
|
159
|
+
search_vector,
|
|
140
160
|
group_ids,
|
|
141
161
|
config.community_config,
|
|
142
162
|
config.limit,
|
|
@@ -305,12 +325,20 @@ async def node_search(
|
|
|
305
325
|
search_tasks = []
|
|
306
326
|
if NodeSearchMethod.bm25 in config.search_methods:
|
|
307
327
|
search_tasks.append(
|
|
308
|
-
node_fulltext_search(
|
|
328
|
+
node_fulltext_search(
|
|
329
|
+
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
|
330
|
+
)
|
|
309
331
|
)
|
|
310
332
|
if NodeSearchMethod.cosine_similarity in config.search_methods:
|
|
311
333
|
search_tasks.append(
|
|
312
334
|
node_similarity_search(
|
|
313
|
-
driver,
|
|
335
|
+
driver,
|
|
336
|
+
query_vector,
|
|
337
|
+
search_filter,
|
|
338
|
+
group_ids,
|
|
339
|
+
2 * limit,
|
|
340
|
+
config.sim_min_score,
|
|
341
|
+
config.use_local_indexes,
|
|
314
342
|
)
|
|
315
343
|
)
|
|
316
344
|
if NodeSearchMethod.bfs in config.search_methods:
|
|
@@ -406,7 +434,9 @@ async def episode_search(
|
|
|
406
434
|
search_results: list[list[EpisodicNode]] = list(
|
|
407
435
|
await semaphore_gather(
|
|
408
436
|
*[
|
|
409
|
-
episode_fulltext_search(
|
|
437
|
+
episode_fulltext_search(
|
|
438
|
+
driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
|
|
439
|
+
),
|
|
410
440
|
]
|
|
411
441
|
)
|
|
412
442
|
)
|
|
@@ -24,6 +24,7 @@ from graphiti_core.search.search_utils import (
|
|
|
24
24
|
DEFAULT_MIN_SCORE,
|
|
25
25
|
DEFAULT_MMR_LAMBDA,
|
|
26
26
|
MAX_SEARCH_DEPTH,
|
|
27
|
+
USE_HNSW,
|
|
27
28
|
)
|
|
28
29
|
|
|
29
30
|
DEFAULT_SEARCH_LIMIT = 10
|
|
@@ -91,6 +92,7 @@ class NodeSearchConfig(BaseModel):
|
|
|
91
92
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
92
93
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
93
94
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
95
|
+
use_local_indexes: bool = Field(default=USE_HNSW)
|
|
94
96
|
|
|
95
97
|
|
|
96
98
|
class EpisodeSearchConfig(BaseModel):
|
|
@@ -99,6 +101,7 @@ class EpisodeSearchConfig(BaseModel):
|
|
|
99
101
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
100
102
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
101
103
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
104
|
+
use_local_indexes: bool = Field(default=USE_HNSW)
|
|
102
105
|
|
|
103
106
|
|
|
104
107
|
class CommunitySearchConfig(BaseModel):
|
|
@@ -107,6 +110,7 @@ class CommunitySearchConfig(BaseModel):
|
|
|
107
110
|
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
108
111
|
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
109
112
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
113
|
+
use_local_indexes: bool = Field(default=USE_HNSW)
|
|
110
114
|
|
|
111
115
|
|
|
112
116
|
class SearchConfig(BaseModel):
|
|
@@ -20,6 +20,8 @@ from typing import Any
|
|
|
20
20
|
|
|
21
21
|
from pydantic import BaseModel, Field
|
|
22
22
|
|
|
23
|
+
from graphiti_core.driver.driver import GraphProvider
|
|
24
|
+
|
|
23
25
|
|
|
24
26
|
class ComparisonOperator(Enum):
|
|
25
27
|
equals = '='
|
|
@@ -54,16 +56,21 @@ class SearchFilters(BaseModel):
|
|
|
54
56
|
|
|
55
57
|
def node_search_filter_query_constructor(
|
|
56
58
|
filters: SearchFilters,
|
|
57
|
-
|
|
58
|
-
|
|
59
|
+
provider: GraphProvider,
|
|
60
|
+
) -> tuple[list[str], dict[str, Any]]:
|
|
61
|
+
filter_queries: list[str] = []
|
|
59
62
|
filter_params: dict[str, Any] = {}
|
|
60
63
|
|
|
61
64
|
if filters.node_labels is not None:
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
+
if provider == GraphProvider.KUZU:
|
|
66
|
+
node_label_filter = 'list_has_all(n.labels, $labels)'
|
|
67
|
+
filter_params['labels'] = filters.node_labels
|
|
68
|
+
else:
|
|
69
|
+
node_labels = '|'.join(filters.node_labels)
|
|
70
|
+
node_label_filter = 'n:' + node_labels
|
|
71
|
+
filter_queries.append(node_label_filter)
|
|
65
72
|
|
|
66
|
-
return
|
|
73
|
+
return filter_queries, filter_params
|
|
67
74
|
|
|
68
75
|
|
|
69
76
|
def date_filter_query_constructor(
|
|
@@ -81,23 +88,29 @@ def date_filter_query_constructor(
|
|
|
81
88
|
|
|
82
89
|
def edge_search_filter_query_constructor(
|
|
83
90
|
filters: SearchFilters,
|
|
84
|
-
|
|
85
|
-
|
|
91
|
+
provider: GraphProvider,
|
|
92
|
+
) -> tuple[list[str], dict[str, Any]]:
|
|
93
|
+
filter_queries: list[str] = []
|
|
86
94
|
filter_params: dict[str, Any] = {}
|
|
87
95
|
|
|
88
96
|
if filters.edge_types is not None:
|
|
89
97
|
edge_types = filters.edge_types
|
|
90
|
-
|
|
91
|
-
filter_query += edge_types_filter
|
|
98
|
+
filter_queries.append('e.name in $edge_types')
|
|
92
99
|
filter_params['edge_types'] = edge_types
|
|
93
100
|
|
|
94
101
|
if filters.node_labels is not None:
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
102
|
+
if provider == GraphProvider.KUZU:
|
|
103
|
+
node_label_filter = (
|
|
104
|
+
'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
|
|
105
|
+
)
|
|
106
|
+
filter_params['labels'] = filters.node_labels
|
|
107
|
+
else:
|
|
108
|
+
node_labels = '|'.join(filters.node_labels)
|
|
109
|
+
node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
|
|
110
|
+
filter_queries.append(node_label_filter)
|
|
98
111
|
|
|
99
112
|
if filters.valid_at is not None:
|
|
100
|
-
valid_at_filter = '
|
|
113
|
+
valid_at_filter = '('
|
|
101
114
|
for i, or_list in enumerate(filters.valid_at):
|
|
102
115
|
for j, date_filter in enumerate(or_list):
|
|
103
116
|
if date_filter.comparison_operator not in [
|
|
@@ -125,10 +138,10 @@ def edge_search_filter_query_constructor(
|
|
|
125
138
|
else:
|
|
126
139
|
valid_at_filter += ' OR '
|
|
127
140
|
|
|
128
|
-
|
|
141
|
+
filter_queries.append(valid_at_filter)
|
|
129
142
|
|
|
130
143
|
if filters.invalid_at is not None:
|
|
131
|
-
invalid_at_filter = '
|
|
144
|
+
invalid_at_filter = '('
|
|
132
145
|
for i, or_list in enumerate(filters.invalid_at):
|
|
133
146
|
for j, date_filter in enumerate(or_list):
|
|
134
147
|
if date_filter.comparison_operator not in [
|
|
@@ -156,10 +169,10 @@ def edge_search_filter_query_constructor(
|
|
|
156
169
|
else:
|
|
157
170
|
invalid_at_filter += ' OR '
|
|
158
171
|
|
|
159
|
-
|
|
172
|
+
filter_queries.append(invalid_at_filter)
|
|
160
173
|
|
|
161
174
|
if filters.created_at is not None:
|
|
162
|
-
created_at_filter = '
|
|
175
|
+
created_at_filter = '('
|
|
163
176
|
for i, or_list in enumerate(filters.created_at):
|
|
164
177
|
for j, date_filter in enumerate(or_list):
|
|
165
178
|
if date_filter.comparison_operator not in [
|
|
@@ -187,10 +200,10 @@ def edge_search_filter_query_constructor(
|
|
|
187
200
|
else:
|
|
188
201
|
created_at_filter += ' OR '
|
|
189
202
|
|
|
190
|
-
|
|
203
|
+
filter_queries.append(created_at_filter)
|
|
191
204
|
|
|
192
205
|
if filters.expired_at is not None:
|
|
193
|
-
expired_at_filter = '
|
|
206
|
+
expired_at_filter = '('
|
|
194
207
|
for i, or_list in enumerate(filters.expired_at):
|
|
195
208
|
for j, date_filter in enumerate(or_list):
|
|
196
209
|
if date_filter.comparison_operator not in [
|
|
@@ -218,6 +231,6 @@ def edge_search_filter_query_constructor(
|
|
|
218
231
|
else:
|
|
219
232
|
expired_at_filter += ' OR '
|
|
220
233
|
|
|
221
|
-
|
|
234
|
+
filter_queries.append(expired_at_filter)
|
|
222
235
|
|
|
223
|
-
return
|
|
236
|
+
return filter_queries, filter_params
|