xpk 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. xpk/__init__.py +15 -0
  2. xpk/api/__init__.py +15 -0
  3. xpk/api/storage_crd.yaml +52 -0
  4. xpk/commands/__init__.py +15 -0
  5. xpk/commands/batch.py +131 -0
  6. xpk/commands/cluster.py +808 -0
  7. xpk/commands/cluster_gcluster.py +269 -0
  8. xpk/commands/common.py +44 -0
  9. xpk/commands/config.py +29 -0
  10. xpk/commands/info.py +243 -0
  11. xpk/commands/inspector.py +357 -0
  12. xpk/commands/job.py +199 -0
  13. xpk/commands/kind.py +283 -0
  14. xpk/commands/kjob_common.py +44 -0
  15. xpk/commands/run.py +128 -0
  16. xpk/commands/shell.py +140 -0
  17. xpk/commands/storage.py +267 -0
  18. xpk/commands/version.py +27 -0
  19. xpk/commands/workload.py +889 -0
  20. xpk/core/__init__.py +15 -0
  21. xpk/core/blueprint/__init__.py +15 -0
  22. xpk/core/blueprint/blueprint_definitions.py +62 -0
  23. xpk/core/blueprint/blueprint_generator.py +708 -0
  24. xpk/core/capacity.py +185 -0
  25. xpk/core/cluster.py +564 -0
  26. xpk/core/cluster_private.py +200 -0
  27. xpk/core/commands.py +356 -0
  28. xpk/core/config.py +179 -0
  29. xpk/core/docker_container.py +225 -0
  30. xpk/core/docker_image.py +210 -0
  31. xpk/core/docker_manager.py +308 -0
  32. xpk/core/docker_resources.py +350 -0
  33. xpk/core/filestore.py +251 -0
  34. xpk/core/gcloud_context.py +196 -0
  35. xpk/core/gcluster_manager.py +176 -0
  36. xpk/core/gcsfuse.py +50 -0
  37. xpk/core/kjob.py +444 -0
  38. xpk/core/kueue.py +358 -0
  39. xpk/core/monitoring.py +134 -0
  40. xpk/core/nap.py +361 -0
  41. xpk/core/network.py +377 -0
  42. xpk/core/nodepool.py +581 -0
  43. xpk/core/pathways.py +377 -0
  44. xpk/core/ray.py +222 -0
  45. xpk/core/remote_state/__init__.py +15 -0
  46. xpk/core/remote_state/fuse_remote_state.py +99 -0
  47. xpk/core/remote_state/remote_state_client.py +38 -0
  48. xpk/core/resources.py +238 -0
  49. xpk/core/scheduling.py +253 -0
  50. xpk/core/storage.py +581 -0
  51. xpk/core/system_characteristics.py +1432 -0
  52. xpk/core/vertex.py +105 -0
  53. xpk/core/workload.py +341 -0
  54. xpk/core/workload_decorators/__init__.py +15 -0
  55. xpk/core/workload_decorators/rdma_decorator.py +129 -0
  56. xpk/core/workload_decorators/storage_decorator.py +52 -0
  57. xpk/core/workload_decorators/tcpxo_decorator.py +190 -0
  58. xpk/main.py +75 -0
  59. xpk/parser/__init__.py +15 -0
  60. xpk/parser/batch.py +43 -0
  61. xpk/parser/cluster.py +662 -0
  62. xpk/parser/common.py +259 -0
  63. xpk/parser/config.py +49 -0
  64. xpk/parser/core.py +135 -0
  65. xpk/parser/info.py +64 -0
  66. xpk/parser/inspector.py +65 -0
  67. xpk/parser/job.py +147 -0
  68. xpk/parser/kind.py +95 -0
  69. xpk/parser/run.py +47 -0
  70. xpk/parser/shell.py +59 -0
  71. xpk/parser/storage.py +316 -0
  72. xpk/parser/validators.py +39 -0
  73. xpk/parser/version.py +23 -0
  74. xpk/parser/workload.py +726 -0
  75. xpk/templates/__init__.py +15 -0
  76. xpk/templates/storage.yaml +13 -0
  77. xpk/utils/__init__.py +15 -0
  78. xpk/utils/console.py +55 -0
  79. xpk/utils/file.py +82 -0
  80. xpk/utils/gcs_utils.py +125 -0
  81. xpk/utils/kubectl.py +57 -0
  82. xpk/utils/network.py +168 -0
  83. xpk/utils/objects.py +88 -0
  84. xpk/utils/templates.py +28 -0
  85. xpk/utils/validation.py +80 -0
  86. xpk/utils/yaml.py +30 -0
  87. xpk-0.0.1.dist-info/LICENSE +202 -0
  88. xpk-0.0.1.dist-info/METADATA +1498 -0
  89. xpk-0.0.1.dist-info/RECORD +92 -0
  90. xpk-0.0.1.dist-info/WHEEL +5 -0
  91. xpk-0.0.1.dist-info/entry_points.txt +2 -0
  92. xpk-0.0.1.dist-info/top_level.txt +1 -0
xpk/core/vertex.py ADDED
@@ -0,0 +1,105 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from ..utils.console import xpk_print
18
+ from .resources import CLUSTER_METADATA_CONFIGMAP, get_cluster_configmap
19
+
20
+ DEFAULT_VERTEX_TENSORBOARD_NAME = 'tb-instance'
21
+
22
+
23
+ def create_vertex_tensorboard(args) -> dict:
24
+ """Creates a Tensorboard instance in Vertex AI.
25
+
26
+ Args:
27
+ args: user provided arguments.
28
+
29
+ Returns:
30
+ dict containing Tensorboard instance name, id and location.
31
+ """
32
+ from cloud_accelerator_diagnostics import ( # pylint: disable=import-outside-toplevel
33
+ tensorboard,
34
+ )
35
+
36
+ tensorboard_config = {}
37
+ tensorboard_name = args.tensorboard_name
38
+ if tensorboard_name is None:
39
+ tensorboard_name = f'{args.cluster}-{DEFAULT_VERTEX_TENSORBOARD_NAME}'
40
+ instance_id = tensorboard.create_instance( # pylint: disable=used-before-assignment
41
+ project=args.project,
42
+ location=args.tensorboard_region,
43
+ tensorboard_name=tensorboard_name,
44
+ )
45
+ if instance_id:
46
+ xpk_print(
47
+ f'Tensorboard instance {tensorboard_name} is successfully created.'
48
+ )
49
+ tensorboard_config['tensorboard_region'] = args.tensorboard_region
50
+ tensorboard_config['tensorboard_name'] = tensorboard_name
51
+ tensorboard_config['tensorboard_id'] = instance_id
52
+ return tensorboard_config
53
+
54
+
55
+ def create_vertex_experiment(args) -> dict | None:
56
+ """Creates an Experiment in Vertex AI.
57
+
58
+ Args:
59
+ args: user provided arguments.
60
+
61
+ Returns:
62
+ map containing Vertex Tensorboard configurations.
63
+ """
64
+ from cloud_accelerator_diagnostics import ( # pylint: disable=import-outside-toplevel
65
+ tensorboard,
66
+ )
67
+
68
+ metadata_configmap_name = f'{args.cluster}-{CLUSTER_METADATA_CONFIGMAP}'
69
+ cluster_config_map = get_cluster_configmap(args, metadata_configmap_name)
70
+
71
+ if cluster_config_map is None or 'tensorboard_name' not in cluster_config_map:
72
+ xpk_print(
73
+ 'No Vertex Tensorboard instance has been created in cluster create. Run'
74
+ ' `xpk cluster create --create-vertex-tensorboard` before running `xpk'
75
+ ' workload create --use-vertex-tensorboard` to create a Vertex'
76
+ ' Tensorboard instance. Alternatively, use `xpk cluster create-pathways'
77
+ ' --create-vertex-tensorboard` before running `xpk workload'
78
+ ' create-pathways --use-vertex-tensorboard`.'
79
+ )
80
+ return None
81
+
82
+ tensorboard_config = {}
83
+ tensorboard_config['tensorboard_project'] = args.project
84
+ tensorboard_config['tensorboard_region'] = cluster_config_map[
85
+ 'tensorboard_region'
86
+ ]
87
+ tensorboard_config['tensorboard_name'] = cluster_config_map[
88
+ 'tensorboard_name'
89
+ ]
90
+ experiment_name = args.experiment_name
91
+ if experiment_name is None:
92
+ experiment_name = f'{args.cluster}-{args.workload}'
93
+ tensorboard_config['experiment_name'] = experiment_name
94
+
95
+ _, tensorboard_url = tensorboard.create_experiment(
96
+ project=args.project,
97
+ location=tensorboard_config['tensorboard_region'],
98
+ experiment_name=experiment_name,
99
+ tensorboard_name=tensorboard_config['tensorboard_name'],
100
+ )
101
+ if tensorboard_url is None:
102
+ return None
103
+
104
+ xpk_print(f'You can view Vertex Tensorboard at: {tensorboard_url}')
105
+ return tensorboard_config
xpk/core/workload.py ADDED
@@ -0,0 +1,341 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from ..utils.console import xpk_exit, xpk_print
18
+ from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE
19
+ from .commands import run_command_for_value
20
+ from .gcloud_context import zone_to_region
21
+ from .system_characteristics import SystemCharacteristics
22
+
23
+
24
+ def workload_list_awk_command(filter_key) -> str:
25
+ """Function returns the awk command needed from the filter specified.
26
+
27
+ Args:
28
+ filter_key: workload list filter to awk against
29
+
30
+ Returns:
31
+ awk command to use in filtering workload list.
32
+ """
33
+
34
+ return f" | awk -e 'NR == 1 || {filter_key} {{print $0}}'"
35
+
36
+
37
+ def determine_workload_list_filter_by_status(args) -> str:
38
+ """Function to create the filtered view of workload list.
39
+
40
+ Args:
41
+ args: user provided arguments for running the command.
42
+
43
+ Returns:
44
+ the argument needed to filter by status of jobs in workload list.
45
+ """
46
+
47
+ # Argument positions related to columns created by workload list command.
48
+ status_arg = '$7'
49
+ running_vms_arg = '$5'
50
+ status_verbose_arg = '$9'
51
+ if args.filter_by_status == 'EVERYTHING':
52
+ return ''
53
+ elif args.filter_by_status == 'RUNNING':
54
+ # Running includes the status Admitted or Evicted, and when the number of
55
+ # vms running is > 0.
56
+ return workload_list_awk_command(
57
+ f'({status_arg} ~ "Admitted|Evicted" && {running_vms_arg} ~ /^[0-9]+$/'
58
+ f' && {running_vms_arg} > 0)'
59
+ )
60
+ elif args.filter_by_status == 'QUEUED':
61
+ # Queued includes the status Admitted or Evicted, and when the number of
62
+ # vms running is 0.
63
+ return workload_list_awk_command(
64
+ f'({status_arg} ~ "Admitted|Evicted|QuotaReserved" &&'
65
+ f' ({running_vms_arg} ~ "<none>" || {running_vms_arg} == 0))'
66
+ )
67
+ elif args.filter_by_status == 'FINISHED':
68
+ return workload_list_awk_command(f'{status_arg} == "Finished"')
69
+ elif args.filter_by_status == 'FAILED':
70
+ # Failed includes the status Finished, and when the verbose reason is failed.
71
+ return workload_list_awk_command(
72
+ f'({status_arg} == "Finished" && {status_verbose_arg} ~ "failed")'
73
+ )
74
+ elif args.filter_by_status == 'SUCCESSFUL':
75
+ # Failed includes the status Finished, and when the verbose reason is finished/success.
76
+ return workload_list_awk_command(
77
+ f'({status_arg} == "Finished" && {status_verbose_arg} ~ "finished")'
78
+ )
79
+ raise RuntimeError(f'Can not find filter type: {args.filter_by_status}')
80
+
81
+
82
+ def determine_workload_list_filter_by_job(args) -> str:
83
+ """Function to filter view of workload list based on job name.
84
+
85
+ Args:
86
+ args: user provided arguments for running the command.
87
+
88
+ Returns:
89
+ the argument needed to filter job names from workload list
90
+ """
91
+ # Argument positions related to columns created by workload list command.
92
+ if not hasattr(args, 'filter_by_job') or args.filter_by_job is None:
93
+ return ''
94
+ else:
95
+ job_name_arg = '$1'
96
+ return workload_list_awk_command(f'{job_name_arg} ~ "{args.filter_by_job}"')
97
+
98
+
99
+ def get_workload_list(args) -> tuple[int, str]:
100
+ """Function to get the list of the workloads in the cluster.
101
+
102
+ Args:
103
+ args: user provided arguments for running the command.
104
+
105
+ Returns:
106
+ return_code: 0 if successful and 1 otherwise.
107
+ return_value: workloads in the cluster matching the criteria.
108
+ """
109
+ columns = {
110
+ 'Jobset Name': '.metadata.ownerReferences[0].name',
111
+ 'Created Time': '.metadata.creationTimestamp',
112
+ 'Priority': '.spec.priorityClassName',
113
+ 'TPU VMs Needed': '.spec.podSets[0].count',
114
+ 'TPU VMs Running/Ran': '.status.admission.podSetAssignments[-1].count',
115
+ 'TPU VMs Done': '.status.reclaimablePods[0].count',
116
+ 'Status': '.status.conditions[-1].type',
117
+ 'Status Message': '.status.conditions[-1].message',
118
+ 'Status Time': '.status.conditions[-1].lastTransitionTime',
119
+ }
120
+ s = ','.join([key + ':' + value for key, value in columns.items()])
121
+
122
+ workload_list_filter_status_cmd = determine_workload_list_filter_by_status(
123
+ args
124
+ )
125
+ workload_list_filter_job_cmd = determine_workload_list_filter_by_job(args)
126
+ command = (
127
+ f'kubectl get workloads -o=custom-columns="{s}" '
128
+ f'{workload_list_filter_status_cmd} {workload_list_filter_job_cmd}'
129
+ )
130
+
131
+ task = f'List Jobs with filter-by-status={args.filter_by_status}'
132
+ if hasattr(args, 'filter_by_job'):
133
+ task += f' with filter-by-job={args.filter_by_job}'
134
+
135
+ return_code, return_value = run_command_for_value(command, task, args)
136
+ return return_code, return_value
137
+
138
+
139
+ def check_if_workload_exists(args) -> bool:
140
+ """Check if workload exists.
141
+
142
+ Args:
143
+ args: user provided arguments for running the command.
144
+
145
+ Returns:
146
+ returns true if workload exist, otherwise returns false.
147
+ """
148
+ columns = {
149
+ 'Jobset': '.metadata.ownerReferences[0].name',
150
+ }
151
+
152
+ s = ','.join([key + ':' + value for key, value in columns.items()])
153
+
154
+ command = f"kubectl get workloads -o=custom-columns='{s}'"
155
+ return_code, return_msg = run_command_for_value(
156
+ command, 'Check if Workload Already Exists', args
157
+ )
158
+
159
+ if return_code != 0:
160
+ xpk_print(f'List Job request returned ERROR {return_code}')
161
+ xpk_exit(return_code)
162
+
163
+ lines = return_msg.split('\n')
164
+ new_workload_name = args.workload
165
+ for line in lines:
166
+ if line == new_workload_name:
167
+ return True
168
+ return False
169
+
170
+
171
+ def wait_for_job_completion(args) -> int:
172
+ """Function to wait for job completion.
173
+
174
+ Args:
175
+ args: user provided arguments for running the command.
176
+
177
+ Returns:
178
+ return_code: 0 if successful, 124 if timeout, 125 if unsuccessful job, 1 otherwise
179
+ """
180
+ # Check that the workload exists
181
+ args.workload = args.wait_for_job_completion
182
+ workload_exists = check_if_workload_exists(args)
183
+ if not workload_exists:
184
+ xpk_print(f'Workload named {args.workload} does not exist.')
185
+ return 1
186
+
187
+ # Get the full workload name
188
+ get_workload_name_cmd = f'kubectl get workloads | grep jobset-{args.workload}'
189
+ return_code, return_value = run_command_for_value(
190
+ get_workload_name_cmd, 'Get full workload name', args
191
+ )
192
+ if return_code != 0:
193
+ xpk_print(f'Get full workload name request returned ERROR {return_code}')
194
+ return return_code
195
+ full_workload_name = return_value.split(' ')[0]
196
+
197
+ # Call kubectl wait on the workload using the full workload name
198
+ timeout_val = args.timeout if args.timeout is not None else -1
199
+ timeout_msg = (
200
+ f'{timeout_val}s' if timeout_val != -1 else 'max timeout (1 week)'
201
+ )
202
+ wait_cmd = (
203
+ "kubectl wait --for jsonpath='.status.conditions[-1].type'=Finished"
204
+ f' workload {full_workload_name} --timeout={timeout_val}s'
205
+ )
206
+ return_code, return_value = run_command_for_value(
207
+ wait_cmd,
208
+ f'Wait for workload to finish with timeout of {timeout_msg}',
209
+ args,
210
+ print_timer=True,
211
+ )
212
+ if return_code != 0:
213
+ if 'timed out' in return_value:
214
+ xpk_print(
215
+ f'Timed out waiting for your workload after {timeout_msg}, see your'
216
+ ' workload here:'
217
+ # pylint: disable=line-too-long
218
+ f' https://console.cloud.google.com/kubernetes/service/{zone_to_region(args.zone)}/{args.cluster}/default/{args.workload}/details?project={args.project}'
219
+ )
220
+ return 124
221
+ else:
222
+ xpk_print(f'{return_value}')
223
+ xpk_print(f'Wait for workload returned ERROR {return_code}')
224
+ return return_code
225
+ xpk_print(
226
+ 'Finished waiting for your workload, see your workload here:'
227
+ # pylint: disable=line-too-long
228
+ f' https://console.cloud.google.com/kubernetes/service/{zone_to_region(args.zone)}/{args.cluster}/default/{args.workload}/details?project={args.project}'
229
+ )
230
+ status_cmd = (
231
+ f'kubectl get jobset {args.workload} -o'
232
+ " jsonpath='{.status.conditions[-1].type}'"
233
+ )
234
+ return_code, return_value = run_command_for_value(
235
+ status_cmd, 'Get jobset status', args
236
+ )
237
+ if return_code != 0:
238
+ xpk_print(f'Get workload status request returned ERROR {return_code}')
239
+ return return_code
240
+ xpk_print(f'Your workload finished with status: {return_value}')
241
+ if return_value != 'Completed':
242
+ xpk_print('Your workload did not complete successfully')
243
+ return 125
244
+ return 0
245
+
246
+
247
+ def get_gpu_volume(system: SystemCharacteristics) -> str:
248
+ """Get gpu volume based on user provided arguments.
249
+
250
+ Args:
251
+ system: system characteristics.
252
+
253
+ Returns:
254
+ str: yaml containing gpu volume
255
+ """
256
+ gpu_volume = ''
257
+ if system.device_type == H100_DEVICE_TYPE:
258
+ gpu_volume = """- name: nvidia-install-dir-host
259
+ hostPath:
260
+ path: /home/kubernetes/bin/nvidia/lib64
261
+ - name: tcpd-socket
262
+ hostPath:
263
+ path: /run/tcpx
264
+ - name: shared-memory
265
+ emptyDir:
266
+ medium: "Memory"
267
+ sizeLimit: 200Gi
268
+ - name: workload-terminated-volume
269
+ emptyDir:
270
+ - name: tcpx-nccl-plugin-volume
271
+ emptyDir:"""
272
+ elif system.device_type == H100_MEGA_DEVICE_TYPE:
273
+ gpu_volume = """- name: nvidia-install-dir-host
274
+ hostPath:
275
+ path: /home/kubernetes/bin/nvidia/lib64
276
+ - name: shared-memory
277
+ emptyDir:
278
+ medium: "Memory"
279
+ sizeLimit: 1Gi
280
+ - name: workload-terminated-volume
281
+ emptyDir:"""
282
+ return gpu_volume
283
+
284
+
285
+ def get_gpu_rxdm_image(system: SystemCharacteristics) -> str:
286
+ """Get config of rxdm based on user provided arguments.
287
+
288
+ Args:
289
+ system: system characteristics.
290
+
291
+ Returns:
292
+ str: yaml containing the rxdm name and image
293
+ """
294
+ gpu_rxdm_image = ''
295
+ if system.device_type == H100_DEVICE_TYPE:
296
+ gpu_rxdm_image = """- name: tcpd-daemon
297
+ image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/tcpgpudmarxd-dev:v2.0.9"""
298
+ elif system.device_type == H100_MEGA_DEVICE_TYPE:
299
+ gpu_rxdm_image = """- name: fastrak-daemon
300
+ image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.9"""
301
+ return gpu_rxdm_image
302
+
303
+
304
+ def get_gpu_rxdm_cmd(system: SystemCharacteristics) -> str:
305
+ """Get rxdm command based on user provided arguments.
306
+
307
+ Args:
308
+ system: system characteristics.
309
+
310
+ Returns:
311
+ str: command of running rxdm container
312
+ """
313
+ gpu_rxdm_cmd = ''
314
+ if system.device_type == H100_DEVICE_TYPE:
315
+ gpu_rxdm_cmd = (
316
+ '/tcpgpudmarxd/build/app/tcpgpudmarxd --gpu_nic_preset a3vm'
317
+ ' --gpu_shmem_type fd --setup_param "--verbose 128 2 0"'
318
+ )
319
+ elif system.device_type == H100_MEGA_DEVICE_TYPE:
320
+ gpu_rxdm_cmd = (
321
+ 'set -ex; chmod 755 /fts/entrypoint_rxdm_container.sh;'
322
+ ' /fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid='
323
+ ' --alsologtostderr'
324
+ )
325
+ return gpu_rxdm_cmd
326
+
327
+
328
+ def get_gpu_tcp_volume(system: SystemCharacteristics) -> str:
329
+ """Get gpu tcp volume based on user provided arguments.
330
+
331
+ Args:
332
+ system: system characteristics.
333
+
334
+ Returns:
335
+ str: yaml containing gpu tcp volume
336
+ """
337
+ gpu_tcp_volume = ''
338
+ if system.device_type == H100_DEVICE_TYPE:
339
+ gpu_tcp_volume = """- name: tcpd-socket
340
+ mountPath: /tmp"""
341
+ return gpu_tcp_volume
@@ -0,0 +1,15 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
@@ -0,0 +1,129 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import yaml
18
+ from ...utils.yaml import literal_string
19
+
20
+
21
+ def decorate_kjob_template(job_manifest) -> str:
22
+ spec = (
23
+ job_manifest.setdefault('spec', {})
24
+ .setdefault('template', {})
25
+ .setdefault('spec', {})
26
+ )
27
+ spec.setdefault('tolerations', [])
28
+ spec.setdefault('volumes', [])
29
+
30
+ add_volumes(job_manifest)
31
+ add_tolerations(job_manifest)
32
+ update_gpu_containers(job_manifest)
33
+ return job_manifest
34
+
35
+
36
+ def decorate_jobset(jobset_manifest_str, sub_networks) -> str:
37
+ """
38
+ Decorates a JobSet manifest with the necessary components for rdma-daemon.
39
+
40
+ Args:
41
+ jobset_manifest_str: The JobSet manifest as a YAML string.
42
+
43
+ Returns:
44
+ The modified JobSet manifest as a YAML string.
45
+ """
46
+
47
+ manifest = yaml.safe_load(jobset_manifest_str)
48
+
49
+ for job in manifest['spec']['replicatedJobs']:
50
+ job_manifest = job['template']
51
+ job_manifest.setdefault('spec', {}).setdefault('template', {}).setdefault(
52
+ 'metadata', {}
53
+ ).setdefault('annotations', {})
54
+ spec = (
55
+ job_manifest.setdefault('spec', {})
56
+ .setdefault('template', {})
57
+ .setdefault('spec', {})
58
+ )
59
+ spec.setdefault('tolerations', [])
60
+ spec.setdefault('volumes', [])
61
+
62
+ add_annotations(job_manifest, sub_networks)
63
+ add_volumes(job_manifest)
64
+ add_tolerations(job_manifest)
65
+ update_gpu_containers(job_manifest)
66
+
67
+ return yaml.dump(manifest, sort_keys=False)
68
+
69
+
70
+ def get_interfaces_entry(sub_networks: list[str]) -> tuple[str, str]:
71
+ interfaces = [
72
+ '[',
73
+ ' {"interfaceName":"eth0","network":"default"},',
74
+ *[
75
+ f' {{"interfaceName":"eth{i + 1}","network":"{sub_networks[i]}"}}{"," if i<8 else ""}'
76
+ for i in range(9)
77
+ ],
78
+ ']',
79
+ ]
80
+ return 'networking.gke.io/interfaces', literal_string('\n'.join(interfaces))
81
+
82
+
83
+ def add_annotations(job_manifest, sub_networks):
84
+ """Adds or updates annotations in the Pod template."""
85
+ annotations = job_manifest['spec']['template']['metadata']['annotations']
86
+ interfaces_key, interfaces_value = get_interfaces_entry(sub_networks)
87
+ annotations.update({
88
+ 'networking.gke.io/default-interface': "'eth0'",
89
+ interfaces_key: interfaces_value,
90
+ })
91
+
92
+
93
+ def add_volumes(job_manifest):
94
+ """Adds volumes to the Pod spec."""
95
+ volumes = job_manifest['spec']['template']['spec']['volumes']
96
+ volumes.append({
97
+ 'name': 'library-dir-host',
98
+ 'hostPath': {'path': '/home/kubernetes/bin/nvidia'},
99
+ })
100
+ volumes.append(
101
+ {'name': 'gib', 'hostPath': {'path': '/home/kubernetes/bin/gib'}}
102
+ )
103
+
104
+
105
+ def add_tolerations(job_manifest):
106
+ """Adds tolerations to the Pod spec."""
107
+ tolerations = job_manifest['spec']['template']['spec']['tolerations']
108
+ tolerations.append({
109
+ 'key': 'user-workload',
110
+ 'operator': 'Equal',
111
+ 'value': 'true',
112
+ 'effect': 'NoSchedule',
113
+ })
114
+
115
+
116
+ def update_gpu_containers(job_manifest):
117
+ for container in job_manifest['spec']['template']['spec']['containers']:
118
+ if 'nvidia.com/gpu' in container.get('resources', {}).get('limits', {}):
119
+ container.setdefault('env', [])
120
+ container['env'].append(
121
+ {'name': 'LD_LIBRARY_PATH', 'value': '/usr/local/nvidia/lib64'}
122
+ )
123
+ container.setdefault('volumeMounts', [])
124
+ container['volumeMounts'].append(
125
+ {'name': 'library-dir-host', 'mountPath': '/usr/local/nvidia'}
126
+ )
127
+ container['volumeMounts'].append(
128
+ {'name': 'gib', 'mountPath': '/usr/local/gib'}
129
+ )
@@ -0,0 +1,52 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import yaml
18
+
19
+ from ...core.storage import GCS_FUSE_TYPE, get_storage_volumes_yaml_dict, GCS_FUSE_ANNOTATION
20
+
21
+
22
+ def decorate_jobset(jobset_manifest_str, storages) -> str:
23
+ """
24
+ Decorates a JobSet manifest with the necessary storages.
25
+
26
+ Args:
27
+ jobset_manifest_str: The JobSet manifest as a YAML string.
28
+
29
+ Returns:
30
+ The modified JobSet manifest as a YAML string.
31
+ """
32
+
33
+ manifest = yaml.safe_load(jobset_manifest_str)
34
+ storage_volumes = get_storage_volumes_yaml_dict(storages)
35
+ for job in manifest['spec']['replicatedJobs']:
36
+ job_manifest = job['template']
37
+ add_annotations(job_manifest, storages)
38
+ add_volumes(job_manifest, storage_volumes)
39
+ return yaml.dump(manifest, sort_keys=False)
40
+
41
+
42
+ def add_annotations(job_manifest, storages):
43
+ """Adds or updates storage annotations in the Pod template."""
44
+ annotations = job_manifest['spec']['template']['metadata']['annotations']
45
+ gcs_present = [storage.type == GCS_FUSE_TYPE for storage in storages]
46
+ if gcs_present:
47
+ annotations.update(GCS_FUSE_ANNOTATION)
48
+
49
+
50
+ def add_volumes(job_manifest, storage_volumes):
51
+ volumes = job_manifest['spec']['template']['spec']['volumes']
52
+ volumes.extend(storage_volumes)