vec-inf 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vec_inf/client/_utils.py CHANGED
@@ -15,9 +15,10 @@ import requests
15
15
  import yaml
16
16
 
17
17
  from vec_inf.client._client_vars import MODEL_READY_SIGNATURE
18
+ from vec_inf.client._exceptions import MissingRequiredFieldsError
19
+ from vec_inf.client._slurm_vars import CACHED_CONFIG_DIR, REQUIRED_ARGS
18
20
  from vec_inf.client.config import ModelConfig
19
21
  from vec_inf.client.models import ModelStatus
20
- from vec_inf.client.slurm_vars import CACHED_CONFIG
21
22
 
22
23
 
23
24
  def run_bash_command(command: str) -> tuple[str, str]:
@@ -41,9 +42,9 @@ def run_bash_command(command: str) -> tuple[str, str]:
41
42
 
42
43
  def read_slurm_log(
43
44
  slurm_job_name: str,
44
- slurm_job_id: int,
45
+ slurm_job_id: str,
45
46
  slurm_log_type: str,
46
- log_dir: Optional[Union[str, Path]],
47
+ log_dir: str,
47
48
  ) -> Union[list[str], str, dict[str, str]]:
48
49
  """Read the slurm log file.
49
50
 
@@ -51,12 +52,12 @@ def read_slurm_log(
51
52
  ----------
52
53
  slurm_job_name : str
53
54
  Name of the SLURM job
54
- slurm_job_id : int
55
+ slurm_job_id : str
55
56
  ID of the SLURM job
56
57
  slurm_log_type : str
57
58
  Type of log file to read ('out', 'err', or 'json')
58
- log_dir : Optional[Union[str, Path]]
59
- Directory containing log files, if None uses default location
59
+ log_dir : str
60
+ Directory containing log files
60
61
 
61
62
  Returns
62
63
  -------
@@ -66,31 +67,11 @@ def read_slurm_log(
66
67
  - dict[str, str] for 'json' logs
67
68
  - str for error messages if file not found
68
69
  """
69
- if not log_dir:
70
- # Default log directory
71
- models_dir = Path.home() / ".vec-inf-logs"
72
- # Iterate over all dirs in models_dir, sorted by dir name length in desc order
73
- for directory in sorted(
74
- [d for d in models_dir.iterdir() if d.is_dir()],
75
- key=lambda d: len(d.name),
76
- reverse=True,
77
- ):
78
- if directory.name in slurm_job_name:
79
- log_dir = directory
80
- break
81
- else:
82
- log_dir = Path(log_dir)
83
-
84
- # If log_dir is still not set, then didn't find the log dir at default location
85
- if not log_dir:
86
- return "LOG DIR NOT FOUND"
87
-
88
70
  try:
89
- file_path = (
90
- log_dir
91
- / Path(f"{slurm_job_name}.{slurm_job_id}")
92
- / f"{slurm_job_name}.{slurm_job_id}.{slurm_log_type}"
93
- )
71
+ if "+" in slurm_job_id:
72
+ main_job_id, het_job_id = slurm_job_id.split("+")
73
+ slurm_job_id = str(int(main_job_id) + int(het_job_id))
74
+ file_path = Path(log_dir, f"{slurm_job_name}.{slurm_job_id}.{slurm_log_type}")
94
75
  if slurm_log_type == "json":
95
76
  with file_path.open("r") as file:
96
77
  json_content: dict[str, str] = json.load(file)
@@ -103,7 +84,7 @@ def read_slurm_log(
103
84
 
104
85
 
105
86
  def is_server_running(
106
- slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str]
87
+ slurm_job_name: str, slurm_job_id: str, log_dir: str
107
88
  ) -> Union[str, ModelStatus, tuple[ModelStatus, str]]:
108
89
  """Check if a model is ready to serve requests.
109
90
 
@@ -111,9 +92,9 @@ def is_server_running(
111
92
  ----------
112
93
  slurm_job_name : str
113
94
  Name of the SLURM job
114
- slurm_job_id : int
95
+ slurm_job_id : str
115
96
  ID of the SLURM job
116
- log_dir : Optional[str]
97
+ log_dir : str
117
98
  Directory containing log files
118
99
 
119
100
  Returns
@@ -138,16 +119,16 @@ def is_server_running(
138
119
  return status
139
120
 
140
121
 
141
- def get_base_url(slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str]) -> str:
122
+ def get_base_url(slurm_job_name: str, slurm_job_id: str, log_dir: str) -> str:
142
123
  """Get the base URL of a model.
143
124
 
144
125
  Parameters
145
126
  ----------
146
127
  slurm_job_name : str
147
128
  Name of the SLURM job
148
- slurm_job_id : int
129
+ slurm_job_id : str
149
130
  ID of the SLURM job
150
- log_dir : Optional[str]
131
+ log_dir : str
151
132
  Directory containing log files
152
133
 
153
134
  Returns
@@ -164,7 +145,7 @@ def get_base_url(slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str])
164
145
 
165
146
 
166
147
  def model_health_check(
167
- slurm_job_name: str, slurm_job_id: int, log_dir: Optional[str]
148
+ slurm_job_name: str, slurm_job_id: str, log_dir: str
168
149
  ) -> tuple[ModelStatus, Union[str, int]]:
169
150
  """Check the health of a running model on the cluster.
170
151
 
@@ -172,9 +153,9 @@ def model_health_check(
172
153
  ----------
173
154
  slurm_job_name : str
174
155
  Name of the SLURM job
175
- slurm_job_id : int
156
+ slurm_job_id : str
176
157
  ID of the SLURM job
177
- log_dir : Optional[str]
158
+ log_dir : str
178
159
  Directory containing log files
179
160
 
180
161
  Returns
@@ -199,12 +180,17 @@ def model_health_check(
199
180
  return (ModelStatus.FAILED, str(e))
200
181
 
201
182
 
202
- def load_config() -> list[ModelConfig]:
183
+ def load_config(config_path: Optional[str] = None) -> list[ModelConfig]:
203
184
  """Load the model configuration.
204
185
 
205
186
  Loads configuration from default and user-specified paths, merging them
206
187
  if both exist. User configuration takes precedence over default values.
207
188
 
189
+ Parameters
190
+ ----------
191
+ config_path : Optional[str]
192
+ Path to the configuration file
193
+
208
194
  Returns
209
195
  -------
210
196
  list[ModelConfig]
@@ -213,44 +199,80 @@ def load_config() -> list[ModelConfig]:
213
199
  Notes
214
200
  -----
215
201
  Configuration is loaded from:
216
- 1. Default path: package's config/models.yaml
217
- 2. User path: specified by VEC_INF_CONFIG environment variable
202
+ 1. User path: specified by config_path
203
+ 2. Default path: package's config/models.yaml or CACHED_CONFIG if it exists
204
+ 3. Environment variable: specified by VEC_INF_CONFIG environment variable
205
+ and merged with default config
218
206
 
219
207
  If user configuration exists, it will be merged with default configuration,
220
208
  with user values taking precedence for overlapping fields.
221
209
  """
210
+
211
+ def load_yaml_config(path: Path) -> dict[str, Any]:
212
+ """Load YAML config with error handling."""
213
+ try:
214
+ with path.open() as f:
215
+ return yaml.safe_load(f) or {}
216
+ except FileNotFoundError as err:
217
+ raise FileNotFoundError(f"Could not find config: {path}") from err
218
+ except yaml.YAMLError as err:
219
+ raise ValueError(f"Error parsing YAML config at {path}: {err}") from err
220
+
221
+ def process_config(config: dict[str, Any]) -> list[ModelConfig]:
222
+ """Process the config based on the config type."""
223
+ return [
224
+ ModelConfig(model_name=name, **model_data)
225
+ for name, model_data in config.get("models", {}).items()
226
+ ]
227
+
228
+ def resolve_config_path_from_env_var() -> Path | None:
229
+ """Resolve the config path from the environment variable."""
230
+ config_dir = os.getenv("VEC_INF_CONFIG_DIR")
231
+ config_path = os.getenv("VEC_INF_MODEL_CONFIG")
232
+ if config_path:
233
+ return Path(config_path)
234
+ if config_dir:
235
+ return Path(config_dir, "models.yaml")
236
+ return None
237
+
238
+ def update_config(
239
+ config: dict[str, Any], user_config: dict[str, Any]
240
+ ) -> dict[str, Any]:
241
+ """Update the config with the user config."""
242
+ for name, data in user_config.get("models", {}).items():
243
+ if name in config.get("models", {}):
244
+ config["models"][name].update(data)
245
+ else:
246
+ config.setdefault("models", {})[name] = data
247
+
248
+ return config
249
+
250
+ # 1. If config_path is given, use only that
251
+ if config_path:
252
+ config = load_yaml_config(Path(config_path))
253
+ return process_config(config)
254
+
255
+ # 2. Otherwise, load default config
222
256
  default_path = (
223
- CACHED_CONFIG
224
- if CACHED_CONFIG.exists()
257
+ CACHED_CONFIG_DIR / "models.yaml"
258
+ if CACHED_CONFIG_DIR.exists()
225
259
  else Path(__file__).resolve().parent.parent / "config" / "models.yaml"
226
260
  )
261
+ config = load_yaml_config(default_path)
262
+
263
+ # 3. If user config exists, merge it
264
+ user_path = resolve_config_path_from_env_var()
265
+ if user_path and user_path.exists():
266
+ user_config = load_yaml_config(user_path)
267
+ config = update_config(config, user_config)
268
+ elif user_path:
269
+ warnings.warn(
270
+ f"WARNING: Could not find user config: {str(user_path)}, revert to default config located at {default_path}",
271
+ UserWarning,
272
+ stacklevel=2,
273
+ )
227
274
 
228
- config: dict[str, Any] = {}
229
- with open(default_path) as f:
230
- config = yaml.safe_load(f) or {}
231
-
232
- user_path = os.getenv("VEC_INF_CONFIG")
233
- if user_path:
234
- user_path_obj = Path(user_path)
235
- if user_path_obj.exists():
236
- with open(user_path_obj) as f:
237
- user_config = yaml.safe_load(f) or {}
238
- for name, data in user_config.get("models", {}).items():
239
- if name in config.get("models", {}):
240
- config["models"][name].update(data)
241
- else:
242
- config.setdefault("models", {})[name] = data
243
- else:
244
- warnings.warn(
245
- f"WARNING: Could not find user config: {user_path}, revert to default config located at {default_path}",
246
- UserWarning,
247
- stacklevel=2,
248
- )
249
-
250
- return [
251
- ModelConfig(model_name=name, **model_data)
252
- for name, model_data in config.get("models", {}).items()
253
- ]
275
+ return process_config(config)
254
276
 
255
277
 
256
278
  def parse_launch_output(output: str) -> tuple[str, dict[str, str]]:
@@ -285,3 +307,100 @@ def parse_launch_output(output: str) -> tuple[str, dict[str, str]]:
285
307
  config_dict[key.lower().replace(" ", "_")] = value
286
308
 
287
309
  return slurm_job_id, config_dict
310
+
311
+
312
+ def is_power_of_two(n: int) -> bool:
313
+ """Check if a number is a power of two.
314
+
315
+ Parameters
316
+ ----------
317
+ n : int
318
+ The number to check
319
+ """
320
+ return n > 0 and (n & (n - 1)) == 0
321
+
322
+
323
+ def find_matching_dirs(
324
+ log_dir: Path,
325
+ model_family: Optional[str] = None,
326
+ model_name: Optional[str] = None,
327
+ job_id: Optional[int] = None,
328
+ before_job_id: Optional[int] = None,
329
+ ) -> list[Path]:
330
+ """
331
+ Find log directories based on filtering criteria.
332
+
333
+ Parameters
334
+ ----------
335
+ log_dir : Path
336
+ The base directory containing model family directories.
337
+ model_family : str, optional
338
+ Filter to only search inside this family.
339
+ model_name : str, optional
340
+ Filter to only match model names.
341
+ job_id : int, optional
342
+ Filter to only match this exact SLURM job ID.
343
+ before_job_id : int, optional
344
+ Filter to only include job IDs less than this value.
345
+
346
+ Returns
347
+ -------
348
+ list[Path]
349
+ List of directories that match the criteria and can be deleted.
350
+ """
351
+ matched = []
352
+
353
+ if not log_dir.exists() or not log_dir.is_dir():
354
+ raise FileNotFoundError(f"Log directory does not exist: {log_dir}")
355
+
356
+ if not model_family and not model_name and not job_id and not before_job_id:
357
+ return [log_dir]
358
+
359
+ for family_dir in log_dir.iterdir():
360
+ if not family_dir.is_dir():
361
+ continue
362
+ if model_family and family_dir.name != model_family:
363
+ continue
364
+
365
+ if model_family and not model_name and not job_id and not before_job_id:
366
+ return [family_dir]
367
+
368
+ for job_dir in family_dir.iterdir():
369
+ if not job_dir.is_dir():
370
+ continue
371
+
372
+ try:
373
+ name_part, id_part = job_dir.name.rsplit(".", 1)
374
+ parsed_id = int(id_part)
375
+ except ValueError:
376
+ continue
377
+
378
+ if model_name and name_part != model_name:
379
+ continue
380
+ if job_id is not None and parsed_id != job_id:
381
+ continue
382
+ if before_job_id is not None and parsed_id >= before_job_id:
383
+ continue
384
+
385
+ matched.append(job_dir)
386
+
387
+ return matched
388
+
389
+
390
+ def check_required_fields(params: dict[str, Any]) -> None:
391
+ """Check for required fields without default vals and their corresponding env vars.
392
+
393
+ Parameters
394
+ ----------
395
+ params : dict[str, Any]
396
+ Dictionary of parameters to check.
397
+ """
398
+ for arg in REQUIRED_ARGS:
399
+ if not params.get(arg):
400
+ default_value = os.getenv(REQUIRED_ARGS[arg])
401
+ if default_value:
402
+ params[arg] = default_value
403
+ else:
404
+ raise MissingRequiredFieldsError(
405
+ f"{arg} is required, please set it in the command arguments or environment variables"
406
+ )
vec_inf/client/api.py CHANGED
@@ -10,8 +10,10 @@ vec_inf.client._helper : Helper classes for model inference server management
10
10
  vec_inf.client.models : Data models for API responses
11
11
  """
12
12
 
13
+ import shutil
13
14
  import time
14
15
  import warnings
16
+ from pathlib import Path
15
17
  from typing import Any, Optional, Union
16
18
 
17
19
  from vec_inf.client._exceptions import (
@@ -19,14 +21,16 @@ from vec_inf.client._exceptions import (
19
21
  SlurmJobError,
20
22
  )
21
23
  from vec_inf.client._helper import (
24
+ BatchModelLauncher,
22
25
  ModelLauncher,
23
26
  ModelRegistry,
24
27
  ModelStatusMonitor,
25
28
  PerformanceMetricsCollector,
26
29
  )
27
- from vec_inf.client._utils import run_bash_command
30
+ from vec_inf.client._utils import find_matching_dirs, run_bash_command
28
31
  from vec_inf.client.config import ModelConfig
29
32
  from vec_inf.client.models import (
33
+ BatchLaunchResponse,
30
34
  LaunchOptions,
31
35
  LaunchResponse,
32
36
  MetricsResponse,
@@ -60,6 +64,9 @@ class VecInfClient:
60
64
  wait_until_ready(slurm_job_id, timeout_seconds, poll_interval_seconds, log_dir)
61
65
  Wait for a model to become ready
62
66
 
67
+ cleanup_logs(log_dir, model_name, model_family, job_id, dry_run)
68
+ Remove logs from the log directory.
69
+
63
70
  Examples
64
71
  --------
65
72
  >>> from vec_inf.api import VecInfClient
@@ -145,17 +152,42 @@ class VecInfClient:
145
152
  model_launcher = ModelLauncher(model_name, options_dict)
146
153
  return model_launcher.launch()
147
154
 
148
- def get_status(
149
- self, slurm_job_id: int, log_dir: Optional[str] = None
150
- ) -> StatusResponse:
155
+ def batch_launch_models(
156
+ self,
157
+ model_names: list[str],
158
+ batch_config: Optional[str] = None,
159
+ account: Optional[str] = None,
160
+ work_dir: Optional[str] = None,
161
+ ) -> BatchLaunchResponse:
162
+ """Launch multiple models on the cluster.
163
+
164
+ Parameters
165
+ ----------
166
+ model_names : list[str]
167
+ List of model names to launch
168
+
169
+ Returns
170
+ -------
171
+ BatchLaunchResponse
172
+ Response containing launch details for each model
173
+
174
+ Raises
175
+ ------
176
+ ModelConfigurationError
177
+ If the model configuration is invalid
178
+ """
179
+ model_launcher = BatchModelLauncher(
180
+ model_names, batch_config, account, work_dir
181
+ )
182
+ return model_launcher.launch()
183
+
184
+ def get_status(self, slurm_job_id: str) -> StatusResponse:
151
185
  """Get the status of a running model.
152
186
 
153
187
  Parameters
154
188
  ----------
155
- slurm_job_id : int
189
+ slurm_job_id : str
156
190
  The SLURM job ID to check
157
- log_dir : str, optional
158
- Path to the SLURM log directory. If None, uses default location
159
191
 
160
192
  Returns
161
193
  -------
@@ -167,20 +199,16 @@ class VecInfClient:
167
199
  - Base URL (if ready)
168
200
  - Error information (if failed)
169
201
  """
170
- model_status_monitor = ModelStatusMonitor(slurm_job_id, log_dir)
202
+ model_status_monitor = ModelStatusMonitor(slurm_job_id)
171
203
  return model_status_monitor.process_model_status()
172
204
 
173
- def get_metrics(
174
- self, slurm_job_id: int, log_dir: Optional[str] = None
175
- ) -> MetricsResponse:
205
+ def get_metrics(self, slurm_job_id: str) -> MetricsResponse:
176
206
  """Get the performance metrics of a running model.
177
207
 
178
208
  Parameters
179
209
  ----------
180
- slurm_job_id : int
210
+ slurm_job_id : str
181
211
  The SLURM job ID to get metrics for
182
- log_dir : str, optional
183
- Path to the SLURM log directory. If None, uses default location
184
212
 
185
213
  Returns
186
214
  -------
@@ -190,9 +218,7 @@ class VecInfClient:
190
218
  - Performance metrics or error message
191
219
  - Timestamp of collection
192
220
  """
193
- performance_metrics_collector = PerformanceMetricsCollector(
194
- slurm_job_id, log_dir
195
- )
221
+ performance_metrics_collector = PerformanceMetricsCollector(slurm_job_id)
196
222
 
197
223
  metrics: Union[dict[str, float], str]
198
224
  if not performance_metrics_collector.metrics_url.startswith("http"):
@@ -206,12 +232,12 @@ class VecInfClient:
206
232
  timestamp=time.time(),
207
233
  )
208
234
 
209
- def shutdown_model(self, slurm_job_id: int) -> bool:
235
+ def shutdown_model(self, slurm_job_id: str) -> bool:
210
236
  """Shutdown a running model.
211
237
 
212
238
  Parameters
213
239
  ----------
214
- slurm_job_id : int
240
+ slurm_job_id : str
215
241
  The SLURM job ID to shut down
216
242
 
217
243
  Returns
@@ -232,23 +258,20 @@ class VecInfClient:
232
258
 
233
259
  def wait_until_ready(
234
260
  self,
235
- slurm_job_id: int,
261
+ slurm_job_id: str,
236
262
  timeout_seconds: int = 1800,
237
263
  poll_interval_seconds: int = 10,
238
- log_dir: Optional[str] = None,
239
264
  ) -> StatusResponse:
240
265
  """Wait until a model is ready or fails.
241
266
 
242
267
  Parameters
243
268
  ----------
244
- slurm_job_id : int
269
+ slurm_job_id : str
245
270
  The SLURM job ID to wait for
246
271
  timeout_seconds : int, optional
247
272
  Maximum time to wait in seconds, by default 1800 (30 mins)
248
273
  poll_interval_seconds : int, optional
249
274
  How often to check status in seconds, by default 10
250
- log_dir : str, optional
251
- Path to the SLURM log directory. If None, uses default location
252
275
 
253
276
  Returns
254
277
  -------
@@ -273,7 +296,7 @@ class VecInfClient:
273
296
  start_time = time.time()
274
297
 
275
298
  while True:
276
- status_info = self.get_status(slurm_job_id, log_dir)
299
+ status_info = self.get_status(slurm_job_id)
277
300
 
278
301
  if status_info.server_status == ModelStatus.READY:
279
302
  return status_info
@@ -300,3 +323,51 @@ class VecInfClient:
300
323
 
301
324
  # Wait before checking again
302
325
  time.sleep(poll_interval_seconds)
326
+
327
+ def cleanup_logs(
328
+ self,
329
+ log_dir: Optional[Union[str, Path]] = None,
330
+ model_family: Optional[str] = None,
331
+ model_name: Optional[str] = None,
332
+ job_id: Optional[int] = None,
333
+ before_job_id: Optional[int] = None,
334
+ dry_run: bool = False,
335
+ ) -> list[Path]:
336
+ """Remove logs from the log directory.
337
+
338
+ Parameters
339
+ ----------
340
+ log_dir : str or Path, optional
341
+ Root directory containing log files. Defaults to ~/.vec-inf-logs.
342
+ model_family : str, optional
343
+ Only delete logs for this model family.
344
+ model_name : str, optional
345
+ Only delete logs for this model name.
346
+ job_id : int, optional
347
+ If provided, only match directories with this exact SLURM job ID.
348
+ before_job_id : int, optional
349
+ If provided, only delete logs with job ID less than this value.
350
+ dry_run : bool
351
+ If True, return matching files without deleting them.
352
+
353
+ Returns
354
+ -------
355
+ list[Path]
356
+ List of deleted (or matched if dry_run) log file paths.
357
+ """
358
+ log_root = Path(log_dir) if log_dir else Path.home() / ".vec-inf-logs"
359
+ matched = find_matching_dirs(
360
+ log_dir=log_root,
361
+ model_family=model_family,
362
+ model_name=model_name,
363
+ job_id=job_id,
364
+ before_job_id=before_job_id,
365
+ )
366
+
367
+ if dry_run:
368
+ return matched
369
+
370
+ for path in matched:
371
+ shutil.rmtree(path)
372
+
373
+ return matched