graphiti-core 0.21.0rc13__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (41) hide show
  1. graphiti_core/driver/driver.py +4 -211
  2. graphiti_core/driver/falkordb_driver.py +31 -3
  3. graphiti_core/driver/graph_operations/graph_operations.py +195 -0
  4. graphiti_core/driver/neo4j_driver.py +0 -49
  5. graphiti_core/driver/neptune_driver.py +43 -26
  6. graphiti_core/driver/search_interface/__init__.py +0 -0
  7. graphiti_core/driver/search_interface/search_interface.py +89 -0
  8. graphiti_core/edges.py +11 -34
  9. graphiti_core/graphiti.py +459 -326
  10. graphiti_core/graphiti_types.py +2 -0
  11. graphiti_core/llm_client/anthropic_client.py +64 -45
  12. graphiti_core/llm_client/client.py +67 -19
  13. graphiti_core/llm_client/gemini_client.py +73 -54
  14. graphiti_core/llm_client/openai_base_client.py +65 -43
  15. graphiti_core/llm_client/openai_generic_client.py +65 -43
  16. graphiti_core/models/edges/edge_db_queries.py +1 -0
  17. graphiti_core/models/nodes/node_db_queries.py +1 -0
  18. graphiti_core/nodes.py +26 -99
  19. graphiti_core/prompts/dedupe_edges.py +4 -4
  20. graphiti_core/prompts/dedupe_nodes.py +10 -10
  21. graphiti_core/prompts/extract_edges.py +4 -4
  22. graphiti_core/prompts/extract_nodes.py +26 -28
  23. graphiti_core/prompts/prompt_helpers.py +18 -2
  24. graphiti_core/prompts/snippets.py +29 -0
  25. graphiti_core/prompts/summarize_nodes.py +22 -24
  26. graphiti_core/search/search_filters.py +0 -38
  27. graphiti_core/search/search_helpers.py +4 -4
  28. graphiti_core/search/search_utils.py +84 -220
  29. graphiti_core/tracer.py +193 -0
  30. graphiti_core/utils/bulk_utils.py +16 -28
  31. graphiti_core/utils/maintenance/community_operations.py +4 -1
  32. graphiti_core/utils/maintenance/edge_operations.py +26 -15
  33. graphiti_core/utils/maintenance/graph_data_operations.py +6 -25
  34. graphiti_core/utils/maintenance/node_operations.py +98 -51
  35. graphiti_core/utils/maintenance/temporal_operations.py +4 -1
  36. graphiti_core/utils/text_utils.py +53 -0
  37. {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/METADATA +7 -3
  38. {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/RECORD +41 -35
  39. /graphiti_core/{utils/maintenance/utils.py → driver/graph_operations/__init__.py} +0 -0
  40. {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/WHEEL +0 -0
  41. {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/licenses/LICENSE +0 -0
@@ -120,13 +120,12 @@ class OpenAIGenericClient(LLMClient):
120
120
  response_model: type[BaseModel] | None = None,
121
121
  max_tokens: int | None = None,
122
122
  model_size: ModelSize = ModelSize.medium,
123
+ group_id: str | None = None,
124
+ prompt_name: str | None = None,
123
125
  ) -> dict[str, typing.Any]:
124
126
  if max_tokens is None:
125
127
  max_tokens = self.max_tokens
126
128
 
127
- retry_count = 0
128
- last_error = None
129
-
130
129
  if response_model is not None:
131
130
  serialized_model = json.dumps(response_model.model_json_schema())
132
131
  messages[
@@ -136,44 +135,67 @@ class OpenAIGenericClient(LLMClient):
136
135
  )
137
136
 
138
137
  # Add multilingual extraction instructions
139
- messages[0].content += get_extraction_language_instruction()
140
-
141
- while retry_count <= self.MAX_RETRIES:
142
- try:
143
- response = await self._generate_response(
144
- messages, response_model, max_tokens=max_tokens, model_size=model_size
145
- )
146
- return response
147
- except (RateLimitError, RefusalError):
148
- # These errors should not trigger retries
149
- raise
150
- except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
151
- # Let OpenAI's client handle these retries
152
- raise
153
- except Exception as e:
154
- last_error = e
155
-
156
- # Don't retry if we've hit the max retries
157
- if retry_count >= self.MAX_RETRIES:
158
- logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
138
+ messages[0].content += get_extraction_language_instruction(group_id)
139
+
140
+ # Wrap entire operation in tracing span
141
+ with self.tracer.start_span('llm.generate') as span:
142
+ attributes = {
143
+ 'llm.provider': 'openai',
144
+ 'model.size': model_size.value,
145
+ 'max_tokens': max_tokens,
146
+ }
147
+ if prompt_name:
148
+ attributes['prompt.name'] = prompt_name
149
+ span.add_attributes(attributes)
150
+
151
+ retry_count = 0
152
+ last_error = None
153
+
154
+ while retry_count <= self.MAX_RETRIES:
155
+ try:
156
+ response = await self._generate_response(
157
+ messages, response_model, max_tokens=max_tokens, model_size=model_size
158
+ )
159
+ return response
160
+ except (RateLimitError, RefusalError):
161
+ # These errors should not trigger retries
162
+ span.set_status('error', str(last_error))
159
163
  raise
160
-
161
- retry_count += 1
162
-
163
- # Construct a detailed error message for the LLM
164
- error_context = (
165
- f'The previous response attempt was invalid. '
166
- f'Error type: {e.__class__.__name__}. '
167
- f'Error details: {str(e)}. '
168
- f'Please try again with a valid response, ensuring the output matches '
169
- f'the expected format and constraints.'
170
- )
171
-
172
- error_message = Message(role='user', content=error_context)
173
- messages.append(error_message)
174
- logger.warning(
175
- f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
176
- )
177
-
178
- # If we somehow get here, raise the last error
179
- raise last_error or Exception('Max retries exceeded with no specific error')
164
+ except (
165
+ openai.APITimeoutError,
166
+ openai.APIConnectionError,
167
+ openai.InternalServerError,
168
+ ):
169
+ # Let OpenAI's client handle these retries
170
+ span.set_status('error', str(last_error))
171
+ raise
172
+ except Exception as e:
173
+ last_error = e
174
+
175
+ # Don't retry if we've hit the max retries
176
+ if retry_count >= self.MAX_RETRIES:
177
+ logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
178
+ span.set_status('error', str(e))
179
+ span.record_exception(e)
180
+ raise
181
+
182
+ retry_count += 1
183
+
184
+ # Construct a detailed error message for the LLM
185
+ error_context = (
186
+ f'The previous response attempt was invalid. '
187
+ f'Error type: {e.__class__.__name__}. '
188
+ f'Error details: {str(e)}. '
189
+ f'Please try again with a valid response, ensuring the output matches '
190
+ f'the expected format and constraints.'
191
+ )
192
+
193
+ error_message = Message(role='user', content=error_context)
194
+ messages.append(error_message)
195
+ logger.warning(
196
+ f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
197
+ )
198
+
199
+ # If we somehow get here, raise the last error
200
+ span.set_status('error', str(last_error))
201
+ raise last_error or Exception('Max retries exceeded with no specific error')
@@ -68,6 +68,7 @@ def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False)
68
68
  MATCH (target:Entity {uuid: $edge_data.target_uuid})
69
69
  MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
70
70
  SET e = $edge_data
71
+ SET e.fact_embedding = vecf32($edge_data.fact_embedding)
71
72
  RETURN e.uuid AS uuid
72
73
  """
73
74
  case GraphProvider.NEPTUNE:
@@ -133,6 +133,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: b
133
133
  MERGE (n:Entity {{uuid: $entity_data.uuid}})
134
134
  SET n:{labels}
135
135
  SET n = $entity_data
136
+ SET n.name_embedding = vecf32($entity_data.name_embedding)
136
137
  RETURN n.uuid AS uuid
137
138
  """
138
139
  case GraphProvider.KUZU:
graphiti_core/nodes.py CHANGED
@@ -27,10 +27,6 @@ from pydantic import BaseModel, Field
27
27
  from typing_extensions import LiteralString
28
28
 
29
29
  from graphiti_core.driver.driver import (
30
- COMMUNITY_INDEX_NAME,
31
- ENTITY_EDGE_INDEX_NAME,
32
- ENTITY_INDEX_NAME,
33
- EPISODE_INDEX_NAME,
34
30
  GraphDriver,
35
31
  GraphProvider,
36
32
  )
@@ -99,6 +95,9 @@ class Node(BaseModel, ABC):
99
95
  async def save(self, driver: GraphDriver): ...
100
96
 
101
97
  async def delete(self, driver: GraphDriver):
98
+ if driver.graph_operations_interface:
99
+ return await driver.graph_operations_interface.node_delete(self, driver)
100
+
102
101
  match driver.provider:
103
102
  case GraphProvider.NEO4J:
104
103
  records, _, _ = await driver.execute_query(
@@ -113,27 +112,6 @@ class Node(BaseModel, ABC):
113
112
  uuid=self.uuid,
114
113
  )
115
114
 
116
- edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
117
-
118
- if driver.aoss_client:
119
- # Delete the node from OpenSearch indices
120
- for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
121
- await driver.aoss_client.delete(
122
- index=index,
123
- id=self.uuid,
124
- params={'routing': self.group_id},
125
- )
126
-
127
- # Bulk delete the detached edges
128
- if edge_uuids:
129
- actions = []
130
- for eid in edge_uuids:
131
- actions.append(
132
- {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
133
- )
134
-
135
- await driver.aoss_client.bulk(body=actions)
136
-
137
115
  case GraphProvider.KUZU:
138
116
  for label in ['Episodic', 'Community']:
139
117
  await driver.execute_query(
@@ -181,14 +159,18 @@ class Node(BaseModel, ABC):
181
159
 
182
160
  @classmethod
183
161
  async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
162
+ if driver.graph_operations_interface:
163
+ return await driver.graph_operations_interface.node_delete_by_group_id(
164
+ cls, driver, group_id, batch_size
165
+ )
166
+
184
167
  match driver.provider:
185
168
  case GraphProvider.NEO4J:
186
169
  async with driver.session() as session:
187
170
  await session.run(
188
171
  """
189
172
  MATCH (n:Entity|Episodic|Community {group_id: $group_id})
190
- CALL {
191
- WITH n
173
+ CALL (n) {
192
174
  DETACH DELETE n
193
175
  } IN TRANSACTIONS OF $batch_size ROWS
194
176
  """,
@@ -196,31 +178,6 @@ class Node(BaseModel, ABC):
196
178
  batch_size=batch_size,
197
179
  )
198
180
 
199
- if driver.aoss_client:
200
- await driver.aoss_client.delete_by_query(
201
- index=EPISODE_INDEX_NAME,
202
- body={'query': {'term': {'group_id': group_id}}},
203
- params={'routing': group_id},
204
- )
205
-
206
- await driver.aoss_client.delete_by_query(
207
- index=ENTITY_INDEX_NAME,
208
- body={'query': {'term': {'group_id': group_id}}},
209
- params={'routing': group_id},
210
- )
211
-
212
- await driver.aoss_client.delete_by_query(
213
- index=COMMUNITY_INDEX_NAME,
214
- body={'query': {'term': {'group_id': group_id}}},
215
- params={'routing': group_id},
216
- )
217
-
218
- await driver.aoss_client.delete_by_query(
219
- index=ENTITY_EDGE_INDEX_NAME,
220
- body={'query': {'term': {'group_id': group_id}}},
221
- params={'routing': group_id},
222
- )
223
-
224
181
  case GraphProvider.KUZU:
225
182
  for label in ['Episodic', 'Community']:
226
183
  await driver.execute_query(
@@ -258,6 +215,11 @@ class Node(BaseModel, ABC):
258
215
 
259
216
  @classmethod
260
217
  async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
218
+ if driver.graph_operations_interface:
219
+ return await driver.graph_operations_interface.node_delete_by_uuids(
220
+ cls, driver, uuids, group_id=None, batch_size=batch_size
221
+ )
222
+
261
223
  match driver.provider:
262
224
  case GraphProvider.FALKORDB:
263
225
  for label in ['Entity', 'Episodic', 'Community']:
@@ -300,7 +262,7 @@ class Node(BaseModel, ABC):
300
262
  case _: # Neo4J, Neptune
301
263
  async with driver.session() as session:
302
264
  # Collect all edge UUIDs before deleting nodes
303
- result = await session.run(
265
+ await session.run(
304
266
  """
305
267
  MATCH (n:Entity|Episodic|Community)
306
268
  WHERE n.uuid IN $uuids
@@ -310,18 +272,12 @@ class Node(BaseModel, ABC):
310
272
  uuids=uuids,
311
273
  )
312
274
 
313
- record = await result.single()
314
- edge_uuids: list[str] = (
315
- record['edge_uuids'] if record and record['edge_uuids'] else []
316
- )
317
-
318
275
  # Now delete the nodes in batches
319
276
  await session.run(
320
277
  """
321
278
  MATCH (n:Entity|Episodic|Community)
322
279
  WHERE n.uuid IN $uuids
323
- CALL {
324
- WITH n
280
+ CALL (n) {
325
281
  DETACH DELETE n
326
282
  } IN TRANSACTIONS OF $batch_size ROWS
327
283
  """,
@@ -329,20 +285,6 @@ class Node(BaseModel, ABC):
329
285
  batch_size=batch_size,
330
286
  )
331
287
 
332
- if driver.aoss_client:
333
- for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
334
- await driver.aoss_client.delete_by_query(
335
- index=index,
336
- body={'query': {'terms': {'uuid': uuids}}},
337
- )
338
-
339
- if edge_uuids:
340
- actions = [
341
- {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
342
- for eid in edge_uuids
343
- ]
344
- await driver.aoss_client.bulk(body=actions)
345
-
346
288
  @classmethod
347
289
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
348
290
 
@@ -363,6 +305,9 @@ class EpisodicNode(Node):
363
305
  )
364
306
 
365
307
  async def save(self, driver: GraphDriver):
308
+ if driver.graph_operations_interface:
309
+ return await driver.graph_operations_interface.episodic_node_save(self, driver)
310
+
366
311
  episode_args = {
367
312
  'uuid': self.uuid,
368
313
  'name': self.name,
@@ -375,12 +320,6 @@ class EpisodicNode(Node):
375
320
  'source': self.source.value,
376
321
  }
377
322
 
378
- if driver.aoss_client:
379
- await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
380
- 'episodes',
381
- [episode_args],
382
- )
383
-
384
323
  result = await driver.execute_query(
385
324
  get_episode_node_save_query(driver.provider), **episode_args
386
325
  )
@@ -510,26 +449,14 @@ class EntityNode(Node):
510
449
  return self.name_embedding
511
450
 
512
451
  async def load_name_embedding(self, driver: GraphDriver):
452
+ if driver.graph_operations_interface:
453
+ return await driver.graph_operations_interface.node_load_embeddings(self, driver)
454
+
513
455
  if driver.provider == GraphProvider.NEPTUNE:
514
456
  query: LiteralString = """
515
457
  MATCH (n:Entity {uuid: $uuid})
516
458
  RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
517
459
  """
518
- elif driver.aoss_client:
519
- resp = await driver.aoss_client.search(
520
- body={
521
- 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
522
- 'size': 1,
523
- },
524
- index=ENTITY_INDEX_NAME,
525
- params={'routing': self.group_id},
526
- )
527
-
528
- if resp['hits']['hits']:
529
- self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
530
- return
531
- else:
532
- raise NodeNotFoundError(self.uuid)
533
460
 
534
461
  else:
535
462
  query: LiteralString = """
@@ -548,6 +475,9 @@ class EntityNode(Node):
548
475
  self.name_embedding = records[0]['name_embedding']
549
476
 
550
477
  async def save(self, driver: GraphDriver):
478
+ if driver.graph_operations_interface:
479
+ return await driver.graph_operations_interface.node_save(self, driver)
480
+
551
481
  entity_data: dict[str, Any] = {
552
482
  'uuid': self.uuid,
553
483
  'name': self.name,
@@ -568,11 +498,8 @@ class EntityNode(Node):
568
498
  entity_data.update(self.attributes or {})
569
499
  labels = ':'.join(self.labels + ['Entity'])
570
500
 
571
- if driver.aoss_client:
572
- await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
573
-
574
501
  result = await driver.execute_query(
575
- get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
502
+ get_entity_node_save_query(driver.provider, labels),
576
503
  entity_data=entity_data,
577
504
  )
578
505
 
@@ -67,13 +67,13 @@ def edge(context: dict[str, Any]) -> list[Message]:
67
67
  Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
68
68
 
69
69
  <EXISTING EDGES>
70
- {to_prompt_json(context['related_edges'], indent=2)}
70
+ {to_prompt_json(context['related_edges'])}
71
71
  </EXISTING EDGES>
72
72
 
73
73
  <NEW EDGE>
74
- {to_prompt_json(context['extracted_edges'], indent=2)}
74
+ {to_prompt_json(context['extracted_edges'])}
75
75
  </NEW EDGE>
76
-
76
+
77
77
  Task:
78
78
  If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact
79
79
  as part of the list of duplicate_facts.
@@ -98,7 +98,7 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
98
98
  Given the following context, find all of the duplicates in a list of facts:
99
99
 
100
100
  Facts:
101
- {to_prompt_json(context['edges'], indent=2)}
101
+ {to_prompt_json(context['edges'])}
102
102
 
103
103
  Task:
104
104
  If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
@@ -64,20 +64,20 @@ def node(context: dict[str, Any]) -> list[Message]:
64
64
  role='user',
65
65
  content=f"""
66
66
  <PREVIOUS MESSAGES>
67
- {to_prompt_json([ep for ep in context['previous_episodes']], indent=2)}
67
+ {to_prompt_json([ep for ep in context['previous_episodes']])}
68
68
  </PREVIOUS MESSAGES>
69
69
  <CURRENT MESSAGE>
70
70
  {context['episode_content']}
71
71
  </CURRENT MESSAGE>
72
72
  <NEW ENTITY>
73
- {to_prompt_json(context['extracted_node'], indent=2)}
73
+ {to_prompt_json(context['extracted_node'])}
74
74
  </NEW ENTITY>
75
75
  <ENTITY TYPE DESCRIPTION>
76
- {to_prompt_json(context['entity_type_description'], indent=2)}
76
+ {to_prompt_json(context['entity_type_description'])}
77
77
  </ENTITY TYPE DESCRIPTION>
78
78
 
79
79
  <EXISTING ENTITIES>
80
- {to_prompt_json(context['existing_nodes'], indent=2)}
80
+ {to_prompt_json(context['existing_nodes'])}
81
81
  </EXISTING ENTITIES>
82
82
 
83
83
  Given the above EXISTING ENTITIES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW ENTITY extracted from the conversation
@@ -125,13 +125,13 @@ def nodes(context: dict[str, Any]) -> list[Message]:
125
125
  role='user',
126
126
  content=f"""
127
127
  <PREVIOUS MESSAGES>
128
- {to_prompt_json([ep for ep in context['previous_episodes']], indent=2)}
128
+ {to_prompt_json([ep for ep in context['previous_episodes']])}
129
129
  </PREVIOUS MESSAGES>
130
130
  <CURRENT MESSAGE>
131
131
  {context['episode_content']}
132
132
  </CURRENT MESSAGE>
133
-
134
-
133
+
134
+
135
135
  Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
136
136
  Each entity in ENTITIES is represented as a JSON object with the following structure:
137
137
  {{
@@ -142,11 +142,11 @@ def nodes(context: dict[str, Any]) -> list[Message]:
142
142
  }}
143
143
 
144
144
  <ENTITIES>
145
- {to_prompt_json(context['extracted_nodes'], indent=2)}
145
+ {to_prompt_json(context['extracted_nodes'])}
146
146
  </ENTITIES>
147
147
 
148
148
  <EXISTING ENTITIES>
149
- {to_prompt_json(context['existing_nodes'], indent=2)}
149
+ {to_prompt_json(context['existing_nodes'])}
150
150
  </EXISTING ENTITIES>
151
151
 
152
152
  Each entry in EXISTING ENTITIES is an object with the following structure:
@@ -197,7 +197,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
197
197
  Given the following context, deduplicate a list of nodes:
198
198
 
199
199
  Nodes:
200
- {to_prompt_json(context['nodes'], indent=2)}
200
+ {to_prompt_json(context['nodes'])}
201
201
 
202
202
  Task:
203
203
  1. Group nodes together such that all duplicate nodes are in the same list of uuids
@@ -80,7 +80,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
80
80
  </FACT TYPES>
81
81
 
82
82
  <PREVIOUS_MESSAGES>
83
- {to_prompt_json([ep for ep in context['previous_episodes']], indent=2)}
83
+ {to_prompt_json([ep for ep in context['previous_episodes']])}
84
84
  </PREVIOUS_MESSAGES>
85
85
 
86
86
  <CURRENT_MESSAGE>
@@ -88,7 +88,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
88
88
  </CURRENT_MESSAGE>
89
89
 
90
90
  <ENTITIES>
91
- {to_prompt_json(context['nodes'], indent=2)}
91
+ {to_prompt_json(context['nodes'])}
92
92
  </ENTITIES>
93
93
 
94
94
  <REFERENCE_TIME>
@@ -141,7 +141,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
141
141
 
142
142
  user_prompt = f"""
143
143
  <PREVIOUS MESSAGES>
144
- {to_prompt_json([ep for ep in context['previous_episodes']], indent=2)}
144
+ {to_prompt_json([ep for ep in context['previous_episodes']])}
145
145
  </PREVIOUS MESSAGES>
146
146
  <CURRENT MESSAGE>
147
147
  {context['episode_content']}
@@ -175,7 +175,7 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
175
175
  content=f"""
176
176
 
177
177
  <MESSAGE>
178
- {to_prompt_json(context['episode_content'], indent=2)}
178
+ {to_prompt_json(context['episode_content'])}
179
179
  </MESSAGE>
180
180
  <REFERENCE TIME>
181
181
  {context['reference_time']}
@@ -18,8 +18,11 @@ from typing import Any, Protocol, TypedDict
18
18
 
19
19
  from pydantic import BaseModel, Field
20
20
 
21
+ from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS
22
+
21
23
  from .models import Message, PromptFunction, PromptVersion
22
24
  from .prompt_helpers import to_prompt_json
25
+ from .snippets import summary_instructions
23
26
 
24
27
 
25
28
  class ExtractedEntity(BaseModel):
@@ -42,7 +45,8 @@ class EntityClassificationTriple(BaseModel):
42
45
  uuid: str = Field(description='UUID of the entity')
43
46
  name: str = Field(description='Name of the entity')
44
47
  entity_type: str | None = Field(
45
- default=None, description='Type of the entity. Must be one of the provided types or None'
48
+ default=None,
49
+ description='Type of the entity. Must be one of the provided types or None',
46
50
  )
47
51
 
48
52
 
@@ -55,7 +59,7 @@ class EntityClassification(BaseModel):
55
59
  class EntitySummary(BaseModel):
56
60
  summary: str = Field(
57
61
  ...,
58
- description='Summary containing the important information about the entity. Under 250 words',
62
+ description=f'Summary containing the important information about the entity. Under {MAX_SUMMARY_CHARS} characters.',
59
63
  )
60
64
 
61
65
 
@@ -89,7 +93,7 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
89
93
  </ENTITY TYPES>
90
94
 
91
95
  <PREVIOUS MESSAGES>
92
- {to_prompt_json([ep for ep in context['previous_episodes']], indent=2)}
96
+ {to_prompt_json([ep for ep in context['previous_episodes']])}
93
97
  </PREVIOUS MESSAGES>
94
98
 
95
99
  <CURRENT MESSAGE>
@@ -197,7 +201,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
197
201
 
198
202
  user_prompt = f"""
199
203
  <PREVIOUS MESSAGES>
200
- {to_prompt_json([ep for ep in context['previous_episodes']], indent=2)}
204
+ {to_prompt_json([ep for ep in context['previous_episodes']])}
201
205
  </PREVIOUS MESSAGES>
202
206
  <CURRENT MESSAGE>
203
207
  {context['episode_content']}
@@ -221,22 +225,22 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
221
225
 
222
226
  user_prompt = f"""
223
227
  <PREVIOUS MESSAGES>
224
- {to_prompt_json([ep for ep in context['previous_episodes']], indent=2)}
228
+ {to_prompt_json([ep for ep in context['previous_episodes']])}
225
229
  </PREVIOUS MESSAGES>
226
230
  <CURRENT MESSAGE>
227
231
  {context['episode_content']}
228
232
  </CURRENT MESSAGE>
229
-
233
+
230
234
  <EXTRACTED ENTITIES>
231
235
  {context['extracted_entities']}
232
236
  </EXTRACTED ENTITIES>
233
-
237
+
234
238
  <ENTITY TYPES>
235
239
  {context['entity_types']}
236
240
  </ENTITY TYPES>
237
-
241
+
238
242
  Given the above conversation, extracted entities, and provided entity types and their descriptions, classify the extracted entities.
239
-
243
+
240
244
  Guidelines:
241
245
  1. Each entity must have exactly one type
242
246
  2. Only use the provided ENTITY TYPES as types, do not use additional types to classify entities.
@@ -257,19 +261,18 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
257
261
  Message(
258
262
  role='user',
259
263
  content=f"""
260
-
261
- <MESSAGES>
262
- {to_prompt_json(context['previous_episodes'], indent=2)}
263
- {to_prompt_json(context['episode_content'], indent=2)}
264
- </MESSAGES>
265
-
266
- Given the above MESSAGES and the following ENTITY, update any of its attributes based on the information provided
264
+ Given the MESSAGES and the following ENTITY, update any of its attributes based on the information provided
267
265
  in MESSAGES. Use the provided attribute descriptions to better understand how each attribute should be determined.
268
266
 
269
267
  Guidelines:
270
268
  1. Do not hallucinate entity property values if they cannot be found in the current context.
271
269
  2. Only use the provided MESSAGES and ENTITY to set attribute values.
272
-
270
+
271
+ <MESSAGES>
272
+ {to_prompt_json(context['previous_episodes'])}
273
+ {to_prompt_json(context['episode_content'])}
274
+ </MESSAGES>
275
+
273
276
  <ENTITY>
274
277
  {context['node']}
275
278
  </ENTITY>
@@ -287,21 +290,16 @@ def extract_summary(context: dict[str, Any]) -> list[Message]:
287
290
  Message(
288
291
  role='user',
289
292
  content=f"""
293
+ Given the MESSAGES and the ENTITY, update the summary that combines relevant information about the entity
294
+ from the messages and relevant information from the existing summary.
295
+
296
+ {summary_instructions}
290
297
 
291
298
  <MESSAGES>
292
- {to_prompt_json(context['previous_episodes'], indent=2)}
293
- {to_prompt_json(context['episode_content'], indent=2)}
299
+ {to_prompt_json(context['previous_episodes'])}
300
+ {to_prompt_json(context['episode_content'])}
294
301
  </MESSAGES>
295
302
 
296
- Given the above MESSAGES and the following ENTITY, update the summary that combines relevant information about the entity
297
- from the messages and relevant information from the existing summary.
298
-
299
- Guidelines:
300
- 1. Do not hallucinate entity summary information if they cannot be found in the current context.
301
- 2. Only use the provided MESSAGES and ENTITY to set attribute values.
302
- 3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
303
- Summaries must be no longer than 250 words.
304
-
305
303
  <ENTITY>
306
304
  {context['node']}
307
305
  </ENTITY>
@@ -1,17 +1,33 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  import json
2
18
  from typing import Any
3
19
 
4
20
  DO_NOT_ESCAPE_UNICODE = '\nDo not escape unicode characters.\n'
5
21
 
6
22
 
7
- def to_prompt_json(data: Any, ensure_ascii: bool = False, indent: int = 2) -> str:
23
+ def to_prompt_json(data: Any, ensure_ascii: bool = False, indent: int | None = None) -> str:
8
24
  """
9
25
  Serialize data to JSON for use in prompts.
10
26
 
11
27
  Args:
12
28
  data: The data to serialize
13
29
  ensure_ascii: If True, escape non-ASCII characters. If False (default), preserve them.
14
- indent: Number of spaces for indentation
30
+ indent: Number of spaces for indentation. Defaults to None (minified).
15
31
 
16
32
  Returns:
17
33
  JSON string representation of the data