graphiti-core 0.18.9__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +4 -0
- graphiti_core/driver/falkordb_driver.py +3 -14
- graphiti_core/driver/kuzu_driver.py +175 -0
- graphiti_core/driver/neptune_driver.py +301 -0
- graphiti_core/edges.py +155 -62
- graphiti_core/graph_queries.py +31 -2
- graphiti_core/graphiti.py +6 -1
- graphiti_core/helpers.py +8 -8
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/openai_base_client.py +12 -2
- graphiti_core/llm_client/openai_client.py +10 -2
- graphiti_core/migrations/__init__.py +0 -0
- graphiti_core/migrations/neo4j_node_group_labels.py +114 -0
- graphiti_core/models/edges/edge_db_queries.py +205 -76
- graphiti_core/models/nodes/node_db_queries.py +253 -74
- graphiti_core/nodes.py +271 -98
- graphiti_core/search/search.py +42 -12
- graphiti_core/search/search_config.py +4 -0
- graphiti_core/search/search_filters.py +35 -22
- graphiti_core/search/search_utils.py +1329 -392
- graphiti_core/utils/bulk_utils.py +50 -15
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +39 -32
- graphiti_core/utils/maintenance/edge_operations.py +47 -13
- graphiti_core/utils/maintenance/graph_data_operations.py +100 -15
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/METADATA +87 -13
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/RECORD +29 -25
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/edges.py
CHANGED
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
from abc import ABC, abstractmethod
|
|
19
20
|
from datetime import datetime
|
|
@@ -24,16 +25,16 @@ from uuid import uuid4
|
|
|
24
25
|
from pydantic import BaseModel, Field
|
|
25
26
|
from typing_extensions import LiteralString
|
|
26
27
|
|
|
27
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
28
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
28
29
|
from graphiti_core.embedder import EmbedderClient
|
|
29
30
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
30
31
|
from graphiti_core.helpers import parse_db_date
|
|
31
32
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
32
33
|
COMMUNITY_EDGE_RETURN,
|
|
33
|
-
ENTITY_EDGE_RETURN,
|
|
34
34
|
EPISODIC_EDGE_RETURN,
|
|
35
35
|
EPISODIC_EDGE_SAVE,
|
|
36
36
|
get_community_edge_save_query,
|
|
37
|
+
get_entity_edge_return_query,
|
|
37
38
|
get_entity_edge_save_query,
|
|
38
39
|
)
|
|
39
40
|
from graphiti_core.nodes import Node
|
|
@@ -52,33 +53,63 @@ class Edge(BaseModel, ABC):
|
|
|
52
53
|
async def save(self, driver: GraphDriver): ...
|
|
53
54
|
|
|
54
55
|
async def delete(self, driver: GraphDriver):
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
56
|
+
if driver.provider == GraphProvider.KUZU:
|
|
57
|
+
await driver.execute_query(
|
|
58
|
+
"""
|
|
59
|
+
MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
60
|
+
DELETE e
|
|
61
|
+
""",
|
|
62
|
+
uuid=self.uuid,
|
|
63
|
+
)
|
|
64
|
+
await driver.execute_query(
|
|
65
|
+
"""
|
|
66
|
+
MATCH (e:RelatesToNode_ {uuid: $uuid})
|
|
67
|
+
DETACH DELETE e
|
|
68
|
+
""",
|
|
69
|
+
uuid=self.uuid,
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
await driver.execute_query(
|
|
73
|
+
"""
|
|
74
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
75
|
+
DELETE e
|
|
76
|
+
""",
|
|
77
|
+
uuid=self.uuid,
|
|
78
|
+
)
|
|
62
79
|
|
|
63
80
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
64
81
|
|
|
65
|
-
return result
|
|
66
|
-
|
|
67
82
|
@classmethod
|
|
68
83
|
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
84
|
+
if driver.provider == GraphProvider.KUZU:
|
|
85
|
+
await driver.execute_query(
|
|
86
|
+
"""
|
|
87
|
+
MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
|
|
88
|
+
WHERE e.uuid IN $uuids
|
|
89
|
+
DELETE e
|
|
90
|
+
""",
|
|
91
|
+
uuids=uuids,
|
|
92
|
+
)
|
|
93
|
+
await driver.execute_query(
|
|
94
|
+
"""
|
|
95
|
+
MATCH (e:RelatesToNode_)
|
|
96
|
+
WHERE e.uuid IN $uuids
|
|
97
|
+
DETACH DELETE e
|
|
98
|
+
""",
|
|
99
|
+
uuids=uuids,
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
await driver.execute_query(
|
|
103
|
+
"""
|
|
104
|
+
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
|
|
105
|
+
WHERE e.uuid IN $uuids
|
|
106
|
+
DELETE e
|
|
107
|
+
""",
|
|
108
|
+
uuids=uuids,
|
|
109
|
+
)
|
|
77
110
|
|
|
78
111
|
logger.debug(f'Deleted Edges: {uuids}')
|
|
79
112
|
|
|
80
|
-
return result
|
|
81
|
-
|
|
82
113
|
def __hash__(self):
|
|
83
114
|
return hash(self.uuid)
|
|
84
115
|
|
|
@@ -165,7 +196,7 @@ class EpisodicEdge(Edge):
|
|
|
165
196
|
"""
|
|
166
197
|
+ EPISODIC_EDGE_RETURN
|
|
167
198
|
+ """
|
|
168
|
-
ORDER BY e.uuid DESC
|
|
199
|
+
ORDER BY e.uuid DESC
|
|
169
200
|
"""
|
|
170
201
|
+ limit_query,
|
|
171
202
|
group_ids=group_ids,
|
|
@@ -214,11 +245,25 @@ class EntityEdge(Edge):
|
|
|
214
245
|
return self.fact_embedding
|
|
215
246
|
|
|
216
247
|
async def load_fact_embedding(self, driver: GraphDriver):
|
|
217
|
-
|
|
218
|
-
"""
|
|
248
|
+
query = """
|
|
219
249
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
220
250
|
RETURN e.fact_embedding AS fact_embedding
|
|
221
|
-
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
254
|
+
query = """
|
|
255
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
256
|
+
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
|
257
|
+
"""
|
|
258
|
+
|
|
259
|
+
if driver.provider == GraphProvider.KUZU:
|
|
260
|
+
query = """
|
|
261
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
|
|
262
|
+
RETURN e.fact_embedding AS fact_embedding
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
records, _, _ = await driver.execute_query(
|
|
266
|
+
query,
|
|
222
267
|
uuid=self.uuid,
|
|
223
268
|
routing_='r',
|
|
224
269
|
)
|
|
@@ -244,12 +289,22 @@ class EntityEdge(Edge):
|
|
|
244
289
|
'invalid_at': self.invalid_at,
|
|
245
290
|
}
|
|
246
291
|
|
|
247
|
-
|
|
292
|
+
if driver.provider == GraphProvider.KUZU:
|
|
293
|
+
edge_data['attributes'] = json.dumps(self.attributes)
|
|
294
|
+
result = await driver.execute_query(
|
|
295
|
+
get_entity_edge_save_query(driver.provider),
|
|
296
|
+
**edge_data,
|
|
297
|
+
)
|
|
298
|
+
else:
|
|
299
|
+
edge_data.update(self.attributes or {})
|
|
248
300
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
301
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
302
|
+
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
|
303
|
+
|
|
304
|
+
result = await driver.execute_query(
|
|
305
|
+
get_entity_edge_save_query(driver.provider),
|
|
306
|
+
edge_data=edge_data,
|
|
307
|
+
)
|
|
253
308
|
|
|
254
309
|
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
255
310
|
|
|
@@ -257,17 +312,25 @@ class EntityEdge(Edge):
|
|
|
257
312
|
|
|
258
313
|
@classmethod
|
|
259
314
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
260
|
-
|
|
261
|
-
"""
|
|
315
|
+
match_query = """
|
|
262
316
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
317
|
+
"""
|
|
318
|
+
if driver.provider == GraphProvider.KUZU:
|
|
319
|
+
match_query = """
|
|
320
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
records, _, _ = await driver.execute_query(
|
|
324
|
+
match_query
|
|
325
|
+
+ """
|
|
263
326
|
RETURN
|
|
264
327
|
"""
|
|
265
|
-
+
|
|
328
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
266
329
|
uuid=uuid,
|
|
267
330
|
routing_='r',
|
|
268
331
|
)
|
|
269
332
|
|
|
270
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
333
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
271
334
|
|
|
272
335
|
if len(edges) == 0:
|
|
273
336
|
raise EdgeNotFoundError(uuid)
|
|
@@ -278,18 +341,26 @@ class EntityEdge(Edge):
|
|
|
278
341
|
if len(uuids) == 0:
|
|
279
342
|
return []
|
|
280
343
|
|
|
281
|
-
|
|
282
|
-
"""
|
|
344
|
+
match_query = """
|
|
283
345
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
346
|
+
"""
|
|
347
|
+
if driver.provider == GraphProvider.KUZU:
|
|
348
|
+
match_query = """
|
|
349
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
350
|
+
"""
|
|
351
|
+
|
|
352
|
+
records, _, _ = await driver.execute_query(
|
|
353
|
+
match_query
|
|
354
|
+
+ """
|
|
284
355
|
WHERE e.uuid IN $uuids
|
|
285
356
|
RETURN
|
|
286
357
|
"""
|
|
287
|
-
+
|
|
358
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
288
359
|
uuids=uuids,
|
|
289
360
|
routing_='r',
|
|
290
361
|
)
|
|
291
362
|
|
|
292
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
363
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
293
364
|
|
|
294
365
|
return edges
|
|
295
366
|
|
|
@@ -312,19 +383,27 @@ class EntityEdge(Edge):
|
|
|
312
383
|
else ''
|
|
313
384
|
)
|
|
314
385
|
|
|
315
|
-
|
|
316
|
-
"""
|
|
386
|
+
match_query = """
|
|
317
387
|
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
388
|
+
"""
|
|
389
|
+
if driver.provider == GraphProvider.KUZU:
|
|
390
|
+
match_query = """
|
|
391
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
392
|
+
"""
|
|
393
|
+
|
|
394
|
+
records, _, _ = await driver.execute_query(
|
|
395
|
+
match_query
|
|
396
|
+
+ """
|
|
318
397
|
WHERE e.group_id IN $group_ids
|
|
319
398
|
"""
|
|
320
399
|
+ cursor_query
|
|
321
400
|
+ """
|
|
322
401
|
RETURN
|
|
323
402
|
"""
|
|
324
|
-
+
|
|
403
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
325
404
|
+ with_embeddings_query
|
|
326
405
|
+ """
|
|
327
|
-
ORDER BY e.uuid DESC
|
|
406
|
+
ORDER BY e.uuid DESC
|
|
328
407
|
"""
|
|
329
408
|
+ limit_query,
|
|
330
409
|
group_ids=group_ids,
|
|
@@ -333,7 +412,7 @@ class EntityEdge(Edge):
|
|
|
333
412
|
routing_='r',
|
|
334
413
|
)
|
|
335
414
|
|
|
336
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
415
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
337
416
|
|
|
338
417
|
if len(edges) == 0:
|
|
339
418
|
raise GroupsEdgesNotFoundError(group_ids)
|
|
@@ -341,17 +420,25 @@ class EntityEdge(Edge):
|
|
|
341
420
|
|
|
342
421
|
@classmethod
|
|
343
422
|
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
|
344
|
-
|
|
345
|
-
"""
|
|
423
|
+
match_query = """
|
|
346
424
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
|
425
|
+
"""
|
|
426
|
+
if driver.provider == GraphProvider.KUZU:
|
|
427
|
+
match_query = """
|
|
428
|
+
MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
429
|
+
"""
|
|
430
|
+
|
|
431
|
+
records, _, _ = await driver.execute_query(
|
|
432
|
+
match_query
|
|
433
|
+
+ """
|
|
347
434
|
RETURN
|
|
348
435
|
"""
|
|
349
|
-
+
|
|
436
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
350
437
|
node_uuid=node_uuid,
|
|
351
438
|
routing_='r',
|
|
352
439
|
)
|
|
353
440
|
|
|
354
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
441
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
355
442
|
|
|
356
443
|
return edges
|
|
357
444
|
|
|
@@ -451,7 +538,25 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
|
451
538
|
)
|
|
452
539
|
|
|
453
540
|
|
|
454
|
-
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
541
|
+
def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
|
|
542
|
+
episodes = record['episodes']
|
|
543
|
+
if provider == GraphProvider.KUZU:
|
|
544
|
+
attributes = json.loads(record['attributes']) if record['attributes'] else {}
|
|
545
|
+
else:
|
|
546
|
+
attributes = record['attributes']
|
|
547
|
+
attributes.pop('uuid', None)
|
|
548
|
+
attributes.pop('source_node_uuid', None)
|
|
549
|
+
attributes.pop('target_node_uuid', None)
|
|
550
|
+
attributes.pop('fact', None)
|
|
551
|
+
attributes.pop('fact_embedding', None)
|
|
552
|
+
attributes.pop('name', None)
|
|
553
|
+
attributes.pop('group_id', None)
|
|
554
|
+
attributes.pop('episodes', None)
|
|
555
|
+
attributes.pop('created_at', None)
|
|
556
|
+
attributes.pop('expired_at', None)
|
|
557
|
+
attributes.pop('valid_at', None)
|
|
558
|
+
attributes.pop('invalid_at', None)
|
|
559
|
+
|
|
455
560
|
edge = EntityEdge(
|
|
456
561
|
uuid=record['uuid'],
|
|
457
562
|
source_node_uuid=record['source_node_uuid'],
|
|
@@ -460,26 +565,14 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
|
460
565
|
fact_embedding=record.get('fact_embedding'),
|
|
461
566
|
name=record['name'],
|
|
462
567
|
group_id=record['group_id'],
|
|
463
|
-
episodes=
|
|
568
|
+
episodes=episodes,
|
|
464
569
|
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
465
570
|
expired_at=parse_db_date(record['expired_at']),
|
|
466
571
|
valid_at=parse_db_date(record['valid_at']),
|
|
467
572
|
invalid_at=parse_db_date(record['invalid_at']),
|
|
468
|
-
attributes=
|
|
573
|
+
attributes=attributes,
|
|
469
574
|
)
|
|
470
575
|
|
|
471
|
-
edge.attributes.pop('uuid', None)
|
|
472
|
-
edge.attributes.pop('source_node_uuid', None)
|
|
473
|
-
edge.attributes.pop('target_node_uuid', None)
|
|
474
|
-
edge.attributes.pop('fact', None)
|
|
475
|
-
edge.attributes.pop('name', None)
|
|
476
|
-
edge.attributes.pop('group_id', None)
|
|
477
|
-
edge.attributes.pop('episodes', None)
|
|
478
|
-
edge.attributes.pop('created_at', None)
|
|
479
|
-
edge.attributes.pop('expired_at', None)
|
|
480
|
-
edge.attributes.pop('valid_at', None)
|
|
481
|
-
edge.attributes.pop('invalid_at', None)
|
|
482
|
-
|
|
483
576
|
return edge
|
|
484
577
|
|
|
485
578
|
|
graphiti_core/graph_queries.py
CHANGED
|
@@ -16,6 +16,13 @@ NEO4J_TO_FALKORDB_MAPPING = {
|
|
|
16
16
|
'episode_content': 'Episodic',
|
|
17
17
|
'edge_name_and_fact': 'RELATES_TO',
|
|
18
18
|
}
|
|
19
|
+
# Mapping from fulltext index names to Kuzu node labels
|
|
20
|
+
INDEX_TO_LABEL_KUZU_MAPPING = {
|
|
21
|
+
'node_name_and_summary': 'Entity',
|
|
22
|
+
'community_name': 'Community',
|
|
23
|
+
'episode_content': 'Episodic',
|
|
24
|
+
'edge_name_and_fact': 'RelatesToNode_',
|
|
25
|
+
}
|
|
19
26
|
|
|
20
27
|
|
|
21
28
|
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
@@ -35,6 +42,9 @@ def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
|
35
42
|
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
|
36
43
|
]
|
|
37
44
|
|
|
45
|
+
if provider == GraphProvider.KUZU:
|
|
46
|
+
return []
|
|
47
|
+
|
|
38
48
|
return [
|
|
39
49
|
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
|
40
50
|
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
|
@@ -68,6 +78,14 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
|
68
78
|
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
|
|
69
79
|
]
|
|
70
80
|
|
|
81
|
+
if provider == GraphProvider.KUZU:
|
|
82
|
+
return [
|
|
83
|
+
"CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
|
|
84
|
+
"CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
|
|
85
|
+
"CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
|
|
86
|
+
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
|
|
87
|
+
]
|
|
88
|
+
|
|
71
89
|
return [
|
|
72
90
|
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
|
73
91
|
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
|
|
@@ -80,11 +98,15 @@ def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
|
|
|
80
98
|
]
|
|
81
99
|
|
|
82
100
|
|
|
83
|
-
def get_nodes_query(
|
|
101
|
+
def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
|
|
84
102
|
if provider == GraphProvider.FALKORDB:
|
|
85
103
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
86
104
|
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
|
|
87
105
|
|
|
106
|
+
if provider == GraphProvider.KUZU:
|
|
107
|
+
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
|
108
|
+
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
|
|
109
|
+
|
|
88
110
|
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
|
|
89
111
|
|
|
90
112
|
|
|
@@ -93,12 +115,19 @@ def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
|
|
|
93
115
|
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
|
|
94
116
|
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
|
|
95
117
|
|
|
118
|
+
if provider == GraphProvider.KUZU:
|
|
119
|
+
return f'array_cosine_similarity({vec1}, {vec2})'
|
|
120
|
+
|
|
96
121
|
return f'vector.similarity.cosine({vec1}, {vec2})'
|
|
97
122
|
|
|
98
123
|
|
|
99
|
-
def get_relationships_query(name: str, provider: GraphProvider) -> str:
|
|
124
|
+
def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
|
|
100
125
|
if provider == GraphProvider.FALKORDB:
|
|
101
126
|
label = NEO4J_TO_FALKORDB_MAPPING[name]
|
|
102
127
|
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
|
|
103
128
|
|
|
129
|
+
if provider == GraphProvider.KUZU:
|
|
130
|
+
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
|
|
131
|
+
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
|
|
132
|
+
|
|
104
133
|
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
|
graphiti_core/graphiti.py
CHANGED
|
@@ -89,6 +89,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|
|
89
89
|
)
|
|
90
90
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
91
91
|
EPISODE_WINDOW_LEN,
|
|
92
|
+
build_dynamic_indexes,
|
|
92
93
|
build_indices_and_constraints,
|
|
93
94
|
retrieve_episodes,
|
|
94
95
|
)
|
|
@@ -450,6 +451,7 @@ class Graphiti:
|
|
|
450
451
|
|
|
451
452
|
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
|
452
453
|
validate_group_id(group_id)
|
|
454
|
+
await build_dynamic_indexes(self.driver, group_id)
|
|
453
455
|
|
|
454
456
|
previous_episodes = (
|
|
455
457
|
await self.retrieve_episodes(
|
|
@@ -625,6 +627,7 @@ class Graphiti:
|
|
|
625
627
|
# if group_id is None, use the default group id by the provider
|
|
626
628
|
group_id = group_id or get_default_group_id(self.driver.provider)
|
|
627
629
|
validate_group_id(group_id)
|
|
630
|
+
await build_dynamic_indexes(self.driver, group_id)
|
|
628
631
|
|
|
629
632
|
# Create default edge type map
|
|
630
633
|
edge_type_map_default = (
|
|
@@ -1006,6 +1009,8 @@ class Graphiti:
|
|
|
1006
1009
|
if edge.fact_embedding is None:
|
|
1007
1010
|
await edge.generate_embedding(self.embedder)
|
|
1008
1011
|
|
|
1012
|
+
await build_dynamic_indexes(self.driver, source_node.group_id)
|
|
1013
|
+
|
|
1009
1014
|
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
|
1010
1015
|
self.clients,
|
|
1011
1016
|
[source_node, target_node],
|
|
@@ -1068,7 +1073,7 @@ class Graphiti:
|
|
|
1068
1073
|
if record['episode_count'] == 1:
|
|
1069
1074
|
nodes_to_delete.append(node)
|
|
1070
1075
|
|
|
1076
|
+
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
|
|
1071
1077
|
await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
|
|
1072
1078
|
|
|
1073
|
-
await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
|
|
1074
1079
|
await episode.delete(self.driver)
|
graphiti_core/helpers.py
CHANGED
|
@@ -43,14 +43,14 @@ RUNTIME_QUERY: LiteralString = (
|
|
|
43
43
|
)
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
def parse_db_date(
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
46
|
+
def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
|
|
47
|
+
if isinstance(input_date, neo4j_time.DateTime):
|
|
48
|
+
return input_date.to_native()
|
|
49
|
+
|
|
50
|
+
if isinstance(input_date, str):
|
|
51
|
+
return datetime.fromisoformat(input_date)
|
|
52
|
+
|
|
53
|
+
return input_date
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
def get_default_group_id(provider: GraphProvider) -> str:
|
|
@@ -31,8 +31,10 @@ from .errors import RateLimitError, RefusalError
|
|
|
31
31
|
|
|
32
32
|
logger = logging.getLogger(__name__)
|
|
33
33
|
|
|
34
|
-
DEFAULT_MODEL = 'gpt-
|
|
35
|
-
DEFAULT_SMALL_MODEL = 'gpt-
|
|
34
|
+
DEFAULT_MODEL = 'gpt-5-mini'
|
|
35
|
+
DEFAULT_SMALL_MODEL = 'gpt-5-nano'
|
|
36
|
+
DEFAULT_REASONING = 'minimal'
|
|
37
|
+
DEFAULT_VERBOSITY = 'low'
|
|
36
38
|
|
|
37
39
|
|
|
38
40
|
class BaseOpenAIClient(LLMClient):
|
|
@@ -51,6 +53,8 @@ class BaseOpenAIClient(LLMClient):
|
|
|
51
53
|
config: LLMConfig | None = None,
|
|
52
54
|
cache: bool = False,
|
|
53
55
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
56
|
+
reasoning: str | None = DEFAULT_REASONING,
|
|
57
|
+
verbosity: str | None = DEFAULT_VERBOSITY,
|
|
54
58
|
):
|
|
55
59
|
if cache:
|
|
56
60
|
raise NotImplementedError('Caching is not implemented for OpenAI-based clients')
|
|
@@ -60,6 +64,8 @@ class BaseOpenAIClient(LLMClient):
|
|
|
60
64
|
|
|
61
65
|
super().__init__(config, cache)
|
|
62
66
|
self.max_tokens = max_tokens
|
|
67
|
+
self.reasoning = reasoning
|
|
68
|
+
self.verbosity = verbosity
|
|
63
69
|
|
|
64
70
|
@abstractmethod
|
|
65
71
|
async def _create_completion(
|
|
@@ -81,6 +87,8 @@ class BaseOpenAIClient(LLMClient):
|
|
|
81
87
|
temperature: float | None,
|
|
82
88
|
max_tokens: int,
|
|
83
89
|
response_model: type[BaseModel],
|
|
90
|
+
reasoning: str | None,
|
|
91
|
+
verbosity: str | None,
|
|
84
92
|
) -> Any:
|
|
85
93
|
"""Create a structured completion using the specific client implementation."""
|
|
86
94
|
pass
|
|
@@ -140,6 +148,8 @@ class BaseOpenAIClient(LLMClient):
|
|
|
140
148
|
temperature=self.temperature,
|
|
141
149
|
max_tokens=max_tokens or self.max_tokens,
|
|
142
150
|
response_model=response_model,
|
|
151
|
+
reasoning=self.reasoning,
|
|
152
|
+
verbosity=self.verbosity,
|
|
143
153
|
)
|
|
144
154
|
return self._handle_structured_response(response)
|
|
145
155
|
else:
|
|
@@ -21,7 +21,7 @@ from openai.types.chat import ChatCompletionMessageParam
|
|
|
21
21
|
from pydantic import BaseModel
|
|
22
22
|
|
|
23
23
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
|
24
|
-
from .openai_base_client import BaseOpenAIClient
|
|
24
|
+
from .openai_base_client import DEFAULT_REASONING, DEFAULT_VERBOSITY, BaseOpenAIClient
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class OpenAIClient(BaseOpenAIClient):
|
|
@@ -41,6 +41,8 @@ class OpenAIClient(BaseOpenAIClient):
|
|
|
41
41
|
cache: bool = False,
|
|
42
42
|
client: typing.Any = None,
|
|
43
43
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
44
|
+
reasoning: str = DEFAULT_REASONING,
|
|
45
|
+
verbosity: str = DEFAULT_VERBOSITY,
|
|
44
46
|
):
|
|
45
47
|
"""
|
|
46
48
|
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
|
@@ -50,7 +52,7 @@ class OpenAIClient(BaseOpenAIClient):
|
|
|
50
52
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
51
53
|
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
52
54
|
"""
|
|
53
|
-
super().__init__(config, cache, max_tokens)
|
|
55
|
+
super().__init__(config, cache, max_tokens, reasoning, verbosity)
|
|
54
56
|
|
|
55
57
|
if config is None:
|
|
56
58
|
config = LLMConfig()
|
|
@@ -67,6 +69,8 @@ class OpenAIClient(BaseOpenAIClient):
|
|
|
67
69
|
temperature: float | None,
|
|
68
70
|
max_tokens: int,
|
|
69
71
|
response_model: type[BaseModel],
|
|
72
|
+
reasoning: str | None = None,
|
|
73
|
+
verbosity: str | None = None,
|
|
70
74
|
):
|
|
71
75
|
"""Create a structured completion using OpenAI's beta parse API."""
|
|
72
76
|
response = await self.client.responses.parse(
|
|
@@ -75,6 +79,8 @@ class OpenAIClient(BaseOpenAIClient):
|
|
|
75
79
|
temperature=temperature,
|
|
76
80
|
max_output_tokens=max_tokens,
|
|
77
81
|
text_format=response_model, # type: ignore
|
|
82
|
+
reasoning={'effort': reasoning} if reasoning is not None else None, # type: ignore
|
|
83
|
+
text={'verbosity': verbosity} if verbosity is not None else None, # type: ignore
|
|
78
84
|
)
|
|
79
85
|
|
|
80
86
|
return response
|
|
@@ -86,6 +92,8 @@ class OpenAIClient(BaseOpenAIClient):
|
|
|
86
92
|
temperature: float | None,
|
|
87
93
|
max_tokens: int,
|
|
88
94
|
response_model: type[BaseModel] | None = None,
|
|
95
|
+
reasoning: str | None = None,
|
|
96
|
+
verbosity: str | None = None,
|
|
89
97
|
):
|
|
90
98
|
"""Create a regular completion with JSON format."""
|
|
91
99
|
return await self.client.chat.completions.create(
|
|
File without changes
|