xpk 0.14.4__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. integration/README.md +19 -0
  2. integration/gcluster_a3mega_test.py +11 -0
  3. integration/gcluster_a3ultra_test.py +11 -0
  4. integration/gcluster_a4_test.py +11 -0
  5. xpk/blueprints/a3mega/config-map.yaml.tftpl +15 -0
  6. xpk/blueprints/a3mega/storage_crd.yaml +52 -0
  7. xpk/blueprints/a3ultra/config-map.yaml.tftpl +15 -0
  8. xpk/blueprints/a3ultra/mlgru-disable.yaml +59 -0
  9. xpk/blueprints/a3ultra/nccl-installer.yaml +95 -0
  10. xpk/blueprints/a3ultra/storage_crd.yaml +52 -0
  11. xpk/blueprints/a4/config-map.yaml.tftpl +15 -0
  12. xpk/blueprints/a4/nccl-rdma-installer-a4.yaml +66 -0
  13. xpk/blueprints/a4/storage_crd.yaml +52 -0
  14. xpk/commands/cluster.py +89 -32
  15. xpk/commands/cluster_gcluster.py +25 -5
  16. xpk/commands/cluster_gcluster_test.py +16 -3
  17. xpk/commands/cluster_test.py +353 -7
  18. xpk/commands/config.py +3 -5
  19. xpk/commands/inspector.py +5 -3
  20. xpk/commands/kind.py +3 -1
  21. xpk/commands/managed_ml_diagnostics.py +249 -0
  22. xpk/commands/managed_ml_diagnostics_test.py +146 -0
  23. xpk/commands/storage.py +8 -10
  24. xpk/commands/workload.py +143 -142
  25. xpk/commands/workload_test.py +160 -118
  26. xpk/core/blueprint/blueprint_generator.py +73 -33
  27. xpk/core/blueprint/blueprint_test.py +9 -0
  28. xpk/core/blueprint/testing/data/a3_mega.yaml +129 -0
  29. xpk/core/blueprint/testing/data/a3_mega_spot.yaml +125 -0
  30. xpk/core/blueprint/testing/data/a3_ultra.yaml +173 -0
  31. xpk/core/blueprint/testing/data/a4.yaml +185 -0
  32. xpk/core/capacity.py +48 -8
  33. xpk/core/capacity_test.py +32 -1
  34. xpk/core/cluster.py +55 -104
  35. xpk/core/cluster_test.py +170 -0
  36. xpk/core/commands.py +4 -10
  37. xpk/core/config.py +88 -7
  38. xpk/core/config_test.py +67 -11
  39. xpk/core/docker_container.py +3 -1
  40. xpk/core/docker_image.py +10 -6
  41. xpk/core/docker_resources.py +1 -10
  42. xpk/core/gcloud_context.py +18 -12
  43. xpk/core/gcloud_context_test.py +111 -1
  44. xpk/core/kjob.py +17 -19
  45. xpk/core/kueue_manager.py +205 -51
  46. xpk/core/kueue_manager_test.py +158 -4
  47. xpk/core/nap.py +13 -14
  48. xpk/core/nodepool.py +37 -43
  49. xpk/core/nodepool_test.py +42 -19
  50. xpk/core/pathways.py +23 -0
  51. xpk/core/pathways_test.py +57 -0
  52. xpk/core/resources.py +84 -27
  53. xpk/core/scheduling.py +144 -133
  54. xpk/core/scheduling_test.py +298 -6
  55. xpk/core/system_characteristics.py +256 -19
  56. xpk/core/system_characteristics_test.py +128 -5
  57. xpk/core/telemetry.py +263 -0
  58. xpk/core/telemetry_test.py +211 -0
  59. xpk/core/vertex.py +4 -3
  60. xpk/core/workload_decorators/tcpx_decorator.py +5 -1
  61. xpk/main.py +33 -13
  62. xpk/parser/cluster.py +40 -67
  63. xpk/parser/cluster_test.py +83 -3
  64. xpk/parser/common.py +84 -0
  65. xpk/parser/storage.py +10 -0
  66. xpk/parser/storage_test.py +47 -0
  67. xpk/parser/workload.py +14 -29
  68. xpk/parser/workload_test.py +3 -49
  69. xpk/telemetry_uploader.py +29 -0
  70. xpk/templates/arm_gpu_workload_crate.yaml.j2 +46 -0
  71. xpk/templates/kueue_gke_default_topology.yaml.j2 +1 -1
  72. xpk/templates/kueue_sub_slicing_topology.yaml.j2 +3 -8
  73. xpk/utils/console.py +41 -10
  74. xpk/utils/console_test.py +106 -0
  75. xpk/utils/feature_flags.py +10 -1
  76. xpk/utils/file.py +4 -1
  77. xpk/utils/topology.py +4 -0
  78. xpk/utils/user_agent.py +35 -0
  79. xpk/utils/user_agent_test.py +44 -0
  80. xpk/utils/user_input.py +48 -0
  81. xpk/utils/user_input_test.py +92 -0
  82. xpk/utils/validation.py +2 -13
  83. xpk/utils/versions.py +31 -0
  84. xpk-0.16.0.dist-info/METADATA +127 -0
  85. xpk-0.16.0.dist-info/RECORD +168 -0
  86. xpk-0.14.4.dist-info/METADATA +0 -1645
  87. xpk-0.14.4.dist-info/RECORD +0 -139
  88. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/WHEEL +0 -0
  89. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/entry_points.txt +0 -0
  90. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/licenses/LICENSE +0 -0
  91. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/top_level.txt +0 -0
xpk/core/scheduling.py CHANGED
@@ -14,59 +14,63 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ from enum import Enum
18
+
19
+ from .kueue_manager import get_installed_kueue_version, has_sub_slicing_enabled
20
+ from ..utils.feature_flags import FeatureFlags
21
+ from ..utils.topology import get_slice_topology_level
17
22
  from ..utils.console import xpk_print
23
+ from ..utils.topology import is_topology_valid
18
24
  from ..utils.execution_context import is_dry_run
19
25
  from .capacity import AUTOPROVISIONING_CONFIG_MAXIMUM_KEY, AUTOPROVISIONING_CONFIG_VALUE
20
- from .resources import CLUSTER_RESOURCES_CONFIGMAP, get_cluster_configmap
21
26
  from .system_characteristics import (
27
+ SUB_SLICING_TOPOLOGIES,
22
28
  AcceleratorType,
23
- AcceleratorTypeToAcceleratorCharacteristics,
24
29
  SystemCharacteristics,
30
+ create_accelerator_label,
31
+ create_machine_label,
25
32
  )
33
+ from packaging.version import Version
26
34
 
35
+ _SUB_SLICING_MINIMUM_KUEUE_VERSION = Version('0.13.0')
27
36
 
28
- def check_if_workload_can_schedule(args, system: SystemCharacteristics) -> bool:
29
- """Check if workload can schedule based on the cluster resources (tpu_type and maximum VM in cluster).
30
37
 
31
- Args:
32
- args: user provided arguments for running the command.
33
- system: system characteristics
38
+ class WorkloadScheduling(Enum):
39
+ UNAVAILABLE = 0
40
+ AVAILABLE = 1
41
+ SUB_SLICING_AVAILABLE = 2
42
+
43
+
44
+ def check_if_workload_can_schedule(
45
+ args,
46
+ workload_system: SystemCharacteristics,
47
+ cluster_system: SystemCharacteristics | None,
48
+ resources_config_map: dict[str, str] | None,
49
+ ) -> WorkloadScheduling:
50
+ """Check if workload can schedule based on the cluster resources (tpu_type and maximum VM in cluster).
34
51
 
35
52
  Returns:
36
- returns true if workload can schedule, otherwise returns false.
53
+ returns WorkloadScheduling describing scheduling option.
37
54
  """
38
- resources_configmap_name = f'{args.cluster}-{CLUSTER_RESOURCES_CONFIGMAP}'
39
- cluster_config_map = get_cluster_configmap(resources_configmap_name)
55
+ if is_dry_run() and not cluster_system:
56
+ xpk_print('Skipping workload scheduling validation in dry run.')
57
+ return WorkloadScheduling.AVAILABLE
40
58
 
41
- # Prevents workload creation failure for existing clusters with no ConfigMap
42
- if cluster_config_map is None:
59
+ if resources_config_map is None:
43
60
  xpk_print(
44
- 'No ConfigMap exist for cluster with the name'
45
- f' {resources_configmap_name}.'
61
+ "Skipping workload scheduling validation, because there's no Resources"
62
+ ' ConfigMap in the cluster.'
46
63
  )
47
- return True
48
-
49
- if is_dry_run():
50
- return True
64
+ return WorkloadScheduling.AVAILABLE
51
65
 
52
- # Check for gke accelerator type:
53
- missing_gke_accelerator_type = False
54
- if not cluster_config_map.get(system.gke_accelerator):
55
- xpk_print(
56
- f'GKE Accelerator Type Check: {args.workload} is requesting'
57
- f' {system.gke_accelerator} but cluster only contains'
58
- f' {cluster_config_map.keys()}. '
59
- )
60
- missing_gke_accelerator_type = True
61
- elif (
62
- cluster_config_map[system.gke_accelerator]
63
- == AUTOPROVISIONING_CONFIG_VALUE
64
- ):
66
+ if _is_cluster_set_up_for_nap(workload_system, resources_config_map):
65
67
  # Run total chip check when in autoprovisioning mode.
66
68
  max_chips_in_cluster = int(
67
- cluster_config_map[AUTOPROVISIONING_CONFIG_MAXIMUM_KEY]
69
+ resources_config_map[AUTOPROVISIONING_CONFIG_MAXIMUM_KEY]
70
+ )
71
+ num_chips_in_workload = get_total_chips_requested_from_args(
72
+ args, workload_system
68
73
  )
69
- num_chips_in_workload = get_total_chips_requested_from_args(args, system)
70
74
 
71
75
  if num_chips_in_workload > max_chips_in_cluster:
72
76
  xpk_print(
@@ -75,44 +79,100 @@ def check_if_workload_can_schedule(args, system: SystemCharacteristics) -> bool:
75
79
  ' Resize the cluster to support more chips with'
76
80
  ' `xpk cluster create --autoprovisioning-max-chips=X ...`'
77
81
  )
78
- return False
79
- return True
82
+ return WorkloadScheduling.UNAVAILABLE
83
+ return WorkloadScheduling.AVAILABLE
84
+
85
+ if workload_system.device_type in resources_config_map:
86
+ if _check_workload_size_fits(
87
+ args,
88
+ workload_system,
89
+ max_vm_in_cluster=int(
90
+ resources_config_map[workload_system.device_type]
91
+ ),
92
+ ):
93
+ return WorkloadScheduling.AVAILABLE
94
+ else:
95
+ return WorkloadScheduling.UNAVAILABLE
96
+
97
+ if _check_sub_slicing_availability(
98
+ workload_system=workload_system, cluster_system=cluster_system
99
+ ):
100
+ assert cluster_system
101
+ if _check_workload_size_fits(
102
+ args,
103
+ workload_system,
104
+ max_vm_in_cluster=int(resources_config_map[cluster_system.device_type]),
105
+ ):
106
+ return WorkloadScheduling.SUB_SLICING_AVAILABLE
107
+ else:
108
+ return WorkloadScheduling.UNAVAILABLE
109
+
110
+ xpk_print(
111
+ 'Workload scheduling validation failed. XPK will not create the workload'
112
+ f' {args.workload}.'
113
+ )
114
+ return WorkloadScheduling.UNAVAILABLE
80
115
 
81
- # Check for device type
82
- missing_device_type = False
83
- device_type = system.device_type
84
- if device_type not in cluster_config_map:
85
- xpk_print(
86
- f'Device Type Check: {args.workload} is requesting {device_type} but '
87
- f'cluster only contains {cluster_config_map.keys()}. '
88
- )
89
- missing_device_type = True
90
116
 
91
- if missing_device_type and missing_gke_accelerator_type:
117
+ def _is_cluster_set_up_for_nap(
118
+ workload_system: SystemCharacteristics, resources_config_map: dict[str, str]
119
+ ) -> bool:
120
+ return (
121
+ resources_config_map.get(workload_system.gke_accelerator, None)
122
+ == AUTOPROVISIONING_CONFIG_VALUE
123
+ )
124
+
125
+
126
+ def _check_workload_size_fits(
127
+ args,
128
+ workload_system: SystemCharacteristics,
129
+ max_vm_in_cluster: int,
130
+ ) -> bool:
131
+ if workload_system.accelerator_type == AcceleratorType.GPU:
132
+ vm_required_by_workload = args.num_nodes
133
+ else:
134
+ vm_required_by_workload = args.num_slices * workload_system.vms_per_slice
135
+
136
+ if vm_required_by_workload > max_vm_in_cluster:
92
137
  xpk_print(
93
- 'Both Device Type and GKE Accelerator Type checks failed.'
94
- f' XPK will not create the workload {args.workload}.'
138
+ f'{args.workload} is requesting {args.num_slices} slice/slices of'
139
+ f' {workload_system.device_type}, which is'
140
+ f' {vm_required_by_workload} VMs, but the cluster only contains'
141
+ f' {max_vm_in_cluster} VMs of {workload_system.device_type}. XPK will'
142
+ ' not create this workload.'
95
143
  )
96
144
  return False
97
- else:
98
- # Check if the size of the workload will fit in the cluster.
99
- max_vm_in_cluster = int(cluster_config_map[device_type])
100
- if system.accelerator_type == AcceleratorType.GPU:
101
- vm_required_by_workload = args.num_nodes
102
- else:
103
- vm_required_by_workload = args.num_slices * system.vms_per_slice
104
- if vm_required_by_workload > max_vm_in_cluster:
105
- xpk_print(
106
- f'{args.workload} is requesting {args.num_slices} slice/slices of'
107
- f' {device_type}, which is {vm_required_by_workload} VMs, but the'
108
- f' cluster only contains {max_vm_in_cluster} VMs of {device_type}.'
109
- ' XPK will not create this workload.'
110
- )
111
- return False
112
-
113
145
  return True
114
146
 
115
147
 
148
+ def _check_sub_slicing_availability(
149
+ workload_system: SystemCharacteristics,
150
+ cluster_system: SystemCharacteristics | None,
151
+ ) -> bool:
152
+ if (
153
+ (not FeatureFlags.SUB_SLICING_ENABLED)
154
+ or (not cluster_system)
155
+ or (workload_system.gke_accelerator != cluster_system.gke_accelerator)
156
+ or (not cluster_system.supports_sub_slicing)
157
+ or (workload_system.topology not in SUB_SLICING_TOPOLOGIES)
158
+ ):
159
+ return False
160
+
161
+ return_code, sub_slicing_enabled = has_sub_slicing_enabled()
162
+ if return_code != 0 or not sub_slicing_enabled:
163
+ return False
164
+
165
+ return_code, current_version = get_installed_kueue_version(
166
+ dry_run_version=Version('0.13')
167
+ )
168
+
169
+ return (
170
+ return_code == 0
171
+ and current_version is not None
172
+ and current_version >= _SUB_SLICING_MINIMUM_KUEUE_VERSION
173
+ )
174
+
175
+
116
176
  def get_total_chips_requested_from_args(
117
177
  args, system: SystemCharacteristics
118
178
  ) -> int:
@@ -133,7 +193,7 @@ def get_total_chips_requested_from_args(
133
193
  return int(num_chips)
134
194
 
135
195
 
136
- def get_cpu_affinity(accelerator_type) -> str:
196
+ def get_cpu_affinity(accelerator_type: AcceleratorType) -> str:
137
197
  """Generate affinity rules for CPU nodepools, so that workload pods are
138
198
  not scheduled on the default pool machines.
139
199
  Args:
@@ -197,10 +257,8 @@ def get_gpu_scheduler(
197
257
  """
198
258
  gpu_scheduler = gpu_scheduler_yaml.format(
199
259
  scheduler_name=args.scheduler,
200
- accelerator_label=create_accelerator_label(
201
- system.accelerator_type, system
202
- ),
203
- machine_label=create_machine_label(system.accelerator_type, system),
260
+ accelerator_label=create_accelerator_label(system),
261
+ machine_label=create_machine_label(system),
204
262
  node_pool_name=f'{args.cluster}-np-0',
205
263
  autoprovisioning_args=autoprovisioning_args,
206
264
  )
@@ -215,74 +273,14 @@ def get_gpu_scheduler(
215
273
  return gpu_scheduler, return_code
216
274
 
217
275
 
218
- def create_accelerator_label(accelerator_type, system) -> str:
219
- """Generates accelerator label.
220
-
221
- Args:
222
- accelerator_type: type of accelerator.
223
- system: system characteristics.
224
-
225
- Returns:
226
- The accelerator label.
227
- """
228
- if accelerator_type == AcceleratorType.CPU:
229
- return ''
230
- return (
231
- f'{AcceleratorTypeToAcceleratorCharacteristics[accelerator_type].accelerator_label}:'
232
- f' {system.gke_accelerator}'
233
- )
234
-
235
-
236
- def create_tpu_machine_type(accelerator_type, system) -> str:
237
- """Generates TPU machine type..
238
-
239
- Args:
240
- accelerator_type: type of accelerator.
241
- system: system characteristics.
242
-
243
- Returns:
244
- The accelerator label.
245
- """
246
- if accelerator_type == AcceleratorType.TPU:
276
+ def create_tpu_machine_type(system: SystemCharacteristics) -> str:
277
+ if system.accelerator_type == AcceleratorType.TPU:
247
278
  return f'{system.gce_machine_type}'
248
279
  return ''
249
280
 
250
281
 
251
- def create_machine_label(
252
- accelerator_type, system, autoprovisioning_enabled: bool = False
253
- ) -> str:
254
- """Generates machine label.
255
-
256
- Args:
257
- accelerator_type: type of accelerator.
258
- system: system characteristics.
259
- autoprovisioning_enabled: describes autoprovisioning enablement.
260
-
261
- Returns:
262
- The machine label.
263
- """
264
- if accelerator_type == AcceleratorType.TPU and not autoprovisioning_enabled:
265
- return (
266
- f'{AcceleratorTypeToAcceleratorCharacteristics[accelerator_type].machine_label}:'
267
- f' {system.topology}'
268
- )
269
- return ''
270
-
271
-
272
- def create_tpu_topology(
273
- accelerator_type, system, autoprovisioning_enabled: bool = False
274
- ) -> str:
275
- """Generates TPU topology.
276
-
277
- Args:
278
- accelerator_type: type of accelerator.
279
- system: system characteristics.
280
- autoprovisioning_enabled: describes autoprovisioning enablement.
281
-
282
- Returns:
283
- The machine label.
284
- """
285
- if accelerator_type == AcceleratorType.TPU and not autoprovisioning_enabled:
282
+ def create_tpu_topology(system: SystemCharacteristics) -> str:
283
+ if system.accelerator_type == AcceleratorType.TPU:
286
284
  return f'{system.topology}'
287
285
  return ''
288
286
 
@@ -299,7 +297,20 @@ def create_sub_slicing_annotations(sub_slicing_topology: str) -> list[str]:
299
297
  return [
300
298
  (
301
299
  'kueue.x-k8s.io/podset-required-topology:'
302
- f' "google.com/gke-tpu-slice-{sub_slicing_topology}-id"'
300
+ f' "{get_slice_topology_level(sub_slicing_topology)}"'
303
301
  ),
304
302
  f'cloud.google.com/gke-tpu-slice-topology: {sub_slicing_topology}',
305
303
  ]
304
+
305
+
306
+ def create_placement_policy_label(system: SystemCharacteristics) -> str:
307
+ name = get_placement_policy_name(system)
308
+ return f'cloud.google.com/placement-policy-name: {name}'
309
+
310
+
311
+ def get_placement_policy_name(system: SystemCharacteristics) -> str:
312
+ return f'{system.device_type}-{system.topology}-placement-policy'
313
+
314
+
315
+ def is_placement_policy_supported(system: SystemCharacteristics) -> bool:
316
+ return system.requires_workload_policy and is_topology_valid(system.topology)
@@ -14,18 +14,310 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from .scheduling import create_sub_slicing_annotations
17
+ from argparse import Namespace
18
+ from dataclasses import dataclass
19
+ import dataclasses
20
+ import pytest
21
+ from pytest_mock import MockerFixture
22
+ from xpk.core.capacity import AUTOPROVISIONING_CONFIG_MAXIMUM_KEY, AUTOPROVISIONING_CONFIG_VALUE
23
+ from xpk.core.testing.commands_tester import CommandsTester
24
+ from xpk.utils.feature_flags import FeatureFlags
25
+ from .scheduling import WorkloadScheduling, check_if_workload_can_schedule, create_sub_slicing_annotations, create_placement_policy_label, get_placement_policy_name, is_placement_policy_supported
26
+ from .system_characteristics import SystemCharacteristics, AcceleratorType, DockerPlatform, get_system_characteristics_by_device_type
18
27
 
19
28
 
20
- def test_create_sub_slicing_annotations_returns_valid_annotations():
21
- subslicing_topology = '2x2'
29
+ def _get_system_characteristics_or_die(
30
+ device_type: str,
31
+ ) -> SystemCharacteristics:
32
+ system = get_system_characteristics_by_device_type(device_type)[0]
33
+ assert system
34
+ return system
35
+
22
36
 
23
- result = create_sub_slicing_annotations(subslicing_topology)
37
+ @pytest.fixture(autouse=True)
38
+ def commands_tester(mocker: MockerFixture) -> CommandsTester:
39
+ return CommandsTester(
40
+ mocker=mocker,
41
+ run_command_for_value_path='xpk.core.kueue_manager.run_command_for_value',
42
+ )
43
+
44
+
45
+ def test_create_sub_slicing_annotations_returns_valid_annotations():
46
+ result = create_sub_slicing_annotations(sub_slicing_topology='2x4')
24
47
 
25
48
  assert result == [
26
49
  (
27
50
  'kueue.x-k8s.io/podset-required-topology:'
28
- ' "google.com/gke-tpu-slice-2x2-id"'
51
+ ' "cloud.google.com/gke-tpu-slice-2x4-id"'
29
52
  ),
30
- 'cloud.google.com/gke-tpu-slice-topology: 2x2',
53
+ 'cloud.google.com/gke-tpu-slice-topology: 2x4',
31
54
  ]
55
+
56
+
57
+ def test_create_placement_policy_label_returns_valid_label():
58
+ system_characteristics = SystemCharacteristics(
59
+ chips_per_vm=1,
60
+ gce_machine_type='tpu7x-standard-1t',
61
+ gke_accelerator='tpu7x',
62
+ requires_workload_policy=False,
63
+ topology='1x1x1',
64
+ vms_per_slice=1,
65
+ device_type='tpu7x',
66
+ accelerator_type=AcceleratorType.TPU,
67
+ supports_sub_slicing=False,
68
+ docker_platform=DockerPlatform.ARM,
69
+ )
70
+ label = create_placement_policy_label(system_characteristics)
71
+ assert (
72
+ label
73
+ == 'cloud.google.com/placement-policy-name: tpu7x-1x1x1-placement-policy'
74
+ )
75
+
76
+
77
+ def test_get_placement_policy_name_returns_valid_name():
78
+ system_characteristics = SystemCharacteristics(
79
+ chips_per_vm=1,
80
+ gce_machine_type='tpu7x-standard-1t',
81
+ gke_accelerator='tpu7x',
82
+ requires_workload_policy=False,
83
+ topology='1x1x1',
84
+ vms_per_slice=1,
85
+ device_type='tpu7x',
86
+ accelerator_type=AcceleratorType.TPU,
87
+ supports_sub_slicing=False,
88
+ docker_platform=DockerPlatform.ARM,
89
+ )
90
+ name = get_placement_policy_name(system_characteristics)
91
+ assert name == 'tpu7x-1x1x1-placement-policy'
92
+
93
+
94
+ def test_is_placement_policy_supported_returns_true_for_system_characteristics_supporting_workload_policy_and_having_valid_topology():
95
+ system_characteristics = SystemCharacteristics(
96
+ chips_per_vm=1,
97
+ gce_machine_type='tpu7x-standard-1t',
98
+ gke_accelerator='tpu7x',
99
+ requires_workload_policy=True,
100
+ topology='1x1x1',
101
+ vms_per_slice=1,
102
+ device_type='tpu7x',
103
+ accelerator_type=AcceleratorType.TPU,
104
+ supports_sub_slicing=False,
105
+ docker_platform=DockerPlatform.ARM,
106
+ )
107
+ assert is_placement_policy_supported(system_characteristics) is True
108
+
109
+
110
+ def test_is_placement_policy_supported_returns_false_for_system_characteristics_not_supporting_workload_policy_and_having_valid_topology():
111
+ system_characteristics = SystemCharacteristics(
112
+ chips_per_vm=1,
113
+ gce_machine_type='tpu7x-standard-1t',
114
+ gke_accelerator='tpu7x',
115
+ requires_workload_policy=False,
116
+ topology='1x1x1',
117
+ vms_per_slice=1,
118
+ device_type='tpu7x',
119
+ accelerator_type=AcceleratorType.TPU,
120
+ supports_sub_slicing=False,
121
+ docker_platform=DockerPlatform.ARM,
122
+ )
123
+ assert is_placement_policy_supported(system_characteristics) is False
124
+
125
+
126
+ def test_is_placement_policy_supported_returns_false_for_system_characteristics_supporting_workload_policy_and_having_invalid_topology():
127
+ system_characteristics = SystemCharacteristics(
128
+ chips_per_vm=1,
129
+ gce_machine_type='tpu7x-standard-1t',
130
+ gke_accelerator='tpu7x',
131
+ requires_workload_policy=True,
132
+ topology='aaa',
133
+ vms_per_slice=1,
134
+ device_type='tpu7x',
135
+ accelerator_type=AcceleratorType.TPU,
136
+ supports_sub_slicing=False,
137
+ docker_platform=DockerPlatform.ARM,
138
+ )
139
+ assert is_placement_policy_supported(system_characteristics) is False
140
+
141
+
142
+ @dataclass(frozen=True)
143
+ class SchedulingTestCase:
144
+ workload_system: SystemCharacteristics
145
+ num_slices: int = 1
146
+ cluster_system: SystemCharacteristics | None = None
147
+ resources_config_map: dict[str, str] | None = None
148
+ sub_slicing_feature_enabled: bool = False
149
+ kueue_version: str | None = None
150
+ sub_slicing_topology_set: bool = False
151
+
152
+
153
+ SUB_SLICING_CASE = SchedulingTestCase(
154
+ workload_system=_get_system_characteristics_or_die('v6e-8'),
155
+ cluster_system=_get_system_characteristics_or_die('v6e-16'),
156
+ resources_config_map={'v6e-16': '8'},
157
+ sub_slicing_feature_enabled=True,
158
+ kueue_version='0.13.0',
159
+ sub_slicing_topology_set=True,
160
+ num_slices=1,
161
+ )
162
+
163
+ NAP_CASE = SchedulingTestCase(
164
+ workload_system=_get_system_characteristics_or_die('v6e-8'),
165
+ cluster_system=None,
166
+ resources_config_map={
167
+ 'tpu-v6e-slice': AUTOPROVISIONING_CONFIG_VALUE,
168
+ AUTOPROVISIONING_CONFIG_MAXIMUM_KEY: '10',
169
+ },
170
+ )
171
+
172
+
173
+ @pytest.mark.parametrize(
174
+ 'title, case, expected',
175
+ [
176
+ (
177
+ 'No resources config map',
178
+ SchedulingTestCase(
179
+ workload_system=_get_system_characteristics_or_die('v6e-8'),
180
+ resources_config_map=None,
181
+ ),
182
+ WorkloadScheduling.AVAILABLE,
183
+ ),
184
+ (
185
+ 'Cluster system matches and workload fits',
186
+ SchedulingTestCase(
187
+ workload_system=_get_system_characteristics_or_die('v6e-8'),
188
+ resources_config_map={'v6e-8': '8'},
189
+ num_slices=2,
190
+ ),
191
+ WorkloadScheduling.AVAILABLE,
192
+ ),
193
+ (
194
+ 'Cluster system does not match',
195
+ SchedulingTestCase(
196
+ workload_system=_get_system_characteristics_or_die('v6e-8'),
197
+ resources_config_map={'tpu7x-32': '16'},
198
+ ),
199
+ WorkloadScheduling.UNAVAILABLE,
200
+ ),
201
+ (
202
+ 'Workload does not fit',
203
+ SchedulingTestCase(
204
+ workload_system=_get_system_characteristics_or_die('v6e-8'),
205
+ resources_config_map={'v6e-8': '8'},
206
+ num_slices=100,
207
+ ),
208
+ WorkloadScheduling.UNAVAILABLE,
209
+ ),
210
+ (
211
+ 'Correct NAP',
212
+ NAP_CASE,
213
+ WorkloadScheduling.AVAILABLE,
214
+ ),
215
+ (
216
+ 'NAP, too big workload',
217
+ dataclasses.replace(NAP_CASE, num_slices=100),
218
+ WorkloadScheduling.UNAVAILABLE,
219
+ ),
220
+ (
221
+ 'Correct Sub-slicing',
222
+ SUB_SLICING_CASE,
223
+ WorkloadScheduling.SUB_SLICING_AVAILABLE,
224
+ ),
225
+ (
226
+ 'Sub-slicing, but disabled flag',
227
+ dataclasses.replace(
228
+ SUB_SLICING_CASE, sub_slicing_feature_enabled=False
229
+ ),
230
+ WorkloadScheduling.UNAVAILABLE,
231
+ ),
232
+ (
233
+ 'Sub-slicing, but low Kueue version',
234
+ dataclasses.replace(SUB_SLICING_CASE, kueue_version='0.12.0'),
235
+ WorkloadScheduling.UNAVAILABLE,
236
+ ),
237
+ (
238
+ 'Sub-slicing, but no sub-slicing-topology',
239
+ dataclasses.replace(
240
+ SUB_SLICING_CASE, sub_slicing_topology_set=False
241
+ ),
242
+ WorkloadScheduling.UNAVAILABLE,
243
+ ),
244
+ (
245
+ 'Sub-slicing, but workload too big',
246
+ dataclasses.replace(SUB_SLICING_CASE, num_slices=100),
247
+ WorkloadScheduling.UNAVAILABLE,
248
+ ),
249
+ (
250
+ 'Sub-slicing, but cluster system is incorrect',
251
+ dataclasses.replace(
252
+ SUB_SLICING_CASE,
253
+ cluster_system=_get_system_characteristics_or_die('tpu7x-16'),
254
+ ),
255
+ WorkloadScheduling.UNAVAILABLE,
256
+ ),
257
+ (
258
+ 'Sub-slicing, but workload system is incorrect',
259
+ dataclasses.replace(
260
+ SUB_SLICING_CASE,
261
+ workload_system=_get_system_characteristics_or_die('tpu7x-8'),
262
+ ),
263
+ WorkloadScheduling.UNAVAILABLE,
264
+ ),
265
+ (
266
+ 'Sub-slicing, but workload topology is incorrect',
267
+ dataclasses.replace(
268
+ SUB_SLICING_CASE,
269
+ workload_system=_get_system_characteristics_or_die('v6e-2x2'),
270
+ ),
271
+ WorkloadScheduling.UNAVAILABLE,
272
+ ),
273
+ (
274
+ (
275
+ 'Sub-slicing should be ignored when a given device is already'
276
+ ' present in the cluster'
277
+ ),
278
+ dataclasses.replace(
279
+ SUB_SLICING_CASE,
280
+ workload_system=_get_system_characteristics_or_die('v6e-8'),
281
+ cluster_system=_get_system_characteristics_or_die('v6e-8'),
282
+ resources_config_map={'v6e-8': '4'},
283
+ ),
284
+ WorkloadScheduling.AVAILABLE,
285
+ ),
286
+ ],
287
+ )
288
+ def test_check_if_workload_can_schedule(
289
+ commands_tester: CommandsTester,
290
+ title: str,
291
+ case: SchedulingTestCase,
292
+ expected: WorkloadScheduling,
293
+ ):
294
+ FeatureFlags.SUB_SLICING_ENABLED = case.sub_slicing_feature_enabled
295
+ commands_tester.set_result_for_command(
296
+ (
297
+ 0,
298
+ f'registry.k8s.io/kueue/kueue:v{case.kueue_version}'
299
+ if case.kueue_version
300
+ else '',
301
+ ),
302
+ 'kubectl get deployment',
303
+ 'image',
304
+ )
305
+ commands_tester.set_result_for_command(
306
+ (0, 'sub-slice-topology' if case.sub_slicing_topology_set else ''),
307
+ 'kubectl get topology',
308
+ )
309
+ args = Namespace(
310
+ cluster='test-cluster',
311
+ workload='test-workload',
312
+ num_slices=case.num_slices,
313
+ )
314
+
315
+ assert (
316
+ check_if_workload_can_schedule(
317
+ args,
318
+ workload_system=case.workload_system,
319
+ cluster_system=case.cluster_system,
320
+ resources_config_map=case.resources_config_map,
321
+ )
322
+ == expected
323
+ )