xpk 0.14.4__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. integration/README.md +19 -0
  2. integration/gcluster_a3mega_test.py +11 -0
  3. integration/gcluster_a3ultra_test.py +11 -0
  4. integration/gcluster_a4_test.py +11 -0
  5. xpk/blueprints/a3mega/config-map.yaml.tftpl +15 -0
  6. xpk/blueprints/a3mega/storage_crd.yaml +52 -0
  7. xpk/blueprints/a3ultra/config-map.yaml.tftpl +15 -0
  8. xpk/blueprints/a3ultra/mlgru-disable.yaml +59 -0
  9. xpk/blueprints/a3ultra/nccl-installer.yaml +95 -0
  10. xpk/blueprints/a3ultra/storage_crd.yaml +52 -0
  11. xpk/blueprints/a4/config-map.yaml.tftpl +15 -0
  12. xpk/blueprints/a4/nccl-rdma-installer-a4.yaml +66 -0
  13. xpk/blueprints/a4/storage_crd.yaml +52 -0
  14. xpk/commands/cluster.py +89 -32
  15. xpk/commands/cluster_gcluster.py +25 -5
  16. xpk/commands/cluster_gcluster_test.py +16 -3
  17. xpk/commands/cluster_test.py +353 -7
  18. xpk/commands/config.py +3 -5
  19. xpk/commands/inspector.py +5 -3
  20. xpk/commands/kind.py +3 -1
  21. xpk/commands/managed_ml_diagnostics.py +249 -0
  22. xpk/commands/managed_ml_diagnostics_test.py +146 -0
  23. xpk/commands/storage.py +8 -10
  24. xpk/commands/workload.py +143 -142
  25. xpk/commands/workload_test.py +160 -118
  26. xpk/core/blueprint/blueprint_generator.py +73 -33
  27. xpk/core/blueprint/blueprint_test.py +9 -0
  28. xpk/core/blueprint/testing/data/a3_mega.yaml +129 -0
  29. xpk/core/blueprint/testing/data/a3_mega_spot.yaml +125 -0
  30. xpk/core/blueprint/testing/data/a3_ultra.yaml +173 -0
  31. xpk/core/blueprint/testing/data/a4.yaml +185 -0
  32. xpk/core/capacity.py +48 -8
  33. xpk/core/capacity_test.py +32 -1
  34. xpk/core/cluster.py +55 -104
  35. xpk/core/cluster_test.py +170 -0
  36. xpk/core/commands.py +4 -10
  37. xpk/core/config.py +88 -7
  38. xpk/core/config_test.py +67 -11
  39. xpk/core/docker_container.py +3 -1
  40. xpk/core/docker_image.py +10 -6
  41. xpk/core/docker_resources.py +1 -10
  42. xpk/core/gcloud_context.py +18 -12
  43. xpk/core/gcloud_context_test.py +111 -1
  44. xpk/core/kjob.py +17 -19
  45. xpk/core/kueue_manager.py +205 -51
  46. xpk/core/kueue_manager_test.py +158 -4
  47. xpk/core/nap.py +13 -14
  48. xpk/core/nodepool.py +37 -43
  49. xpk/core/nodepool_test.py +42 -19
  50. xpk/core/pathways.py +23 -0
  51. xpk/core/pathways_test.py +57 -0
  52. xpk/core/resources.py +84 -27
  53. xpk/core/scheduling.py +144 -133
  54. xpk/core/scheduling_test.py +298 -6
  55. xpk/core/system_characteristics.py +256 -19
  56. xpk/core/system_characteristics_test.py +128 -5
  57. xpk/core/telemetry.py +263 -0
  58. xpk/core/telemetry_test.py +211 -0
  59. xpk/core/vertex.py +4 -3
  60. xpk/core/workload_decorators/tcpx_decorator.py +5 -1
  61. xpk/main.py +33 -13
  62. xpk/parser/cluster.py +40 -67
  63. xpk/parser/cluster_test.py +83 -3
  64. xpk/parser/common.py +84 -0
  65. xpk/parser/storage.py +10 -0
  66. xpk/parser/storage_test.py +47 -0
  67. xpk/parser/workload.py +14 -29
  68. xpk/parser/workload_test.py +3 -49
  69. xpk/telemetry_uploader.py +29 -0
  70. xpk/templates/arm_gpu_workload_crate.yaml.j2 +46 -0
  71. xpk/templates/kueue_gke_default_topology.yaml.j2 +1 -1
  72. xpk/templates/kueue_sub_slicing_topology.yaml.j2 +3 -8
  73. xpk/utils/console.py +41 -10
  74. xpk/utils/console_test.py +106 -0
  75. xpk/utils/feature_flags.py +10 -1
  76. xpk/utils/file.py +4 -1
  77. xpk/utils/topology.py +4 -0
  78. xpk/utils/user_agent.py +35 -0
  79. xpk/utils/user_agent_test.py +44 -0
  80. xpk/utils/user_input.py +48 -0
  81. xpk/utils/user_input_test.py +92 -0
  82. xpk/utils/validation.py +2 -13
  83. xpk/utils/versions.py +31 -0
  84. xpk-0.16.0.dist-info/METADATA +127 -0
  85. xpk-0.16.0.dist-info/RECORD +168 -0
  86. xpk-0.14.4.dist-info/METADATA +0 -1645
  87. xpk-0.14.4.dist-info/RECORD +0 -139
  88. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/WHEEL +0 -0
  89. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/entry_points.txt +0 -0
  90. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/licenses/LICENSE +0 -0
  91. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/top_level.txt +0 -0
xpk/commands/cluster.py CHANGED
@@ -17,7 +17,9 @@ limitations under the License.
17
17
  from tabulate import tabulate
18
18
 
19
19
  from ..utils.feature_flags import FeatureFlags
20
- from ..core.capacity import H100_DEVICE_TYPE, H200_DEVICE_TYPE, B200_DEVICE_TYPE, get_reservation_deployment_type
20
+ from ..utils.versions import ReleaseChannel
21
+ from ..core.pathways import get_pathways_machine_types
22
+ from ..core.capacity import H100_DEVICE_TYPE, get_reservation_deployment_type
21
23
  from ..core.cluster import (
22
24
  get_all_clusters_programmatic,
23
25
  get_cluster_credentials,
@@ -26,7 +28,6 @@ from ..core.cluster import (
26
28
  set_jobset_on_cluster,
27
29
  set_pathways_job_on_cluster,
28
30
  setup_k8s_env,
29
- disable_mglru_on_cluster,
30
31
  count_nodes_on_cluster,
31
32
  update_cluster_with_gcpfilestore_driver_if_necessary,
32
33
  update_cluster_with_gcsfuse_driver_if_necessary,
@@ -38,6 +39,8 @@ from ..core.cluster import (
38
39
  from ..core.cluster_private import authorize_private_cluster_access_if_necessary
39
40
  from ..core.commands import run_command_for_value, run_command_with_updates
40
41
  from ..core.config import VERTEX_TENSORBOARD_FEATURE_FLAG
42
+ from ..core.telemetry import MetricsCollector, MetricsEventMetadataKey
43
+ from ..core.capacity import get_capacity_type
41
44
  from ..core.gcloud_context import (
42
45
  add_zone_and_project,
43
46
  get_gke_control_plane_version,
@@ -71,9 +74,9 @@ from ..core.system_characteristics import (
71
74
  )
72
75
  from ..core.vertex import create_vertex_tensorboard
73
76
  from ..core.workload import get_workload_list
74
- from ..utils.console import get_user_input, xpk_exit, xpk_print
77
+ from ..utils.console import ask_for_user_consent, xpk_exit, xpk_print
75
78
  from ..utils.file import write_tmp_file
76
- from ..utils.execution_context import is_dry_run
79
+ from ..utils.execution_context import is_dry_run, is_quiet
77
80
  from ..utils.validation import validate_dependencies_list, SystemDependency, should_validate_dependencies
78
81
  from . import cluster_gcluster
79
82
  from .common import set_cluster_command, validate_sub_slicing_system
@@ -81,6 +84,7 @@ from jinja2 import Environment, FileSystemLoader
81
84
  from ..utils.templates import get_templates_absolute_path
82
85
  import shutil
83
86
  import os
87
+ from .managed_ml_diagnostics import install_mldiagnostics_prerequisites
84
88
 
85
89
  CLUSTER_PREHEAT_JINJA_FILE = 'cluster_preheat.yaml.j2'
86
90
 
@@ -207,6 +211,25 @@ def _validate_cluster_create_args(args, system: SystemCharacteristics):
207
211
  if FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing:
208
212
  validate_sub_slicing_system(system)
209
213
  _validate_sub_slicing_reservation(args)
214
+ if args.enable_pathways:
215
+ _validate_pathways_machine(args)
216
+
217
+
218
+ def _validate_pathways_machine(args):
219
+ return_code, result = get_pathways_machine_types(
220
+ project=args.project, zone=args.zone
221
+ )
222
+ if return_code != 0:
223
+ xpk_print('Error: Unable to retrieve available pathways machine types')
224
+ xpk_exit(1)
225
+
226
+ if args.pathways_gce_machine_type not in result:
227
+ xpk_print(
228
+ 'Error: Invalid --pathways-gce-machine-type. Specify machine type that'
229
+ ' has at least 100GB of memory and at least 49 CPUs.'
230
+ )
231
+ xpk_print(f'Available machine types: {", ".join(result)}')
232
+ xpk_exit(1)
210
233
 
211
234
 
212
235
  def _validate_sub_slicing_reservation(args):
@@ -258,20 +281,19 @@ def cluster_create(args) -> None:
258
281
  xpk_print('Fetching system characteristics failed!')
259
282
  xpk_exit(return_code)
260
283
 
261
- _validate_cluster_create_args(args, system)
262
-
263
284
  xpk_print(f'Starting cluster create for cluster {args.cluster}:', flush=True)
264
285
  add_zone_and_project(args)
265
286
 
266
- if system.device_type in cluster_gcluster.supported_device_types:
267
- xpk_print(
268
- 'Creating the cluster using Cluster Toolkit. Machine Type:'
269
- f' {system.gce_machine_type} ...'
270
- )
271
- cluster_gcluster.cluster_create(args)
272
- xpk_exit(0)
287
+ _validate_cluster_create_args(args, system)
288
+ _log_cluster_create_telemetry(args)
289
+
290
+ release_channel = (
291
+ ReleaseChannel.REGULAR if args.gke_version else ReleaseChannel.RAPID
292
+ )
273
293
 
274
- return_code, gke_server_config = get_gke_server_config(args)
294
+ return_code, gke_server_config = get_gke_server_config(
295
+ args, release_channel=release_channel
296
+ )
275
297
  if return_code != 0 or gke_server_config is None:
276
298
  xpk_exit(return_code)
277
299
 
@@ -281,8 +303,20 @@ def cluster_create(args) -> None:
281
303
  if return_code != 0 or gke_control_plane_version is None:
282
304
  xpk_exit(return_code)
283
305
 
306
+ if system.device_type in cluster_gcluster.supported_device_types:
307
+ xpk_print(
308
+ 'Creating the cluster using Cluster Toolkit. Machine Type:'
309
+ f' {system.gce_machine_type} ...'
310
+ )
311
+ cluster_gcluster.cluster_create(
312
+ args,
313
+ gke_control_plane_version=gke_control_plane_version,
314
+ release_channel=release_channel,
315
+ )
316
+ xpk_exit(0)
317
+
284
318
  create_cluster_command_code = create_cluster_if_necessary(
285
- args, gke_control_plane_version, system
319
+ args, gke_control_plane_version, system, release_channel=release_channel
286
320
  )
287
321
  if create_cluster_command_code != 0:
288
322
  xpk_exit(create_cluster_command_code)
@@ -407,6 +441,13 @@ def cluster_create(args) -> None:
407
441
  # pylint: disable=line-too-long
408
442
  f' https://console.cloud.google.com/kubernetes/clusters/details/{get_cluster_location(args.project, args.cluster, args.zone)}/{args.cluster}/details?project={args.project}'
409
443
  )
444
+
445
+ if args.managed_mldiagnostics:
446
+ return_code = install_mldiagnostics_prerequisites()
447
+ if return_code != 0:
448
+ xpk_print('Installation of MLDiagnostics failed.')
449
+ xpk_exit(return_code)
450
+
410
451
  xpk_exit(0)
411
452
 
412
453
 
@@ -964,7 +1005,7 @@ def update_coredns() -> int:
964
1005
 
965
1006
  # 6. Scale up coredns and verify readiness
966
1007
  scale_up_coredns(replicas=15)
967
- verify_coredns_readiness(timeout=120)
1008
+ verify_coredns_readiness()
968
1009
 
969
1010
  xpk_print('The CoreDNS setup process has been completed.')
970
1011
 
@@ -1022,7 +1063,10 @@ def update_coredns_if_necessary() -> int:
1022
1063
 
1023
1064
 
1024
1065
  def create_cluster_if_necessary(
1025
- args, gke_control_plane_version: str, system: SystemCharacteristics
1066
+ args,
1067
+ gke_control_plane_version: str,
1068
+ system: SystemCharacteristics,
1069
+ release_channel: ReleaseChannel,
1026
1070
  ) -> int:
1027
1071
  """Creates cluster if not present in the project.
1028
1072
 
@@ -1043,7 +1087,7 @@ def create_cluster_if_necessary(
1043
1087
  return 0
1044
1088
  else:
1045
1089
  return run_gke_cluster_create_command(
1046
- args, gke_control_plane_version, system
1090
+ args, gke_control_plane_version, system, release_channel=release_channel
1047
1091
  )
1048
1092
 
1049
1093
 
@@ -1056,7 +1100,7 @@ def run_gke_cluster_delete_command(args) -> int:
1056
1100
  Returns:
1057
1101
  0 if successful and 1 otherwise.
1058
1102
  """
1059
- if not args.force:
1103
+ if not is_quiet():
1060
1104
  xpk_print('Get the name of the workloads in the cluster.')
1061
1105
  args.filter_by_status = 'EVERYTHING'
1062
1106
  return_code, return_value = get_workload_list(args)
@@ -1067,10 +1111,9 @@ def run_gke_cluster_delete_command(args) -> int:
1067
1111
  # Ignore Column Names line.
1068
1112
  if len(return_value) > 1:
1069
1113
  workloads = [x.split(' ')[0] for x in return_value.splitlines()][1:]
1070
- if workloads and not get_user_input(
1114
+ if workloads and not ask_for_user_consent(
1071
1115
  f'Planning to delete {len(workloads)} workloads in the cluster'
1072
- f' {args.cluster} including {workloads}. \nDo you wish to delete: y'
1073
- ' (yes) / n (no):\n'
1116
+ f' {args.cluster} including {workloads}. \nDo you wish to delete?'
1074
1117
  ):
1075
1118
  xpk_print('Skipping delete command.')
1076
1119
  return 0
@@ -1115,7 +1158,10 @@ def run_gke_clusters_list_command(args) -> int:
1115
1158
 
1116
1159
 
1117
1160
  def run_gke_cluster_create_command(
1118
- args, gke_control_plane_version: str, system: SystemCharacteristics
1161
+ args,
1162
+ gke_control_plane_version: str,
1163
+ system: SystemCharacteristics,
1164
+ release_channel: ReleaseChannel,
1119
1165
  ) -> int:
1120
1166
  """Run the Create GKE Cluster request.
1121
1167
 
@@ -1155,9 +1201,10 @@ def run_gke_cluster_create_command(
1155
1201
  ' --enable-dns-access'
1156
1202
  ' --autoscaling-profile=optimize-utilization'
1157
1203
  ' --labels=gke_product_type=xpk'
1204
+ f' --release-channel={release_channel.value.lower()}'
1158
1205
  )
1159
1206
 
1160
- if args.gke_version or system.accelerator_type == AcceleratorType.GPU:
1207
+ if args.gke_version:
1161
1208
  command += ' --no-enable-autoupgrade'
1162
1209
 
1163
1210
  enable_ip_alias = False
@@ -1199,7 +1246,8 @@ def run_gke_cluster_create_command(
1199
1246
 
1200
1247
  if args.enable_lustre_csi_driver:
1201
1248
  addons.append('LustreCsiDriver')
1202
- command += ' --enable-legacy-lustre-port'
1249
+ if args.enable_legacy_lustre_port:
1250
+ command += ' --enable-legacy-lustre-port'
1203
1251
 
1204
1252
  if hasattr(args, 'enable_mtc') and args.enable_mtc:
1205
1253
  addons.append('HighScaleCheckpointing')
@@ -1285,7 +1333,7 @@ def _install_kueue(
1285
1333
  else:
1286
1334
  # Determine total chips based on user specified topology.
1287
1335
  total_chips = get_total_chips_requested_from_args(args, system)
1288
- kueue_manager = KueueManager()
1336
+ kueue_manager = KueueManager(args.project, args.zone)
1289
1337
  return kueue_manager.install_or_upgrade(
1290
1338
  KueueConfig(
1291
1339
  system,
@@ -1299,7 +1347,7 @@ def _install_kueue(
1299
1347
  configure_sub_slicing=(
1300
1348
  FeatureFlags.SUB_SLICING_ENABLED and args.sub_slicing
1301
1349
  ),
1302
- ),
1350
+ )
1303
1351
  )
1304
1352
 
1305
1353
 
@@ -1315,8 +1363,17 @@ def prepare_gpus(system: SystemCharacteristics):
1315
1363
  if install_nri_code != 0:
1316
1364
  xpk_exit(install_nri_code)
1317
1365
 
1318
- if system.device_type in [H200_DEVICE_TYPE, B200_DEVICE_TYPE]:
1319
- xpk_print('Disabling MGLRU')
1320
- err_code = disable_mglru_on_cluster()
1321
- if err_code > 0:
1322
- xpk_exit(err_code)
1366
+
1367
+ def _log_cluster_create_telemetry(args) -> None:
1368
+ if FeatureFlags.TELEMETRY_ENABLED:
1369
+ capacity_type, _ = get_capacity_type(args)
1370
+ MetricsCollector.log_custom(
1371
+ name='cluster_create',
1372
+ metadata={
1373
+ MetricsEventMetadataKey.ZONE: args.zone,
1374
+ MetricsEventMetadataKey.SYSTEM_CHARACTERISTICS: (
1375
+ args.tpu_type if args.tpu_type else args.device_type
1376
+ ),
1377
+ MetricsEventMetadataKey.PROVISIONING_MODE: capacity_type.value,
1378
+ },
1379
+ )
@@ -17,6 +17,7 @@ limitations under the License.
17
17
  import os
18
18
 
19
19
  from ..utils.feature_flags import FeatureFlags
20
+ from ..utils.versions import ReleaseChannel
20
21
  from ..utils.execution_context import is_dry_run
21
22
  from ..core.kueue_manager import KueueConfig, KueueManager
22
23
  from ..core.nap import enable_autoprovisioning_on_cluster
@@ -51,11 +52,15 @@ gcluster_working_dir = os.path.abspath('xpkclusters/gcluster-out')
51
52
  gcloud_cfg_path = os.path.expanduser('~/.config/gcloud')
52
53
 
53
54
 
54
- def cluster_create(args) -> None:
55
+ def cluster_create(
56
+ args, gke_control_plane_version: str, release_channel: ReleaseChannel
57
+ ) -> None:
55
58
  """Function around cluster creation using Cluster toolkit.
56
59
 
57
60
  Args:
58
61
  args: user provided arguments for running the command.
62
+ gke_control_plane_version: the GKE version used for the new cluster.
63
+ release_channel:t the release channel used for the new cluster.
59
64
 
60
65
  Returns:
61
66
  0 if successful and 1 otherwise.
@@ -79,7 +84,13 @@ def cluster_create(args) -> None:
79
84
  )
80
85
  gcm = prepare_gcluster_manager(remote_state_client)
81
86
 
82
- bp = generate_blueprint(blueprint_name=unique_name, args=args, prefix=prefix)
87
+ bp = generate_blueprint(
88
+ blueprint_name=unique_name,
89
+ args=args,
90
+ prefix=prefix,
91
+ gke_control_plane_version=gke_control_plane_version,
92
+ release_channel=release_channel,
93
+ )
83
94
 
84
95
  # staging: sending the blueprint file(s) to gcluster's working directory
85
96
  if is_dry_run():
@@ -141,7 +152,7 @@ def __install_kueue(args) -> int:
141
152
  else:
142
153
  # Determine total chips based on user specified topology.
143
154
  total_chips = get_total_chips_requested_from_args(args, system)
144
- kueue_manager = KueueManager()
155
+ kueue_manager = KueueManager(args.project, args.zone)
145
156
 
146
157
  tolerations = [{
147
158
  'key': 'components.gke.io/gke-managed-components',
@@ -149,7 +160,6 @@ def __install_kueue(args) -> int:
149
160
  'value': 'true',
150
161
  'effect': 'NoSchedule',
151
162
  }]
152
-
153
163
  kueue_manager.install_or_upgrade(
154
164
  KueueConfig(
155
165
  system,
@@ -287,7 +297,11 @@ def validate_state_gcs_bucket(args):
287
297
 
288
298
 
289
299
  def generate_blueprint(
290
- blueprint_name, args, prefix=None
300
+ blueprint_name,
301
+ args,
302
+ gke_control_plane_version: str,
303
+ release_channel: ReleaseChannel,
304
+ prefix=None,
291
305
  ) -> BlueprintGeneratorOutput:
292
306
  capacity_type, return_code = get_capacity_type(args)
293
307
  if return_code != 0:
@@ -342,6 +356,8 @@ def generate_blueprint(
342
356
  system_node_pool_machine_type=args.default_pool_cpu_machine_type,
343
357
  system_node_pool_min_node_count=args.default_pool_cpu_num_nodes,
344
358
  gcs_bucket=args.cluster_state_gcs_bucket,
359
+ cluster_version=gke_control_plane_version,
360
+ release_channel=release_channel,
345
361
  )
346
362
  if args.device_type == a3ultra_device_type:
347
363
  num_nodes = args.num_nodes if not args.num_nodes is None else 2
@@ -360,6 +376,8 @@ def generate_blueprint(
360
376
  system_node_pool_machine_type=args.default_pool_cpu_machine_type,
361
377
  system_node_pool_min_node_count=args.default_pool_cpu_num_nodes,
362
378
  gcs_bucket=args.cluster_state_gcs_bucket,
379
+ cluster_version=gke_control_plane_version,
380
+ release_channel=release_channel,
363
381
  )
364
382
  if args.device_type == a4_device_type:
365
383
  num_nodes = args.num_nodes if not args.num_nodes is None else 2
@@ -376,6 +394,8 @@ def generate_blueprint(
376
394
  capacity_type=capacity_type,
377
395
  system_node_pool_machine_type=args.default_pool_cpu_machine_type,
378
396
  system_node_pool_min_node_count=args.default_pool_cpu_num_nodes,
397
+ cluster_version=gke_control_plane_version,
398
+ release_channel=release_channel,
379
399
  )
380
400
  xpk_print('Device type is not supported.')
381
401
  xpk_exit(1)
@@ -20,7 +20,8 @@ import pytest
20
20
 
21
21
  from xpk.commands.cluster_gcluster import cluster_create
22
22
  from xpk.core.kueue_manager import KueueConfig
23
- from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics
23
+ from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics, DockerPlatform, GpuConfig
24
+ from xpk.utils.versions import ReleaseChannel
24
25
 
25
26
 
26
27
  @pytest.fixture
@@ -96,6 +97,8 @@ def test_install_kueue_standard(
96
97
  accelerator_type=AcceleratorType.GPU,
97
98
  device_type="h100-mega-80gb-8",
98
99
  supports_sub_slicing=False,
100
+ docker_platform=DockerPlatform.ARM,
101
+ gpu_config=GpuConfig(requires_topology=True),
99
102
  )
100
103
  mock_cluster_create_deps["get_system_characteristics"].return_value = (
101
104
  mock_system,
@@ -103,7 +106,11 @@ def test_install_kueue_standard(
103
106
  )
104
107
  mock_get_total_chips.return_value = 16
105
108
 
106
- cluster_create(mock_args)
109
+ cluster_create(
110
+ mock_args,
111
+ release_channel=ReleaseChannel.RAPID,
112
+ gke_control_plane_version="1.2.3",
113
+ )
107
114
 
108
115
  mock_cluster_create_deps["xpk_exit"].assert_called_with(0)
109
116
  mock_kueue_manager = mock_cluster_create_deps["KueueManager"]
@@ -143,6 +150,8 @@ def test_install_kueue_with_autoprovisioning(
143
150
  accelerator_type=AcceleratorType.GPU,
144
151
  device_type="h100-mega-80gb-8",
145
152
  supports_sub_slicing=False,
153
+ docker_platform=DockerPlatform.ARM,
154
+ gpu_config=GpuConfig(requires_topology=True),
146
155
  )
147
156
  mock_cluster_create_deps["get_system_characteristics"].return_value = (
148
157
  mock_system,
@@ -153,7 +162,11 @@ def test_install_kueue_with_autoprovisioning(
153
162
  mock_autoprovisioning_config.maximum_chips = 128
154
163
  mock_enable_autoprovisioning.return_value = (mock_autoprovisioning_config, 0)
155
164
 
156
- cluster_create(mock_args)
165
+ cluster_create(
166
+ mock_args,
167
+ release_channel=ReleaseChannel.RAPID,
168
+ gke_control_plane_version="1.2.3",
169
+ )
157
170
 
158
171
  mock_cluster_create_deps["xpk_exit"].assert_called_with(0)
159
172
  mock_enable_autoprovisioning.assert_called_once_with(mock_args, mock_system)