vec-inf 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,178 @@
1
+ """Class for generating SLURM scripts to run vLLM servers.
2
+
3
+ This module provides functionality to generate SLURM scripts for running vLLM servers
4
+ in both single-node and multi-node configurations.
5
+ """
6
+
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from vec_inf.client._client_vars import (
12
+ SLURM_JOB_CONFIG_ARGS,
13
+ SLURM_SCRIPT_TEMPLATE,
14
+ )
15
+
16
+
17
+ class SlurmScriptGenerator:
18
+ """A class to generate SLURM scripts for running vLLM servers.
19
+
20
+ This class handles the generation of SLURM scripts for both single-node and
21
+ multi-node configurations, supporting different virtualization environments
22
+ (venv or singularity).
23
+
24
+ Parameters
25
+ ----------
26
+ params : dict[str, Any]
27
+ Configuration parameters for the SLURM script. Contains settings for job
28
+ configuration, model parameters, and virtualization environment.
29
+ """
30
+
31
+ def __init__(self, params: dict[str, Any]):
32
+ """Initialize the SlurmScriptGenerator with configuration parameters.
33
+
34
+ Parameters
35
+ ----------
36
+ params : dict[str, Any]
37
+ Configuration parameters for the SLURM script.
38
+ """
39
+ self.params = params
40
+ self.is_multinode = int(self.params["num_nodes"]) > 1
41
+ self.use_singularity = self.params["venv"] == "singularity"
42
+ self.additional_binds = self.params.get("bind", "")
43
+ if self.additional_binds:
44
+ self.additional_binds = f" --bind {self.additional_binds}"
45
+ self.model_weights_path = str(
46
+ Path(params["model_weights_parent_dir"], params["model_name"])
47
+ )
48
+
49
+ def _generate_script_content(self) -> str:
50
+ """Generate the complete SLURM script content.
51
+
52
+ Returns
53
+ -------
54
+ str
55
+ The complete SLURM script as a string.
56
+ """
57
+ script_content = []
58
+ script_content.append(self._generate_shebang())
59
+ script_content.append(self._generate_server_setup())
60
+ script_content.append(self._generate_launch_cmd())
61
+ return "\n".join(script_content)
62
+
63
+ def _generate_shebang(self) -> str:
64
+ """Generate the SLURM script shebang with job specifications.
65
+
66
+ Returns
67
+ -------
68
+ str
69
+ SLURM shebang containing job specifications.
70
+ """
71
+ shebang = [SLURM_SCRIPT_TEMPLATE["shebang"]["base"]]
72
+ for arg, value in SLURM_JOB_CONFIG_ARGS.items():
73
+ if self.params.get(value):
74
+ shebang.append(f"#SBATCH --{arg}={self.params[value]}")
75
+ if self.is_multinode:
76
+ shebang += SLURM_SCRIPT_TEMPLATE["shebang"]["multinode"]
77
+ return "\n".join(shebang)
78
+
79
+ def _generate_server_setup(self) -> str:
80
+ """Generate the server initialization script.
81
+
82
+ Creates the script section that handles server setup, including Ray
83
+ initialization for multi-node setups and port configuration.
84
+
85
+ Returns
86
+ -------
87
+ str
88
+ Server initialization script content.
89
+ """
90
+ server_script = ["\n"]
91
+ if self.use_singularity:
92
+ server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["singularity_setup"]))
93
+ server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["env_vars"]))
94
+ server_script.append(
95
+ SLURM_SCRIPT_TEMPLATE["imports"].format(src_dir=self.params["src_dir"])
96
+ )
97
+ if self.is_multinode:
98
+ server_setup_str = "\n".join(
99
+ SLURM_SCRIPT_TEMPLATE["server_setup"]["multinode"]
100
+ )
101
+ if self.use_singularity:
102
+ server_setup_str = server_setup_str.replace(
103
+ "SINGULARITY_PLACEHOLDER",
104
+ SLURM_SCRIPT_TEMPLATE["singularity_command"].format(
105
+ model_weights_path=self.model_weights_path,
106
+ additional_binds=self.additional_binds,
107
+ ),
108
+ )
109
+ else:
110
+ server_setup_str = "\n".join(
111
+ SLURM_SCRIPT_TEMPLATE["server_setup"]["single_node"]
112
+ )
113
+ server_script.append(server_setup_str)
114
+ server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["find_vllm_port"]))
115
+ server_script.append(
116
+ "\n".join(SLURM_SCRIPT_TEMPLATE["write_to_json"]).format(
117
+ log_dir=self.params["log_dir"], model_name=self.params["model_name"]
118
+ )
119
+ )
120
+ return "\n".join(server_script)
121
+
122
+ def _generate_launch_cmd(self) -> str:
123
+ """Generate the vLLM server launch command.
124
+
125
+ Creates the command to launch the vLLM server, handling different virtualization
126
+ environments (venv or singularity).
127
+
128
+ Returns
129
+ -------
130
+ str
131
+ Server launch command.
132
+ """
133
+ launcher_script = ["\n"]
134
+ if self.use_singularity:
135
+ launcher_script.append(
136
+ SLURM_SCRIPT_TEMPLATE["singularity_command"].format(
137
+ model_weights_path=self.model_weights_path,
138
+ additional_binds=self.additional_binds,
139
+ )
140
+ + " \\"
141
+ )
142
+ else:
143
+ launcher_script.append(
144
+ SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
145
+ )
146
+ launcher_script.append(
147
+ "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
148
+ model_weights_path=self.model_weights_path,
149
+ model_name=self.params["model_name"],
150
+ )
151
+ )
152
+
153
+ for arg, value in self.params["vllm_args"].items():
154
+ if isinstance(value, bool):
155
+ launcher_script.append(f" {arg} \\")
156
+ else:
157
+ launcher_script.append(f" {arg} {value} \\")
158
+ return "\n".join(launcher_script)
159
+
160
+ def write_to_log_dir(self) -> Path:
161
+ """Write the generated SLURM script to the log directory.
162
+
163
+ Creates a timestamped script file in the configured log directory.
164
+
165
+ Returns
166
+ -------
167
+ Path
168
+ Path to the generated SLURM script file.
169
+ """
170
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
171
+ script_path: Path = (
172
+ Path(self.params["log_dir"])
173
+ / f"launch_{self.params['model_name']}_{timestamp}.slurm"
174
+ )
175
+
176
+ content = self._generate_script_content()
177
+ script_path.write_text(content)
178
+ return script_path
@@ -0,0 +1,287 @@
1
+ """Utility functions shared between CLI and API.
2
+
3
+ This module provides utility functions for managing SLURM jobs, server status checks,
4
+ and configuration handling for the vector inference package.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import subprocess
10
+ import warnings
11
+ from pathlib import Path
12
+ from typing import Any, Optional, Union, cast
13
+
14
+ import requests
15
+ import yaml
16
+
17
+ from vec_inf.client._client_vars import MODEL_READY_SIGNATURE
18
+ from vec_inf.client.config import ModelConfig
19
+ from vec_inf.client.models import ModelStatus
20
+ from vec_inf.client.slurm_vars import CACHED_CONFIG
21
+
22
+
23
+ def run_bash_command(command: str) -> tuple[str, str]:
24
+ """Run a bash command and return the output.
25
+
26
+ Parameters
27
+ ----------
28
+ command : str
29
+ The bash command to execute
30
+
31
+ Returns
32
+ -------
33
+ tuple[str, str]
34
+ A tuple containing (stdout, stderr) from the command execution
35
+ """
36
+ process = subprocess.Popen(
37
+ command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
38
+ )
39
+ return process.communicate()
40
+
41
+
42
+ def read_slurm_log(
43
+ slurm_job_name: str,
44
+ slurm_job_id: int,
45
+ slurm_log_type: str,
46
+ log_dir: Optional[Union[str, Path]],
47
+ ) -> Union[list[str], str, dict[str, str]]:
48
+ """Read the slurm log file.
49
+
50
+ Parameters
51
+ ----------
52
+ slurm_job_name : str
53
+ Name of the SLURM job
54
+ slurm_job_id : int
55
+ ID of the SLURM job
56
+ slurm_log_type : str
57
+ Type of log file to read ('out', 'err', or 'json')
58
+ log_dir : Optional[Union[str, Path]]
59
+ Directory containing log files, if None uses default location
60
+
61
+ Returns
62
+ -------
63
+ Union[list[str], str, dict[str, str]]
64
+ Contents of the log file:
65
+ - list[str] for 'out' and 'err' logs
66
+ - dict[str, str] for 'json' logs
67
+ - str for error messages if file not found
68
+ """
69
+ if not log_dir:
70
+ # Default log directory
71
+ models_dir = Path.home() / ".vec-inf-logs"
72
+ # Iterate over all dirs in models_dir, sorted by dir name length in desc order
73
+ for directory in sorted(
74
+ [d for d in models_dir.iterdir() if d.is_dir()],
75
+ key=lambda d: len(d.name),
76
+ reverse=True,
77
+ ):
78
+ if directory.name in slurm_job_name:
79
+ log_dir = directory
80
+ break
81
+ else:
82
+ log_dir = Path(log_dir)
83
+
84
+ # If log_dir is still not set, then didn't find the log dir at default location
85
+ if not log_dir:
86
+ return "LOG DIR NOT FOUND"
87
+
88
+ try:
89
+ file_path = (
90
+ log_dir
91
+ / Path(f"{slurm_job_name}.{slurm_job_id}")
92
+ / f"{slurm_job_name}.{slurm_job_id}.{slurm_log_type}"
93
+ )
94
+ if slurm_log_type == "json":
95
+ with file_path.open("r") as file:
96
+ json_content: dict[str, str] = json.load(file)
97
+ return json_content
98
+ else:
99
+ with file_path.open("r") as file:
100
+ return file.readlines()
101
+ except FileNotFoundError:
102
+ return f"LOG FILE NOT FOUND: {file_path}"
103
+
104
+
105
+ def is_server_running(
106
+ slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str]
107
+ ) -> Union[str, ModelStatus, tuple[ModelStatus, str]]:
108
+ """Check if a model is ready to serve requests.
109
+
110
+ Parameters
111
+ ----------
112
+ slurm_job_name : str
113
+ Name of the SLURM job
114
+ slurm_job_id : int
115
+ ID of the SLURM job
116
+ log_dir : Optional[str]
117
+ Directory containing log files
118
+
119
+ Returns
120
+ -------
121
+ Union[str, ModelStatus, tuple[ModelStatus, str]]
122
+ - str: Error message if logs cannot be read
123
+ - ModelStatus: Current status of the server
124
+ - tuple[ModelStatus, str]: Status and error message if server failed
125
+ """
126
+ log_content = read_slurm_log(slurm_job_name, slurm_job_id, "err", log_dir)
127
+ if isinstance(log_content, str):
128
+ return log_content
129
+
130
+ status: Union[str, tuple[ModelStatus, str]] = ModelStatus.LAUNCHING
131
+
132
+ for line in log_content:
133
+ if "error" in line.lower():
134
+ status = (ModelStatus.FAILED, line.strip("\n"))
135
+ if MODEL_READY_SIGNATURE in line:
136
+ status = "RUNNING"
137
+
138
+ return status
139
+
140
+
141
+ def get_base_url(slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str]) -> str:
142
+ """Get the base URL of a model.
143
+
144
+ Parameters
145
+ ----------
146
+ slurm_job_name : str
147
+ Name of the SLURM job
148
+ slurm_job_id : int
149
+ ID of the SLURM job
150
+ log_dir : Optional[str]
151
+ Directory containing log files
152
+
153
+ Returns
154
+ -------
155
+ str
156
+ Base URL of the model server or error message if not found
157
+ """
158
+ log_content = read_slurm_log(slurm_job_name, slurm_job_id, "json", log_dir)
159
+ if isinstance(log_content, str):
160
+ return log_content
161
+
162
+ server_addr = cast(dict[str, str], log_content).get("server_address")
163
+ return server_addr if server_addr else "URL NOT FOUND"
164
+
165
+
166
+ def model_health_check(
167
+ slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str]
168
+ ) -> tuple[ModelStatus, Union[str, int]]:
169
+ """Check the health of a running model on the cluster.
170
+
171
+ Parameters
172
+ ----------
173
+ slurm_job_name : str
174
+ Name of the SLURM job
175
+ slurm_job_id : int
176
+ ID of the SLURM job
177
+ log_dir : Optional[str]
178
+ Directory containing log files
179
+
180
+ Returns
181
+ -------
182
+ tuple[ModelStatus, Union[str, int]]
183
+ Tuple containing:
184
+ - ModelStatus: Current status of the model
185
+ - Union[str, int]: Either HTTP status code or error message
186
+ """
187
+ base_url = get_base_url(slurm_job_name, slurm_job_id, log_dir)
188
+ if not base_url.startswith("http"):
189
+ return (ModelStatus.FAILED, base_url)
190
+ health_check_url = base_url.replace("v1", "health")
191
+
192
+ try:
193
+ response = requests.get(health_check_url)
194
+ # Check if the request was successful
195
+ if response.status_code == 200:
196
+ return (ModelStatus.READY, response.status_code)
197
+ return (ModelStatus.FAILED, response.status_code)
198
+ except requests.exceptions.RequestException as e:
199
+ return (ModelStatus.FAILED, str(e))
200
+
201
+
202
+ def load_config() -> list[ModelConfig]:
203
+ """Load the model configuration.
204
+
205
+ Loads configuration from default and user-specified paths, merging them
206
+ if both exist. User configuration takes precedence over default values.
207
+
208
+ Returns
209
+ -------
210
+ list[ModelConfig]
211
+ List of validated model configurations
212
+
213
+ Notes
214
+ -----
215
+ Configuration is loaded from:
216
+ 1. Default path: package's config/models.yaml
217
+ 2. User path: specified by VEC_INF_CONFIG environment variable
218
+
219
+ If user configuration exists, it will be merged with default configuration,
220
+ with user values taking precedence for overlapping fields.
221
+ """
222
+ default_path = (
223
+ CACHED_CONFIG
224
+ if CACHED_CONFIG.exists()
225
+ else Path(__file__).resolve().parent.parent / "config" / "models.yaml"
226
+ )
227
+
228
+ config: dict[str, Any] = {}
229
+ with open(default_path) as f:
230
+ config = yaml.safe_load(f) or {}
231
+
232
+ user_path = os.getenv("VEC_INF_CONFIG")
233
+ if user_path:
234
+ user_path_obj = Path(user_path)
235
+ if user_path_obj.exists():
236
+ with open(user_path_obj) as f:
237
+ user_config = yaml.safe_load(f) or {}
238
+ for name, data in user_config.get("models", {}).items():
239
+ if name in config.get("models", {}):
240
+ config["models"][name].update(data)
241
+ else:
242
+ config.setdefault("models", {})[name] = data
243
+ else:
244
+ warnings.warn(
245
+ f"WARNING: Could not find user config: {user_path}, revert to default config located at {default_path}",
246
+ UserWarning,
247
+ stacklevel=2,
248
+ )
249
+
250
+ return [
251
+ ModelConfig(model_name=name, **model_data)
252
+ for name, model_data in config.get("models", {}).items()
253
+ ]
254
+
255
+
256
+ def parse_launch_output(output: str) -> tuple[str, dict[str, str]]:
257
+ """Parse output from model launch command.
258
+
259
+ Parameters
260
+ ----------
261
+ output : str
262
+ Raw output from the launch command
263
+
264
+ Returns
265
+ -------
266
+ tuple[str, dict[str, str]]
267
+ Tuple containing:
268
+ - str: SLURM job ID
269
+ - dict[str, str]: Dictionary of parsed configuration parameters
270
+
271
+ Notes
272
+ -----
273
+ Extracts the SLURM job ID and configuration parameters from the launch
274
+ command output. Configuration parameters are parsed from key-value pairs
275
+ in the output text.
276
+ """
277
+ slurm_job_id = output.split(" ")[-1].strip().strip("\n")
278
+
279
+ # Extract config parameters
280
+ config_dict = {}
281
+ output_lines = output.split("\n")[:-2]
282
+ for line in output_lines:
283
+ if ": " in line:
284
+ key, value = line.split(": ", 1)
285
+ config_dict[key.lower().replace(" ", "_")] = value
286
+
287
+ return slurm_job_id, config_dict