torchx-nightly 2023.10.21__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +7 -6
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +79 -5
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +10 -8
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +431 -32
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +254 -22
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2023.10.21.dist-info/METADATA +0 -174
  108. torchx_nightly-2023.10.21.dist-info/RECORD +0 -118
  109. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,106 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ """
10
+ Overlays are JSON structs applied to :py:class:`~torchx.specs.AppDef` and :py:class:`~torchx.specs.Role`
11
+ to specify attributes of the scheduler's submit-job request that are not currently representable
12
+ as attributes of :py:class:`~torchx.specs.AppDef` and :py:class:`~torchx.specs.Role`.
13
+
14
+ For end-uses, here are a few use-cases of overlays:
15
+
16
+ 1. A new version of the scheduler has concepts/features that have not yet been added to TorchX.
17
+ 2. A bespoke internal scheduler has custom features that do not generalize hence not in TorchX.
18
+ 3. Re-using a pre-built ``AppDef`` but need to make a small change to the resulting scheduler request.
19
+
20
+ And for scheduler authors:
21
+
22
+ 1. Scheduler setting needs to be applied to a ``Role``, which makes it hard to add as ``runopts``
23
+ since ``runopts`` apply at the ``AppDef`` level.
24
+ 2. Scheduler setting cannot be represented naturally as the types supported by ``runopts``.
25
+ 3. Exposing the setting as a ``runopts`` obfuscates things.
26
+
27
+ See :py:func:`~torchx.specs.overlays.apply_overlay` for rules on how overlays are applied.
28
+ """
29
+
30
+ from typing import Any
31
+
32
+ Json = dict[str, Any]
33
+
34
+
35
+ def apply_overlay(base: Json, overlay: Json) -> None:
36
+ """Applies ``overlay`` on ``base``.
37
+
38
+ .. note:: this function mutates the ``base``!
39
+
40
+ Overlays follow these rules:
41
+
42
+ 1. Dicts, upsert key, value in base with the ones in overlay.
43
+ 2. Nested dicts, overlay recursively.
44
+ 3. Lists, append the overlay values to the base values.
45
+ 4. Nested lists DO NOT append recursively.
46
+ 5. Primitives (bool, str, int, float), replace base with the value in overlay.
47
+
48
+ .. doctest::
49
+
50
+ from torchx.specs.overlays import apply_overlay
51
+
52
+ base = {
53
+ "scheduler": {"policy": "default"},
54
+ "resources": {"limits": {"cpu": "500m"}},
55
+ "tolerations": [{"key": "gpu"}],
56
+ "nodeSelectorTerms": [
57
+ [{"matchExpressions": []}]
58
+ ],
59
+ "maxPods": 110,
60
+ }
61
+ overlay = {
62
+ "scheduler": {"policy": "binpacking"},
63
+ "resources": {"limits": {"memory": "1Gi"}},
64
+ "tolerations": [{"key": "spot"}],
65
+ "nodeSelectorTerms": [
66
+ [{"matchExpressions": [{"key": "disk"}]}]
67
+ ],
68
+ "maxPods": 250,
69
+ }
70
+
71
+ apply_overlay(base, overlay)
72
+
73
+ assert {
74
+ "scheduler": {"policy": "binpacking"},
75
+ "resources": {"limits": {"cpu": "500m", "memory": "1Gi"}},
76
+ "tolerations": [{"key": "gpu"}, {"key": "spot"}],
77
+ "nodeSelectorTerms": [
78
+ [{"matchExpressions": []}],
79
+ [{"matchExpressions": [{"key": "disk"}]}],
80
+ ],
81
+ "maxPods": 250,
82
+ } == base
83
+
84
+ """
85
+
86
+ def assert_type_equal(key: str, o1: object, o2: object) -> None:
87
+ o1_type = type(o1)
88
+ o2_type = type(o2)
89
+ assert (
90
+ o1_type == o2_type
91
+ ), f"Type mismatch for attr: `{key}`. {o1_type.__qualname__} != {o2_type.__qualname__}"
92
+
93
+ for key, overlay_value in overlay.items():
94
+ if key in base:
95
+ base_value = base[key]
96
+
97
+ assert_type_equal(key, base_value, overlay_value)
98
+
99
+ if isinstance(base_value, dict) and isinstance(overlay_value, dict):
100
+ apply_overlay(base_value, overlay_value)
101
+ elif isinstance(base_value, list) and isinstance(overlay_value, list):
102
+ base_value.extend(overlay_value)
103
+ else:
104
+ base[key] = overlay_value
105
+ else:
106
+ base[key] = overlay_value
@@ -3,3 +3,5 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -3,6 +3,8 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
6
8
  import torchx
7
9
  from torchx import specs
8
10
 
@@ -3,3 +3,5 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -3,6 +3,8 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
6
8
  import torchx
7
9
  from torchx import specs
8
10
 
@@ -4,3 +4,5 @@
4
4
  #
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
@@ -3,6 +3,8 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
6
8
  import torchx
7
9
  from torchx import specs
8
10
 
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  .. note:: PROTOTYPE, USE AT YOUR OWN RISK, APIs SUBJECT TO CHANGE
9
11
 
@@ -30,14 +32,14 @@ implementation.
30
32
 
31
33
  Example usage
32
34
  -------------
33
- Sample `code <https://github.com/pytorch/torchx/blob/main/torchx/examples/apps/tracker/main.py>`__ using tracker API.
35
+ Sample `code <https://github.com/meta-pytorch/torchx/blob/main/torchx/examples/apps/tracker/main.py>`__ using tracker API.
34
36
 
35
37
 
36
38
  Tracker Setup
37
39
  -------------
38
40
  To enable tracking it requires:
39
41
 
40
- 1. Defining tracker backends (entrypoints and configuration) on launcher side using :doc:`runner.config`
42
+ 1. Defining tracker backends (entrypoints/modules and configuration) on launcher side using :doc:`runner.config`
41
43
  2. Adding entrypoints within a user job using entry_points (`specification`_)
42
44
 
43
45
  .. _specification: https://packaging.python.org/en/latest/specifications/entry-points/
@@ -49,13 +51,13 @@ To enable tracking it requires:
49
51
  User can define any number of tracker backends under **torchx:tracker** section in :doc:`runner.config`, where:
50
52
  * Key: is an arbitrary name for the tracker, where the name will be used to configure its properties
51
53
  under [tracker:<TRACKER_NAME>]
52
- * Value: is *entrypoint/factory method* that must be available within user job. The value will be injected into a
54
+ * Value: is *entrypoint* or *module* factory method that must be available within user job. The value will be injected into a
53
55
  user job and used to construct tracker implementation.
54
56
 
55
57
  .. code-block:: ini
56
58
 
57
59
  [torchx:tracker]
58
- tracker_name=<entry_point>
60
+ tracker_name=<entry_point_or_module_factory_method>
59
61
 
60
62
 
61
63
  Each tracker can be additionally configured (currently limited to `config` parameter) under `[tracker:<TRACKER NAME>]` section:
@@ -71,11 +73,15 @@ For example, ~/.torchxconfig may be setup as:
71
73
 
72
74
  [torchx:tracker]
73
75
  tracker1=tracker1
74
- tracker12=backend_2_entry_point
76
+ tracker2=backend_2_entry_point
77
+ tracker3=torchx.tracker.mlflow:create_tracker
75
78
 
76
79
  [tracker:tracker1]
77
80
  config=s3://my_bucket/config.json
78
81
 
82
+ [tracker:tracker3]
83
+ config=my_config.json
84
+
79
85
 
80
86
  2. User job configuration (Advanced)
81
87
  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -105,7 +111,7 @@ Use :py:meth:`~torchx.tracker.app_run_from_env`:
105
111
  Reference :py:class:`~torchx.tracker.api.TrackerBase` implementation
106
112
  --------------------------------------------------------------------
107
113
  :py:class:`~torchx.tracker.backend.fsspec.FsspecTracker` provides reference implementation of a tracker backend.
108
- GitHub example `directory <https://github.com/pytorch/torchx/blob/main/torchx/examples/apps/tracker/>`__ provides example on how to
114
+ GitHub example `directory <https://github.com/meta-pytorch/torchx/blob/main/torchx/examples/apps/tracker/>`__ provides example on how to
109
115
  configure and use it in user application.
110
116
 
111
117
 
torchx/tracker/api.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  from __future__ import annotations
8
10
 
9
11
  import logging
@@ -14,6 +16,7 @@ from functools import lru_cache
14
16
  from typing import Iterable, Mapping, Optional
15
17
 
16
18
  from torchx.util.entrypoints import load_group
19
+ from torchx.util.modules import load_module
17
20
 
18
21
  logger: logging.Logger = logging.getLogger(__name__)
19
22
 
@@ -66,8 +69,7 @@ class AppRunTrackableSource:
66
69
  artifact_name: Optional[str]
67
70
 
68
71
 
69
- class Lineage:
70
- ...
72
+ class Lineage: ...
71
73
 
72
74
 
73
75
  class TrackerBase(ABC):
@@ -177,30 +179,26 @@ def _extract_tracker_name_and_config_from_environ() -> Mapping[str, Optional[str
177
179
 
178
180
 
179
181
  def build_trackers(
180
- entrypoint_and_config: Mapping[str, Optional[str]]
182
+ factory_and_config: Mapping[str, Optional[str]]
181
183
  ) -> Iterable[TrackerBase]:
182
184
  trackers = []
183
185
 
184
- entrypoint_factories = load_group("torchx.tracker")
186
+ entrypoint_factories = load_group("torchx.tracker") or {}
185
187
  if not entrypoint_factories:
186
- logger.warning(
187
- "No 'torchx.tracker' entry_points are defined. Tracking will not capture any data."
188
- )
189
- return trackers
188
+ logger.warning("No 'torchx.tracker' entry_points are defined.")
190
189
 
191
- for entrypoint_key, config in entrypoint_and_config.items():
192
- if entrypoint_key not in entrypoint_factories:
190
+ for factory_name, config in factory_and_config.items():
191
+ factory = entrypoint_factories.get(factory_name) or load_module(factory_name)
192
+ if not factory:
193
193
  logger.warning(
194
- f"Could not find `{entrypoint_key}` tracker entrypoint. Skipping..."
194
+ f"No tracker factory `{factory_name}` found in entry_points or modules. See https://meta-pytorch.org/torchx/main/tracker.html#module-torchx.tracker"
195
195
  )
196
196
  continue
197
- factory = entrypoint_factories[entrypoint_key]
198
197
  if config:
199
- logger.info(f"Tracker config found for `{entrypoint_key}` as `{config}`")
200
- tracker = factory(config)
198
+ logger.info(f"Tracker config found for `{factory_name}` as `{config}`")
201
199
  else:
202
- logger.info(f"No tracker config specified for `{entrypoint_key}`")
203
- tracker = factory(None)
200
+ logger.info(f"No tracker config specified for `{factory_name}`")
201
+ tracker = factory(config)
204
202
  trackers.append(tracker)
205
203
  return trackers
206
204
 
@@ -335,5 +333,4 @@ class AppRun:
335
333
 
336
334
  return model_run_sources
337
335
 
338
- def children(self) -> Iterable[AppRun]:
339
- ...
336
+ def children(self) -> Iterable[AppRun]: ...
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  from __future__ import annotations
8
10
 
9
11
  import json
torchx/util/cuda.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import torch
8
10
 
9
11
 
torchx/util/datetime.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  from datetime import datetime, timedelta
8
10
 
9
11
 
@@ -4,13 +4,14 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
- from typing import Any, Dict, Optional
7
+ # pyre-strict
8
+ # pyre-ignore-all-errors[3, 2, 16]
8
9
 
9
- import importlib_metadata as metadata
10
- from importlib_metadata import EntryPoint
10
+ from importlib import metadata
11
+ from importlib.metadata import EntryPoint
12
+ from typing import Any, Dict, Optional
11
13
 
12
14
 
13
- # pyre-ignore-all-errors[3, 2]
14
15
  def load(group: str, name: str, default=None):
15
16
  """
16
17
  Loads the entry point specified by
@@ -28,13 +29,34 @@ def load(group: str, name: str, default=None):
28
29
  raises an error.
29
30
  """
30
31
 
31
- entrypoints = metadata.entry_points().select(group=group)
32
+ # [note_on_entrypoints]
33
+ # return type of importlib.metadata.entry_points() is different between python-3.9 and python-3.10
34
+ # https://docs.python.org/3.9/library/importlib.metadata.html#importlib.metadata.entry_points
35
+ # https://docs.python.org/3.10/library/importlib.metadata.html#importlib.metadata.entry_points
36
+ if hasattr(metadata.entry_points(), "select"):
37
+ # python>=3.10
38
+ entrypoints = metadata.entry_points().select(group=group)
32
39
 
33
- if name not in entrypoints.names and default is not None:
34
- return default
40
+ if name not in entrypoints.names and default is not None:
41
+ return default
42
+
43
+ ep = entrypoints[name]
44
+ return ep.load()
35
45
 
36
- ep = entrypoints[name]
37
- return ep.load()
46
+ else:
47
+ # python<3.10 (e.g. 3.9)
48
+ # metadata.entry_points() returns dict[str, tuple[EntryPoint]] (not EntryPoints) in python-3.9
49
+ entrypoints = metadata.entry_points().get(group, ())
50
+
51
+ for ep in entrypoints:
52
+ if ep.name == name:
53
+ return ep.load()
54
+
55
+ # [group].name not found
56
+ if default is not None:
57
+ return default
58
+ else:
59
+ raise KeyError(f"entrypoint {group}.{name} not found")
38
60
 
39
61
 
40
62
  def _defer_load_ep(ep: EntryPoint) -> object:
@@ -47,11 +69,7 @@ def _defer_load_ep(ep: EntryPoint) -> object:
47
69
  return run
48
70
 
49
71
 
50
- # pyre-ignore-all-errors[3, 2]
51
- def load_group(
52
- group: str,
53
- default: Optional[Dict[str, Any]] = None,
54
- ):
72
+ def load_group(group: str, default: Optional[Dict[str, Any]] = None):
55
73
  """
56
74
  Loads all the entry points specified by ``group`` and returns
57
75
  the entry points as a map of ``name (str) -> deferred_load_fn``.
@@ -85,7 +103,13 @@ def load_group(
85
103
 
86
104
  """
87
105
 
88
- entrypoints = metadata.entry_points().select(group=group)
106
+ # see [note_on_entrypoints] above
107
+ if hasattr(metadata.entry_points(), "select"):
108
+ # python>=3.10
109
+ entrypoints = metadata.entry_points().select(group=group)
110
+ else:
111
+ # python<3.10 (e.g. 3.9)
112
+ entrypoints = metadata.entry_points().get(group, ())
89
113
 
90
114
  if len(entrypoints) == 0:
91
115
  return default
torchx/util/io.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  from os import path
8
10
  from pathlib import Path
9
11
  from typing import Optional
@@ -0,0 +1,210 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ """
10
+ If you're wrapping the TorchX API with your own CLI, these functions can
11
+ help show the logs of the job within your CLI, just like
12
+ `torchx log`
13
+ """
14
+
15
+ import logging
16
+ import threading
17
+ from queue import Queue
18
+ from typing import List, Optional, TextIO, Tuple, TYPE_CHECKING
19
+
20
+ from torchx.util.types import none_throws
21
+
22
+ if TYPE_CHECKING:
23
+ from torchx.runner.api import Runner
24
+ from torchx.schedulers.api import Stream
25
+ from torchx.specs.api import AppDef
26
+
27
+ logger: logging.Logger = logging.getLogger(__name__)
28
+
29
+ # A torchX job can have stderr/stdout for many replicas, of many roles
30
+ # The scheduler API has functions that allow us to get,
31
+ # with unspecified detail, the log lines of a given replica of
32
+ # a given role.
33
+ #
34
+ # So, to neatly tee the results, we:
35
+ # 1) Determine every role ID / replica ID pair we want to monitor
36
+ # 2) Request the given stderr / stdout / combined streams from them (1 thread each)
37
+ # 3) Concatenate each of those streams to a given destination file
38
+
39
+
40
+ def _find_role_replicas(
41
+ app: "AppDef",
42
+ role_name: Optional[str],
43
+ ) -> List[Tuple[str, int]]:
44
+ """
45
+ Enumerate all (role, replica id) pairs in the given AppDef.
46
+ Replica IDs are 0-indexed, and range up to num_replicas,
47
+ for each role.
48
+ If role_name is provided, filters to only that name.
49
+ """
50
+ role_replicas = []
51
+ for role in app.roles:
52
+ if role_name is None or role_name == role.name:
53
+ for i in range(role.num_replicas):
54
+ role_replicas.append((role.name, i))
55
+ return role_replicas
56
+
57
+
58
+ def _prefix_line(prefix: str, line: str) -> str:
59
+ """
60
+ _prefix_line ensure the prefix is still present even when dealing with return characters
61
+ """
62
+ if "\r" in line:
63
+ line = line.replace("\r", f"\r{prefix}")
64
+ if "\n" in line[:-1]:
65
+ line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:]
66
+ if not line.startswith("\r"):
67
+ line = f"{prefix}{line}"
68
+ return line
69
+
70
+
71
+ def _print_log_lines_for_role_replica(
72
+ dst: TextIO,
73
+ app_handle: str,
74
+ regex: Optional[str],
75
+ runner: "Runner",
76
+ which_role: str,
77
+ which_replica: int,
78
+ exceptions: "Queue[Exception]",
79
+ should_tail: bool,
80
+ streams: Optional["Stream"],
81
+ colorize: bool = False,
82
+ ) -> None:
83
+ """
84
+ Helper function that'll run in parallel - one
85
+ per monitored replica of a given role.
86
+
87
+ Based on print_log_lines .. but not designed for TTY
88
+ """
89
+ try:
90
+ for line in runner.log_lines(
91
+ app_handle,
92
+ which_role,
93
+ which_replica,
94
+ regex,
95
+ should_tail=should_tail,
96
+ streams=streams,
97
+ ):
98
+ if colorize:
99
+ color_begin = "\033[32m"
100
+ color_end = "\033[0m"
101
+ else:
102
+ color_begin = ""
103
+ color_end = ""
104
+ prefix = f"{color_begin}{which_role}/{which_replica}{color_end} "
105
+ print(_prefix_line(prefix, line.strip()), file=dst, end="\n", flush=True)
106
+ except Exception as e:
107
+ exceptions.put(e)
108
+ raise
109
+
110
+
111
+ def _start_threads_to_monitor_role_replicas(
112
+ dst: TextIO,
113
+ app_handle: str,
114
+ regex: Optional[str],
115
+ runner: "Runner",
116
+ which_role: Optional[str] = None,
117
+ should_tail: bool = False,
118
+ streams: Optional["Stream"] = None,
119
+ colorize: bool = False,
120
+ ) -> None:
121
+ threads = []
122
+
123
+ app = none_throws(runner.describe(app_handle))
124
+ replica_ids = _find_role_replicas(app, role_name=which_role)
125
+
126
+ # Holds exceptions raised by all threads, in a thread-safe
127
+ # object
128
+ exceptions = Queue()
129
+
130
+ if not replica_ids:
131
+ valid_roles = [role.name for role in app.roles]
132
+ raise ValueError(
133
+ f"{which_role} is not a valid role name. Available: {valid_roles}"
134
+ )
135
+
136
+ for role_name, replica_id in replica_ids:
137
+ threads.append(
138
+ threading.Thread(
139
+ target=_print_log_lines_for_role_replica,
140
+ kwargs={
141
+ "dst": dst,
142
+ "runner": runner,
143
+ "app_handle": app_handle,
144
+ "which_role": role_name,
145
+ "which_replica": replica_id,
146
+ "regex": regex,
147
+ "should_tail": should_tail,
148
+ "exceptions": exceptions,
149
+ "streams": streams,
150
+ "colorize": colorize,
151
+ },
152
+ daemon=True,
153
+ )
154
+ )
155
+
156
+ for t in threads:
157
+ t.start()
158
+
159
+ for t in threads:
160
+ t.join()
161
+
162
+ # Retrieve all exceptions, print all except one and raise the first recorded exception
163
+ threads_exceptions = []
164
+ while not exceptions.empty():
165
+ threads_exceptions.append(exceptions.get())
166
+
167
+ if len(threads_exceptions) > 0:
168
+ for i in range(1, len(threads_exceptions)):
169
+ logger.error(threads_exceptions[i])
170
+
171
+ raise threads_exceptions[0]
172
+
173
+
174
+ def tee_logs(
175
+ dst: TextIO,
176
+ app_handle: str,
177
+ regex: Optional[str],
178
+ runner: "Runner",
179
+ should_tail: bool = False,
180
+ streams: Optional["Stream"] = None,
181
+ colorize: bool = False,
182
+ ) -> threading.Thread:
183
+ """
184
+ Makes a thread, which in turn will start 1 thread per replica
185
+ per role, that tees that role-replica's logs to the given
186
+ destination buffer.
187
+
188
+ You'll need to start and join with this parent thread.
189
+
190
+ dst: TextIO to tee the logs into
191
+ app_handle: The return value of runner.run() or runner.schedule()
192
+ regex: Regex to filter the logs that are tee-d
193
+ runner: The Runner you used to schedule the job
194
+ should_tail: If true, continue until we run out of logs. Otherwise, just fetch
195
+ what's available
196
+ streams: Whether to fetch STDERR, STDOUT, or the temporally COMBINED (default) logs
197
+ """
198
+ thread = threading.Thread(
199
+ target=_start_threads_to_monitor_role_replicas,
200
+ kwargs={
201
+ "dst": dst,
202
+ "runner": runner,
203
+ "app_handle": app_handle,
204
+ "regex": None,
205
+ "should_tail": True,
206
+ "colorize": colorize,
207
+ },
208
+ daemon=True,
209
+ )
210
+ return thread