torchx-nightly 2024.1.6__py3-none-any.whl → 2025.12.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchx-nightly might be problematic. Click here for more details.
- torchx/__init__.py +2 -0
- torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
- torchx/apps/serve/serve.py +2 -0
- torchx/apps/utils/booth_main.py +2 -0
- torchx/apps/utils/copy_main.py +2 -0
- torchx/apps/utils/process_monitor.py +2 -0
- torchx/cli/__init__.py +2 -0
- torchx/cli/argparse_util.py +38 -3
- torchx/cli/cmd_base.py +2 -0
- torchx/cli/cmd_cancel.py +2 -0
- torchx/cli/cmd_configure.py +2 -0
- torchx/cli/cmd_delete.py +30 -0
- torchx/cli/cmd_describe.py +2 -0
- torchx/cli/cmd_list.py +8 -4
- torchx/cli/cmd_log.py +6 -24
- torchx/cli/cmd_run.py +269 -45
- torchx/cli/cmd_runopts.py +2 -0
- torchx/cli/cmd_status.py +12 -1
- torchx/cli/cmd_tracker.py +3 -1
- torchx/cli/colors.py +2 -0
- torchx/cli/main.py +4 -0
- torchx/components/__init__.py +3 -8
- torchx/components/component_test_base.py +2 -0
- torchx/components/dist.py +18 -7
- torchx/components/integration_tests/component_provider.py +4 -2
- torchx/components/integration_tests/integ_tests.py +2 -0
- torchx/components/serve.py +2 -0
- torchx/components/structured_arg.py +4 -3
- torchx/components/utils.py +15 -4
- torchx/distributed/__init__.py +2 -4
- torchx/examples/apps/datapreproc/datapreproc.py +2 -0
- torchx/examples/apps/lightning/data.py +5 -3
- torchx/examples/apps/lightning/model.py +7 -6
- torchx/examples/apps/lightning/profiler.py +7 -4
- torchx/examples/apps/lightning/train.py +11 -2
- torchx/examples/torchx_out_of_sync_training.py +11 -0
- torchx/notebook.py +2 -0
- torchx/runner/__init__.py +2 -0
- torchx/runner/api.py +167 -60
- torchx/runner/config.py +43 -10
- torchx/runner/events/__init__.py +57 -13
- torchx/runner/events/api.py +14 -3
- torchx/runner/events/handlers.py +2 -0
- torchx/runtime/tracking/__init__.py +2 -0
- torchx/runtime/tracking/api.py +2 -0
- torchx/schedulers/__init__.py +16 -15
- torchx/schedulers/api.py +70 -14
- torchx/schedulers/aws_batch_scheduler.py +75 -6
- torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
- torchx/schedulers/devices.py +17 -4
- torchx/schedulers/docker_scheduler.py +43 -11
- torchx/schedulers/ids.py +29 -23
- torchx/schedulers/kubernetes_mcad_scheduler.py +9 -7
- torchx/schedulers/kubernetes_scheduler.py +383 -38
- torchx/schedulers/local_scheduler.py +100 -27
- torchx/schedulers/lsf_scheduler.py +5 -4
- torchx/schedulers/slurm_scheduler.py +336 -20
- torchx/schedulers/streams.py +2 -0
- torchx/specs/__init__.py +89 -12
- torchx/specs/api.py +418 -30
- torchx/specs/builders.py +176 -38
- torchx/specs/file_linter.py +143 -57
- torchx/specs/finder.py +68 -28
- torchx/specs/named_resources_aws.py +181 -4
- torchx/specs/named_resources_generic.py +2 -0
- torchx/specs/overlays.py +106 -0
- torchx/specs/test/components/__init__.py +2 -0
- torchx/specs/test/components/a/__init__.py +2 -0
- torchx/specs/test/components/a/b/__init__.py +2 -0
- torchx/specs/test/components/a/b/c.py +2 -0
- torchx/specs/test/components/c/__init__.py +2 -0
- torchx/specs/test/components/c/d.py +2 -0
- torchx/tracker/__init__.py +12 -6
- torchx/tracker/api.py +15 -18
- torchx/tracker/backend/fsspec.py +2 -0
- torchx/util/cuda.py +2 -0
- torchx/util/datetime.py +2 -0
- torchx/util/entrypoints.py +39 -15
- torchx/util/io.py +2 -0
- torchx/util/log_tee_helpers.py +210 -0
- torchx/util/modules.py +65 -0
- torchx/util/session.py +42 -0
- torchx/util/shlex.py +2 -0
- torchx/util/strings.py +3 -1
- torchx/util/types.py +90 -29
- torchx/version.py +4 -2
- torchx/workspace/__init__.py +2 -0
- torchx/workspace/api.py +136 -6
- torchx/workspace/dir_workspace.py +2 -0
- torchx/workspace/docker_workspace.py +30 -2
- torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
- torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
- {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
- {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
- torchx/examples/pipelines/__init__.py +0 -0
- torchx/examples/pipelines/kfp/__init__.py +0 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
- torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
- torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
- torchx/pipelines/kfp/__init__.py +0 -28
- torchx/pipelines/kfp/adapter.py +0 -271
- torchx/pipelines/kfp/version.py +0 -17
- torchx/schedulers/gcp_batch_scheduler.py +0 -487
- torchx/schedulers/ray/ray_common.py +0 -22
- torchx/schedulers/ray/ray_driver.py +0 -307
- torchx/schedulers/ray_scheduler.py +0 -453
- torchx_nightly-2024.1.6.dist-info/METADATA +0 -176
- torchx_nightly-2024.1.6.dist-info/RECORD +0 -118
- {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
torchx/specs/finder.py
CHANGED
|
@@ -4,7 +4,10 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
import abc
|
|
10
|
+
import copy
|
|
8
11
|
import importlib
|
|
9
12
|
import inspect
|
|
10
13
|
import logging
|
|
@@ -17,11 +20,17 @@ from types import ModuleType
|
|
|
17
20
|
from typing import Callable, Dict, Generator, List, Optional, Union
|
|
18
21
|
|
|
19
22
|
from torchx.specs import AppDef
|
|
20
|
-
|
|
23
|
+
|
|
24
|
+
from torchx.specs.file_linter import (
|
|
25
|
+
ComponentFunctionValidator,
|
|
26
|
+
get_fn_docstring,
|
|
27
|
+
validate,
|
|
28
|
+
)
|
|
21
29
|
from torchx.util import entrypoints
|
|
22
30
|
from torchx.util.io import read_conf_file
|
|
23
31
|
from torchx.util.types import none_throws
|
|
24
32
|
|
|
33
|
+
|
|
25
34
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
26
35
|
|
|
27
36
|
|
|
@@ -51,13 +60,17 @@ class _Component:
|
|
|
51
60
|
name: str
|
|
52
61
|
description: str
|
|
53
62
|
fn_name: str
|
|
63
|
+
|
|
54
64
|
fn: Callable[..., AppDef]
|
|
65
|
+
|
|
55
66
|
validation_errors: List[str]
|
|
56
67
|
|
|
57
68
|
|
|
58
69
|
class ComponentsFinder(abc.ABC):
|
|
59
70
|
@abc.abstractmethod
|
|
60
|
-
def find(
|
|
71
|
+
def find(
|
|
72
|
+
self, validators: Optional[List[ComponentFunctionValidator]]
|
|
73
|
+
) -> List[_Component]:
|
|
61
74
|
"""
|
|
62
75
|
Retrieves a set of components. A component is defined as a python
|
|
63
76
|
function that conforms to ``torchx.specs.file_linter`` linter.
|
|
@@ -201,10 +214,12 @@ class ModuleComponentsFinder(ComponentsFinder):
|
|
|
201
214
|
else:
|
|
202
215
|
yield self._try_import(module_info.name)
|
|
203
216
|
|
|
204
|
-
def find(
|
|
217
|
+
def find(
|
|
218
|
+
self, validators: Optional[List[ComponentFunctionValidator]]
|
|
219
|
+
) -> List[_Component]:
|
|
205
220
|
components = []
|
|
206
221
|
for m in self._iter_modules_recursive(self.base_module):
|
|
207
|
-
components += self._get_components_from_module(m)
|
|
222
|
+
components += self._get_components_from_module(m, validators)
|
|
208
223
|
return components
|
|
209
224
|
|
|
210
225
|
def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
|
|
@@ -219,7 +234,9 @@ class ModuleComponentsFinder(ComponentsFinder):
|
|
|
219
234
|
else:
|
|
220
235
|
return module
|
|
221
236
|
|
|
222
|
-
def _get_components_from_module(
|
|
237
|
+
def _get_components_from_module(
|
|
238
|
+
self, module: ModuleType, validators: Optional[List[ComponentFunctionValidator]]
|
|
239
|
+
) -> List[_Component]:
|
|
223
240
|
functions = getmembers(module, isfunction)
|
|
224
241
|
component_defs = []
|
|
225
242
|
|
|
@@ -228,7 +245,7 @@ class ModuleComponentsFinder(ComponentsFinder):
|
|
|
228
245
|
module_path = os.path.abspath(module_path)
|
|
229
246
|
rel_module_name = module_relname(module, relative_to=self.base_module)
|
|
230
247
|
for function_name, function in functions:
|
|
231
|
-
linter_errors = validate(module_path, function_name)
|
|
248
|
+
linter_errors = validate(module_path, function_name, validators)
|
|
232
249
|
component_desc, _ = get_fn_docstring(function)
|
|
233
250
|
|
|
234
251
|
# remove empty string to deal with group=""
|
|
@@ -253,17 +270,26 @@ class CustomComponentsFinder(ComponentsFinder):
|
|
|
253
270
|
self._filepath = filepath
|
|
254
271
|
self._function_name = function_name
|
|
255
272
|
|
|
256
|
-
def _get_validation_errors(
|
|
257
|
-
|
|
273
|
+
def _get_validation_errors(
|
|
274
|
+
self,
|
|
275
|
+
path: str,
|
|
276
|
+
function_name: str,
|
|
277
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
278
|
+
) -> List[str]:
|
|
279
|
+
linter_errors = validate(path, function_name, validators)
|
|
258
280
|
return [linter_error.description for linter_error in linter_errors]
|
|
259
281
|
|
|
260
|
-
def find(
|
|
282
|
+
def find(
|
|
283
|
+
self, validators: Optional[List[ComponentFunctionValidator]]
|
|
284
|
+
) -> List[_Component]:
|
|
261
285
|
validation_errors = self._get_validation_errors(
|
|
262
|
-
self._filepath, self._function_name
|
|
286
|
+
self._filepath, self._function_name, validators
|
|
263
287
|
)
|
|
264
288
|
|
|
265
289
|
file_source = read_conf_file(self._filepath)
|
|
266
|
-
namespace = globals()
|
|
290
|
+
namespace = copy.copy(globals())
|
|
291
|
+
# so that __file__ used inside the component points to the correct file
|
|
292
|
+
namespace["__file__"] = os.path.abspath(self._filepath)
|
|
267
293
|
exec(file_source, namespace) # noqa: P204
|
|
268
294
|
if self._function_name not in namespace:
|
|
269
295
|
raise ComponentNotFoundException(
|
|
@@ -282,7 +308,9 @@ class CustomComponentsFinder(ComponentsFinder):
|
|
|
282
308
|
]
|
|
283
309
|
|
|
284
310
|
|
|
285
|
-
def _load_custom_components(
|
|
311
|
+
def _load_custom_components(
|
|
312
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
313
|
+
) -> List[_Component]:
|
|
286
314
|
component_modules = {
|
|
287
315
|
name: load_fn()
|
|
288
316
|
for name, load_fn in
|
|
@@ -301,11 +329,13 @@ def _load_custom_components() -> List[_Component]:
|
|
|
301
329
|
# _0 = torchx.components.dist
|
|
302
330
|
# _1 = torchx.components.utils
|
|
303
331
|
group = "" if group.startswith("_") else group
|
|
304
|
-
components += ModuleComponentsFinder(module, group).find()
|
|
332
|
+
components += ModuleComponentsFinder(module, group).find(validators)
|
|
305
333
|
return components
|
|
306
334
|
|
|
307
335
|
|
|
308
|
-
def _load_components(
|
|
336
|
+
def _load_components(
|
|
337
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
338
|
+
) -> Dict[str, _Component]:
|
|
309
339
|
"""
|
|
310
340
|
Loads either the custom component defs from the entrypoint ``[torchx.components]``
|
|
311
341
|
or the default builtins from ``torchx.components`` module.
|
|
@@ -316,19 +346,21 @@ def _load_components() -> Dict[str, _Component]:
|
|
|
316
346
|
|
|
317
347
|
"""
|
|
318
348
|
|
|
319
|
-
components = _load_custom_components()
|
|
349
|
+
components = _load_custom_components(validators)
|
|
320
350
|
if not components:
|
|
321
|
-
components = ModuleComponentsFinder("torchx.components", "").find()
|
|
351
|
+
components = ModuleComponentsFinder("torchx.components", "").find(validators)
|
|
322
352
|
return {c.name: c for c in components}
|
|
323
353
|
|
|
324
354
|
|
|
325
355
|
_components: Optional[Dict[str, _Component]] = None
|
|
326
356
|
|
|
327
357
|
|
|
328
|
-
def _find_components(
|
|
358
|
+
def _find_components(
|
|
359
|
+
validators: Optional[List[ComponentFunctionValidator]],
|
|
360
|
+
) -> Dict[str, _Component]:
|
|
329
361
|
global _components
|
|
330
362
|
if not _components:
|
|
331
|
-
_components = _load_components()
|
|
363
|
+
_components = _load_components(validators)
|
|
332
364
|
return none_throws(_components)
|
|
333
365
|
|
|
334
366
|
|
|
@@ -336,17 +368,21 @@ def _is_custom_component(component_name: str) -> bool:
|
|
|
336
368
|
return ":" in component_name
|
|
337
369
|
|
|
338
370
|
|
|
339
|
-
def _find_custom_components(
|
|
371
|
+
def _find_custom_components(
|
|
372
|
+
name: str, validators: Optional[List[ComponentFunctionValidator]]
|
|
373
|
+
) -> Dict[str, _Component]:
|
|
340
374
|
if ":" not in name:
|
|
341
375
|
raise ValueError(
|
|
342
376
|
f"Invalid custom component: {name}, valid template : `FILEPATH`:`FUNCTION_NAME`"
|
|
343
377
|
)
|
|
344
378
|
filepath, component_name = name.split(":")
|
|
345
|
-
components = CustomComponentsFinder(filepath, component_name).find()
|
|
379
|
+
components = CustomComponentsFinder(filepath, component_name).find(validators)
|
|
346
380
|
return {component.name: component for component in components}
|
|
347
381
|
|
|
348
382
|
|
|
349
|
-
def get_components(
|
|
383
|
+
def get_components(
|
|
384
|
+
validators: Optional[List[ComponentFunctionValidator]] = None,
|
|
385
|
+
) -> Dict[str, _Component]:
|
|
350
386
|
"""
|
|
351
387
|
Returns all custom components registered via ``[torchx.components]`` entrypoints
|
|
352
388
|
OR builtin components that ship with TorchX (but not both).
|
|
@@ -393,13 +429,15 @@ def get_components() -> Dict[str, _Component]:
|
|
|
393
429
|
"""
|
|
394
430
|
|
|
395
431
|
valid_components: Dict[str, _Component] = {}
|
|
396
|
-
for component_name, component in _find_components().items():
|
|
432
|
+
for component_name, component in _find_components(validators).items():
|
|
397
433
|
if len(component.validation_errors) == 0:
|
|
398
434
|
valid_components[component_name] = component
|
|
399
435
|
return valid_components
|
|
400
436
|
|
|
401
437
|
|
|
402
|
-
def get_component(
|
|
438
|
+
def get_component(
|
|
439
|
+
name: str, validators: Optional[List[ComponentFunctionValidator]] = None
|
|
440
|
+
) -> _Component:
|
|
403
441
|
"""
|
|
404
442
|
Retrieves components by the provided name.
|
|
405
443
|
|
|
@@ -407,14 +445,14 @@ def get_component(name: str) -> _Component:
|
|
|
407
445
|
Component or None if no component with ``name`` exists
|
|
408
446
|
"""
|
|
409
447
|
if _is_custom_component(name):
|
|
410
|
-
components = _find_custom_components(name)
|
|
448
|
+
components = _find_custom_components(name, validators)
|
|
411
449
|
else:
|
|
412
|
-
components = _find_components()
|
|
450
|
+
components = _find_components(validators)
|
|
413
451
|
if name not in components:
|
|
414
452
|
raise ComponentNotFoundException(
|
|
415
453
|
f"Component `{name}` not found. Please make sure it is one of the "
|
|
416
454
|
"builtins: `torchx builtins`. Or registered via `[torchx.components]` "
|
|
417
|
-
"entry point (see: https://pytorch.org/torchx/latest/configure.html)"
|
|
455
|
+
"entry point (see: https://meta-pytorch.org/torchx/latest/configure.html)"
|
|
418
456
|
)
|
|
419
457
|
|
|
420
458
|
component = components[name]
|
|
@@ -426,7 +464,9 @@ def get_component(name: str) -> _Component:
|
|
|
426
464
|
return component
|
|
427
465
|
|
|
428
466
|
|
|
429
|
-
def get_builtin_source(
|
|
467
|
+
def get_builtin_source(
|
|
468
|
+
name: str, validators: Optional[List[ComponentFunctionValidator]] = None
|
|
469
|
+
) -> str:
|
|
430
470
|
"""
|
|
431
471
|
Returns a string of the the builtin component's function source code
|
|
432
472
|
with all the import statements. Intended to be used to make a copy
|
|
@@ -444,7 +484,7 @@ def get_builtin_source(name: str) -> str:
|
|
|
444
484
|
are optimized and formatting adheres to your organization's standards.
|
|
445
485
|
"""
|
|
446
486
|
|
|
447
|
-
component = get_component(name)
|
|
487
|
+
component = get_component(name, validators)
|
|
448
488
|
fn = component.fn
|
|
449
489
|
fn_name = component.name.split(".")[-1]
|
|
450
490
|
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
r"""
|
|
8
10
|
`torchx.specs.named_resources_aws` contains resource definitions that represent corresponding AWS instance types
|
|
9
11
|
taken from https://aws.amazon.com/ec2/instance-types/. The resources are exposed
|
|
@@ -14,7 +16,7 @@ the equvalent resource in mem, cpu and gpu numbers.
|
|
|
14
16
|
|
|
15
17
|
.. note::
|
|
16
18
|
These resource definitions may change in future. It is expected for each user to
|
|
17
|
-
manage their own resources. Follow https://pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
|
|
19
|
+
manage their own resources. Follow https://meta-pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
|
|
18
20
|
to set up named resources.
|
|
19
21
|
|
|
20
22
|
Usage:
|
|
@@ -35,6 +37,7 @@ from typing import Callable, Mapping
|
|
|
35
37
|
from torchx.specs.api import Resource
|
|
36
38
|
|
|
37
39
|
EFA_DEVICE = "vpc.amazonaws.com/efa"
|
|
40
|
+
NEURON_DEVICE = "aws.amazon.com/neurondevice"
|
|
38
41
|
|
|
39
42
|
# ecs and ec2 have memtax and currently AWS Batch uses hard memory limits
|
|
40
43
|
# so we have to account for mem tax when registering these resources for AWS
|
|
@@ -44,7 +47,7 @@ EFA_DEVICE = "vpc.amazonaws.com/efa"
|
|
|
44
47
|
MEM_TAX = 0.96
|
|
45
48
|
|
|
46
49
|
# determines instance type for non-honogeneous CEs
|
|
47
|
-
# see https://github.com/pytorch/torchx/issues/780
|
|
50
|
+
# see https://github.com/meta-pytorch/torchx/issues/780
|
|
48
51
|
K8S_ITYPE = "node.kubernetes.io/instance-type"
|
|
49
52
|
GiB: int = int(1024 * MEM_TAX)
|
|
50
53
|
|
|
@@ -107,6 +110,36 @@ def aws_p4de_24xlarge() -> Resource:
|
|
|
107
110
|
)
|
|
108
111
|
|
|
109
112
|
|
|
113
|
+
def aws_p5_48xlarge() -> Resource:
|
|
114
|
+
return Resource(
|
|
115
|
+
cpu=192,
|
|
116
|
+
gpu=8,
|
|
117
|
+
memMB=2048 * GiB,
|
|
118
|
+
capabilities={K8S_ITYPE: "p5.48xlarge"},
|
|
119
|
+
devices={EFA_DEVICE: 32},
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def aws_p5e_48xlarge() -> Resource:
|
|
124
|
+
return Resource(
|
|
125
|
+
cpu=192,
|
|
126
|
+
gpu=8,
|
|
127
|
+
memMB=2048 * GiB,
|
|
128
|
+
capabilities={K8S_ITYPE: "p5e.48xlarge"},
|
|
129
|
+
devices={EFA_DEVICE: 32},
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def aws_p5en_48xlarge() -> Resource:
|
|
134
|
+
return Resource(
|
|
135
|
+
cpu=192,
|
|
136
|
+
gpu=8,
|
|
137
|
+
memMB=2048 * GiB,
|
|
138
|
+
capabilities={K8S_ITYPE: "p5en.48xlarge"},
|
|
139
|
+
devices={EFA_DEVICE: 16},
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
110
143
|
def aws_t3_medium() -> Resource:
|
|
111
144
|
return Resource(cpu=2, gpu=0, memMB=4 * GiB, capabilities={K8S_ITYPE: "t3.medium"})
|
|
112
145
|
|
|
@@ -117,6 +150,16 @@ def aws_m5_2xlarge() -> Resource:
|
|
|
117
150
|
)
|
|
118
151
|
|
|
119
152
|
|
|
153
|
+
def aws_c5_18xlarge() -> Resource:
|
|
154
|
+
return Resource(
|
|
155
|
+
# using lower memory size than the spec since MEM_TAX is not enough for adjustment
|
|
156
|
+
cpu=72,
|
|
157
|
+
gpu=0,
|
|
158
|
+
memMB=142 * GiB,
|
|
159
|
+
capabilities={K8S_ITYPE: "c5.18xlarge"},
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
120
163
|
def aws_g4dn_xlarge() -> Resource:
|
|
121
164
|
return Resource(
|
|
122
165
|
cpu=4, gpu=1, memMB=16 * GiB, capabilities={K8S_ITYPE: "g4dn.xlarge"}
|
|
@@ -241,9 +284,87 @@ def aws_g5_48xlarge() -> Resource:
|
|
|
241
284
|
)
|
|
242
285
|
|
|
243
286
|
|
|
287
|
+
def aws_g6e_xlarge() -> Resource:
|
|
288
|
+
return Resource(
|
|
289
|
+
cpu=4,
|
|
290
|
+
gpu=1,
|
|
291
|
+
memMB=32 * GiB,
|
|
292
|
+
capabilities={K8S_ITYPE: "g6e.xlarge"},
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def aws_g6e_2xlarge() -> Resource:
|
|
297
|
+
return Resource(
|
|
298
|
+
cpu=8,
|
|
299
|
+
gpu=1,
|
|
300
|
+
memMB=64 * GiB,
|
|
301
|
+
capabilities={K8S_ITYPE: "g6e.2xlarge"},
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def aws_g6e_4xlarge() -> Resource:
|
|
306
|
+
return Resource(
|
|
307
|
+
cpu=16,
|
|
308
|
+
gpu=1,
|
|
309
|
+
memMB=128 * GiB,
|
|
310
|
+
capabilities={K8S_ITYPE: "g6e.4xlarge"},
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def aws_g6e_8xlarge() -> Resource:
|
|
315
|
+
return Resource(
|
|
316
|
+
cpu=32,
|
|
317
|
+
gpu=1,
|
|
318
|
+
memMB=256 * GiB,
|
|
319
|
+
capabilities={K8S_ITYPE: "g6e.8xlarge"},
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def aws_g6e_16xlarge() -> Resource:
|
|
324
|
+
return Resource(
|
|
325
|
+
cpu=64,
|
|
326
|
+
gpu=1,
|
|
327
|
+
memMB=512 * GiB,
|
|
328
|
+
capabilities={K8S_ITYPE: "g6e.16xlarge"},
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def aws_g6e_12xlarge() -> Resource:
|
|
333
|
+
return Resource(
|
|
334
|
+
cpu=48,
|
|
335
|
+
gpu=4,
|
|
336
|
+
memMB=384 * GiB,
|
|
337
|
+
capabilities={K8S_ITYPE: "g6e.12xlarge"},
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def aws_g6e_24xlarge() -> Resource:
|
|
342
|
+
return Resource(
|
|
343
|
+
cpu=96,
|
|
344
|
+
gpu=4,
|
|
345
|
+
memMB=768 * GiB,
|
|
346
|
+
capabilities={K8S_ITYPE: "g6e.24xlarge"},
|
|
347
|
+
devices={EFA_DEVICE: 2},
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def aws_g6e_48xlarge() -> Resource:
|
|
352
|
+
return Resource(
|
|
353
|
+
cpu=192,
|
|
354
|
+
gpu=8,
|
|
355
|
+
memMB=1536 * GiB,
|
|
356
|
+
capabilities={K8S_ITYPE: "g6e.48xlarge"},
|
|
357
|
+
devices={EFA_DEVICE: 4},
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
|
|
244
361
|
def aws_trn1_2xlarge() -> Resource:
|
|
245
362
|
return Resource(
|
|
246
|
-
cpu=8,
|
|
363
|
+
cpu=8,
|
|
364
|
+
gpu=0,
|
|
365
|
+
memMB=32 * GiB,
|
|
366
|
+
capabilities={K8S_ITYPE: "trn1.2xlarge"},
|
|
367
|
+
devices={NEURON_DEVICE: 1},
|
|
247
368
|
)
|
|
248
369
|
|
|
249
370
|
|
|
@@ -253,19 +374,63 @@ def aws_trn1_32xlarge() -> Resource:
|
|
|
253
374
|
gpu=0,
|
|
254
375
|
memMB=512 * GiB,
|
|
255
376
|
capabilities={K8S_ITYPE: "trn1.32xlarge"},
|
|
256
|
-
devices={EFA_DEVICE: 8},
|
|
377
|
+
devices={EFA_DEVICE: 8, NEURON_DEVICE: 16},
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def aws_inf2_xlarge() -> Resource:
|
|
382
|
+
return Resource(
|
|
383
|
+
cpu=4,
|
|
384
|
+
gpu=0,
|
|
385
|
+
memMB=16 * GiB,
|
|
386
|
+
capabilities={K8S_ITYPE: "inf2.xlarge"},
|
|
387
|
+
devices={NEURON_DEVICE: 1},
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def aws_inf2_8xlarge() -> Resource:
|
|
392
|
+
return Resource(
|
|
393
|
+
cpu=32,
|
|
394
|
+
gpu=0,
|
|
395
|
+
memMB=128 * GiB,
|
|
396
|
+
capabilities={K8S_ITYPE: "inf2.8xlarge"},
|
|
397
|
+
devices={NEURON_DEVICE: 1},
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def aws_inf2_24xlarge() -> Resource:
|
|
402
|
+
return Resource(
|
|
403
|
+
cpu=96,
|
|
404
|
+
gpu=0,
|
|
405
|
+
memMB=384 * GiB,
|
|
406
|
+
capabilities={K8S_ITYPE: "inf2.24xlarge"},
|
|
407
|
+
devices={NEURON_DEVICE: 6},
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def aws_inf2_48xlarge() -> Resource:
|
|
412
|
+
return Resource(
|
|
413
|
+
cpu=192,
|
|
414
|
+
gpu=0,
|
|
415
|
+
memMB=768 * GiB,
|
|
416
|
+
capabilities={K8S_ITYPE: "inf2.48xlarge"},
|
|
417
|
+
devices={NEURON_DEVICE: 12},
|
|
257
418
|
)
|
|
258
419
|
|
|
259
420
|
|
|
260
421
|
NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
|
|
261
422
|
"aws_t3.medium": aws_t3_medium,
|
|
262
423
|
"aws_m5.2xlarge": aws_m5_2xlarge,
|
|
424
|
+
"aws_c5.18xlarge": aws_c5_18xlarge,
|
|
263
425
|
"aws_p3.2xlarge": aws_p3_2xlarge,
|
|
264
426
|
"aws_p3.8xlarge": aws_p3_8xlarge,
|
|
265
427
|
"aws_p3.16xlarge": aws_p3_16xlarge,
|
|
266
428
|
"aws_p3dn.24xlarge": aws_p3dn_24xlarge,
|
|
267
429
|
"aws_p4d.24xlarge": aws_p4d_24xlarge,
|
|
268
430
|
"aws_p4de.24xlarge": aws_p4de_24xlarge,
|
|
431
|
+
"aws_p5.48xlarge": aws_p5_48xlarge,
|
|
432
|
+
"aws_p5e.48xlarge": aws_p5e_48xlarge,
|
|
433
|
+
"aws_p5en.48xlarge": aws_p5en_48xlarge,
|
|
269
434
|
"aws_g4dn.xlarge": aws_g4dn_xlarge,
|
|
270
435
|
"aws_g4dn.2xlarge": aws_g4dn_2xlarge,
|
|
271
436
|
"aws_g4dn.4xlarge": aws_g4dn_4xlarge,
|
|
@@ -281,6 +446,18 @@ NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
|
|
|
281
446
|
"aws_g5.12xlarge": aws_g5_12xlarge,
|
|
282
447
|
"aws_g5.24xlarge": aws_g5_24xlarge,
|
|
283
448
|
"aws_g5.48xlarge": aws_g5_48xlarge,
|
|
449
|
+
"aws_g6e.xlarge": aws_g6e_xlarge,
|
|
450
|
+
"aws_g6e.2xlarge": aws_g6e_2xlarge,
|
|
451
|
+
"aws_g6e.4xlarge": aws_g6e_4xlarge,
|
|
452
|
+
"aws_g6e.8xlarge": aws_g6e_8xlarge,
|
|
453
|
+
"aws_g6e.16xlarge": aws_g6e_16xlarge,
|
|
454
|
+
"aws_g6e.12xlarge": aws_g6e_12xlarge,
|
|
455
|
+
"aws_g6e.24xlarge": aws_g6e_24xlarge,
|
|
456
|
+
"aws_g6e.48xlarge": aws_g6e_48xlarge,
|
|
284
457
|
"aws_trn1.2xlarge": aws_trn1_2xlarge,
|
|
285
458
|
"aws_trn1.32xlarge": aws_trn1_32xlarge,
|
|
459
|
+
"aws_inf2.xlarge": aws_inf2_xlarge,
|
|
460
|
+
"aws_inf2.8xlarge": aws_inf2_8xlarge,
|
|
461
|
+
"aws_inf2.24xlarge": aws_inf2_24xlarge,
|
|
462
|
+
"aws_inf2.48xlarge": aws_inf2_48xlarge,
|
|
286
463
|
}
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
"""
|
|
8
10
|
Defines generic named resources that are not specific to any cloud provider's
|
|
9
11
|
instance types. These generic named resources are meant to be used as
|
torchx/specs/overlays.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
Overlays are JSON structs applied to :py:class:`~torchx.specs.AppDef` and :py:class:`~torchx.specs.Role`
|
|
11
|
+
to specify attributes of the scheduler's submit-job request that are not currently representable
|
|
12
|
+
as attributes of :py:class:`~torchx.specs.AppDef` and :py:class:`~torchx.specs.Role`.
|
|
13
|
+
|
|
14
|
+
For end-uses, here are a few use-cases of overlays:
|
|
15
|
+
|
|
16
|
+
1. A new version of the scheduler has concepts/features that have not yet been added to TorchX.
|
|
17
|
+
2. A bespoke internal scheduler has custom features that do not generalize hence not in TorchX.
|
|
18
|
+
3. Re-using a pre-built ``AppDef`` but need to make a small change to the resulting scheduler request.
|
|
19
|
+
|
|
20
|
+
And for scheduler authors:
|
|
21
|
+
|
|
22
|
+
1. Scheduler setting needs to be applied to a ``Role``, which makes it hard to add as ``runopts``
|
|
23
|
+
since ``runopts`` apply at the ``AppDef`` level.
|
|
24
|
+
2. Scheduler setting cannot be represented naturally as the types supported by ``runopts``.
|
|
25
|
+
3. Exposing the setting as a ``runopts`` obfuscates things.
|
|
26
|
+
|
|
27
|
+
See :py:func:`~torchx.specs.overlays.apply_overlay` for rules on how overlays are applied.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from typing import Any
|
|
31
|
+
|
|
32
|
+
Json = dict[str, Any]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def apply_overlay(base: Json, overlay: Json) -> None:
|
|
36
|
+
"""Applies ``overlay`` on ``base``.
|
|
37
|
+
|
|
38
|
+
.. note:: this function mutates the ``base``!
|
|
39
|
+
|
|
40
|
+
Overlays follow these rules:
|
|
41
|
+
|
|
42
|
+
1. Dicts, upsert key, value in base with the ones in overlay.
|
|
43
|
+
2. Nested dicts, overlay recursively.
|
|
44
|
+
3. Lists, append the overlay values to the base values.
|
|
45
|
+
4. Nested lists DO NOT append recursively.
|
|
46
|
+
5. Primitives (bool, str, int, float), replace base with the value in overlay.
|
|
47
|
+
|
|
48
|
+
.. doctest::
|
|
49
|
+
|
|
50
|
+
from torchx.specs.overlays import apply_overlay
|
|
51
|
+
|
|
52
|
+
base = {
|
|
53
|
+
"scheduler": {"policy": "default"},
|
|
54
|
+
"resources": {"limits": {"cpu": "500m"}},
|
|
55
|
+
"tolerations": [{"key": "gpu"}],
|
|
56
|
+
"nodeSelectorTerms": [
|
|
57
|
+
[{"matchExpressions": []}]
|
|
58
|
+
],
|
|
59
|
+
"maxPods": 110,
|
|
60
|
+
}
|
|
61
|
+
overlay = {
|
|
62
|
+
"scheduler": {"policy": "binpacking"},
|
|
63
|
+
"resources": {"limits": {"memory": "1Gi"}},
|
|
64
|
+
"tolerations": [{"key": "spot"}],
|
|
65
|
+
"nodeSelectorTerms": [
|
|
66
|
+
[{"matchExpressions": [{"key": "disk"}]}]
|
|
67
|
+
],
|
|
68
|
+
"maxPods": 250,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
apply_overlay(base, overlay)
|
|
72
|
+
|
|
73
|
+
assert {
|
|
74
|
+
"scheduler": {"policy": "binpacking"},
|
|
75
|
+
"resources": {"limits": {"cpu": "500m", "memory": "1Gi"}},
|
|
76
|
+
"tolerations": [{"key": "gpu"}, {"key": "spot"}],
|
|
77
|
+
"nodeSelectorTerms": [
|
|
78
|
+
[{"matchExpressions": []}],
|
|
79
|
+
[{"matchExpressions": [{"key": "disk"}]}],
|
|
80
|
+
],
|
|
81
|
+
"maxPods": 250,
|
|
82
|
+
} == base
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def assert_type_equal(key: str, o1: object, o2: object) -> None:
|
|
87
|
+
o1_type = type(o1)
|
|
88
|
+
o2_type = type(o2)
|
|
89
|
+
assert (
|
|
90
|
+
o1_type == o2_type
|
|
91
|
+
), f"Type mismatch for attr: `{key}`. {o1_type.__qualname__} != {o2_type.__qualname__}"
|
|
92
|
+
|
|
93
|
+
for key, overlay_value in overlay.items():
|
|
94
|
+
if key in base:
|
|
95
|
+
base_value = base[key]
|
|
96
|
+
|
|
97
|
+
assert_type_equal(key, base_value, overlay_value)
|
|
98
|
+
|
|
99
|
+
if isinstance(base_value, dict) and isinstance(overlay_value, dict):
|
|
100
|
+
apply_overlay(base_value, overlay_value)
|
|
101
|
+
elif isinstance(base_value, list) and isinstance(overlay_value, list):
|
|
102
|
+
base_value.extend(overlay_value)
|
|
103
|
+
else:
|
|
104
|
+
base[key] = overlay_value
|
|
105
|
+
else:
|
|
106
|
+
base[key] = overlay_value
|