torchx-nightly 2024.1.6__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +4 -3
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +75 -6
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +9 -7
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +418 -30
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +181 -4
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2024.1.6.dist-info/METADATA +0 -176
  108. torchx_nightly-2024.1.6.dist-info/RECORD +0 -118
  109. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
torchx/specs/finder.py CHANGED
@@ -4,7 +4,10 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import abc
10
+ import copy
8
11
  import importlib
9
12
  import inspect
10
13
  import logging
@@ -17,11 +20,17 @@ from types import ModuleType
17
20
  from typing import Callable, Dict, Generator, List, Optional, Union
18
21
 
19
22
  from torchx.specs import AppDef
20
- from torchx.specs.file_linter import get_fn_docstring, validate
23
+
24
+ from torchx.specs.file_linter import (
25
+ ComponentFunctionValidator,
26
+ get_fn_docstring,
27
+ validate,
28
+ )
21
29
  from torchx.util import entrypoints
22
30
  from torchx.util.io import read_conf_file
23
31
  from torchx.util.types import none_throws
24
32
 
33
+
25
34
  logger: logging.Logger = logging.getLogger(__name__)
26
35
 
27
36
 
@@ -51,13 +60,17 @@ class _Component:
51
60
  name: str
52
61
  description: str
53
62
  fn_name: str
63
+
54
64
  fn: Callable[..., AppDef]
65
+
55
66
  validation_errors: List[str]
56
67
 
57
68
 
58
69
  class ComponentsFinder(abc.ABC):
59
70
  @abc.abstractmethod
60
- def find(self) -> List[_Component]:
71
+ def find(
72
+ self, validators: Optional[List[ComponentFunctionValidator]]
73
+ ) -> List[_Component]:
61
74
  """
62
75
  Retrieves a set of components. A component is defined as a python
63
76
  function that conforms to ``torchx.specs.file_linter`` linter.
@@ -201,10 +214,12 @@ class ModuleComponentsFinder(ComponentsFinder):
201
214
  else:
202
215
  yield self._try_import(module_info.name)
203
216
 
204
- def find(self) -> List[_Component]:
217
+ def find(
218
+ self, validators: Optional[List[ComponentFunctionValidator]]
219
+ ) -> List[_Component]:
205
220
  components = []
206
221
  for m in self._iter_modules_recursive(self.base_module):
207
- components += self._get_components_from_module(m)
222
+ components += self._get_components_from_module(m, validators)
208
223
  return components
209
224
 
210
225
  def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
@@ -219,7 +234,9 @@ class ModuleComponentsFinder(ComponentsFinder):
219
234
  else:
220
235
  return module
221
236
 
222
- def _get_components_from_module(self, module: ModuleType) -> List[_Component]:
237
+ def _get_components_from_module(
238
+ self, module: ModuleType, validators: Optional[List[ComponentFunctionValidator]]
239
+ ) -> List[_Component]:
223
240
  functions = getmembers(module, isfunction)
224
241
  component_defs = []
225
242
 
@@ -228,7 +245,7 @@ class ModuleComponentsFinder(ComponentsFinder):
228
245
  module_path = os.path.abspath(module_path)
229
246
  rel_module_name = module_relname(module, relative_to=self.base_module)
230
247
  for function_name, function in functions:
231
- linter_errors = validate(module_path, function_name)
248
+ linter_errors = validate(module_path, function_name, validators)
232
249
  component_desc, _ = get_fn_docstring(function)
233
250
 
234
251
  # remove empty string to deal with group=""
@@ -253,17 +270,26 @@ class CustomComponentsFinder(ComponentsFinder):
253
270
  self._filepath = filepath
254
271
  self._function_name = function_name
255
272
 
256
- def _get_validation_errors(self, path: str, function_name: str) -> List[str]:
257
- linter_errors = validate(path, function_name)
273
+ def _get_validation_errors(
274
+ self,
275
+ path: str,
276
+ function_name: str,
277
+ validators: Optional[List[ComponentFunctionValidator]],
278
+ ) -> List[str]:
279
+ linter_errors = validate(path, function_name, validators)
258
280
  return [linter_error.description for linter_error in linter_errors]
259
281
 
260
- def find(self) -> List[_Component]:
282
+ def find(
283
+ self, validators: Optional[List[ComponentFunctionValidator]]
284
+ ) -> List[_Component]:
261
285
  validation_errors = self._get_validation_errors(
262
- self._filepath, self._function_name
286
+ self._filepath, self._function_name, validators
263
287
  )
264
288
 
265
289
  file_source = read_conf_file(self._filepath)
266
- namespace = globals()
290
+ namespace = copy.copy(globals())
291
+ # so that __file__ used inside the component points to the correct file
292
+ namespace["__file__"] = os.path.abspath(self._filepath)
267
293
  exec(file_source, namespace) # noqa: P204
268
294
  if self._function_name not in namespace:
269
295
  raise ComponentNotFoundException(
@@ -282,7 +308,9 @@ class CustomComponentsFinder(ComponentsFinder):
282
308
  ]
283
309
 
284
310
 
285
- def _load_custom_components() -> List[_Component]:
311
+ def _load_custom_components(
312
+ validators: Optional[List[ComponentFunctionValidator]],
313
+ ) -> List[_Component]:
286
314
  component_modules = {
287
315
  name: load_fn()
288
316
  for name, load_fn in
@@ -301,11 +329,13 @@ def _load_custom_components() -> List[_Component]:
301
329
  # _0 = torchx.components.dist
302
330
  # _1 = torchx.components.utils
303
331
  group = "" if group.startswith("_") else group
304
- components += ModuleComponentsFinder(module, group).find()
332
+ components += ModuleComponentsFinder(module, group).find(validators)
305
333
  return components
306
334
 
307
335
 
308
- def _load_components() -> Dict[str, _Component]:
336
+ def _load_components(
337
+ validators: Optional[List[ComponentFunctionValidator]],
338
+ ) -> Dict[str, _Component]:
309
339
  """
310
340
  Loads either the custom component defs from the entrypoint ``[torchx.components]``
311
341
  or the default builtins from ``torchx.components`` module.
@@ -316,19 +346,21 @@ def _load_components() -> Dict[str, _Component]:
316
346
 
317
347
  """
318
348
 
319
- components = _load_custom_components()
349
+ components = _load_custom_components(validators)
320
350
  if not components:
321
- components = ModuleComponentsFinder("torchx.components", "").find()
351
+ components = ModuleComponentsFinder("torchx.components", "").find(validators)
322
352
  return {c.name: c for c in components}
323
353
 
324
354
 
325
355
  _components: Optional[Dict[str, _Component]] = None
326
356
 
327
357
 
328
- def _find_components() -> Dict[str, _Component]:
358
+ def _find_components(
359
+ validators: Optional[List[ComponentFunctionValidator]],
360
+ ) -> Dict[str, _Component]:
329
361
  global _components
330
362
  if not _components:
331
- _components = _load_components()
363
+ _components = _load_components(validators)
332
364
  return none_throws(_components)
333
365
 
334
366
 
@@ -336,17 +368,21 @@ def _is_custom_component(component_name: str) -> bool:
336
368
  return ":" in component_name
337
369
 
338
370
 
339
- def _find_custom_components(name: str) -> Dict[str, _Component]:
371
+ def _find_custom_components(
372
+ name: str, validators: Optional[List[ComponentFunctionValidator]]
373
+ ) -> Dict[str, _Component]:
340
374
  if ":" not in name:
341
375
  raise ValueError(
342
376
  f"Invalid custom component: {name}, valid template : `FILEPATH`:`FUNCTION_NAME`"
343
377
  )
344
378
  filepath, component_name = name.split(":")
345
- components = CustomComponentsFinder(filepath, component_name).find()
379
+ components = CustomComponentsFinder(filepath, component_name).find(validators)
346
380
  return {component.name: component for component in components}
347
381
 
348
382
 
349
- def get_components() -> Dict[str, _Component]:
383
+ def get_components(
384
+ validators: Optional[List[ComponentFunctionValidator]] = None,
385
+ ) -> Dict[str, _Component]:
350
386
  """
351
387
  Returns all custom components registered via ``[torchx.components]`` entrypoints
352
388
  OR builtin components that ship with TorchX (but not both).
@@ -393,13 +429,15 @@ def get_components() -> Dict[str, _Component]:
393
429
  """
394
430
 
395
431
  valid_components: Dict[str, _Component] = {}
396
- for component_name, component in _find_components().items():
432
+ for component_name, component in _find_components(validators).items():
397
433
  if len(component.validation_errors) == 0:
398
434
  valid_components[component_name] = component
399
435
  return valid_components
400
436
 
401
437
 
402
- def get_component(name: str) -> _Component:
438
+ def get_component(
439
+ name: str, validators: Optional[List[ComponentFunctionValidator]] = None
440
+ ) -> _Component:
403
441
  """
404
442
  Retrieves components by the provided name.
405
443
 
@@ -407,14 +445,14 @@ def get_component(name: str) -> _Component:
407
445
  Component or None if no component with ``name`` exists
408
446
  """
409
447
  if _is_custom_component(name):
410
- components = _find_custom_components(name)
448
+ components = _find_custom_components(name, validators)
411
449
  else:
412
- components = _find_components()
450
+ components = _find_components(validators)
413
451
  if name not in components:
414
452
  raise ComponentNotFoundException(
415
453
  f"Component `{name}` not found. Please make sure it is one of the "
416
454
  "builtins: `torchx builtins`. Or registered via `[torchx.components]` "
417
- "entry point (see: https://pytorch.org/torchx/latest/configure.html)"
455
+ "entry point (see: https://meta-pytorch.org/torchx/latest/configure.html)"
418
456
  )
419
457
 
420
458
  component = components[name]
@@ -426,7 +464,9 @@ def get_component(name: str) -> _Component:
426
464
  return component
427
465
 
428
466
 
429
- def get_builtin_source(name: str) -> str:
467
+ def get_builtin_source(
468
+ name: str, validators: Optional[List[ComponentFunctionValidator]] = None
469
+ ) -> str:
430
470
  """
431
471
  Returns a string of the the builtin component's function source code
432
472
  with all the import statements. Intended to be used to make a copy
@@ -444,7 +484,7 @@ def get_builtin_source(name: str) -> str:
444
484
  are optimized and formatting adheres to your organization's standards.
445
485
  """
446
486
 
447
- component = get_component(name)
487
+ component = get_component(name, validators)
448
488
  fn = component.fn
449
489
  fn_name = component.name.split(".")[-1]
450
490
 
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  r"""
8
10
  `torchx.specs.named_resources_aws` contains resource definitions that represent corresponding AWS instance types
9
11
  taken from https://aws.amazon.com/ec2/instance-types/. The resources are exposed
@@ -14,7 +16,7 @@ the equvalent resource in mem, cpu and gpu numbers.
14
16
 
15
17
  .. note::
16
18
  These resource definitions may change in future. It is expected for each user to
17
- manage their own resources. Follow https://pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
19
+ manage their own resources. Follow https://meta-pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
18
20
  to set up named resources.
19
21
 
20
22
  Usage:
@@ -35,6 +37,7 @@ from typing import Callable, Mapping
35
37
  from torchx.specs.api import Resource
36
38
 
37
39
  EFA_DEVICE = "vpc.amazonaws.com/efa"
40
+ NEURON_DEVICE = "aws.amazon.com/neurondevice"
38
41
 
39
42
  # ecs and ec2 have memtax and currently AWS Batch uses hard memory limits
40
43
  # so we have to account for mem tax when registering these resources for AWS
@@ -44,7 +47,7 @@ EFA_DEVICE = "vpc.amazonaws.com/efa"
44
47
  MEM_TAX = 0.96
45
48
 
46
49
  # determines instance type for non-honogeneous CEs
47
- # see https://github.com/pytorch/torchx/issues/780
50
+ # see https://github.com/meta-pytorch/torchx/issues/780
48
51
  K8S_ITYPE = "node.kubernetes.io/instance-type"
49
52
  GiB: int = int(1024 * MEM_TAX)
50
53
 
@@ -107,6 +110,36 @@ def aws_p4de_24xlarge() -> Resource:
107
110
  )
108
111
 
109
112
 
113
+ def aws_p5_48xlarge() -> Resource:
114
+ return Resource(
115
+ cpu=192,
116
+ gpu=8,
117
+ memMB=2048 * GiB,
118
+ capabilities={K8S_ITYPE: "p5.48xlarge"},
119
+ devices={EFA_DEVICE: 32},
120
+ )
121
+
122
+
123
+ def aws_p5e_48xlarge() -> Resource:
124
+ return Resource(
125
+ cpu=192,
126
+ gpu=8,
127
+ memMB=2048 * GiB,
128
+ capabilities={K8S_ITYPE: "p5e.48xlarge"},
129
+ devices={EFA_DEVICE: 32},
130
+ )
131
+
132
+
133
+ def aws_p5en_48xlarge() -> Resource:
134
+ return Resource(
135
+ cpu=192,
136
+ gpu=8,
137
+ memMB=2048 * GiB,
138
+ capabilities={K8S_ITYPE: "p5en.48xlarge"},
139
+ devices={EFA_DEVICE: 16},
140
+ )
141
+
142
+
110
143
  def aws_t3_medium() -> Resource:
111
144
  return Resource(cpu=2, gpu=0, memMB=4 * GiB, capabilities={K8S_ITYPE: "t3.medium"})
112
145
 
@@ -117,6 +150,16 @@ def aws_m5_2xlarge() -> Resource:
117
150
  )
118
151
 
119
152
 
153
+ def aws_c5_18xlarge() -> Resource:
154
+ return Resource(
155
+ # using lower memory size than the spec since MEM_TAX is not enough for adjustment
156
+ cpu=72,
157
+ gpu=0,
158
+ memMB=142 * GiB,
159
+ capabilities={K8S_ITYPE: "c5.18xlarge"},
160
+ )
161
+
162
+
120
163
  def aws_g4dn_xlarge() -> Resource:
121
164
  return Resource(
122
165
  cpu=4, gpu=1, memMB=16 * GiB, capabilities={K8S_ITYPE: "g4dn.xlarge"}
@@ -241,9 +284,87 @@ def aws_g5_48xlarge() -> Resource:
241
284
  )
242
285
 
243
286
 
287
+ def aws_g6e_xlarge() -> Resource:
288
+ return Resource(
289
+ cpu=4,
290
+ gpu=1,
291
+ memMB=32 * GiB,
292
+ capabilities={K8S_ITYPE: "g6e.xlarge"},
293
+ )
294
+
295
+
296
+ def aws_g6e_2xlarge() -> Resource:
297
+ return Resource(
298
+ cpu=8,
299
+ gpu=1,
300
+ memMB=64 * GiB,
301
+ capabilities={K8S_ITYPE: "g6e.2xlarge"},
302
+ )
303
+
304
+
305
+ def aws_g6e_4xlarge() -> Resource:
306
+ return Resource(
307
+ cpu=16,
308
+ gpu=1,
309
+ memMB=128 * GiB,
310
+ capabilities={K8S_ITYPE: "g6e.4xlarge"},
311
+ )
312
+
313
+
314
+ def aws_g6e_8xlarge() -> Resource:
315
+ return Resource(
316
+ cpu=32,
317
+ gpu=1,
318
+ memMB=256 * GiB,
319
+ capabilities={K8S_ITYPE: "g6e.8xlarge"},
320
+ )
321
+
322
+
323
+ def aws_g6e_16xlarge() -> Resource:
324
+ return Resource(
325
+ cpu=64,
326
+ gpu=1,
327
+ memMB=512 * GiB,
328
+ capabilities={K8S_ITYPE: "g6e.16xlarge"},
329
+ )
330
+
331
+
332
+ def aws_g6e_12xlarge() -> Resource:
333
+ return Resource(
334
+ cpu=48,
335
+ gpu=4,
336
+ memMB=384 * GiB,
337
+ capabilities={K8S_ITYPE: "g6e.12xlarge"},
338
+ )
339
+
340
+
341
+ def aws_g6e_24xlarge() -> Resource:
342
+ return Resource(
343
+ cpu=96,
344
+ gpu=4,
345
+ memMB=768 * GiB,
346
+ capabilities={K8S_ITYPE: "g6e.24xlarge"},
347
+ devices={EFA_DEVICE: 2},
348
+ )
349
+
350
+
351
+ def aws_g6e_48xlarge() -> Resource:
352
+ return Resource(
353
+ cpu=192,
354
+ gpu=8,
355
+ memMB=1536 * GiB,
356
+ capabilities={K8S_ITYPE: "g6e.48xlarge"},
357
+ devices={EFA_DEVICE: 4},
358
+ )
359
+
360
+
244
361
  def aws_trn1_2xlarge() -> Resource:
245
362
  return Resource(
246
- cpu=8, gpu=0, memMB=32 * GiB, capabilities={K8S_ITYPE: "trn1.2xlarge"}
363
+ cpu=8,
364
+ gpu=0,
365
+ memMB=32 * GiB,
366
+ capabilities={K8S_ITYPE: "trn1.2xlarge"},
367
+ devices={NEURON_DEVICE: 1},
247
368
  )
248
369
 
249
370
 
@@ -253,19 +374,63 @@ def aws_trn1_32xlarge() -> Resource:
253
374
  gpu=0,
254
375
  memMB=512 * GiB,
255
376
  capabilities={K8S_ITYPE: "trn1.32xlarge"},
256
- devices={EFA_DEVICE: 8},
377
+ devices={EFA_DEVICE: 8, NEURON_DEVICE: 16},
378
+ )
379
+
380
+
381
+ def aws_inf2_xlarge() -> Resource:
382
+ return Resource(
383
+ cpu=4,
384
+ gpu=0,
385
+ memMB=16 * GiB,
386
+ capabilities={K8S_ITYPE: "inf2.xlarge"},
387
+ devices={NEURON_DEVICE: 1},
388
+ )
389
+
390
+
391
+ def aws_inf2_8xlarge() -> Resource:
392
+ return Resource(
393
+ cpu=32,
394
+ gpu=0,
395
+ memMB=128 * GiB,
396
+ capabilities={K8S_ITYPE: "inf2.8xlarge"},
397
+ devices={NEURON_DEVICE: 1},
398
+ )
399
+
400
+
401
+ def aws_inf2_24xlarge() -> Resource:
402
+ return Resource(
403
+ cpu=96,
404
+ gpu=0,
405
+ memMB=384 * GiB,
406
+ capabilities={K8S_ITYPE: "inf2.24xlarge"},
407
+ devices={NEURON_DEVICE: 6},
408
+ )
409
+
410
+
411
+ def aws_inf2_48xlarge() -> Resource:
412
+ return Resource(
413
+ cpu=192,
414
+ gpu=0,
415
+ memMB=768 * GiB,
416
+ capabilities={K8S_ITYPE: "inf2.48xlarge"},
417
+ devices={NEURON_DEVICE: 12},
257
418
  )
258
419
 
259
420
 
260
421
  NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
261
422
  "aws_t3.medium": aws_t3_medium,
262
423
  "aws_m5.2xlarge": aws_m5_2xlarge,
424
+ "aws_c5.18xlarge": aws_c5_18xlarge,
263
425
  "aws_p3.2xlarge": aws_p3_2xlarge,
264
426
  "aws_p3.8xlarge": aws_p3_8xlarge,
265
427
  "aws_p3.16xlarge": aws_p3_16xlarge,
266
428
  "aws_p3dn.24xlarge": aws_p3dn_24xlarge,
267
429
  "aws_p4d.24xlarge": aws_p4d_24xlarge,
268
430
  "aws_p4de.24xlarge": aws_p4de_24xlarge,
431
+ "aws_p5.48xlarge": aws_p5_48xlarge,
432
+ "aws_p5e.48xlarge": aws_p5e_48xlarge,
433
+ "aws_p5en.48xlarge": aws_p5en_48xlarge,
269
434
  "aws_g4dn.xlarge": aws_g4dn_xlarge,
270
435
  "aws_g4dn.2xlarge": aws_g4dn_2xlarge,
271
436
  "aws_g4dn.4xlarge": aws_g4dn_4xlarge,
@@ -281,6 +446,18 @@ NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
281
446
  "aws_g5.12xlarge": aws_g5_12xlarge,
282
447
  "aws_g5.24xlarge": aws_g5_24xlarge,
283
448
  "aws_g5.48xlarge": aws_g5_48xlarge,
449
+ "aws_g6e.xlarge": aws_g6e_xlarge,
450
+ "aws_g6e.2xlarge": aws_g6e_2xlarge,
451
+ "aws_g6e.4xlarge": aws_g6e_4xlarge,
452
+ "aws_g6e.8xlarge": aws_g6e_8xlarge,
453
+ "aws_g6e.16xlarge": aws_g6e_16xlarge,
454
+ "aws_g6e.12xlarge": aws_g6e_12xlarge,
455
+ "aws_g6e.24xlarge": aws_g6e_24xlarge,
456
+ "aws_g6e.48xlarge": aws_g6e_48xlarge,
284
457
  "aws_trn1.2xlarge": aws_trn1_2xlarge,
285
458
  "aws_trn1.32xlarge": aws_trn1_32xlarge,
459
+ "aws_inf2.xlarge": aws_inf2_xlarge,
460
+ "aws_inf2.8xlarge": aws_inf2_8xlarge,
461
+ "aws_inf2.24xlarge": aws_inf2_24xlarge,
462
+ "aws_inf2.48xlarge": aws_inf2_48xlarge,
286
463
  }
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  Defines generic named resources that are not specific to any cloud provider's
9
11
  instance types. These generic named resources are meant to be used as
@@ -0,0 +1,106 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ """
10
+ Overlays are JSON structs applied to :py:class:`~torchx.specs.AppDef` and :py:class:`~torchx.specs.Role`
11
+ to specify attributes of the scheduler's submit-job request that are not currently representable
12
+ as attributes of :py:class:`~torchx.specs.AppDef` and :py:class:`~torchx.specs.Role`.
13
+
14
+ For end-uses, here are a few use-cases of overlays:
15
+
16
+ 1. A new version of the scheduler has concepts/features that have not yet been added to TorchX.
17
+ 2. A bespoke internal scheduler has custom features that do not generalize hence not in TorchX.
18
+ 3. Re-using a pre-built ``AppDef`` but need to make a small change to the resulting scheduler request.
19
+
20
+ And for scheduler authors:
21
+
22
+ 1. Scheduler setting needs to be applied to a ``Role``, which makes it hard to add as ``runopts``
23
+ since ``runopts`` apply at the ``AppDef`` level.
24
+ 2. Scheduler setting cannot be represented naturally as the types supported by ``runopts``.
25
+ 3. Exposing the setting as a ``runopts`` obfuscates things.
26
+
27
+ See :py:func:`~torchx.specs.overlays.apply_overlay` for rules on how overlays are applied.
28
+ """
29
+
30
+ from typing import Any
31
+
32
+ Json = dict[str, Any]
33
+
34
+
35
+ def apply_overlay(base: Json, overlay: Json) -> None:
36
+ """Applies ``overlay`` on ``base``.
37
+
38
+ .. note:: this function mutates the ``base``!
39
+
40
+ Overlays follow these rules:
41
+
42
+ 1. Dicts, upsert key, value in base with the ones in overlay.
43
+ 2. Nested dicts, overlay recursively.
44
+ 3. Lists, append the overlay values to the base values.
45
+ 4. Nested lists DO NOT append recursively.
46
+ 5. Primitives (bool, str, int, float), replace base with the value in overlay.
47
+
48
+ .. doctest::
49
+
50
+ from torchx.specs.overlays import apply_overlay
51
+
52
+ base = {
53
+ "scheduler": {"policy": "default"},
54
+ "resources": {"limits": {"cpu": "500m"}},
55
+ "tolerations": [{"key": "gpu"}],
56
+ "nodeSelectorTerms": [
57
+ [{"matchExpressions": []}]
58
+ ],
59
+ "maxPods": 110,
60
+ }
61
+ overlay = {
62
+ "scheduler": {"policy": "binpacking"},
63
+ "resources": {"limits": {"memory": "1Gi"}},
64
+ "tolerations": [{"key": "spot"}],
65
+ "nodeSelectorTerms": [
66
+ [{"matchExpressions": [{"key": "disk"}]}]
67
+ ],
68
+ "maxPods": 250,
69
+ }
70
+
71
+ apply_overlay(base, overlay)
72
+
73
+ assert {
74
+ "scheduler": {"policy": "binpacking"},
75
+ "resources": {"limits": {"cpu": "500m", "memory": "1Gi"}},
76
+ "tolerations": [{"key": "gpu"}, {"key": "spot"}],
77
+ "nodeSelectorTerms": [
78
+ [{"matchExpressions": []}],
79
+ [{"matchExpressions": [{"key": "disk"}]}],
80
+ ],
81
+ "maxPods": 250,
82
+ } == base
83
+
84
+ """
85
+
86
+ def assert_type_equal(key: str, o1: object, o2: object) -> None:
87
+ o1_type = type(o1)
88
+ o2_type = type(o2)
89
+ assert (
90
+ o1_type == o2_type
91
+ ), f"Type mismatch for attr: `{key}`. {o1_type.__qualname__} != {o2_type.__qualname__}"
92
+
93
+ for key, overlay_value in overlay.items():
94
+ if key in base:
95
+ base_value = base[key]
96
+
97
+ assert_type_equal(key, base_value, overlay_value)
98
+
99
+ if isinstance(base_value, dict) and isinstance(overlay_value, dict):
100
+ apply_overlay(base_value, overlay_value)
101
+ elif isinstance(base_value, list) and isinstance(overlay_value, list):
102
+ base_value.extend(overlay_value)
103
+ else:
104
+ base[key] = overlay_value
105
+ else:
106
+ base[key] = overlay_value
@@ -3,3 +3,5 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -3,6 +3,8 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
6
8
  import torchx
7
9
  from torchx import specs
8
10
 
@@ -3,3 +3,5 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -3,6 +3,8 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
6
8
  import torchx
7
9
  from torchx import specs
8
10
 
@@ -4,3 +4,5 @@
4
4
  #
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
@@ -3,6 +3,8 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
6
8
  import torchx
7
9
  from torchx import specs
8
10