graphiti-core 0.21.0rc12__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (41) hide show
  1. graphiti_core/driver/driver.py +4 -211
  2. graphiti_core/driver/falkordb_driver.py +31 -3
  3. graphiti_core/driver/graph_operations/graph_operations.py +195 -0
  4. graphiti_core/driver/neo4j_driver.py +0 -49
  5. graphiti_core/driver/neptune_driver.py +43 -26
  6. graphiti_core/driver/search_interface/__init__.py +0 -0
  7. graphiti_core/driver/search_interface/search_interface.py +89 -0
  8. graphiti_core/edges.py +11 -34
  9. graphiti_core/graphiti.py +459 -326
  10. graphiti_core/graphiti_types.py +2 -0
  11. graphiti_core/llm_client/anthropic_client.py +64 -45
  12. graphiti_core/llm_client/client.py +67 -19
  13. graphiti_core/llm_client/gemini_client.py +73 -54
  14. graphiti_core/llm_client/openai_base_client.py +65 -43
  15. graphiti_core/llm_client/openai_generic_client.py +65 -43
  16. graphiti_core/models/edges/edge_db_queries.py +1 -0
  17. graphiti_core/models/nodes/node_db_queries.py +1 -0
  18. graphiti_core/nodes.py +26 -99
  19. graphiti_core/prompts/dedupe_edges.py +4 -4
  20. graphiti_core/prompts/dedupe_nodes.py +10 -10
  21. graphiti_core/prompts/extract_edges.py +4 -4
  22. graphiti_core/prompts/extract_nodes.py +26 -28
  23. graphiti_core/prompts/prompt_helpers.py +18 -2
  24. graphiti_core/prompts/snippets.py +29 -0
  25. graphiti_core/prompts/summarize_nodes.py +22 -24
  26. graphiti_core/search/search_filters.py +0 -38
  27. graphiti_core/search/search_helpers.py +4 -4
  28. graphiti_core/search/search_utils.py +84 -220
  29. graphiti_core/tracer.py +193 -0
  30. graphiti_core/utils/bulk_utils.py +16 -28
  31. graphiti_core/utils/maintenance/community_operations.py +4 -1
  32. graphiti_core/utils/maintenance/edge_operations.py +30 -15
  33. graphiti_core/utils/maintenance/graph_data_operations.py +6 -25
  34. graphiti_core/utils/maintenance/node_operations.py +99 -51
  35. graphiti_core/utils/maintenance/temporal_operations.py +4 -1
  36. graphiti_core/utils/text_utils.py +53 -0
  37. {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/METADATA +7 -3
  38. {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/RECORD +41 -35
  39. /graphiti_core/{utils/maintenance/utils.py → driver/graph_operations/__init__.py} +0 -0
  40. {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/WHEEL +0 -0
  41. {graphiti_core-0.21.0rc12.dist-info → graphiti_core-0.22.0.dist-info}/licenses/LICENSE +0 -0
@@ -14,28 +14,18 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import copy
19
18
  import logging
20
19
  import os
21
20
  from abc import ABC, abstractmethod
22
21
  from collections.abc import Coroutine
23
- from datetime import datetime
24
22
  from enum import Enum
25
23
  from typing import Any
26
24
 
27
25
  from dotenv import load_dotenv
28
26
 
29
- from graphiti_core.embedder.client import EMBEDDING_DIM
30
-
31
- try:
32
- from opensearchpy import AsyncOpenSearch, helpers
33
-
34
- _HAS_OPENSEARCH = True
35
- except ImportError:
36
- OpenSearch = None
37
- helpers = None
38
- _HAS_OPENSEARCH = False
27
+ from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface
28
+ from graphiti_core.driver.search_interface.search_interface import SearchInterface
39
29
 
40
30
  logger = logging.getLogger(__name__)
41
31
 
@@ -56,91 +46,6 @@ class GraphProvider(Enum):
56
46
  NEPTUNE = 'neptune'
57
47
 
58
48
 
59
- aoss_indices = [
60
- {
61
- 'index_name': ENTITY_INDEX_NAME,
62
- 'body': {
63
- 'settings': {'index': {'knn': True}},
64
- 'mappings': {
65
- 'properties': {
66
- 'uuid': {'type': 'keyword'},
67
- 'name': {'type': 'text'},
68
- 'summary': {'type': 'text'},
69
- 'group_id': {'type': 'keyword'},
70
- 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
71
- 'name_embedding': {
72
- 'type': 'knn_vector',
73
- 'dimension': EMBEDDING_DIM,
74
- 'method': {
75
- 'engine': 'faiss',
76
- 'space_type': 'cosinesimil',
77
- 'name': 'hnsw',
78
- 'parameters': {'ef_construction': 128, 'm': 16},
79
- },
80
- },
81
- }
82
- },
83
- },
84
- },
85
- {
86
- 'index_name': COMMUNITY_INDEX_NAME,
87
- 'body': {
88
- 'mappings': {
89
- 'properties': {
90
- 'uuid': {'type': 'keyword'},
91
- 'name': {'type': 'text'},
92
- 'group_id': {'type': 'keyword'},
93
- }
94
- }
95
- },
96
- },
97
- {
98
- 'index_name': EPISODE_INDEX_NAME,
99
- 'body': {
100
- 'mappings': {
101
- 'properties': {
102
- 'uuid': {'type': 'keyword'},
103
- 'content': {'type': 'text'},
104
- 'source': {'type': 'text'},
105
- 'source_description': {'type': 'text'},
106
- 'group_id': {'type': 'keyword'},
107
- 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
108
- 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
109
- }
110
- }
111
- },
112
- },
113
- {
114
- 'index_name': ENTITY_EDGE_INDEX_NAME,
115
- 'body': {
116
- 'settings': {'index': {'knn': True}},
117
- 'mappings': {
118
- 'properties': {
119
- 'uuid': {'type': 'keyword'},
120
- 'name': {'type': 'text'},
121
- 'fact': {'type': 'text'},
122
- 'group_id': {'type': 'keyword'},
123
- 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
124
- 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
125
- 'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
126
- 'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
127
- 'fact_embedding': {
128
- 'type': 'knn_vector',
129
- 'dimension': EMBEDDING_DIM,
130
- 'method': {
131
- 'engine': 'faiss',
132
- 'space_type': 'cosinesimil',
133
- 'name': 'hnsw',
134
- 'parameters': {'ef_construction': 128, 'm': 16},
135
- },
136
- },
137
- }
138
- },
139
- },
140
- },
141
- ]
142
-
143
-
144
49
  class GraphDriverSession(ABC):
145
50
  provider: GraphProvider
146
51
 
@@ -171,7 +76,8 @@ class GraphDriver(ABC):
171
76
  '' # Neo4j (default) syntax does not require a prefix for fulltext queries
172
77
  )
173
78
  _database: str
174
- aoss_client: AsyncOpenSearch | None # type: ignore
79
+ search_interface: SearchInterface | None = None
80
+ graph_operations_interface: GraphOperationsInterface | None = None
175
81
 
176
82
  @abstractmethod
177
83
  def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@@ -199,119 +105,6 @@ class GraphDriver(ABC):
199
105
 
200
106
  return cloned
201
107
 
202
- async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
203
- # No matter what happens above, always return True
204
- return self.delete_aoss_indices()
205
-
206
- async def create_aoss_indices(self):
207
- client = self.aoss_client
208
- if not client:
209
- logger.warning('No OpenSearch client found')
210
- return
211
-
212
- for index in aoss_indices:
213
- alias_name = index['index_name']
214
-
215
- # If alias already exists, skip (idempotent behavior)
216
- if await client.indices.exists_alias(name=alias_name):
217
- continue
218
-
219
- # Build a physical index name with timestamp
220
- ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
221
- physical_index_name = f'{alias_name}_{ts_suffix}'
222
-
223
- # Create the index
224
- await client.indices.create(index=physical_index_name, body=index['body'])
225
-
226
- # Point alias to it
227
- await client.indices.put_alias(index=physical_index_name, name=alias_name)
228
-
229
- # Allow some time for index creation
230
- await asyncio.sleep(1)
231
-
232
- async def delete_aoss_indices(self):
233
- client = self.aoss_client
234
-
235
- if not client:
236
- logger.warning('No OpenSearch client found')
237
- return
238
-
239
- for entry in aoss_indices:
240
- alias_name = entry['index_name']
241
-
242
- try:
243
- # Resolve alias → indices
244
- alias_info = await client.indices.get_alias(name=alias_name)
245
- indices = list(alias_info.keys())
246
-
247
- if not indices:
248
- logger.info(f"No indices found for alias '{alias_name}'")
249
- continue
250
-
251
- for index in indices:
252
- if await client.indices.exists(index=index):
253
- await client.indices.delete(index=index)
254
- logger.info(f"Deleted index '{index}' (alias: {alias_name})")
255
- else:
256
- logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
257
-
258
- except Exception as e:
259
- logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
260
-
261
- async def clear_aoss_indices(self):
262
- client = self.aoss_client
263
-
264
- if not client:
265
- logger.warning('No OpenSearch client found')
266
- return
267
-
268
- for index in aoss_indices:
269
- index_name = index['index_name']
270
-
271
- if await client.indices.exists(index=index_name):
272
- try:
273
- # Delete all documents but keep the index
274
- response = await client.delete_by_query(
275
- index=index_name,
276
- body={'query': {'match_all': {}}},
277
- )
278
- logger.info(f"Cleared index '{index_name}': {response}")
279
- except Exception as e:
280
- logger.error(f"Error clearing index '{index_name}': {e}")
281
- else:
282
- logger.warning(f"Index '{index_name}' does not exist")
283
-
284
- async def save_to_aoss(self, name: str, data: list[dict]) -> int:
285
- client = self.aoss_client
286
- if not client or not helpers:
287
- logger.warning('No OpenSearch client found')
288
- return 0
289
-
290
- for index in aoss_indices:
291
- if name.lower() == index['index_name']:
292
- to_index = []
293
- for d in data:
294
- doc = {}
295
- for p in index['body']['mappings']['properties']:
296
- if p in d: # protect against missing fields
297
- doc[p] = d[p]
298
-
299
- item = {
300
- '_index': name,
301
- '_id': d['uuid'],
302
- '_routing': d.get('group_id'),
303
- '_source': doc,
304
- }
305
- to_index.append(item)
306
-
307
- success, failed = await helpers.async_bulk(
308
- client, to_index, stats_only=True, request_timeout=60
309
- )
310
-
311
- return success if failed == 0 else success
312
-
313
- return 0
314
-
315
108
  def build_fulltext_query(
316
109
  self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
317
110
  ) -> str:
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import asyncio
17
18
  import logging
18
19
  from typing import TYPE_CHECKING, Any
19
20
 
@@ -191,9 +192,36 @@ class FalkorDriver(GraphDriver):
191
192
  await self.client.connection.close()
192
193
 
193
194
  async def delete_all_indexes(self) -> None:
194
- await self.execute_query(
195
- 'CALL db.indexes() YIELD name DROP INDEX name',
196
- )
195
+ result = await self.execute_query('CALL db.indexes()')
196
+ if not result:
197
+ return
198
+
199
+ records, _, _ = result
200
+ drop_tasks = []
201
+
202
+ for record in records:
203
+ label = record['label']
204
+ entity_type = record['entitytype']
205
+
206
+ for field_name, index_type in record['types'].items():
207
+ if 'RANGE' in index_type:
208
+ drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})'))
209
+ elif 'FULLTEXT' in index_type:
210
+ if entity_type == 'NODE':
211
+ drop_tasks.append(
212
+ self.execute_query(
213
+ f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})'
214
+ )
215
+ )
216
+ elif entity_type == 'RELATIONSHIP':
217
+ drop_tasks.append(
218
+ self.execute_query(
219
+ f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})'
220
+ )
221
+ )
222
+
223
+ if drop_tasks:
224
+ await asyncio.gather(*drop_tasks)
197
225
 
198
226
  def clone(self, database: str) -> 'GraphDriver':
199
227
  """
@@ -0,0 +1,195 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from typing import Any
18
+
19
+ from pydantic import BaseModel
20
+
21
+
22
+ class GraphOperationsInterface(BaseModel):
23
+ """
24
+ Interface for updating graph mutation behavior.
25
+ """
26
+
27
+ # -----------------
28
+ # Node: Save/Delete
29
+ # -----------------
30
+
31
+ async def node_save(self, node: Any, driver: Any) -> None:
32
+ """Persist (create or update) a single node."""
33
+ raise NotImplementedError
34
+
35
+ async def node_delete(self, node: Any, driver: Any) -> None:
36
+ raise NotImplementedError
37
+
38
+ async def node_save_bulk(
39
+ self,
40
+ _cls: Any, # kept for parity; callers won't pass it
41
+ driver: Any,
42
+ transaction: Any,
43
+ nodes: list[Any],
44
+ batch_size: int = 100,
45
+ ) -> None:
46
+ """Persist (create or update) many nodes in batches."""
47
+ raise NotImplementedError
48
+
49
+ async def node_delete_by_group_id(
50
+ self,
51
+ _cls: Any,
52
+ driver: Any,
53
+ group_id: str,
54
+ batch_size: int = 100,
55
+ ) -> None:
56
+ raise NotImplementedError
57
+
58
+ async def node_delete_by_uuids(
59
+ self,
60
+ _cls: Any,
61
+ driver: Any,
62
+ uuids: list[str],
63
+ group_id: str | None = None,
64
+ batch_size: int = 100,
65
+ ) -> None:
66
+ raise NotImplementedError
67
+
68
+ # --------------------------
69
+ # Node: Embeddings (load)
70
+ # --------------------------
71
+
72
+ async def node_load_embeddings(self, node: Any, driver: Any) -> None:
73
+ """
74
+ Load embedding vectors for a single node into the instance (e.g., set node.embedding or similar).
75
+ """
76
+ raise NotImplementedError
77
+
78
+ async def node_load_embeddings_bulk(
79
+ self,
80
+ _cls: Any,
81
+ driver: Any,
82
+ transaction: Any,
83
+ nodes: list[Any],
84
+ batch_size: int = 100,
85
+ ) -> None:
86
+ """
87
+ Load embedding vectors for many nodes in batches. Mutates the provided node instances.
88
+ """
89
+ raise NotImplementedError
90
+
91
+ # --------------------------
92
+ # EpisodicNode: Save/Delete
93
+ # --------------------------
94
+
95
+ async def episodic_node_save(self, node: Any, driver: Any) -> None:
96
+ """Persist (create or update) a single episodic node."""
97
+ raise NotImplementedError
98
+
99
+ async def episodic_node_delete(self, node: Any, driver: Any) -> None:
100
+ raise NotImplementedError
101
+
102
+ async def episodic_node_save_bulk(
103
+ self,
104
+ _cls: Any,
105
+ driver: Any,
106
+ transaction: Any,
107
+ nodes: list[Any],
108
+ batch_size: int = 100,
109
+ ) -> None:
110
+ """Persist (create or update) many episodic nodes in batches."""
111
+ raise NotImplementedError
112
+
113
+ async def episodic_edge_save_bulk(
114
+ self,
115
+ _cls: Any,
116
+ driver: Any,
117
+ transaction: Any,
118
+ episodic_edges: list[Any],
119
+ batch_size: int = 100,
120
+ ) -> None:
121
+ """Persist (create or update) many episodic edges in batches."""
122
+ raise NotImplementedError
123
+
124
+ async def episodic_node_delete_by_group_id(
125
+ self,
126
+ _cls: Any,
127
+ driver: Any,
128
+ group_id: str,
129
+ batch_size: int = 100,
130
+ ) -> None:
131
+ raise NotImplementedError
132
+
133
+ async def episodic_node_delete_by_uuids(
134
+ self,
135
+ _cls: Any,
136
+ driver: Any,
137
+ uuids: list[str],
138
+ group_id: str | None = None,
139
+ batch_size: int = 100,
140
+ ) -> None:
141
+ raise NotImplementedError
142
+
143
+ # -----------------
144
+ # Edge: Save/Delete
145
+ # -----------------
146
+
147
+ async def edge_save(self, edge: Any, driver: Any) -> None:
148
+ """Persist (create or update) a single edge."""
149
+ raise NotImplementedError
150
+
151
+ async def edge_delete(self, edge: Any, driver: Any) -> None:
152
+ raise NotImplementedError
153
+
154
+ async def edge_save_bulk(
155
+ self,
156
+ _cls: Any,
157
+ driver: Any,
158
+ transaction: Any,
159
+ edges: list[Any],
160
+ batch_size: int = 100,
161
+ ) -> None:
162
+ """Persist (create or update) many edges in batches."""
163
+ raise NotImplementedError
164
+
165
+ async def edge_delete_by_uuids(
166
+ self,
167
+ _cls: Any,
168
+ driver: Any,
169
+ uuids: list[str],
170
+ group_id: str | None = None,
171
+ ) -> None:
172
+ raise NotImplementedError
173
+
174
+ # -----------------
175
+ # Edge: Embeddings (load)
176
+ # -----------------
177
+
178
+ async def edge_load_embeddings(self, edge: Any, driver: Any) -> None:
179
+ """
180
+ Load embedding vectors for a single edge into the instance (e.g., set edge.embedding or similar).
181
+ """
182
+ raise NotImplementedError
183
+
184
+ async def edge_load_embeddings_bulk(
185
+ self,
186
+ _cls: Any,
187
+ driver: Any,
188
+ transaction: Any,
189
+ edges: list[Any],
190
+ batch_size: int = 100,
191
+ ) -> None:
192
+ """
193
+ Load embedding vectors for many edges in batches. Mutates the provided edge instances.
194
+ """
195
+ raise NotImplementedError
@@ -22,28 +22,9 @@ from neo4j import AsyncGraphDatabase, EagerResult
22
22
  from typing_extensions import LiteralString
23
23
 
24
24
  from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
25
- from graphiti_core.helpers import semaphore_gather
26
25
 
27
26
  logger = logging.getLogger(__name__)
28
27
 
29
- try:
30
- import boto3
31
- from opensearchpy import (
32
- AIOHttpConnection,
33
- AsyncOpenSearch,
34
- AWSV4SignerAuth,
35
- Urllib3AWSV4SignerAuth,
36
- Urllib3HttpConnection,
37
- )
38
-
39
- _HAS_OPENSEARCH = True
40
- except ImportError:
41
- boto3 = None
42
- OpenSearch = None
43
- Urllib3AWSV4SignerAuth = None
44
- Urllib3HttpConnection = None
45
- _HAS_OPENSEARCH = False
46
-
47
28
 
48
29
  class Neo4jDriver(GraphDriver):
49
30
  provider = GraphProvider.NEO4J
@@ -54,11 +35,6 @@ class Neo4jDriver(GraphDriver):
54
35
  user: str | None,
55
36
  password: str | None,
56
37
  database: str = 'neo4j',
57
- aoss_host: str | None = None,
58
- aoss_port: int | None = None,
59
- aws_profile_name: str | None = None,
60
- aws_region: str | None = None,
61
- aws_service: str | None = None,
62
38
  ):
63
39
  super().__init__()
64
40
  self.client = AsyncGraphDatabase.driver(
@@ -68,24 +44,6 @@ class Neo4jDriver(GraphDriver):
68
44
  self._database = database
69
45
 
70
46
  self.aoss_client = None
71
- if aoss_host and aoss_port and boto3 is not None:
72
- try:
73
- region = aws_region
74
- service = aws_service
75
- credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
76
- auth = AWSV4SignerAuth(credentials, region or '', service or '')
77
-
78
- self.aoss_client = AsyncOpenSearch(
79
- hosts=[{'host': aoss_host, 'port': aoss_port}],
80
- auth=auth,
81
- use_ssl=True,
82
- verify_certs=True,
83
- connection_class=AIOHttpConnection,
84
- pool_maxsize=20,
85
- ) # type: ignore
86
- except Exception as e:
87
- logger.warning(f'Failed to initialize OpenSearch client: {e}')
88
- self.aoss_client = None
89
47
 
90
48
  async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
91
49
  # Check if database_ is provided in kwargs.
@@ -111,13 +69,6 @@ class Neo4jDriver(GraphDriver):
111
69
  return await self.client.close()
112
70
 
113
71
  def delete_all_indexes(self) -> Coroutine:
114
- if self.aoss_client:
115
- return semaphore_gather(
116
- self.client.execute_query(
117
- 'CALL db.indexes() YIELD name DROP INDEX name',
118
- ),
119
- self.delete_aoss_indices(),
120
- )
121
72
  return self.client.execute_query(
122
73
  'CALL db.indexes() YIELD name DROP INDEX name',
123
74
  )
@@ -22,21 +22,16 @@ from typing import Any
22
22
 
23
23
  import boto3
24
24
  from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
25
- from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
25
+ from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
26
26
 
27
- from graphiti_core.driver.driver import (
28
- DEFAULT_SIZE,
29
- GraphDriver,
30
- GraphDriverSession,
31
- GraphProvider,
32
- )
27
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
33
28
 
34
29
  logger = logging.getLogger(__name__)
30
+ DEFAULT_SIZE = 10
35
31
 
36
- neptune_aoss_indices = [
32
+ aoss_indices = [
37
33
  {
38
34
  'index_name': 'node_name_and_summary',
39
- 'alias_name': 'entities',
40
35
  'body': {
41
36
  'mappings': {
42
37
  'properties': {
@@ -54,7 +49,6 @@ neptune_aoss_indices = [
54
49
  },
55
50
  {
56
51
  'index_name': 'community_name',
57
- 'alias_name': 'communities',
58
52
  'body': {
59
53
  'mappings': {
60
54
  'properties': {
@@ -71,7 +65,6 @@ neptune_aoss_indices = [
71
65
  },
72
66
  {
73
67
  'index_name': 'episode_content',
74
- 'alias_name': 'episodes',
75
68
  'body': {
76
69
  'mappings': {
77
70
  'properties': {
@@ -95,7 +88,6 @@ neptune_aoss_indices = [
95
88
  },
96
89
  {
97
90
  'index_name': 'edge_name_and_fact',
98
- 'alias_name': 'facts',
99
91
  'body': {
100
92
  'mappings': {
101
93
  'properties': {
@@ -228,27 +220,52 @@ class NeptuneDriver(GraphDriver):
228
220
  async def _delete_all_data(self) -> Any:
229
221
  return await self.execute_query('MATCH (n) DETACH DELETE n')
230
222
 
223
+ def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
224
+ return self.delete_all_indexes_impl()
225
+
226
+ async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
227
+ # No matter what happens above, always return True
228
+ return self.delete_aoss_indices()
229
+
231
230
  async def create_aoss_indices(self):
232
- for index in neptune_aoss_indices:
231
+ for index in aoss_indices:
233
232
  index_name = index['index_name']
234
233
  client = self.aoss_client
235
- if not client:
236
- raise ValueError(
237
- 'You must provide an AOSS endpoint to create an OpenSearch driver.'
238
- )
239
234
  if not client.indices.exists(index=index_name):
240
- await client.indices.create(index=index_name, body=index['body'])
241
-
242
- alias_name = index.get('alias_name', index_name)
243
-
244
- if not client.indices.exists_alias(name=alias_name, index=index_name):
245
- await client.indices.put_alias(index=index_name, name=alias_name)
246
-
235
+ client.indices.create(index=index_name, body=index['body'])
247
236
  # Sleep for 1 minute to let the index creation complete
248
237
  await asyncio.sleep(60)
249
238
 
250
- def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
251
- return self.delete_all_indexes_impl()
239
+ async def delete_aoss_indices(self):
240
+ for index in aoss_indices:
241
+ index_name = index['index_name']
242
+ client = self.aoss_client
243
+ if client.indices.exists(index=index_name):
244
+ client.indices.delete(index=index_name)
245
+
246
+ def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
247
+ for index in aoss_indices:
248
+ if name.lower() == index['index_name']:
249
+ index['query']['query']['multi_match']['query'] = query_text
250
+ query = {'size': limit, 'query': index['query']}
251
+ resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
252
+ return resp
253
+ return {}
254
+
255
+ def save_to_aoss(self, name: str, data: list[dict]) -> int:
256
+ for index in aoss_indices:
257
+ if name.lower() == index['index_name']:
258
+ to_index = []
259
+ for d in data:
260
+ item = {'_index': name, '_id': d['uuid']}
261
+ for p in index['body']['mappings']['properties']:
262
+ if p in d:
263
+ item[p] = d[p]
264
+ to_index.append(item)
265
+ success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
266
+ return success
267
+
268
+ return 0
252
269
 
253
270
 
254
271
  class NeptuneDriverSession(GraphDriverSession):
File without changes