vec-inf 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vec_inf/cli/_helper.py CHANGED
@@ -1,609 +1,312 @@
1
- """Command line interface for Vector Inference."""
1
+ """Helper classes for the CLI.
2
+
3
+ This module provides formatting and display classes for the command-line interface,
4
+ handling the presentation of model information, status updates, and metrics.
5
+ """
2
6
 
3
- import json
4
- import os
5
- import time
6
7
  from pathlib import Path
7
- from typing import Any, Optional, Union, cast
8
- from urllib.parse import urlparse, urlunparse
8
+ from typing import Any, Union
9
9
 
10
10
  import click
11
- import requests
12
11
  from rich.columns import Columns
13
12
  from rich.console import Console
14
13
  from rich.panel import Panel
15
14
  from rich.table import Table
16
15
 
17
- import vec_inf.cli._utils as utils
18
- from vec_inf.cli._config import ModelConfig
19
-
20
-
21
- VLLM_TASK_MAP = {
22
- "LLM": "generate",
23
- "VLM": "generate",
24
- "Text_Embedding": "embed",
25
- "Reward_Modeling": "reward",
26
- }
27
-
28
- REQUIRED_FIELDS = {
29
- "model_family",
30
- "model_type",
31
- "gpus_per_node",
32
- "num_nodes",
33
- "vocab_size",
34
- "max_model_len",
35
- }
36
-
37
- BOOLEAN_FIELDS = {
38
- "pipeline_parallelism",
39
- "enforce_eager",
40
- "enable_prefix_caching",
41
- "enable_chunked_prefill",
42
- }
43
-
44
- LD_LIBRARY_PATH = "/scratch/ssd001/pkgs/cudnn-11.7-v8.5.0.96/lib/:/scratch/ssd001/pkgs/cuda-11.7/targets/x86_64-linux/lib/"
45
- SRC_DIR = str(Path(__file__).parent.parent)
46
-
47
-
48
- class LaunchHelper:
49
- def __init__(
50
- self, model_name: str, cli_kwargs: dict[str, Optional[Union[str, int, bool]]]
51
- ):
52
- self.model_name = model_name
53
- self.cli_kwargs = cli_kwargs
54
- self.model_config = self._get_model_configuration()
55
- self.params = self._get_launch_params()
56
-
57
- def _get_model_configuration(self) -> ModelConfig:
58
- """Load and validate model configuration."""
59
- model_configs = utils.load_config()
60
- config = next(
61
- (m for m in model_configs if m.model_name == self.model_name), None
62
- )
16
+ from vec_inf.cli._utils import create_table
17
+ from vec_inf.cli._vars import MODEL_TYPE_COLORS, MODEL_TYPE_PRIORITY
18
+ from vec_inf.client import ModelConfig, ModelInfo, StatusResponse
63
19
 
64
- if config:
65
- return config
66
- # If model config not found, load path from CLI args or fallback to default
67
- model_weights_parent_dir = self.cli_kwargs.get(
68
- "model_weights_parent_dir", model_configs[0].model_weights_parent_dir
69
- )
70
- model_weights_path = Path(cast(str, model_weights_parent_dir), self.model_name)
71
- # Only give a warning msg if weights exist but config missing
72
- if model_weights_path.exists():
73
- click.echo(
74
- click.style(
75
- f"Warning: '{self.model_name}' configuration not found in config, please ensure model configuration are properly set in command arguments",
76
- fg="yellow",
77
- )
78
- )
79
- # Return a dummy model config object with model name and weights parent dir
80
- return ModelConfig(
81
- model_name=self.model_name,
82
- model_family="model_family_placeholder",
83
- model_type="LLM",
84
- gpus_per_node=1,
85
- num_nodes=1,
86
- vocab_size=1000,
87
- max_model_len=8192,
88
- model_weights_parent_dir=Path(cast(str, model_weights_parent_dir)),
89
- )
90
- raise click.ClickException(
91
- f"'{self.model_name}' not found in configuration and model weights "
92
- f"not found at expected path '{model_weights_path}'"
93
- )
94
20
 
95
- def _get_launch_params(self) -> dict[str, Any]:
96
- """Merge config defaults with CLI overrides."""
97
- params = self.model_config.model_dump()
98
-
99
- # Process boolean fields
100
- for bool_field in BOOLEAN_FIELDS:
101
- if self.cli_kwargs[bool_field]:
102
- params[bool_field] = True
103
-
104
- # Merge other overrides
105
- for key, value in self.cli_kwargs.items():
106
- if value is not None and key not in [
107
- "json_mode",
108
- *BOOLEAN_FIELDS,
109
- ]:
110
- params[key] = value
111
-
112
- # Validate required fields
113
- if not REQUIRED_FIELDS.issubset(set(params.keys())):
114
- raise click.ClickException(
115
- f"Missing required fields: {REQUIRED_FIELDS - set(params.keys())}"
116
- )
21
+ class LaunchResponseFormatter:
22
+ """CLI Helper class for formatting LaunchResponse.
117
23
 
118
- # Create log directory
119
- params["log_dir"] = Path(params["log_dir"], params["model_family"]).expanduser()
120
- params["log_dir"].mkdir(parents=True, exist_ok=True)
121
-
122
- # Convert to string for JSON serialization
123
- for field in params:
124
- params[field] = str(params[field])
125
-
126
- return params
127
-
128
- def set_env_vars(self) -> None:
129
- """Set environment variables for the launch command."""
130
- os.environ["MODEL_NAME"] = self.model_name
131
- os.environ["MAX_MODEL_LEN"] = self.params["max_model_len"]
132
- os.environ["MAX_LOGPROBS"] = self.params["vocab_size"]
133
- os.environ["DATA_TYPE"] = self.params["data_type"]
134
- os.environ["MAX_NUM_SEQS"] = self.params["max_num_seqs"]
135
- os.environ["GPU_MEMORY_UTILIZATION"] = self.params["gpu_memory_utilization"]
136
- os.environ["TASK"] = VLLM_TASK_MAP[self.params["model_type"]]
137
- os.environ["PIPELINE_PARALLELISM"] = self.params["pipeline_parallelism"]
138
- os.environ["COMPILATION_CONFIG"] = self.params["compilation_config"]
139
- os.environ["SRC_DIR"] = SRC_DIR
140
- os.environ["MODEL_WEIGHTS"] = str(
141
- Path(self.params["model_weights_parent_dir"], self.model_name)
142
- )
143
- os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH
144
- os.environ["VENV_BASE"] = self.params["venv"]
145
- os.environ["LOG_DIR"] = self.params["log_dir"]
146
-
147
- if self.params.get("enable_prefix_caching"):
148
- os.environ["ENABLE_PREFIX_CACHING"] = self.params["enable_prefix_caching"]
149
- if self.params.get("enable_chunked_prefill"):
150
- os.environ["ENABLE_CHUNKED_PREFILL"] = self.params["enable_chunked_prefill"]
151
- if self.params.get("max_num_batched_tokens"):
152
- os.environ["MAX_NUM_BATCHED_TOKENS"] = self.params["max_num_batched_tokens"]
153
- if self.params.get("enforce_eager"):
154
- os.environ["ENFORCE_EAGER"] = self.params["enforce_eager"]
155
-
156
- def build_launch_command(self) -> str:
157
- """Construct the full launch command with parameters."""
158
- # Base command
159
- command_list = ["sbatch"]
160
- # Append options
161
- command_list.extend(["--job-name", f"{self.model_name}"])
162
- command_list.extend(["--partition", f"{self.params['partition']}"])
163
- command_list.extend(["--qos", f"{self.params['qos']}"])
164
- command_list.extend(["--time", f"{self.params['time']}"])
165
- command_list.extend(["--nodes", f"{self.params['num_nodes']}"])
166
- command_list.extend(["--gpus-per-node", f"{self.params['gpus_per_node']}"])
167
- command_list.extend(
168
- [
169
- "--output",
170
- f"{self.params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.out",
171
- ]
172
- )
173
- command_list.extend(
174
- [
175
- "--error",
176
- f"{self.params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.err",
177
- ]
178
- )
179
- # Add slurm script
180
- slurm_script = "vllm.slurm"
181
- if int(self.params["num_nodes"]) > 1:
182
- slurm_script = "multinode_vllm.slurm"
183
- command_list.append(f"{SRC_DIR}/{slurm_script}")
184
- return " ".join(command_list)
185
-
186
- def format_table_output(self, job_id: str) -> Table:
187
- """Format output as rich Table."""
188
- table = utils.create_table(key_title="Job Config", value_title="Value")
189
- # Add rows
190
- table.add_row("Slurm Job ID", job_id, style="blue")
24
+ A formatter class that handles the presentation of model launch information
25
+ in both table and JSON formats.
26
+
27
+ Parameters
28
+ ----------
29
+ model_name : str
30
+ Name of the launched model
31
+ params : dict[str, Any]
32
+ Launch parameters and configuration
33
+ """
34
+
35
+ def __init__(self, model_name: str, params: dict[str, Any]):
36
+ self.model_name = model_name
37
+ self.params = params
38
+
39
+ def format_table_output(self) -> Table:
40
+ """Format output as rich Table.
41
+
42
+ Returns
43
+ -------
44
+ Table
45
+ Rich table containing formatted launch information including:
46
+ - Job configuration
47
+ - Model details
48
+ - Resource allocation
49
+ - vLLM configuration
50
+ """
51
+ table = create_table(key_title="Job Config", value_title="Value")
52
+
53
+ # Add key information with consistent styling
54
+ table.add_row("Slurm Job ID", self.params["slurm_job_id"], style="blue")
191
55
  table.add_row("Job Name", self.model_name)
56
+
57
+ # Add model details
192
58
  table.add_row("Model Type", self.params["model_type"])
59
+ table.add_row("Vocabulary Size", self.params["vocab_size"])
60
+
61
+ # Add resource allocation details
193
62
  table.add_row("Partition", self.params["partition"])
194
63
  table.add_row("QoS", self.params["qos"])
195
64
  table.add_row("Time Limit", self.params["time"])
196
65
  table.add_row("Num Nodes", self.params["num_nodes"])
197
66
  table.add_row("GPUs/Node", self.params["gpus_per_node"])
198
- table.add_row("Data Type", self.params["data_type"])
199
- table.add_row("Vocabulary Size", self.params["vocab_size"])
200
- table.add_row("Max Model Length", self.params["max_model_len"])
201
- table.add_row("Max Num Seqs", self.params["max_num_seqs"])
202
- table.add_row("GPU Memory Utilization", self.params["gpu_memory_utilization"])
203
- table.add_row("Compilation Config", self.params["compilation_config"])
204
- table.add_row("Pipeline Parallelism", self.params["pipeline_parallelism"])
205
- if self.params.get("enable_prefix_caching"):
206
- table.add_row("Enable Prefix Caching", self.params["enable_prefix_caching"])
207
- if self.params.get("enable_chunked_prefill"):
208
- table.add_row(
209
- "Enable Chunked Prefill", self.params["enable_chunked_prefill"]
210
- )
211
- if self.params.get("max_num_batched_tokens"):
212
- table.add_row(
213
- "Max Num Batched Tokens", self.params["max_num_batched_tokens"]
214
- )
215
- if self.params.get("enforce_eager"):
216
- table.add_row("Enforce Eager", self.params["enforce_eager"])
217
- table.add_row("Model Weights Directory", os.environ.get("MODEL_WEIGHTS"))
67
+ table.add_row("CPUs/Task", self.params["cpus_per_task"])
68
+ table.add_row("Memory/Node", self.params["mem_per_node"])
69
+
70
+ # Add job config details
71
+ table.add_row(
72
+ "Model Weights Directory",
73
+ str(Path(self.params["model_weights_parent_dir"], self.model_name)),
74
+ )
218
75
  table.add_row("Log Directory", self.params["log_dir"])
219
76
 
220
- return table
77
+ # Add vLLM configuration details
78
+ table.add_row("vLLM Arguments:", style="magenta")
79
+ for arg, value in self.params["vllm_args"].items():
80
+ table.add_row(f" {arg}:", str(value))
221
81
 
222
- def post_launch_processing(self, output: str, console: Console) -> None:
223
- """Process and display launch output."""
224
- json_mode = bool(self.cli_kwargs.get("json_mode", False))
225
- slurm_job_id = output.split(" ")[-1].strip().strip("\n")
226
- self.params["slurm_job_id"] = slurm_job_id
227
- job_json = Path(
228
- self.params["log_dir"],
229
- f"{self.model_name}.{slurm_job_id}",
230
- f"{self.model_name}.{slurm_job_id}.json",
231
- )
232
- job_json.parent.mkdir(parents=True, exist_ok=True)
233
- job_json.touch(exist_ok=True)
82
+ return table
234
83
 
235
- with job_json.open("w") as file:
236
- json.dump(self.params, file, indent=4)
237
- if json_mode:
238
- click.echo(self.params)
239
- else:
240
- table = self.format_table_output(slurm_job_id)
241
- console.print(table)
242
-
243
-
244
- class StatusHelper:
245
- def __init__(self, slurm_job_id: int, output: str, log_dir: Optional[str] = None):
246
- self.slurm_job_id = slurm_job_id
247
- self.output = output
248
- self.log_dir = log_dir
249
- self.status_info = self._get_base_status_data()
250
-
251
- def _get_base_status_data(self) -> dict[str, Union[str, None]]:
252
- """Extract basic job status information from scontrol output."""
253
- try:
254
- job_name = self.output.split(" ")[1].split("=")[1]
255
- job_state = self.output.split(" ")[9].split("=")[1]
256
- except IndexError:
257
- job_name = "UNAVAILABLE"
258
- job_state = "UNAVAILABLE"
259
-
260
- return {
261
- "model_name": job_name,
262
- "status": "UNAVAILABLE",
263
- "base_url": "UNAVAILABLE",
264
- "state": job_state,
265
- "pending_reason": None,
266
- "failed_reason": None,
267
- }
268
84
 
269
- def process_job_state(self) -> None:
270
- """Process different job states and update status information."""
271
- if self.status_info["state"] == "PENDING":
272
- self.process_pending_state()
273
- elif self.status_info["state"] == "RUNNING":
274
- self.process_running_state()
275
-
276
- def check_model_health(self) -> None:
277
- """Check model health and update status accordingly."""
278
- status, status_code = utils.model_health_check(
279
- cast(str, self.status_info["model_name"]), self.slurm_job_id, self.log_dir
280
- )
281
- if status == "READY":
282
- self.status_info["base_url"] = utils.get_base_url(
283
- cast(str, self.status_info["model_name"]),
284
- self.slurm_job_id,
285
- self.log_dir,
286
- )
287
- self.status_info["status"] = status
288
- else:
289
- self.status_info["status"], self.status_info["failed_reason"] = (
290
- status,
291
- cast(str, status_code),
292
- )
85
+ class StatusResponseFormatter:
86
+ """CLI Helper class for formatting StatusResponse.
293
87
 
294
- def process_running_state(self) -> None:
295
- """Process RUNNING job state and check server status."""
296
- server_status = utils.is_server_running(
297
- cast(str, self.status_info["model_name"]), self.slurm_job_id, self.log_dir
298
- )
88
+ A formatter class that handles the presentation of model status information
89
+ in both table and JSON formats.
299
90
 
300
- if isinstance(server_status, tuple):
301
- self.status_info["status"], self.status_info["failed_reason"] = (
302
- server_status
303
- )
304
- return
91
+ Parameters
92
+ ----------
93
+ status_info : StatusResponse
94
+ Status information to format
95
+ """
305
96
 
306
- if server_status == "RUNNING":
307
- self.check_model_health()
308
- else:
309
- self.status_info["status"] = server_status
310
-
311
- def process_pending_state(self) -> None:
312
- """Process PENDING job state."""
313
- try:
314
- self.status_info["pending_reason"] = self.output.split(" ")[10].split("=")[
315
- 1
316
- ]
317
- self.status_info["status"] = "PENDING"
318
- except IndexError:
319
- self.status_info["pending_reason"] = "Unknown pending reason"
97
+ def __init__(self, status_info: StatusResponse):
98
+ self.status_info = status_info
320
99
 
321
100
  def output_json(self) -> None:
322
- """Format and output JSON data."""
101
+ """Format and output JSON data.
102
+
103
+ Outputs a JSON object containing:
104
+ - model_name
105
+ - model_status
106
+ - base_url
107
+ - pending_reason (if applicable)
108
+ - failed_reason (if applicable)
109
+ """
323
110
  json_data = {
324
- "model_name": self.status_info["model_name"],
325
- "model_status": self.status_info["status"],
326
- "base_url": self.status_info["base_url"],
111
+ "model_name": self.status_info.model_name,
112
+ "model_status": self.status_info.server_status,
113
+ "base_url": self.status_info.base_url,
327
114
  }
328
- if self.status_info["pending_reason"]:
329
- json_data["pending_reason"] = self.status_info["pending_reason"]
330
- if self.status_info["failed_reason"]:
331
- json_data["failed_reason"] = self.status_info["failed_reason"]
115
+ if self.status_info.pending_reason:
116
+ json_data["pending_reason"] = self.status_info.pending_reason
117
+ if self.status_info.failed_reason:
118
+ json_data["failed_reason"] = self.status_info.failed_reason
332
119
  click.echo(json_data)
333
120
 
334
- def output_table(self, console: Console) -> None:
335
- """Create and display rich table."""
336
- table = utils.create_table(key_title="Job Status", value_title="Value")
337
- table.add_row("Model Name", self.status_info["model_name"])
338
- table.add_row("Model Status", self.status_info["status"], style="blue")
121
+ def output_table(self) -> Table:
122
+ """Create and display rich table.
123
+
124
+ Returns
125
+ -------
126
+ Table
127
+ Rich table containing formatted status information including:
128
+ - Model name
129
+ - Status
130
+ - Base URL
131
+ - Error information (if applicable)
132
+ """
133
+ table = create_table(key_title="Job Status", value_title="Value")
134
+ table.add_row("Model Name", self.status_info.model_name)
135
+ table.add_row("Model Status", self.status_info.server_status, style="blue")
136
+
137
+ if self.status_info.pending_reason:
138
+ table.add_row("Pending Reason", self.status_info.pending_reason)
139
+ if self.status_info.failed_reason:
140
+ table.add_row("Failed Reason", self.status_info.failed_reason)
141
+
142
+ table.add_row("Base URL", self.status_info.base_url)
143
+ return table
144
+
339
145
 
340
- if self.status_info["pending_reason"]:
341
- table.add_row("Pending Reason", self.status_info["pending_reason"])
342
- if self.status_info["failed_reason"]:
343
- table.add_row("Failed Reason", self.status_info["failed_reason"])
146
+ class MetricsResponseFormatter:
147
+ """CLI Helper class for formatting MetricsResponse.
344
148
 
345
- table.add_row("Base URL", self.status_info["base_url"])
346
- console.print(table)
149
+ A formatter class that handles the presentation of model metrics
150
+ in a table format.
347
151
 
152
+ Parameters
153
+ ----------
154
+ metrics : Union[dict[str, float], str]
155
+ Dictionary of metrics or error message
156
+ """
348
157
 
349
- class MetricsHelper:
350
- def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None):
351
- self.slurm_job_id = slurm_job_id
352
- self.log_dir = log_dir
353
- self.status_info = self._get_status_info()
354
- self.metrics_url = self._build_metrics_url()
158
+ def __init__(self, metrics: Union[dict[str, float], str]):
159
+ self.metrics = self._set_metrics(metrics)
160
+ self.table = create_table("Metric", "Value")
355
161
  self.enabled_prefix_caching = self._check_prefix_caching()
356
162
 
357
- self._prev_prompt_tokens: float = 0.0
358
- self._prev_generation_tokens: float = 0.0
359
- self._last_updated: Optional[float] = None
360
- self._last_throughputs = {"prompt": 0.0, "generation": 0.0}
361
-
362
- def _get_status_info(self) -> dict[str, Union[str, None]]:
363
- """Retrieve status info using existing StatusHelper."""
364
- status_cmd = f"scontrol show job {self.slurm_job_id} --oneliner"
365
- output, stderr = utils.run_bash_command(status_cmd)
366
- if stderr:
367
- raise click.ClickException(f"Error: {stderr}")
368
- status_helper = StatusHelper(self.slurm_job_id, output, self.log_dir)
369
- return status_helper.status_info
370
-
371
- def _build_metrics_url(self) -> str:
372
- """Construct metrics endpoint URL from base URL with version stripping."""
373
- if self.status_info.get("state") == "PENDING":
374
- return "Pending resources for server initialization"
375
-
376
- base_url = utils.get_base_url(
377
- cast(str, self.status_info["model_name"]),
378
- self.slurm_job_id,
379
- self.log_dir,
380
- )
381
- if not base_url.startswith("http"):
382
- return "Server not ready"
163
+ def _set_metrics(self, metrics: Union[dict[str, float], str]) -> dict[str, float]:
164
+ """Set the metrics attribute.
383
165
 
384
- parsed = urlparse(base_url)
385
- clean_path = parsed.path.replace("/v1", "", 1).rstrip("/")
386
- return urlunparse(
387
- (parsed.scheme, parsed.netloc, f"{clean_path}/metrics", "", "", "")
388
- )
166
+ Parameters
167
+ ----------
168
+ metrics : Union[dict[str, float], str]
169
+ Raw metrics data
389
170
 
390
- def _check_prefix_caching(self) -> bool:
391
- """Check if prefix caching is enabled."""
392
- job_json = utils.read_slurm_log(
393
- cast(str, self.status_info["model_name"]),
394
- self.slurm_job_id,
395
- "json",
396
- self.log_dir,
397
- )
398
- if isinstance(job_json, str):
399
- return False
400
- return bool(cast(dict[str, str], job_json).get("enable_prefix_caching", False))
401
-
402
- def fetch_metrics(self) -> Union[dict[str, float], str]:
403
- """Fetch metrics from the endpoint."""
404
- try:
405
- response = requests.get(self.metrics_url, timeout=3)
406
- response.raise_for_status()
407
- current_metrics = self._parse_metrics(response.text)
408
- current_time = time.time()
409
-
410
- # Set defaults using last known throughputs
411
- current_metrics.setdefault(
412
- "prompt_tokens_per_sec", self._last_throughputs["prompt"]
413
- )
414
- current_metrics.setdefault(
415
- "generation_tokens_per_sec", self._last_throughputs["generation"]
416
- )
417
-
418
- if self._last_updated is None:
419
- self._prev_prompt_tokens = current_metrics.get(
420
- "total_prompt_tokens", 0.0
421
- )
422
- self._prev_generation_tokens = current_metrics.get(
423
- "total_generation_tokens", 0.0
424
- )
425
- self._last_updated = current_time
426
- return current_metrics
427
-
428
- time_diff = current_time - self._last_updated
429
- if time_diff > 0:
430
- current_prompt = current_metrics.get("total_prompt_tokens", 0.0)
431
- current_gen = current_metrics.get("total_generation_tokens", 0.0)
432
-
433
- delta_prompt = current_prompt - self._prev_prompt_tokens
434
- delta_gen = current_gen - self._prev_generation_tokens
435
-
436
- # Only update throughputs when we have new tokens
437
- prompt_tps = (
438
- delta_prompt / time_diff
439
- if delta_prompt > 0
440
- else self._last_throughputs["prompt"]
441
- )
442
- gen_tps = (
443
- delta_gen / time_diff
444
- if delta_gen > 0
445
- else self._last_throughputs["generation"]
446
- )
447
-
448
- current_metrics["prompt_tokens_per_sec"] = prompt_tps
449
- current_metrics["generation_tokens_per_sec"] = gen_tps
450
-
451
- # Persist calculated values regardless of activity
452
- self._last_throughputs["prompt"] = prompt_tps
453
- self._last_throughputs["generation"] = gen_tps
454
-
455
- # Update tracking state
456
- self._prev_prompt_tokens = current_prompt
457
- self._prev_generation_tokens = current_gen
458
- self._last_updated = current_time
459
-
460
- # Calculate average latency if data is available
461
- if (
462
- "request_latency_sum" in current_metrics
463
- and "request_latency_count" in current_metrics
464
- ):
465
- latency_sum = current_metrics["request_latency_sum"]
466
- latency_count = current_metrics["request_latency_count"]
467
- current_metrics["avg_request_latency"] = (
468
- latency_sum / latency_count if latency_count > 0 else 0.0
469
- )
470
-
471
- return current_metrics
472
-
473
- except requests.RequestException as e:
474
- return f"Metrics request failed, `metrics` endpoint might not be ready yet: {str(e)}"
475
-
476
- def _parse_metrics(self, metrics_text: str) -> dict[str, float]:
477
- """Parse metrics with latency count and sum."""
478
- key_metrics = {
479
- "vllm:prompt_tokens_total": "total_prompt_tokens",
480
- "vllm:generation_tokens_total": "total_generation_tokens",
481
- "vllm:e2e_request_latency_seconds_sum": "request_latency_sum",
482
- "vllm:e2e_request_latency_seconds_count": "request_latency_count",
483
- "vllm:request_queue_time_seconds_sum": "queue_time_sum",
484
- "vllm:request_success_total": "successful_requests_total",
485
- "vllm:num_requests_running": "requests_running",
486
- "vllm:num_requests_waiting": "requests_waiting",
487
- "vllm:num_requests_swapped": "requests_swapped",
488
- "vllm:gpu_cache_usage_perc": "gpu_cache_usage",
489
- "vllm:cpu_cache_usage_perc": "cpu_cache_usage",
490
- }
171
+ Returns
172
+ -------
173
+ dict[str, float]
174
+ Processed metrics dictionary
175
+ """
176
+ return metrics if isinstance(metrics, dict) else {}
491
177
 
492
- if self.enabled_prefix_caching:
493
- key_metrics["vllm:gpu_prefix_cache_hit_rate"] = "gpu_prefix_cache_hit_rate"
494
- key_metrics["vllm:cpu_prefix_cache_hit_rate"] = "cpu_prefix_cache_hit_rate"
495
-
496
- parsed: dict[str, float] = {}
497
- for line in metrics_text.split("\n"):
498
- if line.startswith("#") or not line.strip():
499
- continue
500
-
501
- parts = line.split()
502
- if len(parts) < 2:
503
- continue
504
-
505
- metric_name = parts[0].split("{")[0]
506
- if metric_name in key_metrics:
507
- try:
508
- parsed[key_metrics[metric_name]] = float(parts[1])
509
- except (ValueError, IndexError):
510
- continue
511
- return parsed
512
-
513
- def display_failed_metrics(self, table: Table, metrics: str) -> None:
514
- table.add_row("Server State", self.status_info["state"], style="yellow")
515
- table.add_row("Message", metrics)
516
-
517
- def display_metrics(self, table: Table, metrics: dict[str, float]) -> None:
178
+ def _check_prefix_caching(self) -> bool:
179
+ """Check if prefix caching is enabled.
180
+
181
+ Returns
182
+ -------
183
+ bool
184
+ True if prefix caching metrics are present
185
+ """
186
+ return self.metrics.get("gpu_prefix_cache_hit_rate") is not None
187
+
188
+ def format_failed_metrics(self, message: str) -> None:
189
+ """Format error message for failed metrics collection.
190
+
191
+ Parameters
192
+ ----------
193
+ message : str
194
+ Error message to display
195
+ """
196
+ self.table.add_row("ERROR", message)
197
+
198
+ def format_metrics(self) -> None:
199
+ """Format and display all available metrics.
200
+
201
+ Formats and adds to the table:
202
+ - Throughput metrics
203
+ - Request queue metrics
204
+ - Cache usage metrics
205
+ - Prefix cache metrics (if enabled)
206
+ - Latency metrics
207
+ - Token counts
208
+ """
518
209
  # Throughput metrics
519
- table.add_row(
210
+ self.table.add_row(
520
211
  "Prompt Throughput",
521
- f"{metrics.get('prompt_tokens_per_sec', 0):.1f} tokens/s",
212
+ f"{self.metrics.get('prompt_tokens_per_sec', 0):.1f} tokens/s",
522
213
  )
523
- table.add_row(
214
+ self.table.add_row(
524
215
  "Generation Throughput",
525
- f"{metrics.get('generation_tokens_per_sec', 0):.1f} tokens/s",
216
+ f"{self.metrics.get('generation_tokens_per_sec', 0):.1f} tokens/s",
526
217
  )
527
218
 
528
219
  # Request queue metrics
529
- table.add_row(
220
+ self.table.add_row(
530
221
  "Requests Running",
531
- f"{metrics.get('requests_running', 0):.0f} reqs",
222
+ f"{self.metrics.get('requests_running', 0):.0f} reqs",
532
223
  )
533
- table.add_row(
224
+ self.table.add_row(
534
225
  "Requests Waiting",
535
- f"{metrics.get('requests_waiting', 0):.0f} reqs",
226
+ f"{self.metrics.get('requests_waiting', 0):.0f} reqs",
536
227
  )
537
- table.add_row(
228
+ self.table.add_row(
538
229
  "Requests Swapped",
539
- f"{metrics.get('requests_swapped', 0):.0f} reqs",
230
+ f"{self.metrics.get('requests_swapped', 0):.0f} reqs",
540
231
  )
541
232
 
542
233
  # Cache usage metrics
543
- table.add_row(
234
+ self.table.add_row(
544
235
  "GPU Cache Usage",
545
- f"{metrics.get('gpu_cache_usage', 0) * 100:.1f}%",
236
+ f"{self.metrics.get('gpu_cache_usage', 0) * 100:.1f}%",
546
237
  )
547
- table.add_row(
238
+ self.table.add_row(
548
239
  "CPU Cache Usage",
549
- f"{metrics.get('cpu_cache_usage', 0) * 100:.1f}%",
240
+ f"{self.metrics.get('cpu_cache_usage', 0) * 100:.1f}%",
550
241
  )
551
242
 
552
243
  if self.enabled_prefix_caching:
553
- table.add_row(
244
+ self.table.add_row(
554
245
  "GPU Prefix Cache Hit Rate",
555
- f"{metrics.get('gpu_prefix_cache_hit_rate', 0) * 100:.1f}%",
246
+ f"{self.metrics.get('gpu_prefix_cache_hit_rate', 0) * 100:.1f}%",
556
247
  )
557
- table.add_row(
248
+ self.table.add_row(
558
249
  "CPU Prefix Cache Hit Rate",
559
- f"{metrics.get('cpu_prefix_cache_hit_rate', 0) * 100:.1f}%",
250
+ f"{self.metrics.get('cpu_prefix_cache_hit_rate', 0) * 100:.1f}%",
560
251
  )
561
252
 
562
253
  # Show average latency if available
563
- if "avg_request_latency" in metrics:
564
- table.add_row(
254
+ if "avg_request_latency" in self.metrics:
255
+ self.table.add_row(
565
256
  "Avg Request Latency",
566
- f"{metrics['avg_request_latency']:.1f} s",
257
+ f"{self.metrics['avg_request_latency']:.1f} s",
567
258
  )
568
259
 
569
260
  # Token counts
570
- table.add_row(
261
+ self.table.add_row(
571
262
  "Total Prompt Tokens",
572
- f"{metrics.get('total_prompt_tokens', 0):.0f} tokens",
263
+ f"{self.metrics.get('total_prompt_tokens', 0):.0f} tokens",
573
264
  )
574
- table.add_row(
265
+ self.table.add_row(
575
266
  "Total Generation Tokens",
576
- f"{metrics.get('total_generation_tokens', 0):.0f} tokens",
267
+ f"{self.metrics.get('total_generation_tokens', 0):.0f} tokens",
577
268
  )
578
- table.add_row(
269
+ self.table.add_row(
579
270
  "Successful Requests",
580
- f"{metrics.get('successful_requests_total', 0):.0f} reqs",
271
+ f"{self.metrics.get('successful_requests_total', 0):.0f} reqs",
581
272
  )
582
273
 
583
274
 
584
- class ListHelper:
585
- """Helper class for handling model listing functionality."""
275
+ class ListCmdDisplay:
276
+ """CLI Helper class for displaying model listing functionality.
586
277
 
587
- def __init__(self, model_name: Optional[str] = None, json_mode: bool = False):
588
- self.model_name = model_name
589
- self.json_mode = json_mode
590
- self.model_configs = utils.load_config()
278
+ A display class that handles the presentation of model listings
279
+ in both table and JSON formats.
591
280
 
592
- def get_single_model_config(self) -> ModelConfig:
593
- """Get configuration for a specific model."""
594
- config = next(
595
- (c for c in self.model_configs if c.model_name == self.model_name), None
596
- )
597
- if not config:
598
- raise click.ClickException(
599
- f"Model '{self.model_name}' not found in configuration"
600
- )
601
- return config
281
+ Parameters
282
+ ----------
283
+ console : Console
284
+ Rich console instance for output
285
+ json_mode : bool, default=False
286
+ Whether to output in JSON format
287
+ """
288
+
289
+ def __init__(self, console: Console, json_mode: bool = False):
290
+ self.console = console
291
+ self.json_mode = json_mode
292
+ self.model_config = None
293
+ self.model_names: list[str] = []
602
294
 
603
- def format_single_model_output(
295
+ def _format_single_model_output(
604
296
  self, config: ModelConfig
605
297
  ) -> Union[dict[str, Any], Table]:
606
- """Format output for a single model."""
298
+ """Format output table for a single model.
299
+
300
+ Parameters
301
+ ----------
302
+ config : ModelConfig
303
+ Model configuration to format
304
+
305
+ Returns
306
+ -------
307
+ Union[dict[str, Any], Table]
308
+ Either a dictionary for JSON output or a Rich table
309
+ """
607
310
  if self.json_mode:
608
311
  # Exclude non-essential fields from JSON output
609
312
  excluded = {"venv", "log_dir"}
@@ -614,62 +317,84 @@ class ListHelper:
614
317
  )
615
318
  return config_dict
616
319
 
617
- table = utils.create_table(key_title="Model Config", value_title="Value")
320
+ table = create_table(key_title="Model Config", value_title="Value")
618
321
  for field, value in config.model_dump().items():
619
- if field not in {"venv", "log_dir"}:
322
+ if field not in {"venv", "log_dir", "vllm_args"}:
620
323
  table.add_row(field, str(value))
324
+ if field == "vllm_args":
325
+ table.add_row("vLLM Arguments:", style="magenta")
326
+ for vllm_arg, vllm_value in value.items():
327
+ table.add_row(f" {vllm_arg}:", str(vllm_value))
621
328
  return table
622
329
 
623
- def format_all_models_output(self) -> Union[list[str], list[Panel]]:
624
- """Format output for all models."""
625
- if self.json_mode:
626
- return [config.model_name for config in self.model_configs]
627
-
330
+ def _format_all_models_output(
331
+ self, model_infos: list[ModelInfo]
332
+ ) -> Union[list[str], list[Panel]]:
333
+ """Format output table for all models.
334
+
335
+ Parameters
336
+ ----------
337
+ model_infos : list[ModelInfo]
338
+ List of model information to format
339
+
340
+ Returns
341
+ -------
342
+ Union[list[str], list[Panel]]
343
+ Either a list of model names or a list of formatted panels
344
+
345
+ Notes
346
+ -----
347
+ Models are sorted by type priority and color-coded based on their type.
348
+ """
628
349
  # Sort by model type priority
629
- type_priority = {"LLM": 0, "VLM": 1, "Text_Embedding": 2, "Reward_Modeling": 3}
630
- sorted_configs = sorted(
631
- self.model_configs, key=lambda x: type_priority.get(x.model_type, 4)
350
+ sorted_model_infos = sorted(
351
+ model_infos,
352
+ key=lambda x: MODEL_TYPE_PRIORITY.get(x.model_type, 4),
632
353
  )
633
354
 
634
355
  # Create panels with color coding
635
- model_type_colors = {
636
- "LLM": "cyan",
637
- "VLM": "bright_blue",
638
- "Text_Embedding": "purple",
639
- "Reward_Modeling": "bright_magenta",
640
- }
641
-
642
356
  panels = []
643
- for config in sorted_configs:
644
- color = model_type_colors.get(config.model_type, "white")
645
- variant = config.model_variant or ""
646
- display_text = f"[magenta]{config.model_family}[/magenta]"
357
+ for model_info in sorted_model_infos:
358
+ color = MODEL_TYPE_COLORS.get(model_info.model_type, "white")
359
+ variant = model_info.variant or ""
360
+ display_text = f"[magenta]{model_info.family}[/magenta]"
647
361
  if variant:
648
362
  display_text += f"-{variant}"
649
363
  panels.append(Panel(display_text, expand=True, border_style=color))
650
364
 
651
365
  return panels
652
366
 
653
- def process_list_command(self, console: Console) -> None:
654
- """Process the list command and display output."""
655
- try:
656
- if self.model_name:
657
- # Handle single model case
658
- config = self.get_single_model_config()
659
- output = self.format_single_model_output(config)
660
- if self.json_mode:
661
- click.echo(output)
662
- else:
663
- console.print(output)
664
- # Handle all models case
665
- elif self.json_mode:
666
- # JSON output for all models is just a list of names
667
- model_names = [config.model_name for config in self.model_configs]
668
- click.echo(model_names)
669
- else:
670
- # Rich output for all models is a list of panels
671
- panels = self.format_all_models_output()
672
- if isinstance(panels, list): # This helps mypy understand the type
673
- console.print(Columns(panels, equal=True))
674
- except Exception as e:
675
- raise click.ClickException(str(e)) from e
367
+ def display_single_model_output(self, config: ModelConfig) -> None:
368
+ """Display the output for a single model.
369
+
370
+ Parameters
371
+ ----------
372
+ config : ModelConfig
373
+ Model configuration to display
374
+ """
375
+ output = self._format_single_model_output(config)
376
+ if self.json_mode:
377
+ click.echo(output)
378
+ else:
379
+ self.console.print(output)
380
+
381
+ def display_all_models_output(self, model_infos: list[ModelInfo]) -> None:
382
+ """Display the output for all models.
383
+
384
+ Parameters
385
+ ----------
386
+ model_infos : list[ModelInfo]
387
+ List of model information to display
388
+
389
+ Notes
390
+ -----
391
+ Output format depends on json_mode:
392
+ - JSON: List of model names
393
+ - Table: Color-coded panels with model information
394
+ """
395
+ if self.json_mode:
396
+ model_names = [info.name for info in model_infos]
397
+ click.echo(model_names)
398
+ else:
399
+ panels = self._format_all_models_output(model_infos)
400
+ self.console.print(Columns(panels, equal=True))