graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +70 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +635 -260
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +37 -15
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +92 -48
  24. graphiti_core/llm_client/openai_client.py +39 -9
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +24 -15
  33. graphiti_core/prompts/extract_nodes.py +76 -35
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +110 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/content_chunking.py +702 -0
  45. graphiti_core/utils/datetime_utils.py +13 -0
  46. graphiti_core/utils/maintenance/community_operations.py +62 -38
  47. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  48. graphiti_core/utils/maintenance/edge_operations.py +306 -156
  49. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  50. graphiti_core/utils/maintenance/node_operations.py +466 -206
  51. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  52. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  53. graphiti_core/utils/text_utils.py +53 -0
  54. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
  55. graphiti_core-0.25.3.dist-info/RECORD +87 -0
  56. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
  57. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  58. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  59. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/edges.py CHANGED
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  from abc import ABC, abstractmethod
19
20
  from datetime import datetime
@@ -24,35 +25,22 @@ from uuid import uuid4
24
25
  from pydantic import BaseModel, Field
25
26
  from typing_extensions import LiteralString
26
27
 
27
- from graphiti_core.driver.driver import GraphDriver
28
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
28
29
  from graphiti_core.embedder import EmbedderClient
29
30
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
30
31
  from graphiti_core.helpers import parse_db_date
31
32
  from graphiti_core.models.edges.edge_db_queries import (
32
- COMMUNITY_EDGE_SAVE,
33
- ENTITY_EDGE_SAVE,
33
+ COMMUNITY_EDGE_RETURN,
34
+ EPISODIC_EDGE_RETURN,
34
35
  EPISODIC_EDGE_SAVE,
36
+ get_community_edge_save_query,
37
+ get_entity_edge_return_query,
38
+ get_entity_edge_save_query,
35
39
  )
36
40
  from graphiti_core.nodes import Node
37
41
 
38
42
  logger = logging.getLogger(__name__)
39
43
 
40
- ENTITY_EDGE_RETURN: LiteralString = """
41
- RETURN
42
- e.uuid AS uuid,
43
- startNode(e).uuid AS source_node_uuid,
44
- endNode(e).uuid AS target_node_uuid,
45
- e.created_at AS created_at,
46
- e.name AS name,
47
- e.group_id AS group_id,
48
- e.fact AS fact,
49
- e.episodes AS episodes,
50
- e.expired_at AS expired_at,
51
- e.valid_at AS valid_at,
52
- e.invalid_at AS invalid_at,
53
- properties(e) AS attributes
54
- """
55
-
56
44
 
57
45
  class Edge(BaseModel, ABC):
58
46
  uuid: str = Field(default_factory=lambda: str(uuid4()))
@@ -65,17 +53,68 @@ class Edge(BaseModel, ABC):
65
53
  async def save(self, driver: GraphDriver): ...
66
54
 
67
55
  async def delete(self, driver: GraphDriver):
68
- result = await driver.execute_query(
69
- """
70
- MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
71
- DELETE e
72
- """,
73
- uuid=self.uuid,
74
- )
56
+ if driver.graph_operations_interface:
57
+ return await driver.graph_operations_interface.edge_delete(self, driver)
58
+
59
+ if driver.provider == GraphProvider.KUZU:
60
+ await driver.execute_query(
61
+ """
62
+ MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
63
+ DELETE e
64
+ """,
65
+ uuid=self.uuid,
66
+ )
67
+ await driver.execute_query(
68
+ """
69
+ MATCH (e:RelatesToNode_ {uuid: $uuid})
70
+ DETACH DELETE e
71
+ """,
72
+ uuid=self.uuid,
73
+ )
74
+ else:
75
+ await driver.execute_query(
76
+ """
77
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
78
+ DELETE e
79
+ """,
80
+ uuid=self.uuid,
81
+ )
75
82
 
76
83
  logger.debug(f'Deleted Edge: {self.uuid}')
77
84
 
78
- return result
85
+ @classmethod
86
+ async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
87
+ if driver.graph_operations_interface:
88
+ return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
89
+
90
+ if driver.provider == GraphProvider.KUZU:
91
+ await driver.execute_query(
92
+ """
93
+ MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
94
+ WHERE e.uuid IN $uuids
95
+ DELETE e
96
+ """,
97
+ uuids=uuids,
98
+ )
99
+ await driver.execute_query(
100
+ """
101
+ MATCH (e:RelatesToNode_)
102
+ WHERE e.uuid IN $uuids
103
+ DETACH DELETE e
104
+ """,
105
+ uuids=uuids,
106
+ )
107
+ else:
108
+ await driver.execute_query(
109
+ """
110
+ MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
111
+ WHERE e.uuid IN $uuids
112
+ DELETE e
113
+ """,
114
+ uuids=uuids,
115
+ )
116
+
117
+ logger.debug(f'Deleted Edges: {uuids}')
79
118
 
80
119
  def __hash__(self):
81
120
  return hash(self.uuid)
@@ -108,14 +147,10 @@ class EpisodicEdge(Edge):
108
147
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
109
148
  records, _, _ = await driver.execute_query(
110
149
  """
111
- MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
112
- RETURN
113
- e.uuid As uuid,
114
- e.group_id AS group_id,
115
- n.uuid AS source_node_uuid,
116
- m.uuid AS target_node_uuid,
117
- e.created_at AS created_at
118
- """,
150
+ MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
151
+ RETURN
152
+ """
153
+ + EPISODIC_EDGE_RETURN,
119
154
  uuid=uuid,
120
155
  routing_='r',
121
156
  )
@@ -130,15 +165,11 @@ class EpisodicEdge(Edge):
130
165
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
131
166
  records, _, _ = await driver.execute_query(
132
167
  """
133
- MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
134
- WHERE e.uuid IN $uuids
135
- RETURN
136
- e.uuid As uuid,
137
- e.group_id AS group_id,
138
- n.uuid AS source_node_uuid,
139
- m.uuid AS target_node_uuid,
140
- e.created_at AS created_at
141
- """,
168
+ MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
169
+ WHERE e.uuid IN $uuids
170
+ RETURN
171
+ """
172
+ + EPISODIC_EDGE_RETURN,
142
173
  uuids=uuids,
143
174
  routing_='r',
144
175
  )
@@ -162,19 +193,17 @@ class EpisodicEdge(Edge):
162
193
 
163
194
  records, _, _ = await driver.execute_query(
164
195
  """
165
- MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
166
- WHERE e.group_id IN $group_ids
167
- """
196
+ MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
197
+ WHERE e.group_id IN $group_ids
198
+ """
168
199
  + cursor_query
169
200
  + """
170
- RETURN
171
- e.uuid As uuid,
172
- e.group_id AS group_id,
173
- n.uuid AS source_node_uuid,
174
- m.uuid AS target_node_uuid,
175
- e.created_at AS created_at
176
- ORDER BY e.uuid DESC
177
- """
201
+ RETURN
202
+ """
203
+ + EPISODIC_EDGE_RETURN
204
+ + """
205
+ ORDER BY e.uuid DESC
206
+ """
178
207
  + limit_query,
179
208
  group_ids=group_ids,
180
209
  uuid=uuid_cursor,
@@ -222,11 +251,31 @@ class EntityEdge(Edge):
222
251
  return self.fact_embedding
223
252
 
224
253
  async def load_fact_embedding(self, driver: GraphDriver):
225
- query: LiteralString = """
254
+ if driver.graph_operations_interface:
255
+ return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
256
+
257
+ query = """
226
258
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
227
259
  RETURN e.fact_embedding AS fact_embedding
228
260
  """
229
- records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
261
+
262
+ if driver.provider == GraphProvider.NEPTUNE:
263
+ query = """
264
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
265
+ RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
266
+ """
267
+
268
+ if driver.provider == GraphProvider.KUZU:
269
+ query = """
270
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
271
+ RETURN e.fact_embedding AS fact_embedding
272
+ """
273
+
274
+ records, _, _ = await driver.execute_query(
275
+ query,
276
+ uuid=self.uuid,
277
+ routing_='r',
278
+ )
230
279
 
231
280
  if len(records) == 0:
232
281
  raise EdgeNotFoundError(self.uuid)
@@ -249,12 +298,18 @@ class EntityEdge(Edge):
249
298
  'invalid_at': self.invalid_at,
250
299
  }
251
300
 
252
- edge_data.update(self.attributes or {})
253
-
254
- result = await driver.execute_query(
255
- ENTITY_EDGE_SAVE,
256
- edge_data=edge_data,
257
- )
301
+ if driver.provider == GraphProvider.KUZU:
302
+ edge_data['attributes'] = json.dumps(self.attributes)
303
+ result = await driver.execute_query(
304
+ get_entity_edge_save_query(driver.provider),
305
+ **edge_data,
306
+ )
307
+ else:
308
+ edge_data.update(self.attributes or {})
309
+ result = await driver.execute_query(
310
+ get_entity_edge_save_query(driver.provider),
311
+ edge_data=edge_data,
312
+ )
258
313
 
259
314
  logger.debug(f'Saved edge to Graph: {self.uuid}')
260
315
 
@@ -262,37 +317,84 @@ class EntityEdge(Edge):
262
317
 
263
318
  @classmethod
264
319
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
320
+ match_query = """
321
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
322
+ """
323
+ if driver.provider == GraphProvider.KUZU:
324
+ match_query = """
325
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
326
+ """
327
+
265
328
  records, _, _ = await driver.execute_query(
329
+ match_query
330
+ + """
331
+ RETURN
266
332
  """
267
- MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
268
- """
269
- + ENTITY_EDGE_RETURN,
333
+ + get_entity_edge_return_query(driver.provider),
270
334
  uuid=uuid,
271
335
  routing_='r',
272
336
  )
273
337
 
274
- edges = [get_entity_edge_from_record(record) for record in records]
338
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
275
339
 
276
340
  if len(edges) == 0:
277
341
  raise EdgeNotFoundError(uuid)
278
342
  return edges[0]
279
343
 
344
+ @classmethod
345
+ async def get_between_nodes(
346
+ cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
347
+ ):
348
+ match_query = """
349
+ MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
350
+ """
351
+ if driver.provider == GraphProvider.KUZU:
352
+ match_query = """
353
+ MATCH (n:Entity {uuid: $source_node_uuid})
354
+ -[:RELATES_TO]->(e:RelatesToNode_)
355
+ -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
356
+ """
357
+
358
+ records, _, _ = await driver.execute_query(
359
+ match_query
360
+ + """
361
+ RETURN
362
+ """
363
+ + get_entity_edge_return_query(driver.provider),
364
+ source_node_uuid=source_node_uuid,
365
+ target_node_uuid=target_node_uuid,
366
+ routing_='r',
367
+ )
368
+
369
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
370
+
371
+ return edges
372
+
280
373
  @classmethod
281
374
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
282
375
  if len(uuids) == 0:
283
376
  return []
284
377
 
378
+ match_query = """
379
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
380
+ """
381
+ if driver.provider == GraphProvider.KUZU:
382
+ match_query = """
383
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
384
+ """
385
+
285
386
  records, _, _ = await driver.execute_query(
387
+ match_query
388
+ + """
389
+ WHERE e.uuid IN $uuids
390
+ RETURN
286
391
  """
287
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
288
- WHERE e.uuid IN $uuids
289
- """
290
- + ENTITY_EDGE_RETURN,
392
+ + get_entity_edge_return_query(driver.provider),
291
393
  uuids=uuids,
292
394
  routing_='r',
293
395
  )
294
396
 
295
- edges = [get_entity_edge_from_record(record) for record in records]
397
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
296
398
 
297
399
  return edges
298
400
 
@@ -303,20 +405,40 @@ class EntityEdge(Edge):
303
405
  group_ids: list[str],
304
406
  limit: int | None = None,
305
407
  uuid_cursor: str | None = None,
408
+ with_embeddings: bool = False,
306
409
  ):
307
410
  cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
308
411
  limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
412
+ with_embeddings_query: LiteralString = (
413
+ """,
414
+ e.fact_embedding AS fact_embedding
415
+ """
416
+ if with_embeddings
417
+ else ''
418
+ )
419
+
420
+ match_query = """
421
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
422
+ """
423
+ if driver.provider == GraphProvider.KUZU:
424
+ match_query = """
425
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
426
+ """
309
427
 
310
428
  records, _, _ = await driver.execute_query(
429
+ match_query
430
+ + """
431
+ WHERE e.group_id IN $group_ids
311
432
  """
312
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
313
- WHERE e.group_id IN $group_ids
314
- """
315
433
  + cursor_query
316
- + ENTITY_EDGE_RETURN
317
434
  + """
318
- ORDER BY e.uuid DESC
319
- """
435
+ RETURN
436
+ """
437
+ + get_entity_edge_return_query(driver.provider)
438
+ + with_embeddings_query
439
+ + """
440
+ ORDER BY e.uuid DESC
441
+ """
320
442
  + limit_query,
321
443
  group_ids=group_ids,
322
444
  uuid=uuid_cursor,
@@ -324,7 +446,7 @@ class EntityEdge(Edge):
324
446
  routing_='r',
325
447
  )
326
448
 
327
- edges = [get_entity_edge_from_record(record) for record in records]
449
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
328
450
 
329
451
  if len(edges) == 0:
330
452
  raise GroupsEdgesNotFoundError(group_ids)
@@ -332,15 +454,25 @@ class EntityEdge(Edge):
332
454
 
333
455
  @classmethod
334
456
  async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
335
- query: LiteralString = (
457
+ match_query = """
458
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
459
+ """
460
+ if driver.provider == GraphProvider.KUZU:
461
+ match_query = """
462
+ MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
463
+ """
464
+
465
+ records, _, _ = await driver.execute_query(
466
+ match_query
467
+ + """
468
+ RETURN
336
469
  """
337
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
338
- """
339
- + ENTITY_EDGE_RETURN
470
+ + get_entity_edge_return_query(driver.provider),
471
+ node_uuid=node_uuid,
472
+ routing_='r',
340
473
  )
341
- records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
342
474
 
343
- edges = [get_entity_edge_from_record(record) for record in records]
475
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
344
476
 
345
477
  return edges
346
478
 
@@ -348,7 +480,7 @@ class EntityEdge(Edge):
348
480
  class CommunityEdge(Edge):
349
481
  async def save(self, driver: GraphDriver):
350
482
  result = await driver.execute_query(
351
- COMMUNITY_EDGE_SAVE,
483
+ get_community_edge_save_query(driver.provider),
352
484
  community_uuid=self.source_node_uuid,
353
485
  entity_uuid=self.target_node_uuid,
354
486
  uuid=self.uuid,
@@ -364,14 +496,10 @@ class CommunityEdge(Edge):
364
496
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
365
497
  records, _, _ = await driver.execute_query(
366
498
  """
367
- MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
368
- RETURN
369
- e.uuid As uuid,
370
- e.group_id AS group_id,
371
- n.uuid AS source_node_uuid,
372
- m.uuid AS target_node_uuid,
373
- e.created_at AS created_at
374
- """,
499
+ MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
500
+ RETURN
501
+ """
502
+ + COMMUNITY_EDGE_RETURN,
375
503
  uuid=uuid,
376
504
  routing_='r',
377
505
  )
@@ -384,15 +512,11 @@ class CommunityEdge(Edge):
384
512
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
385
513
  records, _, _ = await driver.execute_query(
386
514
  """
387
- MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
388
- WHERE e.uuid IN $uuids
389
- RETURN
390
- e.uuid As uuid,
391
- e.group_id AS group_id,
392
- n.uuid AS source_node_uuid,
393
- m.uuid AS target_node_uuid,
394
- e.created_at AS created_at
395
- """,
515
+ MATCH (n:Community)-[e:HAS_MEMBER]->(m)
516
+ WHERE e.uuid IN $uuids
517
+ RETURN
518
+ """
519
+ + COMMUNITY_EDGE_RETURN,
396
520
  uuids=uuids,
397
521
  routing_='r',
398
522
  )
@@ -414,19 +538,17 @@ class CommunityEdge(Edge):
414
538
 
415
539
  records, _, _ = await driver.execute_query(
416
540
  """
417
- MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
418
- WHERE e.group_id IN $group_ids
419
- """
541
+ MATCH (n:Community)-[e:HAS_MEMBER]->(m)
542
+ WHERE e.group_id IN $group_ids
543
+ """
420
544
  + cursor_query
421
545
  + """
422
- RETURN
423
- e.uuid As uuid,
424
- e.group_id AS group_id,
425
- n.uuid AS source_node_uuid,
426
- m.uuid AS target_node_uuid,
427
- e.created_at AS created_at
428
- ORDER BY e.uuid DESC
429
- """
546
+ RETURN
547
+ """
548
+ + COMMUNITY_EDGE_RETURN
549
+ + """
550
+ ORDER BY e.uuid DESC
551
+ """
430
552
  + limit_query,
431
553
  group_ids=group_ids,
432
554
  uuid=uuid_cursor,
@@ -450,34 +572,41 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
450
572
  )
451
573
 
452
574
 
453
- def get_entity_edge_from_record(record: Any) -> EntityEdge:
575
+ def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
576
+ episodes = record['episodes']
577
+ if provider == GraphProvider.KUZU:
578
+ attributes = json.loads(record['attributes']) if record['attributes'] else {}
579
+ else:
580
+ attributes = record['attributes']
581
+ attributes.pop('uuid', None)
582
+ attributes.pop('source_node_uuid', None)
583
+ attributes.pop('target_node_uuid', None)
584
+ attributes.pop('fact', None)
585
+ attributes.pop('fact_embedding', None)
586
+ attributes.pop('name', None)
587
+ attributes.pop('group_id', None)
588
+ attributes.pop('episodes', None)
589
+ attributes.pop('created_at', None)
590
+ attributes.pop('expired_at', None)
591
+ attributes.pop('valid_at', None)
592
+ attributes.pop('invalid_at', None)
593
+
454
594
  edge = EntityEdge(
455
595
  uuid=record['uuid'],
456
596
  source_node_uuid=record['source_node_uuid'],
457
597
  target_node_uuid=record['target_node_uuid'],
458
598
  fact=record['fact'],
599
+ fact_embedding=record.get('fact_embedding'),
459
600
  name=record['name'],
460
601
  group_id=record['group_id'],
461
- episodes=record['episodes'],
602
+ episodes=episodes,
462
603
  created_at=parse_db_date(record['created_at']), # type: ignore
463
604
  expired_at=parse_db_date(record['expired_at']),
464
605
  valid_at=parse_db_date(record['valid_at']),
465
606
  invalid_at=parse_db_date(record['invalid_at']),
466
- attributes=record['attributes'],
607
+ attributes=attributes,
467
608
  )
468
609
 
469
- edge.attributes.pop('uuid', None)
470
- edge.attributes.pop('source_node_uuid', None)
471
- edge.attributes.pop('target_node_uuid', None)
472
- edge.attributes.pop('fact', None)
473
- edge.attributes.pop('name', None)
474
- edge.attributes.pop('group_id', None)
475
- edge.attributes.pop('episodes', None)
476
- edge.attributes.pop('created_at', None)
477
- edge.attributes.pop('expired_at', None)
478
- edge.attributes.pop('valid_at', None)
479
- edge.attributes.pop('invalid_at', None)
480
-
481
610
  return edge
482
611
 
483
612
 
@@ -492,8 +621,11 @@ def get_community_edge_from_record(record: Any):
492
621
 
493
622
 
494
623
  async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
495
- if len(edges) == 0:
624
+ # filter out falsey values from edges
625
+ filtered_edges = [edge for edge in edges if edge.fact]
626
+
627
+ if len(filtered_edges) == 0:
496
628
  return
497
- fact_embeddings = await embedder.create_batch([edge.fact for edge in edges])
498
- for edge, fact_embedding in zip(edges, fact_embeddings, strict=True):
629
+ fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
630
+ for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
499
631
  edge.fact_embedding = fact_embedding
@@ -17,7 +17,7 @@ limitations under the License.
17
17
  import logging
18
18
  from typing import Any
19
19
 
20
- from openai import AsyncAzureOpenAI
20
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
21
21
 
22
22
  from .client import EmbedderClient
23
23
 
@@ -25,9 +25,16 @@ logger = logging.getLogger(__name__)
25
25
 
26
26
 
27
27
  class AzureOpenAIEmbedderClient(EmbedderClient):
28
- """Wrapper class for AsyncAzureOpenAI that implements the EmbedderClient interface."""
28
+ """Wrapper class for Azure OpenAI that implements the EmbedderClient interface.
29
29
 
30
- def __init__(self, azure_client: AsyncAzureOpenAI, model: str = 'text-embedding-3-small'):
30
+ Supports both AsyncAzureOpenAI and AsyncOpenAI (with Azure v1 API endpoint).
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ azure_client: AsyncAzureOpenAI | AsyncOpenAI,
36
+ model: str = 'text-embedding-3-small',
37
+ ):
31
38
  self.azure_client = azure_client
32
39
  self.model = model
33
40
 
@@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import os
17
18
  from abc import ABC, abstractmethod
18
19
  from collections.abc import Iterable
19
20
 
20
21
  from pydantic import BaseModel, Field
21
22
 
22
- EMBEDDING_DIM = 1024
23
+ EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
23
24
 
24
25
 
25
26
  class EmbedderConfig(BaseModel):