speedy-utils 1.1.17__py3-none-any.whl → 1.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +9 -1
- llm_utils/chat_format/display.py +109 -14
- llm_utils/lm/__init__.py +12 -11
- llm_utils/lm/async_lm/async_llm_task.py +1 -10
- llm_utils/lm/async_lm/async_lm.py +13 -4
- llm_utils/lm/async_lm/async_lm_base.py +24 -14
- llm_utils/lm/base_prompt_builder.py +288 -0
- llm_utils/lm/llm_task.py +693 -0
- llm_utils/lm/lm.py +207 -0
- llm_utils/lm/lm_base.py +285 -0
- llm_utils/lm/openai_memoize.py +2 -2
- llm_utils/vector_cache/core.py +285 -89
- speedy_utils/__init__.py +2 -1
- speedy_utils/common/patcher.py +68 -0
- speedy_utils/common/utils_cache.py +6 -6
- speedy_utils/common/utils_io.py +238 -8
- speedy_utils/multi_worker/process.py +180 -192
- speedy_utils/multi_worker/thread.py +94 -2
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.19.dist-info}/METADATA +36 -14
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.19.dist-info}/RECORD +24 -19
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.19.dist-info}/WHEEL +1 -1
- speedy_utils-1.1.19.dist-info/entry_points.txt +5 -0
- speedy_utils-1.1.17.dist-info/entry_points.txt +0 -6
llm_utils/lm/llm_task.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
1
|
+
# type: ignore
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Simplified LLM Task module for handling language model interactions with structured input/output.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from typing import Any, Dict, List, Optional, Type, Union, cast
|
|
9
|
+
|
|
10
|
+
import requests
|
|
11
|
+
from loguru import logger
|
|
12
|
+
from openai import OpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
13
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
|
|
16
|
+
from .base_prompt_builder import BasePromptBuilder
|
|
17
|
+
|
|
18
|
+
# Type aliases for better readability
|
|
19
|
+
Messages = List[ChatCompletionMessageParam]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_base_client(
|
|
23
|
+
client: Union[OpenAI, int, str, None] = None, cache: bool = True, api_key="abc"
|
|
24
|
+
) -> OpenAI:
|
|
25
|
+
"""Get OpenAI client from port number, base_url string, or existing client."""
|
|
26
|
+
from llm_utils import MOpenAI
|
|
27
|
+
|
|
28
|
+
open_ai_class = OpenAI if not cache else MOpenAI
|
|
29
|
+
if client is None:
|
|
30
|
+
return open_ai_class()
|
|
31
|
+
elif isinstance(client, int):
|
|
32
|
+
return open_ai_class(base_url=f"http://localhost:{client}/v1", api_key=api_key)
|
|
33
|
+
elif isinstance(client, str):
|
|
34
|
+
return open_ai_class(base_url=client, api_key=api_key)
|
|
35
|
+
elif isinstance(client, OpenAI):
|
|
36
|
+
return client
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
"Invalid client type. Must be OpenAI instance, port number (int), base_url (str), or None."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _is_lora_path(path: str) -> bool:
|
|
44
|
+
"""Check if the given path is a LoRA adapter directory.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
path: Path to check
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
True if the path contains adapter_config.json, False otherwise
|
|
51
|
+
"""
|
|
52
|
+
if not os.path.isdir(path):
|
|
53
|
+
return False
|
|
54
|
+
adapter_config_path = os.path.join(path, 'adapter_config.json')
|
|
55
|
+
return os.path.isfile(adapter_config_path)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _get_port_from_client(client: OpenAI) -> Optional[int]:
|
|
59
|
+
"""Extract port number from OpenAI client base_url.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
client: OpenAI client instance
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Port number if found, None otherwise
|
|
66
|
+
"""
|
|
67
|
+
if hasattr(client, 'base_url') and client.base_url:
|
|
68
|
+
base_url = str(client.base_url)
|
|
69
|
+
if 'localhost:' in base_url:
|
|
70
|
+
try:
|
|
71
|
+
# Extract port from localhost:PORT/v1 format
|
|
72
|
+
port_part = base_url.split('localhost:')[1].split('/')[0]
|
|
73
|
+
return int(port_part)
|
|
74
|
+
except (IndexError, ValueError):
|
|
75
|
+
pass
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _load_lora_adapter(lora_path: str, port: int) -> str:
|
|
80
|
+
"""Load a LoRA adapter from the specified path.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
lora_path: Path to the LoRA adapter directory
|
|
84
|
+
port: Port number for the API endpoint
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Name of the loaded LoRA adapter
|
|
88
|
+
|
|
89
|
+
Raises:
|
|
90
|
+
requests.RequestException: If the API call fails
|
|
91
|
+
"""
|
|
92
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
93
|
+
if not lora_name: # Handle edge case of empty basename
|
|
94
|
+
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
95
|
+
|
|
96
|
+
response = requests.post(
|
|
97
|
+
f'http://localhost:{port}/v1/load_lora_adapter',
|
|
98
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
99
|
+
json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
|
|
100
|
+
)
|
|
101
|
+
response.raise_for_status()
|
|
102
|
+
return lora_name
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _unload_lora_adapter(lora_path: str, port: int) -> None:
|
|
106
|
+
"""Unload the current LoRA adapter.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
lora_path: Path to the LoRA adapter directory
|
|
110
|
+
port: Port number for the API endpoint
|
|
111
|
+
"""
|
|
112
|
+
try:
|
|
113
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
114
|
+
if not lora_name: # Handle edge case of empty basename
|
|
115
|
+
lora_name = os.path.basename(os.path.dirname(lora_path))
|
|
116
|
+
|
|
117
|
+
response = requests.post(
|
|
118
|
+
f'http://localhost:{port}/v1/unload_lora_adapter',
|
|
119
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
120
|
+
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
121
|
+
)
|
|
122
|
+
response.raise_for_status()
|
|
123
|
+
except requests.RequestException as e:
|
|
124
|
+
logger.warning(f"Error unloading LoRA adapter: {str(e)[:100]}")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class LLMTask:
|
|
128
|
+
"""
|
|
129
|
+
Language model task with structured input/output and optional system instruction.
|
|
130
|
+
|
|
131
|
+
Supports str or Pydantic models for both input and output. Automatically handles
|
|
132
|
+
message formatting and response parsing.
|
|
133
|
+
|
|
134
|
+
Two main APIs:
|
|
135
|
+
- text(): Returns raw text responses as list of dicts (alias for text_completion)
|
|
136
|
+
- parse(): Returns parsed Pydantic model responses as list of dicts (alias for pydantic_parse)
|
|
137
|
+
- __call__(): Backward compatibility method that delegates based on output_model
|
|
138
|
+
|
|
139
|
+
Example:
|
|
140
|
+
```python
|
|
141
|
+
from pydantic import BaseModel
|
|
142
|
+
from llm_utils.lm.llm_task import LLMTask
|
|
143
|
+
|
|
144
|
+
class EmailOutput(BaseModel):
|
|
145
|
+
content: str
|
|
146
|
+
estimated_read_time: int
|
|
147
|
+
|
|
148
|
+
# Set up task with Pydantic output model
|
|
149
|
+
task = LLMTask(
|
|
150
|
+
instruction="Generate professional email content.",
|
|
151
|
+
output_model=EmailOutput,
|
|
152
|
+
client=OpenAI(),
|
|
153
|
+
temperature=0.7
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Use parse() for structured output
|
|
157
|
+
results = task.parse("Write a meeting follow-up email")
|
|
158
|
+
result = results[0]
|
|
159
|
+
print(result["parsed"].content, result["parsed"].estimated_read_time)
|
|
160
|
+
|
|
161
|
+
# Use text() for plain text output
|
|
162
|
+
results = task.text("Write a meeting follow-up email")
|
|
163
|
+
text_result = results[0]
|
|
164
|
+
print(text_result["parsed"])
|
|
165
|
+
|
|
166
|
+
# Multiple responses
|
|
167
|
+
results = task.parse("Write a meeting follow-up email", n=3)
|
|
168
|
+
for result in results:
|
|
169
|
+
print(f"Content: {result['parsed'].content}")
|
|
170
|
+
|
|
171
|
+
# Override parameters at runtime
|
|
172
|
+
results = task.text(
|
|
173
|
+
"Write a meeting follow-up email",
|
|
174
|
+
temperature=0.9,
|
|
175
|
+
n=2,
|
|
176
|
+
max_tokens=500
|
|
177
|
+
)
|
|
178
|
+
for result in results:
|
|
179
|
+
print(result["parsed"])
|
|
180
|
+
|
|
181
|
+
# Backward compatibility (uses output_model to choose method)
|
|
182
|
+
results = task("Write a meeting follow-up email") # Calls parse()
|
|
183
|
+
result = results[0]
|
|
184
|
+
print(result["parsed"].content)
|
|
185
|
+
```
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def __init__(
|
|
189
|
+
self,
|
|
190
|
+
instruction: Optional[str] = None,
|
|
191
|
+
input_model: Union[Type[BaseModel], type[str]] = str,
|
|
192
|
+
output_model: Type[BaseModel] | Type[str] = None,
|
|
193
|
+
client: Union[OpenAI, int, str, None] = None,
|
|
194
|
+
cache=True,
|
|
195
|
+
is_reasoning_model: bool = False,
|
|
196
|
+
force_lora_unload: bool = False,
|
|
197
|
+
lora_path: Optional[str] = None,
|
|
198
|
+
**model_kwargs,
|
|
199
|
+
):
|
|
200
|
+
"""
|
|
201
|
+
Initialize the LLMTask.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
instruction: Optional system instruction for the task
|
|
205
|
+
input_model: Input type (str or BaseModel subclass)
|
|
206
|
+
output_model: Output BaseModel type
|
|
207
|
+
client: OpenAI client, port number, or base_url string
|
|
208
|
+
cache: Whether to use cached responses (default True)
|
|
209
|
+
is_reasoning_model: Whether the model is a reasoning model (o1-preview, o1-mini, etc.)
|
|
210
|
+
that outputs reasoning_content separately from content (default False)
|
|
211
|
+
force_lora_unload: If True, forces unloading of any existing LoRA adapter before loading
|
|
212
|
+
a new one when lora_path is provided (default False)
|
|
213
|
+
lora_path: Optional path to LoRA adapter directory. If provided, will load the LoRA
|
|
214
|
+
and use it as the model. Takes precedence over model parameter.
|
|
215
|
+
**model_kwargs: Additional model parameters including:
|
|
216
|
+
- temperature: Controls randomness (0.0 to 2.0)
|
|
217
|
+
- n: Number of responses to generate (when n > 1, returns list)
|
|
218
|
+
- max_tokens: Maximum tokens in response
|
|
219
|
+
- model: Model name (auto-detected if not provided)
|
|
220
|
+
"""
|
|
221
|
+
self.instruction = instruction
|
|
222
|
+
self.input_model = input_model
|
|
223
|
+
self.output_model = output_model
|
|
224
|
+
self.model_kwargs = model_kwargs
|
|
225
|
+
self.is_reasoning_model = is_reasoning_model
|
|
226
|
+
self.force_lora_unload = force_lora_unload
|
|
227
|
+
self.lora_path = lora_path
|
|
228
|
+
self.last_ai_response = None # Store raw response from client
|
|
229
|
+
|
|
230
|
+
# if cache:
|
|
231
|
+
# print("Caching is enabled will use llm_utils.MOpenAI")
|
|
232
|
+
|
|
233
|
+
# self.client = MOpenAI(base_url=base_url, api_key=api_key)
|
|
234
|
+
# else:
|
|
235
|
+
# self.client = OpenAI(base_url=base_url, api_key=api_key)
|
|
236
|
+
self.client = get_base_client(client, cache=cache)
|
|
237
|
+
# check connection of client
|
|
238
|
+
try:
|
|
239
|
+
self.client.models.list()
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.error(f"Failed to connect to OpenAI client: {str(e)}, base_url={self.client.base_url}")
|
|
242
|
+
raise e
|
|
243
|
+
|
|
244
|
+
if not self.model_kwargs.get("model", ""):
|
|
245
|
+
self.model_kwargs["model"] = self.client.models.list().data[0].id
|
|
246
|
+
|
|
247
|
+
# Handle LoRA loading if lora_path is provided
|
|
248
|
+
if self.lora_path:
|
|
249
|
+
self._load_lora_adapter()
|
|
250
|
+
|
|
251
|
+
print(self.model_kwargs)
|
|
252
|
+
|
|
253
|
+
def _load_lora_adapter(self) -> None:
|
|
254
|
+
"""
|
|
255
|
+
Load LoRA adapter from the specified lora_path.
|
|
256
|
+
|
|
257
|
+
This method:
|
|
258
|
+
1. Validates that lora_path is a valid LoRA directory
|
|
259
|
+
2. Checks if LoRA is already loaded (unless force_lora_unload is True)
|
|
260
|
+
3. Loads the LoRA adapter and updates the model name
|
|
261
|
+
"""
|
|
262
|
+
if not self.lora_path:
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
if not _is_lora_path(self.lora_path):
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Invalid LoRA path '{self.lora_path}': "
|
|
268
|
+
"Directory must contain 'adapter_config.json'"
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
logger.info(f"Loading LoRA adapter from: {self.lora_path}")
|
|
272
|
+
|
|
273
|
+
# Get the expected LoRA name (basename of the path)
|
|
274
|
+
lora_name = os.path.basename(self.lora_path.rstrip('/\\'))
|
|
275
|
+
if not lora_name: # Handle edge case of empty basename
|
|
276
|
+
lora_name = os.path.basename(os.path.dirname(self.lora_path))
|
|
277
|
+
|
|
278
|
+
# Get list of available models to check if LoRA is already loaded
|
|
279
|
+
try:
|
|
280
|
+
available_models = [m.id for m in self.client.models.list().data]
|
|
281
|
+
except Exception as e:
|
|
282
|
+
logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
|
|
283
|
+
available_models = []
|
|
284
|
+
|
|
285
|
+
# Check if LoRA is already loaded
|
|
286
|
+
if lora_name in available_models and not self.force_lora_unload:
|
|
287
|
+
logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
|
|
288
|
+
self.model_kwargs["model"] = lora_name
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
# Force unload if requested
|
|
292
|
+
if self.force_lora_unload and lora_name in available_models:
|
|
293
|
+
logger.info(f"Force unloading LoRA adapter '{lora_name}' before reloading")
|
|
294
|
+
port = _get_port_from_client(self.client)
|
|
295
|
+
if port is not None:
|
|
296
|
+
try:
|
|
297
|
+
LLMTask.unload_lora(port, lora_name)
|
|
298
|
+
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
299
|
+
except Exception as e:
|
|
300
|
+
logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
|
|
301
|
+
|
|
302
|
+
# Get port from client for API calls
|
|
303
|
+
port = _get_port_from_client(self.client)
|
|
304
|
+
if port is None:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Cannot load LoRA adapter '{self.lora_path}': "
|
|
307
|
+
"Unable to determine port from client base_url. "
|
|
308
|
+
"LoRA loading requires a client initialized with port number."
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
# Load the LoRA adapter
|
|
313
|
+
loaded_lora_name = _load_lora_adapter(self.lora_path, port)
|
|
314
|
+
logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
|
|
315
|
+
|
|
316
|
+
# Update model name to the loaded LoRA name
|
|
317
|
+
self.model_kwargs["model"] = loaded_lora_name
|
|
318
|
+
|
|
319
|
+
except requests.RequestException as e:
|
|
320
|
+
# Check if the error is due to LoRA already being loaded
|
|
321
|
+
error_msg = str(e)
|
|
322
|
+
if "400" in error_msg or "Bad Request" in error_msg:
|
|
323
|
+
logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
|
|
324
|
+
# Refresh the model list to check if it's now available
|
|
325
|
+
try:
|
|
326
|
+
updated_models = [m.id for m in self.client.models.list().data]
|
|
327
|
+
if lora_name in updated_models:
|
|
328
|
+
logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
|
|
329
|
+
self.model_kwargs["model"] = lora_name
|
|
330
|
+
return
|
|
331
|
+
except Exception:
|
|
332
|
+
pass # Fall through to original error
|
|
333
|
+
|
|
334
|
+
raise ValueError(
|
|
335
|
+
f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def unload_lora_adapter(self, lora_path: str) -> None:
|
|
339
|
+
"""
|
|
340
|
+
Unload a LoRA adapter.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
lora_path: Path to the LoRA adapter directory to unload
|
|
344
|
+
|
|
345
|
+
Raises:
|
|
346
|
+
ValueError: If unable to determine port from client
|
|
347
|
+
"""
|
|
348
|
+
port = _get_port_from_client(self.client)
|
|
349
|
+
if port is None:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
"Cannot unload LoRA adapter: "
|
|
352
|
+
"Unable to determine port from client base_url. "
|
|
353
|
+
"LoRA operations require a client initialized with port number."
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
_unload_lora_adapter(lora_path, port)
|
|
357
|
+
lora_name = os.path.basename(lora_path.rstrip('/\\'))
|
|
358
|
+
logger.info(f"Unloaded LoRA adapter: {lora_name}")
|
|
359
|
+
|
|
360
|
+
@staticmethod
|
|
361
|
+
def unload_lora(port: int, lora_name: str) -> None:
|
|
362
|
+
"""Static method to unload a LoRA adapter by name.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
port: Port number for the API endpoint
|
|
366
|
+
lora_name: Name of the LoRA adapter to unload
|
|
367
|
+
|
|
368
|
+
Raises:
|
|
369
|
+
requests.RequestException: If the API call fails
|
|
370
|
+
"""
|
|
371
|
+
try:
|
|
372
|
+
response = requests.post(
|
|
373
|
+
f'http://localhost:{port}/v1/unload_lora_adapter',
|
|
374
|
+
headers={'accept': 'application/json', 'Content-Type': 'application/json'},
|
|
375
|
+
json={"lora_name": lora_name, "lora_int_id": 0}
|
|
376
|
+
)
|
|
377
|
+
response.raise_for_status()
|
|
378
|
+
logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
|
|
379
|
+
except requests.RequestException as e:
|
|
380
|
+
logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
|
|
381
|
+
raise
|
|
382
|
+
|
|
383
|
+
def _prepare_input(self, input_data: Union[str, BaseModel, List[Dict]]) -> Messages:
|
|
384
|
+
"""Convert input to messages format."""
|
|
385
|
+
if isinstance(input_data, list):
|
|
386
|
+
assert isinstance(input_data[0], dict) and "role" in input_data[0], (
|
|
387
|
+
"If input_data is a list, it must be a list of messages with 'role' and 'content' keys."
|
|
388
|
+
)
|
|
389
|
+
return cast(Messages, input_data)
|
|
390
|
+
else:
|
|
391
|
+
# Convert input to string format
|
|
392
|
+
if isinstance(input_data, str):
|
|
393
|
+
user_content = input_data
|
|
394
|
+
elif hasattr(input_data, "model_dump_json"):
|
|
395
|
+
user_content = input_data.model_dump_json()
|
|
396
|
+
elif isinstance(input_data, dict):
|
|
397
|
+
user_content = str(input_data)
|
|
398
|
+
else:
|
|
399
|
+
user_content = str(input_data)
|
|
400
|
+
|
|
401
|
+
# Build messages
|
|
402
|
+
messages = (
|
|
403
|
+
[
|
|
404
|
+
{"role": "system", "content": self.instruction},
|
|
405
|
+
]
|
|
406
|
+
if self.instruction is not None
|
|
407
|
+
else []
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
messages.append({"role": "user", "content": user_content})
|
|
411
|
+
return cast(Messages, messages)
|
|
412
|
+
|
|
413
|
+
def text_completion(
|
|
414
|
+
self, input_data: Union[str, BaseModel, list[Dict]], **runtime_kwargs
|
|
415
|
+
) -> List[Dict[str, Any]]:
|
|
416
|
+
"""
|
|
417
|
+
Execute the LLM task and return text responses.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
input_data: Input as string or BaseModel
|
|
421
|
+
**runtime_kwargs: Runtime model parameters that override defaults
|
|
422
|
+
- temperature: Controls randomness (0.0 to 2.0)
|
|
423
|
+
- n: Number of responses to generate
|
|
424
|
+
- max_tokens: Maximum tokens in response
|
|
425
|
+
- model: Model name override
|
|
426
|
+
- Any other model parameters supported by OpenAI API
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
List of dicts [{'parsed': text_response, 'messages': messages}, ...]
|
|
430
|
+
When n=1: List contains one dict
|
|
431
|
+
When n>1: List contains multiple dicts
|
|
432
|
+
"""
|
|
433
|
+
# Prepare messages
|
|
434
|
+
messages = self._prepare_input(input_data)
|
|
435
|
+
|
|
436
|
+
# Merge runtime kwargs with default model kwargs (runtime takes precedence)
|
|
437
|
+
effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
|
|
438
|
+
model_name = effective_kwargs.get("model", self.model_kwargs["model"])
|
|
439
|
+
|
|
440
|
+
# Extract model name from kwargs for API call
|
|
441
|
+
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
|
|
442
|
+
|
|
443
|
+
try:
|
|
444
|
+
completion = self.client.chat.completions.create(
|
|
445
|
+
model=model_name, messages=messages, **api_kwargs
|
|
446
|
+
)
|
|
447
|
+
# Store raw response from client
|
|
448
|
+
self.last_ai_response = completion
|
|
449
|
+
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
450
|
+
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
451
|
+
logger.error(error_msg)
|
|
452
|
+
raise
|
|
453
|
+
except Exception as e:
|
|
454
|
+
is_length_error = "Length" in str(e) or "maximum context length" in str(e)
|
|
455
|
+
if is_length_error:
|
|
456
|
+
raise ValueError(
|
|
457
|
+
f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
|
|
458
|
+
)
|
|
459
|
+
# Re-raise all other exceptions
|
|
460
|
+
raise
|
|
461
|
+
# print(completion)
|
|
462
|
+
|
|
463
|
+
results: List[Dict[str, Any]] = []
|
|
464
|
+
for choice in completion.choices:
|
|
465
|
+
choice_messages = cast(
|
|
466
|
+
Messages,
|
|
467
|
+
messages + [{"role": "assistant", "content": choice.message.content}],
|
|
468
|
+
)
|
|
469
|
+
result_dict = {"parsed": choice.message.content, "messages": choice_messages}
|
|
470
|
+
|
|
471
|
+
# Add reasoning content if this is a reasoning model
|
|
472
|
+
if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
|
|
473
|
+
result_dict["reasoning_content"] = choice.message.reasoning_content
|
|
474
|
+
|
|
475
|
+
results.append(result_dict)
|
|
476
|
+
return results
|
|
477
|
+
|
|
478
|
+
def pydantic_parse(
|
|
479
|
+
self,
|
|
480
|
+
input_data: Union[str, BaseModel, list[Dict]],
|
|
481
|
+
response_model: Optional[Type[BaseModel]] | Type[str] = None,
|
|
482
|
+
**runtime_kwargs,
|
|
483
|
+
) -> List[Dict[str, Any]]:
|
|
484
|
+
"""
|
|
485
|
+
Execute the LLM task and return parsed Pydantic model responses.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
input_data: Input as string or BaseModel
|
|
489
|
+
response_model: Pydantic model for response parsing (overrides default)
|
|
490
|
+
**runtime_kwargs: Runtime model parameters that override defaults
|
|
491
|
+
- temperature: Controls randomness (0.0 to 2.0)
|
|
492
|
+
- n: Number of responses to generate
|
|
493
|
+
- max_tokens: Maximum tokens in response
|
|
494
|
+
- model: Model name override
|
|
495
|
+
- Any other model parameters supported by OpenAI API
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
List of dicts [{'parsed': parsed_model, 'messages': messages}, ...]
|
|
499
|
+
When n=1: List contains one dict
|
|
500
|
+
When n>1: List contains multiple dicts
|
|
501
|
+
|
|
502
|
+
Note:
|
|
503
|
+
This method ensures consistent Pydantic model output for both fresh and cached responses.
|
|
504
|
+
When responses are cached and loaded back, the parsed content is re-validated to maintain
|
|
505
|
+
type consistency between first-time and subsequent calls.
|
|
506
|
+
"""
|
|
507
|
+
# Prepare messages
|
|
508
|
+
messages = self._prepare_input(input_data)
|
|
509
|
+
|
|
510
|
+
# Merge runtime kwargs with default model kwargs (runtime takes precedence)
|
|
511
|
+
effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
|
|
512
|
+
model_name = effective_kwargs.get("model", self.model_kwargs["model"])
|
|
513
|
+
|
|
514
|
+
# Extract model name from kwargs for API call
|
|
515
|
+
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
|
|
516
|
+
|
|
517
|
+
pydantic_model_to_use_opt = response_model or self.output_model
|
|
518
|
+
if pydantic_model_to_use_opt is None:
|
|
519
|
+
raise ValueError(
|
|
520
|
+
"No response model specified. Either set output_model in constructor or pass response_model parameter."
|
|
521
|
+
)
|
|
522
|
+
pydantic_model_to_use: Type[BaseModel] = cast(
|
|
523
|
+
Type[BaseModel], pydantic_model_to_use_opt
|
|
524
|
+
)
|
|
525
|
+
try:
|
|
526
|
+
completion = self.client.chat.completions.parse(
|
|
527
|
+
model=model_name,
|
|
528
|
+
messages=messages,
|
|
529
|
+
response_format=pydantic_model_to_use,
|
|
530
|
+
**api_kwargs,
|
|
531
|
+
)
|
|
532
|
+
# Store raw response from client
|
|
533
|
+
self.last_ai_response = completion
|
|
534
|
+
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
535
|
+
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
536
|
+
logger.error(error_msg)
|
|
537
|
+
raise
|
|
538
|
+
except Exception as e:
|
|
539
|
+
is_length_error = "Length" in str(e) or "maximum context length" in str(e)
|
|
540
|
+
if is_length_error:
|
|
541
|
+
raise ValueError(
|
|
542
|
+
f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
|
|
543
|
+
)
|
|
544
|
+
# Re-raise all other exceptions
|
|
545
|
+
raise
|
|
546
|
+
|
|
547
|
+
results: List[Dict[str, Any]] = []
|
|
548
|
+
for choice in completion.choices: # type: ignore[attr-defined]
|
|
549
|
+
choice_messages = cast(
|
|
550
|
+
Messages,
|
|
551
|
+
messages + [{"role": "assistant", "content": choice.message.content}],
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# Ensure consistent Pydantic model output for both fresh and cached responses
|
|
555
|
+
parsed_content = choice.message.parsed # type: ignore[attr-defined]
|
|
556
|
+
if isinstance(parsed_content, dict):
|
|
557
|
+
# Cached response: validate dict back to Pydantic model
|
|
558
|
+
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
559
|
+
elif not isinstance(parsed_content, pydantic_model_to_use):
|
|
560
|
+
# Fallback: ensure it's the correct type
|
|
561
|
+
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
562
|
+
|
|
563
|
+
result_dict = {"parsed": parsed_content, "messages": choice_messages}
|
|
564
|
+
|
|
565
|
+
# Add reasoning content if this is a reasoning model
|
|
566
|
+
if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
|
|
567
|
+
result_dict["reasoning_content"] = choice.message.reasoning_content
|
|
568
|
+
|
|
569
|
+
results.append(result_dict)
|
|
570
|
+
return results
|
|
571
|
+
|
|
572
|
+
def __call__(
|
|
573
|
+
self,
|
|
574
|
+
input_data: Union[str, BaseModel, list[Dict]],
|
|
575
|
+
response_model: Optional[Type[BaseModel] | Type[str]] = None,
|
|
576
|
+
two_step_parse_pydantic=False,
|
|
577
|
+
**runtime_kwargs,
|
|
578
|
+
) -> List[Dict[str, Any]]:
|
|
579
|
+
"""
|
|
580
|
+
Execute the LLM task. Delegates to text() or parse() based on output_model.
|
|
581
|
+
|
|
582
|
+
This method maintains backward compatibility by automatically choosing
|
|
583
|
+
between text and parse methods based on the output_model configuration.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
input_data: Input as string or BaseModel
|
|
587
|
+
response_model: Optional override for output model
|
|
588
|
+
**runtime_kwargs: Runtime model parameters
|
|
589
|
+
|
|
590
|
+
Returns:
|
|
591
|
+
List of dicts [{'parsed': response, 'messages': messages}, ...]
|
|
592
|
+
"""
|
|
593
|
+
pydantic_model_to_use = response_model or self.output_model
|
|
594
|
+
|
|
595
|
+
if pydantic_model_to_use is str or pydantic_model_to_use is None:
|
|
596
|
+
return self.text_completion(input_data, **runtime_kwargs)
|
|
597
|
+
elif two_step_parse_pydantic:
|
|
598
|
+
# step 1: get text completions
|
|
599
|
+
results = self.text_completion(input_data, **runtime_kwargs)
|
|
600
|
+
parsed_results = []
|
|
601
|
+
for result in results:
|
|
602
|
+
response_text = result["parsed"]
|
|
603
|
+
messages = result["messages"]
|
|
604
|
+
# check if the pydantic_model_to_use is validated
|
|
605
|
+
if "</think>" in response_text:
|
|
606
|
+
response_text = response_text.split("</think>")[1]
|
|
607
|
+
try:
|
|
608
|
+
parsed = pydantic_model_to_use.model_validate_json(response_text)
|
|
609
|
+
except Exception as e:
|
|
610
|
+
# logger.info(
|
|
611
|
+
# f"Warning: Failed to parsed JSON, Falling back to LLM parsing. Error: {str(e)[:100]}..."
|
|
612
|
+
# )
|
|
613
|
+
# use model to parse the response_text
|
|
614
|
+
_parsed_messages = [
|
|
615
|
+
{
|
|
616
|
+
"role": "system",
|
|
617
|
+
"content": "You are a helpful assistant that extracts JSON from text.",
|
|
618
|
+
},
|
|
619
|
+
{
|
|
620
|
+
"role": "user",
|
|
621
|
+
"content": f"Extract JSON from the following text:\n{response_text}",
|
|
622
|
+
},
|
|
623
|
+
]
|
|
624
|
+
parsed_result = self.pydantic_parse(
|
|
625
|
+
_parsed_messages,
|
|
626
|
+
response_model=pydantic_model_to_use,
|
|
627
|
+
**runtime_kwargs,
|
|
628
|
+
)[0]
|
|
629
|
+
parsed = parsed_result["parsed"]
|
|
630
|
+
# ---
|
|
631
|
+
parsed_results.append({"parsed": parsed, "messages": messages})
|
|
632
|
+
return parsed_results
|
|
633
|
+
|
|
634
|
+
else:
|
|
635
|
+
return self.pydantic_parse(
|
|
636
|
+
input_data, response_model=response_model, **runtime_kwargs
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
# Backward compatibility aliases
|
|
640
|
+
def text(self, *args, **kwargs) -> List[Dict[str, Any]]:
|
|
641
|
+
"""Alias for text_completion() for backward compatibility."""
|
|
642
|
+
return self.text_completion(*args, **kwargs)
|
|
643
|
+
|
|
644
|
+
def parse(self, *args, **kwargs) -> List[Dict[str, Any]]:
|
|
645
|
+
"""Alias for pydantic_parse() for backward compatibility."""
|
|
646
|
+
return self.pydantic_parse(*args, **kwargs)
|
|
647
|
+
|
|
648
|
+
@classmethod
|
|
649
|
+
def from_prompt_builder(
|
|
650
|
+
builder: BasePromptBuilder,
|
|
651
|
+
client: Union[OpenAI, int, str, None] = None,
|
|
652
|
+
cache=True,
|
|
653
|
+
is_reasoning_model: bool = False,
|
|
654
|
+
lora_path: Optional[str] = None,
|
|
655
|
+
**model_kwargs,
|
|
656
|
+
) -> "LLMTask":
|
|
657
|
+
"""
|
|
658
|
+
Create an LLMTask instance from a BasePromptBuilder instance.
|
|
659
|
+
|
|
660
|
+
This method extracts the instruction, input model, and output model
|
|
661
|
+
from the provided builder and initializes an LLMTask accordingly.
|
|
662
|
+
"""
|
|
663
|
+
instruction = builder.get_instruction()
|
|
664
|
+
input_model = builder.get_input_model()
|
|
665
|
+
output_model = builder.get_output_model()
|
|
666
|
+
|
|
667
|
+
# Extract data from the builder to initialize LLMTask
|
|
668
|
+
return LLMTask(
|
|
669
|
+
instruction=instruction,
|
|
670
|
+
input_model=input_model,
|
|
671
|
+
output_model=output_model,
|
|
672
|
+
client=client,
|
|
673
|
+
cache=cache,
|
|
674
|
+
is_reasoning_model=is_reasoning_model,
|
|
675
|
+
lora_path=lora_path,
|
|
676
|
+
**model_kwargs,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
@staticmethod
|
|
680
|
+
def list_models(client: Union[OpenAI, int, str, None] = None) -> List[str]:
|
|
681
|
+
"""
|
|
682
|
+
List available models from the OpenAI client.
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
client: OpenAI client, port number, or base_url string
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
List of available model names.
|
|
689
|
+
"""
|
|
690
|
+
client = get_base_client(client, cache=False)
|
|
691
|
+
models = client.models.list().data
|
|
692
|
+
return [m.id for m in models]
|
|
693
|
+
|