vec-inf 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vec_inf/README.md CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  * `launch`: Specify a model family and other optional parameters to launch an OpenAI compatible inference server.
4
4
  * `batch-launch`: Specify a list of models to launch multiple OpenAI compatible inference servers at the same time.
5
- * `status`: Check the model status by providing its Slurm job ID.
5
+ * `status`: Check the status of all `vec-inf` jobs, or a specific job by providing its job ID.
6
6
  * `metrics`: Streams performance metrics to the console.
7
7
  * `shutdown`: Shutdown a model by providing its Slurm job ID.
8
8
  * `list`: List all available model names, or view the default/cached configuration of a specific model.
@@ -14,6 +14,7 @@ Use `--help` to see all available options
14
14
 
15
15
  * `launch_model`: Launch an OpenAI compatible inference server.
16
16
  * `batch_launch_models`: Launch multiple OpenAI compatible inference servers.
17
+ * `fetch_running_jobs`: Get the running `vec-inf` job IDs.
17
18
  * `get_status`: Get the status of a running model.
18
19
  * `get_metrics`: Get the performance metrics of a running model.
19
20
  * `shutdown_model`: Shutdown a running model.
vec_inf/cli/_cli.py CHANGED
@@ -30,6 +30,7 @@ from vec_inf.cli._helper import (
30
30
  BatchLaunchResponseFormatter,
31
31
  LaunchResponseFormatter,
32
32
  ListCmdDisplay,
33
+ ListStatusDisplay,
33
34
  MetricsResponseFormatter,
34
35
  StatusResponseFormatter,
35
36
  )
@@ -69,6 +70,16 @@ def cli() -> None:
69
70
  type=int,
70
71
  help="Number of GPUs/node to use, default to suggested resource allocation for model",
71
72
  )
73
+ @click.option(
74
+ "--cpus-per-task",
75
+ type=int,
76
+ help="Number of CPU cores per task",
77
+ )
78
+ @click.option(
79
+ "--mem-per-node",
80
+ type=str,
81
+ help="Memory allocation per node in GB format (e.g., '32G')",
82
+ )
72
83
  @click.option(
73
84
  "--account",
74
85
  "-A",
@@ -165,6 +176,10 @@ def launch(
165
176
  Number of nodes to use
166
177
  - gpus_per_node : int, optional
167
178
  Number of GPUs per node
179
+ - cpus_per_task : int, optional
180
+ Number of CPU cores per task
181
+ - mem_per_node : str, optional
182
+ Memory allocation per node in GB format (e.g., '32G')
168
183
  - account : str, optional
169
184
  Charge resources used by this job to specified account
170
185
  - work_dir : str, optional
@@ -299,14 +314,14 @@ def batch_launch(
299
314
  raise click.ClickException(f"Batch launch failed: {str(e)}") from e
300
315
 
301
316
 
302
- @cli.command("status", help="Check the status of a running model on the cluster.")
303
- @click.argument("slurm_job_id", type=str, nargs=1)
317
+ @cli.command("status", help="Check the status of running vec-inf jobs on the cluster.")
318
+ @click.argument("slurm_job_id", required=False)
304
319
  @click.option(
305
320
  "--json-mode",
306
321
  is_flag=True,
307
322
  help="Output in JSON string",
308
323
  )
309
- def status(slurm_job_id: str, json_mode: bool = False) -> None:
324
+ def status(slurm_job_id: Optional[str] = None, json_mode: bool = False) -> None:
310
325
  """Get the status of a running model on the cluster.
311
326
 
312
327
  Parameters
@@ -324,14 +339,28 @@ def status(slurm_job_id: str, json_mode: bool = False) -> None:
324
339
  try:
325
340
  # Start the client and get model inference server status
326
341
  client = VecInfClient()
327
- status_response = client.get_status(slurm_job_id)
342
+ if not slurm_job_id:
343
+ slurm_job_ids = client.fetch_running_jobs()
344
+ if not slurm_job_ids:
345
+ click.echo("No running jobs found.")
346
+ return
347
+ else:
348
+ slurm_job_ids = [slurm_job_id]
349
+ responses = []
350
+ for job_id in slurm_job_ids:
351
+ responses.append(client.get_status(job_id))
352
+
328
353
  # Display status information
329
- status_formatter = StatusResponseFormatter(status_response)
330
- if json_mode:
331
- status_formatter.output_json()
354
+ if slurm_job_id:
355
+ status_formatter = StatusResponseFormatter(responses[0])
356
+ if json_mode:
357
+ status_formatter.output_json()
358
+ else:
359
+ status_info_table = status_formatter.output_table()
360
+ CONSOLE.print(status_info_table)
332
361
  else:
333
- status_info_table = status_formatter.output_table()
334
- CONSOLE.print(status_info_table)
362
+ list_status_display = ListStatusDisplay(slurm_job_ids, responses, json_mode)
363
+ list_status_display.display_multiple_status_output(CONSOLE)
335
364
 
336
365
  except click.ClickException as e:
337
366
  raise e
@@ -447,7 +476,7 @@ def metrics(slurm_job_id: str) -> None:
447
476
  metrics_formatter.format_metrics()
448
477
 
449
478
  live.update(metrics_formatter.table)
450
- time.sleep(2)
479
+ time.sleep(1)
451
480
  except click.ClickException as e:
452
481
  raise e
453
482
  except Exception as e:
vec_inf/cli/_helper.py CHANGED
@@ -36,6 +36,43 @@ class LaunchResponseFormatter:
36
36
  self.model_name = model_name
37
37
  self.params = params
38
38
 
39
+ def _add_resource_allocation_details(self, table: Table) -> None:
40
+ """Add resource allocation details to the table."""
41
+ optional_fields = [
42
+ ("account", "Account"),
43
+ ("work_dir", "Working Directory"),
44
+ ("resource_type", "Resource Type"),
45
+ ("partition", "Partition"),
46
+ ("qos", "QoS"),
47
+ ]
48
+ for key, label in optional_fields:
49
+ if self.params.get(key):
50
+ table.add_row(label, self.params[key])
51
+
52
+ def _add_vllm_config(self, table: Table) -> None:
53
+ """Add vLLM configuration details to the table."""
54
+ if self.params.get("vllm_args"):
55
+ table.add_row("vLLM Arguments:", style="magenta")
56
+ for arg, value in self.params["vllm_args"].items():
57
+ table.add_row(f" {arg}:", str(value))
58
+
59
+ def _add_env_vars(self, table: Table) -> None:
60
+ """Add environment variable configuration details to the table."""
61
+ if self.params.get("env"):
62
+ table.add_row("Environment Variables", style="magenta")
63
+ for arg, value in self.params["env"].items():
64
+ table.add_row(f" {arg}:", str(value))
65
+
66
+ def _add_bind_paths(self, table: Table) -> None:
67
+ """Add bind path configuration details to the table."""
68
+ if self.params.get("bind"):
69
+ table.add_row("Bind Paths", style="magenta")
70
+ for path in self.params["bind"].split(","):
71
+ host = target = path
72
+ if ":" in path:
73
+ host, target = path.split(":")
74
+ table.add_row(f" {host}:", target)
75
+
39
76
  def format_table_output(self) -> Table:
40
77
  """Format output as rich Table.
41
78
 
@@ -59,16 +96,7 @@ class LaunchResponseFormatter:
59
96
  table.add_row("Vocabulary Size", self.params["vocab_size"])
60
97
 
61
98
  # Add resource allocation details
62
- if self.params.get("account"):
63
- table.add_row("Account", self.params["account"])
64
- if self.params.get("work_dir"):
65
- table.add_row("Working Directory", self.params["work_dir"])
66
- if self.params.get("resource_type"):
67
- table.add_row("Resource Type", self.params["resource_type"])
68
- if self.params.get("partition"):
69
- table.add_row("Partition", self.params["partition"])
70
- if self.params.get("qos"):
71
- table.add_row("QoS", self.params["qos"])
99
+ self._add_resource_allocation_details(table)
72
100
  table.add_row("Time Limit", self.params["time"])
73
101
  table.add_row("Num Nodes", self.params["num_nodes"])
74
102
  table.add_row("GPUs/Node", self.params["gpus_per_node"])
@@ -76,21 +104,18 @@ class LaunchResponseFormatter:
76
104
  table.add_row("Memory/Node", self.params["mem_per_node"])
77
105
 
78
106
  # Add job config details
107
+ if self.params.get("venv"):
108
+ table.add_row("Virtual Environment", self.params["venv"])
79
109
  table.add_row(
80
110
  "Model Weights Directory",
81
111
  str(Path(self.params["model_weights_parent_dir"], self.model_name)),
82
112
  )
83
113
  table.add_row("Log Directory", self.params["log_dir"])
84
114
 
85
- # Add vLLM configuration details
86
- table.add_row("vLLM Arguments:", style="magenta")
87
- for arg, value in self.params["vllm_args"].items():
88
- table.add_row(f" {arg}:", str(value))
89
-
90
- # Add Environment Variable Configuration Details
91
- table.add_row("Environment Variables", style="magenta")
92
- for arg, value in self.params["env"].items():
93
- table.add_row(f" {arg}:", str(value))
115
+ # Add configuration details
116
+ self._add_vllm_config(table)
117
+ self._add_env_vars(table)
118
+ self._add_bind_paths(table)
94
119
 
95
120
  return table
96
121
 
@@ -226,6 +251,62 @@ class StatusResponseFormatter:
226
251
  return table
227
252
 
228
253
 
254
+ class ListStatusDisplay:
255
+ """CLI Helper class for formatting a list of StatusResponse.
256
+
257
+ A formatter class that handles the presentation of multiple job statuses
258
+ in a table format.
259
+
260
+ Parameters
261
+ ----------
262
+ statuses : list[StatusResponse]
263
+ List of model status information
264
+ """
265
+
266
+ def __init__(
267
+ self,
268
+ job_ids: list[str],
269
+ statuses: list[StatusResponse],
270
+ json_mode: bool = False,
271
+ ):
272
+ self.job_ids = job_ids
273
+ self.statuses = statuses
274
+ self.json_mode = json_mode
275
+
276
+ self.table = Table(show_header=True, header_style="bold magenta")
277
+ self.table.add_column("Job ID")
278
+ self.table.add_column("Model Name")
279
+ self.table.add_column("Status", style="blue")
280
+ self.table.add_column("Base URL")
281
+
282
+ def display_multiple_status_output(self, console: Console) -> None:
283
+ """Format and display all model statuses.
284
+
285
+ Formats each model's status and adds it to the table.
286
+ """
287
+ if self.json_mode:
288
+ json_data = [
289
+ {
290
+ "job_id": status.model_name,
291
+ "model_name": status.model_name,
292
+ "model_status": status.server_status,
293
+ "base_url": status.base_url,
294
+ }
295
+ for status in self.statuses
296
+ ]
297
+ click.echo(json.dumps(json_data, indent=4))
298
+ return
299
+
300
+ for i, status in enumerate(self.statuses):
301
+ self.table.add_row(
302
+ self.job_ids[i],
303
+ status.model_name,
304
+ status.server_status,
305
+ status.base_url,
306
+ )
307
+ console.print(self.table)
308
+
309
+
229
310
  class MetricsResponseFormatter:
230
311
  """CLI Helper class for formatting MetricsResponse.
231
312
 
vec_inf/client/_helper.py CHANGED
@@ -31,6 +31,7 @@ from vec_inf.client._slurm_script_generator import (
31
31
  BatchSlurmScriptGenerator,
32
32
  SlurmScriptGenerator,
33
33
  )
34
+ from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME, IMAGE_PATH
34
35
  from vec_inf.client.config import ModelConfig
35
36
  from vec_inf.client.models import (
36
37
  BatchLaunchResponse,
@@ -195,23 +196,14 @@ class ModelLauncher:
195
196
  print(f"WARNING: Could not parse env var: {line}")
196
197
  return env_vars
197
198
 
198
- def _get_launch_params(self) -> dict[str, Any]:
199
- """Prepare launch parameters, set log dir, and validate required fields.
199
+ def _apply_cli_overrides(self, params: dict[str, Any]) -> None:
200
+ """Apply CLI argument overrides to params.
200
201
 
201
- Returns
202
- -------
203
- dict[str, Any]
204
- Dictionary of prepared launch parameters
205
-
206
- Raises
207
- ------
208
- MissingRequiredFieldsError
209
- If required fields are missing or tensor parallel size is not specified
210
- when using multiple GPUs
202
+ Parameters
203
+ ----------
204
+ params : dict[str, Any]
205
+ Dictionary of launch parameters to override
211
206
  """
212
- params = self.model_config.model_dump(exclude_none=True)
213
-
214
- # Override config defaults with CLI arguments
215
207
  if self.kwargs.get("vllm_args"):
216
208
  vllm_args = self._process_vllm_args(self.kwargs["vllm_args"])
217
209
  for key, value in vllm_args.items():
@@ -224,13 +216,29 @@ class ModelLauncher:
224
216
  params["env"][key] = str(value)
225
217
  del self.kwargs["env"]
226
218
 
219
+ if self.kwargs.get("bind") and params.get("bind"):
220
+ params["bind"] = f"{params['bind']},{self.kwargs['bind']}"
221
+ del self.kwargs["bind"]
222
+
227
223
  for key, value in self.kwargs.items():
228
224
  params[key] = value
229
225
 
230
- # Check for required fields without default vals, will raise an error if missing
231
- utils.check_required_fields(params)
226
+ def _validate_resource_allocation(self, params: dict[str, Any]) -> None:
227
+ """Validate resource allocation and parallelization settings.
232
228
 
233
- # Validate resource allocation and parallelization settings
229
+ Parameters
230
+ ----------
231
+ params : dict[str, Any]
232
+ Dictionary of launch parameters to validate
233
+
234
+ Raises
235
+ ------
236
+ MissingRequiredFieldsError
237
+ If tensor parallel size is not specified when using multiple GPUs
238
+ ValueError
239
+ If total # of GPUs requested is not a power of two
240
+ If mismatch between total # of GPUs requested and parallelization settings
241
+ """
234
242
  if (
235
243
  int(params["gpus_per_node"]) > 1
236
244
  and params["vllm_args"].get("--tensor-parallel-size") is None
@@ -251,19 +259,18 @@ class ModelLauncher:
251
259
  "Mismatch between total number of GPUs requested and parallelization settings"
252
260
  )
253
261
 
254
- # Convert gpus_per_node and resource_type to gres
255
- resource_type = params.get("resource_type")
256
- if resource_type:
257
- params["gres"] = f"gpu:{resource_type}:{params['gpus_per_node']}"
258
- else:
259
- params["gres"] = f"gpu:{params['gpus_per_node']}"
262
+ def _setup_log_files(self, params: dict[str, Any]) -> None:
263
+ """Set up log directory and file paths.
260
264
 
261
- # Create log directory
265
+ Parameters
266
+ ----------
267
+ params : dict[str, Any]
268
+ Dictionary of launch parameters to set up log files
269
+ """
262
270
  params["log_dir"] = Path(params["log_dir"], params["model_family"]).expanduser()
263
271
  params["log_dir"].mkdir(parents=True, exist_ok=True)
264
272
  params["src_dir"] = SRC_DIR
265
273
 
266
- # Construct slurm log file paths
267
274
  params["out_file"] = (
268
275
  f"{params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.out"
269
276
  )
@@ -274,6 +281,35 @@ class ModelLauncher:
274
281
  f"{params['log_dir']}/{self.model_name}.$SLURM_JOB_ID/{self.model_name}.$SLURM_JOB_ID.json"
275
282
  )
276
283
 
284
+ def _get_launch_params(self) -> dict[str, Any]:
285
+ """Prepare launch parameters, set log dir, and validate required fields.
286
+
287
+ Returns
288
+ -------
289
+ dict[str, Any]
290
+ Dictionary of prepared launch parameters
291
+ """
292
+ params = self.model_config.model_dump(exclude_none=True)
293
+
294
+ # Override config defaults with CLI arguments
295
+ self._apply_cli_overrides(params)
296
+
297
+ # Check for required fields without default vals, will raise an error if missing
298
+ utils.check_required_fields(params)
299
+
300
+ # Validate resource allocation and parallelization settings
301
+ self._validate_resource_allocation(params)
302
+
303
+ # Convert gpus_per_node and resource_type to gres
304
+ resource_type = params.get("resource_type")
305
+ if resource_type:
306
+ params["gres"] = f"gpu:{resource_type}:{params['gpus_per_node']}"
307
+ else:
308
+ params["gres"] = f"gpu:{params['gpus_per_node']}"
309
+
310
+ # Setup log files
311
+ self._setup_log_files(params)
312
+
277
313
  # Convert path to string for JSON serialization
278
314
  for field in params:
279
315
  if field in ["vllm_args", "env"]:
@@ -332,6 +368,10 @@ class ModelLauncher:
332
368
  job_log_dir / f"{self.model_name}.{self.slurm_job_id}.sbatch"
333
369
  )
334
370
 
371
+ # Replace venv with image path if using container
372
+ if self.params["venv"] == CONTAINER_MODULE_NAME:
373
+ self.params["venv"] = IMAGE_PATH
374
+
335
375
  with job_json.open("w") as file:
336
376
  json.dump(self.params, file, indent=4)
337
377
 
@@ -429,16 +469,15 @@ class BatchModelLauncher:
429
469
  If required fields are missing or tensor parallel size is not specified
430
470
  when using multiple GPUs
431
471
  """
432
- params: dict[str, Any] = {
433
- "models": {},
472
+ common_params: dict[str, Any] = {
434
473
  "slurm_job_name": self.slurm_job_name,
435
474
  "src_dir": str(SRC_DIR),
436
475
  "account": account,
437
476
  "work_dir": work_dir,
438
477
  }
439
478
 
440
- # Check for required fields without default vals, will raise an error if missing
441
- utils.check_required_fields(params)
479
+ params: dict[str, Any] = common_params.copy()
480
+ params["models"] = {}
442
481
 
443
482
  for i, (model_name, config) in enumerate(self.model_configs.items()):
444
483
  params["models"][model_name] = config.model_dump(exclude_none=True)
@@ -515,6 +554,16 @@ class BatchModelLauncher:
515
554
  raise ValueError(
516
555
  f"Mismatch found for {arg}: {params[arg]} != {params['models'][model_name][arg]}, check your configuration"
517
556
  )
557
+ # Check for required fields and return environment variable overrides
558
+ env_overrides = utils.check_required_fields(
559
+ {**params["models"][model_name], **common_params}
560
+ )
561
+
562
+ for arg, value in env_overrides.items():
563
+ if arg in common_params:
564
+ params[arg] = value
565
+ else:
566
+ params["models"][model_name][arg] = value
518
567
 
519
568
  return params
520
569
 
@@ -678,7 +727,7 @@ class ModelStatusMonitor:
678
727
  Basic status information for the job
679
728
  """
680
729
  try:
681
- job_name = self.job_status["JobName"]
730
+ job_name = self.job_status["JobName"].removesuffix("-vec-inf")
682
731
  job_state = self.job_status["JobState"]
683
732
  except KeyError:
684
733
  job_name = "UNAVAILABLE"
@@ -14,6 +14,7 @@ from vec_inf.client._slurm_templates import (
14
14
  BATCH_SLURM_SCRIPT_TEMPLATE,
15
15
  SLURM_SCRIPT_TEMPLATE,
16
16
  )
17
+ from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME
17
18
 
18
19
 
19
20
  class SlurmScriptGenerator:
@@ -32,24 +33,35 @@ class SlurmScriptGenerator:
32
33
  def __init__(self, params: dict[str, Any]):
33
34
  self.params = params
34
35
  self.is_multinode = int(self.params["num_nodes"]) > 1
35
- self.use_container = (
36
- self.params["venv"] == "singularity" or self.params["venv"] == "apptainer"
36
+ self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
37
+ self.additional_binds = (
38
+ f",{self.params['bind']}" if self.params.get("bind") else ""
37
39
  )
38
- self.additional_binds = self.params.get("bind", "")
39
- if self.additional_binds:
40
- self.additional_binds = f" --bind {self.additional_binds}"
41
40
  self.model_weights_path = str(
42
41
  Path(self.params["model_weights_parent_dir"], self.params["model_name"])
43
42
  )
43
+ self.env_str = self._generate_env_str()
44
+
45
+ def _generate_env_str(self) -> str:
46
+ """Generate the environment variables string for the Slurm script.
47
+
48
+ Returns
49
+ -------
50
+ str
51
+ Formatted env vars string for container or shell export commands.
52
+ """
44
53
  env_dict: dict[str, str] = self.params.get("env", {})
45
- # Create string of environment variables
46
- self.env_str = ""
47
- for key, val in env_dict.items():
48
- if len(self.env_str) == 0:
49
- self.env_str = "--env "
50
- else:
51
- self.env_str += ","
52
- self.env_str += key + "=" + val
54
+
55
+ if not env_dict:
56
+ return ""
57
+
58
+ if self.use_container:
59
+ # Format for container: --env KEY1=VAL1,KEY2=VAL2
60
+ env_pairs = [f"{key}={val}" for key, val in env_dict.items()]
61
+ return f"--env {','.join(env_pairs)}"
62
+ # Format for shell: export KEY1=VAL1\nexport KEY2=VAL2
63
+ export_lines = [f"export {key}={val}" for key, val in env_dict.items()]
64
+ return "\n".join(export_lines)
53
65
 
54
66
  def _generate_script_content(self) -> str:
55
67
  """Generate the complete Slurm script content.
@@ -77,6 +89,8 @@ class SlurmScriptGenerator:
77
89
  for arg, value in SLURM_JOB_CONFIG_ARGS.items():
78
90
  if self.params.get(value):
79
91
  shebang.append(f"#SBATCH --{arg}={self.params[value]}")
92
+ if value == "model_name":
93
+ shebang[-1] += "-vec-inf"
80
94
  if self.is_multinode:
81
95
  shebang += SLURM_SCRIPT_TEMPLATE["shebang"]["multinode"]
82
96
  return "\n".join(shebang)
@@ -95,7 +109,17 @@ class SlurmScriptGenerator:
95
109
  server_script = ["\n"]
96
110
  if self.use_container:
97
111
  server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
98
- server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["env_vars"]))
112
+ server_script.append(
113
+ SLURM_SCRIPT_TEMPLATE["bind_path"].format(
114
+ model_weights_path=self.model_weights_path,
115
+ additional_binds=self.additional_binds,
116
+ )
117
+ )
118
+ else:
119
+ server_script.append(
120
+ SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
121
+ )
122
+ server_script.append(self.env_str)
99
123
  server_script.append(
100
124
  SLURM_SCRIPT_TEMPLATE["imports"].format(src_dir=self.params["src_dir"])
101
125
  )
@@ -108,10 +132,14 @@ class SlurmScriptGenerator:
108
132
  "CONTAINER_PLACEHOLDER",
109
133
  SLURM_SCRIPT_TEMPLATE["container_command"].format(
110
134
  model_weights_path=self.model_weights_path,
111
- additional_binds=self.additional_binds,
112
135
  env_str=self.env_str,
113
136
  ),
114
137
  )
138
+ else:
139
+ server_setup_str = server_setup_str.replace(
140
+ "CONTAINER_PLACEHOLDER",
141
+ "\\",
142
+ )
115
143
  else:
116
144
  server_setup_str = "\n".join(
117
145
  SLURM_SCRIPT_TEMPLATE["server_setup"]["single_node"]
@@ -141,14 +169,10 @@ class SlurmScriptGenerator:
141
169
  launcher_script.append(
142
170
  SLURM_SCRIPT_TEMPLATE["container_command"].format(
143
171
  model_weights_path=self.model_weights_path,
144
- additional_binds=self.additional_binds,
145
172
  env_str=self.env_str,
146
173
  )
147
174
  )
148
- else:
149
- launcher_script.append(
150
- SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
151
- )
175
+
152
176
  launcher_script.append(
153
177
  "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
154
178
  model_weights_path=self.model_weights_path,
@@ -194,15 +218,13 @@ class BatchSlurmScriptGenerator:
194
218
  def __init__(self, params: dict[str, Any]):
195
219
  self.params = params
196
220
  self.script_paths: list[Path] = []
197
- self.use_container = (
198
- self.params["venv"] == "singularity" or self.params["venv"] == "apptainer"
199
- )
221
+ self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
200
222
  for model_name in self.params["models"]:
201
- self.params["models"][model_name]["additional_binds"] = ""
202
- if self.params["models"][model_name].get("bind"):
203
- self.params["models"][model_name]["additional_binds"] = (
204
- f" --bind {self.params['models'][model_name]['bind']}"
205
- )
223
+ self.params["models"][model_name]["additional_binds"] = (
224
+ f",{self.params['models'][model_name]['bind']}"
225
+ if self.params["models"][model_name].get("bind")
226
+ else ""
227
+ )
206
228
  self.params["models"][model_name]["model_weights_path"] = str(
207
229
  Path(
208
230
  self.params["models"][model_name]["model_weights_parent_dir"],
@@ -242,7 +264,12 @@ class BatchSlurmScriptGenerator:
242
264
  script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["shebang"])
243
265
  if self.use_container:
244
266
  script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
245
- script_content.append("\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["env_vars"]))
267
+ script_content.append(
268
+ BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
269
+ model_weights_path=model_params["model_weights_path"],
270
+ additional_binds=model_params["additional_binds"],
271
+ )
272
+ )
246
273
  script_content.append(
247
274
  "\n".join(
248
275
  BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["server_address_setup"]
@@ -260,7 +287,6 @@ class BatchSlurmScriptGenerator:
260
287
  script_content.append(
261
288
  BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format(
262
289
  model_weights_path=model_params["model_weights_path"],
263
- additional_binds=model_params["additional_binds"],
264
290
  )
265
291
  )
266
292
  script_content.append(
@@ -304,6 +330,8 @@ class BatchSlurmScriptGenerator:
304
330
  model_params = self.params["models"][model_name]
305
331
  if model_params.get(value) and value not in ["out_file", "err_file"]:
306
332
  shebang.append(f"#SBATCH --{arg}={model_params[value]}")
333
+ if value == "model_name":
334
+ shebang[-1] += "-vec-inf"
307
335
  shebang[-1] += "\n"
308
336
  shebang.append(BATCH_SLURM_SCRIPT_TEMPLATE["hetjob"])
309
337
  # Remove the last hetjob line