graphiti-core 0.18.8__py3-none-any.whl → 0.19.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +1 -0
- graphiti_core/driver/neptune_driver.py +299 -0
- graphiti_core/edges.py +35 -7
- graphiti_core/graphiti.py +2 -0
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/openai_base_client.py +15 -5
- graphiti_core/llm_client/openai_client.py +16 -6
- graphiti_core/migrations/__init__.py +0 -0
- graphiti_core/migrations/neo4j_node_group_labels.py +53 -0
- graphiti_core/models/edges/edge_db_queries.py +104 -54
- graphiti_core/models/nodes/node_db_queries.py +165 -65
- graphiti_core/nodes.py +121 -51
- graphiti_core/prompts/extract_edges.py +1 -0
- graphiti_core/prompts/extract_nodes.py +1 -1
- graphiti_core/search/search_utils.py +878 -267
- graphiti_core/utils/bulk_utils.py +6 -3
- graphiti_core/utils/maintenance/edge_operations.py +36 -13
- graphiti_core/utils/maintenance/graph_data_operations.py +59 -7
- graphiti_core/utils/maintenance/node_operations.py +7 -3
- {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0rc1.dist-info}/METADATA +44 -6
- {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0rc1.dist-info}/RECORD +23 -20
- {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0rc1.dist-info}/WHEEL +0 -0
- {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0rc1.dist-info}/licenses/LICENSE +0 -0
graphiti_core/driver/driver.py
CHANGED
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import datetime
|
|
19
|
+
import logging
|
|
20
|
+
from collections.abc import Coroutine
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import boto3
|
|
24
|
+
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
|
25
|
+
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
|
26
|
+
|
|
27
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
DEFAULT_SIZE = 10
|
|
31
|
+
|
|
32
|
+
aoss_indices = [
|
|
33
|
+
{
|
|
34
|
+
'index_name': 'node_name_and_summary',
|
|
35
|
+
'body': {
|
|
36
|
+
'mappings': {
|
|
37
|
+
'properties': {
|
|
38
|
+
'uuid': {'type': 'keyword'},
|
|
39
|
+
'name': {'type': 'text'},
|
|
40
|
+
'summary': {'type': 'text'},
|
|
41
|
+
'group_id': {'type': 'text'},
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
},
|
|
45
|
+
'query': {
|
|
46
|
+
'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
|
|
47
|
+
'size': DEFAULT_SIZE,
|
|
48
|
+
},
|
|
49
|
+
},
|
|
50
|
+
{
|
|
51
|
+
'index_name': 'community_name',
|
|
52
|
+
'body': {
|
|
53
|
+
'mappings': {
|
|
54
|
+
'properties': {
|
|
55
|
+
'uuid': {'type': 'keyword'},
|
|
56
|
+
'name': {'type': 'text'},
|
|
57
|
+
'group_id': {'type': 'text'},
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
},
|
|
61
|
+
'query': {
|
|
62
|
+
'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
|
|
63
|
+
'size': DEFAULT_SIZE,
|
|
64
|
+
},
|
|
65
|
+
},
|
|
66
|
+
{
|
|
67
|
+
'index_name': 'episode_content',
|
|
68
|
+
'body': {
|
|
69
|
+
'mappings': {
|
|
70
|
+
'properties': {
|
|
71
|
+
'uuid': {'type': 'keyword'},
|
|
72
|
+
'content': {'type': 'text'},
|
|
73
|
+
'source': {'type': 'text'},
|
|
74
|
+
'source_description': {'type': 'text'},
|
|
75
|
+
'group_id': {'type': 'text'},
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
},
|
|
79
|
+
'query': {
|
|
80
|
+
'query': {
|
|
81
|
+
'multi_match': {
|
|
82
|
+
'query': '',
|
|
83
|
+
'fields': ['content', 'source', 'source_description', 'group_id'],
|
|
84
|
+
}
|
|
85
|
+
},
|
|
86
|
+
'size': DEFAULT_SIZE,
|
|
87
|
+
},
|
|
88
|
+
},
|
|
89
|
+
{
|
|
90
|
+
'index_name': 'edge_name_and_fact',
|
|
91
|
+
'body': {
|
|
92
|
+
'mappings': {
|
|
93
|
+
'properties': {
|
|
94
|
+
'uuid': {'type': 'keyword'},
|
|
95
|
+
'name': {'type': 'text'},
|
|
96
|
+
'fact': {'type': 'text'},
|
|
97
|
+
'group_id': {'type': 'text'},
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
},
|
|
101
|
+
'query': {
|
|
102
|
+
'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
|
|
103
|
+
'size': DEFAULT_SIZE,
|
|
104
|
+
},
|
|
105
|
+
},
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class NeptuneDriver(GraphDriver):
|
|
110
|
+
provider: GraphProvider = GraphProvider.NEPTUNE
|
|
111
|
+
|
|
112
|
+
def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
|
|
113
|
+
"""This initializes a NeptuneDriver for use with Neptune as a backend
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
host (str): The Neptune Database or Neptune Analytics host
|
|
117
|
+
aoss_host (str): The OpenSearch host value
|
|
118
|
+
port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
|
|
119
|
+
aoss_port (int, optional): The OpenSearch port. Defaults to 443.
|
|
120
|
+
"""
|
|
121
|
+
if not host:
|
|
122
|
+
raise ValueError('You must provide an endpoint to create a NeptuneDriver')
|
|
123
|
+
|
|
124
|
+
if host.startswith('neptune-db://'):
|
|
125
|
+
# This is a Neptune Database Cluster
|
|
126
|
+
endpoint = host.replace('neptune-db://', '')
|
|
127
|
+
self.client = NeptuneGraph(endpoint, port)
|
|
128
|
+
logger.debug('Creating Neptune Database session for %s', host)
|
|
129
|
+
elif host.startswith('neptune-graph://'):
|
|
130
|
+
# This is a Neptune Analytics Graph
|
|
131
|
+
graphId = host.replace('neptune-graph://', '')
|
|
132
|
+
self.client = NeptuneAnalyticsGraph(graphId)
|
|
133
|
+
logger.debug('Creating Neptune Graph session for %s', host)
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
'You must provide an endpoint to create a NeptuneDriver as either neptune-db://<endpoint> or neptune-graph://<graphid>'
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if not aoss_host:
|
|
140
|
+
raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
|
|
141
|
+
|
|
142
|
+
session = boto3.Session()
|
|
143
|
+
self.aoss_client = OpenSearch(
|
|
144
|
+
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
|
145
|
+
http_auth=Urllib3AWSV4SignerAuth(
|
|
146
|
+
session.get_credentials(), session.region_name, 'aoss'
|
|
147
|
+
),
|
|
148
|
+
use_ssl=True,
|
|
149
|
+
verify_certs=True,
|
|
150
|
+
connection_class=Urllib3HttpConnection,
|
|
151
|
+
pool_maxsize=20,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def _sanitize_parameters(self, query, params: dict):
|
|
155
|
+
if isinstance(query, list):
|
|
156
|
+
queries = []
|
|
157
|
+
for q in query:
|
|
158
|
+
queries.append(self._sanitize_parameters(q, params))
|
|
159
|
+
return queries
|
|
160
|
+
else:
|
|
161
|
+
for k, v in params.items():
|
|
162
|
+
if isinstance(v, datetime.datetime):
|
|
163
|
+
params[k] = v.isoformat()
|
|
164
|
+
elif isinstance(v, list):
|
|
165
|
+
# Handle lists that might contain datetime objects
|
|
166
|
+
for i, item in enumerate(v):
|
|
167
|
+
if isinstance(item, datetime.datetime):
|
|
168
|
+
v[i] = item.isoformat()
|
|
169
|
+
query = str(query).replace(f'${k}', f'datetime(${k})')
|
|
170
|
+
if isinstance(item, dict):
|
|
171
|
+
query = self._sanitize_parameters(query, v[i])
|
|
172
|
+
|
|
173
|
+
# If the list contains datetime objects, we need to wrap each element with datetime()
|
|
174
|
+
if any(isinstance(item, str) and 'T' in item for item in v):
|
|
175
|
+
# Create a new list expression with datetime() wrapped around each element
|
|
176
|
+
datetime_list = (
|
|
177
|
+
'['
|
|
178
|
+
+ ', '.join(
|
|
179
|
+
f'datetime("{item}")'
|
|
180
|
+
if isinstance(item, str) and 'T' in item
|
|
181
|
+
else repr(item)
|
|
182
|
+
for item in v
|
|
183
|
+
)
|
|
184
|
+
+ ']'
|
|
185
|
+
)
|
|
186
|
+
query = str(query).replace(f'${k}', datetime_list)
|
|
187
|
+
elif isinstance(v, dict):
|
|
188
|
+
query = self._sanitize_parameters(query, v)
|
|
189
|
+
return query
|
|
190
|
+
|
|
191
|
+
async def execute_query(
|
|
192
|
+
self, cypher_query_, **kwargs: Any
|
|
193
|
+
) -> tuple[dict[str, Any], None, None]:
|
|
194
|
+
params = dict(kwargs)
|
|
195
|
+
if isinstance(cypher_query_, list):
|
|
196
|
+
for q in cypher_query_:
|
|
197
|
+
result, _, _ = self._run_query(q[0], q[1])
|
|
198
|
+
return result, None, None
|
|
199
|
+
else:
|
|
200
|
+
return self._run_query(cypher_query_, params)
|
|
201
|
+
|
|
202
|
+
def _run_query(self, cypher_query_, params):
|
|
203
|
+
cypher_query_ = str(self._sanitize_parameters(cypher_query_, params))
|
|
204
|
+
try:
|
|
205
|
+
result = self.client.query(cypher_query_, params=params)
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.error('Query: %s', cypher_query_)
|
|
208
|
+
logger.error('Parameters: %s', params)
|
|
209
|
+
logger.error('Error executing query: %s', e)
|
|
210
|
+
raise e
|
|
211
|
+
|
|
212
|
+
return result, None, None
|
|
213
|
+
|
|
214
|
+
def session(self, database: str | None = None) -> GraphDriverSession:
|
|
215
|
+
return NeptuneDriverSession(driver=self)
|
|
216
|
+
|
|
217
|
+
async def close(self) -> None:
|
|
218
|
+
return self.client.client.close()
|
|
219
|
+
|
|
220
|
+
async def _delete_all_data(self) -> Any:
|
|
221
|
+
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
|
222
|
+
|
|
223
|
+
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
|
224
|
+
return self.delete_all_indexes_impl()
|
|
225
|
+
|
|
226
|
+
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
|
227
|
+
# No matter what happens above, always return True
|
|
228
|
+
return self.delete_aoss_indices()
|
|
229
|
+
|
|
230
|
+
async def create_aoss_indices(self):
|
|
231
|
+
for index in aoss_indices:
|
|
232
|
+
index_name = index['index_name']
|
|
233
|
+
client = self.aoss_client
|
|
234
|
+
if not client.indices.exists(index=index_name):
|
|
235
|
+
client.indices.create(index=index_name, body=index['body'])
|
|
236
|
+
# Sleep for 1 minute to let the index creation complete
|
|
237
|
+
await asyncio.sleep(60)
|
|
238
|
+
|
|
239
|
+
async def delete_aoss_indices(self):
|
|
240
|
+
for index in aoss_indices:
|
|
241
|
+
index_name = index['index_name']
|
|
242
|
+
client = self.aoss_client
|
|
243
|
+
if client.indices.exists(index=index_name):
|
|
244
|
+
client.indices.delete(index=index_name)
|
|
245
|
+
|
|
246
|
+
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
|
247
|
+
for index in aoss_indices:
|
|
248
|
+
if name.lower() == index['index_name']:
|
|
249
|
+
index['query']['query']['multi_match']['query'] = query_text
|
|
250
|
+
query = {'size': limit, 'query': index['query']}
|
|
251
|
+
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
|
252
|
+
return resp
|
|
253
|
+
return {}
|
|
254
|
+
|
|
255
|
+
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
256
|
+
for index in aoss_indices:
|
|
257
|
+
if name.lower() == index['index_name']:
|
|
258
|
+
to_index = []
|
|
259
|
+
for d in data:
|
|
260
|
+
item = {'_index': name}
|
|
261
|
+
for p in index['body']['mappings']['properties']:
|
|
262
|
+
item[p] = d[p]
|
|
263
|
+
to_index.append(item)
|
|
264
|
+
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
|
265
|
+
if failed > 0:
|
|
266
|
+
return success
|
|
267
|
+
else:
|
|
268
|
+
return 0
|
|
269
|
+
|
|
270
|
+
return 0
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class NeptuneDriverSession(GraphDriverSession):
|
|
274
|
+
def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
|
|
275
|
+
self.driver = driver
|
|
276
|
+
|
|
277
|
+
async def __aenter__(self):
|
|
278
|
+
return self
|
|
279
|
+
|
|
280
|
+
async def __aexit__(self, exc_type, exc, tb):
|
|
281
|
+
# No cleanup needed for Neptune, but method must exist
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
async def close(self):
|
|
285
|
+
# No explicit close needed for Neptune, but method must exist
|
|
286
|
+
pass
|
|
287
|
+
|
|
288
|
+
async def execute_write(self, func, *args, **kwargs):
|
|
289
|
+
# Directly await the provided async function with `self` as the transaction/session
|
|
290
|
+
return await func(self, *args, **kwargs)
|
|
291
|
+
|
|
292
|
+
async def run(self, query: str | list, **kwargs: Any) -> Any:
|
|
293
|
+
if isinstance(query, list):
|
|
294
|
+
res = None
|
|
295
|
+
for q in query:
|
|
296
|
+
res = await self.driver.execute_query(q, **kwargs)
|
|
297
|
+
return res
|
|
298
|
+
else:
|
|
299
|
+
return await self.driver.execute_query(str(query), **kwargs)
|
graphiti_core/edges.py
CHANGED
|
@@ -24,13 +24,14 @@ from uuid import uuid4
|
|
|
24
24
|
from pydantic import BaseModel, Field
|
|
25
25
|
from typing_extensions import LiteralString
|
|
26
26
|
|
|
27
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
27
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
28
28
|
from graphiti_core.embedder import EmbedderClient
|
|
29
29
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
30
30
|
from graphiti_core.helpers import parse_db_date
|
|
31
31
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
32
32
|
COMMUNITY_EDGE_RETURN,
|
|
33
33
|
ENTITY_EDGE_RETURN,
|
|
34
|
+
ENTITY_EDGE_RETURN_NEPTUNE,
|
|
34
35
|
EPISODIC_EDGE_RETURN,
|
|
35
36
|
EPISODIC_EDGE_SAVE,
|
|
36
37
|
get_community_edge_save_query,
|
|
@@ -214,11 +215,19 @@ class EntityEdge(Edge):
|
|
|
214
215
|
return self.fact_embedding
|
|
215
216
|
|
|
216
217
|
async def load_fact_embedding(self, driver: GraphDriver):
|
|
217
|
-
|
|
218
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
219
|
+
query: LiteralString = """
|
|
220
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
221
|
+
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
|
218
222
|
"""
|
|
223
|
+
else:
|
|
224
|
+
query: LiteralString = """
|
|
219
225
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
220
226
|
RETURN e.fact_embedding AS fact_embedding
|
|
221
|
-
"""
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
records, _, _ = await driver.execute_query(
|
|
230
|
+
query,
|
|
222
231
|
uuid=self.uuid,
|
|
223
232
|
routing_='r',
|
|
224
233
|
)
|
|
@@ -246,6 +255,9 @@ class EntityEdge(Edge):
|
|
|
246
255
|
|
|
247
256
|
edge_data.update(self.attributes or {})
|
|
248
257
|
|
|
258
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
259
|
+
driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
|
260
|
+
|
|
249
261
|
result = await driver.execute_query(
|
|
250
262
|
get_entity_edge_save_query(driver.provider),
|
|
251
263
|
edge_data=edge_data,
|
|
@@ -262,7 +274,11 @@ class EntityEdge(Edge):
|
|
|
262
274
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
263
275
|
RETURN
|
|
264
276
|
"""
|
|
265
|
-
+
|
|
277
|
+
+ (
|
|
278
|
+
ENTITY_EDGE_RETURN_NEPTUNE
|
|
279
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
280
|
+
else ENTITY_EDGE_RETURN
|
|
281
|
+
),
|
|
266
282
|
uuid=uuid,
|
|
267
283
|
routing_='r',
|
|
268
284
|
)
|
|
@@ -284,7 +300,11 @@ class EntityEdge(Edge):
|
|
|
284
300
|
WHERE e.uuid IN $uuids
|
|
285
301
|
RETURN
|
|
286
302
|
"""
|
|
287
|
-
+
|
|
303
|
+
+ (
|
|
304
|
+
ENTITY_EDGE_RETURN_NEPTUNE
|
|
305
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
306
|
+
else ENTITY_EDGE_RETURN
|
|
307
|
+
),
|
|
288
308
|
uuids=uuids,
|
|
289
309
|
routing_='r',
|
|
290
310
|
)
|
|
@@ -321,7 +341,11 @@ class EntityEdge(Edge):
|
|
|
321
341
|
+ """
|
|
322
342
|
RETURN
|
|
323
343
|
"""
|
|
324
|
-
+
|
|
344
|
+
+ (
|
|
345
|
+
ENTITY_EDGE_RETURN_NEPTUNE
|
|
346
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
347
|
+
else ENTITY_EDGE_RETURN
|
|
348
|
+
)
|
|
325
349
|
+ with_embeddings_query
|
|
326
350
|
+ """
|
|
327
351
|
ORDER BY e.uuid DESC
|
|
@@ -346,7 +370,11 @@ class EntityEdge(Edge):
|
|
|
346
370
|
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
|
|
347
371
|
RETURN
|
|
348
372
|
"""
|
|
349
|
-
+
|
|
373
|
+
+ (
|
|
374
|
+
ENTITY_EDGE_RETURN_NEPTUNE
|
|
375
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
376
|
+
else ENTITY_EDGE_RETURN
|
|
377
|
+
),
|
|
350
378
|
node_uuid=node_uuid,
|
|
351
379
|
routing_='r',
|
|
352
380
|
)
|
graphiti_core/graphiti.py
CHANGED
|
@@ -89,6 +89,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|
|
89
89
|
)
|
|
90
90
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
91
91
|
EPISODE_WINDOW_LEN,
|
|
92
|
+
build_dynamic_indexes,
|
|
92
93
|
build_indices_and_constraints,
|
|
93
94
|
retrieve_episodes,
|
|
94
95
|
)
|
|
@@ -450,6 +451,7 @@ class Graphiti:
|
|
|
450
451
|
|
|
451
452
|
validate_excluded_entity_types(excluded_entity_types, entity_types)
|
|
452
453
|
validate_group_id(group_id)
|
|
454
|
+
await build_dynamic_indexes(self.driver, group_id)
|
|
453
455
|
|
|
454
456
|
previous_episodes = (
|
|
455
457
|
await self.retrieve_episodes(
|
|
@@ -31,8 +31,10 @@ from .errors import RateLimitError, RefusalError
|
|
|
31
31
|
|
|
32
32
|
logger = logging.getLogger(__name__)
|
|
33
33
|
|
|
34
|
-
DEFAULT_MODEL = 'gpt-
|
|
35
|
-
DEFAULT_SMALL_MODEL = 'gpt-
|
|
34
|
+
DEFAULT_MODEL = 'gpt-5-mini'
|
|
35
|
+
DEFAULT_SMALL_MODEL = 'gpt-5-nano'
|
|
36
|
+
DEFAULT_REASONING = 'minimal'
|
|
37
|
+
DEFAULT_VERBOSITY = 'low'
|
|
36
38
|
|
|
37
39
|
|
|
38
40
|
class BaseOpenAIClient(LLMClient):
|
|
@@ -51,6 +53,8 @@ class BaseOpenAIClient(LLMClient):
|
|
|
51
53
|
config: LLMConfig | None = None,
|
|
52
54
|
cache: bool = False,
|
|
53
55
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
56
|
+
reasoning: str | None = DEFAULT_REASONING,
|
|
57
|
+
verbosity: str | None = DEFAULT_VERBOSITY,
|
|
54
58
|
):
|
|
55
59
|
if cache:
|
|
56
60
|
raise NotImplementedError('Caching is not implemented for OpenAI-based clients')
|
|
@@ -60,6 +64,8 @@ class BaseOpenAIClient(LLMClient):
|
|
|
60
64
|
|
|
61
65
|
super().__init__(config, cache)
|
|
62
66
|
self.max_tokens = max_tokens
|
|
67
|
+
self.reasoning = reasoning
|
|
68
|
+
self.verbosity = verbosity
|
|
63
69
|
|
|
64
70
|
@abstractmethod
|
|
65
71
|
async def _create_completion(
|
|
@@ -81,6 +87,8 @@ class BaseOpenAIClient(LLMClient):
|
|
|
81
87
|
temperature: float | None,
|
|
82
88
|
max_tokens: int,
|
|
83
89
|
response_model: type[BaseModel],
|
|
90
|
+
reasoning: str | None,
|
|
91
|
+
verbosity: str | None,
|
|
84
92
|
) -> Any:
|
|
85
93
|
"""Create a structured completion using the specific client implementation."""
|
|
86
94
|
pass
|
|
@@ -107,10 +115,10 @@ class BaseOpenAIClient(LLMClient):
|
|
|
107
115
|
|
|
108
116
|
def _handle_structured_response(self, response: Any) -> dict[str, Any]:
|
|
109
117
|
"""Handle structured response parsing and validation."""
|
|
110
|
-
response_object = response.
|
|
118
|
+
response_object = response.output_text
|
|
111
119
|
|
|
112
|
-
if response_object
|
|
113
|
-
return
|
|
120
|
+
if response_object:
|
|
121
|
+
return json.loads(response_object)
|
|
114
122
|
elif response_object.refusal:
|
|
115
123
|
raise RefusalError(response_object.refusal)
|
|
116
124
|
else:
|
|
@@ -140,6 +148,8 @@ class BaseOpenAIClient(LLMClient):
|
|
|
140
148
|
temperature=self.temperature,
|
|
141
149
|
max_tokens=max_tokens or self.max_tokens,
|
|
142
150
|
response_model=response_model,
|
|
151
|
+
reasoning=self.reasoning,
|
|
152
|
+
verbosity=self.verbosity,
|
|
143
153
|
)
|
|
144
154
|
return self._handle_structured_response(response)
|
|
145
155
|
else:
|
|
@@ -21,7 +21,7 @@ from openai.types.chat import ChatCompletionMessageParam
|
|
|
21
21
|
from pydantic import BaseModel
|
|
22
22
|
|
|
23
23
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
|
24
|
-
from .openai_base_client import BaseOpenAIClient
|
|
24
|
+
from .openai_base_client import DEFAULT_REASONING, DEFAULT_VERBOSITY, BaseOpenAIClient
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class OpenAIClient(BaseOpenAIClient):
|
|
@@ -41,6 +41,8 @@ class OpenAIClient(BaseOpenAIClient):
|
|
|
41
41
|
cache: bool = False,
|
|
42
42
|
client: typing.Any = None,
|
|
43
43
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
44
|
+
reasoning: str = DEFAULT_REASONING,
|
|
45
|
+
verbosity: str = DEFAULT_VERBOSITY,
|
|
44
46
|
):
|
|
45
47
|
"""
|
|
46
48
|
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
|
@@ -50,7 +52,7 @@ class OpenAIClient(BaseOpenAIClient):
|
|
|
50
52
|
cache (bool): Whether to use caching for responses. Defaults to False.
|
|
51
53
|
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
52
54
|
"""
|
|
53
|
-
super().__init__(config, cache, max_tokens)
|
|
55
|
+
super().__init__(config, cache, max_tokens, reasoning, verbosity)
|
|
54
56
|
|
|
55
57
|
if config is None:
|
|
56
58
|
config = LLMConfig()
|
|
@@ -67,16 +69,22 @@ class OpenAIClient(BaseOpenAIClient):
|
|
|
67
69
|
temperature: float | None,
|
|
68
70
|
max_tokens: int,
|
|
69
71
|
response_model: type[BaseModel],
|
|
72
|
+
reasoning: str | None = None,
|
|
73
|
+
verbosity: str | None = None,
|
|
70
74
|
):
|
|
71
75
|
"""Create a structured completion using OpenAI's beta parse API."""
|
|
72
|
-
|
|
76
|
+
response = await self.client.responses.parse(
|
|
73
77
|
model=model,
|
|
74
|
-
|
|
78
|
+
input=messages, # type: ignore
|
|
75
79
|
temperature=temperature,
|
|
76
|
-
|
|
77
|
-
|
|
80
|
+
max_output_tokens=max_tokens,
|
|
81
|
+
text_format=response_model, # type: ignore
|
|
82
|
+
reasoning={'effort': reasoning} if reasoning is not None else None, # type: ignore
|
|
83
|
+
text={'verbosity': verbosity} if verbosity is not None else None, # type: ignore
|
|
78
84
|
)
|
|
79
85
|
|
|
86
|
+
return response
|
|
87
|
+
|
|
80
88
|
async def _create_completion(
|
|
81
89
|
self,
|
|
82
90
|
model: str,
|
|
@@ -84,6 +92,8 @@ class OpenAIClient(BaseOpenAIClient):
|
|
|
84
92
|
temperature: float | None,
|
|
85
93
|
max_tokens: int,
|
|
86
94
|
response_model: type[BaseModel] | None = None,
|
|
95
|
+
reasoning: str | None = None,
|
|
96
|
+
verbosity: str | None = None,
|
|
87
97
|
):
|
|
88
98
|
"""Create a regular completion with JSON format."""
|
|
89
99
|
return await self.client.chat.completions.create(
|
|
File without changes
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
2
|
+
from graphiti_core.helpers import validate_group_id
|
|
3
|
+
from graphiti_core.utils.maintenance.graph_data_operations import build_dynamic_indexes
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
async def neo4j_node_group_labels(driver: GraphDriver, group_id: str, batch_size: int = 100):
|
|
7
|
+
validate_group_id(group_id)
|
|
8
|
+
await build_dynamic_indexes(driver, group_id)
|
|
9
|
+
|
|
10
|
+
episode_query = """
|
|
11
|
+
MATCH (n:Episodic {group_id: $group_id})
|
|
12
|
+
CALL {
|
|
13
|
+
WITH n
|
|
14
|
+
SET n:$group_label
|
|
15
|
+
} IN TRANSACTIONS OF $batch_size ROWS"""
|
|
16
|
+
|
|
17
|
+
entity_query = """
|
|
18
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
19
|
+
CALL {
|
|
20
|
+
WITH n
|
|
21
|
+
SET n:$group_label
|
|
22
|
+
} IN TRANSACTIONS OF $batch_size ROWS"""
|
|
23
|
+
|
|
24
|
+
community_query = """
|
|
25
|
+
MATCH (n:Community {group_id: $group_id})
|
|
26
|
+
CALL {
|
|
27
|
+
WITH n
|
|
28
|
+
SET n:$group_label
|
|
29
|
+
} IN TRANSACTIONS OF $batch_size ROWS"""
|
|
30
|
+
|
|
31
|
+
async with driver.session() as session:
|
|
32
|
+
await session.run(
|
|
33
|
+
episode_query,
|
|
34
|
+
group_id=group_id,
|
|
35
|
+
group_label='Episodic_' + group_id.replace('-', ''),
|
|
36
|
+
batch_size=batch_size,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
async with driver.session() as session:
|
|
40
|
+
await session.run(
|
|
41
|
+
entity_query,
|
|
42
|
+
group_id=group_id,
|
|
43
|
+
group_label='Entity_' + group_id.replace('-', ''),
|
|
44
|
+
batch_size=batch_size,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
async with driver.session() as session:
|
|
48
|
+
await session.run(
|
|
49
|
+
community_query,
|
|
50
|
+
group_id=group_id,
|
|
51
|
+
group_label='Community_' + group_id.replace('-', ''),
|
|
52
|
+
batch_size=batch_size,
|
|
53
|
+
)
|