graphiti-core 0.20.4__py3-none-any.whl → 0.21.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +224 -0
- graphiti_core/driver/falkordb_driver.py +1 -0
- graphiti_core/driver/kuzu_driver.py +1 -0
- graphiti_core/driver/neo4j_driver.py +59 -2
- graphiti_core/driver/neptune_driver.py +26 -45
- graphiti_core/edges.py +61 -4
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graphiti.py +21 -5
- graphiti_core/models/edges/edge_db_queries.py +36 -16
- graphiti_core/models/nodes/node_db_queries.py +30 -10
- graphiti_core/nodes.py +120 -22
- graphiti_core/search/search_filters.py +53 -0
- graphiti_core/search/search_utils.py +225 -57
- graphiti_core/utils/bulk_utils.py +23 -3
- graphiti_core/utils/maintenance/edge_operations.py +39 -5
- graphiti_core/utils/maintenance/graph_data_operations.py +9 -5
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc2.dist-info}/METADATA +4 -1
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc2.dist-info}/RECORD +20 -20
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc2.dist-info}/WHEEL +0 -0
- {graphiti_core-0.20.4.dist-info → graphiti_core-0.21.0rc2.dist-info}/licenses/LICENSE +0 -0
graphiti_core/driver/driver.py
CHANGED
|
@@ -14,15 +14,40 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import asyncio
|
|
17
18
|
import copy
|
|
18
19
|
import logging
|
|
20
|
+
import os
|
|
19
21
|
from abc import ABC, abstractmethod
|
|
20
22
|
from collections.abc import Coroutine
|
|
23
|
+
from datetime import datetime
|
|
21
24
|
from enum import Enum
|
|
22
25
|
from typing import Any
|
|
23
26
|
|
|
27
|
+
from dotenv import load_dotenv
|
|
28
|
+
|
|
29
|
+
from graphiti_core.embedder.client import EMBEDDING_DIM
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
from opensearchpy import AsyncOpenSearch, helpers
|
|
33
|
+
|
|
34
|
+
_HAS_OPENSEARCH = True
|
|
35
|
+
except ImportError:
|
|
36
|
+
OpenSearch = None
|
|
37
|
+
helpers = None
|
|
38
|
+
_HAS_OPENSEARCH = False
|
|
39
|
+
|
|
24
40
|
logger = logging.getLogger(__name__)
|
|
25
41
|
|
|
42
|
+
DEFAULT_SIZE = 10
|
|
43
|
+
|
|
44
|
+
load_dotenv()
|
|
45
|
+
|
|
46
|
+
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
|
|
47
|
+
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
|
|
48
|
+
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
|
|
49
|
+
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
|
|
50
|
+
|
|
26
51
|
|
|
27
52
|
class GraphProvider(Enum):
|
|
28
53
|
NEO4J = 'neo4j'
|
|
@@ -31,6 +56,91 @@ class GraphProvider(Enum):
|
|
|
31
56
|
NEPTUNE = 'neptune'
|
|
32
57
|
|
|
33
58
|
|
|
59
|
+
aoss_indices = [
|
|
60
|
+
{
|
|
61
|
+
'index_name': ENTITY_INDEX_NAME,
|
|
62
|
+
'body': {
|
|
63
|
+
'settings': {'index': {'knn': True}},
|
|
64
|
+
'mappings': {
|
|
65
|
+
'properties': {
|
|
66
|
+
'uuid': {'type': 'keyword'},
|
|
67
|
+
'name': {'type': 'text'},
|
|
68
|
+
'summary': {'type': 'text'},
|
|
69
|
+
'group_id': {'type': 'keyword'},
|
|
70
|
+
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
71
|
+
'name_embedding': {
|
|
72
|
+
'type': 'knn_vector',
|
|
73
|
+
'dimension': EMBEDDING_DIM,
|
|
74
|
+
'method': {
|
|
75
|
+
'engine': 'faiss',
|
|
76
|
+
'space_type': 'cosinesimil',
|
|
77
|
+
'name': 'hnsw',
|
|
78
|
+
'parameters': {'ef_construction': 128, 'm': 16},
|
|
79
|
+
},
|
|
80
|
+
},
|
|
81
|
+
}
|
|
82
|
+
},
|
|
83
|
+
},
|
|
84
|
+
},
|
|
85
|
+
{
|
|
86
|
+
'index_name': COMMUNITY_INDEX_NAME,
|
|
87
|
+
'body': {
|
|
88
|
+
'mappings': {
|
|
89
|
+
'properties': {
|
|
90
|
+
'uuid': {'type': 'keyword'},
|
|
91
|
+
'name': {'type': 'text'},
|
|
92
|
+
'group_id': {'type': 'keyword'},
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
},
|
|
96
|
+
},
|
|
97
|
+
{
|
|
98
|
+
'index_name': EPISODE_INDEX_NAME,
|
|
99
|
+
'body': {
|
|
100
|
+
'mappings': {
|
|
101
|
+
'properties': {
|
|
102
|
+
'uuid': {'type': 'keyword'},
|
|
103
|
+
'content': {'type': 'text'},
|
|
104
|
+
'source': {'type': 'text'},
|
|
105
|
+
'source_description': {'type': 'text'},
|
|
106
|
+
'group_id': {'type': 'keyword'},
|
|
107
|
+
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
108
|
+
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
},
|
|
112
|
+
},
|
|
113
|
+
{
|
|
114
|
+
'index_name': ENTITY_EDGE_INDEX_NAME,
|
|
115
|
+
'body': {
|
|
116
|
+
'settings': {'index': {'knn': True}},
|
|
117
|
+
'mappings': {
|
|
118
|
+
'properties': {
|
|
119
|
+
'uuid': {'type': 'keyword'},
|
|
120
|
+
'name': {'type': 'text'},
|
|
121
|
+
'fact': {'type': 'text'},
|
|
122
|
+
'group_id': {'type': 'keyword'},
|
|
123
|
+
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
124
|
+
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
125
|
+
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
126
|
+
'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
127
|
+
'fact_embedding': {
|
|
128
|
+
'type': 'knn_vector',
|
|
129
|
+
'dimension': EMBEDDING_DIM,
|
|
130
|
+
'method': {
|
|
131
|
+
'engine': 'faiss',
|
|
132
|
+
'space_type': 'cosinesimil',
|
|
133
|
+
'name': 'hnsw',
|
|
134
|
+
'parameters': {'ef_construction': 128, 'm': 16},
|
|
135
|
+
},
|
|
136
|
+
},
|
|
137
|
+
}
|
|
138
|
+
},
|
|
139
|
+
},
|
|
140
|
+
},
|
|
141
|
+
]
|
|
142
|
+
|
|
143
|
+
|
|
34
144
|
class GraphDriverSession(ABC):
|
|
35
145
|
provider: GraphProvider
|
|
36
146
|
|
|
@@ -61,6 +171,7 @@ class GraphDriver(ABC):
|
|
|
61
171
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
|
62
172
|
)
|
|
63
173
|
_database: str
|
|
174
|
+
aoss_client: AsyncOpenSearch | None # type: ignore
|
|
64
175
|
|
|
65
176
|
@abstractmethod
|
|
66
177
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
@@ -87,3 +198,116 @@ class GraphDriver(ABC):
|
|
|
87
198
|
cloned._database = database
|
|
88
199
|
|
|
89
200
|
return cloned
|
|
201
|
+
|
|
202
|
+
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
|
203
|
+
# No matter what happens above, always return True
|
|
204
|
+
return self.delete_aoss_indices()
|
|
205
|
+
|
|
206
|
+
async def create_aoss_indices(self):
|
|
207
|
+
client = self.aoss_client
|
|
208
|
+
if not client:
|
|
209
|
+
logger.warning('No OpenSearch client found')
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
for index in aoss_indices:
|
|
213
|
+
alias_name = index['index_name']
|
|
214
|
+
|
|
215
|
+
# If alias already exists, skip (idempotent behavior)
|
|
216
|
+
if await client.indices.exists_alias(name=alias_name):
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
# Build a physical index name with timestamp
|
|
220
|
+
ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
|
|
221
|
+
physical_index_name = f'{alias_name}_{ts_suffix}'
|
|
222
|
+
|
|
223
|
+
# Create the index
|
|
224
|
+
await client.indices.create(index=physical_index_name, body=index['body'])
|
|
225
|
+
|
|
226
|
+
# Point alias to it
|
|
227
|
+
await client.indices.put_alias(index=physical_index_name, name=alias_name)
|
|
228
|
+
|
|
229
|
+
# Allow some time for index creation
|
|
230
|
+
await asyncio.sleep(1)
|
|
231
|
+
|
|
232
|
+
async def delete_aoss_indices(self):
|
|
233
|
+
client = self.aoss_client
|
|
234
|
+
|
|
235
|
+
if not client:
|
|
236
|
+
logger.warning('No OpenSearch client found')
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
for entry in aoss_indices:
|
|
240
|
+
alias_name = entry['index_name']
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
# Resolve alias → indices
|
|
244
|
+
alias_info = await client.indices.get_alias(name=alias_name)
|
|
245
|
+
indices = list(alias_info.keys())
|
|
246
|
+
|
|
247
|
+
if not indices:
|
|
248
|
+
logger.info(f"No indices found for alias '{alias_name}'")
|
|
249
|
+
continue
|
|
250
|
+
|
|
251
|
+
for index in indices:
|
|
252
|
+
if await client.indices.exists(index=index):
|
|
253
|
+
await client.indices.delete(index=index)
|
|
254
|
+
logger.info(f"Deleted index '{index}' (alias: {alias_name})")
|
|
255
|
+
else:
|
|
256
|
+
logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
|
|
257
|
+
|
|
258
|
+
except Exception as e:
|
|
259
|
+
logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
|
|
260
|
+
|
|
261
|
+
async def clear_aoss_indices(self):
|
|
262
|
+
client = self.aoss_client
|
|
263
|
+
|
|
264
|
+
if not client:
|
|
265
|
+
logger.warning('No OpenSearch client found')
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
for index in aoss_indices:
|
|
269
|
+
index_name = index['index_name']
|
|
270
|
+
|
|
271
|
+
if await client.indices.exists(index=index_name):
|
|
272
|
+
try:
|
|
273
|
+
# Delete all documents but keep the index
|
|
274
|
+
response = await client.delete_by_query(
|
|
275
|
+
index=index_name,
|
|
276
|
+
body={'query': {'match_all': {}}},
|
|
277
|
+
)
|
|
278
|
+
logger.info(f"Cleared index '{index_name}': {response}")
|
|
279
|
+
except Exception as e:
|
|
280
|
+
logger.error(f"Error clearing index '{index_name}': {e}")
|
|
281
|
+
else:
|
|
282
|
+
logger.warning(f"Index '{index_name}' does not exist")
|
|
283
|
+
|
|
284
|
+
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
285
|
+
client = self.aoss_client
|
|
286
|
+
if not client or not helpers:
|
|
287
|
+
logger.warning('No OpenSearch client found')
|
|
288
|
+
return 0
|
|
289
|
+
|
|
290
|
+
for index in aoss_indices:
|
|
291
|
+
if name.lower() == index['index_name']:
|
|
292
|
+
to_index = []
|
|
293
|
+
for d in data:
|
|
294
|
+
doc = {}
|
|
295
|
+
for p in index['body']['mappings']['properties']:
|
|
296
|
+
if p in d: # protect against missing fields
|
|
297
|
+
doc[p] = d[p]
|
|
298
|
+
|
|
299
|
+
item = {
|
|
300
|
+
'_index': name,
|
|
301
|
+
'_id': d['uuid'],
|
|
302
|
+
'_routing': d.get('group_id'),
|
|
303
|
+
'_source': doc,
|
|
304
|
+
}
|
|
305
|
+
to_index.append(item)
|
|
306
|
+
|
|
307
|
+
success, failed = await helpers.async_bulk(
|
|
308
|
+
client, to_index, stats_only=True, request_timeout=60
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
return success if failed == 0 else success
|
|
312
|
+
|
|
313
|
+
return 0
|
|
@@ -22,14 +22,44 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
|
|
22
22
|
from typing_extensions import LiteralString
|
|
23
23
|
|
|
24
24
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
25
|
+
from graphiti_core.helpers import semaphore_gather
|
|
25
26
|
|
|
26
27
|
logger = logging.getLogger(__name__)
|
|
27
28
|
|
|
29
|
+
try:
|
|
30
|
+
import boto3
|
|
31
|
+
from opensearchpy import (
|
|
32
|
+
AIOHttpConnection,
|
|
33
|
+
AsyncOpenSearch,
|
|
34
|
+
AWSV4SignerAuth,
|
|
35
|
+
Urllib3AWSV4SignerAuth,
|
|
36
|
+
Urllib3HttpConnection,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
_HAS_OPENSEARCH = True
|
|
40
|
+
except ImportError:
|
|
41
|
+
boto3 = None
|
|
42
|
+
OpenSearch = None
|
|
43
|
+
Urllib3AWSV4SignerAuth = None
|
|
44
|
+
Urllib3HttpConnection = None
|
|
45
|
+
_HAS_OPENSEARCH = False
|
|
46
|
+
|
|
28
47
|
|
|
29
48
|
class Neo4jDriver(GraphDriver):
|
|
30
49
|
provider = GraphProvider.NEO4J
|
|
31
50
|
|
|
32
|
-
def __init__(
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
uri: str,
|
|
54
|
+
user: str | None,
|
|
55
|
+
password: str | None,
|
|
56
|
+
database: str = 'neo4j',
|
|
57
|
+
aoss_host: str | None = None,
|
|
58
|
+
aoss_port: int | None = None,
|
|
59
|
+
aws_profile_name: str | None = None,
|
|
60
|
+
aws_region: str | None = None,
|
|
61
|
+
aws_service: str | None = None,
|
|
62
|
+
):
|
|
33
63
|
super().__init__()
|
|
34
64
|
self.client = AsyncGraphDatabase.driver(
|
|
35
65
|
uri=uri,
|
|
@@ -37,6 +67,26 @@ class Neo4jDriver(GraphDriver):
|
|
|
37
67
|
)
|
|
38
68
|
self._database = database
|
|
39
69
|
|
|
70
|
+
self.aoss_client = None
|
|
71
|
+
if aoss_host and aoss_port and boto3 is not None:
|
|
72
|
+
try:
|
|
73
|
+
region = aws_region
|
|
74
|
+
service = aws_service
|
|
75
|
+
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
|
|
76
|
+
auth = AWSV4SignerAuth(credentials, region or '', service or '')
|
|
77
|
+
|
|
78
|
+
self.aoss_client = AsyncOpenSearch(
|
|
79
|
+
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
|
80
|
+
auth=auth,
|
|
81
|
+
use_ssl=True,
|
|
82
|
+
verify_certs=True,
|
|
83
|
+
connection_class=AIOHttpConnection,
|
|
84
|
+
pool_maxsize=20,
|
|
85
|
+
) # type: ignore
|
|
86
|
+
except Exception as e:
|
|
87
|
+
logger.warning(f'Failed to initialize OpenSearch client: {e}')
|
|
88
|
+
self.aoss_client = None
|
|
89
|
+
|
|
40
90
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
|
41
91
|
# Check if database_ is provided in kwargs.
|
|
42
92
|
# If not populated, set the value to retain backwards compatibility
|
|
@@ -60,7 +110,14 @@ class Neo4jDriver(GraphDriver):
|
|
|
60
110
|
async def close(self) -> None:
|
|
61
111
|
return await self.client.close()
|
|
62
112
|
|
|
63
|
-
def delete_all_indexes(self) -> Coroutine
|
|
113
|
+
def delete_all_indexes(self) -> Coroutine:
|
|
114
|
+
if self.aoss_client:
|
|
115
|
+
return semaphore_gather(
|
|
116
|
+
self.client.execute_query(
|
|
117
|
+
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
118
|
+
),
|
|
119
|
+
self.delete_aoss_indices(),
|
|
120
|
+
)
|
|
64
121
|
return self.client.execute_query(
|
|
65
122
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
66
123
|
)
|
|
@@ -22,16 +22,21 @@ from typing import Any
|
|
|
22
22
|
|
|
23
23
|
import boto3
|
|
24
24
|
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
|
25
|
-
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
|
25
|
+
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
|
26
26
|
|
|
27
|
-
from graphiti_core.driver.driver import
|
|
27
|
+
from graphiti_core.driver.driver import (
|
|
28
|
+
DEFAULT_SIZE,
|
|
29
|
+
GraphDriver,
|
|
30
|
+
GraphDriverSession,
|
|
31
|
+
GraphProvider,
|
|
32
|
+
)
|
|
28
33
|
|
|
29
34
|
logger = logging.getLogger(__name__)
|
|
30
|
-
DEFAULT_SIZE = 10
|
|
31
35
|
|
|
32
|
-
|
|
36
|
+
neptune_aoss_indices = [
|
|
33
37
|
{
|
|
34
38
|
'index_name': 'node_name_and_summary',
|
|
39
|
+
'alias_name': 'entities',
|
|
35
40
|
'body': {
|
|
36
41
|
'mappings': {
|
|
37
42
|
'properties': {
|
|
@@ -49,6 +54,7 @@ aoss_indices = [
|
|
|
49
54
|
},
|
|
50
55
|
{
|
|
51
56
|
'index_name': 'community_name',
|
|
57
|
+
'alias_name': 'communities',
|
|
52
58
|
'body': {
|
|
53
59
|
'mappings': {
|
|
54
60
|
'properties': {
|
|
@@ -65,6 +71,7 @@ aoss_indices = [
|
|
|
65
71
|
},
|
|
66
72
|
{
|
|
67
73
|
'index_name': 'episode_content',
|
|
74
|
+
'alias_name': 'episodes',
|
|
68
75
|
'body': {
|
|
69
76
|
'mappings': {
|
|
70
77
|
'properties': {
|
|
@@ -88,6 +95,7 @@ aoss_indices = [
|
|
|
88
95
|
},
|
|
89
96
|
{
|
|
90
97
|
'index_name': 'edge_name_and_fact',
|
|
98
|
+
'alias_name': 'facts',
|
|
91
99
|
'body': {
|
|
92
100
|
'mappings': {
|
|
93
101
|
'properties': {
|
|
@@ -220,54 +228,27 @@ class NeptuneDriver(GraphDriver):
|
|
|
220
228
|
async def _delete_all_data(self) -> Any:
|
|
221
229
|
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
|
222
230
|
|
|
223
|
-
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
|
224
|
-
return self.delete_all_indexes_impl()
|
|
225
|
-
|
|
226
|
-
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
|
227
|
-
# No matter what happens above, always return True
|
|
228
|
-
return self.delete_aoss_indices()
|
|
229
|
-
|
|
230
231
|
async def create_aoss_indices(self):
|
|
231
|
-
for index in
|
|
232
|
+
for index in neptune_aoss_indices:
|
|
232
233
|
index_name = index['index_name']
|
|
233
234
|
client = self.aoss_client
|
|
235
|
+
if not client:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
'You must provide an AOSS endpoint to create an OpenSearch driver.'
|
|
238
|
+
)
|
|
234
239
|
if not client.indices.exists(index=index_name):
|
|
235
|
-
client.indices.create(index=index_name, body=index['body'])
|
|
240
|
+
await client.indices.create(index=index_name, body=index['body'])
|
|
241
|
+
|
|
242
|
+
alias_name = index.get('alias_name', index_name)
|
|
243
|
+
|
|
244
|
+
if not client.indices.exists_alias(name=alias_name, index=index_name):
|
|
245
|
+
await client.indices.put_alias(index=index_name, name=alias_name)
|
|
246
|
+
|
|
236
247
|
# Sleep for 1 minute to let the index creation complete
|
|
237
248
|
await asyncio.sleep(60)
|
|
238
249
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
index_name = index['index_name']
|
|
242
|
-
client = self.aoss_client
|
|
243
|
-
if client.indices.exists(index=index_name):
|
|
244
|
-
client.indices.delete(index=index_name)
|
|
245
|
-
|
|
246
|
-
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
|
247
|
-
for index in aoss_indices:
|
|
248
|
-
if name.lower() == index['index_name']:
|
|
249
|
-
index['query']['query']['multi_match']['query'] = query_text
|
|
250
|
-
query = {'size': limit, 'query': index['query']}
|
|
251
|
-
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
|
252
|
-
return resp
|
|
253
|
-
return {}
|
|
254
|
-
|
|
255
|
-
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
256
|
-
for index in aoss_indices:
|
|
257
|
-
if name.lower() == index['index_name']:
|
|
258
|
-
to_index = []
|
|
259
|
-
for d in data:
|
|
260
|
-
item = {'_index': name}
|
|
261
|
-
for p in index['body']['mappings']['properties']:
|
|
262
|
-
item[p] = d[p]
|
|
263
|
-
to_index.append(item)
|
|
264
|
-
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
|
265
|
-
if failed > 0:
|
|
266
|
-
return success
|
|
267
|
-
else:
|
|
268
|
-
return 0
|
|
269
|
-
|
|
270
|
-
return 0
|
|
250
|
+
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
|
251
|
+
return self.delete_all_indexes_impl()
|
|
271
252
|
|
|
272
253
|
|
|
273
254
|
class NeptuneDriverSession(GraphDriverSession):
|
graphiti_core/edges.py
CHANGED
|
@@ -25,7 +25,7 @@ from uuid import uuid4
|
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
26
|
from typing_extensions import LiteralString
|
|
27
27
|
|
|
28
|
-
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
28
|
+
from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
|
|
29
29
|
from graphiti_core.embedder import EmbedderClient
|
|
30
30
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
31
31
|
from graphiti_core.helpers import parse_db_date
|
|
@@ -77,6 +77,13 @@ class Edge(BaseModel, ABC):
|
|
|
77
77
|
uuid=self.uuid,
|
|
78
78
|
)
|
|
79
79
|
|
|
80
|
+
if driver.aoss_client:
|
|
81
|
+
await driver.aoss_client.delete(
|
|
82
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
83
|
+
id=self.uuid,
|
|
84
|
+
params={'routing': self.group_id},
|
|
85
|
+
)
|
|
86
|
+
|
|
80
87
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
81
88
|
|
|
82
89
|
@classmethod
|
|
@@ -108,6 +115,12 @@ class Edge(BaseModel, ABC):
|
|
|
108
115
|
uuids=uuids,
|
|
109
116
|
)
|
|
110
117
|
|
|
118
|
+
if driver.aoss_client:
|
|
119
|
+
await driver.aoss_client.delete_by_query(
|
|
120
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
121
|
+
body={'query': {'terms': {'uuid': uuids}}},
|
|
122
|
+
)
|
|
123
|
+
|
|
111
124
|
logger.debug(f'Deleted Edges: {uuids}')
|
|
112
125
|
|
|
113
126
|
def __hash__(self):
|
|
@@ -255,6 +268,21 @@ class EntityEdge(Edge):
|
|
|
255
268
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
256
269
|
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
|
257
270
|
"""
|
|
271
|
+
elif driver.aoss_client:
|
|
272
|
+
resp = await driver.aoss_client.search(
|
|
273
|
+
body={
|
|
274
|
+
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
|
275
|
+
'size': 1,
|
|
276
|
+
},
|
|
277
|
+
index=ENTITY_EDGE_INDEX_NAME,
|
|
278
|
+
params={'routing': self.group_id},
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
if resp['hits']['hits']:
|
|
282
|
+
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
|
|
283
|
+
return
|
|
284
|
+
else:
|
|
285
|
+
raise EdgeNotFoundError(self.uuid)
|
|
258
286
|
|
|
259
287
|
if driver.provider == GraphProvider.KUZU:
|
|
260
288
|
query = """
|
|
@@ -292,14 +320,14 @@ class EntityEdge(Edge):
|
|
|
292
320
|
if driver.provider == GraphProvider.KUZU:
|
|
293
321
|
edge_data['attributes'] = json.dumps(self.attributes)
|
|
294
322
|
result = await driver.execute_query(
|
|
295
|
-
get_entity_edge_save_query(driver.provider),
|
|
323
|
+
get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
|
296
324
|
**edge_data,
|
|
297
325
|
)
|
|
298
326
|
else:
|
|
299
327
|
edge_data.update(self.attributes or {})
|
|
300
328
|
|
|
301
|
-
if driver.
|
|
302
|
-
driver.save_to_aoss(
|
|
329
|
+
if driver.aoss_client:
|
|
330
|
+
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
|
303
331
|
|
|
304
332
|
result = await driver.execute_query(
|
|
305
333
|
get_entity_edge_save_query(driver.provider),
|
|
@@ -336,6 +364,35 @@ class EntityEdge(Edge):
|
|
|
336
364
|
raise EdgeNotFoundError(uuid)
|
|
337
365
|
return edges[0]
|
|
338
366
|
|
|
367
|
+
@classmethod
|
|
368
|
+
async def get_between_nodes(
|
|
369
|
+
cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
|
|
370
|
+
):
|
|
371
|
+
match_query = """
|
|
372
|
+
MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
|
373
|
+
"""
|
|
374
|
+
if driver.provider == GraphProvider.KUZU:
|
|
375
|
+
match_query = """
|
|
376
|
+
MATCH (n:Entity {uuid: $source_node_uuid})
|
|
377
|
+
-[:RELATES_TO]->(e:RelatesToNode_)
|
|
378
|
+
-[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
records, _, _ = await driver.execute_query(
|
|
382
|
+
match_query
|
|
383
|
+
+ """
|
|
384
|
+
RETURN
|
|
385
|
+
"""
|
|
386
|
+
+ get_entity_edge_return_query(driver.provider),
|
|
387
|
+
source_node_uuid=source_node_uuid,
|
|
388
|
+
target_node_uuid=target_node_uuid,
|
|
389
|
+
routing_='r',
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
393
|
+
|
|
394
|
+
return edges
|
|
395
|
+
|
|
339
396
|
@classmethod
|
|
340
397
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
341
398
|
if len(uuids) == 0:
|
graphiti_core/embedder/client.py
CHANGED
|
@@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import os
|
|
17
18
|
from abc import ABC, abstractmethod
|
|
18
19
|
from collections.abc import Iterable
|
|
19
20
|
|
|
20
21
|
from pydantic import BaseModel, Field
|
|
21
22
|
|
|
22
|
-
EMBEDDING_DIM = 1024
|
|
23
|
+
EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class EmbedderConfig(BaseModel):
|
graphiti_core/graphiti.py
CHANGED
|
@@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
|
|
|
60
60
|
from graphiti_core.search.search_filters import SearchFilters
|
|
61
61
|
from graphiti_core.search.search_utils import (
|
|
62
62
|
RELEVANT_SCHEMA_LIMIT,
|
|
63
|
-
get_edge_invalidation_candidates,
|
|
64
63
|
get_mentioned_nodes,
|
|
65
|
-
get_relevant_edges,
|
|
66
64
|
)
|
|
67
65
|
from graphiti_core.telemetry import capture_event
|
|
68
66
|
from graphiti_core.utils.bulk_utils import (
|
|
@@ -1037,10 +1035,28 @@ class Graphiti:
|
|
|
1037
1035
|
|
|
1038
1036
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
|
1039
1037
|
|
|
1040
|
-
|
|
1038
|
+
valid_edges = await EntityEdge.get_between_nodes(
|
|
1039
|
+
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
related_edges = (
|
|
1043
|
+
await search(
|
|
1044
|
+
self.clients,
|
|
1045
|
+
updated_edge.fact,
|
|
1046
|
+
group_ids=[updated_edge.group_id],
|
|
1047
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1048
|
+
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
|
|
1049
|
+
)
|
|
1050
|
+
).edges
|
|
1041
1051
|
existing_edges = (
|
|
1042
|
-
await
|
|
1043
|
-
|
|
1052
|
+
await search(
|
|
1053
|
+
self.clients,
|
|
1054
|
+
updated_edge.fact,
|
|
1055
|
+
group_ids=[updated_edge.group_id],
|
|
1056
|
+
config=EDGE_HYBRID_SEARCH_RRF,
|
|
1057
|
+
search_filter=SearchFilters(),
|
|
1058
|
+
)
|
|
1059
|
+
).edges
|
|
1044
1060
|
|
|
1045
1061
|
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
|
1046
1062
|
self.llm_client,
|