graphiti-core 0.18.9__py3-none-any.whl → 0.19.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +1 -0
- graphiti_core/driver/neptune_driver.py +299 -0
- graphiti_core/edges.py +35 -7
- graphiti_core/graphiti.py +2 -0
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/openai_base_client.py +12 -2
- graphiti_core/llm_client/openai_client.py +10 -2
- graphiti_core/migrations/__init__.py +0 -0
- graphiti_core/migrations/neo4j_node_group_labels.py +53 -0
- graphiti_core/models/edges/edge_db_queries.py +104 -54
- graphiti_core/models/nodes/node_db_queries.py +165 -65
- graphiti_core/nodes.py +121 -51
- graphiti_core/search/search.py +29 -9
- graphiti_core/search/search_utils.py +878 -267
- graphiti_core/utils/bulk_utils.py +6 -3
- graphiti_core/utils/maintenance/edge_operations.py +36 -13
- graphiti_core/utils/maintenance/graph_data_operations.py +59 -7
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0rc2.dist-info}/METADATA +44 -6
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0rc2.dist-info}/RECORD +21 -18
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0rc2.dist-info}/WHEEL +0 -0
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0rc2.dist-info}/licenses/LICENSE +0 -0
graphiti_core/nodes.py
CHANGED
|
@@ -31,11 +31,13 @@ from graphiti_core.errors import NodeNotFoundError
|
|
|
31
31
|
from graphiti_core.helpers import parse_db_date
|
|
32
32
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
33
33
|
COMMUNITY_NODE_RETURN,
|
|
34
|
+
COMMUNITY_NODE_RETURN_NEPTUNE,
|
|
34
35
|
ENTITY_NODE_RETURN,
|
|
35
36
|
EPISODIC_NODE_RETURN,
|
|
36
|
-
|
|
37
|
+
EPISODIC_NODE_RETURN_NEPTUNE,
|
|
37
38
|
get_community_node_save_query,
|
|
38
39
|
get_entity_node_save_query,
|
|
40
|
+
get_episode_node_save_query,
|
|
39
41
|
)
|
|
40
42
|
from graphiti_core.utils.datetime_utils import utc_now
|
|
41
43
|
|
|
@@ -89,23 +91,24 @@ class Node(BaseModel, ABC):
|
|
|
89
91
|
async def save(self, driver: GraphDriver): ...
|
|
90
92
|
|
|
91
93
|
async def delete(self, driver: GraphDriver):
|
|
92
|
-
|
|
93
|
-
|
|
94
|
+
match driver.provider:
|
|
95
|
+
case GraphProvider.NEO4J:
|
|
94
96
|
await driver.execute_query(
|
|
95
|
-
|
|
96
|
-
MATCH (n:{label} {{uuid: $uuid}})
|
|
97
|
-
DETACH DELETE n
|
|
98
|
-
""",
|
|
99
|
-
uuid=self.uuid,
|
|
100
|
-
)
|
|
101
|
-
else:
|
|
102
|
-
await driver.execute_query(
|
|
103
|
-
"""
|
|
97
|
+
"""
|
|
104
98
|
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
|
105
99
|
DETACH DELETE n
|
|
106
100
|
""",
|
|
107
|
-
|
|
108
|
-
|
|
101
|
+
uuid=self.uuid,
|
|
102
|
+
)
|
|
103
|
+
case _: # FalkorDB and Neptune
|
|
104
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
105
|
+
await driver.execute_query(
|
|
106
|
+
f"""
|
|
107
|
+
MATCH (n:{label} {{uuid: $uuid}})
|
|
108
|
+
DETACH DELETE n
|
|
109
|
+
""",
|
|
110
|
+
uuid=self.uuid,
|
|
111
|
+
)
|
|
109
112
|
|
|
110
113
|
logger.debug(f'Deleted Node: {self.uuid}')
|
|
111
114
|
|
|
@@ -119,28 +122,30 @@ class Node(BaseModel, ABC):
|
|
|
119
122
|
|
|
120
123
|
@classmethod
|
|
121
124
|
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
125
|
+
match driver.provider:
|
|
126
|
+
case GraphProvider.NEO4J:
|
|
127
|
+
async with driver.session() as session:
|
|
128
|
+
await session.run(
|
|
129
|
+
"""
|
|
130
|
+
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
|
131
|
+
CALL {
|
|
132
|
+
WITH n
|
|
133
|
+
DETACH DELETE n
|
|
134
|
+
} IN TRANSACTIONS OF $batch_size ROWS
|
|
135
|
+
""",
|
|
136
|
+
group_id=group_id,
|
|
137
|
+
batch_size=batch_size,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
case _: # FalkorDB and Neptune
|
|
141
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
142
|
+
await driver.execute_query(
|
|
143
|
+
f"""
|
|
144
|
+
MATCH (n:{label} {{group_id: $group_id}})
|
|
138
145
|
DETACH DELETE n
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
batch_size=batch_size,
|
|
143
|
-
)
|
|
146
|
+
""",
|
|
147
|
+
group_id=group_id,
|
|
148
|
+
)
|
|
144
149
|
|
|
145
150
|
@classmethod
|
|
146
151
|
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
|
@@ -189,11 +194,25 @@ class EpisodicNode(Node):
|
|
|
189
194
|
)
|
|
190
195
|
|
|
191
196
|
async def save(self, driver: GraphDriver):
|
|
197
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
198
|
+
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
199
|
+
'episode_content',
|
|
200
|
+
[
|
|
201
|
+
{
|
|
202
|
+
'uuid': self.uuid,
|
|
203
|
+
'group_id': self.group_id,
|
|
204
|
+
'source': self.source.value,
|
|
205
|
+
'content': self.content,
|
|
206
|
+
'source_description': self.source_description,
|
|
207
|
+
}
|
|
208
|
+
],
|
|
209
|
+
)
|
|
192
210
|
result = await driver.execute_query(
|
|
193
|
-
|
|
211
|
+
get_episode_node_save_query(driver.provider),
|
|
194
212
|
uuid=self.uuid,
|
|
195
213
|
name=self.name,
|
|
196
214
|
group_id=self.group_id,
|
|
215
|
+
group_label='Episodic_' + self.group_id.replace('-', ''),
|
|
197
216
|
source_description=self.source_description,
|
|
198
217
|
content=self.content,
|
|
199
218
|
entity_edges=self.entity_edges,
|
|
@@ -213,7 +232,11 @@ class EpisodicNode(Node):
|
|
|
213
232
|
MATCH (e:Episodic {uuid: $uuid})
|
|
214
233
|
RETURN
|
|
215
234
|
"""
|
|
216
|
-
+
|
|
235
|
+
+ (
|
|
236
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
237
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
238
|
+
else EPISODIC_NODE_RETURN
|
|
239
|
+
),
|
|
217
240
|
uuid=uuid,
|
|
218
241
|
routing_='r',
|
|
219
242
|
)
|
|
@@ -233,7 +256,11 @@ class EpisodicNode(Node):
|
|
|
233
256
|
WHERE e.uuid IN $uuids
|
|
234
257
|
RETURN DISTINCT
|
|
235
258
|
"""
|
|
236
|
-
+
|
|
259
|
+
+ (
|
|
260
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
261
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
262
|
+
else EPISODIC_NODE_RETURN
|
|
263
|
+
),
|
|
237
264
|
uuids=uuids,
|
|
238
265
|
routing_='r',
|
|
239
266
|
)
|
|
@@ -262,7 +289,11 @@ class EpisodicNode(Node):
|
|
|
262
289
|
+ """
|
|
263
290
|
RETURN DISTINCT
|
|
264
291
|
"""
|
|
265
|
-
+
|
|
292
|
+
+ (
|
|
293
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
294
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
295
|
+
else EPISODIC_NODE_RETURN
|
|
296
|
+
)
|
|
266
297
|
+ """
|
|
267
298
|
ORDER BY uuid DESC
|
|
268
299
|
"""
|
|
@@ -284,7 +315,11 @@ class EpisodicNode(Node):
|
|
|
284
315
|
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
|
285
316
|
RETURN DISTINCT
|
|
286
317
|
"""
|
|
287
|
-
+
|
|
318
|
+
+ (
|
|
319
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
320
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
321
|
+
else EPISODIC_NODE_RETURN
|
|
322
|
+
),
|
|
288
323
|
entity_node_uuid=entity_node_uuid,
|
|
289
324
|
routing_='r',
|
|
290
325
|
)
|
|
@@ -311,11 +346,18 @@ class EntityNode(Node):
|
|
|
311
346
|
return self.name_embedding
|
|
312
347
|
|
|
313
348
|
async def load_name_embedding(self, driver: GraphDriver):
|
|
314
|
-
|
|
349
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
350
|
+
query: LiteralString = """
|
|
351
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
352
|
+
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
315
353
|
"""
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
354
|
+
else:
|
|
355
|
+
query: LiteralString = """
|
|
356
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
357
|
+
RETURN n.name_embedding AS name_embedding
|
|
358
|
+
"""
|
|
359
|
+
records, _, _ = await driver.execute_query(
|
|
360
|
+
query,
|
|
319
361
|
uuid=self.uuid,
|
|
320
362
|
routing_='r',
|
|
321
363
|
)
|
|
@@ -336,7 +378,10 @@ class EntityNode(Node):
|
|
|
336
378
|
}
|
|
337
379
|
entity_data.update(self.attributes or {})
|
|
338
380
|
|
|
339
|
-
|
|
381
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
382
|
+
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
|
383
|
+
|
|
384
|
+
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
|
|
340
385
|
|
|
341
386
|
result = await driver.execute_query(
|
|
342
387
|
get_entity_node_save_query(driver.provider, labels),
|
|
@@ -433,8 +478,13 @@ class CommunityNode(Node):
|
|
|
433
478
|
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
|
434
479
|
|
|
435
480
|
async def save(self, driver: GraphDriver):
|
|
481
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
482
|
+
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
483
|
+
'community_name',
|
|
484
|
+
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
|
485
|
+
)
|
|
436
486
|
result = await driver.execute_query(
|
|
437
|
-
get_community_node_save_query(driver.provider),
|
|
487
|
+
get_community_node_save_query(driver.provider), # type: ignore
|
|
438
488
|
uuid=self.uuid,
|
|
439
489
|
name=self.name,
|
|
440
490
|
group_id=self.group_id,
|
|
@@ -457,11 +507,19 @@ class CommunityNode(Node):
|
|
|
457
507
|
return self.name_embedding
|
|
458
508
|
|
|
459
509
|
async def load_name_embedding(self, driver: GraphDriver):
|
|
460
|
-
|
|
510
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
511
|
+
query: LiteralString = """
|
|
512
|
+
MATCH (c:Community {uuid: $uuid})
|
|
513
|
+
RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
461
514
|
"""
|
|
515
|
+
else:
|
|
516
|
+
query: LiteralString = """
|
|
462
517
|
MATCH (c:Community {uuid: $uuid})
|
|
463
518
|
RETURN c.name_embedding AS name_embedding
|
|
464
|
-
"""
|
|
519
|
+
"""
|
|
520
|
+
|
|
521
|
+
records, _, _ = await driver.execute_query(
|
|
522
|
+
query,
|
|
465
523
|
uuid=self.uuid,
|
|
466
524
|
routing_='r',
|
|
467
525
|
)
|
|
@@ -478,7 +536,11 @@ class CommunityNode(Node):
|
|
|
478
536
|
MATCH (n:Community {uuid: $uuid})
|
|
479
537
|
RETURN
|
|
480
538
|
"""
|
|
481
|
-
+
|
|
539
|
+
+ (
|
|
540
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
541
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
542
|
+
else COMMUNITY_NODE_RETURN
|
|
543
|
+
),
|
|
482
544
|
uuid=uuid,
|
|
483
545
|
routing_='r',
|
|
484
546
|
)
|
|
@@ -498,7 +560,11 @@ class CommunityNode(Node):
|
|
|
498
560
|
WHERE n.uuid IN $uuids
|
|
499
561
|
RETURN
|
|
500
562
|
"""
|
|
501
|
-
+
|
|
563
|
+
+ (
|
|
564
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
565
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
566
|
+
else COMMUNITY_NODE_RETURN
|
|
567
|
+
),
|
|
502
568
|
uuids=uuids,
|
|
503
569
|
routing_='r',
|
|
504
570
|
)
|
|
@@ -527,7 +593,11 @@ class CommunityNode(Node):
|
|
|
527
593
|
+ """
|
|
528
594
|
RETURN
|
|
529
595
|
"""
|
|
530
|
-
+
|
|
596
|
+
+ (
|
|
597
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
598
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
599
|
+
else COMMUNITY_NODE_RETURN
|
|
600
|
+
)
|
|
531
601
|
+ """
|
|
532
602
|
ORDER BY n.uuid DESC
|
|
533
603
|
"""
|
graphiti_core/search/search.py
CHANGED
|
@@ -21,6 +21,7 @@ from time import time
|
|
|
21
21
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
22
22
|
from graphiti_core.driver.driver import GraphDriver
|
|
23
23
|
from graphiti_core.edges import EntityEdge
|
|
24
|
+
from graphiti_core.embedder.client import EMBEDDING_DIM
|
|
24
25
|
from graphiti_core.errors import SearchRerankerError
|
|
25
26
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
26
27
|
from graphiti_core.helpers import semaphore_gather
|
|
@@ -29,6 +30,7 @@ from graphiti_core.search.search_config import (
|
|
|
29
30
|
DEFAULT_SEARCH_LIMIT,
|
|
30
31
|
CommunityReranker,
|
|
31
32
|
CommunitySearchConfig,
|
|
33
|
+
CommunitySearchMethod,
|
|
32
34
|
EdgeReranker,
|
|
33
35
|
EdgeSearchConfig,
|
|
34
36
|
EdgeSearchMethod,
|
|
@@ -81,11 +83,29 @@ async def search(
|
|
|
81
83
|
|
|
82
84
|
if query.strip() == '':
|
|
83
85
|
return SearchResults()
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
86
|
+
|
|
87
|
+
if (
|
|
88
|
+
config.edge_config
|
|
89
|
+
and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
|
|
90
|
+
or config.edge_config
|
|
91
|
+
and EdgeReranker.mmr == config.edge_config.reranker
|
|
92
|
+
or config.node_config
|
|
93
|
+
and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
|
|
94
|
+
or config.node_config
|
|
95
|
+
and NodeReranker.mmr == config.node_config.reranker
|
|
96
|
+
or (
|
|
97
|
+
config.community_config
|
|
98
|
+
and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
|
|
99
|
+
)
|
|
100
|
+
or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
|
|
101
|
+
):
|
|
102
|
+
search_vector = (
|
|
103
|
+
query_vector
|
|
104
|
+
if query_vector is not None
|
|
105
|
+
else await embedder.create(input_data=[query.replace('\n', ' ')])
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
search_vector = [0.0] * EMBEDDING_DIM
|
|
89
109
|
|
|
90
110
|
# if group_ids is empty, set it to None
|
|
91
111
|
group_ids = group_ids if group_ids and group_ids != [''] else None
|
|
@@ -99,7 +119,7 @@ async def search(
|
|
|
99
119
|
driver,
|
|
100
120
|
cross_encoder,
|
|
101
121
|
query,
|
|
102
|
-
|
|
122
|
+
search_vector,
|
|
103
123
|
group_ids,
|
|
104
124
|
config.edge_config,
|
|
105
125
|
search_filter,
|
|
@@ -112,7 +132,7 @@ async def search(
|
|
|
112
132
|
driver,
|
|
113
133
|
cross_encoder,
|
|
114
134
|
query,
|
|
115
|
-
|
|
135
|
+
search_vector,
|
|
116
136
|
group_ids,
|
|
117
137
|
config.node_config,
|
|
118
138
|
search_filter,
|
|
@@ -125,7 +145,7 @@ async def search(
|
|
|
125
145
|
driver,
|
|
126
146
|
cross_encoder,
|
|
127
147
|
query,
|
|
128
|
-
|
|
148
|
+
search_vector,
|
|
129
149
|
group_ids,
|
|
130
150
|
config.episode_config,
|
|
131
151
|
search_filter,
|
|
@@ -136,7 +156,7 @@ async def search(
|
|
|
136
156
|
driver,
|
|
137
157
|
cross_encoder,
|
|
138
158
|
query,
|
|
139
|
-
|
|
159
|
+
search_vector,
|
|
140
160
|
group_ids,
|
|
141
161
|
config.community_config,
|
|
142
162
|
config.limit,
|