graphiti-core 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
- graphiti_core/edges.py +13 -10
- graphiti_core/graphiti.py +25 -27
- graphiti_core/helpers.py +25 -0
- graphiti_core/llm_client/anthropic_client.py +4 -1
- graphiti_core/llm_client/client.py +45 -5
- graphiti_core/llm_client/errors.py +8 -0
- graphiti_core/llm_client/groq_client.py +4 -1
- graphiti_core/llm_client/openai_client.py +71 -7
- graphiti_core/llm_client/openai_generic_client.py +163 -0
- graphiti_core/nodes.py +16 -12
- graphiti_core/prompts/dedupe_edges.py +20 -17
- graphiti_core/prompts/dedupe_nodes.py +15 -1
- graphiti_core/prompts/eval.py +17 -14
- graphiti_core/prompts/extract_edge_dates.py +15 -7
- graphiti_core/prompts/extract_edges.py +18 -19
- graphiti_core/prompts/extract_nodes.py +11 -21
- graphiti_core/prompts/invalidate_edges.py +13 -25
- graphiti_core/prompts/summarize_nodes.py +17 -16
- graphiti_core/search/search.py +5 -5
- graphiti_core/search/search_utils.py +54 -13
- graphiti_core/utils/__init__.py +0 -15
- graphiti_core/utils/bulk_utils.py +22 -15
- graphiti_core/utils/datetime_utils.py +42 -0
- graphiti_core/utils/maintenance/community_operations.py +13 -9
- graphiti_core/utils/maintenance/edge_operations.py +26 -19
- graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
- graphiti_core/utils/maintenance/node_operations.py +19 -13
- graphiti_core/utils/maintenance/temporal_operations.py +16 -7
- {graphiti_core-0.4.3.dist-info → graphiti_core-0.5.0.dist-info}/METADATA +1 -1
- graphiti_core-0.5.0.dist-info/RECORD +60 -0
- graphiti_core-0.4.3.dist-info/RECORD +0 -58
- {graphiti_core-0.4.3.dist-info → graphiti_core-0.5.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.4.3.dist-info → graphiti_core-0.5.0.dist-info}/WHEEL +0 -0
|
@@ -34,7 +34,7 @@ class CrossEncoderClient(ABC):
|
|
|
34
34
|
passages (list[str]): A list of passages to rank.
|
|
35
35
|
|
|
36
36
|
Returns:
|
|
37
|
-
|
|
37
|
+
list[tuple[str, float]]: A list of tuples containing the passage and its score,
|
|
38
38
|
sorted in descending order of relevance.
|
|
39
39
|
"""
|
|
40
40
|
pass
|
|
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
18
|
from typing import Any
|
|
20
19
|
|
|
@@ -22,6 +21,7 @@ import openai
|
|
|
22
21
|
from openai import AsyncOpenAI
|
|
23
22
|
from pydantic import BaseModel
|
|
24
23
|
|
|
24
|
+
from ..helpers import semaphore_gather
|
|
25
25
|
from ..llm_client import LLMConfig, RateLimitError
|
|
26
26
|
from ..prompts import Message
|
|
27
27
|
from .client import CrossEncoderClient
|
|
@@ -75,7 +75,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
75
75
|
for passage in passages
|
|
76
76
|
]
|
|
77
77
|
try:
|
|
78
|
-
responses = await
|
|
78
|
+
responses = await semaphore_gather(
|
|
79
79
|
*[
|
|
80
80
|
self.client.chat.completions.create(
|
|
81
81
|
model=DEFAULT_MODEL,
|
graphiti_core/edges.py
CHANGED
|
@@ -27,7 +27,7 @@ from typing_extensions import LiteralString
|
|
|
27
27
|
|
|
28
28
|
from graphiti_core.embedder import EmbedderClient
|
|
29
29
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
30
|
-
from graphiti_core.helpers import DEFAULT_DATABASE,
|
|
30
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
|
|
31
31
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
32
32
|
COMMUNITY_EDGE_SAVE,
|
|
33
33
|
ENTITY_EDGE_SAVE,
|
|
@@ -142,10 +142,11 @@ class EpisodicEdge(Edge):
|
|
|
142
142
|
cls,
|
|
143
143
|
driver: AsyncDriver,
|
|
144
144
|
group_ids: list[str],
|
|
145
|
-
limit: int =
|
|
145
|
+
limit: int | None = None,
|
|
146
146
|
created_at: datetime | None = None,
|
|
147
147
|
):
|
|
148
148
|
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
149
|
+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
149
150
|
|
|
150
151
|
records, _, _ = await driver.execute_query(
|
|
151
152
|
"""
|
|
@@ -161,8 +162,8 @@ class EpisodicEdge(Edge):
|
|
|
161
162
|
m.uuid AS target_node_uuid,
|
|
162
163
|
e.created_at AS created_at
|
|
163
164
|
ORDER BY e.uuid DESC
|
|
164
|
-
|
|
165
|
-
|
|
165
|
+
"""
|
|
166
|
+
+ limit_query,
|
|
166
167
|
group_ids=group_ids,
|
|
167
168
|
created_at=created_at,
|
|
168
169
|
limit=limit,
|
|
@@ -294,10 +295,11 @@ class EntityEdge(Edge):
|
|
|
294
295
|
cls,
|
|
295
296
|
driver: AsyncDriver,
|
|
296
297
|
group_ids: list[str],
|
|
297
|
-
limit: int =
|
|
298
|
+
limit: int | None = None,
|
|
298
299
|
created_at: datetime | None = None,
|
|
299
300
|
):
|
|
300
301
|
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
302
|
+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
301
303
|
|
|
302
304
|
records, _, _ = await driver.execute_query(
|
|
303
305
|
"""
|
|
@@ -320,8 +322,8 @@ class EntityEdge(Edge):
|
|
|
320
322
|
e.valid_at AS valid_at,
|
|
321
323
|
e.invalid_at AS invalid_at
|
|
322
324
|
ORDER BY e.uuid DESC
|
|
323
|
-
|
|
324
|
-
|
|
325
|
+
"""
|
|
326
|
+
+ limit_query,
|
|
325
327
|
group_ids=group_ids,
|
|
326
328
|
created_at=created_at,
|
|
327
329
|
limit=limit,
|
|
@@ -400,10 +402,11 @@ class CommunityEdge(Edge):
|
|
|
400
402
|
cls,
|
|
401
403
|
driver: AsyncDriver,
|
|
402
404
|
group_ids: list[str],
|
|
403
|
-
limit: int =
|
|
405
|
+
limit: int | None = None,
|
|
404
406
|
created_at: datetime | None = None,
|
|
405
407
|
):
|
|
406
408
|
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
|
|
409
|
+
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
407
410
|
|
|
408
411
|
records, _, _ = await driver.execute_query(
|
|
409
412
|
"""
|
|
@@ -419,8 +422,8 @@ class CommunityEdge(Edge):
|
|
|
419
422
|
m.uuid AS target_node_uuid,
|
|
420
423
|
e.created_at AS created_at
|
|
421
424
|
ORDER BY e.uuid DESC
|
|
422
|
-
|
|
423
|
-
|
|
425
|
+
"""
|
|
426
|
+
+ limit_query,
|
|
424
427
|
group_ids=group_ids,
|
|
425
428
|
created_at=created_at,
|
|
426
429
|
limit=limit,
|
graphiti_core/graphiti.py
CHANGED
|
@@ -14,9 +14,8 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import logging
|
|
19
|
-
from datetime import datetime
|
|
18
|
+
from datetime import datetime
|
|
20
19
|
from time import time
|
|
21
20
|
|
|
22
21
|
from dotenv import load_dotenv
|
|
@@ -27,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
|
27
26
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
28
27
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
29
28
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
30
|
-
from graphiti_core.helpers import DEFAULT_DATABASE
|
|
29
|
+
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
|
31
30
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
32
31
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
|
33
32
|
from graphiti_core.search.search import SearchConfig, search
|
|
@@ -43,10 +42,6 @@ from graphiti_core.search.search_utils import (
|
|
|
43
42
|
get_relevant_edges,
|
|
44
43
|
get_relevant_nodes,
|
|
45
44
|
)
|
|
46
|
-
from graphiti_core.utils import (
|
|
47
|
-
build_episodic_edges,
|
|
48
|
-
retrieve_episodes,
|
|
49
|
-
)
|
|
50
45
|
from graphiti_core.utils.bulk_utils import (
|
|
51
46
|
RawEpisode,
|
|
52
47
|
add_nodes_and_edges_bulk,
|
|
@@ -57,12 +52,14 @@ from graphiti_core.utils.bulk_utils import (
|
|
|
57
52
|
resolve_edge_pointers,
|
|
58
53
|
retrieve_previous_episodes_bulk,
|
|
59
54
|
)
|
|
55
|
+
from graphiti_core.utils.datetime_utils import utc_now
|
|
60
56
|
from graphiti_core.utils.maintenance.community_operations import (
|
|
61
57
|
build_communities,
|
|
62
58
|
remove_communities,
|
|
63
59
|
update_community,
|
|
64
60
|
)
|
|
65
61
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
62
|
+
build_episodic_edges,
|
|
66
63
|
dedupe_extracted_edge,
|
|
67
64
|
extract_edges,
|
|
68
65
|
resolve_edge_contradictions,
|
|
@@ -71,6 +68,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
|
|
71
68
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
72
69
|
EPISODE_WINDOW_LEN,
|
|
73
70
|
build_indices_and_constraints,
|
|
71
|
+
retrieve_episodes,
|
|
74
72
|
)
|
|
75
73
|
from graphiti_core.utils.maintenance.node_operations import (
|
|
76
74
|
extract_nodes,
|
|
@@ -313,7 +311,7 @@ class Graphiti:
|
|
|
313
311
|
start = time()
|
|
314
312
|
|
|
315
313
|
entity_edges: list[EntityEdge] = []
|
|
316
|
-
now =
|
|
314
|
+
now = utc_now()
|
|
317
315
|
|
|
318
316
|
previous_episodes = await self.retrieve_episodes(
|
|
319
317
|
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
|
|
@@ -341,13 +339,13 @@ class Graphiti:
|
|
|
341
339
|
|
|
342
340
|
# Calculate Embeddings
|
|
343
341
|
|
|
344
|
-
await
|
|
342
|
+
await semaphore_gather(
|
|
345
343
|
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
|
|
346
344
|
)
|
|
347
345
|
|
|
348
346
|
# Find relevant nodes already in the graph
|
|
349
347
|
existing_nodes_lists: list[list[EntityNode]] = list(
|
|
350
|
-
await
|
|
348
|
+
await semaphore_gather(
|
|
351
349
|
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
|
|
352
350
|
)
|
|
353
351
|
)
|
|
@@ -355,7 +353,7 @@ class Graphiti:
|
|
|
355
353
|
# Resolve extracted nodes with nodes already in the graph and extract facts
|
|
356
354
|
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
|
357
355
|
|
|
358
|
-
(mentioned_nodes, uuid_map), extracted_edges = await
|
|
356
|
+
(mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
|
|
359
357
|
resolve_extracted_nodes(
|
|
360
358
|
self.llm_client,
|
|
361
359
|
extracted_nodes,
|
|
@@ -375,7 +373,7 @@ class Graphiti:
|
|
|
375
373
|
)
|
|
376
374
|
|
|
377
375
|
# calculate embeddings
|
|
378
|
-
await
|
|
376
|
+
await semaphore_gather(
|
|
379
377
|
*[
|
|
380
378
|
edge.generate_embedding(self.embedder)
|
|
381
379
|
for edge in extracted_edges_with_resolved_pointers
|
|
@@ -384,7 +382,7 @@ class Graphiti:
|
|
|
384
382
|
|
|
385
383
|
# Resolve extracted edges with related edges already in the graph
|
|
386
384
|
related_edges_list: list[list[EntityEdge]] = list(
|
|
387
|
-
await
|
|
385
|
+
await semaphore_gather(
|
|
388
386
|
*[
|
|
389
387
|
get_relevant_edges(
|
|
390
388
|
self.driver,
|
|
@@ -405,7 +403,7 @@ class Graphiti:
|
|
|
405
403
|
)
|
|
406
404
|
|
|
407
405
|
existing_source_edges_list: list[list[EntityEdge]] = list(
|
|
408
|
-
await
|
|
406
|
+
await semaphore_gather(
|
|
409
407
|
*[
|
|
410
408
|
get_relevant_edges(
|
|
411
409
|
self.driver,
|
|
@@ -420,7 +418,7 @@ class Graphiti:
|
|
|
420
418
|
)
|
|
421
419
|
|
|
422
420
|
existing_target_edges_list: list[list[EntityEdge]] = list(
|
|
423
|
-
await
|
|
421
|
+
await semaphore_gather(
|
|
424
422
|
*[
|
|
425
423
|
get_relevant_edges(
|
|
426
424
|
self.driver,
|
|
@@ -469,7 +467,7 @@ class Graphiti:
|
|
|
469
467
|
|
|
470
468
|
# Update any communities
|
|
471
469
|
if update_communities:
|
|
472
|
-
await
|
|
470
|
+
await semaphore_gather(
|
|
473
471
|
*[
|
|
474
472
|
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
475
473
|
for node in nodes
|
|
@@ -522,7 +520,7 @@ class Graphiti:
|
|
|
522
520
|
"""
|
|
523
521
|
try:
|
|
524
522
|
start = time()
|
|
525
|
-
now =
|
|
523
|
+
now = utc_now()
|
|
526
524
|
|
|
527
525
|
episodes = [
|
|
528
526
|
EpisodicNode(
|
|
@@ -539,7 +537,7 @@ class Graphiti:
|
|
|
539
537
|
]
|
|
540
538
|
|
|
541
539
|
# Save all the episodes
|
|
542
|
-
await
|
|
540
|
+
await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
|
|
543
541
|
|
|
544
542
|
# Get previous episode context for each episode
|
|
545
543
|
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
|
@@ -552,19 +550,19 @@ class Graphiti:
|
|
|
552
550
|
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
|
|
553
551
|
|
|
554
552
|
# Generate embeddings
|
|
555
|
-
await
|
|
553
|
+
await semaphore_gather(
|
|
556
554
|
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
|
557
555
|
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
|
558
556
|
)
|
|
559
557
|
|
|
560
558
|
# Dedupe extracted nodes, compress extracted edges
|
|
561
|
-
(nodes, uuid_map), extracted_edges_timestamped = await
|
|
559
|
+
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
|
|
562
560
|
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
|
563
561
|
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
|
564
562
|
)
|
|
565
563
|
|
|
566
564
|
# save nodes to KG
|
|
567
|
-
await
|
|
565
|
+
await semaphore_gather(*[node.save(self.driver) for node in nodes])
|
|
568
566
|
|
|
569
567
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
|
570
568
|
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
|
@@ -575,7 +573,7 @@ class Graphiti:
|
|
|
575
573
|
)
|
|
576
574
|
|
|
577
575
|
# save episodic edges to KG
|
|
578
|
-
await
|
|
576
|
+
await semaphore_gather(
|
|
579
577
|
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
|
|
580
578
|
)
|
|
581
579
|
|
|
@@ -588,7 +586,7 @@ class Graphiti:
|
|
|
588
586
|
# invalidate edges
|
|
589
587
|
|
|
590
588
|
# save edges to KG
|
|
591
|
-
await
|
|
589
|
+
await semaphore_gather(*[edge.save(self.driver) for edge in edges])
|
|
592
590
|
|
|
593
591
|
end = time()
|
|
594
592
|
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
|
@@ -611,12 +609,12 @@ class Graphiti:
|
|
|
611
609
|
self.driver, self.llm_client, group_ids
|
|
612
610
|
)
|
|
613
611
|
|
|
614
|
-
await
|
|
612
|
+
await semaphore_gather(
|
|
615
613
|
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
|
616
614
|
)
|
|
617
615
|
|
|
618
|
-
await
|
|
619
|
-
await
|
|
616
|
+
await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
|
|
617
|
+
await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
|
|
620
618
|
|
|
621
619
|
return community_nodes
|
|
622
620
|
|
|
@@ -699,7 +697,7 @@ class Graphiti:
|
|
|
699
697
|
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
|
700
698
|
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
|
701
699
|
|
|
702
|
-
edges_list = await
|
|
700
|
+
edges_list = await semaphore_gather(
|
|
703
701
|
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
|
704
702
|
)
|
|
705
703
|
|
graphiti_core/helpers.py
CHANGED
|
@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import asyncio
|
|
17
18
|
import os
|
|
19
|
+
from collections.abc import Coroutine
|
|
18
20
|
from datetime import datetime
|
|
19
21
|
|
|
20
22
|
import numpy as np
|
|
@@ -25,6 +27,7 @@ load_dotenv()
|
|
|
25
27
|
|
|
26
28
|
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
|
|
27
29
|
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
|
|
30
|
+
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
|
|
28
31
|
MAX_REFLEXION_ITERATIONS = 2
|
|
29
32
|
DEFAULT_PAGE_LIMIT = 20
|
|
30
33
|
|
|
@@ -57,6 +60,12 @@ def lucene_sanitize(query: str) -> str:
|
|
|
57
60
|
':': r'\:',
|
|
58
61
|
'\\': r'\\',
|
|
59
62
|
'/': r'\/',
|
|
63
|
+
'O': r'\O',
|
|
64
|
+
'R': r'\R',
|
|
65
|
+
'N': r'\N',
|
|
66
|
+
'T': r'\T',
|
|
67
|
+
'A': r'\A',
|
|
68
|
+
'D': r'\D',
|
|
60
69
|
}
|
|
61
70
|
)
|
|
62
71
|
|
|
@@ -74,3 +83,19 @@ def normalize_l2(embedding: list[float]) -> list[float]:
|
|
|
74
83
|
else:
|
|
75
84
|
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
|
76
85
|
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# Use this instead of asyncio.gather() to bound coroutines
|
|
89
|
+
async def semaphore_gather(
|
|
90
|
+
*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT, return_exceptions=True
|
|
91
|
+
):
|
|
92
|
+
semaphore = asyncio.Semaphore(max_coroutines)
|
|
93
|
+
|
|
94
|
+
async def _wrap_coroutine(coroutine):
|
|
95
|
+
async with semaphore:
|
|
96
|
+
return await coroutine
|
|
97
|
+
|
|
98
|
+
return await asyncio.gather(
|
|
99
|
+
*(_wrap_coroutine(coroutine) for coroutine in coroutines),
|
|
100
|
+
return_exceptions=return_exceptions,
|
|
101
|
+
)
|
|
@@ -20,6 +20,7 @@ import typing
|
|
|
20
20
|
|
|
21
21
|
import anthropic
|
|
22
22
|
from anthropic import AsyncAnthropic
|
|
23
|
+
from pydantic import BaseModel
|
|
23
24
|
|
|
24
25
|
from ..prompts.models import Message
|
|
25
26
|
from .client import LLMClient
|
|
@@ -46,7 +47,9 @@ class AnthropicClient(LLMClient):
|
|
|
46
47
|
max_retries=1,
|
|
47
48
|
)
|
|
48
49
|
|
|
49
|
-
async def _generate_response(
|
|
50
|
+
async def _generate_response(
|
|
51
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
52
|
+
) -> dict[str, typing.Any]:
|
|
50
53
|
system_message = messages[0]
|
|
51
54
|
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
|
|
52
55
|
{'role': 'assistant', 'content': '{'}
|
|
@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
|
|
|
22
22
|
|
|
23
23
|
import httpx
|
|
24
24
|
from diskcache import Cache
|
|
25
|
+
from pydantic import BaseModel
|
|
25
26
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
|
26
27
|
|
|
27
28
|
from ..prompts.models import Message
|
|
@@ -55,6 +56,28 @@ class LLMClient(ABC):
|
|
|
55
56
|
self.cache_enabled = cache
|
|
56
57
|
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
|
57
58
|
|
|
59
|
+
def _clean_input(self, input: str) -> str:
|
|
60
|
+
"""Clean input string of invalid unicode and control characters.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
input: Raw input string to be cleaned
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Cleaned string safe for LLM processing
|
|
67
|
+
"""
|
|
68
|
+
# Clean any invalid Unicode
|
|
69
|
+
cleaned = input.encode('utf-8', errors='ignore').decode('utf-8')
|
|
70
|
+
|
|
71
|
+
# Remove zero-width characters and other invisible unicode
|
|
72
|
+
zero_width = '\u200b\u200c\u200d\ufeff\u2060'
|
|
73
|
+
for char in zero_width:
|
|
74
|
+
cleaned = cleaned.replace(char, '')
|
|
75
|
+
|
|
76
|
+
# Remove control characters except newlines, returns, and tabs
|
|
77
|
+
cleaned = ''.join(char for char in cleaned if ord(char) >= 32 or char in '\n\r\t')
|
|
78
|
+
|
|
79
|
+
return cleaned
|
|
80
|
+
|
|
58
81
|
@retry(
|
|
59
82
|
stop=stop_after_attempt(4),
|
|
60
83
|
wait=wait_random_exponential(multiplier=10, min=5, max=120),
|
|
@@ -66,14 +89,18 @@ class LLMClient(ABC):
|
|
|
66
89
|
else None,
|
|
67
90
|
reraise=True,
|
|
68
91
|
)
|
|
69
|
-
async def _generate_response_with_retry(
|
|
92
|
+
async def _generate_response_with_retry(
|
|
93
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
94
|
+
) -> dict[str, typing.Any]:
|
|
70
95
|
try:
|
|
71
|
-
return await self._generate_response(messages)
|
|
96
|
+
return await self._generate_response(messages, response_model)
|
|
72
97
|
except (httpx.HTTPStatusError, RateLimitError) as e:
|
|
73
98
|
raise e
|
|
74
99
|
|
|
75
100
|
@abstractmethod
|
|
76
|
-
async def _generate_response(
|
|
101
|
+
async def _generate_response(
|
|
102
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
103
|
+
) -> dict[str, typing.Any]:
|
|
77
104
|
pass
|
|
78
105
|
|
|
79
106
|
def _get_cache_key(self, messages: list[Message]) -> str:
|
|
@@ -82,7 +109,17 @@ class LLMClient(ABC):
|
|
|
82
109
|
key_str = f'{self.model}:{message_str}'
|
|
83
110
|
return hashlib.md5(key_str.encode()).hexdigest()
|
|
84
111
|
|
|
85
|
-
async def generate_response(
|
|
112
|
+
async def generate_response(
|
|
113
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
114
|
+
) -> dict[str, typing.Any]:
|
|
115
|
+
if response_model is not None:
|
|
116
|
+
serialized_model = json.dumps(response_model.model_json_schema())
|
|
117
|
+
messages[
|
|
118
|
+
-1
|
|
119
|
+
].content += (
|
|
120
|
+
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
|
121
|
+
)
|
|
122
|
+
|
|
86
123
|
if self.cache_enabled:
|
|
87
124
|
cache_key = self._get_cache_key(messages)
|
|
88
125
|
|
|
@@ -91,7 +128,10 @@ class LLMClient(ABC):
|
|
|
91
128
|
logger.debug(f'Cache hit for {cache_key}')
|
|
92
129
|
return cached_response
|
|
93
130
|
|
|
94
|
-
|
|
131
|
+
for message in messages:
|
|
132
|
+
message.content = self._clean_input(message.content)
|
|
133
|
+
|
|
134
|
+
response = await self._generate_response_with_retry(messages, response_model)
|
|
95
135
|
|
|
96
136
|
if self.cache_enabled:
|
|
97
137
|
self.cache_dir.set(cache_key, response)
|
|
@@ -21,3 +21,11 @@ class RateLimitError(Exception):
|
|
|
21
21
|
def __init__(self, message='Rate limit exceeded. Please try again later.'):
|
|
22
22
|
self.message = message
|
|
23
23
|
super().__init__(self.message)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RefusalError(Exception):
|
|
27
|
+
"""Exception raised when the LLM refuses to generate a response."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, message: str):
|
|
30
|
+
self.message = message
|
|
31
|
+
super().__init__(self.message)
|
|
@@ -21,6 +21,7 @@ import typing
|
|
|
21
21
|
import groq
|
|
22
22
|
from groq import AsyncGroq
|
|
23
23
|
from groq.types.chat import ChatCompletionMessageParam
|
|
24
|
+
from pydantic import BaseModel
|
|
24
25
|
|
|
25
26
|
from ..prompts.models import Message
|
|
26
27
|
from .client import LLMClient
|
|
@@ -43,7 +44,9 @@ class GroqClient(LLMClient):
|
|
|
43
44
|
|
|
44
45
|
self.client = AsyncGroq(api_key=config.api_key)
|
|
45
46
|
|
|
46
|
-
async def _generate_response(
|
|
47
|
+
async def _generate_response(
|
|
48
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
49
|
+
) -> dict[str, typing.Any]:
|
|
47
50
|
msgs: list[ChatCompletionMessageParam] = []
|
|
48
51
|
for m in messages:
|
|
49
52
|
if m.role == 'user':
|
|
@@ -14,18 +14,19 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
import json
|
|
18
17
|
import logging
|
|
19
18
|
import typing
|
|
19
|
+
from typing import ClassVar
|
|
20
20
|
|
|
21
21
|
import openai
|
|
22
22
|
from openai import AsyncOpenAI
|
|
23
23
|
from openai.types.chat import ChatCompletionMessageParam
|
|
24
|
+
from pydantic import BaseModel
|
|
24
25
|
|
|
25
26
|
from ..prompts.models import Message
|
|
26
27
|
from .client import LLMClient
|
|
27
28
|
from .config import LLMConfig
|
|
28
|
-
from .errors import RateLimitError
|
|
29
|
+
from .errors import RateLimitError, RefusalError
|
|
29
30
|
|
|
30
31
|
logger = logging.getLogger(__name__)
|
|
31
32
|
|
|
@@ -53,6 +54,9 @@ class OpenAIClient(LLMClient):
|
|
|
53
54
|
Generates a response from the language model based on the provided messages.
|
|
54
55
|
"""
|
|
55
56
|
|
|
57
|
+
# Class-level constants
|
|
58
|
+
MAX_RETRIES: ClassVar[int] = 2
|
|
59
|
+
|
|
56
60
|
def __init__(
|
|
57
61
|
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
|
|
58
62
|
):
|
|
@@ -65,6 +69,10 @@ class OpenAIClient(LLMClient):
|
|
|
65
69
|
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
|
|
66
70
|
|
|
67
71
|
"""
|
|
72
|
+
# removed caching to simplify the `generate_response` override
|
|
73
|
+
if cache:
|
|
74
|
+
raise NotImplementedError('Caching is not implemented for OpenAI')
|
|
75
|
+
|
|
68
76
|
if config is None:
|
|
69
77
|
config = LLMConfig()
|
|
70
78
|
|
|
@@ -75,25 +83,81 @@ class OpenAIClient(LLMClient):
|
|
|
75
83
|
else:
|
|
76
84
|
self.client = client
|
|
77
85
|
|
|
78
|
-
async def _generate_response(
|
|
86
|
+
async def _generate_response(
|
|
87
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
88
|
+
) -> dict[str, typing.Any]:
|
|
79
89
|
openai_messages: list[ChatCompletionMessageParam] = []
|
|
80
90
|
for m in messages:
|
|
91
|
+
m.content = self._clean_input(m.content)
|
|
81
92
|
if m.role == 'user':
|
|
82
93
|
openai_messages.append({'role': 'user', 'content': m.content})
|
|
83
94
|
elif m.role == 'system':
|
|
84
95
|
openai_messages.append({'role': 'system', 'content': m.content})
|
|
85
96
|
try:
|
|
86
|
-
response = await self.client.chat.completions.
|
|
97
|
+
response = await self.client.beta.chat.completions.parse(
|
|
87
98
|
model=self.model or DEFAULT_MODEL,
|
|
88
99
|
messages=openai_messages,
|
|
89
100
|
temperature=self.temperature,
|
|
90
101
|
max_tokens=self.max_tokens,
|
|
91
|
-
response_format=
|
|
102
|
+
response_format=response_model, # type: ignore
|
|
92
103
|
)
|
|
93
|
-
|
|
94
|
-
|
|
104
|
+
|
|
105
|
+
response_object = response.choices[0].message
|
|
106
|
+
|
|
107
|
+
if response_object.parsed:
|
|
108
|
+
return response_object.parsed.model_dump()
|
|
109
|
+
elif response_object.refusal:
|
|
110
|
+
raise RefusalError(response_object.refusal)
|
|
111
|
+
else:
|
|
112
|
+
raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
|
|
113
|
+
except openai.LengthFinishReasonError as e:
|
|
114
|
+
raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
|
|
95
115
|
except openai.RateLimitError as e:
|
|
96
116
|
raise RateLimitError from e
|
|
97
117
|
except Exception as e:
|
|
98
118
|
logger.error(f'Error in generating LLM response: {e}')
|
|
99
119
|
raise
|
|
120
|
+
|
|
121
|
+
async def generate_response(
|
|
122
|
+
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
|
123
|
+
) -> dict[str, typing.Any]:
|
|
124
|
+
retry_count = 0
|
|
125
|
+
last_error = None
|
|
126
|
+
|
|
127
|
+
while retry_count <= self.MAX_RETRIES:
|
|
128
|
+
try:
|
|
129
|
+
response = await self._generate_response(messages, response_model)
|
|
130
|
+
return response
|
|
131
|
+
except (RateLimitError, RefusalError):
|
|
132
|
+
# These errors should not trigger retries
|
|
133
|
+
raise
|
|
134
|
+
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError):
|
|
135
|
+
# Let OpenAI's client handle these retries
|
|
136
|
+
raise
|
|
137
|
+
except Exception as e:
|
|
138
|
+
last_error = e
|
|
139
|
+
|
|
140
|
+
# Don't retry if we've hit the max retries
|
|
141
|
+
if retry_count >= self.MAX_RETRIES:
|
|
142
|
+
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
|
|
143
|
+
raise
|
|
144
|
+
|
|
145
|
+
retry_count += 1
|
|
146
|
+
|
|
147
|
+
# Construct a detailed error message for the LLM
|
|
148
|
+
error_context = (
|
|
149
|
+
f'The previous response attempt was invalid. '
|
|
150
|
+
f'Error type: {e.__class__.__name__}. '
|
|
151
|
+
f'Error details: {str(e)}. '
|
|
152
|
+
f'Please try again with a valid response, ensuring the output matches '
|
|
153
|
+
f'the expected format and constraints.'
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
error_message = Message(role='user', content=error_context)
|
|
157
|
+
messages.append(error_message)
|
|
158
|
+
logger.warning(
|
|
159
|
+
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# If we somehow get here, raise the last error
|
|
163
|
+
raise last_error or Exception('Max retries exceeded with no specific error')
|