graphiti-core 0.22.0rc5__py3-none-any.whl → 0.22.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -24,6 +24,9 @@ from typing import Any
24
24
 
25
25
  from dotenv import load_dotenv
26
26
 
27
+ from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface
28
+ from graphiti_core.driver.search_interface.search_interface import SearchInterface
29
+
27
30
  logger = logging.getLogger(__name__)
28
31
 
29
32
  DEFAULT_SIZE = 10
@@ -73,7 +76,8 @@ class GraphDriver(ABC):
73
76
  '' # Neo4j (default) syntax does not require a prefix for fulltext queries
74
77
  )
75
78
  _database: str
76
- aoss_client: Any # type: ignore
79
+ search_interface: SearchInterface | None = None
80
+ graph_operations_interface: GraphOperationsInterface | None = None
77
81
 
78
82
  @abstractmethod
79
83
  def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@@ -109,9 +113,3 @@ class GraphDriver(ABC):
109
113
  Only implemented by providers that need custom fulltext query building.
110
114
  """
111
115
  raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
112
-
113
- async def save_to_aoss(self, name: str, data: list[dict]) -> int:
114
- return 0
115
-
116
- async def clear_aoss_indices(self):
117
- return 1
@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import asyncio
18
+ import datetime
17
19
  import logging
18
20
  from typing import TYPE_CHECKING, Any
19
21
 
@@ -191,9 +193,36 @@ class FalkorDriver(GraphDriver):
191
193
  await self.client.connection.close()
192
194
 
193
195
  async def delete_all_indexes(self) -> None:
194
- await self.execute_query(
195
- 'CALL db.indexes() YIELD name DROP INDEX name',
196
- )
196
+ result = await self.execute_query('CALL db.indexes()')
197
+ if not result:
198
+ return
199
+
200
+ records, _, _ = result
201
+ drop_tasks = []
202
+
203
+ for record in records:
204
+ label = record['label']
205
+ entity_type = record['entitytype']
206
+
207
+ for field_name, index_type in record['types'].items():
208
+ if 'RANGE' in index_type:
209
+ drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})'))
210
+ elif 'FULLTEXT' in index_type:
211
+ if entity_type == 'NODE':
212
+ drop_tasks.append(
213
+ self.execute_query(
214
+ f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})'
215
+ )
216
+ )
217
+ elif entity_type == 'RELATIONSHIP':
218
+ drop_tasks.append(
219
+ self.execute_query(
220
+ f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})'
221
+ )
222
+ )
223
+
224
+ if drop_tasks:
225
+ await asyncio.gather(*drop_tasks)
197
226
 
198
227
  def clone(self, database: str) -> 'GraphDriver':
199
228
  """
@@ -204,6 +233,28 @@ class FalkorDriver(GraphDriver):
204
233
 
205
234
  return cloned
206
235
 
236
+ async def health_check(self) -> None:
237
+ """Check FalkorDB connectivity by running a simple query."""
238
+ try:
239
+ await self.execute_query('MATCH (n) RETURN 1 LIMIT 1')
240
+ return None
241
+ except Exception as e:
242
+ print(f'FalkorDB health check failed: {e}')
243
+ raise
244
+
245
+ @staticmethod
246
+ def convert_datetimes_to_strings(obj):
247
+ if isinstance(obj, dict):
248
+ return {k: FalkorDriver.convert_datetimes_to_strings(v) for k, v in obj.items()}
249
+ elif isinstance(obj, list):
250
+ return [FalkorDriver.convert_datetimes_to_strings(item) for item in obj]
251
+ elif isinstance(obj, tuple):
252
+ return tuple(FalkorDriver.convert_datetimes_to_strings(item) for item in obj)
253
+ elif isinstance(obj, datetime):
254
+ return obj.isoformat()
255
+ else:
256
+ return obj
257
+
207
258
  def sanitize(self, query: str) -> str:
208
259
  """
209
260
  Replace FalkorDB special characters with whitespace.
File without changes
@@ -0,0 +1,195 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from typing import Any
18
+
19
+ from pydantic import BaseModel
20
+
21
+
22
+ class GraphOperationsInterface(BaseModel):
23
+ """
24
+ Interface for updating graph mutation behavior.
25
+ """
26
+
27
+ # -----------------
28
+ # Node: Save/Delete
29
+ # -----------------
30
+
31
+ async def node_save(self, node: Any, driver: Any) -> None:
32
+ """Persist (create or update) a single node."""
33
+ raise NotImplementedError
34
+
35
+ async def node_delete(self, node: Any, driver: Any) -> None:
36
+ raise NotImplementedError
37
+
38
+ async def node_save_bulk(
39
+ self,
40
+ _cls: Any, # kept for parity; callers won't pass it
41
+ driver: Any,
42
+ transaction: Any,
43
+ nodes: list[Any],
44
+ batch_size: int = 100,
45
+ ) -> None:
46
+ """Persist (create or update) many nodes in batches."""
47
+ raise NotImplementedError
48
+
49
+ async def node_delete_by_group_id(
50
+ self,
51
+ _cls: Any,
52
+ driver: Any,
53
+ group_id: str,
54
+ batch_size: int = 100,
55
+ ) -> None:
56
+ raise NotImplementedError
57
+
58
+ async def node_delete_by_uuids(
59
+ self,
60
+ _cls: Any,
61
+ driver: Any,
62
+ uuids: list[str],
63
+ group_id: str | None = None,
64
+ batch_size: int = 100,
65
+ ) -> None:
66
+ raise NotImplementedError
67
+
68
+ # --------------------------
69
+ # Node: Embeddings (load)
70
+ # --------------------------
71
+
72
+ async def node_load_embeddings(self, node: Any, driver: Any) -> None:
73
+ """
74
+ Load embedding vectors for a single node into the instance (e.g., set node.embedding or similar).
75
+ """
76
+ raise NotImplementedError
77
+
78
+ async def node_load_embeddings_bulk(
79
+ self,
80
+ _cls: Any,
81
+ driver: Any,
82
+ transaction: Any,
83
+ nodes: list[Any],
84
+ batch_size: int = 100,
85
+ ) -> None:
86
+ """
87
+ Load embedding vectors for many nodes in batches. Mutates the provided node instances.
88
+ """
89
+ raise NotImplementedError
90
+
91
+ # --------------------------
92
+ # EpisodicNode: Save/Delete
93
+ # --------------------------
94
+
95
+ async def episodic_node_save(self, node: Any, driver: Any) -> None:
96
+ """Persist (create or update) a single episodic node."""
97
+ raise NotImplementedError
98
+
99
+ async def episodic_node_delete(self, node: Any, driver: Any) -> None:
100
+ raise NotImplementedError
101
+
102
+ async def episodic_node_save_bulk(
103
+ self,
104
+ _cls: Any,
105
+ driver: Any,
106
+ transaction: Any,
107
+ nodes: list[Any],
108
+ batch_size: int = 100,
109
+ ) -> None:
110
+ """Persist (create or update) many episodic nodes in batches."""
111
+ raise NotImplementedError
112
+
113
+ async def episodic_edge_save_bulk(
114
+ self,
115
+ _cls: Any,
116
+ driver: Any,
117
+ transaction: Any,
118
+ episodic_edges: list[Any],
119
+ batch_size: int = 100,
120
+ ) -> None:
121
+ """Persist (create or update) many episodic edges in batches."""
122
+ raise NotImplementedError
123
+
124
+ async def episodic_node_delete_by_group_id(
125
+ self,
126
+ _cls: Any,
127
+ driver: Any,
128
+ group_id: str,
129
+ batch_size: int = 100,
130
+ ) -> None:
131
+ raise NotImplementedError
132
+
133
+ async def episodic_node_delete_by_uuids(
134
+ self,
135
+ _cls: Any,
136
+ driver: Any,
137
+ uuids: list[str],
138
+ group_id: str | None = None,
139
+ batch_size: int = 100,
140
+ ) -> None:
141
+ raise NotImplementedError
142
+
143
+ # -----------------
144
+ # Edge: Save/Delete
145
+ # -----------------
146
+
147
+ async def edge_save(self, edge: Any, driver: Any) -> None:
148
+ """Persist (create or update) a single edge."""
149
+ raise NotImplementedError
150
+
151
+ async def edge_delete(self, edge: Any, driver: Any) -> None:
152
+ raise NotImplementedError
153
+
154
+ async def edge_save_bulk(
155
+ self,
156
+ _cls: Any,
157
+ driver: Any,
158
+ transaction: Any,
159
+ edges: list[Any],
160
+ batch_size: int = 100,
161
+ ) -> None:
162
+ """Persist (create or update) many edges in batches."""
163
+ raise NotImplementedError
164
+
165
+ async def edge_delete_by_uuids(
166
+ self,
167
+ _cls: Any,
168
+ driver: Any,
169
+ uuids: list[str],
170
+ group_id: str | None = None,
171
+ ) -> None:
172
+ raise NotImplementedError
173
+
174
+ # -----------------
175
+ # Edge: Embeddings (load)
176
+ # -----------------
177
+
178
+ async def edge_load_embeddings(self, edge: Any, driver: Any) -> None:
179
+ """
180
+ Load embedding vectors for a single edge into the instance (e.g., set edge.embedding or similar).
181
+ """
182
+ raise NotImplementedError
183
+
184
+ async def edge_load_embeddings_bulk(
185
+ self,
186
+ _cls: Any,
187
+ driver: Any,
188
+ transaction: Any,
189
+ edges: list[Any],
190
+ batch_size: int = 100,
191
+ ) -> None:
192
+ """
193
+ Load embedding vectors for many edges in batches. Mutates the provided edge instances.
194
+ """
195
+ raise NotImplementedError
@@ -72,3 +72,12 @@ class Neo4jDriver(GraphDriver):
72
72
  return self.client.execute_query(
73
73
  'CALL db.indexes() YIELD name DROP INDEX name',
74
74
  )
75
+
76
+ async def health_check(self) -> None:
77
+ """Check Neo4j connectivity by running the driver's verify_connectivity method."""
78
+ try:
79
+ await self.client.verify_connectivity()
80
+ return None
81
+ except Exception as e:
82
+ print(f'Neo4j health check failed: {e}')
83
+ raise
File without changes
@@ -0,0 +1,89 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from typing import Any
18
+
19
+ from pydantic import BaseModel
20
+
21
+
22
+ class SearchInterface(BaseModel):
23
+ """
24
+ This is an interface for implementing custom search logic
25
+ """
26
+
27
+ async def edge_fulltext_search(
28
+ self,
29
+ driver: Any,
30
+ query: str,
31
+ search_filter: Any,
32
+ group_ids: list[str] | None = None,
33
+ limit: int = 100,
34
+ ) -> list[Any]:
35
+ raise NotImplementedError
36
+
37
+ async def edge_similarity_search(
38
+ self,
39
+ driver: Any,
40
+ search_vector: list[float],
41
+ source_node_uuid: str | None,
42
+ target_node_uuid: str | None,
43
+ search_filter: Any,
44
+ group_ids: list[str] | None = None,
45
+ limit: int = 100,
46
+ min_score: float = 0.7,
47
+ ) -> list[Any]:
48
+ raise NotImplementedError
49
+
50
+ async def node_fulltext_search(
51
+ self,
52
+ driver: Any,
53
+ query: str,
54
+ search_filter: Any,
55
+ group_ids: list[str] | None = None,
56
+ limit: int = 100,
57
+ ) -> list[Any]:
58
+ raise NotImplementedError
59
+
60
+ async def node_similarity_search(
61
+ self,
62
+ driver: Any,
63
+ search_vector: list[float],
64
+ search_filter: Any,
65
+ group_ids: list[str] | None = None,
66
+ limit: int = 100,
67
+ min_score: float = 0.7,
68
+ ) -> list[Any]:
69
+ raise NotImplementedError
70
+
71
+ async def episode_fulltext_search(
72
+ self,
73
+ driver: Any,
74
+ query: str,
75
+ search_filter: Any, # kept for parity even if unused in your impl
76
+ group_ids: list[str] | None = None,
77
+ limit: int = 100,
78
+ ) -> list[Any]:
79
+ raise NotImplementedError
80
+
81
+ # ---------- SEARCH FILTERS (sync) ----------
82
+ def build_node_search_filters(self, search_filters: Any) -> Any:
83
+ raise NotImplementedError
84
+
85
+ def build_edge_search_filters(self, search_filters: Any) -> Any:
86
+ raise NotImplementedError
87
+
88
+ class Config:
89
+ arbitrary_types_allowed = True
graphiti_core/edges.py CHANGED
@@ -25,7 +25,7 @@ from uuid import uuid4
25
25
  from pydantic import BaseModel, Field
26
26
  from typing_extensions import LiteralString
27
27
 
28
- from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
28
+ from graphiti_core.driver.driver import GraphDriver, GraphProvider
29
29
  from graphiti_core.embedder import EmbedderClient
30
30
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
31
31
  from graphiti_core.helpers import parse_db_date
@@ -53,6 +53,9 @@ class Edge(BaseModel, ABC):
53
53
  async def save(self, driver: GraphDriver): ...
54
54
 
55
55
  async def delete(self, driver: GraphDriver):
56
+ if driver.graph_operations_interface:
57
+ return await driver.graph_operations_interface.edge_delete(self, driver)
58
+
56
59
  if driver.provider == GraphProvider.KUZU:
57
60
  await driver.execute_query(
58
61
  """
@@ -77,17 +80,13 @@ class Edge(BaseModel, ABC):
77
80
  uuid=self.uuid,
78
81
  )
79
82
 
80
- if driver.aoss_client:
81
- await driver.aoss_client.delete(
82
- index=ENTITY_EDGE_INDEX_NAME,
83
- id=self.uuid,
84
- params={'routing': self.group_id},
85
- )
86
-
87
83
  logger.debug(f'Deleted Edge: {self.uuid}')
88
84
 
89
85
  @classmethod
90
86
  async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
87
+ if driver.graph_operations_interface:
88
+ return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
89
+
91
90
  if driver.provider == GraphProvider.KUZU:
92
91
  await driver.execute_query(
93
92
  """
@@ -115,12 +114,6 @@ class Edge(BaseModel, ABC):
115
114
  uuids=uuids,
116
115
  )
117
116
 
118
- if driver.aoss_client:
119
- await driver.aoss_client.delete_by_query(
120
- index=ENTITY_EDGE_INDEX_NAME,
121
- body={'query': {'terms': {'uuid': uuids}}},
122
- )
123
-
124
117
  logger.debug(f'Deleted Edges: {uuids}')
125
118
 
126
119
  def __hash__(self):
@@ -258,6 +251,9 @@ class EntityEdge(Edge):
258
251
  return self.fact_embedding
259
252
 
260
253
  async def load_fact_embedding(self, driver: GraphDriver):
254
+ if driver.graph_operations_interface:
255
+ return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
256
+
261
257
  query = """
262
258
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
263
259
  RETURN e.fact_embedding AS fact_embedding
@@ -268,21 +264,6 @@ class EntityEdge(Edge):
268
264
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
269
265
  RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
270
266
  """
271
- elif driver.aoss_client:
272
- resp = await driver.aoss_client.search(
273
- body={
274
- 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
275
- 'size': 1,
276
- },
277
- index=ENTITY_EDGE_INDEX_NAME,
278
- params={'routing': self.group_id},
279
- )
280
-
281
- if resp['hits']['hits']:
282
- self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
283
- return
284
- else:
285
- raise EdgeNotFoundError(self.uuid)
286
267
 
287
268
  if driver.provider == GraphProvider.KUZU:
288
269
  query = """
@@ -320,15 +301,11 @@ class EntityEdge(Edge):
320
301
  if driver.provider == GraphProvider.KUZU:
321
302
  edge_data['attributes'] = json.dumps(self.attributes)
322
303
  result = await driver.execute_query(
323
- get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
304
+ get_entity_edge_save_query(driver.provider),
324
305
  **edge_data,
325
306
  )
326
307
  else:
327
308
  edge_data.update(self.attributes or {})
328
-
329
- if driver.aoss_client:
330
- await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
331
-
332
309
  result = await driver.execute_query(
333
310
  get_entity_edge_save_query(driver.provider),
334
311
  edge_data=edge_data,
@@ -68,6 +68,7 @@ def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False)
68
68
  MATCH (target:Entity {uuid: $edge_data.target_uuid})
69
69
  MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
70
70
  SET e = $edge_data
71
+ SET e.fact_embedding = vecf32($edge_data.fact_embedding)
71
72
  RETURN e.uuid AS uuid
72
73
  """
73
74
  case GraphProvider.NEPTUNE:
@@ -133,6 +133,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: b
133
133
  MERGE (n:Entity {{uuid: $entity_data.uuid}})
134
134
  SET n:{labels}
135
135
  SET n = $entity_data
136
+ SET n.name_embedding = vecf32($entity_data.name_embedding)
136
137
  RETURN n.uuid AS uuid
137
138
  """
138
139
  case GraphProvider.KUZU: