speedy-utils 1.1.26__py3-none-any.whl → 1.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +16 -4
- llm_utils/chat_format/__init__.py +10 -10
- llm_utils/chat_format/display.py +33 -21
- llm_utils/chat_format/transform.py +17 -19
- llm_utils/chat_format/utils.py +6 -4
- llm_utils/group_messages.py +17 -14
- llm_utils/lm/__init__.py +6 -5
- llm_utils/lm/async_lm/__init__.py +1 -0
- llm_utils/lm/async_lm/_utils.py +10 -9
- llm_utils/lm/async_lm/async_llm_task.py +141 -137
- llm_utils/lm/async_lm/async_lm.py +48 -42
- llm_utils/lm/async_lm/async_lm_base.py +59 -60
- llm_utils/lm/async_lm/lm_specific.py +4 -3
- llm_utils/lm/base_prompt_builder.py +93 -70
- llm_utils/lm/llm.py +126 -108
- llm_utils/lm/llm_signature.py +4 -2
- llm_utils/lm/lm_base.py +72 -73
- llm_utils/lm/mixins.py +102 -62
- llm_utils/lm/openai_memoize.py +124 -87
- llm_utils/lm/signature.py +105 -92
- llm_utils/lm/utils.py +42 -23
- llm_utils/scripts/vllm_load_balancer.py +23 -30
- llm_utils/scripts/vllm_serve.py +8 -7
- llm_utils/vector_cache/__init__.py +9 -3
- llm_utils/vector_cache/cli.py +1 -1
- llm_utils/vector_cache/core.py +59 -63
- llm_utils/vector_cache/types.py +7 -5
- llm_utils/vector_cache/utils.py +12 -8
- speedy_utils/__imports.py +244 -0
- speedy_utils/__init__.py +90 -194
- speedy_utils/all.py +125 -227
- speedy_utils/common/clock.py +37 -42
- speedy_utils/common/function_decorator.py +6 -12
- speedy_utils/common/logger.py +43 -52
- speedy_utils/common/notebook_utils.py +13 -21
- speedy_utils/common/patcher.py +21 -17
- speedy_utils/common/report_manager.py +42 -44
- speedy_utils/common/utils_cache.py +152 -169
- speedy_utils/common/utils_io.py +137 -103
- speedy_utils/common/utils_misc.py +15 -21
- speedy_utils/common/utils_print.py +22 -28
- speedy_utils/multi_worker/process.py +66 -79
- speedy_utils/multi_worker/thread.py +78 -155
- speedy_utils/scripts/mpython.py +38 -36
- speedy_utils/scripts/openapi_client_codegen.py +10 -10
- {speedy_utils-1.1.26.dist-info → speedy_utils-1.1.28.dist-info}/METADATA +1 -1
- speedy_utils-1.1.28.dist-info/RECORD +57 -0
- vision_utils/README.md +202 -0
- vision_utils/__init__.py +5 -0
- vision_utils/io_utils.py +470 -0
- vision_utils/plot.py +345 -0
- speedy_utils-1.1.26.dist-info/RECORD +0 -52
- {speedy_utils-1.1.26.dist-info → speedy_utils-1.1.28.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.26.dist-info → speedy_utils-1.1.28.dist-info}/entry_points.txt +0 -0
llm_utils/lm/llm.py
CHANGED
|
@@ -10,35 +10,36 @@ from typing import Any, Dict, List, Optional, Type, Union, cast
|
|
|
10
10
|
|
|
11
11
|
import requests
|
|
12
12
|
from loguru import logger
|
|
13
|
-
from openai import
|
|
13
|
+
from openai import AuthenticationError, BadRequestError, OpenAI, RateLimitError
|
|
14
14
|
from openai.types.chat import ChatCompletionMessageParam
|
|
15
15
|
from pydantic import BaseModel
|
|
16
16
|
|
|
17
17
|
from speedy_utils.common.utils_io import jdumps
|
|
18
18
|
|
|
19
|
+
from .base_prompt_builder import BasePromptBuilder
|
|
20
|
+
from .mixins import (
|
|
21
|
+
ModelUtilsMixin,
|
|
22
|
+
TemperatureRangeMixin,
|
|
23
|
+
TwoStepPydanticMixin,
|
|
24
|
+
VLLMMixin,
|
|
25
|
+
)
|
|
19
26
|
from .utils import (
|
|
20
27
|
_extract_port_from_vllm_cmd,
|
|
21
|
-
_start_vllm_server,
|
|
22
|
-
_kill_vllm_on_port,
|
|
23
|
-
_is_server_running,
|
|
24
|
-
get_base_client,
|
|
25
|
-
_is_lora_path,
|
|
26
28
|
_get_port_from_client,
|
|
29
|
+
_is_lora_path,
|
|
30
|
+
_is_server_running,
|
|
31
|
+
_kill_vllm_on_port,
|
|
27
32
|
_load_lora_adapter,
|
|
33
|
+
_start_vllm_server,
|
|
28
34
|
_unload_lora_adapter,
|
|
35
|
+
get_base_client,
|
|
29
36
|
kill_all_vllm_processes,
|
|
30
37
|
stop_vllm_process,
|
|
31
38
|
)
|
|
32
|
-
|
|
33
|
-
from .mixins import (
|
|
34
|
-
TemperatureRangeMixin,
|
|
35
|
-
TwoStepPydanticMixin,
|
|
36
|
-
VLLMMixin,
|
|
37
|
-
ModelUtilsMixin,
|
|
38
|
-
)
|
|
39
|
+
|
|
39
40
|
|
|
40
41
|
# Type aliases for better readability
|
|
41
|
-
Messages =
|
|
42
|
+
Messages = list[ChatCompletionMessageParam]
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
class LLM(
|
|
@@ -51,15 +52,15 @@ class LLM(
|
|
|
51
52
|
|
|
52
53
|
def __init__(
|
|
53
54
|
self,
|
|
54
|
-
instruction:
|
|
55
|
-
input_model:
|
|
56
|
-
output_model:
|
|
57
|
-
client:
|
|
55
|
+
instruction: str | None = None,
|
|
56
|
+
input_model: type[BaseModel] | type[str] = str,
|
|
57
|
+
output_model: type[BaseModel] | type[str] = None,
|
|
58
|
+
client: OpenAI | int | str | None = None,
|
|
58
59
|
cache=True,
|
|
59
60
|
is_reasoning_model: bool = False,
|
|
60
61
|
force_lora_unload: bool = False,
|
|
61
|
-
lora_path:
|
|
62
|
-
vllm_cmd:
|
|
62
|
+
lora_path: str | None = None,
|
|
63
|
+
vllm_cmd: str | None = None,
|
|
63
64
|
vllm_timeout: int = 1200,
|
|
64
65
|
vllm_reuse: bool = True,
|
|
65
66
|
**model_kwargs,
|
|
@@ -75,7 +76,7 @@ class LLM(
|
|
|
75
76
|
self.vllm_cmd = vllm_cmd
|
|
76
77
|
self.vllm_timeout = vllm_timeout
|
|
77
78
|
self.vllm_reuse = vllm_reuse
|
|
78
|
-
self.vllm_process:
|
|
79
|
+
self.vllm_process: subprocess.Popen | None = None
|
|
79
80
|
self.last_ai_response = None # Store raw response from client
|
|
80
81
|
self.cache = cache
|
|
81
82
|
|
|
@@ -88,16 +89,20 @@ class LLM(
|
|
|
88
89
|
if client is None:
|
|
89
90
|
client = port
|
|
90
91
|
|
|
91
|
-
self.client = get_base_client(
|
|
92
|
+
self.client = get_base_client(
|
|
93
|
+
client, cache=cache, vllm_cmd=self.vllm_cmd, vllm_process=self.vllm_process
|
|
94
|
+
)
|
|
92
95
|
# check connection of client
|
|
93
96
|
try:
|
|
94
97
|
self.client.models.list()
|
|
95
98
|
except Exception as e:
|
|
96
|
-
logger.error(
|
|
99
|
+
logger.error(
|
|
100
|
+
f'Failed to connect to OpenAI client: {str(e)}, base_url={self.client.base_url}'
|
|
101
|
+
)
|
|
97
102
|
raise e
|
|
98
103
|
|
|
99
|
-
if not self.model_kwargs.get(
|
|
100
|
-
self.model_kwargs[
|
|
104
|
+
if not self.model_kwargs.get('model', ''):
|
|
105
|
+
self.model_kwargs['model'] = self.client.models.list().data[0].id
|
|
101
106
|
|
|
102
107
|
# Handle LoRA loading if lora_path is provided
|
|
103
108
|
if self.lora_path:
|
|
@@ -111,102 +116,112 @@ class LLM(
|
|
|
111
116
|
"""Context manager exit with cleanup."""
|
|
112
117
|
self.cleanup_vllm_server()
|
|
113
118
|
|
|
114
|
-
def _prepare_input(self, input_data:
|
|
119
|
+
def _prepare_input(self, input_data: str | BaseModel | list[dict]) -> Messages:
|
|
115
120
|
"""Convert input to messages format."""
|
|
116
121
|
if isinstance(input_data, list):
|
|
117
|
-
assert isinstance(input_data[0], dict) and
|
|
122
|
+
assert isinstance(input_data[0], dict) and 'role' in input_data[0], (
|
|
118
123
|
"If input_data is a list, it must be a list of messages with 'role' and 'content' keys."
|
|
119
124
|
)
|
|
120
125
|
return cast(Messages, input_data)
|
|
126
|
+
# Convert input to string format
|
|
127
|
+
if isinstance(input_data, str):
|
|
128
|
+
user_content = input_data
|
|
129
|
+
elif hasattr(input_data, 'model_dump_json'):
|
|
130
|
+
user_content = input_data.model_dump_json()
|
|
131
|
+
elif isinstance(input_data, dict):
|
|
132
|
+
user_content = jdumps(input_data)
|
|
121
133
|
else:
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
# Build messages
|
|
133
|
-
messages = (
|
|
134
|
-
[
|
|
135
|
-
{"role": "system", "content": self.instruction},
|
|
136
|
-
]
|
|
137
|
-
if self.instruction is not None
|
|
138
|
-
else []
|
|
139
|
-
)
|
|
134
|
+
user_content = str(input_data)
|
|
135
|
+
|
|
136
|
+
# Build messages
|
|
137
|
+
messages = (
|
|
138
|
+
[
|
|
139
|
+
{'role': 'system', 'content': self.instruction},
|
|
140
|
+
]
|
|
141
|
+
if self.instruction is not None
|
|
142
|
+
else []
|
|
143
|
+
)
|
|
140
144
|
|
|
141
|
-
|
|
142
|
-
|
|
145
|
+
messages.append({'role': 'user', 'content': user_content})
|
|
146
|
+
return cast(Messages, messages)
|
|
143
147
|
|
|
144
|
-
def text_completion(
|
|
148
|
+
def text_completion(
|
|
149
|
+
self, input_data: str | BaseModel | list[dict], **runtime_kwargs
|
|
150
|
+
) -> list[dict[str, Any]]:
|
|
145
151
|
"""Execute LLM task and return text responses."""
|
|
146
152
|
# Prepare messages
|
|
147
153
|
messages = self._prepare_input(input_data)
|
|
148
154
|
|
|
149
155
|
# Merge runtime kwargs with default model kwargs (runtime takes precedence)
|
|
150
156
|
effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
|
|
151
|
-
model_name = effective_kwargs.get(
|
|
157
|
+
model_name = effective_kwargs.get('model', self.model_kwargs['model'])
|
|
152
158
|
|
|
153
159
|
# Extract model name from kwargs for API call
|
|
154
|
-
api_kwargs = {k: v for k, v in effective_kwargs.items() if k !=
|
|
160
|
+
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != 'model'}
|
|
155
161
|
|
|
156
162
|
try:
|
|
157
|
-
completion = self.client.chat.completions.create(
|
|
163
|
+
completion = self.client.chat.completions.create(
|
|
164
|
+
model=model_name, messages=messages, **api_kwargs
|
|
165
|
+
)
|
|
158
166
|
# Store raw response from client
|
|
159
167
|
self.last_ai_response = completion
|
|
160
168
|
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
161
|
-
error_msg = f
|
|
169
|
+
error_msg = f'OpenAI API error ({type(exc).__name__}): {exc}'
|
|
162
170
|
logger.error(error_msg)
|
|
163
171
|
raise
|
|
164
172
|
except Exception as e:
|
|
165
|
-
is_length_error =
|
|
173
|
+
is_length_error = 'Length' in str(e) or 'maximum context length' in str(e)
|
|
166
174
|
if is_length_error:
|
|
167
|
-
raise ValueError(
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f'Input too long for model {model_name}. Error: {str(e)[:100]}...'
|
|
177
|
+
) from e
|
|
168
178
|
# Re-raise all other exceptions
|
|
169
179
|
raise
|
|
170
180
|
# print(completion)
|
|
171
181
|
|
|
172
|
-
results:
|
|
182
|
+
results: list[dict[str, Any]] = []
|
|
173
183
|
for choice in completion.choices:
|
|
174
184
|
choice_messages = cast(
|
|
175
185
|
Messages,
|
|
176
|
-
messages + [{
|
|
186
|
+
messages + [{'role': 'assistant', 'content': choice.message.content}],
|
|
177
187
|
)
|
|
178
|
-
result_dict = {
|
|
188
|
+
result_dict = {
|
|
189
|
+
'parsed': choice.message.content,
|
|
190
|
+
'messages': choice_messages,
|
|
191
|
+
}
|
|
179
192
|
|
|
180
193
|
# Add reasoning content if this is a reasoning model
|
|
181
|
-
if self.is_reasoning_model and hasattr(choice.message,
|
|
182
|
-
result_dict[
|
|
194
|
+
if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
|
|
195
|
+
result_dict['reasoning_content'] = choice.message.reasoning_content
|
|
183
196
|
|
|
184
197
|
results.append(result_dict)
|
|
185
198
|
return results
|
|
186
199
|
|
|
187
200
|
def pydantic_parse(
|
|
188
201
|
self,
|
|
189
|
-
input_data:
|
|
190
|
-
response_model:
|
|
202
|
+
input_data: str | BaseModel | list[dict],
|
|
203
|
+
response_model: type[BaseModel] | None | type[str] = None,
|
|
191
204
|
**runtime_kwargs,
|
|
192
|
-
) ->
|
|
205
|
+
) -> list[dict[str, Any]]:
|
|
193
206
|
"""Execute LLM task and return parsed Pydantic model responses."""
|
|
194
207
|
# Prepare messages
|
|
195
208
|
messages = self._prepare_input(input_data)
|
|
196
209
|
|
|
197
210
|
# Merge runtime kwargs with default model kwargs (runtime takes precedence)
|
|
198
211
|
effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
|
|
199
|
-
model_name = effective_kwargs.get(
|
|
212
|
+
model_name = effective_kwargs.get('model', self.model_kwargs['model'])
|
|
200
213
|
|
|
201
214
|
# Extract model name from kwargs for API call
|
|
202
|
-
api_kwargs = {k: v for k, v in effective_kwargs.items() if k !=
|
|
215
|
+
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != 'model'}
|
|
203
216
|
|
|
204
217
|
pydantic_model_to_use_opt = response_model or self.output_model
|
|
205
218
|
if pydantic_model_to_use_opt is None:
|
|
206
219
|
raise ValueError(
|
|
207
|
-
|
|
220
|
+
'No response model specified. Either set output_model in constructor or pass response_model parameter.'
|
|
208
221
|
)
|
|
209
|
-
pydantic_model_to_use:
|
|
222
|
+
pydantic_model_to_use: type[BaseModel] = cast(
|
|
223
|
+
type[BaseModel], pydantic_model_to_use_opt
|
|
224
|
+
)
|
|
210
225
|
try:
|
|
211
226
|
completion = self.client.chat.completions.parse(
|
|
212
227
|
model=model_name,
|
|
@@ -217,21 +232,22 @@ class LLM(
|
|
|
217
232
|
# Store raw response from client
|
|
218
233
|
self.last_ai_response = completion
|
|
219
234
|
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
220
|
-
error_msg = f
|
|
235
|
+
error_msg = f'OpenAI API error ({type(exc).__name__}): {exc}'
|
|
221
236
|
logger.error(error_msg)
|
|
222
237
|
raise
|
|
223
238
|
except Exception as e:
|
|
224
|
-
is_length_error =
|
|
239
|
+
is_length_error = 'Length' in str(e) or 'maximum context length' in str(e)
|
|
225
240
|
if is_length_error:
|
|
226
|
-
raise ValueError(
|
|
227
|
-
|
|
241
|
+
raise ValueError(
|
|
242
|
+
f'Input too long for model {model_name}. Error: {str(e)[:100]}...'
|
|
243
|
+
) from e
|
|
228
244
|
raise
|
|
229
245
|
|
|
230
|
-
results:
|
|
246
|
+
results: list[dict[str, Any]] = []
|
|
231
247
|
for choice in completion.choices: # type: ignore[attr-defined]
|
|
232
248
|
choice_messages = cast(
|
|
233
249
|
Messages,
|
|
234
|
-
messages + [{
|
|
250
|
+
messages + [{'role': 'assistant', 'content': choice.message.content}],
|
|
235
251
|
)
|
|
236
252
|
|
|
237
253
|
# Ensure consistent Pydantic model output for both fresh and cached responses
|
|
@@ -243,25 +259,25 @@ class LLM(
|
|
|
243
259
|
# Fallback: ensure it's the correct type
|
|
244
260
|
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
245
261
|
|
|
246
|
-
result_dict = {
|
|
262
|
+
result_dict = {'parsed': parsed_content, 'messages': choice_messages}
|
|
247
263
|
|
|
248
264
|
# Add reasoning content if this is a reasoning model
|
|
249
|
-
if self.is_reasoning_model and hasattr(choice.message,
|
|
250
|
-
result_dict[
|
|
265
|
+
if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
|
|
266
|
+
result_dict['reasoning_content'] = choice.message.reasoning_content
|
|
251
267
|
|
|
252
268
|
results.append(result_dict)
|
|
253
269
|
return results
|
|
254
270
|
|
|
255
271
|
def __call__(
|
|
256
272
|
self,
|
|
257
|
-
input_data:
|
|
258
|
-
response_model:
|
|
273
|
+
input_data: str | BaseModel | list[dict],
|
|
274
|
+
response_model: type[BaseModel] | type[str] | None = None,
|
|
259
275
|
two_step_parse_pydantic: bool = False,
|
|
260
|
-
temperature_ranges:
|
|
276
|
+
temperature_ranges: tuple[float, float] | None = None,
|
|
261
277
|
n: int = 1,
|
|
262
278
|
cache=None,
|
|
263
279
|
**openai_client_kwargs,
|
|
264
|
-
) ->
|
|
280
|
+
) -> list[dict[str, Any]]:
|
|
265
281
|
"""
|
|
266
282
|
Execute LLM task.
|
|
267
283
|
|
|
@@ -277,14 +293,16 @@ class LLM(
|
|
|
277
293
|
List of response dictionaries
|
|
278
294
|
"""
|
|
279
295
|
if cache is not None:
|
|
280
|
-
if hasattr(self.client,
|
|
296
|
+
if hasattr(self.client, 'set_cache'):
|
|
281
297
|
self.client.set_cache(cache)
|
|
282
298
|
else:
|
|
283
|
-
logger.warning(
|
|
299
|
+
logger.warning('Client does not support caching.')
|
|
284
300
|
# Handle temperature range sampling
|
|
285
301
|
if temperature_ranges is not None:
|
|
286
302
|
if n < 2:
|
|
287
|
-
raise ValueError(
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f'n must be >= 2 when using temperature_ranges, got {n}'
|
|
305
|
+
)
|
|
288
306
|
return self.temperature_range_sampling(
|
|
289
307
|
input_data,
|
|
290
308
|
temperature_ranges=temperature_ranges,
|
|
@@ -292,7 +310,7 @@ class LLM(
|
|
|
292
310
|
response_model=response_model,
|
|
293
311
|
**openai_client_kwargs,
|
|
294
312
|
)
|
|
295
|
-
openai_client_kwargs[
|
|
313
|
+
openai_client_kwargs['n'] = n
|
|
296
314
|
|
|
297
315
|
# Handle two-step Pydantic parsing
|
|
298
316
|
pydantic_model = response_model or self.output_model
|
|
@@ -311,33 +329,34 @@ class LLM(
|
|
|
311
329
|
)
|
|
312
330
|
|
|
313
331
|
# Track conversation history
|
|
314
|
-
_last_conv = choices[0][
|
|
315
|
-
if not hasattr(self,
|
|
332
|
+
_last_conv = choices[0]['messages'] if choices else []
|
|
333
|
+
if not hasattr(self, '_last_conversations'):
|
|
316
334
|
self._last_conversations = []
|
|
317
335
|
else:
|
|
318
336
|
self._last_conversations = self._last_conversations[-100:]
|
|
319
337
|
self._last_conversations.append(_last_conv)
|
|
320
338
|
return choices
|
|
321
339
|
|
|
322
|
-
def inspect_history(
|
|
340
|
+
def inspect_history(
|
|
341
|
+
self, idx: int = -1, k_last_messages: int = 2
|
|
342
|
+
) -> list[dict[str, Any]]:
|
|
323
343
|
"""Inspect the message history of a specific response choice."""
|
|
324
|
-
if hasattr(self,
|
|
344
|
+
if hasattr(self, '_last_conversations'):
|
|
325
345
|
from llm_utils import show_chat_v2
|
|
326
346
|
|
|
327
347
|
conv = self._last_conversations[idx]
|
|
328
348
|
if k_last_messages > 0:
|
|
329
349
|
conv = conv[-k_last_messages:]
|
|
330
350
|
return show_chat_v2(conv)
|
|
331
|
-
|
|
332
|
-
raise ValueError("No message history available. Make a call first.")
|
|
351
|
+
raise ValueError('No message history available. Make a call first.')
|
|
333
352
|
|
|
334
353
|
def __inner_call__(
|
|
335
354
|
self,
|
|
336
|
-
input_data:
|
|
337
|
-
response_model:
|
|
355
|
+
input_data: str | BaseModel | list[dict],
|
|
356
|
+
response_model: type[BaseModel] | type[str] | None = None,
|
|
338
357
|
two_step_parse_pydantic: bool = False,
|
|
339
358
|
**runtime_kwargs,
|
|
340
|
-
) ->
|
|
359
|
+
) -> list[dict[str, Any]]:
|
|
341
360
|
"""
|
|
342
361
|
Internal call handler. Delegates to text() or parse() based on model.
|
|
343
362
|
|
|
@@ -348,34 +367,33 @@ class LLM(
|
|
|
348
367
|
|
|
349
368
|
if pydantic_model_to_use is str or pydantic_model_to_use is None:
|
|
350
369
|
return self.text_completion(input_data, **runtime_kwargs)
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
)
|
|
370
|
+
return self.pydantic_parse(
|
|
371
|
+
input_data,
|
|
372
|
+
response_model=response_model,
|
|
373
|
+
**runtime_kwargs,
|
|
374
|
+
)
|
|
357
375
|
|
|
358
376
|
# Backward compatibility aliases
|
|
359
|
-
def text(self, *args, **kwargs) ->
|
|
377
|
+
def text(self, *args, **kwargs) -> list[dict[str, Any]]:
|
|
360
378
|
"""Alias for text_completion() for backward compatibility."""
|
|
361
379
|
return self.text_completion(*args, **kwargs)
|
|
362
380
|
|
|
363
|
-
def parse(self, *args, **kwargs) ->
|
|
381
|
+
def parse(self, *args, **kwargs) -> list[dict[str, Any]]:
|
|
364
382
|
"""Alias for pydantic_parse() for backward compatibility."""
|
|
365
383
|
return self.pydantic_parse(*args, **kwargs)
|
|
366
384
|
|
|
367
385
|
@classmethod
|
|
368
386
|
def from_prompt_builder(
|
|
369
|
-
|
|
370
|
-
client:
|
|
387
|
+
cls: BasePromptBuilder,
|
|
388
|
+
client: OpenAI | int | str | None = None,
|
|
371
389
|
cache=True,
|
|
372
390
|
is_reasoning_model: bool = False,
|
|
373
|
-
lora_path:
|
|
374
|
-
vllm_cmd:
|
|
391
|
+
lora_path: str | None = None,
|
|
392
|
+
vllm_cmd: str | None = None,
|
|
375
393
|
vllm_timeout: int = 120,
|
|
376
394
|
vllm_reuse: bool = True,
|
|
377
395
|
**model_kwargs,
|
|
378
|
-
) ->
|
|
396
|
+
) -> 'LLM':
|
|
379
397
|
"""
|
|
380
398
|
Create an LLMTask instance from a BasePromptBuilder instance.
|
|
381
399
|
|
|
@@ -393,9 +411,9 @@ class LLM(
|
|
|
393
411
|
vllm_reuse: If True (default), reuse existing server on target port
|
|
394
412
|
**model_kwargs: Additional model parameters
|
|
395
413
|
"""
|
|
396
|
-
instruction =
|
|
397
|
-
input_model =
|
|
398
|
-
output_model =
|
|
414
|
+
instruction = cls.get_instruction()
|
|
415
|
+
input_model = cls.get_input_model()
|
|
416
|
+
output_model = cls.get_output_model()
|
|
399
417
|
|
|
400
418
|
# Extract data from the builder to initialize LLMTask
|
|
401
419
|
return LLM(
|
llm_utils/lm/llm_signature.py
CHANGED
|
@@ -7,7 +7,9 @@ prompts, variable substitution, and export capabilities for fine-tuning.
|
|
|
7
7
|
|
|
8
8
|
import json
|
|
9
9
|
from typing import Any, Dict, List, Optional, Type, Union
|
|
10
|
+
|
|
10
11
|
from pydantic import BaseModel
|
|
12
|
+
|
|
11
13
|
from ..chat_format import get_conversation_one_turn
|
|
12
14
|
from .llm import LLM
|
|
13
15
|
from .signature import Signature
|
|
@@ -16,7 +18,7 @@ from .signature import Signature
|
|
|
16
18
|
class LLMSignature(LLM):
|
|
17
19
|
"""Base class for LLM judges with template support and SFT export."""
|
|
18
20
|
|
|
19
|
-
def __init__(self, signature:
|
|
21
|
+
def __init__(self, signature: type[Signature], **kwargs):
|
|
20
22
|
"""
|
|
21
23
|
Initialize LLMJudgeBase.
|
|
22
24
|
|
|
@@ -26,7 +28,7 @@ class LLMSignature(LLM):
|
|
|
26
28
|
**kwargs: Additional arguments passed to LLMTask
|
|
27
29
|
"""
|
|
28
30
|
self.signature = signature
|
|
29
|
-
self.sft_data:
|
|
31
|
+
self.sft_data: list[dict[str, Any]] = [] # Store SFT training examples
|
|
30
32
|
|
|
31
33
|
# Set instruction from signature if available
|
|
32
34
|
kwargs.setdefault("instruction", signature.get_instruction())
|