xpk 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. xpk/commands/cluster.py +10 -11
  2. xpk/commands/cluster_gcluster.py +2 -1
  3. xpk/commands/common.py +3 -3
  4. xpk/commands/info.py +12 -12
  5. xpk/commands/job.py +12 -10
  6. xpk/commands/kjob_common.py +2 -1
  7. xpk/commands/storage.py +1 -1
  8. xpk/commands/workload.py +12 -6
  9. xpk/core/blueprint/blueprint_generator.py +7 -7
  10. xpk/core/blueprint/blueprint_test.py +218 -0
  11. xpk/core/capacity.py +3 -1
  12. xpk/core/cluster.py +9 -7
  13. xpk/core/cluster_private.py +5 -1
  14. xpk/core/commands.py +3 -3
  15. xpk/core/config.py +3 -4
  16. xpk/core/config_test.py +71 -0
  17. xpk/core/docker_manager.py +1 -1
  18. xpk/core/docker_resources.py +1 -1
  19. xpk/core/filestore.py +7 -2
  20. xpk/core/gcloud_context.py +2 -2
  21. xpk/core/kjob.py +2 -1
  22. xpk/core/kueue.py +6 -2
  23. xpk/core/nap.py +4 -4
  24. xpk/core/nodepool_test.py +82 -0
  25. xpk/core/resources.py +1 -7
  26. xpk/core/storage.py +14 -14
  27. xpk/core/system_characteristics.py +1 -1
  28. xpk/core/workload.py +11 -0
  29. xpk/core/workload_decorators/rdma_decorator.py +3 -2
  30. xpk/core/workload_decorators/storage_decorator.py +2 -1
  31. xpk/core/workload_decorators/tcpx_decorator.py +4 -2
  32. xpk/core/workload_decorators/tcpx_decorator_test.py +267 -0
  33. xpk/core/workload_decorators/tcpxo_decorator.py +2 -1
  34. xpk/core/workload_test.py +28 -0
  35. xpk/main.py +9 -10
  36. xpk/parser/cluster.py +67 -49
  37. xpk/parser/common.py +45 -36
  38. xpk/parser/storage.py +12 -13
  39. xpk/parser/workload.py +57 -39
  40. xpk/utils/console.py +2 -1
  41. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/METADATA +4 -1
  42. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/RECORD +46 -41
  43. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/WHEEL +0 -0
  44. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/entry_points.txt +0 -0
  45. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/licenses/LICENSE +0 -0
  46. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/top_level.txt +0 -0
xpk/commands/cluster.py CHANGED
@@ -92,7 +92,7 @@ def cluster_adapt(args) -> None:
92
92
 
93
93
  system, return_code = get_system_characteristics(args)
94
94
 
95
- if return_code > 0:
95
+ if return_code > 0 or system is None:
96
96
  xpk_print('Fetching system characteristics failed!')
97
97
  xpk_exit(return_code)
98
98
 
@@ -141,8 +141,6 @@ def cluster_adapt(args) -> None:
141
141
  if not tensorboard_config:
142
142
  xpk_exit(1)
143
143
 
144
- # Provision node pools dynamically based on incoming workloads:
145
- # Currently autoprovisioning is not supported with Pathways.
146
144
  autoprovisioning_config = None
147
145
  if args.enable_autoprovisioning:
148
146
  xpk_print('Enabling Autoprovisioning')
@@ -201,7 +199,7 @@ def cluster_create(args) -> None:
201
199
  """
202
200
  system, return_code = get_system_characteristics(args)
203
201
 
204
- if return_code > 0:
202
+ if return_code > 0 or system is None:
205
203
  xpk_print('Fetching system characteristics failed!')
206
204
  xpk_exit(return_code)
207
205
 
@@ -217,13 +215,13 @@ def cluster_create(args) -> None:
217
215
  xpk_exit(0)
218
216
 
219
217
  return_code, gke_server_config = get_gke_server_config(args)
220
- if return_code != 0:
218
+ if return_code != 0 or gke_server_config is None:
221
219
  xpk_exit(return_code)
222
220
 
223
221
  return_code, gke_control_plane_version = get_gke_control_plane_version(
224
222
  args, gke_server_config
225
223
  )
226
- if return_code != 0:
224
+ if return_code != 0 or gke_control_plane_version is None:
227
225
  xpk_exit(return_code)
228
226
 
229
227
  create_cluster_command_code = create_cluster_if_necessary(
@@ -294,7 +292,7 @@ def cluster_create(args) -> None:
294
292
  # Provision node pools dynamically based on incoming workloads:
295
293
  # Currently autoprovisioning is not supported with Pathways.
296
294
  autoprovisioning_config = None
297
- if not args.enable_pathways and args.enable_autoprovisioning:
295
+ if args.enable_autoprovisioning:
298
296
  xpk_print('Enabling Autoprovisioning')
299
297
  autoprovisioning_config, return_code = enable_autoprovisioning_on_cluster(
300
298
  args, system
@@ -398,7 +396,7 @@ def cluster_cacheimage(args) -> None:
398
396
  get_cluster_credentials(args)
399
397
  system, return_code = get_system_characteristics(args)
400
398
 
401
- if return_code > 0:
399
+ if return_code > 0 or system is None:
402
400
  xpk_print('Fetching system characteristics failed!')
403
401
  xpk_exit(return_code)
404
402
 
@@ -808,6 +806,7 @@ def scale_up_coredns(args, replicas: int = 15, namespace: str = 'kube-system'):
808
806
 
809
807
  def check_deployment_exists(args, deployment_name: str, namespace: str) -> bool:
810
808
  """Check for the existence of a specific Deployment in a given namespace."""
809
+ # TODO: rewrite this to be more obvious, check if it is correct
811
810
  command = (
812
811
  f'kubectl get deployment {deployment_name} -n'
813
812
  f' {namespace} --ignore-not-found'
@@ -815,11 +814,11 @@ def check_deployment_exists(args, deployment_name: str, namespace: str) -> bool:
815
814
  result = run_command_with_updates(
816
815
  command, 'Waiting for kubeDNS to be checked.', args
817
816
  )
818
- return result
817
+ return result != 0
819
818
 
820
819
 
821
820
  def verify_coredns_readiness(
822
- args, timeout: int = 120, namespace: str = 'kube-system'
821
+ args, timeout: int = 240, namespace: str = 'kube-system'
823
822
  ):
824
823
  """Verifies CoreDNS readiness using kubectl wait commands."""
825
824
  xpk_print('Now verifying CoreDNS readiness...')
@@ -874,7 +873,7 @@ def cleanup_coredns_repo(coredns_repo_full_path: str):
874
873
  xpk_print(f'Error deleting directory {coredns_repo_full_path}: {e}')
875
874
 
876
875
 
877
- def update_coredns(args):
876
+ def update_coredns(args) -> int:
878
877
  """Updates and deploys CoreDNS within a cluster.
879
878
 
880
879
  Args:
@@ -310,4 +310,5 @@ def generate_blueprint(
310
310
  system_node_pool_machine_type=args.default_pool_cpu_machine_type,
311
311
  system_node_pool_min_node_count=args.default_pool_cpu_num_nodes,
312
312
  )
313
- return None
313
+ xpk_print('Device type is not supported.')
314
+ xpk_exit(1)
xpk/commands/common.py CHANGED
@@ -50,8 +50,8 @@ def set_cluster_command(args) -> int:
50
50
 
51
51
 
52
52
  def is_TAS_possible(
53
- system_characteristics: SystemCharacteristics,
54
- capacity_type: CapacityType,
53
+ system_characteristics: SystemCharacteristics | None,
54
+ capacity_type: CapacityType | None,
55
55
  flex: bool,
56
56
  ) -> bool:
57
57
  """Check cluster's machine_type and capacity type to determine if Kueue TAS is possible
@@ -71,7 +71,7 @@ def is_TAS_possible(
71
71
  xpk_print('capacity_type data was not found in configmaps.')
72
72
  xpk_exit(1)
73
73
 
74
- if flex:
74
+ if not flex:
75
75
  return False
76
76
 
77
77
  if (
xpk/commands/info.py CHANGED
@@ -51,19 +51,19 @@ def info(args: Namespace) -> None:
51
51
  cqs = run_kueuectl_list_clusterqueue(args)
52
52
  quotas = get_nominal_quotas(cqs)
53
53
 
54
- if lq:
54
+ if lq and lqs is not None:
55
55
  print_formatted_lqs(lqs, quotas)
56
56
 
57
57
  if cq:
58
58
  print_formatted_cqs(cqs, quotas)
59
59
 
60
60
 
61
- def get_nominal_quotas(cqs: list[dict]) -> dict[str, dict[str, str]]:
61
+ def get_nominal_quotas(cqs: str) -> dict[str, dict[str, str]]:
62
62
  """Get quotas from clusterqueues.
63
63
  This function retrieves how much of resource in each flavor is assigned to cluster queue.
64
64
  It parses flavors of passed cluster queues.
65
65
  Args:
66
- - cqs - list of cluster queues.
66
+ - cqs - string containing a list of cluster queues in JSON format.
67
67
  Returns:
68
68
  - dictionary of cluster queues resources quotas in format:
69
69
  {cq_name:{"flavorName:resourceName":quota}}
@@ -75,7 +75,7 @@ def get_nominal_quotas(cqs: list[dict]) -> dict[str, dict[str, str]]:
75
75
  xpk_print(cqs)
76
76
  xpk_exit(1)
77
77
 
78
- quotas = {}
78
+ quotas: dict[str, dict] = {}
79
79
  for cq in cq_list:
80
80
  spec = cq['spec']
81
81
  cq_name = cq['metadata']['name']
@@ -89,7 +89,7 @@ def get_nominal_quotas(cqs: list[dict]) -> dict[str, dict[str, str]]:
89
89
  return quotas
90
90
 
91
91
 
92
- def print_formatted_cqs(cqs: list[dict], nominalQuotas) -> None:
92
+ def print_formatted_cqs(cqs: str, nominalQuotas) -> None:
93
93
  try:
94
94
  cq_list = json.loads(cqs)['items']
95
95
  except ValueError:
@@ -105,7 +105,7 @@ def print_formatted_cqs(cqs: list[dict], nominalQuotas) -> None:
105
105
  )
106
106
 
107
107
 
108
- def print_formatted_lqs(lqs: list[dict], nominalQuotas) -> None:
108
+ def print_formatted_lqs(lqs: str, nominalQuotas) -> None:
109
109
  try:
110
110
  lq_list = json.loads(lqs)['items']
111
111
  except ValueError:
@@ -143,18 +143,18 @@ def parse_queue_lists(
143
143
 
144
144
 
145
145
  def get_flavors_resources_reservations(
146
- cq_name: str, flavors_res: list[dict]
146
+ cq_name: str, flavors_res: dict
147
147
  ) -> dict[str, dict[str, str]]:
148
148
  """Get usage of flavors resources.
149
149
  This function parser flavorsReservation section of clusterQueue of LocalQueue.
150
150
  Args:
151
151
  - cq_name - name of ClusterQueue to which flavors belong.
152
- - flavors_res - list of reservations made by flavors
152
+ - flavors_res - dict of reservations made by flavors
153
153
  Returns:
154
154
  Dict containing usage of each resource in flavor for each flavor in cluster or local queue.
155
155
  Dict format: {cq_name: {{flavor:resource}:reservation}}
156
156
  """
157
- reservations = {}
157
+ reservations: dict[str, dict] = {}
158
158
  reservations[cq_name] = {}
159
159
  for flavor_name, flavor_resources_reservation_list in flavors_res.items():
160
160
  for resource in flavor_resources_reservation_list:
@@ -167,15 +167,15 @@ def get_flavors_resources_reservations(
167
167
 
168
168
  def get_flavors_usage(
169
169
  q_entry: dict, res_field: str, flavor_resource_quotas: dict
170
- ) -> list[dict]:
170
+ ) -> dict[str, str]:
171
171
  """Parse q_entry to retrieve list of each resource usage in flavour.
172
172
  Args:
173
173
  q_entry - single entry into either LocalQueue or ClusterQueue structured as json
174
174
  flavor_resource_quotas - nominalQuota of flavors resource usage for each clusterqueue
175
175
  Returns:
176
- list of dicts where each list entry is in format (key, entry) where:
176
+ Dict where for each (key, value):
177
177
  - key is flavorName:resourceName
178
- - entry is flavorResourceReservation/flavorResourceQuota
178
+ - value is string formatted as 'flavorResourceReservation/flavorResourceQuota'
179
179
  """
180
180
  status = q_entry['status']
181
181
  flavors_res = status[res_field]
xpk/commands/job.py CHANGED
@@ -18,6 +18,7 @@ import re
18
18
  import sys
19
19
 
20
20
  from ruamel.yaml import YAML
21
+ from typing import cast
21
22
 
22
23
  from ..core.commands import run_command_for_value, run_command_with_updates
23
24
  from ..core.cluster import get_cluster_credentials
@@ -84,7 +85,7 @@ def job_info(args):
84
85
 
85
86
 
86
87
  def get_profile(job_yaml: dict) -> str:
87
- containers = (
88
+ containers: list[dict] = (
88
89
  job_yaml.get('spec', {})
89
90
  .get('template', {})
90
91
  .get('spec', {})
@@ -96,13 +97,13 @@ def get_profile(job_yaml: dict) -> str:
96
97
 
97
98
 
98
99
  def get_mounts(job_yaml: dict) -> list[dict]:
99
- containers = (
100
+ containers: list[dict] = (
100
101
  job_yaml.get('spec', {})
101
102
  .get('template', {})
102
103
  .get('spec', {})
103
104
  .get('containers', [])
104
105
  )
105
- mounts = next(iter(containers), {}).get('volumeMounts', [])
106
+ mounts: list[dict] = next(iter(containers), {}).get('volumeMounts', [])
106
107
  return mounts
107
108
 
108
109
 
@@ -112,23 +113,24 @@ def get_kjob_env_vars(job_desc_text: str) -> list[tuple[str, str]]:
112
113
  return search_res
113
114
 
114
115
 
115
- def get_pods(pods_text: str) -> list[str]:
116
+ def get_pods(pods_text: str) -> list[dict[str, str]]:
116
117
  pods_lines = pods_text.strip().split('\n')
117
- pods_lines = [line.split() for line in pods_lines]
118
+ pods_lines_tokenized = [line.split() for line in pods_lines]
118
119
  return [
119
120
  {
120
- 'Name': line[0],
121
- 'Status': line[2],
121
+ 'Name': tokens[0],
122
+ 'Status': tokens[2],
122
123
  }
123
- for line in pods_lines
124
+ for tokens in pods_lines_tokenized
124
125
  ]
125
126
 
126
127
 
127
128
  def get_script_name(job_yaml: dict) -> str | None:
128
- return (
129
+ return cast(
130
+ str | None,
129
131
  job_yaml.get('metadata', {})
130
132
  .get('annotations', {})
131
- .get('kjobctl.x-k8s.io/script', '')
133
+ .get('kjobctl.x-k8s.io/script', ''),
132
134
  )
133
135
 
134
136
 
@@ -33,6 +33,7 @@ from ..core.resources import get_cluster_capacity_type, get_cluster_system_chara
33
33
  def add_gpu_networking_annotations_to_command(args, cmd: str) -> str:
34
34
  gpu_type = get_gpu_type_from_cluster(args)
35
35
 
36
+ annotations: tuple
36
37
  if gpu_type == H100_MEGA_DEVICE_TYPE:
37
38
  annotations = get_a3mega_pod_template_annotations(args)
38
39
  elif gpu_type == H200_DEVICE_TYPE:
@@ -40,7 +41,7 @@ def add_gpu_networking_annotations_to_command(args, cmd: str) -> str:
40
41
  elif gpu_type == B200_DEVICE_TYPE:
41
42
  annotations = get_a4_pod_template_annotations(args)
42
43
  else:
43
- annotations = []
44
+ annotations = tuple()
44
45
 
45
46
  flags = [
46
47
  f" --pod-template-annotation {annotation} " for annotation in annotations
xpk/commands/storage.py CHANGED
@@ -141,7 +141,7 @@ def storage_delete(args: Namespace) -> None:
141
141
 
142
142
  def storage_attach(args: Namespace) -> None:
143
143
  add_zone_and_project(args)
144
- manifest = [{}]
144
+ manifest: list[dict] = [{}]
145
145
  if args.type == GCP_FILESTORE_TYPE:
146
146
  if args.instance is None:
147
147
  args.instance = args.name
xpk/commands/workload.py CHANGED
@@ -84,6 +84,7 @@ from ..core.system_characteristics import (
84
84
  from ..core.vertex import create_vertex_experiment
85
85
  from ..core.workload import (
86
86
  check_if_workload_exists,
87
+ get_jobsets_list_gcp_link,
87
88
  get_workload_list,
88
89
  wait_for_job_completion,
89
90
  zone_to_region,
@@ -226,7 +227,8 @@ spec:
226
227
  metadata:
227
228
  labels:
228
229
  xpk.google.com/workload: {args.workload}
229
- annotations: {annotations}
230
+ annotations:
231
+ {annotations}
230
232
  spec:
231
233
  priorityClassName: {args.priority}
232
234
  restartPolicy: Never
@@ -319,7 +321,7 @@ def workload_create(args) -> None:
319
321
  xpk_print('Starting workload create', flush=True)
320
322
  system, return_code = get_system_characteristics(args)
321
323
 
322
- if return_code > 0:
324
+ if return_code > 0 or system is None:
323
325
  xpk_print('Fetching system characteristics failed!')
324
326
  xpk_exit(return_code)
325
327
 
@@ -345,7 +347,7 @@ def workload_create(args) -> None:
345
347
  ):
346
348
  xpk_print(
347
349
  'Warning: Cluster has been created using XPK version:'
348
- f' {cluster_config_map["xpk_version"]} but the XPK version you are'
350
+ f' {cluster_xpk_version} but the XPK version you are'
349
351
  f' using to schedule workload is: {XPK_CURRENT_VERSION}. Some features'
350
352
  ' might not be available for this cluster. We recommend to'
351
353
  ' upgrade/downgrade your XPK version or cluster by running `xpk'
@@ -354,7 +356,7 @@ def workload_create(args) -> None:
354
356
 
355
357
  debugging_dashboard_id = None
356
358
 
357
- tensorboard_config = {}
359
+ tensorboard_config: dict | None = {}
358
360
  if VERTEX_TENSORBOARD_FEATURE_FLAG and args.use_vertex_tensorboard:
359
361
  tensorboard_config = create_vertex_experiment(args)
360
362
  # exit if failed to create Experiment in Vertex AI
@@ -450,8 +452,8 @@ def workload_create(args) -> None:
450
452
  - action: FailJobSet
451
453
  onJobFailureReasons:
452
454
  - PodFailurePolicy"""
453
- restart_on_exit_codes = get_restart_exit_codes(args)
454
- restart_on_exit_codes = ','.join(map(str, restart_on_exit_codes))
455
+ restart_on_exit_codes_list = get_restart_exit_codes(args)
456
+ restart_on_exit_codes = ','.join(map(str, restart_on_exit_codes_list))
455
457
  pod_failure_policy = f"""
456
458
  podFailurePolicy:
457
459
  rules:
@@ -760,4 +762,8 @@ def workload_list(args) -> None:
760
762
  xpk_print(f'List Job request returned ERROR {return_code}')
761
763
  xpk_exit(return_code)
762
764
  xpk_print(f'Workload List Output:\n{return_value}')
765
+
766
+ workload_list_gcp_link = get_jobsets_list_gcp_link(project=args.project)
767
+ xpk_print(f'See your workloads in Cloud Console: {workload_list_gcp_link}')
768
+
763
769
  xpk_exit(0)
@@ -34,7 +34,7 @@ from ..system_characteristics import get_system_characteristics_by_device_type
34
34
  from .blueprint_definitions import Blueprint, DeploymentGroup, DeploymentModule
35
35
  from ..kueue import KUEUE_VERSION
36
36
 
37
- yaml = yaml.YAML()
37
+ yaml_parser = yaml.YAML()
38
38
 
39
39
  a3high_device_type = H100_DEVICE_TYPE
40
40
  a3mega_device_type = H100_MEGA_DEVICE_TYPE
@@ -52,7 +52,7 @@ blueprint_dependencies_dir = {
52
52
  }
53
53
 
54
54
  cluster_toolkit_url = "github.com/GoogleCloudPlatform/cluster-toolkit"
55
- cluster_toolkit_version = "v1.57.1"
55
+ cluster_toolkit_version = "v1.62.2"
56
56
 
57
57
 
58
58
  class BlueprintGeneratorOutput:
@@ -1019,7 +1019,7 @@ class BlueprintGenerator:
1019
1019
  ) -> str:
1020
1020
  blueprint_path = self._get_blueprint_path(blueprint_name, prefix)
1021
1021
  with open(blueprint_path, "w+", encoding="utf-8") as blueprint_file:
1022
- yaml.dump(xpk_blueprint, blueprint_file)
1022
+ yaml_parser.dump(xpk_blueprint, blueprint_file)
1023
1023
  return blueprint_path
1024
1024
 
1025
1025
  def _get_blueprint_path(self, blueprint_name, prefix: str = ""):
@@ -1033,7 +1033,7 @@ class BlueprintGenerator:
1033
1033
  ensure_directory_exists(storage_path_with_prefix)
1034
1034
  return storage_path_with_prefix
1035
1035
 
1036
- def blueprint_exists(self, blueprint_name, prefix: str = ""):
1036
+ def blueprint_exists(self, blueprint_name, prefix: str = "") -> bool:
1037
1037
  blueprint_path = self._get_blueprint_path(blueprint_name, prefix)
1038
1038
  return os.path.exists(blueprint_path)
1039
1039
 
@@ -1061,6 +1061,6 @@ class BlueprintGenerator:
1061
1061
  }
1062
1062
 
1063
1063
 
1064
- yaml.register_class(Blueprint)
1065
- yaml.register_class(DeploymentGroup)
1066
- yaml.register_class(DeploymentModule)
1064
+ yaml_parser.register_class(Blueprint)
1065
+ yaml_parser.register_class(DeploymentGroup)
1066
+ yaml_parser.register_class(DeploymentModule)
@@ -0,0 +1,218 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import os
18
+ import shutil
19
+
20
+ import ruamel.yaml
21
+
22
+ from xpk.core.blueprint.blueprint_definitions import Blueprint
23
+ from xpk.core.blueprint.blueprint_generator import BlueprintGenerator
24
+ from xpk.core.capacity import CapacityType
25
+
26
+ yaml = ruamel.yaml.YAML()
27
+
28
+ yaml.register_class(Blueprint)
29
+
30
+ a3_yaml_test_path = "src/xpk/core/blueprint/testing/data/a3_mega.yaml"
31
+ a3_spot_yaml_test_path = "src/xpk/core/blueprint/testing/data/a3_mega_spot.yaml"
32
+ a3_ultra_yaml_test_path = "src/xpk/core/blueprint/testing/data/a3_ultra.yaml"
33
+ a4_yaml_test_path = "src/xpk/core/blueprint/testing/data/a4.yaml"
34
+ config_map_filename = "config-map.yaml.tftpl"
35
+ kueue_conf_filename = "kueue-xpk-configuration.yaml.tftpl"
36
+ tmp_test_dir = "/tmp/xpk_test"
37
+
38
+
39
+ def prepare_test():
40
+ if os.path.exists(tmp_test_dir):
41
+ shutil.rmtree(tmp_test_dir)
42
+ os.mkdir(tmp_test_dir)
43
+
44
+
45
+ def test_generate_a3_mega_blueprint():
46
+ prepare_test()
47
+ blueprint_name = "xpk-gke-a3-megagpu"
48
+ bp_generator = BlueprintGenerator(tmp_test_dir)
49
+ bp = bp_generator.generate_a3_mega_blueprint(
50
+ project_id="foo",
51
+ cluster_name="bar",
52
+ blueprint_name=blueprint_name,
53
+ prefix="prefix",
54
+ region="us-central1",
55
+ zone="us-central1-c",
56
+ auth_cidr="10.0.0.0/32",
57
+ reservation_placement_policy={
58
+ "type": "COMPACT",
59
+ "name": "test-reservation-placement",
60
+ },
61
+ reservation="test-reservation",
62
+ capacity_type=CapacityType.RESERVATION,
63
+ system_node_pool_min_node_count=5,
64
+ )
65
+
66
+ assert bp.blueprint_file.endswith("/prefix/xpk-gke-a3-megagpu.yaml")
67
+
68
+ with open(a3_yaml_test_path, encoding="utf-8") as stream:
69
+ ctk_yaml = yaml.load(stream)
70
+ with open(bp.blueprint_file, encoding="utf-8") as generated_blueprint:
71
+ ctk_test = yaml.load(generated_blueprint)
72
+ assert ctk_yaml.blueprint_name == ctk_test.blueprint_name
73
+ assert ctk_test.terraform_backend_defaults is None
74
+ assert ctk_yaml.toolkit_modules_url == ctk_test.toolkit_modules_url
75
+ assert (
76
+ ctk_yaml.toolkit_modules_version == ctk_test.toolkit_modules_version
77
+ )
78
+ assert ctk_yaml.vars == ctk_test.vars
79
+ assert ctk_test.deployment_groups == ctk_yaml.deployment_groups
80
+ assert os.path.exists(
81
+ os.path.join(
82
+ tmp_test_dir, "prefix", blueprint_name, config_map_filename
83
+ )
84
+ )
85
+ assert os.path.exists(
86
+ os.path.join(
87
+ tmp_test_dir, "prefix", blueprint_name, kueue_conf_filename
88
+ )
89
+ )
90
+
91
+ shutil.rmtree(tmp_test_dir)
92
+
93
+
94
+ def test_generate_a3_mega_spot_blueprint():
95
+ prepare_test()
96
+ blueprint_name = "xpk-gke-a3-megagpu"
97
+ bp_generator = BlueprintGenerator(tmp_test_dir)
98
+ bp = bp_generator.generate_a3_mega_blueprint(
99
+ project_id="foo",
100
+ cluster_name="bar",
101
+ blueprint_name=blueprint_name,
102
+ prefix="prefix",
103
+ region="us-central1",
104
+ zone="us-central1-c",
105
+ auth_cidr="10.0.0.0/32",
106
+ capacity_type=CapacityType.SPOT,
107
+ system_node_pool_min_node_count=5,
108
+ )
109
+
110
+ assert bp.blueprint_file.endswith("/prefix/xpk-gke-a3-megagpu.yaml")
111
+
112
+ with open(a3_spot_yaml_test_path, encoding="utf-8") as stream:
113
+ ctk_yaml = yaml.load(stream)
114
+ with open(bp.blueprint_file, encoding="utf-8") as generated_blueprint:
115
+ ctk_test = yaml.load(generated_blueprint)
116
+ assert ctk_yaml.blueprint_name == ctk_test.blueprint_name
117
+ assert ctk_test.terraform_backend_defaults is None
118
+ assert ctk_yaml.toolkit_modules_url == ctk_test.toolkit_modules_url
119
+ assert (
120
+ ctk_yaml.toolkit_modules_version == ctk_test.toolkit_modules_version
121
+ )
122
+ assert ctk_yaml.vars == ctk_test.vars
123
+ assert ctk_test.deployment_groups == ctk_yaml.deployment_groups
124
+
125
+ shutil.rmtree(tmp_test_dir)
126
+
127
+
128
+ def test_generate_a3_ultra_blueprint():
129
+ prepare_test()
130
+ blueprint_name = "xpk-gke-a3-ultra"
131
+ bp_generator = BlueprintGenerator(tmp_test_dir)
132
+ bp = bp_generator.generate_a3_ultra_blueprint(
133
+ project_id="foo",
134
+ cluster_name="gke-a3-ultra",
135
+ blueprint_name=blueprint_name,
136
+ region="us-central1",
137
+ zone="us-central1-c",
138
+ auth_cidr="10.0.0.0/32",
139
+ reservation="test-reservation",
140
+ system_node_pool_machine_type="e2-standard-16",
141
+ capacity_type=CapacityType.RESERVATION,
142
+ gcs_bucket="test-bucket",
143
+ prefix="testdir",
144
+ )
145
+ with open(a3_ultra_yaml_test_path, encoding="utf-8") as stream:
146
+ ctk_yaml = yaml.load(stream)
147
+ with open(bp.blueprint_file, encoding="utf-8") as generated_blueprint:
148
+ ctk_test = yaml.load(generated_blueprint)
149
+ assert ctk_yaml.blueprint_name == ctk_test.blueprint_name
150
+ assert (
151
+ ctk_yaml.terraform_backend_defaults
152
+ == ctk_test.terraform_backend_defaults
153
+ )
154
+ assert ctk_yaml.toolkit_modules_url == ctk_test.toolkit_modules_url
155
+ assert (
156
+ ctk_yaml.toolkit_modules_version == ctk_test.toolkit_modules_version
157
+ )
158
+ assert ctk_test.deployment_groups == ctk_yaml.deployment_groups
159
+ assert os.path.exists(
160
+ os.path.join(
161
+ tmp_test_dir, "testdir", blueprint_name, "mlgru-disable.yaml"
162
+ )
163
+ )
164
+ assert os.path.exists(
165
+ os.path.join(
166
+ tmp_test_dir, "testdir", blueprint_name, "nccl-installer.yaml"
167
+ )
168
+ )
169
+
170
+ shutil.rmtree(tmp_test_dir)
171
+
172
+
173
+ def test_generate_a4_blueprint():
174
+ prepare_test()
175
+ blueprint_name = "xpk-gke-a4"
176
+ bp_generator = BlueprintGenerator(tmp_test_dir)
177
+ bp = bp_generator.generate_a4_blueprint(
178
+ project_id="foo",
179
+ cluster_name="gke-a4",
180
+ blueprint_name=blueprint_name,
181
+ region="us-central1",
182
+ zone="us-central1-c",
183
+ auth_cidr="10.0.0.0/32",
184
+ reservation="test-reservation",
185
+ system_node_pool_machine_type="e2-standard-16",
186
+ capacity_type=CapacityType.RESERVATION,
187
+ gcs_bucket="test-bucket",
188
+ prefix="testdir",
189
+ )
190
+ with open(a4_yaml_test_path, encoding="utf-8") as stream:
191
+ ctk_yaml = yaml.load(stream)
192
+ with open(bp.blueprint_file, encoding="utf-8") as generated_blueprint:
193
+ ctk_test = yaml.load(generated_blueprint)
194
+ assert ctk_yaml.blueprint_name == ctk_test.blueprint_name
195
+ assert (
196
+ ctk_yaml.terraform_backend_defaults
197
+ == ctk_test.terraform_backend_defaults
198
+ )
199
+ assert ctk_yaml.toolkit_modules_url == ctk_test.toolkit_modules_url
200
+ assert (
201
+ ctk_yaml.toolkit_modules_version == ctk_test.toolkit_modules_version
202
+ )
203
+ assert ctk_test.deployment_groups == ctk_yaml.deployment_groups
204
+ assert os.path.exists(
205
+ os.path.join(
206
+ tmp_test_dir, "testdir", blueprint_name, "storage_crd.yaml"
207
+ )
208
+ )
209
+ assert os.path.exists(
210
+ os.path.join(
211
+ tmp_test_dir,
212
+ "testdir",
213
+ blueprint_name,
214
+ "nccl-rdma-installer-a4.yaml",
215
+ )
216
+ )
217
+
218
+ shutil.rmtree(tmp_test_dir)
xpk/core/capacity.py CHANGED
@@ -195,10 +195,12 @@ def get_capacity_arguments_from_capacity_type(
195
195
  capacity_args = '--spot'
196
196
  case CapacityType.FLEX_START:
197
197
  capacity_args = (
198
- ' --flex-start --enable-queued-provisioning --enable-autoscaling'
198
+ ' --flex-start --enable-autoscaling'
199
199
  ' --location-policy=ANY --reservation-affinity=none'
200
200
  f' --no-enable-autorepair --max-nodes={max_nodes}'
201
201
  )
202
+ if args.num_slices <= 1:
203
+ capacity_args += ' --enable-queued-provisioning'
202
204
  case CapacityType.RESERVATION:
203
205
  capacity_args = (
204
206
  f'--reservation-affinity=specific --reservation={args.reservation}'