xpk 0.7.1__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. xpk/commands/batch.py +19 -12
  2. xpk/commands/cluster.py +33 -16
  3. xpk/commands/cluster_gcluster.py +22 -5
  4. xpk/commands/info.py +2 -4
  5. xpk/commands/job.py +7 -8
  6. xpk/commands/kjob_common.py +23 -20
  7. xpk/commands/run.py +17 -11
  8. xpk/commands/shell.py +3 -4
  9. xpk/commands/storage.py +64 -19
  10. xpk/commands/workload.py +154 -319
  11. xpk/core/blueprint/blueprint_definitions.py +2 -0
  12. xpk/core/blueprint/blueprint_generator.py +322 -32
  13. xpk/core/capacity.py +1 -0
  14. xpk/core/cluster.py +75 -5
  15. xpk/core/config.py +3 -1
  16. xpk/core/docker_manager.py +1 -1
  17. xpk/core/docker_resources.py +9 -21
  18. xpk/core/filestore.py +11 -3
  19. xpk/core/gcsfuse.py +8 -5
  20. xpk/core/kjob.py +57 -18
  21. xpk/core/nap.py +4 -0
  22. xpk/core/network.py +11 -21
  23. xpk/core/nodepool.py +28 -26
  24. xpk/core/pathways.py +165 -210
  25. xpk/core/scheduling.py +36 -0
  26. xpk/core/storage.py +66 -12
  27. xpk/core/system_characteristics.py +9 -0
  28. xpk/core/workload.py +27 -82
  29. xpk/core/workload_decorators/rdma_decorator.py +3 -3
  30. xpk/core/workload_decorators/storage_decorator.py +8 -3
  31. xpk/core/workload_decorators/tcpxo_decorator.py +2 -2
  32. xpk/parser/cluster.py +15 -6
  33. xpk/parser/storage.py +14 -3
  34. xpk/parser/workload.py +59 -31
  35. {xpk-0.7.1.dist-info → xpk-0.8.0.dist-info}/METADATA +60 -4
  36. {xpk-0.7.1.dist-info → xpk-0.8.0.dist-info}/RECORD +40 -40
  37. {xpk-0.7.1.dist-info → xpk-0.8.0.dist-info}/WHEEL +1 -1
  38. {xpk-0.7.1.dist-info → xpk-0.8.0.dist-info}/entry_points.txt +0 -0
  39. {xpk-0.7.1.dist-info → xpk-0.8.0.dist-info}/licenses/LICENSE +0 -0
  40. {xpk-0.7.1.dist-info → xpk-0.8.0.dist-info}/top_level.txt +0 -0
xpk/core/pathways.py CHANGED
@@ -14,124 +14,24 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from .cluster import XPK_SA
17
+ from ..core.commands import run_command_for_value, run_command_with_updates, run_commands
18
18
  from ..core.docker_container import get_user_workload_container
19
19
  from ..core.gcloud_context import zone_to_region
20
20
  from ..core.nodepool import get_all_nodepools_programmatic
21
21
  from ..utils.console import xpk_exit, xpk_print
22
22
  from .config import AcceleratorType
23
- from .storage import Storage, get_storage_volumes_yaml, GCS_FUSE_ANNOTATION
24
23
  from .system_characteristics import SystemCharacteristics
25
24
 
26
- PathwaysExpectedInstancesMap = {
27
- 'v6e': 'tpuv6e',
28
- 'v5p': 'tpuv5',
29
- 'v5litepod': 'tpuv5e',
30
- 'v4': 'tpuv4',
31
- 'v3': 'tpuv3',
32
- }
33
-
34
-
35
- def get_pathways_worker_args(args) -> str:
36
- """Arguments for the Pathways workers.
37
- Args:
38
- args: user provided arguments for running the command.
39
-
40
- Returns:
41
- str: yaml containing arguments for the Pathways workers.
42
- """
43
- yaml = """- --server_port=29001
44
- - --resource_manager_address={rm_address}
45
- - --gcs_scratch_location={args.pathways_gcs_location}"""
46
- if args.use_pathways:
47
- if args.custom_pathways_worker_args:
48
- yaml = append_custom_pathways_args(yaml, args.custom_pathways_worker_args)
49
- return yaml.format(args=args, rm_address=get_rm_address(args))
50
- else:
51
- return ''
52
-
53
-
54
- def get_pathways_proxy_args(args) -> str:
55
- """Arguments for the Pathways proxy.
56
- Args:
57
- args: user provided arguments for running the command.
58
-
59
- Returns:
60
- str: yaml containing arguments for the Pathways proxy.
61
- """
62
- yaml = """- --server_port=29000
63
- - --resource_manager_address={rm_address}
64
- - --gcs_scratch_location={args.pathways_gcs_location}"""
65
-
66
- if args.use_pathways:
67
- if args.custom_pathways_proxy_server_args:
68
- yaml = append_custom_pathways_args(
69
- yaml, args.custom_pathways_proxy_server_args
70
- )
71
- return yaml.format(args=args, rm_address=get_rm_address(args))
72
- else:
73
- return ''
74
-
75
-
76
- def get_pathways_sidecar_container(args) -> str:
77
- """This is a sidecar container that runs the remote python server.
78
-
79
- It is a special case of the initContainer (designated by restartPolicy:
80
- Always)
81
- See https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/
82
- for more details.
83
- Args:
84
- args: user provided arguments for running the command.
85
-
86
- Returns:
87
- str: yaml containing arguments for the Pathways sidecar container.
88
- """
89
- yaml = """initContainers:
90
- - name: remote-python-sidecar
91
- image: {args.remote_python_sidecar_image}
92
- imagePullPolicy: Always
93
- securityContext:
94
- privileged: true
95
- volumeMounts:
96
- - mountPath: /tmp # Shared volume mount with the main container.
97
- name: shared-tmp
98
- restartPolicy: Always
99
- ports:
100
- - containerPort: 50051
101
- env:
102
- - name: GRPC_SERVER_ADDRESS
103
- value: '0.0.0.0:50051'"""
104
- if args.use_pathways and args.remote_python_sidecar_image is not None:
105
- return yaml.format(args=args)
106
- else:
107
- return ''
108
-
109
25
 
110
26
  def add_pw_resource_flavors(args):
111
27
  """Add resource flavors required for Pathways enabled clusters."""
112
28
  resource_flavor_yaml = """apiVersion: kueue.x-k8s.io/v1beta1
113
29
  kind: ResourceFlavor
114
- metadata:
115
- name: cpu-rm
116
- spec:
117
- nodeLabels:
118
- cloud.google.com/gke-nodepool: cpu-rm-np
119
- ---
120
- apiVersion: kueue.x-k8s.io/v1beta1
121
- kind: ResourceFlavor
122
- metadata:
123
- name: cpu-proxy
124
- spec:
125
- nodeLabels:
126
- cloud.google.com/gke-nodepool: cpu-proxy-np
127
- ---
128
- apiVersion: kueue.x-k8s.io/v1beta1
129
- kind: ResourceFlavor
130
30
  metadata:
131
31
  name: cpu-user
132
32
  spec:
133
33
  nodeLabels:
134
- cloud.google.com/gke-nodepool: cpu-user-np
34
+ cloud.google.com/gke-nodepool: cpu-np
135
35
  ---"""
136
36
  if args.enable_pathways:
137
37
  return resource_flavor_yaml
@@ -142,18 +42,6 @@ def add_pw_resources_to_kueue(args):
142
42
  """Add resource flavors required for Pathways, to the cluster queue."""
143
43
  resources_yaml = """- coveredResources: ["cpu", "memory"]
144
44
  flavors:
145
- - name: cpu-rm
146
- resources:
147
- - name: "cpu"
148
- nominalQuota: 480
149
- - name: "memory"
150
- nominalQuota: 2000G
151
- - name: cpu-proxy
152
- resources:
153
- - name: "cpu"
154
- nominalQuota: 480
155
- - name: "memory"
156
- nominalQuota: 2000G
157
45
  - name: cpu-user
158
46
  resources:
159
47
  - name: "cpu"
@@ -175,6 +63,10 @@ def ensure_pathways_workload_prerequisites(args, system) -> bool:
175
63
  Returns:
176
64
  True once conditions satisfy and variables are set. Exits otherwise.
177
65
  """
66
+ # Ensure that PathwaysJob is installed and available on the cluster.
67
+ if not check_if_pathways_job_is_installed(args):
68
+ xpk_exit(1)
69
+
178
70
  # Ensure command is provided if not using Pathways in headless mode
179
71
  if args.command is None and not args.headless:
180
72
  xpk_print(
@@ -187,7 +79,7 @@ def ensure_pathways_workload_prerequisites(args, system) -> bool:
187
79
 
188
80
  # Ensure the cluster and CPU nodepools were created with create-pathways
189
81
  all_node_pools = get_all_nodepools_programmatic(args)
190
- desired_pw_cpu_node_pools = {'cpu-user-np', 'cpu-rm-np', 'cpu-proxy-np'}
82
+ desired_pw_cpu_node_pools = {'cpu-np'}
191
83
  if not desired_pw_cpu_node_pools.issubset(set(all_node_pools[0])):
192
84
  xpk_print(
193
85
  'Cluster needs to be created with `xpk create-pathways` to run'
@@ -209,6 +101,35 @@ def ensure_pathways_workload_prerequisites(args, system) -> bool:
209
101
  return True
210
102
 
211
103
 
104
+ def check_if_pathways_job_is_installed(args) -> bool:
105
+ """Check if PathwaysJob is installed on the cluster.
106
+ Args:
107
+ args: user provided arguments for running the command.
108
+ Returns:
109
+ 0 if successful and 1 otherwise.
110
+ """
111
+ command = (
112
+ 'kubectl get pods -n pathways-job-system --no-headers -o'
113
+ ' custom-columns=NAME:.metadata.name'
114
+ )
115
+ task = f'Check if PathwaysJob is installed on {args.cluster}'
116
+ return_code, return_msg = run_command_for_value(command, task, args)
117
+ # return_msg contains the name of the controller pod, if found.
118
+ xpk_print('check_if_pathways_job_is_installed', return_code, return_msg)
119
+
120
+ if return_code != 0:
121
+ xpk_print(f'{task} returned with ERROR {return_code}.\n')
122
+ return False
123
+ if not return_msg:
124
+ xpk_print(
125
+ 'You are using a new version of XPK, which uses PathwaysJob'
126
+ ' for Pathways workloads. Please update the cluster using'
127
+ ' `cluster create-pathways` to enjoy the upgrade!'
128
+ )
129
+ return False
130
+ return True
131
+
132
+
212
133
  def get_pathways_unified_query_link(args) -> str:
213
134
  """Get the unified query link for the pathways workload."""
214
135
  query_params = (
@@ -223,58 +144,106 @@ def get_pathways_unified_query_link(args) -> str:
223
144
  return f'https://console.cloud.google.com/logs/query;query={query_params}'
224
145
 
225
146
 
226
- def get_pathways_rm_args(args, system: SystemCharacteristics) -> str:
227
- """Arguments for the Pathways resource manager.
228
- Args:
229
- args: user provided arguments for running the command.
147
+ def append_custom_pathways_flags(custom_args, prev_indentation=8) -> str:
148
+ """Append custom Pathways args to Pathways components using a YAML with proper indentation.
230
149
 
231
150
  Returns:
232
- str: yaml containing arguments for the Pathways resource manager.
151
+ yaml (string): yaml with additional args appended.
233
152
  """
234
- yaml = """- --server_port=29001
235
- - --gcs_scratch_location={args.pathways_gcs_location}
236
- - --node_type=resource_manager
237
- - --instance_count={instance_count}
238
- - --instance_type={instance_type}"""
239
- if args.use_pathways:
240
- if args.custom_pathways_server_args:
241
- yaml = append_custom_pathways_args(yaml, args.custom_pathways_server_args)
242
- return yaml.format(
243
- args=args,
244
- instance_count=args.num_slices,
245
- instance_type=f'{get_pathways_expected_tpu_type(system.device_type)}:{system.topology}',
153
+ yaml = """"""
154
+ indentation = ' ' * (prev_indentation + 2)
155
+ if custom_args:
156
+ custom_args = custom_args.split(' ')
157
+ for arg in custom_args:
158
+ yaml += '\n' + indentation + '- ' + arg
159
+ return yaml
160
+
161
+
162
+ def append_custom_pathways_proxy_server(args) -> str:
163
+ """Append custom Pathways proxy server component using a YAML with proper indentation.
164
+
165
+ Returns:
166
+ yaml (string): yaml with custom proxy server appended.
167
+ """
168
+ yaml = """"""
169
+ if args.proxy_server_image or args.custom_pathways_proxy_server_args:
170
+ yaml = """- componentType: proxy_server"""
171
+ indentation = (
172
+ ' ' * 8
173
+ ) # Currently 8, based on the YAML, may need to update in the future.
174
+ if args.proxy_server_image:
175
+ yaml += '\n' + indentation + 'image: ' + args.proxy_server_image
176
+ if args.custom_pathways_proxy_server_args:
177
+ yaml += '\n' + indentation + 'customFlags: '
178
+ yaml += append_custom_pathways_flags(
179
+ args.custom_pathways_proxy_server_args, len(indentation)
246
180
  )
247
- else:
248
- return ''
181
+ return yaml
249
182
 
250
183
 
251
- def append_custom_pathways_args(yaml, custom_args) -> str:
252
- """Append custom Pathways args to the YAML with proper indentation.
184
+ def append_custom_pathways_server(args) -> str:
185
+ """Append custom Pathways server component using a YAML with proper indentation.
253
186
 
254
- Args:
255
- yaml (string): existing yaml containing args
187
+ Returns:
188
+ yaml (string): yaml with custom pathways server appended.
189
+ """
190
+ yaml = """"""
191
+ if args.server_image or args.custom_pathways_server_args:
192
+ yaml = """- componentType: pathways_server"""
193
+ indentation = (
194
+ ' ' * 8
195
+ ) # Currently 8, based on the YAML, may need to update in the future.
196
+ if args.server_image:
197
+ yaml += '\n' + indentation + 'image: ' + args.server_image
198
+ if args.custom_pathways_server_args:
199
+ yaml += '\n' + indentation + 'customFlags: '
200
+ yaml += append_custom_pathways_flags(
201
+ args.custom_pathways_server_args, len(indentation)
202
+ )
203
+ return yaml
204
+
205
+
206
+ def append_custom_pathways_worker(args) -> str:
207
+ """Append custom Pathways worker component using a YAML with proper indentation.
208
+
209
+ Returns:
210
+ yaml (string): yaml with custom pathways server appended.
211
+ """
212
+ yaml = """"""
213
+ if args.server_image or args.custom_pathways_worker_args:
214
+ yaml = """- componentType: pathways_worker"""
215
+ indentation = (
216
+ ' ' * 8
217
+ ) # Currently 8, based on the YAML, may need to update in the future.
218
+ if args.server_image:
219
+ yaml += '\n' + indentation + 'image: ' + args.server_image
220
+ if args.custom_pathways_worker_args:
221
+ yaml += '\n' + indentation + 'customFlags: '
222
+ yaml += append_custom_pathways_flags(
223
+ args.custom_pathways_worker_args, len(indentation)
224
+ )
225
+ return yaml
226
+
227
+
228
+ def append_custom_colocated_python_sidecar(args) -> str:
229
+ """Append custom Pathways colocated python sidecar component using a YAML with proper indentation.
256
230
 
257
231
  Returns:
258
- yaml (string): yaml with additional args appended.
232
+ yaml (string): yaml with custom pathways server appended.
259
233
  """
260
- second_line = yaml.split('\n')[1]
261
- if (
262
- not second_line
263
- ): # to cover edge case if only one arg remains, we would have to look at the entire YAML in this case.
264
- return yaml
265
- # Calculate the indentation based on the second line of existing YAML.
266
- indentation = ' ' * (len(second_line) - len(second_line.lstrip()))
267
- custom_args = custom_args.split(' ')
268
- for arg in custom_args:
269
- yaml += '\n' + indentation + '- ' + arg
234
+ yaml = """"""
235
+ if args.colocated_python_sidecar_image:
236
+ yaml = """- componentType: colocated_python_sidecar"""
237
+ indentation = (
238
+ ' ' * 8
239
+ ) # Currently 8, based on the YAML, may need to update in the future.
240
+ yaml += '\n' + indentation + 'image: ' + args.colocated_python_sidecar_image
270
241
  return yaml
271
242
 
272
243
 
273
244
  def get_user_workload_for_pathways(
274
245
  args,
275
246
  system: SystemCharacteristics,
276
- pod_failure_policy,
277
- storages: list[Storage],
278
247
  ) -> str:
279
248
  """
280
249
  Create a user workload container for Pathways.
@@ -289,63 +258,32 @@ def get_user_workload_for_pathways(
289
258
  str:
290
259
  Pathways server port as a YAML string
291
260
  """
292
- user_workload_yaml = """- name: main
293
- replicas: 1
294
- template:
295
- metadata:
296
- labels:
297
- xpk.google.com/workload: {args.workload}
298
- spec:
299
- backoffLimit: 0
300
- completions: 1
301
- parallelism: 1
302
- {pod_failure_policy}
303
- template:
304
- metadata:
305
- annotations:
306
- {gcs_fuse_annotation}
307
- spec:
308
- containers:
261
+ user_workload_yaml = """
262
+ metadata:
263
+ spec:
264
+ containers:
309
265
  {container}
310
- serviceAccountName: {service_account}
311
- nodeSelector:
312
- cloud.google.com/gke-nodepool: cpu-user-np
313
- hostNetwork: true
314
- dnsPolicy: ClusterFirstWithHostNet
315
- restartPolicy: Never
316
- volumes:
317
- - hostPath:
318
- path: /tmp
319
- type: DirectoryOrCreate
320
- name: shared-tmp
321
- {storage_volumes}"""
266
+ nodeSelector:
267
+ cloud.google.com/gke-nodepool: cpu-np
268
+ hostNetwork: true
269
+ dnsPolicy: ClusterFirstWithHostNet
270
+ restartPolicy: Never
271
+ volumes:
272
+ - hostPath:
273
+ path: /tmp
274
+ type: DirectoryOrCreate
275
+ name: shared-tmp
276
+ """
322
277
  if args.headless:
323
278
  return ''
324
279
  else:
325
280
  container, _ = get_user_workload_container(args, system)
326
- storage_volumes = get_storage_volumes_yaml(storages)
327
281
  return user_workload_yaml.format(
328
282
  args=args,
329
283
  container=container,
330
- storage_volumes=storage_volumes,
331
- pod_failure_policy=pod_failure_policy,
332
- service_account=XPK_SA,
333
- gcs_fuse_annotation=GCS_FUSE_ANNOTATION,
334
284
  )
335
285
 
336
286
 
337
- def get_rm_address(args) -> str:
338
- """Generates the Pathways resource manager address.
339
- Args:
340
- args: user provided arguments for running the command.
341
-
342
- Returns:
343
- str: Fully qualified RM address.
344
- """
345
- rm_address = f'{args.workload}-rm-0-0.{args.workload}:29001'
346
- return rm_address
347
-
348
-
349
287
  def get_proxy_address(args) -> str:
350
288
  """Generates the Pathways proxy address.
351
289
  Args:
@@ -354,24 +292,41 @@ def get_proxy_address(args) -> str:
354
292
  Returns:
355
293
  str: Fully qualified proxy address.
356
294
  """
357
- proxy_address = f'grpc://{args.workload}-proxy-0-0.{args.workload}:29000'
295
+ proxy_address = (
296
+ f'grpc://{args.workload}-pathways-head-0-0.{args.workload}:29000'
297
+ )
358
298
  return proxy_address
359
299
 
360
300
 
361
- def get_pathways_expected_tpu_type(device_type: str) -> str:
362
- """Returns the device type expected by Pathways
301
+ def try_to_delete_pathwaysjob_first(args, workloads) -> bool:
302
+ """Function to delete PathwaysJob workload. This is needed as PathwaysJob
303
+ owns the JobSet it creates.
304
+
363
305
  Args:
364
- device_type: the system characteristic device type
306
+ args: user provided arguments for running the command.
307
+ workloads: list of workloads that match the delete filter.
365
308
 
366
309
  Returns:
367
- str: the device type expected by pathways.
310
+ True if successful and False otherwise.
368
311
  """
369
- raw_type = device_type.split('-')[0].lower()
370
- pathways_expected_instance = PathwaysExpectedInstancesMap[raw_type]
371
- if not pathways_expected_instance:
372
- xpk_print(
373
- f'Passed in device_type {device_type} is incorrect. Please pass in a'
374
- ' valid device type'
312
+ commands = []
313
+ task_names = []
314
+ for workload in workloads:
315
+ args.workload = workload
316
+ command = f'kubectl delete pathwaysjob {workload} -n default'
317
+ task_name = f'PathwaysWorkloadDelete-{workload}'
318
+ commands.append(command)
319
+ task_names.append(task_name)
320
+
321
+ # Not batching deletion for single workload
322
+ if len(workloads) == 1:
323
+ return_code = run_command_with_updates(commands[0], 'Delete Workload', args)
324
+ else:
325
+ return_code = run_commands(
326
+ commands, 'Delete Workload', task_names, batch=100
375
327
  )
376
- xpk_exit(1)
377
- return pathways_expected_instance
328
+
329
+ if return_code != 0:
330
+ xpk_print(f'Delete Workload request returned ERROR {return_code}')
331
+ return False
332
+ return True
xpk/core/scheduling.py CHANGED
@@ -229,6 +229,21 @@ def create_accelerator_label(accelerator_type, system) -> str:
229
229
  )
230
230
 
231
231
 
232
+ def create_tpu_machine_type(accelerator_type, system) -> str:
233
+ """Generates TPU machine type..
234
+
235
+ Args:
236
+ accelerator_type: type of accelerator.
237
+ system: system characteristics.
238
+
239
+ Returns:
240
+ The accelerator label.
241
+ """
242
+ if accelerator_type == AcceleratorType['TPU']:
243
+ return f'{system.gce_machine_type}'
244
+ return ''
245
+
246
+
232
247
  def create_machine_label(
233
248
  accelerator_type, system, autoprovisioning_enabled: bool = False
234
249
  ) -> str:
@@ -251,3 +266,24 @@ def create_machine_label(
251
266
  f' {system.topology}'
252
267
  )
253
268
  return ''
269
+
270
+
271
+ def create_tpu_topology(
272
+ accelerator_type, system, autoprovisioning_enabled: bool = False
273
+ ) -> str:
274
+ """Generates TPU topology.
275
+
276
+ Args:
277
+ accelerator_type: type of accelerator.
278
+ system: system characteristics.
279
+ autoprovisioning_enabled: describes autoprovisioning enablement.
280
+
281
+ Returns:
282
+ The machine label.
283
+ """
284
+ if (
285
+ accelerator_type == AcceleratorType['TPU']
286
+ and not autoprovisioning_enabled
287
+ ):
288
+ return f'{system.topology}'
289
+ return ''
xpk/core/storage.py CHANGED
@@ -45,8 +45,18 @@ STORAGE_CRD_PLURAL = "storages"
45
45
  STORAGE_CRD_NAME = f"{XPK_API_GROUP_NAME}.{STORAGE_CRD_PLURAL}"
46
46
  GCS_FUSE_TYPE = "gcsfuse"
47
47
  GCP_FILESTORE_TYPE = "gcpfilestore"
48
+ PARALLELSTORE_TYPE = "parallelstore"
49
+ GCE_PD_TYPE = "pd"
48
50
  MANIFESTS_PATH = os.path.abspath("xpkclusters/storage-manifests")
49
- GCS_FUSE_ANNOTATION = 'gke-gcsfuse/volumes: "true"'
51
+ GCS_FUSE_ANNOTATIONS = {
52
+ "gke-gcsfuse/volumes": "true",
53
+ "gke-gcsfuse/cpu-limit": "0",
54
+ "gke-gcsfuse/memory-limit": "0",
55
+ "gke-gcsfuse/ephemeral-storage-limit": "0",
56
+ }
57
+ PARALLELSTORE_ANNOTATIONS = {
58
+ "gke-parallelstore/volumes": "true",
59
+ }
50
60
 
51
61
 
52
62
  @dataclass
@@ -210,6 +220,24 @@ def get_auto_mount_gcsfuse_storages(k8s_api_client: ApiClient) -> list[Storage]:
210
220
  return list(filter(lambda storage: storage.type == GCS_FUSE_TYPE, storages))
211
221
 
212
222
 
223
+ def get_auto_mount_parallelstore_storages(
224
+ k8s_api_client: ApiClient,
225
+ ) -> list[Storage]:
226
+ """
227
+ Retrieves all GCS Fuse Storage resources that have --auto-mount flag set to true.
228
+
229
+ Args:
230
+ k8s_api_client: An ApiClient object for interacting with the Kubernetes API.
231
+
232
+ Returns:
233
+ A list of GCS Fuse Storage objects that have `auto_mount` set to True.
234
+ """
235
+ storages: list[Storage] = get_auto_mount_storages(k8s_api_client)
236
+ return list(
237
+ filter(lambda storage: storage.type == PARALLELSTORE_TYPE, storages)
238
+ )
239
+
240
+
213
241
  def get_storages(
214
242
  k8s_api_client: ApiClient, requested_storages: list[str]
215
243
  ) -> list[Storage]:
@@ -314,6 +342,29 @@ def install_storage_crd(k8s_api_client: ApiClient) -> None:
314
342
  xpk_exit(1)
315
343
 
316
344
 
345
+ def get_storage_annotations(storages: list[Storage]) -> list[str]:
346
+ """
347
+ Generates the storage annotations for workloads in the format of a YAML snippet.
348
+
349
+ Args:
350
+ storages: A list of Storage objects
351
+ offset: An integer specifying the depth of the YAML file
352
+
353
+ Returns:
354
+ A string containing the YAML representation of the storage annotations.
355
+ """
356
+ annotations = []
357
+ if any(storage.type == GCS_FUSE_TYPE for storage in storages):
358
+ for key, value in GCS_FUSE_ANNOTATIONS.items():
359
+ annotations.append(f'{key}: "{value}"')
360
+
361
+ if any(storage.type == PARALLELSTORE_TYPE for storage in storages):
362
+ for key, value in PARALLELSTORE_ANNOTATIONS.items():
363
+ annotations.append(f'{key}: "{value}"')
364
+
365
+ return annotations
366
+
367
+
317
368
  def get_storage_volume_mounts_yaml(storages: list[Storage]) -> str:
318
369
  """
319
370
  Generates the YAML representation of the volumeMounts section for the given Storages.
@@ -360,26 +411,29 @@ def get_storage_volumes_yaml(storages: list[Storage]) -> str:
360
411
  return yaml_str
361
412
 
362
413
 
363
- def get_storage_volume_mounts_yaml_for_gpu(storages: list[Storage]) -> str:
414
+ def get_storage_volume_mounts_for_gpu(
415
+ storages: list[Storage],
416
+ ) -> list[dict]:
364
417
  """
365
418
  Generates the YAML representation of the volumeMounts section for the given Storages.
366
419
 
367
- This function creates the YAML snippet that defines how the storage volumes
420
+ This function creates the list of storage specifications that define how the storage volumes
368
421
  should be mounted within a Pod's containers.
369
422
 
370
423
  Args:
371
424
  storages: A list of Storage objects.
372
425
 
373
426
  Returns:
374
- A string containing the YAML representation of the volumeMounts section.
375
- """
376
- yaml_str = ""
377
- for storage in storages:
378
- yaml_str += f"""- name: {storage.pv}
379
- mountPath: {storage.mount_point}
380
- readOnly: {storage.readonly}
381
- """
382
- return yaml_str
427
+ A list containing the dictionary representation of the volumeMounts section.
428
+ """
429
+ return [
430
+ {
431
+ "name": storage.pv,
432
+ "mountPath": storage.mount_point,
433
+ "readOnly": storage.readonly,
434
+ }
435
+ for storage in storages
436
+ ]
383
437
 
384
438
 
385
439
  def get_storage_volumes_yaml_for_gpu(storages: list[Storage]) -> str:
@@ -173,6 +173,15 @@ UserFacingNameToSystemCharacteristics = {
173
173
  AcceleratorType['GPU'],
174
174
  'a100-40gb-8',
175
175
  ),
176
+ 'b200-8': SystemCharacteristics(
177
+ 'N/A',
178
+ 1,
179
+ 'nvidia-b200',
180
+ 'a4-highgpu-8g',
181
+ 8,
182
+ AcceleratorType['GPU'],
183
+ 'b200-8',
184
+ ),
176
185
  'h200-141gb-8': SystemCharacteristics(
177
186
  'N/A',
178
187
  1,