xpk 0.8.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. xpk/commands/batch.py +5 -6
  2. xpk/commands/cluster.py +246 -73
  3. xpk/commands/cluster_gcluster.py +27 -0
  4. xpk/commands/common.py +40 -1
  5. xpk/commands/kjob_common.py +13 -1
  6. xpk/commands/run.py +4 -5
  7. xpk/commands/shell.py +2 -2
  8. xpk/commands/storage.py +24 -6
  9. xpk/commands/workload.py +66 -27
  10. xpk/core/blueprint/blueprint_generator.py +115 -47
  11. xpk/core/capacity.py +66 -6
  12. xpk/core/cluster.py +282 -13
  13. xpk/core/config.py +1 -65
  14. xpk/core/docker_manager.py +1 -1
  15. xpk/core/docker_resources.py +145 -72
  16. xpk/core/filestore.py +2 -6
  17. xpk/core/gcsfuse.py +22 -4
  18. xpk/core/jobset.py +143 -0
  19. xpk/core/kjob.py +21 -18
  20. xpk/core/kueue.py +194 -4
  21. xpk/core/mtc.py +195 -0
  22. xpk/core/network.py +23 -1
  23. xpk/core/nodepool.py +17 -4
  24. xpk/core/pathways.py +2 -3
  25. xpk/core/resources.py +21 -0
  26. xpk/core/storage.py +1 -95
  27. xpk/core/system_characteristics.py +1 -1
  28. xpk/core/workload.py +1 -45
  29. xpk/core/workload_decorators/rdma_decorator.py +8 -10
  30. xpk/core/workload_decorators/tcpx_decorator.py +185 -0
  31. xpk/core/workload_decorators/tcpxo_decorator.py +22 -14
  32. xpk/parser/cluster.py +589 -389
  33. xpk/parser/storage.py +12 -3
  34. xpk/parser/workload.py +21 -3
  35. xpk/utils/kubectl.py +4 -1
  36. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/METADATA +178 -96
  37. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/RECORD +41 -38
  38. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/WHEEL +1 -1
  39. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/entry_points.txt +0 -0
  40. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/licenses/LICENSE +0 -0
  41. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/top_level.txt +0 -0
xpk/core/kueue.py CHANGED
@@ -16,11 +16,13 @@ limitations under the License.
16
16
 
17
17
  from argparse import Namespace
18
18
 
19
+ import math
19
20
  import packaging
20
21
  from packaging.version import Version
21
22
 
22
23
  from ..utils.console import xpk_exit, xpk_print
23
24
  from ..utils.file import write_tmp_file
25
+ from .capacity import B200_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
24
26
  from .commands import (
25
27
  run_command_for_value,
26
28
  run_command_with_updates,
@@ -38,13 +40,28 @@ from .system_characteristics import (
38
40
  SystemCharacteristics,
39
41
  )
40
42
 
41
- KUEUE_VERSION = 'v0.10.0'
43
+ KUEUE_VERSION = 'v0.12.2'
42
44
  CLUSTER_QUEUE_NAME = 'cluster-queue'
43
45
  LOCAL_QUEUE_NAME = 'multislice-queue'
44
46
  WAIT_FOR_KUEUE_TIMEOUT = '5m'
47
+ MEMORY_SIZE_PER_VM = 1.2
48
+ MIN_MEMORY_LIMIT_SIZE = 4096
45
49
 
46
50
  packaging.version.VERSION_PATTERN = r'^v\d+\.\d+\.\d+$'
47
51
 
52
+ topology_yaml = """apiVersion: kueue.x-k8s.io/v1alpha1
53
+ kind: Topology
54
+ metadata:
55
+ name: "gke-default"
56
+ spec:
57
+ levels:
58
+ - nodeLabel: "cloud.google.com/gce-topology-block"
59
+ - nodeLabel: "cloud.google.com/gce-topology-subblock"
60
+ - nodeLabel: "cloud.google.com/gce-topology-host"
61
+ - nodeLabel: "kubernetes.io/hostname"
62
+ ---
63
+ """
64
+
48
65
  cluster_set_crd_yaml = """apiVersion: kueue.x-k8s.io/v1beta1
49
66
  kind: ResourceFlavor
50
67
  metadata:
@@ -53,6 +70,27 @@ spec:
53
70
  nodeLabels:
54
71
  {accelerator_label}
55
72
  {machine_label}
73
+ {topology_label}
74
+ ---
75
+ apiVersion: kueue.x-k8s.io/v1beta1
76
+ kind: AdmissionCheck
77
+ metadata:
78
+ name: dws-prov
79
+ spec:
80
+ controllerName: kueue.x-k8s.io/provisioning-request
81
+ parameters:
82
+ apiGroup: kueue.x-k8s.io
83
+ kind: ProvisioningRequestConfig
84
+ name: dws-config
85
+ ---
86
+ apiVersion: kueue.x-k8s.io/v1beta1
87
+ kind: ProvisioningRequestConfig
88
+ metadata:
89
+ name: dws-config
90
+ spec:
91
+ provisioningClassName: queued-provisioning.gke.io
92
+ managedResources:
93
+ - {managed_resource}
56
94
  ---
57
95
  {pw_resource_flavors}
58
96
  apiVersion: kueue.x-k8s.io/v1beta1
@@ -67,6 +105,7 @@ spec:
67
105
  resourceGroups:
68
106
  {covered_resources_config}
69
107
  {pw_resources_kueue}
108
+ {admission_checks}
70
109
  ---
71
110
  apiVersion: kueue.x-k8s.io/v1beta1
72
111
  kind: LocalQueue
@@ -151,6 +190,99 @@ spec:
151
190
  command: [ "sleep", "inf" ]
152
191
  """
153
192
 
193
+ kueue_controller_manager_yml = """
194
+ apiVersion: apps/v1
195
+ kind: Deployment
196
+ metadata:
197
+ labels:
198
+ app.kubernetes.io/component: controller
199
+ app.kubernetes.io/name: kueue
200
+ control-plane: controller-manager
201
+ name: kueue-controller-manager
202
+ namespace: kueue-system
203
+ spec:
204
+ replicas: 1
205
+ selector:
206
+ matchLabels:
207
+ control-plane: controller-manager
208
+ template:
209
+ metadata:
210
+ annotations:
211
+ kubectl.kubernetes.io/default-container: manager
212
+ labels:
213
+ app.kubernetes.io/component: controller
214
+ app.kubernetes.io/name: kueue
215
+ control-plane: controller-manager
216
+ spec:
217
+ containers:
218
+ - args:
219
+ - --config=/controller_manager_config.yaml
220
+ - --zap-log-level=2
221
+ command:
222
+ - /manager
223
+ image: registry.k8s.io/kueue/kueue:v0.10.0
224
+ imagePullPolicy: Always
225
+ livenessProbe:
226
+ httpGet:
227
+ path: /healthz
228
+ port: 8081
229
+ initialDelaySeconds: 15
230
+ periodSeconds: 20
231
+ name: manager
232
+ ports:
233
+ - containerPort: 8082
234
+ name: visibility
235
+ protocol: TCP
236
+ - containerPort: 9443
237
+ name: webhook-server
238
+ protocol: TCP
239
+ readinessProbe:
240
+ httpGet:
241
+ path: /readyz
242
+ port: 8081
243
+ initialDelaySeconds: 5
244
+ periodSeconds: 10
245
+ resources:
246
+ limits:
247
+ cpu: 500m
248
+ memory: {memory_limit_size}
249
+ requests:
250
+ cpu: 500m
251
+ memory: 512Mi
252
+ securityContext:
253
+ allowPrivilegeEscalation: false
254
+ volumeMounts:
255
+ - mountPath: /tmp/k8s-webhook-server/serving-certs
256
+ name: cert
257
+ readOnly: true
258
+ - mountPath: /controller_manager_config.yaml
259
+ name: manager-config
260
+ subPath: controller_manager_config.yaml
261
+ - args:
262
+ - --secure-listen-address=0.0.0.0:8443
263
+ - --upstream=http://127.0.0.1:8080/
264
+ - --logtostderr=true
265
+ - --v=10
266
+ image: registry.k8s.io/kubebuilder/kube-rbac-proxy:v0.16.0
267
+ name: kube-rbac-proxy
268
+ ports:
269
+ - containerPort: 8443
270
+ name: https
271
+ protocol: TCP
272
+ securityContext:
273
+ runAsNonRoot: true
274
+ serviceAccountName: kueue-controller-manager
275
+ terminationGracePeriodSeconds: 10
276
+ volumes:
277
+ - name: cert
278
+ secret:
279
+ defaultMode: 420
280
+ secretName: kueue-webhook-server-cert
281
+ - configMap:
282
+ name: kueue-manager-config
283
+ name: manager-config
284
+ """
285
+
154
286
 
155
287
  def verify_kueuectl(args: Namespace) -> None:
156
288
  """Verify if kueuectl is installed.
@@ -267,6 +399,7 @@ def install_kueue_crs(
267
399
  args,
268
400
  system: SystemCharacteristics,
269
401
  autoprovisioning_config: AutoprovisioningConfig | None,
402
+ flex_with_tpu=False,
270
403
  ) -> int:
271
404
  """Install Kueue Custom Resources.
272
405
 
@@ -294,12 +427,29 @@ def install_kueue_crs(
294
427
  else:
295
428
  # Determine total chips based on user specified topology.
296
429
  total_chips = get_total_chips_requested_from_args(args, system)
430
+ if args.flex and flex_with_tpu is False:
431
+ admission_checks = """
432
+ admissionChecks:
433
+ - dws-prov
434
+ """
435
+ else:
436
+ admission_checks = ''
297
437
 
298
438
  covered_resources_config = get_kueue_covered_resources_config(
299
439
  cluster_hardware_name=cluster_hardware_name,
300
440
  resource_type=resource_type,
301
441
  total_chips=total_chips,
302
442
  )
443
+ topology_label = ''
444
+ if system.device_type in [
445
+ H100_MEGA_DEVICE_TYPE,
446
+ H200_DEVICE_TYPE,
447
+ B200_DEVICE_TYPE,
448
+ ]:
449
+ topology_label = 'topologyName: "gke-default"'
450
+ res_type = AcceleratorTypeToAcceleratorCharacteristics[
451
+ system.accelerator_type
452
+ ].resource_type
303
453
  yml_string = cluster_set_crd_yaml.format(
304
454
  system=system,
305
455
  cluster_hardware_name=cluster_hardware_name,
@@ -309,15 +459,22 @@ def install_kueue_crs(
309
459
  machine_label=create_machine_label(
310
460
  system.accelerator_type, system, autoprovisioning_enabled
311
461
  ),
462
+ topology_label=topology_label,
312
463
  covered_resources_config=covered_resources_config,
313
- resource_type=AcceleratorTypeToAcceleratorCharacteristics[
314
- system.accelerator_type
315
- ].resource_type,
464
+ resource_type=res_type,
316
465
  pw_resource_flavors=add_pw_resource_flavors(args),
317
466
  pw_resources_kueue=add_pw_resources_to_kueue(args),
467
+ admission_checks=admission_checks,
468
+ managed_resource=res_type,
318
469
  cluster_queue_name=CLUSTER_QUEUE_NAME,
319
470
  local_queue_name=LOCAL_QUEUE_NAME,
320
471
  )
472
+ if system.device_type in [
473
+ H100_MEGA_DEVICE_TYPE,
474
+ H200_DEVICE_TYPE,
475
+ B200_DEVICE_TYPE,
476
+ ]:
477
+ yml_string = topology_yaml + yml_string
321
478
 
322
479
  tmp = write_tmp_file(yml_string)
323
480
  command = f'kubectl apply -f {str(tmp.file.name)}'
@@ -356,3 +513,36 @@ def get_kueue_covered_resources_config(
356
513
  total_chips=total_chips,
357
514
  )
358
515
  return config_string
516
+
517
+
518
+ def update_kueue_resources_if_necessary(args):
519
+ """Update the kueue manifest to increase the resources for the kueue controller manager.
520
+
521
+ Args:
522
+ args: user provided arguments for running the command.
523
+
524
+ Returns:
525
+ 0 if successful and 1 otherwise.
526
+ """
527
+ # Get total number of nodes
528
+ cmd_total_node_num = 'kubectl get node --no-headers | wc -l'
529
+ return_code, out = run_command_for_value(
530
+ cmd_total_node_num, 'Count total nodes', args
531
+ )
532
+ if return_code != 0:
533
+ xpk_exit(1)
534
+ # 1.2MiB per VM or 4GiB (whichever is greater).
535
+ new_memory_limit = (
536
+ f'{max(math.ceil(int(out) * MEMORY_SIZE_PER_VM), MIN_MEMORY_LIMIT_SIZE)}Mi'
537
+ )
538
+ yml_string = kueue_controller_manager_yml.format(
539
+ memory_limit_size=new_memory_limit,
540
+ )
541
+ tmp = write_tmp_file(yml_string)
542
+ command = f'kubectl apply -f {str(tmp.file.name)}'
543
+
544
+ task = 'Updating Kueue Controller Manager resources'
545
+ return_code = run_command_with_updates_retry(command, task, args)
546
+ if return_code != 0:
547
+ xpk_print(f'{task} returned ERROR {return_code}')
548
+ return return_code
xpk/core/mtc.py ADDED
@@ -0,0 +1,195 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import requests
18
+ import yaml
19
+
20
+ from ..core.cluster import JOBSET_VERSION
21
+ from ..core.cluster import setup_k8s_env
22
+ from ..utils import templates
23
+ from ..utils.console import xpk_exit
24
+ from ..utils.console import xpk_print
25
+ from ..utils.kubectl import apply_kubectl_manifest
26
+
27
+
28
+ MTC_CPC_PATH = "/../templates/mtc-cpc.yaml"
29
+
30
+
31
+ def create_mtc_cpc(
32
+ mtc_gcs_bucket: str,
33
+ mtc_machine_type: str,
34
+ mtc_toleration_key: str,
35
+ mtc_ramdisk_size: str,
36
+ ) -> dict:
37
+ """Create MTC Checkpoint Configuration.
38
+
39
+ Args:
40
+ mtc_gcs_bucket: GCS bucket for MTC
41
+ mtc_machine_type: Machine type for MTC
42
+ mtc_toleration_key: Toleration key for MTC
43
+ mtc_ramdisk_size: Ramdisk size for MTC
44
+
45
+ Returns:
46
+ MTC Checkpoint Configuration
47
+ """
48
+ data = templates.load(MTC_CPC_PATH)
49
+
50
+ data["spec"]["cloudStorageBucketName"] = mtc_gcs_bucket
51
+ data["spec"]["nodeSelector"][
52
+ "node.kubernetes.io/instance-type"
53
+ ] = mtc_machine_type
54
+ data["spec"]["tolerations"][0]["key"] = mtc_toleration_key
55
+ data["spec"]["inMemoryVolumeSize"] = mtc_ramdisk_size
56
+
57
+ return data
58
+
59
+
60
+ def install_mtc_on_cluster(args, system) -> int:
61
+ """Install MTC on the cluster.
62
+
63
+ Args:
64
+ args: user provided arguments for running the command.
65
+ system: system related information.
66
+
67
+ Returns:
68
+ return code of the command.
69
+ """
70
+ if args.mtc_gcs_bucket is None:
71
+ xpk_print("MTC GCS bucket is required.")
72
+ xpk_exit(1)
73
+ if args.mtc_gcs_bucket.startswith("gs://"):
74
+ args.mtc_gcs_bucket = args.mtc_gcs_bucket.replace("gs://", "")
75
+
76
+ if args.mtc_ramdisk_size is None:
77
+ xpk_print("MTC ramdisk size is required.")
78
+ xpk_exit(1)
79
+
80
+ if args.mtc_toleration_key is None:
81
+ args.mtc_toleration_key = "google.com/tpu"
82
+
83
+ k8s_api_client = setup_k8s_env(args)
84
+ jobset_manifest = update_jobset_manifest()
85
+ if jobset_manifest is None:
86
+ xpk_print(
87
+ "Updated jobset manifest is empty, not updating the jobset controller."
88
+ )
89
+
90
+ xpk_print("Applying Jobset with MTC Configuration")
91
+ return_code = apply_kubectl_manifest(k8s_api_client, [jobset_manifest])
92
+ if return_code != 0:
93
+ return return_code
94
+
95
+ mtc_checkpoint_configuration_crd_data = create_mtc_cpc(
96
+ args.mtc_gcs_bucket,
97
+ system.gce_machine_type,
98
+ args.mtc_toleration_key,
99
+ args.mtc_ramdisk_size,
100
+ )
101
+ xpk_print("Applying MTC Checkpoint Configuration")
102
+ return_code = apply_kubectl_manifest(
103
+ k8s_api_client, [mtc_checkpoint_configuration_crd_data]
104
+ )
105
+
106
+ return return_code
107
+
108
+
109
+ def update_jobset_manifest():
110
+ """Update the jobset manifest to increase the resources for the jobset controller manager.
111
+
112
+ Returns:
113
+ The updated jobset manifest.
114
+ """
115
+ manifest_url = f"https://github.com/kubernetes-sigs/jobset/releases/download/{JOBSET_VERSION}/manifests.yaml"
116
+ manifest_content = None
117
+ # Fetch the manifest content
118
+ try:
119
+ response = requests.get(manifest_url, timeout=10)
120
+ response.raise_for_status() # Raise an exception for HTTP errors
121
+ manifest_content = response.text
122
+ except requests.exceptions.Timeout as e:
123
+ xpk_print(f"Error: Request to {manifest_url} after 10 seconds: {e}")
124
+ xpk_exit(1)
125
+ except requests.exceptions.RequestException as e:
126
+ xpk_print(f"Error fetching manifest from {manifest_url}: {e}")
127
+ xpk_exit(1)
128
+
129
+ if manifest_content is None:
130
+ xpk_print("Manifest content not found.")
131
+ xpk_exit(1)
132
+
133
+ # Load all YAML documents from the manifest
134
+ yaml_data_list = list(yaml.safe_load_all(manifest_content))
135
+ # Iterate through the yaml_data to find the Deployment for
136
+ # jobset-controller-manager
137
+ update_manifest = False
138
+ for yaml_data in yaml_data_list:
139
+ if (
140
+ yaml_data
141
+ and yaml_data.get("apiVersion") == "apps/v1"
142
+ and yaml_data.get("kind") == "Deployment"
143
+ and yaml_data.get("metadata", {}).get("name")
144
+ == "jobset-controller-manager"
145
+ ):
146
+ # Found the Deployment, now modify the resources
147
+ containers = yaml_data["spec"]["template"]["spec"]["containers"]
148
+ for container in containers:
149
+ if container["name"] == "manager":
150
+ # Update resource limits and requests
151
+ current_cpu_request = (
152
+ container["resources"].get("requests", {}).get("cpu", "0m")
153
+ )
154
+ current_memory_request = (
155
+ container["resources"].get("requests", {}).get("memory", "0Mi")
156
+ )
157
+ current_memory_limit = (
158
+ container["resources"].get("limits", {}).get("memory", "0Mi")
159
+ )
160
+
161
+ # Define new values for comparison
162
+ new_cpu_request = "1000m"
163
+ new_memory_request = "1Gi"
164
+ new_memory_limit = "2Gi"
165
+
166
+ if parse_resource_value(current_cpu_request) < parse_resource_value(
167
+ new_cpu_request
168
+ ):
169
+ container["resources"]["requests"]["cpu"] = new_cpu_request
170
+ update_manifest = True
171
+ if parse_resource_value(
172
+ current_memory_request
173
+ ) < parse_resource_value(new_memory_request):
174
+ container["resources"]["requests"]["memory"] = new_memory_request
175
+ update_manifest = True
176
+ if parse_resource_value(current_memory_limit) < parse_resource_value(
177
+ new_memory_limit
178
+ ):
179
+ container["resources"]["limits"]["memory"] = new_memory_limit
180
+ update_manifest = True
181
+ break
182
+ if update_manifest:
183
+ xpk_print("Jobset controller updation required.")
184
+ return yaml_data
185
+ xpk_print("Jobset controller no updation required.")
186
+
187
+
188
+ def parse_resource_value(value) -> int:
189
+ if value.endswith("m"):
190
+ return int(value[:-1])
191
+ if value.endswith("Mi"):
192
+ return int(value[:-2])
193
+ if value.endswith("Gi"):
194
+ return int(value[:-2]) * 1024
195
+ return int(value)
xpk/core/network.py CHANGED
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from ..utils.console import xpk_print
17
+ from ..utils.console import xpk_exit, xpk_print
18
18
  from ..utils.file import write_tmp_file
19
19
  from .commands import run_command_for_value, run_command_with_updates
20
20
  from .gcloud_context import zone_to_region
@@ -235,6 +235,28 @@ def create_cluster_network_config(args) -> int:
235
235
  return 0
236
236
 
237
237
 
238
+ def get_cluster_subnetworks(args) -> list[str]:
239
+ """Gets the list of cluster networks.
240
+
241
+ Args:
242
+ args: user provided arguments for running the command.
243
+
244
+ Returns:
245
+ list[str]: list of cluster networks
246
+ """
247
+ command = 'kubectl get GKENetworkParamSet'
248
+ return_code, stdout = run_command_for_value(
249
+ command, 'Get Cluster Networks', args
250
+ )
251
+ if return_code != 0:
252
+ xpk_print('GKE Cluster Get NetworkParamSet failed')
253
+ xpk_exit(return_code)
254
+
255
+ networks = [line.split()[0] for line in stdout.splitlines()][1:]
256
+
257
+ return networks
258
+
259
+
238
260
  def set_up_cluster_network_for_a3(args) -> int:
239
261
  """Set up GKE Cluster networks, subnets and firewall rules for A3.
240
262
  Note: there are 4 NICs for GPU-GPU bw and 1 NIC for host in an A3 node.
xpk/core/nodepool.py CHANGED
@@ -77,8 +77,12 @@ def run_gke_node_pool_create_command(
77
77
  if return_code > 0:
78
78
  xpk_print('Listing all reservations failed!')
79
79
  return_code = 1
80
+ if system.accelerator_type == AcceleratorType['TPU']:
81
+ max_nodes = system.vms_per_slice
82
+ else:
83
+ max_nodes = 1000
80
84
  capacity_args, return_code = get_capacity_arguments_from_capacity_type(
81
- args, capacity_type
85
+ args, capacity_type, max_nodes
82
86
  )
83
87
  if return_code > 0:
84
88
  xpk_print('Parsing capacity arguments failed!')
@@ -275,7 +279,10 @@ def run_gke_node_pool_create_command(
275
279
  )
276
280
  if system.accelerator_type == AcceleratorType['TPU']:
277
281
  command += f' --node-version={gke_node_pool_version}'
278
- command += f' --num-nodes={system.vms_per_slice}'
282
+ if capacity_type == CapacityType.FLEX_START:
283
+ command += ' --num-nodes=0'
284
+ else:
285
+ command += f' --num-nodes={system.vms_per_slice}'
279
286
  command += ' --placement-type=COMPACT --max-pods-per-node 15'
280
287
  command += (
281
288
  f' --scopes=storage-full,gke-default,{CLOUD_PLATFORM_AUTH_SCOPE_URL}'
@@ -284,7 +291,10 @@ def run_gke_node_pool_create_command(
284
291
  command += f' {args.custom_tpu_nodepool_arguments}'
285
292
  elif system.accelerator_type == AcceleratorType['GPU']:
286
293
  subnet_prefix = f'{args.cluster}-{zone_to_region(args.zone)}'
287
- command += f' --num-nodes={args.num_nodes}'
294
+ if capacity_type == CapacityType.FLEX_START:
295
+ command += ' --num-nodes=0'
296
+ else:
297
+ command += f' --num-nodes={args.num_nodes}'
288
298
  command += (
289
299
  ' --accelerator'
290
300
  f' type={system.gke_accelerator},count={str(system.chips_per_vm)},gpu-driver-version=latest'
@@ -298,7 +308,10 @@ def run_gke_node_pool_create_command(
298
308
  )
299
309
  command += ' --max-pods-per-node=32'
300
310
  elif system.accelerator_type == AcceleratorType['CPU']:
301
- command += f' --num-nodes={system.vms_per_slice}'
311
+ if capacity_type == CapacityType.FLEX_START:
312
+ command += ' --num-nodes=0'
313
+ else:
314
+ command += f' --num-nodes={system.vms_per_slice}'
302
315
  command += (
303
316
  f' --scopes=storage-full,gke-default,{CLOUD_PLATFORM_AUTH_SCOPE_URL}'
304
317
  )
xpk/core/pathways.py CHANGED
@@ -19,8 +19,7 @@ from ..core.docker_container import get_user_workload_container
19
19
  from ..core.gcloud_context import zone_to_region
20
20
  from ..core.nodepool import get_all_nodepools_programmatic
21
21
  from ..utils.console import xpk_exit, xpk_print
22
- from .config import AcceleratorType
23
- from .system_characteristics import SystemCharacteristics
22
+ from .system_characteristics import AcceleratorType, SystemCharacteristics
24
23
 
25
24
 
26
25
  def add_pw_resource_flavors(args):
@@ -211,7 +210,7 @@ def append_custom_pathways_worker(args) -> str:
211
210
  """
212
211
  yaml = """"""
213
212
  if args.server_image or args.custom_pathways_worker_args:
214
- yaml = """- componentType: pathways_worker"""
213
+ yaml = """- componentType: worker"""
215
214
  indentation = (
216
215
  ' ' * 8
217
216
  ) # Currently 8, based on the YAML, may need to update in the future.
xpk/core/resources.py CHANGED
@@ -236,3 +236,24 @@ def get_cluster_system_characteristics(args) -> SystemCharacteristics | None:
236
236
  return system
237
237
 
238
238
  return None
239
+
240
+
241
+ def get_cluster_capacity_type(args) -> CapacityType | None:
242
+ """Get systemCharcteristics based on the cluster resources configMap
243
+ Args:
244
+ args: user provided arguments for running the command.
245
+
246
+ Returns:
247
+ returns system characteristics
248
+ """
249
+ metadata_configmap_name = f'{args.cluster}-{CLUSTER_METADATA_CONFIGMAP}'
250
+ cluster_config_map = get_cluster_configmap(args, metadata_configmap_name)
251
+
252
+ if cluster_config_map is None:
253
+ return None
254
+
255
+ capacityValue = cluster_config_map.get('capacity_type')
256
+ if capacityValue is not None:
257
+ return CapacityType[capacityValue.upper()]
258
+
259
+ return None