vec-inf 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,674 @@
1
+ """Helper classes for the model.
2
+
3
+ This module provides utility classes for managing model deployment, status monitoring,
4
+ metrics collection, and model registry operations.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ import warnings
11
+ from pathlib import Path
12
+ from typing import Any, Optional, Union, cast
13
+ from urllib.parse import urlparse, urlunparse
14
+
15
+ import requests
16
+
17
+ import vec_inf.client._utils as utils
18
+ from vec_inf.client._client_vars import (
19
+ KEY_METRICS,
20
+ REQUIRED_FIELDS,
21
+ SRC_DIR,
22
+ VLLM_SHORT_TO_LONG_MAP,
23
+ )
24
+ from vec_inf.client._exceptions import (
25
+ MissingRequiredFieldsError,
26
+ ModelConfigurationError,
27
+ ModelNotFoundError,
28
+ SlurmJobError,
29
+ )
30
+ from vec_inf.client._slurm_script_generator import SlurmScriptGenerator
31
+ from vec_inf.client.config import ModelConfig
32
+ from vec_inf.client.models import (
33
+ LaunchResponse,
34
+ ModelInfo,
35
+ ModelStatus,
36
+ ModelType,
37
+ StatusResponse,
38
+ )
39
+ from vec_inf.client.slurm_vars import (
40
+ LD_LIBRARY_PATH,
41
+ VLLM_NCCL_SO_PATH,
42
+ )
43
+
44
+
45
+ class ModelLauncher:
46
+ """Helper class for handling inference server launch.
47
+
48
+ A class that manages the launch process of inference servers, including
49
+ configuration validation, parameter preparation, and SLURM job submission.
50
+
51
+ Parameters
52
+ ----------
53
+ model_name : str
54
+ Name of the model to launch
55
+ kwargs : dict[str, Any], optional
56
+ Optional launch keyword arguments to override default configuration
57
+ """
58
+
59
+ def __init__(self, model_name: str, kwargs: Optional[dict[str, Any]]):
60
+ """Initialize the model launcher.
61
+
62
+ Parameters
63
+ ----------
64
+ model_name: str
65
+ Name of the model to launch
66
+ kwargs: Optional[dict[str, Any]]
67
+ Optional launch keyword arguments to override default configuration
68
+ """
69
+ self.model_name = model_name
70
+ self.kwargs = kwargs or {}
71
+ self.slurm_job_id = ""
72
+ self.slurm_script_path = Path("")
73
+ self.model_config = self._get_model_configuration()
74
+ self.params = self._get_launch_params()
75
+
76
+ def _warn(self, message: str) -> None:
77
+ """Warn the user about a potential issue.
78
+
79
+ Parameters
80
+ ----------
81
+ message : str
82
+ Warning message to display
83
+ """
84
+ warnings.warn(message, UserWarning, stacklevel=2)
85
+
86
+ def _get_model_configuration(self) -> ModelConfig:
87
+ """Load and validate model configuration.
88
+
89
+ Returns
90
+ -------
91
+ ModelConfig
92
+ Validated configuration for the model
93
+
94
+ Raises
95
+ ------
96
+ ModelNotFoundError
97
+ If model weights parent directory cannot be determined
98
+ ModelConfigurationError
99
+ If model configuration is not found and weights don't exist
100
+ """
101
+ model_configs = utils.load_config()
102
+ config = next(
103
+ (m for m in model_configs if m.model_name == self.model_name), None
104
+ )
105
+
106
+ if config:
107
+ return config
108
+
109
+ # If model config not found, check for path from CLI kwargs or use fallback
110
+ model_weights_parent_dir = self.kwargs.get(
111
+ "model_weights_parent_dir",
112
+ model_configs[0].model_weights_parent_dir if model_configs else None,
113
+ )
114
+
115
+ if not model_weights_parent_dir:
116
+ raise ModelNotFoundError(
117
+ "Could not determine model weights parent directory"
118
+ )
119
+
120
+ model_weights_path = Path(model_weights_parent_dir, self.model_name)
121
+
122
+ # Only give a warning if weights exist but config missing
123
+ if model_weights_path.exists():
124
+ self._warn(
125
+ f"Warning: '{self.model_name}' configuration not found in config, please ensure model configuration are properly set in command arguments",
126
+ )
127
+ # Return a dummy model config object with model name and weights parent dir
128
+ return ModelConfig(
129
+ model_name=self.model_name,
130
+ model_family="model_family_placeholder",
131
+ model_type="LLM",
132
+ gpus_per_node=1,
133
+ num_nodes=1,
134
+ vocab_size=1000,
135
+ model_weights_parent_dir=Path(str(model_weights_parent_dir)),
136
+ )
137
+
138
+ raise ModelConfigurationError(
139
+ f"'{self.model_name}' not found in configuration and model weights "
140
+ f"not found at expected path '{model_weights_path}'"
141
+ )
142
+
143
+ def _process_vllm_args(self, arg_string: str) -> dict[str, Any]:
144
+ """Process the vllm_args string into a dictionary.
145
+
146
+ Parameters
147
+ ----------
148
+ arg_string : str
149
+ Comma-separated string of vLLM arguments
150
+
151
+ Returns
152
+ -------
153
+ dict[str, Any]
154
+ Processed vLLM arguments as key-value pairs
155
+ """
156
+ vllm_args: dict[str, str | bool] = {}
157
+ for arg in arg_string.split(","):
158
+ if "=" in arg:
159
+ key, value = arg.split("=")
160
+ if key.strip() in VLLM_SHORT_TO_LONG_MAP:
161
+ key = VLLM_SHORT_TO_LONG_MAP[key.strip()]
162
+ vllm_args[key.strip()] = value.strip()
163
+ elif "-O" in arg.strip():
164
+ key = VLLM_SHORT_TO_LONG_MAP["-O"]
165
+ vllm_args[key] = arg.strip()[2:].strip()
166
+ else:
167
+ vllm_args[arg.strip()] = True
168
+ return vllm_args
169
+
170
+ def _get_launch_params(self) -> dict[str, Any]:
171
+ """Prepare launch parameters, set log dir, and validate required fields.
172
+
173
+ Returns
174
+ -------
175
+ dict[str, Any]
176
+ Dictionary of prepared launch parameters
177
+
178
+ Raises
179
+ ------
180
+ MissingRequiredFieldsError
181
+ If required fields are missing or tensor parallel size is not specified
182
+ when using multiple GPUs
183
+ """
184
+ params = self.model_config.model_dump(exclude_none=True)
185
+
186
+ # Override config defaults with CLI arguments
187
+ if self.kwargs.get("vllm_args"):
188
+ vllm_args = self._process_vllm_args(self.kwargs["vllm_args"])
189
+ for key, value in vllm_args.items():
190
+ params["vllm_args"][key] = value
191
+ del self.kwargs["vllm_args"]
192
+
193
+ for key, value in self.kwargs.items():
194
+ params[key] = value
195
+
196
+ # Validate required fields and vllm args
197
+ if not REQUIRED_FIELDS.issubset(set(params.keys())):
198
+ raise MissingRequiredFieldsError(
199
+ f"Missing required fields: {REQUIRED_FIELDS - set(params.keys())}"
200
+ )
201
+ if (
202
+ int(params["gpus_per_node"]) > 1
203
+ and params["vllm_args"].get("--tensor-parallel-size") is None
204
+ ):
205
+ raise MissingRequiredFieldsError(
206
+ "--tensor-parallel-size is required when gpus_per_node > 1"
207
+ )
208
+
209
+ # Create log directory
210
+ params["log_dir"] = Path(params["log_dir"], params["model_family"]).expanduser()
211
+ params["log_dir"].mkdir(parents=True, exist_ok=True)
212
+ params["src_dir"] = SRC_DIR
213
+
214
+ # Construct slurm log file paths
215
+ params["out_file"] = (
216
+ f"{params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.out"
217
+ )
218
+ params["err_file"] = (
219
+ f"{params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.err"
220
+ )
221
+ params["json_file"] = (
222
+ f"{params['log_dir']}/{self.model_name}.$SLURM_JOB_ID/{self.model_name}.$SLURM_JOB_ID.json"
223
+ )
224
+
225
+ # Convert path to string for JSON serialization
226
+ for field in params:
227
+ if field == "vllm_args":
228
+ continue
229
+ params[field] = str(params[field])
230
+
231
+ return params
232
+
233
+ def _set_env_vars(self) -> None:
234
+ """Set environment variables for the launch command."""
235
+ os.environ["LD_LIBRARY_PATH"] = LD_LIBRARY_PATH
236
+ os.environ["VLLM_NCCL_SO_PATH"] = VLLM_NCCL_SO_PATH
237
+
238
+ def _build_launch_command(self) -> str:
239
+ """Generate the slurm script and construct the launch command.
240
+
241
+ Returns
242
+ -------
243
+ str
244
+ Complete SLURM launch command
245
+ """
246
+ self.slurm_script_path = SlurmScriptGenerator(self.params).write_to_log_dir()
247
+ return f"sbatch {self.slurm_script_path}"
248
+
249
+ def launch(self) -> LaunchResponse:
250
+ """Launch the model.
251
+
252
+ Returns
253
+ -------
254
+ LaunchResponse
255
+ Response object containing launch details and status
256
+
257
+ Raises
258
+ ------
259
+ SlurmJobError
260
+ If SLURM job submission fails
261
+ """
262
+ # Set environment variables
263
+ self._set_env_vars()
264
+
265
+ # Build and execute the launch command
266
+ command_output, stderr = utils.run_bash_command(self._build_launch_command())
267
+
268
+ if stderr:
269
+ raise SlurmJobError(f"Error: {stderr}")
270
+
271
+ # Extract slurm job id from command output
272
+ self.slurm_job_id = command_output.split(" ")[-1].strip().strip("\n")
273
+ self.params["slurm_job_id"] = self.slurm_job_id
274
+
275
+ # Create log directory and job json file, move slurm script to job log directory
276
+ job_log_dir = Path(
277
+ self.params["log_dir"], f"{self.model_name}.{self.slurm_job_id}"
278
+ )
279
+ job_log_dir.mkdir(parents=True, exist_ok=True)
280
+
281
+ job_json = Path(
282
+ job_log_dir,
283
+ f"{self.model_name}.{self.slurm_job_id}.json",
284
+ )
285
+ job_json.touch(exist_ok=True)
286
+
287
+ self.slurm_script_path.rename(
288
+ job_log_dir / f"{self.model_name}.{self.slurm_job_id}.slurm"
289
+ )
290
+
291
+ with job_json.open("w") as file:
292
+ json.dump(self.params, file, indent=4)
293
+
294
+ return LaunchResponse(
295
+ slurm_job_id=int(self.slurm_job_id),
296
+ model_name=self.model_name,
297
+ config=self.params,
298
+ raw_output=command_output,
299
+ )
300
+
301
+
302
+ class ModelStatusMonitor:
303
+ """Class for handling server status information and monitoring.
304
+
305
+ A class that monitors and reports the status of deployed model servers,
306
+ including job state and health checks.
307
+
308
+ Parameters
309
+ ----------
310
+ slurm_job_id : int
311
+ ID of the SLURM job to monitor
312
+ log_dir : str, optional
313
+ Base directory containing log files
314
+ """
315
+
316
+ def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None):
317
+ self.slurm_job_id = slurm_job_id
318
+ self.output = self._get_raw_status_output()
319
+ self.log_dir = log_dir
320
+ self.status_info = self._get_base_status_data()
321
+
322
+ def _get_raw_status_output(self) -> str:
323
+ """Get the raw server status output from slurm.
324
+
325
+ Returns
326
+ -------
327
+ str
328
+ Raw output from scontrol command
329
+
330
+ Raises
331
+ ------
332
+ SlurmJobError
333
+ If status check fails
334
+ """
335
+ status_cmd = f"scontrol show job {self.slurm_job_id} --oneliner"
336
+ output, stderr = utils.run_bash_command(status_cmd)
337
+ if stderr:
338
+ raise SlurmJobError(f"Error: {stderr}")
339
+ return output
340
+
341
+ def _get_base_status_data(self) -> StatusResponse:
342
+ """Extract basic job status information from scontrol output.
343
+
344
+ Returns
345
+ -------
346
+ StatusResponse
347
+ Basic status information for the job
348
+ """
349
+ try:
350
+ job_name = self.output.split(" ")[1].split("=")[1]
351
+ job_state = self.output.split(" ")[9].split("=")[1]
352
+ except IndexError:
353
+ job_name = "UNAVAILABLE"
354
+ job_state = ModelStatus.UNAVAILABLE
355
+
356
+ return StatusResponse(
357
+ model_name=job_name,
358
+ server_status=ModelStatus.UNAVAILABLE,
359
+ job_state=job_state,
360
+ raw_output=self.output,
361
+ base_url="UNAVAILABLE",
362
+ pending_reason=None,
363
+ failed_reason=None,
364
+ )
365
+
366
+ def _check_model_health(self) -> None:
367
+ """Check model health and update status accordingly."""
368
+ status, status_code = utils.model_health_check(
369
+ self.status_info.model_name, self.slurm_job_id, self.log_dir
370
+ )
371
+ if status == ModelStatus.READY:
372
+ self.status_info.base_url = utils.get_base_url(
373
+ self.status_info.model_name,
374
+ self.slurm_job_id,
375
+ self.log_dir,
376
+ )
377
+ self.status_info.server_status = status
378
+ else:
379
+ self.status_info.server_status = status
380
+ self.status_info.failed_reason = cast(str, status_code)
381
+
382
+ def _process_running_state(self) -> None:
383
+ """Process RUNNING job state and check server status."""
384
+ server_status = utils.is_server_running(
385
+ self.status_info.model_name, self.slurm_job_id, self.log_dir
386
+ )
387
+
388
+ if isinstance(server_status, tuple):
389
+ self.status_info.server_status, self.status_info.failed_reason = (
390
+ server_status
391
+ )
392
+ return
393
+
394
+ if server_status == "RUNNING":
395
+ self._check_model_health()
396
+ else:
397
+ self.status_info.server_status = cast(ModelStatus, server_status)
398
+
399
+ def _process_pending_state(self) -> None:
400
+ """Process PENDING job state and update status information."""
401
+ try:
402
+ self.status_info.pending_reason = self.output.split(" ")[10].split("=")[1]
403
+ self.status_info.server_status = ModelStatus.PENDING
404
+ except IndexError:
405
+ self.status_info.pending_reason = "Unknown pending reason"
406
+
407
+ def process_model_status(self) -> StatusResponse:
408
+ """Process different job states and update status information.
409
+
410
+ Returns
411
+ -------
412
+ StatusResponse
413
+ Complete status information for the model
414
+ """
415
+ if self.status_info.job_state == ModelStatus.PENDING:
416
+ self._process_pending_state()
417
+ elif self.status_info.job_state == "RUNNING":
418
+ self._process_running_state()
419
+
420
+ return self.status_info
421
+
422
+
423
+ class PerformanceMetricsCollector:
424
+ """Class for handling metrics collection and processing.
425
+
426
+ A class that collects and processes performance metrics from running model servers,
427
+ including throughput and latency measurements.
428
+
429
+ Parameters
430
+ ----------
431
+ slurm_job_id : int
432
+ ID of the SLURM job to collect metrics from
433
+ log_dir : str, optional
434
+ Directory containing log files
435
+ """
436
+
437
+ def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None):
438
+ self.slurm_job_id = slurm_job_id
439
+ self.log_dir = log_dir
440
+ self.status_info = self._get_status_info()
441
+ self.metrics_url = self._build_metrics_url()
442
+ self.enabled_prefix_caching = self._check_prefix_caching()
443
+
444
+ self._prev_prompt_tokens: float = 0.0
445
+ self._prev_generation_tokens: float = 0.0
446
+ self._last_updated: Optional[float] = None
447
+ self._last_throughputs = {"prompt": 0.0, "generation": 0.0}
448
+
449
+ def _get_status_info(self) -> StatusResponse:
450
+ """Retrieve status info using existing StatusHelper.
451
+
452
+ Returns
453
+ -------
454
+ StatusResponse
455
+ Current status information for the model
456
+ """
457
+ status_helper = ModelStatusMonitor(self.slurm_job_id, self.log_dir)
458
+ return status_helper.process_model_status()
459
+
460
+ def _build_metrics_url(self) -> str:
461
+ """Construct metrics endpoint URL from base URL with version stripping.
462
+
463
+ Returns
464
+ -------
465
+ str
466
+ Complete metrics endpoint URL or status message
467
+ """
468
+ if self.status_info.job_state == ModelStatus.PENDING:
469
+ return "Pending resources for server initialization"
470
+
471
+ base_url = utils.get_base_url(
472
+ self.status_info.model_name,
473
+ self.slurm_job_id,
474
+ self.log_dir,
475
+ )
476
+ if not base_url.startswith("http"):
477
+ return "Server not ready"
478
+
479
+ parsed = urlparse(base_url)
480
+ clean_path = parsed.path.replace("/v1", "", 1).rstrip("/")
481
+ return urlunparse(
482
+ (parsed.scheme, parsed.netloc, f"{clean_path}/metrics", "", "", "")
483
+ )
484
+
485
+ def _check_prefix_caching(self) -> bool:
486
+ """Check if prefix caching is enabled.
487
+
488
+ Returns
489
+ -------
490
+ bool
491
+ True if prefix caching is enabled, False otherwise
492
+ """
493
+ job_json = utils.read_slurm_log(
494
+ self.status_info.model_name,
495
+ self.slurm_job_id,
496
+ "json",
497
+ self.log_dir,
498
+ )
499
+ if isinstance(job_json, str):
500
+ return False
501
+ return bool(cast(dict[str, str], job_json).get("enable_prefix_caching", False))
502
+
503
+ def _parse_metrics(self, metrics_text: str) -> dict[str, float]:
504
+ """Parse metrics with latency count and sum.
505
+
506
+ Parameters
507
+ ----------
508
+ metrics_text : str
509
+ Raw metrics text from the server
510
+
511
+ Returns
512
+ -------
513
+ dict[str, float]
514
+ Parsed metrics as key-value pairs
515
+ """
516
+ key_metrics = KEY_METRICS
517
+
518
+ if self.enabled_prefix_caching:
519
+ key_metrics["vllm:gpu_prefix_cache_hit_rate"] = "gpu_prefix_cache_hit_rate"
520
+ key_metrics["vllm:cpu_prefix_cache_hit_rate"] = "cpu_prefix_cache_hit_rate"
521
+
522
+ parsed: dict[str, float] = {}
523
+ for line in metrics_text.split("\n"):
524
+ if line.startswith("#") or not line.strip():
525
+ continue
526
+
527
+ parts = line.split()
528
+ if len(parts) < 2:
529
+ continue
530
+
531
+ metric_name = parts[0].split("{")[0]
532
+ if metric_name in key_metrics:
533
+ try:
534
+ parsed[key_metrics[metric_name]] = float(parts[1])
535
+ except (ValueError, IndexError):
536
+ continue
537
+ return parsed
538
+
539
+ def fetch_metrics(self) -> Union[dict[str, float], str]:
540
+ """Fetch metrics from the endpoint.
541
+
542
+ Returns
543
+ -------
544
+ Union[dict[str, float], str]
545
+ Dictionary of metrics or error message if request fails
546
+ """
547
+ try:
548
+ response = requests.get(self.metrics_url, timeout=3)
549
+ response.raise_for_status()
550
+ current_metrics = self._parse_metrics(response.text)
551
+ current_time = time.time()
552
+
553
+ # Set defaults using last known throughputs
554
+ current_metrics.setdefault(
555
+ "prompt_tokens_per_sec", self._last_throughputs["prompt"]
556
+ )
557
+ current_metrics.setdefault(
558
+ "generation_tokens_per_sec", self._last_throughputs["generation"]
559
+ )
560
+
561
+ if self._last_updated is None:
562
+ self._prev_prompt_tokens = current_metrics.get(
563
+ "total_prompt_tokens", 0.0
564
+ )
565
+ self._prev_generation_tokens = current_metrics.get(
566
+ "total_generation_tokens", 0.0
567
+ )
568
+ self._last_updated = current_time
569
+ return current_metrics
570
+
571
+ time_diff = current_time - self._last_updated
572
+ if time_diff > 0:
573
+ current_prompt = current_metrics.get("total_prompt_tokens", 0.0)
574
+ current_gen = current_metrics.get("total_generation_tokens", 0.0)
575
+
576
+ delta_prompt = current_prompt - self._prev_prompt_tokens
577
+ delta_gen = current_gen - self._prev_generation_tokens
578
+
579
+ # Only update throughputs when we have new tokens
580
+ prompt_tps = (
581
+ delta_prompt / time_diff
582
+ if delta_prompt > 0
583
+ else self._last_throughputs["prompt"]
584
+ )
585
+ gen_tps = (
586
+ delta_gen / time_diff
587
+ if delta_gen > 0
588
+ else self._last_throughputs["generation"]
589
+ )
590
+
591
+ current_metrics["prompt_tokens_per_sec"] = prompt_tps
592
+ current_metrics["generation_tokens_per_sec"] = gen_tps
593
+
594
+ # Persist calculated values regardless of activity
595
+ self._last_throughputs["prompt"] = prompt_tps
596
+ self._last_throughputs["generation"] = gen_tps
597
+
598
+ # Update tracking state
599
+ self._prev_prompt_tokens = current_prompt
600
+ self._prev_generation_tokens = current_gen
601
+ self._last_updated = current_time
602
+
603
+ # Calculate average latency if data is available
604
+ if (
605
+ "request_latency_sum" in current_metrics
606
+ and "request_latency_count" in current_metrics
607
+ ):
608
+ latency_sum = current_metrics["request_latency_sum"]
609
+ latency_count = current_metrics["request_latency_count"]
610
+ current_metrics["avg_request_latency"] = (
611
+ latency_sum / latency_count if latency_count > 0 else 0.0
612
+ )
613
+
614
+ return current_metrics
615
+
616
+ except requests.RequestException as e:
617
+ return f"Metrics request failed, `metrics` endpoint might not be ready yet: {str(e)}"
618
+
619
+
620
+ class ModelRegistry:
621
+ """Class for handling model listing and configuration management.
622
+
623
+ A class that provides functionality for listing available models and
624
+ managing their configurations.
625
+ """
626
+
627
+ def __init__(self) -> None:
628
+ """Initialize the model lister."""
629
+ self.model_configs = utils.load_config()
630
+
631
+ def get_all_models(self) -> list[ModelInfo]:
632
+ """Get all available models.
633
+
634
+ Returns
635
+ -------
636
+ list[ModelInfo]
637
+ List of information about all available models
638
+ """
639
+ available_models = []
640
+ for config in self.model_configs:
641
+ info = ModelInfo(
642
+ name=config.model_name,
643
+ family=config.model_family,
644
+ variant=config.model_variant,
645
+ model_type=ModelType(config.model_type),
646
+ config=config.model_dump(exclude={"model_name", "venv", "log_dir"}),
647
+ )
648
+ available_models.append(info)
649
+ return available_models
650
+
651
+ def get_single_model_config(self, model_name: str) -> ModelConfig:
652
+ """Get configuration for a specific model.
653
+
654
+ Parameters
655
+ ----------
656
+ model_name : str
657
+ Name of the model to retrieve configuration for
658
+
659
+ Returns
660
+ -------
661
+ ModelConfig
662
+ Configuration for the specified model
663
+
664
+ Raises
665
+ ------
666
+ ModelNotFoundError
667
+ If the specified model is not found in configuration
668
+ """
669
+ config = next(
670
+ (c for c in self.model_configs if c.model_name == model_name), None
671
+ )
672
+ if not config:
673
+ raise ModelNotFoundError(f"Model '{model_name}' not found in configuration")
674
+ return config