vec-inf 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vec_inf/cli/_cli.py CHANGED
@@ -1,17 +1,14 @@
1
1
  """Command line interface for Vector Inference."""
2
2
 
3
- import os
4
3
  import time
5
- from typing import Any, Dict, Optional
4
+ from typing import Optional, Union
6
5
 
7
6
  import click
8
- import polars as pl
9
- from rich.columns import Columns
10
7
  from rich.console import Console
11
8
  from rich.live import Live
12
- from rich.panel import Panel
13
9
 
14
10
  import vec_inf.cli._utils as utils
11
+ from vec_inf.cli._helper import LaunchHelper, ListHelper, MetricsHelper, StatusHelper
15
12
 
16
13
 
17
14
  CONSOLE = Console()
@@ -37,11 +34,30 @@ def cli() -> None:
37
34
  type=int,
38
35
  help="Maximum number of sequences to process in a single request",
39
36
  )
37
+ @click.option(
38
+ "--gpu-memory-utilization",
39
+ type=float,
40
+ help="GPU memory utilization, default to 0.9",
41
+ )
42
+ @click.option(
43
+ "--enable-prefix-caching",
44
+ is_flag=True,
45
+ help="Enables automatic prefix caching",
46
+ )
47
+ @click.option(
48
+ "--enable-chunked-prefill",
49
+ is_flag=True,
50
+ help="Enable chunked prefill, enabled by default if max number of sequences > 32k",
51
+ )
52
+ @click.option(
53
+ "--max-num-batched-tokens",
54
+ type=int,
55
+ help="Maximum number of batched tokens per iteration, defaults to 2048 if --enable-chunked-prefill is set, else None",
56
+ )
40
57
  @click.option(
41
58
  "--partition",
42
59
  type=str,
43
- default="a40",
44
- help="Type of compute partition, default to a40",
60
+ help="Type of compute partition",
45
61
  )
46
62
  @click.option(
47
63
  "--num-nodes",
@@ -49,7 +65,7 @@ def cli() -> None:
49
65
  help="Number of nodes to use, default to suggested resource allocation for model",
50
66
  )
51
67
  @click.option(
52
- "--num-gpus",
68
+ "--gpus-per-node",
53
69
  type=int,
54
70
  help="Number of GPUs/node to use, default to suggested resource allocation for model",
55
71
  )
@@ -68,36 +84,36 @@ def cli() -> None:
68
84
  type=int,
69
85
  help="Vocabulary size, this option is intended for custom models",
70
86
  )
71
- @click.option(
72
- "--data-type", type=str, default="auto", help="Model data type, default to auto"
73
- )
87
+ @click.option("--data-type", type=str, help="Model data type")
74
88
  @click.option(
75
89
  "--venv",
76
90
  type=str,
77
- default="singularity",
78
- help="Path to virtual environment, default to preconfigured singularity container",
91
+ help="Path to virtual environment",
79
92
  )
80
93
  @click.option(
81
94
  "--log-dir",
82
95
  type=str,
83
- default="default",
84
- help="Path to slurm log directory, default to .vec-inf-logs in user home directory",
96
+ help="Path to slurm log directory",
85
97
  )
86
98
  @click.option(
87
99
  "--model-weights-parent-dir",
88
100
  type=str,
89
- default="/model-weights",
90
- help="Path to parent directory containing model weights, default to '/model-weights' for supported models",
101
+ help="Path to parent directory containing model weights",
91
102
  )
92
103
  @click.option(
93
104
  "--pipeline-parallelism",
94
- type=str,
95
- help="Enable pipeline parallelism, accepts 'True' or 'False', default to 'True' for supported models",
105
+ is_flag=True,
106
+ help="Enable pipeline parallelism, enabled by default for supported models",
107
+ )
108
+ @click.option(
109
+ "--compilation-config",
110
+ type=click.Choice(["0", "3"]),
111
+ help="torch.compile optimization level, accepts '0' or '3', default to '0', which means no optimization is applied",
96
112
  )
97
113
  @click.option(
98
114
  "--enforce-eager",
99
- type=str,
100
- help="Always use eager-mode PyTorch, accepts 'True' or 'False', default to 'False' for custom models if not set",
115
+ is_flag=True,
116
+ help="Always use eager-mode PyTorch",
101
117
  )
102
118
  @click.option(
103
119
  "--json-mode",
@@ -106,77 +122,23 @@ def cli() -> None:
106
122
  )
107
123
  def launch(
108
124
  model_name: str,
109
- model_family: Optional[str] = None,
110
- model_variant: Optional[str] = None,
111
- max_model_len: Optional[int] = None,
112
- max_num_seqs: Optional[int] = None,
113
- partition: Optional[str] = None,
114
- num_nodes: Optional[int] = None,
115
- num_gpus: Optional[int] = None,
116
- qos: Optional[str] = None,
117
- time: Optional[str] = None,
118
- vocab_size: Optional[int] = None,
119
- data_type: Optional[str] = None,
120
- venv: Optional[str] = None,
121
- log_dir: Optional[str] = None,
122
- model_weights_parent_dir: Optional[str] = None,
123
- pipeline_parallelism: Optional[str] = None,
124
- enforce_eager: Optional[str] = None,
125
- json_mode: bool = False,
125
+ **cli_kwargs: Optional[Union[str, int, bool]],
126
126
  ) -> None:
127
127
  """Launch a model on the cluster."""
128
- if isinstance(pipeline_parallelism, str):
129
- pipeline_parallelism = (
130
- "True" if pipeline_parallelism.lower() == "true" else "False"
131
- )
132
-
133
- launch_script_path = os.path.join(
134
- os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "launch_server.sh"
135
- )
136
- launch_cmd = f"bash {launch_script_path}"
137
-
138
- models_df = utils.load_models_df()
139
-
140
- models_df = models_df.with_columns(
141
- pl.col("model_type").replace("Reward Modeling", "Reward_Modeling")
142
- )
143
- models_df = models_df.with_columns(
144
- pl.col("model_type").replace("Text Embedding", "Text_Embedding")
145
- )
146
-
147
- if model_name in models_df["model_name"].to_list():
148
- default_args = utils.load_default_args(models_df, model_name)
149
- for arg in default_args:
150
- if arg in locals() and locals()[arg] is not None:
151
- default_args[arg] = locals()[arg]
152
- renamed_arg = arg.replace("_", "-")
153
- launch_cmd += f" --{renamed_arg} {default_args[arg]}"
154
- else:
155
- model_args = models_df.columns
156
- model_args.remove("model_name")
157
- for arg in model_args:
158
- if locals()[arg] is not None:
159
- renamed_arg = arg.replace("_", "-")
160
- launch_cmd += f" --{renamed_arg} {locals()[arg]}"
161
-
162
- output = utils.run_bash_command(launch_cmd)
163
-
164
- slurm_job_id = output.split(" ")[-1].strip().strip("\n")
165
- output_lines = output.split("\n")[:-2]
166
-
167
- table = utils.create_table(key_title="Job Config", value_title="Value")
168
- table.add_row("Slurm Job ID", slurm_job_id, style="blue")
169
- output_dict = {"slurm_job_id": slurm_job_id}
128
+ try:
129
+ launch_helper = LaunchHelper(model_name, cli_kwargs)
170
130
 
171
- for line in output_lines:
172
- key, value = line.split(": ")
173
- table.add_row(key, value)
174
- output_dict[key.lower().replace(" ", "_")] = value
131
+ launch_helper.set_env_vars()
132
+ launch_command = launch_helper.build_launch_command()
133
+ command_output, stderr = utils.run_bash_command(launch_command)
134
+ if stderr:
135
+ raise click.ClickException(f"Error: {stderr}")
136
+ launch_helper.post_launch_processing(command_output, CONSOLE)
175
137
 
176
- if json_mode:
177
- click.echo(output_dict)
178
- else:
179
- CONSOLE.print(table)
138
+ except click.ClickException as e:
139
+ raise e
140
+ except Exception as e:
141
+ raise click.ClickException(f"Launch failed: {str(e)}") from e
180
142
 
181
143
 
182
144
  @cli.command("status")
@@ -196,122 +158,17 @@ def status(
196
158
  ) -> None:
197
159
  """Get the status of a running model on the cluster."""
198
160
  status_cmd = f"scontrol show job {slurm_job_id} --oneliner"
199
- output = utils.run_bash_command(status_cmd)
200
-
201
- base_data = _get_base_status_data(output)
202
- status_info = _process_job_state(output, base_data, slurm_job_id, log_dir)
203
- _display_status(status_info, json_mode)
204
-
205
-
206
- def _get_base_status_data(output: str) -> Dict[str, Any]:
207
- """Extract basic job status information from scontrol output."""
208
- try:
209
- job_name = output.split(" ")[1].split("=")[1]
210
- job_state = output.split(" ")[9].split("=")[1]
211
- except IndexError:
212
- job_name = "UNAVAILABLE"
213
- job_state = "UNAVAILABLE"
214
-
215
- return {
216
- "model_name": job_name,
217
- "status": "SHUTDOWN",
218
- "base_url": "UNAVAILABLE",
219
- "state": job_state,
220
- "pending_reason": None,
221
- "failed_reason": None,
222
- }
223
-
224
-
225
- def _process_job_state(
226
- output: str, status_info: Dict[str, Any], slurm_job_id: int, log_dir: Optional[str]
227
- ) -> Dict[str, Any]:
228
- """Process different job states and update status information."""
229
- if status_info["state"] == "PENDING":
230
- _process_pending_state(output, status_info)
231
- elif status_info["state"] == "RUNNING":
232
- _handle_running_state(status_info, slurm_job_id, log_dir)
233
- return status_info
234
-
235
-
236
- def _process_pending_state(output: str, status_info: Dict[str, Any]) -> None:
237
- """Handle PENDING job state."""
238
- try:
239
- status_info["pending_reason"] = output.split(" ")[10].split("=")[1]
240
- status_info["status"] = "PENDING"
241
- except IndexError:
242
- status_info["pending_reason"] = "Unknown pending reason"
161
+ output, stderr = utils.run_bash_command(status_cmd)
162
+ if stderr:
163
+ raise click.ClickException(f"Error: {stderr}")
243
164
 
165
+ status_helper = StatusHelper(slurm_job_id, output, log_dir)
244
166
 
245
- def _handle_running_state(
246
- status_info: Dict[str, Any], slurm_job_id: int, log_dir: Optional[str]
247
- ) -> None:
248
- """Handle RUNNING job state and check server status."""
249
- server_status = utils.is_server_running(
250
- status_info["model_name"], slurm_job_id, log_dir
251
- )
252
-
253
- if isinstance(server_status, tuple):
254
- status_info["status"], status_info["failed_reason"] = server_status
255
- return
256
-
257
- if server_status == "RUNNING":
258
- _check_model_health(status_info, slurm_job_id, log_dir)
259
- else:
260
- status_info["status"] = server_status
261
-
262
-
263
- def _check_model_health(
264
- status_info: Dict[str, Any], slurm_job_id: int, log_dir: Optional[str]
265
- ) -> None:
266
- """Check model health and update status accordingly."""
267
- model_status = utils.model_health_check(
268
- status_info["model_name"], slurm_job_id, log_dir
269
- )
270
- status, failed_reason = model_status
271
- if status == "READY":
272
- status_info["base_url"] = utils.get_base_url(
273
- status_info["model_name"], slurm_job_id, log_dir
274
- )
275
- status_info["status"] = status
276
- else:
277
- status_info["status"], status_info["failed_reason"] = status, failed_reason
278
-
279
-
280
- def _display_status(status_info: Dict[str, Any], json_mode: bool) -> None:
281
- """Display the status information in appropriate format."""
167
+ status_helper.process_job_state()
282
168
  if json_mode:
283
- _output_json(status_info)
169
+ status_helper.output_json()
284
170
  else:
285
- _output_table(status_info)
286
-
287
-
288
- def _output_json(status_info: Dict[str, Any]) -> None:
289
- """Format and output JSON data."""
290
- json_data = {
291
- "model_name": status_info["model_name"],
292
- "model_status": status_info["status"],
293
- "base_url": status_info["base_url"],
294
- }
295
- if status_info["pending_reason"]:
296
- json_data["pending_reason"] = status_info["pending_reason"]
297
- if status_info["failed_reason"]:
298
- json_data["failed_reason"] = status_info["failed_reason"]
299
- click.echo(json_data)
300
-
301
-
302
- def _output_table(status_info: Dict[str, Any]) -> None:
303
- """Create and display rich table."""
304
- table = utils.create_table(key_title="Job Status", value_title="Value")
305
- table.add_row("Model Name", status_info["model_name"])
306
- table.add_row("Model Status", status_info["status"], style="blue")
307
-
308
- if status_info["pending_reason"]:
309
- table.add_row("Pending Reason", status_info["pending_reason"])
310
- if status_info["failed_reason"]:
311
- table.add_row("Failed Reason", status_info["failed_reason"])
312
-
313
- table.add_row("Base URL", status_info["base_url"])
314
- CONSOLE.print(table)
171
+ status_helper.output_table(CONSOLE)
315
172
 
316
173
 
317
174
  @cli.command("shutdown")
@@ -332,105 +189,40 @@ def shutdown(slurm_job_id: int) -> None:
332
189
  )
333
190
  def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> None:
334
191
  """List all available models, or get default setup of a specific model."""
335
-
336
- def list_model(model_name: str, models_df: pl.DataFrame, json_mode: bool) -> None:
337
- if model_name not in models_df["model_name"].to_list():
338
- raise ValueError(f"Model name {model_name} not found in available models")
339
-
340
- excluded_keys = {"venv", "log_dir"}
341
- model_row = models_df.filter(models_df["model_name"] == model_name)
342
-
343
- if json_mode:
344
- filtered_model_row = model_row.drop(excluded_keys, strict=False)
345
- click.echo(filtered_model_row.to_dicts()[0])
346
- return
347
- table = utils.create_table(key_title="Model Config", value_title="Value")
348
- for row in model_row.to_dicts():
349
- for key, value in row.items():
350
- if key not in excluded_keys:
351
- table.add_row(key, str(value))
352
- CONSOLE.print(table)
353
-
354
- def list_all(models_df: pl.DataFrame, json_mode: bool) -> None:
355
- if json_mode:
356
- click.echo(models_df["model_name"].to_list())
357
- return
358
- panels = []
359
- model_type_colors = {
360
- "LLM": "cyan",
361
- "VLM": "bright_blue",
362
- "Text Embedding": "purple",
363
- "Reward Modeling": "bright_magenta",
364
- }
365
-
366
- models_df = models_df.with_columns(
367
- pl.when(pl.col("model_type") == "LLM")
368
- .then(0)
369
- .when(pl.col("model_type") == "VLM")
370
- .then(1)
371
- .when(pl.col("model_type") == "Text Embedding")
372
- .then(2)
373
- .when(pl.col("model_type") == "Reward Modeling")
374
- .then(3)
375
- .otherwise(-1)
376
- .alias("model_type_order")
377
- )
378
-
379
- models_df = models_df.sort("model_type_order")
380
- models_df = models_df.drop("model_type_order")
381
-
382
- for row in models_df.to_dicts():
383
- panel_color = model_type_colors.get(row["model_type"], "white")
384
- if row["model_variant"] == "None":
385
- styled_text = f"[magenta]{row['model_family']}[/magenta]"
386
- else:
387
- styled_text = (
388
- f"[magenta]{row['model_family']}[/magenta]-{row['model_variant']}"
389
- )
390
- panels.append(Panel(styled_text, expand=True, border_style=panel_color))
391
- CONSOLE.print(Columns(panels, equal=True))
392
-
393
- models_df = utils.load_models_df()
394
-
395
- if model_name:
396
- list_model(model_name, models_df, json_mode)
397
- else:
398
- list_all(models_df, json_mode)
192
+ list_helper = ListHelper(model_name, json_mode)
193
+ list_helper.process_list_command(CONSOLE)
399
194
 
400
195
 
401
196
  @cli.command("metrics")
402
197
  @click.argument("slurm_job_id", type=int, nargs=1)
403
198
  @click.option(
404
- "--log-dir",
405
- type=str,
406
- help="Path to slurm log directory. This is required if --log-dir was set in model launch",
199
+ "--log-dir", type=str, help="Path to slurm log directory (if used during launch)"
407
200
  )
408
201
  def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None:
409
- """Stream performance metrics to the console."""
410
- status_cmd = f"scontrol show job {slurm_job_id} --oneliner"
411
- output = utils.run_bash_command(status_cmd)
412
- slurm_job_name = output.split(" ")[1].split("=")[1]
202
+ """Stream real-time performance metrics from the model endpoint."""
203
+ helper = MetricsHelper(slurm_job_id, log_dir)
204
+
205
+ # Check if metrics URL is ready
206
+ if not helper.metrics_url.startswith("http"):
207
+ table = utils.create_table("Metric", "Value")
208
+ helper.display_failed_metrics(
209
+ table, f"Metrics endpoint unavailable - {helper.metrics_url}"
210
+ )
211
+ CONSOLE.print(table)
212
+ return
413
213
 
414
214
  with Live(refresh_per_second=1, console=CONSOLE) as live:
415
215
  while True:
416
- out_logs = utils.read_slurm_log(
417
- slurm_job_name, slurm_job_id, "out", log_dir
418
- )
419
- # if out_logs is a string, then it is an error message
420
- if isinstance(out_logs, str):
421
- live.update(out_logs)
422
- break
423
- latest_metrics = utils.get_latest_metric(out_logs)
424
- # if latest_metrics is a string, then it is an error message
425
- if isinstance(latest_metrics, str):
426
- live.update(latest_metrics)
427
- break
428
- table = utils.create_table(key_title="Metric", value_title="Value")
429
- for key, value in latest_metrics.items():
430
- table.add_row(key, value)
216
+ metrics = helper.fetch_metrics()
217
+ table = utils.create_table("Metric", "Value")
431
218
 
432
- live.update(table)
219
+ if isinstance(metrics, str):
220
+ # Show status information if metrics aren't available
221
+ helper.display_failed_metrics(table, metrics)
222
+ else:
223
+ helper.display_metrics(table, metrics)
433
224
 
225
+ live.update(table)
434
226
  time.sleep(2)
435
227
 
436
228
 
vec_inf/cli/_config.py ADDED
@@ -0,0 +1,87 @@
1
+ """Model configuration."""
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+ from typing_extensions import Literal
8
+
9
+
10
+ QOS = Literal[
11
+ "normal",
12
+ "m",
13
+ "m2",
14
+ "m3",
15
+ "m4",
16
+ "m5",
17
+ "long",
18
+ "deadline",
19
+ "high",
20
+ "scavenger",
21
+ "llm",
22
+ "a100",
23
+ ]
24
+
25
+ PARTITION = Literal["a40", "a100", "t4v1", "t4v2", "rtx6000"]
26
+
27
+ DATA_TYPE = Literal["auto", "float16", "bfloat16", "float32"]
28
+
29
+
30
+ class ModelConfig(BaseModel):
31
+ """Pydantic model for validating and managing model deployment configurations."""
32
+
33
+ model_name: str = Field(..., min_length=3, pattern=r"^[a-zA-Z0-9\-_\.]+$")
34
+ model_family: str = Field(..., min_length=2)
35
+ model_variant: Optional[str] = Field(
36
+ default=None, description="Specific variant/version of the model family"
37
+ )
38
+ model_type: Literal["LLM", "VLM", "Text_Embedding", "Reward_Modeling"] = Field(
39
+ ..., description="Type of model architecture"
40
+ )
41
+ gpus_per_node: int = Field(..., gt=0, le=8, description="GPUs per node")
42
+ num_nodes: int = Field(..., gt=0, le=16, description="Number of nodes")
43
+ vocab_size: int = Field(..., gt=0, le=1_000_000)
44
+ max_model_len: int = Field(
45
+ ..., gt=0, le=1_010_000, description="Maximum context length supported"
46
+ )
47
+ max_num_seqs: int = Field(
48
+ default=256, gt=0, le=1024, description="Maximum concurrent request sequences"
49
+ )
50
+ compilation_config: int = Field(
51
+ default=0,
52
+ gt=-1,
53
+ le=4,
54
+ description="torch.compile optimization level",
55
+ )
56
+ gpu_memory_utilization: float = Field(
57
+ default=0.9, gt=0.0, le=1.0, description="GPU memory utilization"
58
+ )
59
+ pipeline_parallelism: bool = Field(
60
+ default=True, description="Enable pipeline parallelism"
61
+ )
62
+ enforce_eager: bool = Field(default=False, description="Force eager mode execution")
63
+ qos: Union[QOS, str] = Field(default="m2", description="Quality of Service tier")
64
+ time: str = Field(
65
+ default="08:00:00",
66
+ pattern=r"^\d{2}:\d{2}:\d{2}$",
67
+ description="HH:MM:SS time limit",
68
+ )
69
+ partition: Union[PARTITION, str] = Field(
70
+ default="a40", description="GPU partition type"
71
+ )
72
+ data_type: Union[DATA_TYPE, str] = Field(
73
+ default="auto", description="Model precision format"
74
+ )
75
+ venv: str = Field(
76
+ default="singularity", description="Virtual environment/container system"
77
+ )
78
+ log_dir: Path = Field(
79
+ default=Path("~/.vec-inf-logs").expanduser(), description="Log directory path"
80
+ )
81
+ model_weights_parent_dir: Path = Field(
82
+ default=Path("/model-weights"), description="Base directory for model weights"
83
+ )
84
+
85
+ model_config = ConfigDict(
86
+ extra="forbid", str_strip_whitespace=True, validate_default=True, frozen=True
87
+ )