xpk 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -52,20 +52,6 @@ cluster_toolkit_url = "github.com/GoogleCloudPlatform/cluster-toolkit"
52
52
  cluster_toolkit_version = "v1.48.0"
53
53
 
54
54
 
55
- def get_subnetworks_for_a3mega(cluster_name: str) -> list[str]:
56
- return [f"{cluster_name}-gpunet-{i}-subnet" for i in range(8)]
57
-
58
-
59
- def get_subnetworks_for_a3ultra(cluster_name: str) -> list[str]:
60
- return [f"{cluster_name}-sub-1"] + [
61
- f"{cluster_name}-rdma-sub-{i}" for i in range(8)
62
- ]
63
-
64
-
65
- def get_subnetworks_for_a4() -> list[str]:
66
- return ["gvnic-1"] + [f"rdma-{i}" for i in range(8)]
67
-
68
-
69
55
  class BlueprintGeneratorOutput:
70
56
  """BlueprintGeneratorOutput is a class containing fields with output blueprint file path and path to blueprint dependencies.
71
57
  Atributes:
@@ -194,13 +180,17 @@ class BlueprintGenerator:
194
180
  a3_megagpu_pool_0 = DeploymentModule(
195
181
  id="a3_megagpu_pool_0",
196
182
  source="modules/compute/gke-node-pool",
197
- use=["gke_cluster", gpu_subnets_name, "group_placement_0"],
183
+ use=["gke_cluster", gpu_subnets_name],
198
184
  settings={
199
185
  "name": f"{cluster_name}-a3-megagpu-pool-0",
200
186
  "machine_type": system.gce_machine_type,
201
187
  "static_node_count": num_nodes,
202
188
  "zones": [zone],
203
- "host_maintenance_interval": "PERIODIC",
189
+ "host_maintenance_interval": (
190
+ None
191
+ if capacity_type == CapacityType.RESERVATION
192
+ else "PERIODIC"
193
+ ),
204
194
  "reservation_affinity": self._getblock_reservation_affinity(
205
195
  reservation
206
196
  ),
@@ -211,6 +201,9 @@ class BlueprintGenerator:
211
201
  },
212
202
  outputs=["instructions"],
213
203
  )
204
+
205
+ set_placement_policy = capacity_type != CapacityType.SPOT
206
+ tas_name = "topologyName: 'gke-default'" if set_placement_policy else ""
214
207
  num_chips = num_nodes * system.chips_per_vm
215
208
  workload = DeploymentModule(
216
209
  id="workload_component_install",
@@ -221,7 +214,10 @@ class BlueprintGenerator:
221
214
  "install": True,
222
215
  "version": "v0.10.0", # TAS feature-gates is enabled in CT
223
216
  "config_path": f'$(ghpc_stage("{blueprint_name}"))/kueue-xpk-configuration.yaml.tftpl',
224
- "config_template_vars": {"num_chips": num_chips},
217
+ "config_template_vars": {
218
+ "num_chips": num_chips,
219
+ "tas_name": tas_name,
220
+ },
225
221
  },
226
222
  "jobset": {"install": True, "version": "v0.7.2"},
227
223
  "apply_manifests": [{
@@ -257,12 +253,16 @@ class BlueprintGenerator:
257
253
  primary_vpc,
258
254
  gpunets,
259
255
  gke_cluster,
260
- group_placement_0,
261
256
  a3_megagpu_pool_0,
262
257
  workload,
263
258
  workload_configmap,
264
259
  ],
265
260
  )
261
+
262
+ if set_placement_policy:
263
+ a3_megagpu_pool_0.use.append(group_placement_0.id)
264
+ primary_group.modules.append(group_placement_0)
265
+
266
266
  a3_mega_blueprint = Blueprint(
267
267
  terraform_backend_defaults=self._getblock_terraform_backend(
268
268
  gcs_bucket, cluster_name, prefix
xpk/core/cluster.py CHANGED
@@ -14,28 +14,37 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import yaml
17
18
  from google.api_core.exceptions import PermissionDenied
18
19
  from google.cloud import resourcemanager_v3
19
20
  from kubernetes import client as k8s_client
20
21
  from kubernetes import config
21
22
  from kubernetes.client.exceptions import ApiException
22
- from .resources import get_cluster_system_characteristics
23
23
 
24
24
  from ..utils.console import xpk_exit, xpk_print
25
- from .capacity import H100_DEVICE_TYPE
25
+ from .capacity import B200_DEVICE_TYPE, H100_DEVICE_TYPE, H200_DEVICE_TYPE
26
26
  from .commands import (
27
27
  run_command_for_value,
28
28
  run_command_with_updates,
29
29
  run_command_with_updates_retry,
30
30
  )
31
- from .gcloud_context import add_zone_and_project, get_gke_server_config, zone_to_region
31
+ from .gcloud_context import (
32
+ add_zone_and_project,
33
+ get_gke_server_config,
34
+ zone_to_region,
35
+ )
32
36
  from .nodepool import upgrade_gke_nodepools_version
37
+ from .resources import get_cluster_system_characteristics
33
38
  from .system_characteristics import SystemCharacteristics
34
39
 
35
40
  JOBSET_VERSION = 'v0.8.0'
36
- PATHWAYS_JOB_VERSION = 'v0.1.0'
37
- INSTALLER_NCC_TCPX = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-tcpx-installer.yaml'
38
- INSTALLER_NCC_TCPXO = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpxo/nccl-tcpxo-installer.yaml'
41
+ PATHWAYS_JOB_VERSION = 'v0.1.1'
42
+ INSTALLER_NCCL_TCPX = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-tcpx-installer.yaml'
43
+ INSTALLER_NCCL_TCPXO = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpxo/nccl-tcpxo-installer.yaml'
44
+ INSTALLER_NCCL_RDMA = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-rdma/nccl-rdma-installer.yaml'
45
+ CONFIG_NCCL_TCPX = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-config.yaml'
46
+ NRI_DEVICE_INJECTOR = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nri_device_injector/nri-device-injector.yaml'
47
+ MGLRU_DISABLE = 'https://raw.githubusercontent.com/GoogleCloudPlatform/cluster-toolkit/refs/heads/main/examples/gke-a3-ultragpu/mglru-disable.yaml'
39
48
 
40
49
  DEFAULT_NAMESPACE = 'default'
41
50
  XPK_SA = 'xpk-sa'
@@ -112,9 +121,11 @@ def install_nccl_on_cluster(args, system: SystemCharacteristics) -> int:
112
121
  0 if successful and 1 otherwise.
113
122
  """
114
123
  if system.device_type == H100_DEVICE_TYPE:
115
- command = f'kubectl apply -f {INSTALLER_NCC_TCPX}'
124
+ command = f'kubectl apply -f {INSTALLER_NCCL_TCPX}'
125
+ elif system.device_type in [H200_DEVICE_TYPE, B200_DEVICE_TYPE]:
126
+ command = f'kubectl apply -f {INSTALLER_NCCL_RDMA}'
116
127
  else:
117
- command = f'kubectl apply -f {INSTALLER_NCC_TCPXO}'
128
+ command = f'kubectl apply -f {INSTALLER_NCCL_TCPXO}'
118
129
 
119
130
  return_code = run_command_with_updates(
120
131
  command, 'Install NCCL Plugin On Cluster', args
@@ -126,9 +137,108 @@ def install_nccl_on_cluster(args, system: SystemCharacteristics) -> int:
126
137
  )
127
138
  return 1
128
139
 
140
+ if system.device_type == H100_DEVICE_TYPE:
141
+ command = f'kubectl apply -f {CONFIG_NCCL_TCPX}'
142
+
143
+ return_code = run_command_with_updates(
144
+ command, 'Install NCCL Config On Cluster', args
145
+ )
146
+
147
+ if return_code != 0:
148
+ xpk_print(
149
+ f'Install NCCL Config On Cluster request returned ERROR {return_code}'
150
+ )
151
+ return 1
152
+
153
+ return 0
154
+
155
+
156
+ def disable_mglru_on_cluster(args) -> int:
157
+ """Disable MGLRU on the cluster.
158
+
159
+ Args:
160
+ args: user provided arguments for running the command.
161
+
162
+ Returns:
163
+ 0 if successful and 1 otherwise.
164
+ """
165
+ command = f'kubectl apply -f {MGLRU_DISABLE}'
166
+ return_code = run_command_with_updates(
167
+ command, 'Disable MGLRU On Cluster', args
168
+ )
169
+
170
+ if return_code != 0:
171
+ xpk_print('Disablig MGLRU On Cluster request returned ERROR')
172
+ return 1
173
+
129
174
  return 0
130
175
 
131
176
 
177
+ def install_nri_on_cluster(args) -> int:
178
+ """Install NRI Device Injector on the cluster.
179
+
180
+ Args:
181
+ args: user provided arguments for running the command.
182
+ system: system characteristics.
183
+
184
+ Returns:
185
+ 0 if successful and 1 otherwise.
186
+ """
187
+ command = f'kubectl apply -f {NRI_DEVICE_INJECTOR}'
188
+ return_code = run_command_with_updates(
189
+ command, 'Install NRI Device Injector On Cluster', args
190
+ )
191
+
192
+ if return_code != 0:
193
+ xpk_print(
194
+ 'Install NRI Device Injector On Cluster request returned ERROR'
195
+ f' {return_code}'
196
+ )
197
+ return 1
198
+
199
+ return 0
200
+
201
+
202
+ def get_cluster_nodes_info(args) -> list[dict]:
203
+ """Get list of cluster's nodes descrition in yaml format
204
+
205
+ Args:
206
+ args: user provided arguments for running the command.
207
+
208
+ Returns:
209
+ List of nodes info yaml objects.
210
+ """
211
+ xpk_print("Getting cluster's info...")
212
+ command = 'kubectl get nodes -o yaml'
213
+ err_code, val = run_command_for_value(
214
+ command=command,
215
+ task='Get cluster nodes info',
216
+ global_args=args,
217
+ )
218
+ if err_code != 0:
219
+ xpk_exit(err_code)
220
+ data = yaml.safe_load(val)
221
+ return data['items'] # pytype: disable=bad-return-type
222
+
223
+
224
+ def count_nodes_on_cluster(args, system: SystemCharacteristics) -> int:
225
+ """Count cluster nodes by accelerator type"""
226
+ nodes_info = get_cluster_nodes_info(args)
227
+ accelerators = [
228
+ node['metadata']['labels']['cloud.google.com/gke-accelerator']
229
+ for node in nodes_info
230
+ if 'cloud.google.com/gke-accelerator' in node['metadata']['labels']
231
+ ]
232
+ if system.device_type != H200_DEVICE_TYPE:
233
+ xpk_print(
234
+ 'Automatic node detection is not supported for device type:'
235
+ f' {system.device_type}'
236
+ )
237
+ xpk_exit(1)
238
+ num_nodes: int = sum(acc == system.gke_accelerator for acc in accelerators)
239
+ return num_nodes
240
+
241
+
132
242
  def get_cluster_network(args) -> str:
133
243
  xpk_print("Getting cluster's VPC network...")
134
244
  cluster_network_cmd = (
@@ -621,6 +731,7 @@ def get_cluster_credentials(args) -> None:
621
731
  command = (
622
732
  'gcloud container clusters get-credentials'
623
733
  f' {args.cluster} --region={zone_to_region(args.zone)}'
734
+ ' --dns-endpoint'
624
735
  f' --project={args.project} &&'
625
736
  ' kubectl config view && kubectl config set-context --current'
626
737
  ' --namespace=default'
xpk/core/config.py CHANGED
@@ -24,7 +24,7 @@ from ..utils.console import xpk_print
24
24
  from .system_characteristics import AcceleratorType, SystemCharacteristics
25
25
 
26
26
  # This is the version for XPK PyPI package
27
- __version__ = 'v0.8.0'
27
+ __version__ = 'v0.9.0'
28
28
  XPK_CURRENT_VERSION = __version__
29
29
  XPK_CONFIG_FILE = os.path.expanduser('~/.config/xpk/config.yaml')
30
30
 
xpk/core/filestore.py CHANGED
@@ -200,9 +200,7 @@ class FilestoreClient:
200
200
  ] = f"projects/{self.project}/global/networks/{network}"
201
201
  return data
202
202
 
203
- def create_pv(
204
- self, name: str, vol: str, access_mode: str, mount_options: str
205
- ) -> dict:
203
+ def create_pv(self, name: str, vol: str, access_mode: str) -> dict:
206
204
  """Create a yaml representing filestore PersistentVolume."""
207
205
  data = templates.load(FS_PV_PATH)
208
206
  data["metadata"]["name"] = get_pv_name(name)
@@ -217,7 +215,6 @@ class FilestoreClient:
217
215
  0
218
216
  ].ip_addresses[0]
219
217
  data["spec"]["csi"]["volumeAttributes"]["volume"] = vol
220
- data["spec"]["mountOptions"] = mount_options.split(",")
221
218
  return data
222
219
 
223
220
  def create_pvc(self, name: str, access_mode: str) -> dict:
@@ -238,10 +235,9 @@ class FilestoreClient:
238
235
  vol: str,
239
236
  access_mode: str,
240
237
  network: str,
241
- mount_options: str,
242
238
  ) -> list[dict]:
243
239
  self.load_instance()
244
- pv = self.create_pv(name, vol, access_mode, mount_options)
240
+ pv = self.create_pv(name, vol, access_mode)
245
241
  pvc = self.create_pvc(name, access_mode)
246
242
  sc = self.create_sc(name, network)
247
243
  return [pv, pvc, sc]
xpk/core/gcsfuse.py CHANGED
@@ -20,11 +20,21 @@ FUSE_PV_PATH = "/../templates/fuse-pv.yaml"
20
20
  FUSE_PVC_PATH = "/../templates/fuse-pvc.yaml"
21
21
 
22
22
 
23
- def create_pv(name: str, size: int, bucket: str, mount_options: str) -> dict:
23
+ def create_pv(
24
+ name: str,
25
+ size: int,
26
+ bucket: str,
27
+ mount_options: str,
28
+ prefetch_metadata: bool,
29
+ ) -> dict:
24
30
  data = templates.load(FUSE_PV_PATH)
25
31
  data["metadata"]["name"] = f"{name}-pv"
26
32
  data["spec"]["capacity"]["storage"] = f"{size}Gi"
27
33
  data["spec"]["csi"]["volumeHandle"] = bucket
34
+ if prefetch_metadata:
35
+ data["spec"]["csi"]["volumeAttributes"][
36
+ "gcsfuseMetadataPrefetchOnMount"
37
+ ] = "true"
28
38
  data["spec"]["mountOptions"] = mount_options.split(",")
29
39
  return data
30
40
 
@@ -38,16 +48,24 @@ def create_pvc(name: str, size: int) -> dict:
38
48
 
39
49
 
40
50
  def manifest(
41
- name: str, bucket: str, size: int, mount_options: str
51
+ name: str,
52
+ bucket: str,
53
+ size: int,
54
+ mount_options: str,
55
+ prefetch_metadata: bool,
42
56
  ) -> list[dict]:
43
- """Creates GCS FUSE manifest file.
57
+ """Creates GCS FUSE storage manifest file.
44
58
 
45
59
  Args:
46
60
  name (str): base name of the volumes
47
61
  bucket (str): name of the storage bucket
48
62
  size (str): size of the storage (in GB)
63
+ prefetch_metadata (bool): if set, then enables metadata pre-population when mounting the volume
49
64
  mount_options (str): comma-separated list of mountOptions for PersistentVolume
65
+
66
+ Returns:
67
+ list[dict]: list of manifests
50
68
  """
51
- pv = create_pv(name, size, bucket, mount_options)
69
+ pv = create_pv(name, size, bucket, mount_options, prefetch_metadata)
52
70
  pvc = create_pvc(name, size)
53
71
  return [pv, pvc]
xpk/core/kjob.py CHANGED
@@ -22,16 +22,9 @@ from kubernetes import client as k8s_client
22
22
  from kubernetes.client import ApiClient
23
23
  from kubernetes.client.rest import ApiException
24
24
 
25
- from ..core.blueprint.blueprint_generator import (
26
- get_subnetworks_for_a3mega,
27
- get_subnetworks_for_a3ultra,
28
- get_subnetworks_for_a4,
29
- )
30
- from ..core.capacity import H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
31
- from ..core.storage import GCS_FUSE_ANNOTATIONS, PARALLELSTORE_ANNOTATIONS
32
- from ..core.workload_decorators import rdma_decorator, tcpxo_decorator
33
25
  from ..utils import templates
34
26
  from ..utils.console import xpk_exit, xpk_print
27
+ from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
35
28
  from .cluster import DEFAULT_NAMESPACE, XPK_SA, setup_k8s_env
36
29
  from .commands import (
37
30
  run_command_for_value,
@@ -46,12 +39,24 @@ from .config import (
46
39
  KJOB_SHELL_WORKING_DIRECTORY,
47
40
  XpkConfig,
48
41
  )
42
+ from .network import get_cluster_subnetworks
49
43
  from .resources import (
50
44
  AcceleratorType,
51
45
  SystemCharacteristics,
52
46
  get_cluster_system_characteristics,
53
47
  )
54
- from .storage import get_auto_mount_gcsfuse_storages, get_auto_mount_storages, get_auto_mount_parallelstore_storages
48
+ from .storage import (
49
+ GCS_FUSE_ANNOTATIONS,
50
+ PARALLELSTORE_ANNOTATIONS,
51
+ get_auto_mount_gcsfuse_storages,
52
+ get_auto_mount_parallelstore_storages,
53
+ get_auto_mount_storages,
54
+ )
55
+ from .workload_decorators import (
56
+ rdma_decorator,
57
+ tcpx_decorator,
58
+ tcpxo_decorator,
59
+ )
55
60
  from .workload_decorators.tcpxo_decorator import get_tcpxo_deamon_entry
56
61
 
57
62
  KJOB_API_GROUP_NAME = "kjobctl.x-k8s.io"
@@ -164,8 +169,8 @@ Kueue_TAS_annotation = "kueue.x-k8s.io/podset-preferred-topology=cloud.google.co
164
169
  default_interface_annotation = "networking.gke.io/default-interface=eth0"
165
170
 
166
171
 
167
- def get_a4_pod_template_annotations() -> tuple[str, str]:
168
- sub_networks = get_subnetworks_for_a4()
172
+ def get_a4_pod_template_annotations(args) -> tuple[str, str]:
173
+ sub_networks = get_cluster_subnetworks(args)
169
174
  interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
170
175
  sub_networks
171
176
  )
@@ -177,7 +182,7 @@ def get_a4_pod_template_annotations() -> tuple[str, str]:
177
182
 
178
183
 
179
184
  def get_a3ultra_pod_template_annotations(args: Namespace) -> tuple[str, str]:
180
- sub_networks = get_subnetworks_for_a3ultra(args.cluster)
185
+ sub_networks = get_cluster_subnetworks(args)
181
186
  interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
182
187
  sub_networks
183
188
  )
@@ -192,7 +197,7 @@ def get_a3mega_pod_template_annotations(
192
197
  args: Namespace,
193
198
  ) -> tuple[str, str, str]:
194
199
  """Adds or updates annotations in the Pod template."""
195
- sub_networks = get_subnetworks_for_a3mega(args.cluster)
200
+ sub_networks = get_cluster_subnetworks(args)
196
201
  tcpxo_deamon_key, tcpxo_deamon_paths = get_tcpxo_deamon_entry()
197
202
  interfaces_key, interfaces_value = tcpxo_decorator.get_interfaces_entry(
198
203
  sub_networks
@@ -267,6 +272,8 @@ def create_app_profile_instance(
267
272
 
268
273
  def decorate_job_template_with_gpu(yml_string: str, gpu_type: str) -> str:
269
274
  job_spec = yaml.safe_load(yml_string)["template"]
275
+ if gpu_type == H100_DEVICE_TYPE:
276
+ job_spec = tcpx_decorator.decorate_kjob_template(job_spec)
270
277
  if gpu_type == H100_MEGA_DEVICE_TYPE:
271
278
  job_spec = tcpxo_decorator.decorate_kjob_template(job_spec)
272
279
  if gpu_type == H200_DEVICE_TYPE:
xpk/core/kueue.py CHANGED
@@ -21,6 +21,7 @@ from packaging.version import Version
21
21
 
22
22
  from ..utils.console import xpk_exit, xpk_print
23
23
  from ..utils.file import write_tmp_file
24
+ from .capacity import B200_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
24
25
  from .commands import (
25
26
  run_command_for_value,
26
27
  run_command_with_updates,
@@ -45,6 +46,19 @@ WAIT_FOR_KUEUE_TIMEOUT = '5m'
45
46
 
46
47
  packaging.version.VERSION_PATTERN = r'^v\d+\.\d+\.\d+$'
47
48
 
49
+ topology_yaml = """apiVersion: kueue.x-k8s.io/v1alpha1
50
+ kind: Topology
51
+ metadata:
52
+ name: "gke-default"
53
+ spec:
54
+ levels:
55
+ - nodeLabel: "cloud.google.com/gce-topology-block"
56
+ - nodeLabel: "cloud.google.com/gce-topology-subblock"
57
+ - nodeLabel: "cloud.google.com/gce-topology-host"
58
+ - nodeLabel: "kubernetes.io/hostname"
59
+ ---
60
+ """
61
+
48
62
  cluster_set_crd_yaml = """apiVersion: kueue.x-k8s.io/v1beta1
49
63
  kind: ResourceFlavor
50
64
  metadata:
@@ -53,6 +67,7 @@ spec:
53
67
  nodeLabels:
54
68
  {accelerator_label}
55
69
  {machine_label}
70
+ {topology_label}
56
71
  ---
57
72
  {pw_resource_flavors}
58
73
  apiVersion: kueue.x-k8s.io/v1beta1
@@ -300,6 +315,14 @@ def install_kueue_crs(
300
315
  resource_type=resource_type,
301
316
  total_chips=total_chips,
302
317
  )
318
+ topology_label = ''
319
+ if system.device_type in [
320
+ H100_MEGA_DEVICE_TYPE,
321
+ H200_DEVICE_TYPE,
322
+ B200_DEVICE_TYPE,
323
+ ]:
324
+ topology_label = 'topologyName: "gke-default"'
325
+
303
326
  yml_string = cluster_set_crd_yaml.format(
304
327
  system=system,
305
328
  cluster_hardware_name=cluster_hardware_name,
@@ -309,6 +332,7 @@ def install_kueue_crs(
309
332
  machine_label=create_machine_label(
310
333
  system.accelerator_type, system, autoprovisioning_enabled
311
334
  ),
335
+ topology_label=topology_label,
312
336
  covered_resources_config=covered_resources_config,
313
337
  resource_type=AcceleratorTypeToAcceleratorCharacteristics[
314
338
  system.accelerator_type
@@ -318,6 +342,12 @@ def install_kueue_crs(
318
342
  cluster_queue_name=CLUSTER_QUEUE_NAME,
319
343
  local_queue_name=LOCAL_QUEUE_NAME,
320
344
  )
345
+ if system.device_type in [
346
+ H100_MEGA_DEVICE_TYPE,
347
+ H200_DEVICE_TYPE,
348
+ B200_DEVICE_TYPE,
349
+ ]:
350
+ yml_string = topology_yaml + yml_string
321
351
 
322
352
  tmp = write_tmp_file(yml_string)
323
353
  command = f'kubectl apply -f {str(tmp.file.name)}'
xpk/core/mtc.py ADDED
@@ -0,0 +1,195 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import requests
18
+ import yaml
19
+
20
+ from ..core.cluster import JOBSET_VERSION
21
+ from ..core.cluster import setup_k8s_env
22
+ from ..utils import templates
23
+ from ..utils.console import xpk_exit
24
+ from ..utils.console import xpk_print
25
+ from ..utils.kubectl import apply_kubectl_manifest
26
+
27
+
28
+ MTC_CPC_PATH = "/../templates/mtc-cpc.yaml"
29
+
30
+
31
+ def create_mtc_cpc(
32
+ mtc_gcs_bucket: str,
33
+ mtc_machine_type: str,
34
+ mtc_toleration_key: str,
35
+ mtc_ramdisk_size: str,
36
+ ) -> dict:
37
+ """Create MTC Checkpoint Configuration.
38
+
39
+ Args:
40
+ mtc_gcs_bucket: GCS bucket for MTC
41
+ mtc_machine_type: Machine type for MTC
42
+ mtc_toleration_key: Toleration key for MTC
43
+ mtc_ramdisk_size: Ramdisk size for MTC
44
+
45
+ Returns:
46
+ MTC Checkpoint Configuration
47
+ """
48
+ data = templates.load(MTC_CPC_PATH)
49
+
50
+ data["spec"]["cloudStorageBucketName"] = mtc_gcs_bucket
51
+ data["spec"]["nodeSelector"][
52
+ "node.kubernetes.io/instance-type"
53
+ ] = mtc_machine_type
54
+ data["spec"]["tolerations"][0]["key"] = mtc_toleration_key
55
+ data["spec"]["inMemoryVolumeSize"] = mtc_ramdisk_size
56
+
57
+ return data
58
+
59
+
60
+ def install_mtc_on_cluster(args, system) -> int:
61
+ """Install MTC on the cluster.
62
+
63
+ Args:
64
+ args: user provided arguments for running the command.
65
+ system: system related information.
66
+
67
+ Returns:
68
+ return code of the command.
69
+ """
70
+ if args.mtc_gcs_bucket is None:
71
+ xpk_print("MTC GCS bucket is required.")
72
+ xpk_exit(1)
73
+ if args.mtc_gcs_bucket.startswith("gs://"):
74
+ args.mtc_gcs_bucket = args.mtc_gcs_bucket.replace("gs://", "")
75
+
76
+ if args.mtc_ramdisk_size is None:
77
+ xpk_print("MTC ramdisk size is required.")
78
+ xpk_exit(1)
79
+
80
+ if args.mtc_toleration_key is None:
81
+ args.mtc_toleration_key = "google.com/tpu"
82
+
83
+ k8s_api_client = setup_k8s_env(args)
84
+ jobset_manifest = update_jobset_manifest()
85
+ if jobset_manifest is None:
86
+ xpk_print(
87
+ "Updated jobset manifest is empty, not updating the jobset controller."
88
+ )
89
+
90
+ xpk_print("Applying Jobset with MTC Configuration")
91
+ return_code = apply_kubectl_manifest(k8s_api_client, [jobset_manifest])
92
+ if return_code != 0:
93
+ return return_code
94
+
95
+ mtc_checkpoint_configuration_crd_data = create_mtc_cpc(
96
+ args.mtc_gcs_bucket,
97
+ system.gce_machine_type,
98
+ args.mtc_toleration_key,
99
+ args.mtc_ramdisk_size,
100
+ )
101
+ xpk_print("Applying MTC Checkpoint Configuration")
102
+ return_code = apply_kubectl_manifest(
103
+ k8s_api_client, [mtc_checkpoint_configuration_crd_data]
104
+ )
105
+
106
+ return return_code
107
+
108
+
109
+ def update_jobset_manifest():
110
+ """Update the jobset manifest to increase the resources for the jobset controller manager.
111
+
112
+ Returns:
113
+ The updated jobset manifest.
114
+ """
115
+ manifest_url = f"https://github.com/kubernetes-sigs/jobset/releases/download/{JOBSET_VERSION}/manifests.yaml"
116
+ manifest_content = None
117
+ # Fetch the manifest content
118
+ try:
119
+ response = requests.get(manifest_url, timeout=10)
120
+ response.raise_for_status() # Raise an exception for HTTP errors
121
+ manifest_content = response.text
122
+ except requests.exceptions.Timeout as e:
123
+ xpk_print(f"Error: Request to {manifest_url} after 10 seconds: {e}")
124
+ xpk_exit(1)
125
+ except requests.exceptions.RequestException as e:
126
+ xpk_print(f"Error fetching manifest from {manifest_url}: {e}")
127
+ xpk_exit(1)
128
+
129
+ if manifest_content is None:
130
+ xpk_print("Manifest content not found.")
131
+ xpk_exit(1)
132
+
133
+ # Load all YAML documents from the manifest
134
+ yaml_data_list = list(yaml.safe_load_all(manifest_content))
135
+ # Iterate through the yaml_data to find the Deployment for
136
+ # jobset-controller-manager
137
+ update_manifest = False
138
+ for yaml_data in yaml_data_list:
139
+ if (
140
+ yaml_data
141
+ and yaml_data.get("apiVersion") == "apps/v1"
142
+ and yaml_data.get("kind") == "Deployment"
143
+ and yaml_data.get("metadata", {}).get("name")
144
+ == "jobset-controller-manager"
145
+ ):
146
+ # Found the Deployment, now modify the resources
147
+ containers = yaml_data["spec"]["template"]["spec"]["containers"]
148
+ for container in containers:
149
+ if container["name"] == "manager":
150
+ # Update resource limits and requests
151
+ current_cpu_request = (
152
+ container["resources"].get("requests", {}).get("cpu", "0m")
153
+ )
154
+ current_memory_request = (
155
+ container["resources"].get("requests", {}).get("memory", "0Mi")
156
+ )
157
+ current_memory_limit = (
158
+ container["resources"].get("limits", {}).get("memory", "0Mi")
159
+ )
160
+
161
+ # Define new values for comparison
162
+ new_cpu_request = "1000m"
163
+ new_memory_request = "1Gi"
164
+ new_memory_limit = "2Gi"
165
+
166
+ if parse_resource_value(current_cpu_request) < parse_resource_value(
167
+ new_cpu_request
168
+ ):
169
+ container["resources"]["requests"]["cpu"] = new_cpu_request
170
+ update_manifest = True
171
+ if parse_resource_value(
172
+ current_memory_request
173
+ ) < parse_resource_value(new_memory_request):
174
+ container["resources"]["requests"]["memory"] = new_memory_request
175
+ update_manifest = True
176
+ if parse_resource_value(current_memory_limit) < parse_resource_value(
177
+ new_memory_limit
178
+ ):
179
+ container["resources"]["limits"]["memory"] = new_memory_limit
180
+ update_manifest = True
181
+ break
182
+ if update_manifest:
183
+ xpk_print("Jobset controller updation required.")
184
+ return yaml_data
185
+ xpk_print("Jobset controller no updation required.")
186
+
187
+
188
+ def parse_resource_value(value) -> int:
189
+ if value.endswith("m"):
190
+ return int(value[:-1])
191
+ if value.endswith("Mi"):
192
+ return int(value[:-2])
193
+ if value.endswith("Gi"):
194
+ return int(value[:-2]) * 1024
195
+ return int(value)