torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,160 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import importlib
9
+ import logging
10
+ import sys
11
+ import warnings
12
+ from logging import Logger
13
+
14
+ # pyre-ignore
15
+ from pickle import _getattribute, PickleError, whichmodule
16
+ from types import BuiltinFunctionType, FunctionType
17
+ from typing import (
18
+ Any,
19
+ Callable,
20
+ Dict,
21
+ NamedTuple,
22
+ Optional,
23
+ Protocol,
24
+ runtime_checkable,
25
+ )
26
+
27
+ import cloudpickle
28
+
29
+ logger: Logger = logging.getLogger(__name__)
30
+
31
+
32
+ @runtime_checkable
33
+ class ResolvableFunction(Protocol):
34
+ def resolve(self) -> Callable: ...
35
+
36
+
37
+ ConvertsToResolvable = Any
38
+
39
+
40
+ def _string_resolver(arg: Any) -> Optional[ResolvableFunction]:
41
+ if isinstance(arg, str) and "." in arg:
42
+ return ResolvableFunctionFromPath(arg)
43
+
44
+
45
+ def _torch_resolver(arg: Any) -> Optional[ResolvableFunction]:
46
+ import torch
47
+
48
+ if isinstance(arg, torch._ops.OpOverload):
49
+ return ResolvableFunctionFromPath("torch.ops." + str(arg))
50
+
51
+
52
+ def function_to_import_path(arg: BuiltinFunctionType | FunctionType) -> Optional[str]:
53
+ # code replicated from pickler to check if we
54
+ # would successfully be able to pickle this function.
55
+ name = getattr(arg, "__qualname__", None)
56
+ if name is None:
57
+ name = arg.__name__
58
+ try:
59
+ # pyre-ignore
60
+ module_name = whichmodule(arg, name)
61
+ __import__(module_name, level=0)
62
+ module = sys.modules[module_name]
63
+ if module_name == "__main__":
64
+ return None # the workers will not have the same main
65
+
66
+ # pytest installs its own custom loaders that do not
67
+ # survive process creation
68
+ try:
69
+ if "pytest" in module.__loader__.__class__.__module__:
70
+ return None
71
+ except AttributeError:
72
+ pass
73
+
74
+ # pyre-ignore
75
+ obj2, parent = _getattribute(module, name)
76
+ # support annotations that cover up the global impl
77
+ if obj2 is arg or getattr(obj2, "_remote_impl", None) is arg:
78
+ return f"{module_name}.{name}"
79
+ except (PickleError, ImportError, KeyError, AttributeError):
80
+ pass
81
+ return None
82
+
83
+
84
+ def _function_resolver(arg: Any):
85
+ if isinstance(arg, (FunctionType, BuiltinFunctionType)):
86
+ if path := function_to_import_path(arg):
87
+ return ResolvableFunctionFromPath(path)
88
+
89
+
90
+ def _cloudpickle_resolver(arg: Any):
91
+ # @lint-ignore PYTHONPICKLEISBAD
92
+ return ResolvableFromCloudpickle(cloudpickle.dumps(arg))
93
+
94
+
95
+ resolvers = [
96
+ _torch_resolver,
97
+ _string_resolver,
98
+ _function_resolver,
99
+ _cloudpickle_resolver,
100
+ ]
101
+
102
+
103
+ _cached_resolvers = {}
104
+
105
+
106
+ def maybe_resolvable_function(arg: Any) -> Optional[ResolvableFunction]:
107
+ if arg == "__test_panic":
108
+ return ResolvableFunctionFromPath("__test_panic")
109
+ r = _cached_resolvers.get(arg)
110
+ if r is not None:
111
+ return r
112
+ for resolver in resolvers:
113
+ r = resolver(arg)
114
+ if r is not None:
115
+ _cached_resolvers[arg] = r
116
+ return r
117
+ return None
118
+
119
+
120
+ def resolvable_function(arg: ConvertsToResolvable) -> ResolvableFunction:
121
+ if isinstance(arg, ResolvableFunction):
122
+ return arg
123
+ r = maybe_resolvable_function(arg)
124
+ if r is None:
125
+ raise ValueError(f"Unsupported target for a remote call: {arg!r}")
126
+ return r
127
+
128
+
129
+ class ResolvableFunctionFromPath(NamedTuple):
130
+ path: str
131
+
132
+ def resolve(self):
133
+ first, *parts = self.path.split(".")
134
+ if first == "torch":
135
+ function = importlib.import_module("torch")
136
+ for p in parts:
137
+ function = getattr(function, p)
138
+ assert isinstance(function, Callable)
139
+ else:
140
+ modulename, funcname = self.path.rsplit(".", 1)
141
+ module = importlib.import_module(modulename)
142
+ function = getattr(module, funcname)
143
+ # support annotations that cover up the global impl
144
+ actual = getattr(function, "_remote_impl", None)
145
+ return function if actual is None else actual
146
+ return function
147
+
148
+ def __str__(self):
149
+ return self.path
150
+
151
+
152
+ class ResolvableFromCloudpickle(NamedTuple):
153
+ data: bytes
154
+
155
+ def resolve(self):
156
+ # @lint-ignore PYTHONPICKLEISBAD
157
+ return cloudpickle.loads(self.data)
158
+
159
+
160
+ Propagator = Any
@@ -0,0 +1,164 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import itertools
8
+ from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Sequence, Tuple
9
+
10
+ import torch
11
+ from torch import autograd
12
+ from torch.utils._pytree import tree_flatten, TreeSpec
13
+
14
+
15
+ class AliasOf(NamedTuple):
16
+ group: int # 0 -this group, -1 - the parent, -2 - parent's parent, etc.
17
+ offset: int
18
+
19
+
20
+ class Storage(NamedTuple):
21
+ numel: int
22
+
23
+
24
+ # Hashable pattern for recreating tensors
25
+ # Each tensor either creates its own Storage
26
+ # or is an AliasOf another tensor either earlier in this list,
27
+ # or in one of the parent lists.
28
+ # parent lists are used to represent other collections of tensors
29
+ # for instance if this pattern is for outputs of a function
30
+ # parents might contains lists of inputs to the function and captured
31
+ # globals as two separate lists.
32
+ class TensorGroupPattern(NamedTuple):
33
+ entries: Tuple["PatternEntry", ...]
34
+
35
+ def empty(self, parents: List[List[torch.Tensor]]) -> List[torch.Tensor]:
36
+ tensors = []
37
+ for entry in self.entries:
38
+ match entry.storage:
39
+ case AliasOf(group=group, offset=offset):
40
+ base = tensors[offset] if group == 0 else parents[group][offset]
41
+ case Storage(numel=numel):
42
+ base = torch.empty(
43
+ (numel,),
44
+ dtype=entry.dtype,
45
+ layout=entry.layout,
46
+ device=entry.device,
47
+ )
48
+ case _:
49
+ raise ValueError("unexpected storage")
50
+ t = torch.as_strided(base, entry.size, entry.stride, entry.storage_offset)
51
+ tensors.append(t)
52
+ return tensors
53
+
54
+
55
+ class PatternEntry(NamedTuple):
56
+ size: Tuple[int, ...]
57
+ stride: Tuple[int, ...]
58
+ storage_offset: int
59
+ dtype: torch.dtype
60
+ layout: torch.layout
61
+ device: torch.device
62
+ storage: AliasOf | Storage
63
+
64
+
65
+ # Takes a list of tensors and computes the pattern of aliasing that
66
+ # would reconstruct the group. If `parent` is specified aliases
67
+ # are also computed with respect to that group and its parents.
68
+ # new storage is only specified is a tensor's storage was not
69
+ # seen in any parent or previously in a group.
70
+ class TensorGroup:
71
+ def __init__(
72
+ self,
73
+ tensors: Sequence[torch.Tensor],
74
+ parent: Optional["TensorGroup"] = None,
75
+ ):
76
+ self.parent = parent
77
+ self.tensors = tensors
78
+ self.storage_dict: Dict[torch.UntypedStorage, int] = {}
79
+
80
+ def create_entry(i: int, t: torch.Tensor):
81
+ storage = t.untyped_storage()
82
+ numel = t.untyped_storage().size() // t.element_size()
83
+ alias = self._find_alias(storage)
84
+ if alias is None:
85
+ self.storage_dict[storage] = i
86
+ alias = Storage(numel)
87
+
88
+ return PatternEntry(
89
+ tuple(t.size()),
90
+ tuple(t.stride()),
91
+ int(t.storage_offset()),
92
+ t.dtype,
93
+ t.layout,
94
+ t.device,
95
+ alias,
96
+ )
97
+
98
+ self.pattern = TensorGroupPattern(
99
+ tuple(create_entry(i, t) for i, t in enumerate(tensors))
100
+ )
101
+
102
+ def _find_alias(self, storage: torch.UntypedStorage) -> Optional[AliasOf]:
103
+ grp = self
104
+ for i in itertools.count():
105
+ if storage in grp.storage_dict:
106
+ return AliasOf(-i, grp.storage_dict[storage])
107
+ if grp.parent is None:
108
+ return None
109
+ grp = grp.parent
110
+
111
+
112
+ class TensorPlaceholder:
113
+ pass
114
+
115
+
116
+ # singleton to represent where tensors go in a pytree
117
+ tensor_placeholder = TensorPlaceholder()
118
+
119
+
120
+ def _to_placeholder(x):
121
+ if isinstance(x, torch.Tensor):
122
+ return tensor_placeholder
123
+ return x
124
+
125
+
126
+ def _remove_ctx(x):
127
+ if isinstance(x, autograd.function.FunctionCtx):
128
+ return None
129
+ return x
130
+
131
+
132
+ # customizable set of filters to handle data types that appear
133
+ # in functions that one wants to support in cached functions
134
+ key_filters = [_to_placeholder, _remove_ctx]
135
+
136
+
137
+ def _filter_key(v: Any):
138
+ for filter in key_filters:
139
+ v = filter(v)
140
+ return v
141
+
142
+
143
+ class HashableTreeSpec(NamedTuple):
144
+ type: Any
145
+ context: Any
146
+ children_specs: Tuple["HashableTreeSpec", ...]
147
+
148
+ @staticmethod
149
+ def from_treespec(t: "TreeSpec"):
150
+ return HashableTreeSpec(
151
+ t.type,
152
+ tuple(t.context) if isinstance(t.context, list) else t.context,
153
+ tuple(HashableTreeSpec.from_treespec(child) for child in t.children_specs),
154
+ )
155
+
156
+
157
+ def hashable_tensor_flatten(args, kwargs) -> Tuple[List[torch.Tensor], Hashable]:
158
+ values, spec = tree_flatten((args, kwargs))
159
+ tensors = [t for t in values if isinstance(t, torch.Tensor)]
160
+ key: Hashable = (
161
+ tuple(_filter_key(v) for v in values),
162
+ HashableTreeSpec.from_treespec(spec),
163
+ )
164
+ return tensors, key
@@ -0,0 +1,168 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import math
10
+ import os
11
+ import subprocess
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ cast,
16
+ Generic,
17
+ Optional,
18
+ Sequence,
19
+ TYPE_CHECKING,
20
+ TypeVar,
21
+ )
22
+
23
+ from monarch_supervisor import TTL
24
+
25
+ if TYPE_CHECKING:
26
+ from monarch.common.client import Client
27
+
28
+ from .invocation import RemoteException
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ try:
33
+ PYSPY_REPORT_INTERVAL: Optional[float] = float(
34
+ os.environ["CONTROLLER_PYSPY_REPORT_INTERVAL"]
35
+ )
36
+ except KeyError:
37
+ PYSPY_REPORT_INTERVAL = None
38
+
39
+
40
+ def _split(elems, cond):
41
+ trues = []
42
+ falses = []
43
+ for elem in elems:
44
+ if cond(elem):
45
+ trues.append(elem)
46
+ else:
47
+ falses.append(elem)
48
+ return trues, falses
49
+
50
+
51
+ def _periodic_TTL(interval: Optional[float]) -> Callable[[], float]:
52
+ if interval is None:
53
+ return lambda: math.inf
54
+
55
+ ttl = TTL(interval)
56
+
57
+ def _remaining():
58
+ nonlocal ttl
59
+ rem = ttl()
60
+ if rem == 0:
61
+ ttl = TTL(interval)
62
+ return rem
63
+
64
+ return _remaining
65
+
66
+
67
+ T = TypeVar("T")
68
+
69
+
70
+ class Future(Generic[T]):
71
+ def __init__(self, client: "Client"):
72
+ self._client = client
73
+ self._status = "incomplete"
74
+ self._callbacks = None
75
+ self._result: T | Exception | None = None
76
+
77
+ def _set_result(self, r):
78
+ assert self._status == "incomplete"
79
+ self._result = r
80
+ self._status = "exception" if isinstance(r, RemoteException) else "complete"
81
+ if self._callbacks:
82
+ for cb in self._callbacks:
83
+ try:
84
+ cb(self)
85
+ except Exception:
86
+ logger.exception("exception in controller's Future callback")
87
+ self._callbacks = None
88
+ self._client = None
89
+
90
+ def _wait(self, timeout: Optional[float]):
91
+ if self._status != "incomplete":
92
+ return True
93
+
94
+ assert self._client is not None
95
+
96
+ # see if the future is done already
97
+ # and we just haven't processed the messages
98
+ while self._client.handle_next_message(0):
99
+ if self._status != "incomplete":
100
+ return True
101
+
102
+ ttl = TTL(timeout)
103
+ ttl_pyspy = _periodic_TTL(PYSPY_REPORT_INTERVAL)
104
+ while self._status == "incomplete" and _wait(self._client, ttl, ttl_pyspy):
105
+ ...
106
+
107
+ return self._status != "incomplete"
108
+
109
+ def result(self, timeout: Optional[float] = None) -> T:
110
+ if not self._wait(timeout):
111
+ raise TimeoutError()
112
+ if self._status == "exception":
113
+ raise cast(Exception, self._result)
114
+ return cast(T, self._result)
115
+
116
+ def done(self) -> bool:
117
+ return self._wait(0)
118
+
119
+ def exception(self, timeout: Optional[float] = None):
120
+ if not self._wait(timeout):
121
+ raise TimeoutError()
122
+ return self._result if self._status == "exception" else None
123
+
124
+ def add_callback(self, callback):
125
+ if not self._callbacks:
126
+ self._callbacks = [callback]
127
+ else:
128
+ self._callbacks.append(callback)
129
+
130
+
131
+ def _wait(client: "Client", ttl: Callable[[], float], ttl_pyspy: Callable[[], float]):
132
+ remaining = ttl()
133
+ pyspy_remaining = ttl_pyspy()
134
+ if pyspy_remaining == 0:
135
+ try:
136
+ logging.warning(
137
+ f"future has not finished in {PYSPY_REPORT_INTERVAL} seconds (remaining time to live is {remaining}), py-spying process to debug."
138
+ )
139
+ subprocess.run(["py-spy", "dump", "-s", "-p", str(os.getpid())])
140
+ except FileNotFoundError:
141
+ logging.warning("py-spy is not installed.")
142
+ timeout = min(remaining, pyspy_remaining)
143
+ client.handle_next_message(timeout=None if timeout == math.inf else timeout)
144
+ return remaining > 0
145
+
146
+
147
+ def stream(futures: Sequence[Future], timeout: Optional[float] = None):
148
+ """Stream the provided futures as they complete.
149
+
150
+ If a timeout is provided, it applies to the completion of the entire set of futures.
151
+ """
152
+ assert len(futures) > 0
153
+
154
+ ttl = TTL(timeout)
155
+ pyspy_ttl = _periodic_TTL(PYSPY_REPORT_INTERVAL)
156
+
157
+ assert (
158
+ len({f._client for f in futures if f._client is not None}) <= 1
159
+ ), "all futures must be from the same controller"
160
+
161
+ todo = futures
162
+ while True:
163
+ done, todo = _split(todo, lambda f: f._status != "incomplete")
164
+ for f in done:
165
+ yield f
166
+
167
+ if len(todo) == 0 or not _wait(todo[0]._client, ttl, pyspy_ttl):
168
+ break
@@ -0,0 +1,125 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import traceback
9
+ from typing import Any, List, Optional, Tuple
10
+
11
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
12
+ ActorId,
13
+ )
14
+
15
+
16
+ Seq = int
17
+
18
+
19
+ class DeviceException(Exception):
20
+ """
21
+ Non-deterministic failure in the underlying worker, controller or its infrastructure.
22
+ For example, a worker may enter a crash loop, or its GPU may be lost
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ exception: Exception,
28
+ frames: List[traceback.FrameSummary],
29
+ source_actor_id: ActorId,
30
+ message: str,
31
+ ):
32
+ self.exception = exception
33
+ self.frames = frames
34
+ self.source_actor_id = source_actor_id
35
+ self.message = message
36
+
37
+ def __str__(self):
38
+ try:
39
+ exe = str(self.exception)
40
+ worker_tb = "".join(traceback.format_list(self.frames))
41
+ return (
42
+ f"{self.message}\n"
43
+ f"Traceback of the failure on worker (most recent call last):\n{worker_tb}{type(self.exception).__name__}: {exe}"
44
+ )
45
+ except Exception as e:
46
+ print(e)
47
+ return "oops"
48
+
49
+
50
+ class RemoteException(Exception):
51
+ """
52
+ Deterministic problem with the user's code.
53
+ For example, an OOM resulting in trying to allocate too much GPU memory, or violating
54
+ some invariant enforced by the various APIs.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ seq: Seq,
60
+ exception: Exception,
61
+ controller_frame_index: Optional[int],
62
+ controller_frames: Optional[List[traceback.FrameSummary]],
63
+ worker_frames: List[traceback.FrameSummary],
64
+ source_actor_id: ActorId,
65
+ message="A remote function has failed asynchronously.",
66
+ ):
67
+ self.exception = exception
68
+ self.worker_frames = worker_frames
69
+ self.message = message
70
+ self.seq = seq
71
+ self.controller_frame_index = controller_frame_index
72
+ self.source_actor_id = source_actor_id
73
+ self.controller_frames = controller_frames
74
+
75
+ def __str__(self):
76
+ try:
77
+ exe = str(self.exception)
78
+ worker_tb = "".join(traceback.format_list(self.worker_frames))
79
+ controller_tb = (
80
+ "".join(traceback.format_list(self.controller_frames))
81
+ if self.controller_frames is not None
82
+ else " <not related to a specific invocation>\n"
83
+ )
84
+ return (
85
+ f"{self.message}\n"
86
+ f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}"
87
+ f"Traceback of where the remote function failed on worker (most recent call last):\n{worker_tb}{type(self.exception).__name__}: {exe}"
88
+ )
89
+ except Exception as e:
90
+ print(e)
91
+ return "oops"
92
+
93
+
94
+ class Invocation:
95
+ def __init__(self, seq: Seq):
96
+ self.seq = seq
97
+ self.users: Optional[set["Invocation"]] = set()
98
+ self.failure: Optional[RemoteException] = None
99
+ self.fut_value: Any = None
100
+
101
+ def __repr__(self):
102
+ return f"<Invocation {self.seq}>"
103
+
104
+ def fail(self, remote_exception: RemoteException):
105
+ if self.failure is None or self.failure.seq > remote_exception.seq:
106
+ self.failure = remote_exception
107
+ return True
108
+ return False
109
+
110
+ def add_user(self, r: "Invocation"):
111
+ if self.users is not None:
112
+ self.users.add(r)
113
+ if self.failure is not None:
114
+ r.fail(self.failure)
115
+
116
+ def complete(self) -> Tuple[Any, Optional[RemoteException]]:
117
+ """
118
+ Complete the current invocation.
119
+ Return the result and exception tuple.
120
+ """
121
+ # after completion we no longer need to inform users of failures
122
+ # since they will just immediately get the value during add_user
123
+ self.users = None
124
+
125
+ return (self.fut_value if self.failure is None else None, self.failure)