vec-inf 0.7.3__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
- """Class for generating Slurm scripts to run vLLM servers.
1
+ """Class for generating Slurm scripts to run inference servers.
2
2
 
3
- This module provides functionality to generate Slurm scripts for running vLLM servers
4
- in both single-node and multi-node configurations.
3
+ This module provides functionality to generate Slurm scripts for running inference
4
+ servers in both single-node and multi-node configurations.
5
5
  """
6
6
 
7
7
  from datetime import datetime
@@ -14,11 +14,11 @@ from vec_inf.client._slurm_templates import (
14
14
  BATCH_SLURM_SCRIPT_TEMPLATE,
15
15
  SLURM_SCRIPT_TEMPLATE,
16
16
  )
17
- from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME
17
+ from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME, IMAGE_PATH
18
18
 
19
19
 
20
20
  class SlurmScriptGenerator:
21
- """A class to generate Slurm scripts for running vLLM servers.
21
+ """A class to generate Slurm scripts for running inference servers.
22
22
 
23
23
  This class handles the generation of Slurm scripts for both single-node and
24
24
  multi-node configurations, supporting different virtualization environments
@@ -32,6 +32,7 @@ class SlurmScriptGenerator:
32
32
 
33
33
  def __init__(self, params: dict[str, Any]):
34
34
  self.params = params
35
+ self.engine = params.get("engine", "vllm")
35
36
  self.is_multinode = int(self.params["num_nodes"]) > 1
36
37
  self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
37
38
  self.additional_binds = (
@@ -111,6 +112,7 @@ class SlurmScriptGenerator:
111
112
  server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
112
113
  server_script.append(
113
114
  SLURM_SCRIPT_TEMPLATE["bind_path"].format(
115
+ work_dir=self.params.get("work_dir", str(Path.home())),
114
116
  model_weights_path=self.model_weights_path,
115
117
  additional_binds=self.additional_binds,
116
118
  )
@@ -123,16 +125,17 @@ class SlurmScriptGenerator:
123
125
  server_script.append(
124
126
  SLURM_SCRIPT_TEMPLATE["imports"].format(src_dir=self.params["src_dir"])
125
127
  )
126
- if self.is_multinode:
128
+
129
+ if self.is_multinode and self.engine == "vllm":
127
130
  server_setup_str = "\n".join(
128
- SLURM_SCRIPT_TEMPLATE["server_setup"]["multinode"]
131
+ SLURM_SCRIPT_TEMPLATE["server_setup"]["multinode_vllm"]
129
132
  ).format(gpus_per_node=self.params["gpus_per_node"])
130
133
  if self.use_container:
131
134
  server_setup_str = server_setup_str.replace(
132
135
  "CONTAINER_PLACEHOLDER",
133
136
  SLURM_SCRIPT_TEMPLATE["container_command"].format(
134
- model_weights_path=self.model_weights_path,
135
137
  env_str=self.env_str,
138
+ image_path=IMAGE_PATH[self.engine],
136
139
  ),
137
140
  )
138
141
  else:
@@ -140,12 +143,16 @@ class SlurmScriptGenerator:
140
143
  "CONTAINER_PLACEHOLDER",
141
144
  "\\",
142
145
  )
146
+ elif self.is_multinode and self.engine == "sglang":
147
+ server_setup_str = "\n".join(
148
+ SLURM_SCRIPT_TEMPLATE["server_setup"]["multinode_sglang"]
149
+ )
143
150
  else:
144
151
  server_setup_str = "\n".join(
145
152
  SLURM_SCRIPT_TEMPLATE["server_setup"]["single_node"]
146
153
  )
147
154
  server_script.append(server_setup_str)
148
- server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["find_vllm_port"]))
155
+ server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["find_server_port"]))
149
156
  server_script.append(
150
157
  "\n".join(SLURM_SCRIPT_TEMPLATE["write_to_json"]).format(
151
158
  log_dir=self.params["log_dir"], model_name=self.params["model_name"]
@@ -154,38 +161,85 @@ class SlurmScriptGenerator:
154
161
  return "\n".join(server_script)
155
162
 
156
163
  def _generate_launch_cmd(self) -> str:
157
- """Generate the vLLM server launch command.
164
+ """Generate the inference server launch command.
158
165
 
159
- Creates the command to launch the vLLM server, handling different virtualization
160
- environments (venv or singularity/apptainer).
166
+ Creates the command to launch the inference server, handling different
167
+ virtualization environments (venv or singularity/apptainer).
161
168
 
162
169
  Returns
163
170
  -------
164
171
  str
165
172
  Server launch command.
166
173
  """
167
- launcher_script = ["\n"]
174
+ if self.is_multinode and self.engine == "sglang":
175
+ return self._generate_multinode_sglang_launch_cmd()
176
+
177
+ launch_cmd = ["\n"]
168
178
  if self.use_container:
169
- launcher_script.append(
179
+ launch_cmd.append(
170
180
  SLURM_SCRIPT_TEMPLATE["container_command"].format(
171
- model_weights_path=self.model_weights_path,
172
181
  env_str=self.env_str,
182
+ image_path=IMAGE_PATH[self.engine],
173
183
  )
174
184
  )
175
185
 
176
- launcher_script.append(
177
- "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
186
+ launch_cmd.append(
187
+ "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"][self.engine]).format( # type: ignore[literal-required]
178
188
  model_weights_path=self.model_weights_path,
179
189
  model_name=self.params["model_name"],
180
190
  )
181
191
  )
182
192
 
183
- for arg, value in self.params["vllm_args"].items():
193
+ for arg, value in self.params["engine_args"].items():
194
+ if isinstance(value, bool):
195
+ launch_cmd.append(f" {arg} \\")
196
+ else:
197
+ launch_cmd.append(f" {arg} {value} \\")
198
+
199
+ # A known bug in vLLM requires setting backend to ray for multi-node
200
+ # Remove this when the bug is fixed
201
+ if self.is_multinode:
202
+ launch_cmd.append(" --distributed-executor-backend ray \\")
203
+
204
+ return "\n".join(launch_cmd).rstrip(" \\")
205
+
206
+ def _generate_multinode_sglang_launch_cmd(self) -> str:
207
+ """Generate the launch command for multi-node sglang setup.
208
+
209
+ Returns
210
+ -------
211
+ str
212
+ Multi-node sglang launch command.
213
+ """
214
+ launch_cmd = "\n" + "\n".join(
215
+ SLURM_SCRIPT_TEMPLATE["launch_cmd"]["sglang_multinode"]
216
+ ).format(
217
+ num_nodes=self.params["num_nodes"],
218
+ model_weights_path=self.model_weights_path,
219
+ model_name=self.params["model_name"],
220
+ )
221
+
222
+ container_placeholder = "\\"
223
+ if self.use_container:
224
+ container_placeholder = SLURM_SCRIPT_TEMPLATE["container_command"].format(
225
+ env_str=self.env_str,
226
+ image_path=IMAGE_PATH[self.engine],
227
+ )
228
+ launch_cmd = launch_cmd.replace(
229
+ "CONTAINER_PLACEHOLDER",
230
+ container_placeholder,
231
+ )
232
+
233
+ engine_arg_str = ""
234
+ for arg, value in self.params["engine_args"].items():
184
235
  if isinstance(value, bool):
185
- launcher_script.append(f" {arg} \\")
236
+ engine_arg_str += f" {arg} \\\n"
186
237
  else:
187
- launcher_script.append(f" {arg} {value} \\")
188
- return "\n".join(launcher_script)
238
+ engine_arg_str += f" {arg} {value} \\\n"
239
+
240
+ return launch_cmd.replace(
241
+ "SGLANG_ARGS_PLACEHOLDER", engine_arg_str.rstrip("\\\n")
242
+ )
189
243
 
190
244
  def write_to_log_dir(self) -> Path:
191
245
  """Write the generated Slurm script to the log directory.
@@ -212,7 +266,7 @@ class BatchSlurmScriptGenerator:
212
266
  """A class to generate Slurm scripts for batch mode.
213
267
 
214
268
  This class handles the generation of Slurm scripts for batch mode, which
215
- launches multiple vLLM servers with different configurations in parallel.
269
+ launches multiple inference servers with different configurations in parallel.
216
270
  """
217
271
 
218
272
  def __init__(self, params: dict[str, Any]):
@@ -246,7 +300,7 @@ class BatchSlurmScriptGenerator:
246
300
  return script_path
247
301
 
248
302
  def _generate_model_launch_script(self, model_name: str) -> Path:
249
- """Generate the bash script for launching individual vLLM servers.
303
+ """Generate the bash script for launching individual inference servers.
250
304
 
251
305
  Parameters
252
306
  ----------
@@ -256,7 +310,7 @@ class BatchSlurmScriptGenerator:
256
310
  Returns
257
311
  -------
258
312
  Path
259
- The bash script path for launching the vLLM server.
313
+ The bash script path for launching the inference server.
260
314
  """
261
315
  # Generate the bash script content
262
316
  script_content = []
@@ -266,6 +320,7 @@ class BatchSlurmScriptGenerator:
266
320
  script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
267
321
  script_content.append(
268
322
  BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
323
+ work_dir=self.params.get("work_dir", str(Path.home())),
269
324
  model_weights_path=model_params["model_weights_path"],
270
325
  additional_binds=model_params["additional_binds"],
271
326
  )
@@ -286,21 +341,23 @@ class BatchSlurmScriptGenerator:
286
341
  if self.use_container:
287
342
  script_content.append(
288
343
  BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format(
289
- model_weights_path=model_params["model_weights_path"],
344
+ image_path=IMAGE_PATH[model_params["engine"]],
290
345
  )
291
346
  )
292
347
  script_content.append(
293
- "\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format(
348
+ "\n".join(
349
+ BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"][model_params["engine"]]
350
+ ).format(
294
351
  model_weights_path=model_params["model_weights_path"],
295
352
  model_name=model_name,
296
353
  )
297
354
  )
298
- for arg, value in model_params["vllm_args"].items():
355
+ for arg, value in model_params["engine_args"].items():
299
356
  if isinstance(value, bool):
300
357
  script_content.append(f" {arg} \\")
301
358
  else:
302
359
  script_content.append(f" {arg} {value} \\")
303
- script_content[-1] = script_content[-1].replace("\\", "")
360
+ script_content[-1] = script_content[-1].rstrip(" \\")
304
361
  # Write the bash script to the log directory
305
362
  launch_script_path = self._write_to_log_dir(
306
363
  script_content, f"launch_{model_name}.sh"
@@ -339,12 +396,12 @@ class BatchSlurmScriptGenerator:
339
396
  return "\n".join(shebang)
340
397
 
341
398
  def generate_batch_slurm_script(self) -> Path:
342
- """Generate the Slurm script for launching multiple vLLM servers in batch mode.
399
+ """Generate the Slurm script for launching multiple inference servers in batch.
343
400
 
344
401
  Returns
345
402
  -------
346
403
  Path
347
- The Slurm script for launching multiple vLLM servers in batch mode.
404
+ The Slurm script for launching multiple inference servers in batch.
348
405
  """
349
406
  script_content = []
350
407
 
@@ -9,7 +9,7 @@ from typing import TypedDict
9
9
  from vec_inf.client._slurm_vars import (
10
10
  CONTAINER_LOAD_CMD,
11
11
  CONTAINER_MODULE_NAME,
12
- IMAGE_PATH,
12
+ PYTHON_VERSION,
13
13
  )
14
14
 
15
15
 
@@ -38,12 +38,33 @@ class ServerSetupConfig(TypedDict):
38
38
  ----------
39
39
  single_node : list[str]
40
40
  Setup commands for single-node deployments
41
- multinode : list[str]
42
- Setup commands for multi-node deployments, including Ray initialization
41
+ multinode_vllm : list[str]
42
+ Setup commands for multi-node vLLM deployments
43
+ multinode_sglang : list[str]
44
+ Setup commands for multi-node SGLang deployments
43
45
  """
44
46
 
45
47
  single_node: list[str]
46
- multinode: list[str]
48
+ multinode_vllm: list[str]
49
+ multinode_sglang: list[str]
50
+
51
+
52
+ class LaunchCmdConfig(TypedDict):
53
+ """TypedDict for launch command configuration.
54
+
55
+ Parameters
56
+ ----------
57
+ vllm : list[str]
58
+ Launch commands for vLLM inference server
59
+ sglang : list[str]
60
+ Launch commands for SGLang inference server
61
+ sglang_multinode : list[str]
62
+ Launch commands for multi-node SGLang inference server
63
+ """
64
+
65
+ vllm: list[str]
66
+ sglang: list[str]
67
+ sglang_multinode: list[str]
47
68
 
48
69
 
49
70
  class SlurmScriptTemplate(TypedDict):
@@ -65,12 +86,12 @@ class SlurmScriptTemplate(TypedDict):
65
86
  Template for virtual environment activation
66
87
  server_setup : ServerSetupConfig
67
88
  Server initialization commands for different deployment modes
68
- find_vllm_port : list[str]
69
- Commands to find available ports for vLLM server
89
+ find_server_port : list[str]
90
+ Commands to find available ports for inference server
70
91
  write_to_json : list[str]
71
92
  Commands to write server configuration to JSON
72
- launch_cmd : list[str]
73
- vLLM server launch commands
93
+ launch_cmd : LaunchCmdConfig
94
+ Inference server launch commands
74
95
  """
75
96
 
76
97
  shebang: ShebangConfig
@@ -80,33 +101,31 @@ class SlurmScriptTemplate(TypedDict):
80
101
  container_command: str
81
102
  activate_venv: str
82
103
  server_setup: ServerSetupConfig
83
- find_vllm_port: list[str]
104
+ find_server_port: list[str]
84
105
  write_to_json: list[str]
85
- launch_cmd: list[str]
106
+ launch_cmd: LaunchCmdConfig
86
107
 
87
108
 
88
109
  SLURM_SCRIPT_TEMPLATE: SlurmScriptTemplate = {
89
110
  "shebang": {
90
111
  "base": "#!/bin/bash",
91
112
  "multinode": [
92
- "#SBATCH --exclusive",
93
- "#SBATCH --tasks-per-node=1",
113
+ "#SBATCH --ntasks-per-node=1",
94
114
  ],
95
115
  },
96
116
  "container_setup": [
97
117
  CONTAINER_LOAD_CMD,
98
- f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop",
99
118
  ],
100
119
  "imports": "source {src_dir}/find_port.sh",
101
- "bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}",
102
- "container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --containall {IMAGE_PATH} \\",
120
+ "bind_path": f"export {CONTAINER_MODULE_NAME_UPPER}_BINDPATH=${CONTAINER_MODULE_NAME_UPPER}_BINDPATH,/dev,/tmp,{{work_dir}}/.vec-inf-cache:$HOME/.cache,{{model_weights_path}}{{additional_binds}}",
121
+ "container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --containall {{image_path}} \\",
103
122
  "activate_venv": "source {venv}/bin/activate",
104
123
  "server_setup": {
105
124
  "single_node": [
106
125
  "\n# Find available port",
107
- "head_node_ip=${SLURMD_NODENAME}",
126
+ "head_node=${SLURMD_NODENAME}",
108
127
  ],
109
- "multinode": [
128
+ "multinode_vllm": [
110
129
  "\n# Get list of nodes",
111
130
  'nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")',
112
131
  "nodes_array=($nodes)",
@@ -130,7 +149,7 @@ SLURM_SCRIPT_TEMPLATE: SlurmScriptTemplate = {
130
149
  " fi",
131
150
  "fi",
132
151
  "\n# Start Ray head node",
133
- "head_node_port=$(find_available_port $head_node_ip 8080 65535)",
152
+ "head_node_port=$(find_available_port $head_node 8080 65535)",
134
153
  "ray_head=$head_node_ip:$head_node_port",
135
154
  'echo "Ray Head IP: $ray_head"',
136
155
  'echo "Starting HEAD at $head_node"',
@@ -151,10 +170,19 @@ SLURM_SCRIPT_TEMPLATE: SlurmScriptTemplate = {
151
170
  " sleep 5",
152
171
  "done",
153
172
  ],
173
+ "multinode_sglang": [
174
+ "\n# Set NCCL initialization address using the hostname of the head node",
175
+ 'nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")',
176
+ "nodes_array=($nodes)",
177
+ "head_node=${nodes_array[0]}",
178
+ "NCCL_PORT=$(find_available_port $head_node 8000 65535)",
179
+ 'NCCL_INIT_ADDR="${head_node}:${NCCL_PORT}"',
180
+ 'echo "[INFO] NCCL_INIT_ADDR: $NCCL_INIT_ADDR"',
181
+ ],
154
182
  },
155
- "find_vllm_port": [
156
- "\nvllm_port_number=$(find_available_port $head_node_ip 8080 65535)",
157
- 'server_address="http://${head_node_ip}:${vllm_port_number}/v1"',
183
+ "find_server_port": [
184
+ "\nserver_port_number=$(find_available_port $head_node 8080 65535)",
185
+ 'server_address="http://${head_node}:${server_port_number}/v1"',
158
186
  ],
159
187
  "write_to_json": [
160
188
  '\njson_path="{log_dir}/{model_name}.$SLURM_JOB_ID/{model_name}.$SLURM_JOB_ID.json"',
@@ -163,12 +191,39 @@ SLURM_SCRIPT_TEMPLATE: SlurmScriptTemplate = {
163
191
  ' "$json_path" > temp.json \\',
164
192
  ' && mv temp.json "$json_path"',
165
193
  ],
166
- "launch_cmd": [
167
- "vllm serve {model_weights_path} \\",
168
- " --served-model-name {model_name} \\",
169
- ' --host "0.0.0.0" \\',
170
- " --port $vllm_port_number \\",
171
- ],
194
+ "launch_cmd": {
195
+ "vllm": [
196
+ "vllm serve {model_weights_path} \\",
197
+ " --served-model-name {model_name} \\",
198
+ ' --host "0.0.0.0" \\',
199
+ " --port $server_port_number \\",
200
+ ],
201
+ "sglang": [
202
+ f"{PYTHON_VERSION} -m sglang.launch_server \\",
203
+ " --model-path {model_weights_path} \\",
204
+ " --served-model-name {model_name} \\",
205
+ ' --host "0.0.0.0" \\',
206
+ " --port $server_port_number \\",
207
+ ],
208
+ "sglang_multinode": [
209
+ "for ((i = 0; i < $SLURM_JOB_NUM_NODES; i++)); do",
210
+ " node_i=${{nodes_array[$i]}}",
211
+ ' echo "Launching SGLang server on $node_i"',
212
+ ' srun --ntasks=1 --nodes=1 -w "$node_i" \\',
213
+ " CONTAINER_PLACEHOLDER",
214
+ f" {PYTHON_VERSION} -m sglang.launch_server \\",
215
+ " --model-path {model_weights_path} \\",
216
+ " --served-model-name {model_name} \\",
217
+ ' --host "0.0.0.0" \\',
218
+ " --port $server_port_number \\",
219
+ ' --nccl-init-addr "$NCCL_INIT_ADDR" \\',
220
+ " --nnodes {num_nodes} \\",
221
+ ' --node-rank "$i" \\',
222
+ "SGLANG_ARGS_PLACEHOLDER &",
223
+ "done",
224
+ "\nwait",
225
+ ],
226
+ },
172
227
  }
173
228
 
174
229
 
@@ -184,7 +239,7 @@ class BatchSlurmScriptTemplate(TypedDict):
184
239
  permission_update : str
185
240
  Command to update permissions of the script
186
241
  launch_model_scripts : list[str]
187
- Commands to launch the vLLM server
242
+ Commands to run server launch scripts
188
243
  """
189
244
 
190
245
  shebang: str
@@ -220,7 +275,7 @@ class BatchModelLaunchScriptTemplate(TypedDict):
220
275
  server_address_setup : list[str]
221
276
  Commands to setup the server address
222
277
  launch_cmd : list[str]
223
- Commands to launch the vLLM server
278
+ Commands to launch the inference server
224
279
  container_command : str
225
280
  Commands to setup the container command
226
281
  """
@@ -230,19 +285,19 @@ class BatchModelLaunchScriptTemplate(TypedDict):
230
285
  bind_path: str
231
286
  server_address_setup: list[str]
232
287
  write_to_json: list[str]
233
- launch_cmd: list[str]
288
+ launch_cmd: dict[str, list[str]]
234
289
  container_command: str
235
290
 
236
291
 
237
292
  BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE: BatchModelLaunchScriptTemplate = {
238
293
  "shebang": "#!/bin/bash\n",
239
294
  "container_setup": f"{CONTAINER_LOAD_CMD}\n",
240
- "bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}",
295
+ "bind_path": f"export {CONTAINER_MODULE_NAME_UPPER}_BINDPATH=${CONTAINER_MODULE_NAME_UPPER}_BINDPATH,/dev,/tmp,{{work_dir}}/.vec-inf-cache:$HOME/.cache,{{model_weights_path}}{{additional_binds}}",
241
296
  "server_address_setup": [
242
297
  "source {src_dir}/find_port.sh",
243
298
  "head_node_ip=${{SLURMD_NODENAME}}",
244
- "vllm_port_number=$(find_available_port $head_node_ip 8080 65535)",
245
- 'server_address="http://${{head_node_ip}}:${{vllm_port_number}}/v1"\n',
299
+ "server_port_number=$(find_available_port $head_node_ip 8080 65535)",
300
+ 'server_address="http://${{head_node_ip}}:${{server_port_number}}/v1"\n',
246
301
  "echo $server_address\n",
247
302
  ],
248
303
  "write_to_json": [
@@ -253,11 +308,20 @@ BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE: BatchModelLaunchScriptTemplate = {
253
308
  ' "$json_path" > temp_{model_name}.json \\',
254
309
  ' && mv temp_{model_name}.json "$json_path"\n',
255
310
  ],
256
- "container_command": f"{CONTAINER_MODULE_NAME} exec --nv --containall {IMAGE_PATH} \\",
257
- "launch_cmd": [
258
- "vllm serve {model_weights_path} \\",
259
- " --served-model-name {model_name} \\",
260
- ' --host "0.0.0.0" \\',
261
- " --port $vllm_port_number \\",
262
- ],
311
+ "container_command": f"{CONTAINER_MODULE_NAME} exec --nv --containall {{image_path}} \\",
312
+ "launch_cmd": {
313
+ "vllm": [
314
+ "vllm serve {model_weights_path} \\",
315
+ " --served-model-name {model_name} \\",
316
+ ' --host "0.0.0.0" \\',
317
+ " --port $server_port_number \\",
318
+ ],
319
+ "sglang": [
320
+ f"{PYTHON_VERSION} -m sglang.launch_server \\",
321
+ " --model-path {model_weights_path} \\",
322
+ " --served-model-name {model_name} \\",
323
+ ' --host "0.0.0.0" \\',
324
+ " --port $server_port_number \\",
325
+ ],
326
+ },
263
327
  }
@@ -52,7 +52,11 @@ def load_env_config() -> dict[str, Any]:
52
52
  _config = load_env_config()
53
53
 
54
54
  # Extract path values
55
- IMAGE_PATH = _config["paths"]["image_path"]
55
+ IMAGE_PATH = {
56
+ "vllm": _config["paths"]["vllm_image_path"],
57
+ "sglang": _config["paths"]["sglang_image_path"],
58
+ }
59
+ CACHED_MODEL_CONFIG_PATH = Path(_config["paths"]["cached_model_config_path"])
56
60
 
57
61
  # Extract containerization info
58
62
  CONTAINER_LOAD_CMD = _config["containerization"]["module_load_cmd"]
@@ -78,9 +82,14 @@ RESOURCE_TYPE: TypeAlias = create_literal_type( # type: ignore[valid-type]
78
82
  _config["allowed_values"]["resource_type"]
79
83
  )
80
84
 
81
- # Extract required arguments, for launching jobs that don't have a default value and
82
- # their corresponding environment variables
83
- REQUIRED_ARGS: dict[str, str] = _config["required_args"]
85
+ # Model types available derived from the cached model config
86
+ MODEL_TYPES: TypeAlias = create_literal_type(_config["model_types"]) # type: ignore[valid-type]
87
+
88
+ # Required arguments for launching jobs and corresponding environment variables
89
+ REQUIRED_ARGS: dict[str, str | None] = _config["required_args"]
90
+
91
+ # Running sglang requires python version
92
+ PYTHON_VERSION: str = _config["python_version"]
84
93
 
85
94
  # Extract default arguments
86
95
  DEFAULT_ARGS: dict[str, str] = _config["default_args"]
vec_inf/client/_utils.py CHANGED
@@ -16,7 +16,7 @@ import yaml
16
16
 
17
17
  from vec_inf.client._client_vars import MODEL_READY_SIGNATURE
18
18
  from vec_inf.client._exceptions import MissingRequiredFieldsError
19
- from vec_inf.client._slurm_vars import CACHED_CONFIG_DIR, REQUIRED_ARGS
19
+ from vec_inf.client._slurm_vars import CACHED_MODEL_CONFIG_PATH, REQUIRED_ARGS
20
20
  from vec_inf.client.config import ModelConfig
21
21
  from vec_inf.client.models import ModelStatus
22
22
 
@@ -77,7 +77,7 @@ def read_slurm_log(
77
77
  json_content: dict[str, str] = json.load(file)
78
78
  return json_content
79
79
  else:
80
- with file_path.open("r") as file:
80
+ with file_path.open("r", errors="replace") as file:
81
81
  return file.readlines()
82
82
  except FileNotFoundError:
83
83
  return f"LOG FILE NOT FOUND: {file_path}"
@@ -249,7 +249,7 @@ def load_config(config_path: Optional[str] = None) -> list[ModelConfig]:
249
249
  -----
250
250
  Configuration is loaded from:
251
251
  1. User path: specified by config_path
252
- 2. Default path: package's config/models.yaml or CACHED_CONFIG if it exists
252
+ 2. Default path: package's config/models.yaml or CACHED_MODEL_CONFIG_PATH if exists
253
253
  3. Environment variable: specified by VEC_INF_CONFIG environment variable
254
254
  and merged with default config
255
255
 
@@ -303,8 +303,8 @@ def load_config(config_path: Optional[str] = None) -> list[ModelConfig]:
303
303
 
304
304
  # 2. Otherwise, load default config
305
305
  default_path = (
306
- CACHED_CONFIG_DIR / "models.yaml"
307
- if CACHED_CONFIG_DIR.exists()
306
+ CACHED_MODEL_CONFIG_PATH
307
+ if CACHED_MODEL_CONFIG_PATH.exists()
308
308
  else Path(__file__).resolve().parent.parent / "config" / "models.yaml"
309
309
  )
310
310
  config = load_yaml_config(default_path)
@@ -444,10 +444,13 @@ def check_required_fields(params: dict[str, Any]) -> dict[str, Any]:
444
444
  params : dict[str, Any]
445
445
  Dictionary of parameters to check.
446
446
  """
447
- env_overrides = {}
447
+ env_overrides: dict[str, str] = {}
448
+
449
+ if not REQUIRED_ARGS:
450
+ return env_overrides
448
451
  for arg in REQUIRED_ARGS:
449
452
  if not params.get(arg):
450
- default_value = os.getenv(REQUIRED_ARGS[arg])
453
+ default_value = os.getenv(str(REQUIRED_ARGS[arg]))
451
454
  if default_value:
452
455
  params[arg] = default_value
453
456
  env_overrides[arg] = default_value
vec_inf/client/config.py CHANGED
@@ -8,13 +8,13 @@ from pathlib import Path
8
8
  from typing import Any, Optional, Union
9
9
 
10
10
  from pydantic import BaseModel, ConfigDict, Field
11
- from typing_extensions import Literal
12
11
 
13
12
  from vec_inf.client._slurm_vars import (
14
13
  DEFAULT_ARGS,
15
14
  MAX_CPUS_PER_TASK,
16
15
  MAX_GPUS_PER_NODE,
17
16
  MAX_NUM_NODES,
17
+ MODEL_TYPES,
18
18
  PARTITION,
19
19
  QOS,
20
20
  RESOURCE_TYPE,
@@ -66,8 +66,12 @@ class ModelConfig(BaseModel):
66
66
  Directory path for storing logs
67
67
  model_weights_parent_dir : Path, optional
68
68
  Base directory containing model weights
69
+ engine: str, optional
70
+ Inference engine to be used, supports 'vllm' and 'sglang'
69
71
  vllm_args : dict[str, Any], optional
70
72
  Additional arguments for vLLM engine configuration
73
+ sglang_args : dict[str, Any], optional
74
+ Additional arguments for SGLang engine configuration
71
75
 
72
76
  Notes
73
77
  -----
@@ -75,14 +79,16 @@ class ModelConfig(BaseModel):
75
79
  configured to be immutable (frozen) and forbids extra fields.
76
80
  """
77
81
 
82
+ model_config = ConfigDict(
83
+ extra="ignore", str_strip_whitespace=True, validate_default=True, frozen=True
84
+ )
85
+
78
86
  model_name: str = Field(..., min_length=3, pattern=r"^[a-zA-Z0-9\-_\.]+$")
79
87
  model_family: str = Field(..., min_length=2)
80
88
  model_variant: Optional[str] = Field(
81
89
  default=None, description="Specific variant/version of the model family"
82
90
  )
83
- model_type: Literal["LLM", "VLM", "Text_Embedding", "Reward_Modeling"] = Field(
84
- ..., description="Type of model architecture"
85
- )
91
+ model_type: MODEL_TYPES = Field(..., description="Type of model architecture")
86
92
  gpus_per_node: int = Field(
87
93
  ..., gt=0, le=MAX_GPUS_PER_NODE, description="GPUs per node"
88
94
  )
@@ -148,12 +154,16 @@ class ModelConfig(BaseModel):
148
154
  default=Path(DEFAULT_ARGS["model_weights_parent_dir"]),
149
155
  description="Base directory for model weights",
150
156
  )
157
+ engine: Optional[str] = Field(
158
+ default="vllm",
159
+ description="Inference engine to be used, supports 'vllm' and 'sglang'",
160
+ )
151
161
  vllm_args: Optional[dict[str, Any]] = Field(
152
162
  default={}, description="vLLM engine arguments"
153
163
  )
164
+ sglang_args: Optional[dict[str, Any]] = Field(
165
+ default={}, description="SGLang engine arguments"
166
+ )
154
167
  env: Optional[dict[str, Any]] = Field(
155
168
  default={}, description="Environment variables to be set"
156
169
  )
157
- model_config = ConfigDict(
158
- extra="forbid", str_strip_whitespace=True, validate_default=True, frozen=True
159
- )