speedy-utils 1.1.23__py3-none-any.whl → 1.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/llm.py ADDED
@@ -0,0 +1,413 @@
1
+ # type: ignore
2
+
3
+ """
4
+ Simplified LLM Task module for handling language model interactions with structured input/output.
5
+ """
6
+
7
+ import os
8
+ import subprocess
9
+ from typing import Any, Dict, List, Optional, Type, Union, cast
10
+
11
+ import requests
12
+ from loguru import logger
13
+ from openai import OpenAI, AuthenticationError, BadRequestError, RateLimitError
14
+ from openai.types.chat import ChatCompletionMessageParam
15
+ from pydantic import BaseModel
16
+
17
+ from speedy_utils.common.utils_io import jdumps
18
+
19
+ from .utils import (
20
+ _extract_port_from_vllm_cmd,
21
+ _start_vllm_server,
22
+ _kill_vllm_on_port,
23
+ _is_server_running,
24
+ get_base_client,
25
+ _is_lora_path,
26
+ _get_port_from_client,
27
+ _load_lora_adapter,
28
+ _unload_lora_adapter,
29
+ kill_all_vllm_processes,
30
+ stop_vllm_process,
31
+ )
32
+ from .base_prompt_builder import BasePromptBuilder
33
+ from .mixins import (
34
+ TemperatureRangeMixin,
35
+ TwoStepPydanticMixin,
36
+ VLLMMixin,
37
+ ModelUtilsMixin,
38
+ )
39
+
40
+ # Type aliases for better readability
41
+ Messages = List[ChatCompletionMessageParam]
42
+
43
+
44
+ class LLM(
45
+ TemperatureRangeMixin,
46
+ TwoStepPydanticMixin,
47
+ VLLMMixin,
48
+ ModelUtilsMixin,
49
+ ):
50
+ """LLM task with structured input/output handling."""
51
+
52
+ def __init__(
53
+ self,
54
+ instruction: Optional[str] = None,
55
+ input_model: Union[Type[BaseModel], type[str]] = str,
56
+ output_model: Type[BaseModel] | Type[str] = None,
57
+ client: Union[OpenAI, int, str, None] = None,
58
+ cache=True,
59
+ is_reasoning_model: bool = False,
60
+ force_lora_unload: bool = False,
61
+ lora_path: Optional[str] = None,
62
+ vllm_cmd: Optional[str] = None,
63
+ vllm_timeout: int = 1200,
64
+ vllm_reuse: bool = True,
65
+ **model_kwargs,
66
+ ):
67
+ """Initialize LLMTask."""
68
+ self.instruction = instruction
69
+ self.input_model = input_model
70
+ self.output_model = output_model
71
+ self.model_kwargs = model_kwargs
72
+ self.is_reasoning_model = is_reasoning_model
73
+ self.force_lora_unload = force_lora_unload
74
+ self.lora_path = lora_path
75
+ self.vllm_cmd = vllm_cmd
76
+ self.vllm_timeout = vllm_timeout
77
+ self.vllm_reuse = vllm_reuse
78
+ self.vllm_process: Optional[subprocess.Popen] = None
79
+ self.last_ai_response = None # Store raw response from client
80
+ self.cache = cache
81
+
82
+ # Handle VLLM server startup if vllm_cmd is provided
83
+ if self.vllm_cmd:
84
+ self._setup_vllm_server()
85
+
86
+ # Set client to use the VLLM server port if not explicitly provided
87
+ port = _extract_port_from_vllm_cmd(self.vllm_cmd)
88
+ if client is None:
89
+ client = port
90
+
91
+ self.client = get_base_client(client, cache=cache, vllm_cmd=self.vllm_cmd, vllm_process=self.vllm_process)
92
+ # check connection of client
93
+ try:
94
+ self.client.models.list()
95
+ except Exception as e:
96
+ logger.error(f"Failed to connect to OpenAI client: {str(e)}, base_url={self.client.base_url}")
97
+ raise e
98
+
99
+ if not self.model_kwargs.get("model", ""):
100
+ self.model_kwargs["model"] = self.client.models.list().data[0].id
101
+
102
+ # Handle LoRA loading if lora_path is provided
103
+ if self.lora_path:
104
+ self._load_lora_adapter()
105
+
106
+ def __enter__(self):
107
+ """Context manager entry."""
108
+ return self
109
+
110
+ def __exit__(self, exc_type, exc_val, exc_tb):
111
+ """Context manager exit with cleanup."""
112
+ self.cleanup_vllm_server()
113
+
114
+ def _prepare_input(self, input_data: Union[str, BaseModel, List[Dict]]) -> Messages:
115
+ """Convert input to messages format."""
116
+ if isinstance(input_data, list):
117
+ assert isinstance(input_data[0], dict) and "role" in input_data[0], (
118
+ "If input_data is a list, it must be a list of messages with 'role' and 'content' keys."
119
+ )
120
+ return cast(Messages, input_data)
121
+ else:
122
+ # Convert input to string format
123
+ if isinstance(input_data, str):
124
+ user_content = input_data
125
+ elif hasattr(input_data, "model_dump_json"):
126
+ user_content = input_data.model_dump_json()
127
+ elif isinstance(input_data, dict):
128
+ user_content = jdumps(input_data)
129
+ else:
130
+ user_content = str(input_data)
131
+
132
+ # Build messages
133
+ messages = (
134
+ [
135
+ {"role": "system", "content": self.instruction},
136
+ ]
137
+ if self.instruction is not None
138
+ else []
139
+ )
140
+
141
+ messages.append({"role": "user", "content": user_content})
142
+ return cast(Messages, messages)
143
+
144
+ def text_completion(self, input_data: Union[str, BaseModel, list[Dict]], **runtime_kwargs) -> List[Dict[str, Any]]:
145
+ """Execute LLM task and return text responses."""
146
+ # Prepare messages
147
+ messages = self._prepare_input(input_data)
148
+
149
+ # Merge runtime kwargs with default model kwargs (runtime takes precedence)
150
+ effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
151
+ model_name = effective_kwargs.get("model", self.model_kwargs["model"])
152
+
153
+ # Extract model name from kwargs for API call
154
+ api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
155
+
156
+ try:
157
+ completion = self.client.chat.completions.create(model=model_name, messages=messages, **api_kwargs)
158
+ # Store raw response from client
159
+ self.last_ai_response = completion
160
+ except (AuthenticationError, RateLimitError, BadRequestError) as exc:
161
+ error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
162
+ logger.error(error_msg)
163
+ raise
164
+ except Exception as e:
165
+ is_length_error = "Length" in str(e) or "maximum context length" in str(e)
166
+ if is_length_error:
167
+ raise ValueError(f"Input too long for model {model_name}. Error: {str(e)[:100]}...")
168
+ # Re-raise all other exceptions
169
+ raise
170
+ # print(completion)
171
+
172
+ results: List[Dict[str, Any]] = []
173
+ for choice in completion.choices:
174
+ choice_messages = cast(
175
+ Messages,
176
+ messages + [{"role": "assistant", "content": choice.message.content}],
177
+ )
178
+ result_dict = {"parsed": choice.message.content, "messages": choice_messages}
179
+
180
+ # Add reasoning content if this is a reasoning model
181
+ if self.is_reasoning_model and hasattr(choice.message, "reasoning_content"):
182
+ result_dict["reasoning_content"] = choice.message.reasoning_content
183
+
184
+ results.append(result_dict)
185
+ return results
186
+
187
+ def pydantic_parse(
188
+ self,
189
+ input_data: Union[str, BaseModel, list[Dict]],
190
+ response_model: Optional[Type[BaseModel]] | Type[str] = None,
191
+ **runtime_kwargs,
192
+ ) -> List[Dict[str, Any]]:
193
+ """Execute LLM task and return parsed Pydantic model responses."""
194
+ # Prepare messages
195
+ messages = self._prepare_input(input_data)
196
+
197
+ # Merge runtime kwargs with default model kwargs (runtime takes precedence)
198
+ effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
199
+ model_name = effective_kwargs.get("model", self.model_kwargs["model"])
200
+
201
+ # Extract model name from kwargs for API call
202
+ api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
203
+
204
+ pydantic_model_to_use_opt = response_model or self.output_model
205
+ if pydantic_model_to_use_opt is None:
206
+ raise ValueError(
207
+ "No response model specified. Either set output_model in constructor or pass response_model parameter."
208
+ )
209
+ pydantic_model_to_use: Type[BaseModel] = cast(Type[BaseModel], pydantic_model_to_use_opt)
210
+ try:
211
+ completion = self.client.chat.completions.parse(
212
+ model=model_name,
213
+ messages=messages,
214
+ response_format=pydantic_model_to_use,
215
+ **api_kwargs,
216
+ )
217
+ # Store raw response from client
218
+ self.last_ai_response = completion
219
+ except (AuthenticationError, RateLimitError, BadRequestError) as exc:
220
+ error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
221
+ logger.error(error_msg)
222
+ raise
223
+ except Exception as e:
224
+ is_length_error = "Length" in str(e) or "maximum context length" in str(e)
225
+ if is_length_error:
226
+ raise ValueError(f"Input too long for model {model_name}. Error: {str(e)[:100]}...")
227
+ # Re-raise all other exceptions
228
+ raise
229
+
230
+ results: List[Dict[str, Any]] = []
231
+ for choice in completion.choices: # type: ignore[attr-defined]
232
+ choice_messages = cast(
233
+ Messages,
234
+ messages + [{"role": "assistant", "content": choice.message.content}],
235
+ )
236
+
237
+ # Ensure consistent Pydantic model output for both fresh and cached responses
238
+ parsed_content = choice.message.parsed # type: ignore[attr-defined]
239
+ if isinstance(parsed_content, dict):
240
+ # Cached response: validate dict back to Pydantic model
241
+ parsed_content = pydantic_model_to_use.model_validate(parsed_content)
242
+ elif not isinstance(parsed_content, pydantic_model_to_use):
243
+ # Fallback: ensure it's the correct type
244
+ parsed_content = pydantic_model_to_use.model_validate(parsed_content)
245
+
246
+ result_dict = {"parsed": parsed_content, "messages": choice_messages}
247
+
248
+ # Add reasoning content if this is a reasoning model
249
+ if self.is_reasoning_model and hasattr(choice.message, "reasoning_content"):
250
+ result_dict["reasoning_content"] = choice.message.reasoning_content
251
+
252
+ results.append(result_dict)
253
+ return results
254
+
255
+ def __call__(
256
+ self,
257
+ input_data: Union[str, BaseModel, list[Dict]],
258
+ response_model: Optional[Type[BaseModel] | Type[str]] = None,
259
+ two_step_parse_pydantic: bool = False,
260
+ temperature_ranges: Optional[tuple[float, float]] = None,
261
+ n: int = 1,
262
+ cache=None,
263
+ **openai_client_kwargs,
264
+ ) -> List[Dict[str, Any]]:
265
+ """
266
+ Execute LLM task.
267
+
268
+ Args:
269
+ input_data: Input data (string, BaseModel, or message list)
270
+ response_model: Optional response model override
271
+ two_step_parse_pydantic: Use two-step parsing (text then parse)
272
+ temperature_ranges: If set, tuple of (min_temp, max_temp) to sample
273
+ n: Number of temperature samples (only used with temperature_ranges, must be >= 2)
274
+ **runtime_kwargs: Additional runtime parameters
275
+
276
+ Returns:
277
+ List of response dictionaries
278
+ """
279
+ if cache is not None:
280
+ if hasattr(self.client, "set_cache"):
281
+ self.client.set_cache(cache)
282
+ else:
283
+ logger.warning("Client does not support caching.")
284
+ # Handle temperature range sampling
285
+ if temperature_ranges is not None:
286
+ if n < 2:
287
+ raise ValueError(f"n must be >= 2 when using temperature_ranges, got {n}")
288
+ return self.temperature_range_sampling(
289
+ input_data,
290
+ temperature_ranges=temperature_ranges,
291
+ n=n,
292
+ response_model=response_model,
293
+ **openai_client_kwargs,
294
+ )
295
+ openai_client_kwargs["n"] = n
296
+
297
+ # Handle two-step Pydantic parsing
298
+ pydantic_model = response_model or self.output_model
299
+ if two_step_parse_pydantic and pydantic_model not in (str, None):
300
+ choices = self.two_step_pydantic_parse(
301
+ input_data,
302
+ response_model=pydantic_model,
303
+ **openai_client_kwargs,
304
+ )
305
+ else:
306
+ choices = self.__inner_call__(
307
+ input_data,
308
+ response_model=response_model,
309
+ two_step_parse_pydantic=False,
310
+ **openai_client_kwargs,
311
+ )
312
+
313
+ # Track conversation history
314
+ _last_conv = choices[0]["messages"] if choices else []
315
+ if not hasattr(self, "_last_conversations"):
316
+ self._last_conversations = []
317
+ else:
318
+ self._last_conversations = self._last_conversations[-100:]
319
+ self._last_conversations.append(_last_conv)
320
+ return choices
321
+
322
+ def inspect_history(self, idx: int = -1, k_last_messages: int = 2) -> List[Dict[str, Any]]:
323
+ """Inspect the message history of a specific response choice."""
324
+ if hasattr(self, "_last_conversations"):
325
+ from llm_utils import show_chat_v2
326
+
327
+ conv = self._last_conversations[idx]
328
+ if k_last_messages > 0:
329
+ conv = conv[-k_last_messages:]
330
+ return show_chat_v2(conv)
331
+ else:
332
+ raise ValueError("No message history available. Make a call first.")
333
+
334
+ def __inner_call__(
335
+ self,
336
+ input_data: Union[str, BaseModel, list[Dict]],
337
+ response_model: Optional[Type[BaseModel] | Type[str]] = None,
338
+ two_step_parse_pydantic: bool = False,
339
+ **runtime_kwargs,
340
+ ) -> List[Dict[str, Any]]:
341
+ """
342
+ Internal call handler. Delegates to text() or parse() based on model.
343
+
344
+ Note: two_step_parse_pydantic is deprecated here; use the public
345
+ __call__ method which routes to the mixin.
346
+ """
347
+ pydantic_model_to_use = response_model or self.output_model
348
+
349
+ if pydantic_model_to_use is str or pydantic_model_to_use is None:
350
+ return self.text_completion(input_data, **runtime_kwargs)
351
+ else:
352
+ return self.pydantic_parse(
353
+ input_data,
354
+ response_model=response_model,
355
+ **runtime_kwargs,
356
+ )
357
+
358
+ # Backward compatibility aliases
359
+ def text(self, *args, **kwargs) -> List[Dict[str, Any]]:
360
+ """Alias for text_completion() for backward compatibility."""
361
+ return self.text_completion(*args, **kwargs)
362
+
363
+ def parse(self, *args, **kwargs) -> List[Dict[str, Any]]:
364
+ """Alias for pydantic_parse() for backward compatibility."""
365
+ return self.pydantic_parse(*args, **kwargs)
366
+
367
+ @classmethod
368
+ def from_prompt_builder(
369
+ builder: BasePromptBuilder,
370
+ client: Union[OpenAI, int, str, None] = None,
371
+ cache=True,
372
+ is_reasoning_model: bool = False,
373
+ lora_path: Optional[str] = None,
374
+ vllm_cmd: Optional[str] = None,
375
+ vllm_timeout: int = 120,
376
+ vllm_reuse: bool = True,
377
+ **model_kwargs,
378
+ ) -> "LLM":
379
+ """
380
+ Create an LLMTask instance from a BasePromptBuilder instance.
381
+
382
+ This method extracts the instruction, input model, and output model
383
+ from the provided builder and initializes an LLMTask accordingly.
384
+
385
+ Args:
386
+ builder: BasePromptBuilder instance
387
+ client: OpenAI client, port number, or base_url string
388
+ cache: Whether to use cached responses (default True)
389
+ is_reasoning_model: Whether model is reasoning model (default False)
390
+ lora_path: Optional path to LoRA adapter directory
391
+ vllm_cmd: Optional VLLM command to start server automatically
392
+ vllm_timeout: Timeout in seconds to wait for VLLM server (default 120)
393
+ vllm_reuse: If True (default), reuse existing server on target port
394
+ **model_kwargs: Additional model parameters
395
+ """
396
+ instruction = builder.get_instruction()
397
+ input_model = builder.get_input_model()
398
+ output_model = builder.get_output_model()
399
+
400
+ # Extract data from the builder to initialize LLMTask
401
+ return LLM(
402
+ instruction=instruction,
403
+ input_model=input_model,
404
+ output_model=output_model,
405
+ client=client,
406
+ cache=cache,
407
+ is_reasoning_model=is_reasoning_model,
408
+ lora_path=lora_path,
409
+ vllm_cmd=vllm_cmd,
410
+ vllm_timeout=vllm_timeout,
411
+ vllm_reuse=vllm_reuse,
412
+ **model_kwargs,
413
+ )
@@ -0,0 +1,35 @@
1
+ """
2
+ LLM-as-a-Judge implementation with template support and SFT export utilities.
3
+
4
+ This module provides a base class for creating LLM judges with structured
5
+ prompts, variable substitution, and export capabilities for fine-tuning.
6
+ """
7
+
8
+ import json
9
+ from typing import Any, Dict, List, Optional, Type, Union
10
+ from pydantic import BaseModel
11
+ from ..chat_format import get_conversation_one_turn
12
+ from .llm import LLM
13
+ from .signature import Signature
14
+
15
+
16
+ class LLMSignature(LLM):
17
+ """Base class for LLM judges with template support and SFT export."""
18
+
19
+ def __init__(self, signature: Type[Signature], **kwargs):
20
+ """
21
+ Initialize LLMJudgeBase.
22
+
23
+ Args:
24
+ system_prompt_template: System prompt template with {variable} placeholders
25
+ signature: Optional Signature class for structured I/O
26
+ **kwargs: Additional arguments passed to LLMTask
27
+ """
28
+ self.signature = signature
29
+ self.sft_data: List[Dict[str, Any]] = [] # Store SFT training examples
30
+
31
+ # Set instruction from signature if available
32
+ kwargs.setdefault("instruction", signature.get_instruction())
33
+ kwargs.setdefault("output_model", signature.get_output_model())
34
+
35
+ super().__init__(**kwargs)