xpk 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. xpk/commands/cluster.py +29 -30
  2. xpk/commands/cluster_gcluster.py +19 -14
  3. xpk/commands/cluster_test.py +1 -21
  4. xpk/commands/common.py +39 -6
  5. xpk/commands/common_test.py +170 -0
  6. xpk/commands/info.py +9 -5
  7. xpk/commands/inspector.py +33 -4
  8. xpk/commands/inspector_test.py +142 -0
  9. xpk/commands/workload.py +22 -8
  10. xpk/commands/workload_test.py +70 -3
  11. xpk/core/blueprint/blueprint_generator.py +19 -8
  12. xpk/core/blueprint/testing/data/a3_ultra.yaml +3 -1
  13. xpk/core/blueprint/testing/data/a4.yaml +3 -1
  14. xpk/core/capacity.py +37 -17
  15. xpk/core/capacity_test.py +66 -1
  16. xpk/core/cluster.py +10 -10
  17. xpk/core/cluster_private.py +3 -3
  18. xpk/core/cluster_test.py +29 -2
  19. xpk/core/docker_container.py +31 -24
  20. xpk/core/docker_manager.py +4 -4
  21. xpk/core/docker_resources.py +4 -1
  22. xpk/core/kueue_manager.py +6 -8
  23. xpk/core/kueue_manager_test.py +4 -5
  24. xpk/core/nap.py +14 -3
  25. xpk/core/nodepool.py +46 -13
  26. xpk/core/nodepool_test.py +143 -8
  27. xpk/core/remote_state/fuse_remote_state.py +1 -1
  28. xpk/core/scheduling.py +4 -1
  29. xpk/core/scheduling_test.py +1 -1
  30. xpk/core/system_characteristics.py +6 -0
  31. xpk/core/telemetry.py +11 -1
  32. xpk/core/telemetry_test.py +39 -0
  33. xpk/core/testing/commands_tester.py +26 -0
  34. xpk/core/testing/commands_tester_test.py +20 -1
  35. xpk/core/workload_decorators/rdma_decorator.py +9 -0
  36. xpk/parser/cluster.py +11 -1
  37. xpk/parser/cluster_test.py +59 -1
  38. xpk/parser/common.py +11 -0
  39. xpk/parser/storage.py +3 -3
  40. xpk/utils/console.py +1 -1
  41. xpk/utils/feature_flags.py +7 -3
  42. {xpk-1.0.0.dist-info → xpk-1.1.0.dist-info}/METADATA +37 -21
  43. {xpk-1.0.0.dist-info → xpk-1.1.0.dist-info}/RECORD +47 -54
  44. xpk-1.1.0.dist-info/top_level.txt +1 -0
  45. integration/README.md +0 -19
  46. integration/__init__.py +0 -15
  47. integration/docker_manager_test.py +0 -102
  48. integration/gcluster_a3mega_test.py +0 -215
  49. integration/gcluster_a3ultra_test.py +0 -187
  50. integration/gcluster_a4_test.py +0 -187
  51. integration/gcluster_test.py +0 -107
  52. xpk/utils/user_input.py +0 -48
  53. xpk/utils/user_input_test.py +0 -92
  54. xpk-1.0.0.dist-info/top_level.txt +0 -2
  55. {xpk-1.0.0.dist-info → xpk-1.1.0.dist-info}/WHEEL +0 -0
  56. {xpk-1.0.0.dist-info → xpk-1.1.0.dist-info}/entry_points.txt +0 -0
  57. {xpk-1.0.0.dist-info → xpk-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -61,7 +61,7 @@ def authorize_private_cluster_access_if_necessary(args) -> int:
61
61
  if new_authorized_networks_needed or not is_current_machine_in_network:
62
62
  return update_cluster_new_authorized_networks(args, authorized_networks)
63
63
 
64
- xpk_print("Current machine's IP adrress is already authorized.")
64
+ xpk_print("Current machine's IP address is already authorized.")
65
65
  return 0
66
66
 
67
67
 
@@ -84,7 +84,7 @@ def add_current_machine_to_networks_if_needed(
84
84
  is_current_machine_in_any_network(authorized_networks)
85
85
  )
86
86
  if is_current_machine_in_network_return_code != 0:
87
- xpk_print("Error on checking current machine's IP adrress.")
87
+ xpk_print("Error on checking current machine's IP address.")
88
88
  return is_current_machine_in_network_return_code, False, authorized_networks
89
89
 
90
90
  if not is_current_machine_in_network:
@@ -148,7 +148,7 @@ def is_cluster_private(args) -> bool:
148
148
 
149
149
 
150
150
  def get_cluster_authorized_networks(args) -> list[str]:
151
- """Retreives the networks list that are authorized to have access to Control Plane.
151
+ """Retrieves the networks list that are authorized to have access to Control Plane.
152
152
  Args:
153
153
  args: user provided arguments for running the command.
154
154
 
xpk/core/cluster_test.py CHANGED
@@ -41,11 +41,15 @@ def command_args(mocker: MockerFixture):
41
41
  return mocker.Mock(cluster="cluster", project="project", zone="zone")
42
42
 
43
43
 
44
- def test_get_cluster_credentials_returns_1_when_retrieval_command_fails(
44
+ def test_get_cluster_credentials_returns_1_when_retrieval_commands_fail(
45
45
  commands_tester: CommandsTester, command_args
46
46
  ):
47
47
  commands_tester.set_result_for_command(
48
- (1, ""), "gcloud container clusters get-credentials"
48
+ (1, ""), "gcloud container clusters get-credentials", " --dns-endpoint"
49
+ )
50
+ commands_tester.set_result_for_command(
51
+ (1, ""),
52
+ "gcloud container clusters get-credentials",
49
53
  )
50
54
  assert get_cluster_credentials(command_args) == 1
51
55
 
@@ -95,6 +99,29 @@ def test_get_cluster_credentials_retries_without_dns_when_dns_retrieval_fails(
95
99
  assert len(non_dns_endpoint_commands) == 1
96
100
 
97
101
 
102
+ def test_get_cluster_credentials_retries_without_dns_when_dns_retrieval_returns_error(
103
+ commands_tester: CommandsTester, command_args
104
+ ):
105
+ commands_tester.set_result_for_command(
106
+ (1, ""), "gcloud container clusters get-credentials", "--dns-endpoint"
107
+ )
108
+ commands_tester.set_result_for_command(
109
+ (0, ""),
110
+ "gcloud container clusters get-credentials",
111
+ )
112
+
113
+ assert get_cluster_credentials(command_args) == 0
114
+
115
+ non_dns_endpoint_commands = [
116
+ c
117
+ for c in commands_tester.get_matching_commands(
118
+ "gcloud container clusters get-credentials"
119
+ )
120
+ if "dns-endpoint" not in c
121
+ ]
122
+ assert len(non_dns_endpoint_commands) == 1
123
+
124
+
98
125
  def test_update_cluster_with_lustre_driver_if_necessary_with_default_port_runs_correct_checks(
99
126
  commands_tester: CommandsTester, command_args
100
127
  ):
@@ -17,9 +17,7 @@ limitations under the License.
17
17
  from ..utils.console import xpk_exit, xpk_print
18
18
  from .docker_image import setup_docker_image
19
19
  from .docker_resources import (
20
- add_container_ports,
21
20
  add_image_pull_policy_for_pw_or_gpu,
22
- add_jax_coordinator_port,
23
21
  get_env_container,
24
22
  get_main_container_resources,
25
23
  get_volume_mounts,
@@ -112,13 +110,12 @@ def get_main_container(args, system, docker_image, resource_type) -> str:
112
110
  'touch /shared-volume/stacktrace_signal; '
113
111
  )
114
112
 
115
- yaml = """- name: {docker_name}
113
+ containers = []
114
+ container_yaml = """
115
+ - name: {docker_name}
116
116
  image: {docker_image}
117
117
  {image_pull_policy}
118
118
  env: {env}
119
- ports:
120
- {container_ports}
121
- {jax_coordinator_port}
122
119
  securityContext:
123
120
  privileged: true
124
121
  command:
@@ -145,29 +142,39 @@ def get_main_container(args, system, docker_image, resource_type) -> str:
145
142
  limits:
146
143
  {resources}
147
144
  """
145
+ docker_name = get_main_container_docker_image(args, system)
148
146
  volume_mounts = get_volume_mounts(args, system)
149
147
  if volume_mounts != '':
150
- yaml += """
148
+ container_yaml += """
151
149
  volumeMounts:
152
150
  {volume_mounts}
153
151
  """
154
- return yaml.format(
155
- args=args,
156
- system=system,
157
- image_pull_policy=add_image_pull_policy_for_pw_or_gpu(args, system),
158
- env=get_env_container(args, system),
159
- container_ports=add_container_ports(args, system),
160
- jax_coordinator_port=add_jax_coordinator_port(system),
161
- docker_name=get_main_container_docker_image(args, system),
162
- docker_image=docker_image,
163
- gsutil_test_command=gsutil_test_command,
164
- command=command,
165
- tpu_stacktrace_terminate_command=tpu_stacktrace_terminate_command,
166
- gpu_workload_terminate_command=gpu_workload_terminate_command,
167
- xpk_internal_commands=xpk_internal_commands,
168
- resources=get_main_container_resources(args, system, resource_type),
169
- volume_mounts=volume_mounts,
170
- )
152
+ # pathways job running on 2 parallel containers is not verified yet
153
+ if args.use_pathways:
154
+ system.parallel_containers = 1
155
+
156
+ env = get_env_container(args, system)
157
+ image_pull_policy = add_image_pull_policy_for_pw_or_gpu(args, system)
158
+ for i in range(system.parallel_containers):
159
+ docker_name_sufix = f'-{i + 1}' if system.parallel_containers > 1 else ''
160
+ containers.append(
161
+ container_yaml.format(
162
+ args=args,
163
+ system=system,
164
+ image_pull_policy=image_pull_policy,
165
+ env=env,
166
+ docker_name=f'{docker_name}{docker_name_sufix}',
167
+ docker_image=docker_image,
168
+ gsutil_test_command=gsutil_test_command,
169
+ command=command,
170
+ tpu_stacktrace_terminate_command=tpu_stacktrace_terminate_command,
171
+ gpu_workload_terminate_command=gpu_workload_terminate_command,
172
+ xpk_internal_commands=xpk_internal_commands,
173
+ resources=get_main_container_resources(args, system, resource_type),
174
+ volume_mounts=volume_mounts,
175
+ )
176
+ )
177
+ return ''.join(containers)
171
178
 
172
179
 
173
180
  def get_user_workload_container(args, system: SystemCharacteristics):
@@ -44,7 +44,7 @@ class CommandRunner(ABC):
44
44
 
45
45
  @abstractmethod
46
46
  def initialize(self) -> None:
47
- """initialize is a method that should implement all steps neccessary to run command.
47
+ """initialize is a method that should implement all steps necessary to run command.
48
48
 
49
49
  Returns:
50
50
  None
@@ -95,7 +95,7 @@ class DockerManager(CommandRunner):
95
95
  - gcloud_cfg_path (str) : path to directory containing gcloud configuration
96
96
  - working_dir (str) : path to directory in which gcluster deployment directory will be saved
97
97
  - client (DockerClient) : docker client
98
- - nocache (bool) : wheter to use docker cache when building image
98
+ - nocache (bool) : whether to use docker cache when building image
99
99
  - img_name (str) : name of docker image to create
100
100
  - container_name (str) : name of the container that will be created from img_name
101
101
  - rm_container_after (bool) : if set to True, docker container in which command is executed will be removed after each execution.
@@ -294,12 +294,12 @@ class DockerManager(CommandRunner):
294
294
  xpk_print(f"error while building image {self.img_name}: {e.msg}")
295
295
  xpk_exit(dockerBuildErrorCode)
296
296
  except APIError as e:
297
- xpk_print(f"erro while building image {self.img_name}: {e.explanation}")
297
+ xpk_print(f"error while building image {self.img_name}: {e.explanation}")
298
298
  xpk_exit(dockerBuildErrorCode)
299
299
  except TypeError as e:
300
300
  xpk_print(f"TypeError while building image {self.img_name}: {e.args}")
301
301
  xpk_exit(dockerBuildErrorCode)
302
- xpk_print("Docker image build succesfully.")
302
+ xpk_print("Docker image build successfully.")
303
303
  os.remove(self.dockerfile_path)
304
304
  tmp_dockerfile_dir = "/".join(self.dockerfile_path.split("/")[:-1])
305
305
  os.rmdir(tmp_dockerfile_dir)
@@ -53,7 +53,10 @@ def get_main_container_resources(
53
53
  offset_vCPUs = int(system.chips_per_vm) * 0.95
54
54
  return f'{resource_type}: {offset_vCPUs}'
55
55
 
56
- return f'{resource_type}: {system.chips_per_vm}'
56
+ return (
57
+ f'{resource_type}:'
58
+ f' {int(system.chips_per_vm / system.parallel_containers)}'
59
+ )
57
60
 
58
61
 
59
62
  def get_env_container(args, system: SystemCharacteristics) -> str:
xpk/core/kueue_manager.py CHANGED
@@ -41,8 +41,8 @@ from ..utils.console import xpk_print, xpk_exit, ask_for_user_consent
41
41
  from ..utils.templates import TEMPLATE_PATH, get_templates_absolute_path
42
42
  from packaging.version import Version
43
43
 
44
- KUEUE_VERSION = Version("v0.14.3")
45
- LATEST_BREAKING_VERSION = Version("v0.14.0")
44
+ KUEUE_VERSION = Version("v0.15.2")
45
+ LATEST_BREAKING_VERSION = Version("v0.15.0")
46
46
  WAIT_FOR_KUEUE_TIMEOUT = "10m"
47
47
  CLUSTER_QUEUE_NAME = "cluster-queue"
48
48
  LOCAL_QUEUE_NAME = "multislice-queue"
@@ -290,6 +290,7 @@ class KueueManager:
290
290
  cpu_limit=cpu_limit,
291
291
  memory_limit=memory_limit,
292
292
  topology_name=topology_name,
293
+ configure_super_slicing=kueue_config.configure_super_slicing,
293
294
  )
294
295
 
295
296
  config_yaml = template.render(context)
@@ -316,6 +317,7 @@ class KueueManager:
316
317
  cpu_limit: int,
317
318
  memory_limit: str,
318
319
  topology_name: str | None,
320
+ configure_super_slicing: bool,
319
321
  ) -> Dict[str, Any]:
320
322
  """Prepares the context for the Jinja2 template."""
321
323
  # Main accelerator flavor
@@ -328,11 +330,7 @@ class KueueManager:
328
330
  key, value = accelerator_label.split(":", 1)
329
331
  node_labels_dict[key] = value.strip()
330
332
 
331
- if system.supports_super_slicing:
332
- node_labels_dict["cloud.google.com/gke-tpu-partition-4x4x4-state"] = (
333
- "HEALTHY"
334
- )
335
- elif not autoprovisioning:
333
+ if not autoprovisioning and not configure_super_slicing:
336
334
  machine_label = create_machine_label(system)
337
335
  if machine_label:
338
336
  key, value = machine_label.split(":", 1)
@@ -383,7 +381,7 @@ class KueueManager:
383
381
  })
384
382
 
385
383
  admission_checks = []
386
- if system.supports_super_slicing:
384
+ if configure_super_slicing:
387
385
  admission_checks.append("ss-kueue-operator")
388
386
  if flex and is_queued_cluster(num_slices, system.accelerator_type):
389
387
  admission_checks.append("dws-prov")
@@ -113,7 +113,7 @@ def test_install_or_upgrade_when_outdated(
113
113
  result = kueue_manager.install_or_upgrade(KUEUE_CONFIG)
114
114
 
115
115
  assert result == 0
116
- mock_commands.assert_command_run("kubectl apply", "v0.14.3/manifests.yaml")
116
+ mock_commands.assert_command_run("kubectl apply", "v0.15.2/manifests.yaml")
117
117
  mock_commands.assert_command_run("kubectl apply -f", "/tmp/")
118
118
 
119
119
 
@@ -126,7 +126,7 @@ def test_install_or_upgrade_when_not_installed(
126
126
  result = kueue_manager.install_or_upgrade(KUEUE_CONFIG)
127
127
 
128
128
  assert result == 0
129
- mock_commands.assert_command_run("kubectl apply", "v0.14.3/manifests.yaml")
129
+ mock_commands.assert_command_run("kubectl apply", "v0.15.2/manifests.yaml")
130
130
  mock_commands.assert_command_run("kubectl apply -f", "/tmp/")
131
131
 
132
132
 
@@ -135,7 +135,7 @@ def test_upgrade_when_no_breaking_changes_between_versions_no_preparation_needed
135
135
  kueue_manager: KueueManager,
136
136
  mock_ask_for_user_consent: MagicMock,
137
137
  ):
138
- set_installed_kueue_version(mock_commands, Version("0.14.0"))
138
+ set_installed_kueue_version(mock_commands, Version("0.15.0"))
139
139
 
140
140
  kueue_manager.install_or_upgrade(KUEUE_CONFIG)
141
141
 
@@ -162,7 +162,7 @@ def test_upgrade_with_breaking_changes_between_versions_runs_preparation(
162
162
  assert result == 0
163
163
  mock_ask_for_user_consent.assert_called_once()
164
164
  assert (
165
- "CHANGELOG/CHANGELOG-0.14.md"
165
+ "CHANGELOG/CHANGELOG-0.15.md"
166
166
  in mock_ask_for_user_consent.mock_calls[0].args[0]
167
167
  )
168
168
  mock_commands.assert_command_run(
@@ -492,7 +492,6 @@ def test_configure_generates_correct_manifest_with_super_slicing(
492
492
  assert resource_flavor["spec"]["topologyName"] == "super-slice-topology"
493
493
  assert resource_flavor["spec"]["nodeLabels"] == {
494
494
  "cloud.google.com/gke-tpu-accelerator": "tpu7x",
495
- "cloud.google.com/gke-tpu-partition-4x4x4-state": "HEALTHY",
496
495
  }
497
496
  topology = _first(doc for doc in manifest_docs if doc["kind"] == "Topology")
498
497
  assert topology["metadata"]["name"] == "super-slice-topology"
xpk/core/nap.py CHANGED
@@ -24,7 +24,8 @@ from .capacity import (
24
24
  CapacityType,
25
25
  get_capacity_node_selectors_from_capacity_type,
26
26
  get_capacity_type,
27
- verify_reservation_exists,
27
+ get_reservations_list,
28
+ verify_reservations_exist,
28
29
  )
29
30
  from .commands import run_command_with_updates, run_commands
30
31
  from .gcloud_context import get_cluster_location
@@ -345,14 +346,24 @@ def get_autoprovisioning_node_selector_args(args) -> tuple[str, int]:
345
346
  )
346
347
  if return_code != 0:
347
348
  return node_selector_args, return_code
348
- return_code = verify_reservation_exists(args)
349
+ return_code = verify_reservations_exist(args)
349
350
  if return_code > 0:
350
351
  xpk_print('Unable to verify reservation name saved in config map.')
351
352
  return node_selector_args, return_code
352
353
 
353
354
  # Check if reservation id is valid. Shared function with cluster creation.
355
+ reservation_name = None
356
+ if capacity_type_str == CapacityType.RESERVATION.name:
357
+ reservations = get_reservations_list(args)
358
+ if len(reservations) > 1:
359
+ xpk_print('Error: NAP based clusters only support a single reservation.')
360
+ return node_selector_args, 1
361
+ reservation_name = reservations[0] if len(reservations) > 0 else None
362
+
354
363
  node_selector_args, return_code = (
355
- get_capacity_node_selectors_from_capacity_type(args, capacity_type_str)
364
+ get_capacity_node_selectors_from_capacity_type(
365
+ capacity_type_str, reservation_name
366
+ )
356
367
  )
357
368
  if return_code != 0:
358
369
  xpk_print('Unable to get node selectors from capacity type.')
xpk/core/nodepool.py CHANGED
@@ -14,7 +14,8 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from typing import List
17
+ from typing import Iterator, List
18
+ from itertools import cycle
18
19
 
19
20
  from ..utils.feature_flags import FeatureFlags
20
21
  from ..utils.console import ask_for_user_consent, xpk_print
@@ -25,6 +26,7 @@ from .capacity import (
25
26
  CapacityType,
26
27
  get_capacity_arguments_from_capacity_type,
27
28
  get_capacity_type,
29
+ get_reservations_list,
28
30
  print_reservations,
29
31
  )
30
32
  from .commands import run_command_for_value, run_commands, FailedCommand
@@ -85,12 +87,6 @@ def run_gke_node_pool_create_command(
85
87
  max_nodes = system.vms_per_slice
86
88
  else:
87
89
  max_nodes = 1000
88
- capacity_args, return_code = get_capacity_arguments_from_capacity_type(
89
- args, capacity_type, max_nodes, system.accelerator_type
90
- )
91
- if return_code > 0:
92
- xpk_print('Parsing capacity arguments failed!')
93
- return return_code
94
90
 
95
91
  desired_node_pool_count = (
96
92
  1 if system.accelerator_type == AcceleratorType.GPU else args.num_slices
@@ -274,9 +270,34 @@ def run_gke_node_pool_create_command(
274
270
 
275
271
  create_commands = []
276
272
  create_task_names = []
277
- for node_pool_name in desired_node_pool_names:
278
- if node_pool_name in node_pools_to_remain:
279
- continue
273
+ node_pools_to_create = [
274
+ np for np in desired_node_pool_names if np not in node_pools_to_remain
275
+ ]
276
+
277
+ reservations_iter: Iterator[str] | None = None
278
+ if capacity_type == CapacityType.RESERVATION:
279
+ reservations = get_reservations_list(args)
280
+ if (
281
+ _validate_reservation_count(reservations, len(node_pools_to_create))
282
+ != 0
283
+ ):
284
+ return 1
285
+ reservations_iter = (
286
+ cycle(reservations) if len(reservations) == 1 else iter(reservations)
287
+ )
288
+
289
+ for node_pool_name in node_pools_to_create:
290
+ capacity_args, return_code = get_capacity_arguments_from_capacity_type(
291
+ args,
292
+ capacity_type,
293
+ max_nodes,
294
+ system.accelerator_type,
295
+ reservation_name=next(reservations_iter) if reservations_iter else None,
296
+ )
297
+ if return_code > 0:
298
+ xpk_print('Parsing capacity arguments failed!')
299
+ return return_code
300
+
280
301
  command = (
281
302
  'gcloud beta container node-pools create'
282
303
  f' {node_pool_name}'
@@ -632,7 +653,7 @@ def ensure_resource_policy_exists(
632
653
  ) -> None:
633
654
  return_code, _ = run_command_for_value(
634
655
  (
635
- 'gcloud compute resource-policies describe'
656
+ 'gcloud beta compute resource-policies describe'
636
657
  f' {resource_policy_name}'
637
658
  f' --project={project}'
638
659
  f' --region={zone_to_region(zone)}'
@@ -643,13 +664,12 @@ def ensure_resource_policy_exists(
643
664
  if return_code == 0:
644
665
  return
645
666
 
646
- # TODO: b/465696970 - Verify the flag below before launching SUPER_SLICING:
647
667
  accelerator_topology_mode = (
648
668
  ' --accelerator-topology-mode=PROVISION_ONLY' if super_slicing else ''
649
669
  )
650
670
  return_code, _ = run_command_for_value(
651
671
  (
652
- 'gcloud compute resource-policies create workload-policy'
672
+ 'gcloud beta compute resource-policies create workload-policy'
653
673
  f' {resource_policy_name} --project={project} --region={zone_to_region(zone)} --type=HIGH_THROUGHPUT'
654
674
  f' --accelerator-topology={topology}{accelerator_topology_mode}'
655
675
  ),
@@ -658,3 +678,16 @@ def ensure_resource_policy_exists(
658
678
 
659
679
  if return_code != 0:
660
680
  raise RuntimeError('Unable to create resource policy')
681
+
682
+
683
+ def _validate_reservation_count(
684
+ reservations: List[str], num_node_pools_to_create: int
685
+ ) -> int:
686
+ """Validate that reservation count matches new nodepool count or is 1."""
687
+ if len(reservations) > 1 and len(reservations) != num_node_pools_to_create:
688
+ xpk_print(
689
+ f'Error: Number of reservations ({len(reservations)}) must match'
690
+ f' the number of NEW nodepools ({num_node_pools_to_create}) or be 1.'
691
+ )
692
+ return 1
693
+ return 0
xpk/core/nodepool_test.py CHANGED
@@ -20,6 +20,7 @@ from xpk.core.nodepool import (
20
20
  ensure_resource_policy_exists,
21
21
  get_desired_node_pool_names,
22
22
  run_gke_node_pool_create_command,
23
+ _validate_reservation_count,
23
24
  )
24
25
  from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics, DockerPlatform, GpuConfig
25
26
  from xpk.core.commands import FailedCommand
@@ -103,6 +104,7 @@ def commands_tester(mocker):
103
104
  return CommandsTester(
104
105
  mocker,
105
106
  run_command_for_value_path="xpk.core.nodepool.run_command_for_value",
107
+ run_command_batch_path="xpk.core.commands.run_command_batch",
106
108
  )
107
109
 
108
110
 
@@ -119,7 +121,7 @@ def test_ensure_resource_policy_exists_with_existing_policy_retrieves_existing_p
119
121
 
120
122
  assert len(commands_tester.commands_history) == 1
121
123
  commands_tester.assert_command_run(
122
- "gcloud compute resource-policies describe resource-policy",
124
+ "gcloud beta compute resource-policies describe resource-policy",
123
125
  "--project=test-project",
124
126
  "--region=us-central1",
125
127
  )
@@ -129,7 +131,7 @@ def test_ensure_resource_policy_exists_without_existing_policy_creates_policy(
129
131
  commands_tester: CommandsTester,
130
132
  ):
131
133
  commands_tester.set_result_for_command(
132
- (1, ""), "gcloud compute resource-policies describe"
134
+ (1, ""), "gcloud beta compute resource-policies describe"
133
135
  )
134
136
 
135
137
  ensure_resource_policy_exists(
@@ -142,16 +144,17 @@ def test_ensure_resource_policy_exists_without_existing_policy_creates_policy(
142
144
 
143
145
  assert len(commands_tester.commands_history) == 2
144
146
  commands_tester.assert_command_run(
145
- "gcloud compute resource-policies describe"
147
+ "gcloud beta compute resource-policies describe"
146
148
  )
147
149
  commands_tester.assert_command_run(
148
- "gcloud compute resource-policies create workload-policy resource-policy",
150
+ "gcloud beta compute resource-policies create workload-policy"
151
+ " resource-policy",
149
152
  "--project=test-project",
150
153
  "--region=us-central1",
151
154
  "--accelerator-topology=2x2x1",
152
155
  )
153
156
  commands_tester.assert_command_not_run(
154
- "gcloud compute resource-policies create workload-policy",
157
+ "gcloud beta compute resource-policies create workload-policy",
155
158
  "--accelerator-topology-mode",
156
159
  )
157
160
 
@@ -160,7 +163,7 @@ def test_ensure_resource_policy_exists_without_existing_policy_creates_policy_fo
160
163
  commands_tester: CommandsTester,
161
164
  ):
162
165
  commands_tester.set_result_for_command(
163
- (1, ""), "gcloud compute resource-policies describe"
166
+ (1, ""), "gcloud beta compute resource-policies describe"
164
167
  )
165
168
 
166
169
  ensure_resource_policy_exists(
@@ -172,7 +175,7 @@ def test_ensure_resource_policy_exists_without_existing_policy_creates_policy_fo
172
175
  )
173
176
 
174
177
  commands_tester.assert_command_run(
175
- "gcloud compute resource-policies create workload-policy",
178
+ "gcloud beta compute resource-policies create workload-policy",
176
179
  "--accelerator-topology-mode",
177
180
  )
178
181
 
@@ -182,7 +185,7 @@ def test_ensure_resource_policy_exits_without_existing_policy_throws_when_creati
182
185
  ):
183
186
  with pytest.raises(RuntimeError):
184
187
  commands_tester.set_result_for_command(
185
- (1, ""), "gcloud compute resource-policies"
188
+ (1, ""), "gcloud beta compute resource-policies"
186
189
  )
187
190
 
188
191
  ensure_resource_policy_exists(
@@ -433,3 +436,135 @@ def test_display_nodepool_creation_ignores_logs_without_errors(
433
436
  mock_xpk_print.call_args_list[0].args[0]
434
437
  == "Create Nodepools returned ERROR 1"
435
438
  )
439
+
440
+
441
+ def test_validate_reservation_count_mismatch(mock_xpk_print):
442
+ result = _validate_reservation_count(
443
+ ["res1", "res2"], num_node_pools_to_create=3
444
+ )
445
+
446
+ assert result == 1
447
+ assert mock_xpk_print.call_count == 1
448
+ assert (
449
+ "reservations (2) must match the number of NEW nodepools (3)"
450
+ in mock_xpk_print.call_args_list[0].args[0]
451
+ )
452
+
453
+
454
+ def test_run_gke_node_pool_create_command_multiple_reservations(
455
+ mocker,
456
+ commands_tester: CommandsTester,
457
+ ):
458
+ mocker.patch(
459
+ "xpk.core.nodepool.get_cluster_location", return_value="us-central1"
460
+ )
461
+ mocker.patch("xpk.core.capacity.verify_reservations_exist", return_value=0)
462
+ args = mocker.Mock(
463
+ num_slices=2,
464
+ reservation="res1,res2",
465
+ tpu_type="v4-8",
466
+ device_type=None,
467
+ cluster="test-cluster",
468
+ project="test-project",
469
+ zone="us-central1-a",
470
+ on_demand=False,
471
+ spot=False,
472
+ flex=False,
473
+ enable_workload_identity=False,
474
+ enable_gcsfuse_csi_driver=False,
475
+ host_maintenance_interval="AS_NEEDED",
476
+ custom_nodepool_arguments="",
477
+ )
478
+ system = SystemCharacteristics(
479
+ topology="2x2x1",
480
+ vms_per_slice=2,
481
+ gke_accelerator="tpu-v4",
482
+ gce_machine_type="ct4p-hightpu-4t",
483
+ chips_per_vm=4,
484
+ accelerator_type=AcceleratorType.TPU,
485
+ device_type="v4-8",
486
+ requires_workload_policy=False,
487
+ supports_sub_slicing=False,
488
+ supports_super_slicing=False,
489
+ supports_accelerator_network_profile=True,
490
+ docker_platform=DockerPlatform.AMD,
491
+ )
492
+ commands_tester.set_result_for_command(
493
+ (0, ""), "gcloud beta container node-pools list"
494
+ )
495
+
496
+ result = run_gke_node_pool_create_command(args, system, "1.2.3")
497
+
498
+ assert result == 0
499
+ commands_tester.assert_command_run(
500
+ "gcloud", "node-pools create", "--tpu-topology=2x2x1", times=2
501
+ )
502
+ commands_tester.assert_command_run(
503
+ "gcloud", "node-pools create", "test-cluster-np-0", "--reservation=res1"
504
+ )
505
+ commands_tester.assert_command_run(
506
+ "gcloud", "node-pools create", "test-cluster-np-1", "--reservation=res2"
507
+ )
508
+
509
+
510
+ def test_run_gke_node_pool_create_command_partial_reservations(
511
+ mocker,
512
+ commands_tester: CommandsTester,
513
+ ):
514
+ mocker.patch(
515
+ "xpk.core.nodepool.get_cluster_location", return_value="us-central1"
516
+ )
517
+ mocker.patch("xpk.core.nodepool.get_node_pools_to_delete", return_value=[])
518
+ mocker.patch("xpk.core.capacity.verify_reservations_exist", return_value=0)
519
+ args = mocker.Mock(
520
+ num_slices=3,
521
+ reservation="res1,res2",
522
+ tpu_type="v4-8",
523
+ device_type=None,
524
+ cluster="test-cluster",
525
+ project="test-project",
526
+ zone="us-central1-a",
527
+ on_demand=False,
528
+ spot=False,
529
+ flex=False,
530
+ enable_workload_identity=False,
531
+ enable_gcsfuse_csi_driver=False,
532
+ host_maintenance_interval="AS_NEEDED",
533
+ custom_nodepool_arguments="",
534
+ )
535
+ system = SystemCharacteristics(
536
+ topology="2x2x1",
537
+ vms_per_slice=2,
538
+ gke_accelerator="tpu-v4",
539
+ gce_machine_type="ct4p-hightpu-4t",
540
+ chips_per_vm=4,
541
+ accelerator_type=AcceleratorType.TPU,
542
+ device_type="v4-8",
543
+ requires_workload_policy=False,
544
+ supports_sub_slicing=False,
545
+ supports_super_slicing=False,
546
+ supports_accelerator_network_profile=True,
547
+ docker_platform=DockerPlatform.AMD,
548
+ )
549
+ commands_tester.set_result_for_command(
550
+ (0, "test-cluster-np-0"), "gcloud beta container node-pools list"
551
+ )
552
+ commands_tester.set_result_for_command(
553
+ (0, "us-central1-a"),
554
+ "gcloud",
555
+ "node-pools describe",
556
+ '--format="value(locations)"',
557
+ )
558
+
559
+ result = run_gke_node_pool_create_command(args, system, "1.2.3")
560
+
561
+ assert result == 0
562
+ commands_tester.assert_command_run(
563
+ "gcloud", "node-pools create", "--tpu-topology=2x2x1", times=2
564
+ )
565
+ commands_tester.assert_command_run(
566
+ "gcloud", "node-pools create", "test-cluster-np-1", "--reservation=res1"
567
+ )
568
+ commands_tester.assert_command_run(
569
+ "gcloud", "node-pools create", "test-cluster-np-2", "--reservation=res2"
570
+ )
@@ -56,7 +56,7 @@ class FuseStateClient(RemoteStateClient):
56
56
 
57
57
  def upload_state(self) -> None:
58
58
  xpk_print(
59
- f'Uploading dependecies from directory {self.state_dir} to bucket:'
59
+ f'Uploading dependencies from directory {self.state_dir} to bucket:'
60
60
  f' {self.bucket}. Path within bucket is: {self._get_bucket_path()}'
61
61
  )
62
62
  upload_directory_to_gcs(