xpk 0.13.0__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. xpk/commands/batch.py +9 -2
  2. xpk/commands/cluster.py +128 -115
  3. xpk/commands/cluster_gcluster.py +77 -14
  4. xpk/commands/cluster_gcluster_test.py +177 -0
  5. xpk/commands/common.py +10 -28
  6. xpk/commands/info.py +11 -9
  7. xpk/commands/inspector.py +21 -10
  8. xpk/commands/job.py +25 -9
  9. xpk/commands/kind.py +38 -40
  10. xpk/commands/kjob_common.py +4 -4
  11. xpk/commands/run.py +9 -2
  12. xpk/commands/shell.py +13 -10
  13. xpk/commands/storage.py +21 -0
  14. xpk/commands/version.py +0 -4
  15. xpk/commands/workload.py +43 -22
  16. xpk/core/blueprint/blueprint_generator.py +4 -40
  17. xpk/core/blueprint/blueprint_test.py +0 -6
  18. xpk/core/capacity.py +6 -5
  19. xpk/core/cluster.py +91 -194
  20. xpk/core/cluster_private.py +6 -11
  21. xpk/core/commands.py +11 -18
  22. xpk/core/config.py +1 -1
  23. xpk/core/docker_image.py +3 -4
  24. xpk/core/gcloud_context.py +26 -2
  25. xpk/core/gcloud_context_test.py +96 -0
  26. xpk/core/gcluster_manager.py +0 -3
  27. xpk/core/jobset.py +4 -7
  28. xpk/core/kjob.py +14 -27
  29. xpk/core/kueue_manager.py +383 -0
  30. xpk/core/kueue_manager_test.py +542 -0
  31. xpk/core/monitoring.py +1 -1
  32. xpk/core/nap.py +10 -15
  33. xpk/core/network.py +17 -18
  34. xpk/core/nodepool.py +66 -77
  35. xpk/core/nodepool_test.py +198 -1
  36. xpk/core/pathways.py +5 -5
  37. xpk/core/ray.py +10 -14
  38. xpk/core/resources.py +6 -11
  39. xpk/core/scheduling.py +19 -1
  40. xpk/core/scheduling_test.py +31 -0
  41. xpk/core/system_characteristics.py +335 -229
  42. xpk/core/vertex.py +1 -1
  43. xpk/core/workload.py +7 -8
  44. xpk/main.py +2 -4
  45. xpk/parser/cluster.py +7 -0
  46. xpk/parser/cluster_test.py +66 -0
  47. xpk/parser/common.py +11 -0
  48. xpk/parser/workload.py +62 -25
  49. xpk/parser/workload_test.py +82 -0
  50. xpk/utils/feature_flags.py +28 -0
  51. xpk/utils/kueue.py +20 -0
  52. xpk/utils/templates.py +2 -0
  53. xpk/utils/topology.py +37 -0
  54. xpk/utils/topology_test.py +43 -0
  55. xpk/utils/validation.py +79 -55
  56. xpk/utils/validation_test.py +37 -0
  57. {xpk-0.13.0.dist-info → xpk-0.14.0.dist-info}/METADATA +6 -1
  58. xpk-0.14.0.dist-info/RECORD +112 -0
  59. xpk/core/kueue.py +0 -561
  60. xpk-0.13.0.dist-info/RECORD +0 -101
  61. {xpk-0.13.0.dist-info → xpk-0.14.0.dist-info}/WHEEL +0 -0
  62. {xpk-0.13.0.dist-info → xpk-0.14.0.dist-info}/entry_points.txt +0 -0
  63. {xpk-0.13.0.dist-info → xpk-0.14.0.dist-info}/licenses/LICENSE +0 -0
  64. {xpk-0.13.0.dist-info → xpk-0.14.0.dist-info}/top_level.txt +0 -0
xpk/core/network.py CHANGED
@@ -17,7 +17,7 @@ limitations under the License.
17
17
  from ..utils.console import xpk_exit, xpk_print
18
18
  from ..utils.file import write_tmp_file
19
19
  from .commands import run_command_for_value, run_command_with_updates
20
- from .gcloud_context import zone_to_region
20
+ from .gcloud_context import zone_to_region, get_cluster_location
21
21
 
22
22
  # cluster_network_yaml: the config when creating the network for a3 cluster
23
23
  CLUSTER_NETWORK_YAML = """
@@ -126,7 +126,7 @@ def create_cluster_network(args, index) -> int:
126
126
  ' --subnet-mode=custom --mtu=8244'
127
127
  )
128
128
  return_code = run_command_with_updates(
129
- command, 'Create Cluster Network', args, verbose=False
129
+ command, 'Create Cluster Network', verbose=False
130
130
  )
131
131
 
132
132
  if return_code != 0:
@@ -152,7 +152,9 @@ def create_cluster_subnet(args, index) -> int:
152
152
  if return_code > 0:
153
153
  xpk_print('Listing all subnets failed!')
154
154
  return return_code
155
- subnet_name = f'{args.cluster}-{zone_to_region(args.zone)}-sub-{index}'
155
+ subnet_name = (
156
+ f'{args.cluster}-{get_cluster_location(args.project, args.cluster, args.zone)}-sub-{index}'
157
+ )
156
158
  if subnet_name not in existing_subnet_names:
157
159
  command = (
158
160
  f'gcloud compute --project={args.project}'
@@ -161,7 +163,7 @@ def create_cluster_subnet(args, index) -> int:
161
163
  f' --region={zone_to_region(args.zone)} --range=192.168.{index}.0/24'
162
164
  )
163
165
  return_code = run_command_with_updates(
164
- command, 'Create Cluster Subnet', args, verbose=False
166
+ command, 'Create Cluster Subnet', verbose=False
165
167
  )
166
168
 
167
169
  if return_code != 0:
@@ -197,7 +199,7 @@ def create_cluster_firewall_rule(args, index) -> int:
197
199
  ' --rules=tcp:0-65535,udp:0-65535,icmp --source-ranges=192.168.0.0/16'
198
200
  )
199
201
  return_code = run_command_with_updates(
200
- command, 'Create Cluster Firewall Rule', args, verbose=False
202
+ command, 'Create Cluster Firewall Rule', verbose=False
201
203
  )
202
204
 
203
205
  if return_code != 0:
@@ -224,7 +226,7 @@ def create_cluster_network_config(args) -> int:
224
226
  command = f'kubectl apply -f {str(tmp)}'
225
227
 
226
228
  return_code = run_command_with_updates(
227
- command, 'GKE Cluster Create Network Config', args
229
+ command, 'GKE Cluster Create Network Config'
228
230
  )
229
231
  if return_code != 0:
230
232
  xpk_print(
@@ -235,19 +237,14 @@ def create_cluster_network_config(args) -> int:
235
237
  return 0
236
238
 
237
239
 
238
- def get_cluster_subnetworks(args) -> list[str]:
240
+ def get_cluster_subnetworks() -> list[str]:
239
241
  """Gets the list of cluster networks.
240
242
 
241
- Args:
242
- args: user provided arguments for running the command.
243
-
244
243
  Returns:
245
244
  list[str]: list of cluster networks
246
245
  """
247
246
  command = 'kubectl get GKENetworkParamSet'
248
- return_code, stdout = run_command_for_value(
249
- command, 'Get Cluster Networks', args
250
- )
247
+ return_code, stdout = run_command_for_value(command, 'Get Cluster Networks')
251
248
  if return_code != 0:
252
249
  xpk_print('GKE Cluster Get NetworkParamSet failed')
253
250
  xpk_exit(return_code)
@@ -302,7 +299,7 @@ def delete_cluster_subnets(args) -> int:
302
299
  )
303
300
 
304
301
  return_code = run_command_with_updates(
305
- command, 'Delete Cluster Subnet', args, verbose=False
302
+ command, 'Delete Cluster Subnet', verbose=False
306
303
  )
307
304
 
308
305
  if return_code != 0:
@@ -328,7 +325,7 @@ def get_all_networks_programmatic(args) -> tuple[list[str], int]:
328
325
  f' --project={args.project}'
329
326
  )
330
327
  return_code, raw_network_output = run_command_for_value(
331
- command, 'Get All Networks', args
328
+ command, 'Get All Networks'
332
329
  )
333
330
  if return_code != 0:
334
331
  xpk_print(f'Get All Networks returned ERROR {return_code}')
@@ -346,14 +343,16 @@ def get_all_subnets_programmatic(args) -> tuple[list[str], int]:
346
343
  Returns:
347
344
  List of subnets and 0 if successful and 1 otherwise.
348
345
  """
349
- subnet_name_filter = f'{args.cluster}-{zone_to_region(args.zone)}-sub-*'
346
+ subnet_name_filter = (
347
+ f'{args.cluster}-{get_cluster_location(args.project, args.cluster, args.zone)}-sub-*'
348
+ )
350
349
 
351
350
  command = (
352
351
  'gcloud compute networks subnets list'
353
352
  f' --filter=name~"{subnet_name_filter}" --project={args.project}'
354
353
  )
355
354
  return_code, raw_subnets_output = run_command_for_value(
356
- command, 'Get All Subnets', args
355
+ command, 'Get All Subnets'
357
356
  )
358
357
  if return_code != 0:
359
358
  xpk_print(f'Get All Subnets returned ERROR {return_code}')
@@ -380,7 +379,7 @@ def get_all_firewall_rules_programmatic(args) -> tuple[list[str], int]:
380
379
  f' --project={args.project}'
381
380
  )
382
381
  return_code, raw_subnets_output = run_command_for_value(
383
- command, 'Get All Firewall Rules', args
382
+ command, 'Get All Firewall Rules'
384
383
  )
385
384
  if return_code != 0:
386
385
  xpk_print(f'Get All Firewall Rules returned ERROR {return_code}')
xpk/core/nodepool.py CHANGED
@@ -16,6 +16,7 @@ limitations under the License.
16
16
 
17
17
  from typing import List
18
18
  from ..utils.console import get_user_input, xpk_print
19
+ from ..utils.topology import get_topology_product, is_topology_valid
19
20
  from .capacity import (
20
21
  AUTOPROVISIONING_CONFIG_VALUE,
21
22
  H100_MEGA_DEVICE_TYPE,
@@ -25,7 +26,7 @@ from .capacity import (
25
26
  print_reservations,
26
27
  )
27
28
  from .commands import run_command_for_value, run_commands
28
- from .gcloud_context import GkeServerConfig, zone_to_region
29
+ from .gcloud_context import GkeServerConfig, get_cluster_location, zone_to_region
29
30
  from .resources import (
30
31
  CLUSTER_CONFIGMAP_YAML,
31
32
  CLUSTER_RESOURCES_CONFIGMAP,
@@ -33,8 +34,7 @@ from .resources import (
33
34
  create_or_update_cluster_configmap,
34
35
  )
35
36
  from .system_characteristics import AcceleratorType
36
- from functools import reduce
37
- from operator import mul
37
+
38
38
 
39
39
  CLOUD_PLATFORM_AUTH_SCOPE_URL = (
40
40
  '"https://www.googleapis.com/auth/cloud-platform"'
@@ -147,7 +147,7 @@ def run_gke_node_pool_create_command(
147
147
  command = (
148
148
  'gcloud beta container node-pools delete'
149
149
  f' {node_pool_name} --cluster={args.cluster}'
150
- f' --zone={zone_to_region(args.zone)}'
150
+ f' --zone={get_cluster_location(args.project, args.cluster, args.zone)}'
151
151
  f' --project={args.project} --quiet'
152
152
  )
153
153
  task = f'NodepoolDelete-{node_pool_name}'
@@ -173,9 +173,7 @@ def run_gke_node_pool_create_command(
173
173
  ):
174
174
  command = (
175
175
  'gcloud container node-pools update'
176
- f' {node_pool_name} --cluster={args.cluster}'
177
- f' --zone={zone_to_region(args.zone)}'
178
- f' --project={args.project} --quiet'
176
+ f' {node_pool_name} --cluster={args.cluster} --location={get_cluster_location(args.project, args.cluster, args.zone)} --project={args.project} --quiet'
179
177
  ' --workload-metadata=GKE_METADATA'
180
178
  )
181
179
  task = (
@@ -212,7 +210,6 @@ def run_gke_node_pool_create_command(
212
210
  delete_commands,
213
211
  'Delete Nodepools',
214
212
  delete_task_names,
215
- dry_run=args.dry_run,
216
213
  )
217
214
  if max_return_code != 0:
218
215
  xpk_print(f'Delete Nodepools returned ERROR {max_return_code}')
@@ -240,7 +237,6 @@ def run_gke_node_pool_create_command(
240
237
  update_WI_commands,
241
238
  'Enable Workload Identity on existing Nodepools',
242
239
  update_WI_task_names,
243
- dry_run=args.dry_run,
244
240
  )
245
241
  if max_return_code != 0:
246
242
  xpk_print(
@@ -265,12 +261,16 @@ def run_gke_node_pool_create_command(
265
261
  )
266
262
  configmap_yml = {}
267
263
  configmap_yml[resources_configmap_name] = resources_yml
268
- return_code = create_or_update_cluster_configmap(
269
- configmap_yml, args.dry_run
270
- )
264
+ return_code = create_or_update_cluster_configmap(configmap_yml)
271
265
  if return_code != 0:
272
266
  return 1
273
267
 
268
+ placement_args = ''
269
+ if system.requires_workload_policy and is_topology_valid(system.topology):
270
+ placement_policy = f'{args.cluster}-placement-policy'
271
+ ensure_resource_policy_exists(placement_policy, args, system.topology)
272
+ placement_args = f' --placement-policy={placement_policy}'
273
+
274
274
  create_commands = []
275
275
  create_task_names = []
276
276
  for node_pool_name in desired_node_pool_names:
@@ -279,19 +279,18 @@ def run_gke_node_pool_create_command(
279
279
  command = (
280
280
  'gcloud beta container node-pools create'
281
281
  f' {node_pool_name}'
282
- f' --region={zone_to_region(args.zone)}'
282
+ f' --location={get_cluster_location(args.project, args.cluster, args.zone)}'
283
283
  f' --cluster={args.cluster}'
284
284
  f' --project={args.project} --node-locations={args.zone}'
285
285
  f' --machine-type={system.gce_machine_type}'
286
286
  f' --host-maintenance-interval={args.host_maintenance_interval}'
287
287
  f' {capacity_args}'
288
+ f'{placement_args}'
288
289
  ' --enable-gvnic'
289
290
  )
290
291
  if system.accelerator_type == AcceleratorType['TPU']:
291
292
  command += f' --node-version={gke_node_pool_version}'
292
- topology_product = reduce(
293
- mul, (int(x) for x in system.topology.split('x')), 1
294
- )
293
+ topology_product = get_topology_product(system.topology)
295
294
  if capacity_type == CapacityType.FLEX_START:
296
295
  command += ' --num-nodes=0'
297
296
  elif topology_product > 1:
@@ -301,11 +300,18 @@ def run_gke_node_pool_create_command(
301
300
  )
302
301
 
303
302
  if topology_product > 1:
304
- command += ' --placement-type=COMPACT --max-pods-per-node 15'
305
- command += f' --tpu-topology={system.topology}'
303
+ # --placement-type=COMPACT enables group placement policy which
304
+ # is mutually exclusive with workload policy, --tpu-topology should
305
+ # also not be passed when workload policy is used
306
+ if not system.requires_workload_policy:
307
+ command += ' --placement-type=COMPACT'
308
+ command += f' --tpu-topology={system.topology}'
309
+ command += ' --max-pods-per-node 15'
306
310
  command += f' {args.custom_tpu_nodepool_arguments}'
307
311
  elif system.accelerator_type == AcceleratorType['GPU']:
308
- subnet_prefix = f'{args.cluster}-{zone_to_region(args.zone)}'
312
+ subnet_prefix = (
313
+ f'{args.cluster}-{get_cluster_location(args.project, args.cluster, args.zone)}'
314
+ )
309
315
  if capacity_type == CapacityType.FLEX_START:
310
316
  command += ' --num-nodes=0'
311
317
  else:
@@ -348,7 +354,7 @@ def run_gke_node_pool_create_command(
348
354
  continue
349
355
  command = (
350
356
  'gcloud beta container node-pools create'
351
- f' {node_pool_name} --node-version={gke_node_pool_version} --cluster={args.cluster} --project={args.project} --node-locations={args.zone} --region={zone_to_region(args.zone)} --num-nodes=1'
357
+ f' {node_pool_name} --node-version={gke_node_pool_version} --cluster={args.cluster} --project={args.project} --node-locations={args.zone} --location={get_cluster_location(args.project, args.cluster, args.zone)} --num-nodes=1'
352
358
  f' --machine-type={args.pathways_gce_machine_type} --scopes=storage-full,gke-default,{CLOUD_PLATFORM_AUTH_SCOPE_URL} --enable-autoscaling'
353
359
  ' --min-nodes=1 --max-nodes=20'
354
360
  )
@@ -362,7 +368,6 @@ def run_gke_node_pool_create_command(
362
368
  create_commands,
363
369
  'Create Nodepools',
364
370
  create_task_names,
365
- dry_run=args.dry_run,
366
371
  )
367
372
  if max_return_code != 0:
368
373
  xpk_print(f'Create Nodepools returned ERROR {max_return_code}')
@@ -432,11 +437,12 @@ def get_all_nodepools_programmatic(args) -> tuple[list[str], int]:
432
437
  command = (
433
438
  'gcloud beta container node-pools list'
434
439
  ' --cluster'
435
- f' {args.cluster} --project={args.project} --region={zone_to_region(args.zone)}'
440
+ f' {args.cluster} --project={args.project} '
441
+ f'--location={get_cluster_location(args.project, args.cluster, args.zone)}'
436
442
  ' --format="csv[no-heading](name)"'
437
443
  )
438
444
  return_code, raw_nodepool_output = run_command_for_value(
439
- command, 'Get All Node Pools', args
445
+ command, 'Get All Node Pools'
440
446
  )
441
447
  if return_code != 0:
442
448
  xpk_print(f'Get All Node Pools returned ERROR {return_code}')
@@ -460,10 +466,10 @@ def get_nodepool_zone(args, nodepool_name) -> tuple[int, str | None]:
460
466
  command = (
461
467
  f'gcloud beta container node-pools describe {nodepool_name}'
462
468
  f' --cluster {args.cluster} --project={args.project}'
463
- f' --region={zone_to_region(args.zone)} --format="value(locations)"'
469
+ f' --location={get_cluster_location(args.project, args.cluster, args.zone)} --format="value(locations)"'
464
470
  )
465
471
  return_code, nodepool_zone = run_command_for_value(
466
- command, 'Get Node Pool Zone', args, dry_run_return_val=args.zone
472
+ command, 'Get Node Pool Zone', dry_run_return_val=args.zone
467
473
  )
468
474
  if return_code != 0:
469
475
  xpk_print(f'Get Node Pool Zone returned ERROR {return_code}')
@@ -490,13 +496,13 @@ def get_gke_node_pool_version(
490
496
  # By default use the current gke master version for creating node pools.
491
497
  command_description = 'Determine current gke master version'
492
498
  command = (
493
- f'gcloud beta container clusters describe {args.cluster}'
494
- f' --region {zone_to_region(args.zone)} --project {args.project}'
495
- ' --format="value(currentMasterVersion)"'
499
+ f'gcloud beta container clusters describe {args.cluster} --location'
500
+ f' {get_cluster_location(args.project, args.cluster, args.zone)} --project'
501
+ f' {args.project} --format="value(currentMasterVersion)"'
496
502
  )
497
503
 
498
504
  return_code, current_gke_master_version = run_command_for_value(
499
- command, command_description, args
505
+ command, command_description
500
506
  )
501
507
  if return_code != 0:
502
508
  xpk_print(
@@ -540,52 +546,6 @@ def get_gke_node_pool_version(
540
546
  return 0, node_pool_gke_version
541
547
 
542
548
 
543
- def upgrade_gke_nodepools_version(args, default_rapid_gke_version) -> int:
544
- """Upgrade nodepools in the cluster to default rapid gke version. Recreates the nodes.
545
-
546
- Args:
547
- args: user provided arguments for running the command.
548
- default_rapid_gke_version: Rapid default version for the upgrade.
549
-
550
- Returns:
551
- 0 if successful and 1 otherwise.
552
- """
553
- existing_node_pool_names, return_code = get_all_nodepools_programmatic(args)
554
- if return_code != 0:
555
- xpk_print('Listing all node pools failed!')
556
- return return_code
557
-
558
- # Batch execution to upgrade node pools simultaneously
559
- commands = []
560
- task_names = []
561
- for node_pool_name in existing_node_pool_names:
562
- commands.append(
563
- 'gcloud container clusters upgrade'
564
- f' {args.cluster} --project={args.project}'
565
- f' --region={zone_to_region(args.zone)}'
566
- f' --cluster-version={default_rapid_gke_version}'
567
- f' --node-pool={node_pool_name}'
568
- ' --quiet'
569
- )
570
- task_names.append(f'Upgrading node pool {node_pool_name}.')
571
-
572
- for i, command in enumerate(commands):
573
- xpk_print(f'To complete {task_names[i]} we are executing {command}')
574
- max_return_code = run_commands(
575
- commands,
576
- 'Update GKE node pools to default RAPID GKE version',
577
- task_names,
578
- dry_run=args.dry_run,
579
- )
580
- if max_return_code != 0:
581
- xpk_print(
582
- 'GKE node pools update to default RAPID GKE version returned ERROR:'
583
- f' {max_return_code}'
584
- )
585
- return int(max_return_code)
586
- return 0
587
-
588
-
589
549
  def get_nodepool_workload_metadata_mode(
590
550
  args, nodepool_name
591
551
  ) -> tuple[int, str | None]:
@@ -601,10 +561,10 @@ def get_nodepool_workload_metadata_mode(
601
561
  command = (
602
562
  f'gcloud beta container node-pools describe {nodepool_name}'
603
563
  f' --cluster {args.cluster} --project={args.project}'
604
- f' --region={zone_to_region(args.zone)} --format="value(config.workloadMetadataConfig.mode)"'
564
+ f' --location={get_cluster_location(args.project, args.cluster, args.zone)} --format="value(config.workloadMetadataConfig.mode)"'
605
565
  )
606
566
  return_code, nodepool_WI_mode = run_command_for_value(
607
- command, 'Get Node Pool Workload Identity Metadata Mode', args
567
+ command, 'Get Node Pool Workload Identity Metadata Mode'
608
568
  )
609
569
  if return_code != 0:
610
570
  xpk_print(
@@ -632,3 +592,32 @@ def get_desired_node_pool_names(
632
592
  result.add(f'{cluster_name}-np-{i}')
633
593
  i += 1
634
594
  return list(result)
595
+
596
+
597
+ def ensure_resource_policy_exists(
598
+ resource_policy_name: str, args, topology: str
599
+ ) -> None:
600
+ return_code, _ = run_command_for_value(
601
+ (
602
+ 'gcloud compute resource-policies describe'
603
+ f' {resource_policy_name} '
604
+ f'--project={args.project} '
605
+ f'--region={zone_to_region(args.zone)}'
606
+ ),
607
+ 'Retrieve resource policy',
608
+ )
609
+
610
+ if return_code == 0:
611
+ return
612
+
613
+ return_code, _ = run_command_for_value(
614
+ (
615
+ 'gcloud compute resource-policies create workload-policy'
616
+ f' {resource_policy_name} --project={args.project} --region={zone_to_region(args.zone)} --type=HIGH_THROUGHPUT'
617
+ f' --accelerator-topology={topology}'
618
+ ),
619
+ 'Create resource policy',
620
+ )
621
+
622
+ if return_code != 0:
623
+ raise RuntimeError('Unable to create resource policy')
xpk/core/nodepool_test.py CHANGED
@@ -14,7 +14,13 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from xpk.core.nodepool import get_desired_node_pool_names
17
+ import pytest
18
+ from xpk.core.nodepool import (
19
+ ensure_resource_policy_exists,
20
+ get_desired_node_pool_names,
21
+ run_gke_node_pool_create_command,
22
+ )
23
+ from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics
18
24
 
19
25
  CLUSTER_NAME = "running-cucumber"
20
26
 
@@ -80,3 +86,194 @@ def test_compute_desired_node_pool_names_with_unknown_node_pools():
80
86
 
81
87
  expected_result = [node_pool_name(0), node_pool_name(3)]
82
88
  assert set(result) == set(expected_result)
89
+
90
+
91
+ def test_ensure_resource_policy_exists_with_existing_policy_retrieves_existing_policy(
92
+ mocker,
93
+ ):
94
+ args = mocker.Mock(project="test-project", zone="us-central1-a")
95
+ mocker.patch("xpk.core.nodepool.get_cluster_location", return_value=args.zone)
96
+ mock = mocker.patch(
97
+ "xpk.core.nodepool.run_command_for_value", return_value=(0, "")
98
+ )
99
+ ensure_resource_policy_exists("resource-policy", args, "2x2x1")
100
+ mock.assert_called_once()
101
+
102
+
103
+ def test_ensure_resource_policy_exists_without_existing_policy_creates_policy(
104
+ mocker,
105
+ ):
106
+ args = mocker.Mock(project="test-project", zone="us-central1-a")
107
+ mocker.patch("xpk.core.nodepool.get_cluster_location", return_value=args.zone)
108
+ mock = mocker.patch(
109
+ "xpk.core.nodepool.run_command_for_value", side_effect=[(1, ""), (0, "")]
110
+ )
111
+ ensure_resource_policy_exists("resource-policy", args, "2x2x1")
112
+ assert mock.call_count == 2
113
+ assert mock.call_args_list[0].args[1] == "Retrieve resource policy"
114
+
115
+
116
+ def test_ensure_resource_policy_exits_without_existing_policy_throws_when_creation_fails(
117
+ mocker,
118
+ ):
119
+ with pytest.raises(RuntimeError):
120
+ args = mocker.Mock(project="test-project", zone="us-central1-a")
121
+ mocker.patch(
122
+ "xpk.core.nodepool.get_cluster_location", return_value=args.zone
123
+ )
124
+ mocker.patch(
125
+ "xpk.core.nodepool.run_command_for_value",
126
+ side_effect=[(1, ""), (1, "")],
127
+ )
128
+ ensure_resource_policy_exists("resource-policy", args, "2x2x1")
129
+
130
+
131
+ @pytest.fixture
132
+ def mock_nodepool_dependencies(mocker):
133
+ """Mocks dependencies for run_gke_node_pool_create_command."""
134
+ mocker.patch(
135
+ "xpk.core.nodepool.get_all_nodepools_programmatic", return_value=([], 0)
136
+ )
137
+ mocker.patch(
138
+ "xpk.core.nodepool.get_capacity_type", return_value=("on-demand", 0)
139
+ )
140
+ mocker.patch(
141
+ "xpk.core.nodepool.get_capacity_arguments_from_capacity_type",
142
+ return_value=("--on-demand", 0),
143
+ )
144
+ mocker.patch(
145
+ "xpk.core.nodepool.get_cluster_location", return_value="us-central1"
146
+ )
147
+ mocker.patch("xpk.core.nodepool.run_commands", return_value=0)
148
+ mocker.patch("xpk.core.nodepool.get_user_input", return_value=True)
149
+ mock_is_topology_valid = mocker.patch("xpk.core.nodepool.is_topology_valid")
150
+ mock_ensure_resource_policy = mocker.patch(
151
+ "xpk.core.nodepool.ensure_resource_policy_exists"
152
+ )
153
+ return mock_is_topology_valid, mock_ensure_resource_policy
154
+
155
+
156
+ def test_placement_policy_created_for_gpu_with_valid_topology(
157
+ mocker, mock_nodepool_dependencies
158
+ ):
159
+ """Tests that placement policy is created for GPUs with a valid topology."""
160
+ mock_is_topology_valid, mock_ensure_resource_policy = (
161
+ mock_nodepool_dependencies
162
+ )
163
+ mock_is_topology_valid.return_value = True
164
+ args = mocker.Mock(
165
+ tpu_type=None,
166
+ device_type="h100-80gb-8",
167
+ cluster="test-cluster",
168
+ project="test-project",
169
+ zone="us-central1-a",
170
+ )
171
+ system = SystemCharacteristics(
172
+ topology="N/A",
173
+ vms_per_slice=1,
174
+ gke_accelerator="nvidia-h100-80gb",
175
+ gce_machine_type="a3-highgpu-8g",
176
+ chips_per_vm=8,
177
+ accelerator_type=AcceleratorType["GPU"],
178
+ device_type="h100-80gb-8",
179
+ supports_sub_slicing=False,
180
+ )
181
+
182
+ run_gke_node_pool_create_command(args, system, "1.2.3")
183
+
184
+ mock_ensure_resource_policy.assert_called_once()
185
+
186
+
187
+ def test_placement_policy_not_created_for_gpu_with_invalid_topology(
188
+ mocker, mock_nodepool_dependencies
189
+ ):
190
+ """Tests that placement policy is not created for GPUs with an invalid topology."""
191
+ mock_is_topology_valid, mock_ensure_resource_policy = (
192
+ mock_nodepool_dependencies
193
+ )
194
+ mock_is_topology_valid.return_value = False
195
+ args = mocker.Mock(
196
+ tpu_type=None,
197
+ device_type="h100-80gb-8",
198
+ cluster="test-cluster",
199
+ zone="us-central1-a",
200
+ )
201
+ system = SystemCharacteristics(
202
+ topology="N/A",
203
+ vms_per_slice=1,
204
+ gke_accelerator="nvidia-h100-80gb",
205
+ gce_machine_type="a3-highgpu-8g",
206
+ chips_per_vm=8,
207
+ accelerator_type=AcceleratorType["GPU"],
208
+ device_type="h100-80gb-8",
209
+ supports_sub_slicing=False,
210
+ )
211
+
212
+ run_gke_node_pool_create_command(args, system, "1.2.3")
213
+
214
+ mock_ensure_resource_policy.assert_not_called()
215
+
216
+
217
+ def test_placement_policy_created_for_tpu7x_with_valid_topology(
218
+ mocker, mock_nodepool_dependencies
219
+ ):
220
+ """Tests that placement policy is created for tpu7x with a valid topology."""
221
+ mock_is_topology_valid, mock_ensure_resource_policy = (
222
+ mock_nodepool_dependencies
223
+ )
224
+ mock_is_topology_valid.return_value = True
225
+ args = mocker.Mock(
226
+ tpu_type="tpu7x-8",
227
+ device_type=None,
228
+ num_slices=1,
229
+ cluster="test-cluster",
230
+ project="test-project",
231
+ zone="us-central1-a",
232
+ )
233
+ system = SystemCharacteristics(
234
+ topology="2x2x1",
235
+ vms_per_slice=1,
236
+ gke_accelerator="tpu7x",
237
+ gce_machine_type="tpu7x-standard-4t",
238
+ chips_per_vm=4,
239
+ accelerator_type=AcceleratorType["TPU"],
240
+ device_type="tpu7x-8",
241
+ requires_workload_policy=True,
242
+ supports_sub_slicing=False,
243
+ )
244
+
245
+ run_gke_node_pool_create_command(args, system, "1.2.3")
246
+
247
+ mock_ensure_resource_policy.assert_called_once()
248
+
249
+
250
+ def test_placement_policy_not_created_for_non7x_tpu(
251
+ mocker, mock_nodepool_dependencies
252
+ ):
253
+ """Tests that placement policy is not created for non-tpu7x TPUs."""
254
+ mock_is_topology_valid, mock_ensure_resource_policy = (
255
+ mock_nodepool_dependencies
256
+ )
257
+ mock_is_topology_valid.return_value = True
258
+ args = mocker.Mock(
259
+ tpu_type="v6e",
260
+ device_type=None,
261
+ num_slices=1,
262
+ cluster="test-cluster",
263
+ project="test-project",
264
+ zone="us-central1-a",
265
+ )
266
+ system = SystemCharacteristics(
267
+ topology="2x2",
268
+ vms_per_slice=1,
269
+ gke_accelerator="v6e",
270
+ gce_machine_type="tpu-v6e-slice",
271
+ chips_per_vm=4,
272
+ accelerator_type=AcceleratorType["TPU"],
273
+ device_type="v6e-4",
274
+ supports_sub_slicing=True,
275
+ )
276
+
277
+ run_gke_node_pool_create_command(args, system, "1.2.3")
278
+
279
+ mock_ensure_resource_policy.assert_not_called()
xpk/core/pathways.py CHANGED
@@ -16,7 +16,7 @@ limitations under the License.
16
16
 
17
17
  from ..core.commands import run_command_for_value, run_command_with_updates, run_commands
18
18
  from ..core.docker_container import get_user_workload_container
19
- from ..core.gcloud_context import zone_to_region
19
+ from ..core.gcloud_context import get_cluster_location
20
20
  from ..core.nodepool import get_all_nodepools_programmatic
21
21
  from ..utils.console import xpk_exit, xpk_print
22
22
  from ..utils.execution_context import is_dry_run
@@ -116,7 +116,7 @@ def check_if_pathways_job_is_installed(args) -> bool:
116
116
  ' custom-columns=NAME:.metadata.name'
117
117
  )
118
118
  task = f'Check if PathwaysJob is installed on {args.cluster}'
119
- return_code, return_msg = run_command_for_value(command, task, args)
119
+ return_code, return_msg = run_command_for_value(command, task)
120
120
  # return_msg contains the name of the controller pod, if found.
121
121
  xpk_print('check_if_pathways_job_is_installed', return_code, return_msg)
122
122
 
@@ -138,7 +138,7 @@ def get_pathways_unified_query_link(args) -> str:
138
138
  query_params = (
139
139
  'resource.type%3D"k8s_container"%0A'
140
140
  f'resource.labels.project_id%3D"{args.project}"%0A'
141
- f'resource.labels.location%3D"{zone_to_region(args.zone)}"%0A'
141
+ f'resource.labels.location%3D"{get_cluster_location(args.project, args.cluster, args.zone)}"%0A'
142
142
  f'resource.labels.cluster_name%3D"{args.cluster}"%0A'
143
143
  f'resource.labels.pod_name:"{args.workload}-"%0A'
144
144
  'severity>%3DDEFAULT'
@@ -323,10 +323,10 @@ def try_to_delete_pathwaysjob_first(args, workloads) -> bool:
323
323
 
324
324
  # Not batching deletion for single workload
325
325
  if len(workloads) == 1:
326
- return_code = run_command_with_updates(commands[0], 'Delete Workload', args)
326
+ return_code = run_command_with_updates(commands[0], 'Delete Workload')
327
327
  else:
328
328
  return_code = run_commands(
329
- commands, 'Delete Workload', task_names, batch=100, dry_run=args.dry_run
329
+ commands, 'Delete Workload', task_names, batch=100
330
330
  )
331
331
 
332
332
  if return_code != 0: