gpustack-runtime 0.1.39.post2__py3-none-any.whl → 0.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. gpustack_runtime/__main__.py +7 -3
  2. gpustack_runtime/_version.py +2 -2
  3. gpustack_runtime/_version_appendix.py +1 -1
  4. gpustack_runtime/cmds/__init__.py +2 -0
  5. gpustack_runtime/cmds/deployer.py +84 -2
  6. gpustack_runtime/cmds/images.py +2 -0
  7. gpustack_runtime/deployer/__init__.py +2 -0
  8. gpustack_runtime/deployer/__types__.py +52 -28
  9. gpustack_runtime/deployer/__utils__.py +99 -112
  10. gpustack_runtime/deployer/cdi/__init__.py +81 -0
  11. gpustack_runtime/deployer/cdi/__types__.py +667 -0
  12. gpustack_runtime/deployer/cdi/thead.py +103 -0
  13. gpustack_runtime/deployer/docker.py +42 -24
  14. gpustack_runtime/deployer/kuberentes.py +8 -4
  15. gpustack_runtime/deployer/podman.py +41 -23
  16. gpustack_runtime/detector/__init__.py +62 -3
  17. gpustack_runtime/detector/__types__.py +11 -0
  18. gpustack_runtime/detector/__utils__.py +23 -0
  19. gpustack_runtime/detector/amd.py +17 -9
  20. gpustack_runtime/detector/hygon.py +6 -1
  21. gpustack_runtime/detector/iluvatar.py +20 -5
  22. gpustack_runtime/detector/mthreads.py +8 -12
  23. gpustack_runtime/detector/nvidia.py +365 -168
  24. gpustack_runtime/detector/pyacl/__init__.py +9 -1
  25. gpustack_runtime/detector/pyamdgpu/__init__.py +8 -0
  26. gpustack_runtime/detector/pycuda/__init__.py +9 -1
  27. gpustack_runtime/detector/pydcmi/__init__.py +9 -2
  28. gpustack_runtime/detector/pyhgml/__init__.py +5879 -0
  29. gpustack_runtime/detector/pyhgml/libhgml.so +0 -0
  30. gpustack_runtime/detector/pyhgml/libuki.so +0 -0
  31. gpustack_runtime/detector/pyhsa/__init__.py +9 -0
  32. gpustack_runtime/detector/pyixml/__init__.py +89 -164
  33. gpustack_runtime/detector/pyrocmcore/__init__.py +42 -24
  34. gpustack_runtime/detector/pyrocmsmi/__init__.py +141 -138
  35. gpustack_runtime/detector/thead.py +733 -0
  36. gpustack_runtime/envs.py +128 -55
  37. {gpustack_runtime-0.1.39.post2.dist-info → gpustack_runtime-0.1.40.dist-info}/METADATA +4 -2
  38. gpustack_runtime-0.1.40.dist-info/RECORD +55 -0
  39. gpustack_runtime/detector/pymtml/__init__.py +0 -770
  40. gpustack_runtime-0.1.39.post2.dist-info/RECORD +0 -49
  41. {gpustack_runtime-0.1.39.post2.dist-info → gpustack_runtime-0.1.40.dist-info}/WHEEL +0 -0
  42. {gpustack_runtime-0.1.39.post2.dist-info → gpustack_runtime-0.1.40.dist-info}/entry_points.txt +0 -0
  43. {gpustack_runtime-0.1.39.post2.dist-info → gpustack_runtime-0.1.40.dist-info}/licenses/LICENSE +0 -0
@@ -7,10 +7,12 @@ import time
7
7
  from argparse import ArgumentParser
8
8
 
9
9
  import argcomplete
10
+ from gpustack_runner.cmds import LoadImagesSubCommand
10
11
 
11
12
  from . import deployer, detector
12
13
  from ._version import commit_id, version
13
14
  from .cmds import (
15
+ CDIGenerateSubCommand,
14
16
  CopyImagesSubCommand,
15
17
  CreateWorkloadSubCommand,
16
18
  DeleteWorkloadsSubCommand,
@@ -71,12 +73,14 @@ def main():
71
73
  InspectWorkloadSubCommand.register(subcommand_parser)
72
74
  DetectDevicesSubCommand.register(subcommand_parser)
73
75
  GetDevicesTopologySubCommand.register(subcommand_parser)
74
- ListImagesSubCommand.register(subcommand_parser)
75
- SaveImagesSubCommand.register(subcommand_parser)
76
- CopyImagesSubCommand.register(subcommand_parser)
77
76
  LogsSelfSubCommand.register(subcommand_parser)
78
77
  ExecSelfSubCommand.register(subcommand_parser)
79
78
  InspectSelfSubCommand.register(subcommand_parser)
79
+ CDIGenerateSubCommand.register(subcommand_parser)
80
+ ListImagesSubCommand.register(subcommand_parser)
81
+ SaveImagesSubCommand.register(subcommand_parser)
82
+ LoadImagesSubCommand.register(subcommand_parser)
83
+ CopyImagesSubCommand.register(subcommand_parser)
80
84
 
81
85
  # Autocomplete
82
86
  argcomplete.autocomplete(parser)
@@ -27,8 +27,8 @@ version_tuple: VERSION_TUPLE
27
27
  __commit_id__: COMMIT_ID
28
28
  commit_id: COMMIT_ID
29
29
 
30
- __version__ = version = '0.1.39.post2'
31
- __version_tuple__ = version_tuple = (0, 1, 39, 'post2')
30
+ __version__ = version = '0.1.40'
31
+ __version_tuple__ = version_tuple = (0, 1, 40)
32
32
  try:
33
33
  from ._version_appendix import git_commit
34
34
  __commit_id__ = commit_id = git_commit
@@ -1 +1 @@
1
- git_commit = "e044bab"
1
+ git_commit = "1f4627e"
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from .deployer import (
4
+ CDIGenerateSubCommand,
4
5
  CreateWorkloadSubCommand,
5
6
  DeleteWorkloadsSubCommand,
6
7
  DeleteWorkloadSubCommand,
@@ -24,6 +25,7 @@ from .images import (
24
25
  )
25
26
 
26
27
  __all__ = [
28
+ "CDIGenerateSubCommand",
27
29
  "CopyImagesSubCommand",
28
30
  "CreateWorkloadSubCommand",
29
31
  "DeleteWorkloadSubCommand",
@@ -5,7 +5,7 @@ import json
5
5
  import os
6
6
  import sys
7
7
  import time
8
- from argparse import REMAINDER
8
+ from argparse import OPTIONAL, REMAINDER
9
9
  from pathlib import Path
10
10
  from typing import TYPE_CHECKING
11
11
 
@@ -25,6 +25,7 @@ from ..deployer import (
25
25
  WorkloadStatusStateEnum,
26
26
  async_logs_self,
27
27
  async_logs_workload,
28
+ cdi_generate_config,
28
29
  create_workload,
29
30
  delete_workload,
30
31
  exec_self,
@@ -34,7 +35,7 @@ from ..deployer import (
34
35
  inspect_workload,
35
36
  list_workloads,
36
37
  )
37
- from ..detector import supported_backends
38
+ from ..detector import supported_backends, supported_manufacturers
38
39
  from .__types__ import SubCommand
39
40
 
40
41
  if TYPE_CHECKING:
@@ -92,6 +93,7 @@ class CreateWorkloadSubCommand(SubCommand):
92
93
  command_script: str | None
93
94
  port: int
94
95
  host_network: bool
96
+ privileged: bool
95
97
  check: bool
96
98
  namespace: str
97
99
  name: str
@@ -133,11 +135,19 @@ class CreateWorkloadSubCommand(SubCommand):
133
135
 
134
136
  deploy_parser.add_argument(
135
137
  "--host-network",
138
+ "--network-host",
136
139
  action="store_true",
137
140
  help="Use host network (default: False)",
138
141
  default=False,
139
142
  )
140
143
 
144
+ deploy_parser.add_argument(
145
+ "--privileged",
146
+ action="store_true",
147
+ help="Run the container in privileged mode (default: False)",
148
+ default=False,
149
+ )
150
+
141
151
  deploy_parser.add_argument(
142
152
  "--check",
143
153
  action="store_true",
@@ -183,6 +193,7 @@ class CreateWorkloadSubCommand(SubCommand):
183
193
  self.command_script = None
184
194
  self.port = args.port
185
195
  self.host_network = args.host_network
196
+ self.privileged = args.privileged
186
197
  self.check = args.check
187
198
  self.namespace = args.namespace
188
199
  self.name = args.name
@@ -237,6 +248,7 @@ class CreateWorkloadSubCommand(SubCommand):
237
248
  execution = ContainerExecution(
238
249
  command_script=self.command_script,
239
250
  args=self.extra_args,
251
+ privileged=self.privileged,
240
252
  )
241
253
  ports = (
242
254
  [
@@ -945,6 +957,76 @@ class InspectSelfSubCommand(SubCommand):
945
957
  print(inspect_self())
946
958
 
947
959
 
960
+ class CDIGenerateSubCommand(SubCommand):
961
+ """
962
+ Command to generate CDI configurations.
963
+ """
964
+
965
+ format: str
966
+ output: Path | None
967
+
968
+ @staticmethod
969
+ def register(parser: _SubParsersAction):
970
+ cdi_parser = parser.add_parser(
971
+ "cdi-generate",
972
+ help="Generate CDI configurations according to the current environment",
973
+ aliases=["cdi-gen"],
974
+ )
975
+
976
+ cdi_parser.add_argument(
977
+ "--format",
978
+ type=str,
979
+ choices=["yaml", "json"],
980
+ default="yaml",
981
+ help="Format of the CDI configurations",
982
+ )
983
+
984
+ cdi_parser.add_argument(
985
+ "output",
986
+ nargs=OPTIONAL,
987
+ help="Output directory to save CDI configurations (default: current directory)",
988
+ )
989
+
990
+ cdi_parser.set_defaults(func=CDIGenerateSubCommand)
991
+
992
+ def __init__(self, args: Namespace):
993
+ self.format = args.format
994
+ self.output = Path(args.output) if args.output else None
995
+
996
+ if self.output:
997
+ try:
998
+ if not self.output.exists():
999
+ self.output.mkdir(parents=True, exist_ok=True)
1000
+ except OSError as e:
1001
+ msg = f"Failed to prepare output directory '{self.output}' for CDI configurations"
1002
+ raise RuntimeError(msg) from e
1003
+
1004
+ if not self.output.is_dir():
1005
+ msg = f"The output path '{self.output}' is not a directory"
1006
+ raise RuntimeError(msg)
1007
+
1008
+ def run(self):
1009
+ print("\033[2J\033[H", end="")
1010
+
1011
+ generated = False
1012
+ for manu in supported_manufacturers():
1013
+ content, path = cdi_generate_config(
1014
+ manufacturer=manu,
1015
+ output=self.output,
1016
+ )
1017
+ if content:
1018
+ generated = True
1019
+ if path:
1020
+ print(f"Generated CDI configuration for '{manu}' at {path}:\n")
1021
+ else:
1022
+ print(f"Generated CDI configuration for '{manu}':\n")
1023
+ print(content)
1024
+ print()
1025
+
1026
+ if not generated:
1027
+ print("No CDI configurations were generated.")
1028
+
1029
+
948
1030
  def format_workloads_json(sts: list[WorkloadStatus]) -> str:
949
1031
  return json.dumps([st.to_dict() for st in sts], indent=2)
950
1032
 
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from gpustack_runner.cmds import (
4
4
  CopyImagesSubCommand,
5
5
  ListImagesSubCommand,
6
+ LoadImagesSubCommand,
6
7
  PlatformedImage,
7
8
  SaveImagesSubCommand,
8
9
  append_images,
@@ -20,6 +21,7 @@ append_images(
20
21
  __all__ = [
21
22
  "CopyImagesSubCommand",
22
23
  "ListImagesSubCommand",
24
+ "LoadImagesSubCommand",
23
25
  "PlatformedImage",
24
26
  "SaveImagesSubCommand",
25
27
  "append_images",
@@ -31,6 +31,7 @@ from .__types__ import (
31
31
  WorkloadStatus,
32
32
  WorkloadStatusStateEnum,
33
33
  )
34
+ from .cdi import generate_config as cdi_generate_config
34
35
  from .docker import (
35
36
  DockerDeployer,
36
37
  DockerWorkloadPlan,
@@ -602,6 +603,7 @@ __all__ = [
602
603
  "WorkloadStatusStateEnum",
603
604
  "async_logs_self",
604
605
  "async_logs_workload",
606
+ "cdi_generate_config",
605
607
  "create_workload",
606
608
  "delete_workload",
607
609
  "exec_self",
@@ -14,6 +14,7 @@ from dataclasses_json import dataclass_json
14
14
 
15
15
  from .. import envs
16
16
  from ..detector import (
17
+ ManufacturerEnum,
17
18
  Topology,
18
19
  detect_devices,
19
20
  get_devices_topologies,
@@ -21,6 +22,7 @@ from ..detector import (
21
22
  manufacturer_to_backend,
22
23
  )
23
24
  from .__utils__ import (
25
+ adjust_image_with_envs,
24
26
  correct_runner_image,
25
27
  fnv1a_32_hex,
26
28
  fnv1a_64_hex,
@@ -1018,20 +1020,8 @@ class WorkloadPlan(WorkloadSecurity):
1018
1020
  c.files.append(command_script)
1019
1021
  c.execution.command = [command_script_name] # Override command.
1020
1022
  c.execution.command_script = None
1021
- # Add default registry if needed.
1022
- if (
1023
- envs.GPUSTACK_RUNTIME_DEPLOY_DEFAULT_CONTAINER_REGISTRY
1024
- and envs.GPUSTACK_RUNTIME_DEPLOY_DEFAULT_CONTAINER_REGISTRY
1025
- not in ["docker.io", "index.docker.io"]
1026
- ):
1027
- image_registry = (
1028
- envs.GPUSTACK_RUNTIME_DEPLOY_DEFAULT_CONTAINER_NAMESPACE
1029
- )
1030
- image_split = c.image.split("/")
1031
- if len(image_split) == 1:
1032
- c.image = f"{image_registry}/library/{c.image}"
1033
- elif len(image_split) == 2:
1034
- c.image = f"{image_registry}/{c.image}"
1023
+ # Adjust images.
1024
+ c.image = adjust_image_with_envs(c.image)
1035
1025
  # Correct runner image if needed.
1036
1026
  if envs.GPUSTACK_RUNTIME_DEPLOY_CORRECT_RUNNER_IMAGE:
1037
1027
  c.image, ok = correct_runner_image(c.image)
@@ -1279,6 +1269,17 @@ class Deployer(ABC):
1279
1269
  """
1280
1270
  Thread pool for the deployer.
1281
1271
  """
1272
+ _visible_devices_manufacturers: dict[str, ManufacturerEnum] | None = None
1273
+ """
1274
+ Recorded visible devices manufacturers,
1275
+ the key is the runtime visible devices env name,
1276
+ the value is the corresponding manufacturer.
1277
+ For example:
1278
+ {
1279
+ "NVIDIA_VISIBLE_DEVICES": ManufacturerEnum.NVIDIA,
1280
+ "AMD_VISIBLE_DEVICES": ManufacturerEnum.AMD
1281
+ }.
1282
+ """
1282
1283
  _visible_devices_env: dict[str, list[str]] | None = None
1283
1284
  """
1284
1285
  Recorded visible devices envs,
@@ -1358,14 +1359,16 @@ class Deployer(ABC):
1358
1359
 
1359
1360
  def _prepare(self):
1360
1361
  """
1361
- Detect devices once, and construct critical elements for post processing, including:
1362
+ Detect devices once, and construct critical elements for post-processing, including:
1363
+ - Prepare visible devices manufacturers mapping.
1362
1364
  - Prepare visible devices environment variables mapping.
1363
1365
  - Prepare visible devices values mapping.
1364
- - Prepare topology.
1366
+ - Prepare visible devices topologies mapping.
1365
1367
  """
1366
- if self._visible_devices_env:
1368
+ if self._visible_devices_manufacturers is not None:
1367
1369
  return
1368
1370
 
1371
+ self._visible_devices_manufacturers = {}
1369
1372
  self._visible_devices_env = {}
1370
1373
  self._visible_devices_cdis = {}
1371
1374
  self._visible_devices_values = {}
@@ -1379,15 +1382,19 @@ class Deployer(ABC):
1379
1382
  if group_devices:
1380
1383
  for manu, devs in group_devices.items():
1381
1384
  backend = manufacturer_to_backend(manu)
1382
- rk = envs.GPUSTACK_RUNTIME_DETECT_BACKEND_MAP_RESOURCE_KEY.get(backend)
1385
+ resource_key = (
1386
+ envs.GPUSTACK_RUNTIME_DETECT_BACKEND_MAP_RESOURCE_KEY.get(backend)
1387
+ )
1388
+ if resource_key is None:
1389
+ continue
1383
1390
  ren = envs.GPUSTACK_RUNTIME_DEPLOY_RESOURCE_KEY_MAP_RUNTIME_VISIBLE_DEVICES.get(
1384
- rk,
1391
+ resource_key,
1385
1392
  )
1386
1393
  ben_list = envs.GPUSTACK_RUNTIME_DEPLOY_RESOURCE_KEY_MAP_BACKEND_VISIBLE_DEVICES.get(
1387
- rk,
1394
+ resource_key,
1388
1395
  )
1389
- cdi = envs.GPUSTACK_RUNTIME_DEPLOY_RESOURCE_KEY_MAP_CONTAINER_DEVICE_INTERFACES.get(
1390
- rk,
1396
+ cdi = envs.GPUSTACK_RUNTIME_DEPLOY_RESOURCE_KEY_MAP_CDI.get(
1397
+ resource_key,
1391
1398
  )
1392
1399
  if ren and ben_list:
1393
1400
  valued_uuid = (
@@ -1401,6 +1408,8 @@ class Deployer(ABC):
1401
1408
  dev_uuids.append(dev.uuid)
1402
1409
  dev_indexes.append(str(dev.index))
1403
1410
  dev_indexes_alignment[str(dev.index)] = str(dev_i)
1411
+ # Map runtime visible devices env <-> manufacturer.
1412
+ self._visible_devices_manufacturers[ren] = manu
1404
1413
  # Map runtime visible devices env <-> backend visible devices env list.
1405
1414
  self._visible_devices_env[ren] = ben_list
1406
1415
  # Map runtime visible devices env <-> CDI key.
@@ -1433,16 +1442,28 @@ class Deployer(ABC):
1433
1442
  return
1434
1443
 
1435
1444
  # Fallback to unknown backend
1436
- self._visible_devices_env["UNKNOWN_RUNTIME_VISIBLE_DEVICES"] = []
1437
- self._visible_devices_values["UNKNOWN_RUNTIME_VISIBLE_DEVICES"] = ["all"]
1445
+ ren = "UNKNOWN_RUNTIME_VISIBLE_DEVICES"
1446
+ self._visible_devices_manufacturers[ren] = ManufacturerEnum.UNKNOWN
1447
+ self._visible_devices_env[ren] = []
1448
+ self._visible_devices_cdis[ren] = "unknown/devices"
1449
+ self._visible_devices_values[ren] = ["all"]
1438
1450
 
1439
- def get_visible_devices_values(
1451
+ def get_visible_devices_materials(
1440
1452
  self,
1441
- ) -> (dict[str, list[str]], dict[str, str], dict[str, list[str]]):
1453
+ ) -> (
1454
+ dict[str, ManufacturerEnum],
1455
+ dict[str, list[str]],
1456
+ dict[str, str],
1457
+ dict[str, list[str]],
1458
+ ):
1442
1459
  """
1443
1460
  Return the visible devices environment variables, cdis and values mappings.
1444
1461
  For example:
1445
1462
  (
1463
+ {
1464
+ "NVIDIA_VISIBLE_DEVICES": ManufacturerEnum.NVIDIA,
1465
+ "AMD_VISIBLE_DEVICES": ManufacturerEnum.AMD
1466
+ },
1446
1467
  {
1447
1468
  "NVIDIA_VISIBLE_DEVICES": ["CUDA_VISIBLE_DEVICES"],
1448
1469
  "AMD_VISIBLE_DEVICES": ["HIP_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES"]
@@ -1458,10 +1479,12 @@ class Deployer(ABC):
1458
1479
  ).
1459
1480
 
1460
1481
  Returns:
1461
- A tuple of two dictionaries:
1482
+ A tuple of four dictionaries:
1462
1483
  - The first dictionary maps runtime visible devices environment variable names
1463
- to lists of backend visible devices environment variable names.
1484
+ to corresponding manufacturers.
1464
1485
  - The second dictionary maps runtime visible devices environment variable names
1486
+ to lists of backend visible devices environment variable names.
1487
+ - The third dictionary maps runtime visible devices environment variable names
1465
1488
  to corresponding CDI keys.
1466
1489
  - The last dictionary maps runtime visible devices environment variable names
1467
1490
  to lists of device indexes or UUIDs.
@@ -1469,6 +1492,7 @@ class Deployer(ABC):
1469
1492
  """
1470
1493
  self._prepare()
1471
1494
  return (
1495
+ self._visible_devices_manufacturers,
1472
1496
  self._visible_devices_env,
1473
1497
  self._visible_devices_cdis,
1474
1498
  self._visible_devices_values,
@@ -6,11 +6,18 @@ import json
6
6
  import platform
7
7
  import re
8
8
  from functools import lru_cache
9
+ from pathlib import Path
9
10
  from typing import Any
10
11
 
11
12
  import yaml
12
- from gpustack_runner import DockerImage, list_backend_runners
13
+ from gpustack_runner import (
14
+ DockerImage,
15
+ list_backend_runners,
16
+ parse_image,
17
+ replace_image_with,
18
+ )
13
19
 
20
+ from .. import envs
14
21
  from ..detector import backend_to_manufacturer, detect_backend, detect_devices
15
22
  from ..detector.ascend import get_ascend_cann_variant
16
23
 
@@ -338,6 +345,49 @@ def safe_yaml(obj: Any, **kwargs) -> str:
338
345
  return yaml.dump(dict_data, **kwargs)
339
346
 
340
347
 
348
+ def load_yaml_or_json(path: str | Path) -> list[dict] | dict:
349
+ """
350
+ Load a YAML or JSON string into to a dict.
351
+
352
+ Args:
353
+ path:
354
+ The path to the CDI configuration file.
355
+
356
+ Returns:
357
+ The loaded dict.
358
+
359
+ """
360
+ if isinstance(path, str):
361
+ path = Path(path)
362
+
363
+ if not path.exists():
364
+ msg = f"File not found: {path}"
365
+ raise FileNotFoundError(msg)
366
+
367
+ content = path.read_text(encoding="utf-8")
368
+
369
+ if path.suffix in {".yaml", ".yml"}:
370
+ try:
371
+ ret = list(yaml.safe_load_all(content))
372
+ except yaml.YAMLError as e:
373
+ msg = f"Failed to parse YAML file: {path}"
374
+ raise RuntimeError(msg) from e
375
+ else:
376
+ if len(ret) == 1:
377
+ return ret[0]
378
+ return ret
379
+
380
+ if path.suffix == ".json":
381
+ try:
382
+ return json.loads(content)
383
+ except json.JSONDecodeError as e:
384
+ msg = f"Failed to parse JSON file: {path}"
385
+ raise RuntimeError(msg) from e
386
+
387
+ msg = f"Unsupported file format: {path.suffix}"
388
+ raise RuntimeError(msg)
389
+
390
+
341
391
  @lru_cache
342
392
  def compare_versions(v1: str | None, v2: str | None) -> int:
343
393
  """
@@ -672,140 +722,77 @@ def bytes_to_human_readable(size_in_bytes: int) -> str:
672
722
  return f"{size_in_bytes} B"
673
723
 
674
724
 
675
- def replace_image_with(
676
- image: str,
677
- registry: str | None = None,
678
- namespace: str | None = None,
679
- repository: str | None = None,
680
- ) -> str:
725
+ def sensitive_env_var(name: str) -> bool:
681
726
  """
682
- Replace the registry, namespace, and repository of a Docker image string.
683
-
684
- The given image string is parsed into its components (registry, namespace, repository, tag),
685
- and the specified components are replaced with the provided values.
686
-
687
- The format of a Docker image string is:
688
- [registry/][namespace/]repository[:tag|@digest]
727
+ Check if the given environment variable name is considered sensitive.
689
728
 
690
729
  Args:
691
- image:
692
- The original Docker image string.
693
- registry:
694
- The new registry to use. If None, keep the original registry.
695
- namespace:
696
- The new namespace to use. If None, keep the original namespace.
697
- repository:
698
- The new repository to use. If None, keep the original repository.
730
+ name:
731
+ The environment variable name to check.
699
732
 
700
733
  Returns:
701
- The modified Docker image string.
734
+ True if the name is considered sensitive, False otherwise.
702
735
 
703
736
  """
704
- if not image or (not registry and not namespace and not repository):
705
- return image
706
-
707
- registry = registry.strip() if registry else None
708
- namespace = namespace.strip() if namespace else None
709
- repository = repository.strip() if repository else None
710
-
711
- image_reg, image_ns, image_repo, image_tag = (
712
- None,
713
- None,
714
- None,
715
- None,
716
- )
717
- image_rest = image.strip()
718
-
719
- # Get tag.
720
- parts = image_rest.rsplit("@", maxsplit=1)
721
- if len(parts) == 2:
722
- image_rest, image_tag = parts
723
- else:
724
- parts = image_rest.rsplit(":", maxsplit=1)
725
- if len(parts) == 2 and "/" not in parts[1]:
726
- image_rest, image_tag = parts
727
- if not image_rest:
728
- return image
729
-
730
- # Get repository.
731
- parts = image_rest.rsplit("/", maxsplit=1)
732
- if len(parts) == 2:
733
- image_rest, image_repo = parts
734
- else:
735
- image_rest, image_repo = None, image_rest
736
-
737
- # Get namespace.
738
- if image_rest:
739
- parts = image_rest.rsplit("/", maxsplit=1)
740
- if len(parts) == 2:
741
- image_reg, image_ns = parts
742
- else:
743
- image_reg, image_ns = None, image_rest
744
-
745
- return make_image_with(
746
- repository=repository or image_repo,
747
- registry=registry or image_reg,
748
- namespace=namespace or image_ns,
749
- tag=image_tag,
750
- )
737
+ return name.lower().endswith(_SENSITIVE_ENVS_SUFFIX)
751
738
 
752
739
 
753
- def make_image_with(
754
- repository: str,
755
- registry: str | None = None,
756
- namespace: str | None = None,
757
- tag: str | None = None,
758
- ) -> str:
740
+ def isexception(
741
+ e: Exception,
742
+ target_exception_types: type[Exception] | tuple[type[Exception], ...],
743
+ ) -> bool:
759
744
  """
760
- Make a Docker image string from the given registry, namespace, repository, and tag.
761
-
762
- The format of a Docker image string is:
763
- [registry/][namespace/]repository[:tag|@digest]
745
+ Check if the given exception is caused by any of the target exception types.
764
746
 
765
747
  Args:
766
- repository:
767
- The repository name.
768
- registry:
769
- The registry to use. If None, no registry will be used.
770
- namespace:
771
- The namespace to use. If None, no namespace will be used.
772
- tag:
773
- The tag to use. If None, no tag will be used.
748
+ e:
749
+ The exception to check.
750
+ target_exception_types:
751
+ A tuple of target exception types.
774
752
 
775
753
  Returns:
776
- The Docker image string.
754
+ True if the exception is caused by any of the target exception types, False otherwise.
777
755
 
778
756
  """
779
- if not repository or (not registry and not namespace and not tag):
780
- return repository
757
+ if isinstance(e, target_exception_types):
758
+ return True
781
759
 
782
- image = ""
783
- if registry:
784
- image += f"{registry}/"
785
- if namespace:
786
- image += f"{namespace}/"
787
- elif registry:
788
- image += "library/"
789
- image += repository
790
- if not tag:
791
- return image
792
- if tag.startswith("sha256:"):
793
- image += f"@{tag}"
794
- else:
795
- image += f":{tag}"
796
- return image
760
+ cause = getattr(e, "__cause__", None)
761
+ if cause and isexception(cause, target_exception_types):
762
+ return True
797
763
 
764
+ context = getattr(e, "__context__", None)
765
+ return bool(context and isexception(context, target_exception_types))
798
766
 
799
- def sensitive_env_var(name: str) -> bool:
767
+
768
+ @lru_cache
769
+ def adjust_image_with_envs(image: str) -> str:
800
770
  """
801
- Check if the given environment variable name is considered sensitive.
771
+ Replace the registry and namespace of the given image
772
+ with the default ones from environment variables if applicable.
802
773
 
803
774
  Args:
804
- name:
805
- The environment variable name to check.
775
+ image:
776
+ The image to replace.
806
777
 
807
778
  Returns:
808
- True if the name is considered sensitive, False otherwise.
779
+ The replaced image.
809
780
 
810
781
  """
811
- return name.lower().endswith(_SENSITIVE_ENVS_SUFFIX)
782
+ original_reg, original_ns, _, _ = parse_image(image)
783
+
784
+ target_reg = original_reg or envs.GPUSTACK_RUNTIME_DEPLOY_DEFAULT_CONTAINER_REGISTRY
785
+ target_ns = (
786
+ original_ns
787
+ if original_ns != "gpustack"
788
+ else envs.GPUSTACK_RUNTIME_DEPLOY_DEFAULT_CONTAINER_NAMESPACE
789
+ )
790
+
791
+ if original_reg == target_reg and original_ns == target_ns:
792
+ return image
793
+
794
+ return replace_image_with(
795
+ image,
796
+ registry=target_reg,
797
+ namespace=target_ns,
798
+ )