xpk 0.14.0__py3-none-any.whl → 0.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. integration/__init__.py +15 -0
  2. integration/docker_manager_test.py +102 -0
  3. integration/gcluster_a3mega_test.py +204 -0
  4. integration/gcluster_a3ultra_test.py +176 -0
  5. integration/gcluster_a4_test.py +176 -0
  6. integration/gcluster_test.py +107 -0
  7. xpk/commands/cluster.py +17 -4
  8. xpk/commands/cluster_gcluster.py +4 -0
  9. xpk/commands/cluster_test.py +92 -0
  10. xpk/commands/common.py +6 -0
  11. xpk/commands/kind.py +1 -0
  12. xpk/commands/workload.py +41 -7
  13. xpk/commands/workload_test.py +81 -0
  14. xpk/core/blueprint/testing/__init__.py +15 -0
  15. xpk/core/config.py +1 -1
  16. xpk/core/kueue_manager.py +60 -20
  17. xpk/core/kueue_manager_test.py +52 -20
  18. xpk/core/system_characteristics.py +16 -4
  19. xpk/core/system_characteristics_test.py +73 -0
  20. xpk/templates/cluster_preheat.yaml.j2 +31 -0
  21. xpk/templates/filestore-pv.yaml +17 -0
  22. xpk/templates/filestore-pvc.yaml +11 -0
  23. xpk/templates/filestore-sc.yaml +10 -0
  24. xpk/templates/fuse-pv.yaml +17 -0
  25. xpk/templates/fuse-pvc.yaml +13 -0
  26. xpk/templates/kueue_config.yaml.j2 +95 -0
  27. xpk/templates/kueue_gke_default_topology.yaml.j2 +10 -0
  28. xpk/templates/kueue_sub_slicing_topology.yaml.j2 +14 -0
  29. xpk/templates/mtc-cpc.yaml +15 -0
  30. xpk/templates/volume_bundle.yaml +7 -0
  31. xpk/utils/templates.py +14 -1
  32. xpk/utils/topology.py +9 -0
  33. xpk/utils/topology_test.py +21 -1
  34. {xpk-0.14.0.dist-info → xpk-0.14.1.dist-info}/METADATA +1 -1
  35. {xpk-0.14.0.dist-info → xpk-0.14.1.dist-info}/RECORD +39 -18
  36. xpk-0.14.1.dist-info/top_level.txt +2 -0
  37. xpk-0.14.0.dist-info/top_level.txt +0 -1
  38. {xpk-0.14.0.dist-info → xpk-0.14.1.dist-info}/WHEEL +0 -0
  39. {xpk-0.14.0.dist-info → xpk-0.14.1.dist-info}/entry_points.txt +0 -0
  40. {xpk-0.14.0.dist-info → xpk-0.14.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,107 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from xpk.core.docker_manager import DockerManager
18
+ from xpk.core.gcluster_manager import GclusterManager
19
+ from xpk.core.blueprint.blueprint_generator import BlueprintGenerator
20
+ import os
21
+ import pytest
22
+ import shutil
23
+
24
+ ctk_gcloud_cfg = os.getenv("GCLOUD_CFG_PATH")
25
+ project_id = os.getenv("PROJECT_ID")
26
+ region = os.getenv("REGION")
27
+ zone = os.getenv("ZONE")
28
+ auth_cidr = os.getenv("AUTH_CIDR")
29
+ cluster_name = os.getenv("GKE_ML_TEST_CLUSTER_NAME")
30
+
31
+ uploads_dir = "uploads"
32
+
33
+
34
+ def prepare_test(docker_path: str, bp_path: str) -> None:
35
+ if not os.path.exists(docker_path):
36
+ os.makedirs(docker_path)
37
+ if not os.path.exists(bp_path):
38
+ os.makedirs(bp_path)
39
+
40
+
41
+ @pytest.mark.skip(reason="Credentails not working. Skipping for now")
42
+ def test_create_deployment():
43
+ assert project_id is not None
44
+ assert region is not None
45
+ assert zone is not None
46
+ assert auth_cidr is not None
47
+ assert ctk_gcloud_cfg is not None
48
+ assert cluster_name is not None
49
+
50
+ pwd = os.getcwd()
51
+ test_docker_working_dir = os.path.join(
52
+ pwd, "xpkclusters/tests/xpk_test_docker_dir"
53
+ )
54
+ test_bp_dir = os.path.join(pwd, "xpkclusters/tests/xpk_test_bp_dir")
55
+ prepare_test(test_docker_working_dir, test_bp_dir)
56
+ blueprint_name = "my-test-blueprint"
57
+ prefix = "prefix"
58
+
59
+ docker_manager = DockerManager(
60
+ gcloud_cfg_path=ctk_gcloud_cfg, working_dir=test_docker_working_dir
61
+ )
62
+ docker_manager.initialize()
63
+
64
+ bpm = BlueprintGenerator(storage_path=test_bp_dir)
65
+ ml_gke_blueprint = bpm.generate_gke_ml_blueprint(
66
+ cluster_name=cluster_name,
67
+ blueprint_name=blueprint_name,
68
+ prefix=prefix,
69
+ region=region,
70
+ project_id=project_id,
71
+ auth_cidr=auth_cidr,
72
+ )
73
+ blueprint_test_path = os.path.join(
74
+ test_bp_dir, prefix, f"{blueprint_name}.yaml"
75
+ )
76
+ # there are no files in ghcp stage for this blueprint
77
+ blueprint_deps_test_path = ""
78
+
79
+ assert ml_gke_blueprint.blueprint_file == blueprint_test_path
80
+ assert ml_gke_blueprint.blueprint_dependencies == blueprint_deps_test_path
81
+
82
+ assert os.path.exists(blueprint_test_path)
83
+
84
+ gcluster_manager = GclusterManager(
85
+ gcluster_command_runner=docker_manager, remote_state_client=None
86
+ )
87
+
88
+ staged_bp_path = gcluster_manager.stage_files(
89
+ blueprint_file=ml_gke_blueprint.blueprint_file,
90
+ blueprint_dependencies=ml_gke_blueprint.blueprint_dependencies,
91
+ prefix=prefix,
92
+ )
93
+
94
+ assert staged_bp_path == os.path.join(
95
+ "/out", uploads_dir, prefix, f"{blueprint_name}.yaml"
96
+ )
97
+
98
+ gcluster_manager.deploy(
99
+ blueprint_path=staged_bp_path,
100
+ deployment_name=blueprint_name,
101
+ prefix=prefix,
102
+ )
103
+ gcluster_manager.destroy_deployment(
104
+ deployment_name=blueprint_name, prefix=prefix
105
+ )
106
+ shutil.rmtree(test_docker_working_dir)
107
+ shutil.rmtree(test_bp_dir)
xpk/commands/cluster.py CHANGED
@@ -16,6 +16,7 @@ limitations under the License.
16
16
 
17
17
  from tabulate import tabulate
18
18
 
19
+ from ..utils.feature_flags import FeatureFlags
19
20
  from ..core.capacity import H100_DEVICE_TYPE, H200_DEVICE_TYPE, B200_DEVICE_TYPE
20
21
  from ..core.cluster import (
21
22
  get_all_clusters_programmatic,
@@ -75,9 +76,9 @@ from ..utils.file import write_tmp_file
75
76
  from ..utils.execution_context import is_dry_run
76
77
  from ..utils.validation import validate_dependencies_list, SystemDependency, should_validate_dependencies
77
78
  from . import cluster_gcluster
78
- from .common import set_cluster_command
79
+ from .common import set_cluster_command, validate_sub_slicing_system
79
80
  from jinja2 import Environment, FileSystemLoader
80
- from ..utils.templates import TEMPLATE_PATH
81
+ from ..utils.templates import get_templates_absolute_path
81
82
  import shutil
82
83
  import os
83
84
 
@@ -200,6 +201,11 @@ def cluster_adapt(args) -> None:
200
201
  xpk_exit(0)
201
202
 
202
203
 
204
+ def _validate_cluster_create_args(args, system: SystemCharacteristics):
205
+ if FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing:
206
+ validate_sub_slicing_system(system)
207
+
208
+
203
209
  def cluster_create(args) -> None:
204
210
  """Function around cluster creation.
205
211
 
@@ -212,12 +218,14 @@ def cluster_create(args) -> None:
212
218
  SystemDependency.KJOB,
213
219
  SystemDependency.GCLOUD,
214
220
  ])
215
- system, return_code = get_system_characteristics(args)
216
221
 
222
+ system, return_code = get_system_characteristics(args)
217
223
  if return_code > 0 or system is None:
218
224
  xpk_print('Fetching system characteristics failed!')
219
225
  xpk_exit(return_code)
220
226
 
227
+ _validate_cluster_create_args(args, system)
228
+
221
229
  xpk_print(f'Starting cluster create for cluster {args.cluster}:', flush=True)
222
230
  add_zone_and_project(args)
223
231
 
@@ -426,7 +434,9 @@ def cluster_cacheimage(args) -> None:
426
434
  system.accelerator_type
427
435
  ].accelerator_label
428
436
 
429
- template_env = Environment(loader=FileSystemLoader(TEMPLATE_PATH))
437
+ template_env = Environment(
438
+ loader=FileSystemLoader(searchpath=get_templates_absolute_path())
439
+ )
430
440
  cluster_preheat_yaml = template_env.get_template(CLUSTER_PREHEAT_JINJA_FILE)
431
441
  rendered_yaml = cluster_preheat_yaml.render(
432
442
  cachekey=args.cache_key,
@@ -1251,6 +1261,9 @@ def install_kueue(args, system: SystemCharacteristics, autoprovisioning_config):
1251
1261
  memory_limit=args.memory_limit,
1252
1262
  cpu_limit=args.cpu_limit,
1253
1263
  is_pathways_cluster=args.enable_pathways,
1264
+ configure_sub_slicing=(
1265
+ FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing
1266
+ ),
1254
1267
  ),
1255
1268
  )
1256
1269
 
@@ -16,6 +16,7 @@ limitations under the License.
16
16
 
17
17
  import os
18
18
 
19
+ from ..utils.feature_flags import FeatureFlags
19
20
  from ..utils.execution_context import is_dry_run
20
21
  from ..core.kueue_manager import KueueConfig, KueueManager
21
22
  from ..core.nap import enable_autoprovisioning_on_cluster
@@ -159,6 +160,9 @@ def __install_kueue(args) -> int:
159
160
  cpu_limit=args.cpu_limit,
160
161
  is_pathways_cluster=args.enable_pathways,
161
162
  flex=args.flex,
163
+ configure_sub_slicing=(
164
+ FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing
165
+ ),
162
166
  ),
163
167
  tolerations=tolerations,
164
168
  )
@@ -0,0 +1,92 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from argparse import Namespace
18
+ from dataclasses import dataclass
19
+ from unittest.mock import MagicMock
20
+ import pytest
21
+
22
+ from xpk.commands.cluster import _validate_cluster_create_args
23
+ from xpk.core.system_characteristics import SystemCharacteristics, UserFacingNameToSystemCharacteristics
24
+ from xpk.utils.feature_flags import FeatureFlags
25
+
26
+
27
+ @dataclass
28
+ class _Mocks:
29
+ common_print_mock: MagicMock
30
+ common_exit_mock: MagicMock
31
+
32
+
33
+ @pytest.fixture
34
+ def mock_common_print_and_exit(mocker):
35
+ common_print_mock = mocker.patch(
36
+ 'xpk.commands.common.xpk_print',
37
+ return_value=None,
38
+ )
39
+ common_exit_mock = mocker.patch(
40
+ 'xpk.commands.common.xpk_exit',
41
+ return_value=None,
42
+ )
43
+ return _Mocks(
44
+ common_print_mock=common_print_mock, common_exit_mock=common_exit_mock
45
+ )
46
+
47
+
48
+ DEFAULT_TEST_SYSTEM: SystemCharacteristics = (
49
+ UserFacingNameToSystemCharacteristics['l4-1']
50
+ )
51
+ SUB_SLICING_SYSTEM: SystemCharacteristics = (
52
+ UserFacingNameToSystemCharacteristics['v6e-4x4']
53
+ )
54
+
55
+
56
+ def test_validate_cluster_create_args_for_correct_args_pass(
57
+ mock_common_print_and_exit: _Mocks,
58
+ ):
59
+ args = Namespace()
60
+
61
+ _validate_cluster_create_args(args, DEFAULT_TEST_SYSTEM)
62
+
63
+ assert mock_common_print_and_exit.common_print_mock.call_count == 0
64
+ assert mock_common_print_and_exit.common_exit_mock.call_count == 0
65
+
66
+
67
+ def test_validate_cluster_create_args_for_correct_sub_slicing_args_pass(
68
+ mock_common_print_and_exit: _Mocks,
69
+ ):
70
+ FeatureFlags.SUB_SLICING_ENABLED = True
71
+ args = Namespace(sub_slicing=True)
72
+
73
+ _validate_cluster_create_args(args, SUB_SLICING_SYSTEM)
74
+
75
+ assert mock_common_print_and_exit.common_print_mock.call_count == 0
76
+ assert mock_common_print_and_exit.common_exit_mock.call_count == 0
77
+
78
+
79
+ def test_validate_cluster_create_args_for_not_supported_system_throws(
80
+ mock_common_print_and_exit: _Mocks,
81
+ ):
82
+ FeatureFlags.SUB_SLICING_ENABLED = True
83
+ args = Namespace(sub_slicing=True)
84
+
85
+ _validate_cluster_create_args(args, DEFAULT_TEST_SYSTEM)
86
+
87
+ assert mock_common_print_and_exit.common_print_mock.call_count == 1
88
+ assert (
89
+ mock_common_print_and_exit.common_print_mock.call_args[0][0]
90
+ == 'Error: l4-1 does not support Sub-slicing.'
91
+ )
92
+ assert mock_common_print_and_exit.common_exit_mock.call_count == 1
xpk/commands/common.py CHANGED
@@ -67,3 +67,9 @@ def is_TAS_possible(
67
67
  system_characteristics.device_type != H100_MEGA_DEVICE_TYPE
68
68
  or capacity_type == CapacityType.RESERVATION
69
69
  )
70
+
71
+
72
+ def validate_sub_slicing_system(system: SystemCharacteristics):
73
+ if not system.supports_sub_slicing:
74
+ xpk_print(f'Error: {system.device_type} does not support Sub-slicing.')
75
+ xpk_exit(1)
xpk/commands/kind.py CHANGED
@@ -110,6 +110,7 @@ def cluster_create(args) -> None:
110
110
  cpu_limit=0,
111
111
  is_pathways_cluster=False,
112
112
  flex=False,
113
+ configure_sub_slicing=False,
113
114
  ),
114
115
  )
115
116
 
xpk/commands/workload.py CHANGED
@@ -52,7 +52,7 @@ from ..core.pathways import (
52
52
  get_user_workload_for_pathways,
53
53
  try_to_delete_pathwaysjob_first,
54
54
  )
55
- from ..core.resources import get_cluster_capacity_type, get_cluster_system_characteristics
55
+ from ..core.resources import get_cluster_capacity_type, get_cluster_system_characteristics, SystemCharacteristics
56
56
  from ..core.resources import CLUSTER_METADATA_CONFIGMAP, get_cluster_configmap
57
57
  from ..core.scheduling import (
58
58
  check_if_workload_can_schedule,
@@ -78,6 +78,7 @@ from ..core.storage import (
78
78
  from ..core.system_characteristics import (
79
79
  AcceleratorType,
80
80
  get_system_characteristics,
81
+ compute_vms_per_slice,
81
82
  )
82
83
  from ..core.vertex import create_vertex_experiment
83
84
  from ..core.workload import (
@@ -98,7 +99,8 @@ from ..utils.file import write_tmp_file
98
99
  from ..utils.execution_context import is_dry_run
99
100
  from ..utils.validation import validate_dependencies_list, SystemDependency, should_validate_dependencies
100
101
  from . import cluster_gcluster
101
- from .common import is_TAS_possible
102
+ from .common import is_TAS_possible, validate_sub_slicing_system
103
+ from ..utils.topology import is_topology_contained
102
104
  from ..utils.feature_flags import FeatureFlags
103
105
 
104
106
  WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
@@ -120,8 +122,8 @@ spec:
120
122
  replicas: {args.num_slices}
121
123
  template:
122
124
  spec:
123
- parallelism: {system.vms_per_slice} # Equal to the number of VMs per slice
124
- completions: {system.vms_per_slice} # Same as the above.
125
+ parallelism: {vms_per_slice} # Equal to the number of VMs per slice (or sub-slice).
126
+ completions: {vms_per_slice} # Same as the above.
125
127
  backoffLimit: 0 # When any pod fails, the job is failed
126
128
  {pod_failure_policy}
127
129
  template:
@@ -280,6 +282,8 @@ PW_WORKLOAD_CREATE_YAML = """
280
282
  {user_workload}
281
283
  """
282
284
 
285
+ SUB_SLICING_TOPOLOGIES = ['2x2', '2x4', '4x4', '4x8', '8x8', '8x16', '16x16']
286
+
283
287
 
284
288
  def workload_create_pathways(args) -> None:
285
289
  """Run jobset apply command for a file, specifically for Pathways.
@@ -330,13 +334,14 @@ def workload_create(args) -> None:
330
334
  )
331
335
  xpk_exit(1)
332
336
 
333
- xpk_print('Starting workload create', flush=True)
334
337
  system, return_code = get_system_characteristics(args)
335
-
336
338
  if return_code > 0 or system is None:
337
339
  xpk_print('Fetching system characteristics failed!')
338
340
  xpk_exit(return_code)
339
341
 
342
+ if FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing_topology is not None:
343
+ _validate_sub_slicing_topology(system, args.sub_slicing_topology)
344
+
340
345
  if not check_if_workload_can_schedule(args, system):
341
346
  xpk_exit(1)
342
347
 
@@ -558,8 +563,14 @@ def workload_create(args) -> None:
558
563
  )
559
564
  yml_string = WORKLOAD_CREATE_YAML.format(
560
565
  args=args,
561
- system=system,
562
566
  container=container,
567
+ vms_per_slice=(
568
+ compute_vms_per_slice(args.sub_slicing_topology)
569
+ if system.accelerator_type == AcceleratorType['TPU']
570
+ and FeatureFlags.SUB_SLICING_ENABLED
571
+ and args.sub_slicing_topology is not None
572
+ else system.vms_per_slice
573
+ ),
563
574
  affinity=get_cpu_affinity(system.accelerator_type),
564
575
  accelerator_label=create_accelerator_label(
565
576
  system.accelerator_type, system
@@ -667,6 +678,29 @@ def workload_create(args) -> None:
667
678
  xpk_exit(0)
668
679
 
669
680
 
681
+ def _validate_sub_slicing_topology(
682
+ system_characteristics: SystemCharacteristics, sub_slicing_topology: str
683
+ ) -> None:
684
+ if sub_slicing_topology not in SUB_SLICING_TOPOLOGIES:
685
+ xpk_print(
686
+ f'Error: --sub-slicing-topology={sub_slicing_topology} shape is'
687
+ f' invalid. It has to be one of: {", ".join(SUB_SLICING_TOPOLOGIES)}.'
688
+ )
689
+ xpk_exit(1)
690
+
691
+ if not is_topology_contained(
692
+ contained=sub_slicing_topology, container=system_characteristics.topology
693
+ ):
694
+ xpk_print(
695
+ f'Error: --sub-slicing-topology={sub_slicing_topology} shape is too'
696
+ ' large. The shape cannot be bigger than'
697
+ f' {system_characteristics.topology}.'
698
+ )
699
+ xpk_exit(1)
700
+
701
+ validate_sub_slicing_system(system_characteristics)
702
+
703
+
670
704
  def get_restart_exit_codes(args) -> list:
671
705
  exit_codes = [42]
672
706
  exit_codes.extend(range(127, 256, 1))
@@ -0,0 +1,81 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import dataclasses
18
+ from unittest.mock import MagicMock, patch
19
+ import pytest
20
+ from ..core.system_characteristics import SystemCharacteristics
21
+ from .workload import _validate_sub_slicing_topology
22
+
23
+
24
+ SYSTEM_CHARACTERISTICS = SystemCharacteristics(
25
+ topology='8x8',
26
+ vms_per_slice=1,
27
+ gke_accelerator='nvidia-l4',
28
+ gce_machine_type='g2-standard-12',
29
+ chips_per_vm=1,
30
+ accelerator_type=1,
31
+ device_type='l4-1',
32
+ supports_sub_slicing=True,
33
+ requires_workload_policy=False,
34
+ )
35
+
36
+
37
+ @pytest.fixture(autouse=True)
38
+ def xpk_print(mocker):
39
+ return mocker.patch('xpk.commands.workload.xpk_print')
40
+
41
+
42
+ def test_validate_sub_slicing_topology_exits_for_unsupported_topology(
43
+ xpk_print,
44
+ ):
45
+ with pytest.raises(SystemExit):
46
+ _validate_sub_slicing_topology(SYSTEM_CHARACTERISTICS, '2x1')
47
+
48
+ assert (
49
+ 'shape is invalid. It has to be one of' in xpk_print.mock_calls[0].args[0]
50
+ )
51
+
52
+
53
+ def test_validate_sub_slicing_topology_exits_for_too_large_topology(xpk_print):
54
+ with pytest.raises(SystemExit):
55
+ _validate_sub_slicing_topology(SYSTEM_CHARACTERISTICS, '16x16')
56
+
57
+ assert (
58
+ 'shape is too large. The shape cannot be'
59
+ in xpk_print.mock_calls[0].args[0]
60
+ )
61
+
62
+
63
+ def test_validate_sub_slicing_topology_does_nothing_for_supported_topology():
64
+ _validate_sub_slicing_topology(SYSTEM_CHARACTERISTICS, '4x4')
65
+
66
+
67
+ @patch('xpk.commands.common.xpk_print')
68
+ def test_validate_sub_slicing_topology_fails_for_unsupported_system(
69
+ common_xpk_print: MagicMock,
70
+ ):
71
+ unsupported_system = dataclasses.replace(
72
+ SYSTEM_CHARACTERISTICS, supports_sub_slicing=False
73
+ )
74
+
75
+ with pytest.raises(SystemExit):
76
+ _validate_sub_slicing_topology(unsupported_system, '4x4')
77
+
78
+ assert (
79
+ 'l4-1 does not support Sub-slicing.'
80
+ in common_xpk_print.mock_calls[0].args[0]
81
+ )
@@ -0,0 +1,15 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
xpk/core/config.py CHANGED
@@ -22,7 +22,7 @@ from ..utils import file
22
22
  from ..utils.console import xpk_print
23
23
 
24
24
  # This is the version for XPK PyPI package
25
- __version__ = 'v0.14.0'
25
+ __version__ = 'v0.14.1'
26
26
  XPK_CURRENT_VERSION = __version__
27
27
  XPK_CONFIG_FILE = os.path.expanduser('~/.config/xpk/config.yaml')
28
28
 
xpk/core/kueue_manager.py CHANGED
@@ -39,14 +39,16 @@ from ..core.commands import (
39
39
  )
40
40
  from ..utils.file import write_tmp_file
41
41
  from ..utils.console import xpk_print, xpk_exit
42
- from ..utils.templates import TEMPLATE_PATH
42
+ from ..utils.templates import TEMPLATE_PATH, get_templates_absolute_path
43
43
 
44
44
  WAIT_FOR_KUEUE_TIMEOUT = "10m"
45
45
  CLUSTER_QUEUE_NAME = "cluster-queue"
46
46
  LOCAL_QUEUE_NAME = "multislice-queue"
47
+ SUB_SLICE_TOPOLOGY_NAME = "sub-slice-topology"
47
48
  KUEUE_CONFIG_JINJA_FILE = "kueue_config.yaml.j2"
48
- KUEUE_TOPOLOGY_JINJA_FILE = "kueue_topology.yaml.j2"
49
+ KUEUE_GKE_DEFAULT_TOPOLOGY_JINJA_FILE = "kueue_gke_default_topology.yaml.j2"
49
50
  KUEUE_CONTROLLER_MANAGER_JINJA_FILE = "kueue_controller_manager.yaml.j2"
51
+ KUEUE_SUB_SLICING_TOPOLOGY_JINJA_FILE = "kueue_sub_slicing_topology.yaml.j2"
50
52
  MEMORY_SIZE_PER_VM = 1.2
51
53
  MIN_MEMORY_LIMIT_SIZE = 4096
52
54
  KUEUE_VERSION = "v0.14.1"
@@ -58,12 +60,19 @@ class KueueConfig:
58
60
  total_chips: int
59
61
  cpu_limit: int
60
62
  memory_limit: str
63
+ configure_sub_slicing: bool
61
64
  is_pathways_cluster: bool = False
62
65
  autoprovisioning_enabled: bool = False
63
66
  flex: bool = False
64
67
  num_slices: int = 1
65
68
 
66
69
 
70
+ @dataclass
71
+ class _NameAndYaml:
72
+ name: str
73
+ yaml: str
74
+
75
+
67
76
  class KueueManager:
68
77
  """Manages the installation and configuration of Kueue on an XPK cluster."""
69
78
 
@@ -73,7 +82,12 @@ class KueueManager:
73
82
  template_path=TEMPLATE_PATH,
74
83
  ):
75
84
  self.kueue_version = kueue_version
76
- self.template_env = Environment(loader=FileSystemLoader(template_path))
85
+
86
+ self.template_env = Environment(
87
+ loader=FileSystemLoader(
88
+ searchpath=get_templates_absolute_path(template_path)
89
+ )
90
+ )
77
91
 
78
92
  def install_or_upgrade(
79
93
  self,
@@ -87,7 +101,7 @@ class KueueManager:
87
101
  Args:
88
102
  tolerations: An optional list of tolerations to apply to the kueue-controller-manager.
89
103
  """
90
- return_code, installed_version = self.__get_installed_kueue_version()
104
+ return_code, installed_version = self.get_installed_kueue_version()
91
105
 
92
106
  if return_code == 0:
93
107
  if installed_version and installed_version > self.kueue_version:
@@ -107,7 +121,7 @@ class KueueManager:
107
121
 
108
122
  return self.__configure(kueue_config)
109
123
 
110
- def __get_installed_kueue_version(self) -> tuple[int, str | None]:
124
+ def get_installed_kueue_version(self) -> tuple[int, str | None]:
111
125
  command = (
112
126
  "kubectl get deployment kueue-controller-manager -n kueue-system -o"
113
127
  " jsonpath='{.spec.template.spec.containers[0].image}'"
@@ -208,6 +222,13 @@ class KueueManager:
208
222
  """
209
223
  template = self.template_env.get_template(KUEUE_CONFIG_JINJA_FILE)
210
224
 
225
+ topology_name_and_yaml = self.__get_topology_name_and_yaml(
226
+ kueue_config.system, kueue_config.configure_sub_slicing
227
+ )
228
+ topology_name = (
229
+ topology_name_and_yaml.name if topology_name_and_yaml else None
230
+ )
231
+
211
232
  # The manager builds the context internally based on its opinionated logic
212
233
  context = self.__build_template_context(
213
234
  system=kueue_config.system,
@@ -218,18 +239,16 @@ class KueueManager:
218
239
  num_slices=kueue_config.num_slices,
219
240
  cpu_limit=kueue_config.cpu_limit,
220
241
  memory_limit=kueue_config.memory_limit,
242
+ topology_name=topology_name,
221
243
  )
222
244
 
223
- rendered_manifest = template.render(context)
245
+ config_yaml = template.render(context)
246
+ yamls = [config_yaml]
224
247
 
225
- if kueue_config.system.device_type in [
226
- H100_MEGA_DEVICE_TYPE,
227
- H200_DEVICE_TYPE,
228
- B200_DEVICE_TYPE,
229
- ]:
230
- topology_yaml = self.template_env.get_template(KUEUE_TOPOLOGY_JINJA_FILE)
231
- rendered_manifest = topology_yaml.render() + rendered_manifest
248
+ if topology_name_and_yaml:
249
+ yamls.append(topology_name_and_yaml.yaml)
232
250
 
251
+ rendered_manifest = "\n---\n".join(yamls)
233
252
  return_code = self.__apply_manifest(rendered_manifest)
234
253
  if return_code != 0:
235
254
  return return_code
@@ -246,6 +265,7 @@ class KueueManager:
246
265
  num_slices: int,
247
266
  cpu_limit: int,
248
267
  memory_limit: str,
268
+ topology_name: str | None,
249
269
  ) -> Dict[str, Any]:
250
270
  """Prepares the context for the Jinja2 template."""
251
271
  # Main accelerator flavor
@@ -267,13 +287,7 @@ class KueueManager:
267
287
  key, value = machine_label.split(":", 1)
268
288
  node_labels_dict[key] = value.strip()
269
289
 
270
- topology_label = ""
271
- if system.device_type in [
272
- H100_MEGA_DEVICE_TYPE,
273
- H200_DEVICE_TYPE,
274
- B200_DEVICE_TYPE,
275
- ]:
276
- topology_label = 'topologyName: "gke-default"'
290
+ topology_label = f"topologyName: {topology_name}" if topology_name else ""
277
291
 
278
292
  flavors = [{
279
293
  "name": main_flavor_name,
@@ -335,6 +349,32 @@ class KueueManager:
335
349
  "admission_checks": admission_checks,
336
350
  }
337
351
 
352
+ def __get_topology_name_and_yaml(
353
+ self, system: SystemCharacteristics, configure_sub_slicing: bool
354
+ ) -> _NameAndYaml | None:
355
+ if system.device_type in [
356
+ H100_MEGA_DEVICE_TYPE,
357
+ H200_DEVICE_TYPE,
358
+ B200_DEVICE_TYPE,
359
+ ]:
360
+ return _NameAndYaml(
361
+ name="gke-default",
362
+ yaml=self.template_env.get_template(
363
+ KUEUE_GKE_DEFAULT_TOPOLOGY_JINJA_FILE
364
+ ).render(),
365
+ )
366
+ elif configure_sub_slicing:
367
+ return _NameAndYaml(
368
+ name=SUB_SLICE_TOPOLOGY_NAME,
369
+ yaml=self.template_env.get_template(
370
+ KUEUE_SUB_SLICING_TOPOLOGY_JINJA_FILE
371
+ ).render({
372
+ "sub_slice_topology_name": SUB_SLICE_TOPOLOGY_NAME,
373
+ }),
374
+ )
375
+ else:
376
+ return None
377
+
338
378
  def __apply_manifest(self, manifest: str) -> int:
339
379
  task = "Applying Kueue Custom Resources"
340
380
  if is_dry_run():