xpk 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. xpk/commands/cluster.py +29 -30
  2. xpk/commands/cluster_gcluster.py +19 -14
  3. xpk/commands/cluster_test.py +1 -21
  4. xpk/commands/common.py +39 -6
  5. xpk/commands/common_test.py +170 -0
  6. xpk/commands/info.py +9 -5
  7. xpk/commands/inspector.py +33 -4
  8. xpk/commands/inspector_test.py +142 -0
  9. xpk/commands/workload.py +22 -8
  10. xpk/commands/workload_test.py +70 -3
  11. xpk/core/blueprint/blueprint_generator.py +19 -8
  12. xpk/core/blueprint/testing/data/a3_ultra.yaml +3 -1
  13. xpk/core/blueprint/testing/data/a4.yaml +3 -1
  14. xpk/core/capacity.py +37 -17
  15. xpk/core/capacity_test.py +66 -1
  16. xpk/core/cluster.py +10 -10
  17. xpk/core/cluster_private.py +3 -3
  18. xpk/core/cluster_test.py +29 -2
  19. xpk/core/docker_container.py +31 -24
  20. xpk/core/docker_manager.py +4 -4
  21. xpk/core/docker_resources.py +4 -1
  22. xpk/core/kueue_manager.py +6 -8
  23. xpk/core/kueue_manager_test.py +4 -5
  24. xpk/core/nap.py +14 -3
  25. xpk/core/nodepool.py +46 -13
  26. xpk/core/nodepool_test.py +143 -8
  27. xpk/core/remote_state/fuse_remote_state.py +1 -1
  28. xpk/core/scheduling.py +4 -1
  29. xpk/core/scheduling_test.py +1 -1
  30. xpk/core/system_characteristics.py +6 -0
  31. xpk/core/telemetry.py +11 -1
  32. xpk/core/telemetry_test.py +39 -0
  33. xpk/core/testing/commands_tester.py +26 -0
  34. xpk/core/testing/commands_tester_test.py +20 -1
  35. xpk/core/workload_decorators/rdma_decorator.py +9 -0
  36. xpk/parser/cluster.py +11 -1
  37. xpk/parser/cluster_test.py +59 -1
  38. xpk/parser/common.py +11 -0
  39. xpk/parser/storage.py +3 -3
  40. xpk/utils/console.py +1 -1
  41. xpk/utils/feature_flags.py +7 -3
  42. {xpk-1.0.0.dist-info → xpk-1.1.0.dist-info}/METADATA +37 -21
  43. {xpk-1.0.0.dist-info → xpk-1.1.0.dist-info}/RECORD +47 -54
  44. xpk-1.1.0.dist-info/top_level.txt +1 -0
  45. integration/README.md +0 -19
  46. integration/__init__.py +0 -15
  47. integration/docker_manager_test.py +0 -102
  48. integration/gcluster_a3mega_test.py +0 -215
  49. integration/gcluster_a3ultra_test.py +0 -187
  50. integration/gcluster_a4_test.py +0 -187
  51. integration/gcluster_test.py +0 -107
  52. xpk/utils/user_input.py +0 -48
  53. xpk/utils/user_input_test.py +0 -92
  54. xpk-1.0.0.dist-info/top_level.txt +0 -2
  55. {xpk-1.0.0.dist-info → xpk-1.1.0.dist-info}/WHEEL +0 -0
  56. {xpk-1.0.0.dist-info → xpk-1.1.0.dist-info}/entry_points.txt +0 -0
  57. {xpk-1.0.0.dist-info → xpk-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,142 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import pytest
18
+ from unittest import mock
19
+ from xpk.commands import inspector
20
+ from xpk.core.testing.commands_tester import CommandsTester
21
+
22
+
23
+ @pytest.fixture
24
+ def args():
25
+ args = mock.Mock()
26
+ args.print_to_terminal = False
27
+ return args
28
+
29
+
30
+ @pytest.fixture
31
+ def commands_tester(mocker):
32
+ return CommandsTester(
33
+ mocker,
34
+ run_command_for_value_path="xpk.commands.inspector.run_command_for_value",
35
+ )
36
+
37
+
38
+ @pytest.fixture
39
+ def mock_has_super_slicing_enabled(mocker):
40
+ return mocker.patch("xpk.commands.inspector.has_super_slicing_enabled")
41
+
42
+
43
+ @pytest.fixture
44
+ def mock_append_tmp_file(mocker):
45
+ return mocker.patch("xpk.commands.inspector.append_tmp_file")
46
+
47
+
48
+ @pytest.fixture
49
+ def mock_xpk_print(mocker):
50
+ return mocker.patch("xpk.commands.inspector.xpk_print")
51
+
52
+
53
+ def test_inspector_run_slice_controller_helper_no_super_slicing(
54
+ args: mock.Mock,
55
+ commands_tester: CommandsTester,
56
+ mock_has_super_slicing_enabled: mock.Mock,
57
+ mock_append_tmp_file: mock.Mock,
58
+ ):
59
+ mock_has_super_slicing_enabled.return_value = (0, False)
60
+
61
+ inspector.inspector_run_slice_controller_helper(args, "test_file")
62
+ commands_tester.assert_command_not_run(
63
+ "kubectl logs deployment slice-controller-controller-manager"
64
+ )
65
+ commands_tester.assert_command_not_run(
66
+ "kubectl describe deployment slice-controller-controller-manager"
67
+ )
68
+ mock_append_tmp_file.assert_not_called()
69
+
70
+
71
+ def test_inspector_run_slice_controller_helper_with_super_slicing_success(
72
+ args: mock.Mock,
73
+ commands_tester: CommandsTester,
74
+ mock_has_super_slicing_enabled: mock.Mock,
75
+ mock_append_tmp_file: mock.Mock,
76
+ ):
77
+ commands_tester.set_result_for_command(
78
+ (0, "some logs"),
79
+ "kubectl",
80
+ "logs",
81
+ "deployment slice-controller-controller-manager",
82
+ )
83
+ commands_tester.set_result_for_command(
84
+ (0, "some details"),
85
+ "kubectl",
86
+ "describe",
87
+ "deployment slice-controller-controller-manager",
88
+ )
89
+ mock_has_super_slicing_enabled.return_value = (0, True)
90
+
91
+ inspector.inspector_run_slice_controller_helper(args, "test_file")
92
+
93
+ commands_tester.assert_command_run(
94
+ "kubectl logs deployment slice-controller-controller-manager"
95
+ )
96
+ commands_tester.assert_command_run(
97
+ "kubectl describe deployment slice-controller-controller-manager"
98
+ )
99
+
100
+ mock_append_tmp_file.assert_called()
101
+ call_args_list = mock_append_tmp_file.call_args_list
102
+ assert any(
103
+ "Super-slicing topology set up" in args[0] for args, _ in call_args_list
104
+ )
105
+ assert any("some logs" in args[0] for args, _ in call_args_list)
106
+ assert any("some details" in args[0] for args, _ in call_args_list)
107
+
108
+
109
+ def test_inspector_run_slice_controller_helper_with_slice_controller_not_found(
110
+ args: mock.Mock,
111
+ commands_tester: CommandsTester,
112
+ mock_has_super_slicing_enabled: mock.Mock,
113
+ mock_append_tmp_file: mock.Mock,
114
+ mock_xpk_print: mock.Mock,
115
+ ):
116
+ commands_tester.set_result_for_command(
117
+ (1, "Error: Deployment not found"),
118
+ "kubectl",
119
+ "deployment slice-controller-controller-manager",
120
+ )
121
+ mock_has_super_slicing_enabled.return_value = (0, True)
122
+
123
+ inspector.inspector_run_slice_controller_helper(args, "test_file")
124
+
125
+ commands_tester.assert_command_run(
126
+ "kubectl describe deployment slice-controller-controller-manager"
127
+ )
128
+ commands_tester.assert_command_run(
129
+ "kubectl logs deployment slice-controller-controller-manager"
130
+ )
131
+
132
+ mock_append_tmp_file.assert_called()
133
+ call_args_list = mock_append_tmp_file.call_args_list
134
+ assert any(
135
+ "Super-slicing topology set up" in args[0] for args, _ in call_args_list
136
+ )
137
+
138
+ mock_xpk_print.assert_called()
139
+ call_args_list = mock_xpk_print.call_args_list
140
+ assert any(
141
+ "Error: Deployment not found" in args[0] for args, _ in call_args_list
142
+ )
xpk/commands/workload.py CHANGED
@@ -54,6 +54,7 @@ from ..core.resources import get_cluster_capacity_type, get_cluster_system_chara
54
54
  from ..core.resources import ConfigMapType, get_cluster_configmap
55
55
  from ..core.nodepool import ensure_resource_policy_exists
56
56
  from ..core.scheduling import (
57
+ ONE_TO_ONE_REPLICA_NODE_POOL_ASSIGNMENT_ANNOTATION,
57
58
  WorkloadScheduling,
58
59
  check_if_workload_can_schedule,
59
60
  create_tpu_machine_type,
@@ -99,7 +100,7 @@ from ..utils.file import write_tmp_file
99
100
  from ..utils.execution_context import is_dry_run
100
101
  from ..utils.validation import validate_dependencies_list, SystemDependency, should_validate_dependencies
101
102
  from . import cluster_gcluster
102
- from .common import is_TAS_possible
103
+ from .common import is_GPU_TAS_possible
103
104
  from jinja2 import Environment, FileSystemLoader
104
105
  from ..utils.templates import get_templates_absolute_path
105
106
 
@@ -111,7 +112,7 @@ metadata:
111
112
  kueue.x-k8s.io/queue-name: {local_queue_name} # Name of the LocalQueue
112
113
  xpk.google.com/workload: {args.workload}
113
114
  annotations:
114
- alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool # 1:1 job replica to node pool assignment
115
+ {jobset_annotations}
115
116
  spec:
116
117
  ttlSecondsAfterFinished: {args.ttl_seconds_after_finished}
117
118
  failurePolicy:
@@ -490,13 +491,21 @@ def workload_create(args) -> None:
490
491
  - PodFailurePolicy"""
491
492
  restart_on_exit_codes_list = get_restart_exit_codes(args)
492
493
  restart_on_exit_codes = ','.join(map(str, restart_on_exit_codes_list))
493
- pod_failure_policy = f"""
494
+
495
+ pod_failure_policy = """
494
496
  podFailurePolicy:
495
497
  rules:
498
+ """
499
+ docker_image = get_main_container_docker_image(args, workload_system)
500
+ for i in range(workload_system.parallel_containers):
501
+ docker_image_sufix = (
502
+ f'-{i + 1}' if workload_system.parallel_containers > 1 else ''
503
+ )
504
+ pod_failure_policy += f"""
496
505
  - action: FailJob
497
506
  onPodConditions: []
498
507
  onExitCodes:
499
- containerName: {get_main_container_docker_image(args, workload_system)}
508
+ containerName: {docker_image}{docker_image_sufix}
500
509
  operator: NotIn
501
510
  values: [{restart_on_exit_codes}]"""
502
511
 
@@ -534,11 +543,10 @@ def workload_create(args) -> None:
534
543
  capacity_type = get_cluster_capacity_type(args)
535
544
 
536
545
  annotations = (
537
- (
538
- 'kueue.x-k8s.io/podset-preferred-topology:'
539
- ' "cloud.google.com/gce-topology-host"'
546
+ 'kueue.x-k8s.io/podset-preferred-topology: "kubernetes.io/hostname"'
547
+ if is_GPU_TAS_possible(
548
+ cluster_system, capacity_type, args.cluster, args.zone, args.project
540
549
  )
541
- if is_TAS_possible(cluster_system, capacity_type)
542
550
  else ''
543
551
  )
544
552
 
@@ -648,9 +656,15 @@ def workload_create(args) -> None:
648
656
  if use_super_slicing
649
657
  else ''
650
658
  )
659
+ jobset_annotations = (
660
+ ''
661
+ if use_super_slicing or use_sub_slicing
662
+ else ONE_TO_ONE_REPLICA_NODE_POOL_ASSIGNMENT_ANNOTATION
663
+ )
651
664
 
652
665
  yml_string = WORKLOAD_CREATE_YAML.format(
653
666
  args=args,
667
+ jobset_annotations=jobset_annotations,
654
668
  container=container,
655
669
  vms_per_slice=workload_system.vms_per_slice,
656
670
  affinity=get_cpu_affinity(workload_system.accelerator_type),
@@ -23,6 +23,7 @@ from ..core.scheduling import WorkloadScheduling
23
23
  from ..core.system_characteristics import DockerPlatform, SystemCharacteristics, AcceleratorType, UserFacingNameToSystemCharacteristics, GpuConfig
24
24
  from .workload import workload_create
25
25
  from .cluster_test import construct_args
26
+ from ..core.docker_container import get_user_workload_container as real_get_user_workload_container
26
27
 
27
28
 
28
29
  SYSTEM_CHARACTERISTICS = SystemCharacteristics(
@@ -58,7 +59,7 @@ class _WorkloadCreateMocks:
58
59
  validate_dependencies_list: MagicMock
59
60
  write_tmp_file: MagicMock
60
61
  get_cluster_capacity_type: MagicMock
61
- is_TAS_possible: MagicMock
62
+ is_GPU_TAS_possible: MagicMock
62
63
  get_cluster_location: MagicMock
63
64
  xpk_exit: MagicMock
64
65
  run_command_with_updates: MagicMock
@@ -113,8 +114,8 @@ def workload_create_mocks(mocker) -> _WorkloadCreateMocks:
113
114
  'xpk.commands.workload.get_cluster_capacity_type',
114
115
  return_value='on-demand',
115
116
  ),
116
- is_TAS_possible=mocker.patch(
117
- 'xpk.commands.workload.is_TAS_possible', return_value=False
117
+ is_GPU_TAS_possible=mocker.patch(
118
+ 'xpk.commands.workload.is_GPU_TAS_possible', return_value=False
118
119
  ),
119
120
  get_cluster_location=mocker.patch(
120
121
  'xpk.commands.workload.get_cluster_location',
@@ -206,3 +207,69 @@ def test_workload_create_dry_run_with_output_file(mocker):
206
207
  written_content = mock_open.return_value.write.call_args[0][0]
207
208
  assert 'test-workload' in written_content
208
209
  assert 'cloud.google.com/gke-tpu-topology: 8x8' in written_content
210
+
211
+
212
+ def test_workload_create_multi_container_for_tpu7x(
213
+ workload_create_mocks: _WorkloadCreateMocks,
214
+ mocker,
215
+ ):
216
+ """Tests that the generated YAML for a multi-container workload has correct pod failure policy and container structure."""
217
+
218
+ # Enable dry_run to prevent external calls like get_storages_to_mount -> gcloud
219
+ mocker.patch('xpk.utils.execution_context.dry_run', True)
220
+
221
+ # Mock dependencies required by get_user_workload_container -> get_main_container
222
+ mocker.patch(
223
+ 'xpk.core.docker_container.setup_docker_image',
224
+ return_value=(0, 'dummy-image'),
225
+ )
226
+ mocker.patch(
227
+ 'xpk.core.docker_container.get_gke_debugging_dashboard', return_value=None
228
+ )
229
+
230
+ # Use the real get_user_workload_container to test integration
231
+ workload_create_mocks.get_user_workload_container.side_effect = (
232
+ real_get_user_workload_container
233
+ )
234
+
235
+ args = construct_args(
236
+ workload='test-workload',
237
+ command='echo hello',
238
+ num_nodes=1,
239
+ tpu_type='tpu7x-2x2x2',
240
+ restart_on_exit_codes=None,
241
+ docker_name='test-docker',
242
+ deploy_stacktrace_sidecar=False,
243
+ enable_debug_logs=False,
244
+ scheduler='default-scheduler',
245
+ )
246
+ workload_create(args)
247
+
248
+ assert workload_create_mocks.write_tmp_file.called
249
+ yaml_content = workload_create_mocks.write_tmp_file.call_args[0][0]
250
+ jobset = yaml.safe_load(yaml_content)
251
+
252
+ # Verify Pod Failure Policy
253
+ pod_failure_rules = jobset['spec']['replicatedJobs'][0]['template']['spec'][
254
+ 'podFailurePolicy'
255
+ ]['rules']
256
+ # Should have 2 rules for multi_container
257
+ assert len(pod_failure_rules) == 2
258
+ assert pod_failure_rules[0]['onExitCodes']['containerName'].endswith('-1')
259
+ assert pod_failure_rules[1]['onExitCodes']['containerName'].endswith('-2')
260
+
261
+ # Verify Containers
262
+ # Navigate to the containers list in the YAML
263
+ containers = jobset['spec']['replicatedJobs'][0]['template']['spec'][
264
+ 'template'
265
+ ]['spec']['containers']
266
+
267
+ assert len(containers) == 2
268
+ assert containers[0]['name'].endswith('-1')
269
+ assert containers[1]['name'].endswith('-2')
270
+ assert containers[0]['image'] == 'dummy-image'
271
+ assert containers[1]['image'] == 'dummy-image'
272
+
273
+ # Check if resources are split correctly (4 chips / 2 containers = 2 chips)
274
+ assert containers[0]['resources']['limits']['google.com/tpu'] == 2
275
+ assert containers[1]['resources']['limits']['google.com/tpu'] == 2
@@ -24,6 +24,7 @@ from packaging.version import parse
24
24
  from ...utils.console import xpk_exit, xpk_print
25
25
  from ...utils.versions import ReleaseChannel
26
26
  from ...utils.file import ensure_directory_exists
27
+ from ...utils.templates import get_templates_absolute_path
27
28
 
28
29
 
29
30
  from ..capacity import (
@@ -51,9 +52,9 @@ supported_device_types = {
51
52
  a4_device_type,
52
53
  }
53
54
  blueprint_dependencies_dir = {
54
- a3mega_device_type: "src/xpk/blueprints/a3mega",
55
- a3ultra_device_type: "src/xpk/blueprints/a3ultra",
56
- a4_device_type: "src/xpk/blueprints/a4",
55
+ a3mega_device_type: get_templates_absolute_path("blueprints/a3mega"),
56
+ a3ultra_device_type: get_templates_absolute_path("blueprints/a3ultra"),
57
+ a4_device_type: get_templates_absolute_path("blueprints/a4"),
57
58
  }
58
59
 
59
60
  cluster_toolkit_url = "github.com/GoogleCloudPlatform/cluster-toolkit"
@@ -63,7 +64,7 @@ common_cluster_labels = {"gke_product_type": "xpk"}
63
64
 
64
65
  class BlueprintGeneratorOutput:
65
66
  """BlueprintGeneratorOutput is a class containing fields with output blueprint file path and path to blueprint dependencies.
66
- Atributes:
67
+ Attributes:
67
68
  - blueprint_file (str) : path to generated blueprint file.
68
69
  - blueprint_dependencies (str) : path to directory containing blueprint dependencies.
69
70
  """
@@ -75,7 +76,7 @@ class BlueprintGeneratorOutput:
75
76
 
76
77
  class BlueprintGenerator:
77
78
  """BlueprintGenerator is a class for generating blueprints
78
- Atributes:
79
+ Attributes:
79
80
  - storage_path (str) - path to directory where generated files and directories will be stored.
80
81
  """
81
82
 
@@ -239,10 +240,18 @@ class BlueprintGenerator:
239
240
  else:
240
241
  a3_megagpu_pool_0.update_settings({"static_node_count": num_nodes})
241
242
 
243
+ if capacity_type not in (CapacityType.SPOT, CapacityType.FLEX_START):
244
+ a3_megagpu_pool_0.update_settings(
245
+ {"placement_policy": {"type": "COMPACT"}}
246
+ )
247
+
242
248
  if release_channel == ReleaseChannel.RAPID:
243
249
  a3_megagpu_pool_0.set_setting("auto_upgrade", True)
244
250
 
245
- set_placement_policy = capacity_type != CapacityType.SPOT
251
+ set_placement_policy = capacity_type not in (
252
+ CapacityType.SPOT,
253
+ CapacityType.FLEX_START,
254
+ )
246
255
  workload = DeploymentModule(
247
256
  id="workload_component_install",
248
257
  source="modules/management/kubectl-apply",
@@ -521,7 +530,7 @@ class BlueprintGenerator:
521
530
  settings={
522
531
  "release_channel": release_channel.value,
523
532
  "version_prefix": version_prefix,
524
- "min_cluster_version": cluster_version,
533
+ "min_master_version": cluster_version,
525
534
  "prefix_with_deployment_name": False,
526
535
  "name_suffix": cluster_name,
527
536
  "system_node_pool_machine_type": system_node_pool_machine_type,
@@ -614,6 +623,7 @@ class BlueprintGenerator:
614
623
  gpu_pool.update_settings(self.get_dws_flex_start())
615
624
  else:
616
625
  gpu_pool.update_settings({"static_node_count": num_nodes})
626
+ gpu_pool.update_settings({"placement_policy": {"type": "COMPACT"}})
617
627
 
618
628
  if release_channel == ReleaseChannel.RAPID:
619
629
  gpu_pool.set_setting("auto_upgrade", True)
@@ -809,7 +819,7 @@ class BlueprintGenerator:
809
819
  settings={
810
820
  "release_channel": release_channel.value,
811
821
  "version_prefix": version_prefix,
812
- "min_cluster_version": cluster_version,
822
+ "min_master_version": cluster_version,
813
823
  "system_node_pool_machine_type": system_node_pool_machine_type,
814
824
  "system_node_pool_node_count": {
815
825
  "total_min_nodes": system_node_pool_min_node_count,
@@ -896,6 +906,7 @@ class BlueprintGenerator:
896
906
  gpu_pool.update_settings(self.get_dws_flex_start())
897
907
  else:
898
908
  gpu_pool.update_settings({"static_node_count": num_nodes})
909
+ gpu_pool.update_settings({"placement_policy": {"type": "COMPACT"}})
899
910
 
900
911
  if release_channel == ReleaseChannel.RAPID:
901
912
  gpu_pool.set_setting("auto_upgrade", True)
@@ -97,7 +97,7 @@ deployment_groups:
97
97
  settings:
98
98
  release_channel: RAPID
99
99
  version_prefix: '1.2'
100
- min_cluster_version: 1.2.3
100
+ min_master_version: 1.2.3
101
101
  prefix_with_deployment_name: false
102
102
  name_suffix: gke-a3-ultra
103
103
  system_node_pool_machine_type: "e2-standard-16"
@@ -142,6 +142,8 @@ deployment_groups:
142
142
  specific_reservations:
143
143
  - name: test-reservation
144
144
  static_node_count: 2
145
+ placement_policy:
146
+ type: COMPACT
145
147
  outputs: [instructions]
146
148
 
147
149
  - !DeploymentModule
@@ -121,7 +121,7 @@ deployment_groups:
121
121
  network_tier=null}], ipv6_access_config=[], alias_ip_range=[]}], gke-a4-rdma-net.subnetwork_interfaces_gke))
122
122
  release_channel: RAPID
123
123
  version_prefix: '1.2'
124
- min_cluster_version: 1.2.3
124
+ min_master_version: 1.2.3
125
125
  use:
126
126
  - gke-a4-net-0
127
127
  - !DeploymentModule
@@ -154,6 +154,8 @@ deployment_groups:
154
154
  network_ip=null, stack_type=null, access_config=[{nat_ip=null, public_ptr_domain_name=null,
155
155
  network_tier=null}], ipv6_access_config=[], alias_ip_range=[]}], gke-a4-rdma-net.subnetwork_interfaces_gke))
156
156
  static_node_count: 2
157
+ placement_policy:
158
+ type: COMPACT
157
159
 
158
160
  - !DeploymentModule
159
161
  id: workload-manager-install
xpk/core/capacity.py CHANGED
@@ -90,7 +90,7 @@ def get_capacity_type(args) -> tuple[CapacityType, int]:
90
90
  capacity_type = CapacityType.ON_DEMAND
91
91
  num_types += 1
92
92
  if args.reservation:
93
- return_code = verify_reservation_exists(args)
93
+ return_code = verify_reservations_exist(args)
94
94
  if return_code > 0:
95
95
  return capacity_type, return_code
96
96
  capacity_type = CapacityType.RESERVATION
@@ -184,8 +184,22 @@ def get_reservation_deployment_type(
184
184
  return output.strip()
185
185
 
186
186
 
187
- def verify_reservation_exists(args) -> int:
188
- """Verify the reservation exists.
187
+ def get_reservations_list(args) -> list[str]:
188
+ """Get the list of reservations from args.
189
+
190
+ Args:
191
+ args: user provided arguments.
192
+
193
+ Returns:
194
+ List of strings of reservations.
195
+ """
196
+ if not args.reservation:
197
+ return []
198
+ return [r.strip() for r in args.reservation.split(',')]
199
+
200
+
201
+ def verify_reservations_exist(args) -> int:
202
+ """Verify the reservations exist.
189
203
 
190
204
  Args:
191
205
  args: user provided arguments for running the command.
@@ -193,16 +207,20 @@ def verify_reservation_exists(args) -> int:
193
207
  Returns:
194
208
  0 if successful and 1 otherwise.
195
209
  """
196
- reservation = parse_reservation(args.reservation, args.project)
197
- command = (
198
- f'gcloud beta compute reservations describe {reservation.name}'
199
- f' --project={reservation.project} --zone={args.zone}'
200
- )
201
- return_code = run_command_with_updates(command, 'Describe reservation')
202
- if return_code != 0:
203
- xpk_print(f'Describe reservation returned ERROR {return_code}')
204
- xpk_print('Please confirm that your reservation name is correct.')
205
- return 1
210
+ for reservation_name in get_reservations_list(args):
211
+ reservation = parse_reservation(reservation_name, args.project)
212
+ command = (
213
+ f'gcloud beta compute reservations describe {reservation.name}'
214
+ f' --project={reservation.project} --zone={args.zone}'
215
+ )
216
+ return_code = run_command_with_updates(command, 'Describe reservation')
217
+ if return_code != 0:
218
+ xpk_print(f'Describe reservation returned ERROR {return_code}')
219
+ xpk_print(
220
+ f'Please confirm that your reservation name {reservation_name} is'
221
+ ' correct.'
222
+ )
223
+ return 1
206
224
  return 0
207
225
 
208
226
 
@@ -211,6 +229,7 @@ def get_capacity_arguments_from_capacity_type(
211
229
  capacity_type: CapacityType,
212
230
  max_nodes: int,
213
231
  accelerator_type: AcceleratorType,
232
+ reservation_name: str | None,
214
233
  ) -> tuple[str, int]:
215
234
  """Determine the Nodepool creation capacity arguments needed.
216
235
 
@@ -240,7 +259,7 @@ def get_capacity_arguments_from_capacity_type(
240
259
  capacity_args += ' --enable-queued-provisioning'
241
260
  case CapacityType.RESERVATION:
242
261
  capacity_args = (
243
- f'--reservation-affinity=specific --reservation={args.reservation}'
262
+ f'--reservation-affinity=specific --reservation={reservation_name}'
244
263
  )
245
264
  case _:
246
265
  xpk_print(
@@ -252,13 +271,14 @@ def get_capacity_arguments_from_capacity_type(
252
271
 
253
272
 
254
273
  def get_capacity_node_selectors_from_capacity_type(
255
- args, capacity_type: str
274
+ capacity_type: str, reservation_name: str | None
256
275
  ) -> tuple[str, int]:
257
276
  """Determine the node selectors for a workload to run on a specific capacity type.
258
277
 
259
278
  Args:
260
- args: user provided arguments for running the command.
261
279
  capacity_type: The type of capacity the user configured.
280
+ reservation_name: The name of the reservation to use. Set to None if not
281
+ using reservations.
262
282
 
263
283
  Returns:
264
284
  Tuple with string with the node selectors to use and
@@ -275,7 +295,7 @@ def get_capacity_node_selectors_from_capacity_type(
275
295
  case CapacityType.SPOT.name:
276
296
  node_selector = 'cloud.google.com/gke-spot: "true"'
277
297
  case CapacityType.RESERVATION.name:
278
- node_selector = f'cloud.google.com/reservation-name: {args.reservation}'
298
+ node_selector = f'cloud.google.com/reservation-name: {reservation_name}'
279
299
  case _:
280
300
  xpk_print(
281
301
  f'Unknown capacity type: {capacity_type}. Unable to determine the'
xpk/core/capacity_test.py CHANGED
@@ -16,7 +16,15 @@ limitations under the License.
16
16
 
17
17
  import pytest
18
18
  from unittest.mock import MagicMock, patch
19
- from .capacity import get_reservation_deployment_type, parse_reservation, Reservation
19
+ from .capacity import (
20
+ get_reservation_deployment_type,
21
+ parse_reservation,
22
+ Reservation,
23
+ get_capacity_type,
24
+ CapacityType,
25
+ verify_reservations_exist,
26
+ get_reservations_list,
27
+ )
20
28
 
21
29
 
22
30
  @patch('xpk.core.capacity.xpk_print')
@@ -133,3 +141,60 @@ def test_parse_reservation_fails_on_invalid_reservations(
133
141
  parse_reservation(reservation_path, 'cluster-project')
134
142
 
135
143
  assert 'Unable to parse reservation' in xpk_print.mock_calls[0].args[0]
144
+
145
+
146
+ def test_get_capacity_type_multiple_reservations(mocker):
147
+ args = MagicMock()
148
+ args.on_demand = False
149
+ args.spot = False
150
+ args.flex = False
151
+ args.reservation = 'res1,res2'
152
+ args.project = 'test-project'
153
+ args.zone = 'us-central1-a'
154
+ mocker.patch('xpk.core.capacity.run_command_with_updates', return_value=0)
155
+
156
+ capacity_type, return_code = get_capacity_type(args)
157
+
158
+ assert capacity_type == CapacityType.RESERVATION
159
+ assert return_code == 0
160
+
161
+
162
+ def test_verify_reservations_exist_multiple(mocker):
163
+ args = MagicMock()
164
+ args.reservation = 'res1,res2'
165
+ args.project = 'test-project'
166
+ args.zone = 'us-central1-a'
167
+
168
+ mock_run = mocker.patch(
169
+ 'xpk.core.capacity.run_command_with_updates', return_value=0
170
+ )
171
+
172
+ return_code = verify_reservations_exist(args)
173
+
174
+ assert return_code == 0
175
+ assert mock_run.call_count == 2
176
+
177
+
178
+ def test_get_reservations_list_with_single_reservation(mocker):
179
+ args = mocker.Mock(reservation='res1')
180
+ assert get_reservations_list(args) == ['res1']
181
+
182
+
183
+ def test_get_reservations_list_with_multiple_reservations(mocker):
184
+ args = mocker.Mock(reservation='res1,res2')
185
+ assert get_reservations_list(args) == ['res1', 'res2']
186
+
187
+
188
+ def test_get_reservations_list_with_whitespace(mocker):
189
+ args = mocker.Mock(reservation='res1, res2 ')
190
+ assert get_reservations_list(args) == ['res1', 'res2']
191
+
192
+
193
+ def test_get_reservations_list_none(mocker):
194
+ args = mocker.Mock(reservation=None)
195
+ assert get_reservations_list(args) == []
196
+
197
+
198
+ def test_get_reservations_list_empty(mocker):
199
+ args = mocker.Mock(reservation='')
200
+ assert get_reservations_list(args) == []
xpk/core/cluster.py CHANGED
@@ -158,7 +158,7 @@ def install_nri_on_cluster() -> int:
158
158
 
159
159
 
160
160
  def get_cluster_nodes_info() -> list[dict]:
161
- """Get list of cluster's nodes descrition in yaml format
161
+ """Get list of cluster's nodes description in yaml format
162
162
 
163
163
  Returns:
164
164
  List of nodes info yaml objects.
@@ -393,11 +393,13 @@ def project_id_to_project_number(project_id: str) -> str:
393
393
  def setup_k8s_env(args) -> k8s_client.ApiClient:
394
394
  add_zone_and_project(args)
395
395
  get_cluster_credentials(args)
396
- args.project_number = (
397
- project_id_to_project_number(args.project)
398
- if not args.dry_run
399
- else abs(hash(args.project) % (10**12)) # 12 digit hash
400
- )
396
+ # Use provided project number if available, otherwise fetch via API
397
+ if getattr(args, 'project_number', None):
398
+ xpk_print(f'Using provided project number: {args.project_number}')
399
+ elif args.dry_run:
400
+ args.project_number = abs(hash(args.project) % (10**12)) # 12 digit hash
401
+ else:
402
+ args.project_number = project_id_to_project_number(args.project)
401
403
 
402
404
  config.load_kube_config()
403
405
  return k8s_client.ApiClient()
@@ -716,10 +718,8 @@ def get_cluster_credentials(args) -> int:
716
718
  location=location,
717
719
  dns_endpoint=True,
718
720
  )
719
- if return_code != 0:
720
- return return_code
721
721
 
722
- if not _are_credentials_valid():
722
+ if return_code != 0 or not _are_credentials_valid():
723
723
  xpk_print('Detected error. Retrying without --dns-endpoint flag...')
724
724
  return_code = _get_credentials(
725
725
  project=args.project,
@@ -751,6 +751,6 @@ def _get_credentials(
751
751
  def _are_credentials_valid() -> bool:
752
752
  kubectl_command = 'kubectl get pods'
753
753
  kubectl_return_code = run_command_with_updates(
754
- kubectl_command, 'Test kubectl credentials'
754
+ kubectl_command, 'Test kubectl credentials', verbose=False
755
755
  )
756
756
  return kubectl_return_code == 0