speedy-utils 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +3 -3
- llm_utils/lm/__init__.py +12 -0
- llm_utils/lm/base_lm.py +337 -0
- llm_utils/lm/chat_session.py +115 -0
- llm_utils/lm/pydantic_lm.py +195 -0
- llm_utils/lm/text_lm.py +130 -0
- llm_utils/lm/utils.py +130 -0
- {speedy_utils-1.0.0.dist-info → speedy_utils-1.0.1.dist-info}/METADATA +1 -1
- {speedy_utils-1.0.0.dist-info → speedy_utils-1.0.1.dist-info}/RECORD +11 -6
- llm_utils/lm.py +0 -742
- {speedy_utils-1.0.0.dist-info → speedy_utils-1.0.1.dist-info}/WHEEL +0 -0
- {speedy_utils-1.0.0.dist-info → speedy_utils-1.0.1.dist-info}/entry_points.txt +0 -0
llm_utils/lm.py
DELETED
|
@@ -1,742 +0,0 @@
|
|
|
1
|
-
import fcntl
|
|
2
|
-
import os
|
|
3
|
-
import random
|
|
4
|
-
import tempfile
|
|
5
|
-
from copy import deepcopy
|
|
6
|
-
import time
|
|
7
|
-
from typing import Any, List, Literal, Optional, TypedDict, Dict, Type, Union, cast
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
import numpy as np
|
|
11
|
-
from loguru import logger
|
|
12
|
-
from pydantic import BaseModel
|
|
13
|
-
from speedy_utils import dump_json_or_pickle, identify_uuid, load_json_or_pickle
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class Message(TypedDict):
|
|
17
|
-
role: Literal["user", "assistant", "system"]
|
|
18
|
-
content: str | BaseModel
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class ChatSession:
|
|
22
|
-
|
|
23
|
-
def __init__(
|
|
24
|
-
self,
|
|
25
|
-
lm: "OAI_LM",
|
|
26
|
-
system_prompt: Optional[str] = None,
|
|
27
|
-
history: List[Message] = [], # Default to empty list, deepcopy happens below
|
|
28
|
-
callback=None,
|
|
29
|
-
response_format: Optional[Type[BaseModel]] = None,
|
|
30
|
-
):
|
|
31
|
-
self.lm = deepcopy(lm)
|
|
32
|
-
self.history = deepcopy(history) # Deepcopy the provided history
|
|
33
|
-
self.callback = callback
|
|
34
|
-
self.response_format = response_format
|
|
35
|
-
if system_prompt:
|
|
36
|
-
system_message: Message = {
|
|
37
|
-
"role": "system",
|
|
38
|
-
"content": system_prompt,
|
|
39
|
-
}
|
|
40
|
-
self.history.insert(0, system_message)
|
|
41
|
-
|
|
42
|
-
def __len__(self):
|
|
43
|
-
return len(self.history)
|
|
44
|
-
|
|
45
|
-
def __call__(
|
|
46
|
-
self,
|
|
47
|
-
text,
|
|
48
|
-
response_format: Optional[Type[BaseModel]] = None,
|
|
49
|
-
display=False,
|
|
50
|
-
max_prev_turns=3,
|
|
51
|
-
**kwargs,
|
|
52
|
-
) -> str | BaseModel:
|
|
53
|
-
current_response_format = response_format or self.response_format
|
|
54
|
-
self.history.append({"role": "user", "content": text})
|
|
55
|
-
output = self.lm(
|
|
56
|
-
messages=self.parse_history(),
|
|
57
|
-
response_format=current_response_format,
|
|
58
|
-
**kwargs,
|
|
59
|
-
)
|
|
60
|
-
# output could be a string or a pydantic model
|
|
61
|
-
if isinstance(output, BaseModel):
|
|
62
|
-
self.history.append({"role": "assistant", "content": output})
|
|
63
|
-
else:
|
|
64
|
-
assert response_format is None
|
|
65
|
-
self.history.append({"role": "assistant", "content": output})
|
|
66
|
-
if display:
|
|
67
|
-
self.inspect_history(max_prev_turns=max_prev_turns)
|
|
68
|
-
|
|
69
|
-
if self.callback:
|
|
70
|
-
self.callback(self, output)
|
|
71
|
-
return output
|
|
72
|
-
|
|
73
|
-
def send_message(self, text, **kwargs):
|
|
74
|
-
"""
|
|
75
|
-
Wrapper around __call__ method for sending messages.
|
|
76
|
-
This maintains compatibility with the test suite.
|
|
77
|
-
"""
|
|
78
|
-
return self.__call__(text, **kwargs)
|
|
79
|
-
|
|
80
|
-
def parse_history(self, indent=None):
|
|
81
|
-
parsed_history = []
|
|
82
|
-
for m in self.history:
|
|
83
|
-
if isinstance(m["content"], str):
|
|
84
|
-
parsed_history.append(m)
|
|
85
|
-
elif isinstance(m["content"], BaseModel):
|
|
86
|
-
parsed_history.append(
|
|
87
|
-
{
|
|
88
|
-
"role": m["role"],
|
|
89
|
-
"content": m["content"].model_dump_json(indent=indent),
|
|
90
|
-
}
|
|
91
|
-
)
|
|
92
|
-
else:
|
|
93
|
-
raise ValueError(f"Unexpected content type: {type(m['content'])}")
|
|
94
|
-
return parsed_history
|
|
95
|
-
|
|
96
|
-
def inspect_history(self, max_prev_turns=3):
|
|
97
|
-
from llm_utils import display_chat_messages_as_html
|
|
98
|
-
|
|
99
|
-
h = self.parse_history(indent=2)
|
|
100
|
-
try:
|
|
101
|
-
from IPython.display import clear_output
|
|
102
|
-
|
|
103
|
-
clear_output()
|
|
104
|
-
display_chat_messages_as_html(h[-max_prev_turns * 2 :])
|
|
105
|
-
except:
|
|
106
|
-
pass
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def _clear_port_use(ports):
|
|
110
|
-
"""
|
|
111
|
-
Clear the usage counters for all ports.
|
|
112
|
-
"""
|
|
113
|
-
for port in ports:
|
|
114
|
-
file_counter = f"/tmp/port_use_counter_{port}.npy"
|
|
115
|
-
if os.path.exists(file_counter):
|
|
116
|
-
os.remove(file_counter)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def _atomic_save(array: np.ndarray, filename: str):
|
|
120
|
-
"""
|
|
121
|
-
Write `array` to `filename` with an atomic rename to avoid partial writes.
|
|
122
|
-
"""
|
|
123
|
-
# The temp file must be on the same filesystem as `filename` to ensure
|
|
124
|
-
# that os.replace() is truly atomic.
|
|
125
|
-
tmp_dir = os.path.dirname(filename) or "."
|
|
126
|
-
with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp:
|
|
127
|
-
np.save(tmp, array)
|
|
128
|
-
temp_name = tmp.name
|
|
129
|
-
|
|
130
|
-
# Atomically rename the temp file to the final name.
|
|
131
|
-
# On POSIX systems, os.replace is an atomic operation.
|
|
132
|
-
os.replace(temp_name, filename)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
def _update_port_use(port: int, increment: int):
|
|
136
|
-
"""
|
|
137
|
-
Update the usage counter for a given port, safely with an exclusive lock
|
|
138
|
-
and atomic writes to avoid file corruption.
|
|
139
|
-
"""
|
|
140
|
-
file_counter = f"/tmp/port_use_counter_{port}.npy"
|
|
141
|
-
file_counter_lock = f"/tmp/port_use_counter_{port}.lock"
|
|
142
|
-
|
|
143
|
-
with open(file_counter_lock, "w") as lock_file:
|
|
144
|
-
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
|
145
|
-
try:
|
|
146
|
-
# If file exists, load it. Otherwise assume zero usage.
|
|
147
|
-
if os.path.exists(file_counter):
|
|
148
|
-
try:
|
|
149
|
-
counter = np.load(file_counter)
|
|
150
|
-
except Exception as e:
|
|
151
|
-
# If we fail to load (e.g. file corrupted), start from zero
|
|
152
|
-
logger.warning(f"Corrupted usage file {file_counter}: {e}")
|
|
153
|
-
counter = np.array([0])
|
|
154
|
-
else:
|
|
155
|
-
counter = np.array([0])
|
|
156
|
-
|
|
157
|
-
# Increment usage and atomically overwrite the old file
|
|
158
|
-
counter[0] += increment
|
|
159
|
-
_atomic_save(counter, file_counter)
|
|
160
|
-
|
|
161
|
-
finally:
|
|
162
|
-
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def _pick_least_used_port(ports: List[int]) -> int:
|
|
166
|
-
"""
|
|
167
|
-
Pick the least-used port among the provided list, safely under a global lock
|
|
168
|
-
so that no two processes pick a port at the same time.
|
|
169
|
-
"""
|
|
170
|
-
global_lock_file = "/tmp/ports.lock"
|
|
171
|
-
|
|
172
|
-
with open(global_lock_file, "w") as lock_file:
|
|
173
|
-
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
|
174
|
-
try:
|
|
175
|
-
port_use: Dict[int, int] = {}
|
|
176
|
-
# Read usage for each port
|
|
177
|
-
for port in ports:
|
|
178
|
-
file_counter = f"/tmp/port_use_counter_{port}.npy"
|
|
179
|
-
if os.path.exists(file_counter):
|
|
180
|
-
try:
|
|
181
|
-
counter = np.load(file_counter)
|
|
182
|
-
except Exception as e:
|
|
183
|
-
# If the file is corrupted, reset usage to 0
|
|
184
|
-
logger.warning(f"Corrupted usage file {file_counter}: {e}")
|
|
185
|
-
counter = np.array([0])
|
|
186
|
-
else:
|
|
187
|
-
counter = np.array([0])
|
|
188
|
-
port_use[port] = counter[0]
|
|
189
|
-
|
|
190
|
-
logger.debug(f"Port use: {port_use}")
|
|
191
|
-
|
|
192
|
-
if not port_use:
|
|
193
|
-
if ports:
|
|
194
|
-
raise ValueError("Port usage data is empty, cannot pick a port.")
|
|
195
|
-
else:
|
|
196
|
-
raise ValueError("No ports provided to pick from.")
|
|
197
|
-
|
|
198
|
-
# Pick the least-used port
|
|
199
|
-
lsp = min(port_use, key=lambda k: port_use[k])
|
|
200
|
-
|
|
201
|
-
# Increment usage of that port
|
|
202
|
-
_update_port_use(lsp, 1)
|
|
203
|
-
|
|
204
|
-
finally:
|
|
205
|
-
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
|
206
|
-
|
|
207
|
-
return lsp
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
class OAI_LM:
|
|
211
|
-
"""
|
|
212
|
-
A language model supporting chat or text completion requests for use with DSPy modules.
|
|
213
|
-
"""
|
|
214
|
-
|
|
215
|
-
def __init__(
|
|
216
|
-
self,
|
|
217
|
-
model: Optional[str] = None,
|
|
218
|
-
model_type: Literal["chat", "text"] = "chat",
|
|
219
|
-
temperature: float = 0.0,
|
|
220
|
-
max_tokens: int = 2000,
|
|
221
|
-
cache: bool = True,
|
|
222
|
-
callbacks: Optional[Any] = None,
|
|
223
|
-
num_retries: int = 3,
|
|
224
|
-
provider=None,
|
|
225
|
-
finetuning_model: Optional[str] = None,
|
|
226
|
-
launch_kwargs: Optional[dict[str, Any]] = None,
|
|
227
|
-
host: str = "localhost",
|
|
228
|
-
port: Optional[int] = None,
|
|
229
|
-
ports: Optional[List[int]] = None,
|
|
230
|
-
api_key: Optional[str] = None,
|
|
231
|
-
**kwargs,
|
|
232
|
-
):
|
|
233
|
-
# Lazy import dspy
|
|
234
|
-
import dspy
|
|
235
|
-
|
|
236
|
-
self.ports = ports
|
|
237
|
-
self.host = host
|
|
238
|
-
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
239
|
-
|
|
240
|
-
# Determine base_url: kwargs["base_url"] > http://host:port > http://host:ports[0]
|
|
241
|
-
resolved_base_url_from_kwarg = kwargs.get("base_url")
|
|
242
|
-
if resolved_base_url_from_kwarg is not None and not isinstance(
|
|
243
|
-
resolved_base_url_from_kwarg, str
|
|
244
|
-
):
|
|
245
|
-
logger.warning(
|
|
246
|
-
f"base_url in kwargs was not a string ({type(resolved_base_url_from_kwarg)}), ignoring."
|
|
247
|
-
)
|
|
248
|
-
resolved_base_url_from_kwarg = None
|
|
249
|
-
|
|
250
|
-
resolved_base_url: Optional[str] = cast(
|
|
251
|
-
Optional[str], resolved_base_url_from_kwarg
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
if resolved_base_url is None:
|
|
255
|
-
selected_port = port
|
|
256
|
-
if selected_port is None and ports is not None and len(ports) > 0:
|
|
257
|
-
selected_port = ports[0]
|
|
258
|
-
|
|
259
|
-
if selected_port is not None:
|
|
260
|
-
resolved_base_url = f"http://{host}:{selected_port}/v1"
|
|
261
|
-
self.base_url = resolved_base_url
|
|
262
|
-
|
|
263
|
-
if model is None:
|
|
264
|
-
if self.base_url:
|
|
265
|
-
try:
|
|
266
|
-
model_list = (
|
|
267
|
-
self.list_models()
|
|
268
|
-
) # Uses self.base_url and self.api_key
|
|
269
|
-
if model_list:
|
|
270
|
-
model_name_from_list = model_list[0]
|
|
271
|
-
model = f"openai/{model_name_from_list}"
|
|
272
|
-
logger.info(f"Using default model: {model}")
|
|
273
|
-
else:
|
|
274
|
-
logger.warning(
|
|
275
|
-
f"No models found at {self.base_url}. Please specify a model."
|
|
276
|
-
)
|
|
277
|
-
except Exception as e:
|
|
278
|
-
example_cmd = (
|
|
279
|
-
"LM.start_server('unsloth/gemma-3-1b-it')\n"
|
|
280
|
-
"# Or manually run: svllm serve --model unsloth/gemma-3-1b-it --gpus 0 -hp localhost:9150"
|
|
281
|
-
)
|
|
282
|
-
logger.error(
|
|
283
|
-
f"Failed to list models from {self.base_url}: {e}\n"
|
|
284
|
-
f"Make sure your model server is running and accessible.\n"
|
|
285
|
-
f"Example to start a server:\n{example_cmd}"
|
|
286
|
-
)
|
|
287
|
-
else:
|
|
288
|
-
logger.warning(
|
|
289
|
-
"base_url not configured, cannot fetch default model. Please specify a model."
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
assert (
|
|
293
|
-
model is not None
|
|
294
|
-
), "Model name must be provided or discoverable via list_models"
|
|
295
|
-
|
|
296
|
-
if not model.startswith("openai/"):
|
|
297
|
-
model = f"openai/{model}"
|
|
298
|
-
|
|
299
|
-
dspy_lm_kwargs = kwargs.copy()
|
|
300
|
-
dspy_lm_kwargs["api_key"] = self.api_key # Ensure dspy.LM gets this
|
|
301
|
-
|
|
302
|
-
if self.base_url and "base_url" not in dspy_lm_kwargs:
|
|
303
|
-
dspy_lm_kwargs["base_url"] = self.base_url
|
|
304
|
-
elif (
|
|
305
|
-
self.base_url
|
|
306
|
-
and "base_url" in dspy_lm_kwargs
|
|
307
|
-
and dspy_lm_kwargs["base_url"] != self.base_url
|
|
308
|
-
):
|
|
309
|
-
# If kwarg['base_url'] exists and differs from derived self.base_url,
|
|
310
|
-
# dspy.LM will use kwarg['base_url']. Update self.base_url to reflect this.
|
|
311
|
-
self.base_url = dspy_lm_kwargs["base_url"]
|
|
312
|
-
|
|
313
|
-
self._dspy_lm: dspy.LM = dspy.LM(
|
|
314
|
-
model=model,
|
|
315
|
-
model_type=model_type,
|
|
316
|
-
temperature=temperature,
|
|
317
|
-
max_tokens=max_tokens,
|
|
318
|
-
callbacks=callbacks,
|
|
319
|
-
num_retries=num_retries,
|
|
320
|
-
provider=provider,
|
|
321
|
-
finetuning_model=finetuning_model,
|
|
322
|
-
launch_kwargs=launch_kwargs,
|
|
323
|
-
# api_key is passed via dspy_lm_kwargs
|
|
324
|
-
**dspy_lm_kwargs,
|
|
325
|
-
)
|
|
326
|
-
# Store the actual kwargs used by dspy.LM
|
|
327
|
-
self.kwargs = self._dspy_lm.kwargs
|
|
328
|
-
self.model = self._dspy_lm.model # self.model is str
|
|
329
|
-
|
|
330
|
-
# Ensure self.base_url and self.api_key are consistent with what dspy.LM is using
|
|
331
|
-
self.base_url = self.kwargs.get("base_url")
|
|
332
|
-
self.api_key = self.kwargs.get("api_key")
|
|
333
|
-
|
|
334
|
-
self.do_cache = cache
|
|
335
|
-
|
|
336
|
-
@property
|
|
337
|
-
def last_response(self):
|
|
338
|
-
return self._dspy_lm.history[-1]["response"].model_dump()["choices"][0][
|
|
339
|
-
"message"
|
|
340
|
-
]
|
|
341
|
-
|
|
342
|
-
def __call__(
|
|
343
|
-
self,
|
|
344
|
-
prompt: Optional[str] = None,
|
|
345
|
-
messages: Optional[List[Message]] = None,
|
|
346
|
-
response_format: Optional[Type[BaseModel]] = None,
|
|
347
|
-
cache: Optional[bool] = None,
|
|
348
|
-
retry_count: int = 0,
|
|
349
|
-
port: Optional[int] = None,
|
|
350
|
-
error: Optional[Exception] = None,
|
|
351
|
-
use_loadbalance: Optional[bool] = None,
|
|
352
|
-
must_load_cache: bool = False,
|
|
353
|
-
max_tokens: Optional[int] = None,
|
|
354
|
-
num_retries: int = 10,
|
|
355
|
-
**kwargs,
|
|
356
|
-
) -> Union[str, BaseModel]:
|
|
357
|
-
if retry_count > num_retries:
|
|
358
|
-
logger.error(f"Retry limit exceeded, error: {error}")
|
|
359
|
-
if error:
|
|
360
|
-
raise error
|
|
361
|
-
raise ValueError("Retry limit exceeded with no specific error.")
|
|
362
|
-
|
|
363
|
-
effective_kwargs = {**self.kwargs, **kwargs}
|
|
364
|
-
id_for_cache: Optional[str] = None
|
|
365
|
-
|
|
366
|
-
effective_cache = cache if cache is not None else self.do_cache
|
|
367
|
-
|
|
368
|
-
if max_tokens is not None:
|
|
369
|
-
effective_kwargs["max_tokens"] = max_tokens
|
|
370
|
-
|
|
371
|
-
if response_format:
|
|
372
|
-
assert issubclass(
|
|
373
|
-
response_format, BaseModel
|
|
374
|
-
), f"response_format must be a Pydantic model class, {type(response_format)} provided"
|
|
375
|
-
|
|
376
|
-
cached_result: Optional[Union[str, BaseModel, List[Union[str, BaseModel]]]] = (
|
|
377
|
-
None
|
|
378
|
-
)
|
|
379
|
-
if effective_cache:
|
|
380
|
-
cache_key_list = [
|
|
381
|
-
prompt,
|
|
382
|
-
messages,
|
|
383
|
-
(response_format.model_json_schema() if response_format else None),
|
|
384
|
-
effective_kwargs.get("temperature"),
|
|
385
|
-
effective_kwargs.get("max_tokens"),
|
|
386
|
-
self.model,
|
|
387
|
-
]
|
|
388
|
-
s = str(cache_key_list)
|
|
389
|
-
id_for_cache = identify_uuid(s)
|
|
390
|
-
cached_result = self.load_cache(id_for_cache)
|
|
391
|
-
|
|
392
|
-
if cached_result is not None:
|
|
393
|
-
if response_format:
|
|
394
|
-
if isinstance(cached_result, str):
|
|
395
|
-
try:
|
|
396
|
-
import json_repair
|
|
397
|
-
|
|
398
|
-
parsed = json_repair.loads(cached_result)
|
|
399
|
-
if not isinstance(parsed, dict):
|
|
400
|
-
raise ValueError("Parsed cached_result is not a dict")
|
|
401
|
-
# Ensure keys are strings
|
|
402
|
-
parsed = {str(k): v for k, v in parsed.items()}
|
|
403
|
-
return response_format(**parsed)
|
|
404
|
-
except Exception as e_parse:
|
|
405
|
-
logger.warning(
|
|
406
|
-
f"Failed to parse cached string for {id_for_cache} into {response_format.__name__}: {e_parse}. Retrying LLM call."
|
|
407
|
-
)
|
|
408
|
-
elif isinstance(cached_result, response_format):
|
|
409
|
-
return cached_result
|
|
410
|
-
else:
|
|
411
|
-
logger.warning(
|
|
412
|
-
f"Cached result for {id_for_cache} has unexpected type {type(cached_result)}. Expected {response_format.__name__} or str. Retrying LLM call."
|
|
413
|
-
)
|
|
414
|
-
else: # No response_format, expect string
|
|
415
|
-
if isinstance(cached_result, str):
|
|
416
|
-
return cached_result
|
|
417
|
-
else:
|
|
418
|
-
logger.warning(
|
|
419
|
-
f"Cached result for {id_for_cache} has unexpected type {type(cached_result)}. Expected str. Retrying LLM call."
|
|
420
|
-
)
|
|
421
|
-
|
|
422
|
-
if (
|
|
423
|
-
must_load_cache and cached_result is None
|
|
424
|
-
): # If we are here, cache load failed or was not suitable
|
|
425
|
-
raise ValueError(
|
|
426
|
-
"must_load_cache is True, but failed to load a valid response from cache."
|
|
427
|
-
)
|
|
428
|
-
|
|
429
|
-
import litellm
|
|
430
|
-
|
|
431
|
-
current_port: int | None = port
|
|
432
|
-
if self.ports and not current_port:
|
|
433
|
-
if use_loadbalance:
|
|
434
|
-
current_port = self.get_least_used_port()
|
|
435
|
-
else:
|
|
436
|
-
current_port = random.choice(self.ports)
|
|
437
|
-
|
|
438
|
-
if current_port:
|
|
439
|
-
effective_kwargs["base_url"] = f"http://{self.host}:{current_port}/v1"
|
|
440
|
-
|
|
441
|
-
llm_output_or_outputs: Union[str, BaseModel, List[Union[str, BaseModel]]]
|
|
442
|
-
try:
|
|
443
|
-
dspy_main_input: Union[str, List[Message]]
|
|
444
|
-
if messages is not None:
|
|
445
|
-
dspy_main_input = messages
|
|
446
|
-
elif prompt is not None:
|
|
447
|
-
dspy_main_input = prompt
|
|
448
|
-
else:
|
|
449
|
-
# Depending on LM capabilities, this might be valid if other means of generation are used (e.g. tool use)
|
|
450
|
-
# For now, assume one is needed for typical completion/chat.
|
|
451
|
-
# Consider if _dspy_lm can handle None/empty input gracefully or if an error is better.
|
|
452
|
-
# If dspy.LM expects a non-null primary argument, this will fail there.
|
|
453
|
-
# For safety, let's raise if both are None, assuming typical usage.
|
|
454
|
-
raise ValueError(
|
|
455
|
-
"Either 'prompt' or 'messages' must be provided for the LLM call."
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
llm_outputs_list = self._dspy_lm(
|
|
459
|
-
dspy_main_input, # Pass as positional argument
|
|
460
|
-
response_format=response_format, # Pass as keyword argument, dspy will handle it in its **kwargs
|
|
461
|
-
**effective_kwargs,
|
|
462
|
-
)
|
|
463
|
-
|
|
464
|
-
if not llm_outputs_list:
|
|
465
|
-
raise ValueError("LLM call returned an empty list.")
|
|
466
|
-
|
|
467
|
-
# Convert dict outputs to string to match expected return type
|
|
468
|
-
def convert_output(o):
|
|
469
|
-
if isinstance(o, dict):
|
|
470
|
-
import json
|
|
471
|
-
|
|
472
|
-
return json.dumps(o)
|
|
473
|
-
return o
|
|
474
|
-
|
|
475
|
-
if effective_kwargs.get("n", 1) == 1:
|
|
476
|
-
llm_output_or_outputs = convert_output(llm_outputs_list[0])
|
|
477
|
-
else:
|
|
478
|
-
llm_output_or_outputs = [convert_output(o) for o in llm_outputs_list]
|
|
479
|
-
|
|
480
|
-
except (litellm.exceptions.APIError, litellm.exceptions.Timeout) as e_llm:
|
|
481
|
-
t = 3
|
|
482
|
-
base_url_info = effective_kwargs.get("base_url", "N/A")
|
|
483
|
-
log_msg = f"[{base_url_info=}] {type(e_llm).__name__}: {str(e_llm)[:100]}, will sleep for {t}s and retry"
|
|
484
|
-
logger.warning(log_msg) # Always warn on retry for these
|
|
485
|
-
time.sleep(t)
|
|
486
|
-
return self.__call__(
|
|
487
|
-
prompt=prompt,
|
|
488
|
-
messages=messages,
|
|
489
|
-
response_format=response_format,
|
|
490
|
-
cache=cache,
|
|
491
|
-
retry_count=retry_count + 1,
|
|
492
|
-
port=current_port,
|
|
493
|
-
error=e_llm,
|
|
494
|
-
use_loadbalance=use_loadbalance,
|
|
495
|
-
must_load_cache=must_load_cache,
|
|
496
|
-
max_tokens=max_tokens,
|
|
497
|
-
num_retries=num_retries,
|
|
498
|
-
**kwargs,
|
|
499
|
-
)
|
|
500
|
-
except litellm.exceptions.ContextWindowExceededError as e_cwe:
|
|
501
|
-
logger.error(f"Context window exceeded: {e_cwe}")
|
|
502
|
-
raise
|
|
503
|
-
except Exception as e_generic:
|
|
504
|
-
logger.error(f"Generic error during LLM call: {e_generic}")
|
|
505
|
-
import traceback
|
|
506
|
-
|
|
507
|
-
traceback.print_exc()
|
|
508
|
-
raise
|
|
509
|
-
finally:
|
|
510
|
-
if (
|
|
511
|
-
current_port and use_loadbalance is True
|
|
512
|
-
): # Ensure use_loadbalance is explicitly True
|
|
513
|
-
_update_port_use(current_port, -1)
|
|
514
|
-
|
|
515
|
-
if effective_cache and id_for_cache:
|
|
516
|
-
self.dump_cache(id_for_cache, llm_output_or_outputs)
|
|
517
|
-
|
|
518
|
-
# Ensure single return if n=1, which is implied by method signature str | BaseModel
|
|
519
|
-
final_output: Union[str, BaseModel]
|
|
520
|
-
if isinstance(llm_output_or_outputs, list):
|
|
521
|
-
# This should ideally not happen if n=1 was handled correctly above.
|
|
522
|
-
# If it's a list, it means n > 1. The method signature needs to change for that.
|
|
523
|
-
# For now, stick to returning the first element if it's a list.
|
|
524
|
-
logger.warning(
|
|
525
|
-
"LLM returned multiple completions; __call__ expects single. Returning first."
|
|
526
|
-
)
|
|
527
|
-
final_output = llm_output_or_outputs[0]
|
|
528
|
-
else:
|
|
529
|
-
final_output = llm_output_or_outputs # type: ignore # It's already Union[str, BaseModel]
|
|
530
|
-
|
|
531
|
-
if response_format:
|
|
532
|
-
if not isinstance(final_output, response_format):
|
|
533
|
-
if isinstance(final_output, str):
|
|
534
|
-
logger.warning(
|
|
535
|
-
f"LLM call returned string, but expected {response_format.__name__}. Attempting parse."
|
|
536
|
-
)
|
|
537
|
-
try:
|
|
538
|
-
import json_repair
|
|
539
|
-
|
|
540
|
-
parsed_dict = json_repair.loads(final_output)
|
|
541
|
-
if not isinstance(parsed_dict, dict):
|
|
542
|
-
raise ValueError("Parsed output is not a dict")
|
|
543
|
-
parsed_dict = {str(k): v for k, v in parsed_dict.items()}
|
|
544
|
-
parsed_output = response_format(**parsed_dict)
|
|
545
|
-
if effective_cache and id_for_cache:
|
|
546
|
-
self.dump_cache(
|
|
547
|
-
id_for_cache, parsed_output
|
|
548
|
-
) # Cache the successfully parsed model
|
|
549
|
-
return parsed_output
|
|
550
|
-
except Exception as e_final_parse:
|
|
551
|
-
logger.error(
|
|
552
|
-
f"Final attempt to parse LLM string output into {response_format.__name__} failed: {e_final_parse}"
|
|
553
|
-
)
|
|
554
|
-
# Retry without cache to force regeneration
|
|
555
|
-
return self.__call__(
|
|
556
|
-
prompt=prompt,
|
|
557
|
-
messages=messages,
|
|
558
|
-
response_format=response_format,
|
|
559
|
-
cache=False,
|
|
560
|
-
retry_count=retry_count + 1,
|
|
561
|
-
port=current_port,
|
|
562
|
-
error=e_final_parse,
|
|
563
|
-
use_loadbalance=use_loadbalance,
|
|
564
|
-
must_load_cache=False,
|
|
565
|
-
max_tokens=max_tokens,
|
|
566
|
-
num_retries=num_retries,
|
|
567
|
-
**kwargs,
|
|
568
|
-
)
|
|
569
|
-
else:
|
|
570
|
-
logger.error(
|
|
571
|
-
f"LLM output type mismatch. Expected {response_format.__name__} or str, got {type(final_output)}. Raising error."
|
|
572
|
-
)
|
|
573
|
-
raise TypeError(
|
|
574
|
-
f"LLM output type mismatch: expected {response_format.__name__}, got {type(final_output)}"
|
|
575
|
-
)
|
|
576
|
-
return final_output # Already a response_format instance
|
|
577
|
-
else: # No response_format, expect string
|
|
578
|
-
if not isinstance(final_output, str):
|
|
579
|
-
# This could happen if LLM returns structured data and dspy parses it even without response_format
|
|
580
|
-
logger.warning(
|
|
581
|
-
f"LLM output type mismatch. Expected str, got {type(final_output)}. Attempting to convert to string."
|
|
582
|
-
)
|
|
583
|
-
# Convert to string, or handle as error depending on desired strictness
|
|
584
|
-
return str(final_output) # Or raise TypeError
|
|
585
|
-
return final_output
|
|
586
|
-
|
|
587
|
-
def clear_port_use(self):
|
|
588
|
-
if self.ports:
|
|
589
|
-
_clear_port_use(self.ports)
|
|
590
|
-
else:
|
|
591
|
-
logger.warning("No ports configured to clear usage for.")
|
|
592
|
-
|
|
593
|
-
def get_least_used_port(self) -> int:
|
|
594
|
-
if self.ports is None:
|
|
595
|
-
raise ValueError("Ports must be configured to pick the least used port.")
|
|
596
|
-
if not self.ports:
|
|
597
|
-
raise ValueError("Ports list is empty, cannot pick a port.")
|
|
598
|
-
return _pick_least_used_port(self.ports)
|
|
599
|
-
|
|
600
|
-
def get_session(
|
|
601
|
-
self,
|
|
602
|
-
system_prompt: Optional[str],
|
|
603
|
-
history: Optional[List[Message]] = None,
|
|
604
|
-
callback=None,
|
|
605
|
-
response_format: Optional[Type[BaseModel]] = None,
|
|
606
|
-
**kwargs, # kwargs are not used by ChatSession constructor
|
|
607
|
-
) -> ChatSession:
|
|
608
|
-
actual_history = deepcopy(history) if history is not None else []
|
|
609
|
-
return ChatSession(
|
|
610
|
-
self,
|
|
611
|
-
system_prompt=system_prompt,
|
|
612
|
-
history=actual_history,
|
|
613
|
-
callback=callback,
|
|
614
|
-
response_format=response_format,
|
|
615
|
-
# **kwargs, # ChatSession constructor does not accept **kwargs
|
|
616
|
-
)
|
|
617
|
-
|
|
618
|
-
def dump_cache(
|
|
619
|
-
self, id: str, result: Union[str, BaseModel, List[Union[str, BaseModel]]]
|
|
620
|
-
):
|
|
621
|
-
try:
|
|
622
|
-
cache_file = f"~/.cache/oai_lm/{self.model}/{id}.pkl"
|
|
623
|
-
cache_file = os.path.expanduser(cache_file)
|
|
624
|
-
|
|
625
|
-
dump_json_or_pickle(result, cache_file)
|
|
626
|
-
except Exception as e:
|
|
627
|
-
logger.warning(f"Cache dump failed: {e}")
|
|
628
|
-
|
|
629
|
-
def load_cache(
|
|
630
|
-
self, id: str
|
|
631
|
-
) -> Optional[Union[str, BaseModel, List[Union[str, BaseModel]]]]:
|
|
632
|
-
try:
|
|
633
|
-
cache_file = f"~/.cache/oai_lm/{self.model}/{id}.pkl"
|
|
634
|
-
cache_file = os.path.expanduser(cache_file)
|
|
635
|
-
if not os.path.exists(cache_file):
|
|
636
|
-
return
|
|
637
|
-
return load_json_or_pickle(cache_file)
|
|
638
|
-
except Exception as e:
|
|
639
|
-
logger.warning(f"Cache load failed for {id}: {e}") # Added id to log
|
|
640
|
-
return None
|
|
641
|
-
|
|
642
|
-
def list_models(self) -> List[str]:
|
|
643
|
-
import openai
|
|
644
|
-
|
|
645
|
-
if not self.base_url:
|
|
646
|
-
raise ValueError("Cannot list models: base_url is not configured.")
|
|
647
|
-
if not self.api_key: # api_key should be set by __init__
|
|
648
|
-
logger.warning(
|
|
649
|
-
"API key not available for listing models. Using default 'abc'."
|
|
650
|
-
)
|
|
651
|
-
|
|
652
|
-
api_key_str = str(self.api_key) if self.api_key is not None else "abc"
|
|
653
|
-
base_url_str = str(self.base_url) if self.base_url is not None else None
|
|
654
|
-
if isinstance(self.base_url, float):
|
|
655
|
-
raise TypeError(f"base_url must be a string or None, got float: {self.base_url}")
|
|
656
|
-
client = openai.OpenAI(base_url=base_url_str, api_key=api_key_str)
|
|
657
|
-
page = client.models.list()
|
|
658
|
-
return [d.id for d in page.data]
|
|
659
|
-
|
|
660
|
-
@property
|
|
661
|
-
def client(self):
|
|
662
|
-
import openai
|
|
663
|
-
if not self.base_url:
|
|
664
|
-
raise ValueError("Cannot create client: base_url is not configured.")
|
|
665
|
-
if not self.api_key:
|
|
666
|
-
logger.warning("API key not available for client. Using default 'abc'.")
|
|
667
|
-
|
|
668
|
-
base_url_str = str(self.base_url) if self.base_url is not None else None
|
|
669
|
-
api_key_str = str(self.api_key) if self.api_key is not None else "abc"
|
|
670
|
-
return openai.OpenAI(base_url=base_url_str, api_key=api_key_str)
|
|
671
|
-
|
|
672
|
-
def __getattr__(self, name):
|
|
673
|
-
"""
|
|
674
|
-
Delegate any attributes not found in OAI_LM to the underlying dspy.LM instance.
|
|
675
|
-
This makes sure any dspy.LM methods not explicitly defined in OAI_LM are still accessible.
|
|
676
|
-
"""
|
|
677
|
-
# Check __dict__ directly to avoid recursion via hasattr
|
|
678
|
-
if "_dspy_lm" in self.__dict__ and hasattr(self._dspy_lm, name):
|
|
679
|
-
return getattr(self._dspy_lm, name)
|
|
680
|
-
raise AttributeError(
|
|
681
|
-
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
|
682
|
-
)
|
|
683
|
-
|
|
684
|
-
@classmethod
|
|
685
|
-
def get_deepseek_chat(
|
|
686
|
-
cls, api_key: Optional[str] = None, max_tokens: int = 2000, **kwargs
|
|
687
|
-
):
|
|
688
|
-
api_key_to_pass = cast(
|
|
689
|
-
Optional[str], api_key or os.environ.get("DEEPSEEK_API_KEY")
|
|
690
|
-
)
|
|
691
|
-
return cls( # Use cls instead of OAI_LM
|
|
692
|
-
base_url="https://api.deepseek.com/v1",
|
|
693
|
-
model="deepseek-chat",
|
|
694
|
-
api_key=api_key_to_pass,
|
|
695
|
-
max_tokens=max_tokens,
|
|
696
|
-
**kwargs,
|
|
697
|
-
)
|
|
698
|
-
|
|
699
|
-
@classmethod
|
|
700
|
-
def get_deepseek_reasoner(
|
|
701
|
-
cls, api_key: Optional[str] = None, max_tokens: int = 2000, **kwargs
|
|
702
|
-
):
|
|
703
|
-
api_key_to_pass = cast(
|
|
704
|
-
Optional[str], api_key or os.environ.get("DEEPSEEK_API_KEY")
|
|
705
|
-
)
|
|
706
|
-
return cls( # Use cls instead of OAI_LM
|
|
707
|
-
base_url="https://api.deepseek.com/v1",
|
|
708
|
-
model="deepseek-reasoner",
|
|
709
|
-
api_key=api_key_to_pass,
|
|
710
|
-
max_tokens=max_tokens,
|
|
711
|
-
**kwargs,
|
|
712
|
-
)
|
|
713
|
-
|
|
714
|
-
@classmethod
|
|
715
|
-
def start_server(
|
|
716
|
-
cls, model_name: str, gpus: str = "4567", port: int = 9150, eager: bool = True
|
|
717
|
-
):
|
|
718
|
-
cmd = f"svllm serve --model {model_name} --gpus {gpus} -hp localhost:{port}"
|
|
719
|
-
if eager:
|
|
720
|
-
cmd += " --eager"
|
|
721
|
-
session_name = f"vllm_{port}"
|
|
722
|
-
is_session_exists = os.system(f"tmux has-session -t {session_name}")
|
|
723
|
-
logger.info(f"Starting server with command: {cmd}")
|
|
724
|
-
if is_session_exists == 0:
|
|
725
|
-
logger.warning(
|
|
726
|
-
f"Session {session_name} exists, please kill it before running the script"
|
|
727
|
-
)
|
|
728
|
-
# as user if they want to kill the session
|
|
729
|
-
user_input = input(
|
|
730
|
-
f"Session {session_name} exists, do you want to kill it? (y/n): "
|
|
731
|
-
)
|
|
732
|
-
if user_input.lower() == "y":
|
|
733
|
-
os.system(f"tmux kill-session -t {session_name}")
|
|
734
|
-
logger.info(f"Session {session_name} killed")
|
|
735
|
-
os.system(cmd)
|
|
736
|
-
# return subprocess.Popen(shlex.split(cmd))
|
|
737
|
-
|
|
738
|
-
# set get_agent is get_session
|
|
739
|
-
get_agent = get_session
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
LM = OAI_LM
|