xpk 0.8.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. xpk/commands/batch.py +5 -6
  2. xpk/commands/cluster.py +246 -73
  3. xpk/commands/cluster_gcluster.py +27 -0
  4. xpk/commands/common.py +40 -1
  5. xpk/commands/kjob_common.py +13 -1
  6. xpk/commands/run.py +4 -5
  7. xpk/commands/shell.py +2 -2
  8. xpk/commands/storage.py +24 -6
  9. xpk/commands/workload.py +66 -27
  10. xpk/core/blueprint/blueprint_generator.py +115 -47
  11. xpk/core/capacity.py +66 -6
  12. xpk/core/cluster.py +282 -13
  13. xpk/core/config.py +1 -65
  14. xpk/core/docker_manager.py +1 -1
  15. xpk/core/docker_resources.py +145 -72
  16. xpk/core/filestore.py +2 -6
  17. xpk/core/gcsfuse.py +22 -4
  18. xpk/core/jobset.py +143 -0
  19. xpk/core/kjob.py +21 -18
  20. xpk/core/kueue.py +194 -4
  21. xpk/core/mtc.py +195 -0
  22. xpk/core/network.py +23 -1
  23. xpk/core/nodepool.py +17 -4
  24. xpk/core/pathways.py +2 -3
  25. xpk/core/resources.py +21 -0
  26. xpk/core/storage.py +1 -95
  27. xpk/core/system_characteristics.py +1 -1
  28. xpk/core/workload.py +1 -45
  29. xpk/core/workload_decorators/rdma_decorator.py +8 -10
  30. xpk/core/workload_decorators/tcpx_decorator.py +185 -0
  31. xpk/core/workload_decorators/tcpxo_decorator.py +22 -14
  32. xpk/parser/cluster.py +589 -389
  33. xpk/parser/storage.py +12 -3
  34. xpk/parser/workload.py +21 -3
  35. xpk/utils/kubectl.py +4 -1
  36. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/METADATA +178 -96
  37. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/RECORD +41 -38
  38. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/WHEEL +1 -1
  39. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/entry_points.txt +0 -0
  40. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/licenses/LICENSE +0 -0
  41. {xpk-0.8.0.dist-info → xpk-0.10.0.dist-info}/top_level.txt +0 -0
@@ -14,9 +14,11 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE, B200_DEVICE_TYPE
17
+ import os
18
+ import re
19
+ from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
18
20
  from .cluster import setup_k8s_env
19
- from .storage import GCS_FUSE_TYPE, GCP_FILESTORE_TYPE, Storage, get_storages_to_mount
21
+ from .storage import GCS_FUSE_TYPE, GCP_FILESTORE_TYPE, PARALLELSTORE_TYPE, GCE_PD_TYPE, LUSTRE_TYPE, Storage, get_storages_to_mount
20
22
  from .system_characteristics import AcceleratorType, SystemCharacteristics
21
23
 
22
24
 
@@ -64,6 +66,25 @@ def get_env_container(args, system: SystemCharacteristics) -> str:
64
66
  str:
65
67
  YAML with the env config for the main container, as a YAML string.
66
68
  """
69
+ if system.accelerator_type == AcceleratorType['GPU']:
70
+ return get_gpu_env(args, system)
71
+
72
+ if system.accelerator_type == AcceleratorType['CPU']:
73
+ return get_cpu_env(args, system)
74
+
75
+ return format_env_dict(args.env, system) # pytype: disable=bad-return-type
76
+
77
+
78
+ def get_gpu_env(args, system) -> str:
79
+ """Generate environment variables for GPU nodepools
80
+ Args:
81
+ num_slices: Number of slices to be used in the workload.
82
+ env_vars: Environment variables, processed from user args.
83
+ system: system characteristics
84
+
85
+ Returns:
86
+ str: yaml containing env variables
87
+ """
67
88
  gpu_env_yaml = """
68
89
  - name: REPLICATED_JOB_NAME
69
90
  valueFrom:
@@ -73,8 +94,6 @@ def get_env_container(args, system: SystemCharacteristics) -> str:
73
94
  valueFrom:
74
95
  fieldRef:
75
96
  fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
76
- - name: JAX_COORDINATOR_ADDRESS
77
- value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
78
97
  - name: NNODES
79
98
  value: "{args.num_nodes}"
80
99
  - name: NODE_RANK
@@ -84,36 +103,37 @@ def get_env_container(args, system: SystemCharacteristics) -> str:
84
103
  - name: USE_GPUDIRECT
85
104
  value: {gpu_direct_name}
86
105
  - name: GPUS_PER_NODE
87
- value: "{system.chips_per_vm}"
88
- - name: JAX_COORDINATOR_PORT
89
- value: "6002"
106
+ value: "{chips_per_vm}"
90
107
  - name: COMMAND
91
108
  value: "{args.command}"
92
- {args.env}"""
93
-
94
- if system.accelerator_type == AcceleratorType['GPU']:
95
- gpu_direct_name = 'fastrak'
96
- if args.device_type == H100_DEVICE_TYPE:
97
- gpu_direct_name = 'tcpx'
98
- gpu_env_yaml += """
99
- - name: LD_LIBRARY_PATH
100
- value: /usr/local/nvidia/lib64
101
- """
102
- elif args.device_type == H100_MEGA_DEVICE_TYPE:
103
- gpu_direct_name = 'tcpxo'
104
- elif args.device_type == H200_DEVICE_TYPE:
105
- gpu_direct_name = 'rdma'
106
- return gpu_env_yaml.format(
107
- args=args, system=system, gpu_direct_name=gpu_direct_name
108
- )
109
-
110
- if system.accelerator_type == AcceleratorType['CPU']:
111
- return get_cpu_env(args.num_slices, args.env, system)
112
-
113
- return args.env # pytype: disable=bad-return-type
109
+ {custom_envs}"""
110
+
111
+ gpu_direct_name = 'fastrak'
112
+ if args.device_type == H100_DEVICE_TYPE:
113
+ gpu_direct_name = 'tcpx'
114
+ elif args.device_type == H100_MEGA_DEVICE_TYPE:
115
+ gpu_direct_name = 'tcpxo'
116
+ elif args.device_type == H200_DEVICE_TYPE:
117
+ gpu_direct_name = 'rdma'
118
+
119
+ gpu_env_dic = {
120
+ 'JAX_COORDINATOR_PORT': '6002',
121
+ 'JAX_COORDINATOR_ADDRESS': (
122
+ '$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)'
123
+ ),
124
+ }
125
+
126
+ args.env = gpu_env_dic | args.env
127
+
128
+ return gpu_env_yaml.format(
129
+ args=args,
130
+ chips_per_vm=system.chips_per_vm,
131
+ gpu_direct_name=gpu_direct_name,
132
+ custom_envs=format_env_dict(args.env, system),
133
+ )
114
134
 
115
135
 
116
- def get_cpu_env(num_slices, env_vars, system) -> str:
136
+ def get_cpu_env(args, system) -> str:
117
137
  """Generate environment variables for CPU nodepools
118
138
  Args:
119
139
  num_slices: Number of slices to be used in the workload.
@@ -136,19 +156,87 @@ def get_cpu_env(num_slices, env_vars, system) -> str:
136
156
  valueFrom:
137
157
  fieldRef:
138
158
  fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
139
- - name: PROCESSES_IN_JOB
140
- value: "{processes_in_job}"
141
- - name: JAX_PROCESS_COUNT
142
- value: "{process_count}"
143
- {env_vars}
144
- - name: JAX_COORDINATOR_ADDRESS
145
- value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
159
+ {custom_envs}
146
160
  """
147
- return yaml.format(
148
- processes_in_job=system.vms_per_slice,
149
- process_count=calculate_process_count(num_slices, system.vms_per_slice),
150
- env_vars=env_vars,
151
- )
161
+
162
+ cpu_env_dic = {
163
+ 'PROCESSES_IN_JOB': str(system.vms_per_slice),
164
+ 'JAX_PROCESS_COUNT': str(
165
+ calculate_process_count(args.num_slices, system.vms_per_slice)
166
+ ),
167
+ 'JAX_COORDINATOR_ADDRESS': (
168
+ '$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)'
169
+ ),
170
+ }
171
+
172
+ args.env = cpu_env_dic | args.env
173
+
174
+ return yaml.format(custom_envs=format_env_dict(args.env, system))
175
+
176
+
177
+ def format_env_dict(env, system: SystemCharacteristics) -> str:
178
+ if system.accelerator_type == AcceleratorType['GPU']:
179
+ # For GPUs, it has two more spaces ahead of name and value respectively
180
+ env_format = '''
181
+ - name: {key}
182
+ value: "{value}"'''
183
+ else:
184
+ env_format = '''
185
+ - name: {key}
186
+ value: "{value}"'''
187
+ return ''.join(env_format.format(key=k, value=v) for k, v in env.items())
188
+
189
+
190
+ def parse_env_config(args, tensorboard_config):
191
+ """Parses the environment configurations to the a dictionary.
192
+
193
+ Args:
194
+ args: user provided arguments for running the command.
195
+ tensorboard_config: configuration of Vertex Tensorboard.
196
+ system: system characteristics.
197
+ """
198
+ env = {}
199
+
200
+ env_pat = re.compile(r'(^[a-zA-Z_][a-zA-Z0-9_]*?)(?:=(.*))?$', re.M)
201
+ if args.env_file:
202
+ print('Setting container environment from', args.env_file)
203
+ with open(file=args.env_file, mode='r', encoding='utf-8') as f:
204
+ for match in env_pat.finditer(f.read()):
205
+ variable = match.group(1)
206
+ if match.group(2) is not None:
207
+ env[variable] = match.group(2)
208
+ else:
209
+ assert variable in os.environ, (
210
+ f'Variable {variable} is not set in the current '
211
+ 'environment, a value must be specified.'
212
+ )
213
+ env[variable] = os.environ[variable]
214
+ if args.env:
215
+ for var in args.env:
216
+ match = env_pat.match(var)
217
+ assert match and match.group(2) is not None, (
218
+ 'Invalid environment variable, format must be '
219
+ f'`--env VARIABLE=value`: {var}'
220
+ )
221
+ variable = match.group(1)
222
+ env[variable] = match.group(2)
223
+
224
+ if not args.use_pathways:
225
+ if args.debug_dump_gcs:
226
+ if 'XLA_FLAGS' in env:
227
+ raise ValueError(
228
+ 'Conflict: XLA_FLAGS defined in both --debug_dump_gcs '
229
+ 'and environment file. Please choose one way to define '
230
+ 'XLA_FLAGS.'
231
+ )
232
+ env['XLA_FLAGS'] = '--xla_dump_to=/tmp/xla_dump/'
233
+
234
+ if tensorboard_config:
235
+ env['UPLOAD_DATA_TO_TENSORBOARD'] = True
236
+ for key, value in tensorboard_config.items():
237
+ env[key.upper()] = value
238
+
239
+ args.env = env
152
240
 
153
241
 
154
242
  def get_volumes(args, system: SystemCharacteristics) -> str:
@@ -188,13 +276,13 @@ def get_volumes(args, system: SystemCharacteristics) -> str:
188
276
  setup_k8s_env(args), args.storage
189
277
  )
190
278
  for storage in storages:
191
- if storage.type == GCS_FUSE_TYPE:
192
- volumes += f"""- name: {storage.pv}
193
- persistentVolumeClaim:
194
- claimName: {storage.pvc}
195
- readOnly: {storage.readonly}
196
- """
197
- if storage.type == GCP_FILESTORE_TYPE:
279
+ if storage.type in {
280
+ GCS_FUSE_TYPE,
281
+ GCP_FILESTORE_TYPE,
282
+ PARALLELSTORE_TYPE,
283
+ GCE_PD_TYPE,
284
+ LUSTRE_TYPE,
285
+ }:
198
286
  volumes += f"""- name: {storage.pv}
199
287
  persistentVolumeClaim:
200
288
  claimName: {storage.pvc}
@@ -235,34 +323,19 @@ def get_volume_mounts(args, system: SystemCharacteristics) -> str:
235
323
  mountPath: /shared-volume
236
324
  """
237
325
  elif system.accelerator_type == AcceleratorType['GPU']:
238
- if system.device_type == H100_DEVICE_TYPE:
239
- volume_mount_yaml = """- name: nvidia-install-dir-host
240
- mountPath: /usr/local/nvidia/lib64
241
- - name: tcpx-nccl-plugin-volume
242
- mountPath: /usr/local/tcpx
243
- - name: tcpd-socket
244
- mountPath: /tmp
245
- - name: shared-memory
246
- mountPath: /dev/shm
247
- - name: workload-terminated-volume
248
- mountPath: /usr/share/workload"""
249
- elif (
250
- system.device_type == H100_MEGA_DEVICE_TYPE
251
- or system.device_type == H200_DEVICE_TYPE
252
- or system.device_type == B200_DEVICE_TYPE
253
- ):
254
- volume_mount_yaml = ''
326
+ volume_mount_yaml = ''
255
327
 
256
328
  storages: list[Storage] = get_storages_to_mount(
257
329
  setup_k8s_env(args), args.storage
258
330
  )
259
331
  for storage in storages:
260
- if storage.type == GCS_FUSE_TYPE:
261
- volume_mount_yaml += f"""- name: {storage.pv}
262
- mountPath: {storage.mount_point}
263
- readOnly: {storage.readonly}
264
- """
265
- if storage.type == GCP_FILESTORE_TYPE:
332
+ if storage.type in {
333
+ GCS_FUSE_TYPE,
334
+ GCP_FILESTORE_TYPE,
335
+ PARALLELSTORE_TYPE,
336
+ GCE_PD_TYPE,
337
+ LUSTRE_TYPE,
338
+ }:
266
339
  volume_mount_yaml += f"""- name: {storage.pv}
267
340
  mountPath: {storage.mount_point}
268
341
  readOnly: {storage.readonly}
xpk/core/filestore.py CHANGED
@@ -200,9 +200,7 @@ class FilestoreClient:
200
200
  ] = f"projects/{self.project}/global/networks/{network}"
201
201
  return data
202
202
 
203
- def create_pv(
204
- self, name: str, vol: str, access_mode: str, mount_options: str
205
- ) -> dict:
203
+ def create_pv(self, name: str, vol: str, access_mode: str) -> dict:
206
204
  """Create a yaml representing filestore PersistentVolume."""
207
205
  data = templates.load(FS_PV_PATH)
208
206
  data["metadata"]["name"] = get_pv_name(name)
@@ -217,7 +215,6 @@ class FilestoreClient:
217
215
  0
218
216
  ].ip_addresses[0]
219
217
  data["spec"]["csi"]["volumeAttributes"]["volume"] = vol
220
- data["spec"]["mountOptions"] = mount_options.split(",")
221
218
  return data
222
219
 
223
220
  def create_pvc(self, name: str, access_mode: str) -> dict:
@@ -238,10 +235,9 @@ class FilestoreClient:
238
235
  vol: str,
239
236
  access_mode: str,
240
237
  network: str,
241
- mount_options: str,
242
238
  ) -> list[dict]:
243
239
  self.load_instance()
244
- pv = self.create_pv(name, vol, access_mode, mount_options)
240
+ pv = self.create_pv(name, vol, access_mode)
245
241
  pvc = self.create_pvc(name, access_mode)
246
242
  sc = self.create_sc(name, network)
247
243
  return [pv, pvc, sc]
xpk/core/gcsfuse.py CHANGED
@@ -20,11 +20,21 @@ FUSE_PV_PATH = "/../templates/fuse-pv.yaml"
20
20
  FUSE_PVC_PATH = "/../templates/fuse-pvc.yaml"
21
21
 
22
22
 
23
- def create_pv(name: str, size: int, bucket: str, mount_options: str) -> dict:
23
+ def create_pv(
24
+ name: str,
25
+ size: int,
26
+ bucket: str,
27
+ mount_options: str,
28
+ prefetch_metadata: bool,
29
+ ) -> dict:
24
30
  data = templates.load(FUSE_PV_PATH)
25
31
  data["metadata"]["name"] = f"{name}-pv"
26
32
  data["spec"]["capacity"]["storage"] = f"{size}Gi"
27
33
  data["spec"]["csi"]["volumeHandle"] = bucket
34
+ if prefetch_metadata:
35
+ data["spec"]["csi"]["volumeAttributes"][
36
+ "gcsfuseMetadataPrefetchOnMount"
37
+ ] = "true"
28
38
  data["spec"]["mountOptions"] = mount_options.split(",")
29
39
  return data
30
40
 
@@ -38,16 +48,24 @@ def create_pvc(name: str, size: int) -> dict:
38
48
 
39
49
 
40
50
  def manifest(
41
- name: str, bucket: str, size: int, mount_options: str
51
+ name: str,
52
+ bucket: str,
53
+ size: int,
54
+ mount_options: str,
55
+ prefetch_metadata: bool,
42
56
  ) -> list[dict]:
43
- """Creates GCS FUSE manifest file.
57
+ """Creates GCS FUSE storage manifest file.
44
58
 
45
59
  Args:
46
60
  name (str): base name of the volumes
47
61
  bucket (str): name of the storage bucket
48
62
  size (str): size of the storage (in GB)
63
+ prefetch_metadata (bool): if set, then enables metadata pre-population when mounting the volume
49
64
  mount_options (str): comma-separated list of mountOptions for PersistentVolume
65
+
66
+ Returns:
67
+ list[dict]: list of manifests
50
68
  """
51
- pv = create_pv(name, size, bucket, mount_options)
69
+ pv = create_pv(name, size, bucket, mount_options, prefetch_metadata)
52
70
  pvc = create_pvc(name, size)
53
71
  return [pv, pvc]
xpk/core/jobset.py ADDED
@@ -0,0 +1,143 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import math
18
+
19
+ from ..utils.console import xpk_exit, xpk_print
20
+ from ..utils.file import write_tmp_file
21
+ from ..core.kueue import (
22
+ MEMORY_SIZE_PER_VM,
23
+ MIN_MEMORY_LIMIT_SIZE,
24
+ )
25
+ from .commands import (
26
+ run_command_for_value,
27
+ run_command_with_updates_retry,
28
+ )
29
+
30
+ jobset_controller_manager_yml = """
31
+ apiVersion: apps/v1
32
+ kind: Deployment
33
+ metadata:
34
+ labels:
35
+ app.kubernetes.io/component: manager
36
+ app.kubernetes.io/created-by: jobset
37
+ app.kubernetes.io/instance: controller-manager
38
+ app.kubernetes.io/managed-by: kustomize
39
+ app.kubernetes.io/name: deployment
40
+ app.kubernetes.io/part-of: jobset
41
+ control-plane: controller-manager
42
+ name: jobset-controller-manager
43
+ namespace: jobset-system
44
+ spec:
45
+ replicas: 1
46
+ selector:
47
+ matchLabels:
48
+ control-plane: controller-manager
49
+ template:
50
+ metadata:
51
+ annotations:
52
+ kubectl.kubernetes.io/default-container: manager
53
+ labels:
54
+ control-plane: controller-manager
55
+ spec:
56
+ containers:
57
+ - args:
58
+ - --config=/controller_manager_config.yaml
59
+ - --zap-log-level=2
60
+ command:
61
+ - /manager
62
+ image: registry.k8s.io/jobset/jobset:v0.8.0
63
+ livenessProbe:
64
+ httpGet:
65
+ path: /healthz
66
+ port: 8081
67
+ initialDelaySeconds: 15
68
+ periodSeconds: 20
69
+ name: manager
70
+ ports:
71
+ - containerPort: 9443
72
+ name: webhook-server
73
+ protocol: TCP
74
+ readinessProbe:
75
+ httpGet:
76
+ path: /readyz
77
+ port: 8081
78
+ initialDelaySeconds: 5
79
+ periodSeconds: 10
80
+ resources:
81
+ limits:
82
+ memory: {memory_limit_size}
83
+ requests:
84
+ cpu: 500m
85
+ memory: 128Mi
86
+ securityContext:
87
+ allowPrivilegeEscalation: false
88
+ capabilities:
89
+ drop:
90
+ - ALL
91
+ volumeMounts:
92
+ - mountPath: /controller_manager_config.yaml
93
+ name: manager-config
94
+ subPath: controller_manager_config.yaml
95
+ - mountPath: /tmp/k8s-webhook-server/serving-certs
96
+ name: cert
97
+ readOnly: true
98
+ securityContext:
99
+ runAsNonRoot: true
100
+ serviceAccountName: jobset-controller-manager
101
+ terminationGracePeriodSeconds: 10
102
+ volumes:
103
+ - configMap:
104
+ name: jobset-manager-config
105
+ name: manager-config
106
+ - name: cert
107
+ secret:
108
+ defaultMode: 420
109
+ secretName: jobset-webhook-server-cert
110
+ """
111
+
112
+
113
+ def update_jobset_resources_if_necessary(args):
114
+ """Update the jobset manifest to increase the resources for the jobset controller manager.
115
+
116
+ Args:
117
+ args: user provided arguments for running the command.
118
+
119
+ Returns:
120
+ 0 if successful and 1 otherwise.
121
+ """
122
+ # Get total number of nodes
123
+ cmd_total_node_num = 'kubectl get node --no-headers | wc -l'
124
+ return_code, out = run_command_for_value(
125
+ cmd_total_node_num, 'Count total nodes', args
126
+ )
127
+ if return_code != 0:
128
+ xpk_exit(1)
129
+ # 1.2MiB per VM or 4GiB (whichever is greater).
130
+ new_memory_limit = (
131
+ f'{max(math.ceil(int(out) * MEMORY_SIZE_PER_VM), MIN_MEMORY_LIMIT_SIZE)}Mi'
132
+ )
133
+ yml_string = jobset_controller_manager_yml.format(
134
+ memory_limit_size=new_memory_limit,
135
+ )
136
+ tmp = write_tmp_file(yml_string)
137
+ command = f'kubectl apply -f {str(tmp.file.name)}'
138
+
139
+ task = 'Updating jobset Controller Manager resources'
140
+ return_code = run_command_with_updates_retry(command, task, args)
141
+ if return_code != 0:
142
+ xpk_print(f'{task} returned ERROR {return_code}')
143
+ return return_code
xpk/core/kjob.py CHANGED
@@ -22,16 +22,9 @@ from kubernetes import client as k8s_client
22
22
  from kubernetes.client import ApiClient
23
23
  from kubernetes.client.rest import ApiException
24
24
 
25
- from ..core.blueprint.blueprint_generator import (
26
- get_subnetworks_for_a3mega,
27
- get_subnetworks_for_a3ultra,
28
- get_subnetworks_for_a4,
29
- )
30
- from ..core.capacity import H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
31
- from ..core.storage import GCS_FUSE_ANNOTATIONS, PARALLELSTORE_ANNOTATIONS
32
- from ..core.workload_decorators import rdma_decorator, tcpxo_decorator
33
25
  from ..utils import templates
34
26
  from ..utils.console import xpk_exit, xpk_print
27
+ from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
35
28
  from .cluster import DEFAULT_NAMESPACE, XPK_SA, setup_k8s_env
36
29
  from .commands import (
37
30
  run_command_for_value,
@@ -46,12 +39,21 @@ from .config import (
46
39
  KJOB_SHELL_WORKING_DIRECTORY,
47
40
  XpkConfig,
48
41
  )
49
- from .resources import (
50
- AcceleratorType,
51
- SystemCharacteristics,
52
- get_cluster_system_characteristics,
42
+ from .network import get_cluster_subnetworks
43
+ from .system_characteristics import AcceleratorType, SystemCharacteristics
44
+ from .resources import get_cluster_system_characteristics
45
+ from .storage import (
46
+ GCS_FUSE_ANNOTATIONS,
47
+ PARALLELSTORE_ANNOTATIONS,
48
+ get_auto_mount_gcsfuse_storages,
49
+ get_auto_mount_parallelstore_storages,
50
+ get_auto_mount_storages,
51
+ )
52
+ from .workload_decorators import (
53
+ rdma_decorator,
54
+ tcpx_decorator,
55
+ tcpxo_decorator,
53
56
  )
54
- from .storage import get_auto_mount_gcsfuse_storages, get_auto_mount_storages, get_auto_mount_parallelstore_storages
55
57
  from .workload_decorators.tcpxo_decorator import get_tcpxo_deamon_entry
56
58
 
57
59
  KJOB_API_GROUP_NAME = "kjobctl.x-k8s.io"
@@ -164,8 +166,8 @@ Kueue_TAS_annotation = "kueue.x-k8s.io/podset-preferred-topology=cloud.google.co
164
166
  default_interface_annotation = "networking.gke.io/default-interface=eth0"
165
167
 
166
168
 
167
- def get_a4_pod_template_annotations() -> tuple[str, str]:
168
- sub_networks = get_subnetworks_for_a4()
169
+ def get_a4_pod_template_annotations(args) -> tuple[str, str]:
170
+ sub_networks = get_cluster_subnetworks(args)
169
171
  interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
170
172
  sub_networks
171
173
  )
@@ -177,7 +179,7 @@ def get_a4_pod_template_annotations() -> tuple[str, str]:
177
179
 
178
180
 
179
181
  def get_a3ultra_pod_template_annotations(args: Namespace) -> tuple[str, str]:
180
- sub_networks = get_subnetworks_for_a3ultra(args.cluster)
182
+ sub_networks = get_cluster_subnetworks(args)
181
183
  interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
182
184
  sub_networks
183
185
  )
@@ -192,7 +194,7 @@ def get_a3mega_pod_template_annotations(
192
194
  args: Namespace,
193
195
  ) -> tuple[str, str, str]:
194
196
  """Adds or updates annotations in the Pod template."""
195
- sub_networks = get_subnetworks_for_a3mega(args.cluster)
197
+ sub_networks = get_cluster_subnetworks(args)
196
198
  tcpxo_deamon_key, tcpxo_deamon_paths = get_tcpxo_deamon_entry()
197
199
  interfaces_key, interfaces_value = tcpxo_decorator.get_interfaces_entry(
198
200
  sub_networks
@@ -267,6 +269,8 @@ def create_app_profile_instance(
267
269
 
268
270
  def decorate_job_template_with_gpu(yml_string: str, gpu_type: str) -> str:
269
271
  job_spec = yaml.safe_load(yml_string)["template"]
272
+ if gpu_type == H100_DEVICE_TYPE:
273
+ job_spec = tcpx_decorator.decorate_kjob_template(job_spec)
270
274
  if gpu_type == H100_MEGA_DEVICE_TYPE:
271
275
  job_spec = tcpxo_decorator.decorate_kjob_template(job_spec)
272
276
  if gpu_type == H200_DEVICE_TYPE:
@@ -373,7 +377,6 @@ def prepare_kjob(args: Namespace) -> int:
373
377
  job_err_code = create_job_template_instance(args, system, service_account)
374
378
  if job_err_code > 0:
375
379
  return job_err_code
376
-
377
380
  pod_err_code = create_pod_template_instance(args, service_account)
378
381
  if pod_err_code > 0:
379
382
  return pod_err_code