speedy-utils 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +3 -3
- llm_utils/lm/__init__.py +12 -0
- llm_utils/lm/base_lm.py +337 -0
- llm_utils/lm/chat_session.py +115 -0
- llm_utils/lm/pydantic_lm.py +195 -0
- llm_utils/lm/text_lm.py +130 -0
- llm_utils/lm/utils.py +130 -0
- {speedy_utils-1.0.0.dist-info → speedy_utils-1.0.1.dist-info}/METADATA +1 -1
- {speedy_utils-1.0.0.dist-info → speedy_utils-1.0.1.dist-info}/RECORD +11 -6
- llm_utils/lm.py +0 -742
- {speedy_utils-1.0.0.dist-info → speedy_utils-1.0.1.dist-info}/WHEEL +0 -0
- {speedy_utils-1.0.0.dist-info → speedy_utils-1.0.1.dist-info}/entry_points.txt +0 -0
llm_utils/lm/text_lm.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from typing import Any, List, Optional, Union, Tuple
|
|
2
|
+
from .base_lm import LM
|
|
3
|
+
import random
|
|
4
|
+
import logging
|
|
5
|
+
import json
|
|
6
|
+
from speedy_utils import identify_uuid
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TextLM(LM):
|
|
12
|
+
"""
|
|
13
|
+
Language model that returns outputs as plain text (str).
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def _generate_cache_key(
|
|
17
|
+
self,
|
|
18
|
+
messages: List[Any],
|
|
19
|
+
effective_kwargs: dict,
|
|
20
|
+
) -> str:
|
|
21
|
+
"""Generate a cache key based on input parameters."""
|
|
22
|
+
cache_key_list = self._generate_cache_key_base(messages, effective_kwargs)
|
|
23
|
+
return identify_uuid(str(cache_key_list))
|
|
24
|
+
|
|
25
|
+
def _check_cache(
|
|
26
|
+
self,
|
|
27
|
+
effective_cache: bool,
|
|
28
|
+
messages: List[Any],
|
|
29
|
+
effective_kwargs: dict,
|
|
30
|
+
) -> Tuple[Optional[str], Optional[str]]:
|
|
31
|
+
"""Check if result is in cache and return it if available."""
|
|
32
|
+
if not effective_cache:
|
|
33
|
+
return None, None
|
|
34
|
+
|
|
35
|
+
id_for_cache = self._generate_cache_key(messages, effective_kwargs)
|
|
36
|
+
cached_result = self.load_cache(id_for_cache)
|
|
37
|
+
|
|
38
|
+
if cached_result is not None:
|
|
39
|
+
if isinstance(cached_result, str):
|
|
40
|
+
return id_for_cache, cached_result
|
|
41
|
+
elif (
|
|
42
|
+
isinstance(cached_result, list)
|
|
43
|
+
and cached_result
|
|
44
|
+
and isinstance(cached_result[0], str)
|
|
45
|
+
):
|
|
46
|
+
return id_for_cache, cached_result[0]
|
|
47
|
+
else:
|
|
48
|
+
logger.warning(
|
|
49
|
+
f"Cached result for {id_for_cache} has unexpected type {type(cached_result)}. Ignoring cache."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
return id_for_cache, None
|
|
53
|
+
|
|
54
|
+
def _parse_llm_output(self, llm_output: Any) -> str:
|
|
55
|
+
"""Parse the LLM output into the correct format."""
|
|
56
|
+
if isinstance(llm_output, dict):
|
|
57
|
+
return json.dumps(llm_output)
|
|
58
|
+
|
|
59
|
+
if isinstance(llm_output, str):
|
|
60
|
+
return llm_output
|
|
61
|
+
else:
|
|
62
|
+
logger.warning(
|
|
63
|
+
f"LLM output type mismatch. Expected str, got {type(llm_output)}. Returning str."
|
|
64
|
+
)
|
|
65
|
+
return str(llm_output)
|
|
66
|
+
|
|
67
|
+
def _store_in_cache(
|
|
68
|
+
self, effective_cache: bool, id_for_cache: Optional[str], result: str
|
|
69
|
+
):
|
|
70
|
+
"""Store the result in cache if caching is enabled."""
|
|
71
|
+
self._store_in_cache_base(effective_cache, id_for_cache, result)
|
|
72
|
+
|
|
73
|
+
def forward_messages(
|
|
74
|
+
self,
|
|
75
|
+
messages: List[Any],
|
|
76
|
+
cache: Optional[bool] = None,
|
|
77
|
+
port: Optional[int] = None,
|
|
78
|
+
use_loadbalance: Optional[bool] = None,
|
|
79
|
+
max_tokens: Optional[int] = None,
|
|
80
|
+
**kwargs,
|
|
81
|
+
) -> str:
|
|
82
|
+
# 1. Prepare inputs
|
|
83
|
+
effective_kwargs, effective_cache, current_port, dspy_main_input = (
|
|
84
|
+
self._prepare_call_inputs(
|
|
85
|
+
messages, max_tokens, port, use_loadbalance, cache, **kwargs
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# 2. Check cache
|
|
90
|
+
cache_id, cached_result = self._check_cache(
|
|
91
|
+
effective_cache, messages, effective_kwargs
|
|
92
|
+
)
|
|
93
|
+
if cached_result:
|
|
94
|
+
return cached_result
|
|
95
|
+
|
|
96
|
+
# 3. Call LLM
|
|
97
|
+
llm_output = self._call_llm(
|
|
98
|
+
dspy_main_input, current_port, use_loadbalance, **effective_kwargs
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# 4. Parse output
|
|
102
|
+
result: str = self._parse_llm_output(llm_output)
|
|
103
|
+
|
|
104
|
+
# 5. Store in cache
|
|
105
|
+
self._store_in_cache(effective_cache, cache_id, result)
|
|
106
|
+
|
|
107
|
+
return result
|
|
108
|
+
|
|
109
|
+
def _call_llm(
|
|
110
|
+
self,
|
|
111
|
+
dspy_main_input: List[Any],
|
|
112
|
+
current_port: Optional[int],
|
|
113
|
+
use_loadbalance: Optional[bool],
|
|
114
|
+
**kwargs,
|
|
115
|
+
):
|
|
116
|
+
"""Call the LLM and return the output as plain text (no retry)."""
|
|
117
|
+
# Use messages directly
|
|
118
|
+
messages = dspy_main_input
|
|
119
|
+
|
|
120
|
+
response = self.openai_client.chat.completions.create(
|
|
121
|
+
model=self.model,
|
|
122
|
+
messages=messages,
|
|
123
|
+
**kwargs,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Update port usage stats if needed
|
|
127
|
+
if current_port and use_loadbalance is True:
|
|
128
|
+
self._update_port_use(current_port, -1)
|
|
129
|
+
|
|
130
|
+
return response.choices[0].message.content
|
llm_utils/lm/utils.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import fcntl
|
|
2
|
+
import os
|
|
3
|
+
import tempfile
|
|
4
|
+
import time
|
|
5
|
+
from typing import List, Dict
|
|
6
|
+
import numpy as np
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _clear_port_use(ports):
|
|
11
|
+
for port in ports:
|
|
12
|
+
file_counter: str = f"/tmp/port_use_counter_{port}.npy"
|
|
13
|
+
if os.path.exists(file_counter):
|
|
14
|
+
os.remove(file_counter)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _atomic_save(array: np.ndarray, filename: str):
|
|
18
|
+
tmp_dir = os.path.dirname(filename) or "."
|
|
19
|
+
with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp:
|
|
20
|
+
np.save(tmp, array)
|
|
21
|
+
temp_name = tmp.name
|
|
22
|
+
os.replace(temp_name, filename)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _update_port_use(port: int, increment: int) -> None:
|
|
26
|
+
file_counter: str = f"/tmp/port_use_counter_{port}.npy"
|
|
27
|
+
file_counter_lock: str = f"/tmp/port_use_counter_{port}.lock"
|
|
28
|
+
with open(file_counter_lock, "w") as lock_file:
|
|
29
|
+
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
|
30
|
+
try:
|
|
31
|
+
if os.path.exists(file_counter):
|
|
32
|
+
try:
|
|
33
|
+
counter = np.load(file_counter)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
logger.warning(f"Corrupted usage file {file_counter}: {e}")
|
|
36
|
+
counter = np.array([0])
|
|
37
|
+
else:
|
|
38
|
+
counter: np.ndarray = np.array([0], dtype=np.int64)
|
|
39
|
+
counter[0] += increment
|
|
40
|
+
_atomic_save(counter, file_counter)
|
|
41
|
+
finally:
|
|
42
|
+
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _pick_least_used_port(ports: List[int]) -> int:
|
|
46
|
+
global_lock_file = "/tmp/ports.lock"
|
|
47
|
+
with open(global_lock_file, "w") as lock_file:
|
|
48
|
+
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
|
49
|
+
try:
|
|
50
|
+
port_use: Dict[int, int] = {}
|
|
51
|
+
for port in ports:
|
|
52
|
+
file_counter = f"/tmp/port_use_counter_{port}.npy"
|
|
53
|
+
if os.path.exists(file_counter):
|
|
54
|
+
try:
|
|
55
|
+
counter = np.load(file_counter)
|
|
56
|
+
except Exception as e:
|
|
57
|
+
logger.warning(f"Corrupted usage file {file_counter}: {e}")
|
|
58
|
+
counter = np.array([0])
|
|
59
|
+
else:
|
|
60
|
+
counter = np.array([0])
|
|
61
|
+
port_use[port] = counter[0]
|
|
62
|
+
if not port_use:
|
|
63
|
+
if ports:
|
|
64
|
+
raise ValueError("Port usage data is empty, cannot pick a port.")
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError("No ports provided to pick from.")
|
|
67
|
+
lsp = min(port_use, key=lambda k: port_use[k])
|
|
68
|
+
_update_port_use(lsp, 1)
|
|
69
|
+
finally:
|
|
70
|
+
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
|
71
|
+
return lsp
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def retry_on_exception(max_retries=10, exceptions=(Exception,), sleep_time=3):
|
|
75
|
+
def decorator(func):
|
|
76
|
+
from functools import wraps
|
|
77
|
+
|
|
78
|
+
def wrapper(self, *args, **kwargs):
|
|
79
|
+
retry_count = kwargs.get("retry_count", 0)
|
|
80
|
+
last_exception = None
|
|
81
|
+
while retry_count <= max_retries:
|
|
82
|
+
try:
|
|
83
|
+
return func(self, *args, **kwargs)
|
|
84
|
+
except exceptions as e:
|
|
85
|
+
import litellm
|
|
86
|
+
|
|
87
|
+
if isinstance(
|
|
88
|
+
e, (litellm.exceptions.APIError, litellm.exceptions.Timeout)
|
|
89
|
+
):
|
|
90
|
+
base_url_info = kwargs.get(
|
|
91
|
+
"base_url", getattr(self, "base_url", None)
|
|
92
|
+
)
|
|
93
|
+
logger.warning(
|
|
94
|
+
f"[{base_url_info=}] {type(e).__name__}: {str(e)[:100]}, will sleep for {sleep_time}s and retry"
|
|
95
|
+
)
|
|
96
|
+
time.sleep(sleep_time)
|
|
97
|
+
retry_count += 1
|
|
98
|
+
kwargs["retry_count"] = retry_count
|
|
99
|
+
last_exception = e
|
|
100
|
+
continue
|
|
101
|
+
elif hasattr(
|
|
102
|
+
litellm.exceptions, "ContextWindowExceededError"
|
|
103
|
+
) and isinstance(e, litellm.exceptions.ContextWindowExceededError):
|
|
104
|
+
logger.error(f"Context window exceeded: {e}")
|
|
105
|
+
raise
|
|
106
|
+
else:
|
|
107
|
+
logger.error(f"Generic error during LLM call: {e}")
|
|
108
|
+
import traceback
|
|
109
|
+
|
|
110
|
+
traceback.print_exc()
|
|
111
|
+
raise
|
|
112
|
+
logger.error(f"Retry limit exceeded, error: {last_exception}")
|
|
113
|
+
if last_exception:
|
|
114
|
+
raise last_exception
|
|
115
|
+
raise ValueError("Retry limit exceeded with no specific error.")
|
|
116
|
+
|
|
117
|
+
return wraps(func)(wrapper)
|
|
118
|
+
|
|
119
|
+
return decorator
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def forward_only(func):
|
|
123
|
+
from functools import wraps
|
|
124
|
+
|
|
125
|
+
@wraps(func)
|
|
126
|
+
def wrapper(self, *args, **kwargs):
|
|
127
|
+
kwargs["retry_count"] = 0
|
|
128
|
+
return func(self, *args, **kwargs)
|
|
129
|
+
|
|
130
|
+
return wrapper
|
|
@@ -1,7 +1,12 @@
|
|
|
1
|
-
llm_utils/__init__.py,sha256=
|
|
1
|
+
llm_utils/__init__.py,sha256=ujANp90z6cBoRAufbDbRs85KsDaghKc4-wzEugl6meg,717
|
|
2
2
|
llm_utils/chat_format.py,sha256=ZY2HYv3FPL2xiMxbbO-huIwT5LZrcJm_if_us-2eSZ4,15094
|
|
3
3
|
llm_utils/group_messages.py,sha256=GKMQkenQf-6DD_1EJa11UBj7-VfkGT7xVhR_B_zMzqY,3868
|
|
4
|
-
llm_utils/lm.py,sha256=
|
|
4
|
+
llm_utils/lm/__init__.py,sha256=UZ5Ij-d5Xf64luuf1aAiDE1CFRoqEgoCU8xWuuLMBis,226
|
|
5
|
+
llm_utils/lm/base_lm.py,sha256=daT9mueYSZtHE0QMmwRQT_kvjC82KJmZ-u6_1gOIsrw,12560
|
|
6
|
+
llm_utils/lm/chat_session.py,sha256=rkjj1caGu03o7I3d8q_HJzHA9CVzV-Q-mz6iWgqHno0,3301
|
|
7
|
+
llm_utils/lm/pydantic_lm.py,sha256=f-eFFCoCwbPUTZAoS3s217upzUSUHxu-kKlZz6Jl5Ck,6628
|
|
8
|
+
llm_utils/lm/text_lm.py,sha256=xt41ZKDFjIfAe0e2rdvTWePNq5tdvdPzW64UJ8thuXw,4069
|
|
9
|
+
llm_utils/lm/utils.py,sha256=-fDNueiXKQI6RDoNHJYNyORomf2XlCf2doJZ3GEV2Io,4762
|
|
5
10
|
llm_utils/lm_classification.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
11
|
llm_utils/load_chat_dataset.py,sha256=hsPPlOmZEDqsg7GQD7SgxTEGJky6JDm6jnG3JHbpjb4,1895
|
|
7
12
|
llm_utils/scripts/vllm_load_balancer.py,sha256=uSjGd_jOmI9W9eVOhiOXbeUnZkQq9xG4bCVzhmpupcA,16096
|
|
@@ -21,7 +26,7 @@ speedy_utils/multi_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
|
|
|
21
26
|
speedy_utils/multi_worker/process.py,sha256=XwQlffxzRFnCVeKjDNBZDwFfUQHiJiuFA12MRGJVru8,6708
|
|
22
27
|
speedy_utils/multi_worker/thread.py,sha256=9pXjvgjD0s0Hp0cZ6I3M0ndp1OlYZ1yvqbs_bcun_Kw,12775
|
|
23
28
|
speedy_utils/scripts/mpython.py,sha256=ZzkBWI5Xw3vPoMx8xQt2x4mOFRjtwWqfvAJ5_ngyWgw,3816
|
|
24
|
-
speedy_utils-1.0.
|
|
25
|
-
speedy_utils-1.0.
|
|
26
|
-
speedy_utils-1.0.
|
|
27
|
-
speedy_utils-1.0.
|
|
29
|
+
speedy_utils-1.0.1.dist-info/METADATA,sha256=U-8IfrZTleZAsydTg-22oTyjpzsW0pKTPV3P7Jp2WVw,7165
|
|
30
|
+
speedy_utils-1.0.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
31
|
+
speedy_utils-1.0.1.dist-info/entry_points.txt,sha256=fsv8_lMg62BeswoUHrqfj2u6q2l4YcDCw7AgQFg6GRw,61
|
|
32
|
+
speedy_utils-1.0.1.dist-info/RECORD,,
|