graphiti-core 0.15.1__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -41,6 +41,10 @@ DEFAULT_MODEL = 'gemini-2.5-flash-lite-preview-06-17'
41
41
 
42
42
 
43
43
  class GeminiRerankerClient(CrossEncoderClient):
44
+ """
45
+ Google Gemini Reranker Client
46
+ """
47
+
44
48
  def __init__(
45
49
  self,
46
50
  config: LLMConfig | None = None,
@@ -57,7 +61,6 @@ class GeminiRerankerClient(CrossEncoderClient):
57
61
  config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
58
62
  client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
59
63
  """
60
-
61
64
  if config is None:
62
65
  config = LLMConfig()
63
66
 
@@ -47,15 +47,27 @@ class GeminiEmbedder(EmbedderClient):
47
47
  Google Gemini Embedder Client
48
48
  """
49
49
 
50
- def __init__(self, config: GeminiEmbedderConfig | None = None):
50
+ def __init__(
51
+ self,
52
+ config: GeminiEmbedderConfig | None = None,
53
+ client: 'genai.Client | None' = None,
54
+ ):
55
+ """
56
+ Initialize the GeminiEmbedder with the provided configuration and client.
57
+
58
+ Args:
59
+ config (GeminiEmbedderConfig | None): The configuration for the GeminiEmbedder, including API key, model, base URL, temperature, and max tokens.
60
+ client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
61
+ """
51
62
  if config is None:
52
63
  config = GeminiEmbedderConfig()
64
+
53
65
  self.config = config
54
66
 
55
- # Configure the Gemini API
56
- self.client = genai.Client(
57
- api_key=config.api_key,
58
- )
67
+ if client is None:
68
+ self.client = genai.Client(api_key=config.api_key)
69
+ else:
70
+ self.client = client
59
71
 
60
72
  async def create(
61
73
  self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
graphiti_core/graphiti.py CHANGED
@@ -57,7 +57,6 @@ from graphiti_core.utils.bulk_utils import (
57
57
  add_nodes_and_edges_bulk,
58
58
  dedupe_edges_bulk,
59
59
  dedupe_nodes_bulk,
60
- extract_edge_dates_bulk,
61
60
  extract_nodes_and_edges_bulk,
62
61
  resolve_edge_pointers,
63
62
  retrieve_previous_episodes_bulk,
@@ -508,7 +507,7 @@ class Graphiti:
508
507
 
509
508
  entity_edges = resolved_edges + invalidated_edges + duplicate_of_edges
510
509
 
511
- episodic_edges = build_episodic_edges(nodes, episode, now)
510
+ episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
512
511
 
513
512
  episode.entity_edges = [edge.uuid for edge in entity_edges]
514
513
 
@@ -536,8 +535,16 @@ class Graphiti:
536
535
  except Exception as e:
537
536
  raise e
538
537
 
539
- #### WIP: USE AT YOUR OWN RISK ####
540
- async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
538
+ ##### EXPERIMENTAL #####
539
+ async def add_episode_bulk(
540
+ self,
541
+ bulk_episodes: list[RawEpisode],
542
+ group_id: str = '',
543
+ entity_types: dict[str, BaseModel] | None = None,
544
+ excluded_entity_types: list[str] | None = None,
545
+ edge_types: dict[str, BaseModel] | None = None,
546
+ edge_type_map: dict[tuple[str, str], list[str]] | None = None,
547
+ ):
541
548
  """
542
549
  Process multiple episodes in bulk and update the graph.
543
550
 
@@ -580,8 +587,17 @@ class Graphiti:
580
587
 
581
588
  validate_group_id(group_id)
582
589
 
590
+ # Create default edge type map
591
+ edge_type_map_default = (
592
+ {('Entity', 'Entity'): list(edge_types.keys())}
593
+ if edge_types is not None
594
+ else {('Entity', 'Entity'): []}
595
+ )
596
+
583
597
  episodes = [
584
- EpisodicNode(
598
+ await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
599
+ if episode.uuid is not None
600
+ else EpisodicNode(
585
601
  name=episode.name,
586
602
  labels=[],
587
603
  source=episode.source,
@@ -594,68 +610,106 @@ class Graphiti:
594
610
  for episode in bulk_episodes
595
611
  ]
596
612
 
597
- # Save all the episodes
598
- await semaphore_gather(
599
- *[episode.save(self.driver) for episode in episodes],
600
- max_coroutines=self.max_coroutines,
613
+ episodes_by_uuid: dict[str, EpisodicNode] = {
614
+ episode.uuid: episode for episode in episodes
615
+ }
616
+
617
+ # Save all episodes
618
+ await add_nodes_and_edges_bulk(
619
+ driver=self.driver,
620
+ episodic_nodes=episodes,
621
+ episodic_edges=[],
622
+ entity_nodes=[],
623
+ entity_edges=[],
624
+ embedder=self.embedder,
601
625
  )
602
626
 
603
627
  # Get previous episode context for each episode
604
- episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
628
+ episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
605
629
 
606
- # Extract all nodes and edges
607
- (
608
- extracted_nodes,
609
- extracted_edges,
610
- episodic_edges,
611
- ) = await extract_nodes_and_edges_bulk(self.clients, episode_pairs, None, None)
612
-
613
- # Generate embeddings
614
- await semaphore_gather(
615
- *[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
616
- *[edge.generate_embedding(self.embedder) for edge in extracted_edges],
617
- max_coroutines=self.max_coroutines,
630
+ # Extract all nodes and edges for each episode
631
+ extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
632
+ self.clients,
633
+ episode_context,
634
+ edge_type_map=edge_type_map or edge_type_map_default,
635
+ edge_types=edge_types,
636
+ entity_types=entity_types,
637
+ excluded_entity_types=excluded_entity_types,
618
638
  )
619
639
 
620
- # Dedupe extracted nodes, compress extracted edges
621
- (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
622
- dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
623
- extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
624
- max_coroutines=self.max_coroutines,
640
+ # Dedupe extracted nodes in memory
641
+ nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
642
+ self.clients, extracted_nodes_bulk, episode_context, entity_types
625
643
  )
626
644
 
627
- # save nodes to KG
628
- await semaphore_gather(
629
- *[node.save(self.driver) for node in nodes],
630
- max_coroutines=self.max_coroutines,
631
- )
645
+ episodic_edges: list[EpisodicEdge] = []
646
+ for episode_uuid, nodes in nodes_by_episode.items():
647
+ episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
632
648
 
633
649
  # re-map edge pointers so that they don't point to discard dupe nodes
634
- extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
635
- extracted_edges_timestamped, uuid_map
636
- )
637
- episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
638
- episodic_edges, uuid_map
639
- )
650
+ extracted_edges_bulk_updated: list[list[EntityEdge]] = [
651
+ resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
652
+ ]
640
653
 
641
- # save episodic edges to KG
642
- await semaphore_gather(
643
- *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers],
644
- max_coroutines=self.max_coroutines,
654
+ # Dedupe extracted edges in memory
655
+ edges_by_episode = await dedupe_edges_bulk(
656
+ self.clients,
657
+ extracted_edges_bulk_updated,
658
+ episode_context,
659
+ [],
660
+ edge_types or {},
661
+ edge_type_map or edge_type_map_default,
645
662
  )
646
663
 
647
- # Dedupe extracted edges
648
- edges = await dedupe_edges_bulk(
649
- self.driver, self.llm_client, extracted_edges_with_resolved_pointers
664
+ # Extract node attributes
665
+ nodes_by_uuid: dict[str, EntityNode] = {
666
+ node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
667
+ }
668
+
669
+ extract_attributes_params: list[tuple[EntityNode, list[EpisodicNode]]] = []
670
+ for node in nodes_by_uuid.values():
671
+ episode_uuids: list[str] = []
672
+ for episode_uuid, mentioned_nodes in nodes_by_episode.items():
673
+ for mentioned_node in mentioned_nodes:
674
+ if node.uuid == mentioned_node.uuid:
675
+ episode_uuids.append(episode_uuid)
676
+ break
677
+
678
+ episode_mentions: list[EpisodicNode] = [
679
+ episodes_by_uuid[episode_uuid] for episode_uuid in episode_uuids
680
+ ]
681
+ episode_mentions.sort(key=lambda x: x.valid_at, reverse=True)
682
+
683
+ extract_attributes_params.append((node, episode_mentions))
684
+
685
+ new_hydrated_nodes: list[list[EntityNode]] = await semaphore_gather(
686
+ *[
687
+ extract_attributes_from_nodes(
688
+ self.clients,
689
+ [params[0]],
690
+ params[1][0],
691
+ params[1][0:],
692
+ entity_types,
693
+ )
694
+ for params in extract_attributes_params
695
+ ]
650
696
  )
651
- logger.debug(f'extracted edge length: {len(edges)}')
652
697
 
653
- # invalidate edges
698
+ hydrated_nodes = [node for nodes in new_hydrated_nodes for node in nodes]
654
699
 
655
- # save edges to KG
656
- await semaphore_gather(
657
- *[edge.save(self.driver) for edge in edges],
658
- max_coroutines=self.max_coroutines,
700
+ # TODO: Resolve nodes and edges against the existing graph
701
+ edges_by_uuid: dict[str, EntityEdge] = {
702
+ edge.uuid: edge for edges in edges_by_episode.values() for edge in edges
703
+ }
704
+
705
+ # save data to KG
706
+ await add_nodes_and_edges_bulk(
707
+ self.driver,
708
+ episodes,
709
+ episodic_edges,
710
+ hydrated_nodes,
711
+ list(edges_by_uuid.values()),
712
+ self.embedder,
659
713
  )
660
714
 
661
715
  end = time()
@@ -828,7 +882,7 @@ class Graphiti:
828
882
  await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
829
883
  )[0]
830
884
 
831
- resolved_edge, invalidated_edges = await resolve_extracted_edge(
885
+ resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
832
886
  self.llm_client,
833
887
  updated_edge,
834
888
  related_edges,
@@ -76,6 +76,7 @@ class GeminiClient(LLMClient):
76
76
  cache: bool = False,
77
77
  max_tokens: int = DEFAULT_MAX_TOKENS,
78
78
  thinking_config: types.ThinkingConfig | None = None,
79
+ client: 'genai.Client | None' = None,
79
80
  ):
80
81
  """
81
82
  Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
@@ -85,7 +86,7 @@ class GeminiClient(LLMClient):
85
86
  cache (bool): Whether to use caching for responses. Defaults to False.
86
87
  thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
87
88
  Only use with models that support thinking (gemini-2.5+). Defaults to None.
88
-
89
+ client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
89
90
  """
90
91
  if config is None:
91
92
  config = LLMConfig()
@@ -93,10 +94,12 @@ class GeminiClient(LLMClient):
93
94
  super().__init__(config, cache)
94
95
 
95
96
  self.model = config.model
96
- # Configure the Gemini API
97
- self.client = genai.Client(
98
- api_key=config.api_key,
99
- )
97
+
98
+ if client is None:
99
+ self.client = genai.Client(api_key=config.api_key)
100
+ else:
101
+ self.client = client
102
+
100
103
  self.max_tokens = max_tokens
101
104
  self.thinking_config = thinking_config
102
105
 
@@ -23,9 +23,9 @@ from .models import Message, PromptFunction, PromptVersion
23
23
 
24
24
 
25
25
  class EdgeDuplicate(BaseModel):
26
- duplicate_fact_id: int = Field(
26
+ duplicate_facts: list[int] = Field(
27
27
  ...,
28
- description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
28
+ description='List of ids of any duplicate facts. If no duplicate facts are found, default to empty list.',
29
29
  )
30
30
  contradicted_facts: list[int] = Field(
31
31
  ...,
@@ -75,8 +75,9 @@ def edge(context: dict[str, Any]) -> list[Message]:
75
75
  </NEW EDGE>
76
76
 
77
77
  Task:
78
- If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact.
79
- If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return -1.
78
+ If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact
79
+ as part of the list of duplicate_facts.
80
+ If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return an empty list.
80
81
 
81
82
  Guidelines:
82
83
  1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
@@ -32,9 +32,9 @@ class NodeDuplicate(BaseModel):
32
32
  ...,
33
33
  description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
34
34
  )
35
- additional_duplicates: list[int] = Field(
35
+ duplicates: list[int] = Field(
36
36
  ...,
37
- description='idx of additional duplicate entities. Use this list if the entity has multiple duplicates among existing entities.',
37
+ description='idx of all duplicate entities.',
38
38
  )
39
39
 
40
40
 
@@ -94,7 +94,7 @@ def node(context: dict[str, Any]) -> list[Message]:
94
94
  1. Compare `new_entity` against each item in `existing_entities`.
95
95
  2. If it refers to the same real‐world object or concept, collect its index.
96
96
  3. Let `duplicate_idx` = the *first* collected index, or –1 if none.
97
- 4. Let `additional_duplicates` = the list of *any other* collected indices (empty list if none).
97
+ 4. Let `duplicates` = the list of *all* collected indices (empty list if none).
98
98
 
99
99
  Also return the full name of the NEW ENTITY (whether it is the name of the NEW ENTITY, a node it
100
100
  is a duplicate of, or a combination of the two).