graphiti-core 0.20.4__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +28 -0
- graphiti_core/driver/falkordb_driver.py +112 -0
- graphiti_core/driver/kuzu_driver.py +1 -0
- graphiti_core/driver/neo4j_driver.py +10 -2
- graphiti_core/driver/neptune_driver.py +4 -6
- graphiti_core/edges.py +67 -7
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +35 -6
- graphiti_core/graphiti.py +27 -23
- graphiti_core/graphiti_types.py +0 -1
- graphiti_core/helpers.py +2 -2
- graphiti_core/llm_client/client.py +19 -4
- graphiti_core/llm_client/gemini_client.py +4 -2
- graphiti_core/llm_client/openai_base_client.py +3 -2
- graphiti_core/llm_client/openai_generic_client.py +3 -2
- graphiti_core/models/edges/edge_db_queries.py +36 -16
- graphiti_core/models/nodes/node_db_queries.py +30 -10
- graphiti_core/nodes.py +126 -25
- graphiti_core/prompts/dedupe_edges.py +40 -29
- graphiti_core/prompts/dedupe_nodes.py +51 -34
- graphiti_core/prompts/eval.py +3 -3
- graphiti_core/prompts/extract_edges.py +17 -9
- graphiti_core/prompts/extract_nodes.py +10 -9
- graphiti_core/prompts/prompt_helpers.py +3 -3
- graphiti_core/prompts/summarize_nodes.py +5 -5
- graphiti_core/search/search_filters.py +53 -0
- graphiti_core/search/search_helpers.py +5 -7
- graphiti_core/search/search_utils.py +227 -57
- graphiti_core/utils/bulk_utils.py +168 -69
- graphiti_core/utils/maintenance/community_operations.py +8 -20
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +187 -50
- graphiti_core/utils/maintenance/graph_data_operations.py +9 -5
- graphiti_core/utils/maintenance/node_operations.py +244 -88
- graphiti_core/utils/maintenance/temporal_operations.py +0 -4
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/METADATA +7 -1
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/RECORD +39 -38
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -25,7 +25,7 @@ from openai.types.chat import ChatCompletionMessageParam
|
|
|
25
25
|
from pydantic import BaseModel
|
|
26
26
|
|
|
27
27
|
from ..prompts.models import Message
|
|
28
|
-
from .client import
|
|
28
|
+
from .client import LLMClient, get_extraction_language_instruction
|
|
29
29
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
|
30
30
|
from .errors import RateLimitError, RefusalError
|
|
31
31
|
|
|
@@ -175,6 +175,7 @@ class BaseOpenAIClient(LLMClient):
|
|
|
175
175
|
response_model: type[BaseModel] | None = None,
|
|
176
176
|
max_tokens: int | None = None,
|
|
177
177
|
model_size: ModelSize = ModelSize.medium,
|
|
178
|
+
group_id: str | None = None,
|
|
178
179
|
) -> dict[str, typing.Any]:
|
|
179
180
|
"""Generate a response with retry logic and error handling."""
|
|
180
181
|
if max_tokens is None:
|
|
@@ -184,7 +185,7 @@ class BaseOpenAIClient(LLMClient):
|
|
|
184
185
|
last_error = None
|
|
185
186
|
|
|
186
187
|
# Add multilingual extraction instructions
|
|
187
|
-
messages[0].content +=
|
|
188
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
188
189
|
|
|
189
190
|
while retry_count <= self.MAX_RETRIES:
|
|
190
191
|
try:
|
|
@@ -25,7 +25,7 @@ from openai.types.chat import ChatCompletionMessageParam
|
|
|
25
25
|
from pydantic import BaseModel
|
|
26
26
|
|
|
27
27
|
from ..prompts.models import Message
|
|
28
|
-
from .client import
|
|
28
|
+
from .client import LLMClient, get_extraction_language_instruction
|
|
29
29
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
|
30
30
|
from .errors import RateLimitError, RefusalError
|
|
31
31
|
|
|
@@ -120,6 +120,7 @@ class OpenAIGenericClient(LLMClient):
|
|
|
120
120
|
response_model: type[BaseModel] | None = None,
|
|
121
121
|
max_tokens: int | None = None,
|
|
122
122
|
model_size: ModelSize = ModelSize.medium,
|
|
123
|
+
group_id: str | None = None,
|
|
123
124
|
) -> dict[str, typing.Any]:
|
|
124
125
|
if max_tokens is None:
|
|
125
126
|
max_tokens = self.max_tokens
|
|
@@ -136,7 +137,7 @@ class OpenAIGenericClient(LLMClient):
|
|
|
136
137
|
)
|
|
137
138
|
|
|
138
139
|
# Add multilingual extraction instructions
|
|
139
|
-
messages[0].content +=
|
|
140
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
140
141
|
|
|
141
142
|
while retry_count <= self.MAX_RETRIES:
|
|
142
143
|
try:
|
|
@@ -60,7 +60,7 @@ EPISODIC_EDGE_RETURN = """
|
|
|
60
60
|
"""
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
|
63
|
+
def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
|
|
64
64
|
match provider:
|
|
65
65
|
case GraphProvider.FALKORDB:
|
|
66
66
|
return """
|
|
@@ -99,17 +99,28 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
|
|
|
99
99
|
RETURN e.uuid AS uuid
|
|
100
100
|
"""
|
|
101
101
|
case _: # Neo4j
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
102
|
+
save_embedding_query = (
|
|
103
|
+
"""WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
|
|
104
|
+
if not has_aoss
|
|
105
|
+
else ''
|
|
106
|
+
)
|
|
107
|
+
return (
|
|
108
|
+
(
|
|
109
|
+
"""
|
|
110
|
+
MATCH (source:Entity {uuid: $edge_data.source_uuid})
|
|
111
|
+
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
|
112
|
+
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
|
113
|
+
SET e = $edge_data
|
|
114
|
+
"""
|
|
115
|
+
+ save_embedding_query
|
|
116
|
+
)
|
|
117
|
+
+ """
|
|
108
118
|
RETURN e.uuid AS uuid
|
|
109
|
-
|
|
119
|
+
"""
|
|
120
|
+
)
|
|
110
121
|
|
|
111
122
|
|
|
112
|
-
def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
|
123
|
+
def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
|
|
113
124
|
match provider:
|
|
114
125
|
case GraphProvider.FALKORDB:
|
|
115
126
|
return """
|
|
@@ -152,15 +163,24 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
|
|
|
152
163
|
RETURN e.uuid AS uuid
|
|
153
164
|
"""
|
|
154
165
|
case _:
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
166
|
+
save_embedding_query = (
|
|
167
|
+
'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
|
|
168
|
+
if not has_aoss
|
|
169
|
+
else ''
|
|
170
|
+
)
|
|
171
|
+
return (
|
|
172
|
+
"""
|
|
173
|
+
UNWIND $entity_edges AS edge
|
|
174
|
+
MATCH (source:Entity {uuid: edge.source_node_uuid})
|
|
175
|
+
MATCH (target:Entity {uuid: edge.target_node_uuid})
|
|
176
|
+
MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
|
|
177
|
+
SET e = edge
|
|
178
|
+
"""
|
|
179
|
+
+ save_embedding_query
|
|
180
|
+
+ """
|
|
162
181
|
RETURN edge.uuid AS uuid
|
|
163
182
|
"""
|
|
183
|
+
)
|
|
164
184
|
|
|
165
185
|
|
|
166
186
|
def get_entity_edge_return_query(provider: GraphProvider) -> str:
|
|
@@ -126,7 +126,7 @@ EPISODIC_NODE_RETURN_NEPTUNE = """
|
|
|
126
126
|
"""
|
|
127
127
|
|
|
128
128
|
|
|
129
|
-
def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
|
129
|
+
def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
|
|
130
130
|
match provider:
|
|
131
131
|
case GraphProvider.FALKORDB:
|
|
132
132
|
return f"""
|
|
@@ -161,16 +161,27 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
|
|
|
161
161
|
RETURN n.uuid AS uuid
|
|
162
162
|
"""
|
|
163
163
|
case _:
|
|
164
|
-
|
|
164
|
+
save_embedding_query = (
|
|
165
|
+
'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
|
|
166
|
+
if not has_aoss
|
|
167
|
+
else ''
|
|
168
|
+
)
|
|
169
|
+
return (
|
|
170
|
+
f"""
|
|
165
171
|
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
|
166
172
|
SET n:{labels}
|
|
167
173
|
SET n = $entity_data
|
|
168
|
-
|
|
174
|
+
"""
|
|
175
|
+
+ save_embedding_query
|
|
176
|
+
+ """
|
|
169
177
|
RETURN n.uuid AS uuid
|
|
170
178
|
"""
|
|
179
|
+
)
|
|
171
180
|
|
|
172
181
|
|
|
173
|
-
def get_entity_node_save_bulk_query(
|
|
182
|
+
def get_entity_node_save_bulk_query(
|
|
183
|
+
provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
|
|
184
|
+
) -> str | Any:
|
|
174
185
|
match provider:
|
|
175
186
|
case GraphProvider.FALKORDB:
|
|
176
187
|
queries = []
|
|
@@ -222,14 +233,23 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
|
|
|
222
233
|
RETURN n.uuid AS uuid
|
|
223
234
|
"""
|
|
224
235
|
case _: # Neo4j
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
236
|
+
save_embedding_query = (
|
|
237
|
+
'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
|
|
238
|
+
if not has_aoss
|
|
239
|
+
else ''
|
|
240
|
+
)
|
|
241
|
+
return (
|
|
242
|
+
"""
|
|
243
|
+
UNWIND $nodes AS node
|
|
244
|
+
MERGE (n:Entity {uuid: node.uuid})
|
|
245
|
+
SET n:$(node.labels)
|
|
246
|
+
SET n = node
|
|
247
|
+
"""
|
|
248
|
+
+ save_embedding_query
|
|
249
|
+
+ """
|
|
231
250
|
RETURN n.uuid AS uuid
|
|
232
251
|
"""
|
|
252
|
+
)
|
|
233
253
|
|
|
234
254
|
|
|
235
255
|
def get_entity_node_return_query(provider: GraphProvider) -> str:
|
graphiti_core/nodes.py
CHANGED
|
@@ -26,7 +26,14 @@ from uuid import uuid4
|
|
|
26
26
|
from pydantic import BaseModel, Field
|
|
27
27
|
from typing_extensions import LiteralString
|
|
28
28
|
|
|
29
|
-
from graphiti_core.driver.driver import
|
|
29
|
+
from graphiti_core.driver.driver import (
|
|
30
|
+
COMMUNITY_INDEX_NAME,
|
|
31
|
+
ENTITY_EDGE_INDEX_NAME,
|
|
32
|
+
ENTITY_INDEX_NAME,
|
|
33
|
+
EPISODE_INDEX_NAME,
|
|
34
|
+
GraphDriver,
|
|
35
|
+
GraphProvider,
|
|
36
|
+
)
|
|
30
37
|
from graphiti_core.embedder import EmbedderClient
|
|
31
38
|
from graphiti_core.errors import NodeNotFoundError
|
|
32
39
|
from graphiti_core.helpers import parse_db_date
|
|
@@ -94,13 +101,39 @@ class Node(BaseModel, ABC):
|
|
|
94
101
|
async def delete(self, driver: GraphDriver):
|
|
95
102
|
match driver.provider:
|
|
96
103
|
case GraphProvider.NEO4J:
|
|
97
|
-
await driver.execute_query(
|
|
104
|
+
records, _, _ = await driver.execute_query(
|
|
98
105
|
"""
|
|
99
|
-
MATCH (n
|
|
106
|
+
MATCH (n {uuid: $uuid})
|
|
107
|
+
WHERE n:Entity OR n:Episodic OR n:Community
|
|
108
|
+
OPTIONAL MATCH (n)-[r]-()
|
|
109
|
+
WITH collect(r.uuid) AS edge_uuids, n
|
|
100
110
|
DETACH DELETE n
|
|
111
|
+
RETURN edge_uuids
|
|
101
112
|
""",
|
|
102
113
|
uuid=self.uuid,
|
|
103
114
|
)
|
|
115
|
+
|
|
116
|
+
edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
|
|
117
|
+
|
|
118
|
+
if driver.aoss_client:
|
|
119
|
+
# Delete the node from OpenSearch indices
|
|
120
|
+
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
|
121
|
+
await driver.aoss_client.delete(
|
|
122
|
+
index=index,
|
|
123
|
+
id=self.uuid,
|
|
124
|
+
params={'routing': self.group_id},
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Bulk delete the detached edges
|
|
128
|
+
if edge_uuids:
|
|
129
|
+
actions = []
|
|
130
|
+
for eid in edge_uuids:
|
|
131
|
+
actions.append(
|
|
132
|
+
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
await driver.aoss_client.bulk(body=actions)
|
|
136
|
+
|
|
104
137
|
case GraphProvider.KUZU:
|
|
105
138
|
for label in ['Episodic', 'Community']:
|
|
106
139
|
await driver.execute_query(
|
|
@@ -162,6 +195,32 @@ class Node(BaseModel, ABC):
|
|
|
162
195
|
group_id=group_id,
|
|
163
196
|
batch_size=batch_size,
|
|
164
197
|
)
|
|
198
|
+
|
|
199
|
+
if driver.aoss_client:
|
|
200
|
+
await driver.aoss_client.delete_by_query(
|
|
201
|
+
index=EPISODE_INDEX_NAME,
|
|
202
|
+
body={'query': {'term': {'group_id': group_id}}},
|
|
203
|
+
params={'routing': group_id},
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
await driver.aoss_client.delete_by_query(
|
|
207
|
+
index=ENTITY_INDEX_NAME,
|
|
208
|
+
body={'query': {'term': {'group_id': group_id}}},
|
|
209
|
+
params={'routing': group_id},
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
await driver.aoss_client.delete_by_query(
|
|
213
|
+
index=COMMUNITY_INDEX_NAME,
|
|
214
|
+
body={'query': {'term': {'group_id': group_id}}},
|
|
215
|
+
params={'routing': group_id},
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
await driver.aoss_client.delete_by_query(
|
|
219
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
220
|
+
body={'query': {'term': {'group_id': group_id}}},
|
|
221
|
+
params={'routing': group_id},
|
|
222
|
+
)
|
|
223
|
+
|
|
165
224
|
case GraphProvider.KUZU:
|
|
166
225
|
for label in ['Episodic', 'Community']:
|
|
167
226
|
await driver.execute_query(
|
|
@@ -240,6 +299,23 @@ class Node(BaseModel, ABC):
|
|
|
240
299
|
)
|
|
241
300
|
case _: # Neo4J, Neptune
|
|
242
301
|
async with driver.session() as session:
|
|
302
|
+
# Collect all edge UUIDs before deleting nodes
|
|
303
|
+
result = await session.run(
|
|
304
|
+
"""
|
|
305
|
+
MATCH (n:Entity|Episodic|Community)
|
|
306
|
+
WHERE n.uuid IN $uuids
|
|
307
|
+
MATCH (n)-[r]-()
|
|
308
|
+
RETURN collect(r.uuid) AS edge_uuids
|
|
309
|
+
""",
|
|
310
|
+
uuids=uuids,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
record = await result.single()
|
|
314
|
+
edge_uuids: list[str] = (
|
|
315
|
+
record['edge_uuids'] if record and record['edge_uuids'] else []
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Now delete the nodes in batches
|
|
243
319
|
await session.run(
|
|
244
320
|
"""
|
|
245
321
|
MATCH (n:Entity|Episodic|Community)
|
|
@@ -253,6 +329,20 @@ class Node(BaseModel, ABC):
|
|
|
253
329
|
batch_size=batch_size,
|
|
254
330
|
)
|
|
255
331
|
|
|
332
|
+
if driver.aoss_client:
|
|
333
|
+
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
|
334
|
+
await driver.aoss_client.delete_by_query(
|
|
335
|
+
index=index,
|
|
336
|
+
body={'query': {'terms': {'uuid': uuids}}},
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
if edge_uuids:
|
|
340
|
+
actions = [
|
|
341
|
+
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
|
342
|
+
for eid in edge_uuids
|
|
343
|
+
]
|
|
344
|
+
await driver.aoss_client.bulk(body=actions)
|
|
345
|
+
|
|
256
346
|
@classmethod
|
|
257
347
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
258
348
|
|
|
@@ -273,20 +363,6 @@ class EpisodicNode(Node):
|
|
|
273
363
|
)
|
|
274
364
|
|
|
275
365
|
async def save(self, driver: GraphDriver):
|
|
276
|
-
if driver.provider == GraphProvider.NEPTUNE:
|
|
277
|
-
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
278
|
-
'episode_content',
|
|
279
|
-
[
|
|
280
|
-
{
|
|
281
|
-
'uuid': self.uuid,
|
|
282
|
-
'group_id': self.group_id,
|
|
283
|
-
'source': self.source.value,
|
|
284
|
-
'content': self.content,
|
|
285
|
-
'source_description': self.source_description,
|
|
286
|
-
}
|
|
287
|
-
],
|
|
288
|
-
)
|
|
289
|
-
|
|
290
366
|
episode_args = {
|
|
291
367
|
'uuid': self.uuid,
|
|
292
368
|
'name': self.name,
|
|
@@ -299,6 +375,12 @@ class EpisodicNode(Node):
|
|
|
299
375
|
'source': self.source.value,
|
|
300
376
|
}
|
|
301
377
|
|
|
378
|
+
if driver.aoss_client:
|
|
379
|
+
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
380
|
+
'episodes',
|
|
381
|
+
[episode_args],
|
|
382
|
+
)
|
|
383
|
+
|
|
302
384
|
result = await driver.execute_query(
|
|
303
385
|
get_episode_node_save_query(driver.provider), **episode_args
|
|
304
386
|
)
|
|
@@ -433,6 +515,22 @@ class EntityNode(Node):
|
|
|
433
515
|
MATCH (n:Entity {uuid: $uuid})
|
|
434
516
|
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
435
517
|
"""
|
|
518
|
+
elif driver.aoss_client:
|
|
519
|
+
resp = await driver.aoss_client.search(
|
|
520
|
+
body={
|
|
521
|
+
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
|
522
|
+
'size': 1,
|
|
523
|
+
},
|
|
524
|
+
index=ENTITY_INDEX_NAME,
|
|
525
|
+
params={'routing': self.group_id},
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
if resp['hits']['hits']:
|
|
529
|
+
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
|
|
530
|
+
return
|
|
531
|
+
else:
|
|
532
|
+
raise NodeNotFoundError(self.uuid)
|
|
533
|
+
|
|
436
534
|
else:
|
|
437
535
|
query: LiteralString = """
|
|
438
536
|
MATCH (n:Entity {uuid: $uuid})
|
|
@@ -470,11 +568,11 @@ class EntityNode(Node):
|
|
|
470
568
|
entity_data.update(self.attributes or {})
|
|
471
569
|
labels = ':'.join(self.labels + ['Entity'])
|
|
472
570
|
|
|
473
|
-
if driver.
|
|
474
|
-
driver.save_to_aoss(
|
|
571
|
+
if driver.aoss_client:
|
|
572
|
+
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
|
475
573
|
|
|
476
574
|
result = await driver.execute_query(
|
|
477
|
-
get_entity_node_save_query(driver.provider, labels),
|
|
575
|
+
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
|
|
478
576
|
entity_data=entity_data,
|
|
479
577
|
)
|
|
480
578
|
|
|
@@ -569,8 +667,8 @@ class CommunityNode(Node):
|
|
|
569
667
|
|
|
570
668
|
async def save(self, driver: GraphDriver):
|
|
571
669
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
572
|
-
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
573
|
-
'
|
|
670
|
+
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
671
|
+
'communities',
|
|
574
672
|
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
|
575
673
|
)
|
|
576
674
|
result = await driver.execute_query(
|
|
@@ -770,9 +868,12 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|
|
770
868
|
|
|
771
869
|
|
|
772
870
|
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
|
773
|
-
|
|
871
|
+
# filter out falsey values from nodes
|
|
872
|
+
filtered_nodes = [node for node in nodes if node.name]
|
|
873
|
+
|
|
874
|
+
if not filtered_nodes:
|
|
774
875
|
return
|
|
775
876
|
|
|
776
|
-
name_embeddings = await embedder.create_batch([node.name for node in
|
|
777
|
-
for node, name_embedding in zip(
|
|
877
|
+
name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
|
|
878
|
+
for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
|
|
778
879
|
node.name_embedding = name_embedding
|
|
@@ -25,11 +25,11 @@ from .prompt_helpers import to_prompt_json
|
|
|
25
25
|
class EdgeDuplicate(BaseModel):
|
|
26
26
|
duplicate_facts: list[int] = Field(
|
|
27
27
|
...,
|
|
28
|
-
description='List of
|
|
28
|
+
description='List of idx values of any duplicate facts. If no duplicate facts are found, default to empty list.',
|
|
29
29
|
)
|
|
30
30
|
contradicted_facts: list[int] = Field(
|
|
31
31
|
...,
|
|
32
|
-
description='List of
|
|
32
|
+
description='List of idx values of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
|
|
33
33
|
)
|
|
34
34
|
fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
|
|
35
35
|
|
|
@@ -67,11 +67,11 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|
|
67
67
|
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
|
|
68
68
|
|
|
69
69
|
<EXISTING EDGES>
|
|
70
|
-
{to_prompt_json(context['related_edges'],
|
|
70
|
+
{to_prompt_json(context['related_edges'], indent=2)}
|
|
71
71
|
</EXISTING EDGES>
|
|
72
72
|
|
|
73
73
|
<NEW EDGE>
|
|
74
|
-
{to_prompt_json(context['extracted_edges'],
|
|
74
|
+
{to_prompt_json(context['extracted_edges'], indent=2)}
|
|
75
75
|
</NEW EDGE>
|
|
76
76
|
|
|
77
77
|
Task:
|
|
@@ -98,7 +98,7 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
|
|
|
98
98
|
Given the following context, find all of the duplicates in a list of facts:
|
|
99
99
|
|
|
100
100
|
Facts:
|
|
101
|
-
{to_prompt_json(context['edges'],
|
|
101
|
+
{to_prompt_json(context['edges'], indent=2)}
|
|
102
102
|
|
|
103
103
|
Task:
|
|
104
104
|
If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
|
|
@@ -124,37 +124,48 @@ def resolve_edge(context: dict[str, Any]) -> list[Message]:
|
|
|
124
124
|
Message(
|
|
125
125
|
role='user',
|
|
126
126
|
content=f"""
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
127
|
+
Task:
|
|
128
|
+
You will receive TWO separate lists of facts. Each list uses 'idx' as its index field, starting from 0.
|
|
129
|
+
|
|
130
|
+
1. DUPLICATE DETECTION:
|
|
131
|
+
- If the NEW FACT represents identical factual information as any fact in EXISTING FACTS, return those idx values in duplicate_facts.
|
|
132
|
+
- Facts with similar information that contain key differences should NOT be marked as duplicates.
|
|
133
|
+
- Return idx values from EXISTING FACTS.
|
|
134
|
+
- If no duplicates, return an empty list for duplicate_facts.
|
|
135
|
+
|
|
136
|
+
2. FACT TYPE CLASSIFICATION:
|
|
137
|
+
- Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
|
|
138
|
+
- Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
|
|
139
|
+
|
|
140
|
+
3. CONTRADICTION DETECTION:
|
|
141
|
+
- Based on FACT INVALIDATION CANDIDATES and NEW FACT, determine which facts the new fact contradicts.
|
|
142
|
+
- Return idx values from FACT INVALIDATION CANDIDATES.
|
|
143
|
+
- If no contradictions, return an empty list for contradicted_facts.
|
|
144
|
+
|
|
145
|
+
IMPORTANT:
|
|
146
|
+
- duplicate_facts: Use ONLY 'idx' values from EXISTING FACTS
|
|
147
|
+
- contradicted_facts: Use ONLY 'idx' values from FACT INVALIDATION CANDIDATES
|
|
148
|
+
- These are two separate lists with independent idx ranges starting from 0
|
|
149
|
+
|
|
150
|
+
Guidelines:
|
|
151
|
+
1. Some facts may be very similar but will have key differences, particularly around numeric values in the facts.
|
|
152
|
+
Do not mark these facts as duplicates.
|
|
153
|
+
|
|
154
|
+
<FACT TYPES>
|
|
155
|
+
{context['edge_types']}
|
|
156
|
+
</FACT TYPES>
|
|
157
|
+
|
|
131
158
|
<EXISTING FACTS>
|
|
132
159
|
{context['existing_edges']}
|
|
133
160
|
</EXISTING FACTS>
|
|
161
|
+
|
|
134
162
|
<FACT INVALIDATION CANDIDATES>
|
|
135
163
|
{context['edge_invalidation_candidates']}
|
|
136
164
|
</FACT INVALIDATION CANDIDATES>
|
|
137
|
-
|
|
138
|
-
<FACT TYPES>
|
|
139
|
-
{context['edge_types']}
|
|
140
|
-
</FACT TYPES>
|
|
141
|
-
|
|
142
165
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return an empty list.
|
|
147
|
-
|
|
148
|
-
Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
|
|
149
|
-
Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
|
|
150
|
-
|
|
151
|
-
Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts.
|
|
152
|
-
Return a list containing all idx's of the facts that are contradicted by the NEW FACT.
|
|
153
|
-
If there are no contradicted facts, return an empty list.
|
|
154
|
-
|
|
155
|
-
Guidelines:
|
|
156
|
-
1. Some facts may be very similar but will have key differences, particularly around numeric values in the facts.
|
|
157
|
-
Do not mark these facts as duplicates.
|
|
166
|
+
<NEW FACT>
|
|
167
|
+
{context['new_edge']}
|
|
168
|
+
</NEW FACT>
|
|
158
169
|
""",
|
|
159
170
|
),
|
|
160
171
|
]
|