speedy-utils 1.1.6__tar.gz → 1.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/PKG-INFO +1 -1
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/pyproject.toml +42 -10
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/__init__.py +1 -5
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/chat_format/transform.py +9 -9
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/group_messages.py +1 -1
- speedy_utils-1.1.8/src/llm_utils/lm/async_lm/__init__.py +7 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/lm/async_lm/_utils.py +7 -4
- speedy_utils-1.1.8/src/llm_utils/lm/async_lm/async_llm_task.py +516 -0
- speedy_utils-1.1.8/src/llm_utils/lm/async_lm/async_lm.py +387 -0
- speedy_utils-1.1.8/src/llm_utils/lm/async_lm/async_lm_base.py +407 -0
- speedy_utils-1.1.8/src/llm_utils/lm/async_lm/lm_specific.py +136 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/lm/utils.py +1 -3
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/scripts/vllm_load_balancer.py +49 -37
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/__init__.py +3 -1
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/common/notebook_utils.py +4 -4
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/common/report_manager.py +2 -3
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/common/utils_cache.py +233 -3
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/common/utils_io.py +2 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/scripts/mpython.py +1 -3
- speedy_utils-1.1.6/src/llm_utils/lm/async_lm/__init__.py +0 -2
- speedy_utils-1.1.6/src/llm_utils/lm/async_lm/async_llm_task.py +0 -154
- speedy_utils-1.1.6/src/llm_utils/lm/async_lm/async_lm.py +0 -779
- speedy_utils-1.1.6/src/llm_utils/lm/chat_html.py +0 -246
- speedy_utils-1.1.6/src/llm_utils/lm/lm_json.py +0 -68
- speedy_utils-1.1.6/src/llm_utils/lm/sync_lm.py +0 -943
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/README.md +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/chat_format/__init__.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/chat_format/display.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/chat_format/utils.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/lm/__init__.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/scripts/README.md +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/llm_utils/scripts/vllm_serve.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/all.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/common/__init__.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/common/clock.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/common/function_decorator.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/common/logger.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/common/utils_misc.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/common/utils_print.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/multi_worker/__init__.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/multi_worker/process.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/multi_worker/thread.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/scripts/__init__.py +0 -0
- {speedy_utils-1.1.6 → speedy_utils-1.1.8}/src/speedy_utils/scripts/openapi_client_codegen.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "speedy-utils"
|
|
3
|
-
version = "1.1.
|
|
3
|
+
version = "1.1.8"
|
|
4
4
|
description = "Fast and easy-to-use package for data science"
|
|
5
5
|
authors = ["AnhVTH <anhvth.226@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -66,16 +66,48 @@ svllm-lb = "llm_utils.scripts.vllm_load_balancer:run_load_balancer"
|
|
|
66
66
|
openapi_client_codegen = "speedy_utils.scripts.openapi_client_codegen:main"
|
|
67
67
|
|
|
68
68
|
|
|
69
|
+
[tool.ruff]
|
|
70
|
+
exclude = [
|
|
71
|
+
"**/*.ipynb",
|
|
72
|
+
"notebooks/**/*.ipynb",
|
|
73
|
+
"legacy",
|
|
74
|
+
"**/__pycache__",
|
|
75
|
+
"**/.cache",
|
|
76
|
+
"**/.ruff_cache",
|
|
77
|
+
"**/.pytest_cache",
|
|
78
|
+
"**/.ipynb_checkpoints",
|
|
79
|
+
"**/.venv",
|
|
80
|
+
"**/.vscode",
|
|
81
|
+
"**/*.egg-info",
|
|
82
|
+
"**/*.lock",
|
|
83
|
+
"poetry.lock",
|
|
84
|
+
"Pipfile.lock",
|
|
85
|
+
"package-lock.json",
|
|
86
|
+
"yarn.lock",
|
|
87
|
+
"unsloth_compiled_cache",
|
|
88
|
+
"unsloth_training_checkpoints",
|
|
89
|
+
]
|
|
90
|
+
target-version = "py310"
|
|
91
|
+
unsafe-fixes = true # allow deletions Ruff marks unsafe
|
|
92
|
+
[tool.ruff.lint]
|
|
93
|
+
ignore = [
|
|
94
|
+
"E401", # multiple imports on one line
|
|
95
|
+
"E402", # module level import not at top of file
|
|
96
|
+
"E501", # line too long
|
|
97
|
+
"F401", # unused import
|
|
98
|
+
"F403", # wildcard import
|
|
99
|
+
"F405", # name may be undefined, from wildcard import
|
|
100
|
+
"F841", # local variable assigned but never used
|
|
101
|
+
"E722", # do not use bare except
|
|
102
|
+
"E731", # do not assign a lambda expression, use a def
|
|
103
|
+
"E741", # ambiguous variable name
|
|
104
|
+
"E902", # io error
|
|
105
|
+
]
|
|
106
|
+
unfixable = ["E401", "E402", "E501", "F401", "F403"]
|
|
107
|
+
extend-select = ["F"] # keep all pyflakes rules
|
|
108
|
+
|
|
109
|
+
|
|
69
110
|
[tool.ruff.format]
|
|
70
111
|
quote-style = "double"
|
|
71
112
|
line-ending = "lf"
|
|
72
113
|
docstring-code-format = true
|
|
73
|
-
[tool.ruff]
|
|
74
|
-
exclude = ["**/*.ipynb", "poly_frontend_controler/*", "poly_client/", "legacy"]
|
|
75
|
-
ignore = [
|
|
76
|
-
"E501", # Line too long
|
|
77
|
-
"F401", # Unused import
|
|
78
|
-
"F403", # Wildcard import
|
|
79
|
-
"F841", # Local variable is assigned to but never used
|
|
80
|
-
"T201", # Use of `print` statement
|
|
81
|
-
]
|
|
@@ -10,7 +10,6 @@ from .chat_format import (
|
|
|
10
10
|
transform_messages_to_chatml,
|
|
11
11
|
)
|
|
12
12
|
from .lm.async_lm import AsyncLLMTask, AsyncLM
|
|
13
|
-
from .lm.sync_lm import LM, LLMTask
|
|
14
13
|
|
|
15
14
|
__all__ = [
|
|
16
15
|
"transform_messages",
|
|
@@ -21,10 +20,7 @@ __all__ = [
|
|
|
21
20
|
"display_conversations",
|
|
22
21
|
"build_chatml_input",
|
|
23
22
|
"format_msgs",
|
|
24
|
-
# "group_messages_by_len",
|
|
25
|
-
"LM",
|
|
26
|
-
"AsyncLM",
|
|
27
23
|
"display_chat_messages_as_html",
|
|
28
|
-
"
|
|
24
|
+
"AsyncLM",
|
|
29
25
|
"AsyncLLMTask",
|
|
30
26
|
]
|
|
@@ -16,9 +16,9 @@ def identify_format(item):
|
|
|
16
16
|
def _transform_sharegpt_to_chatml(
|
|
17
17
|
item, default_system_message="You are a helpful assistant.", print_msg=False
|
|
18
18
|
):
|
|
19
|
-
assert isinstance(
|
|
20
|
-
item
|
|
21
|
-
)
|
|
19
|
+
assert isinstance(item, dict), (
|
|
20
|
+
"The item is not in the correct format. Please check the format of the item."
|
|
21
|
+
)
|
|
22
22
|
|
|
23
23
|
messages = []
|
|
24
24
|
system_msg = item.get("system", "")
|
|
@@ -116,16 +116,16 @@ def transform_messages_to_chatml(input_data, input_format="auto"):
|
|
|
116
116
|
input_data = deepcopy(input_data)
|
|
117
117
|
if isinstance(input_data, list):
|
|
118
118
|
input_format = "chatlm"
|
|
119
|
-
assert (
|
|
120
|
-
|
|
121
|
-
)
|
|
119
|
+
assert input_data[0].get("role") is not None, (
|
|
120
|
+
"The input format is not recognized. Please specify the input format."
|
|
121
|
+
)
|
|
122
122
|
elif isinstance(input_data, dict):
|
|
123
123
|
input_data = _transform_sharegpt_to_chatml(input_data)
|
|
124
124
|
input_format = "sharegpt"
|
|
125
125
|
elif isinstance(input_data, str):
|
|
126
|
-
assert (
|
|
127
|
-
"
|
|
128
|
-
)
|
|
126
|
+
assert "<|im_end|>" in input_data, (
|
|
127
|
+
"The input format is not recognized. Please specify the input format."
|
|
128
|
+
)
|
|
129
129
|
input_format = "chatlm"
|
|
130
130
|
parts = input_data.split("<|im_end|>")
|
|
131
131
|
input_data = []
|
|
@@ -76,7 +76,7 @@ def group_messages_by_len(
|
|
|
76
76
|
"""
|
|
77
77
|
if messages is None:
|
|
78
78
|
raise ValueError("messages parameter cannot be None")
|
|
79
|
-
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
79
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer # type: ignore
|
|
80
80
|
|
|
81
81
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
82
82
|
|
|
@@ -48,13 +48,17 @@ def _yellow(t):
|
|
|
48
48
|
return _color(33, t)
|
|
49
49
|
|
|
50
50
|
|
|
51
|
-
TParsed = TypeVar("TParsed", bound=BaseModel)
|
|
51
|
+
# TParsed = TypeVar("TParsed", bound=BaseModel)
|
|
52
52
|
|
|
53
|
+
InputModelType = TypeVar("InputModelType", bound=BaseModel)
|
|
54
|
+
OutputModelType = TypeVar("OutputModelType", bound=BaseModel)
|
|
53
55
|
|
|
54
|
-
|
|
56
|
+
|
|
57
|
+
class ParsedOutput(TypedDict, Generic[OutputModelType]):
|
|
55
58
|
messages: List
|
|
56
59
|
completion: Any
|
|
57
|
-
parsed:
|
|
60
|
+
parsed: OutputModelType
|
|
61
|
+
model_kwargs: Dict[str, Any]
|
|
58
62
|
|
|
59
63
|
|
|
60
64
|
# --------------------------------------------------------------------------- #
|
|
@@ -185,7 +189,6 @@ __all__ = [
|
|
|
185
189
|
"Messages",
|
|
186
190
|
"LegacyMsgs",
|
|
187
191
|
"RawMsgs",
|
|
188
|
-
"TParsed",
|
|
189
192
|
"ParsedOutput",
|
|
190
193
|
"get_tokenizer",
|
|
191
194
|
"inspect_word_probs_async",
|
|
@@ -0,0 +1,516 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Async LLM Task module for handling language model interactions with structured input/output.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import pathlib
|
|
7
|
+
from abc import ABC
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union, cast
|
|
10
|
+
|
|
11
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
from pytest import Cache
|
|
14
|
+
from speedy_utils import jdumps
|
|
15
|
+
from speedy_utils.all import dump_json_or_pickle, identify
|
|
16
|
+
|
|
17
|
+
from llm_utils.chat_format.display import get_conversation_one_turn
|
|
18
|
+
from llm_utils.lm.async_lm._utils import InputModelType, OutputModelType, ParsedOutput
|
|
19
|
+
from llm_utils.lm.async_lm.async_lm import AsyncLM
|
|
20
|
+
|
|
21
|
+
# Type aliases for better readability
|
|
22
|
+
TModel = TypeVar("TModel", bound=BaseModel)
|
|
23
|
+
Messages = List[ChatCompletionMessageParam]
|
|
24
|
+
LegacyMsgs = List[Dict[str, str]]
|
|
25
|
+
RawMsgs = Union[Messages, LegacyMsgs]
|
|
26
|
+
|
|
27
|
+
# Default configuration constants
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class LMConfiguration:
|
|
32
|
+
"""Configuration class for language model parameters."""
|
|
33
|
+
|
|
34
|
+
model: Optional[str] = None
|
|
35
|
+
temperature: Optional[float] = None
|
|
36
|
+
max_tokens: Optional[int] = None
|
|
37
|
+
host: Optional[str] = None
|
|
38
|
+
port: Optional[Union[int, str]] = None
|
|
39
|
+
base_url: Optional[str] = None
|
|
40
|
+
api_key: Optional[str] = None
|
|
41
|
+
cache: Optional[bool] = True
|
|
42
|
+
think: Optional[Literal[True, False]] = None
|
|
43
|
+
add_json_schema_to_instruction: Optional[bool] = None
|
|
44
|
+
use_beta: Optional[bool] = False
|
|
45
|
+
ports: Optional[List[int]] = None
|
|
46
|
+
top_p: Optional[float] = None
|
|
47
|
+
presence_penalty: Optional[float] = None
|
|
48
|
+
top_k: Optional[int] = None
|
|
49
|
+
repetition_penalty: Optional[float] = None
|
|
50
|
+
|
|
51
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
52
|
+
"""Convert configuration to dictionary format."""
|
|
53
|
+
return {
|
|
54
|
+
"model": self.model,
|
|
55
|
+
"temperature": self.temperature,
|
|
56
|
+
"max_tokens": self.max_tokens,
|
|
57
|
+
"host": self.host,
|
|
58
|
+
"port": self.port,
|
|
59
|
+
"base_url": self.base_url,
|
|
60
|
+
"api_key": self.api_key,
|
|
61
|
+
"cache": self.cache,
|
|
62
|
+
"think": self.think,
|
|
63
|
+
"add_json_schema_to_instruction": self.add_json_schema_to_instruction,
|
|
64
|
+
"use_beta": self.use_beta,
|
|
65
|
+
"ports": self.ports,
|
|
66
|
+
"top_p": self.top_p,
|
|
67
|
+
"presence_penalty": self.presence_penalty,
|
|
68
|
+
"top_k": self.top_k,
|
|
69
|
+
"repetition_penalty": self.repetition_penalty,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class AsyncLLMTask(ABC, Generic[InputModelType, OutputModelType]):
|
|
74
|
+
"""
|
|
75
|
+
Abstract base class for asynchronous language model tasks with structured I/O.
|
|
76
|
+
|
|
77
|
+
This class provides a framework for creating LLM tasks with strongly typed
|
|
78
|
+
input and output models, automatic training data collection, and support
|
|
79
|
+
for both thinking and non-thinking modes.
|
|
80
|
+
|
|
81
|
+
Type Parameters:
|
|
82
|
+
InputModelType: Pydantic model type for task input
|
|
83
|
+
OutputModelType: Pydantic model type for task output
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
InputModel: InputModelType
|
|
87
|
+
OutputModel: OutputModelType
|
|
88
|
+
|
|
89
|
+
# default class attributes for configuration
|
|
90
|
+
DEFAULT_MODEL: Optional[str] = None
|
|
91
|
+
DEFAULT_CACHE_DIR: Optional[pathlib.Path] = None
|
|
92
|
+
DEFAULT_TEMPERATURE: Optional[float] = None
|
|
93
|
+
DEFAULT_MAX_TOKENS: Optional[int] = None
|
|
94
|
+
DEFAULT_HOST: Optional[str] = None
|
|
95
|
+
DEFAULT_PORT: Optional[Union[int, str]] = None
|
|
96
|
+
DEFAULT_TOP_P: Optional[float] = None
|
|
97
|
+
DEFAULT_PRESENCE_PENALTY: Optional[float] = None
|
|
98
|
+
DEFAULT_TOP_K: Optional[int] = None
|
|
99
|
+
DEFAULT_REPETITION_PENALTY: Optional[float] = None
|
|
100
|
+
DEFAULT_CACHE: Optional[bool] = True
|
|
101
|
+
DEFAULT_THINK: Optional[Literal[True, False]] = None
|
|
102
|
+
DEFAULT_PORTS: Optional[List[int]] = None
|
|
103
|
+
DEFAULT_USE_BETA: Optional[bool] = False
|
|
104
|
+
DEFAULT_ADD_JSON_SCHEMA_TO_INSTRUCTION: Optional[bool] = True
|
|
105
|
+
DEFAULT_COLLECT_DATA: Optional[bool] = None
|
|
106
|
+
DEFAULT_BASE_URL: Optional[str] = None
|
|
107
|
+
DEFAULT_API_KEY: Optional[str] = None
|
|
108
|
+
|
|
109
|
+
IS_DATA_COLLECTION: bool = False
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
model: Optional[str] = None,
|
|
114
|
+
temperature: Optional[float] = None,
|
|
115
|
+
max_tokens: Optional[int] = None,
|
|
116
|
+
host: Optional[str] = None,
|
|
117
|
+
port: Optional[Union[int, str]] = None,
|
|
118
|
+
base_url: Optional[str] = None,
|
|
119
|
+
api_key: Optional[str] = None,
|
|
120
|
+
cache: Optional[bool] = None,
|
|
121
|
+
think: Optional[Literal[True, False]] = None,
|
|
122
|
+
add_json_schema_to_instruction: Optional[bool] = None,
|
|
123
|
+
use_beta: Optional[bool] = None,
|
|
124
|
+
ports: Optional[List[int]] = None,
|
|
125
|
+
top_p: Optional[float] = None,
|
|
126
|
+
presence_penalty: Optional[float] = None,
|
|
127
|
+
top_k: Optional[int] = None,
|
|
128
|
+
repetition_penalty: Optional[float] = None,
|
|
129
|
+
) -> None:
|
|
130
|
+
"""
|
|
131
|
+
Initialize the AsyncLLMTask with language model configuration.
|
|
132
|
+
|
|
133
|
+
All arguments are optional; defaults are taken from class attributes if not provided.
|
|
134
|
+
"""
|
|
135
|
+
self._config = LMConfiguration(
|
|
136
|
+
model=model if model is not None else self.DEFAULT_MODEL,
|
|
137
|
+
temperature=temperature
|
|
138
|
+
if temperature is not None
|
|
139
|
+
else self.DEFAULT_TEMPERATURE,
|
|
140
|
+
max_tokens=max_tokens
|
|
141
|
+
if max_tokens is not None
|
|
142
|
+
else self.DEFAULT_MAX_TOKENS,
|
|
143
|
+
host=host if host is not None else self.DEFAULT_HOST,
|
|
144
|
+
port=port if port is not None else self.DEFAULT_PORT,
|
|
145
|
+
base_url=base_url if base_url is not None else self.DEFAULT_BASE_URL,
|
|
146
|
+
api_key=api_key if api_key is not None else self.DEFAULT_API_KEY,
|
|
147
|
+
cache=cache if cache is not None else self.DEFAULT_CACHE,
|
|
148
|
+
think=think if think is not None else self.DEFAULT_THINK,
|
|
149
|
+
add_json_schema_to_instruction=add_json_schema_to_instruction
|
|
150
|
+
if add_json_schema_to_instruction is not None
|
|
151
|
+
else self.DEFAULT_ADD_JSON_SCHEMA_TO_INSTRUCTION,
|
|
152
|
+
use_beta=use_beta if use_beta is not None else self.DEFAULT_USE_BETA,
|
|
153
|
+
ports=ports if ports is not None else self.DEFAULT_PORTS,
|
|
154
|
+
top_p=top_p if top_p is not None else self.DEFAULT_TOP_P,
|
|
155
|
+
presence_penalty=presence_penalty
|
|
156
|
+
if presence_penalty is not None
|
|
157
|
+
else self.DEFAULT_PRESENCE_PENALTY,
|
|
158
|
+
top_k=top_k if top_k is not None else self.DEFAULT_TOP_K,
|
|
159
|
+
repetition_penalty=repetition_penalty
|
|
160
|
+
if repetition_penalty is not None
|
|
161
|
+
else self.DEFAULT_REPETITION_PENALTY,
|
|
162
|
+
)
|
|
163
|
+
self._lm: Optional[AsyncLM] = None
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def lm(self) -> AsyncLM:
|
|
167
|
+
"""
|
|
168
|
+
Lazy-loaded AsyncLM instance with proper configuration.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Configured AsyncLM instance for this task
|
|
172
|
+
"""
|
|
173
|
+
if self._lm is None:
|
|
174
|
+
self._lm = AsyncLM(
|
|
175
|
+
**self._config.to_dict(),
|
|
176
|
+
response_model=self._get_output_model_type(),
|
|
177
|
+
)
|
|
178
|
+
return self._lm
|
|
179
|
+
|
|
180
|
+
def _get_output_model_type(self) -> type[OutputModelType]:
|
|
181
|
+
"""
|
|
182
|
+
Extract the output model type from generic type arguments.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
The OutputModelType class
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
TypeError: If output model type cannot be determined
|
|
189
|
+
"""
|
|
190
|
+
# Try to get type from generic base classes
|
|
191
|
+
orig_bases = getattr(self.__class__, "__orig_bases__", None)
|
|
192
|
+
if (
|
|
193
|
+
orig_bases
|
|
194
|
+
and hasattr(orig_bases[0], "__args__")
|
|
195
|
+
and len(orig_bases[0].__args__) >= 2
|
|
196
|
+
):
|
|
197
|
+
return orig_bases[0].__args__[1]
|
|
198
|
+
|
|
199
|
+
# Fallback to class attribute
|
|
200
|
+
if hasattr(self, "OutputModel"):
|
|
201
|
+
return self.OutputModel # type: ignore
|
|
202
|
+
|
|
203
|
+
raise TypeError(
|
|
204
|
+
f"{self.__class__.__name__} must define OutputModel as a class attribute "
|
|
205
|
+
"or use proper generic typing with AsyncLLMTask[InputModel, OutputModel]"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def _get_input_model_type(self) -> type[InputModelType]:
|
|
209
|
+
"""
|
|
210
|
+
Extract the input model type from generic type arguments.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
The InputModelType class
|
|
214
|
+
|
|
215
|
+
Raises:
|
|
216
|
+
TypeError: If input model type cannot be determined
|
|
217
|
+
"""
|
|
218
|
+
# Try to get type from generic base classes
|
|
219
|
+
orig_bases = getattr(self.__class__, "__orig_bases__", None)
|
|
220
|
+
if (
|
|
221
|
+
orig_bases
|
|
222
|
+
and hasattr(orig_bases[0], "__args__")
|
|
223
|
+
and len(orig_bases[0].__args__) >= 2
|
|
224
|
+
):
|
|
225
|
+
return orig_bases[0].__args__[0]
|
|
226
|
+
|
|
227
|
+
raise TypeError(
|
|
228
|
+
f"{self.__class__.__name__} must define InputModel as a class attribute "
|
|
229
|
+
"or use proper generic typing with AsyncLLMTask[InputModel, OutputModel]"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def _validate_and_convert_input(self, data: Union[BaseModel, dict]) -> BaseModel:
|
|
233
|
+
"""
|
|
234
|
+
Validate and convert input data to the expected input model type.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
data: Input data as BaseModel instance or dictionary
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
Validated BaseModel instance
|
|
241
|
+
|
|
242
|
+
Raises:
|
|
243
|
+
TypeError: If input data cannot be converted to InputModel
|
|
244
|
+
"""
|
|
245
|
+
if isinstance(data, BaseModel):
|
|
246
|
+
return data
|
|
247
|
+
|
|
248
|
+
input_model_type = self._get_input_model_type()
|
|
249
|
+
if isinstance(input_model_type, type) and issubclass(
|
|
250
|
+
input_model_type, BaseModel
|
|
251
|
+
):
|
|
252
|
+
try:
|
|
253
|
+
return input_model_type(**data)
|
|
254
|
+
except Exception as e:
|
|
255
|
+
raise TypeError(
|
|
256
|
+
f"Failed to convert input data to {input_model_type.__name__}: {e}"
|
|
257
|
+
) from e
|
|
258
|
+
|
|
259
|
+
raise TypeError("InputModel must be a subclass of BaseModel")
|
|
260
|
+
|
|
261
|
+
def _validate_output_model(self) -> type[BaseModel]:
|
|
262
|
+
"""
|
|
263
|
+
Validate that the output model is properly configured.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
The validated output model type
|
|
267
|
+
|
|
268
|
+
Raises:
|
|
269
|
+
TypeError: If output model is not a valid BaseModel subclass
|
|
270
|
+
"""
|
|
271
|
+
output_model_type = self._get_output_model_type()
|
|
272
|
+
if not (
|
|
273
|
+
isinstance(output_model_type, type)
|
|
274
|
+
and issubclass(output_model_type, BaseModel)
|
|
275
|
+
):
|
|
276
|
+
raise TypeError("OutputModel must be a subclass of BaseModel")
|
|
277
|
+
return output_model_type
|
|
278
|
+
|
|
279
|
+
async def _base_call(
|
|
280
|
+
self, data: Union[BaseModel, dict]
|
|
281
|
+
) -> ParsedOutput[OutputModelType]:
|
|
282
|
+
"""
|
|
283
|
+
Core method that handles language model interaction with type safety.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
data: Input data as BaseModel instance or dictionary
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Parsed output from the language model
|
|
290
|
+
|
|
291
|
+
Raises:
|
|
292
|
+
TypeError: If input/output models are not properly configured
|
|
293
|
+
"""
|
|
294
|
+
# Validate input and output models
|
|
295
|
+
validated_input = self._validate_and_convert_input(data)
|
|
296
|
+
self._validate_output_model()
|
|
297
|
+
|
|
298
|
+
# Execute the language model call
|
|
299
|
+
return cast(
|
|
300
|
+
ParsedOutput[OutputModelType],
|
|
301
|
+
await self.lm.parse(
|
|
302
|
+
instruction=self.__doc__ or "",
|
|
303
|
+
prompt=validated_input.model_dump_json(),
|
|
304
|
+
),
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
def _create_no_think_messages(self, think_messages: Messages) -> Messages:
|
|
308
|
+
"""
|
|
309
|
+
Convert thinking mode messages to non-thinking mode.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
think_messages: Original messages with thinking mode enabled
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Messages converted to non-thinking mode
|
|
316
|
+
"""
|
|
317
|
+
if not think_messages:
|
|
318
|
+
return think_messages
|
|
319
|
+
|
|
320
|
+
# Create deep copy to avoid modifying original
|
|
321
|
+
no_think_messages = copy.deepcopy(think_messages)
|
|
322
|
+
|
|
323
|
+
# Update system message
|
|
324
|
+
if no_think_messages and "content" in no_think_messages[0]:
|
|
325
|
+
system_content = no_think_messages[0]["content"]
|
|
326
|
+
if isinstance(system_content, str):
|
|
327
|
+
no_think_messages[0]["content"] = system_content.replace(
|
|
328
|
+
"/think", "/no_think"
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Update assistant message (last message)
|
|
332
|
+
if len(no_think_messages) > 1 and "content" in no_think_messages[-1]:
|
|
333
|
+
assistant_content = no_think_messages[-1]["content"]
|
|
334
|
+
if isinstance(assistant_content, str) and "</think>" in assistant_content:
|
|
335
|
+
# Extract content after thinking block
|
|
336
|
+
post_think_content = assistant_content.split("</think>", 1)[1].strip()
|
|
337
|
+
no_think_messages[-1]["content"] = (
|
|
338
|
+
f"<think>\n\n</think>\n\n{post_think_content}"
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return no_think_messages
|
|
342
|
+
|
|
343
|
+
def _save_training_data(
|
|
344
|
+
self,
|
|
345
|
+
input_data: InputModelType,
|
|
346
|
+
think_messages: Messages,
|
|
347
|
+
no_think_messages: Messages,
|
|
348
|
+
model_kwargs: Dict[str, Any],
|
|
349
|
+
cache_dir: pathlib.Path,
|
|
350
|
+
expected_response: Optional[OutputModelType] = None,
|
|
351
|
+
label: Optional[str] = None,
|
|
352
|
+
) -> None:
|
|
353
|
+
"""
|
|
354
|
+
Save training data to cache directory.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
input_data: Input data for the task
|
|
358
|
+
think_messages: Messages with thinking mode
|
|
359
|
+
no_think_messages: Messages without thinking mode
|
|
360
|
+
model_kwargs: Model configuration used
|
|
361
|
+
cache_dir: Directory to save training data
|
|
362
|
+
expected_response: Expected response for validation
|
|
363
|
+
label: Optional label for the training data
|
|
364
|
+
"""
|
|
365
|
+
# Create unique identifier for this input
|
|
366
|
+
input_id = identify(input_data.model_dump())
|
|
367
|
+
class_cache_dir = cache_dir / self.__class__.__name__
|
|
368
|
+
class_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
369
|
+
|
|
370
|
+
# Prepare combined training data
|
|
371
|
+
training_data = {
|
|
372
|
+
"think_messages": think_messages,
|
|
373
|
+
"no_think_messages": no_think_messages,
|
|
374
|
+
"model_kwargs": model_kwargs,
|
|
375
|
+
"input_data": input_data.model_dump(),
|
|
376
|
+
"label": label,
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
if expected_response is not None:
|
|
380
|
+
training_data["expected_response"] = expected_response.model_dump()
|
|
381
|
+
|
|
382
|
+
# Save to file
|
|
383
|
+
training_file = class_cache_dir / f"{input_id}.json"
|
|
384
|
+
dump_json_or_pickle(training_data, str(training_file))
|
|
385
|
+
|
|
386
|
+
async def _generate_training_data_with_thinking_mode(
|
|
387
|
+
self,
|
|
388
|
+
input_data: InputModelType,
|
|
389
|
+
expected_response: Optional[OutputModelType] = None,
|
|
390
|
+
label: Optional[str] = None,
|
|
391
|
+
cache_dir: pathlib.Path = DEFAULT_CACHE_DIR,
|
|
392
|
+
) -> OutputModelType:
|
|
393
|
+
"""
|
|
394
|
+
Generate training data for both thinking and non-thinking modes.
|
|
395
|
+
|
|
396
|
+
This method executes the task in thinking mode, then creates equivalent
|
|
397
|
+
non-thinking mode data for training purposes. Both versions are saved
|
|
398
|
+
to the cache directory for later use in model training.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
input_data: Input data for the task
|
|
402
|
+
expected_response: Expected response for validation
|
|
403
|
+
label: Optional label for the training data
|
|
404
|
+
cache_dir: Directory to save training data
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Parsed output from the language model
|
|
408
|
+
"""
|
|
409
|
+
# Execute the base call to get thinking mode data
|
|
410
|
+
output = await self._base_call(input_data)
|
|
411
|
+
parsed_result = output["parsed"]
|
|
412
|
+
think_messages = output["messages"]
|
|
413
|
+
|
|
414
|
+
# Create non-thinking mode equivalent
|
|
415
|
+
no_think_messages = self._create_no_think_messages(think_messages)
|
|
416
|
+
|
|
417
|
+
# Save training data
|
|
418
|
+
self._save_training_data(
|
|
419
|
+
input_data=input_data,
|
|
420
|
+
think_messages=think_messages,
|
|
421
|
+
no_think_messages=no_think_messages,
|
|
422
|
+
model_kwargs=output["model_kwargs"],
|
|
423
|
+
cache_dir=cache_dir,
|
|
424
|
+
expected_response=expected_response,
|
|
425
|
+
label=label,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
return parsed_result
|
|
429
|
+
|
|
430
|
+
def _should_collect_data(self) -> bool:
|
|
431
|
+
"""
|
|
432
|
+
Determine if training data should be collected for this call.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
True if data collection is enabled
|
|
436
|
+
"""
|
|
437
|
+
return self.IS_DATA_COLLECTION
|
|
438
|
+
|
|
439
|
+
async def __call__(
|
|
440
|
+
self,
|
|
441
|
+
input_data: InputModelType,
|
|
442
|
+
expected_response: Optional[OutputModelType] = None,
|
|
443
|
+
label: Optional[str] = None,
|
|
444
|
+
**kwargs: Any,
|
|
445
|
+
) -> OutputModelType:
|
|
446
|
+
"""
|
|
447
|
+
Execute the LLM task with the provided input data.
|
|
448
|
+
|
|
449
|
+
This is the main entry point for task execution. If data collection
|
|
450
|
+
is enabled (either via instance configuration or environment variable),
|
|
451
|
+
training data will be automatically generated and saved.
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
input_data: Input data conforming to InputModelType
|
|
455
|
+
expected_response: Expected response for validation during data collection
|
|
456
|
+
label: Optional label for training data categorization
|
|
457
|
+
**kwargs: Additional keyword arguments (for future extensibility)
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
Parsed output conforming to OutputModelType
|
|
461
|
+
"""
|
|
462
|
+
if self._should_collect_data():
|
|
463
|
+
return await self._generate_training_data_with_thinking_mode(
|
|
464
|
+
input_data=input_data,
|
|
465
|
+
expected_response=expected_response,
|
|
466
|
+
label=label,
|
|
467
|
+
)
|
|
468
|
+
else:
|
|
469
|
+
output = await self._base_call(input_data)
|
|
470
|
+
return output["parsed"]
|
|
471
|
+
|
|
472
|
+
def generate_training_data(
|
|
473
|
+
self, input_json: str, output_json: str
|
|
474
|
+
) -> Dict[str, Any]:
|
|
475
|
+
"""
|
|
476
|
+
Generate training data in ShareGPT format for the given input/output pair.
|
|
477
|
+
|
|
478
|
+
This method is useful for creating training datasets from existing
|
|
479
|
+
input/output pairs without executing the language model.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
input_dict: Input data as dictionary
|
|
483
|
+
output: Output data as dictionary
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
Training data in ShareGPT message format
|
|
487
|
+
|
|
488
|
+
Raises:
|
|
489
|
+
AttributeError: If InputModel or OutputModel are not properly defined
|
|
490
|
+
"""
|
|
491
|
+
# if not hasattr(self, "InputModel") or not hasattr(self, "OutputModel"):
|
|
492
|
+
# raise AttributeError(
|
|
493
|
+
# f"{self.__class__.__name__} must define InputModel and OutputModel "
|
|
494
|
+
# "as class attributes to use generate_training_data"
|
|
495
|
+
# )
|
|
496
|
+
|
|
497
|
+
system_prompt = self.__doc__ or ""
|
|
498
|
+
assert isinstance(input_json, str), "Input must be a JSON string"
|
|
499
|
+
assert isinstance(output_json, str), "Output must be a JSON string"
|
|
500
|
+
messages = get_conversation_one_turn(
|
|
501
|
+
system_msg=system_prompt,
|
|
502
|
+
user_msg=input_json,
|
|
503
|
+
assistant_msg=output_json,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
return {"messages": messages}
|
|
507
|
+
|
|
508
|
+
# Compatibility alias for other LLMTask implementations
|
|
509
|
+
arun = __call__
|
|
510
|
+
|
|
511
|
+
async def __aenter__(self):
|
|
512
|
+
return self
|
|
513
|
+
|
|
514
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
515
|
+
if self._lm and hasattr(self._lm, "aclose"): # Or self._lm.client
|
|
516
|
+
await self._lm._last_client._client.aclose()
|