graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +61 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +582 -255
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +21 -14
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +94 -50
  24. graphiti_core/llm_client/openai_client.py +28 -8
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +23 -14
  33. graphiti_core/prompts/extract_nodes.py +73 -32
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +109 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/datetime_utils.py +13 -0
  45. graphiti_core/utils/maintenance/community_operations.py +62 -38
  46. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  47. graphiti_core/utils/maintenance/edge_operations.py +286 -126
  48. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  49. graphiti_core/utils/maintenance/node_operations.py +320 -158
  50. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  51. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  52. graphiti_core/utils/text_utils.py +53 -0
  53. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
  54. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  55. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  56. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  57. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  58. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/graphiti.py CHANGED
@@ -24,18 +24,33 @@ from typing_extensions import LiteralString
24
24
 
25
25
  from graphiti_core.cross_encoder.client import CrossEncoderClient
26
26
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
27
+ from graphiti_core.decorators import handle_multiple_group_ids
27
28
  from graphiti_core.driver.driver import GraphDriver
28
29
  from graphiti_core.driver.neo4j_driver import Neo4jDriver
29
- from graphiti_core.edges import EntityEdge, EpisodicEdge
30
+ from graphiti_core.edges import (
31
+ CommunityEdge,
32
+ Edge,
33
+ EntityEdge,
34
+ EpisodicEdge,
35
+ create_entity_edge_embeddings,
36
+ )
30
37
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
31
38
  from graphiti_core.graphiti_types import GraphitiClients
32
39
  from graphiti_core.helpers import (
40
+ get_default_group_id,
33
41
  semaphore_gather,
34
42
  validate_excluded_entity_types,
35
43
  validate_group_id,
36
44
  )
37
45
  from graphiti_core.llm_client import LLMClient, OpenAIClient
38
- from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
46
+ from graphiti_core.nodes import (
47
+ CommunityNode,
48
+ EntityNode,
49
+ EpisodeType,
50
+ EpisodicNode,
51
+ Node,
52
+ create_entity_node_embeddings,
53
+ )
39
54
  from graphiti_core.search.search import SearchConfig, search
40
55
  from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
41
56
  from graphiti_core.search.search_config_recipes import (
@@ -46,11 +61,10 @@ from graphiti_core.search.search_config_recipes import (
46
61
  from graphiti_core.search.search_filters import SearchFilters
47
62
  from graphiti_core.search.search_utils import (
48
63
  RELEVANT_SCHEMA_LIMIT,
49
- get_edge_invalidation_candidates,
50
64
  get_mentioned_nodes,
51
- get_relevant_edges,
52
65
  )
53
66
  from graphiti_core.telemetry import capture_event
67
+ from graphiti_core.tracer import Tracer, create_tracer
54
68
  from graphiti_core.utils.bulk_utils import (
55
69
  RawEpisode,
56
70
  add_nodes_and_edges_bulk,
@@ -67,7 +81,6 @@ from graphiti_core.utils.maintenance.community_operations import (
67
81
  update_community,
68
82
  )
69
83
  from graphiti_core.utils.maintenance.edge_operations import (
70
- build_duplicate_of_edges,
71
84
  build_episodic_edges,
72
85
  extract_edges,
73
86
  resolve_extracted_edge,
@@ -75,7 +88,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
75
88
  )
76
89
  from graphiti_core.utils.maintenance.graph_data_operations import (
77
90
  EPISODE_WINDOW_LEN,
78
- build_indices_and_constraints,
79
91
  retrieve_episodes,
80
92
  )
81
93
  from graphiti_core.utils.maintenance.node_operations import (
@@ -92,6 +104,23 @@ load_dotenv()
92
104
 
93
105
  class AddEpisodeResults(BaseModel):
94
106
  episode: EpisodicNode
107
+ episodic_edges: list[EpisodicEdge]
108
+ nodes: list[EntityNode]
109
+ edges: list[EntityEdge]
110
+ communities: list[CommunityNode]
111
+ community_edges: list[CommunityEdge]
112
+
113
+
114
+ class AddBulkEpisodeResults(BaseModel):
115
+ episodes: list[EpisodicNode]
116
+ episodic_edges: list[EpisodicEdge]
117
+ nodes: list[EntityNode]
118
+ edges: list[EntityEdge]
119
+ communities: list[CommunityNode]
120
+ community_edges: list[CommunityEdge]
121
+
122
+
123
+ class AddTripletResults(BaseModel):
95
124
  nodes: list[EntityNode]
96
125
  edges: list[EntityEdge]
97
126
 
@@ -108,11 +137,13 @@ class Graphiti:
108
137
  store_raw_episode_content: bool = True,
109
138
  graph_driver: GraphDriver | None = None,
110
139
  max_coroutines: int | None = None,
140
+ tracer: Tracer | None = None,
141
+ trace_span_prefix: str = 'graphiti',
111
142
  ):
112
143
  """
113
144
  Initialize a Graphiti instance.
114
145
 
115
- This constructor sets up a connection to the Neo4j database and initializes
146
+ This constructor sets up a connection to a graph database and initializes
116
147
  the LLM client for natural language processing tasks.
117
148
 
118
149
  Parameters
@@ -140,6 +171,10 @@ class Graphiti:
140
171
  max_coroutines : int | None, optional
141
172
  The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
142
173
  If not set, the Graphiti default is used.
174
+ tracer : Tracer | None, optional
175
+ An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
176
+ trace_span_prefix : str, optional
177
+ Prefix to prepend to all span names. Defaults to 'graphiti'.
143
178
 
144
179
  Returns
145
180
  -------
@@ -147,11 +182,11 @@ class Graphiti:
147
182
 
148
183
  Notes
149
184
  -----
150
- This method establishes a connection to the Neo4j database using the provided
185
+ This method establishes a connection to a graph database (Neo4j by default) using the provided
151
186
  credentials. It also sets up the LLM client, either using the provided client
152
187
  or by creating a default OpenAIClient.
153
188
 
154
- The default database name is set to 'neo4j'. If a different database name
189
+ The default database name is defined during the driver’s construction. If a different database name
155
190
  is required, it should be specified in the URI or set separately after
156
191
  initialization.
157
192
 
@@ -182,11 +217,18 @@ class Graphiti:
182
217
  else:
183
218
  self.cross_encoder = OpenAIRerankerClient()
184
219
 
220
+ # Initialize tracer
221
+ self.tracer = create_tracer(tracer, trace_span_prefix)
222
+
223
+ # Set tracer on clients
224
+ self.llm_client.set_tracer(self.tracer)
225
+
185
226
  self.clients = GraphitiClients(
186
227
  driver=self.driver,
187
228
  llm_client=self.llm_client,
188
229
  embedder=self.embedder,
189
230
  cross_encoder=self.cross_encoder,
231
+ tracer=self.tracer,
190
232
  )
191
233
 
192
234
  # Capture telemetry event
@@ -298,25 +340,247 @@ class Graphiti:
298
340
  -----
299
341
  This method should typically be called once during the initial setup of the
300
342
  knowledge graph or when updating the database schema. It uses the
301
- `build_indices_and_constraints` function from the
302
- `graphiti_core.utils.maintenance.graph_data_operations` module to perform
343
+ driver's `build_indices_and_constraints` method to perform
303
344
  the actual database operations.
304
345
 
305
346
  The specific indices and constraints created depend on the implementation
306
- of the `build_indices_and_constraints` function. Refer to that function's
307
- documentation for details on the exact database schema modifications.
347
+ of the driver's `build_indices_and_constraints` method. Refer to the specific
348
+ driver documentation for details on the exact database schema modifications.
308
349
 
309
350
  Caution: Running this method on a large existing database may take some time
310
351
  and could impact database performance during execution.
311
352
  """
312
- await build_indices_and_constraints(self.driver, delete_existing)
353
+ await self.driver.build_indices_and_constraints(delete_existing)
354
+
355
+ async def _extract_and_resolve_nodes(
356
+ self,
357
+ episode: EpisodicNode,
358
+ previous_episodes: list[EpisodicNode],
359
+ entity_types: dict[str, type[BaseModel]] | None,
360
+ excluded_entity_types: list[str] | None,
361
+ ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
362
+ """Extract nodes from episode and resolve against existing graph."""
363
+ extracted_nodes = await extract_nodes(
364
+ self.clients, episode, previous_episodes, entity_types, excluded_entity_types
365
+ )
366
+
367
+ nodes, uuid_map, duplicates = await resolve_extracted_nodes(
368
+ self.clients,
369
+ extracted_nodes,
370
+ episode,
371
+ previous_episodes,
372
+ entity_types,
373
+ )
374
+
375
+ return nodes, uuid_map, duplicates
376
+
377
+ async def _extract_and_resolve_edges(
378
+ self,
379
+ episode: EpisodicNode,
380
+ extracted_nodes: list[EntityNode],
381
+ previous_episodes: list[EpisodicNode],
382
+ edge_type_map: dict[tuple[str, str], list[str]],
383
+ group_id: str,
384
+ edge_types: dict[str, type[BaseModel]] | None,
385
+ nodes: list[EntityNode],
386
+ uuid_map: dict[str, str],
387
+ ) -> tuple[list[EntityEdge], list[EntityEdge]]:
388
+ """Extract edges from episode and resolve against existing graph."""
389
+ extracted_edges = await extract_edges(
390
+ self.clients,
391
+ episode,
392
+ extracted_nodes,
393
+ previous_episodes,
394
+ edge_type_map,
395
+ group_id,
396
+ edge_types,
397
+ )
313
398
 
399
+ edges = resolve_edge_pointers(extracted_edges, uuid_map)
400
+
401
+ resolved_edges, invalidated_edges = await resolve_extracted_edges(
402
+ self.clients,
403
+ edges,
404
+ episode,
405
+ nodes,
406
+ edge_types or {},
407
+ edge_type_map,
408
+ )
409
+
410
+ return resolved_edges, invalidated_edges
411
+
412
+ async def _process_episode_data(
413
+ self,
414
+ episode: EpisodicNode,
415
+ nodes: list[EntityNode],
416
+ entity_edges: list[EntityEdge],
417
+ now: datetime,
418
+ ) -> tuple[list[EpisodicEdge], EpisodicNode]:
419
+ """Process and save episode data to the graph."""
420
+ episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
421
+ episode.entity_edges = [edge.uuid for edge in entity_edges]
422
+
423
+ if not self.store_raw_episode_content:
424
+ episode.content = ''
425
+
426
+ await add_nodes_and_edges_bulk(
427
+ self.driver,
428
+ [episode],
429
+ episodic_edges,
430
+ nodes,
431
+ entity_edges,
432
+ self.embedder,
433
+ )
434
+
435
+ return episodic_edges, episode
436
+
437
+ async def _extract_and_dedupe_nodes_bulk(
438
+ self,
439
+ episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
440
+ edge_type_map: dict[tuple[str, str], list[str]],
441
+ edge_types: dict[str, type[BaseModel]] | None,
442
+ entity_types: dict[str, type[BaseModel]] | None,
443
+ excluded_entity_types: list[str] | None,
444
+ ) -> tuple[
445
+ dict[str, list[EntityNode]],
446
+ dict[str, str],
447
+ list[list[EntityEdge]],
448
+ ]:
449
+ """Extract nodes and edges from all episodes and deduplicate."""
450
+ # Extract all nodes and edges for each episode
451
+ extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
452
+ self.clients,
453
+ episode_context,
454
+ edge_type_map=edge_type_map,
455
+ edge_types=edge_types,
456
+ entity_types=entity_types,
457
+ excluded_entity_types=excluded_entity_types,
458
+ )
459
+
460
+ # Dedupe extracted nodes in memory
461
+ nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
462
+ self.clients, extracted_nodes_bulk, episode_context, entity_types
463
+ )
464
+
465
+ return nodes_by_episode, uuid_map, extracted_edges_bulk
466
+
467
+ async def _resolve_nodes_and_edges_bulk(
468
+ self,
469
+ nodes_by_episode: dict[str, list[EntityNode]],
470
+ edges_by_episode: dict[str, list[EntityEdge]],
471
+ episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
472
+ entity_types: dict[str, type[BaseModel]] | None,
473
+ edge_types: dict[str, type[BaseModel]] | None,
474
+ edge_type_map: dict[tuple[str, str], list[str]],
475
+ episodes: list[EpisodicNode],
476
+ ) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
477
+ """Resolve nodes and edges against the existing graph."""
478
+ nodes_by_uuid: dict[str, EntityNode] = {
479
+ node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
480
+ }
481
+
482
+ # Get unique nodes per episode
483
+ nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
484
+ nodes_uuid_set: set[str] = set()
485
+ for episode, _ in episode_context:
486
+ nodes_by_episode_unique[episode.uuid] = []
487
+ nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
488
+ for node in nodes:
489
+ if node.uuid not in nodes_uuid_set:
490
+ nodes_by_episode_unique[episode.uuid].append(node)
491
+ nodes_uuid_set.add(node.uuid)
492
+
493
+ # Resolve nodes
494
+ node_results = await semaphore_gather(
495
+ *[
496
+ resolve_extracted_nodes(
497
+ self.clients,
498
+ nodes_by_episode_unique[episode.uuid],
499
+ episode,
500
+ previous_episodes,
501
+ entity_types,
502
+ )
503
+ for episode, previous_episodes in episode_context
504
+ ]
505
+ )
506
+
507
+ resolved_nodes: list[EntityNode] = []
508
+ uuid_map: dict[str, str] = {}
509
+ for result in node_results:
510
+ resolved_nodes.extend(result[0])
511
+ uuid_map.update(result[1])
512
+
513
+ # Update nodes_by_uuid with resolved nodes
514
+ for resolved_node in resolved_nodes:
515
+ nodes_by_uuid[resolved_node.uuid] = resolved_node
516
+
517
+ # Update nodes_by_episode_unique with resolved pointers
518
+ for episode_uuid, nodes in nodes_by_episode_unique.items():
519
+ updated_nodes: list[EntityNode] = []
520
+ for node in nodes:
521
+ updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
522
+ updated_node = nodes_by_uuid[updated_node_uuid]
523
+ updated_nodes.append(updated_node)
524
+ nodes_by_episode_unique[episode_uuid] = updated_nodes
525
+
526
+ # Extract attributes for resolved nodes
527
+ hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
528
+ *[
529
+ extract_attributes_from_nodes(
530
+ self.clients,
531
+ nodes_by_episode_unique[episode.uuid],
532
+ episode,
533
+ previous_episodes,
534
+ entity_types,
535
+ )
536
+ for episode, previous_episodes in episode_context
537
+ ]
538
+ )
539
+
540
+ final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
541
+
542
+ # Resolve edges with updated pointers
543
+ edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
544
+ edges_uuid_set: set[str] = set()
545
+ for episode_uuid, edges in edges_by_episode.items():
546
+ edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
547
+ edges_by_episode_unique[episode_uuid] = []
548
+
549
+ for edge in edges_with_updated_pointers:
550
+ if edge.uuid not in edges_uuid_set:
551
+ edges_by_episode_unique[episode_uuid].append(edge)
552
+ edges_uuid_set.add(edge.uuid)
553
+
554
+ edge_results = await semaphore_gather(
555
+ *[
556
+ resolve_extracted_edges(
557
+ self.clients,
558
+ edges_by_episode_unique[episode.uuid],
559
+ episode,
560
+ final_hydrated_nodes,
561
+ edge_types or {},
562
+ edge_type_map,
563
+ )
564
+ for episode in episodes
565
+ ]
566
+ )
567
+
568
+ resolved_edges: list[EntityEdge] = []
569
+ invalidated_edges: list[EntityEdge] = []
570
+ for result in edge_results:
571
+ resolved_edges.extend(result[0])
572
+ invalidated_edges.extend(result[1])
573
+
574
+ return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
575
+
576
+ @handle_multiple_group_ids
314
577
  async def retrieve_episodes(
315
578
  self,
316
579
  reference_time: datetime,
317
580
  last_n: int = EPISODE_WINDOW_LEN,
318
581
  group_ids: list[str] | None = None,
319
582
  source: EpisodeType | None = None,
583
+ driver: GraphDriver | None = None,
320
584
  ) -> list[EpisodicNode]:
321
585
  """
322
586
  Retrieve the last n episodic nodes from the graph.
@@ -343,7 +607,10 @@ class Graphiti:
343
607
  The actual retrieval is performed by the `retrieve_episodes` function
344
608
  from the `graphiti_core.utils` module.
345
609
  """
346
- return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
610
+ if driver is None:
611
+ driver = self.clients.driver
612
+
613
+ return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
347
614
 
348
615
  async def add_episode(
349
616
  self,
@@ -352,13 +619,13 @@ class Graphiti:
352
619
  source_description: str,
353
620
  reference_time: datetime,
354
621
  source: EpisodeType = EpisodeType.message,
355
- group_id: str = '',
622
+ group_id: str | None = None,
356
623
  uuid: str | None = None,
357
624
  update_communities: bool = False,
358
- entity_types: dict[str, BaseModel] | None = None,
625
+ entity_types: dict[str, type[BaseModel]] | None = None,
359
626
  excluded_entity_types: list[str] | None = None,
360
627
  previous_episode_uuids: list[str] | None = None,
361
- edge_types: dict[str, BaseModel] | None = None,
628
+ edge_types: dict[str, type[BaseModel]] | None = None,
362
629
  edge_type_map: dict[tuple[str, str], list[str]] | None = None,
363
630
  ) -> AddEpisodeResults:
364
631
  """
@@ -416,133 +683,155 @@ class Graphiti:
416
683
  background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
417
684
  return {"message": "Episode processing started"}
418
685
  """
419
- try:
420
- start = time()
421
- now = utc_now()
686
+ start = time()
687
+ now = utc_now()
688
+
689
+ validate_entity_types(entity_types)
690
+ validate_excluded_entity_types(excluded_entity_types, entity_types)
422
691
 
423
- validate_entity_types(entity_types)
424
- validate_excluded_entity_types(excluded_entity_types, entity_types)
692
+ if group_id is None:
693
+ # if group_id is None, use the default group id by the provider
694
+ # and the preset database name will be used
695
+ group_id = get_default_group_id(self.driver.provider)
696
+ else:
425
697
  validate_group_id(group_id)
698
+ if group_id != self.driver._database:
699
+ # if group_id is provided, use it as the database name
700
+ self.driver = self.driver.clone(database=group_id)
701
+ self.clients.driver = self.driver
426
702
 
427
- previous_episodes = (
428
- await self.retrieve_episodes(
429
- reference_time,
430
- last_n=RELEVANT_SCHEMA_LIMIT,
431
- group_ids=[group_id],
432
- source=source,
703
+ with self.tracer.start_span('add_episode') as span:
704
+ try:
705
+ # Retrieve previous episodes for context
706
+ previous_episodes = (
707
+ await self.retrieve_episodes(
708
+ reference_time,
709
+ last_n=RELEVANT_SCHEMA_LIMIT,
710
+ group_ids=[group_id],
711
+ source=source,
712
+ )
713
+ if previous_episode_uuids is None
714
+ else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
433
715
  )
434
- if previous_episode_uuids is None
435
- else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
436
- )
437
716
 
438
- episode = (
439
- await EpisodicNode.get_by_uuid(self.driver, uuid)
440
- if uuid is not None
441
- else EpisodicNode(
442
- name=name,
443
- group_id=group_id,
444
- labels=[],
445
- source=source,
446
- content=episode_body,
447
- source_description=source_description,
448
- created_at=now,
449
- valid_at=reference_time,
717
+ # Get or create episode
718
+ episode = (
719
+ await EpisodicNode.get_by_uuid(self.driver, uuid)
720
+ if uuid is not None
721
+ else EpisodicNode(
722
+ name=name,
723
+ group_id=group_id,
724
+ labels=[],
725
+ source=source,
726
+ content=episode_body,
727
+ source_description=source_description,
728
+ created_at=now,
729
+ valid_at=reference_time,
730
+ )
450
731
  )
451
- )
452
-
453
- # Create default edge type map
454
- edge_type_map_default = (
455
- {('Entity', 'Entity'): list(edge_types.keys())}
456
- if edge_types is not None
457
- else {('Entity', 'Entity'): []}
458
- )
459
732
 
460
- # Extract entities as nodes
733
+ # Create default edge type map
734
+ edge_type_map_default = (
735
+ {('Entity', 'Entity'): list(edge_types.keys())}
736
+ if edge_types is not None
737
+ else {('Entity', 'Entity'): []}
738
+ )
461
739
 
462
- extracted_nodes = await extract_nodes(
463
- self.clients, episode, previous_episodes, entity_types, excluded_entity_types
464
- )
740
+ # Extract and resolve nodes
741
+ extracted_nodes = await extract_nodes(
742
+ self.clients, episode, previous_episodes, entity_types, excluded_entity_types
743
+ )
465
744
 
466
- # Extract edges and resolve nodes
467
- (nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
468
- resolve_extracted_nodes(
745
+ nodes, uuid_map, _ = await resolve_extracted_nodes(
469
746
  self.clients,
470
747
  extracted_nodes,
471
748
  episode,
472
749
  previous_episodes,
473
750
  entity_types,
474
- ),
475
- extract_edges(
476
- self.clients,
751
+ )
752
+
753
+ # Extract and resolve edges in parallel with attribute extraction
754
+ resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
477
755
  episode,
478
756
  extracted_nodes,
479
757
  previous_episodes,
480
758
  edge_type_map or edge_type_map_default,
481
759
  group_id,
482
760
  edge_types,
483
- ),
484
- max_coroutines=self.max_coroutines,
485
- )
486
-
487
- edges = resolve_edge_pointers(extracted_edges, uuid_map)
488
-
489
- (resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
490
- resolve_extracted_edges(
491
- self.clients,
492
- edges,
493
- episode,
494
761
  nodes,
495
- edge_types or {},
496
- edge_type_map or edge_type_map_default,
497
- ),
498
- extract_attributes_from_nodes(
499
- self.clients, nodes, episode, previous_episodes, entity_types
500
- ),
501
- max_coroutines=self.max_coroutines,
502
- )
762
+ uuid_map,
763
+ )
503
764
 
504
- duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
765
+ # Extract node attributes
766
+ hydrated_nodes = await extract_attributes_from_nodes(
767
+ self.clients, nodes, episode, previous_episodes, entity_types
768
+ )
505
769
 
506
- entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
770
+ entity_edges = resolved_edges + invalidated_edges
507
771
 
508
- episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
772
+ # Process and save episode data
773
+ episodic_edges, episode = await self._process_episode_data(
774
+ episode, hydrated_nodes, entity_edges, now
775
+ )
509
776
 
510
- episode.entity_edges = [edge.uuid for edge in entity_edges]
777
+ # Update communities if requested
778
+ communities = []
779
+ community_edges = []
780
+ if update_communities:
781
+ communities, community_edges = await semaphore_gather(
782
+ *[
783
+ update_community(self.driver, self.llm_client, self.embedder, node)
784
+ for node in nodes
785
+ ],
786
+ max_coroutines=self.max_coroutines,
787
+ )
511
788
 
512
- if not self.store_raw_episode_content:
513
- episode.content = ''
789
+ end = time()
790
+
791
+ # Add span attributes
792
+ span.add_attributes(
793
+ {
794
+ 'episode.uuid': episode.uuid,
795
+ 'episode.source': source.value,
796
+ 'episode.reference_time': reference_time.isoformat(),
797
+ 'group_id': group_id,
798
+ 'node.count': len(hydrated_nodes),
799
+ 'edge.count': len(entity_edges),
800
+ 'edge.invalidated_count': len(invalidated_edges),
801
+ 'previous_episodes.count': len(previous_episodes),
802
+ 'entity_types.count': len(entity_types) if entity_types else 0,
803
+ 'edge_types.count': len(edge_types) if edge_types else 0,
804
+ 'update_communities': update_communities,
805
+ 'communities.count': len(communities) if update_communities else 0,
806
+ 'duration_ms': (end - start) * 1000,
807
+ }
808
+ )
514
809
 
515
- await add_nodes_and_edges_bulk(
516
- self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
517
- )
810
+ logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
518
811
 
519
- # Update any communities
520
- if update_communities:
521
- await semaphore_gather(
522
- *[
523
- update_community(self.driver, self.llm_client, self.embedder, node)
524
- for node in nodes
525
- ],
526
- max_coroutines=self.max_coroutines,
812
+ return AddEpisodeResults(
813
+ episode=episode,
814
+ episodic_edges=episodic_edges,
815
+ nodes=hydrated_nodes,
816
+ edges=entity_edges,
817
+ communities=communities,
818
+ community_edges=community_edges,
527
819
  )
528
- end = time()
529
- logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
530
820
 
531
- return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
821
+ except Exception as e:
822
+ span.set_status('error', str(e))
823
+ span.record_exception(e)
824
+ raise e
532
825
 
533
- except Exception as e:
534
- raise e
535
-
536
- ##### EXPERIMENTAL #####
537
826
  async def add_episode_bulk(
538
827
  self,
539
828
  bulk_episodes: list[RawEpisode],
540
- group_id: str = '',
541
- entity_types: dict[str, BaseModel] | None = None,
829
+ group_id: str | None = None,
830
+ entity_types: dict[str, type[BaseModel]] | None = None,
542
831
  excluded_entity_types: list[str] | None = None,
543
- edge_types: dict[str, BaseModel] | None = None,
832
+ edge_types: dict[str, type[BaseModel]] | None = None,
544
833
  edge_type_map: dict[tuple[str, str], list[str]] | None = None,
545
- ):
834
+ ) -> AddBulkEpisodeResults:
546
835
  """
547
836
  Process multiple episodes in bulk and update the graph.
548
837
 
@@ -558,7 +847,7 @@ class Graphiti:
558
847
 
559
848
  Returns
560
849
  -------
561
- None
850
+ AddBulkEpisodeResults
562
851
 
563
852
  Notes
564
853
  -----
@@ -579,156 +868,167 @@ class Graphiti:
579
868
  If these operations are required, use the `add_episode` method instead for each
580
869
  individual episode.
581
870
  """
582
- try:
583
- start = time()
584
- now = utc_now()
585
-
586
- validate_group_id(group_id)
587
-
588
- # Create default edge type map
589
- edge_type_map_default = (
590
- {('Entity', 'Entity'): list(edge_types.keys())}
591
- if edge_types is not None
592
- else {('Entity', 'Entity'): []}
593
- )
871
+ with self.tracer.start_span('add_episode_bulk') as bulk_span:
872
+ bulk_span.add_attributes({'episode.count': len(bulk_episodes)})
594
873
 
595
- episodes = [
596
- await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
597
- if episode.uuid is not None
598
- else EpisodicNode(
599
- name=episode.name,
600
- labels=[],
601
- source=episode.source,
602
- content=episode.content,
603
- source_description=episode.source_description,
604
- group_id=group_id,
605
- created_at=now,
606
- valid_at=episode.reference_time,
874
+ try:
875
+ start = time()
876
+ now = utc_now()
877
+
878
+ # if group_id is None, use the default group id by the provider
879
+ if group_id is None:
880
+ group_id = get_default_group_id(self.driver.provider)
881
+ else:
882
+ validate_group_id(group_id)
883
+ if group_id != self.driver._database:
884
+ # if group_id is provided, use it as the database name
885
+ self.driver = self.driver.clone(database=group_id)
886
+ self.clients.driver = self.driver
887
+
888
+ # Create default edge type map
889
+ edge_type_map_default = (
890
+ {('Entity', 'Entity'): list(edge_types.keys())}
891
+ if edge_types is not None
892
+ else {('Entity', 'Entity'): []}
607
893
  )
608
- for episode in bulk_episodes
609
- ]
610
894
 
611
- episodes_by_uuid: dict[str, EpisodicNode] = {
612
- episode.uuid: episode for episode in episodes
613
- }
614
-
615
- # Save all episodes
616
- await add_nodes_and_edges_bulk(
617
- driver=self.driver,
618
- episodic_nodes=episodes,
619
- episodic_edges=[],
620
- entity_nodes=[],
621
- entity_edges=[],
622
- embedder=self.embedder,
623
- )
895
+ episodes = [
896
+ await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
897
+ if episode.uuid is not None
898
+ else EpisodicNode(
899
+ name=episode.name,
900
+ labels=[],
901
+ source=episode.source,
902
+ content=episode.content,
903
+ source_description=episode.source_description,
904
+ group_id=group_id,
905
+ created_at=now,
906
+ valid_at=episode.reference_time,
907
+ )
908
+ for episode in bulk_episodes
909
+ ]
624
910
 
625
- # Get previous episode context for each episode
626
- episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
911
+ # Save all episodes
912
+ await add_nodes_and_edges_bulk(
913
+ driver=self.driver,
914
+ episodic_nodes=episodes,
915
+ episodic_edges=[],
916
+ entity_nodes=[],
917
+ entity_edges=[],
918
+ embedder=self.embedder,
919
+ )
627
920
 
628
- # Extract all nodes and edges for each episode
629
- extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
630
- self.clients,
631
- episode_context,
632
- edge_type_map=edge_type_map or edge_type_map_default,
633
- edge_types=edge_types,
634
- entity_types=entity_types,
635
- excluded_entity_types=excluded_entity_types,
636
- )
921
+ # Get previous episode context for each episode
922
+ episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
637
923
 
638
- # Dedupe extracted nodes in memory
639
- nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
640
- self.clients, extracted_nodes_bulk, episode_context, entity_types
641
- )
924
+ # Extract and dedupe nodes and edges
925
+ (
926
+ nodes_by_episode,
927
+ uuid_map,
928
+ extracted_edges_bulk,
929
+ ) = await self._extract_and_dedupe_nodes_bulk(
930
+ episode_context,
931
+ edge_type_map or edge_type_map_default,
932
+ edge_types,
933
+ entity_types,
934
+ excluded_entity_types,
935
+ )
642
936
 
643
- episodic_edges: list[EpisodicEdge] = []
644
- for episode_uuid, nodes in nodes_by_episode.items():
645
- episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
937
+ # Create Episodic Edges
938
+ episodic_edges: list[EpisodicEdge] = []
939
+ for episode_uuid, nodes in nodes_by_episode.items():
940
+ episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
646
941
 
647
- # re-map edge pointers so that they don't point to discard dupe nodes
648
- extracted_edges_bulk_updated: list[list[EntityEdge]] = [
649
- resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
650
- ]
942
+ # Re-map edge pointers and dedupe edges
943
+ extracted_edges_bulk_updated: list[list[EntityEdge]] = [
944
+ resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
945
+ ]
651
946
 
652
- # Dedupe extracted edges in memory
653
- edges_by_episode = await dedupe_edges_bulk(
654
- self.clients,
655
- extracted_edges_bulk_updated,
656
- episode_context,
657
- [],
658
- edge_types or {},
659
- edge_type_map or edge_type_map_default,
660
- )
947
+ edges_by_episode = await dedupe_edges_bulk(
948
+ self.clients,
949
+ extracted_edges_bulk_updated,
950
+ episode_context,
951
+ [],
952
+ edge_types or {},
953
+ edge_type_map or edge_type_map_default,
954
+ )
661
955
 
662
- # Extract node attributes
663
- nodes_by_uuid: dict[str, EntityNode] = {
664
- node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
665
- }
956
+ # Resolve nodes and edges against the existing graph
957
+ (
958
+ final_hydrated_nodes,
959
+ resolved_edges,
960
+ invalidated_edges,
961
+ final_uuid_map,
962
+ ) = await self._resolve_nodes_and_edges_bulk(
963
+ nodes_by_episode,
964
+ edges_by_episode,
965
+ episode_context,
966
+ entity_types,
967
+ edge_types,
968
+ edge_type_map or edge_type_map_default,
969
+ episodes,
970
+ )
666
971
 
667
- extract_attributes_params: list[tuple[EntityNode, list[EpisodicNode]]] = []
668
- for node in nodes_by_uuid.values():
669
- episode_uuids: list[str] = []
670
- for episode_uuid, mentioned_nodes in nodes_by_episode.items():
671
- for mentioned_node in mentioned_nodes:
672
- if node.uuid == mentioned_node.uuid:
673
- episode_uuids.append(episode_uuid)
674
- break
675
-
676
- episode_mentions: list[EpisodicNode] = [
677
- episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
678
- ]
679
- episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
680
-
681
- extract_attributes_params.append((node, episode_mentions))
682
-
683
- new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
684
- *[
685
- extract_attributes_from_nodes(
686
- self.clients,
687
- [params[0]],
688
- params[1][0],
689
- params[1][0:],
690
- entity_types,
691
- )
692
- for params in extract_attributes_params
693
- ]
694
- )
972
+ # Resolved pointers for episodic edges
973
+ resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)
974
+
975
+ # save data to KG
976
+ await add_nodes_and_edges_bulk(
977
+ self.driver,
978
+ episodes,
979
+ resolved_episodic_edges,
980
+ final_hydrated_nodes,
981
+ resolved_edges + invalidated_edges,
982
+ self.embedder,
983
+ )
695
984
 
696
- hydrated_nodes = [node for nodes in new_hydrated_nodes for node in nodes]
985
+ end = time()
697
986
 
698
- # TODO: Resolve nodes and edges against the existing graph
699
- edges_by_uuid: dict[str, EntityEdge] = {
700
- edge.uuid: edge for edges in edges_by_episode.values() for edge in edges
701
- }
987
+ # Add span attributes
988
+ bulk_span.add_attributes(
989
+ {
990
+ 'group_id': group_id,
991
+ 'node.count': len(final_hydrated_nodes),
992
+ 'edge.count': len(resolved_edges + invalidated_edges),
993
+ 'duration_ms': (end - start) * 1000,
994
+ }
995
+ )
702
996
 
703
- # save data to KG
704
- await add_nodes_and_edges_bulk(
705
- self.driver,
706
- episodes,
707
- episodic_edges,
708
- hydrated_nodes,
709
- list(edges_by_uuid.values()),
710
- self.embedder,
711
- )
997
+ logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
712
998
 
713
- end = time()
714
- logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
999
+ return AddBulkEpisodeResults(
1000
+ episodes=episodes,
1001
+ episodic_edges=resolved_episodic_edges,
1002
+ nodes=final_hydrated_nodes,
1003
+ edges=resolved_edges + invalidated_edges,
1004
+ communities=[],
1005
+ community_edges=[],
1006
+ )
715
1007
 
716
- except Exception as e:
717
- raise e
1008
+ except Exception as e:
1009
+ bulk_span.set_status('error', str(e))
1010
+ bulk_span.record_exception(e)
1011
+ raise e
718
1012
 
719
- async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
1013
+ @handle_multiple_group_ids
1014
+ async def build_communities(
1015
+ self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
1016
+ ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
720
1017
  """
721
1018
  Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
722
1019
  the content of these communities.
723
1020
  ----------
724
- query : list[str] | None
1021
+ group_ids : list[str] | None
725
1022
  Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
726
1023
  """
1024
+ if driver is None:
1025
+ driver = self.clients.driver
1026
+
727
1027
  # Clear existing communities
728
- await remove_communities(self.driver)
1028
+ await remove_communities(driver)
729
1029
 
730
1030
  community_nodes, community_edges = await build_communities(
731
- self.driver, self.llm_client, group_ids
1031
+ driver, self.llm_client, group_ids
732
1032
  )
733
1033
 
734
1034
  await semaphore_gather(
@@ -737,16 +1037,17 @@ class Graphiti:
737
1037
  )
738
1038
 
739
1039
  await semaphore_gather(
740
- *[node.save(self.driver) for node in community_nodes],
1040
+ *[node.save(driver) for node in community_nodes],
741
1041
  max_coroutines=self.max_coroutines,
742
1042
  )
743
1043
  await semaphore_gather(
744
- *[edge.save(self.driver) for edge in community_edges],
1044
+ *[edge.save(driver) for edge in community_edges],
745
1045
  max_coroutines=self.max_coroutines,
746
1046
  )
747
1047
 
748
- return community_nodes
1048
+ return community_nodes, community_edges
749
1049
 
1050
+ @handle_multiple_group_ids
750
1051
  async def search(
751
1052
  self,
752
1053
  query: str,
@@ -754,6 +1055,7 @@ class Graphiti:
754
1055
  group_ids: list[str] | None = None,
755
1056
  num_results=DEFAULT_SEARCH_LIMIT,
756
1057
  search_filter: SearchFilters | None = None,
1058
+ driver: GraphDriver | None = None,
757
1059
  ) -> list[EntityEdge]:
758
1060
  """
759
1061
  Perform a hybrid search on the knowledge graph.
@@ -800,7 +1102,8 @@ class Graphiti:
800
1102
  group_ids,
801
1103
  search_config,
802
1104
  search_filter if search_filter is not None else SearchFilters(),
803
- center_node_uuid,
1105
+ driver=driver,
1106
+ center_node_uuid=center_node_uuid,
804
1107
  )
805
1108
  ).edges
806
1109
 
@@ -820,6 +1123,7 @@ class Graphiti:
820
1123
  query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
821
1124
  )
822
1125
 
1126
+ @handle_multiple_group_ids
823
1127
  async def search_(
824
1128
  self,
825
1129
  query: str,
@@ -828,6 +1132,7 @@ class Graphiti:
828
1132
  center_node_uuid: str | None = None,
829
1133
  bfs_origin_node_uuids: list[str] | None = None,
830
1134
  search_filter: SearchFilters | None = None,
1135
+ driver: GraphDriver | None = None,
831
1136
  ) -> SearchResults:
832
1137
  """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
833
1138
  than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
@@ -844,6 +1149,7 @@ class Graphiti:
844
1149
  search_filter if search_filter is not None else SearchFilters(),
845
1150
  center_node_uuid,
846
1151
  bfs_origin_node_uuids,
1152
+ driver=driver,
847
1153
  )
848
1154
 
849
1155
  async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
@@ -858,9 +1164,11 @@ class Graphiti:
858
1164
 
859
1165
  nodes = await get_mentioned_nodes(self.driver, episodes)
860
1166
 
861
- return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[])
1167
+ return SearchResults(edges=edges, nodes=nodes)
862
1168
 
863
- async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
1169
+ async def add_triplet(
1170
+ self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
1171
+ ) -> AddTripletResults:
864
1172
  if source_node.name_embedding is None:
865
1173
  await source_node.generate_name_embedding(self.embedder)
866
1174
  if target_node.name_embedding is None:
@@ -868,17 +1176,35 @@ class Graphiti:
868
1176
  if edge.fact_embedding is None:
869
1177
  await edge.generate_embedding(self.embedder)
870
1178
 
871
- resolved_nodes, uuid_map, _ = await resolve_extracted_nodes(
1179
+ nodes, uuid_map, _ = await resolve_extracted_nodes(
872
1180
  self.clients,
873
1181
  [source_node, target_node],
874
1182
  )
875
1183
 
876
1184
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
877
1185
 
878
- related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
1186
+ valid_edges = await EntityEdge.get_between_nodes(
1187
+ self.driver, edge.source_node_uuid, edge.target_node_uuid
1188
+ )
1189
+
1190
+ related_edges = (
1191
+ await search(
1192
+ self.clients,
1193
+ updated_edge.fact,
1194
+ group_ids=[updated_edge.group_id],
1195
+ config=EDGE_HYBRID_SEARCH_RRF,
1196
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
1197
+ )
1198
+ ).edges
879
1199
  existing_edges = (
880
- await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
881
- )[0]
1200
+ await search(
1201
+ self.clients,
1202
+ updated_edge.fact,
1203
+ group_ids=[updated_edge.group_id],
1204
+ config=EDGE_HYBRID_SEARCH_RRF,
1205
+ search_filter=SearchFilters(),
1206
+ )
1207
+ ).edges
882
1208
 
883
1209
  resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
884
1210
  self.llm_client,
@@ -894,11 +1220,17 @@ class Graphiti:
894
1220
  entity_edges=[],
895
1221
  group_id=edge.group_id,
896
1222
  ),
1223
+ None,
1224
+ None,
897
1225
  )
898
1226
 
899
- await add_nodes_and_edges_bulk(
900
- self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
901
- )
1227
+ edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
1228
+
1229
+ await create_entity_edge_embeddings(self.embedder, edges)
1230
+ await create_entity_node_embeddings(self.embedder, nodes)
1231
+
1232
+ await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
1233
+ return AddTripletResults(edges=edges, nodes=nodes)
902
1234
 
903
1235
  async def remove_episode(self, episode_uuid: str):
904
1236
  # Find the episode to be deleted
@@ -925,12 +1257,7 @@ class Graphiti:
925
1257
  if record['episode_count'] == 1:
926
1258
  nodes_to_delete.append(node)
927
1259
 
928
- await semaphore_gather(
929
- *[node.delete(self.driver) for node in nodes_to_delete],
930
- max_coroutines=self.max_coroutines,
931
- )
932
- await semaphore_gather(
933
- *[edge.delete(self.driver) for edge in edges_to_delete],
934
- max_coroutines=self.max_coroutines,
935
- )
1260
+ await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
1261
+ await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
1262
+
936
1263
  await episode.delete(self.driver)