xpk 0.17.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. xpk/commands/cluster.py +33 -43
  2. xpk/commands/cluster_gcluster.py +19 -14
  3. xpk/commands/cluster_gcluster_test.py +2 -0
  4. xpk/commands/cluster_test.py +1 -21
  5. xpk/commands/common.py +39 -6
  6. xpk/commands/common_test.py +170 -0
  7. xpk/commands/info.py +9 -5
  8. xpk/commands/inspector.py +33 -4
  9. xpk/commands/inspector_test.py +142 -0
  10. xpk/commands/workload.py +32 -11
  11. xpk/commands/workload_test.py +71 -3
  12. xpk/core/blueprint/blueprint_generator.py +19 -8
  13. xpk/core/blueprint/testing/data/a3_ultra.yaml +3 -1
  14. xpk/core/blueprint/testing/data/a4.yaml +3 -1
  15. xpk/core/capacity.py +37 -17
  16. xpk/core/capacity_test.py +66 -1
  17. xpk/core/cluster.py +11 -10
  18. xpk/core/cluster_private.py +3 -3
  19. xpk/core/cluster_test.py +29 -2
  20. xpk/core/config.py +5 -2
  21. xpk/core/docker_container.py +31 -24
  22. xpk/core/docker_manager.py +4 -4
  23. xpk/core/docker_resources.py +4 -1
  24. xpk/core/kueue_manager.py +6 -8
  25. xpk/core/kueue_manager_test.py +6 -5
  26. xpk/core/nap.py +14 -3
  27. xpk/core/nodepool.py +52 -13
  28. xpk/core/nodepool_test.py +147 -8
  29. xpk/core/remote_state/fuse_remote_state.py +1 -1
  30. xpk/core/scheduling.py +32 -4
  31. xpk/core/scheduling_test.py +39 -2
  32. xpk/core/system_characteristics.py +44 -0
  33. xpk/core/system_characteristics_test.py +11 -0
  34. xpk/core/telemetry.py +11 -1
  35. xpk/core/telemetry_test.py +39 -0
  36. xpk/core/testing/commands_tester.py +26 -0
  37. xpk/core/testing/commands_tester_test.py +20 -1
  38. xpk/core/workload_decorators/rdma_decorator.py +9 -0
  39. xpk/parser/cluster.py +11 -1
  40. xpk/parser/cluster_test.py +59 -1
  41. xpk/parser/common.py +11 -17
  42. xpk/parser/core.py +0 -8
  43. xpk/parser/storage.py +3 -14
  44. xpk/utils/console.py +1 -1
  45. xpk/utils/feature_flags.py +8 -4
  46. {xpk-0.17.3.dist-info → xpk-1.1.0.dist-info}/METADATA +50 -23
  47. {xpk-0.17.3.dist-info → xpk-1.1.0.dist-info}/RECORD +51 -60
  48. xpk-1.1.0.dist-info/top_level.txt +1 -0
  49. integration/README.md +0 -19
  50. integration/__init__.py +0 -15
  51. integration/docker_manager_test.py +0 -102
  52. integration/gcluster_a3mega_test.py +0 -215
  53. integration/gcluster_a3ultra_test.py +0 -187
  54. integration/gcluster_a4_test.py +0 -187
  55. integration/gcluster_test.py +0 -107
  56. xpk/commands/kind.py +0 -265
  57. xpk/parser/kind.py +0 -95
  58. xpk/utils/user_input.py +0 -48
  59. xpk/utils/user_input_test.py +0 -92
  60. xpk-0.17.3.dist-info/top_level.txt +0 -2
  61. {xpk-0.17.3.dist-info → xpk-1.1.0.dist-info}/WHEEL +0 -0
  62. {xpk-0.17.3.dist-info → xpk-1.1.0.dist-info}/entry_points.txt +0 -0
  63. {xpk-0.17.3.dist-info → xpk-1.1.0.dist-info}/licenses/LICENSE +0 -0
xpk/core/cluster.py CHANGED
@@ -158,7 +158,7 @@ def install_nri_on_cluster() -> int:
158
158
 
159
159
 
160
160
  def get_cluster_nodes_info() -> list[dict]:
161
- """Get list of cluster's nodes descrition in yaml format
161
+ """Get list of cluster's nodes description in yaml format
162
162
 
163
163
  Returns:
164
164
  List of nodes info yaml objects.
@@ -391,14 +391,15 @@ def project_id_to_project_number(project_id: str) -> str:
391
391
 
392
392
 
393
393
  def setup_k8s_env(args) -> k8s_client.ApiClient:
394
- if not getattr(args, 'kind_cluster', False):
395
- add_zone_and_project(args)
396
- get_cluster_credentials(args)
397
- args.project_number = (
398
- project_id_to_project_number(args.project)
399
- if not args.dry_run
400
- else abs(hash(args.project) % (10**12)) # 12 digit hash
401
- )
394
+ add_zone_and_project(args)
395
+ get_cluster_credentials(args)
396
+ # Use provided project number if available, otherwise fetch via API
397
+ if getattr(args, 'project_number', None):
398
+ xpk_print(f'Using provided project number: {args.project_number}')
399
+ elif args.dry_run:
400
+ args.project_number = abs(hash(args.project) % (10**12)) # 12 digit hash
401
+ else:
402
+ args.project_number = project_id_to_project_number(args.project)
402
403
 
403
404
  config.load_kube_config()
404
405
  return k8s_client.ApiClient()
@@ -750,6 +751,6 @@ def _get_credentials(
750
751
  def _are_credentials_valid() -> bool:
751
752
  kubectl_command = 'kubectl get pods'
752
753
  kubectl_return_code = run_command_with_updates(
753
- kubectl_command, 'Test kubectl credentials'
754
+ kubectl_command, 'Test kubectl credentials', verbose=False
754
755
  )
755
756
  return kubectl_return_code == 0
@@ -61,7 +61,7 @@ def authorize_private_cluster_access_if_necessary(args) -> int:
61
61
  if new_authorized_networks_needed or not is_current_machine_in_network:
62
62
  return update_cluster_new_authorized_networks(args, authorized_networks)
63
63
 
64
- xpk_print("Current machine's IP adrress is already authorized.")
64
+ xpk_print("Current machine's IP address is already authorized.")
65
65
  return 0
66
66
 
67
67
 
@@ -84,7 +84,7 @@ def add_current_machine_to_networks_if_needed(
84
84
  is_current_machine_in_any_network(authorized_networks)
85
85
  )
86
86
  if is_current_machine_in_network_return_code != 0:
87
- xpk_print("Error on checking current machine's IP adrress.")
87
+ xpk_print("Error on checking current machine's IP address.")
88
88
  return is_current_machine_in_network_return_code, False, authorized_networks
89
89
 
90
90
  if not is_current_machine_in_network:
@@ -148,7 +148,7 @@ def is_cluster_private(args) -> bool:
148
148
 
149
149
 
150
150
  def get_cluster_authorized_networks(args) -> list[str]:
151
- """Retreives the networks list that are authorized to have access to Control Plane.
151
+ """Retrieves the networks list that are authorized to have access to Control Plane.
152
152
  Args:
153
153
  args: user provided arguments for running the command.
154
154
 
xpk/core/cluster_test.py CHANGED
@@ -41,11 +41,15 @@ def command_args(mocker: MockerFixture):
41
41
  return mocker.Mock(cluster="cluster", project="project", zone="zone")
42
42
 
43
43
 
44
- def test_get_cluster_credentials_returns_1_when_retrieval_command_fails(
44
+ def test_get_cluster_credentials_returns_1_when_retrieval_commands_fail(
45
45
  commands_tester: CommandsTester, command_args
46
46
  ):
47
47
  commands_tester.set_result_for_command(
48
- (1, ""), "gcloud container clusters get-credentials"
48
+ (1, ""), "gcloud container clusters get-credentials", " --dns-endpoint"
49
+ )
50
+ commands_tester.set_result_for_command(
51
+ (1, ""),
52
+ "gcloud container clusters get-credentials",
49
53
  )
50
54
  assert get_cluster_credentials(command_args) == 1
51
55
 
@@ -95,6 +99,29 @@ def test_get_cluster_credentials_retries_without_dns_when_dns_retrieval_fails(
95
99
  assert len(non_dns_endpoint_commands) == 1
96
100
 
97
101
 
102
+ def test_get_cluster_credentials_retries_without_dns_when_dns_retrieval_returns_error(
103
+ commands_tester: CommandsTester, command_args
104
+ ):
105
+ commands_tester.set_result_for_command(
106
+ (1, ""), "gcloud container clusters get-credentials", "--dns-endpoint"
107
+ )
108
+ commands_tester.set_result_for_command(
109
+ (0, ""),
110
+ "gcloud container clusters get-credentials",
111
+ )
112
+
113
+ assert get_cluster_credentials(command_args) == 0
114
+
115
+ non_dns_endpoint_commands = [
116
+ c
117
+ for c in commands_tester.get_matching_commands(
118
+ "gcloud container clusters get-credentials"
119
+ )
120
+ if "dns-endpoint" not in c
121
+ ]
122
+ assert len(non_dns_endpoint_commands) == 1
123
+
124
+
98
125
  def test_update_cluster_with_lustre_driver_if_necessary_with_default_port_runs_correct_checks(
99
126
  commands_tester: CommandsTester, command_args
100
127
  ):
xpk/core/config.py CHANGED
@@ -19,6 +19,7 @@ import os
19
19
  import ruamel.yaml
20
20
  from abc import ABC, abstractmethod
21
21
  from ..utils import file
22
+ from ..utils.execution_context import is_dry_run
22
23
  from ..utils.console import xpk_print
23
24
  from setuptools_scm import get_version as setuptools_get_version
24
25
  from importlib.metadata import version, PackageNotFoundError
@@ -96,8 +97,7 @@ class FileSystemConfig(Config):
96
97
  self._allowed_keys = DEFAULT_KEYS
97
98
 
98
99
  def _open_configs(self) -> dict | None:
99
- dir_path = '/'.join(self._config.split('/')[:-1])
100
- file.ensure_directory_exists(dir_path)
100
+ file.ensure_directory_exists(os.path.dirname(self._config))
101
101
 
102
102
  if not os.path.exists(self._config):
103
103
  return None
@@ -107,6 +107,9 @@ class FileSystemConfig(Config):
107
107
  return config_yaml
108
108
 
109
109
  def _save_configs(self, config_yaml: dict) -> None:
110
+ if is_dry_run():
111
+ return None
112
+
110
113
  with open(self._config, encoding='utf-8', mode='w') as stream:
111
114
  yaml.dump(config_yaml, stream)
112
115
 
@@ -17,9 +17,7 @@ limitations under the License.
17
17
  from ..utils.console import xpk_exit, xpk_print
18
18
  from .docker_image import setup_docker_image
19
19
  from .docker_resources import (
20
- add_container_ports,
21
20
  add_image_pull_policy_for_pw_or_gpu,
22
- add_jax_coordinator_port,
23
21
  get_env_container,
24
22
  get_main_container_resources,
25
23
  get_volume_mounts,
@@ -112,13 +110,12 @@ def get_main_container(args, system, docker_image, resource_type) -> str:
112
110
  'touch /shared-volume/stacktrace_signal; '
113
111
  )
114
112
 
115
- yaml = """- name: {docker_name}
113
+ containers = []
114
+ container_yaml = """
115
+ - name: {docker_name}
116
116
  image: {docker_image}
117
117
  {image_pull_policy}
118
118
  env: {env}
119
- ports:
120
- {container_ports}
121
- {jax_coordinator_port}
122
119
  securityContext:
123
120
  privileged: true
124
121
  command:
@@ -145,29 +142,39 @@ def get_main_container(args, system, docker_image, resource_type) -> str:
145
142
  limits:
146
143
  {resources}
147
144
  """
145
+ docker_name = get_main_container_docker_image(args, system)
148
146
  volume_mounts = get_volume_mounts(args, system)
149
147
  if volume_mounts != '':
150
- yaml += """
148
+ container_yaml += """
151
149
  volumeMounts:
152
150
  {volume_mounts}
153
151
  """
154
- return yaml.format(
155
- args=args,
156
- system=system,
157
- image_pull_policy=add_image_pull_policy_for_pw_or_gpu(args, system),
158
- env=get_env_container(args, system),
159
- container_ports=add_container_ports(args, system),
160
- jax_coordinator_port=add_jax_coordinator_port(system),
161
- docker_name=get_main_container_docker_image(args, system),
162
- docker_image=docker_image,
163
- gsutil_test_command=gsutil_test_command,
164
- command=command,
165
- tpu_stacktrace_terminate_command=tpu_stacktrace_terminate_command,
166
- gpu_workload_terminate_command=gpu_workload_terminate_command,
167
- xpk_internal_commands=xpk_internal_commands,
168
- resources=get_main_container_resources(args, system, resource_type),
169
- volume_mounts=volume_mounts,
170
- )
152
+ # pathways job running on 2 parallel containers is not verified yet
153
+ if args.use_pathways:
154
+ system.parallel_containers = 1
155
+
156
+ env = get_env_container(args, system)
157
+ image_pull_policy = add_image_pull_policy_for_pw_or_gpu(args, system)
158
+ for i in range(system.parallel_containers):
159
+ docker_name_sufix = f'-{i + 1}' if system.parallel_containers > 1 else ''
160
+ containers.append(
161
+ container_yaml.format(
162
+ args=args,
163
+ system=system,
164
+ image_pull_policy=image_pull_policy,
165
+ env=env,
166
+ docker_name=f'{docker_name}{docker_name_sufix}',
167
+ docker_image=docker_image,
168
+ gsutil_test_command=gsutil_test_command,
169
+ command=command,
170
+ tpu_stacktrace_terminate_command=tpu_stacktrace_terminate_command,
171
+ gpu_workload_terminate_command=gpu_workload_terminate_command,
172
+ xpk_internal_commands=xpk_internal_commands,
173
+ resources=get_main_container_resources(args, system, resource_type),
174
+ volume_mounts=volume_mounts,
175
+ )
176
+ )
177
+ return ''.join(containers)
171
178
 
172
179
 
173
180
  def get_user_workload_container(args, system: SystemCharacteristics):
@@ -44,7 +44,7 @@ class CommandRunner(ABC):
44
44
 
45
45
  @abstractmethod
46
46
  def initialize(self) -> None:
47
- """initialize is a method that should implement all steps neccessary to run command.
47
+ """initialize is a method that should implement all steps necessary to run command.
48
48
 
49
49
  Returns:
50
50
  None
@@ -95,7 +95,7 @@ class DockerManager(CommandRunner):
95
95
  - gcloud_cfg_path (str) : path to directory containing gcloud configuration
96
96
  - working_dir (str) : path to directory in which gcluster deployment directory will be saved
97
97
  - client (DockerClient) : docker client
98
- - nocache (bool) : wheter to use docker cache when building image
98
+ - nocache (bool) : whether to use docker cache when building image
99
99
  - img_name (str) : name of docker image to create
100
100
  - container_name (str) : name of the container that will be created from img_name
101
101
  - rm_container_after (bool) : if set to True, docker container in which command is executed will be removed after each execution.
@@ -294,12 +294,12 @@ class DockerManager(CommandRunner):
294
294
  xpk_print(f"error while building image {self.img_name}: {e.msg}")
295
295
  xpk_exit(dockerBuildErrorCode)
296
296
  except APIError as e:
297
- xpk_print(f"erro while building image {self.img_name}: {e.explanation}")
297
+ xpk_print(f"error while building image {self.img_name}: {e.explanation}")
298
298
  xpk_exit(dockerBuildErrorCode)
299
299
  except TypeError as e:
300
300
  xpk_print(f"TypeError while building image {self.img_name}: {e.args}")
301
301
  xpk_exit(dockerBuildErrorCode)
302
- xpk_print("Docker image build succesfully.")
302
+ xpk_print("Docker image build successfully.")
303
303
  os.remove(self.dockerfile_path)
304
304
  tmp_dockerfile_dir = "/".join(self.dockerfile_path.split("/")[:-1])
305
305
  os.rmdir(tmp_dockerfile_dir)
@@ -53,7 +53,10 @@ def get_main_container_resources(
53
53
  offset_vCPUs = int(system.chips_per_vm) * 0.95
54
54
  return f'{resource_type}: {offset_vCPUs}'
55
55
 
56
- return f'{resource_type}: {system.chips_per_vm}'
56
+ return (
57
+ f'{resource_type}:'
58
+ f' {int(system.chips_per_vm / system.parallel_containers)}'
59
+ )
57
60
 
58
61
 
59
62
  def get_env_container(args, system: SystemCharacteristics) -> str:
xpk/core/kueue_manager.py CHANGED
@@ -41,8 +41,8 @@ from ..utils.console import xpk_print, xpk_exit, ask_for_user_consent
41
41
  from ..utils.templates import TEMPLATE_PATH, get_templates_absolute_path
42
42
  from packaging.version import Version
43
43
 
44
- KUEUE_VERSION = Version("v0.14.3")
45
- LATEST_BREAKING_VERSION = Version("v0.14.0")
44
+ KUEUE_VERSION = Version("v0.15.2")
45
+ LATEST_BREAKING_VERSION = Version("v0.15.0")
46
46
  WAIT_FOR_KUEUE_TIMEOUT = "10m"
47
47
  CLUSTER_QUEUE_NAME = "cluster-queue"
48
48
  LOCAL_QUEUE_NAME = "multislice-queue"
@@ -290,6 +290,7 @@ class KueueManager:
290
290
  cpu_limit=cpu_limit,
291
291
  memory_limit=memory_limit,
292
292
  topology_name=topology_name,
293
+ configure_super_slicing=kueue_config.configure_super_slicing,
293
294
  )
294
295
 
295
296
  config_yaml = template.render(context)
@@ -316,6 +317,7 @@ class KueueManager:
316
317
  cpu_limit: int,
317
318
  memory_limit: str,
318
319
  topology_name: str | None,
320
+ configure_super_slicing: bool,
319
321
  ) -> Dict[str, Any]:
320
322
  """Prepares the context for the Jinja2 template."""
321
323
  # Main accelerator flavor
@@ -328,11 +330,7 @@ class KueueManager:
328
330
  key, value = accelerator_label.split(":", 1)
329
331
  node_labels_dict[key] = value.strip()
330
332
 
331
- if system.supports_super_slicing:
332
- node_labels_dict["cloud.google.com/gke-tpu-partition-4x4x4-state"] = (
333
- "HEALTHY"
334
- )
335
- elif not autoprovisioning:
333
+ if not autoprovisioning and not configure_super_slicing:
336
334
  machine_label = create_machine_label(system)
337
335
  if machine_label:
338
336
  key, value = machine_label.split(":", 1)
@@ -383,7 +381,7 @@ class KueueManager:
383
381
  })
384
382
 
385
383
  admission_checks = []
386
- if system.supports_super_slicing:
384
+ if configure_super_slicing:
387
385
  admission_checks.append("ss-kueue-operator")
388
386
  if flex and is_queued_cluster(num_slices, system.accelerator_type):
389
387
  admission_checks.append("dws-prov")
@@ -36,6 +36,7 @@ TPU_SYSTEM: SystemCharacteristics = SystemCharacteristics(
36
36
  device_type="v5p-8",
37
37
  supports_sub_slicing=False,
38
38
  supports_super_slicing=False,
39
+ supports_accelerator_network_profile=False,
39
40
  docker_platform=DockerPlatform.ARM,
40
41
  )
41
42
 
@@ -112,7 +113,7 @@ def test_install_or_upgrade_when_outdated(
112
113
  result = kueue_manager.install_or_upgrade(KUEUE_CONFIG)
113
114
 
114
115
  assert result == 0
115
- mock_commands.assert_command_run("kubectl apply", "v0.14.3/manifests.yaml")
116
+ mock_commands.assert_command_run("kubectl apply", "v0.15.2/manifests.yaml")
116
117
  mock_commands.assert_command_run("kubectl apply -f", "/tmp/")
117
118
 
118
119
 
@@ -125,7 +126,7 @@ def test_install_or_upgrade_when_not_installed(
125
126
  result = kueue_manager.install_or_upgrade(KUEUE_CONFIG)
126
127
 
127
128
  assert result == 0
128
- mock_commands.assert_command_run("kubectl apply", "v0.14.3/manifests.yaml")
129
+ mock_commands.assert_command_run("kubectl apply", "v0.15.2/manifests.yaml")
129
130
  mock_commands.assert_command_run("kubectl apply -f", "/tmp/")
130
131
 
131
132
 
@@ -134,7 +135,7 @@ def test_upgrade_when_no_breaking_changes_between_versions_no_preparation_needed
134
135
  kueue_manager: KueueManager,
135
136
  mock_ask_for_user_consent: MagicMock,
136
137
  ):
137
- set_installed_kueue_version(mock_commands, Version("0.14.0"))
138
+ set_installed_kueue_version(mock_commands, Version("0.15.0"))
138
139
 
139
140
  kueue_manager.install_or_upgrade(KUEUE_CONFIG)
140
141
 
@@ -161,7 +162,7 @@ def test_upgrade_with_breaking_changes_between_versions_runs_preparation(
161
162
  assert result == 0
162
163
  mock_ask_for_user_consent.assert_called_once()
163
164
  assert (
164
- "CHANGELOG/CHANGELOG-0.14.md"
165
+ "CHANGELOG/CHANGELOG-0.15.md"
165
166
  in mock_ask_for_user_consent.mock_calls[0].args[0]
166
167
  )
167
168
  mock_commands.assert_command_run(
@@ -411,6 +412,7 @@ def test_configure_generates_correct_manifest_with_gke_default_topology(
411
412
  supports_sub_slicing=False,
412
413
  supports_super_slicing=False,
413
414
  docker_platform=DockerPlatform.ARM,
415
+ supports_accelerator_network_profile=True,
414
416
  gpu_config=GpuConfig(requires_topology=True),
415
417
  ),
416
418
  )
@@ -490,7 +492,6 @@ def test_configure_generates_correct_manifest_with_super_slicing(
490
492
  assert resource_flavor["spec"]["topologyName"] == "super-slice-topology"
491
493
  assert resource_flavor["spec"]["nodeLabels"] == {
492
494
  "cloud.google.com/gke-tpu-accelerator": "tpu7x",
493
- "cloud.google.com/gke-tpu-partition-4x4x4-state": "HEALTHY",
494
495
  }
495
496
  topology = _first(doc for doc in manifest_docs if doc["kind"] == "Topology")
496
497
  assert topology["metadata"]["name"] == "super-slice-topology"
xpk/core/nap.py CHANGED
@@ -24,7 +24,8 @@ from .capacity import (
24
24
  CapacityType,
25
25
  get_capacity_node_selectors_from_capacity_type,
26
26
  get_capacity_type,
27
- verify_reservation_exists,
27
+ get_reservations_list,
28
+ verify_reservations_exist,
28
29
  )
29
30
  from .commands import run_command_with_updates, run_commands
30
31
  from .gcloud_context import get_cluster_location
@@ -345,14 +346,24 @@ def get_autoprovisioning_node_selector_args(args) -> tuple[str, int]:
345
346
  )
346
347
  if return_code != 0:
347
348
  return node_selector_args, return_code
348
- return_code = verify_reservation_exists(args)
349
+ return_code = verify_reservations_exist(args)
349
350
  if return_code > 0:
350
351
  xpk_print('Unable to verify reservation name saved in config map.')
351
352
  return node_selector_args, return_code
352
353
 
353
354
  # Check if reservation id is valid. Shared function with cluster creation.
355
+ reservation_name = None
356
+ if capacity_type_str == CapacityType.RESERVATION.name:
357
+ reservations = get_reservations_list(args)
358
+ if len(reservations) > 1:
359
+ xpk_print('Error: NAP based clusters only support a single reservation.')
360
+ return node_selector_args, 1
361
+ reservation_name = reservations[0] if len(reservations) > 0 else None
362
+
354
363
  node_selector_args, return_code = (
355
- get_capacity_node_selectors_from_capacity_type(args, capacity_type_str)
364
+ get_capacity_node_selectors_from_capacity_type(
365
+ capacity_type_str, reservation_name
366
+ )
356
367
  )
357
368
  if return_code != 0:
358
369
  xpk_print('Unable to get node selectors from capacity type.')
xpk/core/nodepool.py CHANGED
@@ -14,7 +14,8 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from typing import List
17
+ from typing import Iterator, List
18
+ from itertools import cycle
18
19
 
19
20
  from ..utils.feature_flags import FeatureFlags
20
21
  from ..utils.console import ask_for_user_consent, xpk_print
@@ -25,6 +26,7 @@ from .capacity import (
25
26
  CapacityType,
26
27
  get_capacity_arguments_from_capacity_type,
27
28
  get_capacity_type,
29
+ get_reservations_list,
28
30
  print_reservations,
29
31
  )
30
32
  from .commands import run_command_for_value, run_commands, FailedCommand
@@ -85,12 +87,6 @@ def run_gke_node_pool_create_command(
85
87
  max_nodes = system.vms_per_slice
86
88
  else:
87
89
  max_nodes = 1000
88
- capacity_args, return_code = get_capacity_arguments_from_capacity_type(
89
- args, capacity_type, max_nodes, system.accelerator_type
90
- )
91
- if return_code > 0:
92
- xpk_print('Parsing capacity arguments failed!')
93
- return return_code
94
90
 
95
91
  desired_node_pool_count = (
96
92
  1 if system.accelerator_type == AcceleratorType.GPU else args.num_slices
@@ -274,9 +270,34 @@ def run_gke_node_pool_create_command(
274
270
 
275
271
  create_commands = []
276
272
  create_task_names = []
277
- for node_pool_name in desired_node_pool_names:
278
- if node_pool_name in node_pools_to_remain:
279
- continue
273
+ node_pools_to_create = [
274
+ np for np in desired_node_pool_names if np not in node_pools_to_remain
275
+ ]
276
+
277
+ reservations_iter: Iterator[str] | None = None
278
+ if capacity_type == CapacityType.RESERVATION:
279
+ reservations = get_reservations_list(args)
280
+ if (
281
+ _validate_reservation_count(reservations, len(node_pools_to_create))
282
+ != 0
283
+ ):
284
+ return 1
285
+ reservations_iter = (
286
+ cycle(reservations) if len(reservations) == 1 else iter(reservations)
287
+ )
288
+
289
+ for node_pool_name in node_pools_to_create:
290
+ capacity_args, return_code = get_capacity_arguments_from_capacity_type(
291
+ args,
292
+ capacity_type,
293
+ max_nodes,
294
+ system.accelerator_type,
295
+ reservation_name=next(reservations_iter) if reservations_iter else None,
296
+ )
297
+ if return_code > 0:
298
+ xpk_print('Parsing capacity arguments failed!')
299
+ return return_code
300
+
280
301
  command = (
281
302
  'gcloud beta container node-pools create'
282
303
  f' {node_pool_name}'
@@ -289,6 +310,12 @@ def run_gke_node_pool_create_command(
289
310
  f'{placement_args}'
290
311
  ' --enable-gvnic'
291
312
  )
313
+
314
+ if system.supports_accelerator_network_profile:
315
+ command += (
316
+ ' --accelerator-network-profile=auto'
317
+ ' --node-labels=cloud.google.com/gke-networking-dra-driver=true'
318
+ )
292
319
  if system.accelerator_type == AcceleratorType.TPU:
293
320
  command += f' --node-version={gke_node_pool_version}'
294
321
  if capacity_type == CapacityType.FLEX_START:
@@ -626,7 +653,7 @@ def ensure_resource_policy_exists(
626
653
  ) -> None:
627
654
  return_code, _ = run_command_for_value(
628
655
  (
629
- 'gcloud compute resource-policies describe'
656
+ 'gcloud beta compute resource-policies describe'
630
657
  f' {resource_policy_name}'
631
658
  f' --project={project}'
632
659
  f' --region={zone_to_region(zone)}'
@@ -637,13 +664,12 @@ def ensure_resource_policy_exists(
637
664
  if return_code == 0:
638
665
  return
639
666
 
640
- # TODO: b/465696970 - Verify the flag below before launching SUPER_SLICING:
641
667
  accelerator_topology_mode = (
642
668
  ' --accelerator-topology-mode=PROVISION_ONLY' if super_slicing else ''
643
669
  )
644
670
  return_code, _ = run_command_for_value(
645
671
  (
646
- 'gcloud compute resource-policies create workload-policy'
672
+ 'gcloud beta compute resource-policies create workload-policy'
647
673
  f' {resource_policy_name} --project={project} --region={zone_to_region(zone)} --type=HIGH_THROUGHPUT'
648
674
  f' --accelerator-topology={topology}{accelerator_topology_mode}'
649
675
  ),
@@ -652,3 +678,16 @@ def ensure_resource_policy_exists(
652
678
 
653
679
  if return_code != 0:
654
680
  raise RuntimeError('Unable to create resource policy')
681
+
682
+
683
+ def _validate_reservation_count(
684
+ reservations: List[str], num_node_pools_to_create: int
685
+ ) -> int:
686
+ """Validate that reservation count matches new nodepool count or is 1."""
687
+ if len(reservations) > 1 and len(reservations) != num_node_pools_to_create:
688
+ xpk_print(
689
+ f'Error: Number of reservations ({len(reservations)}) must match'
690
+ f' the number of NEW nodepools ({num_node_pools_to_create}) or be 1.'
691
+ )
692
+ return 1
693
+ return 0