graphiti-core 0.21.0rc13__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +4 -211
- graphiti_core/driver/falkordb_driver.py +31 -3
- graphiti_core/driver/graph_operations/graph_operations.py +195 -0
- graphiti_core/driver/neo4j_driver.py +0 -49
- graphiti_core/driver/neptune_driver.py +43 -26
- graphiti_core/driver/search_interface/__init__.py +0 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +11 -34
- graphiti_core/graphiti.py +459 -326
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/llm_client/anthropic_client.py +64 -45
- graphiti_core/llm_client/client.py +67 -19
- graphiti_core/llm_client/gemini_client.py +73 -54
- graphiti_core/llm_client/openai_base_client.py +65 -43
- graphiti_core/llm_client/openai_generic_client.py +65 -43
- graphiti_core/models/edges/edge_db_queries.py +1 -0
- graphiti_core/models/nodes/node_db_queries.py +1 -0
- graphiti_core/nodes.py +26 -99
- graphiti_core/prompts/dedupe_edges.py +4 -4
- graphiti_core/prompts/dedupe_nodes.py +10 -10
- graphiti_core/prompts/extract_edges.py +4 -4
- graphiti_core/prompts/extract_nodes.py +26 -28
- graphiti_core/prompts/prompt_helpers.py +18 -2
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +22 -24
- graphiti_core/search/search_filters.py +0 -38
- graphiti_core/search/search_helpers.py +4 -4
- graphiti_core/search/search_utils.py +84 -220
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +16 -28
- graphiti_core/utils/maintenance/community_operations.py +4 -1
- graphiti_core/utils/maintenance/edge_operations.py +26 -15
- graphiti_core/utils/maintenance/graph_data_operations.py +6 -25
- graphiti_core/utils/maintenance/node_operations.py +98 -51
- graphiti_core/utils/maintenance/temporal_operations.py +4 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/METADATA +7 -3
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/RECORD +41 -35
- /graphiti_core/{utils/maintenance/utils.py → driver/graph_operations/__init__.py} +0 -0
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/driver/driver.py
CHANGED
|
@@ -14,28 +14,18 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import copy
|
|
19
18
|
import logging
|
|
20
19
|
import os
|
|
21
20
|
from abc import ABC, abstractmethod
|
|
22
21
|
from collections.abc import Coroutine
|
|
23
|
-
from datetime import datetime
|
|
24
22
|
from enum import Enum
|
|
25
23
|
from typing import Any
|
|
26
24
|
|
|
27
25
|
from dotenv import load_dotenv
|
|
28
26
|
|
|
29
|
-
from graphiti_core.
|
|
30
|
-
|
|
31
|
-
try:
|
|
32
|
-
from opensearchpy import AsyncOpenSearch, helpers
|
|
33
|
-
|
|
34
|
-
_HAS_OPENSEARCH = True
|
|
35
|
-
except ImportError:
|
|
36
|
-
OpenSearch = None
|
|
37
|
-
helpers = None
|
|
38
|
-
_HAS_OPENSEARCH = False
|
|
27
|
+
from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface
|
|
28
|
+
from graphiti_core.driver.search_interface.search_interface import SearchInterface
|
|
39
29
|
|
|
40
30
|
logger = logging.getLogger(__name__)
|
|
41
31
|
|
|
@@ -56,91 +46,6 @@ class GraphProvider(Enum):
|
|
|
56
46
|
NEPTUNE = 'neptune'
|
|
57
47
|
|
|
58
48
|
|
|
59
|
-
aoss_indices = [
|
|
60
|
-
{
|
|
61
|
-
'index_name': ENTITY_INDEX_NAME,
|
|
62
|
-
'body': {
|
|
63
|
-
'settings': {'index': {'knn': True}},
|
|
64
|
-
'mappings': {
|
|
65
|
-
'properties': {
|
|
66
|
-
'uuid': {'type': 'keyword'},
|
|
67
|
-
'name': {'type': 'text'},
|
|
68
|
-
'summary': {'type': 'text'},
|
|
69
|
-
'group_id': {'type': 'keyword'},
|
|
70
|
-
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
71
|
-
'name_embedding': {
|
|
72
|
-
'type': 'knn_vector',
|
|
73
|
-
'dimension': EMBEDDING_DIM,
|
|
74
|
-
'method': {
|
|
75
|
-
'engine': 'faiss',
|
|
76
|
-
'space_type': 'cosinesimil',
|
|
77
|
-
'name': 'hnsw',
|
|
78
|
-
'parameters': {'ef_construction': 128, 'm': 16},
|
|
79
|
-
},
|
|
80
|
-
},
|
|
81
|
-
}
|
|
82
|
-
},
|
|
83
|
-
},
|
|
84
|
-
},
|
|
85
|
-
{
|
|
86
|
-
'index_name': COMMUNITY_INDEX_NAME,
|
|
87
|
-
'body': {
|
|
88
|
-
'mappings': {
|
|
89
|
-
'properties': {
|
|
90
|
-
'uuid': {'type': 'keyword'},
|
|
91
|
-
'name': {'type': 'text'},
|
|
92
|
-
'group_id': {'type': 'keyword'},
|
|
93
|
-
}
|
|
94
|
-
}
|
|
95
|
-
},
|
|
96
|
-
},
|
|
97
|
-
{
|
|
98
|
-
'index_name': EPISODE_INDEX_NAME,
|
|
99
|
-
'body': {
|
|
100
|
-
'mappings': {
|
|
101
|
-
'properties': {
|
|
102
|
-
'uuid': {'type': 'keyword'},
|
|
103
|
-
'content': {'type': 'text'},
|
|
104
|
-
'source': {'type': 'text'},
|
|
105
|
-
'source_description': {'type': 'text'},
|
|
106
|
-
'group_id': {'type': 'keyword'},
|
|
107
|
-
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
108
|
-
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
109
|
-
}
|
|
110
|
-
}
|
|
111
|
-
},
|
|
112
|
-
},
|
|
113
|
-
{
|
|
114
|
-
'index_name': ENTITY_EDGE_INDEX_NAME,
|
|
115
|
-
'body': {
|
|
116
|
-
'settings': {'index': {'knn': True}},
|
|
117
|
-
'mappings': {
|
|
118
|
-
'properties': {
|
|
119
|
-
'uuid': {'type': 'keyword'},
|
|
120
|
-
'name': {'type': 'text'},
|
|
121
|
-
'fact': {'type': 'text'},
|
|
122
|
-
'group_id': {'type': 'keyword'},
|
|
123
|
-
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
124
|
-
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
125
|
-
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
126
|
-
'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
127
|
-
'fact_embedding': {
|
|
128
|
-
'type': 'knn_vector',
|
|
129
|
-
'dimension': EMBEDDING_DIM,
|
|
130
|
-
'method': {
|
|
131
|
-
'engine': 'faiss',
|
|
132
|
-
'space_type': 'cosinesimil',
|
|
133
|
-
'name': 'hnsw',
|
|
134
|
-
'parameters': {'ef_construction': 128, 'm': 16},
|
|
135
|
-
},
|
|
136
|
-
},
|
|
137
|
-
}
|
|
138
|
-
},
|
|
139
|
-
},
|
|
140
|
-
},
|
|
141
|
-
]
|
|
142
|
-
|
|
143
|
-
|
|
144
49
|
class GraphDriverSession(ABC):
|
|
145
50
|
provider: GraphProvider
|
|
146
51
|
|
|
@@ -171,7 +76,8 @@ class GraphDriver(ABC):
|
|
|
171
76
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
|
172
77
|
)
|
|
173
78
|
_database: str
|
|
174
|
-
|
|
79
|
+
search_interface: SearchInterface | None = None
|
|
80
|
+
graph_operations_interface: GraphOperationsInterface | None = None
|
|
175
81
|
|
|
176
82
|
@abstractmethod
|
|
177
83
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
@@ -199,119 +105,6 @@ class GraphDriver(ABC):
|
|
|
199
105
|
|
|
200
106
|
return cloned
|
|
201
107
|
|
|
202
|
-
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
|
203
|
-
# No matter what happens above, always return True
|
|
204
|
-
return self.delete_aoss_indices()
|
|
205
|
-
|
|
206
|
-
async def create_aoss_indices(self):
|
|
207
|
-
client = self.aoss_client
|
|
208
|
-
if not client:
|
|
209
|
-
logger.warning('No OpenSearch client found')
|
|
210
|
-
return
|
|
211
|
-
|
|
212
|
-
for index in aoss_indices:
|
|
213
|
-
alias_name = index['index_name']
|
|
214
|
-
|
|
215
|
-
# If alias already exists, skip (idempotent behavior)
|
|
216
|
-
if await client.indices.exists_alias(name=alias_name):
|
|
217
|
-
continue
|
|
218
|
-
|
|
219
|
-
# Build a physical index name with timestamp
|
|
220
|
-
ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
|
|
221
|
-
physical_index_name = f'{alias_name}_{ts_suffix}'
|
|
222
|
-
|
|
223
|
-
# Create the index
|
|
224
|
-
await client.indices.create(index=physical_index_name, body=index['body'])
|
|
225
|
-
|
|
226
|
-
# Point alias to it
|
|
227
|
-
await client.indices.put_alias(index=physical_index_name, name=alias_name)
|
|
228
|
-
|
|
229
|
-
# Allow some time for index creation
|
|
230
|
-
await asyncio.sleep(1)
|
|
231
|
-
|
|
232
|
-
async def delete_aoss_indices(self):
|
|
233
|
-
client = self.aoss_client
|
|
234
|
-
|
|
235
|
-
if not client:
|
|
236
|
-
logger.warning('No OpenSearch client found')
|
|
237
|
-
return
|
|
238
|
-
|
|
239
|
-
for entry in aoss_indices:
|
|
240
|
-
alias_name = entry['index_name']
|
|
241
|
-
|
|
242
|
-
try:
|
|
243
|
-
# Resolve alias → indices
|
|
244
|
-
alias_info = await client.indices.get_alias(name=alias_name)
|
|
245
|
-
indices = list(alias_info.keys())
|
|
246
|
-
|
|
247
|
-
if not indices:
|
|
248
|
-
logger.info(f"No indices found for alias '{alias_name}'")
|
|
249
|
-
continue
|
|
250
|
-
|
|
251
|
-
for index in indices:
|
|
252
|
-
if await client.indices.exists(index=index):
|
|
253
|
-
await client.indices.delete(index=index)
|
|
254
|
-
logger.info(f"Deleted index '{index}' (alias: {alias_name})")
|
|
255
|
-
else:
|
|
256
|
-
logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
|
|
257
|
-
|
|
258
|
-
except Exception as e:
|
|
259
|
-
logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
|
|
260
|
-
|
|
261
|
-
async def clear_aoss_indices(self):
|
|
262
|
-
client = self.aoss_client
|
|
263
|
-
|
|
264
|
-
if not client:
|
|
265
|
-
logger.warning('No OpenSearch client found')
|
|
266
|
-
return
|
|
267
|
-
|
|
268
|
-
for index in aoss_indices:
|
|
269
|
-
index_name = index['index_name']
|
|
270
|
-
|
|
271
|
-
if await client.indices.exists(index=index_name):
|
|
272
|
-
try:
|
|
273
|
-
# Delete all documents but keep the index
|
|
274
|
-
response = await client.delete_by_query(
|
|
275
|
-
index=index_name,
|
|
276
|
-
body={'query': {'match_all': {}}},
|
|
277
|
-
)
|
|
278
|
-
logger.info(f"Cleared index '{index_name}': {response}")
|
|
279
|
-
except Exception as e:
|
|
280
|
-
logger.error(f"Error clearing index '{index_name}': {e}")
|
|
281
|
-
else:
|
|
282
|
-
logger.warning(f"Index '{index_name}' does not exist")
|
|
283
|
-
|
|
284
|
-
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
285
|
-
client = self.aoss_client
|
|
286
|
-
if not client or not helpers:
|
|
287
|
-
logger.warning('No OpenSearch client found')
|
|
288
|
-
return 0
|
|
289
|
-
|
|
290
|
-
for index in aoss_indices:
|
|
291
|
-
if name.lower() == index['index_name']:
|
|
292
|
-
to_index = []
|
|
293
|
-
for d in data:
|
|
294
|
-
doc = {}
|
|
295
|
-
for p in index['body']['mappings']['properties']:
|
|
296
|
-
if p in d: # protect against missing fields
|
|
297
|
-
doc[p] = d[p]
|
|
298
|
-
|
|
299
|
-
item = {
|
|
300
|
-
'_index': name,
|
|
301
|
-
'_id': d['uuid'],
|
|
302
|
-
'_routing': d.get('group_id'),
|
|
303
|
-
'_source': doc,
|
|
304
|
-
}
|
|
305
|
-
to_index.append(item)
|
|
306
|
-
|
|
307
|
-
success, failed = await helpers.async_bulk(
|
|
308
|
-
client, to_index, stats_only=True, request_timeout=60
|
|
309
|
-
)
|
|
310
|
-
|
|
311
|
-
return success if failed == 0 else success
|
|
312
|
-
|
|
313
|
-
return 0
|
|
314
|
-
|
|
315
108
|
def build_fulltext_query(
|
|
316
109
|
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
|
317
110
|
) -> str:
|
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import asyncio
|
|
17
18
|
import logging
|
|
18
19
|
from typing import TYPE_CHECKING, Any
|
|
19
20
|
|
|
@@ -191,9 +192,36 @@ class FalkorDriver(GraphDriver):
|
|
|
191
192
|
await self.client.connection.close()
|
|
192
193
|
|
|
193
194
|
async def delete_all_indexes(self) -> None:
|
|
194
|
-
await self.execute_query(
|
|
195
|
-
|
|
196
|
-
|
|
195
|
+
result = await self.execute_query('CALL db.indexes()')
|
|
196
|
+
if not result:
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
records, _, _ = result
|
|
200
|
+
drop_tasks = []
|
|
201
|
+
|
|
202
|
+
for record in records:
|
|
203
|
+
label = record['label']
|
|
204
|
+
entity_type = record['entitytype']
|
|
205
|
+
|
|
206
|
+
for field_name, index_type in record['types'].items():
|
|
207
|
+
if 'RANGE' in index_type:
|
|
208
|
+
drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})'))
|
|
209
|
+
elif 'FULLTEXT' in index_type:
|
|
210
|
+
if entity_type == 'NODE':
|
|
211
|
+
drop_tasks.append(
|
|
212
|
+
self.execute_query(
|
|
213
|
+
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})'
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
elif entity_type == 'RELATIONSHIP':
|
|
217
|
+
drop_tasks.append(
|
|
218
|
+
self.execute_query(
|
|
219
|
+
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})'
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if drop_tasks:
|
|
224
|
+
await asyncio.gather(*drop_tasks)
|
|
197
225
|
|
|
198
226
|
def clone(self, database: str) -> 'GraphDriver':
|
|
199
227
|
"""
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GraphOperationsInterface(BaseModel):
|
|
23
|
+
"""
|
|
24
|
+
Interface for updating graph mutation behavior.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# -----------------
|
|
28
|
+
# Node: Save/Delete
|
|
29
|
+
# -----------------
|
|
30
|
+
|
|
31
|
+
async def node_save(self, node: Any, driver: Any) -> None:
|
|
32
|
+
"""Persist (create or update) a single node."""
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
async def node_delete(self, node: Any, driver: Any) -> None:
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
38
|
+
async def node_save_bulk(
|
|
39
|
+
self,
|
|
40
|
+
_cls: Any, # kept for parity; callers won't pass it
|
|
41
|
+
driver: Any,
|
|
42
|
+
transaction: Any,
|
|
43
|
+
nodes: list[Any],
|
|
44
|
+
batch_size: int = 100,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Persist (create or update) many nodes in batches."""
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
async def node_delete_by_group_id(
|
|
50
|
+
self,
|
|
51
|
+
_cls: Any,
|
|
52
|
+
driver: Any,
|
|
53
|
+
group_id: str,
|
|
54
|
+
batch_size: int = 100,
|
|
55
|
+
) -> None:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
async def node_delete_by_uuids(
|
|
59
|
+
self,
|
|
60
|
+
_cls: Any,
|
|
61
|
+
driver: Any,
|
|
62
|
+
uuids: list[str],
|
|
63
|
+
group_id: str | None = None,
|
|
64
|
+
batch_size: int = 100,
|
|
65
|
+
) -> None:
|
|
66
|
+
raise NotImplementedError
|
|
67
|
+
|
|
68
|
+
# --------------------------
|
|
69
|
+
# Node: Embeddings (load)
|
|
70
|
+
# --------------------------
|
|
71
|
+
|
|
72
|
+
async def node_load_embeddings(self, node: Any, driver: Any) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Load embedding vectors for a single node into the instance (e.g., set node.embedding or similar).
|
|
75
|
+
"""
|
|
76
|
+
raise NotImplementedError
|
|
77
|
+
|
|
78
|
+
async def node_load_embeddings_bulk(
|
|
79
|
+
self,
|
|
80
|
+
_cls: Any,
|
|
81
|
+
driver: Any,
|
|
82
|
+
transaction: Any,
|
|
83
|
+
nodes: list[Any],
|
|
84
|
+
batch_size: int = 100,
|
|
85
|
+
) -> None:
|
|
86
|
+
"""
|
|
87
|
+
Load embedding vectors for many nodes in batches. Mutates the provided node instances.
|
|
88
|
+
"""
|
|
89
|
+
raise NotImplementedError
|
|
90
|
+
|
|
91
|
+
# --------------------------
|
|
92
|
+
# EpisodicNode: Save/Delete
|
|
93
|
+
# --------------------------
|
|
94
|
+
|
|
95
|
+
async def episodic_node_save(self, node: Any, driver: Any) -> None:
|
|
96
|
+
"""Persist (create or update) a single episodic node."""
|
|
97
|
+
raise NotImplementedError
|
|
98
|
+
|
|
99
|
+
async def episodic_node_delete(self, node: Any, driver: Any) -> None:
|
|
100
|
+
raise NotImplementedError
|
|
101
|
+
|
|
102
|
+
async def episodic_node_save_bulk(
|
|
103
|
+
self,
|
|
104
|
+
_cls: Any,
|
|
105
|
+
driver: Any,
|
|
106
|
+
transaction: Any,
|
|
107
|
+
nodes: list[Any],
|
|
108
|
+
batch_size: int = 100,
|
|
109
|
+
) -> None:
|
|
110
|
+
"""Persist (create or update) many episodic nodes in batches."""
|
|
111
|
+
raise NotImplementedError
|
|
112
|
+
|
|
113
|
+
async def episodic_edge_save_bulk(
|
|
114
|
+
self,
|
|
115
|
+
_cls: Any,
|
|
116
|
+
driver: Any,
|
|
117
|
+
transaction: Any,
|
|
118
|
+
episodic_edges: list[Any],
|
|
119
|
+
batch_size: int = 100,
|
|
120
|
+
) -> None:
|
|
121
|
+
"""Persist (create or update) many episodic edges in batches."""
|
|
122
|
+
raise NotImplementedError
|
|
123
|
+
|
|
124
|
+
async def episodic_node_delete_by_group_id(
|
|
125
|
+
self,
|
|
126
|
+
_cls: Any,
|
|
127
|
+
driver: Any,
|
|
128
|
+
group_id: str,
|
|
129
|
+
batch_size: int = 100,
|
|
130
|
+
) -> None:
|
|
131
|
+
raise NotImplementedError
|
|
132
|
+
|
|
133
|
+
async def episodic_node_delete_by_uuids(
|
|
134
|
+
self,
|
|
135
|
+
_cls: Any,
|
|
136
|
+
driver: Any,
|
|
137
|
+
uuids: list[str],
|
|
138
|
+
group_id: str | None = None,
|
|
139
|
+
batch_size: int = 100,
|
|
140
|
+
) -> None:
|
|
141
|
+
raise NotImplementedError
|
|
142
|
+
|
|
143
|
+
# -----------------
|
|
144
|
+
# Edge: Save/Delete
|
|
145
|
+
# -----------------
|
|
146
|
+
|
|
147
|
+
async def edge_save(self, edge: Any, driver: Any) -> None:
|
|
148
|
+
"""Persist (create or update) a single edge."""
|
|
149
|
+
raise NotImplementedError
|
|
150
|
+
|
|
151
|
+
async def edge_delete(self, edge: Any, driver: Any) -> None:
|
|
152
|
+
raise NotImplementedError
|
|
153
|
+
|
|
154
|
+
async def edge_save_bulk(
|
|
155
|
+
self,
|
|
156
|
+
_cls: Any,
|
|
157
|
+
driver: Any,
|
|
158
|
+
transaction: Any,
|
|
159
|
+
edges: list[Any],
|
|
160
|
+
batch_size: int = 100,
|
|
161
|
+
) -> None:
|
|
162
|
+
"""Persist (create or update) many edges in batches."""
|
|
163
|
+
raise NotImplementedError
|
|
164
|
+
|
|
165
|
+
async def edge_delete_by_uuids(
|
|
166
|
+
self,
|
|
167
|
+
_cls: Any,
|
|
168
|
+
driver: Any,
|
|
169
|
+
uuids: list[str],
|
|
170
|
+
group_id: str | None = None,
|
|
171
|
+
) -> None:
|
|
172
|
+
raise NotImplementedError
|
|
173
|
+
|
|
174
|
+
# -----------------
|
|
175
|
+
# Edge: Embeddings (load)
|
|
176
|
+
# -----------------
|
|
177
|
+
|
|
178
|
+
async def edge_load_embeddings(self, edge: Any, driver: Any) -> None:
|
|
179
|
+
"""
|
|
180
|
+
Load embedding vectors for a single edge into the instance (e.g., set edge.embedding or similar).
|
|
181
|
+
"""
|
|
182
|
+
raise NotImplementedError
|
|
183
|
+
|
|
184
|
+
async def edge_load_embeddings_bulk(
|
|
185
|
+
self,
|
|
186
|
+
_cls: Any,
|
|
187
|
+
driver: Any,
|
|
188
|
+
transaction: Any,
|
|
189
|
+
edges: list[Any],
|
|
190
|
+
batch_size: int = 100,
|
|
191
|
+
) -> None:
|
|
192
|
+
"""
|
|
193
|
+
Load embedding vectors for many edges in batches. Mutates the provided edge instances.
|
|
194
|
+
"""
|
|
195
|
+
raise NotImplementedError
|
|
@@ -22,28 +22,9 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
|
|
22
22
|
from typing_extensions import LiteralString
|
|
23
23
|
|
|
24
24
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
25
|
-
from graphiti_core.helpers import semaphore_gather
|
|
26
25
|
|
|
27
26
|
logger = logging.getLogger(__name__)
|
|
28
27
|
|
|
29
|
-
try:
|
|
30
|
-
import boto3
|
|
31
|
-
from opensearchpy import (
|
|
32
|
-
AIOHttpConnection,
|
|
33
|
-
AsyncOpenSearch,
|
|
34
|
-
AWSV4SignerAuth,
|
|
35
|
-
Urllib3AWSV4SignerAuth,
|
|
36
|
-
Urllib3HttpConnection,
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
_HAS_OPENSEARCH = True
|
|
40
|
-
except ImportError:
|
|
41
|
-
boto3 = None
|
|
42
|
-
OpenSearch = None
|
|
43
|
-
Urllib3AWSV4SignerAuth = None
|
|
44
|
-
Urllib3HttpConnection = None
|
|
45
|
-
_HAS_OPENSEARCH = False
|
|
46
|
-
|
|
47
28
|
|
|
48
29
|
class Neo4jDriver(GraphDriver):
|
|
49
30
|
provider = GraphProvider.NEO4J
|
|
@@ -54,11 +35,6 @@ class Neo4jDriver(GraphDriver):
|
|
|
54
35
|
user: str | None,
|
|
55
36
|
password: str | None,
|
|
56
37
|
database: str = 'neo4j',
|
|
57
|
-
aoss_host: str | None = None,
|
|
58
|
-
aoss_port: int | None = None,
|
|
59
|
-
aws_profile_name: str | None = None,
|
|
60
|
-
aws_region: str | None = None,
|
|
61
|
-
aws_service: str | None = None,
|
|
62
38
|
):
|
|
63
39
|
super().__init__()
|
|
64
40
|
self.client = AsyncGraphDatabase.driver(
|
|
@@ -68,24 +44,6 @@ class Neo4jDriver(GraphDriver):
|
|
|
68
44
|
self._database = database
|
|
69
45
|
|
|
70
46
|
self.aoss_client = None
|
|
71
|
-
if aoss_host and aoss_port and boto3 is not None:
|
|
72
|
-
try:
|
|
73
|
-
region = aws_region
|
|
74
|
-
service = aws_service
|
|
75
|
-
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
|
|
76
|
-
auth = AWSV4SignerAuth(credentials, region or '', service or '')
|
|
77
|
-
|
|
78
|
-
self.aoss_client = AsyncOpenSearch(
|
|
79
|
-
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
|
80
|
-
auth=auth,
|
|
81
|
-
use_ssl=True,
|
|
82
|
-
verify_certs=True,
|
|
83
|
-
connection_class=AIOHttpConnection,
|
|
84
|
-
pool_maxsize=20,
|
|
85
|
-
) # type: ignore
|
|
86
|
-
except Exception as e:
|
|
87
|
-
logger.warning(f'Failed to initialize OpenSearch client: {e}')
|
|
88
|
-
self.aoss_client = None
|
|
89
47
|
|
|
90
48
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
|
91
49
|
# Check if database_ is provided in kwargs.
|
|
@@ -111,13 +69,6 @@ class Neo4jDriver(GraphDriver):
|
|
|
111
69
|
return await self.client.close()
|
|
112
70
|
|
|
113
71
|
def delete_all_indexes(self) -> Coroutine:
|
|
114
|
-
if self.aoss_client:
|
|
115
|
-
return semaphore_gather(
|
|
116
|
-
self.client.execute_query(
|
|
117
|
-
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
118
|
-
),
|
|
119
|
-
self.delete_aoss_indices(),
|
|
120
|
-
)
|
|
121
72
|
return self.client.execute_query(
|
|
122
73
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
123
74
|
)
|
|
@@ -22,21 +22,16 @@ from typing import Any
|
|
|
22
22
|
|
|
23
23
|
import boto3
|
|
24
24
|
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
|
25
|
-
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
|
25
|
+
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
|
26
26
|
|
|
27
|
-
from graphiti_core.driver.driver import
|
|
28
|
-
DEFAULT_SIZE,
|
|
29
|
-
GraphDriver,
|
|
30
|
-
GraphDriverSession,
|
|
31
|
-
GraphProvider,
|
|
32
|
-
)
|
|
27
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
33
28
|
|
|
34
29
|
logger = logging.getLogger(__name__)
|
|
30
|
+
DEFAULT_SIZE = 10
|
|
35
31
|
|
|
36
|
-
|
|
32
|
+
aoss_indices = [
|
|
37
33
|
{
|
|
38
34
|
'index_name': 'node_name_and_summary',
|
|
39
|
-
'alias_name': 'entities',
|
|
40
35
|
'body': {
|
|
41
36
|
'mappings': {
|
|
42
37
|
'properties': {
|
|
@@ -54,7 +49,6 @@ neptune_aoss_indices = [
|
|
|
54
49
|
},
|
|
55
50
|
{
|
|
56
51
|
'index_name': 'community_name',
|
|
57
|
-
'alias_name': 'communities',
|
|
58
52
|
'body': {
|
|
59
53
|
'mappings': {
|
|
60
54
|
'properties': {
|
|
@@ -71,7 +65,6 @@ neptune_aoss_indices = [
|
|
|
71
65
|
},
|
|
72
66
|
{
|
|
73
67
|
'index_name': 'episode_content',
|
|
74
|
-
'alias_name': 'episodes',
|
|
75
68
|
'body': {
|
|
76
69
|
'mappings': {
|
|
77
70
|
'properties': {
|
|
@@ -95,7 +88,6 @@ neptune_aoss_indices = [
|
|
|
95
88
|
},
|
|
96
89
|
{
|
|
97
90
|
'index_name': 'edge_name_and_fact',
|
|
98
|
-
'alias_name': 'facts',
|
|
99
91
|
'body': {
|
|
100
92
|
'mappings': {
|
|
101
93
|
'properties': {
|
|
@@ -228,27 +220,52 @@ class NeptuneDriver(GraphDriver):
|
|
|
228
220
|
async def _delete_all_data(self) -> Any:
|
|
229
221
|
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
|
230
222
|
|
|
223
|
+
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
|
224
|
+
return self.delete_all_indexes_impl()
|
|
225
|
+
|
|
226
|
+
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
|
227
|
+
# No matter what happens above, always return True
|
|
228
|
+
return self.delete_aoss_indices()
|
|
229
|
+
|
|
231
230
|
async def create_aoss_indices(self):
|
|
232
|
-
for index in
|
|
231
|
+
for index in aoss_indices:
|
|
233
232
|
index_name = index['index_name']
|
|
234
233
|
client = self.aoss_client
|
|
235
|
-
if not client:
|
|
236
|
-
raise ValueError(
|
|
237
|
-
'You must provide an AOSS endpoint to create an OpenSearch driver.'
|
|
238
|
-
)
|
|
239
234
|
if not client.indices.exists(index=index_name):
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
alias_name = index.get('alias_name', index_name)
|
|
243
|
-
|
|
244
|
-
if not client.indices.exists_alias(name=alias_name, index=index_name):
|
|
245
|
-
await client.indices.put_alias(index=index_name, name=alias_name)
|
|
246
|
-
|
|
235
|
+
client.indices.create(index=index_name, body=index['body'])
|
|
247
236
|
# Sleep for 1 minute to let the index creation complete
|
|
248
237
|
await asyncio.sleep(60)
|
|
249
238
|
|
|
250
|
-
def
|
|
251
|
-
|
|
239
|
+
async def delete_aoss_indices(self):
|
|
240
|
+
for index in aoss_indices:
|
|
241
|
+
index_name = index['index_name']
|
|
242
|
+
client = self.aoss_client
|
|
243
|
+
if client.indices.exists(index=index_name):
|
|
244
|
+
client.indices.delete(index=index_name)
|
|
245
|
+
|
|
246
|
+
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
|
247
|
+
for index in aoss_indices:
|
|
248
|
+
if name.lower() == index['index_name']:
|
|
249
|
+
index['query']['query']['multi_match']['query'] = query_text
|
|
250
|
+
query = {'size': limit, 'query': index['query']}
|
|
251
|
+
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
|
252
|
+
return resp
|
|
253
|
+
return {}
|
|
254
|
+
|
|
255
|
+
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
256
|
+
for index in aoss_indices:
|
|
257
|
+
if name.lower() == index['index_name']:
|
|
258
|
+
to_index = []
|
|
259
|
+
for d in data:
|
|
260
|
+
item = {'_index': name, '_id': d['uuid']}
|
|
261
|
+
for p in index['body']['mappings']['properties']:
|
|
262
|
+
if p in d:
|
|
263
|
+
item[p] = d[p]
|
|
264
|
+
to_index.append(item)
|
|
265
|
+
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
|
266
|
+
return success
|
|
267
|
+
|
|
268
|
+
return 0
|
|
252
269
|
|
|
253
270
|
|
|
254
271
|
class NeptuneDriverSession(GraphDriverSession):
|
|
File without changes
|