graphiti-core 0.6.1__tar.gz → 0.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/PKG-INFO +1 -1
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/graphiti.py +5 -1
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/models/nodes/node_db_queries.py +4 -2
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/nodes.py +27 -10
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/extract_nodes.py +43 -1
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/summarize_nodes.py +12 -6
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/search/search_utils.py +14 -6
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/bulk_utils.py +17 -1
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/node_operations.py +58 -6
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/pyproject.toml +1 -1
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/LICENSE +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/README.md +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/cross_encoder/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/cross_encoder/client.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/cross_encoder/openai_reranker_client.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/edges.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/embedder/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/embedder/client.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/embedder/openai.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/embedder/voyage.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/helpers.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/llm_client/client.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/llm_client/openai_client.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/llm_client/openai_generic_client.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/models/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/models/edges/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/models/edges/edge_db_queries.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/models/nodes/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/eval.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/prompts/prompt_helpers.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/search/search.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/search/search_config.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/search/search_config_recipes.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/search/search_filters.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/datetime_utils.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/community_operations.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/edge_operations.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/graph_data_operations.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
- {graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/utils.py +0 -0
|
@@ -262,6 +262,7 @@ class Graphiti:
|
|
|
262
262
|
group_id: str = '',
|
|
263
263
|
uuid: str | None = None,
|
|
264
264
|
update_communities: bool = False,
|
|
265
|
+
entity_types: dict[str, BaseModel] | None = None,
|
|
265
266
|
) -> AddEpisodeResults:
|
|
266
267
|
"""
|
|
267
268
|
Process an episode and update the graph.
|
|
@@ -336,7 +337,9 @@ class Graphiti:
|
|
|
336
337
|
|
|
337
338
|
# Extract entities as nodes
|
|
338
339
|
|
|
339
|
-
extracted_nodes = await extract_nodes(
|
|
340
|
+
extracted_nodes = await extract_nodes(
|
|
341
|
+
self.llm_client, episode, previous_episodes, entity_types
|
|
342
|
+
)
|
|
340
343
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
341
344
|
|
|
342
345
|
# Calculate Embeddings
|
|
@@ -362,6 +365,7 @@ class Graphiti:
|
|
|
362
365
|
existing_nodes_lists,
|
|
363
366
|
episode,
|
|
364
367
|
previous_episodes,
|
|
368
|
+
entity_types,
|
|
365
369
|
),
|
|
366
370
|
extract_edges(
|
|
367
371
|
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
|
@@ -31,14 +31,16 @@ EPISODIC_NODE_SAVE_BULK = """
|
|
|
31
31
|
|
|
32
32
|
ENTITY_NODE_SAVE = """
|
|
33
33
|
MERGE (n:Entity {uuid: $uuid})
|
|
34
|
-
SET n
|
|
34
|
+
SET n:$($labels)
|
|
35
|
+
SET n = $entity_data
|
|
35
36
|
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
|
36
37
|
RETURN n.uuid AS uuid"""
|
|
37
38
|
|
|
38
39
|
ENTITY_NODE_SAVE_BULK = """
|
|
39
40
|
UNWIND $nodes AS node
|
|
40
41
|
MERGE (n:Entity {uuid: node.uuid})
|
|
41
|
-
SET n
|
|
42
|
+
SET n:$(node.labels)
|
|
43
|
+
SET n = node
|
|
42
44
|
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
|
43
45
|
RETURN n.uuid AS uuid
|
|
44
46
|
"""
|
|
@@ -255,6 +255,9 @@ class EpisodicNode(Node):
|
|
|
255
255
|
class EntityNode(Node):
|
|
256
256
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
257
257
|
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
|
258
|
+
attributes: dict[str, Any] = Field(
|
|
259
|
+
default={}, description='Additional attributes of the node. Dependent on node labels'
|
|
260
|
+
)
|
|
258
261
|
|
|
259
262
|
async def generate_name_embedding(self, embedder: EmbedderClient):
|
|
260
263
|
start = time()
|
|
@@ -266,14 +269,21 @@ class EntityNode(Node):
|
|
|
266
269
|
return self.name_embedding
|
|
267
270
|
|
|
268
271
|
async def save(self, driver: AsyncDriver):
|
|
272
|
+
entity_data: dict[str, Any] = {
|
|
273
|
+
'uuid': self.uuid,
|
|
274
|
+
'name': self.name,
|
|
275
|
+
'name_embedding': self.name_embedding,
|
|
276
|
+
'group_id': self.group_id,
|
|
277
|
+
'summary': self.summary,
|
|
278
|
+
'created_at': self.created_at,
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
entity_data.update(self.attributes or {})
|
|
282
|
+
|
|
269
283
|
result = await driver.execute_query(
|
|
270
284
|
ENTITY_NODE_SAVE,
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
group_id=self.group_id,
|
|
274
|
-
summary=self.summary,
|
|
275
|
-
name_embedding=self.name_embedding,
|
|
276
|
-
created_at=self.created_at,
|
|
285
|
+
labels=self.labels + ['Entity'],
|
|
286
|
+
entity_data=entity_data,
|
|
277
287
|
database_=DEFAULT_DATABASE,
|
|
278
288
|
)
|
|
279
289
|
|
|
@@ -292,7 +302,9 @@ class EntityNode(Node):
|
|
|
292
302
|
n.name_embedding AS name_embedding,
|
|
293
303
|
n.group_id AS group_id,
|
|
294
304
|
n.created_at AS created_at,
|
|
295
|
-
n.summary AS summary
|
|
305
|
+
n.summary AS summary,
|
|
306
|
+
labels(n) AS labels,
|
|
307
|
+
properties(n) AS attributes
|
|
296
308
|
""",
|
|
297
309
|
uuid=uuid,
|
|
298
310
|
database_=DEFAULT_DATABASE,
|
|
@@ -317,7 +329,9 @@ class EntityNode(Node):
|
|
|
317
329
|
n.name_embedding AS name_embedding,
|
|
318
330
|
n.group_id AS group_id,
|
|
319
331
|
n.created_at AS created_at,
|
|
320
|
-
n.summary AS summary
|
|
332
|
+
n.summary AS summary,
|
|
333
|
+
labels(n) AS labels,
|
|
334
|
+
properties(n) AS attributes
|
|
321
335
|
""",
|
|
322
336
|
uuids=uuids,
|
|
323
337
|
database_=DEFAULT_DATABASE,
|
|
@@ -351,7 +365,9 @@ class EntityNode(Node):
|
|
|
351
365
|
n.name_embedding AS name_embedding,
|
|
352
366
|
n.group_id AS group_id,
|
|
353
367
|
n.created_at AS created_at,
|
|
354
|
-
n.summary AS summary
|
|
368
|
+
n.summary AS summary,
|
|
369
|
+
labels(n) AS labels,
|
|
370
|
+
properties(n) AS attributes
|
|
355
371
|
ORDER BY n.uuid DESC
|
|
356
372
|
"""
|
|
357
373
|
+ limit_query,
|
|
@@ -503,9 +519,10 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
|
503
519
|
name=record['name'],
|
|
504
520
|
group_id=record['group_id'],
|
|
505
521
|
name_embedding=record['name_embedding'],
|
|
506
|
-
labels=['
|
|
522
|
+
labels=record['labels'],
|
|
507
523
|
created_at=record['created_at'].to_native(),
|
|
508
524
|
summary=record['summary'],
|
|
525
|
+
attributes=record['attributes'],
|
|
509
526
|
)
|
|
510
527
|
|
|
511
528
|
|
|
@@ -30,11 +30,19 @@ class MissedEntities(BaseModel):
|
|
|
30
30
|
missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
class EntityClassification(BaseModel):
|
|
34
|
+
entity_classification: str = Field(
|
|
35
|
+
...,
|
|
36
|
+
description='Dictionary of entity classifications. Key is the entity name and value is the entity type',
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
33
40
|
class Prompt(Protocol):
|
|
34
41
|
extract_message: PromptVersion
|
|
35
42
|
extract_json: PromptVersion
|
|
36
43
|
extract_text: PromptVersion
|
|
37
44
|
reflexion: PromptVersion
|
|
45
|
+
classify_nodes: PromptVersion
|
|
38
46
|
|
|
39
47
|
|
|
40
48
|
class Versions(TypedDict):
|
|
@@ -42,6 +50,7 @@ class Versions(TypedDict):
|
|
|
42
50
|
extract_json: PromptFunction
|
|
43
51
|
extract_text: PromptFunction
|
|
44
52
|
reflexion: PromptFunction
|
|
53
|
+
classify_nodes: PromptFunction
|
|
45
54
|
|
|
46
55
|
|
|
47
56
|
def extract_message(context: dict[str, Any]) -> list[Message]:
|
|
@@ -66,6 +75,7 @@ Guidelines:
|
|
|
66
75
|
4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
|
|
67
76
|
5. Be as explicit as possible in your node names, using full names.
|
|
68
77
|
6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
|
|
78
|
+
7. Extract preferences as their own nodes
|
|
69
79
|
"""
|
|
70
80
|
return [
|
|
71
81
|
Message(role='system', content=sys_prompt),
|
|
@@ -109,7 +119,7 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
|
|
|
109
119
|
|
|
110
120
|
{context['custom_prompt']}
|
|
111
121
|
|
|
112
|
-
Given the
|
|
122
|
+
Given the above text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned:
|
|
113
123
|
|
|
114
124
|
Guidelines:
|
|
115
125
|
1. Extract significant entities, concepts, or actors mentioned in the conversation.
|
|
@@ -147,9 +157,41 @@ extracted.
|
|
|
147
157
|
]
|
|
148
158
|
|
|
149
159
|
|
|
160
|
+
def classify_nodes(context: dict[str, Any]) -> list[Message]:
|
|
161
|
+
sys_prompt = """You are an AI assistant that classifies entity nodes given the context from which they were extracted"""
|
|
162
|
+
|
|
163
|
+
user_prompt = f"""
|
|
164
|
+
<PREVIOUS MESSAGES>
|
|
165
|
+
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
|
166
|
+
</PREVIOUS MESSAGES>
|
|
167
|
+
<CURRENT MESSAGE>
|
|
168
|
+
{context["episode_content"]}
|
|
169
|
+
</CURRENT MESSAGE>
|
|
170
|
+
|
|
171
|
+
<EXTRACTED ENTITIES>
|
|
172
|
+
{context['extracted_entities']}
|
|
173
|
+
</EXTRACTED ENTITIES>
|
|
174
|
+
|
|
175
|
+
<ENTITY TYPES>
|
|
176
|
+
{context['entity_types']}
|
|
177
|
+
</ENTITY TYPES>
|
|
178
|
+
|
|
179
|
+
Given the above conversation, extracted entities, and provided entity types, classify the extracted entities.
|
|
180
|
+
|
|
181
|
+
Guidelines:
|
|
182
|
+
1. Each entity must have exactly one type
|
|
183
|
+
2. If none of the provided entity types accurately classify an extracted node, the type should be set to None
|
|
184
|
+
"""
|
|
185
|
+
return [
|
|
186
|
+
Message(role='system', content=sys_prompt),
|
|
187
|
+
Message(role='user', content=user_prompt),
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
|
|
150
191
|
versions: Versions = {
|
|
151
192
|
'extract_message': extract_message,
|
|
152
193
|
'extract_json': extract_json,
|
|
153
194
|
'extract_text': extract_text,
|
|
154
195
|
'reflexion': reflexion,
|
|
196
|
+
'classify_nodes': classify_nodes,
|
|
155
197
|
}
|
|
@@ -24,7 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
|
|
|
24
24
|
|
|
25
25
|
class Summary(BaseModel):
|
|
26
26
|
summary: str = Field(
|
|
27
|
-
...,
|
|
27
|
+
...,
|
|
28
|
+
description='Summary containing the important information about the entity. Under 500 words',
|
|
28
29
|
)
|
|
29
30
|
|
|
30
31
|
|
|
@@ -68,7 +69,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
|
|
68
69
|
return [
|
|
69
70
|
Message(
|
|
70
71
|
role='system',
|
|
71
|
-
content='You are a helpful assistant that
|
|
72
|
+
content='You are a helpful assistant that extracts entity properties from the provided text.',
|
|
72
73
|
),
|
|
73
74
|
Message(
|
|
74
75
|
role='user',
|
|
@@ -79,18 +80,23 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
|
|
79
80
|
{json.dumps(context['episode_content'], indent=2)}
|
|
80
81
|
</MESSAGES>
|
|
81
82
|
|
|
82
|
-
Given the above MESSAGES and the following ENTITY name
|
|
83
|
-
information from the provided MESSAGES
|
|
84
|
-
provided ENTITY.
|
|
83
|
+
Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
|
|
84
|
+
information from the provided MESSAGES. Your summary should also only contain information relevant to the
|
|
85
|
+
provided ENTITY. Summaries must be under 500 words.
|
|
85
86
|
|
|
86
|
-
|
|
87
|
+
In addition, extract any values for the provided entity properties based on their descriptions.
|
|
87
88
|
|
|
88
89
|
<ENTITY>
|
|
89
90
|
{context['node_name']}
|
|
90
91
|
</ENTITY>
|
|
92
|
+
|
|
91
93
|
<ENTITY CONTEXT>
|
|
92
94
|
{context['node_summary']}
|
|
93
95
|
</ENTITY CONTEXT>
|
|
96
|
+
|
|
97
|
+
<ATTRIBUTES>
|
|
98
|
+
{json.dumps(context['attributes'], indent=2)}
|
|
99
|
+
</ATTRIBUTES>
|
|
94
100
|
""",
|
|
95
101
|
),
|
|
96
102
|
]
|
|
@@ -97,7 +97,9 @@ async def get_mentioned_nodes(
|
|
|
97
97
|
n.name AS name,
|
|
98
98
|
n.name_embedding AS name_embedding,
|
|
99
99
|
n.created_at AS created_at,
|
|
100
|
-
n.summary AS summary
|
|
100
|
+
n.summary AS summary,
|
|
101
|
+
labels(n) AS labels,
|
|
102
|
+
properties(n) AS attributes
|
|
101
103
|
""",
|
|
102
104
|
uuids=episode_uuids,
|
|
103
105
|
database_=DEFAULT_DATABASE,
|
|
@@ -223,8 +225,8 @@ async def edge_similarity_search(
|
|
|
223
225
|
|
|
224
226
|
query: LiteralString = (
|
|
225
227
|
"""
|
|
226
|
-
|
|
227
|
-
|
|
228
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
229
|
+
"""
|
|
228
230
|
+ group_filter_query
|
|
229
231
|
+ filter_query
|
|
230
232
|
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
|
@@ -341,7 +343,9 @@ async def node_fulltext_search(
|
|
|
341
343
|
n.name AS name,
|
|
342
344
|
n.name_embedding AS name_embedding,
|
|
343
345
|
n.created_at AS created_at,
|
|
344
|
-
n.summary AS summary
|
|
346
|
+
n.summary AS summary,
|
|
347
|
+
labels(n) AS labels,
|
|
348
|
+
properties(n) AS attributes
|
|
345
349
|
ORDER BY score DESC
|
|
346
350
|
LIMIT $limit
|
|
347
351
|
""",
|
|
@@ -390,7 +394,9 @@ async def node_similarity_search(
|
|
|
390
394
|
n.name AS name,
|
|
391
395
|
n.name_embedding AS name_embedding,
|
|
392
396
|
n.created_at AS created_at,
|
|
393
|
-
n.summary AS summary
|
|
397
|
+
n.summary AS summary,
|
|
398
|
+
labels(n) AS labels,
|
|
399
|
+
properties(n) AS attributes
|
|
394
400
|
ORDER BY score DESC
|
|
395
401
|
LIMIT $limit
|
|
396
402
|
""",
|
|
@@ -427,7 +433,9 @@ async def node_bfs_search(
|
|
|
427
433
|
n.name AS name,
|
|
428
434
|
n.name_embedding AS name_embedding,
|
|
429
435
|
n.created_at AS created_at,
|
|
430
|
-
n.summary AS summary
|
|
436
|
+
n.summary AS summary,
|
|
437
|
+
labels(n) AS labels,
|
|
438
|
+
properties(n) AS attributes
|
|
431
439
|
LIMIT $limit
|
|
432
440
|
""",
|
|
433
441
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
@@ -23,6 +23,7 @@ from math import ceil
|
|
|
23
23
|
from neo4j import AsyncDriver, AsyncManagedTransaction
|
|
24
24
|
from numpy import dot, sqrt
|
|
25
25
|
from pydantic import BaseModel
|
|
26
|
+
from typing_extensions import Any
|
|
26
27
|
|
|
27
28
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
|
28
29
|
from graphiti_core.helpers import semaphore_gather
|
|
@@ -109,8 +110,23 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
109
110
|
episodes = [dict(episode) for episode in episodic_nodes]
|
|
110
111
|
for episode in episodes:
|
|
111
112
|
episode['source'] = str(episode['source'].value)
|
|
113
|
+
nodes: list[dict[str, Any]] = []
|
|
114
|
+
for node in entity_nodes:
|
|
115
|
+
entity_data: dict[str, Any] = {
|
|
116
|
+
'uuid': node.uuid,
|
|
117
|
+
'name': node.name,
|
|
118
|
+
'name_embedding': node.name_embedding,
|
|
119
|
+
'group_id': node.group_id,
|
|
120
|
+
'summary': node.summary,
|
|
121
|
+
'created_at': node.created_at,
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
entity_data.update(node.attributes or {})
|
|
125
|
+
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
|
126
|
+
nodes.append(entity_data)
|
|
127
|
+
|
|
112
128
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
|
113
|
-
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=
|
|
129
|
+
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
|
114
130
|
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
|
|
115
131
|
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
|
|
116
132
|
|
{graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/node_operations.py
RENAMED
|
@@ -14,15 +14,19 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import ast
|
|
17
18
|
import logging
|
|
18
19
|
from time import time
|
|
19
20
|
|
|
21
|
+
import pydantic
|
|
22
|
+
from pydantic import BaseModel
|
|
23
|
+
|
|
20
24
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
|
21
25
|
from graphiti_core.llm_client import LLMClient
|
|
22
26
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
23
27
|
from graphiti_core.prompts import prompt_library
|
|
24
28
|
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
|
25
|
-
from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities
|
|
29
|
+
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
|
|
26
30
|
from graphiti_core.prompts.summarize_nodes import Summary
|
|
27
31
|
from graphiti_core.utils.datetime_utils import utc_now
|
|
28
32
|
|
|
@@ -114,6 +118,7 @@ async def extract_nodes(
|
|
|
114
118
|
llm_client: LLMClient,
|
|
115
119
|
episode: EpisodicNode,
|
|
116
120
|
previous_episodes: list[EpisodicNode],
|
|
121
|
+
entity_types: dict[str, BaseModel] | None = None,
|
|
117
122
|
) -> list[EntityNode]:
|
|
118
123
|
start = time()
|
|
119
124
|
extracted_node_names: list[str] = []
|
|
@@ -144,15 +149,35 @@ async def extract_nodes(
|
|
|
144
149
|
for entity in missing_entities:
|
|
145
150
|
custom_prompt += f'\n{entity},'
|
|
146
151
|
|
|
152
|
+
node_classification_context = {
|
|
153
|
+
'episode_content': episode.content,
|
|
154
|
+
'previous_episodes': [ep.content for ep in previous_episodes],
|
|
155
|
+
'extracted_entities': extracted_node_names,
|
|
156
|
+
'entity_types': entity_types.keys() if entity_types is not None else [],
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
node_classifications: dict[str, str | None] = {}
|
|
160
|
+
|
|
161
|
+
if entity_types is not None:
|
|
162
|
+
llm_response = await llm_client.generate_response(
|
|
163
|
+
prompt_library.extract_nodes.classify_nodes(node_classification_context),
|
|
164
|
+
response_model=EntityClassification,
|
|
165
|
+
)
|
|
166
|
+
response_string = llm_response.get('entity_classification', '{}')
|
|
167
|
+
node_classifications.update(ast.literal_eval(response_string))
|
|
168
|
+
|
|
147
169
|
end = time()
|
|
148
170
|
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
|
|
149
171
|
# Convert the extracted data into EntityNode objects
|
|
150
172
|
new_nodes = []
|
|
151
173
|
for name in extracted_node_names:
|
|
174
|
+
entity_type = node_classifications.get(name)
|
|
175
|
+
labels = ['Entity'] if entity_type is None else ['Entity', entity_type]
|
|
176
|
+
|
|
152
177
|
new_node = EntityNode(
|
|
153
178
|
name=name,
|
|
154
179
|
group_id=episode.group_id,
|
|
155
|
-
labels=
|
|
180
|
+
labels=labels,
|
|
156
181
|
summary='',
|
|
157
182
|
created_at=utc_now(),
|
|
158
183
|
)
|
|
@@ -218,6 +243,7 @@ async def resolve_extracted_nodes(
|
|
|
218
243
|
existing_nodes_lists: list[list[EntityNode]],
|
|
219
244
|
episode: EpisodicNode | None = None,
|
|
220
245
|
previous_episodes: list[EpisodicNode] | None = None,
|
|
246
|
+
entity_types: dict[str, BaseModel] | None = None,
|
|
221
247
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
222
248
|
uuid_map: dict[str, str] = {}
|
|
223
249
|
resolved_nodes: list[EntityNode] = []
|
|
@@ -225,7 +251,12 @@ async def resolve_extracted_nodes(
|
|
|
225
251
|
await semaphore_gather(
|
|
226
252
|
*[
|
|
227
253
|
resolve_extracted_node(
|
|
228
|
-
llm_client,
|
|
254
|
+
llm_client,
|
|
255
|
+
extracted_node,
|
|
256
|
+
existing_nodes,
|
|
257
|
+
episode,
|
|
258
|
+
previous_episodes,
|
|
259
|
+
entity_types,
|
|
229
260
|
)
|
|
230
261
|
for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists)
|
|
231
262
|
]
|
|
@@ -245,6 +276,7 @@ async def resolve_extracted_node(
|
|
|
245
276
|
existing_nodes: list[EntityNode],
|
|
246
277
|
episode: EpisodicNode | None = None,
|
|
247
278
|
previous_episodes: list[EpisodicNode] | None = None,
|
|
279
|
+
entity_types: dict[str, BaseModel] | None = None,
|
|
248
280
|
) -> tuple[EntityNode, dict[str, str]]:
|
|
249
281
|
start = time()
|
|
250
282
|
|
|
@@ -273,19 +305,39 @@ async def resolve_extracted_node(
|
|
|
273
305
|
'previous_episodes': [ep.content for ep in previous_episodes]
|
|
274
306
|
if previous_episodes is not None
|
|
275
307
|
else [],
|
|
308
|
+
'attributes': [],
|
|
276
309
|
}
|
|
277
310
|
|
|
278
|
-
|
|
311
|
+
entity_type_classes: tuple[BaseModel, ...] = tuple()
|
|
312
|
+
if entity_types is not None: # type: ignore
|
|
313
|
+
entity_type_classes = entity_type_classes + tuple(
|
|
314
|
+
filter(
|
|
315
|
+
lambda x: x is not None, # type: ignore
|
|
316
|
+
[entity_types.get(entity_type) for entity_type in extracted_node.labels], # type: ignore
|
|
317
|
+
)
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
for entity_type in entity_type_classes:
|
|
321
|
+
for field_name in entity_type.model_fields:
|
|
322
|
+
summary_context.get('attributes', []).append(field_name) # type: ignore
|
|
323
|
+
|
|
324
|
+
entity_attributes_model = pydantic.create_model( # type: ignore
|
|
325
|
+
'EntityAttributes',
|
|
326
|
+
__base__=entity_type_classes + (Summary,), # type: ignore
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
llm_response, node_attributes_response = await semaphore_gather(
|
|
279
330
|
llm_client.generate_response(
|
|
280
331
|
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
|
|
281
332
|
),
|
|
282
333
|
llm_client.generate_response(
|
|
283
334
|
prompt_library.summarize_nodes.summarize_context(summary_context),
|
|
284
|
-
response_model=
|
|
335
|
+
response_model=entity_attributes_model,
|
|
285
336
|
),
|
|
286
337
|
)
|
|
287
338
|
|
|
288
|
-
extracted_node.summary =
|
|
339
|
+
extracted_node.summary = node_attributes_response.get('summary', '')
|
|
340
|
+
extracted_node.attributes.update(node_attributes_response)
|
|
289
341
|
|
|
290
342
|
is_duplicate: bool = llm_response.get('is_duplicate', False)
|
|
291
343
|
uuid: str | None = llm_response.get('uuid', None)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/cross_encoder/bge_reranker_client.py
RENAMED
|
File without changes
|
|
File without changes
|
{graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/cross_encoder/openai_reranker_client.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/llm_client/openai_generic_client.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/community_operations.py
RENAMED
|
File without changes
|
{graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/edge_operations.py
RENAMED
|
File without changes
|
{graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/graph_data_operations.py
RENAMED
|
File without changes
|
{graphiti_core-0.6.1 → graphiti_core-0.7.0}/graphiti_core/utils/maintenance/temporal_operations.py
RENAMED
|
File without changes
|
|
File without changes
|