xpk 0.9.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,7 +22,9 @@ from ruamel import yaml
22
22
 
23
23
  from ...utils.console import xpk_exit, xpk_print
24
24
  from ...utils.file import ensure_directory_exists
25
+
25
26
  from ..capacity import (
27
+ H100_DEVICE_TYPE,
26
28
  B200_DEVICE_TYPE,
27
29
  H100_MEGA_DEVICE_TYPE,
28
30
  H200_DEVICE_TYPE,
@@ -30,10 +32,11 @@ from ..capacity import (
30
32
  )
31
33
  from ..system_characteristics import get_system_characteristics_by_device_type
32
34
  from .blueprint_definitions import Blueprint, DeploymentGroup, DeploymentModule
33
-
35
+ from ..kueue import KUEUE_VERSION
34
36
 
35
37
  yaml = yaml.YAML()
36
38
 
39
+ a3high_device_type = H100_DEVICE_TYPE
37
40
  a3mega_device_type = H100_MEGA_DEVICE_TYPE
38
41
  a3ultra_device_type = H200_DEVICE_TYPE
39
42
  a4_device_type = B200_DEVICE_TYPE
@@ -49,7 +52,7 @@ blueprint_dependencies_dir = {
49
52
  }
50
53
 
51
54
  cluster_toolkit_url = "github.com/GoogleCloudPlatform/cluster-toolkit"
52
- cluster_toolkit_version = "v1.48.0"
55
+ cluster_toolkit_version = "v1.57.1"
53
56
 
54
57
 
55
58
  class BlueprintGeneratorOutput:
@@ -92,6 +95,8 @@ class BlueprintGenerator:
92
95
  group_placement_max_distance: int = 2,
93
96
  subnetwork_cidr_suffix: int = 24,
94
97
  reservation: str | None = None,
98
+ reservation_placement_policy: dict[str, str] | None = None,
99
+ reservation_maintenance_interval: str = "PERIODIC",
95
100
  gcs_bucket: Optional[str | None] = None,
96
101
  capacity_type: CapacityType = CapacityType.ON_DEMAND,
97
102
  system_node_pool_min_node_count: int = 2,
@@ -142,7 +147,6 @@ class BlueprintGenerator:
142
147
  source="modules/scheduler/gke-cluster",
143
148
  use=[primary_vpc_name, gpu_subnets_name],
144
149
  settings={
145
- "release_channel": "RAPID",
146
150
  "prefix_with_deployment_name": False,
147
151
  "name_suffix": cluster_name,
148
152
  "enable_private_endpoint": False,
@@ -176,34 +180,42 @@ class BlueprintGenerator:
176
180
  "group_placement_max_distance": group_placement_max_distance,
177
181
  },
178
182
  )
183
+ nodepool_used_deps = ["gke_cluster", gpu_subnets_name]
179
184
 
180
185
  a3_megagpu_pool_0 = DeploymentModule(
181
186
  id="a3_megagpu_pool_0",
182
187
  source="modules/compute/gke-node-pool",
183
- use=["gke_cluster", gpu_subnets_name],
188
+ use=nodepool_used_deps,
184
189
  settings={
185
190
  "name": f"{cluster_name}-a3-megagpu-pool-0",
186
191
  "machine_type": system.gce_machine_type,
187
- "static_node_count": num_nodes,
188
192
  "zones": [zone],
189
- "host_maintenance_interval": (
190
- None
191
- if capacity_type == CapacityType.RESERVATION
192
- else "PERIODIC"
193
- ),
193
+ "host_maintenance_interval": reservation_maintenance_interval,
194
194
  "reservation_affinity": self._getblock_reservation_affinity(
195
195
  reservation
196
196
  ),
197
197
  "run_workload_script": False,
198
198
  "spot": capacity_type == CapacityType.SPOT,
199
199
  "max_pods_per_node": 32,
200
- "auto_upgrade": True,
200
+ "guest_accelerator": [{
201
+ "type": "nvidia-h100-mega-80gb",
202
+ "count": 8,
203
+ "gpu_driver_installation_config": {
204
+ "gpu_driver_version": "LATEST"
205
+ },
206
+ }],
207
+ "auto_upgrade": (
208
+ True if capacity_type != CapacityType.FLEX_START else False
209
+ ),
201
210
  },
202
211
  outputs=["instructions"],
203
212
  )
213
+ if capacity_type == CapacityType.FLEX_START:
214
+ a3_megagpu_pool_0.settings.update(self.get_dws_flex_start())
215
+ else:
216
+ a3_megagpu_pool_0.settings.update({"static_node_count": num_nodes})
204
217
 
205
218
  set_placement_policy = capacity_type != CapacityType.SPOT
206
- tas_name = "topologyName: 'gke-default'" if set_placement_policy else ""
207
219
  num_chips = num_nodes * system.chips_per_vm
208
220
  workload = DeploymentModule(
209
221
  id="workload_component_install",
@@ -212,11 +224,16 @@ class BlueprintGenerator:
212
224
  settings={
213
225
  "kueue": {
214
226
  "install": True,
215
- "version": "v0.10.0", # TAS feature-gates is enabled in CT
227
+ "version": KUEUE_VERSION, # TAS feature-gates is enabled in CT
216
228
  "config_path": f'$(ghpc_stage("{blueprint_name}"))/kueue-xpk-configuration.yaml.tftpl',
217
229
  "config_template_vars": {
218
230
  "num_chips": num_chips,
219
- "tas_name": tas_name,
231
+ "reservation": (
232
+ 1 if capacity_type == CapacityType.RESERVATION else 0
233
+ ),
234
+ "flex_start": (
235
+ 1 if capacity_type == CapacityType.FLEX_START else 0
236
+ ),
220
237
  },
221
238
  },
222
239
  "jobset": {"install": True, "version": "v0.7.2"},
@@ -247,6 +264,13 @@ class BlueprintGenerator:
247
264
  }]
248
265
  },
249
266
  )
267
+
268
+ print(reservation_placement_policy)
269
+ if reservation_placement_policy is not None:
270
+ a3_megagpu_pool_0.settings["placement_policy"] = (
271
+ reservation_placement_policy
272
+ )
273
+
250
274
  primary_group = DeploymentGroup(
251
275
  group="primary",
252
276
  modules=[
@@ -258,11 +282,9 @@ class BlueprintGenerator:
258
282
  workload_configmap,
259
283
  ],
260
284
  )
261
-
262
- if set_placement_policy:
285
+ if set_placement_policy and reservation_placement_policy is None:
263
286
  a3_megagpu_pool_0.use.append(group_placement_0.id)
264
287
  primary_group.modules.append(group_placement_0)
265
-
266
288
  a3_mega_blueprint = Blueprint(
267
289
  terraform_backend_defaults=self._getblock_terraform_backend(
268
290
  gcs_bucket, cluster_name, prefix
@@ -478,14 +500,22 @@ class BlueprintGenerator:
478
500
  source="modules/scheduler/gke-cluster",
479
501
  use=[net_0_id],
480
502
  settings={
481
- "release_channel": "RAPID",
482
- "version_prefix": "1.31.",
483
- "maintenance_exclusions": [{
484
- "name": "no-minor-or-node-upgrades-indefinite",
485
- "start_time": "2024-12-01T00:00:00Z",
486
- "end_time": "2025-12-22T00:00:00Z",
487
- "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
488
- }],
503
+ "release_channel": (
504
+ "UNSPECIFIED"
505
+ if capacity_type == CapacityType.FLEX_START
506
+ else "RAPID"
507
+ ),
508
+ "version_prefix": "1.32.",
509
+ "maintenance_exclusions": (
510
+ []
511
+ if capacity_type == CapacityType.FLEX_START
512
+ else [{
513
+ "name": "no-minor-or-node-upgrades-indefinite",
514
+ "start_time": "2024-12-01T00:00:00Z",
515
+ "end_time": "2025-12-22T00:00:00Z",
516
+ "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
517
+ }]
518
+ ),
489
519
  "prefix_with_deployment_name": False,
490
520
  "name_suffix": cluster_name,
491
521
  "system_node_pool_machine_type": system_node_pool_machine_type,
@@ -534,9 +564,10 @@ class BlueprintGenerator:
534
564
  use=[cluster_id],
535
565
  settings={
536
566
  "machine_type": system.gce_machine_type,
537
- "auto_upgrade": True,
567
+ "auto_upgrade": (
568
+ True if capacity_type != CapacityType.FLEX_START else False
569
+ ),
538
570
  "zones": [zone],
539
- "static_node_count": num_nodes,
540
571
  "spot": capacity_type == CapacityType.SPOT,
541
572
  "reservation_affinity": self._getblock_reservation_affinity(
542
573
  reservation
@@ -562,6 +593,10 @@ class BlueprintGenerator:
562
593
  },
563
594
  outputs=["instructions"],
564
595
  )
596
+ if capacity_type == CapacityType.FLEX_START:
597
+ gpu_pool.settings.update(self.get_dws_flex_start())
598
+ else:
599
+ gpu_pool.settings.update({"static_node_count": num_nodes})
565
600
 
566
601
  num_chips = num_nodes * system.chips_per_vm
567
602
  workload_manager_install_id = "workload-manager-install"
@@ -572,9 +607,14 @@ class BlueprintGenerator:
572
607
  settings={
573
608
  "kueue": {
574
609
  "install": True,
575
- "version": "v0.10.0", # TAS feature-gates is enabled in CT
610
+ "version": KUEUE_VERSION, # TAS feature-gates is enabled in CT
576
611
  "config_path": f'$(ghpc_stage("{blueprint_name}"))/kueue-xpk-configuration.yaml.tftpl',
577
- "config_template_vars": {"num_chips": num_chips},
612
+ "config_template_vars": {
613
+ "num_chips": num_chips,
614
+ "flex_start": (
615
+ 1 if capacity_type == CapacityType.FLEX_START else 0
616
+ ),
617
+ },
578
618
  },
579
619
  "jobset": {"install": True, "version": "v0.7.2"},
580
620
  "apply_manifests": [
@@ -777,13 +817,21 @@ class BlueprintGenerator:
777
817
  f" {cluster_name}-rdma-net.subnetwork_interfaces_gke))"
778
818
  ),
779
819
  "version_prefix": "1.32.",
780
- "release_channel": "RAPID",
781
- "maintenance_exclusions": [{
782
- "name": "no-minor-or-node-upgrades-indefinite",
783
- "start_time": "2024-12-01T00:00:00Z",
784
- "end_time": "2025-12-22T00:00:00Z",
785
- "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
786
- }],
820
+ "release_channel": (
821
+ "UNSPECIFIED"
822
+ if capacity_type == CapacityType.FLEX_START
823
+ else "RAPID"
824
+ ),
825
+ "maintenance_exclusions": (
826
+ []
827
+ if capacity_type == CapacityType.FLEX_START
828
+ else [{
829
+ "name": "no-minor-or-node-upgrades-indefinite",
830
+ "start_time": "2024-12-01T00:00:00Z",
831
+ "end_time": "2025-12-22T00:00:00Z",
832
+ "exclusion_scope": "NO_MINOR_OR_NODE_UPGRADES",
833
+ }]
834
+ ),
787
835
  },
788
836
  outputs=["instructions"],
789
837
  )
@@ -800,10 +848,11 @@ class BlueprintGenerator:
800
848
  use=[cluster_id],
801
849
  settings={
802
850
  "machine_type": system.gce_machine_type,
803
- "auto_upgrade": True,
851
+ "auto_upgrade": (
852
+ True if capacity_type != CapacityType.FLEX_START else False
853
+ ),
804
854
  "zones": [zone],
805
855
  "disk_type": "hyperdisk-balanced",
806
- "static_node_count": num_nodes,
807
856
  "local_ssd_count_ephemeral_storage": 32,
808
857
  "spot": capacity_type == CapacityType.SPOT,
809
858
  "reservation_affinity": self._getblock_reservation_affinity(
@@ -830,6 +879,10 @@ class BlueprintGenerator:
830
879
  },
831
880
  outputs=["instructions"],
832
881
  )
882
+ if capacity_type == CapacityType.FLEX_START:
883
+ gpu_pool.settings.update(self.get_dws_flex_start())
884
+ else:
885
+ gpu_pool.settings.update({"static_node_count": num_nodes})
833
886
 
834
887
  num_chips = num_nodes * system.chips_per_vm
835
888
  workload_manager_install_id = "workload-manager-install"
@@ -840,9 +893,14 @@ class BlueprintGenerator:
840
893
  settings={
841
894
  "kueue": {
842
895
  "install": True,
843
- "version": "v0.10.0", # TAS feature-gates is enabled in CT
896
+ "version": KUEUE_VERSION, # TAS feature-gates is enabled in CT
844
897
  "config_path": f'$(ghpc_stage("{blueprint_name}"))/kueue-xpk-configuration.yaml.tftpl',
845
- "config_template_vars": {"num_chips": num_chips},
898
+ "config_template_vars": {
899
+ "num_chips": num_chips,
900
+ "flex_start": (
901
+ 1 if capacity_type == CapacityType.FLEX_START else 0
902
+ ),
903
+ },
846
904
  },
847
905
  "jobset": {"install": True, "version": "v0.7.2"},
848
906
  "apply_manifests": [
@@ -992,6 +1050,16 @@ class BlueprintGenerator:
992
1050
  )
993
1051
  return deployment_files_path
994
1052
 
1053
+ def get_dws_flex_start(self) -> dict:
1054
+ return {
1055
+ "enable_flex_start": True,
1056
+ "enable_queued_provisioning": True,
1057
+ "autoscaling_total_min_nodes": 0,
1058
+ "release_channel": "UNSPECIFIED",
1059
+ "auto_repair": False,
1060
+ "auto_upgrade": False,
1061
+ }
1062
+
995
1063
 
996
1064
  yaml.register_class(Blueprint)
997
1065
  yaml.register_class(DeploymentGroup)
xpk/core/capacity.py CHANGED
@@ -16,8 +16,8 @@ limitations under the License.
16
16
 
17
17
  import enum
18
18
 
19
- from ..utils.console import xpk_print
20
- from .commands import run_command_with_updates
19
+ from ..utils.console import xpk_print, xpk_exit
20
+ from .commands import run_command_with_updates, run_command_for_value
21
21
 
22
22
  AUTOPROVISIONING_CONFIG_VALUE = 'AUTOPROVISION'
23
23
  AUTOPROVISIONING_CONFIG_MINIMUM_KEY = 'minimum_chips'
@@ -36,6 +36,7 @@ class CapacityType(enum.Enum):
36
36
  RESERVATION = 'reservation'
37
37
  SPOT = 'spot'
38
38
  UNKNOWN = 'unknown'
39
+ FLEX_START = 'flex_start'
39
40
 
40
41
 
41
42
  def print_reservations(args) -> int:
@@ -84,6 +85,9 @@ def get_capacity_type(args) -> tuple[CapacityType, int]:
84
85
  if args.spot:
85
86
  capacity_type = CapacityType.SPOT
86
87
  num_types += 1
88
+ if args.flex:
89
+ capacity_type = CapacityType.FLEX_START
90
+ num_types += 1
87
91
 
88
92
  # Check that the number of user arguments provided is valid.
89
93
  if num_types == 0:
@@ -91,14 +95,62 @@ def get_capacity_type(args) -> tuple[CapacityType, int]:
91
95
  elif num_types != 1:
92
96
  xpk_print(
93
97
  'ERROR: User specified more than one of the following arguments. Please'
94
- ' specify only one of `--reservation=$RESERVATION_NAME`, `--on-demand`'
95
- ' or `--spot`.'
98
+ ' specify only one of `--reservation=$RESERVATION_NAME`, `--on-demand`,'
99
+ ' `--flex` or `--spot`.'
96
100
  )
97
101
  return_code = 1
98
102
 
99
103
  return capacity_type, return_code
100
104
 
101
105
 
106
+ def get_reservation_maintenance_interval(
107
+ reservation: str, zone: str, project: str
108
+ ) -> str:
109
+ """Get reservation maintenance interval.
110
+
111
+ Args:
112
+ args: user provided arguments for running the command.
113
+
114
+ Returns:
115
+ 0 if successful and 1 otherwise.
116
+ """
117
+ command = (
118
+ f'gcloud beta compute reservations describe {reservation}'
119
+ f' --project={project} --zone={zone} --format="value(specificReservation.instanceProperties.maintenanceInterval)"'
120
+ )
121
+ return_code, output = run_command_for_value(
122
+ command, 'Get reservation maintenance interval', None
123
+ )
124
+ if return_code != 0:
125
+ xpk_print(f'Get reservation maintenance interval ERROR {return_code}')
126
+ xpk_exit(1)
127
+ return output.strip()
128
+
129
+
130
+ def get_reservation_placement_policy(
131
+ reservation: str, zone: str, project: str
132
+ ) -> str:
133
+ """Get reservation placement policy.
134
+
135
+ Args:
136
+ args: user provided arguments for running the command.
137
+
138
+ Returns:
139
+ 0 if successful and 1 otherwise.
140
+ """
141
+ command = (
142
+ f'gcloud beta compute reservations describe {reservation}'
143
+ f' --project={project} --zone={zone} --format="value(resourcePolicies.policy)"'
144
+ )
145
+ return_code, output = run_command_for_value(
146
+ command, 'Get reservation placement policy', None
147
+ )
148
+ if return_code != 0:
149
+ xpk_print(f'Get reservation placement policy ERROR {return_code}')
150
+ xpk_exit(1)
151
+ return output.strip()
152
+
153
+
102
154
  def verify_reservation_exists(args) -> int:
103
155
  """Verify the reservation exists.
104
156
 
@@ -121,9 +173,9 @@ def verify_reservation_exists(args) -> int:
121
173
 
122
174
 
123
175
  def get_capacity_arguments_from_capacity_type(
124
- args, capacity_type: CapacityType
176
+ args, capacity_type: CapacityType, max_nodes: int
125
177
  ) -> tuple[str, int]:
126
- """Determine the TPU Nodepool creation capacity arguments needed.
178
+ """Determine the Nodepool creation capacity arguments needed.
127
179
 
128
180
  Args:
129
181
  args: user provided arguments for running the command.
@@ -141,6 +193,12 @@ def get_capacity_arguments_from_capacity_type(
141
193
  capacity_args = ''
142
194
  case CapacityType.SPOT:
143
195
  capacity_args = '--spot'
196
+ case CapacityType.FLEX_START:
197
+ capacity_args = (
198
+ ' --flex-start --enable-queued-provisioning --enable-autoscaling'
199
+ ' --location-policy=ANY --reservation-affinity=none'
200
+ f' --no-enable-autorepair --max-nodes={max_nodes}'
201
+ )
144
202
  case CapacityType.RESERVATION:
145
203
  capacity_args = (
146
204
  f'--reservation-affinity=specific --reservation={args.reservation}'
@@ -173,6 +231,8 @@ def get_capacity_node_selectors_from_capacity_type(
173
231
  match capacity_type:
174
232
  case CapacityType.ON_DEMAND.name:
175
233
  node_selector = ''
234
+ case CapacityType.FLEX_START.name:
235
+ node_selector = 'cloud.google.com/gke-queued="true"'
176
236
  case CapacityType.SPOT.name:
177
237
  node_selector = 'cloud.google.com/gke-spot="true"'
178
238
  case CapacityType.RESERVATION.name:
xpk/core/cluster.py CHANGED
@@ -38,7 +38,7 @@ from .resources import get_cluster_system_characteristics
38
38
  from .system_characteristics import SystemCharacteristics
39
39
 
40
40
  JOBSET_VERSION = 'v0.8.0'
41
- PATHWAYS_JOB_VERSION = 'v0.1.1'
41
+ PATHWAYS_JOB_VERSION = 'v0.1.2'
42
42
  INSTALLER_NCCL_TCPX = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-tcpx-installer.yaml'
43
43
  INSTALLER_NCCL_TCPXO = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpxo/nccl-tcpxo-installer.yaml'
44
44
  INSTALLER_NCCL_RDMA = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-rdma/nccl-rdma-installer.yaml'
@@ -315,28 +315,60 @@ def update_cluster_with_pd_driver_if_necessary(args) -> int:
315
315
  return 0
316
316
 
317
317
 
318
- def is_driver_enabled_on_cluster(args, driver: str) -> bool:
318
+ def update_cluster_with_lustre_driver_if_necessary(args) -> int:
319
+ """Updates a GKE cluster to enable Lustre CSI driver, if not enabled already.
320
+ Args:
321
+ args: user provided arguments for running the command.
322
+ Returns:
323
+ 0 if successful and error code otherwise.
324
+ """
325
+ if is_driver_enabled_on_cluster(
326
+ args, driver='lustreCsiDriver'
327
+ ) and is_driver_enabled_on_cluster(
328
+ args, driver='lustreCsiDriver', config_key='enableLegacyLustrePort'
329
+ ):
330
+ return 0
331
+ cluster_update_return_code = update_gke_cluster_with_lustre_driver_enabled(
332
+ args
333
+ )
334
+ if cluster_update_return_code > 0:
335
+ xpk_print(
336
+ 'Updating GKE cluster to enable PersistentDisk CSI driver failed!'
337
+ )
338
+ return cluster_update_return_code
339
+
340
+ return 0
341
+
342
+
343
+ def is_driver_enabled_on_cluster(
344
+ args, driver: str, config_key: str = 'enabled', config_val: str = 'true'
345
+ ) -> bool:
319
346
  """Checks if the CSI driver is enabled on the cluster.
320
347
  Args:
321
348
  args: user provided arguments for running the command.
322
349
  driver (str) : name of the driver
350
+ config (str): the config to look for; by default looks for "enabled" parameter
351
+ config_val (str): the value indicating the enabled; default vale is "true"
323
352
  Returns:
324
353
  True if driver is enabled on the cluster and False otherwise.
325
354
  """
326
355
  command = (
327
356
  f'gcloud container clusters describe {args.cluster}'
328
357
  f' --project={args.project} --region={zone_to_region(args.zone)}'
329
- f' --format="value(addonsConfig.{driver}Config.enabled)"'
358
+ f' --format="value(addonsConfig.{driver}Config.{config_key})"'
330
359
  )
331
360
  return_code, driver_enabled = run_command_for_value(
332
361
  command,
333
- f'Checks if {driver} driver is enabled in cluster describe.',
362
+ f"Checks if {driver} driver's {config_key} is enabled in cluster"
363
+ ' describe.',
334
364
  args,
335
365
  )
336
366
  if return_code != 0:
337
367
  xpk_exit(return_code)
338
- if driver_enabled.strip().lower() == 'true':
339
- xpk_print(f'{driver} driver is enabled on the cluster, no update needed.')
368
+ if driver_enabled.strip().lower() == config_val.lower():
369
+ xpk_print(
370
+ f"{driver} driver's {config_key} config is {config_val} on the cluster."
371
+ )
340
372
  return True
341
373
  return False
342
374
 
@@ -423,6 +455,19 @@ def get_gpu_type_from_cluster(args) -> str:
423
455
  return ''
424
456
 
425
457
 
458
+ def setup_k8s_service_accounts() -> None:
459
+ """
460
+ Creates/sets up SAs and the roles for them
461
+ """
462
+ default_sa = 'default'
463
+
464
+ create_xpk_k8s_service_account()
465
+
466
+ role_name = create_pod_reader_role()
467
+ create_role_binding(default_sa, role_name)
468
+ create_role_binding(XPK_SA, role_name)
469
+
470
+
426
471
  def create_xpk_k8s_service_account() -> None:
427
472
  k8s_core_client = k8s_client.CoreV1Api()
428
473
  sa = k8s_client.V1ServiceAccount(
@@ -441,6 +486,94 @@ def create_xpk_k8s_service_account() -> None:
441
486
  )
442
487
 
443
488
 
489
+ def create_pod_reader_role() -> str:
490
+ """
491
+ Creates the 'pod-reader' Role in the default namespace.
492
+ """
493
+ k8s_rbac_client = k8s_client.RbacAuthorizationV1Api()
494
+ role_name = 'pod-reader'
495
+
496
+ role = k8s_client.V1Role(
497
+ metadata=k8s_client.V1ObjectMeta(
498
+ name=role_name, namespace=DEFAULT_NAMESPACE
499
+ ),
500
+ rules=[
501
+ k8s_client.V1PolicyRule(
502
+ api_groups=[''],
503
+ resources=['pods', 'services'],
504
+ verbs=['get', 'list', 'watch'],
505
+ ),
506
+ k8s_client.V1PolicyRule(
507
+ api_groups=['batch'],
508
+ resources=['jobs'],
509
+ verbs=['get', 'list', 'watch'],
510
+ ),
511
+ ],
512
+ )
513
+
514
+ xpk_print(
515
+ f'Attempting to create Role: {role_name} in namespace:'
516
+ f' {DEFAULT_NAMESPACE}'
517
+ )
518
+ try:
519
+ k8s_rbac_client.create_namespaced_role(DEFAULT_NAMESPACE, role, pretty=True)
520
+ xpk_print(f'Successfully created Role: {role_name}')
521
+ return role_name
522
+ except ApiException as e:
523
+ if e.status == 409: # Conflict, meaning it already exists
524
+ xpk_print(f'Role: {role_name} already exists. Skipping its creation.')
525
+ return role_name
526
+ else:
527
+ xpk_print(f'Error creating Role {role_name}: {e}')
528
+ xpk_exit(1)
529
+
530
+
531
+ def create_role_binding(sa: str, role_name: str) -> None:
532
+ """
533
+ Creates a RoleBinding to associate the Service Account
534
+ with the Role in the default namespace.
535
+ Assumes the Service Account and the Role already exist.
536
+ """
537
+ k8s_rbac_client = k8s_client.RbacAuthorizationV1Api()
538
+ role_binding_name = f'{sa}-{role_name}-binding'
539
+
540
+ role_binding = k8s_client.V1RoleBinding(
541
+ metadata=k8s_client.V1ObjectMeta(
542
+ name=role_binding_name, namespace=DEFAULT_NAMESPACE
543
+ ),
544
+ subjects=[
545
+ k8s_client.RbacV1Subject(
546
+ kind='ServiceAccount', name=sa, namespace=DEFAULT_NAMESPACE
547
+ )
548
+ ],
549
+ role_ref=k8s_client.V1RoleRef(
550
+ kind='Role', name=role_name, api_group='rbac.authorization.k8s.io'
551
+ ),
552
+ )
553
+
554
+ xpk_print(
555
+ f'Attempting to create RoleBinding: {role_binding_name} for Service'
556
+ f' Account: {XPK_SA} to Role: {role_name} in namespace:'
557
+ f' {DEFAULT_NAMESPACE}'
558
+ )
559
+ try:
560
+ k8s_rbac_client.create_namespaced_role_binding(
561
+ DEFAULT_NAMESPACE, role_binding, pretty=True
562
+ )
563
+ xpk_print(
564
+ f'Successfully created RoleBinding: {role_binding_name} for {XPK_SA}'
565
+ )
566
+ except ApiException as e:
567
+ if e.status == 409: # Conflict, meaning it already exists
568
+ xpk_print(
569
+ f'RoleBinding: {role_binding_name} already exists. Skipping its'
570
+ ' creation.'
571
+ )
572
+ else:
573
+ xpk_print(f'Error creating RoleBinding {role_binding_name}: {e}')
574
+ xpk_exit(1)
575
+
576
+
444
577
  def update_gke_cluster_with_clouddns(args) -> int:
445
578
  """Run the GKE cluster update command for existing clusters and enable CloudDNS.
446
579
 
@@ -522,6 +655,32 @@ def update_gke_cluster_with_gcsfuse_driver_enabled(args) -> int:
522
655
  return 0
523
656
 
524
657
 
658
+ def update_gke_cluster_with_lustre_driver_enabled(args) -> int:
659
+ """Run the GKE cluster update command for existing cluster and enable Lustre CSI driver.
660
+ Args:
661
+ args: user provided arguments for running the command.
662
+ Returns:
663
+ 0 if successful and 1 otherwise.
664
+ """
665
+ command = (
666
+ 'gcloud container clusters update'
667
+ f' {args.cluster} --project={args.project}'
668
+ f' --region={zone_to_region(args.zone)}'
669
+ ' --enable-legacy-lustre-port'
670
+ ' --quiet'
671
+ )
672
+ xpk_print(
673
+ 'Updating GKE cluster to enable Lustre CSI driver, may take a while!'
674
+ )
675
+ return_code = run_command_with_updates(
676
+ command, 'GKE Cluster Update to enable Lustre CSI driver', args
677
+ )
678
+ if return_code != 0:
679
+ xpk_print(f'GKE Cluster Update request returned ERROR {return_code}')
680
+ return 1
681
+ return 0
682
+
683
+
525
684
  def upgrade_gke_control_plane_version(args, default_rapid_gke_version) -> int:
526
685
  """Upgrade GKE cluster's control plane version before updating nodepools to use CloudDNS.
527
686
 
@@ -731,7 +890,6 @@ def get_cluster_credentials(args) -> None:
731
890
  command = (
732
891
  'gcloud container clusters get-credentials'
733
892
  f' {args.cluster} --region={zone_to_region(args.zone)}'
734
- ' --dns-endpoint'
735
893
  f' --project={args.project} &&'
736
894
  ' kubectl config view && kubectl config set-context --current'
737
895
  ' --namespace=default'