vec-inf 0.7.3__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vec_inf/cli/_cli.py CHANGED
@@ -132,10 +132,20 @@ def cli() -> None:
132
132
  type=str,
133
133
  help="Path to parent directory containing model weights",
134
134
  )
135
+ @click.option(
136
+ "--engine",
137
+ type=str,
138
+ help="Inference engine to use, supports 'vllm' and 'sglang'",
139
+ )
135
140
  @click.option(
136
141
  "--vllm-args",
137
142
  type=str,
138
- help="vLLM engine arguments to be set, use the format as specified in vLLM documentation and separate arguments with commas, e.g. --vllm-args '--max-model-len=8192,--max-num-seqs=256,--enable-prefix-caching'",
143
+ help="vLLM engine arguments to be set, use the format as specified in vLLM serve documentation and separate arguments with commas, e.g. --vllm-args '--max-model-len=8192,--max-num-seqs=256,--enable-prefix-caching'",
144
+ )
145
+ @click.option(
146
+ "--sglang-args",
147
+ type=str,
148
+ help="SGLang engine arguments to be set, use the format as specified in SGLang Server Arguments documentation and separate arguments with commas, e.g. --sglang-args '--context-length=8192,--mem-fraction-static=0.85'",
139
149
  )
140
150
  @click.option(
141
151
  "--json-mode",
@@ -150,7 +160,7 @@ def cli() -> None:
150
160
  @click.option(
151
161
  "--config",
152
162
  type=str,
153
- help="Path to a model config yaml file to use in place of the default",
163
+ help="Path to a model config yaml file to use in place of the default, you can also set VEC_INF_MODEL_CONFIG to the path to the model config file",
154
164
  )
155
165
  def launch(
156
166
  model_name: str,
@@ -201,7 +211,9 @@ def launch(
201
211
  - model_weights_parent_dir : str, optional
202
212
  Path to model weights directory
203
213
  - vllm_args : str, optional
204
- vLLM engine arguments
214
+ vllm engine arguments
215
+ - sglang_args : str, optional
216
+ sglang engine arguments
205
217
  - env : str, optional
206
218
  Environment variables
207
219
  - config : str, optional
@@ -229,6 +241,10 @@ def launch(
229
241
  if json_mode:
230
242
  click.echo(json.dumps(launch_response.config))
231
243
  else:
244
+ if launch_response.config.get("engine_inferred"):
245
+ CONSOLE.print(
246
+ "Warning: Inference engine inferred from engine-specific args"
247
+ )
232
248
  launch_formatter = LaunchResponseFormatter(
233
249
  model_name, launch_response.config
234
250
  )
vec_inf/cli/_helper.py CHANGED
@@ -15,7 +15,7 @@ from rich.panel import Panel
15
15
  from rich.table import Table
16
16
 
17
17
  from vec_inf.cli._utils import create_table
18
- from vec_inf.cli._vars import MODEL_TYPE_COLORS, MODEL_TYPE_PRIORITY
18
+ from vec_inf.cli._vars import ENGINE_NAME_MAP, MODEL_TYPE_COLORS, MODEL_TYPE_PRIORITY
19
19
  from vec_inf.client import ModelConfig, ModelInfo, StatusResponse
20
20
 
21
21
 
@@ -49,11 +49,12 @@ class LaunchResponseFormatter:
49
49
  if self.params.get(key):
50
50
  table.add_row(label, self.params[key])
51
51
 
52
- def _add_vllm_config(self, table: Table) -> None:
53
- """Add vLLM configuration details to the table."""
54
- if self.params.get("vllm_args"):
55
- table.add_row("vLLM Arguments:", style="magenta")
56
- for arg, value in self.params["vllm_args"].items():
52
+ def _add_engine_config(self, table: Table) -> None:
53
+ """Add inference engine configuration details to the table."""
54
+ if self.params.get("engine_args"):
55
+ engine_name = ENGINE_NAME_MAP[self.params["engine"]]
56
+ table.add_row(f"{engine_name} Arguments:", style="magenta")
57
+ for arg, value in self.params["engine_args"].items():
57
58
  table.add_row(f" {arg}:", str(value))
58
59
 
59
60
  def _add_env_vars(self, table: Table) -> None:
@@ -111,9 +112,10 @@ class LaunchResponseFormatter:
111
112
  str(Path(self.params["model_weights_parent_dir"], self.model_name)),
112
113
  )
113
114
  table.add_row("Log Directory", self.params["log_dir"])
115
+ table.add_row("Inference Engine", ENGINE_NAME_MAP[self.params["engine"]])
114
116
 
115
117
  # Add configuration details
116
- self._add_vllm_config(table)
118
+ self._add_engine_config(table)
117
119
  self._add_env_vars(table)
118
120
  self._add_bind_paths(table)
119
121
 
@@ -185,6 +187,10 @@ class BatchLaunchResponseFormatter:
185
187
  table.add_row(
186
188
  "Memory/Node", f" {self.params['models'][model_name]['mem_per_node']}"
187
189
  )
190
+ table.add_row(
191
+ "Inference Engine",
192
+ f" {ENGINE_NAME_MAP[self.params['models'][model_name]['engine']]}",
193
+ )
188
194
 
189
195
  return table
190
196
 
@@ -479,14 +485,19 @@ class ListCmdDisplay:
479
485
  )
480
486
  return json.dumps(config_dict, indent=4)
481
487
 
488
+ excluded_list = ["venv", "log_dir"]
489
+
482
490
  table = create_table(key_title="Model Config", value_title="Value")
483
491
  for field, value in config.model_dump().items():
484
- if field not in {"venv", "log_dir", "vllm_args"} and value:
492
+ if "args" in field:
493
+ if not value:
494
+ continue
495
+ engine_name = ENGINE_NAME_MAP[field.split("_")[0]]
496
+ table.add_row(f"{engine_name} Arguments:", style="magenta")
497
+ for engine_arg, engine_value in value.items():
498
+ table.add_row(f" {engine_arg}:", str(engine_value))
499
+ elif field not in excluded_list and value:
485
500
  table.add_row(field, str(value))
486
- if field == "vllm_args":
487
- table.add_row("vLLM Arguments:", style="magenta")
488
- for vllm_arg, vllm_value in value.items():
489
- table.add_row(f" {vllm_arg}:", str(vllm_value))
490
501
  return table
491
502
 
492
503
  def _format_all_models_output(
vec_inf/cli/_vars.py CHANGED
@@ -1,32 +1,47 @@
1
1
  """Constants for CLI rendering.
2
2
 
3
- This module defines constant mappings for model type priorities and colors
3
+ This module defines mappings for model type priorities, colors, and engine name mappings
4
4
  used in the CLI display formatting.
5
+ """
5
6
 
6
- Constants
7
- ---------
8
- MODEL_TYPE_PRIORITY : dict
9
- Mapping of model types to their display priority (lower numbers shown first)
7
+ from typing import get_args
10
8
 
11
- MODEL_TYPE_COLORS : dict
12
- Mapping of model types to their display colors in Rich
9
+ from vec_inf.client._slurm_vars import MODEL_TYPES
13
10
 
14
- Notes
15
- -----
16
- These constants are used primarily by the ListCmdDisplay class to ensure
17
- consistent sorting and color coding of different model types in the CLI output.
18
- """
19
11
 
20
- MODEL_TYPE_PRIORITY = {
21
- "LLM": 0,
22
- "VLM": 1,
23
- "Text_Embedding": 2,
24
- "Reward_Modeling": 3,
25
- }
12
+ # Extract model type values from the Literal type
13
+ _MODEL_TYPES = get_args(MODEL_TYPES)
14
+
15
+ # Rich color options (prioritizing current colors, with fallbacks for additional types)
16
+ _RICH_COLORS = [
17
+ "cyan",
18
+ "bright_blue",
19
+ "purple",
20
+ "bright_magenta",
21
+ "green",
22
+ "yellow",
23
+ "bright_green",
24
+ "bright_yellow",
25
+ "red",
26
+ "bright_red",
27
+ "blue",
28
+ "magenta",
29
+ "bright_cyan",
30
+ "white",
31
+ "bright_white",
32
+ ]
26
33
 
34
+ # Mapping of model types to their display priority (lower numbers shown first)
35
+ MODEL_TYPE_PRIORITY = {model_type: idx for idx, model_type in enumerate(_MODEL_TYPES)}
36
+
37
+ # Mapping of model types to their display colors in Rich
27
38
  MODEL_TYPE_COLORS = {
28
- "LLM": "cyan",
29
- "VLM": "bright_blue",
30
- "Text_Embedding": "purple",
31
- "Reward_Modeling": "bright_magenta",
39
+ model_type: _RICH_COLORS[idx % len(_RICH_COLORS)]
40
+ for idx, model_type in enumerate(_MODEL_TYPES)
41
+ }
42
+
43
+ # Inference engine choice and name mapping
44
+ ENGINE_NAME_MAP = {
45
+ "vllm": "vLLM",
46
+ "sglang": "SGLang",
32
47
  }
@@ -49,7 +49,7 @@ SLURM_JOB_CONFIG_ARGS = {
49
49
  "time": "time",
50
50
  "nodes": "num_nodes",
51
51
  "exclude": "exclude",
52
- "nodelist": "node_list",
52
+ "nodelist": "nodelist",
53
53
  "gres": "gres",
54
54
  "cpus-per-task": "cpus_per_task",
55
55
  "mem": "mem_per_node",
@@ -61,13 +61,43 @@ SLURM_JOB_CONFIG_ARGS = {
61
61
  VLLM_SHORT_TO_LONG_MAP = {
62
62
  "-tp": "--tensor-parallel-size",
63
63
  "-pp": "--pipeline-parallel-size",
64
+ "-n": "--nnodes",
65
+ "-r": "--node-rank",
66
+ "-dcp": "--decode-context-parallel-size",
67
+ "-pcp": "--prefill-context-parallel-size",
64
68
  "-dp": "--data-parallel-size",
69
+ "-dpn": "--data-parallel-rank",
70
+ "-dpr": "--data-parallel-start-rank",
65
71
  "-dpl": "--data-parallel-size-local",
66
72
  "-dpa": "--data-parallel-address",
67
73
  "-dpp": "--data-parallel-rpc-port",
74
+ "-dpb": "--data-parallel-backend",
75
+ "-dph": "--data-parallel-hybrid-lb",
76
+ "-dpe": "--data-parallel-external-lb",
68
77
  "-O": "--compilation-config",
69
78
  "-q": "--quantization",
70
79
  }
71
80
 
81
+ # SGLang engine args mapping between short and long names
82
+ SGLANG_SHORT_TO_LONG_MAP = {
83
+ "--tp": "--tensor-parallel-size",
84
+ "--tp-size": "--tensor-parallel-size",
85
+ "--pp": "--pipeline-parallel-size",
86
+ "--pp-size": "--pipeline-parallel-size",
87
+ "--dp": "--data-parallel-size",
88
+ "--dp-size": "--data-parallel-size",
89
+ "--ep": "--expert-parallel-size",
90
+ "--ep-size": "--expert-parallel-expert-size",
91
+ }
92
+
93
+ # Mapping of engine short names to their argument mappings
94
+ ENGINE_SHORT_TO_LONG_MAP = {
95
+ "vllm": VLLM_SHORT_TO_LONG_MAP,
96
+ "sglang": SGLANG_SHORT_TO_LONG_MAP,
97
+ }
98
+
72
99
  # Required matching arguments for batch mode
73
100
  BATCH_MODE_REQUIRED_MATCHING_ARGS = ["venv", "log_dir"]
101
+
102
+ # Supported engines
103
+ SUPPORTED_ENGINES = ["vllm", "sglang"]
vec_inf/client/_helper.py CHANGED
@@ -17,9 +17,10 @@ import requests
17
17
  import vec_inf.client._utils as utils
18
18
  from vec_inf.client._client_vars import (
19
19
  BATCH_MODE_REQUIRED_MATCHING_ARGS,
20
+ ENGINE_SHORT_TO_LONG_MAP,
20
21
  KEY_METRICS,
21
22
  SRC_DIR,
22
- VLLM_SHORT_TO_LONG_MAP,
23
+ SUPPORTED_ENGINES,
23
24
  )
24
25
  from vec_inf.client._exceptions import (
25
26
  MissingRequiredFieldsError,
@@ -63,6 +64,7 @@ class ModelLauncher:
63
64
  self.slurm_job_id = ""
64
65
  self.slurm_script_path = Path("")
65
66
  self.model_config = self._get_model_configuration(self.kwargs.get("config"))
67
+ self.engine = ""
66
68
  self.params = self._get_launch_params()
67
69
 
68
70
  def _warn(self, message: str) -> None:
@@ -137,32 +139,38 @@ class ModelLauncher:
137
139
  f"not found at expected path '{model_weights_path}'"
138
140
  )
139
141
 
140
- def _process_vllm_args(self, arg_string: str) -> dict[str, Any]:
141
- """Process the vllm_args string into a dictionary.
142
+ def _process_engine_args(
143
+ self, arg_string: str, engine_choice: str
144
+ ) -> dict[str, Any]:
145
+ """Process the engine_args string into a dictionary.
142
146
 
143
147
  Parameters
144
148
  ----------
145
149
  arg_string : str
146
- Comma-separated string of vLLM arguments
150
+ Comma-separated string of inference engine arguments
147
151
 
148
152
  Returns
149
153
  -------
150
154
  dict[str, Any]
151
- Processed vLLM arguments as key-value pairs
155
+ Processed inference engine arguments as key-value pairs
152
156
  """
153
- vllm_args: dict[str, str | bool] = {}
157
+ engine_args: dict[str, str | bool] = {}
158
+ engine_arg_map = ENGINE_SHORT_TO_LONG_MAP[engine_choice]
159
+
154
160
  for arg in arg_string.split(","):
155
161
  if "=" in arg:
156
162
  key, value = arg.split("=")
157
- if key.strip() in VLLM_SHORT_TO_LONG_MAP:
158
- key = VLLM_SHORT_TO_LONG_MAP[key.strip()]
159
- vllm_args[key.strip()] = value.strip()
163
+ if key.strip() in engine_arg_map:
164
+ key = engine_arg_map[key.strip()]
165
+ engine_args[key.strip()] = value.strip()
160
166
  elif "-O" in arg.strip():
161
- key = VLLM_SHORT_TO_LONG_MAP["-O"]
162
- vllm_args[key] = arg.strip()[2:].strip()
167
+ if engine_choice != "vllm":
168
+ raise ValueError("-O is only supported for vLLM")
169
+ key = engine_arg_map["-O"]
170
+ engine_args[key] = arg.strip()[2:].strip()
163
171
  else:
164
- vllm_args[arg.strip()] = True
165
- return vllm_args
172
+ engine_args[arg.strip()] = True
173
+ return engine_args
166
174
 
167
175
  def _process_env_vars(self, env_arg: str) -> dict[str, str]:
168
176
  """Process the env string into a dictionary of environment variables.
@@ -196,6 +204,63 @@ class ModelLauncher:
196
204
  print(f"WARNING: Could not parse env var: {line}")
197
205
  return env_vars
198
206
 
207
+ def _engine_check_override(self, params: dict[str, Any]) -> None:
208
+ """Check for engine override in CLI args and warn user.
209
+
210
+ Parameters
211
+ ----------
212
+ params : dict[str, Any]
213
+ Dictionary of launch parameters to check
214
+ """
215
+
216
+ def overwrite_engine_args(params: dict[str, Any]) -> None:
217
+ engine_args = self._process_engine_args(
218
+ self.kwargs[f"{self.engine}_args"], self.engine
219
+ )
220
+ for key, value in engine_args.items():
221
+ params["engine_args"][key] = value
222
+ del self.kwargs[f"{self.engine}_args"]
223
+
224
+ # Infer engine name from engine-specific args if provided
225
+ extracted_engine = ""
226
+ for engine in SUPPORTED_ENGINES:
227
+ if self.kwargs.get(f"{engine}_args"):
228
+ if not extracted_engine:
229
+ extracted_engine = engine
230
+ else:
231
+ raise ValueError(
232
+ "Cannot provide engine-specific args for multiple engines, please choose one"
233
+ )
234
+ # Check for mismatch between provided engine arg and engine-specific args
235
+ input_engine = self.kwargs.get("engine", "")
236
+
237
+ if input_engine and extracted_engine:
238
+ if input_engine != extracted_engine:
239
+ raise ValueError(
240
+ f"Mismatch between provided engine '{input_engine}' and engine-specific args '{extracted_engine}'"
241
+ )
242
+ self.engine = input_engine
243
+ params["engine_args"] = params[f"{self.engine}_args"]
244
+ overwrite_engine_args(params)
245
+ elif input_engine:
246
+ # Only engine arg in CLI, use default engine args from config
247
+ self.engine = input_engine
248
+ params["engine_args"] = params[f"{self.engine}_args"]
249
+ elif extracted_engine:
250
+ # Only engine-specific args in CLI, infer engine and warn user
251
+ self.engine = extracted_engine
252
+ params["engine_inferred"] = True
253
+ params["engine_args"] = params[f"{self.engine}_args"]
254
+ overwrite_engine_args(params)
255
+ else:
256
+ # No engine-related args in CLI, use defaults from config
257
+ self.engine = params.get("engine", "vllm")
258
+ params["engine_args"] = params[f"{self.engine}_args"]
259
+
260
+ # Remove $ENGINE_NAME_args from params as they won't get populated to sjob json.
261
+ for engine in SUPPORTED_ENGINES:
262
+ del params[f"{engine}_args"]
263
+
199
264
  def _apply_cli_overrides(self, params: dict[str, Any]) -> None:
200
265
  """Apply CLI argument overrides to params.
201
266
 
@@ -204,11 +269,7 @@ class ModelLauncher:
204
269
  params : dict[str, Any]
205
270
  Dictionary of launch parameters to override
206
271
  """
207
- if self.kwargs.get("vllm_args"):
208
- vllm_args = self._process_vllm_args(self.kwargs["vllm_args"])
209
- for key, value in vllm_args.items():
210
- params["vllm_args"][key] = value
211
- del self.kwargs["vllm_args"]
272
+ self._engine_check_override(params)
212
273
 
213
274
  if self.kwargs.get("env"):
214
275
  env_vars = self._process_env_vars(self.kwargs["env"])
@@ -241,7 +302,7 @@ class ModelLauncher:
241
302
  """
242
303
  if (
243
304
  int(params["gpus_per_node"]) > 1
244
- and params["vllm_args"].get("--tensor-parallel-size") is None
305
+ and params["engine_args"].get("--tensor-parallel-size") is None
245
306
  ):
246
307
  raise MissingRequiredFieldsError(
247
308
  "--tensor-parallel-size is required when gpus_per_node > 1"
@@ -252,8 +313,8 @@ class ModelLauncher:
252
313
  raise ValueError("Total number of GPUs requested must be a power of two")
253
314
 
254
315
  total_parallel_sizes = int(
255
- params["vllm_args"].get("--tensor-parallel-size", "1")
256
- ) * int(params["vllm_args"].get("--pipeline-parallel-size", "1"))
316
+ params["engine_args"].get("--tensor-parallel-size", "1")
317
+ ) * int(params["engine_args"].get("--pipeline-parallel-size", "1"))
257
318
  if total_gpus_requested != total_parallel_sizes:
258
319
  raise ValueError(
259
320
  "Mismatch between total number of GPUs requested and parallelization settings"
@@ -297,6 +358,11 @@ class ModelLauncher:
297
358
  # Check for required fields without default vals, will raise an error if missing
298
359
  utils.check_required_fields(params)
299
360
 
361
+ if not params.get("work_dir"):
362
+ # This is last resort, work dir should always be a required field to avoid
363
+ # blowing up user home directory unless intended
364
+ params["work_dir"] = str(Path.home())
365
+
300
366
  # Validate resource allocation and parallelization settings
301
367
  self._validate_resource_allocation(params)
302
368
 
@@ -312,7 +378,8 @@ class ModelLauncher:
312
378
 
313
379
  # Convert path to string for JSON serialization
314
380
  for field in params:
315
- if field in ["vllm_args", "env"]:
381
+ # Keep structured fields (dicts/bools) intact
382
+ if field in ["engine_args", "env", "engine_inferred"]:
316
383
  continue
317
384
  params[field] = str(params[field])
318
385
 
@@ -342,6 +409,10 @@ class ModelLauncher:
342
409
  SlurmJobError
343
410
  If SLURM job submission fails
344
411
  """
412
+ # Create cache directory if it doesn't exist
413
+ cache_dir = Path(self.params["work_dir"], ".vec-inf-cache").expanduser()
414
+ cache_dir.mkdir(parents=True, exist_ok=True)
415
+
345
416
  # Build and execute the launch command
346
417
  command_output, stderr = utils.run_bash_command(self._build_launch_command())
347
418
 
@@ -370,7 +441,7 @@ class ModelLauncher:
370
441
 
371
442
  # Replace venv with image path if using container
372
443
  if self.params["venv"] == CONTAINER_MODULE_NAME:
373
- self.params["venv"] = IMAGE_PATH
444
+ self.params["venv"] = IMAGE_PATH[self.params["engine"]]
374
445
 
375
446
  with job_json.open("w") as file:
376
447
  json.dump(self.params, file, indent=4)
@@ -453,6 +524,53 @@ class BatchModelLauncher:
453
524
 
454
525
  return model_configs_dict
455
526
 
527
+ def _validate_resource_and_parallel_settings(
528
+ self,
529
+ config: ModelConfig,
530
+ model_engine_args: dict[str, Any] | None,
531
+ model_name: str,
532
+ ) -> None:
533
+ """Validate resource allocation and parallelization settings for each model.
534
+
535
+ Parameters
536
+ ----------
537
+ config : ModelConfig
538
+ Configuration of the model to validate
539
+ model_engine_args : dict[str, Any] | None
540
+ Inference engine arguments of the model to validate
541
+ model_name : str
542
+ Name of the model to validate
543
+
544
+ Raises
545
+ ------
546
+ MissingRequiredFieldsError
547
+ If tensor parallel size is not specified when using multiple GPUs
548
+ ValueError
549
+ If total # of GPUs requested is not a power of two
550
+ If mismatch between total # of GPUs requested and parallelization settings
551
+ """
552
+ if (
553
+ int(config.gpus_per_node) > 1
554
+ and (model_engine_args or {}).get("--tensor-parallel-size") is None
555
+ ):
556
+ raise MissingRequiredFieldsError(
557
+ f"--tensor-parallel-size is required when gpus_per_node > 1, check your configuration for {model_name}"
558
+ )
559
+
560
+ total_gpus_requested = int(config.gpus_per_node) * int(config.num_nodes)
561
+ if not utils.is_power_of_two(total_gpus_requested):
562
+ raise ValueError(
563
+ f"Total number of GPUs requested must be a power of two, check your configuration for {model_name}"
564
+ )
565
+
566
+ total_parallel_sizes = int(
567
+ (model_engine_args or {}).get("--tensor-parallel-size", "1")
568
+ ) * int((model_engine_args or {}).get("--pipeline-parallel-size", "1"))
569
+ if total_gpus_requested != total_parallel_sizes:
570
+ raise ValueError(
571
+ f"Mismatch between total number of GPUs requested and parallelization settings, check your configuration for {model_name}"
572
+ )
573
+
456
574
  def _get_launch_params(
457
575
  self, account: Optional[str] = None, work_dir: Optional[str] = None
458
576
  ) -> dict[str, Any]:
@@ -483,28 +601,15 @@ class BatchModelLauncher:
483
601
  params["models"][model_name] = config.model_dump(exclude_none=True)
484
602
  params["models"][model_name]["het_group_id"] = i
485
603
 
486
- # Validate resource allocation and parallelization settings
487
- if (
488
- int(config.gpus_per_node) > 1
489
- and (config.vllm_args or {}).get("--tensor-parallel-size") is None
490
- ):
491
- raise MissingRequiredFieldsError(
492
- f"--tensor-parallel-size is required when gpus_per_node > 1, check your configuration for {model_name}"
493
- )
604
+ model_engine_args = getattr(config, f"{config.engine}_args", None)
605
+ params["models"][model_name]["engine_args"] = model_engine_args
606
+ for engine in SUPPORTED_ENGINES:
607
+ del params["models"][model_name][f"{engine}_args"]
494
608
 
495
- total_gpus_requested = int(config.gpus_per_node) * int(config.num_nodes)
496
- if not utils.is_power_of_two(total_gpus_requested):
497
- raise ValueError(
498
- f"Total number of GPUs requested must be a power of two, check your configuration for {model_name}"
499
- )
500
-
501
- total_parallel_sizes = int(
502
- (config.vllm_args or {}).get("--tensor-parallel-size", "1")
503
- ) * int((config.vllm_args or {}).get("--pipeline-parallel-size", "1"))
504
- if total_gpus_requested != total_parallel_sizes:
505
- raise ValueError(
506
- f"Mismatch between total number of GPUs requested and parallelization settings, check your configuration for {model_name}"
507
- )
609
+ # Validate resource allocation and parallelization settings
610
+ self._validate_resource_and_parallel_settings(
611
+ config, model_engine_args, model_name
612
+ )
508
613
 
509
614
  # Convert gpus_per_node and resource_type to gres
510
615
  params["models"][model_name]["gres"] = (
@@ -565,6 +670,10 @@ class BatchModelLauncher:
565
670
  else:
566
671
  params["models"][model_name][arg] = value
567
672
 
673
+ if not params.get("work_dir"):
674
+ # This is last resort, work dir should always be a required field to avoid
675
+ # blowing up user home directory unless intended
676
+ params["work_dir"] = str(Path.home())
568
677
  return params
569
678
 
570
679
  def _build_launch_command(self) -> str:
@@ -593,6 +702,10 @@ class BatchModelLauncher:
593
702
  SlurmJobError
594
703
  If SLURM job submission fails
595
704
  """
705
+ # Create cache directory if it doesn't exist
706
+ cache_dir = Path(self.params["work_dir"], ".vec-inf-cache").expanduser()
707
+ cache_dir.mkdir(parents=True, exist_ok=True)
708
+
596
709
  # Build and execute the launch command
597
710
  command_output, stderr = utils.run_bash_command(self._build_launch_command())
598
711