vec-inf 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vec_inf/README.md +2 -1
- vec_inf/cli/_cli.py +39 -10
- vec_inf/cli/_helper.py +100 -19
- vec_inf/client/_helper.py +80 -31
- vec_inf/client/_slurm_script_generator.py +58 -30
- vec_inf/client/_slurm_templates.py +27 -12
- vec_inf/client/_utils.py +58 -6
- vec_inf/client/api.py +55 -2
- vec_inf/client/models.py +6 -0
- vec_inf/config/models.yaml +47 -99
- vec_inf/find_port.sh +10 -1
- {vec_inf-0.7.1.dist-info → vec_inf-0.7.3.dist-info}/METADATA +7 -6
- vec_inf-0.7.3.dist-info/RECORD +27 -0
- {vec_inf-0.7.1.dist-info → vec_inf-0.7.3.dist-info}/WHEEL +1 -1
- vec_inf-0.7.1.dist-info/RECORD +0 -27
- {vec_inf-0.7.1.dist-info → vec_inf-0.7.3.dist-info}/entry_points.txt +0 -0
- {vec_inf-0.7.1.dist-info → vec_inf-0.7.3.dist-info}/licenses/LICENSE +0 -0
vec_inf/README.md
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
* `launch`: Specify a model family and other optional parameters to launch an OpenAI compatible inference server.
|
|
4
4
|
* `batch-launch`: Specify a list of models to launch multiple OpenAI compatible inference servers at the same time.
|
|
5
|
-
* `status`: Check the
|
|
5
|
+
* `status`: Check the status of all `vec-inf` jobs, or a specific job by providing its job ID.
|
|
6
6
|
* `metrics`: Streams performance metrics to the console.
|
|
7
7
|
* `shutdown`: Shutdown a model by providing its Slurm job ID.
|
|
8
8
|
* `list`: List all available model names, or view the default/cached configuration of a specific model.
|
|
@@ -14,6 +14,7 @@ Use `--help` to see all available options
|
|
|
14
14
|
|
|
15
15
|
* `launch_model`: Launch an OpenAI compatible inference server.
|
|
16
16
|
* `batch_launch_models`: Launch multiple OpenAI compatible inference servers.
|
|
17
|
+
* `fetch_running_jobs`: Get the running `vec-inf` job IDs.
|
|
17
18
|
* `get_status`: Get the status of a running model.
|
|
18
19
|
* `get_metrics`: Get the performance metrics of a running model.
|
|
19
20
|
* `shutdown_model`: Shutdown a running model.
|
vec_inf/cli/_cli.py
CHANGED
|
@@ -30,6 +30,7 @@ from vec_inf.cli._helper import (
|
|
|
30
30
|
BatchLaunchResponseFormatter,
|
|
31
31
|
LaunchResponseFormatter,
|
|
32
32
|
ListCmdDisplay,
|
|
33
|
+
ListStatusDisplay,
|
|
33
34
|
MetricsResponseFormatter,
|
|
34
35
|
StatusResponseFormatter,
|
|
35
36
|
)
|
|
@@ -69,6 +70,16 @@ def cli() -> None:
|
|
|
69
70
|
type=int,
|
|
70
71
|
help="Number of GPUs/node to use, default to suggested resource allocation for model",
|
|
71
72
|
)
|
|
73
|
+
@click.option(
|
|
74
|
+
"--cpus-per-task",
|
|
75
|
+
type=int,
|
|
76
|
+
help="Number of CPU cores per task",
|
|
77
|
+
)
|
|
78
|
+
@click.option(
|
|
79
|
+
"--mem-per-node",
|
|
80
|
+
type=str,
|
|
81
|
+
help="Memory allocation per node in GB format (e.g., '32G')",
|
|
82
|
+
)
|
|
72
83
|
@click.option(
|
|
73
84
|
"--account",
|
|
74
85
|
"-A",
|
|
@@ -165,6 +176,10 @@ def launch(
|
|
|
165
176
|
Number of nodes to use
|
|
166
177
|
- gpus_per_node : int, optional
|
|
167
178
|
Number of GPUs per node
|
|
179
|
+
- cpus_per_task : int, optional
|
|
180
|
+
Number of CPU cores per task
|
|
181
|
+
- mem_per_node : str, optional
|
|
182
|
+
Memory allocation per node in GB format (e.g., '32G')
|
|
168
183
|
- account : str, optional
|
|
169
184
|
Charge resources used by this job to specified account
|
|
170
185
|
- work_dir : str, optional
|
|
@@ -299,14 +314,14 @@ def batch_launch(
|
|
|
299
314
|
raise click.ClickException(f"Batch launch failed: {str(e)}") from e
|
|
300
315
|
|
|
301
316
|
|
|
302
|
-
@cli.command("status", help="Check the status of
|
|
303
|
-
@click.argument("slurm_job_id",
|
|
317
|
+
@cli.command("status", help="Check the status of running vec-inf jobs on the cluster.")
|
|
318
|
+
@click.argument("slurm_job_id", required=False)
|
|
304
319
|
@click.option(
|
|
305
320
|
"--json-mode",
|
|
306
321
|
is_flag=True,
|
|
307
322
|
help="Output in JSON string",
|
|
308
323
|
)
|
|
309
|
-
def status(slurm_job_id: str, json_mode: bool = False) -> None:
|
|
324
|
+
def status(slurm_job_id: Optional[str] = None, json_mode: bool = False) -> None:
|
|
310
325
|
"""Get the status of a running model on the cluster.
|
|
311
326
|
|
|
312
327
|
Parameters
|
|
@@ -324,14 +339,28 @@ def status(slurm_job_id: str, json_mode: bool = False) -> None:
|
|
|
324
339
|
try:
|
|
325
340
|
# Start the client and get model inference server status
|
|
326
341
|
client = VecInfClient()
|
|
327
|
-
|
|
342
|
+
if not slurm_job_id:
|
|
343
|
+
slurm_job_ids = client.fetch_running_jobs()
|
|
344
|
+
if not slurm_job_ids:
|
|
345
|
+
click.echo("No running jobs found.")
|
|
346
|
+
return
|
|
347
|
+
else:
|
|
348
|
+
slurm_job_ids = [slurm_job_id]
|
|
349
|
+
responses = []
|
|
350
|
+
for job_id in slurm_job_ids:
|
|
351
|
+
responses.append(client.get_status(job_id))
|
|
352
|
+
|
|
328
353
|
# Display status information
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
354
|
+
if slurm_job_id:
|
|
355
|
+
status_formatter = StatusResponseFormatter(responses[0])
|
|
356
|
+
if json_mode:
|
|
357
|
+
status_formatter.output_json()
|
|
358
|
+
else:
|
|
359
|
+
status_info_table = status_formatter.output_table()
|
|
360
|
+
CONSOLE.print(status_info_table)
|
|
332
361
|
else:
|
|
333
|
-
|
|
334
|
-
|
|
362
|
+
list_status_display = ListStatusDisplay(slurm_job_ids, responses, json_mode)
|
|
363
|
+
list_status_display.display_multiple_status_output(CONSOLE)
|
|
335
364
|
|
|
336
365
|
except click.ClickException as e:
|
|
337
366
|
raise e
|
|
@@ -447,7 +476,7 @@ def metrics(slurm_job_id: str) -> None:
|
|
|
447
476
|
metrics_formatter.format_metrics()
|
|
448
477
|
|
|
449
478
|
live.update(metrics_formatter.table)
|
|
450
|
-
time.sleep(
|
|
479
|
+
time.sleep(1)
|
|
451
480
|
except click.ClickException as e:
|
|
452
481
|
raise e
|
|
453
482
|
except Exception as e:
|
vec_inf/cli/_helper.py
CHANGED
|
@@ -36,6 +36,43 @@ class LaunchResponseFormatter:
|
|
|
36
36
|
self.model_name = model_name
|
|
37
37
|
self.params = params
|
|
38
38
|
|
|
39
|
+
def _add_resource_allocation_details(self, table: Table) -> None:
|
|
40
|
+
"""Add resource allocation details to the table."""
|
|
41
|
+
optional_fields = [
|
|
42
|
+
("account", "Account"),
|
|
43
|
+
("work_dir", "Working Directory"),
|
|
44
|
+
("resource_type", "Resource Type"),
|
|
45
|
+
("partition", "Partition"),
|
|
46
|
+
("qos", "QoS"),
|
|
47
|
+
]
|
|
48
|
+
for key, label in optional_fields:
|
|
49
|
+
if self.params.get(key):
|
|
50
|
+
table.add_row(label, self.params[key])
|
|
51
|
+
|
|
52
|
+
def _add_vllm_config(self, table: Table) -> None:
|
|
53
|
+
"""Add vLLM configuration details to the table."""
|
|
54
|
+
if self.params.get("vllm_args"):
|
|
55
|
+
table.add_row("vLLM Arguments:", style="magenta")
|
|
56
|
+
for arg, value in self.params["vllm_args"].items():
|
|
57
|
+
table.add_row(f" {arg}:", str(value))
|
|
58
|
+
|
|
59
|
+
def _add_env_vars(self, table: Table) -> None:
|
|
60
|
+
"""Add environment variable configuration details to the table."""
|
|
61
|
+
if self.params.get("env"):
|
|
62
|
+
table.add_row("Environment Variables", style="magenta")
|
|
63
|
+
for arg, value in self.params["env"].items():
|
|
64
|
+
table.add_row(f" {arg}:", str(value))
|
|
65
|
+
|
|
66
|
+
def _add_bind_paths(self, table: Table) -> None:
|
|
67
|
+
"""Add bind path configuration details to the table."""
|
|
68
|
+
if self.params.get("bind"):
|
|
69
|
+
table.add_row("Bind Paths", style="magenta")
|
|
70
|
+
for path in self.params["bind"].split(","):
|
|
71
|
+
host = target = path
|
|
72
|
+
if ":" in path:
|
|
73
|
+
host, target = path.split(":")
|
|
74
|
+
table.add_row(f" {host}:", target)
|
|
75
|
+
|
|
39
76
|
def format_table_output(self) -> Table:
|
|
40
77
|
"""Format output as rich Table.
|
|
41
78
|
|
|
@@ -59,16 +96,7 @@ class LaunchResponseFormatter:
|
|
|
59
96
|
table.add_row("Vocabulary Size", self.params["vocab_size"])
|
|
60
97
|
|
|
61
98
|
# Add resource allocation details
|
|
62
|
-
|
|
63
|
-
table.add_row("Account", self.params["account"])
|
|
64
|
-
if self.params.get("work_dir"):
|
|
65
|
-
table.add_row("Working Directory", self.params["work_dir"])
|
|
66
|
-
if self.params.get("resource_type"):
|
|
67
|
-
table.add_row("Resource Type", self.params["resource_type"])
|
|
68
|
-
if self.params.get("partition"):
|
|
69
|
-
table.add_row("Partition", self.params["partition"])
|
|
70
|
-
if self.params.get("qos"):
|
|
71
|
-
table.add_row("QoS", self.params["qos"])
|
|
99
|
+
self._add_resource_allocation_details(table)
|
|
72
100
|
table.add_row("Time Limit", self.params["time"])
|
|
73
101
|
table.add_row("Num Nodes", self.params["num_nodes"])
|
|
74
102
|
table.add_row("GPUs/Node", self.params["gpus_per_node"])
|
|
@@ -76,21 +104,18 @@ class LaunchResponseFormatter:
|
|
|
76
104
|
table.add_row("Memory/Node", self.params["mem_per_node"])
|
|
77
105
|
|
|
78
106
|
# Add job config details
|
|
107
|
+
if self.params.get("venv"):
|
|
108
|
+
table.add_row("Virtual Environment", self.params["venv"])
|
|
79
109
|
table.add_row(
|
|
80
110
|
"Model Weights Directory",
|
|
81
111
|
str(Path(self.params["model_weights_parent_dir"], self.model_name)),
|
|
82
112
|
)
|
|
83
113
|
table.add_row("Log Directory", self.params["log_dir"])
|
|
84
114
|
|
|
85
|
-
# Add
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
# Add Environment Variable Configuration Details
|
|
91
|
-
table.add_row("Environment Variables", style="magenta")
|
|
92
|
-
for arg, value in self.params["env"].items():
|
|
93
|
-
table.add_row(f" {arg}:", str(value))
|
|
115
|
+
# Add configuration details
|
|
116
|
+
self._add_vllm_config(table)
|
|
117
|
+
self._add_env_vars(table)
|
|
118
|
+
self._add_bind_paths(table)
|
|
94
119
|
|
|
95
120
|
return table
|
|
96
121
|
|
|
@@ -226,6 +251,62 @@ class StatusResponseFormatter:
|
|
|
226
251
|
return table
|
|
227
252
|
|
|
228
253
|
|
|
254
|
+
class ListStatusDisplay:
|
|
255
|
+
"""CLI Helper class for formatting a list of StatusResponse.
|
|
256
|
+
|
|
257
|
+
A formatter class that handles the presentation of multiple job statuses
|
|
258
|
+
in a table format.
|
|
259
|
+
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
statuses : list[StatusResponse]
|
|
263
|
+
List of model status information
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
def __init__(
|
|
267
|
+
self,
|
|
268
|
+
job_ids: list[str],
|
|
269
|
+
statuses: list[StatusResponse],
|
|
270
|
+
json_mode: bool = False,
|
|
271
|
+
):
|
|
272
|
+
self.job_ids = job_ids
|
|
273
|
+
self.statuses = statuses
|
|
274
|
+
self.json_mode = json_mode
|
|
275
|
+
|
|
276
|
+
self.table = Table(show_header=True, header_style="bold magenta")
|
|
277
|
+
self.table.add_column("Job ID")
|
|
278
|
+
self.table.add_column("Model Name")
|
|
279
|
+
self.table.add_column("Status", style="blue")
|
|
280
|
+
self.table.add_column("Base URL")
|
|
281
|
+
|
|
282
|
+
def display_multiple_status_output(self, console: Console) -> None:
|
|
283
|
+
"""Format and display all model statuses.
|
|
284
|
+
|
|
285
|
+
Formats each model's status and adds it to the table.
|
|
286
|
+
"""
|
|
287
|
+
if self.json_mode:
|
|
288
|
+
json_data = [
|
|
289
|
+
{
|
|
290
|
+
"job_id": status.model_name,
|
|
291
|
+
"model_name": status.model_name,
|
|
292
|
+
"model_status": status.server_status,
|
|
293
|
+
"base_url": status.base_url,
|
|
294
|
+
}
|
|
295
|
+
for status in self.statuses
|
|
296
|
+
]
|
|
297
|
+
click.echo(json.dumps(json_data, indent=4))
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
for i, status in enumerate(self.statuses):
|
|
301
|
+
self.table.add_row(
|
|
302
|
+
self.job_ids[i],
|
|
303
|
+
status.model_name,
|
|
304
|
+
status.server_status,
|
|
305
|
+
status.base_url,
|
|
306
|
+
)
|
|
307
|
+
console.print(self.table)
|
|
308
|
+
|
|
309
|
+
|
|
229
310
|
class MetricsResponseFormatter:
|
|
230
311
|
"""CLI Helper class for formatting MetricsResponse.
|
|
231
312
|
|
vec_inf/client/_helper.py
CHANGED
|
@@ -31,6 +31,7 @@ from vec_inf.client._slurm_script_generator import (
|
|
|
31
31
|
BatchSlurmScriptGenerator,
|
|
32
32
|
SlurmScriptGenerator,
|
|
33
33
|
)
|
|
34
|
+
from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME, IMAGE_PATH
|
|
34
35
|
from vec_inf.client.config import ModelConfig
|
|
35
36
|
from vec_inf.client.models import (
|
|
36
37
|
BatchLaunchResponse,
|
|
@@ -195,23 +196,14 @@ class ModelLauncher:
|
|
|
195
196
|
print(f"WARNING: Could not parse env var: {line}")
|
|
196
197
|
return env_vars
|
|
197
198
|
|
|
198
|
-
def
|
|
199
|
-
"""
|
|
199
|
+
def _apply_cli_overrides(self, params: dict[str, Any]) -> None:
|
|
200
|
+
"""Apply CLI argument overrides to params.
|
|
200
201
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
dict[str, Any]
|
|
204
|
-
Dictionary of
|
|
205
|
-
|
|
206
|
-
Raises
|
|
207
|
-
------
|
|
208
|
-
MissingRequiredFieldsError
|
|
209
|
-
If required fields are missing or tensor parallel size is not specified
|
|
210
|
-
when using multiple GPUs
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
params : dict[str, Any]
|
|
205
|
+
Dictionary of launch parameters to override
|
|
211
206
|
"""
|
|
212
|
-
params = self.model_config.model_dump(exclude_none=True)
|
|
213
|
-
|
|
214
|
-
# Override config defaults with CLI arguments
|
|
215
207
|
if self.kwargs.get("vllm_args"):
|
|
216
208
|
vllm_args = self._process_vllm_args(self.kwargs["vllm_args"])
|
|
217
209
|
for key, value in vllm_args.items():
|
|
@@ -224,13 +216,29 @@ class ModelLauncher:
|
|
|
224
216
|
params["env"][key] = str(value)
|
|
225
217
|
del self.kwargs["env"]
|
|
226
218
|
|
|
219
|
+
if self.kwargs.get("bind") and params.get("bind"):
|
|
220
|
+
params["bind"] = f"{params['bind']},{self.kwargs['bind']}"
|
|
221
|
+
del self.kwargs["bind"]
|
|
222
|
+
|
|
227
223
|
for key, value in self.kwargs.items():
|
|
228
224
|
params[key] = value
|
|
229
225
|
|
|
230
|
-
|
|
231
|
-
|
|
226
|
+
def _validate_resource_allocation(self, params: dict[str, Any]) -> None:
|
|
227
|
+
"""Validate resource allocation and parallelization settings.
|
|
232
228
|
|
|
233
|
-
|
|
229
|
+
Parameters
|
|
230
|
+
----------
|
|
231
|
+
params : dict[str, Any]
|
|
232
|
+
Dictionary of launch parameters to validate
|
|
233
|
+
|
|
234
|
+
Raises
|
|
235
|
+
------
|
|
236
|
+
MissingRequiredFieldsError
|
|
237
|
+
If tensor parallel size is not specified when using multiple GPUs
|
|
238
|
+
ValueError
|
|
239
|
+
If total # of GPUs requested is not a power of two
|
|
240
|
+
If mismatch between total # of GPUs requested and parallelization settings
|
|
241
|
+
"""
|
|
234
242
|
if (
|
|
235
243
|
int(params["gpus_per_node"]) > 1
|
|
236
244
|
and params["vllm_args"].get("--tensor-parallel-size") is None
|
|
@@ -251,19 +259,18 @@ class ModelLauncher:
|
|
|
251
259
|
"Mismatch between total number of GPUs requested and parallelization settings"
|
|
252
260
|
)
|
|
253
261
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
if resource_type:
|
|
257
|
-
params["gres"] = f"gpu:{resource_type}:{params['gpus_per_node']}"
|
|
258
|
-
else:
|
|
259
|
-
params["gres"] = f"gpu:{params['gpus_per_node']}"
|
|
262
|
+
def _setup_log_files(self, params: dict[str, Any]) -> None:
|
|
263
|
+
"""Set up log directory and file paths.
|
|
260
264
|
|
|
261
|
-
|
|
265
|
+
Parameters
|
|
266
|
+
----------
|
|
267
|
+
params : dict[str, Any]
|
|
268
|
+
Dictionary of launch parameters to set up log files
|
|
269
|
+
"""
|
|
262
270
|
params["log_dir"] = Path(params["log_dir"], params["model_family"]).expanduser()
|
|
263
271
|
params["log_dir"].mkdir(parents=True, exist_ok=True)
|
|
264
272
|
params["src_dir"] = SRC_DIR
|
|
265
273
|
|
|
266
|
-
# Construct slurm log file paths
|
|
267
274
|
params["out_file"] = (
|
|
268
275
|
f"{params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.out"
|
|
269
276
|
)
|
|
@@ -274,6 +281,35 @@ class ModelLauncher:
|
|
|
274
281
|
f"{params['log_dir']}/{self.model_name}.$SLURM_JOB_ID/{self.model_name}.$SLURM_JOB_ID.json"
|
|
275
282
|
)
|
|
276
283
|
|
|
284
|
+
def _get_launch_params(self) -> dict[str, Any]:
|
|
285
|
+
"""Prepare launch parameters, set log dir, and validate required fields.
|
|
286
|
+
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
dict[str, Any]
|
|
290
|
+
Dictionary of prepared launch parameters
|
|
291
|
+
"""
|
|
292
|
+
params = self.model_config.model_dump(exclude_none=True)
|
|
293
|
+
|
|
294
|
+
# Override config defaults with CLI arguments
|
|
295
|
+
self._apply_cli_overrides(params)
|
|
296
|
+
|
|
297
|
+
# Check for required fields without default vals, will raise an error if missing
|
|
298
|
+
utils.check_required_fields(params)
|
|
299
|
+
|
|
300
|
+
# Validate resource allocation and parallelization settings
|
|
301
|
+
self._validate_resource_allocation(params)
|
|
302
|
+
|
|
303
|
+
# Convert gpus_per_node and resource_type to gres
|
|
304
|
+
resource_type = params.get("resource_type")
|
|
305
|
+
if resource_type:
|
|
306
|
+
params["gres"] = f"gpu:{resource_type}:{params['gpus_per_node']}"
|
|
307
|
+
else:
|
|
308
|
+
params["gres"] = f"gpu:{params['gpus_per_node']}"
|
|
309
|
+
|
|
310
|
+
# Setup log files
|
|
311
|
+
self._setup_log_files(params)
|
|
312
|
+
|
|
277
313
|
# Convert path to string for JSON serialization
|
|
278
314
|
for field in params:
|
|
279
315
|
if field in ["vllm_args", "env"]:
|
|
@@ -332,6 +368,10 @@ class ModelLauncher:
|
|
|
332
368
|
job_log_dir / f"{self.model_name}.{self.slurm_job_id}.sbatch"
|
|
333
369
|
)
|
|
334
370
|
|
|
371
|
+
# Replace venv with image path if using container
|
|
372
|
+
if self.params["venv"] == CONTAINER_MODULE_NAME:
|
|
373
|
+
self.params["venv"] = IMAGE_PATH
|
|
374
|
+
|
|
335
375
|
with job_json.open("w") as file:
|
|
336
376
|
json.dump(self.params, file, indent=4)
|
|
337
377
|
|
|
@@ -429,16 +469,15 @@ class BatchModelLauncher:
|
|
|
429
469
|
If required fields are missing or tensor parallel size is not specified
|
|
430
470
|
when using multiple GPUs
|
|
431
471
|
"""
|
|
432
|
-
|
|
433
|
-
"models": {},
|
|
472
|
+
common_params: dict[str, Any] = {
|
|
434
473
|
"slurm_job_name": self.slurm_job_name,
|
|
435
474
|
"src_dir": str(SRC_DIR),
|
|
436
475
|
"account": account,
|
|
437
476
|
"work_dir": work_dir,
|
|
438
477
|
}
|
|
439
478
|
|
|
440
|
-
|
|
441
|
-
|
|
479
|
+
params: dict[str, Any] = common_params.copy()
|
|
480
|
+
params["models"] = {}
|
|
442
481
|
|
|
443
482
|
for i, (model_name, config) in enumerate(self.model_configs.items()):
|
|
444
483
|
params["models"][model_name] = config.model_dump(exclude_none=True)
|
|
@@ -515,6 +554,16 @@ class BatchModelLauncher:
|
|
|
515
554
|
raise ValueError(
|
|
516
555
|
f"Mismatch found for {arg}: {params[arg]} != {params['models'][model_name][arg]}, check your configuration"
|
|
517
556
|
)
|
|
557
|
+
# Check for required fields and return environment variable overrides
|
|
558
|
+
env_overrides = utils.check_required_fields(
|
|
559
|
+
{**params["models"][model_name], **common_params}
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
for arg, value in env_overrides.items():
|
|
563
|
+
if arg in common_params:
|
|
564
|
+
params[arg] = value
|
|
565
|
+
else:
|
|
566
|
+
params["models"][model_name][arg] = value
|
|
518
567
|
|
|
519
568
|
return params
|
|
520
569
|
|
|
@@ -678,7 +727,7 @@ class ModelStatusMonitor:
|
|
|
678
727
|
Basic status information for the job
|
|
679
728
|
"""
|
|
680
729
|
try:
|
|
681
|
-
job_name = self.job_status["JobName"]
|
|
730
|
+
job_name = self.job_status["JobName"].removesuffix("-vec-inf")
|
|
682
731
|
job_state = self.job_status["JobState"]
|
|
683
732
|
except KeyError:
|
|
684
733
|
job_name = "UNAVAILABLE"
|
|
@@ -14,6 +14,7 @@ from vec_inf.client._slurm_templates import (
|
|
|
14
14
|
BATCH_SLURM_SCRIPT_TEMPLATE,
|
|
15
15
|
SLURM_SCRIPT_TEMPLATE,
|
|
16
16
|
)
|
|
17
|
+
from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class SlurmScriptGenerator:
|
|
@@ -32,24 +33,35 @@ class SlurmScriptGenerator:
|
|
|
32
33
|
def __init__(self, params: dict[str, Any]):
|
|
33
34
|
self.params = params
|
|
34
35
|
self.is_multinode = int(self.params["num_nodes"]) > 1
|
|
35
|
-
self.use_container =
|
|
36
|
-
|
|
36
|
+
self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
|
|
37
|
+
self.additional_binds = (
|
|
38
|
+
f",{self.params['bind']}" if self.params.get("bind") else ""
|
|
37
39
|
)
|
|
38
|
-
self.additional_binds = self.params.get("bind", "")
|
|
39
|
-
if self.additional_binds:
|
|
40
|
-
self.additional_binds = f" --bind {self.additional_binds}"
|
|
41
40
|
self.model_weights_path = str(
|
|
42
41
|
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
|
|
43
42
|
)
|
|
43
|
+
self.env_str = self._generate_env_str()
|
|
44
|
+
|
|
45
|
+
def _generate_env_str(self) -> str:
|
|
46
|
+
"""Generate the environment variables string for the Slurm script.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
str
|
|
51
|
+
Formatted env vars string for container or shell export commands.
|
|
52
|
+
"""
|
|
44
53
|
env_dict: dict[str, str] = self.params.get("env", {})
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
54
|
+
|
|
55
|
+
if not env_dict:
|
|
56
|
+
return ""
|
|
57
|
+
|
|
58
|
+
if self.use_container:
|
|
59
|
+
# Format for container: --env KEY1=VAL1,KEY2=VAL2
|
|
60
|
+
env_pairs = [f"{key}={val}" for key, val in env_dict.items()]
|
|
61
|
+
return f"--env {','.join(env_pairs)}"
|
|
62
|
+
# Format for shell: export KEY1=VAL1\nexport KEY2=VAL2
|
|
63
|
+
export_lines = [f"export {key}={val}" for key, val in env_dict.items()]
|
|
64
|
+
return "\n".join(export_lines)
|
|
53
65
|
|
|
54
66
|
def _generate_script_content(self) -> str:
|
|
55
67
|
"""Generate the complete Slurm script content.
|
|
@@ -77,6 +89,8 @@ class SlurmScriptGenerator:
|
|
|
77
89
|
for arg, value in SLURM_JOB_CONFIG_ARGS.items():
|
|
78
90
|
if self.params.get(value):
|
|
79
91
|
shebang.append(f"#SBATCH --{arg}={self.params[value]}")
|
|
92
|
+
if value == "model_name":
|
|
93
|
+
shebang[-1] += "-vec-inf"
|
|
80
94
|
if self.is_multinode:
|
|
81
95
|
shebang += SLURM_SCRIPT_TEMPLATE["shebang"]["multinode"]
|
|
82
96
|
return "\n".join(shebang)
|
|
@@ -95,7 +109,17 @@ class SlurmScriptGenerator:
|
|
|
95
109
|
server_script = ["\n"]
|
|
96
110
|
if self.use_container:
|
|
97
111
|
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
|
|
98
|
-
|
|
112
|
+
server_script.append(
|
|
113
|
+
SLURM_SCRIPT_TEMPLATE["bind_path"].format(
|
|
114
|
+
model_weights_path=self.model_weights_path,
|
|
115
|
+
additional_binds=self.additional_binds,
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
server_script.append(
|
|
120
|
+
SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
|
|
121
|
+
)
|
|
122
|
+
server_script.append(self.env_str)
|
|
99
123
|
server_script.append(
|
|
100
124
|
SLURM_SCRIPT_TEMPLATE["imports"].format(src_dir=self.params["src_dir"])
|
|
101
125
|
)
|
|
@@ -108,10 +132,14 @@ class SlurmScriptGenerator:
|
|
|
108
132
|
"CONTAINER_PLACEHOLDER",
|
|
109
133
|
SLURM_SCRIPT_TEMPLATE["container_command"].format(
|
|
110
134
|
model_weights_path=self.model_weights_path,
|
|
111
|
-
additional_binds=self.additional_binds,
|
|
112
135
|
env_str=self.env_str,
|
|
113
136
|
),
|
|
114
137
|
)
|
|
138
|
+
else:
|
|
139
|
+
server_setup_str = server_setup_str.replace(
|
|
140
|
+
"CONTAINER_PLACEHOLDER",
|
|
141
|
+
"\\",
|
|
142
|
+
)
|
|
115
143
|
else:
|
|
116
144
|
server_setup_str = "\n".join(
|
|
117
145
|
SLURM_SCRIPT_TEMPLATE["server_setup"]["single_node"]
|
|
@@ -141,14 +169,10 @@ class SlurmScriptGenerator:
|
|
|
141
169
|
launcher_script.append(
|
|
142
170
|
SLURM_SCRIPT_TEMPLATE["container_command"].format(
|
|
143
171
|
model_weights_path=self.model_weights_path,
|
|
144
|
-
additional_binds=self.additional_binds,
|
|
145
172
|
env_str=self.env_str,
|
|
146
173
|
)
|
|
147
174
|
)
|
|
148
|
-
|
|
149
|
-
launcher_script.append(
|
|
150
|
-
SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
|
|
151
|
-
)
|
|
175
|
+
|
|
152
176
|
launcher_script.append(
|
|
153
177
|
"\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
|
|
154
178
|
model_weights_path=self.model_weights_path,
|
|
@@ -194,15 +218,13 @@ class BatchSlurmScriptGenerator:
|
|
|
194
218
|
def __init__(self, params: dict[str, Any]):
|
|
195
219
|
self.params = params
|
|
196
220
|
self.script_paths: list[Path] = []
|
|
197
|
-
self.use_container =
|
|
198
|
-
self.params["venv"] == "singularity" or self.params["venv"] == "apptainer"
|
|
199
|
-
)
|
|
221
|
+
self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
|
|
200
222
|
for model_name in self.params["models"]:
|
|
201
|
-
self.params["models"][model_name]["additional_binds"] =
|
|
202
|
-
|
|
203
|
-
self.params["models"][model_name]
|
|
204
|
-
|
|
205
|
-
|
|
223
|
+
self.params["models"][model_name]["additional_binds"] = (
|
|
224
|
+
f",{self.params['models'][model_name]['bind']}"
|
|
225
|
+
if self.params["models"][model_name].get("bind")
|
|
226
|
+
else ""
|
|
227
|
+
)
|
|
206
228
|
self.params["models"][model_name]["model_weights_path"] = str(
|
|
207
229
|
Path(
|
|
208
230
|
self.params["models"][model_name]["model_weights_parent_dir"],
|
|
@@ -242,7 +264,12 @@ class BatchSlurmScriptGenerator:
|
|
|
242
264
|
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["shebang"])
|
|
243
265
|
if self.use_container:
|
|
244
266
|
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
|
|
245
|
-
script_content.append(
|
|
267
|
+
script_content.append(
|
|
268
|
+
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
|
|
269
|
+
model_weights_path=model_params["model_weights_path"],
|
|
270
|
+
additional_binds=model_params["additional_binds"],
|
|
271
|
+
)
|
|
272
|
+
)
|
|
246
273
|
script_content.append(
|
|
247
274
|
"\n".join(
|
|
248
275
|
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["server_address_setup"]
|
|
@@ -260,7 +287,6 @@ class BatchSlurmScriptGenerator:
|
|
|
260
287
|
script_content.append(
|
|
261
288
|
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format(
|
|
262
289
|
model_weights_path=model_params["model_weights_path"],
|
|
263
|
-
additional_binds=model_params["additional_binds"],
|
|
264
290
|
)
|
|
265
291
|
)
|
|
266
292
|
script_content.append(
|
|
@@ -304,6 +330,8 @@ class BatchSlurmScriptGenerator:
|
|
|
304
330
|
model_params = self.params["models"][model_name]
|
|
305
331
|
if model_params.get(value) and value not in ["out_file", "err_file"]:
|
|
306
332
|
shebang.append(f"#SBATCH --{arg}={model_params[value]}")
|
|
333
|
+
if value == "model_name":
|
|
334
|
+
shebang[-1] += "-vec-inf"
|
|
307
335
|
shebang[-1] += "\n"
|
|
308
336
|
shebang.append(BATCH_SLURM_SCRIPT_TEMPLATE["hetjob"])
|
|
309
337
|
# Remove the last hetjob line
|