gpustack-runtime 0.1.39.post2__py3-none-any.whl → 0.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. gpustack_runtime/__main__.py +7 -3
  2. gpustack_runtime/_version.py +2 -2
  3. gpustack_runtime/_version_appendix.py +1 -1
  4. gpustack_runtime/cmds/__init__.py +2 -0
  5. gpustack_runtime/cmds/deployer.py +84 -2
  6. gpustack_runtime/cmds/images.py +2 -0
  7. gpustack_runtime/deployer/__init__.py +2 -0
  8. gpustack_runtime/deployer/__types__.py +52 -28
  9. gpustack_runtime/deployer/__utils__.py +99 -112
  10. gpustack_runtime/deployer/cdi/__init__.py +81 -0
  11. gpustack_runtime/deployer/cdi/__types__.py +667 -0
  12. gpustack_runtime/deployer/cdi/thead.py +103 -0
  13. gpustack_runtime/deployer/docker.py +42 -24
  14. gpustack_runtime/deployer/kuberentes.py +8 -4
  15. gpustack_runtime/deployer/podman.py +41 -23
  16. gpustack_runtime/detector/__init__.py +62 -3
  17. gpustack_runtime/detector/__types__.py +11 -0
  18. gpustack_runtime/detector/__utils__.py +23 -0
  19. gpustack_runtime/detector/amd.py +17 -9
  20. gpustack_runtime/detector/hygon.py +6 -1
  21. gpustack_runtime/detector/iluvatar.py +20 -5
  22. gpustack_runtime/detector/mthreads.py +8 -12
  23. gpustack_runtime/detector/nvidia.py +365 -168
  24. gpustack_runtime/detector/pyacl/__init__.py +9 -1
  25. gpustack_runtime/detector/pyamdgpu/__init__.py +8 -0
  26. gpustack_runtime/detector/pycuda/__init__.py +9 -1
  27. gpustack_runtime/detector/pydcmi/__init__.py +9 -2
  28. gpustack_runtime/detector/pyhgml/__init__.py +5879 -0
  29. gpustack_runtime/detector/pyhgml/libhgml.so +0 -0
  30. gpustack_runtime/detector/pyhgml/libuki.so +0 -0
  31. gpustack_runtime/detector/pyhsa/__init__.py +9 -0
  32. gpustack_runtime/detector/pyixml/__init__.py +89 -164
  33. gpustack_runtime/detector/pyrocmcore/__init__.py +42 -24
  34. gpustack_runtime/detector/pyrocmsmi/__init__.py +141 -138
  35. gpustack_runtime/detector/thead.py +733 -0
  36. gpustack_runtime/envs.py +128 -55
  37. {gpustack_runtime-0.1.39.post2.dist-info → gpustack_runtime-0.1.40.dist-info}/METADATA +4 -2
  38. gpustack_runtime-0.1.40.dist-info/RECORD +55 -0
  39. gpustack_runtime/detector/pymtml/__init__.py +0 -770
  40. gpustack_runtime-0.1.39.post2.dist-info/RECORD +0 -49
  41. {gpustack_runtime-0.1.39.post2.dist-info → gpustack_runtime-0.1.40.dist-info}/WHEEL +0 -0
  42. {gpustack_runtime-0.1.39.post2.dist-info → gpustack_runtime-0.1.40.dist-info}/entry_points.txt +0 -0
  43. {gpustack_runtime-0.1.39.post2.dist-info → gpustack_runtime-0.1.40.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,103 @@
1
+ from __future__ import annotations as __future_annotations__
2
+
3
+ from ...detector import (
4
+ Devices,
5
+ ManufacturerEnum,
6
+ detect_devices,
7
+ filter_devices_by_manufacturer,
8
+ )
9
+ from .__types__ import (
10
+ Config,
11
+ ConfigContainerEdits,
12
+ ConfigDevice,
13
+ Generator,
14
+ manufacturer_to_config_kind,
15
+ )
16
+
17
+
18
+ class THeadGenerator(Generator):
19
+ """
20
+ CDI generator for T-Head devices.
21
+ """
22
+
23
+ def __init__(self):
24
+ super().__init__(ManufacturerEnum.THEAD)
25
+
26
+ def generate(self, devices: Devices | None = None) -> Config | None:
27
+ """
28
+ Generate the CDI configuration for T-Head devices.
29
+
30
+ Args:
31
+ devices: The detected devices.
32
+ If None, all available devices are considered.
33
+
34
+ Returns:
35
+ The Config object, or None if not supported.
36
+
37
+ """
38
+ if devices is None:
39
+ devices = detect_devices(manufacturer=self.manufacturer)
40
+ else:
41
+ devices = filter_devices_by_manufacturer(
42
+ devices,
43
+ manufacturer=self.manufacturer,
44
+ )
45
+
46
+ if not devices:
47
+ return None
48
+
49
+ kind = manufacturer_to_config_kind(self.manufacturer)
50
+ if not kind:
51
+ return None
52
+
53
+ cdi_devices: list[ConfigDevice] = []
54
+
55
+ all_container_edits_device_nodes = [
56
+ "/dev/alixpu",
57
+ "/dev/alixpu_ctl",
58
+ ]
59
+ for dev in devices:
60
+ if not dev:
61
+ continue
62
+ all_container_edits_device_nodes.append(
63
+ f"/dev/alixpu_ppu{dev.index}",
64
+ )
65
+
66
+ # Add specific container edits for each device
67
+ cdi_container_edits = ConfigContainerEdits(
68
+ device_nodes=[
69
+ "/dev/alixpu",
70
+ "/dev/alixpu_ctl",
71
+ f"/dev/alixpu_ppu{dev.index}",
72
+ ],
73
+ )
74
+ cdi_devices.append(
75
+ ConfigDevice(
76
+ name=str(dev.index),
77
+ container_edits=cdi_container_edits,
78
+ ),
79
+ )
80
+ cdi_devices.append(
81
+ ConfigDevice(
82
+ name=dev.uuid,
83
+ container_edits=cdi_container_edits,
84
+ ),
85
+ )
86
+
87
+ if not cdi_devices:
88
+ return None
89
+
90
+ # Add common container edits for all devices
91
+ cdi_devices.append(
92
+ ConfigDevice(
93
+ name="all",
94
+ container_edits=ConfigContainerEdits(
95
+ device_nodes=all_container_edits_device_nodes,
96
+ ),
97
+ ),
98
+ )
99
+
100
+ return Config(
101
+ kind=kind,
102
+ devices=cdi_devices,
103
+ )
@@ -22,7 +22,7 @@ import docker.models.images
22
22
  import docker.models.volumes
23
23
  import docker.types
24
24
  from dataclasses_json import dataclass_json
25
- from docker.utils import parse_repository_tag
25
+ from gpustack_runner import split_image
26
26
  from tqdm import tqdm
27
27
 
28
28
  from .. import envs
@@ -48,11 +48,13 @@ from .__types__ import (
48
48
  )
49
49
  from .__utils__ import (
50
50
  _MiB,
51
+ adjust_image_with_envs,
51
52
  bytes_to_human_readable,
52
- replace_image_with,
53
+ isexception,
53
54
  safe_json,
54
55
  sensitive_env_var,
55
56
  )
57
+ from .cdi import generate_config as cdi_generate_config
56
58
 
57
59
  if TYPE_CHECKING:
58
60
  from collections.abc import Callable, Generator
@@ -146,16 +148,11 @@ class DockerWorkloadPlan(WorkloadPlan):
146
148
  # Default and validate in the base class.
147
149
  super().validate_and_default()
148
150
 
149
- # Adjust default image namespace if needed.
150
- if namespace := envs.GPUSTACK_RUNTIME_DEPLOY_DEFAULT_CONTAINER_NAMESPACE:
151
- self.pause_image = replace_image_with(
152
- image=self.pause_image,
153
- namespace=namespace,
154
- )
155
- self.unhealthy_restart_image = replace_image_with(
156
- image=self.unhealthy_restart_image,
157
- namespace=namespace,
158
- )
151
+ # Adjust images.
152
+ self.pause_image = adjust_image_with_envs(self.pause_image)
153
+ self.unhealthy_restart_image = adjust_image_with_envs(
154
+ self.unhealthy_restart_image,
155
+ )
159
156
 
160
157
 
161
158
  @dataclass_json
@@ -330,7 +327,7 @@ class DockerDeployer(EndoscopicDeployer):
330
327
  if envs.GPUSTACK_RUNTIME_DEPLOY.lower() not in ("auto", _NAME):
331
328
  return supported
332
329
 
333
- client = DockerDeployer._get_client()
330
+ client = DockerDeployer._get_client(timeout=3)
334
331
  if client:
335
332
  try:
336
333
  supported = client.ping()
@@ -340,16 +337,24 @@ class DockerDeployer(EndoscopicDeployer):
340
337
  "Connected to Docker API server: %s",
341
338
  version_info,
342
339
  )
343
- except docker.errors.APIError:
344
- debug_log_exception(logger, "Failed to connect to Docker API server")
340
+ except docker.errors.APIError as e:
341
+ if not isexception(e, FileNotFoundError):
342
+ debug_log_exception(
343
+ logger,
344
+ "Failed to connect to Docker API server",
345
+ )
345
346
 
346
347
  return supported
347
348
 
348
349
  @staticmethod
349
- def _get_client() -> docker.DockerClient | None:
350
+ def _get_client(**kwargs) -> docker.DockerClient | None:
350
351
  """
351
352
  Return a Docker client.
352
353
 
354
+ Args:
355
+ **kwargs:
356
+ Additional arguments to pass to docker.from_env().
357
+
353
358
  Returns:
354
359
  A Docker client if available, None otherwise.
355
360
 
@@ -365,9 +370,9 @@ class DockerDeployer(EndoscopicDeployer):
365
370
  os_env = os.environ.copy()
366
371
  if envs.GPUSTACK_RUNTIME_DOCKER_HOST:
367
372
  os_env["DOCKER_HOST"] = envs.GPUSTACK_RUNTIME_DOCKER_HOST
368
- client = docker.from_env(environment=os_env)
373
+ client = docker.from_env(environment=os_env, **kwargs)
369
374
  except docker.errors.DockerException as e:
370
- if "FileNotFoundError" not in str(e):
375
+ if not isexception(e, FileNotFoundError):
371
376
  debug_log_exception(logger, "Failed to get Docker client")
372
377
 
373
378
  return client
@@ -428,8 +433,7 @@ class DockerDeployer(EndoscopicDeployer):
428
433
  try:
429
434
  logger.info("Pulling image %s", image)
430
435
 
431
- repo, tag = parse_repository_tag(image)
432
- tag = tag or "latest"
436
+ repo, tag = split_image(image, fill_blank_tag=True)
433
437
  auth_config = None
434
438
  if (
435
439
  envs.GPUSTACK_RUNTIME_DEPLOY_DEFAULT_CONTAINER_REGISTRY_USERNAME
@@ -840,7 +844,7 @@ class DockerDeployer(EndoscopicDeployer):
840
844
  msg = f"Failed to upload ephemeral files to container {container.name}"
841
845
  raise OperationError(msg)
842
846
 
843
- def _create_containers(
847
+ def _create_containers( # noqa: C901
844
848
  self,
845
849
  workload: DockerWorkloadPlan,
846
850
  ephemeral_volume_name_mapping: dict[str, str],
@@ -953,7 +957,9 @@ class DockerDeployer(EndoscopicDeployer):
953
957
 
954
958
  r_k_runtime_env = workload.resource_key_runtime_env_mapping or {}
955
959
  r_k_backend_env = workload.resource_key_backend_env_mapping or {}
956
- vd_env, vd_cdis, vd_values = self.get_visible_devices_values()
960
+ vd_manus, vd_env, vd_cdis, vd_values = (
961
+ self.get_visible_devices_materials()
962
+ )
957
963
  for r_k, r_v in c.resources.items():
958
964
  match r_k:
959
965
  case "cpu":
@@ -995,6 +1001,14 @@ class DockerDeployer(EndoscopicDeployer):
995
1001
 
996
1002
  privileged = create_options.get("privileged", False)
997
1003
 
1004
+ # Generate CDI config if not yet.
1005
+ if cdi and envs.GPUSTACK_RUNTIME_DEPLOY_CDI_SPECS_GENERATE:
1006
+ for re in runtime_env:
1007
+ cdi_generate_config(
1008
+ manufacturer=vd_manus[re],
1009
+ output=envs.GPUSTACK_RUNTIME_DEPLOY_CDI_SPECS_DIRECTORY,
1010
+ )
1011
+
998
1012
  # Configure device access environment variable.
999
1013
  if r_v == "all" and backend_env:
1000
1014
  # Configure privileged if requested all devices.
@@ -1213,8 +1227,12 @@ class DockerDeployer(EndoscopicDeployer):
1213
1227
  self_container_envs: dict[str, str] = dict(
1214
1228
  item.split("=", 1) for item in self_container.attrs["Config"].get("Env", [])
1215
1229
  )
1216
- self_image_envs: dict[str, str] = dict(
1217
- item.split("=", 1) for item in self_image.attrs["Config"].get("Env", [])
1230
+ self_image_envs: dict[str, str] = (
1231
+ dict(
1232
+ item.split("=", 1) for item in self_image.attrs["Config"].get("Env", [])
1233
+ )
1234
+ if self_image.attrs["Config"]
1235
+ else {}
1218
1236
  )
1219
1237
  mirrored_envs: dict[str, str] = {
1220
1238
  # Filter out gpustack-internal envs and same-as-image envs.
@@ -319,7 +319,7 @@ class KubernetesDeployer(EndoscopicDeployer):
319
319
  if client:
320
320
  try:
321
321
  version_api = kubernetes.client.VersionApi(client)
322
- version_info = version_api.get_code()
322
+ version_info = version_api.get_code(_request_timeout=3)
323
323
  supported = version_info is not None
324
324
  if envs.GPUSTACK_RUNTIME_LOG_EXCEPTION:
325
325
  logger.debug(
@@ -337,10 +337,14 @@ class KubernetesDeployer(EndoscopicDeployer):
337
337
  return supported
338
338
 
339
339
  @staticmethod
340
- def _get_client() -> kubernetes.client.ApiClient | None:
340
+ def _get_client(**kwargs) -> kubernetes.client.ApiClient | None:
341
341
  """
342
342
  Return a Kubernetes API client.
343
343
 
344
+ Args:
345
+ **kwargs:
346
+ Additional arguments to pass to the Kubernetes config loader.
347
+
344
348
  Returns:
345
349
  A Kubernetes API client if the configuration is valid, None otherwise.
346
350
 
@@ -353,7 +357,7 @@ class KubernetesDeployer(EndoscopicDeployer):
353
357
  contextlib.redirect_stdout(dev_null),
354
358
  contextlib.redirect_stderr(dev_null),
355
359
  ):
356
- kubernetes.config.load_config()
360
+ kubernetes.config.load_config(**kwargs)
357
361
  client = kubernetes.client.ApiClient()
358
362
  client.user_agent = "gpustack/runtime"
359
363
  except kubernetes.config.config_exception.ConfigException:
@@ -989,7 +993,7 @@ class KubernetesDeployer(EndoscopicDeployer):
989
993
  resources: dict[str, str] = {}
990
994
  r_k_runtime_env = workload.resource_key_runtime_env_mapping or {}
991
995
  r_k_backend_env = workload.resource_key_backend_env_mapping or {}
992
- vd_env, _, vd_values = self.get_visible_devices_values()
996
+ _, vd_env, _, vd_values = self.get_visible_devices_materials()
993
997
  for r_k, r_v in c.resources.items():
994
998
  if r_k in ("cpu", "memory"):
995
999
  resources[r_k] = str(r_v)
@@ -23,7 +23,7 @@ import podman.domain.images
23
23
  import podman.domain.volumes
24
24
  import podman.errors
25
25
  from dataclasses_json import dataclass_json
26
- from podman.api import parse_repository
26
+ from gpustack_runner import split_image
27
27
  from podman.domain.containers_create import CreateMixin
28
28
  from tqdm import tqdm
29
29
 
@@ -51,11 +51,13 @@ from .__types__ import (
51
51
  )
52
52
  from .__utils__ import (
53
53
  _MiB,
54
+ adjust_image_with_envs,
54
55
  bytes_to_human_readable,
55
- replace_image_with,
56
+ isexception,
56
57
  safe_json,
57
58
  sensitive_env_var,
58
59
  )
60
+ from .cdi import generate_config as cdi_generate_config
59
61
 
60
62
  if TYPE_CHECKING:
61
63
  from collections.abc import Callable, Generator
@@ -149,16 +151,11 @@ class PodmanWorkloadPlan(WorkloadPlan):
149
151
  # Default and validate in the base class.
150
152
  super().validate_and_default()
151
153
 
152
- # Adjust default image namespace if needed.
153
- if namespace := envs.GPUSTACK_RUNTIME_DEPLOY_DEFAULT_CONTAINER_NAMESPACE:
154
- self.pause_image = replace_image_with(
155
- image=self.pause_image,
156
- namespace=namespace,
157
- )
158
- self.unhealthy_restart_image = replace_image_with(
159
- image=self.unhealthy_restart_image,
160
- namespace=namespace,
161
- )
154
+ # Adjust images.
155
+ self.pause_image = adjust_image_with_envs(self.pause_image)
156
+ self.unhealthy_restart_image = adjust_image_with_envs(
157
+ self.unhealthy_restart_image,
158
+ )
162
159
 
163
160
 
164
161
  @dataclass_json
@@ -333,7 +330,7 @@ class PodmanDeployer(EndoscopicDeployer):
333
330
  if envs.GPUSTACK_RUNTIME_DEPLOY.lower() not in ("auto", _NAME):
334
331
  return supported
335
332
 
336
- client = PodmanDeployer._get_client()
333
+ client = PodmanDeployer._get_client(timeout=3)
337
334
  if client:
338
335
  try:
339
336
  supported = client.ping()
@@ -343,16 +340,24 @@ class PodmanDeployer(EndoscopicDeployer):
343
340
  "Connected to Podman API server: %s",
344
341
  version_info,
345
342
  )
346
- except podman.errors.APIError:
347
- debug_log_exception(logger, "Failed to connect to Podman API server")
343
+ except podman.errors.APIError as e:
344
+ if not isexception(e, FileNotFoundError):
345
+ debug_log_exception(
346
+ logger,
347
+ "Failed to connect to Podman API server",
348
+ )
348
349
 
349
350
  return supported
350
351
 
351
352
  @staticmethod
352
- def _get_client() -> podman.PodmanClient | None:
353
+ def _get_client(**kwargs) -> podman.PodmanClient | None:
353
354
  """
354
355
  Return a Podman client.
355
356
 
357
+ Args:
358
+ **kwargs:
359
+ Additional arguments to pass to podman.from_env().
360
+
356
361
  Returns:
357
362
  A Podman client if available, None otherwise.
358
363
 
@@ -368,9 +373,9 @@ class PodmanDeployer(EndoscopicDeployer):
368
373
  os_env = os.environ.copy()
369
374
  if envs.GPUSTACK_RUNTIME_PODMAN_HOST:
370
375
  os_env["CONTAINER_HOST"] = envs.GPUSTACK_RUNTIME_PODMAN_HOST
371
- client = podman.from_env(environment=os_env)
376
+ client = podman.from_env(environment=os_env, **kwargs)
372
377
  except podman.errors.DockerException as e:
373
- if "FileNotFoundError" not in str(e):
378
+ if not isexception(e, FileNotFoundError):
374
379
  debug_log_exception(logger, "Failed to get Podman client")
375
380
 
376
381
  return client
@@ -431,8 +436,7 @@ class PodmanDeployer(EndoscopicDeployer):
431
436
  try:
432
437
  logger.info("Pulling image %s", image)
433
438
 
434
- repo, tag = parse_repository(image)
435
- tag = tag or "latest"
439
+ repo, tag = split_image(image, fill_blank_tag=True)
436
440
  auth_config = None
437
441
  if (
438
442
  envs.GPUSTACK_RUNTIME_DEPLOY_DEFAULT_CONTAINER_REGISTRY_USERNAME
@@ -949,7 +953,9 @@ class PodmanDeployer(EndoscopicDeployer):
949
953
  if c.resources:
950
954
  r_k_runtime_env = workload.resource_key_runtime_env_mapping or {}
951
955
  r_k_backend_env = workload.resource_key_backend_env_mapping or {}
952
- vd_env, vd_cdis, vd_values = self.get_visible_devices_values()
956
+ vd_manus, vd_env, vd_cdis, vd_values = (
957
+ self.get_visible_devices_materials()
958
+ )
953
959
  for r_k, r_v in c.resources.items():
954
960
  match r_k:
955
961
  case "cpu":
@@ -991,6 +997,14 @@ class PodmanDeployer(EndoscopicDeployer):
991
997
 
992
998
  privileged = create_options.get("privileged", False)
993
999
 
1000
+ # Generate CDI config if not yet.
1001
+ if envs.GPUSTACK_RUNTIME_DEPLOY_CDI_SPECS_GENERATE:
1002
+ for re in runtime_env:
1003
+ cdi_generate_config(
1004
+ manufacturer=vd_manus[re],
1005
+ output=envs.GPUSTACK_RUNTIME_DEPLOY_CDI_SPECS_DIRECTORY,
1006
+ )
1007
+
994
1008
  # Configure device access environment variable.
995
1009
  if r_v == "all" and backend_env:
996
1010
  # Configure privileged if requested all devices.
@@ -1189,8 +1203,12 @@ class PodmanDeployer(EndoscopicDeployer):
1189
1203
  self_container_envs: dict[str, str] = dict(
1190
1204
  item.split("=", 1) for item in self_container.attrs["Config"].get("Env", [])
1191
1205
  )
1192
- self_image_envs: dict[str, str] = dict(
1193
- item.split("=", 1) for item in self_image.attrs["Config"].get("Env", [])
1206
+ self_image_envs: dict[str, str] = (
1207
+ dict(
1208
+ item.split("=", 1) for item in self_image.attrs["Config"].get("Env", [])
1209
+ )
1210
+ if self_image.attrs["Config"]
1211
+ else {}
1194
1212
  )
1195
1213
  mirrored_envs: dict[str, str] = {
1196
1214
  # Filter out gpustack-internal envs and same-as-image envs.
@@ -24,6 +24,7 @@ from .iluvatar import IluvatarDetector
24
24
  from .metax import MetaXDetector
25
25
  from .mthreads import MThreadsDetector
26
26
  from .nvidia import NVIDIADetector
27
+ from .thead import THeadDetector
27
28
 
28
29
  logger = logging.getLogger(__package__)
29
30
 
@@ -36,6 +37,7 @@ _DETECTORS: list[Detector] = [
36
37
  MetaXDetector(),
37
38
  MThreadsDetector(),
38
39
  NVIDIADetector(),
40
+ THeadDetector(),
39
41
  ]
40
42
  """
41
43
  List of all detectors.
@@ -60,7 +62,10 @@ def supported_list() -> list[Detector]:
60
62
  return [det for det in _DETECTORS if det.is_supported()]
61
63
 
62
64
 
63
- def detect_backend(fast: bool = True) -> str | list[str]:
65
+ def detect_backend(
66
+ fast: bool = True,
67
+ manufacturer: ManufacturerEnum = None,
68
+ ) -> str | list[str]:
64
69
  """
65
70
  Detect all supported backend.
66
71
 
@@ -68,12 +73,21 @@ def detect_backend(fast: bool = True) -> str | list[str]:
68
73
  fast:
69
74
  If True, return the first detected backend.
70
75
  Otherwise, return a list of all detected backends.
76
+ manufacturer:
77
+ Manufacturer to filter the detection, implies `fast=True`.
78
+ If None, detect all available manufacturers.
71
79
 
72
80
  Returns:
73
81
  A string of the detected backend if `fast` is True and a backend is found.
74
82
  A list of detected backends if `fast` is False.
75
83
 
76
84
  """
85
+ if manufacturer:
86
+ det = _DETECTORS_MAP.get(manufacturer)
87
+ if det and det.is_supported():
88
+ return det.backend
89
+ return ""
90
+
77
91
  backends: list[str] = []
78
92
 
79
93
  for det in _DETECTORS:
@@ -88,7 +102,10 @@ def detect_backend(fast: bool = True) -> str | list[str]:
88
102
  return backends
89
103
 
90
104
 
91
- def detect_devices(fast: bool = True) -> Devices:
105
+ def detect_devices(
106
+ fast: bool = True,
107
+ manufacturer: ManufacturerEnum = None,
108
+ ) -> Devices:
92
109
  """
93
110
  Detect all available devices.
94
111
 
@@ -96,6 +113,9 @@ def detect_devices(fast: bool = True) -> Devices:
96
113
  fast:
97
114
  If True, return devices from the first supported detector.
98
115
  Otherwise, return devices from all supported detectors.
116
+ manufacturer:
117
+ Manufacturer to filter the detection, implies `fast=True`.
118
+ If None, detect all available manufacturers.
99
119
 
100
120
  Returns:
101
121
  A list of detected devices.
@@ -105,6 +125,18 @@ def detect_devices(fast: bool = True) -> Devices:
105
125
  If detection fails for the target detector specified by the `GPUSTACK_RUNTIME_DETECT` environment variable.
106
126
 
107
127
  """
128
+ if manufacturer:
129
+ det = _DETECTORS_MAP.get(manufacturer)
130
+ if det and det.is_supported():
131
+ try:
132
+ return det.detect()
133
+ except Exception:
134
+ detect_target = envs.GPUSTACK_RUNTIME_DETECT.lower()
135
+ if detect_target == det.name:
136
+ raise
137
+ debug_log_exception(logger, "Failed to detect devices for %s", det.name)
138
+ return []
139
+
108
140
  devices: Devices = []
109
141
 
110
142
  for det in _DETECTORS:
@@ -128,6 +160,7 @@ def detect_devices(fast: bool = True) -> Devices:
128
160
  def get_devices_topologies(
129
161
  devices: Devices | None = None,
130
162
  fast: bool = True,
163
+ manufacturer: ManufacturerEnum = None,
131
164
  ) -> list[Topology]:
132
165
  """
133
166
  Get the topology information of the given devices.
@@ -140,6 +173,9 @@ def get_devices_topologies(
140
173
  If True, return topologies from the first supported detector.
141
174
  Otherwise, return topologies from all supported detectors.
142
175
  Only works when `devices` is None.
176
+ manufacturer:
177
+ Manufacturer to filter the detection.
178
+ If None, detect all available manufacturers.
143
179
 
144
180
  Returns:
145
181
  A list of Topology objects for each manufacturer group.
@@ -147,7 +183,7 @@ def get_devices_topologies(
147
183
  """
148
184
  group = False
149
185
  if not devices:
150
- devices = detect_devices(fast=fast)
186
+ devices = detect_devices(fast=fast, manufacturer=manufacturer)
151
187
  if not devices:
152
188
  return []
153
189
  group = True and not fast
@@ -160,6 +196,7 @@ def get_devices_topologies(
160
196
 
161
197
  # Get topology for each group.
162
198
  topologies: list[Topology] = []
199
+
163
200
  for manu, devs in group_devices.items():
164
201
  det = _DETECTORS_MAP.get(manu)
165
202
  if det is not None:
@@ -172,6 +209,7 @@ def get_devices_topologies(
172
209
  if detect_target == det.name:
173
210
  raise
174
211
  debug_log_exception(logger, "Failed to get topology for %s", det.name)
212
+
175
213
  return topologies
176
214
 
177
215
 
@@ -197,6 +235,26 @@ def group_devices_by_manufacturer(
197
235
  return group_devices
198
236
 
199
237
 
238
+ def filter_devices_by_manufacturer(
239
+ devices: Devices | None,
240
+ manufacturer: ManufacturerEnum,
241
+ ) -> Devices:
242
+ """
243
+ Filter devices by their manufacturer.
244
+
245
+ Args:
246
+ devices:
247
+ A list of devices to be filtered.
248
+ manufacturer:
249
+ The manufacturer to filter by.
250
+
251
+ Returns:
252
+ A list of devices that match the specified manufacturer.
253
+
254
+ """
255
+ return [dev for dev in devices or [] if dev.manufacturer == manufacturer]
256
+
257
+
200
258
  __all__ = [
201
259
  "Device",
202
260
  "Devices",
@@ -205,6 +263,7 @@ __all__ = [
205
263
  "backend_to_manufacturer",
206
264
  "detect_backend",
207
265
  "detect_devices",
266
+ "filter_devices_by_manufacturer",
208
267
  "get_devices_topologies",
209
268
  "group_devices_by_manufacturer",
210
269
  "manufacturer_to_backend",
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
  from dataclasses import dataclass
5
5
  from enum import Enum
6
+ from functools import lru_cache
6
7
  from typing import Any
7
8
 
8
9
  from dataclasses_json import dataclass_json
@@ -45,6 +46,10 @@ class ManufacturerEnum(str, Enum):
45
46
  """
46
47
  NVIDIA Corporation
47
48
  """
49
+ THEAD = "thead"
50
+ """
51
+ T-Head Semiconductor Co., Ltd.
52
+ """
48
53
  UNKNOWN = "unknown"
49
54
  """
50
55
  Unknown Manufacturer
@@ -63,6 +68,7 @@ _MANUFACTURER_BACKEND_MAPPING: dict[ManufacturerEnum, str] = {
63
68
  ManufacturerEnum.METAX: "maca",
64
69
  ManufacturerEnum.MTHREADS: "musa",
65
70
  ManufacturerEnum.NVIDIA: "cuda",
71
+ ManufacturerEnum.THEAD: "hggc",
66
72
  }
67
73
  """
68
74
  Mapping of manufacturer to runtime backend,
@@ -70,6 +76,7 @@ which should map to the gpustack-runner's backend names.
70
76
  """
71
77
 
72
78
 
79
+ @lru_cache
73
80
  def manufacturer_to_backend(manufacturer: ManufacturerEnum) -> str:
74
81
  """
75
82
  Convert manufacturer to runtime backend,
@@ -92,6 +99,7 @@ def manufacturer_to_backend(manufacturer: ManufacturerEnum) -> str:
92
99
  return ManufacturerEnum.UNKNOWN.value
93
100
 
94
101
 
102
+ @lru_cache
95
103
  def backend_to_manufacturer(backend: str) -> ManufacturerEnum:
96
104
  """
97
105
  Convert runtime backend to manufacturer,
@@ -449,6 +457,9 @@ class Detector(ABC):
449
457
  """
450
458
 
451
459
  manufacturer: ManufacturerEnum = ManufacturerEnum.UNKNOWN
460
+ """
461
+ Manufacturer of the detector.
462
+ """
452
463
 
453
464
  @staticmethod
454
465
  @abstractmethod
@@ -951,3 +951,26 @@ def bitmask_to_str(bitmask_list: list) -> str:
951
951
  offset += get_bits_size()
952
952
 
953
953
  return list_to_range_str(sorted(bits_lists))
954
+
955
+
956
+ def get_physical_function_by_bdf(bdf: str) -> str:
957
+ """
958
+ Get the physical function BDF for a given PCI device BDF address.
959
+
960
+ Args:
961
+ bdf:
962
+ The PCI device BDF address (e.g., "0000:00:1f.0").
963
+
964
+ Returns:
965
+ The physical function BDF if found, otherwise returns the original BDF.
966
+
967
+ """
968
+ if bdf:
969
+ with contextlib.suppress(Exception):
970
+ dev_path = Path(f"/sys/bus/pci/devices/{bdf}")
971
+ if dev_path.exists():
972
+ physfn_path = dev_path / "physfn"
973
+ if physfn_path.exists():
974
+ physfn_realpath = physfn_path.resolve()
975
+ return physfn_realpath.name
976
+ return bdf