graphiti-core 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/edges.py CHANGED
@@ -18,11 +18,13 @@ import logging
18
18
  from abc import ABC, abstractmethod
19
19
  from datetime import datetime
20
20
  from time import time
21
+ from typing import Any
21
22
  from uuid import uuid4
22
23
 
23
24
  from neo4j import AsyncDriver
24
25
  from pydantic import BaseModel, Field
25
26
 
27
+ from graphiti_core.errors import EdgeNotFoundError
26
28
  from graphiti_core.helpers import parse_db_date
27
29
  from graphiti_core.llm_client.config import EMBEDDING_DIM
28
30
  from graphiti_core.nodes import Node
@@ -32,6 +34,7 @@ logger = logging.getLogger(__name__)
32
34
 
33
35
  class Edge(BaseModel, ABC):
34
36
  uuid: str = Field(default_factory=lambda: uuid4().hex)
37
+ group_id: str | None = Field(description='partition of the graph')
35
38
  source_node_uuid: str
36
39
  target_node_uuid: str
37
40
  created_at: datetime
@@ -39,8 +42,18 @@ class Edge(BaseModel, ABC):
39
42
  @abstractmethod
40
43
  async def save(self, driver: AsyncDriver): ...
41
44
 
42
- @abstractmethod
43
- async def delete(self, driver: AsyncDriver): ...
45
+ async def delete(self, driver: AsyncDriver):
46
+ result = await driver.execute_query(
47
+ """
48
+ MATCH (n)-[e {uuid: $uuid}]->(m)
49
+ DELETE e
50
+ """,
51
+ uuid=self.uuid,
52
+ )
53
+
54
+ logger.info(f'Deleted Edge: {self.uuid}')
55
+
56
+ return result
44
57
 
45
58
  def __hash__(self):
46
59
  return hash(self.uuid)
@@ -61,11 +74,12 @@ class EpisodicEdge(Edge):
61
74
  MATCH (episode:Episodic {uuid: $episode_uuid})
62
75
  MATCH (node:Entity {uuid: $entity_uuid})
63
76
  MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
64
- SET r = {uuid: $uuid, created_at: $created_at}
77
+ SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
65
78
  RETURN r.uuid AS uuid""",
66
79
  episode_uuid=self.source_node_uuid,
67
80
  entity_uuid=self.target_node_uuid,
68
81
  uuid=self.uuid,
82
+ group_id=self.group_id,
69
83
  created_at=self.created_at,
70
84
  )
71
85
 
@@ -73,26 +87,14 @@ class EpisodicEdge(Edge):
73
87
 
74
88
  return result
75
89
 
76
- async def delete(self, driver: AsyncDriver):
77
- result = await driver.execute_query(
78
- """
79
- MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
80
- DELETE e
81
- """,
82
- uuid=self.uuid,
83
- )
84
-
85
- logger.info(f'Deleted Edge: {self.uuid}')
86
-
87
- return result
88
-
89
90
  @classmethod
90
91
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
91
92
  records, _, _ = await driver.execute_query(
92
93
  """
93
94
  MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
94
95
  RETURN
95
- e.uuid As uuid,
96
+ e.uuid As uuid,
97
+ e.group_id AS group_id,
96
98
  n.uuid AS source_node_uuid,
97
99
  m.uuid AS target_node_uuid,
98
100
  e.created_at AS created_at
@@ -100,20 +102,11 @@ class EpisodicEdge(Edge):
100
102
  uuid=uuid,
101
103
  )
102
104
 
103
- edges: list[EpisodicEdge] = []
104
-
105
- for record in records:
106
- edges.append(
107
- EpisodicEdge(
108
- uuid=record['uuid'],
109
- source_node_uuid=record['source_node_uuid'],
110
- target_node_uuid=record['target_node_uuid'],
111
- created_at=record['created_at'].to_native(),
112
- )
113
- )
105
+ edges = [get_episodic_edge_from_record(record) for record in records]
114
106
 
115
107
  logger.info(f'Found Edge: {uuid}')
116
-
108
+ if len(edges) == 0:
109
+ raise EdgeNotFoundError(uuid)
117
110
  return edges[0]
118
111
 
119
112
 
@@ -153,7 +146,7 @@ class EntityEdge(Edge):
153
146
  MATCH (source:Entity {uuid: $source_uuid})
154
147
  MATCH (target:Entity {uuid: $target_uuid})
155
148
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
156
- SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
149
+ SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
157
150
  episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
158
151
  valid_at: $valid_at, invalid_at: $invalid_at}
159
152
  RETURN r.uuid AS uuid""",
@@ -161,6 +154,7 @@ class EntityEdge(Edge):
161
154
  target_uuid=self.target_node_uuid,
162
155
  uuid=self.uuid,
163
156
  name=self.name,
157
+ group_id=self.group_id,
164
158
  fact=self.fact,
165
159
  fact_embedding=self.fact_embedding,
166
160
  episodes=self.episodes,
@@ -174,19 +168,6 @@ class EntityEdge(Edge):
174
168
 
175
169
  return result
176
170
 
177
- async def delete(self, driver: AsyncDriver):
178
- result = await driver.execute_query(
179
- """
180
- MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
181
- DELETE e
182
- """,
183
- uuid=self.uuid,
184
- )
185
-
186
- logger.info(f'Deleted Edge: {self.uuid}')
187
-
188
- return result
189
-
190
171
  @classmethod
191
172
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
192
173
  records, _, _ = await driver.execute_query(
@@ -198,6 +179,7 @@ class EntityEdge(Edge):
198
179
  m.uuid AS target_node_uuid,
199
180
  e.created_at AS created_at,
200
181
  e.name AS name,
182
+ e.group_id AS group_id,
201
183
  e.fact AS fact,
202
184
  e.fact_embedding AS fact_embedding,
203
185
  e.episodes AS episodes,
@@ -208,25 +190,89 @@ class EntityEdge(Edge):
208
190
  uuid=uuid,
209
191
  )
210
192
 
211
- edges: list[EntityEdge] = []
212
-
213
- for record in records:
214
- edges.append(
215
- EntityEdge(
216
- uuid=record['uuid'],
217
- source_node_uuid=record['source_node_uuid'],
218
- target_node_uuid=record['target_node_uuid'],
219
- fact=record['fact'],
220
- name=record['name'],
221
- episodes=record['episodes'],
222
- fact_embedding=record['fact_embedding'],
223
- created_at=record['created_at'].to_native(),
224
- expired_at=parse_db_date(record['expired_at']),
225
- valid_at=parse_db_date(record['valid_at']),
226
- invalid_at=parse_db_date(record['invalid_at']),
227
- )
228
- )
193
+ edges = [get_entity_edge_from_record(record) for record in records]
194
+
195
+ logger.info(f'Found Edge: {uuid}')
196
+ if len(edges) == 0:
197
+ raise EdgeNotFoundError(uuid)
198
+ return edges[0]
199
+
200
+
201
+ class CommunityEdge(Edge):
202
+ async def save(self, driver: AsyncDriver):
203
+ result = await driver.execute_query(
204
+ """
205
+ MATCH (community:Community {uuid: $community_uuid})
206
+ MATCH (node:Entity | Community {uuid: $entity_uuid})
207
+ MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
208
+ SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
209
+ RETURN r.uuid AS uuid""",
210
+ community_uuid=self.source_node_uuid,
211
+ entity_uuid=self.target_node_uuid,
212
+ uuid=self.uuid,
213
+ group_id=self.group_id,
214
+ created_at=self.created_at,
215
+ )
216
+
217
+ logger.info(f'Saved edge to neo4j: {self.uuid}')
218
+
219
+ return result
220
+
221
+ @classmethod
222
+ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
223
+ records, _, _ = await driver.execute_query(
224
+ """
225
+ MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
226
+ RETURN
227
+ e.uuid As uuid,
228
+ e.group_id AS group_id,
229
+ n.uuid AS source_node_uuid,
230
+ m.uuid AS target_node_uuid,
231
+ e.created_at AS created_at
232
+ """,
233
+ uuid=uuid,
234
+ )
235
+
236
+ edges = [get_community_edge_from_record(record) for record in records]
229
237
 
230
238
  logger.info(f'Found Edge: {uuid}')
231
239
 
232
240
  return edges[0]
241
+
242
+
243
+ # Edge helpers
244
+ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
245
+ return EpisodicEdge(
246
+ uuid=record['uuid'],
247
+ group_id=record['group_id'],
248
+ source_node_uuid=record['source_node_uuid'],
249
+ target_node_uuid=record['target_node_uuid'],
250
+ created_at=record['created_at'].to_native(),
251
+ )
252
+
253
+
254
+ def get_entity_edge_from_record(record: Any) -> EntityEdge:
255
+ return EntityEdge(
256
+ uuid=record['uuid'],
257
+ source_node_uuid=record['source_node_uuid'],
258
+ target_node_uuid=record['target_node_uuid'],
259
+ fact=record['fact'],
260
+ name=record['name'],
261
+ group_id=record['group_id'],
262
+ episodes=record['episodes'],
263
+ fact_embedding=record['fact_embedding'],
264
+ created_at=record['created_at'].to_native(),
265
+ expired_at=parse_db_date(record['expired_at']),
266
+ valid_at=parse_db_date(record['valid_at']),
267
+ invalid_at=parse_db_date(record['invalid_at']),
268
+ )
269
+
270
+
271
+ def get_community_edge_from_record(record: Any):
272
+ return CommunityEdge(
273
+ uuid=record['uuid'],
274
+ group_id=record['group_id'],
275
+ source_node_uuid=record['source_node_uuid'],
276
+ target_node_uuid=record['target_node_uuid'],
277
+ created_at=record['created_at'].to_native(),
278
+ )
@@ -0,0 +1,18 @@
1
+ class GraphitiError(Exception):
2
+ """Base exception class for Graphiti Core."""
3
+
4
+
5
+ class EdgeNotFoundError(GraphitiError):
6
+ """Raised when an edge is not found."""
7
+
8
+ def __init__(self, uuid: str):
9
+ self.message = f'edge {uuid} not found'
10
+ super().__init__(self.message)
11
+
12
+
13
+ class NodeNotFoundError(GraphitiError):
14
+ """Raised when a node is not found."""
15
+
16
+ def __init__(self, uuid: str):
17
+ self.message = f'node {uuid} not found'
18
+ super().__init__(self.message)
graphiti_core/graphiti.py CHANGED
@@ -18,7 +18,6 @@ import asyncio
18
18
  import logging
19
19
  from datetime import datetime
20
20
  from time import time
21
- from typing import Callable
22
21
 
23
22
  from dotenv import load_dotenv
24
23
  from neo4j import AsyncGraphDatabase
@@ -47,6 +46,10 @@ from graphiti_core.utils.bulk_utils import (
47
46
  resolve_edge_pointers,
48
47
  retrieve_previous_episodes_bulk,
49
48
  )
49
+ from graphiti_core.utils.maintenance.community_operations import (
50
+ build_communities,
51
+ remove_communities,
52
+ )
50
53
  from graphiti_core.utils.maintenance.edge_operations import (
51
54
  extract_edges,
52
55
  resolve_extracted_edges,
@@ -120,7 +123,7 @@ class Graphiti:
120
123
 
121
124
  Parameters
122
125
  ----------
123
- None
126
+ self
124
127
 
125
128
  Returns
126
129
  -------
@@ -151,7 +154,7 @@ class Graphiti:
151
154
 
152
155
  Parameters
153
156
  ----------
154
- None
157
+ self
155
158
 
156
159
  Returns
157
160
  -------
@@ -178,6 +181,7 @@ class Graphiti:
178
181
  self,
179
182
  reference_time: datetime,
180
183
  last_n: int = EPISODE_WINDOW_LEN,
184
+ group_ids: list[str | None] | None = None,
181
185
  ) -> list[EpisodicNode]:
182
186
  """
183
187
  Retrieve the last n episodic nodes from the graph.
@@ -191,6 +195,8 @@ class Graphiti:
191
195
  The reference time to retrieve episodes before.
192
196
  last_n : int, optional
193
197
  The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
198
+ group_ids : list[str | None], optional
199
+ The group ids to return data from.
194
200
 
195
201
  Returns
196
202
  -------
@@ -202,7 +208,7 @@ class Graphiti:
202
208
  The actual retrieval is performed by the `retrieve_episodes` function
203
209
  from the `graphiti_core.utils` module.
204
210
  """
205
- return await retrieve_episodes(self.driver, reference_time, last_n)
211
+ return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
206
212
 
207
213
  async def add_episode(
208
214
  self,
@@ -211,8 +217,8 @@ class Graphiti:
211
217
  source_description: str,
212
218
  reference_time: datetime,
213
219
  source: EpisodeType = EpisodeType.message,
214
- success_callback: Callable | None = None,
215
- error_callback: Callable | None = None,
220
+ group_id: str | None = None,
221
+ uuid: str | None = None,
216
222
  ):
217
223
  """
218
224
  Process an episode and update the graph.
@@ -232,10 +238,10 @@ class Graphiti:
232
238
  The reference time for the episode.
233
239
  source : EpisodeType, optional
234
240
  The type of the episode. Defaults to EpisodeType.message.
235
- success_callback : Callable | None, optional
236
- A callback function to be called upon successful processing.
237
- error_callback : Callable | None, optional
238
- A callback function to be called if an error occurs during processing.
241
+ group_id : str | None
242
+ An id for the graph partition the episode is a part of.
243
+ uuid : str | None
244
+ Optional uuid of the episode.
239
245
 
240
246
  Returns
241
247
  -------
@@ -266,9 +272,12 @@ class Graphiti:
266
272
  embedder = self.llm_client.get_embedder()
267
273
  now = datetime.now()
268
274
 
269
- previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
275
+ previous_episodes = await self.retrieve_episodes(
276
+ reference_time, last_n=3, group_ids=[group_id]
277
+ )
270
278
  episode = EpisodicNode(
271
279
  name=name,
280
+ group_id=group_id,
272
281
  labels=[],
273
282
  source=source,
274
283
  content=episode_body,
@@ -276,6 +285,7 @@ class Graphiti:
276
285
  created_at=now,
277
286
  valid_at=reference_time,
278
287
  )
288
+ episode.uuid = uuid if uuid is not None else episode.uuid
279
289
 
280
290
  # Extract entities as nodes
281
291
 
@@ -299,7 +309,9 @@ class Graphiti:
299
309
 
300
310
  (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
301
311
  resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
302
- extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
312
+ extract_edges(
313
+ self.llm_client, episode, extracted_nodes, previous_episodes, group_id
314
+ ),
303
315
  )
304
316
  logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
305
317
  nodes.extend(mentioned_nodes)
@@ -388,11 +400,7 @@ class Graphiti:
388
400
 
389
401
  logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
390
402
 
391
- episodic_edges: list[EpisodicEdge] = build_episodic_edges(
392
- mentioned_nodes,
393
- episode,
394
- now,
395
- )
403
+ episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
396
404
 
397
405
  logger.info(f'Built episodic edges: {episodic_edges}')
398
406
 
@@ -405,18 +413,10 @@ class Graphiti:
405
413
  end = time()
406
414
  logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
407
415
 
408
- if success_callback:
409
- await success_callback(episode)
410
416
  except Exception as e:
411
- if error_callback:
412
- await error_callback(episode, e)
413
- else:
414
- raise e
417
+ raise e
415
418
 
416
- async def add_episode_bulk(
417
- self,
418
- bulk_episodes: list[RawEpisode],
419
- ):
419
+ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
420
420
  """
421
421
  Process multiple episodes in bulk and update the graph.
422
422
 
@@ -427,6 +427,8 @@ class Graphiti:
427
427
  ----------
428
428
  bulk_episodes : list[RawEpisode]
429
429
  A list of RawEpisode objects to be processed and added to the graph.
430
+ group_id : str | None
431
+ An id for the graph partition the episode is a part of.
430
432
 
431
433
  Returns
432
434
  -------
@@ -463,6 +465,7 @@ class Graphiti:
463
465
  source=episode.source,
464
466
  content=episode.content,
465
467
  source_description=episode.source_description,
468
+ group_id=group_id,
466
469
  created_at=now,
467
470
  valid_at=episode.reference_time,
468
471
  )
@@ -527,7 +530,26 @@ class Graphiti:
527
530
  except Exception as e:
528
531
  raise e
529
532
 
530
- async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
533
+ async def build_communities(self):
534
+ embedder = self.llm_client.get_embedder()
535
+
536
+ # Clear existing communities
537
+ await remove_communities(self.driver)
538
+
539
+ community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
540
+
541
+ await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
542
+
543
+ await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
544
+ await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
545
+
546
+ async def search(
547
+ self,
548
+ query: str,
549
+ center_node_uuid: str | None = None,
550
+ group_ids: list[str | None] | None = None,
551
+ num_results=10,
552
+ ):
531
553
  """
532
554
  Perform a hybrid search on the knowledge graph.
533
555
 
@@ -540,6 +562,8 @@ class Graphiti:
540
562
  The search query string.
541
563
  center_node_uuid: str, optional
542
564
  Facts will be reranked based on proximity to this node
565
+ group_ids : list[str | None] | None, optional
566
+ The graph partitions to return data from.
543
567
  num_results : int, optional
544
568
  The maximum number of results to return. Defaults to 10.
545
569
 
@@ -562,6 +586,7 @@ class Graphiti:
562
586
  num_episodes=0,
563
587
  num_edges=num_results,
564
588
  num_nodes=0,
589
+ group_ids=group_ids,
565
590
  search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
566
591
  reranker=reranker,
567
592
  )
@@ -590,7 +615,10 @@ class Graphiti:
590
615
  )
591
616
 
592
617
  async def get_nodes_by_query(
593
- self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
618
+ self,
619
+ query: str,
620
+ group_ids: list[str | None] | None = None,
621
+ limit: int = RELEVANT_SCHEMA_LIMIT,
594
622
  ) -> list[EntityNode]:
595
623
  """
596
624
  Retrieve nodes from the graph database based on a text query.
@@ -602,6 +630,8 @@ class Graphiti:
602
630
  ----------
603
631
  query : str
604
632
  The text query to search for in the graph.
633
+ group_ids : list[str | None] | None, optional
634
+ The graph partitions to return data from.
605
635
  limit : int | None, optional
606
636
  The maximum number of results to return per search method.
607
637
  If None, a default limit will be applied.
@@ -626,5 +656,7 @@ class Graphiti:
626
656
  """
627
657
  embedder = self.llm_client.get_embedder()
628
658
  query_embedding = await generate_embedding(embedder, query)
629
- relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit)
659
+ relevant_nodes = await hybrid_node_search(
660
+ [query], [query_embedding], self.driver, group_ids, limit
661
+ )
630
662
  return relevant_nodes
@@ -1,5 +1,6 @@
1
1
  from .client import LLMClient
2
2
  from .config import LLMConfig
3
+ from .errors import RateLimitError
3
4
  from .openai_client import OpenAIClient
4
5
 
5
- __all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
6
+ __all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']
@@ -18,12 +18,14 @@ import json
18
18
  import logging
19
19
  import typing
20
20
 
21
+ import anthropic
21
22
  from anthropic import AsyncAnthropic
22
23
  from openai import AsyncOpenAI
23
24
 
24
25
  from ..prompts.models import Message
25
26
  from .client import LLMClient
26
27
  from .config import LLMConfig
28
+ from .errors import RateLimitError
27
29
 
28
30
  logger = logging.getLogger(__name__)
29
31
 
@@ -35,7 +37,11 @@ class AnthropicClient(LLMClient):
35
37
  if config is None:
36
38
  config = LLMConfig()
37
39
  super().__init__(config, cache)
38
- self.client = AsyncAnthropic(api_key=config.api_key)
40
+ self.client = AsyncAnthropic(
41
+ api_key=config.api_key,
42
+ # we'll use tenacity to retry
43
+ max_retries=1,
44
+ )
39
45
 
40
46
  def get_embedder(self) -> typing.Any:
41
47
  openai_client = AsyncOpenAI()
@@ -58,6 +64,8 @@ class AnthropicClient(LLMClient):
58
64
  )
59
65
 
60
66
  return json.loads('{' + result.content[0].text) # type: ignore
67
+ except anthropic.RateLimitError as e:
68
+ raise RateLimitError from e
61
69
  except Exception as e:
62
70
  logger.error(f'Error in generating LLM response: {e}')
63
71
  raise
@@ -22,10 +22,11 @@ from abc import ABC, abstractmethod
22
22
 
23
23
  import httpx
24
24
  from diskcache import Cache
25
- from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
25
+ from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
26
26
 
27
27
  from ..prompts.models import Message
28
28
  from .config import LLMConfig
29
+ from .errors import RateLimitError
29
30
 
30
31
  DEFAULT_TEMPERATURE = 0
31
32
  DEFAULT_CACHE_DIR = './llm_cache'
@@ -33,7 +34,10 @@ DEFAULT_CACHE_DIR = './llm_cache'
33
34
  logger = logging.getLogger(__name__)
34
35
 
35
36
 
36
- def is_server_error(exception):
37
+ def is_server_or_retry_error(exception):
38
+ if isinstance(exception, RateLimitError):
39
+ return True
40
+
37
41
  return (
38
42
  isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
39
43
  )
@@ -56,18 +60,21 @@ class LLMClient(ABC):
56
60
  pass
57
61
 
58
62
  @retry(
59
- stop=stop_after_attempt(3),
60
- wait=wait_exponential(multiplier=1, min=4, max=10),
61
- retry=retry_if_exception(is_server_error),
63
+ stop=stop_after_attempt(4),
64
+ wait=wait_random_exponential(multiplier=10, min=5, max=120),
65
+ retry=retry_if_exception(is_server_or_retry_error),
66
+ after=lambda retry_state: logger.warning(
67
+ f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
68
+ )
69
+ if retry_state.attempt_number > 1
70
+ else None,
71
+ reraise=True,
62
72
  )
63
73
  async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
64
74
  try:
65
75
  return await self._generate_response(messages)
66
- except httpx.HTTPStatusError as e:
67
- if not is_server_error(e):
68
- raise Exception(f'LLM request error: {e}') from e
69
- else:
70
- raise
76
+ except (httpx.HTTPStatusError, RateLimitError) as e:
77
+ raise e
71
78
 
72
79
  @abstractmethod
73
80
  async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
@@ -0,0 +1,6 @@
1
+ class RateLimitError(Exception):
2
+ """Exception raised when the rate limit is exceeded."""
3
+
4
+ def __init__(self, message='Rate limit exceeded. Please try again later.'):
5
+ self.message = message
6
+ super().__init__(self.message)
@@ -18,6 +18,7 @@ import json
18
18
  import logging
19
19
  import typing
20
20
 
21
+ import groq
21
22
  from groq import AsyncGroq
22
23
  from groq.types.chat import ChatCompletionMessageParam
23
24
  from openai import AsyncOpenAI
@@ -25,6 +26,7 @@ from openai import AsyncOpenAI
25
26
  from ..prompts.models import Message
26
27
  from .client import LLMClient
27
28
  from .config import LLMConfig
29
+ from .errors import RateLimitError
28
30
 
29
31
  logger = logging.getLogger(__name__)
30
32
 
@@ -59,6 +61,8 @@ class GroqClient(LLMClient):
59
61
  )
60
62
  result = response.choices[0].message.content or ''
61
63
  return json.loads(result)
64
+ except groq.RateLimitError as e:
65
+ raise RateLimitError from e
62
66
  except Exception as e:
63
67
  logger.error(f'Error in generating LLM response: {e}')
64
68
  raise
@@ -18,12 +18,14 @@ import json
18
18
  import logging
19
19
  import typing
20
20
 
21
+ import openai
21
22
  from openai import AsyncOpenAI
22
23
  from openai.types.chat import ChatCompletionMessageParam
23
24
 
24
25
  from ..prompts.models import Message
25
26
  from .client import LLMClient
26
27
  from .config import LLMConfig
28
+ from .errors import RateLimitError
27
29
 
28
30
  logger = logging.getLogger(__name__)
29
31
 
@@ -59,6 +61,8 @@ class OpenAIClient(LLMClient):
59
61
  )
60
62
  result = response.choices[0].message.content or ''
61
63
  return json.loads(result)
64
+ except openai.RateLimitError as e:
65
+ raise RateLimitError from e
62
66
  except Exception as e:
63
67
  logger.error(f'Error in generating LLM response: {e}')
64
68
  raise