graphiti-core 0.22.0rc3__py3-none-any.whl → 0.22.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/graphiti.py CHANGED
@@ -63,6 +63,7 @@ from graphiti_core.search.search_utils import (
63
63
  get_mentioned_nodes,
64
64
  )
65
65
  from graphiti_core.telemetry import capture_event
66
+ from graphiti_core.tracer import Tracer, create_tracer
66
67
  from graphiti_core.utils.bulk_utils import (
67
68
  RawEpisode,
68
69
  add_nodes_and_edges_bulk,
@@ -136,6 +137,8 @@ class Graphiti:
136
137
  store_raw_episode_content: bool = True,
137
138
  graph_driver: GraphDriver | None = None,
138
139
  max_coroutines: int | None = None,
140
+ tracer: Tracer | None = None,
141
+ trace_span_prefix: str = 'graphiti',
139
142
  ):
140
143
  """
141
144
  Initialize a Graphiti instance.
@@ -168,6 +171,10 @@ class Graphiti:
168
171
  max_coroutines : int | None, optional
169
172
  The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
170
173
  If not set, the Graphiti default is used.
174
+ tracer : Tracer | None, optional
175
+ An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
176
+ trace_span_prefix : str, optional
177
+ Prefix to prepend to all span names. Defaults to 'graphiti'.
171
178
 
172
179
  Returns
173
180
  -------
@@ -210,11 +217,18 @@ class Graphiti:
210
217
  else:
211
218
  self.cross_encoder = OpenAIRerankerClient()
212
219
 
220
+ # Initialize tracer
221
+ self.tracer = create_tracer(tracer, trace_span_prefix)
222
+
223
+ # Set tracer on clients
224
+ self.llm_client.set_tracer(self.tracer)
225
+
213
226
  self.clients = GraphitiClients(
214
227
  driver=self.driver,
215
228
  llm_client=self.llm_client,
216
229
  embedder=self.embedder,
217
230
  cross_encoder=self.cross_encoder,
231
+ tracer=self.tracer,
218
232
  )
219
233
 
220
234
  # Capture telemetry event
@@ -339,6 +353,227 @@ class Graphiti:
339
353
  """
340
354
  await build_indices_and_constraints(self.driver, delete_existing)
341
355
 
356
+ async def _extract_and_resolve_nodes(
357
+ self,
358
+ episode: EpisodicNode,
359
+ previous_episodes: list[EpisodicNode],
360
+ entity_types: dict[str, type[BaseModel]] | None,
361
+ excluded_entity_types: list[str] | None,
362
+ ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
363
+ """Extract nodes from episode and resolve against existing graph."""
364
+ extracted_nodes = await extract_nodes(
365
+ self.clients, episode, previous_episodes, entity_types, excluded_entity_types
366
+ )
367
+
368
+ nodes, uuid_map, duplicates = await resolve_extracted_nodes(
369
+ self.clients,
370
+ extracted_nodes,
371
+ episode,
372
+ previous_episodes,
373
+ entity_types,
374
+ )
375
+
376
+ return nodes, uuid_map, duplicates
377
+
378
+ async def _extract_and_resolve_edges(
379
+ self,
380
+ episode: EpisodicNode,
381
+ extracted_nodes: list[EntityNode],
382
+ previous_episodes: list[EpisodicNode],
383
+ edge_type_map: dict[tuple[str, str], list[str]],
384
+ group_id: str,
385
+ edge_types: dict[str, type[BaseModel]] | None,
386
+ nodes: list[EntityNode],
387
+ uuid_map: dict[str, str],
388
+ ) -> tuple[list[EntityEdge], list[EntityEdge]]:
389
+ """Extract edges from episode and resolve against existing graph."""
390
+ extracted_edges = await extract_edges(
391
+ self.clients,
392
+ episode,
393
+ extracted_nodes,
394
+ previous_episodes,
395
+ edge_type_map,
396
+ group_id,
397
+ edge_types,
398
+ )
399
+
400
+ edges = resolve_edge_pointers(extracted_edges, uuid_map)
401
+
402
+ resolved_edges, invalidated_edges = await resolve_extracted_edges(
403
+ self.clients,
404
+ edges,
405
+ episode,
406
+ nodes,
407
+ edge_types or {},
408
+ edge_type_map,
409
+ )
410
+
411
+ return resolved_edges, invalidated_edges
412
+
413
+ async def _process_episode_data(
414
+ self,
415
+ episode: EpisodicNode,
416
+ nodes: list[EntityNode],
417
+ entity_edges: list[EntityEdge],
418
+ now: datetime,
419
+ ) -> tuple[list[EpisodicEdge], EpisodicNode]:
420
+ """Process and save episode data to the graph."""
421
+ episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
422
+ episode.entity_edges = [edge.uuid for edge in entity_edges]
423
+
424
+ if not self.store_raw_episode_content:
425
+ episode.content = ''
426
+
427
+ await add_nodes_and_edges_bulk(
428
+ self.driver,
429
+ [episode],
430
+ episodic_edges,
431
+ nodes,
432
+ entity_edges,
433
+ self.embedder,
434
+ )
435
+
436
+ return episodic_edges, episode
437
+
438
+ async def _extract_and_dedupe_nodes_bulk(
439
+ self,
440
+ episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
441
+ edge_type_map: dict[tuple[str, str], list[str]],
442
+ edge_types: dict[str, type[BaseModel]] | None,
443
+ entity_types: dict[str, type[BaseModel]] | None,
444
+ excluded_entity_types: list[str] | None,
445
+ ) -> tuple[
446
+ dict[str, list[EntityNode]],
447
+ dict[str, str],
448
+ list[list[EntityEdge]],
449
+ ]:
450
+ """Extract nodes and edges from all episodes and deduplicate."""
451
+ # Extract all nodes and edges for each episode
452
+ extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
453
+ self.clients,
454
+ episode_context,
455
+ edge_type_map=edge_type_map,
456
+ edge_types=edge_types,
457
+ entity_types=entity_types,
458
+ excluded_entity_types=excluded_entity_types,
459
+ )
460
+
461
+ # Dedupe extracted nodes in memory
462
+ nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
463
+ self.clients, extracted_nodes_bulk, episode_context, entity_types
464
+ )
465
+
466
+ return nodes_by_episode, uuid_map, extracted_edges_bulk
467
+
468
+ async def _resolve_nodes_and_edges_bulk(
469
+ self,
470
+ nodes_by_episode: dict[str, list[EntityNode]],
471
+ edges_by_episode: dict[str, list[EntityEdge]],
472
+ episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
473
+ entity_types: dict[str, type[BaseModel]] | None,
474
+ edge_types: dict[str, type[BaseModel]] | None,
475
+ edge_type_map: dict[tuple[str, str], list[str]],
476
+ episodes: list[EpisodicNode],
477
+ ) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
478
+ """Resolve nodes and edges against the existing graph."""
479
+ nodes_by_uuid: dict[str, EntityNode] = {
480
+ node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
481
+ }
482
+
483
+ # Get unique nodes per episode
484
+ nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
485
+ nodes_uuid_set: set[str] = set()
486
+ for episode, _ in episode_context:
487
+ nodes_by_episode_unique[episode.uuid] = []
488
+ nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
489
+ for node in nodes:
490
+ if node.uuid not in nodes_uuid_set:
491
+ nodes_by_episode_unique[episode.uuid].append(node)
492
+ nodes_uuid_set.add(node.uuid)
493
+
494
+ # Resolve nodes
495
+ node_results = await semaphore_gather(
496
+ *[
497
+ resolve_extracted_nodes(
498
+ self.clients,
499
+ nodes_by_episode_unique[episode.uuid],
500
+ episode,
501
+ previous_episodes,
502
+ entity_types,
503
+ )
504
+ for episode, previous_episodes in episode_context
505
+ ]
506
+ )
507
+
508
+ resolved_nodes: list[EntityNode] = []
509
+ uuid_map: dict[str, str] = {}
510
+ for result in node_results:
511
+ resolved_nodes.extend(result[0])
512
+ uuid_map.update(result[1])
513
+
514
+ # Update nodes_by_uuid with resolved nodes
515
+ for resolved_node in resolved_nodes:
516
+ nodes_by_uuid[resolved_node.uuid] = resolved_node
517
+
518
+ # Update nodes_by_episode_unique with resolved pointers
519
+ for episode_uuid, nodes in nodes_by_episode_unique.items():
520
+ updated_nodes: list[EntityNode] = []
521
+ for node in nodes:
522
+ updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
523
+ updated_node = nodes_by_uuid[updated_node_uuid]
524
+ updated_nodes.append(updated_node)
525
+ nodes_by_episode_unique[episode_uuid] = updated_nodes
526
+
527
+ # Extract attributes for resolved nodes
528
+ hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
529
+ *[
530
+ extract_attributes_from_nodes(
531
+ self.clients,
532
+ nodes_by_episode_unique[episode.uuid],
533
+ episode,
534
+ previous_episodes,
535
+ entity_types,
536
+ )
537
+ for episode, previous_episodes in episode_context
538
+ ]
539
+ )
540
+
541
+ final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
542
+
543
+ # Resolve edges with updated pointers
544
+ edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
545
+ edges_uuid_set: set[str] = set()
546
+ for episode_uuid, edges in edges_by_episode.items():
547
+ edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
548
+ edges_by_episode_unique[episode_uuid] = []
549
+
550
+ for edge in edges_with_updated_pointers:
551
+ if edge.uuid not in edges_uuid_set:
552
+ edges_by_episode_unique[episode_uuid].append(edge)
553
+ edges_uuid_set.add(edge.uuid)
554
+
555
+ edge_results = await semaphore_gather(
556
+ *[
557
+ resolve_extracted_edges(
558
+ self.clients,
559
+ edges_by_episode_unique[episode.uuid],
560
+ episode,
561
+ final_hydrated_nodes,
562
+ edge_types or {},
563
+ edge_type_map,
564
+ )
565
+ for episode in episodes
566
+ ]
567
+ )
568
+
569
+ resolved_edges: list[EntityEdge] = []
570
+ invalidated_edges: list[EntityEdge] = []
571
+ for result in edge_results:
572
+ resolved_edges.extend(result[0])
573
+ invalidated_edges.extend(result[1])
574
+
575
+ return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
576
+
342
577
  async def retrieve_episodes(
343
578
  self,
344
579
  reference_time: datetime,
@@ -444,133 +679,138 @@ class Graphiti:
444
679
  background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
445
680
  return {"message": "Episode processing started"}
446
681
  """
447
- try:
448
- start = time()
449
- now = utc_now()
450
-
451
- validate_entity_types(entity_types)
452
-
453
- validate_excluded_entity_types(excluded_entity_types, entity_types)
454
- validate_group_id(group_id)
455
- # if group_id is None, use the default group id by the provider
456
- group_id = group_id or get_default_group_id(self.driver.provider)
457
-
458
- previous_episodes = (
459
- await self.retrieve_episodes(
460
- reference_time,
461
- last_n=RELEVANT_SCHEMA_LIMIT,
462
- group_ids=[group_id],
463
- source=source,
464
- )
465
- if previous_episode_uuids is None
466
- else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
467
- )
682
+ start = time()
683
+ now = utc_now()
684
+
685
+ validate_entity_types(entity_types)
468
686
 
469
- episode = (
470
- await EpisodicNode.get_by_uuid(self.driver, uuid)
471
- if uuid is not None
472
- else EpisodicNode(
473
- name=name,
474
- group_id=group_id,
475
- labels=[],
476
- source=source,
477
- content=episode_body,
478
- source_description=source_description,
479
- created_at=now,
480
- valid_at=reference_time,
687
+ validate_excluded_entity_types(excluded_entity_types, entity_types)
688
+ validate_group_id(group_id)
689
+ # if group_id is None, use the default group id by the provider
690
+ group_id = group_id or get_default_group_id(self.driver.provider)
691
+
692
+ with self.tracer.start_span('add_episode') as span:
693
+ try:
694
+ # Retrieve previous episodes for context
695
+ previous_episodes = (
696
+ await self.retrieve_episodes(
697
+ reference_time,
698
+ last_n=RELEVANT_SCHEMA_LIMIT,
699
+ group_ids=[group_id],
700
+ source=source,
701
+ )
702
+ if previous_episode_uuids is None
703
+ else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
481
704
  )
482
- )
483
705
 
484
- # Create default edge type map
485
- edge_type_map_default = (
486
- {('Entity', 'Entity'): list(edge_types.keys())}
487
- if edge_types is not None
488
- else {('Entity', 'Entity'): []}
489
- )
706
+ # Get or create episode
707
+ episode = (
708
+ await EpisodicNode.get_by_uuid(self.driver, uuid)
709
+ if uuid is not None
710
+ else EpisodicNode(
711
+ name=name,
712
+ group_id=group_id,
713
+ labels=[],
714
+ source=source,
715
+ content=episode_body,
716
+ source_description=source_description,
717
+ created_at=now,
718
+ valid_at=reference_time,
719
+ )
720
+ )
490
721
 
491
- # Extract entities as nodes
722
+ # Create default edge type map
723
+ edge_type_map_default = (
724
+ {('Entity', 'Entity'): list(edge_types.keys())}
725
+ if edge_types is not None
726
+ else {('Entity', 'Entity'): []}
727
+ )
492
728
 
493
- extracted_nodes = await extract_nodes(
494
- self.clients, episode, previous_episodes, entity_types, excluded_entity_types
495
- )
729
+ # Extract and resolve nodes
730
+ extracted_nodes = await extract_nodes(
731
+ self.clients, episode, previous_episodes, entity_types, excluded_entity_types
732
+ )
496
733
 
497
- # Extract edges and resolve nodes
498
- (nodes, uuid_map, _), extracted_edges = await semaphore_gather(
499
- resolve_extracted_nodes(
734
+ nodes, uuid_map, _ = await resolve_extracted_nodes(
500
735
  self.clients,
501
736
  extracted_nodes,
502
737
  episode,
503
738
  previous_episodes,
504
739
  entity_types,
505
- ),
506
- extract_edges(
507
- self.clients,
740
+ )
741
+
742
+ # Extract and resolve edges in parallel with attribute extraction
743
+ resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
508
744
  episode,
509
745
  extracted_nodes,
510
746
  previous_episodes,
511
747
  edge_type_map or edge_type_map_default,
512
748
  group_id,
513
749
  edge_types,
514
- ),
515
- max_coroutines=self.max_coroutines,
516
- )
517
-
518
- edges = resolve_edge_pointers(extracted_edges, uuid_map)
519
-
520
- (resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
521
- resolve_extracted_edges(
522
- self.clients,
523
- edges,
524
- episode,
525
750
  nodes,
526
- edge_types or {},
527
- edge_type_map or edge_type_map_default,
528
- ),
529
- extract_attributes_from_nodes(
751
+ uuid_map,
752
+ )
753
+
754
+ # Extract node attributes
755
+ hydrated_nodes = await extract_attributes_from_nodes(
530
756
  self.clients, nodes, episode, previous_episodes, entity_types
531
- ),
532
- max_coroutines=self.max_coroutines,
533
- )
757
+ )
534
758
 
535
- entity_edges = resolved_edges + invalidated_edges
759
+ entity_edges = resolved_edges + invalidated_edges
536
760
 
537
- episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
761
+ # Process and save episode data
762
+ episodic_edges, episode = await self._process_episode_data(
763
+ episode, hydrated_nodes, entity_edges, now
764
+ )
538
765
 
539
- episode.entity_edges = [edge.uuid for edge in entity_edges]
766
+ # Update communities if requested
767
+ communities = []
768
+ community_edges = []
769
+ if update_communities:
770
+ communities, community_edges = await semaphore_gather(
771
+ *[
772
+ update_community(self.driver, self.llm_client, self.embedder, node)
773
+ for node in nodes
774
+ ],
775
+ max_coroutines=self.max_coroutines,
776
+ )
540
777
 
541
- if not self.store_raw_episode_content:
542
- episode.content = ''
778
+ end = time()
779
+
780
+ # Add span attributes
781
+ span.add_attributes(
782
+ {
783
+ 'episode.uuid': episode.uuid,
784
+ 'episode.source': source.value,
785
+ 'episode.reference_time': reference_time.isoformat(),
786
+ 'group_id': group_id,
787
+ 'node.count': len(hydrated_nodes),
788
+ 'edge.count': len(entity_edges),
789
+ 'edge.invalidated_count': len(invalidated_edges),
790
+ 'previous_episodes.count': len(previous_episodes),
791
+ 'entity_types.count': len(entity_types) if entity_types else 0,
792
+ 'edge_types.count': len(edge_types) if edge_types else 0,
793
+ 'update_communities': update_communities,
794
+ 'communities.count': len(communities) if update_communities else 0,
795
+ 'duration_ms': (end - start) * 1000,
796
+ }
797
+ )
543
798
 
544
- await add_nodes_and_edges_bulk(
545
- self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
546
- )
799
+ logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
547
800
 
548
- communities = []
549
- community_edges = []
550
-
551
- # Update any communities
552
- if update_communities:
553
- communities, community_edges = await semaphore_gather(
554
- *[
555
- update_community(self.driver, self.llm_client, self.embedder, node)
556
- for node in nodes
557
- ],
558
- max_coroutines=self.max_coroutines,
801
+ return AddEpisodeResults(
802
+ episode=episode,
803
+ episodic_edges=episodic_edges,
804
+ nodes=hydrated_nodes,
805
+ edges=entity_edges,
806
+ communities=communities,
807
+ community_edges=community_edges,
559
808
  )
560
- end = time()
561
- logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
562
-
563
- return AddEpisodeResults(
564
- episode=episode,
565
- episodic_edges=episodic_edges,
566
- nodes=hydrated_nodes,
567
- edges=entity_edges,
568
- communities=communities,
569
- community_edges=community_edges,
570
- )
571
809
 
572
- except Exception as e:
573
- raise e
810
+ except Exception as e:
811
+ span.set_status('error', str(e))
812
+ span.record_exception(e)
813
+ raise e
574
814
 
575
815
  async def add_episode_bulk(
576
816
  self,
@@ -617,248 +857,141 @@ class Graphiti:
617
857
  If these operations are required, use the `add_episode` method instead for each
618
858
  individual episode.
619
859
  """
620
- try:
621
- start = time()
622
- now = utc_now()
623
-
624
- # if group_id is None, use the default group id by the provider
625
- group_id = group_id or get_default_group_id(self.driver.provider)
626
- validate_group_id(group_id)
627
-
628
- # Create default edge type map
629
- edge_type_map_default = (
630
- {('Entity', 'Entity'): list(edge_types.keys())}
631
- if edge_types is not None
632
- else {('Entity', 'Entity'): []}
633
- )
860
+ with self.tracer.start_span('add_episode_bulk') as bulk_span:
861
+ bulk_span.add_attributes({'episode.count': len(bulk_episodes)})
634
862
 
635
- episodes = [
636
- await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
637
- if episode.uuid is not None
638
- else EpisodicNode(
639
- name=episode.name,
640
- labels=[],
641
- source=episode.source,
642
- content=episode.content,
643
- source_description=episode.source_description,
644
- group_id=group_id,
645
- created_at=now,
646
- valid_at=episode.reference_time,
863
+ try:
864
+ start = time()
865
+ now = utc_now()
866
+
867
+ # if group_id is None, use the default group id by the provider
868
+ group_id = group_id or get_default_group_id(self.driver.provider)
869
+ validate_group_id(group_id)
870
+
871
+ # Create default edge type map
872
+ edge_type_map_default = (
873
+ {('Entity', 'Entity'): list(edge_types.keys())}
874
+ if edge_types is not None
875
+ else {('Entity', 'Entity'): []}
647
876
  )
648
- for episode in bulk_episodes
649
- ]
650
-
651
- episodes_by_uuid: dict[str, EpisodicNode] = {
652
- episode.uuid: episode for episode in episodes
653
- }
654
-
655
- # Save all episodes
656
- await add_nodes_and_edges_bulk(
657
- driver=self.driver,
658
- episodic_nodes=episodes,
659
- episodic_edges=[],
660
- entity_nodes=[],
661
- entity_edges=[],
662
- embedder=self.embedder,
663
- )
664
-
665
- # Get previous episode context for each episode
666
- episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
667
877
 
668
- # Extract all nodes and edges for each episode
669
- extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
670
- self.clients,
671
- episode_context,
672
- edge_type_map=edge_type_map or edge_type_map_default,
673
- edge_types=edge_types,
674
- entity_types=entity_types,
675
- excluded_entity_types=excluded_entity_types,
676
- )
677
-
678
- # Dedupe extracted nodes in memory
679
- nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
680
- self.clients, extracted_nodes_bulk, episode_context, entity_types
681
- )
878
+ episodes = [
879
+ await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
880
+ if episode.uuid is not None
881
+ else EpisodicNode(
882
+ name=episode.name,
883
+ labels=[],
884
+ source=episode.source,
885
+ content=episode.content,
886
+ source_description=episode.source_description,
887
+ group_id=group_id,
888
+ created_at=now,
889
+ valid_at=episode.reference_time,
890
+ )
891
+ for episode in bulk_episodes
892
+ ]
682
893
 
683
- # Create Episodic Edges
684
- episodic_edges: list[EpisodicEdge] = []
685
- for episode_uuid, nodes in nodes_by_episode.items():
686
- episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
894
+ # Save all episodes
895
+ await add_nodes_and_edges_bulk(
896
+ driver=self.driver,
897
+ episodic_nodes=episodes,
898
+ episodic_edges=[],
899
+ entity_nodes=[],
900
+ entity_edges=[],
901
+ embedder=self.embedder,
902
+ )
687
903
 
688
- # re-map edge pointers so that they don't point to discard dupe nodes
689
- extracted_edges_bulk_updated: list[list[EntityEdge]] = [
690
- resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
691
- ]
904
+ # Get previous episode context for each episode
905
+ episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
692
906
 
693
- # Dedupe extracted edges in memory
694
- edges_by_episode = await dedupe_edges_bulk(
695
- self.clients,
696
- extracted_edges_bulk_updated,
697
- episode_context,
698
- [],
699
- edge_types or {},
700
- edge_type_map or edge_type_map_default,
701
- )
907
+ # Extract and dedupe nodes and edges
908
+ (
909
+ nodes_by_episode,
910
+ uuid_map,
911
+ extracted_edges_bulk,
912
+ ) = await self._extract_and_dedupe_nodes_bulk(
913
+ episode_context,
914
+ edge_type_map or edge_type_map_default,
915
+ edge_types,
916
+ entity_types,
917
+ excluded_entity_types,
918
+ )
702
919
 
703
- # Extract node attributes
704
- nodes_by_uuid: dict[str, EntityNode] = {
705
- node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
706
- }
920
+ # Create Episodic Edges
921
+ episodic_edges: list[EpisodicEdge] = []
922
+ for episode_uuid, nodes in nodes_by_episode.items():
923
+ episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
707
924
 
708
- extract_attributes_params: list[tuple[EntityNode, list[EpisodicNode]]] = []
709
- for node in nodes_by_uuid.values():
710
- episode_uuids: list[str] = []
711
- for episode_uuid, mentioned_nodes in nodes_by_episode.items():
712
- for mentioned_node in mentioned_nodes:
713
- if node.uuid == mentioned_node.uuid:
714
- episode_uuids.append(episode_uuid)
715
- break
716
-
717
- episode_mentions: list[EpisodicNode] = [
718
- episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
719
- ]
720
- episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
721
-
722
- extract_attributes_params.append((node, episode_mentions))
723
-
724
- new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
725
- *[
726
- extract_attributes_from_nodes(
727
- self.clients,
728
- [params[0]],
729
- params[1][0],
730
- params[1][0:],
731
- entity_types,
732
- )
733
- for params in extract_attributes_params
925
+ # Re-map edge pointers and dedupe edges
926
+ extracted_edges_bulk_updated: list[list[EntityEdge]] = [
927
+ resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
734
928
  ]
735
- )
736
929
 
737
- hydrated_nodes = [node for nodes in new_hydrated_nodes for node in nodes]
738
-
739
- # Update nodes_by_uuid map with the hydrated nodes
740
- for hydrated_node in hydrated_nodes:
741
- nodes_by_uuid[hydrated_node.uuid] = hydrated_node
742
-
743
- # Resolve nodes and edges against the existing graph
744
- nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
745
- nodes_uuid_set: set[str] = set()
746
- for episode, _ in episode_context:
747
- nodes_by_episode_unique[episode.uuid] = []
748
- nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
749
- for node in nodes:
750
- if node.uuid not in nodes_uuid_set:
751
- nodes_by_episode_unique[episode.uuid].append(node)
752
- nodes_uuid_set.add(node.uuid)
753
-
754
- node_results = await semaphore_gather(
755
- *[
756
- resolve_extracted_nodes(
757
- self.clients,
758
- nodes_by_episode_unique[episode.uuid],
759
- episode,
760
- previous_episodes,
761
- entity_types,
762
- )
763
- for episode, previous_episodes in episode_context
764
- ]
765
- )
930
+ edges_by_episode = await dedupe_edges_bulk(
931
+ self.clients,
932
+ extracted_edges_bulk_updated,
933
+ episode_context,
934
+ [],
935
+ edge_types or {},
936
+ edge_type_map or edge_type_map_default,
937
+ )
766
938
 
767
- resolved_nodes: list[EntityNode] = []
768
- uuid_map: dict[str, str] = {}
769
- node_duplicates: list[tuple[EntityNode, EntityNode]] = []
770
- for result in node_results:
771
- resolved_nodes.extend(result[0])
772
- uuid_map.update(result[1])
773
- node_duplicates.extend(result[2])
774
-
775
- # Update nodes_by_uuid map with the resolved nodes
776
- for resolved_node in resolved_nodes:
777
- nodes_by_uuid[resolved_node.uuid] = resolved_node
778
-
779
- # update nodes_by_episode_unique mapping
780
- for episode_uuid, nodes in nodes_by_episode_unique.items():
781
- updated_nodes: list[EntityNode] = []
782
- for node in nodes:
783
- updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
784
- updated_node = nodes_by_uuid[updated_node_uuid]
785
- updated_nodes.append(updated_node)
786
-
787
- nodes_by_episode_unique[episode_uuid] = updated_nodes
788
-
789
- hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
790
- *[
791
- extract_attributes_from_nodes(
792
- self.clients,
793
- nodes_by_episode_unique[episode.uuid],
794
- episode,
795
- previous_episodes,
796
- entity_types,
797
- )
798
- for episode, previous_episodes in episode_context
799
- ]
800
- )
939
+ # Resolve nodes and edges against the existing graph
940
+ (
941
+ final_hydrated_nodes,
942
+ resolved_edges,
943
+ invalidated_edges,
944
+ final_uuid_map,
945
+ ) = await self._resolve_nodes_and_edges_bulk(
946
+ nodes_by_episode,
947
+ edges_by_episode,
948
+ episode_context,
949
+ entity_types,
950
+ edge_types,
951
+ edge_type_map or edge_type_map_default,
952
+ episodes,
953
+ )
801
954
 
802
- final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
803
-
804
- edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
805
- edges_uuid_set: set[str] = set()
806
- for episode_uuid, edges in edges_by_episode.items():
807
- edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
808
- edges_by_episode_unique[episode_uuid] = []
809
-
810
- for edge in edges_with_updated_pointers:
811
- if edge.uuid not in edges_uuid_set:
812
- edges_by_episode_unique[episode_uuid].append(edge)
813
- edges_uuid_set.add(edge.uuid)
814
-
815
- edge_results = await semaphore_gather(
816
- *[
817
- resolve_extracted_edges(
818
- self.clients,
819
- edges_by_episode_unique[episode.uuid],
820
- episode,
821
- hydrated_nodes,
822
- edge_types or {},
823
- edge_type_map or edge_type_map_default,
824
- )
825
- for episode in episodes
826
- ]
827
- )
955
+ # Resolved pointers for episodic edges
956
+ resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)
957
+
958
+ # save data to KG
959
+ await add_nodes_and_edges_bulk(
960
+ self.driver,
961
+ episodes,
962
+ resolved_episodic_edges,
963
+ final_hydrated_nodes,
964
+ resolved_edges + invalidated_edges,
965
+ self.embedder,
966
+ )
828
967
 
829
- resolved_edges: list[EntityEdge] = []
830
- invalidated_edges: list[EntityEdge] = []
831
- for result in edge_results:
832
- resolved_edges.extend(result[0])
833
- invalidated_edges.extend(result[1])
834
-
835
- # Resolved pointers for episodic edges
836
- resolved_episodic_edges = resolve_edge_pointers(episodic_edges, uuid_map)
837
-
838
- # save data to KG
839
- await add_nodes_and_edges_bulk(
840
- self.driver,
841
- episodes,
842
- resolved_episodic_edges,
843
- final_hydrated_nodes,
844
- resolved_edges + invalidated_edges,
845
- self.embedder,
846
- )
968
+ end = time()
847
969
 
848
- end = time()
849
- logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
970
+ # Add span attributes
971
+ bulk_span.add_attributes(
972
+ {
973
+ 'group_id': group_id,
974
+ 'node.count': len(final_hydrated_nodes),
975
+ 'edge.count': len(resolved_edges + invalidated_edges),
976
+ 'duration_ms': (end - start) * 1000,
977
+ }
978
+ )
850
979
 
851
- return AddBulkEpisodeResults(
852
- episodes=episodes,
853
- episodic_edges=resolved_episodic_edges,
854
- nodes=final_hydrated_nodes,
855
- edges=resolved_edges + invalidated_edges,
856
- communities=[],
857
- community_edges=[],
858
- )
980
+ logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
981
+
982
+ return AddBulkEpisodeResults(
983
+ episodes=episodes,
984
+ episodic_edges=resolved_episodic_edges,
985
+ nodes=final_hydrated_nodes,
986
+ edges=resolved_edges + invalidated_edges,
987
+ communities=[],
988
+ community_edges=[],
989
+ )
859
990
 
860
- except Exception as e:
861
- raise e
991
+ except Exception as e:
992
+ bulk_span.set_status('error', str(e))
993
+ bulk_span.record_exception(e)
994
+ raise e
862
995
 
863
996
  async def build_communities(
864
997
  self, group_ids: list[str] | None = None