xpk 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. xpk/__init__.py +15 -0
  2. xpk/api/__init__.py +15 -0
  3. xpk/api/storage_crd.yaml +52 -0
  4. xpk/commands/__init__.py +15 -0
  5. xpk/commands/batch.py +131 -0
  6. xpk/commands/cluster.py +808 -0
  7. xpk/commands/cluster_gcluster.py +269 -0
  8. xpk/commands/common.py +44 -0
  9. xpk/commands/config.py +29 -0
  10. xpk/commands/info.py +243 -0
  11. xpk/commands/inspector.py +357 -0
  12. xpk/commands/job.py +199 -0
  13. xpk/commands/kind.py +283 -0
  14. xpk/commands/kjob_common.py +44 -0
  15. xpk/commands/run.py +128 -0
  16. xpk/commands/shell.py +140 -0
  17. xpk/commands/storage.py +267 -0
  18. xpk/commands/version.py +27 -0
  19. xpk/commands/workload.py +889 -0
  20. xpk/core/__init__.py +15 -0
  21. xpk/core/blueprint/__init__.py +15 -0
  22. xpk/core/blueprint/blueprint_definitions.py +62 -0
  23. xpk/core/blueprint/blueprint_generator.py +708 -0
  24. xpk/core/capacity.py +185 -0
  25. xpk/core/cluster.py +564 -0
  26. xpk/core/cluster_private.py +200 -0
  27. xpk/core/commands.py +356 -0
  28. xpk/core/config.py +179 -0
  29. xpk/core/docker_container.py +225 -0
  30. xpk/core/docker_image.py +210 -0
  31. xpk/core/docker_manager.py +308 -0
  32. xpk/core/docker_resources.py +350 -0
  33. xpk/core/filestore.py +251 -0
  34. xpk/core/gcloud_context.py +196 -0
  35. xpk/core/gcluster_manager.py +176 -0
  36. xpk/core/gcsfuse.py +50 -0
  37. xpk/core/kjob.py +444 -0
  38. xpk/core/kueue.py +358 -0
  39. xpk/core/monitoring.py +134 -0
  40. xpk/core/nap.py +361 -0
  41. xpk/core/network.py +377 -0
  42. xpk/core/nodepool.py +581 -0
  43. xpk/core/pathways.py +377 -0
  44. xpk/core/ray.py +222 -0
  45. xpk/core/remote_state/__init__.py +15 -0
  46. xpk/core/remote_state/fuse_remote_state.py +99 -0
  47. xpk/core/remote_state/remote_state_client.py +38 -0
  48. xpk/core/resources.py +238 -0
  49. xpk/core/scheduling.py +253 -0
  50. xpk/core/storage.py +581 -0
  51. xpk/core/system_characteristics.py +1432 -0
  52. xpk/core/vertex.py +105 -0
  53. xpk/core/workload.py +341 -0
  54. xpk/core/workload_decorators/__init__.py +15 -0
  55. xpk/core/workload_decorators/rdma_decorator.py +129 -0
  56. xpk/core/workload_decorators/storage_decorator.py +52 -0
  57. xpk/core/workload_decorators/tcpxo_decorator.py +190 -0
  58. xpk/main.py +75 -0
  59. xpk/parser/__init__.py +15 -0
  60. xpk/parser/batch.py +43 -0
  61. xpk/parser/cluster.py +662 -0
  62. xpk/parser/common.py +259 -0
  63. xpk/parser/config.py +49 -0
  64. xpk/parser/core.py +135 -0
  65. xpk/parser/info.py +64 -0
  66. xpk/parser/inspector.py +65 -0
  67. xpk/parser/job.py +147 -0
  68. xpk/parser/kind.py +95 -0
  69. xpk/parser/run.py +47 -0
  70. xpk/parser/shell.py +59 -0
  71. xpk/parser/storage.py +316 -0
  72. xpk/parser/validators.py +39 -0
  73. xpk/parser/version.py +23 -0
  74. xpk/parser/workload.py +726 -0
  75. xpk/templates/__init__.py +15 -0
  76. xpk/templates/storage.yaml +13 -0
  77. xpk/utils/__init__.py +15 -0
  78. xpk/utils/console.py +55 -0
  79. xpk/utils/file.py +82 -0
  80. xpk/utils/gcs_utils.py +125 -0
  81. xpk/utils/kubectl.py +57 -0
  82. xpk/utils/network.py +168 -0
  83. xpk/utils/objects.py +88 -0
  84. xpk/utils/templates.py +28 -0
  85. xpk/utils/validation.py +80 -0
  86. xpk/utils/yaml.py +30 -0
  87. xpk-0.0.1.dist-info/LICENSE +202 -0
  88. xpk-0.0.1.dist-info/METADATA +1498 -0
  89. xpk-0.0.1.dist-info/RECORD +92 -0
  90. xpk-0.0.1.dist-info/WHEEL +5 -0
  91. xpk-0.0.1.dist-info/entry_points.txt +2 -0
  92. xpk-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,308 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from abc import ABC, abstractmethod
18
+ import docker
19
+ from docker.errors import ContainerError, APIError, ImageNotFound, BuildError
20
+ from ..utils.console import xpk_print, xpk_exit
21
+ from ..utils.file import ensure_directory_exists
22
+ from ..utils.objects import hash_string
23
+ from shutil import copytree, copy
24
+ import requests
25
+ import os
26
+ import tempfile
27
+ import time
28
+
29
+
30
+ DockerRunCommandExitCode = 135
31
+ dockerBuildErrorCode = 134
32
+ ctk_dockerfile_path = "Dockerfile"
33
+ ctk_build_ref = "v1.45.1"
34
+ ctk_docker_image = "xpk-ctk"
35
+ ctk_container_name = "xpk-ctk-container"
36
+ gcloud_cfg_mount_path = "/root/.config/gcloud"
37
+ working_dir_mount_path = "/out"
38
+ dockerfile_gh_path = f"https://raw.githubusercontent.com/GoogleCloudPlatform/cluster-toolkit/refs/tags/{ctk_build_ref}/tools/cloud-build/images/cluster-toolkit-dockerfile/Dockerfile"
39
+ upload_dir_name = "uploads"
40
+
41
+
42
+ class CommandRunner(ABC):
43
+ """This is a base class that defines methods a class for running cluster toolkit command should implement."""
44
+
45
+ @abstractmethod
46
+ def initialize(self) -> None:
47
+ """initialize is a method that should implement all steps neccessary to run command.
48
+
49
+ Returns:
50
+ None
51
+ """
52
+ return None
53
+
54
+ @abstractmethod
55
+ def run_command(self, cmd: str) -> None:
56
+ """run_command implements executing command. If command execution fails, exception should be raised.
57
+
58
+ Args:
59
+ cmd (str): command to run
60
+
61
+ Returns:
62
+ None:
63
+ """
64
+ return None
65
+
66
+ @abstractmethod
67
+ def upload_file_to_working_dir(self, path: str, prefix: str = "") -> str:
68
+ """Uploads single file to working directory.
69
+
70
+ Args:
71
+ path (str): path to file to upload
72
+
73
+ Returns:
74
+ str: path to a destination file
75
+ """
76
+ return ""
77
+
78
+ @abstractmethod
79
+ def upload_directory_to_working_dir(self, path: str, prefix: str = "") -> str:
80
+ """upload directory and its content to working directory.
81
+
82
+ Args:
83
+ path (str): path pointing to directory that will be uploaded.
84
+
85
+ Returns:
86
+ str: path to a target directory.
87
+ """
88
+ return ""
89
+
90
+
91
+ class DockerManager(CommandRunner):
92
+ """DockerManager is a class for managing gcluster execution in docker container.
93
+ Attributes:
94
+ - dockerfile_path (str) : path to dockerfile defining gcluster execution image
95
+ - gcloud_cfg_path (str) : path to directory containing gcloud configuration
96
+ - working_dir (str) : path to directory in which gcluster deployment directory will be saved
97
+ - client (DockerClient) : docker client
98
+ - nocache (bool) : wheter to use docker cache when building image
99
+ - img_name (str) : name of docker image to create
100
+ - container_name (str) : name of the container that will be created from img_name
101
+ - rm_container_after (bool) : if set to True, docker container in which command is executed will be removed after each execution.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ gcloud_cfg_path: str,
107
+ working_dir: str,
108
+ nocache: bool = False,
109
+ img_name: str = ctk_docker_image,
110
+ container_name: str = ctk_container_name,
111
+ remove_container: bool = True,
112
+ ) -> None:
113
+ self.dockerfile_path = ""
114
+ self.client = docker.from_env()
115
+ self.gcloud_cfg_path = gcloud_cfg_path
116
+ self.working_dir = working_dir
117
+ self.nocache = nocache
118
+ self.img_name = f"{img_name}:{ctk_build_ref}"
119
+ self.container_name = container_name
120
+ self.remove_container = remove_container
121
+
122
+ def initialize(self):
123
+ """Build image from dockerfile pointed by _img_name. This method
124
+ uses python docker client to build cloud toolkit execution image.
125
+ Arguments:
126
+ Returns:
127
+ - None
128
+ Raises:
129
+ - docker.errors.BuildError – If there is an error during the build.
130
+ - docker.errors.APIError – If the server returns any other error.
131
+ - TypeError - otherwise
132
+
133
+ """
134
+ self._is_docker_installed()
135
+ xpk_print("Docker found!")
136
+
137
+ if not self._docker_image_exists():
138
+ xpk_print(f"Docker image {self.img_name} not found.")
139
+ self._build_image()
140
+ else:
141
+ xpk_print(f"Docker image {self.img_name} found!")
142
+
143
+ def run_command(
144
+ self,
145
+ cmd: str,
146
+ ) -> None:
147
+ """Run container from _img_name and mount directories:
148
+ - gcloud config
149
+ - deployment directory
150
+ Arguments:
151
+ Returns:
152
+ - bytes
153
+ Raises:
154
+ - docker.errors.ContainerError,
155
+ - docker.errors.ImageNotFound,
156
+ - docker.errors.APIError
157
+ """
158
+ xpk_print(f"Running command: {cmd} ...")
159
+ xpk_print(
160
+ f"volumes: {self.gcloud_cfg_path}:{gcloud_cfg_mount_path},"
161
+ f" {self.working_dir}:{working_dir_mount_path}"
162
+ )
163
+ try:
164
+ container = self.client.containers.run(
165
+ image=self.img_name,
166
+ entrypoint=cmd,
167
+ remove=self.remove_container,
168
+ name=self._get_container_unique_name(
169
+ cmd
170
+ ), # To allow multiple xpk commands run in one machine.
171
+ detach=True,
172
+ volumes=[
173
+ f"{self.gcloud_cfg_path}:{gcloud_cfg_mount_path}",
174
+ f"{self.working_dir}:{working_dir_mount_path}",
175
+ ],
176
+ environment={
177
+ "GOOGLE_APPLICATION_CREDENTIALS": (
178
+ "/root/.config/gcloud/application_default_credentials.json"
179
+ )
180
+ },
181
+ )
182
+ self._print_logs_from_container(container)
183
+ result = container.wait()
184
+ if result["StatusCode"] != 0:
185
+ xpk_print(f"Running gcluster command: {cmd} failed.")
186
+ xpk_exit(result["StatusCode"])
187
+ except ContainerError as e:
188
+ xpk_print(
189
+ "Running command failed due to ContainerError with exit status:"
190
+ f" {e.exit_status} and stderr: {e.stderr}"
191
+ )
192
+ xpk_exit(DockerRunCommandExitCode)
193
+ except ImageNotFound as _:
194
+ xpk_print(f"Image {ctk_docker_image} not found. Deploying cluster failed")
195
+ xpk_exit(DockerRunCommandExitCode)
196
+ except APIError as e:
197
+ xpk_print(f"Deploying cluster toolkit failed due to {e.explanation}")
198
+ xpk_exit(DockerRunCommandExitCode)
199
+
200
+ def _print_logs_from_container(self, container):
201
+ output = container.attach(stdout=True, stream=True, logs=True)
202
+ for line in output:
203
+ xpk_print(f"[gcluster] {line.decode('utf-8').strip()}")
204
+
205
+ def upload_directory_to_working_dir(self, path: str, prefix: str = "") -> str:
206
+ """Move file or directory from specified path to directory containing deployment files
207
+
208
+ Args:
209
+ path (str): path of directory/file that will be moved to deployment directory
210
+ """
211
+ name = path.split("/")[-1]
212
+ target_path = os.path.join(self._get_upload_directory(prefix), name)
213
+ uploaded_path = os.path.join(
214
+ self._get_upload_directory_mounted(prefix), name
215
+ )
216
+ xpk_print(
217
+ f"Copying directory from {path} to {target_path}. Path in docker:"
218
+ f" {uploaded_path}"
219
+ )
220
+ copytree(path, target_path, dirs_exist_ok=True)
221
+ return uploaded_path
222
+
223
+ def upload_file_to_working_dir(self, path: str, prefix: str = "") -> str:
224
+ """Move file or directory from specified path to directory containing deployment files
225
+
226
+ Args:
227
+ path (str): path of directory/file that will be moved to deployment directory
228
+ """
229
+ name = path.split("/")[-1]
230
+ target_path = os.path.join(self._get_upload_directory(prefix), name)
231
+ uploaded_path = os.path.join(
232
+ self._get_upload_directory_mounted(prefix), name
233
+ )
234
+ xpk_print(
235
+ f"Copying a file from {path} to {target_path}. Path in docker:"
236
+ f" {uploaded_path}"
237
+ )
238
+ copy(path, target_path)
239
+ return uploaded_path
240
+
241
+ def _get_upload_directory(self, prefix: str = "") -> str:
242
+ upload_dir = os.path.join(self.working_dir, upload_dir_name, prefix)
243
+ ensure_directory_exists(upload_dir)
244
+ return upload_dir
245
+
246
+ def _get_upload_directory_mounted(self, prefix: str = "") -> str:
247
+ return os.path.join(working_dir_mount_path, upload_dir_name, prefix)
248
+
249
+ def _create_tmp_for_dockerfile(self) -> str:
250
+ tmp_dir = os.path.join(tempfile.gettempdir(), "xpkutils")
251
+ ensure_directory_exists(tmp_dir)
252
+ tmp_path = os.path.join(tmp_dir, "Dockerfile")
253
+ return tmp_path
254
+
255
+ def _is_docker_installed(self) -> None:
256
+ self.client.info()
257
+
258
+ def _docker_image_exists(self) -> bool:
259
+ try:
260
+ self.client.images.get(f"{self.img_name}")
261
+ except ImageNotFound as _:
262
+ return False
263
+ return True
264
+
265
+ def _download_ctk_dockerfile(self) -> None:
266
+ """Downloads cluster toolkit dockerfile to dockerfile_path
267
+
268
+ Returns:
269
+ None
270
+ """
271
+ xpk_print(f"Downloading Dockerfile from {dockerfile_gh_path} ...")
272
+ self.dockerfile_path = self._create_tmp_for_dockerfile()
273
+ r = requests.get(dockerfile_gh_path, timeout=100)
274
+ with open(self.dockerfile_path, "w+", encoding="utf8") as dockerfile:
275
+ dockerfile.write(r.text)
276
+ xpk_print("Downloading Dockerfile completed!")
277
+
278
+ def _build_image(self):
279
+ try:
280
+ self._download_ctk_dockerfile()
281
+ dir_path = "/".join(self.dockerfile_path.split("/")[:-1])
282
+ xpk_print(
283
+ f"Building {self.img_name} docker image from dockerfile:"
284
+ f" {self.dockerfile_path}. It may take a while..."
285
+ )
286
+ self.client.images.build(
287
+ nocache=self.nocache,
288
+ path=dir_path,
289
+ tag=f"{self.img_name}",
290
+ rm=True,
291
+ buildargs={"CLUSTER_TOOLKIT_REF": ctk_build_ref},
292
+ )
293
+ except BuildError as e:
294
+ xpk_print(f"error while building image {self.img_name}: {e.msg}")
295
+ xpk_exit(dockerBuildErrorCode)
296
+ except APIError as e:
297
+ xpk_print(f"erro while building image {self.img_name}: {e.explanation}")
298
+ xpk_exit(dockerBuildErrorCode)
299
+ except TypeError as e:
300
+ xpk_print(f"TypeError while building image {self.img_name}: {e.args}")
301
+ xpk_exit(dockerBuildErrorCode)
302
+ xpk_print("Docker image build succesfully.")
303
+ os.remove(self.dockerfile_path)
304
+ tmp_dockerfile_dir = "/".join(self.dockerfile_path.split("/")[:-1])
305
+ os.rmdir(tmp_dockerfile_dir)
306
+
307
+ def _get_container_unique_name(self, cmd):
308
+ return f"{self.container_name}_{hash_string(cmd + str(time.time_ns()))}"
@@ -0,0 +1,350 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
18
+ from .cluster import setup_k8s_env
19
+ from .storage import GCS_FUSE_TYPE, GCP_FILESTORE_TYPE, Storage, get_storages_to_mount
20
+ from .system_characteristics import AcceleratorType, SystemCharacteristics
21
+
22
+
23
+ def get_main_container_resources(
24
+ args, system: SystemCharacteristics, resource_type
25
+ ) -> str:
26
+ """Resources for the main container.
27
+ Args:
28
+ args: user provided args.
29
+ system: system characteristics.
30
+ resource_type: TPU / GPU / CPU
31
+
32
+ Returns:
33
+ str:
34
+ Workload resources port as a YAML string
35
+ """
36
+ # Resources requirements for Pathways workload containers are known.
37
+ resources_yaml = """cpu: "24"
38
+ memory: 100G"""
39
+ if args.use_pathways:
40
+ return resources_yaml
41
+
42
+ gpu_resources_yaml = """nvidia.com/gpu: {system.chips_per_vm}"""
43
+ if system.accelerator_type == AcceleratorType['GPU']:
44
+ return gpu_resources_yaml.format(system=system)
45
+
46
+ if system.accelerator_type == AcceleratorType['CPU']:
47
+ # CPUs don't have chips, but have a subresource called vCPUs.
48
+ # system.chips_per_vm is used as a proxy for vCPUs.
49
+ # Some vCPUs get used in hosting system pods of the workloads,
50
+ # hence an offset of 0.95 is introduced.
51
+ offset_vCPUs = int(system.chips_per_vm) * 0.95
52
+ return f'{resource_type}: {offset_vCPUs}'
53
+
54
+ return f'{resource_type}: {system.chips_per_vm}'
55
+
56
+
57
+ def get_env_container(args, system: SystemCharacteristics) -> str:
58
+ """Environment configuration for the main container.
59
+ Args:
60
+ args: user provided args.
61
+ system: system characteristics.
62
+
63
+ Returns:
64
+ str:
65
+ YAML with the env config for the main container, as a YAML string.
66
+ """
67
+ pw_env_yaml = """
68
+ - name: XCLOUD_ENVIRONMENT
69
+ value: GCP
70
+ - name: JAX_PLATFORMS
71
+ value: proxy
72
+ - name: JAX_BACKEND_TARGET
73
+ value: {proxy_address}
74
+ - name: JOBSET_NAME
75
+ valueFrom:
76
+ fieldRef:
77
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']"""
78
+ if args.use_pathways:
79
+ return pw_env_yaml.format(
80
+ args=args, proxy_address=args.pathways_proxy_address
81
+ )
82
+
83
+ gpu_env_yaml = """
84
+ - name: REPLICATED_JOB_NAME
85
+ valueFrom:
86
+ fieldRef:
87
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
88
+ - name: JOBSET_NAME
89
+ valueFrom:
90
+ fieldRef:
91
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
92
+ - name: JAX_COORDINATOR_ADDRESS
93
+ value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
94
+ - name: NNODES
95
+ value: "{args.num_nodes}"
96
+ - name: NODE_RANK
97
+ valueFrom:
98
+ fieldRef:
99
+ fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
100
+ - name: USE_GPUDIRECT
101
+ value: {gpu_direct_name}
102
+ - name: GPUS_PER_NODE
103
+ value: "{system.chips_per_vm}"
104
+ - name: JAX_COORDINATOR_PORT
105
+ value: "6002"
106
+ - name: COMMAND
107
+ value: "{args.command}"
108
+ {args.env}"""
109
+
110
+ if system.accelerator_type == AcceleratorType['GPU']:
111
+ gpu_direct_name = 'fastrak'
112
+ if args.device_type == H100_DEVICE_TYPE:
113
+ gpu_direct_name = 'tcpx'
114
+ gpu_env_yaml += """
115
+ - name: LD_LIBRARY_PATH
116
+ value: /usr/local/nvidia/lib64
117
+ """
118
+ elif args.device_type == H100_MEGA_DEVICE_TYPE:
119
+ gpu_direct_name = 'tcpxo'
120
+ elif args.device_type == H200_DEVICE_TYPE:
121
+ gpu_direct_name = 'rdma'
122
+ return gpu_env_yaml.format(
123
+ args=args, system=system, gpu_direct_name=gpu_direct_name
124
+ )
125
+
126
+ if system.accelerator_type == AcceleratorType['CPU']:
127
+ return get_cpu_env(args.num_slices, args.env, system)
128
+
129
+ return args.env # pytype: disable=bad-return-type
130
+
131
+
132
+ def get_cpu_env(num_slices, env_vars, system) -> str:
133
+ """Generate environment variables for CPU nodepools
134
+ Args:
135
+ num_slices: Number of slices to be used in the workload.
136
+ env_vars: Environment variables, processed from user args.
137
+ system: system characteristics
138
+
139
+ Returns:
140
+ str: yaml containing env variables
141
+ """
142
+ yaml = """
143
+ - name: REPLICATED_JOB_NAME
144
+ valueFrom:
145
+ fieldRef:
146
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
147
+ - name: JOB_INDEX
148
+ valueFrom:
149
+ fieldRef:
150
+ fieldPath: metadata.annotations['jobset.sigs.k8s.io/job-index']
151
+ - name: JOB_COMPLETION_INDEX
152
+ valueFrom:
153
+ fieldRef:
154
+ fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
155
+ - name: PROCESSES_IN_JOB
156
+ value: "{processes_in_job}"
157
+ - name: JAX_PROCESS_COUNT
158
+ value: "{process_count}"
159
+ {env_vars}
160
+ - name: JAX_COORDINATOR_ADDRESS
161
+ value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
162
+ """
163
+ return yaml.format(
164
+ processes_in_job=system.vms_per_slice,
165
+ process_count=calculate_process_count(num_slices, system.vms_per_slice),
166
+ env_vars=env_vars,
167
+ )
168
+
169
+
170
+ def get_volumes(args, system: SystemCharacteristics) -> str:
171
+ """Get volumes accessible to the containers in the pod.
172
+ Args:
173
+ args: user provided args.
174
+ system: system characteristics.
175
+
176
+ Returns:
177
+ str:
178
+ YAML for the volumes.
179
+ """
180
+ volumes = """- emptyDir:
181
+ medium: Memory
182
+ name: dshm-2
183
+ """
184
+
185
+ if args.ramdisk_directory != '':
186
+ volumes += """
187
+ - name: cache
188
+ csi:
189
+ driver: phase1-checkpoint.csi.storage.gke.io"""
190
+
191
+ if (
192
+ system.accelerator_type == AcceleratorType['TPU']
193
+ and args.deploy_stacktrace_sidecar
194
+ ):
195
+ volumes += """
196
+ - name: tpu-stack-trace
197
+ - name: shared-data
198
+ """
199
+
200
+ storages: list[Storage] = get_storages_to_mount(
201
+ setup_k8s_env(args), args.storage
202
+ )
203
+ for storage in storages:
204
+ if storage.type == GCS_FUSE_TYPE:
205
+ volumes += f"""- name: {storage.pv}
206
+ persistentVolumeClaim:
207
+ claimName: {storage.pvc}
208
+ readOnly: {storage.readonly}
209
+ """
210
+ if storage.type == GCP_FILESTORE_TYPE:
211
+ volumes += f"""- name: {storage.pv}
212
+ persistentVolumeClaim:
213
+ claimName: {storage.pvc}
214
+ readOnly: {storage.readonly}
215
+ """
216
+ return volumes
217
+
218
+
219
+ def get_volume_mounts(args, system: SystemCharacteristics) -> str:
220
+ """Resources for the main container.
221
+ Args:
222
+ args: user provided args.
223
+
224
+ Returns:
225
+ str:
226
+ YAML for the volumes mounted within a Pathways container or GPU container as a YAML string.
227
+ """
228
+ volume_mount_yaml = """- mountPath: /dev/shm
229
+ name: dshm-2
230
+ """
231
+
232
+ if args.ramdisk_directory != '':
233
+ volume_mount_yaml += f"""
234
+ - mountPath: /{args.ramdisk_directory}
235
+ name: cache"""
236
+
237
+ if args.use_pathways:
238
+ volume_mount_yaml = """- mountPath: /tmp
239
+ name: shared-tmp
240
+ """
241
+ elif (
242
+ system.accelerator_type == AcceleratorType['TPU']
243
+ and args.deploy_stacktrace_sidecar
244
+ ):
245
+ volume_mount_yaml += """- name: tpu-stack-trace
246
+ mountPath: /tmp/debugging
247
+ - name: shared-data
248
+ mountPath: /shared-volume
249
+ """
250
+ elif system.accelerator_type == AcceleratorType['GPU']:
251
+ if system.device_type == H100_DEVICE_TYPE:
252
+ volume_mount_yaml = """- name: nvidia-install-dir-host
253
+ mountPath: /usr/local/nvidia/lib64
254
+ - name: tcpx-nccl-plugin-volume
255
+ mountPath: /usr/local/tcpx
256
+ - name: tcpd-socket
257
+ mountPath: /tmp
258
+ - name: shared-memory
259
+ mountPath: /dev/shm
260
+ - name: workload-terminated-volume
261
+ mountPath: /usr/share/workload"""
262
+ elif (
263
+ system.device_type == H100_MEGA_DEVICE_TYPE
264
+ or system.device_type == H200_DEVICE_TYPE
265
+ ):
266
+ volume_mount_yaml = ''
267
+
268
+ storages: list[Storage] = get_storages_to_mount(
269
+ setup_k8s_env(args), args.storage
270
+ )
271
+ for storage in storages:
272
+ if storage.type == GCS_FUSE_TYPE:
273
+ volume_mount_yaml += f"""- name: {storage.pv}
274
+ mountPath: {storage.mount_point}
275
+ readOnly: {storage.readonly}
276
+ """
277
+ if storage.type == GCP_FILESTORE_TYPE:
278
+ volume_mount_yaml += f"""- name: {storage.pv}
279
+ mountPath: {storage.mount_point}
280
+ readOnly: {storage.readonly}
281
+ """
282
+ return volume_mount_yaml
283
+
284
+
285
+ def calculate_process_count(num_slices, vms_per_slice) -> str:
286
+ """Calculates the total number of processes in the workload.
287
+ Args:
288
+ num_slices: Number of slices to be used in the workload.
289
+ vms_per_slice: number of VMs in each slice.
290
+
291
+ Returns:
292
+ str: total number of processes.
293
+ """
294
+ num_processes = int(num_slices) * int(vms_per_slice)
295
+
296
+ return f'{num_processes}'
297
+
298
+
299
+ def add_container_ports(args, system: SystemCharacteristics) -> str:
300
+ """Add slice builder and megascale container ports,
301
+ for non-pathways workloads.
302
+
303
+ Args:
304
+ args: user provided args.
305
+
306
+ Returns:
307
+ str:
308
+ Pathways server port as a YAML string
309
+ """
310
+ port_yaml = """- containerPort: 8471
311
+ - containerPort: 8080"""
312
+ if args.use_pathways:
313
+ return ''
314
+
315
+ gpu_port_yaml = """- containerPort: 6002"""
316
+ if system.accelerator_type == AcceleratorType['GPU']:
317
+ return gpu_port_yaml
318
+ return port_yaml
319
+
320
+
321
+ def add_jax_coordinator_port(system) -> str:
322
+ """Add jax coordinator port only for CPUs
323
+
324
+ Args:
325
+ system: system characteristics.
326
+
327
+ Returns:
328
+ str:
329
+ jax coordinator port as a YAML string
330
+ """
331
+ if system.accelerator_type == AcceleratorType['CPU']:
332
+ return '- containerPort: 1234'
333
+ return ''
334
+
335
+
336
+ def add_image_pull_policy_for_pw_or_gpu(args, system: SystemCharacteristics):
337
+ """Add image pull policy only for Pathways containers.
338
+ Args:
339
+ args: user provided args.
340
+ system: system characteristics
341
+
342
+ Returns:
343
+ str:
344
+ YAML stating that the image will be pulled fro GCR every time.
345
+ """
346
+ yaml = """imagePullPolicy: Always"""
347
+
348
+ if args.use_pathways or system.accelerator_type == AcceleratorType['GPU']:
349
+ return yaml.format(args=args)
350
+ return ''