torchx-nightly 2024.2.12__py3-none-any.whl → 2025.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (102) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/apps/serve/serve.py +2 -0
  3. torchx/apps/utils/booth_main.py +2 -0
  4. torchx/apps/utils/copy_main.py +2 -0
  5. torchx/apps/utils/process_monitor.py +2 -0
  6. torchx/cli/__init__.py +2 -0
  7. torchx/cli/argparse_util.py +38 -3
  8. torchx/cli/cmd_base.py +2 -0
  9. torchx/cli/cmd_cancel.py +2 -0
  10. torchx/cli/cmd_configure.py +2 -0
  11. torchx/cli/cmd_describe.py +2 -0
  12. torchx/cli/cmd_list.py +2 -0
  13. torchx/cli/cmd_log.py +6 -24
  14. torchx/cli/cmd_run.py +30 -12
  15. torchx/cli/cmd_runopts.py +2 -0
  16. torchx/cli/cmd_status.py +2 -0
  17. torchx/cli/cmd_tracker.py +2 -0
  18. torchx/cli/colors.py +2 -0
  19. torchx/cli/main.py +2 -0
  20. torchx/components/__init__.py +2 -0
  21. torchx/components/component_test_base.py +2 -0
  22. torchx/components/dist.py +2 -0
  23. torchx/components/integration_tests/component_provider.py +2 -0
  24. torchx/components/integration_tests/integ_tests.py +2 -0
  25. torchx/components/serve.py +2 -0
  26. torchx/components/structured_arg.py +2 -0
  27. torchx/components/utils.py +2 -0
  28. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  29. torchx/examples/apps/lightning/data.py +5 -3
  30. torchx/examples/apps/lightning/model.py +2 -0
  31. torchx/examples/apps/lightning/profiler.py +7 -4
  32. torchx/examples/apps/lightning/train.py +2 -0
  33. torchx/examples/pipelines/kfp/advanced_pipeline.py +2 -0
  34. torchx/examples/pipelines/kfp/dist_pipeline.py +3 -1
  35. torchx/examples/pipelines/kfp/intro_pipeline.py +3 -1
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/pipelines/kfp/__init__.py +2 -0
  39. torchx/pipelines/kfp/adapter.py +7 -4
  40. torchx/pipelines/kfp/version.py +2 -0
  41. torchx/runner/__init__.py +2 -0
  42. torchx/runner/api.py +78 -20
  43. torchx/runner/config.py +34 -3
  44. torchx/runner/events/__init__.py +37 -3
  45. torchx/runner/events/api.py +13 -2
  46. torchx/runner/events/handlers.py +2 -0
  47. torchx/runtime/tracking/__init__.py +2 -0
  48. torchx/runtime/tracking/api.py +2 -0
  49. torchx/schedulers/__init__.py +10 -5
  50. torchx/schedulers/api.py +3 -1
  51. torchx/schedulers/aws_batch_scheduler.py +4 -0
  52. torchx/schedulers/aws_sagemaker_scheduler.py +596 -0
  53. torchx/schedulers/devices.py +17 -4
  54. torchx/schedulers/docker_scheduler.py +38 -8
  55. torchx/schedulers/gcp_batch_scheduler.py +8 -9
  56. torchx/schedulers/ids.py +2 -0
  57. torchx/schedulers/kubernetes_mcad_scheduler.py +3 -1
  58. torchx/schedulers/kubernetes_scheduler.py +31 -5
  59. torchx/schedulers/local_scheduler.py +45 -6
  60. torchx/schedulers/lsf_scheduler.py +3 -1
  61. torchx/schedulers/ray/ray_driver.py +7 -7
  62. torchx/schedulers/ray_scheduler.py +1 -1
  63. torchx/schedulers/slurm_scheduler.py +3 -1
  64. torchx/schedulers/streams.py +2 -0
  65. torchx/specs/__init__.py +49 -8
  66. torchx/specs/api.py +87 -5
  67. torchx/specs/builders.py +61 -19
  68. torchx/specs/file_linter.py +8 -2
  69. torchx/specs/finder.py +2 -0
  70. torchx/specs/named_resources_aws.py +109 -2
  71. torchx/specs/named_resources_generic.py +2 -0
  72. torchx/specs/test/components/__init__.py +2 -0
  73. torchx/specs/test/components/a/__init__.py +2 -0
  74. torchx/specs/test/components/a/b/__init__.py +2 -0
  75. torchx/specs/test/components/a/b/c.py +2 -0
  76. torchx/specs/test/components/c/__init__.py +2 -0
  77. torchx/specs/test/components/c/d.py +2 -0
  78. torchx/tracker/__init__.py +2 -0
  79. torchx/tracker/api.py +4 -4
  80. torchx/tracker/backend/fsspec.py +2 -0
  81. torchx/util/cuda.py +2 -0
  82. torchx/util/datetime.py +2 -0
  83. torchx/util/entrypoints.py +6 -2
  84. torchx/util/io.py +2 -0
  85. torchx/util/log_tee_helpers.py +210 -0
  86. torchx/util/modules.py +2 -0
  87. torchx/util/session.py +42 -0
  88. torchx/util/shlex.py +2 -0
  89. torchx/util/strings.py +2 -0
  90. torchx/util/types.py +20 -2
  91. torchx/version.py +3 -1
  92. torchx/workspace/__init__.py +2 -0
  93. torchx/workspace/api.py +34 -1
  94. torchx/workspace/dir_workspace.py +2 -0
  95. torchx/workspace/docker_workspace.py +25 -2
  96. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.15.dist-info}/METADATA +55 -48
  97. torchx_nightly-2025.1.15.dist-info/RECORD +123 -0
  98. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.15.dist-info}/WHEEL +1 -1
  99. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.15.dist-info}/entry_points.txt +0 -1
  100. torchx_nightly-2024.2.12.dist-info/RECORD +0 -119
  101. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.15.dist-info}/LICENSE +0 -0
  102. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.15.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  from typing import Any, Dict, Optional
8
10
 
9
11
  import importlib_metadata as metadata
@@ -49,8 +51,7 @@ def _defer_load_ep(ep: EntryPoint) -> object:
49
51
 
50
52
  # pyre-ignore-all-errors[3, 2]
51
53
  def load_group(
52
- group: str,
53
- default: Optional[Dict[str, Any]] = None,
54
+ group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False
54
55
  ):
55
56
  """
56
57
  Loads all the entry points specified by ``group`` and returns
@@ -70,6 +71,7 @@ def load_group(
70
71
  1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")``
71
72
  1. ``load_group("food")`` -> ``None``
72
73
  1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")``
74
+ 1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None``
73
75
 
74
76
 
75
77
  If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn``
@@ -88,6 +90,8 @@ def load_group(
88
90
  entrypoints = metadata.entry_points().select(group=group)
89
91
 
90
92
  if len(entrypoints) == 0:
93
+ if skip_defaults:
94
+ return None
91
95
  return default
92
96
 
93
97
  eps = {}
torchx/util/io.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  from os import path
8
10
  from pathlib import Path
9
11
  from typing import Optional
@@ -0,0 +1,210 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ """
10
+ If you're wrapping the TorchX API with your own CLI, these functions can
11
+ help show the logs of the job within your CLI, just like
12
+ `torchx log`
13
+ """
14
+
15
+ import logging
16
+ import threading
17
+ from queue import Queue
18
+ from typing import List, Optional, TextIO, Tuple, TYPE_CHECKING
19
+
20
+ from torchx.util.types import none_throws
21
+
22
+ if TYPE_CHECKING:
23
+ from torchx.runner.api import Runner
24
+ from torchx.schedulers.api import Stream
25
+ from torchx.specs.api import AppDef
26
+
27
+ logger: logging.Logger = logging.getLogger(__name__)
28
+
29
+ # A torchX job can have stderr/stdout for many replicas, of many roles
30
+ # The scheduler API has functions that allow us to get,
31
+ # with unspecified detail, the log lines of a given replica of
32
+ # a given role.
33
+ #
34
+ # So, to neatly tee the results, we:
35
+ # 1) Determine every role ID / replica ID pair we want to monitor
36
+ # 2) Request the given stderr / stdout / combined streams from them (1 thread each)
37
+ # 3) Concatenate each of those streams to a given destination file
38
+
39
+
40
+ def _find_role_replicas(
41
+ app: "AppDef",
42
+ role_name: Optional[str],
43
+ ) -> List[Tuple[str, int]]:
44
+ """
45
+ Enumerate all (role, replica id) pairs in the given AppDef.
46
+ Replica IDs are 0-indexed, and range up to num_replicas,
47
+ for each role.
48
+ If role_name is provided, filters to only that name.
49
+ """
50
+ role_replicas = []
51
+ for role in app.roles:
52
+ if role_name is None or role_name == role.name:
53
+ for i in range(role.num_replicas):
54
+ role_replicas.append((role.name, i))
55
+ return role_replicas
56
+
57
+
58
+ def _prefix_line(prefix: str, line: str) -> str:
59
+ """
60
+ _prefix_line ensure the prefix is still present even when dealing with return characters
61
+ """
62
+ if "\r" in line:
63
+ line = line.replace("\r", f"\r{prefix}")
64
+ if "\n" in line[:-1]:
65
+ line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:]
66
+ if not line.startswith("\r"):
67
+ line = f"{prefix}{line}"
68
+ return line
69
+
70
+
71
+ def _print_log_lines_for_role_replica(
72
+ dst: TextIO,
73
+ app_handle: str,
74
+ regex: Optional[str],
75
+ runner: "Runner",
76
+ which_role: str,
77
+ which_replica: int,
78
+ exceptions: "Queue[Exception]",
79
+ should_tail: bool,
80
+ streams: Optional["Stream"],
81
+ colorize: bool = False,
82
+ ) -> None:
83
+ """
84
+ Helper function that'll run in parallel - one
85
+ per monitored replica of a given role.
86
+
87
+ Based on print_log_lines .. but not designed for TTY
88
+ """
89
+ try:
90
+ for line in runner.log_lines(
91
+ app_handle,
92
+ which_role,
93
+ which_replica,
94
+ regex,
95
+ should_tail=should_tail,
96
+ streams=streams,
97
+ ):
98
+ if colorize:
99
+ color_begin = "\033[32m"
100
+ color_end = "\033[0m"
101
+ else:
102
+ color_begin = ""
103
+ color_end = ""
104
+ prefix = f"{color_begin}{which_role}/{which_replica}{color_end} "
105
+ print(_prefix_line(prefix, line.strip()), file=dst, end="\n", flush=True)
106
+ except Exception as e:
107
+ exceptions.put(e)
108
+ raise
109
+
110
+
111
+ def _start_threads_to_monitor_role_replicas(
112
+ dst: TextIO,
113
+ app_handle: str,
114
+ regex: Optional[str],
115
+ runner: "Runner",
116
+ which_role: Optional[str] = None,
117
+ should_tail: bool = False,
118
+ streams: Optional["Stream"] = None,
119
+ colorize: bool = False,
120
+ ) -> None:
121
+ threads = []
122
+
123
+ app = none_throws(runner.describe(app_handle))
124
+ replica_ids = _find_role_replicas(app, role_name=which_role)
125
+
126
+ # Holds exceptions raised by all threads, in a thread-safe
127
+ # object
128
+ exceptions = Queue()
129
+
130
+ if not replica_ids:
131
+ valid_roles = [role.name for role in app.roles]
132
+ raise ValueError(
133
+ f"{which_role} is not a valid role name. Available: {valid_roles}"
134
+ )
135
+
136
+ for role_name, replica_id in replica_ids:
137
+ threads.append(
138
+ threading.Thread(
139
+ target=_print_log_lines_for_role_replica,
140
+ kwargs={
141
+ "dst": dst,
142
+ "runner": runner,
143
+ "app_handle": app_handle,
144
+ "which_role": role_name,
145
+ "which_replica": replica_id,
146
+ "regex": regex,
147
+ "should_tail": should_tail,
148
+ "exceptions": exceptions,
149
+ "streams": streams,
150
+ "colorize": colorize,
151
+ },
152
+ daemon=True,
153
+ )
154
+ )
155
+
156
+ for t in threads:
157
+ t.start()
158
+
159
+ for t in threads:
160
+ t.join()
161
+
162
+ # Retrieve all exceptions, print all except one and raise the first recorded exception
163
+ threads_exceptions = []
164
+ while not exceptions.empty():
165
+ threads_exceptions.append(exceptions.get())
166
+
167
+ if len(threads_exceptions) > 0:
168
+ for i in range(1, len(threads_exceptions)):
169
+ logger.error(threads_exceptions[i])
170
+
171
+ raise threads_exceptions[0]
172
+
173
+
174
+ def tee_logs(
175
+ dst: TextIO,
176
+ app_handle: str,
177
+ regex: Optional[str],
178
+ runner: "Runner",
179
+ should_tail: bool = False,
180
+ streams: Optional["Stream"] = None,
181
+ colorize: bool = False,
182
+ ) -> threading.Thread:
183
+ """
184
+ Makes a thread, which in turn will start 1 thread per replica
185
+ per role, that tees that role-replica's logs to the given
186
+ destination buffer.
187
+
188
+ You'll need to start and join with this parent thread.
189
+
190
+ dst: TextIO to tee the logs into
191
+ app_handle: The return value of runner.run() or runner.schedule()
192
+ regex: Regex to filter the logs that are tee-d
193
+ runner: The Runner you used to schedule the job
194
+ should_tail: If true, continue until we run out of logs. Otherwise, just fetch
195
+ what's available
196
+ streams: Whether to fetch STDERR, STDOUT, or the temporally COMBINED (default) logs
197
+ """
198
+ thread = threading.Thread(
199
+ target=_start_threads_to_monitor_role_replicas,
200
+ kwargs={
201
+ "dst": dst,
202
+ "runner": runner,
203
+ "app_handle": app_handle,
204
+ "regex": None,
205
+ "should_tail": True,
206
+ "colorize": colorize,
207
+ },
208
+ daemon=True,
209
+ )
210
+ return thread
torchx/util/modules.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import importlib
8
10
  from types import ModuleType
9
11
  from typing import Callable, Optional, Union
torchx/util/session.py ADDED
@@ -0,0 +1,42 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ import os
11
+ import uuid
12
+ from typing import Optional
13
+
14
+ TORCHX_INTERNAL_SESSION_ID = "TORCHX_INTERNAL_SESSION_ID"
15
+
16
+ CURRENT_SESSION_ID: Optional[str] = None
17
+
18
+
19
+ def get_session_id_or_create_new() -> str:
20
+ """
21
+ Returns the current session ID, or creates a new one if none exists.
22
+ The session ID remains the same as long as it is in the same process.
23
+ Please DO NOT use this function out of torchx codebase.
24
+ """
25
+ global CURRENT_SESSION_ID
26
+ if CURRENT_SESSION_ID:
27
+ return CURRENT_SESSION_ID
28
+ env_session_id = os.getenv(TORCHX_INTERNAL_SESSION_ID)
29
+ if env_session_id:
30
+ CURRENT_SESSION_ID = env_session_id
31
+ return CURRENT_SESSION_ID
32
+ session_id = str(uuid.uuid4())
33
+ CURRENT_SESSION_ID = session_id
34
+ return session_id
35
+
36
+
37
+ def get_torchx_session_id() -> Optional[str]:
38
+ """
39
+ Returns the torchx session ID.
40
+ Please use this function to get the session ID out of torchx codebase.
41
+ """
42
+ return CURRENT_SESSION_ID
torchx/util/shlex.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import shlex
8
10
  from typing import Iterable
9
11
 
torchx/util/strings.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import re
8
10
 
9
11
 
torchx/util/types.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import inspect
8
10
  from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
9
11
 
@@ -43,6 +45,9 @@ def to_dict(arg: str) -> Dict[str, str]:
43
45
 
44
46
  to_dict("FOO=v1") == {"FOO": "v1"}
45
47
 
48
+ to_dict("FOO=''") == {"FOO": ""}
49
+ to_dict('FOO=""') == {"FOO": ""}
50
+
46
51
  to_dict("FOO=v1,v2") == {"FOO": "v1,v2"]}
47
52
  to_dict("FOO=v1;v2") == {"FOO": "v1;v2"]}
48
53
  to_dict("FOO=v1;v2") == {"FOO": "v1;v2,"]}
@@ -68,6 +73,9 @@ def to_dict(arg: str) -> Dict[str, str]:
68
73
  else:
69
74
  return vk[0:idx].strip(), vk[idx + 1 :].strip()
70
75
 
76
+ def to_val(val: str) -> str:
77
+ return val if val != '""' and val != "''" else ""
78
+
71
79
  arg_map: Dict[str, str] = {}
72
80
 
73
81
  if not arg:
@@ -90,10 +98,10 @@ def to_dict(arg: str) -> Dict[str, str]:
90
98
  # middle elements are value_{n}<delim>key_{n+1}
91
99
  for vk in split_arg[1 : split_arg_len - 1]: # python deals with
92
100
  val, key_next = parse_val_key(vk)
93
- arg_map[key] = val
101
+ arg_map[key] = to_val(val)
94
102
  key = key_next
95
103
  val = split_arg[-1] # last element is always a value
96
- arg_map[key] = val
104
+ arg_map[key] = to_val(val)
97
105
  return arg_map
98
106
 
99
107
 
@@ -120,6 +128,16 @@ def _decode_string_to_list(
120
128
  return arg_values
121
129
 
122
130
 
131
+ def decode(encoded_value: Any, annotation: Any):
132
+ if encoded_value is None:
133
+ return None
134
+ if is_bool(annotation):
135
+ return encoded_value and encoded_value.lower() == "true"
136
+ if not is_primitive(annotation) and type(encoded_value) == str:
137
+ return decode_from_string(encoded_value, annotation)
138
+ return encoded_value
139
+
140
+
123
141
  def decode_from_string(
124
142
  encoded_value: str, annotation: Any
125
143
  ) -> Union[Dict[Any, Any], List[Any], None]:
torchx/version.py CHANGED
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  from torchx.util.entrypoints import load
9
11
 
10
12
  # Follows PEP-0440 version scheme guidelines
@@ -16,7 +18,7 @@ from torchx.util.entrypoints import load
16
18
  # 0.1.0bN # Beta release
17
19
  # 0.1.0rcN # Release Candidate
18
20
  # 0.1.0 # Final release
19
- __version__ = "0.7.0dev0"
21
+ __version__ = "0.8.0dev0"
20
22
 
21
23
 
22
24
  # Use the github container registry images corresponding to the current package
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  Status: Beta
9
11
 
torchx/workspace/api.py CHANGED
@@ -4,10 +4,13 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import abc
8
10
  import fnmatch
9
11
  import posixpath
10
- from typing import Generic, Iterable, Mapping, Tuple, TYPE_CHECKING, TypeVar
12
+ from dataclasses import dataclass
13
+ from typing import Any, Dict, Generic, Iterable, Mapping, Tuple, TYPE_CHECKING, TypeVar
11
14
 
12
15
  from torchx.specs import AppDef, CfgVal, Role, runopts
13
16
 
@@ -18,6 +21,36 @@ TORCHX_IGNORE = ".torchxignore"
18
21
 
19
22
  T = TypeVar("T")
20
23
 
24
+ PackageType = TypeVar("PackageType")
25
+ WorkspaceConfigType = TypeVar("WorkspaceConfigType")
26
+
27
+
28
+ @dataclass
29
+ class PkgInfo(Generic[PackageType]):
30
+ """
31
+ Convenience class used to specify information regarding the built workspace
32
+ """
33
+
34
+ img: str
35
+ lazy_overrides: Dict[str, Any]
36
+ metadata: PackageType
37
+
38
+
39
+ @dataclass
40
+ class WorkspaceBuilder(Generic[PackageType, WorkspaceConfigType]):
41
+ cfg: WorkspaceConfigType
42
+
43
+ @abc.abstractmethod
44
+ def build_workspace(self, sync: bool = True) -> PkgInfo[PackageType]:
45
+ """
46
+ Builds the specified ``workspace`` with respect to ``img``.
47
+ In the simplest case, this method builds a new image.
48
+ Certain (more efficient) implementations build
49
+ incremental diff patches that overlay on top of the role's image.
50
+
51
+ """
52
+ pass
53
+
21
54
 
22
55
  class WorkspaceMixin(abc.ABC, Generic[T]):
23
56
  """
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import os
9
11
  import posixpath
10
12
  import shutil
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import io
8
10
  import logging
9
11
  import posixpath
@@ -16,6 +18,7 @@ from typing import Dict, IO, Iterable, Mapping, Optional, TextIO, Tuple, TYPE_CH
16
18
  import fsspec
17
19
 
18
20
  import torchx
21
+ from docker.errors import BuildError
19
22
  from torchx.specs import AppDef, CfgVal, Role, runopts
20
23
  from torchx.workspace.api import walk_workspace, WorkspaceMixin
21
24
 
@@ -91,6 +94,12 @@ class DockerWorkspaceMixin(WorkspaceMixin[Dict[str, Tuple[str, str]]]):
91
94
  type_=str,
92
95
  help="(remote jobs) the image repository to use when pushing patched images, must have push access. Ex: example.com/your/container",
93
96
  )
97
+ opts.add(
98
+ "quiet",
99
+ type_=bool,
100
+ default=False,
101
+ help="whether to suppress verbose output for image building. Defaults to ``False``.",
102
+ )
94
103
  return opts
95
104
 
96
105
  def build_workspace_and_update_role(
@@ -119,7 +128,7 @@ class DockerWorkspaceMixin(WorkspaceMixin[Dict[str, Tuple[str, str]]]):
119
128
  f"failed to pull image {role.image}, falling back to local: {e}"
120
129
  )
121
130
  log.info("Building workspace docker image (this may take a while)...")
122
- image, _ = self._docker_client.images.build(
131
+ build_events = self._docker_client.api.build(
123
132
  fileobj=context,
124
133
  custom_context=True,
125
134
  dockerfile=TORCHX_DOCKERFILE,
@@ -129,12 +138,26 @@ class DockerWorkspaceMixin(WorkspaceMixin[Dict[str, Tuple[str, str]]]):
129
138
  },
130
139
  pull=False,
131
140
  rm=True,
141
+ decode=True,
132
142
  labels={
133
143
  self.LABEL_VERSION: torchx.__version__,
134
144
  },
135
145
  )
146
+ image_id = None
147
+ for event in build_events:
148
+ if message := event.get("stream"):
149
+ if not cfg.get("quiet", False):
150
+ message = message.strip("\r\n").strip("\n")
151
+ if message:
152
+ log.info(message)
153
+ if aux := event.get("aux"):
154
+ image_id = aux["ID"]
155
+ if error := event.get("error"):
156
+ raise BuildError(reason=error, build_log=None)
136
157
  if len(old_imgs) == 0 or role.image not in old_imgs:
137
- role.image = image.id
158
+ assert image_id, "image id was not found"
159
+ role.image = image_id
160
+
138
161
  finally:
139
162
  context.close()
140
163