graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
graphiti_core/edges.py
CHANGED
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
from abc import ABC, abstractmethod
|
|
19
20
|
from datetime import datetime
|
|
@@ -21,38 +22,25 @@ from time import time
|
|
|
21
22
|
from typing import Any
|
|
22
23
|
from uuid import uuid4
|
|
23
24
|
|
|
24
|
-
from neo4j import AsyncDriver
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
26
|
from typing_extensions import LiteralString
|
|
27
27
|
|
|
28
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
28
29
|
from graphiti_core.embedder import EmbedderClient
|
|
29
30
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
30
|
-
from graphiti_core.helpers import
|
|
31
|
+
from graphiti_core.helpers import parse_db_date
|
|
31
32
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
32
|
-
|
|
33
|
-
|
|
33
|
+
COMMUNITY_EDGE_RETURN,
|
|
34
|
+
EPISODIC_EDGE_RETURN,
|
|
34
35
|
EPISODIC_EDGE_SAVE,
|
|
36
|
+
get_community_edge_save_query,
|
|
37
|
+
get_entity_edge_return_query,
|
|
38
|
+
get_entity_edge_save_query,
|
|
35
39
|
)
|
|
36
40
|
from graphiti_core.nodes import Node
|
|
37
41
|
|
|
38
42
|
logger = logging.getLogger(__name__)
|
|
39
43
|
|
|
40
|
-
ENTITY_EDGE_RETURN: LiteralString = """
|
|
41
|
-
RETURN
|
|
42
|
-
e.uuid AS uuid,
|
|
43
|
-
startNode(e).uuid AS source_node_uuid,
|
|
44
|
-
endNode(e).uuid AS target_node_uuid,
|
|
45
|
-
e.created_at AS created_at,
|
|
46
|
-
e.name AS name,
|
|
47
|
-
e.group_id AS group_id,
|
|
48
|
-
e.fact AS fact,
|
|
49
|
-
e.episodes AS episodes,
|
|
50
|
-
e.expired_at AS expired_at,
|
|
51
|
-
e.valid_at AS valid_at,
|
|
52
|
-
e.invalid_at AS invalid_at,
|
|
53
|
-
properties(e) AS attributes
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
44
|
|
|
57
45
|
class Edge(BaseModel, ABC):
|
|
58
46
|
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
|
@@ -62,21 +50,71 @@ class Edge(BaseModel, ABC):
|
|
|
62
50
|
created_at: datetime
|
|
63
51
|
|
|
64
52
|
@abstractmethod
|
|
65
|
-
async def save(self, driver:
|
|
66
|
-
|
|
67
|
-
async def delete(self, driver:
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
53
|
+
async def save(self, driver: GraphDriver): ...
|
|
54
|
+
|
|
55
|
+
async def delete(self, driver: GraphDriver):
|
|
56
|
+
if driver.graph_operations_interface:
|
|
57
|
+
return await driver.graph_operations_interface.edge_delete(self, driver)
|
|
58
|
+
|
|
59
|
+
if driver.provider == GraphProvider.KUZU:
|
|
60
|
+
await driver.execute_query(
|
|
61
|
+
"""
|
|
62
|
+
MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
63
|
+
DELETE e
|
|
64
|
+
""",
|
|
65
|
+
uuid=self.uuid,
|
|
66
|
+
)
|
|
67
|
+
await driver.execute_query(
|
|
68
|
+
"""
|
|
69
|
+
MATCH (e:RelatesToNode_ {uuid: $uuid})
|
|
70
|
+
DETACH DELETE e
|
|
71
|
+
""",
|
|
72
|
+
uuid=self.uuid,
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
await driver.execute_query(
|
|
76
|
+
"""
|
|
77
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
78
|
+
DELETE e
|
|
79
|
+
""",
|
|
80
|
+
uuid=self.uuid,
|
|
81
|
+
)
|
|
76
82
|
|
|
77
83
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
78
84
|
|
|
79
|
-
|
|
85
|
+
@classmethod
|
|
86
|
+
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
87
|
+
if driver.graph_operations_interface:
|
|
88
|
+
return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
|
|
89
|
+
|
|
90
|
+
if driver.provider == GraphProvider.KUZU:
|
|
91
|
+
await driver.execute_query(
|
|
92
|
+
"""
|
|
93
|
+
MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
|
|
94
|
+
WHERE e.uuid IN $uuids
|
|
95
|
+
DELETE e
|
|
96
|
+
""",
|
|
97
|
+
uuids=uuids,
|
|
98
|
+
)
|
|
99
|
+
await driver.execute_query(
|
|
100
|
+
"""
|
|
101
|
+
MATCH (e:RelatesToNode_)
|
|
102
|
+
WHERE e.uuid IN $uuids
|
|
103
|
+
DETACH DELETE e
|
|
104
|
+
""",
|
|
105
|
+
uuids=uuids,
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
await driver.execute_query(
|
|
109
|
+
"""
|
|
110
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
|
|
111
|
+
WHERE e.uuid IN $uuids
|
|
112
|
+
DELETE e
|
|
113
|
+
""",
|
|
114
|
+
uuids=uuids,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
logger.debug(f'Deleted Edges: {uuids}')
|
|
80
118
|
|
|
81
119
|
def __hash__(self):
|
|
82
120
|
return hash(self.uuid)
|
|
@@ -87,11 +125,11 @@ class Edge(BaseModel, ABC):
|
|
|
87
125
|
return False
|
|
88
126
|
|
|
89
127
|
@classmethod
|
|
90
|
-
async def get_by_uuid(cls, driver:
|
|
128
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
91
129
|
|
|
92
130
|
|
|
93
131
|
class EpisodicEdge(Edge):
|
|
94
|
-
async def save(self, driver:
|
|
132
|
+
async def save(self, driver: GraphDriver):
|
|
95
133
|
result = await driver.execute_query(
|
|
96
134
|
EPISODIC_EDGE_SAVE,
|
|
97
135
|
episode_uuid=self.source_node_uuid,
|
|
@@ -99,27 +137,21 @@ class EpisodicEdge(Edge):
|
|
|
99
137
|
uuid=self.uuid,
|
|
100
138
|
group_id=self.group_id,
|
|
101
139
|
created_at=self.created_at,
|
|
102
|
-
database_=DEFAULT_DATABASE,
|
|
103
140
|
)
|
|
104
141
|
|
|
105
|
-
logger.debug(f'Saved edge to
|
|
142
|
+
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
106
143
|
|
|
107
144
|
return result
|
|
108
145
|
|
|
109
146
|
@classmethod
|
|
110
|
-
async def get_by_uuid(cls, driver:
|
|
147
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
111
148
|
records, _, _ = await driver.execute_query(
|
|
112
149
|
"""
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
n.uuid AS source_node_uuid,
|
|
118
|
-
m.uuid AS target_node_uuid,
|
|
119
|
-
e.created_at AS created_at
|
|
120
|
-
""",
|
|
150
|
+
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
151
|
+
RETURN
|
|
152
|
+
"""
|
|
153
|
+
+ EPISODIC_EDGE_RETURN,
|
|
121
154
|
uuid=uuid,
|
|
122
|
-
database_=DEFAULT_DATABASE,
|
|
123
155
|
routing_='r',
|
|
124
156
|
)
|
|
125
157
|
|
|
@@ -130,20 +162,15 @@ class EpisodicEdge(Edge):
|
|
|
130
162
|
return edges[0]
|
|
131
163
|
|
|
132
164
|
@classmethod
|
|
133
|
-
async def get_by_uuids(cls, driver:
|
|
165
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
134
166
|
records, _, _ = await driver.execute_query(
|
|
135
167
|
"""
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
n.uuid AS source_node_uuid,
|
|
142
|
-
m.uuid AS target_node_uuid,
|
|
143
|
-
e.created_at AS created_at
|
|
144
|
-
""",
|
|
168
|
+
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
|
169
|
+
WHERE e.uuid IN $uuids
|
|
170
|
+
RETURN
|
|
171
|
+
"""
|
|
172
|
+
+ EPISODIC_EDGE_RETURN,
|
|
145
173
|
uuids=uuids,
|
|
146
|
-
database_=DEFAULT_DATABASE,
|
|
147
174
|
routing_='r',
|
|
148
175
|
)
|
|
149
176
|
|
|
@@ -156,7 +183,7 @@ class EpisodicEdge(Edge):
|
|
|
156
183
|
@classmethod
|
|
157
184
|
async def get_by_group_ids(
|
|
158
185
|
cls,
|
|
159
|
-
driver:
|
|
186
|
+
driver: GraphDriver,
|
|
160
187
|
group_ids: list[str],
|
|
161
188
|
limit: int | None = None,
|
|
162
189
|
uuid_cursor: str | None = None,
|
|
@@ -166,24 +193,21 @@ class EpisodicEdge(Edge):
|
|
|
166
193
|
|
|
167
194
|
records, _, _ = await driver.execute_query(
|
|
168
195
|
"""
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
196
|
+
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
|
197
|
+
WHERE e.group_id IN $group_ids
|
|
198
|
+
"""
|
|
172
199
|
+ cursor_query
|
|
173
200
|
+ """
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
ORDER BY e.uuid DESC
|
|
181
|
-
"""
|
|
201
|
+
RETURN
|
|
202
|
+
"""
|
|
203
|
+
+ EPISODIC_EDGE_RETURN
|
|
204
|
+
+ """
|
|
205
|
+
ORDER BY e.uuid DESC
|
|
206
|
+
"""
|
|
182
207
|
+ limit_query,
|
|
183
208
|
group_ids=group_ids,
|
|
184
209
|
uuid=uuid_cursor,
|
|
185
210
|
limit=limit,
|
|
186
|
-
database_=DEFAULT_DATABASE,
|
|
187
211
|
routing_='r',
|
|
188
212
|
)
|
|
189
213
|
|
|
@@ -226,13 +250,31 @@ class EntityEdge(Edge):
|
|
|
226
250
|
|
|
227
251
|
return self.fact_embedding
|
|
228
252
|
|
|
229
|
-
async def load_fact_embedding(self, driver:
|
|
230
|
-
|
|
253
|
+
async def load_fact_embedding(self, driver: GraphDriver):
|
|
254
|
+
if driver.graph_operations_interface:
|
|
255
|
+
return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
|
|
256
|
+
|
|
257
|
+
query = """
|
|
231
258
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
232
259
|
RETURN e.fact_embedding AS fact_embedding
|
|
233
260
|
"""
|
|
261
|
+
|
|
262
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
263
|
+
query = """
|
|
264
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
265
|
+
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
if driver.provider == GraphProvider.KUZU:
|
|
269
|
+
query = """
|
|
270
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
|
|
271
|
+
RETURN e.fact_embedding AS fact_embedding
|
|
272
|
+
"""
|
|
273
|
+
|
|
234
274
|
records, _, _ = await driver.execute_query(
|
|
235
|
-
query,
|
|
275
|
+
query,
|
|
276
|
+
uuid=self.uuid,
|
|
277
|
+
routing_='r',
|
|
236
278
|
)
|
|
237
279
|
|
|
238
280
|
if len(records) == 0:
|
|
@@ -240,7 +282,7 @@ class EntityEdge(Edge):
|
|
|
240
282
|
|
|
241
283
|
self.fact_embedding = records[0]['fact_embedding']
|
|
242
284
|
|
|
243
|
-
async def save(self, driver:
|
|
285
|
+
async def save(self, driver: GraphDriver):
|
|
244
286
|
edge_data: dict[str, Any] = {
|
|
245
287
|
'source_uuid': self.source_node_uuid,
|
|
246
288
|
'target_uuid': self.target_node_uuid,
|
|
@@ -256,138 +298,209 @@ class EntityEdge(Edge):
|
|
|
256
298
|
'invalid_at': self.invalid_at,
|
|
257
299
|
}
|
|
258
300
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
301
|
+
if driver.provider == GraphProvider.KUZU:
|
|
302
|
+
edge_data['attributes'] = json.dumps(self.attributes)
|
|
303
|
+
result = await driver.execute_query(
|
|
304
|
+
get_entity_edge_save_query(driver.provider),
|
|
305
|
+
**edge_data,
|
|
306
|
+
)
|
|
307
|
+
else:
|
|
308
|
+
edge_data.update(self.attributes or {})
|
|
309
|
+
result = await driver.execute_query(
|
|
310
|
+
get_entity_edge_save_query(driver.provider),
|
|
311
|
+
edge_data=edge_data,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
268
315
|
|
|
269
316
|
return result
|
|
270
317
|
|
|
271
318
|
@classmethod
|
|
272
|
-
async def get_by_uuid(cls, driver:
|
|
319
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
320
|
+
match_query = """
|
|
321
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
322
|
+
"""
|
|
323
|
+
if driver.provider == GraphProvider.KUZU:
|
|
324
|
+
match_query = """
|
|
325
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
|
|
326
|
+
"""
|
|
327
|
+
|
|
273
328
|
records, _, _ = await driver.execute_query(
|
|
329
|
+
match_query
|
|
330
|
+
+ """
|
|
331
|
+
RETURN
|
|
274
332
|
"""
|
|
275
|
-
|
|
276
|
-
"""
|
|
277
|
-
+ ENTITY_EDGE_RETURN,
|
|
333
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
278
334
|
uuid=uuid,
|
|
279
|
-
database_=DEFAULT_DATABASE,
|
|
280
335
|
routing_='r',
|
|
281
336
|
)
|
|
282
337
|
|
|
283
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
338
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
284
339
|
|
|
285
340
|
if len(edges) == 0:
|
|
286
341
|
raise EdgeNotFoundError(uuid)
|
|
287
342
|
return edges[0]
|
|
288
343
|
|
|
289
344
|
@classmethod
|
|
290
|
-
async def
|
|
345
|
+
async def get_between_nodes(
|
|
346
|
+
cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
|
|
347
|
+
):
|
|
348
|
+
match_query = """
|
|
349
|
+
MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
|
350
|
+
"""
|
|
351
|
+
if driver.provider == GraphProvider.KUZU:
|
|
352
|
+
match_query = """
|
|
353
|
+
MATCH (n:Entity {uuid: $source_node_uuid})
|
|
354
|
+
-[:RELATES_TO]->(e:RelatesToNode_)
|
|
355
|
+
-[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
records, _, _ = await driver.execute_query(
|
|
359
|
+
match_query
|
|
360
|
+
+ """
|
|
361
|
+
RETURN
|
|
362
|
+
"""
|
|
363
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
364
|
+
source_node_uuid=source_node_uuid,
|
|
365
|
+
target_node_uuid=target_node_uuid,
|
|
366
|
+
routing_='r',
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
370
|
+
|
|
371
|
+
return edges
|
|
372
|
+
|
|
373
|
+
@classmethod
|
|
374
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
291
375
|
if len(uuids) == 0:
|
|
292
376
|
return []
|
|
293
377
|
|
|
378
|
+
match_query = """
|
|
379
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
380
|
+
"""
|
|
381
|
+
if driver.provider == GraphProvider.KUZU:
|
|
382
|
+
match_query = """
|
|
383
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
384
|
+
"""
|
|
385
|
+
|
|
294
386
|
records, _, _ = await driver.execute_query(
|
|
387
|
+
match_query
|
|
388
|
+
+ """
|
|
389
|
+
WHERE e.uuid IN $uuids
|
|
390
|
+
RETURN
|
|
295
391
|
"""
|
|
296
|
-
|
|
297
|
-
WHERE e.uuid IN $uuids
|
|
298
|
-
"""
|
|
299
|
-
+ ENTITY_EDGE_RETURN,
|
|
392
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
300
393
|
uuids=uuids,
|
|
301
|
-
database_=DEFAULT_DATABASE,
|
|
302
394
|
routing_='r',
|
|
303
395
|
)
|
|
304
396
|
|
|
305
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
397
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
306
398
|
|
|
307
399
|
return edges
|
|
308
400
|
|
|
309
401
|
@classmethod
|
|
310
402
|
async def get_by_group_ids(
|
|
311
403
|
cls,
|
|
312
|
-
driver:
|
|
404
|
+
driver: GraphDriver,
|
|
313
405
|
group_ids: list[str],
|
|
314
406
|
limit: int | None = None,
|
|
315
407
|
uuid_cursor: str | None = None,
|
|
408
|
+
with_embeddings: bool = False,
|
|
316
409
|
):
|
|
317
410
|
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
|
|
318
411
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
412
|
+
with_embeddings_query: LiteralString = (
|
|
413
|
+
""",
|
|
414
|
+
e.fact_embedding AS fact_embedding
|
|
415
|
+
"""
|
|
416
|
+
if with_embeddings
|
|
417
|
+
else ''
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
match_query = """
|
|
421
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
422
|
+
"""
|
|
423
|
+
if driver.provider == GraphProvider.KUZU:
|
|
424
|
+
match_query = """
|
|
425
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
426
|
+
"""
|
|
319
427
|
|
|
320
428
|
records, _, _ = await driver.execute_query(
|
|
429
|
+
match_query
|
|
430
|
+
+ """
|
|
431
|
+
WHERE e.group_id IN $group_ids
|
|
321
432
|
"""
|
|
322
|
-
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
323
|
-
WHERE e.group_id IN $group_ids
|
|
324
|
-
"""
|
|
325
433
|
+ cursor_query
|
|
326
|
-
+ ENTITY_EDGE_RETURN
|
|
327
434
|
+ """
|
|
328
|
-
|
|
329
|
-
|
|
435
|
+
RETURN
|
|
436
|
+
"""
|
|
437
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
438
|
+
+ with_embeddings_query
|
|
439
|
+
+ """
|
|
440
|
+
ORDER BY e.uuid DESC
|
|
441
|
+
"""
|
|
330
442
|
+ limit_query,
|
|
331
443
|
group_ids=group_ids,
|
|
332
444
|
uuid=uuid_cursor,
|
|
333
445
|
limit=limit,
|
|
334
|
-
database_=DEFAULT_DATABASE,
|
|
335
446
|
routing_='r',
|
|
336
447
|
)
|
|
337
448
|
|
|
338
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
449
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
339
450
|
|
|
340
451
|
if len(edges) == 0:
|
|
341
452
|
raise GroupsEdgesNotFoundError(group_ids)
|
|
342
453
|
return edges
|
|
343
454
|
|
|
344
455
|
@classmethod
|
|
345
|
-
async def get_by_node_uuid(cls, driver:
|
|
346
|
-
|
|
456
|
+
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
|
457
|
+
match_query = """
|
|
458
|
+
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
|
459
|
+
"""
|
|
460
|
+
if driver.provider == GraphProvider.KUZU:
|
|
461
|
+
match_query = """
|
|
462
|
+
MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
347
463
|
"""
|
|
348
|
-
|
|
349
|
-
"""
|
|
350
|
-
+ ENTITY_EDGE_RETURN
|
|
351
|
-
)
|
|
464
|
+
|
|
352
465
|
records, _, _ = await driver.execute_query(
|
|
353
|
-
|
|
466
|
+
match_query
|
|
467
|
+
+ """
|
|
468
|
+
RETURN
|
|
469
|
+
"""
|
|
470
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
471
|
+
node_uuid=node_uuid,
|
|
472
|
+
routing_='r',
|
|
354
473
|
)
|
|
355
474
|
|
|
356
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
475
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
357
476
|
|
|
358
477
|
return edges
|
|
359
478
|
|
|
360
479
|
|
|
361
480
|
class CommunityEdge(Edge):
|
|
362
|
-
async def save(self, driver:
|
|
481
|
+
async def save(self, driver: GraphDriver):
|
|
363
482
|
result = await driver.execute_query(
|
|
364
|
-
|
|
483
|
+
get_community_edge_save_query(driver.provider),
|
|
365
484
|
community_uuid=self.source_node_uuid,
|
|
366
485
|
entity_uuid=self.target_node_uuid,
|
|
367
486
|
uuid=self.uuid,
|
|
368
487
|
group_id=self.group_id,
|
|
369
488
|
created_at=self.created_at,
|
|
370
|
-
database_=DEFAULT_DATABASE,
|
|
371
489
|
)
|
|
372
490
|
|
|
373
|
-
logger.debug(f'Saved edge to
|
|
491
|
+
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
374
492
|
|
|
375
493
|
return result
|
|
376
494
|
|
|
377
495
|
@classmethod
|
|
378
|
-
async def get_by_uuid(cls, driver:
|
|
496
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
379
497
|
records, _, _ = await driver.execute_query(
|
|
380
498
|
"""
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
n.uuid AS source_node_uuid,
|
|
386
|
-
m.uuid AS target_node_uuid,
|
|
387
|
-
e.created_at AS created_at
|
|
388
|
-
""",
|
|
499
|
+
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
|
|
500
|
+
RETURN
|
|
501
|
+
"""
|
|
502
|
+
+ COMMUNITY_EDGE_RETURN,
|
|
389
503
|
uuid=uuid,
|
|
390
|
-
database_=DEFAULT_DATABASE,
|
|
391
504
|
routing_='r',
|
|
392
505
|
)
|
|
393
506
|
|
|
@@ -396,20 +509,15 @@ class CommunityEdge(Edge):
|
|
|
396
509
|
return edges[0]
|
|
397
510
|
|
|
398
511
|
@classmethod
|
|
399
|
-
async def get_by_uuids(cls, driver:
|
|
512
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
400
513
|
records, _, _ = await driver.execute_query(
|
|
401
514
|
"""
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
n.uuid AS source_node_uuid,
|
|
408
|
-
m.uuid AS target_node_uuid,
|
|
409
|
-
e.created_at AS created_at
|
|
410
|
-
""",
|
|
515
|
+
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
|
|
516
|
+
WHERE e.uuid IN $uuids
|
|
517
|
+
RETURN
|
|
518
|
+
"""
|
|
519
|
+
+ COMMUNITY_EDGE_RETURN,
|
|
411
520
|
uuids=uuids,
|
|
412
|
-
database_=DEFAULT_DATABASE,
|
|
413
521
|
routing_='r',
|
|
414
522
|
)
|
|
415
523
|
|
|
@@ -420,7 +528,7 @@ class CommunityEdge(Edge):
|
|
|
420
528
|
@classmethod
|
|
421
529
|
async def get_by_group_ids(
|
|
422
530
|
cls,
|
|
423
|
-
driver:
|
|
531
|
+
driver: GraphDriver,
|
|
424
532
|
group_ids: list[str],
|
|
425
533
|
limit: int | None = None,
|
|
426
534
|
uuid_cursor: str | None = None,
|
|
@@ -430,24 +538,21 @@ class CommunityEdge(Edge):
|
|
|
430
538
|
|
|
431
539
|
records, _, _ = await driver.execute_query(
|
|
432
540
|
"""
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
541
|
+
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
|
|
542
|
+
WHERE e.group_id IN $group_ids
|
|
543
|
+
"""
|
|
436
544
|
+ cursor_query
|
|
437
545
|
+ """
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
ORDER BY e.uuid DESC
|
|
445
|
-
"""
|
|
546
|
+
RETURN
|
|
547
|
+
"""
|
|
548
|
+
+ COMMUNITY_EDGE_RETURN
|
|
549
|
+
+ """
|
|
550
|
+
ORDER BY e.uuid DESC
|
|
551
|
+
"""
|
|
446
552
|
+ limit_query,
|
|
447
553
|
group_ids=group_ids,
|
|
448
554
|
uuid=uuid_cursor,
|
|
449
555
|
limit=limit,
|
|
450
|
-
database_=DEFAULT_DATABASE,
|
|
451
556
|
routing_='r',
|
|
452
557
|
)
|
|
453
558
|
|
|
@@ -463,38 +568,45 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
|
463
568
|
group_id=record['group_id'],
|
|
464
569
|
source_node_uuid=record['source_node_uuid'],
|
|
465
570
|
target_node_uuid=record['target_node_uuid'],
|
|
466
|
-
created_at=record['created_at']
|
|
571
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
467
572
|
)
|
|
468
573
|
|
|
469
574
|
|
|
470
|
-
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
575
|
+
def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
|
|
576
|
+
episodes = record['episodes']
|
|
577
|
+
if provider == GraphProvider.KUZU:
|
|
578
|
+
attributes = json.loads(record['attributes']) if record['attributes'] else {}
|
|
579
|
+
else:
|
|
580
|
+
attributes = record['attributes']
|
|
581
|
+
attributes.pop('uuid', None)
|
|
582
|
+
attributes.pop('source_node_uuid', None)
|
|
583
|
+
attributes.pop('target_node_uuid', None)
|
|
584
|
+
attributes.pop('fact', None)
|
|
585
|
+
attributes.pop('fact_embedding', None)
|
|
586
|
+
attributes.pop('name', None)
|
|
587
|
+
attributes.pop('group_id', None)
|
|
588
|
+
attributes.pop('episodes', None)
|
|
589
|
+
attributes.pop('created_at', None)
|
|
590
|
+
attributes.pop('expired_at', None)
|
|
591
|
+
attributes.pop('valid_at', None)
|
|
592
|
+
attributes.pop('invalid_at', None)
|
|
593
|
+
|
|
471
594
|
edge = EntityEdge(
|
|
472
595
|
uuid=record['uuid'],
|
|
473
596
|
source_node_uuid=record['source_node_uuid'],
|
|
474
597
|
target_node_uuid=record['target_node_uuid'],
|
|
475
598
|
fact=record['fact'],
|
|
599
|
+
fact_embedding=record.get('fact_embedding'),
|
|
476
600
|
name=record['name'],
|
|
477
601
|
group_id=record['group_id'],
|
|
478
|
-
episodes=
|
|
479
|
-
created_at=record['created_at']
|
|
602
|
+
episodes=episodes,
|
|
603
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
480
604
|
expired_at=parse_db_date(record['expired_at']),
|
|
481
605
|
valid_at=parse_db_date(record['valid_at']),
|
|
482
606
|
invalid_at=parse_db_date(record['invalid_at']),
|
|
483
|
-
attributes=
|
|
607
|
+
attributes=attributes,
|
|
484
608
|
)
|
|
485
609
|
|
|
486
|
-
edge.attributes.pop('uuid', None)
|
|
487
|
-
edge.attributes.pop('source_node_uuid', None)
|
|
488
|
-
edge.attributes.pop('target_node_uuid', None)
|
|
489
|
-
edge.attributes.pop('fact', None)
|
|
490
|
-
edge.attributes.pop('name', None)
|
|
491
|
-
edge.attributes.pop('group_id', None)
|
|
492
|
-
edge.attributes.pop('episodes', None)
|
|
493
|
-
edge.attributes.pop('created_at', None)
|
|
494
|
-
edge.attributes.pop('expired_at', None)
|
|
495
|
-
edge.attributes.pop('valid_at', None)
|
|
496
|
-
edge.attributes.pop('invalid_at', None)
|
|
497
|
-
|
|
498
610
|
return edge
|
|
499
611
|
|
|
500
612
|
|
|
@@ -504,13 +616,16 @@ def get_community_edge_from_record(record: Any):
|
|
|
504
616
|
group_id=record['group_id'],
|
|
505
617
|
source_node_uuid=record['source_node_uuid'],
|
|
506
618
|
target_node_uuid=record['target_node_uuid'],
|
|
507
|
-
created_at=record['created_at']
|
|
619
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
508
620
|
)
|
|
509
621
|
|
|
510
622
|
|
|
511
623
|
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
|
512
|
-
|
|
624
|
+
# filter out falsey values from edges
|
|
625
|
+
filtered_edges = [edge for edge in edges if edge.fact]
|
|
626
|
+
|
|
627
|
+
if len(filtered_edges) == 0:
|
|
513
628
|
return
|
|
514
|
-
fact_embeddings = await embedder.create_batch([edge.fact for edge in
|
|
515
|
-
for edge, fact_embedding in zip(
|
|
629
|
+
fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
|
|
630
|
+
for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
|
|
516
631
|
edge.fact_embedding = fact_embedding
|