graphiti-core 0.17.11__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +20 -2
- graphiti_core/driver/falkordb_driver.py +16 -9
- graphiti_core/driver/neo4j_driver.py +8 -6
- graphiti_core/edges.py +73 -99
- graphiti_core/graph_queries.py +51 -97
- graphiti_core/graphiti.py +24 -9
- graphiti_core/helpers.py +3 -2
- graphiti_core/models/edges/edge_db_queries.py +106 -32
- graphiti_core/models/nodes/node_db_queries.py +101 -20
- graphiti_core/nodes.py +113 -128
- graphiti_core/prompts/dedupe_nodes.py +1 -1
- graphiti_core/prompts/extract_edges.py +4 -4
- graphiti_core/prompts/extract_nodes.py +12 -10
- graphiti_core/search/search.py +44 -32
- graphiti_core/search/search_config.py +8 -4
- graphiti_core/search/search_filters.py +5 -5
- graphiti_core/search/search_utils.py +154 -189
- graphiti_core/utils/bulk_utils.py +3 -5
- graphiti_core/utils/maintenance/community_operations.py +11 -7
- graphiti_core/utils/maintenance/edge_operations.py +19 -50
- graphiti_core/utils/maintenance/graph_data_operations.py +14 -29
- graphiti_core/utils/maintenance/node_operations.py +11 -55
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/METADATA +11 -3
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/RECORD +26 -26
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/WHEEL +0 -0
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/licenses/LICENSE +0 -0
graphiti_core/driver/driver.py
CHANGED
|
@@ -14,14 +14,21 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import copy
|
|
17
18
|
import logging
|
|
18
19
|
from abc import ABC, abstractmethod
|
|
19
20
|
from collections.abc import Coroutine
|
|
21
|
+
from enum import Enum
|
|
20
22
|
from typing import Any
|
|
21
23
|
|
|
22
24
|
logger = logging.getLogger(__name__)
|
|
23
25
|
|
|
24
26
|
|
|
27
|
+
class GraphProvider(Enum):
|
|
28
|
+
NEO4J = 'neo4j'
|
|
29
|
+
FALKORDB = 'falkordb'
|
|
30
|
+
|
|
31
|
+
|
|
25
32
|
class GraphDriverSession(ABC):
|
|
26
33
|
async def __aenter__(self):
|
|
27
34
|
return self
|
|
@@ -45,10 +52,11 @@ class GraphDriverSession(ABC):
|
|
|
45
52
|
|
|
46
53
|
|
|
47
54
|
class GraphDriver(ABC):
|
|
48
|
-
provider:
|
|
55
|
+
provider: GraphProvider
|
|
49
56
|
fulltext_syntax: str = (
|
|
50
57
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
|
51
58
|
)
|
|
59
|
+
_database: str
|
|
52
60
|
|
|
53
61
|
@abstractmethod
|
|
54
62
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
@@ -63,5 +71,15 @@ class GraphDriver(ABC):
|
|
|
63
71
|
raise NotImplementedError()
|
|
64
72
|
|
|
65
73
|
@abstractmethod
|
|
66
|
-
def delete_all_indexes(self
|
|
74
|
+
def delete_all_indexes(self) -> Coroutine:
|
|
67
75
|
raise NotImplementedError()
|
|
76
|
+
|
|
77
|
+
def with_database(self, database: str) -> 'GraphDriver':
|
|
78
|
+
"""
|
|
79
|
+
Returns a shallow copy of this driver with a different default database.
|
|
80
|
+
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
|
81
|
+
"""
|
|
82
|
+
cloned = copy.copy(self)
|
|
83
|
+
cloned._database = database
|
|
84
|
+
|
|
85
|
+
return cloned
|
|
@@ -32,7 +32,7 @@ else:
|
|
|
32
32
|
'Install it with: pip install graphiti-core[falkordb]'
|
|
33
33
|
) from None
|
|
34
34
|
|
|
35
|
-
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
35
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
36
36
|
|
|
37
37
|
logger = logging.getLogger(__name__)
|
|
38
38
|
|
|
@@ -71,7 +71,7 @@ class FalkorDriverSession(GraphDriverSession):
|
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
class FalkorDriver(GraphDriver):
|
|
74
|
-
provider
|
|
74
|
+
provider = GraphProvider.FALKORDB
|
|
75
75
|
|
|
76
76
|
def __init__(
|
|
77
77
|
self,
|
|
@@ -90,12 +90,13 @@ class FalkorDriver(GraphDriver):
|
|
|
90
90
|
The default parameters assume a local (on-premises) FalkorDB instance.
|
|
91
91
|
"""
|
|
92
92
|
super().__init__()
|
|
93
|
+
|
|
94
|
+
self._database = database
|
|
93
95
|
if falkor_db is not None:
|
|
94
96
|
# If a FalkorDB instance is provided, use it directly
|
|
95
97
|
self.client = falkor_db
|
|
96
98
|
else:
|
|
97
99
|
self.client = FalkorDB(host=host, port=port, username=username, password=password)
|
|
98
|
-
self._database = database
|
|
99
100
|
|
|
100
101
|
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
|
|
101
102
|
|
|
@@ -106,8 +107,7 @@ class FalkorDriver(GraphDriver):
|
|
|
106
107
|
return self.client.select_graph(graph_name)
|
|
107
108
|
|
|
108
109
|
async def execute_query(self, cypher_query_, **kwargs: Any):
|
|
109
|
-
|
|
110
|
-
graph = self._get_graph(graph_name)
|
|
110
|
+
graph = self._get_graph(self._database)
|
|
111
111
|
|
|
112
112
|
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
|
|
113
113
|
params = convert_datetimes_to_strings(dict(kwargs))
|
|
@@ -119,7 +119,7 @@ class FalkorDriver(GraphDriver):
|
|
|
119
119
|
# check if index already exists
|
|
120
120
|
logger.info(f'Index already exists: {e}')
|
|
121
121
|
return None
|
|
122
|
-
logger.error(f'Error executing FalkorDB query: {e}')
|
|
122
|
+
logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
|
|
123
123
|
raise
|
|
124
124
|
|
|
125
125
|
# Convert the result header to a list of strings
|
|
@@ -151,13 +151,20 @@ class FalkorDriver(GraphDriver):
|
|
|
151
151
|
elif hasattr(self.client.connection, 'close'):
|
|
152
152
|
await self.client.connection.close()
|
|
153
153
|
|
|
154
|
-
async def delete_all_indexes(self
|
|
155
|
-
database = database_ or self._database
|
|
154
|
+
async def delete_all_indexes(self) -> None:
|
|
156
155
|
await self.execute_query(
|
|
157
156
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
158
|
-
database_=database,
|
|
159
157
|
)
|
|
160
158
|
|
|
159
|
+
def clone(self, database: str) -> 'GraphDriver':
|
|
160
|
+
"""
|
|
161
|
+
Returns a shallow copy of this driver with a different default database.
|
|
162
|
+
Reuses the same connection (e.g. FalkorDB, Neo4j).
|
|
163
|
+
"""
|
|
164
|
+
cloned = FalkorDriver(falkor_db=self.client, database=database)
|
|
165
|
+
|
|
166
|
+
return cloned
|
|
167
|
+
|
|
161
168
|
|
|
162
169
|
def convert_datetimes_to_strings(obj):
|
|
163
170
|
if isinstance(obj, dict):
|
|
@@ -21,13 +21,13 @@ from typing import Any
|
|
|
21
21
|
from neo4j import AsyncGraphDatabase, EagerResult
|
|
22
22
|
from typing_extensions import LiteralString
|
|
23
23
|
|
|
24
|
-
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
24
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
25
25
|
|
|
26
26
|
logger = logging.getLogger(__name__)
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class Neo4jDriver(GraphDriver):
|
|
30
|
-
provider
|
|
30
|
+
provider = GraphProvider.NEO4J
|
|
31
31
|
|
|
32
32
|
def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
|
|
33
33
|
super().__init__()
|
|
@@ -45,7 +45,11 @@ class Neo4jDriver(GraphDriver):
|
|
|
45
45
|
params = {}
|
|
46
46
|
params.setdefault('database_', self._database)
|
|
47
47
|
|
|
48
|
-
|
|
48
|
+
try:
|
|
49
|
+
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.error(f'Error executing Neo4j query: {e}\n{cypher_query_}\n{params}')
|
|
52
|
+
raise
|
|
49
53
|
|
|
50
54
|
return result
|
|
51
55
|
|
|
@@ -56,9 +60,7 @@ class Neo4jDriver(GraphDriver):
|
|
|
56
60
|
async def close(self) -> None:
|
|
57
61
|
return await self.client.close()
|
|
58
62
|
|
|
59
|
-
def delete_all_indexes(self
|
|
60
|
-
database = database_ or self._database
|
|
63
|
+
def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
|
|
61
64
|
return self.client.execute_query(
|
|
62
65
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
63
|
-
database_=database,
|
|
64
66
|
)
|
graphiti_core/edges.py
CHANGED
|
@@ -29,29 +29,17 @@ from graphiti_core.embedder import EmbedderClient
|
|
|
29
29
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
30
30
|
from graphiti_core.helpers import parse_db_date
|
|
31
31
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
32
|
-
|
|
33
|
-
|
|
32
|
+
COMMUNITY_EDGE_RETURN,
|
|
33
|
+
ENTITY_EDGE_RETURN,
|
|
34
|
+
EPISODIC_EDGE_RETURN,
|
|
34
35
|
EPISODIC_EDGE_SAVE,
|
|
36
|
+
get_community_edge_save_query,
|
|
37
|
+
get_entity_edge_save_query,
|
|
35
38
|
)
|
|
36
39
|
from graphiti_core.nodes import Node
|
|
37
40
|
|
|
38
41
|
logger = logging.getLogger(__name__)
|
|
39
42
|
|
|
40
|
-
ENTITY_EDGE_RETURN: LiteralString = """
|
|
41
|
-
RETURN
|
|
42
|
-
e.uuid AS uuid,
|
|
43
|
-
startNode(e).uuid AS source_node_uuid,
|
|
44
|
-
endNode(e).uuid AS target_node_uuid,
|
|
45
|
-
e.created_at AS created_at,
|
|
46
|
-
e.name AS name,
|
|
47
|
-
e.group_id AS group_id,
|
|
48
|
-
e.fact AS fact,
|
|
49
|
-
e.episodes AS episodes,
|
|
50
|
-
e.expired_at AS expired_at,
|
|
51
|
-
e.valid_at AS valid_at,
|
|
52
|
-
e.invalid_at AS invalid_at,
|
|
53
|
-
properties(e) AS attributes"""
|
|
54
|
-
|
|
55
43
|
|
|
56
44
|
class Edge(BaseModel, ABC):
|
|
57
45
|
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
|
@@ -66,9 +54,9 @@ class Edge(BaseModel, ABC):
|
|
|
66
54
|
async def delete(self, driver: GraphDriver):
|
|
67
55
|
result = await driver.execute_query(
|
|
68
56
|
"""
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
57
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
58
|
+
DELETE e
|
|
59
|
+
""",
|
|
72
60
|
uuid=self.uuid,
|
|
73
61
|
)
|
|
74
62
|
|
|
@@ -107,14 +95,10 @@ class EpisodicEdge(Edge):
|
|
|
107
95
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
108
96
|
records, _, _ = await driver.execute_query(
|
|
109
97
|
"""
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
n.uuid AS source_node_uuid,
|
|
115
|
-
m.uuid AS target_node_uuid,
|
|
116
|
-
e.created_at AS created_at
|
|
117
|
-
""",
|
|
98
|
+
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
99
|
+
RETURN
|
|
100
|
+
"""
|
|
101
|
+
+ EPISODIC_EDGE_RETURN,
|
|
118
102
|
uuid=uuid,
|
|
119
103
|
routing_='r',
|
|
120
104
|
)
|
|
@@ -129,15 +113,11 @@ class EpisodicEdge(Edge):
|
|
|
129
113
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
130
114
|
records, _, _ = await driver.execute_query(
|
|
131
115
|
"""
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
n.uuid AS source_node_uuid,
|
|
138
|
-
m.uuid AS target_node_uuid,
|
|
139
|
-
e.created_at AS created_at
|
|
140
|
-
""",
|
|
116
|
+
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
|
117
|
+
WHERE e.uuid IN $uuids
|
|
118
|
+
RETURN
|
|
119
|
+
"""
|
|
120
|
+
+ EPISODIC_EDGE_RETURN,
|
|
141
121
|
uuids=uuids,
|
|
142
122
|
routing_='r',
|
|
143
123
|
)
|
|
@@ -161,19 +141,17 @@ class EpisodicEdge(Edge):
|
|
|
161
141
|
|
|
162
142
|
records, _, _ = await driver.execute_query(
|
|
163
143
|
"""
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
144
|
+
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
|
145
|
+
WHERE e.group_id IN $group_ids
|
|
146
|
+
"""
|
|
167
147
|
+ cursor_query
|
|
168
148
|
+ """
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
ORDER BY e.uuid DESC
|
|
176
|
-
"""
|
|
149
|
+
RETURN
|
|
150
|
+
"""
|
|
151
|
+
+ EPISODIC_EDGE_RETURN
|
|
152
|
+
+ """
|
|
153
|
+
ORDER BY e.uuid DESC
|
|
154
|
+
"""
|
|
177
155
|
+ limit_query,
|
|
178
156
|
group_ids=group_ids,
|
|
179
157
|
uuid=uuid_cursor,
|
|
@@ -221,11 +199,14 @@ class EntityEdge(Edge):
|
|
|
221
199
|
return self.fact_embedding
|
|
222
200
|
|
|
223
201
|
async def load_fact_embedding(self, driver: GraphDriver):
|
|
224
|
-
|
|
202
|
+
records, _, _ = await driver.execute_query(
|
|
203
|
+
"""
|
|
225
204
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
226
205
|
RETURN e.fact_embedding AS fact_embedding
|
|
227
|
-
|
|
228
|
-
|
|
206
|
+
""",
|
|
207
|
+
uuid=self.uuid,
|
|
208
|
+
routing_='r',
|
|
209
|
+
)
|
|
229
210
|
|
|
230
211
|
if len(records) == 0:
|
|
231
212
|
raise EdgeNotFoundError(self.uuid)
|
|
@@ -251,7 +232,7 @@ class EntityEdge(Edge):
|
|
|
251
232
|
edge_data.update(self.attributes or {})
|
|
252
233
|
|
|
253
234
|
result = await driver.execute_query(
|
|
254
|
-
|
|
235
|
+
get_entity_edge_save_query(driver.provider),
|
|
255
236
|
edge_data=edge_data,
|
|
256
237
|
)
|
|
257
238
|
|
|
@@ -263,8 +244,9 @@ class EntityEdge(Edge):
|
|
|
263
244
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
264
245
|
records, _, _ = await driver.execute_query(
|
|
265
246
|
"""
|
|
266
|
-
|
|
267
|
-
|
|
247
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
248
|
+
RETURN
|
|
249
|
+
"""
|
|
268
250
|
+ ENTITY_EDGE_RETURN,
|
|
269
251
|
uuid=uuid,
|
|
270
252
|
routing_='r',
|
|
@@ -283,9 +265,10 @@ class EntityEdge(Edge):
|
|
|
283
265
|
|
|
284
266
|
records, _, _ = await driver.execute_query(
|
|
285
267
|
"""
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
268
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
269
|
+
WHERE e.uuid IN $uuids
|
|
270
|
+
RETURN
|
|
271
|
+
"""
|
|
289
272
|
+ ENTITY_EDGE_RETURN,
|
|
290
273
|
uuids=uuids,
|
|
291
274
|
routing_='r',
|
|
@@ -314,22 +297,21 @@ class EntityEdge(Edge):
|
|
|
314
297
|
else ''
|
|
315
298
|
)
|
|
316
299
|
|
|
317
|
-
|
|
300
|
+
records, _, _ = await driver.execute_query(
|
|
318
301
|
"""
|
|
319
302
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
320
303
|
WHERE e.group_id IN $group_ids
|
|
321
304
|
"""
|
|
322
305
|
+ cursor_query
|
|
306
|
+
+ """
|
|
307
|
+
RETURN
|
|
308
|
+
"""
|
|
323
309
|
+ ENTITY_EDGE_RETURN
|
|
324
310
|
+ with_embeddings_query
|
|
325
311
|
+ """
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
+ limit_query
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
records, _, _ = await driver.execute_query(
|
|
332
|
-
query,
|
|
312
|
+
ORDER BY e.uuid DESC
|
|
313
|
+
"""
|
|
314
|
+
+ limit_query,
|
|
333
315
|
group_ids=group_ids,
|
|
334
316
|
uuid=uuid_cursor,
|
|
335
317
|
limit=limit,
|
|
@@ -344,13 +326,15 @@ class EntityEdge(Edge):
|
|
|
344
326
|
|
|
345
327
|
@classmethod
|
|
346
328
|
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
|
347
|
-
|
|
329
|
+
records, _, _ = await driver.execute_query(
|
|
348
330
|
"""
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
331
|
+
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
|
332
|
+
RETURN
|
|
333
|
+
"""
|
|
334
|
+
+ ENTITY_EDGE_RETURN,
|
|
335
|
+
node_uuid=node_uuid,
|
|
336
|
+
routing_='r',
|
|
352
337
|
)
|
|
353
|
-
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
|
|
354
338
|
|
|
355
339
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
356
340
|
|
|
@@ -360,7 +344,7 @@ class EntityEdge(Edge):
|
|
|
360
344
|
class CommunityEdge(Edge):
|
|
361
345
|
async def save(self, driver: GraphDriver):
|
|
362
346
|
result = await driver.execute_query(
|
|
363
|
-
|
|
347
|
+
get_community_edge_save_query(driver.provider),
|
|
364
348
|
community_uuid=self.source_node_uuid,
|
|
365
349
|
entity_uuid=self.target_node_uuid,
|
|
366
350
|
uuid=self.uuid,
|
|
@@ -376,14 +360,10 @@ class CommunityEdge(Edge):
|
|
|
376
360
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
377
361
|
records, _, _ = await driver.execute_query(
|
|
378
362
|
"""
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
n.uuid AS source_node_uuid,
|
|
384
|
-
m.uuid AS target_node_uuid,
|
|
385
|
-
e.created_at AS created_at
|
|
386
|
-
""",
|
|
363
|
+
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
|
|
364
|
+
RETURN
|
|
365
|
+
"""
|
|
366
|
+
+ COMMUNITY_EDGE_RETURN,
|
|
387
367
|
uuid=uuid,
|
|
388
368
|
routing_='r',
|
|
389
369
|
)
|
|
@@ -396,15 +376,11 @@ class CommunityEdge(Edge):
|
|
|
396
376
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
397
377
|
records, _, _ = await driver.execute_query(
|
|
398
378
|
"""
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
n.uuid AS source_node_uuid,
|
|
405
|
-
m.uuid AS target_node_uuid,
|
|
406
|
-
e.created_at AS created_at
|
|
407
|
-
""",
|
|
379
|
+
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
|
|
380
|
+
WHERE e.uuid IN $uuids
|
|
381
|
+
RETURN
|
|
382
|
+
"""
|
|
383
|
+
+ COMMUNITY_EDGE_RETURN,
|
|
408
384
|
uuids=uuids,
|
|
409
385
|
routing_='r',
|
|
410
386
|
)
|
|
@@ -426,19 +402,17 @@ class CommunityEdge(Edge):
|
|
|
426
402
|
|
|
427
403
|
records, _, _ = await driver.execute_query(
|
|
428
404
|
"""
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
405
|
+
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
|
|
406
|
+
WHERE e.group_id IN $group_ids
|
|
407
|
+
"""
|
|
432
408
|
+ cursor_query
|
|
433
409
|
+ """
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
ORDER BY e.uuid DESC
|
|
441
|
-
"""
|
|
410
|
+
RETURN
|
|
411
|
+
"""
|
|
412
|
+
+ COMMUNITY_EDGE_RETURN
|
|
413
|
+
+ """
|
|
414
|
+
ORDER BY e.uuid DESC
|
|
415
|
+
"""
|
|
442
416
|
+ limit_query,
|
|
443
417
|
group_ids=group_ids,
|
|
444
418
|
uuid=uuid_cursor,
|
graphiti_core/graph_queries.py
CHANGED
|
@@ -5,16 +5,9 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB,
|
|
|
5
5
|
supporting index creation, fulltext search, and bulk operations.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import Any
|
|
9
|
-
|
|
10
8
|
from typing_extensions import LiteralString
|
|
11
9
|
|
|
12
|
-
from graphiti_core.
|
|
13
|
-
ENTITY_EDGE_SAVE_BULK,
|
|
14
|
-
)
|
|
15
|
-
from graphiti_core.models.nodes.node_db_queries import (
|
|
16
|
-
ENTITY_NODE_SAVE_BULK,
|
|
17
|
-
)
|
|
10
|
+
from graphiti_core.driver.driver import GraphProvider
|
|
18
11
|
|
|
19
12
|
# Mapping from Neo4j fulltext index names to FalkorDB node labels
|
|
20
13
|
NEO4J_TO_FALKORDB_MAPPING = {
|
|
@@ -25,8 +18,8 @@ NEO4J_TO_FALKORDB_MAPPING = {
|
|
|
25
18
|
}
|
|
26
19
|
|
|
27
20
|
|
|
28
|
-
def get_range_indices(
|
|
29
|
-
if
|
|
21
|
+
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
22
|
+
if provider == GraphProvider.FALKORDB:
|
|
30
23
|
return [
|
|
31
24
|
# Entity node
|
|
32
25
|
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
|
|
@@ -41,109 +34,70 @@ def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
|
|
|
41
34
|
# HAS_MEMBER edge
|
|
42
35
|
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
43
36
|
]
|
|
44
|
-
else:
|
|
45
|
-
return [
|
|
46
|
-
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
47
|
-
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
48
|
-
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
|
49
|
-
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
50
|
-
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
51
|
-
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
52
|
-
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
|
53
|
-
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
|
54
|
-
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
|
55
|
-
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
|
56
|
-
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
|
57
|
-
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
|
58
|
-
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
|
59
|
-
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
|
60
|
-
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
|
61
|
-
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
|
62
|
-
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
|
63
|
-
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
|
64
|
-
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
|
65
|
-
]
|
|
66
37
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
38
|
+
return [
|
|
39
|
+
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
40
|
+
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
41
|
+
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
|
42
|
+
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
|
43
|
+
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
|
44
|
+
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
45
|
+
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
|
46
|
+
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
|
47
|
+
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
|
48
|
+
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
|
|
49
|
+
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
|
|
50
|
+
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
|
|
51
|
+
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
|
|
52
|
+
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
|
|
53
|
+
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
|
|
54
|
+
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
|
|
55
|
+
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
|
|
56
|
+
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
|
|
57
|
+
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
62
|
+
if provider == GraphProvider.FALKORDB:
|
|
70
63
|
return [
|
|
71
64
|
"""CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
|
|
72
65
|
"""CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
|
|
73
66
|
"""CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
|
|
74
67
|
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
|
75
68
|
]
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
69
|
+
|
|
70
|
+
return [
|
|
71
|
+
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
|
72
|
+
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
73
|
+
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
|
74
|
+
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
|
75
|
+
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
|
76
|
+
FOR (n:Community) ON EACH [n.name, n.group_id]""",
|
|
77
|
+
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
|
|
78
|
+
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
|
|
79
|
+
]
|
|
87
80
|
|
|
88
81
|
|
|
89
|
-
def get_nodes_query(
|
|
90
|
-
if
|
|
82
|
+
def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str:
|
|
83
|
+
if provider == GraphProvider.FALKORDB:
|
|
91
84
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
92
85
|
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
|
93
|
-
|
|
94
|
-
|
|
86
|
+
|
|
87
|
+
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
|
95
88
|
|
|
96
89
|
|
|
97
|
-
def get_vector_cosine_func_query(vec1, vec2,
|
|
98
|
-
if
|
|
90
|
+
def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
|
91
|
+
if provider == GraphProvider.FALKORDB:
|
|
99
92
|
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
|
100
93
|
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
|
101
|
-
else:
|
|
102
|
-
return f'vector.similarity.cosine({vec1}, {vec2})'
|
|
103
94
|
|
|
95
|
+
return f'vector.similarity.cosine({vec1}, {vec2})'
|
|
104
96
|
|
|
105
|
-
|
|
106
|
-
|
|
97
|
+
|
|
98
|
+
def get_relationships_query(name: str, provider: GraphProvider) -> str:
|
|
99
|
+
if provider == GraphProvider.FALKORDB:
|
|
107
100
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
108
101
|
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
|
|
114
|
-
if db_type == 'falkordb':
|
|
115
|
-
queries = []
|
|
116
|
-
for node in nodes:
|
|
117
|
-
for label in node['labels']:
|
|
118
|
-
queries.append(
|
|
119
|
-
(
|
|
120
|
-
f"""
|
|
121
|
-
UNWIND $nodes AS node
|
|
122
|
-
MERGE (n:Entity {{uuid: node.uuid}})
|
|
123
|
-
SET n:{label}
|
|
124
|
-
SET n = node
|
|
125
|
-
WITH n, node
|
|
126
|
-
SET n.name_embedding = vecf32(node.name_embedding)
|
|
127
|
-
RETURN n.uuid AS uuid
|
|
128
|
-
""",
|
|
129
|
-
{'nodes': [node]},
|
|
130
|
-
)
|
|
131
|
-
)
|
|
132
|
-
return queries
|
|
133
|
-
else:
|
|
134
|
-
return ENTITY_NODE_SAVE_BULK
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
|
|
138
|
-
if db_type == 'falkordb':
|
|
139
|
-
return """
|
|
140
|
-
UNWIND $entity_edges AS edge
|
|
141
|
-
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
142
|
-
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
143
|
-
MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
144
|
-
SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
|
|
145
|
-
created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
|
|
146
|
-
WITH r, edge
|
|
147
|
-
RETURN edge.uuid AS uuid"""
|
|
148
|
-
else:
|
|
149
|
-
return ENTITY_EDGE_SAVE_BULK
|
|
102
|
+
|
|
103
|
+
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|