xpk 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. xpk/__init__.py +15 -0
  2. xpk/api/__init__.py +15 -0
  3. xpk/api/storage_crd.yaml +52 -0
  4. xpk/commands/__init__.py +15 -0
  5. xpk/commands/batch.py +131 -0
  6. xpk/commands/cluster.py +808 -0
  7. xpk/commands/cluster_gcluster.py +269 -0
  8. xpk/commands/common.py +44 -0
  9. xpk/commands/config.py +29 -0
  10. xpk/commands/info.py +243 -0
  11. xpk/commands/inspector.py +357 -0
  12. xpk/commands/job.py +199 -0
  13. xpk/commands/kind.py +283 -0
  14. xpk/commands/kjob_common.py +44 -0
  15. xpk/commands/run.py +128 -0
  16. xpk/commands/shell.py +140 -0
  17. xpk/commands/storage.py +267 -0
  18. xpk/commands/version.py +27 -0
  19. xpk/commands/workload.py +889 -0
  20. xpk/core/__init__.py +15 -0
  21. xpk/core/blueprint/__init__.py +15 -0
  22. xpk/core/blueprint/blueprint_definitions.py +62 -0
  23. xpk/core/blueprint/blueprint_generator.py +708 -0
  24. xpk/core/capacity.py +185 -0
  25. xpk/core/cluster.py +564 -0
  26. xpk/core/cluster_private.py +200 -0
  27. xpk/core/commands.py +356 -0
  28. xpk/core/config.py +179 -0
  29. xpk/core/docker_container.py +225 -0
  30. xpk/core/docker_image.py +210 -0
  31. xpk/core/docker_manager.py +308 -0
  32. xpk/core/docker_resources.py +350 -0
  33. xpk/core/filestore.py +251 -0
  34. xpk/core/gcloud_context.py +196 -0
  35. xpk/core/gcluster_manager.py +176 -0
  36. xpk/core/gcsfuse.py +50 -0
  37. xpk/core/kjob.py +444 -0
  38. xpk/core/kueue.py +358 -0
  39. xpk/core/monitoring.py +134 -0
  40. xpk/core/nap.py +361 -0
  41. xpk/core/network.py +377 -0
  42. xpk/core/nodepool.py +581 -0
  43. xpk/core/pathways.py +377 -0
  44. xpk/core/ray.py +222 -0
  45. xpk/core/remote_state/__init__.py +15 -0
  46. xpk/core/remote_state/fuse_remote_state.py +99 -0
  47. xpk/core/remote_state/remote_state_client.py +38 -0
  48. xpk/core/resources.py +238 -0
  49. xpk/core/scheduling.py +253 -0
  50. xpk/core/storage.py +581 -0
  51. xpk/core/system_characteristics.py +1432 -0
  52. xpk/core/vertex.py +105 -0
  53. xpk/core/workload.py +341 -0
  54. xpk/core/workload_decorators/__init__.py +15 -0
  55. xpk/core/workload_decorators/rdma_decorator.py +129 -0
  56. xpk/core/workload_decorators/storage_decorator.py +52 -0
  57. xpk/core/workload_decorators/tcpxo_decorator.py +190 -0
  58. xpk/main.py +75 -0
  59. xpk/parser/__init__.py +15 -0
  60. xpk/parser/batch.py +43 -0
  61. xpk/parser/cluster.py +662 -0
  62. xpk/parser/common.py +259 -0
  63. xpk/parser/config.py +49 -0
  64. xpk/parser/core.py +135 -0
  65. xpk/parser/info.py +64 -0
  66. xpk/parser/inspector.py +65 -0
  67. xpk/parser/job.py +147 -0
  68. xpk/parser/kind.py +95 -0
  69. xpk/parser/run.py +47 -0
  70. xpk/parser/shell.py +59 -0
  71. xpk/parser/storage.py +316 -0
  72. xpk/parser/validators.py +39 -0
  73. xpk/parser/version.py +23 -0
  74. xpk/parser/workload.py +726 -0
  75. xpk/templates/__init__.py +15 -0
  76. xpk/templates/storage.yaml +13 -0
  77. xpk/utils/__init__.py +15 -0
  78. xpk/utils/console.py +55 -0
  79. xpk/utils/file.py +82 -0
  80. xpk/utils/gcs_utils.py +125 -0
  81. xpk/utils/kubectl.py +57 -0
  82. xpk/utils/network.py +168 -0
  83. xpk/utils/objects.py +88 -0
  84. xpk/utils/templates.py +28 -0
  85. xpk/utils/validation.py +80 -0
  86. xpk/utils/yaml.py +30 -0
  87. {xpk-0.5.0.dist-info → xpk-0.7.0.dist-info}/METADATA +456 -32
  88. xpk-0.7.0.dist-info/RECORD +92 -0
  89. {xpk-0.5.0.dist-info → xpk-0.7.0.dist-info}/WHEEL +1 -1
  90. xpk-0.7.0.dist-info/entry_points.txt +2 -0
  91. xpk-0.5.0.dist-info/RECORD +0 -7
  92. xpk-0.5.0.dist-info/entry_points.txt +0 -2
  93. xpk.py +0 -7282
  94. {xpk-0.5.0.dist-info → xpk-0.7.0.dist-info}/LICENSE +0 -0
  95. {xpk-0.5.0.dist-info → xpk-0.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,889 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from ..core.blueprint.blueprint_generator import get_subnetworks_for_a3mega, get_subnetworks_for_a3ultra
18
+ from ..core.cluster import (
19
+ create_xpk_k8s_service_account,
20
+ get_cluster_credentials,
21
+ setup_k8s_env,
22
+ XPK_SA,
23
+ )
24
+ from ..core.commands import run_command_with_updates, run_commands
25
+ from ..core.config import VERTEX_TENSORBOARD_FEATURE_FLAG, XPK_CURRENT_VERSION, parse_env_config
26
+ from ..core.docker_container import (
27
+ get_main_container_docker_image,
28
+ get_user_workload_container,
29
+ )
30
+
31
+ from ..core.docker_resources import get_volumes
32
+ from ..core.gcloud_context import add_zone_and_project
33
+ from ..core.kueue import LOCAL_QUEUE_NAME
34
+ from ..core.monitoring import get_gke_outlier_dashboard
35
+ from ..core.nap import (
36
+ get_autoprovisioning_node_selector_args,
37
+ is_autoprovisioning_enabled,
38
+ )
39
+ from ..core.pathways import (
40
+ ensure_pathways_workload_prerequisites,
41
+ get_pathways_proxy_args,
42
+ get_pathways_rm_args,
43
+ get_pathways_sidecar_container,
44
+ get_pathways_unified_query_link,
45
+ get_pathways_worker_args,
46
+ get_user_workload_for_pathways,
47
+ )
48
+ from ..core.resources import CLUSTER_METADATA_CONFIGMAP, get_cluster_configmap
49
+ from ..core.scheduling import (
50
+ check_if_workload_can_schedule,
51
+ create_accelerator_label,
52
+ create_machine_label,
53
+ get_cpu_affinity,
54
+ get_gpu_scheduler,
55
+ )
56
+ from ..core.storage import (
57
+ GCS_FUSE_TYPE,
58
+ GCP_FILESTORE_TYPE,
59
+ Storage,
60
+ add_bucket_iam_members,
61
+ get_storage_volume_mounts_yaml,
62
+ get_storage_volumes_yaml,
63
+ get_storages_to_mount,
64
+ get_storage_volume_mounts_yaml_for_gpu,
65
+ get_storage_volumes_yaml_for_gpu,
66
+ GCS_FUSE_ANNOTATION,
67
+ )
68
+ from ..core.system_characteristics import (
69
+ AcceleratorType,
70
+ AcceleratorTypeToAcceleratorCharacteristics,
71
+ get_system_characteristics,
72
+ )
73
+ from ..core.vertex import create_vertex_experiment
74
+ from ..core.workload import (
75
+ check_if_workload_exists,
76
+ get_gpu_rxdm_cmd,
77
+ get_gpu_rxdm_image,
78
+ get_gpu_tcp_volume,
79
+ get_gpu_volume,
80
+ get_workload_list,
81
+ wait_for_job_completion,
82
+ zone_to_region,
83
+ )
84
+ from ..core.workload_decorators import rdma_decorator, tcpxo_decorator, storage_decorator
85
+ from ..utils.console import get_user_input, xpk_exit, xpk_print
86
+ from ..utils.file import write_tmp_file
87
+ from . import cluster_gcluster
88
+
89
+ WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
90
+ kind: JobSet
91
+ metadata:
92
+ name: {args.workload}
93
+ labels:
94
+ kueue.x-k8s.io/queue-name: {local_queue_name} # Name of the LocalQueue
95
+ xpk.google.com/workload: {args.workload}
96
+ annotations:
97
+ alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool # 1:1 job replica to node pool assignment
98
+ spec:
99
+ ttlSecondsAfterFinished: {args.ttl_seconds_after_finished}
100
+ failurePolicy:
101
+ {failure_policy_rules}
102
+ maxRestarts: {args.max_restarts}
103
+ replicatedJobs:
104
+ - name: slice-job
105
+ replicas: {args.num_slices}
106
+ template:
107
+ spec:
108
+ parallelism: {system.vms_per_slice} # Equal to the number of VMs per slice
109
+ completions: {system.vms_per_slice} # Same as the above.
110
+ backoffLimit: 0 # When any pod fails, the job is failed
111
+ {pod_failure_policy}
112
+ template:
113
+ metadata:
114
+ labels:
115
+ xpk.google.com/workload: {args.workload}
116
+ annotations:
117
+ {storage_annotations}
118
+ spec:
119
+ schedulerName: {args.scheduler}
120
+ restartPolicy: Never
121
+ {affinity}
122
+ nodeSelector:
123
+ {accelerator_label}
124
+ {machine_label}
125
+ {autoprovisioning_args}
126
+ priorityClassName: {args.priority}
127
+ hostNetwork: true
128
+ dnsPolicy: ClusterFirstWithHostNet
129
+ terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
130
+ containers:
131
+ {container}
132
+ serviceAccountName: {service_account}
133
+ volumes:
134
+ {volumes}
135
+ """
136
+
137
+
138
+ GPU_WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
139
+ kind: JobSet
140
+ metadata:
141
+ name: {args.workload}
142
+ annotations: {storage_annotations}
143
+ labels:
144
+ kueue.x-k8s.io/queue-name: multislice-queue # Name of the LocalQueue
145
+ xpk.google.com/workload: {args.workload}
146
+ spec:
147
+ ttlSecondsAfterFinished: {args.ttl_seconds_after_finished}
148
+ failurePolicy:
149
+ {failure_policy_rules}
150
+ maxRestarts: {args.max_restarts}
151
+ replicatedJobs:
152
+ - name: slice-job
153
+ replicas: 1
154
+ template:
155
+ metadata:
156
+ annotations:
157
+ {storage_annotations}
158
+ spec:
159
+ parallelism: {args.num_nodes}
160
+ completions: {args.num_nodes}
161
+ backoffLimit: 0 # When any pod fails, the job is failed
162
+ {pod_failure_policy}
163
+ template:
164
+ metadata:
165
+ labels:
166
+ xpk.google.com/workload: {args.workload}
167
+ spec:
168
+ {gpu_scheduler}
169
+ priorityClassName: {args.priority}
170
+ restartPolicy: Never
171
+ hostNetwork: true
172
+ dnsPolicy: ClusterFirstWithHostNet
173
+ terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
174
+ serviceAccountName: {service_account}
175
+ tolerations:
176
+ - operator: "Exists"
177
+ key: nvidia.com/gpu
178
+ volumes:
179
+ {gpu_volume}
180
+ {storage_volumes}
181
+ containers:
182
+ {gpu_rxdm_image}
183
+ imagePullPolicy: Always
184
+ command:
185
+ - "bash"
186
+ - "-c"
187
+ - |
188
+ {gpu_rxdm_cmd} &
189
+ while [ ! -e "/usr/share/workload/workload_terminated" ]; do sleep 10; echo "sleeping"; done
190
+ securityContext:
191
+ privileged: true
192
+ volumeMounts:
193
+ {gpu_tcp_volume}
194
+ {storage_volume_mounts}
195
+ - name: nvidia-install-dir-host
196
+ mountPath: /usr/local/nvidia/lib64
197
+ - name: workload-terminated-volume
198
+ mountPath: /usr/share/workload
199
+ env:
200
+ - name: LD_LIBRARY_PATH
201
+ value: /usr/local/nvidia/lib64
202
+ {container}
203
+ """
204
+
205
+ A3_GPU_WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
206
+ kind: JobSet
207
+ metadata:
208
+ name: {args.workload}
209
+ labels:
210
+ kueue.x-k8s.io/queue-name: multislice-queue # Name of the LocalQueue
211
+ xpk.google.com/workload: {args.workload}
212
+ spec:
213
+ ttlSecondsAfterFinished: {args.ttl_seconds_after_finished}
214
+ failurePolicy:
215
+ {failure_policy_rules}
216
+ maxRestarts: {args.max_restarts}
217
+ replicatedJobs:
218
+ - name: slice-job
219
+ replicas: 1
220
+ template:
221
+ spec:
222
+ parallelism: {args.num_nodes}
223
+ completions: {args.num_nodes}
224
+ backoffLimit: 0 # When any pod fails, the job is failed
225
+ {pod_failure_policy}
226
+ template:
227
+ metadata:
228
+ labels:
229
+ xpk.google.com/workload: {args.workload}
230
+ annotations:
231
+ kueue.x-k8s.io/podset-preferred-topology: "cloud.google.com/gce-topology-host"
232
+ spec:
233
+ priorityClassName: {args.priority}
234
+ restartPolicy: Never
235
+ dnsPolicy: ClusterFirstWithHostNet
236
+ terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
237
+ serviceAccountName: {service_account}
238
+ tolerations:
239
+ - operator: "Exists"
240
+ key: nvidia.com/gpu
241
+ containers:
242
+ {container}
243
+ """
244
+
245
+ PW_WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
246
+ kind: JobSet
247
+ metadata:
248
+ name: {args.workload}
249
+ labels:
250
+ kueue.x-k8s.io/queue-name: {local_queue_name} # Name of the LocalQueue
251
+ xpk.google.com/workload: {args.workload}
252
+ spec:
253
+ ttlSecondsAfterFinished: {args.ttl_seconds_after_finished}
254
+ failurePolicy:
255
+ {failure_policy_rules}
256
+ maxRestarts: {args.max_restarts}
257
+ successPolicy:
258
+ operator: "All"
259
+ targetReplicatedJobs:
260
+ - {args.targetReplicatedJob}
261
+ replicatedJobs:
262
+ - name: worker
263
+ replicas: {args.num_slices}
264
+ template:
265
+ metadata:
266
+ annotations:
267
+ alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
268
+ labels:
269
+ xpk.google.com/workload: {args.workload}
270
+ spec:
271
+ backoffLimit: {backoff_limit}
272
+ completions: {system.vms_per_slice}
273
+ parallelism: {system.vms_per_slice}
274
+ template:
275
+ metadata:
276
+ annotations:
277
+ {storage_annotations}
278
+ spec:
279
+ terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
280
+ serviceAccountName: {service_account}
281
+ containers:
282
+ - args:
283
+ {pathways_worker_args}
284
+ image: {args.server_image}
285
+ imagePullPolicy: Always
286
+ name: pathways-worker
287
+ ports:
288
+ - containerPort: 29001
289
+ - containerPort: 8471
290
+ - containerPort: 8080
291
+ resources:
292
+ limits:
293
+ {resource_type}: {system.chips_per_vm}
294
+ securityContext:
295
+ privileged: true
296
+ volumeMounts:
297
+ - mountPath: /tmp
298
+ name: shared-tmp
299
+ {storage_volume_mounts}
300
+ env:
301
+ - name: PROJECT_ID
302
+ value: {args.project}
303
+ - name: LOCATION
304
+ value: {args.zone}
305
+ - name: CLUSTER_NAME
306
+ value: {args.cluster}
307
+ - name: POD_NAME
308
+ valueFrom:
309
+ fieldRef:
310
+ fieldPath: metadata.name
311
+ - name: CONTAINER_NAME
312
+ value: "pathways-worker"
313
+ - name: NAMESPACE
314
+ valueFrom:
315
+ fieldRef:
316
+ fieldPath: metadata.namespace
317
+ # Workaround for v6e
318
+ - name: MEGASCALE_GRPC_ENABLE_XOR_TRACER
319
+ value: "false"
320
+ - name: MEGASCALE_NUM_SLICES
321
+ valueFrom:
322
+ fieldRef:
323
+ fieldPath: "metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']"
324
+ - name: JOBSET_NAME
325
+ valueFrom:
326
+ fieldRef:
327
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
328
+ - name: REPLICATED_JOB_NAME
329
+ valueFrom:
330
+ fieldRef:
331
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
332
+ - name: MEGASCALE_SLICE_ID
333
+ valueFrom:
334
+ fieldRef:
335
+ fieldPath: "metadata.labels['jobset.sigs.k8s.io/job-index']"
336
+ - name: MEGASCALE_COORDINATOR_ADDRESS
337
+ value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-$(MEGASCALE_SLICE_ID)-0.$(JOBSET_NAME)"
338
+ {pathways_sidecar_container}
339
+ nodeSelector:
340
+ {accelerator_label}
341
+ {machine_label}
342
+ {autoprovisioning_args}
343
+ priorityClassName: {args.priority}
344
+ hostNetwork: true
345
+ dnsPolicy: ClusterFirstWithHostNet
346
+ volumes:
347
+ - hostPath:
348
+ path: /tmp
349
+ type: DirectoryOrCreate
350
+ name: shared-tmp
351
+ {storage_volumes}
352
+ - name: rm
353
+ replicas: 1
354
+ template:
355
+ metadata:
356
+ labels:
357
+ xpk.google.com/workload: {args.workload}
358
+ spec:
359
+ backoffLimit: 0
360
+ completions: 1
361
+ parallelism: 1
362
+ template:
363
+ spec:
364
+ containers:
365
+ - args:
366
+ {pathways_rm_args}
367
+ env:
368
+ - name: PROJECT_ID
369
+ value: {args.project}
370
+ - name: LOCATION
371
+ value: {args.zone}
372
+ - name: CLUSTER_NAME
373
+ value: {args.cluster}
374
+ - name: POD_NAME
375
+ valueFrom:
376
+ fieldRef:
377
+ fieldPath: metadata.name
378
+ - name: CONTAINER_NAME
379
+ value: "pathways-rm"
380
+ - name: NAMESPACE
381
+ valueFrom:
382
+ fieldRef:
383
+ fieldPath: metadata.namespace
384
+ - name: REPLICATED_JOB_NAME
385
+ valueFrom:
386
+ fieldRef:
387
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
388
+ - name: JOBSET_NAME
389
+ valueFrom:
390
+ fieldRef:
391
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
392
+ - name: HOST_ADDRESS
393
+ value: $(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)
394
+ - name: TPU_SKIP_MDS_QUERY
395
+ value: "true"
396
+ image: {args.server_image}
397
+ imagePullPolicy: Always
398
+ name: pathways-rm
399
+ ports:
400
+ - containerPort: 29001
401
+ securityContext:
402
+ privileged: true
403
+ volumeMounts:
404
+ - mountPath: /tmp
405
+ name: shared-tmp
406
+ nodeSelector:
407
+ cloud.google.com/gke-nodepool: cpu-rm-np
408
+ hostNetwork: true
409
+ dnsPolicy: ClusterFirstWithHostNet
410
+ volumes:
411
+ - hostPath:
412
+ path: /tmp
413
+ type: DirectoryOrCreate
414
+ name: shared-tmp
415
+ - name: proxy
416
+ replicas: 1
417
+ template:
418
+ metadata:
419
+ labels:
420
+ xpk.google.com/workload: {args.workload}
421
+ spec:
422
+ backoffLimit: 0
423
+ completions: 1
424
+ parallelism: 1
425
+ template:
426
+ spec:
427
+ containers:
428
+ - args:
429
+ {pathways_proxy_args}
430
+ env:
431
+ - name: PROJECT_ID
432
+ value: {args.project}
433
+ - name: LOCATION
434
+ value: {args.zone}
435
+ - name: CLUSTER_NAME
436
+ value: {args.cluster}
437
+ - name: POD_NAME
438
+ valueFrom:
439
+ fieldRef:
440
+ fieldPath: metadata.name
441
+ - name: CONTAINER_NAME
442
+ value: "pathways-proxy"
443
+ - name: NAMESPACE
444
+ valueFrom:
445
+ fieldRef:
446
+ fieldPath: metadata.namespace
447
+ image: {args.proxy_server_image}
448
+ imagePullPolicy: Always
449
+ name: pathways-proxy
450
+ ports:
451
+ - containerPort: 29000
452
+ hostNetwork: true
453
+ dnsPolicy: ClusterFirstWithHostNet
454
+ nodeSelector:
455
+ cloud.google.com/gke-nodepool: cpu-proxy-np
456
+ {user_workload}
457
+ """
458
+
459
+
460
+ def workload_create_pathways(args) -> None:
461
+ """Run jobset apply command for a file, specifically for Pathways.
462
+
463
+ Args:
464
+ args: user provided arguments for running the command.
465
+
466
+ Returns:
467
+ 0 if successful and 1 otherwise.
468
+ """
469
+ args.use_pathways = True
470
+ if args.headless:
471
+ xpk_print(
472
+ 'Please use kubectl port forwarding to connect to the Pathways proxy.'
473
+ ' kubectl get pods kubectl port-forward <proxy-pod-name> 29000:29000'
474
+ ' JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 python'
475
+ " -c 'import pathwaysutils; import jax; print(jax.devices())'"
476
+ )
477
+ workload_create(args)
478
+
479
+
480
+ def workload_create(args) -> None:
481
+ """Run jobset apply command for a file.
482
+
483
+ Args:
484
+ args: user provided arguments for running the command.
485
+
486
+ Returns:
487
+ 0 if successful and 1 otherwise.
488
+ """
489
+ k8s_api_client = setup_k8s_env(args)
490
+ create_xpk_k8s_service_account()
491
+
492
+ workload_exists = check_if_workload_exists(args)
493
+
494
+ if workload_exists:
495
+ xpk_print(
496
+ f'{args.workload} already exists, XPK will not create this workload.'
497
+ ' Please pick a new workload name'
498
+ )
499
+ xpk_exit(1)
500
+
501
+ xpk_print('Starting workload create', flush=True)
502
+ system, return_code = get_system_characteristics(args)
503
+
504
+ if return_code > 0:
505
+ xpk_print('Fetching system characteristics failed!')
506
+ xpk_exit(return_code)
507
+
508
+ if not check_if_workload_can_schedule(args, system):
509
+ xpk_exit(1)
510
+
511
+ xpk_print('Starting workload create', flush=True)
512
+
513
+ metadata_configmap_name = f'{args.cluster}-{CLUSTER_METADATA_CONFIGMAP}'
514
+ cluster_config_map = get_cluster_configmap(args, metadata_configmap_name)
515
+ cluster_xpk_version = None
516
+ if cluster_config_map is None:
517
+ xpk_print(
518
+ f'Warning: Unable to find ConfigMap: {metadata_configmap_name} for the'
519
+ ' cluster. We recommend to upgrade your cluster by running `xpk'
520
+ ' cluster create`.'
521
+ )
522
+ else:
523
+ cluster_xpk_version = cluster_config_map.get('xpk_version')
524
+ if (
525
+ cluster_xpk_version is not None
526
+ and cluster_xpk_version != XPK_CURRENT_VERSION
527
+ ):
528
+ xpk_print(
529
+ 'Warning: Cluster has been created using XPK version:'
530
+ f' {cluster_config_map["xpk_version"]} but the XPK version you are'
531
+ f' using to schedule workload is: {XPK_CURRENT_VERSION}. Some features'
532
+ ' might not be available for this cluster. We recommend to'
533
+ ' upgrade/downgrade your XPK version or cluster by running `xpk'
534
+ ' cluster create`.'
535
+ )
536
+
537
+ debugging_dashboard_id = None
538
+
539
+ tensorboard_config = {}
540
+ if VERTEX_TENSORBOARD_FEATURE_FLAG and args.use_vertex_tensorboard:
541
+ tensorboard_config = create_vertex_experiment(args)
542
+ # exit if failed to create Experiment in Vertex AI
543
+ if not tensorboard_config:
544
+ xpk_exit(1)
545
+
546
+ parse_env_config(args, tensorboard_config, system)
547
+
548
+ # Currently autoprovisioning is not enabled for Pathways workloads.
549
+ autoprovisioning_args = ''
550
+ autoprovisioning_enabled, return_code = is_autoprovisioning_enabled(
551
+ args, system
552
+ )
553
+ if return_code != 0:
554
+ xpk_exit(return_code)
555
+ if autoprovisioning_enabled:
556
+ # Determine NAP capacity type
557
+ autoprovisioning_args, return_code = (
558
+ get_autoprovisioning_node_selector_args(args)
559
+ )
560
+ if return_code != 0:
561
+ xpk_exit(return_code)
562
+
563
+ storages: list[Storage] = get_storages_to_mount(k8s_api_client, args.storage)
564
+ gcs_fuse_storages = list(
565
+ filter(lambda storage: storage.type == GCS_FUSE_TYPE, storages)
566
+ )
567
+ gcpfilestore_storages: list[Storage] = list(
568
+ filter(lambda storage: storage.type == GCP_FILESTORE_TYPE, storages)
569
+ )
570
+ storage_annotations = ''
571
+ service_account = ''
572
+ if len(gcs_fuse_storages) > 0:
573
+ storage_annotations = GCS_FUSE_ANNOTATION
574
+ service_account = XPK_SA
575
+ xpk_print(f'Detected gcsfuse Storages to add: {gcs_fuse_storages}')
576
+ else:
577
+ xpk_print('No gcsfuse Storages to add detected')
578
+ failure_policy_rules = """rules:
579
+ - action: FailJobSet
580
+ onJobFailureReasons:
581
+ - PodFailurePolicy"""
582
+ restart_on_exit_codes = get_restart_exit_codes(args)
583
+ restart_on_exit_codes = ','.join(map(str, restart_on_exit_codes))
584
+ pod_failure_policy = f"""
585
+ podFailurePolicy:
586
+ rules:
587
+ - action: FailJob
588
+ onExitCodes:
589
+ containerName: {get_main_container_docker_image(args, system)}
590
+ operator: NotIn
591
+ values: [{restart_on_exit_codes}]"""
592
+
593
+ if len(gcpfilestore_storages) > 0:
594
+ xpk_print(
595
+ f'Detected gcp filestores instances to add: {gcpfilestore_storages}'
596
+ )
597
+ service_account = XPK_SA
598
+ else:
599
+ xpk_print('No gcp filestore instances to add detected.')
600
+ all_storages = gcs_fuse_storages + gcpfilestore_storages
601
+ # Create the workload file based on accelerator type or workload type.
602
+ if system.accelerator_type == AcceleratorType['GPU']:
603
+ container, debugging_dashboard_id = get_user_workload_container(
604
+ args, system
605
+ )
606
+ gpu_scheduler, return_code = get_gpu_scheduler(
607
+ args, system, autoprovisioning_args
608
+ )
609
+ if return_code != 0:
610
+ xpk_exit(return_code)
611
+
612
+ if system.device_type in cluster_gcluster.supported_device_types:
613
+ yml_string = A3_GPU_WORKLOAD_CREATE_YAML.format(
614
+ args=args,
615
+ container=container,
616
+ service_account=XPK_SA,
617
+ failure_policy_rules=failure_policy_rules,
618
+ pod_failure_policy=pod_failure_policy,
619
+ )
620
+
621
+ if args.device_type == cluster_gcluster.a3mega_device_type:
622
+ sub_networks = get_subnetworks_for_a3mega(args.cluster)
623
+ yml_string = tcpxo_decorator.decorate_jobset(yml_string, sub_networks)
624
+
625
+ if args.device_type == cluster_gcluster.a3ultra_device_type:
626
+ sub_networks = get_subnetworks_for_a3ultra(args.cluster)
627
+ yml_string = rdma_decorator.decorate_jobset(yml_string, sub_networks)
628
+
629
+ if len(gcs_fuse_storages) + len(gcpfilestore_storages) > 0:
630
+ yml_string = storage_decorator.decorate_jobset(yml_string, all_storages)
631
+ else:
632
+ yml_string = GPU_WORKLOAD_CREATE_YAML.format(
633
+ args=args,
634
+ container=container,
635
+ command=args.command,
636
+ chips_per_vm=system.chips_per_vm,
637
+ gpu_scheduler=gpu_scheduler,
638
+ gpu_volume=get_gpu_volume(system),
639
+ gpu_rxdm_image=get_gpu_rxdm_image(system),
640
+ gpu_rxdm_cmd=get_gpu_rxdm_cmd(system),
641
+ gpu_tcp_volume=get_gpu_tcp_volume(system),
642
+ storage_volumes=get_storage_volumes_yaml_for_gpu(all_storages),
643
+ storage_volume_mounts=get_storage_volume_mounts_yaml_for_gpu(
644
+ all_storages
645
+ ),
646
+ storage_annotations=storage_annotations,
647
+ service_account=service_account,
648
+ failure_policy_rules=failure_policy_rules,
649
+ pod_failure_policy=pod_failure_policy,
650
+ )
651
+
652
+ elif args.use_pathways and ensure_pathways_workload_prerequisites(
653
+ args, system
654
+ ):
655
+ yml_string = PW_WORKLOAD_CREATE_YAML.format(
656
+ args=args,
657
+ system=system,
658
+ accelerator_label=create_accelerator_label(
659
+ system.accelerator_type, system
660
+ ),
661
+ machine_label=create_machine_label(system.accelerator_type, system),
662
+ pathways_worker_args=get_pathways_worker_args(args),
663
+ pathways_proxy_args=get_pathways_proxy_args(args),
664
+ pathways_sidecar_container=get_pathways_sidecar_container(args),
665
+ user_workload=get_user_workload_for_pathways(
666
+ args, system, pod_failure_policy, storages
667
+ ),
668
+ resource_type=AcceleratorTypeToAcceleratorCharacteristics[
669
+ system.accelerator_type
670
+ ].resource_type,
671
+ local_queue_name=LOCAL_QUEUE_NAME,
672
+ autoprovisioning_args=autoprovisioning_args,
673
+ backoff_limit=system.vms_per_slice * 4,
674
+ storage_annotations=storage_annotations,
675
+ storage_volumes=get_storage_volumes_yaml(all_storages),
676
+ storage_volume_mounts=get_storage_volume_mounts_yaml(all_storages),
677
+ pathways_rm_args=get_pathways_rm_args(args, system),
678
+ service_account=service_account,
679
+ failure_policy_rules=failure_policy_rules,
680
+ pod_failure_policy=pod_failure_policy,
681
+ )
682
+ else:
683
+ container, debugging_dashboard_id = get_user_workload_container(
684
+ args, system
685
+ )
686
+ yml_string = WORKLOAD_CREATE_YAML.format(
687
+ args=args,
688
+ system=system,
689
+ container=container,
690
+ affinity=get_cpu_affinity(system.accelerator_type),
691
+ accelerator_label=create_accelerator_label(
692
+ system.accelerator_type, system
693
+ ),
694
+ machine_label=create_machine_label(system.accelerator_type, system),
695
+ local_queue_name=LOCAL_QUEUE_NAME,
696
+ autoprovisioning_args=autoprovisioning_args,
697
+ volumes=get_volumes(args, system),
698
+ storage_annotations=storage_annotations,
699
+ service_account=service_account,
700
+ failure_policy_rules=failure_policy_rules,
701
+ pod_failure_policy=pod_failure_policy,
702
+ )
703
+ tmp = write_tmp_file(yml_string)
704
+ command = f'kubectl apply -f {str(tmp.file.name)}'
705
+ return_code = run_command_with_updates(command, 'Creating Workload', args)
706
+
707
+ if return_code != 0:
708
+ xpk_print(f'Create Workload request returned ERROR {return_code}')
709
+ xpk_exit(return_code)
710
+
711
+ add_bucket_iam_members(args, storages)
712
+ # Get GKE outlier dashboard for TPU
713
+ outlier_dashboard_id = None
714
+ if system.accelerator_type == AcceleratorType['TPU']:
715
+ outlier_dashboard_id = get_gke_outlier_dashboard(args)
716
+
717
+ # Outlier and debugging dashboards
718
+ if outlier_dashboard_id is not None:
719
+ xpk_print(
720
+ 'Check statistics and outlier mode of GKE metrics here:'
721
+ # pylint: disable=line-too-long
722
+ f' https://console.cloud.google.com/monitoring/dashboards/builder/{outlier_dashboard_id}?project={args.project}&f.rlabel.cluster_name.ClusterName={args.cluster}.'
723
+ ' To view the metric data for your workload, select'
724
+ f' {args.workload} from the JobName filter on the dashboard.'
725
+ )
726
+
727
+ if debugging_dashboard_id is not None:
728
+ xpk_print(
729
+ 'Check stack traces collected in Cloud Logging here:'
730
+ # pylint: disable=line-too-long
731
+ f' https://console.cloud.google.com/monitoring/dashboards/builder/{debugging_dashboard_id}?project={args.project}&f.rlabel.cluster_name.ClusterName={args.cluster}.'
732
+ ' To view the stack traces for your workload, select'
733
+ f' {args.workload} from the JobName filter on the dashboard.'
734
+ )
735
+
736
+ if args.use_pathways:
737
+ if args.headless:
738
+ xpk_print(
739
+ '******* Please use kubectl port forwarding to connect to the'
740
+ ' Pathways proxy, once you see "IFRT proxy server started with status'
741
+ ' OK" on the proxy link below. Remember to delete the workload once'
742
+ ' done! ******* '
743
+ )
744
+ xpk_print(
745
+ 'Steps to connect to the proxy: kubectl get pods | grep proxy ;'
746
+ ' kubectl port-forward <proxy-pod-name> 29000:29000; '
747
+ ' JAX_PLATFORMS=proxy; JAX_BACKEND_TARGET=grpc://127.0.0.1:29000;'
748
+ " python -c 'import pathwaysutils; import jax; print(jax.devices())'"
749
+ )
750
+ pathways_proxy_link = f'https://console.cloud.google.com/kubernetes/job/{zone_to_region(args.zone)}/{args.cluster}/default/{args.workload}-proxy-0/details?project={args.project}'
751
+ xpk_print(
752
+ 'Follow the proxy here:'
753
+ # pylint: disable=line-too-long)
754
+ f' {pathways_proxy_link} '
755
+ )
756
+ xpk_print(
757
+ 'Follow your Pathways workload and other resources here : '
758
+ f'{get_pathways_unified_query_link(args)}'
759
+ )
760
+ else:
761
+ xpk_print(
762
+ 'Follow your workload here:'
763
+ # pylint: disable=line-too-long
764
+ f' https://console.cloud.google.com/kubernetes/service/{zone_to_region(args.zone)}/{args.cluster}/default/{args.workload}/details?project={args.project}'
765
+ )
766
+ duration_of_logs = 'P1D' # Past 1 Day
767
+ xpk_print(
768
+ 'Follow your worker 0, slice 0 logs here:'
769
+ ' Adjust the pod name'
770
+ ' ([prefix]-slice-job-[slice_number]-[worker_number])'
771
+ ' after clicking the url if you want other worker logs.'
772
+ # pylint: disable=line-too-long
773
+ f' https://console.cloud.google.com/logs/query;query=resource.type%3D%22k8s_container%22%0Aresource.labels.project_id%3D%22{args.project}%22%0Aresource.labels.location%3D%22{zone_to_region(args.zone)}%22%0Aresource.labels.cluster_name%3D%22{args.cluster}%22%0Aresource.labels.namespace_name%3D%22default%22%0Aresource.labels.pod_name:%22{args.workload}-slice-job-0-0-%22%20severity%3E%3DDEFAULT;storageScope=project;duration={duration_of_logs}?e=13802955&mods=allow_workbench_image_override&project={args.project}'
774
+ )
775
+
776
+ xpk_exit(0)
777
+
778
+
779
+ def get_restart_exit_codes(args) -> list:
780
+ exit_codes = [42]
781
+ exit_codes.extend(range(127, 256, 1))
782
+
783
+ if args.restart_on_exit_codes is not None:
784
+ items = args.restart_on_exit_codes.split(',')
785
+ for item in items:
786
+ item = item.strip()
787
+ if '-' in item:
788
+ start, end = map(int, item.split('-'))
789
+ exit_codes.extend(range(start, end + 1))
790
+ else:
791
+ exit_codes.append(int(item))
792
+
793
+ # Remove duplicates that the user may have added.
794
+ return list(set(exit_codes))
795
+
796
+
797
+ def workload_delete(args) -> None:
798
+ """Function around workload delete.
799
+
800
+ Args:
801
+ args: user provided arguments for running the command.
802
+
803
+ Returns:
804
+ 0 if successful and 1 otherwise.
805
+ """
806
+ xpk_print('Starting Workload delete', flush=True)
807
+ add_zone_and_project(args)
808
+ get_cluster_credentials(args)
809
+
810
+ will_delete = True
811
+ if not args.workload:
812
+ xpk_print('Get the name of the workloads in the cluster.')
813
+ return_code, return_value = get_workload_list(args)
814
+
815
+ if return_code != 0:
816
+ xpk_print(f'List Job request returned ERROR {return_code}')
817
+ xpk_exit(return_code)
818
+ # Skip the header
819
+ workloads = [x.split(' ')[0] for x in return_value.splitlines()][1:]
820
+ if workloads and not args.force:
821
+ will_delete = get_user_input(
822
+ f'Planning to delete {len(workloads)} workloads in the cluster'
823
+ f' {args.cluster} including {workloads}. \nDo you wish to delete: y'
824
+ ' (yes) / n (no):\n'
825
+ )
826
+ else:
827
+ workloads = [args.workload]
828
+
829
+ if not workloads:
830
+ xpk_print(
831
+ 'There are no workloads to delete matching the filter in the cluster.'
832
+ )
833
+ elif not will_delete:
834
+ xpk_print('Skipping delete command.')
835
+ else:
836
+ commands = []
837
+ task_names = []
838
+ for workload in workloads:
839
+ args.workload = workload
840
+ command = f'kubectl delete jobset {workload} -n default'
841
+ task_name = f'WorkloadDelete-{workload}'
842
+ commands.append(command)
843
+ task_names.append(task_name)
844
+
845
+ # Not batching deletion for single workload
846
+ if len(workloads) == 1:
847
+ return_code = run_command_with_updates(
848
+ commands[0], 'Delete Workload', args
849
+ )
850
+ else:
851
+ return_code = run_commands(
852
+ commands, 'Delete Workload', task_names, batch=100
853
+ )
854
+
855
+ if return_code != 0:
856
+ xpk_print(f'Delete Workload request returned ERROR {return_code}')
857
+ xpk_exit(return_code)
858
+ xpk_exit(0)
859
+
860
+
861
+ def workload_list(args) -> None:
862
+ """Function around workload list.
863
+
864
+ Args:
865
+ args: user provided arguments for running the command.
866
+
867
+ Returns:
868
+ 0 if successful and 1 otherwise.
869
+ """
870
+ xpk_print(args)
871
+
872
+ xpk_print('Starting workload list', flush=True)
873
+ add_zone_and_project(args)
874
+ get_cluster_credentials(args)
875
+
876
+ if args.wait_for_job_completion:
877
+ return_code = wait_for_job_completion(args)
878
+ if return_code != 0:
879
+ xpk_print(f'Wait for job completion returned ERROR {return_code}')
880
+ xpk_exit(return_code)
881
+ args.filter_by_job = args.wait_for_job_completion
882
+
883
+ return_code, return_value = get_workload_list(args)
884
+
885
+ if return_code != 0:
886
+ xpk_print(f'List Job request returned ERROR {return_code}')
887
+ xpk_exit(return_code)
888
+ xpk_print(f'Workload List Output:\n{return_value}')
889
+ xpk_exit(0)