torchx-nightly 2024.2.12__py3-none-any.whl → 2025.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchx-nightly might be problematic. Click here for more details.
- torchx/__init__.py +2 -0
- torchx/apps/serve/serve.py +2 -0
- torchx/apps/utils/booth_main.py +2 -0
- torchx/apps/utils/copy_main.py +2 -0
- torchx/apps/utils/process_monitor.py +2 -0
- torchx/cli/__init__.py +2 -0
- torchx/cli/argparse_util.py +38 -3
- torchx/cli/cmd_base.py +2 -0
- torchx/cli/cmd_cancel.py +2 -0
- torchx/cli/cmd_configure.py +2 -0
- torchx/cli/cmd_describe.py +2 -0
- torchx/cli/cmd_list.py +2 -0
- torchx/cli/cmd_log.py +6 -24
- torchx/cli/cmd_run.py +30 -12
- torchx/cli/cmd_runopts.py +2 -0
- torchx/cli/cmd_status.py +2 -0
- torchx/cli/cmd_tracker.py +2 -0
- torchx/cli/colors.py +2 -0
- torchx/cli/main.py +2 -0
- torchx/components/__init__.py +2 -0
- torchx/components/component_test_base.py +2 -0
- torchx/components/dist.py +2 -0
- torchx/components/integration_tests/component_provider.py +2 -0
- torchx/components/integration_tests/integ_tests.py +2 -0
- torchx/components/serve.py +2 -0
- torchx/components/structured_arg.py +2 -0
- torchx/components/utils.py +2 -0
- torchx/examples/apps/datapreproc/datapreproc.py +2 -0
- torchx/examples/apps/lightning/data.py +5 -3
- torchx/examples/apps/lightning/model.py +2 -0
- torchx/examples/apps/lightning/profiler.py +7 -4
- torchx/examples/apps/lightning/train.py +2 -0
- torchx/examples/pipelines/kfp/advanced_pipeline.py +2 -0
- torchx/examples/pipelines/kfp/dist_pipeline.py +3 -1
- torchx/examples/pipelines/kfp/intro_pipeline.py +3 -1
- torchx/examples/torchx_out_of_sync_training.py +11 -0
- torchx/notebook.py +2 -0
- torchx/pipelines/kfp/__init__.py +2 -0
- torchx/pipelines/kfp/adapter.py +7 -4
- torchx/pipelines/kfp/version.py +2 -0
- torchx/runner/__init__.py +2 -0
- torchx/runner/api.py +78 -20
- torchx/runner/config.py +34 -3
- torchx/runner/events/__init__.py +37 -3
- torchx/runner/events/api.py +13 -2
- torchx/runner/events/handlers.py +2 -0
- torchx/runtime/tracking/__init__.py +2 -0
- torchx/runtime/tracking/api.py +2 -0
- torchx/schedulers/__init__.py +10 -5
- torchx/schedulers/api.py +3 -1
- torchx/schedulers/aws_batch_scheduler.py +4 -0
- torchx/schedulers/aws_sagemaker_scheduler.py +596 -0
- torchx/schedulers/devices.py +17 -4
- torchx/schedulers/docker_scheduler.py +38 -8
- torchx/schedulers/gcp_batch_scheduler.py +8 -9
- torchx/schedulers/ids.py +2 -0
- torchx/schedulers/kubernetes_mcad_scheduler.py +3 -1
- torchx/schedulers/kubernetes_scheduler.py +31 -5
- torchx/schedulers/local_scheduler.py +45 -6
- torchx/schedulers/lsf_scheduler.py +3 -1
- torchx/schedulers/ray/ray_driver.py +7 -7
- torchx/schedulers/ray_scheduler.py +1 -1
- torchx/schedulers/slurm_scheduler.py +3 -1
- torchx/schedulers/streams.py +2 -0
- torchx/specs/__init__.py +49 -8
- torchx/specs/api.py +87 -5
- torchx/specs/builders.py +61 -19
- torchx/specs/file_linter.py +8 -2
- torchx/specs/finder.py +2 -0
- torchx/specs/named_resources_aws.py +109 -2
- torchx/specs/named_resources_generic.py +2 -0
- torchx/specs/test/components/__init__.py +2 -0
- torchx/specs/test/components/a/__init__.py +2 -0
- torchx/specs/test/components/a/b/__init__.py +2 -0
- torchx/specs/test/components/a/b/c.py +2 -0
- torchx/specs/test/components/c/__init__.py +2 -0
- torchx/specs/test/components/c/d.py +2 -0
- torchx/tracker/__init__.py +2 -0
- torchx/tracker/api.py +4 -4
- torchx/tracker/backend/fsspec.py +2 -0
- torchx/util/cuda.py +2 -0
- torchx/util/datetime.py +2 -0
- torchx/util/entrypoints.py +6 -2
- torchx/util/io.py +2 -0
- torchx/util/log_tee_helpers.py +210 -0
- torchx/util/modules.py +2 -0
- torchx/util/session.py +42 -0
- torchx/util/shlex.py +2 -0
- torchx/util/strings.py +2 -0
- torchx/util/types.py +20 -2
- torchx/version.py +3 -1
- torchx/workspace/__init__.py +2 -0
- torchx/workspace/api.py +34 -1
- torchx/workspace/dir_workspace.py +2 -0
- torchx/workspace/docker_workspace.py +25 -2
- {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/METADATA +55 -48
- torchx_nightly-2025.1.14.dist-info/RECORD +123 -0
- {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/WHEEL +1 -1
- {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/entry_points.txt +0 -1
- torchx_nightly-2024.2.12.dist-info/RECORD +0 -119
- {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/LICENSE +0 -0
- {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.14.dist-info}/top_level.txt +0 -0
torchx/util/entrypoints.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
from typing import Any, Dict, Optional
|
|
8
10
|
|
|
9
11
|
import importlib_metadata as metadata
|
|
@@ -49,8 +51,7 @@ def _defer_load_ep(ep: EntryPoint) -> object:
|
|
|
49
51
|
|
|
50
52
|
# pyre-ignore-all-errors[3, 2]
|
|
51
53
|
def load_group(
|
|
52
|
-
group: str,
|
|
53
|
-
default: Optional[Dict[str, Any]] = None,
|
|
54
|
+
group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False
|
|
54
55
|
):
|
|
55
56
|
"""
|
|
56
57
|
Loads all the entry points specified by ``group`` and returns
|
|
@@ -70,6 +71,7 @@ def load_group(
|
|
|
70
71
|
1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")``
|
|
71
72
|
1. ``load_group("food")`` -> ``None``
|
|
72
73
|
1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")``
|
|
74
|
+
1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None``
|
|
73
75
|
|
|
74
76
|
|
|
75
77
|
If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn``
|
|
@@ -88,6 +90,8 @@ def load_group(
|
|
|
88
90
|
entrypoints = metadata.entry_points().select(group=group)
|
|
89
91
|
|
|
90
92
|
if len(entrypoints) == 0:
|
|
93
|
+
if skip_defaults:
|
|
94
|
+
return None
|
|
91
95
|
return default
|
|
92
96
|
|
|
93
97
|
eps = {}
|
torchx/util/io.py
CHANGED
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
If you're wrapping the TorchX API with your own CLI, these functions can
|
|
11
|
+
help show the logs of the job within your CLI, just like
|
|
12
|
+
`torchx log`
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import threading
|
|
17
|
+
from queue import Queue
|
|
18
|
+
from typing import List, Optional, TextIO, Tuple, TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
from torchx.util.types import none_throws
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from torchx.runner.api import Runner
|
|
24
|
+
from torchx.schedulers.api import Stream
|
|
25
|
+
from torchx.specs.api import AppDef
|
|
26
|
+
|
|
27
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# A torchX job can have stderr/stdout for many replicas, of many roles
|
|
30
|
+
# The scheduler API has functions that allow us to get,
|
|
31
|
+
# with unspecified detail, the log lines of a given replica of
|
|
32
|
+
# a given role.
|
|
33
|
+
#
|
|
34
|
+
# So, to neatly tee the results, we:
|
|
35
|
+
# 1) Determine every role ID / replica ID pair we want to monitor
|
|
36
|
+
# 2) Request the given stderr / stdout / combined streams from them (1 thread each)
|
|
37
|
+
# 3) Concatenate each of those streams to a given destination file
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _find_role_replicas(
|
|
41
|
+
app: "AppDef",
|
|
42
|
+
role_name: Optional[str],
|
|
43
|
+
) -> List[Tuple[str, int]]:
|
|
44
|
+
"""
|
|
45
|
+
Enumerate all (role, replica id) pairs in the given AppDef.
|
|
46
|
+
Replica IDs are 0-indexed, and range up to num_replicas,
|
|
47
|
+
for each role.
|
|
48
|
+
If role_name is provided, filters to only that name.
|
|
49
|
+
"""
|
|
50
|
+
role_replicas = []
|
|
51
|
+
for role in app.roles:
|
|
52
|
+
if role_name is None or role_name == role.name:
|
|
53
|
+
for i in range(role.num_replicas):
|
|
54
|
+
role_replicas.append((role.name, i))
|
|
55
|
+
return role_replicas
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _prefix_line(prefix: str, line: str) -> str:
|
|
59
|
+
"""
|
|
60
|
+
_prefix_line ensure the prefix is still present even when dealing with return characters
|
|
61
|
+
"""
|
|
62
|
+
if "\r" in line:
|
|
63
|
+
line = line.replace("\r", f"\r{prefix}")
|
|
64
|
+
if "\n" in line[:-1]:
|
|
65
|
+
line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:]
|
|
66
|
+
if not line.startswith("\r"):
|
|
67
|
+
line = f"{prefix}{line}"
|
|
68
|
+
return line
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _print_log_lines_for_role_replica(
|
|
72
|
+
dst: TextIO,
|
|
73
|
+
app_handle: str,
|
|
74
|
+
regex: Optional[str],
|
|
75
|
+
runner: "Runner",
|
|
76
|
+
which_role: str,
|
|
77
|
+
which_replica: int,
|
|
78
|
+
exceptions: "Queue[Exception]",
|
|
79
|
+
should_tail: bool,
|
|
80
|
+
streams: Optional["Stream"],
|
|
81
|
+
colorize: bool = False,
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Helper function that'll run in parallel - one
|
|
85
|
+
per monitored replica of a given role.
|
|
86
|
+
|
|
87
|
+
Based on print_log_lines .. but not designed for TTY
|
|
88
|
+
"""
|
|
89
|
+
try:
|
|
90
|
+
for line in runner.log_lines(
|
|
91
|
+
app_handle,
|
|
92
|
+
which_role,
|
|
93
|
+
which_replica,
|
|
94
|
+
regex,
|
|
95
|
+
should_tail=should_tail,
|
|
96
|
+
streams=streams,
|
|
97
|
+
):
|
|
98
|
+
if colorize:
|
|
99
|
+
color_begin = "\033[32m"
|
|
100
|
+
color_end = "\033[0m"
|
|
101
|
+
else:
|
|
102
|
+
color_begin = ""
|
|
103
|
+
color_end = ""
|
|
104
|
+
prefix = f"{color_begin}{which_role}/{which_replica}{color_end} "
|
|
105
|
+
print(_prefix_line(prefix, line.strip()), file=dst, end="\n", flush=True)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
exceptions.put(e)
|
|
108
|
+
raise
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _start_threads_to_monitor_role_replicas(
|
|
112
|
+
dst: TextIO,
|
|
113
|
+
app_handle: str,
|
|
114
|
+
regex: Optional[str],
|
|
115
|
+
runner: "Runner",
|
|
116
|
+
which_role: Optional[str] = None,
|
|
117
|
+
should_tail: bool = False,
|
|
118
|
+
streams: Optional["Stream"] = None,
|
|
119
|
+
colorize: bool = False,
|
|
120
|
+
) -> None:
|
|
121
|
+
threads = []
|
|
122
|
+
|
|
123
|
+
app = none_throws(runner.describe(app_handle))
|
|
124
|
+
replica_ids = _find_role_replicas(app, role_name=which_role)
|
|
125
|
+
|
|
126
|
+
# Holds exceptions raised by all threads, in a thread-safe
|
|
127
|
+
# object
|
|
128
|
+
exceptions = Queue()
|
|
129
|
+
|
|
130
|
+
if not replica_ids:
|
|
131
|
+
valid_roles = [role.name for role in app.roles]
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"{which_role} is not a valid role name. Available: {valid_roles}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
for role_name, replica_id in replica_ids:
|
|
137
|
+
threads.append(
|
|
138
|
+
threading.Thread(
|
|
139
|
+
target=_print_log_lines_for_role_replica,
|
|
140
|
+
kwargs={
|
|
141
|
+
"dst": dst,
|
|
142
|
+
"runner": runner,
|
|
143
|
+
"app_handle": app_handle,
|
|
144
|
+
"which_role": role_name,
|
|
145
|
+
"which_replica": replica_id,
|
|
146
|
+
"regex": regex,
|
|
147
|
+
"should_tail": should_tail,
|
|
148
|
+
"exceptions": exceptions,
|
|
149
|
+
"streams": streams,
|
|
150
|
+
"colorize": colorize,
|
|
151
|
+
},
|
|
152
|
+
daemon=True,
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
for t in threads:
|
|
157
|
+
t.start()
|
|
158
|
+
|
|
159
|
+
for t in threads:
|
|
160
|
+
t.join()
|
|
161
|
+
|
|
162
|
+
# Retrieve all exceptions, print all except one and raise the first recorded exception
|
|
163
|
+
threads_exceptions = []
|
|
164
|
+
while not exceptions.empty():
|
|
165
|
+
threads_exceptions.append(exceptions.get())
|
|
166
|
+
|
|
167
|
+
if len(threads_exceptions) > 0:
|
|
168
|
+
for i in range(1, len(threads_exceptions)):
|
|
169
|
+
logger.error(threads_exceptions[i])
|
|
170
|
+
|
|
171
|
+
raise threads_exceptions[0]
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def tee_logs(
|
|
175
|
+
dst: TextIO,
|
|
176
|
+
app_handle: str,
|
|
177
|
+
regex: Optional[str],
|
|
178
|
+
runner: "Runner",
|
|
179
|
+
should_tail: bool = False,
|
|
180
|
+
streams: Optional["Stream"] = None,
|
|
181
|
+
colorize: bool = False,
|
|
182
|
+
) -> threading.Thread:
|
|
183
|
+
"""
|
|
184
|
+
Makes a thread, which in turn will start 1 thread per replica
|
|
185
|
+
per role, that tees that role-replica's logs to the given
|
|
186
|
+
destination buffer.
|
|
187
|
+
|
|
188
|
+
You'll need to start and join with this parent thread.
|
|
189
|
+
|
|
190
|
+
dst: TextIO to tee the logs into
|
|
191
|
+
app_handle: The return value of runner.run() or runner.schedule()
|
|
192
|
+
regex: Regex to filter the logs that are tee-d
|
|
193
|
+
runner: The Runner you used to schedule the job
|
|
194
|
+
should_tail: If true, continue until we run out of logs. Otherwise, just fetch
|
|
195
|
+
what's available
|
|
196
|
+
streams: Whether to fetch STDERR, STDOUT, or the temporally COMBINED (default) logs
|
|
197
|
+
"""
|
|
198
|
+
thread = threading.Thread(
|
|
199
|
+
target=_start_threads_to_monitor_role_replicas,
|
|
200
|
+
kwargs={
|
|
201
|
+
"dst": dst,
|
|
202
|
+
"runner": runner,
|
|
203
|
+
"app_handle": app_handle,
|
|
204
|
+
"regex": None,
|
|
205
|
+
"should_tail": True,
|
|
206
|
+
"colorize": colorize,
|
|
207
|
+
},
|
|
208
|
+
daemon=True,
|
|
209
|
+
)
|
|
210
|
+
return thread
|
torchx/util/modules.py
CHANGED
torchx/util/session.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# This source code is licensed under the BSD-style license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import uuid
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
TORCHX_INTERNAL_SESSION_ID = "TORCHX_INTERNAL_SESSION_ID"
|
|
15
|
+
|
|
16
|
+
CURRENT_SESSION_ID: Optional[str] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_session_id_or_create_new() -> str:
|
|
20
|
+
"""
|
|
21
|
+
Returns the current session ID, or creates a new one if none exists.
|
|
22
|
+
The session ID remains the same as long as it is in the same process.
|
|
23
|
+
Please DO NOT use this function out of torchx codebase.
|
|
24
|
+
"""
|
|
25
|
+
global CURRENT_SESSION_ID
|
|
26
|
+
if CURRENT_SESSION_ID:
|
|
27
|
+
return CURRENT_SESSION_ID
|
|
28
|
+
env_session_id = os.getenv(TORCHX_INTERNAL_SESSION_ID)
|
|
29
|
+
if env_session_id:
|
|
30
|
+
CURRENT_SESSION_ID = env_session_id
|
|
31
|
+
return CURRENT_SESSION_ID
|
|
32
|
+
session_id = str(uuid.uuid4())
|
|
33
|
+
CURRENT_SESSION_ID = session_id
|
|
34
|
+
return session_id
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_torchx_session_id() -> Optional[str]:
|
|
38
|
+
"""
|
|
39
|
+
Returns the torchx session ID.
|
|
40
|
+
Please use this function to get the session ID out of torchx codebase.
|
|
41
|
+
"""
|
|
42
|
+
return CURRENT_SESSION_ID
|
torchx/util/shlex.py
CHANGED
torchx/util/strings.py
CHANGED
torchx/util/types.py
CHANGED
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
import inspect
|
|
8
10
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
9
11
|
|
|
@@ -43,6 +45,9 @@ def to_dict(arg: str) -> Dict[str, str]:
|
|
|
43
45
|
|
|
44
46
|
to_dict("FOO=v1") == {"FOO": "v1"}
|
|
45
47
|
|
|
48
|
+
to_dict("FOO=''") == {"FOO": ""}
|
|
49
|
+
to_dict('FOO=""') == {"FOO": ""}
|
|
50
|
+
|
|
46
51
|
to_dict("FOO=v1,v2") == {"FOO": "v1,v2"]}
|
|
47
52
|
to_dict("FOO=v1;v2") == {"FOO": "v1;v2"]}
|
|
48
53
|
to_dict("FOO=v1;v2") == {"FOO": "v1;v2,"]}
|
|
@@ -68,6 +73,9 @@ def to_dict(arg: str) -> Dict[str, str]:
|
|
|
68
73
|
else:
|
|
69
74
|
return vk[0:idx].strip(), vk[idx + 1 :].strip()
|
|
70
75
|
|
|
76
|
+
def to_val(val: str) -> str:
|
|
77
|
+
return val if val != '""' and val != "''" else ""
|
|
78
|
+
|
|
71
79
|
arg_map: Dict[str, str] = {}
|
|
72
80
|
|
|
73
81
|
if not arg:
|
|
@@ -90,10 +98,10 @@ def to_dict(arg: str) -> Dict[str, str]:
|
|
|
90
98
|
# middle elements are value_{n}<delim>key_{n+1}
|
|
91
99
|
for vk in split_arg[1 : split_arg_len - 1]: # python deals with
|
|
92
100
|
val, key_next = parse_val_key(vk)
|
|
93
|
-
arg_map[key] = val
|
|
101
|
+
arg_map[key] = to_val(val)
|
|
94
102
|
key = key_next
|
|
95
103
|
val = split_arg[-1] # last element is always a value
|
|
96
|
-
arg_map[key] = val
|
|
104
|
+
arg_map[key] = to_val(val)
|
|
97
105
|
return arg_map
|
|
98
106
|
|
|
99
107
|
|
|
@@ -120,6 +128,16 @@ def _decode_string_to_list(
|
|
|
120
128
|
return arg_values
|
|
121
129
|
|
|
122
130
|
|
|
131
|
+
def decode(encoded_value: Any, annotation: Any):
|
|
132
|
+
if encoded_value is None:
|
|
133
|
+
return None
|
|
134
|
+
if is_bool(annotation):
|
|
135
|
+
return encoded_value and encoded_value.lower() == "true"
|
|
136
|
+
if not is_primitive(annotation) and type(encoded_value) == str:
|
|
137
|
+
return decode_from_string(encoded_value, annotation)
|
|
138
|
+
return encoded_value
|
|
139
|
+
|
|
140
|
+
|
|
123
141
|
def decode_from_string(
|
|
124
142
|
encoded_value: str, annotation: Any
|
|
125
143
|
) -> Union[Dict[Any, Any], List[Any], None]:
|
torchx/version.py
CHANGED
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
# This source code is licensed under the BSD-style license found in the
|
|
6
6
|
# LICENSE file in the root directory of this source tree.
|
|
7
7
|
|
|
8
|
+
# pyre-strict
|
|
9
|
+
|
|
8
10
|
from torchx.util.entrypoints import load
|
|
9
11
|
|
|
10
12
|
# Follows PEP-0440 version scheme guidelines
|
|
@@ -16,7 +18,7 @@ from torchx.util.entrypoints import load
|
|
|
16
18
|
# 0.1.0bN # Beta release
|
|
17
19
|
# 0.1.0rcN # Release Candidate
|
|
18
20
|
# 0.1.0 # Final release
|
|
19
|
-
__version__ = "0.
|
|
21
|
+
__version__ = "0.8.0dev0"
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
# Use the github container registry images corresponding to the current package
|
torchx/workspace/__init__.py
CHANGED
torchx/workspace/api.py
CHANGED
|
@@ -4,10 +4,13 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
import abc
|
|
8
10
|
import fnmatch
|
|
9
11
|
import posixpath
|
|
10
|
-
from
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import Any, Dict, Generic, Iterable, Mapping, Tuple, TYPE_CHECKING, TypeVar
|
|
11
14
|
|
|
12
15
|
from torchx.specs import AppDef, CfgVal, Role, runopts
|
|
13
16
|
|
|
@@ -18,6 +21,36 @@ TORCHX_IGNORE = ".torchxignore"
|
|
|
18
21
|
|
|
19
22
|
T = TypeVar("T")
|
|
20
23
|
|
|
24
|
+
PackageType = TypeVar("PackageType")
|
|
25
|
+
WorkspaceConfigType = TypeVar("WorkspaceConfigType")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class PkgInfo(Generic[PackageType]):
|
|
30
|
+
"""
|
|
31
|
+
Convenience class used to specify information regarding the built workspace
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
img: str
|
|
35
|
+
lazy_overrides: Dict[str, Any]
|
|
36
|
+
metadata: PackageType
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class WorkspaceBuilder(Generic[PackageType, WorkspaceConfigType]):
|
|
41
|
+
cfg: WorkspaceConfigType
|
|
42
|
+
|
|
43
|
+
@abc.abstractmethod
|
|
44
|
+
def build_workspace(self, sync: bool = True) -> PkgInfo[PackageType]:
|
|
45
|
+
"""
|
|
46
|
+
Builds the specified ``workspace`` with respect to ``img``.
|
|
47
|
+
In the simplest case, this method builds a new image.
|
|
48
|
+
Certain (more efficient) implementations build
|
|
49
|
+
incremental diff patches that overlay on top of the role's image.
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
pass
|
|
53
|
+
|
|
21
54
|
|
|
22
55
|
class WorkspaceMixin(abc.ABC, Generic[T]):
|
|
23
56
|
"""
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
# pyre-strict
|
|
8
|
+
|
|
7
9
|
import io
|
|
8
10
|
import logging
|
|
9
11
|
import posixpath
|
|
@@ -16,6 +18,7 @@ from typing import Dict, IO, Iterable, Mapping, Optional, TextIO, Tuple, TYPE_CH
|
|
|
16
18
|
import fsspec
|
|
17
19
|
|
|
18
20
|
import torchx
|
|
21
|
+
from docker.errors import BuildError
|
|
19
22
|
from torchx.specs import AppDef, CfgVal, Role, runopts
|
|
20
23
|
from torchx.workspace.api import walk_workspace, WorkspaceMixin
|
|
21
24
|
|
|
@@ -91,6 +94,12 @@ class DockerWorkspaceMixin(WorkspaceMixin[Dict[str, Tuple[str, str]]]):
|
|
|
91
94
|
type_=str,
|
|
92
95
|
help="(remote jobs) the image repository to use when pushing patched images, must have push access. Ex: example.com/your/container",
|
|
93
96
|
)
|
|
97
|
+
opts.add(
|
|
98
|
+
"quiet",
|
|
99
|
+
type_=bool,
|
|
100
|
+
default=False,
|
|
101
|
+
help="whether to suppress verbose output for image building. Defaults to ``False``.",
|
|
102
|
+
)
|
|
94
103
|
return opts
|
|
95
104
|
|
|
96
105
|
def build_workspace_and_update_role(
|
|
@@ -119,7 +128,7 @@ class DockerWorkspaceMixin(WorkspaceMixin[Dict[str, Tuple[str, str]]]):
|
|
|
119
128
|
f"failed to pull image {role.image}, falling back to local: {e}"
|
|
120
129
|
)
|
|
121
130
|
log.info("Building workspace docker image (this may take a while)...")
|
|
122
|
-
|
|
131
|
+
build_events = self._docker_client.api.build(
|
|
123
132
|
fileobj=context,
|
|
124
133
|
custom_context=True,
|
|
125
134
|
dockerfile=TORCHX_DOCKERFILE,
|
|
@@ -129,12 +138,26 @@ class DockerWorkspaceMixin(WorkspaceMixin[Dict[str, Tuple[str, str]]]):
|
|
|
129
138
|
},
|
|
130
139
|
pull=False,
|
|
131
140
|
rm=True,
|
|
141
|
+
decode=True,
|
|
132
142
|
labels={
|
|
133
143
|
self.LABEL_VERSION: torchx.__version__,
|
|
134
144
|
},
|
|
135
145
|
)
|
|
146
|
+
image_id = None
|
|
147
|
+
for event in build_events:
|
|
148
|
+
if message := event.get("stream"):
|
|
149
|
+
if not cfg.get("quiet", False):
|
|
150
|
+
message = message.strip("\r\n").strip("\n")
|
|
151
|
+
if message:
|
|
152
|
+
log.info(message)
|
|
153
|
+
if aux := event.get("aux"):
|
|
154
|
+
image_id = aux["ID"]
|
|
155
|
+
if error := event.get("error"):
|
|
156
|
+
raise BuildError(reason=error, build_log=None)
|
|
136
157
|
if len(old_imgs) == 0 or role.image not in old_imgs:
|
|
137
|
-
|
|
158
|
+
assert image_id, "image id was not found"
|
|
159
|
+
role.image = image_id
|
|
160
|
+
|
|
138
161
|
finally:
|
|
139
162
|
context.close()
|
|
140
163
|
|