vec-inf 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vec_inf/cli/_cli.py CHANGED
@@ -69,6 +69,16 @@ def cli() -> None:
69
69
  type=int,
70
70
  help="Number of GPUs/node to use, default to suggested resource allocation for model",
71
71
  )
72
+ @click.option(
73
+ "--cpus-per-task",
74
+ type=int,
75
+ help="Number of CPU cores per task",
76
+ )
77
+ @click.option(
78
+ "--mem-per-node",
79
+ type=str,
80
+ help="Memory allocation per node in GB format (e.g., '32G')",
81
+ )
72
82
  @click.option(
73
83
  "--account",
74
84
  "-A",
@@ -165,6 +175,10 @@ def launch(
165
175
  Number of nodes to use
166
176
  - gpus_per_node : int, optional
167
177
  Number of GPUs per node
178
+ - cpus_per_task : int, optional
179
+ Number of CPU cores per task
180
+ - mem_per_node : str, optional
181
+ Memory allocation per node in GB format (e.g., '32G')
168
182
  - account : str, optional
169
183
  Charge resources used by this job to specified account
170
184
  - work_dir : str, optional
@@ -447,7 +461,7 @@ def metrics(slurm_job_id: str) -> None:
447
461
  metrics_formatter.format_metrics()
448
462
 
449
463
  live.update(metrics_formatter.table)
450
- time.sleep(2)
464
+ time.sleep(1)
451
465
  except click.ClickException as e:
452
466
  raise e
453
467
  except Exception as e:
vec_inf/cli/_helper.py CHANGED
@@ -36,6 +36,43 @@ class LaunchResponseFormatter:
36
36
  self.model_name = model_name
37
37
  self.params = params
38
38
 
39
+ def _add_resource_allocation_details(self, table: Table) -> None:
40
+ """Add resource allocation details to the table."""
41
+ optional_fields = [
42
+ ("account", "Account"),
43
+ ("work_dir", "Working Directory"),
44
+ ("resource_type", "Resource Type"),
45
+ ("partition", "Partition"),
46
+ ("qos", "QoS"),
47
+ ]
48
+ for key, label in optional_fields:
49
+ if self.params.get(key):
50
+ table.add_row(label, self.params[key])
51
+
52
+ def _add_vllm_config(self, table: Table) -> None:
53
+ """Add vLLM configuration details to the table."""
54
+ if self.params.get("vllm_args"):
55
+ table.add_row("vLLM Arguments:", style="magenta")
56
+ for arg, value in self.params["vllm_args"].items():
57
+ table.add_row(f" {arg}:", str(value))
58
+
59
+ def _add_env_vars(self, table: Table) -> None:
60
+ """Add environment variable configuration details to the table."""
61
+ if self.params.get("env"):
62
+ table.add_row("Environment Variables", style="magenta")
63
+ for arg, value in self.params["env"].items():
64
+ table.add_row(f" {arg}:", str(value))
65
+
66
+ def _add_bind_paths(self, table: Table) -> None:
67
+ """Add bind path configuration details to the table."""
68
+ if self.params.get("bind"):
69
+ table.add_row("Bind Paths", style="magenta")
70
+ for path in self.params["bind"].split(","):
71
+ host = target = path
72
+ if ":" in path:
73
+ host, target = path.split(":")
74
+ table.add_row(f" {host}:", target)
75
+
39
76
  def format_table_output(self) -> Table:
40
77
  """Format output as rich Table.
41
78
 
@@ -59,16 +96,7 @@ class LaunchResponseFormatter:
59
96
  table.add_row("Vocabulary Size", self.params["vocab_size"])
60
97
 
61
98
  # Add resource allocation details
62
- if self.params.get("account"):
63
- table.add_row("Account", self.params["account"])
64
- if self.params.get("work_dir"):
65
- table.add_row("Working Directory", self.params["work_dir"])
66
- if self.params.get("resource_type"):
67
- table.add_row("Resource Type", self.params["resource_type"])
68
- if self.params.get("partition"):
69
- table.add_row("Partition", self.params["partition"])
70
- if self.params.get("qos"):
71
- table.add_row("QoS", self.params["qos"])
99
+ self._add_resource_allocation_details(table)
72
100
  table.add_row("Time Limit", self.params["time"])
73
101
  table.add_row("Num Nodes", self.params["num_nodes"])
74
102
  table.add_row("GPUs/Node", self.params["gpus_per_node"])
@@ -76,21 +104,18 @@ class LaunchResponseFormatter:
76
104
  table.add_row("Memory/Node", self.params["mem_per_node"])
77
105
 
78
106
  # Add job config details
107
+ if self.params.get("venv"):
108
+ table.add_row("Virtual Environment", self.params["venv"])
79
109
  table.add_row(
80
110
  "Model Weights Directory",
81
111
  str(Path(self.params["model_weights_parent_dir"], self.model_name)),
82
112
  )
83
113
  table.add_row("Log Directory", self.params["log_dir"])
84
114
 
85
- # Add vLLM configuration details
86
- table.add_row("vLLM Arguments:", style="magenta")
87
- for arg, value in self.params["vllm_args"].items():
88
- table.add_row(f" {arg}:", str(value))
89
-
90
- # Add Environment Variable Configuration Details
91
- table.add_row("Environment Variables", style="magenta")
92
- for arg, value in self.params["env"].items():
93
- table.add_row(f" {arg}:", str(value))
115
+ # Add configuration details
116
+ self._add_vllm_config(table)
117
+ self._add_env_vars(table)
118
+ self._add_bind_paths(table)
94
119
 
95
120
  return table
96
121
 
@@ -71,10 +71,3 @@ VLLM_SHORT_TO_LONG_MAP = {
71
71
 
72
72
  # Required matching arguments for batch mode
73
73
  BATCH_MODE_REQUIRED_MATCHING_ARGS = ["venv", "log_dir"]
74
-
75
- # Required arguments for launching jobs that don't have a default value and their
76
- # corresponding environment variables
77
- REQUIRED_ARGS = {
78
- "account": "VEC_INF_ACCOUNT",
79
- "work_dir": "VEC_INF_WORK_DIR",
80
- }
vec_inf/client/_helper.py CHANGED
@@ -31,6 +31,7 @@ from vec_inf.client._slurm_script_generator import (
31
31
  BatchSlurmScriptGenerator,
32
32
  SlurmScriptGenerator,
33
33
  )
34
+ from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME, IMAGE_PATH
34
35
  from vec_inf.client.config import ModelConfig
35
36
  from vec_inf.client.models import (
36
37
  BatchLaunchResponse,
@@ -195,23 +196,14 @@ class ModelLauncher:
195
196
  print(f"WARNING: Could not parse env var: {line}")
196
197
  return env_vars
197
198
 
198
- def _get_launch_params(self) -> dict[str, Any]:
199
- """Prepare launch parameters, set log dir, and validate required fields.
200
-
201
- Returns
202
- -------
203
- dict[str, Any]
204
- Dictionary of prepared launch parameters
199
+ def _apply_cli_overrides(self, params: dict[str, Any]) -> None:
200
+ """Apply CLI argument overrides to params.
205
201
 
206
- Raises
207
- ------
208
- MissingRequiredFieldsError
209
- If required fields are missing or tensor parallel size is not specified
210
- when using multiple GPUs
202
+ Parameters
203
+ ----------
204
+ params : dict[str, Any]
205
+ Dictionary of launch parameters to override
211
206
  """
212
- params = self.model_config.model_dump(exclude_none=True)
213
-
214
- # Override config defaults with CLI arguments
215
207
  if self.kwargs.get("vllm_args"):
216
208
  vllm_args = self._process_vllm_args(self.kwargs["vllm_args"])
217
209
  for key, value in vllm_args.items():
@@ -224,13 +216,29 @@ class ModelLauncher:
224
216
  params["env"][key] = str(value)
225
217
  del self.kwargs["env"]
226
218
 
219
+ if self.kwargs.get("bind") and params.get("bind"):
220
+ params["bind"] = f"{params['bind']},{self.kwargs['bind']}"
221
+ del self.kwargs["bind"]
222
+
227
223
  for key, value in self.kwargs.items():
228
224
  params[key] = value
229
225
 
230
- # Check for required fields without default vals, will raise an error if missing
231
- utils.check_required_fields(params)
226
+ def _validate_resource_allocation(self, params: dict[str, Any]) -> None:
227
+ """Validate resource allocation and parallelization settings.
232
228
 
233
- # Validate resource allocation and parallelization settings
229
+ Parameters
230
+ ----------
231
+ params : dict[str, Any]
232
+ Dictionary of launch parameters to validate
233
+
234
+ Raises
235
+ ------
236
+ MissingRequiredFieldsError
237
+ If tensor parallel size is not specified when using multiple GPUs
238
+ ValueError
239
+ If total # of GPUs requested is not a power of two
240
+ If mismatch between total # of GPUs requested and parallelization settings
241
+ """
234
242
  if (
235
243
  int(params["gpus_per_node"]) > 1
236
244
  and params["vllm_args"].get("--tensor-parallel-size") is None
@@ -251,19 +259,18 @@ class ModelLauncher:
251
259
  "Mismatch between total number of GPUs requested and parallelization settings"
252
260
  )
253
261
 
254
- # Convert gpus_per_node and resource_type to gres
255
- resource_type = params.get("resource_type")
256
- if resource_type:
257
- params["gres"] = f"gpu:{resource_type}:{params['gpus_per_node']}"
258
- else:
259
- params["gres"] = f"gpu:{params['gpus_per_node']}"
262
+ def _setup_log_files(self, params: dict[str, Any]) -> None:
263
+ """Set up log directory and file paths.
260
264
 
261
- # Create log directory
265
+ Parameters
266
+ ----------
267
+ params : dict[str, Any]
268
+ Dictionary of launch parameters to set up log files
269
+ """
262
270
  params["log_dir"] = Path(params["log_dir"], params["model_family"]).expanduser()
263
271
  params["log_dir"].mkdir(parents=True, exist_ok=True)
264
272
  params["src_dir"] = SRC_DIR
265
273
 
266
- # Construct slurm log file paths
267
274
  params["out_file"] = (
268
275
  f"{params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.out"
269
276
  )
@@ -274,6 +281,35 @@ class ModelLauncher:
274
281
  f"{params['log_dir']}/{self.model_name}.$SLURM_JOB_ID/{self.model_name}.$SLURM_JOB_ID.json"
275
282
  )
276
283
 
284
+ def _get_launch_params(self) -> dict[str, Any]:
285
+ """Prepare launch parameters, set log dir, and validate required fields.
286
+
287
+ Returns
288
+ -------
289
+ dict[str, Any]
290
+ Dictionary of prepared launch parameters
291
+ """
292
+ params = self.model_config.model_dump(exclude_none=True)
293
+
294
+ # Override config defaults with CLI arguments
295
+ self._apply_cli_overrides(params)
296
+
297
+ # Check for required fields without default vals, will raise an error if missing
298
+ utils.check_required_fields(params)
299
+
300
+ # Validate resource allocation and parallelization settings
301
+ self._validate_resource_allocation(params)
302
+
303
+ # Convert gpus_per_node and resource_type to gres
304
+ resource_type = params.get("resource_type")
305
+ if resource_type:
306
+ params["gres"] = f"gpu:{resource_type}:{params['gpus_per_node']}"
307
+ else:
308
+ params["gres"] = f"gpu:{params['gpus_per_node']}"
309
+
310
+ # Setup log files
311
+ self._setup_log_files(params)
312
+
277
313
  # Convert path to string for JSON serialization
278
314
  for field in params:
279
315
  if field in ["vllm_args", "env"]:
@@ -332,6 +368,10 @@ class ModelLauncher:
332
368
  job_log_dir / f"{self.model_name}.{self.slurm_job_id}.sbatch"
333
369
  )
334
370
 
371
+ # Replace venv with image path if using container
372
+ if self.params["venv"] == CONTAINER_MODULE_NAME:
373
+ self.params["venv"] = IMAGE_PATH
374
+
335
375
  with job_json.open("w") as file:
336
376
  json.dump(self.params, file, indent=4)
337
377
 
@@ -14,6 +14,7 @@ from vec_inf.client._slurm_templates import (
14
14
  BATCH_SLURM_SCRIPT_TEMPLATE,
15
15
  SLURM_SCRIPT_TEMPLATE,
16
16
  )
17
+ from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME
17
18
 
18
19
 
19
20
  class SlurmScriptGenerator:
@@ -32,24 +33,35 @@ class SlurmScriptGenerator:
32
33
  def __init__(self, params: dict[str, Any]):
33
34
  self.params = params
34
35
  self.is_multinode = int(self.params["num_nodes"]) > 1
35
- self.use_container = (
36
- self.params["venv"] == "singularity" or self.params["venv"] == "apptainer"
37
- )
36
+ self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
38
37
  self.additional_binds = self.params.get("bind", "")
39
38
  if self.additional_binds:
40
39
  self.additional_binds = f" --bind {self.additional_binds}"
41
40
  self.model_weights_path = str(
42
41
  Path(self.params["model_weights_parent_dir"], self.params["model_name"])
43
42
  )
43
+ self.env_str = self._generate_env_str()
44
+
45
+ def _generate_env_str(self) -> str:
46
+ """Generate the environment variables string for the Slurm script.
47
+
48
+ Returns
49
+ -------
50
+ str
51
+ Formatted env vars string for container or shell export commands.
52
+ """
44
53
  env_dict: dict[str, str] = self.params.get("env", {})
45
- # Create string of environment variables
46
- self.env_str = ""
47
- for key, val in env_dict.items():
48
- if len(self.env_str) == 0:
49
- self.env_str = "--env "
50
- else:
51
- self.env_str += ","
52
- self.env_str += key + "=" + val
54
+
55
+ if not env_dict:
56
+ return ""
57
+
58
+ if self.use_container:
59
+ # Format for container: --env KEY1=VAL1,KEY2=VAL2
60
+ env_pairs = [f"{key}={val}" for key, val in env_dict.items()]
61
+ return f"--env {','.join(env_pairs)}"
62
+ # Format for shell: export KEY1=VAL1\nexport KEY2=VAL2
63
+ export_lines = [f"export {key}={val}" for key, val in env_dict.items()]
64
+ return "\n".join(export_lines)
53
65
 
54
66
  def _generate_script_content(self) -> str:
55
67
  """Generate the complete Slurm script content.
@@ -95,7 +107,12 @@ class SlurmScriptGenerator:
95
107
  server_script = ["\n"]
96
108
  if self.use_container:
97
109
  server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
98
- server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["env_vars"]))
110
+ server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_env_vars"]))
111
+ else:
112
+ server_script.append(
113
+ SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
114
+ )
115
+ server_script.append(self.env_str)
99
116
  server_script.append(
100
117
  SLURM_SCRIPT_TEMPLATE["imports"].format(src_dir=self.params["src_dir"])
101
118
  )
@@ -112,6 +129,11 @@ class SlurmScriptGenerator:
112
129
  env_str=self.env_str,
113
130
  ),
114
131
  )
132
+ else:
133
+ server_setup_str = server_setup_str.replace(
134
+ "CONTAINER_PLACEHOLDER",
135
+ "\\",
136
+ )
115
137
  else:
116
138
  server_setup_str = "\n".join(
117
139
  SLURM_SCRIPT_TEMPLATE["server_setup"]["single_node"]
@@ -145,10 +167,7 @@ class SlurmScriptGenerator:
145
167
  env_str=self.env_str,
146
168
  )
147
169
  )
148
- else:
149
- launcher_script.append(
150
- SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
151
- )
170
+
152
171
  launcher_script.append(
153
172
  "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
154
173
  model_weights_path=self.model_weights_path,
@@ -194,9 +213,7 @@ class BatchSlurmScriptGenerator:
194
213
  def __init__(self, params: dict[str, Any]):
195
214
  self.params = params
196
215
  self.script_paths: list[Path] = []
197
- self.use_container = (
198
- self.params["venv"] == "singularity" or self.params["venv"] == "apptainer"
199
- )
216
+ self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
200
217
  for model_name in self.params["models"]:
201
218
  self.params["models"][model_name]["additional_binds"] = ""
202
219
  if self.params["models"][model_name].get("bind"):
@@ -74,7 +74,7 @@ class SlurmScriptTemplate(TypedDict):
74
74
  shebang: ShebangConfig
75
75
  container_setup: list[str]
76
76
  imports: str
77
- env_vars: list[str]
77
+ container_env_vars: list[str]
78
78
  container_command: str
79
79
  activate_venv: str
80
80
  server_setup: ServerSetupConfig
@@ -96,8 +96,8 @@ SLURM_SCRIPT_TEMPLATE: SlurmScriptTemplate = {
96
96
  f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop",
97
97
  ],
98
98
  "imports": "source {src_dir}/find_port.sh",
99
- "env_vars": [
100
- f"export {CONTAINER_MODULE_NAME}_BINDPATH=${CONTAINER_MODULE_NAME}_BINDPATH,$(echo /dev/infiniband* | sed -e 's/ /,/g')"
99
+ "container_env_vars": [
100
+ f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp"
101
101
  ],
102
102
  "container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --bind {{model_weights_path}}{{additional_binds}} --containall {IMAGE_PATH} \\",
103
103
  "activate_venv": "source {venv}/bin/activate",
@@ -112,6 +112,23 @@ SLURM_SCRIPT_TEMPLATE: SlurmScriptTemplate = {
112
112
  "nodes_array=($nodes)",
113
113
  "head_node=${{nodes_array[0]}}",
114
114
  'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)',
115
+ "\n# Check for RDMA devices and set environment variable accordingly",
116
+ "if ! command -v ibv_devices >/dev/null 2>&1; then",
117
+ ' echo "ibv_devices not found; forcing TCP. (No RDMA userland on host?)"',
118
+ " export NCCL_IB_DISABLE=1",
119
+ ' export NCCL_ENV_ARG="--env NCCL_IB_DISABLE=1"',
120
+ "else",
121
+ " # Pick GID index based on link layer (IB vs RoCE)",
122
+ ' if ibv_devinfo 2>/dev/null | grep -q "link_layer:.*Ethernet"; then',
123
+ " # RoCEv2 typically needs a nonzero GID index; 3 is common, try 2 if your fabric uses it",
124
+ " export NCCL_IB_GID_INDEX={{NCCL_IB_GID_INDEX:-3}}",
125
+ ' export NCCL_ENV_ARG="--env NCCL_IB_GID_INDEX={{NCCL_IB_GID_INDEX:-3}}"',
126
+ " else",
127
+ " # Native InfiniBand => GID 0",
128
+ " export NCCL_IB_GID_INDEX={{NCCL_IB_GID_INDEX:-0}}",
129
+ ' export NCCL_ENV_ARG="--env NCCL_IB_GID_INDEX={{NCCL_IB_GID_INDEX:-0}}"',
130
+ " fi",
131
+ "fi",
115
132
  "\n# Start Ray head node",
116
133
  "head_node_port=$(find_available_port $head_node_ip 8080 65535)",
117
134
  "ray_head=$head_node_ip:$head_node_port",
@@ -78,5 +78,9 @@ RESOURCE_TYPE: TypeAlias = create_literal_type( # type: ignore[valid-type]
78
78
  _config["allowed_values"]["resource_type"]
79
79
  )
80
80
 
81
+ # Extract required arguments, for launching jobs that don't have a default value and
82
+ # their corresponding environment variables
83
+ REQUIRED_ARGS: dict[str, str] = _config["required_args"]
84
+
81
85
  # Extract default arguments
82
86
  DEFAULT_ARGS: dict[str, str] = _config["default_args"]
vec_inf/client/_utils.py CHANGED
@@ -14,9 +14,9 @@ from typing import Any, Optional, Union, cast
14
14
  import requests
15
15
  import yaml
16
16
 
17
- from vec_inf.client._client_vars import MODEL_READY_SIGNATURE, REQUIRED_ARGS
17
+ from vec_inf.client._client_vars import MODEL_READY_SIGNATURE
18
18
  from vec_inf.client._exceptions import MissingRequiredFieldsError
19
- from vec_inf.client._slurm_vars import CACHED_CONFIG_DIR
19
+ from vec_inf.client._slurm_vars import CACHED_CONFIG_DIR, REQUIRED_ARGS
20
20
  from vec_inf.client.config import ModelConfig
21
21
  from vec_inf.client.models import ModelStatus
22
22
 
@@ -108,15 +108,64 @@ def is_server_running(
108
108
  if isinstance(log_content, str):
109
109
  return log_content
110
110
 
111
- status: Union[str, tuple[ModelStatus, str]] = ModelStatus.LAUNCHING
111
+ # Patterns that indicate fatal errors (not just warnings)
112
+ fatal_error_patterns = [
113
+ "traceback",
114
+ "exception",
115
+ "fatal error",
116
+ "critical error",
117
+ "failed to",
118
+ "could not",
119
+ "unable to",
120
+ "error:",
121
+ ]
122
+
123
+ # Patterns to ignore (non-fatal warnings/info messages)
124
+ ignore_patterns = [
125
+ "deprecated",
126
+ "futurewarning",
127
+ "userwarning",
128
+ "deprecationwarning",
129
+ "slurmstepd: error:", # SLURM cancellation messages (often after server started)
130
+ ]
131
+
132
+ ready_signature_found = False
133
+ fatal_error_line = None
112
134
 
113
135
  for line in log_content:
114
- if "error" in line.lower():
115
- status = (ModelStatus.FAILED, line.strip("\n"))
136
+ line_lower = line.lower()
137
+
138
+ # Check for ready signature first - if found, server is running
116
139
  if MODEL_READY_SIGNATURE in line:
117
- status = "RUNNING"
140
+ ready_signature_found = True
141
+ # Continue checking to see if there are errors after startup
142
+
143
+ # Check for fatal errors (only if we haven't seen ready signature yet)
144
+ if not ready_signature_found:
145
+ # Skip lines that match ignore patterns
146
+ if any(ignore_pattern in line_lower for ignore_pattern in ignore_patterns):
147
+ continue
118
148
 
119
- return status
149
+ # Check for fatal error patterns
150
+ for pattern in fatal_error_patterns:
151
+ if pattern in line_lower:
152
+ # Additional check: skip if it's part of a warning message
153
+ # (warnings often contain "error:" but aren't fatal)
154
+ if "warning" in line_lower and "error:" in line_lower:
155
+ continue
156
+ fatal_error_line = line.strip("\n")
157
+ break
158
+
159
+ # If we found a fatal error, mark as failed
160
+ if fatal_error_line:
161
+ return (ModelStatus.FAILED, fatal_error_line)
162
+
163
+ # If ready signature was found and no fatal errors, server is running
164
+ if ready_signature_found:
165
+ return "RUNNING"
166
+
167
+ # Otherwise, still launching
168
+ return ModelStatus.LAUNCHING
120
169
 
121
170
 
122
171
  def get_base_url(slurm_job_name: str, slurm_job_id: str, log_dir: str) -> str:
vec_inf/client/api.py CHANGED
@@ -81,7 +81,7 @@ class VecInfClient:
81
81
 
82
82
  def __init__(self) -> None:
83
83
  """Initialize the Vector Inference client."""
84
- pass
84
+ self._metrics_collectors: dict[str, PerformanceMetricsCollector] = {}
85
85
 
86
86
  def list_models(self) -> list[ModelInfo]:
87
87
  """List all available models.
@@ -218,7 +218,13 @@ class VecInfClient:
218
218
  - Performance metrics or error message
219
219
  - Timestamp of collection
220
220
  """
221
- performance_metrics_collector = PerformanceMetricsCollector(slurm_job_id)
221
+ # Use cached collector to preserve state between calls to compute throughput
222
+ if slurm_job_id not in self._metrics_collectors:
223
+ self._metrics_collectors[slurm_job_id] = PerformanceMetricsCollector(
224
+ slurm_job_id
225
+ )
226
+
227
+ performance_metrics_collector = self._metrics_collectors[slurm_job_id]
222
228
 
223
229
  metrics: Union[dict[str, float], str]
224
230
  if not performance_metrics_collector.metrics_url.startswith("http"):
vec_inf/client/models.py CHANGED
@@ -194,6 +194,10 @@ class LaunchOptions:
194
194
  Number of nodes to allocate
195
195
  gpus_per_node : int, optional
196
196
  Number of GPUs per node
197
+ cpus_per_task : int, optional
198
+ Number of CPUs per task
199
+ mem_per_node : str, optional
200
+ Memory per node
197
201
  account : str, optional
198
202
  Account name for job scheduling
199
203
  work_dir : str, optional
@@ -232,6 +236,8 @@ class LaunchOptions:
232
236
  resource_type: Optional[str] = None
233
237
  num_nodes: Optional[int] = None
234
238
  gpus_per_node: Optional[int] = None
239
+ cpus_per_task: Optional[int] = None
240
+ mem_per_node: Optional[str] = None
235
241
  account: Optional[str] = None
236
242
  work_dir: Optional[str] = None
237
243
  qos: Optional[str] = None
@@ -15,6 +15,10 @@ allowed_values:
15
15
  partition: []
16
16
  resource_type: ["l40s", "h100"]
17
17
 
18
+ required_args:
19
+ account: "VEC_INF_ACCOUNT"
20
+ work_dir: "VEC_INF_WORK_DIR"
21
+
18
22
  default_args:
19
23
  cpus_per_task: "16"
20
24
  mem_per_node: "64G"