xpk 0.14.3__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. integration/gcluster_a3mega_test.py +11 -0
  2. integration/gcluster_a3ultra_test.py +11 -0
  3. integration/gcluster_a4_test.py +11 -0
  4. xpk/commands/cluster.py +57 -21
  5. xpk/commands/cluster_gcluster.py +25 -5
  6. xpk/commands/cluster_gcluster_test.py +11 -2
  7. xpk/commands/cluster_test.py +233 -12
  8. xpk/commands/config.py +3 -5
  9. xpk/commands/kind.py +1 -1
  10. xpk/commands/storage.py +8 -10
  11. xpk/commands/workload.py +28 -11
  12. xpk/commands/workload_test.py +3 -3
  13. xpk/core/blueprint/blueprint_generator.py +70 -33
  14. xpk/core/blueprint/blueprint_test.py +9 -0
  15. xpk/core/capacity.py +46 -8
  16. xpk/core/capacity_test.py +32 -1
  17. xpk/core/cluster.py +37 -57
  18. xpk/core/cluster_test.py +95 -0
  19. xpk/core/commands.py +4 -10
  20. xpk/core/config.py +9 -2
  21. xpk/core/gcloud_context.py +18 -12
  22. xpk/core/gcloud_context_test.py +111 -1
  23. xpk/core/kjob.py +6 -9
  24. xpk/core/kueue_manager.py +192 -32
  25. xpk/core/kueue_manager_test.py +132 -4
  26. xpk/core/nodepool.py +21 -29
  27. xpk/core/nodepool_test.py +17 -15
  28. xpk/core/scheduling.py +16 -1
  29. xpk/core/scheduling_test.py +85 -6
  30. xpk/core/system_characteristics.py +77 -19
  31. xpk/core/system_characteristics_test.py +80 -5
  32. xpk/core/telemetry.py +263 -0
  33. xpk/core/telemetry_test.py +211 -0
  34. xpk/main.py +31 -13
  35. xpk/parser/cluster.py +48 -9
  36. xpk/parser/cluster_test.py +42 -3
  37. xpk/parser/workload.py +12 -0
  38. xpk/parser/workload_test.py +4 -4
  39. xpk/telemetry_uploader.py +29 -0
  40. xpk/templates/kueue_gke_default_topology.yaml.j2 +1 -1
  41. xpk/templates/kueue_sub_slicing_topology.yaml.j2 +3 -8
  42. xpk/utils/console.py +41 -10
  43. xpk/utils/console_test.py +106 -0
  44. xpk/utils/feature_flags.py +7 -1
  45. xpk/utils/file.py +4 -1
  46. xpk/utils/topology.py +4 -0
  47. xpk/utils/user_agent.py +35 -0
  48. xpk/utils/user_agent_test.py +44 -0
  49. xpk/utils/user_input.py +48 -0
  50. xpk/utils/user_input_test.py +92 -0
  51. xpk/utils/validation.py +0 -11
  52. xpk/utils/versions.py +31 -0
  53. {xpk-0.14.3.dist-info → xpk-0.15.0.dist-info}/METADATA +113 -92
  54. {xpk-0.14.3.dist-info → xpk-0.15.0.dist-info}/RECORD +58 -48
  55. {xpk-0.14.3.dist-info → xpk-0.15.0.dist-info}/WHEEL +0 -0
  56. {xpk-0.14.3.dist-info → xpk-0.15.0.dist-info}/entry_points.txt +0 -0
  57. {xpk-0.14.3.dist-info → xpk-0.15.0.dist-info}/licenses/LICENSE +0 -0
  58. {xpk-0.14.3.dist-info → xpk-0.15.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import yaml
22
22
  from unittest.mock import MagicMock, patch
23
23
 
24
24
  from xpk.core.kueue_manager import KueueConfig, KueueManager, has_sub_slicing_enabled
25
- from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics
25
+ from xpk.core.system_characteristics import AcceleratorType, SystemCharacteristics, UserFacingNameToSystemCharacteristics
26
26
  from xpk.core.testing.commands_tester import CommandsTester
27
27
  from packaging.version import Version
28
28
 
@@ -61,6 +61,13 @@ def set_installed_kueue_version(
61
61
  )
62
62
 
63
63
 
64
+ @pytest.fixture(autouse=True)
65
+ def mock_ask_for_user_consent(mocker: MockerFixture) -> MagicMock:
66
+ return mocker.patch(
67
+ "xpk.core.kueue_manager.ask_for_user_consent", return_value=True
68
+ )
69
+
70
+
64
71
  @pytest.fixture(autouse=True)
65
72
  def mock_commands(mocker: MockerFixture) -> CommandsTester:
66
73
  return CommandsTester(
@@ -78,7 +85,7 @@ def mock_commands(mocker: MockerFixture) -> CommandsTester:
78
85
  @pytest.fixture(autouse=True)
79
86
  @patch("jinja2.Environment", return_value=MagicMock())
80
87
  def kueue_manager(mock_env: MagicMock) -> KueueManager:
81
- return KueueManager()
88
+ return KueueManager("test-project", "test-zone")
82
89
 
83
90
 
84
91
  def test_install_or_upgrade_when_newer_version_already_installed(
@@ -102,7 +109,7 @@ def test_install_or_upgrade_when_outdated(
102
109
  result = kueue_manager.install_or_upgrade(KUEUE_CONFIG)
103
110
 
104
111
  assert result == 0
105
- mock_commands.assert_command_run("kubectl apply", "v0.12.2/manifests.yaml")
112
+ mock_commands.assert_command_run("kubectl apply", "v0.14.3/manifests.yaml")
106
113
  mock_commands.assert_command_run("kubectl apply -f", "/tmp/")
107
114
 
108
115
 
@@ -115,10 +122,84 @@ def test_install_or_upgrade_when_not_installed(
115
122
  result = kueue_manager.install_or_upgrade(KUEUE_CONFIG)
116
123
 
117
124
  assert result == 0
118
- mock_commands.assert_command_run("kubectl apply", "v0.12.2/manifests.yaml")
125
+ mock_commands.assert_command_run("kubectl apply", "v0.14.3/manifests.yaml")
119
126
  mock_commands.assert_command_run("kubectl apply -f", "/tmp/")
120
127
 
121
128
 
129
+ def test_upgrade_when_no_breaking_changes_between_versions_no_preparation_needed(
130
+ mock_commands: CommandsTester,
131
+ kueue_manager: KueueManager,
132
+ mock_ask_for_user_consent: MagicMock,
133
+ ):
134
+ set_installed_kueue_version(mock_commands, Version("0.14.0"))
135
+
136
+ kueue_manager.install_or_upgrade(KUEUE_CONFIG)
137
+
138
+ mock_ask_for_user_consent.assert_not_called()
139
+
140
+
141
+ def test_upgrade_with_breaking_changes_between_versions_runs_preparation(
142
+ mock_commands: CommandsTester,
143
+ kueue_manager: KueueManager,
144
+ mock_ask_for_user_consent: MagicMock,
145
+ ):
146
+ set_installed_kueue_version(mock_commands, Version("0.11.0"))
147
+ fake_crds = (
148
+ "customresourcedefinition.apiextensions.k8s.io/kueue-crd-1.kueue.x-k8s.io\n"
149
+ "customresourcedefinition.apiextensions.k8s.io/kueue-crd-2.kueue.x-k8s.io"
150
+ )
151
+ mock_commands.set_result_for_command(
152
+ (0, fake_crds), "kubectl get crd -o name"
153
+ )
154
+ mock_ask_for_user_consent.return_value = True
155
+
156
+ result = kueue_manager.install_or_upgrade(KUEUE_CONFIG)
157
+
158
+ assert result == 0
159
+ mock_ask_for_user_consent.assert_called_once()
160
+ assert (
161
+ "CHANGELOG/CHANGELOG-0.14.md"
162
+ in mock_ask_for_user_consent.mock_calls[0].args[0]
163
+ )
164
+ mock_commands.assert_command_run(
165
+ "kubectl delete kueue-crd-1.kueue.x-k8s.io --all"
166
+ )
167
+ mock_commands.assert_command_run(
168
+ "kubectl delete kueue-crd-2.kueue.x-k8s.io --all"
169
+ )
170
+ mock_commands.assert_command_run(
171
+ "kubectl delete crd kueue-crd-1.kueue.x-k8s.io"
172
+ )
173
+ mock_commands.assert_command_run(
174
+ "kubectl delete crd kueue-crd-2.kueue.x-k8s.io"
175
+ )
176
+ mock_commands.assert_command_run(
177
+ "kubectl delete deployment kueue-controller-manager"
178
+ )
179
+
180
+
181
+ def test_upgrade_with_breaking_changes_between_versions_does_not_run_preparation_without_consent(
182
+ mock_commands: CommandsTester,
183
+ kueue_manager: KueueManager,
184
+ mock_ask_for_user_consent: MagicMock,
185
+ ):
186
+ set_installed_kueue_version(mock_commands, Version("0.11.0"))
187
+ mock_commands.set_result_for_command(
188
+ (
189
+ 0,
190
+ "customresourcedefinition.apiextensions.k8s.io/kueue-crd-1.kueue.x-k8s.io",
191
+ ),
192
+ "kubectl get crd -o name",
193
+ )
194
+ mock_ask_for_user_consent.return_value = False
195
+
196
+ result = kueue_manager.install_or_upgrade(KUEUE_CONFIG)
197
+
198
+ assert result == 1
199
+ # Assert there was no command run for the Kueue crd:
200
+ mock_commands.assert_command_not_run("kueue-crd-1.kueue.x-k8s.io")
201
+
202
+
122
203
  def test_installation_with_tolerations(
123
204
  mock_commands: CommandsTester, kueue_manager: KueueManager
124
205
  ):
@@ -199,6 +280,10 @@ def test_configure_generates_correct_manifest_for_tpu(
199
280
  ):
200
281
  """Test that __configure generates the correct manifest content for TPUs."""
201
282
  set_installed_kueue_version(mock_commands, None)
283
+ mock_commands.set_result_for_command(
284
+ (0, "100 102400"), "gcloud compute machine-types describe"
285
+ )
286
+
202
287
  tpu_kueue_config = dataclasses.replace(
203
288
  KUEUE_CONFIG, system=TPU_SYSTEM, num_slices=2
204
289
  )
@@ -239,6 +324,39 @@ def test_configure_generates_correct_manifest_for_tpu(
239
324
  )
240
325
 
241
326
 
327
+ @patch("xpk.core.kueue_manager.write_tmp_file")
328
+ def test_install_autocorrects_resource_limits(
329
+ write_tmp_file_mock: MagicMock,
330
+ mock_commands: CommandsTester,
331
+ kueue_manager: KueueManager,
332
+ ):
333
+ """Test that installation auto-corrects the specified resource limits."""
334
+ set_installed_kueue_version(mock_commands, None)
335
+ # set 50 vCPU, 200Gi memory
336
+ mock_commands.set_result_for_command(
337
+ (0, "50 204800"), "gcloud compute machine-types describe"
338
+ )
339
+
340
+ kueue_config = dataclasses.replace(
341
+ KUEUE_CONFIG, cpu_limit=100, memory_limit="100Gi"
342
+ )
343
+
344
+ kueue_manager.install_or_upgrade(kueue_config)
345
+
346
+ rendered_manifest: str = write_tmp_file_mock.call_args[0][0]
347
+ manifest_docs = list(yaml.safe_load_all(rendered_manifest))
348
+ cluster_queue = _first(
349
+ doc for doc in manifest_docs if doc["kind"] == "ClusterQueue"
350
+ )
351
+ resources = cluster_queue["spec"]["resourceGroups"][0]["flavors"][0][
352
+ "resources"
353
+ ]
354
+ cpu_resource = _first(r for r in resources if r["name"] == "cpu")
355
+ memory_resource = _first(r for r in resources if r["name"] == "memory")
356
+ assert cpu_resource["nominalQuota"] == 50
357
+ assert memory_resource["nominalQuota"] == "204800Mi"
358
+
359
+
242
360
  @patch("xpk.core.kueue_manager.write_tmp_file")
243
361
  def test_configure_generates_manifest_with_admission_checks_for_flex_single_slice(
244
362
  write_tmp_file_mock: MagicMock,
@@ -317,6 +435,7 @@ def test_configure_generates_correct_manifest_with_sub_slicing(
317
435
  kueue_config = dataclasses.replace(
318
436
  KUEUE_CONFIG,
319
437
  configure_sub_slicing=True,
438
+ system=UserFacingNameToSystemCharacteristics["v6e-8x8"],
320
439
  )
321
440
 
322
441
  kueue_manager.install_or_upgrade(kueue_config)
@@ -329,6 +448,15 @@ def test_configure_generates_correct_manifest_with_sub_slicing(
329
448
  assert resource_flavor["spec"]["topologyName"] == "sub-slice-topology"
330
449
  topology = _first(doc for doc in manifest_docs if doc["kind"] == "Topology")
331
450
  assert topology["metadata"]["name"] == "sub-slice-topology"
451
+ expected_levels = [
452
+ "cloud.google.com/gke-tpu-slice-8x8-id",
453
+ "cloud.google.com/gke-tpu-slice-4x8-id",
454
+ "cloud.google.com/gke-tpu-slice-4x4-id",
455
+ "cloud.google.com/gke-tpu-slice-2x4-id",
456
+ "kubernetes.io/hostname",
457
+ ]
458
+ actual_levels = [level["nodeLabel"] for level in topology["spec"]["levels"]]
459
+ assert actual_levels == expected_levels
332
460
 
333
461
 
334
462
  @patch("xpk.core.kueue_manager.write_tmp_file")
xpk/core/nodepool.py CHANGED
@@ -15,8 +15,8 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  from typing import List
18
- from ..utils.console import get_user_input, xpk_print
19
- from ..utils.topology import get_topology_product, is_topology_valid
18
+ from ..utils.console import ask_for_user_consent, xpk_print
19
+ from .scheduling import get_placement_policy_name, is_placement_policy_supported
20
20
  from .capacity import (
21
21
  AUTOPROVISIONING_CONFIG_VALUE,
22
22
  H100_MEGA_DEVICE_TYPE,
@@ -110,6 +110,7 @@ def run_gke_node_pool_create_command(
110
110
  existing_node_pool_names, args.cluster, desired_node_pool_count
111
111
  )
112
112
 
113
+ node_pools_to_delete = []
113
114
  node_pools_to_remain = []
114
115
  delete_commands = []
115
116
  delete_task_names = []
@@ -186,14 +187,10 @@ def run_gke_node_pool_create_command(
186
187
  # when cluster is getting updated from 'x' device_type/gke_accelerator to 'y' device_type/gke_accelerator.
187
188
  # In that case, '{args.cluster}-np-i' nodepool will be re-created for 'y' device_type/gke_accelerator.
188
189
  if delete_commands:
189
- will_delete = True
190
- if node_pools_to_delete and not args.force:
191
- will_delete = get_user_input(
192
- f'Planning to delete {len(node_pools_to_delete)} node pools including'
193
- f' {node_pools_to_delete}. \nDo you wish to delete: y (yes) / n'
194
- ' (no):\n'
195
- )
196
- if not will_delete:
190
+ if node_pools_to_delete and not ask_for_user_consent(
191
+ f'Planning to delete {len(node_pools_to_delete)} node pools including'
192
+ f' {node_pools_to_delete}. \nDo you wish to delete?'
193
+ ):
197
194
  xpk_print(
198
195
  'You have requested to not delete the existing nodepools in the'
199
196
  ' cluster. There will be no change to the cluster.'
@@ -215,18 +212,15 @@ def run_gke_node_pool_create_command(
215
212
 
216
213
  # Enable Workload Identity on existing Nodepools
217
214
  if update_WI_commands:
218
- will_update_WI = True
219
- if node_pools_to_update_WI and not args.force:
220
- will_update_WI = get_user_input(
221
- 'Planning to enable Workload Identity Federation on'
222
- f' {len(node_pools_to_update_WI)} existing node pools including'
223
- f' {node_pools_to_update_WI}.This immediately enables Workload'
224
- ' Identity Federation for GKE for any workloads running in the node'
225
- ' pool. Also, xpk does not support disabling Workload Identity on'
226
- ' clusters that have it enabled already \nDo you wish to update: y'
227
- ' (yes) / n (no):\n'
228
- )
229
- if not will_update_WI:
215
+ will_update_WI = not node_pools_to_update_WI or ask_for_user_consent(
216
+ 'Planning to enable Workload Identity Federation on'
217
+ f' {len(node_pools_to_update_WI)} existing node pools including'
218
+ f' {node_pools_to_update_WI}. This immediately enables Workload'
219
+ ' Identity Federation for GKE for any workloads running in the node'
220
+ ' pool. Also, xpk does not support disabling Workload Identity on'
221
+ ' clusters that have it enabled already \nDo you wish to update?'
222
+ )
223
+ if will_update_WI:
230
224
  for i, command in enumerate(update_WI_commands):
231
225
  xpk_print(
232
226
  f'To complete {update_WI_task_names[i]} we are executing {command}'
@@ -264,10 +258,8 @@ def run_gke_node_pool_create_command(
264
258
  return 1
265
259
 
266
260
  placement_args = ''
267
- if system.requires_workload_policy and is_topology_valid(system.topology):
268
- placement_policy = (
269
- f'{system.device_type}-{system.topology}-placement-policy'
270
- )
261
+ if is_placement_policy_supported(system):
262
+ placement_policy = get_placement_policy_name(system)
271
263
  ensure_resource_policy_exists(placement_policy, args, system.topology)
272
264
  placement_args = f' --placement-policy={placement_policy}'
273
265
 
@@ -290,16 +282,16 @@ def run_gke_node_pool_create_command(
290
282
  )
291
283
  if system.accelerator_type == AcceleratorType.TPU:
292
284
  command += f' --node-version={gke_node_pool_version}'
293
- topology_product = get_topology_product(system.topology)
294
285
  if capacity_type == CapacityType.FLEX_START:
295
286
  command += ' --num-nodes=0'
296
- elif topology_product > 1:
287
+ else:
297
288
  command += f' --num-nodes={system.vms_per_slice}'
298
289
  command += (
299
290
  f' --scopes=storage-full,gke-default,{CLOUD_PLATFORM_AUTH_SCOPE_URL}'
300
291
  )
301
292
 
302
- if topology_product > 1:
293
+ # --tpu-topology should not be set for single-host node pools
294
+ if system.vms_per_slice > 1:
303
295
  # --placement-type=COMPACT enables group placement policy which
304
296
  # is mutually exclusive with workload policy, --tpu-topology should
305
297
  # also not be passed when workload policy is used
xpk/core/nodepool_test.py CHANGED
@@ -145,22 +145,24 @@ def mock_nodepool_dependencies(mocker):
145
145
  "xpk.core.nodepool.get_cluster_location", return_value="us-central1"
146
146
  )
147
147
  mocker.patch("xpk.core.nodepool.run_commands", return_value=0)
148
- mocker.patch("xpk.core.nodepool.get_user_input", return_value=True)
149
- mock_is_topology_valid = mocker.patch("xpk.core.nodepool.is_topology_valid")
148
+ mocker.patch("xpk.core.nodepool.ask_for_user_consent", return_value=True)
149
+ mock_is_placement_policy_supported = mocker.patch(
150
+ "xpk.core.nodepool.is_placement_policy_supported"
151
+ )
150
152
  mock_ensure_resource_policy = mocker.patch(
151
153
  "xpk.core.nodepool.ensure_resource_policy_exists"
152
154
  )
153
- return mock_is_topology_valid, mock_ensure_resource_policy
155
+ return mock_is_placement_policy_supported, mock_ensure_resource_policy
154
156
 
155
157
 
156
158
  def test_placement_policy_created_for_gpu_with_valid_topology(
157
159
  mocker, mock_nodepool_dependencies
158
160
  ):
159
161
  """Tests that placement policy is created for GPUs with a valid topology."""
160
- mock_is_topology_valid, mock_ensure_resource_policy = (
162
+ mock_is_placement_policy_supported, mock_ensure_resource_policy = (
161
163
  mock_nodepool_dependencies
162
164
  )
163
- mock_is_topology_valid.return_value = True
165
+ mock_is_placement_policy_supported.return_value = True
164
166
  args = mocker.Mock(
165
167
  tpu_type=None,
166
168
  device_type="h100-80gb-8",
@@ -170,7 +172,7 @@ def test_placement_policy_created_for_gpu_with_valid_topology(
170
172
  )
171
173
  system = SystemCharacteristics(
172
174
  topology="N/A",
173
- vms_per_slice=1,
175
+ vms_per_slice=2,
174
176
  gke_accelerator="nvidia-h100-80gb",
175
177
  gce_machine_type="a3-highgpu-8g",
176
178
  chips_per_vm=8,
@@ -188,10 +190,10 @@ def test_placement_policy_not_created_for_gpu_with_invalid_topology(
188
190
  mocker, mock_nodepool_dependencies
189
191
  ):
190
192
  """Tests that placement policy is not created for GPUs with an invalid topology."""
191
- mock_is_topology_valid, mock_ensure_resource_policy = (
193
+ mock_is_placement_policy_supported, mock_ensure_resource_policy = (
192
194
  mock_nodepool_dependencies
193
195
  )
194
- mock_is_topology_valid.return_value = False
196
+ mock_is_placement_policy_supported.return_value = False
195
197
  args = mocker.Mock(
196
198
  tpu_type=None,
197
199
  device_type="h100-80gb-8",
@@ -200,7 +202,7 @@ def test_placement_policy_not_created_for_gpu_with_invalid_topology(
200
202
  )
201
203
  system = SystemCharacteristics(
202
204
  topology="N/A",
203
- vms_per_slice=1,
205
+ vms_per_slice=2,
204
206
  gke_accelerator="nvidia-h100-80gb",
205
207
  gce_machine_type="a3-highgpu-8g",
206
208
  chips_per_vm=8,
@@ -218,10 +220,10 @@ def test_placement_policy_created_for_tpu7x_with_valid_topology(
218
220
  mocker, mock_nodepool_dependencies
219
221
  ):
220
222
  """Tests that placement policy is created for tpu7x with a valid topology."""
221
- mock_is_topology_valid, mock_ensure_resource_policy = (
223
+ mock_is_placement_policy_supported, mock_ensure_resource_policy = (
222
224
  mock_nodepool_dependencies
223
225
  )
224
- mock_is_topology_valid.return_value = True
226
+ mock_is_placement_policy_supported.return_value = True
225
227
  args = mocker.Mock(
226
228
  tpu_type="tpu7x-8",
227
229
  device_type=None,
@@ -232,7 +234,7 @@ def test_placement_policy_created_for_tpu7x_with_valid_topology(
232
234
  )
233
235
  system = SystemCharacteristics(
234
236
  topology="2x2x1",
235
- vms_per_slice=1,
237
+ vms_per_slice=2,
236
238
  gke_accelerator="tpu7x",
237
239
  gce_machine_type="tpu7x-standard-4t",
238
240
  chips_per_vm=4,
@@ -251,14 +253,14 @@ def test_placement_policy_not_created_for_non7x_tpu(
251
253
  mocker, mock_nodepool_dependencies
252
254
  ):
253
255
  """Tests that placement policy is not created for non-tpu7x TPUs."""
254
- mock_is_topology_valid, mock_ensure_resource_policy = (
256
+ mock_is_placement_policy_supported, mock_ensure_resource_policy = (
255
257
  mock_nodepool_dependencies
256
258
  )
257
- mock_is_topology_valid.return_value = True
259
+ mock_is_placement_policy_supported.return_value = False
258
260
  args = mocker.Mock(
259
261
  tpu_type="v6e",
260
262
  device_type=None,
261
- num_slices=1,
263
+ num_slices=2,
262
264
  cluster="test-cluster",
263
265
  project="test-project",
264
266
  zone="us-central1-a",
xpk/core/scheduling.py CHANGED
@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ from ..utils.topology import get_slice_topology_level
17
18
  from ..utils.console import xpk_print
19
+ from ..utils.topology import is_topology_valid
18
20
  from ..utils.execution_context import is_dry_run
19
21
  from .capacity import AUTOPROVISIONING_CONFIG_MAXIMUM_KEY, AUTOPROVISIONING_CONFIG_VALUE
20
22
  from .resources import CLUSTER_RESOURCES_CONFIGMAP, get_cluster_configmap
@@ -299,7 +301,20 @@ def create_sub_slicing_annotations(sub_slicing_topology: str) -> list[str]:
299
301
  return [
300
302
  (
301
303
  'kueue.x-k8s.io/podset-required-topology:'
302
- f' "google.com/gke-tpu-slice-{sub_slicing_topology}-id"'
304
+ f' "{get_slice_topology_level(sub_slicing_topology)}"'
303
305
  ),
304
306
  f'cloud.google.com/gke-tpu-slice-topology: {sub_slicing_topology}',
305
307
  ]
308
+
309
+
310
+ def create_placement_policy_label(system: SystemCharacteristics) -> str:
311
+ name = get_placement_policy_name(system)
312
+ return f'cloud.google.com/placement-policy-name: {name}'
313
+
314
+
315
+ def get_placement_policy_name(system: SystemCharacteristics) -> str:
316
+ return f'{system.device_type}-{system.topology}-placement-policy'
317
+
318
+
319
+ def is_placement_policy_supported(system: SystemCharacteristics) -> bool:
320
+ return system.requires_workload_policy and is_topology_valid(system.topology)
@@ -14,18 +14,97 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from .scheduling import create_sub_slicing_annotations
17
+ from .scheduling import create_sub_slicing_annotations, create_placement_policy_label, get_placement_policy_name, is_placement_policy_supported
18
+ from .system_characteristics import SystemCharacteristics, AcceleratorType
18
19
 
19
20
 
20
21
  def test_create_sub_slicing_annotations_returns_valid_annotations():
21
- subslicing_topology = '2x2'
22
-
23
- result = create_sub_slicing_annotations(subslicing_topology)
22
+ result = create_sub_slicing_annotations(sub_slicing_topology='2x4')
24
23
 
25
24
  assert result == [
26
25
  (
27
26
  'kueue.x-k8s.io/podset-required-topology:'
28
- ' "google.com/gke-tpu-slice-2x2-id"'
27
+ ' "cloud.google.com/gke-tpu-slice-2x4-id"'
29
28
  ),
30
- 'cloud.google.com/gke-tpu-slice-topology: 2x2',
29
+ 'cloud.google.com/gke-tpu-slice-topology: 2x4',
31
30
  ]
31
+
32
+
33
+ def test_create_placement_policy_label_returns_valid_label():
34
+ system_characteristics = SystemCharacteristics(
35
+ chips_per_vm=1,
36
+ gce_machine_type='tpu7x-standard-1t',
37
+ gke_accelerator='tpu7x',
38
+ requires_workload_policy=False,
39
+ topology='1x1x1',
40
+ vms_per_slice=1,
41
+ device_type='tpu7x',
42
+ accelerator_type=AcceleratorType.TPU,
43
+ supports_sub_slicing=False,
44
+ )
45
+ label = create_placement_policy_label(system_characteristics)
46
+ assert (
47
+ label
48
+ == 'cloud.google.com/placement-policy-name: tpu7x-1x1x1-placement-policy'
49
+ )
50
+
51
+
52
+ def test_get_placement_policy_name_returns_valid_name():
53
+ system_characteristics = SystemCharacteristics(
54
+ chips_per_vm=1,
55
+ gce_machine_type='tpu7x-standard-1t',
56
+ gke_accelerator='tpu7x',
57
+ requires_workload_policy=False,
58
+ topology='1x1x1',
59
+ vms_per_slice=1,
60
+ device_type='tpu7x',
61
+ accelerator_type=AcceleratorType.TPU,
62
+ supports_sub_slicing=False,
63
+ )
64
+ name = get_placement_policy_name(system_characteristics)
65
+ assert name == 'tpu7x-1x1x1-placement-policy'
66
+
67
+
68
+ def test_is_placement_policy_supported_returns_true_for_system_characteristics_supporting_workload_policy_and_having_valid_topology():
69
+ system_characteristics = SystemCharacteristics(
70
+ chips_per_vm=1,
71
+ gce_machine_type='tpu7x-standard-1t',
72
+ gke_accelerator='tpu7x',
73
+ requires_workload_policy=True,
74
+ topology='1x1x1',
75
+ vms_per_slice=1,
76
+ device_type='tpu7x',
77
+ accelerator_type=AcceleratorType.TPU,
78
+ supports_sub_slicing=False,
79
+ )
80
+ assert is_placement_policy_supported(system_characteristics) is True
81
+
82
+
83
+ def test_is_placement_policy_supported_returns_false_for_system_characteristics_not_supporting_workload_policy_and_having_valid_topology():
84
+ system_characteristics = SystemCharacteristics(
85
+ chips_per_vm=1,
86
+ gce_machine_type='tpu7x-standard-1t',
87
+ gke_accelerator='tpu7x',
88
+ requires_workload_policy=False,
89
+ topology='1x1x1',
90
+ vms_per_slice=1,
91
+ device_type='tpu7x',
92
+ accelerator_type=AcceleratorType.TPU,
93
+ supports_sub_slicing=False,
94
+ )
95
+ assert is_placement_policy_supported(system_characteristics) is False
96
+
97
+
98
+ def test_is_placement_policy_supported_returns_false_for_system_characteristics_supporting_workload_policy_and_having_invalid_topology():
99
+ system_characteristics = SystemCharacteristics(
100
+ chips_per_vm=1,
101
+ gce_machine_type='tpu7x-standard-1t',
102
+ gke_accelerator='tpu7x',
103
+ requires_workload_policy=True,
104
+ topology='aaa',
105
+ vms_per_slice=1,
106
+ device_type='tpu7x',
107
+ accelerator_type=AcceleratorType.TPU,
108
+ supports_sub_slicing=False,
109
+ )
110
+ assert is_placement_policy_supported(system_characteristics) is False