speedy-utils 1.1.22__py3-none-any.whl → 1.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +19 -7
- llm_utils/chat_format/__init__.py +2 -0
- llm_utils/chat_format/display.py +115 -44
- llm_utils/lm/__init__.py +20 -2
- llm_utils/lm/llm.py +413 -0
- llm_utils/lm/llm_signature.py +35 -0
- llm_utils/lm/mixins.py +379 -0
- llm_utils/lm/openai_memoize.py +18 -7
- llm_utils/lm/signature.py +271 -0
- llm_utils/lm/utils.py +61 -76
- speedy_utils/__init__.py +28 -1
- speedy_utils/all.py +30 -1
- speedy_utils/common/utils_io.py +36 -26
- speedy_utils/common/utils_misc.py +25 -1
- speedy_utils/multi_worker/thread.py +145 -58
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.24.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.24.dist-info}/RECORD +19 -17
- llm_utils/lm/llm_task.py +0 -614
- llm_utils/lm/lm.py +0 -207
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.24.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.22.dist-info → speedy_utils-1.1.24.dist-info}/entry_points.txt +0 -0
llm_utils/lm/llm.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
# type: ignore
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Simplified LLM Task module for handling language model interactions with structured input/output.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import subprocess
|
|
9
|
+
from typing import Any, Dict, List, Optional, Type, Union, cast
|
|
10
|
+
|
|
11
|
+
import requests
|
|
12
|
+
from loguru import logger
|
|
13
|
+
from openai import OpenAI, AuthenticationError, BadRequestError, RateLimitError
|
|
14
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
15
|
+
from pydantic import BaseModel
|
|
16
|
+
|
|
17
|
+
from speedy_utils.common.utils_io import jdumps
|
|
18
|
+
|
|
19
|
+
from .utils import (
|
|
20
|
+
_extract_port_from_vllm_cmd,
|
|
21
|
+
_start_vllm_server,
|
|
22
|
+
_kill_vllm_on_port,
|
|
23
|
+
_is_server_running,
|
|
24
|
+
get_base_client,
|
|
25
|
+
_is_lora_path,
|
|
26
|
+
_get_port_from_client,
|
|
27
|
+
_load_lora_adapter,
|
|
28
|
+
_unload_lora_adapter,
|
|
29
|
+
kill_all_vllm_processes,
|
|
30
|
+
stop_vllm_process,
|
|
31
|
+
)
|
|
32
|
+
from .base_prompt_builder import BasePromptBuilder
|
|
33
|
+
from .mixins import (
|
|
34
|
+
TemperatureRangeMixin,
|
|
35
|
+
TwoStepPydanticMixin,
|
|
36
|
+
VLLMMixin,
|
|
37
|
+
ModelUtilsMixin,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Type aliases for better readability
|
|
41
|
+
Messages = List[ChatCompletionMessageParam]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LLM(
|
|
45
|
+
TemperatureRangeMixin,
|
|
46
|
+
TwoStepPydanticMixin,
|
|
47
|
+
VLLMMixin,
|
|
48
|
+
ModelUtilsMixin,
|
|
49
|
+
):
|
|
50
|
+
"""LLM task with structured input/output handling."""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
instruction: Optional[str] = None,
|
|
55
|
+
input_model: Union[Type[BaseModel], type[str]] = str,
|
|
56
|
+
output_model: Type[BaseModel] | Type[str] = None,
|
|
57
|
+
client: Union[OpenAI, int, str, None] = None,
|
|
58
|
+
cache=True,
|
|
59
|
+
is_reasoning_model: bool = False,
|
|
60
|
+
force_lora_unload: bool = False,
|
|
61
|
+
lora_path: Optional[str] = None,
|
|
62
|
+
vllm_cmd: Optional[str] = None,
|
|
63
|
+
vllm_timeout: int = 1200,
|
|
64
|
+
vllm_reuse: bool = True,
|
|
65
|
+
**model_kwargs,
|
|
66
|
+
):
|
|
67
|
+
"""Initialize LLMTask."""
|
|
68
|
+
self.instruction = instruction
|
|
69
|
+
self.input_model = input_model
|
|
70
|
+
self.output_model = output_model
|
|
71
|
+
self.model_kwargs = model_kwargs
|
|
72
|
+
self.is_reasoning_model = is_reasoning_model
|
|
73
|
+
self.force_lora_unload = force_lora_unload
|
|
74
|
+
self.lora_path = lora_path
|
|
75
|
+
self.vllm_cmd = vllm_cmd
|
|
76
|
+
self.vllm_timeout = vllm_timeout
|
|
77
|
+
self.vllm_reuse = vllm_reuse
|
|
78
|
+
self.vllm_process: Optional[subprocess.Popen] = None
|
|
79
|
+
self.last_ai_response = None # Store raw response from client
|
|
80
|
+
self.cache = cache
|
|
81
|
+
|
|
82
|
+
# Handle VLLM server startup if vllm_cmd is provided
|
|
83
|
+
if self.vllm_cmd:
|
|
84
|
+
self._setup_vllm_server()
|
|
85
|
+
|
|
86
|
+
# Set client to use the VLLM server port if not explicitly provided
|
|
87
|
+
port = _extract_port_from_vllm_cmd(self.vllm_cmd)
|
|
88
|
+
if client is None:
|
|
89
|
+
client = port
|
|
90
|
+
|
|
91
|
+
self.client = get_base_client(client, cache=cache, vllm_cmd=self.vllm_cmd, vllm_process=self.vllm_process)
|
|
92
|
+
# check connection of client
|
|
93
|
+
try:
|
|
94
|
+
self.client.models.list()
|
|
95
|
+
except Exception as e:
|
|
96
|
+
logger.error(f"Failed to connect to OpenAI client: {str(e)}, base_url={self.client.base_url}")
|
|
97
|
+
raise e
|
|
98
|
+
|
|
99
|
+
if not self.model_kwargs.get("model", ""):
|
|
100
|
+
self.model_kwargs["model"] = self.client.models.list().data[0].id
|
|
101
|
+
|
|
102
|
+
# Handle LoRA loading if lora_path is provided
|
|
103
|
+
if self.lora_path:
|
|
104
|
+
self._load_lora_adapter()
|
|
105
|
+
|
|
106
|
+
def __enter__(self):
|
|
107
|
+
"""Context manager entry."""
|
|
108
|
+
return self
|
|
109
|
+
|
|
110
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
111
|
+
"""Context manager exit with cleanup."""
|
|
112
|
+
self.cleanup_vllm_server()
|
|
113
|
+
|
|
114
|
+
def _prepare_input(self, input_data: Union[str, BaseModel, List[Dict]]) -> Messages:
|
|
115
|
+
"""Convert input to messages format."""
|
|
116
|
+
if isinstance(input_data, list):
|
|
117
|
+
assert isinstance(input_data[0], dict) and "role" in input_data[0], (
|
|
118
|
+
"If input_data is a list, it must be a list of messages with 'role' and 'content' keys."
|
|
119
|
+
)
|
|
120
|
+
return cast(Messages, input_data)
|
|
121
|
+
else:
|
|
122
|
+
# Convert input to string format
|
|
123
|
+
if isinstance(input_data, str):
|
|
124
|
+
user_content = input_data
|
|
125
|
+
elif hasattr(input_data, "model_dump_json"):
|
|
126
|
+
user_content = input_data.model_dump_json()
|
|
127
|
+
elif isinstance(input_data, dict):
|
|
128
|
+
user_content = jdumps(input_data)
|
|
129
|
+
else:
|
|
130
|
+
user_content = str(input_data)
|
|
131
|
+
|
|
132
|
+
# Build messages
|
|
133
|
+
messages = (
|
|
134
|
+
[
|
|
135
|
+
{"role": "system", "content": self.instruction},
|
|
136
|
+
]
|
|
137
|
+
if self.instruction is not None
|
|
138
|
+
else []
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
messages.append({"role": "user", "content": user_content})
|
|
142
|
+
return cast(Messages, messages)
|
|
143
|
+
|
|
144
|
+
def text_completion(self, input_data: Union[str, BaseModel, list[Dict]], **runtime_kwargs) -> List[Dict[str, Any]]:
|
|
145
|
+
"""Execute LLM task and return text responses."""
|
|
146
|
+
# Prepare messages
|
|
147
|
+
messages = self._prepare_input(input_data)
|
|
148
|
+
|
|
149
|
+
# Merge runtime kwargs with default model kwargs (runtime takes precedence)
|
|
150
|
+
effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
|
|
151
|
+
model_name = effective_kwargs.get("model", self.model_kwargs["model"])
|
|
152
|
+
|
|
153
|
+
# Extract model name from kwargs for API call
|
|
154
|
+
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
completion = self.client.chat.completions.create(model=model_name, messages=messages, **api_kwargs)
|
|
158
|
+
# Store raw response from client
|
|
159
|
+
self.last_ai_response = completion
|
|
160
|
+
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
161
|
+
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
162
|
+
logger.error(error_msg)
|
|
163
|
+
raise
|
|
164
|
+
except Exception as e:
|
|
165
|
+
is_length_error = "Length" in str(e) or "maximum context length" in str(e)
|
|
166
|
+
if is_length_error:
|
|
167
|
+
raise ValueError(f"Input too long for model {model_name}. Error: {str(e)[:100]}...")
|
|
168
|
+
# Re-raise all other exceptions
|
|
169
|
+
raise
|
|
170
|
+
# print(completion)
|
|
171
|
+
|
|
172
|
+
results: List[Dict[str, Any]] = []
|
|
173
|
+
for choice in completion.choices:
|
|
174
|
+
choice_messages = cast(
|
|
175
|
+
Messages,
|
|
176
|
+
messages + [{"role": "assistant", "content": choice.message.content}],
|
|
177
|
+
)
|
|
178
|
+
result_dict = {"parsed": choice.message.content, "messages": choice_messages}
|
|
179
|
+
|
|
180
|
+
# Add reasoning content if this is a reasoning model
|
|
181
|
+
if self.is_reasoning_model and hasattr(choice.message, "reasoning_content"):
|
|
182
|
+
result_dict["reasoning_content"] = choice.message.reasoning_content
|
|
183
|
+
|
|
184
|
+
results.append(result_dict)
|
|
185
|
+
return results
|
|
186
|
+
|
|
187
|
+
def pydantic_parse(
|
|
188
|
+
self,
|
|
189
|
+
input_data: Union[str, BaseModel, list[Dict]],
|
|
190
|
+
response_model: Optional[Type[BaseModel]] | Type[str] = None,
|
|
191
|
+
**runtime_kwargs,
|
|
192
|
+
) -> List[Dict[str, Any]]:
|
|
193
|
+
"""Execute LLM task and return parsed Pydantic model responses."""
|
|
194
|
+
# Prepare messages
|
|
195
|
+
messages = self._prepare_input(input_data)
|
|
196
|
+
|
|
197
|
+
# Merge runtime kwargs with default model kwargs (runtime takes precedence)
|
|
198
|
+
effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
|
|
199
|
+
model_name = effective_kwargs.get("model", self.model_kwargs["model"])
|
|
200
|
+
|
|
201
|
+
# Extract model name from kwargs for API call
|
|
202
|
+
api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
|
|
203
|
+
|
|
204
|
+
pydantic_model_to_use_opt = response_model or self.output_model
|
|
205
|
+
if pydantic_model_to_use_opt is None:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
"No response model specified. Either set output_model in constructor or pass response_model parameter."
|
|
208
|
+
)
|
|
209
|
+
pydantic_model_to_use: Type[BaseModel] = cast(Type[BaseModel], pydantic_model_to_use_opt)
|
|
210
|
+
try:
|
|
211
|
+
completion = self.client.chat.completions.parse(
|
|
212
|
+
model=model_name,
|
|
213
|
+
messages=messages,
|
|
214
|
+
response_format=pydantic_model_to_use,
|
|
215
|
+
**api_kwargs,
|
|
216
|
+
)
|
|
217
|
+
# Store raw response from client
|
|
218
|
+
self.last_ai_response = completion
|
|
219
|
+
except (AuthenticationError, RateLimitError, BadRequestError) as exc:
|
|
220
|
+
error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
|
|
221
|
+
logger.error(error_msg)
|
|
222
|
+
raise
|
|
223
|
+
except Exception as e:
|
|
224
|
+
is_length_error = "Length" in str(e) or "maximum context length" in str(e)
|
|
225
|
+
if is_length_error:
|
|
226
|
+
raise ValueError(f"Input too long for model {model_name}. Error: {str(e)[:100]}...")
|
|
227
|
+
# Re-raise all other exceptions
|
|
228
|
+
raise
|
|
229
|
+
|
|
230
|
+
results: List[Dict[str, Any]] = []
|
|
231
|
+
for choice in completion.choices: # type: ignore[attr-defined]
|
|
232
|
+
choice_messages = cast(
|
|
233
|
+
Messages,
|
|
234
|
+
messages + [{"role": "assistant", "content": choice.message.content}],
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Ensure consistent Pydantic model output for both fresh and cached responses
|
|
238
|
+
parsed_content = choice.message.parsed # type: ignore[attr-defined]
|
|
239
|
+
if isinstance(parsed_content, dict):
|
|
240
|
+
# Cached response: validate dict back to Pydantic model
|
|
241
|
+
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
242
|
+
elif not isinstance(parsed_content, pydantic_model_to_use):
|
|
243
|
+
# Fallback: ensure it's the correct type
|
|
244
|
+
parsed_content = pydantic_model_to_use.model_validate(parsed_content)
|
|
245
|
+
|
|
246
|
+
result_dict = {"parsed": parsed_content, "messages": choice_messages}
|
|
247
|
+
|
|
248
|
+
# Add reasoning content if this is a reasoning model
|
|
249
|
+
if self.is_reasoning_model and hasattr(choice.message, "reasoning_content"):
|
|
250
|
+
result_dict["reasoning_content"] = choice.message.reasoning_content
|
|
251
|
+
|
|
252
|
+
results.append(result_dict)
|
|
253
|
+
return results
|
|
254
|
+
|
|
255
|
+
def __call__(
|
|
256
|
+
self,
|
|
257
|
+
input_data: Union[str, BaseModel, list[Dict]],
|
|
258
|
+
response_model: Optional[Type[BaseModel] | Type[str]] = None,
|
|
259
|
+
two_step_parse_pydantic: bool = False,
|
|
260
|
+
temperature_ranges: Optional[tuple[float, float]] = None,
|
|
261
|
+
n: int = 1,
|
|
262
|
+
cache=None,
|
|
263
|
+
**openai_client_kwargs,
|
|
264
|
+
) -> List[Dict[str, Any]]:
|
|
265
|
+
"""
|
|
266
|
+
Execute LLM task.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
input_data: Input data (string, BaseModel, or message list)
|
|
270
|
+
response_model: Optional response model override
|
|
271
|
+
two_step_parse_pydantic: Use two-step parsing (text then parse)
|
|
272
|
+
temperature_ranges: If set, tuple of (min_temp, max_temp) to sample
|
|
273
|
+
n: Number of temperature samples (only used with temperature_ranges, must be >= 2)
|
|
274
|
+
**runtime_kwargs: Additional runtime parameters
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
List of response dictionaries
|
|
278
|
+
"""
|
|
279
|
+
if cache is not None:
|
|
280
|
+
if hasattr(self.client, "set_cache"):
|
|
281
|
+
self.client.set_cache(cache)
|
|
282
|
+
else:
|
|
283
|
+
logger.warning("Client does not support caching.")
|
|
284
|
+
# Handle temperature range sampling
|
|
285
|
+
if temperature_ranges is not None:
|
|
286
|
+
if n < 2:
|
|
287
|
+
raise ValueError(f"n must be >= 2 when using temperature_ranges, got {n}")
|
|
288
|
+
return self.temperature_range_sampling(
|
|
289
|
+
input_data,
|
|
290
|
+
temperature_ranges=temperature_ranges,
|
|
291
|
+
n=n,
|
|
292
|
+
response_model=response_model,
|
|
293
|
+
**openai_client_kwargs,
|
|
294
|
+
)
|
|
295
|
+
openai_client_kwargs["n"] = n
|
|
296
|
+
|
|
297
|
+
# Handle two-step Pydantic parsing
|
|
298
|
+
pydantic_model = response_model or self.output_model
|
|
299
|
+
if two_step_parse_pydantic and pydantic_model not in (str, None):
|
|
300
|
+
choices = self.two_step_pydantic_parse(
|
|
301
|
+
input_data,
|
|
302
|
+
response_model=pydantic_model,
|
|
303
|
+
**openai_client_kwargs,
|
|
304
|
+
)
|
|
305
|
+
else:
|
|
306
|
+
choices = self.__inner_call__(
|
|
307
|
+
input_data,
|
|
308
|
+
response_model=response_model,
|
|
309
|
+
two_step_parse_pydantic=False,
|
|
310
|
+
**openai_client_kwargs,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Track conversation history
|
|
314
|
+
_last_conv = choices[0]["messages"] if choices else []
|
|
315
|
+
if not hasattr(self, "_last_conversations"):
|
|
316
|
+
self._last_conversations = []
|
|
317
|
+
else:
|
|
318
|
+
self._last_conversations = self._last_conversations[-100:]
|
|
319
|
+
self._last_conversations.append(_last_conv)
|
|
320
|
+
return choices
|
|
321
|
+
|
|
322
|
+
def inspect_history(self, idx: int = -1, k_last_messages: int = 2) -> List[Dict[str, Any]]:
|
|
323
|
+
"""Inspect the message history of a specific response choice."""
|
|
324
|
+
if hasattr(self, "_last_conversations"):
|
|
325
|
+
from llm_utils import show_chat_v2
|
|
326
|
+
|
|
327
|
+
conv = self._last_conversations[idx]
|
|
328
|
+
if k_last_messages > 0:
|
|
329
|
+
conv = conv[-k_last_messages:]
|
|
330
|
+
return show_chat_v2(conv)
|
|
331
|
+
else:
|
|
332
|
+
raise ValueError("No message history available. Make a call first.")
|
|
333
|
+
|
|
334
|
+
def __inner_call__(
|
|
335
|
+
self,
|
|
336
|
+
input_data: Union[str, BaseModel, list[Dict]],
|
|
337
|
+
response_model: Optional[Type[BaseModel] | Type[str]] = None,
|
|
338
|
+
two_step_parse_pydantic: bool = False,
|
|
339
|
+
**runtime_kwargs,
|
|
340
|
+
) -> List[Dict[str, Any]]:
|
|
341
|
+
"""
|
|
342
|
+
Internal call handler. Delegates to text() or parse() based on model.
|
|
343
|
+
|
|
344
|
+
Note: two_step_parse_pydantic is deprecated here; use the public
|
|
345
|
+
__call__ method which routes to the mixin.
|
|
346
|
+
"""
|
|
347
|
+
pydantic_model_to_use = response_model or self.output_model
|
|
348
|
+
|
|
349
|
+
if pydantic_model_to_use is str or pydantic_model_to_use is None:
|
|
350
|
+
return self.text_completion(input_data, **runtime_kwargs)
|
|
351
|
+
else:
|
|
352
|
+
return self.pydantic_parse(
|
|
353
|
+
input_data,
|
|
354
|
+
response_model=response_model,
|
|
355
|
+
**runtime_kwargs,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Backward compatibility aliases
|
|
359
|
+
def text(self, *args, **kwargs) -> List[Dict[str, Any]]:
|
|
360
|
+
"""Alias for text_completion() for backward compatibility."""
|
|
361
|
+
return self.text_completion(*args, **kwargs)
|
|
362
|
+
|
|
363
|
+
def parse(self, *args, **kwargs) -> List[Dict[str, Any]]:
|
|
364
|
+
"""Alias for pydantic_parse() for backward compatibility."""
|
|
365
|
+
return self.pydantic_parse(*args, **kwargs)
|
|
366
|
+
|
|
367
|
+
@classmethod
|
|
368
|
+
def from_prompt_builder(
|
|
369
|
+
builder: BasePromptBuilder,
|
|
370
|
+
client: Union[OpenAI, int, str, None] = None,
|
|
371
|
+
cache=True,
|
|
372
|
+
is_reasoning_model: bool = False,
|
|
373
|
+
lora_path: Optional[str] = None,
|
|
374
|
+
vllm_cmd: Optional[str] = None,
|
|
375
|
+
vllm_timeout: int = 120,
|
|
376
|
+
vllm_reuse: bool = True,
|
|
377
|
+
**model_kwargs,
|
|
378
|
+
) -> "LLM":
|
|
379
|
+
"""
|
|
380
|
+
Create an LLMTask instance from a BasePromptBuilder instance.
|
|
381
|
+
|
|
382
|
+
This method extracts the instruction, input model, and output model
|
|
383
|
+
from the provided builder and initializes an LLMTask accordingly.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
builder: BasePromptBuilder instance
|
|
387
|
+
client: OpenAI client, port number, or base_url string
|
|
388
|
+
cache: Whether to use cached responses (default True)
|
|
389
|
+
is_reasoning_model: Whether model is reasoning model (default False)
|
|
390
|
+
lora_path: Optional path to LoRA adapter directory
|
|
391
|
+
vllm_cmd: Optional VLLM command to start server automatically
|
|
392
|
+
vllm_timeout: Timeout in seconds to wait for VLLM server (default 120)
|
|
393
|
+
vllm_reuse: If True (default), reuse existing server on target port
|
|
394
|
+
**model_kwargs: Additional model parameters
|
|
395
|
+
"""
|
|
396
|
+
instruction = builder.get_instruction()
|
|
397
|
+
input_model = builder.get_input_model()
|
|
398
|
+
output_model = builder.get_output_model()
|
|
399
|
+
|
|
400
|
+
# Extract data from the builder to initialize LLMTask
|
|
401
|
+
return LLM(
|
|
402
|
+
instruction=instruction,
|
|
403
|
+
input_model=input_model,
|
|
404
|
+
output_model=output_model,
|
|
405
|
+
client=client,
|
|
406
|
+
cache=cache,
|
|
407
|
+
is_reasoning_model=is_reasoning_model,
|
|
408
|
+
lora_path=lora_path,
|
|
409
|
+
vllm_cmd=vllm_cmd,
|
|
410
|
+
vllm_timeout=vllm_timeout,
|
|
411
|
+
vllm_reuse=vllm_reuse,
|
|
412
|
+
**model_kwargs,
|
|
413
|
+
)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM-as-a-Judge implementation with template support and SFT export utilities.
|
|
3
|
+
|
|
4
|
+
This module provides a base class for creating LLM judges with structured
|
|
5
|
+
prompts, variable substitution, and export capabilities for fine-tuning.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
from ..chat_format import get_conversation_one_turn
|
|
12
|
+
from .llm import LLM
|
|
13
|
+
from .signature import Signature
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LLMSignature(LLM):
|
|
17
|
+
"""Base class for LLM judges with template support and SFT export."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, signature: Type[Signature], **kwargs):
|
|
20
|
+
"""
|
|
21
|
+
Initialize LLMJudgeBase.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
system_prompt_template: System prompt template with {variable} placeholders
|
|
25
|
+
signature: Optional Signature class for structured I/O
|
|
26
|
+
**kwargs: Additional arguments passed to LLMTask
|
|
27
|
+
"""
|
|
28
|
+
self.signature = signature
|
|
29
|
+
self.sft_data: List[Dict[str, Any]] = [] # Store SFT training examples
|
|
30
|
+
|
|
31
|
+
# Set instruction from signature if available
|
|
32
|
+
kwargs.setdefault("instruction", signature.get_instruction())
|
|
33
|
+
kwargs.setdefault("output_model", signature.get_output_model())
|
|
34
|
+
|
|
35
|
+
super().__init__(**kwargs)
|