xpk 0.7.2__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. xpk/commands/batch.py +19 -13
  2. xpk/commands/cluster.py +240 -71
  3. xpk/commands/cluster_gcluster.py +22 -5
  4. xpk/commands/common.py +33 -1
  5. xpk/commands/info.py +2 -4
  6. xpk/commands/job.py +7 -8
  7. xpk/commands/kjob_common.py +30 -18
  8. xpk/commands/run.py +17 -12
  9. xpk/commands/shell.py +3 -4
  10. xpk/commands/storage.py +75 -19
  11. xpk/commands/workload.py +161 -324
  12. xpk/core/blueprint/blueprint_definitions.py +2 -0
  13. xpk/core/blueprint/blueprint_generator.py +335 -45
  14. xpk/core/capacity.py +1 -0
  15. xpk/core/cluster.py +193 -12
  16. xpk/core/config.py +3 -1
  17. xpk/core/docker_manager.py +1 -1
  18. xpk/core/docker_resources.py +9 -21
  19. xpk/core/filestore.py +5 -1
  20. xpk/core/gcsfuse.py +27 -6
  21. xpk/core/kjob.py +66 -20
  22. xpk/core/kueue.py +30 -0
  23. xpk/core/mtc.py +195 -0
  24. xpk/core/nap.py +4 -0
  25. xpk/core/network.py +34 -22
  26. xpk/core/nodepool.py +28 -26
  27. xpk/core/pathways.py +165 -210
  28. xpk/core/resources.py +21 -0
  29. xpk/core/scheduling.py +36 -0
  30. xpk/core/storage.py +66 -12
  31. xpk/core/system_characteristics.py +9 -0
  32. xpk/core/workload.py +28 -83
  33. xpk/core/workload_decorators/rdma_decorator.py +11 -15
  34. xpk/core/workload_decorators/storage_decorator.py +8 -3
  35. xpk/core/workload_decorators/tcpx_decorator.py +179 -0
  36. xpk/core/workload_decorators/tcpxo_decorator.py +17 -16
  37. xpk/parser/cluster.py +574 -381
  38. xpk/parser/storage.py +25 -5
  39. xpk/parser/workload.py +59 -31
  40. xpk/utils/kubectl.py +4 -1
  41. {xpk-0.7.2.dist-info → xpk-0.9.0.dist-info}/METADATA +192 -93
  42. {xpk-0.7.2.dist-info → xpk-0.9.0.dist-info}/RECORD +46 -44
  43. {xpk-0.7.2.dist-info → xpk-0.9.0.dist-info}/WHEEL +1 -1
  44. {xpk-0.7.2.dist-info → xpk-0.9.0.dist-info}/entry_points.txt +0 -0
  45. {xpk-0.7.2.dist-info → xpk-0.9.0.dist-info}/licenses/LICENSE +0 -0
  46. {xpk-0.7.2.dist-info → xpk-0.9.0.dist-info}/top_level.txt +0 -0
xpk/core/cluster.py CHANGED
@@ -14,27 +14,37 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import yaml
17
18
  from google.api_core.exceptions import PermissionDenied
18
19
  from google.cloud import resourcemanager_v3
19
20
  from kubernetes import client as k8s_client
20
21
  from kubernetes import config
21
22
  from kubernetes.client.exceptions import ApiException
22
- from .resources import get_cluster_system_characteristics
23
23
 
24
24
  from ..utils.console import xpk_exit, xpk_print
25
- from .capacity import H100_DEVICE_TYPE
25
+ from .capacity import B200_DEVICE_TYPE, H100_DEVICE_TYPE, H200_DEVICE_TYPE
26
26
  from .commands import (
27
27
  run_command_for_value,
28
28
  run_command_with_updates,
29
29
  run_command_with_updates_retry,
30
30
  )
31
- from .gcloud_context import add_zone_and_project, get_gke_server_config, zone_to_region
31
+ from .gcloud_context import (
32
+ add_zone_and_project,
33
+ get_gke_server_config,
34
+ zone_to_region,
35
+ )
32
36
  from .nodepool import upgrade_gke_nodepools_version
37
+ from .resources import get_cluster_system_characteristics
33
38
  from .system_characteristics import SystemCharacteristics
34
39
 
35
- JOBSET_VERSION = 'v0.7.2'
36
- INSTALLER_NCC_TCPX = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-tcpx-installer.yaml'
37
- INSTALLER_NCC_TCPXO = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpxo/nccl-tcpxo-installer.yaml'
40
+ JOBSET_VERSION = 'v0.8.0'
41
+ PATHWAYS_JOB_VERSION = 'v0.1.1'
42
+ INSTALLER_NCCL_TCPX = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-tcpx-installer.yaml'
43
+ INSTALLER_NCCL_TCPXO = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpxo/nccl-tcpxo-installer.yaml'
44
+ INSTALLER_NCCL_RDMA = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-rdma/nccl-rdma-installer.yaml'
45
+ CONFIG_NCCL_TCPX = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-config.yaml'
46
+ NRI_DEVICE_INJECTOR = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nri_device_injector/nri-device-injector.yaml'
47
+ MGLRU_DISABLE = 'https://raw.githubusercontent.com/GoogleCloudPlatform/cluster-toolkit/refs/heads/main/examples/gke-a3-ultragpu/mglru-disable.yaml'
38
48
 
39
49
  DEFAULT_NAMESPACE = 'default'
40
50
  XPK_SA = 'xpk-sa'
@@ -71,6 +81,35 @@ def set_jobset_on_cluster(args) -> int:
71
81
  return return_code
72
82
 
73
83
 
84
+ def set_pathways_job_on_cluster(args) -> int:
85
+ """Add PathwaysJob command on server side and ask user to verify it is created.
86
+
87
+ Args:
88
+ args: user provided arguments for running the command.
89
+
90
+ Returns:
91
+ 0 if successful and 1 otherwise.
92
+ """
93
+ command = (
94
+ 'kubectl apply --server-side -f'
95
+ f' https://github.com/google/pathways-job/releases/download/{PATHWAYS_JOB_VERSION}/install.yaml'
96
+ )
97
+ task = f'Install PathwaysJob on {args.cluster}'
98
+ return_code = run_command_with_updates_retry(command, task, args)
99
+
100
+ if return_code != 0:
101
+ xpk_print(f'{task} returned with ERROR {return_code}.\n')
102
+ xpk_print(
103
+ "This LIKELY means you're missing Kubernetes Permissions, you can"
104
+ ' validate this by checking if the error references permission problems'
105
+ ' such as `requires one of ["container.*"] permission(s)`. Follow our'
106
+ ' readme:'
107
+ ' https://github.com/google/xpk/blob/main/README.md#troubleshooting for'
108
+ ' instructions on how to fix these permissions.'
109
+ )
110
+ return return_code
111
+
112
+
74
113
  def install_nccl_on_cluster(args, system: SystemCharacteristics) -> int:
75
114
  """Install NCCL plugin on the cluster.
76
115
 
@@ -82,9 +121,11 @@ def install_nccl_on_cluster(args, system: SystemCharacteristics) -> int:
82
121
  0 if successful and 1 otherwise.
83
122
  """
84
123
  if system.device_type == H100_DEVICE_TYPE:
85
- command = f'kubectl apply -f {INSTALLER_NCC_TCPX}'
124
+ command = f'kubectl apply -f {INSTALLER_NCCL_TCPX}'
125
+ elif system.device_type in [H200_DEVICE_TYPE, B200_DEVICE_TYPE]:
126
+ command = f'kubectl apply -f {INSTALLER_NCCL_RDMA}'
86
127
  else:
87
- command = f'kubectl apply -f {INSTALLER_NCC_TCPXO}'
128
+ command = f'kubectl apply -f {INSTALLER_NCCL_TCPXO}'
88
129
 
89
130
  return_code = run_command_with_updates(
90
131
  command, 'Install NCCL Plugin On Cluster', args
@@ -96,9 +137,108 @@ def install_nccl_on_cluster(args, system: SystemCharacteristics) -> int:
96
137
  )
97
138
  return 1
98
139
 
140
+ if system.device_type == H100_DEVICE_TYPE:
141
+ command = f'kubectl apply -f {CONFIG_NCCL_TCPX}'
142
+
143
+ return_code = run_command_with_updates(
144
+ command, 'Install NCCL Config On Cluster', args
145
+ )
146
+
147
+ if return_code != 0:
148
+ xpk_print(
149
+ f'Install NCCL Config On Cluster request returned ERROR {return_code}'
150
+ )
151
+ return 1
152
+
153
+ return 0
154
+
155
+
156
+ def disable_mglru_on_cluster(args) -> int:
157
+ """Disable MGLRU on the cluster.
158
+
159
+ Args:
160
+ args: user provided arguments for running the command.
161
+
162
+ Returns:
163
+ 0 if successful and 1 otherwise.
164
+ """
165
+ command = f'kubectl apply -f {MGLRU_DISABLE}'
166
+ return_code = run_command_with_updates(
167
+ command, 'Disable MGLRU On Cluster', args
168
+ )
169
+
170
+ if return_code != 0:
171
+ xpk_print('Disablig MGLRU On Cluster request returned ERROR')
172
+ return 1
173
+
174
+ return 0
175
+
176
+
177
+ def install_nri_on_cluster(args) -> int:
178
+ """Install NRI Device Injector on the cluster.
179
+
180
+ Args:
181
+ args: user provided arguments for running the command.
182
+ system: system characteristics.
183
+
184
+ Returns:
185
+ 0 if successful and 1 otherwise.
186
+ """
187
+ command = f'kubectl apply -f {NRI_DEVICE_INJECTOR}'
188
+ return_code = run_command_with_updates(
189
+ command, 'Install NRI Device Injector On Cluster', args
190
+ )
191
+
192
+ if return_code != 0:
193
+ xpk_print(
194
+ 'Install NRI Device Injector On Cluster request returned ERROR'
195
+ f' {return_code}'
196
+ )
197
+ return 1
198
+
99
199
  return 0
100
200
 
101
201
 
202
+ def get_cluster_nodes_info(args) -> list[dict]:
203
+ """Get list of cluster's nodes descrition in yaml format
204
+
205
+ Args:
206
+ args: user provided arguments for running the command.
207
+
208
+ Returns:
209
+ List of nodes info yaml objects.
210
+ """
211
+ xpk_print("Getting cluster's info...")
212
+ command = 'kubectl get nodes -o yaml'
213
+ err_code, val = run_command_for_value(
214
+ command=command,
215
+ task='Get cluster nodes info',
216
+ global_args=args,
217
+ )
218
+ if err_code != 0:
219
+ xpk_exit(err_code)
220
+ data = yaml.safe_load(val)
221
+ return data['items'] # pytype: disable=bad-return-type
222
+
223
+
224
+ def count_nodes_on_cluster(args, system: SystemCharacteristics) -> int:
225
+ """Count cluster nodes by accelerator type"""
226
+ nodes_info = get_cluster_nodes_info(args)
227
+ accelerators = [
228
+ node['metadata']['labels']['cloud.google.com/gke-accelerator']
229
+ for node in nodes_info
230
+ if 'cloud.google.com/gke-accelerator' in node['metadata']['labels']
231
+ ]
232
+ if system.device_type != H200_DEVICE_TYPE:
233
+ xpk_print(
234
+ 'Automatic node detection is not supported for device type:'
235
+ f' {system.device_type}'
236
+ )
237
+ xpk_exit(1)
238
+ num_nodes: int = sum(acc == system.gke_accelerator for acc in accelerators)
239
+ return num_nodes
240
+
241
+
102
242
  def get_cluster_network(args) -> str:
103
243
  xpk_print("Getting cluster's VPC network...")
104
244
  cluster_network_cmd = (
@@ -135,8 +275,48 @@ def update_cluster_with_gcpfilestore_driver_if_necessary(args) -> int:
135
275
  return 0
136
276
 
137
277
 
278
+ def update_cluster_with_parallelstore_driver_if_necessary(args) -> int:
279
+ """Updates a GKE cluster to enable Parallelstore CSI driver, if not enabled already.
280
+ Args:
281
+ args: user provided arguments for running the command.
282
+ Returns:
283
+ 0 if successful and error code otherwise.
284
+ """
285
+ if is_driver_enabled_on_cluster(args, driver='parallelstoreCsiDriver'):
286
+ return 0
287
+ cluster_update_return_code = update_gke_cluster_with_addon(
288
+ args, 'ParallelstoreCsiDriver'
289
+ )
290
+ if cluster_update_return_code > 0:
291
+ xpk_print('Updating GKE cluster to enable Parallelstore CSI driver failed!')
292
+ return cluster_update_return_code
293
+
294
+ return 0
295
+
296
+
297
+ def update_cluster_with_pd_driver_if_necessary(args) -> int:
298
+ """Updates a GKE cluster to enable PersistentDisk CSI driver, if not enabled already.
299
+ Args:
300
+ args: user provided arguments for running the command.
301
+ Returns:
302
+ 0 if successful and error code otherwise.
303
+ """
304
+ if is_driver_enabled_on_cluster(args, driver='gcePersistentDiskCsiDriver'):
305
+ return 0
306
+ cluster_update_return_code = update_gke_cluster_with_addon(
307
+ args, 'GcePersistentDiskCsiDriver'
308
+ )
309
+ if cluster_update_return_code > 0:
310
+ xpk_print(
311
+ 'Updating GKE cluster to enable PersistentDisk CSI driver failed!'
312
+ )
313
+ return cluster_update_return_code
314
+
315
+ return 0
316
+
317
+
138
318
  def is_driver_enabled_on_cluster(args, driver: str) -> bool:
139
- """Checks if GCSFuse CSI driver is enabled on the cluster.
319
+ """Checks if the CSI driver is enabled on the cluster.
140
320
  Args:
141
321
  args: user provided arguments for running the command.
142
322
  driver (str) : name of the driver
@@ -148,14 +328,14 @@ def is_driver_enabled_on_cluster(args, driver: str) -> bool:
148
328
  f' --project={args.project} --region={zone_to_region(args.zone)}'
149
329
  f' --format="value(addonsConfig.{driver}Config.enabled)"'
150
330
  )
151
- return_code, gcsfuse_driver_enabled = run_command_for_value(
331
+ return_code, driver_enabled = run_command_for_value(
152
332
  command,
153
333
  f'Checks if {driver} driver is enabled in cluster describe.',
154
334
  args,
155
335
  )
156
336
  if return_code != 0:
157
337
  xpk_exit(return_code)
158
- if gcsfuse_driver_enabled.lower() == 'true':
338
+ if driver_enabled.strip().lower() == 'true':
159
339
  xpk_print(f'{driver} driver is enabled on the cluster, no update needed.')
160
340
  return True
161
341
  return False
@@ -446,7 +626,7 @@ def is_gcsfuse_driver_enabled_on_cluster(args) -> bool:
446
626
  )
447
627
  if return_code != 0:
448
628
  xpk_exit(return_code)
449
- if gcsfuse_driver_enabled.lower() == 'true':
629
+ if gcsfuse_driver_enabled.strip().lower() == 'true':
450
630
  xpk_print('GCSFuse CSI driver is enabled on the cluster, no update needed.')
451
631
  return True
452
632
  return False
@@ -551,6 +731,7 @@ def get_cluster_credentials(args) -> None:
551
731
  command = (
552
732
  'gcloud container clusters get-credentials'
553
733
  f' {args.cluster} --region={zone_to_region(args.zone)}'
734
+ ' --dns-endpoint'
554
735
  f' --project={args.project} &&'
555
736
  ' kubectl config view && kubectl config set-context --current'
556
737
  ' --namespace=default'
xpk/core/config.py CHANGED
@@ -24,7 +24,7 @@ from ..utils.console import xpk_print
24
24
  from .system_characteristics import AcceleratorType, SystemCharacteristics
25
25
 
26
26
  # This is the version for XPK PyPI package
27
- __version__ = 'v0.7.2'
27
+ __version__ = 'v0.9.0'
28
28
  XPK_CURRENT_VERSION = __version__
29
29
  XPK_CONFIG_FILE = os.path.expanduser('~/.config/xpk/config.yaml')
30
30
 
@@ -39,6 +39,7 @@ KJOB_SHELL_IMAGE = 'shell-image'
39
39
  KJOB_SHELL_INTERACTIVE_COMMAND = 'shell-interactive-command'
40
40
  KJOB_SHELL_WORKING_DIRECTORY = 'shell-working-directory'
41
41
  CONFIGS_KEY = 'configs'
42
+ GKE_ENDPOINT_KEY = 'gke-endpoint'
42
43
  DEPENDENCIES_KEY = 'deps-verified-version'
43
44
  XPK_CONFIG_FILE = os.path.expanduser('~/.config/xpk/config.yaml')
44
45
 
@@ -47,6 +48,7 @@ DEFAULT_KEYS = [
47
48
  CLUSTER_NAME_KEY,
48
49
  PROJECT_KEY,
49
50
  ZONE_KEY,
51
+ GKE_ENDPOINT_KEY,
50
52
  DEPENDENCIES_KEY,
51
53
  KJOB_BATCH_IMAGE,
52
54
  KJOB_BATCH_WORKING_DIRECTORY,
@@ -30,7 +30,7 @@ import time
30
30
  DockerRunCommandExitCode = 135
31
31
  dockerBuildErrorCode = 134
32
32
  ctk_dockerfile_path = "Dockerfile"
33
- ctk_build_ref = "v1.45.1"
33
+ ctk_build_ref = "v1.48.0"
34
34
  ctk_docker_image = "xpk-ctk"
35
35
  ctk_container_name = "xpk-ctk-container"
36
36
  gcloud_cfg_mount_path = "/root/.config/gcloud"
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
17
+ from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE, B200_DEVICE_TYPE
18
18
  from .cluster import setup_k8s_env
19
19
  from .storage import GCS_FUSE_TYPE, GCP_FILESTORE_TYPE, Storage, get_storages_to_mount
20
20
  from .system_characteristics import AcceleratorType, SystemCharacteristics
@@ -64,22 +64,6 @@ def get_env_container(args, system: SystemCharacteristics) -> str:
64
64
  str:
65
65
  YAML with the env config for the main container, as a YAML string.
66
66
  """
67
- pw_env_yaml = """
68
- - name: XCLOUD_ENVIRONMENT
69
- value: GCP
70
- - name: JAX_PLATFORMS
71
- value: proxy
72
- - name: JAX_BACKEND_TARGET
73
- value: {proxy_address}
74
- - name: JOBSET_NAME
75
- valueFrom:
76
- fieldRef:
77
- fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']"""
78
- if args.use_pathways:
79
- return pw_env_yaml.format(
80
- args=args, proxy_address=args.pathways_proxy_address
81
- )
82
-
83
67
  gpu_env_yaml = """
84
68
  - name: REPLICATED_JOB_NAME
85
69
  valueFrom:
@@ -182,11 +166,14 @@ def get_volumes(args, system: SystemCharacteristics) -> str:
182
166
  name: dshm-2
183
167
  """
184
168
 
185
- if args.ramdisk_directory != '':
186
- volumes += """
169
+ if hasattr(args, 'ramdisk_directory') and args.ramdisk_directory != '':
170
+ driver = 'phase1-checkpoint.csi.storage.gke.io'
171
+ if hasattr(args, 'mtc_enabled') and args.mtc_enabled:
172
+ driver = 'multitier-checkpoint.csi.storage.gke.io'
173
+ volumes += f"""
187
174
  - name: cache
188
175
  csi:
189
- driver: phase1-checkpoint.csi.storage.gke.io"""
176
+ driver: {driver}"""
190
177
 
191
178
  if (
192
179
  system.accelerator_type == AcceleratorType['TPU']
@@ -229,7 +216,7 @@ def get_volume_mounts(args, system: SystemCharacteristics) -> str:
229
216
  name: dshm-2
230
217
  """
231
218
 
232
- if args.ramdisk_directory != '':
219
+ if hasattr(args, 'ramdisk_directory') and args.ramdisk_directory != '':
233
220
  volume_mount_yaml += f"""
234
221
  - mountPath: /{args.ramdisk_directory}
235
222
  name: cache"""
@@ -262,6 +249,7 @@ def get_volume_mounts(args, system: SystemCharacteristics) -> str:
262
249
  elif (
263
250
  system.device_type == H100_MEGA_DEVICE_TYPE
264
251
  or system.device_type == H200_DEVICE_TYPE
252
+ or system.device_type == B200_DEVICE_TYPE
265
253
  ):
266
254
  volume_mount_yaml = ''
267
255
 
xpk/core/filestore.py CHANGED
@@ -230,7 +230,11 @@ class FilestoreClient:
230
230
  return data
231
231
 
232
232
  def manifest(
233
- self, name: str, vol: str, access_mode: str, network: str
233
+ self,
234
+ name: str,
235
+ vol: str,
236
+ access_mode: str,
237
+ network: str,
234
238
  ) -> list[dict]:
235
239
  self.load_instance()
236
240
  pv = self.create_pv(name, vol, access_mode)
xpk/core/gcsfuse.py CHANGED
@@ -20,11 +20,22 @@ FUSE_PV_PATH = "/../templates/fuse-pv.yaml"
20
20
  FUSE_PVC_PATH = "/../templates/fuse-pvc.yaml"
21
21
 
22
22
 
23
- def create_pv(name: str, size: int, bucket: str) -> dict:
23
+ def create_pv(
24
+ name: str,
25
+ size: int,
26
+ bucket: str,
27
+ mount_options: str,
28
+ prefetch_metadata: bool,
29
+ ) -> dict:
24
30
  data = templates.load(FUSE_PV_PATH)
25
31
  data["metadata"]["name"] = f"{name}-pv"
26
32
  data["spec"]["capacity"]["storage"] = f"{size}Gi"
27
33
  data["spec"]["csi"]["volumeHandle"] = bucket
34
+ if prefetch_metadata:
35
+ data["spec"]["csi"]["volumeAttributes"][
36
+ "gcsfuseMetadataPrefetchOnMount"
37
+ ] = "true"
38
+ data["spec"]["mountOptions"] = mount_options.split(",")
28
39
  return data
29
40
 
30
41
 
@@ -36,15 +47,25 @@ def create_pvc(name: str, size: int) -> dict:
36
47
  return data
37
48
 
38
49
 
39
- def manifest(name: str, bucket: str, size: int) -> list[dict]:
40
- """Creates GCS FUSE manifest file.
50
+ def manifest(
51
+ name: str,
52
+ bucket: str,
53
+ size: int,
54
+ mount_options: str,
55
+ prefetch_metadata: bool,
56
+ ) -> list[dict]:
57
+ """Creates GCS FUSE storage manifest file.
41
58
 
42
59
  Args:
43
- path (str): path to the file where the manifest will be created
44
60
  name (str): base name of the volumes
45
61
  bucket (str): name of the storage bucket
46
- size (str): size of the storage
62
+ size (str): size of the storage (in GB)
63
+ prefetch_metadata (bool): if set, then enables metadata pre-population when mounting the volume
64
+ mount_options (str): comma-separated list of mountOptions for PersistentVolume
65
+
66
+ Returns:
67
+ list[dict]: list of manifests
47
68
  """
48
- pv = create_pv(name, size, bucket)
69
+ pv = create_pv(name, size, bucket, mount_options, prefetch_metadata)
49
70
  pvc = create_pvc(name, size)
50
71
  return [pv, pvc]
xpk/core/kjob.py CHANGED
@@ -14,27 +14,50 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from ..core.blueprint.blueprint_generator import get_subnetworks_for_a3mega, get_subnetworks_for_a3ultra
18
- from ..core.capacity import H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
19
17
  from argparse import Namespace
20
- import yaml
21
- from .workload_decorators.tcpxo_decorator import get_tcpxo_deamon_entry
22
- from ..utils.console import xpk_print, xpk_exit
18
+ from enum import Enum
23
19
 
24
- from ..utils import templates
20
+ import yaml
25
21
  from kubernetes import client as k8s_client
26
22
  from kubernetes.client import ApiClient
27
23
  from kubernetes.client.rest import ApiException
28
- from .cluster import setup_k8s_env, XPK_SA, DEFAULT_NAMESPACE
29
- from .storage import get_auto_mount_storages, get_auto_mount_gcsfuse_storages
30
- from .commands import run_command_for_value, run_kubectl_apply, run_command_with_updates
31
- from .config import XpkConfig, KJOB_SHELL_IMAGE, KJOB_SHELL_INTERACTIVE_COMMAND, KJOB_SHELL_WORKING_DIRECTORY, KJOB_BATCH_IMAGE, KJOB_BATCH_WORKING_DIRECTORY
32
- from .resources import get_cluster_system_characteristics, SystemCharacteristics, AcceleratorType
33
- from enum import Enum
34
24
 
35
- from ..core.workload_decorators import tcpxo_decorator
36
-
37
- from ..core.workload_decorators import rdma_decorator
25
+ from ..utils import templates
26
+ from ..utils.console import xpk_exit, xpk_print
27
+ from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
28
+ from .cluster import DEFAULT_NAMESPACE, XPK_SA, setup_k8s_env
29
+ from .commands import (
30
+ run_command_for_value,
31
+ run_command_with_updates,
32
+ run_kubectl_apply,
33
+ )
34
+ from .config import (
35
+ KJOB_BATCH_IMAGE,
36
+ KJOB_BATCH_WORKING_DIRECTORY,
37
+ KJOB_SHELL_IMAGE,
38
+ KJOB_SHELL_INTERACTIVE_COMMAND,
39
+ KJOB_SHELL_WORKING_DIRECTORY,
40
+ XpkConfig,
41
+ )
42
+ from .network import get_cluster_subnetworks
43
+ from .resources import (
44
+ AcceleratorType,
45
+ SystemCharacteristics,
46
+ get_cluster_system_characteristics,
47
+ )
48
+ from .storage import (
49
+ GCS_FUSE_ANNOTATIONS,
50
+ PARALLELSTORE_ANNOTATIONS,
51
+ get_auto_mount_gcsfuse_storages,
52
+ get_auto_mount_parallelstore_storages,
53
+ get_auto_mount_storages,
54
+ )
55
+ from .workload_decorators import (
56
+ rdma_decorator,
57
+ tcpx_decorator,
58
+ tcpxo_decorator,
59
+ )
60
+ from .workload_decorators.tcpxo_decorator import get_tcpxo_deamon_entry
38
61
 
39
62
  KJOB_API_GROUP_NAME = "kjobctl.x-k8s.io"
40
63
  KJOB_API_GROUP_VERSION = "v1alpha1"
@@ -146,8 +169,20 @@ Kueue_TAS_annotation = "kueue.x-k8s.io/podset-preferred-topology=cloud.google.co
146
169
  default_interface_annotation = "networking.gke.io/default-interface=eth0"
147
170
 
148
171
 
172
+ def get_a4_pod_template_annotations(args) -> tuple[str, str]:
173
+ sub_networks = get_cluster_subnetworks(args)
174
+ interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
175
+ sub_networks
176
+ )
177
+
178
+ return (
179
+ default_interface_annotation,
180
+ f"{interfaces_key}=$'{interfaces_value}'",
181
+ )
182
+
183
+
149
184
  def get_a3ultra_pod_template_annotations(args: Namespace) -> tuple[str, str]:
150
- sub_networks = get_subnetworks_for_a3ultra(args.cluster)
185
+ sub_networks = get_cluster_subnetworks(args)
151
186
  interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
152
187
  sub_networks
153
188
  )
@@ -162,7 +197,7 @@ def get_a3mega_pod_template_annotations(
162
197
  args: Namespace,
163
198
  ) -> tuple[str, str, str]:
164
199
  """Adds or updates annotations in the Pod template."""
165
- sub_networks = get_subnetworks_for_a3mega(args.cluster)
200
+ sub_networks = get_cluster_subnetworks(args)
166
201
  tcpxo_deamon_key, tcpxo_deamon_paths = get_tcpxo_deamon_entry()
167
202
  interfaces_key, interfaces_value = tcpxo_decorator.get_interfaces_entry(
168
203
  sub_networks
@@ -237,6 +272,8 @@ def create_app_profile_instance(
237
272
 
238
273
  def decorate_job_template_with_gpu(yml_string: str, gpu_type: str) -> str:
239
274
  job_spec = yaml.safe_load(yml_string)["template"]
275
+ if gpu_type == H100_DEVICE_TYPE:
276
+ job_spec = tcpx_decorator.decorate_kjob_template(job_spec)
240
277
  if gpu_type == H100_MEGA_DEVICE_TYPE:
241
278
  job_spec = tcpxo_decorator.decorate_kjob_template(job_spec)
242
279
  if gpu_type == H200_DEVICE_TYPE:
@@ -436,9 +473,18 @@ def create_volume_bundle_instance(
436
473
  xpk_exit(1)
437
474
 
438
475
 
439
- def get_gcsfuse_annotation(args: Namespace) -> str | None:
476
+ def get_storage_annotations(args: Namespace) -> list[str]:
477
+ annotations = []
440
478
  k8s_api_client = setup_k8s_env(args)
479
+
441
480
  gcsfuse_storages = get_auto_mount_gcsfuse_storages(k8s_api_client)
442
481
  if len(gcsfuse_storages) > 0:
443
- return "gke-gcsfuse/volumes=true"
444
- return None
482
+ for key, value in GCS_FUSE_ANNOTATIONS.items():
483
+ annotations.append(f"{key}={value}")
484
+
485
+ parallelstore_storages = get_auto_mount_parallelstore_storages(k8s_api_client)
486
+ if len(parallelstore_storages) > 0:
487
+ for key, value in PARALLELSTORE_ANNOTATIONS.items():
488
+ annotations.append(f"{key}={value}")
489
+
490
+ return annotations
xpk/core/kueue.py CHANGED
@@ -21,6 +21,7 @@ from packaging.version import Version
21
21
 
22
22
  from ..utils.console import xpk_exit, xpk_print
23
23
  from ..utils.file import write_tmp_file
24
+ from .capacity import B200_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
24
25
  from .commands import (
25
26
  run_command_for_value,
26
27
  run_command_with_updates,
@@ -45,6 +46,19 @@ WAIT_FOR_KUEUE_TIMEOUT = '5m'
45
46
 
46
47
  packaging.version.VERSION_PATTERN = r'^v\d+\.\d+\.\d+$'
47
48
 
49
+ topology_yaml = """apiVersion: kueue.x-k8s.io/v1alpha1
50
+ kind: Topology
51
+ metadata:
52
+ name: "gke-default"
53
+ spec:
54
+ levels:
55
+ - nodeLabel: "cloud.google.com/gce-topology-block"
56
+ - nodeLabel: "cloud.google.com/gce-topology-subblock"
57
+ - nodeLabel: "cloud.google.com/gce-topology-host"
58
+ - nodeLabel: "kubernetes.io/hostname"
59
+ ---
60
+ """
61
+
48
62
  cluster_set_crd_yaml = """apiVersion: kueue.x-k8s.io/v1beta1
49
63
  kind: ResourceFlavor
50
64
  metadata:
@@ -53,6 +67,7 @@ spec:
53
67
  nodeLabels:
54
68
  {accelerator_label}
55
69
  {machine_label}
70
+ {topology_label}
56
71
  ---
57
72
  {pw_resource_flavors}
58
73
  apiVersion: kueue.x-k8s.io/v1beta1
@@ -300,6 +315,14 @@ def install_kueue_crs(
300
315
  resource_type=resource_type,
301
316
  total_chips=total_chips,
302
317
  )
318
+ topology_label = ''
319
+ if system.device_type in [
320
+ H100_MEGA_DEVICE_TYPE,
321
+ H200_DEVICE_TYPE,
322
+ B200_DEVICE_TYPE,
323
+ ]:
324
+ topology_label = 'topologyName: "gke-default"'
325
+
303
326
  yml_string = cluster_set_crd_yaml.format(
304
327
  system=system,
305
328
  cluster_hardware_name=cluster_hardware_name,
@@ -309,6 +332,7 @@ def install_kueue_crs(
309
332
  machine_label=create_machine_label(
310
333
  system.accelerator_type, system, autoprovisioning_enabled
311
334
  ),
335
+ topology_label=topology_label,
312
336
  covered_resources_config=covered_resources_config,
313
337
  resource_type=AcceleratorTypeToAcceleratorCharacteristics[
314
338
  system.accelerator_type
@@ -318,6 +342,12 @@ def install_kueue_crs(
318
342
  cluster_queue_name=CLUSTER_QUEUE_NAME,
319
343
  local_queue_name=LOCAL_QUEUE_NAME,
320
344
  )
345
+ if system.device_type in [
346
+ H100_MEGA_DEVICE_TYPE,
347
+ H200_DEVICE_TYPE,
348
+ B200_DEVICE_TYPE,
349
+ ]:
350
+ yml_string = topology_yaml + yml_string
321
351
 
322
352
  tmp = write_tmp_file(yml_string)
323
353
  command = f'kubectl apply -f {str(tmp.file.name)}'