torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,214 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ from collections import deque
10
+ from typing import (
11
+ cast,
12
+ Generator,
13
+ List,
14
+ NamedTuple,
15
+ Optional,
16
+ Sequence,
17
+ TYPE_CHECKING,
18
+ Union,
19
+ )
20
+
21
+ import torch
22
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
23
+ WorldState,
24
+ )
25
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
26
+ ActorId,
27
+ )
28
+
29
+ from monarch.common import messages
30
+
31
+ from monarch.common.controller_api import DebuggerMessage, LogMessage, MessageResult
32
+ from monarch.common.device_mesh import no_mesh
33
+ from monarch.common.invocation import Invocation, RemoteException, Seq
34
+ from monarch.common.reference import Ref
35
+ from monarch.common.shape import iter_ranks, NDSlice, Slices as Ranks
36
+ from monarch.common.tree import flatten
37
+
38
+ if TYPE_CHECKING:
39
+ from monarch.common.tensor import Tensor
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class History:
45
+ def __init__(self, N):
46
+ self.first_uncompleted_ident = [0 for _ in range(N)]
47
+ self.min_first_uncompleted_ident = 0
48
+ self.invocations = deque[Invocation]()
49
+
50
+ def _invocation(
51
+ self,
52
+ seq: Seq,
53
+ defs: Sequence["Tensor"],
54
+ uses: Sequence["Tensor"],
55
+ ):
56
+ r = Invocation(seq)
57
+ for t in uses:
58
+ u = t._invocation
59
+ assert u is not None
60
+ u.add_user(r)
61
+ for t in defs:
62
+ t._invocation = r
63
+ return r
64
+
65
+ def ident(
66
+ self,
67
+ seq: Seq,
68
+ defs: Sequence["Tensor"],
69
+ uses: Sequence["Tensor"],
70
+ ):
71
+ invocation = self._invocation(seq, defs, uses)
72
+ self.invocations.append(invocation)
73
+
74
+ def propagate_failure(self, seq, traceback_index, exception, worker_frames):
75
+ invocation = self.invocations[seq - self.min_first_uncompleted_ident]
76
+ remote_exception = RemoteException(
77
+ seq,
78
+ exception,
79
+ traceback_index,
80
+ None,
81
+ worker_frames,
82
+ ActorId.from_string("unknown[0].unknown[0]"),
83
+ )
84
+ worklist = deque((invocation,))
85
+ while worklist:
86
+ invocation = worklist.popleft()
87
+ if invocation.fail(remote_exception):
88
+ worklist.extend(invocation.users)
89
+
90
+ def rank_completed(
91
+ self, rank, first_uncompleted_ident
92
+ ) -> Generator[MessageResult, None, None]:
93
+ # advance what our last completed action was, and
94
+ # trim the list of tracebacks if we no longer need them.
95
+ prev = self.first_uncompleted_ident[rank]
96
+ self.first_uncompleted_ident[rank] = first_uncompleted_ident
97
+ if prev == self.min_first_uncompleted_ident:
98
+ self.min_first_uncompleted_ident = min(self.first_uncompleted_ident)
99
+ for seq in range(prev, self.min_first_uncompleted_ident):
100
+ invocation = self.invocations.popleft()
101
+ assert seq == invocation.seq
102
+ result, error = invocation.complete()
103
+ yield MessageResult(
104
+ seq=seq,
105
+ result=result,
106
+ error=error,
107
+ )
108
+
109
+ def future_completed(self, ident, value):
110
+ invocation = self.invocations[ident - self.min_first_uncompleted_ident]
111
+ invocation.fut_value = value
112
+
113
+
114
+ class MockController:
115
+ def __init__(self, world_size: int, verbose: bool = True):
116
+ self.history = History(world_size)
117
+ self.world_size = world_size
118
+ self.responses = deque[MessageResult | LogMessage | DebuggerMessage]()
119
+ self.exited = False
120
+ self.verbose = verbose
121
+
122
+ @property
123
+ def gpu_per_host(self) -> int:
124
+ return self.world_size
125
+
126
+ def send(self, ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple) -> None:
127
+ attr = getattr(self, type(msg).__name__, None)
128
+ if self.verbose:
129
+ logger.info(
130
+ "MockController: %s %s %s", str(ranks), str(type(msg)), str(msg)
131
+ )
132
+
133
+ if attr is not None:
134
+ attr(ranks, msg)
135
+
136
+ def next_message(
137
+ self, timeout: Optional[float]
138
+ ) -> Optional[MessageResult | LogMessage]:
139
+ return (
140
+ cast(Optional[MessageResult | LogMessage], self.responses.popleft())
141
+ if len(self.responses) > 0
142
+ else None
143
+ )
144
+
145
+ def stop_mesh(self) -> None:
146
+ pass
147
+
148
+ def drain_and_stop(self) -> List[MessageResult | LogMessage | DebuggerMessage]:
149
+ if not self.exited:
150
+ raise RuntimeError("Got drain_and_stop but exited is not True")
151
+ r = list(self.responses)
152
+ self.responses.clear()
153
+ return r
154
+
155
+ def drop_refs(self, refs: Sequence[Ref]) -> None:
156
+ """
157
+ noop as this is used for the Rust controller to know when to gc invocations_for_ref for failed invocations
158
+ """
159
+ pass
160
+
161
+ def node(
162
+ self, seq: Seq, defs: Sequence["Tensor"], uses: Sequence["Tensor"]
163
+ ) -> None:
164
+ self.history.ident(seq, defs, uses)
165
+
166
+ def worker_world_state(self) -> WorldState:
167
+ # Eventhough not implemented, return needed so return value complies with type checking
168
+ assert 1 == 2, "not implemented"
169
+ return WorldState()
170
+
171
+ # Below are the messages that should be executed on "workers".
172
+ def CommandGroup(self, ranks: Ranks, msg: messages.CommandGroup):
173
+ for command in msg.commands:
174
+ self.send(ranks, command)
175
+
176
+ def RequestStatus(self, ranks: Ranks, msg: messages.RequestStatus):
177
+ for rank in iter_ranks(ranks):
178
+ for r in self.history.rank_completed(rank, msg.ident + 1):
179
+ self.responses.append(r)
180
+
181
+ def SendValue(self, ranks: Ranks, msg: messages.SendValue):
182
+ dtensors, unflatten = flatten(
183
+ (msg.args, msg.kwargs), lambda x: isinstance(x, torch.Tensor)
184
+ )
185
+ fake_args, _fake_kwargs = unflatten(d._fake for d in dtensors)
186
+ if msg.function is not None:
187
+ fake_result = None
188
+ else:
189
+ fake_result = fake_args[0]
190
+
191
+ if msg.destination is None:
192
+ # If the destination is the controller, we need to send back an actual
193
+ # tensor, not a fake tensor because the rest operations are likely to
194
+ # be data dependent (e.g., losses.item()).
195
+ # Note that this also means that if the controller are going to branch
196
+ # out the execution, the execution path is going to diverge from the
197
+ # actual workload.
198
+ with no_mesh.activate():
199
+ tensors, unflatten = flatten(
200
+ fake_result, lambda x: isinstance(x, torch.Tensor)
201
+ )
202
+ fake_result = unflatten(
203
+ torch.zeros(
204
+ t.size(), dtype=t.dtype, device=t.device, requires_grad=False
205
+ )
206
+ for t in tensors
207
+ )
208
+ for _ in iter_ranks(ranks):
209
+ self.responses.append(
210
+ self.history.future_completed(msg.ident, fake_result)
211
+ )
212
+
213
+ def Exit(self, ranks: Ranks, msg: messages.Exit):
214
+ self.exited = True
@@ -0,0 +1,424 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import contextlib
9
+ import copy
10
+ import enum
11
+ import functools
12
+ import multiprocessing
13
+ import os
14
+ import socket
15
+ import time
16
+ import traceback
17
+
18
+ from contextlib import closing
19
+ from datetime import timedelta
20
+ from typing import (
21
+ Any,
22
+ Callable,
23
+ Dict,
24
+ Generator,
25
+ List,
26
+ NamedTuple,
27
+ Optional,
28
+ Set,
29
+ Tuple,
30
+ )
31
+
32
+ import torch
33
+ import torch.distributed as dist
34
+ from monarch.common import messages
35
+ from monarch.common.function import resolvable_function
36
+ from monarch.common.function_caching import (
37
+ hashable_tensor_flatten,
38
+ HashableTreeSpec,
39
+ key_filters,
40
+ TensorGroup,
41
+ )
42
+ from monarch.common.tensor_factory import TensorFactory
43
+ from monarch.simulator.command_history import CommandHistory, DTensorRef
44
+ from torch.utils import _pytree as pytree
45
+ from torch.utils._mode_utils import no_dispatch
46
+
47
+
48
+ def get_free_port() -> int:
49
+ configs = [(socket.AF_INET6, "::1"), (socket.AF_INET, "127.0.0.1")]
50
+ errors = []
51
+
52
+ for addr_family, address in configs:
53
+ with socket.socket(addr_family, socket.SOCK_STREAM) as s:
54
+ try:
55
+ s.bind((address, 0))
56
+ s.listen(0)
57
+ with closing(s):
58
+ return s.getsockname()[1]
59
+ except Exception as e:
60
+ errors.append(
61
+ f"Binding failed with address {address} while getting free port: {e}"
62
+ )
63
+
64
+ # If this is reached, we failed to bind to any of the configs
65
+ raise Exception(", ".join(errors))
66
+
67
+
68
+ # These functions below are from cached_remote_function.py but depending on
69
+ # cached_remote_function.py can cauce dependency issues.
70
+ def _to_factory(x):
71
+ if isinstance(x, torch.Tensor):
72
+ return (TensorFactory.from_tensor(x), x.requires_grad)
73
+ return x
74
+
75
+
76
+ def _filter_key(v: Any):
77
+ for filter in key_filters:
78
+ v = filter(v)
79
+ return v
80
+
81
+
82
+ def _make_key(args, kwargs):
83
+ values, spec = pytree.tree_flatten((args, kwargs))
84
+ return tuple(_filter_key(v) for v in values), HashableTreeSpec.from_treespec(spec)
85
+
86
+
87
+ class ProfilingWorker:
88
+ _float_types: Set[torch.dtype] = {
89
+ torch.float16,
90
+ torch.bfloat16,
91
+ torch.float32,
92
+ torch.float64,
93
+ }
94
+
95
+ def __init__(self, world_size, rank) -> None:
96
+ self.world_size = world_size
97
+ self.rank = rank
98
+ self.counter = 0
99
+
100
+ @contextlib.contextmanager
101
+ def _worker_env(self) -> Generator[dist.TCPStore, None, None]:
102
+ try:
103
+ store = dist.TCPStore(
104
+ os.environ["STORE_HOSTNAME"],
105
+ int(os.environ["STORE_PORT"]),
106
+ timeout=timedelta(seconds=10),
107
+ )
108
+ torch.cuda.set_device(self.rank)
109
+ yield store
110
+ finally:
111
+ if dist.is_initialized():
112
+ dist.destroy_process_group()
113
+
114
+ # Adapted from: https://fburl.com/3xpyoq93
115
+ # NB: returns fake tensors
116
+ def _run_function(
117
+ self, func: Callable, args: Any, kwargs: Any
118
+ ) -> Tuple[int, Any | None]:
119
+ """
120
+ Runs and benchmarks a fallback kernel for a given function.
121
+
122
+ Args:
123
+ func (Callable): The function to benchmark.
124
+ args (Tuple): The arguments to pass to the function.
125
+ kwargs (Dict[str, Any]): The keyword arguments to pass to the function.
126
+
127
+ Returns:
128
+ Tuple[int, Any | None]: A tuple containing the mean operation time in nano-seconds
129
+ and the result of the function.
130
+ """
131
+ # these should all be supported, just to be safe
132
+ # avoid fallback for operators which inplace modify metadata
133
+ # because the input fake tensors would be umodified
134
+
135
+ if torch.Tag.inplace_view in getattr(func, "tags", ()):
136
+ raise NotImplementedError
137
+
138
+ if args is None:
139
+ args = ()
140
+
141
+ if kwargs is None:
142
+ kwargs = {}
143
+
144
+ warmup_iters, actual_iters = 2, 3
145
+ # We have to deecopy before entering `no_dispatch()` context so that
146
+ # the copy won't materialize the fake tensor to a tensor automatically.
147
+ args_copies = [
148
+ copy.deepcopy(args) for _ in range(warmup_iters + actual_iters + 1)
149
+ ]
150
+ kwargs_copies = [
151
+ copy.deepcopy(kwargs) for _ in range(warmup_iters + actual_iters + 1)
152
+ ]
153
+
154
+ with no_dispatch():
155
+ materialized_tensors = {}
156
+
157
+ def to_real_tensor(e): # type: ignore[no-untyped-def]
158
+ if isinstance(e, DTensorRef):
159
+ ref = e.ref
160
+
161
+ # TODO: Should we investigate this issue or not
162
+ # much we can do?
163
+ # Context: caching the materilized tensors won't work for
164
+ # TE's backward. It will crash without throwing any exception.
165
+ # out = materialized_tensors.get(ref, None)
166
+ out = None
167
+ if out is None:
168
+ e = e._fake
169
+ assert e is not None
170
+ if e.dtype in self._float_types:
171
+ out = torch.rand_like(e, device=e.fake_device)
172
+ else:
173
+ out = torch.ones_like(e, device=e.fake_device)
174
+ if e.is_sparse:
175
+ out._coalesced_(e.is_coalesced())
176
+ materialized_tensors[ref] = out
177
+ return out
178
+ return e
179
+
180
+ def materialize():
181
+ args = args_copies.pop()
182
+ kwargs = kwargs_copies.pop()
183
+ flat_args, args_spec = pytree.tree_flatten((args, kwargs))
184
+ flat_args = [to_real_tensor(a) for a in flat_args]
185
+ args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
186
+ return args, kwargs
187
+
188
+ args, kwargs = materialize()
189
+ r = func(*args, **kwargs)
190
+
191
+ warmup_iters, actual_iters = 2, 3
192
+ for _ in range(warmup_iters):
193
+ args, kwargs = materialize()
194
+ func(*args, **kwargs)
195
+
196
+ start_event = torch.cuda.Event(enable_timing=True)
197
+ end_event = torch.cuda.Event(enable_timing=True)
198
+ start_event.record(torch.cuda.current_stream())
199
+ for _ in range(actual_iters):
200
+ args, kwargs = materialize()
201
+ func(*args, **kwargs)
202
+ end_event.record(torch.cuda.current_stream())
203
+ torch.cuda.synchronize()
204
+ cuda_time = start_event.elapsed_time(end_event)
205
+ mean_op_time = int(cuda_time / actual_iters * 1000)
206
+
207
+ return r, mean_op_time
208
+
209
+ def CallFunction(self, msg) -> None:
210
+ func = msg.function.resolve()
211
+ ret = self._run_function(func, msg.args, msg.kwargs)
212
+
213
+ count = 2**31
214
+
215
+ def tensor_to_dtensor_ref(t):
216
+ nonlocal count
217
+ count += 1
218
+ t.ref = count
219
+ return DTensorRef(t)
220
+
221
+ return pytree.tree_map_only(torch.Tensor, tensor_to_dtensor_ref, ret)
222
+
223
+ def run(self, conn) -> None:
224
+ with self._worker_env() as store:
225
+ try:
226
+ while True:
227
+ msg = conn.recv()
228
+ if msg == "exit":
229
+ break
230
+ elif msg == "init_pg":
231
+ if not dist.is_initialized:
232
+ dist.init_process_group(
233
+ backend="nccl",
234
+ world_size=self.world_size,
235
+ rank=self.rank,
236
+ store=store,
237
+ )
238
+ else:
239
+ ret = self.CallFunction(msg)
240
+ conn.send(("result", ret))
241
+ self.counter += 1
242
+ except Exception:
243
+ conn.send(("exception", traceback.format_exc()))
244
+ finally:
245
+ conn.close()
246
+
247
+
248
+ class RuntimeProfiler:
249
+ def __init__(self, world_size: int = 8, port: int = -1) -> None:
250
+ # TODO: Add a cached mode to save the results into a pickle file so that
251
+ # we can reuse the result without running anything.
252
+ self.world_size = world_size
253
+ self.port = port if port > 0 else get_free_port()
254
+ self._initizlied = False
255
+ self.parent_conns: List[multiprocessing.connection.Connection] = []
256
+ self.cached: Dict[Tuple[Any, ...], Any] = {}
257
+
258
+ def _lazy_init(self):
259
+ if self._initizlied:
260
+ return
261
+
262
+ self.store = dist.TCPStore("localhost", self.port, is_master=True)
263
+ self.processes = []
264
+ self.world_size = self.world_size
265
+ ctx = multiprocessing.get_context("spawn")
266
+ os.environ["STORE_HOSTNAME"] = "localhost"
267
+ os.environ["STORE_PORT"] = str(self.port)
268
+ for i in range(self.world_size):
269
+ parent_conn, child_conn = multiprocessing.Pipe()
270
+ worker = ProfilingWorker(self.world_size, i)
271
+ self.processes.append(
272
+ ctx.Process(target=worker.run, args=(child_conn,), daemon=True),
273
+ )
274
+ self.parent_conns.append(parent_conn)
275
+ self.processes[-1].start()
276
+
277
+ self._initizlied = True
278
+
279
+ def __exit__(self) -> None:
280
+ if self._initizlied:
281
+ for i in range(self.world_size):
282
+ conn = self.parent_conns[i]
283
+ conn.send("exit")
284
+ time.sleep(0.1)
285
+
286
+ def profile_cmd(self, cmd, ranks) -> List[Any | None]:
287
+ self._lazy_init()
288
+
289
+ ret = []
290
+ assert type(cmd).__name__ == "CallFunction"
291
+ cmd = CommandHistory.convert_msg(cmd)
292
+ cmd = cmd._replace(function=resolvable_function(cmd.function))
293
+
294
+ def dtensor_ref_filter(v: Any):
295
+ if isinstance(v, DTensorRef):
296
+ return v.factory
297
+ return v
298
+
299
+ key_filters.append(dtensor_ref_filter)
300
+ tensors, shape_key = hashable_tensor_flatten((cmd, ranks), {})
301
+ inputs_group = TensorGroup([t._fake for t in tensors]) # pyre-ignore[16]
302
+ requires_grads = tuple(t.requires_grad for t in tensors)
303
+ key = (shape_key, inputs_group.pattern, requires_grads)
304
+ key_filters.pop()
305
+ # key = _make_key((cmd, ranks), None)
306
+ if key in self.cached:
307
+ return self.cached[key]
308
+
309
+ for i in ranks:
310
+ conn = self.parent_conns[i]
311
+ conn.send(cmd)
312
+
313
+ # This cannot be merged to the previous for loop. A deadlock can happen.
314
+ for _ in ranks:
315
+ ret.append(conn.recv())
316
+
317
+ clean_ret = []
318
+ for r in ret:
319
+ if r[0] == "exception":
320
+ raise RuntimeError(r[1])
321
+ clean_ret.append(r[1])
322
+
323
+ self.cached[key] = clean_ret
324
+ return clean_ret
325
+
326
+
327
+ def _return_if_exist(attr):
328
+ def decorator(func):
329
+ @functools.wraps(func)
330
+ def wrapper(self, *args, **kwargs):
331
+ user_fn = getattr(self, attr)
332
+ if isinstance(user_fn, int):
333
+ return user_fn
334
+ elif callable(user_fn):
335
+ return user_fn(*args, **kwargs)
336
+ return func(self, *args, **kwargs)
337
+
338
+ return wrapper
339
+
340
+ return decorator
341
+
342
+
343
+ class TimingType(str, enum.Enum):
344
+ SEND_TENSOR = "_send_tensor_time"
345
+ REDUCE = "_reduce_time"
346
+ CALL_FUNCTION = "_call_function_time"
347
+ KERNEL_LAUNCH = "_kernel_launch_time"
348
+ WAIT_EVENT = "_wait_event_time"
349
+
350
+
351
+ TimingFunction = Callable[[Optional[NamedTuple]], int]
352
+
353
+
354
+ class RuntimeEstimator:
355
+ def __init__(self) -> None:
356
+ self._call_function_time: TimingFunction | int | None = None
357
+ self._reduce_time: TimingFunction | int | None = None
358
+ self._send_tensor_time: TimingFunction | int | None = None
359
+ self._wait_event_time: int | None = None
360
+ self._kernel_launch_time: int | None = None
361
+
362
+ @_return_if_exist("_send_tensor_time")
363
+ def _get_send_tensor_time(self, msg: messages.SendTensor) -> int:
364
+ if msg.from_ranks == msg.to_ranks:
365
+ return 1_000
366
+ return 100_000
367
+
368
+ @_return_if_exist("_reduce_time")
369
+ def _get_reduce_time(self, msg: messages.Reduce) -> int:
370
+ return 100_000
371
+
372
+ @_return_if_exist("_call_function_time")
373
+ def _get_call_function_time(self, msg: messages.CallFunction) -> int:
374
+ return 10_000
375
+
376
+ @_return_if_exist("_kernel_launch_time")
377
+ def _get_kernel_launch_time(self) -> int:
378
+ return 500
379
+
380
+ @_return_if_exist("_wait_event_time")
381
+ def _get_wait_event_time(self) -> int:
382
+ return 500
383
+
384
+ def set_custom_timing(
385
+ self, func_or_time: Dict[TimingType, TimingFunction | int]
386
+ ) -> None:
387
+ """
388
+ Set custom timing values for specific message types or events.
389
+
390
+ This method allows the user to define custom timing values for various
391
+ operations in the simulator. The timing can be specified either as a fixed
392
+ integer value or as a function that computes the timing dynamically.
393
+ All the integer values are in nanoseconds.
394
+
395
+ Args:
396
+ func_or_time (Dict[TimingType, TimingFunction | int]): A dictionary
397
+ mapping TimingType to either a function or an integer. If a function
398
+ is provided, it should accept an optional NamedTuple as input and
399
+ return an integer representing the timing in nanoseconds.
400
+
401
+ Raises:
402
+ AssertionError: If the values in the dictionary are neither integers
403
+ nor callable functions.
404
+ """
405
+ for k, v in func_or_time.items():
406
+ assert isinstance(v, int) or callable(
407
+ v
408
+ ), "The supported customized timing are an integer or a function."
409
+ setattr(self, k.value, v)
410
+
411
+ def get_runtime(self, msg) -> int:
412
+ match msg:
413
+ case messages.CallFunction():
414
+ return self._get_call_function_time(msg)
415
+ case messages.Reduce():
416
+ return self._get_reduce_time(msg)
417
+ case messages.SendTensor():
418
+ return self._get_send_tensor_time(msg)
419
+ case "kernel_launch":
420
+ return self._get_kernel_launch_time()
421
+ case "wait_event":
422
+ return self._get_wait_event_time()
423
+ case _:
424
+ raise ValueError(f"Get an unexpected message for profiling, {msg}.")