graphiti-core 0.19.0rc2__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +3 -0
- graphiti_core/driver/falkordb_driver.py +3 -14
- graphiti_core/driver/kuzu_driver.py +175 -0
- graphiti_core/driver/neptune_driver.py +2 -0
- graphiti_core/edges.py +148 -83
- graphiti_core/graph_queries.py +31 -2
- graphiti_core/graphiti.py +4 -1
- graphiti_core/helpers.py +7 -12
- graphiti_core/migrations/neo4j_node_group_labels.py +64 -3
- graphiti_core/models/edges/edge_db_queries.py +121 -42
- graphiti_core/models/nodes/node_db_queries.py +102 -23
- graphiti_core/nodes.py +169 -66
- graphiti_core/search/search.py +13 -3
- graphiti_core/search/search_config.py +4 -0
- graphiti_core/search/search_filters.py +35 -22
- graphiti_core/search/search_utils.py +693 -382
- graphiti_core/utils/bulk_utils.py +50 -18
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +39 -32
- graphiti_core/utils/maintenance/edge_operations.py +19 -8
- graphiti_core/utils/maintenance/graph_data_operations.py +79 -48
- {graphiti_core-0.19.0rc2.dist-info → graphiti_core-0.20.0.dist-info}/METADATA +123 -53
- {graphiti_core-0.19.0rc2.dist-info → graphiti_core-0.20.0.dist-info}/RECORD +25 -24
- {graphiti_core-0.19.0rc2.dist-info → graphiti_core-0.20.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.19.0rc2.dist-info → graphiti_core-0.20.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/driver/driver.py
CHANGED
|
@@ -27,10 +27,13 @@ logger = logging.getLogger(__name__)
|
|
|
27
27
|
class GraphProvider(Enum):
|
|
28
28
|
NEO4J = 'neo4j'
|
|
29
29
|
FALKORDB = 'falkordb'
|
|
30
|
+
KUZU = 'kuzu'
|
|
30
31
|
NEPTUNE = 'neptune'
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
class GraphDriverSession(ABC):
|
|
35
|
+
provider: GraphProvider
|
|
36
|
+
|
|
34
37
|
async def __aenter__(self):
|
|
35
38
|
return self
|
|
36
39
|
|
|
@@ -15,7 +15,6 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
|
-
from datetime import datetime
|
|
19
18
|
from typing import TYPE_CHECKING, Any
|
|
20
19
|
|
|
21
20
|
if TYPE_CHECKING:
|
|
@@ -33,11 +32,14 @@ else:
|
|
|
33
32
|
) from None
|
|
34
33
|
|
|
35
34
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
35
|
+
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
|
36
36
|
|
|
37
37
|
logger = logging.getLogger(__name__)
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
class FalkorDriverSession(GraphDriverSession):
|
|
41
|
+
provider = GraphProvider.FALKORDB
|
|
42
|
+
|
|
41
43
|
def __init__(self, graph: FalkorGraph):
|
|
42
44
|
self.graph = graph
|
|
43
45
|
|
|
@@ -164,16 +166,3 @@ class FalkorDriver(GraphDriver):
|
|
|
164
166
|
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
|
165
167
|
|
|
166
168
|
return cloned
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def convert_datetimes_to_strings(obj):
|
|
170
|
-
if isinstance(obj, dict):
|
|
171
|
-
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
|
|
172
|
-
elif isinstance(obj, list):
|
|
173
|
-
return [convert_datetimes_to_strings(item) for item in obj]
|
|
174
|
-
elif isinstance(obj, tuple):
|
|
175
|
-
return tuple(convert_datetimes_to_strings(item) for item in obj)
|
|
176
|
-
elif isinstance(obj, datetime):
|
|
177
|
-
return obj.isoformat()
|
|
178
|
-
else:
|
|
179
|
-
return obj
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import kuzu
|
|
21
|
+
|
|
22
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
# Kuzu requires an explicit schema.
|
|
27
|
+
# As Kuzu currently does not support creating full text indexes on edge properties,
|
|
28
|
+
# we work around this by representing (n:Entity)-[:RELATES_TO]->(m:Entity) as
|
|
29
|
+
# (n)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m).
|
|
30
|
+
SCHEMA_QUERIES = """
|
|
31
|
+
CREATE NODE TABLE IF NOT EXISTS Episodic (
|
|
32
|
+
uuid STRING PRIMARY KEY,
|
|
33
|
+
name STRING,
|
|
34
|
+
group_id STRING,
|
|
35
|
+
created_at TIMESTAMP,
|
|
36
|
+
source STRING,
|
|
37
|
+
source_description STRING,
|
|
38
|
+
content STRING,
|
|
39
|
+
valid_at TIMESTAMP,
|
|
40
|
+
entity_edges STRING[]
|
|
41
|
+
);
|
|
42
|
+
CREATE NODE TABLE IF NOT EXISTS Entity (
|
|
43
|
+
uuid STRING PRIMARY KEY,
|
|
44
|
+
name STRING,
|
|
45
|
+
group_id STRING,
|
|
46
|
+
labels STRING[],
|
|
47
|
+
created_at TIMESTAMP,
|
|
48
|
+
name_embedding FLOAT[],
|
|
49
|
+
summary STRING,
|
|
50
|
+
attributes STRING
|
|
51
|
+
);
|
|
52
|
+
CREATE NODE TABLE IF NOT EXISTS Community (
|
|
53
|
+
uuid STRING PRIMARY KEY,
|
|
54
|
+
name STRING,
|
|
55
|
+
group_id STRING,
|
|
56
|
+
created_at TIMESTAMP,
|
|
57
|
+
name_embedding FLOAT[],
|
|
58
|
+
summary STRING
|
|
59
|
+
);
|
|
60
|
+
CREATE NODE TABLE IF NOT EXISTS RelatesToNode_ (
|
|
61
|
+
uuid STRING PRIMARY KEY,
|
|
62
|
+
group_id STRING,
|
|
63
|
+
created_at TIMESTAMP,
|
|
64
|
+
name STRING,
|
|
65
|
+
fact STRING,
|
|
66
|
+
fact_embedding FLOAT[],
|
|
67
|
+
episodes STRING[],
|
|
68
|
+
expired_at TIMESTAMP,
|
|
69
|
+
valid_at TIMESTAMP,
|
|
70
|
+
invalid_at TIMESTAMP,
|
|
71
|
+
attributes STRING
|
|
72
|
+
);
|
|
73
|
+
CREATE REL TABLE IF NOT EXISTS RELATES_TO(
|
|
74
|
+
FROM Entity TO RelatesToNode_,
|
|
75
|
+
FROM RelatesToNode_ TO Entity
|
|
76
|
+
);
|
|
77
|
+
CREATE REL TABLE IF NOT EXISTS MENTIONS(
|
|
78
|
+
FROM Episodic TO Entity,
|
|
79
|
+
uuid STRING PRIMARY KEY,
|
|
80
|
+
group_id STRING,
|
|
81
|
+
created_at TIMESTAMP
|
|
82
|
+
);
|
|
83
|
+
CREATE REL TABLE IF NOT EXISTS HAS_MEMBER(
|
|
84
|
+
FROM Community TO Entity,
|
|
85
|
+
FROM Community TO Community,
|
|
86
|
+
uuid STRING,
|
|
87
|
+
group_id STRING,
|
|
88
|
+
created_at TIMESTAMP
|
|
89
|
+
);
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class KuzuDriver(GraphDriver):
|
|
94
|
+
provider: GraphProvider = GraphProvider.KUZU
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
db: str = ':memory:',
|
|
99
|
+
max_concurrent_queries: int = 1,
|
|
100
|
+
):
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.db = kuzu.Database(db)
|
|
103
|
+
|
|
104
|
+
self.setup_schema()
|
|
105
|
+
|
|
106
|
+
self.client = kuzu.AsyncConnection(self.db, max_concurrent_queries=max_concurrent_queries)
|
|
107
|
+
|
|
108
|
+
async def execute_query(
|
|
109
|
+
self, cypher_query_: str, **kwargs: Any
|
|
110
|
+
) -> tuple[list[dict[str, Any]] | list[list[dict[str, Any]]], None, None]:
|
|
111
|
+
params = {k: v for k, v in kwargs.items() if v is not None}
|
|
112
|
+
# Kuzu does not support these parameters.
|
|
113
|
+
params.pop('database_', None)
|
|
114
|
+
params.pop('routing_', None)
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
results = await self.client.execute(cypher_query_, parameters=params)
|
|
118
|
+
except Exception as e:
|
|
119
|
+
params = {k: (v[:5] if isinstance(v, list) else v) for k, v in params.items()}
|
|
120
|
+
logger.error(f'Error executing Kuzu query: {e}\n{cypher_query_}\n{params}')
|
|
121
|
+
raise
|
|
122
|
+
|
|
123
|
+
if not results:
|
|
124
|
+
return [], None, None
|
|
125
|
+
|
|
126
|
+
if isinstance(results, list):
|
|
127
|
+
dict_results = [list(result.rows_as_dict()) for result in results]
|
|
128
|
+
else:
|
|
129
|
+
dict_results = list(results.rows_as_dict())
|
|
130
|
+
return dict_results, None, None # type: ignore
|
|
131
|
+
|
|
132
|
+
def session(self, _database: str | None = None) -> GraphDriverSession:
|
|
133
|
+
return KuzuDriverSession(self)
|
|
134
|
+
|
|
135
|
+
async def close(self):
|
|
136
|
+
# Do not explicity close the connection, instead rely on GC.
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
def delete_all_indexes(self, database_: str):
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
def setup_schema(self):
|
|
143
|
+
conn = kuzu.Connection(self.db)
|
|
144
|
+
conn.execute(SCHEMA_QUERIES)
|
|
145
|
+
conn.close()
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class KuzuDriverSession(GraphDriverSession):
|
|
149
|
+
provider = GraphProvider.KUZU
|
|
150
|
+
|
|
151
|
+
def __init__(self, driver: KuzuDriver):
|
|
152
|
+
self.driver = driver
|
|
153
|
+
|
|
154
|
+
async def __aenter__(self):
|
|
155
|
+
return self
|
|
156
|
+
|
|
157
|
+
async def __aexit__(self, exc_type, exc, tb):
|
|
158
|
+
# No cleanup needed for Kuzu, but method must exist.
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
async def close(self):
|
|
162
|
+
# Do not close the session here, as we're reusing the driver connection.
|
|
163
|
+
pass
|
|
164
|
+
|
|
165
|
+
async def execute_write(self, func, *args, **kwargs):
|
|
166
|
+
# Directly await the provided async function with `self` as the transaction/session
|
|
167
|
+
return await func(self, *args, **kwargs)
|
|
168
|
+
|
|
169
|
+
async def run(self, query: str | list, **kwargs: Any) -> Any:
|
|
170
|
+
if isinstance(query, list):
|
|
171
|
+
for cypher, params in query:
|
|
172
|
+
await self.driver.execute_query(cypher, **params)
|
|
173
|
+
else:
|
|
174
|
+
await self.driver.execute_query(query, **kwargs)
|
|
175
|
+
return None
|
|
@@ -271,6 +271,8 @@ class NeptuneDriver(GraphDriver):
|
|
|
271
271
|
|
|
272
272
|
|
|
273
273
|
class NeptuneDriverSession(GraphDriverSession):
|
|
274
|
+
provider = GraphProvider.NEPTUNE
|
|
275
|
+
|
|
274
276
|
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
|
|
275
277
|
self.driver = driver
|
|
276
278
|
|
graphiti_core/edges.py
CHANGED
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
from abc import ABC, abstractmethod
|
|
19
20
|
from datetime import datetime
|
|
@@ -30,11 +31,10 @@ from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
|
30
31
|
from graphiti_core.helpers import parse_db_date
|
|
31
32
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
32
33
|
COMMUNITY_EDGE_RETURN,
|
|
33
|
-
ENTITY_EDGE_RETURN,
|
|
34
|
-
ENTITY_EDGE_RETURN_NEPTUNE,
|
|
35
34
|
EPISODIC_EDGE_RETURN,
|
|
36
35
|
EPISODIC_EDGE_SAVE,
|
|
37
36
|
get_community_edge_save_query,
|
|
37
|
+
get_entity_edge_return_query,
|
|
38
38
|
get_entity_edge_save_query,
|
|
39
39
|
)
|
|
40
40
|
from graphiti_core.nodes import Node
|
|
@@ -53,33 +53,63 @@ class Edge(BaseModel, ABC):
|
|
|
53
53
|
async def save(self, driver: GraphDriver): ...
|
|
54
54
|
|
|
55
55
|
async def delete(self, driver: GraphDriver):
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
56
|
+
if driver.provider == GraphProvider.KUZU:
|
|
57
|
+
await driver.execute_query(
|
|
58
|
+
"""
|
|
59
|
+
MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
60
|
+
DELETE e
|
|
61
|
+
""",
|
|
62
|
+
uuid=self.uuid,
|
|
63
|
+
)
|
|
64
|
+
await driver.execute_query(
|
|
65
|
+
"""
|
|
66
|
+
MATCH (e:RelatesToNode_ {uuid: $uuid})
|
|
67
|
+
DETACH DELETE e
|
|
68
|
+
""",
|
|
69
|
+
uuid=self.uuid,
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
await driver.execute_query(
|
|
73
|
+
"""
|
|
74
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
75
|
+
DELETE e
|
|
76
|
+
""",
|
|
77
|
+
uuid=self.uuid,
|
|
78
|
+
)
|
|
63
79
|
|
|
64
80
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
65
81
|
|
|
66
|
-
return result
|
|
67
|
-
|
|
68
82
|
@classmethod
|
|
69
83
|
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
84
|
+
if driver.provider == GraphProvider.KUZU:
|
|
85
|
+
await driver.execute_query(
|
|
86
|
+
"""
|
|
87
|
+
MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
|
|
88
|
+
WHERE e.uuid IN $uuids
|
|
89
|
+
DELETE e
|
|
90
|
+
""",
|
|
91
|
+
uuids=uuids,
|
|
92
|
+
)
|
|
93
|
+
await driver.execute_query(
|
|
94
|
+
"""
|
|
95
|
+
MATCH (e:RelatesToNode_)
|
|
96
|
+
WHERE e.uuid IN $uuids
|
|
97
|
+
DETACH DELETE e
|
|
98
|
+
""",
|
|
99
|
+
uuids=uuids,
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
await driver.execute_query(
|
|
103
|
+
"""
|
|
104
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
|
|
105
|
+
WHERE e.uuid IN $uuids
|
|
106
|
+
DELETE e
|
|
107
|
+
""",
|
|
108
|
+
uuids=uuids,
|
|
109
|
+
)
|
|
78
110
|
|
|
79
111
|
logger.debug(f'Deleted Edges: {uuids}')
|
|
80
112
|
|
|
81
|
-
return result
|
|
82
|
-
|
|
83
113
|
def __hash__(self):
|
|
84
114
|
return hash(self.uuid)
|
|
85
115
|
|
|
@@ -166,7 +196,7 @@ class EpisodicEdge(Edge):
|
|
|
166
196
|
"""
|
|
167
197
|
+ EPISODIC_EDGE_RETURN
|
|
168
198
|
+ """
|
|
169
|
-
ORDER BY e.uuid DESC
|
|
199
|
+
ORDER BY e.uuid DESC
|
|
170
200
|
"""
|
|
171
201
|
+ limit_query,
|
|
172
202
|
group_ids=group_ids,
|
|
@@ -215,15 +245,21 @@ class EntityEdge(Edge):
|
|
|
215
245
|
return self.fact_embedding
|
|
216
246
|
|
|
217
247
|
async def load_fact_embedding(self, driver: GraphDriver):
|
|
218
|
-
|
|
219
|
-
query: LiteralString = """
|
|
248
|
+
query = """
|
|
220
249
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
250
|
+
RETURN e.fact_embedding AS fact_embedding
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
254
|
+
query = """
|
|
255
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
221
256
|
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
|
222
257
|
"""
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
258
|
+
|
|
259
|
+
if driver.provider == GraphProvider.KUZU:
|
|
260
|
+
query = """
|
|
261
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
|
|
262
|
+
RETURN e.fact_embedding AS fact_embedding
|
|
227
263
|
"""
|
|
228
264
|
|
|
229
265
|
records, _, _ = await driver.execute_query(
|
|
@@ -253,15 +289,22 @@ class EntityEdge(Edge):
|
|
|
253
289
|
'invalid_at': self.invalid_at,
|
|
254
290
|
}
|
|
255
291
|
|
|
256
|
-
|
|
292
|
+
if driver.provider == GraphProvider.KUZU:
|
|
293
|
+
edge_data['attributes'] = json.dumps(self.attributes)
|
|
294
|
+
result = await driver.execute_query(
|
|
295
|
+
get_entity_edge_save_query(driver.provider),
|
|
296
|
+
**edge_data,
|
|
297
|
+
)
|
|
298
|
+
else:
|
|
299
|
+
edge_data.update(self.attributes or {})
|
|
257
300
|
|
|
258
|
-
|
|
259
|
-
|
|
301
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
302
|
+
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
|
260
303
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
304
|
+
result = await driver.execute_query(
|
|
305
|
+
get_entity_edge_save_query(driver.provider),
|
|
306
|
+
edge_data=edge_data,
|
|
307
|
+
)
|
|
265
308
|
|
|
266
309
|
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
267
310
|
|
|
@@ -269,21 +312,25 @@ class EntityEdge(Edge):
|
|
|
269
312
|
|
|
270
313
|
@classmethod
|
|
271
314
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
272
|
-
|
|
273
|
-
"""
|
|
315
|
+
match_query = """
|
|
274
316
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
317
|
+
"""
|
|
318
|
+
if driver.provider == GraphProvider.KUZU:
|
|
319
|
+
match_query = """
|
|
320
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
records, _, _ = await driver.execute_query(
|
|
324
|
+
match_query
|
|
325
|
+
+ """
|
|
275
326
|
RETURN
|
|
276
327
|
"""
|
|
277
|
-
+ (
|
|
278
|
-
ENTITY_EDGE_RETURN_NEPTUNE
|
|
279
|
-
if driver.provider == GraphProvider.NEPTUNE
|
|
280
|
-
else ENTITY_EDGE_RETURN
|
|
281
|
-
),
|
|
328
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
282
329
|
uuid=uuid,
|
|
283
330
|
routing_='r',
|
|
284
331
|
)
|
|
285
332
|
|
|
286
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
333
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
287
334
|
|
|
288
335
|
if len(edges) == 0:
|
|
289
336
|
raise EdgeNotFoundError(uuid)
|
|
@@ -294,22 +341,26 @@ class EntityEdge(Edge):
|
|
|
294
341
|
if len(uuids) == 0:
|
|
295
342
|
return []
|
|
296
343
|
|
|
297
|
-
|
|
298
|
-
"""
|
|
344
|
+
match_query = """
|
|
299
345
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
346
|
+
"""
|
|
347
|
+
if driver.provider == GraphProvider.KUZU:
|
|
348
|
+
match_query = """
|
|
349
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
350
|
+
"""
|
|
351
|
+
|
|
352
|
+
records, _, _ = await driver.execute_query(
|
|
353
|
+
match_query
|
|
354
|
+
+ """
|
|
300
355
|
WHERE e.uuid IN $uuids
|
|
301
356
|
RETURN
|
|
302
357
|
"""
|
|
303
|
-
+ (
|
|
304
|
-
ENTITY_EDGE_RETURN_NEPTUNE
|
|
305
|
-
if driver.provider == GraphProvider.NEPTUNE
|
|
306
|
-
else ENTITY_EDGE_RETURN
|
|
307
|
-
),
|
|
358
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
308
359
|
uuids=uuids,
|
|
309
360
|
routing_='r',
|
|
310
361
|
)
|
|
311
362
|
|
|
312
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
363
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
313
364
|
|
|
314
365
|
return edges
|
|
315
366
|
|
|
@@ -332,23 +383,27 @@ class EntityEdge(Edge):
|
|
|
332
383
|
else ''
|
|
333
384
|
)
|
|
334
385
|
|
|
335
|
-
|
|
336
|
-
"""
|
|
386
|
+
match_query = """
|
|
337
387
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
388
|
+
"""
|
|
389
|
+
if driver.provider == GraphProvider.KUZU:
|
|
390
|
+
match_query = """
|
|
391
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
392
|
+
"""
|
|
393
|
+
|
|
394
|
+
records, _, _ = await driver.execute_query(
|
|
395
|
+
match_query
|
|
396
|
+
+ """
|
|
338
397
|
WHERE e.group_id IN $group_ids
|
|
339
398
|
"""
|
|
340
399
|
+ cursor_query
|
|
341
400
|
+ """
|
|
342
401
|
RETURN
|
|
343
402
|
"""
|
|
344
|
-
+ (
|
|
345
|
-
ENTITY_EDGE_RETURN_NEPTUNE
|
|
346
|
-
if driver.provider == GraphProvider.NEPTUNE
|
|
347
|
-
else ENTITY_EDGE_RETURN
|
|
348
|
-
)
|
|
403
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
349
404
|
+ with_embeddings_query
|
|
350
405
|
+ """
|
|
351
|
-
ORDER BY e.uuid DESC
|
|
406
|
+
ORDER BY e.uuid DESC
|
|
352
407
|
"""
|
|
353
408
|
+ limit_query,
|
|
354
409
|
group_ids=group_ids,
|
|
@@ -357,7 +412,7 @@ class EntityEdge(Edge):
|
|
|
357
412
|
routing_='r',
|
|
358
413
|
)
|
|
359
414
|
|
|
360
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
415
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
361
416
|
|
|
362
417
|
if len(edges) == 0:
|
|
363
418
|
raise GroupsEdgesNotFoundError(group_ids)
|
|
@@ -365,21 +420,25 @@ class EntityEdge(Edge):
|
|
|
365
420
|
|
|
366
421
|
@classmethod
|
|
367
422
|
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
|
368
|
-
|
|
369
|
-
"""
|
|
423
|
+
match_query = """
|
|
370
424
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
|
425
|
+
"""
|
|
426
|
+
if driver.provider == GraphProvider.KUZU:
|
|
427
|
+
match_query = """
|
|
428
|
+
MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
429
|
+
"""
|
|
430
|
+
|
|
431
|
+
records, _, _ = await driver.execute_query(
|
|
432
|
+
match_query
|
|
433
|
+
+ """
|
|
371
434
|
RETURN
|
|
372
435
|
"""
|
|
373
|
-
+ (
|
|
374
|
-
ENTITY_EDGE_RETURN_NEPTUNE
|
|
375
|
-
if driver.provider == GraphProvider.NEPTUNE
|
|
376
|
-
else ENTITY_EDGE_RETURN
|
|
377
|
-
),
|
|
436
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
378
437
|
node_uuid=node_uuid,
|
|
379
438
|
routing_='r',
|
|
380
439
|
)
|
|
381
440
|
|
|
382
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
441
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
383
442
|
|
|
384
443
|
return edges
|
|
385
444
|
|
|
@@ -479,7 +538,25 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
|
479
538
|
)
|
|
480
539
|
|
|
481
540
|
|
|
482
|
-
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
541
|
+
def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
|
|
542
|
+
episodes = record['episodes']
|
|
543
|
+
if provider == GraphProvider.KUZU:
|
|
544
|
+
attributes = json.loads(record['attributes']) if record['attributes'] else {}
|
|
545
|
+
else:
|
|
546
|
+
attributes = record['attributes']
|
|
547
|
+
attributes.pop('uuid', None)
|
|
548
|
+
attributes.pop('source_node_uuid', None)
|
|
549
|
+
attributes.pop('target_node_uuid', None)
|
|
550
|
+
attributes.pop('fact', None)
|
|
551
|
+
attributes.pop('fact_embedding', None)
|
|
552
|
+
attributes.pop('name', None)
|
|
553
|
+
attributes.pop('group_id', None)
|
|
554
|
+
attributes.pop('episodes', None)
|
|
555
|
+
attributes.pop('created_at', None)
|
|
556
|
+
attributes.pop('expired_at', None)
|
|
557
|
+
attributes.pop('valid_at', None)
|
|
558
|
+
attributes.pop('invalid_at', None)
|
|
559
|
+
|
|
483
560
|
edge = EntityEdge(
|
|
484
561
|
uuid=record['uuid'],
|
|
485
562
|
source_node_uuid=record['source_node_uuid'],
|
|
@@ -488,26 +565,14 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
|
488
565
|
fact_embedding=record.get('fact_embedding'),
|
|
489
566
|
name=record['name'],
|
|
490
567
|
group_id=record['group_id'],
|
|
491
|
-
episodes=
|
|
568
|
+
episodes=episodes,
|
|
492
569
|
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
493
570
|
expired_at=parse_db_date(record['expired_at']),
|
|
494
571
|
valid_at=parse_db_date(record['valid_at']),
|
|
495
572
|
invalid_at=parse_db_date(record['invalid_at']),
|
|
496
|
-
attributes=
|
|
573
|
+
attributes=attributes,
|
|
497
574
|
)
|
|
498
575
|
|
|
499
|
-
edge.attributes.pop('uuid', None)
|
|
500
|
-
edge.attributes.pop('source_node_uuid', None)
|
|
501
|
-
edge.attributes.pop('target_node_uuid', None)
|
|
502
|
-
edge.attributes.pop('fact', None)
|
|
503
|
-
edge.attributes.pop('name', None)
|
|
504
|
-
edge.attributes.pop('group_id', None)
|
|
505
|
-
edge.attributes.pop('episodes', None)
|
|
506
|
-
edge.attributes.pop('created_at', None)
|
|
507
|
-
edge.attributes.pop('expired_at', None)
|
|
508
|
-
edge.attributes.pop('valid_at', None)
|
|
509
|
-
edge.attributes.pop('invalid_at', None)
|
|
510
|
-
|
|
511
576
|
return edge
|
|
512
577
|
|
|
513
578
|
|
graphiti_core/graph_queries.py
CHANGED
|
@@ -16,6 +16,13 @@ NEO4J_TO_FALKORDB_MAPPING = {
|
|
|
16
16
|
'episode_content': 'Episodic',
|
|
17
17
|
'edge_name_and_fact': 'RELATES_TO',
|
|
18
18
|
}
|
|
19
|
+
# Mapping from fulltext index names to Kuzu node labels
|
|
20
|
+
INDEX_TO_LABEL_KUZU_MAPPING = {
|
|
21
|
+
'node_name_and_summary': 'Entity',
|
|
22
|
+
'community_name': 'Community',
|
|
23
|
+
'episode_content': 'Episodic',
|
|
24
|
+
'edge_name_and_fact': 'RelatesToNode_',
|
|
25
|
+
}
|
|
19
26
|
|
|
20
27
|
|
|
21
28
|
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
@@ -35,6 +42,9 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
|
35
42
|
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
36
43
|
]
|
|
37
44
|
|
|
45
|
+
if provider == GraphProvider.KUZU:
|
|
46
|
+
return []
|
|
47
|
+
|
|
38
48
|
return [
|
|
39
49
|
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
40
50
|
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
@@ -68,6 +78,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
|
68
78
|
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
|
69
79
|
]
|
|
70
80
|
|
|
81
|
+
if provider == GraphProvider.KUZU:
|
|
82
|
+
return [
|
|
83
|
+
"CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
|
|
84
|
+
"CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
|
|
85
|
+
"CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
|
|
86
|
+
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
|
|
87
|
+
]
|
|
88
|
+
|
|
71
89
|
return [
|
|
72
90
|
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
|
73
91
|
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
@@ -80,11 +98,15 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
|
80
98
|
]
|
|
81
99
|
|
|
82
100
|
|
|
83
|
-
def get_nodes_query(
|
|
101
|
+
def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
|
|
84
102
|
if provider == GraphProvider.FALKORDB:
|
|
85
103
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
86
104
|
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
|
87
105
|
|
|
106
|
+
if provider == GraphProvider.KUZU:
|
|
107
|
+
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
|
108
|
+
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
|
|
109
|
+
|
|
88
110
|
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
|
89
111
|
|
|
90
112
|
|
|
@@ -93,12 +115,19 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
|
|
93
115
|
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
|
94
116
|
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
|
95
117
|
|
|
118
|
+
if provider == GraphProvider.KUZU:
|
|
119
|
+
return f'array_cosine_similarity({vec1}, {vec2})'
|
|
120
|
+
|
|
96
121
|
return f'vector.similarity.cosine({vec1}, {vec2})'
|
|
97
122
|
|
|
98
123
|
|
|
99
|
-
def get_relationships_query(name: str, provider: GraphProvider) -> str:
|
|
124
|
+
def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
|
|
100
125
|
if provider == GraphProvider.FALKORDB:
|
|
101
126
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
102
127
|
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
|
103
128
|
|
|
129
|
+
if provider == GraphProvider.KUZU:
|
|
130
|
+
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
|
131
|
+
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
|
|
132
|
+
|
|
104
133
|
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|