xpk 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. xpk/__init__.py +15 -0
  2. xpk/commands/__init__.py +15 -0
  3. xpk/commands/batch.py +109 -0
  4. xpk/commands/cluster.py +784 -0
  5. xpk/commands/cluster_gcluster.py +185 -0
  6. xpk/commands/info.py +245 -0
  7. xpk/commands/inspector.py +363 -0
  8. xpk/commands/job.py +197 -0
  9. xpk/commands/kind.py +253 -0
  10. xpk/commands/shell.py +120 -0
  11. xpk/commands/version.py +39 -0
  12. xpk/commands/workload.py +692 -0
  13. xpk/core/__init__.py +15 -0
  14. xpk/core/blueprint/__init__.py +15 -0
  15. xpk/core/blueprint/blueprint_definitions.py +61 -0
  16. xpk/core/blueprint/blueprint_generator.py +652 -0
  17. xpk/core/cluster_private.py +197 -0
  18. xpk/core/commands.py +352 -0
  19. xpk/core/core.py +2824 -0
  20. xpk/core/docker_manager.py +308 -0
  21. xpk/core/gcluster_manager.py +158 -0
  22. xpk/core/kjob.py +205 -0
  23. xpk/core/kueue.py +352 -0
  24. xpk/core/nap.py +349 -0
  25. xpk/core/pathways.py +298 -0
  26. xpk/core/ray.py +222 -0
  27. xpk/core/system_characteristics.py +1395 -0
  28. xpk/core/workload.py +133 -0
  29. xpk/core/workload_decorators/__init__.py +15 -0
  30. xpk/core/workload_decorators/rdma_decorator.py +109 -0
  31. xpk/core/workload_decorators/tcpxo_decorator.py +157 -0
  32. xpk/main.py +73 -0
  33. xpk/parser/__init__.py +15 -0
  34. xpk/parser/batch.py +184 -0
  35. xpk/parser/cluster.py +621 -0
  36. xpk/parser/common.py +71 -0
  37. xpk/parser/core.py +109 -0
  38. xpk/parser/info.py +63 -0
  39. xpk/parser/inspector.py +65 -0
  40. xpk/parser/job.py +126 -0
  41. xpk/parser/kind.py +94 -0
  42. xpk/parser/shell.py +50 -0
  43. xpk/parser/validators.py +39 -0
  44. xpk/parser/version.py +23 -0
  45. xpk/parser/workload.py +684 -0
  46. xpk/utils/__init__.py +15 -0
  47. xpk/utils/console.py +55 -0
  48. xpk/utils/file.py +82 -0
  49. xpk/utils/network.py +168 -0
  50. xpk/utils/objects.py +85 -0
  51. xpk/utils/yaml.py +30 -0
  52. {xpk-0.5.0.dist-info → xpk-0.6.0.dist-info}/METADATA +301 -28
  53. xpk-0.6.0.dist-info/RECORD +57 -0
  54. {xpk-0.5.0.dist-info → xpk-0.6.0.dist-info}/WHEEL +1 -1
  55. xpk-0.6.0.dist-info/entry_points.txt +2 -0
  56. xpk-0.5.0.dist-info/RECORD +0 -7
  57. xpk-0.5.0.dist-info/entry_points.txt +0 -2
  58. xpk.py +0 -7282
  59. {xpk-0.5.0.dist-info → xpk-0.6.0.dist-info}/LICENSE +0 -0
  60. {xpk-0.5.0.dist-info → xpk-0.6.0.dist-info}/top_level.txt +0 -0
xpk/core/core.py ADDED
@@ -0,0 +1,2824 @@
1
+ """
2
+ Copyright 2023 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ r"""xpk (Accelerated Processing Kit).
18
+
19
+ Next Steps:
20
+ - Cluster describe is broken by Cacheimage since that counts as a workload.
21
+ - Cluster describe: count by jobset.
22
+ - If any instance goes down, bring down the whole job.
23
+ - How to more gracefully handle job failures, distinguishing between software
24
+ and infra?
25
+ - Look into --docker-name and --docker-image.
26
+ Shouldn't one string be adequate to express what we want?
27
+ - Apply learnings from about private, region, coredns, etc:
28
+ - Enable special preheater
29
+ - Make Argparse logic this a function?
30
+ - Obvious logic that starts in main instead of here in code but args will
31
+ not be a universal argument.
32
+ """
33
+
34
+ import datetime
35
+ import enum
36
+ import os
37
+ import random
38
+ import re
39
+ import string
40
+ import subprocess
41
+ import sys
42
+ from dataclasses import dataclass
43
+
44
+ from ..utils.file import write_tmp_file
45
+ from ..utils.console import get_user_input, xpk_exit, xpk_print
46
+ from .commands import (
47
+ run_command_for_value,
48
+ run_command_with_updates,
49
+ run_command_with_updates_retry,
50
+ run_commands,
51
+ )
52
+ from .system_characteristics import (
53
+ AcceleratorType,
54
+ AcceleratorTypeToAcceleratorCharacteristics,
55
+ SystemCharacteristics,
56
+ )
57
+
58
+ ################### Internally used constants ##############
59
+
60
+ default_docker_image = 'python:3.10'
61
+ default_script_dir = os.getcwd()
62
+ # This is the version for XPK PyPI package
63
+ __version__ = '0.6.0'
64
+ xpk_current_version = __version__
65
+
66
+ h100_device_type = 'h100-80gb-8'
67
+ h100_mega_device_type = 'h100-mega-80gb-8'
68
+ h200_device_type = 'h200-141gb-8'
69
+
70
+ JOBSET_VERSION = 'v0.7.2'
71
+
72
+ CAPACITY_TYPE_CONFIG_KEY = 'capacity_type'
73
+ RESERVATION_CONFIG_KEY = 'reservation_id'
74
+ _DEFAULT_POOL_NAME = 'default-pool'
75
+ CLUSTER_RESOURCES_CONFIGMAP = 'resources-configmap'
76
+ CLUSTER_METADATA_CONFIGMAP = 'metadata-configmap'
77
+ VERTEX_TENSORBOARD_FEATURE_FLAG = xpk_current_version >= '0.4.0'
78
+ DEFAULT_VERTEX_TENSORBOARD_NAME = 'tb-instance'
79
+ AUTOPROVISIONING_CONFIG_VALUE = 'AUTOPROVISION'
80
+ AUTOPROVISIONING_CONFIG_MINIMUM_KEY = 'minimum_chips'
81
+ AUTOPROVISIONING_CONFIG_MAXIMUM_KEY = 'maximum_chips'
82
+ CLOUD_PLATFORM_AUTH_SCOPE_URL = (
83
+ '"https://www.googleapis.com/auth/cloud-platform"'
84
+ )
85
+ PLATFORM = 'linux/amd64'
86
+
87
+
88
+ class CapacityType(enum.Enum):
89
+ ON_DEMAND = 'on_demand'
90
+ RESERVATION = 'reservation'
91
+ SPOT = 'spot'
92
+ UNKNOWN = 'unknown'
93
+
94
+
95
+ @dataclass
96
+ class AutoprovisioningConfig:
97
+ config_filename: str
98
+ minimum_chips: int
99
+ maximum_chips: int
100
+
101
+
102
+ cluster_configmap_yaml = """kind: ConfigMap
103
+ apiVersion: v1
104
+ metadata:
105
+ name: {name}
106
+ data:
107
+ {data}
108
+ """
109
+
110
+ # cluster_network_yaml: the config when creating the network for a3 cluster
111
+ cluster_network_yaml = """
112
+ apiVersion: networking.gke.io/v1
113
+ kind: Network
114
+ metadata:
115
+ name: vpc1
116
+ spec:
117
+ parametersRef:
118
+ group: networking.gke.io
119
+ kind: GKENetworkParamSet
120
+ name: vpc1
121
+ type: Device
122
+ ---
123
+ apiVersion: networking.gke.io/v1
124
+ kind: Network
125
+ metadata:
126
+ name: vpc2
127
+ spec:
128
+ parametersRef:
129
+ group: networking.gke.io
130
+ kind: GKENetworkParamSet
131
+ name: vpc2
132
+ type: Device
133
+ ---
134
+ apiVersion: networking.gke.io/v1
135
+ kind: Network
136
+ metadata:
137
+ name: vpc3
138
+ spec:
139
+ parametersRef:
140
+ group: networking.gke.io
141
+ kind: GKENetworkParamSet
142
+ name: vpc3
143
+ type: Device
144
+ ---
145
+ apiVersion: networking.gke.io/v1
146
+ kind: Network
147
+ metadata:
148
+ name: vpc4
149
+ spec:
150
+ parametersRef:
151
+ group: networking.gke.io
152
+ kind: GKENetworkParamSet
153
+ name: vpc4
154
+ type: Device
155
+ ---
156
+ apiVersion: networking.gke.io/v1
157
+ kind: GKENetworkParamSet
158
+ metadata:
159
+ name: vpc1
160
+ spec:
161
+ vpc: {cluster_name}-net-1
162
+ vpcSubnet: {cluster_name}-sub-1
163
+ deviceMode: NetDevice
164
+ ---
165
+ apiVersion: networking.gke.io/v1
166
+ kind: GKENetworkParamSet
167
+ metadata:
168
+ name: vpc2
169
+ spec:
170
+ vpc: {cluster_name}-net-2
171
+ vpcSubnet: {cluster_name}-sub-2
172
+ deviceMode: NetDevice
173
+ ---
174
+ apiVersion: networking.gke.io/v1
175
+ kind: GKENetworkParamSet
176
+ metadata:
177
+ name: vpc3
178
+ spec:
179
+ vpc: {cluster_name}-net-3
180
+ vpcSubnet: {cluster_name}-sub-3
181
+ deviceMode: NetDevice
182
+ ---
183
+ apiVersion: networking.gke.io/v1
184
+ kind: GKENetworkParamSet
185
+ metadata:
186
+ name: vpc4
187
+ spec:
188
+ vpc: {cluster_name}-net-4
189
+ vpcSubnet: {cluster_name}-sub-4
190
+ deviceMode: NetDevice
191
+ """
192
+
193
+
194
+ def add_zone_and_project(args):
195
+ """Obtains the zone and project names from gcloud configs if not defined.
196
+
197
+ Args:
198
+ args: user provided arguments for running the command.
199
+ """
200
+ if not args.project:
201
+ args.project = get_project()
202
+ if not args.zone:
203
+ args.zone = get_zone()
204
+ xpk_print(f'Working on {args.project} and {args.zone}')
205
+
206
+
207
+ def parse_env_config(args, tensorboard_config, system: SystemCharacteristics):
208
+ """Parses the environment configurations to the jobset config.
209
+
210
+ Args:
211
+ args: user provided arguments for running the command.
212
+ tensorboard_config: configuration of Vertex Tensorboard.
213
+ system: system characteristics.
214
+ """
215
+ env = {}
216
+
217
+ env_pat = re.compile(r'(^[a-zA-Z_][a-zA-Z0-9_]*?)(?:=(.*))?$', re.M)
218
+ if args.env_file:
219
+ print('Setting container environment from', args.env_file)
220
+ with open(file=args.env_file, mode='r', encoding='utf-8') as f:
221
+ for match in env_pat.finditer(f.read()):
222
+ variable = match.group(1)
223
+ if match.group(2) is not None:
224
+ env[variable] = match.group(2)
225
+ else:
226
+ assert variable in os.environ, (
227
+ f'Variable {variable} is not set in the current '
228
+ 'environment, a value must be specified.'
229
+ )
230
+ env[variable] = os.environ[variable]
231
+ if args.env:
232
+ for var in args.env:
233
+ match = env_pat.match(var)
234
+ assert match and match.group(2) is not None, (
235
+ 'Invalid environment variable, format must be '
236
+ f'`--env VARIABLE=value`: {var}'
237
+ )
238
+ variable = match.group(1)
239
+ env[variable] = match.group(2)
240
+
241
+ if not args.use_pathways:
242
+ if args.debug_dump_gcs:
243
+ if 'XLA_FLAGS' in env:
244
+ raise ValueError(
245
+ 'Conflict: XLA_FLAGS defined in both --debug_dump_gcs '
246
+ 'and environment file. Please choose one way to define '
247
+ 'XLA_FLAGS.'
248
+ )
249
+ env['XLA_FLAGS'] = '--xla_dump_to=/tmp/xla_dump/'
250
+
251
+ if tensorboard_config:
252
+ env['UPLOAD_DATA_TO_TENSORBOARD'] = True
253
+ for key, value in tensorboard_config.items():
254
+ env[key.upper()] = value
255
+
256
+ if system.accelerator_type == AcceleratorType['GPU']:
257
+ # For GPUs, it has two more spaces ahead of name and value respectively
258
+ env_format = '''
259
+ - name: {key}
260
+ value: "{value}"'''
261
+ else:
262
+ env_format = '''
263
+ - name: {key}
264
+ value: "{value}"'''
265
+
266
+ args.env = ''.join(env_format.format(key=k, value=v) for k, v in env.items())
267
+
268
+
269
+ def get_project():
270
+ """Get GCE project from `gcloud config get project`.
271
+
272
+ Returns:
273
+ The project name.
274
+ """
275
+ completed_command = subprocess.run(
276
+ ['gcloud', 'config', 'get', 'project'], check=True, capture_output=True
277
+ )
278
+ project_outputs = completed_command.stdout.decode().strip().split('\n')
279
+ if len(project_outputs) < 1 or project_outputs[-1] == '':
280
+ sys.exit(
281
+ 'You must specify the project in the project flag or set it with'
282
+ " 'gcloud config set project <project>'"
283
+ )
284
+ return project_outputs[
285
+ -1
286
+ ] # The project name lives on the last line of the output
287
+
288
+
289
+ def get_zone():
290
+ """Get GCE zone from `gcloud config get compute/zone`.
291
+
292
+ Returns:
293
+ The zone name.
294
+ """
295
+ completed_command = subprocess.run(
296
+ ['gcloud', 'config', 'get', 'compute/zone'],
297
+ check=True,
298
+ capture_output=True,
299
+ )
300
+ zone_outputs = completed_command.stdout.decode().strip().split('\n')
301
+ if len(zone_outputs) < 1 or zone_outputs[-1] == '':
302
+ sys.exit(
303
+ "You must specify the zone in the zone flag or set it with 'gcloud"
304
+ " config set compute/zone <zone>'"
305
+ )
306
+ return zone_outputs[-1] # The zone name lives on the last line of the output
307
+
308
+
309
+ def zone_to_region(zone) -> str:
310
+ """Helper function converts zone name to region name.
311
+
312
+ Args:
313
+ zone: zone name.
314
+
315
+ Returns:
316
+ The region name.
317
+ """
318
+ zone_terms = zone.split('-')
319
+ return zone_terms[0] + '-' + zone_terms[1]
320
+
321
+
322
+ def get_total_chips_requested_from_args(
323
+ args, system: SystemCharacteristics
324
+ ) -> int:
325
+ """Return the total chips requested based on user args.
326
+
327
+ Args:
328
+ args: user provided arguments for running the command.
329
+ system: system characteristics.
330
+
331
+ Returns:
332
+ num of chips for the current request.
333
+ """
334
+ if system.accelerator_type == AcceleratorType['GPU']:
335
+ num_chips = system.vms_per_slice * system.chips_per_vm * args.num_nodes
336
+ else:
337
+ num_chips = system.vms_per_slice * system.chips_per_vm * args.num_slices
338
+
339
+ return num_chips
340
+
341
+
342
+ def set_up_cluster_network_for_gpu(args, system: SystemCharacteristics) -> int:
343
+ """Set up GKE Cluster networks, subnets and firewall rules for A3/A3+.
344
+ Note: there are 4 NICs for GPU-GPU bw and 1 NIC for host in an A3 node,
345
+ and there are 8 NICs for GPU-GPU bw and 1 NIC for host in an A3+ node.
346
+
347
+ Args:
348
+ args: user provided arguments for running the command.
349
+ system: system characteristics.
350
+
351
+ Returns:
352
+ 0 if successful and 1 otherwise.
353
+ """
354
+ num_networks = 5 if system.device_type == h100_device_type else 9
355
+ for i in range(1, num_networks):
356
+ return_code = create_cluster_network(args, i)
357
+ if return_code != 0:
358
+ return 1
359
+ return_code = create_cluster_subnet(args, i)
360
+ if return_code != 0:
361
+ return 1
362
+ return_code = create_cluster_firewall_rule(args, i)
363
+ if return_code != 0:
364
+ return 1
365
+ return 0
366
+
367
+
368
+ def create_cluster_network(args, index) -> int:
369
+ """Create one GKE Cluster network.
370
+
371
+ Args:
372
+ args: user provided arguments for running the command.
373
+ index: index number for the network to be created.
374
+
375
+ Returns:
376
+ 0 if successful and 1 otherwise.
377
+ """
378
+ existing_network_names, return_code = get_all_networks_programmatic(args)
379
+ if return_code > 0:
380
+ xpk_print('Listing all networks failed!')
381
+ return return_code
382
+
383
+ network_name = f'{args.cluster}-net-{index}'
384
+ if network_name not in existing_network_names:
385
+ command = (
386
+ f'gcloud compute --project={args.project}'
387
+ f' networks create {network_name}'
388
+ ' --subnet-mode=custom --mtu=8244'
389
+ )
390
+ return_code = run_command_with_updates(
391
+ command, 'Create Cluster Network', args, verbose=False
392
+ )
393
+
394
+ if return_code != 0:
395
+ xpk_print(f'Create Cluster Network request returned ERROR {return_code}')
396
+ return 1
397
+ else:
398
+ xpk_print(f'Reusing existing network {network_name}')
399
+
400
+ return 0
401
+
402
+
403
+ def create_cluster_subnet(args, index) -> int:
404
+ """Create one GKE Cluster subnet.
405
+
406
+ Args:
407
+ args: user provided arguments for running the command.
408
+ index: index number for the subnet to be created.
409
+
410
+ Returns:
411
+ 0 if successful and 1 otherwise.
412
+ """
413
+ existing_subnet_names, return_code = get_all_subnets_programmatic(args)
414
+ if return_code > 0:
415
+ xpk_print('Listing all subnets failed!')
416
+ return return_code
417
+ subnet_name = f'{args.cluster}-{zone_to_region(args.zone)}-sub-{index}'
418
+ if subnet_name not in existing_subnet_names:
419
+ command = (
420
+ f'gcloud compute --project={args.project}'
421
+ f' networks subnets create {subnet_name}'
422
+ f' --network={args.cluster}-net-{index}'
423
+ f' --region={zone_to_region(args.zone)} --range=192.168.{index}.0/24'
424
+ )
425
+ return_code = run_command_with_updates(
426
+ command, 'Create Cluster Subnet', args, verbose=False
427
+ )
428
+
429
+ if return_code != 0:
430
+ xpk_print(f'Create Cluster Subnet request returned ERROR {return_code}')
431
+ return 1
432
+ else:
433
+ xpk_print(f'Reusing existing subnet {subnet_name}')
434
+
435
+ return 0
436
+
437
+
438
+ def delete_cluster_subnets(args) -> int:
439
+ """Delete GKE Cluster subnets.
440
+
441
+ Args:
442
+ args: user provided arguments for running the command.
443
+
444
+ Returns:
445
+ 0 if successful and 1 otherwise.
446
+ """
447
+ existing_subnet_names, return_code = get_all_subnets_programmatic(args)
448
+ if return_code > 0:
449
+ xpk_print('Listing all subnets failed!')
450
+ return return_code
451
+
452
+ for subnet_name in existing_subnet_names:
453
+ command = (
454
+ f'gcloud compute networks subnets delete {subnet_name}'
455
+ f' --region={zone_to_region(args.zone)} --project={args.project} --quiet'
456
+ )
457
+
458
+ return_code = run_command_with_updates(
459
+ command, 'Delete Cluster Subnet', args, verbose=False
460
+ )
461
+
462
+ if return_code != 0:
463
+ xpk_print(f'Delete Cluster Subnet request returned ERROR {return_code}')
464
+ return 1
465
+ else:
466
+ xpk_print(f'Deleted existing subnet {subnet_name}')
467
+
468
+ return 0
469
+
470
+
471
+ def create_cluster_firewall_rule(args, index) -> int:
472
+ """Create one GKE Cluster firewall rule.
473
+
474
+ Args:
475
+ args: user provided arguments for running the command.
476
+ index: index number for the firewall rule to be created.
477
+
478
+ Returns:
479
+ 0 if successful and 1 otherwise.
480
+ """
481
+ existing_firewall_rules_names, return_code = (
482
+ get_all_firewall_rules_programmatic(args)
483
+ )
484
+ if return_code > 0:
485
+ xpk_print('Listing all firewall rules failed!')
486
+ return return_code
487
+ firewall_rule_name = f'{args.cluster}-internal-{index}'
488
+ if firewall_rule_name not in existing_firewall_rules_names:
489
+ command = (
490
+ f'gcloud compute --project={args.project} firewall-rules create'
491
+ f' {firewall_rule_name} --network={args.cluster}-net-{index} --action=ALLOW'
492
+ ' --rules=tcp:0-65535,udp:0-65535,icmp --source-ranges=192.168.0.0/16'
493
+ )
494
+ return_code = run_command_with_updates(
495
+ command, 'Create Cluster Firewall Rule', args, verbose=False
496
+ )
497
+
498
+ if return_code != 0:
499
+ xpk_print(
500
+ f'Create Cluster Firewall Rule request returned ERROR {return_code}'
501
+ )
502
+ return 1
503
+ else:
504
+ xpk_print(f'Reusing existing firewall rule {firewall_rule_name}')
505
+ return 0
506
+
507
+
508
+ def create_cluster_network_config(args) -> int:
509
+ """Run the Create GKE Cluster Network Config request.
510
+
511
+ Args:
512
+ args: user provided arguments for running the command.
513
+
514
+ Returns:
515
+ 0 if successful and 1 otherwise.
516
+ """
517
+ yml_string = cluster_network_yaml.format(cluster_name=args.cluster)
518
+ tmp = write_tmp_file(yml_string)
519
+ command = f'kubectl apply -f {str(tmp.file.name)}'
520
+
521
+ return_code = run_command_with_updates(
522
+ command, 'GKE Cluster Create Network Config', args
523
+ )
524
+ if return_code != 0:
525
+ xpk_print(
526
+ f'GKE Cluster Create ConfigMap request returned ERROR {return_code}'
527
+ )
528
+ return 1
529
+
530
+ return 0
531
+
532
+
533
+ def print_reservations(args) -> int:
534
+ """Print the reservations in the project.
535
+
536
+ Args:
537
+ args: user provided arguments for running the command.
538
+
539
+ Returns:
540
+ 0 if successful and 1 otherwise.
541
+ """
542
+ command = f'gcloud beta compute reservations list --project={args.project}'
543
+ return_code = run_command_with_updates(
544
+ command, 'Get all reservations in the project', args
545
+ )
546
+ if return_code != 0:
547
+ xpk_print(f'Get all reservations returned ERROR {return_code}')
548
+ return 1
549
+ return 0
550
+
551
+
552
+ def verify_reservation_exists(args) -> int:
553
+ """Verify the reservation exists.
554
+
555
+ Args:
556
+ args: user provided arguments for running the command.
557
+
558
+ Returns:
559
+ 0 if successful and 1 otherwise.
560
+ """
561
+ command = (
562
+ f'gcloud beta compute reservations describe {args.reservation}'
563
+ f' --project={args.project} --zone={args.zone}'
564
+ )
565
+ return_code = run_command_with_updates(command, 'Describe reservation', args)
566
+ if return_code != 0:
567
+ xpk_print(f'Describe reservation returned ERROR {return_code}')
568
+ xpk_print('Please confirm that your reservation name is correct.')
569
+ return 1
570
+ return 0
571
+
572
+
573
+ def get_capacity_type(args) -> tuple[CapacityType, int]:
574
+ """Determine the capacity type based on user arguments.
575
+
576
+ Args:
577
+ args: user provided arguments for running the command.
578
+
579
+ Returns:
580
+ Tuple with string with the system characteristics and
581
+ int of 0 if successful and 1 otherwise.
582
+ """
583
+ capacity_type = CapacityType.UNKNOWN
584
+ num_types = 0
585
+ return_code = 0
586
+
587
+ # Determine the capacity argument.
588
+ if args.on_demand:
589
+ capacity_type = CapacityType.ON_DEMAND
590
+ num_types += 1
591
+ if args.reservation:
592
+ return_code = verify_reservation_exists(args)
593
+ if return_code > 0:
594
+ return capacity_type, return_code
595
+ capacity_type = CapacityType.RESERVATION
596
+ num_types += 1
597
+ if args.spot:
598
+ capacity_type = CapacityType.SPOT
599
+ num_types += 1
600
+
601
+ # Check that the number of user arguments provided is valid.
602
+ if num_types == 0:
603
+ capacity_type = CapacityType.UNKNOWN
604
+ elif num_types != 1:
605
+ xpk_print(
606
+ 'ERROR: User specified more than one of the following arguments. Please'
607
+ ' specify only one of `--reservation=$RESERVATION_NAME`, `--on-demand`'
608
+ ' or `--spot`.'
609
+ )
610
+ return_code = 1
611
+
612
+ return capacity_type, return_code
613
+
614
+
615
+ def get_capacity_arguments_from_capacity_type(
616
+ args, capacity_type: CapacityType
617
+ ) -> tuple[str, int]:
618
+ """Determine the TPU Nodepool creation capacity arguments needed.
619
+
620
+ Args:
621
+ args: user provided arguments for running the command.
622
+ capacity_type: The type of capacity the user configured.
623
+
624
+ Returns:
625
+ Tuple with string with the capacity argument to use and
626
+ int of 0 if successful and 1 otherwise.
627
+ """
628
+ capacity_args = ''
629
+ return_code = 0
630
+
631
+ match capacity_type:
632
+ case CapacityType.ON_DEMAND:
633
+ capacity_args = ''
634
+ case CapacityType.SPOT:
635
+ capacity_args = '--spot'
636
+ case CapacityType.RESERVATION:
637
+ capacity_args = (
638
+ f'--reservation-affinity=specific --reservation={args.reservation}'
639
+ )
640
+ case _:
641
+ xpk_print(
642
+ f'Unknown capacity type: {capacity_type}. Unable to determine'
643
+ ' capacity args.'
644
+ )
645
+ return_code = 1
646
+ return capacity_args, return_code
647
+
648
+
649
+ def get_capacity_node_selectors_from_capacity_type(
650
+ args, capacity_type: str
651
+ ) -> tuple[str, int]:
652
+ """Determine the node selectors for a workload to run on a specific capacity type.
653
+
654
+ Args:
655
+ args: user provided arguments for running the command.
656
+ capacity_type: The type of capacity the user configured.
657
+
658
+ Returns:
659
+ Tuple with string with the node selectors to use and
660
+ int of 0 if successful and 1 otherwise.
661
+ """
662
+ node_selector = ''
663
+ return_code = 0
664
+
665
+ match capacity_type:
666
+ case CapacityType.ON_DEMAND.name:
667
+ node_selector = ''
668
+ case CapacityType.SPOT.name:
669
+ node_selector = 'cloud.google.com/gke-spot="true"'
670
+ case CapacityType.RESERVATION.name:
671
+ node_selector = f'cloud.google.com/reservation-name: {args.reservation}'
672
+ case _:
673
+ xpk_print(
674
+ f'Unknown capacity type: {capacity_type}. Unable to determine the'
675
+ ' node selectors.'
676
+ )
677
+ return_code = 1
678
+ return node_selector, return_code
679
+
680
+
681
+ def create_or_update_cluster_configmap(configmap_yml: dict) -> int:
682
+ """
683
+ Args:
684
+ configmap_yml: dict containing ConfigMap name and yml string.
685
+
686
+ Returns:
687
+ 0 if successful, 1 otherwise.
688
+ """
689
+ commands = []
690
+ task_names = []
691
+ for configmap_name, yml_string in configmap_yml.items():
692
+ tmp = write_tmp_file(yml_string)
693
+ command = f'kubectl apply -f {str(tmp.file.name)}'
694
+ commands.append(command)
695
+ task_name = f'ConfigMap CreateOrUpdate-{configmap_name}'
696
+ task_names.append(task_name)
697
+
698
+ return_code = run_commands(
699
+ commands, 'GKE Cluster CreateOrUpdate ConfigMap(s)', task_names
700
+ )
701
+ if return_code != 0:
702
+ xpk_print(
703
+ 'GKE Cluster Create/Update ConfigMap(s) request returned ERROR'
704
+ f' {return_code}'
705
+ )
706
+ return 1
707
+ return 0
708
+
709
+
710
+ def create_cluster_configmaps(
711
+ args,
712
+ system,
713
+ tensorboard_config: dict,
714
+ autoprovisioning_config: AutoprovisioningConfig | None,
715
+ ) -> int:
716
+ """Run the Create GKE Cluster ConfigMap request.
717
+
718
+ Args:
719
+ args: user provided arguments for running the command.
720
+ system: system characteristics.
721
+ tensorboard_config: map that contains Vertex Tensorboard name, id and location
722
+ autoprovisioning_config: Config used in autoprovisioning.
723
+ Returns:
724
+ 0 if successful and 1 otherwise.
725
+ """
726
+ configmap_yml = {}
727
+
728
+ # ConfigMap to store resources available in the cluster.
729
+ device_type = system.device_type
730
+ if system.accelerator_type == AcceleratorType['GPU']:
731
+ resources_data = f'{device_type}: "{int(args.num_nodes)}"'
732
+ elif (
733
+ not args.enable_pathways
734
+ and args.enable_autoprovisioning
735
+ and autoprovisioning_config
736
+ ):
737
+ # Currently autoprovisioning is not supported with Pathways.
738
+ # Auto provisioning will have variable topologies for a gke accelerator type.
739
+ resources_data = (
740
+ f'{system.gke_accelerator}: {AUTOPROVISIONING_CONFIG_VALUE}'
741
+ )
742
+ resources_data += (
743
+ f'\n {AUTOPROVISIONING_CONFIG_MINIMUM_KEY}:'
744
+ f' "{autoprovisioning_config.minimum_chips}"'
745
+ )
746
+ resources_data += (
747
+ f'\n {AUTOPROVISIONING_CONFIG_MAXIMUM_KEY}:'
748
+ f' "{autoprovisioning_config.maximum_chips}"'
749
+ )
750
+ else:
751
+ resources_data = (
752
+ f'{device_type}: "{int(args.num_slices) * system.vms_per_slice}"'
753
+ )
754
+ resources_configmap_name = f'{args.cluster}-{CLUSTER_RESOURCES_CONFIGMAP}'
755
+ resources_yml = cluster_configmap_yaml.format(
756
+ args=args, name=resources_configmap_name, data=resources_data
757
+ )
758
+ configmap_yml[resources_configmap_name] = resources_yml
759
+
760
+ # ConfigMap to store cluster metadata.
761
+ # XPK Version.
762
+ metadata = f'xpk_version: {xpk_current_version}'
763
+ # Vertex Tensorboard information
764
+ for key, value in tensorboard_config.items():
765
+ metadata += f'\n {key}: "{value}"'
766
+ # Capacity Type.
767
+ capacity_type, return_code = get_capacity_type(args)
768
+ if return_code != 0:
769
+ xpk_print('Unable to determine capacity type.')
770
+ return return_code
771
+ metadata += f'\n {CAPACITY_TYPE_CONFIG_KEY}: {capacity_type.name}'
772
+ # Reservation ID if applicable.
773
+ if capacity_type == CapacityType.RESERVATION:
774
+ metadata += f'\n {RESERVATION_CONFIG_KEY}: {args.reservation}'
775
+ metadata_configmap_name = f'{args.cluster}-{CLUSTER_METADATA_CONFIGMAP}'
776
+ metadata_yml = cluster_configmap_yaml.format(
777
+ args=args, name=metadata_configmap_name, data=metadata
778
+ )
779
+ configmap_yml[metadata_configmap_name] = metadata_yml
780
+ return create_or_update_cluster_configmap(configmap_yml)
781
+
782
+
783
+ def get_cluster_configmap(args, configmap_name) -> dict[str, str] | None:
784
+ """Run the Get GKE Cluster ConfigMap request.
785
+
786
+ Args:
787
+ args: user provided arguments for running the command.
788
+ configmap_name: name of the configmap.
789
+
790
+ Returns:
791
+ key:value pairs stored in cluster ConfigMap.
792
+ """
793
+ command = (
794
+ 'kubectl get configmap'
795
+ f' {configmap_name} -o=custom-columns="ConfigData:data" --no-headers=true'
796
+ )
797
+
798
+ return_code, return_value = run_command_for_value(
799
+ command, 'GKE Cluster Get ConfigMap', args
800
+ )
801
+ if return_code != 0:
802
+ xpk_print(f'GKE Cluster Get ConfigMap request returned ERROR {return_code}')
803
+ return None
804
+
805
+ config_map = {}
806
+ return_value = return_value.strip()
807
+
808
+ if return_value:
809
+ # Format of ConfigMap: map[key1:value1 key2:value2]
810
+ return_value = return_value[return_value.index('map') :]
811
+ configs = return_value[4:-1].split(' ')
812
+
813
+ for config in configs:
814
+ key, value = config.strip().split(':')
815
+ config_map[key] = value
816
+ return config_map
817
+
818
+
819
+ def get_cluster_provisioner(args) -> str:
820
+ metadata_configmap_name = f'{args.cluster}-{CLUSTER_METADATA_CONFIGMAP}'
821
+ cluster_config_map = get_cluster_configmap(args, metadata_configmap_name)
822
+ cluster_provisioner = 'gcloud'
823
+ if not cluster_config_map is None:
824
+ provisioner = cluster_config_map.get('provisioner')
825
+ if not provisioner is None:
826
+ cluster_provisioner = provisioner
827
+ xpk_print(f'Cluster provisioner: {cluster_provisioner}')
828
+ return cluster_provisioner
829
+
830
+
831
+ def create_vertex_tensorboard(args) -> dict:
832
+ """Creates a Tensorboard instance in Vertex AI.
833
+
834
+ Args:
835
+ args: user provided arguments.
836
+
837
+ Returns:
838
+ dict containing Tensorboard instance name, id and location.
839
+ """
840
+ from cloud_accelerator_diagnostics import tensorboard # pylint: disable=import-outside-toplevel
841
+
842
+ tensorboard_config = {}
843
+ tensorboard_name = args.tensorboard_name
844
+ if tensorboard_name is None:
845
+ tensorboard_name = f'{args.cluster}-{DEFAULT_VERTEX_TENSORBOARD_NAME}'
846
+ instance_id = tensorboard.create_instance( # pylint: disable=used-before-assignment
847
+ project=args.project,
848
+ location=args.tensorboard_region,
849
+ tensorboard_name=tensorboard_name,
850
+ )
851
+ if instance_id:
852
+ xpk_print(
853
+ f'Tensorboard instance {tensorboard_name} is successfully created.'
854
+ )
855
+ tensorboard_config['tensorboard_region'] = args.tensorboard_region
856
+ tensorboard_config['tensorboard_name'] = tensorboard_name
857
+ tensorboard_config['tensorboard_id'] = instance_id
858
+ return tensorboard_config
859
+
860
+
861
+ def create_vertex_experiment(args) -> dict:
862
+ """Creates an Experiment in Vertex AI.
863
+
864
+ Args:
865
+ args: user provided arguments.
866
+
867
+ Returns:
868
+ map containing Vertex Tensorboard configurations.
869
+ """
870
+ from cloud_accelerator_diagnostics import tensorboard # pylint: disable=import-outside-toplevel
871
+
872
+ metadata_configmap_name = f'{args.cluster}-{CLUSTER_METADATA_CONFIGMAP}'
873
+ cluster_config_map = get_cluster_configmap(args, metadata_configmap_name)
874
+
875
+ if cluster_config_map is None or 'tensorboard_name' not in cluster_config_map:
876
+ xpk_print(
877
+ 'No Vertex Tensorboard instance has been created in cluster create. Run'
878
+ ' `xpk cluster create --create-vertex-tensorboard` before running `xpk'
879
+ ' workload create --use-vertex-tensorboard` to create a Vertex'
880
+ ' Tensorboard instance. Alternatively, use `xpk cluster create-pathways'
881
+ ' --create-vertex-tensorboard` before running `xpk workload'
882
+ ' create-pathways --use-vertex-tensorboard`.'
883
+ )
884
+ return None
885
+
886
+ tensorboard_config = {}
887
+ tensorboard_config['tensorboard_project'] = args.project
888
+ tensorboard_config['tensorboard_region'] = cluster_config_map[
889
+ 'tensorboard_region'
890
+ ]
891
+ tensorboard_config['tensorboard_name'] = cluster_config_map[
892
+ 'tensorboard_name'
893
+ ]
894
+ experiment_name = args.experiment_name
895
+ if experiment_name is None:
896
+ experiment_name = f'{args.cluster}-{args.workload}'
897
+ tensorboard_config['experiment_name'] = experiment_name
898
+
899
+ _, tensorboard_url = tensorboard.create_experiment(
900
+ project=args.project,
901
+ location=tensorboard_config['tensorboard_region'],
902
+ experiment_name=experiment_name,
903
+ tensorboard_name=tensorboard_config['tensorboard_name'],
904
+ )
905
+ if tensorboard_url is None:
906
+ return None
907
+
908
+ xpk_print(f'You can view Vertex Tensorboard at: {tensorboard_url}')
909
+ return tensorboard_config
910
+
911
+
912
+ def get_all_clusters_programmatic(args) -> tuple[list[str], int]:
913
+ """Gets all the clusters associated with the project / region.
914
+
915
+ Args:
916
+ args: user provided arguments for running the command.
917
+
918
+ Returns:
919
+ List of cluster names and 0 if successful and 1 otherwise.
920
+ """
921
+ command = (
922
+ 'gcloud container clusters list'
923
+ f' --project={args.project} --region={zone_to_region(args.zone)}'
924
+ ' --format="csv[no-heading](name)"'
925
+ )
926
+ return_code, raw_cluster_output = run_command_for_value(
927
+ command, 'Find if Cluster Exists', args
928
+ )
929
+ if return_code != 0:
930
+ xpk_print(f'Find if Cluster Exists returned ERROR {return_code}')
931
+ return [], return_code
932
+
933
+ return raw_cluster_output.splitlines(), 0
934
+
935
+
936
+ def get_nodepool_zone(args, nodepool_name) -> tuple[int, str]:
937
+ """Return zone in which nodepool exists in the cluster.
938
+
939
+ Args:
940
+ args: user provided arguments for running the command.
941
+ nodepool_name: name of nodepool.
942
+
943
+ Returns:
944
+ Tuple of int, str where
945
+ int is the return code - 0 if successful, 1 otherwise.
946
+ str is the zone of nodepool.
947
+ """
948
+ command = (
949
+ f'gcloud beta container node-pools describe {nodepool_name}'
950
+ f' --cluster {args.cluster} --project={args.project}'
951
+ f' --region={zone_to_region(args.zone)} --format="value(locations)"'
952
+ )
953
+ return_code, nodepool_zone = run_command_for_value(
954
+ command, 'Get Node Pool Zone', args
955
+ )
956
+ if return_code != 0:
957
+ xpk_print(f'Get Node Pool Zone returned ERROR {return_code}')
958
+ return 1, None
959
+
960
+ return 0, nodepool_zone.strip()
961
+
962
+
963
+ def check_cluster_resources(args, system) -> tuple[bool, bool]:
964
+ """Check if cluster has resources of a specified device_type/gke_accelerator.
965
+ This check will be skipped if <args.cluster>-<_CLUSTER_RESOURCES_CONFIGMAP> ConfigMap doesn't exist for the cluster.
966
+
967
+ Args:
968
+ args: user provided arguments for running the command.
969
+ system: system characteristics.
970
+
971
+ Returns:
972
+ Tuple of bool, bool
973
+ True if resources in the cluster should be checked, False otherwise.
974
+ True if device_type/gke_accelerator exists in the cluster, False otherwise.
975
+ """
976
+ resources_configmap_name = f'{args.cluster}-{CLUSTER_RESOURCES_CONFIGMAP}'
977
+ resources_config_map = get_cluster_configmap(args, resources_configmap_name)
978
+ if resources_config_map is None:
979
+ xpk_print(
980
+ f'No ConfigMap exist for cluster with the name {resources_config_map}.'
981
+ ' Cluster resources check will be skipped.'
982
+ )
983
+ return False, False
984
+ if system.device_type in resources_config_map:
985
+ return True, True
986
+ elif system.gke_accelerator in resources_config_map:
987
+ return True, True
988
+ return True, False
989
+
990
+
991
+ def get_all_nodepools_programmatic(args) -> tuple[list[str], int]:
992
+ """Gets all the nodepools associated with the cluster / project / region.
993
+
994
+ Args:
995
+ args: user provided arguments for running the command.
996
+
997
+ Returns:
998
+ List of nodepools and 0 if successful and 1 otherwise.
999
+ """
1000
+ command = (
1001
+ 'gcloud beta container node-pools list'
1002
+ ' --cluster'
1003
+ f' {args.cluster} --project={args.project} --region={zone_to_region(args.zone)}'
1004
+ ' --format="csv[no-heading](name)"'
1005
+ )
1006
+ return_code, raw_nodepool_output = run_command_for_value(
1007
+ command, 'Get All Node Pools', args
1008
+ )
1009
+ if return_code != 0:
1010
+ xpk_print(f'Get All Node Pools returned ERROR {return_code}')
1011
+ return [], 1
1012
+
1013
+ return raw_nodepool_output.splitlines(), 0
1014
+
1015
+
1016
+ def get_all_networks_programmatic(args) -> tuple[list[str], int]:
1017
+ """Gets all the networks associated with project .
1018
+
1019
+ Args:
1020
+ args: user provided arguments for running the command.
1021
+
1022
+ Returns:
1023
+ List of networks and 0 if successful and 1 otherwise.
1024
+ """
1025
+ command = 'gcloud compute networks list --format="csv[no-heading](name)"'
1026
+ return_code, raw_network_output = run_command_for_value(
1027
+ command, 'Get All Networks', args
1028
+ )
1029
+ if return_code != 0:
1030
+ xpk_print(f'Get All Networks returned ERROR {return_code}')
1031
+ return [], 1
1032
+
1033
+ return raw_network_output.splitlines(), 0
1034
+
1035
+
1036
+ def get_all_subnets_programmatic(args) -> tuple[list[str], int]:
1037
+ """Gets all the subnets associated with the project.
1038
+
1039
+ Args:
1040
+ args: user provided arguments for running the command.
1041
+
1042
+ Returns:
1043
+ List of subnets and 0 if successful and 1 otherwise.
1044
+ """
1045
+ subnet_name_filter = f'{args.cluster}-{zone_to_region(args.zone)}-sub-*'
1046
+
1047
+ command = (
1048
+ 'gcloud compute networks subnets list'
1049
+ f' --filter=name~"{subnet_name_filter}" --project={args.project}'
1050
+ )
1051
+ return_code, raw_subnets_output = run_command_for_value(
1052
+ command, 'Get All Subnets', args
1053
+ )
1054
+ if return_code != 0:
1055
+ xpk_print(f'Get All Subnets returned ERROR {return_code}')
1056
+ return [], 1
1057
+
1058
+ all_outputs = raw_subnets_output.splitlines()
1059
+ all_networks = [
1060
+ all_outputs[i].split(' ')[0] for i in range(1, len(all_outputs))
1061
+ ]
1062
+ return all_networks, 0
1063
+
1064
+
1065
+ def get_all_firewall_rules_programmatic(args) -> tuple[list[str], int]:
1066
+ """Gets all the firewall rules associated with the project.
1067
+
1068
+ Args:
1069
+ args: user provided arguments for running the command.
1070
+
1071
+ Returns:
1072
+ List of firewall rules and 0 if successful and 1 otherwise.
1073
+ """
1074
+ command = (
1075
+ 'gcloud compute firewall-rules list --format="csv[no-heading](name)"'
1076
+ )
1077
+ return_code, raw_subnets_output = run_command_for_value(
1078
+ command, 'Get All Firewall Rules', args
1079
+ )
1080
+ if return_code != 0:
1081
+ xpk_print(f'Get All Firewall Rules returned ERROR {return_code}')
1082
+ return [], 1
1083
+
1084
+ return raw_subnets_output.splitlines(), 0
1085
+
1086
+
1087
+ def get_node_pools_to_delete(
1088
+ args, system, existing_node_pool_names, desired_node_pool_names
1089
+ ) -> list:
1090
+ """Get list of nodepools to delete from the cluster.
1091
+
1092
+ Args:
1093
+ args: user provided arguments for running the command.
1094
+ system: system characteristics.
1095
+ existing_node_pool_names: names of nodepools that already exist in the cluster.
1096
+ desired_node_pool_names: names of nodepools that should exist in the cluster.
1097
+
1098
+ Returns:
1099
+ List of nodepool names to delete.
1100
+ """
1101
+ node_pools_to_delete = []
1102
+ check_resource, is_requested_resource_in_cluster = check_cluster_resources(
1103
+ args, system
1104
+ )
1105
+ for existing_node_pool_name in existing_node_pool_names:
1106
+ # Deletion logic would leave behind any Pathways CPU nodepools.
1107
+ if existing_node_pool_name.find(f'{args.cluster}-np-') != 0:
1108
+ continue
1109
+
1110
+ # Nodepools will be deleted in two scenarios:
1111
+ # Scenario 1: Cluster exists with 3 nodepools of 'x' device_type/gke_accelerator and now we are updating
1112
+ # the cluster to 2 nodepools of 'x' device_type/gke_accelerator. In this case, we will delete
1113
+ # '{args.cluster}-np-2' from the cluster.
1114
+ # Scenario 2: Cluster exists with 2 nodepools of 'x' device_type/gke_accelerator and now we are updating
1115
+ # the cluster to 2 nodepools of 'y' device_type/gke_accelerator. In this case, we will delete
1116
+ # '{args.cluster}-np-0' and '{args.cluster}-np-1' from the cluster.
1117
+ if existing_node_pool_name not in desired_node_pool_names or (
1118
+ check_resource and not is_requested_resource_in_cluster
1119
+ ):
1120
+ node_pools_to_delete.append(existing_node_pool_name)
1121
+
1122
+ return node_pools_to_delete
1123
+
1124
+
1125
+ def run_gke_node_pool_create_command(
1126
+ args, system, gke_node_pool_version
1127
+ ) -> int:
1128
+ """Run the Create GKE Node Pool request.
1129
+
1130
+ Args:
1131
+ args: user provided arguments for running the command.
1132
+ system: System characteristics based on device type/topology.
1133
+ gke_node_pool_version: GKE version to use to create node pools.
1134
+
1135
+ Returns:
1136
+ 0 if successful and 1 otherwise.
1137
+ """
1138
+ device_type = args.tpu_type if args.tpu_type else args.device_type
1139
+ xpk_print(
1140
+ f'Creating {args.num_slices} node pool or pools of {device_type}\n'
1141
+ f'We assume that the underlying system is: {system}'
1142
+ )
1143
+ existing_node_pool_names, return_code = get_all_nodepools_programmatic(args)
1144
+ if return_code > 0:
1145
+ xpk_print('Listing all node pools failed!')
1146
+ return return_code
1147
+
1148
+ capacity_type, return_code = get_capacity_type(args)
1149
+ if return_code > 0:
1150
+ xpk_print('Parsing capacity type failed!')
1151
+ return return_code
1152
+ if capacity_type == CapacityType.UNKNOWN:
1153
+ return_code = print_reservations(args)
1154
+ xpk_print(
1155
+ 'ERROR: User needs to provide the capacity type. Please specify one of'
1156
+ ' the following `--reservation=$RESERVATION_NAME`, `--on-demand`'
1157
+ ' or `--spot`. See the above list of reservations to choose from.'
1158
+ )
1159
+ if return_code > 0:
1160
+ xpk_print('Listing all reservations failed!')
1161
+ return_code = 1
1162
+ capacity_args, return_code = get_capacity_arguments_from_capacity_type(
1163
+ args, capacity_type
1164
+ )
1165
+ if return_code > 0:
1166
+ xpk_print('Parsing capacity arguments failed!')
1167
+ return return_code
1168
+
1169
+ if system.accelerator_type == AcceleratorType['GPU']:
1170
+ xpk_print(
1171
+ f'Creating 1 node pool with {args.num_nodes} nodes of'
1172
+ f' {system.device_type}\nUnderlyingly, we assume that means: {system}'
1173
+ )
1174
+ desired_node_pool_names = [f'{args.cluster}-np-0']
1175
+ else:
1176
+ xpk_print(
1177
+ f'Creating {args.num_slices} node pool or pools of'
1178
+ f' {system.device_type}\nUnderlyingly, we assume that means: {system}'
1179
+ )
1180
+ desired_node_pool_names = [
1181
+ f'{args.cluster}-np-{slice_num}' for slice_num in range(args.num_slices)
1182
+ ]
1183
+
1184
+ node_pools_to_remain = []
1185
+ delete_commands = []
1186
+ delete_task_names = []
1187
+ if existing_node_pool_names:
1188
+ return_code, existing_node_pool_zone = get_nodepool_zone(
1189
+ args, existing_node_pool_names[0]
1190
+ )
1191
+ if return_code != 0:
1192
+ return 1
1193
+
1194
+ if existing_node_pool_zone and existing_node_pool_zone != args.zone:
1195
+ xpk_print(
1196
+ f'Cluster {args.cluster} already has nodepools in zone:'
1197
+ f' {existing_node_pool_zone}. Use the same zone to update nodepools'
1198
+ ' in the cluster.'
1199
+ )
1200
+ return 1
1201
+
1202
+ node_pools_to_delete = get_node_pools_to_delete(
1203
+ args, system, existing_node_pool_names, desired_node_pool_names
1204
+ )
1205
+ for node_pool_name in existing_node_pool_names:
1206
+ if node_pool_name.find(f'{args.cluster}-np-') != 0:
1207
+ continue
1208
+
1209
+ if node_pool_name in node_pools_to_delete:
1210
+ command = (
1211
+ 'gcloud beta container node-pools delete'
1212
+ f' {node_pool_name} --cluster={args.cluster}'
1213
+ f' --zone={zone_to_region(args.zone)}'
1214
+ f' --project={args.project} --quiet'
1215
+ )
1216
+ task = f'NodepoolDelete-{node_pool_name}'
1217
+ delete_commands.append(command)
1218
+ delete_task_names.append(task)
1219
+ else:
1220
+ node_pools_to_remain.append(node_pool_name)
1221
+
1222
+ # Deletion of nodepools should happen before attempting to create new nodepools for the case
1223
+ # when cluster is getting updated from 'x' device_type/gke_accelerator to 'y' device_type/gke_accelerator.
1224
+ # In that case, '{args.cluster}-np-i' nodepool will be re-created for 'y' device_type/gke_accelerator.
1225
+ if delete_commands:
1226
+ will_delete = True
1227
+ if node_pools_to_delete and not args.force:
1228
+ will_delete = get_user_input(
1229
+ f'Planning to delete {len(node_pools_to_delete)} node pools including'
1230
+ f' {node_pools_to_delete}. \nDo you wish to delete: y (yes) / n'
1231
+ ' (no):\n'
1232
+ )
1233
+ if not will_delete:
1234
+ xpk_print(
1235
+ 'You have requested to not delete the existing nodepools in the'
1236
+ ' cluster. There will be no change to the cluster.'
1237
+ )
1238
+ return 1
1239
+
1240
+ for i, command in enumerate(delete_commands):
1241
+ xpk_print(
1242
+ f'To complete {delete_task_names[i]} we are executing {command}'
1243
+ )
1244
+ max_return_code = run_commands(
1245
+ delete_commands,
1246
+ 'Delete Nodepools',
1247
+ delete_task_names,
1248
+ dry_run=args.dry_run,
1249
+ )
1250
+ if max_return_code != 0:
1251
+ xpk_print(f'Delete Nodepools returned ERROR {max_return_code}')
1252
+ return 1
1253
+
1254
+ # Update {args.cluster}-{_CLUSTER_RESOURCES_CONFIGMAP} ConfigMap to 'y': '0'
1255
+ # and remove 'x' from the ConfigMap when cluster is getting updated from
1256
+ # 'x' device_type/gke_accelerator to 'y' device_type/gke_accelerator.
1257
+ if not node_pools_to_remain:
1258
+ if args.enable_autoprovisioning:
1259
+ resources_data = (
1260
+ f'{system.gke_accelerator}: {AUTOPROVISIONING_CONFIG_VALUE}'
1261
+ )
1262
+ else:
1263
+ resources_data = f'{device_type}: "0"'
1264
+ resources_configmap_name = f'{args.cluster}-{CLUSTER_RESOURCES_CONFIGMAP}'
1265
+ resources_yml = cluster_configmap_yaml.format(
1266
+ args=args, name=resources_configmap_name, data=resources_data
1267
+ )
1268
+ configmap_yml = {}
1269
+ configmap_yml[resources_configmap_name] = resources_yml
1270
+ return_code = create_or_update_cluster_configmap(configmap_yml)
1271
+ if return_code != 0:
1272
+ return 1
1273
+
1274
+ create_commands = []
1275
+ create_task_names = []
1276
+ for node_pool_name in desired_node_pool_names:
1277
+ if node_pool_name in node_pools_to_remain:
1278
+ continue
1279
+ command = (
1280
+ 'gcloud beta container node-pools create'
1281
+ f' {node_pool_name}'
1282
+ f' --region={zone_to_region(args.zone)}'
1283
+ f' --cluster={args.cluster}'
1284
+ f' --project={args.project} --node-locations={args.zone}'
1285
+ f' --machine-type={system.gce_machine_type}'
1286
+ f' --host-maintenance-interval={args.host_maintenance_interval}'
1287
+ f' {capacity_args}'
1288
+ ' --enable-gvnic'
1289
+ f' {args.custom_nodepool_arguments}'
1290
+ )
1291
+ if system.accelerator_type == AcceleratorType['TPU']:
1292
+ command += f' --node-version={gke_node_pool_version}'
1293
+ command += f' --num-nodes={system.vms_per_slice}'
1294
+ command += ' --placement-type=COMPACT --max-pods-per-node 15'
1295
+ command += (
1296
+ f' --scopes=storage-full,gke-default,{CLOUD_PLATFORM_AUTH_SCOPE_URL}'
1297
+ )
1298
+ command += f' --tpu-topology={system.topology}'
1299
+ command += f' {args.custom_tpu_nodepool_arguments}'
1300
+ elif system.accelerator_type == AcceleratorType['GPU']:
1301
+ subnet_prefix = f'{args.cluster}-{zone_to_region(args.zone)}'
1302
+ command += f' --num-nodes={args.num_nodes}'
1303
+ command += (
1304
+ ' --accelerator'
1305
+ f' type={system.gke_accelerator},count={str(system.chips_per_vm)},gpu-driver-version=latest'
1306
+ ' --no-enable-autoupgrade '
1307
+ f' --scopes={CLOUD_PLATFORM_AUTH_SCOPE_URL} --additional-node-network'
1308
+ f' network={args.cluster}-net-1,subnetwork={subnet_prefix}-sub-1'
1309
+ ' --additional-node-network'
1310
+ f' network={args.cluster}-net-2,subnetwork={subnet_prefix}-sub-2'
1311
+ ' --additional-node-network'
1312
+ f' network={args.cluster}-net-3,subnetwork={subnet_prefix}-sub-3'
1313
+ ' --additional-node-network'
1314
+ f' network={args.cluster}-net-4,subnetwork={subnet_prefix}-sub-4'
1315
+ )
1316
+ if device_type == h100_mega_device_type:
1317
+ command += (
1318
+ ' --additional-node-network'
1319
+ f' network={args.cluster}-net-5,subnetwork={subnet_prefix}-sub-5'
1320
+ ' --additional-node-network'
1321
+ f' network={args.cluster}-net-6,subnetwork={subnet_prefix}-sub-6'
1322
+ ' --additional-node-network'
1323
+ f' network={args.cluster}-net-7,subnetwork={subnet_prefix}-sub-7'
1324
+ ' --additional-node-network'
1325
+ f' network={args.cluster}-net-8,subnetwork={subnet_prefix}-sub-8'
1326
+ ' --max-pods-per-node=32'
1327
+ )
1328
+ elif system.accelerator_type == AcceleratorType['CPU']:
1329
+ command += f' --num-nodes={system.vms_per_slice}'
1330
+ command += (
1331
+ f' --scopes=storage-full,gke-default,{CLOUD_PLATFORM_AUTH_SCOPE_URL}'
1332
+ )
1333
+
1334
+ task = f'NodepoolCreate-{node_pool_name}'
1335
+ create_commands.append(command)
1336
+ create_task_names.append(task)
1337
+
1338
+ desired_pw_cpu_node_pools = ['cpu-user-np', 'cpu-rm-np', 'cpu-proxy-np']
1339
+ if args.enable_pathways:
1340
+ # Pathways needs CPU nodepools in addition to TPU nodepools
1341
+ for node_pool_name in desired_pw_cpu_node_pools:
1342
+ if node_pool_name in existing_node_pool_names:
1343
+ continue
1344
+ command = (
1345
+ 'gcloud beta container node-pools create'
1346
+ f' {node_pool_name} --node-version={gke_node_pool_version} --cluster={args.cluster} --project={args.project} --node-locations={args.zone} --region={zone_to_region(args.zone)} --num-nodes=1'
1347
+ f' --machine-type={args.pathways_gce_machine_type} --scopes=storage-full,gke-default,{CLOUD_PLATFORM_AUTH_SCOPE_URL} --enable-autoscaling'
1348
+ ' --min-nodes=1 --max-nodes=20'
1349
+ )
1350
+ task = f'NodepoolCreate-{node_pool_name}'
1351
+ create_commands.append(command)
1352
+ create_task_names.append(task)
1353
+
1354
+ for i, command in enumerate(create_commands):
1355
+ xpk_print(f'To complete {create_task_names[i]} we are executing {command}')
1356
+ max_return_code = run_commands(
1357
+ create_commands,
1358
+ 'Create Nodepools',
1359
+ create_task_names,
1360
+ dry_run=args.dry_run,
1361
+ )
1362
+ if max_return_code != 0:
1363
+ xpk_print(f'Create Nodepools returned ERROR {max_return_code}')
1364
+ return 1
1365
+
1366
+ xpk_print('Create or delete node pool request complete.')
1367
+ return 0
1368
+
1369
+
1370
+ # TODO(vbarr): Remove this function when jobsets gets enabled by default on
1371
+ # GKE clusters.
1372
+ def set_jobset_on_cluster(args) -> int:
1373
+ """Add jobset command on server side and ask user to verify it is created.
1374
+
1375
+ Args:
1376
+ args: user provided arguments for running the command.
1377
+
1378
+ Returns:
1379
+ 0 if successful and 1 otherwise.
1380
+ """
1381
+ command = (
1382
+ 'kubectl apply --server-side -f'
1383
+ f' https://github.com/kubernetes-sigs/jobset/releases/download/{JOBSET_VERSION}/manifests.yaml'
1384
+ )
1385
+ task = f'Install Jobset on {args.cluster}'
1386
+ return_code = run_command_with_updates_retry(command, task, args)
1387
+
1388
+ if return_code != 0:
1389
+ xpk_print(f'{task} returned with ERROR {return_code}.\n')
1390
+ xpk_print(
1391
+ "This LIKELY means you're missing Kubernetes Permissions, you can"
1392
+ ' validate this by checking if the error references permission problems'
1393
+ ' such as `requires one of ["container.*"] permission(s)`. Follow our'
1394
+ ' readme:'
1395
+ ' https://github.com/google/xpk/blob/main/README.md#troubleshooting for'
1396
+ ' instructions on how to fix these permissions.'
1397
+ )
1398
+ return return_code
1399
+
1400
+
1401
+ def install_nccl_on_cluster(args, system: SystemCharacteristics) -> int:
1402
+ """Install NCCL plugin on the cluster.
1403
+
1404
+ Args:
1405
+ args: user provided arguments for running the command.
1406
+ system: system characteristics.
1407
+
1408
+ Returns:
1409
+ 0 if successful and 1 otherwise.
1410
+ """
1411
+ if system.device_type == h100_device_type:
1412
+ command = (
1413
+ 'kubectl apply -f '
1414
+ # pylint: disable=line-too-long
1415
+ 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-tcpx-installer.yaml'
1416
+ )
1417
+ else:
1418
+ command = (
1419
+ 'kubectl apply -f '
1420
+ # pylint: disable=line-too-long
1421
+ 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpxo/nccl-tcpxo-installer.yaml'
1422
+ )
1423
+
1424
+ return_code = run_command_with_updates(
1425
+ command, 'Install NCCL Plugin On Cluster', args
1426
+ )
1427
+
1428
+ if return_code != 0:
1429
+ xpk_print(
1430
+ f'Install NCCL Plugin On Cluster request returned ERROR {return_code}'
1431
+ )
1432
+ return 1
1433
+
1434
+ return 0
1435
+
1436
+
1437
+ @dataclass
1438
+ class GkeServerConfig:
1439
+ """Stores the valid gke versions based on gcloud recommendations."""
1440
+
1441
+ default_rapid_gke_version: str
1442
+ valid_versions: set[str]
1443
+
1444
+
1445
+ def get_gke_server_config(args) -> tuple[int, GkeServerConfig | None]:
1446
+ """Determine the GKE versions supported by gcloud currently.
1447
+
1448
+ Args:
1449
+ args: user provided arguments for running the command.
1450
+
1451
+ Returns:
1452
+ Tuple of
1453
+ int: 0 if successful and 1 otherwise.
1454
+ GkeServerConfig: stores valid gke version to use in node pool and cluster.
1455
+ """
1456
+ base_command = (
1457
+ 'gcloud container get-server-config'
1458
+ f' --project={args.project} --region={zone_to_region(args.zone)}'
1459
+ )
1460
+ default_rapid_gke_version_cmd = (
1461
+ base_command
1462
+ + ' --flatten="channels" --filter="channels.channel=RAPID"'
1463
+ ' --format="value(channels.defaultVersion)"'
1464
+ )
1465
+ valid_versions_cmd = (
1466
+ base_command
1467
+ + ' --flatten="channels" --filter="channels.channel=RAPID"'
1468
+ ' --format="value(channels.validVersions)"'
1469
+ )
1470
+ base_command_description = 'Determine server supported GKE versions for '
1471
+
1472
+ server_config_commands_and_descriptions = [
1473
+ (
1474
+ default_rapid_gke_version_cmd,
1475
+ base_command_description + 'default rapid gke version',
1476
+ ),
1477
+ (
1478
+ valid_versions_cmd,
1479
+ base_command_description + 'valid versions',
1480
+ ),
1481
+ ]
1482
+ command_outputs = []
1483
+
1484
+ for command, command_description in server_config_commands_and_descriptions:
1485
+ return_code, cmd_output = run_command_for_value(
1486
+ command,
1487
+ command_description,
1488
+ args,
1489
+ hide_error=True,
1490
+ )
1491
+ if return_code != 0:
1492
+ xpk_print(f'Unable to get server config for {command_description}.')
1493
+ return return_code, None
1494
+ command_outputs.append(cmd_output)
1495
+
1496
+ return 0, GkeServerConfig(
1497
+ default_rapid_gke_version=command_outputs[0].strip(),
1498
+ valid_versions=set(command_outputs[1].split(';')),
1499
+ )
1500
+
1501
+
1502
+ def get_gke_control_plane_version(
1503
+ args, gke_server_config: GkeServerConfig
1504
+ ) -> tuple[int, str | None]:
1505
+ """Determine gke control plane version for cluster creation.
1506
+
1507
+ Args:
1508
+ args: user provided arguments for running the command.
1509
+ gke_server_config: holds valid gke versions and recommended default version.
1510
+
1511
+ Returns:
1512
+ Tuple of
1513
+ int: 0 if successful and 1 otherwise.
1514
+ str: gke control plane version to use.
1515
+ """
1516
+
1517
+ # Override with user provide gke version if specified.
1518
+ if args.gke_version is not None:
1519
+ master_gke_version = args.gke_version
1520
+ else:
1521
+ master_gke_version = gke_server_config.default_rapid_gke_version
1522
+
1523
+ is_valid_version = master_gke_version in gke_server_config.valid_versions
1524
+
1525
+ if not is_valid_version:
1526
+ xpk_print(
1527
+ f'Planned GKE Version: {master_gke_version}\n Valid Versions:'
1528
+ f'\n{gke_server_config.valid_versions}\nRecommended / Default GKE'
1529
+ f' Version: {gke_server_config.default_rapid_gke_version}'
1530
+ )
1531
+ xpk_print(
1532
+ f'Error: Planned GKE Version {master_gke_version} is not valid.'
1533
+ f'Checks failed: Is Version Valid: {is_valid_version}'
1534
+ )
1535
+ xpk_print(
1536
+ 'Please select a gke version from the above list using --gke-version=x'
1537
+ ' argument or rely on the default gke version:'
1538
+ f' {gke_server_config.default_rapid_gke_version}'
1539
+ )
1540
+ return 1, None
1541
+
1542
+ return 0, master_gke_version
1543
+
1544
+
1545
+ def get_gke_node_pool_version(
1546
+ args, gke_server_config: GkeServerConfig
1547
+ ) -> tuple[int, str | None]:
1548
+ """Determine the gke node pool version for the node pool.
1549
+
1550
+ Args:
1551
+ args: user provided arguments for running the command.
1552
+ gke_server_config: holds valid gke versions and recommended default version.
1553
+
1554
+ Returns:
1555
+ Tuple of
1556
+ int: 0 if successful and 1 otherwise.
1557
+ str: gke control plane version to use.
1558
+ """
1559
+
1560
+ # By default use the current gke master version for creating node pools.
1561
+ command_description = 'Determine current gke master version'
1562
+ command = (
1563
+ f'gcloud beta container clusters describe {args.cluster}'
1564
+ f' --region {zone_to_region(args.zone)} --project {args.project}'
1565
+ ' --format="value(currentMasterVersion)"'
1566
+ )
1567
+
1568
+ return_code, current_gke_master_version = run_command_for_value(
1569
+ command, command_description, args
1570
+ )
1571
+ if return_code != 0:
1572
+ xpk_print(
1573
+ f'Unable to get server config for command: {command_description}.'
1574
+ )
1575
+ return return_code, None
1576
+
1577
+ # Override with user provide gke version if specified.
1578
+ if args.gke_version is not None:
1579
+ node_pool_gke_version = args.gke_version
1580
+ else:
1581
+ master_gke_version = current_gke_master_version.strip()
1582
+ node_pool_gke_version = ''
1583
+ # Select minimum version which is >= master_gke_version and has the same minor version.
1584
+ # If this does not exist select maximum version which is < master_gke_version.
1585
+ for version in gke_server_config.valid_versions:
1586
+ if (
1587
+ (node_pool_gke_version == '' or node_pool_gke_version < version)
1588
+ and version < master_gke_version
1589
+ ) or (
1590
+ (node_pool_gke_version == '' or node_pool_gke_version > version)
1591
+ and master_gke_version <= version
1592
+ and master_gke_version.split('.')[:2] == version.split('.')[:2]
1593
+ ):
1594
+ node_pool_gke_version = version
1595
+
1596
+ is_supported_node_pool_version = (
1597
+ node_pool_gke_version in gke_server_config.valid_versions
1598
+ )
1599
+ # In rare cases, user's provided gke version may be invalid, but gke will return an error if so.
1600
+ # An example scenario is if the user provided gke version is greater than the master version.
1601
+ if not is_supported_node_pool_version:
1602
+ xpk_print(
1603
+ f'Planned node pool version {node_pool_gke_version} is not supported in'
1604
+ ' valid version'
1605
+ f' {gke_server_config.valid_versions}\nPlease adjust the gke version'
1606
+ ' using --gke-version=x or remove the arg and depend on xpk default of'
1607
+ f' {current_gke_master_version}'
1608
+ )
1609
+ return 1, None
1610
+ return 0, node_pool_gke_version
1611
+
1612
+
1613
+ def validate_docker_image(docker_image, args) -> int:
1614
+ """Validates that the user provided docker image exists in your project.
1615
+
1616
+ Args:
1617
+ docker_image: The docker image to verify.
1618
+ args: user provided arguments for running the command.
1619
+
1620
+ Returns:
1621
+ 0 if successful and 1 otherwise.
1622
+ """
1623
+
1624
+ project = args.project
1625
+
1626
+ if not any(repo in docker_image for repo in ['gcr.io', 'docker.pkg.dev']):
1627
+ return 0
1628
+
1629
+ command = (
1630
+ f'gcloud container images describe {docker_image} --project {project}'
1631
+ )
1632
+ return_code = run_command_with_updates(
1633
+ command, 'Validate Docker Image', args, verbose=False
1634
+ )
1635
+ if return_code != 0:
1636
+ xpk_print(
1637
+ 'Failed to validate your docker image, check that the docker image'
1638
+ f' exists. You may be able to find the {docker_image} in {project}.'
1639
+ ' If the docker image exists, the service account of this'
1640
+ ' project maybe be missing the permissions to access the docker image.'
1641
+ )
1642
+ return return_code
1643
+ else:
1644
+ return 0
1645
+
1646
+
1647
+ def build_docker_image_from_base_image(args, verbose=True) -> tuple[int, str]:
1648
+ """Adds script dir to the base docker image and uploads the image.
1649
+
1650
+ Args:
1651
+ args: user provided arguments for running the command.
1652
+
1653
+ Returns:
1654
+ Tuple of:
1655
+ 0 if successful and 1 otherwise.
1656
+ Name of the Docker image created.
1657
+ """
1658
+
1659
+ # Pick a name for the docker image.
1660
+ docker_image_prefix = os.getenv('USER', 'unknown')
1661
+ docker_name = f'{docker_image_prefix}-runner'
1662
+
1663
+ script_dir_dockerfile = """FROM {base_docker_image}
1664
+
1665
+ # Set the working directory in the container
1666
+ WORKDIR /app
1667
+
1668
+ # Copy all files from local workspace into docker container
1669
+ COPY . .
1670
+
1671
+ WORKDIR /app
1672
+ """
1673
+
1674
+ docker_file = script_dir_dockerfile.format(
1675
+ base_docker_image=args.base_docker_image,
1676
+ )
1677
+ tmp = write_tmp_file(docker_file)
1678
+ docker_build_command = (
1679
+ f'docker buildx build --platform={PLATFORM} -f {str(tmp.file.name)} -t'
1680
+ f' {docker_name} {args.script_dir}'
1681
+ )
1682
+ xpk_print(f'Building {args.script_dir} into docker image.')
1683
+ return_code = run_command_with_updates(
1684
+ docker_build_command,
1685
+ 'Building script_dir into docker image',
1686
+ args,
1687
+ verbose=verbose,
1688
+ )
1689
+ if return_code != 0:
1690
+ xpk_print(
1691
+ 'Failed to add script_dir to docker image, check the base docker image.'
1692
+ f' You should be able to navigate to the URL {args.base_docker_image}'
1693
+ f' in {args.project}.'
1694
+ )
1695
+ xpk_exit(1)
1696
+
1697
+ # Pick a randomly generated `tag_length` character docker tag.
1698
+ tag_length = 4
1699
+ tag_random_prefix = ''.join(
1700
+ random.choices(string.ascii_lowercase, k=tag_length)
1701
+ )
1702
+ tag_datetime = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
1703
+ tag_name = f'{tag_random_prefix}-{tag_datetime}'
1704
+ cloud_docker_image = f'gcr.io/{args.project}/{docker_name}:{tag_name}'
1705
+ xpk_print(f'Adding Docker Image: {cloud_docker_image} to {args.project}')
1706
+
1707
+ # Tag the docker image.
1708
+ tag_docker_image_command = f'docker tag {docker_name} {cloud_docker_image}'
1709
+ return_code = run_command_with_updates(
1710
+ tag_docker_image_command, 'Tag Docker Image', args, verbose=verbose
1711
+ )
1712
+ if return_code != 0:
1713
+ xpk_print(
1714
+ f'Failed to tag docker image with tag: {tag_name}.'
1715
+ f' You should be able to navigate to the URL {cloud_docker_image} in'
1716
+ f' {args.project}.'
1717
+ )
1718
+ xpk_exit(1)
1719
+
1720
+ # Upload image to Artifact Registry.
1721
+ upload_docker_image_command = f'docker push {cloud_docker_image}'
1722
+ return_code = run_command_with_updates(
1723
+ upload_docker_image_command, 'Upload Docker Image', args, verbose=verbose
1724
+ )
1725
+ if return_code != 0:
1726
+ xpk_print(
1727
+ 'Failed to upload docker image.'
1728
+ f' You should be able to navigate to the URL {cloud_docker_image} in'
1729
+ f' {args.project}.'
1730
+ )
1731
+ xpk_exit(1)
1732
+ return return_code, cloud_docker_image
1733
+
1734
+
1735
+ def check_if_workload_exists(args) -> bool:
1736
+ """Check if workload exists.
1737
+
1738
+ Args:
1739
+ args: user provided arguments for running the command.
1740
+
1741
+ Returns:
1742
+ returns true if workload exist, otherwise returns false.
1743
+ """
1744
+ columns = {
1745
+ 'Jobset': '.metadata.ownerReferences[0].name',
1746
+ }
1747
+
1748
+ s = ','.join([key + ':' + value for key, value in columns.items()])
1749
+
1750
+ command = f"kubectl get workloads -o=custom-columns='{s}'"
1751
+ return_code, return_msg = run_command_for_value(
1752
+ command, 'Check if Workload Already Exists', args
1753
+ )
1754
+
1755
+ if return_code != 0:
1756
+ xpk_print(f'List Job request returned ERROR {return_code}')
1757
+ xpk_exit(return_code)
1758
+
1759
+ lines = return_msg.split('\n')
1760
+ new_workload_name = args.workload
1761
+ for line in lines:
1762
+ if line == new_workload_name:
1763
+ return True
1764
+ return False
1765
+
1766
+
1767
+ def check_if_workload_can_schedule(args, system: SystemCharacteristics) -> bool:
1768
+ """Check if workload can schedule based on the cluster resources (tpu_type and maximum VM in cluster).
1769
+
1770
+ Args:
1771
+ args: user provided arguments for running the command.
1772
+ system: system characteristics
1773
+
1774
+ Returns:
1775
+ returns true if workload can schedule, otherwise returns false.
1776
+ """
1777
+ resources_configmap_name = f'{args.cluster}-{CLUSTER_RESOURCES_CONFIGMAP}'
1778
+ cluster_config_map = get_cluster_configmap(args, resources_configmap_name)
1779
+
1780
+ # Prevents workload creation failure for existing clusters with no ConfigMap
1781
+ if cluster_config_map is None:
1782
+ xpk_print(
1783
+ 'No ConfigMap exist for cluster with the name'
1784
+ f' {resources_configmap_name}.'
1785
+ )
1786
+ return True
1787
+
1788
+ # Check for gke accelerator type:
1789
+ missing_gke_accelerator_type = False
1790
+ if not cluster_config_map.get(system.gke_accelerator):
1791
+ xpk_print(
1792
+ f'Gke Accelerator Type Check: {args.workload} is requesting'
1793
+ f' {system.gke_accelerator} but cluster only contains'
1794
+ f' {cluster_config_map.keys()}. '
1795
+ )
1796
+ missing_gke_accelerator_type = True
1797
+ elif (
1798
+ cluster_config_map[system.gke_accelerator]
1799
+ == AUTOPROVISIONING_CONFIG_VALUE
1800
+ ):
1801
+ # Run total chip check when in autoprovisioning mode.
1802
+ max_chips_in_cluster = int(
1803
+ cluster_config_map[AUTOPROVISIONING_CONFIG_MAXIMUM_KEY]
1804
+ )
1805
+ num_chips_in_workload = get_total_chips_requested_from_args(args, system)
1806
+
1807
+ if num_chips_in_workload > max_chips_in_cluster:
1808
+ xpk_print(
1809
+ f'{args.workload} is requesting {num_chips_in_workload} chips but'
1810
+ f' the cluster {args.cluster} supports up to {max_chips_in_cluster}.'
1811
+ ' Resize the cluster to support more chips with'
1812
+ ' `xpk cluster create --autoprovisioning-max-chips=X ...`'
1813
+ )
1814
+ return False
1815
+ return True
1816
+
1817
+ # Check for device type
1818
+ missing_device_type = False
1819
+ device_type = system.device_type
1820
+ if device_type not in cluster_config_map:
1821
+ xpk_print(
1822
+ f'Device Type Check: {args.workload} is requesting {device_type} but '
1823
+ f'cluster only contains {cluster_config_map.keys()}. '
1824
+ )
1825
+ missing_device_type = True
1826
+
1827
+ if missing_device_type and missing_gke_accelerator_type:
1828
+ xpk_print(
1829
+ 'Both Device Type and GKE Accelerator Type checks failed.'
1830
+ f' XPK will not create the workload {args.workload}.'
1831
+ )
1832
+ return False
1833
+ else:
1834
+ # Check if the size of the workload will fit in the cluster.
1835
+ max_vm_in_cluster = int(cluster_config_map[device_type])
1836
+ if system.accelerator_type == AcceleratorType['GPU']:
1837
+ vm_required_by_workload = args.num_nodes
1838
+ else:
1839
+ vm_required_by_workload = args.num_slices * system.vms_per_slice
1840
+ if vm_required_by_workload > max_vm_in_cluster:
1841
+ xpk_print(
1842
+ f'{args.workload} is requesting {args.num_slices} slice/slices of'
1843
+ f' {device_type}, which is {vm_required_by_workload} VMs, but the'
1844
+ f' cluster only contains {max_vm_in_cluster} VMs of {device_type}.'
1845
+ ' XPK will not create this workload.'
1846
+ )
1847
+ return False
1848
+
1849
+ return True
1850
+
1851
+
1852
+ def use_base_docker_image_or_docker_image(args) -> bool:
1853
+ """Checks for correct docker image arguments.
1854
+
1855
+ Args:
1856
+ args: user provided arguments for running the command.
1857
+
1858
+ Returns:
1859
+ True if intended to use base docker image, False to use docker image.
1860
+ """
1861
+ use_base_docker_image = True
1862
+ # Check if (base_docker_image and script_dir) or (docker_image) is set.
1863
+ if args.docker_image is not None:
1864
+ if args.script_dir is not default_script_dir:
1865
+ xpk_print(
1866
+ '`--script-dir` and --docker-image can not be used together. Please'
1867
+ ' see `--help` command for more details.'
1868
+ )
1869
+ xpk_exit(1)
1870
+ if args.base_docker_image is not default_docker_image:
1871
+ xpk_print(
1872
+ '`--base-docker-image` and --docker-image can not be used together.'
1873
+ ' Please see `--help` command for more details.'
1874
+ )
1875
+ xpk_exit(1)
1876
+ use_base_docker_image = False
1877
+ return use_base_docker_image
1878
+
1879
+
1880
+ def setup_docker_image(args) -> tuple[int, str]:
1881
+ """Does steps to verify docker args, check image, and build image (if asked).
1882
+
1883
+ Args:
1884
+ args: user provided arguments for running the command.
1885
+
1886
+ Returns:
1887
+ tuple:
1888
+ 0 if successful and 1 otherwise.
1889
+ Name of the docker image to use.
1890
+ """
1891
+ use_base_docker_image = use_base_docker_image_or_docker_image(args)
1892
+
1893
+ docker_image = args.base_docker_image
1894
+ if use_base_docker_image:
1895
+ validate_docker_image_code = validate_docker_image(docker_image, args)
1896
+ if validate_docker_image_code != 0:
1897
+ xpk_exit(validate_docker_image_code)
1898
+ build_docker_image_code, docker_image = build_docker_image_from_base_image(
1899
+ args
1900
+ )
1901
+ if build_docker_image_code != 0:
1902
+ xpk_exit(build_docker_image_code)
1903
+ else:
1904
+ docker_image = args.docker_image
1905
+ validate_docker_image_code = validate_docker_image(args.docker_image, args)
1906
+ if validate_docker_image_code != 0:
1907
+ xpk_exit(validate_docker_image_code)
1908
+
1909
+ return 0, docker_image
1910
+
1911
+
1912
+ def get_main_and_sidecar_container(args, system, docker_image) -> str:
1913
+ """Generate yaml for main and sidecar container.
1914
+ Args:
1915
+ args: user provided arguments for running the command.
1916
+ system: system characteristics
1917
+ docker_image: docker image
1918
+
1919
+ Returns:
1920
+ str:
1921
+ yaml for main and sidecar container
1922
+ """
1923
+ resource_type = AcceleratorTypeToAcceleratorCharacteristics[
1924
+ system.accelerator_type
1925
+ ].resource_type
1926
+ main_container = get_main_container(args, system, docker_image, resource_type)
1927
+ yaml = """- name: stacktrace-explorer
1928
+ image: busybox:1.28
1929
+ args: [/bin/sh, -c, "check_signal() (while [ ! -f /shared-volume/stacktrace_signal ]; do sleep 1; done; pid=$(pidof 'tail'); kill $pid;); check_signal & while [ ! -d /tmp/debugging ]; do sleep 60; done; while [ ! -e /tmp/debugging/* ]; do sleep 60; done; tail -n+1 -f /tmp/debugging/*; exit 0;"]
1930
+ volumeMounts:
1931
+ - name: tpu-stack-trace
1932
+ readOnly: true
1933
+ mountPath: /tmp/debugging
1934
+ - name: shared-data
1935
+ mountPath: /shared-volume
1936
+ {main_container}
1937
+ """
1938
+ return yaml.format(main_container=main_container)
1939
+
1940
+
1941
+ def get_main_container(args, system, docker_image, resource_type) -> str:
1942
+ """Generate yaml for main container including the xpk command.
1943
+ Args:
1944
+ args: user provided arguments for running the command.
1945
+ system: system characteristics
1946
+ docker_image: docker image
1947
+ resource_type: The label to describe the resource type for TPUs/GPUs/CPUs.
1948
+
1949
+ Returns:
1950
+ str:
1951
+ yaml for main container
1952
+ """
1953
+
1954
+ xpk_internal_commands = ''
1955
+ gsutil_test_command = ''
1956
+ if not args.use_pathways and args.debug_dump_gcs:
1957
+ gsutil_test_command = (
1958
+ 'which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil'
1959
+ ' is required but not installed. Aborting"; exit 24;};'
1960
+ )
1961
+ xpk_internal_commands += (
1962
+ 'WORKER_ID=$HOSTNAME;'
1963
+ f'gsutil -m cp -r /tmp/xla_dump/ {args.debug_dump_gcs}/$WORKER_ID;'
1964
+ )
1965
+
1966
+ command = args.command
1967
+ if args.enable_debug_logs:
1968
+ command = (
1969
+ 'export TPU_STDERR_LOG_LEVEL=0 &&'
1970
+ ' export TPU_MIN_LOG_LEVEL=0 &&'
1971
+ ' export TF_CPP_MIN_LOG_LEVEL=0 &&'
1972
+ ' export TPU_VMODULE=real_program_continuator=1 &&'
1973
+ f' {args.command}'
1974
+ )
1975
+
1976
+ gpu_workload_terminate_command = ''
1977
+ if system.accelerator_type == AcceleratorType['GPU']:
1978
+ gpu_workload_terminate_command = (
1979
+ 'echo Main app is done > /usr/share/workload/workload_terminated; '
1980
+ )
1981
+
1982
+ tpu_stacktrace_terminate_command = ''
1983
+ if (
1984
+ not args.use_pathways
1985
+ and system.accelerator_type == AcceleratorType['TPU']
1986
+ and args.deploy_stacktrace_sidecar
1987
+ ):
1988
+ tpu_stacktrace_terminate_command = (
1989
+ 'touch /shared-volume/stacktrace_signal; '
1990
+ )
1991
+
1992
+ xpk_return_user_exit_code = ''
1993
+ if args.restart_on_user_code_failure:
1994
+ if int(args.max_restarts) <= 0:
1995
+ xpk_print(
1996
+ f'Warning: --max-restarts, is set to {args.max_restarts}. Will not'
1997
+ ' restart on user failure.'
1998
+ )
1999
+ xpk_return_user_exit_code = 'exit $EXIT_CODE'
2000
+
2001
+ yaml = """- name: {docker_name}
2002
+ image: {docker_image}
2003
+ {image_pull_policy}
2004
+ env: {env}
2005
+ ports:
2006
+ {container_ports}
2007
+ {jax_coordinator_port}
2008
+ securityContext:
2009
+ privileged: true
2010
+ command:
2011
+ - bash
2012
+ - -c
2013
+ - |
2014
+ echo XPK Start: $(date);
2015
+ _sigterm() (kill -SIGTERM $! 2>/dev/null;);
2016
+ trap _sigterm SIGTERM;
2017
+ {gsutil_test_command}
2018
+ ({command}) & PID=$!;
2019
+ while kill -0 $PID 2>/dev/null;
2020
+ do sleep 5;
2021
+ done;
2022
+ wait $PID;
2023
+ EXIT_CODE=$?;
2024
+ {xpk_internal_commands}
2025
+ echo XPK End: $(date);
2026
+ echo EXIT_CODE=$EXIT_CODE;
2027
+ {tpu_stacktrace_terminate_command}
2028
+ {gpu_workload_terminate_command}
2029
+ if [ "$EXIT_CODE" = 143 ]; then
2030
+ exit $EXIT_CODE
2031
+ fi
2032
+ {xpk_return_user_exit_code}
2033
+ resources:
2034
+ limits:
2035
+ {resources}
2036
+ """
2037
+ volume_mounts = get_volume_mounts(args, system)
2038
+ if volume_mounts != '':
2039
+ yaml += """
2040
+ volumeMounts:
2041
+ {volume_mounts}
2042
+ """
2043
+ return yaml.format(
2044
+ args=args,
2045
+ system=system,
2046
+ image_pull_policy=add_image_pull_policy_for_pw_or_gpu(args, system),
2047
+ env=get_env_container(args, system),
2048
+ container_ports=add_container_ports(args, system),
2049
+ jax_coordinator_port=add_jax_coordinator_port(system),
2050
+ docker_name=get_main_container_docker_image(args, system),
2051
+ docker_image=docker_image,
2052
+ gsutil_test_command=gsutil_test_command,
2053
+ command=command,
2054
+ tpu_stacktrace_terminate_command=tpu_stacktrace_terminate_command,
2055
+ gpu_workload_terminate_command=gpu_workload_terminate_command,
2056
+ xpk_internal_commands=xpk_internal_commands,
2057
+ resources=get_main_container_resources(args, system, resource_type),
2058
+ volume_mounts=volume_mounts,
2059
+ xpk_return_user_exit_code=xpk_return_user_exit_code,
2060
+ )
2061
+
2062
+
2063
+ def add_image_pull_policy_for_pw_or_gpu(args, system: SystemCharacteristics):
2064
+ """Add image pull policy only for Pathways containers.
2065
+ Args:
2066
+ args: user provided args.
2067
+ system: system characteristics
2068
+
2069
+ Returns:
2070
+ str:
2071
+ YAML stating that the image will be pulled fro GCR every time.
2072
+ """
2073
+ yaml = """imagePullPolicy: Always"""
2074
+
2075
+ if args.use_pathways or system.accelerator_type == AcceleratorType['GPU']:
2076
+ return yaml.format(args=args)
2077
+ return ''
2078
+
2079
+
2080
+ def get_main_container_docker_image(args, system: SystemCharacteristics) -> str:
2081
+ """Docker name for the main container.
2082
+ Args:
2083
+ args: user provided args.
2084
+ system: system characteristics.
2085
+
2086
+ Returns:
2087
+ str:
2088
+ Workload docker image as a YAML string
2089
+ """
2090
+
2091
+ if system.accelerator_type == AcceleratorType['GPU']:
2092
+ return 'gpu-image'
2093
+
2094
+ return f'{args.docker_name}'
2095
+
2096
+
2097
+ def get_volumes(args, system: SystemCharacteristics) -> str:
2098
+ """Get volumes accessible to the containers in the pod.
2099
+ Args:
2100
+ args: user provided args.
2101
+ system: system characteristics.
2102
+
2103
+ Returns:
2104
+ str:
2105
+ YAML for the volumes.
2106
+ """
2107
+ volumes = """- emptyDir:
2108
+ medium: Memory
2109
+ name: dshm-2"""
2110
+
2111
+ if args.ramdisk_directory != '':
2112
+ volumes += """
2113
+ - name: cache
2114
+ csi:
2115
+ driver: phase1-checkpoint.csi.storage.gke.io"""
2116
+
2117
+ if (
2118
+ system.accelerator_type == AcceleratorType['TPU']
2119
+ and args.deploy_stacktrace_sidecar
2120
+ ):
2121
+ volumes += """
2122
+ - name: tpu-stack-trace
2123
+ - name: shared-data"""
2124
+
2125
+ return volumes
2126
+
2127
+
2128
+ def get_volume_mounts(args, system: SystemCharacteristics) -> str:
2129
+ """Resources for the main container.
2130
+ Args:
2131
+ args: user provided args.
2132
+
2133
+ Returns:
2134
+ str:
2135
+ YAML for the volumes mounted within a Pathways container or GPU container as a YAML string.
2136
+ """
2137
+ volume_mount_yaml = """- mountPath: /dev/shm
2138
+ name: dshm-2"""
2139
+
2140
+ if args.ramdisk_directory != '':
2141
+ volume_mount_yaml += f"""
2142
+ - mountPath: /{args.ramdisk_directory}
2143
+ name: cache"""
2144
+
2145
+ if args.use_pathways:
2146
+ volume_mount_yaml = """- mountPath: /tmp
2147
+ name: shared-tmp"""
2148
+ elif (
2149
+ system.accelerator_type == AcceleratorType['TPU']
2150
+ and args.deploy_stacktrace_sidecar
2151
+ ):
2152
+ volume_mount_yaml += """
2153
+ - name: tpu-stack-trace
2154
+ mountPath: /tmp/debugging
2155
+ - name: shared-data
2156
+ mountPath: /shared-volume"""
2157
+ elif system.accelerator_type == AcceleratorType['GPU']:
2158
+ if system.device_type == h100_device_type:
2159
+ volume_mount_yaml = """- name: nvidia-install-dir-host
2160
+ mountPath: /usr/local/nvidia/lib64
2161
+ - name: tcpx-nccl-plugin-volume
2162
+ mountPath: /usr/local/tcpx
2163
+ - name: tcpd-socket
2164
+ mountPath: /tmp
2165
+ - name: shared-memory
2166
+ mountPath: /dev/shm
2167
+ - name: workload-terminated-volume
2168
+ mountPath: /usr/share/workload"""
2169
+ elif (
2170
+ system.device_type == h100_mega_device_type
2171
+ or system.device_type == h200_device_type
2172
+ ):
2173
+ volume_mount_yaml = ''
2174
+
2175
+ return volume_mount_yaml
2176
+
2177
+
2178
+ def get_user_workload_container(args, system: SystemCharacteristics):
2179
+ """Deploy user workload container
2180
+
2181
+ Args:
2182
+ args: user provided args.
2183
+ system: system characteristics.
2184
+
2185
+ Returns:
2186
+ container: main container
2187
+ debugging_dashboard_id: id of the GKE dashboard
2188
+ """
2189
+
2190
+ setup_docker_image_code, docker_image = setup_docker_image(args)
2191
+ if setup_docker_image_code != 0:
2192
+ xpk_exit(setup_docker_image_code)
2193
+
2194
+ # Determine if we deploy a sidecar and if we deploy a container.
2195
+ debugging_dashboard_id = None
2196
+ resource_type = AcceleratorTypeToAcceleratorCharacteristics[
2197
+ system.accelerator_type
2198
+ ].resource_type
2199
+ if (
2200
+ not args.use_pathways
2201
+ and system.accelerator_type == AcceleratorType['TPU']
2202
+ and args.deploy_stacktrace_sidecar
2203
+ ):
2204
+ xpk_print(
2205
+ 'Sidecar container to display stack traces for TPU workloads will also'
2206
+ ' be deployed.'
2207
+ )
2208
+ container = get_main_and_sidecar_container(args, system, docker_image)
2209
+ # Get GKE debugging dashboard only when sidecar container is deployed for TPU workloads
2210
+ debugging_dashboard_id = get_gke_debugging_dashboard(args)
2211
+ else:
2212
+ container = get_main_container(args, system, docker_image, resource_type)
2213
+ return container, debugging_dashboard_id
2214
+
2215
+
2216
+ def get_env_container(args, system: SystemCharacteristics):
2217
+ """Environment configuration for the main container.
2218
+ Args:
2219
+ args: user provided args.
2220
+ system: system characteristics.
2221
+
2222
+ Returns:
2223
+ str:
2224
+ YAML with the env config for the main container, as a YAML string.
2225
+ """
2226
+ pw_env_yaml = """
2227
+ - name: XCLOUD_ENVIRONMENT
2228
+ value: GCP
2229
+ - name: JAX_PLATFORMS
2230
+ value: proxy
2231
+ - name: JAX_BACKEND_TARGET
2232
+ value: {proxy_address}
2233
+ - name: JOBSET_NAME
2234
+ valueFrom:
2235
+ fieldRef:
2236
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']"""
2237
+ if args.use_pathways:
2238
+ return pw_env_yaml.format(
2239
+ args=args, proxy_address=args.pathways_proxy_address
2240
+ )
2241
+
2242
+ gpu_env_yaml = """
2243
+ - name: REPLICATED_JOB_NAME
2244
+ valueFrom:
2245
+ fieldRef:
2246
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
2247
+ - name: JOBSET_NAME
2248
+ valueFrom:
2249
+ fieldRef:
2250
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
2251
+ - name: JAX_COORDINATOR_ADDRESS
2252
+ value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
2253
+ - name: NNODES
2254
+ value: "{args.num_nodes}"
2255
+ - name: NODE_RANK
2256
+ valueFrom:
2257
+ fieldRef:
2258
+ fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
2259
+ - name: USE_GPUDIRECT
2260
+ value: {gpu_direct_name}
2261
+ - name: GPUS_PER_NODE
2262
+ value: "{system.chips_per_vm}"
2263
+ - name: JAX_COORDINATOR_PORT
2264
+ value: "6002"
2265
+ - name: COMMAND
2266
+ value: "{args.command}"
2267
+ {args.env}"""
2268
+
2269
+ if system.accelerator_type == AcceleratorType['GPU']:
2270
+ gpu_direct_name = 'fastrak'
2271
+ if args.device_type == h100_device_type:
2272
+ gpu_direct_name = 'tcpx'
2273
+ gpu_env_yaml += """
2274
+ - name: LD_LIBRARY_PATH
2275
+ value: /usr/local/nvidia/lib64
2276
+ """
2277
+ elif args.device_type == h100_mega_device_type:
2278
+ gpu_direct_name = 'tcpxo'
2279
+ elif args.device_type == h200_device_type:
2280
+ gpu_direct_name = 'rdma'
2281
+ return gpu_env_yaml.format(
2282
+ args=args, system=system, gpu_direct_name=gpu_direct_name
2283
+ )
2284
+
2285
+ if system.accelerator_type == AcceleratorType['CPU']:
2286
+ return get_cpu_env(args.num_slices, args.env, system)
2287
+
2288
+ return args.env
2289
+
2290
+
2291
+ def get_main_container_resources(
2292
+ args, system: SystemCharacteristics, resource_type
2293
+ ) -> str:
2294
+ """Resources for the main container.
2295
+ Args:
2296
+ args: user provided args.
2297
+ system: system characteristics.
2298
+ resource_type: TPU / GPU / CPU
2299
+
2300
+ Returns:
2301
+ str:
2302
+ Workload resources port as a YAML string
2303
+ """
2304
+ # Resources requirements for Pathways workload containers are known.
2305
+ resources_yaml = """cpu: "24"
2306
+ memory: 100G"""
2307
+ if args.use_pathways:
2308
+ return resources_yaml
2309
+
2310
+ gpu_resources_yaml = """nvidia.com/gpu: {system.chips_per_vm}"""
2311
+ if system.accelerator_type == AcceleratorType['GPU']:
2312
+ return gpu_resources_yaml.format(system=system)
2313
+
2314
+ if system.accelerator_type == AcceleratorType['CPU']:
2315
+ # CPUs don't have chips, but have a subresource called vCPUs.
2316
+ # system.chips_per_vm is used as a proxy for vCPUs.
2317
+ # Some vCPUs get used in hosting system pods of the workloads,
2318
+ # hence an offset of 0.95 is introduced.
2319
+ offset_vCPUs = int(system.chips_per_vm) * 0.95
2320
+ return f'{resource_type}: {offset_vCPUs}'
2321
+
2322
+ return f'{resource_type}: {system.chips_per_vm}'
2323
+
2324
+
2325
+ def add_container_ports(args, system: SystemCharacteristics) -> str:
2326
+ """Add slice builder and megascale container ports,
2327
+ for non-pathways workloads.
2328
+
2329
+ Args:
2330
+ args: user provided args.
2331
+
2332
+ Returns:
2333
+ str:
2334
+ Pathways server port as a YAML string
2335
+ """
2336
+ port_yaml = """- containerPort: 8471
2337
+ - containerPort: 8080"""
2338
+ if args.use_pathways:
2339
+ return ''
2340
+
2341
+ gpu_port_yaml = """- containerPort: 6002"""
2342
+ if system.accelerator_type == AcceleratorType['GPU']:
2343
+ return gpu_port_yaml
2344
+ return port_yaml
2345
+
2346
+
2347
+ def add_jax_coordinator_port(system) -> str:
2348
+ """Add jax coordinator port only for CPUs
2349
+
2350
+ Args:
2351
+ system: system characteristics.
2352
+
2353
+ Returns:
2354
+ str:
2355
+ jax coordinator port as a YAML string
2356
+ """
2357
+ if system.accelerator_type == AcceleratorType['CPU']:
2358
+ return '- containerPort: 1234'
2359
+ return ''
2360
+
2361
+
2362
+ def get_gke_dashboard(args, dashboard_filter):
2363
+ """Get the identifier of GKE dashboard deployed in the project.
2364
+
2365
+ Args:
2366
+ args: user provided arguments for running the command.
2367
+
2368
+ Returns:
2369
+ bool:
2370
+ True if 'gcloud monitoring dashboards list' returned an error or
2371
+ multiple dashboards with same filter exist in the project,
2372
+ False otherwise.
2373
+ str:
2374
+ identifier of dashboard if deployed in project,
2375
+ None otherwise.
2376
+ """
2377
+ command = (
2378
+ 'gcloud monitoring dashboards list'
2379
+ f' --project={args.project} --filter="{dashboard_filter}"'
2380
+ ' --format="value(name)" --verbosity=error'
2381
+ )
2382
+
2383
+ return_code, return_value = run_command_for_value(
2384
+ command, 'GKE Dashboard List', args
2385
+ )
2386
+
2387
+ if return_code != 0:
2388
+ xpk_print(
2389
+ f'GKE Dashboard List request returned ERROR {return_code}. If there is'
2390
+ ' a permissions error, please check'
2391
+ ' https://github.com/google/xpk/blob/main/README.md#roles-needed-based-on-permission-errors'
2392
+ ' for possible solutions.'
2393
+ )
2394
+ return True, None
2395
+
2396
+ if not return_value:
2397
+ xpk_print(
2398
+ f'No dashboard with {dashboard_filter} found in the'
2399
+ f' project:{args.project}.'
2400
+ )
2401
+ return False, return_value
2402
+
2403
+ dashboards = return_value.strip().split('\n')
2404
+ if len(dashboards) > 1:
2405
+ xpk_print(
2406
+ f'Multiple dashboards with same {dashboard_filter} exist in the'
2407
+ f' project:{args.project}. Delete all but one dashboard deployed using'
2408
+ ' https://github.com/google/cloud-tpu-monitoring-debugging.'
2409
+ )
2410
+ return True, None
2411
+
2412
+ if dashboards[0]:
2413
+ return False, dashboards[0].strip().split('/')[-1]
2414
+
2415
+ return True, None
2416
+
2417
+
2418
+ def get_gke_outlier_dashboard(args):
2419
+ """Get the identifier of GKE outlier dashboard deployed in the project.
2420
+
2421
+ Args:
2422
+ args: user provided arguments for running the command.
2423
+
2424
+ Returns:
2425
+ str:
2426
+ identifier of outlier dashboard if deployed in project,
2427
+ None otherwise.
2428
+ """
2429
+ outlier_dashboard_filter = "displayName:'GKE - TPU Monitoring Dashboard'"
2430
+ is_error, dashboard_id = get_gke_dashboard(args, outlier_dashboard_filter)
2431
+
2432
+ # 'gcloud monitoring dashboards list' returned an error or multiple dashboards with same filter exist in the project
2433
+ if is_error:
2434
+ return None
2435
+
2436
+ # 'gcloud monitoring dashboards list' succeeded but no dashboard for the filter exist in the project
2437
+ if not is_error and not dashboard_id:
2438
+ xpk_print(
2439
+ 'Follow https://github.com/google/cloud-tpu-monitoring-debugging to'
2440
+ ' deploy monitoring dashboard to view statistics and outlier mode of'
2441
+ ' GKE metrics.'
2442
+ )
2443
+ return None
2444
+
2445
+ return dashboard_id
2446
+
2447
+
2448
+ def get_gke_debugging_dashboard(args):
2449
+ """Get the identifier of GKE debugging dashboard deployed in the project.
2450
+
2451
+ Args:
2452
+ args: user provided arguments for running the command.
2453
+
2454
+ Returns:
2455
+ str:
2456
+ identifier of debugging dashboard if deployed in project,
2457
+ None otherwise.
2458
+ """
2459
+ debugging_dashboard_filter = "displayName:'GKE - TPU Logging Dashboard'"
2460
+ is_error, dashboard_id = get_gke_dashboard(args, debugging_dashboard_filter)
2461
+
2462
+ # 'gcloud monitoring dashboards list' returned an error or multiple dashboards with same filter exist in the project
2463
+ if is_error:
2464
+ return None
2465
+
2466
+ # 'gcloud monitoring dashboards list' succeeded but no dashboard for the filter exist in the project
2467
+ if not is_error and not dashboard_id:
2468
+ xpk_print(
2469
+ 'Follow https://github.com/google/cloud-tpu-monitoring-debugging to'
2470
+ ' deploy debugging dashboard to view stack traces collected in Cloud'
2471
+ ' Logging.'
2472
+ )
2473
+ return None
2474
+
2475
+ return dashboard_id
2476
+
2477
+
2478
+ def create_accelerator_label(accelerator_type, system) -> str:
2479
+ """Generates accelerator label.
2480
+
2481
+ Args:
2482
+ accelerator_type: type of accelerator.
2483
+ system: system characteristics.
2484
+
2485
+ Returns:
2486
+ The accelerator label.
2487
+ """
2488
+ if accelerator_type == AcceleratorType['CPU']:
2489
+ return ''
2490
+ return (
2491
+ f'{AcceleratorTypeToAcceleratorCharacteristics[accelerator_type].accelerator_label}:'
2492
+ f' {system.gke_accelerator}'
2493
+ )
2494
+
2495
+
2496
+ def create_machine_label(
2497
+ accelerator_type, system, autoprovisioning_enabled: bool = False
2498
+ ) -> str:
2499
+ """Generates machine label.
2500
+
2501
+ Args:
2502
+ accelerator_type: type of accelerator.
2503
+ system: system characteristics.
2504
+ autoprovisioning_enabled: describes autoprovisioning enablement.
2505
+
2506
+ Returns:
2507
+ The machine label.
2508
+ """
2509
+ if (
2510
+ accelerator_type == AcceleratorType['TPU']
2511
+ and not autoprovisioning_enabled
2512
+ ):
2513
+ return (
2514
+ f'{AcceleratorTypeToAcceleratorCharacteristics[accelerator_type].machine_label}:'
2515
+ f' {system.topology}'
2516
+ )
2517
+ return ''
2518
+
2519
+
2520
+ def calculate_process_count(num_slices, vms_per_slice) -> str:
2521
+ """Calculates the total number of processes in the workload.
2522
+ Args:
2523
+ num_slices: Number of slices to be used in the workload.
2524
+ vms_per_slice: number of VMs in each slice.
2525
+
2526
+ Returns:
2527
+ str: total number of processes.
2528
+ """
2529
+ num_processes = int(num_slices) * int(vms_per_slice)
2530
+
2531
+ return f'{num_processes}'
2532
+
2533
+
2534
+ def get_cpu_env(num_slices, env_vars, system) -> str:
2535
+ """Generate environment variables for CPU nodepools
2536
+ Args:
2537
+ num_slices: Number of slices to be used in the workload.
2538
+ env_vars: Environment variables, processed from user args.
2539
+ system: system characteristics
2540
+
2541
+ Returns:
2542
+ str: yaml containing env variables
2543
+ """
2544
+ yaml = """
2545
+ - name: REPLICATED_JOB_NAME
2546
+ valueFrom:
2547
+ fieldRef:
2548
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
2549
+ - name: JOB_INDEX
2550
+ valueFrom:
2551
+ fieldRef:
2552
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/job-index']
2553
+ - name: JOB_COMPLETION_INDEX
2554
+ valueFrom:
2555
+ fieldRef:
2556
+ fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
2557
+ - name: PROCESSES_IN_JOB
2558
+ value: "{processes_in_job}"
2559
+ - name: JAX_PROCESS_COUNT
2560
+ value: "{process_count}"
2561
+ {env_vars}
2562
+ - name: JAX_COORDINATOR_ADDRESS
2563
+ value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
2564
+ """
2565
+ return yaml.format(
2566
+ processes_in_job=system.vms_per_slice,
2567
+ process_count=calculate_process_count(num_slices, system.vms_per_slice),
2568
+ env_vars=env_vars,
2569
+ )
2570
+
2571
+
2572
+ def get_cpu_affinity(accelerator_type) -> str:
2573
+ """Generate affinity rules for CPU nodepools, so that workload pods are
2574
+ not scheduled on the default pool machines.
2575
+ Args:
2576
+ accelerator_type: TPU / GPU / CPU
2577
+
2578
+ Returns:
2579
+ str: yaml containing affinity constraints
2580
+ """
2581
+ yaml = """affinity:
2582
+ nodeAffinity:
2583
+ requiredDuringSchedulingIgnoredDuringExecution:
2584
+ nodeSelectorTerms:
2585
+ - matchExpressions:
2586
+ - key: cloud.google.com/gke-nodepool
2587
+ operator: NotIn
2588
+ values:
2589
+ - default-pool
2590
+ """
2591
+ if accelerator_type == AcceleratorType['CPU']:
2592
+ return yaml
2593
+ return ''
2594
+
2595
+
2596
+ def get_gpu_scheduler(
2597
+ args, system: SystemCharacteristics, autoprovisioning_args: str
2598
+ ) -> tuple[str, int]:
2599
+ """Get gpu scheduler configuration.
2600
+
2601
+ Args:
2602
+ args: user provided arguments for running the command.
2603
+ system: system characteristics.
2604
+ autoprovisioning_args: a string of arguments for Autoprovisioning.
2605
+
2606
+ Returns:
2607
+ str: yaml containing gpu scheduler configuration
2608
+ int of 0 if successful and 1 otherwise.
2609
+ """
2610
+ gpu_scheduler = ''
2611
+ return_code = 0
2612
+
2613
+ if args.scheduler == 'gke.io/topology-aware-auto':
2614
+ gpu_scheduler = f"""schedulingGates:
2615
+ - name: "{args.scheduler}-{args.workload}"
2616
+ """
2617
+ elif args.scheduler == 'default-scheduler':
2618
+ gpu_scheduler_yaml = """schedulerName: {scheduler_name}
2619
+ affinity:
2620
+ nodeAffinity:
2621
+ requiredDuringSchedulingIgnoredDuringExecution:
2622
+ nodeSelectorTerms:
2623
+ - matchExpressions:
2624
+ - key: cloud.google.com/gke-accelerator
2625
+ operator: Exists
2626
+ - key: cloud.google.com/gke-nodepool
2627
+ operator: In
2628
+ values: [{node_pool_name}]
2629
+ nodeSelector:
2630
+ {accelerator_label}
2631
+ {machine_label}
2632
+ {autoprovisioning_args}
2633
+ """
2634
+ gpu_scheduler = gpu_scheduler_yaml.format(
2635
+ scheduler_name=args.scheduler,
2636
+ accelerator_label=create_accelerator_label(
2637
+ system.accelerator_type, system
2638
+ ),
2639
+ machine_label=create_machine_label(system.accelerator_type, system),
2640
+ node_pool_name=f'{args.cluster}-np-0',
2641
+ autoprovisioning_args=autoprovisioning_args,
2642
+ )
2643
+ else:
2644
+ return_code = 1
2645
+ xpk_print(
2646
+ '--scheduler needs to be set as either `default-scheduler`'
2647
+ ' or `gke.io/topology-aware-auto` in order to schedule the'
2648
+ ' workloads on GPUs.'
2649
+ )
2650
+
2651
+ return gpu_scheduler, return_code
2652
+
2653
+
2654
+ def get_gpu_volume(system: SystemCharacteristics) -> str:
2655
+ """Get gpu volume based on user provided arguments.
2656
+
2657
+ Args:
2658
+ system: system characteristics.
2659
+
2660
+ Returns:
2661
+ str: yaml containing gpu volume
2662
+ """
2663
+ gpu_volume = ''
2664
+ if system.device_type == h100_device_type:
2665
+ gpu_volume = """- name: nvidia-install-dir-host
2666
+ hostPath:
2667
+ path: /home/kubernetes/bin/nvidia/lib64
2668
+ - name: tcpd-socket
2669
+ hostPath:
2670
+ path: /run/tcpx
2671
+ - name: shared-memory
2672
+ emptyDir:
2673
+ medium: "Memory"
2674
+ sizeLimit: 200Gi
2675
+ - name: workload-terminated-volume
2676
+ emptyDir:
2677
+ - name: tcpx-nccl-plugin-volume
2678
+ emptyDir:"""
2679
+ elif system.device_type == h100_mega_device_type:
2680
+ gpu_volume = """- name: nvidia-install-dir-host
2681
+ hostPath:
2682
+ path: /home/kubernetes/bin/nvidia/lib64
2683
+ - name: shared-memory
2684
+ emptyDir:
2685
+ medium: "Memory"
2686
+ sizeLimit: 1Gi
2687
+ - name: workload-terminated-volume
2688
+ emptyDir:"""
2689
+ return gpu_volume
2690
+
2691
+
2692
+ def get_gpu_rxdm_image(system: SystemCharacteristics) -> str:
2693
+ """Get config of rxdm based on user provided arguments.
2694
+
2695
+ Args:
2696
+ system: system characteristics.
2697
+
2698
+ Returns:
2699
+ str: yaml containing the rxdm name and image
2700
+ """
2701
+ gpu_rxdm_image = ''
2702
+ if system.device_type == h100_device_type:
2703
+ gpu_rxdm_image = """- name: tcpd-daemon
2704
+ image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/tcpgpudmarxd-dev:v2.0.9"""
2705
+ elif system.device_type == h100_mega_device_type:
2706
+ gpu_rxdm_image = """- name: fastrak-daemon
2707
+ image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.9"""
2708
+ return gpu_rxdm_image
2709
+
2710
+
2711
+ def get_gpu_rxdm_cmd(system: SystemCharacteristics) -> str:
2712
+ """Get rxdm command based on user provided arguments.
2713
+
2714
+ Args:
2715
+ system: system characteristics.
2716
+
2717
+ Returns:
2718
+ str: command of running rxdm container
2719
+ """
2720
+ gpu_rxdm_cmd = ''
2721
+ if system.device_type == h100_device_type:
2722
+ gpu_rxdm_cmd = (
2723
+ '/tcpgpudmarxd/build/app/tcpgpudmarxd --gpu_nic_preset a3vm'
2724
+ ' --gpu_shmem_type fd --setup_param "--verbose 128 2 0"'
2725
+ )
2726
+ elif system.device_type == h100_mega_device_type:
2727
+ gpu_rxdm_cmd = (
2728
+ 'set -ex; chmod 755 /fts/entrypoint_rxdm_container.sh;'
2729
+ ' /fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid='
2730
+ ' --alsologtostderr'
2731
+ )
2732
+ return gpu_rxdm_cmd
2733
+
2734
+
2735
+ def get_gpu_tcp_volume(system: SystemCharacteristics) -> str:
2736
+ """Get gpu tcp volume based on user provided arguments.
2737
+
2738
+ Args:
2739
+ system: system characteristics.
2740
+
2741
+ Returns:
2742
+ str: yaml containing gpu tcp volume
2743
+ """
2744
+ gpu_tcp_volume = ''
2745
+ if system.device_type == h100_device_type:
2746
+ gpu_tcp_volume = """- name: tcpd-socket
2747
+ mountPath: /tmp"""
2748
+ return gpu_tcp_volume
2749
+
2750
+
2751
+ def wait_for_job_completion(args) -> int:
2752
+ """Function to wait for job completion.
2753
+
2754
+ Args:
2755
+ args: user provided arguments for running the command.
2756
+
2757
+ Returns:
2758
+ return_code: 0 if successful, 124 if timeout, 125 if unsuccessful job, 1 otherwise
2759
+ """
2760
+ # Check that the workload exists
2761
+ args.workload = args.wait_for_job_completion
2762
+ workload_exists = check_if_workload_exists(args)
2763
+ if not workload_exists:
2764
+ xpk_print(f'Workload named {args.workload} does not exist.')
2765
+ return 1
2766
+
2767
+ # Get the full workload name
2768
+ get_workload_name_cmd = f'kubectl get workloads | grep jobset-{args.workload}'
2769
+ return_code, return_value = run_command_for_value(
2770
+ get_workload_name_cmd, 'Get full workload name', args
2771
+ )
2772
+ if return_code != 0:
2773
+ xpk_print(f'Get full workload name request returned ERROR {return_code}')
2774
+ return return_code
2775
+ full_workload_name = return_value.split(' ')[0]
2776
+
2777
+ # Call kubectl wait on the workload using the full workload name
2778
+ timeout_val = args.timeout if args.timeout is not None else -1
2779
+ timeout_msg = (
2780
+ f'{timeout_val}s' if timeout_val != -1 else 'max timeout (1 week)'
2781
+ )
2782
+ wait_cmd = (
2783
+ "kubectl wait --for jsonpath='.status.conditions[-1].type'=Finished"
2784
+ f' workload {full_workload_name} --timeout={timeout_val}s'
2785
+ )
2786
+ return_code, return_value = run_command_for_value(
2787
+ wait_cmd,
2788
+ f'Wait for workload to finish with timeout of {timeout_msg}',
2789
+ args,
2790
+ print_timer=True,
2791
+ )
2792
+ if return_code != 0:
2793
+ if 'timed out' in return_value:
2794
+ xpk_print(
2795
+ f'Timed out waiting for your workload after {timeout_msg}, see your'
2796
+ ' workload here:'
2797
+ # pylint: disable=line-too-long
2798
+ f' https://console.cloud.google.com/kubernetes/service/{zone_to_region(args.zone)}/{args.cluster}/default/{args.workload}/details?project={args.project}'
2799
+ )
2800
+ return 124
2801
+ else:
2802
+ xpk_print(f'{return_value}')
2803
+ xpk_print(f'Wait for workload returned ERROR {return_code}')
2804
+ return return_code
2805
+ xpk_print(
2806
+ 'Finished waiting for your workload, see your workload here:'
2807
+ # pylint: disable=line-too-long
2808
+ f' https://console.cloud.google.com/kubernetes/service/{zone_to_region(args.zone)}/{args.cluster}/default/{args.workload}/details?project={args.project}'
2809
+ )
2810
+ status_cmd = (
2811
+ f'kubectl get jobset {args.workload} -o'
2812
+ " jsonpath='{.status.conditions[-1].type}'"
2813
+ )
2814
+ return_code, return_value = run_command_for_value(
2815
+ status_cmd, 'Get jobset status', args
2816
+ )
2817
+ if return_code != 0:
2818
+ xpk_print(f'Get workload status request returned ERROR {return_code}')
2819
+ return return_code
2820
+ xpk_print(f'Your workload finished with status: {return_value}')
2821
+ if return_value != 'Completed':
2822
+ xpk_print('Your workload did not complete successfully')
2823
+ return 125
2824
+ return 0