xpk 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. xpk/commands/cluster.py +57 -22
  2. xpk/commands/cluster_gcluster_test.py +2 -2
  3. xpk/commands/cluster_test.py +197 -25
  4. xpk/commands/inspector.py +20 -7
  5. xpk/commands/kind.py +1 -1
  6. xpk/commands/workload.py +42 -4
  7. xpk/commands/workload_test.py +88 -5
  8. xpk/core/blueprint/blueprint_definitions.py +16 -1
  9. xpk/core/blueprint/blueprint_generator.py +11 -11
  10. xpk/core/capacity.py +17 -0
  11. xpk/core/capacity_test.py +50 -0
  12. xpk/core/config.py +1 -1
  13. xpk/core/docker_container.py +4 -4
  14. xpk/core/docker_resources.py +11 -11
  15. xpk/core/kjob.py +3 -5
  16. xpk/core/kueue_manager.py +21 -10
  17. xpk/core/kueue_manager_test.py +379 -536
  18. xpk/core/nap.py +1 -1
  19. xpk/core/nodepool.py +9 -9
  20. xpk/core/nodepool_test.py +4 -4
  21. xpk/core/pathways.py +1 -1
  22. xpk/core/resources.py +1 -1
  23. xpk/core/scheduling.py +7 -13
  24. xpk/core/system_characteristics.py +42 -35
  25. xpk/core/system_characteristics_test.py +3 -3
  26. xpk/core/testing/__init__.py +15 -0
  27. xpk/core/testing/commands_tester.py +131 -0
  28. xpk/core/testing/commands_tester_test.py +129 -0
  29. xpk/core/updates.py +57 -0
  30. xpk/core/updates_test.py +80 -0
  31. xpk/main.py +7 -4
  32. xpk/parser/common.py +8 -0
  33. xpk/utils/execution_context.py +20 -2
  34. {xpk-0.14.2.dist-info → xpk-0.14.3.dist-info}/METADATA +1 -3
  35. {xpk-0.14.2.dist-info → xpk-0.14.3.dist-info}/RECORD +39 -33
  36. {xpk-0.14.2.dist-info → xpk-0.14.3.dist-info}/WHEEL +0 -0
  37. {xpk-0.14.2.dist-info → xpk-0.14.3.dist-info}/entry_points.txt +0 -0
  38. {xpk-0.14.2.dist-info → xpk-0.14.3.dist-info}/licenses/LICENSE +0 -0
  39. {xpk-0.14.2.dist-info → xpk-0.14.3.dist-info}/top_level.txt +0 -0
xpk/core/nap.py CHANGED
@@ -85,7 +85,7 @@ def enable_autoprovisioning_on_cluster(
85
85
  # TODO(@vbarr): Support timeout period for idle np before they are deleted.
86
86
  # TODO(@vbarr): Support for hot idle configuration (timeout period is infinity).
87
87
  return_code = 0
88
- if system.accelerator_type == AcceleratorType['CPU']:
88
+ if system.accelerator_type == AcceleratorType.CPU:
89
89
  xpk_print("Error: XPK NAP doesn't support Accelerators of Types: CPUs.")
90
90
  return None, 1
91
91
 
xpk/core/nodepool.py CHANGED
@@ -80,7 +80,7 @@ def run_gke_node_pool_create_command(
80
80
  if return_code > 0:
81
81
  xpk_print('Listing all reservations failed!')
82
82
  return_code = 1
83
- if system.accelerator_type == AcceleratorType['TPU']:
83
+ if system.accelerator_type == AcceleratorType.TPU:
84
84
  max_nodes = system.vms_per_slice
85
85
  else:
86
86
  max_nodes = 1000
@@ -92,16 +92,14 @@ def run_gke_node_pool_create_command(
92
92
  return return_code
93
93
 
94
94
  desired_node_pool_count = (
95
- 1
96
- if system.accelerator_type == AcceleratorType['GPU']
97
- else args.num_slices
95
+ 1 if system.accelerator_type == AcceleratorType.GPU else args.num_slices
98
96
  )
99
97
  message = (
100
98
  (
101
99
  f'Creating 1 node pool with {args.num_nodes} nodes of'
102
100
  f' {system.device_type}\nUnderlyingly, we assume that means: {system}'
103
101
  )
104
- if system.accelerator_type == AcceleratorType['GPU']
102
+ if system.accelerator_type == AcceleratorType.GPU
105
103
  else (
106
104
  f'Creating {args.num_slices} node pool or pools of'
107
105
  f' {system.device_type}\nUnderlyingly, we assume that means: {system}'
@@ -267,7 +265,9 @@ def run_gke_node_pool_create_command(
267
265
 
268
266
  placement_args = ''
269
267
  if system.requires_workload_policy and is_topology_valid(system.topology):
270
- placement_policy = f'{args.cluster}-placement-policy'
268
+ placement_policy = (
269
+ f'{system.device_type}-{system.topology}-placement-policy'
270
+ )
271
271
  ensure_resource_policy_exists(placement_policy, args, system.topology)
272
272
  placement_args = f' --placement-policy={placement_policy}'
273
273
 
@@ -288,7 +288,7 @@ def run_gke_node_pool_create_command(
288
288
  f'{placement_args}'
289
289
  ' --enable-gvnic'
290
290
  )
291
- if system.accelerator_type == AcceleratorType['TPU']:
291
+ if system.accelerator_type == AcceleratorType.TPU:
292
292
  command += f' --node-version={gke_node_pool_version}'
293
293
  topology_product = get_topology_product(system.topology)
294
294
  if capacity_type == CapacityType.FLEX_START:
@@ -308,7 +308,7 @@ def run_gke_node_pool_create_command(
308
308
  command += f' --tpu-topology={system.topology}'
309
309
  command += ' --max-pods-per-node 15'
310
310
  command += f' {args.custom_tpu_nodepool_arguments}'
311
- elif system.accelerator_type == AcceleratorType['GPU']:
311
+ elif system.accelerator_type == AcceleratorType.GPU:
312
312
  subnet_prefix = (
313
313
  f'{args.cluster}-{get_cluster_location(args.project, args.cluster, args.zone)}'
314
314
  )
@@ -328,7 +328,7 @@ def run_gke_node_pool_create_command(
328
328
  f' network={args.cluster}-net-{i},subnetwork={subnet_prefix}-sub-{i}'
329
329
  )
330
330
  command += ' --max-pods-per-node=32'
331
- elif system.accelerator_type == AcceleratorType['CPU']:
331
+ elif system.accelerator_type == AcceleratorType.CPU:
332
332
  if capacity_type == CapacityType.FLEX_START:
333
333
  command += ' --num-nodes=0'
334
334
  else:
xpk/core/nodepool_test.py CHANGED
@@ -174,7 +174,7 @@ def test_placement_policy_created_for_gpu_with_valid_topology(
174
174
  gke_accelerator="nvidia-h100-80gb",
175
175
  gce_machine_type="a3-highgpu-8g",
176
176
  chips_per_vm=8,
177
- accelerator_type=AcceleratorType["GPU"],
177
+ accelerator_type=AcceleratorType.GPU,
178
178
  device_type="h100-80gb-8",
179
179
  supports_sub_slicing=False,
180
180
  )
@@ -204,7 +204,7 @@ def test_placement_policy_not_created_for_gpu_with_invalid_topology(
204
204
  gke_accelerator="nvidia-h100-80gb",
205
205
  gce_machine_type="a3-highgpu-8g",
206
206
  chips_per_vm=8,
207
- accelerator_type=AcceleratorType["GPU"],
207
+ accelerator_type=AcceleratorType.GPU,
208
208
  device_type="h100-80gb-8",
209
209
  supports_sub_slicing=False,
210
210
  )
@@ -236,7 +236,7 @@ def test_placement_policy_created_for_tpu7x_with_valid_topology(
236
236
  gke_accelerator="tpu7x",
237
237
  gce_machine_type="tpu7x-standard-4t",
238
238
  chips_per_vm=4,
239
- accelerator_type=AcceleratorType["TPU"],
239
+ accelerator_type=AcceleratorType.TPU,
240
240
  device_type="tpu7x-8",
241
241
  requires_workload_policy=True,
242
242
  supports_sub_slicing=False,
@@ -269,7 +269,7 @@ def test_placement_policy_not_created_for_non7x_tpu(
269
269
  gke_accelerator="v6e",
270
270
  gce_machine_type="tpu-v6e-slice",
271
271
  chips_per_vm=4,
272
- accelerator_type=AcceleratorType["TPU"],
272
+ accelerator_type=AcceleratorType.TPU,
273
273
  device_type="v6e-4",
274
274
  supports_sub_slicing=True,
275
275
  )
xpk/core/pathways.py CHANGED
@@ -91,7 +91,7 @@ def ensure_pathways_workload_prerequisites(args, system) -> bool:
91
91
  xpk_exit(1)
92
92
 
93
93
  # Ensure device type is TPUs - currently Pathways supports TPUs only.
94
- if system.accelerator_type != AcceleratorType['TPU']:
94
+ if system.accelerator_type != AcceleratorType.TPU:
95
95
  xpk_print('Currently, Pathways workloads can only be run on TPUs.')
96
96
  xpk_exit(1)
97
97
 
xpk/core/resources.py CHANGED
@@ -109,7 +109,7 @@ def create_cluster_configmaps(
109
109
 
110
110
  # ConfigMap to store resources available in the cluster.
111
111
  device_type = system.device_type
112
- if system.accelerator_type == AcceleratorType['GPU']:
112
+ if system.accelerator_type == AcceleratorType.GPU:
113
113
  resources_data = f'{device_type}: "{int(args.num_nodes)}"'
114
114
  elif args.enable_autoprovisioning and autoprovisioning_config:
115
115
  resources_data = (
xpk/core/scheduling.py CHANGED
@@ -97,7 +97,7 @@ def check_if_workload_can_schedule(args, system: SystemCharacteristics) -> bool:
97
97
  else:
98
98
  # Check if the size of the workload will fit in the cluster.
99
99
  max_vm_in_cluster = int(cluster_config_map[device_type])
100
- if system.accelerator_type == AcceleratorType['GPU']:
100
+ if system.accelerator_type == AcceleratorType.GPU:
101
101
  vm_required_by_workload = args.num_nodes
102
102
  else:
103
103
  vm_required_by_workload = args.num_slices * system.vms_per_slice
@@ -125,7 +125,7 @@ def get_total_chips_requested_from_args(
125
125
  Returns:
126
126
  num of chips for the current request.
127
127
  """
128
- if system.accelerator_type == AcceleratorType['GPU']:
128
+ if system.accelerator_type == AcceleratorType.GPU:
129
129
  num_chips = system.vms_per_slice * system.chips_per_vm * args.num_nodes
130
130
  else:
131
131
  num_chips = system.vms_per_slice * system.chips_per_vm * args.num_slices
@@ -152,7 +152,7 @@ def get_cpu_affinity(accelerator_type) -> str:
152
152
  values:
153
153
  - default-pool
154
154
  """
155
- if accelerator_type == AcceleratorType['CPU']:
155
+ if accelerator_type == AcceleratorType.CPU:
156
156
  return yaml
157
157
  return ''
158
158
 
@@ -225,7 +225,7 @@ def create_accelerator_label(accelerator_type, system) -> str:
225
225
  Returns:
226
226
  The accelerator label.
227
227
  """
228
- if accelerator_type == AcceleratorType['CPU']:
228
+ if accelerator_type == AcceleratorType.CPU:
229
229
  return ''
230
230
  return (
231
231
  f'{AcceleratorTypeToAcceleratorCharacteristics[accelerator_type].accelerator_label}:'
@@ -243,7 +243,7 @@ def create_tpu_machine_type(accelerator_type, system) -> str:
243
243
  Returns:
244
244
  The accelerator label.
245
245
  """
246
- if accelerator_type == AcceleratorType['TPU']:
246
+ if accelerator_type == AcceleratorType.TPU:
247
247
  return f'{system.gce_machine_type}'
248
248
  return ''
249
249
 
@@ -261,10 +261,7 @@ def create_machine_label(
261
261
  Returns:
262
262
  The machine label.
263
263
  """
264
- if (
265
- accelerator_type == AcceleratorType['TPU']
266
- and not autoprovisioning_enabled
267
- ):
264
+ if accelerator_type == AcceleratorType.TPU and not autoprovisioning_enabled:
268
265
  return (
269
266
  f'{AcceleratorTypeToAcceleratorCharacteristics[accelerator_type].machine_label}:'
270
267
  f' {system.topology}'
@@ -285,10 +282,7 @@ def create_tpu_topology(
285
282
  Returns:
286
283
  The machine label.
287
284
  """
288
- if (
289
- accelerator_type == AcceleratorType['TPU']
290
- and not autoprovisioning_enabled
291
- ):
285
+ if accelerator_type == AcceleratorType.TPU and not autoprovisioning_enabled:
292
286
  return f'{system.topology}'
293
287
  return ''
294
288
 
@@ -16,9 +16,16 @@ limitations under the License.
16
16
 
17
17
  from dataclasses import dataclass
18
18
  from ..utils.topology import get_topology_product
19
+ from enum import Enum
19
20
 
20
21
 
21
- AcceleratorType = {'TPU': 1, 'GPU': 2, 'CPU': 3}
22
+ class AcceleratorType(Enum):
23
+ TPU = 1
24
+ GPU = 2
25
+ CPU = 3
26
+
27
+ def __repr__(self):
28
+ return self._name_
22
29
 
23
30
 
24
31
  @dataclass
@@ -29,17 +36,17 @@ class AcceleratorCharacteristics:
29
36
 
30
37
 
31
38
  AcceleratorTypeToAcceleratorCharacteristics = {
32
- AcceleratorType['TPU']: AcceleratorCharacteristics(
39
+ AcceleratorType.TPU: AcceleratorCharacteristics(
33
40
  resource_type='google.com/tpu',
34
41
  accelerator_label='cloud.google.com/gke-tpu-accelerator',
35
42
  machine_label='cloud.google.com/gke-tpu-topology',
36
43
  ),
37
- AcceleratorType['GPU']: AcceleratorCharacteristics(
44
+ AcceleratorType.GPU: AcceleratorCharacteristics(
38
45
  resource_type='nvidia.com/gpu',
39
46
  accelerator_label='cloud.google.com/gke-accelerator',
40
47
  machine_label='cloud.google.com/gce-machine-type',
41
48
  ),
42
- AcceleratorType['CPU']: AcceleratorCharacteristics(
49
+ AcceleratorType.CPU: AcceleratorCharacteristics(
43
50
  resource_type='cpu',
44
51
  accelerator_label='',
45
52
  machine_label='cloud.google.com/gke-nodepool',
@@ -80,13 +87,13 @@ class SystemCharacteristics:
80
87
  gke_accelerator: str
81
88
  gce_machine_type: str
82
89
  chips_per_vm: int
83
- accelerator_type: int # TODO: use enums
90
+ accelerator_type: AcceleratorType
84
91
  device_type: str
85
92
  supports_sub_slicing: bool
86
93
  requires_workload_policy: bool = False
87
94
 
88
95
  def __post_init__(self):
89
- if self.accelerator_type == AcceleratorType['GPU']:
96
+ if self.accelerator_type == AcceleratorType.GPU:
90
97
  self.requires_workload_policy = True
91
98
 
92
99
 
@@ -144,7 +151,7 @@ def get_tpu_system_characteristics_map(
144
151
  gke_accelerator=gke_accelerator,
145
152
  gce_machine_type=machine_type,
146
153
  chips_per_vm=chips_per_vm,
147
- accelerator_type=AcceleratorType['TPU'],
154
+ accelerator_type=AcceleratorType.TPU,
148
155
  device_type=f'{prefix}-{num_tensorcores}',
149
156
  requires_workload_policy=requires_workload_policy,
150
157
  supports_sub_slicing=supports_sub_slicing,
@@ -183,7 +190,7 @@ UserFacingNameToSystemCharacteristics = {
183
190
  gke_accelerator='nvidia-l4',
184
191
  gce_machine_type='g2-standard-12',
185
192
  chips_per_vm=1,
186
- accelerator_type=AcceleratorType['GPU'],
193
+ accelerator_type=AcceleratorType.GPU,
187
194
  device_type='l4-1',
188
195
  supports_sub_slicing=False,
189
196
  ),
@@ -193,7 +200,7 @@ UserFacingNameToSystemCharacteristics = {
193
200
  gke_accelerator='nvidia-l4',
194
201
  gce_machine_type='g2-standard-24',
195
202
  chips_per_vm=2,
196
- accelerator_type=AcceleratorType['GPU'],
203
+ accelerator_type=AcceleratorType.GPU,
197
204
  device_type='l4-2',
198
205
  supports_sub_slicing=False,
199
206
  ),
@@ -203,7 +210,7 @@ UserFacingNameToSystemCharacteristics = {
203
210
  gke_accelerator='nvidia-l4',
204
211
  gce_machine_type='g2-standard-48',
205
212
  chips_per_vm=4,
206
- accelerator_type=AcceleratorType['GPU'],
213
+ accelerator_type=AcceleratorType.GPU,
207
214
  device_type='l4-4',
208
215
  supports_sub_slicing=False,
209
216
  ),
@@ -213,7 +220,7 @@ UserFacingNameToSystemCharacteristics = {
213
220
  gke_accelerator='nvidia-l4',
214
221
  gce_machine_type='g2-standard-96',
215
222
  chips_per_vm=8,
216
- accelerator_type=AcceleratorType['GPU'],
223
+ accelerator_type=AcceleratorType.GPU,
217
224
  device_type='l4-8',
218
225
  supports_sub_slicing=False,
219
226
  ),
@@ -224,7 +231,7 @@ UserFacingNameToSystemCharacteristics = {
224
231
  gke_accelerator='nvidia-tesla-a100',
225
232
  gce_machine_type='a2-highgpu-1g',
226
233
  chips_per_vm=1,
227
- accelerator_type=AcceleratorType['GPU'],
234
+ accelerator_type=AcceleratorType.GPU,
228
235
  device_type='a100-40gb-1',
229
236
  supports_sub_slicing=False,
230
237
  ),
@@ -234,7 +241,7 @@ UserFacingNameToSystemCharacteristics = {
234
241
  gke_accelerator='nvidia-tesla-a100',
235
242
  gce_machine_type='a2-highgpu-2g',
236
243
  chips_per_vm=2,
237
- accelerator_type=AcceleratorType['GPU'],
244
+ accelerator_type=AcceleratorType.GPU,
238
245
  device_type='a100-40gb-2',
239
246
  supports_sub_slicing=False,
240
247
  ),
@@ -244,7 +251,7 @@ UserFacingNameToSystemCharacteristics = {
244
251
  gke_accelerator='nvidia-tesla-a100',
245
252
  gce_machine_type='a2-highgpu-4g',
246
253
  chips_per_vm=4,
247
- accelerator_type=AcceleratorType['GPU'],
254
+ accelerator_type=AcceleratorType.GPU,
248
255
  device_type='a100-40gb-4',
249
256
  supports_sub_slicing=False,
250
257
  ),
@@ -254,7 +261,7 @@ UserFacingNameToSystemCharacteristics = {
254
261
  gke_accelerator='nvidia-tesla-a100',
255
262
  gce_machine_type='a2-highgpu-8g',
256
263
  chips_per_vm=8,
257
- accelerator_type=AcceleratorType['GPU'],
264
+ accelerator_type=AcceleratorType.GPU,
258
265
  device_type='a100-40gb-8',
259
266
  supports_sub_slicing=False,
260
267
  ),
@@ -264,7 +271,7 @@ UserFacingNameToSystemCharacteristics = {
264
271
  gke_accelerator='nvidia-gb200',
265
272
  gce_machine_type='a4x-highgpu-4g',
266
273
  chips_per_vm=4,
267
- accelerator_type=AcceleratorType['GPU'],
274
+ accelerator_type=AcceleratorType.GPU,
268
275
  device_type='gb200-4',
269
276
  supports_sub_slicing=False,
270
277
  ),
@@ -274,7 +281,7 @@ UserFacingNameToSystemCharacteristics = {
274
281
  gke_accelerator='nvidia-gb200',
275
282
  gce_machine_type='a4x-highgpu-4g-nolssd',
276
283
  chips_per_vm=4,
277
- accelerator_type=AcceleratorType['GPU'],
284
+ accelerator_type=AcceleratorType.GPU,
278
285
  device_type='gb200-4',
279
286
  supports_sub_slicing=False,
280
287
  ),
@@ -284,7 +291,7 @@ UserFacingNameToSystemCharacteristics = {
284
291
  gke_accelerator='nvidia-b200',
285
292
  gce_machine_type='a4-highgpu-8g',
286
293
  chips_per_vm=8,
287
- accelerator_type=AcceleratorType['GPU'],
294
+ accelerator_type=AcceleratorType.GPU,
288
295
  device_type='b200-8',
289
296
  supports_sub_slicing=False,
290
297
  ),
@@ -294,7 +301,7 @@ UserFacingNameToSystemCharacteristics = {
294
301
  gke_accelerator='nvidia-h200-141gb',
295
302
  gce_machine_type='a3-ultragpu-8g',
296
303
  chips_per_vm=8,
297
- accelerator_type=AcceleratorType['GPU'],
304
+ accelerator_type=AcceleratorType.GPU,
298
305
  device_type='h200-141gb-8',
299
306
  supports_sub_slicing=False,
300
307
  ),
@@ -305,7 +312,7 @@ UserFacingNameToSystemCharacteristics = {
305
312
  gke_accelerator='nvidia-h100-80gb',
306
313
  gce_machine_type='a3-highgpu-8g',
307
314
  chips_per_vm=8,
308
- accelerator_type=AcceleratorType['GPU'],
315
+ accelerator_type=AcceleratorType.GPU,
309
316
  device_type='h100-80gb-8',
310
317
  supports_sub_slicing=False,
311
318
  ),
@@ -316,7 +323,7 @@ UserFacingNameToSystemCharacteristics = {
316
323
  gke_accelerator='nvidia-h100-mega-80gb',
317
324
  gce_machine_type='a3-megagpu-8g',
318
325
  chips_per_vm=8,
319
- accelerator_type=AcceleratorType['GPU'],
326
+ accelerator_type=AcceleratorType.GPU,
320
327
  device_type='h100-mega-80gb-8',
321
328
  supports_sub_slicing=False,
322
329
  ),
@@ -605,7 +612,7 @@ UserFacingNameToSystemCharacteristics = {
605
612
  gke_accelerator='N/A',
606
613
  gce_machine_type='m1-megamem-96',
607
614
  chips_per_vm=96,
608
- accelerator_type=AcceleratorType['CPU'],
615
+ accelerator_type=AcceleratorType.CPU,
609
616
  device_type='m1-megamem-96-1',
610
617
  supports_sub_slicing=False,
611
618
  ),
@@ -616,7 +623,7 @@ UserFacingNameToSystemCharacteristics = {
616
623
  gke_accelerator='N/A',
617
624
  gce_machine_type='n2-standard-64',
618
625
  chips_per_vm=64,
619
- accelerator_type=AcceleratorType['CPU'],
626
+ accelerator_type=AcceleratorType.CPU,
620
627
  device_type='n2-standard-64-1',
621
628
  supports_sub_slicing=False,
622
629
  ),
@@ -626,7 +633,7 @@ UserFacingNameToSystemCharacteristics = {
626
633
  gke_accelerator='N/A',
627
634
  gce_machine_type='n2-standard-32',
628
635
  chips_per_vm=32,
629
- accelerator_type=AcceleratorType['CPU'],
636
+ accelerator_type=AcceleratorType.CPU,
630
637
  device_type='n2-standard-32-1',
631
638
  supports_sub_slicing=False,
632
639
  ),
@@ -636,7 +643,7 @@ UserFacingNameToSystemCharacteristics = {
636
643
  gke_accelerator='N/A',
637
644
  gce_machine_type='n2-standard-32',
638
645
  chips_per_vm=32,
639
- accelerator_type=AcceleratorType['CPU'],
646
+ accelerator_type=AcceleratorType.CPU,
640
647
  device_type='n2-standard-32-2',
641
648
  supports_sub_slicing=False,
642
649
  ),
@@ -646,7 +653,7 @@ UserFacingNameToSystemCharacteristics = {
646
653
  gke_accelerator='N/A',
647
654
  gce_machine_type='n2-standard-32',
648
655
  chips_per_vm=32,
649
- accelerator_type=AcceleratorType['CPU'],
656
+ accelerator_type=AcceleratorType.CPU,
650
657
  device_type='n2-standard-32-4',
651
658
  supports_sub_slicing=False,
652
659
  ),
@@ -656,7 +663,7 @@ UserFacingNameToSystemCharacteristics = {
656
663
  gke_accelerator='N/A',
657
664
  gce_machine_type='n2-standard-32',
658
665
  chips_per_vm=32,
659
- accelerator_type=AcceleratorType['CPU'],
666
+ accelerator_type=AcceleratorType.CPU,
660
667
  device_type='n2-standard-32-8',
661
668
  supports_sub_slicing=False,
662
669
  ),
@@ -666,7 +673,7 @@ UserFacingNameToSystemCharacteristics = {
666
673
  gke_accelerator='N/A',
667
674
  gce_machine_type='n2-standard-32',
668
675
  chips_per_vm=32,
669
- accelerator_type=AcceleratorType['CPU'],
676
+ accelerator_type=AcceleratorType.CPU,
670
677
  device_type='n2-standard-32-16',
671
678
  supports_sub_slicing=False,
672
679
  ),
@@ -676,7 +683,7 @@ UserFacingNameToSystemCharacteristics = {
676
683
  gke_accelerator='N/A',
677
684
  gce_machine_type='n2-standard-32',
678
685
  chips_per_vm=32,
679
- accelerator_type=AcceleratorType['CPU'],
686
+ accelerator_type=AcceleratorType.CPU,
680
687
  device_type='n2-standard-32-32',
681
688
  supports_sub_slicing=False,
682
689
  ),
@@ -686,7 +693,7 @@ UserFacingNameToSystemCharacteristics = {
686
693
  gke_accelerator='N/A',
687
694
  gce_machine_type='n2-standard-32',
688
695
  chips_per_vm=32,
689
- accelerator_type=AcceleratorType['CPU'],
696
+ accelerator_type=AcceleratorType.CPU,
690
697
  device_type='n2-standard-32-64',
691
698
  supports_sub_slicing=False,
692
699
  ),
@@ -696,7 +703,7 @@ UserFacingNameToSystemCharacteristics = {
696
703
  gke_accelerator='N/A',
697
704
  gce_machine_type='n2-standard-32',
698
705
  chips_per_vm=32,
699
- accelerator_type=AcceleratorType['CPU'],
706
+ accelerator_type=AcceleratorType.CPU,
700
707
  device_type='n2-standard-32-128',
701
708
  supports_sub_slicing=False,
702
709
  ),
@@ -706,7 +713,7 @@ UserFacingNameToSystemCharacteristics = {
706
713
  gke_accelerator='N/A',
707
714
  gce_machine_type='n2-standard-32',
708
715
  chips_per_vm=32,
709
- accelerator_type=AcceleratorType['CPU'],
716
+ accelerator_type=AcceleratorType.CPU,
710
717
  device_type='n2-standard-32-256',
711
718
  supports_sub_slicing=False,
712
719
  ),
@@ -716,7 +723,7 @@ UserFacingNameToSystemCharacteristics = {
716
723
  gke_accelerator='N/A',
717
724
  gce_machine_type='n2-standard-32',
718
725
  chips_per_vm=32,
719
- accelerator_type=AcceleratorType['CPU'],
726
+ accelerator_type=AcceleratorType.CPU,
720
727
  device_type='n2-standard-32-512',
721
728
  supports_sub_slicing=False,
722
729
  ),
@@ -726,7 +733,7 @@ UserFacingNameToSystemCharacteristics = {
726
733
  gke_accelerator='N/A',
727
734
  gce_machine_type='n2-standard-32',
728
735
  chips_per_vm=32,
729
- accelerator_type=AcceleratorType['CPU'],
736
+ accelerator_type=AcceleratorType.CPU,
730
737
  device_type='n2-standard-32-1024',
731
738
  supports_sub_slicing=False,
732
739
  ),
@@ -736,7 +743,7 @@ UserFacingNameToSystemCharacteristics = {
736
743
  gke_accelerator='N/A',
737
744
  gce_machine_type='n2-standard-32',
738
745
  chips_per_vm=32,
739
- accelerator_type=AcceleratorType['CPU'],
746
+ accelerator_type=AcceleratorType.CPU,
740
747
  device_type='n2-standard-32-2048',
741
748
  supports_sub_slicing=False,
742
749
  ),
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from .system_characteristics import get_tpu_system_characteristics_map, SystemCharacteristics
17
+ from .system_characteristics import get_tpu_system_characteristics_map, SystemCharacteristics, AcceleratorType
18
18
 
19
19
 
20
20
  def test_get_tpu_system_characteristics_map_returns_correct_values_for_1x1_topology():
@@ -34,7 +34,7 @@ def test_get_tpu_system_characteristics_map_returns_correct_values_for_1x1_topol
34
34
  gke_accelerator="test",
35
35
  gce_machine_type="test",
36
36
  chips_per_vm=1,
37
- accelerator_type=1,
37
+ accelerator_type=AcceleratorType.TPU,
38
38
  device_type="test-1",
39
39
  supports_sub_slicing=False,
40
40
  requires_workload_policy=True,
@@ -62,7 +62,7 @@ def test_get_tpu_system_characteristics_map_returns_correct_values_for_2x2_topol
62
62
  gke_accelerator="test",
63
63
  gce_machine_type="test",
64
64
  chips_per_vm=4,
65
- accelerator_type=1,
65
+ accelerator_type=AcceleratorType.TPU,
66
66
  device_type="test-8",
67
67
  supports_sub_slicing=False,
68
68
  requires_workload_policy=True,
@@ -0,0 +1,15 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
@@ -0,0 +1,131 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import re
18
+ from pytest_mock import MockerFixture
19
+
20
+
21
+ class CommandsTester:
22
+ """Tester class useful for mocking and asserting command runs."""
23
+
24
+ def __init__(
25
+ self,
26
+ mocker: MockerFixture,
27
+ run_command_for_value_path: str | None = None,
28
+ run_command_with_updates_path: str | None = None,
29
+ run_command_with_updates_retry_path: str | None = None,
30
+ ):
31
+ self.__results: dict[re.Pattern, tuple[int, str]] = {}
32
+ self.commands_history: list[str] = []
33
+ if run_command_for_value_path:
34
+ mocker.patch(
35
+ run_command_for_value_path,
36
+ wraps=self.__fake_run_command_for_value,
37
+ )
38
+ if run_command_with_updates_path:
39
+ mocker.patch(
40
+ run_command_with_updates_path,
41
+ wraps=self.__fake_run_command_with_updates,
42
+ )
43
+ if run_command_with_updates_retry_path:
44
+ mocker.patch(
45
+ run_command_with_updates_retry_path,
46
+ wraps=self.__fake_run_command_with_updates_retry,
47
+ )
48
+
49
+ def set_result_for_command(
50
+ self, result: tuple[int, str], *command_parts: str
51
+ ):
52
+ """Sets the result for the given command parts.
53
+ The command parts will be joined with '.*' during comparison with the actual commands.
54
+ """
55
+ pattern = self.__get_pattern_for_command_parts(*command_parts)
56
+ self.__results[pattern] = result
57
+
58
+ def assert_command_run(self, *command_parts: str, times: int = 1):
59
+ """Asserts the command composed from the command parts (joined with '.*') was run exactly `times` times."""
60
+ matching = self.get_matching_commands(*command_parts)
61
+ if not matching:
62
+ raise AssertionError(
63
+ f"{command_parts} was not found in {self.commands_history}"
64
+ )
65
+ elif len(matching) != times:
66
+ raise AssertionError(
67
+ f"{command_parts} was expected to be run {times} times in"
68
+ f" {self.commands_history}"
69
+ )
70
+
71
+ def assert_command_not_run(self, *command_parts: str):
72
+ """Asserts the command composed from the command parts (joined with '.*') was never run."""
73
+ if self.get_matching_commands(*command_parts):
74
+ raise AssertionError(
75
+ f"{command_parts} was found in {self.commands_history}"
76
+ )
77
+
78
+ def get_matching_commands(self, *command_parts: str) -> list[str]:
79
+ """Returns list of already run commands matching the command parts (joined with '.*')."""
80
+ pattern = self.__get_pattern_for_command_parts(*command_parts)
81
+ return [c for c in self.commands_history if pattern.match(c)]
82
+
83
+ # Unused arguments, but the signature has to match the original one:
84
+ # pylint: disable=unused-argument
85
+ def __fake_run_command_with_updates(
86
+ self,
87
+ command: str,
88
+ task: str,
89
+ verbose=True,
90
+ ) -> int:
91
+ return self.__common_fake_run_command(command, (0, ""))[0]
92
+
93
+ def __fake_run_command_with_updates_retry(
94
+ self,
95
+ command: str,
96
+ task: str,
97
+ verbose=True,
98
+ num_retry_attempts=5,
99
+ wait_seconds=10,
100
+ ) -> int:
101
+ return self.__common_fake_run_command(command, (0, ""))[0]
102
+
103
+ def __fake_run_command_for_value(
104
+ self,
105
+ command: str,
106
+ task: str,
107
+ dry_run_return_val="0",
108
+ print_timer=False,
109
+ hide_error=False,
110
+ quiet=False,
111
+ ) -> tuple[int, str]:
112
+ return self.__common_fake_run_command(command, (0, dry_run_return_val))
113
+
114
+ # pylint: enable=unused-argument
115
+
116
+ def __common_fake_run_command(
117
+ self,
118
+ command: str,
119
+ default_result: tuple[int, str],
120
+ ) -> tuple[int, str]:
121
+ self.commands_history.append(command)
122
+ matching_results = [
123
+ result
124
+ for pattern, result in self.__results.items()
125
+ if pattern.match(command)
126
+ ]
127
+ return len(matching_results) > 0 and matching_results[0] or default_result
128
+
129
+ def __get_pattern_for_command_parts(self, *command_parts: str) -> re.Pattern:
130
+ pattern_s = ".*" + ".*".join(map(re.escape, command_parts)) + ".*"
131
+ return re.compile(pattern_s)