vec-inf 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- """Class for generating SLURM scripts to run vLLM servers.
1
+ """Class for generating Slurm scripts to run vLLM servers.
2
2
 
3
- This module provides functionality to generate SLURM scripts for running vLLM servers
3
+ This module provides functionality to generate Slurm scripts for running vLLM servers
4
4
  in both single-node and multi-node configurations.
5
5
  """
6
6
 
@@ -8,51 +8,56 @@ from datetime import datetime
8
8
  from pathlib import Path
9
9
  from typing import Any
10
10
 
11
- from vec_inf.client._client_vars import (
12
- SLURM_JOB_CONFIG_ARGS,
11
+ from vec_inf.client._client_vars import SLURM_JOB_CONFIG_ARGS
12
+ from vec_inf.client._slurm_templates import (
13
+ BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE,
14
+ BATCH_SLURM_SCRIPT_TEMPLATE,
13
15
  SLURM_SCRIPT_TEMPLATE,
14
16
  )
15
17
 
16
18
 
17
19
  class SlurmScriptGenerator:
18
- """A class to generate SLURM scripts for running vLLM servers.
20
+ """A class to generate Slurm scripts for running vLLM servers.
19
21
 
20
- This class handles the generation of SLURM scripts for both single-node and
22
+ This class handles the generation of Slurm scripts for both single-node and
21
23
  multi-node configurations, supporting different virtualization environments
22
- (venv or singularity).
24
+ (venv or singularity/apptainer).
23
25
 
24
26
  Parameters
25
27
  ----------
26
- params : dict[str, Any]
27
- Configuration parameters for the SLURM script. Contains settings for job
28
- configuration, model parameters, and virtualization environment.
28
+ params : dict[str, Any]
29
+ Configuration parameters for the Slurm script.
29
30
  """
30
31
 
31
32
  def __init__(self, params: dict[str, Any]):
32
- """Initialize the SlurmScriptGenerator with configuration parameters.
33
-
34
- Parameters
35
- ----------
36
- params : dict[str, Any]
37
- Configuration parameters for the SLURM script.
38
- """
39
33
  self.params = params
40
34
  self.is_multinode = int(self.params["num_nodes"]) > 1
41
- self.use_singularity = self.params["venv"] == "singularity"
35
+ self.use_container = (
36
+ self.params["venv"] == "singularity" or self.params["venv"] == "apptainer"
37
+ )
42
38
  self.additional_binds = self.params.get("bind", "")
43
39
  if self.additional_binds:
44
40
  self.additional_binds = f" --bind {self.additional_binds}"
45
41
  self.model_weights_path = str(
46
- Path(params["model_weights_parent_dir"], params["model_name"])
42
+ Path(self.params["model_weights_parent_dir"], self.params["model_name"])
47
43
  )
44
+ env_dict: dict[str, str] = self.params.get("env", {})
45
+ # Create string of environment variables
46
+ self.env_str = ""
47
+ for key, val in env_dict.items():
48
+ if len(self.env_str) == 0:
49
+ self.env_str = "--env "
50
+ else:
51
+ self.env_str += ","
52
+ self.env_str += key + "=" + val
48
53
 
49
54
  def _generate_script_content(self) -> str:
50
- """Generate the complete SLURM script content.
55
+ """Generate the complete Slurm script content.
51
56
 
52
57
  Returns
53
58
  -------
54
59
  str
55
- The complete SLURM script as a string.
60
+ The complete Slurm script as a string.
56
61
  """
57
62
  script_content = []
58
63
  script_content.append(self._generate_shebang())
@@ -61,12 +66,12 @@ class SlurmScriptGenerator:
61
66
  return "\n".join(script_content)
62
67
 
63
68
  def _generate_shebang(self) -> str:
64
- """Generate the SLURM script shebang with job specifications.
69
+ """Generate the Slurm script shebang with job specifications.
65
70
 
66
71
  Returns
67
72
  -------
68
73
  str
69
- SLURM shebang containing job specifications.
74
+ Slurm shebang containing job specifications.
70
75
  """
71
76
  shebang = [SLURM_SCRIPT_TEMPLATE["shebang"]["base"]]
72
77
  for arg, value in SLURM_JOB_CONFIG_ARGS.items():
@@ -88,8 +93,8 @@ class SlurmScriptGenerator:
88
93
  Server initialization script content.
89
94
  """
90
95
  server_script = ["\n"]
91
- if self.use_singularity:
92
- server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["singularity_setup"]))
96
+ if self.use_container:
97
+ server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
93
98
  server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["env_vars"]))
94
99
  server_script.append(
95
100
  SLURM_SCRIPT_TEMPLATE["imports"].format(src_dir=self.params["src_dir"])
@@ -97,13 +102,14 @@ class SlurmScriptGenerator:
97
102
  if self.is_multinode:
98
103
  server_setup_str = "\n".join(
99
104
  SLURM_SCRIPT_TEMPLATE["server_setup"]["multinode"]
100
- )
101
- if self.use_singularity:
105
+ ).format(gpus_per_node=self.params["gpus_per_node"])
106
+ if self.use_container:
102
107
  server_setup_str = server_setup_str.replace(
103
- "SINGULARITY_PLACEHOLDER",
104
- SLURM_SCRIPT_TEMPLATE["singularity_command"].format(
108
+ "CONTAINER_PLACEHOLDER",
109
+ SLURM_SCRIPT_TEMPLATE["container_command"].format(
105
110
  model_weights_path=self.model_weights_path,
106
111
  additional_binds=self.additional_binds,
112
+ env_str=self.env_str,
107
113
  ),
108
114
  )
109
115
  else:
@@ -123,7 +129,7 @@ class SlurmScriptGenerator:
123
129
  """Generate the vLLM server launch command.
124
130
 
125
131
  Creates the command to launch the vLLM server, handling different virtualization
126
- environments (venv or singularity).
132
+ environments (venv or singularity/apptainer).
127
133
 
128
134
  Returns
129
135
  -------
@@ -131,13 +137,13 @@ class SlurmScriptGenerator:
131
137
  Server launch command.
132
138
  """
133
139
  launcher_script = ["\n"]
134
- if self.use_singularity:
140
+ if self.use_container:
135
141
  launcher_script.append(
136
- SLURM_SCRIPT_TEMPLATE["singularity_command"].format(
142
+ SLURM_SCRIPT_TEMPLATE["container_command"].format(
137
143
  model_weights_path=self.model_weights_path,
138
144
  additional_binds=self.additional_binds,
145
+ env_str=self.env_str,
139
146
  )
140
- + " \\"
141
147
  )
142
148
  else:
143
149
  launcher_script.append(
@@ -158,21 +164,183 @@ class SlurmScriptGenerator:
158
164
  return "\n".join(launcher_script)
159
165
 
160
166
  def write_to_log_dir(self) -> Path:
161
- """Write the generated SLURM script to the log directory.
167
+ """Write the generated Slurm script to the log directory.
162
168
 
163
169
  Creates a timestamped script file in the configured log directory.
164
170
 
165
171
  Returns
166
172
  -------
167
173
  Path
168
- Path to the generated SLURM script file.
174
+ Path to the generated Slurm script file.
169
175
  """
170
176
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
171
177
  script_path: Path = (
172
178
  Path(self.params["log_dir"])
173
- / f"launch_{self.params['model_name']}_{timestamp}.slurm"
179
+ / f"launch_{self.params['model_name']}_{timestamp}.sbatch"
174
180
  )
175
181
 
176
182
  content = self._generate_script_content()
177
183
  script_path.write_text(content)
178
184
  return script_path
185
+
186
+
187
+ class BatchSlurmScriptGenerator:
188
+ """A class to generate Slurm scripts for batch mode.
189
+
190
+ This class handles the generation of Slurm scripts for batch mode, which
191
+ launches multiple vLLM servers with different configurations in parallel.
192
+ """
193
+
194
+ def __init__(self, params: dict[str, Any]):
195
+ self.params = params
196
+ self.script_paths: list[Path] = []
197
+ self.use_container = (
198
+ self.params["venv"] == "singularity" or self.params["venv"] == "apptainer"
199
+ )
200
+ for model_name in self.params["models"]:
201
+ self.params["models"][model_name]["additional_binds"] = ""
202
+ if self.params["models"][model_name].get("bind"):
203
+ self.params["models"][model_name]["additional_binds"] = (
204
+ f" --bind {self.params['models'][model_name]['bind']}"
205
+ )
206
+ self.params["models"][model_name]["model_weights_path"] = str(
207
+ Path(
208
+ self.params["models"][model_name]["model_weights_parent_dir"],
209
+ model_name,
210
+ )
211
+ )
212
+
213
+ def _write_to_log_dir(self, script_content: list[str], script_name: str) -> Path:
214
+ """Write the generated Slurm script to the log directory.
215
+
216
+ Returns
217
+ -------
218
+ Path
219
+ The Path object to the generated Slurm script file.
220
+ """
221
+ script_path = Path(self.params["log_dir"]) / script_name
222
+ script_path.touch(exist_ok=True)
223
+ script_path.write_text("\n".join(script_content))
224
+ return script_path
225
+
226
+ def _generate_model_launch_script(self, model_name: str) -> Path:
227
+ """Generate the bash script for launching individual vLLM servers.
228
+
229
+ Parameters
230
+ ----------
231
+ model_name : str
232
+ The name of the model to launch.
233
+
234
+ Returns
235
+ -------
236
+ Path
237
+ The bash script path for launching the vLLM server.
238
+ """
239
+ # Generate the bash script content
240
+ script_content = []
241
+ model_params = self.params["models"][model_name]
242
+ script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["shebang"])
243
+ if self.use_container:
244
+ script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
245
+ script_content.append("\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["env_vars"]))
246
+ script_content.append(
247
+ "\n".join(
248
+ BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["server_address_setup"]
249
+ ).format(src_dir=self.params["src_dir"])
250
+ )
251
+ script_content.append(
252
+ "\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["write_to_json"]).format(
253
+ het_group_id=model_params["het_group_id"],
254
+ log_dir=self.params["log_dir"],
255
+ slurm_job_name=self.params["slurm_job_name"],
256
+ model_name=model_name,
257
+ )
258
+ )
259
+ if self.use_container:
260
+ script_content.append(
261
+ BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format(
262
+ model_weights_path=model_params["model_weights_path"],
263
+ additional_binds=model_params["additional_binds"],
264
+ )
265
+ )
266
+ script_content.append(
267
+ "\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format(
268
+ model_weights_path=model_params["model_weights_path"],
269
+ model_name=model_name,
270
+ )
271
+ )
272
+ for arg, value in model_params["vllm_args"].items():
273
+ if isinstance(value, bool):
274
+ script_content.append(f" {arg} \\")
275
+ else:
276
+ script_content.append(f" {arg} {value} \\")
277
+ script_content[-1] = script_content[-1].replace("\\", "")
278
+ # Write the bash script to the log directory
279
+ launch_script_path = self._write_to_log_dir(
280
+ script_content, f"launch_{model_name}.sh"
281
+ )
282
+ self.script_paths.append(launch_script_path)
283
+ return launch_script_path
284
+
285
+ def _generate_batch_slurm_script_shebang(self) -> str:
286
+ """Generate the shebang for batch mode Slurm script.
287
+
288
+ Returns
289
+ -------
290
+ str
291
+ The shebang for batch mode Slurm script.
292
+ """
293
+ shebang = [BATCH_SLURM_SCRIPT_TEMPLATE["shebang"]]
294
+
295
+ for arg, value in SLURM_JOB_CONFIG_ARGS.items():
296
+ if self.params.get(value):
297
+ shebang.append(f"#SBATCH --{arg}={self.params[value]}")
298
+ shebang.append("#SBATCH --ntasks=1")
299
+ shebang.append("\n")
300
+
301
+ for model_name in self.params["models"]:
302
+ shebang.append(f"# ===== Resource group for {model_name} =====")
303
+ for arg, value in SLURM_JOB_CONFIG_ARGS.items():
304
+ model_params = self.params["models"][model_name]
305
+ if model_params.get(value) and value not in ["out_file", "err_file"]:
306
+ shebang.append(f"#SBATCH --{arg}={model_params[value]}")
307
+ shebang[-1] += "\n"
308
+ shebang.append(BATCH_SLURM_SCRIPT_TEMPLATE["hetjob"])
309
+ # Remove the last hetjob line
310
+ shebang.pop()
311
+ return "\n".join(shebang)
312
+
313
+ def generate_batch_slurm_script(self) -> Path:
314
+ """Generate the Slurm script for launching multiple vLLM servers in batch mode.
315
+
316
+ Returns
317
+ -------
318
+ Path
319
+ The Slurm script for launching multiple vLLM servers in batch mode.
320
+ """
321
+ script_content = []
322
+
323
+ script_content.append(self._generate_batch_slurm_script_shebang())
324
+
325
+ for model_name in self.params["models"]:
326
+ model_params = self.params["models"][model_name]
327
+ script_content.append(f"# ===== Launching {model_name} =====")
328
+ launch_script_path = str(self._generate_model_launch_script(model_name))
329
+ script_content.append(
330
+ BATCH_SLURM_SCRIPT_TEMPLATE["permission_update"].format(
331
+ script_name=launch_script_path
332
+ )
333
+ )
334
+ script_content.append(
335
+ "\n".join(BATCH_SLURM_SCRIPT_TEMPLATE["launch_model_scripts"]).format(
336
+ het_group_id=model_params["het_group_id"],
337
+ out_file=model_params["out_file"],
338
+ err_file=model_params["err_file"],
339
+ script_name=launch_script_path,
340
+ )
341
+ )
342
+ script_content.append("wait")
343
+
344
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
345
+ script_name = f"{self.params['slurm_job_name']}_{timestamp}.sbatch"
346
+ return self._write_to_log_dir(script_content, script_name)
@@ -0,0 +1,248 @@
1
+ """SLURM script templates for Vector Inference.
2
+
3
+ This module contains the SLURM script templates for Vector Inference, including
4
+ single-node, multi-node, and batch mode templates.
5
+ """
6
+
7
+ from typing import TypedDict
8
+
9
+ from vec_inf.client._slurm_vars import (
10
+ CONTAINER_LOAD_CMD,
11
+ CONTAINER_MODULE_NAME,
12
+ IMAGE_PATH,
13
+ )
14
+
15
+
16
+ CONTAINER_MODULE_NAME_UPPER = CONTAINER_MODULE_NAME.upper()
17
+
18
+
19
+ class ShebangConfig(TypedDict):
20
+ """TypedDict for SLURM script shebang configuration.
21
+
22
+ Parameters
23
+ ----------
24
+ base : str
25
+ Base shebang line for all SLURM scripts
26
+ multinode : list[str]
27
+ Additional SLURM directives for multi-node configurations
28
+ """
29
+
30
+ base: str
31
+ multinode: list[str]
32
+
33
+
34
+ class ServerSetupConfig(TypedDict):
35
+ """TypedDict for server setup configuration.
36
+
37
+ Parameters
38
+ ----------
39
+ single_node : list[str]
40
+ Setup commands for single-node deployments
41
+ multinode : list[str]
42
+ Setup commands for multi-node deployments, including Ray initialization
43
+ """
44
+
45
+ single_node: list[str]
46
+ multinode: list[str]
47
+
48
+
49
+ class SlurmScriptTemplate(TypedDict):
50
+ """TypedDict for complete SLURM script template configuration.
51
+
52
+ Parameters
53
+ ----------
54
+ shebang : ShebangConfig
55
+ Shebang and SLURM directive configuration
56
+ container_setup : list[str]
57
+ Commands for container setup
58
+ imports : str
59
+ Import statements and source commands
60
+ container_command : str
61
+ Template for container execution command
62
+ activate_venv : str
63
+ Template for virtual environment activation
64
+ server_setup : ServerSetupConfig
65
+ Server initialization commands for different deployment modes
66
+ find_vllm_port : list[str]
67
+ Commands to find available ports for vLLM server
68
+ write_to_json : list[str]
69
+ Commands to write server configuration to JSON
70
+ launch_cmd : list[str]
71
+ vLLM server launch commands
72
+ """
73
+
74
+ shebang: ShebangConfig
75
+ container_setup: list[str]
76
+ imports: str
77
+ env_vars: list[str]
78
+ container_command: str
79
+ activate_venv: str
80
+ server_setup: ServerSetupConfig
81
+ find_vllm_port: list[str]
82
+ write_to_json: list[str]
83
+ launch_cmd: list[str]
84
+
85
+
86
+ SLURM_SCRIPT_TEMPLATE: SlurmScriptTemplate = {
87
+ "shebang": {
88
+ "base": "#!/bin/bash",
89
+ "multinode": [
90
+ "#SBATCH --exclusive",
91
+ "#SBATCH --tasks-per-node=1",
92
+ ],
93
+ },
94
+ "container_setup": [
95
+ CONTAINER_LOAD_CMD,
96
+ f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop",
97
+ ],
98
+ "imports": "source {src_dir}/find_port.sh",
99
+ "env_vars": [
100
+ f"export {CONTAINER_MODULE_NAME}_BINDPATH=${CONTAINER_MODULE_NAME}_BINDPATH,$(echo /dev/infiniband* | sed -e 's/ /,/g')"
101
+ ],
102
+ "container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --bind {{model_weights_path}}{{additional_binds}} --containall {IMAGE_PATH} \\",
103
+ "activate_venv": "source {venv}/bin/activate",
104
+ "server_setup": {
105
+ "single_node": [
106
+ "\n# Find available port",
107
+ "head_node_ip=${SLURMD_NODENAME}",
108
+ ],
109
+ "multinode": [
110
+ "\n# Get list of nodes",
111
+ 'nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")',
112
+ "nodes_array=($nodes)",
113
+ "head_node=${{nodes_array[0]}}",
114
+ 'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)',
115
+ "\n# Start Ray head node",
116
+ "head_node_port=$(find_available_port $head_node_ip 8080 65535)",
117
+ "ray_head=$head_node_ip:$head_node_port",
118
+ 'echo "Ray Head IP: $ray_head"',
119
+ 'echo "Starting HEAD at $head_node"',
120
+ 'srun --nodes=1 --ntasks=1 -w "$head_node" \\',
121
+ " CONTAINER_PLACEHOLDER",
122
+ ' ray start --head --node-ip-address="$head_node_ip" --port=$head_node_port \\',
123
+ ' --num-cpus "$SLURM_CPUS_PER_TASK" --num-gpus {gpus_per_node} --block &',
124
+ "sleep 10",
125
+ "\n# Start Ray worker nodes",
126
+ "worker_num=$((SLURM_JOB_NUM_NODES - 1))",
127
+ "for ((i = 1; i <= worker_num; i++)); do",
128
+ " node_i=${{nodes_array[$i]}}",
129
+ ' echo "Starting WORKER $i at $node_i"',
130
+ ' srun --nodes=1 --ntasks=1 -w "$node_i" \\',
131
+ " CONTAINER_PLACEHOLDER",
132
+ ' ray start --address "$ray_head" \\',
133
+ ' --num-cpus "$SLURM_CPUS_PER_TASK" --num-gpus {gpus_per_node} --block &',
134
+ " sleep 5",
135
+ "done",
136
+ ],
137
+ },
138
+ "find_vllm_port": [
139
+ "\nvllm_port_number=$(find_available_port $head_node_ip 8080 65535)",
140
+ 'server_address="http://${head_node_ip}:${vllm_port_number}/v1"',
141
+ ],
142
+ "write_to_json": [
143
+ '\njson_path="{log_dir}/{model_name}.$SLURM_JOB_ID/{model_name}.$SLURM_JOB_ID.json"',
144
+ 'jq --arg server_addr "$server_address" \\',
145
+ " '. + {{\"server_address\": $server_addr}}' \\",
146
+ ' "$json_path" > temp.json \\',
147
+ ' && mv temp.json "$json_path"',
148
+ ],
149
+ "launch_cmd": [
150
+ "vllm serve {model_weights_path} \\",
151
+ " --served-model-name {model_name} \\",
152
+ ' --host "0.0.0.0" \\',
153
+ " --port $vllm_port_number \\",
154
+ ],
155
+ }
156
+
157
+
158
+ class BatchSlurmScriptTemplate(TypedDict):
159
+ """TypedDict for batch SLURM script template configuration.
160
+
161
+ Parameters
162
+ ----------
163
+ shebang : str
164
+ Shebang line for the script
165
+ hetjob : str
166
+ SLURM directive for hetjob
167
+ permission_update : str
168
+ Command to update permissions of the script
169
+ launch_model_scripts : list[str]
170
+ Commands to launch the vLLM server
171
+ """
172
+
173
+ shebang: str
174
+ hetjob: str
175
+ permission_update: str
176
+ launch_model_scripts: list[str]
177
+
178
+
179
+ BATCH_SLURM_SCRIPT_TEMPLATE: BatchSlurmScriptTemplate = {
180
+ "shebang": "#!/bin/bash",
181
+ "hetjob": "#SBATCH hetjob\n",
182
+ "permission_update": "chmod +x {script_name}",
183
+ "launch_model_scripts": [
184
+ "\nsrun --het-group={het_group_id} \\",
185
+ " --output={out_file} \\",
186
+ " --error={err_file} \\",
187
+ " {script_name} &\n",
188
+ ],
189
+ }
190
+
191
+
192
+ class BatchModelLaunchScriptTemplate(TypedDict):
193
+ """TypedDict for batch model launch script template configuration.
194
+
195
+ Parameters
196
+ ----------
197
+ shebang : str
198
+ Shebang line for the script
199
+ container_setup : list[str]
200
+ Commands for container setup
201
+ env_vars : list[str]
202
+ Environment variables to set
203
+ server_address_setup : list[str]
204
+ Commands to setup the server address
205
+ launch_cmd : list[str]
206
+ Commands to launch the vLLM server
207
+ container_command : str
208
+ Commands to setup the container command
209
+ """
210
+
211
+ shebang: str
212
+ container_setup: str
213
+ env_vars: list[str]
214
+ server_address_setup: list[str]
215
+ write_to_json: list[str]
216
+ launch_cmd: list[str]
217
+ container_command: str
218
+
219
+
220
+ BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE: BatchModelLaunchScriptTemplate = {
221
+ "shebang": "#!/bin/bash\n",
222
+ "container_setup": f"{CONTAINER_LOAD_CMD}\n",
223
+ "env_vars": [
224
+ f"export {CONTAINER_MODULE_NAME}_BINDPATH=${CONTAINER_MODULE_NAME}_BINDPATH,$(echo /dev/infiniband* | sed -e 's/ /,/g')"
225
+ ],
226
+ "server_address_setup": [
227
+ "source {src_dir}/find_port.sh",
228
+ "head_node_ip=${{SLURMD_NODENAME}}",
229
+ "vllm_port_number=$(find_available_port $head_node_ip 8080 65535)",
230
+ 'server_address="http://${{head_node_ip}}:${{vllm_port_number}}/v1"\n',
231
+ "echo $server_address\n",
232
+ ],
233
+ "write_to_json": [
234
+ "het_job_id=$(($SLURM_JOB_ID+{het_group_id}))",
235
+ 'json_path="{log_dir}/{slurm_job_name}.$het_job_id/{model_name}.$het_job_id.json"',
236
+ 'jq --arg server_addr "$server_address" \\',
237
+ " '. + {{\"server_address\": $server_addr}}' \\",
238
+ ' "$json_path" > temp_{model_name}.json \\',
239
+ ' && mv temp_{model_name}.json "$json_path"\n',
240
+ ],
241
+ "container_command": f"{CONTAINER_MODULE_NAME} exec --nv --bind {{model_weights_path}}{{additional_binds}} --containall {IMAGE_PATH} \\",
242
+ "launch_cmd": [
243
+ "vllm serve {model_weights_path} \\",
244
+ " --served-model-name {model_name} \\",
245
+ ' --host "0.0.0.0" \\',
246
+ " --port $vllm_port_number \\",
247
+ ],
248
+ }
@@ -0,0 +1,86 @@
1
+ """Slurm cluster configuration variables."""
2
+
3
+ import os
4
+ import warnings
5
+ from pathlib import Path
6
+ from typing import Any, TypeAlias
7
+
8
+ import yaml
9
+ from typing_extensions import Literal
10
+
11
+
12
+ CACHED_CONFIG_DIR = Path("/model-weights/vec-inf-shared")
13
+
14
+
15
+ def load_env_config() -> dict[str, Any]:
16
+ """Load the environment configuration."""
17
+
18
+ def load_yaml_config(path: Path) -> dict[str, Any]:
19
+ """Load YAML config with error handling."""
20
+ try:
21
+ with path.open() as f:
22
+ return yaml.safe_load(f) or {}
23
+ except FileNotFoundError as err:
24
+ raise FileNotFoundError(f"Could not find config: {path}") from err
25
+ except yaml.YAMLError as err:
26
+ raise ValueError(f"Error parsing YAML config at {path}: {err}") from err
27
+
28
+ cached_config_path = CACHED_CONFIG_DIR / "environment.yaml"
29
+ default_path = (
30
+ cached_config_path
31
+ if cached_config_path.exists()
32
+ else Path(__file__).resolve().parent.parent / "config" / "environment.yaml"
33
+ )
34
+ config = load_yaml_config(default_path)
35
+
36
+ user_path = os.getenv("VEC_INF_CONFIG_DIR")
37
+ if user_path:
38
+ user_path_obj = Path(user_path, "environment.yaml")
39
+ if user_path_obj.exists():
40
+ user_config = load_yaml_config(user_path_obj)
41
+ config.update(user_config)
42
+ else:
43
+ warnings.warn(
44
+ f"WARNING: Could not find user config directory: {user_path}, revert to default config located at {default_path}",
45
+ UserWarning,
46
+ stacklevel=2,
47
+ )
48
+
49
+ return config
50
+
51
+
52
+ _config = load_env_config()
53
+
54
+ # Extract path values
55
+ IMAGE_PATH = _config["paths"]["image_path"]
56
+
57
+ # Extract containerization info
58
+ CONTAINER_LOAD_CMD = _config["containerization"]["module_load_cmd"]
59
+ CONTAINER_MODULE_NAME = _config["containerization"]["module_name"]
60
+
61
+ # Extract limits
62
+ MAX_GPUS_PER_NODE = _config["limits"]["max_gpus_per_node"]
63
+ MAX_NUM_NODES = _config["limits"]["max_num_nodes"]
64
+ MAX_CPUS_PER_TASK = _config["limits"]["max_cpus_per_task"]
65
+
66
+
67
+ # Create dynamic Literal types
68
+ def create_literal_type(values: list[str], fallback: str = "") -> Any:
69
+ """Create a Literal type from a list, with configurable fallback."""
70
+ if not values:
71
+ return Literal[fallback]
72
+ return Literal[tuple(values)]
73
+
74
+
75
+ QOS: TypeAlias = create_literal_type(_config["allowed_values"]["qos"]) # type: ignore[valid-type]
76
+ PARTITION: TypeAlias = create_literal_type(_config["allowed_values"]["partition"]) # type: ignore[valid-type]
77
+ RESOURCE_TYPE: TypeAlias = create_literal_type( # type: ignore[valid-type]
78
+ _config["allowed_values"]["resource_type"]
79
+ )
80
+
81
+ # Extract required arguments, for launching jobs that don't have a default value and
82
+ # their corresponding environment variables
83
+ REQUIRED_ARGS: dict[str, str] = _config["required_args"]
84
+
85
+ # Extract default arguments
86
+ DEFAULT_ARGS: dict[str, str] = _config["default_args"]