xpk 0.15.0__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. integration/README.md +19 -0
  2. xpk/blueprints/a3mega/config-map.yaml.tftpl +15 -0
  3. xpk/blueprints/a3mega/storage_crd.yaml +52 -0
  4. xpk/blueprints/a3ultra/config-map.yaml.tftpl +15 -0
  5. xpk/blueprints/a3ultra/mlgru-disable.yaml +59 -0
  6. xpk/blueprints/a3ultra/nccl-installer.yaml +95 -0
  7. xpk/blueprints/a3ultra/storage_crd.yaml +52 -0
  8. xpk/blueprints/a4/config-map.yaml.tftpl +15 -0
  9. xpk/blueprints/a4/nccl-rdma-installer-a4.yaml +66 -0
  10. xpk/blueprints/a4/storage_crd.yaml +52 -0
  11. xpk/commands/cluster.py +33 -12
  12. xpk/commands/cluster_gcluster_test.py +5 -1
  13. xpk/commands/cluster_test.py +125 -0
  14. xpk/commands/config.py +3 -3
  15. xpk/commands/inspector.py +5 -3
  16. xpk/commands/kind.py +2 -0
  17. xpk/commands/managed_ml_diagnostics.py +249 -0
  18. xpk/commands/managed_ml_diagnostics_test.py +146 -0
  19. xpk/commands/workload.py +125 -139
  20. xpk/commands/workload_test.py +160 -118
  21. xpk/core/blueprint/blueprint_generator.py +3 -0
  22. xpk/core/blueprint/testing/data/a3_mega.yaml +129 -0
  23. xpk/core/blueprint/testing/data/a3_mega_spot.yaml +125 -0
  24. xpk/core/blueprint/testing/data/a3_ultra.yaml +173 -0
  25. xpk/core/blueprint/testing/data/a4.yaml +185 -0
  26. xpk/core/capacity.py +2 -0
  27. xpk/core/cluster.py +18 -47
  28. xpk/core/cluster_test.py +76 -1
  29. xpk/core/config.py +81 -7
  30. xpk/core/config_test.py +67 -11
  31. xpk/core/docker_container.py +3 -1
  32. xpk/core/docker_image.py +10 -6
  33. xpk/core/docker_resources.py +1 -10
  34. xpk/core/kjob.py +17 -16
  35. xpk/core/kueue_manager.py +13 -19
  36. xpk/core/kueue_manager_test.py +27 -1
  37. xpk/core/nap.py +13 -14
  38. xpk/core/nodepool.py +17 -15
  39. xpk/core/nodepool_test.py +25 -4
  40. xpk/core/pathways.py +23 -0
  41. xpk/core/pathways_test.py +57 -0
  42. xpk/core/resources.py +84 -27
  43. xpk/core/scheduling.py +128 -132
  44. xpk/core/scheduling_test.py +215 -2
  45. xpk/core/system_characteristics.py +179 -0
  46. xpk/core/system_characteristics_test.py +49 -1
  47. xpk/core/telemetry.py +4 -4
  48. xpk/core/telemetry_test.py +9 -9
  49. xpk/core/vertex.py +4 -3
  50. xpk/core/workload_decorators/tcpx_decorator.py +5 -1
  51. xpk/main.py +2 -0
  52. xpk/parser/cluster.py +22 -88
  53. xpk/parser/cluster_test.py +41 -0
  54. xpk/parser/common.py +84 -0
  55. xpk/parser/storage.py +10 -0
  56. xpk/parser/storage_test.py +47 -0
  57. xpk/parser/workload.py +14 -41
  58. xpk/parser/workload_test.py +2 -48
  59. xpk/templates/arm_gpu_workload_crate.yaml.j2 +46 -0
  60. xpk/utils/feature_flags.py +3 -0
  61. xpk/utils/validation.py +2 -2
  62. xpk-0.16.1.dist-info/METADATA +127 -0
  63. {xpk-0.15.0.dist-info → xpk-0.16.1.dist-info}/RECORD +67 -48
  64. xpk-0.15.0.dist-info/METADATA +0 -1666
  65. {xpk-0.15.0.dist-info → xpk-0.16.1.dist-info}/WHEEL +0 -0
  66. {xpk-0.15.0.dist-info → xpk-0.16.1.dist-info}/entry_points.txt +0 -0
  67. {xpk-0.15.0.dist-info → xpk-0.16.1.dist-info}/licenses/LICENSE +0 -0
  68. {xpk-0.15.0.dist-info → xpk-0.16.1.dist-info}/top_level.txt +0 -0
xpk/core/config_test.py CHANGED
@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from xpk.core.config import XpkConfig, CFG_BUCKET_KEY, CLUSTER_NAME_KEY, PROJECT_KEY, ZONE_KEY
17
+ from xpk.core.config import FileSystemConfig, InMemoryXpkConfig, CFG_BUCKET_KEY, CLUSTER_NAME_KEY, PROJECT_KEY, ZONE_KEY, _get_version
18
+ from unittest.mock import patch
19
+ from importlib.metadata import PackageNotFoundError
18
20
 
19
21
  import os
20
22
  import pytest
@@ -31,15 +33,60 @@ def _():
31
33
  os.remove(config_tmp_path)
32
34
 
33
35
 
34
- def test_config(_):
35
- cfg = XpkConfig(config_tmp_path)
36
+ @patch('os.getenv', return_value='10.0.0')
37
+ def test_get_version_returns_overriden_value_when_it_is_overriden(_):
38
+ assert _get_version() == '10.0.0'
39
+
40
+
41
+ @patch('os.getenv', return_value='')
42
+ @patch('xpk.core.config.setuptools_get_version', return_value='10.0.0')
43
+ def test_get_version_returns_value_from_setuptools_scm_when_there_is_no_override(
44
+ *_,
45
+ ):
46
+ assert _get_version() == '10.0.0'
47
+
48
+
49
+ @patch('os.getenv', return_value='')
50
+ @patch(
51
+ 'xpk.core.config.setuptools_get_version',
52
+ side_effect=LookupError('unable to find git version'),
53
+ )
54
+ @patch('xpk.core.config.version', return_value='10.0.0')
55
+ def test_get_version_returns_value_from_pip_when_there_is_no_setuptools_could_be_resolved(
56
+ *_,
57
+ ):
58
+ assert _get_version() == '10.0.0'
59
+
60
+
61
+ @patch('os.getenv', return_value='')
62
+ @patch(
63
+ 'xpk.core.config.setuptools_get_version',
64
+ side_effect=LookupError('unable to find git version'),
65
+ )
66
+ @patch(
67
+ 'xpk.core.config.version',
68
+ side_effect=PackageNotFoundError('unable to locate package'),
69
+ )
70
+ def test_get_version_returns_none_when_no_version_could_be_resolved(*_):
71
+ with pytest.raises(LookupError):
72
+ _get_version()
73
+
74
+
75
+ @pytest.mark.parametrize(
76
+ argnames='cfg',
77
+ argvalues=[(FileSystemConfig(config_tmp_path)), (InMemoryXpkConfig())],
78
+ )
79
+ def test_config(_, cfg):
36
80
  cfg.set('project-id', 'foo')
37
81
  project_id = cfg.get('project-id')
38
82
  assert project_id == 'foo'
39
83
 
40
84
 
41
- def test_config_get_all(_):
42
- cfg = XpkConfig(config_tmp_path)
85
+ @pytest.mark.parametrize(
86
+ argnames='cfg',
87
+ argvalues=[(FileSystemConfig(config_tmp_path)), (InMemoryXpkConfig())],
88
+ )
89
+ def test_config_get_all(_, cfg):
43
90
  cfg.set(PROJECT_KEY, 'foo')
44
91
  cfg.set(CLUSTER_NAME_KEY, 'bar')
45
92
  cfg.set(ZONE_KEY, 'europe-west1-a')
@@ -52,20 +99,29 @@ def test_config_get_all(_):
52
99
  assert cfg_all[CFG_BUCKET_KEY] == 'cfg-bucket'
53
100
 
54
101
 
55
- def test_config_get_empty(_):
56
- cfg = XpkConfig(config_tmp_path)
102
+ @pytest.mark.parametrize(
103
+ argnames='cfg',
104
+ argvalues=[(FileSystemConfig(config_tmp_path)), (InMemoryXpkConfig())],
105
+ )
106
+ def test_config_get_empty(_, cfg):
57
107
  val = cfg.get(PROJECT_KEY)
58
108
  assert val is None
59
109
 
60
110
 
61
- def test_config_get_all_empty(_):
62
- cfg = XpkConfig(config_tmp_path)
111
+ @pytest.mark.parametrize(
112
+ argnames='cfg',
113
+ argvalues=[(FileSystemConfig(config_tmp_path)), (InMemoryXpkConfig())],
114
+ )
115
+ def test_config_get_all_empty(_, cfg):
63
116
  val = cfg.get_all()
64
117
  assert not val
65
118
 
66
119
 
67
- def test_config_set_incorrect(_):
68
- cfg = XpkConfig(config_tmp_path)
120
+ @pytest.mark.parametrize(
121
+ argnames='cfg',
122
+ argvalues=[(FileSystemConfig(config_tmp_path)), (InMemoryXpkConfig())],
123
+ )
124
+ def test_config_set_incorrect(cfg, _):
69
125
  cfg.set('foo', 'bar')
70
126
  cfg_all = cfg.get_all()
71
127
  assert not cfg_all
@@ -182,7 +182,9 @@ def get_user_workload_container(args, system: SystemCharacteristics):
182
182
  debugging_dashboard_id: id of the GKE dashboard
183
183
  """
184
184
 
185
- setup_docker_image_code, docker_image = setup_docker_image(args)
185
+ setup_docker_image_code, docker_image = setup_docker_image(
186
+ args, system.docker_platform
187
+ )
186
188
  if setup_docker_image_code != 0:
187
189
  xpk_exit(setup_docker_image_code)
188
190
 
xpk/core/docker_image.py CHANGED
@@ -19,6 +19,7 @@ import os
19
19
  import random
20
20
  import string
21
21
 
22
+ from .system_characteristics import DockerPlatform
22
23
  from ..utils.console import xpk_exit, xpk_print
23
24
  from ..utils.file import write_tmp_file
24
25
  from ..utils.execution_context import is_dry_run
@@ -26,7 +27,6 @@ from .commands import run_command_with_updates
26
27
 
27
28
  DEFAULT_DOCKER_IMAGE = 'python:3.10'
28
29
  DEFAULT_SCRIPT_DIR = os.getcwd()
29
- PLATFORM = 'linux/amd64'
30
30
 
31
31
 
32
32
  def validate_docker_image(docker_image, args) -> int:
@@ -63,7 +63,9 @@ def validate_docker_image(docker_image, args) -> int:
63
63
  return 0
64
64
 
65
65
 
66
- def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
66
+ def build_docker_image_from_base_image(
67
+ args, docker_platform: DockerPlatform, verbose=True
68
+ ) -> tuple[int, str]:
67
69
  """Adds script dir to the base docker image and uploads the image.
68
70
 
69
71
  Args:
@@ -97,8 +99,8 @@ def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
97
99
  )
98
100
  tmp = write_tmp_file(docker_file)
99
101
  docker_build_command = (
100
- f'docker buildx build --platform={PLATFORM} -f {str(tmp)} -t'
101
- f' {docker_name} {args.script_dir}'
102
+ f'docker buildx build --platform={docker_platform.value} -f'
103
+ f' {str(tmp)} -t {docker_name} {args.script_dir}'
102
104
  )
103
105
  xpk_print(f'Building {args.script_dir} into docker image.')
104
106
  return_code = run_command_with_updates(
@@ -158,7 +160,9 @@ def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
158
160
  return return_code, cloud_docker_image
159
161
 
160
162
 
161
- def setup_docker_image(args) -> tuple[int, str]:
163
+ def setup_docker_image(
164
+ args, docker_platform: DockerPlatform
165
+ ) -> tuple[int, str]:
162
166
  """Does steps to verify docker args, check image, and build image (if asked).
163
167
 
164
168
  Args:
@@ -177,7 +181,7 @@ def setup_docker_image(args) -> tuple[int, str]:
177
181
  if validate_docker_image_code != 0:
178
182
  xpk_exit(validate_docker_image_code)
179
183
  build_docker_image_code, docker_image = build_docker_image_from_base_image(
180
- args
184
+ args, docker_platform
181
185
  )
182
186
  if build_docker_image_code != 0:
183
187
  xpk_exit(build_docker_image_code)
@@ -16,7 +16,6 @@ limitations under the License.
16
16
 
17
17
  import os
18
18
  import re
19
- from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
20
19
  from .cluster import setup_k8s_env
21
20
  from .storage import GCS_FUSE_TYPE, GCP_FILESTORE_TYPE, PARALLELSTORE_TYPE, GCE_PD_TYPE, LUSTRE_TYPE, Storage, get_storages_to_mount
22
21
  from .system_characteristics import AcceleratorType, SystemCharacteristics
@@ -109,14 +108,6 @@ def get_gpu_env(args, system) -> str:
109
108
  value: "{args.command}"
110
109
  {custom_envs}"""
111
110
 
112
- gpu_direct_name = 'fastrak'
113
- if args.device_type == H100_DEVICE_TYPE:
114
- gpu_direct_name = 'tcpx'
115
- elif args.device_type == H100_MEGA_DEVICE_TYPE:
116
- gpu_direct_name = 'tcpxo'
117
- elif args.device_type == H200_DEVICE_TYPE:
118
- gpu_direct_name = 'rdma'
119
-
120
111
  gpu_env_dic = {
121
112
  'JAX_COORDINATOR_PORT': '6002',
122
113
  'JAX_COORDINATOR_ADDRESS': (
@@ -129,7 +120,7 @@ def get_gpu_env(args, system) -> str:
129
120
  return gpu_env_yaml.format(
130
121
  args=args,
131
122
  chips_per_vm=system.chips_per_vm,
132
- gpu_direct_name=gpu_direct_name,
123
+ gpu_direct_name=system.gpu_config.gpu_direct_name,
133
124
  custom_envs=format_env_dict(args.env, system),
134
125
  )
135
126
 
xpk/core/kjob.py CHANGED
@@ -25,7 +25,6 @@ from kubernetes.client.rest import ApiException
25
25
  from ..utils import templates
26
26
  from ..utils.execution_context import is_dry_run
27
27
  from ..utils.console import xpk_exit, xpk_print
28
- from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
29
28
  from .cluster import DEFAULT_NAMESPACE, XPK_SA, setup_k8s_env
30
29
  from .commands import (
31
30
  run_command_for_value,
@@ -38,7 +37,7 @@ from .config import (
38
37
  KJOB_SHELL_IMAGE,
39
38
  KJOB_SHELL_INTERACTIVE_COMMAND,
40
39
  KJOB_SHELL_WORKING_DIRECTORY,
41
- xpk_config,
40
+ get_config,
42
41
  )
43
42
  from .network import get_cluster_subnetworks
44
43
  from .system_characteristics import AcceleratorType, SystemCharacteristics
@@ -52,7 +51,6 @@ from .storage import (
52
51
  )
53
52
  from .workload_decorators import (
54
53
  rdma_decorator,
55
- tcpx_decorator,
56
54
  tcpxo_decorator,
57
55
  )
58
56
  from .workload_decorators.tcpxo_decorator import get_tcpxo_deamon_entry
@@ -234,7 +232,7 @@ def get_pod_template_interactive_command() -> str:
234
232
  Returns:
235
233
  str - PodTemplate's interactive command
236
234
  """
237
- pod_command = xpk_config.get(KJOB_SHELL_INTERACTIVE_COMMAND)
235
+ pod_command = get_config().get(KJOB_SHELL_INTERACTIVE_COMMAND)
238
236
  if pod_command is None or len(pod_command) == 0:
239
237
  pod_command = PodTemplateDefaults.INTERACTIVE_COMMAND.value
240
238
 
@@ -260,14 +258,17 @@ def create_app_profile_instance(volume_bundles: list[str]) -> int:
260
258
  )
261
259
 
262
260
 
263
- def decorate_job_template_with_gpu(yml_string: str, gpu_type: str) -> str:
261
+ def decorate_job_template_with_gpu(
262
+ yml_string: str, system: SystemCharacteristics
263
+ ) -> str:
264
264
  job_spec = yaml.safe_load(yml_string)["template"]
265
- if gpu_type == H100_DEVICE_TYPE:
266
- job_spec = tcpx_decorator.decorate_kjob_template(job_spec)
267
- if gpu_type == H100_MEGA_DEVICE_TYPE:
268
- job_spec = tcpxo_decorator.decorate_kjob_template(job_spec)
269
- if gpu_type == H200_DEVICE_TYPE:
270
- job_spec = rdma_decorator.decorate_kjob_template(job_spec)
265
+ kjob_decorator = (
266
+ system.gpu_config.kjob_decorator_fn
267
+ if system.gpu_config and system.gpu_config.kjob_decorator_fn
268
+ else None
269
+ )
270
+ if kjob_decorator:
271
+ job_spec = kjob_decorator(job_spec)
271
272
  job_template_dict = yaml.safe_load(yml_string)
272
273
  job_template_dict["template"] = job_spec
273
274
  yaml_result: str = yaml.dump(job_template_dict, sort_keys=False)
@@ -286,10 +287,10 @@ def create_job_template_instance(
286
287
  Returns:
287
288
  exit_code > 0 if creating JobTemplate fails, 0 otherwise
288
289
  """
289
- job_image = xpk_config.get(KJOB_BATCH_IMAGE)
290
+ job_image = get_config().get(KJOB_BATCH_IMAGE)
290
291
  if job_image is None or len(job_image) == 0:
291
292
  job_image = JobTemplateDefaults.IMAGE.value
292
- working_directory = xpk_config.get(KJOB_BATCH_WORKING_DIRECTORY)
293
+ working_directory = get_config().get(KJOB_BATCH_WORKING_DIRECTORY)
293
294
  if working_directory is None or len(working_directory) == 0:
294
295
  working_directory = JobTemplateDefaults.WORKING_DIRECTORY.value
295
296
  resources = (
@@ -316,7 +317,7 @@ def create_job_template_instance(
316
317
  service_account=service_account,
317
318
  )
318
319
  if system is not None and system.accelerator_type == AcceleratorType.GPU:
319
- yml_string = decorate_job_template_with_gpu(yml_string, system.device_type)
320
+ yml_string = decorate_job_template_with_gpu(yml_string, system)
320
321
 
321
322
  return run_kubectl_apply(
322
323
  yml_string,
@@ -330,10 +331,10 @@ def create_pod_template_instance(service_account: str) -> int:
330
331
  Returns:
331
332
  exit_code > 0 if creating PodTemplate fails, 0 otherwise
332
333
  """
333
- pod_image = xpk_config.get(KJOB_SHELL_IMAGE)
334
+ pod_image = get_config().get(KJOB_SHELL_IMAGE)
334
335
  if pod_image is None or len(pod_image) == 0:
335
336
  pod_image = PodTemplateDefaults.IMAGE.value
336
- working_directory = xpk_config.get(KJOB_SHELL_WORKING_DIRECTORY)
337
+ working_directory = get_config().get(KJOB_SHELL_WORKING_DIRECTORY)
337
338
  if working_directory is None or len(working_directory) == 0:
338
339
  working_directory = PodTemplateDefaults.WORKING_DIRECTORY.value
339
340
 
xpk/core/kueue_manager.py CHANGED
@@ -24,15 +24,13 @@ from jinja2 import Environment, FileSystemLoader
24
24
  from ..utils.topology import get_slice_topology_level, get_topology_product, is_topology_contained
25
25
  from ..utils.kueue import is_queued_cluster
26
26
  from kubernetes.utils import parse_quantity
27
- from .capacity import B200_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
28
- from .scheduling import (
29
- create_accelerator_label,
30
- create_machine_label,
31
- )
32
27
  from .system_characteristics import (
33
28
  SUB_SLICING_TOPOLOGIES,
29
+ AcceleratorType,
34
30
  AcceleratorTypeToAcceleratorCharacteristics,
35
31
  SystemCharacteristics,
32
+ create_accelerator_label,
33
+ create_machine_label,
36
34
  )
37
35
  from ..core.commands import (
38
36
  run_command_for_value,
@@ -321,19 +319,16 @@ class KueueManager:
321
319
  main_flavor_name = f"{num_slices}x{device_type_str}"
322
320
 
323
321
  node_labels_dict = {}
324
- accelerator_label = create_accelerator_label(
325
- system.accelerator_type, system
326
- )
322
+ accelerator_label = create_accelerator_label(system)
327
323
  if accelerator_label:
328
324
  key, value = accelerator_label.split(":", 1)
329
325
  node_labels_dict[key] = value.strip()
330
326
 
331
- machine_label = create_machine_label(
332
- system.accelerator_type, system, autoprovisioning
333
- )
334
- if machine_label:
335
- key, value = machine_label.split(":", 1)
336
- node_labels_dict[key] = value.strip()
327
+ if not autoprovisioning:
328
+ machine_label = create_machine_label(system)
329
+ if machine_label:
330
+ key, value = machine_label.split(":", 1)
331
+ node_labels_dict[key] = value.strip()
337
332
 
338
333
  topology_label = f"topologyName: {topology_name}" if topology_name else ""
339
334
 
@@ -400,11 +395,10 @@ class KueueManager:
400
395
  def __get_topology_name_and_yaml(
401
396
  self, system: SystemCharacteristics, configure_sub_slicing: bool
402
397
  ) -> _NameAndYaml | None:
403
- if system.device_type in [
404
- H100_MEGA_DEVICE_TYPE,
405
- H200_DEVICE_TYPE,
406
- B200_DEVICE_TYPE,
407
- ]:
398
+ if (
399
+ system.accelerator_type == AcceleratorType["GPU"]
400
+ and system.gpu_requires_topology
401
+ ):
408
402
  return _NameAndYaml(
409
403
  name="gke-default",
410
404
  yaml=self.template_env.get_template(
@@ -22,7 +22,7 @@ import yaml
22
22
  from unittest.mock import MagicMock, patch
23
23
 
24
24
  from xpk.core.kueue_manager import KueueConfig, KueueManager, has_sub_slicing_enabled
25
- from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics, UserFacingNameToSystemCharacteristics
25
+ from xpk.core.system_characteristics import GpuConfig, DockerPlatform, AcceleratorType, SystemCharacteristics, UserFacingNameToSystemCharacteristics
26
26
  from xpk.core.testing.commands_tester import CommandsTester
27
27
  from packaging.version import Version
28
28
 
@@ -35,6 +35,7 @@ TPU_SYSTEM: SystemCharacteristics = SystemCharacteristics(
35
35
  accelerator_type=AcceleratorType.TPU,
36
36
  device_type="v5p-8",
37
37
  supports_sub_slicing=False,
38
+ docker_platform=DockerPlatform.ARM,
38
39
  )
39
40
 
40
41
  KUEUE_CONFIG: KueueConfig = KueueConfig(
@@ -405,6 +406,8 @@ def test_configure_generates_correct_manifest_with_gke_default_topology(
405
406
  accelerator_type=AcceleratorType.GPU,
406
407
  device_type="h100-mega-80gb-8",
407
408
  supports_sub_slicing=False,
409
+ docker_platform=DockerPlatform.ARM,
410
+ gpu_config=GpuConfig(requires_topology=True),
408
411
  ),
409
412
  )
410
413
 
@@ -501,6 +504,29 @@ def test_configure_generates_correct_manifest_with_pathways(
501
504
  assert pathways_rg["flavors"][0]["resources"][1]["nominalQuota"] == "2000G"
502
505
 
503
506
 
507
+ @patch("xpk.core.kueue_manager.write_tmp_file")
508
+ def test_configure_generates_correct_manifest_for_a4x(
509
+ write_tmp_file_mock: MagicMock,
510
+ mock_commands: CommandsTester,
511
+ kueue_manager: KueueManager,
512
+ ):
513
+ """Test that __configure generates correct manifest for a4x GPUs."""
514
+ set_installed_kueue_version(mock_commands, None)
515
+ kueue_config = dataclasses.replace(
516
+ KUEUE_CONFIG,
517
+ system=UserFacingNameToSystemCharacteristics["gb200-4"],
518
+ )
519
+
520
+ kueue_manager.install_or_upgrade(kueue_config)
521
+
522
+ rendered_manifest: str = write_tmp_file_mock.call_args[0][0]
523
+ manifest_docs = list(yaml.safe_load_all(rendered_manifest))
524
+
525
+ # Check that the gke-default topology is present for a4x.
526
+ topology = _first(doc for doc in manifest_docs if doc["kind"] == "Topology")
527
+ assert topology["metadata"]["name"] == "gke-default"
528
+
529
+
504
530
  def test_has_sub_slicing_enabled_returns_exit_code_when_command_fails(
505
531
  mock_commands: CommandsTester,
506
532
  ):
xpk/core/nap.py CHANGED
@@ -30,9 +30,8 @@ from .commands import run_command_with_updates, run_commands
30
30
  from .gcloud_context import get_cluster_location
31
31
  from .nodepool import get_all_nodepools_programmatic
32
32
  from .resources import (
33
- CLUSTER_METADATA_CONFIGMAP,
34
- CLUSTER_RESOURCES_CONFIGMAP,
35
33
  AutoprovisioningConfig,
34
+ ConfigMapType,
36
35
  get_cluster_configmap,
37
36
  )
38
37
  from .scheduling import get_total_chips_requested_from_args
@@ -266,14 +265,12 @@ def is_autoprovisioning_enabled(
266
265
  int of 0 if successful and 1 otherwise.
267
266
  """
268
267
 
269
- resources_configmap_name = f'{args.cluster}-{CLUSTER_RESOURCES_CONFIGMAP}'
270
- cluster_config_map = get_cluster_configmap(resources_configmap_name)
268
+ cluster_config_map = get_cluster_configmap(
269
+ args.cluster, ConfigMapType.RESOURCES
270
+ )
271
271
 
272
272
  if cluster_config_map is None:
273
- xpk_print(
274
- f'Unable to find config map: {resources_configmap_name}.'
275
- ' Autoprovisioning is not enabled.'
276
- )
273
+ xpk_print('Unable to find config map. Autoprovisioning is not enabled.')
277
274
  return False, 0
278
275
 
279
276
  return_code, autoprovisioning_value = get_value_from_map(
@@ -281,8 +278,8 @@ def is_autoprovisioning_enabled(
281
278
  )
282
279
  if return_code != 0:
283
280
  xpk_print(
284
- 'gke_accelerator type not found in config map:'
285
- f' {resources_configmap_name}. Autoprovisioning is not enabled.'
281
+ 'gke_accelerator type not found in config map. Autoprovisioning is not'
282
+ ' enabled.'
286
283
  )
287
284
  return False, 0
288
285
 
@@ -319,8 +316,9 @@ def get_autoprovisioning_node_selector_args(args) -> tuple[str, int]:
319
316
 
320
317
  if capacity_type_str == CapacityType.UNKNOWN.name:
321
318
  # Use default settings from cluster creation.
322
- metadata_configmap_name = f'{args.cluster}-{CLUSTER_METADATA_CONFIGMAP}'
323
- cluster_config_map = get_cluster_configmap(metadata_configmap_name)
319
+ cluster_config_map = get_cluster_configmap(
320
+ args.cluster, ConfigMapType.METADATA
321
+ )
324
322
 
325
323
  # Error out if the metadata config map doesn't exist, and is attempting to use
326
324
  # autoprovisioning.
@@ -363,8 +361,9 @@ def get_autoprovisioning_node_selector_args(args) -> tuple[str, int]:
363
361
 
364
362
 
365
363
  def get_cluster_provisioner(args) -> str:
366
- metadata_configmap_name = f'{args.cluster}-{CLUSTER_METADATA_CONFIGMAP}'
367
- cluster_config_map = get_cluster_configmap(metadata_configmap_name)
364
+ cluster_config_map = get_cluster_configmap(
365
+ args.cluster, ConfigMapType.METADATA
366
+ )
368
367
  cluster_provisioner = 'gcloud'
369
368
  if not cluster_config_map is None:
370
369
  provisioner = cluster_config_map.get('provisioner')
xpk/core/nodepool.py CHANGED
@@ -28,10 +28,9 @@ from .capacity import (
28
28
  from .commands import run_command_for_value, run_commands
29
29
  from .gcloud_context import GkeServerConfig, get_cluster_location, zone_to_region
30
30
  from .resources import (
31
- CLUSTER_CONFIGMAP_YAML,
32
- CLUSTER_RESOURCES_CONFIGMAP,
31
+ ConfigMapType,
33
32
  check_cluster_resources,
34
- create_or_update_cluster_configmap,
33
+ update_cluster_configmap,
35
34
  )
36
35
  from .system_characteristics import AcceleratorType
37
36
 
@@ -247,20 +246,23 @@ def run_gke_node_pool_create_command(
247
246
  )
248
247
  else:
249
248
  resources_data = f'{device_type}: "0"'
250
- resources_configmap_name = f'{args.cluster}-{CLUSTER_RESOURCES_CONFIGMAP}'
251
- resources_yml = CLUSTER_CONFIGMAP_YAML.format(
252
- args=args, name=resources_configmap_name, data=resources_data
249
+ return_code = update_cluster_configmap(
250
+ cluster_name=args.cluster,
251
+ config_map_type=ConfigMapType.RESOURCES,
252
+ data=resources_data,
253
253
  )
254
- configmap_yml = {}
255
- configmap_yml[resources_configmap_name] = resources_yml
256
- return_code = create_or_update_cluster_configmap(configmap_yml)
257
254
  if return_code != 0:
258
255
  return 1
259
256
 
260
257
  placement_args = ''
261
258
  if is_placement_policy_supported(system):
262
259
  placement_policy = get_placement_policy_name(system)
263
- ensure_resource_policy_exists(placement_policy, args, system.topology)
260
+ ensure_resource_policy_exists(
261
+ resource_policy_name=placement_policy,
262
+ project=args.project,
263
+ zone=args.zone,
264
+ topology=system.topology,
265
+ )
264
266
  placement_args = f' --placement-policy={placement_policy}'
265
267
 
266
268
  create_commands = []
@@ -311,7 +313,7 @@ def run_gke_node_pool_create_command(
311
313
  command += (
312
314
  ' --accelerator'
313
315
  f' type={system.gke_accelerator},count={str(system.chips_per_vm)},gpu-driver-version=latest'
314
- f' --no-enable-autoupgrade --scopes={CLOUD_PLATFORM_AUTH_SCOPE_URL}'
316
+ f' --scopes={CLOUD_PLATFORM_AUTH_SCOPE_URL}'
315
317
  )
316
318
  if device_type == H100_MEGA_DEVICE_TYPE:
317
319
  for i in range(1, 9):
@@ -587,14 +589,14 @@ def get_desired_node_pool_names(
587
589
 
588
590
 
589
591
  def ensure_resource_policy_exists(
590
- resource_policy_name: str, args, topology: str
592
+ resource_policy_name: str, project: str, zone: str, topology: str
591
593
  ) -> None:
592
594
  return_code, _ = run_command_for_value(
593
595
  (
594
596
  'gcloud compute resource-policies describe'
595
597
  f' {resource_policy_name} '
596
- f'--project={args.project} '
597
- f'--region={zone_to_region(args.zone)}'
598
+ f'--project={project} '
599
+ f'--region={zone_to_region(zone)}'
598
600
  ),
599
601
  'Retrieve resource policy',
600
602
  )
@@ -605,7 +607,7 @@ def ensure_resource_policy_exists(
605
607
  return_code, _ = run_command_for_value(
606
608
  (
607
609
  'gcloud compute resource-policies create workload-policy'
608
- f' {resource_policy_name} --project={args.project} --region={zone_to_region(args.zone)} --type=HIGH_THROUGHPUT'
610
+ f' {resource_policy_name} --project={project} --region={zone_to_region(zone)} --type=HIGH_THROUGHPUT'
609
611
  f' --accelerator-topology={topology}'
610
612
  ),
611
613
  'Create resource policy',
xpk/core/nodepool_test.py CHANGED
@@ -20,7 +20,7 @@ from xpk.core.nodepool import (
20
20
  get_desired_node_pool_names,
21
21
  run_gke_node_pool_create_command,
22
22
  )
23
- from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics
23
+ from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics, DockerPlatform, GpuConfig
24
24
 
25
25
  CLUSTER_NAME = "running-cucumber"
26
26
 
@@ -96,7 +96,12 @@ def test_ensure_resource_policy_exists_with_existing_policy_retrieves_existing_p
96
96
  mock = mocker.patch(
97
97
  "xpk.core.nodepool.run_command_for_value", return_value=(0, "")
98
98
  )
99
- ensure_resource_policy_exists("resource-policy", args, "2x2x1")
99
+ ensure_resource_policy_exists(
100
+ resource_policy_name="resource-policy",
101
+ project="test-project",
102
+ zone="us-central1-a",
103
+ topology="2x2x1",
104
+ )
100
105
  mock.assert_called_once()
101
106
 
102
107
 
@@ -108,7 +113,12 @@ def test_ensure_resource_policy_exists_without_existing_policy_creates_policy(
108
113
  mock = mocker.patch(
109
114
  "xpk.core.nodepool.run_command_for_value", side_effect=[(1, ""), (0, "")]
110
115
  )
111
- ensure_resource_policy_exists("resource-policy", args, "2x2x1")
116
+ ensure_resource_policy_exists(
117
+ resource_policy_name="resource-policy",
118
+ project="test-project",
119
+ zone="us-central1-a",
120
+ topology="2x2x1",
121
+ )
112
122
  assert mock.call_count == 2
113
123
  assert mock.call_args_list[0].args[1] == "Retrieve resource policy"
114
124
 
@@ -125,7 +135,12 @@ def test_ensure_resource_policy_exits_without_existing_policy_throws_when_creati
125
135
  "xpk.core.nodepool.run_command_for_value",
126
136
  side_effect=[(1, ""), (1, "")],
127
137
  )
128
- ensure_resource_policy_exists("resource-policy", args, "2x2x1")
138
+ ensure_resource_policy_exists(
139
+ resource_policy_name="resource-policy",
140
+ project="test-project",
141
+ zone="us-central1-a",
142
+ topology="2x2x1",
143
+ )
129
144
 
130
145
 
131
146
  @pytest.fixture
@@ -179,6 +194,8 @@ def test_placement_policy_created_for_gpu_with_valid_topology(
179
194
  accelerator_type=AcceleratorType.GPU,
180
195
  device_type="h100-80gb-8",
181
196
  supports_sub_slicing=False,
197
+ docker_platform=DockerPlatform.ARM,
198
+ gpu_config=GpuConfig(requires_topology=True),
182
199
  )
183
200
 
184
201
  run_gke_node_pool_create_command(args, system, "1.2.3")
@@ -209,6 +226,8 @@ def test_placement_policy_not_created_for_gpu_with_invalid_topology(
209
226
  accelerator_type=AcceleratorType.GPU,
210
227
  device_type="h100-80gb-8",
211
228
  supports_sub_slicing=False,
229
+ docker_platform=DockerPlatform.ARM,
230
+ gpu_config=GpuConfig(requires_topology=True),
212
231
  )
213
232
 
214
233
  run_gke_node_pool_create_command(args, system, "1.2.3")
@@ -242,6 +261,7 @@ def test_placement_policy_created_for_tpu7x_with_valid_topology(
242
261
  device_type="tpu7x-8",
243
262
  requires_workload_policy=True,
244
263
  supports_sub_slicing=False,
264
+ docker_platform=DockerPlatform.ARM,
245
265
  )
246
266
 
247
267
  run_gke_node_pool_create_command(args, system, "1.2.3")
@@ -274,6 +294,7 @@ def test_placement_policy_not_created_for_non7x_tpu(
274
294
  accelerator_type=AcceleratorType.TPU,
275
295
  device_type="v6e-4",
276
296
  supports_sub_slicing=True,
297
+ docker_platform=DockerPlatform.ARM,
277
298
  )
278
299
 
279
300
  run_gke_node_pool_create_command(args, system, "1.2.3")
xpk/core/pathways.py CHANGED
@@ -333,3 +333,26 @@ def try_to_delete_pathwaysjob_first(args, workloads) -> bool:
333
333
  xpk_print(f'Delete Workload request returned ERROR {return_code}')
334
334
  return False
335
335
  return True
336
+
337
+
338
+ def get_pathways_machine_types(
339
+ project: str, zone: str
340
+ ) -> tuple[int, list[str]]:
341
+ # Identify machine types with sufficient allocatable capacity to
342
+ # schedule the Pathways pod. This filter ensures the selected node
343
+ # is large enough to handle the control plane workload plus GKE
344
+ # system overhead.
345
+ min_memory_mb = 233 * 1024
346
+ command = (
347
+ 'gcloud compute machine-types list --filter "guestCpus >= 49 AND memoryMb'
348
+ f' >= {min_memory_mb} AND zone = \'{zone}\'" --format="value(name)"'
349
+ f' --project={project}'
350
+ )
351
+ return_code, result = run_command_for_value(
352
+ command=command,
353
+ task='Retrieve available pathways machine types',
354
+ dry_run_return_val='n2-standard-64',
355
+ )
356
+ if return_code != 0:
357
+ return return_code, []
358
+ return 0, result.strip().splitlines()