graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
- graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
- graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/__init__.py +19 -0
- graphiti_core/driver/driver.py +124 -0
- graphiti_core/driver/falkordb_driver.py +362 -0
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +117 -0
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +287 -172
- graphiti_core/embedder/azure_openai.py +71 -0
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/embedder/gemini.py +116 -22
- graphiti_core/embedder/voyage.py +13 -2
- graphiti_core/errors.py +8 -0
- graphiti_core/graph_queries.py +162 -0
- graphiti_core/graphiti.py +705 -193
- graphiti_core/graphiti_types.py +4 -2
- graphiti_core/helpers.py +87 -10
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/anthropic_client.py +159 -56
- graphiti_core/llm_client/azure_openai_client.py +115 -0
- graphiti_core/llm_client/client.py +98 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +290 -41
- graphiti_core/llm_client/groq_client.py +14 -3
- graphiti_core/llm_client/openai_base_client.py +261 -0
- graphiti_core/llm_client/openai_client.py +56 -132
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +420 -205
- graphiti_core/prompts/dedupe_edges.py +46 -32
- graphiti_core/prompts/dedupe_nodes.py +67 -42
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +27 -16
- graphiti_core/prompts/extract_nodes.py +74 -31
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +158 -82
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +126 -35
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1405 -485
- graphiti_core/telemetry/__init__.py +9 -0
- graphiti_core/telemetry/telemetry.py +117 -0
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +364 -285
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +67 -49
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +339 -197
- graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
- graphiti_core/utils/maintenance/node_operations.py +319 -238
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- graphiti_core-0.24.3.dist-info/METADATA +726 -0
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
- graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import kuzu
|
|
21
|
+
|
|
22
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
# Kuzu requires an explicit schema.
|
|
27
|
+
# As Kuzu currently does not support creating full text indexes on edge properties,
|
|
28
|
+
# we work around this by representing (n:Entity)-[:RELATES_TO]->(m:Entity) as
|
|
29
|
+
# (n)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m).
|
|
30
|
+
SCHEMA_QUERIES = """
|
|
31
|
+
CREATE NODE TABLE IF NOT EXISTS Episodic (
|
|
32
|
+
uuid STRING PRIMARY KEY,
|
|
33
|
+
name STRING,
|
|
34
|
+
group_id STRING,
|
|
35
|
+
created_at TIMESTAMP,
|
|
36
|
+
source STRING,
|
|
37
|
+
source_description STRING,
|
|
38
|
+
content STRING,
|
|
39
|
+
valid_at TIMESTAMP,
|
|
40
|
+
entity_edges STRING[]
|
|
41
|
+
);
|
|
42
|
+
CREATE NODE TABLE IF NOT EXISTS Entity (
|
|
43
|
+
uuid STRING PRIMARY KEY,
|
|
44
|
+
name STRING,
|
|
45
|
+
group_id STRING,
|
|
46
|
+
labels STRING[],
|
|
47
|
+
created_at TIMESTAMP,
|
|
48
|
+
name_embedding FLOAT[],
|
|
49
|
+
summary STRING,
|
|
50
|
+
attributes STRING
|
|
51
|
+
);
|
|
52
|
+
CREATE NODE TABLE IF NOT EXISTS Community (
|
|
53
|
+
uuid STRING PRIMARY KEY,
|
|
54
|
+
name STRING,
|
|
55
|
+
group_id STRING,
|
|
56
|
+
created_at TIMESTAMP,
|
|
57
|
+
name_embedding FLOAT[],
|
|
58
|
+
summary STRING
|
|
59
|
+
);
|
|
60
|
+
CREATE NODE TABLE IF NOT EXISTS RelatesToNode_ (
|
|
61
|
+
uuid STRING PRIMARY KEY,
|
|
62
|
+
group_id STRING,
|
|
63
|
+
created_at TIMESTAMP,
|
|
64
|
+
name STRING,
|
|
65
|
+
fact STRING,
|
|
66
|
+
fact_embedding FLOAT[],
|
|
67
|
+
episodes STRING[],
|
|
68
|
+
expired_at TIMESTAMP,
|
|
69
|
+
valid_at TIMESTAMP,
|
|
70
|
+
invalid_at TIMESTAMP,
|
|
71
|
+
attributes STRING
|
|
72
|
+
);
|
|
73
|
+
CREATE REL TABLE IF NOT EXISTS RELATES_TO(
|
|
74
|
+
FROM Entity TO RelatesToNode_,
|
|
75
|
+
FROM RelatesToNode_ TO Entity
|
|
76
|
+
);
|
|
77
|
+
CREATE REL TABLE IF NOT EXISTS MENTIONS(
|
|
78
|
+
FROM Episodic TO Entity,
|
|
79
|
+
uuid STRING PRIMARY KEY,
|
|
80
|
+
group_id STRING,
|
|
81
|
+
created_at TIMESTAMP
|
|
82
|
+
);
|
|
83
|
+
CREATE REL TABLE IF NOT EXISTS HAS_MEMBER(
|
|
84
|
+
FROM Community TO Entity,
|
|
85
|
+
FROM Community TO Community,
|
|
86
|
+
uuid STRING,
|
|
87
|
+
group_id STRING,
|
|
88
|
+
created_at TIMESTAMP
|
|
89
|
+
);
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class KuzuDriver(GraphDriver):
|
|
94
|
+
provider: GraphProvider = GraphProvider.KUZU
|
|
95
|
+
aoss_client: None = None
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
db: str = ':memory:',
|
|
100
|
+
max_concurrent_queries: int = 1,
|
|
101
|
+
):
|
|
102
|
+
super().__init__()
|
|
103
|
+
self.db = kuzu.Database(db)
|
|
104
|
+
|
|
105
|
+
self.setup_schema()
|
|
106
|
+
|
|
107
|
+
self.client = kuzu.AsyncConnection(self.db, max_concurrent_queries=max_concurrent_queries)
|
|
108
|
+
|
|
109
|
+
async def execute_query(
|
|
110
|
+
self, cypher_query_: str, **kwargs: Any
|
|
111
|
+
) -> tuple[list[dict[str, Any]] | list[list[dict[str, Any]]], None, None]:
|
|
112
|
+
params = {k: v for k, v in kwargs.items() if v is not None}
|
|
113
|
+
# Kuzu does not support these parameters.
|
|
114
|
+
params.pop('database_', None)
|
|
115
|
+
params.pop('routing_', None)
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
results = await self.client.execute(cypher_query_, parameters=params)
|
|
119
|
+
except Exception as e:
|
|
120
|
+
params = {k: (v[:5] if isinstance(v, list) else v) for k, v in params.items()}
|
|
121
|
+
logger.error(f'Error executing Kuzu query: {e}\n{cypher_query_}\n{params}')
|
|
122
|
+
raise
|
|
123
|
+
|
|
124
|
+
if not results:
|
|
125
|
+
return [], None, None
|
|
126
|
+
|
|
127
|
+
if isinstance(results, list):
|
|
128
|
+
dict_results = [list(result.rows_as_dict()) for result in results]
|
|
129
|
+
else:
|
|
130
|
+
dict_results = list(results.rows_as_dict())
|
|
131
|
+
return dict_results, None, None # type: ignore
|
|
132
|
+
|
|
133
|
+
def session(self, _database: str | None = None) -> GraphDriverSession:
|
|
134
|
+
return KuzuDriverSession(self)
|
|
135
|
+
|
|
136
|
+
async def close(self):
|
|
137
|
+
# Do not explicitly close the connection, instead rely on GC.
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
def delete_all_indexes(self, database_: str):
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
|
144
|
+
# Kuzu doesn't support dynamic index creation like Neo4j or FalkorDB
|
|
145
|
+
# Schema and indices are created during setup_schema()
|
|
146
|
+
# This method is required by the abstract base class but is a no-op for Kuzu
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
def setup_schema(self):
|
|
150
|
+
conn = kuzu.Connection(self.db)
|
|
151
|
+
conn.execute(SCHEMA_QUERIES)
|
|
152
|
+
conn.close()
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class KuzuDriverSession(GraphDriverSession):
|
|
156
|
+
provider = GraphProvider.KUZU
|
|
157
|
+
|
|
158
|
+
def __init__(self, driver: KuzuDriver):
|
|
159
|
+
self.driver = driver
|
|
160
|
+
|
|
161
|
+
async def __aenter__(self):
|
|
162
|
+
return self
|
|
163
|
+
|
|
164
|
+
async def __aexit__(self, exc_type, exc, tb):
|
|
165
|
+
# No cleanup needed for Kuzu, but method must exist.
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
async def close(self):
|
|
169
|
+
# Do not close the session here, as we're reusing the driver connection.
|
|
170
|
+
pass
|
|
171
|
+
|
|
172
|
+
async def execute_write(self, func, *args, **kwargs):
|
|
173
|
+
# Directly await the provided async function with `self` as the transaction/session
|
|
174
|
+
return await func(self, *args, **kwargs)
|
|
175
|
+
|
|
176
|
+
async def run(self, query: str | list, **kwargs: Any) -> Any:
|
|
177
|
+
if isinstance(query, list):
|
|
178
|
+
for cypher, params in query:
|
|
179
|
+
await self.driver.execute_query(cypher, **params)
|
|
180
|
+
else:
|
|
181
|
+
await self.driver.execute_query(query, **kwargs)
|
|
182
|
+
return None
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from collections.abc import Coroutine
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from neo4j import AsyncGraphDatabase, EagerResult
|
|
22
|
+
from typing_extensions import LiteralString
|
|
23
|
+
|
|
24
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
25
|
+
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
|
|
26
|
+
from graphiti_core.helpers import semaphore_gather
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Neo4jDriver(GraphDriver):
|
|
32
|
+
provider = GraphProvider.NEO4J
|
|
33
|
+
default_group_id: str = ''
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
uri: str,
|
|
38
|
+
user: str | None,
|
|
39
|
+
password: str | None,
|
|
40
|
+
database: str = 'neo4j',
|
|
41
|
+
):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.client = AsyncGraphDatabase.driver(
|
|
44
|
+
uri=uri,
|
|
45
|
+
auth=(user or '', password or ''),
|
|
46
|
+
)
|
|
47
|
+
self._database = database
|
|
48
|
+
|
|
49
|
+
# Schedule the indices and constraints to be built
|
|
50
|
+
import asyncio
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
# Try to get the current event loop
|
|
54
|
+
loop = asyncio.get_running_loop()
|
|
55
|
+
# Schedule the build_indices_and_constraints to run
|
|
56
|
+
loop.create_task(self.build_indices_and_constraints())
|
|
57
|
+
except RuntimeError:
|
|
58
|
+
# No event loop running, this will be handled later
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
self.aoss_client = None
|
|
62
|
+
|
|
63
|
+
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
|
64
|
+
# Check if database_ is provided in kwargs.
|
|
65
|
+
# If not populated, set the value to retain backwards compatibility
|
|
66
|
+
params = kwargs.pop('params', None)
|
|
67
|
+
if params is None:
|
|
68
|
+
params = {}
|
|
69
|
+
params.setdefault('database_', self._database)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
result = await self.client.execute_query(cypher_query_, parameters_=params, **kwargs)
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.error(f'Error executing Neo4j query: {e}\n{cypher_query_}\n{params}')
|
|
75
|
+
raise
|
|
76
|
+
|
|
77
|
+
return result
|
|
78
|
+
|
|
79
|
+
def session(self, database: str | None = None) -> GraphDriverSession:
|
|
80
|
+
_database = database or self._database
|
|
81
|
+
return self.client.session(database=_database) # type: ignore
|
|
82
|
+
|
|
83
|
+
async def close(self) -> None:
|
|
84
|
+
return await self.client.close()
|
|
85
|
+
|
|
86
|
+
def delete_all_indexes(self) -> Coroutine:
|
|
87
|
+
return self.client.execute_query(
|
|
88
|
+
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
|
92
|
+
if delete_existing:
|
|
93
|
+
await self.delete_all_indexes()
|
|
94
|
+
|
|
95
|
+
range_indices: list[LiteralString] = get_range_indices(self.provider)
|
|
96
|
+
|
|
97
|
+
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
|
|
98
|
+
|
|
99
|
+
index_queries: list[LiteralString] = range_indices + fulltext_indices
|
|
100
|
+
|
|
101
|
+
await semaphore_gather(
|
|
102
|
+
*[
|
|
103
|
+
self.execute_query(
|
|
104
|
+
query,
|
|
105
|
+
)
|
|
106
|
+
for query in index_queries
|
|
107
|
+
]
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
async def health_check(self) -> None:
|
|
111
|
+
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
|
|
112
|
+
try:
|
|
113
|
+
await self.client.verify_connectivity()
|
|
114
|
+
return None
|
|
115
|
+
except Exception as e:
|
|
116
|
+
print(f'Neo4j health check failed: {e}')
|
|
117
|
+
raise
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import datetime
|
|
19
|
+
import logging
|
|
20
|
+
from collections.abc import Coroutine
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import boto3
|
|
24
|
+
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
|
25
|
+
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
|
26
|
+
|
|
27
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
DEFAULT_SIZE = 10
|
|
31
|
+
|
|
32
|
+
aoss_indices = [
|
|
33
|
+
{
|
|
34
|
+
'index_name': 'node_name_and_summary',
|
|
35
|
+
'body': {
|
|
36
|
+
'mappings': {
|
|
37
|
+
'properties': {
|
|
38
|
+
'uuid': {'type': 'keyword'},
|
|
39
|
+
'name': {'type': 'text'},
|
|
40
|
+
'summary': {'type': 'text'},
|
|
41
|
+
'group_id': {'type': 'text'},
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
},
|
|
45
|
+
'query': {
|
|
46
|
+
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
|
|
47
|
+
'size': DEFAULT_SIZE,
|
|
48
|
+
},
|
|
49
|
+
},
|
|
50
|
+
{
|
|
51
|
+
'index_name': 'community_name',
|
|
52
|
+
'body': {
|
|
53
|
+
'mappings': {
|
|
54
|
+
'properties': {
|
|
55
|
+
'uuid': {'type': 'keyword'},
|
|
56
|
+
'name': {'type': 'text'},
|
|
57
|
+
'group_id': {'type': 'text'},
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
},
|
|
61
|
+
'query': {
|
|
62
|
+
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
|
|
63
|
+
'size': DEFAULT_SIZE,
|
|
64
|
+
},
|
|
65
|
+
},
|
|
66
|
+
{
|
|
67
|
+
'index_name': 'episode_content',
|
|
68
|
+
'body': {
|
|
69
|
+
'mappings': {
|
|
70
|
+
'properties': {
|
|
71
|
+
'uuid': {'type': 'keyword'},
|
|
72
|
+
'content': {'type': 'text'},
|
|
73
|
+
'source': {'type': 'text'},
|
|
74
|
+
'source_description': {'type': 'text'},
|
|
75
|
+
'group_id': {'type': 'text'},
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
},
|
|
79
|
+
'query': {
|
|
80
|
+
'query': {
|
|
81
|
+
'multi_match': {
|
|
82
|
+
'query': '',
|
|
83
|
+
'fields': ['content', 'source', 'source_description', 'group_id'],
|
|
84
|
+
}
|
|
85
|
+
},
|
|
86
|
+
'size': DEFAULT_SIZE,
|
|
87
|
+
},
|
|
88
|
+
},
|
|
89
|
+
{
|
|
90
|
+
'index_name': 'edge_name_and_fact',
|
|
91
|
+
'body': {
|
|
92
|
+
'mappings': {
|
|
93
|
+
'properties': {
|
|
94
|
+
'uuid': {'type': 'keyword'},
|
|
95
|
+
'name': {'type': 'text'},
|
|
96
|
+
'fact': {'type': 'text'},
|
|
97
|
+
'group_id': {'type': 'text'},
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
},
|
|
101
|
+
'query': {
|
|
102
|
+
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
|
|
103
|
+
'size': DEFAULT_SIZE,
|
|
104
|
+
},
|
|
105
|
+
},
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class NeptuneDriver(GraphDriver):
|
|
110
|
+
provider: GraphProvider = GraphProvider.NEPTUNE
|
|
111
|
+
|
|
112
|
+
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
|
|
113
|
+
"""This initializes a NeptuneDriver for use with Neptune as a backend
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
host (str): The Neptune Database or Neptune Analytics host
|
|
117
|
+
aoss_host (str): The OpenSearch host value
|
|
118
|
+
port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
|
|
119
|
+
aoss_port (int, optional): The OpenSearch port. Defaults to 443.
|
|
120
|
+
"""
|
|
121
|
+
if not host:
|
|
122
|
+
raise ValueError('You must provide an endpoint to create a NeptuneDriver')
|
|
123
|
+
|
|
124
|
+
if host.startswith('neptune-db://'):
|
|
125
|
+
# This is a Neptune Database Cluster
|
|
126
|
+
endpoint = host.replace('neptune-db://', '')
|
|
127
|
+
self.client = NeptuneGraph(endpoint, port)
|
|
128
|
+
logger.debug('Creating Neptune Database session for %s', host)
|
|
129
|
+
elif host.startswith('neptune-graph://'):
|
|
130
|
+
# This is a Neptune Analytics Graph
|
|
131
|
+
graphId = host.replace('neptune-graph://', '')
|
|
132
|
+
self.client = NeptuneAnalyticsGraph(graphId)
|
|
133
|
+
logger.debug('Creating Neptune Graph session for %s', host)
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
'You must provide an endpoint to create a NeptuneDriver as either neptune-db://<endpoint> or neptune-graph://<graphid>'
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if not aoss_host:
|
|
140
|
+
raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
|
|
141
|
+
|
|
142
|
+
session = boto3.Session()
|
|
143
|
+
self.aoss_client = OpenSearch(
|
|
144
|
+
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
|
145
|
+
http_auth=Urllib3AWSV4SignerAuth(
|
|
146
|
+
session.get_credentials(), session.region_name, 'aoss'
|
|
147
|
+
),
|
|
148
|
+
use_ssl=True,
|
|
149
|
+
verify_certs=True,
|
|
150
|
+
connection_class=Urllib3HttpConnection,
|
|
151
|
+
pool_maxsize=20,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def _sanitize_parameters(self, query, params: dict):
|
|
155
|
+
if isinstance(query, list):
|
|
156
|
+
queries = []
|
|
157
|
+
for q in query:
|
|
158
|
+
queries.append(self._sanitize_parameters(q, params))
|
|
159
|
+
return queries
|
|
160
|
+
else:
|
|
161
|
+
for k, v in params.items():
|
|
162
|
+
if isinstance(v, datetime.datetime):
|
|
163
|
+
params[k] = v.isoformat()
|
|
164
|
+
elif isinstance(v, list):
|
|
165
|
+
# Handle lists that might contain datetime objects
|
|
166
|
+
for i, item in enumerate(v):
|
|
167
|
+
if isinstance(item, datetime.datetime):
|
|
168
|
+
v[i] = item.isoformat()
|
|
169
|
+
query = str(query).replace(f'${k}', f'datetime(${k})')
|
|
170
|
+
if isinstance(item, dict):
|
|
171
|
+
query = self._sanitize_parameters(query, v[i])
|
|
172
|
+
|
|
173
|
+
# If the list contains datetime objects, we need to wrap each element with datetime()
|
|
174
|
+
if any(isinstance(item, str) and 'T' in item for item in v):
|
|
175
|
+
# Create a new list expression with datetime() wrapped around each element
|
|
176
|
+
datetime_list = (
|
|
177
|
+
'['
|
|
178
|
+
+ ', '.join(
|
|
179
|
+
f'datetime("{item}")'
|
|
180
|
+
if isinstance(item, str) and 'T' in item
|
|
181
|
+
else repr(item)
|
|
182
|
+
for item in v
|
|
183
|
+
)
|
|
184
|
+
+ ']'
|
|
185
|
+
)
|
|
186
|
+
query = str(query).replace(f'${k}', datetime_list)
|
|
187
|
+
elif isinstance(v, dict):
|
|
188
|
+
query = self._sanitize_parameters(query, v)
|
|
189
|
+
return query
|
|
190
|
+
|
|
191
|
+
async def execute_query(
|
|
192
|
+
self, cypher_query_, **kwargs: Any
|
|
193
|
+
) -> tuple[dict[str, Any], None, None]:
|
|
194
|
+
params = dict(kwargs)
|
|
195
|
+
if isinstance(cypher_query_, list):
|
|
196
|
+
for q in cypher_query_:
|
|
197
|
+
result, _, _ = self._run_query(q[0], q[1])
|
|
198
|
+
return result, None, None
|
|
199
|
+
else:
|
|
200
|
+
return self._run_query(cypher_query_, params)
|
|
201
|
+
|
|
202
|
+
def _run_query(self, cypher_query_, params):
|
|
203
|
+
cypher_query_ = str(self._sanitize_parameters(cypher_query_, params))
|
|
204
|
+
try:
|
|
205
|
+
result = self.client.query(cypher_query_, params=params)
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.error('Query: %s', cypher_query_)
|
|
208
|
+
logger.error('Parameters: %s', params)
|
|
209
|
+
logger.error('Error executing query: %s', e)
|
|
210
|
+
raise e
|
|
211
|
+
|
|
212
|
+
return result, None, None
|
|
213
|
+
|
|
214
|
+
def session(self, database: str | None = None) -> GraphDriverSession:
|
|
215
|
+
return NeptuneDriverSession(driver=self)
|
|
216
|
+
|
|
217
|
+
async def close(self) -> None:
|
|
218
|
+
return self.client.client.close()
|
|
219
|
+
|
|
220
|
+
async def _delete_all_data(self) -> Any:
|
|
221
|
+
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
|
222
|
+
|
|
223
|
+
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
|
224
|
+
return self.delete_all_indexes_impl()
|
|
225
|
+
|
|
226
|
+
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
|
227
|
+
# No matter what happens above, always return True
|
|
228
|
+
return self.delete_aoss_indices()
|
|
229
|
+
|
|
230
|
+
async def create_aoss_indices(self):
|
|
231
|
+
for index in aoss_indices:
|
|
232
|
+
index_name = index['index_name']
|
|
233
|
+
client = self.aoss_client
|
|
234
|
+
if not client.indices.exists(index=index_name):
|
|
235
|
+
client.indices.create(index=index_name, body=index['body'])
|
|
236
|
+
# Sleep for 1 minute to let the index creation complete
|
|
237
|
+
await asyncio.sleep(60)
|
|
238
|
+
|
|
239
|
+
async def delete_aoss_indices(self):
|
|
240
|
+
for index in aoss_indices:
|
|
241
|
+
index_name = index['index_name']
|
|
242
|
+
client = self.aoss_client
|
|
243
|
+
if client.indices.exists(index=index_name):
|
|
244
|
+
client.indices.delete(index=index_name)
|
|
245
|
+
|
|
246
|
+
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
|
247
|
+
# Neptune uses OpenSearch (AOSS) for indexing
|
|
248
|
+
if delete_existing:
|
|
249
|
+
await self.delete_aoss_indices()
|
|
250
|
+
await self.create_aoss_indices()
|
|
251
|
+
|
|
252
|
+
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
|
253
|
+
for index in aoss_indices:
|
|
254
|
+
if name.lower() == index['index_name']:
|
|
255
|
+
index['query']['query']['multi_match']['query'] = query_text
|
|
256
|
+
query = {'size': limit, 'query': index['query']}
|
|
257
|
+
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
|
258
|
+
return resp
|
|
259
|
+
return {}
|
|
260
|
+
|
|
261
|
+
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
262
|
+
for index in aoss_indices:
|
|
263
|
+
if name.lower() == index['index_name']:
|
|
264
|
+
to_index = []
|
|
265
|
+
for d in data:
|
|
266
|
+
item = {'_index': name, '_id': d['uuid']}
|
|
267
|
+
for p in index['body']['mappings']['properties']:
|
|
268
|
+
if p in d:
|
|
269
|
+
item[p] = d[p]
|
|
270
|
+
to_index.append(item)
|
|
271
|
+
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
|
272
|
+
return success
|
|
273
|
+
|
|
274
|
+
return 0
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class NeptuneDriverSession(GraphDriverSession):
|
|
278
|
+
provider = GraphProvider.NEPTUNE
|
|
279
|
+
|
|
280
|
+
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
|
|
281
|
+
self.driver = driver
|
|
282
|
+
|
|
283
|
+
async def __aenter__(self):
|
|
284
|
+
return self
|
|
285
|
+
|
|
286
|
+
async def __aexit__(self, exc_type, exc, tb):
|
|
287
|
+
# No cleanup needed for Neptune, but method must exist
|
|
288
|
+
pass
|
|
289
|
+
|
|
290
|
+
async def close(self):
|
|
291
|
+
# No explicit close needed for Neptune, but method must exist
|
|
292
|
+
pass
|
|
293
|
+
|
|
294
|
+
async def execute_write(self, func, *args, **kwargs):
|
|
295
|
+
# Directly await the provided async function with `self` as the transaction/session
|
|
296
|
+
return await func(self, *args, **kwargs)
|
|
297
|
+
|
|
298
|
+
async def run(self, query: str | list, **kwargs: Any) -> Any:
|
|
299
|
+
if isinstance(query, list):
|
|
300
|
+
res = None
|
|
301
|
+
for q in query:
|
|
302
|
+
res = await self.driver.execute_query(q, **kwargs)
|
|
303
|
+
return res
|
|
304
|
+
else:
|
|
305
|
+
return await self.driver.execute_query(str(query), **kwargs)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SearchInterface(BaseModel):
|
|
23
|
+
"""
|
|
24
|
+
This is an interface for implementing custom search logic
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
async def edge_fulltext_search(
|
|
28
|
+
self,
|
|
29
|
+
driver: Any,
|
|
30
|
+
query: str,
|
|
31
|
+
search_filter: Any,
|
|
32
|
+
group_ids: list[str] | None = None,
|
|
33
|
+
limit: int = 100,
|
|
34
|
+
) -> list[Any]:
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
async def edge_similarity_search(
|
|
38
|
+
self,
|
|
39
|
+
driver: Any,
|
|
40
|
+
search_vector: list[float],
|
|
41
|
+
source_node_uuid: str | None,
|
|
42
|
+
target_node_uuid: str | None,
|
|
43
|
+
search_filter: Any,
|
|
44
|
+
group_ids: list[str] | None = None,
|
|
45
|
+
limit: int = 100,
|
|
46
|
+
min_score: float = 0.7,
|
|
47
|
+
) -> list[Any]:
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
async def node_fulltext_search(
|
|
51
|
+
self,
|
|
52
|
+
driver: Any,
|
|
53
|
+
query: str,
|
|
54
|
+
search_filter: Any,
|
|
55
|
+
group_ids: list[str] | None = None,
|
|
56
|
+
limit: int = 100,
|
|
57
|
+
) -> list[Any]:
|
|
58
|
+
raise NotImplementedError
|
|
59
|
+
|
|
60
|
+
async def node_similarity_search(
|
|
61
|
+
self,
|
|
62
|
+
driver: Any,
|
|
63
|
+
search_vector: list[float],
|
|
64
|
+
search_filter: Any,
|
|
65
|
+
group_ids: list[str] | None = None,
|
|
66
|
+
limit: int = 100,
|
|
67
|
+
min_score: float = 0.7,
|
|
68
|
+
) -> list[Any]:
|
|
69
|
+
raise NotImplementedError
|
|
70
|
+
|
|
71
|
+
async def episode_fulltext_search(
|
|
72
|
+
self,
|
|
73
|
+
driver: Any,
|
|
74
|
+
query: str,
|
|
75
|
+
search_filter: Any, # kept for parity even if unused in your impl
|
|
76
|
+
group_ids: list[str] | None = None,
|
|
77
|
+
limit: int = 100,
|
|
78
|
+
) -> list[Any]:
|
|
79
|
+
raise NotImplementedError
|
|
80
|
+
|
|
81
|
+
# ---------- SEARCH FILTERS (sync) ----------
|
|
82
|
+
def build_node_search_filters(self, search_filters: Any) -> Any:
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
def build_edge_search_filters(self, search_filters: Any) -> Any:
|
|
86
|
+
raise NotImplementedError
|
|
87
|
+
|
|
88
|
+
class Config:
|
|
89
|
+
arbitrary_types_allowed = True
|