xpk 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. xpk/api/__init__.py +15 -0
  2. xpk/api/storage_crd.yaml +52 -0
  3. xpk/commands/batch.py +27 -5
  4. xpk/commands/cluster.py +104 -80
  5. xpk/commands/cluster_gcluster.py +94 -10
  6. xpk/commands/common.py +44 -0
  7. xpk/commands/config.py +29 -0
  8. xpk/commands/info.py +8 -10
  9. xpk/commands/inspector.py +5 -11
  10. xpk/commands/job.py +9 -7
  11. xpk/commands/kind.py +34 -4
  12. xpk/commands/kjob_common.py +44 -0
  13. xpk/commands/run.py +128 -0
  14. xpk/commands/shell.py +27 -7
  15. xpk/commands/storage.py +267 -0
  16. xpk/commands/version.py +6 -18
  17. xpk/commands/workload.py +381 -184
  18. xpk/core/blueprint/blueprint_definitions.py +1 -0
  19. xpk/core/blueprint/blueprint_generator.py +132 -76
  20. xpk/core/capacity.py +185 -0
  21. xpk/core/cluster.py +564 -0
  22. xpk/core/cluster_private.py +6 -3
  23. xpk/core/commands.py +18 -14
  24. xpk/core/config.py +179 -0
  25. xpk/core/docker_container.py +225 -0
  26. xpk/core/docker_image.py +210 -0
  27. xpk/core/docker_resources.py +350 -0
  28. xpk/core/filestore.py +251 -0
  29. xpk/core/gcloud_context.py +196 -0
  30. xpk/core/gcluster_manager.py +20 -2
  31. xpk/core/gcsfuse.py +50 -0
  32. xpk/core/kjob.py +257 -18
  33. xpk/core/kueue.py +12 -6
  34. xpk/core/monitoring.py +134 -0
  35. xpk/core/nap.py +32 -20
  36. xpk/core/network.py +377 -0
  37. xpk/core/nodepool.py +581 -0
  38. xpk/core/pathways.py +124 -45
  39. xpk/core/remote_state/__init__.py +15 -0
  40. xpk/core/remote_state/fuse_remote_state.py +99 -0
  41. xpk/core/remote_state/remote_state_client.py +38 -0
  42. xpk/core/resources.py +238 -0
  43. xpk/core/scheduling.py +253 -0
  44. xpk/core/storage.py +581 -0
  45. xpk/core/system_characteristics.py +38 -1
  46. xpk/core/vertex.py +105 -0
  47. xpk/core/workload.py +209 -1
  48. xpk/core/workload_decorators/rdma_decorator.py +25 -5
  49. xpk/core/workload_decorators/storage_decorator.py +52 -0
  50. xpk/core/workload_decorators/tcpxo_decorator.py +70 -37
  51. xpk/main.py +3 -1
  52. xpk/parser/batch.py +10 -151
  53. xpk/parser/cluster.py +49 -8
  54. xpk/parser/common.py +189 -1
  55. xpk/parser/config.py +49 -0
  56. xpk/parser/core.py +27 -1
  57. xpk/parser/info.py +2 -1
  58. xpk/parser/inspector.py +3 -3
  59. xpk/parser/job.py +25 -4
  60. xpk/parser/kind.py +3 -2
  61. xpk/parser/run.py +47 -0
  62. xpk/parser/shell.py +10 -1
  63. xpk/parser/storage.py +316 -0
  64. xpk/parser/validators.py +3 -3
  65. xpk/parser/workload.py +118 -76
  66. xpk/templates/__init__.py +15 -0
  67. xpk/templates/storage.yaml +13 -0
  68. xpk/utils/gcs_utils.py +125 -0
  69. xpk/utils/kubectl.py +57 -0
  70. xpk/utils/objects.py +8 -5
  71. xpk/utils/templates.py +28 -0
  72. xpk/utils/validation.py +80 -0
  73. {xpk-0.6.0.dist-info → xpk-0.7.0.dist-info}/METADATA +165 -14
  74. xpk-0.7.0.dist-info/RECORD +92 -0
  75. {xpk-0.6.0.dist-info → xpk-0.7.0.dist-info}/WHEEL +1 -1
  76. xpk/core/core.py +0 -2824
  77. xpk-0.6.0.dist-info/RECORD +0 -57
  78. {xpk-0.6.0.dist-info → xpk-0.7.0.dist-info}/LICENSE +0 -0
  79. {xpk-0.6.0.dist-info → xpk-0.7.0.dist-info}/entry_points.txt +0 -0
  80. {xpk-0.6.0.dist-info → xpk-0.7.0.dist-info}/top_level.txt +0 -0
xpk/commands/workload.py CHANGED
@@ -14,36 +14,24 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from ..core.commands import (
18
- run_command_with_updates,
19
- run_commands,
17
+ from ..core.blueprint.blueprint_generator import get_subnetworks_for_a3mega, get_subnetworks_for_a3ultra
18
+ from ..core.cluster import (
19
+ create_xpk_k8s_service_account,
20
+ get_cluster_credentials,
21
+ setup_k8s_env,
22
+ XPK_SA,
20
23
  )
21
- from ..core.core import (
22
- CLUSTER_METADATA_CONFIGMAP,
23
- VERTEX_TENSORBOARD_FEATURE_FLAG,
24
- AcceleratorTypeToAcceleratorCharacteristics,
25
- add_zone_and_project,
26
- check_if_workload_can_schedule,
27
- check_if_workload_exists,
28
- create_accelerator_label,
29
- create_machine_label,
30
- create_vertex_experiment,
31
- get_cluster_configmap,
32
- get_cpu_affinity,
33
- get_gke_outlier_dashboard,
34
- get_gpu_rxdm_cmd,
35
- get_gpu_rxdm_image,
36
- get_gpu_scheduler,
37
- get_gpu_tcp_volume,
38
- get_gpu_volume,
24
+ from ..core.commands import run_command_with_updates, run_commands
25
+ from ..core.config import VERTEX_TENSORBOARD_FEATURE_FLAG, XPK_CURRENT_VERSION, parse_env_config
26
+ from ..core.docker_container import (
27
+ get_main_container_docker_image,
39
28
  get_user_workload_container,
40
- get_volumes,
41
- parse_env_config,
42
- wait_for_job_completion,
43
- xpk_current_version,
44
- zone_to_region,
45
29
  )
30
+
31
+ from ..core.docker_resources import get_volumes
32
+ from ..core.gcloud_context import add_zone_and_project
46
33
  from ..core.kueue import LOCAL_QUEUE_NAME
34
+ from ..core.monitoring import get_gke_outlier_dashboard
47
35
  from ..core.nap import (
48
36
  get_autoprovisioning_node_selector_args,
49
37
  is_autoprovisioning_enabled,
@@ -52,22 +40,53 @@ from ..core.pathways import (
52
40
  ensure_pathways_workload_prerequisites,
53
41
  get_pathways_proxy_args,
54
42
  get_pathways_rm_args,
43
+ get_pathways_sidecar_container,
55
44
  get_pathways_unified_query_link,
56
45
  get_pathways_worker_args,
57
46
  get_user_workload_for_pathways,
58
47
  )
48
+ from ..core.resources import CLUSTER_METADATA_CONFIGMAP, get_cluster_configmap
49
+ from ..core.scheduling import (
50
+ check_if_workload_can_schedule,
51
+ create_accelerator_label,
52
+ create_machine_label,
53
+ get_cpu_affinity,
54
+ get_gpu_scheduler,
55
+ )
56
+ from ..core.storage import (
57
+ GCS_FUSE_TYPE,
58
+ GCP_FILESTORE_TYPE,
59
+ Storage,
60
+ add_bucket_iam_members,
61
+ get_storage_volume_mounts_yaml,
62
+ get_storage_volumes_yaml,
63
+ get_storages_to_mount,
64
+ get_storage_volume_mounts_yaml_for_gpu,
65
+ get_storage_volumes_yaml_for_gpu,
66
+ GCS_FUSE_ANNOTATION,
67
+ )
59
68
  from ..core.system_characteristics import (
60
69
  AcceleratorType,
70
+ AcceleratorTypeToAcceleratorCharacteristics,
61
71
  get_system_characteristics,
62
72
  )
63
- from ..core.workload import get_workload_list
73
+ from ..core.vertex import create_vertex_experiment
74
+ from ..core.workload import (
75
+ check_if_workload_exists,
76
+ get_gpu_rxdm_cmd,
77
+ get_gpu_rxdm_image,
78
+ get_gpu_tcp_volume,
79
+ get_gpu_volume,
80
+ get_workload_list,
81
+ wait_for_job_completion,
82
+ zone_to_region,
83
+ )
84
+ from ..core.workload_decorators import rdma_decorator, tcpxo_decorator, storage_decorator
64
85
  from ..utils.console import get_user_input, xpk_exit, xpk_print
65
86
  from ..utils.file import write_tmp_file
66
- from .cluster import set_cluster_command
67
- from ..core.workload_decorators import tcpxo_decorator, rdma_decorator
68
87
  from . import cluster_gcluster
69
88
 
70
- workload_create_yaml = """apiVersion: jobset.x-k8s.io/v1alpha2
89
+ WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
71
90
  kind: JobSet
72
91
  metadata:
73
92
  name: {args.workload}
@@ -79,6 +98,7 @@ metadata:
79
98
  spec:
80
99
  ttlSecondsAfterFinished: {args.ttl_seconds_after_finished}
81
100
  failurePolicy:
101
+ {failure_policy_rules}
82
102
  maxRestarts: {args.max_restarts}
83
103
  replicatedJobs:
84
104
  - name: slice-job
@@ -88,10 +108,13 @@ spec:
88
108
  parallelism: {system.vms_per_slice} # Equal to the number of VMs per slice
89
109
  completions: {system.vms_per_slice} # Same as the above.
90
110
  backoffLimit: 0 # When any pod fails, the job is failed
111
+ {pod_failure_policy}
91
112
  template:
92
113
  metadata:
93
114
  labels:
94
115
  xpk.google.com/workload: {args.workload}
116
+ annotations:
117
+ {storage_annotations}
95
118
  spec:
96
119
  schedulerName: {args.scheduler}
97
120
  restartPolicy: Never
@@ -106,30 +129,37 @@ spec:
106
129
  terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
107
130
  containers:
108
131
  {container}
132
+ serviceAccountName: {service_account}
109
133
  volumes:
110
134
  {volumes}
111
135
  """
112
136
 
113
137
 
114
- gpu_workload_create_yaml = """apiVersion: jobset.x-k8s.io/v1alpha2
138
+ GPU_WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
115
139
  kind: JobSet
116
140
  metadata:
117
141
  name: {args.workload}
142
+ annotations: {storage_annotations}
118
143
  labels:
119
144
  kueue.x-k8s.io/queue-name: multislice-queue # Name of the LocalQueue
120
145
  xpk.google.com/workload: {args.workload}
121
146
  spec:
122
147
  ttlSecondsAfterFinished: {args.ttl_seconds_after_finished}
123
148
  failurePolicy:
149
+ {failure_policy_rules}
124
150
  maxRestarts: {args.max_restarts}
125
151
  replicatedJobs:
126
152
  - name: slice-job
127
153
  replicas: 1
128
154
  template:
155
+ metadata:
156
+ annotations:
157
+ {storage_annotations}
129
158
  spec:
130
159
  parallelism: {args.num_nodes}
131
160
  completions: {args.num_nodes}
132
161
  backoffLimit: 0 # When any pod fails, the job is failed
162
+ {pod_failure_policy}
133
163
  template:
134
164
  metadata:
135
165
  labels:
@@ -141,11 +171,13 @@ spec:
141
171
  hostNetwork: true
142
172
  dnsPolicy: ClusterFirstWithHostNet
143
173
  terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
174
+ serviceAccountName: {service_account}
144
175
  tolerations:
145
176
  - operator: "Exists"
146
177
  key: nvidia.com/gpu
147
178
  volumes:
148
179
  {gpu_volume}
180
+ {storage_volumes}
149
181
  containers:
150
182
  {gpu_rxdm_image}
151
183
  imagePullPolicy: Always
@@ -159,6 +191,7 @@ spec:
159
191
  privileged: true
160
192
  volumeMounts:
161
193
  {gpu_tcp_volume}
194
+ {storage_volume_mounts}
162
195
  - name: nvidia-install-dir-host
163
196
  mountPath: /usr/local/nvidia/lib64
164
197
  - name: workload-terminated-volume
@@ -169,7 +202,7 @@ spec:
169
202
  {container}
170
203
  """
171
204
 
172
- a3_gpu_workload_create_yaml = """apiVersion: jobset.x-k8s.io/v1alpha2
205
+ A3_GPU_WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
173
206
  kind: JobSet
174
207
  metadata:
175
208
  name: {args.workload}
@@ -179,6 +212,7 @@ metadata:
179
212
  spec:
180
213
  ttlSecondsAfterFinished: {args.ttl_seconds_after_finished}
181
214
  failurePolicy:
215
+ {failure_policy_rules}
182
216
  maxRestarts: {args.max_restarts}
183
217
  replicatedJobs:
184
218
  - name: slice-job
@@ -188,6 +222,7 @@ spec:
188
222
  parallelism: {args.num_nodes}
189
223
  completions: {args.num_nodes}
190
224
  backoffLimit: 0 # When any pod fails, the job is failed
225
+ {pod_failure_policy}
191
226
  template:
192
227
  metadata:
193
228
  labels:
@@ -199,6 +234,7 @@ spec:
199
234
  restartPolicy: Never
200
235
  dnsPolicy: ClusterFirstWithHostNet
201
236
  terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
237
+ serviceAccountName: {service_account}
202
238
  tolerations:
203
239
  - operator: "Exists"
204
240
  key: nvidia.com/gpu
@@ -206,7 +242,7 @@ spec:
206
242
  {container}
207
243
  """
208
244
 
209
- pw_workload_create_yaml = """apiVersion: jobset.x-k8s.io/v1alpha2
245
+ PW_WORKLOAD_CREATE_YAML = """apiVersion: jobset.x-k8s.io/v1alpha2
210
246
  kind: JobSet
211
247
  metadata:
212
248
  name: {args.workload}
@@ -216,129 +252,208 @@ metadata:
216
252
  spec:
217
253
  ttlSecondsAfterFinished: {args.ttl_seconds_after_finished}
218
254
  failurePolicy:
255
+ {failure_policy_rules}
219
256
  maxRestarts: {args.max_restarts}
220
257
  successPolicy:
221
258
  operator: "All"
222
259
  targetReplicatedJobs:
223
260
  - {args.targetReplicatedJob}
224
261
  replicatedJobs:
225
- - name: worker
226
- replicas: {args.num_slices}
227
- template:
228
- metadata:
229
- annotations:
230
- alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
231
- labels:
232
- xpk.google.com/workload: {args.workload}
233
- spec:
234
- backoffLimit: {backoff_limit}
235
- completions: {system.vms_per_slice}
236
- parallelism: {system.vms_per_slice}
237
- template:
238
- spec:
239
- terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
240
- containers:
241
- - args:
242
- {pathways_worker_args}
243
- image: {args.server_image}
244
- imagePullPolicy: Always
245
- name: pathways-worker
246
- ports:
247
- - containerPort: 29001
248
- - containerPort: 8471
249
- - containerPort: 8080
250
- resources:
251
- limits:
252
- {resource_type}: {system.chips_per_vm}
253
- securityContext:
254
- privileged: true
255
- volumeMounts:
256
- - mountPath: /tmp
262
+ - name: worker
263
+ replicas: {args.num_slices}
264
+ template:
265
+ metadata:
266
+ annotations:
267
+ alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
268
+ labels:
269
+ xpk.google.com/workload: {args.workload}
270
+ spec:
271
+ backoffLimit: {backoff_limit}
272
+ completions: {system.vms_per_slice}
273
+ parallelism: {system.vms_per_slice}
274
+ template:
275
+ metadata:
276
+ annotations:
277
+ {storage_annotations}
278
+ spec:
279
+ terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
280
+ serviceAccountName: {service_account}
281
+ containers:
282
+ - args:
283
+ {pathways_worker_args}
284
+ image: {args.server_image}
285
+ imagePullPolicy: Always
286
+ name: pathways-worker
287
+ ports:
288
+ - containerPort: 29001
289
+ - containerPort: 8471
290
+ - containerPort: 8080
291
+ resources:
292
+ limits:
293
+ {resource_type}: {system.chips_per_vm}
294
+ securityContext:
295
+ privileged: true
296
+ volumeMounts:
297
+ - mountPath: /tmp
298
+ name: shared-tmp
299
+ {storage_volume_mounts}
300
+ env:
301
+ - name: PROJECT_ID
302
+ value: {args.project}
303
+ - name: LOCATION
304
+ value: {args.zone}
305
+ - name: CLUSTER_NAME
306
+ value: {args.cluster}
307
+ - name: POD_NAME
308
+ valueFrom:
309
+ fieldRef:
310
+ fieldPath: metadata.name
311
+ - name: CONTAINER_NAME
312
+ value: "pathways-worker"
313
+ - name: NAMESPACE
314
+ valueFrom:
315
+ fieldRef:
316
+ fieldPath: metadata.namespace
317
+ # Workaround for v6e
318
+ - name: MEGASCALE_GRPC_ENABLE_XOR_TRACER
319
+ value: "false"
320
+ - name: MEGASCALE_NUM_SLICES
321
+ valueFrom:
322
+ fieldRef:
323
+ fieldPath: "metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']"
324
+ - name: JOBSET_NAME
325
+ valueFrom:
326
+ fieldRef:
327
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
328
+ - name: REPLICATED_JOB_NAME
329
+ valueFrom:
330
+ fieldRef:
331
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
332
+ - name: MEGASCALE_SLICE_ID
333
+ valueFrom:
334
+ fieldRef:
335
+ fieldPath: "metadata.labels['jobset.sigs.k8s.io/job-index']"
336
+ - name: MEGASCALE_COORDINATOR_ADDRESS
337
+ value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-$(MEGASCALE_SLICE_ID)-0.$(JOBSET_NAME)"
338
+ {pathways_sidecar_container}
339
+ nodeSelector:
340
+ {accelerator_label}
341
+ {machine_label}
342
+ {autoprovisioning_args}
343
+ priorityClassName: {args.priority}
344
+ hostNetwork: true
345
+ dnsPolicy: ClusterFirstWithHostNet
346
+ volumes:
347
+ - hostPath:
348
+ path: /tmp
349
+ type: DirectoryOrCreate
257
350
  name: shared-tmp
258
- nodeSelector:
259
- {accelerator_label}
260
- {machine_label}
261
- {autoprovisioning_args}
262
- priorityClassName: {args.priority}
263
- hostNetwork: true
264
- dnsPolicy: ClusterFirstWithHostNet
265
- volumes:
266
- - hostPath:
267
- path: /tmp
268
- type: DirectoryOrCreate
269
- name: shared-tmp
270
- - name: rm
271
- replicas: 1
272
- template:
273
- metadata:
274
- labels:
275
- xpk.google.com/workload: {args.workload}
276
- spec:
277
- backoffLimit: 0
278
- completions: 1
279
- parallelism: 1
280
- template:
281
- spec:
282
- containers:
283
- - args:
284
- {pathways_rm_args}
285
- env:
286
- - name: REPLICATED_JOB_NAME
287
- valueFrom:
288
- fieldRef:
289
- fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
290
- - name: JOBSET_NAME
291
- valueFrom:
292
- fieldRef:
293
- fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
294
- - name: HOST_ADDRESS
295
- value: $(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)
296
- - name: TPU_SKIP_MDS_QUERY
297
- value: "true"
298
- image: {args.server_image}
299
- imagePullPolicy: Always
300
- name: pathways-rm
301
- ports:
302
- - containerPort: 29001
303
- securityContext:
304
- privileged: true
305
- volumeMounts:
306
- - mountPath: /tmp
351
+ {storage_volumes}
352
+ - name: rm
353
+ replicas: 1
354
+ template:
355
+ metadata:
356
+ labels:
357
+ xpk.google.com/workload: {args.workload}
358
+ spec:
359
+ backoffLimit: 0
360
+ completions: 1
361
+ parallelism: 1
362
+ template:
363
+ spec:
364
+ containers:
365
+ - args:
366
+ {pathways_rm_args}
367
+ env:
368
+ - name: PROJECT_ID
369
+ value: {args.project}
370
+ - name: LOCATION
371
+ value: {args.zone}
372
+ - name: CLUSTER_NAME
373
+ value: {args.cluster}
374
+ - name: POD_NAME
375
+ valueFrom:
376
+ fieldRef:
377
+ fieldPath: metadata.name
378
+ - name: CONTAINER_NAME
379
+ value: "pathways-rm"
380
+ - name: NAMESPACE
381
+ valueFrom:
382
+ fieldRef:
383
+ fieldPath: metadata.namespace
384
+ - name: REPLICATED_JOB_NAME
385
+ valueFrom:
386
+ fieldRef:
387
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
388
+ - name: JOBSET_NAME
389
+ valueFrom:
390
+ fieldRef:
391
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
392
+ - name: HOST_ADDRESS
393
+ value: $(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)
394
+ - name: TPU_SKIP_MDS_QUERY
395
+ value: "true"
396
+ image: {args.server_image}
397
+ imagePullPolicy: Always
398
+ name: pathways-rm
399
+ ports:
400
+ - containerPort: 29001
401
+ securityContext:
402
+ privileged: true
403
+ volumeMounts:
404
+ - mountPath: /tmp
405
+ name: shared-tmp
406
+ nodeSelector:
407
+ cloud.google.com/gke-nodepool: cpu-rm-np
408
+ hostNetwork: true
409
+ dnsPolicy: ClusterFirstWithHostNet
410
+ volumes:
411
+ - hostPath:
412
+ path: /tmp
413
+ type: DirectoryOrCreate
307
414
  name: shared-tmp
308
- nodeSelector:
309
- cloud.google.com/gke-nodepool: cpu-rm-np
310
- hostNetwork: true
311
- dnsPolicy: ClusterFirstWithHostNet
312
- volumes:
313
- - hostPath:
314
- path: /tmp
315
- type: DirectoryOrCreate
316
- name: shared-tmp
317
- - name: proxy
318
- replicas: 1
319
- template:
320
- metadata:
321
- labels:
322
- xpk.google.com/workload: {args.workload}
323
- spec:
324
- backoffLimit: 0
325
- completions: 1
326
- parallelism: 1
327
- template:
328
- spec:
329
- containers:
330
- - args:
331
- {pathways_proxy_args}
332
- image: {args.proxy_server_image}
333
- imagePullPolicy: Always
334
- name: pathways-proxy
335
- ports:
336
- - containerPort: 29000
337
- hostNetwork: true
338
- dnsPolicy: ClusterFirstWithHostNet
339
- nodeSelector:
340
- cloud.google.com/gke-nodepool: cpu-proxy-np
341
- {user_workload}
415
+ - name: proxy
416
+ replicas: 1
417
+ template:
418
+ metadata:
419
+ labels:
420
+ xpk.google.com/workload: {args.workload}
421
+ spec:
422
+ backoffLimit: 0
423
+ completions: 1
424
+ parallelism: 1
425
+ template:
426
+ spec:
427
+ containers:
428
+ - args:
429
+ {pathways_proxy_args}
430
+ env:
431
+ - name: PROJECT_ID
432
+ value: {args.project}
433
+ - name: LOCATION
434
+ value: {args.zone}
435
+ - name: CLUSTER_NAME
436
+ value: {args.cluster}
437
+ - name: POD_NAME
438
+ valueFrom:
439
+ fieldRef:
440
+ fieldPath: metadata.name
441
+ - name: CONTAINER_NAME
442
+ value: "pathways-proxy"
443
+ - name: NAMESPACE
444
+ valueFrom:
445
+ fieldRef:
446
+ fieldPath: metadata.namespace
447
+ image: {args.proxy_server_image}
448
+ imagePullPolicy: Always
449
+ name: pathways-proxy
450
+ ports:
451
+ - containerPort: 29000
452
+ hostNetwork: true
453
+ dnsPolicy: ClusterFirstWithHostNet
454
+ nodeSelector:
455
+ cloud.google.com/gke-nodepool: cpu-proxy-np
456
+ {user_workload}
342
457
  """
343
458
 
344
459
 
@@ -352,6 +467,13 @@ def workload_create_pathways(args) -> None:
352
467
  0 if successful and 1 otherwise.
353
468
  """
354
469
  args.use_pathways = True
470
+ if args.headless:
471
+ xpk_print(
472
+ 'Please use kubectl port forwarding to connect to the Pathways proxy.'
473
+ ' kubectl get pods kubectl port-forward <proxy-pod-name> 29000:29000'
474
+ ' JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 python'
475
+ " -c 'import pathwaysutils; import jax; print(jax.devices())'"
476
+ )
355
477
  workload_create(args)
356
478
 
357
479
 
@@ -364,19 +486,8 @@ def workload_create(args) -> None:
364
486
  Returns:
365
487
  0 if successful and 1 otherwise.
366
488
  """
367
- add_zone_and_project(args)
368
-
369
- if args.headless:
370
- xpk_print(
371
- 'Please use kubectl port forwarding to connect to the Pathways proxy.'
372
- ' kubectl get pods kubectl port-forward <proxy-pod-name> 29000:29000'
373
- ' JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 python'
374
- " -c 'import pathwaysutils; import jax; print(jax.devices())'"
375
- )
376
-
377
- set_cluster_command_code = set_cluster_command(args)
378
- if set_cluster_command_code != 0:
379
- xpk_exit(set_cluster_command_code)
489
+ k8s_api_client = setup_k8s_env(args)
490
+ create_xpk_k8s_service_account()
380
491
 
381
492
  workload_exists = check_if_workload_exists(args)
382
493
 
@@ -412,12 +523,12 @@ def workload_create(args) -> None:
412
523
  cluster_xpk_version = cluster_config_map.get('xpk_version')
413
524
  if (
414
525
  cluster_xpk_version is not None
415
- and cluster_xpk_version != xpk_current_version
526
+ and cluster_xpk_version != XPK_CURRENT_VERSION
416
527
  ):
417
528
  xpk_print(
418
529
  'Warning: Cluster has been created using XPK version:'
419
530
  f' {cluster_config_map["xpk_version"]} but the XPK version you are'
420
- f' using to schedule workload is: {xpk_current_version}. Some features'
531
+ f' using to schedule workload is: {XPK_CURRENT_VERSION}. Some features'
421
532
  ' might not be available for this cluster. We recommend to'
422
533
  ' upgrade/downgrade your XPK version or cluster by running `xpk'
423
534
  ' cluster create`.'
@@ -449,6 +560,44 @@ def workload_create(args) -> None:
449
560
  if return_code != 0:
450
561
  xpk_exit(return_code)
451
562
 
563
+ storages: list[Storage] = get_storages_to_mount(k8s_api_client, args.storage)
564
+ gcs_fuse_storages = list(
565
+ filter(lambda storage: storage.type == GCS_FUSE_TYPE, storages)
566
+ )
567
+ gcpfilestore_storages: list[Storage] = list(
568
+ filter(lambda storage: storage.type == GCP_FILESTORE_TYPE, storages)
569
+ )
570
+ storage_annotations = ''
571
+ service_account = ''
572
+ if len(gcs_fuse_storages) > 0:
573
+ storage_annotations = GCS_FUSE_ANNOTATION
574
+ service_account = XPK_SA
575
+ xpk_print(f'Detected gcsfuse Storages to add: {gcs_fuse_storages}')
576
+ else:
577
+ xpk_print('No gcsfuse Storages to add detected')
578
+ failure_policy_rules = """rules:
579
+ - action: FailJobSet
580
+ onJobFailureReasons:
581
+ - PodFailurePolicy"""
582
+ restart_on_exit_codes = get_restart_exit_codes(args)
583
+ restart_on_exit_codes = ','.join(map(str, restart_on_exit_codes))
584
+ pod_failure_policy = f"""
585
+ podFailurePolicy:
586
+ rules:
587
+ - action: FailJob
588
+ onExitCodes:
589
+ containerName: {get_main_container_docker_image(args, system)}
590
+ operator: NotIn
591
+ values: [{restart_on_exit_codes}]"""
592
+
593
+ if len(gcpfilestore_storages) > 0:
594
+ xpk_print(
595
+ f'Detected gcp filestores instances to add: {gcpfilestore_storages}'
596
+ )
597
+ service_account = XPK_SA
598
+ else:
599
+ xpk_print('No gcp filestore instances to add detected.')
600
+ all_storages = gcs_fuse_storages + gcpfilestore_storages
452
601
  # Create the workload file based on accelerator type or workload type.
453
602
  if system.accelerator_type == AcceleratorType['GPU']:
454
603
  container, debugging_dashboard_id = get_user_workload_container(
@@ -461,21 +610,26 @@ def workload_create(args) -> None:
461
610
  xpk_exit(return_code)
462
611
 
463
612
  if system.device_type in cluster_gcluster.supported_device_types:
464
- yml_string = a3_gpu_workload_create_yaml.format(
465
- args=args, container=container
613
+ yml_string = A3_GPU_WORKLOAD_CREATE_YAML.format(
614
+ args=args,
615
+ container=container,
616
+ service_account=XPK_SA,
617
+ failure_policy_rules=failure_policy_rules,
618
+ pod_failure_policy=pod_failure_policy,
466
619
  )
467
620
 
468
621
  if args.device_type == cluster_gcluster.a3mega_device_type:
469
- sub_networks = [f'{args.cluster}-gpunet-{i}-subnet' for i in range(8)]
622
+ sub_networks = get_subnetworks_for_a3mega(args.cluster)
470
623
  yml_string = tcpxo_decorator.decorate_jobset(yml_string, sub_networks)
471
624
 
472
625
  if args.device_type == cluster_gcluster.a3ultra_device_type:
473
- sub_networks = [f'{args.cluster}-sub-1'] + [
474
- f'{args.cluster}-rdma-sub-{i}' for i in range(8)
475
- ]
626
+ sub_networks = get_subnetworks_for_a3ultra(args.cluster)
476
627
  yml_string = rdma_decorator.decorate_jobset(yml_string, sub_networks)
628
+
629
+ if len(gcs_fuse_storages) + len(gcpfilestore_storages) > 0:
630
+ yml_string = storage_decorator.decorate_jobset(yml_string, all_storages)
477
631
  else:
478
- yml_string = gpu_workload_create_yaml.format(
632
+ yml_string = GPU_WORKLOAD_CREATE_YAML.format(
479
633
  args=args,
480
634
  container=container,
481
635
  command=args.command,
@@ -485,33 +639,51 @@ def workload_create(args) -> None:
485
639
  gpu_rxdm_image=get_gpu_rxdm_image(system),
486
640
  gpu_rxdm_cmd=get_gpu_rxdm_cmd(system),
487
641
  gpu_tcp_volume=get_gpu_tcp_volume(system),
642
+ storage_volumes=get_storage_volumes_yaml_for_gpu(all_storages),
643
+ storage_volume_mounts=get_storage_volume_mounts_yaml_for_gpu(
644
+ all_storages
645
+ ),
646
+ storage_annotations=storage_annotations,
647
+ service_account=service_account,
648
+ failure_policy_rules=failure_policy_rules,
649
+ pod_failure_policy=pod_failure_policy,
488
650
  )
651
+
489
652
  elif args.use_pathways and ensure_pathways_workload_prerequisites(
490
653
  args, system
491
654
  ):
492
- yml_string = pw_workload_create_yaml.format(
655
+ yml_string = PW_WORKLOAD_CREATE_YAML.format(
493
656
  args=args,
494
657
  system=system,
495
658
  accelerator_label=create_accelerator_label(
496
659
  system.accelerator_type, system
497
660
  ),
498
661
  machine_label=create_machine_label(system.accelerator_type, system),
499
- pathways_rm_args=get_pathways_rm_args(args, system),
500
662
  pathways_worker_args=get_pathways_worker_args(args),
501
663
  pathways_proxy_args=get_pathways_proxy_args(args),
502
- user_workload=get_user_workload_for_pathways(args, system),
664
+ pathways_sidecar_container=get_pathways_sidecar_container(args),
665
+ user_workload=get_user_workload_for_pathways(
666
+ args, system, pod_failure_policy, storages
667
+ ),
503
668
  resource_type=AcceleratorTypeToAcceleratorCharacteristics[
504
669
  system.accelerator_type
505
670
  ].resource_type,
506
671
  local_queue_name=LOCAL_QUEUE_NAME,
507
672
  autoprovisioning_args=autoprovisioning_args,
508
673
  backoff_limit=system.vms_per_slice * 4,
674
+ storage_annotations=storage_annotations,
675
+ storage_volumes=get_storage_volumes_yaml(all_storages),
676
+ storage_volume_mounts=get_storage_volume_mounts_yaml(all_storages),
677
+ pathways_rm_args=get_pathways_rm_args(args, system),
678
+ service_account=service_account,
679
+ failure_policy_rules=failure_policy_rules,
680
+ pod_failure_policy=pod_failure_policy,
509
681
  )
510
682
  else:
511
683
  container, debugging_dashboard_id = get_user_workload_container(
512
684
  args, system
513
685
  )
514
- yml_string = workload_create_yaml.format(
686
+ yml_string = WORKLOAD_CREATE_YAML.format(
515
687
  args=args,
516
688
  system=system,
517
689
  container=container,
@@ -523,6 +695,10 @@ def workload_create(args) -> None:
523
695
  local_queue_name=LOCAL_QUEUE_NAME,
524
696
  autoprovisioning_args=autoprovisioning_args,
525
697
  volumes=get_volumes(args, system),
698
+ storage_annotations=storage_annotations,
699
+ service_account=service_account,
700
+ failure_policy_rules=failure_policy_rules,
701
+ pod_failure_policy=pod_failure_policy,
526
702
  )
527
703
  tmp = write_tmp_file(yml_string)
528
704
  command = f'kubectl apply -f {str(tmp.file.name)}'
@@ -532,6 +708,7 @@ def workload_create(args) -> None:
532
708
  xpk_print(f'Create Workload request returned ERROR {return_code}')
533
709
  xpk_exit(return_code)
534
710
 
711
+ add_bucket_iam_members(args, storages)
535
712
  # Get GKE outlier dashboard for TPU
536
713
  outlier_dashboard_id = None
537
714
  if system.accelerator_type == AcceleratorType['TPU']:
@@ -559,10 +736,16 @@ def workload_create(args) -> None:
559
736
  if args.use_pathways:
560
737
  if args.headless:
561
738
  xpk_print(
562
- ' \n ******* Please connect to your Pathways proxy at'
563
- f' {args.pathways_proxy_address}, once you see "IFRT proxy server'
564
- ' started with status OK" on the proxy link below.'
565
- ' Remember to delete the workload once done! ****** \n'
739
+ '******* Please use kubectl port forwarding to connect to the'
740
+ ' Pathways proxy, once you see "IFRT proxy server started with status'
741
+ ' OK" on the proxy link below. Remember to delete the workload once'
742
+ ' done! ******* '
743
+ )
744
+ xpk_print(
745
+ 'Steps to connect to the proxy: kubectl get pods | grep proxy ;'
746
+ ' kubectl port-forward <proxy-pod-name> 29000:29000; '
747
+ ' JAX_PLATFORMS=proxy; JAX_BACKEND_TARGET=grpc://127.0.0.1:29000;'
748
+ " python -c 'import pathwaysutils; import jax; print(jax.devices())'"
566
749
  )
567
750
  pathways_proxy_link = f'https://console.cloud.google.com/kubernetes/job/{zone_to_region(args.zone)}/{args.cluster}/default/{args.workload}-proxy-0/details?project={args.project}'
568
751
  xpk_print(
@@ -593,6 +776,24 @@ def workload_create(args) -> None:
593
776
  xpk_exit(0)
594
777
 
595
778
 
779
+ def get_restart_exit_codes(args) -> list:
780
+ exit_codes = [42]
781
+ exit_codes.extend(range(127, 256, 1))
782
+
783
+ if args.restart_on_exit_codes is not None:
784
+ items = args.restart_on_exit_codes.split(',')
785
+ for item in items:
786
+ item = item.strip()
787
+ if '-' in item:
788
+ start, end = map(int, item.split('-'))
789
+ exit_codes.extend(range(start, end + 1))
790
+ else:
791
+ exit_codes.append(int(item))
792
+
793
+ # Remove duplicates that the user may have added.
794
+ return list(set(exit_codes))
795
+
796
+
596
797
  def workload_delete(args) -> None:
597
798
  """Function around workload delete.
598
799
 
@@ -604,9 +805,7 @@ def workload_delete(args) -> None:
604
805
  """
605
806
  xpk_print('Starting Workload delete', flush=True)
606
807
  add_zone_and_project(args)
607
- set_cluster_command_code = set_cluster_command(args)
608
- if set_cluster_command_code != 0:
609
- xpk_exit(set_cluster_command_code)
808
+ get_cluster_credentials(args)
610
809
 
611
810
  will_delete = True
612
811
  if not args.workload:
@@ -672,9 +871,7 @@ def workload_list(args) -> None:
672
871
 
673
872
  xpk_print('Starting workload list', flush=True)
674
873
  add_zone_and_project(args)
675
- set_cluster_command_code = set_cluster_command(args)
676
- if set_cluster_command_code != 0:
677
- xpk_exit(set_cluster_command_code)
874
+ get_cluster_credentials(args)
678
875
 
679
876
  if args.wait_for_job_completion:
680
877
  return_code = wait_for_job_completion(args)