graphiti-core 0.17.4__py3-none-any.whl → 0.25.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +70 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +635 -260
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +37 -15
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +92 -48
  24. graphiti_core/llm_client/openai_client.py +39 -9
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +24 -15
  33. graphiti_core/prompts/extract_nodes.py +76 -35
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +110 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/content_chunking.py +702 -0
  45. graphiti_core/utils/datetime_utils.py +13 -0
  46. graphiti_core/utils/maintenance/community_operations.py +62 -38
  47. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  48. graphiti_core/utils/maintenance/edge_operations.py +306 -156
  49. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  50. graphiti_core/utils/maintenance/node_operations.py +466 -206
  51. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  52. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  53. graphiti_core/utils/text_utils.py +53 -0
  54. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/METADATA +221 -87
  55. graphiti_core-0.25.3.dist-info/RECORD +87 -0
  56. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/WHEEL +1 -1
  57. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  58. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  59. {graphiti_core-0.17.4.dist-info → graphiti_core-0.25.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/graphiti.py CHANGED
@@ -24,18 +24,34 @@ from typing_extensions import LiteralString
24
24
 
25
25
  from graphiti_core.cross_encoder.client import CrossEncoderClient
26
26
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
27
+ from graphiti_core.decorators import handle_multiple_group_ids
27
28
  from graphiti_core.driver.driver import GraphDriver
28
29
  from graphiti_core.driver.neo4j_driver import Neo4jDriver
29
- from graphiti_core.edges import EntityEdge, EpisodicEdge
30
+ from graphiti_core.edges import (
31
+ CommunityEdge,
32
+ Edge,
33
+ EntityEdge,
34
+ EpisodicEdge,
35
+ create_entity_edge_embeddings,
36
+ )
30
37
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
38
+ from graphiti_core.errors import NodeNotFoundError
31
39
  from graphiti_core.graphiti_types import GraphitiClients
32
40
  from graphiti_core.helpers import (
41
+ get_default_group_id,
33
42
  semaphore_gather,
34
43
  validate_excluded_entity_types,
35
44
  validate_group_id,
36
45
  )
37
46
  from graphiti_core.llm_client import LLMClient, OpenAIClient
38
- from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
47
+ from graphiti_core.nodes import (
48
+ CommunityNode,
49
+ EntityNode,
50
+ EpisodeType,
51
+ EpisodicNode,
52
+ Node,
53
+ create_entity_node_embeddings,
54
+ )
39
55
  from graphiti_core.search.search import SearchConfig, search
40
56
  from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
41
57
  from graphiti_core.search.search_config_recipes import (
@@ -46,11 +62,10 @@ from graphiti_core.search.search_config_recipes import (
46
62
  from graphiti_core.search.search_filters import SearchFilters
47
63
  from graphiti_core.search.search_utils import (
48
64
  RELEVANT_SCHEMA_LIMIT,
49
- get_edge_invalidation_candidates,
50
65
  get_mentioned_nodes,
51
- get_relevant_edges,
52
66
  )
53
67
  from graphiti_core.telemetry import capture_event
68
+ from graphiti_core.tracer import Tracer, create_tracer
54
69
  from graphiti_core.utils.bulk_utils import (
55
70
  RawEpisode,
56
71
  add_nodes_and_edges_bulk,
@@ -67,7 +82,6 @@ from graphiti_core.utils.maintenance.community_operations import (
67
82
  update_community,
68
83
  )
69
84
  from graphiti_core.utils.maintenance.edge_operations import (
70
- build_duplicate_of_edges,
71
85
  build_episodic_edges,
72
86
  extract_edges,
73
87
  resolve_extracted_edge,
@@ -75,7 +89,6 @@ from graphiti_core.utils.maintenance.edge_operations import (
75
89
  )
76
90
  from graphiti_core.utils.maintenance.graph_data_operations import (
77
91
  EPISODE_WINDOW_LEN,
78
- build_indices_and_constraints,
79
92
  retrieve_episodes,
80
93
  )
81
94
  from graphiti_core.utils.maintenance.node_operations import (
@@ -92,6 +105,23 @@ load_dotenv()
92
105
 
93
106
  class AddEpisodeResults(BaseModel):
94
107
  episode: EpisodicNode
108
+ episodic_edges: list[EpisodicEdge]
109
+ nodes: list[EntityNode]
110
+ edges: list[EntityEdge]
111
+ communities: list[CommunityNode]
112
+ community_edges: list[CommunityEdge]
113
+
114
+
115
+ class AddBulkEpisodeResults(BaseModel):
116
+ episodes: list[EpisodicNode]
117
+ episodic_edges: list[EpisodicEdge]
118
+ nodes: list[EntityNode]
119
+ edges: list[EntityEdge]
120
+ communities: list[CommunityNode]
121
+ community_edges: list[CommunityEdge]
122
+
123
+
124
+ class AddTripletResults(BaseModel):
95
125
  nodes: list[EntityNode]
96
126
  edges: list[EntityEdge]
97
127
 
@@ -108,11 +138,13 @@ class Graphiti:
108
138
  store_raw_episode_content: bool = True,
109
139
  graph_driver: GraphDriver | None = None,
110
140
  max_coroutines: int | None = None,
141
+ tracer: Tracer | None = None,
142
+ trace_span_prefix: str = 'graphiti',
111
143
  ):
112
144
  """
113
145
  Initialize a Graphiti instance.
114
146
 
115
- This constructor sets up a connection to the Neo4j database and initializes
147
+ This constructor sets up a connection to a graph database and initializes
116
148
  the LLM client for natural language processing tasks.
117
149
 
118
150
  Parameters
@@ -140,6 +172,10 @@ class Graphiti:
140
172
  max_coroutines : int | None, optional
141
173
  The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
142
174
  If not set, the Graphiti default is used.
175
+ tracer : Tracer | None, optional
176
+ An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
177
+ trace_span_prefix : str, optional
178
+ Prefix to prepend to all span names. Defaults to 'graphiti'.
143
179
 
144
180
  Returns
145
181
  -------
@@ -147,11 +183,11 @@ class Graphiti:
147
183
 
148
184
  Notes
149
185
  -----
150
- This method establishes a connection to the Neo4j database using the provided
186
+ This method establishes a connection to a graph database (Neo4j by default) using the provided
151
187
  credentials. It also sets up the LLM client, either using the provided client
152
188
  or by creating a default OpenAIClient.
153
189
 
154
- The default database name is set to 'neo4j'. If a different database name
190
+ The default database name is defined during the driver’s construction. If a different database name
155
191
  is required, it should be specified in the URI or set separately after
156
192
  initialization.
157
193
 
@@ -182,11 +218,18 @@ class Graphiti:
182
218
  else:
183
219
  self.cross_encoder = OpenAIRerankerClient()
184
220
 
221
+ # Initialize tracer
222
+ self.tracer = create_tracer(tracer, trace_span_prefix)
223
+
224
+ # Set tracer on clients
225
+ self.llm_client.set_tracer(self.tracer)
226
+
185
227
  self.clients = GraphitiClients(
186
228
  driver=self.driver,
187
229
  llm_client=self.llm_client,
188
230
  embedder=self.embedder,
189
231
  cross_encoder=self.cross_encoder,
232
+ tracer=self.tracer,
190
233
  )
191
234
 
192
235
  # Capture telemetry event
@@ -298,25 +341,249 @@ class Graphiti:
298
341
  -----
299
342
  This method should typically be called once during the initial setup of the
300
343
  knowledge graph or when updating the database schema. It uses the
301
- `build_indices_and_constraints` function from the
302
- `graphiti_core.utils.maintenance.graph_data_operations` module to perform
344
+ driver's `build_indices_and_constraints` method to perform
303
345
  the actual database operations.
304
346
 
305
347
  The specific indices and constraints created depend on the implementation
306
- of the `build_indices_and_constraints` function. Refer to that function's
307
- documentation for details on the exact database schema modifications.
348
+ of the driver's `build_indices_and_constraints` method. Refer to the specific
349
+ driver documentation for details on the exact database schema modifications.
308
350
 
309
351
  Caution: Running this method on a large existing database may take some time
310
352
  and could impact database performance during execution.
311
353
  """
312
- await build_indices_and_constraints(self.driver, delete_existing)
354
+ await self.driver.build_indices_and_constraints(delete_existing)
355
+
356
+ async def _extract_and_resolve_nodes(
357
+ self,
358
+ episode: EpisodicNode,
359
+ previous_episodes: list[EpisodicNode],
360
+ entity_types: dict[str, type[BaseModel]] | None,
361
+ excluded_entity_types: list[str] | None,
362
+ ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
363
+ """Extract nodes from episode and resolve against existing graph."""
364
+ extracted_nodes = await extract_nodes(
365
+ self.clients, episode, previous_episodes, entity_types, excluded_entity_types
366
+ )
367
+
368
+ nodes, uuid_map, duplicates = await resolve_extracted_nodes(
369
+ self.clients,
370
+ extracted_nodes,
371
+ episode,
372
+ previous_episodes,
373
+ entity_types,
374
+ )
375
+
376
+ return nodes, uuid_map, duplicates
377
+
378
+ async def _extract_and_resolve_edges(
379
+ self,
380
+ episode: EpisodicNode,
381
+ extracted_nodes: list[EntityNode],
382
+ previous_episodes: list[EpisodicNode],
383
+ edge_type_map: dict[tuple[str, str], list[str]],
384
+ group_id: str,
385
+ edge_types: dict[str, type[BaseModel]] | None,
386
+ nodes: list[EntityNode],
387
+ uuid_map: dict[str, str],
388
+ custom_extraction_instructions: str | None = None,
389
+ ) -> tuple[list[EntityEdge], list[EntityEdge]]:
390
+ """Extract edges from episode and resolve against existing graph."""
391
+ extracted_edges = await extract_edges(
392
+ self.clients,
393
+ episode,
394
+ extracted_nodes,
395
+ previous_episodes,
396
+ edge_type_map,
397
+ group_id,
398
+ edge_types,
399
+ custom_extraction_instructions,
400
+ )
401
+
402
+ edges = resolve_edge_pointers(extracted_edges, uuid_map)
403
+
404
+ resolved_edges, invalidated_edges = await resolve_extracted_edges(
405
+ self.clients,
406
+ edges,
407
+ episode,
408
+ nodes,
409
+ edge_types or {},
410
+ edge_type_map,
411
+ )
412
+
413
+ return resolved_edges, invalidated_edges
414
+
415
+ async def _process_episode_data(
416
+ self,
417
+ episode: EpisodicNode,
418
+ nodes: list[EntityNode],
419
+ entity_edges: list[EntityEdge],
420
+ now: datetime,
421
+ ) -> tuple[list[EpisodicEdge], EpisodicNode]:
422
+ """Process and save episode data to the graph."""
423
+ episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
424
+ episode.entity_edges = [edge.uuid for edge in entity_edges]
425
+
426
+ if not self.store_raw_episode_content:
427
+ episode.content = ''
428
+
429
+ await add_nodes_and_edges_bulk(
430
+ self.driver,
431
+ [episode],
432
+ episodic_edges,
433
+ nodes,
434
+ entity_edges,
435
+ self.embedder,
436
+ )
437
+
438
+ return episodic_edges, episode
439
+
440
+ async def _extract_and_dedupe_nodes_bulk(
441
+ self,
442
+ episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
443
+ edge_type_map: dict[tuple[str, str], list[str]],
444
+ edge_types: dict[str, type[BaseModel]] | None,
445
+ entity_types: dict[str, type[BaseModel]] | None,
446
+ excluded_entity_types: list[str] | None,
447
+ ) -> tuple[
448
+ dict[str, list[EntityNode]],
449
+ dict[str, str],
450
+ list[list[EntityEdge]],
451
+ ]:
452
+ """Extract nodes and edges from all episodes and deduplicate."""
453
+ # Extract all nodes and edges for each episode
454
+ extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
455
+ self.clients,
456
+ episode_context,
457
+ edge_type_map=edge_type_map,
458
+ edge_types=edge_types,
459
+ entity_types=entity_types,
460
+ excluded_entity_types=excluded_entity_types,
461
+ )
462
+
463
+ # Dedupe extracted nodes in memory
464
+ nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
465
+ self.clients, extracted_nodes_bulk, episode_context, entity_types
466
+ )
467
+
468
+ return nodes_by_episode, uuid_map, extracted_edges_bulk
469
+
470
+ async def _resolve_nodes_and_edges_bulk(
471
+ self,
472
+ nodes_by_episode: dict[str, list[EntityNode]],
473
+ edges_by_episode: dict[str, list[EntityEdge]],
474
+ episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
475
+ entity_types: dict[str, type[BaseModel]] | None,
476
+ edge_types: dict[str, type[BaseModel]] | None,
477
+ edge_type_map: dict[tuple[str, str], list[str]],
478
+ episodes: list[EpisodicNode],
479
+ ) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
480
+ """Resolve nodes and edges against the existing graph."""
481
+ nodes_by_uuid: dict[str, EntityNode] = {
482
+ node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
483
+ }
484
+
485
+ # Get unique nodes per episode
486
+ nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
487
+ nodes_uuid_set: set[str] = set()
488
+ for episode, _ in episode_context:
489
+ nodes_by_episode_unique[episode.uuid] = []
490
+ nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
491
+ for node in nodes:
492
+ if node.uuid not in nodes_uuid_set:
493
+ nodes_by_episode_unique[episode.uuid].append(node)
494
+ nodes_uuid_set.add(node.uuid)
495
+
496
+ # Resolve nodes
497
+ node_results = await semaphore_gather(
498
+ *[
499
+ resolve_extracted_nodes(
500
+ self.clients,
501
+ nodes_by_episode_unique[episode.uuid],
502
+ episode,
503
+ previous_episodes,
504
+ entity_types,
505
+ )
506
+ for episode, previous_episodes in episode_context
507
+ ]
508
+ )
509
+
510
+ resolved_nodes: list[EntityNode] = []
511
+ uuid_map: dict[str, str] = {}
512
+ for result in node_results:
513
+ resolved_nodes.extend(result[0])
514
+ uuid_map.update(result[1])
515
+
516
+ # Update nodes_by_uuid with resolved nodes
517
+ for resolved_node in resolved_nodes:
518
+ nodes_by_uuid[resolved_node.uuid] = resolved_node
519
+
520
+ # Update nodes_by_episode_unique with resolved pointers
521
+ for episode_uuid, nodes in nodes_by_episode_unique.items():
522
+ updated_nodes: list[EntityNode] = []
523
+ for node in nodes:
524
+ updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
525
+ updated_node = nodes_by_uuid[updated_node_uuid]
526
+ updated_nodes.append(updated_node)
527
+ nodes_by_episode_unique[episode_uuid] = updated_nodes
528
+
529
+ # Extract attributes for resolved nodes
530
+ hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
531
+ *[
532
+ extract_attributes_from_nodes(
533
+ self.clients,
534
+ nodes_by_episode_unique[episode.uuid],
535
+ episode,
536
+ previous_episodes,
537
+ entity_types,
538
+ )
539
+ for episode, previous_episodes in episode_context
540
+ ]
541
+ )
542
+
543
+ final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
544
+
545
+ # Resolve edges with updated pointers
546
+ edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
547
+ edges_uuid_set: set[str] = set()
548
+ for episode_uuid, edges in edges_by_episode.items():
549
+ edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
550
+ edges_by_episode_unique[episode_uuid] = []
551
+
552
+ for edge in edges_with_updated_pointers:
553
+ if edge.uuid not in edges_uuid_set:
554
+ edges_by_episode_unique[episode_uuid].append(edge)
555
+ edges_uuid_set.add(edge.uuid)
556
+
557
+ edge_results = await semaphore_gather(
558
+ *[
559
+ resolve_extracted_edges(
560
+ self.clients,
561
+ edges_by_episode_unique[episode.uuid],
562
+ episode,
563
+ final_hydrated_nodes,
564
+ edge_types or {},
565
+ edge_type_map,
566
+ )
567
+ for episode in episodes
568
+ ]
569
+ )
313
570
 
571
+ resolved_edges: list[EntityEdge] = []
572
+ invalidated_edges: list[EntityEdge] = []
573
+ for result in edge_results:
574
+ resolved_edges.extend(result[0])
575
+ invalidated_edges.extend(result[1])
576
+
577
+ return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
578
+
579
+ @handle_multiple_group_ids
314
580
  async def retrieve_episodes(
315
581
  self,
316
582
  reference_time: datetime,
317
583
  last_n: int = EPISODE_WINDOW_LEN,
318
584
  group_ids: list[str] | None = None,
319
585
  source: EpisodeType | None = None,
586
+ driver: GraphDriver | None = None,
320
587
  ) -> list[EpisodicNode]:
321
588
  """
322
589
  Retrieve the last n episodic nodes from the graph.
@@ -343,7 +610,10 @@ class Graphiti:
343
610
  The actual retrieval is performed by the `retrieve_episodes` function
344
611
  from the `graphiti_core.utils` module.
345
612
  """
346
- return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source)
613
+ if driver is None:
614
+ driver = self.clients.driver
615
+
616
+ return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
347
617
 
348
618
  async def add_episode(
349
619
  self,
@@ -352,14 +622,15 @@ class Graphiti:
352
622
  source_description: str,
353
623
  reference_time: datetime,
354
624
  source: EpisodeType = EpisodeType.message,
355
- group_id: str = '',
625
+ group_id: str | None = None,
356
626
  uuid: str | None = None,
357
627
  update_communities: bool = False,
358
- entity_types: dict[str, BaseModel] | None = None,
628
+ entity_types: dict[str, type[BaseModel]] | None = None,
359
629
  excluded_entity_types: list[str] | None = None,
360
630
  previous_episode_uuids: list[str] | None = None,
361
- edge_types: dict[str, BaseModel] | None = None,
631
+ edge_types: dict[str, type[BaseModel]] | None = None,
362
632
  edge_type_map: dict[tuple[str, str], list[str]] | None = None,
633
+ custom_extraction_instructions: str | None = None,
363
634
  ) -> AddEpisodeResults:
364
635
  """
365
636
  Process an episode and update the graph.
@@ -394,6 +665,9 @@ class Graphiti:
394
665
  previous_episode_uuids : list[str] | None
395
666
  Optional. list of episode uuids to use as the previous episodes. If this is not provided,
396
667
  the most recent episodes by created_at date will be used.
668
+ custom_extraction_instructions : str | None
669
+ Optional. Custom extraction instructions string to be included in the extract entities and extract edges prompts.
670
+ This allows for additional instructions or context to guide the extraction process.
397
671
 
398
672
  Returns
399
673
  -------
@@ -416,133 +690,161 @@ class Graphiti:
416
690
  background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
417
691
  return {"message": "Episode processing started"}
418
692
  """
419
- try:
420
- start = time()
421
- now = utc_now()
693
+ start = time()
694
+ now = utc_now()
695
+
696
+ validate_entity_types(entity_types)
697
+ validate_excluded_entity_types(excluded_entity_types, entity_types)
422
698
 
423
- validate_entity_types(entity_types)
424
- validate_excluded_entity_types(excluded_entity_types, entity_types)
699
+ if group_id is None:
700
+ # if group_id is None, use the default group id by the provider
701
+ # and the preset database name will be used
702
+ group_id = get_default_group_id(self.driver.provider)
703
+ else:
425
704
  validate_group_id(group_id)
705
+ if group_id != self.driver._database:
706
+ # if group_id is provided, use it as the database name
707
+ self.driver = self.driver.clone(database=group_id)
708
+ self.clients.driver = self.driver
426
709
 
427
- previous_episodes = (
428
- await self.retrieve_episodes(
429
- reference_time,
430
- last_n=RELEVANT_SCHEMA_LIMIT,
431
- group_ids=[group_id],
432
- source=source,
710
+ with self.tracer.start_span('add_episode') as span:
711
+ try:
712
+ # Retrieve previous episodes for context
713
+ previous_episodes = (
714
+ await self.retrieve_episodes(
715
+ reference_time,
716
+ last_n=RELEVANT_SCHEMA_LIMIT,
717
+ group_ids=[group_id],
718
+ source=source,
719
+ )
720
+ if previous_episode_uuids is None
721
+ else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
433
722
  )
434
- if previous_episode_uuids is None
435
- else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
436
- )
437
723
 
438
- episode = (
439
- await EpisodicNode.get_by_uuid(self.driver, uuid)
440
- if uuid is not None
441
- else EpisodicNode(
442
- name=name,
443
- group_id=group_id,
444
- labels=[],
445
- source=source,
446
- content=episode_body,
447
- source_description=source_description,
448
- created_at=now,
449
- valid_at=reference_time,
724
+ # Get or create episode
725
+ episode = (
726
+ await EpisodicNode.get_by_uuid(self.driver, uuid)
727
+ if uuid is not None
728
+ else EpisodicNode(
729
+ name=name,
730
+ group_id=group_id,
731
+ labels=[],
732
+ source=source,
733
+ content=episode_body,
734
+ source_description=source_description,
735
+ created_at=now,
736
+ valid_at=reference_time,
737
+ )
450
738
  )
451
- )
452
739
 
453
- # Create default edge type map
454
- edge_type_map_default = (
455
- {('Entity', 'Entity'): list(edge_types.keys())}
456
- if edge_types is not None
457
- else {('Entity', 'Entity'): []}
458
- )
459
-
460
- # Extract entities as nodes
740
+ # Create default edge type map
741
+ edge_type_map_default = (
742
+ {('Entity', 'Entity'): list(edge_types.keys())}
743
+ if edge_types is not None
744
+ else {('Entity', 'Entity'): []}
745
+ )
461
746
 
462
- extracted_nodes = await extract_nodes(
463
- self.clients, episode, previous_episodes, entity_types, excluded_entity_types
464
- )
747
+ # Extract and resolve nodes
748
+ extracted_nodes = await extract_nodes(
749
+ self.clients,
750
+ episode,
751
+ previous_episodes,
752
+ entity_types,
753
+ excluded_entity_types,
754
+ custom_extraction_instructions,
755
+ )
465
756
 
466
- # Extract edges and resolve nodes
467
- (nodes, uuid_map, node_duplicates), extracted_edges = await semaphore_gather(
468
- resolve_extracted_nodes(
757
+ nodes, uuid_map, _ = await resolve_extracted_nodes(
469
758
  self.clients,
470
759
  extracted_nodes,
471
760
  episode,
472
761
  previous_episodes,
473
762
  entity_types,
474
- ),
475
- extract_edges(
476
- self.clients,
763
+ )
764
+
765
+ # Extract and resolve edges in parallel with attribute extraction
766
+ resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
477
767
  episode,
478
768
  extracted_nodes,
479
769
  previous_episodes,
480
770
  edge_type_map or edge_type_map_default,
481
771
  group_id,
482
772
  edge_types,
483
- ),
484
- max_coroutines=self.max_coroutines,
485
- )
486
-
487
- edges = resolve_edge_pointers(extracted_edges, uuid_map)
488
-
489
- (resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
490
- resolve_extracted_edges(
491
- self.clients,
492
- edges,
493
- episode,
494
773
  nodes,
495
- edge_types or {},
496
- edge_type_map or edge_type_map_default,
497
- ),
498
- extract_attributes_from_nodes(
499
- self.clients, nodes, episode, previous_episodes, entity_types
500
- ),
501
- max_coroutines=self.max_coroutines,
502
- )
774
+ uuid_map,
775
+ custom_extraction_instructions,
776
+ )
503
777
 
504
- duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
778
+ # Extract node attributes
779
+ hydrated_nodes = await extract_attributes_from_nodes(
780
+ self.clients, nodes, episode, previous_episodes, entity_types
781
+ )
505
782
 
506
- entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
783
+ entity_edges = resolved_edges + invalidated_edges
507
784
 
508
- episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
785
+ # Process and save episode data
786
+ episodic_edges, episode = await self._process_episode_data(
787
+ episode, hydrated_nodes, entity_edges, now
788
+ )
509
789
 
510
- episode.entity_edges = [edge.uuid for edge in entity_edges]
790
+ # Update communities if requested
791
+ communities = []
792
+ community_edges = []
793
+ if update_communities:
794
+ communities, community_edges = await semaphore_gather(
795
+ *[
796
+ update_community(self.driver, self.llm_client, self.embedder, node)
797
+ for node in nodes
798
+ ],
799
+ max_coroutines=self.max_coroutines,
800
+ )
511
801
 
512
- if not self.store_raw_episode_content:
513
- episode.content = ''
802
+ end = time()
803
+
804
+ # Add span attributes
805
+ span.add_attributes(
806
+ {
807
+ 'episode.uuid': episode.uuid,
808
+ 'episode.source': source.value,
809
+ 'episode.reference_time': reference_time.isoformat(),
810
+ 'group_id': group_id,
811
+ 'node.count': len(hydrated_nodes),
812
+ 'edge.count': len(entity_edges),
813
+ 'edge.invalidated_count': len(invalidated_edges),
814
+ 'previous_episodes.count': len(previous_episodes),
815
+ 'entity_types.count': len(entity_types) if entity_types else 0,
816
+ 'edge_types.count': len(edge_types) if edge_types else 0,
817
+ 'update_communities': update_communities,
818
+ 'communities.count': len(communities) if update_communities else 0,
819
+ 'duration_ms': (end - start) * 1000,
820
+ }
821
+ )
514
822
 
515
- await add_nodes_and_edges_bulk(
516
- self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
517
- )
823
+ logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
518
824
 
519
- # Update any communities
520
- if update_communities:
521
- await semaphore_gather(
522
- *[
523
- update_community(self.driver, self.llm_client, self.embedder, node)
524
- for node in nodes
525
- ],
526
- max_coroutines=self.max_coroutines,
825
+ return AddEpisodeResults(
826
+ episode=episode,
827
+ episodic_edges=episodic_edges,
828
+ nodes=hydrated_nodes,
829
+ edges=entity_edges,
830
+ communities=communities,
831
+ community_edges=community_edges,
527
832
  )
528
- end = time()
529
- logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
530
-
531
- return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
532
833
 
533
- except Exception as e:
534
- raise e
834
+ except Exception as e:
835
+ span.set_status('error', str(e))
836
+ span.record_exception(e)
837
+ raise e
535
838
 
536
- ##### EXPERIMENTAL #####
537
839
  async def add_episode_bulk(
538
840
  self,
539
841
  bulk_episodes: list[RawEpisode],
540
- group_id: str = '',
541
- entity_types: dict[str, BaseModel] | None = None,
842
+ group_id: str | None = None,
843
+ entity_types: dict[str, type[BaseModel]] | None = None,
542
844
  excluded_entity_types: list[str] | None = None,
543
- edge_types: dict[str, BaseModel] | None = None,
845
+ edge_types: dict[str, type[BaseModel]] | None = None,
544
846
  edge_type_map: dict[tuple[str, str], list[str]] | None = None,
545
- ):
847
+ ) -> AddBulkEpisodeResults:
546
848
  """
547
849
  Process multiple episodes in bulk and update the graph.
548
850
 
@@ -558,7 +860,7 @@ class Graphiti:
558
860
 
559
861
  Returns
560
862
  -------
561
- None
863
+ AddBulkEpisodeResults
562
864
 
563
865
  Notes
564
866
  -----
@@ -579,156 +881,167 @@ class Graphiti:
579
881
  If these operations are required, use the `add_episode` method instead for each
580
882
  individual episode.
581
883
  """
582
- try:
583
- start = time()
584
- now = utc_now()
884
+ with self.tracer.start_span('add_episode_bulk') as bulk_span:
885
+ bulk_span.add_attributes({'episode.count': len(bulk_episodes)})
585
886
 
586
- validate_group_id(group_id)
587
-
588
- # Create default edge type map
589
- edge_type_map_default = (
590
- {('Entity', 'Entity'): list(edge_types.keys())}
591
- if edge_types is not None
592
- else {('Entity', 'Entity'): []}
593
- )
594
-
595
- episodes = [
596
- await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
597
- if episode.uuid is not None
598
- else EpisodicNode(
599
- name=episode.name,
600
- labels=[],
601
- source=episode.source,
602
- content=episode.content,
603
- source_description=episode.source_description,
604
- group_id=group_id,
605
- created_at=now,
606
- valid_at=episode.reference_time,
887
+ try:
888
+ start = time()
889
+ now = utc_now()
890
+
891
+ # if group_id is None, use the default group id by the provider
892
+ if group_id is None:
893
+ group_id = get_default_group_id(self.driver.provider)
894
+ else:
895
+ validate_group_id(group_id)
896
+ if group_id != self.driver._database:
897
+ # if group_id is provided, use it as the database name
898
+ self.driver = self.driver.clone(database=group_id)
899
+ self.clients.driver = self.driver
900
+
901
+ # Create default edge type map
902
+ edge_type_map_default = (
903
+ {('Entity', 'Entity'): list(edge_types.keys())}
904
+ if edge_types is not None
905
+ else {('Entity', 'Entity'): []}
607
906
  )
608
- for episode in bulk_episodes
609
- ]
610
907
 
611
- episodes_by_uuid: dict[str, EpisodicNode] = {
612
- episode.uuid: episode for episode in episodes
613
- }
908
+ episodes = [
909
+ await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
910
+ if episode.uuid is not None
911
+ else EpisodicNode(
912
+ name=episode.name,
913
+ labels=[],
914
+ source=episode.source,
915
+ content=episode.content,
916
+ source_description=episode.source_description,
917
+ group_id=group_id,
918
+ created_at=now,
919
+ valid_at=episode.reference_time,
920
+ )
921
+ for episode in bulk_episodes
922
+ ]
614
923
 
615
- # Save all episodes
616
- await add_nodes_and_edges_bulk(
617
- driver=self.driver,
618
- episodic_nodes=episodes,
619
- episodic_edges=[],
620
- entity_nodes=[],
621
- entity_edges=[],
622
- embedder=self.embedder,
623
- )
924
+ # Save all episodes
925
+ await add_nodes_and_edges_bulk(
926
+ driver=self.driver,
927
+ episodic_nodes=episodes,
928
+ episodic_edges=[],
929
+ entity_nodes=[],
930
+ entity_edges=[],
931
+ embedder=self.embedder,
932
+ )
624
933
 
625
- # Get previous episode context for each episode
626
- episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
934
+ # Get previous episode context for each episode
935
+ episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
627
936
 
628
- # Extract all nodes and edges for each episode
629
- extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
630
- self.clients,
631
- episode_context,
632
- edge_type_map=edge_type_map or edge_type_map_default,
633
- edge_types=edge_types,
634
- entity_types=entity_types,
635
- excluded_entity_types=excluded_entity_types,
636
- )
637
-
638
- # Dedupe extracted nodes in memory
639
- nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
640
- self.clients, extracted_nodes_bulk, episode_context, entity_types
641
- )
937
+ # Extract and dedupe nodes and edges
938
+ (
939
+ nodes_by_episode,
940
+ uuid_map,
941
+ extracted_edges_bulk,
942
+ ) = await self._extract_and_dedupe_nodes_bulk(
943
+ episode_context,
944
+ edge_type_map or edge_type_map_default,
945
+ edge_types,
946
+ entity_types,
947
+ excluded_entity_types,
948
+ )
642
949
 
643
- episodic_edges: list[EpisodicEdge] = []
644
- for episode_uuid, nodes in nodes_by_episode.items():
645
- episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
950
+ # Create Episodic Edges
951
+ episodic_edges: list[EpisodicEdge] = []
952
+ for episode_uuid, nodes in nodes_by_episode.items():
953
+ episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
646
954
 
647
- # re-map edge pointers so that they don't point to discard dupe nodes
648
- extracted_edges_bulk_updated: list[list[EntityEdge]] = [
649
- resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
650
- ]
955
+ # Re-map edge pointers and dedupe edges
956
+ extracted_edges_bulk_updated: list[list[EntityEdge]] = [
957
+ resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
958
+ ]
651
959
 
652
- # Dedupe extracted edges in memory
653
- edges_by_episode = await dedupe_edges_bulk(
654
- self.clients,
655
- extracted_edges_bulk_updated,
656
- episode_context,
657
- [],
658
- edge_types or {},
659
- edge_type_map or edge_type_map_default,
660
- )
960
+ edges_by_episode = await dedupe_edges_bulk(
961
+ self.clients,
962
+ extracted_edges_bulk_updated,
963
+ episode_context,
964
+ [],
965
+ edge_types or {},
966
+ edge_type_map or edge_type_map_default,
967
+ )
661
968
 
662
- # Extract node attributes
663
- nodes_by_uuid: dict[str, EntityNode] = {
664
- node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
665
- }
969
+ # Resolve nodes and edges against the existing graph
970
+ (
971
+ final_hydrated_nodes,
972
+ resolved_edges,
973
+ invalidated_edges,
974
+ final_uuid_map,
975
+ ) = await self._resolve_nodes_and_edges_bulk(
976
+ nodes_by_episode,
977
+ edges_by_episode,
978
+ episode_context,
979
+ entity_types,
980
+ edge_types,
981
+ edge_type_map or edge_type_map_default,
982
+ episodes,
983
+ )
666
984
 
667
- extract_attributes_params: list[tuple[EntityNode, list[EpisodicNode]]] = []
668
- for node in nodes_by_uuid.values():
669
- episode_uuids: list[str] = []
670
- for episode_uuid, mentioned_nodes in nodes_by_episode.items():
671
- for mentioned_node in mentioned_nodes:
672
- if node.uuid == mentioned_node.uuid:
673
- episode_uuids.append(episode_uuid)
674
- break
675
-
676
- episode_mentions: list[EpisodicNode] = [
677
- episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
678
- ]
679
- episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
680
-
681
- extract_attributes_params.append((node, episode_mentions))
682
-
683
- new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
684
- *[
685
- extract_attributes_from_nodes(
686
- self.clients,
687
- [params[0]],
688
- params[1][0],
689
- params[1][0:],
690
- entity_types,
691
- )
692
- for params in extract_attributes_params
693
- ]
694
- )
985
+ # Resolved pointers for episodic edges
986
+ resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)
987
+
988
+ # save data to KG
989
+ await add_nodes_and_edges_bulk(
990
+ self.driver,
991
+ episodes,
992
+ resolved_episodic_edges,
993
+ final_hydrated_nodes,
994
+ resolved_edges + invalidated_edges,
995
+ self.embedder,
996
+ )
695
997
 
696
- hydrated_nodes = [node for nodes in new_hydrated_nodes for node in nodes]
998
+ end = time()
697
999
 
698
- # TODO: Resolve nodes and edges against the existing graph
699
- edges_by_uuid: dict[str, EntityEdge] = {
700
- edge.uuid: edge for edges in edges_by_episode.values() for edge in edges
701
- }
1000
+ # Add span attributes
1001
+ bulk_span.add_attributes(
1002
+ {
1003
+ 'group_id': group_id,
1004
+ 'node.count': len(final_hydrated_nodes),
1005
+ 'edge.count': len(resolved_edges + invalidated_edges),
1006
+ 'duration_ms': (end - start) * 1000,
1007
+ }
1008
+ )
702
1009
 
703
- # save data to KG
704
- await add_nodes_and_edges_bulk(
705
- self.driver,
706
- episodes,
707
- episodic_edges,
708
- hydrated_nodes,
709
- list(edges_by_uuid.values()),
710
- self.embedder,
711
- )
1010
+ logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
712
1011
 
713
- end = time()
714
- logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
1012
+ return AddBulkEpisodeResults(
1013
+ episodes=episodes,
1014
+ episodic_edges=resolved_episodic_edges,
1015
+ nodes=final_hydrated_nodes,
1016
+ edges=resolved_edges + invalidated_edges,
1017
+ communities=[],
1018
+ community_edges=[],
1019
+ )
715
1020
 
716
- except Exception as e:
717
- raise e
1021
+ except Exception as e:
1022
+ bulk_span.set_status('error', str(e))
1023
+ bulk_span.record_exception(e)
1024
+ raise e
718
1025
 
719
- async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
1026
+ @handle_multiple_group_ids
1027
+ async def build_communities(
1028
+ self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
1029
+ ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
720
1030
  """
721
1031
  Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
722
1032
  the content of these communities.
723
1033
  ----------
724
- query : list[str] | None
1034
+ group_ids : list[str] | None
725
1035
  Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
726
1036
  """
1037
+ if driver is None:
1038
+ driver = self.clients.driver
1039
+
727
1040
  # Clear existing communities
728
- await remove_communities(self.driver)
1041
+ await remove_communities(driver)
729
1042
 
730
1043
  community_nodes, community_edges = await build_communities(
731
- self.driver, self.llm_client, group_ids
1044
+ driver, self.llm_client, group_ids
732
1045
  )
733
1046
 
734
1047
  await semaphore_gather(
@@ -737,16 +1050,17 @@ class Graphiti:
737
1050
  )
738
1051
 
739
1052
  await semaphore_gather(
740
- *[node.save(self.driver) for node in community_nodes],
1053
+ *[node.save(driver) for node in community_nodes],
741
1054
  max_coroutines=self.max_coroutines,
742
1055
  )
743
1056
  await semaphore_gather(
744
- *[edge.save(self.driver) for edge in community_edges],
1057
+ *[edge.save(driver) for edge in community_edges],
745
1058
  max_coroutines=self.max_coroutines,
746
1059
  )
747
1060
 
748
- return community_nodes
1061
+ return community_nodes, community_edges
749
1062
 
1063
+ @handle_multiple_group_ids
750
1064
  async def search(
751
1065
  self,
752
1066
  query: str,
@@ -754,6 +1068,7 @@ class Graphiti:
754
1068
  group_ids: list[str] | None = None,
755
1069
  num_results=DEFAULT_SEARCH_LIMIT,
756
1070
  search_filter: SearchFilters | None = None,
1071
+ driver: GraphDriver | None = None,
757
1072
  ) -> list[EntityEdge]:
758
1073
  """
759
1074
  Perform a hybrid search on the knowledge graph.
@@ -800,7 +1115,8 @@ class Graphiti:
800
1115
  group_ids,
801
1116
  search_config,
802
1117
  search_filter if search_filter is not None else SearchFilters(),
803
- center_node_uuid,
1118
+ driver=driver,
1119
+ center_node_uuid=center_node_uuid,
804
1120
  )
805
1121
  ).edges
806
1122
 
@@ -820,6 +1136,7 @@ class Graphiti:
820
1136
  query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
821
1137
  )
822
1138
 
1139
+ @handle_multiple_group_ids
823
1140
  async def search_(
824
1141
  self,
825
1142
  query: str,
@@ -828,6 +1145,7 @@ class Graphiti:
828
1145
  center_node_uuid: str | None = None,
829
1146
  bfs_origin_node_uuids: list[str] | None = None,
830
1147
  search_filter: SearchFilters | None = None,
1148
+ driver: GraphDriver | None = None,
831
1149
  ) -> SearchResults:
832
1150
  """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
833
1151
  than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
@@ -844,6 +1162,7 @@ class Graphiti:
844
1162
  search_filter if search_filter is not None else SearchFilters(),
845
1163
  center_node_uuid,
846
1164
  bfs_origin_node_uuids,
1165
+ driver=driver,
847
1166
  )
848
1167
 
849
1168
  async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
@@ -858,9 +1177,11 @@ class Graphiti:
858
1177
 
859
1178
  nodes = await get_mentioned_nodes(self.driver, episodes)
860
1179
 
861
- return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[])
1180
+ return SearchResults(edges=edges, nodes=nodes)
862
1181
 
863
- async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
1182
+ async def add_triplet(
1183
+ self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
1184
+ ) -> AddTripletResults:
864
1185
  if source_node.name_embedding is None:
865
1186
  await source_node.generate_name_embedding(self.embedder)
866
1187
  if target_node.name_embedding is None:
@@ -868,21 +1189,74 @@ class Graphiti:
868
1189
  if edge.fact_embedding is None:
869
1190
  await edge.generate_embedding(self.embedder)
870
1191
 
871
- resolved_nodes, uuid_map, _ = await resolve_extracted_nodes(
872
- self.clients,
873
- [source_node, target_node],
874
- )
1192
+ try:
1193
+ resolved_source = await EntityNode.get_by_uuid(self.driver, source_node.uuid)
1194
+ except NodeNotFoundError:
1195
+ resolved_source_nodes, _, _ = await resolve_extracted_nodes(
1196
+ self.clients,
1197
+ [source_node],
1198
+ )
1199
+ resolved_source = resolved_source_nodes[0]
875
1200
 
876
- updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
1201
+ try:
1202
+ resolved_target = await EntityNode.get_by_uuid(self.driver, target_node.uuid)
1203
+ except NodeNotFoundError:
1204
+ resolved_target_nodes, _, _ = await resolve_extracted_nodes(
1205
+ self.clients,
1206
+ [target_node],
1207
+ )
1208
+ resolved_target = resolved_target_nodes[0]
1209
+
1210
+ nodes = [resolved_source, resolved_target]
1211
+
1212
+ # Merge user-provided properties from original nodes into resolved nodes (excluding uuid)
1213
+ # Update attributes dictionary (merge rather than replace)
1214
+ if source_node.attributes:
1215
+ resolved_source.attributes.update(source_node.attributes)
1216
+ if target_node.attributes:
1217
+ resolved_target.attributes.update(target_node.attributes)
1218
+
1219
+ # Update summary if provided by user (non-empty string)
1220
+ if source_node.summary:
1221
+ resolved_source.summary = source_node.summary
1222
+ if target_node.summary:
1223
+ resolved_target.summary = target_node.summary
1224
+
1225
+ # Update labels (merge with existing)
1226
+ if source_node.labels:
1227
+ resolved_source.labels = list(set(resolved_source.labels) | set(source_node.labels))
1228
+ if target_node.labels:
1229
+ resolved_target.labels = list(set(resolved_target.labels) | set(target_node.labels))
1230
+
1231
+ edge.source_node_uuid = resolved_source.uuid
1232
+ edge.target_node_uuid = resolved_target.uuid
1233
+
1234
+ valid_edges = await EntityEdge.get_between_nodes(
1235
+ self.driver, edge.source_node_uuid, edge.target_node_uuid
1236
+ )
877
1237
 
878
- related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
1238
+ related_edges = (
1239
+ await search(
1240
+ self.clients,
1241
+ edge.fact,
1242
+ group_ids=[edge.group_id],
1243
+ config=EDGE_HYBRID_SEARCH_RRF,
1244
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
1245
+ )
1246
+ ).edges
879
1247
  existing_edges = (
880
- await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
881
- )[0]
1248
+ await search(
1249
+ self.clients,
1250
+ edge.fact,
1251
+ group_ids=[edge.group_id],
1252
+ config=EDGE_HYBRID_SEARCH_RRF,
1253
+ search_filter=SearchFilters(),
1254
+ )
1255
+ ).edges
882
1256
 
883
1257
  resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
884
1258
  self.llm_client,
885
- updated_edge,
1259
+ edge,
886
1260
  related_edges,
887
1261
  existing_edges,
888
1262
  EpisodicNode(
@@ -894,11 +1268,17 @@ class Graphiti:
894
1268
  entity_edges=[],
895
1269
  group_id=edge.group_id,
896
1270
  ),
1271
+ None,
1272
+ None,
897
1273
  )
898
1274
 
899
- await add_nodes_and_edges_bulk(
900
- self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder
901
- )
1275
+ edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
1276
+
1277
+ await create_entity_edge_embeddings(self.embedder, edges)
1278
+ await create_entity_node_embeddings(self.embedder, nodes)
1279
+
1280
+ await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
1281
+ return AddTripletResults(edges=edges, nodes=nodes)
902
1282
 
903
1283
  async def remove_episode(self, episode_uuid: str):
904
1284
  # Find the episode to be deleted
@@ -925,12 +1305,7 @@ class Graphiti:
925
1305
  if record['episode_count'] == 1:
926
1306
  nodes_to_delete.append(node)
927
1307
 
928
- await semaphore_gather(
929
- *[node.delete(self.driver) for node in nodes_to_delete],
930
- max_coroutines=self.max_coroutines,
931
- )
932
- await semaphore_gather(
933
- *[edge.delete(self.driver) for edge in edges_to_delete],
934
- max_coroutines=self.max_coroutines,
935
- )
1308
+ await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
1309
+ await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
1310
+
936
1311
  await episode.delete(self.driver)