speedy-utils 1.1.21__py3-none-any.whl → 1.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/__init__.py CHANGED
@@ -6,6 +6,15 @@ from llm_utils.lm.base_prompt_builder import BasePromptBuilder
6
6
 
7
7
  LLM = LLMTask
8
8
 
9
+ # Convenience functions for killing VLLM servers
10
+ def kill_all_vllm() -> int:
11
+ """Kill all tracked VLLM server processes. Returns number of processes killed."""
12
+ return LLMTask.kill_all_vllm()
13
+
14
+ def kill_vllm_on_port(port: int) -> bool:
15
+ """Kill VLLM server on specific port. Returns True if server was killed."""
16
+ return LLMTask.kill_vllm_on_port(port)
17
+
9
18
  from .chat_format import (
10
19
  build_chatml_input,
11
20
  display_chat_messages_as_html,
@@ -35,5 +44,7 @@ __all__ = [
35
44
  "get_model_name",
36
45
  "VectorCache",
37
46
  "BasePromptBuilder",
38
- "LLM"
47
+ "LLM",
48
+ "kill_all_vllm",
49
+ "kill_vllm_on_port"
39
50
  ]
llm_utils/lm/llm_task.py CHANGED
@@ -5,6 +5,7 @@ Simplified LLM Task module for handling language model interactions with structu
5
5
  """
6
6
 
7
7
  import os
8
+ import subprocess
8
9
  from typing import Any, Dict, List, Optional, Type, Union, cast
9
10
 
10
11
  import requests
@@ -13,177 +14,27 @@ from openai import OpenAI, AuthenticationError, BadRequestError, RateLimitError
13
14
  from openai.types.chat import ChatCompletionMessageParam
14
15
  from pydantic import BaseModel
15
16
 
17
+ from .utils import (
18
+ _extract_port_from_vllm_cmd,
19
+ _start_vllm_server,
20
+ _kill_vllm_on_port,
21
+ _is_server_running,
22
+ get_base_client,
23
+ _is_lora_path,
24
+ _get_port_from_client,
25
+ _load_lora_adapter,
26
+ _unload_lora_adapter,
27
+ kill_all_vllm_processes,
28
+ stop_vllm_process,
29
+ )
16
30
  from .base_prompt_builder import BasePromptBuilder
17
31
 
18
32
  # Type aliases for better readability
19
33
  Messages = List[ChatCompletionMessageParam]
20
34
 
21
35
 
22
- def get_base_client(
23
- client: Union[OpenAI, int, str, None] = None, cache: bool = True, api_key="abc"
24
- ) -> OpenAI:
25
- """Get OpenAI client from port number, base_url string, or existing client."""
26
- from llm_utils import MOpenAI
27
-
28
- open_ai_class = OpenAI if not cache else MOpenAI
29
- if client is None:
30
- return open_ai_class()
31
- elif isinstance(client, int):
32
- return open_ai_class(base_url=f"http://localhost:{client}/v1", api_key=api_key)
33
- elif isinstance(client, str):
34
- return open_ai_class(base_url=client, api_key=api_key)
35
- elif isinstance(client, OpenAI):
36
- return client
37
- else:
38
- raise ValueError(
39
- "Invalid client type. Must be OpenAI instance, port number (int), base_url (str), or None."
40
- )
41
-
42
-
43
- def _is_lora_path(path: str) -> bool:
44
- """Check if the given path is a LoRA adapter directory.
45
-
46
- Args:
47
- path: Path to check
48
-
49
- Returns:
50
- True if the path contains adapter_config.json, False otherwise
51
- """
52
- if not os.path.isdir(path):
53
- return False
54
- adapter_config_path = os.path.join(path, 'adapter_config.json')
55
- return os.path.isfile(adapter_config_path)
56
-
57
-
58
- def _get_port_from_client(client: OpenAI) -> Optional[int]:
59
- """Extract port number from OpenAI client base_url.
60
-
61
- Args:
62
- client: OpenAI client instance
63
-
64
- Returns:
65
- Port number if found, None otherwise
66
- """
67
- if hasattr(client, 'base_url') and client.base_url:
68
- base_url = str(client.base_url)
69
- if 'localhost:' in base_url:
70
- try:
71
- # Extract port from localhost:PORT/v1 format
72
- port_part = base_url.split('localhost:')[1].split('/')[0]
73
- return int(port_part)
74
- except (IndexError, ValueError):
75
- pass
76
- return None
77
-
78
-
79
- def _load_lora_adapter(lora_path: str, port: int) -> str:
80
- """Load a LoRA adapter from the specified path.
81
-
82
- Args:
83
- lora_path: Path to the LoRA adapter directory
84
- port: Port number for the API endpoint
85
-
86
- Returns:
87
- Name of the loaded LoRA adapter
88
-
89
- Raises:
90
- requests.RequestException: If the API call fails
91
- """
92
- lora_name = os.path.basename(lora_path.rstrip('/\\'))
93
- if not lora_name: # Handle edge case of empty basename
94
- lora_name = os.path.basename(os.path.dirname(lora_path))
95
-
96
- response = requests.post(
97
- f'http://localhost:{port}/v1/load_lora_adapter',
98
- headers={'accept': 'application/json', 'Content-Type': 'application/json'},
99
- json={"lora_name": lora_name, "lora_path": os.path.abspath(lora_path)}
100
- )
101
- response.raise_for_status()
102
- return lora_name
103
-
104
-
105
- def _unload_lora_adapter(lora_path: str, port: int) -> None:
106
- """Unload the current LoRA adapter.
107
-
108
- Args:
109
- lora_path: Path to the LoRA adapter directory
110
- port: Port number for the API endpoint
111
- """
112
- try:
113
- lora_name = os.path.basename(lora_path.rstrip('/\\'))
114
- if not lora_name: # Handle edge case of empty basename
115
- lora_name = os.path.basename(os.path.dirname(lora_path))
116
-
117
- response = requests.post(
118
- f'http://localhost:{port}/v1/unload_lora_adapter',
119
- headers={'accept': 'application/json', 'Content-Type': 'application/json'},
120
- json={"lora_name": lora_name, "lora_int_id": 0}
121
- )
122
- response.raise_for_status()
123
- except requests.RequestException as e:
124
- logger.warning(f"Error unloading LoRA adapter: {str(e)[:100]}")
125
-
126
-
127
36
  class LLMTask:
128
- """
129
- Language model task with structured input/output and optional system instruction.
130
-
131
- Supports str or Pydantic models for both input and output. Automatically handles
132
- message formatting and response parsing.
133
-
134
- Two main APIs:
135
- - text(): Returns raw text responses as list of dicts (alias for text_completion)
136
- - parse(): Returns parsed Pydantic model responses as list of dicts (alias for pydantic_parse)
137
- - __call__(): Backward compatibility method that delegates based on output_model
138
-
139
- Example:
140
- ```python
141
- from pydantic import BaseModel
142
- from llm_utils.lm.llm_task import LLMTask
143
-
144
- class EmailOutput(BaseModel):
145
- content: str
146
- estimated_read_time: int
147
-
148
- # Set up task with Pydantic output model
149
- task = LLMTask(
150
- instruction="Generate professional email content.",
151
- output_model=EmailOutput,
152
- client=OpenAI(),
153
- temperature=0.7
154
- )
155
-
156
- # Use parse() for structured output
157
- results = task.parse("Write a meeting follow-up email")
158
- result = results[0]
159
- print(result["parsed"].content, result["parsed"].estimated_read_time)
160
-
161
- # Use text() for plain text output
162
- results = task.text("Write a meeting follow-up email")
163
- text_result = results[0]
164
- print(text_result["parsed"])
165
-
166
- # Multiple responses
167
- results = task.parse("Write a meeting follow-up email", n=3)
168
- for result in results:
169
- print(f"Content: {result['parsed'].content}")
170
-
171
- # Override parameters at runtime
172
- results = task.text(
173
- "Write a meeting follow-up email",
174
- temperature=0.9,
175
- n=2,
176
- max_tokens=500
177
- )
178
- for result in results:
179
- print(result["parsed"])
180
-
181
- # Backward compatibility (uses output_model to choose method)
182
- results = task("Write a meeting follow-up email") # Calls parse()
183
- result = results[0]
184
- print(result["parsed"].content)
185
- ```
186
- """
37
+ """LLM task with structured input/output handling."""
187
38
 
188
39
  def __init__(
189
40
  self,
@@ -195,29 +46,12 @@ class LLMTask:
195
46
  is_reasoning_model: bool = False,
196
47
  force_lora_unload: bool = False,
197
48
  lora_path: Optional[str] = None,
49
+ vllm_cmd: Optional[str] = None,
50
+ vllm_timeout: int = 1200,
51
+ vllm_reuse: bool = True,
198
52
  **model_kwargs,
199
53
  ):
200
- """
201
- Initialize the LLMTask.
202
-
203
- Args:
204
- instruction: Optional system instruction for the task
205
- input_model: Input type (str or BaseModel subclass)
206
- output_model: Output BaseModel type
207
- client: OpenAI client, port number, or base_url string
208
- cache: Whether to use cached responses (default True)
209
- is_reasoning_model: Whether the model is a reasoning model (o1-preview, o1-mini, etc.)
210
- that outputs reasoning_content separately from content (default False)
211
- force_lora_unload: If True, forces unloading of any existing LoRA adapter before loading
212
- a new one when lora_path is provided (default False)
213
- lora_path: Optional path to LoRA adapter directory. If provided, will load the LoRA
214
- and use it as the model. Takes precedence over model parameter.
215
- **model_kwargs: Additional model parameters including:
216
- - temperature: Controls randomness (0.0 to 2.0)
217
- - n: Number of responses to generate (when n > 1, returns list)
218
- - max_tokens: Maximum tokens in response
219
- - model: Model name (auto-detected if not provided)
220
- """
54
+ """Initialize LLMTask."""
221
55
  self.instruction = instruction
222
56
  self.input_model = input_model
223
57
  self.output_model = output_model
@@ -225,15 +59,55 @@ class LLMTask:
225
59
  self.is_reasoning_model = is_reasoning_model
226
60
  self.force_lora_unload = force_lora_unload
227
61
  self.lora_path = lora_path
62
+ self.vllm_cmd = vllm_cmd
63
+ self.vllm_timeout = vllm_timeout
64
+ self.vllm_reuse = vllm_reuse
65
+ self.vllm_process: Optional[subprocess.Popen] = None
228
66
  self.last_ai_response = None # Store raw response from client
229
67
 
230
- # if cache:
231
- # print("Caching is enabled will use llm_utils.MOpenAI")
68
+ # Handle VLLM server startup if vllm_cmd is provided
69
+ if self.vllm_cmd:
70
+ port = _extract_port_from_vllm_cmd(self.vllm_cmd)
71
+ reuse_existing = False
232
72
 
233
- # self.client = MOpenAI(base_url=base_url, api_key=api_key)
234
- # else:
235
- # self.client = OpenAI(base_url=base_url, api_key=api_key)
236
- self.client = get_base_client(client, cache=cache)
73
+ if self.vllm_reuse:
74
+ try:
75
+ reuse_client = get_base_client(port, cache=False)
76
+ models_response = reuse_client.models.list()
77
+ if getattr(models_response, "data", None):
78
+ reuse_existing = True
79
+ logger.info(
80
+ f"VLLM server already running on port {port}, "
81
+ "reusing existing server (vllm_reuse=True)"
82
+ )
83
+ else:
84
+ logger.info(
85
+ f"No models returned from VLLM server on port {port}; "
86
+ "starting a new server"
87
+ )
88
+ except Exception as exc:
89
+ logger.info(
90
+ f"Unable to reach VLLM server on port {port} (list_models failed): {exc}. "
91
+ "Starting a new server."
92
+ )
93
+
94
+ if not self.vllm_reuse:
95
+ if _is_server_running(port):
96
+ logger.info(
97
+ f"VLLM server already running on port {port}, killing it first (vllm_reuse=False)"
98
+ )
99
+ _kill_vllm_on_port(port)
100
+ logger.info(f"Starting new VLLM server on port {port}")
101
+ self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
102
+ elif not reuse_existing:
103
+ logger.info(f"Starting VLLM server on port {port}")
104
+ self.vllm_process = _start_vllm_server(self.vllm_cmd, self.vllm_timeout)
105
+
106
+ # Set client to use the VLLM server port if not explicitly provided
107
+ if client is None:
108
+ client = port
109
+
110
+ self.client = get_base_client(client, cache=cache, vllm_cmd=self.vllm_cmd, vllm_process=self.vllm_process)
237
111
  # check connection of client
238
112
  try:
239
113
  self.client.models.list()
@@ -247,8 +121,20 @@ class LLMTask:
247
121
  # Handle LoRA loading if lora_path is provided
248
122
  if self.lora_path:
249
123
  self._load_lora_adapter()
250
-
251
- print(self.model_kwargs)
124
+
125
+ def cleanup_vllm_server(self) -> None:
126
+ """Stop the VLLM server process if it was started by this instance."""
127
+ if self.vllm_process is not None:
128
+ stop_vllm_process(self.vllm_process)
129
+ self.vllm_process = None
130
+
131
+ def __enter__(self):
132
+ """Context manager entry."""
133
+ return self
134
+
135
+ def __exit__(self, exc_type, exc_val, exc_tb):
136
+ """Context manager exit with cleanup."""
137
+ self.cleanup_vllm_server()
252
138
 
253
139
  def _load_lora_adapter(self) -> None:
254
140
  """
@@ -413,23 +299,7 @@ class LLMTask:
413
299
  def text_completion(
414
300
  self, input_data: Union[str, BaseModel, list[Dict]], **runtime_kwargs
415
301
  ) -> List[Dict[str, Any]]:
416
- """
417
- Execute the LLM task and return text responses.
418
-
419
- Args:
420
- input_data: Input as string or BaseModel
421
- **runtime_kwargs: Runtime model parameters that override defaults
422
- - temperature: Controls randomness (0.0 to 2.0)
423
- - n: Number of responses to generate
424
- - max_tokens: Maximum tokens in response
425
- - model: Model name override
426
- - Any other model parameters supported by OpenAI API
427
-
428
- Returns:
429
- List of dicts [{'parsed': text_response, 'messages': messages}, ...]
430
- When n=1: List contains one dict
431
- When n>1: List contains multiple dicts
432
- """
302
+ """Execute LLM task and return text responses."""
433
303
  # Prepare messages
434
304
  messages = self._prepare_input(input_data)
435
305
 
@@ -481,29 +351,7 @@ class LLMTask:
481
351
  response_model: Optional[Type[BaseModel]] | Type[str] = None,
482
352
  **runtime_kwargs,
483
353
  ) -> List[Dict[str, Any]]:
484
- """
485
- Execute the LLM task and return parsed Pydantic model responses.
486
-
487
- Args:
488
- input_data: Input as string or BaseModel
489
- response_model: Pydantic model for response parsing (overrides default)
490
- **runtime_kwargs: Runtime model parameters that override defaults
491
- - temperature: Controls randomness (0.0 to 2.0)
492
- - n: Number of responses to generate
493
- - max_tokens: Maximum tokens in response
494
- - model: Model name override
495
- - Any other model parameters supported by OpenAI API
496
-
497
- Returns:
498
- List of dicts [{'parsed': parsed_model, 'messages': messages}, ...]
499
- When n=1: List contains one dict
500
- When n>1: List contains multiple dicts
501
-
502
- Note:
503
- This method ensures consistent Pydantic model output for both fresh and cached responses.
504
- When responses are cached and loaded back, the parsed content is re-validated to maintain
505
- type consistency between first-time and subsequent calls.
506
- """
354
+ """Execute LLM task and return parsed Pydantic model responses."""
507
355
  # Prepare messages
508
356
  messages = self._prepare_input(input_data)
509
357
 
@@ -576,20 +424,7 @@ class LLMTask:
576
424
  two_step_parse_pydantic=False,
577
425
  **runtime_kwargs,
578
426
  ) -> List[Dict[str, Any]]:
579
- """
580
- Execute the LLM task. Delegates to text() or parse() based on output_model.
581
-
582
- This method maintains backward compatibility by automatically choosing
583
- between text and parse methods based on the output_model configuration.
584
-
585
- Args:
586
- input_data: Input as string or BaseModel
587
- response_model: Optional override for output model
588
- **runtime_kwargs: Runtime model parameters
589
-
590
- Returns:
591
- List of dicts [{'parsed': response, 'messages': messages}, ...]
592
- """
427
+ """Execute LLM task. Delegates to text() or parse() based on output_model."""
593
428
  pydantic_model_to_use = response_model or self.output_model
594
429
 
595
430
  if pydantic_model_to_use is str or pydantic_model_to_use is None:
@@ -606,10 +441,8 @@ class LLMTask:
606
441
  response_text = response_text.split("</think>")[1]
607
442
  try:
608
443
  parsed = pydantic_model_to_use.model_validate_json(response_text)
609
- except Exception as e:
610
- # logger.info(
611
- # f"Warning: Failed to parsed JSON, Falling back to LLM parsing. Error: {str(e)[:100]}..."
612
- # )
444
+ except Exception:
445
+ # Failed to parse JSON, falling back to LLM parsing
613
446
  # use model to parse the response_text
614
447
  _parsed_messages = [
615
448
  {
@@ -652,6 +485,9 @@ class LLMTask:
652
485
  cache=True,
653
486
  is_reasoning_model: bool = False,
654
487
  lora_path: Optional[str] = None,
488
+ vllm_cmd: Optional[str] = None,
489
+ vllm_timeout: int = 120,
490
+ vllm_reuse: bool = True,
655
491
  **model_kwargs,
656
492
  ) -> "LLMTask":
657
493
  """
@@ -659,6 +495,17 @@ class LLMTask:
659
495
 
660
496
  This method extracts the instruction, input model, and output model
661
497
  from the provided builder and initializes an LLMTask accordingly.
498
+
499
+ Args:
500
+ builder: BasePromptBuilder instance
501
+ client: OpenAI client, port number, or base_url string
502
+ cache: Whether to use cached responses (default True)
503
+ is_reasoning_model: Whether model is reasoning model (default False)
504
+ lora_path: Optional path to LoRA adapter directory
505
+ vllm_cmd: Optional VLLM command to start server automatically
506
+ vllm_timeout: Timeout in seconds to wait for VLLM server (default 120)
507
+ vllm_reuse: If True (default), reuse existing server on target port
508
+ **model_kwargs: Additional model parameters
662
509
  """
663
510
  instruction = builder.get_instruction()
664
511
  input_model = builder.get_input_model()
@@ -673,6 +520,9 @@ class LLMTask:
673
520
  cache=cache,
674
521
  is_reasoning_model=is_reasoning_model,
675
522
  lora_path=lora_path,
523
+ vllm_cmd=vllm_cmd,
524
+ vllm_timeout=vllm_timeout,
525
+ vllm_reuse=vllm_reuse,
676
526
  **model_kwargs,
677
527
  )
678
528
 
@@ -691,3 +541,74 @@ class LLMTask:
691
541
  models = client.models.list().data
692
542
  return [m.id for m in models]
693
543
 
544
+ @staticmethod
545
+ def kill_all_vllm() -> int:
546
+ """Kill all tracked VLLM server processes."""
547
+ return kill_all_vllm_processes()
548
+
549
+ @staticmethod
550
+ def kill_vllm_on_port(port: int) -> bool:
551
+ """
552
+ Kill VLLM server running on a specific port.
553
+
554
+ Args:
555
+ port: Port number to kill server on
556
+
557
+ Returns:
558
+ True if a server was killed, False if no server was running
559
+ """
560
+ return _kill_vllm_on_port(port)
561
+
562
+
563
+ # Example usage:
564
+ if __name__ == "__main__":
565
+ # Example 1: Using VLLM with reuse (default behavior)
566
+ vllm_command = (
567
+ "vllm serve saves/vng/dpo/01 -tp 4 --port 8001 "
568
+ "--gpu-memory-utilization 0.9 --served-model-name sft --quantization experts_int8"
569
+ )
570
+
571
+ print("Example 1: Using VLLM with server reuse (default)")
572
+ # Create LLM instance - will reuse existing server if running on port 8001
573
+ with LLMTask(vllm_cmd=vllm_command) as llm: # vllm_reuse=True by default
574
+ result = llm.text("Hello, how are you?")
575
+ print("Response:", result[0]["parsed"])
576
+
577
+ print("\nExample 2: Force restart server (vllm_reuse=False)")
578
+ # This will kill any existing server on port 8001 and start fresh
579
+ with LLMTask(vllm_cmd=vllm_command, vllm_reuse=False) as llm:
580
+ result = llm.text("Tell me a joke")
581
+ print("Joke:", result[0]["parsed"])
582
+
583
+ print("\nExample 3: Multiple instances with reuse")
584
+ # First instance starts the server
585
+ llm1 = LLMTask(vllm_cmd=vllm_command) # Starts server or reuses existing
586
+
587
+ # Second instance reuses the same server
588
+ llm2 = LLMTask(vllm_cmd=vllm_command) # Reuses server on port 8001
589
+
590
+ try:
591
+ result1 = llm1.text("What's the weather like?")
592
+ result2 = llm2.text("How's the traffic?")
593
+ print("Weather response:", result1[0]["parsed"])
594
+ print("Traffic response:", result2[0]["parsed"])
595
+ finally:
596
+ # Only cleanup if we started the process
597
+ llm1.cleanup_vllm_server()
598
+ llm2.cleanup_vllm_server() # Won't do anything if process not owned
599
+
600
+ print("\nExample 4: Different ports")
601
+ # These will start separate servers
602
+ llm_8001 = LLMTask(vllm_cmd="vllm serve model1 --port 8001", vllm_reuse=True)
603
+ llm_8002 = LLMTask(vllm_cmd="vllm serve model2 --port 8002", vllm_reuse=True)
604
+
605
+ print("\nExample 5: Kill all VLLM servers")
606
+ # Kill all tracked VLLM processes
607
+ killed_count = LLMTask.kill_all_vllm()
608
+ print(f"Killed {killed_count} VLLM servers")
609
+
610
+ print("\nYou can check VLLM server output at: /tmp/vllm.txt")
611
+ print("Server reuse behavior:")
612
+ print("- vllm_reuse=True (default): Reuse existing server on target port")
613
+ print("- vllm_reuse=False: Kill existing server first, then start fresh")
614
+