xpk 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. xpk/commands/cluster.py +29 -30
  2. xpk/commands/cluster_gcluster.py +19 -14
  3. xpk/commands/cluster_test.py +1 -21
  4. xpk/commands/common.py +39 -6
  5. xpk/commands/common_test.py +170 -0
  6. xpk/commands/info.py +9 -5
  7. xpk/commands/inspector.py +33 -4
  8. xpk/commands/inspector_test.py +142 -0
  9. xpk/commands/workload.py +35 -17
  10. xpk/commands/workload_test.py +70 -3
  11. xpk/core/blueprint/blueprint_generator.py +19 -8
  12. xpk/core/blueprint/testing/data/a3_ultra.yaml +3 -1
  13. xpk/core/blueprint/testing/data/a4.yaml +3 -1
  14. xpk/core/capacity.py +37 -17
  15. xpk/core/capacity_test.py +66 -1
  16. xpk/core/cluster.py +10 -10
  17. xpk/core/cluster_private.py +3 -3
  18. xpk/core/cluster_test.py +29 -2
  19. xpk/core/docker_container.py +55 -30
  20. xpk/core/docker_manager.py +4 -4
  21. xpk/core/docker_resources.py +4 -1
  22. xpk/core/kueue_manager.py +6 -8
  23. xpk/core/kueue_manager_test.py +4 -5
  24. xpk/core/nap.py +14 -3
  25. xpk/core/nodepool.py +46 -13
  26. xpk/core/nodepool_test.py +143 -8
  27. xpk/core/pathways.py +4 -8
  28. xpk/core/remote_state/fuse_remote_state.py +1 -1
  29. xpk/core/scheduling.py +16 -13
  30. xpk/core/scheduling_test.py +15 -7
  31. xpk/core/system_characteristics.py +6 -0
  32. xpk/core/telemetry.py +11 -1
  33. xpk/core/telemetry_test.py +39 -0
  34. xpk/core/testing/commands_tester.py +26 -0
  35. xpk/core/testing/commands_tester_test.py +20 -1
  36. xpk/core/workload_decorators/rdma_decorator.py +9 -0
  37. xpk/parser/cluster.py +11 -1
  38. xpk/parser/cluster_test.py +59 -1
  39. xpk/parser/common.py +11 -0
  40. xpk/parser/storage.py +3 -3
  41. xpk/utils/console.py +1 -1
  42. xpk/utils/feature_flags.py +7 -3
  43. {xpk-1.0.0.dist-info → xpk-1.1.1.dist-info}/METADATA +37 -21
  44. {xpk-1.0.0.dist-info → xpk-1.1.1.dist-info}/RECORD +48 -55
  45. xpk-1.1.1.dist-info/top_level.txt +1 -0
  46. integration/README.md +0 -19
  47. integration/__init__.py +0 -15
  48. integration/docker_manager_test.py +0 -102
  49. integration/gcluster_a3mega_test.py +0 -215
  50. integration/gcluster_a3ultra_test.py +0 -187
  51. integration/gcluster_a4_test.py +0 -187
  52. integration/gcluster_test.py +0 -107
  53. xpk/utils/user_input.py +0 -48
  54. xpk/utils/user_input_test.py +0 -92
  55. xpk-1.0.0.dist-info/top_level.txt +0 -2
  56. {xpk-1.0.0.dist-info → xpk-1.1.1.dist-info}/WHEEL +0 -0
  57. {xpk-1.0.0.dist-info → xpk-1.1.1.dist-info}/entry_points.txt +0 -0
  58. {xpk-1.0.0.dist-info → xpk-1.1.1.dist-info}/licenses/LICENSE +0 -0
xpk/core/cluster.py CHANGED
@@ -158,7 +158,7 @@ def install_nri_on_cluster() -> int:
158
158
 
159
159
 
160
160
  def get_cluster_nodes_info() -> list[dict]:
161
- """Get list of cluster's nodes descrition in yaml format
161
+ """Get list of cluster's nodes description in yaml format
162
162
 
163
163
  Returns:
164
164
  List of nodes info yaml objects.
@@ -393,11 +393,13 @@ def project_id_to_project_number(project_id: str) -> str:
393
393
  def setup_k8s_env(args) -> k8s_client.ApiClient:
394
394
  add_zone_and_project(args)
395
395
  get_cluster_credentials(args)
396
- args.project_number = (
397
- project_id_to_project_number(args.project)
398
- if not args.dry_run
399
- else abs(hash(args.project) % (10**12)) # 12 digit hash
400
- )
396
+ # Use provided project number if available, otherwise fetch via API
397
+ if getattr(args, 'project_number', None):
398
+ xpk_print(f'Using provided project number: {args.project_number}')
399
+ elif args.dry_run:
400
+ args.project_number = abs(hash(args.project) % (10**12)) # 12 digit hash
401
+ else:
402
+ args.project_number = project_id_to_project_number(args.project)
401
403
 
402
404
  config.load_kube_config()
403
405
  return k8s_client.ApiClient()
@@ -716,10 +718,8 @@ def get_cluster_credentials(args) -> int:
716
718
  location=location,
717
719
  dns_endpoint=True,
718
720
  )
719
- if return_code != 0:
720
- return return_code
721
721
 
722
- if not _are_credentials_valid():
722
+ if return_code != 0 or not _are_credentials_valid():
723
723
  xpk_print('Detected error. Retrying without --dns-endpoint flag...')
724
724
  return_code = _get_credentials(
725
725
  project=args.project,
@@ -751,6 +751,6 @@ def _get_credentials(
751
751
  def _are_credentials_valid() -> bool:
752
752
  kubectl_command = 'kubectl get pods'
753
753
  kubectl_return_code = run_command_with_updates(
754
- kubectl_command, 'Test kubectl credentials'
754
+ kubectl_command, 'Test kubectl credentials', verbose=False
755
755
  )
756
756
  return kubectl_return_code == 0
@@ -61,7 +61,7 @@ def authorize_private_cluster_access_if_necessary(args) -> int:
61
61
  if new_authorized_networks_needed or not is_current_machine_in_network:
62
62
  return update_cluster_new_authorized_networks(args, authorized_networks)
63
63
 
64
- xpk_print("Current machine's IP adrress is already authorized.")
64
+ xpk_print("Current machine's IP address is already authorized.")
65
65
  return 0
66
66
 
67
67
 
@@ -84,7 +84,7 @@ def add_current_machine_to_networks_if_needed(
84
84
  is_current_machine_in_any_network(authorized_networks)
85
85
  )
86
86
  if is_current_machine_in_network_return_code != 0:
87
- xpk_print("Error on checking current machine's IP adrress.")
87
+ xpk_print("Error on checking current machine's IP address.")
88
88
  return is_current_machine_in_network_return_code, False, authorized_networks
89
89
 
90
90
  if not is_current_machine_in_network:
@@ -148,7 +148,7 @@ def is_cluster_private(args) -> bool:
148
148
 
149
149
 
150
150
  def get_cluster_authorized_networks(args) -> list[str]:
151
- """Retreives the networks list that are authorized to have access to Control Plane.
151
+ """Retrieves the networks list that are authorized to have access to Control Plane.
152
152
  Args:
153
153
  args: user provided arguments for running the command.
154
154
 
xpk/core/cluster_test.py CHANGED
@@ -41,11 +41,15 @@ def command_args(mocker: MockerFixture):
41
41
  return mocker.Mock(cluster="cluster", project="project", zone="zone")
42
42
 
43
43
 
44
- def test_get_cluster_credentials_returns_1_when_retrieval_command_fails(
44
+ def test_get_cluster_credentials_returns_1_when_retrieval_commands_fail(
45
45
  commands_tester: CommandsTester, command_args
46
46
  ):
47
47
  commands_tester.set_result_for_command(
48
- (1, ""), "gcloud container clusters get-credentials"
48
+ (1, ""), "gcloud container clusters get-credentials", " --dns-endpoint"
49
+ )
50
+ commands_tester.set_result_for_command(
51
+ (1, ""),
52
+ "gcloud container clusters get-credentials",
49
53
  )
50
54
  assert get_cluster_credentials(command_args) == 1
51
55
 
@@ -95,6 +99,29 @@ def test_get_cluster_credentials_retries_without_dns_when_dns_retrieval_fails(
95
99
  assert len(non_dns_endpoint_commands) == 1
96
100
 
97
101
 
102
+ def test_get_cluster_credentials_retries_without_dns_when_dns_retrieval_returns_error(
103
+ commands_tester: CommandsTester, command_args
104
+ ):
105
+ commands_tester.set_result_for_command(
106
+ (1, ""), "gcloud container clusters get-credentials", "--dns-endpoint"
107
+ )
108
+ commands_tester.set_result_for_command(
109
+ (0, ""),
110
+ "gcloud container clusters get-credentials",
111
+ )
112
+
113
+ assert get_cluster_credentials(command_args) == 0
114
+
115
+ non_dns_endpoint_commands = [
116
+ c
117
+ for c in commands_tester.get_matching_commands(
118
+ "gcloud container clusters get-credentials"
119
+ )
120
+ if "dns-endpoint" not in c
121
+ ]
122
+ assert len(non_dns_endpoint_commands) == 1
123
+
124
+
98
125
  def test_update_cluster_with_lustre_driver_if_necessary_with_default_port_runs_correct_checks(
99
126
  commands_tester: CommandsTester, command_args
100
127
  ):
@@ -17,9 +17,7 @@ limitations under the License.
17
17
  from ..utils.console import xpk_exit, xpk_print
18
18
  from .docker_image import setup_docker_image
19
19
  from .docker_resources import (
20
- add_container_ports,
21
20
  add_image_pull_policy_for_pw_or_gpu,
22
- add_jax_coordinator_port,
23
21
  get_env_container,
24
22
  get_main_container_resources,
25
23
  get_volume_mounts,
@@ -32,12 +30,18 @@ from .system_characteristics import (
32
30
  )
33
31
 
34
32
 
35
- def get_main_and_sidecar_container(args, system, docker_image) -> str:
33
+ def get_main_and_sidecar_container(
34
+ args,
35
+ system: SystemCharacteristics,
36
+ docker_image: str,
37
+ parallel_containers: int,
38
+ ) -> str:
36
39
  """Generate yaml for main and sidecar container.
37
40
  Args:
38
41
  args: user provided arguments for running the command.
39
42
  system: system characteristics
40
43
  docker_image: docker image
44
+ parallel_containers: number of containers to run per VM.
41
45
 
42
46
  Returns:
43
47
  str:
@@ -46,7 +50,9 @@ def get_main_and_sidecar_container(args, system, docker_image) -> str:
46
50
  resource_type = AcceleratorTypeToAcceleratorCharacteristics[
47
51
  system.accelerator_type
48
52
  ].resource_type
49
- main_container = get_main_container(args, system, docker_image, resource_type)
53
+ main_container = get_main_container(
54
+ args, system, docker_image, resource_type, parallel_containers
55
+ )
50
56
  yaml = """- name: stacktrace-explorer
51
57
  image: busybox:1.28
52
58
  args: [/bin/sh, -c, "check_signal() (while [ ! -f /shared-volume/stacktrace_signal ]; do sleep 1; done; pid=$(pidof 'tail'); kill $pid;); check_signal & while [ ! -d /tmp/debugging ]; do sleep 60; done; while [ ! -e /tmp/debugging/* ]; do sleep 60; done; tail -n+1 -f /tmp/debugging/*; exit 0;"]
@@ -61,13 +67,20 @@ def get_main_and_sidecar_container(args, system, docker_image) -> str:
61
67
  return yaml.format(main_container=main_container)
62
68
 
63
69
 
64
- def get_main_container(args, system, docker_image, resource_type) -> str:
70
+ def get_main_container(
71
+ args,
72
+ system: SystemCharacteristics,
73
+ docker_image: str,
74
+ resource_type,
75
+ parallel_containers: int,
76
+ ) -> str:
65
77
  """Generate yaml for main container including the xpk command.
66
78
  Args:
67
79
  args: user provided arguments for running the command.
68
80
  system: system characteristics
69
81
  docker_image: docker image
70
82
  resource_type: The label to describe the resource type for TPUs/GPUs/CPUs.
83
+ parallel_containers: number of containers to run per VM.
71
84
 
72
85
  Returns:
73
86
  str:
@@ -112,13 +125,12 @@ def get_main_container(args, system, docker_image, resource_type) -> str:
112
125
  'touch /shared-volume/stacktrace_signal; '
113
126
  )
114
127
 
115
- yaml = """- name: {docker_name}
128
+ containers = []
129
+ container_yaml = """
130
+ - name: {docker_name}
116
131
  image: {docker_image}
117
132
  {image_pull_policy}
118
133
  env: {env}
119
- ports:
120
- {container_ports}
121
- {jax_coordinator_port}
122
134
  securityContext:
123
135
  privileged: true
124
136
  command:
@@ -145,37 +157,46 @@ def get_main_container(args, system, docker_image, resource_type) -> str:
145
157
  limits:
146
158
  {resources}
147
159
  """
160
+ docker_name = get_main_container_docker_image(args, system)
148
161
  volume_mounts = get_volume_mounts(args, system)
149
162
  if volume_mounts != '':
150
- yaml += """
163
+ container_yaml += """
151
164
  volumeMounts:
152
165
  {volume_mounts}
153
166
  """
154
- return yaml.format(
155
- args=args,
156
- system=system,
157
- image_pull_policy=add_image_pull_policy_for_pw_or_gpu(args, system),
158
- env=get_env_container(args, system),
159
- container_ports=add_container_ports(args, system),
160
- jax_coordinator_port=add_jax_coordinator_port(system),
161
- docker_name=get_main_container_docker_image(args, system),
162
- docker_image=docker_image,
163
- gsutil_test_command=gsutil_test_command,
164
- command=command,
165
- tpu_stacktrace_terminate_command=tpu_stacktrace_terminate_command,
166
- gpu_workload_terminate_command=gpu_workload_terminate_command,
167
- xpk_internal_commands=xpk_internal_commands,
168
- resources=get_main_container_resources(args, system, resource_type),
169
- volume_mounts=volume_mounts,
170
- )
167
+ env = get_env_container(args, system)
168
+ image_pull_policy = add_image_pull_policy_for_pw_or_gpu(args, system)
169
+ for i in range(parallel_containers):
170
+ docker_name_sufix = f'-{i + 1}' if parallel_containers > 1 else ''
171
+ containers.append(
172
+ container_yaml.format(
173
+ args=args,
174
+ system=system,
175
+ image_pull_policy=image_pull_policy,
176
+ env=env,
177
+ docker_name=f'{docker_name}{docker_name_sufix}',
178
+ docker_image=docker_image,
179
+ gsutil_test_command=gsutil_test_command,
180
+ command=command,
181
+ tpu_stacktrace_terminate_command=tpu_stacktrace_terminate_command,
182
+ gpu_workload_terminate_command=gpu_workload_terminate_command,
183
+ xpk_internal_commands=xpk_internal_commands,
184
+ resources=get_main_container_resources(args, system, resource_type),
185
+ volume_mounts=volume_mounts,
186
+ )
187
+ )
188
+ return ''.join(containers)
171
189
 
172
190
 
173
- def get_user_workload_container(args, system: SystemCharacteristics):
191
+ def get_user_workload_container(
192
+ args, system: SystemCharacteristics, parallel_containers: int
193
+ ):
174
194
  """Deploy user workload container
175
195
 
176
196
  Args:
177
197
  args: user provided args.
178
198
  system: system characteristics.
199
+ parallel_containers: number of containers to run per VM.
179
200
 
180
201
  Returns:
181
202
  container: main container
@@ -202,11 +223,15 @@ def get_user_workload_container(args, system: SystemCharacteristics):
202
223
  'Sidecar container to display stack traces for TPU workloads will also'
203
224
  ' be deployed.'
204
225
  )
205
- container = get_main_and_sidecar_container(args, system, docker_image)
226
+ container = get_main_and_sidecar_container(
227
+ args, system, docker_image, parallel_containers
228
+ )
206
229
  # Get GKE debugging dashboard only when sidecar container is deployed for TPU workloads
207
230
  debugging_dashboard_id = get_gke_debugging_dashboard(args)
208
231
  else:
209
- container = get_main_container(args, system, docker_image, resource_type)
232
+ container = get_main_container(
233
+ args, system, docker_image, resource_type, parallel_containers
234
+ )
210
235
  return container, debugging_dashboard_id
211
236
 
212
237
 
@@ -44,7 +44,7 @@ class CommandRunner(ABC):
44
44
 
45
45
  @abstractmethod
46
46
  def initialize(self) -> None:
47
- """initialize is a method that should implement all steps neccessary to run command.
47
+ """initialize is a method that should implement all steps necessary to run command.
48
48
 
49
49
  Returns:
50
50
  None
@@ -95,7 +95,7 @@ class DockerManager(CommandRunner):
95
95
  - gcloud_cfg_path (str) : path to directory containing gcloud configuration
96
96
  - working_dir (str) : path to directory in which gcluster deployment directory will be saved
97
97
  - client (DockerClient) : docker client
98
- - nocache (bool) : wheter to use docker cache when building image
98
+ - nocache (bool) : whether to use docker cache when building image
99
99
  - img_name (str) : name of docker image to create
100
100
  - container_name (str) : name of the container that will be created from img_name
101
101
  - rm_container_after (bool) : if set to True, docker container in which command is executed will be removed after each execution.
@@ -294,12 +294,12 @@ class DockerManager(CommandRunner):
294
294
  xpk_print(f"error while building image {self.img_name}: {e.msg}")
295
295
  xpk_exit(dockerBuildErrorCode)
296
296
  except APIError as e:
297
- xpk_print(f"erro while building image {self.img_name}: {e.explanation}")
297
+ xpk_print(f"error while building image {self.img_name}: {e.explanation}")
298
298
  xpk_exit(dockerBuildErrorCode)
299
299
  except TypeError as e:
300
300
  xpk_print(f"TypeError while building image {self.img_name}: {e.args}")
301
301
  xpk_exit(dockerBuildErrorCode)
302
- xpk_print("Docker image build succesfully.")
302
+ xpk_print("Docker image build successfully.")
303
303
  os.remove(self.dockerfile_path)
304
304
  tmp_dockerfile_dir = "/".join(self.dockerfile_path.split("/")[:-1])
305
305
  os.rmdir(tmp_dockerfile_dir)
@@ -53,7 +53,10 @@ def get_main_container_resources(
53
53
  offset_vCPUs = int(system.chips_per_vm) * 0.95
54
54
  return f'{resource_type}: {offset_vCPUs}'
55
55
 
56
- return f'{resource_type}: {system.chips_per_vm}'
56
+ return (
57
+ f'{resource_type}:'
58
+ f' {int(system.chips_per_vm / system.parallel_containers)}'
59
+ )
57
60
 
58
61
 
59
62
  def get_env_container(args, system: SystemCharacteristics) -> str:
xpk/core/kueue_manager.py CHANGED
@@ -41,8 +41,8 @@ from ..utils.console import xpk_print, xpk_exit, ask_for_user_consent
41
41
  from ..utils.templates import TEMPLATE_PATH, get_templates_absolute_path
42
42
  from packaging.version import Version
43
43
 
44
- KUEUE_VERSION = Version("v0.14.3")
45
- LATEST_BREAKING_VERSION = Version("v0.14.0")
44
+ KUEUE_VERSION = Version("v0.15.2")
45
+ LATEST_BREAKING_VERSION = Version("v0.15.0")
46
46
  WAIT_FOR_KUEUE_TIMEOUT = "10m"
47
47
  CLUSTER_QUEUE_NAME = "cluster-queue"
48
48
  LOCAL_QUEUE_NAME = "multislice-queue"
@@ -290,6 +290,7 @@ class KueueManager:
290
290
  cpu_limit=cpu_limit,
291
291
  memory_limit=memory_limit,
292
292
  topology_name=topology_name,
293
+ configure_super_slicing=kueue_config.configure_super_slicing,
293
294
  )
294
295
 
295
296
  config_yaml = template.render(context)
@@ -316,6 +317,7 @@ class KueueManager:
316
317
  cpu_limit: int,
317
318
  memory_limit: str,
318
319
  topology_name: str | None,
320
+ configure_super_slicing: bool,
319
321
  ) -> Dict[str, Any]:
320
322
  """Prepares the context for the Jinja2 template."""
321
323
  # Main accelerator flavor
@@ -328,11 +330,7 @@ class KueueManager:
328
330
  key, value = accelerator_label.split(":", 1)
329
331
  node_labels_dict[key] = value.strip()
330
332
 
331
- if system.supports_super_slicing:
332
- node_labels_dict["cloud.google.com/gke-tpu-partition-4x4x4-state"] = (
333
- "HEALTHY"
334
- )
335
- elif not autoprovisioning:
333
+ if not autoprovisioning and not configure_super_slicing:
336
334
  machine_label = create_machine_label(system)
337
335
  if machine_label:
338
336
  key, value = machine_label.split(":", 1)
@@ -383,7 +381,7 @@ class KueueManager:
383
381
  })
384
382
 
385
383
  admission_checks = []
386
- if system.supports_super_slicing:
384
+ if configure_super_slicing:
387
385
  admission_checks.append("ss-kueue-operator")
388
386
  if flex and is_queued_cluster(num_slices, system.accelerator_type):
389
387
  admission_checks.append("dws-prov")
@@ -113,7 +113,7 @@ def test_install_or_upgrade_when_outdated(
113
113
  result = kueue_manager.install_or_upgrade(KUEUE_CONFIG)
114
114
 
115
115
  assert result == 0
116
- mock_commands.assert_command_run("kubectl apply", "v0.14.3/manifests.yaml")
116
+ mock_commands.assert_command_run("kubectl apply", "v0.15.2/manifests.yaml")
117
117
  mock_commands.assert_command_run("kubectl apply -f", "/tmp/")
118
118
 
119
119
 
@@ -126,7 +126,7 @@ def test_install_or_upgrade_when_not_installed(
126
126
  result = kueue_manager.install_or_upgrade(KUEUE_CONFIG)
127
127
 
128
128
  assert result == 0
129
- mock_commands.assert_command_run("kubectl apply", "v0.14.3/manifests.yaml")
129
+ mock_commands.assert_command_run("kubectl apply", "v0.15.2/manifests.yaml")
130
130
  mock_commands.assert_command_run("kubectl apply -f", "/tmp/")
131
131
 
132
132
 
@@ -135,7 +135,7 @@ def test_upgrade_when_no_breaking_changes_between_versions_no_preparation_needed
135
135
  kueue_manager: KueueManager,
136
136
  mock_ask_for_user_consent: MagicMock,
137
137
  ):
138
- set_installed_kueue_version(mock_commands, Version("0.14.0"))
138
+ set_installed_kueue_version(mock_commands, Version("0.15.0"))
139
139
 
140
140
  kueue_manager.install_or_upgrade(KUEUE_CONFIG)
141
141
 
@@ -162,7 +162,7 @@ def test_upgrade_with_breaking_changes_between_versions_runs_preparation(
162
162
  assert result == 0
163
163
  mock_ask_for_user_consent.assert_called_once()
164
164
  assert (
165
- "CHANGELOG/CHANGELOG-0.14.md"
165
+ "CHANGELOG/CHANGELOG-0.15.md"
166
166
  in mock_ask_for_user_consent.mock_calls[0].args[0]
167
167
  )
168
168
  mock_commands.assert_command_run(
@@ -492,7 +492,6 @@ def test_configure_generates_correct_manifest_with_super_slicing(
492
492
  assert resource_flavor["spec"]["topologyName"] == "super-slice-topology"
493
493
  assert resource_flavor["spec"]["nodeLabels"] == {
494
494
  "cloud.google.com/gke-tpu-accelerator": "tpu7x",
495
- "cloud.google.com/gke-tpu-partition-4x4x4-state": "HEALTHY",
496
495
  }
497
496
  topology = _first(doc for doc in manifest_docs if doc["kind"] == "Topology")
498
497
  assert topology["metadata"]["name"] == "super-slice-topology"
xpk/core/nap.py CHANGED
@@ -24,7 +24,8 @@ from .capacity import (
24
24
  CapacityType,
25
25
  get_capacity_node_selectors_from_capacity_type,
26
26
  get_capacity_type,
27
- verify_reservation_exists,
27
+ get_reservations_list,
28
+ verify_reservations_exist,
28
29
  )
29
30
  from .commands import run_command_with_updates, run_commands
30
31
  from .gcloud_context import get_cluster_location
@@ -345,14 +346,24 @@ def get_autoprovisioning_node_selector_args(args) -> tuple[str, int]:
345
346
  )
346
347
  if return_code != 0:
347
348
  return node_selector_args, return_code
348
- return_code = verify_reservation_exists(args)
349
+ return_code = verify_reservations_exist(args)
349
350
  if return_code > 0:
350
351
  xpk_print('Unable to verify reservation name saved in config map.')
351
352
  return node_selector_args, return_code
352
353
 
353
354
  # Check if reservation id is valid. Shared function with cluster creation.
355
+ reservation_name = None
356
+ if capacity_type_str == CapacityType.RESERVATION.name:
357
+ reservations = get_reservations_list(args)
358
+ if len(reservations) > 1:
359
+ xpk_print('Error: NAP based clusters only support a single reservation.')
360
+ return node_selector_args, 1
361
+ reservation_name = reservations[0] if len(reservations) > 0 else None
362
+
354
363
  node_selector_args, return_code = (
355
- get_capacity_node_selectors_from_capacity_type(args, capacity_type_str)
364
+ get_capacity_node_selectors_from_capacity_type(
365
+ capacity_type_str, reservation_name
366
+ )
356
367
  )
357
368
  if return_code != 0:
358
369
  xpk_print('Unable to get node selectors from capacity type.')
xpk/core/nodepool.py CHANGED
@@ -14,7 +14,8 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from typing import List
17
+ from typing import Iterator, List
18
+ from itertools import cycle
18
19
 
19
20
  from ..utils.feature_flags import FeatureFlags
20
21
  from ..utils.console import ask_for_user_consent, xpk_print
@@ -25,6 +26,7 @@ from .capacity import (
25
26
  CapacityType,
26
27
  get_capacity_arguments_from_capacity_type,
27
28
  get_capacity_type,
29
+ get_reservations_list,
28
30
  print_reservations,
29
31
  )
30
32
  from .commands import run_command_for_value, run_commands, FailedCommand
@@ -85,12 +87,6 @@ def run_gke_node_pool_create_command(
85
87
  max_nodes = system.vms_per_slice
86
88
  else:
87
89
  max_nodes = 1000
88
- capacity_args, return_code = get_capacity_arguments_from_capacity_type(
89
- args, capacity_type, max_nodes, system.accelerator_type
90
- )
91
- if return_code > 0:
92
- xpk_print('Parsing capacity arguments failed!')
93
- return return_code
94
90
 
95
91
  desired_node_pool_count = (
96
92
  1 if system.accelerator_type == AcceleratorType.GPU else args.num_slices
@@ -274,9 +270,34 @@ def run_gke_node_pool_create_command(
274
270
 
275
271
  create_commands = []
276
272
  create_task_names = []
277
- for node_pool_name in desired_node_pool_names:
278
- if node_pool_name in node_pools_to_remain:
279
- continue
273
+ node_pools_to_create = [
274
+ np for np in desired_node_pool_names if np not in node_pools_to_remain
275
+ ]
276
+
277
+ reservations_iter: Iterator[str] | None = None
278
+ if capacity_type == CapacityType.RESERVATION:
279
+ reservations = get_reservations_list(args)
280
+ if (
281
+ _validate_reservation_count(reservations, len(node_pools_to_create))
282
+ != 0
283
+ ):
284
+ return 1
285
+ reservations_iter = (
286
+ cycle(reservations) if len(reservations) == 1 else iter(reservations)
287
+ )
288
+
289
+ for node_pool_name in node_pools_to_create:
290
+ capacity_args, return_code = get_capacity_arguments_from_capacity_type(
291
+ args,
292
+ capacity_type,
293
+ max_nodes,
294
+ system.accelerator_type,
295
+ reservation_name=next(reservations_iter) if reservations_iter else None,
296
+ )
297
+ if return_code > 0:
298
+ xpk_print('Parsing capacity arguments failed!')
299
+ return return_code
300
+
280
301
  command = (
281
302
  'gcloud beta container node-pools create'
282
303
  f' {node_pool_name}'
@@ -632,7 +653,7 @@ def ensure_resource_policy_exists(
632
653
  ) -> None:
633
654
  return_code, _ = run_command_for_value(
634
655
  (
635
- 'gcloud compute resource-policies describe'
656
+ 'gcloud beta compute resource-policies describe'
636
657
  f' {resource_policy_name}'
637
658
  f' --project={project}'
638
659
  f' --region={zone_to_region(zone)}'
@@ -643,13 +664,12 @@ def ensure_resource_policy_exists(
643
664
  if return_code == 0:
644
665
  return
645
666
 
646
- # TODO: b/465696970 - Verify the flag below before launching SUPER_SLICING:
647
667
  accelerator_topology_mode = (
648
668
  ' --accelerator-topology-mode=PROVISION_ONLY' if super_slicing else ''
649
669
  )
650
670
  return_code, _ = run_command_for_value(
651
671
  (
652
- 'gcloud compute resource-policies create workload-policy'
672
+ 'gcloud beta compute resource-policies create workload-policy'
653
673
  f' {resource_policy_name} --project={project} --region={zone_to_region(zone)} --type=HIGH_THROUGHPUT'
654
674
  f' --accelerator-topology={topology}{accelerator_topology_mode}'
655
675
  ),
@@ -658,3 +678,16 @@ def ensure_resource_policy_exists(
658
678
 
659
679
  if return_code != 0:
660
680
  raise RuntimeError('Unable to create resource policy')
681
+
682
+
683
+ def _validate_reservation_count(
684
+ reservations: List[str], num_node_pools_to_create: int
685
+ ) -> int:
686
+ """Validate that reservation count matches new nodepool count or is 1."""
687
+ if len(reservations) > 1 and len(reservations) != num_node_pools_to_create:
688
+ xpk_print(
689
+ f'Error: Number of reservations ({len(reservations)}) must match'
690
+ f' the number of NEW nodepools ({num_node_pools_to_create}) or be 1.'
691
+ )
692
+ return 1
693
+ return 0