torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,31 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from typing import NamedTuple, Tuple
9
+
10
+ import torch
11
+
12
+
13
+ class TensorFactory(NamedTuple):
14
+ size: Tuple[int, ...]
15
+ dtype: torch.dtype
16
+ layout: torch.layout
17
+ device: torch.device
18
+
19
+ @staticmethod
20
+ def from_tensor(t):
21
+ return TensorFactory(t.size(), t.dtype, t.layout, t.device)
22
+
23
+ def empty(self):
24
+ return torch.empty(
25
+ self.size, dtype=self.dtype, layout=self.layout, device=self.device
26
+ )
27
+
28
+ def zeros(self):
29
+ return torch.full(
30
+ self.size, 0, dtype=self.dtype, layout=self.layout, device=self.device
31
+ )
monarch/common/tree.py ADDED
@@ -0,0 +1,73 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from typing import Any, Callable, Protocol, Sequence, Tuple
9
+
10
+ import torch.utils._pytree as _pytree
11
+ from torch.utils._pytree import (
12
+ _get_node_type,
13
+ register_pytree_node,
14
+ SUPPORTED_NODES,
15
+ tree_flatten,
16
+ tree_map,
17
+ tree_unflatten,
18
+ )
19
+
20
+
21
+ def flatten(tree, cond):
22
+ r, spec = tree_flatten(tree)
23
+
24
+ # be careful to not capture values we return in
25
+ # 'trues'. We do not need them to reconstruct and do not want to
26
+ # extend their lifetime.
27
+ trues = []
28
+ falses = []
29
+ conds = []
30
+ for e in r:
31
+ c = cond(e)
32
+ (trues if c else falses).append(e)
33
+ conds.append(c)
34
+
35
+ def unflatten(n):
36
+ n_it = iter(n)
37
+ falses_it = iter(falses)
38
+ return tree_unflatten([next(n_it if c else falses_it) for c in conds], spec)
39
+
40
+ return trues, unflatten
41
+
42
+
43
+ def flattener(tree, cond=None):
44
+ """
45
+ Produce a _traceable_ flattener routine from tree. That is, it produces code that can
46
+ flatten another object shaped the same as tree, but whose structure cannot
47
+ be introspected because it might be (e.g.) an fx proxy value.
48
+ """
49
+ if isinstance(tree, (tuple, list)):
50
+ flattens = [flattener(t, cond) for t in tree]
51
+ return lambda obj: [
52
+ f for i, flatten in enumerate(flattens) for f in flatten(obj[i])
53
+ ]
54
+ elif isinstance(tree, dict):
55
+ keys = tuple(tree.keys())
56
+ flattens = [flattener(t, cond) for t in tree.values()]
57
+ return lambda obj: [
58
+ f for k, flatten in zip(keys, flattens) for f in flatten(obj[k])
59
+ ]
60
+ elif _get_node_type(tree) in SUPPORTED_NODES:
61
+ flatten_fn = SUPPORTED_NODES[_get_node_type(tree)].flatten_fn
62
+ trees, _ = flatten_fn(tree)
63
+ flattens = [flattener(t, cond) for t in trees]
64
+
65
+ def the_flattener(obj):
66
+ trees, _ = flatten_fn(obj)
67
+ return [f for i, flatten in enumerate(flattens) for f in flatten(trees[i])]
68
+
69
+ return the_flattener
70
+ elif cond is None or cond(tree):
71
+ return lambda obj: [obj]
72
+ else:
73
+ return lambda obj: []
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -0,0 +1,223 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+
10
+ import os
11
+ import socket
12
+
13
+ from abc import ABC, abstractmethod
14
+ from typing import List, NamedTuple, Optional, Sequence, Tuple
15
+
16
+ from monarch.common import messages
17
+
18
+ from monarch.common.shape import iter_ranks, Slices as Ranks
19
+ from monarch_supervisor import (
20
+ Context,
21
+ FunctionCall,
22
+ Host,
23
+ Process,
24
+ ProcessExited as ProcessExitedMsg,
25
+ )
26
+ from torch.distributed import TCPStore
27
+
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class Backend(ABC):
33
+ @abstractmethod
34
+ def send(self, ranks: Ranks, msg) -> None:
35
+ raise NotImplementedError()
36
+
37
+ @abstractmethod
38
+ def recvready(self, timeout: Optional[float]) -> Sequence[Tuple[int, NamedTuple]]:
39
+ raise NotImplementedError()
40
+
41
+ @property
42
+ @abstractmethod
43
+ def world_size(self):
44
+ raise NotImplementedError()
45
+
46
+ @property
47
+ @abstractmethod
48
+ def gpu_per_host(self):
49
+ raise NotImplementedError()
50
+
51
+
52
+ class ProcessBackend(Backend):
53
+ def __init__(
54
+ self,
55
+ ctx: Context,
56
+ hosts: List[Host],
57
+ gpu_per_host: int,
58
+ _processes=None,
59
+ _store=None,
60
+ ):
61
+ self.ctx = ctx
62
+ self.hosts = hosts
63
+ self.store = self._create_store() if _store is None else _store
64
+ self._gpu_per_host = gpu_per_host
65
+ self.worker_processes = (
66
+ self._create_pg(ctx, hosts, gpu_per_host, self.store)
67
+ if _processes is None
68
+ else _processes
69
+ )
70
+ self.exiting = False
71
+ self.process_to_rank = {p: p.rank for p in self.worker_processes}
72
+ self.live_processes_per_rank: List[List[Process]] = [
73
+ [p] for p in self.worker_processes
74
+ ]
75
+
76
+ @property
77
+ def world_size(self):
78
+ return len(self.worker_processes)
79
+
80
+ @property
81
+ def gpu_per_host(self) -> int:
82
+ return self._gpu_per_host
83
+
84
+ def send(self, ranks: Ranks, msg) -> None:
85
+ handler = getattr(self, msg.__class__.__name__, None)
86
+ if handler is not None:
87
+ handler(ranks, msg)
88
+ self._send(ranks, msg)
89
+
90
+ def _send(self, ranks: Ranks, msg):
91
+ # the intent is for this to be optimized as tree broadcast
92
+ # base on if members of tree nodes overlap with a slice.
93
+ for rank in iter_ranks(ranks):
94
+ self.worker_processes[rank].send(msg)
95
+
96
+ def CommandGroup(self, ranks: Ranks, msg: messages.CommandGroup):
97
+ for command in msg.commands:
98
+ handler = getattr(self, command.__class__.__name__, None)
99
+ if handler is not None:
100
+ handler(ranks, command)
101
+
102
+ def CreatePipe(self, ranks: Ranks, msg: messages.CreatePipe):
103
+ pipe_ranks = list(enumerate(iter_ranks(ranks)))
104
+ for i, rank in pipe_ranks:
105
+ # In general, pipes on different workers may need to have different behavior.
106
+ # For example, two data loader pipes operating on the same dataset should
107
+ # load different shards of the dataset. In order to do this, each pipe process
108
+ # on the worker needs to know the number of instances of the pipe (e.g. len(pipe_ranks))
109
+ # and its unique rank among all instances of the pipe (e.g., i).
110
+ proc = self.worker_processes[rank].host.create_process(
111
+ FunctionCall(
112
+ "monarch.worker.worker.pipe_main",
113
+ f"{msg.key}-{rank}",
114
+ msg.max_messages,
115
+ ),
116
+ env={"CUDA_VISIBLE_DEVICES": ""},
117
+ name=f"pipe-{rank}",
118
+ )
119
+ self.live_processes_per_rank[rank].append(proc)
120
+ self.process_to_rank[proc] = rank
121
+
122
+ def ProcessExited(
123
+ self, sender: Process, msg: ProcessExitedMsg
124
+ ) -> List[Tuple[int, NamedTuple]]:
125
+ return self._process_exited(sender, msg.result)
126
+
127
+ def Restarted(
128
+ self, sender: Process, restarted: messages.Restarted
129
+ ) -> List[Tuple[int, NamedTuple]]:
130
+ return self._process_exited(sender, restarted.result)
131
+
132
+ def _process_exited(
133
+ self, sender: Process, result: int | Exception
134
+ ) -> List[Tuple[int, NamedTuple]]:
135
+ rank = self.process_to_rank[sender]
136
+ if result != 0:
137
+ if not self.exiting or self.worker_processes[rank] is sender:
138
+ kind = (
139
+ "worker"
140
+ if self.worker_processes[rank] is sender
141
+ else "pipe_process"
142
+ )
143
+ raise RuntimeError(f"Unexpected {kind} exit on rank {rank}")
144
+
145
+ live_procs = self.live_processes_per_rank[rank]
146
+ live_procs.remove(sender)
147
+ if len(live_procs) == 0:
148
+ return [(rank, ProcessExitedMsg(0))]
149
+ return []
150
+
151
+ def Exit(self, ranks: Ranks, msg: messages.Exit):
152
+ self.exiting = True
153
+ for rank in iter_ranks(ranks):
154
+ # ideally we are more kind to these processes.
155
+ # but first we need to develop the API for asking them
156
+ # to suspend, restore, fast forward, rewind, etc.
157
+ worker = self.worker_processes[rank]
158
+ for proc in self.live_processes_per_rank[rank]:
159
+ if worker is not proc:
160
+ proc.signal()
161
+ self.worker_processes[rank].send(msg)
162
+
163
+ def recvready(self, timeout: Optional[float]) -> Sequence[Tuple[int, NamedTuple]]:
164
+ result = []
165
+ for sender, msg in self.ctx.recvready(timeout):
166
+ handler = getattr(self, msg.__class__.__name__, None)
167
+ if handler is not None:
168
+ result.extend(handler(sender, msg))
169
+ continue
170
+ elif isinstance(sender, Process):
171
+ result.append((sender.rank, msg))
172
+ else:
173
+ logger.warning("TODO: ignoring non-worker message: %s %s", sender, msg)
174
+ return result
175
+
176
+ @staticmethod
177
+ def _create_store():
178
+ if os.environ.get("INSIDE_RE_WORKER"):
179
+ hostname = "localhost"
180
+ else:
181
+ hostname = socket.gethostname()
182
+
183
+ with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
184
+ sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
185
+ sock.bind(("::", 0))
186
+ port = sock.getsockname()[1]
187
+ store = TCPStore(
188
+ hostname,
189
+ port,
190
+ is_master=True,
191
+ use_libuv=False,
192
+ master_listen_fd=sock.detach(),
193
+ )
194
+ return store
195
+
196
+ @staticmethod
197
+ def _create_pg(
198
+ ctx: Context, hosts: List[Host], gpu_per_host: int, store, _restartable=False
199
+ ):
200
+ env = {
201
+ # cuda event cache disabled pending fix for:
202
+ # https://github.com/pytorch/pytorch/issues/143470
203
+ "TORCH_NCCL_CUDA_EVENT_CACHE": "0",
204
+ # disable nonblocking comm until D68727854 lands.
205
+ "TORCH_NCCL_USE_COMM_NONBLOCKING": "0",
206
+ # supervisor_pipe is a unique ID per Host object,
207
+ # so it lets us put multiple processes on the same GPU.
208
+ "NCCL_HOSTID": "$SUPERVISOR_PIPE",
209
+ "STORE_HOSTNAME": store.host,
210
+ "STORE_PORT": str(store.port),
211
+ }
212
+ for name, value in os.environ.items():
213
+ if name.startswith("NCCL_") and name not in env:
214
+ env[name] = value
215
+ return ctx.create_process_group(
216
+ hosts,
217
+ FunctionCall(
218
+ "monarch.worker.worker.worker_main", _restartable=_restartable
219
+ ),
220
+ processes_per_host=gpu_per_host,
221
+ env=env,
222
+ name="worker",
223
+ )
@@ -0,0 +1,223 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import traceback
10
+ from collections import deque
11
+ from typing import Generator, List, NamedTuple, Optional, Sequence, Tuple, Union
12
+
13
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
14
+ DebuggerMessage,
15
+ WorldState,
16
+ )
17
+
18
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
19
+ ActorId,
20
+ )
21
+
22
+ from monarch.common import messages
23
+ from monarch.common.controller_api import LogMessage, MessageResult
24
+ from monarch.common.invocation import DeviceException, Seq
25
+ from monarch.common.reference import Ref
26
+ from monarch.common.shape import NDSlice
27
+ from monarch.common.tensor import Tensor
28
+ from monarch.controller import debugger
29
+
30
+ from .backend import Backend
31
+ from .history import History
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class Controller:
37
+ def __init__(self, backend: Backend):
38
+ self._backend = backend
39
+ self._history = History(backend.world_size)
40
+ self._messages = deque()
41
+
42
+ self.exited = {}
43
+ self.active_debugger: Optional[Tuple[int, int]] = None
44
+ self.pending_debugger_sessions: deque[Tuple[int, int]] = deque()
45
+ # for current active session
46
+ self.pending_debugger_messages: deque[messages.DebuggerMessage] = deque()
47
+
48
+ def send(
49
+ self,
50
+ ranks: Union[NDSlice, List[NDSlice]],
51
+ msg: NamedTuple,
52
+ ) -> None:
53
+ self._backend.send(ranks, msg)
54
+
55
+ def next_message(
56
+ self, timeout: Optional[float]
57
+ ) -> Optional[MessageResult | LogMessage]:
58
+ if len(self._messages) == 0:
59
+ self._messages.extend(self._read_messages(timeout))
60
+ return self._messages.popleft() if len(self._messages) > 0 else None
61
+
62
+ def drop_refs(self, refs: Sequence[Ref]) -> None:
63
+ """
64
+ noop as this is used for the Rust controller to know when to gc invocations_for_ref for failed invocations
65
+ """
66
+ pass
67
+
68
+ def _read_messages(
69
+ self, timeout: Optional[float]
70
+ ) -> Generator[MessageResult, None, None]:
71
+ # XXX - how can we avoid always requesting status when waiting on futures?
72
+ # we need to figure out what submesh we need to hear from before a future
73
+ # is considered 'good'. This means not just waiting for the future value
74
+ # but also for signal that any failures that could invalidate the future have
75
+ # not happened. We could do better if tensors/collectives had an invalid bit
76
+ # that we propagate. In real uses fetches might lag behind anyway so we would not
77
+ # have to send out so many requests for current status.
78
+ for rank, value in self._backend.recvready(timeout):
79
+ yield from self._handle_message(rank, value)
80
+
81
+ def drain_and_stop(self) -> List[MessageResult | LogMessage | DebuggerMessage]:
82
+ messages = []
83
+ while self._messages:
84
+ messages.append(self._messages.popleft())
85
+ while len(self.exited) < self._backend.world_size:
86
+ messages.extend(self._read_messages(None))
87
+ return messages
88
+
89
+ def stop_mesh(self) -> None:
90
+ pass
91
+
92
+ def node(
93
+ self,
94
+ seq: Seq,
95
+ defs: Sequence["Tensor"],
96
+ uses: Sequence["Tensor"],
97
+ ) -> None:
98
+ self._history.ident(seq, defs, uses)
99
+
100
+ def _handle_message(self, sender, value) -> Generator[MessageResult, None, None]:
101
+ yield from getattr(self, value.__class__.__name__)(sender, *value)
102
+
103
+ def worker_world_state(self) -> WorldState:
104
+ # Eventhough not implemented, return needed so return value complies with type checking
105
+ assert 1 == 2, "not implemented"
106
+ return WorldState()
107
+
108
+ def ProcessExited(self, proc, result) -> Generator[MessageResult, None, None]:
109
+ if result != 0:
110
+ # XXX - this should start the failure recovery process
111
+ raise RuntimeError("Unexpected worker process exit")
112
+ self.exited[proc] = result
113
+ yield from []
114
+
115
+ def ProcessStarted(self, proc, pid) -> Generator[MessageResult, None, None]:
116
+ yield from []
117
+
118
+ def FetchResult(self, proc, ident, value) -> Generator[MessageResult, None, None]:
119
+ self._history.future_completed(ident, value)
120
+ yield from []
121
+
122
+ def RemoteFunctionFailed(
123
+ self,
124
+ proc,
125
+ failing_ident,
126
+ traceback_index,
127
+ exception: Exception,
128
+ worker_frames: List[traceback.FrameSummary],
129
+ ) -> Generator[MessageResult, None, None]:
130
+ self._history.propagate_failure(
131
+ failing_ident, traceback_index, exception, worker_frames
132
+ )
133
+ yield from self._history.rank_completed(proc, failing_ident)
134
+
135
+ def InternalException(
136
+ self,
137
+ proc,
138
+ exception: Exception,
139
+ worker_frames: List[traceback.FrameSummary],
140
+ ) -> Generator[MessageResult, None, None]:
141
+ yield MessageResult(
142
+ seq=0, # will not be used
143
+ result=None,
144
+ error=DeviceException(
145
+ exception,
146
+ worker_frames,
147
+ ActorId.from_string("unknown[0].unknown[0]"),
148
+ message="A worker experienced an internal error.",
149
+ ),
150
+ )
151
+
152
+ def RemoteGeneratorFailed(
153
+ self,
154
+ proc,
155
+ exception: Exception,
156
+ frames: List[traceback.FrameSummary],
157
+ ) -> Generator[MessageResult, None, None]:
158
+ yield MessageResult(
159
+ seq=0, # will not be used
160
+ result=None,
161
+ error=DeviceException(
162
+ exception=exception,
163
+ frames=frames,
164
+ source_actor_id=ActorId.from_string("unknown[0].unknown[0]"),
165
+ message="A remote generator failed.",
166
+ ),
167
+ )
168
+
169
+ def Status(
170
+ self, proc, first_uncompleted_ident
171
+ ) -> Generator[MessageResult, None, None]:
172
+ yield from self._history.rank_completed(proc, first_uncompleted_ident)
173
+
174
+ def DebuggerMessage(
175
+ self, proc, stream_id: int, action
176
+ ) -> Generator[MessageResult, None, None]:
177
+ if action == "paused":
178
+ self.pending_debugger_sessions.append((proc, stream_id))
179
+ else:
180
+ assert self.active_debugger == (proc, stream_id)
181
+ self.pending_debugger_messages.append(action)
182
+
183
+ if self.active_debugger is None:
184
+ yield from self._run_debugger_loop()
185
+
186
+ def _run_debugger_loop(self) -> Generator[MessageResult, None, None]:
187
+ # debug loop
188
+ while self.pending_debugger_sessions:
189
+ yield from self._run_debugger_session(
190
+ *self.pending_debugger_sessions.popleft()
191
+ )
192
+
193
+ def _run_debugger_session(
194
+ self, proc_id: int, stream_id: int
195
+ ) -> Generator[MessageResult, None, None]:
196
+ debugger.write(f"pdb attached to rank {proc_id}, stream {stream_id}\n")
197
+ self.active_debugger = (proc_id, stream_id)
198
+ try:
199
+ rank = NDSlice(offset=proc_id, sizes=[], strides=[])
200
+ self.send(rank, messages.DebuggerMessage(stream_id, "attach"))
201
+ while True:
202
+ while not self.pending_debugger_messages:
203
+ # todo: eventually we should timeout
204
+ yield from self._read_messages(None)
205
+ message = self.pending_debugger_messages.popleft()
206
+ match message:
207
+ case "detach":
208
+ break
209
+ case messages.DebuggerRead(requested):
210
+ self.send(
211
+ rank,
212
+ messages.DebuggerMessage(
213
+ stream_id,
214
+ messages.DebuggerWrite(debugger.read(requested)),
215
+ ),
216
+ )
217
+ case messages.DebuggerWrite(payload):
218
+ debugger.write(payload.decode())
219
+ case other:
220
+ raise RuntimeError(f"unexpected debugger message: {other}")
221
+ finally:
222
+ self.active_debugger = None
223
+ self.pending_debugger_messages.clear()
@@ -0,0 +1,47 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ import sys
9
+ from typing import Optional
10
+
11
+ _is_ipython: Optional[bool] = None
12
+
13
+
14
+ def is_ipython() -> bool:
15
+ global _is_ipython
16
+ if _is_ipython is not None:
17
+ return _is_ipython
18
+ try:
19
+ from IPython import get_ipython
20
+
21
+ _is_ipython = get_ipython() is not None
22
+ except ImportError:
23
+ _is_ipython = False
24
+ return _is_ipython
25
+
26
+
27
+ def write(msg: str) -> None:
28
+ sys.stdout.write(msg)
29
+ sys.stdout.flush()
30
+
31
+
32
+ def read(requested_size: int) -> bytes:
33
+ if not is_ipython():
34
+ b = bytearray(requested_size)
35
+ bytes_read = sys.stdin.buffer.raw.readinto(b)
36
+ return bytes(b[:bytes_read])
37
+
38
+ # ipython doesn't have stdin directly connected
39
+ # so we need to use input() instead.
40
+ user_input = input() + "\n"
41
+ input_bytes = user_input.encode("utf-8")
42
+ num_bytes_to_write = len(input_bytes)
43
+ if requested_size < num_bytes_to_write:
44
+ raise RuntimeError(
45
+ f"Debugger input line too long, max length is {requested_size}"
46
+ )
47
+ return input_bytes[:num_bytes_to_write]
@@ -0,0 +1,90 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from collections import deque
9
+ from typing import Generator, Sequence, TYPE_CHECKING
10
+
11
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
12
+ ActorId,
13
+ )
14
+
15
+ from monarch.common.controller_api import MessageResult
16
+
17
+ from monarch.common.invocation import Invocation, RemoteException, Seq
18
+
19
+ if TYPE_CHECKING:
20
+ from monarch.common.tensor import Tensor
21
+
22
+
23
+ class History:
24
+ def __init__(self, N):
25
+ self.first_uncompleted_ident = [0 for _ in range(N)]
26
+ self.min_first_uncompleted_ident = 0
27
+ self.invocations = deque[Invocation]()
28
+
29
+ def _invocation(
30
+ self,
31
+ seq: Seq,
32
+ defs: Sequence["Tensor"],
33
+ uses: Sequence["Tensor"],
34
+ ):
35
+ r = Invocation(seq)
36
+ for t in uses:
37
+ u = t._invocation
38
+ assert u is not None
39
+ u.add_user(r)
40
+ for t in defs:
41
+ t._invocation = r
42
+ return r
43
+
44
+ def ident(
45
+ self,
46
+ seq: Seq,
47
+ defs: Sequence["Tensor"],
48
+ uses: Sequence["Tensor"],
49
+ ):
50
+ invocation = self._invocation(seq, defs, uses)
51
+ self.invocations.append(invocation)
52
+
53
+ def propagate_failure(self, seq, traceback_index, exception, worker_frames):
54
+ invocation = self.invocations[seq - self.min_first_uncompleted_ident]
55
+ remote_exception = RemoteException(
56
+ seq,
57
+ exception,
58
+ traceback_index,
59
+ None,
60
+ worker_frames,
61
+ ActorId.from_string("unknown[0].unknown[0]"),
62
+ )
63
+ worklist = deque((invocation,))
64
+ while worklist:
65
+ invocation = worklist.popleft()
66
+ if invocation.fail(remote_exception):
67
+ worklist.extend(invocation.users)
68
+
69
+ def rank_completed(
70
+ self, rank, first_uncompleted_ident
71
+ ) -> Generator[MessageResult, None, None]:
72
+ # advance what our last completed action was, and
73
+ # trim the list of tracebacks if we no longer need them.
74
+ prev = self.first_uncompleted_ident[rank]
75
+ self.first_uncompleted_ident[rank] = first_uncompleted_ident
76
+ if prev == self.min_first_uncompleted_ident:
77
+ self.min_first_uncompleted_ident = min(self.first_uncompleted_ident)
78
+ for seq in range(prev, self.min_first_uncompleted_ident):
79
+ invocation = self.invocations.popleft()
80
+ assert seq == invocation.seq
81
+ result, error = invocation.complete()
82
+ yield MessageResult(
83
+ seq=seq,
84
+ result=result,
85
+ error=error,
86
+ )
87
+
88
+ def future_completed(self, ident, value):
89
+ invocation = self.invocations[ident - self.min_first_uncompleted_ident]
90
+ invocation.fut_value = value