speedy-utils 1.1.23__py3-none-any.whl → 1.1.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/lm/llm_task.py DELETED
@@ -1,614 +0,0 @@
1
- # type: ignore
2
-
3
- """
4
- Simplified LLM Task module for handling language model interactions with structured input/output.
5
- """
6
-
7
- import os
8
- import subprocess
9
- from typing import Any, Dict, List, Optional, Type, Union, cast
10
-
11
- import requests
12
- from loguru import logger
13
- from openai import OpenAI, AuthenticationError, BadRequestError, RateLimitError
14
- from openai.types.chat import ChatCompletionMessageParam
15
- from pydantic import BaseModel
16
-
17
- from .utils import (
18
- _extract_port_from_vllm_cmd,
19
- _start_vllm_server,
20
- _kill_vllm_on_port,
21
- _is_server_running,
22
- get_base_client,
23
- _is_lora_path,
24
- _get_port_from_client,
25
- _load_lora_adapter,
26
- _unload_lora_adapter,
27
- kill_all_vllm_processes,
28
- stop_vllm_process,
29
- )
30
- from .base_prompt_builder import BasePromptBuilder
31
-
32
- # Type aliases for better readability
33
- Messages = List[ChatCompletionMessageParam]
34
-
35
-
36
- class LLMTask:
37
- """LLM task with structured input/output handling."""
38
-
39
- def __init__(
40
- self,
41
- instruction: Optional[str] = None,
42
- input_model: Union[Type[BaseModel], type[str]] = str,
43
- output_model: Type[BaseModel] | Type[str] = None,
44
- client: Union[OpenAI, int, str, None] = None,
45
- cache=True,
46
- is_reasoning_model: bool = False,
47
- force_lora_unload: bool = False,
48
- lora_path: Optional[str] = None,
49
- vllm_cmd: Optional[str] = None,
50
- vllm_timeout: int = 1200,
51
- vllm_reuse: bool = True,
52
- **model_kwargs,
53
- ):
54
- """Initialize LLMTask."""
55
- self.instruction = instruction
56
- self.input_model = input_model
57
- self.output_model = output_model
58
- self.model_kwargs = model_kwargs
59
- self.is_reasoning_model = is_reasoning_model
60
- self.force_lora_unload = force_lora_unload
61
- self.lora_path = lora_path
62
- self.vllm_cmd = vllm_cmd
63
- self.vllm_timeout = vllm_timeout
64
- self.vllm_reuse = vllm_reuse
65
- self.vllm_process: Optional[subprocess.Popen] = None
66
- self.last_ai_response = None # Store raw response from client
67
-
68
- # Handle VLLM server startup if vllm_cmd is provided
69
- if self.vllm_cmd:
70
- port = _extract_port_from_vllm_cmd(self.vllm_cmd)
71
- reuse_existing = False
72
-
73
- if self.vllm_reuse:
74
- try:
75
- reuse_client = get_base_client(port, cache=False)
76
- models_response = reuse_client.models.list()
77
- if getattr(models_response, "data", None):
78
- reuse_existing = True
79
- logger.info(
80
- f"VLLM server already running on port {port}, "
81
- "reusing existing server (vllm_reuse=True)"
82
- )
83
- else:
84
- logger.info(
85
- f"No models returned from VLLM server on port {port}; "
86
- "starting a new server"
87
- )
88
- except Exception as exc:
89
- logger.info(
90
- f"Unable to reach VLLM server on port {port} (list_models failed): {exc}. "
91
- "Starting a new server."
92
- )
93
-
94
- if not self.vllm_reuse:
95
- if _is_server_running(port):
96
- logger.info(
97
- f"VLLM server already running on port {port}, killing it first (vllm_reuse=False)"
98
- )
99
- _kill_vllm_on_port(port)
100
- logger.info(f"Starting new VLLM server on port {port}")
101
- self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
102
- elif not reuse_existing:
103
- logger.info(f"Starting VLLM server on port {port}")
104
- self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
105
-
106
- # Set client to use the VLLM server port if not explicitly provided
107
- if client is None:
108
- client = port
109
-
110
- self.client = get_base_client(client, cache=cache, vllm_cmd=self.vllm_cmd, vllm_process=self.vllm_process)
111
- # check connection of client
112
- try:
113
- self.client.models.list()
114
- except Exception as e:
115
- logger.error(f"Failed to connect to OpenAI client: {str(e)}, base_url={self.client.base_url}")
116
- raise e
117
-
118
- if not self.model_kwargs.get("model", ""):
119
- self.model_kwargs["model"] = self.client.models.list().data[0].id
120
-
121
- # Handle LoRA loading if lora_path is provided
122
- if self.lora_path:
123
- self._load_lora_adapter()
124
-
125
- def cleanup_vllm_server(self) -> None:
126
- """Stop the VLLM server process if it was started by this instance."""
127
- if self.vllm_process is not None:
128
- stop_vllm_process(self.vllm_process)
129
- self.vllm_process = None
130
-
131
- def __enter__(self):
132
- """Context manager entry."""
133
- return self
134
-
135
- def __exit__(self, exc_type, exc_val, exc_tb):
136
- """Context manager exit with cleanup."""
137
- self.cleanup_vllm_server()
138
-
139
- def _load_lora_adapter(self) -> None:
140
- """
141
- Load LoRA adapter from the specified lora_path.
142
-
143
- This method:
144
- 1. Validates that lora_path is a valid LoRA directory
145
- 2. Checks if LoRA is already loaded (unless force_lora_unload is True)
146
- 3. Loads the LoRA adapter and updates the model name
147
- """
148
- if not self.lora_path:
149
- return
150
-
151
- if not _is_lora_path(self.lora_path):
152
- raise ValueError(
153
- f"Invalid LoRA path '{self.lora_path}': "
154
- "Directory must contain 'adapter_config.json'"
155
- )
156
-
157
- logger.info(f"Loading LoRA adapter from: {self.lora_path}")
158
-
159
- # Get the expected LoRA name (basename of the path)
160
- lora_name = os.path.basename(self.lora_path.rstrip('/\\'))
161
- if not lora_name: # Handle edge case of empty basename
162
- lora_name = os.path.basename(os.path.dirname(self.lora_path))
163
-
164
- # Get list of available models to check if LoRA is already loaded
165
- try:
166
- available_models = [m.id for m in self.client.models.list().data]
167
- except Exception as e:
168
- logger.warning(f"Failed to list models, proceeding with LoRA load: {str(e)[:100]}")
169
- available_models = []
170
-
171
- # Check if LoRA is already loaded
172
- if lora_name in available_models and not self.force_lora_unload:
173
- logger.info(f"LoRA adapter '{lora_name}' is already loaded, using existing model")
174
- self.model_kwargs["model"] = lora_name
175
- return
176
-
177
- # Force unload if requested
178
- if self.force_lora_unload and lora_name in available_models:
179
- logger.info(f"Force unloading LoRA adapter '{lora_name}' before reloading")
180
- port = _get_port_from_client(self.client)
181
- if port is not None:
182
- try:
183
- LLMTask.unload_lora(port, lora_name)
184
- logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
185
- except Exception as e:
186
- logger.warning(f"Failed to unload LoRA adapter: {str(e)[:100]}")
187
-
188
- # Get port from client for API calls
189
- port = _get_port_from_client(self.client)
190
- if port is None:
191
- raise ValueError(
192
- f"Cannot load LoRA adapter '{self.lora_path}': "
193
- "Unable to determine port from client base_url. "
194
- "LoRA loading requires a client initialized with port number."
195
- )
196
-
197
- try:
198
- # Load the LoRA adapter
199
- loaded_lora_name = _load_lora_adapter(self.lora_path, port)
200
- logger.info(f"Successfully loaded LoRA adapter: {loaded_lora_name}")
201
-
202
- # Update model name to the loaded LoRA name
203
- self.model_kwargs["model"] = loaded_lora_name
204
-
205
- except requests.RequestException as e:
206
- # Check if the error is due to LoRA already being loaded
207
- error_msg = str(e)
208
- if "400" in error_msg or "Bad Request" in error_msg:
209
- logger.info(f"LoRA adapter may already be loaded, attempting to use '{lora_name}'")
210
- # Refresh the model list to check if it's now available
211
- try:
212
- updated_models = [m.id for m in self.client.models.list().data]
213
- if lora_name in updated_models:
214
- logger.info(f"Found LoRA adapter '{lora_name}' in updated model list")
215
- self.model_kwargs["model"] = lora_name
216
- return
217
- except Exception:
218
- pass # Fall through to original error
219
-
220
- raise ValueError(
221
- f"Failed to load LoRA adapter from '{self.lora_path}': {error_msg[:100]}"
222
- )
223
-
224
- def unload_lora_adapter(self, lora_path: str) -> None:
225
- """
226
- Unload a LoRA adapter.
227
-
228
- Args:
229
- lora_path: Path to the LoRA adapter directory to unload
230
-
231
- Raises:
232
- ValueError: If unable to determine port from client
233
- """
234
- port = _get_port_from_client(self.client)
235
- if port is None:
236
- raise ValueError(
237
- "Cannot unload LoRA adapter: "
238
- "Unable to determine port from client base_url. "
239
- "LoRA operations require a client initialized with port number."
240
- )
241
-
242
- _unload_lora_adapter(lora_path, port)
243
- lora_name = os.path.basename(lora_path.rstrip('/\\'))
244
- logger.info(f"Unloaded LoRA adapter: {lora_name}")
245
-
246
- @staticmethod
247
- def unload_lora(port: int, lora_name: str) -> None:
248
- """Static method to unload a LoRA adapter by name.
249
-
250
- Args:
251
- port: Port number for the API endpoint
252
- lora_name: Name of the LoRA adapter to unload
253
-
254
- Raises:
255
- requests.RequestException: If the API call fails
256
- """
257
- try:
258
- response = requests.post(
259
- f'http://localhost:{port}/v1/unload_lora_adapter',
260
- headers={'accept': 'application/json', 'Content-Type': 'application/json'},
261
- json={"lora_name": lora_name, "lora_int_id": 0}
262
- )
263
- response.raise_for_status()
264
- logger.info(f"Successfully unloaded LoRA adapter: {lora_name}")
265
- except requests.RequestException as e:
266
- logger.error(f"Error unloading LoRA adapter '{lora_name}': {str(e)[:100]}")
267
- raise
268
-
269
- def _prepare_input(self, input_data: Union[str, BaseModel, List[Dict]]) -> Messages:
270
- """Convert input to messages format."""
271
- if isinstance(input_data, list):
272
- assert isinstance(input_data[0], dict) and "role" in input_data[0], (
273
- "If input_data is a list, it must be a list of messages with 'role' and 'content' keys."
274
- )
275
- return cast(Messages, input_data)
276
- else:
277
- # Convert input to string format
278
- if isinstance(input_data, str):
279
- user_content = input_data
280
- elif hasattr(input_data, "model_dump_json"):
281
- user_content = input_data.model_dump_json()
282
- elif isinstance(input_data, dict):
283
- user_content = str(input_data)
284
- else:
285
- user_content = str(input_data)
286
-
287
- # Build messages
288
- messages = (
289
- [
290
- {"role": "system", "content": self.instruction},
291
- ]
292
- if self.instruction is not None
293
- else []
294
- )
295
-
296
- messages.append({"role": "user", "content": user_content})
297
- return cast(Messages, messages)
298
-
299
- def text_completion(
300
- self, input_data: Union[str, BaseModel, list[Dict]], **runtime_kwargs
301
- ) -> List[Dict[str, Any]]:
302
- """Execute LLM task and return text responses."""
303
- # Prepare messages
304
- messages = self._prepare_input(input_data)
305
-
306
- # Merge runtime kwargs with default model kwargs (runtime takes precedence)
307
- effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
308
- model_name = effective_kwargs.get("model", self.model_kwargs["model"])
309
-
310
- # Extract model name from kwargs for API call
311
- api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
312
-
313
- try:
314
- completion = self.client.chat.completions.create(
315
- model=model_name, messages=messages, **api_kwargs
316
- )
317
- # Store raw response from client
318
- self.last_ai_response = completion
319
- except (AuthenticationError, RateLimitError, BadRequestError) as exc:
320
- error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
321
- logger.error(error_msg)
322
- raise
323
- except Exception as e:
324
- is_length_error = "Length" in str(e) or "maximum context length" in str(e)
325
- if is_length_error:
326
- raise ValueError(
327
- f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
328
- )
329
- # Re-raise all other exceptions
330
- raise
331
- # print(completion)
332
-
333
- results: List[Dict[str, Any]] = []
334
- for choice in completion.choices:
335
- choice_messages = cast(
336
- Messages,
337
- messages + [{"role": "assistant", "content": choice.message.content}],
338
- )
339
- result_dict = {"parsed": choice.message.content, "messages": choice_messages}
340
-
341
- # Add reasoning content if this is a reasoning model
342
- if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
343
- result_dict["reasoning_content"] = choice.message.reasoning_content
344
-
345
- results.append(result_dict)
346
- return results
347
-
348
- def pydantic_parse(
349
- self,
350
- input_data: Union[str, BaseModel, list[Dict]],
351
- response_model: Optional[Type[BaseModel]] | Type[str] = None,
352
- **runtime_kwargs,
353
- ) -> List[Dict[str, Any]]:
354
- """Execute LLM task and return parsed Pydantic model responses."""
355
- # Prepare messages
356
- messages = self._prepare_input(input_data)
357
-
358
- # Merge runtime kwargs with default model kwargs (runtime takes precedence)
359
- effective_kwargs = {**self.model_kwargs, **runtime_kwargs}
360
- model_name = effective_kwargs.get("model", self.model_kwargs["model"])
361
-
362
- # Extract model name from kwargs for API call
363
- api_kwargs = {k: v for k, v in effective_kwargs.items() if k != "model"}
364
-
365
- pydantic_model_to_use_opt = response_model or self.output_model
366
- if pydantic_model_to_use_opt is None:
367
- raise ValueError(
368
- "No response model specified. Either set output_model in constructor or pass response_model parameter."
369
- )
370
- pydantic_model_to_use: Type[BaseModel] = cast(
371
- Type[BaseModel], pydantic_model_to_use_opt
372
- )
373
- try:
374
- completion = self.client.chat.completions.parse(
375
- model=model_name,
376
- messages=messages,
377
- response_format=pydantic_model_to_use,
378
- **api_kwargs,
379
- )
380
- # Store raw response from client
381
- self.last_ai_response = completion
382
- except (AuthenticationError, RateLimitError, BadRequestError) as exc:
383
- error_msg = f"OpenAI API error ({type(exc).__name__}): {exc}"
384
- logger.error(error_msg)
385
- raise
386
- except Exception as e:
387
- is_length_error = "Length" in str(e) or "maximum context length" in str(e)
388
- if is_length_error:
389
- raise ValueError(
390
- f"Input too long for model {model_name}. Error: {str(e)[:100]}..."
391
- )
392
- # Re-raise all other exceptions
393
- raise
394
-
395
- results: List[Dict[str, Any]] = []
396
- for choice in completion.choices: # type: ignore[attr-defined]
397
- choice_messages = cast(
398
- Messages,
399
- messages + [{"role": "assistant", "content": choice.message.content}],
400
- )
401
-
402
- # Ensure consistent Pydantic model output for both fresh and cached responses
403
- parsed_content = choice.message.parsed # type: ignore[attr-defined]
404
- if isinstance(parsed_content, dict):
405
- # Cached response: validate dict back to Pydantic model
406
- parsed_content = pydantic_model_to_use.model_validate(parsed_content)
407
- elif not isinstance(parsed_content, pydantic_model_to_use):
408
- # Fallback: ensure it's the correct type
409
- parsed_content = pydantic_model_to_use.model_validate(parsed_content)
410
-
411
- result_dict = {"parsed": parsed_content, "messages": choice_messages}
412
-
413
- # Add reasoning content if this is a reasoning model
414
- if self.is_reasoning_model and hasattr(choice.message, 'reasoning_content'):
415
- result_dict["reasoning_content"] = choice.message.reasoning_content
416
-
417
- results.append(result_dict)
418
- return results
419
-
420
- def __call__(
421
- self,
422
- input_data: Union[str, BaseModel, list[Dict]],
423
- response_model: Optional[Type[BaseModel] | Type[str]] = None,
424
- two_step_parse_pydantic=False,
425
- **runtime_kwargs,
426
- ) -> List[Dict[str, Any]]:
427
- """Execute LLM task. Delegates to text() or parse() based on output_model."""
428
- pydantic_model_to_use = response_model or self.output_model
429
-
430
- if pydantic_model_to_use is str or pydantic_model_to_use is None:
431
- return self.text_completion(input_data, **runtime_kwargs)
432
- elif two_step_parse_pydantic:
433
- # step 1: get text completions
434
- results = self.text_completion(input_data, **runtime_kwargs)
435
- parsed_results = []
436
- for result in results:
437
- response_text = result["parsed"]
438
- messages = result["messages"]
439
- # check if the pydantic_model_to_use is validated
440
- if "</think>" in response_text:
441
- response_text = response_text.split("</think>")[1]
442
- try:
443
- parsed = pydantic_model_to_use.model_validate_json(response_text)
444
- except Exception:
445
- # Failed to parse JSON, falling back to LLM parsing
446
- # use model to parse the response_text
447
- _parsed_messages = [
448
- {
449
- "role": "system",
450
- "content": "You are a helpful assistant that extracts JSON from text.",
451
- },
452
- {
453
- "role": "user",
454
- "content": f"Extract JSON from the following text:\n{response_text}",
455
- },
456
- ]
457
- parsed_result = self.pydantic_parse(
458
- _parsed_messages,
459
- response_model=pydantic_model_to_use,
460
- **runtime_kwargs,
461
- )[0]
462
- parsed = parsed_result["parsed"]
463
- # ---
464
- parsed_results.append({"parsed": parsed, "messages": messages})
465
- return parsed_results
466
-
467
- else:
468
- return self.pydantic_parse(
469
- input_data, response_model=response_model, **runtime_kwargs
470
- )
471
-
472
- # Backward compatibility aliases
473
- def text(self, *args, **kwargs) -> List[Dict[str, Any]]:
474
- """Alias for text_completion() for backward compatibility."""
475
- return self.text_completion(*args, **kwargs)
476
-
477
- def parse(self, *args, **kwargs) -> List[Dict[str, Any]]:
478
- """Alias for pydantic_parse() for backward compatibility."""
479
- return self.pydantic_parse(*args, **kwargs)
480
-
481
- @classmethod
482
- def from_prompt_builder(
483
- builder: BasePromptBuilder,
484
- client: Union[OpenAI, int, str, None] = None,
485
- cache=True,
486
- is_reasoning_model: bool = False,
487
- lora_path: Optional[str] = None,
488
- vllm_cmd: Optional[str] = None,
489
- vllm_timeout: int = 120,
490
- vllm_reuse: bool = True,
491
- **model_kwargs,
492
- ) -> "LLMTask":
493
- """
494
- Create an LLMTask instance from a BasePromptBuilder instance.
495
-
496
- This method extracts the instruction, input model, and output model
497
- from the provided builder and initializes an LLMTask accordingly.
498
-
499
- Args:
500
- builder: BasePromptBuilder instance
501
- client: OpenAI client, port number, or base_url string
502
- cache: Whether to use cached responses (default True)
503
- is_reasoning_model: Whether model is reasoning model (default False)
504
- lora_path: Optional path to LoRA adapter directory
505
- vllm_cmd: Optional VLLM command to start server automatically
506
- vllm_timeout: Timeout in seconds to wait for VLLM server (default 120)
507
- vllm_reuse: If True (default), reuse existing server on target port
508
- **model_kwargs: Additional model parameters
509
- """
510
- instruction = builder.get_instruction()
511
- input_model = builder.get_input_model()
512
- output_model = builder.get_output_model()
513
-
514
- # Extract data from the builder to initialize LLMTask
515
- return LLMTask(
516
- instruction=instruction,
517
- input_model=input_model,
518
- output_model=output_model,
519
- client=client,
520
- cache=cache,
521
- is_reasoning_model=is_reasoning_model,
522
- lora_path=lora_path,
523
- vllm_cmd=vllm_cmd,
524
- vllm_timeout=vllm_timeout,
525
- vllm_reuse=vllm_reuse,
526
- **model_kwargs,
527
- )
528
-
529
- @staticmethod
530
- def list_models(client: Union[OpenAI, int, str, None] = None) -> List[str]:
531
- """
532
- List available models from the OpenAI client.
533
-
534
- Args:
535
- client: OpenAI client, port number, or base_url string
536
-
537
- Returns:
538
- List of available model names.
539
- """
540
- client = get_base_client(client, cache=False)
541
- models = client.models.list().data
542
- return [m.id for m in models]
543
-
544
- @staticmethod
545
- def kill_all_vllm() -> int:
546
- """Kill all tracked VLLM server processes."""
547
- return kill_all_vllm_processes()
548
-
549
- @staticmethod
550
- def kill_vllm_on_port(port: int) -> bool:
551
- """
552
- Kill VLLM server running on a specific port.
553
-
554
- Args:
555
- port: Port number to kill server on
556
-
557
- Returns:
558
- True if a server was killed, False if no server was running
559
- """
560
- return _kill_vllm_on_port(port)
561
-
562
-
563
- # Example usage:
564
- if __name__ == "__main__":
565
- # Example 1: Using VLLM with reuse (default behavior)
566
- vllm_command = (
567
- "vllm serve saves/vng/dpo/01 -tp 4 --port 8001 "
568
- "--gpu-memory-utilization 0.9 --served-model-name sft --quantization experts_int8"
569
- )
570
-
571
- print("Example 1: Using VLLM with server reuse (default)")
572
- # Create LLM instance - will reuse existing server if running on port 8001
573
- with LLMTask(vllm_cmd=vllm_command) as llm: # vllm_reuse=True by default
574
- result = llm.text("Hello, how are you?")
575
- print("Response:", result[0]["parsed"])
576
-
577
- print("\nExample 2: Force restart server (vllm_reuse=False)")
578
- # This will kill any existing server on port 8001 and start fresh
579
- with LLMTask(vllm_cmd=vllm_command, vllm_reuse=False) as llm:
580
- result = llm.text("Tell me a joke")
581
- print("Joke:", result[0]["parsed"])
582
-
583
- print("\nExample 3: Multiple instances with reuse")
584
- # First instance starts the server
585
- llm1 = LLMTask(vllm_cmd=vllm_command) # Starts server or reuses existing
586
-
587
- # Second instance reuses the same server
588
- llm2 = LLMTask(vllm_cmd=vllm_command) # Reuses server on port 8001
589
-
590
- try:
591
- result1 = llm1.text("What's the weather like?")
592
- result2 = llm2.text("How's the traffic?")
593
- print("Weather response:", result1[0]["parsed"])
594
- print("Traffic response:", result2[0]["parsed"])
595
- finally:
596
- # Only cleanup if we started the process
597
- llm1.cleanup_vllm_server()
598
- llm2.cleanup_vllm_server() # Won't do anything if process not owned
599
-
600
- print("\nExample 4: Different ports")
601
- # These will start separate servers
602
- llm_8001 = LLMTask(vllm_cmd="vllm serve model1 --port 8001", vllm_reuse=True)
603
- llm_8002 = LLMTask(vllm_cmd="vllm serve model2 --port 8002", vllm_reuse=True)
604
-
605
- print("\nExample 5: Kill all VLLM servers")
606
- # Kill all tracked VLLM processes
607
- killed_count = LLMTask.kill_all_vllm()
608
- print(f"Killed {killed_count} VLLM servers")
609
-
610
- print("\nYou can check VLLM server output at: /tmp/vllm.txt")
611
- print("Server reuse behavior:")
612
- print("- vllm_reuse=True (default): Reuse existing server on target port")
613
- print("- vllm_reuse=False: Kill existing server first, then start fresh")
614
-