torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -0,0 +1,245 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import logging
10
+ import traceback
11
+ from collections import deque
12
+ from logging import Logger
13
+ from typing import List, NamedTuple, Optional, Sequence, Union
14
+
15
+ from monarch._rust_bindings.monarch_extension import (
16
+ client,
17
+ controller,
18
+ debugger,
19
+ tensor_worker,
20
+ )
21
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
22
+ ClientActor,
23
+ SystemSnapshotFilter,
24
+ WorldState,
25
+ )
26
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
27
+ ActorId,
28
+ Proc,
29
+ )
30
+
31
+ from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
32
+ from monarch.common.controller_api import LogMessage, MessageResult
33
+ from monarch.common.device_mesh import no_mesh
34
+ from monarch.common.invocation import DeviceException, RemoteException
35
+ from monarch.common.messages import SupportsToRustMessage
36
+ from monarch.common.shape import NDSlice
37
+ from monarch.common.tensor import Tensor
38
+ from monarch.controller.debugger import read as debugger_read, write as debugger_write
39
+ from pyre_extensions import none_throws
40
+
41
+ logger: Logger = logging.getLogger(__name__)
42
+
43
+
44
+ class RustController:
45
+ def __init__(
46
+ self,
47
+ proc: Proc,
48
+ client_actor: ClientActor,
49
+ controller_id: ActorId,
50
+ worker_world_name: str,
51
+ ) -> None:
52
+ self._controller_actor = controller_id
53
+ self._proc = proc
54
+ self._actor = client_actor
55
+ # Attach the client to the controller
56
+ # Errors will be raised if someone else has attached it already.
57
+ self._actor.attach(self._controller_actor)
58
+ self._worker_world_name = worker_world_name
59
+
60
+ # Buffer for messages unrelated to debugging that are received while a
61
+ # debugger session is active.
62
+ self._non_debugger_pending_messages: deque[
63
+ Optional[client.LogMessage | client.WorkerResponse]
64
+ ] = deque()
65
+ self._pending_debugger_sessions: deque[ActorId] = deque()
66
+
67
+ def send(
68
+ self,
69
+ ranks: Union[NDSlice, List[NDSlice]],
70
+ msg: NamedTuple,
71
+ ) -> None:
72
+ self._actor.send_obj(self._controller_actor, ranks, msg)
73
+
74
+ def drop_refs(self, refs: Sequence[tensor_worker.Ref]) -> None:
75
+ self._actor.drop_refs(self._controller_actor, list(refs))
76
+
77
+ def node(
78
+ self,
79
+ seq: int,
80
+ defs: Sequence["Tensor"],
81
+ uses: Sequence["Tensor"],
82
+ ) -> None:
83
+ node = controller.Node(
84
+ seq=seq,
85
+ defs=[tensor_worker.Ref(id=t.ref) for t in defs if t.ref is not None],
86
+ uses=[tensor_worker.Ref(id=t.ref) for t in uses if t.ref is not None],
87
+ )
88
+
89
+ self._actor.send(self._controller_actor, node.serialize())
90
+
91
+ def next_message(
92
+ self, timeout: Optional[float]
93
+ ) -> Optional[LogMessage | MessageResult]:
94
+ if self._non_debugger_pending_messages:
95
+ msg = self._non_debugger_pending_messages.popleft()
96
+ else:
97
+ msg = self._actor.get_next_message(
98
+ timeout_msec=int((timeout or 0.0) * 1000.0)
99
+ )
100
+ if msg is None:
101
+ return None
102
+
103
+ if isinstance(msg, client.WorkerResponse):
104
+ return _worker_response_to_result(msg)
105
+ elif isinstance(msg, client.LogMessage):
106
+ return LogMessage(msg.level, msg.message)
107
+ elif isinstance(msg, client.DebuggerMessage):
108
+ self._run_debugger_loop(msg)
109
+
110
+ def stop_mesh(self) -> None:
111
+ logger.info("rust controller stopping the system")
112
+ self._actor.stop_worlds(
113
+ [self._controller_actor.world_name, self._worker_world_name]
114
+ )
115
+
116
+ def drain_and_stop(
117
+ self,
118
+ ) -> List[LogMessage | MessageResult | client.DebuggerMessage]:
119
+ logger.info("rust controller shutting down")
120
+ results = []
121
+ for msg in self._actor.drain_and_stop():
122
+ if isinstance(msg, client.WorkerResponse):
123
+ results.append(_worker_response_to_result(msg))
124
+ elif isinstance(msg, client.LogMessage):
125
+ results.append(LogMessage(msg.level, msg.message))
126
+ elif isinstance(msg, client.DebuggerMessage):
127
+ results.append(msg)
128
+ else:
129
+ raise RuntimeError(f"Unexpected message type {type(msg)}")
130
+ return results
131
+
132
+ def _run_debugger_loop(self, message: client.DebuggerMessage) -> None:
133
+ if not isinstance(message.action, DebuggerAction.Paused):
134
+ raise RuntimeError(
135
+ f"Unexpected debugger message {message} when no debugger session is running"
136
+ )
137
+
138
+ self._pending_debugger_sessions.append(message.debugger_actor_id)
139
+ while self._pending_debugger_sessions:
140
+ debugger_actor_id = self._pending_debugger_sessions.popleft()
141
+ rank = debugger_actor_id.rank
142
+ proc_id = debugger_actor_id.proc_id
143
+ debugger_write(
144
+ f"pdb attached to proc {proc_id} with rank {rank}, debugger actor {debugger_actor_id} \n"
145
+ )
146
+
147
+ self._actor.send(
148
+ debugger_actor_id,
149
+ debugger.DebuggerMessage(action=DebuggerAction.Attach()).serialize(),
150
+ )
151
+
152
+ while True:
153
+ # TODO: Add appropriate timeout.
154
+ msg = self._actor.get_next_message(timeout_msec=None)
155
+
156
+ if not isinstance(msg, client.DebuggerMessage):
157
+ self._non_debugger_pending_messages.append(msg)
158
+ continue
159
+
160
+ if msg.debugger_actor_id != debugger_actor_id:
161
+ if isinstance(msg.action, DebuggerAction.Paused):
162
+ self._pending_debugger_sessions.append(msg.debugger_actor_id)
163
+ continue
164
+ else:
165
+ raise RuntimeError(
166
+ f"unexpected debugger message {msg} from rank {msg.debugger_actor_id.rank} "
167
+ f"when debugging rank {debugger_actor_id.rank}"
168
+ )
169
+
170
+ action = msg.action
171
+ if isinstance(action, DebuggerAction.Detach):
172
+ break
173
+ elif isinstance(action, DebuggerAction.Read):
174
+ self._actor.send(
175
+ debugger_actor_id,
176
+ debugger.DebuggerMessage(
177
+ action=DebuggerAction.Write(
178
+ bytes=debugger_read(action.requested_size)
179
+ )
180
+ ).serialize(),
181
+ )
182
+ elif isinstance(action, DebuggerAction.Write):
183
+ debugger_write(
184
+ debugger.get_bytes_from_write_action(action).decode()
185
+ )
186
+ else:
187
+ raise RuntimeError(
188
+ f"unexpected debugger message {msg} when debugging rank {debugger_actor_id.rank}"
189
+ )
190
+
191
+ def worker_world_state(self) -> WorldState:
192
+ worlds_state = self._actor.world_state(
193
+ SystemSnapshotFilter(worlds=[self._worker_world_name])
194
+ )
195
+
196
+ return worlds_state[self._worker_world_name]
197
+
198
+
199
+ # TODO: Handling conversion of the response can move to a separate module over time
200
+ # especially as we have structured error messages.
201
+ def _worker_response_to_result(result: client.WorkerResponse) -> MessageResult:
202
+ if not result.is_exception():
203
+ # The result of the message needs to be unwrapped on a real device.
204
+ # Staying as a fake tensor will fail the tensor deserialization.
205
+ with no_mesh.activate():
206
+ return MessageResult(result.seq, result.result(), None)
207
+ exc = none_throws(result.exception())
208
+ if isinstance(exc, client.Error):
209
+ worker_frames = [
210
+ traceback.FrameSummary("<unknown>", None, frame)
211
+ for frame in exc.backtrace.split("\\n")
212
+ ]
213
+ logger.error(f"Worker {exc.actor_id} failed")
214
+ return MessageResult(
215
+ seq=result.seq,
216
+ result=None,
217
+ error=RemoteException(
218
+ seq=exc.caused_by_seq,
219
+ exception=RuntimeError(exc.backtrace),
220
+ controller_frame_index=0, # TODO: fix this once we have recording support in rust
221
+ controller_frames=None,
222
+ worker_frames=worker_frames,
223
+ source_actor_id=exc.actor_id,
224
+ message=f"Worker {exc.actor_id} failed",
225
+ ),
226
+ )
227
+ elif isinstance(exc, client.Failure):
228
+ frames = [
229
+ traceback.FrameSummary("<unknown>", None, frame)
230
+ for frame in exc.backtrace.split("\n")
231
+ ]
232
+ reason = f"Actor {exc.actor_id} crashed on {exc.address}, check the host log for details"
233
+ logger.error(reason)
234
+ return MessageResult(
235
+ seq=0, # seq is not consumed for DeviceException; it will be directly thrown by the client
236
+ result=None,
237
+ error=DeviceException(
238
+ exception=RuntimeError(reason),
239
+ frames=frames,
240
+ source_actor_id=exc.actor_id,
241
+ message=reason,
242
+ ),
243
+ )
244
+ else:
245
+ raise RuntimeError(f"Unknown exception type: {type(exc)}")
monarch/fetch.py ADDED
@@ -0,0 +1,55 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ """
9
+ This is a utility file for fetching a shard of a tensor from remote.
10
+ """
11
+
12
+ from typing import TypeVar
13
+
14
+ from monarch.common.device_mesh import no_mesh
15
+
16
+ from monarch.common.future import Future
17
+
18
+ from monarch.common.remote import _call_on_shard_and_fetch
19
+
20
+ T = TypeVar("T")
21
+
22
+
23
+ def fetch_shard(
24
+ obj: T, shard: dict[str, int] | None = None, **kwargs: int
25
+ ) -> Future[T]:
26
+ """
27
+ Retrieve the shard at `coordinates` of the current device mesh of each
28
+ tensor in obj. All tensors in `obj` will be fetched to the CPU device.
29
+ obj - a pytree containing the tensors the fetch
30
+ shard - a dictionary from mesh dimension name to coordinate of the shard
31
+ If None, this will fetch from coordinate 0 for all dimensions (useful after all_reduce/all_gather)
32
+ preprocess - a
33
+ **kwargs - additional keyword arguments are added as entries to the shard dictionary
34
+ """
35
+ if kwargs:
36
+ if shard is None:
37
+ shard = {}
38
+ shard.update(kwargs)
39
+
40
+ return _call_on_shard_and_fetch(
41
+ None, lambda *args, **kwargs: None, obj, shard=shard
42
+ )
43
+
44
+
45
+ def show(obj: T, shard: dict[str, int] | None = None, **kwargs: int) -> object:
46
+ v = inspect(obj, shard=shard, **kwargs)
47
+ # pyre-ignore
48
+ from torchshow import show # @manual
49
+
50
+ with no_mesh.activate():
51
+ return show(v)
52
+
53
+
54
+ def inspect(obj: T, shard: dict[str, int] | None = None, **kwargs: int) -> T:
55
+ return fetch_shard(obj, shard=shard, **kwargs).result()
monarch/future.py ADDED
@@ -0,0 +1,25 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import asyncio
8
+ from typing import Generator, Generic, TypeVar
9
+
10
+ R = TypeVar("R")
11
+
12
+
13
+ # TODO: consolidate with monarch.common.future
14
+ class ActorFuture(Generic[R]):
15
+ def __init__(self, impl, blocking_impl=None):
16
+ self._impl = impl
17
+ self._blocking_impl = blocking_impl
18
+
19
+ def get(self) -> R:
20
+ if self._blocking_impl is not None:
21
+ return self._blocking_impl()
22
+ return asyncio.run(self._impl())
23
+
24
+ def __await__(self) -> Generator[R, None, R]:
25
+ return self._impl().__await__()
@@ -0,0 +1,11 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from ._gradient_generator import GradientGenerator
10
+
11
+ __all__ = ["GradientGenerator"]
@@ -0,0 +1,22 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from typing import Any, Optional
9
+
10
+ import torch
11
+
12
+ class GradientGenerator:
13
+ def __init__(
14
+ self,
15
+ roots_list: Any,
16
+ with_respect_to: Any,
17
+ grad_roots: Any,
18
+ context_restorer: Any,
19
+ ): ...
20
+ # pyre-ignore[11]: Annotation `torch.Tensor` is not defined as a type.
21
+ def __next__(self) -> Optional[torch.Tensor]: ...
22
+ def __iter__(self) -> "GradientGenerator": ...
Binary file
@@ -0,0 +1,185 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import math
9
+ from contextlib import nullcontext
10
+ from functools import partial
11
+ from types import CellType, FunctionType
12
+ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
13
+
14
+ import torch
15
+ import torch.autograd.graph
16
+
17
+ from monarch.common import device_mesh, stream
18
+ from monarch.common.tensor import Tensor
19
+ from monarch.common.tree import flatten
20
+ from monarch.gradient import GradientGenerator as _GradientGenerator
21
+ from torch._C._autograd import _get_sequence_nr # @manual
22
+ from torch.autograd.graph import get_gradient_edge, GradientEdge
23
+
24
+ TensorOrEdge = Union[torch.Tensor, GradientEdge]
25
+
26
+
27
+ class Context(NamedTuple):
28
+ device_mesh: "Optional[device_mesh.DeviceMesh]"
29
+ stream: "stream.Stream"
30
+
31
+ def enable(self):
32
+ if device_mesh is None:
33
+ activate_mesh = device_mesh.no_mesh.activate()
34
+ elif self.device_mesh is not device_mesh._active:
35
+ # XXX: something about activating device meshes from this object
36
+ # doesn't work correctly and somehow inactivates the device mesh
37
+ # if it is already enabled. This is a temporary workaround for
38
+ # the demo.
39
+ activate_mesh = self.device_mesh.activate()
40
+ else:
41
+ activate_mesh = nullcontext()
42
+ with activate_mesh, self.stream.activate(), torch.no_grad():
43
+ yield
44
+
45
+
46
+ _sequence_nr_to_context: Dict[int, Context] = {}
47
+ _sequence_nr_end = 0
48
+
49
+
50
+ def restore_context(t: Optional[Tensor], sn: Optional[int], last: bool):
51
+ if sn is not None:
52
+ _update_context_map(Context(device_mesh._active, stream._active))
53
+ ctx = _sequence_nr_to_context.pop(sn) if last else _sequence_nr_to_context[sn]
54
+ return ctx.enable()
55
+ if t is not None:
56
+ return Context(t.mesh, t.stream).enable()
57
+ return Context(device_mesh._active, stream._active).enable()
58
+
59
+
60
+ def _update_context_map(ctx: Context):
61
+ global _sequence_nr_end
62
+ next_sequence_nr = _get_sequence_nr()
63
+ for i in range(_sequence_nr_end, next_sequence_nr):
64
+ _sequence_nr_to_context[i] = ctx
65
+ _sequence_nr_end = _get_sequence_nr()
66
+
67
+
68
+ device_mesh._on_change.append(
69
+ lambda old, mesh: _update_context_map(Context(old, stream._active))
70
+ )
71
+ stream._on_change.append(
72
+ lambda old, stream: _update_context_map(Context(device_mesh._active, old))
73
+ )
74
+
75
+
76
+ def grad_generator(
77
+ roots: Union[torch.Tensor, Sequence[TensorOrEdge]] = (),
78
+ with_respect_to: Sequence[TensorOrEdge] = (),
79
+ grad_roots: Sequence[Optional[torch.Tensor]] = (),
80
+ ):
81
+ if isinstance(roots, torch.Tensor):
82
+ roots = [roots]
83
+ return _GradientGenerator(
84
+ list(roots), list(with_respect_to), list(grad_roots), restore_context
85
+ )
86
+
87
+
88
+ def _gradient_edge(a: TensorOrEdge) -> GradientEdge:
89
+ if isinstance(a, GradientEdge):
90
+ return a
91
+ return get_gradient_edge(a)
92
+
93
+
94
+ class GradGenerator:
95
+ def __init__(self):
96
+ self.roots: List[torch.autograd.graph.GradientEdge] = []
97
+ self.with_respect_to: List[torch.autograd.graph.GradientEdge] = []
98
+ self.grad_roots: List[Optional[torch.Tensor]] = []
99
+ self.unflattens: List[Tuple[int, Any]] = []
100
+
101
+ def grad(self, tree: Any):
102
+ tensors, unflatten = flatten(tree, lambda x: isinstance(x, torch.Tensor))
103
+ self.unflattens.append((len(tensors), unflatten))
104
+ self.with_respect_to.extend(_gradient_edge(t) for t in tensors)
105
+
106
+ def root(self, r: TensorOrEdge, grad: Optional[torch.Tensor] = None):
107
+ self.roots.append(_gradient_edge(r))
108
+ self.grad_roots.append(grad)
109
+
110
+ def __iter__(self):
111
+ gi = _GradientGenerator(
112
+ self.roots,
113
+ list(reversed(self.with_respect_to)),
114
+ self.grad_roots,
115
+ restore_context,
116
+ )
117
+ for n, unflatten in reversed(self.unflattens):
118
+ yield unflatten(reversed([next(gi) for _ in range(n)]))
119
+
120
+
121
+ class GradFunction(torch.autograd.Function):
122
+ @staticmethod
123
+ def forward(ctx, fn, *args, **kwargs):
124
+ result, backward_continuation = fn(*args, **kwargs)
125
+ ctx.backward_continuation = backward_continuation
126
+ values = []
127
+ if backward_continuation.__closure__ is not None:
128
+ for cell in backward_continuation.__closure__:
129
+ values.append(cell.cell_contents)
130
+ cell.cell_contents = None
131
+ tensors, ctx.unflatten = flatten(values, lambda x: isinstance(x, torch.Tensor))
132
+ ctx.save_for_backward(*tensors)
133
+ return result
134
+
135
+ @staticmethod
136
+ def backward(ctx, *args, **kwargs):
137
+ closure = tuple(CellType(v) for v in ctx.unflatten(ctx.saved_tensors))
138
+ orig = ctx.backward_continuation
139
+ fn = FunctionType(
140
+ orig.__code__, orig.__globals__, orig.__name__, orig.__defaults__, closure
141
+ )
142
+ output = fn(*args, **kwargs)
143
+ if isinstance(output, tuple):
144
+ return None, *output
145
+ else:
146
+ return None, output
147
+
148
+
149
+ def grad_function(fn):
150
+ return partial(GradFunction.apply, fn)
151
+
152
+
153
+ def gradient_execution_order(
154
+ roots: Sequence[TensorOrEdge], with_respect_to: Sequence[TensorOrEdge]
155
+ ) -> List[int]:
156
+ """
157
+ Returns the order in which the gradients for `with_respect_to` would become available
158
+ if autograd were run on `roots`. This is the reverse order of each tensors
159
+ first use in the gradient computation.
160
+ """
161
+ with_respect_to = [_gradient_edge(g) for g in with_respect_to]
162
+ min_sequence_nr: Dict[Any, float] = {e: math.inf for e in with_respect_to}
163
+
164
+ to_scan = [_gradient_edge(r).node for r in roots]
165
+ scanned = set()
166
+ while to_scan:
167
+ node = to_scan.pop()
168
+ if node in scanned:
169
+ continue
170
+ scanned.add(node)
171
+ for key in node.next_functions:
172
+ nnode = key[0]
173
+ if nnode is None:
174
+ continue
175
+ to_scan.append(nnode)
176
+ value = min_sequence_nr.get(key)
177
+ if value is not None:
178
+ # pyre-ignore
179
+ min_sequence_nr[key] = min(node._sequence_nr(), value)
180
+
181
+ return sorted(
182
+ range(len(with_respect_to)),
183
+ key=lambda i: min_sequence_nr[with_respect_to[i]],
184
+ reverse=True,
185
+ )
monarch/memory.py ADDED
@@ -0,0 +1,43 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import itertools
9
+ import os
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from monarch.common.remote import remote
14
+
15
+
16
+ PATH_KEY = "dir_snapshots"
17
+ _counter = itertools.count()
18
+
19
+
20
+ @remote(propagate="inspect")
21
+ def record_memory_history() -> None:
22
+ torch.cuda.memory._record_memory_history()
23
+
24
+
25
+ def dump_memory_snapshot(*args, **kwargs) -> None:
26
+ """
27
+ This function wraps torch.cuda.memory._dump_snapshot() to dump memory snapshot remotely.
28
+ """
29
+ assert isinstance(
30
+ kwargs.get(PATH_KEY, None), str
31
+ ), f"{PATH_KEY} must be passed and must be a string to represent the path to save the memory snapshots."
32
+ id = next(_counter)
33
+ _memory_controller_dump(id, *args, **kwargs)
34
+
35
+
36
+ @remote(propagate="inspect")
37
+ def _memory_controller_dump(ident, *args, **kwargs) -> None:
38
+ dir_path = Path(kwargs[PATH_KEY]).absolute()
39
+ os.makedirs(dir_path, exist_ok=True)
40
+ # This is not a synchronized call, so it is okay to call without device mesh.
41
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
42
+ snapshot_path = f"{dir_path}/snapshot_{rank}.pickle"
43
+ torch.cuda.memory._dump_snapshot(filename=snapshot_path)
Binary file