xpk 0.14.4__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. integration/gcluster_a3mega_test.py +11 -0
  2. integration/gcluster_a3ultra_test.py +11 -0
  3. integration/gcluster_a4_test.py +11 -0
  4. xpk/commands/cluster.py +57 -21
  5. xpk/commands/cluster_gcluster.py +25 -5
  6. xpk/commands/cluster_gcluster_test.py +11 -2
  7. xpk/commands/cluster_test.py +233 -12
  8. xpk/commands/config.py +3 -5
  9. xpk/commands/kind.py +1 -1
  10. xpk/commands/storage.py +8 -10
  11. xpk/commands/workload.py +28 -12
  12. xpk/commands/workload_test.py +3 -3
  13. xpk/core/blueprint/blueprint_generator.py +70 -33
  14. xpk/core/blueprint/blueprint_test.py +9 -0
  15. xpk/core/capacity.py +46 -8
  16. xpk/core/capacity_test.py +32 -1
  17. xpk/core/cluster.py +37 -57
  18. xpk/core/cluster_test.py +95 -0
  19. xpk/core/commands.py +4 -10
  20. xpk/core/config.py +9 -2
  21. xpk/core/gcloud_context.py +18 -12
  22. xpk/core/gcloud_context_test.py +111 -1
  23. xpk/core/kjob.py +6 -9
  24. xpk/core/kueue_manager.py +192 -32
  25. xpk/core/kueue_manager_test.py +132 -4
  26. xpk/core/nodepool.py +21 -29
  27. xpk/core/nodepool_test.py +17 -15
  28. xpk/core/scheduling.py +16 -1
  29. xpk/core/scheduling_test.py +85 -6
  30. xpk/core/system_characteristics.py +77 -19
  31. xpk/core/system_characteristics_test.py +80 -5
  32. xpk/core/telemetry.py +263 -0
  33. xpk/core/telemetry_test.py +211 -0
  34. xpk/main.py +31 -13
  35. xpk/parser/cluster.py +48 -9
  36. xpk/parser/cluster_test.py +42 -3
  37. xpk/parser/workload.py +12 -0
  38. xpk/parser/workload_test.py +4 -4
  39. xpk/telemetry_uploader.py +29 -0
  40. xpk/templates/kueue_gke_default_topology.yaml.j2 +1 -1
  41. xpk/templates/kueue_sub_slicing_topology.yaml.j2 +3 -8
  42. xpk/utils/console.py +41 -10
  43. xpk/utils/console_test.py +106 -0
  44. xpk/utils/feature_flags.py +7 -1
  45. xpk/utils/file.py +4 -1
  46. xpk/utils/topology.py +4 -0
  47. xpk/utils/user_agent.py +35 -0
  48. xpk/utils/user_agent_test.py +44 -0
  49. xpk/utils/user_input.py +48 -0
  50. xpk/utils/user_input_test.py +92 -0
  51. xpk/utils/validation.py +0 -11
  52. xpk/utils/versions.py +31 -0
  53. {xpk-0.14.4.dist-info → xpk-0.15.0.dist-info}/METADATA +113 -92
  54. {xpk-0.14.4.dist-info → xpk-0.15.0.dist-info}/RECORD +58 -48
  55. {xpk-0.14.4.dist-info → xpk-0.15.0.dist-info}/WHEEL +0 -0
  56. {xpk-0.14.4.dist-info → xpk-0.15.0.dist-info}/entry_points.txt +0 -0
  57. {xpk-0.14.4.dist-info → xpk-0.15.0.dist-info}/licenses/LICENSE +0 -0
  58. {xpk-0.14.4.dist-info → xpk-0.15.0.dist-info}/top_level.txt +0 -0
@@ -14,16 +14,20 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  from argparse import Namespace
18
19
  from dataclasses import dataclass
19
20
  from typing import Any
20
21
  from unittest.mock import MagicMock, patch
21
22
  import pytest
22
23
 
23
- from xpk.commands.cluster import _install_kueue, _validate_cluster_create_args, run_gke_cluster_create_command
24
+ from xpk.core.telemetry import MetricsCollector
25
+ from xpk.commands.cluster import _install_kueue, _validate_cluster_create_args, run_gke_cluster_create_command, cluster_create, _log_cluster_create_telemetry
26
+ from xpk.core.capacity import CapacityType
24
27
  from xpk.core.system_characteristics import SystemCharacteristics, UserFacingNameToSystemCharacteristics
25
28
  from xpk.core.testing.commands_tester import CommandsTester
26
29
  from xpk.utils.feature_flags import FeatureFlags
30
+ from xpk.utils.versions import ReleaseChannel
27
31
 
28
32
 
29
33
  @dataclass
@@ -34,6 +38,30 @@ class _Mocks:
34
38
  commands_tester: CommandsTester
35
39
 
36
40
 
41
+ @dataclass
42
+ class _ClusterCreateMocks:
43
+ """Holds all the mocked dependencies for the cluster_create function."""
44
+
45
+ get_all_clusters_programmatic: MagicMock
46
+ get_gke_server_config: MagicMock
47
+ get_gke_control_plane_version: MagicMock
48
+ get_system_characteristics: MagicMock
49
+ authorize_private_cluster_access_if_necessary: MagicMock
50
+ update_coredns_if_necessary: MagicMock
51
+ get_cluster_credentials: MagicMock
52
+ setup_k8s_env: MagicMock
53
+ get_gke_node_pool_version: MagicMock
54
+ run_gke_node_pool_create_command: MagicMock
55
+ create_cluster_configmaps: MagicMock
56
+ set_jobset_on_cluster: MagicMock
57
+ get_cluster_location: MagicMock
58
+ install_kjob: MagicMock
59
+ xpk_exit: MagicMock
60
+ update_jobset_resources_if_necessary: MagicMock
61
+ _install_kueue: MagicMock
62
+ set_pathways_job_on_cluster: MagicMock
63
+
64
+
37
65
  @pytest.fixture
38
66
  def mocks(mocker) -> _Mocks:
39
67
  common_print_mock = mocker.patch(
@@ -65,6 +93,10 @@ def construct_args(**kwargs: Any) -> Namespace:
65
93
  project='project',
66
94
  zone='us-central1-a',
67
95
  reservation='',
96
+ on_demand=False,
97
+ tpu_type=None,
98
+ device_type=None,
99
+ spot=False,
68
100
  default_pool_cpu_machine_type='test-machine-type',
69
101
  cluster='test-cluster',
70
102
  default_pool_cpu_num_nodes='100',
@@ -87,11 +119,82 @@ def construct_args(**kwargs: Any) -> Namespace:
87
119
  memory_limit='100Gi',
88
120
  cpu_limit=100,
89
121
  cluster_cpu_machine_type='',
122
+ create_vertex_tensorboard=False,
123
+ enable_autoprovisioning=False,
90
124
  )
91
125
  args_dict.update(kwargs)
92
126
  return Namespace(**args_dict)
93
127
 
94
128
 
129
+ @pytest.fixture
130
+ def cluster_create_mocks(mocker) -> _ClusterCreateMocks:
131
+ """Mocks all dependencies for the cluster_create function."""
132
+ # This fixture patches all the functions called by cluster_create, allowing
133
+ # tests to focus on specific logic paths without executing external commands
134
+ # or complex sub-functions. Each mock can be configured within the test
135
+ # itself if a specific return value or behavior is needed.
136
+ return _ClusterCreateMocks(
137
+ get_all_clusters_programmatic=mocker.patch(
138
+ 'xpk.commands.cluster.get_all_clusters_programmatic',
139
+ return_value=([], 0),
140
+ ),
141
+ get_gke_server_config=mocker.patch(
142
+ 'xpk.commands.cluster.get_gke_server_config',
143
+ return_value=(0, MagicMock()),
144
+ ),
145
+ get_gke_control_plane_version=mocker.patch(
146
+ 'xpk.commands.cluster.get_gke_control_plane_version'
147
+ ),
148
+ get_system_characteristics=mocker.patch(
149
+ 'xpk.commands.cluster.get_system_characteristics',
150
+ return_value=(TPU_TEST_SYSTEM, 0),
151
+ ),
152
+ authorize_private_cluster_access_if_necessary=mocker.patch(
153
+ 'xpk.commands.cluster.authorize_private_cluster_access_if_necessary',
154
+ return_value=0,
155
+ ),
156
+ update_coredns_if_necessary=mocker.patch(
157
+ 'xpk.commands.cluster.update_coredns_if_necessary', return_value=0
158
+ ),
159
+ get_cluster_credentials=mocker.patch(
160
+ 'xpk.commands.cluster.get_cluster_credentials', return_value=0
161
+ ),
162
+ setup_k8s_env=mocker.patch('xpk.commands.cluster.setup_k8s_env'),
163
+ get_gke_node_pool_version=mocker.patch(
164
+ 'xpk.commands.cluster.get_gke_node_pool_version',
165
+ return_value=(0, '1.2.3'),
166
+ ),
167
+ run_gke_node_pool_create_command=mocker.patch(
168
+ 'xpk.commands.cluster.run_gke_node_pool_create_command',
169
+ return_value=0,
170
+ ),
171
+ create_cluster_configmaps=mocker.patch(
172
+ 'xpk.commands.cluster.create_cluster_configmaps', return_value=0
173
+ ),
174
+ set_jobset_on_cluster=mocker.patch(
175
+ 'xpk.commands.cluster.set_jobset_on_cluster', return_value=0
176
+ ),
177
+ get_cluster_location=mocker.patch(
178
+ 'xpk.commands.cluster.get_cluster_location',
179
+ return_value='us-central1',
180
+ ),
181
+ install_kjob=mocker.patch(
182
+ 'xpk.commands.cluster.install_kjob', return_value=0
183
+ ),
184
+ xpk_exit=mocker.patch('xpk.commands.cluster.xpk_exit'),
185
+ update_jobset_resources_if_necessary=mocker.patch(
186
+ 'xpk.commands.cluster.update_jobset_resources_if_necessary',
187
+ return_value=0,
188
+ ),
189
+ _install_kueue=mocker.patch(
190
+ 'xpk.commands.cluster._install_kueue', return_value=0
191
+ ),
192
+ set_pathways_job_on_cluster=mocker.patch(
193
+ 'xpk.commands.cluster.set_pathways_job_on_cluster', return_value=0
194
+ ),
195
+ )
196
+
197
+
95
198
  GPU_TEST_SYSTEM: SystemCharacteristics = UserFacingNameToSystemCharacteristics[
96
199
  'l4-1'
97
200
  ]
@@ -106,7 +209,7 @@ TPU_TEST_SYSTEM: SystemCharacteristics = UserFacingNameToSystemCharacteristics[
106
209
  def test_validate_cluster_create_args_for_correct_args_pass(
107
210
  mocks: _Mocks,
108
211
  ):
109
- args = Namespace()
212
+ args = construct_args()
110
213
 
111
214
  _validate_cluster_create_args(args, GPU_TEST_SYSTEM)
112
215
 
@@ -209,6 +312,7 @@ def test_run_gke_cluster_create_command_specifies_custom_cluster_arguments_last(
209
312
  ),
210
313
  gke_control_plane_version='1.2.3',
211
314
  system=TPU_TEST_SYSTEM,
315
+ release_channel=ReleaseChannel.STABLE,
212
316
  )
213
317
 
214
318
  assert result == 0
@@ -226,12 +330,16 @@ def test_run_gke_cluster_create_command_without_gke_version_does_not_have_no_aut
226
330
  args=construct_args(gke_version=''),
227
331
  gke_control_plane_version='1.2.3',
228
332
  system=TPU_TEST_SYSTEM,
333
+ release_channel=ReleaseChannel.RAPID,
229
334
  )
230
335
 
231
336
  assert result == 0
232
337
  mocks.commands_tester.assert_command_not_run(
233
338
  'clusters create', ' --no-enable-autoupgrade'
234
339
  )
340
+ mocks.commands_tester.assert_command_run(
341
+ 'clusters create', ' --release-channel=rapid'
342
+ )
235
343
 
236
344
 
237
345
  def test_run_gke_cluster_create_command_with_gke_version_has_no_autoupgrade_flag(
@@ -241,24 +349,137 @@ def test_run_gke_cluster_create_command_with_gke_version_has_no_autoupgrade_flag
241
349
  args=construct_args(gke_version='1.2.3'),
242
350
  gke_control_plane_version='1.2.3',
243
351
  system=TPU_TEST_SYSTEM,
352
+ release_channel=ReleaseChannel.REGULAR,
244
353
  )
245
354
 
246
355
  assert result == 0
247
356
  mocks.commands_tester.assert_command_run(
248
- 'clusters create', ' --no-enable-autoupgrade'
357
+ 'clusters create', '--release-channel=regular', ' --no-enable-autoupgrade'
249
358
  )
250
359
 
251
360
 
252
- def test_run_gke_cluster_create_command_with_gpu_system_has_no_enable_autoupgrade(
253
- mocks: _Mocks,
361
+ def test_log_cluster_create_telemetry_does_not_log_when_feature_flag_is_disabled():
362
+ FeatureFlags.TELEMETRY_ENABLED = False
363
+ _log_cluster_create_telemetry(construct_args())
364
+ events = json.loads(MetricsCollector.flush())['log_event']
365
+ assert len(events) == 0
366
+
367
+
368
+ def test_log_cluster_create_telemetry_logs_correct_event_when_tpu_type_is_provided(
369
+ mocker: MagicMock,
254
370
  ):
255
- result = run_gke_cluster_create_command(
256
- args=construct_args(gke_version=''),
257
- gke_control_plane_version='1.2.3',
258
- system=GPU_TEST_SYSTEM,
371
+ FeatureFlags.TELEMETRY_ENABLED = True
372
+ mocker.patch(
373
+ 'xpk.commands.cluster.get_capacity_type',
374
+ return_value=(CapacityType.SPOT, 0),
375
+ )
376
+ _log_cluster_create_telemetry(construct_args(device_type='test-device-type'))
377
+ event = json.loads(MetricsCollector.flush())['log_event'][0]
378
+ payload = json.loads(event['source_extension_json'])
379
+ event_metadata = payload['event_metadata']
380
+ assert payload['event_name'] == 'cluster_create'
381
+ assert (
382
+ _get_event_metadata_value_by_key(
383
+ event_metadata,
384
+ 'XPK_ZONE',
385
+ )
386
+ == 'us-central1-a'
387
+ )
388
+ assert (
389
+ _get_event_metadata_value_by_key(
390
+ event_metadata,
391
+ 'XPK_SYSTEM_CHARACTERISTICS',
392
+ )
393
+ == 'test-device-type'
394
+ )
395
+ assert (
396
+ _get_event_metadata_value_by_key(
397
+ event_metadata,
398
+ 'XPK_PROVISIONING_MODE',
399
+ )
400
+ == 'spot'
259
401
  )
260
402
 
261
- assert result == 0
262
- mocks.commands_tester.assert_command_run(
263
- 'clusters create', ' --no-enable-autoupgrade'
403
+
404
+ def test_log_cluster_create_telemetry_logs_correct_event_when_device_type_is_provided(
405
+ mocker: MagicMock,
406
+ ):
407
+ FeatureFlags.TELEMETRY_ENABLED = True
408
+ mocker.patch(
409
+ 'xpk.commands.cluster.get_capacity_type',
410
+ return_value=(CapacityType.SPOT, 0),
411
+ )
412
+ _log_cluster_create_telemetry(construct_args(tpu_type='test-tpu-type'))
413
+ event = json.loads(MetricsCollector.flush())['log_event'][0]
414
+ payload = json.loads(event['source_extension_json'])
415
+ event_metadata = payload['event_metadata']
416
+ assert payload['event_name'] == 'cluster_create'
417
+ assert (
418
+ _get_event_metadata_value_by_key(
419
+ event_metadata,
420
+ 'XPK_ZONE',
421
+ )
422
+ == 'us-central1-a'
423
+ )
424
+ assert (
425
+ _get_event_metadata_value_by_key(
426
+ event_metadata,
427
+ 'XPK_SYSTEM_CHARACTERISTICS',
428
+ )
429
+ == 'test-tpu-type'
264
430
  )
431
+ assert (
432
+ _get_event_metadata_value_by_key(
433
+ event_metadata,
434
+ 'XPK_PROVISIONING_MODE',
435
+ )
436
+ == 'spot'
437
+ )
438
+
439
+
440
+ def _get_event_metadata_value_by_key(
441
+ event_metadata: list[dict[str, str]], key: str
442
+ ) -> str | None:
443
+ return next(
444
+ (meta['value'] for meta in event_metadata if meta['key'] == key),
445
+ None,
446
+ )
447
+
448
+
449
+ @pytest.mark.parametrize(
450
+ 'gke_version_arg, expected_channel, expected_version',
451
+ [
452
+ (None, ReleaseChannel.RAPID, '1.2.4'), # No version, should use RAPID
453
+ (
454
+ '1.2.3',
455
+ ReleaseChannel.REGULAR,
456
+ '1.2.3',
457
+ ), # Version provided, should use REGULAR
458
+ ],
459
+ )
460
+ def test_cluster_create_calls_run_command_with_correct_channel_and_version(
461
+ gke_version_arg,
462
+ expected_channel,
463
+ expected_version,
464
+ mocks: _Mocks,
465
+ cluster_create_mocks: _ClusterCreateMocks,
466
+ ):
467
+ """
468
+ Verifies that cluster_create calls run_gke_cluster_create_command with the correct
469
+ release channel and GKE version based on whether a version is provided.
470
+ """
471
+ cluster_create_mocks.get_gke_control_plane_version.return_value = (
472
+ 0,
473
+ expected_version,
474
+ )
475
+
476
+ args = construct_args(gke_version=gke_version_arg)
477
+ cluster_create(args)
478
+
479
+ expected_command_parts = [
480
+ 'clusters create',
481
+ f'--cluster-version={expected_version}',
482
+ f'--release-channel={expected_channel.value.lower()}',
483
+ ]
484
+
485
+ mocks.commands_tester.assert_command_run(*expected_command_parts)
xpk/commands/config.py CHANGED
@@ -14,16 +14,14 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from ..core.config import XpkConfig
17
+ from ..core.config import xpk_config
18
18
  from ..utils.console import xpk_print
19
19
 
20
- xpk_cfg = XpkConfig()
21
-
22
20
 
23
21
  def set_config(args):
24
- xpk_cfg.set(args.set_config_args[0], args.set_config_args[1])
22
+ xpk_config.set(args.set_config_args[0], args.set_config_args[1])
25
23
 
26
24
 
27
25
  def get_config(args):
28
- value = xpk_cfg.get(args.get_config_key[0])
26
+ value = xpk_config.get(args.get_config_key[0])
29
27
  xpk_print(value)
xpk/commands/kind.py CHANGED
@@ -99,7 +99,7 @@ def cluster_create(args) -> None:
99
99
  supports_sub_slicing=False,
100
100
  )
101
101
 
102
- kueue_manager = KueueManager()
102
+ kueue_manager = KueueManager(project='', zone='')
103
103
  kueue_manager.install_or_upgrade(
104
104
  KueueConfig(
105
105
  system,
xpk/commands/storage.py CHANGED
@@ -56,7 +56,7 @@ from ..core.storage import (
56
56
  list_storages,
57
57
  print_storages_for_cluster,
58
58
  )
59
- from ..utils.console import get_user_input, xpk_exit, xpk_print
59
+ from ..utils.console import ask_for_user_consent, xpk_exit, xpk_print
60
60
  from ..utils.kubectl import apply_kubectl_manifest
61
61
  from ..utils.execution_context import is_dry_run
62
62
  from ..utils.validation import validate_dependencies_list, SystemDependency, should_validate_dependencies
@@ -133,15 +133,13 @@ def storage_delete(args: Namespace) -> None:
133
133
  if storage.bucket.startswith(filestore_instance_name)
134
134
  ]
135
135
 
136
- if children and not args.force:
137
- detach = get_user_input(
138
- "Deleting a filestore storage will destroy your filestore instance and"
139
- " all its data in all volumes will be lost. Do you wish to delete the"
140
- f" filestore instance {filestore_instance_name}?\n y (yes) / n (no):\n'"
141
- )
142
- if not detach:
143
- xpk_print("Deleting storage canceled.")
144
- xpk_exit(0)
136
+ if children and not ask_for_user_consent(
137
+ "Deleting a filestore storage will destroy your filestore instance and"
138
+ " all its data in all volumes will be lost. Do you wish to delete the"
139
+ f" filestore instance {filestore_instance_name}?"
140
+ ):
141
+ xpk_print("Deleting storage canceled.")
142
+ xpk_exit(0)
145
143
 
146
144
  for child in children:
147
145
  delete_storage_resources(k8s_api_client, child)
xpk/commands/workload.py CHANGED
@@ -27,15 +27,14 @@ from ..core.cluster import (
27
27
  setup_k8s_env,
28
28
  )
29
29
  from ..core.commands import run_command_with_updates, run_commands
30
- from ..core.kueue_manager import KueueManager, has_sub_slicing_enabled
31
30
  from ..core.config import (VERTEX_TENSORBOARD_FEATURE_FLAG, XPK_CURRENT_VERSION)
32
31
  from ..core.docker_container import (
33
32
  get_main_container_docker_image,
34
33
  get_user_workload_container,
35
34
  )
35
+ from ..core.kueue_manager import has_sub_slicing_enabled, get_installed_kueue_version, LOCAL_QUEUE_NAME
36
36
  from ..core.docker_resources import get_volumes, parse_env_config
37
37
  from ..core.gcloud_context import add_zone_and_project
38
- from ..core.kueue_manager import LOCAL_QUEUE_NAME
39
38
  from ..core.monitoring import get_gke_outlier_dashboard
40
39
  from ..core.nap import (
41
40
  get_autoprovisioning_node_selector_args,
@@ -64,6 +63,8 @@ from ..core.scheduling import (
64
63
  get_cpu_affinity,
65
64
  get_gpu_scheduler,
66
65
  create_sub_slicing_annotations,
66
+ create_placement_policy_label,
67
+ is_placement_policy_supported,
67
68
  )
68
69
  from ..core.storage import (
69
70
  GCE_PD_TYPE,
@@ -77,6 +78,7 @@ from ..core.storage import (
77
78
  get_storages_to_mount,
78
79
  )
79
80
  from ..core.system_characteristics import (
81
+ SUB_SLICING_TOPOLOGIES,
80
82
  AcceleratorType,
81
83
  get_system_characteristics,
82
84
  compute_vms_per_slice,
@@ -95,7 +97,7 @@ from ..core.workload_decorators import (
95
97
  tcpx_decorator,
96
98
  tcpxo_decorator,
97
99
  )
98
- from ..utils.console import get_user_input, xpk_exit, xpk_print
100
+ from ..utils.console import ask_for_user_consent, xpk_exit, xpk_print
99
101
  from packaging.version import Version
100
102
  from ..utils.file import write_tmp_file
101
103
  from ..utils.execution_context import is_dry_run
@@ -144,6 +146,7 @@ spec:
144
146
  nodeSelector:
145
147
  {accelerator_label}
146
148
  {machine_label}
149
+ {placement_policy_label}
147
150
  {autoprovisioning_args}
148
151
  priorityClassName: {args.priority}
149
152
  hostNetwork: true
@@ -193,6 +196,8 @@ spec:
193
196
  {gpu_scheduler}
194
197
  priorityClassName: {args.priority}
195
198
  restartPolicy: Never
199
+ nodeSelector:
200
+ {placement_policy_label}
196
201
  imagePullSecrets:
197
202
  - name: {args.docker_image_pull_secret}
198
203
  hostNetwork: true
@@ -238,6 +243,8 @@ spec:
238
243
  spec:
239
244
  priorityClassName: {args.priority}
240
245
  restartPolicy: Never
246
+ nodeSelector:
247
+ {placement_policy_label}
241
248
  imagePullSecrets:
242
249
  - name: {args.docker_image_pull_secret}
243
250
  dnsPolicy: ClusterFirstWithHostNet
@@ -273,6 +280,7 @@ PW_WORKLOAD_CREATE_YAML = """
273
280
  terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
274
281
  priorityClassName: {args.priority}
275
282
  nodeSelector:
283
+ {placement_policy_label}
276
284
  {autoprovisioning_args}
277
285
  pathwaysDir: {args.pathways_gcs_location} #This bucket needs to be created in advance.
278
286
  controller:
@@ -284,7 +292,6 @@ PW_WORKLOAD_CREATE_YAML = """
284
292
  {user_workload}
285
293
  """
286
294
 
287
- SUB_SLICING_TOPOLOGIES = ['2x2', '2x4', '4x4', '4x8', '8x8', '8x16', '16x16']
288
295
  SUB_SLICING_MINIMUM_KUEUE_VERSION = Version('0.13.0')
289
296
 
290
297
 
@@ -481,12 +488,17 @@ def workload_create(args) -> None:
481
488
  podFailurePolicy:
482
489
  rules:
483
490
  - action: FailJob
484
- onPodConditions: []
485
491
  onExitCodes:
486
492
  containerName: {get_main_container_docker_image(args, system)}
487
493
  operator: NotIn
488
494
  values: [{restart_on_exit_codes}]"""
489
495
 
496
+ placement_policy_label = (
497
+ create_placement_policy_label(system)
498
+ if is_placement_policy_supported(system)
499
+ else ''
500
+ )
501
+
490
502
  # Create the workload file based on accelerator type or workload type.
491
503
  if system.accelerator_type == AcceleratorType.GPU:
492
504
  container, debugging_dashboard_id = get_user_workload_container(
@@ -520,6 +532,7 @@ def workload_create(args) -> None:
520
532
  failure_policy_rules=failure_policy_rules,
521
533
  pod_failure_policy=pod_failure_policy,
522
534
  annotations=annotations,
535
+ placement_policy_label=placement_policy_label,
523
536
  )
524
537
 
525
538
  sub_networks = get_cluster_subnetworks()
@@ -544,6 +557,7 @@ def workload_create(args) -> None:
544
557
  service_account=service_account,
545
558
  failure_policy_rules=failure_policy_rules,
546
559
  pod_failure_policy=pod_failure_policy,
560
+ placement_policy_label=placement_policy_label,
547
561
  )
548
562
 
549
563
  elif args.use_pathways and ensure_pathways_workload_prerequisites(
@@ -561,6 +575,7 @@ def workload_create(args) -> None:
561
575
  user_workload=get_user_workload_for_pathways(args, system),
562
576
  local_queue_name=LOCAL_QUEUE_NAME,
563
577
  autoprovisioning_args=autoprovisioning_args,
578
+ placement_policy_label=placement_policy_label,
564
579
  )
565
580
  else:
566
581
  container, debugging_dashboard_id = get_user_workload_container(
@@ -588,6 +603,7 @@ def workload_create(args) -> None:
588
603
  create_sub_slicing_annotations(args.sub_slicing_topology)
589
604
  )
590
605
  ),
606
+ placement_policy_label=placement_policy_label,
591
607
  machine_label=create_machine_label(system.accelerator_type, system),
592
608
  local_queue_name=LOCAL_QUEUE_NAME,
593
609
  autoprovisioning_args=autoprovisioning_args,
@@ -698,9 +714,10 @@ def _validate_sub_slicing_availability():
698
714
  )
699
715
  xpk_exit(1)
700
716
 
701
- kueue_manager = KueueManager()
702
- return_code, current_version = kueue_manager.get_installed_kueue_version()
703
- if return_code != 0:
717
+ return_code, current_version = get_installed_kueue_version(
718
+ dry_run_version=Version('0.13')
719
+ )
720
+ if return_code != 0 or not current_version:
704
721
  xpk_print(
705
722
  'Error: Unable to validate sub-slicing support on a given cluster.'
706
723
  )
@@ -785,11 +802,10 @@ def workload_delete(args) -> None:
785
802
  xpk_exit(return_code)
786
803
  # Skip the header
787
804
  workloads = [x.split(' ')[0] for x in return_value.splitlines()][1:]
788
- if workloads and not args.force:
789
- will_delete = get_user_input(
805
+ if workloads:
806
+ will_delete = ask_for_user_consent(
790
807
  f'Planning to delete {len(workloads)} workloads in the cluster'
791
- f' {args.cluster} including {workloads}. \nDo you wish to delete: y'
792
- ' (yes) / n (no):\n'
808
+ f' {args.cluster} including {workloads}. \nDo you wish to delete?'
793
809
  )
794
810
  else:
795
811
  workloads = [args.workload]
@@ -107,7 +107,7 @@ def test_validate_sub_slicing_availability_exits_when_kueue_version_cannot_be_de
107
107
  return_value=(0, True),
108
108
  )
109
109
  mocker.patch(
110
- 'xpk.commands.workload.KueueManager.get_installed_kueue_version',
110
+ 'xpk.commands.workload.get_installed_kueue_version',
111
111
  return_value=(1, None),
112
112
  )
113
113
  with pytest.raises(SystemExit):
@@ -124,7 +124,7 @@ def test_validate_sub_slicing_availability_exits_when_kueue_version_does_not_mee
124
124
  return_value=(0, True),
125
125
  )
126
126
  mocker.patch(
127
- 'xpk.commands.workload.KueueManager.get_installed_kueue_version',
127
+ 'xpk.commands.workload.get_installed_kueue_version',
128
128
  return_value=(0, Version('0.0.0')),
129
129
  )
130
130
  with pytest.raises(SystemExit):
@@ -141,7 +141,7 @@ def test_validate_sub_slicing_availability_does_nothing_when_cluster_is_correctl
141
141
  return_value=(0, True),
142
142
  )
143
143
  mocker.patch(
144
- 'xpk.commands.workload.KueueManager.get_installed_kueue_version',
144
+ 'xpk.commands.workload.get_installed_kueue_version',
145
145
  return_value=(0, Version('0.13.0')),
146
146
  )
147
147
  _validate_sub_slicing_availability()