graphiti-core 0.17.11__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/graphiti.py CHANGED
@@ -26,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
26
26
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
27
27
  from graphiti_core.driver.driver import GraphDriver
28
28
  from graphiti_core.driver.neo4j_driver import Neo4jDriver
29
- from graphiti_core.edges import EntityEdge, EpisodicEdge
29
+ from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
30
30
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
31
31
  from graphiti_core.graphiti_types import GraphitiClients
32
32
  from graphiti_core.helpers import (
@@ -93,8 +93,11 @@ load_dotenv()
93
93
 
94
94
  class AddEpisodeResults(BaseModel):
95
95
  episode: EpisodicNode
96
+ episodic_edges: list[EpisodicEdge]
96
97
  nodes: list[EntityNode]
97
98
  edges: list[EntityEdge]
99
+ communities: list[CommunityNode]
100
+ community_edges: list[CommunityEdge]
98
101
 
99
102
 
100
103
  class Graphiti:
@@ -113,7 +116,7 @@ class Graphiti:
113
116
  """
114
117
  Initialize a Graphiti instance.
115
118
 
116
- This constructor sets up a connection to the Neo4j database and initializes
119
+ This constructor sets up a connection to a graph database and initializes
117
120
  the LLM client for natural language processing tasks.
118
121
 
119
122
  Parameters
@@ -148,11 +151,11 @@ class Graphiti:
148
151
 
149
152
  Notes
150
153
  -----
151
- This method establishes a connection to the Neo4j database using the provided
154
+ This method establishes a connection to a graph database (Neo4j by default) using the provided
152
155
  credentials. It also sets up the LLM client, either using the provided client
153
156
  or by creating a default OpenAIClient.
154
157
 
155
- The default database name is set to 'neo4j'. If a different database name
158
+ The default database name is defined during the driver’s construction. If a different database name
156
159
  is required, it should be specified in the URI or set separately after
157
160
  initialization.
158
161
 
@@ -520,9 +523,12 @@ class Graphiti:
520
523
  self.driver, [episode], episodic_edges, hydrated_nodes, entity_edges, self.embedder
521
524
  )
522
525
 
526
+ communities = []
527
+ community_edges = []
528
+
523
529
  # Update any communities
524
530
  if update_communities:
525
- await semaphore_gather(
531
+ communities, community_edges = await semaphore_gather(
526
532
  *[
527
533
  update_community(self.driver, self.llm_client, self.embedder, node)
528
534
  for node in nodes
@@ -532,7 +538,14 @@ class Graphiti:
532
538
  end = time()
533
539
  logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
534
540
 
535
- return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
541
+ return AddEpisodeResults(
542
+ episode=episode,
543
+ episodic_edges=episodic_edges,
544
+ nodes=hydrated_nodes,
545
+ edges=entity_edges,
546
+ communities=communities,
547
+ community_edges=community_edges,
548
+ )
536
549
 
537
550
  except Exception as e:
538
551
  raise e
@@ -817,7 +830,9 @@ class Graphiti:
817
830
  except Exception as e:
818
831
  raise e
819
832
 
820
- async def build_communities(self, group_ids: list[str] | None = None) -> list[CommunityNode]:
833
+ async def build_communities(
834
+ self, group_ids: list[str] | None = None
835
+ ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
821
836
  """
822
837
  Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
823
838
  the content of these communities.
@@ -846,7 +861,7 @@ class Graphiti:
846
861
  max_coroutines=self.max_coroutines,
847
862
  )
848
863
 
849
- return community_nodes
864
+ return community_nodes, community_edges
850
865
 
851
866
  async def search(
852
867
  self,
@@ -959,7 +974,7 @@ class Graphiti:
959
974
 
960
975
  nodes = await get_mentioned_nodes(self.driver, episodes)
961
976
 
962
- return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[])
977
+ return SearchResults(edges=edges, nodes=nodes)
963
978
 
964
979
  async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
965
980
  if source_node.name_embedding is None:
graphiti_core/helpers.py CHANGED
@@ -28,6 +28,7 @@ from numpy._typing import NDArray
28
28
  from pydantic import BaseModel
29
29
  from typing_extensions import LiteralString
30
30
 
31
+ from graphiti_core.driver.driver import GraphProvider
31
32
  from graphiti_core.errors import GroupIdValidationError
32
33
 
33
34
  load_dotenv()
@@ -52,12 +53,12 @@ def parse_db_date(neo_date: neo4j_time.DateTime | str | None) -> datetime | None
52
53
  )
53
54
 
54
55
 
55
- def get_default_group_id(db_type: str) -> str:
56
+ def get_default_group_id(provider: GraphProvider) -> str:
56
57
  """
57
58
  This function differentiates the default group id based on the database type.
58
59
  For most databases, the default group id is an empty string, while there are database types that require a specific default group id.
59
60
  """
60
- if db_type == 'falkordb':
61
+ if provider == GraphProvider.FALKORDB:
61
62
  return '_'
62
63
  else:
63
64
  return ''
@@ -14,43 +14,117 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ from graphiti_core.driver.driver import GraphProvider
18
+
17
19
  EPISODIC_EDGE_SAVE = """
18
- MATCH (episode:Episodic {uuid: $episode_uuid})
19
- MATCH (node:Entity {uuid: $entity_uuid})
20
- MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
21
- SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
22
- RETURN r.uuid AS uuid"""
20
+ MATCH (episode:Episodic {uuid: $episode_uuid})
21
+ MATCH (node:Entity {uuid: $entity_uuid})
22
+ MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
23
+ SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
24
+ RETURN e.uuid AS uuid
25
+ """
23
26
 
24
27
  EPISODIC_EDGE_SAVE_BULK = """
25
28
  UNWIND $episodic_edges AS edge
26
- MATCH (episode:Episodic {uuid: edge.source_node_uuid})
27
- MATCH (node:Entity {uuid: edge.target_node_uuid})
28
- MERGE (episode)-[r:MENTIONS {uuid: edge.uuid}]->(node)
29
- SET r = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
30
- RETURN r.uuid AS uuid
29
+ MATCH (episode:Episodic {uuid: edge.source_node_uuid})
30
+ MATCH (node:Entity {uuid: edge.target_node_uuid})
31
+ MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
32
+ SET e = {uuid: edge.uuid, group_id: edge.group_id, created_at: edge.created_at}
33
+ RETURN e.uuid AS uuid
34
+ """
35
+
36
+ EPISODIC_EDGE_RETURN = """
37
+ e.uuid AS uuid,
38
+ e.group_id AS group_id,
39
+ n.uuid AS source_node_uuid,
40
+ m.uuid AS target_node_uuid,
41
+ e.created_at AS created_at
31
42
  """
32
43
 
33
- ENTITY_EDGE_SAVE = """
34
- MATCH (source:Entity {uuid: $edge_data.source_uuid})
35
- MATCH (target:Entity {uuid: $edge_data.target_uuid})
36
- MERGE (source)-[r:RELATES_TO {uuid: $edge_data.uuid}]->(target)
37
- SET r = $edge_data
38
- WITH r CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", $edge_data.fact_embedding)
39
- RETURN r.uuid AS uuid"""
40
-
41
- ENTITY_EDGE_SAVE_BULK = """
42
- UNWIND $entity_edges AS edge
43
- MATCH (source:Entity {uuid: edge.source_node_uuid})
44
- MATCH (target:Entity {uuid: edge.target_node_uuid})
45
- MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
46
- SET r = edge
47
- WITH r, edge CALL db.create.setRelationshipVectorProperty(r, "fact_embedding", edge.fact_embedding)
48
- RETURN edge.uuid AS uuid
44
+
45
+ def get_entity_edge_save_query(provider: GraphProvider) -> str:
46
+ if provider == GraphProvider.FALKORDB:
47
+ return """
48
+ MATCH (source:Entity {uuid: $edge_data.source_uuid})
49
+ MATCH (target:Entity {uuid: $edge_data.target_uuid})
50
+ MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
51
+ SET e = $edge_data
52
+ RETURN e.uuid AS uuid
53
+ """
54
+
55
+ return """
56
+ MATCH (source:Entity {uuid: $edge_data.source_uuid})
57
+ MATCH (target:Entity {uuid: $edge_data.target_uuid})
58
+ MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
59
+ SET e = $edge_data
60
+ WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
61
+ RETURN e.uuid AS uuid
62
+ """
63
+
64
+
65
+ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
66
+ if provider == GraphProvider.FALKORDB:
67
+ return """
68
+ UNWIND $entity_edges AS edge
69
+ MATCH (source:Entity {uuid: edge.source_node_uuid})
70
+ MATCH (target:Entity {uuid: edge.target_node_uuid})
71
+ MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
72
+ SET r = {uuid: edge.uuid, name: edge.name, group_id: edge.group_id, fact: edge.fact, episodes: edge.episodes,
73
+ created_at: edge.created_at, expired_at: edge.expired_at, valid_at: edge.valid_at, invalid_at: edge.invalid_at, fact_embedding: vecf32(edge.fact_embedding)}
74
+ WITH r, edge
75
+ RETURN edge.uuid AS uuid
76
+ """
77
+
78
+ return """
79
+ UNWIND $entity_edges AS edge
80
+ MATCH (source:Entity {uuid: edge.source_node_uuid})
81
+ MATCH (target:Entity {uuid: edge.target_node_uuid})
82
+ MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
83
+ SET e = edge
84
+ WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
85
+ RETURN edge.uuid AS uuid
86
+ """
87
+
88
+
89
+ ENTITY_EDGE_RETURN = """
90
+ e.uuid AS uuid,
91
+ n.uuid AS source_node_uuid,
92
+ m.uuid AS target_node_uuid,
93
+ e.group_id AS group_id,
94
+ e.name AS name,
95
+ e.fact AS fact,
96
+ e.episodes AS episodes,
97
+ e.created_at AS created_at,
98
+ e.expired_at AS expired_at,
99
+ e.valid_at AS valid_at,
100
+ e.invalid_at AS invalid_at,
101
+ properties(e) AS attributes
49
102
  """
50
103
 
51
- COMMUNITY_EDGE_SAVE = """
52
- MATCH (community:Community {uuid: $community_uuid})
53
- MATCH (node:Entity | Community {uuid: $entity_uuid})
54
- MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
55
- SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
56
- RETURN r.uuid AS uuid"""
104
+
105
+ def get_community_edge_save_query(provider: GraphProvider) -> str:
106
+ if provider == GraphProvider.FALKORDB:
107
+ return """
108
+ MATCH (community:Community {uuid: $community_uuid})
109
+ MATCH (node {uuid: $entity_uuid})
110
+ MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
111
+ SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
112
+ RETURN e.uuid AS uuid
113
+ """
114
+
115
+ return """
116
+ MATCH (community:Community {uuid: $community_uuid})
117
+ MATCH (node:Entity | Community {uuid: $entity_uuid})
118
+ MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
119
+ SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
120
+ RETURN e.uuid AS uuid
121
+ """
122
+
123
+
124
+ COMMUNITY_EDGE_RETURN = """
125
+ e.uuid AS uuid,
126
+ e.group_id AS group_id,
127
+ n.uuid AS source_node_uuid,
128
+ m.uuid AS target_node_uuid,
129
+ e.created_at AS created_at
130
+ """
@@ -14,39 +14,120 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ from typing import Any
18
+
19
+ from graphiti_core.driver.driver import GraphProvider
20
+
17
21
  EPISODIC_NODE_SAVE = """
18
- MERGE (n:Episodic {uuid: $uuid})
19
- SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
20
- entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
21
- RETURN n.uuid AS uuid"""
22
+ MERGE (n:Episodic {uuid: $uuid})
23
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
24
+ entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
25
+ RETURN n.uuid AS uuid
26
+ """
22
27
 
23
28
  EPISODIC_NODE_SAVE_BULK = """
24
29
  UNWIND $episodes AS episode
25
30
  MERGE (n:Episodic {uuid: episode.uuid})
26
- SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
27
- source: episode.source, content: episode.content,
31
+ SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
32
+ source: episode.source, content: episode.content,
28
33
  entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
29
34
  RETURN n.uuid AS uuid
30
35
  """
31
36
 
32
- ENTITY_NODE_SAVE = """
33
- MERGE (n:Entity {uuid: $entity_data.uuid})
34
- SET n:$($labels)
37
+ EPISODIC_NODE_RETURN = """
38
+ e.content AS content,
39
+ e.created_at AS created_at,
40
+ e.valid_at AS valid_at,
41
+ e.uuid AS uuid,
42
+ e.name AS name,
43
+ e.group_id AS group_id,
44
+ e.source_description AS source_description,
45
+ e.source AS source,
46
+ e.entity_edges AS entity_edges
47
+ """
48
+
49
+
50
+ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
51
+ if provider == GraphProvider.FALKORDB:
52
+ return f"""
53
+ MERGE (n:Entity {{uuid: $entity_data.uuid}})
54
+ SET n:{labels}
55
+ SET n = $entity_data
56
+ RETURN n.uuid AS uuid
57
+ """
58
+
59
+ return f"""
60
+ MERGE (n:Entity {{uuid: $entity_data.uuid}})
61
+ SET n:{labels}
35
62
  SET n = $entity_data
36
63
  WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
37
- RETURN n.uuid AS uuid"""
38
-
39
- ENTITY_NODE_SAVE_BULK = """
40
- UNWIND $nodes AS node
41
- MERGE (n:Entity {uuid: node.uuid})
42
- SET n:$(node.labels)
43
- SET n = node
44
- WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
45
- RETURN n.uuid AS uuid
64
+ RETURN n.uuid AS uuid
65
+ """
66
+
67
+
68
+ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
69
+ if provider == GraphProvider.FALKORDB:
70
+ queries = []
71
+ for node in nodes:
72
+ for label in node['labels']:
73
+ queries.append(
74
+ (
75
+ f"""
76
+ UNWIND $nodes AS node
77
+ MERGE (n:Entity {{uuid: node.uuid}})
78
+ SET n:{label}
79
+ SET n = node
80
+ WITH n, node
81
+ SET n.name_embedding = vecf32(node.name_embedding)
82
+ RETURN n.uuid AS uuid
83
+ """,
84
+ {'nodes': [node]},
85
+ )
86
+ )
87
+ return queries
88
+
89
+ return """
90
+ UNWIND $nodes AS node
91
+ MERGE (n:Entity {uuid: node.uuid})
92
+ SET n:$(node.labels)
93
+ SET n = node
94
+ WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
95
+ RETURN n.uuid AS uuid
96
+ """
97
+
98
+
99
+ ENTITY_NODE_RETURN = """
100
+ n.uuid AS uuid,
101
+ n.name AS name,
102
+ n.group_id AS group_id,
103
+ n.created_at AS created_at,
104
+ n.summary AS summary,
105
+ labels(n) AS labels,
106
+ properties(n) AS attributes
46
107
  """
47
108
 
48
- COMMUNITY_NODE_SAVE = """
109
+
110
+ def get_community_node_save_query(provider: GraphProvider) -> str:
111
+ if provider == GraphProvider.FALKORDB:
112
+ return """
113
+ MERGE (n:Community {uuid: $uuid})
114
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: $name_embedding}
115
+ RETURN n.uuid AS uuid
116
+ """
117
+
118
+ return """
49
119
  MERGE (n:Community {uuid: $uuid})
50
120
  SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
51
121
  WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
52
- RETURN n.uuid AS uuid"""
122
+ RETURN n.uuid AS uuid
123
+ """
124
+
125
+
126
+ COMMUNITY_NODE_RETURN = """
127
+ n.uuid AS uuid,
128
+ n.name AS name,
129
+ n.name_embedding AS name_embedding,
130
+ n.group_id AS group_id,
131
+ n.summary AS summary,
132
+ n.created_at AS created_at
133
+ """