graphiti-core 0.11.6rc9__tar.gz → 0.12.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (74) hide show
  1. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/PKG-INFO +14 -5
  2. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/README.md +10 -2
  3. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  4. graphiti_core-0.12.0/graphiti_core/driver/__init__.py +17 -0
  5. graphiti_core-0.12.0/graphiti_core/driver/driver.py +66 -0
  6. graphiti_core-0.12.0/graphiti_core/driver/falkordb_driver.py +132 -0
  7. graphiti_core-0.12.0/graphiti_core/driver/neo4j_driver.py +61 -0
  8. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/edges.py +66 -40
  9. graphiti_core-0.12.0/graphiti_core/embedder/azure_openai.py +64 -0
  10. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/embedder/gemini.py +14 -3
  11. graphiti_core-0.12.0/graphiti_core/graph_queries.py +149 -0
  12. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/graphiti.py +41 -14
  13. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/graphiti_types.py +2 -2
  14. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/helpers.py +9 -4
  15. graphiti_core-0.12.0/graphiti_core/llm_client/__init__.py +22 -0
  16. graphiti_core-0.12.0/graphiti_core/llm_client/azure_openai_client.py +73 -0
  17. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/llm_client/gemini_client.py +4 -1
  18. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/models/edges/edge_db_queries.py +2 -4
  19. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/nodes.py +31 -31
  20. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/dedupe_edges.py +52 -1
  21. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/dedupe_nodes.py +79 -4
  22. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/extract_edges.py +50 -5
  23. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/invalidate_edges.py +1 -1
  24. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/search/search.py +6 -10
  25. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/search/search_filters.py +23 -9
  26. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/search/search_utils.py +250 -189
  27. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/bulk_utils.py +38 -11
  28. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/maintenance/community_operations.py +6 -7
  29. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/maintenance/edge_operations.py +149 -19
  30. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
  31. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/maintenance/node_operations.py +52 -71
  32. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/pyproject.toml +10 -3
  33. graphiti_core-0.11.6rc9/graphiti_core/llm_client/__init__.py +0 -6
  34. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/LICENSE +0 -0
  35. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/__init__.py +0 -0
  36. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/cross_encoder/__init__.py +0 -0
  37. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
  38. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/cross_encoder/client.py +0 -0
  39. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/embedder/__init__.py +0 -0
  40. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/embedder/client.py +0 -0
  41. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/embedder/openai.py +0 -0
  42. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/embedder/voyage.py +0 -0
  43. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/errors.py +0 -0
  44. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/llm_client/anthropic_client.py +0 -0
  45. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/llm_client/client.py +0 -0
  46. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/llm_client/config.py +0 -0
  47. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/llm_client/errors.py +0 -0
  48. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/llm_client/groq_client.py +0 -0
  49. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/llm_client/openai_client.py +0 -0
  50. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/llm_client/openai_generic_client.py +0 -0
  51. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/llm_client/utils.py +0 -0
  52. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/models/__init__.py +0 -0
  53. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/models/edges/__init__.py +0 -0
  54. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/models/nodes/__init__.py +0 -0
  55. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  56. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/__init__.py +0 -0
  57. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/eval.py +0 -0
  58. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  59. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/extract_nodes.py +0 -0
  60. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/lib.py +0 -0
  61. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/models.py +0 -0
  62. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/prompt_helpers.py +0 -0
  63. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/prompts/summarize_nodes.py +0 -0
  64. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/py.typed +0 -0
  65. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/search/__init__.py +0 -0
  66. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/search/search_config.py +0 -0
  67. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/search/search_config_recipes.py +0 -0
  68. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/search/search_helpers.py +0 -0
  69. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/__init__.py +0 -0
  70. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/datetime_utils.py +0 -0
  71. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/maintenance/__init__.py +0 -0
  72. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
  73. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/maintenance/utils.py +0 -0
  74. {graphiti_core-0.11.6rc9 → graphiti_core-0.12.0}/graphiti_core/utils/ontology_utils/entity_types_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: graphiti-core
3
- Version: 0.11.6rc9
3
+ Version: 0.12.0
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -17,12 +17,13 @@ Provides-Extra: google-genai
17
17
  Provides-Extra: groq
18
18
  Requires-Dist: anthropic (>=0.49.0) ; extra == "anthropic"
19
19
  Requires-Dist: diskcache (>=5.6.3)
20
+ Requires-Dist: falkordb (>=1.1.2,<2.0.0)
20
21
  Requires-Dist: google-genai (>=1.8.0) ; extra == "google-genai"
21
22
  Requires-Dist: groq (>=0.2.0) ; extra == "groq"
22
- Requires-Dist: neo4j (>=5.23.0)
23
+ Requires-Dist: neo4j (>=5.26.0)
23
24
  Requires-Dist: numpy (>=1.0.0)
24
25
  Requires-Dist: openai (>=1.53.0)
25
- Requires-Dist: pydantic (>=2.8.2)
26
+ Requires-Dist: pydantic (>=2.11.5)
26
27
  Requires-Dist: python-dotenv (>=1.0.1)
27
28
  Requires-Dist: tenacity (>=9.0.0)
28
29
  Project-URL: Homepage, https://help.getzep.com/graphiti/graphiti/overview
@@ -136,7 +137,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
136
137
  Requirements:
137
138
 
138
139
  - Python 3.10 or higher
139
- - Neo4j 5.26 or higher (serves as the embeddings storage backend)
140
+ - Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend)
140
141
  - OpenAI API key (for LLM inference and embedding)
141
142
 
142
143
  > [!IMPORTANT]
@@ -236,7 +237,7 @@ Graphiti supports Azure OpenAI for both LLM inference and embeddings. To use Azu
236
237
  ```python
237
238
  from openai import AsyncAzureOpenAI
238
239
  from graphiti_core import Graphiti
239
- from graphiti_core.llm_client import OpenAIClient
240
+ from graphiti_core.llm_client import LLMConfig, OpenAIClient
240
241
  from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
241
242
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
242
243
 
@@ -252,12 +253,19 @@ azure_openai_client = AsyncAzureOpenAI(
252
253
  azure_endpoint=azure_endpoint
253
254
  )
254
255
 
256
+ # Create LLM Config with your Azure deployed model names
257
+ azure_llm_config = LLMConfig(
258
+ small_model="gpt-4.1-nano",
259
+ model="gpt-4.1-mini",
260
+ )
261
+
255
262
  # Initialize Graphiti with Azure OpenAI clients
256
263
  graphiti = Graphiti(
257
264
  "bolt://localhost:7687",
258
265
  "neo4j",
259
266
  "password",
260
267
  llm_client=OpenAIClient(
268
+ llm_config=azure_llm_config,
261
269
  client=azure_openai_client
262
270
  ),
263
271
  embedder=OpenAIEmbedder(
@@ -268,6 +276,7 @@ graphiti = Graphiti(
268
276
  ),
269
277
  # Optional: Configure the OpenAI cross encoder with Azure OpenAI
270
278
  cross_encoder=OpenAIRerankerClient(
279
+ llm_config=azure_llm_config,
271
280
  client=azure_openai_client
272
281
  )
273
282
  )
@@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
105
105
  Requirements:
106
106
 
107
107
  - Python 3.10 or higher
108
- - Neo4j 5.26 or higher (serves as the embeddings storage backend)
108
+ - Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend)
109
109
  - OpenAI API key (for LLM inference and embedding)
110
110
 
111
111
  > [!IMPORTANT]
@@ -205,7 +205,7 @@ Graphiti supports Azure OpenAI for both LLM inference and embeddings. To use Azu
205
205
  ```python
206
206
  from openai import AsyncAzureOpenAI
207
207
  from graphiti_core import Graphiti
208
- from graphiti_core.llm_client import OpenAIClient
208
+ from graphiti_core.llm_client import LLMConfig, OpenAIClient
209
209
  from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
210
210
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
211
211
 
@@ -221,12 +221,19 @@ azure_openai_client = AsyncAzureOpenAI(
221
221
  azure_endpoint=azure_endpoint
222
222
  )
223
223
 
224
+ # Create LLM Config with your Azure deployed model names
225
+ azure_llm_config = LLMConfig(
226
+ small_model="gpt-4.1-nano",
227
+ model="gpt-4.1-mini",
228
+ )
229
+
224
230
  # Initialize Graphiti with Azure OpenAI clients
225
231
  graphiti = Graphiti(
226
232
  "bolt://localhost:7687",
227
233
  "neo4j",
228
234
  "password",
229
235
  llm_client=OpenAIClient(
236
+ llm_config=azure_llm_config,
230
237
  client=azure_openai_client
231
238
  ),
232
239
  embedder=OpenAIEmbedder(
@@ -237,6 +244,7 @@ graphiti = Graphiti(
237
244
  ),
238
245
  # Optional: Configure the OpenAI cross encoder with Azure OpenAI
239
246
  cross_encoder=OpenAIRerankerClient(
247
+ llm_config=azure_llm_config,
240
248
  client=azure_openai_client
241
249
  )
242
250
  )
@@ -106,7 +106,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
106
106
  if len(top_logprobs) == 0:
107
107
  continue
108
108
  norm_logprobs = np.exp(top_logprobs[0].logprob)
109
- if bool(top_logprobs[0].token):
109
+ if top_logprobs[0].token.strip().split(' ')[0].lower() == 'true':
110
110
  scores.append(norm_logprobs)
111
111
  else:
112
112
  scores.append(1 - norm_logprobs)
@@ -0,0 +1,17 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ __all__ = ['GraphDriver', 'Neo4jDriver', 'FalkorDriver']
@@ -0,0 +1,66 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from abc import ABC, abstractmethod
19
+ from collections.abc import Coroutine
20
+ from typing import Any
21
+
22
+ from graphiti_core.helpers import DEFAULT_DATABASE
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class GraphDriverSession(ABC):
28
+ async def __aenter__(self):
29
+ return self
30
+
31
+ @abstractmethod
32
+ async def __aexit__(self, exc_type, exc, tb):
33
+ # No cleanup needed for Falkor, but method must exist
34
+ pass
35
+
36
+ @abstractmethod
37
+ async def run(self, query: str, **kwargs: Any) -> Any:
38
+ raise NotImplementedError()
39
+
40
+ @abstractmethod
41
+ async def close(self):
42
+ raise NotImplementedError()
43
+
44
+ @abstractmethod
45
+ async def execute_write(self, func, *args, **kwargs):
46
+ raise NotImplementedError()
47
+
48
+
49
+ class GraphDriver(ABC):
50
+ provider: str
51
+
52
+ @abstractmethod
53
+ def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
54
+ raise NotImplementedError()
55
+
56
+ @abstractmethod
57
+ def session(self, database: str) -> GraphDriverSession:
58
+ raise NotImplementedError()
59
+
60
+ @abstractmethod
61
+ def close(self):
62
+ raise NotImplementedError()
63
+
64
+ @abstractmethod
65
+ def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
66
+ raise NotImplementedError()
@@ -0,0 +1,132 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from collections.abc import Coroutine
19
+ from datetime import datetime
20
+ from typing import Any
21
+
22
+ from falkordb import Graph as FalkorGraph # type: ignore
23
+ from falkordb.asyncio import FalkorDB # type: ignore
24
+
25
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
26
+ from graphiti_core.helpers import DEFAULT_DATABASE
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class FalkorDriverSession(GraphDriverSession):
32
+ def __init__(self, graph: FalkorGraph):
33
+ self.graph = graph
34
+
35
+ async def __aenter__(self):
36
+ return self
37
+
38
+ async def __aexit__(self, exc_type, exc, tb):
39
+ # No cleanup needed for Falkor, but method must exist
40
+ pass
41
+
42
+ async def close(self):
43
+ # No explicit close needed for FalkorDB, but method must exist
44
+ pass
45
+
46
+ async def execute_write(self, func, *args, **kwargs):
47
+ # Directly await the provided async function with `self` as the transaction/session
48
+ return await func(self, *args, **kwargs)
49
+
50
+ async def run(self, query: str | list, **kwargs: Any) -> Any:
51
+ # FalkorDB does not support argument for Label Set, so it's converted into an array of queries
52
+ if isinstance(query, list):
53
+ for cypher, params in query:
54
+ params = convert_datetimes_to_strings(params)
55
+ await self.graph.query(str(cypher), params)
56
+ else:
57
+ params = dict(kwargs)
58
+ params = convert_datetimes_to_strings(params)
59
+ await self.graph.query(str(query), params)
60
+ # Assuming `graph.query` is async (ideal); otherwise, wrap in executor
61
+ return None
62
+
63
+
64
+ class FalkorDriver(GraphDriver):
65
+ provider: str = 'falkordb'
66
+
67
+ def __init__(
68
+ self,
69
+ uri: str,
70
+ user: str,
71
+ password: str,
72
+ ):
73
+ super().__init__()
74
+ if user and password:
75
+ uri_parts = uri.split('://', 1)
76
+ uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
77
+
78
+ self.client = FalkorDB.from_url(
79
+ url=uri,
80
+ )
81
+
82
+ def _get_graph(self, graph_name: str | None) -> FalkorGraph:
83
+ # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
84
+ if graph_name is None:
85
+ graph_name = 'DEFAULT_DATABASE'
86
+ return self.client.select_graph(graph_name)
87
+
88
+ async def execute_query(self, cypher_query_, **kwargs: Any):
89
+ graph_name = kwargs.pop('database_', DEFAULT_DATABASE)
90
+ graph = self._get_graph(graph_name)
91
+
92
+ # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
93
+ params = convert_datetimes_to_strings(dict(kwargs))
94
+
95
+ try:
96
+ result = await graph.query(cypher_query_, params)
97
+ except Exception as e:
98
+ if 'already indexed' in str(e):
99
+ # check if index already exists
100
+ logger.info(f'Index already exists: {e}')
101
+ return None
102
+ logger.error(f'Error executing FalkorDB query: {e}')
103
+ raise
104
+
105
+ # Convert the result header to a list of strings
106
+ header = [h[1].decode('utf-8') for h in result.header]
107
+ return result.result_set, header, None
108
+
109
+ def session(self, database: str | None) -> GraphDriverSession:
110
+ return FalkorDriverSession(self._get_graph(database))
111
+
112
+ async def close(self) -> None:
113
+ await self.client.connection.close()
114
+
115
+ async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
116
+ return self.execute_query(
117
+ 'CALL db.indexes() YIELD name DROP INDEX name',
118
+ database_=database_,
119
+ )
120
+
121
+
122
+ def convert_datetimes_to_strings(obj):
123
+ if isinstance(obj, dict):
124
+ return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
125
+ elif isinstance(obj, list):
126
+ return [convert_datetimes_to_strings(item) for item in obj]
127
+ elif isinstance(obj, tuple):
128
+ return tuple(convert_datetimes_to_strings(item) for item in obj)
129
+ elif isinstance(obj, datetime):
130
+ return obj.isoformat()
131
+ else:
132
+ return obj
@@ -0,0 +1,61 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import logging
18
+ from collections.abc import Coroutine
19
+ from typing import Any
20
+
21
+ from neo4j import AsyncGraphDatabase
22
+ from typing_extensions import LiteralString
23
+
24
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
25
+ from graphiti_core.helpers import DEFAULT_DATABASE
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class Neo4jDriver(GraphDriver):
31
+ provider: str = 'neo4j'
32
+
33
+ def __init__(
34
+ self,
35
+ uri: str,
36
+ user: str | None,
37
+ password: str | None,
38
+ ):
39
+ super().__init__()
40
+ self.client = AsyncGraphDatabase.driver(
41
+ uri=uri,
42
+ auth=(user or '', password or ''),
43
+ )
44
+
45
+ async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:
46
+ params = kwargs.pop('params', None)
47
+ result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
48
+
49
+ return result
50
+
51
+ def session(self, database: str) -> GraphDriverSession:
52
+ return self.client.session(database=database) # type: ignore
53
+
54
+ async def close(self) -> None:
55
+ return await self.client.close()
56
+
57
+ def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
58
+ return self.client.execute_query(
59
+ 'CALL db.indexes() YIELD name DROP INDEX name',
60
+ database_=database_,
61
+ )
@@ -21,10 +21,10 @@ from time import time
21
21
  from typing import Any
22
22
  from uuid import uuid4
23
23
 
24
- from neo4j import AsyncDriver
25
24
  from pydantic import BaseModel, Field
26
25
  from typing_extensions import LiteralString
27
26
 
27
+ from graphiti_core.driver.driver import GraphDriver
28
28
  from graphiti_core.embedder import EmbedderClient
29
29
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
30
30
  from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
@@ -49,7 +49,9 @@ ENTITY_EDGE_RETURN: LiteralString = """
49
49
  e.episodes AS episodes,
50
50
  e.expired_at AS expired_at,
51
51
  e.valid_at AS valid_at,
52
- e.invalid_at AS invalid_at"""
52
+ e.invalid_at AS invalid_at,
53
+ properties(e) AS attributes
54
+ """
53
55
 
54
56
 
55
57
  class Edge(BaseModel, ABC):
@@ -60,9 +62,9 @@ class Edge(BaseModel, ABC):
60
62
  created_at: datetime
61
63
 
62
64
  @abstractmethod
63
- async def save(self, driver: AsyncDriver): ...
65
+ async def save(self, driver: GraphDriver): ...
64
66
 
65
- async def delete(self, driver: AsyncDriver):
67
+ async def delete(self, driver: GraphDriver):
66
68
  result = await driver.execute_query(
67
69
  """
68
70
  MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
@@ -85,11 +87,11 @@ class Edge(BaseModel, ABC):
85
87
  return False
86
88
 
87
89
  @classmethod
88
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
90
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
89
91
 
90
92
 
91
93
  class EpisodicEdge(Edge):
92
- async def save(self, driver: AsyncDriver):
94
+ async def save(self, driver: GraphDriver):
93
95
  result = await driver.execute_query(
94
96
  EPISODIC_EDGE_SAVE,
95
97
  episode_uuid=self.source_node_uuid,
@@ -100,12 +102,12 @@ class EpisodicEdge(Edge):
100
102
  database_=DEFAULT_DATABASE,
101
103
  )
102
104
 
103
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
105
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
104
106
 
105
107
  return result
106
108
 
107
109
  @classmethod
108
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
110
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
109
111
  records, _, _ = await driver.execute_query(
110
112
  """
111
113
  MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
@@ -128,7 +130,7 @@ class EpisodicEdge(Edge):
128
130
  return edges[0]
129
131
 
130
132
  @classmethod
131
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
133
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
132
134
  records, _, _ = await driver.execute_query(
133
135
  """
134
136
  MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
@@ -154,7 +156,7 @@ class EpisodicEdge(Edge):
154
156
  @classmethod
155
157
  async def get_by_group_ids(
156
158
  cls,
157
- driver: AsyncDriver,
159
+ driver: GraphDriver,
158
160
  group_ids: list[str],
159
161
  limit: int | None = None,
160
162
  uuid_cursor: str | None = None,
@@ -209,6 +211,9 @@ class EntityEdge(Edge):
209
211
  invalid_at: datetime | None = Field(
210
212
  default=None, description='datetime of when the fact stopped being true'
211
213
  )
214
+ attributes: dict[str, Any] = Field(
215
+ default={}, description='Additional attributes of the edge. Dependent on edge name'
216
+ )
212
217
 
213
218
  async def generate_embedding(self, embedder: EmbedderClient):
214
219
  start = time()
@@ -221,7 +226,7 @@ class EntityEdge(Edge):
221
226
 
222
227
  return self.fact_embedding
223
228
 
224
- async def load_fact_embedding(self, driver: AsyncDriver):
229
+ async def load_fact_embedding(self, driver: GraphDriver):
225
230
  query: LiteralString = """
226
231
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
227
232
  RETURN e.fact_embedding AS fact_embedding
@@ -235,30 +240,36 @@ class EntityEdge(Edge):
235
240
 
236
241
  self.fact_embedding = records[0]['fact_embedding']
237
242
 
238
- async def save(self, driver: AsyncDriver):
243
+ async def save(self, driver: GraphDriver):
244
+ edge_data: dict[str, Any] = {
245
+ 'source_uuid': self.source_node_uuid,
246
+ 'target_uuid': self.target_node_uuid,
247
+ 'uuid': self.uuid,
248
+ 'name': self.name,
249
+ 'group_id': self.group_id,
250
+ 'fact': self.fact,
251
+ 'fact_embedding': self.fact_embedding,
252
+ 'episodes': self.episodes,
253
+ 'created_at': self.created_at,
254
+ 'expired_at': self.expired_at,
255
+ 'valid_at': self.valid_at,
256
+ 'invalid_at': self.invalid_at,
257
+ }
258
+
259
+ edge_data.update(self.attributes or {})
260
+
239
261
  result = await driver.execute_query(
240
262
  ENTITY_EDGE_SAVE,
241
- source_uuid=self.source_node_uuid,
242
- target_uuid=self.target_node_uuid,
243
- uuid=self.uuid,
244
- name=self.name,
245
- group_id=self.group_id,
246
- fact=self.fact,
247
- fact_embedding=self.fact_embedding,
248
- episodes=self.episodes,
249
- created_at=self.created_at,
250
- expired_at=self.expired_at,
251
- valid_at=self.valid_at,
252
- invalid_at=self.invalid_at,
263
+ edge_data=edge_data,
253
264
  database_=DEFAULT_DATABASE,
254
265
  )
255
266
 
256
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
267
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
257
268
 
258
269
  return result
259
270
 
260
271
  @classmethod
261
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
272
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
262
273
  records, _, _ = await driver.execute_query(
263
274
  """
264
275
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
@@ -276,7 +287,7 @@ class EntityEdge(Edge):
276
287
  return edges[0]
277
288
 
278
289
  @classmethod
279
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
290
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
280
291
  if len(uuids) == 0:
281
292
  return []
282
293
 
@@ -298,7 +309,7 @@ class EntityEdge(Edge):
298
309
  @classmethod
299
310
  async def get_by_group_ids(
300
311
  cls,
301
- driver: AsyncDriver,
312
+ driver: GraphDriver,
302
313
  group_ids: list[str],
303
314
  limit: int | None = None,
304
315
  uuid_cursor: str | None = None,
@@ -331,11 +342,11 @@ class EntityEdge(Edge):
331
342
  return edges
332
343
 
333
344
  @classmethod
334
- async def get_by_node_uuid(cls, driver: AsyncDriver, node_uuid: str):
345
+ async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
335
346
  query: LiteralString = (
336
347
  """
337
- MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
338
- """
348
+ MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
349
+ """
339
350
  + ENTITY_EDGE_RETURN
340
351
  )
341
352
  records, _, _ = await driver.execute_query(
@@ -348,7 +359,7 @@ class EntityEdge(Edge):
348
359
 
349
360
 
350
361
  class CommunityEdge(Edge):
351
- async def save(self, driver: AsyncDriver):
362
+ async def save(self, driver: GraphDriver):
352
363
  result = await driver.execute_query(
353
364
  COMMUNITY_EDGE_SAVE,
354
365
  community_uuid=self.source_node_uuid,
@@ -359,12 +370,12 @@ class CommunityEdge(Edge):
359
370
  database_=DEFAULT_DATABASE,
360
371
  )
361
372
 
362
- logger.debug(f'Saved edge to neo4j: {self.uuid}')
373
+ logger.debug(f'Saved edge to Graph: {self.uuid}')
363
374
 
364
375
  return result
365
376
 
366
377
  @classmethod
367
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
378
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
368
379
  records, _, _ = await driver.execute_query(
369
380
  """
370
381
  MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
@@ -385,7 +396,7 @@ class CommunityEdge(Edge):
385
396
  return edges[0]
386
397
 
387
398
  @classmethod
388
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
399
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
389
400
  records, _, _ = await driver.execute_query(
390
401
  """
391
402
  MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
@@ -409,7 +420,7 @@ class CommunityEdge(Edge):
409
420
  @classmethod
410
421
  async def get_by_group_ids(
411
422
  cls,
412
- driver: AsyncDriver,
423
+ driver: GraphDriver,
413
424
  group_ids: list[str],
414
425
  limit: int | None = None,
415
426
  uuid_cursor: str | None = None,
@@ -452,12 +463,12 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
452
463
  group_id=record['group_id'],
453
464
  source_node_uuid=record['source_node_uuid'],
454
465
  target_node_uuid=record['target_node_uuid'],
455
- created_at=record['created_at'].to_native(),
466
+ created_at=parse_db_date(record['created_at']), # type: ignore
456
467
  )
457
468
 
458
469
 
459
470
  def get_entity_edge_from_record(record: Any) -> EntityEdge:
460
- return EntityEdge(
471
+ edge = EntityEdge(
461
472
  uuid=record['uuid'],
462
473
  source_node_uuid=record['source_node_uuid'],
463
474
  target_node_uuid=record['target_node_uuid'],
@@ -465,12 +476,27 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
465
476
  name=record['name'],
466
477
  group_id=record['group_id'],
467
478
  episodes=record['episodes'],
468
- created_at=record['created_at'].to_native(),
479
+ created_at=parse_db_date(record['created_at']), # type: ignore
469
480
  expired_at=parse_db_date(record['expired_at']),
470
481
  valid_at=parse_db_date(record['valid_at']),
471
482
  invalid_at=parse_db_date(record['invalid_at']),
483
+ attributes=record['attributes'],
472
484
  )
473
485
 
486
+ edge.attributes.pop('uuid', None)
487
+ edge.attributes.pop('source_node_uuid', None)
488
+ edge.attributes.pop('target_node_uuid', None)
489
+ edge.attributes.pop('fact', None)
490
+ edge.attributes.pop('name', None)
491
+ edge.attributes.pop('group_id', None)
492
+ edge.attributes.pop('episodes', None)
493
+ edge.attributes.pop('created_at', None)
494
+ edge.attributes.pop('expired_at', None)
495
+ edge.attributes.pop('valid_at', None)
496
+ edge.attributes.pop('invalid_at', None)
497
+
498
+ return edge
499
+
474
500
 
475
501
  def get_community_edge_from_record(record: Any):
476
502
  return CommunityEdge(
@@ -478,7 +504,7 @@ def get_community_edge_from_record(record: Any):
478
504
  group_id=record['group_id'],
479
505
  source_node_uuid=record['source_node_uuid'],
480
506
  target_node_uuid=record['target_node_uuid'],
481
- created_at=record['created_at'].to_native(),
507
+ created_at=parse_db_date(record['created_at']), # type: ignore
482
508
  )
483
509
 
484
510