vec-inf 0.7.2__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vec_inf/client/_helper.py CHANGED
@@ -17,9 +17,10 @@ import requests
17
17
  import vec_inf.client._utils as utils
18
18
  from vec_inf.client._client_vars import (
19
19
  BATCH_MODE_REQUIRED_MATCHING_ARGS,
20
+ ENGINE_SHORT_TO_LONG_MAP,
20
21
  KEY_METRICS,
21
22
  SRC_DIR,
22
- VLLM_SHORT_TO_LONG_MAP,
23
+ SUPPORTED_ENGINES,
23
24
  )
24
25
  from vec_inf.client._exceptions import (
25
26
  MissingRequiredFieldsError,
@@ -63,6 +64,7 @@ class ModelLauncher:
63
64
  self.slurm_job_id = ""
64
65
  self.slurm_script_path = Path("")
65
66
  self.model_config = self._get_model_configuration(self.kwargs.get("config"))
67
+ self.engine = ""
66
68
  self.params = self._get_launch_params()
67
69
 
68
70
  def _warn(self, message: str) -> None:
@@ -137,32 +139,38 @@ class ModelLauncher:
137
139
  f"not found at expected path '{model_weights_path}'"
138
140
  )
139
141
 
140
- def _process_vllm_args(self, arg_string: str) -> dict[str, Any]:
141
- """Process the vllm_args string into a dictionary.
142
+ def _process_engine_args(
143
+ self, arg_string: str, engine_choice: str
144
+ ) -> dict[str, Any]:
145
+ """Process the engine_args string into a dictionary.
142
146
 
143
147
  Parameters
144
148
  ----------
145
149
  arg_string : str
146
- Comma-separated string of vLLM arguments
150
+ Comma-separated string of inference engine arguments
147
151
 
148
152
  Returns
149
153
  -------
150
154
  dict[str, Any]
151
- Processed vLLM arguments as key-value pairs
155
+ Processed inference engine arguments as key-value pairs
152
156
  """
153
- vllm_args: dict[str, str | bool] = {}
157
+ engine_args: dict[str, str | bool] = {}
158
+ engine_arg_map = ENGINE_SHORT_TO_LONG_MAP[engine_choice]
159
+
154
160
  for arg in arg_string.split(","):
155
161
  if "=" in arg:
156
162
  key, value = arg.split("=")
157
- if key.strip() in VLLM_SHORT_TO_LONG_MAP:
158
- key = VLLM_SHORT_TO_LONG_MAP[key.strip()]
159
- vllm_args[key.strip()] = value.strip()
163
+ if key.strip() in engine_arg_map:
164
+ key = engine_arg_map[key.strip()]
165
+ engine_args[key.strip()] = value.strip()
160
166
  elif "-O" in arg.strip():
161
- key = VLLM_SHORT_TO_LONG_MAP["-O"]
162
- vllm_args[key] = arg.strip()[2:].strip()
167
+ if engine_choice != "vllm":
168
+ raise ValueError("-O is only supported for vLLM")
169
+ key = engine_arg_map["-O"]
170
+ engine_args[key] = arg.strip()[2:].strip()
163
171
  else:
164
- vllm_args[arg.strip()] = True
165
- return vllm_args
172
+ engine_args[arg.strip()] = True
173
+ return engine_args
166
174
 
167
175
  def _process_env_vars(self, env_arg: str) -> dict[str, str]:
168
176
  """Process the env string into a dictionary of environment variables.
@@ -196,6 +204,63 @@ class ModelLauncher:
196
204
  print(f"WARNING: Could not parse env var: {line}")
197
205
  return env_vars
198
206
 
207
+ def _engine_check_override(self, params: dict[str, Any]) -> None:
208
+ """Check for engine override in CLI args and warn user.
209
+
210
+ Parameters
211
+ ----------
212
+ params : dict[str, Any]
213
+ Dictionary of launch parameters to check
214
+ """
215
+
216
+ def overwrite_engine_args(params: dict[str, Any]) -> None:
217
+ engine_args = self._process_engine_args(
218
+ self.kwargs[f"{self.engine}_args"], self.engine
219
+ )
220
+ for key, value in engine_args.items():
221
+ params["engine_args"][key] = value
222
+ del self.kwargs[f"{self.engine}_args"]
223
+
224
+ # Infer engine name from engine-specific args if provided
225
+ extracted_engine = ""
226
+ for engine in SUPPORTED_ENGINES:
227
+ if self.kwargs.get(f"{engine}_args"):
228
+ if not extracted_engine:
229
+ extracted_engine = engine
230
+ else:
231
+ raise ValueError(
232
+ "Cannot provide engine-specific args for multiple engines, please choose one"
233
+ )
234
+ # Check for mismatch between provided engine arg and engine-specific args
235
+ input_engine = self.kwargs.get("engine", "")
236
+
237
+ if input_engine and extracted_engine:
238
+ if input_engine != extracted_engine:
239
+ raise ValueError(
240
+ f"Mismatch between provided engine '{input_engine}' and engine-specific args '{extracted_engine}'"
241
+ )
242
+ self.engine = input_engine
243
+ params["engine_args"] = params[f"{self.engine}_args"]
244
+ overwrite_engine_args(params)
245
+ elif input_engine:
246
+ # Only engine arg in CLI, use default engine args from config
247
+ self.engine = input_engine
248
+ params["engine_args"] = params[f"{self.engine}_args"]
249
+ elif extracted_engine:
250
+ # Only engine-specific args in CLI, infer engine and warn user
251
+ self.engine = extracted_engine
252
+ params["engine_inferred"] = True
253
+ params["engine_args"] = params[f"{self.engine}_args"]
254
+ overwrite_engine_args(params)
255
+ else:
256
+ # No engine-related args in CLI, use defaults from config
257
+ self.engine = params.get("engine", "vllm")
258
+ params["engine_args"] = params[f"{self.engine}_args"]
259
+
260
+ # Remove $ENGINE_NAME_args from params as they won't get populated to sjob json.
261
+ for engine in SUPPORTED_ENGINES:
262
+ del params[f"{engine}_args"]
263
+
199
264
  def _apply_cli_overrides(self, params: dict[str, Any]) -> None:
200
265
  """Apply CLI argument overrides to params.
201
266
 
@@ -204,11 +269,7 @@ class ModelLauncher:
204
269
  params : dict[str, Any]
205
270
  Dictionary of launch parameters to override
206
271
  """
207
- if self.kwargs.get("vllm_args"):
208
- vllm_args = self._process_vllm_args(self.kwargs["vllm_args"])
209
- for key, value in vllm_args.items():
210
- params["vllm_args"][key] = value
211
- del self.kwargs["vllm_args"]
272
+ self._engine_check_override(params)
212
273
 
213
274
  if self.kwargs.get("env"):
214
275
  env_vars = self._process_env_vars(self.kwargs["env"])
@@ -241,7 +302,7 @@ class ModelLauncher:
241
302
  """
242
303
  if (
243
304
  int(params["gpus_per_node"]) > 1
244
- and params["vllm_args"].get("--tensor-parallel-size") is None
305
+ and params["engine_args"].get("--tensor-parallel-size") is None
245
306
  ):
246
307
  raise MissingRequiredFieldsError(
247
308
  "--tensor-parallel-size is required when gpus_per_node > 1"
@@ -252,8 +313,8 @@ class ModelLauncher:
252
313
  raise ValueError("Total number of GPUs requested must be a power of two")
253
314
 
254
315
  total_parallel_sizes = int(
255
- params["vllm_args"].get("--tensor-parallel-size", "1")
256
- ) * int(params["vllm_args"].get("--pipeline-parallel-size", "1"))
316
+ params["engine_args"].get("--tensor-parallel-size", "1")
317
+ ) * int(params["engine_args"].get("--pipeline-parallel-size", "1"))
257
318
  if total_gpus_requested != total_parallel_sizes:
258
319
  raise ValueError(
259
320
  "Mismatch between total number of GPUs requested and parallelization settings"
@@ -312,7 +373,8 @@ class ModelLauncher:
312
373
 
313
374
  # Convert path to string for JSON serialization
314
375
  for field in params:
315
- if field in ["vllm_args", "env"]:
376
+ # Keep structured fields (dicts/bools) intact
377
+ if field in ["engine_args", "env", "engine_inferred"]:
316
378
  continue
317
379
  params[field] = str(params[field])
318
380
 
@@ -370,7 +432,7 @@ class ModelLauncher:
370
432
 
371
433
  # Replace venv with image path if using container
372
434
  if self.params["venv"] == CONTAINER_MODULE_NAME:
373
- self.params["venv"] = IMAGE_PATH
435
+ self.params["venv"] = IMAGE_PATH[self.params["engine"]]
374
436
 
375
437
  with job_json.open("w") as file:
376
438
  json.dump(self.params, file, indent=4)
@@ -453,6 +515,53 @@ class BatchModelLauncher:
453
515
 
454
516
  return model_configs_dict
455
517
 
518
+ def _validate_resource_and_parallel_settings(
519
+ self,
520
+ config: ModelConfig,
521
+ model_engine_args: dict[str, Any] | None,
522
+ model_name: str,
523
+ ) -> None:
524
+ """Validate resource allocation and parallelization settings for each model.
525
+
526
+ Parameters
527
+ ----------
528
+ config : ModelConfig
529
+ Configuration of the model to validate
530
+ model_engine_args : dict[str, Any] | None
531
+ Inference engine arguments of the model to validate
532
+ model_name : str
533
+ Name of the model to validate
534
+
535
+ Raises
536
+ ------
537
+ MissingRequiredFieldsError
538
+ If tensor parallel size is not specified when using multiple GPUs
539
+ ValueError
540
+ If total # of GPUs requested is not a power of two
541
+ If mismatch between total # of GPUs requested and parallelization settings
542
+ """
543
+ if (
544
+ int(config.gpus_per_node) > 1
545
+ and (model_engine_args or {}).get("--tensor-parallel-size") is None
546
+ ):
547
+ raise MissingRequiredFieldsError(
548
+ f"--tensor-parallel-size is required when gpus_per_node > 1, check your configuration for {model_name}"
549
+ )
550
+
551
+ total_gpus_requested = int(config.gpus_per_node) * int(config.num_nodes)
552
+ if not utils.is_power_of_two(total_gpus_requested):
553
+ raise ValueError(
554
+ f"Total number of GPUs requested must be a power of two, check your configuration for {model_name}"
555
+ )
556
+
557
+ total_parallel_sizes = int(
558
+ (model_engine_args or {}).get("--tensor-parallel-size", "1")
559
+ ) * int((model_engine_args or {}).get("--pipeline-parallel-size", "1"))
560
+ if total_gpus_requested != total_parallel_sizes:
561
+ raise ValueError(
562
+ f"Mismatch between total number of GPUs requested and parallelization settings, check your configuration for {model_name}"
563
+ )
564
+
456
565
  def _get_launch_params(
457
566
  self, account: Optional[str] = None, work_dir: Optional[str] = None
458
567
  ) -> dict[str, Any]:
@@ -469,43 +578,29 @@ class BatchModelLauncher:
469
578
  If required fields are missing or tensor parallel size is not specified
470
579
  when using multiple GPUs
471
580
  """
472
- params: dict[str, Any] = {
473
- "models": {},
581
+ common_params: dict[str, Any] = {
474
582
  "slurm_job_name": self.slurm_job_name,
475
583
  "src_dir": str(SRC_DIR),
476
584
  "account": account,
477
585
  "work_dir": work_dir,
478
586
  }
479
587
 
480
- # Check for required fields without default vals, will raise an error if missing
481
- utils.check_required_fields(params)
588
+ params: dict[str, Any] = common_params.copy()
589
+ params["models"] = {}
482
590
 
483
591
  for i, (model_name, config) in enumerate(self.model_configs.items()):
484
592
  params["models"][model_name] = config.model_dump(exclude_none=True)
485
593
  params["models"][model_name]["het_group_id"] = i
486
594
 
487
- # Validate resource allocation and parallelization settings
488
- if (
489
- int(config.gpus_per_node) > 1
490
- and (config.vllm_args or {}).get("--tensor-parallel-size") is None
491
- ):
492
- raise MissingRequiredFieldsError(
493
- f"--tensor-parallel-size is required when gpus_per_node > 1, check your configuration for {model_name}"
494
- )
595
+ model_engine_args = getattr(config, f"{config.engine}_args", None)
596
+ params["models"][model_name]["engine_args"] = model_engine_args
597
+ for engine in SUPPORTED_ENGINES:
598
+ del params["models"][model_name][f"{engine}_args"]
495
599
 
496
- total_gpus_requested = int(config.gpus_per_node) * int(config.num_nodes)
497
- if not utils.is_power_of_two(total_gpus_requested):
498
- raise ValueError(
499
- f"Total number of GPUs requested must be a power of two, check your configuration for {model_name}"
500
- )
501
-
502
- total_parallel_sizes = int(
503
- (config.vllm_args or {}).get("--tensor-parallel-size", "1")
504
- ) * int((config.vllm_args or {}).get("--pipeline-parallel-size", "1"))
505
- if total_gpus_requested != total_parallel_sizes:
506
- raise ValueError(
507
- f"Mismatch between total number of GPUs requested and parallelization settings, check your configuration for {model_name}"
508
- )
600
+ # Validate resource allocation and parallelization settings
601
+ self._validate_resource_and_parallel_settings(
602
+ config, model_engine_args, model_name
603
+ )
509
604
 
510
605
  # Convert gpus_per_node and resource_type to gres
511
606
  params["models"][model_name]["gres"] = (
@@ -555,6 +650,16 @@ class BatchModelLauncher:
555
650
  raise ValueError(
556
651
  f"Mismatch found for {arg}: {params[arg]} != {params['models'][model_name][arg]}, check your configuration"
557
652
  )
653
+ # Check for required fields and return environment variable overrides
654
+ env_overrides = utils.check_required_fields(
655
+ {**params["models"][model_name], **common_params}
656
+ )
657
+
658
+ for arg, value in env_overrides.items():
659
+ if arg in common_params:
660
+ params[arg] = value
661
+ else:
662
+ params["models"][model_name][arg] = value
558
663
 
559
664
  return params
560
665
 
@@ -718,7 +823,7 @@ class ModelStatusMonitor:
718
823
  Basic status information for the job
719
824
  """
720
825
  try:
721
- job_name = self.job_status["JobName"]
826
+ job_name = self.job_status["JobName"].removesuffix("-vec-inf")
722
827
  job_state = self.job_status["JobState"]
723
828
  except KeyError:
724
829
  job_name = "UNAVAILABLE"
@@ -1,7 +1,7 @@
1
- """Class for generating Slurm scripts to run vLLM servers.
1
+ """Class for generating Slurm scripts to run inference servers.
2
2
 
3
- This module provides functionality to generate Slurm scripts for running vLLM servers
4
- in both single-node and multi-node configurations.
3
+ This module provides functionality to generate Slurm scripts for running inference
4
+ servers in both single-node and multi-node configurations.
5
5
  """
6
6
 
7
7
  from datetime import datetime
@@ -14,11 +14,11 @@ from vec_inf.client._slurm_templates import (
14
14
  BATCH_SLURM_SCRIPT_TEMPLATE,
15
15
  SLURM_SCRIPT_TEMPLATE,
16
16
  )
17
- from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME
17
+ from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME, IMAGE_PATH
18
18
 
19
19
 
20
20
  class SlurmScriptGenerator:
21
- """A class to generate Slurm scripts for running vLLM servers.
21
+ """A class to generate Slurm scripts for running inference servers.
22
22
 
23
23
  This class handles the generation of Slurm scripts for both single-node and
24
24
  multi-node configurations, supporting different virtualization environments
@@ -32,11 +32,12 @@ class SlurmScriptGenerator:
32
32
 
33
33
  def __init__(self, params: dict[str, Any]):
34
34
  self.params = params
35
+ self.engine = params.get("engine", "vllm")
35
36
  self.is_multinode = int(self.params["num_nodes"]) > 1
36
37
  self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
37
- self.additional_binds = self.params.get("bind", "")
38
- if self.additional_binds:
39
- self.additional_binds = f" --bind {self.additional_binds}"
38
+ self.additional_binds = (
39
+ f",{self.params['bind']}" if self.params.get("bind") else ""
40
+ )
40
41
  self.model_weights_path = str(
41
42
  Path(self.params["model_weights_parent_dir"], self.params["model_name"])
42
43
  )
@@ -89,6 +90,8 @@ class SlurmScriptGenerator:
89
90
  for arg, value in SLURM_JOB_CONFIG_ARGS.items():
90
91
  if self.params.get(value):
91
92
  shebang.append(f"#SBATCH --{arg}={self.params[value]}")
93
+ if value == "model_name":
94
+ shebang[-1] += "-vec-inf"
92
95
  if self.is_multinode:
93
96
  shebang += SLURM_SCRIPT_TEMPLATE["shebang"]["multinode"]
94
97
  return "\n".join(shebang)
@@ -107,7 +110,12 @@ class SlurmScriptGenerator:
107
110
  server_script = ["\n"]
108
111
  if self.use_container:
109
112
  server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
110
- server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_env_vars"]))
113
+ server_script.append(
114
+ SLURM_SCRIPT_TEMPLATE["bind_path"].format(
115
+ model_weights_path=self.model_weights_path,
116
+ additional_binds=self.additional_binds,
117
+ )
118
+ )
111
119
  else:
112
120
  server_script.append(
113
121
  SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
@@ -116,17 +124,17 @@ class SlurmScriptGenerator:
116
124
  server_script.append(
117
125
  SLURM_SCRIPT_TEMPLATE["imports"].format(src_dir=self.params["src_dir"])
118
126
  )
119
- if self.is_multinode:
127
+
128
+ if self.is_multinode and self.engine == "vllm":
120
129
  server_setup_str = "\n".join(
121
- SLURM_SCRIPT_TEMPLATE["server_setup"]["multinode"]
130
+ SLURM_SCRIPT_TEMPLATE["server_setup"]["multinode_vllm"]
122
131
  ).format(gpus_per_node=self.params["gpus_per_node"])
123
132
  if self.use_container:
124
133
  server_setup_str = server_setup_str.replace(
125
134
  "CONTAINER_PLACEHOLDER",
126
135
  SLURM_SCRIPT_TEMPLATE["container_command"].format(
127
- model_weights_path=self.model_weights_path,
128
- additional_binds=self.additional_binds,
129
136
  env_str=self.env_str,
137
+ image_path=IMAGE_PATH[self.engine],
130
138
  ),
131
139
  )
132
140
  else:
@@ -134,12 +142,16 @@ class SlurmScriptGenerator:
134
142
  "CONTAINER_PLACEHOLDER",
135
143
  "\\",
136
144
  )
145
+ elif self.is_multinode and self.engine == "sglang":
146
+ server_setup_str = "\n".join(
147
+ SLURM_SCRIPT_TEMPLATE["server_setup"]["multinode_sglang"]
148
+ )
137
149
  else:
138
150
  server_setup_str = "\n".join(
139
151
  SLURM_SCRIPT_TEMPLATE["server_setup"]["single_node"]
140
152
  )
141
153
  server_script.append(server_setup_str)
142
- server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["find_vllm_port"]))
154
+ server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["find_server_port"]))
143
155
  server_script.append(
144
156
  "\n".join(SLURM_SCRIPT_TEMPLATE["write_to_json"]).format(
145
157
  log_dir=self.params["log_dir"], model_name=self.params["model_name"]
@@ -148,39 +160,85 @@ class SlurmScriptGenerator:
148
160
  return "\n".join(server_script)
149
161
 
150
162
  def _generate_launch_cmd(self) -> str:
151
- """Generate the vLLM server launch command.
163
+ """Generate the inference server launch command.
152
164
 
153
- Creates the command to launch the vLLM server, handling different virtualization
154
- environments (venv or singularity/apptainer).
165
+ Creates the command to launch the inference server, handling different
166
+ virtualization environments (venv or singularity/apptainer).
155
167
 
156
168
  Returns
157
169
  -------
158
170
  str
159
171
  Server launch command.
160
172
  """
161
- launcher_script = ["\n"]
173
+ if self.is_multinode and self.engine == "sglang":
174
+ return self._generate_multinode_sglang_launch_cmd()
175
+
176
+ launch_cmd = ["\n"]
162
177
  if self.use_container:
163
- launcher_script.append(
178
+ launch_cmd.append(
164
179
  SLURM_SCRIPT_TEMPLATE["container_command"].format(
165
- model_weights_path=self.model_weights_path,
166
- additional_binds=self.additional_binds,
167
180
  env_str=self.env_str,
181
+ image_path=IMAGE_PATH[self.engine],
168
182
  )
169
183
  )
170
184
 
171
- launcher_script.append(
172
- "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
185
+ launch_cmd.append(
186
+ "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"][self.engine]).format( # type: ignore[literal-required]
173
187
  model_weights_path=self.model_weights_path,
174
188
  model_name=self.params["model_name"],
175
189
  )
176
190
  )
177
191
 
178
- for arg, value in self.params["vllm_args"].items():
192
+ for arg, value in self.params["engine_args"].items():
179
193
  if isinstance(value, bool):
180
- launcher_script.append(f" {arg} \\")
194
+ launch_cmd.append(f" {arg} \\")
181
195
  else:
182
- launcher_script.append(f" {arg} {value} \\")
183
- return "\n".join(launcher_script)
196
+ launch_cmd.append(f" {arg} {value} \\")
197
+
198
+ # A known bug in vLLM requires setting backend to ray for multi-node
199
+ # Remove this when the bug is fixed
200
+ if self.is_multinode:
201
+ launch_cmd.append(" --distributed-executor-backend ray \\")
202
+
203
+ return "\n".join(launch_cmd).rstrip(" \\")
204
+
205
+ def _generate_multinode_sglang_launch_cmd(self) -> str:
206
+ """Generate the launch command for multi-node sglang setup.
207
+
208
+ Returns
209
+ -------
210
+ str
211
+ Multi-node sglang launch command.
212
+ """
213
+ launch_cmd = "\n" + "\n".join(
214
+ SLURM_SCRIPT_TEMPLATE["launch_cmd"]["sglang_multinode"]
215
+ ).format(
216
+ num_nodes=self.params["num_nodes"],
217
+ model_weights_path=self.model_weights_path,
218
+ model_name=self.params["model_name"],
219
+ )
220
+
221
+ container_placeholder = "\\"
222
+ if self.use_container:
223
+ container_placeholder = SLURM_SCRIPT_TEMPLATE["container_command"].format(
224
+ env_str=self.env_str,
225
+ image_path=IMAGE_PATH[self.engine],
226
+ )
227
+ launch_cmd = launch_cmd.replace(
228
+ "CONTAINER_PLACEHOLDER",
229
+ container_placeholder,
230
+ )
231
+
232
+ engine_arg_str = ""
233
+ for arg, value in self.params["engine_args"].items():
234
+ if isinstance(value, bool):
235
+ engine_arg_str += f" {arg} \\\n"
236
+ else:
237
+ engine_arg_str += f" {arg} {value} \\\n"
238
+
239
+ return launch_cmd.replace(
240
+ "SGLANG_ARGS_PLACEHOLDER", engine_arg_str.rstrip("\\\n")
241
+ )
184
242
 
185
243
  def write_to_log_dir(self) -> Path:
186
244
  """Write the generated Slurm script to the log directory.
@@ -207,7 +265,7 @@ class BatchSlurmScriptGenerator:
207
265
  """A class to generate Slurm scripts for batch mode.
208
266
 
209
267
  This class handles the generation of Slurm scripts for batch mode, which
210
- launches multiple vLLM servers with different configurations in parallel.
268
+ launches multiple inference servers with different configurations in parallel.
211
269
  """
212
270
 
213
271
  def __init__(self, params: dict[str, Any]):
@@ -215,11 +273,11 @@ class BatchSlurmScriptGenerator:
215
273
  self.script_paths: list[Path] = []
216
274
  self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
217
275
  for model_name in self.params["models"]:
218
- self.params["models"][model_name]["additional_binds"] = ""
219
- if self.params["models"][model_name].get("bind"):
220
- self.params["models"][model_name]["additional_binds"] = (
221
- f" --bind {self.params['models'][model_name]['bind']}"
222
- )
276
+ self.params["models"][model_name]["additional_binds"] = (
277
+ f",{self.params['models'][model_name]['bind']}"
278
+ if self.params["models"][model_name].get("bind")
279
+ else ""
280
+ )
223
281
  self.params["models"][model_name]["model_weights_path"] = str(
224
282
  Path(
225
283
  self.params["models"][model_name]["model_weights_parent_dir"],
@@ -241,7 +299,7 @@ class BatchSlurmScriptGenerator:
241
299
  return script_path
242
300
 
243
301
  def _generate_model_launch_script(self, model_name: str) -> Path:
244
- """Generate the bash script for launching individual vLLM servers.
302
+ """Generate the bash script for launching individual inference servers.
245
303
 
246
304
  Parameters
247
305
  ----------
@@ -251,7 +309,7 @@ class BatchSlurmScriptGenerator:
251
309
  Returns
252
310
  -------
253
311
  Path
254
- The bash script path for launching the vLLM server.
312
+ The bash script path for launching the inference server.
255
313
  """
256
314
  # Generate the bash script content
257
315
  script_content = []
@@ -259,7 +317,12 @@ class BatchSlurmScriptGenerator:
259
317
  script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["shebang"])
260
318
  if self.use_container:
261
319
  script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
262
- script_content.append("\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["env_vars"]))
320
+ script_content.append(
321
+ BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
322
+ model_weights_path=model_params["model_weights_path"],
323
+ additional_binds=model_params["additional_binds"],
324
+ )
325
+ )
263
326
  script_content.append(
264
327
  "\n".join(
265
328
  BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["server_address_setup"]
@@ -276,22 +339,23 @@ class BatchSlurmScriptGenerator:
276
339
  if self.use_container:
277
340
  script_content.append(
278
341
  BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format(
279
- model_weights_path=model_params["model_weights_path"],
280
- additional_binds=model_params["additional_binds"],
342
+ image_path=IMAGE_PATH[model_params["engine"]],
281
343
  )
282
344
  )
283
345
  script_content.append(
284
- "\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format(
346
+ "\n".join(
347
+ BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"][model_params["engine"]]
348
+ ).format(
285
349
  model_weights_path=model_params["model_weights_path"],
286
350
  model_name=model_name,
287
351
  )
288
352
  )
289
- for arg, value in model_params["vllm_args"].items():
353
+ for arg, value in model_params["engine_args"].items():
290
354
  if isinstance(value, bool):
291
355
  script_content.append(f" {arg} \\")
292
356
  else:
293
357
  script_content.append(f" {arg} {value} \\")
294
- script_content[-1] = script_content[-1].replace("\\", "")
358
+ script_content[-1] = script_content[-1].rstrip(" \\")
295
359
  # Write the bash script to the log directory
296
360
  launch_script_path = self._write_to_log_dir(
297
361
  script_content, f"launch_{model_name}.sh"
@@ -321,6 +385,8 @@ class BatchSlurmScriptGenerator:
321
385
  model_params = self.params["models"][model_name]
322
386
  if model_params.get(value) and value not in ["out_file", "err_file"]:
323
387
  shebang.append(f"#SBATCH --{arg}={model_params[value]}")
388
+ if value == "model_name":
389
+ shebang[-1] += "-vec-inf"
324
390
  shebang[-1] += "\n"
325
391
  shebang.append(BATCH_SLURM_SCRIPT_TEMPLATE["hetjob"])
326
392
  # Remove the last hetjob line
@@ -328,12 +394,12 @@ class BatchSlurmScriptGenerator:
328
394
  return "\n".join(shebang)
329
395
 
330
396
  def generate_batch_slurm_script(self) -> Path:
331
- """Generate the Slurm script for launching multiple vLLM servers in batch mode.
397
+ """Generate the Slurm script for launching multiple inference servers in batch.
332
398
 
333
399
  Returns
334
400
  -------
335
401
  Path
336
- The Slurm script for launching multiple vLLM servers in batch mode.
402
+ The Slurm script for launching multiple inference servers in batch.
337
403
  """
338
404
  script_content = []
339
405