graphiti-core 0.20.3__py3-none-any.whl → 0.21.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -14,15 +14,30 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import asyncio
17
18
  import copy
18
19
  import logging
19
20
  from abc import ABC, abstractmethod
20
21
  from collections.abc import Coroutine
22
+ from datetime import datetime
21
23
  from enum import Enum
22
24
  from typing import Any
23
25
 
26
+ from graphiti_core.embedder.client import EMBEDDING_DIM
27
+
28
+ try:
29
+ from opensearchpy import OpenSearch, helpers
30
+
31
+ _HAS_OPENSEARCH = True
32
+ except ImportError:
33
+ OpenSearch = None
34
+ helpers = None
35
+ _HAS_OPENSEARCH = False
36
+
24
37
  logger = logging.getLogger(__name__)
25
38
 
39
+ DEFAULT_SIZE = 10
40
+
26
41
 
27
42
  class GraphProvider(Enum):
28
43
  NEO4J = 'neo4j'
@@ -31,6 +46,93 @@ class GraphProvider(Enum):
31
46
  NEPTUNE = 'neptune'
32
47
 
33
48
 
49
+ aoss_indices = [
50
+ {
51
+ 'index_name': 'entities',
52
+ 'body': {
53
+ 'mappings': {
54
+ 'properties': {
55
+ 'uuid': {'type': 'keyword'},
56
+ 'name': {'type': 'text'},
57
+ 'summary': {'type': 'text'},
58
+ 'group_id': {'type': 'text'},
59
+ 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
60
+ 'name_embedding': {
61
+ 'type': 'knn_vector',
62
+ 'dims': EMBEDDING_DIM,
63
+ 'index': True,
64
+ 'similarity': 'cosine',
65
+ 'method': {
66
+ 'engine': 'faiss',
67
+ 'space_type': 'cosinesimil',
68
+ 'name': 'hnsw',
69
+ 'parameters': {'ef_construction': 128, 'm': 16},
70
+ },
71
+ },
72
+ }
73
+ }
74
+ },
75
+ },
76
+ {
77
+ 'index_name': 'communities',
78
+ 'body': {
79
+ 'mappings': {
80
+ 'properties': {
81
+ 'uuid': {'type': 'keyword'},
82
+ 'name': {'type': 'text'},
83
+ 'group_id': {'type': 'text'},
84
+ }
85
+ }
86
+ },
87
+ },
88
+ {
89
+ 'index_name': 'episodes',
90
+ 'body': {
91
+ 'mappings': {
92
+ 'properties': {
93
+ 'uuid': {'type': 'keyword'},
94
+ 'content': {'type': 'text'},
95
+ 'source': {'type': 'text'},
96
+ 'source_description': {'type': 'text'},
97
+ 'group_id': {'type': 'text'},
98
+ 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
99
+ 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
100
+ }
101
+ }
102
+ },
103
+ },
104
+ {
105
+ 'index_name': 'entity_edges',
106
+ 'body': {
107
+ 'mappings': {
108
+ 'properties': {
109
+ 'uuid': {'type': 'keyword'},
110
+ 'name': {'type': 'text'},
111
+ 'fact': {'type': 'text'},
112
+ 'group_id': {'type': 'text'},
113
+ 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
114
+ 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
115
+ 'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
116
+ 'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
117
+ 'fact_embedding': {
118
+ 'type': 'knn_vector',
119
+ 'dims': EMBEDDING_DIM,
120
+ 'index': True,
121
+ 'similarity': 'cosine',
122
+ 'method': {
123
+ 'engine': 'faiss',
124
+ 'space_type': 'cosinesimil',
125
+ 'name': 'hnsw',
126
+ 'parameters': {'ef_construction': 128, 'm': 16},
127
+ },
128
+ },
129
+ }
130
+ }
131
+ },
132
+ },
133
+ ]
134
+
135
+
34
136
  class GraphDriverSession(ABC):
35
137
  provider: GraphProvider
36
138
 
@@ -61,6 +163,7 @@ class GraphDriver(ABC):
61
163
  '' # Neo4j (default) syntax does not require a prefix for fulltext queries
62
164
  )
63
165
  _database: str
166
+ aoss_client: OpenSearch | None # type: ignore
64
167
 
65
168
  @abstractmethod
66
169
  def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@@ -87,3 +190,70 @@ class GraphDriver(ABC):
87
190
  cloned._database = database
88
191
 
89
192
  return cloned
193
+
194
+ async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
195
+ # No matter what happens above, always return True
196
+ return self.delete_aoss_indices()
197
+
198
+ async def create_aoss_indices(self):
199
+ client = self.aoss_client
200
+ if not client:
201
+ logger.warning('No OpenSearch client found')
202
+ return
203
+
204
+ for index in aoss_indices:
205
+ alias_name = index['index_name']
206
+
207
+ # If alias already exists, skip (idempotent behavior)
208
+ if client.indices.exists_alias(name=alias_name):
209
+ continue
210
+
211
+ # Build a physical index name with timestamp
212
+ ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
213
+ physical_index_name = f'{alias_name}_{ts_suffix}'
214
+
215
+ # Create the index
216
+ client.indices.create(index=physical_index_name, body=index['body'])
217
+
218
+ # Point alias to it
219
+ client.indices.put_alias(index=physical_index_name, name=alias_name)
220
+
221
+ # Allow some time for index creation
222
+ await asyncio.sleep(60)
223
+
224
+ async def delete_aoss_indices(self):
225
+ for index in aoss_indices:
226
+ index_name = index['index_name']
227
+ client = self.aoss_client
228
+
229
+ if not client:
230
+ logger.warning('No OpenSearch client found')
231
+ return
232
+
233
+ if client.indices.exists(index=index_name):
234
+ client.indices.delete(index=index_name)
235
+
236
+ def save_to_aoss(self, name: str, data: list[dict]) -> int:
237
+ client = self.aoss_client
238
+ if not client or not helpers:
239
+ logger.warning('No OpenSearch client found')
240
+ return 0
241
+
242
+ for index in aoss_indices:
243
+ if name.lower() == index['index_name']:
244
+ to_index = []
245
+ for d in data:
246
+ item = {
247
+ '_index': name,
248
+ '_routing': d.get('group_id'), # shard routing
249
+ }
250
+ for p in index['body']['mappings']['properties']:
251
+ if p in d: # protect against missing fields
252
+ item[p] = d[p]
253
+ to_index.append(item)
254
+
255
+ success, failed = helpers.bulk(client, to_index, stats_only=True)
256
+
257
+ return success if failed == 0 else success
258
+
259
+ return 0
@@ -74,6 +74,7 @@ class FalkorDriverSession(GraphDriverSession):
74
74
 
75
75
  class FalkorDriver(GraphDriver):
76
76
  provider = GraphProvider.FALKORDB
77
+ aoss_client: None = None
77
78
 
78
79
  def __init__(
79
80
  self,
@@ -92,6 +92,7 @@ SCHEMA_QUERIES = """
92
92
 
93
93
  class KuzuDriver(GraphDriver):
94
94
  provider: GraphProvider = GraphProvider.KUZU
95
+ aoss_client: None = None
95
96
 
96
97
  def __init__(
97
98
  self,
@@ -22,14 +22,35 @@ from neo4j import AsyncGraphDatabase, EagerResult
22
22
  from typing_extensions import LiteralString
23
23
 
24
24
  from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
25
+ from graphiti_core.helpers import semaphore_gather
25
26
 
26
27
  logger = logging.getLogger(__name__)
27
28
 
29
+ try:
30
+ import boto3
31
+ from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
32
+
33
+ _HAS_OPENSEARCH = True
34
+ except ImportError:
35
+ boto3 = None
36
+ OpenSearch = None
37
+ Urllib3AWSV4SignerAuth = None
38
+ Urllib3HttpConnection = None
39
+ _HAS_OPENSEARCH = False
40
+
28
41
 
29
42
  class Neo4jDriver(GraphDriver):
30
43
  provider = GraphProvider.NEO4J
31
44
 
32
- def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
45
+ def __init__(
46
+ self,
47
+ uri: str,
48
+ user: str | None,
49
+ password: str | None,
50
+ database: str = 'neo4j',
51
+ aoss_host: str | None = None,
52
+ aoss_port: int | None = None,
53
+ ):
33
54
  super().__init__()
34
55
  self.client = AsyncGraphDatabase.driver(
35
56
  uri=uri,
@@ -37,6 +58,24 @@ class Neo4jDriver(GraphDriver):
37
58
  )
38
59
  self._database = database
39
60
 
61
+ self.aoss_client = None
62
+ if aoss_host and aoss_port and boto3 is not None:
63
+ try:
64
+ session = boto3.Session()
65
+ self.aoss_client = OpenSearch( # type: ignore
66
+ hosts=[{'host': aoss_host, 'port': aoss_port}],
67
+ http_auth=Urllib3AWSV4SignerAuth( # type: ignore
68
+ session.get_credentials(), session.region_name, 'aoss'
69
+ ),
70
+ use_ssl=True,
71
+ verify_certs=True,
72
+ connection_class=Urllib3HttpConnection,
73
+ pool_maxsize=20,
74
+ ) # type: ignore
75
+ except Exception as e:
76
+ logger.warning(f'Failed to initialize OpenSearch client: {e}')
77
+ self.aoss_client = None
78
+
40
79
  async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
41
80
  # Check if database_ is provided in kwargs.
42
81
  # If not populated, set the value to retain backwards compatibility
@@ -60,7 +99,14 @@ class Neo4jDriver(GraphDriver):
60
99
  async def close(self) -> None:
61
100
  return await self.client.close()
62
101
 
63
- def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
102
+ def delete_all_indexes(self) -> Coroutine:
103
+ if self.aoss_client:
104
+ return semaphore_gather(
105
+ self.client.execute_query(
106
+ 'CALL db.indexes() YIELD name DROP INDEX name',
107
+ ),
108
+ self.delete_aoss_indices(),
109
+ )
64
110
  return self.client.execute_query(
65
111
  'CALL db.indexes() YIELD name DROP INDEX name',
66
112
  )
@@ -22,16 +22,21 @@ from typing import Any
22
22
 
23
23
  import boto3
24
24
  from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
25
- from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
25
+ from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
26
26
 
27
- from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
27
+ from graphiti_core.driver.driver import (
28
+ DEFAULT_SIZE,
29
+ GraphDriver,
30
+ GraphDriverSession,
31
+ GraphProvider,
32
+ )
28
33
 
29
34
  logger = logging.getLogger(__name__)
30
- DEFAULT_SIZE = 10
31
35
 
32
- aoss_indices = [
36
+ neptune_aoss_indices = [
33
37
  {
34
38
  'index_name': 'node_name_and_summary',
39
+ 'alias_name': 'entities',
35
40
  'body': {
36
41
  'mappings': {
37
42
  'properties': {
@@ -49,6 +54,7 @@ aoss_indices = [
49
54
  },
50
55
  {
51
56
  'index_name': 'community_name',
57
+ 'alias_name': 'communities',
52
58
  'body': {
53
59
  'mappings': {
54
60
  'properties': {
@@ -65,6 +71,7 @@ aoss_indices = [
65
71
  },
66
72
  {
67
73
  'index_name': 'episode_content',
74
+ 'alias_name': 'episodes',
68
75
  'body': {
69
76
  'mappings': {
70
77
  'properties': {
@@ -88,6 +95,7 @@ aoss_indices = [
88
95
  },
89
96
  {
90
97
  'index_name': 'edge_name_and_fact',
98
+ 'alias_name': 'facts',
91
99
  'body': {
92
100
  'mappings': {
93
101
  'properties': {
@@ -220,54 +228,27 @@ class NeptuneDriver(GraphDriver):
220
228
  async def _delete_all_data(self) -> Any:
221
229
  return await self.execute_query('MATCH (n) DETACH DELETE n')
222
230
 
223
- def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
224
- return self.delete_all_indexes_impl()
225
-
226
- async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
227
- # No matter what happens above, always return True
228
- return self.delete_aoss_indices()
229
-
230
231
  async def create_aoss_indices(self):
231
- for index in aoss_indices:
232
+ for index in neptune_aoss_indices:
232
233
  index_name = index['index_name']
233
234
  client = self.aoss_client
235
+ if not client:
236
+ raise ValueError(
237
+ 'You must provide an AOSS endpoint to create an OpenSearch driver.'
238
+ )
234
239
  if not client.indices.exists(index=index_name):
235
240
  client.indices.create(index=index_name, body=index['body'])
241
+
242
+ alias_name = index.get('alias_name', index_name)
243
+
244
+ if not client.indices.exists_alias(name=alias_name, index=index_name):
245
+ client.indices.put_alias(index=index_name, name=alias_name)
246
+
236
247
  # Sleep for 1 minute to let the index creation complete
237
248
  await asyncio.sleep(60)
238
249
 
239
- async def delete_aoss_indices(self):
240
- for index in aoss_indices:
241
- index_name = index['index_name']
242
- client = self.aoss_client
243
- if client.indices.exists(index=index_name):
244
- client.indices.delete(index=index_name)
245
-
246
- def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
247
- for index in aoss_indices:
248
- if name.lower() == index['index_name']:
249
- index['query']['query']['multi_match']['query'] = query_text
250
- query = {'size': limit, 'query': index['query']}
251
- resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
252
- return resp
253
- return {}
254
-
255
- def save_to_aoss(self, name: str, data: list[dict]) -> int:
256
- for index in aoss_indices:
257
- if name.lower() == index['index_name']:
258
- to_index = []
259
- for d in data:
260
- item = {'_index': name}
261
- for p in index['body']['mappings']['properties']:
262
- item[p] = d[p]
263
- to_index.append(item)
264
- success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
265
- if failed > 0:
266
- return success
267
- else:
268
- return 0
269
-
270
- return 0
250
+ def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
251
+ return self.delete_all_indexes_impl()
271
252
 
272
253
 
273
254
  class NeptuneDriverSession(GraphDriverSession):
graphiti_core/edges.py CHANGED
@@ -255,6 +255,21 @@ class EntityEdge(Edge):
255
255
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
256
256
  RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
257
257
  """
258
+ elif driver.aoss_client:
259
+ resp = driver.aoss_client.search(
260
+ body={
261
+ 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
262
+ 'size': 1,
263
+ },
264
+ index='entity_edges',
265
+ routing=self.group_id,
266
+ )
267
+
268
+ if resp['hits']['hits']:
269
+ self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
270
+ return
271
+ else:
272
+ raise EdgeNotFoundError(self.uuid)
258
273
 
259
274
  if driver.provider == GraphProvider.KUZU:
260
275
  query = """
@@ -292,14 +307,14 @@ class EntityEdge(Edge):
292
307
  if driver.provider == GraphProvider.KUZU:
293
308
  edge_data['attributes'] = json.dumps(self.attributes)
294
309
  result = await driver.execute_query(
295
- get_entity_edge_save_query(driver.provider),
310
+ get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
296
311
  **edge_data,
297
312
  )
298
313
  else:
299
314
  edge_data.update(self.attributes or {})
300
315
 
301
- if driver.provider == GraphProvider.NEPTUNE:
302
- driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
316
+ if driver.aoss_client:
317
+ driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
303
318
 
304
319
  result = await driver.execute_query(
305
320
  get_entity_edge_save_query(driver.provider),
@@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import os
17
18
  from abc import ABC, abstractmethod
18
19
  from collections.abc import Iterable
19
20
 
20
21
  from pydantic import BaseModel, Field
21
22
 
22
- EMBEDDING_DIM = 1024
23
+ EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
23
24
 
24
25
 
25
26
  class EmbedderConfig(BaseModel):
graphiti_core/graphiti.py CHANGED
@@ -122,6 +122,11 @@ class AddBulkEpisodeResults(BaseModel):
122
122
  community_edges: list[CommunityEdge]
123
123
 
124
124
 
125
+ class AddTripletResults(BaseModel):
126
+ nodes: list[EntityNode]
127
+ edges: list[EntityEdge]
128
+
129
+
125
130
  class Graphiti:
126
131
  def __init__(
127
132
  self,
@@ -1015,7 +1020,9 @@ class Graphiti:
1015
1020
 
1016
1021
  return SearchResults(edges=edges, nodes=nodes)
1017
1022
 
1018
- async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
1023
+ async def add_triplet(
1024
+ self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
1025
+ ) -> AddTripletResults:
1019
1026
  if source_node.name_embedding is None:
1020
1027
  await source_node.generate_name_embedding(self.embedder)
1021
1028
  if target_node.name_embedding is None:
@@ -1059,6 +1066,7 @@ class Graphiti:
1059
1066
  await create_entity_node_embeddings(self.embedder, nodes)
1060
1067
 
1061
1068
  await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
1069
+ return AddTripletResults(edges=edges, nodes=nodes)
1062
1070
 
1063
1071
  async def remove_episode(self, episode_uuid: str):
1064
1072
  # Find the episode to be deleted
@@ -60,7 +60,7 @@ EPISODIC_EDGE_RETURN = """
60
60
  """
61
61
 
62
62
 
63
- def get_entity_edge_save_query(provider: GraphProvider) -> str:
63
+ def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
64
64
  match provider:
65
65
  case GraphProvider.FALKORDB:
66
66
  return """
@@ -99,17 +99,28 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
99
99
  RETURN e.uuid AS uuid
100
100
  """
101
101
  case _: # Neo4j
102
- return """
103
- MATCH (source:Entity {uuid: $edge_data.source_uuid})
104
- MATCH (target:Entity {uuid: $edge_data.target_uuid})
105
- MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
106
- SET e = $edge_data
107
- WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
102
+ save_embedding_query = (
103
+ """WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
104
+ if not has_aoss
105
+ else ''
106
+ )
107
+ return (
108
+ (
109
+ """
110
+ MATCH (source:Entity {uuid: $edge_data.source_uuid})
111
+ MATCH (target:Entity {uuid: $edge_data.target_uuid})
112
+ MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
113
+ SET e = $edge_data
114
+ """
115
+ + save_embedding_query
116
+ )
117
+ + """
108
118
  RETURN e.uuid AS uuid
109
- """
119
+ """
120
+ )
110
121
 
111
122
 
112
- def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
123
+ def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
113
124
  match provider:
114
125
  case GraphProvider.FALKORDB:
115
126
  return """
@@ -152,15 +163,24 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
152
163
  RETURN e.uuid AS uuid
153
164
  """
154
165
  case _:
155
- return """
156
- UNWIND $entity_edges AS edge
157
- MATCH (source:Entity {uuid: edge.source_node_uuid})
158
- MATCH (target:Entity {uuid: edge.target_node_uuid})
159
- MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
160
- SET e = edge
161
- WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
166
+ save_embedding_query = (
167
+ 'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
168
+ if not has_aoss
169
+ else ''
170
+ )
171
+ return (
172
+ """
173
+ UNWIND $entity_edges AS edge
174
+ MATCH (source:Entity {uuid: edge.source_node_uuid})
175
+ MATCH (target:Entity {uuid: edge.target_node_uuid})
176
+ MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
177
+ SET e = edge
178
+ """
179
+ + save_embedding_query
180
+ + """
162
181
  RETURN edge.uuid AS uuid
163
182
  """
183
+ )
164
184
 
165
185
 
166
186
  def get_entity_edge_return_query(provider: GraphProvider) -> str:
@@ -126,7 +126,7 @@ EPISODIC_NODE_RETURN_NEPTUNE = """
126
126
  """
127
127
 
128
128
 
129
- def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
129
+ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
130
130
  match provider:
131
131
  case GraphProvider.FALKORDB:
132
132
  return f"""
@@ -161,16 +161,27 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
161
161
  RETURN n.uuid AS uuid
162
162
  """
163
163
  case _:
164
- return f"""
164
+ save_embedding_query = (
165
+ 'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
166
+ if not has_aoss
167
+ else ''
168
+ )
169
+ return (
170
+ f"""
165
171
  MERGE (n:Entity {{uuid: $entity_data.uuid}})
166
172
  SET n:{labels}
167
173
  SET n = $entity_data
168
- WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
174
+ """
175
+ + save_embedding_query
176
+ + """
169
177
  RETURN n.uuid AS uuid
170
178
  """
179
+ )
171
180
 
172
181
 
173
- def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
182
+ def get_entity_node_save_bulk_query(
183
+ provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
184
+ ) -> str | Any:
174
185
  match provider:
175
186
  case GraphProvider.FALKORDB:
176
187
  queries = []
@@ -222,14 +233,23 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
222
233
  RETURN n.uuid AS uuid
223
234
  """
224
235
  case _: # Neo4j
225
- return """
226
- UNWIND $nodes AS node
227
- MERGE (n:Entity {uuid: node.uuid})
228
- SET n:$(node.labels)
229
- SET n = node
230
- WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
236
+ save_embedding_query = (
237
+ 'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
238
+ if not has_aoss
239
+ else ''
240
+ )
241
+ return (
242
+ """
243
+ UNWIND $nodes AS node
244
+ MERGE (n:Entity {uuid: node.uuid})
245
+ SET n:$(node.labels)
246
+ SET n = node
247
+ """
248
+ + save_embedding_query
249
+ + """
231
250
  RETURN n.uuid AS uuid
232
251
  """
252
+ )
233
253
 
234
254
 
235
255
  def get_entity_node_return_query(provider: GraphProvider) -> str: