xpk 0.14.0__py3-none-any.whl → 0.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. integration/__init__.py +15 -0
  2. integration/docker_manager_test.py +102 -0
  3. integration/gcluster_a3mega_test.py +204 -0
  4. integration/gcluster_a3ultra_test.py +176 -0
  5. integration/gcluster_a4_test.py +176 -0
  6. integration/gcluster_test.py +107 -0
  7. xpk/commands/cluster.py +17 -4
  8. xpk/commands/cluster_gcluster.py +4 -0
  9. xpk/commands/cluster_test.py +92 -0
  10. xpk/commands/common.py +6 -0
  11. xpk/commands/kind.py +1 -0
  12. xpk/commands/workload.py +41 -7
  13. xpk/commands/workload_test.py +81 -0
  14. xpk/core/blueprint/testing/__init__.py +15 -0
  15. xpk/core/config.py +1 -1
  16. xpk/core/kueue_manager.py +60 -20
  17. xpk/core/kueue_manager_test.py +52 -20
  18. xpk/core/system_characteristics.py +16 -4
  19. xpk/core/system_characteristics_test.py +73 -0
  20. xpk/templates/cluster_preheat.yaml.j2 +31 -0
  21. xpk/templates/filestore-pv.yaml +17 -0
  22. xpk/templates/filestore-pvc.yaml +11 -0
  23. xpk/templates/filestore-sc.yaml +10 -0
  24. xpk/templates/fuse-pv.yaml +17 -0
  25. xpk/templates/fuse-pvc.yaml +13 -0
  26. xpk/templates/kueue_config.yaml.j2 +95 -0
  27. xpk/templates/kueue_gke_default_topology.yaml.j2 +10 -0
  28. xpk/templates/kueue_sub_slicing_topology.yaml.j2 +14 -0
  29. xpk/templates/mtc-cpc.yaml +15 -0
  30. xpk/templates/volume_bundle.yaml +7 -0
  31. xpk/utils/templates.py +14 -1
  32. xpk/utils/topology.py +9 -0
  33. xpk/utils/topology_test.py +21 -1
  34. {xpk-0.14.0.dist-info → xpk-0.14.1.dist-info}/METADATA +1 -1
  35. {xpk-0.14.0.dist-info → xpk-0.14.1.dist-info}/RECORD +39 -18
  36. xpk-0.14.1.dist-info/top_level.txt +2 -0
  37. xpk-0.14.0.dist-info/top_level.txt +0 -1
  38. {xpk-0.14.0.dist-info → xpk-0.14.1.dist-info}/WHEEL +0 -0
  39. {xpk-0.14.0.dist-info → xpk-0.14.1.dist-info}/entry_points.txt +0 -0
  40. {xpk-0.14.0.dist-info → xpk-0.14.1.dist-info}/licenses/LICENSE +0 -0
@@ -76,9 +76,7 @@ class KueueManagerTest(unittest.TestCase):
76
76
  mock_install.assert_called_once()
77
77
  mock_configure.assert_called_once()
78
78
 
79
- @patch(
80
- "xpk.core.kueue_manager.KueueManager._KueueManager__get_installed_kueue_version"
81
- )
79
+ @patch("xpk.core.kueue_manager.KueueManager.get_installed_kueue_version")
82
80
  @patch("xpk.core.kueue_manager.KueueManager._KueueManager__install")
83
81
  @patch("xpk.core.kueue_manager.KueueManager._KueueManager__configure")
84
82
  def test_install_or_upgrade_when_newer_version_already_installed(
@@ -95,9 +93,7 @@ class KueueManagerTest(unittest.TestCase):
95
93
  mock_install.assert_not_called()
96
94
  mock_configure.assert_not_called()
97
95
 
98
- @patch(
99
- "xpk.core.kueue_manager.KueueManager._KueueManager__get_installed_kueue_version"
100
- )
96
+ @patch("xpk.core.kueue_manager.KueueManager.get_installed_kueue_version")
101
97
  def test_install_or_upgrade_when_outdated(
102
98
  self,
103
99
  mock_get_version,
@@ -121,9 +117,7 @@ class KueueManagerTest(unittest.TestCase):
121
117
  mock_install.assert_called_once()
122
118
  mock_configure.assert_called_once()
123
119
 
124
- @patch(
125
- "xpk.core.kueue_manager.KueueManager._KueueManager__get_installed_kueue_version"
126
- )
120
+ @patch("xpk.core.kueue_manager.KueueManager.get_installed_kueue_version")
127
121
  def test_install_or_upgrade_when_not_installed(
128
122
  self,
129
123
  mock_get_version,
@@ -155,7 +149,7 @@ class KueueManagerTest(unittest.TestCase):
155
149
  return_value=0,
156
150
  ) as mock_run_retry,
157
151
  patch(
158
- "xpk.core.kueue_manager.KueueManager._KueueManager__get_installed_kueue_version",
152
+ "xpk.core.kueue_manager.KueueManager.get_installed_kueue_version",
159
153
  return_value=(1, None),
160
154
  ),
161
155
  patch(
@@ -199,7 +193,7 @@ class KueueManagerTest(unittest.TestCase):
199
193
  return_value=0,
200
194
  ) as mock_run_retry,
201
195
  patch(
202
- "xpk.core.kueue_manager.KueueManager._KueueManager__get_installed_kueue_version",
196
+ "xpk.core.kueue_manager.KueueManager.get_installed_kueue_version",
203
197
  return_value=(1, None),
204
198
  ),
205
199
  patch(
@@ -224,9 +218,7 @@ class KueueManagerTest(unittest.TestCase):
224
218
  self.assertEqual(result, 0)
225
219
  self.assertEqual(mock_run_retry.call_count, 0)
226
220
 
227
- @patch(
228
- "xpk.core.kueue_manager.KueueManager._KueueManager__get_installed_kueue_version"
229
- )
221
+ @patch("xpk.core.kueue_manager.KueueManager.get_installed_kueue_version")
230
222
  @patch("xpk.core.kueue_manager.KueueManager._KueueManager__apply_manifest")
231
223
  def test_configuration_updates_resources(
232
224
  self, mock_apply_manifest, mock_get_version
@@ -240,6 +232,7 @@ class KueueManagerTest(unittest.TestCase):
240
232
  total_chips=8,
241
233
  cpu_limit=100,
242
234
  memory_limit="100Gi",
235
+ configure_sub_slicing=False,
243
236
  )
244
237
 
245
238
  with (
@@ -265,6 +258,7 @@ class KueueManagerTest(unittest.TestCase):
265
258
  total_chips=8,
266
259
  cpu_limit=100,
267
260
  memory_limit="100Gi",
261
+ configure_sub_slicing=False,
268
262
  )
269
263
 
270
264
  with (
@@ -274,7 +268,7 @@ class KueueManagerTest(unittest.TestCase):
274
268
  ),
275
269
  patch.object(
276
270
  self.kueue_manager,
277
- "_KueueManager__get_installed_kueue_version",
271
+ "get_installed_kueue_version",
278
272
  return_value=(1, None),
279
273
  ),
280
274
  patch.object(
@@ -307,6 +301,7 @@ class KueueManagerTest(unittest.TestCase):
307
301
  total_chips=8,
308
302
  cpu_limit=100,
309
303
  memory_limit="100Gi",
304
+ configure_sub_slicing=False,
310
305
  )
311
306
 
312
307
  with (
@@ -316,7 +311,7 @@ class KueueManagerTest(unittest.TestCase):
316
311
  ),
317
312
  patch.object(
318
313
  self.kueue_manager,
319
- "_KueueManager__get_installed_kueue_version",
314
+ "get_installed_kueue_version",
320
315
  return_value=(1, None),
321
316
  ),
322
317
  patch.object(
@@ -344,7 +339,7 @@ class KueueManagerTest(unittest.TestCase):
344
339
  @patch(
345
340
  "xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary"
346
341
  )
347
- def test_configure_generates_correct_manifest(
342
+ def test_configure_generates_correct_manifest_for_tpu(
348
343
  self, mock_update_resources, mock_install
349
344
  ):
350
345
  """Test that __configure generates the correct manifest content for TPUs."""
@@ -357,6 +352,7 @@ class KueueManagerTest(unittest.TestCase):
357
352
  memory_limit="100Gi",
358
353
  autoprovisioning_enabled=False,
359
354
  num_slices=2,
355
+ configure_sub_slicing=False,
360
356
  )
361
357
 
362
358
  rendered_manifest = self._trigger_installation(kueue_config)
@@ -413,6 +409,7 @@ class KueueManagerTest(unittest.TestCase):
413
409
  autoprovisioning_enabled=False,
414
410
  num_slices=1,
415
411
  flex=True,
412
+ configure_sub_slicing=False,
416
413
  )
417
414
 
418
415
  rendered_manifest = self._trigger_installation(kueue_config)
@@ -432,7 +429,7 @@ class KueueManagerTest(unittest.TestCase):
432
429
  @patch(
433
430
  "xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary"
434
431
  )
435
- def test_configure_generates_correct_manifest_with_topology(
432
+ def test_configure_generates_correct_manifest_with_gke_default_topology(
436
433
  self, mock_update_resources, mock_install
437
434
  ):
438
435
  """Test that __configure generates correct manifest for GPUs."""
@@ -444,11 +441,11 @@ class KueueManagerTest(unittest.TestCase):
444
441
  cpu_limit=100,
445
442
  memory_limit="100Gi",
446
443
  num_slices=2,
444
+ configure_sub_slicing=False,
447
445
  )
448
446
 
449
447
  rendered_manifest = self._trigger_installation(kueue_config)
450
448
 
451
- self.assertIn("kind: Topology", rendered_manifest)
452
449
  manifest_docs = list(yaml.safe_load_all(rendered_manifest))
453
450
  resource_flavor = _first(
454
451
  doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor"
@@ -459,6 +456,40 @@ class KueueManagerTest(unittest.TestCase):
459
456
  ],
460
457
  "h100-mega-80gb-8",
461
458
  )
459
+ self.assertEqual(resource_flavor["spec"]["topologyName"], "gke-default")
460
+ topology = _first(doc for doc in manifest_docs if doc["kind"] == "Topology")
461
+ self.assertEqual(topology["metadata"]["name"], "gke-default")
462
+
463
+ @patch("xpk.core.kueue_manager.KueueManager._KueueManager__install")
464
+ @patch(
465
+ "xpk.core.kueue_manager.KueueManager._KueueManager__update_kueue_resources_if_necessary"
466
+ )
467
+ def test_configure_generates_correct_manifest_with_sub_slicing(
468
+ self, mock_update_resources, mock_install
469
+ ):
470
+ """Test that __configure generates correct manifest with sub-slicing topology."""
471
+ mock_install.return_value = 0
472
+ mock_update_resources.return_value = 0
473
+ kueue_config = KueueConfig(
474
+ system=self.mock_system_chars,
475
+ total_chips=16,
476
+ cpu_limit=100,
477
+ memory_limit="100Gi",
478
+ num_slices=2,
479
+ configure_sub_slicing=True,
480
+ )
481
+
482
+ rendered_manifest = self._trigger_installation(kueue_config)
483
+
484
+ manifest_docs = list(yaml.safe_load_all(rendered_manifest))
485
+ resource_flavor = _first(
486
+ doc for doc in manifest_docs if doc["kind"] == "ResourceFlavor"
487
+ )
488
+ self.assertEqual(
489
+ resource_flavor["spec"]["topologyName"], "sub-slice-topology"
490
+ )
491
+ topology = _first(doc for doc in manifest_docs if doc["kind"] == "Topology")
492
+ self.assertEqual(topology["metadata"]["name"], "sub-slice-topology")
462
493
 
463
494
  @patch("xpk.core.kueue_manager.KueueManager._KueueManager__install")
464
495
  @patch(
@@ -477,6 +508,7 @@ class KueueManagerTest(unittest.TestCase):
477
508
  memory_limit="100Gi",
478
509
  is_pathways_cluster=True,
479
510
  num_slices=2,
511
+ configure_sub_slicing=False,
480
512
  )
481
513
 
482
514
  rendered_manifest = self._trigger_installation(kueue_config)
@@ -513,7 +545,7 @@ class KueueManagerTest(unittest.TestCase):
513
545
  """Calls Kueue installation and returns the rendered manifest."""
514
546
  with (
515
547
  patch.object(
516
- self.kueue_manager, "_KueueManager__get_installed_kueue_version"
548
+ self.kueue_manager, "get_installed_kueue_version"
517
549
  ) as mock_get_version,
518
550
  patch.object(
519
551
  self.kueue_manager, "_KueueManager__apply_manifest"
@@ -135,10 +135,9 @@ def get_tpu_system_characteristics_map(
135
135
  ) -> dict[str, SystemCharacteristics]:
136
136
  system_characteristics_map = {}
137
137
  for topology in supported_topologies:
138
- total_chips = get_topology_product(topology)
139
- num_tensorcores = total_chips * tensorcores_per_chip
140
- chips_per_vm = 1 if total_chips == 1 else 4
141
- vms_per_slice = total_chips // chips_per_vm
138
+ chips_per_vm = compute_chips_per_vm(topology)
139
+ vms_per_slice = compute_vms_per_slice(topology)
140
+ num_tensorcores = compute_num_tensorcores(tensorcores_per_chip, topology)
142
141
  system = SystemCharacteristics(
143
142
  topology=topology,
144
143
  vms_per_slice=vms_per_slice,
@@ -156,6 +155,19 @@ def get_tpu_system_characteristics_map(
156
155
  return system_characteristics_map
157
156
 
158
157
 
158
+ def compute_chips_per_vm(topology: str) -> int:
159
+ return 1 if get_topology_product(topology) == 1 else 4
160
+
161
+
162
+ def compute_num_tensorcores(tensorcores_per_chip: int, topology: str) -> int:
163
+ return get_topology_product(topology) * tensorcores_per_chip
164
+
165
+
166
+ def compute_vms_per_slice(topology: str) -> int:
167
+ chips_per_vm = compute_chips_per_vm(topology)
168
+ return get_topology_product(topology) // chips_per_vm
169
+
170
+
159
171
  ################### Subcommand Helper Functions #############################
160
172
  """ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
161
173
  IF YOU MODIFY THE BELOW UserFacingNameToSystemCharacteristics MAP YOU SHOULD
@@ -0,0 +1,73 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from .system_characteristics import get_tpu_system_characteristics_map, SystemCharacteristics
18
+
19
+
20
+ def test_get_tpu_system_characteristics_map_returns_correct_values_for_1x1_topology():
21
+ result = get_tpu_system_characteristics_map(
22
+ prefix="test",
23
+ tensorcores_per_chip=1,
24
+ gke_accelerator="test",
25
+ machine_type="test",
26
+ supported_topologies=["1x1"],
27
+ supports_sub_slicing=False,
28
+ requires_workload_policy=True,
29
+ )
30
+
31
+ expected_system_characteristics = SystemCharacteristics(
32
+ topology="1x1",
33
+ vms_per_slice=1,
34
+ gke_accelerator="test",
35
+ gce_machine_type="test",
36
+ chips_per_vm=1,
37
+ accelerator_type=1,
38
+ device_type="test-1",
39
+ supports_sub_slicing=False,
40
+ requires_workload_policy=True,
41
+ )
42
+ assert result == {
43
+ "test-1": expected_system_characteristics,
44
+ "test-1x1": expected_system_characteristics,
45
+ }
46
+
47
+
48
+ def test_get_tpu_system_characteristics_map_returns_correct_values_for_2x2_topology():
49
+ result = get_tpu_system_characteristics_map(
50
+ prefix="test",
51
+ tensorcores_per_chip=2,
52
+ gke_accelerator="test",
53
+ machine_type="test",
54
+ supported_topologies=["2x2"],
55
+ supports_sub_slicing=False,
56
+ requires_workload_policy=True,
57
+ )
58
+
59
+ expected_system_characteristics = SystemCharacteristics(
60
+ topology="2x2",
61
+ vms_per_slice=1,
62
+ gke_accelerator="test",
63
+ gce_machine_type="test",
64
+ chips_per_vm=4,
65
+ accelerator_type=1,
66
+ device_type="test-8",
67
+ supports_sub_slicing=False,
68
+ requires_workload_policy=True,
69
+ )
70
+ assert result == {
71
+ "test-8": expected_system_characteristics,
72
+ "test-2x2": expected_system_characteristics,
73
+ }
@@ -0,0 +1,31 @@
1
+ apiVersion: apps/v1
2
+ kind: DaemonSet
3
+ metadata:
4
+ name: {{ cachekey }}
5
+ labels:
6
+ k8s-app: {{ cachekey }}
7
+ spec:
8
+ selector:
9
+ matchLabels:
10
+ k8s-app: {{ cachekey }}
11
+ updateStrategy:
12
+ type: RollingUpdate
13
+ template:
14
+ metadata:
15
+ labels:
16
+ name: {{ cachekey }}
17
+ k8s-app: {{ cachekey }}
18
+ spec:
19
+ affinity:
20
+ nodeAffinity:
21
+ requiredDuringSchedulingIgnoredDuringExecution:
22
+ nodeSelectorTerms:
23
+ - matchExpressions:
24
+ - key: {{ nodeSelectorKey }}
25
+ operator: Exists
26
+ tolerations:
27
+ - operator: "Exists"
28
+ containers:
29
+ - image: {{ image_name }}
30
+ name: {{ cachekey }}
31
+ command: [ "sleep", "inf" ]
@@ -0,0 +1,17 @@
1
+ apiVersion: v1
2
+ kind: PersistentVolume
3
+ metadata:
4
+ name: xpk-filestore-pv
5
+ spec:
6
+ storageClassName:
7
+ capacity:
8
+ storage:
9
+ accessModes:
10
+ persistentVolumeReclaimPolicy: Retain
11
+ volumeMode: Filesystem
12
+ csi:
13
+ driver: filestore.csi.storage.gke.io
14
+ volumeHandle:
15
+ volumeAttributes:
16
+ ip:
17
+ volume:
@@ -0,0 +1,11 @@
1
+ kind: PersistentVolumeClaim
2
+ apiVersion: v1
3
+ metadata:
4
+ name:
5
+ spec:
6
+ accessModes:
7
+ storageClassName:
8
+ volumeName:
9
+ resources:
10
+ requests:
11
+ storage:
@@ -0,0 +1,10 @@
1
+ apiVersion: storage.k8s.io/v1
2
+ kind: StorageClass
3
+ metadata:
4
+ name:
5
+ provisioner: filestore.csi.storage.gke.io
6
+ volumeBindingMode: Immediate
7
+ allowVolumeExpansion: true
8
+ parameters:
9
+ tier: standard
10
+ network: default
@@ -0,0 +1,17 @@
1
+ apiVersion: v1
2
+ kind: PersistentVolume
3
+ metadata:
4
+ name:
5
+ spec:
6
+ accessModes:
7
+ - ReadWriteMany
8
+ capacity:
9
+ storage:
10
+ storageClassName: example-storage-class
11
+ mountOptions:
12
+ - implicit-dirs
13
+ csi:
14
+ driver: gcsfuse.csi.storage.gke.io
15
+ volumeHandle:
16
+ volumeAttributes:
17
+ gcsfuseLoggingSeverity: warning
@@ -0,0 +1,13 @@
1
+ apiVersion: v1
2
+ kind: PersistentVolumeClaim
3
+ metadata:
4
+ name:
5
+ namespace: default
6
+ spec:
7
+ accessModes:
8
+ - ReadWriteMany
9
+ resources:
10
+ requests:
11
+ storage:
12
+ volumeName:
13
+ storageClassName: example-storage-class
@@ -0,0 +1,95 @@
1
+ {% for flavor in flavors %}
2
+ apiVersion: kueue.x-k8s.io/v1beta1
3
+ kind: ResourceFlavor
4
+ metadata:
5
+ name: "{{ flavor.name }}"
6
+ spec:
7
+ nodeLabels: {{ flavor.nodeLabels | tojson }}
8
+ {% if flavor.topologyLabel %}
9
+ {{ flavor.topologyLabel }}
10
+ {% endif %}
11
+ ---
12
+ {% endfor %}
13
+ apiVersion: kueue.x-k8s.io/v1beta1
14
+ kind: AdmissionCheck
15
+ metadata:
16
+ name: dws-prov
17
+ spec:
18
+ controllerName: kueue.x-k8s.io/provisioning-request
19
+ parameters:
20
+ apiGroup: kueue.x-k8s.io
21
+ kind: ProvisioningRequestConfig
22
+ name: dws-config
23
+ ---
24
+ apiVersion: kueue.x-k8s.io/v1beta1
25
+ kind: ProvisioningRequestConfig
26
+ metadata:
27
+ name: dws-config
28
+ spec:
29
+ provisioningClassName: queued-provisioning.gke.io
30
+ podSetUpdates:
31
+ nodeSelector:
32
+ - key: autoscaling.gke.io/provisioning-request
33
+ valueFromProvisioningClassDetail: ResizeRequestName
34
+ managedResources:
35
+ - {{ managed_resource }}
36
+ ---
37
+ apiVersion: kueue.x-k8s.io/v1beta1
38
+ kind: ClusterQueue
39
+ metadata:
40
+ name: "{{ cluster_queue_name }}"
41
+ spec:
42
+ preemption:
43
+ reclaimWithinCohort: Never # Don't preempt other queues in the cohort.
44
+ withinClusterQueue: LowerPriority
45
+ namespaceSelector: {} # match all.
46
+ resourceGroups: {{ resource_groups }}
47
+ {{ admission_checks | indent(2) }}
48
+ ---
49
+ apiVersion: kueue.x-k8s.io/v1beta1
50
+ kind: LocalQueue
51
+ metadata:
52
+ namespace: default
53
+ name: {{ local_queue_name }}
54
+ spec:
55
+ clusterQueue: {{ cluster_queue_name }}
56
+ ---
57
+ apiVersion: scheduling.k8s.io/v1
58
+ kind: PriorityClass
59
+ metadata:
60
+ name: very-low
61
+ value: 100
62
+ globalDefault: false
63
+ description: "Very Low"
64
+ ---
65
+ apiVersion: scheduling.k8s.io/v1
66
+ kind: PriorityClass
67
+ metadata:
68
+ name: low
69
+ value: 250
70
+ globalDefault: false
71
+ description: "Low"
72
+ ---
73
+ apiVersion: scheduling.k8s.io/v1
74
+ kind: PriorityClass
75
+ metadata:
76
+ name: medium
77
+ value: 500
78
+ globalDefault: false
79
+ description: "Medium"
80
+ ---
81
+ apiVersion: scheduling.k8s.io/v1
82
+ kind: PriorityClass
83
+ metadata:
84
+ name: high
85
+ value: 750
86
+ globalDefault: false
87
+ description: "High"
88
+ ---
89
+ apiVersion: scheduling.k8s.io/v1
90
+ kind: PriorityClass
91
+ metadata:
92
+ name: very-high
93
+ value: 1000
94
+ globalDefault: false
95
+ description: "Very High"
@@ -0,0 +1,10 @@
1
+ apiVersion: kueue.x-k8s.io/v1beta1
2
+ kind: Topology
3
+ metadata:
4
+ name: "gke-default"
5
+ spec:
6
+ levels:
7
+ - nodeLabel: "cloud.google.com/gce-topology-block"
8
+ - nodeLabel: "cloud.google.com/gce-topology-subblock"
9
+ - nodeLabel: "cloud.google.com/gce-topology-host"
10
+ - nodeLabel: "kubernetes.io/hostname"
@@ -0,0 +1,14 @@
1
+ apiVersion: kueue.x-k8s.io/v1beta1
2
+ kind: Topology
3
+ metadata:
4
+ name: {{ sub_slice_topology_name }}
5
+ spec:
6
+ levels:
7
+ - nodeLabel: "cloud.google.com/gke-tpu-slice-16x16-id"
8
+ - nodeLabel: "cloud.google.com/gke-tpu-slice-8x16-id"
9
+ - nodeLabel: "cloud.google.com/gke-tpu-slice-8x8-id"
10
+ - nodeLabel: "cloud.google.com/gke-tpu-slice-4x8-id"
11
+ - nodeLabel: "cloud.google.com/gke-tpu-slice-4x4-id"
12
+ - nodeLabel: "cloud.google.com/gke-tpu-slice-2x4-id"
13
+ - nodeLabel: "cloud.google.com/gke-tpu-slice-2x2-id"
14
+ - nodeLabel: "kubernetes.io/hostname"
@@ -0,0 +1,15 @@
1
+ apiVersion: checkpointing.gke.io/v1
2
+ kind: CheckpointConfiguration
3
+ metadata:
4
+ name: my-checkpointconfiguration
5
+ spec:
6
+ cloudStorageBucketName:
7
+ # This field is optional
8
+ nodeSelector:
9
+ node.kubernetes.io/instance-type:
10
+ # This field is optional
11
+ tolerations:
12
+ - key:
13
+ operator: Exists
14
+ effect: NoSchedule
15
+ inMemoryVolumeSize:
@@ -0,0 +1,7 @@
1
+ apiVersion: kjobctl.x-k8s.io/v1alpha1
2
+ kind: VolumeBundle
3
+ metadata:
4
+ name: $NAME
5
+ spec:
6
+ volumes: []
7
+ containerVolumeMounts: []
xpk/utils/templates.py CHANGED
@@ -18,7 +18,7 @@ import os
18
18
 
19
19
  import ruamel.yaml
20
20
 
21
- TEMPLATE_PATH = "src/xpk/templates/"
21
+ TEMPLATE_PATH = "templates"
22
22
 
23
23
  yaml = ruamel.yaml.YAML()
24
24
 
@@ -28,3 +28,16 @@ def load(path: str) -> dict:
28
28
  with open(template_path, "r", encoding="utf-8") as file:
29
29
  data: dict = yaml.load(file)
30
30
  return data
31
+
32
+
33
+ def get_templates_absolute_path(templates_path: str = TEMPLATE_PATH) -> str:
34
+ """
35
+ Return the absolute path to the templates folder
36
+
37
+ Args:
38
+ templates_path: The path to the templates folder relative to the src/xpk directory
39
+ """
40
+ current_file_path = os.path.abspath(__file__)
41
+ current_dir = os.path.dirname(current_file_path)
42
+ xpk_package_dir = os.path.dirname(current_dir)
43
+ return os.path.join(xpk_package_dir, templates_path)
xpk/utils/topology.py CHANGED
@@ -35,3 +35,12 @@ def parse_topology(topology: str) -> list[int]:
35
35
  raise ValueError("Topology is an empty string")
36
36
 
37
37
  return [int(el) for el in topology.lower().split("x")]
38
+
39
+
40
+ def is_topology_contained(contained: str, container: str) -> bool:
41
+ contained_parsed = parse_topology(contained)
42
+ container_parsed = parse_topology(container)
43
+ return len(contained_parsed) == len(container_parsed) and all(
44
+ contained <= container
45
+ for contained, container in zip(contained_parsed, container_parsed)
46
+ )
@@ -15,7 +15,7 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import pytest
18
- from .topology import is_topology_valid, get_topology_product, parse_topology
18
+ from .topology import is_topology_valid, get_topology_product, parse_topology, is_topology_contained
19
19
 
20
20
 
21
21
  def test_is_topology_valid_with_invalid_topology():
@@ -41,3 +41,23 @@ def test_parse_topology_with_empty_input():
41
41
  def test_get_topology_product():
42
42
  result = get_topology_product("1x2x3")
43
43
  assert result == 6
44
+
45
+
46
+ def test_is_topology_contained_with_container_smaller_than_contained_returns_false():
47
+ result = is_topology_contained(contained="3x3x3", container="2x2x2")
48
+ assert result is False
49
+
50
+
51
+ def test_is_topology_contained_with_container_larger_than_contained_returns_true():
52
+ result = is_topology_contained(contained="1x1x1", container="2x2x2")
53
+ assert result is True
54
+
55
+
56
+ def test_is_topology_contained_with_container_equal_to_contained_returns_true():
57
+ result = is_topology_contained(contained="2x2x2", container="2x2x2")
58
+ assert result is True
59
+
60
+
61
+ def test_is_topology_contained_with_different_topologies_dimensions_returns_false():
62
+ result = is_topology_contained(contained="2x2", container="2x2x2")
63
+ assert result is False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xpk
3
- Version: 0.14.0
3
+ Version: 0.14.1
4
4
  Summary: xpk helps Cloud developers to orchestrate training jobs on accelerators on GKE.
5
5
  Author-email: XPK team <xpk-code-reviewers@google.com>
6
6
  License: Apache-2.0