xpk 0.14.4__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. integration/README.md +19 -0
  2. integration/gcluster_a3mega_test.py +11 -0
  3. integration/gcluster_a3ultra_test.py +11 -0
  4. integration/gcluster_a4_test.py +11 -0
  5. xpk/blueprints/a3mega/config-map.yaml.tftpl +15 -0
  6. xpk/blueprints/a3mega/storage_crd.yaml +52 -0
  7. xpk/blueprints/a3ultra/config-map.yaml.tftpl +15 -0
  8. xpk/blueprints/a3ultra/mlgru-disable.yaml +59 -0
  9. xpk/blueprints/a3ultra/nccl-installer.yaml +95 -0
  10. xpk/blueprints/a3ultra/storage_crd.yaml +52 -0
  11. xpk/blueprints/a4/config-map.yaml.tftpl +15 -0
  12. xpk/blueprints/a4/nccl-rdma-installer-a4.yaml +66 -0
  13. xpk/blueprints/a4/storage_crd.yaml +52 -0
  14. xpk/commands/cluster.py +89 -32
  15. xpk/commands/cluster_gcluster.py +25 -5
  16. xpk/commands/cluster_gcluster_test.py +16 -3
  17. xpk/commands/cluster_test.py +353 -7
  18. xpk/commands/config.py +3 -5
  19. xpk/commands/inspector.py +5 -3
  20. xpk/commands/kind.py +3 -1
  21. xpk/commands/managed_ml_diagnostics.py +249 -0
  22. xpk/commands/managed_ml_diagnostics_test.py +146 -0
  23. xpk/commands/storage.py +8 -10
  24. xpk/commands/workload.py +143 -142
  25. xpk/commands/workload_test.py +160 -118
  26. xpk/core/blueprint/blueprint_generator.py +73 -33
  27. xpk/core/blueprint/blueprint_test.py +9 -0
  28. xpk/core/blueprint/testing/data/a3_mega.yaml +129 -0
  29. xpk/core/blueprint/testing/data/a3_mega_spot.yaml +125 -0
  30. xpk/core/blueprint/testing/data/a3_ultra.yaml +173 -0
  31. xpk/core/blueprint/testing/data/a4.yaml +185 -0
  32. xpk/core/capacity.py +48 -8
  33. xpk/core/capacity_test.py +32 -1
  34. xpk/core/cluster.py +55 -104
  35. xpk/core/cluster_test.py +170 -0
  36. xpk/core/commands.py +4 -10
  37. xpk/core/config.py +88 -7
  38. xpk/core/config_test.py +67 -11
  39. xpk/core/docker_container.py +3 -1
  40. xpk/core/docker_image.py +10 -6
  41. xpk/core/docker_resources.py +1 -10
  42. xpk/core/gcloud_context.py +18 -12
  43. xpk/core/gcloud_context_test.py +111 -1
  44. xpk/core/kjob.py +17 -19
  45. xpk/core/kueue_manager.py +205 -51
  46. xpk/core/kueue_manager_test.py +158 -4
  47. xpk/core/nap.py +13 -14
  48. xpk/core/nodepool.py +37 -43
  49. xpk/core/nodepool_test.py +42 -19
  50. xpk/core/pathways.py +23 -0
  51. xpk/core/pathways_test.py +57 -0
  52. xpk/core/resources.py +84 -27
  53. xpk/core/scheduling.py +144 -133
  54. xpk/core/scheduling_test.py +298 -6
  55. xpk/core/system_characteristics.py +256 -19
  56. xpk/core/system_characteristics_test.py +128 -5
  57. xpk/core/telemetry.py +263 -0
  58. xpk/core/telemetry_test.py +211 -0
  59. xpk/core/vertex.py +4 -3
  60. xpk/core/workload_decorators/tcpx_decorator.py +5 -1
  61. xpk/main.py +33 -13
  62. xpk/parser/cluster.py +40 -67
  63. xpk/parser/cluster_test.py +83 -3
  64. xpk/parser/common.py +84 -0
  65. xpk/parser/storage.py +10 -0
  66. xpk/parser/storage_test.py +47 -0
  67. xpk/parser/workload.py +14 -29
  68. xpk/parser/workload_test.py +3 -49
  69. xpk/telemetry_uploader.py +29 -0
  70. xpk/templates/arm_gpu_workload_crate.yaml.j2 +46 -0
  71. xpk/templates/kueue_gke_default_topology.yaml.j2 +1 -1
  72. xpk/templates/kueue_sub_slicing_topology.yaml.j2 +3 -8
  73. xpk/utils/console.py +41 -10
  74. xpk/utils/console_test.py +106 -0
  75. xpk/utils/feature_flags.py +10 -1
  76. xpk/utils/file.py +4 -1
  77. xpk/utils/topology.py +4 -0
  78. xpk/utils/user_agent.py +35 -0
  79. xpk/utils/user_agent_test.py +44 -0
  80. xpk/utils/user_input.py +48 -0
  81. xpk/utils/user_input_test.py +92 -0
  82. xpk/utils/validation.py +2 -13
  83. xpk/utils/versions.py +31 -0
  84. xpk-0.16.0.dist-info/METADATA +127 -0
  85. xpk-0.16.0.dist-info/RECORD +168 -0
  86. xpk-0.14.4.dist-info/METADATA +0 -1645
  87. xpk-0.14.4.dist-info/RECORD +0 -139
  88. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/WHEEL +0 -0
  89. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/entry_points.txt +0 -0
  90. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/licenses/LICENSE +0 -0
  91. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/top_level.txt +0 -0
xpk/main.py CHANGED
@@ -32,11 +32,14 @@ Next Steps:
32
32
  """
33
33
 
34
34
  import argparse
35
+ import argcomplete
35
36
  import sys
36
37
 
37
38
  from .parser.core import set_parser
38
39
  from .core.updates import print_xpk_hello
39
- from .utils.console import xpk_print
40
+ from .core.config import set_config, FileSystemConfig
41
+ from .core.telemetry import MetricsCollector, send_clearcut_payload, should_send_telemetry
42
+ from .utils.console import xpk_print, exit_code_to_int
40
43
  from .utils.execution_context import set_context
41
44
  ################### Compatibility Check ###################
42
45
  # Check that the user runs the below version or greater.
@@ -59,19 +62,36 @@ if (
59
62
 
60
63
 
61
64
  def main() -> None:
62
- # Create top level parser for xpk command.
63
- parser = argparse.ArgumentParser(description='xpk command', prog='xpk')
64
- set_parser(parser=parser)
65
+ try:
66
+ # Create top level parser for xpk command.
67
+ parser = argparse.ArgumentParser(description='xpk command', prog='xpk')
68
+ set_parser(parser=parser)
69
+ argcomplete.autocomplete(parser)
65
70
 
66
- main_args = parser.parse_args()
67
- main_args.enable_ray_cluster = False
68
- set_context(
69
- dry_run_value='dry_run' in main_args and main_args.dry_run,
70
- quiet_value='quiet' in main_args and main_args.quiet,
71
- )
72
- print_xpk_hello()
73
- main_args.func(main_args)
74
- xpk_print('XPK Done.', flush=True)
71
+ main_args = parser.parse_args()
72
+ main_args.enable_ray_cluster = False
73
+ set_config(FileSystemConfig())
74
+ set_context(
75
+ dry_run_value='dry_run' in main_args and main_args.dry_run,
76
+ quiet_value=(
77
+ ('quiet' in main_args and main_args.quiet)
78
+ or ('force' in main_args and main_args.force)
79
+ ),
80
+ )
81
+ MetricsCollector.log_start(main_args.xpk_subcommands)
82
+ print_xpk_hello()
83
+ main_args.func(main_args)
84
+ xpk_print('XPK Done.', flush=True)
85
+ MetricsCollector.log_complete(0)
86
+ except SystemExit as e:
87
+ MetricsCollector.log_complete(exit_code_to_int(e.code))
88
+ raise
89
+ except:
90
+ MetricsCollector.log_complete(-1)
91
+ raise
92
+ finally:
93
+ if should_send_telemetry():
94
+ send_clearcut_payload(MetricsCollector.flush())
75
95
 
76
96
 
77
97
  if __name__ == '__main__':
xpk/parser/cluster.py CHANGED
@@ -26,10 +26,10 @@ from ..commands.cluster import (
26
26
  cluster_describe,
27
27
  cluster_list,
28
28
  )
29
- from ..commands.config import xpk_cfg
29
+ from ..core.config import get_config
30
30
  from ..core.config import CFG_BUCKET_KEY
31
31
  from ..core.vertex import DEFAULT_VERTEX_TENSORBOARD_NAME
32
- from .common import add_shared_arguments, ParserOrArgumentGroup
32
+ from .common import add_shared_arguments, ParserOrArgumentGroup, add_tpu_type_argument, add_tpu_and_device_type_arguments
33
33
  from .validators import name_type
34
34
  from ..utils.feature_flags import FeatureFlags
35
35
 
@@ -98,21 +98,7 @@ def set_cluster_create_parser(cluster_create_parser: ArgumentParser):
98
98
  required=True
99
99
  )
100
100
  )
101
- cluster_device_group.add_argument(
102
- '--tpu-type',
103
- type=str,
104
- default=None,
105
- help='The tpu type to use, v5litepod-16, etc.',
106
- )
107
- cluster_device_group.add_argument(
108
- '--device-type',
109
- type=str,
110
- default=None,
111
- help=(
112
- 'The device type to use (can be tpu or gpu or cpu), v5litepod-16,'
113
- ' h100-80gb-8, n2-standard-32-4 etc.'
114
- ),
115
- )
101
+ add_tpu_and_device_type_arguments(cluster_device_group)
116
102
 
117
103
  ### Optional arguments specific to "cluster create"
118
104
  cluster_create_optional_arguments = cluster_create_parser.add_argument_group(
@@ -124,7 +110,7 @@ def set_cluster_create_parser(cluster_create_parser: ArgumentParser):
124
110
  cluster_create_optional_arguments.add_argument(
125
111
  '--cluster-state-gcs-bucket',
126
112
  type=str,
127
- default=xpk_cfg.get(CFG_BUCKET_KEY),
113
+ default=get_config().get(CFG_BUCKET_KEY),
128
114
  help='The name of the bucket to store cluster state.',
129
115
  required=False,
130
116
  )
@@ -143,12 +129,9 @@ def set_cluster_create_parser(cluster_create_parser: ArgumentParser):
143
129
  ' enable cluster to accept Pathways workloads.'
144
130
  ),
145
131
  )
132
+
146
133
  if FeatureFlags.SUB_SLICING_ENABLED:
147
- cluster_create_optional_arguments.add_argument(
148
- '--sub-slicing',
149
- action='store_true',
150
- help='Whether to set up cluster to support sub-slicing',
151
- )
134
+ add_cluster_create_sub_slicing_arguments(cluster_create_optional_arguments)
152
135
 
153
136
  autoprovisioning_arguments = cluster_create_parser.add_argument_group(
154
137
  'Autoprovisioning Arguments',
@@ -204,11 +187,8 @@ def set_cluster_create_pathways_parser(
204
187
  add_shared_cluster_create_required_arguments(
205
188
  cluster_create_pathways_required_arguments
206
189
  )
207
- cluster_create_pathways_required_arguments.add_argument(
208
- '--tpu-type',
209
- type=str,
210
- default=None,
211
- help='The tpu type to use, v5litepod-16, etc.',
190
+ add_tpu_type_argument(
191
+ cluster_create_pathways_required_arguments, required=True
212
192
  )
213
193
 
214
194
  ### Optional arguments specific to "cluster create-pathways"
@@ -221,6 +201,10 @@ def set_cluster_create_pathways_parser(
221
201
  add_shared_cluster_create_optional_arguments(
222
202
  cluster_create_pathways_optional_arguments
223
203
  )
204
+ if FeatureFlags.SUB_SLICING_ENABLED:
205
+ add_cluster_create_sub_slicing_arguments(
206
+ cluster_create_pathways_optional_arguments
207
+ )
224
208
 
225
209
  autoprovisioning_arguments = (
226
210
  cluster_create_pathways_parser.add_argument_group(
@@ -281,13 +265,8 @@ def set_cluster_create_ray_parser(cluster_create_ray_parser: ArgumentParser):
281
265
  add_shared_cluster_create_required_arguments(
282
266
  cluster_create_ray_required_arguments
283
267
  )
284
- cluster_create_ray_required_arguments.add_argument(
285
- '--tpu-type',
286
- type=str,
287
- default=None,
288
- help='The tpu type to use, v5litepod-16, etc.',
289
- required=True,
290
- )
268
+ add_tpu_type_argument(cluster_create_ray_required_arguments, required=True)
269
+
291
270
  # TODO(bzmarke): Add --device-type to support GPU/CPU
292
271
  cluster_create_ray_required_arguments.add_argument(
293
272
  '--ray-version',
@@ -350,7 +329,9 @@ def set_cluster_create_ray_parser(cluster_create_ray_parser: ArgumentParser):
350
329
  )
351
330
  add_resource_limits(cluster_create_resource_limits)
352
331
 
353
- cluster_create_ray_parser.set_defaults(func=cluster_create_ray_cluster)
332
+ cluster_create_ray_parser.set_defaults(
333
+ func=cluster_create_ray_cluster, sub_slicing=False
334
+ )
354
335
 
355
336
 
356
337
  def set_cluster_delete_parser(cluster_delete_parser: ArgumentParser):
@@ -375,7 +356,7 @@ def set_cluster_delete_parser(cluster_delete_parser: ArgumentParser):
375
356
  cluster_delete_optional_arguments.add_argument(
376
357
  '--cluster-state-gcs-bucket',
377
358
  type=str,
378
- default=xpk_cfg.get(CFG_BUCKET_KEY),
359
+ default=get_config().get(CFG_BUCKET_KEY),
379
360
  help='The name of the bucket to store cluster state.',
380
361
  required=False,
381
362
  )
@@ -404,21 +385,7 @@ def set_cluster_cacheimage_parser(cluster_cacheimage_parser: ArgumentParser):
404
385
  )
405
386
 
406
387
  ### Device Type Argument
407
- cluster_cacheimage_group.add_argument(
408
- '--tpu-type',
409
- type=str,
410
- default=None,
411
- help='The tpu type to cache images on, v5litepod-16, etc.',
412
- )
413
- cluster_cacheimage_group.add_argument(
414
- '--device-type',
415
- type=str,
416
- default=None,
417
- help=(
418
- 'The device type to cache images on (can be tpu or gpu),'
419
- ' v5litepod-16, h100-80gb-8, etc.'
420
- ),
421
- )
388
+ add_tpu_and_device_type_arguments(cluster_cacheimage_group)
422
389
 
423
390
  ### Required arguments
424
391
  cluster_cacheimage_required_arguments.add_argument(
@@ -503,21 +470,7 @@ def set_cluster_adapt_parser(cluster_adapt_parser: ArgumentParser):
503
470
  required=True
504
471
  )
505
472
  )
506
- cluster_adapt_device_group.add_argument(
507
- '--tpu-type',
508
- type=str,
509
- default=None,
510
- help='The tpu type used on cluster, v5litepod-16, etc.',
511
- )
512
- cluster_adapt_device_group.add_argument(
513
- '--device-type',
514
- type=str,
515
- default=None,
516
- help=(
517
- 'The device type used on cluster (can be tpu or gpu or cpu), eg.'
518
- ' h100-80gb-8, n2-standard-32-4 etc.'
519
- ),
520
- )
473
+ add_tpu_and_device_type_arguments(cluster_adapt_device_group)
521
474
 
522
475
  cluster_adapt_optional_arguments = cluster_adapt_parser.add_argument_group(
523
476
  'Optional Arguments',
@@ -662,6 +615,11 @@ def add_shared_cluster_create_optional_arguments(
662
615
  ' regional clusters, all zones must support the machine type.'
663
616
  ),
664
617
  )
618
+ parser_or_group.add_argument(
619
+ '--managed-mldiagnostics',
620
+ action='store_true',
621
+ help='Enables the installation of required ML Diagnostics components.',
622
+ )
665
623
  parser_or_group.add_argument(
666
624
  '--cluster-cpu-machine-type',
667
625
  type=str,
@@ -790,6 +748,11 @@ def add_driver_arguments(parser_or_group: ParserOrArgumentGroup):
790
748
  action='store_true',
791
749
  help='Enable Lustre CSI driver on the cluster.',
792
750
  )
751
+ parser_or_group.add_argument(
752
+ '--enable-legacy-lustre-port',
753
+ action='store_true',
754
+ help='Enable legacy port for Lustre CSI driver on the cluster.',
755
+ )
793
756
 
794
757
 
795
758
  def add_shared_cluster_create_tensorboard_arguments(
@@ -937,3 +900,13 @@ def add_resource_limits(parser_or_group: ParserOrArgumentGroup):
937
900
  default=None,
938
901
  help='The CPU limit for the Kueue controller manager.',
939
902
  )
903
+
904
+
905
+ def add_cluster_create_sub_slicing_arguments(
906
+ parser_or_group: ParserOrArgumentGroup,
907
+ ):
908
+ parser_or_group.add_argument(
909
+ '--sub-slicing',
910
+ action='store_true',
911
+ help='Whether to set up cluster to support sub-slicing',
912
+ )
@@ -15,7 +15,7 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import argparse
18
- from xpk.parser.cluster import set_cluster_create_parser
18
+ from xpk.parser.cluster import set_cluster_create_parser, set_cluster_create_pathways_parser, set_cluster_create_ray_parser
19
19
  import pytest
20
20
  from ..utils.feature_flags import FeatureFlags
21
21
 
@@ -49,7 +49,7 @@ def test_cluster_create_sub_slicing_is_false_by_default():
49
49
 
50
50
  set_cluster_create_parser(parser)
51
51
  args = parser.parse_args(
52
- ["--cluster", "test-cluster", "--tpu-type", "test-tpu"]
52
+ ["--cluster", "test-cluster", "--tpu-type", "tpu7x-2"]
53
53
  )
54
54
 
55
55
  assert args.sub_slicing is False
@@ -60,7 +60,87 @@ def test_cluster_create_sub_slicing_can_be_set():
60
60
 
61
61
  set_cluster_create_parser(parser)
62
62
  args = parser.parse_args(
63
- ["--cluster", "test-cluster", "--tpu-type", "test-tpu", "--sub-slicing"]
63
+ ["--cluster", "test-cluster", "--tpu-type", "tpu7x-2", "--sub-slicing"]
64
64
  )
65
65
 
66
66
  assert args.sub_slicing is True
67
+
68
+
69
+ def test_cluster_create_pathways_sub_slicing_is_hidden_with_flag_off():
70
+ FeatureFlags.SUB_SLICING_ENABLED = False
71
+ parser = argparse.ArgumentParser()
72
+
73
+ set_cluster_create_pathways_parser(parser)
74
+ help_str = parser.format_help()
75
+
76
+ assert "--sub-slicing" not in help_str
77
+
78
+
79
+ def test_cluster_create_pathways_sub_slicing_can_be_set():
80
+ parser = argparse.ArgumentParser()
81
+
82
+ set_cluster_create_pathways_parser(parser)
83
+ args = parser.parse_args(
84
+ ["--cluster", "test-cluster", "--tpu-type", "tpu7x-2", "--sub-slicing"]
85
+ )
86
+
87
+ assert args.sub_slicing is True
88
+
89
+
90
+ def test_cluster_create_ray_sub_slicing_is_hidden_but_set_to_false():
91
+ parser = argparse.ArgumentParser()
92
+
93
+ set_cluster_create_ray_parser(parser)
94
+ args = parser.parse_args([
95
+ "--cluster",
96
+ "test-cluster",
97
+ "--tpu-type",
98
+ "tpu7x-2",
99
+ "--ray-version",
100
+ "19.32.0",
101
+ ])
102
+ help_str = parser.format_help()
103
+
104
+ assert args.sub_slicing is False
105
+ assert "--sub-slicing" not in help_str
106
+
107
+
108
+ def test_cluster_create_managed_mldiagnostics():
109
+ parser = argparse.ArgumentParser()
110
+
111
+ set_cluster_create_parser(parser)
112
+ args = parser.parse_args([
113
+ "--cluster",
114
+ "test-cluster",
115
+ "--tpu-type",
116
+ "v5p-8",
117
+ "--managed-mldiagnostics",
118
+ ])
119
+
120
+ assert args.managed_mldiagnostics is True
121
+
122
+
123
+ def test_cluster_create_enable_lustre_legacy_port_is_false_by_default():
124
+ parser = argparse.ArgumentParser()
125
+
126
+ set_cluster_create_parser(parser)
127
+ args = parser.parse_args(
128
+ ["--cluster", "test-cluster", "--tpu-type", "tpu7x-2"]
129
+ )
130
+
131
+ assert args.enable_legacy_lustre_port is False
132
+
133
+
134
+ def test_cluster_create_enable_lustre_legacy_port_can_be_set():
135
+ parser = argparse.ArgumentParser()
136
+
137
+ set_cluster_create_parser(parser)
138
+ args = parser.parse_args([
139
+ "--cluster",
140
+ "test-cluster",
141
+ "--tpu-type",
142
+ "tpu7x-2",
143
+ "--enable-legacy-lustre-port",
144
+ ])
145
+
146
+ assert args.enable_legacy_lustre_port is True
xpk/parser/common.py CHANGED
@@ -16,6 +16,10 @@ limitations under the License.
16
16
 
17
17
  import argparse
18
18
  from typing import Protocol, Any
19
+ from ..core.system_characteristics import get_system_characteristics_keys_by_accelerator_type, AcceleratorType
20
+ import difflib
21
+ from argcomplete import ChoicesCompleter
22
+ from argparse import Action, ArgumentError
19
23
 
20
24
 
21
25
  class ParserOrArgumentGroup(Protocol):
@@ -24,6 +28,46 @@ class ParserOrArgumentGroup(Protocol):
24
28
  ...
25
29
 
26
30
 
31
+ class ManyChoicesAction(Action):
32
+ """An action class to output better error message for arguments with large lists of choices."""
33
+
34
+ def __init__(self, *args, large_choice_list, **kwargs):
35
+ self.large_list_of_choices = large_choice_list
36
+ super().__init__(*args, **kwargs)
37
+
38
+ def __call__(self, parser, namespace, value, option_string=None):
39
+ if value not in self.large_list_of_choices:
40
+ close_matches = difflib.get_close_matches(
41
+ value, self.large_list_of_choices, n=5, cutoff=0
42
+ )
43
+ msg = (
44
+ f"invalid choice: '{value}' (closest matches:"
45
+ f" {', '.join(close_matches)})"
46
+ )
47
+ raise ArgumentError(self, msg)
48
+ setattr(namespace, self.dest, value)
49
+
50
+
51
+ def add_many_choices_argument(
52
+ parserOrGroup: ParserOrArgumentGroup,
53
+ flag_name,
54
+ choices: list[str],
55
+ metavar: str,
56
+ help_msg: str,
57
+ required: bool = False,
58
+ ) -> None:
59
+ parserOrGroup.add_argument(
60
+ flag_name,
61
+ action=ManyChoicesAction,
62
+ large_choice_list=choices,
63
+ type=str,
64
+ metavar=metavar,
65
+ help=help_msg,
66
+ required=required,
67
+ default=None,
68
+ ).completer = ChoicesCompleter(choices)
69
+
70
+
27
71
  def add_shared_arguments(
28
72
  custom_parser_or_group: ParserOrArgumentGroup, required=False
29
73
  ) -> None:
@@ -285,3 +329,43 @@ def add_slurm_arguments(custom_parser_or_group: ParserOrArgumentGroup):
285
329
  ' `very-high`. Defaults to `medium`.'
286
330
  ),
287
331
  )
332
+
333
+
334
+ def add_tpu_type_argument(
335
+ custom_parser_or_group: ParserOrArgumentGroup,
336
+ required: bool = False,
337
+ ) -> None:
338
+ add_many_choices_argument(
339
+ custom_parser_or_group,
340
+ '--tpu-type',
341
+ choices=get_system_characteristics_keys_by_accelerator_type(
342
+ [AcceleratorType.TPU]
343
+ ),
344
+ metavar='TPU_TYPE',
345
+ help_msg='The tpu type to use, v5litepod-16, etc.',
346
+ required=required,
347
+ )
348
+
349
+
350
+ def add_device_type_argument(
351
+ custom_parser_or_group: ParserOrArgumentGroup,
352
+ required: bool = False,
353
+ ) -> None:
354
+ add_many_choices_argument(
355
+ custom_parser_or_group,
356
+ '--device-type',
357
+ choices=get_system_characteristics_keys_by_accelerator_type(),
358
+ metavar='DEVICE_TYPE',
359
+ help_msg=(
360
+ 'The device type to use (can be tpu or gpu or cpu), v5litepod-16,'
361
+ ' h100-80gb-8, n2-standard-32-4 etc.'
362
+ ),
363
+ required=required,
364
+ )
365
+
366
+
367
+ def add_tpu_and_device_type_arguments(
368
+ custom_parser_or_group: ParserOrArgumentGroup,
369
+ ) -> None:
370
+ add_tpu_type_argument(custom_parser_or_group)
371
+ add_device_type_argument(custom_parser_or_group)
xpk/parser/storage.py CHANGED
@@ -104,6 +104,16 @@ def add_storage_attach_parser(
104
104
  help='If true workloads can only read from storage',
105
105
  )
106
106
 
107
+ lustre_args = storage_attach_parser.add_argument_group(
108
+ 'Lustre arguments',
109
+ 'Arguments used when --type=lustre',
110
+ )
111
+ lustre_args.add_argument(
112
+ '--enable-legacy-lustre-port',
113
+ action='store_true',
114
+ help='Enable legacy port for Lustre CSI driver on the cluster.',
115
+ )
116
+
107
117
  gcsfuse_args = storage_attach_parser.add_argument_group(
108
118
  'FUSE arguments',
109
119
  'Arguments used when --type=gcsfuse',
@@ -0,0 +1,47 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import argparse
18
+ from xpk.parser.storage import set_storage_parser
19
+
20
+ DEFAULT_ATTACH_ARGUMENTS = (
21
+ "attach test-storage --cluster test-cluster --zone test-zone"
22
+ " --project test-project --mount-point test-mount-point"
23
+ " --readonly false --auto-mount true"
24
+ )
25
+
26
+ DEFAULT_LUSTRE_ATTACH_ARGUMENTS = (
27
+ DEFAULT_ATTACH_ARGUMENTS + " --type lustre --manifest test-manifest"
28
+ )
29
+
30
+
31
+ def test_cluster_create_enable_lustre_legacy_port_is_false_by_default():
32
+ parser = argparse.ArgumentParser()
33
+
34
+ set_storage_parser(parser)
35
+ args = parser.parse_args(DEFAULT_LUSTRE_ATTACH_ARGUMENTS.split())
36
+
37
+ assert args.enable_legacy_lustre_port is False
38
+
39
+
40
+ def test_cluster_create_enable_lustre_legacy_port_can_be_set():
41
+ parser = argparse.ArgumentParser()
42
+ set_storage_parser(parser)
43
+ args = parser.parse_args(
44
+ DEFAULT_LUSTRE_ATTACH_ARGUMENTS.split() + ["--enable-legacy-lustre-port"]
45
+ )
46
+
47
+ assert args.enable_legacy_lustre_port is True
xpk/parser/workload.py CHANGED
@@ -22,9 +22,8 @@ from ..commands.workload import (
22
22
  workload_list,
23
23
  )
24
24
  from ..core.docker_image import DEFAULT_DOCKER_IMAGE, DEFAULT_SCRIPT_DIR
25
- from .common import add_shared_arguments
25
+ from .common import add_shared_arguments, add_tpu_type_argument, add_tpu_and_device_type_arguments
26
26
  from .validators import directory_path_type, name_type
27
- from ..utils.feature_flags import FeatureFlags
28
27
 
29
28
 
30
29
  def set_workload_parsers(workload_parser: ArgumentParser):
@@ -118,21 +117,7 @@ def set_workload_create_parser(workload_create_parser: ArgumentParser):
118
117
  required=True
119
118
  )
120
119
  )
121
- workload_device_group.add_argument(
122
- '--tpu-type',
123
- type=str,
124
- default=None,
125
- help='The tpu type to use, v5litepod-16, etc.',
126
- )
127
- workload_device_group.add_argument(
128
- '--device-type',
129
- type=str,
130
- default=None,
131
- help=(
132
- 'The device type to use (can be tpu or gpu or cpu), v5litepod-16,'
133
- ' h100-80gb-8, n2-standard-32-4 etc.'
134
- ),
135
- )
120
+ add_tpu_and_device_type_arguments(workload_device_group)
136
121
 
137
122
  workload_create_parser_optional_arguments.add_argument(
138
123
  '--storage',
@@ -280,11 +265,8 @@ def set_workload_create_pathways_parser(
280
265
  )
281
266
  )
282
267
  ### "workload create-pathways" Required arguments, specific to Pathways
283
- workload_create_pathways_parser_required_arguments.add_argument(
284
- '--tpu-type',
285
- type=str,
286
- default=None,
287
- help='The tpu type to use, v5litepod-16, etc.',
268
+ add_tpu_type_argument(
269
+ workload_create_pathways_parser_required_arguments, required=True
288
270
  )
289
271
 
290
272
  ### "workload create-pathways" Optional arguments, specific to Pathways
@@ -601,6 +583,16 @@ def add_shared_workload_create_optional_arguments(args_parsers):
601
583
  ' `jax-tpu`.'
602
584
  ),
603
585
  )
586
+ custom_parser.add_argument(
587
+ '--output-manifest-file',
588
+ type=str,
589
+ default=None,
590
+ help=(
591
+ 'If you want to see the generated manifest, provide a file path'
592
+ ' here. This will write the manifest to the file. If used with'
593
+ ' --dry-run, it will skip the actual deployment and cluster checks.'
594
+ ),
595
+ )
604
596
  custom_parser.add_argument(
605
597
  '--num-slices',
606
598
  type=int,
@@ -659,13 +651,6 @@ def add_shared_workload_create_optional_arguments(args_parsers):
659
651
  ' the workload.'
660
652
  ),
661
653
  )
662
- if FeatureFlags.SUB_SLICING_ENABLED:
663
- custom_parser.add_argument(
664
- '--sub-slicing-topology',
665
- type=str,
666
- help='Sub-slicing topology to use.',
667
- required=False,
668
- )
669
654
 
670
655
 
671
656
  def add_shared_workload_create_env_arguments(args_parsers):
@@ -16,53 +16,9 @@ limitations under the License.
16
16
 
17
17
  import argparse
18
18
  from xpk.parser.workload import set_workload_create_parser
19
- from ..utils.feature_flags import FeatureFlags
20
- import pytest
21
19
 
22
20
 
23
- @pytest.fixture(autouse=True)
24
- def with_sub_slicing_enabled():
25
- FeatureFlags.SUB_SLICING_ENABLED = True
26
-
27
-
28
- def test_workload_create_sub_slicing_topology_is_hidden_with_flag_off():
29
- FeatureFlags.SUB_SLICING_ENABLED = False
30
- parser = argparse.ArgumentParser()
31
-
32
- set_workload_create_parser(parser)
33
- help_str = parser.format_help()
34
-
35
- assert "--sub-slicing" not in help_str
36
-
37
-
38
- def test_workload_create_sub_slicing_topology_is_shown_with_flag_on():
39
- parser = argparse.ArgumentParser()
40
-
41
- set_workload_create_parser(parser)
42
- help_str = parser.format_help()
43
-
44
- assert "--sub-slicing" in help_str
45
-
46
-
47
- def test_workload_create_sub_slicing_topology_is_none_by_default():
48
- parser = argparse.ArgumentParser()
49
-
50
- set_workload_create_parser(parser)
51
- args = parser.parse_args([
52
- "--cluster",
53
- "test-cluster",
54
- "--command",
55
- "python3",
56
- "--workload",
57
- "test",
58
- "--tpu-type",
59
- "test-tpu",
60
- ])
61
-
62
- assert args.sub_slicing_topology is None
63
-
64
-
65
- def test_workload_create_sub_slicing_topology_can_be_set():
21
+ def test_workload_create_parses():
66
22
  parser = argparse.ArgumentParser()
67
23
 
68
24
  set_workload_create_parser(parser)
@@ -74,9 +30,7 @@ def test_workload_create_sub_slicing_topology_can_be_set():
74
30
  "--workload",
75
31
  "test",
76
32
  "--tpu-type",
77
- "test-tpu",
78
- "--sub-slicing-topology",
79
- "2x2",
33
+ "tpu7x-2",
80
34
  ])
81
35
 
82
- assert args.sub_slicing_topology is "2x2"
36
+ assert args