xpk 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. xpk/commands/cluster.py +10 -11
  2. xpk/commands/cluster_gcluster.py +2 -1
  3. xpk/commands/common.py +3 -3
  4. xpk/commands/info.py +12 -12
  5. xpk/commands/job.py +12 -10
  6. xpk/commands/kjob_common.py +2 -1
  7. xpk/commands/storage.py +1 -1
  8. xpk/commands/workload.py +12 -6
  9. xpk/core/blueprint/blueprint_generator.py +7 -7
  10. xpk/core/blueprint/blueprint_test.py +218 -0
  11. xpk/core/capacity.py +3 -1
  12. xpk/core/cluster.py +9 -7
  13. xpk/core/cluster_private.py +5 -1
  14. xpk/core/commands.py +3 -3
  15. xpk/core/config.py +3 -4
  16. xpk/core/config_test.py +71 -0
  17. xpk/core/docker_manager.py +1 -1
  18. xpk/core/docker_resources.py +1 -1
  19. xpk/core/filestore.py +7 -2
  20. xpk/core/gcloud_context.py +2 -2
  21. xpk/core/kjob.py +2 -1
  22. xpk/core/kueue.py +6 -2
  23. xpk/core/nap.py +4 -4
  24. xpk/core/nodepool_test.py +82 -0
  25. xpk/core/resources.py +1 -7
  26. xpk/core/storage.py +14 -14
  27. xpk/core/system_characteristics.py +1 -1
  28. xpk/core/workload.py +11 -0
  29. xpk/core/workload_decorators/rdma_decorator.py +3 -2
  30. xpk/core/workload_decorators/storage_decorator.py +2 -1
  31. xpk/core/workload_decorators/tcpx_decorator.py +4 -2
  32. xpk/core/workload_decorators/tcpx_decorator_test.py +267 -0
  33. xpk/core/workload_decorators/tcpxo_decorator.py +2 -1
  34. xpk/core/workload_test.py +28 -0
  35. xpk/main.py +9 -10
  36. xpk/parser/cluster.py +67 -49
  37. xpk/parser/common.py +45 -36
  38. xpk/parser/storage.py +12 -13
  39. xpk/parser/workload.py +57 -39
  40. xpk/utils/console.py +2 -1
  41. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/METADATA +4 -1
  42. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/RECORD +46 -41
  43. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/WHEEL +0 -0
  44. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/entry_points.txt +0 -0
  45. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/licenses/LICENSE +0 -0
  46. {xpk-0.11.0.dist-info → xpk-0.12.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,267 @@
1
+ """
2
+ Copyright 2024 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import copy
18
+
19
+ import yaml
20
+
21
+ from xpk.core.workload_decorators import tcpx_decorator
22
+ from xpk.utils.yaml import literal_string
23
+
24
+ # Minimal JobSet manifest for testing
25
+ BASE_JOBSET_MANIFEST_STR = """
26
+ apiVersion: jobset.x-k8s.io/v1alpha2
27
+ kind: JobSet
28
+ metadata:
29
+ name: test-jobset
30
+ spec:
31
+ replicatedJobs:
32
+ - name: slice-job
33
+ template:
34
+ spec:
35
+ template:
36
+ metadata:
37
+ annotations:
38
+ existing-annotation: "true"
39
+ spec:
40
+ containers:
41
+ - name: main-gpu-container
42
+ image: my-gpu-image
43
+ resources:
44
+ limits:
45
+ nvidia.com/gpu: 8
46
+ - name: sidecar-container
47
+ image: my-sidecar-image
48
+ """
49
+
50
+ # Minimal kjob template for testing
51
+ BASE_KJOB_TEMPLATE = {
52
+ "spec": {
53
+ "template": {
54
+ "spec": {
55
+ "containers": [
56
+ {
57
+ "name": "main-gpu-container",
58
+ "image": "my-gpu-image",
59
+ "resources": {"limits": {"nvidia.com/gpu": 8}},
60
+ },
61
+ {"name": "sidecar-container", "image": "my-sidecar-image"},
62
+ ]
63
+ }
64
+ }
65
+ }
66
+ }
67
+
68
+ # Minimal job manifest for testing
69
+ BASE_JOB_MANIFEST = {
70
+ "spec": {
71
+ "template": {
72
+ "metadata": {"annotations": {"existing-annotation": "true"}},
73
+ "spec": {
74
+ "containers": [
75
+ {
76
+ "name": "main-gpu-container",
77
+ "image": "my-gpu-image",
78
+ "resources": {"limits": {"nvidia.com/gpu": 8}},
79
+ },
80
+ {"name": "sidecar-container", "image": "my-sidecar-image"},
81
+ ]
82
+ },
83
+ }
84
+ }
85
+ }
86
+
87
+
88
+ def test_get_interfaces_annotation():
89
+ """Tests get_interfaces_annotation."""
90
+ annotation = tcpx_decorator.get_interfaces_annotation()
91
+ assert "networking.gke.io/interfaces" in annotation
92
+ assert isinstance(annotation["networking.gke.io/interfaces"], literal_string)
93
+ expected_value = (
94
+ "[\n"
95
+ ' {"interfaceName":"eth0","network":"default"},\n'
96
+ ' {"interfaceName":"eth1","network":"vpc1"},\n'
97
+ ' {"interfaceName":"eth2","network":"vpc2"},\n'
98
+ ' {"interfaceName":"eth3","network":"vpc3"},\n'
99
+ ' {"interfaceName":"eth4","network":"vpc4"}\n'
100
+ "]"
101
+ )
102
+ assert str(annotation["networking.gke.io/interfaces"]) == expected_value
103
+
104
+
105
+ def test_get_tcpx_deamon_annotation():
106
+ """Tests get_tcpx_deamon_annotation."""
107
+ annotation = tcpx_decorator.get_tcpx_deamon_annotation()
108
+ assert "devices.gke.io/container.tcpx-daemon" in annotation
109
+ assert isinstance(
110
+ annotation["devices.gke.io/container.tcpx-daemon"], literal_string
111
+ )
112
+ expected_value = (
113
+ "- path: /dev/nvidia0\n"
114
+ "- path: /dev/nvidia1\n"
115
+ "- path: /dev/nvidia2\n"
116
+ "- path: /dev/nvidia3\n"
117
+ "- path: /dev/nvidia4\n"
118
+ "- path: /dev/nvidia5\n"
119
+ "- path: /dev/nvidia6\n"
120
+ "- path: /dev/nvidia7\n"
121
+ "- path: /dev/nvidiactl\n"
122
+ "- path: /dev/nvidia-uvm\n"
123
+ )
124
+ assert (
125
+ str(annotation["devices.gke.io/container.tcpx-daemon"]) == expected_value
126
+ )
127
+
128
+
129
+ def test_decorate_jobset():
130
+ """Tests decorate_jobset."""
131
+ decorated_str = tcpx_decorator.decorate_jobset(BASE_JOBSET_MANIFEST_STR)
132
+ manifest = yaml.safe_load(decorated_str)
133
+
134
+ pod_template_spec = manifest["spec"]["replicatedJobs"][0]["template"]["spec"][
135
+ "template"
136
+ ]["spec"]
137
+ pod_template_metadata = manifest["spec"]["replicatedJobs"][0]["template"][
138
+ "spec"
139
+ ]["template"]["metadata"]
140
+
141
+ # Check annotations
142
+ annotations = pod_template_metadata["annotations"]
143
+ assert "existing-annotation" in annotations
144
+ assert "devices.gke.io/container.tcpx-daemon" in annotations
145
+ assert "networking.gke.io/default-interface" in annotations
146
+ assert "networking.gke.io/interfaces" in annotations
147
+
148
+ # Check tolerations
149
+ tolerations = pod_template_spec["tolerations"]
150
+ assert {
151
+ "key": "user-workload",
152
+ "operator": "Equal",
153
+ "value": "true",
154
+ "effect": "NoSchedule",
155
+ } in tolerations
156
+
157
+ # Check volumes
158
+ volumes = pod_template_spec["volumes"]
159
+ volume_names = {v["name"] for v in volumes}
160
+ assert "libraries" in volume_names
161
+ assert "sys" in volume_names
162
+ assert "proc-sys" in volume_names
163
+ assert "tcpx-socket" in volume_names
164
+ assert "dshm" in volume_names
165
+
166
+ # Check init container
167
+ init_containers = pod_template_spec["initContainers"]
168
+ assert len(init_containers) == 1
169
+ tcpx_daemon = init_containers[0]
170
+ assert tcpx_daemon["name"] == "tcpx-daemon"
171
+ assert tcpx_daemon["image"].endswith(f":{tcpx_decorator.tcpx}")
172
+
173
+ # Check GPU container update
174
+ gpu_container = pod_template_spec["containers"][0]
175
+ assert gpu_container["name"] == "main-gpu-container"
176
+
177
+ # Check env
178
+ env_vars = {e["name"]: e["value"] for e in gpu_container["env"]}
179
+ assert env_vars["LD_LIBRARY_PATH"] == "/usr/local/nvidia/lib64"
180
+
181
+ # Check volume mounts
182
+ volume_mounts = {
183
+ vm["name"]: vm["mountPath"] for vm in gpu_container["volumeMounts"]
184
+ }
185
+ assert volume_mounts["tcpx-socket"] == "/tmp"
186
+ assert volume_mounts["libraries"] == "/usr/local/nvidia/lib64"
187
+ assert volume_mounts["dshm"] == "/dev/shm"
188
+
189
+ # Check non-GPU container is not updated
190
+ sidecar_container = pod_template_spec["containers"][1]
191
+ assert "env" not in sidecar_container
192
+ assert "volumeMounts" not in sidecar_container
193
+
194
+
195
+ def test_decorate_job():
196
+ """Tests decorate_job."""
197
+ job_manifest = copy.deepcopy(BASE_JOB_MANIFEST)
198
+
199
+ decorated_manifest = tcpx_decorator.decorate_job(job_manifest)
200
+ pod_template_metadata = decorated_manifest["spec"]["template"]["metadata"]
201
+
202
+ # Check annotations
203
+ annotations = pod_template_metadata["annotations"]
204
+ assert "existing-annotation" in annotations
205
+ assert "devices.gke.io/container.tcpx-daemon" in annotations
206
+ assert "networking.gke.io/default-interface" in annotations
207
+ assert "networking.gke.io/interfaces" in annotations
208
+
209
+
210
+ def test_decorate_kjob_template():
211
+ """Tests decorate_kjob_template."""
212
+ kjob_template = copy.deepcopy(BASE_KJOB_TEMPLATE)
213
+
214
+ decorated_manifest = tcpx_decorator.decorate_kjob_template(kjob_template)
215
+
216
+ pod_template_spec = decorated_manifest["spec"]["template"]["spec"]
217
+
218
+ # Check annotations are NOT added
219
+ assert "annotations" not in decorated_manifest["spec"]["template"].get(
220
+ "metadata", {}
221
+ )
222
+
223
+ # Check tolerations
224
+ tolerations = pod_template_spec["tolerations"]
225
+ assert {
226
+ "key": "user-workload",
227
+ "operator": "Equal",
228
+ "value": "true",
229
+ "effect": "NoSchedule",
230
+ } in tolerations
231
+
232
+ # Check volumes
233
+ volumes = pod_template_spec["volumes"]
234
+ volume_names = {v["name"] for v in volumes}
235
+ assert "libraries" in volume_names
236
+ assert "sys" in volume_names
237
+ assert "proc-sys" in volume_names
238
+ assert "tcpx-socket" in volume_names
239
+ assert "dshm" in volume_names
240
+
241
+ # Check init container
242
+ init_containers = pod_template_spec["initContainers"]
243
+ assert len(init_containers) == 1
244
+ tcpx_daemon = init_containers[0]
245
+ assert tcpx_daemon["name"] == "tcpx-daemon"
246
+ assert tcpx_daemon["image"].endswith(f":{tcpx_decorator.tcpx}")
247
+
248
+ # Check GPU container update
249
+ gpu_container = pod_template_spec["containers"][0]
250
+ assert gpu_container["name"] == "main-gpu-container"
251
+
252
+ # Check env
253
+ env_vars = {e["name"]: e["value"] for e in gpu_container["env"]}
254
+ assert env_vars["LD_LIBRARY_PATH"] == "/usr/local/nvidia/lib64"
255
+
256
+ # Check volume mounts
257
+ volume_mounts = {
258
+ vm["name"]: vm["mountPath"] for vm in gpu_container["volumeMounts"]
259
+ }
260
+ assert volume_mounts["tcpx-socket"] == "/tmp"
261
+ assert volume_mounts["libraries"] == "/usr/local/nvidia/lib64"
262
+ assert volume_mounts["dshm"] == "/dev/shm"
263
+
264
+ # Check non-GPU container is not updated
265
+ sidecar_container = pod_template_spec["containers"][1]
266
+ assert "env" not in sidecar_container
267
+ assert "volumeMounts" not in sidecar_container
@@ -74,7 +74,8 @@ def decorate_jobset(jobset_manifest_str: str, sub_networks: list[str]) -> str:
74
74
  for job in manifest['spec']['replicatedJobs']:
75
75
  job_manifest = job['template']
76
76
  job_manifest = decorate_job(job_manifest, sub_networks)
77
- return yaml.dump(manifest, sort_keys=False)
77
+ yaml_result: str = yaml.dump(manifest, sort_keys=False)
78
+ return yaml_result
78
79
 
79
80
 
80
81
  def get_interfaces_entry(sub_networks: list[str]) -> tuple[str, str]:
@@ -0,0 +1,28 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from xpk.core.workload import get_jobsets_list_gcp_link
18
+
19
+
20
+ def test_get_jobsets_list_gcp_link():
21
+ result = get_jobsets_list_gcp_link(
22
+ project='test-project',
23
+ )
24
+
25
+ assert (
26
+ result
27
+ == 'https://console.cloud.google.com/kubernetes/aiml/deployments/jobs?project=test-project'
28
+ )
xpk/main.py CHANGED
@@ -56,18 +56,17 @@ if (
56
56
  f' User currently is running {user_major_version}.{user_minor_version}'
57
57
  )
58
58
 
59
- # Create top level parser for xpk command.
60
- parser = argparse.ArgumentParser(description='xpk command', prog='xpk')
61
- set_parser(parser=parser)
62
-
63
- xpk_print('Starting xpk', flush=True)
64
- validate_dependencies()
65
- main_args = parser.parse_args()
66
- main_args.enable_ray_cluster = False
67
- main_args.func(main_args)
68
-
69
59
 
70
60
  def main() -> None:
61
+ # Create top level parser for xpk command.
62
+ parser = argparse.ArgumentParser(description='xpk command', prog='xpk')
63
+ set_parser(parser=parser)
64
+
65
+ xpk_print('Starting xpk', flush=True)
66
+ validate_dependencies()
67
+ main_args = parser.parse_args()
68
+ main_args.enable_ray_cluster = False
69
+ main_args.func(main_args)
71
70
  xpk_print('XPK Done.', flush=True)
72
71
 
73
72
 
xpk/parser/cluster.py CHANGED
@@ -29,7 +29,7 @@ from ..commands.cluster import (
29
29
  from ..commands.config import xpk_cfg
30
30
  from ..core.config import CFG_BUCKET_KEY
31
31
  from ..core.vertex import DEFAULT_VERTEX_TENSORBOARD_NAME
32
- from .common import add_shared_arguments
32
+ from .common import add_shared_arguments, ParserOrArgumentGroup
33
33
  from .validators import name_type
34
34
 
35
35
 
@@ -208,6 +208,14 @@ def set_cluster_create_pathways_parser(
208
208
  cluster_create_pathways_optional_arguments
209
209
  )
210
210
 
211
+ autoprovisioning_arguments = (
212
+ cluster_create_pathways_parser.add_argument_group(
213
+ 'Autoprovisioning Arguments',
214
+ 'Optional arguments for enabling autoprovisioning.',
215
+ )
216
+ )
217
+ add_autoprovisioning_arguments(autoprovisioning_arguments)
218
+
211
219
  ### Capacity arguments specific to "cluster create-pathways"
212
220
  cluster_create_pathways_capacity_arguments = (
213
221
  cluster_create_pathways_parser.add_argument_group(
@@ -529,15 +537,15 @@ def set_cluster_adapt_parser(cluster_adapt_parser: ArgumentParser):
529
537
  cluster_adapt_parser.set_defaults(func=cluster_adapt)
530
538
 
531
539
 
532
- def add_autoprovisioning_arguments(parser: ArgumentParser):
533
- parser.add_argument(
540
+ def add_autoprovisioning_arguments(parser_or_group: ParserOrArgumentGroup):
541
+ parser_or_group.add_argument(
534
542
  '--enable-autoprovisioning',
535
543
  action='store_true',
536
544
  help=(
537
545
  'Enable GKE features for autoprovisioning node pools in GKE clusters.'
538
546
  ),
539
547
  )
540
- parser.add_argument(
548
+ parser_or_group.add_argument(
541
549
  '--autoprovisioning-min-chips',
542
550
  type=int,
543
551
  help=(
@@ -546,7 +554,7 @@ def add_autoprovisioning_arguments(parser: ArgumentParser):
546
554
  ' resources in the cluster as the minimum, and maximum.'
547
555
  ),
548
556
  )
549
- parser.add_argument(
557
+ parser_or_group.add_argument(
550
558
  '--autoprovisioning-max-chips',
551
559
  type=int,
552
560
  help=(
@@ -557,13 +565,15 @@ def add_autoprovisioning_arguments(parser: ArgumentParser):
557
565
  )
558
566
 
559
567
 
560
- def add_shared_cluster_create_required_arguments(parser: ArgumentParser):
568
+ def add_shared_cluster_create_required_arguments(
569
+ parser_or_group: ParserOrArgumentGroup,
570
+ ):
561
571
  """Add shared required arguments in cluster create and Pathways cluster create.
562
572
 
563
573
  Args:
564
- parser: cluster create argument parser or argument group
574
+ parser_or_group: cluster create argument parser or argument group
565
575
  """
566
- parser.add_argument(
576
+ parser_or_group.add_argument(
567
577
  '--cluster',
568
578
  type=name_type,
569
579
  default=None,
@@ -575,21 +585,23 @@ def add_shared_cluster_create_required_arguments(parser: ArgumentParser):
575
585
  )
576
586
 
577
587
 
578
- def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
588
+ def add_shared_cluster_create_optional_arguments(
589
+ parser_or_group: ParserOrArgumentGroup,
590
+ ):
579
591
  """Add shared optional arguments in cluster create and Pathways cluster create.
580
592
 
581
593
  Args:
582
- parser: cluster create argument parser or argument group
594
+ parser_or_group: cluster create argument parser or argument group
583
595
  """
584
- add_shared_arguments(parser)
585
- parser.add_argument(
596
+ add_shared_arguments(parser_or_group)
597
+ parser_or_group.add_argument(
586
598
  '--host-maintenance-interval',
587
599
  type=str,
588
600
  choices=['AS_NEEDED', 'PERIODIC'],
589
601
  default='AS_NEEDED',
590
602
  help='The maintenance policy of the cluster and respective clusters.',
591
603
  )
592
- parser.add_argument(
604
+ parser_or_group.add_argument(
593
605
  '--gke-version',
594
606
  type=str,
595
607
  help=(
@@ -598,20 +610,20 @@ def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
598
610
  ' recommended version.'
599
611
  ),
600
612
  )
601
- parser.add_argument(
613
+ parser_or_group.add_argument(
602
614
  '--num-slices',
603
615
  type=int,
604
616
  default=1,
605
617
  help='The number of slices to run the job on, defaults to 1.',
606
618
  required=False,
607
619
  )
608
- parser.add_argument(
620
+ parser_or_group.add_argument(
609
621
  '--pathways-gce-machine-type',
610
622
  type=str,
611
623
  default='n2-standard-64',
612
624
  help='The CPU type for Pathways CPU nodepools',
613
625
  )
614
- parser.add_argument(
626
+ parser_or_group.add_argument(
615
627
  '--default-pool-cpu-machine-type',
616
628
  type=str,
617
629
  default='e2-standard-16',
@@ -620,7 +632,7 @@ def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
620
632
  ' regional clusters, all zones must support the machine type.'
621
633
  ),
622
634
  )
623
- parser.add_argument(
635
+ parser_or_group.add_argument(
624
636
  '--cluster-cpu-machine-type',
625
637
  type=str,
626
638
  default='',
@@ -631,7 +643,7 @@ def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
631
643
  ' cpu nodepools using --device-type.'
632
644
  ),
633
645
  )
634
- parser.add_argument(
646
+ parser_or_group.add_argument(
635
647
  '--default-pool-cpu-num-nodes',
636
648
  type=int,
637
649
  default=6,
@@ -641,7 +653,7 @@ def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
641
653
  ' over time.'
642
654
  ),
643
655
  )
644
- parser.add_argument(
656
+ parser_or_group.add_argument(
645
657
  '--custom-cluster-arguments',
646
658
  type=str,
647
659
  default='',
@@ -652,7 +664,7 @@ def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
652
664
  " --custom-cluster-arguments='--network=mtu9k --subnetwork=mtu9k'"
653
665
  ),
654
666
  )
655
- parser.add_argument(
667
+ parser_or_group.add_argument(
656
668
  '--custom-nodepool-arguments',
657
669
  type=str,
658
670
  default='',
@@ -663,7 +675,7 @@ def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
663
675
  ' --custom-nodepool-arguments="--disk-size=300"'
664
676
  ),
665
677
  )
666
- parser.add_argument(
678
+ parser_or_group.add_argument(
667
679
  '--force',
668
680
  action='store_true',
669
681
  help=(
@@ -671,7 +683,7 @@ def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
671
683
  ' additional approval.'
672
684
  ),
673
685
  )
674
- parser.add_argument(
686
+ parser_or_group.add_argument(
675
687
  '--custom-tpu-nodepool-arguments',
676
688
  type=str,
677
689
  default='',
@@ -682,7 +694,7 @@ def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
682
694
  ' --custom-tpu-nodepool-arguments="--enable-ip-alias"'
683
695
  ),
684
696
  )
685
- parser.add_argument(
697
+ parser_or_group.add_argument(
686
698
  '--private',
687
699
  action='store_true',
688
700
  help=(
@@ -695,7 +707,7 @@ def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
695
707
  ' clusters.'
696
708
  ),
697
709
  )
698
- parser.add_argument(
710
+ parser_or_group.add_argument(
699
711
  '--authorized-networks',
700
712
  action='extend',
701
713
  nargs='+',
@@ -710,16 +722,16 @@ def add_shared_cluster_create_optional_arguments(parser: ArgumentParser):
710
722
  ' Example usage: --authorized-networks 1.2.3.0/24 1.2.4.5/32'
711
723
  ),
712
724
  )
713
- parser.add_argument(
725
+ parser_or_group.add_argument(
714
726
  '--enable-workload-identity',
715
727
  action='store_true',
716
728
  help='Enable Workload Identity Federation on the cluster and node-pools.',
717
729
  )
718
- add_driver_arguments(parser)
730
+ add_driver_arguments(parser_or_group)
719
731
 
720
732
 
721
- def add_driver_arguments(parser: ArgumentParser):
722
- parser.add_argument(
733
+ def add_driver_arguments(parser_or_group: ParserOrArgumentGroup):
734
+ parser_or_group.add_argument(
723
735
  '--enable-gcsfuse-csi-driver',
724
736
  action='store_true',
725
737
  help=(
@@ -728,42 +740,44 @@ def add_driver_arguments(parser: ArgumentParser):
728
740
  ' Identity is enabled by default.'
729
741
  ),
730
742
  )
731
- parser.add_argument(
743
+ parser_or_group.add_argument(
732
744
  '--enable-gcpfilestore-csi-driver',
733
745
  action='store_true',
734
746
  help='Enable GCPFilestore driver on the cluster.',
735
747
  )
736
- parser.add_argument(
748
+ parser_or_group.add_argument(
737
749
  '--enable-parallelstore-csi-driver',
738
750
  action='store_true',
739
751
  help='Enable Parallelstore CSI driver on the cluster.',
740
752
  )
741
- parser.add_argument(
753
+ parser_or_group.add_argument(
742
754
  '--enable-pd-csi-driver',
743
755
  action='store_true',
744
756
  help='Enable PersistentDisk CSI driver on the cluster.',
745
757
  )
746
- parser.add_argument(
758
+ parser_or_group.add_argument(
747
759
  '--enable-lustre-csi-driver',
748
760
  action='store_true',
749
761
  help='Enable Lustre CSI driver on the cluster.',
750
762
  )
751
763
 
752
764
 
753
- def add_shared_cluster_create_tensorboard_arguments(parser: ArgumentParser):
765
+ def add_shared_cluster_create_tensorboard_arguments(
766
+ parser_or_group: ParserOrArgumentGroup,
767
+ ):
754
768
  """Add shared tensorboard arguments in cluster create and Pathways cluster create.
755
769
  Note that this feature enables non-Pathways workloads to use tensorboard arguments
756
770
  on a Pathways cluster.
757
771
 
758
772
  Args:
759
- parser: cluster create argument parser or argument group
773
+ parser_or_group: cluster create argument parser or argument group
760
774
  """
761
- parser.add_argument(
775
+ parser_or_group.add_argument(
762
776
  '--create-vertex-tensorboard',
763
777
  action='store_true',
764
778
  help='Set this flag to create a Tensorboard instance in Vertex AI.',
765
779
  )
766
- parser.add_argument(
780
+ parser_or_group.add_argument(
767
781
  '--tensorboard-region',
768
782
  type=str,
769
783
  default='us-central1',
@@ -774,7 +788,7 @@ def add_shared_cluster_create_tensorboard_arguments(parser: ArgumentParser):
774
788
  ' instance will be created in us-central1.'
775
789
  ),
776
790
  )
777
- parser.add_argument(
791
+ parser_or_group.add_argument(
778
792
  '--tensorboard-name',
779
793
  type=str,
780
794
  required=False,
@@ -787,13 +801,15 @@ def add_shared_cluster_create_tensorboard_arguments(parser: ArgumentParser):
787
801
  )
788
802
 
789
803
 
790
- def add_shared_cluster_create_capacity_arguments(parser: ArgumentParser):
804
+ def add_shared_cluster_create_capacity_arguments(
805
+ parser_or_group: ParserOrArgumentGroup,
806
+ ):
791
807
  """Add shared capacity arguments in cluster create and Pathways cluster create.
792
808
 
793
809
  Args:
794
- parser: cluster create argument parser or argument group
810
+ parser_or_group: cluster create argument parser or argument group
795
811
  """
796
- parser.add_argument(
812
+ parser_or_group.add_argument(
797
813
  '--on-demand',
798
814
  action='store_true',
799
815
  help=(
@@ -802,7 +818,7 @@ def add_shared_cluster_create_capacity_arguments(parser: ArgumentParser):
802
818
  ' types.'
803
819
  ),
804
820
  )
805
- parser.add_argument(
821
+ parser_or_group.add_argument(
806
822
  '--reservation',
807
823
  type=str,
808
824
  help=(
@@ -811,7 +827,7 @@ def add_shared_cluster_create_capacity_arguments(parser: ArgumentParser):
811
827
  ' `--flex` or `--on-demand` for other capacity types.'
812
828
  ),
813
829
  )
814
- parser.add_argument(
830
+ parser_or_group.add_argument(
815
831
  '--spot',
816
832
  action='store_true',
817
833
  help=(
@@ -820,7 +836,7 @@ def add_shared_cluster_create_capacity_arguments(parser: ArgumentParser):
820
836
  ' capacity types.'
821
837
  ),
822
838
  )
823
- parser.add_argument(
839
+ parser_or_group.add_argument(
824
840
  '--flex',
825
841
  action='store_true',
826
842
  help=(
@@ -831,18 +847,20 @@ def add_shared_cluster_create_capacity_arguments(parser: ArgumentParser):
831
847
  )
832
848
 
833
849
 
834
- def add_shared_cluster_create_mtc_arguments(parser: ArgumentParser):
850
+ def add_shared_cluster_create_mtc_arguments(
851
+ parser_or_group: ParserOrArgumentGroup,
852
+ ):
835
853
  """Add shared Multi-tier Checkpointing arguments in cluster create and Pathways cluster create.
836
854
 
837
855
  Args:
838
- List of cluster create MTC arguments parsers
856
+ List of cluster create MTC arguments parsers or group
839
857
  """
840
- parser.add_argument(
858
+ parser_or_group.add_argument(
841
859
  '--enable-mtc',
842
860
  action='store_true',
843
861
  help='Enable MTC on the cluster.',
844
862
  )
845
- parser.add_argument(
863
+ parser_or_group.add_argument(
846
864
  '--mtc-ramdisk-size',
847
865
  type=str,
848
866
  default=None,
@@ -851,7 +869,7 @@ def add_shared_cluster_create_mtc_arguments(parser: ArgumentParser):
851
869
  ' used for multi-tier checkpointing. e.g. "64Mi" '
852
870
  ),
853
871
  )
854
- parser.add_argument(
872
+ parser_or_group.add_argument(
855
873
  '--mtc-gcs-bucket',
856
874
  type=str,
857
875
  default=None,
@@ -860,7 +878,7 @@ def add_shared_cluster_create_mtc_arguments(parser: ArgumentParser):
860
878
  ' multi-tier checkpointing.'
861
879
  ),
862
880
  )
863
- parser.add_argument(
881
+ parser_or_group.add_argument(
864
882
  '--mtc-toleration-key',
865
883
  type=str,
866
884
  default=None,