graphiti-core 0.19.0rc3__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -27,10 +27,13 @@ logger = logging.getLogger(__name__)
27
27
  class GraphProvider(Enum):
28
28
  NEO4J = 'neo4j'
29
29
  FALKORDB = 'falkordb'
30
+ KUZU = 'kuzu'
30
31
  NEPTUNE = 'neptune'
31
32
 
32
33
 
33
34
  class GraphDriverSession(ABC):
35
+ provider: GraphProvider
36
+
34
37
  async def __aenter__(self):
35
38
  return self
36
39
 
@@ -15,7 +15,6 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
- from datetime import datetime
19
18
  from typing import TYPE_CHECKING, Any
20
19
 
21
20
  if TYPE_CHECKING:
@@ -33,11 +32,14 @@ else:
33
32
  ) from None
34
33
 
35
34
  from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
35
+ from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
36
36
 
37
37
  logger = logging.getLogger(__name__)
38
38
 
39
39
 
40
40
  class FalkorDriverSession(GraphDriverSession):
41
+ provider = GraphProvider.FALKORDB
42
+
41
43
  def __init__(self, graph: FalkorGraph):
42
44
  self.graph = graph
43
45
 
@@ -164,16 +166,3 @@ class FalkorDriver(GraphDriver):
164
166
  cloned = FalkorDriver(falkor_db=self.client, database=database)
165
167
 
166
168
  return cloned
167
-
168
-
169
- def convert_datetimes_to_strings(obj):
170
- if isinstance(obj, dict):
171
- return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
172
- elif isinstance(obj, list):
173
- return [convert_datetimes_to_strings(item) for item in obj]
174
- elif isinstance(obj, tuple):
175
- return tuple(convert_datetimes_to_strings(item) for item in obj)
176
- elif isinstance(obj, datetime):
177
- return obj.isoformat()
178
- else:
179
- return obj
@@ -0,0 +1,175 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from typing import Any
19
+
20
+ import kuzu
21
+
22
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Kuzu requires an explicit schema.
27
+ # As Kuzu currently does not support creating full text indexes on edge properties,
28
+ # we work around this by representing (n:Entity)-[:RELATES_TO]->(m:Entity) as
29
+ # (n)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m).
30
+ SCHEMA_QUERIES = """
31
+ CREATE NODE TABLE IF NOT EXISTS Episodic (
32
+ uuid STRING PRIMARY KEY,
33
+ name STRING,
34
+ group_id STRING,
35
+ created_at TIMESTAMP,
36
+ source STRING,
37
+ source_description STRING,
38
+ content STRING,
39
+ valid_at TIMESTAMP,
40
+ entity_edges STRING[]
41
+ );
42
+ CREATE NODE TABLE IF NOT EXISTS Entity (
43
+ uuid STRING PRIMARY KEY,
44
+ name STRING,
45
+ group_id STRING,
46
+ labels STRING[],
47
+ created_at TIMESTAMP,
48
+ name_embedding FLOAT[],
49
+ summary STRING,
50
+ attributes STRING
51
+ );
52
+ CREATE NODE TABLE IF NOT EXISTS Community (
53
+ uuid STRING PRIMARY KEY,
54
+ name STRING,
55
+ group_id STRING,
56
+ created_at TIMESTAMP,
57
+ name_embedding FLOAT[],
58
+ summary STRING
59
+ );
60
+ CREATE NODE TABLE IF NOT EXISTS RelatesToNode_ (
61
+ uuid STRING PRIMARY KEY,
62
+ group_id STRING,
63
+ created_at TIMESTAMP,
64
+ name STRING,
65
+ fact STRING,
66
+ fact_embedding FLOAT[],
67
+ episodes STRING[],
68
+ expired_at TIMESTAMP,
69
+ valid_at TIMESTAMP,
70
+ invalid_at TIMESTAMP,
71
+ attributes STRING
72
+ );
73
+ CREATE REL TABLE IF NOT EXISTS RELATES_TO(
74
+ FROM Entity TO RelatesToNode_,
75
+ FROM RelatesToNode_ TO Entity
76
+ );
77
+ CREATE REL TABLE IF NOT EXISTS MENTIONS(
78
+ FROM Episodic TO Entity,
79
+ uuid STRING PRIMARY KEY,
80
+ group_id STRING,
81
+ created_at TIMESTAMP
82
+ );
83
+ CREATE REL TABLE IF NOT EXISTS HAS_MEMBER(
84
+ FROM Community TO Entity,
85
+ FROM Community TO Community,
86
+ uuid STRING,
87
+ group_id STRING,
88
+ created_at TIMESTAMP
89
+ );
90
+ """
91
+
92
+
93
+ class KuzuDriver(GraphDriver):
94
+ provider: GraphProvider = GraphProvider.KUZU
95
+
96
+ def __init__(
97
+ self,
98
+ db: str = ':memory:',
99
+ max_concurrent_queries: int = 1,
100
+ ):
101
+ super().__init__()
102
+ self.db = kuzu.Database(db)
103
+
104
+ self.setup_schema()
105
+
106
+ self.client = kuzu.AsyncConnection(self.db, max_concurrent_queries=max_concurrent_queries)
107
+
108
+ async def execute_query(
109
+ self, cypher_query_: str, **kwargs: Any
110
+ ) -> tuple[list[dict[str, Any]] | list[list[dict[str, Any]]], None, None]:
111
+ params = {k: v for k, v in kwargs.items() if v is not None}
112
+ # Kuzu does not support these parameters.
113
+ params.pop('database_', None)
114
+ params.pop('routing_', None)
115
+
116
+ try:
117
+ results = await self.client.execute(cypher_query_, parameters=params)
118
+ except Exception as e:
119
+ params = {k: (v[:5] if isinstance(v, list) else v) for k, v in params.items()}
120
+ logger.error(f'Error executing Kuzu query: {e}\n{cypher_query_}\n{params}')
121
+ raise
122
+
123
+ if not results:
124
+ return [], None, None
125
+
126
+ if isinstance(results, list):
127
+ dict_results = [list(result.rows_as_dict()) for result in results]
128
+ else:
129
+ dict_results = list(results.rows_as_dict())
130
+ return dict_results, None, None # type: ignore
131
+
132
+ def session(self, _database: str | None = None) -> GraphDriverSession:
133
+ return KuzuDriverSession(self)
134
+
135
+ async def close(self):
136
+ # Do not explicity close the connection, instead rely on GC.
137
+ pass
138
+
139
+ def delete_all_indexes(self, database_: str):
140
+ pass
141
+
142
+ def setup_schema(self):
143
+ conn = kuzu.Connection(self.db)
144
+ conn.execute(SCHEMA_QUERIES)
145
+ conn.close()
146
+
147
+
148
+ class KuzuDriverSession(GraphDriverSession):
149
+ provider = GraphProvider.KUZU
150
+
151
+ def __init__(self, driver: KuzuDriver):
152
+ self.driver = driver
153
+
154
+ async def __aenter__(self):
155
+ return self
156
+
157
+ async def __aexit__(self, exc_type, exc, tb):
158
+ # No cleanup needed for Kuzu, but method must exist.
159
+ pass
160
+
161
+ async def close(self):
162
+ # Do not close the session here, as we're reusing the driver connection.
163
+ pass
164
+
165
+ async def execute_write(self, func, *args, **kwargs):
166
+ # Directly await the provided async function with `self` as the transaction/session
167
+ return await func(self, *args, **kwargs)
168
+
169
+ async def run(self, query: str | list, **kwargs: Any) -> Any:
170
+ if isinstance(query, list):
171
+ for cypher, params in query:
172
+ await self.driver.execute_query(cypher, **params)
173
+ else:
174
+ await self.driver.execute_query(query, **kwargs)
175
+ return None
@@ -271,6 +271,8 @@ class NeptuneDriver(GraphDriver):
271
271
 
272
272
 
273
273
  class NeptuneDriverSession(GraphDriverSession):
274
+ provider = GraphProvider.NEPTUNE
275
+
274
276
  def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
275
277
  self.driver = driver
276
278
 
graphiti_core/edges.py CHANGED
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  from abc import ABC, abstractmethod
19
20
  from datetime import datetime
@@ -30,11 +31,10 @@ from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
30
31
  from graphiti_core.helpers import parse_db_date
31
32
  from graphiti_core.models.edges.edge_db_queries import (
32
33
  COMMUNITY_EDGE_RETURN,
33
- ENTITY_EDGE_RETURN,
34
- ENTITY_EDGE_RETURN_NEPTUNE,
35
34
  EPISODIC_EDGE_RETURN,
36
35
  EPISODIC_EDGE_SAVE,
37
36
  get_community_edge_save_query,
37
+ get_entity_edge_return_query,
38
38
  get_entity_edge_save_query,
39
39
  )
40
40
  from graphiti_core.nodes import Node
@@ -53,33 +53,63 @@ class Edge(BaseModel, ABC):
53
53
  async def save(self, driver: GraphDriver): ...
54
54
 
55
55
  async def delete(self, driver: GraphDriver):
56
- result = await driver.execute_query(
57
- """
58
- MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
59
- DELETE e
60
- """,
61
- uuid=self.uuid,
62
- )
56
+ if driver.provider == GraphProvider.KUZU:
57
+ await driver.execute_query(
58
+ """
59
+ MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
60
+ DELETE e
61
+ """,
62
+ uuid=self.uuid,
63
+ )
64
+ await driver.execute_query(
65
+ """
66
+ MATCH (e:RelatesToNode_ {uuid: $uuid})
67
+ DETACH DELETE e
68
+ """,
69
+ uuid=self.uuid,
70
+ )
71
+ else:
72
+ await driver.execute_query(
73
+ """
74
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
75
+ DELETE e
76
+ """,
77
+ uuid=self.uuid,
78
+ )
63
79
 
64
80
  logger.debug(f'Deleted Edge: {self.uuid}')
65
81
 
66
- return result
67
-
68
82
  @classmethod
69
83
  async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
70
- result = await driver.execute_query(
71
- """
72
- MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
73
- WHERE e.uuid IN $uuids
74
- DELETE e
75
- """,
76
- uuids=uuids,
77
- )
84
+ if driver.provider == GraphProvider.KUZU:
85
+ await driver.execute_query(
86
+ """
87
+ MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
88
+ WHERE e.uuid IN $uuids
89
+ DELETE e
90
+ """,
91
+ uuids=uuids,
92
+ )
93
+ await driver.execute_query(
94
+ """
95
+ MATCH (e:RelatesToNode_)
96
+ WHERE e.uuid IN $uuids
97
+ DETACH DELETE e
98
+ """,
99
+ uuids=uuids,
100
+ )
101
+ else:
102
+ await driver.execute_query(
103
+ """
104
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
105
+ WHERE e.uuid IN $uuids
106
+ DELETE e
107
+ """,
108
+ uuids=uuids,
109
+ )
78
110
 
79
111
  logger.debug(f'Deleted Edges: {uuids}')
80
112
 
81
- return result
82
-
83
113
  def __hash__(self):
84
114
  return hash(self.uuid)
85
115
 
@@ -166,7 +196,7 @@ class EpisodicEdge(Edge):
166
196
  """
167
197
  + EPISODIC_EDGE_RETURN
168
198
  + """
169
- ORDER BY e.uuid DESC
199
+ ORDER BY e.uuid DESC
170
200
  """
171
201
  + limit_query,
172
202
  group_ids=group_ids,
@@ -215,15 +245,21 @@ class EntityEdge(Edge):
215
245
  return self.fact_embedding
216
246
 
217
247
  async def load_fact_embedding(self, driver: GraphDriver):
218
- if driver.provider == GraphProvider.NEPTUNE:
219
- query: LiteralString = """
248
+ query = """
220
249
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
250
+ RETURN e.fact_embedding AS fact_embedding
251
+ """
252
+
253
+ if driver.provider == GraphProvider.NEPTUNE:
254
+ query = """
255
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
221
256
  RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
222
257
  """
223
- else:
224
- query: LiteralString = """
225
- MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
226
- RETURN e.fact_embedding AS fact_embedding
258
+
259
+ if driver.provider == GraphProvider.KUZU:
260
+ query = """
261
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
262
+ RETURN e.fact_embedding AS fact_embedding
227
263
  """
228
264
 
229
265
  records, _, _ = await driver.execute_query(
@@ -253,15 +289,22 @@ class EntityEdge(Edge):
253
289
  'invalid_at': self.invalid_at,
254
290
  }
255
291
 
256
- edge_data.update(self.attributes or {})
292
+ if driver.provider == GraphProvider.KUZU:
293
+ edge_data['attributes'] = json.dumps(self.attributes)
294
+ result = await driver.execute_query(
295
+ get_entity_edge_save_query(driver.provider),
296
+ **edge_data,
297
+ )
298
+ else:
299
+ edge_data.update(self.attributes or {})
257
300
 
258
- if driver.provider == GraphProvider.NEPTUNE:
259
- driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
301
+ if driver.provider == GraphProvider.NEPTUNE:
302
+ driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
260
303
 
261
- result = await driver.execute_query(
262
- get_entity_edge_save_query(driver.provider),
263
- edge_data=edge_data,
264
- )
304
+ result = await driver.execute_query(
305
+ get_entity_edge_save_query(driver.provider),
306
+ edge_data=edge_data,
307
+ )
265
308
 
266
309
  logger.debug(f'Saved edge to Graph: {self.uuid}')
267
310
 
@@ -269,21 +312,25 @@ class EntityEdge(Edge):
269
312
 
270
313
  @classmethod
271
314
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
272
- records, _, _ = await driver.execute_query(
273
- """
315
+ match_query = """
274
316
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
317
+ """
318
+ if driver.provider == GraphProvider.KUZU:
319
+ match_query = """
320
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
321
+ """
322
+
323
+ records, _, _ = await driver.execute_query(
324
+ match_query
325
+ + """
275
326
  RETURN
276
327
  """
277
- + (
278
- ENTITY_EDGE_RETURN_NEPTUNE
279
- if driver.provider == GraphProvider.NEPTUNE
280
- else ENTITY_EDGE_RETURN
281
- ),
328
+ + get_entity_edge_return_query(driver.provider),
282
329
  uuid=uuid,
283
330
  routing_='r',
284
331
  )
285
332
 
286
- edges = [get_entity_edge_from_record(record) for record in records]
333
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
287
334
 
288
335
  if len(edges) == 0:
289
336
  raise EdgeNotFoundError(uuid)
@@ -294,22 +341,26 @@ class EntityEdge(Edge):
294
341
  if len(uuids) == 0:
295
342
  return []
296
343
 
297
- records, _, _ = await driver.execute_query(
298
- """
344
+ match_query = """
299
345
  MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
346
+ """
347
+ if driver.provider == GraphProvider.KUZU:
348
+ match_query = """
349
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
350
+ """
351
+
352
+ records, _, _ = await driver.execute_query(
353
+ match_query
354
+ + """
300
355
  WHERE e.uuid IN $uuids
301
356
  RETURN
302
357
  """
303
- + (
304
- ENTITY_EDGE_RETURN_NEPTUNE
305
- if driver.provider == GraphProvider.NEPTUNE
306
- else ENTITY_EDGE_RETURN
307
- ),
358
+ + get_entity_edge_return_query(driver.provider),
308
359
  uuids=uuids,
309
360
  routing_='r',
310
361
  )
311
362
 
312
- edges = [get_entity_edge_from_record(record) for record in records]
363
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
313
364
 
314
365
  return edges
315
366
 
@@ -332,23 +383,27 @@ class EntityEdge(Edge):
332
383
  else ''
333
384
  )
334
385
 
335
- records, _, _ = await driver.execute_query(
336
- """
386
+ match_query = """
337
387
  MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
388
+ """
389
+ if driver.provider == GraphProvider.KUZU:
390
+ match_query = """
391
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
392
+ """
393
+
394
+ records, _, _ = await driver.execute_query(
395
+ match_query
396
+ + """
338
397
  WHERE e.group_id IN $group_ids
339
398
  """
340
399
  + cursor_query
341
400
  + """
342
401
  RETURN
343
402
  """
344
- + (
345
- ENTITY_EDGE_RETURN_NEPTUNE
346
- if driver.provider == GraphProvider.NEPTUNE
347
- else ENTITY_EDGE_RETURN
348
- )
403
+ + get_entity_edge_return_query(driver.provider)
349
404
  + with_embeddings_query
350
405
  + """
351
- ORDER BY e.uuid DESC
406
+ ORDER BY e.uuid DESC
352
407
  """
353
408
  + limit_query,
354
409
  group_ids=group_ids,
@@ -357,7 +412,7 @@ class EntityEdge(Edge):
357
412
  routing_='r',
358
413
  )
359
414
 
360
- edges = [get_entity_edge_from_record(record) for record in records]
415
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
361
416
 
362
417
  if len(edges) == 0:
363
418
  raise GroupsEdgesNotFoundError(group_ids)
@@ -365,21 +420,25 @@ class EntityEdge(Edge):
365
420
 
366
421
  @classmethod
367
422
  async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
368
- records, _, _ = await driver.execute_query(
369
- """
423
+ match_query = """
370
424
  MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
425
+ """
426
+ if driver.provider == GraphProvider.KUZU:
427
+ match_query = """
428
+ MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
429
+ """
430
+
431
+ records, _, _ = await driver.execute_query(
432
+ match_query
433
+ + """
371
434
  RETURN
372
435
  """
373
- + (
374
- ENTITY_EDGE_RETURN_NEPTUNE
375
- if driver.provider == GraphProvider.NEPTUNE
376
- else ENTITY_EDGE_RETURN
377
- ),
436
+ + get_entity_edge_return_query(driver.provider),
378
437
  node_uuid=node_uuid,
379
438
  routing_='r',
380
439
  )
381
440
 
382
- edges = [get_entity_edge_from_record(record) for record in records]
441
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
383
442
 
384
443
  return edges
385
444
 
@@ -479,7 +538,25 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
479
538
  )
480
539
 
481
540
 
482
- def get_entity_edge_from_record(record: Any) -> EntityEdge:
541
+ def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
542
+ episodes = record['episodes']
543
+ if provider == GraphProvider.KUZU:
544
+ attributes = json.loads(record['attributes']) if record['attributes'] else {}
545
+ else:
546
+ attributes = record['attributes']
547
+ attributes.pop('uuid', None)
548
+ attributes.pop('source_node_uuid', None)
549
+ attributes.pop('target_node_uuid', None)
550
+ attributes.pop('fact', None)
551
+ attributes.pop('fact_embedding', None)
552
+ attributes.pop('name', None)
553
+ attributes.pop('group_id', None)
554
+ attributes.pop('episodes', None)
555
+ attributes.pop('created_at', None)
556
+ attributes.pop('expired_at', None)
557
+ attributes.pop('valid_at', None)
558
+ attributes.pop('invalid_at', None)
559
+
483
560
  edge = EntityEdge(
484
561
  uuid=record['uuid'],
485
562
  source_node_uuid=record['source_node_uuid'],
@@ -488,26 +565,14 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
488
565
  fact_embedding=record.get('fact_embedding'),
489
566
  name=record['name'],
490
567
  group_id=record['group_id'],
491
- episodes=record['episodes'],
568
+ episodes=episodes,
492
569
  created_at=parse_db_date(record['created_at']), # type: ignore
493
570
  expired_at=parse_db_date(record['expired_at']),
494
571
  valid_at=parse_db_date(record['valid_at']),
495
572
  invalid_at=parse_db_date(record['invalid_at']),
496
- attributes=record['attributes'],
573
+ attributes=attributes,
497
574
  )
498
575
 
499
- edge.attributes.pop('uuid', None)
500
- edge.attributes.pop('source_node_uuid', None)
501
- edge.attributes.pop('target_node_uuid', None)
502
- edge.attributes.pop('fact', None)
503
- edge.attributes.pop('name', None)
504
- edge.attributes.pop('group_id', None)
505
- edge.attributes.pop('episodes', None)
506
- edge.attributes.pop('created_at', None)
507
- edge.attributes.pop('expired_at', None)
508
- edge.attributes.pop('valid_at', None)
509
- edge.attributes.pop('invalid_at', None)
510
-
511
576
  return edge
512
577
 
513
578
 
@@ -16,6 +16,13 @@ NEO4J_TO_FALKORDB_MAPPING = {
16
16
  'episode_content': 'Episodic',
17
17
  'edge_name_and_fact': 'RELATES_TO',
18
18
  }
19
+ # Mapping from fulltext index names to Kuzu node labels
20
+ INDEX_TO_LABEL_KUZU_MAPPING = {
21
+ 'node_name_and_summary': 'Entity',
22
+ 'community_name': 'Community',
23
+ 'episode_content': 'Episodic',
24
+ 'edge_name_and_fact': 'RelatesToNode_',
25
+ }
19
26
 
20
27
 
21
28
  def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
@@ -35,6 +42,9 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
35
42
  'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
36
43
  ]
37
44
 
45
+ if provider == GraphProvider.KUZU:
46
+ return []
47
+
38
48
  return [
39
49
  'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
40
50
  'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
@@ -68,6 +78,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
68
78
  """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
69
79
  ]
70
80
 
81
+ if provider == GraphProvider.KUZU:
82
+ return [
83
+ "CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
84
+ "CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
85
+ "CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
86
+ "CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
87
+ ]
88
+
71
89
  return [
72
90
  """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
73
91
  FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
@@ -80,11 +98,15 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
80
98
  ]
81
99
 
82
100
 
83
- def get_nodes_query(provider: GraphProvider, name: str = '', query: str | None = None) -> str:
101
+ def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
84
102
  if provider == GraphProvider.FALKORDB:
85
103
  label = NEO4J_TO_FALKORDB_MAPPING[name]
86
104
  return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
87
105
 
106
+ if provider == GraphProvider.KUZU:
107
+ label = INDEX_TO_LABEL_KUZU_MAPPING[name]
108
+ return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
109
+
88
110
  return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
89
111
 
90
112
 
@@ -93,12 +115,19 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
93
115
  # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
94
116
  return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
95
117
 
118
+ if provider == GraphProvider.KUZU:
119
+ return f'array_cosine_similarity({vec1}, {vec2})'
120
+
96
121
  return f'vector.similarity.cosine({vec1}, {vec2})'
97
122
 
98
123
 
99
- def get_relationships_query(name: str, provider: GraphProvider) -> str:
124
+ def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
100
125
  if provider == GraphProvider.FALKORDB:
101
126
  label = NEO4J_TO_FALKORDB_MAPPING[name]
102
127
  return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
103
128
 
129
+ if provider == GraphProvider.KUZU:
130
+ label = INDEX_TO_LABEL_KUZU_MAPPING[name]
131
+ return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
132
+
104
133
  return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'