speedy-utils 1.1.27__py3-none-any.whl → 1.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +16 -4
- llm_utils/chat_format/__init__.py +10 -10
- llm_utils/chat_format/display.py +33 -21
- llm_utils/chat_format/transform.py +17 -19
- llm_utils/chat_format/utils.py +6 -4
- llm_utils/group_messages.py +17 -14
- llm_utils/lm/__init__.py +6 -5
- llm_utils/lm/async_lm/__init__.py +1 -0
- llm_utils/lm/async_lm/_utils.py +10 -9
- llm_utils/lm/async_lm/async_llm_task.py +141 -137
- llm_utils/lm/async_lm/async_lm.py +48 -42
- llm_utils/lm/async_lm/async_lm_base.py +59 -60
- llm_utils/lm/async_lm/lm_specific.py +4 -3
- llm_utils/lm/base_prompt_builder.py +93 -70
- llm_utils/lm/llm.py +126 -108
- llm_utils/lm/llm_signature.py +4 -2
- llm_utils/lm/lm_base.py +72 -73
- llm_utils/lm/mixins.py +102 -62
- llm_utils/lm/openai_memoize.py +124 -87
- llm_utils/lm/signature.py +105 -92
- llm_utils/lm/utils.py +42 -23
- llm_utils/scripts/vllm_load_balancer.py +23 -30
- llm_utils/scripts/vllm_serve.py +8 -7
- llm_utils/vector_cache/__init__.py +9 -3
- llm_utils/vector_cache/cli.py +1 -1
- llm_utils/vector_cache/core.py +59 -63
- llm_utils/vector_cache/types.py +7 -5
- llm_utils/vector_cache/utils.py +12 -8
- speedy_utils/__imports.py +244 -0
- speedy_utils/__init__.py +90 -194
- speedy_utils/all.py +125 -227
- speedy_utils/common/clock.py +37 -42
- speedy_utils/common/function_decorator.py +6 -12
- speedy_utils/common/logger.py +43 -52
- speedy_utils/common/notebook_utils.py +13 -21
- speedy_utils/common/patcher.py +21 -17
- speedy_utils/common/report_manager.py +42 -44
- speedy_utils/common/utils_cache.py +152 -169
- speedy_utils/common/utils_io.py +137 -103
- speedy_utils/common/utils_misc.py +15 -21
- speedy_utils/common/utils_print.py +22 -28
- speedy_utils/multi_worker/process.py +66 -79
- speedy_utils/multi_worker/thread.py +78 -155
- speedy_utils/scripts/mpython.py +38 -36
- speedy_utils/scripts/openapi_client_codegen.py +10 -10
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/METADATA +1 -1
- speedy_utils-1.1.29.dist-info/RECORD +57 -0
- vision_utils/README.md +202 -0
- vision_utils/__init__.py +4 -0
- vision_utils/io_utils.py +735 -0
- vision_utils/plot.py +345 -0
- speedy_utils-1.1.27.dist-info/RECORD +0 -52
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/entry_points.txt +0 -0
|
@@ -12,16 +12,17 @@ from venv import logger
|
|
|
12
12
|
|
|
13
13
|
from openai.types.chat import ChatCompletionMessageParam
|
|
14
14
|
from pydantic import BaseModel
|
|
15
|
-
from speedy_utils.all import dump_json_or_pickle, identify
|
|
16
15
|
|
|
17
16
|
from llm_utils.chat_format.display import get_conversation_one_turn
|
|
18
17
|
from llm_utils.lm.async_lm._utils import InputModelType, OutputModelType, ParsedOutput
|
|
19
18
|
from llm_utils.lm.async_lm.async_lm import AsyncLM
|
|
19
|
+
from speedy_utils import dump_json_or_pickle, identify
|
|
20
|
+
|
|
20
21
|
|
|
21
22
|
# Type aliases for better readability
|
|
22
|
-
TModel = TypeVar(
|
|
23
|
-
Messages =
|
|
24
|
-
LegacyMsgs =
|
|
23
|
+
TModel = TypeVar('TModel', bound=BaseModel)
|
|
24
|
+
Messages = list[ChatCompletionMessageParam]
|
|
25
|
+
LegacyMsgs = list[dict[str, str]]
|
|
25
26
|
RawMsgs = Union[Messages, LegacyMsgs]
|
|
26
27
|
|
|
27
28
|
# Default configuration constants
|
|
@@ -31,38 +32,38 @@ RawMsgs = Union[Messages, LegacyMsgs]
|
|
|
31
32
|
class LMConfiguration:
|
|
32
33
|
"""Configuration class for language model parameters."""
|
|
33
34
|
|
|
34
|
-
model:
|
|
35
|
-
temperature:
|
|
36
|
-
max_tokens:
|
|
37
|
-
base_url:
|
|
38
|
-
api_key:
|
|
39
|
-
cache:
|
|
40
|
-
think:
|
|
41
|
-
add_json_schema_to_instruction:
|
|
42
|
-
use_beta:
|
|
43
|
-
ports:
|
|
44
|
-
top_p:
|
|
45
|
-
presence_penalty:
|
|
46
|
-
top_k:
|
|
47
|
-
repetition_penalty:
|
|
48
|
-
|
|
49
|
-
def to_dict(self) ->
|
|
35
|
+
model: str | None = None
|
|
36
|
+
temperature: float | None = None
|
|
37
|
+
max_tokens: int | None = None
|
|
38
|
+
base_url: str | None = None
|
|
39
|
+
api_key: str | None = None
|
|
40
|
+
cache: bool | None = True
|
|
41
|
+
think: Literal[True, False] | None = None
|
|
42
|
+
add_json_schema_to_instruction: bool | None = None
|
|
43
|
+
use_beta: bool | None = False
|
|
44
|
+
ports: list[int] | None = None
|
|
45
|
+
top_p: float | None = None
|
|
46
|
+
presence_penalty: float | None = None
|
|
47
|
+
top_k: int | None = None
|
|
48
|
+
repetition_penalty: float | None = None
|
|
49
|
+
|
|
50
|
+
def to_dict(self) -> dict[str, Any]:
|
|
50
51
|
"""Convert configuration to dictionary format."""
|
|
51
52
|
return {
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
53
|
+
'model': self.model,
|
|
54
|
+
'temperature': self.temperature,
|
|
55
|
+
'max_tokens': self.max_tokens,
|
|
56
|
+
'base_url': self.base_url,
|
|
57
|
+
'api_key': self.api_key,
|
|
58
|
+
'cache': self.cache,
|
|
59
|
+
'think': self.think,
|
|
60
|
+
'add_json_schema_to_instruction': self.add_json_schema_to_instruction,
|
|
61
|
+
'use_beta': self.use_beta,
|
|
62
|
+
'ports': self.ports,
|
|
63
|
+
'top_p': self.top_p,
|
|
64
|
+
'presence_penalty': self.presence_penalty,
|
|
65
|
+
'top_k': self.top_k,
|
|
66
|
+
'repetition_penalty': self.repetition_penalty,
|
|
66
67
|
}
|
|
67
68
|
|
|
68
69
|
|
|
@@ -83,41 +84,41 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
83
84
|
OutputModel: OutputModelType
|
|
84
85
|
|
|
85
86
|
# default class attributes for configuration
|
|
86
|
-
DEFAULT_MODEL:
|
|
87
|
-
DEFAULT_CACHE_DIR:
|
|
88
|
-
DEFAULT_TEMPERATURE:
|
|
89
|
-
DEFAULT_MAX_TOKENS:
|
|
90
|
-
DEFAULT_TOP_P:
|
|
91
|
-
DEFAULT_PRESENCE_PENALTY:
|
|
92
|
-
DEFAULT_TOP_K:
|
|
93
|
-
DEFAULT_REPETITION_PENALTY:
|
|
94
|
-
DEFAULT_CACHE:
|
|
95
|
-
DEFAULT_THINK:
|
|
96
|
-
DEFAULT_PORTS:
|
|
97
|
-
DEFAULT_USE_BETA:
|
|
98
|
-
DEFAULT_ADD_JSON_SCHEMA_TO_INSTRUCTION:
|
|
99
|
-
DEFAULT_COLLECT_DATA:
|
|
100
|
-
DEFAULT_BASE_URL:
|
|
101
|
-
DEFAULT_API_KEY:
|
|
87
|
+
DEFAULT_MODEL: str | None = None
|
|
88
|
+
DEFAULT_CACHE_DIR: pathlib.Path | None = None
|
|
89
|
+
DEFAULT_TEMPERATURE: float | None = None
|
|
90
|
+
DEFAULT_MAX_TOKENS: int | None = None
|
|
91
|
+
DEFAULT_TOP_P: float | None = None
|
|
92
|
+
DEFAULT_PRESENCE_PENALTY: float | None = None
|
|
93
|
+
DEFAULT_TOP_K: int | None = None
|
|
94
|
+
DEFAULT_REPETITION_PENALTY: float | None = None
|
|
95
|
+
DEFAULT_CACHE: bool | None = True
|
|
96
|
+
DEFAULT_THINK: Literal[True, False] | None = None
|
|
97
|
+
DEFAULT_PORTS: list[int] | None = None
|
|
98
|
+
DEFAULT_USE_BETA: bool | None = False
|
|
99
|
+
DEFAULT_ADD_JSON_SCHEMA_TO_INSTRUCTION: bool | None = True
|
|
100
|
+
DEFAULT_COLLECT_DATA: bool | None = None
|
|
101
|
+
DEFAULT_BASE_URL: str | None = None
|
|
102
|
+
DEFAULT_API_KEY: str | None = None
|
|
102
103
|
|
|
103
104
|
IS_DATA_COLLECTION: bool = False
|
|
104
105
|
|
|
105
106
|
def __init__(
|
|
106
107
|
self,
|
|
107
|
-
model:
|
|
108
|
-
temperature:
|
|
109
|
-
max_tokens:
|
|
110
|
-
base_url:
|
|
111
|
-
api_key:
|
|
112
|
-
cache:
|
|
113
|
-
think:
|
|
114
|
-
add_json_schema_to_instruction:
|
|
115
|
-
use_beta:
|
|
116
|
-
ports:
|
|
117
|
-
top_p:
|
|
118
|
-
presence_penalty:
|
|
119
|
-
top_k:
|
|
120
|
-
repetition_penalty:
|
|
108
|
+
model: str | None = None,
|
|
109
|
+
temperature: float | None = None,
|
|
110
|
+
max_tokens: int | None = None,
|
|
111
|
+
base_url: str | None = None,
|
|
112
|
+
api_key: str | None = None,
|
|
113
|
+
cache: bool | None = None,
|
|
114
|
+
think: Literal[True, False] | None = None,
|
|
115
|
+
add_json_schema_to_instruction: bool | None = None,
|
|
116
|
+
use_beta: bool | None = None,
|
|
117
|
+
ports: list[int] | None = None,
|
|
118
|
+
top_p: float | None = None,
|
|
119
|
+
presence_penalty: float | None = None,
|
|
120
|
+
top_k: int | None = None,
|
|
121
|
+
repetition_penalty: float | None = None,
|
|
121
122
|
) -> None:
|
|
122
123
|
"""
|
|
123
124
|
Initialize the AsyncLLMTask with language model configuration.
|
|
@@ -126,31 +127,37 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
126
127
|
"""
|
|
127
128
|
self._config = LMConfiguration(
|
|
128
129
|
model=model if model is not None else self.DEFAULT_MODEL,
|
|
129
|
-
temperature=
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
max_tokens=
|
|
133
|
-
|
|
134
|
-
|
|
130
|
+
temperature=(
|
|
131
|
+
temperature if temperature is not None else self.DEFAULT_TEMPERATURE
|
|
132
|
+
),
|
|
133
|
+
max_tokens=(
|
|
134
|
+
max_tokens if max_tokens is not None else self.DEFAULT_MAX_TOKENS
|
|
135
|
+
),
|
|
135
136
|
base_url=base_url if base_url is not None else self.DEFAULT_BASE_URL,
|
|
136
137
|
api_key=api_key if api_key is not None else self.DEFAULT_API_KEY,
|
|
137
138
|
cache=cache if cache is not None else self.DEFAULT_CACHE,
|
|
138
139
|
think=think if think is not None else self.DEFAULT_THINK,
|
|
139
|
-
add_json_schema_to_instruction=
|
|
140
|
-
|
|
141
|
-
|
|
140
|
+
add_json_schema_to_instruction=(
|
|
141
|
+
add_json_schema_to_instruction
|
|
142
|
+
if add_json_schema_to_instruction is not None
|
|
143
|
+
else self.DEFAULT_ADD_JSON_SCHEMA_TO_INSTRUCTION
|
|
144
|
+
),
|
|
142
145
|
use_beta=use_beta if use_beta is not None else self.DEFAULT_USE_BETA,
|
|
143
146
|
ports=ports if ports is not None else self.DEFAULT_PORTS,
|
|
144
147
|
top_p=top_p if top_p is not None else self.DEFAULT_TOP_P,
|
|
145
|
-
presence_penalty=
|
|
146
|
-
|
|
147
|
-
|
|
148
|
+
presence_penalty=(
|
|
149
|
+
presence_penalty
|
|
150
|
+
if presence_penalty is not None
|
|
151
|
+
else self.DEFAULT_PRESENCE_PENALTY
|
|
152
|
+
),
|
|
148
153
|
top_k=top_k if top_k is not None else self.DEFAULT_TOP_K,
|
|
149
|
-
repetition_penalty=
|
|
150
|
-
|
|
151
|
-
|
|
154
|
+
repetition_penalty=(
|
|
155
|
+
repetition_penalty
|
|
156
|
+
if repetition_penalty is not None
|
|
157
|
+
else self.DEFAULT_REPETITION_PENALTY
|
|
158
|
+
),
|
|
152
159
|
)
|
|
153
|
-
self._lm:
|
|
160
|
+
self._lm: AsyncLM | None = None
|
|
154
161
|
|
|
155
162
|
@property
|
|
156
163
|
def lm(self) -> AsyncLM:
|
|
@@ -178,21 +185,21 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
178
185
|
TypeError: If output model type cannot be determined
|
|
179
186
|
"""
|
|
180
187
|
# Try to get type from generic base classes
|
|
181
|
-
orig_bases = getattr(self.__class__,
|
|
188
|
+
orig_bases = getattr(self.__class__, '__orig_bases__', None)
|
|
182
189
|
if (
|
|
183
190
|
orig_bases
|
|
184
|
-
and hasattr(orig_bases[0],
|
|
191
|
+
and hasattr(orig_bases[0], '__args__')
|
|
185
192
|
and len(orig_bases[0].__args__) >= 2
|
|
186
193
|
):
|
|
187
194
|
return orig_bases[0].__args__[1]
|
|
188
195
|
|
|
189
196
|
# Fallback to class attribute
|
|
190
|
-
if hasattr(self,
|
|
197
|
+
if hasattr(self, 'OutputModel'):
|
|
191
198
|
return self.OutputModel # type: ignore
|
|
192
199
|
|
|
193
200
|
raise TypeError(
|
|
194
|
-
f
|
|
195
|
-
|
|
201
|
+
f'{self.__class__.__name__} must define OutputModel as a class attribute '
|
|
202
|
+
'or use proper generic typing with AsyncLLMTask[InputModel, OutputModel]'
|
|
196
203
|
)
|
|
197
204
|
|
|
198
205
|
def _get_input_model_type(self) -> type[InputModelType]:
|
|
@@ -206,20 +213,20 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
206
213
|
TypeError: If input model type cannot be determined
|
|
207
214
|
"""
|
|
208
215
|
# Try to get type from generic base classes
|
|
209
|
-
orig_bases = getattr(self.__class__,
|
|
216
|
+
orig_bases = getattr(self.__class__, '__orig_bases__', None)
|
|
210
217
|
if (
|
|
211
218
|
orig_bases
|
|
212
|
-
and hasattr(orig_bases[0],
|
|
219
|
+
and hasattr(orig_bases[0], '__args__')
|
|
213
220
|
and len(orig_bases[0].__args__) >= 2
|
|
214
221
|
):
|
|
215
222
|
return orig_bases[0].__args__[0]
|
|
216
223
|
|
|
217
224
|
raise TypeError(
|
|
218
|
-
f
|
|
219
|
-
|
|
225
|
+
f'{self.__class__.__name__} must define InputModel as a class attribute '
|
|
226
|
+
'or use proper generic typing with AsyncLLMTask[InputModel, OutputModel]'
|
|
220
227
|
)
|
|
221
228
|
|
|
222
|
-
def _validate_and_convert_input(self, data:
|
|
229
|
+
def _validate_and_convert_input(self, data: BaseModel | dict) -> BaseModel:
|
|
223
230
|
"""
|
|
224
231
|
Validate and convert input data to the expected input model type.
|
|
225
232
|
|
|
@@ -243,10 +250,10 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
243
250
|
return input_model_type(**data)
|
|
244
251
|
except Exception as e:
|
|
245
252
|
raise TypeError(
|
|
246
|
-
f
|
|
253
|
+
f'Failed to convert input data to {input_model_type.__name__}: {e}'
|
|
247
254
|
) from e
|
|
248
255
|
|
|
249
|
-
raise TypeError(
|
|
256
|
+
raise TypeError('InputModel must be a subclass of BaseModel')
|
|
250
257
|
|
|
251
258
|
def _validate_output_model(self) -> type[BaseModel]:
|
|
252
259
|
"""
|
|
@@ -263,12 +270,10 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
263
270
|
isinstance(output_model_type, type)
|
|
264
271
|
and issubclass(output_model_type, BaseModel)
|
|
265
272
|
):
|
|
266
|
-
raise TypeError(
|
|
273
|
+
raise TypeError('OutputModel must be a subclass of BaseModel')
|
|
267
274
|
return output_model_type
|
|
268
275
|
|
|
269
|
-
async def _base_call(
|
|
270
|
-
self, data: Union[BaseModel, dict]
|
|
271
|
-
) -> ParsedOutput[OutputModelType]:
|
|
276
|
+
async def _base_call(self, data: BaseModel | dict) -> ParsedOutput[OutputModelType]:
|
|
272
277
|
"""
|
|
273
278
|
Core method that handles language model interaction with type safety.
|
|
274
279
|
|
|
@@ -289,7 +294,7 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
289
294
|
return cast(
|
|
290
295
|
ParsedOutput[OutputModelType],
|
|
291
296
|
await self.lm.parse(
|
|
292
|
-
instruction=self.__doc__ or
|
|
297
|
+
instruction=self.__doc__ or '',
|
|
293
298
|
prompt=validated_input.model_dump_json(),
|
|
294
299
|
),
|
|
295
300
|
)
|
|
@@ -311,21 +316,21 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
311
316
|
no_think_messages = copy.deepcopy(think_messages)
|
|
312
317
|
|
|
313
318
|
# Update system message
|
|
314
|
-
if no_think_messages and
|
|
315
|
-
system_content = no_think_messages[0][
|
|
319
|
+
if no_think_messages and 'content' in no_think_messages[0]:
|
|
320
|
+
system_content = no_think_messages[0]['content']
|
|
316
321
|
if isinstance(system_content, str):
|
|
317
|
-
no_think_messages[0][
|
|
318
|
-
|
|
322
|
+
no_think_messages[0]['content'] = system_content.replace(
|
|
323
|
+
'/think', '/no_think'
|
|
319
324
|
)
|
|
320
325
|
|
|
321
326
|
# Update assistant message (last message)
|
|
322
|
-
if len(no_think_messages) > 1 and
|
|
323
|
-
assistant_content = no_think_messages[-1][
|
|
324
|
-
if isinstance(assistant_content, str) and
|
|
327
|
+
if len(no_think_messages) > 1 and 'content' in no_think_messages[-1]:
|
|
328
|
+
assistant_content = no_think_messages[-1]['content']
|
|
329
|
+
if isinstance(assistant_content, str) and '</think>' in assistant_content:
|
|
325
330
|
# Extract content after thinking block
|
|
326
|
-
post_think_content = assistant_content.split(
|
|
327
|
-
no_think_messages[-1][
|
|
328
|
-
f
|
|
331
|
+
post_think_content = assistant_content.split('</think>', 1)[1].strip()
|
|
332
|
+
no_think_messages[-1]['content'] = (
|
|
333
|
+
f'<think>\n\n</think>\n\n{post_think_content}'
|
|
329
334
|
)
|
|
330
335
|
|
|
331
336
|
return no_think_messages
|
|
@@ -335,10 +340,10 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
335
340
|
input_data: InputModelType,
|
|
336
341
|
think_messages: Messages,
|
|
337
342
|
no_think_messages: Messages,
|
|
338
|
-
model_kwargs:
|
|
343
|
+
model_kwargs: dict[str, Any],
|
|
339
344
|
cache_dir: pathlib.Path,
|
|
340
|
-
expected_response:
|
|
341
|
-
label:
|
|
345
|
+
expected_response: OutputModelType | None = None,
|
|
346
|
+
label: str | None = None,
|
|
342
347
|
) -> None:
|
|
343
348
|
"""
|
|
344
349
|
Save training data to cache directory.
|
|
@@ -359,26 +364,26 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
359
364
|
|
|
360
365
|
# Prepare combined training data
|
|
361
366
|
training_data = {
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
+
'think_messages': think_messages,
|
|
368
|
+
'no_think_messages': no_think_messages,
|
|
369
|
+
'model_kwargs': model_kwargs,
|
|
370
|
+
'input_data': input_data.model_dump(),
|
|
371
|
+
'label': label,
|
|
367
372
|
}
|
|
368
373
|
|
|
369
374
|
if expected_response is not None:
|
|
370
|
-
training_data[
|
|
375
|
+
training_data['expected_response'] = expected_response.model_dump()
|
|
371
376
|
|
|
372
377
|
# Save to file
|
|
373
|
-
training_file = class_cache_dir / f
|
|
378
|
+
training_file = class_cache_dir / f'{input_id}.json'
|
|
374
379
|
dump_json_or_pickle(training_data, str(training_file))
|
|
375
380
|
|
|
376
381
|
async def _generate_training_data_with_thinking_mode(
|
|
377
382
|
self,
|
|
378
383
|
input_data: InputModelType,
|
|
379
|
-
expected_response:
|
|
380
|
-
label:
|
|
381
|
-
cache_dir:
|
|
384
|
+
expected_response: OutputModelType | None = None,
|
|
385
|
+
label: str | None = None,
|
|
386
|
+
cache_dir: pathlib.Path | None = None,
|
|
382
387
|
) -> OutputModelType:
|
|
383
388
|
"""
|
|
384
389
|
Generate training data for both thinking and non-thinking modes.
|
|
@@ -398,22 +403,22 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
398
403
|
"""
|
|
399
404
|
# Execute the base call to get thinking mode data
|
|
400
405
|
output = await self._base_call(input_data)
|
|
401
|
-
parsed_result = output[
|
|
402
|
-
think_messages = output[
|
|
406
|
+
parsed_result = output['parsed']
|
|
407
|
+
think_messages = output['messages']
|
|
403
408
|
|
|
404
409
|
# Create non-thinking mode equivalent
|
|
405
410
|
no_think_messages = self._create_no_think_messages(think_messages)
|
|
406
411
|
|
|
407
412
|
# Use default cache directory if none provided
|
|
408
413
|
if cache_dir is None:
|
|
409
|
-
cache_dir = self.DEFAULT_CACHE_DIR or pathlib.Path(
|
|
414
|
+
cache_dir = self.DEFAULT_CACHE_DIR or pathlib.Path('./cache')
|
|
410
415
|
|
|
411
416
|
# Save training data
|
|
412
417
|
self._save_training_data(
|
|
413
418
|
input_data=input_data,
|
|
414
419
|
think_messages=think_messages,
|
|
415
420
|
no_think_messages=no_think_messages,
|
|
416
|
-
model_kwargs=output[
|
|
421
|
+
model_kwargs=output['model_kwargs'],
|
|
417
422
|
cache_dir=cache_dir,
|
|
418
423
|
expected_response=expected_response,
|
|
419
424
|
label=label,
|
|
@@ -433,8 +438,8 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
433
438
|
async def __call__(
|
|
434
439
|
self,
|
|
435
440
|
input_data: InputModelType,
|
|
436
|
-
expected_response:
|
|
437
|
-
label:
|
|
441
|
+
expected_response: OutputModelType | None = None,
|
|
442
|
+
label: str | None = None,
|
|
438
443
|
**kwargs: Any,
|
|
439
444
|
) -> OutputModelType:
|
|
440
445
|
"""
|
|
@@ -459,13 +464,12 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
459
464
|
expected_response=expected_response,
|
|
460
465
|
label=label,
|
|
461
466
|
)
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
return output["parsed"]
|
|
467
|
+
output = await self._base_call(input_data)
|
|
468
|
+
return output['parsed']
|
|
465
469
|
|
|
466
470
|
def generate_training_data(
|
|
467
471
|
self, input_json: str, output_json: str
|
|
468
|
-
) ->
|
|
472
|
+
) -> dict[str, Any]:
|
|
469
473
|
"""
|
|
470
474
|
Generate training data in ShareGPT format for the given input/output pair.
|
|
471
475
|
|
|
@@ -488,16 +492,16 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
488
492
|
# "as class attributes to use generate_training_data"
|
|
489
493
|
# )
|
|
490
494
|
|
|
491
|
-
system_prompt = self.__doc__ or
|
|
492
|
-
assert isinstance(input_json, str),
|
|
493
|
-
assert isinstance(output_json, str),
|
|
495
|
+
system_prompt = self.__doc__ or ''
|
|
496
|
+
assert isinstance(input_json, str), 'Input must be a JSON string'
|
|
497
|
+
assert isinstance(output_json, str), 'Output must be a JSON string'
|
|
494
498
|
messages = get_conversation_one_turn(
|
|
495
499
|
system_msg=system_prompt,
|
|
496
500
|
user_msg=input_json,
|
|
497
501
|
assistant_msg=output_json,
|
|
498
502
|
)
|
|
499
503
|
|
|
500
|
-
return {
|
|
504
|
+
return {'messages': messages}
|
|
501
505
|
|
|
502
506
|
# Compatibility alias for other LLMTask implementations
|
|
503
507
|
arun = __call__
|
|
@@ -506,8 +510,8 @@ class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
|
506
510
|
return self
|
|
507
511
|
|
|
508
512
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
509
|
-
if hasattr(self._lm,
|
|
510
|
-
last_client = self._lm._last_client
|
|
513
|
+
if hasattr(self._lm, '_last_client'):
|
|
514
|
+
last_client = self._lm._last_client # type: ignore
|
|
511
515
|
await last_client._client.aclose()
|
|
512
516
|
else:
|
|
513
|
-
logger.warning(
|
|
517
|
+
logger.warning('No last client to close')
|
|
@@ -12,10 +12,10 @@ from typing import (
|
|
|
12
12
|
from loguru import logger
|
|
13
13
|
from openai import AuthenticationError, BadRequestError, OpenAI, RateLimitError
|
|
14
14
|
from pydantic import BaseModel
|
|
15
|
-
from speedy_utils import jloads
|
|
16
15
|
|
|
17
16
|
# from llm_utils.lm.async_lm.async_llm_task import OutputModelType
|
|
18
17
|
from llm_utils.lm.async_lm.async_lm_base import AsyncLMBase
|
|
18
|
+
from speedy_utils import jloads
|
|
19
19
|
|
|
20
20
|
from ._utils import (
|
|
21
21
|
LegacyMsgs,
|
|
@@ -44,28 +44,32 @@ class AsyncLM(AsyncLMBase):
|
|
|
44
44
|
def __init__(
|
|
45
45
|
self,
|
|
46
46
|
*,
|
|
47
|
-
model:
|
|
48
|
-
response_model:
|
|
47
|
+
model: str | None = None,
|
|
48
|
+
response_model: type[BaseModel] | None = None,
|
|
49
49
|
temperature: float = 0.0,
|
|
50
50
|
max_tokens: int = 2_000,
|
|
51
51
|
host: str = "localhost",
|
|
52
|
-
port:
|
|
53
|
-
base_url:
|
|
54
|
-
api_key:
|
|
52
|
+
port: int | str | None = None,
|
|
53
|
+
base_url: str | None = None,
|
|
54
|
+
api_key: str | None = None,
|
|
55
55
|
cache: bool = True,
|
|
56
56
|
think: Literal[True, False, None] = None,
|
|
57
|
-
add_json_schema_to_instruction:
|
|
57
|
+
add_json_schema_to_instruction: bool | None = None,
|
|
58
58
|
use_beta: bool = False,
|
|
59
|
-
ports:
|
|
59
|
+
ports: list[int] | None = None,
|
|
60
60
|
top_p: float = 1.0,
|
|
61
61
|
presence_penalty: float = 0.0,
|
|
62
62
|
top_k: int = 1,
|
|
63
63
|
repetition_penalty: float = 1.0,
|
|
64
|
-
frequency_penalty:
|
|
64
|
+
frequency_penalty: float | None = None,
|
|
65
65
|
) -> None:
|
|
66
66
|
|
|
67
67
|
if model is None:
|
|
68
|
-
models =
|
|
68
|
+
models = (
|
|
69
|
+
OpenAI(base_url=f"http://{host}:{port}/v1", api_key="abc")
|
|
70
|
+
.models.list()
|
|
71
|
+
.data
|
|
72
|
+
)
|
|
69
73
|
assert len(models) == 1, f"Found {len(models)} models, please specify one."
|
|
70
74
|
model = models[0].id
|
|
71
75
|
print(f"Using model: {model}")
|
|
@@ -86,24 +90,24 @@ class AsyncLM(AsyncLMBase):
|
|
|
86
90
|
self.add_json_schema_to_instruction = True
|
|
87
91
|
|
|
88
92
|
# Store all model-related parameters in model_kwargs
|
|
89
|
-
self.model_kwargs =
|
|
90
|
-
model
|
|
91
|
-
temperature
|
|
92
|
-
max_tokens
|
|
93
|
-
top_p
|
|
94
|
-
presence_penalty
|
|
95
|
-
|
|
96
|
-
self.extra_body =
|
|
97
|
-
top_k
|
|
98
|
-
repetition_penalty
|
|
99
|
-
frequency_penalty
|
|
100
|
-
|
|
93
|
+
self.model_kwargs = {
|
|
94
|
+
"model": model,
|
|
95
|
+
"temperature": temperature,
|
|
96
|
+
"max_tokens": max_tokens,
|
|
97
|
+
"top_p": top_p,
|
|
98
|
+
"presence_penalty": presence_penalty,
|
|
99
|
+
}
|
|
100
|
+
self.extra_body = {
|
|
101
|
+
"top_k": top_k,
|
|
102
|
+
"repetition_penalty": repetition_penalty,
|
|
103
|
+
"frequency_penalty": frequency_penalty,
|
|
104
|
+
}
|
|
101
105
|
|
|
102
106
|
async def _unified_client_call(
|
|
103
107
|
self,
|
|
104
108
|
messages: RawMsgs,
|
|
105
|
-
extra_body:
|
|
106
|
-
max_tokens:
|
|
109
|
+
extra_body: dict | None = None,
|
|
110
|
+
max_tokens: int | None = None,
|
|
107
111
|
) -> dict:
|
|
108
112
|
"""Unified method for all client interactions (caching handled by MAsyncOpenAI)."""
|
|
109
113
|
converted_messages: Messages = (
|
|
@@ -139,7 +143,7 @@ class AsyncLM(AsyncLMBase):
|
|
|
139
143
|
async def _call_and_parse(
|
|
140
144
|
self,
|
|
141
145
|
messages: list[dict],
|
|
142
|
-
response_model:
|
|
146
|
+
response_model: type[OutputModelType],
|
|
143
147
|
json_schema: dict,
|
|
144
148
|
) -> tuple[dict, list[dict], OutputModelType]:
|
|
145
149
|
"""Unified call and parse with cache and error handling."""
|
|
@@ -198,7 +202,7 @@ class AsyncLM(AsyncLMBase):
|
|
|
198
202
|
async def _call_and_parse_with_beta(
|
|
199
203
|
self,
|
|
200
204
|
messages: list[dict],
|
|
201
|
-
response_model:
|
|
205
|
+
response_model: type[OutputModelType],
|
|
202
206
|
json_schema: dict,
|
|
203
207
|
) -> tuple[dict, list[dict], OutputModelType]:
|
|
204
208
|
"""Call and parse for beta mode with guided JSON."""
|
|
@@ -249,9 +253,9 @@ class AsyncLM(AsyncLMBase):
|
|
|
249
253
|
|
|
250
254
|
async def call_with_messages(
|
|
251
255
|
self,
|
|
252
|
-
prompt:
|
|
253
|
-
messages:
|
|
254
|
-
max_tokens:
|
|
256
|
+
prompt: str | None = None,
|
|
257
|
+
messages: RawMsgs | None = None,
|
|
258
|
+
max_tokens: int | None = None,
|
|
255
259
|
): # -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:# -> tuple[Any | dict[Any, Any], list[ChatCompletionMessagePar...:
|
|
256
260
|
"""Unified async call for language model, returns (assistant_message.model_dump(), messages)."""
|
|
257
261
|
if (prompt is None) == (messages is None):
|
|
@@ -268,9 +272,9 @@ class AsyncLM(AsyncLMBase):
|
|
|
268
272
|
else cast(Messages, messages)
|
|
269
273
|
)
|
|
270
274
|
|
|
271
|
-
assert
|
|
272
|
-
"
|
|
273
|
-
)
|
|
275
|
+
assert (
|
|
276
|
+
self.model_kwargs["model"] is not None
|
|
277
|
+
), "Model must be set before making a call."
|
|
274
278
|
|
|
275
279
|
# Use unified client call
|
|
276
280
|
raw_response = await self._unified_client_call(
|
|
@@ -293,17 +297,19 @@ class AsyncLM(AsyncLMBase):
|
|
|
293
297
|
msg_dump = dict(assistant_msg)
|
|
294
298
|
return msg_dump, full_messages
|
|
295
299
|
|
|
296
|
-
|
|
297
300
|
def call_sync(
|
|
298
301
|
self,
|
|
299
|
-
prompt:
|
|
300
|
-
messages:
|
|
301
|
-
max_tokens:
|
|
302
|
+
prompt: str | None = None,
|
|
303
|
+
messages: RawMsgs | None = None,
|
|
304
|
+
max_tokens: int | None = None,
|
|
302
305
|
):
|
|
303
306
|
"""Synchronous wrapper around the async __call__ method."""
|
|
304
307
|
import asyncio
|
|
305
|
-
|
|
306
|
-
|
|
308
|
+
|
|
309
|
+
return asyncio.run(
|
|
310
|
+
self.__call__(prompt=prompt, messages=messages, max_tokens=max_tokens)
|
|
311
|
+
)
|
|
312
|
+
|
|
307
313
|
async def parse(
|
|
308
314
|
self,
|
|
309
315
|
instruction,
|
|
@@ -311,9 +317,9 @@ class AsyncLM(AsyncLMBase):
|
|
|
311
317
|
) -> ParsedOutput[BaseModel]:
|
|
312
318
|
"""Parse response using guided JSON generation. Returns (parsed.model_dump(), messages)."""
|
|
313
319
|
if not self._use_beta:
|
|
314
|
-
assert
|
|
315
|
-
|
|
316
|
-
)
|
|
320
|
+
assert (
|
|
321
|
+
self.add_json_schema_to_instruction
|
|
322
|
+
), "add_json_schema_to_instruction must be True when use_beta is False. otherwise model will not be able to parse the response."
|
|
317
323
|
|
|
318
324
|
assert self.response_model is not None, "response_model must be set at init."
|
|
319
325
|
json_schema = self.response_model.model_json_schema()
|
|
@@ -351,7 +357,7 @@ class AsyncLM(AsyncLMBase):
|
|
|
351
357
|
)
|
|
352
358
|
|
|
353
359
|
def _parse_complete_output(
|
|
354
|
-
self, completion: Any, response_model:
|
|
360
|
+
self, completion: Any, response_model: type[BaseModel]
|
|
355
361
|
) -> BaseModel:
|
|
356
362
|
"""Parse completion output to response model."""
|
|
357
363
|
if hasattr(completion, "model_dump"):
|