graphiti-core 0.21.0rc13__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +4 -211
- graphiti_core/driver/falkordb_driver.py +31 -3
- graphiti_core/driver/graph_operations/graph_operations.py +195 -0
- graphiti_core/driver/neo4j_driver.py +0 -49
- graphiti_core/driver/neptune_driver.py +43 -26
- graphiti_core/driver/search_interface/__init__.py +0 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +11 -34
- graphiti_core/graphiti.py +459 -326
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/llm_client/anthropic_client.py +64 -45
- graphiti_core/llm_client/client.py +67 -19
- graphiti_core/llm_client/gemini_client.py +73 -54
- graphiti_core/llm_client/openai_base_client.py +65 -43
- graphiti_core/llm_client/openai_generic_client.py +65 -43
- graphiti_core/models/edges/edge_db_queries.py +1 -0
- graphiti_core/models/nodes/node_db_queries.py +1 -0
- graphiti_core/nodes.py +26 -99
- graphiti_core/prompts/dedupe_edges.py +4 -4
- graphiti_core/prompts/dedupe_nodes.py +10 -10
- graphiti_core/prompts/extract_edges.py +4 -4
- graphiti_core/prompts/extract_nodes.py +26 -28
- graphiti_core/prompts/prompt_helpers.py +18 -2
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +22 -24
- graphiti_core/search/search_filters.py +0 -38
- graphiti_core/search/search_helpers.py +4 -4
- graphiti_core/search/search_utils.py +84 -220
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +16 -28
- graphiti_core/utils/maintenance/community_operations.py +4 -1
- graphiti_core/utils/maintenance/edge_operations.py +26 -15
- graphiti_core/utils/maintenance/graph_data_operations.py +6 -25
- graphiti_core/utils/maintenance/node_operations.py +98 -51
- graphiti_core/utils/maintenance/temporal_operations.py +4 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/METADATA +7 -3
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/RECORD +41 -35
- /graphiti_core/{utils/maintenance/utils.py → driver/graph_operations/__init__.py} +0 -0
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -120,13 +120,12 @@ class OpenAIGenericClient(LLMClient):
|
|
|
120
120
|
response_model: type[BaseModel] | None = None,
|
|
121
121
|
max_tokens: int | None = None,
|
|
122
122
|
model_size: ModelSize = ModelSize.medium,
|
|
123
|
+
group_id: str | None = None,
|
|
124
|
+
prompt_name: str | None = None,
|
|
123
125
|
) -> dict[str, typing.Any]:
|
|
124
126
|
if max_tokens is None:
|
|
125
127
|
max_tokens = self.max_tokens
|
|
126
128
|
|
|
127
|
-
retry_count = 0
|
|
128
|
-
last_error = None
|
|
129
|
-
|
|
130
129
|
if response_model is not None:
|
|
131
130
|
serialized_model = json.dumps(response_model.model_json_schema())
|
|
132
131
|
messages[
|
|
@@ -136,44 +135,67 @@ class OpenAIGenericClient(LLMClient):
|
|
|
136
135
|
)
|
|
137
136
|
|
|
138
137
|
# Add multilingual extraction instructions
|
|
139
|
-
messages[0].content += get_extraction_language_instruction()
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
138
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
139
|
+
|
|
140
|
+
# Wrap entire operation in tracing span
|
|
141
|
+
with self.tracer.start_span('llm.generate') as span:
|
|
142
|
+
attributes = {
|
|
143
|
+
'llm.provider': 'openai',
|
|
144
|
+
'model.size': model_size.value,
|
|
145
|
+
'max_tokens': max_tokens,
|
|
146
|
+
}
|
|
147
|
+
if prompt_name:
|
|
148
|
+
attributes['prompt.name'] = prompt_name
|
|
149
|
+
span.add_attributes(attributes)
|
|
150
|
+
|
|
151
|
+
retry_count = 0
|
|
152
|
+
last_error = None
|
|
153
|
+
|
|
154
|
+
while retry_count <= self.MAX_RETRIES:
|
|
155
|
+
try:
|
|
156
|
+
response = await self._generate_response(
|
|
157
|
+
messages, response_model, max_tokens=max_tokens, model_size=model_size
|
|
158
|
+
)
|
|
159
|
+
return response
|
|
160
|
+
except (RateLimitError, RefusalError):
|
|
161
|
+
# These errors should not trigger retries
|
|
162
|
+
span.set_status('error', str(last_error))
|
|
159
163
|
raise
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
164
|
+
except (
|
|
165
|
+
openai.APITimeoutError,
|
|
166
|
+
openai.APIConnectionError,
|
|
167
|
+
openai.InternalServerError,
|
|
168
|
+
):
|
|
169
|
+
# Let OpenAI's client handle these retries
|
|
170
|
+
span.set_status('error', str(last_error))
|
|
171
|
+
raise
|
|
172
|
+
except Exception as e:
|
|
173
|
+
last_error = e
|
|
174
|
+
|
|
175
|
+
# Don't retry if we've hit the max retries
|
|
176
|
+
if retry_count >= self.MAX_RETRIES:
|
|
177
|
+
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
|
178
|
+
span.set_status('error', str(e))
|
|
179
|
+
span.record_exception(e)
|
|
180
|
+
raise
|
|
181
|
+
|
|
182
|
+
retry_count += 1
|
|
183
|
+
|
|
184
|
+
# Construct a detailed error message for the LLM
|
|
185
|
+
error_context = (
|
|
186
|
+
f'The previous response attempt was invalid. '
|
|
187
|
+
f'Error type: {e.__class__.__name__}. '
|
|
188
|
+
f'Error details: {str(e)}. '
|
|
189
|
+
f'Please try again with a valid response, ensuring the output matches '
|
|
190
|
+
f'the expected format and constraints.'
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
error_message = Message(role='user', content=error_context)
|
|
194
|
+
messages.append(error_message)
|
|
195
|
+
logger.warning(
|
|
196
|
+
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# If we somehow get here, raise the last error
|
|
200
|
+
span.set_status('error', str(last_error))
|
|
201
|
+
raise last_error or Exception('Max retries exceeded with no specific error')
|
|
@@ -68,6 +68,7 @@ def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False)
|
|
|
68
68
|
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
|
69
69
|
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
|
70
70
|
SET e = $edge_data
|
|
71
|
+
SET e.fact_embedding = vecf32($edge_data.fact_embedding)
|
|
71
72
|
RETURN e.uuid AS uuid
|
|
72
73
|
"""
|
|
73
74
|
case GraphProvider.NEPTUNE:
|
|
@@ -133,6 +133,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: b
|
|
|
133
133
|
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
|
134
134
|
SET n:{labels}
|
|
135
135
|
SET n = $entity_data
|
|
136
|
+
SET n.name_embedding = vecf32($entity_data.name_embedding)
|
|
136
137
|
RETURN n.uuid AS uuid
|
|
137
138
|
"""
|
|
138
139
|
case GraphProvider.KUZU:
|
graphiti_core/nodes.py
CHANGED
|
@@ -27,10 +27,6 @@ from pydantic import BaseModel, Field
|
|
|
27
27
|
from typing_extensions import LiteralString
|
|
28
28
|
|
|
29
29
|
from graphiti_core.driver.driver import (
|
|
30
|
-
COMMUNITY_INDEX_NAME,
|
|
31
|
-
ENTITY_EDGE_INDEX_NAME,
|
|
32
|
-
ENTITY_INDEX_NAME,
|
|
33
|
-
EPISODE_INDEX_NAME,
|
|
34
30
|
GraphDriver,
|
|
35
31
|
GraphProvider,
|
|
36
32
|
)
|
|
@@ -99,6 +95,9 @@ class Node(BaseModel, ABC):
|
|
|
99
95
|
async def save(self, driver: GraphDriver): ...
|
|
100
96
|
|
|
101
97
|
async def delete(self, driver: GraphDriver):
|
|
98
|
+
if driver.graph_operations_interface:
|
|
99
|
+
return await driver.graph_operations_interface.node_delete(self, driver)
|
|
100
|
+
|
|
102
101
|
match driver.provider:
|
|
103
102
|
case GraphProvider.NEO4J:
|
|
104
103
|
records, _, _ = await driver.execute_query(
|
|
@@ -113,27 +112,6 @@ class Node(BaseModel, ABC):
|
|
|
113
112
|
uuid=self.uuid,
|
|
114
113
|
)
|
|
115
114
|
|
|
116
|
-
edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
|
|
117
|
-
|
|
118
|
-
if driver.aoss_client:
|
|
119
|
-
# Delete the node from OpenSearch indices
|
|
120
|
-
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
|
121
|
-
await driver.aoss_client.delete(
|
|
122
|
-
index=index,
|
|
123
|
-
id=self.uuid,
|
|
124
|
-
params={'routing': self.group_id},
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
# Bulk delete the detached edges
|
|
128
|
-
if edge_uuids:
|
|
129
|
-
actions = []
|
|
130
|
-
for eid in edge_uuids:
|
|
131
|
-
actions.append(
|
|
132
|
-
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
await driver.aoss_client.bulk(body=actions)
|
|
136
|
-
|
|
137
115
|
case GraphProvider.KUZU:
|
|
138
116
|
for label in ['Episodic', 'Community']:
|
|
139
117
|
await driver.execute_query(
|
|
@@ -181,14 +159,18 @@ class Node(BaseModel, ABC):
|
|
|
181
159
|
|
|
182
160
|
@classmethod
|
|
183
161
|
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
|
162
|
+
if driver.graph_operations_interface:
|
|
163
|
+
return await driver.graph_operations_interface.node_delete_by_group_id(
|
|
164
|
+
cls, driver, group_id, batch_size
|
|
165
|
+
)
|
|
166
|
+
|
|
184
167
|
match driver.provider:
|
|
185
168
|
case GraphProvider.NEO4J:
|
|
186
169
|
async with driver.session() as session:
|
|
187
170
|
await session.run(
|
|
188
171
|
"""
|
|
189
172
|
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
|
190
|
-
CALL {
|
|
191
|
-
WITH n
|
|
173
|
+
CALL (n) {
|
|
192
174
|
DETACH DELETE n
|
|
193
175
|
} IN TRANSACTIONS OF $batch_size ROWS
|
|
194
176
|
""",
|
|
@@ -196,31 +178,6 @@ class Node(BaseModel, ABC):
|
|
|
196
178
|
batch_size=batch_size,
|
|
197
179
|
)
|
|
198
180
|
|
|
199
|
-
if driver.aoss_client:
|
|
200
|
-
await driver.aoss_client.delete_by_query(
|
|
201
|
-
index=EPISODE_INDEX_NAME,
|
|
202
|
-
body={'query': {'term': {'group_id': group_id}}},
|
|
203
|
-
params={'routing': group_id},
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
await driver.aoss_client.delete_by_query(
|
|
207
|
-
index=ENTITY_INDEX_NAME,
|
|
208
|
-
body={'query': {'term': {'group_id': group_id}}},
|
|
209
|
-
params={'routing': group_id},
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
await driver.aoss_client.delete_by_query(
|
|
213
|
-
index=COMMUNITY_INDEX_NAME,
|
|
214
|
-
body={'query': {'term': {'group_id': group_id}}},
|
|
215
|
-
params={'routing': group_id},
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
await driver.aoss_client.delete_by_query(
|
|
219
|
-
index=ENTITY_EDGE_INDEX_NAME,
|
|
220
|
-
body={'query': {'term': {'group_id': group_id}}},
|
|
221
|
-
params={'routing': group_id},
|
|
222
|
-
)
|
|
223
|
-
|
|
224
181
|
case GraphProvider.KUZU:
|
|
225
182
|
for label in ['Episodic', 'Community']:
|
|
226
183
|
await driver.execute_query(
|
|
@@ -258,6 +215,11 @@ class Node(BaseModel, ABC):
|
|
|
258
215
|
|
|
259
216
|
@classmethod
|
|
260
217
|
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
|
218
|
+
if driver.graph_operations_interface:
|
|
219
|
+
return await driver.graph_operations_interface.node_delete_by_uuids(
|
|
220
|
+
cls, driver, uuids, group_id=None, batch_size=batch_size
|
|
221
|
+
)
|
|
222
|
+
|
|
261
223
|
match driver.provider:
|
|
262
224
|
case GraphProvider.FALKORDB:
|
|
263
225
|
for label in ['Entity', 'Episodic', 'Community']:
|
|
@@ -300,7 +262,7 @@ class Node(BaseModel, ABC):
|
|
|
300
262
|
case _: # Neo4J, Neptune
|
|
301
263
|
async with driver.session() as session:
|
|
302
264
|
# Collect all edge UUIDs before deleting nodes
|
|
303
|
-
|
|
265
|
+
await session.run(
|
|
304
266
|
"""
|
|
305
267
|
MATCH (n:Entity|Episodic|Community)
|
|
306
268
|
WHERE n.uuid IN $uuids
|
|
@@ -310,18 +272,12 @@ class Node(BaseModel, ABC):
|
|
|
310
272
|
uuids=uuids,
|
|
311
273
|
)
|
|
312
274
|
|
|
313
|
-
record = await result.single()
|
|
314
|
-
edge_uuids: list[str] = (
|
|
315
|
-
record['edge_uuids'] if record and record['edge_uuids'] else []
|
|
316
|
-
)
|
|
317
|
-
|
|
318
275
|
# Now delete the nodes in batches
|
|
319
276
|
await session.run(
|
|
320
277
|
"""
|
|
321
278
|
MATCH (n:Entity|Episodic|Community)
|
|
322
279
|
WHERE n.uuid IN $uuids
|
|
323
|
-
CALL {
|
|
324
|
-
WITH n
|
|
280
|
+
CALL (n) {
|
|
325
281
|
DETACH DELETE n
|
|
326
282
|
} IN TRANSACTIONS OF $batch_size ROWS
|
|
327
283
|
""",
|
|
@@ -329,20 +285,6 @@ class Node(BaseModel, ABC):
|
|
|
329
285
|
batch_size=batch_size,
|
|
330
286
|
)
|
|
331
287
|
|
|
332
|
-
if driver.aoss_client:
|
|
333
|
-
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
|
334
|
-
await driver.aoss_client.delete_by_query(
|
|
335
|
-
index=index,
|
|
336
|
-
body={'query': {'terms': {'uuid': uuids}}},
|
|
337
|
-
)
|
|
338
|
-
|
|
339
|
-
if edge_uuids:
|
|
340
|
-
actions = [
|
|
341
|
-
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
|
342
|
-
for eid in edge_uuids
|
|
343
|
-
]
|
|
344
|
-
await driver.aoss_client.bulk(body=actions)
|
|
345
|
-
|
|
346
288
|
@classmethod
|
|
347
289
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
348
290
|
|
|
@@ -363,6 +305,9 @@ class EpisodicNode(Node):
|
|
|
363
305
|
)
|
|
364
306
|
|
|
365
307
|
async def save(self, driver: GraphDriver):
|
|
308
|
+
if driver.graph_operations_interface:
|
|
309
|
+
return await driver.graph_operations_interface.episodic_node_save(self, driver)
|
|
310
|
+
|
|
366
311
|
episode_args = {
|
|
367
312
|
'uuid': self.uuid,
|
|
368
313
|
'name': self.name,
|
|
@@ -375,12 +320,6 @@ class EpisodicNode(Node):
|
|
|
375
320
|
'source': self.source.value,
|
|
376
321
|
}
|
|
377
322
|
|
|
378
|
-
if driver.aoss_client:
|
|
379
|
-
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
380
|
-
'episodes',
|
|
381
|
-
[episode_args],
|
|
382
|
-
)
|
|
383
|
-
|
|
384
323
|
result = await driver.execute_query(
|
|
385
324
|
get_episode_node_save_query(driver.provider), **episode_args
|
|
386
325
|
)
|
|
@@ -510,26 +449,14 @@ class EntityNode(Node):
|
|
|
510
449
|
return self.name_embedding
|
|
511
450
|
|
|
512
451
|
async def load_name_embedding(self, driver: GraphDriver):
|
|
452
|
+
if driver.graph_operations_interface:
|
|
453
|
+
return await driver.graph_operations_interface.node_load_embeddings(self, driver)
|
|
454
|
+
|
|
513
455
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
514
456
|
query: LiteralString = """
|
|
515
457
|
MATCH (n:Entity {uuid: $uuid})
|
|
516
458
|
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
517
459
|
"""
|
|
518
|
-
elif driver.aoss_client:
|
|
519
|
-
resp = await driver.aoss_client.search(
|
|
520
|
-
body={
|
|
521
|
-
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
|
522
|
-
'size': 1,
|
|
523
|
-
},
|
|
524
|
-
index=ENTITY_INDEX_NAME,
|
|
525
|
-
params={'routing': self.group_id},
|
|
526
|
-
)
|
|
527
|
-
|
|
528
|
-
if resp['hits']['hits']:
|
|
529
|
-
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
|
|
530
|
-
return
|
|
531
|
-
else:
|
|
532
|
-
raise NodeNotFoundError(self.uuid)
|
|
533
460
|
|
|
534
461
|
else:
|
|
535
462
|
query: LiteralString = """
|
|
@@ -548,6 +475,9 @@ class EntityNode(Node):
|
|
|
548
475
|
self.name_embedding = records[0]['name_embedding']
|
|
549
476
|
|
|
550
477
|
async def save(self, driver: GraphDriver):
|
|
478
|
+
if driver.graph_operations_interface:
|
|
479
|
+
return await driver.graph_operations_interface.node_save(self, driver)
|
|
480
|
+
|
|
551
481
|
entity_data: dict[str, Any] = {
|
|
552
482
|
'uuid': self.uuid,
|
|
553
483
|
'name': self.name,
|
|
@@ -568,11 +498,8 @@ class EntityNode(Node):
|
|
|
568
498
|
entity_data.update(self.attributes or {})
|
|
569
499
|
labels = ':'.join(self.labels + ['Entity'])
|
|
570
500
|
|
|
571
|
-
if driver.aoss_client:
|
|
572
|
-
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
|
573
|
-
|
|
574
501
|
result = await driver.execute_query(
|
|
575
|
-
get_entity_node_save_query(driver.provider, labels
|
|
502
|
+
get_entity_node_save_query(driver.provider, labels),
|
|
576
503
|
entity_data=entity_data,
|
|
577
504
|
)
|
|
578
505
|
|
|
@@ -67,13 +67,13 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
67
67
|
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
|
|
68
68
|
|
|
69
69
|
<EXISTING EDGES>
|
|
70
|
-
{to_prompt_json(context['related_edges']
|
|
70
|
+
{to_prompt_json(context['related_edges'])}
|
|
71
71
|
</EXISTING EDGES>
|
|
72
72
|
|
|
73
73
|
<NEW EDGE>
|
|
74
|
-
{to_prompt_json(context['extracted_edges']
|
|
74
|
+
{to_prompt_json(context['extracted_edges'])}
|
|
75
75
|
</NEW EDGE>
|
|
76
|
-
|
|
76
|
+
|
|
77
77
|
Task:
|
|
78
78
|
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact
|
|
79
79
|
as part of the list of duplicate_facts.
|
|
@@ -98,7 +98,7 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
|
|
98
98
|
Given the following context, find all of the duplicates in a list of facts:
|
|
99
99
|
|
|
100
100
|
Facts:
|
|
101
|
-
{to_prompt_json(context['edges']
|
|
101
|
+
{to_prompt_json(context['edges'])}
|
|
102
102
|
|
|
103
103
|
Task:
|
|
104
104
|
If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
|
|
@@ -64,20 +64,20 @@ def node(context: dict[str, Any]) -> list[Message]:
|
|
|
64
64
|
role='user',
|
|
65
65
|
content=f"""
|
|
66
66
|
<PREVIOUS MESSAGES>
|
|
67
|
-
{to_prompt_json([ep for ep in context['previous_episodes']]
|
|
67
|
+
{to_prompt_json([ep for ep in context['previous_episodes']])}
|
|
68
68
|
</PREVIOUS MESSAGES>
|
|
69
69
|
<CURRENT MESSAGE>
|
|
70
70
|
{context['episode_content']}
|
|
71
71
|
</CURRENT MESSAGE>
|
|
72
72
|
<NEW ENTITY>
|
|
73
|
-
{to_prompt_json(context['extracted_node']
|
|
73
|
+
{to_prompt_json(context['extracted_node'])}
|
|
74
74
|
</NEW ENTITY>
|
|
75
75
|
<ENTITY TYPE DESCRIPTION>
|
|
76
|
-
{to_prompt_json(context['entity_type_description']
|
|
76
|
+
{to_prompt_json(context['entity_type_description'])}
|
|
77
77
|
</ENTITY TYPE DESCRIPTION>
|
|
78
78
|
|
|
79
79
|
<EXISTING ENTITIES>
|
|
80
|
-
{to_prompt_json(context['existing_nodes']
|
|
80
|
+
{to_prompt_json(context['existing_nodes'])}
|
|
81
81
|
</EXISTING ENTITIES>
|
|
82
82
|
|
|
83
83
|
Given the above EXISTING ENTITIES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW ENTITY extracted from the conversation
|
|
@@ -125,13 +125,13 @@ def nodes(context: dict[str, Any]) -> list[Message]:
|
|
|
125
125
|
role='user',
|
|
126
126
|
content=f"""
|
|
127
127
|
<PREVIOUS MESSAGES>
|
|
128
|
-
{to_prompt_json([ep for ep in context['previous_episodes']]
|
|
128
|
+
{to_prompt_json([ep for ep in context['previous_episodes']])}
|
|
129
129
|
</PREVIOUS MESSAGES>
|
|
130
130
|
<CURRENT MESSAGE>
|
|
131
131
|
{context['episode_content']}
|
|
132
132
|
</CURRENT MESSAGE>
|
|
133
|
-
|
|
134
|
-
|
|
133
|
+
|
|
134
|
+
|
|
135
135
|
Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
|
|
136
136
|
Each entity in ENTITIES is represented as a JSON object with the following structure:
|
|
137
137
|
{{
|
|
@@ -142,11 +142,11 @@ def nodes(context: dict[str, Any]) -> list[Message]:
|
|
|
142
142
|
}}
|
|
143
143
|
|
|
144
144
|
<ENTITIES>
|
|
145
|
-
{to_prompt_json(context['extracted_nodes']
|
|
145
|
+
{to_prompt_json(context['extracted_nodes'])}
|
|
146
146
|
</ENTITIES>
|
|
147
147
|
|
|
148
148
|
<EXISTING ENTITIES>
|
|
149
|
-
{to_prompt_json(context['existing_nodes']
|
|
149
|
+
{to_prompt_json(context['existing_nodes'])}
|
|
150
150
|
</EXISTING ENTITIES>
|
|
151
151
|
|
|
152
152
|
Each entry in EXISTING ENTITIES is an object with the following structure:
|
|
@@ -197,7 +197,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
|
|
|
197
197
|
Given the following context, deduplicate a list of nodes:
|
|
198
198
|
|
|
199
199
|
Nodes:
|
|
200
|
-
{to_prompt_json(context['nodes']
|
|
200
|
+
{to_prompt_json(context['nodes'])}
|
|
201
201
|
|
|
202
202
|
Task:
|
|
203
203
|
1. Group nodes together such that all duplicate nodes are in the same list of uuids
|
|
@@ -80,7 +80,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
80
80
|
</FACT TYPES>
|
|
81
81
|
|
|
82
82
|
<PREVIOUS_MESSAGES>
|
|
83
|
-
{to_prompt_json([ep for ep in context['previous_episodes']]
|
|
83
|
+
{to_prompt_json([ep for ep in context['previous_episodes']])}
|
|
84
84
|
</PREVIOUS_MESSAGES>
|
|
85
85
|
|
|
86
86
|
<CURRENT_MESSAGE>
|
|
@@ -88,7 +88,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
88
88
|
</CURRENT_MESSAGE>
|
|
89
89
|
|
|
90
90
|
<ENTITIES>
|
|
91
|
-
{to_prompt_json(context['nodes']
|
|
91
|
+
{to_prompt_json(context['nodes'])}
|
|
92
92
|
</ENTITIES>
|
|
93
93
|
|
|
94
94
|
<REFERENCE_TIME>
|
|
@@ -141,7 +141,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
|
|
|
141
141
|
|
|
142
142
|
user_prompt = f"""
|
|
143
143
|
<PREVIOUS MESSAGES>
|
|
144
|
-
{to_prompt_json([ep for ep in context['previous_episodes']]
|
|
144
|
+
{to_prompt_json([ep for ep in context['previous_episodes']])}
|
|
145
145
|
</PREVIOUS MESSAGES>
|
|
146
146
|
<CURRENT MESSAGE>
|
|
147
147
|
{context['episode_content']}
|
|
@@ -175,7 +175,7 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
|
|
175
175
|
content=f"""
|
|
176
176
|
|
|
177
177
|
<MESSAGE>
|
|
178
|
-
{to_prompt_json(context['episode_content']
|
|
178
|
+
{to_prompt_json(context['episode_content'])}
|
|
179
179
|
</MESSAGE>
|
|
180
180
|
<REFERENCE TIME>
|
|
181
181
|
{context['reference_time']}
|
|
@@ -18,8 +18,11 @@ from typing import Any, Protocol, TypedDict
|
|
|
18
18
|
|
|
19
19
|
from pydantic import BaseModel, Field
|
|
20
20
|
|
|
21
|
+
from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS
|
|
22
|
+
|
|
21
23
|
from .models import Message, PromptFunction, PromptVersion
|
|
22
24
|
from .prompt_helpers import to_prompt_json
|
|
25
|
+
from .snippets import summary_instructions
|
|
23
26
|
|
|
24
27
|
|
|
25
28
|
class ExtractedEntity(BaseModel):
|
|
@@ -42,7 +45,8 @@ class EntityClassificationTriple(BaseModel):
|
|
|
42
45
|
uuid: str = Field(description='UUID of the entity')
|
|
43
46
|
name: str = Field(description='Name of the entity')
|
|
44
47
|
entity_type: str | None = Field(
|
|
45
|
-
default=None,
|
|
48
|
+
default=None,
|
|
49
|
+
description='Type of the entity. Must be one of the provided types or None',
|
|
46
50
|
)
|
|
47
51
|
|
|
48
52
|
|
|
@@ -55,7 +59,7 @@ class EntityClassification(BaseModel):
|
|
|
55
59
|
class EntitySummary(BaseModel):
|
|
56
60
|
summary: str = Field(
|
|
57
61
|
...,
|
|
58
|
-
description='Summary containing the important information about the entity. Under
|
|
62
|
+
description=f'Summary containing the important information about the entity. Under {MAX_SUMMARY_CHARS} characters.',
|
|
59
63
|
)
|
|
60
64
|
|
|
61
65
|
|
|
@@ -89,7 +93,7 @@ def extract_message(context: dict[str, Any]) -> list[Message]:
|
|
|
89
93
|
</ENTITY TYPES>
|
|
90
94
|
|
|
91
95
|
<PREVIOUS MESSAGES>
|
|
92
|
-
{to_prompt_json([ep for ep in context['previous_episodes']]
|
|
96
|
+
{to_prompt_json([ep for ep in context['previous_episodes']])}
|
|
93
97
|
</PREVIOUS MESSAGES>
|
|
94
98
|
|
|
95
99
|
<CURRENT MESSAGE>
|
|
@@ -197,7 +201,7 @@ def reflexion(context: dict[str, Any]) -> list[Message]:
|
|
|
197
201
|
|
|
198
202
|
user_prompt = f"""
|
|
199
203
|
<PREVIOUS MESSAGES>
|
|
200
|
-
{to_prompt_json([ep for ep in context['previous_episodes']]
|
|
204
|
+
{to_prompt_json([ep for ep in context['previous_episodes']])}
|
|
201
205
|
</PREVIOUS MESSAGES>
|
|
202
206
|
<CURRENT MESSAGE>
|
|
203
207
|
{context['episode_content']}
|
|
@@ -221,22 +225,22 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
|
|
|
221
225
|
|
|
222
226
|
user_prompt = f"""
|
|
223
227
|
<PREVIOUS MESSAGES>
|
|
224
|
-
{to_prompt_json([ep for ep in context['previous_episodes']]
|
|
228
|
+
{to_prompt_json([ep for ep in context['previous_episodes']])}
|
|
225
229
|
</PREVIOUS MESSAGES>
|
|
226
230
|
<CURRENT MESSAGE>
|
|
227
231
|
{context['episode_content']}
|
|
228
232
|
</CURRENT MESSAGE>
|
|
229
|
-
|
|
233
|
+
|
|
230
234
|
<EXTRACTED ENTITIES>
|
|
231
235
|
{context['extracted_entities']}
|
|
232
236
|
</EXTRACTED ENTITIES>
|
|
233
|
-
|
|
237
|
+
|
|
234
238
|
<ENTITY TYPES>
|
|
235
239
|
{context['entity_types']}
|
|
236
240
|
</ENTITY TYPES>
|
|
237
|
-
|
|
241
|
+
|
|
238
242
|
Given the above conversation, extracted entities, and provided entity types and their descriptions, classify the extracted entities.
|
|
239
|
-
|
|
243
|
+
|
|
240
244
|
Guidelines:
|
|
241
245
|
1. Each entity must have exactly one type
|
|
242
246
|
2. Only use the provided ENTITY TYPES as types, do not use additional types to classify entities.
|
|
@@ -257,19 +261,18 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
|
|
257
261
|
Message(
|
|
258
262
|
role='user',
|
|
259
263
|
content=f"""
|
|
260
|
-
|
|
261
|
-
<MESSAGES>
|
|
262
|
-
{to_prompt_json(context['previous_episodes'], indent=2)}
|
|
263
|
-
{to_prompt_json(context['episode_content'], indent=2)}
|
|
264
|
-
</MESSAGES>
|
|
265
|
-
|
|
266
|
-
Given the above MESSAGES and the following ENTITY, update any of its attributes based on the information provided
|
|
264
|
+
Given the MESSAGES and the following ENTITY, update any of its attributes based on the information provided
|
|
267
265
|
in MESSAGES. Use the provided attribute descriptions to better understand how each attribute should be determined.
|
|
268
266
|
|
|
269
267
|
Guidelines:
|
|
270
268
|
1. Do not hallucinate entity property values if they cannot be found in the current context.
|
|
271
269
|
2. Only use the provided MESSAGES and ENTITY to set attribute values.
|
|
272
|
-
|
|
270
|
+
|
|
271
|
+
<MESSAGES>
|
|
272
|
+
{to_prompt_json(context['previous_episodes'])}
|
|
273
|
+
{to_prompt_json(context['episode_content'])}
|
|
274
|
+
</MESSAGES>
|
|
275
|
+
|
|
273
276
|
<ENTITY>
|
|
274
277
|
{context['node']}
|
|
275
278
|
</ENTITY>
|
|
@@ -287,21 +290,16 @@ def extract_summary(context: dict[str, Any]) -> list[Message]:
|
|
|
287
290
|
Message(
|
|
288
291
|
role='user',
|
|
289
292
|
content=f"""
|
|
293
|
+
Given the MESSAGES and the ENTITY, update the summary that combines relevant information about the entity
|
|
294
|
+
from the messages and relevant information from the existing summary.
|
|
295
|
+
|
|
296
|
+
{summary_instructions}
|
|
290
297
|
|
|
291
298
|
<MESSAGES>
|
|
292
|
-
{to_prompt_json(context['previous_episodes']
|
|
293
|
-
{to_prompt_json(context['episode_content']
|
|
299
|
+
{to_prompt_json(context['previous_episodes'])}
|
|
300
|
+
{to_prompt_json(context['episode_content'])}
|
|
294
301
|
</MESSAGES>
|
|
295
302
|
|
|
296
|
-
Given the above MESSAGES and the following ENTITY, update the summary that combines relevant information about the entity
|
|
297
|
-
from the messages and relevant information from the existing summary.
|
|
298
|
-
|
|
299
|
-
Guidelines:
|
|
300
|
-
1. Do not hallucinate entity summary information if they cannot be found in the current context.
|
|
301
|
-
2. Only use the provided MESSAGES and ENTITY to set attribute values.
|
|
302
|
-
3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
|
|
303
|
-
Summaries must be no longer than 250 words.
|
|
304
|
-
|
|
305
303
|
<ENTITY>
|
|
306
304
|
{context['node']}
|
|
307
305
|
</ENTITY>
|
|
@@ -1,17 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
1
17
|
import json
|
|
2
18
|
from typing import Any
|
|
3
19
|
|
|
4
20
|
DO_NOT_ESCAPE_UNICODE = '\nDo not escape unicode characters.\n'
|
|
5
21
|
|
|
6
22
|
|
|
7
|
-
def to_prompt_json(data: Any, ensure_ascii: bool = False, indent: int =
|
|
23
|
+
def to_prompt_json(data: Any, ensure_ascii: bool = False, indent: int | None = None) -> str:
|
|
8
24
|
"""
|
|
9
25
|
Serialize data to JSON for use in prompts.
|
|
10
26
|
|
|
11
27
|
Args:
|
|
12
28
|
data: The data to serialize
|
|
13
29
|
ensure_ascii: If True, escape non-ASCII characters. If False (default), preserve them.
|
|
14
|
-
indent: Number of spaces for indentation
|
|
30
|
+
indent: Number of spaces for indentation. Defaults to None (minified).
|
|
15
31
|
|
|
16
32
|
Returns:
|
|
17
33
|
JSON string representation of the data
|