xpk 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. xpk/api/__init__.py +15 -0
  2. xpk/api/storage_crd.yaml +52 -0
  3. xpk/commands/batch.py +27 -5
  4. xpk/commands/cluster.py +104 -80
  5. xpk/commands/cluster_gcluster.py +94 -10
  6. xpk/commands/common.py +44 -0
  7. xpk/commands/config.py +29 -0
  8. xpk/commands/info.py +8 -10
  9. xpk/commands/inspector.py +5 -11
  10. xpk/commands/job.py +9 -7
  11. xpk/commands/kind.py +34 -4
  12. xpk/commands/kjob_common.py +44 -0
  13. xpk/commands/run.py +128 -0
  14. xpk/commands/shell.py +27 -7
  15. xpk/commands/storage.py +280 -0
  16. xpk/commands/version.py +6 -18
  17. xpk/commands/workload.py +381 -184
  18. xpk/core/blueprint/blueprint_definitions.py +1 -0
  19. xpk/core/blueprint/blueprint_generator.py +132 -76
  20. xpk/core/capacity.py +185 -0
  21. xpk/core/cluster.py +564 -0
  22. xpk/core/cluster_private.py +6 -3
  23. xpk/core/commands.py +18 -14
  24. xpk/core/config.py +179 -0
  25. xpk/core/docker_container.py +225 -0
  26. xpk/core/docker_image.py +210 -0
  27. xpk/core/docker_resources.py +350 -0
  28. xpk/core/filestore.py +251 -0
  29. xpk/core/gcloud_context.py +196 -0
  30. xpk/core/gcluster_manager.py +20 -2
  31. xpk/core/gcsfuse.py +50 -0
  32. xpk/core/kjob.py +257 -18
  33. xpk/core/kueue.py +12 -6
  34. xpk/core/monitoring.py +134 -0
  35. xpk/core/nap.py +32 -20
  36. xpk/core/network.py +377 -0
  37. xpk/core/nodepool.py +581 -0
  38. xpk/core/pathways.py +124 -45
  39. xpk/core/remote_state/__init__.py +15 -0
  40. xpk/core/remote_state/fuse_remote_state.py +99 -0
  41. xpk/core/remote_state/remote_state_client.py +38 -0
  42. xpk/core/resources.py +238 -0
  43. xpk/core/scheduling.py +253 -0
  44. xpk/core/storage.py +581 -0
  45. xpk/core/system_characteristics.py +38 -1
  46. xpk/core/vertex.py +105 -0
  47. xpk/core/workload.py +209 -1
  48. xpk/core/workload_decorators/rdma_decorator.py +25 -5
  49. xpk/core/workload_decorators/storage_decorator.py +52 -0
  50. xpk/core/workload_decorators/tcpxo_decorator.py +70 -37
  51. xpk/main.py +3 -1
  52. xpk/parser/batch.py +10 -151
  53. xpk/parser/cluster.py +49 -8
  54. xpk/parser/common.py +189 -1
  55. xpk/parser/config.py +49 -0
  56. xpk/parser/core.py +27 -1
  57. xpk/parser/info.py +2 -1
  58. xpk/parser/inspector.py +3 -3
  59. xpk/parser/job.py +25 -4
  60. xpk/parser/kind.py +3 -2
  61. xpk/parser/run.py +47 -0
  62. xpk/parser/shell.py +10 -1
  63. xpk/parser/storage.py +326 -0
  64. xpk/parser/validators.py +3 -3
  65. xpk/parser/workload.py +118 -76
  66. xpk/templates/__init__.py +15 -0
  67. xpk/templates/storage.yaml +13 -0
  68. xpk/utils/gcs_utils.py +125 -0
  69. xpk/utils/kubectl.py +57 -0
  70. xpk/utils/objects.py +8 -5
  71. xpk/utils/templates.py +28 -0
  72. xpk/utils/validation.py +80 -0
  73. {xpk-0.6.0.dist-info → xpk-0.7.1.dist-info}/METADATA +169 -15
  74. xpk-0.7.1.dist-info/RECORD +92 -0
  75. {xpk-0.6.0.dist-info → xpk-0.7.1.dist-info}/WHEEL +1 -1
  76. xpk/core/core.py +0 -2824
  77. xpk-0.6.0.dist-info/RECORD +0 -57
  78. {xpk-0.6.0.dist-info → xpk-0.7.1.dist-info}/entry_points.txt +0 -0
  79. {xpk-0.6.0.dist-info → xpk-0.7.1.dist-info/licenses}/LICENSE +0 -0
  80. {xpk-0.6.0.dist-info → xpk-0.7.1.dist-info}/top_level.txt +0 -0
xpk/core/vertex.py ADDED
@@ -0,0 +1,105 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from ..utils.console import xpk_print
18
+ from .resources import CLUSTER_METADATA_CONFIGMAP, get_cluster_configmap
19
+
20
+ DEFAULT_VERTEX_TENSORBOARD_NAME = 'tb-instance'
21
+
22
+
23
+ def create_vertex_tensorboard(args) -> dict:
24
+ """Creates a Tensorboard instance in Vertex AI.
25
+
26
+ Args:
27
+ args: user provided arguments.
28
+
29
+ Returns:
30
+ dict containing Tensorboard instance name, id and location.
31
+ """
32
+ from cloud_accelerator_diagnostics import ( # pylint: disable=import-outside-toplevel
33
+ tensorboard,
34
+ )
35
+
36
+ tensorboard_config = {}
37
+ tensorboard_name = args.tensorboard_name
38
+ if tensorboard_name is None:
39
+ tensorboard_name = f'{args.cluster}-{DEFAULT_VERTEX_TENSORBOARD_NAME}'
40
+ instance_id = tensorboard.create_instance( # pylint: disable=used-before-assignment
41
+ project=args.project,
42
+ location=args.tensorboard_region,
43
+ tensorboard_name=tensorboard_name,
44
+ )
45
+ if instance_id:
46
+ xpk_print(
47
+ f'Tensorboard instance {tensorboard_name} is successfully created.'
48
+ )
49
+ tensorboard_config['tensorboard_region'] = args.tensorboard_region
50
+ tensorboard_config['tensorboard_name'] = tensorboard_name
51
+ tensorboard_config['tensorboard_id'] = instance_id
52
+ return tensorboard_config
53
+
54
+
55
+ def create_vertex_experiment(args) -> dict | None:
56
+ """Creates an Experiment in Vertex AI.
57
+
58
+ Args:
59
+ args: user provided arguments.
60
+
61
+ Returns:
62
+ map containing Vertex Tensorboard configurations.
63
+ """
64
+ from cloud_accelerator_diagnostics import ( # pylint: disable=import-outside-toplevel
65
+ tensorboard,
66
+ )
67
+
68
+ metadata_configmap_name = f'{args.cluster}-{CLUSTER_METADATA_CONFIGMAP}'
69
+ cluster_config_map = get_cluster_configmap(args, metadata_configmap_name)
70
+
71
+ if cluster_config_map is None or 'tensorboard_name' not in cluster_config_map:
72
+ xpk_print(
73
+ 'No Vertex Tensorboard instance has been created in cluster create. Run'
74
+ ' `xpk cluster create --create-vertex-tensorboard` before running `xpk'
75
+ ' workload create --use-vertex-tensorboard` to create a Vertex'
76
+ ' Tensorboard instance. Alternatively, use `xpk cluster create-pathways'
77
+ ' --create-vertex-tensorboard` before running `xpk workload'
78
+ ' create-pathways --use-vertex-tensorboard`.'
79
+ )
80
+ return None
81
+
82
+ tensorboard_config = {}
83
+ tensorboard_config['tensorboard_project'] = args.project
84
+ tensorboard_config['tensorboard_region'] = cluster_config_map[
85
+ 'tensorboard_region'
86
+ ]
87
+ tensorboard_config['tensorboard_name'] = cluster_config_map[
88
+ 'tensorboard_name'
89
+ ]
90
+ experiment_name = args.experiment_name
91
+ if experiment_name is None:
92
+ experiment_name = f'{args.cluster}-{args.workload}'
93
+ tensorboard_config['experiment_name'] = experiment_name
94
+
95
+ _, tensorboard_url = tensorboard.create_experiment(
96
+ project=args.project,
97
+ location=tensorboard_config['tensorboard_region'],
98
+ experiment_name=experiment_name,
99
+ tensorboard_name=tensorboard_config['tensorboard_name'],
100
+ )
101
+ if tensorboard_url is None:
102
+ return None
103
+
104
+ xpk_print(f'You can view Vertex Tensorboard at: {tensorboard_url}')
105
+ return tensorboard_config
xpk/core/workload.py CHANGED
@@ -14,7 +14,11 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ from ..utils.console import xpk_exit, xpk_print
18
+ from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE
17
19
  from .commands import run_command_for_value
20
+ from .gcloud_context import zone_to_region
21
+ from .system_characteristics import SystemCharacteristics
18
22
 
19
23
 
20
24
  def workload_list_awk_command(filter_key) -> str:
@@ -129,5 +133,209 @@ def get_workload_list(args) -> tuple[int, str]:
129
133
  task += f' with filter-by-job={args.filter_by_job}'
130
134
 
131
135
  return_code, return_value = run_command_for_value(command, task, args)
132
-
133
136
  return return_code, return_value
137
+
138
+
139
+ def check_if_workload_exists(args) -> bool:
140
+ """Check if workload exists.
141
+
142
+ Args:
143
+ args: user provided arguments for running the command.
144
+
145
+ Returns:
146
+ returns true if workload exist, otherwise returns false.
147
+ """
148
+ columns = {
149
+ 'Jobset': '.metadata.ownerReferences[0].name',
150
+ }
151
+
152
+ s = ','.join([key + ':' + value for key, value in columns.items()])
153
+
154
+ command = f"kubectl get workloads -o=custom-columns='{s}'"
155
+ return_code, return_msg = run_command_for_value(
156
+ command, 'Check if Workload Already Exists', args
157
+ )
158
+
159
+ if return_code != 0:
160
+ xpk_print(f'List Job request returned ERROR {return_code}')
161
+ xpk_exit(return_code)
162
+
163
+ lines = return_msg.split('\n')
164
+ new_workload_name = args.workload
165
+ for line in lines:
166
+ if line == new_workload_name:
167
+ return True
168
+ return False
169
+
170
+
171
+ def wait_for_job_completion(args) -> int:
172
+ """Function to wait for job completion.
173
+
174
+ Args:
175
+ args: user provided arguments for running the command.
176
+
177
+ Returns:
178
+ return_code: 0 if successful, 124 if timeout, 125 if unsuccessful job, 1 otherwise
179
+ """
180
+ # Check that the workload exists
181
+ args.workload = args.wait_for_job_completion
182
+ workload_exists = check_if_workload_exists(args)
183
+ if not workload_exists:
184
+ xpk_print(f'Workload named {args.workload} does not exist.')
185
+ return 1
186
+
187
+ # Get the full workload name
188
+ get_workload_name_cmd = f'kubectl get workloads | grep jobset-{args.workload}'
189
+ return_code, return_value = run_command_for_value(
190
+ get_workload_name_cmd, 'Get full workload name', args
191
+ )
192
+ if return_code != 0:
193
+ xpk_print(f'Get full workload name request returned ERROR {return_code}')
194
+ return return_code
195
+ full_workload_name = return_value.split(' ')[0]
196
+
197
+ # Call kubectl wait on the workload using the full workload name
198
+ timeout_val = args.timeout if args.timeout is not None else -1
199
+ timeout_msg = (
200
+ f'{timeout_val}s' if timeout_val != -1 else 'max timeout (1 week)'
201
+ )
202
+ wait_cmd = (
203
+ "kubectl wait --for jsonpath='.status.conditions[-1].type'=Finished"
204
+ f' workload {full_workload_name} --timeout={timeout_val}s'
205
+ )
206
+ return_code, return_value = run_command_for_value(
207
+ wait_cmd,
208
+ f'Wait for workload to finish with timeout of {timeout_msg}',
209
+ args,
210
+ print_timer=True,
211
+ )
212
+ if return_code != 0:
213
+ if 'timed out' in return_value:
214
+ xpk_print(
215
+ f'Timed out waiting for your workload after {timeout_msg}, see your'
216
+ ' workload here:'
217
+ # pylint: disable=line-too-long
218
+ f' https://console.cloud.google.com/kubernetes/service/{zone_to_region(args.zone)}/{args.cluster}/default/{args.workload}/details?project={args.project}'
219
+ )
220
+ return 124
221
+ else:
222
+ xpk_print(f'{return_value}')
223
+ xpk_print(f'Wait for workload returned ERROR {return_code}')
224
+ return return_code
225
+ xpk_print(
226
+ 'Finished waiting for your workload, see your workload here:'
227
+ # pylint: disable=line-too-long
228
+ f' https://console.cloud.google.com/kubernetes/service/{zone_to_region(args.zone)}/{args.cluster}/default/{args.workload}/details?project={args.project}'
229
+ )
230
+ status_cmd = (
231
+ f'kubectl get jobset {args.workload} -o'
232
+ " jsonpath='{.status.conditions[-1].type}'"
233
+ )
234
+ return_code, return_value = run_command_for_value(
235
+ status_cmd, 'Get jobset status', args
236
+ )
237
+ if return_code != 0:
238
+ xpk_print(f'Get workload status request returned ERROR {return_code}')
239
+ return return_code
240
+ xpk_print(f'Your workload finished with status: {return_value}')
241
+ if return_value != 'Completed':
242
+ xpk_print('Your workload did not complete successfully')
243
+ return 125
244
+ return 0
245
+
246
+
247
+ def get_gpu_volume(system: SystemCharacteristics) -> str:
248
+ """Get gpu volume based on user provided arguments.
249
+
250
+ Args:
251
+ system: system characteristics.
252
+
253
+ Returns:
254
+ str: yaml containing gpu volume
255
+ """
256
+ gpu_volume = ''
257
+ if system.device_type == H100_DEVICE_TYPE:
258
+ gpu_volume = """- name: nvidia-install-dir-host
259
+ hostPath:
260
+ path: /home/kubernetes/bin/nvidia/lib64
261
+ - name: tcpd-socket
262
+ hostPath:
263
+ path: /run/tcpx
264
+ - name: shared-memory
265
+ emptyDir:
266
+ medium: "Memory"
267
+ sizeLimit: 200Gi
268
+ - name: workload-terminated-volume
269
+ emptyDir:
270
+ - name: tcpx-nccl-plugin-volume
271
+ emptyDir:"""
272
+ elif system.device_type == H100_MEGA_DEVICE_TYPE:
273
+ gpu_volume = """- name: nvidia-install-dir-host
274
+ hostPath:
275
+ path: /home/kubernetes/bin/nvidia/lib64
276
+ - name: shared-memory
277
+ emptyDir:
278
+ medium: "Memory"
279
+ sizeLimit: 1Gi
280
+ - name: workload-terminated-volume
281
+ emptyDir:"""
282
+ return gpu_volume
283
+
284
+
285
+ def get_gpu_rxdm_image(system: SystemCharacteristics) -> str:
286
+ """Get config of rxdm based on user provided arguments.
287
+
288
+ Args:
289
+ system: system characteristics.
290
+
291
+ Returns:
292
+ str: yaml containing the rxdm name and image
293
+ """
294
+ gpu_rxdm_image = ''
295
+ if system.device_type == H100_DEVICE_TYPE:
296
+ gpu_rxdm_image = """- name: tcpd-daemon
297
+ image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/tcpgpudmarxd-dev:v2.0.9"""
298
+ elif system.device_type == H100_MEGA_DEVICE_TYPE:
299
+ gpu_rxdm_image = """- name: fastrak-daemon
300
+ image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.9"""
301
+ return gpu_rxdm_image
302
+
303
+
304
+ def get_gpu_rxdm_cmd(system: SystemCharacteristics) -> str:
305
+ """Get rxdm command based on user provided arguments.
306
+
307
+ Args:
308
+ system: system characteristics.
309
+
310
+ Returns:
311
+ str: command of running rxdm container
312
+ """
313
+ gpu_rxdm_cmd = ''
314
+ if system.device_type == H100_DEVICE_TYPE:
315
+ gpu_rxdm_cmd = (
316
+ '/tcpgpudmarxd/build/app/tcpgpudmarxd --gpu_nic_preset a3vm'
317
+ ' --gpu_shmem_type fd --setup_param "--verbose 128 2 0"'
318
+ )
319
+ elif system.device_type == H100_MEGA_DEVICE_TYPE:
320
+ gpu_rxdm_cmd = (
321
+ 'set -ex; chmod 755 /fts/entrypoint_rxdm_container.sh;'
322
+ ' /fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid='
323
+ ' --alsologtostderr'
324
+ )
325
+ return gpu_rxdm_cmd
326
+
327
+
328
+ def get_gpu_tcp_volume(system: SystemCharacteristics) -> str:
329
+ """Get gpu tcp volume based on user provided arguments.
330
+
331
+ Args:
332
+ system: system characteristics.
333
+
334
+ Returns:
335
+ str: yaml containing gpu tcp volume
336
+ """
337
+ gpu_tcp_volume = ''
338
+ if system.device_type == H100_DEVICE_TYPE:
339
+ gpu_tcp_volume = """- name: tcpd-socket
340
+ mountPath: /tmp"""
341
+ return gpu_tcp_volume
@@ -18,6 +18,21 @@ import yaml
18
18
  from ...utils.yaml import literal_string
19
19
 
20
20
 
21
+ def decorate_kjob_template(job_manifest) -> str:
22
+ spec = (
23
+ job_manifest.setdefault('spec', {})
24
+ .setdefault('template', {})
25
+ .setdefault('spec', {})
26
+ )
27
+ spec.setdefault('tolerations', [])
28
+ spec.setdefault('volumes', [])
29
+
30
+ add_volumes(job_manifest)
31
+ add_tolerations(job_manifest)
32
+ update_gpu_containers(job_manifest)
33
+ return job_manifest
34
+
35
+
21
36
  def decorate_jobset(jobset_manifest_str, sub_networks) -> str:
22
37
  """
23
38
  Decorates a JobSet manifest with the necessary components for rdma-daemon.
@@ -52,9 +67,7 @@ def decorate_jobset(jobset_manifest_str, sub_networks) -> str:
52
67
  return yaml.dump(manifest, sort_keys=False)
53
68
 
54
69
 
55
- def add_annotations(job_manifest, sub_networks):
56
- """Adds or updates annotations in the Pod template."""
57
- annotations = job_manifest['spec']['template']['metadata']['annotations']
70
+ def get_interfaces_entry(sub_networks: list[str]) -> tuple[str, str]:
58
71
  interfaces = [
59
72
  '[',
60
73
  ' {"interfaceName":"eth0","network":"default"},',
@@ -64,9 +77,16 @@ def add_annotations(job_manifest, sub_networks):
64
77
  ],
65
78
  ']',
66
79
  ]
80
+ return 'networking.gke.io/interfaces', literal_string('\n'.join(interfaces))
81
+
82
+
83
+ def add_annotations(job_manifest, sub_networks):
84
+ """Adds or updates annotations in the Pod template."""
85
+ annotations = job_manifest['spec']['template']['metadata']['annotations']
86
+ interfaces_key, interfaces_value = get_interfaces_entry(sub_networks)
67
87
  annotations.update({
68
- 'networking.gke.io/default-interface': 'eth0',
69
- 'networking.gke.io/interfaces': literal_string('\n'.join(interfaces)),
88
+ 'networking.gke.io/default-interface': "'eth0'",
89
+ interfaces_key: interfaces_value,
70
90
  })
71
91
 
72
92
 
@@ -0,0 +1,52 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import yaml
18
+
19
+ from ...core.storage import GCS_FUSE_TYPE, get_storage_volumes_yaml_dict, GCS_FUSE_ANNOTATION
20
+
21
+
22
+ def decorate_jobset(jobset_manifest_str, storages) -> str:
23
+ """
24
+ Decorates a JobSet manifest with the necessary storages.
25
+
26
+ Args:
27
+ jobset_manifest_str: The JobSet manifest as a YAML string.
28
+
29
+ Returns:
30
+ The modified JobSet manifest as a YAML string.
31
+ """
32
+
33
+ manifest = yaml.safe_load(jobset_manifest_str)
34
+ storage_volumes = get_storage_volumes_yaml_dict(storages)
35
+ for job in manifest['spec']['replicatedJobs']:
36
+ job_manifest = job['template']
37
+ add_annotations(job_manifest, storages)
38
+ add_volumes(job_manifest, storage_volumes)
39
+ return yaml.dump(manifest, sort_keys=False)
40
+
41
+
42
+ def add_annotations(job_manifest, storages):
43
+ """Adds or updates storage annotations in the Pod template."""
44
+ annotations = job_manifest['spec']['template']['metadata']['annotations']
45
+ gcs_present = [storage.type == GCS_FUSE_TYPE for storage in storages]
46
+ if gcs_present:
47
+ annotations.update(GCS_FUSE_ANNOTATION)
48
+
49
+
50
+ def add_volumes(job_manifest, storage_volumes):
51
+ volumes = job_manifest['spec']['template']['spec']['volumes']
52
+ volumes.extend(storage_volumes)
@@ -21,6 +21,42 @@ from ...utils.yaml import literal_string
21
21
  rxdm = 'v1.0.12'
22
22
 
23
23
 
24
+ def decorate_kjob_template(job_manifest: dict) -> dict:
25
+ spec = (
26
+ job_manifest.setdefault('spec', {})
27
+ .setdefault('template', {})
28
+ .setdefault('spec', {})
29
+ )
30
+ spec.setdefault('tolerations', [])
31
+ spec.setdefault('volumes', [])
32
+
33
+ add_volumes(job_manifest)
34
+ add_tolerations(job_manifest)
35
+ add_tcpxo_daemon_container(job_manifest)
36
+ update_gpu_containers(job_manifest)
37
+ return job_manifest
38
+
39
+
40
+ def decorate_job(job_manifest: dict, sub_networks: list[str]) -> dict:
41
+ job_manifest.setdefault('spec', {}).setdefault('template', {}).setdefault(
42
+ 'metadata', {}
43
+ ).setdefault('annotations', {})
44
+ spec = (
45
+ job_manifest.setdefault('spec', {})
46
+ .setdefault('template', {})
47
+ .setdefault('spec', {})
48
+ )
49
+ spec.setdefault('tolerations', [])
50
+ spec.setdefault('volumes', [])
51
+
52
+ add_annotations(job_manifest, sub_networks)
53
+ add_volumes(job_manifest)
54
+ add_tolerations(job_manifest)
55
+ add_tcpxo_daemon_container(job_manifest)
56
+ update_gpu_containers(job_manifest)
57
+ return job_manifest
58
+
59
+
24
60
  def decorate_jobset(jobset_manifest_str, sub_networks) -> str:
25
61
  """
26
62
  Decorates a JobSet manifest with the necessary components for tcpxo-daemon.
@@ -36,29 +72,11 @@ def decorate_jobset(jobset_manifest_str, sub_networks) -> str:
36
72
 
37
73
  for job in manifest['spec']['replicatedJobs']:
38
74
  job_manifest = job['template']
39
- job_manifest.setdefault('spec', {}).setdefault('template', {}).setdefault(
40
- 'metadata', {}
41
- ).setdefault('annotations', {})
42
- spec = (
43
- job_manifest.setdefault('spec', {})
44
- .setdefault('template', {})
45
- .setdefault('spec', {})
46
- )
47
- spec.setdefault('tolerations', [])
48
- spec.setdefault('volumes', [])
49
-
50
- add_annotations(job_manifest, sub_networks)
51
- add_volumes(job_manifest)
52
- add_tolerations(job_manifest)
53
- add_tcpxo_daemon_container(job_manifest)
54
- update_gpu_containers(job_manifest)
55
-
75
+ job_manifest = decorate_job(job_manifest, sub_networks)
56
76
  return yaml.dump(manifest, sort_keys=False)
57
77
 
58
78
 
59
- def add_annotations(job_manifest, sub_networks):
60
- """Adds or updates annotations in the Pod template."""
61
- annotations = job_manifest['spec']['template']['metadata']['annotations']
79
+ def get_interfaces_entry(sub_networks: list[str]) -> tuple[str, str]:
62
80
  interfaces = [
63
81
  '[',
64
82
  ' {"interfaceName":"eth0","network":"default"},',
@@ -68,22 +86,34 @@ def add_annotations(job_manifest, sub_networks):
68
86
  ],
69
87
  ']',
70
88
  ]
89
+ return 'networking.gke.io/interfaces', literal_string('\n'.join(interfaces))
90
+
91
+
92
+ def get_tcpxo_deamon_entry() -> tuple[str, str]:
93
+ return 'devices.gke.io/container.tcpxo-daemon', literal_string(
94
+ '- path: /dev/nvidia0\n'
95
+ '- path: /dev/nvidia1\n'
96
+ '- path: /dev/nvidia2\n'
97
+ '- path: /dev/nvidia3\n'
98
+ '- path: /dev/nvidia4\n'
99
+ '- path: /dev/nvidia5\n'
100
+ '- path: /dev/nvidia6\n'
101
+ '- path: /dev/nvidia7\n'
102
+ '- path: /dev/nvidiactl\n'
103
+ '- path: /dev/nvidia-uvm\n'
104
+ '- path: /dev/dmabuf_import_helper\n'
105
+ )
106
+
107
+
108
+ def add_annotations(job_manifest, sub_networks):
109
+ """Adds or updates annotations in the Pod template."""
110
+ annotations = job_manifest['spec']['template']['metadata']['annotations']
111
+ tcpxo_deamon_key, tcpxo_deamon_paths = get_tcpxo_deamon_entry()
112
+ interfaces_key, interfaces_value = get_interfaces_entry(sub_networks)
71
113
  annotations.update({
72
- 'devices.gke.io/container.tcpxo-daemon': literal_string(
73
- '- path: /dev/nvidia0\n'
74
- '- path: /dev/nvidia1\n'
75
- '- path: /dev/nvidia2\n'
76
- '- path: /dev/nvidia3\n'
77
- '- path: /dev/nvidia4\n'
78
- '- path: /dev/nvidia5\n'
79
- '- path: /dev/nvidia6\n'
80
- '- path: /dev/nvidia7\n'
81
- '- path: /dev/nvidiactl\n'
82
- '- path: /dev/nvidia-uvm\n'
83
- '- path: /dev/dmabuf_import_helper\n'
84
- ),
114
+ tcpxo_deamon_key: tcpxo_deamon_paths,
85
115
  'networking.gke.io/default-interface': 'eth0',
86
- 'networking.gke.io/interfaces': literal_string('\n'.join(interfaces)),
116
+ interfaces_key: interfaces_value,
87
117
  })
88
118
 
89
119
 
@@ -103,7 +133,7 @@ def add_volumes(job_manifest):
103
133
  volumes = job_manifest['spec']['template']['spec']['volumes']
104
134
  volumes.append({
105
135
  'name': 'libraries',
106
- 'hostPath': {'path': '/home/kubernetes/bin/nvidia/lib64'},
136
+ 'hostPath': {'path': '/home/kubernetes/bin/nvidia'},
107
137
  })
108
138
  volumes.append({'name': 'sys', 'hostPath': {'path': '/sys'}})
109
139
  volumes.append({'name': 'proc-sys', 'hostPath': {'path': '/proc/sys'}})
@@ -135,8 +165,8 @@ def add_tcpxo_daemon_container(job_manifest):
135
165
  ],
136
166
  'env': [{'name': 'LD_LIBRARY_PATH', 'value': '/usr/local/nvidia/lib64'}],
137
167
  }
138
- job_manifest['spec']['template']['spec']['containers'].insert(
139
- 0, tcpxo_daemon_container
168
+ job_manifest['spec']['template']['spec']['containers'].append(
169
+ tcpxo_daemon_container
140
170
  )
141
171
 
142
172
 
@@ -155,3 +185,6 @@ def update_gpu_containers(job_manifest):
155
185
  container['volumeMounts'].append(
156
186
  {'name': 'aperture-devices', 'mountPath': '/dev/aperture_devices'}
157
187
  )
188
+ container['volumeMounts'].append(
189
+ {'name': 'libraries', 'mountPath': '/usr/local/nvidia'}
190
+ )
xpk/main.py CHANGED
@@ -36,10 +36,11 @@ import sys
36
36
 
37
37
  from .parser.core import set_parser
38
38
  from .utils.console import xpk_print
39
-
39
+ from .utils.validation import validate_dependencies
40
40
  ################### Compatibility Check ###################
41
41
  # Check that the user runs the below version or greater.
42
42
 
43
+
43
44
  major_version_supported = 3
44
45
  minor_version_supported = 10
45
46
 
@@ -60,6 +61,7 @@ parser = argparse.ArgumentParser(description='xpk command', prog='xpk')
60
61
  set_parser(parser=parser)
61
62
 
62
63
  xpk_print('Starting xpk', flush=True)
64
+ validate_dependencies()
63
65
  main_args = parser.parse_args()
64
66
  main_args.enable_ray_cluster = False
65
67
  main_args.func(main_args)