graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
graphiti_core/edges.py CHANGED
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  from abc import ABC, abstractmethod
19
20
  from datetime import datetime
@@ -21,38 +22,25 @@ from time import time
21
22
  from typing import Any
22
23
  from uuid import uuid4
23
24
 
24
- from neo4j import AsyncDriver
25
25
  from pydantic import BaseModel, Field
26
26
  from typing_extensions import LiteralString
27
27
 
28
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
28
29
  from graphiti_core.embedder import EmbedderClient
29
30
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
30
- from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
31
+ from graphiti_core.helpers import parse_db_date
31
32
  from graphiti_core.models.edges.edge_db_queries import (
32
- COMMUNITY_EDGE_SAVE,
33
- ENTITY_EDGE_SAVE,
33
+ COMMUNITY_EDGE_RETURN,
34
+ EPISODIC_EDGE_RETURN,
34
35
  EPISODIC_EDGE_SAVE,
36
+ get_community_edge_save_query,
37
+ get_entity_edge_return_query,
38
+ get_entity_edge_save_query,
35
39
  )
36
40
  from graphiti_core.nodes import Node
37
41
 
38
42
  logger = logging.getLogger(__name__)
39
43
 
40
- ENTITY_EDGE_RETURN: LiteralString = """
41
- RETURN
42
- e.uuid AS uuid,
43
- startNode(e).uuid AS source_node_uuid,
44
- endNode(e).uuid AS target_node_uuid,
45
- e.created_at AS created_at,
46
- e.name AS name,
47
- e.group_id AS group_id,
48
- e.fact AS fact,
49
- e.episodes AS episodes,
50
- e.expired_at AS expired_at,
51
- e.valid_at AS valid_at,
52
- e.invalid_at AS invalid_at,
53
- properties(e) AS attributes
54
- """
55
-
56
44
 
57
45
  class Edge(BaseModel, ABC):
58
46
  uuid: str = Field(default_factory=lambda: str(uuid4()))
@@ -62,21 +50,71 @@ class Edge(BaseModel, ABC):
62
50
  created_at: datetime
63
51
 
64
52
  @abstractmethod
65
- async def save(self, driver: AsyncDriver): ...
66
-
67
- async def delete(self, driver: AsyncDriver):
68
- result = await driver.execute_query(
69
- """
70
- MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
71
- DELETE e
72
- """,
73
- uuid=self.uuid,
74
- database_=DEFAULT_DATABASE,
75
- )
53
+ async def save(self, driver: GraphDriver): ...
54
+
55
+ async def delete(self, driver: GraphDriver):
56
+ if driver.graph_operations_interface:
57
+ return await driver.graph_operations_interface.edge_delete(self, driver)
58
+
59
+ if driver.provider == GraphProvider.KUZU:
60
+ await driver.execute_query(
61
+ """
62
+ MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
63
+ DELETE e
64
+ """,
65
+ uuid=self.uuid,
66
+ )
67
+ await driver.execute_query(
68
+ """
69
+ MATCH (e:RelatesToNode_ {uuid: $uuid})
70
+ DETACH DELETE e
71
+ """,
72
+ uuid=self.uuid,
73
+ )
74
+ else:
75
+ await driver.execute_query(
76
+ """
77
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
78
+ DELETE e
79
+ """,
80
+ uuid=self.uuid,
81
+ )
76
82
 
77
83
  logger.debug(f'Deleted Edge: {self.uuid}')
78
84
 
79
- return result
85
+ @classmethod
86
+ async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
87
+ if driver.graph_operations_interface:
88
+ return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
89
+
90
+ if driver.provider == GraphProvider.KUZU:
91
+ await driver.execute_query(
92
+ """
93
+ MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
94
+ WHERE e.uuid IN $uuids
95
+ DELETE e
96
+ """,
97
+ uuids=uuids,
98
+ )
99
+ await driver.execute_query(
100
+ """
101
+ MATCH (e:RelatesToNode_)
102
+ WHERE e.uuid IN $uuids
103
+ DETACH DELETE e
104
+ """,
105
+ uuids=uuids,
106
+ )
107
+ else:
108
+ await driver.execute_query(
109
+ """
110
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
111
+ WHERE e.uuid IN $uuids
112
+ DELETE e
113
+ """,
114
+ uuids=uuids,
115
+ )
116
+
117
+ logger.debug(f'Deleted Edges: {uuids}')
80
118
 
81
119
  def __hash__(self):
82
120
  return hash(self.uuid)
@@ -87,11 +125,11 @@ class Edge(BaseModel, ABC):
87
125
  return False
88
126
 
89
127
  @classmethod
90
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
128
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
91
129
 
92
130
 
93
131
  class EpisodicEdge(Edge):
94
- async def save(self, driver: AsyncDriver):
132
+ async def save(self, driver: GraphDriver):
95
133
  result = await driver.execute_query(
96
134
  EPISODIC_EDGE_SAVE,
97
135
  episode_uuid=self.source_node_uuid,
@@ -99,27 +137,21 @@ class EpisodicEdge(Edge):
99
137
  uuid=self.uuid,
100
138
  group_id=self.group_id,
101
139
  created_at=self.created_at,
102
- database_=DEFAULT_DATABASE,
103
140
  )
104
141
 
105
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
142
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
106
143
 
107
144
  return result
108
145
 
109
146
  @classmethod
110
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
147
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
111
148
  records, _, _ = await driver.execute_query(
112
149
  """
113
- MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
114
- RETURN
115
- e.uuid As uuid,
116
- e.group_id AS group_id,
117
- n.uuid AS source_node_uuid,
118
- m.uuid AS target_node_uuid,
119
- e.created_at AS created_at
120
- """,
150
+ MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
151
+ RETURN
152
+ """
153
+ + EPISODIC_EDGE_RETURN,
121
154
  uuid=uuid,
122
- database_=DEFAULT_DATABASE,
123
155
  routing_='r',
124
156
  )
125
157
 
@@ -130,20 +162,15 @@ class EpisodicEdge(Edge):
130
162
  return edges[0]
131
163
 
132
164
  @classmethod
133
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
165
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
134
166
  records, _, _ = await driver.execute_query(
135
167
  """
136
- MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
137
- WHERE e.uuid IN $uuids
138
- RETURN
139
- e.uuid As uuid,
140
- e.group_id AS group_id,
141
- n.uuid AS source_node_uuid,
142
- m.uuid AS target_node_uuid,
143
- e.created_at AS created_at
144
- """,
168
+ MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
169
+ WHERE e.uuid IN $uuids
170
+ RETURN
171
+ """
172
+ + EPISODIC_EDGE_RETURN,
145
173
  uuids=uuids,
146
- database_=DEFAULT_DATABASE,
147
174
  routing_='r',
148
175
  )
149
176
 
@@ -156,7 +183,7 @@ class EpisodicEdge(Edge):
156
183
  @classmethod
157
184
  async def get_by_group_ids(
158
185
  cls,
159
- driver: AsyncDriver,
186
+ driver: GraphDriver,
160
187
  group_ids: list[str],
161
188
  limit: int | None = None,
162
189
  uuid_cursor: str | None = None,
@@ -166,24 +193,21 @@ class EpisodicEdge(Edge):
166
193
 
167
194
  records, _, _ = await driver.execute_query(
168
195
  """
169
- MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
170
- WHERE e.group_id IN $group_ids
171
- """
196
+ MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
197
+ WHERE e.group_id IN $group_ids
198
+ """
172
199
  + cursor_query
173
200
  + """
174
- RETURN
175
- e.uuid As uuid,
176
- e.group_id AS group_id,
177
- n.uuid AS source_node_uuid,
178
- m.uuid AS target_node_uuid,
179
- e.created_at AS created_at
180
- ORDER BY e.uuid DESC
181
- """
201
+ RETURN
202
+ """
203
+ + EPISODIC_EDGE_RETURN
204
+ + """
205
+ ORDER BY e.uuid DESC
206
+ """
182
207
  + limit_query,
183
208
  group_ids=group_ids,
184
209
  uuid=uuid_cursor,
185
210
  limit=limit,
186
- database_=DEFAULT_DATABASE,
187
211
  routing_='r',
188
212
  )
189
213
 
@@ -226,13 +250,31 @@ class EntityEdge(Edge):
226
250
 
227
251
  return self.fact_embedding
228
252
 
229
- async def load_fact_embedding(self, driver: AsyncDriver):
230
- query: LiteralString = """
253
+ async def load_fact_embedding(self, driver: GraphDriver):
254
+ if driver.graph_operations_interface:
255
+ return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
256
+
257
+ query = """
231
258
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
232
259
  RETURN e.fact_embedding AS fact_embedding
233
260
  """
261
+
262
+ if driver.provider == GraphProvider.NEPTUNE:
263
+ query = """
264
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
265
+ RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
266
+ """
267
+
268
+ if driver.provider == GraphProvider.KUZU:
269
+ query = """
270
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
271
+ RETURN e.fact_embedding AS fact_embedding
272
+ """
273
+
234
274
  records, _, _ = await driver.execute_query(
235
- query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
275
+ query,
276
+ uuid=self.uuid,
277
+ routing_='r',
236
278
  )
237
279
 
238
280
  if len(records) == 0:
@@ -240,7 +282,7 @@ class EntityEdge(Edge):
240
282
 
241
283
  self.fact_embedding = records[0]['fact_embedding']
242
284
 
243
- async def save(self, driver: AsyncDriver):
285
+ async def save(self, driver: GraphDriver):
244
286
  edge_data: dict[str, Any] = {
245
287
  'source_uuid': self.source_node_uuid,
246
288
  'target_uuid': self.target_node_uuid,
@@ -256,138 +298,209 @@ class EntityEdge(Edge):
256
298
  'invalid_at': self.invalid_at,
257
299
  }
258
300
 
259
- edge_data.update(self.attributes or {})
260
-
261
- result = await driver.execute_query(
262
- ENTITY_EDGE_SAVE,
263
- edge_data=edge_data,
264
- database_=DEFAULT_DATABASE,
265
- )
266
-
267
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
301
+ if driver.provider == GraphProvider.KUZU:
302
+ edge_data['attributes'] = json.dumps(self.attributes)
303
+ result = await driver.execute_query(
304
+ get_entity_edge_save_query(driver.provider),
305
+ **edge_data,
306
+ )
307
+ else:
308
+ edge_data.update(self.attributes or {})
309
+ result = await driver.execute_query(
310
+ get_entity_edge_save_query(driver.provider),
311
+ edge_data=edge_data,
312
+ )
313
+
314
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
268
315
 
269
316
  return result
270
317
 
271
318
  @classmethod
272
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
319
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
320
+ match_query = """
321
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
322
+ """
323
+ if driver.provider == GraphProvider.KUZU:
324
+ match_query = """
325
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
326
+ """
327
+
273
328
  records, _, _ = await driver.execute_query(
329
+ match_query
330
+ + """
331
+ RETURN
274
332
  """
275
- MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
276
- """
277
- + ENTITY_EDGE_RETURN,
333
+ + get_entity_edge_return_query(driver.provider),
278
334
  uuid=uuid,
279
- database_=DEFAULT_DATABASE,
280
335
  routing_='r',
281
336
  )
282
337
 
283
- edges = [get_entity_edge_from_record(record) for record in records]
338
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
284
339
 
285
340
  if len(edges) == 0:
286
341
  raise EdgeNotFoundError(uuid)
287
342
  return edges[0]
288
343
 
289
344
  @classmethod
290
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
345
+ async def get_between_nodes(
346
+ cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
347
+ ):
348
+ match_query = """
349
+ MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
350
+ """
351
+ if driver.provider == GraphProvider.KUZU:
352
+ match_query = """
353
+ MATCH (n:Entity {uuid: $source_node_uuid})
354
+ -[:RELATES_TO]->(e:RelatesToNode_)
355
+ -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
356
+ """
357
+
358
+ records, _, _ = await driver.execute_query(
359
+ match_query
360
+ + """
361
+ RETURN
362
+ """
363
+ + get_entity_edge_return_query(driver.provider),
364
+ source_node_uuid=source_node_uuid,
365
+ target_node_uuid=target_node_uuid,
366
+ routing_='r',
367
+ )
368
+
369
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
370
+
371
+ return edges
372
+
373
+ @classmethod
374
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
291
375
  if len(uuids) == 0:
292
376
  return []
293
377
 
378
+ match_query = """
379
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
380
+ """
381
+ if driver.provider == GraphProvider.KUZU:
382
+ match_query = """
383
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
384
+ """
385
+
294
386
  records, _, _ = await driver.execute_query(
387
+ match_query
388
+ + """
389
+ WHERE e.uuid IN $uuids
390
+ RETURN
295
391
  """
296
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
297
- WHERE e.uuid IN $uuids
298
- """
299
- + ENTITY_EDGE_RETURN,
392
+ + get_entity_edge_return_query(driver.provider),
300
393
  uuids=uuids,
301
- database_=DEFAULT_DATABASE,
302
394
  routing_='r',
303
395
  )
304
396
 
305
- edges = [get_entity_edge_from_record(record) for record in records]
397
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
306
398
 
307
399
  return edges
308
400
 
309
401
  @classmethod
310
402
  async def get_by_group_ids(
311
403
  cls,
312
- driver: AsyncDriver,
404
+ driver: GraphDriver,
313
405
  group_ids: list[str],
314
406
  limit: int | None = None,
315
407
  uuid_cursor: str | None = None,
408
+ with_embeddings: bool = False,
316
409
  ):
317
410
  cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
318
411
  limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
412
+ with_embeddings_query: LiteralString = (
413
+ """,
414
+ e.fact_embedding AS fact_embedding
415
+ """
416
+ if with_embeddings
417
+ else ''
418
+ )
419
+
420
+ match_query = """
421
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
422
+ """
423
+ if driver.provider == GraphProvider.KUZU:
424
+ match_query = """
425
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
426
+ """
319
427
 
320
428
  records, _, _ = await driver.execute_query(
429
+ match_query
430
+ + """
431
+ WHERE e.group_id IN $group_ids
321
432
  """
322
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
323
- WHERE e.group_id IN $group_ids
324
- """
325
433
  + cursor_query
326
- + ENTITY_EDGE_RETURN
327
434
  + """
328
- ORDER BY e.uuid DESC
329
- """
435
+ RETURN
436
+ """
437
+ + get_entity_edge_return_query(driver.provider)
438
+ + with_embeddings_query
439
+ + """
440
+ ORDER BY e.uuid DESC
441
+ """
330
442
  + limit_query,
331
443
  group_ids=group_ids,
332
444
  uuid=uuid_cursor,
333
445
  limit=limit,
334
- database_=DEFAULT_DATABASE,
335
446
  routing_='r',
336
447
  )
337
448
 
338
- edges = [get_entity_edge_from_record(record) for record in records]
449
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
339
450
 
340
451
  if len(edges) == 0:
341
452
  raise GroupsEdgesNotFoundError(group_ids)
342
453
  return edges
343
454
 
344
455
  @classmethod
345
- async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
346
- query: LiteralString = (
456
+ async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
457
+ match_query = """
458
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
459
+ """
460
+ if driver.provider == GraphProvider.KUZU:
461
+ match_query = """
462
+ MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
347
463
  """
348
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
349
- """
350
- + ENTITY_EDGE_RETURN
351
- )
464
+
352
465
  records, _, _ = await driver.execute_query(
353
- query, node_uuid=node_uuid, database_=DEFAULT_DATABASE, routing_='r'
466
+ match_query
467
+ + """
468
+ RETURN
469
+ """
470
+ + get_entity_edge_return_query(driver.provider),
471
+ node_uuid=node_uuid,
472
+ routing_='r',
354
473
  )
355
474
 
356
- edges = [get_entity_edge_from_record(record) for record in records]
475
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
357
476
 
358
477
  return edges
359
478
 
360
479
 
361
480
  class CommunityEdge(Edge):
362
- async def save(self, driver: AsyncDriver):
481
+ async def save(self, driver: GraphDriver):
363
482
  result = await driver.execute_query(
364
- COMMUNITY_EDGE_SAVE,
483
+ get_community_edge_save_query(driver.provider),
365
484
  community_uuid=self.source_node_uuid,
366
485
  entity_uuid=self.target_node_uuid,
367
486
  uuid=self.uuid,
368
487
  group_id=self.group_id,
369
488
  created_at=self.created_at,
370
- database_=DEFAULT_DATABASE,
371
489
  )
372
490
 
373
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
491
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
374
492
 
375
493
  return result
376
494
 
377
495
  @classmethod
378
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
496
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
379
497
  records, _, _ = await driver.execute_query(
380
498
  """
381
- MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
382
- RETURN
383
- e.uuid As uuid,
384
- e.group_id AS group_id,
385
- n.uuid AS source_node_uuid,
386
- m.uuid AS target_node_uuid,
387
- e.created_at AS created_at
388
- """,
499
+ MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
500
+ RETURN
501
+ """
502
+ + COMMUNITY_EDGE_RETURN,
389
503
  uuid=uuid,
390
- database_=DEFAULT_DATABASE,
391
504
  routing_='r',
392
505
  )
393
506
 
@@ -396,20 +509,15 @@ class CommunityEdge(Edge):
396
509
  return edges[0]
397
510
 
398
511
  @classmethod
399
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
512
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
400
513
  records, _, _ = await driver.execute_query(
401
514
  """
402
- MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
403
- WHERE e.uuid IN $uuids
404
- RETURN
405
- e.uuid As uuid,
406
- e.group_id AS group_id,
407
- n.uuid AS source_node_uuid,
408
- m.uuid AS target_node_uuid,
409
- e.created_at AS created_at
410
- """,
515
+ MATCH (n:Community)-[e:HAS_MEMBER]->(m)
516
+ WHERE e.uuid IN $uuids
517
+ RETURN
518
+ """
519
+ + COMMUNITY_EDGE_RETURN,
411
520
  uuids=uuids,
412
- database_=DEFAULT_DATABASE,
413
521
  routing_='r',
414
522
  )
415
523
 
@@ -420,7 +528,7 @@ class CommunityEdge(Edge):
420
528
  @classmethod
421
529
  async def get_by_group_ids(
422
530
  cls,
423
- driver: AsyncDriver,
531
+ driver: GraphDriver,
424
532
  group_ids: list[str],
425
533
  limit: int | None = None,
426
534
  uuid_cursor: str | None = None,
@@ -430,24 +538,21 @@ class CommunityEdge(Edge):
430
538
 
431
539
  records, _, _ = await driver.execute_query(
432
540
  """
433
- MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
434
- WHERE e.group_id IN $group_ids
435
- """
541
+ MATCH (n:Community)-[e:HAS_MEMBER]->(m)
542
+ WHERE e.group_id IN $group_ids
543
+ """
436
544
  + cursor_query
437
545
  + """
438
- RETURN
439
- e.uuid As uuid,
440
- e.group_id AS group_id,
441
- n.uuid AS source_node_uuid,
442
- m.uuid AS target_node_uuid,
443
- e.created_at AS created_at
444
- ORDER BY e.uuid DESC
445
- """
546
+ RETURN
547
+ """
548
+ + COMMUNITY_EDGE_RETURN
549
+ + """
550
+ ORDER BY e.uuid DESC
551
+ """
446
552
  + limit_query,
447
553
  group_ids=group_ids,
448
554
  uuid=uuid_cursor,
449
555
  limit=limit,
450
- database_=DEFAULT_DATABASE,
451
556
  routing_='r',
452
557
  )
453
558
 
@@ -463,38 +568,45 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
463
568
  group_id=record['group_id'],
464
569
  source_node_uuid=record['source_node_uuid'],
465
570
  target_node_uuid=record['target_node_uuid'],
466
- created_at=record['created_at'].to_native(),
571
+ created_at=parse_db_date(record['created_at']), # type: ignore
467
572
  )
468
573
 
469
574
 
470
- def get_entity_edge_from_record(record: Any) -> EntityEdge:
575
+ def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
576
+ episodes = record['episodes']
577
+ if provider == GraphProvider.KUZU:
578
+ attributes = json.loads(record['attributes']) if record['attributes'] else {}
579
+ else:
580
+ attributes = record['attributes']
581
+ attributes.pop('uuid', None)
582
+ attributes.pop('source_node_uuid', None)
583
+ attributes.pop('target_node_uuid', None)
584
+ attributes.pop('fact', None)
585
+ attributes.pop('fact_embedding', None)
586
+ attributes.pop('name', None)
587
+ attributes.pop('group_id', None)
588
+ attributes.pop('episodes', None)
589
+ attributes.pop('created_at', None)
590
+ attributes.pop('expired_at', None)
591
+ attributes.pop('valid_at', None)
592
+ attributes.pop('invalid_at', None)
593
+
471
594
  edge = EntityEdge(
472
595
  uuid=record['uuid'],
473
596
  source_node_uuid=record['source_node_uuid'],
474
597
  target_node_uuid=record['target_node_uuid'],
475
598
  fact=record['fact'],
599
+ fact_embedding=record.get('fact_embedding'),
476
600
  name=record['name'],
477
601
  group_id=record['group_id'],
478
- episodes=record['episodes'],
479
- created_at=record['created_at'].to_native(),
602
+ episodes=episodes,
603
+ created_at=parse_db_date(record['created_at']), # type: ignore
480
604
  expired_at=parse_db_date(record['expired_at']),
481
605
  valid_at=parse_db_date(record['valid_at']),
482
606
  invalid_at=parse_db_date(record['invalid_at']),
483
- attributes=record['attributes'],
607
+ attributes=attributes,
484
608
  )
485
609
 
486
- edge.attributes.pop('uuid', None)
487
- edge.attributes.pop('source_node_uuid', None)
488
- edge.attributes.pop('target_node_uuid', None)
489
- edge.attributes.pop('fact', None)
490
- edge.attributes.pop('name', None)
491
- edge.attributes.pop('group_id', None)
492
- edge.attributes.pop('episodes', None)
493
- edge.attributes.pop('created_at', None)
494
- edge.attributes.pop('expired_at', None)
495
- edge.attributes.pop('valid_at', None)
496
- edge.attributes.pop('invalid_at', None)
497
-
498
610
  return edge
499
611
 
500
612
 
@@ -504,13 +616,16 @@ def get_community_edge_from_record(record: Any):
504
616
  group_id=record['group_id'],
505
617
  source_node_uuid=record['source_node_uuid'],
506
618
  target_node_uuid=record['target_node_uuid'],
507
- created_at=record['created_at'].to_native(),
619
+ created_at=parse_db_date(record['created_at']), # type: ignore
508
620
  )
509
621
 
510
622
 
511
623
  async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
512
- if len(edges) == 0:
624
+ # filter out falsey values from edges
625
+ filtered_edges = [edge for edge in edges if edge.fact]
626
+
627
+ if len(filtered_edges) == 0:
513
628
  return
514
- fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
515
- for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
629
+ fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
630
+ for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
516
631
  edge.fact_embedding = fact_embedding