xpk 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. xpk/commands/cluster.py +48 -5
  2. xpk/commands/cluster_gcluster.py +3 -0
  3. xpk/commands/cluster_gcluster_test.py +2 -0
  4. xpk/commands/cluster_test.py +203 -0
  5. xpk/commands/common.py +6 -0
  6. xpk/commands/kind.py +2 -0
  7. xpk/commands/workload.py +35 -15
  8. xpk/commands/workload_test.py +1 -0
  9. xpk/core/capacity.py +83 -46
  10. xpk/core/capacity_test.py +82 -28
  11. xpk/core/commands.py +39 -12
  12. xpk/core/kueue_manager.py +42 -11
  13. xpk/core/kueue_manager_test.py +83 -3
  14. xpk/core/nap.py +5 -4
  15. xpk/core/nodepool.py +57 -20
  16. xpk/core/nodepool_test.py +152 -23
  17. xpk/core/pathways.py +2 -1
  18. xpk/core/resources.py +3 -3
  19. xpk/core/scheduling.py +54 -10
  20. xpk/core/scheduling_test.py +118 -13
  21. xpk/core/system_characteristics.py +41 -24
  22. xpk/core/system_characteristics_test.py +37 -4
  23. xpk/core/telemetry.py +5 -0
  24. xpk/core/telemetry_test.py +19 -2
  25. xpk/core/updates.py +1 -1
  26. xpk/main.py +2 -1
  27. xpk/parser/cluster.py +34 -2
  28. xpk/parser/cluster_test.py +117 -0
  29. xpk/parser/common.py +32 -0
  30. xpk/parser/common_test.py +49 -0
  31. xpk/templates/kueue_config.yaml.j2 +21 -5
  32. xpk/templates/kueue_super_slicing_topology.yaml.j2 +9 -0
  33. xpk/utils/kueue.py +6 -2
  34. {xpk-0.16.0.dist-info → xpk-0.17.0.dist-info}/METADATA +2 -1
  35. {xpk-0.16.0.dist-info → xpk-0.17.0.dist-info}/RECORD +39 -37
  36. {xpk-0.16.0.dist-info → xpk-0.17.0.dist-info}/WHEEL +0 -0
  37. {xpk-0.16.0.dist-info → xpk-0.17.0.dist-info}/entry_points.txt +0 -0
  38. {xpk-0.16.0.dist-info → xpk-0.17.0.dist-info}/licenses/LICENSE +0 -0
  39. {xpk-0.16.0.dist-info → xpk-0.17.0.dist-info}/top_level.txt +0 -0
xpk/core/capacity.py CHANGED
@@ -15,10 +15,12 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import enum
18
+ from dataclasses import dataclass
18
19
 
20
+ from .commands import run_command_with_updates, run_command_for_value
21
+ from .system_characteristics import AcceleratorType
19
22
  from ..utils.console import xpk_print, xpk_exit
20
23
  from ..utils.kueue import is_queued_cluster
21
- from .commands import run_command_with_updates, run_command_for_value
22
24
 
23
25
  AUTOPROVISIONING_CONFIG_VALUE = 'AUTOPROVISION'
24
26
  AUTOPROVISIONING_CONFIG_MINIMUM_KEY = 'minimum_chips'
@@ -42,6 +44,14 @@ class CapacityType(enum.Enum):
42
44
  FLEX_START = 'flex_start'
43
45
 
44
46
 
47
+ @dataclass
48
+ class Reservation:
49
+ project: str
50
+ name: str
51
+ block_name: str | None = None
52
+ sub_block_name: str | None = None
53
+
54
+
45
55
  def print_reservations(args) -> int:
46
56
  """Print the reservations in the project.
47
57
 
@@ -107,7 +117,7 @@ def get_capacity_type(args) -> tuple[CapacityType, int]:
107
117
 
108
118
 
109
119
  def get_reservation_maintenance_interval(
110
- reservation: str, zone: str, project: str
120
+ reservation_path: str, zone: str, project: str
111
121
  ) -> str:
112
122
  """Get reservation maintenance interval.
113
123
 
@@ -117,12 +127,10 @@ def get_reservation_maintenance_interval(
117
127
  Returns:
118
128
  0 if successful and 1 otherwise.
119
129
  """
120
- reservation_project, reservation_name = get_reservation_project_and_name(
121
- reservation, project
122
- )
130
+ reservation = parse_reservation(reservation_path, project)
123
131
  command = (
124
- f'gcloud beta compute reservations describe {reservation_name}'
125
- f' --project={reservation_project} --zone={zone} --format="value(specificReservation.instanceProperties.maintenanceInterval)"'
132
+ f'gcloud beta compute reservations describe {reservation.name}'
133
+ f' --project={reservation.project} --zone={zone} --format="value(specificReservation.instanceProperties.maintenanceInterval)"'
126
134
  )
127
135
  return_code, output = run_command_for_value(
128
136
  command, 'Get reservation maintenance interval'
@@ -134,7 +142,7 @@ def get_reservation_maintenance_interval(
134
142
 
135
143
 
136
144
  def get_reservation_placement_policy(
137
- reservation: str, zone: str, project: str
145
+ reservation_path: str, zone: str, project: str
138
146
  ) -> str:
139
147
  """Get reservation placement policy.
140
148
 
@@ -144,12 +152,10 @@ def get_reservation_placement_policy(
144
152
  Returns:
145
153
  0 if successful and 1 otherwise.
146
154
  """
147
- reservation_project, reservation_name = get_reservation_project_and_name(
148
- reservation, project
149
- )
155
+ reservation = parse_reservation(reservation_path, project)
150
156
  command = (
151
- f'gcloud beta compute reservations describe {reservation_name}'
152
- f' --project={reservation_project} --zone={zone} --format="value(resourcePolicies.policy)"'
157
+ f'gcloud beta compute reservations describe {reservation.name}'
158
+ f' --project={reservation.project} --zone={zone} --format="value(resourcePolicies.policy)"'
153
159
  )
154
160
  return_code, output = run_command_for_value(
155
161
  command, 'Get reservation placement policy'
@@ -161,15 +167,13 @@ def get_reservation_placement_policy(
161
167
 
162
168
 
163
169
  def get_reservation_deployment_type(
164
- reservation: str, zone: str, project: str
170
+ reservation_path: str, zone: str, project: str
165
171
  ) -> str:
166
172
  """Get reservation deployment type."""
167
- reservation_project, reservation_name = get_reservation_project_and_name(
168
- reservation, project
169
- )
173
+ reservation = parse_reservation(reservation_path, project)
170
174
  command = (
171
- f'gcloud beta compute reservations describe {reservation_name}'
172
- f' --project={reservation_project} --zone={zone} --format="value(deploymentType)"'
175
+ f'gcloud beta compute reservations describe {reservation.name}'
176
+ f' --project={reservation.project} --zone={zone} --format="value(deploymentType)"'
173
177
  )
174
178
  return_code, output = run_command_for_value(
175
179
  command, 'Get reservation deployment type', dry_run_return_val='DENSE'
@@ -189,12 +193,10 @@ def verify_reservation_exists(args) -> int:
189
193
  Returns:
190
194
  0 if successful and 1 otherwise.
191
195
  """
192
- reservation_project, reservation_name = get_reservation_project_and_name(
193
- args.reservation, args.project
194
- )
196
+ reservation = parse_reservation(args.reservation, args.project)
195
197
  command = (
196
- f'gcloud beta compute reservations describe {reservation_name}'
197
- f' --project={reservation_project} --zone={args.zone}'
198
+ f'gcloud beta compute reservations describe {reservation.name}'
199
+ f' --project={reservation.project} --zone={args.zone}'
198
200
  )
199
201
  return_code = run_command_with_updates(command, 'Describe reservation')
200
202
  if return_code != 0:
@@ -205,7 +207,10 @@ def verify_reservation_exists(args) -> int:
205
207
 
206
208
 
207
209
  def get_capacity_arguments_from_capacity_type(
208
- args, capacity_type: CapacityType, max_nodes: int
210
+ args,
211
+ capacity_type: CapacityType,
212
+ max_nodes: int,
213
+ accelerator_type: AcceleratorType,
209
214
  ) -> tuple[str, int]:
210
215
  """Determine the Nodepool creation capacity arguments needed.
211
216
 
@@ -231,7 +236,7 @@ def get_capacity_arguments_from_capacity_type(
231
236
  ' --location-policy=ANY --reservation-affinity=none'
232
237
  f' --no-enable-autorepair --max-nodes={max_nodes}'
233
238
  )
234
- if is_queued_cluster(args.num_slices):
239
+ if is_queued_cluster(args.num_slices, accelerator_type):
235
240
  capacity_args += ' --enable-queued-provisioning'
236
241
  case CapacityType.RESERVATION:
237
242
  capacity_args = (
@@ -280,27 +285,59 @@ def get_capacity_node_selectors_from_capacity_type(
280
285
  return node_selector, return_code
281
286
 
282
287
 
283
- def get_reservation_project_and_name(
284
- reservation_name_or_path: str, cluster_project: str
285
- ) -> tuple[str, str]:
286
- """Get the reservation project and name.
288
+ def parse_reservation(
289
+ reservation_path: str, cluster_project: str
290
+ ) -> Reservation:
291
+ """Parses the reservation details from the reservation path.
292
+ Also supports reservation blocks and sub-blocks.
293
+ Assumes cluster project if project is not contained in the path.
287
294
 
288
- Args:
289
- reservation_name_or_path: either reservation name or reservation path in format
290
- projects/RESERVATION_PROJECT_ID/reservations/RESERVATION_NAME
291
- cluster_project: the cluster project
295
+ Args:
296
+ reservation_path: path to the reservation, reservation block or sub-block in format:
297
+ `[projects/RESERVATION_PROJECT_ID/reservations/]RESERVATION_NAME[/reservationBlocks/BLOCK_NAME[/reservationSubBlocks/SUB_BLOCK_NAME]]`
298
+ cluster_project: the cluster project
292
299
 
293
- Returns:
294
- Tuple with reservation project and reservation name.
300
+ Returns:
301
+ Reservation instance containing reservation details.
295
302
  """
296
- if '/' not in reservation_name_or_path:
297
- return cluster_project, reservation_name_or_path
298
- reservation_parts = reservation_name_or_path.split('/')
299
- if (
300
- len(reservation_parts) != 4
301
- or reservation_parts[0] != 'projects'
302
- or reservation_parts[2] != 'reservations'
303
- ):
304
- xpk_print('Unable to parse reservation: ', reservation_name_or_path)
303
+ reservation = _try_parse_reservation(reservation_path, cluster_project)
304
+ if reservation is None:
305
+ xpk_print('Unable to parse reservation: ', reservation_path)
305
306
  xpk_exit(1)
306
- return reservation_parts[1], reservation_parts[3]
307
+ return reservation
308
+
309
+
310
+ def _try_parse_reservation(
311
+ reservation_path: str, cluster_project: str
312
+ ) -> Reservation | None:
313
+ # assume trivial case, path contains just the reservation name
314
+ reservation = Reservation(
315
+ project=cluster_project,
316
+ name=reservation_path,
317
+ block_name=None,
318
+ sub_block_name=None,
319
+ )
320
+ parts = reservation_path.split('/')
321
+ if min(map(len, parts)) == 0: # all parts must be non-empty
322
+ return None
323
+ if len(parts) == 1:
324
+ return reservation # trivial case
325
+
326
+ if parts[0] == 'projects':
327
+ reservation.project = parts[1]
328
+ if len(parts) < 4 or parts[2] != 'reservations':
329
+ return None
330
+ parts = parts[3:] # remove projects/PROJECT/reservations/ prefix
331
+
332
+ if len(parts) not in (1, 3, 5):
333
+ return None
334
+ reservation.name = parts[0]
335
+ if len(parts) >= 3:
336
+ if parts[1] != 'reservationBlocks':
337
+ return None
338
+ reservation.block_name = parts[2]
339
+ if len(parts) >= 5:
340
+ if parts[3] != 'reservationSubBlocks':
341
+ return None
342
+ reservation.sub_block_name = parts[4]
343
+ return reservation
xpk/core/capacity_test.py CHANGED
@@ -16,7 +16,7 @@ limitations under the License.
16
16
 
17
17
  import pytest
18
18
  from unittest.mock import MagicMock, patch
19
- from .capacity import get_reservation_deployment_type, get_reservation_project_and_name
19
+ from .capacity import get_reservation_deployment_type, parse_reservation, Reservation
20
20
 
21
21
 
22
22
  @patch('xpk.core.capacity.xpk_print')
@@ -28,7 +28,7 @@ def test_get_reservation_deployment_type_exits_with_command_fails(
28
28
  )
29
29
  with pytest.raises(SystemExit):
30
30
  get_reservation_deployment_type(
31
- reservation='reservation', zone='zone', project='project'
31
+ reservation_path='reservation', zone='zone', project='project'
32
32
  )
33
33
 
34
34
  assert (
@@ -45,37 +45,91 @@ def test_get_reservation_deployment_type_returns_deployment_type_when_command_su
45
45
  return_value=(0, 'DENSE'),
46
46
  )
47
47
  result = get_reservation_deployment_type(
48
- reservation='reservation', zone='zone', project='project'
48
+ reservation_path='reservation', zone='zone', project='project'
49
49
  )
50
50
  assert result == 'DENSE'
51
51
 
52
52
 
53
- def test_get_reservation_project_and_name_parses_local_reservation():
54
- project, name = get_reservation_project_and_name(
55
- 'test-reservation', 'cluster-project'
56
- )
57
-
58
- assert project == 'cluster-project'
59
- assert name == 'test-reservation'
60
-
61
-
62
- def test_get_reservation_project_and_name_parses_shared_reservation():
63
- project, name = get_reservation_project_and_name(
64
- 'projects/reservation-project/reservations/test-reservation',
65
- 'cluster-project',
66
- )
67
-
68
- assert project == 'reservation-project'
69
- assert name == 'test-reservation'
70
-
71
-
53
+ @pytest.mark.parametrize(
54
+ argnames='reservation_path,expected_reservation',
55
+ argvalues=[
56
+ (
57
+ 'reservation',
58
+ Reservation(project='cluster-project', name='reservation'),
59
+ ),
60
+ (
61
+ 'reservation/reservationBlocks/block',
62
+ Reservation(
63
+ project='cluster-project',
64
+ name='reservation',
65
+ block_name='block',
66
+ ),
67
+ ),
68
+ (
69
+ 'reservation/reservationBlocks/block/reservationSubBlocks/subblock',
70
+ Reservation(
71
+ project='cluster-project',
72
+ name='reservation',
73
+ block_name='block',
74
+ sub_block_name='subblock',
75
+ ),
76
+ ),
77
+ (
78
+ 'projects/p/reservations/reservation',
79
+ Reservation(project='p', name='reservation'),
80
+ ),
81
+ (
82
+ 'projects/p/reservations/reservation/reservationBlocks/block',
83
+ Reservation(
84
+ project='p',
85
+ name='reservation',
86
+ block_name='block',
87
+ ),
88
+ ),
89
+ (
90
+ 'projects/p/reservations/reservation/reservationBlocks/block/reservationSubBlocks/subblock',
91
+ Reservation(
92
+ project='p',
93
+ name='reservation',
94
+ block_name='block',
95
+ sub_block_name='subblock',
96
+ ),
97
+ ),
98
+ ],
99
+ )
100
+ def test_parse_reservation_parses_valid_reservations(
101
+ reservation_path: str,
102
+ expected_reservation: Reservation,
103
+ ):
104
+ actual_reservation = parse_reservation(reservation_path, 'cluster-project')
105
+
106
+ assert actual_reservation == expected_reservation
107
+
108
+
109
+ @pytest.mark.parametrize(
110
+ argnames='reservation_path',
111
+ argvalues=[
112
+ '',
113
+ '/name',
114
+ 'name/',
115
+ 'name/reservationBlocks/',
116
+ 'name/reservationBlocks/block/reservationSubBlocks/',
117
+ 'name/reservationBlocks/block/reservationSubBlocks/subblock/extra',
118
+ 'name/reservationBlock/block/reservationSubBlocks/subblock',
119
+ 'name/reservationBlocks/block/reservationSubBlock/subblock',
120
+ 'reservations/name',
121
+ 'project/p/reservations/name',
122
+ 'projects/p/reservation/name',
123
+ 'projects/p/reservations',
124
+ 'projects/p/reservations/name/reservationBlocks/block/reservationSubBlocks/subblock/extra',
125
+ 'projects/p/reservations/name/reservationBlocks//reservationSubBlocks/subblock',
126
+ ],
127
+ )
72
128
  @patch('xpk.core.capacity.xpk_print')
73
- def test_get_reservation_project_and_name_fails_for_invalid_reservation(
74
- xpk_print: MagicMock, mocker
129
+ def test_parse_reservation_fails_on_invalid_reservations(
130
+ xpk_print: MagicMock, reservation_path: str
75
131
  ):
76
132
  with pytest.raises(SystemExit):
77
- get_reservation_project_and_name(
78
- 'invalid/reservation',
79
- 'cluster-project',
80
- )
133
+ parse_reservation(reservation_path, 'cluster-project')
134
+
81
135
  assert 'Unable to parse reservation' in xpk_print.mock_calls[0].args[0]
xpk/core/commands.py CHANGED
@@ -19,13 +19,27 @@ import subprocess
19
19
  import sys
20
20
  import time
21
21
 
22
+ from dataclasses import dataclass
22
23
  from ..utils.objects import chunks
23
24
  from ..utils.file import make_tmp_files, write_tmp_file
24
25
  from ..utils.console import xpk_print
25
26
  from ..utils.execution_context import is_dry_run
26
27
 
27
28
 
28
- def run_commands(commands, jobname, per_command_name, batch=10):
29
+ @dataclass
30
+ class FailedCommand:
31
+ return_code: int
32
+ name: str
33
+ command: str
34
+ logfile: str
35
+
36
+
37
+ def run_commands(
38
+ commands: list[str],
39
+ jobname: str,
40
+ per_command_name: list[str],
41
+ batch: int = 10,
42
+ ) -> FailedCommand | None:
29
43
  """Run commands in groups of `batch`.
30
44
 
31
45
  Args:
@@ -35,8 +49,10 @@ def run_commands(commands, jobname, per_command_name, batch=10):
35
49
  batch: number of commands to run in parallel.
36
50
 
37
51
  Returns:
38
- 0 if successful and 1 otherwise.
52
+ None if all commands were successful, FailedCommand instance containing
53
+ details of a single failing command otherwise
39
54
  """
55
+
40
56
  temporary_files_batches = chunks(make_tmp_files(per_command_name), batch)
41
57
  commands_batched = chunks(commands, batch)
42
58
  per_command_name_batches = chunks(per_command_name, batch)
@@ -47,24 +63,27 @@ def run_commands(commands, jobname, per_command_name, batch=10):
47
63
  )
48
64
  if is_dry_run():
49
65
  xpk_print('Pretending all the jobs succeeded')
50
- return 0
66
+ return None
51
67
 
52
- max_return_code = 0
53
68
  for i, _ in enumerate(commands_batched):
54
69
  xpk_print(f'Dispatching batch {i}/{len(commands_batched)}')
55
- batch_max_return_code, _ = run_command_batch(
70
+ maybe_failure = run_command_batch(
56
71
  commands_batched[i],
57
72
  jobname,
58
73
  per_command_name_batches[i],
59
74
  temporary_files_batches[i],
60
75
  )
61
- max_return_code = max(max_return_code, batch_max_return_code)
62
- if max_return_code > 0:
63
- return max_return_code
64
- return max_return_code
76
+ if maybe_failure is not None:
77
+ return maybe_failure
78
+ return None
65
79
 
66
80
 
67
- def run_command_batch(commands, jobname, per_command_name, output_logs):
81
+ def run_command_batch(
82
+ commands: list[str],
83
+ jobname: str,
84
+ per_command_name: list[str],
85
+ output_logs: list[str],
86
+ ) -> FailedCommand | None:
68
87
  """Runs commands in parallel.
69
88
 
70
89
  Args:
@@ -74,7 +93,8 @@ def run_command_batch(commands, jobname, per_command_name, output_logs):
74
93
  output_logs: list of n log paths, each command will output to each log.
75
94
 
76
95
  Returns:
77
- The max return code and a list of all the return codes.
96
+ None if all commands were successful, FailedCommand instance containing
97
+ details of a single failing command otherwise
78
98
  """
79
99
 
80
100
  files = [open(f, 'w', encoding='utf-8') for f in output_logs]
@@ -86,6 +106,7 @@ def run_command_batch(commands, jobname, per_command_name, output_logs):
86
106
  subprocess.Popen(command, stdout=file, stderr=file, shell=True)
87
107
  )
88
108
 
109
+ maybe_failure: FailedCommand | None = None
89
110
  while True:
90
111
  returncodes = [child.poll() for child in children]
91
112
  max_returncode = max([0] + [r for r in returncodes if r is not None])
@@ -118,6 +139,12 @@ def run_command_batch(commands, jobname, per_command_name, output_logs):
118
139
  )
119
140
  for child in children:
120
141
  child.terminate()
142
+ maybe_failure = FailedCommand(
143
+ return_code=returncodes[failing_index] or 0,
144
+ name=per_command_name[failing_index],
145
+ command=commands[failing_index],
146
+ logfile=output_logs[failing_index],
147
+ )
121
148
  break
122
149
 
123
150
  if completed == total:
@@ -128,7 +155,7 @@ def run_command_batch(commands, jobname, per_command_name, output_logs):
128
155
  for file in files:
129
156
  file.close()
130
157
 
131
- return max_returncode, returncodes
158
+ return maybe_failure
132
159
 
133
160
 
134
161
  def run_command_with_updates_retry(
xpk/core/kueue_manager.py CHANGED
@@ -15,7 +15,6 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import math
18
- import textwrap
19
18
  from dataclasses import dataclass
20
19
  from typing import Optional, List, Dict, Any
21
20
  import json
@@ -48,10 +47,12 @@ WAIT_FOR_KUEUE_TIMEOUT = "10m"
48
47
  CLUSTER_QUEUE_NAME = "cluster-queue"
49
48
  LOCAL_QUEUE_NAME = "multislice-queue"
50
49
  SUB_SLICE_TOPOLOGY_NAME = "sub-slice-topology"
50
+ SUPER_SLICE_TOPOLOGY_NAME = "super-slice-topology"
51
51
  KUEUE_CONFIG_JINJA_FILE = "kueue_config.yaml.j2"
52
52
  KUEUE_GKE_DEFAULT_TOPOLOGY_JINJA_FILE = "kueue_gke_default_topology.yaml.j2"
53
53
  KUEUE_CONTROLLER_MANAGER_JINJA_FILE = "kueue_controller_manager.yaml.j2"
54
54
  KUEUE_SUB_SLICING_TOPOLOGY_JINJA_FILE = "kueue_sub_slicing_topology.yaml.j2"
55
+ KUEUE_SUPER_SLICING_TOPOLOGY_JINJA_FILE = "kueue_super_slicing_topology.yaml.j2"
55
56
  MEMORY_SIZE_PER_VM = 1.2
56
57
  MIN_MEMORY_LIMIT_SIZE = 4096
57
58
 
@@ -63,6 +64,7 @@ class KueueConfig:
63
64
  cpu_limit: int
64
65
  memory_limit: str
65
66
  configure_sub_slicing: bool
67
+ configure_super_slicing: bool
66
68
  is_pathways_cluster: bool = False
67
69
  autoprovisioning_enabled: bool = False
68
70
  flex: bool = False
@@ -268,7 +270,9 @@ class KueueManager:
268
270
  template = self.template_env.get_template(KUEUE_CONFIG_JINJA_FILE)
269
271
 
270
272
  topology_name_and_yaml = self.__get_topology_name_and_yaml(
271
- kueue_config.system, kueue_config.configure_sub_slicing
273
+ kueue_config.system,
274
+ kueue_config.configure_sub_slicing,
275
+ kueue_config.configure_super_slicing,
272
276
  )
273
277
  topology_name = (
274
278
  topology_name_and_yaml.name if topology_name_and_yaml else None
@@ -324,7 +328,11 @@ class KueueManager:
324
328
  key, value = accelerator_label.split(":", 1)
325
329
  node_labels_dict[key] = value.strip()
326
330
 
327
- if not autoprovisioning:
331
+ if system.supports_super_slicing:
332
+ node_labels_dict["cloud.google.com/gke-tpu-partition-4x4x4-state"] = (
333
+ "HEALTHY"
334
+ )
335
+ elif not autoprovisioning:
328
336
  machine_label = create_machine_label(system)
329
337
  if machine_label:
330
338
  key, value = machine_label.split(":", 1)
@@ -374,13 +382,11 @@ class KueueManager:
374
382
  }],
375
383
  })
376
384
 
377
- if flex and is_queued_cluster(num_slices):
378
- admission_checks = textwrap.dedent("""
379
- admissionChecks:
380
- - dws-prov
381
- """)
382
- else:
383
- admission_checks = ""
385
+ admission_checks = []
386
+ if system.supports_super_slicing:
387
+ admission_checks.append("ss-kueue-operator")
388
+ if flex and is_queued_cluster(num_slices, system.accelerator_type):
389
+ admission_checks.append("dws-prov")
384
390
 
385
391
  return {
386
392
  "flavors": flavors,
@@ -393,7 +399,10 @@ class KueueManager:
393
399
  }
394
400
 
395
401
  def __get_topology_name_and_yaml(
396
- self, system: SystemCharacteristics, configure_sub_slicing: bool
402
+ self,
403
+ system: SystemCharacteristics,
404
+ configure_sub_slicing: bool,
405
+ configure_super_slicing: bool,
397
406
  ) -> _NameAndYaml | None:
398
407
  if (
399
408
  system.accelerator_type == AcceleratorType["GPU"]
@@ -427,6 +436,15 @@ class KueueManager:
427
436
  "levels": levels,
428
437
  }),
429
438
  )
439
+ elif configure_super_slicing:
440
+ return _NameAndYaml(
441
+ name=SUPER_SLICE_TOPOLOGY_NAME,
442
+ yaml=self.template_env.get_template(
443
+ KUEUE_SUPER_SLICING_TOPOLOGY_JINJA_FILE
444
+ ).render({
445
+ "super_slice_topology_name": SUPER_SLICE_TOPOLOGY_NAME,
446
+ }),
447
+ )
430
448
  else:
431
449
  return None
432
450
 
@@ -552,6 +570,19 @@ def has_sub_slicing_enabled() -> tuple[int, bool | None]:
552
570
  return return_code, SUB_SLICE_TOPOLOGY_NAME in value
553
571
 
554
572
 
573
+ def has_super_slicing_enabled() -> tuple[int, bool | None]:
574
+ return_code, value = run_command_for_value(
575
+ command="kubectl get topology",
576
+ task="Get defined topologies",
577
+ dry_run_return_val=SUPER_SLICE_TOPOLOGY_NAME,
578
+ )
579
+
580
+ if return_code != 0:
581
+ return return_code, None
582
+
583
+ return return_code, SUPER_SLICE_TOPOLOGY_NAME in value
584
+
585
+
555
586
  def _autocorrect_cpu_limit(cpu_limit: int, cpu_capacity: int) -> int:
556
587
  if cpu_limit > cpu_capacity:
557
588
  xpk_print(