torchx-nightly 2024.1.6__py3-none-any.whl → 2025.12.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchx-nightly might be problematic. Click here for more details.
- torchx/__init__.py +2 -0
- torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
- torchx/apps/serve/serve.py +2 -0
- torchx/apps/utils/booth_main.py +2 -0
- torchx/apps/utils/copy_main.py +2 -0
- torchx/apps/utils/process_monitor.py +2 -0
- torchx/cli/__init__.py +2 -0
- torchx/cli/argparse_util.py +38 -3
- torchx/cli/cmd_base.py +2 -0
- torchx/cli/cmd_cancel.py +2 -0
- torchx/cli/cmd_configure.py +2 -0
- torchx/cli/cmd_delete.py +30 -0
- torchx/cli/cmd_describe.py +2 -0
- torchx/cli/cmd_list.py +8 -4
- torchx/cli/cmd_log.py +6 -24
- torchx/cli/cmd_run.py +269 -45
- torchx/cli/cmd_runopts.py +2 -0
- torchx/cli/cmd_status.py +12 -1
- torchx/cli/cmd_tracker.py +3 -1
- torchx/cli/colors.py +2 -0
- torchx/cli/main.py +4 -0
- torchx/components/__init__.py +3 -8
- torchx/components/component_test_base.py +2 -0
- torchx/components/dist.py +18 -7
- torchx/components/integration_tests/component_provider.py +4 -2
- torchx/components/integration_tests/integ_tests.py +2 -0
- torchx/components/serve.py +2 -0
- torchx/components/structured_arg.py +4 -3
- torchx/components/utils.py +15 -4
- torchx/distributed/__init__.py +2 -4
- torchx/examples/apps/datapreproc/datapreproc.py +2 -0
- torchx/examples/apps/lightning/data.py +5 -3
- torchx/examples/apps/lightning/model.py +7 -6
- torchx/examples/apps/lightning/profiler.py +7 -4
- torchx/examples/apps/lightning/train.py +11 -2
- torchx/examples/torchx_out_of_sync_training.py +11 -0
- torchx/notebook.py +2 -0
- torchx/runner/__init__.py +2 -0
- torchx/runner/api.py +167 -60
- torchx/runner/config.py +43 -10
- torchx/runner/events/__init__.py +57 -13
- torchx/runner/events/api.py +14 -3
- torchx/runner/events/handlers.py +2 -0
- torchx/runtime/tracking/__init__.py +2 -0
- torchx/runtime/tracking/api.py +2 -0
- torchx/schedulers/__init__.py +16 -15
- torchx/schedulers/api.py +70 -14
- torchx/schedulers/aws_batch_scheduler.py +75 -6
- torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
- torchx/schedulers/devices.py +17 -4
- torchx/schedulers/docker_scheduler.py +43 -11
- torchx/schedulers/ids.py +29 -23
- torchx/schedulers/kubernetes_mcad_scheduler.py +9 -7
- torchx/schedulers/kubernetes_scheduler.py +383 -38
- torchx/schedulers/local_scheduler.py +100 -27
- torchx/schedulers/lsf_scheduler.py +5 -4
- torchx/schedulers/slurm_scheduler.py +336 -20
- torchx/schedulers/streams.py +2 -0
- torchx/specs/__init__.py +89 -12
- torchx/specs/api.py +418 -30
- torchx/specs/builders.py +176 -38
- torchx/specs/file_linter.py +143 -57
- torchx/specs/finder.py +68 -28
- torchx/specs/named_resources_aws.py +181 -4
- torchx/specs/named_resources_generic.py +2 -0
- torchx/specs/overlays.py +106 -0
- torchx/specs/test/components/__init__.py +2 -0
- torchx/specs/test/components/a/__init__.py +2 -0
- torchx/specs/test/components/a/b/__init__.py +2 -0
- torchx/specs/test/components/a/b/c.py +2 -0
- torchx/specs/test/components/c/__init__.py +2 -0
- torchx/specs/test/components/c/d.py +2 -0
- torchx/tracker/__init__.py +12 -6
- torchx/tracker/api.py +15 -18
- torchx/tracker/backend/fsspec.py +2 -0
- torchx/util/cuda.py +2 -0
- torchx/util/datetime.py +2 -0
- torchx/util/entrypoints.py +39 -15
- torchx/util/io.py +2 -0
- torchx/util/log_tee_helpers.py +210 -0
- torchx/util/modules.py +65 -0
- torchx/util/session.py +42 -0
- torchx/util/shlex.py +2 -0
- torchx/util/strings.py +3 -1
- torchx/util/types.py +90 -29
- torchx/version.py +4 -2
- torchx/workspace/__init__.py +2 -0
- torchx/workspace/api.py +136 -6
- torchx/workspace/dir_workspace.py +2 -0
- torchx/workspace/docker_workspace.py +30 -2
- torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
- torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
- {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
- {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
- torchx/examples/pipelines/__init__.py +0 -0
- torchx/examples/pipelines/kfp/__init__.py +0 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
- torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
- torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
- torchx/pipelines/kfp/__init__.py +0 -28
- torchx/pipelines/kfp/adapter.py +0 -271
- torchx/pipelines/kfp/version.py +0 -17
- torchx/schedulers/gcp_batch_scheduler.py +0 -487
- torchx/schedulers/ray/ray_common.py +0 -22
- torchx/schedulers/ray/ray_driver.py +0 -307
- torchx/schedulers/ray_scheduler.py +0 -453
- torchx_nightly-2024.1.6.dist-info/METADATA +0 -176
- torchx_nightly-2024.1.6.dist-info/RECORD +0 -118
- {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
- {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
torchx/tracker/__init__.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
"""
|
|
8
10
|
.. note:: PROTOTYPE, USE AT YOUR OWN RISK, APIs SUBJECT TO CHANGE
|
|
9
11
|
|
|
@@ -30,14 +32,14 @@ implementation.
|
|
|
30
32
|
|
|
31
33
|
Example usage
|
|
32
34
|
-------------
|
|
33
|
-
Sample `code <https://github.com/pytorch/torchx/blob/main/torchx/examples/apps/tracker/main.py>`__ using tracker API.
|
|
35
|
+
Sample `code <https://github.com/meta-pytorch/torchx/blob/main/torchx/examples/apps/tracker/main.py>`__ using tracker API.
|
|
34
36
|
|
|
35
37
|
|
|
36
38
|
Tracker Setup
|
|
37
39
|
-------------
|
|
38
40
|
To enable tracking it requires:
|
|
39
41
|
|
|
40
|
-
1. Defining tracker backends (entrypoints and configuration) on launcher side using :doc:`runner.config`
|
|
42
|
+
1. Defining tracker backends (entrypoints/modules and configuration) on launcher side using :doc:`runner.config`
|
|
41
43
|
2. Adding entrypoints within a user job using entry_points (`specification`_)
|
|
42
44
|
|
|
43
45
|
.. _specification: https://packaging.python.org/en/latest/specifications/entry-points/
|
|
@@ -49,13 +51,13 @@ To enable tracking it requires:
|
|
|
49
51
|
User can define any number of tracker backends under **torchx:tracker** section in :doc:`runner.config`, where:
|
|
50
52
|
* Key: is an arbitrary name for the tracker, where the name will be used to configure its properties
|
|
51
53
|
under [tracker:<TRACKER_NAME>]
|
|
52
|
-
* Value: is *entrypoint
|
|
54
|
+
* Value: is *entrypoint* or *module* factory method that must be available within user job. The value will be injected into a
|
|
53
55
|
user job and used to construct tracker implementation.
|
|
54
56
|
|
|
55
57
|
.. code-block:: ini
|
|
56
58
|
|
|
57
59
|
[torchx:tracker]
|
|
58
|
-
tracker_name=<
|
|
60
|
+
tracker_name=<entry_point_or_module_factory_method>
|
|
59
61
|
|
|
60
62
|
|
|
61
63
|
Each tracker can be additionally configured (currently limited to `config` parameter) under `[tracker:<TRACKER NAME>]` section:
|
|
@@ -71,11 +73,15 @@ For example, ~/.torchxconfig may be setup as:
|
|
|
71
73
|
|
|
72
74
|
[torchx:tracker]
|
|
73
75
|
tracker1=tracker1
|
|
74
|
-
|
|
76
|
+
tracker2=backend_2_entry_point
|
|
77
|
+
tracker3=torchx.tracker.mlflow:create_tracker
|
|
75
78
|
|
|
76
79
|
[tracker:tracker1]
|
|
77
80
|
config=s3://my_bucket/config.json
|
|
78
81
|
|
|
82
|
+
[tracker:tracker3]
|
|
83
|
+
config=my_config.json
|
|
84
|
+
|
|
79
85
|
|
|
80
86
|
2. User job configuration (Advanced)
|
|
81
87
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
@@ -105,7 +111,7 @@ Use :py:meth:`~torchx.tracker.app_run_from_env`:
|
|
|
105
111
|
Reference :py:class:`~torchx.tracker.api.TrackerBase` implementation
|
|
106
112
|
--------------------------------------------------------------------
|
|
107
113
|
:py:class:`~torchx.tracker.backend.fsspec.FsspecTracker` provides reference implementation of a tracker backend.
|
|
108
|
-
GitHub example `directory <https://github.com/pytorch/torchx/blob/main/torchx/examples/apps/tracker/>`__ provides example on how to
|
|
114
|
+
GitHub example `directory <https://github.com/meta-pytorch/torchx/blob/main/torchx/examples/apps/tracker/>`__ provides example on how to
|
|
109
115
|
configure and use it in user application.
|
|
110
116
|
|
|
111
117
|
|
torchx/tracker/api.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
from __future__ import annotations
|
|
8
10
|
|
|
9
11
|
import logging
|
|
@@ -14,6 +16,7 @@ from functools import lru_cache
|
|
|
14
16
|
from typing import Iterable, Mapping, Optional
|
|
15
17
|
|
|
16
18
|
from torchx.util.entrypoints import load_group
|
|
19
|
+
from torchx.util.modules import load_module
|
|
17
20
|
|
|
18
21
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
19
22
|
|
|
@@ -66,8 +69,7 @@ class AppRunTrackableSource:
|
|
|
66
69
|
artifact_name: Optional[str]
|
|
67
70
|
|
|
68
71
|
|
|
69
|
-
class Lineage:
|
|
70
|
-
...
|
|
72
|
+
class Lineage: ...
|
|
71
73
|
|
|
72
74
|
|
|
73
75
|
class TrackerBase(ABC):
|
|
@@ -177,30 +179,26 @@ def _extract_tracker_name_and_config_from_environ() -> Mapping[str, Optional[str
|
|
|
177
179
|
|
|
178
180
|
|
|
179
181
|
def build_trackers(
|
|
180
|
-
|
|
182
|
+
factory_and_config: Mapping[str, Optional[str]]
|
|
181
183
|
) -> Iterable[TrackerBase]:
|
|
182
184
|
trackers = []
|
|
183
185
|
|
|
184
|
-
entrypoint_factories = load_group("torchx.tracker")
|
|
186
|
+
entrypoint_factories = load_group("torchx.tracker") or {}
|
|
185
187
|
if not entrypoint_factories:
|
|
186
|
-
logger.warning(
|
|
187
|
-
"No 'torchx.tracker' entry_points are defined. Tracking will not capture any data."
|
|
188
|
-
)
|
|
189
|
-
return trackers
|
|
188
|
+
logger.warning("No 'torchx.tracker' entry_points are defined.")
|
|
190
189
|
|
|
191
|
-
for
|
|
192
|
-
|
|
190
|
+
for factory_name, config in factory_and_config.items():
|
|
191
|
+
factory = entrypoint_factories.get(factory_name) or load_module(factory_name)
|
|
192
|
+
if not factory:
|
|
193
193
|
logger.warning(
|
|
194
|
-
f"
|
|
194
|
+
f"No tracker factory `{factory_name}` found in entry_points or modules. See https://meta-pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
|
|
195
195
|
)
|
|
196
196
|
continue
|
|
197
|
-
factory = entrypoint_factories[entrypoint_key]
|
|
198
197
|
if config:
|
|
199
|
-
logger.info(f"Tracker config found for `{
|
|
200
|
-
tracker = factory(config)
|
|
198
|
+
logger.info(f"Tracker config found for `{factory_name}` as `{config}`")
|
|
201
199
|
else:
|
|
202
|
-
logger.info(f"No tracker config specified for `{
|
|
203
|
-
|
|
200
|
+
logger.info(f"No tracker config specified for `{factory_name}`")
|
|
201
|
+
tracker = factory(config)
|
|
204
202
|
trackers.append(tracker)
|
|
205
203
|
return trackers
|
|
206
204
|
|
|
@@ -335,5 +333,4 @@ class AppRun:
|
|
|
335
333
|
|
|
336
334
|
return model_run_sources
|
|
337
335
|
|
|
338
|
-
def children(self) -> Iterable[AppRun]:
|
|
339
|
-
...
|
|
336
|
+
def children(self) -> Iterable[AppRun]: ...
|
torchx/tracker/backend/fsspec.py
CHANGED
torchx/util/cuda.py
CHANGED
torchx/util/datetime.py
CHANGED
torchx/util/entrypoints.py
CHANGED
|
@@ -4,13 +4,14 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
# pyre-strict
|
|
8
|
+
# pyre-ignore-all-errors[3, 2, 16]
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
from
|
|
10
|
+
from importlib import metadata
|
|
11
|
+
from importlib.metadata import EntryPoint
|
|
12
|
+
from typing import Any, Dict, Optional
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
# pyre-ignore-all-errors[3, 2]
|
|
14
15
|
def load(group: str, name: str, default=None):
|
|
15
16
|
"""
|
|
16
17
|
Loads the entry point specified by
|
|
@@ -28,13 +29,34 @@ def load(group: str, name: str, default=None):
|
|
|
28
29
|
raises an error.
|
|
29
30
|
"""
|
|
30
31
|
|
|
31
|
-
|
|
32
|
+
# [note_on_entrypoints]
|
|
33
|
+
# return type of importlib.metadata.entry_points() is different between python-3.9 and python-3.10
|
|
34
|
+
# https://docs.python.org/3.9/library/importlib.metadata.html#importlib.metadata.entry_points
|
|
35
|
+
# https://docs.python.org/3.10/library/importlib.metadata.html#importlib.metadata.entry_points
|
|
36
|
+
if hasattr(metadata.entry_points(), "select"):
|
|
37
|
+
# python>=3.10
|
|
38
|
+
entrypoints = metadata.entry_points().select(group=group)
|
|
32
39
|
|
|
33
|
-
|
|
34
|
-
|
|
40
|
+
if name not in entrypoints.names and default is not None:
|
|
41
|
+
return default
|
|
42
|
+
|
|
43
|
+
ep = entrypoints[name]
|
|
44
|
+
return ep.load()
|
|
35
45
|
|
|
36
|
-
|
|
37
|
-
|
|
46
|
+
else:
|
|
47
|
+
# python<3.10 (e.g. 3.9)
|
|
48
|
+
# metadata.entry_points() returns dict[str, tuple[EntryPoint]] (not EntryPoints) in python-3.9
|
|
49
|
+
entrypoints = metadata.entry_points().get(group, ())
|
|
50
|
+
|
|
51
|
+
for ep in entrypoints:
|
|
52
|
+
if ep.name == name:
|
|
53
|
+
return ep.load()
|
|
54
|
+
|
|
55
|
+
# [group].name not found
|
|
56
|
+
if default is not None:
|
|
57
|
+
return default
|
|
58
|
+
else:
|
|
59
|
+
raise KeyError(f"entrypoint {group}.{name} not found")
|
|
38
60
|
|
|
39
61
|
|
|
40
62
|
def _defer_load_ep(ep: EntryPoint) -> object:
|
|
@@ -47,11 +69,7 @@ def _defer_load_ep(ep: EntryPoint) -> object:
|
|
|
47
69
|
return run
|
|
48
70
|
|
|
49
71
|
|
|
50
|
-
|
|
51
|
-
def load_group(
|
|
52
|
-
group: str,
|
|
53
|
-
default: Optional[Dict[str, Any]] = None,
|
|
54
|
-
):
|
|
72
|
+
def load_group(group: str, default: Optional[Dict[str, Any]] = None):
|
|
55
73
|
"""
|
|
56
74
|
Loads all the entry points specified by ``group`` and returns
|
|
57
75
|
the entry points as a map of ``name (str) -> deferred_load_fn``.
|
|
@@ -85,7 +103,13 @@ def load_group(
|
|
|
85
103
|
|
|
86
104
|
"""
|
|
87
105
|
|
|
88
|
-
|
|
106
|
+
# see [note_on_entrypoints] above
|
|
107
|
+
if hasattr(metadata.entry_points(), "select"):
|
|
108
|
+
# python>=3.10
|
|
109
|
+
entrypoints = metadata.entry_points().select(group=group)
|
|
110
|
+
else:
|
|
111
|
+
# python<3.10 (e.g. 3.9)
|
|
112
|
+
entrypoints = metadata.entry_points().get(group, ())
|
|
89
113
|
|
|
90
114
|
if len(entrypoints) == 0:
|
|
91
115
|
return default
|
torchx/util/io.py
CHANGED
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
If you're wrapping the TorchX API with your own CLI, these functions can
|
|
11
|
+
help show the logs of the job within your CLI, just like
|
|
12
|
+
`torchx log`
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import threading
|
|
17
|
+
from queue import Queue
|
|
18
|
+
from typing import List, Optional, TextIO, Tuple, TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
from torchx.util.types import none_throws
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from torchx.runner.api import Runner
|
|
24
|
+
from torchx.schedulers.api import Stream
|
|
25
|
+
from torchx.specs.api import AppDef
|
|
26
|
+
|
|
27
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# A torchX job can have stderr/stdout for many replicas, of many roles
|
|
30
|
+
# The scheduler API has functions that allow us to get,
|
|
31
|
+
# with unspecified detail, the log lines of a given replica of
|
|
32
|
+
# a given role.
|
|
33
|
+
#
|
|
34
|
+
# So, to neatly tee the results, we:
|
|
35
|
+
# 1) Determine every role ID / replica ID pair we want to monitor
|
|
36
|
+
# 2) Request the given stderr / stdout / combined streams from them (1 thread each)
|
|
37
|
+
# 3) Concatenate each of those streams to a given destination file
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _find_role_replicas(
|
|
41
|
+
app: "AppDef",
|
|
42
|
+
role_name: Optional[str],
|
|
43
|
+
) -> List[Tuple[str, int]]:
|
|
44
|
+
"""
|
|
45
|
+
Enumerate all (role, replica id) pairs in the given AppDef.
|
|
46
|
+
Replica IDs are 0-indexed, and range up to num_replicas,
|
|
47
|
+
for each role.
|
|
48
|
+
If role_name is provided, filters to only that name.
|
|
49
|
+
"""
|
|
50
|
+
role_replicas = []
|
|
51
|
+
for role in app.roles:
|
|
52
|
+
if role_name is None or role_name == role.name:
|
|
53
|
+
for i in range(role.num_replicas):
|
|
54
|
+
role_replicas.append((role.name, i))
|
|
55
|
+
return role_replicas
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _prefix_line(prefix: str, line: str) -> str:
|
|
59
|
+
"""
|
|
60
|
+
_prefix_line ensure the prefix is still present even when dealing with return characters
|
|
61
|
+
"""
|
|
62
|
+
if "\r" in line:
|
|
63
|
+
line = line.replace("\r", f"\r{prefix}")
|
|
64
|
+
if "\n" in line[:-1]:
|
|
65
|
+
line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:]
|
|
66
|
+
if not line.startswith("\r"):
|
|
67
|
+
line = f"{prefix}{line}"
|
|
68
|
+
return line
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _print_log_lines_for_role_replica(
|
|
72
|
+
dst: TextIO,
|
|
73
|
+
app_handle: str,
|
|
74
|
+
regex: Optional[str],
|
|
75
|
+
runner: "Runner",
|
|
76
|
+
which_role: str,
|
|
77
|
+
which_replica: int,
|
|
78
|
+
exceptions: "Queue[Exception]",
|
|
79
|
+
should_tail: bool,
|
|
80
|
+
streams: Optional["Stream"],
|
|
81
|
+
colorize: bool = False,
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Helper function that'll run in parallel - one
|
|
85
|
+
per monitored replica of a given role.
|
|
86
|
+
|
|
87
|
+
Based on print_log_lines .. but not designed for TTY
|
|
88
|
+
"""
|
|
89
|
+
try:
|
|
90
|
+
for line in runner.log_lines(
|
|
91
|
+
app_handle,
|
|
92
|
+
which_role,
|
|
93
|
+
which_replica,
|
|
94
|
+
regex,
|
|
95
|
+
should_tail=should_tail,
|
|
96
|
+
streams=streams,
|
|
97
|
+
):
|
|
98
|
+
if colorize:
|
|
99
|
+
color_begin = "\033[32m"
|
|
100
|
+
color_end = "\033[0m"
|
|
101
|
+
else:
|
|
102
|
+
color_begin = ""
|
|
103
|
+
color_end = ""
|
|
104
|
+
prefix = f"{color_begin}{which_role}/{which_replica}{color_end} "
|
|
105
|
+
print(_prefix_line(prefix, line.strip()), file=dst, end="\n", flush=True)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
exceptions.put(e)
|
|
108
|
+
raise
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _start_threads_to_monitor_role_replicas(
|
|
112
|
+
dst: TextIO,
|
|
113
|
+
app_handle: str,
|
|
114
|
+
regex: Optional[str],
|
|
115
|
+
runner: "Runner",
|
|
116
|
+
which_role: Optional[str] = None,
|
|
117
|
+
should_tail: bool = False,
|
|
118
|
+
streams: Optional["Stream"] = None,
|
|
119
|
+
colorize: bool = False,
|
|
120
|
+
) -> None:
|
|
121
|
+
threads = []
|
|
122
|
+
|
|
123
|
+
app = none_throws(runner.describe(app_handle))
|
|
124
|
+
replica_ids = _find_role_replicas(app, role_name=which_role)
|
|
125
|
+
|
|
126
|
+
# Holds exceptions raised by all threads, in a thread-safe
|
|
127
|
+
# object
|
|
128
|
+
exceptions = Queue()
|
|
129
|
+
|
|
130
|
+
if not replica_ids:
|
|
131
|
+
valid_roles = [role.name for role in app.roles]
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"{which_role} is not a valid role name. Available: {valid_roles}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
for role_name, replica_id in replica_ids:
|
|
137
|
+
threads.append(
|
|
138
|
+
threading.Thread(
|
|
139
|
+
target=_print_log_lines_for_role_replica,
|
|
140
|
+
kwargs={
|
|
141
|
+
"dst": dst,
|
|
142
|
+
"runner": runner,
|
|
143
|
+
"app_handle": app_handle,
|
|
144
|
+
"which_role": role_name,
|
|
145
|
+
"which_replica": replica_id,
|
|
146
|
+
"regex": regex,
|
|
147
|
+
"should_tail": should_tail,
|
|
148
|
+
"exceptions": exceptions,
|
|
149
|
+
"streams": streams,
|
|
150
|
+
"colorize": colorize,
|
|
151
|
+
},
|
|
152
|
+
daemon=True,
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
for t in threads:
|
|
157
|
+
t.start()
|
|
158
|
+
|
|
159
|
+
for t in threads:
|
|
160
|
+
t.join()
|
|
161
|
+
|
|
162
|
+
# Retrieve all exceptions, print all except one and raise the first recorded exception
|
|
163
|
+
threads_exceptions = []
|
|
164
|
+
while not exceptions.empty():
|
|
165
|
+
threads_exceptions.append(exceptions.get())
|
|
166
|
+
|
|
167
|
+
if len(threads_exceptions) > 0:
|
|
168
|
+
for i in range(1, len(threads_exceptions)):
|
|
169
|
+
logger.error(threads_exceptions[i])
|
|
170
|
+
|
|
171
|
+
raise threads_exceptions[0]
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def tee_logs(
|
|
175
|
+
dst: TextIO,
|
|
176
|
+
app_handle: str,
|
|
177
|
+
regex: Optional[str],
|
|
178
|
+
runner: "Runner",
|
|
179
|
+
should_tail: bool = False,
|
|
180
|
+
streams: Optional["Stream"] = None,
|
|
181
|
+
colorize: bool = False,
|
|
182
|
+
) -> threading.Thread:
|
|
183
|
+
"""
|
|
184
|
+
Makes a thread, which in turn will start 1 thread per replica
|
|
185
|
+
per role, that tees that role-replica's logs to the given
|
|
186
|
+
destination buffer.
|
|
187
|
+
|
|
188
|
+
You'll need to start and join with this parent thread.
|
|
189
|
+
|
|
190
|
+
dst: TextIO to tee the logs into
|
|
191
|
+
app_handle: The return value of runner.run() or runner.schedule()
|
|
192
|
+
regex: Regex to filter the logs that are tee-d
|
|
193
|
+
runner: The Runner you used to schedule the job
|
|
194
|
+
should_tail: If true, continue until we run out of logs. Otherwise, just fetch
|
|
195
|
+
what's available
|
|
196
|
+
streams: Whether to fetch STDERR, STDOUT, or the temporally COMBINED (default) logs
|
|
197
|
+
"""
|
|
198
|
+
thread = threading.Thread(
|
|
199
|
+
target=_start_threads_to_monitor_role_replicas,
|
|
200
|
+
kwargs={
|
|
201
|
+
"dst": dst,
|
|
202
|
+
"runner": runner,
|
|
203
|
+
"app_handle": app_handle,
|
|
204
|
+
"regex": None,
|
|
205
|
+
"should_tail": True,
|
|
206
|
+
"colorize": colorize,
|
|
207
|
+
},
|
|
208
|
+
daemon=True,
|
|
209
|
+
)
|
|
210
|
+
return thread
|
torchx/util/modules.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
import importlib
|
|
10
|
+
from types import ModuleType
|
|
11
|
+
from typing import Callable, Optional, TypeVar, Union
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_module(path: str) -> Union[ModuleType, Optional[Callable[..., object]]]:
|
|
15
|
+
"""
|
|
16
|
+
Loads and returns the module/module attr represented by the ``path``: ``full.module.path:optional_attr``
|
|
17
|
+
|
|
18
|
+
1. ``load_module("this.is.a_module:fn")`` -> equivalent to ``this.is.a_module.fn``
|
|
19
|
+
1. ``load_module("this.is.a_module")`` -> equivalent to ``this.is.a_module``
|
|
20
|
+
"""
|
|
21
|
+
parts = path.split(":", 2)
|
|
22
|
+
module_path, method = parts[0], parts[1] if len(parts) > 1 else None
|
|
23
|
+
module = None
|
|
24
|
+
i, n = -1, len(module_path)
|
|
25
|
+
try:
|
|
26
|
+
while i < n:
|
|
27
|
+
i = module_path.find(".", i + 1)
|
|
28
|
+
i = i if i >= 0 else n
|
|
29
|
+
module = importlib.import_module(module_path[:i])
|
|
30
|
+
return getattr(module, method) if method else module
|
|
31
|
+
except Exception:
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
T = TypeVar("T")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def import_attr(name: str, attr: str, default: T) -> T:
|
|
39
|
+
"""
|
|
40
|
+
Imports ``name.attr`` and returns it if the module is found.
|
|
41
|
+
Otherwise, returns the specified ``default``.
|
|
42
|
+
Useful when getting an attribute from an optional dependency.
|
|
43
|
+
|
|
44
|
+
Note that the ``default`` parameter is intentionally not an optional
|
|
45
|
+
since this function is intended to be used with modules that may not be
|
|
46
|
+
installed as a dependency. Therefore the caller must ALWAYS provide a
|
|
47
|
+
sensible default.
|
|
48
|
+
|
|
49
|
+
Usage:
|
|
50
|
+
|
|
51
|
+
.. code-block:: python
|
|
52
|
+
|
|
53
|
+
aws_resources = import_attr("torchx.specs.named_resources_aws", "NAMED_RESOURCES", default={})
|
|
54
|
+
all_resources.update(aws_resources)
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
AttributeError: If the module exists (e.g. can be imported)
|
|
58
|
+
but does not have an attribute with name ``attr``.
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
mod = importlib.import_module(name)
|
|
62
|
+
except ModuleNotFoundError:
|
|
63
|
+
return default
|
|
64
|
+
else:
|
|
65
|
+
return getattr(mod, attr)
|
torchx/util/session.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import uuid
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
TORCHX_INTERNAL_SESSION_ID = "TORCHX_INTERNAL_SESSION_ID"
|
|
15
|
+
|
|
16
|
+
CURRENT_SESSION_ID: Optional[str] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_session_id_or_create_new() -> str:
|
|
20
|
+
"""
|
|
21
|
+
Returns the current session ID, or creates a new one if none exists.
|
|
22
|
+
The session ID remains the same as long as it is in the same process.
|
|
23
|
+
Please DO NOT use this function out of torchx codebase.
|
|
24
|
+
"""
|
|
25
|
+
global CURRENT_SESSION_ID
|
|
26
|
+
if CURRENT_SESSION_ID:
|
|
27
|
+
return CURRENT_SESSION_ID
|
|
28
|
+
env_session_id = os.getenv(TORCHX_INTERNAL_SESSION_ID)
|
|
29
|
+
if env_session_id:
|
|
30
|
+
CURRENT_SESSION_ID = env_session_id
|
|
31
|
+
return CURRENT_SESSION_ID
|
|
32
|
+
session_id = str(uuid.uuid4())
|
|
33
|
+
CURRENT_SESSION_ID = session_id
|
|
34
|
+
return session_id
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_torchx_session_id() -> Optional[str]:
|
|
38
|
+
"""
|
|
39
|
+
Returns the torchx session ID.
|
|
40
|
+
Please use this function to get the session ID out of torchx codebase.
|
|
41
|
+
"""
|
|
42
|
+
return CURRENT_SESSION_ID
|
torchx/util/shlex.py
CHANGED
torchx/util/strings.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
import re
|
|
8
10
|
|
|
9
11
|
|
|
@@ -11,7 +13,7 @@ def normalize_str(data: str) -> str:
|
|
|
11
13
|
"""
|
|
12
14
|
Invokes ``lower`` on thes string and removes all
|
|
13
15
|
characters that do not satisfy ``[a-z0-9\\-]`` pattern.
|
|
14
|
-
This method is mostly used to make sure kubernetes
|
|
16
|
+
This method is mostly used to make sure kubernetes scheduler gets
|
|
15
17
|
the job name that does not violate its restrictions.
|
|
16
18
|
"""
|
|
17
19
|
if data.startswith("-"):
|