xpk 0.8.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. xpk/commands/batch.py +5 -6
  2. xpk/commands/cluster.py +246 -73
  3. xpk/commands/cluster_gcluster.py +27 -0
  4. xpk/commands/common.py +40 -1
  5. xpk/commands/kjob_common.py +13 -1
  6. xpk/commands/run.py +4 -5
  7. xpk/commands/shell.py +2 -2
  8. xpk/commands/storage.py +24 -6
  9. xpk/commands/workload.py +66 -27
  10. xpk/core/blueprint/blueprint_generator.py +115 -47
  11. xpk/core/capacity.py +66 -6
  12. xpk/core/cluster.py +282 -13
  13. xpk/core/config.py +1 -65
  14. xpk/core/docker_manager.py +1 -1
  15. xpk/core/docker_resources.py +145 -72
  16. xpk/core/filestore.py +2 -6
  17. xpk/core/gcsfuse.py +22 -4
  18. xpk/core/jobset.py +143 -0
  19. xpk/core/kjob.py +21 -18
  20. xpk/core/kueue.py +194 -4
  21. xpk/core/mtc.py +195 -0
  22. xpk/core/network.py +23 -1
  23. xpk/core/nodepool.py +17 -4
  24. xpk/core/pathways.py +2 -3
  25. xpk/core/resources.py +21 -0
  26. xpk/core/storage.py +1 -95
  27. xpk/core/system_characteristics.py +1 -1
  28. xpk/core/workload.py +1 -45
  29. xpk/core/workload_decorators/rdma_decorator.py +8 -10
  30. xpk/core/workload_decorators/tcpx_decorator.py +185 -0
  31. xpk/core/workload_decorators/tcpxo_decorator.py +22 -14
  32. xpk/parser/cluster.py +589 -389
  33. xpk/parser/storage.py +12 -3
  34. xpk/parser/workload.py +21 -3
  35. xpk/utils/kubectl.py +4 -1
  36. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/METADATA +178 -96
  37. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/RECORD +41 -38
  38. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/WHEEL +1 -1
  39. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/entry_points.txt +0 -0
  40. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/licenses/LICENSE +0 -0
  41. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/top_level.txt +0 -0
xpk/commands/storage.py CHANGED
@@ -29,6 +29,7 @@ from ..core.cluster import (
29
29
  setup_k8s_env,
30
30
  update_cluster_with_parallelstore_driver_if_necessary,
31
31
  update_cluster_with_pd_driver_if_necessary,
32
+ update_cluster_with_lustre_driver_if_necessary,
32
33
  update_cluster_with_gcpfilestore_driver_if_necessary,
33
34
  update_cluster_with_gcsfuse_driver_if_necessary,
34
35
  update_cluster_with_workload_identity_if_necessary,
@@ -45,6 +46,7 @@ from ..core.storage import (
45
46
  GCS_FUSE_TYPE,
46
47
  GCE_PD_TYPE,
47
48
  PARALLELSTORE_TYPE,
49
+ LUSTRE_TYPE,
48
50
  STORAGE_CRD_PLURAL,
49
51
  XPK_API_GROUP_NAME,
50
52
  XPK_API_GROUP_VERSION,
@@ -86,7 +88,6 @@ def storage_create(args: Namespace) -> None:
86
88
  args.vol,
87
89
  args.access_mode,
88
90
  filestore_network,
89
- args.mount_options,
90
91
  )
91
92
 
92
93
  k8s_api_client = setup_k8s_env(args)
@@ -162,7 +163,6 @@ def storage_attach(args: Namespace) -> None:
162
163
  args.vol,
163
164
  args.access_mode,
164
165
  filestore_network,
165
- args.mount_options,
166
166
  )
167
167
 
168
168
  elif args.type == GCS_FUSE_TYPE:
@@ -178,14 +178,18 @@ def storage_attach(args: Namespace) -> None:
178
178
  manifest = list(yaml.safe_load_all(f))
179
179
  else:
180
180
  manifest = gcsfuse.manifest(
181
- args.name, args.bucket, args.size, args.mount_options
181
+ args.name,
182
+ args.bucket,
183
+ args.size,
184
+ args.mount_options,
185
+ args.prefetch_metadata,
182
186
  )
183
187
 
184
- elif args.type in [PARALLELSTORE_TYPE, GCE_PD_TYPE]:
188
+ elif args.type in [PARALLELSTORE_TYPE, GCE_PD_TYPE, LUSTRE_TYPE]:
185
189
  if args.manifest is None:
186
190
  xpk_print(
187
- "Parallelstore and PersistentDisk are currently supported only with"
188
- " --manifest"
191
+ "Parallelstore, PersistentDisk, and Lustre are currently supported"
192
+ " only with --manifest"
189
193
  )
190
194
  xpk_exit(1)
191
195
 
@@ -232,6 +236,11 @@ def enable_csi_drivers_if_necessary(args: Namespace) -> None:
232
236
  if return_code > 0:
233
237
  xpk_exit(return_code)
234
238
 
239
+ if args.type == LUSTRE_TYPE:
240
+ return_code = update_cluster_with_lustre_driver_if_necessary(args)
241
+ if return_code > 0:
242
+ xpk_exit(return_code)
243
+
235
244
 
236
245
  def storage_list(args: Namespace) -> None:
237
246
  k8s_api_client = setup_k8s_env(args)
@@ -323,3 +332,12 @@ def delete_storage_resources(k8s_api_client: ApiClient, storage: Storage):
323
332
  storage.name,
324
333
  "Storage",
325
334
  )
335
+
336
+ # remove kubernetes.io/pvc-protection
337
+ delete_resource(
338
+ lambda name: core_api.patch_namespaced_persistent_volume_claim(
339
+ name, "default", {"metadata": {"finalizers": None}}
340
+ ),
341
+ storage.pvc,
342
+ "Persistent Volume Claim finalizers",
343
+ )
xpk/commands/workload.py CHANGED
@@ -15,27 +15,24 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  from ..core.blueprint.blueprint_generator import (
18
- get_subnetworks_for_a3mega,
19
- get_subnetworks_for_a3ultra,
20
- get_subnetworks_for_a4,
18
+ a3high_device_type,
19
+ a3mega_device_type,
20
+ a3ultra_device_type,
21
+ a4_device_type,
21
22
  )
22
23
  from ..core.cluster import (
23
24
  XPK_SA,
24
- create_xpk_k8s_service_account,
25
+ setup_k8s_service_accounts,
25
26
  get_cluster_credentials,
26
27
  setup_k8s_env,
27
28
  )
28
29
  from ..core.commands import run_command_with_updates, run_commands
29
- from ..core.config import (
30
- VERTEX_TENSORBOARD_FEATURE_FLAG,
31
- XPK_CURRENT_VERSION,
32
- parse_env_config,
33
- )
30
+ from ..core.config import (VERTEX_TENSORBOARD_FEATURE_FLAG, XPK_CURRENT_VERSION)
34
31
  from ..core.docker_container import (
35
32
  get_main_container_docker_image,
36
33
  get_user_workload_container,
37
34
  )
38
- from ..core.docker_resources import get_volumes
35
+ from ..core.docker_resources import get_volumes, parse_env_config
39
36
  from ..core.gcloud_context import add_zone_and_project
40
37
  from ..core.kueue import LOCAL_QUEUE_NAME
41
38
  from ..core.monitoring import get_gke_outlier_dashboard
@@ -43,6 +40,7 @@ from ..core.nap import (
43
40
  get_autoprovisioning_node_selector_args,
44
41
  is_autoprovisioning_enabled,
45
42
  )
43
+ from ..core.network import get_cluster_subnetworks
46
44
  from ..core.pathways import (
47
45
  append_custom_colocated_python_sidecar,
48
46
  append_custom_pathways_proxy_server,
@@ -54,6 +52,10 @@ from ..core.pathways import (
54
52
  get_user_workload_for_pathways,
55
53
  try_to_delete_pathwaysjob_first,
56
54
  )
55
+ from ..core.resources import get_cluster_capacity_type, get_cluster_system_characteristics
56
+ from ..core.capacity import (
57
+ CapacityType,
58
+ )
57
59
  from ..core.resources import CLUSTER_METADATA_CONFIGMAP, get_cluster_configmap
58
60
  from ..core.scheduling import (
59
61
  check_if_workload_can_schedule,
@@ -69,6 +71,7 @@ from ..core.storage import (
69
71
  GCP_FILESTORE_TYPE,
70
72
  GCS_FUSE_TYPE,
71
73
  PARALLELSTORE_TYPE,
74
+ LUSTRE_TYPE,
72
75
  Storage,
73
76
  add_bucket_iam_members,
74
77
  get_storage_annotations,
@@ -80,7 +83,6 @@ from ..core.system_characteristics import (
80
83
  )
81
84
  from ..core.vertex import create_vertex_experiment
82
85
  from ..core.workload import (
83
- add_gpu_rxdm_container,
84
86
  check_if_workload_exists,
85
87
  get_workload_list,
86
88
  wait_for_job_completion,
@@ -89,11 +91,13 @@ from ..core.workload import (
89
91
  from ..core.workload_decorators import (
90
92
  rdma_decorator,
91
93
  storage_decorator,
94
+ tcpx_decorator,
92
95
  tcpxo_decorator,
93
96
  )
94
97
  from ..utils.console import get_user_input, xpk_exit, xpk_print
95
98
  from ..utils.file import write_tmp_file
96
99
  from . import cluster_gcluster
100
+ from .common import is_TAS_possible
97
101
 
98
102
  WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
99
103
  kind: JobSet
@@ -126,6 +130,8 @@ spec:
126
130
  {storage_annotations}
127
131
  spec:
128
132
  schedulerName: {args.scheduler}
133
+ imagePullSecrets:
134
+ - name: {args.docker_image_pull_secret}
129
135
  restartPolicy: Never
130
136
  {affinity}
131
137
  nodeSelector:
@@ -139,6 +145,8 @@ spec:
139
145
  containers:
140
146
  {container}
141
147
  serviceAccountName: {service_account}
148
+ tolerations:
149
+ {tpu_toleration}
142
150
  volumes:
143
151
  {volumes}
144
152
  """
@@ -178,6 +186,8 @@ spec:
178
186
  {gpu_scheduler}
179
187
  priorityClassName: {args.priority}
180
188
  restartPolicy: Never
189
+ imagePullSecrets:
190
+ - name: {args.docker_image_pull_secret}
181
191
  hostNetwork: true
182
192
  dnsPolicy: ClusterFirstWithHostNet
183
193
  terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
@@ -216,11 +226,12 @@ spec:
216
226
  metadata:
217
227
  labels:
218
228
  xpk.google.com/workload: {args.workload}
219
- annotations:
220
- kueue.x-k8s.io/podset-preferred-topology: "cloud.google.com/gce-topology-host"
229
+ annotations: {annotations}
221
230
  spec:
222
231
  priorityClassName: {args.priority}
223
232
  restartPolicy: Never
233
+ imagePullSecrets:
234
+ - name: {args.docker_image_pull_secret}
224
235
  dnsPolicy: ClusterFirstWithHostNet
225
236
  terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
226
237
  serviceAccountName: {service_account}
@@ -294,7 +305,7 @@ def workload_create(args) -> None:
294
305
  0 if successful and 1 otherwise.
295
306
  """
296
307
  k8s_api_client = setup_k8s_env(args)
297
- create_xpk_k8s_service_account()
308
+ setup_k8s_service_accounts()
298
309
 
299
310
  workload_exists = check_if_workload_exists(args)
300
311
 
@@ -350,7 +361,7 @@ def workload_create(args) -> None:
350
361
  if not tensorboard_config:
351
362
  xpk_exit(1)
352
363
 
353
- parse_env_config(args, tensorboard_config, system)
364
+ parse_env_config(args, tensorboard_config)
354
365
 
355
366
  autoprovisioning_args = ''
356
367
  autoprovisioning_enabled, return_code = is_autoprovisioning_enabled(
@@ -385,6 +396,9 @@ def workload_create(args) -> None:
385
396
  pd_storages: list[Storage] = list(
386
397
  filter(lambda storage: storage.type == GCE_PD_TYPE, storages)
387
398
  )
399
+ lustre_storages: list[Storage] = list(
400
+ filter(lambda storage: storage.type == LUSTRE_TYPE, storages)
401
+ )
388
402
  if len(gcs_fuse_storages) > 0:
389
403
  service_account = XPK_SA
390
404
  xpk_print(f'Detected gcsfuse Storages to add: {gcs_fuse_storages}')
@@ -406,7 +420,7 @@ def workload_create(args) -> None:
406
420
  f' {parallelstore_storages}'
407
421
  )
408
422
  else:
409
- xpk_print('No gcp filestore instances to add detected.')
423
+ xpk_print('No gcp parallelstore instances to add detected.')
410
424
 
411
425
  if len(pd_storages) > 0:
412
426
  service_account = XPK_SA
@@ -414,11 +428,18 @@ def workload_create(args) -> None:
414
428
  else:
415
429
  xpk_print('No gce persistent disk instances to add detected.')
416
430
 
431
+ if len(lustre_storages) > 0:
432
+ service_account = XPK_SA
433
+ xpk_print(f'Detected managed lustre instances to add: {lustre_storages}')
434
+ else:
435
+ xpk_print('No managed lustre instances to add detected.')
436
+
417
437
  all_storages = (
418
438
  gcs_fuse_storages
419
439
  + gcpfilestore_storages
420
440
  + parallelstore_storages
421
441
  + pd_storages
442
+ + lustre_storages
422
443
  )
423
444
 
424
445
  # Currently failure policy rules are supported for Pathways workloads. b/408465881
@@ -450,26 +471,41 @@ def workload_create(args) -> None:
450
471
  )
451
472
  if return_code != 0:
452
473
  xpk_exit(return_code)
474
+ system_characteristics = get_cluster_system_characteristics(args)
475
+ capacity_type = get_cluster_capacity_type(args)
476
+
477
+ annotations = (
478
+ ''
479
+ if not is_TAS_possible(
480
+ system_characteristics,
481
+ capacity_type,
482
+ flex=True if capacity_type == CapacityType.FLEX_START else False,
483
+ )
484
+ else (
485
+ 'kueue.x-k8s.io/podset-preferred-topology:'
486
+ ' "cloud.google.com/gce-topology-host"'
487
+ )
488
+ )
453
489
 
454
- if system.device_type in cluster_gcluster.supported_device_types:
490
+ if (
491
+ system.device_type in cluster_gcluster.supported_device_types
492
+ or system.device_type == a3high_device_type
493
+ ):
455
494
  yml_string = A3_GPU_WORKLOAD_CREATE_YAML.format(
456
495
  args=args,
457
496
  container=container,
458
497
  service_account=XPK_SA,
459
498
  failure_policy_rules=failure_policy_rules,
460
499
  pod_failure_policy=pod_failure_policy,
500
+ annotations=annotations,
461
501
  )
462
502
 
463
- if args.device_type == cluster_gcluster.a3mega_device_type:
464
- sub_networks = get_subnetworks_for_a3mega(args.cluster)
503
+ sub_networks = get_cluster_subnetworks(args)
504
+ if args.device_type == a3high_device_type:
505
+ yml_string = tcpx_decorator.decorate_jobset(yml_string)
506
+ elif args.device_type == a3mega_device_type:
465
507
  yml_string = tcpxo_decorator.decorate_jobset(yml_string, sub_networks)
466
-
467
- if args.device_type == cluster_gcluster.a3ultra_device_type:
468
- sub_networks = get_subnetworks_for_a3ultra(args.cluster)
469
- yml_string = rdma_decorator.decorate_jobset(yml_string, sub_networks)
470
-
471
- if args.device_type == cluster_gcluster.a4_device_type:
472
- sub_networks = get_subnetworks_for_a4()
508
+ elif args.device_type in [a3ultra_device_type, a4_device_type]:
473
509
  yml_string = rdma_decorator.decorate_jobset(yml_string, sub_networks)
474
510
 
475
511
  if all_storages:
@@ -487,7 +523,6 @@ def workload_create(args) -> None:
487
523
  failure_policy_rules=failure_policy_rules,
488
524
  pod_failure_policy=pod_failure_policy,
489
525
  )
490
- yml_string = add_gpu_rxdm_container(yml_string, system, all_storages)
491
526
 
492
527
  elif args.use_pathways and ensure_pathways_workload_prerequisites(
493
528
  args, system
@@ -524,6 +559,10 @@ def workload_create(args) -> None:
524
559
  get_storage_annotations(all_storages)
525
560
  ),
526
561
  service_account=service_account,
562
+ tpu_toleration="""
563
+ - operator: "Exists"
564
+ key: google.com/tpu
565
+ """ if system.accelerator_type == AcceleratorType['TPU'] else '',
527
566
  failure_policy_rules=failure_policy_rules,
528
567
  pod_failure_policy=pod_failure_policy,
529
568
  )
@@ -22,7 +22,9 @@ from ruamel import yaml
22
22
 
23
23
  from ...utils.console import xpk_exit, xpk_print
24
24
  from ...utils.file import ensure_directory_exists
25
+
25
26
  from ..capacity import (
27
+ H100_DEVICE_TYPE,
26
28
  B200_DEVICE_TYPE,
27
29
  H100_MEGA_DEVICE_TYPE,
28
30
  H200_DEVICE_TYPE,
@@ -30,10 +32,11 @@ from ..capacity import (
30
32
  )
31
33
  from ..system_characteristics import get_system_characteristics_by_device_type
32
34
  from .blueprint_definitions import Blueprint, DeploymentGroup, DeploymentModule
33
-
35
+ from ..kueue import KUEUE_VERSION
34
36
 
35
37
  yaml = yaml.YAML()
36
38
 
39
+ a3high_device_type = H100_DEVICE_TYPE
37
40
  a3mega_device_type = H100_MEGA_DEVICE_TYPE
38
41
  a3ultra_device_type = H200_DEVICE_TYPE
39
42
  a4_device_type = B200_DEVICE_TYPE
@@ -49,21 +52,7 @@ blueprint_dependencies_dir = {
49
52
  }
50
53
 
51
54
  cluster_toolkit_url = "github.com/GoogleCloudPlatform/cluster-toolkit"
52
- cluster_toolkit_version = "v1.48.0"
53
-
54
-
55
- def get_subnetworks_for_a3mega(cluster_name: str) -> list[str]:
56
- return [f"{cluster_name}-gpunet-{i}-subnet" for i in range(8)]
57
-
58
-
59
- def get_subnetworks_for_a3ultra(cluster_name: str) -> list[str]:
60
- return [f"{cluster_name}-sub-1"] + [
61
- f"{cluster_name}-rdma-sub-{i}" for i in range(8)
62
- ]
63
-
64
-
65
- def get_subnetworks_for_a4() -> list[str]:
66
- return ["gvnic-1"] + [f"rdma-{i}" for i in range(8)]
55
+ cluster_toolkit_version = "v1.57.1"
67
56
 
68
57
 
69
58
  class BlueprintGeneratorOutput:
@@ -106,6 +95,8 @@ class BlueprintGenerator:
106
95
  group_placement_max_distance: int = 2,
107
96
  subnetwork_cidr_suffix: int = 24,
108
97
  reservation: str | None = None,
98
+ reservation_placement_policy: dict[str, str] | None = None,
99
+ reservation_maintenance_interval: str = "PERIODIC",
109
100
  gcs_bucket: Optional[str | None] = None,
110
101
  capacity_type: CapacityType = CapacityType.ON_DEMAND,
111
102
  system_node_pool_min_node_count: int = 2,
@@ -156,7 +147,6 @@ class BlueprintGenerator:
156
147
  source="modules/scheduler/gke-cluster",
157
148
  use=[primary_vpc_name, gpu_subnets_name],
158
149
  settings={
159
- "release_channel": "RAPID",
160
150
  "prefix_with_deployment_name": False,
161
151
  "name_suffix": cluster_name,
162
152
  "enable_private_endpoint": False,
@@ -190,27 +180,42 @@ class BlueprintGenerator:
190
180
  "group_placement_max_distance": group_placement_max_distance,
191
181
  },
192
182
  )
183
+ nodepool_used_deps = ["gke_cluster", gpu_subnets_name]
193
184
 
194
185
  a3_megagpu_pool_0 = DeploymentModule(
195
186
  id="a3_megagpu_pool_0",
196
187
  source="modules/compute/gke-node-pool",
197
- use=["gke_cluster", gpu_subnets_name, "group_placement_0"],
188
+ use=nodepool_used_deps,
198
189
  settings={
199
190
  "name": f"{cluster_name}-a3-megagpu-pool-0",
200
191
  "machine_type": system.gce_machine_type,
201
- "static_node_count": num_nodes,
202
192
  "zones": [zone],
203
- "host_maintenance_interval": "PERIODIC",
193
+ "host_maintenance_interval": reservation_maintenance_interval,
204
194
  "reservation_affinity": self._getblock_reservation_affinity(
205
195
  reservation
206
196
  ),
207
197
  "run_workload_script": False,
208
198
  "spot": capacity_type == CapacityType.SPOT,
209
199
  "max_pods_per_node": 32,
210
- "auto_upgrade": True,
200
+ "guest_accelerator": [{
201
+ "type": "nvidia-h100-mega-80gb",
202
+ "count": 8,
203
+ "gpu_driver_installation_config": {
204
+ "gpu_driver_version": "LATEST"
205
+ },
206
+ }],
207
+ "auto_upgrade": (
208
+ True if capacity_type != CapacityType.FLEX_START else False
209
+ ),
211
210
  },
212
211
  outputs=["instructions"],
213
212
  )
213
+ if capacity_type == CapacityType.FLEX_START:
214
+ a3_megagpu_pool_0.settings.update(self.get_dws_flex_start())
215
+ else:
216
+ a3_megagpu_pool_0.settings.update({"static_node_count": num_nodes})
217
+
218
+ set_placement_policy = capacity_type != CapacityType.SPOT
214
219
  num_chips = num_nodes * system.chips_per_vm
215
220
  workload = DeploymentModule(
216
221
  id="workload_component_install",
@@ -219,9 +224,17 @@ class BlueprintGenerator:
219
224
  settings={
220
225
  "kueue": {
221
226
  "install": True,
222
- "version": "v0.10.0", # TAS feature-gates is enabled in CT
227
+ "version": KUEUE_VERSION, # TAS feature-gates is enabled in CT
223
228
  "config_path": f'$(ghpc_stage("{blueprint_name}"))/kueue-xpk-configuration.yaml.tftpl',
224
- "config_template_vars": {"num_chips": num_chips},
229
+ "config_template_vars": {
230
+ "num_chips": num_chips,
231
+ "reservation": (
232
+ 1 if capacity_type == CapacityType.RESERVATION else 0
233
+ ),
234
+ "flex_start": (
235
+ 1 if capacity_type == CapacityType.FLEX_START else 0
236
+ ),
237
+ },
225
238
  },
226
239
  "jobset": {"install": True, "version": "v0.7.2"},
227
240
  "apply_manifests": [{
@@ -251,18 +264,27 @@ class BlueprintGenerator:
251
264
  }]
252
265
  },
253
266
  )
267
+
268
+ print(reservation_placement_policy)
269
+ if reservation_placement_policy is not None:
270
+ a3_megagpu_pool_0.settings["placement_policy"] = (
271
+ reservation_placement_policy
272
+ )
273
+
254
274
  primary_group = DeploymentGroup(
255
275
  group="primary",
256
276
  modules=[
257
277
  primary_vpc,
258
278
  gpunets,
259
279
  gke_cluster,
260
- group_placement_0,
261
280
  a3_megagpu_pool_0,
262
281
  workload,
263
282
  workload_configmap,
264
283
  ],
265
284
  )
285
+ if set_placement_policy and reservation_placement_policy is None:
286
+ a3_megagpu_pool_0.use.append(group_placement_0.id)
287
+ primary_group.modules.append(group_placement_0)
266
288
  a3_mega_blueprint = Blueprint(
267
289
  terraform_backend_defaults=self._getblock_terraform_backend(
268
290
  gcs_bucket, cluster_name, prefix
@@ -478,14 +500,22 @@ class BlueprintGenerator:
478
500
  source="modules/scheduler/gke-cluster",
479
501
  use=[net_0_id],
480
502
  settings={
481
- "release_channel": "RAPID",
482
- "version_prefix": "1.31.",
483
- "maintenance_exclusions": [{
484
- "name": "no-minor-or-node-upgrades-indefinite",
485
- "start_time": "2024-12-01T00:00:00Z",
486
- "end_time": "2025-12-22T00:00:00Z",
487
- "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
488
- }],
503
+ "release_channel": (
504
+ "UNSPECIFIED"
505
+ if capacity_type == CapacityType.FLEX_START
506
+ else "RAPID"
507
+ ),
508
+ "version_prefix": "1.32.",
509
+ "maintenance_exclusions": (
510
+ []
511
+ if capacity_type == CapacityType.FLEX_START
512
+ else [{
513
+ "name": "no-minor-or-node-upgrades-indefinite",
514
+ "start_time": "2024-12-01T00:00:00Z",
515
+ "end_time": "2025-12-22T00:00:00Z",
516
+ "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
517
+ }]
518
+ ),
489
519
  "prefix_with_deployment_name": False,
490
520
  "name_suffix": cluster_name,
491
521
  "system_node_pool_machine_type": system_node_pool_machine_type,
@@ -534,9 +564,10 @@ class BlueprintGenerator:
534
564
  use=[cluster_id],
535
565
  settings={
536
566
  "machine_type": system.gce_machine_type,
537
- "auto_upgrade": True,
567
+ "auto_upgrade": (
568
+ True if capacity_type != CapacityType.FLEX_START else False
569
+ ),
538
570
  "zones": [zone],
539
- "static_node_count": num_nodes,
540
571
  "spot": capacity_type == CapacityType.SPOT,
541
572
  "reservation_affinity": self._getblock_reservation_affinity(
542
573
  reservation
@@ -562,6 +593,10 @@ class BlueprintGenerator:
562
593
  },
563
594
  outputs=["instructions"],
564
595
  )
596
+ if capacity_type == CapacityType.FLEX_START:
597
+ gpu_pool.settings.update(self.get_dws_flex_start())
598
+ else:
599
+ gpu_pool.settings.update({"static_node_count": num_nodes})
565
600
 
566
601
  num_chips = num_nodes * system.chips_per_vm
567
602
  workload_manager_install_id = "workload-manager-install"
@@ -572,9 +607,14 @@ class BlueprintGenerator:
572
607
  settings={
573
608
  "kueue": {
574
609
  "install": True,
575
- "version": "v0.10.0", # TAS feature-gates is enabled in CT
610
+ "version": KUEUE_VERSION, # TAS feature-gates is enabled in CT
576
611
  "config_path": f'$(ghpc_stage("{blueprint_name}"))/kueue-xpk-configuration.yaml.tftpl',
577
- "config_template_vars": {"num_chips": num_chips},
612
+ "config_template_vars": {
613
+ "num_chips": num_chips,
614
+ "flex_start": (
615
+ 1 if capacity_type == CapacityType.FLEX_START else 0
616
+ ),
617
+ },
578
618
  },
579
619
  "jobset": {"install": True, "version": "v0.7.2"},
580
620
  "apply_manifests": [
@@ -777,13 +817,21 @@ class BlueprintGenerator:
777
817
  f" {cluster_name}-rdma-net.subnetwork_interfaces_gke))"
778
818
  ),
779
819
  "version_prefix": "1.32.",
780
- "release_channel": "RAPID",
781
- "maintenance_exclusions": [{
782
- "name": "no-minor-or-node-upgrades-indefinite",
783
- "start_time": "2024-12-01T00:00:00Z",
784
- "end_time": "2025-12-22T00:00:00Z",
785
- "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
786
- }],
820
+ "release_channel": (
821
+ "UNSPECIFIED"
822
+ if capacity_type == CapacityType.FLEX_START
823
+ else "RAPID"
824
+ ),
825
+ "maintenance_exclusions": (
826
+ []
827
+ if capacity_type == CapacityType.FLEX_START
828
+ else [{
829
+ "name": "no-minor-or-node-upgrades-indefinite",
830
+ "start_time": "2024-12-01T00:00:00Z",
831
+ "end_time": "2025-12-22T00:00:00Z",
832
+ "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
833
+ }]
834
+ ),
787
835
  },
788
836
  outputs=["instructions"],
789
837
  )
@@ -800,10 +848,11 @@ class BlueprintGenerator:
800
848
  use=[cluster_id],
801
849
  settings={
802
850
  "machine_type": system.gce_machine_type,
803
- "auto_upgrade": True,
851
+ "auto_upgrade": (
852
+ True if capacity_type != CapacityType.FLEX_START else False
853
+ ),
804
854
  "zones": [zone],
805
855
  "disk_type": "hyperdisk-balanced",
806
- "static_node_count": num_nodes,
807
856
  "local_ssd_count_ephemeral_storage": 32,
808
857
  "spot": capacity_type == CapacityType.SPOT,
809
858
  "reservation_affinity": self._getblock_reservation_affinity(
@@ -830,6 +879,10 @@ class BlueprintGenerator:
830
879
  },
831
880
  outputs=["instructions"],
832
881
  )
882
+ if capacity_type == CapacityType.FLEX_START:
883
+ gpu_pool.settings.update(self.get_dws_flex_start())
884
+ else:
885
+ gpu_pool.settings.update({"static_node_count": num_nodes})
833
886
 
834
887
  num_chips = num_nodes * system.chips_per_vm
835
888
  workload_manager_install_id = "workload-manager-install"
@@ -840,9 +893,14 @@ class BlueprintGenerator:
840
893
  settings={
841
894
  "kueue": {
842
895
  "install": True,
843
- "version": "v0.10.0", # TAS feature-gates is enabled in CT
896
+ "version": KUEUE_VERSION, # TAS feature-gates is enabled in CT
844
897
  "config_path": f'$(ghpc_stage("{blueprint_name}"))/kueue-xpk-configuration.yaml.tftpl',
845
- "config_template_vars": {"num_chips": num_chips},
898
+ "config_template_vars": {
899
+ "num_chips": num_chips,
900
+ "flex_start": (
901
+ 1 if capacity_type == CapacityType.FLEX_START else 0
902
+ ),
903
+ },
846
904
  },
847
905
  "jobset": {"install": True, "version": "v0.7.2"},
848
906
  "apply_manifests": [
@@ -992,6 +1050,16 @@ class BlueprintGenerator:
992
1050
  )
993
1051
  return deployment_files_path
994
1052
 
1053
+ def get_dws_flex_start(self) -> dict:
1054
+ return {
1055
+ "enable_flex_start": True,
1056
+ "enable_queued_provisioning": True,
1057
+ "autoscaling_total_min_nodes": 0,
1058
+ "release_channel": "UNSPECIFIED",
1059
+ "auto_repair": False,
1060
+ "auto_upgrade": False,
1061
+ }
1062
+
995
1063
 
996
1064
  yaml.register_class(Blueprint)
997
1065
  yaml.register_class(DeploymentGroup)