graphiti-core 0.19.0rc3__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  import typing
19
20
  from datetime import datetime
@@ -22,20 +23,21 @@ import numpy as np
22
23
  from pydantic import BaseModel, Field
23
24
  from typing_extensions import Any
24
25
 
25
- from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
26
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
26
27
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
27
28
  from graphiti_core.embedder import EmbedderClient
28
29
  from graphiti_core.graphiti_types import GraphitiClients
29
30
  from graphiti_core.helpers import normalize_l2, semaphore_gather
30
31
  from graphiti_core.models.edges.edge_db_queries import (
31
- EPISODIC_EDGE_SAVE_BULK,
32
32
  get_entity_edge_save_bulk_query,
33
+ get_episodic_edge_save_bulk_query,
33
34
  )
34
35
  from graphiti_core.models.nodes.node_db_queries import (
35
36
  get_entity_node_save_bulk_query,
36
37
  get_episode_node_save_bulk_query,
37
38
  )
38
39
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
40
+ from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
39
41
  from graphiti_core.utils.maintenance.edge_operations import (
40
42
  extract_edges,
41
43
  resolve_extracted_edge,
@@ -116,11 +118,16 @@ async def add_nodes_and_edges_bulk_tx(
116
118
  episodes = [dict(episode) for episode in episodic_nodes]
117
119
  for episode in episodes:
118
120
  episode['source'] = str(episode['source'].value)
119
- episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
120
- nodes: list[dict[str, Any]] = []
121
+ episode.pop('labels', None)
122
+ if driver.provider == GraphProvider.NEO4J:
123
+ episode['group_label'] = 'Episodic_' + episode['group_id'].replace('-', '')
124
+
125
+ nodes = []
126
+
121
127
  for node in entity_nodes:
122
128
  if node.name_embedding is None:
123
129
  await node.generate_name_embedding(embedder)
130
+
124
131
  entity_data: dict[str, Any] = {
125
132
  'uuid': node.uuid,
126
133
  'name': node.name,
@@ -130,13 +137,19 @@ async def add_nodes_and_edges_bulk_tx(
130
137
  'created_at': node.created_at,
131
138
  }
132
139
 
133
- entity_data.update(node.attributes or {})
134
- entity_data['labels'] = list(
135
- set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
136
- )
140
+ entity_data['labels'] = list(set(node.labels + ['Entity']))
141
+ if driver.provider == GraphProvider.KUZU:
142
+ attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
143
+ entity_data['attributes'] = json.dumps(attributes)
144
+ else:
145
+ entity_data.update(node.attributes or {})
146
+ entity_data['labels'] = list(
147
+ set(node.labels + ['Entity', 'Entity_' + node.group_id.replace('-', '')])
148
+ )
149
+
137
150
  nodes.append(entity_data)
138
151
 
139
- edges: list[dict[str, Any]] = []
152
+ edges = []
140
153
  for edge in entity_edges:
141
154
  if edge.fact_embedding is None:
142
155
  await edge.generate_embedding(embedder)
@@ -155,17 +168,36 @@ async def add_nodes_and_edges_bulk_tx(
155
168
  'invalid_at': edge.invalid_at,
156
169
  }
157
170
 
158
- edge_data.update(edge.attributes or {})
171
+ if driver.provider == GraphProvider.KUZU:
172
+ attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
173
+ edge_data['attributes'] = json.dumps(attributes)
174
+ else:
175
+ edge_data.update(edge.attributes or {})
176
+
159
177
  edges.append(edge_data)
160
178
 
161
- await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
162
- entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
163
- await tx.run(entity_node_save_bulk, nodes=nodes)
164
- await tx.run(
165
- EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
166
- )
167
- entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
168
- await tx.run(entity_edge_save_bulk, entity_edges=edges)
179
+ if driver.provider == GraphProvider.KUZU:
180
+ # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
181
+ episode_query = get_episode_node_save_bulk_query(driver.provider)
182
+ for episode in episodes:
183
+ await tx.run(episode_query, **episode)
184
+ entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
185
+ for node in nodes:
186
+ await tx.run(entity_node_query, **node)
187
+ entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
188
+ for edge in edges:
189
+ await tx.run(entity_edge_query, **edge)
190
+ episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
191
+ for edge in episodic_edges:
192
+ await tx.run(episodic_edge_query, **edge.model_dump())
193
+ else:
194
+ await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
195
+ await tx.run(get_entity_node_save_bulk_query(driver.provider, nodes), nodes=nodes)
196
+ await tx.run(
197
+ get_episodic_edge_save_bulk_query(driver.provider),
198
+ episodic_edges=[edge.model_dump() for edge in episodic_edges],
199
+ )
200
+ await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges)
169
201
 
170
202
 
171
203
  async def extract_nodes_and_edges_bulk(
@@ -40,3 +40,16 @@ def ensure_utc(dt: datetime | None) -> datetime | None:
40
40
  return dt.astimezone(timezone.utc)
41
41
 
42
42
  return dt
43
+
44
+
45
+ def convert_datetimes_to_strings(obj):
46
+ if isinstance(obj, dict):
47
+ return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
48
+ elif isinstance(obj, list):
49
+ return [convert_datetimes_to_strings(item) for item in obj]
50
+ elif isinstance(obj, tuple):
51
+ return tuple(convert_datetimes_to_strings(item) for item in obj)
52
+ elif isinstance(obj, datetime):
53
+ return obj.isoformat()
54
+ else:
55
+ return obj
@@ -4,11 +4,12 @@ from collections import defaultdict
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
7
- from graphiti_core.driver.driver import GraphDriver
7
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
8
8
  from graphiti_core.edges import CommunityEdge
9
9
  from graphiti_core.embedder import EmbedderClient
10
10
  from graphiti_core.helpers import semaphore_gather
11
11
  from graphiti_core.llm_client import LLMClient
12
+ from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN
12
13
  from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
13
14
  from graphiti_core.prompts import prompt_library
14
15
  from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
@@ -33,11 +34,11 @@ async def get_community_clusters(
33
34
  if group_ids is None:
34
35
  group_id_values, _, _ = await driver.execute_query(
35
36
  """
36
- MATCH (n:Entity)
37
- WHERE n.group_id IS NOT NULL
38
- RETURN
39
- collect(DISTINCT n.group_id) AS group_ids
40
- """,
37
+ MATCH (n:Entity)
38
+ WHERE n.group_id IS NOT NULL
39
+ RETURN
40
+ collect(DISTINCT n.group_id) AS group_ids
41
+ """
41
42
  )
42
43
 
43
44
  group_ids = group_id_values[0]['group_ids'] if group_id_values else []
@@ -46,14 +47,21 @@ async def get_community_clusters(
46
47
  projection: dict[str, list[Neighbor]] = {}
47
48
  nodes = await EntityNode.get_by_group_ids(driver, [group_id])
48
49
  for node in nodes:
49
- records, _, _ = await driver.execute_query(
50
+ match_query = """
51
+ MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
52
+ """
53
+ if driver.provider == GraphProvider.KUZU:
54
+ match_query = """
55
+ MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
50
56
  """
51
- MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
52
- WITH count(r) AS count, m.uuid AS uuid
53
- RETURN
54
- uuid,
55
- count
56
- """,
57
+ records, _, _ = await driver.execute_query(
58
+ match_query
59
+ + """
60
+ WITH count(e) AS count, m.uuid AS uuid
61
+ RETURN
62
+ uuid,
63
+ count
64
+ """,
57
65
  uuid=node.uuid,
58
66
  group_id=group_id,
59
67
  )
@@ -235,9 +243,9 @@ async def build_communities(
235
243
  async def remove_communities(driver: GraphDriver):
236
244
  await driver.execute_query(
237
245
  """
238
- MATCH (c:Community)
239
- DETACH DELETE c
240
- """,
246
+ MATCH (c:Community)
247
+ DETACH DELETE c
248
+ """
241
249
  )
242
250
 
243
251
 
@@ -247,14 +255,10 @@ async def determine_entity_community(
247
255
  # Check if the node is already part of a community
248
256
  records, _, _ = await driver.execute_query(
249
257
  """
250
- MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
251
- RETURN
252
- c.uuid AS uuid,
253
- c.name AS name,
254
- c.group_id AS group_id,
255
- c.created_at AS created_at,
256
- c.summary AS summary
257
- """,
258
+ MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
259
+ RETURN
260
+ """
261
+ + COMMUNITY_NODE_RETURN,
258
262
  entity_uuid=entity.uuid,
259
263
  )
260
264
 
@@ -262,16 +266,19 @@ async def determine_entity_community(
262
266
  return get_community_node_from_record(records[0]), False
263
267
 
264
268
  # If the node has no community, add it to the mode community of surrounding entities
269
+ match_query = """
270
+ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
271
+ """
272
+ if driver.provider == GraphProvider.KUZU:
273
+ match_query = """
274
+ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
275
+ """
265
276
  records, _, _ = await driver.execute_query(
277
+ match_query
278
+ + """
279
+ RETURN
266
280
  """
267
- MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
268
- RETURN
269
- c.uuid AS uuid,
270
- c.name AS name,
271
- c.group_id AS group_id,
272
- c.created_at AS created_at,
273
- c.summary AS summary
274
- """,
281
+ + COMMUNITY_NODE_RETURN,
275
282
  entity_uuid=entity.uuid,
276
283
  )
277
284
 
@@ -531,17 +531,28 @@ async def filter_existing_duplicate_of_edges(
531
531
  routing_='r',
532
532
  )
533
533
  else:
534
- query: LiteralString = """
535
- UNWIND $duplicate_node_uuids AS duplicate_tuple
536
- MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
537
- RETURN DISTINCT
538
- n.uuid AS source_uuid,
539
- m.uuid AS target_uuid
540
- """
534
+ if driver.provider == GraphProvider.KUZU:
535
+ query = """
536
+ UNWIND $duplicate_node_uuids AS duplicate
537
+ MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
538
+ RETURN DISTINCT
539
+ n.uuid AS source_uuid,
540
+ m.uuid AS target_uuid
541
+ """
542
+ duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
543
+ else:
544
+ query: LiteralString = """
545
+ UNWIND $duplicate_node_uuids AS duplicate_tuple
546
+ MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
547
+ RETURN DISTINCT
548
+ n.uuid AS source_uuid,
549
+ m.uuid AS target_uuid
550
+ """
551
+ duplicate_node_uuids = list(duplicate_nodes_map.keys())
541
552
 
542
553
  records, _, _ = await driver.execute_query(
543
554
  query,
544
- duplicate_node_uuids=list(duplicate_nodes_map.keys()),
555
+ duplicate_node_uuids=duplicate_node_uuids,
545
556
  routing_='r',
546
557
  )
547
558
 
@@ -53,10 +53,29 @@ async def build_indices_and_constraints(driver: GraphDriver, delete_existing: bo
53
53
  for name in index_names
54
54
  ]
55
55
  )
56
+
56
57
  range_indices: list[LiteralString] = get_range_indices(driver.provider)
57
58
 
58
59
  fulltext_indices: list[LiteralString] = get_fulltext_indices(driver.provider)
59
60
 
61
+ if driver.provider == GraphProvider.KUZU:
62
+ # Skip creating fulltext indices if they already exist. Need to do this manually
63
+ # until Kuzu supports `IF NOT EXISTS` for indices.
64
+ result, _, _ = await driver.execute_query('CALL SHOW_INDEXES() RETURN *;')
65
+ if len(result) > 0:
66
+ fulltext_indices = []
67
+
68
+ # Only load the `fts` extension if it's not already loaded, otherwise throw an error.
69
+ result, _, _ = await driver.execute_query('CALL SHOW_LOADED_EXTENSIONS() RETURN *;')
70
+ if len(result) == 0:
71
+ fulltext_indices.insert(
72
+ 0,
73
+ """
74
+ INSTALL fts;
75
+ LOAD fts;
76
+ """,
77
+ )
78
+
60
79
  index_queries: list[LiteralString] = range_indices + fulltext_indices
61
80
 
62
81
  await semaphore_gather(
@@ -76,10 +95,19 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
76
95
  await tx.run('MATCH (n) DETACH DELETE n')
77
96
 
78
97
  async def delete_group_ids(tx):
79
- await tx.run(
80
- 'MATCH (n) WHERE (n:Entity OR n:Episodic OR n:Community) AND n.group_id IN $group_ids DETACH DELETE n',
81
- group_ids=group_ids,
82
- )
98
+ labels = ['Entity', 'Episodic', 'Community']
99
+ if driver.provider == GraphProvider.KUZU:
100
+ labels.append('RelatesToNode_')
101
+
102
+ for label in labels:
103
+ await tx.run(
104
+ f"""
105
+ MATCH (n:{label})
106
+ WHERE n.group_id IN $group_ids
107
+ DETACH DELETE n
108
+ """,
109
+ group_ids=group_ids,
110
+ )
83
111
 
84
112
  if group_ids is None:
85
113
  await session.execute_write(delete_all)
@@ -108,18 +136,23 @@ async def retrieve_episodes(
108
136
  Returns:
109
137
  list[EpisodicNode]: A list of EpisodicNode objects representing the retrieved episodes.
110
138
  """
111
- group_id_filter: LiteralString = (
112
- '\nAND e.group_id IN $group_ids' if group_ids and len(group_ids) > 0 else ''
113
- )
114
- source_filter: LiteralString = '\nAND e.source = $source' if source is not None else ''
139
+
140
+ query_params: dict = {}
141
+ query_filter = ''
142
+ if group_ids and len(group_ids) > 0:
143
+ query_filter += '\nAND e.group_id IN $group_ids'
144
+ query_params['group_ids'] = group_ids
145
+
146
+ if source is not None:
147
+ query_filter += '\nAND e.source = $source'
148
+ query_params['source'] = source.name
115
149
 
116
150
  query: LiteralString = (
117
151
  """
118
- MATCH (e:Episodic)
119
- WHERE e.valid_at <= $reference_time
120
- """
121
- + group_id_filter
122
- + source_filter
152
+ MATCH (e:Episodic)
153
+ WHERE e.valid_at <= $reference_time
154
+ """
155
+ + query_filter
123
156
  + """
124
157
  RETURN
125
158
  """
@@ -136,9 +169,8 @@ async def retrieve_episodes(
136
169
  result, _, _ = await driver.execute_query(
137
170
  query,
138
171
  reference_time=reference_time,
139
- source=source.name if source is not None else None,
140
172
  num_episodes=last_n,
141
- group_ids=group_ids,
173
+ **query_params,
142
174
  )
143
175
 
144
176
  episodes = [get_episodic_node_from_record(record) for record in result]
@@ -148,41 +180,39 @@ async def retrieve_episodes(
148
180
  async def build_dynamic_indexes(driver: GraphDriver, group_id: str):
149
181
  # Make sure indices exist for this group_id in Neo4j
150
182
  if driver.provider == GraphProvider.NEO4J:
151
- await semaphore_gather(
152
- driver.execute_query(
153
- """CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS
154
- FOR (e:"""
155
- + 'Episodic_'
156
- + group_id.replace('-', '')
157
- + """) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
158
- episode_content='episode_content_' + group_id.replace('-', ''),
159
- ),
160
- driver.execute_query(
161
- """CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:"""
162
- + 'Entity_'
163
- + group_id.replace('-', '')
164
- + """) ON EACH [n.name, n.summary, n.group_id]""",
165
- node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''),
166
- ),
167
- driver.execute_query(
168
- """CREATE FULLTEXT INDEX $community_name IF NOT EXISTS
169
- FOR (n:"""
170
- + 'Community_'
171
- + group_id.replace('-', '')
172
- + """) ON EACH [n.name, n.group_id]""",
173
- community_name='Community_' + group_id.replace('-', ''),
174
- ),
175
- driver.execute_query(
176
- """CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS
177
- FOR (n:"""
178
- + 'Entity_'
179
- + group_id.replace('-', '')
180
- + """)
183
+ await driver.execute_query(
184
+ """CREATE FULLTEXT INDEX $episode_content IF NOT EXISTS
185
+ FOR (e:"""
186
+ + 'Episodic_'
187
+ + group_id.replace('-', '')
188
+ + """) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
189
+ episode_content='episode_content_' + group_id.replace('-', ''),
190
+ )
191
+ await driver.execute_query(
192
+ """CREATE FULLTEXT INDEX $node_name_and_summary IF NOT EXISTS FOR (n:"""
193
+ + 'Entity_'
194
+ + group_id.replace('-', '')
195
+ + """) ON EACH [n.name, n.summary, n.group_id]""",
196
+ node_name_and_summary='node_name_and_summary_' + group_id.replace('-', ''),
197
+ )
198
+ await driver.execute_query(
199
+ """CREATE FULLTEXT INDEX $community_name IF NOT EXISTS
200
+ FOR (n:"""
201
+ + 'Community_'
202
+ + group_id.replace('-', '')
203
+ + """) ON EACH [n.name, n.group_id]""",
204
+ community_name='Community_' + group_id.replace('-', ''),
205
+ )
206
+ await driver.execute_query(
207
+ """CREATE VECTOR INDEX $group_entity_vector IF NOT EXISTS
208
+ FOR (n:"""
209
+ + 'Entity_'
210
+ + group_id.replace('-', '')
211
+ + """)
181
212
  ON n.embedding
182
213
  OPTIONS { indexConfig: {
183
214
  `vector.dimensions`: 1024,
184
215
  `vector.similarity_function`: 'cosine'
185
216
  }}""",
186
- group_entity_vector='group_entity_vector_' + group_id.replace('-', ''),
187
- ),
217
+ group_entity_vector='group_entity_vector_' + group_id.replace('-', ''),
188
218
  )