graphiti-core 0.12.0rc5__tar.gz → 0.12.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/PKG-INFO +4 -3
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/README.md +1 -1
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core-0.12.2/graphiti_core/driver/__init__.py +17 -0
- graphiti_core-0.12.2/graphiti_core/driver/driver.py +66 -0
- graphiti_core-0.12.2/graphiti_core/driver/falkordb_driver.py +131 -0
- graphiti_core-0.12.2/graphiti_core/driver/neo4j_driver.py +61 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/edges.py +26 -26
- graphiti_core-0.12.2/graphiti_core/embedder/azure_openai.py +64 -0
- graphiti_core-0.12.2/graphiti_core/graph_queries.py +149 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/graphiti.py +21 -8
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/graphiti_types.py +2 -2
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/helpers.py +9 -3
- graphiti_core-0.12.2/graphiti_core/llm_client/__init__.py +22 -0
- graphiti_core-0.12.2/graphiti_core/llm_client/azure_openai_client.py +73 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/nodes.py +31 -31
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/dedupe_nodes.py +5 -1
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/extract_edges.py +2 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/extract_nodes.py +2 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/search/search.py +6 -10
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/search/search_utils.py +243 -187
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/bulk_utils.py +21 -11
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/maintenance/community_operations.py +6 -7
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/maintenance/edge_operations.py +68 -3
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/maintenance/node_operations.py +19 -5
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/pyproject.toml +9 -2
- graphiti_core-0.12.0rc5/graphiti_core/llm_client/__init__.py +0 -6
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/LICENSE +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/cross_encoder/__init__.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/cross_encoder/client.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/embedder/__init__.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/embedder/client.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/embedder/gemini.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/embedder/openai.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/embedder/voyage.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/llm_client/client.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/llm_client/config.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/llm_client/gemini_client.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/llm_client/openai_client.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/llm_client/openai_generic_client.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/models/__init__.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/models/edges/__init__.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/models/edges/edge_db_queries.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/models/nodes/__init__.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/models/nodes/node_db_queries.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/eval.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/prompt_helpers.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/search/search_config.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/search/search_config_recipes.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/search/search_filters.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/search/search_helpers.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/datetime_utils.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/maintenance/utils.py +0 -0
- {graphiti_core-0.12.0rc5 → graphiti_core-0.12.2}/graphiti_core/utils/ontology_utils/entity_types_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.12.
|
|
3
|
+
Version: 0.12.2
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -17,9 +17,10 @@ Provides-Extra: google-genai
|
|
|
17
17
|
Provides-Extra: groq
|
|
18
18
|
Requires-Dist: anthropic (>=0.49.0) ; extra == "anthropic"
|
|
19
19
|
Requires-Dist: diskcache (>=5.6.3)
|
|
20
|
+
Requires-Dist: falkordb (>=1.1.2,<2.0.0)
|
|
20
21
|
Requires-Dist: google-genai (>=1.8.0) ; extra == "google-genai"
|
|
21
22
|
Requires-Dist: groq (>=0.2.0) ; extra == "groq"
|
|
22
|
-
Requires-Dist: neo4j (>=5.
|
|
23
|
+
Requires-Dist: neo4j (>=5.26.0)
|
|
23
24
|
Requires-Dist: numpy (>=1.0.0)
|
|
24
25
|
Requires-Dist: openai (>=1.53.0)
|
|
25
26
|
Requires-Dist: pydantic (>=2.11.5)
|
|
@@ -136,7 +137,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
|
|
|
136
137
|
Requirements:
|
|
137
138
|
|
|
138
139
|
- Python 3.10 or higher
|
|
139
|
-
- Neo4j 5.26 or higher (serves as the embeddings storage backend)
|
|
140
|
+
- Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend)
|
|
140
141
|
- OpenAI API key (for LLM inference and embedding)
|
|
141
142
|
|
|
142
143
|
> [!IMPORTANT]
|
|
@@ -105,7 +105,7 @@ Graphiti is specifically designed to address the challenges of dynamic and frequ
|
|
|
105
105
|
Requirements:
|
|
106
106
|
|
|
107
107
|
- Python 3.10 or higher
|
|
108
|
-
- Neo4j 5.26 or higher (serves as the embeddings storage backend)
|
|
108
|
+
- Neo4j 5.26 / FalkorDB 1.1.2 or higher (serves as the embeddings storage backend)
|
|
109
109
|
- OpenAI API key (for LLM inference and embedding)
|
|
110
110
|
|
|
111
111
|
> [!IMPORTANT]
|
|
@@ -106,7 +106,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
106
106
|
if len(top_logprobs) == 0:
|
|
107
107
|
continue
|
|
108
108
|
norm_logprobs = np.exp(top_logprobs[0].logprob)
|
|
109
|
-
if
|
|
109
|
+
if top_logprobs[0].token.strip().split(' ')[0].lower() == 'true':
|
|
110
110
|
scores.append(norm_logprobs)
|
|
111
111
|
else:
|
|
112
112
|
scores.append(1 - norm_logprobs)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
__all__ = ['GraphDriver', 'Neo4jDriver', 'FalkorDriver']
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from abc import ABC, abstractmethod
|
|
19
|
+
from collections.abc import Coroutine
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GraphDriverSession(ABC):
|
|
28
|
+
async def __aenter__(self):
|
|
29
|
+
return self
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
async def __aexit__(self, exc_type, exc, tb):
|
|
33
|
+
# No cleanup needed for Falkor, but method must exist
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
async def run(self, query: str, **kwargs: Any) -> Any:
|
|
38
|
+
raise NotImplementedError()
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
async def close(self):
|
|
42
|
+
raise NotImplementedError()
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
async def execute_write(self, func, *args, **kwargs):
|
|
46
|
+
raise NotImplementedError()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class GraphDriver(ABC):
|
|
50
|
+
provider: str
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
54
|
+
raise NotImplementedError()
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def session(self, database: str) -> GraphDriverSession:
|
|
58
|
+
raise NotImplementedError()
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def close(self):
|
|
62
|
+
raise NotImplementedError()
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
|
66
|
+
raise NotImplementedError()
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Coroutine
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
from falkordb import Graph as FalkorGraph # type: ignore
|
|
23
|
+
from falkordb.asyncio import FalkorDB # type: ignore
|
|
24
|
+
|
|
25
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
26
|
+
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FalkorDriverSession(GraphDriverSession):
|
|
32
|
+
def __init__(self, graph: FalkorGraph):
|
|
33
|
+
self.graph = graph
|
|
34
|
+
|
|
35
|
+
async def __aenter__(self):
|
|
36
|
+
return self
|
|
37
|
+
|
|
38
|
+
async def __aexit__(self, exc_type, exc, tb):
|
|
39
|
+
# No cleanup needed for Falkor, but method must exist
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
async def close(self):
|
|
43
|
+
# No explicit close needed for FalkorDB, but method must exist
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
async def execute_write(self, func, *args, **kwargs):
|
|
47
|
+
# Directly await the provided async function with `self` as the transaction/session
|
|
48
|
+
return await func(self, *args, **kwargs)
|
|
49
|
+
|
|
50
|
+
async def run(self, query: str | list, **kwargs: Any) -> Any:
|
|
51
|
+
# FalkorDB does not support argument for Label Set, so it's converted into an array of queries
|
|
52
|
+
if isinstance(query, list):
|
|
53
|
+
for cypher, params in query:
|
|
54
|
+
params = convert_datetimes_to_strings(params)
|
|
55
|
+
await self.graph.query(str(cypher), params)
|
|
56
|
+
else:
|
|
57
|
+
params = dict(kwargs)
|
|
58
|
+
params = convert_datetimes_to_strings(params)
|
|
59
|
+
await self.graph.query(str(query), params)
|
|
60
|
+
# Assuming `graph.query` is async (ideal); otherwise, wrap in executor
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class FalkorDriver(GraphDriver):
|
|
65
|
+
provider: str = 'falkordb'
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
uri: str,
|
|
70
|
+
user: str,
|
|
71
|
+
password: str,
|
|
72
|
+
):
|
|
73
|
+
super().__init__()
|
|
74
|
+
uri_parts = uri.split('://', 1)
|
|
75
|
+
uri = f'{uri_parts[0]}://{user}:{password}@{uri_parts[1]}'
|
|
76
|
+
|
|
77
|
+
self.client = FalkorDB(
|
|
78
|
+
host='your-db.falkor.cloud', port=6380, password='your_password', ssl=True
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
|
|
82
|
+
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "DEFAULT_DATABASE"
|
|
83
|
+
if graph_name is None:
|
|
84
|
+
graph_name = 'DEFAULT_DATABASE'
|
|
85
|
+
return self.client.select_graph(graph_name)
|
|
86
|
+
|
|
87
|
+
async def execute_query(self, cypher_query_, **kwargs: Any):
|
|
88
|
+
graph_name = kwargs.pop('database_', DEFAULT_DATABASE)
|
|
89
|
+
graph = self._get_graph(graph_name)
|
|
90
|
+
|
|
91
|
+
# Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
|
|
92
|
+
params = convert_datetimes_to_strings(dict(kwargs))
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
result = await graph.query(cypher_query_, params)
|
|
96
|
+
except Exception as e:
|
|
97
|
+
if 'already indexed' in str(e):
|
|
98
|
+
# check if index already exists
|
|
99
|
+
logger.info(f'Index already exists: {e}')
|
|
100
|
+
return None
|
|
101
|
+
logger.error(f'Error executing FalkorDB query: {e}')
|
|
102
|
+
raise
|
|
103
|
+
|
|
104
|
+
# Convert the result header to a list of strings
|
|
105
|
+
header = [h[1].decode('utf-8') for h in result.header]
|
|
106
|
+
return result.result_set, header, None
|
|
107
|
+
|
|
108
|
+
def session(self, database: str | None) -> GraphDriverSession:
|
|
109
|
+
return FalkorDriverSession(self._get_graph(database))
|
|
110
|
+
|
|
111
|
+
async def close(self) -> None:
|
|
112
|
+
await self.client.connection.close()
|
|
113
|
+
|
|
114
|
+
async def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
|
115
|
+
return self.execute_query(
|
|
116
|
+
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
117
|
+
database_=database_,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def convert_datetimes_to_strings(obj):
|
|
122
|
+
if isinstance(obj, dict):
|
|
123
|
+
return {k: convert_datetimes_to_strings(v) for k, v in obj.items()}
|
|
124
|
+
elif isinstance(obj, list):
|
|
125
|
+
return [convert_datetimes_to_strings(item) for item in obj]
|
|
126
|
+
elif isinstance(obj, tuple):
|
|
127
|
+
return tuple(convert_datetimes_to_strings(item) for item in obj)
|
|
128
|
+
elif isinstance(obj, datetime):
|
|
129
|
+
return obj.isoformat()
|
|
130
|
+
else:
|
|
131
|
+
return obj
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Coroutine
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from neo4j import AsyncGraphDatabase
|
|
22
|
+
from typing_extensions import LiteralString
|
|
23
|
+
|
|
24
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
25
|
+
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Neo4jDriver(GraphDriver):
|
|
31
|
+
provider: str = 'neo4j'
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
uri: str,
|
|
36
|
+
user: str | None,
|
|
37
|
+
password: str | None,
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.client = AsyncGraphDatabase.driver(
|
|
41
|
+
uri=uri,
|
|
42
|
+
auth=(user or '', password or ''),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> Coroutine:
|
|
46
|
+
params = kwargs.pop('params', None)
|
|
47
|
+
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
|
48
|
+
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
def session(self, database: str) -> GraphDriverSession:
|
|
52
|
+
return self.client.session(database=database) # type: ignore
|
|
53
|
+
|
|
54
|
+
async def close(self) -> None:
|
|
55
|
+
return await self.client.close()
|
|
56
|
+
|
|
57
|
+
def delete_all_indexes(self, database_: str = DEFAULT_DATABASE) -> Coroutine:
|
|
58
|
+
return self.client.execute_query(
|
|
59
|
+
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
60
|
+
database_=database_,
|
|
61
|
+
)
|
|
@@ -21,10 +21,10 @@ from time import time
|
|
|
21
21
|
from typing import Any
|
|
22
22
|
from uuid import uuid4
|
|
23
23
|
|
|
24
|
-
from neo4j import AsyncDriver
|
|
25
24
|
from pydantic import BaseModel, Field
|
|
26
25
|
from typing_extensions import LiteralString
|
|
27
26
|
|
|
27
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
28
28
|
from graphiti_core.embedder import EmbedderClient
|
|
29
29
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
30
30
|
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
|
@@ -62,9 +62,9 @@ class Edge(BaseModel, ABC):
|
|
|
62
62
|
created_at: datetime
|
|
63
63
|
|
|
64
64
|
@abstractmethod
|
|
65
|
-
async def save(self, driver:
|
|
65
|
+
async def save(self, driver: GraphDriver): ...
|
|
66
66
|
|
|
67
|
-
async def delete(self, driver:
|
|
67
|
+
async def delete(self, driver: GraphDriver):
|
|
68
68
|
result = await driver.execute_query(
|
|
69
69
|
"""
|
|
70
70
|
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
|
|
@@ -87,11 +87,11 @@ class Edge(BaseModel, ABC):
|
|
|
87
87
|
return False
|
|
88
88
|
|
|
89
89
|
@classmethod
|
|
90
|
-
async def get_by_uuid(cls, driver:
|
|
90
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
class EpisodicEdge(Edge):
|
|
94
|
-
async def save(self, driver:
|
|
94
|
+
async def save(self, driver: GraphDriver):
|
|
95
95
|
result = await driver.execute_query(
|
|
96
96
|
EPISODIC_EDGE_SAVE,
|
|
97
97
|
episode_uuid=self.source_node_uuid,
|
|
@@ -102,12 +102,12 @@ class EpisodicEdge(Edge):
|
|
|
102
102
|
database_=DEFAULT_DATABASE,
|
|
103
103
|
)
|
|
104
104
|
|
|
105
|
-
logger.debug(f'Saved edge to
|
|
105
|
+
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
106
106
|
|
|
107
107
|
return result
|
|
108
108
|
|
|
109
109
|
@classmethod
|
|
110
|
-
async def get_by_uuid(cls, driver:
|
|
110
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
111
111
|
records, _, _ = await driver.execute_query(
|
|
112
112
|
"""
|
|
113
113
|
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
@@ -130,7 +130,7 @@ class EpisodicEdge(Edge):
|
|
|
130
130
|
return edges[0]
|
|
131
131
|
|
|
132
132
|
@classmethod
|
|
133
|
-
async def get_by_uuids(cls, driver:
|
|
133
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
134
134
|
records, _, _ = await driver.execute_query(
|
|
135
135
|
"""
|
|
136
136
|
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
|
@@ -156,7 +156,7 @@ class EpisodicEdge(Edge):
|
|
|
156
156
|
@classmethod
|
|
157
157
|
async def get_by_group_ids(
|
|
158
158
|
cls,
|
|
159
|
-
driver:
|
|
159
|
+
driver: GraphDriver,
|
|
160
160
|
group_ids: list[str],
|
|
161
161
|
limit: int | None = None,
|
|
162
162
|
uuid_cursor: str | None = None,
|
|
@@ -226,7 +226,7 @@ class EntityEdge(Edge):
|
|
|
226
226
|
|
|
227
227
|
return self.fact_embedding
|
|
228
228
|
|
|
229
|
-
async def load_fact_embedding(self, driver:
|
|
229
|
+
async def load_fact_embedding(self, driver: GraphDriver):
|
|
230
230
|
query: LiteralString = """
|
|
231
231
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
232
232
|
RETURN e.fact_embedding AS fact_embedding
|
|
@@ -240,7 +240,7 @@ class EntityEdge(Edge):
|
|
|
240
240
|
|
|
241
241
|
self.fact_embedding = records[0]['fact_embedding']
|
|
242
242
|
|
|
243
|
-
async def save(self, driver:
|
|
243
|
+
async def save(self, driver: GraphDriver):
|
|
244
244
|
edge_data: dict[str, Any] = {
|
|
245
245
|
'source_uuid': self.source_node_uuid,
|
|
246
246
|
'target_uuid': self.target_node_uuid,
|
|
@@ -264,12 +264,12 @@ class EntityEdge(Edge):
|
|
|
264
264
|
database_=DEFAULT_DATABASE,
|
|
265
265
|
)
|
|
266
266
|
|
|
267
|
-
logger.debug(f'Saved edge to
|
|
267
|
+
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
268
268
|
|
|
269
269
|
return result
|
|
270
270
|
|
|
271
271
|
@classmethod
|
|
272
|
-
async def get_by_uuid(cls, driver:
|
|
272
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
273
273
|
records, _, _ = await driver.execute_query(
|
|
274
274
|
"""
|
|
275
275
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
@@ -287,7 +287,7 @@ class EntityEdge(Edge):
|
|
|
287
287
|
return edges[0]
|
|
288
288
|
|
|
289
289
|
@classmethod
|
|
290
|
-
async def get_by_uuids(cls, driver:
|
|
290
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
291
291
|
if len(uuids) == 0:
|
|
292
292
|
return []
|
|
293
293
|
|
|
@@ -309,7 +309,7 @@ class EntityEdge(Edge):
|
|
|
309
309
|
@classmethod
|
|
310
310
|
async def get_by_group_ids(
|
|
311
311
|
cls,
|
|
312
|
-
driver:
|
|
312
|
+
driver: GraphDriver,
|
|
313
313
|
group_ids: list[str],
|
|
314
314
|
limit: int | None = None,
|
|
315
315
|
uuid_cursor: str | None = None,
|
|
@@ -342,11 +342,11 @@ class EntityEdge(Edge):
|
|
|
342
342
|
return edges
|
|
343
343
|
|
|
344
344
|
@classmethod
|
|
345
|
-
async def get_by_node_uuid(cls, driver:
|
|
345
|
+
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
|
|
346
346
|
query: LiteralString = (
|
|
347
347
|
"""
|
|
348
|
-
|
|
349
|
-
|
|
348
|
+
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
|
349
|
+
"""
|
|
350
350
|
+ ENTITY_EDGE_RETURN
|
|
351
351
|
)
|
|
352
352
|
records, _, _ = await driver.execute_query(
|
|
@@ -359,7 +359,7 @@ class EntityEdge(Edge):
|
|
|
359
359
|
|
|
360
360
|
|
|
361
361
|
class CommunityEdge(Edge):
|
|
362
|
-
async def save(self, driver:
|
|
362
|
+
async def save(self, driver: GraphDriver):
|
|
363
363
|
result = await driver.execute_query(
|
|
364
364
|
COMMUNITY_EDGE_SAVE,
|
|
365
365
|
community_uuid=self.source_node_uuid,
|
|
@@ -370,12 +370,12 @@ class CommunityEdge(Edge):
|
|
|
370
370
|
database_=DEFAULT_DATABASE,
|
|
371
371
|
)
|
|
372
372
|
|
|
373
|
-
logger.debug(f'Saved edge to
|
|
373
|
+
logger.debug(f'Saved edge to Graph: {self.uuid}')
|
|
374
374
|
|
|
375
375
|
return result
|
|
376
376
|
|
|
377
377
|
@classmethod
|
|
378
|
-
async def get_by_uuid(cls, driver:
|
|
378
|
+
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
379
379
|
records, _, _ = await driver.execute_query(
|
|
380
380
|
"""
|
|
381
381
|
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
|
@@ -396,7 +396,7 @@ class CommunityEdge(Edge):
|
|
|
396
396
|
return edges[0]
|
|
397
397
|
|
|
398
398
|
@classmethod
|
|
399
|
-
async def get_by_uuids(cls, driver:
|
|
399
|
+
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
400
400
|
records, _, _ = await driver.execute_query(
|
|
401
401
|
"""
|
|
402
402
|
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
|
@@ -420,7 +420,7 @@ class CommunityEdge(Edge):
|
|
|
420
420
|
@classmethod
|
|
421
421
|
async def get_by_group_ids(
|
|
422
422
|
cls,
|
|
423
|
-
driver:
|
|
423
|
+
driver: GraphDriver,
|
|
424
424
|
group_ids: list[str],
|
|
425
425
|
limit: int | None = None,
|
|
426
426
|
uuid_cursor: str | None = None,
|
|
@@ -463,7 +463,7 @@ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
|
463
463
|
group_id=record['group_id'],
|
|
464
464
|
source_node_uuid=record['source_node_uuid'],
|
|
465
465
|
target_node_uuid=record['target_node_uuid'],
|
|
466
|
-
created_at=record['created_at']
|
|
466
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
467
467
|
)
|
|
468
468
|
|
|
469
469
|
|
|
@@ -476,7 +476,7 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
|
476
476
|
name=record['name'],
|
|
477
477
|
group_id=record['group_id'],
|
|
478
478
|
episodes=record['episodes'],
|
|
479
|
-
created_at=record['created_at']
|
|
479
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
480
480
|
expired_at=parse_db_date(record['expired_at']),
|
|
481
481
|
valid_at=parse_db_date(record['valid_at']),
|
|
482
482
|
invalid_at=parse_db_date(record['invalid_at']),
|
|
@@ -504,7 +504,7 @@ def get_community_edge_from_record(record: Any):
|
|
|
504
504
|
group_id=record['group_id'],
|
|
505
505
|
source_node_uuid=record['source_node_uuid'],
|
|
506
506
|
target_node_uuid=record['target_node_uuid'],
|
|
507
|
-
created_at=record['created_at']
|
|
507
|
+
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
508
508
|
)
|
|
509
509
|
|
|
510
510
|
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
from openai import AsyncAzureOpenAI
|
|
21
|
+
|
|
22
|
+
from .client import EmbedderClient
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AzureOpenAIEmbedderClient(EmbedderClient):
|
|
28
|
+
"""Wrapper class for AsyncAzureOpenAI that implements the EmbedderClient interface."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, azure_client: AsyncAzureOpenAI, model: str = 'text-embedding-3-small'):
|
|
31
|
+
self.azure_client = azure_client
|
|
32
|
+
self.model = model
|
|
33
|
+
|
|
34
|
+
async def create(self, input_data: str | list[str] | Any) -> list[float]:
|
|
35
|
+
"""Create embeddings using Azure OpenAI client."""
|
|
36
|
+
try:
|
|
37
|
+
# Handle different input types
|
|
38
|
+
if isinstance(input_data, str):
|
|
39
|
+
text_input = [input_data]
|
|
40
|
+
elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data):
|
|
41
|
+
text_input = input_data
|
|
42
|
+
else:
|
|
43
|
+
# Convert to string list for other types
|
|
44
|
+
text_input = [str(input_data)]
|
|
45
|
+
|
|
46
|
+
response = await self.azure_client.embeddings.create(model=self.model, input=text_input)
|
|
47
|
+
|
|
48
|
+
# Return the first embedding as a list of floats
|
|
49
|
+
return response.data[0].embedding
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.error(f'Error in Azure OpenAI embedding: {e}')
|
|
52
|
+
raise
|
|
53
|
+
|
|
54
|
+
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
|
55
|
+
"""Create batch embeddings using Azure OpenAI client."""
|
|
56
|
+
try:
|
|
57
|
+
response = await self.azure_client.embeddings.create(
|
|
58
|
+
model=self.model, input=input_data_list
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return [embedding.embedding for embedding in response.data]
|
|
62
|
+
except Exception as e:
|
|
63
|
+
logger.error(f'Error in Azure OpenAI batch embedding: {e}')
|
|
64
|
+
raise
|