graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +61 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +582 -255
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +21 -14
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +94 -50
- graphiti_core/llm_client/openai_client.py +28 -8
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +23 -14
- graphiti_core/prompts/extract_nodes.py +73 -32
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +109 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +286 -126
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +320 -158
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/edges.py
CHANGED
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
from abc import ABC, abstractmethod
|
|
19
20
|
from datetime import datetime
|
|
@@ -24,35 +25,22 @@ from uuid import uuid4
|
|
|
24
25
|
from pydantic import BaseModel, Field
|
|
25
26
|
from typing_extensions import LiteralString
|
|
26
27
|
|
|
27
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
28
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
28
29
|
from graphiti_core.embedder import EmbedderClient
|
|
29
30
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
30
31
|
from graphiti_core.helpers import parse_db_date
|
|
31
32
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
32
|
-
|
|
33
|
-
|
|
33
|
+
COMMUNITY_EDGE_RETURN,
|
|
34
|
+
EPISODIC_EDGE_RETURN,
|
|
34
35
|
EPISODIC_EDGE_SAVE,
|
|
36
|
+
get_community_edge_save_query,
|
|
37
|
+
get_entity_edge_return_query,
|
|
38
|
+
get_entity_edge_save_query,
|
|
35
39
|
)
|
|
36
40
|
from graphiti_core.nodes import Node
|
|
37
41
|
|
|
38
42
|
logger = logging.getLogger(__name__)
|
|
39
43
|
|
|
40
|
-
ENTITY_EDGE_RETURN: LiteralString = """
|
|
41
|
-
RETURN
|
|
42
|
-
e.uuid AS uuid,
|
|
43
|
-
startNode(e).uuid AS source_node_uuid,
|
|
44
|
-
endNode(e).uuid AS target_node_uuid,
|
|
45
|
-
e.created_at AS created_at,
|
|
46
|
-
e.name AS name,
|
|
47
|
-
e.group_id AS group_id,
|
|
48
|
-
e.fact AS fact,
|
|
49
|
-
e.episodes AS episodes,
|
|
50
|
-
e.expired_at AS expired_at,
|
|
51
|
-
e.valid_at AS valid_at,
|
|
52
|
-
e.invalid_at AS invalid_at,
|
|
53
|
-
properties(e) AS attributes
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
44
|
|
|
57
45
|
class Edge(BaseModel, ABC):
|
|
58
46
|
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
|
@@ -65,17 +53,68 @@ class Edge(BaseModel, ABC):
|
|
|
65
53
|
async def save(self, driver: GraphDriver): ...
|
|
66
54
|
|
|
67
55
|
async def delete(self, driver: GraphDriver):
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
56
|
+
if driver.graph_operations_interface:
|
|
57
|
+
return await driver.graph_operations_interface.edge_delete(self, driver)
|
|
58
|
+
|
|
59
|
+
if driver.provider == GraphProvider.KUZU:
|
|
60
|
+
await driver.execute_query(
|
|
61
|
+
"""
|
|
62
|
+
MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
63
|
+
DELETE e
|
|
64
|
+
""",
|
|
65
|
+
uuid=self.uuid,
|
|
66
|
+
)
|
|
67
|
+
await driver.execute_query(
|
|
68
|
+
"""
|
|
69
|
+
MATCH (e:RelatesToNode_ {uuid: $uuid})
|
|
70
|
+
DETACH DELETE e
|
|
71
|
+
""",
|
|
72
|
+
uuid=self.uuid,
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
await driver.execute_query(
|
|
76
|
+
"""
|
|
77
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
78
|
+
DELETE e
|
|
79
|
+
""",
|
|
80
|
+
uuid=self.uuid,
|
|
81
|
+
)
|
|
75
82
|
|
|
76
83
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
77
84
|
|
|
78
|
-
|
|
85
|
+
@classmethod
|
|
86
|
+
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
87
|
+
if driver.graph_operations_interface:
|
|
88
|
+
return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
|
|
89
|
+
|
|
90
|
+
if driver.provider == GraphProvider.KUZU:
|
|
91
|
+
await driver.execute_query(
|
|
92
|
+
"""
|
|
93
|
+
MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
|
|
94
|
+
WHERE e.uuid IN $uuids
|
|
95
|
+
DELETE e
|
|
96
|
+
""",
|
|
97
|
+
uuids=uuids,
|
|
98
|
+
)
|
|
99
|
+
await driver.execute_query(
|
|
100
|
+
"""
|
|
101
|
+
MATCH (e:RelatesToNode_)
|
|
102
|
+
WHERE e.uuid IN $uuids
|
|
103
|
+
DETACH DELETE e
|
|
104
|
+
""",
|
|
105
|
+
uuids=uuids,
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
await driver.execute_query(
|
|
109
|
+
"""
|
|
110
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
|
|
111
|
+
WHERE e.uuid IN $uuids
|
|
112
|
+
DELETE e
|
|
113
|
+
""",
|
|
114
|
+
uuids=uuids,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
logger.debug(f'Deleted Edges: {uuids}')
|
|
79
118
|
|
|
80
119
|
def __hash__(self):
|
|
81
120
|
return hash(self.uuid)
|
|
@@ -108,14 +147,10 @@ class EpisodicEdge(Edge):
|
|
|
108
147
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
109
148
|
records, _, _ = await driver.execute_query(
|
|
110
149
|
"""
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
n.uuid AS source_node_uuid,
|
|
116
|
-
m.uuid AS target_node_uuid,
|
|
117
|
-
e.created_at AS created_at
|
|
118
|
-
""",
|
|
150
|
+
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
151
|
+
RETURN
|
|
152
|
+
"""
|
|
153
|
+
+ EPISODIC_EDGE_RETURN,
|
|
119
154
|
uuid=uuid,
|
|
120
155
|
routing_='r',
|
|
121
156
|
)
|
|
@@ -130,15 +165,11 @@ class EpisodicEdge(Edge):
|
|
|
130
165
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
131
166
|
records, _, _ = await driver.execute_query(
|
|
132
167
|
"""
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
n.uuid AS source_node_uuid,
|
|
139
|
-
m.uuid AS target_node_uuid,
|
|
140
|
-
e.created_at AS created_at
|
|
141
|
-
""",
|
|
168
|
+
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
|
169
|
+
WHERE e.uuid IN $uuids
|
|
170
|
+
RETURN
|
|
171
|
+
"""
|
|
172
|
+
+ EPISODIC_EDGE_RETURN,
|
|
142
173
|
uuids=uuids,
|
|
143
174
|
routing_='r',
|
|
144
175
|
)
|
|
@@ -162,19 +193,17 @@ class EpisodicEdge(Edge):
|
|
|
162
193
|
|
|
163
194
|
records, _, _ = await driver.execute_query(
|
|
164
195
|
"""
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
196
|
+
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
|
197
|
+
WHERE e.group_id IN $group_ids
|
|
198
|
+
"""
|
|
168
199
|
+ cursor_query
|
|
169
200
|
+ """
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
ORDER BY e.uuid DESC
|
|
177
|
-
"""
|
|
201
|
+
RETURN
|
|
202
|
+
"""
|
|
203
|
+
+ EPISODIC_EDGE_RETURN
|
|
204
|
+
+ """
|
|
205
|
+
ORDER BY e.uuid DESC
|
|
206
|
+
"""
|
|
178
207
|
+ limit_query,
|
|
179
208
|
group_ids=group_ids,
|
|
180
209
|
uuid=uuid_cursor,
|
|
@@ -222,11 +251,31 @@ class EntityEdge(Edge):
|
|
|
222
251
|
return self.fact_embedding
|
|
223
252
|
|
|
224
253
|
async def load_fact_embedding(self, driver: GraphDriver):
|
|
225
|
-
|
|
254
|
+
if driver.graph_operations_interface:
|
|
255
|
+
return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
|
|
256
|
+
|
|
257
|
+
query = """
|
|
226
258
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
227
259
|
RETURN e.fact_embedding AS fact_embedding
|
|
228
260
|
"""
|
|
229
|
-
|
|
261
|
+
|
|
262
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
263
|
+
query = """
|
|
264
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
265
|
+
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
if driver.provider == GraphProvider.KUZU:
|
|
269
|
+
query = """
|
|
270
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
|
|
271
|
+
RETURN e.fact_embedding AS fact_embedding
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
records, _, _ = await driver.execute_query(
|
|
275
|
+
query,
|
|
276
|
+
uuid=self.uuid,
|
|
277
|
+
routing_='r',
|
|
278
|
+
)
|
|
230
279
|
|
|
231
280
|
if len(records) == 0:
|
|
232
281
|
raise EdgeNotFoundError(self.uuid)
|
|
@@ -249,12 +298,18 @@ class EntityEdge(Edge):
|
|
|
249
298
|
'invalid_at': self.invalid_at,
|
|
250
299
|
}
|
|
251
300
|
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
301
|
+
if driver.provider == GraphProvider.KUZU:
|
|
302
|
+
edge_data['attributes'] = json.dumps(self.attributes)
|
|
303
|
+
result = await driver.execute_query(
|
|
304
|
+
get_entity_edge_save_query(driver.provider),
|
|
305
|
+
**edge_data,
|
|
306
|
+
)
|
|
307
|
+
else:
|
|
308
|
+
edge_data.update(self.attributes or {})
|
|
309
|
+
result = await driver.execute_query(
|
|
310
|
+
get_entity_edge_save_query(driver.provider),
|
|
311
|
+
edge_data=edge_data,
|
|
312
|
+
)
|
|
258
313
|
|
|
259
314
|
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
260
315
|
|
|
@@ -262,37 +317,84 @@ class EntityEdge(Edge):
|
|
|
262
317
|
|
|
263
318
|
@classmethod
|
|
264
319
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
320
|
+
match_query = """
|
|
321
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
322
|
+
"""
|
|
323
|
+
if driver.provider == GraphProvider.KUZU:
|
|
324
|
+
match_query = """
|
|
325
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
|
|
326
|
+
"""
|
|
327
|
+
|
|
265
328
|
records, _, _ = await driver.execute_query(
|
|
329
|
+
match_query
|
|
330
|
+
+ """
|
|
331
|
+
RETURN
|
|
266
332
|
"""
|
|
267
|
-
|
|
268
|
-
"""
|
|
269
|
-
+ ENTITY_EDGE_RETURN,
|
|
333
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
270
334
|
uuid=uuid,
|
|
271
335
|
routing_='r',
|
|
272
336
|
)
|
|
273
337
|
|
|
274
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
338
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
275
339
|
|
|
276
340
|
if len(edges) == 0:
|
|
277
341
|
raise EdgeNotFoundError(uuid)
|
|
278
342
|
return edges[0]
|
|
279
343
|
|
|
344
|
+
@classmethod
|
|
345
|
+
async def get_between_nodes(
|
|
346
|
+
cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
|
|
347
|
+
):
|
|
348
|
+
match_query = """
|
|
349
|
+
MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
|
350
|
+
"""
|
|
351
|
+
if driver.provider == GraphProvider.KUZU:
|
|
352
|
+
match_query = """
|
|
353
|
+
MATCH (n:Entity {uuid: $source_node_uuid})
|
|
354
|
+
-[:RELATES_TO]->(e:RelatesToNode_)
|
|
355
|
+
-[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
records, _, _ = await driver.execute_query(
|
|
359
|
+
match_query
|
|
360
|
+
+ """
|
|
361
|
+
RETURN
|
|
362
|
+
"""
|
|
363
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
364
|
+
source_node_uuid=source_node_uuid,
|
|
365
|
+
target_node_uuid=target_node_uuid,
|
|
366
|
+
routing_='r',
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
370
|
+
|
|
371
|
+
return edges
|
|
372
|
+
|
|
280
373
|
@classmethod
|
|
281
374
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
282
375
|
if len(uuids) == 0:
|
|
283
376
|
return []
|
|
284
377
|
|
|
378
|
+
match_query = """
|
|
379
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
380
|
+
"""
|
|
381
|
+
if driver.provider == GraphProvider.KUZU:
|
|
382
|
+
match_query = """
|
|
383
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
384
|
+
"""
|
|
385
|
+
|
|
285
386
|
records, _, _ = await driver.execute_query(
|
|
387
|
+
match_query
|
|
388
|
+
+ """
|
|
389
|
+
WHERE e.uuid IN $uuids
|
|
390
|
+
RETURN
|
|
286
391
|
"""
|
|
287
|
-
|
|
288
|
-
WHERE e.uuid IN $uuids
|
|
289
|
-
"""
|
|
290
|
-
+ ENTITY_EDGE_RETURN,
|
|
392
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
291
393
|
uuids=uuids,
|
|
292
394
|
routing_='r',
|
|
293
395
|
)
|
|
294
396
|
|
|
295
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
397
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
296
398
|
|
|
297
399
|
return edges
|
|
298
400
|
|
|
@@ -303,20 +405,40 @@ class EntityEdge(Edge):
|
|
|
303
405
|
group_ids: list[str],
|
|
304
406
|
limit: int | None = None,
|
|
305
407
|
uuid_cursor: str | None = None,
|
|
408
|
+
with_embeddings: bool = False,
|
|
306
409
|
):
|
|
307
410
|
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
|
|
308
411
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
412
|
+
with_embeddings_query: LiteralString = (
|
|
413
|
+
""",
|
|
414
|
+
e.fact_embedding AS fact_embedding
|
|
415
|
+
"""
|
|
416
|
+
if with_embeddings
|
|
417
|
+
else ''
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
match_query = """
|
|
421
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
422
|
+
"""
|
|
423
|
+
if driver.provider == GraphProvider.KUZU:
|
|
424
|
+
match_query = """
|
|
425
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
426
|
+
"""
|
|
309
427
|
|
|
310
428
|
records, _, _ = await driver.execute_query(
|
|
429
|
+
match_query
|
|
430
|
+
+ """
|
|
431
|
+
WHERE e.group_id IN $group_ids
|
|
311
432
|
"""
|
|
312
|
-
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
313
|
-
WHERE e.group_id IN $group_ids
|
|
314
|
-
"""
|
|
315
433
|
+ cursor_query
|
|
316
|
-
+ ENTITY_EDGE_RETURN
|
|
317
434
|
+ """
|
|
318
|
-
|
|
319
|
-
|
|
435
|
+
RETURN
|
|
436
|
+
"""
|
|
437
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
438
|
+
+ with_embeddings_query
|
|
439
|
+
+ """
|
|
440
|
+
ORDER BY e.uuid DESC
|
|
441
|
+
"""
|
|
320
442
|
+ limit_query,
|
|
321
443
|
group_ids=group_ids,
|
|
322
444
|
uuid=uuid_cursor,
|
|
@@ -324,7 +446,7 @@ class EntityEdge(Edge):
|
|
|
324
446
|
routing_='r',
|
|
325
447
|
)
|
|
326
448
|
|
|
327
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
449
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
328
450
|
|
|
329
451
|
if len(edges) == 0:
|
|
330
452
|
raise GroupsEdgesNotFoundError(group_ids)
|
|
@@ -332,15 +454,25 @@ class EntityEdge(Edge):
|
|
|
332
454
|
|
|
333
455
|
@classmethod
|
|
334
456
|
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
|
335
|
-
|
|
457
|
+
match_query = """
|
|
458
|
+
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
|
459
|
+
"""
|
|
460
|
+
if driver.provider == GraphProvider.KUZU:
|
|
461
|
+
match_query = """
|
|
462
|
+
MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
records, _, _ = await driver.execute_query(
|
|
466
|
+
match_query
|
|
467
|
+
+ """
|
|
468
|
+
RETURN
|
|
336
469
|
"""
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
470
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
471
|
+
node_uuid=node_uuid,
|
|
472
|
+
routing_='r',
|
|
340
473
|
)
|
|
341
|
-
records, _, _ = await driver.execute_query(query, node_uuid=node_uuid, routing_='r')
|
|
342
474
|
|
|
343
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
475
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
344
476
|
|
|
345
477
|
return edges
|
|
346
478
|
|
|
@@ -348,7 +480,7 @@ class EntityEdge(Edge):
|
|
|
348
480
|
class CommunityEdge(Edge):
|
|
349
481
|
async def save(self, driver: GraphDriver):
|
|
350
482
|
result = await driver.execute_query(
|
|
351
|
-
|
|
483
|
+
get_community_edge_save_query(driver.provider),
|
|
352
484
|
community_uuid=self.source_node_uuid,
|
|
353
485
|
entity_uuid=self.target_node_uuid,
|
|
354
486
|
uuid=self.uuid,
|
|
@@ -364,14 +496,10 @@ class CommunityEdge(Edge):
|
|
|
364
496
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
365
497
|
records, _, _ = await driver.execute_query(
|
|
366
498
|
"""
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
n.uuid AS source_node_uuid,
|
|
372
|
-
m.uuid AS target_node_uuid,
|
|
373
|
-
e.created_at AS created_at
|
|
374
|
-
""",
|
|
499
|
+
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
|
|
500
|
+
RETURN
|
|
501
|
+
"""
|
|
502
|
+
+ COMMUNITY_EDGE_RETURN,
|
|
375
503
|
uuid=uuid,
|
|
376
504
|
routing_='r',
|
|
377
505
|
)
|
|
@@ -384,15 +512,11 @@ class CommunityEdge(Edge):
|
|
|
384
512
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
385
513
|
records, _, _ = await driver.execute_query(
|
|
386
514
|
"""
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
n.uuid AS source_node_uuid,
|
|
393
|
-
m.uuid AS target_node_uuid,
|
|
394
|
-
e.created_at AS created_at
|
|
395
|
-
""",
|
|
515
|
+
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
|
|
516
|
+
WHERE e.uuid IN $uuids
|
|
517
|
+
RETURN
|
|
518
|
+
"""
|
|
519
|
+
+ COMMUNITY_EDGE_RETURN,
|
|
396
520
|
uuids=uuids,
|
|
397
521
|
routing_='r',
|
|
398
522
|
)
|
|
@@ -414,19 +538,17 @@ class CommunityEdge(Edge):
|
|
|
414
538
|
|
|
415
539
|
records, _, _ = await driver.execute_query(
|
|
416
540
|
"""
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
541
|
+
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
|
|
542
|
+
WHERE e.group_id IN $group_ids
|
|
543
|
+
"""
|
|
420
544
|
+ cursor_query
|
|
421
545
|
+ """
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
ORDER BY e.uuid DESC
|
|
429
|
-
"""
|
|
546
|
+
RETURN
|
|
547
|
+
"""
|
|
548
|
+
+ COMMUNITY_EDGE_RETURN
|
|
549
|
+
+ """
|
|
550
|
+
ORDER BY e.uuid DESC
|
|
551
|
+
"""
|
|
430
552
|
+ limit_query,
|
|
431
553
|
group_ids=group_ids,
|
|
432
554
|
uuid=uuid_cursor,
|
|
@@ -450,34 +572,41 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
|
450
572
|
)
|
|
451
573
|
|
|
452
574
|
|
|
453
|
-
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
575
|
+
def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
|
|
576
|
+
episodes = record['episodes']
|
|
577
|
+
if provider == GraphProvider.KUZU:
|
|
578
|
+
attributes = json.loads(record['attributes']) if record['attributes'] else {}
|
|
579
|
+
else:
|
|
580
|
+
attributes = record['attributes']
|
|
581
|
+
attributes.pop('uuid', None)
|
|
582
|
+
attributes.pop('source_node_uuid', None)
|
|
583
|
+
attributes.pop('target_node_uuid', None)
|
|
584
|
+
attributes.pop('fact', None)
|
|
585
|
+
attributes.pop('fact_embedding', None)
|
|
586
|
+
attributes.pop('name', None)
|
|
587
|
+
attributes.pop('group_id', None)
|
|
588
|
+
attributes.pop('episodes', None)
|
|
589
|
+
attributes.pop('created_at', None)
|
|
590
|
+
attributes.pop('expired_at', None)
|
|
591
|
+
attributes.pop('valid_at', None)
|
|
592
|
+
attributes.pop('invalid_at', None)
|
|
593
|
+
|
|
454
594
|
edge = EntityEdge(
|
|
455
595
|
uuid=record['uuid'],
|
|
456
596
|
source_node_uuid=record['source_node_uuid'],
|
|
457
597
|
target_node_uuid=record['target_node_uuid'],
|
|
458
598
|
fact=record['fact'],
|
|
599
|
+
fact_embedding=record.get('fact_embedding'),
|
|
459
600
|
name=record['name'],
|
|
460
601
|
group_id=record['group_id'],
|
|
461
|
-
episodes=
|
|
602
|
+
episodes=episodes,
|
|
462
603
|
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
463
604
|
expired_at=parse_db_date(record['expired_at']),
|
|
464
605
|
valid_at=parse_db_date(record['valid_at']),
|
|
465
606
|
invalid_at=parse_db_date(record['invalid_at']),
|
|
466
|
-
attributes=
|
|
607
|
+
attributes=attributes,
|
|
467
608
|
)
|
|
468
609
|
|
|
469
|
-
edge.attributes.pop('uuid', None)
|
|
470
|
-
edge.attributes.pop('source_node_uuid', None)
|
|
471
|
-
edge.attributes.pop('target_node_uuid', None)
|
|
472
|
-
edge.attributes.pop('fact', None)
|
|
473
|
-
edge.attributes.pop('name', None)
|
|
474
|
-
edge.attributes.pop('group_id', None)
|
|
475
|
-
edge.attributes.pop('episodes', None)
|
|
476
|
-
edge.attributes.pop('created_at', None)
|
|
477
|
-
edge.attributes.pop('expired_at', None)
|
|
478
|
-
edge.attributes.pop('valid_at', None)
|
|
479
|
-
edge.attributes.pop('invalid_at', None)
|
|
480
|
-
|
|
481
610
|
return edge
|
|
482
611
|
|
|
483
612
|
|
|
@@ -492,8 +621,11 @@ def get_community_edge_from_record(record: Any):
|
|
|
492
621
|
|
|
493
622
|
|
|
494
623
|
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
|
|
495
|
-
|
|
624
|
+
# filter out falsey values from edges
|
|
625
|
+
filtered_edges = [edge for edge in edges if edge.fact]
|
|
626
|
+
|
|
627
|
+
if len(filtered_edges) == 0:
|
|
496
628
|
return
|
|
497
|
-
fact_embeddings = await embedder.create_batch([edge.fact for edge in
|
|
498
|
-
for edge, fact_embedding in zip(
|
|
629
|
+
fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
|
|
630
|
+
for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
|
|
499
631
|
edge.fact_embedding = fact_embedding
|
|
@@ -17,7 +17,7 @@ limitations under the License.
|
|
|
17
17
|
import logging
|
|
18
18
|
from typing import Any
|
|
19
19
|
|
|
20
|
-
from openai import AsyncAzureOpenAI
|
|
20
|
+
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
21
21
|
|
|
22
22
|
from .client import EmbedderClient
|
|
23
23
|
|
|
@@ -25,9 +25,16 @@ logger = logging.getLogger(__name__)
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class AzureOpenAIEmbedderClient(EmbedderClient):
|
|
28
|
-
"""Wrapper class for
|
|
28
|
+
"""Wrapper class for Azure OpenAI that implements the EmbedderClient interface.
|
|
29
29
|
|
|
30
|
-
|
|
30
|
+
Supports both AsyncAzureOpenAI and AsyncOpenAI (with Azure v1 API endpoint).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
azure_client: AsyncAzureOpenAI | AsyncOpenAI,
|
|
36
|
+
model: str = 'text-embedding-3-small',
|
|
37
|
+
):
|
|
31
38
|
self.azure_client = azure_client
|
|
32
39
|
self.model = model
|
|
33
40
|
|
graphiti_core/embedder/client.py
CHANGED
|
@@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import os
|
|
17
18
|
from abc import ABC, abstractmethod
|
|
18
19
|
from collections.abc import Iterable
|
|
19
20
|
|
|
20
21
|
from pydantic import BaseModel, Field
|
|
21
22
|
|
|
22
|
-
EMBEDDING_DIM = 1024
|
|
23
|
+
EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class EmbedderConfig(BaseModel):
|