graphiti-core 0.17.11__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/nodes.py CHANGED
@@ -25,29 +25,22 @@ from uuid import uuid4
25
25
  from pydantic import BaseModel, Field
26
26
  from typing_extensions import LiteralString
27
27
 
28
- from graphiti_core.driver.driver import GraphDriver
28
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
29
29
  from graphiti_core.embedder import EmbedderClient
30
30
  from graphiti_core.errors import NodeNotFoundError
31
31
  from graphiti_core.helpers import parse_db_date
32
32
  from graphiti_core.models.nodes.node_db_queries import (
33
- COMMUNITY_NODE_SAVE,
34
- ENTITY_NODE_SAVE,
33
+ COMMUNITY_NODE_RETURN,
34
+ ENTITY_NODE_RETURN,
35
+ EPISODIC_NODE_RETURN,
35
36
  EPISODIC_NODE_SAVE,
37
+ get_community_node_save_query,
38
+ get_entity_node_save_query,
36
39
  )
37
40
  from graphiti_core.utils.datetime_utils import utc_now
38
41
 
39
42
  logger = logging.getLogger(__name__)
40
43
 
41
- ENTITY_NODE_RETURN: LiteralString = """
42
- RETURN
43
- n.uuid As uuid,
44
- n.name AS name,
45
- n.group_id AS group_id,
46
- n.created_at AS created_at,
47
- n.summary AS summary,
48
- labels(n) AS labels,
49
- properties(n) AS attributes"""
50
-
51
44
 
52
45
  class EpisodeType(Enum):
53
46
  """
@@ -96,18 +89,26 @@ class Node(BaseModel, ABC):
96
89
  async def save(self, driver: GraphDriver): ...
97
90
 
98
91
  async def delete(self, driver: GraphDriver):
99
- result = await driver.execute_query(
100
- """
101
- MATCH (n:Entity|Episodic|Community {uuid: $uuid})
102
- DETACH DELETE n
103
- """,
104
- uuid=self.uuid,
105
- )
92
+ if driver.provider == GraphProvider.FALKORDB:
93
+ for label in ['Entity', 'Episodic', 'Community']:
94
+ await driver.execute_query(
95
+ f"""
96
+ MATCH (n:{label} {{uuid: $uuid}})
97
+ DETACH DELETE n
98
+ """,
99
+ uuid=self.uuid,
100
+ )
101
+ else:
102
+ await driver.execute_query(
103
+ """
104
+ MATCH (n:Entity|Episodic|Community {uuid: $uuid})
105
+ DETACH DELETE n
106
+ """,
107
+ uuid=self.uuid,
108
+ )
106
109
 
107
110
  logger.debug(f'Deleted Node: {self.uuid}')
108
111
 
109
- return result
110
-
111
112
  def __hash__(self):
112
113
  return hash(self.uuid)
113
114
 
@@ -118,15 +119,23 @@ class Node(BaseModel, ABC):
118
119
 
119
120
  @classmethod
120
121
  async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
121
- await driver.execute_query(
122
- """
123
- MATCH (n:Entity|Episodic|Community {group_id: $group_id})
124
- DETACH DELETE n
125
- """,
126
- group_id=group_id,
127
- )
128
-
129
- return 'SUCCESS'
122
+ if driver.provider == GraphProvider.FALKORDB:
123
+ for label in ['Entity', 'Episodic', 'Community']:
124
+ await driver.execute_query(
125
+ f"""
126
+ MATCH (n:{label} {{group_id: $group_id}})
127
+ DETACH DELETE n
128
+ """,
129
+ group_id=group_id,
130
+ )
131
+ else:
132
+ await driver.execute_query(
133
+ """
134
+ MATCH (n:Entity|Episodic|Community {group_id: $group_id})
135
+ DETACH DELETE n
136
+ """,
137
+ group_id=group_id,
138
+ )
130
139
 
131
140
  @classmethod
132
141
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@@ -169,17 +178,10 @@ class EpisodicNode(Node):
169
178
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
170
179
  records, _, _ = await driver.execute_query(
171
180
  """
172
- MATCH (e:Episodic {uuid: $uuid})
173
- RETURN e.content AS content,
174
- e.created_at AS created_at,
175
- e.valid_at AS valid_at,
176
- e.uuid AS uuid,
177
- e.name AS name,
178
- e.group_id AS group_id,
179
- e.source_description AS source_description,
180
- e.source AS source,
181
- e.entity_edges AS entity_edges
182
- """,
181
+ MATCH (e:Episodic {uuid: $uuid})
182
+ RETURN
183
+ """
184
+ + EPISODIC_NODE_RETURN,
183
185
  uuid=uuid,
184
186
  routing_='r',
185
187
  )
@@ -195,18 +197,11 @@ class EpisodicNode(Node):
195
197
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
196
198
  records, _, _ = await driver.execute_query(
197
199
  """
198
- MATCH (e:Episodic) WHERE e.uuid IN $uuids
200
+ MATCH (e:Episodic)
201
+ WHERE e.uuid IN $uuids
199
202
  RETURN DISTINCT
200
- e.content AS content,
201
- e.created_at AS created_at,
202
- e.valid_at AS valid_at,
203
- e.uuid AS uuid,
204
- e.name AS name,
205
- e.group_id AS group_id,
206
- e.source_description AS source_description,
207
- e.source AS source,
208
- e.entity_edges AS entity_edges
209
- """,
203
+ """
204
+ + EPISODIC_NODE_RETURN,
210
205
  uuids=uuids,
211
206
  routing_='r',
212
207
  )
@@ -228,22 +223,17 @@ class EpisodicNode(Node):
228
223
 
229
224
  records, _, _ = await driver.execute_query(
230
225
  """
231
- MATCH (e:Episodic) WHERE e.group_id IN $group_ids
232
- """
226
+ MATCH (e:Episodic)
227
+ WHERE e.group_id IN $group_ids
228
+ """
233
229
  + cursor_query
234
230
  + """
235
231
  RETURN DISTINCT
236
- e.content AS content,
237
- e.created_at AS created_at,
238
- e.valid_at AS valid_at,
239
- e.uuid AS uuid,
240
- e.name AS name,
241
- e.group_id AS group_id,
242
- e.source_description AS source_description,
243
- e.source AS source,
244
- e.entity_edges AS entity_edges
245
- ORDER BY e.uuid DESC
246
- """
232
+ """
233
+ + EPISODIC_NODE_RETURN
234
+ + """
235
+ ORDER BY uuid DESC
236
+ """
247
237
  + limit_query,
248
238
  group_ids=group_ids,
249
239
  uuid=uuid_cursor,
@@ -259,18 +249,10 @@ class EpisodicNode(Node):
259
249
  async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
260
250
  records, _, _ = await driver.execute_query(
261
251
  """
262
- MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
252
+ MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
263
253
  RETURN DISTINCT
264
- e.content AS content,
265
- e.created_at AS created_at,
266
- e.valid_at AS valid_at,
267
- e.uuid AS uuid,
268
- e.name AS name,
269
- e.group_id AS group_id,
270
- e.source_description AS source_description,
271
- e.source AS source,
272
- e.entity_edges AS entity_edges
273
- """,
254
+ """
255
+ + EPISODIC_NODE_RETURN,
274
256
  entity_node_uuid=entity_node_uuid,
275
257
  routing_='r',
276
258
  )
@@ -297,11 +279,14 @@ class EntityNode(Node):
297
279
  return self.name_embedding
298
280
 
299
281
  async def load_name_embedding(self, driver: GraphDriver):
300
- query: LiteralString = """
282
+ records, _, _ = await driver.execute_query(
283
+ """
301
284
  MATCH (n:Entity {uuid: $uuid})
302
285
  RETURN n.name_embedding AS name_embedding
303
- """
304
- records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
286
+ """,
287
+ uuid=self.uuid,
288
+ routing_='r',
289
+ )
305
290
 
306
291
  if len(records) == 0:
307
292
  raise NodeNotFoundError(self.uuid)
@@ -317,12 +302,12 @@ class EntityNode(Node):
317
302
  'summary': self.summary,
318
303
  'created_at': self.created_at,
319
304
  }
320
-
321
305
  entity_data.update(self.attributes or {})
322
306
 
307
+ labels = ':'.join(self.labels + ['Entity'])
308
+
323
309
  result = await driver.execute_query(
324
- ENTITY_NODE_SAVE,
325
- labels=self.labels + ['Entity'],
310
+ get_entity_node_save_query(driver.provider, labels),
326
311
  entity_data=entity_data,
327
312
  )
328
313
 
@@ -332,14 +317,12 @@ class EntityNode(Node):
332
317
 
333
318
  @classmethod
334
319
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
335
- query = (
336
- """
337
- MATCH (n:Entity {uuid: $uuid})
338
- """
339
- + ENTITY_NODE_RETURN
340
- )
341
320
  records, _, _ = await driver.execute_query(
342
- query,
321
+ """
322
+ MATCH (n:Entity {uuid: $uuid})
323
+ RETURN
324
+ """
325
+ + ENTITY_NODE_RETURN,
343
326
  uuid=uuid,
344
327
  routing_='r',
345
328
  )
@@ -355,8 +338,10 @@ class EntityNode(Node):
355
338
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
356
339
  records, _, _ = await driver.execute_query(
357
340
  """
358
- MATCH (n:Entity) WHERE n.uuid IN $uuids
359
- """
341
+ MATCH (n:Entity)
342
+ WHERE n.uuid IN $uuids
343
+ RETURN
344
+ """
360
345
  + ENTITY_NODE_RETURN,
361
346
  uuids=uuids,
362
347
  routing_='r',
@@ -379,22 +364,26 @@ class EntityNode(Node):
379
364
  limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
380
365
  with_embeddings_query: LiteralString = (
381
366
  """,
382
- n.name_embedding AS name_embedding
383
- """
367
+ n.name_embedding AS name_embedding
368
+ """
384
369
  if with_embeddings
385
370
  else ''
386
371
  )
387
372
 
388
373
  records, _, _ = await driver.execute_query(
389
374
  """
390
- MATCH (n:Entity) WHERE n.group_id IN $group_ids
391
- """
375
+ MATCH (n:Entity)
376
+ WHERE n.group_id IN $group_ids
377
+ """
392
378
  + cursor_query
379
+ + """
380
+ RETURN
381
+ """
393
382
  + ENTITY_NODE_RETURN
394
383
  + with_embeddings_query
395
384
  + """
396
- ORDER BY n.uuid DESC
397
- """
385
+ ORDER BY n.uuid DESC
386
+ """
398
387
  + limit_query,
399
388
  group_ids=group_ids,
400
389
  uuid=uuid_cursor,
@@ -413,7 +402,7 @@ class CommunityNode(Node):
413
402
 
414
403
  async def save(self, driver: GraphDriver):
415
404
  result = await driver.execute_query(
416
- COMMUNITY_NODE_SAVE,
405
+ get_community_node_save_query(driver.provider),
417
406
  uuid=self.uuid,
418
407
  name=self.name,
419
408
  group_id=self.group_id,
@@ -436,11 +425,14 @@ class CommunityNode(Node):
436
425
  return self.name_embedding
437
426
 
438
427
  async def load_name_embedding(self, driver: GraphDriver):
439
- query: LiteralString = """
428
+ records, _, _ = await driver.execute_query(
429
+ """
440
430
  MATCH (c:Community {uuid: $uuid})
441
431
  RETURN c.name_embedding AS name_embedding
442
- """
443
- records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
432
+ """,
433
+ uuid=self.uuid,
434
+ routing_='r',
435
+ )
444
436
 
445
437
  if len(records) == 0:
446
438
  raise NodeNotFoundError(self.uuid)
@@ -451,14 +443,10 @@ class CommunityNode(Node):
451
443
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
452
444
  records, _, _ = await driver.execute_query(
453
445
  """
454
- MATCH (n:Community {uuid: $uuid})
455
- RETURN
456
- n.uuid As uuid,
457
- n.name AS name,
458
- n.group_id AS group_id,
459
- n.created_at AS created_at,
460
- n.summary AS summary
461
- """,
446
+ MATCH (n:Community {uuid: $uuid})
447
+ RETURN
448
+ """
449
+ + COMMUNITY_NODE_RETURN,
462
450
  uuid=uuid,
463
451
  routing_='r',
464
452
  )
@@ -474,14 +462,11 @@ class CommunityNode(Node):
474
462
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
475
463
  records, _, _ = await driver.execute_query(
476
464
  """
477
- MATCH (n:Community) WHERE n.uuid IN $uuids
478
- RETURN
479
- n.uuid As uuid,
480
- n.name AS name,
481
- n.group_id AS group_id,
482
- n.created_at AS created_at,
483
- n.summary AS summary
484
- """,
465
+ MATCH (n:Community)
466
+ WHERE n.uuid IN $uuids
467
+ RETURN
468
+ """
469
+ + COMMUNITY_NODE_RETURN,
485
470
  uuids=uuids,
486
471
  routing_='r',
487
472
  )
@@ -503,18 +488,17 @@ class CommunityNode(Node):
503
488
 
504
489
  records, _, _ = await driver.execute_query(
505
490
  """
506
- MATCH (n:Community) WHERE n.group_id IN $group_ids
507
- """
491
+ MATCH (n:Community)
492
+ WHERE n.group_id IN $group_ids
493
+ """
508
494
  + cursor_query
509
495
  + """
510
- RETURN
511
- n.uuid As uuid,
512
- n.name AS name,
513
- n.group_id AS group_id,
514
- n.created_at AS created_at,
515
- n.summary AS summary
516
- ORDER BY n.uuid DESC
517
- """
496
+ RETURN
497
+ """
498
+ + COMMUNITY_NODE_RETURN
499
+ + """
500
+ ORDER BY n.uuid DESC
501
+ """
518
502
  + limit_query,
519
503
  group_ids=group_ids,
520
504
  uuid=uuid_cursor,
@@ -586,6 +570,7 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
586
570
  async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
587
571
  if not nodes: # Handle empty list case
588
572
  return
573
+
589
574
  name_embeddings = await embedder.create_batch([node.name for node in nodes])
590
575
  for node, name_embedding in zip(nodes, name_embeddings, strict=True):
591
576
  node.name_embedding = name_embedding
@@ -34,7 +34,7 @@ class NodeDuplicate(BaseModel):
34
34
  )
35
35
  duplicates: list[int] = Field(
36
36
  ...,
37
- description='idx of all duplicate entities.',
37
+ description='idx of all entities that are a duplicate of the entity with the above id.',
38
38
  )
39
39
 
40
40
 
@@ -68,6 +68,10 @@ def edge(context: dict[str, Any]) -> list[Message]:
68
68
  Message(
69
69
  role='user',
70
70
  content=f"""
71
+ <FACT TYPES>
72
+ {context['edge_types']}
73
+ </FACT TYPES>
74
+
71
75
  <PREVIOUS_MESSAGES>
72
76
  {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
73
77
  </PREVIOUS_MESSAGES>
@@ -84,10 +88,6 @@ def edge(context: dict[str, Any]) -> list[Message]:
84
88
  {context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
85
89
  </REFERENCE_TIME>
86
90
 
87
- <FACT TYPES>
88
- {context['edge_types']}
89
- </FACT TYPES>
90
-
91
91
  # TASK
92
92
  Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
93
93
  Only extract facts that:
@@ -75,6 +75,10 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
75
75
  Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""
76
76
 
77
77
  user_prompt = f"""
78
+ <ENTITY TYPES>
79
+ {context['entity_types']}
80
+ </ENTITY TYPES>
81
+
78
82
  <PREVIOUS MESSAGES>
79
83
  {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
80
84
  </PREVIOUS MESSAGES>
@@ -83,10 +87,6 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
83
87
  {context['episode_content']}
84
88
  </CURRENT MESSAGE>
85
89
 
86
- <ENTITY TYPES>
87
- {context['entity_types']}
88
- </ENTITY TYPES>
89
-
90
90
  Instructions:
91
91
 
92
92
  You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
@@ -124,15 +124,16 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
124
124
  Your primary task is to extract and classify relevant entities from JSON files"""
125
125
 
126
126
  user_prompt = f"""
127
+ <ENTITY TYPES>
128
+ {context['entity_types']}
129
+ </ENTITY TYPES>
130
+
127
131
  <SOURCE DESCRIPTION>:
128
132
  {context['source_description']}
129
133
  </SOURCE DESCRIPTION>
130
134
  <JSON>
131
135
  {context['episode_content']}
132
136
  </JSON>
133
- <ENTITY TYPES>
134
- {context['entity_types']}
135
- </ENTITY TYPES>
136
137
 
137
138
  {context['custom_prompt']}
138
139
 
@@ -155,13 +156,14 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
155
156
  Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""
156
157
 
157
158
  user_prompt = f"""
158
- <TEXT>
159
- {context['episode_content']}
160
- </TEXT>
161
159
  <ENTITY TYPES>
162
160
  {context['entity_types']}
163
161
  </ENTITY TYPES>
164
162
 
163
+ <TEXT>
164
+ {context['episode_content']}
165
+ </TEXT>
166
+
165
167
  Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
166
168
  For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
167
169
  Indicate the classified entity type by providing its entity_type_id.