vec-inf 0.4.0.post1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vec_inf/cli/_helper.py ADDED
@@ -0,0 +1,675 @@
1
+ """Command line interface for Vector Inference."""
2
+
3
+ import json
4
+ import os
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Optional, Union, cast
8
+ from urllib.parse import urlparse, urlunparse
9
+
10
+ import click
11
+ import requests
12
+ from rich.columns import Columns
13
+ from rich.console import Console
14
+ from rich.panel import Panel
15
+ from rich.table import Table
16
+
17
+ import vec_inf.cli._utils as utils
18
+ from vec_inf.cli._config import ModelConfig
19
+
20
+
21
+ VLLM_TASK_MAP = {
22
+ "LLM": "generate",
23
+ "VLM": "generate",
24
+ "Text_Embedding": "embed",
25
+ "Reward_Modeling": "reward",
26
+ }
27
+
28
+ REQUIRED_FIELDS = {
29
+ "model_family",
30
+ "model_type",
31
+ "gpus_per_node",
32
+ "num_nodes",
33
+ "vocab_size",
34
+ "max_model_len",
35
+ }
36
+
37
+ BOOLEAN_FIELDS = {
38
+ "pipeline_parallelism",
39
+ "enforce_eager",
40
+ "enable_prefix_caching",
41
+ "enable_chunked_prefill",
42
+ }
43
+
44
+ LD_LIBRARY_PATH = "/scratch/ssd001/pkgs/cudnn-11.7-v8.5.0.96/lib/:/scratch/ssd001/pkgs/cuda-11.7/targets/x86_64-linux/lib/"
45
+ SRC_DIR = str(Path(__file__).parent.parent)
46
+
47
+
48
+ class LaunchHelper:
49
+ def __init__(
50
+ self, model_name: str, cli_kwargs: dict[str, Optional[Union[str, int, bool]]]
51
+ ):
52
+ self.model_name = model_name
53
+ self.cli_kwargs = cli_kwargs
54
+ self.model_config = self._get_model_configuration()
55
+ self.params = self._get_launch_params()
56
+
57
+ def _get_model_configuration(self) -> ModelConfig:
58
+ """Load and validate model configuration."""
59
+ model_configs = utils.load_config()
60
+ config = next(
61
+ (m for m in model_configs if m.model_name == self.model_name), None
62
+ )
63
+
64
+ if config:
65
+ return config
66
+ # If model config not found, load path from CLI args or fallback to default
67
+ model_weights_parent_dir = self.cli_kwargs.get(
68
+ "model_weights_parent_dir", model_configs[0].model_weights_parent_dir
69
+ )
70
+ model_weights_path = Path(cast(str, model_weights_parent_dir), self.model_name)
71
+ # Only give a warning msg if weights exist but config missing
72
+ if model_weights_path.exists():
73
+ click.echo(
74
+ click.style(
75
+ f"Warning: '{self.model_name}' configuration not found in config, please ensure model configuration are properly set in command arguments",
76
+ fg="yellow",
77
+ )
78
+ )
79
+ # Return a dummy model config object with model name and weights parent dir
80
+ return ModelConfig(
81
+ model_name=self.model_name,
82
+ model_family="model_family_placeholder",
83
+ model_type="LLM",
84
+ gpus_per_node=1,
85
+ num_nodes=1,
86
+ vocab_size=1000,
87
+ max_model_len=8192,
88
+ model_weights_parent_dir=Path(cast(str, model_weights_parent_dir)),
89
+ )
90
+ raise click.ClickException(
91
+ f"'{self.model_name}' not found in configuration and model weights "
92
+ f"not found at expected path '{model_weights_path}'"
93
+ )
94
+
95
+ def _get_launch_params(self) -> dict[str, Any]:
96
+ """Merge config defaults with CLI overrides."""
97
+ params = self.model_config.model_dump()
98
+
99
+ # Process boolean fields
100
+ for bool_field in BOOLEAN_FIELDS:
101
+ if self.cli_kwargs[bool_field]:
102
+ params[bool_field] = True
103
+
104
+ # Merge other overrides
105
+ for key, value in self.cli_kwargs.items():
106
+ if value is not None and key not in [
107
+ "json_mode",
108
+ *BOOLEAN_FIELDS,
109
+ ]:
110
+ params[key] = value
111
+
112
+ # Validate required fields
113
+ if not REQUIRED_FIELDS.issubset(set(params.keys())):
114
+ raise click.ClickException(
115
+ f"Missing required fields: {REQUIRED_FIELDS - set(params.keys())}"
116
+ )
117
+
118
+ # Create log directory
119
+ params["log_dir"] = Path(params["log_dir"], params["model_family"]).expanduser()
120
+ params["log_dir"].mkdir(parents=True, exist_ok=True)
121
+
122
+ # Convert to string for JSON serialization
123
+ for field in params:
124
+ params[field] = str(params[field])
125
+
126
+ return params
127
+
128
+ def set_env_vars(self) -> None:
129
+ """Set environment variables for the launch command."""
130
+ os.environ["MODEL_NAME"] = self.model_name
131
+ os.environ["MAX_MODEL_LEN"] = self.params["max_model_len"]
132
+ os.environ["MAX_LOGPROBS"] = self.params["vocab_size"]
133
+ os.environ["DATA_TYPE"] = self.params["data_type"]
134
+ os.environ["MAX_NUM_SEQS"] = self.params["max_num_seqs"]
135
+ os.environ["GPU_MEMORY_UTILIZATION"] = self.params["gpu_memory_utilization"]
136
+ os.environ["TASK"] = VLLM_TASK_MAP[self.params["model_type"]]
137
+ os.environ["PIPELINE_PARALLELISM"] = self.params["pipeline_parallelism"]
138
+ os.environ["COMPILATION_CONFIG"] = self.params["compilation_config"]
139
+ os.environ["SRC_DIR"] = SRC_DIR
140
+ os.environ["MODEL_WEIGHTS"] = str(
141
+ Path(self.params["model_weights_parent_dir"], self.model_name)
142
+ )
143
+ os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH
144
+ os.environ["VENV_BASE"] = self.params["venv"]
145
+ os.environ["LOG_DIR"] = self.params["log_dir"]
146
+
147
+ if self.params.get("enable_prefix_caching"):
148
+ os.environ["ENABLE_PREFIX_CACHING"] = self.params["enable_prefix_caching"]
149
+ if self.params.get("enable_chunked_prefill"):
150
+ os.environ["ENABLE_CHUNKED_PREFILL"] = self.params["enable_chunked_prefill"]
151
+ if self.params.get("max_num_batched_tokens"):
152
+ os.environ["MAX_NUM_BATCHED_TOKENS"] = self.params["max_num_batched_tokens"]
153
+ if self.params.get("enforce_eager"):
154
+ os.environ["ENFORCE_EAGER"] = self.params["enforce_eager"]
155
+
156
+ def build_launch_command(self) -> str:
157
+ """Construct the full launch command with parameters."""
158
+ # Base command
159
+ command_list = ["sbatch"]
160
+ # Append options
161
+ command_list.extend(["--job-name", f"{self.model_name}"])
162
+ command_list.extend(["--partition", f"{self.params['partition']}"])
163
+ command_list.extend(["--qos", f"{self.params['qos']}"])
164
+ command_list.extend(["--time", f"{self.params['time']}"])
165
+ command_list.extend(["--nodes", f"{self.params['num_nodes']}"])
166
+ command_list.extend(["--gpus-per-node", f"{self.params['gpus_per_node']}"])
167
+ command_list.extend(
168
+ [
169
+ "--output",
170
+ f"{self.params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.out",
171
+ ]
172
+ )
173
+ command_list.extend(
174
+ [
175
+ "--error",
176
+ f"{self.params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.err",
177
+ ]
178
+ )
179
+ # Add slurm script
180
+ slurm_script = "vllm.slurm"
181
+ if int(self.params["num_nodes"]) > 1:
182
+ slurm_script = "multinode_vllm.slurm"
183
+ command_list.append(f"{SRC_DIR}/{slurm_script}")
184
+ return " ".join(command_list)
185
+
186
+ def format_table_output(self, job_id: str) -> Table:
187
+ """Format output as rich Table."""
188
+ table = utils.create_table(key_title="Job Config", value_title="Value")
189
+ # Add rows
190
+ table.add_row("Slurm Job ID", job_id, style="blue")
191
+ table.add_row("Job Name", self.model_name)
192
+ table.add_row("Model Type", self.params["model_type"])
193
+ table.add_row("Partition", self.params["partition"])
194
+ table.add_row("QoS", self.params["qos"])
195
+ table.add_row("Time Limit", self.params["time"])
196
+ table.add_row("Num Nodes", self.params["num_nodes"])
197
+ table.add_row("GPUs/Node", self.params["gpus_per_node"])
198
+ table.add_row("Data Type", self.params["data_type"])
199
+ table.add_row("Vocabulary Size", self.params["vocab_size"])
200
+ table.add_row("Max Model Length", self.params["max_model_len"])
201
+ table.add_row("Max Num Seqs", self.params["max_num_seqs"])
202
+ table.add_row("GPU Memory Utilization", self.params["gpu_memory_utilization"])
203
+ table.add_row("Compilation Config", self.params["compilation_config"])
204
+ table.add_row("Pipeline Parallelism", self.params["pipeline_parallelism"])
205
+ if self.params.get("enable_prefix_caching"):
206
+ table.add_row("Enable Prefix Caching", self.params["enable_prefix_caching"])
207
+ if self.params.get("enable_chunked_prefill"):
208
+ table.add_row(
209
+ "Enable Chunked Prefill", self.params["enable_chunked_prefill"]
210
+ )
211
+ if self.params.get("max_num_batched_tokens"):
212
+ table.add_row(
213
+ "Max Num Batched Tokens", self.params["max_num_batched_tokens"]
214
+ )
215
+ if self.params.get("enforce_eager"):
216
+ table.add_row("Enforce Eager", self.params["enforce_eager"])
217
+ table.add_row("Model Weights Directory", os.environ.get("MODEL_WEIGHTS"))
218
+ table.add_row("Log Directory", self.params["log_dir"])
219
+
220
+ return table
221
+
222
+ def post_launch_processing(self, output: str, console: Console) -> None:
223
+ """Process and display launch output."""
224
+ json_mode = bool(self.cli_kwargs.get("json_mode", False))
225
+ slurm_job_id = output.split(" ")[-1].strip().strip("\n")
226
+ self.params["slurm_job_id"] = slurm_job_id
227
+ job_json = Path(
228
+ self.params["log_dir"],
229
+ f"{self.model_name}.{slurm_job_id}",
230
+ f"{self.model_name}.{slurm_job_id}.json",
231
+ )
232
+ job_json.parent.mkdir(parents=True, exist_ok=True)
233
+ job_json.touch(exist_ok=True)
234
+
235
+ with job_json.open("w") as file:
236
+ json.dump(self.params, file, indent=4)
237
+ if json_mode:
238
+ click.echo(self.params)
239
+ else:
240
+ table = self.format_table_output(slurm_job_id)
241
+ console.print(table)
242
+
243
+
244
+ class StatusHelper:
245
+ def __init__(self, slurm_job_id: int, output: str, log_dir: Optional[str] = None):
246
+ self.slurm_job_id = slurm_job_id
247
+ self.output = output
248
+ self.log_dir = log_dir
249
+ self.status_info = self._get_base_status_data()
250
+
251
+ def _get_base_status_data(self) -> dict[str, Union[str, None]]:
252
+ """Extract basic job status information from scontrol output."""
253
+ try:
254
+ job_name = self.output.split(" ")[1].split("=")[1]
255
+ job_state = self.output.split(" ")[9].split("=")[1]
256
+ except IndexError:
257
+ job_name = "UNAVAILABLE"
258
+ job_state = "UNAVAILABLE"
259
+
260
+ return {
261
+ "model_name": job_name,
262
+ "status": "UNAVAILABLE",
263
+ "base_url": "UNAVAILABLE",
264
+ "state": job_state,
265
+ "pending_reason": None,
266
+ "failed_reason": None,
267
+ }
268
+
269
+ def process_job_state(self) -> None:
270
+ """Process different job states and update status information."""
271
+ if self.status_info["state"] == "PENDING":
272
+ self.process_pending_state()
273
+ elif self.status_info["state"] == "RUNNING":
274
+ self.process_running_state()
275
+
276
+ def check_model_health(self) -> None:
277
+ """Check model health and update status accordingly."""
278
+ status, status_code = utils.model_health_check(
279
+ cast(str, self.status_info["model_name"]), self.slurm_job_id, self.log_dir
280
+ )
281
+ if status == "READY":
282
+ self.status_info["base_url"] = utils.get_base_url(
283
+ cast(str, self.status_info["model_name"]),
284
+ self.slurm_job_id,
285
+ self.log_dir,
286
+ )
287
+ self.status_info["status"] = status
288
+ else:
289
+ self.status_info["status"], self.status_info["failed_reason"] = (
290
+ status,
291
+ cast(str, status_code),
292
+ )
293
+
294
+ def process_running_state(self) -> None:
295
+ """Process RUNNING job state and check server status."""
296
+ server_status = utils.is_server_running(
297
+ cast(str, self.status_info["model_name"]), self.slurm_job_id, self.log_dir
298
+ )
299
+
300
+ if isinstance(server_status, tuple):
301
+ self.status_info["status"], self.status_info["failed_reason"] = (
302
+ server_status
303
+ )
304
+ return
305
+
306
+ if server_status == "RUNNING":
307
+ self.check_model_health()
308
+ else:
309
+ self.status_info["status"] = server_status
310
+
311
+ def process_pending_state(self) -> None:
312
+ """Process PENDING job state."""
313
+ try:
314
+ self.status_info["pending_reason"] = self.output.split(" ")[10].split("=")[
315
+ 1
316
+ ]
317
+ self.status_info["status"] = "PENDING"
318
+ except IndexError:
319
+ self.status_info["pending_reason"] = "Unknown pending reason"
320
+
321
+ def output_json(self) -> None:
322
+ """Format and output JSON data."""
323
+ json_data = {
324
+ "model_name": self.status_info["model_name"],
325
+ "model_status": self.status_info["status"],
326
+ "base_url": self.status_info["base_url"],
327
+ }
328
+ if self.status_info["pending_reason"]:
329
+ json_data["pending_reason"] = self.status_info["pending_reason"]
330
+ if self.status_info["failed_reason"]:
331
+ json_data["failed_reason"] = self.status_info["failed_reason"]
332
+ click.echo(json_data)
333
+
334
+ def output_table(self, console: Console) -> None:
335
+ """Create and display rich table."""
336
+ table = utils.create_table(key_title="Job Status", value_title="Value")
337
+ table.add_row("Model Name", self.status_info["model_name"])
338
+ table.add_row("Model Status", self.status_info["status"], style="blue")
339
+
340
+ if self.status_info["pending_reason"]:
341
+ table.add_row("Pending Reason", self.status_info["pending_reason"])
342
+ if self.status_info["failed_reason"]:
343
+ table.add_row("Failed Reason", self.status_info["failed_reason"])
344
+
345
+ table.add_row("Base URL", self.status_info["base_url"])
346
+ console.print(table)
347
+
348
+
349
+ class MetricsHelper:
350
+ def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None):
351
+ self.slurm_job_id = slurm_job_id
352
+ self.log_dir = log_dir
353
+ self.status_info = self._get_status_info()
354
+ self.metrics_url = self._build_metrics_url()
355
+ self.enabled_prefix_caching = self._check_prefix_caching()
356
+
357
+ self._prev_prompt_tokens: float = 0.0
358
+ self._prev_generation_tokens: float = 0.0
359
+ self._last_updated: Optional[float] = None
360
+ self._last_throughputs = {"prompt": 0.0, "generation": 0.0}
361
+
362
+ def _get_status_info(self) -> dict[str, Union[str, None]]:
363
+ """Retrieve status info using existing StatusHelper."""
364
+ status_cmd = f"scontrol show job {self.slurm_job_id} --oneliner"
365
+ output, stderr = utils.run_bash_command(status_cmd)
366
+ if stderr:
367
+ raise click.ClickException(f"Error: {stderr}")
368
+ status_helper = StatusHelper(self.slurm_job_id, output, self.log_dir)
369
+ return status_helper.status_info
370
+
371
+ def _build_metrics_url(self) -> str:
372
+ """Construct metrics endpoint URL from base URL with version stripping."""
373
+ if self.status_info.get("state") == "PENDING":
374
+ return "Pending resources for server initialization"
375
+
376
+ base_url = utils.get_base_url(
377
+ cast(str, self.status_info["model_name"]),
378
+ self.slurm_job_id,
379
+ self.log_dir,
380
+ )
381
+ if not base_url.startswith("http"):
382
+ return "Server not ready"
383
+
384
+ parsed = urlparse(base_url)
385
+ clean_path = parsed.path.replace("/v1", "", 1).rstrip("/")
386
+ return urlunparse(
387
+ (parsed.scheme, parsed.netloc, f"{clean_path}/metrics", "", "", "")
388
+ )
389
+
390
+ def _check_prefix_caching(self) -> bool:
391
+ """Check if prefix caching is enabled."""
392
+ job_json = utils.read_slurm_log(
393
+ cast(str, self.status_info["model_name"]),
394
+ self.slurm_job_id,
395
+ "json",
396
+ self.log_dir,
397
+ )
398
+ if isinstance(job_json, str):
399
+ return False
400
+ return bool(cast(dict[str, str], job_json).get("enable_prefix_caching", False))
401
+
402
+ def fetch_metrics(self) -> Union[dict[str, float], str]:
403
+ """Fetch metrics from the endpoint."""
404
+ try:
405
+ response = requests.get(self.metrics_url, timeout=3)
406
+ response.raise_for_status()
407
+ current_metrics = self._parse_metrics(response.text)
408
+ current_time = time.time()
409
+
410
+ # Set defaults using last known throughputs
411
+ current_metrics.setdefault(
412
+ "prompt_tokens_per_sec", self._last_throughputs["prompt"]
413
+ )
414
+ current_metrics.setdefault(
415
+ "generation_tokens_per_sec", self._last_throughputs["generation"]
416
+ )
417
+
418
+ if self._last_updated is None:
419
+ self._prev_prompt_tokens = current_metrics.get(
420
+ "total_prompt_tokens", 0.0
421
+ )
422
+ self._prev_generation_tokens = current_metrics.get(
423
+ "total_generation_tokens", 0.0
424
+ )
425
+ self._last_updated = current_time
426
+ return current_metrics
427
+
428
+ time_diff = current_time - self._last_updated
429
+ if time_diff > 0:
430
+ current_prompt = current_metrics.get("total_prompt_tokens", 0.0)
431
+ current_gen = current_metrics.get("total_generation_tokens", 0.0)
432
+
433
+ delta_prompt = current_prompt - self._prev_prompt_tokens
434
+ delta_gen = current_gen - self._prev_generation_tokens
435
+
436
+ # Only update throughputs when we have new tokens
437
+ prompt_tps = (
438
+ delta_prompt / time_diff
439
+ if delta_prompt > 0
440
+ else self._last_throughputs["prompt"]
441
+ )
442
+ gen_tps = (
443
+ delta_gen / time_diff
444
+ if delta_gen > 0
445
+ else self._last_throughputs["generation"]
446
+ )
447
+
448
+ current_metrics["prompt_tokens_per_sec"] = prompt_tps
449
+ current_metrics["generation_tokens_per_sec"] = gen_tps
450
+
451
+ # Persist calculated values regardless of activity
452
+ self._last_throughputs["prompt"] = prompt_tps
453
+ self._last_throughputs["generation"] = gen_tps
454
+
455
+ # Update tracking state
456
+ self._prev_prompt_tokens = current_prompt
457
+ self._prev_generation_tokens = current_gen
458
+ self._last_updated = current_time
459
+
460
+ # Calculate average latency if data is available
461
+ if (
462
+ "request_latency_sum" in current_metrics
463
+ and "request_latency_count" in current_metrics
464
+ ):
465
+ latency_sum = current_metrics["request_latency_sum"]
466
+ latency_count = current_metrics["request_latency_count"]
467
+ current_metrics["avg_request_latency"] = (
468
+ latency_sum / latency_count if latency_count > 0 else 0.0
469
+ )
470
+
471
+ return current_metrics
472
+
473
+ except requests.RequestException as e:
474
+ return f"Metrics request failed, `metrics` endpoint might not be ready yet: {str(e)}"
475
+
476
+ def _parse_metrics(self, metrics_text: str) -> dict[str, float]:
477
+ """Parse metrics with latency count and sum."""
478
+ key_metrics = {
479
+ "vllm:prompt_tokens_total": "total_prompt_tokens",
480
+ "vllm:generation_tokens_total": "total_generation_tokens",
481
+ "vllm:e2e_request_latency_seconds_sum": "request_latency_sum",
482
+ "vllm:e2e_request_latency_seconds_count": "request_latency_count",
483
+ "vllm:request_queue_time_seconds_sum": "queue_time_sum",
484
+ "vllm:request_success_total": "successful_requests_total",
485
+ "vllm:num_requests_running": "requests_running",
486
+ "vllm:num_requests_waiting": "requests_waiting",
487
+ "vllm:num_requests_swapped": "requests_swapped",
488
+ "vllm:gpu_cache_usage_perc": "gpu_cache_usage",
489
+ "vllm:cpu_cache_usage_perc": "cpu_cache_usage",
490
+ }
491
+
492
+ if self.enabled_prefix_caching:
493
+ key_metrics["vllm:gpu_prefix_cache_hit_rate"] = "gpu_prefix_cache_hit_rate"
494
+ key_metrics["vllm:cpu_prefix_cache_hit_rate"] = "cpu_prefix_cache_hit_rate"
495
+
496
+ parsed: dict[str, float] = {}
497
+ for line in metrics_text.split("\n"):
498
+ if line.startswith("#") or not line.strip():
499
+ continue
500
+
501
+ parts = line.split()
502
+ if len(parts) < 2:
503
+ continue
504
+
505
+ metric_name = parts[0].split("{")[0]
506
+ if metric_name in key_metrics:
507
+ try:
508
+ parsed[key_metrics[metric_name]] = float(parts[1])
509
+ except (ValueError, IndexError):
510
+ continue
511
+ return parsed
512
+
513
+ def display_failed_metrics(self, table: Table, metrics: str) -> None:
514
+ table.add_row("Server State", self.status_info["state"], style="yellow")
515
+ table.add_row("Message", metrics)
516
+
517
+ def display_metrics(self, table: Table, metrics: dict[str, float]) -> None:
518
+ # Throughput metrics
519
+ table.add_row(
520
+ "Prompt Throughput",
521
+ f"{metrics.get('prompt_tokens_per_sec', 0):.1f} tokens/s",
522
+ )
523
+ table.add_row(
524
+ "Generation Throughput",
525
+ f"{metrics.get('generation_tokens_per_sec', 0):.1f} tokens/s",
526
+ )
527
+
528
+ # Request queue metrics
529
+ table.add_row(
530
+ "Requests Running",
531
+ f"{metrics.get('requests_running', 0):.0f} reqs",
532
+ )
533
+ table.add_row(
534
+ "Requests Waiting",
535
+ f"{metrics.get('requests_waiting', 0):.0f} reqs",
536
+ )
537
+ table.add_row(
538
+ "Requests Swapped",
539
+ f"{metrics.get('requests_swapped', 0):.0f} reqs",
540
+ )
541
+
542
+ # Cache usage metrics
543
+ table.add_row(
544
+ "GPU Cache Usage",
545
+ f"{metrics.get('gpu_cache_usage', 0) * 100:.1f}%",
546
+ )
547
+ table.add_row(
548
+ "CPU Cache Usage",
549
+ f"{metrics.get('cpu_cache_usage', 0) * 100:.1f}%",
550
+ )
551
+
552
+ if self.enabled_prefix_caching:
553
+ table.add_row(
554
+ "GPU Prefix Cache Hit Rate",
555
+ f"{metrics.get('gpu_prefix_cache_hit_rate', 0) * 100:.1f}%",
556
+ )
557
+ table.add_row(
558
+ "CPU Prefix Cache Hit Rate",
559
+ f"{metrics.get('cpu_prefix_cache_hit_rate', 0) * 100:.1f}%",
560
+ )
561
+
562
+ # Show average latency if available
563
+ if "avg_request_latency" in metrics:
564
+ table.add_row(
565
+ "Avg Request Latency",
566
+ f"{metrics['avg_request_latency']:.1f} s",
567
+ )
568
+
569
+ # Token counts
570
+ table.add_row(
571
+ "Total Prompt Tokens",
572
+ f"{metrics.get('total_prompt_tokens', 0):.0f} tokens",
573
+ )
574
+ table.add_row(
575
+ "Total Generation Tokens",
576
+ f"{metrics.get('total_generation_tokens', 0):.0f} tokens",
577
+ )
578
+ table.add_row(
579
+ "Successful Requests",
580
+ f"{metrics.get('successful_requests_total', 0):.0f} reqs",
581
+ )
582
+
583
+
584
+ class ListHelper:
585
+ """Helper class for handling model listing functionality."""
586
+
587
+ def __init__(self, model_name: Optional[str] = None, json_mode: bool = False):
588
+ self.model_name = model_name
589
+ self.json_mode = json_mode
590
+ self.model_configs = utils.load_config()
591
+
592
+ def get_single_model_config(self) -> ModelConfig:
593
+ """Get configuration for a specific model."""
594
+ config = next(
595
+ (c for c in self.model_configs if c.model_name == self.model_name), None
596
+ )
597
+ if not config:
598
+ raise click.ClickException(
599
+ f"Model '{self.model_name}' not found in configuration"
600
+ )
601
+ return config
602
+
603
+ def format_single_model_output(
604
+ self, config: ModelConfig
605
+ ) -> Union[dict[str, Any], Table]:
606
+ """Format output for a single model."""
607
+ if self.json_mode:
608
+ # Exclude non-essential fields from JSON output
609
+ excluded = {"venv", "log_dir"}
610
+ config_dict = config.model_dump(exclude=excluded)
611
+ # Convert Path objects to strings
612
+ config_dict["model_weights_parent_dir"] = str(
613
+ config_dict["model_weights_parent_dir"]
614
+ )
615
+ return config_dict
616
+
617
+ table = utils.create_table(key_title="Model Config", value_title="Value")
618
+ for field, value in config.model_dump().items():
619
+ if field not in {"venv", "log_dir"}:
620
+ table.add_row(field, str(value))
621
+ return table
622
+
623
+ def format_all_models_output(self) -> Union[list[str], list[Panel]]:
624
+ """Format output for all models."""
625
+ if self.json_mode:
626
+ return [config.model_name for config in self.model_configs]
627
+
628
+ # Sort by model type priority
629
+ type_priority = {"LLM": 0, "VLM": 1, "Text_Embedding": 2, "Reward_Modeling": 3}
630
+ sorted_configs = sorted(
631
+ self.model_configs, key=lambda x: type_priority.get(x.model_type, 4)
632
+ )
633
+
634
+ # Create panels with color coding
635
+ model_type_colors = {
636
+ "LLM": "cyan",
637
+ "VLM": "bright_blue",
638
+ "Text_Embedding": "purple",
639
+ "Reward_Modeling": "bright_magenta",
640
+ }
641
+
642
+ panels = []
643
+ for config in sorted_configs:
644
+ color = model_type_colors.get(config.model_type, "white")
645
+ variant = config.model_variant or ""
646
+ display_text = f"[magenta]{config.model_family}[/magenta]"
647
+ if variant:
648
+ display_text += f"-{variant}"
649
+ panels.append(Panel(display_text, expand=True, border_style=color))
650
+
651
+ return panels
652
+
653
+ def process_list_command(self, console: Console) -> None:
654
+ """Process the list command and display output."""
655
+ try:
656
+ if self.model_name:
657
+ # Handle single model case
658
+ config = self.get_single_model_config()
659
+ output = self.format_single_model_output(config)
660
+ if self.json_mode:
661
+ click.echo(output)
662
+ else:
663
+ console.print(output)
664
+ # Handle all models case
665
+ elif self.json_mode:
666
+ # JSON output for all models is just a list of names
667
+ model_names = [config.model_name for config in self.model_configs]
668
+ click.echo(model_names)
669
+ else:
670
+ # Rich output for all models is a list of panels
671
+ panels = self.format_all_models_output()
672
+ if isinstance(panels, list): # This helps mypy understand the type
673
+ console.print(Columns(panels, equal=True))
674
+ except Exception as e:
675
+ raise click.ClickException(str(e)) from e