speedy-utils 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +1 -5
- llm_utils/chat_format/display.py +17 -4
- llm_utils/chat_format/transform.py +9 -9
- llm_utils/group_messages.py +1 -1
- llm_utils/lm/async_lm/__init__.py +7 -0
- llm_utils/lm/async_lm/_utils.py +201 -0
- llm_utils/lm/async_lm/async_llm_task.py +509 -0
- llm_utils/lm/async_lm/async_lm.py +387 -0
- llm_utils/lm/async_lm/async_lm_base.py +405 -0
- llm_utils/lm/async_lm/lm_specific.py +136 -0
- llm_utils/lm/utils.py +1 -3
- llm_utils/scripts/vllm_load_balancer.py +244 -147
- speedy_utils/__init__.py +3 -1
- speedy_utils/common/notebook_utils.py +4 -4
- speedy_utils/common/report_manager.py +2 -3
- speedy_utils/common/utils_cache.py +233 -3
- speedy_utils/common/utils_io.py +2 -0
- speedy_utils/scripts/mpython.py +1 -3
- {speedy_utils-1.1.5.dist-info → speedy_utils-1.1.7.dist-info}/METADATA +1 -1
- speedy_utils-1.1.7.dist-info/RECORD +39 -0
- llm_utils/lm/async_lm.py +0 -942
- llm_utils/lm/chat_html.py +0 -246
- llm_utils/lm/lm_json.py +0 -68
- llm_utils/lm/sync_lm.py +0 -943
- speedy_utils-1.1.5.dist-info/RECORD +0 -37
- {speedy_utils-1.1.5.dist-info → speedy_utils-1.1.7.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.5.dist-info → speedy_utils-1.1.7.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
# from ._utils import *
|
|
2
|
+
from typing import (
|
|
3
|
+
Any,
|
|
4
|
+
List,
|
|
5
|
+
Literal,
|
|
6
|
+
Optional,
|
|
7
|
+
Type,
|
|
8
|
+
cast,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from loguru import logger
|
|
12
|
+
from openai import AuthenticationError, BadRequestError, RateLimitError
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
from speedy_utils import jloads
|
|
15
|
+
|
|
16
|
+
# from llm_utils.lm.async_lm.async_llm_task import OutputModelType
|
|
17
|
+
from llm_utils.lm.async_lm.async_lm_base import AsyncLMBase
|
|
18
|
+
|
|
19
|
+
from ._utils import (
|
|
20
|
+
LegacyMsgs,
|
|
21
|
+
Messages,
|
|
22
|
+
OutputModelType,
|
|
23
|
+
ParsedOutput,
|
|
24
|
+
RawMsgs,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def jloads_safe(content: str) -> Any:
|
|
29
|
+
# if contain ```json, remove it
|
|
30
|
+
if "```json" in content:
|
|
31
|
+
content = content.split("```json")[1].strip().split("```")[0].strip()
|
|
32
|
+
try:
|
|
33
|
+
return jloads(content)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
logger.error(
|
|
36
|
+
f"Failed to parse JSON content: {content[:100]}... with error: {e}"
|
|
37
|
+
)
|
|
38
|
+
raise ValueError(f"Invalid JSON content: {content}") from e
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class AsyncLM(AsyncLMBase):
|
|
42
|
+
"""Unified **async** language‑model wrapper with optional JSON parsing."""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
model: str,
|
|
47
|
+
*,
|
|
48
|
+
response_model: Optional[type[BaseModel]] = None,
|
|
49
|
+
temperature: float = 0.0,
|
|
50
|
+
max_tokens: int = 2_000,
|
|
51
|
+
host: str = "localhost",
|
|
52
|
+
port: Optional[int | str] = None,
|
|
53
|
+
base_url: Optional[str] = None,
|
|
54
|
+
api_key: Optional[str] = None,
|
|
55
|
+
cache: bool = True,
|
|
56
|
+
think: Literal[True, False, None] = None,
|
|
57
|
+
add_json_schema_to_instruction: Optional[bool] = None,
|
|
58
|
+
use_beta: bool = False,
|
|
59
|
+
ports: Optional[List[int]] = None,
|
|
60
|
+
top_p: float = 1.0,
|
|
61
|
+
presence_penalty: float = 0.0,
|
|
62
|
+
top_k: int = 1,
|
|
63
|
+
repetition_penalty: float = 1.0,
|
|
64
|
+
frequency_penalty: Optional[float] = None,
|
|
65
|
+
) -> None:
|
|
66
|
+
super().__init__(
|
|
67
|
+
host=host,
|
|
68
|
+
port=port,
|
|
69
|
+
ports=ports,
|
|
70
|
+
base_url=base_url,
|
|
71
|
+
cache=cache,
|
|
72
|
+
api_key=api_key,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Model behavior options
|
|
76
|
+
self.response_model = response_model
|
|
77
|
+
self.think = think
|
|
78
|
+
self._use_beta = use_beta
|
|
79
|
+
self.add_json_schema_to_instruction = add_json_schema_to_instruction
|
|
80
|
+
if not use_beta:
|
|
81
|
+
self.add_json_schema_to_instruction = True
|
|
82
|
+
|
|
83
|
+
# Store all model-related parameters in model_kwargs
|
|
84
|
+
self.model_kwargs = dict(
|
|
85
|
+
model=model,
|
|
86
|
+
temperature=temperature,
|
|
87
|
+
max_tokens=max_tokens,
|
|
88
|
+
top_p=top_p,
|
|
89
|
+
presence_penalty=presence_penalty,
|
|
90
|
+
)
|
|
91
|
+
self.extra_body = dict(
|
|
92
|
+
top_k=top_k,
|
|
93
|
+
repetition_penalty=repetition_penalty,
|
|
94
|
+
frequency_penalty=frequency_penalty,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
async def _unified_client_call(
|
|
98
|
+
self,
|
|
99
|
+
messages: list[dict],
|
|
100
|
+
extra_body: Optional[dict] = None,
|
|
101
|
+
cache_suffix: str = "",
|
|
102
|
+
) -> dict:
|
|
103
|
+
"""Unified method for all client interactions with caching and error handling."""
|
|
104
|
+
converted_messages = self._convert_messages(messages)
|
|
105
|
+
cache_key = None
|
|
106
|
+
completion = None
|
|
107
|
+
|
|
108
|
+
# Handle caching
|
|
109
|
+
if self._cache:
|
|
110
|
+
cache_data = {
|
|
111
|
+
"messages": converted_messages,
|
|
112
|
+
"model_kwargs": self.model_kwargs,
|
|
113
|
+
"extra_body": extra_body or {},
|
|
114
|
+
"cache_suffix": cache_suffix,
|
|
115
|
+
}
|
|
116
|
+
cache_key = self._cache_key(cache_data, {}, str)
|
|
117
|
+
completion = self._load_cache(cache_key)
|
|
118
|
+
|
|
119
|
+
# Check for cached error responses
|
|
120
|
+
if (
|
|
121
|
+
completion
|
|
122
|
+
and isinstance(completion, dict)
|
|
123
|
+
and "error" in completion
|
|
124
|
+
and completion["error"]
|
|
125
|
+
):
|
|
126
|
+
error_type = completion.get("error_type", "Unknown")
|
|
127
|
+
error_message = completion.get("error_message", "Cached error")
|
|
128
|
+
logger.warning(f"Found cached error ({error_type}): {error_message}")
|
|
129
|
+
raise ValueError(f"Cached {error_type}: {error_message}")
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
# Get completion from API if not cached
|
|
133
|
+
if not completion:
|
|
134
|
+
call_kwargs = {
|
|
135
|
+
"messages": converted_messages,
|
|
136
|
+
**self.model_kwargs,
|
|
137
|
+
}
|
|
138
|
+
if extra_body:
|
|
139
|
+
call_kwargs["extra_body"] = extra_body
|
|
140
|
+
|
|
141
|
+
completion = await self.client.chat.completions.create(**call_kwargs)
|
|
142
|
+
|
|
143
|
+
if hasattr(completion, "model_dump"):
|
|
144
|
+
completion = completion.model_dump()
|
|
145
|
+
if cache_key:
|
|
146
|
+
self._dump_cache(cache_key, completion)
|
|
147
|
+
|
|
148
|
+
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
149
|
+
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
150
|
+
logger.error(error_msg)
|
|
151
|
+
if isinstance(exc, BadRequestError) and cache_key:
|
|
152
|
+
error_response = {
|
|
153
|
+
"error": True,
|
|
154
|
+
"error_type": "BadRequestError",
|
|
155
|
+
"error_message": str(exc),
|
|
156
|
+
"choices": [],
|
|
157
|
+
}
|
|
158
|
+
self._dump_cache(cache_key, error_response)
|
|
159
|
+
logger.debug(f"Cached BadRequestError for key: {cache_key}")
|
|
160
|
+
raise
|
|
161
|
+
|
|
162
|
+
return completion
|
|
163
|
+
|
|
164
|
+
async def _call_and_parse(
|
|
165
|
+
self,
|
|
166
|
+
messages: list[dict],
|
|
167
|
+
response_model: Type[OutputModelType],
|
|
168
|
+
json_schema: dict,
|
|
169
|
+
) -> tuple[dict, list[dict], OutputModelType]:
|
|
170
|
+
"""Unified call and parse with cache and error handling."""
|
|
171
|
+
if self._use_beta:
|
|
172
|
+
return await self._call_and_parse_with_beta(
|
|
173
|
+
messages, response_model, json_schema
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
choice = None
|
|
177
|
+
try:
|
|
178
|
+
# Use unified client call
|
|
179
|
+
completion = await self._unified_client_call(
|
|
180
|
+
messages,
|
|
181
|
+
extra_body={**self.extra_body},
|
|
182
|
+
cache_suffix=f"_parse_{response_model.__name__}",
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Parse the response
|
|
186
|
+
choice = completion["choices"][0]["message"]
|
|
187
|
+
if "content" not in choice:
|
|
188
|
+
raise ValueError("Response choice must contain 'content' field.")
|
|
189
|
+
|
|
190
|
+
content = choice["content"]
|
|
191
|
+
if not content:
|
|
192
|
+
raise ValueError("Response content is empty")
|
|
193
|
+
|
|
194
|
+
parsed = response_model.model_validate(jloads_safe(content))
|
|
195
|
+
|
|
196
|
+
except Exception as e:
|
|
197
|
+
# Try fallback to beta mode if regular parsing fails
|
|
198
|
+
if not isinstance(
|
|
199
|
+
e, (AuthenticationError, RateLimitError, BadRequestError)
|
|
200
|
+
):
|
|
201
|
+
content = choice.get("content", "N/A") if choice else "N/A"
|
|
202
|
+
logger.info(
|
|
203
|
+
f"Regular parsing failed due to wrong format or content, now falling back to beta mode: {content=}, {e=}"
|
|
204
|
+
)
|
|
205
|
+
try:
|
|
206
|
+
return await self._call_and_parse_with_beta(
|
|
207
|
+
messages, response_model, json_schema
|
|
208
|
+
)
|
|
209
|
+
except Exception as beta_e:
|
|
210
|
+
logger.warning(f"Beta mode fallback also failed: {beta_e}")
|
|
211
|
+
choice_info = choice if choice is not None else "N/A"
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"Failed to parse model response with both regular and beta modes. "
|
|
214
|
+
f"Regular error: {e}. Beta error: {beta_e}. "
|
|
215
|
+
f"Model response message: {choice_info}"
|
|
216
|
+
) from e
|
|
217
|
+
raise
|
|
218
|
+
|
|
219
|
+
assistant_msg = self._extract_assistant_message(choice)
|
|
220
|
+
full_messages = messages + [assistant_msg]
|
|
221
|
+
|
|
222
|
+
return completion, full_messages, cast(OutputModelType, parsed)
|
|
223
|
+
|
|
224
|
+
async def _call_and_parse_with_beta(
|
|
225
|
+
self,
|
|
226
|
+
messages: list[dict],
|
|
227
|
+
response_model: Type[OutputModelType],
|
|
228
|
+
json_schema: dict,
|
|
229
|
+
) -> tuple[dict, list[dict], OutputModelType]:
|
|
230
|
+
"""Call and parse for beta mode with guided JSON."""
|
|
231
|
+
choice = None
|
|
232
|
+
try:
|
|
233
|
+
# Use unified client call with guided JSON
|
|
234
|
+
completion = await self._unified_client_call(
|
|
235
|
+
messages,
|
|
236
|
+
extra_body={"guided_json": json_schema, **self.extra_body},
|
|
237
|
+
cache_suffix=f"_beta_parse_{response_model.__name__}",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Parse the response
|
|
241
|
+
choice = completion["choices"][0]["message"]
|
|
242
|
+
parsed = self._parse_complete_output(completion, response_model)
|
|
243
|
+
|
|
244
|
+
except Exception as e:
|
|
245
|
+
choice_info = choice if choice is not None else "N/A"
|
|
246
|
+
raise ValueError(
|
|
247
|
+
f"Failed to parse model response: {e}\nModel response message: {choice_info}"
|
|
248
|
+
) from e
|
|
249
|
+
|
|
250
|
+
assistant_msg = self._extract_assistant_message(choice)
|
|
251
|
+
full_messages = messages + [assistant_msg]
|
|
252
|
+
|
|
253
|
+
return completion, full_messages, cast(OutputModelType, parsed)
|
|
254
|
+
|
|
255
|
+
def _extract_assistant_message(self, choice): # -> dict[str, str] | dict[str, Any]:
|
|
256
|
+
# TODO this current assume choice is a dict with "reasoning_content" and "content"
|
|
257
|
+
has_reasoning = False
|
|
258
|
+
if "reasoning_content" in choice and isinstance(
|
|
259
|
+
choice["reasoning_content"], str
|
|
260
|
+
):
|
|
261
|
+
reasoning_content = choice["reasoning_content"].strip()
|
|
262
|
+
has_reasoning = True
|
|
263
|
+
|
|
264
|
+
content = choice["content"]
|
|
265
|
+
_content = content.lstrip("\n")
|
|
266
|
+
if has_reasoning:
|
|
267
|
+
assistant_msg = {
|
|
268
|
+
"role": "assistant",
|
|
269
|
+
"content": f"<think>\n{reasoning_content}\n</think>\n\n{_content}",
|
|
270
|
+
}
|
|
271
|
+
else:
|
|
272
|
+
assistant_msg = {"role": "assistant", "content": _content}
|
|
273
|
+
|
|
274
|
+
return assistant_msg
|
|
275
|
+
|
|
276
|
+
async def __call__(
|
|
277
|
+
self,
|
|
278
|
+
prompt: Optional[str] = None,
|
|
279
|
+
messages: Optional[RawMsgs] = None,
|
|
280
|
+
): # -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:# -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:
|
|
281
|
+
"""Unified async call for language model, returns (assistant_message.model_dump(), messages)."""
|
|
282
|
+
if (prompt is None) == (messages is None):
|
|
283
|
+
raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
|
|
284
|
+
|
|
285
|
+
if prompt is not None:
|
|
286
|
+
messages = [{"role": "user", "content": prompt}]
|
|
287
|
+
|
|
288
|
+
assert messages is not None
|
|
289
|
+
|
|
290
|
+
openai_msgs: Messages = (
|
|
291
|
+
self._convert_messages(cast(LegacyMsgs, messages))
|
|
292
|
+
if isinstance(messages[0], dict)
|
|
293
|
+
else cast(Messages, messages)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
assert self.model_kwargs["model"] is not None, (
|
|
297
|
+
"Model must be set before making a call."
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Use unified client call
|
|
301
|
+
raw_response = await self._unified_client_call(
|
|
302
|
+
list(openai_msgs), cache_suffix="_call"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
if hasattr(raw_response, "model_dump"):
|
|
306
|
+
raw_response = raw_response.model_dump() # type: ignore
|
|
307
|
+
|
|
308
|
+
# Extract the assistant's message
|
|
309
|
+
assistant_msg = raw_response["choices"][0]["message"]
|
|
310
|
+
# Build the full messages list (input + assistant reply)
|
|
311
|
+
full_messages = list(messages) + [
|
|
312
|
+
{"role": assistant_msg["role"], "content": assistant_msg["content"]}
|
|
313
|
+
]
|
|
314
|
+
# Return the OpenAI message as model_dump (if available) and the messages list
|
|
315
|
+
if hasattr(assistant_msg, "model_dump"):
|
|
316
|
+
msg_dump = assistant_msg.model_dump()
|
|
317
|
+
else:
|
|
318
|
+
msg_dump = dict(assistant_msg)
|
|
319
|
+
return msg_dump, full_messages
|
|
320
|
+
|
|
321
|
+
async def parse(
|
|
322
|
+
self,
|
|
323
|
+
instruction,
|
|
324
|
+
prompt,
|
|
325
|
+
) -> ParsedOutput[BaseModel]:
|
|
326
|
+
"""Parse response using guided JSON generation. Returns (parsed.model_dump(), messages)."""
|
|
327
|
+
if not self._use_beta:
|
|
328
|
+
assert self.add_json_schema_to_instruction, (
|
|
329
|
+
"add_json_schema_to_instruction must be True when use_beta is False. otherwise model will not be able to parse the response."
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
assert self.response_model is not None, "response_model must be set at init."
|
|
333
|
+
json_schema = self.response_model.model_json_schema()
|
|
334
|
+
|
|
335
|
+
# Build system message content in a single, clear block
|
|
336
|
+
assert instruction is not None, "Instruction must be provided."
|
|
337
|
+
assert prompt is not None, "Prompt must be provided."
|
|
338
|
+
system_content = instruction
|
|
339
|
+
|
|
340
|
+
# Add schema if needed
|
|
341
|
+
system_content = self.build_system_prompt(
|
|
342
|
+
self.response_model,
|
|
343
|
+
self.add_json_schema_to_instruction,
|
|
344
|
+
json_schema,
|
|
345
|
+
system_content,
|
|
346
|
+
think=self.think,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
messages = [
|
|
350
|
+
{"role": "system", "content": system_content},
|
|
351
|
+
{"role": "user", "content": prompt},
|
|
352
|
+
] # type: ignore
|
|
353
|
+
|
|
354
|
+
completion, full_messages, parsed = await self._call_and_parse(
|
|
355
|
+
messages,
|
|
356
|
+
self.response_model,
|
|
357
|
+
json_schema,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
return ParsedOutput(
|
|
361
|
+
messages=full_messages,
|
|
362
|
+
parsed=cast(BaseModel, parsed),
|
|
363
|
+
completion=completion,
|
|
364
|
+
model_kwargs=self.model_kwargs,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
def _parse_complete_output(
|
|
368
|
+
self, completion: Any, response_model: Type[BaseModel]
|
|
369
|
+
) -> BaseModel:
|
|
370
|
+
"""Parse completion output to response model."""
|
|
371
|
+
if hasattr(completion, "model_dump"):
|
|
372
|
+
completion = completion.model_dump()
|
|
373
|
+
|
|
374
|
+
if "choices" not in completion or not completion["choices"]:
|
|
375
|
+
raise ValueError("No choices in OpenAI response")
|
|
376
|
+
|
|
377
|
+
content = completion["choices"][0]["message"]["content"]
|
|
378
|
+
if not content:
|
|
379
|
+
raise ValueError("Response content is empty")
|
|
380
|
+
|
|
381
|
+
try:
|
|
382
|
+
data = jloads(content)
|
|
383
|
+
return response_model.model_validate(data)
|
|
384
|
+
except Exception as exc:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
f"Failed to validate against response model {response_model.__name__}: {exc}\nRaw content: {content}"
|
|
387
|
+
) from exc
|