xpk 0.14.4__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. integration/README.md +19 -0
  2. integration/gcluster_a3mega_test.py +11 -0
  3. integration/gcluster_a3ultra_test.py +11 -0
  4. integration/gcluster_a4_test.py +11 -0
  5. xpk/blueprints/a3mega/config-map.yaml.tftpl +15 -0
  6. xpk/blueprints/a3mega/storage_crd.yaml +52 -0
  7. xpk/blueprints/a3ultra/config-map.yaml.tftpl +15 -0
  8. xpk/blueprints/a3ultra/mlgru-disable.yaml +59 -0
  9. xpk/blueprints/a3ultra/nccl-installer.yaml +95 -0
  10. xpk/blueprints/a3ultra/storage_crd.yaml +52 -0
  11. xpk/blueprints/a4/config-map.yaml.tftpl +15 -0
  12. xpk/blueprints/a4/nccl-rdma-installer-a4.yaml +66 -0
  13. xpk/blueprints/a4/storage_crd.yaml +52 -0
  14. xpk/commands/cluster.py +89 -32
  15. xpk/commands/cluster_gcluster.py +25 -5
  16. xpk/commands/cluster_gcluster_test.py +16 -3
  17. xpk/commands/cluster_test.py +353 -7
  18. xpk/commands/config.py +3 -5
  19. xpk/commands/inspector.py +5 -3
  20. xpk/commands/kind.py +3 -1
  21. xpk/commands/managed_ml_diagnostics.py +249 -0
  22. xpk/commands/managed_ml_diagnostics_test.py +146 -0
  23. xpk/commands/storage.py +8 -10
  24. xpk/commands/workload.py +143 -142
  25. xpk/commands/workload_test.py +160 -118
  26. xpk/core/blueprint/blueprint_generator.py +73 -33
  27. xpk/core/blueprint/blueprint_test.py +9 -0
  28. xpk/core/blueprint/testing/data/a3_mega.yaml +129 -0
  29. xpk/core/blueprint/testing/data/a3_mega_spot.yaml +125 -0
  30. xpk/core/blueprint/testing/data/a3_ultra.yaml +173 -0
  31. xpk/core/blueprint/testing/data/a4.yaml +185 -0
  32. xpk/core/capacity.py +48 -8
  33. xpk/core/capacity_test.py +32 -1
  34. xpk/core/cluster.py +55 -104
  35. xpk/core/cluster_test.py +170 -0
  36. xpk/core/commands.py +4 -10
  37. xpk/core/config.py +88 -7
  38. xpk/core/config_test.py +67 -11
  39. xpk/core/docker_container.py +3 -1
  40. xpk/core/docker_image.py +10 -6
  41. xpk/core/docker_resources.py +1 -10
  42. xpk/core/gcloud_context.py +18 -12
  43. xpk/core/gcloud_context_test.py +111 -1
  44. xpk/core/kjob.py +17 -19
  45. xpk/core/kueue_manager.py +205 -51
  46. xpk/core/kueue_manager_test.py +158 -4
  47. xpk/core/nap.py +13 -14
  48. xpk/core/nodepool.py +37 -43
  49. xpk/core/nodepool_test.py +42 -19
  50. xpk/core/pathways.py +23 -0
  51. xpk/core/pathways_test.py +57 -0
  52. xpk/core/resources.py +84 -27
  53. xpk/core/scheduling.py +144 -133
  54. xpk/core/scheduling_test.py +298 -6
  55. xpk/core/system_characteristics.py +256 -19
  56. xpk/core/system_characteristics_test.py +128 -5
  57. xpk/core/telemetry.py +263 -0
  58. xpk/core/telemetry_test.py +211 -0
  59. xpk/core/vertex.py +4 -3
  60. xpk/core/workload_decorators/tcpx_decorator.py +5 -1
  61. xpk/main.py +33 -13
  62. xpk/parser/cluster.py +40 -67
  63. xpk/parser/cluster_test.py +83 -3
  64. xpk/parser/common.py +84 -0
  65. xpk/parser/storage.py +10 -0
  66. xpk/parser/storage_test.py +47 -0
  67. xpk/parser/workload.py +14 -29
  68. xpk/parser/workload_test.py +3 -49
  69. xpk/telemetry_uploader.py +29 -0
  70. xpk/templates/arm_gpu_workload_crate.yaml.j2 +46 -0
  71. xpk/templates/kueue_gke_default_topology.yaml.j2 +1 -1
  72. xpk/templates/kueue_sub_slicing_topology.yaml.j2 +3 -8
  73. xpk/utils/console.py +41 -10
  74. xpk/utils/console_test.py +106 -0
  75. xpk/utils/feature_flags.py +10 -1
  76. xpk/utils/file.py +4 -1
  77. xpk/utils/topology.py +4 -0
  78. xpk/utils/user_agent.py +35 -0
  79. xpk/utils/user_agent_test.py +44 -0
  80. xpk/utils/user_input.py +48 -0
  81. xpk/utils/user_input_test.py +92 -0
  82. xpk/utils/validation.py +2 -13
  83. xpk/utils/versions.py +31 -0
  84. xpk-0.16.0.dist-info/METADATA +127 -0
  85. xpk-0.16.0.dist-info/RECORD +168 -0
  86. xpk-0.14.4.dist-info/METADATA +0 -1645
  87. xpk-0.14.4.dist-info/RECORD +0 -139
  88. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/WHEEL +0 -0
  89. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/entry_points.txt +0 -0
  90. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/licenses/LICENSE +0 -0
  91. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,13 @@ limitations under the License.
16
16
 
17
17
  import dataclasses
18
18
  from unittest.mock import MagicMock, patch
19
+ import yaml
19
20
  import pytest
20
- from ..core.system_characteristics import SystemCharacteristics, AcceleratorType
21
- from .workload import _validate_sub_slicing_topology, _validate_sub_slicing_availability
22
- from packaging.version import Version
21
+
22
+ from ..core.scheduling import WorkloadScheduling
23
+ from ..core.system_characteristics import DockerPlatform, SystemCharacteristics, AcceleratorType, UserFacingNameToSystemCharacteristics, GpuConfig
24
+ from .workload import workload_create
25
+ from .cluster_test import construct_args
23
26
 
24
27
 
25
28
  SYSTEM_CHARACTERISTICS = SystemCharacteristics(
@@ -32,133 +35,172 @@ SYSTEM_CHARACTERISTICS = SystemCharacteristics(
32
35
  device_type='l4-1',
33
36
  supports_sub_slicing=True,
34
37
  requires_workload_policy=False,
38
+ docker_platform=DockerPlatform.AMD,
35
39
  )
36
40
 
37
41
 
38
- @pytest.fixture(autouse=True)
42
+ @dataclasses.dataclass
43
+ class _WorkloadCreateMocks:
44
+ """Holds all the mocked dependencies for the workload_create function."""
45
+
46
+ get_user_workload_container: MagicMock
47
+ get_gpu_scheduler: MagicMock
48
+ get_storages_to_mount: MagicMock
49
+ add_bucket_iam_members: MagicMock
50
+ get_gke_outlier_dashboard: MagicMock
51
+ check_if_workload_exists: MagicMock
52
+ get_cluster_configmap: MagicMock
53
+ check_if_workload_can_schedule: MagicMock
54
+ setup_k8s_env: MagicMock
55
+ setup_k8s_service_accounts: MagicMock
56
+ validate_dependencies_list: MagicMock
57
+ write_tmp_file: MagicMock
58
+ get_cluster_capacity_type: MagicMock
59
+ is_TAS_possible: MagicMock
60
+ get_cluster_location: MagicMock
61
+ xpk_exit: MagicMock
62
+ run_command_with_updates: MagicMock
63
+ ensure_resource_policy_exists: MagicMock
64
+ get_cluster_subnetworks: MagicMock
65
+
66
+
67
+ @pytest.fixture
39
68
  def xpk_print(mocker):
40
69
  return mocker.patch('xpk.commands.workload.xpk_print')
41
70
 
42
71
 
43
- def test_validate_sub_slicing_topology_exits_for_unsupported_topology(
44
- xpk_print: MagicMock,
45
- ):
46
- with pytest.raises(SystemExit):
47
- _validate_sub_slicing_topology(SYSTEM_CHARACTERISTICS, '2x1')
48
-
49
- assert (
50
- 'shape is invalid. It has to be one of' in xpk_print.mock_calls[0].args[0]
51
- )
52
-
53
-
54
- def test_validate_sub_slicing_topology_exits_for_too_large_topology(
55
- xpk_print: MagicMock,
56
- ):
57
- with pytest.raises(SystemExit):
58
- _validate_sub_slicing_topology(SYSTEM_CHARACTERISTICS, '16x16')
59
-
60
- assert (
61
- 'shape is too large. The shape cannot be'
62
- in xpk_print.mock_calls[0].args[0]
63
- )
64
-
65
-
66
- def test_validate_sub_slicing_topology_does_nothing_for_supported_topology():
67
- _validate_sub_slicing_topology(SYSTEM_CHARACTERISTICS, '4x4')
68
-
69
-
70
- def test_validate_sub_slicing_availability_exits_when_getting_topologies_fails(
71
- xpk_print: MagicMock, mocker
72
- ):
73
- mocker.patch(
74
- 'xpk.commands.workload.has_sub_slicing_enabled',
75
- return_value=(1, None),
76
- )
77
- with pytest.raises(SystemExit):
78
- _validate_sub_slicing_availability()
79
-
80
- assert (
81
- 'Unable to validate sub-slicing support'
82
- in xpk_print.mock_calls[0].args[0]
83
- )
84
-
85
-
86
- def test_validate_sub_slicing_availability_exits_when_subslicing_topology_is_not_defined(
87
- xpk_print: MagicMock, mocker
88
- ):
89
- mocker.patch(
90
- 'xpk.commands.workload.has_sub_slicing_enabled',
91
- return_value=(0, False),
92
- )
93
- with pytest.raises(SystemExit):
94
- _validate_sub_slicing_availability()
95
-
96
- assert (
97
- 'Cluster has not been not set up for Sub-slicing.'
98
- in xpk_print.mock_calls[0].args[0]
72
+ @pytest.fixture
73
+ def workload_create_mocks(mocker) -> _WorkloadCreateMocks:
74
+ """Mocks all dependencies for the workload_create function."""
75
+ return _WorkloadCreateMocks(
76
+ get_user_workload_container=mocker.patch(
77
+ 'xpk.commands.workload.get_user_workload_container',
78
+ return_value=('', None),
79
+ ),
80
+ get_gpu_scheduler=mocker.patch(
81
+ 'xpk.commands.workload.get_gpu_scheduler', return_value=('', 0)
82
+ ),
83
+ get_storages_to_mount=mocker.patch(
84
+ 'xpk.commands.workload.get_storages_to_mount', return_value=[]
85
+ ),
86
+ add_bucket_iam_members=mocker.patch(
87
+ 'xpk.commands.workload.add_bucket_iam_members'
88
+ ),
89
+ get_gke_outlier_dashboard=mocker.patch(
90
+ 'xpk.commands.workload.get_gke_outlier_dashboard'
91
+ ),
92
+ check_if_workload_exists=mocker.patch(
93
+ 'xpk.commands.workload.check_if_workload_exists', return_value=False
94
+ ),
95
+ get_cluster_configmap=mocker.patch(
96
+ 'xpk.commands.workload.get_cluster_configmap', return_value={}
97
+ ),
98
+ check_if_workload_can_schedule=mocker.patch(
99
+ 'xpk.commands.workload.check_if_workload_can_schedule',
100
+ return_value=WorkloadScheduling.AVAILABLE,
101
+ ),
102
+ setup_k8s_env=mocker.patch('xpk.commands.workload.setup_k8s_env'),
103
+ setup_k8s_service_accounts=mocker.patch(
104
+ 'xpk.commands.workload.setup_k8s_service_accounts'
105
+ ),
106
+ validate_dependencies_list=mocker.patch(
107
+ 'xpk.commands.workload.validate_dependencies_list'
108
+ ),
109
+ write_tmp_file=mocker.patch('xpk.commands.workload.write_tmp_file'),
110
+ get_cluster_capacity_type=mocker.patch(
111
+ 'xpk.commands.workload.get_cluster_capacity_type',
112
+ return_value='on-demand',
113
+ ),
114
+ is_TAS_possible=mocker.patch(
115
+ 'xpk.commands.workload.is_TAS_possible', return_value=False
116
+ ),
117
+ get_cluster_location=mocker.patch(
118
+ 'xpk.commands.workload.get_cluster_location',
119
+ return_value='us-central1',
120
+ ),
121
+ xpk_exit=mocker.patch('xpk.commands.workload.xpk_exit'),
122
+ run_command_with_updates=mocker.patch(
123
+ 'xpk.commands.workload.run_command_with_updates', return_value=0
124
+ ),
125
+ ensure_resource_policy_exists=mocker.patch(
126
+ 'xpk.commands.workload.ensure_resource_policy_exists'
127
+ ),
128
+ get_cluster_subnetworks=mocker.patch(
129
+ 'xpk.commands.workload.get_cluster_subnetworks', return_value=[]
130
+ ),
99
131
  )
100
132
 
101
133
 
102
- def test_validate_sub_slicing_availability_exits_when_kueue_version_cannot_be_determined(
103
- xpk_print: MagicMock, mocker
134
+ def test_workload_create_for_a4x_has_arm_toleration(
135
+ workload_create_mocks: _WorkloadCreateMocks,
104
136
  ):
105
- mocker.patch(
106
- 'xpk.commands.workload.has_sub_slicing_enabled',
107
- return_value=(0, True),
137
+ """Tests that the generated YAML for an A4X workload has arm64 toleration."""
138
+ # Copy and overwrite the decorator with a no-op lambda.
139
+ gb200_system_chars = UserFacingNameToSystemCharacteristics['gb200-4']
140
+ gb200_system_chars_no_decorator = dataclasses.replace(
141
+ gb200_system_chars,
142
+ gpu_config=GpuConfig(
143
+ requires_topology=False, jobset_decorator_fn=lambda yml, *_: yml
144
+ ),
108
145
  )
109
- mocker.patch(
110
- 'xpk.commands.workload.KueueManager.get_installed_kueue_version',
111
- return_value=(1, None),
112
- )
113
- with pytest.raises(SystemExit):
114
- _validate_sub_slicing_availability()
115
-
116
- assert 'Unable to validate sub-slicing' in xpk_print.mock_calls[0].args[0]
146
+ # Patch the function that returns the system characteristics
147
+ # to return our modified object.
148
+ with patch(
149
+ 'xpk.commands.workload.get_system_characteristics',
150
+ return_value=(gb200_system_chars_no_decorator, 0),
151
+ ):
152
+ args = construct_args(
153
+ device_type='gb200-4',
154
+ workload='test-workload',
155
+ command='echo hello',
156
+ num_nodes=1,
157
+ restart_on_exit_codes=None,
158
+ )
159
+ workload_create(args)
160
+
161
+ assert workload_create_mocks.write_tmp_file.called
162
+ yaml_content = workload_create_mocks.write_tmp_file.call_args[0][0]
163
+ jobset = yaml.safe_load(yaml_content)
164
+
165
+ tolerations = jobset['spec']['replicatedJobs'][0]['template']['spec'][
166
+ 'template'
167
+ ]['spec']['tolerations']
168
+ assert {
169
+ 'key': 'kubernetes.io/arch',
170
+ 'operator': 'Equal',
171
+ 'value': 'arm64',
172
+ 'effect': 'NoSchedule',
173
+ } in tolerations
174
+
175
+
176
+ def test_workload_create_dry_run_with_output_file(mocker):
177
+ args = MagicMock()
178
+ args.workload = 'test-workload'
179
+ args.output_manifest_file = 'manifest.yaml'
180
+ args.use_pathways = False
181
+ args.use_vertex_tensorboard = False
182
+ args.project = 'test-project'
183
+ args.cluster = 'test-cluster'
184
+ args.zone = 'test-zone'
185
+ args.sub_slicing_topology = None
186
+
187
+ # Mock dependencies to avoid external calls and simulate state
188
+ mocker.patch('xpk.utils.execution_context.dry_run', True)
189
+ mocks = {
190
+ 'get_system_characteristics': (SYSTEM_CHARACTERISTICS, 0),
191
+ 'get_user_workload_container': ('container_yaml', None),
192
+ 'write_tmp_file': 'tmp_file',
193
+ 'parse_env_config': None,
194
+ }
195
+ for name, return_value in mocks.items():
196
+ mocker.patch(f'xpk.commands.workload.{name}', return_value=return_value)
197
+
198
+ mock_open = mocker.patch('builtins.open', mocker.mock_open())
117
199
 
118
-
119
- def test_validate_sub_slicing_availability_exits_when_kueue_version_does_not_meet_minimum_requirements(
120
- xpk_print: MagicMock, mocker
121
- ):
122
- mocker.patch(
123
- 'xpk.commands.workload.has_sub_slicing_enabled',
124
- return_value=(0, True),
125
- )
126
- mocker.patch(
127
- 'xpk.commands.workload.KueueManager.get_installed_kueue_version',
128
- return_value=(0, Version('0.0.0')),
129
- )
130
200
  with pytest.raises(SystemExit):
131
- _validate_sub_slicing_availability()
132
-
133
- assert 'The minimal required version is' in xpk_print.mock_calls[0].args[0]
134
-
201
+ workload_create(args)
135
202
 
136
- def test_validate_sub_slicing_availability_does_nothing_when_cluster_is_correctly_configured_for_subslicing(
137
- mocker,
138
- ):
139
- mocker.patch(
140
- 'xpk.commands.workload.has_sub_slicing_enabled',
141
- return_value=(0, True),
142
- )
143
- mocker.patch(
144
- 'xpk.commands.workload.KueueManager.get_installed_kueue_version',
145
- return_value=(0, Version('0.13.0')),
146
- )
147
- _validate_sub_slicing_availability()
148
-
149
-
150
- @patch('xpk.commands.common.xpk_print')
151
- def test_validate_sub_slicing_topology_fails_for_unsupported_system(
152
- common_xpk_print: MagicMock,
153
- ):
154
- unsupported_system = dataclasses.replace(
155
- SYSTEM_CHARACTERISTICS, supports_sub_slicing=False
156
- )
157
-
158
- with pytest.raises(SystemExit):
159
- _validate_sub_slicing_topology(unsupported_system, '4x4')
160
-
161
- assert (
162
- 'l4-1 does not support Sub-slicing.'
163
- in common_xpk_print.mock_calls[0].args[0]
164
- )
203
+ mock_open.assert_called_once_with('manifest.yaml', 'w', encoding='utf-8')
204
+ written_content = mock_open.return_value.write.call_args[0][0]
205
+ assert 'test-workload' in written_content
206
+ assert 'cloud.google.com/gke-tpu-topology: 8x8' in written_content
@@ -19,15 +19,20 @@ import shutil
19
19
  from typing import Optional
20
20
 
21
21
  from ruamel import yaml
22
+ from packaging.version import parse
22
23
 
23
24
  from ...utils.console import xpk_exit, xpk_print
25
+ from ...utils.versions import ReleaseChannel
24
26
  from ...utils.file import ensure_directory_exists
25
27
 
28
+
26
29
  from ..capacity import (
27
30
  H100_DEVICE_TYPE,
28
31
  B200_DEVICE_TYPE,
29
32
  H100_MEGA_DEVICE_TYPE,
30
33
  H200_DEVICE_TYPE,
34
+ GB200_DEVICE_TYPE,
35
+ GB200_DEVICE_TYPE_NOLSSD,
31
36
  CapacityType,
32
37
  )
33
38
  from ..system_characteristics import get_system_characteristics_by_device_type
@@ -39,6 +44,7 @@ a3high_device_type = H100_DEVICE_TYPE
39
44
  a3mega_device_type = H100_MEGA_DEVICE_TYPE
40
45
  a3ultra_device_type = H200_DEVICE_TYPE
41
46
  a4_device_type = B200_DEVICE_TYPE
47
+ a4x_device_types = (GB200_DEVICE_TYPE, GB200_DEVICE_TYPE_NOLSSD)
42
48
  supported_device_types = {
43
49
  a3mega_device_type,
44
50
  a3ultra_device_type,
@@ -84,6 +90,8 @@ class BlueprintGenerator:
84
90
  region: str,
85
91
  zone: str,
86
92
  auth_cidr: str,
93
+ cluster_version: str,
94
+ release_channel: ReleaseChannel,
87
95
  prefix: str = "",
88
96
  num_nodes: int = 2,
89
97
  pods_ip_cidr_range: str = "10.4.0.0/14",
@@ -142,11 +150,17 @@ class BlueprintGenerator:
142
150
  },
143
151
  )
144
152
 
153
+ sanitized_version = cluster_version.replace("-", "+", 1)
154
+ version = parse(sanitized_version)
155
+ version_prefix = f"{version.major}.{version.minor}"
145
156
  gke_cluster = DeploymentModule(
146
157
  id="gke_cluster",
147
158
  source="modules/scheduler/gke-cluster",
148
159
  use=[primary_vpc_name, gpu_subnets_name],
149
160
  settings={
161
+ "release_channel": release_channel.value,
162
+ "version_prefix": version_prefix,
163
+ "min_master_version": cluster_version,
150
164
  "prefix_with_deployment_name": False,
151
165
  "name_suffix": cluster_name,
152
166
  "enable_private_endpoint": False,
@@ -171,6 +185,16 @@ class BlueprintGenerator:
171
185
  },
172
186
  outputs=["instructions"],
173
187
  )
188
+ if release_channel != ReleaseChannel.RAPID:
189
+ gke_cluster.set_setting(
190
+ "maintenance_exclusions",
191
+ [{
192
+ "name": "no-minor-or-node-upgrades-indefinite",
193
+ "start_time": "2024-12-01T00:00:00Z",
194
+ "end_time": "2026-01-16T00:00:00Z",
195
+ "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
196
+ }],
197
+ )
174
198
 
175
199
  group_placement_0 = DeploymentModule(
176
200
  id="group_placement_0",
@@ -215,6 +239,9 @@ class BlueprintGenerator:
215
239
  else:
216
240
  a3_megagpu_pool_0.update_settings({"static_node_count": num_nodes})
217
241
 
242
+ if release_channel == ReleaseChannel.RAPID:
243
+ a3_megagpu_pool_0.set_setting("auto_upgrade", True)
244
+
218
245
  set_placement_policy = capacity_type != CapacityType.SPOT
219
246
  workload = DeploymentModule(
220
247
  id="workload_component_install",
@@ -391,6 +418,8 @@ class BlueprintGenerator:
391
418
  zone: str,
392
419
  auth_cidr: str,
393
420
  system_node_pool_machine_type: str,
421
+ cluster_version: str,
422
+ release_channel: ReleaseChannel,
394
423
  reservation: Optional[str | None] = None,
395
424
  gcs_bucket: Optional[str | None] = None,
396
425
  num_nodes: int = 2,
@@ -480,28 +509,19 @@ class BlueprintGenerator:
480
509
  },
481
510
  },
482
511
  )
512
+
513
+ sanitized_version = cluster_version.replace("-", "+", 1)
514
+ version = parse(sanitized_version)
515
+ version_prefix = f"{version.major}.{version.minor}"
483
516
  cluster_id = f"{cluster_name}-a3-ultragpu-cluster"
484
517
  a3_ultra_cluster = DeploymentModule(
485
518
  id=cluster_id,
486
519
  source="modules/scheduler/gke-cluster",
487
520
  use=[net_0_id],
488
521
  settings={
489
- "release_channel": (
490
- "UNSPECIFIED"
491
- if capacity_type == CapacityType.FLEX_START
492
- else "RAPID"
493
- ),
494
- "version_prefix": "1.32.",
495
- "maintenance_exclusions": (
496
- []
497
- if capacity_type == CapacityType.FLEX_START
498
- else [{
499
- "name": "no-minor-or-node-upgrades-indefinite",
500
- "start_time": "2024-12-01T00:00:00Z",
501
- "end_time": "2025-12-22T00:00:00Z",
502
- "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
503
- }]
504
- ),
522
+ "release_channel": release_channel.value,
523
+ "version_prefix": version_prefix,
524
+ "min_cluster_version": cluster_version,
505
525
  "prefix_with_deployment_name": False,
506
526
  "name_suffix": cluster_name,
507
527
  "system_node_pool_machine_type": system_node_pool_machine_type,
@@ -537,6 +557,17 @@ class BlueprintGenerator:
537
557
  },
538
558
  outputs=["instructions"],
539
559
  )
560
+ if release_channel != ReleaseChannel.RAPID:
561
+ a3_ultra_cluster.set_setting(
562
+ "maintenance_exclusions",
563
+ [{
564
+ "name": "no-minor-or-node-upgrades-indefinite",
565
+ "start_time": "2024-12-01T00:00:00Z",
566
+ "end_time": "2026-01-16T00:00:00Z",
567
+ "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
568
+ }],
569
+ )
570
+
540
571
  system, _ = get_system_characteristics_by_device_type(a3ultra_device_type)
541
572
  if system is None:
542
573
  xpk_print(
@@ -584,6 +615,9 @@ class BlueprintGenerator:
584
615
  else:
585
616
  gpu_pool.update_settings({"static_node_count": num_nodes})
586
617
 
618
+ if release_channel == ReleaseChannel.RAPID:
619
+ gpu_pool.set_setting("auto_upgrade", True)
620
+
587
621
  workload_manager_install_id = "workload-manager-install"
588
622
  workload_manager_install = DeploymentModule(
589
623
  id=workload_manager_install_id,
@@ -674,6 +708,8 @@ class BlueprintGenerator:
674
708
  zone: str,
675
709
  auth_cidr: str,
676
710
  system_node_pool_machine_type: str,
711
+ cluster_version: str,
712
+ release_channel: ReleaseChannel,
677
713
  reservation: Optional[str | None] = None,
678
714
  gcs_bucket: Optional[str | None] = None,
679
715
  num_nodes: int = 2,
@@ -761,12 +797,19 @@ class BlueprintGenerator:
761
797
  },
762
798
  },
763
799
  )
800
+
801
+ sanitized_version = cluster_version.replace("-", "+", 1)
802
+ version = parse(sanitized_version)
803
+ version_prefix = f"{version.major}.{version.minor}"
764
804
  cluster_id = f"{cluster_name}-a4-cluster"
765
805
  a4_cluster = DeploymentModule(
766
806
  id=cluster_id,
767
807
  source="modules/scheduler/gke-cluster",
768
808
  use=[net_0_id],
769
809
  settings={
810
+ "release_channel": release_channel.value,
811
+ "version_prefix": version_prefix,
812
+ "min_cluster_version": cluster_version,
770
813
  "system_node_pool_machine_type": system_node_pool_machine_type,
771
814
  "system_node_pool_node_count": {
772
815
  "total_min_nodes": system_node_pool_min_node_count,
@@ -791,25 +834,20 @@ class BlueprintGenerator:
791
834
  " alias_ip_range=[]}],"
792
835
  f" {cluster_name}-rdma-net.subnetwork_interfaces_gke))"
793
836
  ),
794
- "version_prefix": "1.32.",
795
- "release_channel": (
796
- "UNSPECIFIED"
797
- if capacity_type == CapacityType.FLEX_START
798
- else "RAPID"
799
- ),
800
- "maintenance_exclusions": (
801
- []
802
- if capacity_type == CapacityType.FLEX_START
803
- else [{
804
- "name": "no-minor-or-node-upgrades-indefinite",
805
- "start_time": "2024-12-01T00:00:00Z",
806
- "end_time": "2025-12-22T00:00:00Z",
807
- "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
808
- }]
809
- ),
810
837
  },
811
838
  outputs=["instructions"],
812
839
  )
840
+ if release_channel != ReleaseChannel.RAPID:
841
+ a4_cluster.set_setting(
842
+ "maintenance_exclusions",
843
+ [{
844
+ "name": "no-minor-or-node-upgrades-indefinite",
845
+ "start_time": "2024-12-01T00:00:00Z",
846
+ "end_time": "2026-01-16T00:00:00Z",
847
+ "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
848
+ }],
849
+ )
850
+
813
851
  system, _ = get_system_characteristics_by_device_type(a4_device_type)
814
852
  if system is None:
815
853
  xpk_print(
@@ -859,6 +897,9 @@ class BlueprintGenerator:
859
897
  else:
860
898
  gpu_pool.update_settings({"static_node_count": num_nodes})
861
899
 
900
+ if release_channel == ReleaseChannel.RAPID:
901
+ gpu_pool.set_setting("auto_upgrade", True)
902
+
862
903
  workload_manager_install_id = "workload-manager-install"
863
904
  workload_manager_install = DeploymentModule(
864
905
  id=workload_manager_install_id,
@@ -1019,7 +1060,6 @@ class BlueprintGenerator:
1019
1060
  "enable_flex_start": True,
1020
1061
  "enable_queued_provisioning": True,
1021
1062
  "autoscaling_total_min_nodes": 0,
1022
- "release_channel": "UNSPECIFIED",
1023
1063
  "auto_repair": False,
1024
1064
  "auto_upgrade": False,
1025
1065
  }
@@ -22,6 +22,7 @@ import ruamel.yaml
22
22
  from xpk.core.blueprint.blueprint_definitions import Blueprint
23
23
  from xpk.core.blueprint.blueprint_generator import BlueprintGenerator
24
24
  from xpk.core.capacity import CapacityType
25
+ from xpk.utils.versions import ReleaseChannel
25
26
 
26
27
  yaml = ruamel.yaml.YAML()
27
28
 
@@ -60,6 +61,8 @@ def test_generate_a3_mega_blueprint():
60
61
  reservation="test-reservation",
61
62
  capacity_type=CapacityType.RESERVATION,
62
63
  system_node_pool_min_node_count=5,
64
+ release_channel=ReleaseChannel.RAPID,
65
+ cluster_version="1.2.3",
63
66
  )
64
67
 
65
68
  assert bp.blueprint_file.endswith("/prefix/xpk-gke-a3-megagpu.yaml")
@@ -99,6 +102,8 @@ def test_generate_a3_mega_spot_blueprint():
99
102
  auth_cidr="10.0.0.0/32",
100
103
  capacity_type=CapacityType.SPOT,
101
104
  system_node_pool_min_node_count=5,
105
+ release_channel=ReleaseChannel.RAPID,
106
+ cluster_version="1.2.3",
102
107
  )
103
108
 
104
109
  assert bp.blueprint_file.endswith("/prefix/xpk-gke-a3-megagpu.yaml")
@@ -135,6 +140,8 @@ def test_generate_a3_ultra_blueprint():
135
140
  capacity_type=CapacityType.RESERVATION,
136
141
  gcs_bucket="test-bucket",
137
142
  prefix="testdir",
143
+ release_channel=ReleaseChannel.RAPID,
144
+ cluster_version="1.2.3",
138
145
  )
139
146
  with open(a3_ultra_yaml_test_path, encoding="utf-8") as stream:
140
147
  ctk_yaml = yaml.load(stream)
@@ -180,6 +187,8 @@ def test_generate_a4_blueprint():
180
187
  capacity_type=CapacityType.RESERVATION,
181
188
  gcs_bucket="test-bucket",
182
189
  prefix="testdir",
190
+ release_channel=ReleaseChannel.RAPID,
191
+ cluster_version="1.2.3",
183
192
  )
184
193
  with open(a4_yaml_test_path, encoding="utf-8") as stream:
185
194
  ctk_yaml = yaml.load(stream)