vec-inf 0.4.0.post1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vec_inf/__init__.py +1 -0
- vec_inf/cli/__init__.py +1 -0
- vec_inf/cli/_cli.py +88 -243
- vec_inf/cli/_config.py +87 -0
- vec_inf/cli/_helper.py +675 -0
- vec_inf/cli/_utils.py +88 -89
- vec_inf/{models → config}/README.md +54 -0
- vec_inf/config/models.yaml +1274 -0
- vec_inf/multinode_vllm.slurm +61 -29
- vec_inf/vllm.slurm +55 -22
- vec_inf-0.5.0.dist-info/METADATA +210 -0
- vec_inf-0.5.0.dist-info/RECORD +17 -0
- {vec_inf-0.4.0.post1.dist-info → vec_inf-0.5.0.dist-info}/WHEEL +1 -1
- vec_inf-0.5.0.dist-info/entry_points.txt +2 -0
- vec_inf/launch_server.sh +0 -126
- vec_inf/models/models.csv +0 -73
- vec_inf-0.4.0.post1.dist-info/METADATA +0 -120
- vec_inf-0.4.0.post1.dist-info/RECORD +0 -16
- vec_inf-0.4.0.post1.dist-info/entry_points.txt +0 -3
- {vec_inf-0.4.0.post1.dist-info → vec_inf-0.5.0.dist-info/licenses}/LICENSE +0 -0
vec_inf/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""vec_inf package."""
|
vec_inf/cli/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""vec_inf cli package."""
|
vec_inf/cli/_cli.py
CHANGED
|
@@ -1,23 +1,22 @@
|
|
|
1
|
-
|
|
1
|
+
"""Command line interface for Vector Inference."""
|
|
2
|
+
|
|
2
3
|
import time
|
|
3
|
-
from typing import Optional,
|
|
4
|
+
from typing import Optional, Union
|
|
4
5
|
|
|
5
6
|
import click
|
|
6
|
-
|
|
7
|
-
import polars as pl
|
|
8
|
-
from rich.columns import Columns
|
|
9
7
|
from rich.console import Console
|
|
10
8
|
from rich.live import Live
|
|
11
|
-
from rich.panel import Panel
|
|
12
9
|
|
|
13
10
|
import vec_inf.cli._utils as utils
|
|
11
|
+
from vec_inf.cli._helper import LaunchHelper, ListHelper, MetricsHelper, StatusHelper
|
|
12
|
+
|
|
14
13
|
|
|
15
14
|
CONSOLE = Console()
|
|
16
15
|
|
|
17
16
|
|
|
18
17
|
@click.group()
|
|
19
|
-
def cli():
|
|
20
|
-
"""Vector Inference CLI"""
|
|
18
|
+
def cli() -> None:
|
|
19
|
+
"""Vector Inference CLI."""
|
|
21
20
|
pass
|
|
22
21
|
|
|
23
22
|
|
|
@@ -35,11 +34,30 @@ def cli():
|
|
|
35
34
|
type=int,
|
|
36
35
|
help="Maximum number of sequences to process in a single request",
|
|
37
36
|
)
|
|
37
|
+
@click.option(
|
|
38
|
+
"--gpu-memory-utilization",
|
|
39
|
+
type=float,
|
|
40
|
+
help="GPU memory utilization, default to 0.9",
|
|
41
|
+
)
|
|
42
|
+
@click.option(
|
|
43
|
+
"--enable-prefix-caching",
|
|
44
|
+
is_flag=True,
|
|
45
|
+
help="Enables automatic prefix caching",
|
|
46
|
+
)
|
|
47
|
+
@click.option(
|
|
48
|
+
"--enable-chunked-prefill",
|
|
49
|
+
is_flag=True,
|
|
50
|
+
help="Enable chunked prefill, enabled by default if max number of sequences > 32k",
|
|
51
|
+
)
|
|
52
|
+
@click.option(
|
|
53
|
+
"--max-num-batched-tokens",
|
|
54
|
+
type=int,
|
|
55
|
+
help="Maximum number of batched tokens per iteration, defaults to 2048 if --enable-chunked-prefill is set, else None",
|
|
56
|
+
)
|
|
38
57
|
@click.option(
|
|
39
58
|
"--partition",
|
|
40
59
|
type=str,
|
|
41
|
-
|
|
42
|
-
help="Type of compute partition, default to a40",
|
|
60
|
+
help="Type of compute partition",
|
|
43
61
|
)
|
|
44
62
|
@click.option(
|
|
45
63
|
"--num-nodes",
|
|
@@ -47,7 +65,7 @@ def cli():
|
|
|
47
65
|
help="Number of nodes to use, default to suggested resource allocation for model",
|
|
48
66
|
)
|
|
49
67
|
@click.option(
|
|
50
|
-
"--
|
|
68
|
+
"--gpus-per-node",
|
|
51
69
|
type=int,
|
|
52
70
|
help="Number of GPUs/node to use, default to suggested resource allocation for model",
|
|
53
71
|
)
|
|
@@ -66,36 +84,36 @@ def cli():
|
|
|
66
84
|
type=int,
|
|
67
85
|
help="Vocabulary size, this option is intended for custom models",
|
|
68
86
|
)
|
|
69
|
-
@click.option(
|
|
70
|
-
"--data-type", type=str, default="auto", help="Model data type, default to auto"
|
|
71
|
-
)
|
|
87
|
+
@click.option("--data-type", type=str, help="Model data type")
|
|
72
88
|
@click.option(
|
|
73
89
|
"--venv",
|
|
74
90
|
type=str,
|
|
75
|
-
|
|
76
|
-
help="Path to virtual environment, default to preconfigured singularity container",
|
|
91
|
+
help="Path to virtual environment",
|
|
77
92
|
)
|
|
78
93
|
@click.option(
|
|
79
94
|
"--log-dir",
|
|
80
95
|
type=str,
|
|
81
|
-
|
|
82
|
-
help="Path to slurm log directory, default to .vec-inf-logs in user home directory",
|
|
96
|
+
help="Path to slurm log directory",
|
|
83
97
|
)
|
|
84
98
|
@click.option(
|
|
85
99
|
"--model-weights-parent-dir",
|
|
86
100
|
type=str,
|
|
87
|
-
|
|
88
|
-
help="Path to parent directory containing model weights, default to '/model-weights' for supported models",
|
|
101
|
+
help="Path to parent directory containing model weights",
|
|
89
102
|
)
|
|
90
103
|
@click.option(
|
|
91
104
|
"--pipeline-parallelism",
|
|
92
|
-
|
|
93
|
-
help="Enable pipeline parallelism,
|
|
105
|
+
is_flag=True,
|
|
106
|
+
help="Enable pipeline parallelism, enabled by default for supported models",
|
|
107
|
+
)
|
|
108
|
+
@click.option(
|
|
109
|
+
"--compilation-config",
|
|
110
|
+
type=click.Choice(["0", "3"]),
|
|
111
|
+
help="torch.compile optimization level, accepts '0' or '3', default to '0', which means no optimization is applied",
|
|
94
112
|
)
|
|
95
113
|
@click.option(
|
|
96
114
|
"--enforce-eager",
|
|
97
|
-
|
|
98
|
-
help="Always use eager-mode PyTorch
|
|
115
|
+
is_flag=True,
|
|
116
|
+
help="Always use eager-mode PyTorch",
|
|
99
117
|
)
|
|
100
118
|
@click.option(
|
|
101
119
|
"--json-mode",
|
|
@@ -104,74 +122,23 @@ def cli():
|
|
|
104
122
|
)
|
|
105
123
|
def launch(
|
|
106
124
|
model_name: str,
|
|
107
|
-
|
|
108
|
-
model_variant: Optional[str] = None,
|
|
109
|
-
max_model_len: Optional[int] = None,
|
|
110
|
-
max_num_seqs: Optional[int] = None,
|
|
111
|
-
partition: Optional[str] = None,
|
|
112
|
-
num_nodes: Optional[int] = None,
|
|
113
|
-
num_gpus: Optional[int] = None,
|
|
114
|
-
qos: Optional[str] = None,
|
|
115
|
-
time: Optional[str] = None,
|
|
116
|
-
vocab_size: Optional[int] = None,
|
|
117
|
-
data_type: Optional[str] = None,
|
|
118
|
-
venv: Optional[str] = None,
|
|
119
|
-
log_dir: Optional[str] = None,
|
|
120
|
-
model_weights_parent_dir: Optional[str] = None,
|
|
121
|
-
pipeline_parallelism: Optional[str] = None,
|
|
122
|
-
enforce_eager: Optional[str] = None,
|
|
123
|
-
json_mode: bool = False,
|
|
125
|
+
**cli_kwargs: Optional[Union[str, int, bool]],
|
|
124
126
|
) -> None:
|
|
125
|
-
"""
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
if isinstance(pipeline_parallelism, str):
|
|
130
|
-
pipeline_parallelism = (
|
|
131
|
-
"True" if pipeline_parallelism.lower() == "true" else "False"
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
launch_script_path = os.path.join(
|
|
135
|
-
os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "launch_server.sh"
|
|
136
|
-
)
|
|
137
|
-
launch_cmd = f"bash {launch_script_path}"
|
|
138
|
-
|
|
139
|
-
models_df = utils.load_models_df()
|
|
140
|
-
|
|
141
|
-
if model_name in models_df["model_name"].to_list():
|
|
142
|
-
default_args = utils.load_default_args(models_df, model_name)
|
|
143
|
-
for arg in default_args:
|
|
144
|
-
if arg in locals() and locals()[arg] is not None:
|
|
145
|
-
default_args[arg] = locals()[arg]
|
|
146
|
-
renamed_arg = arg.replace("_", "-")
|
|
147
|
-
launch_cmd += f" --{renamed_arg} {default_args[arg]}"
|
|
148
|
-
else:
|
|
149
|
-
model_args = models_df.columns
|
|
150
|
-
model_args.remove("model_name")
|
|
151
|
-
model_args.remove("model_type")
|
|
152
|
-
for arg in model_args:
|
|
153
|
-
if locals()[arg] is not None:
|
|
154
|
-
renamed_arg = arg.replace("_", "-")
|
|
155
|
-
launch_cmd += f" --{renamed_arg} {locals()[arg]}"
|
|
156
|
-
|
|
157
|
-
output = utils.run_bash_command(launch_cmd)
|
|
158
|
-
|
|
159
|
-
slurm_job_id = output.split(" ")[-1].strip().strip("\n")
|
|
160
|
-
output_lines = output.split("\n")[:-2]
|
|
161
|
-
|
|
162
|
-
table = utils.create_table(key_title="Job Config", value_title="Value")
|
|
163
|
-
table.add_row("Slurm Job ID", slurm_job_id, style="blue")
|
|
164
|
-
output_dict = {"slurm_job_id": slurm_job_id}
|
|
127
|
+
"""Launch a model on the cluster."""
|
|
128
|
+
try:
|
|
129
|
+
launch_helper = LaunchHelper(model_name, cli_kwargs)
|
|
165
130
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
131
|
+
launch_helper.set_env_vars()
|
|
132
|
+
launch_command = launch_helper.build_launch_command()
|
|
133
|
+
command_output, stderr = utils.run_bash_command(launch_command)
|
|
134
|
+
if stderr:
|
|
135
|
+
raise click.ClickException(f"Error: {stderr}")
|
|
136
|
+
launch_helper.post_launch_processing(command_output, CONSOLE)
|
|
170
137
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
138
|
+
except click.ClickException as e:
|
|
139
|
+
raise e
|
|
140
|
+
except Exception as e:
|
|
141
|
+
raise click.ClickException(f"Launch failed: {str(e)}") from e
|
|
175
142
|
|
|
176
143
|
|
|
177
144
|
@cli.command("status")
|
|
@@ -189,79 +156,25 @@ def launch(
|
|
|
189
156
|
def status(
|
|
190
157
|
slurm_job_id: int, log_dir: Optional[str] = None, json_mode: bool = False
|
|
191
158
|
) -> None:
|
|
192
|
-
"""
|
|
193
|
-
Get the status of a running model on the cluster
|
|
194
|
-
"""
|
|
159
|
+
"""Get the status of a running model on the cluster."""
|
|
195
160
|
status_cmd = f"scontrol show job {slurm_job_id} --oneliner"
|
|
196
|
-
output = utils.run_bash_command(status_cmd)
|
|
161
|
+
output, stderr = utils.run_bash_command(status_cmd)
|
|
162
|
+
if stderr:
|
|
163
|
+
raise click.ClickException(f"Error: {stderr}")
|
|
197
164
|
|
|
198
|
-
|
|
199
|
-
status = "SHUTDOWN"
|
|
200
|
-
base_url = "UNAVAILABLE"
|
|
201
|
-
|
|
202
|
-
try:
|
|
203
|
-
slurm_job_name = output.split(" ")[1].split("=")[1]
|
|
204
|
-
slurm_job_state = output.split(" ")[9].split("=")[1]
|
|
205
|
-
except IndexError:
|
|
206
|
-
# Job ID not found
|
|
207
|
-
slurm_job_state = "UNAVAILABLE"
|
|
208
|
-
|
|
209
|
-
# If Slurm job is currently PENDING
|
|
210
|
-
if slurm_job_state == "PENDING":
|
|
211
|
-
slurm_job_pending_reason = output.split(" ")[10].split("=")[1]
|
|
212
|
-
status = "PENDING"
|
|
213
|
-
# If Slurm job is currently RUNNING
|
|
214
|
-
elif slurm_job_state == "RUNNING":
|
|
215
|
-
# Check whether the server is ready, if yes, run model health check to further determine status
|
|
216
|
-
server_status = utils.is_server_running(slurm_job_name, slurm_job_id, log_dir)
|
|
217
|
-
# If server status is a tuple, then server status is "FAILED"
|
|
218
|
-
if isinstance(server_status, tuple):
|
|
219
|
-
status = server_status[0]
|
|
220
|
-
slurm_job_failed_reason = server_status[1]
|
|
221
|
-
elif server_status == "RUNNING":
|
|
222
|
-
model_status = utils.model_health_check(
|
|
223
|
-
slurm_job_name, slurm_job_id, log_dir
|
|
224
|
-
)
|
|
225
|
-
if model_status == "READY":
|
|
226
|
-
# Only set base_url if model is ready to serve requests
|
|
227
|
-
base_url = utils.get_base_url(slurm_job_name, slurm_job_id, log_dir)
|
|
228
|
-
status = "READY"
|
|
229
|
-
else:
|
|
230
|
-
# If model is not ready, then status must be "FAILED"
|
|
231
|
-
status = model_status[0]
|
|
232
|
-
slurm_job_failed_reason = str(model_status[1])
|
|
233
|
-
else:
|
|
234
|
-
status = server_status
|
|
165
|
+
status_helper = StatusHelper(slurm_job_id, output, log_dir)
|
|
235
166
|
|
|
167
|
+
status_helper.process_job_state()
|
|
236
168
|
if json_mode:
|
|
237
|
-
|
|
238
|
-
"model_name": slurm_job_name,
|
|
239
|
-
"model_status": status,
|
|
240
|
-
"base_url": base_url,
|
|
241
|
-
}
|
|
242
|
-
if "slurm_job_pending_reason" in locals():
|
|
243
|
-
status_dict["pending_reason"] = slurm_job_pending_reason
|
|
244
|
-
if "slurm_job_failed_reason" in locals():
|
|
245
|
-
status_dict["failed_reason"] = slurm_job_failed_reason
|
|
246
|
-
click.echo(f"{status_dict}")
|
|
169
|
+
status_helper.output_json()
|
|
247
170
|
else:
|
|
248
|
-
|
|
249
|
-
table.add_row("Model Name", slurm_job_name)
|
|
250
|
-
table.add_row("Model Status", status, style="blue")
|
|
251
|
-
if "slurm_job_pending_reason" in locals():
|
|
252
|
-
table.add_row("Reason", slurm_job_pending_reason)
|
|
253
|
-
if "slurm_job_failed_reason" in locals():
|
|
254
|
-
table.add_row("Reason", slurm_job_failed_reason)
|
|
255
|
-
table.add_row("Base URL", base_url)
|
|
256
|
-
CONSOLE.print(table)
|
|
171
|
+
status_helper.output_table(CONSOLE)
|
|
257
172
|
|
|
258
173
|
|
|
259
174
|
@cli.command("shutdown")
|
|
260
175
|
@click.argument("slurm_job_id", type=int, nargs=1)
|
|
261
176
|
def shutdown(slurm_job_id: int) -> None:
|
|
262
|
-
"""
|
|
263
|
-
Shutdown a running model on the cluster
|
|
264
|
-
"""
|
|
177
|
+
"""Shutdown a running model on the cluster."""
|
|
265
178
|
shutdown_cmd = f"scancel {slurm_job_id}"
|
|
266
179
|
utils.run_bash_command(shutdown_cmd)
|
|
267
180
|
click.echo(f"Shutting down model with Slurm Job ID: {slurm_job_id}")
|
|
@@ -275,109 +188,41 @@ def shutdown(slurm_job_id: int) -> None:
|
|
|
275
188
|
help="Output in JSON string",
|
|
276
189
|
)
|
|
277
190
|
def list_models(model_name: Optional[str] = None, json_mode: bool = False) -> None:
|
|
278
|
-
"""
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
def list_model(model_name: str, models_df: pl.DataFrame, json_mode: bool):
|
|
283
|
-
if model_name not in models_df["model_name"].to_list():
|
|
284
|
-
raise ValueError(f"Model name {model_name} not found in available models")
|
|
285
|
-
|
|
286
|
-
excluded_keys = {"venv", "log_dir"}
|
|
287
|
-
model_row = models_df.filter(models_df["model_name"] == model_name)
|
|
288
|
-
|
|
289
|
-
if json_mode:
|
|
290
|
-
filtered_model_row = model_row.drop(excluded_keys, strict=False)
|
|
291
|
-
click.echo(filtered_model_row.to_dicts()[0])
|
|
292
|
-
return
|
|
293
|
-
table = utils.create_table(key_title="Model Config", value_title="Value")
|
|
294
|
-
for row in model_row.to_dicts():
|
|
295
|
-
for key, value in row.items():
|
|
296
|
-
if key not in excluded_keys:
|
|
297
|
-
table.add_row(key, str(value))
|
|
298
|
-
CONSOLE.print(table)
|
|
299
|
-
|
|
300
|
-
def list_all(models_df: pl.DataFrame, json_mode: bool):
|
|
301
|
-
if json_mode:
|
|
302
|
-
click.echo(models_df["model_name"].to_list())
|
|
303
|
-
return
|
|
304
|
-
panels = []
|
|
305
|
-
model_type_colors = {
|
|
306
|
-
"LLM": "cyan",
|
|
307
|
-
"VLM": "bright_blue",
|
|
308
|
-
"Text Embedding": "purple",
|
|
309
|
-
"Reward Modeling": "bright_magenta",
|
|
310
|
-
}
|
|
311
|
-
|
|
312
|
-
models_df = models_df.with_columns(
|
|
313
|
-
pl.when(pl.col("model_type") == "LLM")
|
|
314
|
-
.then(0)
|
|
315
|
-
.when(pl.col("model_type") == "VLM")
|
|
316
|
-
.then(1)
|
|
317
|
-
.when(pl.col("model_type") == "Text Embedding")
|
|
318
|
-
.then(2)
|
|
319
|
-
.when(pl.col("model_type") == "Reward Modeling")
|
|
320
|
-
.then(3)
|
|
321
|
-
.otherwise(-1)
|
|
322
|
-
.alias("model_type_order")
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
models_df = models_df.sort("model_type_order")
|
|
326
|
-
models_df = models_df.drop("model_type_order")
|
|
327
|
-
|
|
328
|
-
for row in models_df.to_dicts():
|
|
329
|
-
panel_color = model_type_colors.get(row["model_type"], "white")
|
|
330
|
-
styled_text = (
|
|
331
|
-
f"[magenta]{row['model_family']}[/magenta]-{row['model_variant']}"
|
|
332
|
-
)
|
|
333
|
-
panels.append(Panel(styled_text, expand=True, border_style=panel_color))
|
|
334
|
-
CONSOLE.print(Columns(panels, equal=True))
|
|
335
|
-
|
|
336
|
-
models_df = utils.load_models_df()
|
|
337
|
-
|
|
338
|
-
if model_name:
|
|
339
|
-
list_model(model_name, models_df, json_mode)
|
|
340
|
-
else:
|
|
341
|
-
list_all(models_df, json_mode)
|
|
191
|
+
"""List all available models, or get default setup of a specific model."""
|
|
192
|
+
list_helper = ListHelper(model_name, json_mode)
|
|
193
|
+
list_helper.process_list_command(CONSOLE)
|
|
342
194
|
|
|
343
195
|
|
|
344
196
|
@cli.command("metrics")
|
|
345
197
|
@click.argument("slurm_job_id", type=int, nargs=1)
|
|
346
198
|
@click.option(
|
|
347
|
-
"--log-dir",
|
|
348
|
-
type=str,
|
|
349
|
-
help="Path to slurm log directory. This is required if --log-dir was set in model launch",
|
|
199
|
+
"--log-dir", type=str, help="Path to slurm log directory (if used during launch)"
|
|
350
200
|
)
|
|
351
201
|
def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None:
|
|
352
|
-
"""
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
202
|
+
"""Stream real-time performance metrics from the model endpoint."""
|
|
203
|
+
helper = MetricsHelper(slurm_job_id, log_dir)
|
|
204
|
+
|
|
205
|
+
# Check if metrics URL is ready
|
|
206
|
+
if not helper.metrics_url.startswith("http"):
|
|
207
|
+
table = utils.create_table("Metric", "Value")
|
|
208
|
+
helper.display_failed_metrics(
|
|
209
|
+
table, f"Metrics endpoint unavailable - {helper.metrics_url}"
|
|
210
|
+
)
|
|
211
|
+
CONSOLE.print(table)
|
|
212
|
+
return
|
|
358
213
|
|
|
359
214
|
with Live(refresh_per_second=1, console=CONSOLE) as live:
|
|
360
215
|
while True:
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
)
|
|
364
|
-
# if out_logs is a string, then it is an error message
|
|
365
|
-
if isinstance(out_logs, str):
|
|
366
|
-
live.update(out_logs)
|
|
367
|
-
break
|
|
368
|
-
out_logs = cast(list, out_logs)
|
|
369
|
-
latest_metrics = utils.get_latest_metric(out_logs)
|
|
370
|
-
# if latest_metrics is a string, then it is an error message
|
|
371
|
-
if isinstance(latest_metrics, str):
|
|
372
|
-
live.update(latest_metrics)
|
|
373
|
-
break
|
|
374
|
-
latest_metrics = cast(dict, latest_metrics)
|
|
375
|
-
table = utils.create_table(key_title="Metric", value_title="Value")
|
|
376
|
-
for key, value in latest_metrics.items():
|
|
377
|
-
table.add_row(key, value)
|
|
216
|
+
metrics = helper.fetch_metrics()
|
|
217
|
+
table = utils.create_table("Metric", "Value")
|
|
378
218
|
|
|
379
|
-
|
|
219
|
+
if isinstance(metrics, str):
|
|
220
|
+
# Show status information if metrics aren't available
|
|
221
|
+
helper.display_failed_metrics(table, metrics)
|
|
222
|
+
else:
|
|
223
|
+
helper.display_metrics(table, metrics)
|
|
380
224
|
|
|
225
|
+
live.update(table)
|
|
381
226
|
time.sleep(2)
|
|
382
227
|
|
|
383
228
|
|
vec_inf/cli/_config.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Model configuration."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
7
|
+
from typing_extensions import Literal
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
QOS = Literal[
|
|
11
|
+
"normal",
|
|
12
|
+
"m",
|
|
13
|
+
"m2",
|
|
14
|
+
"m3",
|
|
15
|
+
"m4",
|
|
16
|
+
"m5",
|
|
17
|
+
"long",
|
|
18
|
+
"deadline",
|
|
19
|
+
"high",
|
|
20
|
+
"scavenger",
|
|
21
|
+
"llm",
|
|
22
|
+
"a100",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
PARTITION = Literal["a40", "a100", "t4v1", "t4v2", "rtx6000"]
|
|
26
|
+
|
|
27
|
+
DATA_TYPE = Literal["auto", "float16", "bfloat16", "float32"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ModelConfig(BaseModel):
|
|
31
|
+
"""Pydantic model for validating and managing model deployment configurations."""
|
|
32
|
+
|
|
33
|
+
model_name: str = Field(..., min_length=3, pattern=r"^[a-zA-Z0-9\-_\.]+$")
|
|
34
|
+
model_family: str = Field(..., min_length=2)
|
|
35
|
+
model_variant: Optional[str] = Field(
|
|
36
|
+
default=None, description="Specific variant/version of the model family"
|
|
37
|
+
)
|
|
38
|
+
model_type: Literal["LLM", "VLM", "Text_Embedding", "Reward_Modeling"] = Field(
|
|
39
|
+
..., description="Type of model architecture"
|
|
40
|
+
)
|
|
41
|
+
gpus_per_node: int = Field(..., gt=0, le=8, description="GPUs per node")
|
|
42
|
+
num_nodes: int = Field(..., gt=0, le=16, description="Number of nodes")
|
|
43
|
+
vocab_size: int = Field(..., gt=0, le=1_000_000)
|
|
44
|
+
max_model_len: int = Field(
|
|
45
|
+
..., gt=0, le=1_010_000, description="Maximum context length supported"
|
|
46
|
+
)
|
|
47
|
+
max_num_seqs: int = Field(
|
|
48
|
+
default=256, gt=0, le=1024, description="Maximum concurrent request sequences"
|
|
49
|
+
)
|
|
50
|
+
compilation_config: int = Field(
|
|
51
|
+
default=0,
|
|
52
|
+
gt=-1,
|
|
53
|
+
le=4,
|
|
54
|
+
description="torch.compile optimization level",
|
|
55
|
+
)
|
|
56
|
+
gpu_memory_utilization: float = Field(
|
|
57
|
+
default=0.9, gt=0.0, le=1.0, description="GPU memory utilization"
|
|
58
|
+
)
|
|
59
|
+
pipeline_parallelism: bool = Field(
|
|
60
|
+
default=True, description="Enable pipeline parallelism"
|
|
61
|
+
)
|
|
62
|
+
enforce_eager: bool = Field(default=False, description="Force eager mode execution")
|
|
63
|
+
qos: Union[QOS, str] = Field(default="m2", description="Quality of Service tier")
|
|
64
|
+
time: str = Field(
|
|
65
|
+
default="08:00:00",
|
|
66
|
+
pattern=r"^\d{2}:\d{2}:\d{2}$",
|
|
67
|
+
description="HH:MM:SS time limit",
|
|
68
|
+
)
|
|
69
|
+
partition: Union[PARTITION, str] = Field(
|
|
70
|
+
default="a40", description="GPU partition type"
|
|
71
|
+
)
|
|
72
|
+
data_type: Union[DATA_TYPE, str] = Field(
|
|
73
|
+
default="auto", description="Model precision format"
|
|
74
|
+
)
|
|
75
|
+
venv: str = Field(
|
|
76
|
+
default="singularity", description="Virtual environment/container system"
|
|
77
|
+
)
|
|
78
|
+
log_dir: Path = Field(
|
|
79
|
+
default=Path("~/.vec-inf-logs").expanduser(), description="Log directory path"
|
|
80
|
+
)
|
|
81
|
+
model_weights_parent_dir: Path = Field(
|
|
82
|
+
default=Path("/model-weights"), description="Base directory for model weights"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
model_config = ConfigDict(
|
|
86
|
+
extra="forbid", str_strip_whitespace=True, validate_default=True, frozen=True
|
|
87
|
+
)
|