graphiti-core 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/edges.py +107 -61
- graphiti_core/errors.py +18 -0
- graphiti_core/graphiti.py +62 -30
- graphiti_core/llm_client/__init__.py +2 -1
- graphiti_core/llm_client/anthropic_client.py +9 -1
- graphiti_core/llm_client/client.py +17 -10
- graphiti_core/llm_client/errors.py +6 -0
- graphiti_core/llm_client/groq_client.py +4 -0
- graphiti_core/llm_client/openai_client.py +4 -0
- graphiti_core/nodes.py +183 -58
- graphiti_core/prompts/extract_nodes.py +43 -1
- graphiti_core/prompts/lib.py +6 -0
- graphiti_core/prompts/summarize_nodes.py +79 -0
- graphiti_core/py.typed +1 -0
- graphiti_core/search/search.py +5 -2
- graphiti_core/search/search_utils.py +101 -165
- graphiti_core/utils/bulk_utils.py +31 -3
- graphiti_core/utils/maintenance/community_operations.py +155 -0
- graphiti_core/utils/maintenance/edge_operations.py +27 -7
- graphiti_core/utils/maintenance/graph_data_operations.py +28 -8
- graphiti_core/utils/maintenance/node_operations.py +27 -1
- {graphiti_core-0.2.2.dist-info → graphiti_core-0.3.0.dist-info}/METADATA +8 -3
- graphiti_core-0.3.0.dist-info/RECORD +41 -0
- graphiti_core/utils/utils.py +0 -60
- graphiti_core-0.2.2.dist-info/RECORD +0 -37
- {graphiti_core-0.2.2.dist-info → graphiti_core-0.3.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.2.2.dist-info → graphiti_core-0.3.0.dist-info}/WHEEL +0 -0
graphiti_core/edges.py
CHANGED
|
@@ -18,11 +18,13 @@ import logging
|
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from time import time
|
|
21
|
+
from typing import Any
|
|
21
22
|
from uuid import uuid4
|
|
22
23
|
|
|
23
24
|
from neo4j import AsyncDriver
|
|
24
25
|
from pydantic import BaseModel, Field
|
|
25
26
|
|
|
27
|
+
from graphiti_core.errors import EdgeNotFoundError
|
|
26
28
|
from graphiti_core.helpers import parse_db_date
|
|
27
29
|
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
28
30
|
from graphiti_core.nodes import Node
|
|
@@ -32,6 +34,7 @@ logger = logging.getLogger(__name__)
|
|
|
32
34
|
|
|
33
35
|
class Edge(BaseModel, ABC):
|
|
34
36
|
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
|
37
|
+
group_id: str | None = Field(description='partition of the graph')
|
|
35
38
|
source_node_uuid: str
|
|
36
39
|
target_node_uuid: str
|
|
37
40
|
created_at: datetime
|
|
@@ -39,8 +42,18 @@ class Edge(BaseModel, ABC):
|
|
|
39
42
|
@abstractmethod
|
|
40
43
|
async def save(self, driver: AsyncDriver): ...
|
|
41
44
|
|
|
42
|
-
|
|
43
|
-
|
|
45
|
+
async def delete(self, driver: AsyncDriver):
|
|
46
|
+
result = await driver.execute_query(
|
|
47
|
+
"""
|
|
48
|
+
MATCH (n)-[e {uuid: $uuid}]->(m)
|
|
49
|
+
DELETE e
|
|
50
|
+
""",
|
|
51
|
+
uuid=self.uuid,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
logger.info(f'Deleted Edge: {self.uuid}')
|
|
55
|
+
|
|
56
|
+
return result
|
|
44
57
|
|
|
45
58
|
def __hash__(self):
|
|
46
59
|
return hash(self.uuid)
|
|
@@ -61,11 +74,12 @@ class EpisodicEdge(Edge):
|
|
|
61
74
|
MATCH (episode:Episodic {uuid: $episode_uuid})
|
|
62
75
|
MATCH (node:Entity {uuid: $entity_uuid})
|
|
63
76
|
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
|
|
64
|
-
SET r = {uuid: $uuid, created_at: $created_at}
|
|
77
|
+
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
65
78
|
RETURN r.uuid AS uuid""",
|
|
66
79
|
episode_uuid=self.source_node_uuid,
|
|
67
80
|
entity_uuid=self.target_node_uuid,
|
|
68
81
|
uuid=self.uuid,
|
|
82
|
+
group_id=self.group_id,
|
|
69
83
|
created_at=self.created_at,
|
|
70
84
|
)
|
|
71
85
|
|
|
@@ -73,26 +87,14 @@ class EpisodicEdge(Edge):
|
|
|
73
87
|
|
|
74
88
|
return result
|
|
75
89
|
|
|
76
|
-
async def delete(self, driver: AsyncDriver):
|
|
77
|
-
result = await driver.execute_query(
|
|
78
|
-
"""
|
|
79
|
-
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
80
|
-
DELETE e
|
|
81
|
-
""",
|
|
82
|
-
uuid=self.uuid,
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
logger.info(f'Deleted Edge: {self.uuid}')
|
|
86
|
-
|
|
87
|
-
return result
|
|
88
|
-
|
|
89
90
|
@classmethod
|
|
90
91
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
91
92
|
records, _, _ = await driver.execute_query(
|
|
92
93
|
"""
|
|
93
94
|
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
|
94
95
|
RETURN
|
|
95
|
-
e.uuid As uuid,
|
|
96
|
+
e.uuid As uuid,
|
|
97
|
+
e.group_id AS group_id,
|
|
96
98
|
n.uuid AS source_node_uuid,
|
|
97
99
|
m.uuid AS target_node_uuid,
|
|
98
100
|
e.created_at AS created_at
|
|
@@ -100,20 +102,11 @@ class EpisodicEdge(Edge):
|
|
|
100
102
|
uuid=uuid,
|
|
101
103
|
)
|
|
102
104
|
|
|
103
|
-
edges
|
|
104
|
-
|
|
105
|
-
for record in records:
|
|
106
|
-
edges.append(
|
|
107
|
-
EpisodicEdge(
|
|
108
|
-
uuid=record['uuid'],
|
|
109
|
-
source_node_uuid=record['source_node_uuid'],
|
|
110
|
-
target_node_uuid=record['target_node_uuid'],
|
|
111
|
-
created_at=record['created_at'].to_native(),
|
|
112
|
-
)
|
|
113
|
-
)
|
|
105
|
+
edges = [get_episodic_edge_from_record(record) for record in records]
|
|
114
106
|
|
|
115
107
|
logger.info(f'Found Edge: {uuid}')
|
|
116
|
-
|
|
108
|
+
if len(edges) == 0:
|
|
109
|
+
raise EdgeNotFoundError(uuid)
|
|
117
110
|
return edges[0]
|
|
118
111
|
|
|
119
112
|
|
|
@@ -153,7 +146,7 @@ class EntityEdge(Edge):
|
|
|
153
146
|
MATCH (source:Entity {uuid: $source_uuid})
|
|
154
147
|
MATCH (target:Entity {uuid: $target_uuid})
|
|
155
148
|
MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
|
|
156
|
-
SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
|
|
149
|
+
SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
|
|
157
150
|
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
|
|
158
151
|
valid_at: $valid_at, invalid_at: $invalid_at}
|
|
159
152
|
RETURN r.uuid AS uuid""",
|
|
@@ -161,6 +154,7 @@ class EntityEdge(Edge):
|
|
|
161
154
|
target_uuid=self.target_node_uuid,
|
|
162
155
|
uuid=self.uuid,
|
|
163
156
|
name=self.name,
|
|
157
|
+
group_id=self.group_id,
|
|
164
158
|
fact=self.fact,
|
|
165
159
|
fact_embedding=self.fact_embedding,
|
|
166
160
|
episodes=self.episodes,
|
|
@@ -174,19 +168,6 @@ class EntityEdge(Edge):
|
|
|
174
168
|
|
|
175
169
|
return result
|
|
176
170
|
|
|
177
|
-
async def delete(self, driver: AsyncDriver):
|
|
178
|
-
result = await driver.execute_query(
|
|
179
|
-
"""
|
|
180
|
-
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
181
|
-
DELETE e
|
|
182
|
-
""",
|
|
183
|
-
uuid=self.uuid,
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
logger.info(f'Deleted Edge: {self.uuid}')
|
|
187
|
-
|
|
188
|
-
return result
|
|
189
|
-
|
|
190
171
|
@classmethod
|
|
191
172
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
192
173
|
records, _, _ = await driver.execute_query(
|
|
@@ -198,6 +179,7 @@ class EntityEdge(Edge):
|
|
|
198
179
|
m.uuid AS target_node_uuid,
|
|
199
180
|
e.created_at AS created_at,
|
|
200
181
|
e.name AS name,
|
|
182
|
+
e.group_id AS group_id,
|
|
201
183
|
e.fact AS fact,
|
|
202
184
|
e.fact_embedding AS fact_embedding,
|
|
203
185
|
e.episodes AS episodes,
|
|
@@ -208,25 +190,89 @@ class EntityEdge(Edge):
|
|
|
208
190
|
uuid=uuid,
|
|
209
191
|
)
|
|
210
192
|
|
|
211
|
-
edges
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
193
|
+
edges = [get_entity_edge_from_record(record) for record in records]
|
|
194
|
+
|
|
195
|
+
logger.info(f'Found Edge: {uuid}')
|
|
196
|
+
if len(edges) == 0:
|
|
197
|
+
raise EdgeNotFoundError(uuid)
|
|
198
|
+
return edges[0]
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class CommunityEdge(Edge):
|
|
202
|
+
async def save(self, driver: AsyncDriver):
|
|
203
|
+
result = await driver.execute_query(
|
|
204
|
+
"""
|
|
205
|
+
MATCH (community:Community {uuid: $community_uuid})
|
|
206
|
+
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
|
207
|
+
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
|
208
|
+
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
|
209
|
+
RETURN r.uuid AS uuid""",
|
|
210
|
+
community_uuid=self.source_node_uuid,
|
|
211
|
+
entity_uuid=self.target_node_uuid,
|
|
212
|
+
uuid=self.uuid,
|
|
213
|
+
group_id=self.group_id,
|
|
214
|
+
created_at=self.created_at,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
logger.info(f'Saved edge to neo4j: {self.uuid}')
|
|
218
|
+
|
|
219
|
+
return result
|
|
220
|
+
|
|
221
|
+
@classmethod
|
|
222
|
+
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
223
|
+
records, _, _ = await driver.execute_query(
|
|
224
|
+
"""
|
|
225
|
+
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
|
226
|
+
RETURN
|
|
227
|
+
e.uuid As uuid,
|
|
228
|
+
e.group_id AS group_id,
|
|
229
|
+
n.uuid AS source_node_uuid,
|
|
230
|
+
m.uuid AS target_node_uuid,
|
|
231
|
+
e.created_at AS created_at
|
|
232
|
+
""",
|
|
233
|
+
uuid=uuid,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
edges = [get_community_edge_from_record(record) for record in records]
|
|
229
237
|
|
|
230
238
|
logger.info(f'Found Edge: {uuid}')
|
|
231
239
|
|
|
232
240
|
return edges[0]
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# Edge helpers
|
|
244
|
+
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
|
245
|
+
return EpisodicEdge(
|
|
246
|
+
uuid=record['uuid'],
|
|
247
|
+
group_id=record['group_id'],
|
|
248
|
+
source_node_uuid=record['source_node_uuid'],
|
|
249
|
+
target_node_uuid=record['target_node_uuid'],
|
|
250
|
+
created_at=record['created_at'].to_native(),
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
|
255
|
+
return EntityEdge(
|
|
256
|
+
uuid=record['uuid'],
|
|
257
|
+
source_node_uuid=record['source_node_uuid'],
|
|
258
|
+
target_node_uuid=record['target_node_uuid'],
|
|
259
|
+
fact=record['fact'],
|
|
260
|
+
name=record['name'],
|
|
261
|
+
group_id=record['group_id'],
|
|
262
|
+
episodes=record['episodes'],
|
|
263
|
+
fact_embedding=record['fact_embedding'],
|
|
264
|
+
created_at=record['created_at'].to_native(),
|
|
265
|
+
expired_at=parse_db_date(record['expired_at']),
|
|
266
|
+
valid_at=parse_db_date(record['valid_at']),
|
|
267
|
+
invalid_at=parse_db_date(record['invalid_at']),
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def get_community_edge_from_record(record: Any):
|
|
272
|
+
return CommunityEdge(
|
|
273
|
+
uuid=record['uuid'],
|
|
274
|
+
group_id=record['group_id'],
|
|
275
|
+
source_node_uuid=record['source_node_uuid'],
|
|
276
|
+
target_node_uuid=record['target_node_uuid'],
|
|
277
|
+
created_at=record['created_at'].to_native(),
|
|
278
|
+
)
|
graphiti_core/errors.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
class GraphitiError(Exception):
|
|
2
|
+
"""Base exception class for Graphiti Core."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class EdgeNotFoundError(GraphitiError):
|
|
6
|
+
"""Raised when an edge is not found."""
|
|
7
|
+
|
|
8
|
+
def __init__(self, uuid: str):
|
|
9
|
+
self.message = f'edge {uuid} not found'
|
|
10
|
+
super().__init__(self.message)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NodeNotFoundError(GraphitiError):
|
|
14
|
+
"""Raised when a node is not found."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, uuid: str):
|
|
17
|
+
self.message = f'node {uuid} not found'
|
|
18
|
+
super().__init__(self.message)
|
graphiti_core/graphiti.py
CHANGED
|
@@ -18,7 +18,6 @@ import asyncio
|
|
|
18
18
|
import logging
|
|
19
19
|
from datetime import datetime
|
|
20
20
|
from time import time
|
|
21
|
-
from typing import Callable
|
|
22
21
|
|
|
23
22
|
from dotenv import load_dotenv
|
|
24
23
|
from neo4j import AsyncGraphDatabase
|
|
@@ -47,6 +46,10 @@ from graphiti_core.utils.bulk_utils import (
|
|
|
47
46
|
resolve_edge_pointers,
|
|
48
47
|
retrieve_previous_episodes_bulk,
|
|
49
48
|
)
|
|
49
|
+
from graphiti_core.utils.maintenance.community_operations import (
|
|
50
|
+
build_communities,
|
|
51
|
+
remove_communities,
|
|
52
|
+
)
|
|
50
53
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
51
54
|
extract_edges,
|
|
52
55
|
resolve_extracted_edges,
|
|
@@ -120,7 +123,7 @@ class Graphiti:
|
|
|
120
123
|
|
|
121
124
|
Parameters
|
|
122
125
|
----------
|
|
123
|
-
|
|
126
|
+
self
|
|
124
127
|
|
|
125
128
|
Returns
|
|
126
129
|
-------
|
|
@@ -151,7 +154,7 @@ class Graphiti:
|
|
|
151
154
|
|
|
152
155
|
Parameters
|
|
153
156
|
----------
|
|
154
|
-
|
|
157
|
+
self
|
|
155
158
|
|
|
156
159
|
Returns
|
|
157
160
|
-------
|
|
@@ -178,6 +181,7 @@ class Graphiti:
|
|
|
178
181
|
self,
|
|
179
182
|
reference_time: datetime,
|
|
180
183
|
last_n: int = EPISODE_WINDOW_LEN,
|
|
184
|
+
group_ids: list[str | None] | None = None,
|
|
181
185
|
) -> list[EpisodicNode]:
|
|
182
186
|
"""
|
|
183
187
|
Retrieve the last n episodic nodes from the graph.
|
|
@@ -191,6 +195,8 @@ class Graphiti:
|
|
|
191
195
|
The reference time to retrieve episodes before.
|
|
192
196
|
last_n : int, optional
|
|
193
197
|
The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
|
|
198
|
+
group_ids : list[str | None], optional
|
|
199
|
+
The group ids to return data from.
|
|
194
200
|
|
|
195
201
|
Returns
|
|
196
202
|
-------
|
|
@@ -202,7 +208,7 @@ class Graphiti:
|
|
|
202
208
|
The actual retrieval is performed by the `retrieve_episodes` function
|
|
203
209
|
from the `graphiti_core.utils` module.
|
|
204
210
|
"""
|
|
205
|
-
return await retrieve_episodes(self.driver, reference_time, last_n)
|
|
211
|
+
return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
|
|
206
212
|
|
|
207
213
|
async def add_episode(
|
|
208
214
|
self,
|
|
@@ -211,8 +217,8 @@ class Graphiti:
|
|
|
211
217
|
source_description: str,
|
|
212
218
|
reference_time: datetime,
|
|
213
219
|
source: EpisodeType = EpisodeType.message,
|
|
214
|
-
|
|
215
|
-
|
|
220
|
+
group_id: str | None = None,
|
|
221
|
+
uuid: str | None = None,
|
|
216
222
|
):
|
|
217
223
|
"""
|
|
218
224
|
Process an episode and update the graph.
|
|
@@ -232,10 +238,10 @@ class Graphiti:
|
|
|
232
238
|
The reference time for the episode.
|
|
233
239
|
source : EpisodeType, optional
|
|
234
240
|
The type of the episode. Defaults to EpisodeType.message.
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
241
|
+
group_id : str | None
|
|
242
|
+
An id for the graph partition the episode is a part of.
|
|
243
|
+
uuid : str | None
|
|
244
|
+
Optional uuid of the episode.
|
|
239
245
|
|
|
240
246
|
Returns
|
|
241
247
|
-------
|
|
@@ -266,9 +272,12 @@ class Graphiti:
|
|
|
266
272
|
embedder = self.llm_client.get_embedder()
|
|
267
273
|
now = datetime.now()
|
|
268
274
|
|
|
269
|
-
previous_episodes = await self.retrieve_episodes(
|
|
275
|
+
previous_episodes = await self.retrieve_episodes(
|
|
276
|
+
reference_time, last_n=3, group_ids=[group_id]
|
|
277
|
+
)
|
|
270
278
|
episode = EpisodicNode(
|
|
271
279
|
name=name,
|
|
280
|
+
group_id=group_id,
|
|
272
281
|
labels=[],
|
|
273
282
|
source=source,
|
|
274
283
|
content=episode_body,
|
|
@@ -276,6 +285,7 @@ class Graphiti:
|
|
|
276
285
|
created_at=now,
|
|
277
286
|
valid_at=reference_time,
|
|
278
287
|
)
|
|
288
|
+
episode.uuid = uuid if uuid is not None else episode.uuid
|
|
279
289
|
|
|
280
290
|
# Extract entities as nodes
|
|
281
291
|
|
|
@@ -299,7 +309,9 @@ class Graphiti:
|
|
|
299
309
|
|
|
300
310
|
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
|
|
301
311
|
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
|
|
302
|
-
extract_edges(
|
|
312
|
+
extract_edges(
|
|
313
|
+
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
|
314
|
+
),
|
|
303
315
|
)
|
|
304
316
|
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
|
|
305
317
|
nodes.extend(mentioned_nodes)
|
|
@@ -388,11 +400,7 @@ class Graphiti:
|
|
|
388
400
|
|
|
389
401
|
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
|
|
390
402
|
|
|
391
|
-
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
|
|
392
|
-
mentioned_nodes,
|
|
393
|
-
episode,
|
|
394
|
-
now,
|
|
395
|
-
)
|
|
403
|
+
episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
|
|
396
404
|
|
|
397
405
|
logger.info(f'Built episodic edges: {episodic_edges}')
|
|
398
406
|
|
|
@@ -405,18 +413,10 @@ class Graphiti:
|
|
|
405
413
|
end = time()
|
|
406
414
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
|
407
415
|
|
|
408
|
-
if success_callback:
|
|
409
|
-
await success_callback(episode)
|
|
410
416
|
except Exception as e:
|
|
411
|
-
|
|
412
|
-
await error_callback(episode, e)
|
|
413
|
-
else:
|
|
414
|
-
raise e
|
|
417
|
+
raise e
|
|
415
418
|
|
|
416
|
-
async def add_episode_bulk(
|
|
417
|
-
self,
|
|
418
|
-
bulk_episodes: list[RawEpisode],
|
|
419
|
-
):
|
|
419
|
+
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
|
|
420
420
|
"""
|
|
421
421
|
Process multiple episodes in bulk and update the graph.
|
|
422
422
|
|
|
@@ -427,6 +427,8 @@ class Graphiti:
|
|
|
427
427
|
----------
|
|
428
428
|
bulk_episodes : list[RawEpisode]
|
|
429
429
|
A list of RawEpisode objects to be processed and added to the graph.
|
|
430
|
+
group_id : str | None
|
|
431
|
+
An id for the graph partition the episode is a part of.
|
|
430
432
|
|
|
431
433
|
Returns
|
|
432
434
|
-------
|
|
@@ -463,6 +465,7 @@ class Graphiti:
|
|
|
463
465
|
source=episode.source,
|
|
464
466
|
content=episode.content,
|
|
465
467
|
source_description=episode.source_description,
|
|
468
|
+
group_id=group_id,
|
|
466
469
|
created_at=now,
|
|
467
470
|
valid_at=episode.reference_time,
|
|
468
471
|
)
|
|
@@ -527,7 +530,26 @@ class Graphiti:
|
|
|
527
530
|
except Exception as e:
|
|
528
531
|
raise e
|
|
529
532
|
|
|
530
|
-
async def
|
|
533
|
+
async def build_communities(self):
|
|
534
|
+
embedder = self.llm_client.get_embedder()
|
|
535
|
+
|
|
536
|
+
# Clear existing communities
|
|
537
|
+
await remove_communities(self.driver)
|
|
538
|
+
|
|
539
|
+
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
|
540
|
+
|
|
541
|
+
await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
|
|
542
|
+
|
|
543
|
+
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
|
544
|
+
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
|
545
|
+
|
|
546
|
+
async def search(
|
|
547
|
+
self,
|
|
548
|
+
query: str,
|
|
549
|
+
center_node_uuid: str | None = None,
|
|
550
|
+
group_ids: list[str | None] | None = None,
|
|
551
|
+
num_results=10,
|
|
552
|
+
):
|
|
531
553
|
"""
|
|
532
554
|
Perform a hybrid search on the knowledge graph.
|
|
533
555
|
|
|
@@ -540,6 +562,8 @@ class Graphiti:
|
|
|
540
562
|
The search query string.
|
|
541
563
|
center_node_uuid: str, optional
|
|
542
564
|
Facts will be reranked based on proximity to this node
|
|
565
|
+
group_ids : list[str | None] | None, optional
|
|
566
|
+
The graph partitions to return data from.
|
|
543
567
|
num_results : int, optional
|
|
544
568
|
The maximum number of results to return. Defaults to 10.
|
|
545
569
|
|
|
@@ -562,6 +586,7 @@ class Graphiti:
|
|
|
562
586
|
num_episodes=0,
|
|
563
587
|
num_edges=num_results,
|
|
564
588
|
num_nodes=0,
|
|
589
|
+
group_ids=group_ids,
|
|
565
590
|
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
|
|
566
591
|
reranker=reranker,
|
|
567
592
|
)
|
|
@@ -590,7 +615,10 @@ class Graphiti:
|
|
|
590
615
|
)
|
|
591
616
|
|
|
592
617
|
async def get_nodes_by_query(
|
|
593
|
-
self,
|
|
618
|
+
self,
|
|
619
|
+
query: str,
|
|
620
|
+
group_ids: list[str | None] | None = None,
|
|
621
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
594
622
|
) -> list[EntityNode]:
|
|
595
623
|
"""
|
|
596
624
|
Retrieve nodes from the graph database based on a text query.
|
|
@@ -602,6 +630,8 @@ class Graphiti:
|
|
|
602
630
|
----------
|
|
603
631
|
query : str
|
|
604
632
|
The text query to search for in the graph.
|
|
633
|
+
group_ids : list[str | None] | None, optional
|
|
634
|
+
The graph partitions to return data from.
|
|
605
635
|
limit : int | None, optional
|
|
606
636
|
The maximum number of results to return per search method.
|
|
607
637
|
If None, a default limit will be applied.
|
|
@@ -626,5 +656,7 @@ class Graphiti:
|
|
|
626
656
|
"""
|
|
627
657
|
embedder = self.llm_client.get_embedder()
|
|
628
658
|
query_embedding = await generate_embedding(embedder, query)
|
|
629
|
-
relevant_nodes = await hybrid_node_search(
|
|
659
|
+
relevant_nodes = await hybrid_node_search(
|
|
660
|
+
[query], [query_embedding], self.driver, group_ids, limit
|
|
661
|
+
)
|
|
630
662
|
return relevant_nodes
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from .client import LLMClient
|
|
2
2
|
from .config import LLMConfig
|
|
3
|
+
from .errors import RateLimitError
|
|
3
4
|
from .openai_client import OpenAIClient
|
|
4
5
|
|
|
5
|
-
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
|
|
6
|
+
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']
|
|
@@ -18,12 +18,14 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
20
|
|
|
21
|
+
import anthropic
|
|
21
22
|
from anthropic import AsyncAnthropic
|
|
22
23
|
from openai import AsyncOpenAI
|
|
23
24
|
|
|
24
25
|
from ..prompts.models import Message
|
|
25
26
|
from .client import LLMClient
|
|
26
27
|
from .config import LLMConfig
|
|
28
|
+
from .errors import RateLimitError
|
|
27
29
|
|
|
28
30
|
logger = logging.getLogger(__name__)
|
|
29
31
|
|
|
@@ -35,7 +37,11 @@ class AnthropicClient(LLMClient):
|
|
|
35
37
|
if config is None:
|
|
36
38
|
config = LLMConfig()
|
|
37
39
|
super().__init__(config, cache)
|
|
38
|
-
self.client = AsyncAnthropic(
|
|
40
|
+
self.client = AsyncAnthropic(
|
|
41
|
+
api_key=config.api_key,
|
|
42
|
+
# we'll use tenacity to retry
|
|
43
|
+
max_retries=1,
|
|
44
|
+
)
|
|
39
45
|
|
|
40
46
|
def get_embedder(self) -> typing.Any:
|
|
41
47
|
openai_client = AsyncOpenAI()
|
|
@@ -58,6 +64,8 @@ class AnthropicClient(LLMClient):
|
|
|
58
64
|
)
|
|
59
65
|
|
|
60
66
|
return json.loads('{' + result.content[0].text) # type: ignore
|
|
67
|
+
except anthropic.RateLimitError as e:
|
|
68
|
+
raise RateLimitError from e
|
|
61
69
|
except Exception as e:
|
|
62
70
|
logger.error(f'Error in generating LLM response: {e}')
|
|
63
71
|
raise
|
|
@@ -22,10 +22,11 @@ from abc import ABC, abstractmethod
|
|
|
22
22
|
|
|
23
23
|
import httpx
|
|
24
24
|
from diskcache import Cache
|
|
25
|
-
from tenacity import retry, retry_if_exception, stop_after_attempt,
|
|
25
|
+
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
|
26
26
|
|
|
27
27
|
from ..prompts.models import Message
|
|
28
28
|
from .config import LLMConfig
|
|
29
|
+
from .errors import RateLimitError
|
|
29
30
|
|
|
30
31
|
DEFAULT_TEMPERATURE = 0
|
|
31
32
|
DEFAULT_CACHE_DIR = './llm_cache'
|
|
@@ -33,7 +34,10 @@ DEFAULT_CACHE_DIR = './llm_cache'
|
|
|
33
34
|
logger = logging.getLogger(__name__)
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
def
|
|
37
|
+
def is_server_or_retry_error(exception):
|
|
38
|
+
if isinstance(exception, RateLimitError):
|
|
39
|
+
return True
|
|
40
|
+
|
|
37
41
|
return (
|
|
38
42
|
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
|
|
39
43
|
)
|
|
@@ -56,18 +60,21 @@ class LLMClient(ABC):
|
|
|
56
60
|
pass
|
|
57
61
|
|
|
58
62
|
@retry(
|
|
59
|
-
stop=stop_after_attempt(
|
|
60
|
-
wait=
|
|
61
|
-
retry=retry_if_exception(
|
|
63
|
+
stop=stop_after_attempt(4),
|
|
64
|
+
wait=wait_random_exponential(multiplier=10, min=5, max=120),
|
|
65
|
+
retry=retry_if_exception(is_server_or_retry_error),
|
|
66
|
+
after=lambda retry_state: logger.warning(
|
|
67
|
+
f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
|
|
68
|
+
)
|
|
69
|
+
if retry_state.attempt_number > 1
|
|
70
|
+
else None,
|
|
71
|
+
reraise=True,
|
|
62
72
|
)
|
|
63
73
|
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
|
|
64
74
|
try:
|
|
65
75
|
return await self._generate_response(messages)
|
|
66
|
-
except httpx.HTTPStatusError as e:
|
|
67
|
-
|
|
68
|
-
raise Exception(f'LLM request error: {e}') from e
|
|
69
|
-
else:
|
|
70
|
-
raise
|
|
76
|
+
except (httpx.HTTPStatusError, RateLimitError) as e:
|
|
77
|
+
raise e
|
|
71
78
|
|
|
72
79
|
@abstractmethod
|
|
73
80
|
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
|
@@ -18,6 +18,7 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
20
|
|
|
21
|
+
import groq
|
|
21
22
|
from groq import AsyncGroq
|
|
22
23
|
from groq.types.chat import ChatCompletionMessageParam
|
|
23
24
|
from openai import AsyncOpenAI
|
|
@@ -25,6 +26,7 @@ from openai import AsyncOpenAI
|
|
|
25
26
|
from ..prompts.models import Message
|
|
26
27
|
from .client import LLMClient
|
|
27
28
|
from .config import LLMConfig
|
|
29
|
+
from .errors import RateLimitError
|
|
28
30
|
|
|
29
31
|
logger = logging.getLogger(__name__)
|
|
30
32
|
|
|
@@ -59,6 +61,8 @@ class GroqClient(LLMClient):
|
|
|
59
61
|
)
|
|
60
62
|
result = response.choices[0].message.content or ''
|
|
61
63
|
return json.loads(result)
|
|
64
|
+
except groq.RateLimitError as e:
|
|
65
|
+
raise RateLimitError from e
|
|
62
66
|
except Exception as e:
|
|
63
67
|
logger.error(f'Error in generating LLM response: {e}')
|
|
64
68
|
raise
|
|
@@ -18,12 +18,14 @@ import json
|
|
|
18
18
|
import logging
|
|
19
19
|
import typing
|
|
20
20
|
|
|
21
|
+
import openai
|
|
21
22
|
from openai import AsyncOpenAI
|
|
22
23
|
from openai.types.chat import ChatCompletionMessageParam
|
|
23
24
|
|
|
24
25
|
from ..prompts.models import Message
|
|
25
26
|
from .client import LLMClient
|
|
26
27
|
from .config import LLMConfig
|
|
28
|
+
from .errors import RateLimitError
|
|
27
29
|
|
|
28
30
|
logger = logging.getLogger(__name__)
|
|
29
31
|
|
|
@@ -59,6 +61,8 @@ class OpenAIClient(LLMClient):
|
|
|
59
61
|
)
|
|
60
62
|
result = response.choices[0].message.content or ''
|
|
61
63
|
return json.loads(result)
|
|
64
|
+
except openai.RateLimitError as e:
|
|
65
|
+
raise RateLimitError from e
|
|
62
66
|
except Exception as e:
|
|
63
67
|
logger.error(f'Error in generating LLM response: {e}')
|
|
64
68
|
raise
|