graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
graphiti_core/nodes.py
CHANGED
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
from abc import ABC, abstractmethod
|
|
19
20
|
from datetime import datetime
|
|
@@ -22,33 +23,30 @@ from time import time
|
|
|
22
23
|
from typing import Any
|
|
23
24
|
from uuid import uuid4
|
|
24
25
|
|
|
25
|
-
from neo4j import AsyncDriver
|
|
26
26
|
from pydantic import BaseModel, Field
|
|
27
27
|
from typing_extensions import LiteralString
|
|
28
28
|
|
|
29
|
+
from graphiti_core.driver.driver import (
|
|
30
|
+
GraphDriver,
|
|
31
|
+
GraphProvider,
|
|
32
|
+
)
|
|
29
33
|
from graphiti_core.embedder import EmbedderClient
|
|
30
34
|
from graphiti_core.errors import NodeNotFoundError
|
|
31
|
-
from graphiti_core.helpers import
|
|
35
|
+
from graphiti_core.helpers import parse_db_date
|
|
32
36
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
37
|
+
COMMUNITY_NODE_RETURN,
|
|
38
|
+
COMMUNITY_NODE_RETURN_NEPTUNE,
|
|
39
|
+
EPISODIC_NODE_RETURN,
|
|
40
|
+
EPISODIC_NODE_RETURN_NEPTUNE,
|
|
41
|
+
get_community_node_save_query,
|
|
42
|
+
get_entity_node_return_query,
|
|
43
|
+
get_entity_node_save_query,
|
|
44
|
+
get_episode_node_save_query,
|
|
36
45
|
)
|
|
37
46
|
from graphiti_core.utils.datetime_utils import utc_now
|
|
38
47
|
|
|
39
48
|
logger = logging.getLogger(__name__)
|
|
40
49
|
|
|
41
|
-
ENTITY_NODE_RETURN: LiteralString = """
|
|
42
|
-
RETURN
|
|
43
|
-
n.uuid As uuid,
|
|
44
|
-
n.name AS name,
|
|
45
|
-
n.group_id AS group_id,
|
|
46
|
-
n.created_at AS created_at,
|
|
47
|
-
n.summary AS summary,
|
|
48
|
-
labels(n) AS labels,
|
|
49
|
-
properties(n) AS attributes
|
|
50
|
-
"""
|
|
51
|
-
|
|
52
50
|
|
|
53
51
|
class EpisodeType(Enum):
|
|
54
52
|
"""
|
|
@@ -94,22 +92,63 @@ class Node(BaseModel, ABC):
|
|
|
94
92
|
created_at: datetime = Field(default_factory=lambda: utc_now())
|
|
95
93
|
|
|
96
94
|
@abstractmethod
|
|
97
|
-
async def save(self, driver:
|
|
98
|
-
|
|
99
|
-
async def delete(self, driver:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
95
|
+
async def save(self, driver: GraphDriver): ...
|
|
96
|
+
|
|
97
|
+
async def delete(self, driver: GraphDriver):
|
|
98
|
+
if driver.graph_operations_interface:
|
|
99
|
+
return await driver.graph_operations_interface.node_delete(self, driver)
|
|
100
|
+
|
|
101
|
+
match driver.provider:
|
|
102
|
+
case GraphProvider.NEO4J:
|
|
103
|
+
records, _, _ = await driver.execute_query(
|
|
104
|
+
"""
|
|
105
|
+
MATCH (n {uuid: $uuid})
|
|
106
|
+
WHERE n:Entity OR n:Episodic OR n:Community
|
|
107
|
+
OPTIONAL MATCH (n)-[r]-()
|
|
108
|
+
WITH collect(r.uuid) AS edge_uuids, n
|
|
109
|
+
DETACH DELETE n
|
|
110
|
+
RETURN edge_uuids
|
|
111
|
+
""",
|
|
112
|
+
uuid=self.uuid,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
case GraphProvider.KUZU:
|
|
116
|
+
for label in ['Episodic', 'Community']:
|
|
117
|
+
await driver.execute_query(
|
|
118
|
+
f"""
|
|
119
|
+
MATCH (n:{label} {{uuid: $uuid}})
|
|
120
|
+
DETACH DELETE n
|
|
121
|
+
""",
|
|
122
|
+
uuid=self.uuid,
|
|
123
|
+
)
|
|
124
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
125
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
126
|
+
await driver.execute_query(
|
|
127
|
+
"""
|
|
128
|
+
MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
|
|
129
|
+
DETACH DELETE e
|
|
130
|
+
""",
|
|
131
|
+
uuid=self.uuid,
|
|
132
|
+
)
|
|
133
|
+
await driver.execute_query(
|
|
134
|
+
"""
|
|
135
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
136
|
+
DETACH DELETE n
|
|
137
|
+
""",
|
|
138
|
+
uuid=self.uuid,
|
|
139
|
+
)
|
|
140
|
+
case _: # FalkorDB, Neptune
|
|
141
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
142
|
+
await driver.execute_query(
|
|
143
|
+
f"""
|
|
144
|
+
MATCH (n:{label} {{uuid: $uuid}})
|
|
145
|
+
DETACH DELETE n
|
|
146
|
+
""",
|
|
147
|
+
uuid=self.uuid,
|
|
148
|
+
)
|
|
108
149
|
|
|
109
150
|
logger.debug(f'Deleted Node: {self.uuid}')
|
|
110
151
|
|
|
111
|
-
return result
|
|
112
|
-
|
|
113
152
|
def __hash__(self):
|
|
114
153
|
return hash(self.uuid)
|
|
115
154
|
|
|
@@ -119,23 +158,138 @@ class Node(BaseModel, ABC):
|
|
|
119
158
|
return False
|
|
120
159
|
|
|
121
160
|
@classmethod
|
|
122
|
-
async def delete_by_group_id(cls, driver:
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
161
|
+
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
|
162
|
+
if driver.graph_operations_interface:
|
|
163
|
+
return await driver.graph_operations_interface.node_delete_by_group_id(
|
|
164
|
+
cls, driver, group_id, batch_size
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
match driver.provider:
|
|
168
|
+
case GraphProvider.NEO4J:
|
|
169
|
+
async with driver.session() as session:
|
|
170
|
+
await session.run(
|
|
171
|
+
"""
|
|
172
|
+
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
|
173
|
+
CALL (n) {
|
|
174
|
+
DETACH DELETE n
|
|
175
|
+
} IN TRANSACTIONS OF $batch_size ROWS
|
|
176
|
+
""",
|
|
177
|
+
group_id=group_id,
|
|
178
|
+
batch_size=batch_size,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
case GraphProvider.KUZU:
|
|
182
|
+
for label in ['Episodic', 'Community']:
|
|
183
|
+
await driver.execute_query(
|
|
184
|
+
f"""
|
|
185
|
+
MATCH (n:{label} {{group_id: $group_id}})
|
|
186
|
+
DETACH DELETE n
|
|
187
|
+
""",
|
|
188
|
+
group_id=group_id,
|
|
189
|
+
)
|
|
190
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
191
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
192
|
+
await driver.execute_query(
|
|
193
|
+
"""
|
|
194
|
+
MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
|
|
195
|
+
DETACH DELETE e
|
|
196
|
+
""",
|
|
197
|
+
group_id=group_id,
|
|
198
|
+
)
|
|
199
|
+
await driver.execute_query(
|
|
200
|
+
"""
|
|
201
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
202
|
+
DETACH DELETE n
|
|
203
|
+
""",
|
|
204
|
+
group_id=group_id,
|
|
205
|
+
)
|
|
206
|
+
case _: # FalkorDB, Neptune
|
|
207
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
208
|
+
await driver.execute_query(
|
|
209
|
+
f"""
|
|
210
|
+
MATCH (n:{label} {{group_id: $group_id}})
|
|
211
|
+
DETACH DELETE n
|
|
212
|
+
""",
|
|
213
|
+
group_id=group_id,
|
|
214
|
+
)
|
|
131
215
|
|
|
132
|
-
|
|
216
|
+
@classmethod
|
|
217
|
+
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
|
218
|
+
if driver.graph_operations_interface:
|
|
219
|
+
return await driver.graph_operations_interface.node_delete_by_uuids(
|
|
220
|
+
cls, driver, uuids, group_id=None, batch_size=batch_size
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
match driver.provider:
|
|
224
|
+
case GraphProvider.FALKORDB:
|
|
225
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
226
|
+
await driver.execute_query(
|
|
227
|
+
f"""
|
|
228
|
+
MATCH (n:{label})
|
|
229
|
+
WHERE n.uuid IN $uuids
|
|
230
|
+
DETACH DELETE n
|
|
231
|
+
""",
|
|
232
|
+
uuids=uuids,
|
|
233
|
+
)
|
|
234
|
+
case GraphProvider.KUZU:
|
|
235
|
+
for label in ['Episodic', 'Community']:
|
|
236
|
+
await driver.execute_query(
|
|
237
|
+
f"""
|
|
238
|
+
MATCH (n:{label})
|
|
239
|
+
WHERE n.uuid IN $uuids
|
|
240
|
+
DETACH DELETE n
|
|
241
|
+
""",
|
|
242
|
+
uuids=uuids,
|
|
243
|
+
)
|
|
244
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
245
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
246
|
+
await driver.execute_query(
|
|
247
|
+
"""
|
|
248
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
|
|
249
|
+
WHERE n.uuid IN $uuids
|
|
250
|
+
DETACH DELETE e
|
|
251
|
+
""",
|
|
252
|
+
uuids=uuids,
|
|
253
|
+
)
|
|
254
|
+
await driver.execute_query(
|
|
255
|
+
"""
|
|
256
|
+
MATCH (n:Entity)
|
|
257
|
+
WHERE n.uuid IN $uuids
|
|
258
|
+
DETACH DELETE n
|
|
259
|
+
""",
|
|
260
|
+
uuids=uuids,
|
|
261
|
+
)
|
|
262
|
+
case _: # Neo4J, Neptune
|
|
263
|
+
async with driver.session() as session:
|
|
264
|
+
# Collect all edge UUIDs before deleting nodes
|
|
265
|
+
await session.run(
|
|
266
|
+
"""
|
|
267
|
+
MATCH (n:Entity|Episodic|Community)
|
|
268
|
+
WHERE n.uuid IN $uuids
|
|
269
|
+
MATCH (n)-[r]-()
|
|
270
|
+
RETURN collect(r.uuid) AS edge_uuids
|
|
271
|
+
""",
|
|
272
|
+
uuids=uuids,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Now delete the nodes in batches
|
|
276
|
+
await session.run(
|
|
277
|
+
"""
|
|
278
|
+
MATCH (n:Entity|Episodic|Community)
|
|
279
|
+
WHERE n.uuid IN $uuids
|
|
280
|
+
CALL (n) {
|
|
281
|
+
DETACH DELETE n
|
|
282
|
+
} IN TRANSACTIONS OF $batch_size ROWS
|
|
283
|
+
""",
|
|
284
|
+
uuids=uuids,
|
|
285
|
+
batch_size=batch_size,
|
|
286
|
+
)
|
|
133
287
|
|
|
134
288
|
@classmethod
|
|
135
|
-
async def get_by_uuid(cls, driver:
|
|
289
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
136
290
|
|
|
137
291
|
@classmethod
|
|
138
|
-
async def get_by_uuids(cls, driver:
|
|
292
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
|
|
139
293
|
|
|
140
294
|
|
|
141
295
|
class EpisodicNode(Node):
|
|
@@ -150,42 +304,43 @@ class EpisodicNode(Node):
|
|
|
150
304
|
default_factory=list,
|
|
151
305
|
)
|
|
152
306
|
|
|
153
|
-
async def save(self, driver:
|
|
307
|
+
async def save(self, driver: GraphDriver):
|
|
308
|
+
if driver.graph_operations_interface:
|
|
309
|
+
return await driver.graph_operations_interface.episodic_node_save(self, driver)
|
|
310
|
+
|
|
311
|
+
episode_args = {
|
|
312
|
+
'uuid': self.uuid,
|
|
313
|
+
'name': self.name,
|
|
314
|
+
'group_id': self.group_id,
|
|
315
|
+
'source_description': self.source_description,
|
|
316
|
+
'content': self.content,
|
|
317
|
+
'entity_edges': self.entity_edges,
|
|
318
|
+
'created_at': self.created_at,
|
|
319
|
+
'valid_at': self.valid_at,
|
|
320
|
+
'source': self.source.value,
|
|
321
|
+
}
|
|
322
|
+
|
|
154
323
|
result = await driver.execute_query(
|
|
155
|
-
|
|
156
|
-
uuid=self.uuid,
|
|
157
|
-
name=self.name,
|
|
158
|
-
group_id=self.group_id,
|
|
159
|
-
source_description=self.source_description,
|
|
160
|
-
content=self.content,
|
|
161
|
-
entity_edges=self.entity_edges,
|
|
162
|
-
created_at=self.created_at,
|
|
163
|
-
valid_at=self.valid_at,
|
|
164
|
-
source=self.source.value,
|
|
165
|
-
database_=DEFAULT_DATABASE,
|
|
324
|
+
get_episode_node_save_query(driver.provider), **episode_args
|
|
166
325
|
)
|
|
167
326
|
|
|
168
|
-
logger.debug(f'Saved Node to
|
|
327
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
169
328
|
|
|
170
329
|
return result
|
|
171
330
|
|
|
172
331
|
@classmethod
|
|
173
|
-
async def get_by_uuid(cls, driver:
|
|
332
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
174
333
|
records, _, _ = await driver.execute_query(
|
|
175
334
|
"""
|
|
176
|
-
|
|
177
|
-
RETURN
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
e.source AS source,
|
|
185
|
-
e.entity_edges AS entity_edges
|
|
186
|
-
""",
|
|
335
|
+
MATCH (e:Episodic {uuid: $uuid})
|
|
336
|
+
RETURN
|
|
337
|
+
"""
|
|
338
|
+
+ (
|
|
339
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
340
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
341
|
+
else EPISODIC_NODE_RETURN
|
|
342
|
+
),
|
|
187
343
|
uuid=uuid,
|
|
188
|
-
database_=DEFAULT_DATABASE,
|
|
189
344
|
routing_='r',
|
|
190
345
|
)
|
|
191
346
|
|
|
@@ -197,23 +352,19 @@ class EpisodicNode(Node):
|
|
|
197
352
|
return episodes[0]
|
|
198
353
|
|
|
199
354
|
@classmethod
|
|
200
|
-
async def get_by_uuids(cls, driver:
|
|
355
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
201
356
|
records, _, _ = await driver.execute_query(
|
|
202
357
|
"""
|
|
203
|
-
|
|
358
|
+
MATCH (e:Episodic)
|
|
359
|
+
WHERE e.uuid IN $uuids
|
|
204
360
|
RETURN DISTINCT
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
e.source_description AS source_description,
|
|
212
|
-
e.source AS source,
|
|
213
|
-
e.entity_edges AS entity_edges
|
|
214
|
-
""",
|
|
361
|
+
"""
|
|
362
|
+
+ (
|
|
363
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
364
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
365
|
+
else EPISODIC_NODE_RETURN
|
|
366
|
+
),
|
|
215
367
|
uuids=uuids,
|
|
216
|
-
database_=DEFAULT_DATABASE,
|
|
217
368
|
routing_='r',
|
|
218
369
|
)
|
|
219
370
|
|
|
@@ -224,7 +375,7 @@ class EpisodicNode(Node):
|
|
|
224
375
|
@classmethod
|
|
225
376
|
async def get_by_group_ids(
|
|
226
377
|
cls,
|
|
227
|
-
driver:
|
|
378
|
+
driver: GraphDriver,
|
|
228
379
|
group_ids: list[str],
|
|
229
380
|
limit: int | None = None,
|
|
230
381
|
uuid_cursor: str | None = None,
|
|
@@ -234,27 +385,25 @@ class EpisodicNode(Node):
|
|
|
234
385
|
|
|
235
386
|
records, _, _ = await driver.execute_query(
|
|
236
387
|
"""
|
|
237
|
-
|
|
238
|
-
|
|
388
|
+
MATCH (e:Episodic)
|
|
389
|
+
WHERE e.group_id IN $group_ids
|
|
390
|
+
"""
|
|
239
391
|
+ cursor_query
|
|
240
392
|
+ """
|
|
241
393
|
RETURN DISTINCT
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
ORDER BY e.uuid DESC
|
|
252
|
-
"""
|
|
394
|
+
"""
|
|
395
|
+
+ (
|
|
396
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
397
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
398
|
+
else EPISODIC_NODE_RETURN
|
|
399
|
+
)
|
|
400
|
+
+ """
|
|
401
|
+
ORDER BY uuid DESC
|
|
402
|
+
"""
|
|
253
403
|
+ limit_query,
|
|
254
404
|
group_ids=group_ids,
|
|
255
405
|
uuid=uuid_cursor,
|
|
256
406
|
limit=limit,
|
|
257
|
-
database_=DEFAULT_DATABASE,
|
|
258
407
|
routing_='r',
|
|
259
408
|
)
|
|
260
409
|
|
|
@@ -263,23 +412,18 @@ class EpisodicNode(Node):
|
|
|
263
412
|
return episodes
|
|
264
413
|
|
|
265
414
|
@classmethod
|
|
266
|
-
async def get_by_entity_node_uuid(cls, driver:
|
|
415
|
+
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
|
267
416
|
records, _, _ = await driver.execute_query(
|
|
268
417
|
"""
|
|
269
|
-
|
|
418
|
+
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
|
270
419
|
RETURN DISTINCT
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
e.source_description AS source_description,
|
|
278
|
-
e.source AS source,
|
|
279
|
-
e.entity_edges AS entity_edges
|
|
280
|
-
""",
|
|
420
|
+
"""
|
|
421
|
+
+ (
|
|
422
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
423
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
424
|
+
else EPISODIC_NODE_RETURN
|
|
425
|
+
),
|
|
281
426
|
entity_node_uuid=entity_node_uuid,
|
|
282
|
-
database_=DEFAULT_DATABASE,
|
|
283
427
|
routing_='r',
|
|
284
428
|
)
|
|
285
429
|
|
|
@@ -304,13 +448,25 @@ class EntityNode(Node):
|
|
|
304
448
|
|
|
305
449
|
return self.name_embedding
|
|
306
450
|
|
|
307
|
-
async def load_name_embedding(self, driver:
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
451
|
+
async def load_name_embedding(self, driver: GraphDriver):
|
|
452
|
+
if driver.graph_operations_interface:
|
|
453
|
+
return await driver.graph_operations_interface.node_load_embeddings(self, driver)
|
|
454
|
+
|
|
455
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
456
|
+
query: LiteralString = """
|
|
457
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
458
|
+
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
459
|
+
"""
|
|
460
|
+
|
|
461
|
+
else:
|
|
462
|
+
query: LiteralString = """
|
|
463
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
464
|
+
RETURN n.name_embedding AS name_embedding
|
|
465
|
+
"""
|
|
312
466
|
records, _, _ = await driver.execute_query(
|
|
313
|
-
query,
|
|
467
|
+
query,
|
|
468
|
+
uuid=self.uuid,
|
|
469
|
+
routing_='r',
|
|
314
470
|
)
|
|
315
471
|
|
|
316
472
|
if len(records) == 0:
|
|
@@ -318,7 +474,10 @@ class EntityNode(Node):
|
|
|
318
474
|
|
|
319
475
|
self.name_embedding = records[0]['name_embedding']
|
|
320
476
|
|
|
321
|
-
async def save(self, driver:
|
|
477
|
+
async def save(self, driver: GraphDriver):
|
|
478
|
+
if driver.graph_operations_interface:
|
|
479
|
+
return await driver.graph_operations_interface.node_save(self, driver)
|
|
480
|
+
|
|
322
481
|
entity_data: dict[str, Any] = {
|
|
323
482
|
'uuid': self.uuid,
|
|
324
483
|
'name': self.name,
|
|
@@ -328,35 +487,39 @@ class EntityNode(Node):
|
|
|
328
487
|
'created_at': self.created_at,
|
|
329
488
|
}
|
|
330
489
|
|
|
331
|
-
|
|
490
|
+
if driver.provider == GraphProvider.KUZU:
|
|
491
|
+
entity_data['attributes'] = json.dumps(self.attributes)
|
|
492
|
+
entity_data['labels'] = list(set(self.labels + ['Entity']))
|
|
493
|
+
result = await driver.execute_query(
|
|
494
|
+
get_entity_node_save_query(driver.provider, labels=''),
|
|
495
|
+
**entity_data,
|
|
496
|
+
)
|
|
497
|
+
else:
|
|
498
|
+
entity_data.update(self.attributes or {})
|
|
499
|
+
labels = ':'.join(self.labels + ['Entity'])
|
|
332
500
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
database_=DEFAULT_DATABASE,
|
|
338
|
-
)
|
|
501
|
+
result = await driver.execute_query(
|
|
502
|
+
get_entity_node_save_query(driver.provider, labels),
|
|
503
|
+
entity_data=entity_data,
|
|
504
|
+
)
|
|
339
505
|
|
|
340
|
-
logger.debug(f'Saved Node to
|
|
506
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
341
507
|
|
|
342
508
|
return result
|
|
343
509
|
|
|
344
510
|
@classmethod
|
|
345
|
-
async def get_by_uuid(cls, driver:
|
|
346
|
-
query = (
|
|
347
|
-
"""
|
|
348
|
-
MATCH (n:Entity {uuid: $uuid})
|
|
349
|
-
"""
|
|
350
|
-
+ ENTITY_NODE_RETURN
|
|
351
|
-
)
|
|
511
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
352
512
|
records, _, _ = await driver.execute_query(
|
|
353
|
-
|
|
513
|
+
"""
|
|
514
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
515
|
+
RETURN
|
|
516
|
+
"""
|
|
517
|
+
+ get_entity_node_return_query(driver.provider),
|
|
354
518
|
uuid=uuid,
|
|
355
|
-
database_=DEFAULT_DATABASE,
|
|
356
519
|
routing_='r',
|
|
357
520
|
)
|
|
358
521
|
|
|
359
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
522
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
360
523
|
|
|
361
524
|
if len(nodes) == 0:
|
|
362
525
|
raise NodeNotFoundError(uuid)
|
|
@@ -364,50 +527,63 @@ class EntityNode(Node):
|
|
|
364
527
|
return nodes[0]
|
|
365
528
|
|
|
366
529
|
@classmethod
|
|
367
|
-
async def get_by_uuids(cls, driver:
|
|
530
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
368
531
|
records, _, _ = await driver.execute_query(
|
|
369
532
|
"""
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
533
|
+
MATCH (n:Entity)
|
|
534
|
+
WHERE n.uuid IN $uuids
|
|
535
|
+
RETURN
|
|
536
|
+
"""
|
|
537
|
+
+ get_entity_node_return_query(driver.provider),
|
|
373
538
|
uuids=uuids,
|
|
374
|
-
database_=DEFAULT_DATABASE,
|
|
375
539
|
routing_='r',
|
|
376
540
|
)
|
|
377
541
|
|
|
378
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
542
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
379
543
|
|
|
380
544
|
return nodes
|
|
381
545
|
|
|
382
546
|
@classmethod
|
|
383
547
|
async def get_by_group_ids(
|
|
384
548
|
cls,
|
|
385
|
-
driver:
|
|
549
|
+
driver: GraphDriver,
|
|
386
550
|
group_ids: list[str],
|
|
387
551
|
limit: int | None = None,
|
|
388
552
|
uuid_cursor: str | None = None,
|
|
553
|
+
with_embeddings: bool = False,
|
|
389
554
|
):
|
|
390
555
|
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
|
|
391
556
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
557
|
+
with_embeddings_query: LiteralString = (
|
|
558
|
+
""",
|
|
559
|
+
n.name_embedding AS name_embedding
|
|
560
|
+
"""
|
|
561
|
+
if with_embeddings
|
|
562
|
+
else ''
|
|
563
|
+
)
|
|
392
564
|
|
|
393
565
|
records, _, _ = await driver.execute_query(
|
|
394
566
|
"""
|
|
395
|
-
|
|
396
|
-
|
|
567
|
+
MATCH (n:Entity)
|
|
568
|
+
WHERE n.group_id IN $group_ids
|
|
569
|
+
"""
|
|
397
570
|
+ cursor_query
|
|
398
|
-
+ ENTITY_NODE_RETURN
|
|
399
571
|
+ """
|
|
400
|
-
|
|
401
|
-
|
|
572
|
+
RETURN
|
|
573
|
+
"""
|
|
574
|
+
+ get_entity_node_return_query(driver.provider)
|
|
575
|
+
+ with_embeddings_query
|
|
576
|
+
+ """
|
|
577
|
+
ORDER BY n.uuid DESC
|
|
578
|
+
"""
|
|
402
579
|
+ limit_query,
|
|
403
580
|
group_ids=group_ids,
|
|
404
581
|
uuid=uuid_cursor,
|
|
405
582
|
limit=limit,
|
|
406
|
-
database_=DEFAULT_DATABASE,
|
|
407
583
|
routing_='r',
|
|
408
584
|
)
|
|
409
585
|
|
|
410
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
586
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
411
587
|
|
|
412
588
|
return nodes
|
|
413
589
|
|
|
@@ -416,19 +592,23 @@ class CommunityNode(Node):
|
|
|
416
592
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
417
593
|
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
|
418
594
|
|
|
419
|
-
async def save(self, driver:
|
|
595
|
+
async def save(self, driver: GraphDriver):
|
|
596
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
597
|
+
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
598
|
+
'communities',
|
|
599
|
+
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
|
600
|
+
)
|
|
420
601
|
result = await driver.execute_query(
|
|
421
|
-
|
|
602
|
+
get_community_node_save_query(driver.provider), # type: ignore
|
|
422
603
|
uuid=self.uuid,
|
|
423
604
|
name=self.name,
|
|
424
605
|
group_id=self.group_id,
|
|
425
606
|
summary=self.summary,
|
|
426
607
|
name_embedding=self.name_embedding,
|
|
427
608
|
created_at=self.created_at,
|
|
428
|
-
database_=DEFAULT_DATABASE,
|
|
429
609
|
)
|
|
430
610
|
|
|
431
|
-
logger.debug(f'Saved Node to
|
|
611
|
+
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
432
612
|
|
|
433
613
|
return result
|
|
434
614
|
|
|
@@ -441,13 +621,22 @@ class CommunityNode(Node):
|
|
|
441
621
|
|
|
442
622
|
return self.name_embedding
|
|
443
623
|
|
|
444
|
-
async def load_name_embedding(self, driver:
|
|
445
|
-
|
|
624
|
+
async def load_name_embedding(self, driver: GraphDriver):
|
|
625
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
626
|
+
query: LiteralString = """
|
|
627
|
+
MATCH (c:Community {uuid: $uuid})
|
|
628
|
+
RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
629
|
+
"""
|
|
630
|
+
else:
|
|
631
|
+
query: LiteralString = """
|
|
446
632
|
MATCH (c:Community {uuid: $uuid})
|
|
447
633
|
RETURN c.name_embedding AS name_embedding
|
|
448
|
-
|
|
634
|
+
"""
|
|
635
|
+
|
|
449
636
|
records, _, _ = await driver.execute_query(
|
|
450
|
-
query,
|
|
637
|
+
query,
|
|
638
|
+
uuid=self.uuid,
|
|
639
|
+
routing_='r',
|
|
451
640
|
)
|
|
452
641
|
|
|
453
642
|
if len(records) == 0:
|
|
@@ -456,19 +645,18 @@ class CommunityNode(Node):
|
|
|
456
645
|
self.name_embedding = records[0]['name_embedding']
|
|
457
646
|
|
|
458
647
|
@classmethod
|
|
459
|
-
async def get_by_uuid(cls, driver:
|
|
648
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
460
649
|
records, _, _ = await driver.execute_query(
|
|
461
650
|
"""
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
651
|
+
MATCH (c:Community {uuid: $uuid})
|
|
652
|
+
RETURN
|
|
653
|
+
"""
|
|
654
|
+
+ (
|
|
655
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
656
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
657
|
+
else COMMUNITY_NODE_RETURN
|
|
658
|
+
),
|
|
470
659
|
uuid=uuid,
|
|
471
|
-
database_=DEFAULT_DATABASE,
|
|
472
660
|
routing_='r',
|
|
473
661
|
)
|
|
474
662
|
|
|
@@ -480,19 +668,19 @@ class CommunityNode(Node):
|
|
|
480
668
|
return nodes[0]
|
|
481
669
|
|
|
482
670
|
@classmethod
|
|
483
|
-
async def get_by_uuids(cls, driver:
|
|
671
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
484
672
|
records, _, _ = await driver.execute_query(
|
|
485
673
|
"""
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
674
|
+
MATCH (c:Community)
|
|
675
|
+
WHERE c.uuid IN $uuids
|
|
676
|
+
RETURN
|
|
677
|
+
"""
|
|
678
|
+
+ (
|
|
679
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
680
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
681
|
+
else COMMUNITY_NODE_RETURN
|
|
682
|
+
),
|
|
494
683
|
uuids=uuids,
|
|
495
|
-
database_=DEFAULT_DATABASE,
|
|
496
684
|
routing_='r',
|
|
497
685
|
)
|
|
498
686
|
|
|
@@ -503,33 +691,35 @@ class CommunityNode(Node):
|
|
|
503
691
|
@classmethod
|
|
504
692
|
async def get_by_group_ids(
|
|
505
693
|
cls,
|
|
506
|
-
driver:
|
|
694
|
+
driver: GraphDriver,
|
|
507
695
|
group_ids: list[str],
|
|
508
696
|
limit: int | None = None,
|
|
509
697
|
uuid_cursor: str | None = None,
|
|
510
698
|
):
|
|
511
|
-
cursor_query: LiteralString = 'AND
|
|
699
|
+
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
|
|
512
700
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
513
701
|
|
|
514
702
|
records, _, _ = await driver.execute_query(
|
|
515
703
|
"""
|
|
516
|
-
|
|
517
|
-
|
|
704
|
+
MATCH (c:Community)
|
|
705
|
+
WHERE c.group_id IN $group_ids
|
|
706
|
+
"""
|
|
518
707
|
+ cursor_query
|
|
519
708
|
+ """
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
709
|
+
RETURN
|
|
710
|
+
"""
|
|
711
|
+
+ (
|
|
712
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
713
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
714
|
+
else COMMUNITY_NODE_RETURN
|
|
715
|
+
)
|
|
716
|
+
+ """
|
|
717
|
+
ORDER BY c.uuid DESC
|
|
718
|
+
"""
|
|
528
719
|
+ limit_query,
|
|
529
720
|
group_ids=group_ids,
|
|
530
721
|
uuid=uuid_cursor,
|
|
531
722
|
limit=limit,
|
|
532
|
-
database_=DEFAULT_DATABASE,
|
|
533
723
|
routing_='r',
|
|
534
724
|
)
|
|
535
725
|
|
|
@@ -540,10 +730,18 @@ class CommunityNode(Node):
|
|
|
540
730
|
|
|
541
731
|
# Node helpers
|
|
542
732
|
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
733
|
+
created_at = parse_db_date(record['created_at'])
|
|
734
|
+
valid_at = parse_db_date(record['valid_at'])
|
|
735
|
+
|
|
736
|
+
if created_at is None:
|
|
737
|
+
raise ValueError(f'created_at cannot be None for episode {record.get("uuid", "unknown")}')
|
|
738
|
+
if valid_at is None:
|
|
739
|
+
raise ValueError(f'valid_at cannot be None for episode {record.get("uuid", "unknown")}')
|
|
740
|
+
|
|
543
741
|
return EpisodicNode(
|
|
544
742
|
content=record['content'],
|
|
545
|
-
created_at=
|
|
546
|
-
valid_at=
|
|
743
|
+
created_at=created_at,
|
|
744
|
+
valid_at=valid_at,
|
|
547
745
|
uuid=record['uuid'],
|
|
548
746
|
group_id=record['group_id'],
|
|
549
747
|
source=EpisodeType.from_str(record['source']),
|
|
@@ -553,24 +751,35 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
|
553
751
|
)
|
|
554
752
|
|
|
555
753
|
|
|
556
|
-
def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
754
|
+
def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
|
|
755
|
+
if provider == GraphProvider.KUZU:
|
|
756
|
+
attributes = json.loads(record['attributes']) if record['attributes'] else {}
|
|
757
|
+
else:
|
|
758
|
+
attributes = record['attributes']
|
|
759
|
+
attributes.pop('uuid', None)
|
|
760
|
+
attributes.pop('name', None)
|
|
761
|
+
attributes.pop('group_id', None)
|
|
762
|
+
attributes.pop('name_embedding', None)
|
|
763
|
+
attributes.pop('summary', None)
|
|
764
|
+
attributes.pop('created_at', None)
|
|
765
|
+
attributes.pop('labels', None)
|
|
766
|
+
|
|
767
|
+
labels = record.get('labels', [])
|
|
768
|
+
group_id = record.get('group_id')
|
|
769
|
+
if 'Entity_' + group_id.replace('-', '') in labels:
|
|
770
|
+
labels.remove('Entity_' + group_id.replace('-', ''))
|
|
771
|
+
|
|
557
772
|
entity_node = EntityNode(
|
|
558
773
|
uuid=record['uuid'],
|
|
559
774
|
name=record['name'],
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
775
|
+
name_embedding=record.get('name_embedding'),
|
|
776
|
+
group_id=group_id,
|
|
777
|
+
labels=labels,
|
|
778
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
563
779
|
summary=record['summary'],
|
|
564
|
-
attributes=
|
|
780
|
+
attributes=attributes,
|
|
565
781
|
)
|
|
566
782
|
|
|
567
|
-
entity_node.attributes.pop('uuid', None)
|
|
568
|
-
entity_node.attributes.pop('name', None)
|
|
569
|
-
entity_node.attributes.pop('group_id', None)
|
|
570
|
-
entity_node.attributes.pop('name_embedding', None)
|
|
571
|
-
entity_node.attributes.pop('summary', None)
|
|
572
|
-
entity_node.attributes.pop('created_at', None)
|
|
573
|
-
|
|
574
783
|
return entity_node
|
|
575
784
|
|
|
576
785
|
|
|
@@ -580,12 +789,18 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|
|
580
789
|
name=record['name'],
|
|
581
790
|
group_id=record['group_id'],
|
|
582
791
|
name_embedding=record['name_embedding'],
|
|
583
|
-
created_at=record['created_at']
|
|
792
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
584
793
|
summary=record['summary'],
|
|
585
794
|
)
|
|
586
795
|
|
|
587
796
|
|
|
588
797
|
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
|
589
|
-
|
|
590
|
-
for node
|
|
798
|
+
# filter out falsey values from nodes
|
|
799
|
+
filtered_nodes = [node for node in nodes if node.name]
|
|
800
|
+
|
|
801
|
+
if not filtered_nodes:
|
|
802
|
+
return
|
|
803
|
+
|
|
804
|
+
name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
|
|
805
|
+
for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
|
|
591
806
|
node.name_embedding = name_embedding
|