vec-inf 0.4.0.post1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vec_inf/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ """vec_inf package."""
vec_inf/cli/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ """vec_inf cli package."""
vec_inf/cli/_cli.py CHANGED
@@ -1,23 +1,22 @@
1
- import os
1
+ """Command line interface for Vector Inference."""
2
+
2
3
  import time
3
- from typing import Optional, cast
4
+ from typing import Optional, Union
4
5
 
5
6
  import click
6
-
7
- import polars as pl
8
- from rich.columns import Columns
9
7
  from rich.console import Console
10
8
  from rich.live import Live
11
- from rich.panel import Panel
12
9
 
13
10
  import vec_inf.cli._utils as utils
11
+ from vec_inf.cli._helper import LaunchHelper, ListHelper, MetricsHelper, StatusHelper
12
+
14
13
 
15
14
  CONSOLE = Console()
16
15
 
17
16
 
18
17
  @click.group()
19
- def cli():
20
- """Vector Inference CLI"""
18
+ def cli() -> None:
19
+ """Vector Inference CLI."""
21
20
  pass
22
21
 
23
22
 
@@ -35,11 +34,30 @@ def cli():
35
34
  type=int,
36
35
  help="Maximum number of sequences to process in a single request",
37
36
  )
37
+ @click.option(
38
+ "--gpu-memory-utilization",
39
+ type=float,
40
+ help="GPU memory utilization, default to 0.9",
41
+ )
42
+ @click.option(
43
+ "--enable-prefix-caching",
44
+ is_flag=True,
45
+ help="Enables automatic prefix caching",
46
+ )
47
+ @click.option(
48
+ "--enable-chunked-prefill",
49
+ is_flag=True,
50
+ help="Enable chunked prefill, enabled by default if max number of sequences > 32k",
51
+ )
52
+ @click.option(
53
+ "--max-num-batched-tokens",
54
+ type=int,
55
+ help="Maximum number of batched tokens per iteration, defaults to 2048 if --enable-chunked-prefill is set, else None",
56
+ )
38
57
  @click.option(
39
58
  "--partition",
40
59
  type=str,
41
- default="a40",
42
- help="Type of compute partition, default to a40",
60
+ help="Type of compute partition",
43
61
  )
44
62
  @click.option(
45
63
  "--num-nodes",
@@ -47,7 +65,7 @@ def cli():
47
65
  help="Number of nodes to use, default to suggested resource allocation for model",
48
66
  )
49
67
  @click.option(
50
- "--num-gpus",
68
+ "--gpus-per-node",
51
69
  type=int,
52
70
  help="Number of GPUs/node to use, default to suggested resource allocation for model",
53
71
  )
@@ -66,36 +84,36 @@ def cli():
66
84
  type=int,
67
85
  help="Vocabulary size, this option is intended for custom models",
68
86
  )
69
- @click.option(
70
- "--data-type", type=str, default="auto", help="Model data type, default to auto"
71
- )
87
+ @click.option("--data-type", type=str, help="Model data type")
72
88
  @click.option(
73
89
  "--venv",
74
90
  type=str,
75
- default="singularity",
76
- help="Path to virtual environment, default to preconfigured singularity container",
91
+ help="Path to virtual environment",
77
92
  )
78
93
  @click.option(
79
94
  "--log-dir",
80
95
  type=str,
81
- default="default",
82
- help="Path to slurm log directory, default to .vec-inf-logs in user home directory",
96
+ help="Path to slurm log directory",
83
97
  )
84
98
  @click.option(
85
99
  "--model-weights-parent-dir",
86
100
  type=str,
87
- default="/model-weights",
88
- help="Path to parent directory containing model weights, default to '/model-weights' for supported models",
101
+ help="Path to parent directory containing model weights",
89
102
  )
90
103
  @click.option(
91
104
  "--pipeline-parallelism",
92
- type=str,
93
- help="Enable pipeline parallelism, accepts 'True' or 'False', default to 'True' for supported models",
105
+ is_flag=True,
106
+ help="Enable pipeline parallelism, enabled by default for supported models",
107
+ )
108
+ @click.option(
109
+ "--compilation-config",
110
+ type=click.Choice(["0", "3"]),
111
+ help="torch.compile optimization level, accepts '0' or '3', default to '0', which means no optimization is applied",
94
112
  )
95
113
  @click.option(
96
114
  "--enforce-eager",
97
- type=str,
98
- help="Always use eager-mode PyTorch, accepts 'True' or 'False', default to 'False' for custom models if not set",
115
+ is_flag=True,
116
+ help="Always use eager-mode PyTorch",
99
117
  )
100
118
  @click.option(
101
119
  "--json-mode",
@@ -104,74 +122,23 @@ def cli():
104
122
  )
105
123
  def launch(
106
124
  model_name: str,
107
- model_family: Optional[str] = None,
108
- model_variant: Optional[str] = None,
109
- max_model_len: Optional[int] = None,
110
- max_num_seqs: Optional[int] = None,
111
- partition: Optional[str] = None,
112
- num_nodes: Optional[int] = None,
113
- num_gpus: Optional[int] = None,
114
- qos: Optional[str] = None,
115
- time: Optional[str] = None,
116
- vocab_size: Optional[int] = None,
117
- data_type: Optional[str] = None,
118
- venv: Optional[str] = None,
119
- log_dir: Optional[str] = None,
120
- model_weights_parent_dir: Optional[str] = None,
121
- pipeline_parallelism: Optional[str] = None,
122
- enforce_eager: Optional[str] = None,
123
- json_mode: bool = False,
125
+ **cli_kwargs: Optional[Union[str, int, bool]],
124
126
  ) -> None:
125
- """
126
- Launch a model on the cluster
127
- """
128
-
129
- if isinstance(pipeline_parallelism, str):
130
- pipeline_parallelism = (
131
- "True" if pipeline_parallelism.lower() == "true" else "False"
132
- )
133
-
134
- launch_script_path = os.path.join(
135
- os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "launch_server.sh"
136
- )
137
- launch_cmd = f"bash {launch_script_path}"
138
-
139
- models_df = utils.load_models_df()
140
-
141
- if model_name in models_df["model_name"].to_list():
142
- default_args = utils.load_default_args(models_df, model_name)
143
- for arg in default_args:
144
- if arg in locals() and locals()[arg] is not None:
145
- default_args[arg] = locals()[arg]
146
- renamed_arg = arg.replace("_", "-")
147
- launch_cmd += f" --{renamed_arg} {default_args[arg]}"
148
- else:
149
- model_args = models_df.columns
150
- model_args.remove("model_name")
151
- model_args.remove("model_type")
152
- for arg in model_args:
153
- if locals()[arg] is not None:
154
- renamed_arg = arg.replace("_", "-")
155
- launch_cmd += f" --{renamed_arg} {locals()[arg]}"
156
-
157
- output = utils.run_bash_command(launch_cmd)
158
-
159
- slurm_job_id = output.split(" ")[-1].strip().strip("\n")
160
- output_lines = output.split("\n")[:-2]
161
-
162
- table = utils.create_table(key_title="Job Config", value_title="Value")
163
- table.add_row("Slurm Job ID", slurm_job_id, style="blue")
164
- output_dict = {"slurm_job_id": slurm_job_id}
127
+ """Launch a model on the cluster."""
128
+ try:
129
+ launch_helper = LaunchHelper(model_name, cli_kwargs)
165
130
 
166
- for line in output_lines:
167
- key, value = line.split(": ")
168
- table.add_row(key, value)
169
- output_dict[key.lower().replace(" ", "_")] = value
131
+ launch_helper.set_env_vars()
132
+ launch_command = launch_helper.build_launch_command()
133
+ command_output, stderr = utils.run_bash_command(launch_command)
134
+ if stderr:
135
+ raise click.ClickException(f"Error: {stderr}")
136
+ launch_helper.post_launch_processing(command_output, CONSOLE)
170
137
 
171
- if json_mode:
172
- click.echo(output_dict)
173
- else:
174
- CONSOLE.print(table)
138
+ except click.ClickException as e:
139
+ raise e
140
+ except Exception as e:
141
+ raise click.ClickException(f"Launch failed: {str(e)}") from e
175
142
 
176
143
 
177
144
  @cli.command("status")
@@ -189,79 +156,25 @@ def launch(
189
156
  def status(
190
157
  slurm_job_id: int, log_dir: Optional[str] = None, json_mode: bool = False
191
158
  ) -> None:
192
- """
193
- Get the status of a running model on the cluster
194
- """
159
+ """Get the status of a running model on the cluster."""
195
160
  status_cmd = f"scontrol show job {slurm_job_id} --oneliner"
196
- output = utils.run_bash_command(status_cmd)
161
+ output, stderr = utils.run_bash_command(status_cmd)
162
+ if stderr:
163
+ raise click.ClickException(f"Error: {stderr}")
197
164
 
198
- slurm_job_name = "UNAVAILABLE"
199
- status = "SHUTDOWN"
200
- base_url = "UNAVAILABLE"
201
-
202
- try:
203
- slurm_job_name = output.split(" ")[1].split("=")[1]
204
- slurm_job_state = output.split(" ")[9].split("=")[1]
205
- except IndexError:
206
- # Job ID not found
207
- slurm_job_state = "UNAVAILABLE"
208
-
209
- # If Slurm job is currently PENDING
210
- if slurm_job_state == "PENDING":
211
- slurm_job_pending_reason = output.split(" ")[10].split("=")[1]
212
- status = "PENDING"
213
- # If Slurm job is currently RUNNING
214
- elif slurm_job_state == "RUNNING":
215
- # Check whether the server is ready, if yes, run model health check to further determine status
216
- server_status = utils.is_server_running(slurm_job_name, slurm_job_id, log_dir)
217
- # If server status is a tuple, then server status is "FAILED"
218
- if isinstance(server_status, tuple):
219
- status = server_status[0]
220
- slurm_job_failed_reason = server_status[1]
221
- elif server_status == "RUNNING":
222
- model_status = utils.model_health_check(
223
- slurm_job_name, slurm_job_id, log_dir
224
- )
225
- if model_status == "READY":
226
- # Only set base_url if model is ready to serve requests
227
- base_url = utils.get_base_url(slurm_job_name, slurm_job_id, log_dir)
228
- status = "READY"
229
- else:
230
- # If model is not ready, then status must be "FAILED"
231
- status = model_status[0]
232
- slurm_job_failed_reason = str(model_status[1])
233
- else:
234
- status = server_status
165
+ status_helper = StatusHelper(slurm_job_id, output, log_dir)
235
166
 
167
+ status_helper.process_job_state()
236
168
  if json_mode:
237
- status_dict = {
238
- "model_name": slurm_job_name,
239
- "model_status": status,
240
- "base_url": base_url,
241
- }
242
- if "slurm_job_pending_reason" in locals():
243
- status_dict["pending_reason"] = slurm_job_pending_reason
244
- if "slurm_job_failed_reason" in locals():
245
- status_dict["failed_reason"] = slurm_job_failed_reason
246
- click.echo(f"{status_dict}")
169
+ status_helper.output_json()
247
170
  else:
248
- table = utils.create_table(key_title="Job Status", value_title="Value")
249
- table.add_row("Model Name", slurm_job_name)
250
- table.add_row("Model Status", status, style="blue")
251
- if "slurm_job_pending_reason" in locals():
252
- table.add_row("Reason", slurm_job_pending_reason)
253
- if "slurm_job_failed_reason" in locals():
254
- table.add_row("Reason", slurm_job_failed_reason)
255
- table.add_row("Base URL", base_url)
256
- CONSOLE.print(table)
171
+ status_helper.output_table(CONSOLE)
257
172
 
258
173
 
259
174
  @cli.command("shutdown")
260
175
  @click.argument("slurm_job_id", type=int, nargs=1)
261
176
  def shutdown(slurm_job_id: int) -> None:
262
- """
263
- Shutdown a running model on the cluster
264
- """
177
+ """Shutdown a running model on the cluster."""
265
178
  shutdown_cmd = f"scancel {slurm_job_id}"
266
179
  utils.run_bash_command(shutdown_cmd)
267
180
  click.echo(f"Shutting down model with Slurm Job ID: {slurm_job_id}")
@@ -275,109 +188,41 @@ def shutdown(slurm_job_id: int) -> None:
275
188
  help="Output in JSON string",
276
189
  )
277
190
  def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> None:
278
- """
279
- List all available models, or get default setup of a specific model
280
- """
281
-
282
- def list_model(model_name: str, models_df: pl.DataFrame, json_mode: bool):
283
- if model_name not in models_df["model_name"].to_list():
284
- raise ValueError(f"Model name {model_name} not found in available models")
285
-
286
- excluded_keys = {"venv", "log_dir"}
287
- model_row = models_df.filter(models_df["model_name"] == model_name)
288
-
289
- if json_mode:
290
- filtered_model_row = model_row.drop(excluded_keys, strict=False)
291
- click.echo(filtered_model_row.to_dicts()[0])
292
- return
293
- table = utils.create_table(key_title="Model Config", value_title="Value")
294
- for row in model_row.to_dicts():
295
- for key, value in row.items():
296
- if key not in excluded_keys:
297
- table.add_row(key, str(value))
298
- CONSOLE.print(table)
299
-
300
- def list_all(models_df: pl.DataFrame, json_mode: bool):
301
- if json_mode:
302
- click.echo(models_df["model_name"].to_list())
303
- return
304
- panels = []
305
- model_type_colors = {
306
- "LLM": "cyan",
307
- "VLM": "bright_blue",
308
- "Text Embedding": "purple",
309
- "Reward Modeling": "bright_magenta",
310
- }
311
-
312
- models_df = models_df.with_columns(
313
- pl.when(pl.col("model_type") == "LLM")
314
- .then(0)
315
- .when(pl.col("model_type") == "VLM")
316
- .then(1)
317
- .when(pl.col("model_type") == "Text Embedding")
318
- .then(2)
319
- .when(pl.col("model_type") == "Reward Modeling")
320
- .then(3)
321
- .otherwise(-1)
322
- .alias("model_type_order")
323
- )
324
-
325
- models_df = models_df.sort("model_type_order")
326
- models_df = models_df.drop("model_type_order")
327
-
328
- for row in models_df.to_dicts():
329
- panel_color = model_type_colors.get(row["model_type"], "white")
330
- styled_text = (
331
- f"[magenta]{row['model_family']}[/magenta]-{row['model_variant']}"
332
- )
333
- panels.append(Panel(styled_text, expand=True, border_style=panel_color))
334
- CONSOLE.print(Columns(panels, equal=True))
335
-
336
- models_df = utils.load_models_df()
337
-
338
- if model_name:
339
- list_model(model_name, models_df, json_mode)
340
- else:
341
- list_all(models_df, json_mode)
191
+ """List all available models, or get default setup of a specific model."""
192
+ list_helper = ListHelper(model_name, json_mode)
193
+ list_helper.process_list_command(CONSOLE)
342
194
 
343
195
 
344
196
  @cli.command("metrics")
345
197
  @click.argument("slurm_job_id", type=int, nargs=1)
346
198
  @click.option(
347
- "--log-dir",
348
- type=str,
349
- help="Path to slurm log directory. This is required if --log-dir was set in model launch",
199
+ "--log-dir", type=str, help="Path to slurm log directory (if used during launch)"
350
200
  )
351
201
  def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None:
352
- """
353
- Stream performance metrics to the console
354
- """
355
- status_cmd = f"scontrol show job {slurm_job_id} --oneliner"
356
- output = utils.run_bash_command(status_cmd)
357
- slurm_job_name = output.split(" ")[1].split("=")[1]
202
+ """Stream real-time performance metrics from the model endpoint."""
203
+ helper = MetricsHelper(slurm_job_id, log_dir)
204
+
205
+ # Check if metrics URL is ready
206
+ if not helper.metrics_url.startswith("http"):
207
+ table = utils.create_table("Metric", "Value")
208
+ helper.display_failed_metrics(
209
+ table, f"Metrics endpoint unavailable - {helper.metrics_url}"
210
+ )
211
+ CONSOLE.print(table)
212
+ return
358
213
 
359
214
  with Live(refresh_per_second=1, console=CONSOLE) as live:
360
215
  while True:
361
- out_logs = utils.read_slurm_log(
362
- slurm_job_name, slurm_job_id, "out", log_dir
363
- )
364
- # if out_logs is a string, then it is an error message
365
- if isinstance(out_logs, str):
366
- live.update(out_logs)
367
- break
368
- out_logs = cast(list, out_logs)
369
- latest_metrics = utils.get_latest_metric(out_logs)
370
- # if latest_metrics is a string, then it is an error message
371
- if isinstance(latest_metrics, str):
372
- live.update(latest_metrics)
373
- break
374
- latest_metrics = cast(dict, latest_metrics)
375
- table = utils.create_table(key_title="Metric", value_title="Value")
376
- for key, value in latest_metrics.items():
377
- table.add_row(key, value)
216
+ metrics = helper.fetch_metrics()
217
+ table = utils.create_table("Metric", "Value")
378
218
 
379
- live.update(table)
219
+ if isinstance(metrics, str):
220
+ # Show status information if metrics aren't available
221
+ helper.display_failed_metrics(table, metrics)
222
+ else:
223
+ helper.display_metrics(table, metrics)
380
224
 
225
+ live.update(table)
381
226
  time.sleep(2)
382
227
 
383
228
 
vec_inf/cli/_config.py ADDED
@@ -0,0 +1,87 @@
1
+ """Model configuration."""
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+ from typing_extensions import Literal
8
+
9
+
10
+ QOS = Literal[
11
+ "normal",
12
+ "m",
13
+ "m2",
14
+ "m3",
15
+ "m4",
16
+ "m5",
17
+ "long",
18
+ "deadline",
19
+ "high",
20
+ "scavenger",
21
+ "llm",
22
+ "a100",
23
+ ]
24
+
25
+ PARTITION = Literal["a40", "a100", "t4v1", "t4v2", "rtx6000"]
26
+
27
+ DATA_TYPE = Literal["auto", "float16", "bfloat16", "float32"]
28
+
29
+
30
+ class ModelConfig(BaseModel):
31
+ """Pydantic model for validating and managing model deployment configurations."""
32
+
33
+ model_name: str = Field(..., min_length=3, pattern=r"^[a-zA-Z0-9\-_\.]+$")
34
+ model_family: str = Field(..., min_length=2)
35
+ model_variant: Optional[str] = Field(
36
+ default=None, description="Specific variant/version of the model family"
37
+ )
38
+ model_type: Literal["LLM", "VLM", "Text_Embedding", "Reward_Modeling"] = Field(
39
+ ..., description="Type of model architecture"
40
+ )
41
+ gpus_per_node: int = Field(..., gt=0, le=8, description="GPUs per node")
42
+ num_nodes: int = Field(..., gt=0, le=16, description="Number of nodes")
43
+ vocab_size: int = Field(..., gt=0, le=1_000_000)
44
+ max_model_len: int = Field(
45
+ ..., gt=0, le=1_010_000, description="Maximum context length supported"
46
+ )
47
+ max_num_seqs: int = Field(
48
+ default=256, gt=0, le=1024, description="Maximum concurrent request sequences"
49
+ )
50
+ compilation_config: int = Field(
51
+ default=0,
52
+ gt=-1,
53
+ le=4,
54
+ description="torch.compile optimization level",
55
+ )
56
+ gpu_memory_utilization: float = Field(
57
+ default=0.9, gt=0.0, le=1.0, description="GPU memory utilization"
58
+ )
59
+ pipeline_parallelism: bool = Field(
60
+ default=True, description="Enable pipeline parallelism"
61
+ )
62
+ enforce_eager: bool = Field(default=False, description="Force eager mode execution")
63
+ qos: Union[QOS, str] = Field(default="m2", description="Quality of Service tier")
64
+ time: str = Field(
65
+ default="08:00:00",
66
+ pattern=r"^\d{2}:\d{2}:\d{2}$",
67
+ description="HH:MM:SS time limit",
68
+ )
69
+ partition: Union[PARTITION, str] = Field(
70
+ default="a40", description="GPU partition type"
71
+ )
72
+ data_type: Union[DATA_TYPE, str] = Field(
73
+ default="auto", description="Model precision format"
74
+ )
75
+ venv: str = Field(
76
+ default="singularity", description="Virtual environment/container system"
77
+ )
78
+ log_dir: Path = Field(
79
+ default=Path("~/.vec-inf-logs").expanduser(), description="Log directory path"
80
+ )
81
+ model_weights_parent_dir: Path = Field(
82
+ default=Path("/model-weights"), description="Base directory for model weights"
83
+ )
84
+
85
+ model_config = ConfigDict(
86
+ extra="forbid", str_strip_whitespace=True, validate_default=True, frozen=True
87
+ )