xpk 0.7.2__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. xpk/commands/batch.py +19 -12
  2. xpk/commands/cluster.py +33 -16
  3. xpk/commands/cluster_gcluster.py +22 -5
  4. xpk/commands/info.py +2 -4
  5. xpk/commands/job.py +7 -8
  6. xpk/commands/kjob_common.py +23 -20
  7. xpk/commands/run.py +17 -11
  8. xpk/commands/shell.py +3 -4
  9. xpk/commands/storage.py +64 -19
  10. xpk/commands/workload.py +154 -319
  11. xpk/core/blueprint/blueprint_definitions.py +2 -0
  12. xpk/core/blueprint/blueprint_generator.py +322 -32
  13. xpk/core/capacity.py +1 -0
  14. xpk/core/cluster.py +75 -5
  15. xpk/core/config.py +3 -1
  16. xpk/core/docker_manager.py +1 -1
  17. xpk/core/docker_resources.py +9 -21
  18. xpk/core/filestore.py +11 -3
  19. xpk/core/gcsfuse.py +8 -5
  20. xpk/core/kjob.py +57 -18
  21. xpk/core/nap.py +4 -0
  22. xpk/core/network.py +11 -21
  23. xpk/core/nodepool.py +28 -26
  24. xpk/core/pathways.py +165 -210
  25. xpk/core/scheduling.py +36 -0
  26. xpk/core/storage.py +66 -12
  27. xpk/core/system_characteristics.py +9 -0
  28. xpk/core/workload.py +27 -82
  29. xpk/core/workload_decorators/rdma_decorator.py +3 -3
  30. xpk/core/workload_decorators/storage_decorator.py +8 -3
  31. xpk/core/workload_decorators/tcpxo_decorator.py +2 -2
  32. xpk/parser/cluster.py +15 -6
  33. xpk/parser/storage.py +14 -3
  34. xpk/parser/workload.py +59 -31
  35. {xpk-0.7.2.dist-info → xpk-0.8.0.dist-info}/METADATA +60 -4
  36. {xpk-0.7.2.dist-info → xpk-0.8.0.dist-info}/RECORD +40 -40
  37. {xpk-0.7.2.dist-info → xpk-0.8.0.dist-info}/WHEEL +0 -0
  38. {xpk-0.7.2.dist-info → xpk-0.8.0.dist-info}/entry_points.txt +0 -0
  39. {xpk-0.7.2.dist-info → xpk-0.8.0.dist-info}/licenses/LICENSE +0 -0
  40. {xpk-0.7.2.dist-info → xpk-0.8.0.dist-info}/top_level.txt +0 -0
xpk/commands/batch.py CHANGED
@@ -14,18 +14,26 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import re
17
18
  from argparse import Namespace
18
19
 
19
- from ..core.cluster import create_xpk_k8s_service_account
20
+ from ..core.cluster import (
21
+ create_xpk_k8s_service_account,
22
+ get_cluster_credentials,
23
+ )
20
24
  from ..core.commands import run_command_for_value
21
25
  from ..core.gcloud_context import add_zone_and_project
26
+ from ..core.kjob import (
27
+ AppProfileDefaults,
28
+ JobTemplateDefaults,
29
+ Kueue_TAS_annotation,
30
+ get_storage_annotations,
31
+ prepare_kjob,
32
+ )
22
33
  from ..core.kueue import LOCAL_QUEUE_NAME
23
34
  from ..utils.console import xpk_exit, xpk_print
24
- from .common import set_cluster_command
25
- from ..core.kjob import AppProfileDefaults, JobTemplateDefaults, prepare_kjob, Kueue_TAS_annotation, get_gcsfuse_annotation
26
- from .kjob_common import add_gpu_networking_annotations_to_command
27
35
  from .kind import set_local_cluster_command
28
- import re
36
+ from .kjob_common import add_gpu_networking_annotations_to_command
29
37
 
30
38
 
31
39
  def batch(args: Namespace) -> None:
@@ -38,12 +46,11 @@ def batch(args: Namespace) -> None:
38
46
  """
39
47
  if not args.kind_cluster:
40
48
  add_zone_and_project(args)
41
- set_cluster_command_code = set_cluster_command(args)
49
+ get_cluster_credentials(args)
42
50
  else:
43
51
  set_cluster_command_code = set_local_cluster_command(args)
44
-
45
- if set_cluster_command_code != 0:
46
- xpk_exit(set_cluster_command_code)
52
+ if set_cluster_command_code != 0:
53
+ xpk_exit(set_cluster_command_code)
47
54
 
48
55
  err_code = prepare_kjob(args)
49
56
  if err_code > 0:
@@ -66,9 +73,9 @@ def submit_job(args: Namespace) -> None:
66
73
  ' --first-node-ip'
67
74
  )
68
75
  cmd = add_gpu_networking_annotations_to_command(args, cmd)
69
- gcsfuse_annotation = get_gcsfuse_annotation(args)
70
- if gcsfuse_annotation is not None:
71
- cmd += f' --pod-template-annotation {gcsfuse_annotation}'
76
+
77
+ for annotation in get_storage_annotations(args):
78
+ cmd += f' --pod-template-annotation {annotation}'
72
79
 
73
80
  if args.ignore_unknown_flags:
74
81
  cmd += ' --ignore-unknown-flags'
xpk/commands/cluster.py CHANGED
@@ -22,9 +22,13 @@ from ..core.cluster import (
22
22
  get_cluster_credentials,
23
23
  install_nccl_on_cluster,
24
24
  set_jobset_on_cluster,
25
+ set_pathways_job_on_cluster,
25
26
  setup_k8s_env,
26
27
  update_cluster_with_gcsfuse_driver_if_necessary,
27
28
  update_cluster_with_workload_identity_if_necessary,
29
+ update_cluster_with_gcpfilestore_driver_if_necessary,
30
+ update_cluster_with_parallelstore_driver_if_necessary,
31
+ update_cluster_with_pd_driver_if_necessary,
28
32
  )
29
33
  from ..core.cluster_private import authorize_private_cluster_access_if_necessary
30
34
  from ..core.commands import run_command_for_value, run_command_with_updates
@@ -46,7 +50,7 @@ from ..core.nap import enable_autoprovisioning_on_cluster
46
50
  from ..core.network import (
47
51
  create_cluster_network_config,
48
52
  delete_cluster_subnets,
49
- set_up_cluster_network_for_gpu,
53
+ set_up_cluster_network_for_a3,
50
54
  )
51
55
  from ..core.nodepool import get_gke_node_pool_version, run_gke_node_pool_create_command
52
56
  from ..core.ray import install_ray_cluster
@@ -64,7 +68,6 @@ from ..utils.console import get_user_input, xpk_exit, xpk_print
64
68
  from ..utils.file import write_tmp_file
65
69
  from . import cluster_gcluster
66
70
  from .common import set_cluster_command
67
- from ..core.cluster import update_cluster_with_gcpfilestore_driver_if_necessary
68
71
 
69
72
 
70
73
  def cluster_create(args) -> None:
@@ -117,11 +120,7 @@ def cluster_create(args) -> None:
117
120
 
118
121
  # ToDo(roshanin@) - Re-enable CloudDNS on Pathways clusters conditionally.
119
122
  # Enable WorkloadIdentity if not enabled already.
120
- if (
121
- args.enable_workload_identity
122
- or args.enable_gcsfuse_csi_driver
123
- or args.enable_gcpfilestore_csi_driver
124
- ):
123
+ if args.enable_workload_identity or args.enable_gcsfuse_csi_driver:
125
124
  update_cluster_command_code = (
126
125
  update_cluster_with_workload_identity_if_necessary(args)
127
126
  )
@@ -143,6 +142,20 @@ def cluster_create(args) -> None:
143
142
  if update_cluster_command_code != 0:
144
143
  xpk_exit(update_cluster_command_code)
145
144
 
145
+ if args.enable_parallelstore_csi_driver:
146
+ update_cluster_command_code = (
147
+ update_cluster_with_parallelstore_driver_if_necessary(args)
148
+ )
149
+ if update_cluster_command_code != 0:
150
+ xpk_exit(update_cluster_command_code)
151
+
152
+ if args.enable_pd_csi_driver:
153
+ update_cluster_command_code = update_cluster_with_pd_driver_if_necessary(
154
+ args
155
+ )
156
+ if update_cluster_command_code != 0:
157
+ xpk_exit(update_cluster_command_code)
158
+
146
159
  # Update Pathways clusters with CloudDNS if not enabled already.
147
160
 
148
161
  get_cluster_credentials(args)
@@ -155,13 +168,12 @@ def cluster_create(args) -> None:
155
168
  if not tensorboard_config:
156
169
  xpk_exit(1)
157
170
 
158
- if system.accelerator_type == AcceleratorType['GPU']:
171
+ if system.device_type == H100_DEVICE_TYPE:
159
172
  xpk_print('Setting up Network for cluster')
160
- set_up_cluster_network_code = set_up_cluster_network_for_gpu(args, system)
173
+ set_up_cluster_network_code = set_up_cluster_network_for_a3(args)
161
174
  if set_up_cluster_network_code != 0:
162
175
  xpk_exit(set_up_cluster_network_code)
163
176
 
164
- if system.device_type == H100_DEVICE_TYPE:
165
177
  xpk_print('Creating Network Config for cluster')
166
178
  create_cluster_network_config_code = create_cluster_network_config(args)
167
179
  if create_cluster_network_config_code != 0:
@@ -207,6 +219,10 @@ def cluster_create(args) -> None:
207
219
  if set_jobset_on_cluster_code != 0:
208
220
  xpk_exit(set_jobset_on_cluster_code)
209
221
 
222
+ set_pathways_job_on_cluster_code = set_pathways_job_on_cluster(args)
223
+ if set_pathways_job_on_cluster_code != 0:
224
+ xpk_exit(set_pathways_job_on_cluster_code)
225
+
210
226
  xpk_print('Enabling Kueue on the cluster')
211
227
  install_kueue_on_cluster_code = install_kueue_on_cluster(args)
212
228
  if install_kueue_on_cluster_code != 0:
@@ -783,20 +799,21 @@ def run_gke_cluster_create_command(
783
799
  if args.enable_ray_cluster:
784
800
  command += ' --addons RayOperator'
785
801
 
786
- if (
787
- args.enable_workload_identity
788
- or args.enable_gcsfuse_csi_driver
789
- or args.enable_gcpfilestore_csi_driver
790
- ):
802
+ if args.enable_workload_identity or args.enable_gcsfuse_csi_driver:
791
803
  command += f' --workload-pool={args.project}.svc.id.goog'
792
804
 
793
805
  addons = []
794
806
  if args.enable_gcsfuse_csi_driver:
795
807
  addons.append('GcsFuseCsiDriver')
796
-
797
808
  if args.enable_gcpfilestore_csi_driver:
798
809
  addons.append('GcpFilestoreCsiDriver')
799
810
 
811
+ if args.enable_parallelstore_csi_driver:
812
+ addons.append('ParallelstoreCsiDriver')
813
+
814
+ if args.enable_pd_csi_driver:
815
+ addons.append('GcePersistentDiskCsiDriver')
816
+
800
817
  if len(addons) > 0:
801
818
  addons_str = ','.join(addons)
802
819
  command += f' --addons={addons_str}'
@@ -16,26 +16,27 @@ limitations under the License.
16
16
 
17
17
  import os
18
18
 
19
- from ..core.remote_state.remote_state_client import RemoteStateClient
20
- from ..core.remote_state.fuse_remote_state import FuseStateClient
21
19
  from ..core.blueprint.blueprint_generator import (
22
20
  BlueprintGenerator,
23
21
  BlueprintGeneratorOutput,
24
22
  a3mega_device_type,
25
23
  a3ultra_device_type,
24
+ a4_device_type,
26
25
  supported_device_types,
27
26
  )
28
- from ..core.commands import run_command_for_value
29
27
  from ..core.capacity import get_capacity_type
28
+ from ..core.cluster import get_cluster_credentials
29
+ from ..core.commands import run_command_for_value
30
30
  from ..core.docker_manager import DockerManager
31
31
  from ..core.gcloud_context import zone_to_region
32
32
  from ..core.gcluster_manager import GclusterManager
33
+ from ..core.kjob import apply_kjob_crds, prepare_kjob
34
+ from ..core.remote_state.fuse_remote_state import FuseStateClient
35
+ from ..core.remote_state.remote_state_client import RemoteStateClient
33
36
  from ..utils.console import xpk_exit, xpk_print
34
37
  from ..utils.file import ensure_directory_exists
35
38
  from ..utils.network import all_IPs_cidr
36
39
  from ..utils.objects import hash_string
37
- from ..core.cluster import get_cluster_credentials
38
- from ..core.kjob import apply_kjob_crds, prepare_kjob
39
40
 
40
41
  blueprints_path = os.path.abspath('xpkclusters/blueprints')
41
42
  gcluster_working_dir = os.path.abspath('xpkclusters/gcluster-out')
@@ -266,4 +267,20 @@ def generate_blueprint(
266
267
  system_node_pool_min_node_count=args.default_pool_cpu_num_nodes,
267
268
  gcs_bucket=args.cluster_state_gcs_bucket,
268
269
  )
270
+ if args.device_type == a4_device_type:
271
+ num_nodes = args.num_nodes if not args.num_nodes is None else 2
272
+ return bpg.generate_a4_blueprint(
273
+ blueprint_name=blueprint_name,
274
+ prefix=prefix,
275
+ cluster_name=args.cluster,
276
+ region=zone_to_region(args.zone),
277
+ project_id=args.project,
278
+ zone=args.zone,
279
+ auth_cidr=all_IPs_cidr,
280
+ num_nodes=num_nodes,
281
+ reservation=args.reservation if args.reservation else None,
282
+ capacity_type=capacity_type,
283
+ system_node_pool_machine_type=args.default_pool_cpu_machine_type,
284
+ system_node_pool_min_node_count=args.default_pool_cpu_num_nodes,
285
+ )
269
286
  return None
xpk/commands/info.py CHANGED
@@ -20,10 +20,10 @@ from argparse import Namespace
20
20
  from tabulate import tabulate
21
21
 
22
22
  from ..core.commands import run_command_for_value
23
+ from ..core.cluster import get_cluster_credentials
23
24
  from ..core.gcloud_context import add_zone_and_project
24
25
  from ..core.kueue import verify_kueuectl
25
26
  from ..utils.console import xpk_exit, xpk_print
26
- from .common import set_cluster_command
27
27
 
28
28
  table_fmt = 'plain'
29
29
 
@@ -37,9 +37,7 @@ def info(args: Namespace) -> None:
37
37
  None
38
38
  """
39
39
  add_zone_and_project(args)
40
- set_cluster_command_code = set_cluster_command(args)
41
- if set_cluster_command_code != 0:
42
- xpk_exit(set_cluster_command_code)
40
+ get_cluster_credentials(args)
43
41
 
44
42
  verify_kueuectl(args)
45
43
  lq, cq = bool(args.localqueue), bool(args.clusterqueue)
xpk/commands/job.py CHANGED
@@ -20,10 +20,10 @@ import sys
20
20
  from ruamel.yaml import YAML
21
21
 
22
22
  from ..core.commands import run_command_for_value, run_command_with_updates
23
+ from ..core.cluster import get_cluster_credentials
23
24
  from ..core.gcloud_context import add_zone_and_project
24
25
  from ..core.kjob import AppProfileDefaults
25
26
  from ..utils.console import xpk_exit, xpk_print
26
- from .common import set_cluster_command
27
27
  from .kind import set_local_cluster_command
28
28
 
29
29
 
@@ -143,14 +143,14 @@ def job_list(args) -> None:
143
143
  """
144
144
  if not args.kind_cluster:
145
145
  add_zone_and_project(args)
146
- set_cluster_command_code = set_cluster_command(args)
146
+ get_cluster_credentials(args)
147
147
  msg = f'Listing jobs for project {args.project} and zone {args.zone}:'
148
148
  else:
149
149
  set_cluster_command_code = set_local_cluster_command(args)
150
150
  msg = 'Listing jobs:'
151
+ if set_cluster_command_code != 0:
152
+ xpk_exit(set_cluster_command_code)
151
153
 
152
- if set_cluster_command_code != 0:
153
- xpk_exit(set_cluster_command_code)
154
154
  xpk_print(msg, flush=True)
155
155
 
156
156
  return_code = run_slurm_job_list_command(args)
@@ -178,12 +178,11 @@ def job_cancel(args) -> None:
178
178
  xpk_print(f'Starting job cancel for job: {args.name}', flush=True)
179
179
  if not args.kind_cluster:
180
180
  add_zone_and_project(args)
181
- set_cluster_command_code = set_cluster_command(args)
181
+ get_cluster_credentials(args)
182
182
  else:
183
183
  set_cluster_command_code = set_local_cluster_command(args)
184
-
185
- if set_cluster_command_code != 0:
186
- xpk_exit(set_cluster_command_code)
184
+ if set_cluster_command_code != 0:
185
+ xpk_exit(set_cluster_command_code)
187
186
 
188
187
  return_code = run_slurm_job_delete_command(args)
189
188
  xpk_exit(return_code)
@@ -14,31 +14,34 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from ..core.kjob import get_a3mega_pod_template_annotations, get_a3ultra_pod_template_annotations
18
- from ..core.capacity import H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
17
+ from ..core.capacity import (
18
+ B200_DEVICE_TYPE,
19
+ H100_MEGA_DEVICE_TYPE,
20
+ H200_DEVICE_TYPE,
21
+ )
19
22
  from ..core.cluster import get_gpu_type_from_cluster
20
-
21
-
22
- def add_tcpxo_annotations(args, cmd: str) -> str:
23
- tcpxo, interfaces, eth0 = get_a3mega_pod_template_annotations(args)
24
- cmd += f" --pod-template-annotation {tcpxo} \\\n"
25
- cmd += f" --pod-template-annotation {eth0} \\\n"
26
- cmd += f" --pod-template-annotation {interfaces} "
27
- return cmd
28
-
29
-
30
- def add_rdma_annotations(args, cmd) -> str:
31
- eth0, interfaces = get_a3ultra_pod_template_annotations(args)
32
- cmd += f" --pod-template-annotation {eth0} \\\n"
33
- cmd += f" --pod-template-annotation {interfaces} \\\n"
34
- return cmd
23
+ from ..core.kjob import (
24
+ get_a3mega_pod_template_annotations,
25
+ get_a3ultra_pod_template_annotations,
26
+ get_a4_pod_template_annotations,
27
+ )
35
28
 
36
29
 
37
30
  def add_gpu_networking_annotations_to_command(args, cmd: str) -> str:
38
31
  gpu_type = get_gpu_type_from_cluster(args)
39
32
 
40
33
  if gpu_type == H100_MEGA_DEVICE_TYPE:
41
- return add_tcpxo_annotations(args, cmd)
42
- if gpu_type == H200_DEVICE_TYPE:
43
- return add_rdma_annotations(args, cmd)
34
+ annotations = get_a3mega_pod_template_annotations(args)
35
+ elif gpu_type == H200_DEVICE_TYPE:
36
+ annotations = get_a3ultra_pod_template_annotations(args)
37
+ elif gpu_type == B200_DEVICE_TYPE:
38
+ annotations = get_a4_pod_template_annotations()
39
+ else:
40
+ annotations = []
41
+
42
+ flags = [
43
+ f" --pod-template-annotation {annotation} " for annotation in annotations
44
+ ]
45
+ cmd += "\\\n".join(flags)
46
+
44
47
  return cmd
xpk/commands/run.py CHANGED
@@ -16,15 +16,23 @@ limitations under the License.
16
16
 
17
17
  from argparse import Namespace
18
18
 
19
- from ..core.cluster import create_xpk_k8s_service_account
19
+ from ..core.cluster import (
20
+ create_xpk_k8s_service_account,
21
+ get_cluster_credentials,
22
+ )
20
23
  from ..core.commands import run_command_with_full_controls
21
24
  from ..core.gcloud_context import add_zone_and_project
25
+ from ..core.kjob import (
26
+ AppProfileDefaults,
27
+ JobTemplateDefaults,
28
+ Kueue_TAS_annotation,
29
+ get_storage_annotations,
30
+ prepare_kjob,
31
+ )
22
32
  from ..core.kueue import LOCAL_QUEUE_NAME
23
33
  from ..utils.console import xpk_exit, xpk_print
24
- from .common import set_cluster_command
25
- from ..core.kjob import JobTemplateDefaults, AppProfileDefaults, prepare_kjob, Kueue_TAS_annotation, get_gcsfuse_annotation
26
- from .kjob_common import add_gpu_networking_annotations_to_command
27
34
  from .kind import set_local_cluster_command
35
+ from .kjob_common import add_gpu_networking_annotations_to_command
28
36
 
29
37
 
30
38
  def run(args: Namespace) -> None:
@@ -37,12 +45,11 @@ def run(args: Namespace) -> None:
37
45
  """
38
46
  if not args.kind_cluster:
39
47
  add_zone_and_project(args)
40
- set_cluster_command_code = set_cluster_command(args)
48
+ get_cluster_credentials(args)
41
49
  else:
42
50
  set_cluster_command_code = set_local_cluster_command(args)
43
-
44
- if set_cluster_command_code != 0:
45
- xpk_exit(set_cluster_command_code)
51
+ if set_cluster_command_code != 0:
52
+ xpk_exit(set_cluster_command_code)
46
53
 
47
54
  err_code = prepare_kjob(args)
48
55
  if err_code > 0:
@@ -64,9 +71,8 @@ def submit_job(args: Namespace) -> None:
64
71
  )
65
72
  cmd = add_gpu_networking_annotations_to_command(args, cmd)
66
73
 
67
- gcsfuse_annotation = get_gcsfuse_annotation(args)
68
- if gcsfuse_annotation is not None:
69
- cmd += f' --pod-template-annotation {gcsfuse_annotation}'
74
+ for annotation in get_storage_annotations(args):
75
+ cmd += f' --pod-template-annotation {annotation}'
70
76
 
71
77
  if args.timeout:
72
78
  cmd += f' --wait-timeout {args.timeout}s'
xpk/commands/shell.py CHANGED
@@ -20,7 +20,7 @@ from ..core.kjob import (
20
20
  AppProfileDefaults,
21
21
  prepare_kjob,
22
22
  get_pod_template_interactive_command,
23
- get_gcsfuse_annotation,
23
+ get_storage_annotations,
24
24
  )
25
25
 
26
26
  exit_instructions = 'To exit the shell input "exit".'
@@ -89,9 +89,8 @@ def connect_to_new_interactive_shell(args: Namespace) -> int:
89
89
  f' {AppProfileDefaults.NAME.value} --pod-running-timeout 180s'
90
90
  )
91
91
 
92
- gcsfuse_annotation = get_gcsfuse_annotation(args)
93
- if gcsfuse_annotation is not None:
94
- cmd += f' --pod-template-annotation {gcsfuse_annotation}'
92
+ for annotation in get_storage_annotations(args):
93
+ cmd += f' --pod-template-annotation {annotation}'
95
94
 
96
95
  return run_command_with_full_controls(
97
96
  command=cmd,
xpk/commands/storage.py CHANGED
@@ -27,6 +27,8 @@ from ..core.cluster import (
27
27
  add_zone_and_project,
28
28
  get_cluster_network,
29
29
  setup_k8s_env,
30
+ update_cluster_with_parallelstore_driver_if_necessary,
31
+ update_cluster_with_pd_driver_if_necessary,
30
32
  update_cluster_with_gcpfilestore_driver_if_necessary,
31
33
  update_cluster_with_gcsfuse_driver_if_necessary,
32
34
  update_cluster_with_workload_identity_if_necessary,
@@ -41,6 +43,8 @@ from ..core.kjob import (
41
43
  from ..core.storage import (
42
44
  GCP_FILESTORE_TYPE,
43
45
  GCS_FUSE_TYPE,
46
+ GCE_PD_TYPE,
47
+ PARALLELSTORE_TYPE,
44
48
  STORAGE_CRD_PLURAL,
45
49
  XPK_API_GROUP_NAME,
46
50
  XPK_API_GROUP_VERSION,
@@ -78,7 +82,11 @@ def storage_create(args: Namespace) -> None:
78
82
  manifest = list(yaml.safe_load_all(f))
79
83
  else:
80
84
  manifest = filestore_client.manifest(
81
- args.name, args.vol, args.access_mode, filestore_network
85
+ args.name,
86
+ args.vol,
87
+ args.access_mode,
88
+ filestore_network,
89
+ args.mount_options,
82
90
  )
83
91
 
84
92
  k8s_api_client = setup_k8s_env(args)
@@ -86,9 +94,10 @@ def storage_create(args: Namespace) -> None:
86
94
  create_volume_bundle_instance(
87
95
  k8s_api_client, args.name, manifest, args.readonly, args.mount_point
88
96
  )
89
- return_code = update_cluster_with_workload_identity_if_necessary(args)
90
- if return_code > 0:
91
- xpk_exit(return_code)
97
+ # Not required for Filestore. Will be uncommented when adding GCSFuse create
98
+ # return_code = update_cluster_with_workload_identity_if_necessary(args)
99
+ # if return_code > 0:
100
+ # xpk_exit(return_code)
92
101
  return_code = update_cluster_with_gcpfilestore_driver_if_necessary(args)
93
102
  if return_code > 0:
94
103
  xpk_exit(return_code)
@@ -131,6 +140,7 @@ def storage_delete(args: Namespace) -> None:
131
140
 
132
141
  def storage_attach(args: Namespace) -> None:
133
142
  add_zone_and_project(args)
143
+ manifest = [{}]
134
144
  if args.type == GCP_FILESTORE_TYPE:
135
145
  if args.instance is None:
136
146
  args.instance = args.name
@@ -148,10 +158,14 @@ def storage_attach(args: Namespace) -> None:
148
158
  else:
149
159
  filestore_network = get_cluster_network(args)
150
160
  manifest = filestore_client.manifest(
151
- args.name, args.vol, args.access_mode, filestore_network
161
+ args.name,
162
+ args.vol,
163
+ args.access_mode,
164
+ filestore_network,
165
+ args.mount_options,
152
166
  )
153
167
 
154
- else: # args.type == GCS_FUSE_TYPE:
168
+ elif args.type == GCS_FUSE_TYPE:
155
169
  if args.manifest is None and args.size is None:
156
170
  xpk_print("--size is required when attaching gcsfuse storage.")
157
171
  xpk_exit(1)
@@ -164,30 +178,61 @@ def storage_attach(args: Namespace) -> None:
164
178
  manifest = list(yaml.safe_load_all(f))
165
179
  else:
166
180
  manifest = gcsfuse.manifest(
167
- name=args.name, bucket=args.bucket, size=args.size
181
+ args.name, args.bucket, args.size, args.mount_options
182
+ )
183
+
184
+ elif args.type in [PARALLELSTORE_TYPE, GCE_PD_TYPE]:
185
+ if args.manifest is None:
186
+ xpk_print(
187
+ "Parallelstore and PersistentDisk are currently supported only with"
188
+ " --manifest"
168
189
  )
190
+ xpk_exit(1)
191
+
192
+ with open(args.manifest, "r", encoding="utf-8") as f:
193
+ manifest = list(yaml.safe_load_all(f))
194
+
195
+ else:
196
+ xpk_print(f"Storage type {args.type} is not supported.")
197
+ xpk_exit(1)
169
198
 
170
199
  k8s_api_client = setup_k8s_env(args)
171
200
  create_storage_crds(k8s_api_client, args, manifest)
172
201
  create_volume_bundle_instance(
173
202
  k8s_api_client, args.name, manifest, args.readonly, args.mount_point
174
203
  )
175
- return_code = update_cluster_with_workload_identity_if_necessary(args)
176
- if return_code > 0:
177
- xpk_exit(return_code)
178
-
179
- # args.type can have only two values after parsing
180
- return_code = (
181
- update_cluster_with_gcsfuse_driver_if_necessary(args)
182
- if args.type == GCS_FUSE_TYPE
183
- else update_cluster_with_gcpfilestore_driver_if_necessary(args)
184
- )
185
- if return_code > 0:
186
- xpk_exit(return_code)
204
+
205
+ enable_csi_drivers_if_necessary(args)
187
206
 
188
207
  apply_kubectl_manifest(k8s_api_client, manifest)
189
208
 
190
209
 
210
+ def enable_csi_drivers_if_necessary(args: Namespace) -> None:
211
+ if args.type == GCS_FUSE_TYPE:
212
+ return_code = update_cluster_with_workload_identity_if_necessary(args)
213
+ if return_code > 0:
214
+ xpk_exit(return_code)
215
+
216
+ return_code = update_cluster_with_gcsfuse_driver_if_necessary(args)
217
+ if return_code > 0:
218
+ xpk_exit(return_code)
219
+
220
+ if args.type == GCP_FILESTORE_TYPE:
221
+ return_code = update_cluster_with_gcpfilestore_driver_if_necessary(args)
222
+ if return_code > 0:
223
+ xpk_exit(return_code)
224
+
225
+ if args.type == PARALLELSTORE_TYPE:
226
+ return_code = update_cluster_with_parallelstore_driver_if_necessary(args)
227
+ if return_code > 0:
228
+ xpk_exit(return_code)
229
+
230
+ if args.type == GCE_PD_TYPE:
231
+ return_code = update_cluster_with_pd_driver_if_necessary(args)
232
+ if return_code > 0:
233
+ xpk_exit(return_code)
234
+
235
+
191
236
  def storage_list(args: Namespace) -> None:
192
237
  k8s_api_client = setup_k8s_env(args)
193
238
  storages = list_storages(k8s_api_client)