vec-inf 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,661 @@
1
+ """Helper classes for the model.
2
+
3
+ This module provides utility classes for managing model deployment, status monitoring,
4
+ metrics collection, and model registry operations.
5
+ """
6
+
7
+ import json
8
+ import time
9
+ import warnings
10
+ from pathlib import Path
11
+ from typing import Any, Optional, Union, cast
12
+ from urllib.parse import urlparse, urlunparse
13
+
14
+ import requests
15
+
16
+ import vec_inf.client._utils as utils
17
+ from vec_inf.client._client_vars import (
18
+ KEY_METRICS,
19
+ REQUIRED_FIELDS,
20
+ SRC_DIR,
21
+ VLLM_SHORT_TO_LONG_MAP,
22
+ )
23
+ from vec_inf.client._exceptions import (
24
+ MissingRequiredFieldsError,
25
+ ModelConfigurationError,
26
+ ModelNotFoundError,
27
+ SlurmJobError,
28
+ )
29
+ from vec_inf.client._slurm_script_generator import SlurmScriptGenerator
30
+ from vec_inf.client.config import ModelConfig
31
+ from vec_inf.client.models import (
32
+ LaunchResponse,
33
+ ModelInfo,
34
+ ModelStatus,
35
+ ModelType,
36
+ StatusResponse,
37
+ )
38
+
39
+
40
+ class ModelLauncher:
41
+ """Helper class for handling inference server launch.
42
+
43
+ A class that manages the launch process of inference servers, including
44
+ configuration validation, parameter preparation, and SLURM job submission.
45
+
46
+ Parameters
47
+ ----------
48
+ model_name : str
49
+ Name of the model to launch
50
+ kwargs : dict[str, Any], optional
51
+ Optional launch keyword arguments to override default configuration
52
+ """
53
+
54
+ def __init__(self, model_name: str, kwargs: Optional[dict[str, Any]]):
55
+ """Initialize the model launcher.
56
+
57
+ Parameters
58
+ ----------
59
+ model_name: str
60
+ Name of the model to launch
61
+ kwargs: Optional[dict[str, Any]]
62
+ Optional launch keyword arguments to override default configuration
63
+ """
64
+ self.model_name = model_name
65
+ self.kwargs = kwargs or {}
66
+ self.slurm_job_id = ""
67
+ self.slurm_script_path = Path("")
68
+ self.model_config = self._get_model_configuration()
69
+ self.params = self._get_launch_params()
70
+
71
+ def _warn(self, message: str) -> None:
72
+ """Warn the user about a potential issue.
73
+
74
+ Parameters
75
+ ----------
76
+ message : str
77
+ Warning message to display
78
+ """
79
+ warnings.warn(message, UserWarning, stacklevel=2)
80
+
81
+ def _get_model_configuration(self) -> ModelConfig:
82
+ """Load and validate model configuration.
83
+
84
+ Returns
85
+ -------
86
+ ModelConfig
87
+ Validated configuration for the model
88
+
89
+ Raises
90
+ ------
91
+ ModelNotFoundError
92
+ If model weights parent directory cannot be determined
93
+ ModelConfigurationError
94
+ If model configuration is not found and weights don't exist
95
+ """
96
+ model_configs = utils.load_config()
97
+ config = next(
98
+ (m for m in model_configs if m.model_name == self.model_name), None
99
+ )
100
+
101
+ if config:
102
+ return config
103
+
104
+ # If model config not found, check for path from CLI kwargs or use fallback
105
+ model_weights_parent_dir = self.kwargs.get(
106
+ "model_weights_parent_dir",
107
+ model_configs[0].model_weights_parent_dir if model_configs else None,
108
+ )
109
+
110
+ if not model_weights_parent_dir:
111
+ raise ModelNotFoundError(
112
+ "Could not determine model weights parent directory"
113
+ )
114
+
115
+ model_weights_path = Path(model_weights_parent_dir, self.model_name)
116
+
117
+ # Only give a warning if weights exist but config missing
118
+ if model_weights_path.exists():
119
+ self._warn(
120
+ f"Warning: '{self.model_name}' configuration not found in config, please ensure model configuration are properly set in command arguments",
121
+ )
122
+ # Return a dummy model config object with model name and weights parent dir
123
+ return ModelConfig(
124
+ model_name=self.model_name,
125
+ model_family="model_family_placeholder",
126
+ model_type="LLM",
127
+ gpus_per_node=1,
128
+ num_nodes=1,
129
+ vocab_size=1000,
130
+ model_weights_parent_dir=Path(str(model_weights_parent_dir)),
131
+ )
132
+
133
+ raise ModelConfigurationError(
134
+ f"'{self.model_name}' not found in configuration and model weights "
135
+ f"not found at expected path '{model_weights_path}'"
136
+ )
137
+
138
+ def _process_vllm_args(self, arg_string: str) -> dict[str, Any]:
139
+ """Process the vllm_args string into a dictionary.
140
+
141
+ Parameters
142
+ ----------
143
+ arg_string : str
144
+ Comma-separated string of vLLM arguments
145
+
146
+ Returns
147
+ -------
148
+ dict[str, Any]
149
+ Processed vLLM arguments as key-value pairs
150
+ """
151
+ vllm_args: dict[str, str | bool] = {}
152
+ for arg in arg_string.split(","):
153
+ if "=" in arg:
154
+ key, value = arg.split("=")
155
+ if key.strip() in VLLM_SHORT_TO_LONG_MAP:
156
+ key = VLLM_SHORT_TO_LONG_MAP[key.strip()]
157
+ vllm_args[key.strip()] = value.strip()
158
+ elif "-O" in arg.strip():
159
+ key = VLLM_SHORT_TO_LONG_MAP["-O"]
160
+ vllm_args[key] = arg.strip()[2:].strip()
161
+ else:
162
+ vllm_args[arg.strip()] = True
163
+ return vllm_args
164
+
165
+ def _get_launch_params(self) -> dict[str, Any]:
166
+ """Prepare launch parameters, set log dir, and validate required fields.
167
+
168
+ Returns
169
+ -------
170
+ dict[str, Any]
171
+ Dictionary of prepared launch parameters
172
+
173
+ Raises
174
+ ------
175
+ MissingRequiredFieldsError
176
+ If required fields are missing or tensor parallel size is not specified
177
+ when using multiple GPUs
178
+ """
179
+ params = self.model_config.model_dump(exclude_none=True)
180
+
181
+ # Override config defaults with CLI arguments
182
+ if self.kwargs.get("vllm_args"):
183
+ vllm_args = self._process_vllm_args(self.kwargs["vllm_args"])
184
+ for key, value in vllm_args.items():
185
+ params["vllm_args"][key] = value
186
+ del self.kwargs["vllm_args"]
187
+
188
+ for key, value in self.kwargs.items():
189
+ params[key] = value
190
+
191
+ # Validate required fields and vllm args
192
+ if not REQUIRED_FIELDS.issubset(set(params.keys())):
193
+ raise MissingRequiredFieldsError(
194
+ f"Missing required fields: {REQUIRED_FIELDS - set(params.keys())}"
195
+ )
196
+ if (
197
+ int(params["gpus_per_node"]) > 1
198
+ and params["vllm_args"].get("--tensor-parallel-size") is None
199
+ ):
200
+ raise MissingRequiredFieldsError(
201
+ "--tensor-parallel-size is required when gpus_per_node > 1"
202
+ )
203
+
204
+ # Create log directory
205
+ params["log_dir"] = Path(params["log_dir"], params["model_family"]).expanduser()
206
+ params["log_dir"].mkdir(parents=True, exist_ok=True)
207
+ params["src_dir"] = SRC_DIR
208
+
209
+ # Construct slurm log file paths
210
+ params["out_file"] = (
211
+ f"{params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.out"
212
+ )
213
+ params["err_file"] = (
214
+ f"{params['log_dir']}/{self.model_name}.%j/{self.model_name}.%j.err"
215
+ )
216
+ params["json_file"] = (
217
+ f"{params['log_dir']}/{self.model_name}.$SLURM_JOB_ID/{self.model_name}.$SLURM_JOB_ID.json"
218
+ )
219
+
220
+ # Convert path to string for JSON serialization
221
+ for field in params:
222
+ if field == "vllm_args":
223
+ continue
224
+ params[field] = str(params[field])
225
+
226
+ return params
227
+
228
+ def _build_launch_command(self) -> str:
229
+ """Generate the slurm script and construct the launch command.
230
+
231
+ Returns
232
+ -------
233
+ str
234
+ Complete SLURM launch command
235
+ """
236
+ self.slurm_script_path = SlurmScriptGenerator(self.params).write_to_log_dir()
237
+ return f"sbatch {self.slurm_script_path}"
238
+
239
+ def launch(self) -> LaunchResponse:
240
+ """Launch the model.
241
+
242
+ Returns
243
+ -------
244
+ LaunchResponse
245
+ Response object containing launch details and status
246
+
247
+ Raises
248
+ ------
249
+ SlurmJobError
250
+ If SLURM job submission fails
251
+ """
252
+ # Build and execute the launch command
253
+ command_output, stderr = utils.run_bash_command(self._build_launch_command())
254
+
255
+ if stderr:
256
+ raise SlurmJobError(f"Error: {stderr}")
257
+
258
+ # Extract slurm job id from command output
259
+ self.slurm_job_id = command_output.split(" ")[-1].strip().strip("\n")
260
+ self.params["slurm_job_id"] = self.slurm_job_id
261
+
262
+ # Create log directory and job json file, move slurm script to job log directory
263
+ job_log_dir = Path(
264
+ self.params["log_dir"], f"{self.model_name}.{self.slurm_job_id}"
265
+ )
266
+ job_log_dir.mkdir(parents=True, exist_ok=True)
267
+
268
+ job_json = Path(
269
+ job_log_dir,
270
+ f"{self.model_name}.{self.slurm_job_id}.json",
271
+ )
272
+ job_json.touch(exist_ok=True)
273
+
274
+ self.slurm_script_path.rename(
275
+ job_log_dir / f"{self.model_name}.{self.slurm_job_id}.slurm"
276
+ )
277
+
278
+ with job_json.open("w") as file:
279
+ json.dump(self.params, file, indent=4)
280
+
281
+ return LaunchResponse(
282
+ slurm_job_id=int(self.slurm_job_id),
283
+ model_name=self.model_name,
284
+ config=self.params,
285
+ raw_output=command_output,
286
+ )
287
+
288
+
289
+ class ModelStatusMonitor:
290
+ """Class for handling server status information and monitoring.
291
+
292
+ A class that monitors and reports the status of deployed model servers,
293
+ including job state and health checks.
294
+
295
+ Parameters
296
+ ----------
297
+ slurm_job_id : int
298
+ ID of the SLURM job to monitor
299
+ log_dir : str, optional
300
+ Base directory containing log files
301
+ """
302
+
303
+ def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None):
304
+ self.slurm_job_id = slurm_job_id
305
+ self.output = self._get_raw_status_output()
306
+ self.log_dir = log_dir
307
+ self.status_info = self._get_base_status_data()
308
+
309
+ def _get_raw_status_output(self) -> str:
310
+ """Get the raw server status output from slurm.
311
+
312
+ Returns
313
+ -------
314
+ str
315
+ Raw output from scontrol command
316
+
317
+ Raises
318
+ ------
319
+ SlurmJobError
320
+ If status check fails
321
+ """
322
+ status_cmd = f"scontrol show job {self.slurm_job_id} --oneliner"
323
+ output, stderr = utils.run_bash_command(status_cmd)
324
+ if stderr:
325
+ raise SlurmJobError(f"Error: {stderr}")
326
+ return output
327
+
328
+ def _get_base_status_data(self) -> StatusResponse:
329
+ """Extract basic job status information from scontrol output.
330
+
331
+ Returns
332
+ -------
333
+ StatusResponse
334
+ Basic status information for the job
335
+ """
336
+ try:
337
+ job_name = self.output.split(" ")[1].split("=")[1]
338
+ job_state = self.output.split(" ")[9].split("=")[1]
339
+ except IndexError:
340
+ job_name = "UNAVAILABLE"
341
+ job_state = ModelStatus.UNAVAILABLE
342
+
343
+ return StatusResponse(
344
+ model_name=job_name,
345
+ server_status=ModelStatus.UNAVAILABLE,
346
+ job_state=job_state,
347
+ raw_output=self.output,
348
+ base_url="UNAVAILABLE",
349
+ pending_reason=None,
350
+ failed_reason=None,
351
+ )
352
+
353
+ def _check_model_health(self) -> None:
354
+ """Check model health and update status accordingly."""
355
+ status, status_code = utils.model_health_check(
356
+ self.status_info.model_name, self.slurm_job_id, self.log_dir
357
+ )
358
+ if status == ModelStatus.READY:
359
+ self.status_info.base_url = utils.get_base_url(
360
+ self.status_info.model_name,
361
+ self.slurm_job_id,
362
+ self.log_dir,
363
+ )
364
+ self.status_info.server_status = status
365
+ else:
366
+ self.status_info.server_status = status
367
+ self.status_info.failed_reason = cast(str, status_code)
368
+
369
+ def _process_running_state(self) -> None:
370
+ """Process RUNNING job state and check server status."""
371
+ server_status = utils.is_server_running(
372
+ self.status_info.model_name, self.slurm_job_id, self.log_dir
373
+ )
374
+
375
+ if isinstance(server_status, tuple):
376
+ self.status_info.server_status, self.status_info.failed_reason = (
377
+ server_status
378
+ )
379
+ return
380
+
381
+ if server_status == "RUNNING":
382
+ self._check_model_health()
383
+ else:
384
+ self.status_info.server_status = cast(ModelStatus, server_status)
385
+
386
+ def _process_pending_state(self) -> None:
387
+ """Process PENDING job state and update status information."""
388
+ try:
389
+ self.status_info.pending_reason = self.output.split(" ")[10].split("=")[1]
390
+ self.status_info.server_status = ModelStatus.PENDING
391
+ except IndexError:
392
+ self.status_info.pending_reason = "Unknown pending reason"
393
+
394
+ def process_model_status(self) -> StatusResponse:
395
+ """Process different job states and update status information.
396
+
397
+ Returns
398
+ -------
399
+ StatusResponse
400
+ Complete status information for the model
401
+ """
402
+ if self.status_info.job_state == ModelStatus.PENDING:
403
+ self._process_pending_state()
404
+ elif self.status_info.job_state == "RUNNING":
405
+ self._process_running_state()
406
+
407
+ return self.status_info
408
+
409
+
410
+ class PerformanceMetricsCollector:
411
+ """Class for handling metrics collection and processing.
412
+
413
+ A class that collects and processes performance metrics from running model servers,
414
+ including throughput and latency measurements.
415
+
416
+ Parameters
417
+ ----------
418
+ slurm_job_id : int
419
+ ID of the SLURM job to collect metrics from
420
+ log_dir : str, optional
421
+ Directory containing log files
422
+ """
423
+
424
+ def __init__(self, slurm_job_id: int, log_dir: Optional[str] = None):
425
+ self.slurm_job_id = slurm_job_id
426
+ self.log_dir = log_dir
427
+ self.status_info = self._get_status_info()
428
+ self.metrics_url = self._build_metrics_url()
429
+ self.enabled_prefix_caching = self._check_prefix_caching()
430
+
431
+ self._prev_prompt_tokens: float = 0.0
432
+ self._prev_generation_tokens: float = 0.0
433
+ self._last_updated: Optional[float] = None
434
+ self._last_throughputs = {"prompt": 0.0, "generation": 0.0}
435
+
436
+ def _get_status_info(self) -> StatusResponse:
437
+ """Retrieve status info using existing StatusHelper.
438
+
439
+ Returns
440
+ -------
441
+ StatusResponse
442
+ Current status information for the model
443
+ """
444
+ status_helper = ModelStatusMonitor(self.slurm_job_id, self.log_dir)
445
+ return status_helper.process_model_status()
446
+
447
+ def _build_metrics_url(self) -> str:
448
+ """Construct metrics endpoint URL from base URL with version stripping.
449
+
450
+ Returns
451
+ -------
452
+ str
453
+ Complete metrics endpoint URL or status message
454
+ """
455
+ if self.status_info.job_state == ModelStatus.PENDING:
456
+ return "Pending resources for server initialization"
457
+
458
+ base_url = utils.get_base_url(
459
+ self.status_info.model_name,
460
+ self.slurm_job_id,
461
+ self.log_dir,
462
+ )
463
+ if not base_url.startswith("http"):
464
+ return "Server not ready"
465
+
466
+ parsed = urlparse(base_url)
467
+ clean_path = parsed.path.replace("/v1", "", 1).rstrip("/")
468
+ return urlunparse(
469
+ (parsed.scheme, parsed.netloc, f"{clean_path}/metrics", "", "", "")
470
+ )
471
+
472
+ def _check_prefix_caching(self) -> bool:
473
+ """Check if prefix caching is enabled.
474
+
475
+ Returns
476
+ -------
477
+ bool
478
+ True if prefix caching is enabled, False otherwise
479
+ """
480
+ job_json = utils.read_slurm_log(
481
+ self.status_info.model_name,
482
+ self.slurm_job_id,
483
+ "json",
484
+ self.log_dir,
485
+ )
486
+ if isinstance(job_json, str):
487
+ return False
488
+ return bool(cast(dict[str, str], job_json).get("enable_prefix_caching", False))
489
+
490
+ def _parse_metrics(self, metrics_text: str) -> dict[str, float]:
491
+ """Parse metrics with latency count and sum.
492
+
493
+ Parameters
494
+ ----------
495
+ metrics_text : str
496
+ Raw metrics text from the server
497
+
498
+ Returns
499
+ -------
500
+ dict[str, float]
501
+ Parsed metrics as key-value pairs
502
+ """
503
+ key_metrics = KEY_METRICS
504
+
505
+ if self.enabled_prefix_caching:
506
+ key_metrics["vllm:gpu_prefix_cache_hit_rate"] = "gpu_prefix_cache_hit_rate"
507
+ key_metrics["vllm:cpu_prefix_cache_hit_rate"] = "cpu_prefix_cache_hit_rate"
508
+
509
+ parsed: dict[str, float] = {}
510
+ for line in metrics_text.split("\n"):
511
+ if line.startswith("#") or not line.strip():
512
+ continue
513
+
514
+ parts = line.split()
515
+ if len(parts) < 2:
516
+ continue
517
+
518
+ metric_name = parts[0].split("{")[0]
519
+ if metric_name in key_metrics:
520
+ try:
521
+ parsed[key_metrics[metric_name]] = float(parts[1])
522
+ except (ValueError, IndexError):
523
+ continue
524
+ return parsed
525
+
526
+ def fetch_metrics(self) -> Union[dict[str, float], str]:
527
+ """Fetch metrics from the endpoint.
528
+
529
+ Returns
530
+ -------
531
+ Union[dict[str, float], str]
532
+ Dictionary of metrics or error message if request fails
533
+ """
534
+ try:
535
+ response = requests.get(self.metrics_url, timeout=3)
536
+ response.raise_for_status()
537
+ current_metrics = self._parse_metrics(response.text)
538
+ current_time = time.time()
539
+
540
+ # Set defaults using last known throughputs
541
+ current_metrics.setdefault(
542
+ "prompt_tokens_per_sec", self._last_throughputs["prompt"]
543
+ )
544
+ current_metrics.setdefault(
545
+ "generation_tokens_per_sec", self._last_throughputs["generation"]
546
+ )
547
+
548
+ if self._last_updated is None:
549
+ self._prev_prompt_tokens = current_metrics.get(
550
+ "total_prompt_tokens", 0.0
551
+ )
552
+ self._prev_generation_tokens = current_metrics.get(
553
+ "total_generation_tokens", 0.0
554
+ )
555
+ self._last_updated = current_time
556
+ return current_metrics
557
+
558
+ time_diff = current_time - self._last_updated
559
+ if time_diff > 0:
560
+ current_prompt = current_metrics.get("total_prompt_tokens", 0.0)
561
+ current_gen = current_metrics.get("total_generation_tokens", 0.0)
562
+
563
+ delta_prompt = current_prompt - self._prev_prompt_tokens
564
+ delta_gen = current_gen - self._prev_generation_tokens
565
+
566
+ # Only update throughputs when we have new tokens
567
+ prompt_tps = (
568
+ delta_prompt / time_diff
569
+ if delta_prompt > 0
570
+ else self._last_throughputs["prompt"]
571
+ )
572
+ gen_tps = (
573
+ delta_gen / time_diff
574
+ if delta_gen > 0
575
+ else self._last_throughputs["generation"]
576
+ )
577
+
578
+ current_metrics["prompt_tokens_per_sec"] = prompt_tps
579
+ current_metrics["generation_tokens_per_sec"] = gen_tps
580
+
581
+ # Persist calculated values regardless of activity
582
+ self._last_throughputs["prompt"] = prompt_tps
583
+ self._last_throughputs["generation"] = gen_tps
584
+
585
+ # Update tracking state
586
+ self._prev_prompt_tokens = current_prompt
587
+ self._prev_generation_tokens = current_gen
588
+ self._last_updated = current_time
589
+
590
+ # Calculate average latency if data is available
591
+ if (
592
+ "request_latency_sum" in current_metrics
593
+ and "request_latency_count" in current_metrics
594
+ ):
595
+ latency_sum = current_metrics["request_latency_sum"]
596
+ latency_count = current_metrics["request_latency_count"]
597
+ current_metrics["avg_request_latency"] = (
598
+ latency_sum / latency_count if latency_count > 0 else 0.0
599
+ )
600
+
601
+ return current_metrics
602
+
603
+ except requests.RequestException as e:
604
+ return f"Metrics request failed, `metrics` endpoint might not be ready yet: {str(e)}"
605
+
606
+
607
+ class ModelRegistry:
608
+ """Class for handling model listing and configuration management.
609
+
610
+ A class that provides functionality for listing available models and
611
+ managing their configurations.
612
+ """
613
+
614
+ def __init__(self) -> None:
615
+ """Initialize the model lister."""
616
+ self.model_configs = utils.load_config()
617
+
618
+ def get_all_models(self) -> list[ModelInfo]:
619
+ """Get all available models.
620
+
621
+ Returns
622
+ -------
623
+ list[ModelInfo]
624
+ List of information about all available models
625
+ """
626
+ available_models = []
627
+ for config in self.model_configs:
628
+ info = ModelInfo(
629
+ name=config.model_name,
630
+ family=config.model_family,
631
+ variant=config.model_variant,
632
+ model_type=ModelType(config.model_type),
633
+ config=config.model_dump(exclude={"model_name", "venv", "log_dir"}),
634
+ )
635
+ available_models.append(info)
636
+ return available_models
637
+
638
+ def get_single_model_config(self, model_name: str) -> ModelConfig:
639
+ """Get configuration for a specific model.
640
+
641
+ Parameters
642
+ ----------
643
+ model_name : str
644
+ Name of the model to retrieve configuration for
645
+
646
+ Returns
647
+ -------
648
+ ModelConfig
649
+ Configuration for the specified model
650
+
651
+ Raises
652
+ ------
653
+ ModelNotFoundError
654
+ If the specified model is not found in configuration
655
+ """
656
+ config = next(
657
+ (c for c in self.model_configs if c.model_name == model_name), None
658
+ )
659
+ if not config:
660
+ raise ModelNotFoundError(f"Model '{model_name}' not found in configuration")
661
+ return config