xpk 0.14.4__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. integration/README.md +19 -0
  2. integration/gcluster_a3mega_test.py +11 -0
  3. integration/gcluster_a3ultra_test.py +11 -0
  4. integration/gcluster_a4_test.py +11 -0
  5. xpk/blueprints/a3mega/config-map.yaml.tftpl +15 -0
  6. xpk/blueprints/a3mega/storage_crd.yaml +52 -0
  7. xpk/blueprints/a3ultra/config-map.yaml.tftpl +15 -0
  8. xpk/blueprints/a3ultra/mlgru-disable.yaml +59 -0
  9. xpk/blueprints/a3ultra/nccl-installer.yaml +95 -0
  10. xpk/blueprints/a3ultra/storage_crd.yaml +52 -0
  11. xpk/blueprints/a4/config-map.yaml.tftpl +15 -0
  12. xpk/blueprints/a4/nccl-rdma-installer-a4.yaml +66 -0
  13. xpk/blueprints/a4/storage_crd.yaml +52 -0
  14. xpk/commands/cluster.py +89 -32
  15. xpk/commands/cluster_gcluster.py +25 -5
  16. xpk/commands/cluster_gcluster_test.py +16 -3
  17. xpk/commands/cluster_test.py +353 -7
  18. xpk/commands/config.py +3 -5
  19. xpk/commands/inspector.py +5 -3
  20. xpk/commands/kind.py +3 -1
  21. xpk/commands/managed_ml_diagnostics.py +249 -0
  22. xpk/commands/managed_ml_diagnostics_test.py +146 -0
  23. xpk/commands/storage.py +8 -10
  24. xpk/commands/workload.py +143 -142
  25. xpk/commands/workload_test.py +160 -118
  26. xpk/core/blueprint/blueprint_generator.py +73 -33
  27. xpk/core/blueprint/blueprint_test.py +9 -0
  28. xpk/core/blueprint/testing/data/a3_mega.yaml +129 -0
  29. xpk/core/blueprint/testing/data/a3_mega_spot.yaml +125 -0
  30. xpk/core/blueprint/testing/data/a3_ultra.yaml +173 -0
  31. xpk/core/blueprint/testing/data/a4.yaml +185 -0
  32. xpk/core/capacity.py +48 -8
  33. xpk/core/capacity_test.py +32 -1
  34. xpk/core/cluster.py +55 -104
  35. xpk/core/cluster_test.py +170 -0
  36. xpk/core/commands.py +4 -10
  37. xpk/core/config.py +88 -7
  38. xpk/core/config_test.py +67 -11
  39. xpk/core/docker_container.py +3 -1
  40. xpk/core/docker_image.py +10 -6
  41. xpk/core/docker_resources.py +1 -10
  42. xpk/core/gcloud_context.py +18 -12
  43. xpk/core/gcloud_context_test.py +111 -1
  44. xpk/core/kjob.py +17 -19
  45. xpk/core/kueue_manager.py +205 -51
  46. xpk/core/kueue_manager_test.py +158 -4
  47. xpk/core/nap.py +13 -14
  48. xpk/core/nodepool.py +37 -43
  49. xpk/core/nodepool_test.py +42 -19
  50. xpk/core/pathways.py +23 -0
  51. xpk/core/pathways_test.py +57 -0
  52. xpk/core/resources.py +84 -27
  53. xpk/core/scheduling.py +144 -133
  54. xpk/core/scheduling_test.py +298 -6
  55. xpk/core/system_characteristics.py +256 -19
  56. xpk/core/system_characteristics_test.py +128 -5
  57. xpk/core/telemetry.py +263 -0
  58. xpk/core/telemetry_test.py +211 -0
  59. xpk/core/vertex.py +4 -3
  60. xpk/core/workload_decorators/tcpx_decorator.py +5 -1
  61. xpk/main.py +33 -13
  62. xpk/parser/cluster.py +40 -67
  63. xpk/parser/cluster_test.py +83 -3
  64. xpk/parser/common.py +84 -0
  65. xpk/parser/storage.py +10 -0
  66. xpk/parser/storage_test.py +47 -0
  67. xpk/parser/workload.py +14 -29
  68. xpk/parser/workload_test.py +3 -49
  69. xpk/telemetry_uploader.py +29 -0
  70. xpk/templates/arm_gpu_workload_crate.yaml.j2 +46 -0
  71. xpk/templates/kueue_gke_default_topology.yaml.j2 +1 -1
  72. xpk/templates/kueue_sub_slicing_topology.yaml.j2 +3 -8
  73. xpk/utils/console.py +41 -10
  74. xpk/utils/console_test.py +106 -0
  75. xpk/utils/feature_flags.py +10 -1
  76. xpk/utils/file.py +4 -1
  77. xpk/utils/topology.py +4 -0
  78. xpk/utils/user_agent.py +35 -0
  79. xpk/utils/user_agent_test.py +44 -0
  80. xpk/utils/user_input.py +48 -0
  81. xpk/utils/user_input_test.py +92 -0
  82. xpk/utils/validation.py +2 -13
  83. xpk/utils/versions.py +31 -0
  84. xpk-0.16.0.dist-info/METADATA +127 -0
  85. xpk-0.16.0.dist-info/RECORD +168 -0
  86. xpk-0.14.4.dist-info/METADATA +0 -1645
  87. xpk-0.14.4.dist-info/RECORD +0 -139
  88. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/WHEEL +0 -0
  89. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/entry_points.txt +0 -0
  90. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/licenses/LICENSE +0 -0
  91. {xpk-0.14.4.dist-info → xpk-0.16.0.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,29 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  from dataclasses import dataclass
18
+ import dataclasses
19
+ from typing import Callable, Literal, Optional
20
+
21
+ from ..core.workload_decorators import rdma_decorator, tcpxo_decorator, tcpx_decorator
18
22
  from ..utils.topology import get_topology_product
19
23
  from enum import Enum
20
24
 
25
+ SUB_SLICING_TOPOLOGIES = ['2x4', '4x4', '4x8', '8x8', '8x16', '16x16']
26
+
27
+ INSTALLER_NCCL_TCPX = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-tcpx-installer.yaml'
28
+ INSTALLER_NCCL_TCPXO = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpxo/nccl-tcpxo-installer.yaml'
29
+ INSTALLER_NCCL_RDMA = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-rdma/nccl-rdma-installer.yaml'
30
+ INSTALLER_NCCL_RDMA_A4X = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-rdma/nccl-rdma-installer-a4x.yaml'
31
+
32
+
33
+ class DockerPlatform(str, Enum):
34
+ AMD = 'linux/amd64'
35
+ ARM = 'linux/arm64'
36
+
37
+
38
+ AMD_PLATFORM = DockerPlatform.AMD
39
+ ARM_PLATFORM = DockerPlatform.ARM
40
+
21
41
 
22
42
  class AcceleratorType(Enum):
23
43
  TPU = 1
@@ -54,6 +74,45 @@ AcceleratorTypeToAcceleratorCharacteristics = {
54
74
  }
55
75
 
56
76
 
77
+ @dataclass
78
+ class GpuConfig:
79
+ """Contains GPU-specific configuration and requirements."""
80
+
81
+ requires_topology: bool
82
+ gpu_direct_name: Literal['fastrak', 'rdma', 'tcpx', 'tcpxo'] = 'fastrak'
83
+ kjob_decorator_fn: Optional[Callable[[dict], dict]] = None
84
+ """A function to decorate the kjob template for GPU-specific configurations.
85
+
86
+ Args:
87
+ job_manifest (dict): The kjob manifest as a dictionary.
88
+
89
+ Returns:
90
+ dict: The modified kjob manifest as a dictionary.
91
+ """
92
+ nccl_installer: Optional[str] = None
93
+ jobset_decorator_fn: Optional[Callable[[str, list[str]], str]] = None
94
+ """A function to decorate the jobset for GPU-specific configurations.
95
+
96
+ Args:
97
+ jobset_manifest_str (str): The JobSet manifest as a YAML string.
98
+ sub_networks (list[str], optional): A list of sub-network names, used by some decorators.
99
+
100
+ Returns:
101
+ str: The modified JobSet manifest as a YAML string.
102
+ """
103
+
104
+ def __repr__(self) -> str:
105
+ """Returns a string representation of the GpuConfig, omitting memory addresses for functions."""
106
+ parts = []
107
+ for f in dataclasses.fields(self):
108
+ value = getattr(self, f.name)
109
+ if f.name in ('kjob_decorator_fn', 'jobset_decorator_fn') and value:
110
+ parts.append(f'{f.name}=<function {value.__name__}>')
111
+ else:
112
+ parts.append(f'{f.name}={repr(value)}')
113
+ return f"GpuConfig({', '.join(parts)})"
114
+
115
+
57
116
  @dataclass
58
117
  class SystemCharacteristics:
59
118
  """Contains the defining characteristics of a specific accelerator system.
@@ -90,12 +149,28 @@ class SystemCharacteristics:
90
149
  accelerator_type: AcceleratorType
91
150
  device_type: str
92
151
  supports_sub_slicing: bool
152
+ docker_platform: DockerPlatform
93
153
  requires_workload_policy: bool = False
154
+ gpu_config: Optional[GpuConfig] = None
94
155
 
95
156
  def __post_init__(self):
96
157
  if self.accelerator_type == AcceleratorType.GPU:
97
158
  self.requires_workload_policy = True
98
159
 
160
+ if self.gpu_config is None:
161
+ raise ValueError(
162
+ f"Validation Error: System '{self.device_type}' is a GPU, "
163
+ "but 'gpu_config' was not provided."
164
+ )
165
+
166
+ @property
167
+ def gpu_requires_topology(self) -> bool:
168
+ """
169
+ Safely returns whether the GPU config requires topology,
170
+ defaulting to False if no GPU config exists.
171
+ """
172
+ return self.gpu_config.requires_topology if self.gpu_config else False
173
+
99
174
 
100
175
  def get_system_characteristics(
101
176
  args,
@@ -131,6 +206,33 @@ def get_system_characteristics_by_device_type(
131
206
  return None, 1
132
207
 
133
208
 
209
+ def generate_tpu_topologies(
210
+ max_cubes: int, enforce_nondecreasing: bool = True
211
+ ) -> list[str]:
212
+ """Generates a list of unique TPU topologies formatted as strings "AxBxC".
213
+
214
+ The list will contain all triplets (A, B, C) such that:
215
+ - A, B and C are integers in range 4..256 (including 4 and 256)
216
+ - A, B and C are divisible by 4
217
+ - (A/4) * (B/4) * (C/4) <= max_cubes
218
+ - if enforce_nondecreasing: A <= B <= C
219
+ Additionally, the list will also contain the following triplets:
220
+ 2x2x1, 2x2x2, 2x2x4, 2x4x4
221
+
222
+ Args:
223
+ max_cubes: maximum number of cubes supported by a TPU platform
224
+ enforce_nondecreasing: whether to enforce A <= B <= C or not
225
+ """
226
+ topologies = ['2x2x1', '2x2x2', '2x2x4', '2x4x4']
227
+ MAX = 256
228
+ for x in range(4, MAX + 1, 4):
229
+ for y in range(x if enforce_nondecreasing else 4, MAX + 1, 4):
230
+ for z in range(y if enforce_nondecreasing else 4, MAX + 1, 4):
231
+ if (x // 4) * (y // 4) * (z // 4) <= max_cubes:
232
+ topologies.append(f'{x}x{y}x{z}')
233
+ return topologies
234
+
235
+
134
236
  def get_tpu_system_characteristics_map(
135
237
  prefix: str,
136
238
  tensorcores_per_chip: int,
@@ -138,13 +240,18 @@ def get_tpu_system_characteristics_map(
138
240
  machine_type: str,
139
241
  supported_topologies: list[str],
140
242
  supports_sub_slicing: bool,
141
- requires_workload_policy: bool = False,
243
+ docker_platform: DockerPlatform,
244
+ tpu_type_requires_workload_policy: bool = False,
245
+ default_topologies: set[str] | None = None,
142
246
  ) -> dict[str, SystemCharacteristics]:
143
247
  system_characteristics_map = {}
248
+ if default_topologies is None:
249
+ default_topologies = set()
144
250
  for topology in supported_topologies:
145
251
  chips_per_vm = compute_chips_per_vm(topology)
146
252
  vms_per_slice = compute_vms_per_slice(topology)
147
253
  num_tensorcores = compute_num_tensorcores(tensorcores_per_chip, topology)
254
+ device_type = f'{prefix}-{num_tensorcores}'
148
255
  system = SystemCharacteristics(
149
256
  topology=topology,
150
257
  vms_per_slice=vms_per_slice,
@@ -152,12 +259,18 @@ def get_tpu_system_characteristics_map(
152
259
  gce_machine_type=machine_type,
153
260
  chips_per_vm=chips_per_vm,
154
261
  accelerator_type=AcceleratorType.TPU,
155
- device_type=f'{prefix}-{num_tensorcores}',
156
- requires_workload_policy=requires_workload_policy,
262
+ device_type=device_type,
263
+ requires_workload_policy=tpu_type_requires_workload_policy
264
+ and vms_per_slice > 1,
157
265
  supports_sub_slicing=supports_sub_slicing,
266
+ docker_platform=docker_platform,
158
267
  )
159
268
  system_characteristics_map[f'{prefix}-{topology}'] = system
160
- system_characteristics_map[f'{prefix}-{num_tensorcores}'] = system
269
+ if (
270
+ topology in default_topologies
271
+ or device_type not in system_characteristics_map
272
+ ):
273
+ system_characteristics_map[device_type] = system
161
274
 
162
275
  return system_characteristics_map
163
276
 
@@ -193,6 +306,8 @@ UserFacingNameToSystemCharacteristics = {
193
306
  accelerator_type=AcceleratorType.GPU,
194
307
  device_type='l4-1',
195
308
  supports_sub_slicing=False,
309
+ gpu_config=GpuConfig(requires_topology=False),
310
+ docker_platform=AMD_PLATFORM,
196
311
  ),
197
312
  'l4-2': SystemCharacteristics(
198
313
  topology='N/A',
@@ -203,6 +318,8 @@ UserFacingNameToSystemCharacteristics = {
203
318
  accelerator_type=AcceleratorType.GPU,
204
319
  device_type='l4-2',
205
320
  supports_sub_slicing=False,
321
+ gpu_config=GpuConfig(requires_topology=False),
322
+ docker_platform=AMD_PLATFORM,
206
323
  ),
207
324
  'l4-4': SystemCharacteristics(
208
325
  topology='N/A',
@@ -213,6 +330,8 @@ UserFacingNameToSystemCharacteristics = {
213
330
  accelerator_type=AcceleratorType.GPU,
214
331
  device_type='l4-4',
215
332
  supports_sub_slicing=False,
333
+ gpu_config=GpuConfig(requires_topology=False),
334
+ docker_platform=AMD_PLATFORM,
216
335
  ),
217
336
  'l4-8': SystemCharacteristics(
218
337
  topology='N/A',
@@ -223,6 +342,8 @@ UserFacingNameToSystemCharacteristics = {
223
342
  accelerator_type=AcceleratorType.GPU,
224
343
  device_type='l4-8',
225
344
  supports_sub_slicing=False,
345
+ gpu_config=GpuConfig(requires_topology=False),
346
+ docker_platform=AMD_PLATFORM,
226
347
  ),
227
348
  # A100-40gb-$CHIPSc
228
349
  'a100-40gb-1': SystemCharacteristics(
@@ -234,6 +355,8 @@ UserFacingNameToSystemCharacteristics = {
234
355
  accelerator_type=AcceleratorType.GPU,
235
356
  device_type='a100-40gb-1',
236
357
  supports_sub_slicing=False,
358
+ gpu_config=GpuConfig(requires_topology=False),
359
+ docker_platform=AMD_PLATFORM,
237
360
  ),
238
361
  'a100-40gb-2': SystemCharacteristics(
239
362
  topology='N/A',
@@ -244,6 +367,8 @@ UserFacingNameToSystemCharacteristics = {
244
367
  accelerator_type=AcceleratorType.GPU,
245
368
  device_type='a100-40gb-2',
246
369
  supports_sub_slicing=False,
370
+ gpu_config=GpuConfig(requires_topology=False),
371
+ docker_platform=AMD_PLATFORM,
247
372
  ),
248
373
  'a100-40gb-4': SystemCharacteristics(
249
374
  topology='N/A',
@@ -254,6 +379,8 @@ UserFacingNameToSystemCharacteristics = {
254
379
  accelerator_type=AcceleratorType.GPU,
255
380
  device_type='a100-40gb-4',
256
381
  supports_sub_slicing=False,
382
+ gpu_config=GpuConfig(requires_topology=False),
383
+ docker_platform=AMD_PLATFORM,
257
384
  ),
258
385
  'a100-40gb-8': SystemCharacteristics(
259
386
  topology='N/A',
@@ -264,6 +391,8 @@ UserFacingNameToSystemCharacteristics = {
264
391
  accelerator_type=AcceleratorType.GPU,
265
392
  device_type='a100-40gb-8',
266
393
  supports_sub_slicing=False,
394
+ gpu_config=GpuConfig(requires_topology=False),
395
+ docker_platform=AMD_PLATFORM,
267
396
  ),
268
397
  'gb200-4': SystemCharacteristics(
269
398
  topology='1x72',
@@ -274,6 +403,14 @@ UserFacingNameToSystemCharacteristics = {
274
403
  accelerator_type=AcceleratorType.GPU,
275
404
  device_type='gb200-4',
276
405
  supports_sub_slicing=False,
406
+ gpu_config=GpuConfig(
407
+ requires_topology=True,
408
+ nccl_installer=INSTALLER_NCCL_RDMA_A4X,
409
+ kjob_decorator_fn=rdma_decorator.decorate_kjob_template,
410
+ jobset_decorator_fn=rdma_decorator.decorate_jobset,
411
+ gpu_direct_name='rdma',
412
+ ),
413
+ docker_platform=ARM_PLATFORM,
277
414
  ),
278
415
  'gb200-4-nolssd': SystemCharacteristics(
279
416
  topology='1x72',
@@ -284,6 +421,14 @@ UserFacingNameToSystemCharacteristics = {
284
421
  accelerator_type=AcceleratorType.GPU,
285
422
  device_type='gb200-4',
286
423
  supports_sub_slicing=False,
424
+ gpu_config=GpuConfig(
425
+ requires_topology=True,
426
+ nccl_installer=INSTALLER_NCCL_RDMA_A4X,
427
+ kjob_decorator_fn=rdma_decorator.decorate_kjob_template,
428
+ jobset_decorator_fn=rdma_decorator.decorate_jobset,
429
+ gpu_direct_name='rdma',
430
+ ),
431
+ docker_platform=ARM_PLATFORM,
287
432
  ),
288
433
  'b200-8': SystemCharacteristics(
289
434
  topology='N/A',
@@ -294,6 +439,14 @@ UserFacingNameToSystemCharacteristics = {
294
439
  accelerator_type=AcceleratorType.GPU,
295
440
  device_type='b200-8',
296
441
  supports_sub_slicing=False,
442
+ gpu_config=GpuConfig(
443
+ requires_topology=True,
444
+ nccl_installer=INSTALLER_NCCL_RDMA,
445
+ kjob_decorator_fn=rdma_decorator.decorate_kjob_template,
446
+ jobset_decorator_fn=rdma_decorator.decorate_jobset,
447
+ gpu_direct_name='rdma',
448
+ ),
449
+ docker_platform=AMD_PLATFORM,
297
450
  ),
298
451
  'h200-141gb-8': SystemCharacteristics(
299
452
  topology='N/A',
@@ -304,6 +457,14 @@ UserFacingNameToSystemCharacteristics = {
304
457
  accelerator_type=AcceleratorType.GPU,
305
458
  device_type='h200-141gb-8',
306
459
  supports_sub_slicing=False,
460
+ gpu_config=GpuConfig(
461
+ requires_topology=True,
462
+ nccl_installer=INSTALLER_NCCL_RDMA,
463
+ kjob_decorator_fn=rdma_decorator.decorate_kjob_template,
464
+ jobset_decorator_fn=rdma_decorator.decorate_jobset,
465
+ gpu_direct_name='rdma',
466
+ ),
467
+ docker_platform=AMD_PLATFORM,
307
468
  ),
308
469
  # H100-80gb-$CHIPS
309
470
  'h100-80gb-8': SystemCharacteristics(
@@ -315,6 +476,14 @@ UserFacingNameToSystemCharacteristics = {
315
476
  accelerator_type=AcceleratorType.GPU,
316
477
  device_type='h100-80gb-8',
317
478
  supports_sub_slicing=False,
479
+ gpu_config=GpuConfig(
480
+ requires_topology=True,
481
+ nccl_installer=INSTALLER_NCCL_TCPX,
482
+ kjob_decorator_fn=tcpx_decorator.decorate_kjob_template,
483
+ jobset_decorator_fn=tcpx_decorator.decorate_jobset,
484
+ gpu_direct_name='tcpx',
485
+ ),
486
+ docker_platform=AMD_PLATFORM,
318
487
  ),
319
488
  # H100-mega-80gb-$CHIPS
320
489
  'h100-mega-80gb-8': SystemCharacteristics(
@@ -326,6 +495,14 @@ UserFacingNameToSystemCharacteristics = {
326
495
  accelerator_type=AcceleratorType.GPU,
327
496
  device_type='h100-mega-80gb-8',
328
497
  supports_sub_slicing=False,
498
+ gpu_config=GpuConfig(
499
+ requires_topology=True,
500
+ nccl_installer=INSTALLER_NCCL_TCPXO,
501
+ kjob_decorator_fn=tcpxo_decorator.decorate_kjob_template,
502
+ jobset_decorator_fn=tcpxo_decorator.decorate_jobset,
503
+ gpu_direct_name='tcpxo',
504
+ ),
505
+ docker_platform=AMD_PLATFORM,
329
506
  ),
330
507
  # TPU system characteristics
331
508
  **get_tpu_system_characteristics_map(
@@ -334,17 +511,20 @@ UserFacingNameToSystemCharacteristics = {
334
511
  gke_accelerator='tpu7x',
335
512
  machine_type='tpu7x-standard-1t',
336
513
  supported_topologies=['1x1x1'],
337
- requires_workload_policy=True,
514
+ tpu_type_requires_workload_policy=True,
338
515
  supports_sub_slicing=False,
516
+ docker_platform=AMD_PLATFORM,
339
517
  ),
340
518
  **get_tpu_system_characteristics_map(
341
519
  prefix='tpu7x',
342
520
  tensorcores_per_chip=2,
343
521
  gke_accelerator='tpu7x',
344
522
  machine_type='tpu7x-standard-4t',
345
- requires_workload_policy=True,
523
+ tpu_type_requires_workload_policy=True,
346
524
  supports_sub_slicing=False,
347
- supported_topologies=[
525
+ docker_platform=AMD_PLATFORM,
526
+ supported_topologies=generate_tpu_topologies(max_cubes=144),
527
+ default_topologies=set([
348
528
  '12x12x12',
349
529
  '12x12x16',
350
530
  '12x12x20',
@@ -443,7 +623,7 @@ UserFacingNameToSystemCharacteristics = {
443
623
  '8x8x76',
444
624
  '8x8x8',
445
625
  '8x8x92',
446
- ],
626
+ ]),
447
627
  ),
448
628
  **get_tpu_system_characteristics_map(
449
629
  prefix='v6e',
@@ -452,22 +632,27 @@ UserFacingNameToSystemCharacteristics = {
452
632
  machine_type='ct6e-standard-1t',
453
633
  supports_sub_slicing=False,
454
634
  supported_topologies=['1x1'],
635
+ docker_platform=AMD_PLATFORM,
455
636
  ),
456
637
  **get_tpu_system_characteristics_map(
457
638
  prefix='v6e',
458
639
  tensorcores_per_chip=1,
459
640
  gke_accelerator='tpu-v6e-slice',
460
641
  machine_type='ct6e-standard-4t',
461
- supports_sub_slicing=True,
642
+ supports_sub_slicing=False,
462
643
  supported_topologies=[
463
644
  '2x2',
464
- '2x4',
465
- '4x4',
466
- '4x8',
467
- '8x8',
468
- '8x16',
469
- '16x16',
470
645
  ],
646
+ docker_platform=AMD_PLATFORM,
647
+ ),
648
+ **get_tpu_system_characteristics_map(
649
+ prefix='v6e',
650
+ tensorcores_per_chip=1,
651
+ gke_accelerator='tpu-v6e-slice',
652
+ machine_type='ct6e-standard-4t',
653
+ supports_sub_slicing=True,
654
+ supported_topologies=SUB_SLICING_TOPOLOGIES,
655
+ docker_platform=AMD_PLATFORM,
471
656
  ),
472
657
  **get_tpu_system_characteristics_map(
473
658
  prefix='v5p',
@@ -475,7 +660,9 @@ UserFacingNameToSystemCharacteristics = {
475
660
  gke_accelerator='tpu-v5p-slice',
476
661
  machine_type='ct5p-hightpu-4t',
477
662
  supports_sub_slicing=False,
478
- supported_topologies=[
663
+ docker_platform=AMD_PLATFORM,
664
+ supported_topologies=generate_tpu_topologies(max_cubes=140),
665
+ default_topologies=set([
479
666
  '2x2x1',
480
667
  '2x2x2',
481
668
  '2x2x4',
@@ -572,13 +759,14 @@ UserFacingNameToSystemCharacteristics = {
572
759
  '16x16x24',
573
760
  '12x24x24',
574
761
  '16x20x28',
575
- ],
762
+ ]),
576
763
  ),
577
764
  **get_tpu_system_characteristics_map(
578
765
  prefix='v5litepod',
579
766
  tensorcores_per_chip=1,
580
767
  gke_accelerator='tpu-v5-lite-podslice',
581
768
  machine_type='ct5lp-hightpu-4t',
769
+ docker_platform=AMD_PLATFORM,
582
770
  supports_sub_slicing=False,
583
771
  supported_topologies=['2x4', '4x4', '4x8', '8x8', '8x16', '16x16'],
584
772
  ),
@@ -587,8 +775,12 @@ UserFacingNameToSystemCharacteristics = {
587
775
  tensorcores_per_chip=2,
588
776
  gke_accelerator='tpu-v4-podslice',
589
777
  machine_type='ct4p-hightpu-4t',
778
+ docker_platform=AMD_PLATFORM,
590
779
  supports_sub_slicing=False,
591
- supported_topologies=[
780
+ supported_topologies=generate_tpu_topologies(
781
+ max_cubes=64, enforce_nondecreasing=False
782
+ ),
783
+ default_topologies=set([
592
784
  '2x2x1',
593
785
  '2x2x2',
594
786
  '2x2x4',
@@ -600,7 +792,7 @@ UserFacingNameToSystemCharacteristics = {
600
792
  '8x8x12',
601
793
  '8x8x16',
602
794
  '8x16x16',
603
- ],
795
+ ]),
604
796
  ),
605
797
  # CPU system characteristics.
606
798
  # Note that chips_per_vm is actually the number of vCPUs in that CPU.
@@ -615,6 +807,7 @@ UserFacingNameToSystemCharacteristics = {
615
807
  accelerator_type=AcceleratorType.CPU,
616
808
  device_type='m1-megamem-96-1',
617
809
  supports_sub_slicing=False,
810
+ docker_platform=AMD_PLATFORM,
618
811
  ),
619
812
  # n2-standard-#vCPUs-#VMs
620
813
  'n2-standard-64-1': SystemCharacteristics(
@@ -626,6 +819,7 @@ UserFacingNameToSystemCharacteristics = {
626
819
  accelerator_type=AcceleratorType.CPU,
627
820
  device_type='n2-standard-64-1',
628
821
  supports_sub_slicing=False,
822
+ docker_platform=AMD_PLATFORM,
629
823
  ),
630
824
  'n2-standard-32-1': SystemCharacteristics(
631
825
  topology='N/A',
@@ -636,6 +830,7 @@ UserFacingNameToSystemCharacteristics = {
636
830
  accelerator_type=AcceleratorType.CPU,
637
831
  device_type='n2-standard-32-1',
638
832
  supports_sub_slicing=False,
833
+ docker_platform=AMD_PLATFORM,
639
834
  ),
640
835
  'n2-standard-32-2': SystemCharacteristics(
641
836
  topology='N/A',
@@ -646,6 +841,7 @@ UserFacingNameToSystemCharacteristics = {
646
841
  accelerator_type=AcceleratorType.CPU,
647
842
  device_type='n2-standard-32-2',
648
843
  supports_sub_slicing=False,
844
+ docker_platform=AMD_PLATFORM,
649
845
  ),
650
846
  'n2-standard-32-4': SystemCharacteristics(
651
847
  topology='N/A',
@@ -656,6 +852,7 @@ UserFacingNameToSystemCharacteristics = {
656
852
  accelerator_type=AcceleratorType.CPU,
657
853
  device_type='n2-standard-32-4',
658
854
  supports_sub_slicing=False,
855
+ docker_platform=AMD_PLATFORM,
659
856
  ),
660
857
  'n2-standard-32-8': SystemCharacteristics(
661
858
  topology='N/A',
@@ -666,6 +863,7 @@ UserFacingNameToSystemCharacteristics = {
666
863
  accelerator_type=AcceleratorType.CPU,
667
864
  device_type='n2-standard-32-8',
668
865
  supports_sub_slicing=False,
866
+ docker_platform=AMD_PLATFORM,
669
867
  ),
670
868
  'n2-standard-32-16': SystemCharacteristics(
671
869
  topology='N/A',
@@ -676,6 +874,7 @@ UserFacingNameToSystemCharacteristics = {
676
874
  accelerator_type=AcceleratorType.CPU,
677
875
  device_type='n2-standard-32-16',
678
876
  supports_sub_slicing=False,
877
+ docker_platform=AMD_PLATFORM,
679
878
  ),
680
879
  'n2-standard-32-32': SystemCharacteristics(
681
880
  topology='N/A',
@@ -686,6 +885,7 @@ UserFacingNameToSystemCharacteristics = {
686
885
  accelerator_type=AcceleratorType.CPU,
687
886
  device_type='n2-standard-32-32',
688
887
  supports_sub_slicing=False,
888
+ docker_platform=AMD_PLATFORM,
689
889
  ),
690
890
  'n2-standard-32-64': SystemCharacteristics(
691
891
  topology='N/A',
@@ -696,6 +896,7 @@ UserFacingNameToSystemCharacteristics = {
696
896
  accelerator_type=AcceleratorType.CPU,
697
897
  device_type='n2-standard-32-64',
698
898
  supports_sub_slicing=False,
899
+ docker_platform=AMD_PLATFORM,
699
900
  ),
700
901
  'n2-standard-32-128': SystemCharacteristics(
701
902
  topology='N/A',
@@ -706,6 +907,7 @@ UserFacingNameToSystemCharacteristics = {
706
907
  accelerator_type=AcceleratorType.CPU,
707
908
  device_type='n2-standard-32-128',
708
909
  supports_sub_slicing=False,
910
+ docker_platform=AMD_PLATFORM,
709
911
  ),
710
912
  'n2-standard-32-256': SystemCharacteristics(
711
913
  topology='N/A',
@@ -716,6 +918,7 @@ UserFacingNameToSystemCharacteristics = {
716
918
  accelerator_type=AcceleratorType.CPU,
717
919
  device_type='n2-standard-32-256',
718
920
  supports_sub_slicing=False,
921
+ docker_platform=AMD_PLATFORM,
719
922
  ),
720
923
  'n2-standard-32-512': SystemCharacteristics(
721
924
  topology='N/A',
@@ -726,6 +929,7 @@ UserFacingNameToSystemCharacteristics = {
726
929
  accelerator_type=AcceleratorType.CPU,
727
930
  device_type='n2-standard-32-512',
728
931
  supports_sub_slicing=False,
932
+ docker_platform=AMD_PLATFORM,
729
933
  ),
730
934
  'n2-standard-32-1024': SystemCharacteristics(
731
935
  topology='N/A',
@@ -736,6 +940,7 @@ UserFacingNameToSystemCharacteristics = {
736
940
  accelerator_type=AcceleratorType.CPU,
737
941
  device_type='n2-standard-32-1024',
738
942
  supports_sub_slicing=False,
943
+ docker_platform=AMD_PLATFORM,
739
944
  ),
740
945
  'n2-standard-32-2048': SystemCharacteristics(
741
946
  topology='N/A',
@@ -746,7 +951,39 @@ UserFacingNameToSystemCharacteristics = {
746
951
  accelerator_type=AcceleratorType.CPU,
747
952
  device_type='n2-standard-32-2048',
748
953
  supports_sub_slicing=False,
954
+ docker_platform=AMD_PLATFORM,
749
955
  ),
750
956
  }
751
957
  """ If you modify UserFacingNameToSystemCharacteristics you should also modify
752
958
  the corresponding Map in MaxText/accelerator_to_spec_map.py """
959
+
960
+
961
+ def get_system_characteristics_keys_by_accelerator_type(
962
+ accelerators: list[AcceleratorType] | None = None,
963
+ ) -> list[str]:
964
+ """Returns UserFacingNameToSystemCharacteristics keys for given AcceleratorTypes."""
965
+ if accelerators is None:
966
+ accelerators = list(AcceleratorType)
967
+ return [
968
+ key
969
+ for key, value in UserFacingNameToSystemCharacteristics.items()
970
+ if value.accelerator_type in accelerators
971
+ ]
972
+
973
+
974
+ def create_accelerator_label(system: SystemCharacteristics) -> str:
975
+ if system.accelerator_type == AcceleratorType.CPU:
976
+ return ''
977
+ return (
978
+ f'{AcceleratorTypeToAcceleratorCharacteristics[system.accelerator_type].accelerator_label}:'
979
+ f' {system.gke_accelerator}'
980
+ )
981
+
982
+
983
+ def create_machine_label(system: SystemCharacteristics) -> str:
984
+ if system.accelerator_type == AcceleratorType.TPU:
985
+ return (
986
+ f'{AcceleratorTypeToAcceleratorCharacteristics[system.accelerator_type].machine_label}:'
987
+ f' {system.topology}'
988
+ )
989
+ return ''