xpk 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. xpk/api/__init__.py +15 -0
  2. xpk/api/storage_crd.yaml +52 -0
  3. xpk/commands/batch.py +27 -5
  4. xpk/commands/cluster.py +104 -80
  5. xpk/commands/cluster_gcluster.py +94 -10
  6. xpk/commands/common.py +44 -0
  7. xpk/commands/config.py +29 -0
  8. xpk/commands/info.py +8 -10
  9. xpk/commands/inspector.py +5 -11
  10. xpk/commands/job.py +9 -7
  11. xpk/commands/kind.py +34 -4
  12. xpk/commands/kjob_common.py +44 -0
  13. xpk/commands/run.py +128 -0
  14. xpk/commands/shell.py +27 -7
  15. xpk/commands/storage.py +267 -0
  16. xpk/commands/version.py +6 -18
  17. xpk/commands/workload.py +381 -184
  18. xpk/core/blueprint/blueprint_definitions.py +1 -0
  19. xpk/core/blueprint/blueprint_generator.py +132 -76
  20. xpk/core/capacity.py +185 -0
  21. xpk/core/cluster.py +564 -0
  22. xpk/core/cluster_private.py +6 -3
  23. xpk/core/commands.py +18 -14
  24. xpk/core/config.py +179 -0
  25. xpk/core/docker_container.py +225 -0
  26. xpk/core/docker_image.py +210 -0
  27. xpk/core/docker_resources.py +350 -0
  28. xpk/core/filestore.py +251 -0
  29. xpk/core/gcloud_context.py +196 -0
  30. xpk/core/gcluster_manager.py +20 -2
  31. xpk/core/gcsfuse.py +50 -0
  32. xpk/core/kjob.py +257 -18
  33. xpk/core/kueue.py +12 -6
  34. xpk/core/monitoring.py +134 -0
  35. xpk/core/nap.py +32 -20
  36. xpk/core/network.py +377 -0
  37. xpk/core/nodepool.py +581 -0
  38. xpk/core/pathways.py +124 -45
  39. xpk/core/remote_state/__init__.py +15 -0
  40. xpk/core/remote_state/fuse_remote_state.py +99 -0
  41. xpk/core/remote_state/remote_state_client.py +38 -0
  42. xpk/core/resources.py +238 -0
  43. xpk/core/scheduling.py +253 -0
  44. xpk/core/storage.py +581 -0
  45. xpk/core/system_characteristics.py +38 -1
  46. xpk/core/vertex.py +105 -0
  47. xpk/core/workload.py +209 -1
  48. xpk/core/workload_decorators/rdma_decorator.py +25 -5
  49. xpk/core/workload_decorators/storage_decorator.py +52 -0
  50. xpk/core/workload_decorators/tcpxo_decorator.py +70 -37
  51. xpk/main.py +3 -1
  52. xpk/parser/batch.py +10 -151
  53. xpk/parser/cluster.py +49 -8
  54. xpk/parser/common.py +189 -1
  55. xpk/parser/config.py +49 -0
  56. xpk/parser/core.py +27 -1
  57. xpk/parser/info.py +2 -1
  58. xpk/parser/inspector.py +3 -3
  59. xpk/parser/job.py +25 -4
  60. xpk/parser/kind.py +3 -2
  61. xpk/parser/run.py +47 -0
  62. xpk/parser/shell.py +10 -1
  63. xpk/parser/storage.py +316 -0
  64. xpk/parser/validators.py +3 -3
  65. xpk/parser/workload.py +118 -76
  66. xpk/templates/__init__.py +15 -0
  67. xpk/templates/storage.yaml +13 -0
  68. xpk/utils/gcs_utils.py +125 -0
  69. xpk/utils/kubectl.py +57 -0
  70. xpk/utils/objects.py +8 -5
  71. xpk/utils/templates.py +28 -0
  72. xpk/utils/validation.py +80 -0
  73. {xpk-0.6.0.dist-info → xpk-0.7.0.dist-info}/METADATA +165 -14
  74. xpk-0.7.0.dist-info/RECORD +92 -0
  75. {xpk-0.6.0.dist-info → xpk-0.7.0.dist-info}/WHEEL +1 -1
  76. xpk/core/core.py +0 -2824
  77. xpk-0.6.0.dist-info/RECORD +0 -57
  78. {xpk-0.6.0.dist-info → xpk-0.7.0.dist-info}/LICENSE +0 -0
  79. {xpk-0.6.0.dist-info → xpk-0.7.0.dist-info}/entry_points.txt +0 -0
  80. {xpk-0.6.0.dist-info → xpk-0.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,196 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import subprocess
18
+ import sys
19
+ from dataclasses import dataclass
20
+
21
+ from ..utils.console import xpk_print
22
+ from .commands import run_command_for_value
23
+
24
+
25
+ def get_project():
26
+ """Get GCE project from `gcloud config get project`.
27
+
28
+ Returns:
29
+ The project name.
30
+ """
31
+ completed_command = subprocess.run(
32
+ ['gcloud', 'config', 'get', 'project'], check=True, capture_output=True
33
+ )
34
+ project_outputs = completed_command.stdout.decode().strip().split('\n')
35
+ if len(project_outputs) < 1 or project_outputs[-1] == '':
36
+ sys.exit(
37
+ 'You must specify the project in the project flag or set it with'
38
+ " 'gcloud config set project <project>'"
39
+ )
40
+ return project_outputs[
41
+ -1
42
+ ] # The project name lives on the last line of the output
43
+
44
+
45
+ def get_zone():
46
+ """Get GCE zone from `gcloud config get compute/zone`.
47
+
48
+ Returns:
49
+ The zone name.
50
+ """
51
+ completed_command = subprocess.run(
52
+ ['gcloud', 'config', 'get', 'compute/zone'],
53
+ check=True,
54
+ capture_output=True,
55
+ )
56
+ zone_outputs = completed_command.stdout.decode().strip().split('\n')
57
+ if len(zone_outputs) < 1 or zone_outputs[-1] == '':
58
+ sys.exit(
59
+ "You must specify the zone in the zone flag or set it with 'gcloud"
60
+ " config set compute/zone <zone>'"
61
+ )
62
+ return zone_outputs[-1] # The zone name lives on the last line of the output
63
+
64
+
65
+ def add_zone_and_project(args):
66
+ """Obtains the zone and project names from gcloud configs if not defined.
67
+
68
+ Args:
69
+ args: user provided arguments for running the command.
70
+ """
71
+ if not args.project:
72
+ args.project = get_project()
73
+ if not args.zone:
74
+ args.zone = get_zone()
75
+ xpk_print(f'Working on {args.project} and {args.zone}')
76
+
77
+
78
+ def zone_to_region(zone) -> str:
79
+ """Helper function converts zone name to region name.
80
+
81
+ Args:
82
+ zone: zone name.
83
+
84
+ Returns:
85
+ The region name.
86
+ """
87
+ zone_terms = zone.split('-')
88
+ return zone_terms[0] + '-' + zone_terms[1] # pytype: disable=bad-return-type
89
+
90
+
91
+ @dataclass
92
+ class GkeServerConfig:
93
+ """Stores the valid gke versions based on gcloud recommendations."""
94
+
95
+ default_rapid_gke_version: str
96
+ valid_versions: set[str]
97
+
98
+
99
+ def get_gke_server_config(args) -> tuple[int, GkeServerConfig | None]:
100
+ """Determine the GKE versions supported by gcloud currently.
101
+
102
+ Args:
103
+ args: user provided arguments for running the command.
104
+
105
+ Returns:
106
+ Tuple of
107
+ int: 0 if successful and 1 otherwise.
108
+ GkeServerConfig: stores valid gke version to use in node pool and cluster.
109
+ """
110
+ base_command = (
111
+ 'gcloud container get-server-config'
112
+ f' --project={args.project} --region={zone_to_region(args.zone)}'
113
+ )
114
+ default_rapid_gke_version_cmd = (
115
+ base_command
116
+ + ' --flatten="channels" --filter="channels.channel=RAPID"'
117
+ ' --format="value(channels.defaultVersion)"'
118
+ )
119
+ valid_versions_cmd = (
120
+ base_command
121
+ + ' --flatten="channels" --filter="channels.channel=RAPID"'
122
+ ' --format="value(channels.validVersions)"'
123
+ )
124
+ base_command_description = 'Determine server supported GKE versions for '
125
+
126
+ server_config_commands_and_descriptions = [
127
+ (
128
+ default_rapid_gke_version_cmd,
129
+ base_command_description + 'default rapid gke version',
130
+ ),
131
+ (
132
+ valid_versions_cmd,
133
+ base_command_description + 'valid versions',
134
+ ),
135
+ ]
136
+ command_outputs = []
137
+
138
+ for command, command_description in server_config_commands_and_descriptions:
139
+ return_code, cmd_output = run_command_for_value(
140
+ command,
141
+ command_description,
142
+ args,
143
+ hide_error=True,
144
+ )
145
+ if return_code != 0:
146
+ xpk_print(f'Unable to get server config for {command_description}.')
147
+ return return_code, None
148
+ command_outputs.append(cmd_output)
149
+
150
+ return 0, GkeServerConfig(
151
+ default_rapid_gke_version=command_outputs[0].strip(),
152
+ valid_versions=set(command_outputs[1].split(';')),
153
+ )
154
+
155
+
156
+ def get_gke_control_plane_version(
157
+ args, gke_server_config: GkeServerConfig
158
+ ) -> tuple[int, str | None]:
159
+ """Determine gke control plane version for cluster creation.
160
+
161
+ Args:
162
+ args: user provided arguments for running the command.
163
+ gke_server_config: holds valid gke versions and recommended default version.
164
+
165
+ Returns:
166
+ Tuple of
167
+ int: 0 if successful and 1 otherwise.
168
+ str: gke control plane version to use.
169
+ """
170
+
171
+ # Override with user provide gke version if specified.
172
+ if args.gke_version is not None:
173
+ master_gke_version = args.gke_version
174
+ else:
175
+ master_gke_version = gke_server_config.default_rapid_gke_version
176
+
177
+ is_valid_version = master_gke_version in gke_server_config.valid_versions
178
+
179
+ if not is_valid_version:
180
+ xpk_print(
181
+ f'Planned GKE Version: {master_gke_version}\n Valid Versions:'
182
+ f'\n{gke_server_config.valid_versions}\nRecommended / Default GKE'
183
+ f' Version: {gke_server_config.default_rapid_gke_version}'
184
+ )
185
+ xpk_print(
186
+ f'Error: Planned GKE Version {master_gke_version} is not valid.'
187
+ f'Checks failed: Is Version Valid: {is_valid_version}'
188
+ )
189
+ xpk_print(
190
+ 'Please select a gke version from the above list using --gke-version=x'
191
+ ' argument or rely on the default gke version:'
192
+ f' {gke_server_config.default_rapid_gke_version}'
193
+ )
194
+ return 1, None
195
+
196
+ return 0, master_gke_version
@@ -15,8 +15,8 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  from .docker_manager import CommandRunner
18
- from ..utils.console import xpk_print
19
-
18
+ from ..utils.console import xpk_exit, xpk_print
19
+ from .remote_state.remote_state_client import RemoteStateClient
20
20
 
21
21
  xpk_gcloud_cfg_path = '~/gcloud/cfg'
22
22
  xpk_deployment_dir = '/deployment'
@@ -44,8 +44,10 @@ class GclusterManager:
44
44
  def __init__(
45
45
  self,
46
46
  gcluster_command_runner: CommandRunner,
47
+ remote_state_client: RemoteStateClient | None,
47
48
  ) -> None:
48
49
  self.gcluster_command_runner = gcluster_command_runner
50
+ self.remote_state_client = remote_state_client
49
51
 
50
52
  def _run_create_deployment_cmd(
51
53
  self, blueprint_container_path: str, prefix: str = ''
@@ -156,3 +158,19 @@ class GclusterManager:
156
158
  xpk_print('Staging blueprint completed!')
157
159
  xpk_print(f"File path in gcluster's working directory: {staged_blueprint}")
158
160
  return staged_blueprint
161
+
162
+ def upload_state(self) -> None:
163
+ xpk_print('Uploading state.')
164
+ if self.remote_state_client is None:
165
+ xpk_print('No remote state defined')
166
+ xpk_exit(1)
167
+ self.remote_state_client.upload_state()
168
+
169
+ def download_state(self) -> None:
170
+ if self.remote_state_client is None:
171
+ xpk_print('No remote state defined')
172
+ xpk_exit(1)
173
+
174
+ if self.remote_state_client.check_remote_state_exists():
175
+ self.remote_state_client.download_state()
176
+ xpk_print('Remote state not found.')
xpk/core/gcsfuse.py ADDED
@@ -0,0 +1,50 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from ..utils import templates
18
+
19
+ FUSE_PV_PATH = "/../templates/fuse-pv.yaml"
20
+ FUSE_PVC_PATH = "/../templates/fuse-pvc.yaml"
21
+
22
+
23
+ def create_pv(name: str, size: int, bucket: str) -> dict:
24
+ data = templates.load(FUSE_PV_PATH)
25
+ data["metadata"]["name"] = f"{name}-pv"
26
+ data["spec"]["capacity"]["storage"] = f"{size}Gi"
27
+ data["spec"]["csi"]["volumeHandle"] = bucket
28
+ return data
29
+
30
+
31
+ def create_pvc(name: str, size: int) -> dict:
32
+ data = templates.load(FUSE_PVC_PATH)
33
+ data["metadata"]["name"] = f"{name}-pvc"
34
+ data["spec"]["resources"]["requests"]["storage"] = f"{size}Gi"
35
+ data["spec"]["volumeName"] = f"{name}-pv"
36
+ return data
37
+
38
+
39
+ def manifest(name: str, bucket: str, size: int) -> list[dict]:
40
+ """Creates GCS FUSE manifest file.
41
+
42
+ Args:
43
+ path (str): path to the file where the manifest will be created
44
+ name (str): base name of the volumes
45
+ bucket (str): name of the storage bucket
46
+ size (str): size of the storage
47
+ """
48
+ pv = create_pv(name, size, bucket)
49
+ pvc = create_pvc(name, size)
50
+ return [pv, pvc]
xpk/core/kjob.py CHANGED
@@ -14,11 +14,33 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ from ..core.blueprint.blueprint_generator import get_subnetworks_for_a3mega, get_subnetworks_for_a3ultra
18
+ from ..core.capacity import H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
17
19
  from argparse import Namespace
18
- from ..utils.console import xpk_print
20
+ import yaml
21
+ from .workload_decorators.tcpxo_decorator import get_tcpxo_deamon_entry
22
+ from ..utils.console import xpk_print, xpk_exit
23
+
24
+ from ..utils import templates
25
+ from kubernetes import client as k8s_client
26
+ from kubernetes.client import ApiClient
27
+ from kubernetes.client.rest import ApiException
28
+ from .cluster import setup_k8s_env, XPK_SA, DEFAULT_NAMESPACE
29
+ from .storage import get_auto_mount_storages, get_auto_mount_gcsfuse_storages
19
30
  from .commands import run_command_for_value, run_kubectl_apply, run_command_with_updates
31
+ from .config import XpkConfig, KJOB_SHELL_IMAGE, KJOB_SHELL_INTERACTIVE_COMMAND, KJOB_SHELL_WORKING_DIRECTORY, KJOB_BATCH_IMAGE, KJOB_BATCH_WORKING_DIRECTORY
32
+ from .resources import get_cluster_system_characteristics, SystemCharacteristics, AcceleratorType
20
33
  from enum import Enum
21
34
 
35
+ from ..core.workload_decorators import tcpxo_decorator
36
+
37
+ from ..core.workload_decorators import rdma_decorator
38
+
39
+ KJOB_API_GROUP_NAME = "kjobctl.x-k8s.io"
40
+ KJOB_API_GROUP_VERSION = "v1alpha1"
41
+ KJOB_API_VOLUME_BUNDLE_PLURAL = "volumebundles"
42
+ VOLUME_BUNDLE_TEMPLATE_PATH = "/../templates/volume_bundle.yaml"
43
+
22
44
 
23
45
  class AppProfileDefaults(Enum):
24
46
  NAME = "xpk-def-app-profile"
@@ -30,12 +52,14 @@ class JobTemplateDefaults(Enum):
30
52
  COMPLETIONS = 1
31
53
  CONTAINER_NAME = "xpk-batch-container"
32
54
  IMAGE = "ubuntu:22.04"
55
+ WORKING_DIRECTORY = "/"
33
56
 
34
57
 
35
58
  class PodTemplateDefaults(Enum):
36
59
  NAME = "xpk-def-pod"
37
60
  CONTAINER_NAME = "xpk-interactive-container"
38
61
  IMAGE = "busybox:1.28"
62
+ WORKING_DIRECTORY = "/"
39
63
  INTERACTIVE_COMMAND = "/bin/sh"
40
64
 
41
65
 
@@ -52,10 +76,29 @@ job_template_yaml = """
52
76
  completionMode: Indexed
53
77
  template:
54
78
  spec:
79
+ dnsPolicy: ClusterFirstWithHostNet
80
+ tolerations:
81
+ - operator: "Exists"
82
+ key: nvidia.com/gpu
55
83
  containers:
56
84
  - name: {container_name}
57
85
  image: {image}
58
- restartPolicy: OnFailure"""
86
+ workingDir: {working_directory}
87
+ {resources}
88
+ {node_selector}
89
+ priorityClassName: {priority}
90
+ restartPolicy: OnFailure
91
+ serviceAccountName: {service_account}
92
+ """
93
+ job_node_selector_template = """
94
+ nodeSelector:
95
+ cloud.google.com/gke-accelerator: {gpu_name}
96
+ """
97
+ job_resources_template = """
98
+ resources:
99
+ limits:
100
+ nvidia.com/gpu: {gpu_per_node}
101
+ """
59
102
 
60
103
  app_profile_yaml = """
61
104
  apiVersion: kjobctl.x-k8s.io/v1alpha1
@@ -70,6 +113,7 @@ spec:
70
113
  requiredFlags: []
71
114
  - name: Interactive
72
115
  template: {interactive_template}
116
+ volumeBundles: {volume_bundles}
73
117
  """
74
118
 
75
119
  pod_template_yaml = """
@@ -80,12 +124,53 @@ metadata:
80
124
  namespace: default
81
125
  template:
82
126
  spec:
127
+ tolerations:
128
+ - effect: NoSchedule
129
+ key: components.gke.io/gke-managed-components
130
+ operator: Equal
131
+ value: "true"
83
132
  containers:
84
133
  - name: {container_name}
85
134
  image: {image}
86
135
  command: [{interactive_command}]
136
+ workingDir: {working_directory}
137
+ initContainers:
138
+ - name: init
139
+ image: {image}
140
+ command: ['/bin/mkdir', '-p', '{working_directory}']
141
+ serviceAccountName: {service_account}
87
142
  """
88
143
 
144
+ Kueue_TAS_annotation = "kueue.x-k8s.io/podset-preferred-topology=cloud.google.com/gce-topology-host"
145
+
146
+ default_interface_annotation = "networking.gke.io/default-interface=eth0"
147
+
148
+
149
+ def get_a3ultra_pod_template_annotations(args: Namespace) -> tuple[str, str]:
150
+ sub_networks = get_subnetworks_for_a3ultra(args.cluster)
151
+ interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
152
+ sub_networks
153
+ )
154
+
155
+ return (
156
+ default_interface_annotation,
157
+ f"{interfaces_key}=$'{interfaces_value}'",
158
+ )
159
+
160
+
161
+ def get_a3mega_pod_template_annotations(
162
+ args: Namespace,
163
+ ) -> tuple[str, str, str]:
164
+ """Adds or updates annotations in the Pod template."""
165
+ sub_networks = get_subnetworks_for_a3mega(args.cluster)
166
+ tcpxo_deamon_key, tcpxo_deamon_paths = get_tcpxo_deamon_entry()
167
+ interfaces_key, interfaces_value = tcpxo_decorator.get_interfaces_entry(
168
+ sub_networks
169
+ )
170
+ tcpxo = f"{tcpxo_deamon_key}=$'{tcpxo_deamon_paths}'"
171
+ interfaces = f"{interfaces_key}=$'{interfaces_value}'"
172
+ return tcpxo, interfaces, default_interface_annotation
173
+
89
174
 
90
175
  def verify_kjob_installed(args: Namespace) -> int:
91
176
  """Check if kjob is installed. If not provide user with proper communicate and exit.
@@ -112,7 +197,25 @@ def verify_kjob_installed(args: Namespace) -> int:
112
197
  return 0
113
198
 
114
199
 
115
- def create_app_profile_instance(args: Namespace) -> int:
200
+ def get_pod_template_interactive_command() -> str:
201
+ """Gets the interactive command for PodTemplate from config otherwise the default value.
202
+
203
+ Args:
204
+ args - user provided arguments
205
+ Returns:
206
+ str - PodTemplate's interactive command
207
+ """
208
+ config = XpkConfig()
209
+ pod_command = config.get(KJOB_SHELL_INTERACTIVE_COMMAND)
210
+ if pod_command is None or len(pod_command) == 0:
211
+ pod_command = PodTemplateDefaults.INTERACTIVE_COMMAND.value
212
+
213
+ return pod_command
214
+
215
+
216
+ def create_app_profile_instance(
217
+ args: Namespace, volume_bundles: list[str]
218
+ ) -> int:
116
219
  """Create new AppProfile instance on cluster with default settings.
117
220
 
118
221
  Args:
@@ -125,13 +228,29 @@ def create_app_profile_instance(args: Namespace) -> int:
125
228
  name=AppProfileDefaults.NAME.value,
126
229
  batch_template=JobTemplateDefaults.NAME.value,
127
230
  interactive_template=PodTemplateDefaults.NAME.value,
231
+ volume_bundles=volume_bundles,
128
232
  ),
129
233
  task="Creating AppProfile",
130
234
  args=args,
131
235
  )
132
236
 
133
237
 
134
- def create_job_template_instance(args: Namespace) -> int:
238
+ def decorate_job_template_with_gpu(yml_string: str, gpu_type: str) -> str:
239
+ job_spec = yaml.safe_load(yml_string)["template"]
240
+ if gpu_type == H100_MEGA_DEVICE_TYPE:
241
+ job_spec = tcpxo_decorator.decorate_kjob_template(job_spec)
242
+ if gpu_type == H200_DEVICE_TYPE:
243
+ job_spec = rdma_decorator.decorate_kjob_template(job_spec)
244
+ job_template_dict = yaml.safe_load(yml_string)
245
+ job_template_dict["template"] = job_spec
246
+ return yaml.dump(job_template_dict, sort_keys=False)
247
+
248
+
249
+ def create_job_template_instance(
250
+ args: Namespace,
251
+ system: SystemCharacteristics | None,
252
+ service_account: str,
253
+ ) -> int:
135
254
  """Create new JobTemplate instance on cluster with default settings.
136
255
 
137
256
  Args:
@@ -139,20 +258,49 @@ def create_job_template_instance(args: Namespace) -> int:
139
258
  Returns:
140
259
  exit_code > 0 if creating JobTemplate fails, 0 otherwise
141
260
  """
261
+ config = XpkConfig()
262
+ job_image = config.get(KJOB_BATCH_IMAGE)
263
+ if job_image is None or len(job_image) == 0:
264
+ job_image = JobTemplateDefaults.IMAGE.value
265
+ working_directory = config.get(KJOB_BATCH_WORKING_DIRECTORY)
266
+ if working_directory is None or len(working_directory) == 0:
267
+ working_directory = JobTemplateDefaults.WORKING_DIRECTORY.value
268
+ resources = (
269
+ job_resources_template.format(gpu_per_node=system.chips_per_vm)
270
+ if system is not None
271
+ and system.accelerator_type == AcceleratorType["GPU"]
272
+ else ""
273
+ )
274
+
275
+ node_selector = (
276
+ job_node_selector_template.format(gpu_name=system.gke_accelerator)
277
+ if system is not None
278
+ and system.accelerator_type == AcceleratorType["GPU"]
279
+ else ""
280
+ )
281
+ yml_string = job_template_yaml.format(
282
+ name=JobTemplateDefaults.NAME.value,
283
+ parallelism=JobTemplateDefaults.PARALLELISM.value,
284
+ completions=JobTemplateDefaults.COMPLETIONS.value,
285
+ container_name=JobTemplateDefaults.CONTAINER_NAME.value,
286
+ image=job_image,
287
+ working_directory=working_directory,
288
+ resources=resources,
289
+ node_selector=node_selector,
290
+ priority=args.priority if hasattr(args, "priority") else "medium",
291
+ service_account=service_account,
292
+ )
293
+ if system is not None and system.accelerator_type == AcceleratorType["GPU"]:
294
+ yml_string = decorate_job_template_with_gpu(yml_string, system.device_type)
295
+
142
296
  return run_kubectl_apply(
143
- yml_string=job_template_yaml.format(
144
- name=JobTemplateDefaults.NAME.value,
145
- parallelism=JobTemplateDefaults.PARALLELISM.value,
146
- completions=JobTemplateDefaults.COMPLETIONS.value,
147
- container_name=JobTemplateDefaults.CONTAINER_NAME.value,
148
- image=JobTemplateDefaults.IMAGE.value,
149
- ),
297
+ yml_string,
150
298
  task="Creating JobTemplate",
151
299
  args=args,
152
300
  )
153
301
 
154
302
 
155
- def create_pod_template_instance(args: Namespace) -> int:
303
+ def create_pod_template_instance(args: Namespace, service_account: str) -> int:
156
304
  """Create new PodTemplate instance on cluster with default settings.
157
305
 
158
306
  Args:
@@ -160,28 +308,49 @@ def create_pod_template_instance(args: Namespace) -> int:
160
308
  Returns:
161
309
  exit_code > 0 if creating PodTemplate fails, 0 otherwise
162
310
  """
311
+ config = XpkConfig()
312
+ pod_image = config.get(KJOB_SHELL_IMAGE)
313
+ if pod_image is None or len(pod_image) == 0:
314
+ pod_image = PodTemplateDefaults.IMAGE.value
315
+ working_directory = config.get(KJOB_SHELL_WORKING_DIRECTORY)
316
+ if working_directory is None or len(working_directory) == 0:
317
+ working_directory = PodTemplateDefaults.WORKING_DIRECTORY.value
318
+
163
319
  return run_kubectl_apply(
164
320
  yml_string=pod_template_yaml.format(
165
321
  name=PodTemplateDefaults.NAME.value,
166
322
  container_name=PodTemplateDefaults.CONTAINER_NAME.value,
167
- image=PodTemplateDefaults.IMAGE.value,
168
- interactive_command=PodTemplateDefaults.INTERACTIVE_COMMAND.value,
323
+ image=pod_image,
324
+ working_directory=working_directory,
325
+ interactive_command=get_pod_template_interactive_command(),
326
+ service_account=service_account,
169
327
  ),
170
328
  task="Creating PodTemplate",
171
329
  args=args,
172
330
  )
173
331
 
174
332
 
175
- def prepare_kjob(args) -> int:
176
- job_err_code = create_job_template_instance(args)
333
+ def prepare_kjob(args: Namespace) -> int:
334
+ system = get_cluster_system_characteristics(args)
335
+
336
+ k8s_api_client = setup_k8s_env(args)
337
+ storages = get_auto_mount_storages(k8s_api_client)
338
+
339
+ service_account = ""
340
+ if len(storages) > 0:
341
+ service_account = XPK_SA
342
+
343
+ job_err_code = create_job_template_instance(args, system, service_account)
177
344
  if job_err_code > 0:
178
345
  return job_err_code
179
346
 
180
- pod_err_code = create_pod_template_instance(args)
347
+ pod_err_code = create_pod_template_instance(args, service_account)
181
348
  if pod_err_code > 0:
182
349
  return pod_err_code
183
350
 
184
- return create_app_profile_instance(args)
351
+ volume_bundles = [item.name for item in storages]
352
+
353
+ return create_app_profile_instance(args, volume_bundles)
185
354
 
186
355
 
187
356
  def apply_kjob_crds(args: Namespace) -> int:
@@ -203,3 +372,73 @@ def apply_kjob_crds(args: Namespace) -> int:
203
372
  return return_code
204
373
  xpk_print("Creating kjob CRDs succeeded")
205
374
  return 0
375
+
376
+
377
+ def create_volume_bundle_instance(
378
+ k8s_api_client: ApiClient,
379
+ name: str,
380
+ manifest: list[dict],
381
+ readonly: bool,
382
+ mount_point: str,
383
+ ) -> None:
384
+ """
385
+ Creates a new VolumeBundle resource in the Kubernetes cluster.
386
+
387
+ This function reads a VolumeBundle template from a YAML file, populates it with
388
+ values from the provided arguments, and then creates the VolumeBundle object
389
+ in the cluster.
390
+
391
+ Args:
392
+ k8s_api_client: An ApiClient object for interacting with the Kubernetes API.
393
+ args: An argparse Namespace object containing the arguments for creating
394
+ the Storage resource.
395
+ """
396
+ data = templates.load(VOLUME_BUNDLE_TEMPLATE_PATH)
397
+ data["metadata"]["name"] = name
398
+ spec = data["spec"]
399
+ spec["volumes"] = []
400
+ spec["containerVolumeMounts"] = []
401
+
402
+ for obj in manifest:
403
+ if obj["kind"] == "PersistentVolumeClaim":
404
+ spec["volumes"].append({
405
+ "name": obj["metadata"]["name"],
406
+ "persistentVolumeClaim": {
407
+ "claimName": obj["metadata"]["name"],
408
+ "readOnly": readonly,
409
+ },
410
+ })
411
+ spec["containerVolumeMounts"].append({
412
+ "name": obj["metadata"]["name"],
413
+ "mountPath": mount_point,
414
+ })
415
+
416
+ data["spec"] = spec
417
+
418
+ api_instance = k8s_client.CustomObjectsApi(k8s_api_client)
419
+ try:
420
+ api_instance.create_namespaced_custom_object(
421
+ namespace=DEFAULT_NAMESPACE,
422
+ group=KJOB_API_GROUP_NAME,
423
+ version=KJOB_API_GROUP_VERSION,
424
+ plural=KJOB_API_VOLUME_BUNDLE_PLURAL,
425
+ body=data,
426
+ )
427
+ xpk_print(
428
+ f"Created {KJOB_API_VOLUME_BUNDLE_PLURAL}.{KJOB_API_GROUP_NAME} object:"
429
+ f" {data['metadata']['name']}"
430
+ )
431
+ except ApiException as e:
432
+ if e.status == 409:
433
+ xpk_print(f"VolumeBundle: {name} already exists. Skipping its creation")
434
+ else:
435
+ xpk_print(f"Encountered error during VolumeBundle creation: {e}")
436
+ xpk_exit(1)
437
+
438
+
439
+ def get_gcsfuse_annotation(args: Namespace) -> str | None:
440
+ k8s_api_client = setup_k8s_env(args)
441
+ gcsfuse_storages = get_auto_mount_gcsfuse_storages(k8s_api_client)
442
+ if len(gcsfuse_storages) > 0:
443
+ return "gke-gcsfuse/volumes=true"
444
+ return None