graphiti-core 0.18.0__py3-none-any.whl → 0.18.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -14,14 +14,21 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import copy
17
18
  import logging
18
19
  from abc import ABC, abstractmethod
19
20
  from collections.abc import Coroutine
21
+ from enum import Enum
20
22
  from typing import Any
21
23
 
22
24
  logger = logging.getLogger(__name__)
23
25
 
24
26
 
27
+ class GraphProvider(Enum):
28
+ NEO4J = 'neo4j'
29
+ FALKORDB = 'falkordb'
30
+
31
+
25
32
  class GraphDriverSession(ABC):
26
33
  async def __aenter__(self):
27
34
  return self
@@ -45,10 +52,11 @@ class GraphDriverSession(ABC):
45
52
 
46
53
 
47
54
  class GraphDriver(ABC):
48
- provider: str
55
+ provider: GraphProvider
49
56
  fulltext_syntax: str = (
50
57
  '' # Neo4j (default) syntax does not require a prefix for fulltext queries
51
58
  )
59
+ _database: str
52
60
 
53
61
  @abstractmethod
54
62
  def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@@ -63,5 +71,15 @@ class GraphDriver(ABC):
63
71
  raise NotImplementedError()
64
72
 
65
73
  @abstractmethod
66
- def delete_all_indexes(self, database_: str | None = None) -> Coroutine:
74
+ def delete_all_indexes(self) -> Coroutine:
67
75
  raise NotImplementedError()
76
+
77
+ def with_database(self, database: str) -> 'GraphDriver':
78
+ """
79
+ Returns a shallow copy of this driver with a different default database.
80
+ Reuses the same connection (e.g. FalkorDB, Neo4j).
81
+ """
82
+ cloned = copy.copy(self)
83
+ cloned._database = database
84
+
85
+ return cloned
@@ -32,7 +32,7 @@ else:
32
32
  'Install it with: pip install graphiti-core[falkordb]'
33
33
  ) from None
34
34
 
35
- from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
35
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
36
36
 
37
37
  logger = logging.getLogger(__name__)
38
38
 
@@ -71,7 +71,7 @@ class FalkorDriverSession(GraphDriverSession):
71
71
 
72
72
 
73
73
  class FalkorDriver(GraphDriver):
74
- provider: str = 'falkordb'
74
+ provider = GraphProvider.FALKORDB
75
75
 
76
76
  def __init__(
77
77
  self,
@@ -90,12 +90,13 @@ class FalkorDriver(GraphDriver):
90
90
  The default parameters assume a local (on-premises) FalkorDB instance.
91
91
  """
92
92
  super().__init__()
93
+
94
+ self._database = database
93
95
  if falkor_db is not None:
94
96
  # If a FalkorDB instance is provided, use it directly
95
97
  self.client = falkor_db
96
98
  else:
97
99
  self.client = FalkorDB(host=host, port=port, username=username, password=password)
98
- self._database = database
99
100
 
100
101
  self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
101
102
 
@@ -106,8 +107,7 @@ class FalkorDriver(GraphDriver):
106
107
  return self.client.select_graph(graph_name)
107
108
 
108
109
  async def execute_query(self, cypher_query_, **kwargs: Any):
109
- graph_name = kwargs.pop('database_', self._database)
110
- graph = self._get_graph(graph_name)
110
+ graph = self._get_graph(self._database)
111
111
 
112
112
  # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
113
113
  params = convert_datetimes_to_strings(dict(kwargs))
@@ -119,7 +119,7 @@ class FalkorDriver(GraphDriver):
119
119
  # check if index already exists
120
120
  logger.info(f'Index already exists: {e}')
121
121
  return None
122
- logger.error(f'Error executing FalkorDB query: {e}')
122
+ logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
123
123
  raise
124
124
 
125
125
  # Convert the result header to a list of strings
@@ -151,13 +151,20 @@ class FalkorDriver(GraphDriver):
151
151
  elif hasattr(self.client.connection, 'close'):
152
152
  await self.client.connection.close()
153
153
 
154
- async def delete_all_indexes(self, database_: str | None = None) -> None:
155
- database = database_ or self._database
154
+ async def delete_all_indexes(self) -> None:
156
155
  await self.execute_query(
157
156
  'CALL db.indexes() YIELD name DROP INDEX name',
158
- database_=database,
159
157
  )
160
158
 
159
+ def clone(self, database: str) -> 'GraphDriver':
160
+ """
161
+ Returns a shallow copy of this driver with a different default database.
162
+ Reuses the same connection (e.g. FalkorDB, Neo4j).
163
+ """
164
+ cloned = FalkorDriver(falkor_db=self.client, database=database)
165
+
166
+ return cloned
167
+
161
168
 
162
169
  def convert_datetimes_to_strings(obj):
163
170
  if isinstance(obj, dict):
@@ -21,13 +21,13 @@ from typing import Any
21
21
  from neo4j import AsyncGraphDatabase, EagerResult
22
22
  from typing_extensions import LiteralString
23
23
 
24
- from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
24
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
25
25
 
26
26
  logger = logging.getLogger(__name__)
27
27
 
28
28
 
29
29
  class Neo4jDriver(GraphDriver):
30
- provider: str = 'neo4j'
30
+ provider = GraphProvider.NEO4J
31
31
 
32
32
  def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
33
33
  super().__init__()
@@ -45,7 +45,11 @@ class Neo4jDriver(GraphDriver):
45
45
  params = {}
46
46
  params.setdefault('database_', self._database)
47
47
 
48
- result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
48
+ try:
49
+ result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
50
+ except Exception as e:
51
+ logger.error(f'Error executing Neo4j query: {e}\n{cypher_query_}\n{params}')
52
+ raise
49
53
 
50
54
  return result
51
55
 
@@ -56,9 +60,7 @@ class Neo4jDriver(GraphDriver):
56
60
  async def close(self) -> None:
57
61
  return await self.client.close()
58
62
 
59
- def delete_all_indexes(self, database_: str | None = None) -> Coroutine[Any, Any, EagerResult]:
60
- database = database_ or self._database
63
+ def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
61
64
  return self.client.execute_query(
62
65
  'CALL db.indexes() YIELD name DROP INDEX name',
63
- database_=database,
64
66
  )
graphiti_core/edges.py CHANGED
@@ -29,29 +29,17 @@ from graphiti_core.embedder import EmbedderClient
29
29
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
30
30
  from graphiti_core.helpers import parse_db_date
31
31
  from graphiti_core.models.edges.edge_db_queries import (
32
- COMMUNITY_EDGE_SAVE,
33
- ENTITY_EDGE_SAVE,
32
+ COMMUNITY_EDGE_RETURN,
33
+ ENTITY_EDGE_RETURN,
34
+ EPISODIC_EDGE_RETURN,
34
35
  EPISODIC_EDGE_SAVE,
36
+ get_community_edge_save_query,
37
+ get_entity_edge_save_query,
35
38
  )
36
39
  from graphiti_core.nodes import Node
37
40
 
38
41
  logger = logging.getLogger(__name__)
39
42
 
40
- ENTITY_EDGE_RETURN: LiteralString = """
41
- RETURN
42
- e.uuid AS uuid,
43
- startNode(e).uuid AS source_node_uuid,
44
- endNode(e).uuid AS target_node_uuid,
45
- e.created_at AS created_at,
46
- e.name AS name,
47
- e.group_id AS group_id,
48
- e.fact AS fact,
49
- e.episodes AS episodes,
50
- e.expired_at AS expired_at,
51
- e.valid_at AS valid_at,
52
- e.invalid_at AS invalid_at,
53
- properties(e) AS attributes"""
54
-
55
43
 
56
44
  class Edge(BaseModel, ABC):
57
45
  uuid: str = Field(default_factory=lambda: str(uuid4()))
@@ -66,9 +54,9 @@ class Edge(BaseModel, ABC):
66
54
  async def delete(self, driver: GraphDriver):
67
55
  result = await driver.execute_query(
68
56
  """
69
- MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
70
- DELETE e
71
- """,
57
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
58
+ DELETE e
59
+ """,
72
60
  uuid=self.uuid,
73
61
  )
74
62
 
@@ -107,14 +95,10 @@ class EpisodicEdge(Edge):
107
95
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
108
96
  records, _, _ = await driver.execute_query(
109
97
  """
110
- MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
111
- RETURN
112
- e.uuid As uuid,
113
- e.group_id AS group_id,
114
- n.uuid AS source_node_uuid,
115
- m.uuid AS target_node_uuid,
116
- e.created_at AS created_at
117
- """,
98
+ MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
99
+ RETURN
100
+ """
101
+ + EPISODIC_EDGE_RETURN,
118
102
  uuid=uuid,
119
103
  routing_='r',
120
104
  )
@@ -129,15 +113,11 @@ class EpisodicEdge(Edge):
129
113
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
130
114
  records, _, _ = await driver.execute_query(
131
115
  """
132
- MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
133
- WHERE e.uuid IN $uuids
134
- RETURN
135
- e.uuid As uuid,
136
- e.group_id AS group_id,
137
- n.uuid AS source_node_uuid,
138
- m.uuid AS target_node_uuid,
139
- e.created_at AS created_at
140
- """,
116
+ MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
117
+ WHERE e.uuid IN $uuids
118
+ RETURN
119
+ """
120
+ + EPISODIC_EDGE_RETURN,
141
121
  uuids=uuids,
142
122
  routing_='r',
143
123
  )
@@ -161,19 +141,17 @@ class EpisodicEdge(Edge):
161
141
 
162
142
  records, _, _ = await driver.execute_query(
163
143
  """
164
- MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
165
- WHERE e.group_id IN $group_ids
166
- """
144
+ MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
145
+ WHERE e.group_id IN $group_ids
146
+ """
167
147
  + cursor_query
168
148
  + """
169
- RETURN
170
- e.uuid As uuid,
171
- e.group_id AS group_id,
172
- n.uuid AS source_node_uuid,
173
- m.uuid AS target_node_uuid,
174
- e.created_at AS created_at
175
- ORDER BY e.uuid DESC
176
- """
149
+ RETURN
150
+ """
151
+ + EPISODIC_EDGE_RETURN
152
+ + """
153
+ ORDER BY e.uuid DESC
154
+ """
177
155
  + limit_query,
178
156
  group_ids=group_ids,
179
157
  uuid=uuid_cursor,
@@ -221,11 +199,14 @@ class EntityEdge(Edge):
221
199
  return self.fact_embedding
222
200
 
223
201
  async def load_fact_embedding(self, driver: GraphDriver):
224
- query: LiteralString = """
202
+ records, _, _ = await driver.execute_query(
203
+ """
225
204
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
226
205
  RETURN e.fact_embedding AS fact_embedding
227
- """
228
- records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
206
+ """,
207
+ uuid=self.uuid,
208
+ routing_='r',
209
+ )
229
210
 
230
211
  if len(records) == 0:
231
212
  raise EdgeNotFoundError(self.uuid)
@@ -251,7 +232,7 @@ class EntityEdge(Edge):
251
232
  edge_data.update(self.attributes or {})
252
233
 
253
234
  result = await driver.execute_query(
254
- ENTITY_EDGE_SAVE,
235
+ get_entity_edge_save_query(driver.provider),
255
236
  edge_data=edge_data,
256
237
  )
257
238
 
@@ -263,8 +244,9 @@ class EntityEdge(Edge):
263
244
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
264
245
  records, _, _ = await driver.execute_query(
265
246
  """
266
- MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
267
- """
247
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
248
+ RETURN
249
+ """
268
250
  + ENTITY_EDGE_RETURN,
269
251
  uuid=uuid,
270
252
  routing_='r',
@@ -283,9 +265,10 @@ class EntityEdge(Edge):
283
265
 
284
266
  records, _, _ = await driver.execute_query(
285
267
  """
286
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
287
- WHERE e.uuid IN $uuids
288
- """
268
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
269
+ WHERE e.uuid IN $uuids
270
+ RETURN
271
+ """
289
272
  + ENTITY_EDGE_RETURN,
290
273
  uuids=uuids,
291
274
  routing_='r',
@@ -314,22 +297,21 @@ class EntityEdge(Edge):
314
297
  else ''
315
298
  )
316
299
 
317
- query: LiteralString = (
300
+ records, _, _ = await driver.execute_query(
318
301
  """
319
302
  MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
320
303
  WHERE e.group_id IN $group_ids
321
304
  """
322
305
  + cursor_query
306
+ + """
307
+ RETURN
308
+ """
323
309
  + ENTITY_EDGE_RETURN
324
310
  + with_embeddings_query
325
311
  + """
326
- ORDER BY e.uuid DESC
327
- """
328
- + limit_query
329
- )
330
-
331
- records, _, _ = await driver.execute_query(
332
- query,
312
+ ORDER BY e.uuid DESC
313
+ """
314
+ + limit_query,
333
315
  group_ids=group_ids,
334
316
  uuid=uuid_cursor,
335
317
  limit=limit,
@@ -344,13 +326,15 @@ class EntityEdge(Edge):
344
326
 
345
327
  @classmethod
346
328
  async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
347
- query: LiteralString = (
329
+ records, _, _ = await driver.execute_query(
348
330
  """
349
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
350
- """
351
- + ENTITY_EDGE_RETURN
331
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
332
+ RETURN
333
+ """
334
+ + ENTITY_EDGE_RETURN,
335
+ node_uuid=node_uuid,
336
+ routing_='r',
352
337
  )
353
- records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
354
338
 
355
339
  edges = [get_entity_edge_from_record(record) for record in records]
356
340
 
@@ -360,7 +344,7 @@ class EntityEdge(Edge):
360
344
  class CommunityEdge(Edge):
361
345
  async def save(self, driver: GraphDriver):
362
346
  result = await driver.execute_query(
363
- COMMUNITY_EDGE_SAVE,
347
+ get_community_edge_save_query(driver.provider),
364
348
  community_uuid=self.source_node_uuid,
365
349
  entity_uuid=self.target_node_uuid,
366
350
  uuid=self.uuid,
@@ -376,14 +360,10 @@ class CommunityEdge(Edge):
376
360
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
377
361
  records, _, _ = await driver.execute_query(
378
362
  """
379
- MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
380
- RETURN
381
- e.uuid As uuid,
382
- e.group_id AS group_id,
383
- n.uuid AS source_node_uuid,
384
- m.uuid AS target_node_uuid,
385
- e.created_at AS created_at
386
- """,
363
+ MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
364
+ RETURN
365
+ """
366
+ + COMMUNITY_EDGE_RETURN,
387
367
  uuid=uuid,
388
368
  routing_='r',
389
369
  )
@@ -396,15 +376,11 @@ class CommunityEdge(Edge):
396
376
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
397
377
  records, _, _ = await driver.execute_query(
398
378
  """
399
- MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
400
- WHERE e.uuid IN $uuids
401
- RETURN
402
- e.uuid As uuid,
403
- e.group_id AS group_id,
404
- n.uuid AS source_node_uuid,
405
- m.uuid AS target_node_uuid,
406
- e.created_at AS created_at
407
- """,
379
+ MATCH (n:Community)-[e:HAS_MEMBER]->(m)
380
+ WHERE e.uuid IN $uuids
381
+ RETURN
382
+ """
383
+ + COMMUNITY_EDGE_RETURN,
408
384
  uuids=uuids,
409
385
  routing_='r',
410
386
  )
@@ -426,19 +402,17 @@ class CommunityEdge(Edge):
426
402
 
427
403
  records, _, _ = await driver.execute_query(
428
404
  """
429
- MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
430
- WHERE e.group_id IN $group_ids
431
- """
405
+ MATCH (n:Community)-[e:HAS_MEMBER]->(m)
406
+ WHERE e.group_id IN $group_ids
407
+ """
432
408
  + cursor_query
433
409
  + """
434
- RETURN
435
- e.uuid As uuid,
436
- e.group_id AS group_id,
437
- n.uuid AS source_node_uuid,
438
- m.uuid AS target_node_uuid,
439
- e.created_at AS created_at
440
- ORDER BY e.uuid DESC
441
- """
410
+ RETURN
411
+ """
412
+ + COMMUNITY_EDGE_RETURN
413
+ + """
414
+ ORDER BY e.uuid DESC
415
+ """
442
416
  + limit_query,
443
417
  group_ids=group_ids,
444
418
  uuid=uuid_cursor,
@@ -5,16 +5,9 @@ This module provides database-agnostic query generation for Neo4j and FalkorDB,
5
5
  supporting index creation, fulltext search, and bulk operations.
6
6
  """
7
7
 
8
- from typing import Any
9
-
10
8
  from typing_extensions import LiteralString
11
9
 
12
- from graphiti_core.models.edges.edge_db_queries import (
13
- ENTITY_EDGE_SAVE_BULK,
14
- )
15
- from graphiti_core.models.nodes.node_db_queries import (
16
- ENTITY_NODE_SAVE_BULK,
17
- )
10
+ from graphiti_core.driver.driver import GraphProvider
18
11
 
19
12
  # Mapping from Neo4j fulltext index names to FalkorDB node labels
20
13
  NEO4J_TO_FALKORDB_MAPPING = {
@@ -25,8 +18,8 @@ NEO4J_TO_FALKORDB_MAPPING = {
25
18
  }
26
19
 
27
20
 
28
- def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
29
- if db_type == 'falkordb':
21
+ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
22
+ if provider == GraphProvider.FALKORDB:
30
23
  return [
31
24
  # Entity node
32
25
  'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
@@ -41,109 +34,70 @@ def get_range_indices(db_type: str = 'neo4j') -> list[LiteralString]:
41
34
  # HAS_MEMBER edge
42
35
  'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
43
36
  ]
44
- else:
45
- return [
46
- 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
47
- 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
48
- 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
49
- 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
50
- 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
51
- 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
52
- 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
53
- 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
54
- 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
55
- 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
56
- 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
57
- 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
58
- 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
59
- 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
60
- 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
61
- 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
62
- 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
63
- 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
64
- 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
65
- ]
66
37
 
67
-
68
- def get_fulltext_indices(db_type: str = 'neo4j') -> list[LiteralString]:
69
- if db_type == 'falkordb':
38
+ return [
39
+ 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
40
+ 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
41
+ 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
42
+ 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
43
+ 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
44
+ 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
45
+ 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
46
+ 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
47
+ 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
48
+ 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
49
+ 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
50
+ 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
51
+ 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
52
+ 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
53
+ 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
54
+ 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
55
+ 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
56
+ 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
57
+ 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
58
+ ]
59
+
60
+
61
+ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
62
+ if provider == GraphProvider.FALKORDB:
70
63
  return [
71
64
  """CREATE FULLTEXT INDEX FOR (e:Episodic) ON (e.content, e.source, e.source_description, e.group_id)""",
72
65
  """CREATE FULLTEXT INDEX FOR (n:Entity) ON (n.name, n.summary, n.group_id)""",
73
66
  """CREATE FULLTEXT INDEX FOR (n:Community) ON (n.name, n.group_id)""",
74
67
  """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
75
68
  ]
76
- else:
77
- return [
78
- """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
79
- FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
80
- """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
81
- FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
82
- """CREATE FULLTEXT INDEX community_name IF NOT EXISTS
83
- FOR (n:Community) ON EACH [n.name, n.group_id]""",
84
- """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
85
- FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
86
- ]
69
+
70
+ return [
71
+ """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
72
+ FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
73
+ """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
74
+ FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
75
+ """CREATE FULLTEXT INDEX community_name IF NOT EXISTS
76
+ FOR (n:Community) ON EACH [n.name, n.group_id]""",
77
+ """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
78
+ FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
79
+ ]
87
80
 
88
81
 
89
- def get_nodes_query(db_type: str = 'neo4j', name: str = '', query: str | None = None) -> str:
90
- if db_type == 'falkordb':
82
+ def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str:
83
+ if provider == GraphProvider.FALKORDB:
91
84
  label = NEO4J_TO_FALKORDB_MAPPING[name]
92
85
  return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
93
- else:
94
- return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
86
+
87
+ return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
95
88
 
96
89
 
97
- def get_vector_cosine_func_query(vec1, vec2, db_type: str = 'neo4j') -> str:
98
- if db_type == 'falkordb':
90
+ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
91
+ if provider == GraphProvider.FALKORDB:
99
92
  # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
100
93
  return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
101
- else:
102
- return f'vector.similarity.cosine({vec1}, {vec2})'
103
94
 
95
+ return f'vector.similarity.cosine({vec1}, {vec2})'
104
96
 
105
- def get_relationships_query(name: str, db_type: str = 'neo4j') -> str:
106
- if db_type == 'falkordb':
97
+
98
+ def get_relationships_query(name: str, provider: GraphProvider) -> str:
99
+ if provider == GraphProvider.FALKORDB:
107
100
  label = NEO4J_TO_FALKORDB_MAPPING[name]
108
101
  return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
109
- else:
110
- return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
111
-
112
-
113
- def get_entity_node_save_bulk_query(nodes, db_type: str = 'neo4j') -> str | Any:
114
- if db_type == 'falkordb':
115
- queries = []
116
- for node in nodes:
117
- for label in node['labels']:
118
- queries.append(
119
- (
120
- f"""
121
- UNWIND $nodes AS node
122
- MERGE (n:Entity {{uuid: node.uuid}})
123
- SET n:{label}
124
- SET n = node
125
- WITH n, node
126
- SET n.name_embedding = vecf32(node.name_embedding)
127
- RETURN n.uuid AS uuid
128
- """,
129
- {'nodes': [node]},
130
- )
131
- )
132
- return queries
133
- else:
134
- return ENTITY_NODE_SAVE_BULK
135
-
136
-
137
- def get_entity_edge_save_bulk_query(db_type: str = 'neo4j') -> str:
138
- if db_type == 'falkordb':
139
- return """
140
- UNWIND $entity_edges AS edge
141
- MATCH (source:Entity {uuid: edge.source_node_uuid})
142
- MATCH (target:Entity {uuid: edge.target_node_uuid})
143
- MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
144
- SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
145
- created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
146
- WITH r, edge
147
- RETURN edge.uuid AS uuid"""
148
- else:
149
- return ENTITY_EDGE_SAVE_BULK
102
+
103
+ return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'