torchx-nightly 2023.10.21__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +7 -6
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +79 -5
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +10 -8
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +431 -32
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +254 -22
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2023.10.21.dist-info/METADATA +0 -174
  108. torchx_nightly-2023.10.21.dist-info/RECORD +0 -118
  109. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
torchx/specs/finder.py CHANGED
@@ -4,7 +4,10 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import abc
10
+ import copy
8
11
  import importlib
9
12
  import inspect
10
13
  import logging
@@ -17,11 +20,17 @@ from types import ModuleType
17
20
  from typing import Callable, Dict, Generator, List, Optional, Union
18
21
 
19
22
  from torchx.specs import AppDef
20
- from torchx.specs.file_linter import get_fn_docstring, validate
23
+
24
+ from torchx.specs.file_linter import (
25
+ ComponentFunctionValidator,
26
+ get_fn_docstring,
27
+ validate,
28
+ )
21
29
  from torchx.util import entrypoints
22
30
  from torchx.util.io import read_conf_file
23
31
  from torchx.util.types import none_throws
24
32
 
33
+
25
34
  logger: logging.Logger = logging.getLogger(__name__)
26
35
 
27
36
 
@@ -51,13 +60,17 @@ class _Component:
51
60
  name: str
52
61
  description: str
53
62
  fn_name: str
63
+
54
64
  fn: Callable[..., AppDef]
65
+
55
66
  validation_errors: List[str]
56
67
 
57
68
 
58
69
  class ComponentsFinder(abc.ABC):
59
70
  @abc.abstractmethod
60
- def find(self) -> List[_Component]:
71
+ def find(
72
+ self, validators: Optional[List[ComponentFunctionValidator]]
73
+ ) -> List[_Component]:
61
74
  """
62
75
  Retrieves a set of components. A component is defined as a python
63
76
  function that conforms to ``torchx.specs.file_linter`` linter.
@@ -201,10 +214,12 @@ class ModuleComponentsFinder(ComponentsFinder):
201
214
  else:
202
215
  yield self._try_import(module_info.name)
203
216
 
204
- def find(self) -> List[_Component]:
217
+ def find(
218
+ self, validators: Optional[List[ComponentFunctionValidator]]
219
+ ) -> List[_Component]:
205
220
  components = []
206
221
  for m in self._iter_modules_recursive(self.base_module):
207
- components += self._get_components_from_module(m)
222
+ components += self._get_components_from_module(m, validators)
208
223
  return components
209
224
 
210
225
  def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
@@ -219,7 +234,9 @@ class ModuleComponentsFinder(ComponentsFinder):
219
234
  else:
220
235
  return module
221
236
 
222
- def _get_components_from_module(self, module: ModuleType) -> List[_Component]:
237
+ def _get_components_from_module(
238
+ self, module: ModuleType, validators: Optional[List[ComponentFunctionValidator]]
239
+ ) -> List[_Component]:
223
240
  functions = getmembers(module, isfunction)
224
241
  component_defs = []
225
242
 
@@ -228,7 +245,7 @@ class ModuleComponentsFinder(ComponentsFinder):
228
245
  module_path = os.path.abspath(module_path)
229
246
  rel_module_name = module_relname(module, relative_to=self.base_module)
230
247
  for function_name, function in functions:
231
- linter_errors = validate(module_path, function_name)
248
+ linter_errors = validate(module_path, function_name, validators)
232
249
  component_desc, _ = get_fn_docstring(function)
233
250
 
234
251
  # remove empty string to deal with group=""
@@ -253,17 +270,26 @@ class CustomComponentsFinder(ComponentsFinder):
253
270
  self._filepath = filepath
254
271
  self._function_name = function_name
255
272
 
256
- def _get_validation_errors(self, path: str, function_name: str) -> List[str]:
257
- linter_errors = validate(path, function_name)
273
+ def _get_validation_errors(
274
+ self,
275
+ path: str,
276
+ function_name: str,
277
+ validators: Optional[List[ComponentFunctionValidator]],
278
+ ) -> List[str]:
279
+ linter_errors = validate(path, function_name, validators)
258
280
  return [linter_error.description for linter_error in linter_errors]
259
281
 
260
- def find(self) -> List[_Component]:
282
+ def find(
283
+ self, validators: Optional[List[ComponentFunctionValidator]]
284
+ ) -> List[_Component]:
261
285
  validation_errors = self._get_validation_errors(
262
- self._filepath, self._function_name
286
+ self._filepath, self._function_name, validators
263
287
  )
264
288
 
265
289
  file_source = read_conf_file(self._filepath)
266
- namespace = globals()
290
+ namespace = copy.copy(globals())
291
+ # so that __file__ used inside the component points to the correct file
292
+ namespace["__file__"] = os.path.abspath(self._filepath)
267
293
  exec(file_source, namespace) # noqa: P204
268
294
  if self._function_name not in namespace:
269
295
  raise ComponentNotFoundException(
@@ -282,7 +308,9 @@ class CustomComponentsFinder(ComponentsFinder):
282
308
  ]
283
309
 
284
310
 
285
- def _load_custom_components() -> List[_Component]:
311
+ def _load_custom_components(
312
+ validators: Optional[List[ComponentFunctionValidator]],
313
+ ) -> List[_Component]:
286
314
  component_modules = {
287
315
  name: load_fn()
288
316
  for name, load_fn in
@@ -301,11 +329,13 @@ def _load_custom_components() -> List[_Component]:
301
329
  # _0 = torchx.components.dist
302
330
  # _1 = torchx.components.utils
303
331
  group = "" if group.startswith("_") else group
304
- components += ModuleComponentsFinder(module, group).find()
332
+ components += ModuleComponentsFinder(module, group).find(validators)
305
333
  return components
306
334
 
307
335
 
308
- def _load_components() -> Dict[str, _Component]:
336
+ def _load_components(
337
+ validators: Optional[List[ComponentFunctionValidator]],
338
+ ) -> Dict[str, _Component]:
309
339
  """
310
340
  Loads either the custom component defs from the entrypoint ``[torchx.components]``
311
341
  or the default builtins from ``torchx.components`` module.
@@ -316,19 +346,21 @@ def _load_components() -> Dict[str, _Component]:
316
346
 
317
347
  """
318
348
 
319
- components = _load_custom_components()
349
+ components = _load_custom_components(validators)
320
350
  if not components:
321
- components = ModuleComponentsFinder("torchx.components", "").find()
351
+ components = ModuleComponentsFinder("torchx.components", "").find(validators)
322
352
  return {c.name: c for c in components}
323
353
 
324
354
 
325
355
  _components: Optional[Dict[str, _Component]] = None
326
356
 
327
357
 
328
- def _find_components() -> Dict[str, _Component]:
358
+ def _find_components(
359
+ validators: Optional[List[ComponentFunctionValidator]],
360
+ ) -> Dict[str, _Component]:
329
361
  global _components
330
362
  if not _components:
331
- _components = _load_components()
363
+ _components = _load_components(validators)
332
364
  return none_throws(_components)
333
365
 
334
366
 
@@ -336,17 +368,21 @@ def _is_custom_component(component_name: str) -> bool:
336
368
  return ":" in component_name
337
369
 
338
370
 
339
- def _find_custom_components(name: str) -> Dict[str, _Component]:
371
+ def _find_custom_components(
372
+ name: str, validators: Optional[List[ComponentFunctionValidator]]
373
+ ) -> Dict[str, _Component]:
340
374
  if ":" not in name:
341
375
  raise ValueError(
342
376
  f"Invalid custom component: {name}, valid template : `FILEPATH`:`FUNCTION_NAME`"
343
377
  )
344
378
  filepath, component_name = name.split(":")
345
- components = CustomComponentsFinder(filepath, component_name).find()
379
+ components = CustomComponentsFinder(filepath, component_name).find(validators)
346
380
  return {component.name: component for component in components}
347
381
 
348
382
 
349
- def get_components() -> Dict[str, _Component]:
383
+ def get_components(
384
+ validators: Optional[List[ComponentFunctionValidator]] = None,
385
+ ) -> Dict[str, _Component]:
350
386
  """
351
387
  Returns all custom components registered via ``[torchx.components]`` entrypoints
352
388
  OR builtin components that ship with TorchX (but not both).
@@ -393,13 +429,15 @@ def get_components() -> Dict[str, _Component]:
393
429
  """
394
430
 
395
431
  valid_components: Dict[str, _Component] = {}
396
- for component_name, component in _find_components().items():
432
+ for component_name, component in _find_components(validators).items():
397
433
  if len(component.validation_errors) == 0:
398
434
  valid_components[component_name] = component
399
435
  return valid_components
400
436
 
401
437
 
402
- def get_component(name: str) -> _Component:
438
+ def get_component(
439
+ name: str, validators: Optional[List[ComponentFunctionValidator]] = None
440
+ ) -> _Component:
403
441
  """
404
442
  Retrieves components by the provided name.
405
443
 
@@ -407,14 +445,14 @@ def get_component(name: str) -> _Component:
407
445
  Component or None if no component with ``name`` exists
408
446
  """
409
447
  if _is_custom_component(name):
410
- components = _find_custom_components(name)
448
+ components = _find_custom_components(name, validators)
411
449
  else:
412
- components = _find_components()
450
+ components = _find_components(validators)
413
451
  if name not in components:
414
452
  raise ComponentNotFoundException(
415
453
  f"Component `{name}` not found. Please make sure it is one of the "
416
454
  "builtins: `torchx builtins`. Or registered via `[torchx.components]` "
417
- "entry point (see: https://pytorch.org/torchx/latest/configure.html)"
455
+ "entry point (see: https://meta-pytorch.org/torchx/latest/configure.html)"
418
456
  )
419
457
 
420
458
  component = components[name]
@@ -426,7 +464,9 @@ def get_component(name: str) -> _Component:
426
464
  return component
427
465
 
428
466
 
429
- def get_builtin_source(name: str) -> str:
467
+ def get_builtin_source(
468
+ name: str, validators: Optional[List[ComponentFunctionValidator]] = None
469
+ ) -> str:
430
470
  """
431
471
  Returns a string of the the builtin component's function source code
432
472
  with all the import statements. Intended to be used to make a copy
@@ -444,7 +484,7 @@ def get_builtin_source(name: str) -> str:
444
484
  are optimized and formatting adheres to your organization's standards.
445
485
  """
446
486
 
447
- component = get_component(name)
487
+ component = get_component(name, validators)
448
488
  fn = component.fn
449
489
  fn_name = component.name.split(".")[-1]
450
490
 
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  r"""
8
10
  `torchx.specs.named_resources_aws` contains resource definitions that represent corresponding AWS instance types
9
11
  taken from https://aws.amazon.com/ec2/instance-types/. The resources are exposed
@@ -14,7 +16,7 @@ the equvalent resource in mem, cpu and gpu numbers.
14
16
 
15
17
  .. note::
16
18
  These resource definitions may change in future. It is expected for each user to
17
- manage their own resources. Follow https://pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
19
+ manage their own resources. Follow https://meta-pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
18
20
  to set up named resources.
19
21
 
20
22
  Usage:
@@ -29,22 +31,36 @@ Usage:
29
31
 
30
32
  """
31
33
 
34
+ import warnings
32
35
  from typing import Callable, Mapping
33
36
 
34
37
  from torchx.specs.api import Resource
35
38
 
36
39
  EFA_DEVICE = "vpc.amazonaws.com/efa"
40
+ NEURON_DEVICE = "aws.amazon.com/neurondevice"
37
41
 
38
42
  # ecs and ec2 have memtax and currently AWS Batch uses hard memory limits
39
43
  # so we have to account for mem tax when registering these resources for AWS
40
44
  # otherwise the job will be stuck in the jobqueue forever
41
- # 97% is based on empirical observation that works well for most instance types
45
+ # 96% is based on empirical observation that works well for most instance types
42
46
  # see: https://docs.aws.amazon.com/batch/latest/userguide/memory-management.html
43
- MEM_TAX = 0.97
47
+ MEM_TAX = 0.96
48
+
49
+ # determines instance type for non-honogeneous CEs
50
+ # see https://github.com/meta-pytorch/torchx/issues/780
44
51
  K8S_ITYPE = "node.kubernetes.io/instance-type"
45
52
  GiB: int = int(1024 * MEM_TAX)
46
53
 
47
54
 
55
+ def instance_type_from_resource(resource: Resource) -> str:
56
+ instance_type = resource.capabilities.get(K8S_ITYPE)
57
+ if instance_type is None:
58
+ warnings.warn(
59
+ "Cannot determine resource instance type which can cause issues for non-homogeneous CEs and multinode jobs. Consider providing torchx.specs.named_resources_aws:K8S_TYPE resource capability."
60
+ )
61
+ return instance_type
62
+
63
+
48
64
  def aws_p3_2xlarge() -> Resource:
49
65
  return Resource(
50
66
  cpu=8, gpu=1, memMB=61 * GiB, capabilities={K8S_ITYPE: "p3.2xlarge"}
@@ -94,6 +110,36 @@ def aws_p4de_24xlarge() -> Resource:
94
110
  )
95
111
 
96
112
 
113
+ def aws_p5_48xlarge() -> Resource:
114
+ return Resource(
115
+ cpu=192,
116
+ gpu=8,
117
+ memMB=2048 * GiB,
118
+ capabilities={K8S_ITYPE: "p5.48xlarge"},
119
+ devices={EFA_DEVICE: 32},
120
+ )
121
+
122
+
123
+ def aws_p5e_48xlarge() -> Resource:
124
+ return Resource(
125
+ cpu=192,
126
+ gpu=8,
127
+ memMB=2048 * GiB,
128
+ capabilities={K8S_ITYPE: "p5e.48xlarge"},
129
+ devices={EFA_DEVICE: 32},
130
+ )
131
+
132
+
133
+ def aws_p5en_48xlarge() -> Resource:
134
+ return Resource(
135
+ cpu=192,
136
+ gpu=8,
137
+ memMB=2048 * GiB,
138
+ capabilities={K8S_ITYPE: "p5en.48xlarge"},
139
+ devices={EFA_DEVICE: 16},
140
+ )
141
+
142
+
97
143
  def aws_t3_medium() -> Resource:
98
144
  return Resource(cpu=2, gpu=0, memMB=4 * GiB, capabilities={K8S_ITYPE: "t3.medium"})
99
145
 
@@ -104,6 +150,16 @@ def aws_m5_2xlarge() -> Resource:
104
150
  )
105
151
 
106
152
 
153
+ def aws_c5_18xlarge() -> Resource:
154
+ return Resource(
155
+ # using lower memory size than the spec since MEM_TAX is not enough for adjustment
156
+ cpu=72,
157
+ gpu=0,
158
+ memMB=142 * GiB,
159
+ capabilities={K8S_ITYPE: "c5.18xlarge"},
160
+ )
161
+
162
+
107
163
  def aws_g4dn_xlarge() -> Resource:
108
164
  return Resource(
109
165
  cpu=4, gpu=1, memMB=16 * GiB, capabilities={K8S_ITYPE: "g4dn.xlarge"}
@@ -124,25 +180,41 @@ def aws_g4dn_4xlarge() -> Resource:
124
180
 
125
181
  def aws_g4dn_8xlarge() -> Resource:
126
182
  return Resource(
127
- cpu=32, gpu=1, memMB=128 * GiB, capabilities={K8S_ITYPE: "g4dn.8xlarge"}
183
+ cpu=32,
184
+ gpu=1,
185
+ memMB=128 * GiB,
186
+ capabilities={K8S_ITYPE: "g4dn.8xlarge"},
187
+ devices={EFA_DEVICE: 1},
128
188
  )
129
189
 
130
190
 
131
- def aws_g4dn_16xlarge() -> Resource:
191
+ def aws_g4dn_12xlarge() -> Resource:
132
192
  return Resource(
133
- cpu=64, gpu=1, memMB=256 * GiB, capabilities={K8S_ITYPE: "g4dn.16xlarge"}
193
+ cpu=48,
194
+ gpu=4,
195
+ memMB=192 * GiB,
196
+ capabilities={K8S_ITYPE: "g4dn.12xlarge"},
197
+ devices={EFA_DEVICE: 1},
134
198
  )
135
199
 
136
200
 
137
- def aws_g4dn_12xlarge() -> Resource:
201
+ def aws_g4dn_16xlarge() -> Resource:
138
202
  return Resource(
139
- cpu=48, gpu=4, memMB=192 * GiB, capabilities={K8S_ITYPE: "g4dn.12xlarge"}
203
+ cpu=64,
204
+ gpu=1,
205
+ memMB=256 * GiB,
206
+ capabilities={K8S_ITYPE: "g4dn.16xlarge"},
207
+ devices={EFA_DEVICE: 1},
140
208
  )
141
209
 
142
210
 
143
211
  def aws_g4dn_metal() -> Resource:
144
212
  return Resource(
145
- cpu=96, gpu=8, memMB=384 * GiB, capabilities={K8S_ITYPE: "g4dn.metal"}
213
+ cpu=96,
214
+ gpu=8,
215
+ memMB=384 * GiB,
216
+ capabilities={K8S_ITYPE: "g4dn.metal"},
217
+ devices={EFA_DEVICE: 1},
146
218
  )
147
219
 
148
220
 
@@ -164,53 +236,201 @@ def aws_g5_4xlarge() -> Resource:
164
236
 
165
237
  def aws_g5_8xlarge() -> Resource:
166
238
  return Resource(
167
- cpu=32, gpu=1, memMB=128 * GiB, capabilities={K8S_ITYPE: "g5.8xlarge"}
239
+ cpu=32,
240
+ gpu=1,
241
+ memMB=128 * GiB,
242
+ capabilities={K8S_ITYPE: "g5.8xlarge"},
243
+ devices={EFA_DEVICE: 1},
168
244
  )
169
245
 
170
246
 
171
- def aws_g5_16xlarge() -> Resource:
247
+ def aws_g5_12xlarge() -> Resource:
172
248
  return Resource(
173
- cpu=64, gpu=1, memMB=256 * GiB, capabilities={K8S_ITYPE: "g5.16xlarge"}
249
+ cpu=48,
250
+ gpu=4,
251
+ memMB=192 * GiB,
252
+ capabilities={K8S_ITYPE: "g5.12xlarge"},
253
+ devices={EFA_DEVICE: 1},
174
254
  )
175
255
 
176
256
 
177
- def aws_g5_12xlarge() -> Resource:
257
+ def aws_g5_16xlarge() -> Resource:
178
258
  return Resource(
179
- cpu=48, gpu=4, memMB=192 * GiB, capabilities={K8S_ITYPE: "g5.12xlarge"}
259
+ cpu=64,
260
+ gpu=1,
261
+ memMB=256 * GiB,
262
+ capabilities={K8S_ITYPE: "g5.16xlarge"},
263
+ devices={EFA_DEVICE: 1},
180
264
  )
181
265
 
182
266
 
183
267
  def aws_g5_24xlarge() -> Resource:
184
268
  return Resource(
185
- cpu=96, gpu=4, memMB=384 * GiB, capabilities={K8S_ITYPE: "g5.24xlarge"}
269
+ cpu=96,
270
+ gpu=4,
271
+ memMB=384 * GiB,
272
+ capabilities={K8S_ITYPE: "g5.24xlarge"},
273
+ devices={EFA_DEVICE: 1},
186
274
  )
187
275
 
188
276
 
189
277
  def aws_g5_48xlarge() -> Resource:
190
278
  return Resource(
191
- cpu=192, gpu=8, memMB=768 * GiB, capabilities={K8S_ITYPE: "g5.48xlarge"}
279
+ cpu=192,
280
+ gpu=8,
281
+ memMB=768 * GiB,
282
+ capabilities={K8S_ITYPE: "g5.48xlarge"},
283
+ devices={EFA_DEVICE: 1},
284
+ )
285
+
286
+
287
+ def aws_g6e_xlarge() -> Resource:
288
+ return Resource(
289
+ cpu=4,
290
+ gpu=1,
291
+ memMB=32 * GiB,
292
+ capabilities={K8S_ITYPE: "g6e.xlarge"},
192
293
  )
193
294
 
194
295
 
195
- def aws_trn1_2xl() -> Resource:
196
- return Resource(cpu=8, gpu=0, memMB=32 * GiB, capabilities={K8S_ITYPE: "trn1.2xl"})
296
+ def aws_g6e_2xlarge() -> Resource:
297
+ return Resource(
298
+ cpu=8,
299
+ gpu=1,
300
+ memMB=64 * GiB,
301
+ capabilities={K8S_ITYPE: "g6e.2xlarge"},
302
+ )
197
303
 
198
304
 
199
- def aws_trn1_32xl() -> Resource:
305
+ def aws_g6e_4xlarge() -> Resource:
200
306
  return Resource(
201
- cpu=128, gpu=0, memMB=512 * GiB, capabilities={K8S_ITYPE: "trn1.32xl"}
307
+ cpu=16,
308
+ gpu=1,
309
+ memMB=128 * GiB,
310
+ capabilities={K8S_ITYPE: "g6e.4xlarge"},
311
+ )
312
+
313
+
314
+ def aws_g6e_8xlarge() -> Resource:
315
+ return Resource(
316
+ cpu=32,
317
+ gpu=1,
318
+ memMB=256 * GiB,
319
+ capabilities={K8S_ITYPE: "g6e.8xlarge"},
320
+ )
321
+
322
+
323
+ def aws_g6e_16xlarge() -> Resource:
324
+ return Resource(
325
+ cpu=64,
326
+ gpu=1,
327
+ memMB=512 * GiB,
328
+ capabilities={K8S_ITYPE: "g6e.16xlarge"},
329
+ )
330
+
331
+
332
+ def aws_g6e_12xlarge() -> Resource:
333
+ return Resource(
334
+ cpu=48,
335
+ gpu=4,
336
+ memMB=384 * GiB,
337
+ capabilities={K8S_ITYPE: "g6e.12xlarge"},
338
+ )
339
+
340
+
341
+ def aws_g6e_24xlarge() -> Resource:
342
+ return Resource(
343
+ cpu=96,
344
+ gpu=4,
345
+ memMB=768 * GiB,
346
+ capabilities={K8S_ITYPE: "g6e.24xlarge"},
347
+ devices={EFA_DEVICE: 2},
348
+ )
349
+
350
+
351
+ def aws_g6e_48xlarge() -> Resource:
352
+ return Resource(
353
+ cpu=192,
354
+ gpu=8,
355
+ memMB=1536 * GiB,
356
+ capabilities={K8S_ITYPE: "g6e.48xlarge"},
357
+ devices={EFA_DEVICE: 4},
358
+ )
359
+
360
+
361
+ def aws_trn1_2xlarge() -> Resource:
362
+ return Resource(
363
+ cpu=8,
364
+ gpu=0,
365
+ memMB=32 * GiB,
366
+ capabilities={K8S_ITYPE: "trn1.2xlarge"},
367
+ devices={NEURON_DEVICE: 1},
368
+ )
369
+
370
+
371
+ def aws_trn1_32xlarge() -> Resource:
372
+ return Resource(
373
+ cpu=128,
374
+ gpu=0,
375
+ memMB=512 * GiB,
376
+ capabilities={K8S_ITYPE: "trn1.32xlarge"},
377
+ devices={EFA_DEVICE: 8, NEURON_DEVICE: 16},
378
+ )
379
+
380
+
381
+ def aws_inf2_xlarge() -> Resource:
382
+ return Resource(
383
+ cpu=4,
384
+ gpu=0,
385
+ memMB=16 * GiB,
386
+ capabilities={K8S_ITYPE: "inf2.xlarge"},
387
+ devices={NEURON_DEVICE: 1},
388
+ )
389
+
390
+
391
+ def aws_inf2_8xlarge() -> Resource:
392
+ return Resource(
393
+ cpu=32,
394
+ gpu=0,
395
+ memMB=128 * GiB,
396
+ capabilities={K8S_ITYPE: "inf2.8xlarge"},
397
+ devices={NEURON_DEVICE: 1},
398
+ )
399
+
400
+
401
+ def aws_inf2_24xlarge() -> Resource:
402
+ return Resource(
403
+ cpu=96,
404
+ gpu=0,
405
+ memMB=384 * GiB,
406
+ capabilities={K8S_ITYPE: "inf2.24xlarge"},
407
+ devices={NEURON_DEVICE: 6},
408
+ )
409
+
410
+
411
+ def aws_inf2_48xlarge() -> Resource:
412
+ return Resource(
413
+ cpu=192,
414
+ gpu=0,
415
+ memMB=768 * GiB,
416
+ capabilities={K8S_ITYPE: "inf2.48xlarge"},
417
+ devices={NEURON_DEVICE: 12},
202
418
  )
203
419
 
204
420
 
205
421
  NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
206
422
  "aws_t3.medium": aws_t3_medium,
207
423
  "aws_m5.2xlarge": aws_m5_2xlarge,
424
+ "aws_c5.18xlarge": aws_c5_18xlarge,
208
425
  "aws_p3.2xlarge": aws_p3_2xlarge,
209
426
  "aws_p3.8xlarge": aws_p3_8xlarge,
210
427
  "aws_p3.16xlarge": aws_p3_16xlarge,
211
428
  "aws_p3dn.24xlarge": aws_p3dn_24xlarge,
212
429
  "aws_p4d.24xlarge": aws_p4d_24xlarge,
213
430
  "aws_p4de.24xlarge": aws_p4de_24xlarge,
431
+ "aws_p5.48xlarge": aws_p5_48xlarge,
432
+ "aws_p5e.48xlarge": aws_p5e_48xlarge,
433
+ "aws_p5en.48xlarge": aws_p5en_48xlarge,
214
434
  "aws_g4dn.xlarge": aws_g4dn_xlarge,
215
435
  "aws_g4dn.2xlarge": aws_g4dn_2xlarge,
216
436
  "aws_g4dn.4xlarge": aws_g4dn_4xlarge,
@@ -226,6 +446,18 @@ NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
226
446
  "aws_g5.12xlarge": aws_g5_12xlarge,
227
447
  "aws_g5.24xlarge": aws_g5_24xlarge,
228
448
  "aws_g5.48xlarge": aws_g5_48xlarge,
229
- "aws_trn1.2xl": aws_trn1_2xl,
230
- "aws_trn1.32xl": aws_trn1_32xl,
449
+ "aws_g6e.xlarge": aws_g6e_xlarge,
450
+ "aws_g6e.2xlarge": aws_g6e_2xlarge,
451
+ "aws_g6e.4xlarge": aws_g6e_4xlarge,
452
+ "aws_g6e.8xlarge": aws_g6e_8xlarge,
453
+ "aws_g6e.16xlarge": aws_g6e_16xlarge,
454
+ "aws_g6e.12xlarge": aws_g6e_12xlarge,
455
+ "aws_g6e.24xlarge": aws_g6e_24xlarge,
456
+ "aws_g6e.48xlarge": aws_g6e_48xlarge,
457
+ "aws_trn1.2xlarge": aws_trn1_2xlarge,
458
+ "aws_trn1.32xlarge": aws_trn1_32xlarge,
459
+ "aws_inf2.xlarge": aws_inf2_xlarge,
460
+ "aws_inf2.8xlarge": aws_inf2_8xlarge,
461
+ "aws_inf2.24xlarge": aws_inf2_24xlarge,
462
+ "aws_inf2.48xlarge": aws_inf2_48xlarge,
231
463
  }
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  Defines generic named resources that are not specific to any cloud provider's
9
11
  instance types. These generic named resources are meant to be used as