graphiti-core 0.21.0rc13__py3-none-any.whl → 0.22.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +7 -212
- graphiti_core/driver/neo4j_driver.py +0 -49
- graphiti_core/driver/neptune_driver.py +43 -26
- graphiti_core/llm_client/client.py +7 -2
- graphiti_core/llm_client/gemini_client.py +3 -1
- graphiti_core/llm_client/openai_base_client.py +2 -1
- graphiti_core/llm_client/openai_generic_client.py +2 -1
- graphiti_core/prompts/extract_nodes.py +39 -34
- graphiti_core/prompts/summarize_nodes.py +20 -17
- graphiti_core/utils/maintenance/edge_operations.py +2 -0
- graphiti_core/utils/maintenance/node_operations.py +90 -51
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0rc1.dist-info}/METADATA +1 -1
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0rc1.dist-info}/RECORD +15 -15
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0rc1.dist-info}/WHEEL +0 -0
- {graphiti_core-0.21.0rc13.dist-info → graphiti_core-0.22.0rc1.dist-info}/licenses/LICENSE +0 -0
graphiti_core/driver/driver.py
CHANGED
|
@@ -14,29 +14,16 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import copy
|
|
19
18
|
import logging
|
|
20
19
|
import os
|
|
21
20
|
from abc import ABC, abstractmethod
|
|
22
21
|
from collections.abc import Coroutine
|
|
23
|
-
from datetime import datetime
|
|
24
22
|
from enum import Enum
|
|
25
23
|
from typing import Any
|
|
26
24
|
|
|
27
25
|
from dotenv import load_dotenv
|
|
28
26
|
|
|
29
|
-
from graphiti_core.embedder.client import EMBEDDING_DIM
|
|
30
|
-
|
|
31
|
-
try:
|
|
32
|
-
from opensearchpy import AsyncOpenSearch, helpers
|
|
33
|
-
|
|
34
|
-
_HAS_OPENSEARCH = True
|
|
35
|
-
except ImportError:
|
|
36
|
-
OpenSearch = None
|
|
37
|
-
helpers = None
|
|
38
|
-
_HAS_OPENSEARCH = False
|
|
39
|
-
|
|
40
27
|
logger = logging.getLogger(__name__)
|
|
41
28
|
|
|
42
29
|
DEFAULT_SIZE = 10
|
|
@@ -56,91 +43,6 @@ class GraphProvider(Enum):
|
|
|
56
43
|
NEPTUNE = 'neptune'
|
|
57
44
|
|
|
58
45
|
|
|
59
|
-
aoss_indices = [
|
|
60
|
-
{
|
|
61
|
-
'index_name': ENTITY_INDEX_NAME,
|
|
62
|
-
'body': {
|
|
63
|
-
'settings': {'index': {'knn': True}},
|
|
64
|
-
'mappings': {
|
|
65
|
-
'properties': {
|
|
66
|
-
'uuid': {'type': 'keyword'},
|
|
67
|
-
'name': {'type': 'text'},
|
|
68
|
-
'summary': {'type': 'text'},
|
|
69
|
-
'group_id': {'type': 'keyword'},
|
|
70
|
-
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
71
|
-
'name_embedding': {
|
|
72
|
-
'type': 'knn_vector',
|
|
73
|
-
'dimension': EMBEDDING_DIM,
|
|
74
|
-
'method': {
|
|
75
|
-
'engine': 'faiss',
|
|
76
|
-
'space_type': 'cosinesimil',
|
|
77
|
-
'name': 'hnsw',
|
|
78
|
-
'parameters': {'ef_construction': 128, 'm': 16},
|
|
79
|
-
},
|
|
80
|
-
},
|
|
81
|
-
}
|
|
82
|
-
},
|
|
83
|
-
},
|
|
84
|
-
},
|
|
85
|
-
{
|
|
86
|
-
'index_name': COMMUNITY_INDEX_NAME,
|
|
87
|
-
'body': {
|
|
88
|
-
'mappings': {
|
|
89
|
-
'properties': {
|
|
90
|
-
'uuid': {'type': 'keyword'},
|
|
91
|
-
'name': {'type': 'text'},
|
|
92
|
-
'group_id': {'type': 'keyword'},
|
|
93
|
-
}
|
|
94
|
-
}
|
|
95
|
-
},
|
|
96
|
-
},
|
|
97
|
-
{
|
|
98
|
-
'index_name': EPISODE_INDEX_NAME,
|
|
99
|
-
'body': {
|
|
100
|
-
'mappings': {
|
|
101
|
-
'properties': {
|
|
102
|
-
'uuid': {'type': 'keyword'},
|
|
103
|
-
'content': {'type': 'text'},
|
|
104
|
-
'source': {'type': 'text'},
|
|
105
|
-
'source_description': {'type': 'text'},
|
|
106
|
-
'group_id': {'type': 'keyword'},
|
|
107
|
-
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
108
|
-
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
109
|
-
}
|
|
110
|
-
}
|
|
111
|
-
},
|
|
112
|
-
},
|
|
113
|
-
{
|
|
114
|
-
'index_name': ENTITY_EDGE_INDEX_NAME,
|
|
115
|
-
'body': {
|
|
116
|
-
'settings': {'index': {'knn': True}},
|
|
117
|
-
'mappings': {
|
|
118
|
-
'properties': {
|
|
119
|
-
'uuid': {'type': 'keyword'},
|
|
120
|
-
'name': {'type': 'text'},
|
|
121
|
-
'fact': {'type': 'text'},
|
|
122
|
-
'group_id': {'type': 'keyword'},
|
|
123
|
-
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
124
|
-
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
125
|
-
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
126
|
-
'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
|
|
127
|
-
'fact_embedding': {
|
|
128
|
-
'type': 'knn_vector',
|
|
129
|
-
'dimension': EMBEDDING_DIM,
|
|
130
|
-
'method': {
|
|
131
|
-
'engine': 'faiss',
|
|
132
|
-
'space_type': 'cosinesimil',
|
|
133
|
-
'name': 'hnsw',
|
|
134
|
-
'parameters': {'ef_construction': 128, 'm': 16},
|
|
135
|
-
},
|
|
136
|
-
},
|
|
137
|
-
}
|
|
138
|
-
},
|
|
139
|
-
},
|
|
140
|
-
},
|
|
141
|
-
]
|
|
142
|
-
|
|
143
|
-
|
|
144
46
|
class GraphDriverSession(ABC):
|
|
145
47
|
provider: GraphProvider
|
|
146
48
|
|
|
@@ -171,7 +73,7 @@ class GraphDriver(ABC):
|
|
|
171
73
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
|
172
74
|
)
|
|
173
75
|
_database: str
|
|
174
|
-
aoss_client:
|
|
76
|
+
aoss_client: Any # type: ignore
|
|
175
77
|
|
|
176
78
|
@abstractmethod
|
|
177
79
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
@@ -199,119 +101,6 @@ class GraphDriver(ABC):
|
|
|
199
101
|
|
|
200
102
|
return cloned
|
|
201
103
|
|
|
202
|
-
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
|
203
|
-
# No matter what happens above, always return True
|
|
204
|
-
return self.delete_aoss_indices()
|
|
205
|
-
|
|
206
|
-
async def create_aoss_indices(self):
|
|
207
|
-
client = self.aoss_client
|
|
208
|
-
if not client:
|
|
209
|
-
logger.warning('No OpenSearch client found')
|
|
210
|
-
return
|
|
211
|
-
|
|
212
|
-
for index in aoss_indices:
|
|
213
|
-
alias_name = index['index_name']
|
|
214
|
-
|
|
215
|
-
# If alias already exists, skip (idempotent behavior)
|
|
216
|
-
if await client.indices.exists_alias(name=alias_name):
|
|
217
|
-
continue
|
|
218
|
-
|
|
219
|
-
# Build a physical index name with timestamp
|
|
220
|
-
ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
|
|
221
|
-
physical_index_name = f'{alias_name}_{ts_suffix}'
|
|
222
|
-
|
|
223
|
-
# Create the index
|
|
224
|
-
await client.indices.create(index=physical_index_name, body=index['body'])
|
|
225
|
-
|
|
226
|
-
# Point alias to it
|
|
227
|
-
await client.indices.put_alias(index=physical_index_name, name=alias_name)
|
|
228
|
-
|
|
229
|
-
# Allow some time for index creation
|
|
230
|
-
await asyncio.sleep(1)
|
|
231
|
-
|
|
232
|
-
async def delete_aoss_indices(self):
|
|
233
|
-
client = self.aoss_client
|
|
234
|
-
|
|
235
|
-
if not client:
|
|
236
|
-
logger.warning('No OpenSearch client found')
|
|
237
|
-
return
|
|
238
|
-
|
|
239
|
-
for entry in aoss_indices:
|
|
240
|
-
alias_name = entry['index_name']
|
|
241
|
-
|
|
242
|
-
try:
|
|
243
|
-
# Resolve alias → indices
|
|
244
|
-
alias_info = await client.indices.get_alias(name=alias_name)
|
|
245
|
-
indices = list(alias_info.keys())
|
|
246
|
-
|
|
247
|
-
if not indices:
|
|
248
|
-
logger.info(f"No indices found for alias '{alias_name}'")
|
|
249
|
-
continue
|
|
250
|
-
|
|
251
|
-
for index in indices:
|
|
252
|
-
if await client.indices.exists(index=index):
|
|
253
|
-
await client.indices.delete(index=index)
|
|
254
|
-
logger.info(f"Deleted index '{index}' (alias: {alias_name})")
|
|
255
|
-
else:
|
|
256
|
-
logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
|
|
257
|
-
|
|
258
|
-
except Exception as e:
|
|
259
|
-
logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
|
|
260
|
-
|
|
261
|
-
async def clear_aoss_indices(self):
|
|
262
|
-
client = self.aoss_client
|
|
263
|
-
|
|
264
|
-
if not client:
|
|
265
|
-
logger.warning('No OpenSearch client found')
|
|
266
|
-
return
|
|
267
|
-
|
|
268
|
-
for index in aoss_indices:
|
|
269
|
-
index_name = index['index_name']
|
|
270
|
-
|
|
271
|
-
if await client.indices.exists(index=index_name):
|
|
272
|
-
try:
|
|
273
|
-
# Delete all documents but keep the index
|
|
274
|
-
response = await client.delete_by_query(
|
|
275
|
-
index=index_name,
|
|
276
|
-
body={'query': {'match_all': {}}},
|
|
277
|
-
)
|
|
278
|
-
logger.info(f"Cleared index '{index_name}': {response}")
|
|
279
|
-
except Exception as e:
|
|
280
|
-
logger.error(f"Error clearing index '{index_name}': {e}")
|
|
281
|
-
else:
|
|
282
|
-
logger.warning(f"Index '{index_name}' does not exist")
|
|
283
|
-
|
|
284
|
-
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
285
|
-
client = self.aoss_client
|
|
286
|
-
if not client or not helpers:
|
|
287
|
-
logger.warning('No OpenSearch client found')
|
|
288
|
-
return 0
|
|
289
|
-
|
|
290
|
-
for index in aoss_indices:
|
|
291
|
-
if name.lower() == index['index_name']:
|
|
292
|
-
to_index = []
|
|
293
|
-
for d in data:
|
|
294
|
-
doc = {}
|
|
295
|
-
for p in index['body']['mappings']['properties']:
|
|
296
|
-
if p in d: # protect against missing fields
|
|
297
|
-
doc[p] = d[p]
|
|
298
|
-
|
|
299
|
-
item = {
|
|
300
|
-
'_index': name,
|
|
301
|
-
'_id': d['uuid'],
|
|
302
|
-
'_routing': d.get('group_id'),
|
|
303
|
-
'_source': doc,
|
|
304
|
-
}
|
|
305
|
-
to_index.append(item)
|
|
306
|
-
|
|
307
|
-
success, failed = await helpers.async_bulk(
|
|
308
|
-
client, to_index, stats_only=True, request_timeout=60
|
|
309
|
-
)
|
|
310
|
-
|
|
311
|
-
return success if failed == 0 else success
|
|
312
|
-
|
|
313
|
-
return 0
|
|
314
|
-
|
|
315
104
|
def build_fulltext_query(
|
|
316
105
|
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
|
|
317
106
|
) -> str:
|
|
@@ -320,3 +109,9 @@ class GraphDriver(ABC):
|
|
|
320
109
|
Only implemented by providers that need custom fulltext query building.
|
|
321
110
|
"""
|
|
322
111
|
raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
|
|
112
|
+
|
|
113
|
+
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
114
|
+
return 0
|
|
115
|
+
|
|
116
|
+
async def clear_aoss_indices(self):
|
|
117
|
+
return 1
|
|
@@ -22,28 +22,9 @@ from neo4j import AsyncGraphDatabase, EagerResult
|
|
|
22
22
|
from typing_extensions import LiteralString
|
|
23
23
|
|
|
24
24
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
25
|
-
from graphiti_core.helpers import semaphore_gather
|
|
26
25
|
|
|
27
26
|
logger = logging.getLogger(__name__)
|
|
28
27
|
|
|
29
|
-
try:
|
|
30
|
-
import boto3
|
|
31
|
-
from opensearchpy import (
|
|
32
|
-
AIOHttpConnection,
|
|
33
|
-
AsyncOpenSearch,
|
|
34
|
-
AWSV4SignerAuth,
|
|
35
|
-
Urllib3AWSV4SignerAuth,
|
|
36
|
-
Urllib3HttpConnection,
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
_HAS_OPENSEARCH = True
|
|
40
|
-
except ImportError:
|
|
41
|
-
boto3 = None
|
|
42
|
-
OpenSearch = None
|
|
43
|
-
Urllib3AWSV4SignerAuth = None
|
|
44
|
-
Urllib3HttpConnection = None
|
|
45
|
-
_HAS_OPENSEARCH = False
|
|
46
|
-
|
|
47
28
|
|
|
48
29
|
class Neo4jDriver(GraphDriver):
|
|
49
30
|
provider = GraphProvider.NEO4J
|
|
@@ -54,11 +35,6 @@ class Neo4jDriver(GraphDriver):
|
|
|
54
35
|
user: str | None,
|
|
55
36
|
password: str | None,
|
|
56
37
|
database: str = 'neo4j',
|
|
57
|
-
aoss_host: str | None = None,
|
|
58
|
-
aoss_port: int | None = None,
|
|
59
|
-
aws_profile_name: str | None = None,
|
|
60
|
-
aws_region: str | None = None,
|
|
61
|
-
aws_service: str | None = None,
|
|
62
38
|
):
|
|
63
39
|
super().__init__()
|
|
64
40
|
self.client = AsyncGraphDatabase.driver(
|
|
@@ -68,24 +44,6 @@ class Neo4jDriver(GraphDriver):
|
|
|
68
44
|
self._database = database
|
|
69
45
|
|
|
70
46
|
self.aoss_client = None
|
|
71
|
-
if aoss_host and aoss_port and boto3 is not None:
|
|
72
|
-
try:
|
|
73
|
-
region = aws_region
|
|
74
|
-
service = aws_service
|
|
75
|
-
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
|
|
76
|
-
auth = AWSV4SignerAuth(credentials, region or '', service or '')
|
|
77
|
-
|
|
78
|
-
self.aoss_client = AsyncOpenSearch(
|
|
79
|
-
hosts=[{'host': aoss_host, 'port': aoss_port}],
|
|
80
|
-
auth=auth,
|
|
81
|
-
use_ssl=True,
|
|
82
|
-
verify_certs=True,
|
|
83
|
-
connection_class=AIOHttpConnection,
|
|
84
|
-
pool_maxsize=20,
|
|
85
|
-
) # type: ignore
|
|
86
|
-
except Exception as e:
|
|
87
|
-
logger.warning(f'Failed to initialize OpenSearch client: {e}')
|
|
88
|
-
self.aoss_client = None
|
|
89
47
|
|
|
90
48
|
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
|
|
91
49
|
# Check if database_ is provided in kwargs.
|
|
@@ -111,13 +69,6 @@ class Neo4jDriver(GraphDriver):
|
|
|
111
69
|
return await self.client.close()
|
|
112
70
|
|
|
113
71
|
def delete_all_indexes(self) -> Coroutine:
|
|
114
|
-
if self.aoss_client:
|
|
115
|
-
return semaphore_gather(
|
|
116
|
-
self.client.execute_query(
|
|
117
|
-
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
118
|
-
),
|
|
119
|
-
self.delete_aoss_indices(),
|
|
120
|
-
)
|
|
121
72
|
return self.client.execute_query(
|
|
122
73
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
123
74
|
)
|
|
@@ -22,21 +22,16 @@ from typing import Any
|
|
|
22
22
|
|
|
23
23
|
import boto3
|
|
24
24
|
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
|
|
25
|
-
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
|
|
25
|
+
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
|
26
26
|
|
|
27
|
-
from graphiti_core.driver.driver import
|
|
28
|
-
DEFAULT_SIZE,
|
|
29
|
-
GraphDriver,
|
|
30
|
-
GraphDriverSession,
|
|
31
|
-
GraphProvider,
|
|
32
|
-
)
|
|
27
|
+
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
|
|
33
28
|
|
|
34
29
|
logger = logging.getLogger(__name__)
|
|
30
|
+
DEFAULT_SIZE = 10
|
|
35
31
|
|
|
36
|
-
|
|
32
|
+
aoss_indices = [
|
|
37
33
|
{
|
|
38
34
|
'index_name': 'node_name_and_summary',
|
|
39
|
-
'alias_name': 'entities',
|
|
40
35
|
'body': {
|
|
41
36
|
'mappings': {
|
|
42
37
|
'properties': {
|
|
@@ -54,7 +49,6 @@ neptune_aoss_indices = [
|
|
|
54
49
|
},
|
|
55
50
|
{
|
|
56
51
|
'index_name': 'community_name',
|
|
57
|
-
'alias_name': 'communities',
|
|
58
52
|
'body': {
|
|
59
53
|
'mappings': {
|
|
60
54
|
'properties': {
|
|
@@ -71,7 +65,6 @@ neptune_aoss_indices = [
|
|
|
71
65
|
},
|
|
72
66
|
{
|
|
73
67
|
'index_name': 'episode_content',
|
|
74
|
-
'alias_name': 'episodes',
|
|
75
68
|
'body': {
|
|
76
69
|
'mappings': {
|
|
77
70
|
'properties': {
|
|
@@ -95,7 +88,6 @@ neptune_aoss_indices = [
|
|
|
95
88
|
},
|
|
96
89
|
{
|
|
97
90
|
'index_name': 'edge_name_and_fact',
|
|
98
|
-
'alias_name': 'facts',
|
|
99
91
|
'body': {
|
|
100
92
|
'mappings': {
|
|
101
93
|
'properties': {
|
|
@@ -228,27 +220,52 @@ class NeptuneDriver(GraphDriver):
|
|
|
228
220
|
async def _delete_all_data(self) -> Any:
|
|
229
221
|
return await self.execute_query('MATCH (n) DETACH DELETE n')
|
|
230
222
|
|
|
223
|
+
def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
|
|
224
|
+
return self.delete_all_indexes_impl()
|
|
225
|
+
|
|
226
|
+
async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
|
|
227
|
+
# No matter what happens above, always return True
|
|
228
|
+
return self.delete_aoss_indices()
|
|
229
|
+
|
|
231
230
|
async def create_aoss_indices(self):
|
|
232
|
-
for index in
|
|
231
|
+
for index in aoss_indices:
|
|
233
232
|
index_name = index['index_name']
|
|
234
233
|
client = self.aoss_client
|
|
235
|
-
if not client:
|
|
236
|
-
raise ValueError(
|
|
237
|
-
'You must provide an AOSS endpoint to create an OpenSearch driver.'
|
|
238
|
-
)
|
|
239
234
|
if not client.indices.exists(index=index_name):
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
alias_name = index.get('alias_name', index_name)
|
|
243
|
-
|
|
244
|
-
if not client.indices.exists_alias(name=alias_name, index=index_name):
|
|
245
|
-
await client.indices.put_alias(index=index_name, name=alias_name)
|
|
246
|
-
|
|
235
|
+
client.indices.create(index=index_name, body=index['body'])
|
|
247
236
|
# Sleep for 1 minute to let the index creation complete
|
|
248
237
|
await asyncio.sleep(60)
|
|
249
238
|
|
|
250
|
-
def
|
|
251
|
-
|
|
239
|
+
async def delete_aoss_indices(self):
|
|
240
|
+
for index in aoss_indices:
|
|
241
|
+
index_name = index['index_name']
|
|
242
|
+
client = self.aoss_client
|
|
243
|
+
if client.indices.exists(index=index_name):
|
|
244
|
+
client.indices.delete(index=index_name)
|
|
245
|
+
|
|
246
|
+
def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
|
|
247
|
+
for index in aoss_indices:
|
|
248
|
+
if name.lower() == index['index_name']:
|
|
249
|
+
index['query']['query']['multi_match']['query'] = query_text
|
|
250
|
+
query = {'size': limit, 'query': index['query']}
|
|
251
|
+
resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
|
|
252
|
+
return resp
|
|
253
|
+
return {}
|
|
254
|
+
|
|
255
|
+
def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
256
|
+
for index in aoss_indices:
|
|
257
|
+
if name.lower() == index['index_name']:
|
|
258
|
+
to_index = []
|
|
259
|
+
for d in data:
|
|
260
|
+
item = {'_index': name, '_id': d['uuid']}
|
|
261
|
+
for p in index['body']['mappings']['properties']:
|
|
262
|
+
if p in d:
|
|
263
|
+
item[p] = d[p]
|
|
264
|
+
to_index.append(item)
|
|
265
|
+
success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
|
|
266
|
+
return success
|
|
267
|
+
|
|
268
|
+
return 0
|
|
252
269
|
|
|
253
270
|
|
|
254
271
|
class NeptuneDriverSession(GraphDriverSession):
|
|
@@ -33,12 +33,16 @@ DEFAULT_TEMPERATURE = 0
|
|
|
33
33
|
DEFAULT_CACHE_DIR = './llm_cache'
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def get_extraction_language_instruction() -> str:
|
|
36
|
+
def get_extraction_language_instruction(group_id: str | None = None) -> str:
|
|
37
37
|
"""Returns instruction for language extraction behavior.
|
|
38
38
|
|
|
39
39
|
Override this function to customize language extraction:
|
|
40
40
|
- Return empty string to disable multilingual instructions
|
|
41
41
|
- Return custom instructions for specific language requirements
|
|
42
|
+
- Use group_id to provide different instructions per group/partition
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
group_id: Optional partition identifier for the graph
|
|
42
46
|
|
|
43
47
|
Returns:
|
|
44
48
|
str: Language instruction to append to system messages
|
|
@@ -142,6 +146,7 @@ class LLMClient(ABC):
|
|
|
142
146
|
response_model: type[BaseModel] | None = None,
|
|
143
147
|
max_tokens: int | None = None,
|
|
144
148
|
model_size: ModelSize = ModelSize.medium,
|
|
149
|
+
group_id: str | None = None,
|
|
145
150
|
) -> dict[str, typing.Any]:
|
|
146
151
|
if max_tokens is None:
|
|
147
152
|
max_tokens = self.max_tokens
|
|
@@ -155,7 +160,7 @@ class LLMClient(ABC):
|
|
|
155
160
|
)
|
|
156
161
|
|
|
157
162
|
# Add multilingual extraction instructions
|
|
158
|
-
messages[0].content += get_extraction_language_instruction()
|
|
163
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
159
164
|
|
|
160
165
|
if self.cache_enabled and self.cache_dir is not None:
|
|
161
166
|
cache_key = self._get_cache_key(messages)
|
|
@@ -357,6 +357,7 @@ class GeminiClient(LLMClient):
|
|
|
357
357
|
response_model: type[BaseModel] | None = None,
|
|
358
358
|
max_tokens: int | None = None,
|
|
359
359
|
model_size: ModelSize = ModelSize.medium,
|
|
360
|
+
group_id: str | None = None,
|
|
360
361
|
) -> dict[str, typing.Any]:
|
|
361
362
|
"""
|
|
362
363
|
Generate a response from the Gemini language model with retry logic and error handling.
|
|
@@ -367,6 +368,7 @@ class GeminiClient(LLMClient):
|
|
|
367
368
|
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
|
|
368
369
|
max_tokens (int | None): The maximum number of tokens to generate in the response.
|
|
369
370
|
model_size (ModelSize): The size of the model to use (small or medium).
|
|
371
|
+
group_id (str | None): Optional partition identifier for the graph.
|
|
370
372
|
|
|
371
373
|
Returns:
|
|
372
374
|
dict[str, typing.Any]: The response from the language model.
|
|
@@ -376,7 +378,7 @@ class GeminiClient(LLMClient):
|
|
|
376
378
|
last_output = None
|
|
377
379
|
|
|
378
380
|
# Add multilingual extraction instructions
|
|
379
|
-
messages[0].content += get_extraction_language_instruction()
|
|
381
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
380
382
|
|
|
381
383
|
while retry_count < self.MAX_RETRIES:
|
|
382
384
|
try:
|
|
@@ -175,6 +175,7 @@ class BaseOpenAIClient(LLMClient):
|
|
|
175
175
|
response_model: type[BaseModel] | None = None,
|
|
176
176
|
max_tokens: int | None = None,
|
|
177
177
|
model_size: ModelSize = ModelSize.medium,
|
|
178
|
+
group_id: str | None = None,
|
|
178
179
|
) -> dict[str, typing.Any]:
|
|
179
180
|
"""Generate a response with retry logic and error handling."""
|
|
180
181
|
if max_tokens is None:
|
|
@@ -184,7 +185,7 @@ class BaseOpenAIClient(LLMClient):
|
|
|
184
185
|
last_error = None
|
|
185
186
|
|
|
186
187
|
# Add multilingual extraction instructions
|
|
187
|
-
messages[0].content += get_extraction_language_instruction()
|
|
188
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
188
189
|
|
|
189
190
|
while retry_count <= self.MAX_RETRIES:
|
|
190
191
|
try:
|
|
@@ -120,6 +120,7 @@ class OpenAIGenericClient(LLMClient):
|
|
|
120
120
|
response_model: type[BaseModel] | None = None,
|
|
121
121
|
max_tokens: int | None = None,
|
|
122
122
|
model_size: ModelSize = ModelSize.medium,
|
|
123
|
+
group_id: str | None = None,
|
|
123
124
|
) -> dict[str, typing.Any]:
|
|
124
125
|
if max_tokens is None:
|
|
125
126
|
max_tokens = self.max_tokens
|
|
@@ -136,7 +137,7 @@ class OpenAIGenericClient(LLMClient):
|
|
|
136
137
|
)
|
|
137
138
|
|
|
138
139
|
# Add multilingual extraction instructions
|
|
139
|
-
messages[0].content += get_extraction_language_instruction()
|
|
140
|
+
messages[0].content += get_extraction_language_instruction(group_id)
|
|
140
141
|
|
|
141
142
|
while retry_count <= self.MAX_RETRIES:
|
|
142
143
|
try:
|
|
@@ -23,39 +23,44 @@ from .prompt_helpers import to_prompt_json
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class ExtractedEntity(BaseModel):
|
|
26
|
-
name: str = Field(..., description=
|
|
26
|
+
name: str = Field(..., description="Name of the extracted entity")
|
|
27
27
|
entity_type_id: int = Field(
|
|
28
|
-
description=
|
|
29
|
-
|
|
28
|
+
description="ID of the classified entity type. "
|
|
29
|
+
"Must be one of the provided entity_type_id integers.",
|
|
30
30
|
)
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class ExtractedEntities(BaseModel):
|
|
34
|
-
extracted_entities: list[ExtractedEntity] = Field(
|
|
34
|
+
extracted_entities: list[ExtractedEntity] = Field(
|
|
35
|
+
..., description="List of extracted entities"
|
|
36
|
+
)
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
class MissedEntities(BaseModel):
|
|
38
|
-
missed_entities: list[str] = Field(
|
|
40
|
+
missed_entities: list[str] = Field(
|
|
41
|
+
..., description="Names of entities that weren't extracted"
|
|
42
|
+
)
|
|
39
43
|
|
|
40
44
|
|
|
41
45
|
class EntityClassificationTriple(BaseModel):
|
|
42
|
-
uuid: str = Field(description=
|
|
43
|
-
name: str = Field(description=
|
|
46
|
+
uuid: str = Field(description="UUID of the entity")
|
|
47
|
+
name: str = Field(description="Name of the entity")
|
|
44
48
|
entity_type: str | None = Field(
|
|
45
|
-
default=None,
|
|
49
|
+
default=None,
|
|
50
|
+
description="Type of the entity. Must be one of the provided types or None",
|
|
46
51
|
)
|
|
47
52
|
|
|
48
53
|
|
|
49
54
|
class EntityClassification(BaseModel):
|
|
50
55
|
entity_classifications: list[EntityClassificationTriple] = Field(
|
|
51
|
-
..., description=
|
|
56
|
+
..., description="List of entities classification triples."
|
|
52
57
|
)
|
|
53
58
|
|
|
54
59
|
|
|
55
60
|
class EntitySummary(BaseModel):
|
|
56
61
|
summary: str = Field(
|
|
57
62
|
...,
|
|
58
|
-
description=
|
|
63
|
+
description="Summary containing the important information about the entity. Under 8 sentences.",
|
|
59
64
|
)
|
|
60
65
|
|
|
61
66
|
|
|
@@ -123,8 +128,8 @@ reference entities. Only extract distinct entities from the CURRENT MESSAGE. Don
|
|
|
123
128
|
{context['custom_prompt']}
|
|
124
129
|
"""
|
|
125
130
|
return [
|
|
126
|
-
Message(role=
|
|
127
|
-
Message(role=
|
|
131
|
+
Message(role="system", content=sys_prompt),
|
|
132
|
+
Message(role="user", content=user_prompt),
|
|
128
133
|
]
|
|
129
134
|
|
|
130
135
|
|
|
@@ -156,8 +161,8 @@ Guidelines:
|
|
|
156
161
|
3. Do NOT extract any properties that contain dates
|
|
157
162
|
"""
|
|
158
163
|
return [
|
|
159
|
-
Message(role=
|
|
160
|
-
Message(role=
|
|
164
|
+
Message(role="system", content=sys_prompt),
|
|
165
|
+
Message(role="user", content=user_prompt),
|
|
161
166
|
]
|
|
162
167
|
|
|
163
168
|
|
|
@@ -187,8 +192,8 @@ Guidelines:
|
|
|
187
192
|
4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
|
|
188
193
|
"""
|
|
189
194
|
return [
|
|
190
|
-
Message(role=
|
|
191
|
-
Message(role=
|
|
195
|
+
Message(role="system", content=sys_prompt),
|
|
196
|
+
Message(role="user", content=user_prompt),
|
|
192
197
|
]
|
|
193
198
|
|
|
194
199
|
|
|
@@ -211,8 +216,8 @@ Given the above previous messages, current message, and list of extracted entiti
|
|
|
211
216
|
extracted.
|
|
212
217
|
"""
|
|
213
218
|
return [
|
|
214
|
-
Message(role=
|
|
215
|
-
Message(role=
|
|
219
|
+
Message(role="system", content=sys_prompt),
|
|
220
|
+
Message(role="user", content=user_prompt),
|
|
216
221
|
]
|
|
217
222
|
|
|
218
223
|
|
|
@@ -243,19 +248,19 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
|
|
|
243
248
|
3. If none of the provided entity types accurately classify an extracted node, the type should be set to None
|
|
244
249
|
"""
|
|
245
250
|
return [
|
|
246
|
-
Message(role=
|
|
247
|
-
Message(role=
|
|
251
|
+
Message(role="system", content=sys_prompt),
|
|
252
|
+
Message(role="user", content=user_prompt),
|
|
248
253
|
]
|
|
249
254
|
|
|
250
255
|
|
|
251
256
|
def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
|
252
257
|
return [
|
|
253
258
|
Message(
|
|
254
|
-
role=
|
|
255
|
-
content=
|
|
259
|
+
role="system",
|
|
260
|
+
content="You are a helpful assistant that extracts entity properties from the provided text.",
|
|
256
261
|
),
|
|
257
262
|
Message(
|
|
258
|
-
role=
|
|
263
|
+
role="user",
|
|
259
264
|
content=f"""
|
|
260
265
|
|
|
261
266
|
<MESSAGES>
|
|
@@ -281,11 +286,11 @@ def extract_attributes(context: dict[str, Any]) -> list[Message]:
|
|
|
281
286
|
def extract_summary(context: dict[str, Any]) -> list[Message]:
|
|
282
287
|
return [
|
|
283
288
|
Message(
|
|
284
|
-
role=
|
|
285
|
-
content=
|
|
289
|
+
role="system",
|
|
290
|
+
content="You are a helpful assistant that extracts entity summaries from the provided text.",
|
|
286
291
|
),
|
|
287
292
|
Message(
|
|
288
|
-
role=
|
|
293
|
+
role="user",
|
|
289
294
|
content=f"""
|
|
290
295
|
|
|
291
296
|
<MESSAGES>
|
|
@@ -300,7 +305,7 @@ def extract_summary(context: dict[str, Any]) -> list[Message]:
|
|
|
300
305
|
1. Do not hallucinate entity summary information if they cannot be found in the current context.
|
|
301
306
|
2. Only use the provided MESSAGES and ENTITY to set attribute values.
|
|
302
307
|
3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
|
|
303
|
-
|
|
308
|
+
4. Keep the summary concise and to the point. SUMMARIES MUST BE LESS THAN 8 SENTENCES.
|
|
304
309
|
|
|
305
310
|
<ENTITY>
|
|
306
311
|
{context['node']}
|
|
@@ -311,11 +316,11 @@ def extract_summary(context: dict[str, Any]) -> list[Message]:
|
|
|
311
316
|
|
|
312
317
|
|
|
313
318
|
versions: Versions = {
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
319
|
+
"extract_message": extract_message,
|
|
320
|
+
"extract_json": extract_json,
|
|
321
|
+
"extract_text": extract_text,
|
|
322
|
+
"reflexion": reflexion,
|
|
323
|
+
"extract_summary": extract_summary,
|
|
324
|
+
"classify_nodes": classify_nodes,
|
|
325
|
+
"extract_attributes": extract_attributes,
|
|
321
326
|
}
|
|
@@ -25,12 +25,14 @@ from .prompt_helpers import to_prompt_json
|
|
|
25
25
|
class Summary(BaseModel):
|
|
26
26
|
summary: str = Field(
|
|
27
27
|
...,
|
|
28
|
-
description=
|
|
28
|
+
description="Summary containing the important information about the entity. Under 8 sentences",
|
|
29
29
|
)
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class SummaryDescription(BaseModel):
|
|
33
|
-
description: str = Field(
|
|
33
|
+
description: str = Field(
|
|
34
|
+
..., description="One sentence description of the provided summary"
|
|
35
|
+
)
|
|
34
36
|
|
|
35
37
|
|
|
36
38
|
class Prompt(Protocol):
|
|
@@ -48,15 +50,15 @@ class Versions(TypedDict):
|
|
|
48
50
|
def summarize_pair(context: dict[str, Any]) -> list[Message]:
|
|
49
51
|
return [
|
|
50
52
|
Message(
|
|
51
|
-
role=
|
|
52
|
-
content=
|
|
53
|
+
role="system",
|
|
54
|
+
content="You are a helpful assistant that combines summaries.",
|
|
53
55
|
),
|
|
54
56
|
Message(
|
|
55
|
-
role=
|
|
57
|
+
role="user",
|
|
56
58
|
content=f"""
|
|
57
59
|
Synthesize the information from the following two summaries into a single succinct summary.
|
|
58
60
|
|
|
59
|
-
|
|
61
|
+
IMPORTANT: Keep the summary concise and to the point. SUMMARIES MUST BE LESS THAN 8 SENTENCES.
|
|
60
62
|
|
|
61
63
|
Summaries:
|
|
62
64
|
{to_prompt_json(context['node_summaries'], indent=2)}
|
|
@@ -68,11 +70,11 @@ def summarize_pair(context: dict[str, Any]) -> list[Message]:
|
|
|
68
70
|
def summarize_context(context: dict[str, Any]) -> list[Message]:
|
|
69
71
|
return [
|
|
70
72
|
Message(
|
|
71
|
-
role=
|
|
72
|
-
content=
|
|
73
|
+
role="system",
|
|
74
|
+
content="You are a helpful assistant that generates a summary and attributes from provided text.",
|
|
73
75
|
),
|
|
74
76
|
Message(
|
|
75
|
-
role=
|
|
77
|
+
role="user",
|
|
76
78
|
content=f"""
|
|
77
79
|
|
|
78
80
|
<MESSAGES>
|
|
@@ -82,7 +84,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
|
|
82
84
|
|
|
83
85
|
Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
|
|
84
86
|
information from the provided MESSAGES. Your summary should also only contain information relevant to the
|
|
85
|
-
provided ENTITY.
|
|
87
|
+
provided ENTITY.
|
|
86
88
|
|
|
87
89
|
In addition, extract any values for the provided entity properties based on their descriptions.
|
|
88
90
|
If the value of the entity property cannot be found in the current context, set the value of the property to the Python value None.
|
|
@@ -90,6 +92,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
|
|
90
92
|
Guidelines:
|
|
91
93
|
1. Do not hallucinate entity property values if they cannot be found in the current context.
|
|
92
94
|
2. Only use the provided messages, entity, and entity context to set attribute values.
|
|
95
|
+
3. Keep the summary concise and to the point. SUMMARIES MUST BE LESS THAN 8 SENTENCES.
|
|
93
96
|
|
|
94
97
|
<ENTITY>
|
|
95
98
|
{context['node_name']}
|
|
@@ -110,14 +113,14 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
|
|
110
113
|
def summary_description(context: dict[str, Any]) -> list[Message]:
|
|
111
114
|
return [
|
|
112
115
|
Message(
|
|
113
|
-
role=
|
|
114
|
-
content=
|
|
116
|
+
role="system",
|
|
117
|
+
content="You are a helpful assistant that describes provided contents in a single sentence.",
|
|
115
118
|
),
|
|
116
119
|
Message(
|
|
117
|
-
role=
|
|
120
|
+
role="user",
|
|
118
121
|
content=f"""
|
|
119
122
|
Create a short one sentence description of the summary that explains what kind of information is summarized.
|
|
120
|
-
Summaries must be under
|
|
123
|
+
Summaries must be under 8 sentences.
|
|
121
124
|
|
|
122
125
|
Summary:
|
|
123
126
|
{to_prompt_json(context['summary'], indent=2)}
|
|
@@ -127,7 +130,7 @@ def summary_description(context: dict[str, Any]) -> list[Message]:
|
|
|
127
130
|
|
|
128
131
|
|
|
129
132
|
versions: Versions = {
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
+
"summarize_pair": summarize_pair,
|
|
134
|
+
"summarize_context": summarize_context,
|
|
135
|
+
"summary_description": summary_description,
|
|
133
136
|
}
|
|
@@ -139,6 +139,7 @@ async def extract_edges(
|
|
|
139
139
|
prompt_library.extract_edges.edge(context),
|
|
140
140
|
response_model=ExtractedEdges,
|
|
141
141
|
max_tokens=extract_edges_max_tokens,
|
|
142
|
+
group_id=group_id,
|
|
142
143
|
)
|
|
143
144
|
edges_data = ExtractedEdges(**llm_response).edges
|
|
144
145
|
|
|
@@ -150,6 +151,7 @@ async def extract_edges(
|
|
|
150
151
|
prompt_library.extract_edges.reflexion(context),
|
|
151
152
|
response_model=MissingFacts,
|
|
152
153
|
max_tokens=extract_edges_max_tokens,
|
|
154
|
+
group_id=group_id,
|
|
153
155
|
)
|
|
154
156
|
|
|
155
157
|
missing_facts = reflexion_response.get('missing_facts', [])
|
|
@@ -64,6 +64,7 @@ async def extract_nodes_reflexion(
|
|
|
64
64
|
episode: EpisodicNode,
|
|
65
65
|
previous_episodes: list[EpisodicNode],
|
|
66
66
|
node_names: list[str],
|
|
67
|
+
group_id: str | None = None,
|
|
67
68
|
) -> list[str]:
|
|
68
69
|
# Prepare context for LLM
|
|
69
70
|
context = {
|
|
@@ -73,7 +74,9 @@ async def extract_nodes_reflexion(
|
|
|
73
74
|
}
|
|
74
75
|
|
|
75
76
|
llm_response = await llm_client.generate_response(
|
|
76
|
-
prompt_library.extract_nodes.reflexion(context),
|
|
77
|
+
prompt_library.extract_nodes.reflexion(context),
|
|
78
|
+
MissedEntities,
|
|
79
|
+
group_id=group_id,
|
|
77
80
|
)
|
|
78
81
|
missed_entities = llm_response.get('missed_entities', [])
|
|
79
82
|
|
|
@@ -129,16 +132,19 @@ async def extract_nodes(
|
|
|
129
132
|
llm_response = await llm_client.generate_response(
|
|
130
133
|
prompt_library.extract_nodes.extract_message(context),
|
|
131
134
|
response_model=ExtractedEntities,
|
|
135
|
+
group_id=episode.group_id,
|
|
132
136
|
)
|
|
133
137
|
elif episode.source == EpisodeType.text:
|
|
134
138
|
llm_response = await llm_client.generate_response(
|
|
135
139
|
prompt_library.extract_nodes.extract_text(context),
|
|
136
140
|
response_model=ExtractedEntities,
|
|
141
|
+
group_id=episode.group_id,
|
|
137
142
|
)
|
|
138
143
|
elif episode.source == EpisodeType.json:
|
|
139
144
|
llm_response = await llm_client.generate_response(
|
|
140
145
|
prompt_library.extract_nodes.extract_json(context),
|
|
141
146
|
response_model=ExtractedEntities,
|
|
147
|
+
group_id=episode.group_id,
|
|
142
148
|
)
|
|
143
149
|
|
|
144
150
|
response_object = ExtractedEntities(**llm_response)
|
|
@@ -152,6 +158,7 @@ async def extract_nodes(
|
|
|
152
158
|
episode,
|
|
153
159
|
previous_episodes,
|
|
154
160
|
[entity.name for entity in extracted_entities],
|
|
161
|
+
episode.group_id,
|
|
155
162
|
)
|
|
156
163
|
|
|
157
164
|
entities_missed = len(missing_entities) != 0
|
|
@@ -478,63 +485,95 @@ async def extract_attributes_from_node(
|
|
|
478
485
|
entity_type: type[BaseModel] | None = None,
|
|
479
486
|
should_summarize_node: NodeSummaryFilter | None = None,
|
|
480
487
|
) -> EntityNode:
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
'attributes': node.attributes,
|
|
486
|
-
}
|
|
488
|
+
# Extract attributes if entity type is defined and has attributes
|
|
489
|
+
llm_response = await _extract_entity_attributes(
|
|
490
|
+
llm_client, node, episode, previous_episodes, entity_type
|
|
491
|
+
)
|
|
487
492
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
493
|
-
),
|
|
494
|
-
}
|
|
493
|
+
# Extract summary if needed
|
|
494
|
+
await _extract_entity_summary(
|
|
495
|
+
llm_client, node, episode, previous_episodes, should_summarize_node
|
|
496
|
+
)
|
|
495
497
|
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
'previous_episodes': (
|
|
500
|
-
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
501
|
-
),
|
|
502
|
-
}
|
|
498
|
+
node.attributes.update(llm_response)
|
|
499
|
+
|
|
500
|
+
return node
|
|
503
501
|
|
|
504
|
-
|
|
505
|
-
|
|
502
|
+
|
|
503
|
+
async def _extract_entity_attributes(
|
|
504
|
+
llm_client: LLMClient,
|
|
505
|
+
node: EntityNode,
|
|
506
|
+
episode: EpisodicNode | None,
|
|
507
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
508
|
+
entity_type: type[BaseModel] | None,
|
|
509
|
+
) -> dict[str, Any]:
|
|
510
|
+
if entity_type is None or len(entity_type.model_fields) == 0:
|
|
511
|
+
return {}
|
|
512
|
+
|
|
513
|
+
attributes_context = _build_episode_context(
|
|
514
|
+
# should not include summary
|
|
515
|
+
node_data={
|
|
516
|
+
'name': node.name,
|
|
517
|
+
'entity_types': node.labels,
|
|
518
|
+
'attributes': node.attributes,
|
|
519
|
+
},
|
|
520
|
+
episode=episode,
|
|
521
|
+
previous_episodes=previous_episodes,
|
|
506
522
|
)
|
|
507
523
|
|
|
508
|
-
llm_response = (
|
|
509
|
-
(
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
model_size=ModelSize.small,
|
|
514
|
-
)
|
|
515
|
-
)
|
|
516
|
-
if has_entity_attributes
|
|
517
|
-
else {}
|
|
524
|
+
llm_response = await llm_client.generate_response(
|
|
525
|
+
prompt_library.extract_nodes.extract_attributes(attributes_context),
|
|
526
|
+
response_model=entity_type,
|
|
527
|
+
model_size=ModelSize.small,
|
|
528
|
+
group_id=node.group_id,
|
|
518
529
|
)
|
|
519
530
|
|
|
520
|
-
#
|
|
521
|
-
|
|
522
|
-
if should_summarize_node is not None:
|
|
523
|
-
generate_summary = await should_summarize_node(node)
|
|
524
|
-
|
|
525
|
-
# Conditionally generate summary
|
|
526
|
-
if generate_summary:
|
|
527
|
-
summary_response = await llm_client.generate_response(
|
|
528
|
-
prompt_library.extract_nodes.extract_summary(summary_context),
|
|
529
|
-
response_model=EntitySummary,
|
|
530
|
-
model_size=ModelSize.small,
|
|
531
|
-
)
|
|
532
|
-
node.summary = summary_response.get('summary', '')
|
|
531
|
+
# validate response
|
|
532
|
+
entity_type(**llm_response)
|
|
533
533
|
|
|
534
|
-
|
|
535
|
-
entity_type(**llm_response)
|
|
536
|
-
node_attributes = {key: value for key, value in llm_response.items()}
|
|
534
|
+
return llm_response
|
|
537
535
|
|
|
538
|
-
node.attributes.update(node_attributes)
|
|
539
536
|
|
|
540
|
-
|
|
537
|
+
async def _extract_entity_summary(
|
|
538
|
+
llm_client: LLMClient,
|
|
539
|
+
node: EntityNode,
|
|
540
|
+
episode: EpisodicNode | None,
|
|
541
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
542
|
+
should_summarize_node: NodeSummaryFilter | None,
|
|
543
|
+
) -> None:
|
|
544
|
+
if should_summarize_node is not None and not await should_summarize_node(node):
|
|
545
|
+
return
|
|
546
|
+
|
|
547
|
+
summary_context = _build_episode_context(
|
|
548
|
+
node_data={
|
|
549
|
+
'name': node.name,
|
|
550
|
+
'summary': node.summary,
|
|
551
|
+
'entity_types': node.labels,
|
|
552
|
+
'attributes': node.attributes,
|
|
553
|
+
},
|
|
554
|
+
episode=episode,
|
|
555
|
+
previous_episodes=previous_episodes,
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
summary_response = await llm_client.generate_response(
|
|
559
|
+
prompt_library.extract_nodes.extract_summary(summary_context),
|
|
560
|
+
response_model=EntitySummary,
|
|
561
|
+
model_size=ModelSize.small,
|
|
562
|
+
group_id=node.group_id,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
node.summary = summary_response.get('summary', '')
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def _build_episode_context(
|
|
569
|
+
node_data: dict[str, Any],
|
|
570
|
+
episode: EpisodicNode | None,
|
|
571
|
+
previous_episodes: list[EpisodicNode] | None,
|
|
572
|
+
) -> dict[str, Any]:
|
|
573
|
+
return {
|
|
574
|
+
'node': node_data,
|
|
575
|
+
'episode_content': episode.content if episode is not None else '',
|
|
576
|
+
'previous_episodes': (
|
|
577
|
+
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
|
|
578
|
+
),
|
|
579
|
+
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.22.0rc1
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
Project-URL: Homepage, https://help.getzep.com/graphiti/graphiti/overview
|
|
6
6
|
Project-URL: Repository, https://github.com/getzep/graphiti
|
|
@@ -13,11 +13,11 @@ graphiti_core/cross_encoder/client.py,sha256=KLsbfWKOEaAV3adFe3XZlAeb-gje9_sVKCV
|
|
|
13
13
|
graphiti_core/cross_encoder/gemini_reranker_client.py,sha256=hmITG5YIib52nrKvINwRi4xTfAO1U4jCCaEVIwImHw0,6208
|
|
14
14
|
graphiti_core/cross_encoder/openai_reranker_client.py,sha256=WHMl6Q6gEslR2EzjwpFSZt2Kh6bnu8alkLvzmi0MDtg,4674
|
|
15
15
|
graphiti_core/driver/__init__.py,sha256=kCWimqQU19airu5gKwCmZtZuXkDfaQfKSUhMDoL-rTA,626
|
|
16
|
-
graphiti_core/driver/driver.py,sha256=
|
|
16
|
+
graphiti_core/driver/driver.py,sha256=sF6CkGLNPIvUgrmWkVws7TvQCskRHiQKJze4Y4ibMmI,3357
|
|
17
17
|
graphiti_core/driver/falkordb_driver.py,sha256=Q-dImfK4O2bkikqFzo0Wg2g7iFFRSuzy_c6u82tX6-M,9361
|
|
18
18
|
graphiti_core/driver/kuzu_driver.py,sha256=RcWu8E0CCdofrFe34NmCeqfuhaZr_7ZN5jqDkI3VQMI,5453
|
|
19
|
-
graphiti_core/driver/neo4j_driver.py,sha256=
|
|
20
|
-
graphiti_core/driver/neptune_driver.py,sha256=
|
|
19
|
+
graphiti_core/driver/neo4j_driver.py,sha256=xiMUvGpW-XFM_2ab5nJJTHoi_LM7CvVZVq6ZO0BbNwc,2380
|
|
20
|
+
graphiti_core/driver/neptune_driver.py,sha256=dyQcaA5VnpNA_XkaWdvgGN3Q0QqbxWcVIud--yT8qhE,11266
|
|
21
21
|
graphiti_core/embedder/__init__.py,sha256=EL564ZuE-DZjcuKNUK_exMn_XHXm2LdO9fzdXePVKL4,179
|
|
22
22
|
graphiti_core/embedder/azure_openai.py,sha256=OyomPwC1fIsddI-3n6g00kQFdQznZorBhHwkQKCLUok,2384
|
|
23
23
|
graphiti_core/embedder/client.py,sha256=BXFMXvuPWxaAzPaPILnxtqQQ4JWBFQv9GdBLOXUWgwE,1158
|
|
@@ -27,14 +27,14 @@ graphiti_core/embedder/voyage.py,sha256=oJHAZiNqjdEJOKgoKfGWcxK2-Ewqn5UB3vrBwIwP
|
|
|
27
27
|
graphiti_core/llm_client/__init__.py,sha256=QgBWUiCeBp6YiA_xqyrDvJ9jIyy1hngH8g7FWahN3nw,776
|
|
28
28
|
graphiti_core/llm_client/anthropic_client.py,sha256=xTFcrgMDK77BwnChBhYj51Jaa2mRNI850oJv2pKZI0A,12892
|
|
29
29
|
graphiti_core/llm_client/azure_openai_client.py,sha256=ekERggAekbb7enes1RJqdRChf_mjaZTFXsnMbxO7azQ,2497
|
|
30
|
-
graphiti_core/llm_client/client.py,sha256=
|
|
30
|
+
graphiti_core/llm_client/client.py,sha256=xF3KtXbgP0jC6nKHtIiP5m9dNzxuZaqqQHCKiexijjU,7053
|
|
31
31
|
graphiti_core/llm_client/config.py,sha256=pivp29CDIbDPqgw5NF9Ok2AwcqTV5z5_Q1bgNs1CDGs,2560
|
|
32
32
|
graphiti_core/llm_client/errors.py,sha256=pn6brRiLW60DAUIXJYKBT6MInrS4ueuH1hNLbn_JbQo,1243
|
|
33
|
-
graphiti_core/llm_client/gemini_client.py,sha256=
|
|
33
|
+
graphiti_core/llm_client/gemini_client.py,sha256=ohwuvJ-YTJ67xr6t5UYwSFo87WsyHeMiu8vNCifHod0,17850
|
|
34
34
|
graphiti_core/llm_client/groq_client.py,sha256=bYLE_cg1QEhugsJOXh4b1vPbxagKeMWqk48240GCzMs,2922
|
|
35
|
-
graphiti_core/llm_client/openai_base_client.py,sha256=
|
|
35
|
+
graphiti_core/llm_client/openai_base_client.py,sha256=HGt4CyyFCSZyCBwR__IbUUBF0V6Qwr9Ydu_XLtXPIA8,8533
|
|
36
36
|
graphiti_core/llm_client/openai_client.py,sha256=AuaCFQFMJEGzBkFVouccq3XentmWRIKW0RLRBCUMm7Y,3763
|
|
37
|
-
graphiti_core/llm_client/openai_generic_client.py,sha256=
|
|
37
|
+
graphiti_core/llm_client/openai_generic_client.py,sha256=UseKg9rCqXizAdG1xGGU-jnfwuWJCvVkf-legT0MqjQ,7052
|
|
38
38
|
graphiti_core/llm_client/utils.py,sha256=zKpxXEbKa369m4W7RDEf-m56kH46V1Mx3RowcWZEWWs,1000
|
|
39
39
|
graphiti_core/migrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
40
40
|
graphiti_core/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -48,12 +48,12 @@ graphiti_core/prompts/dedupe_nodes.py,sha256=YNNo19Cq8koLVoLCafpjYJOy5nmRZ-tEWhv
|
|
|
48
48
|
graphiti_core/prompts/eval.py,sha256=GWFkfZoPfY8U7mV8Ngd_5a2S2fHS7KjajChntxv1UEY,5360
|
|
49
49
|
graphiti_core/prompts/extract_edge_dates.py,sha256=3Drs3CmvP0gJN5BidWSxrNvLet3HPoTybU3BUIAoc0Y,4218
|
|
50
50
|
graphiti_core/prompts/extract_edges.py,sha256=-yOIvCPwxIAXeqYpNCzouE6i3WfdsexzRXFmcXpQpAg,7113
|
|
51
|
-
graphiti_core/prompts/extract_nodes.py,sha256=
|
|
51
|
+
graphiti_core/prompts/extract_nodes.py,sha256=jMD-XRi4U3kjp9smHtA_kvnMBGWBfpBoKc45IoTIZs0,11360
|
|
52
52
|
graphiti_core/prompts/invalidate_edges.py,sha256=yfpcs_pyctnoM77ULPZXEtKW0oHr1MeLsJzC5yrE-o4,3547
|
|
53
53
|
graphiti_core/prompts/lib.py,sha256=DCyHePM4_q-CptTpEXGO_dBv9k7xDtclEaB1dGu7EcI,4092
|
|
54
54
|
graphiti_core/prompts/models.py,sha256=NgxdbPHJpBEcpbXovKyScgpBc73Q-GIW-CBDlBtDjto,894
|
|
55
55
|
graphiti_core/prompts/prompt_helpers.py,sha256=dpWbB8IYAqAZoU5qBx896jozKiQJTng4dGzWewZ_s4c,814
|
|
56
|
-
graphiti_core/prompts/summarize_nodes.py,sha256=
|
|
56
|
+
graphiti_core/prompts/summarize_nodes.py,sha256=7WnjRgYo1Z9bfnUWaUUXbiaLqygGLpemvB5inhhq44Y,4340
|
|
57
57
|
graphiti_core/search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
58
58
|
graphiti_core/search/search.py,sha256=2kj7fybSFv6Fnf_cfEUhJhrpfzNtmkPPZ0hV3BQCDqg,18387
|
|
59
59
|
graphiti_core/search/search_config.py,sha256=v_rUHsu1yo5OuPfEm21lSuXexQs-o8qYwSSemW2QWhU,4165
|
|
@@ -69,13 +69,13 @@ graphiti_core/utils/datetime_utils.py,sha256=J-zYSq7-H-2n9hYOXNIun12kM10vNX9mMAT
|
|
|
69
69
|
graphiti_core/utils/maintenance/__init__.py,sha256=vW4H1KyapTl-OOz578uZABYcpND4wPx3Vt6aAPaXh78,301
|
|
70
70
|
graphiti_core/utils/maintenance/community_operations.py,sha256=3IMxfOacZAYtZKebyYtWJYNZPLOPlS8Il-lzitEkoos,10681
|
|
71
71
|
graphiti_core/utils/maintenance/dedup_helpers.py,sha256=B7k6KkB6Sii8PZCWNNTvsNiy4BNTNWpoLeGgrPLq6BE,9220
|
|
72
|
-
graphiti_core/utils/maintenance/edge_operations.py,sha256=
|
|
72
|
+
graphiti_core/utils/maintenance/edge_operations.py,sha256=1hlcJRFnxthGkSr07QyDcOVug7N8dQj5aIENJ17JrpA,26564
|
|
73
73
|
graphiti_core/utils/maintenance/graph_data_operations.py,sha256=42icj3S_ELAJ-NK3jVS_rg_243dmnaZOyUitJj_uJ-M,6085
|
|
74
|
-
graphiti_core/utils/maintenance/node_operations.py,sha256=
|
|
74
|
+
graphiti_core/utils/maintenance/node_operations.py,sha256=ARng4x_pCpfA3g4bM0BncOkxBPaQ2IsdIaYfVq3V3X0,19603
|
|
75
75
|
graphiti_core/utils/maintenance/temporal_operations.py,sha256=wq1I4kqeIoswit6sPohug91FEwrGaVnJ06g1vkJjSLY,3442
|
|
76
76
|
graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
77
|
graphiti_core/utils/ontology_utils/entity_types_utils.py,sha256=4eVgxLWY6Q8k9cRJ5pW59IYF--U4nXZsZIGOVb_yHfQ,1285
|
|
78
|
-
graphiti_core-0.
|
|
79
|
-
graphiti_core-0.
|
|
80
|
-
graphiti_core-0.
|
|
81
|
-
graphiti_core-0.
|
|
78
|
+
graphiti_core-0.22.0rc1.dist-info/METADATA,sha256=NlIXn-TmrQ-_u-6CI6I7sEC7ioBKvQIKEl0oyqRq4YM,27084
|
|
79
|
+
graphiti_core-0.22.0rc1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
80
|
+
graphiti_core-0.22.0rc1.dist-info/licenses/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
|
|
81
|
+
graphiti_core-0.22.0rc1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|