torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,214 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import logging
|
9
|
+
from collections import deque
|
10
|
+
from typing import (
|
11
|
+
cast,
|
12
|
+
Generator,
|
13
|
+
List,
|
14
|
+
NamedTuple,
|
15
|
+
Optional,
|
16
|
+
Sequence,
|
17
|
+
TYPE_CHECKING,
|
18
|
+
Union,
|
19
|
+
)
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
|
23
|
+
WorldState,
|
24
|
+
)
|
25
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
26
|
+
ActorId,
|
27
|
+
)
|
28
|
+
|
29
|
+
from monarch.common import messages
|
30
|
+
|
31
|
+
from monarch.common.controller_api import DebuggerMessage, LogMessage, MessageResult
|
32
|
+
from monarch.common.device_mesh import no_mesh
|
33
|
+
from monarch.common.invocation import Invocation, RemoteException, Seq
|
34
|
+
from monarch.common.reference import Ref
|
35
|
+
from monarch.common.shape import iter_ranks, NDSlice, Slices as Ranks
|
36
|
+
from monarch.common.tree import flatten
|
37
|
+
|
38
|
+
if TYPE_CHECKING:
|
39
|
+
from monarch.common.tensor import Tensor
|
40
|
+
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
|
44
|
+
class History:
|
45
|
+
def __init__(self, N):
|
46
|
+
self.first_uncompleted_ident = [0 for _ in range(N)]
|
47
|
+
self.min_first_uncompleted_ident = 0
|
48
|
+
self.invocations = deque[Invocation]()
|
49
|
+
|
50
|
+
def _invocation(
|
51
|
+
self,
|
52
|
+
seq: Seq,
|
53
|
+
defs: Sequence["Tensor"],
|
54
|
+
uses: Sequence["Tensor"],
|
55
|
+
):
|
56
|
+
r = Invocation(seq)
|
57
|
+
for t in uses:
|
58
|
+
u = t._invocation
|
59
|
+
assert u is not None
|
60
|
+
u.add_user(r)
|
61
|
+
for t in defs:
|
62
|
+
t._invocation = r
|
63
|
+
return r
|
64
|
+
|
65
|
+
def ident(
|
66
|
+
self,
|
67
|
+
seq: Seq,
|
68
|
+
defs: Sequence["Tensor"],
|
69
|
+
uses: Sequence["Tensor"],
|
70
|
+
):
|
71
|
+
invocation = self._invocation(seq, defs, uses)
|
72
|
+
self.invocations.append(invocation)
|
73
|
+
|
74
|
+
def propagate_failure(self, seq, traceback_index, exception, worker_frames):
|
75
|
+
invocation = self.invocations[seq - self.min_first_uncompleted_ident]
|
76
|
+
remote_exception = RemoteException(
|
77
|
+
seq,
|
78
|
+
exception,
|
79
|
+
traceback_index,
|
80
|
+
None,
|
81
|
+
worker_frames,
|
82
|
+
ActorId.from_string("unknown[0].unknown[0]"),
|
83
|
+
)
|
84
|
+
worklist = deque((invocation,))
|
85
|
+
while worklist:
|
86
|
+
invocation = worklist.popleft()
|
87
|
+
if invocation.fail(remote_exception):
|
88
|
+
worklist.extend(invocation.users)
|
89
|
+
|
90
|
+
def rank_completed(
|
91
|
+
self, rank, first_uncompleted_ident
|
92
|
+
) -> Generator[MessageResult, None, None]:
|
93
|
+
# advance what our last completed action was, and
|
94
|
+
# trim the list of tracebacks if we no longer need them.
|
95
|
+
prev = self.first_uncompleted_ident[rank]
|
96
|
+
self.first_uncompleted_ident[rank] = first_uncompleted_ident
|
97
|
+
if prev == self.min_first_uncompleted_ident:
|
98
|
+
self.min_first_uncompleted_ident = min(self.first_uncompleted_ident)
|
99
|
+
for seq in range(prev, self.min_first_uncompleted_ident):
|
100
|
+
invocation = self.invocations.popleft()
|
101
|
+
assert seq == invocation.seq
|
102
|
+
result, error = invocation.complete()
|
103
|
+
yield MessageResult(
|
104
|
+
seq=seq,
|
105
|
+
result=result,
|
106
|
+
error=error,
|
107
|
+
)
|
108
|
+
|
109
|
+
def future_completed(self, ident, value):
|
110
|
+
invocation = self.invocations[ident - self.min_first_uncompleted_ident]
|
111
|
+
invocation.fut_value = value
|
112
|
+
|
113
|
+
|
114
|
+
class MockController:
|
115
|
+
def __init__(self, world_size: int, verbose: bool = True):
|
116
|
+
self.history = History(world_size)
|
117
|
+
self.world_size = world_size
|
118
|
+
self.responses = deque[MessageResult | LogMessage | DebuggerMessage]()
|
119
|
+
self.exited = False
|
120
|
+
self.verbose = verbose
|
121
|
+
|
122
|
+
@property
|
123
|
+
def gpu_per_host(self) -> int:
|
124
|
+
return self.world_size
|
125
|
+
|
126
|
+
def send(self, ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple) -> None:
|
127
|
+
attr = getattr(self, type(msg).__name__, None)
|
128
|
+
if self.verbose:
|
129
|
+
logger.info(
|
130
|
+
"MockController: %s %s %s", str(ranks), str(type(msg)), str(msg)
|
131
|
+
)
|
132
|
+
|
133
|
+
if attr is not None:
|
134
|
+
attr(ranks, msg)
|
135
|
+
|
136
|
+
def next_message(
|
137
|
+
self, timeout: Optional[float]
|
138
|
+
) -> Optional[MessageResult | LogMessage]:
|
139
|
+
return (
|
140
|
+
cast(Optional[MessageResult | LogMessage], self.responses.popleft())
|
141
|
+
if len(self.responses) > 0
|
142
|
+
else None
|
143
|
+
)
|
144
|
+
|
145
|
+
def stop_mesh(self) -> None:
|
146
|
+
pass
|
147
|
+
|
148
|
+
def drain_and_stop(self) -> List[MessageResult | LogMessage | DebuggerMessage]:
|
149
|
+
if not self.exited:
|
150
|
+
raise RuntimeError("Got drain_and_stop but exited is not True")
|
151
|
+
r = list(self.responses)
|
152
|
+
self.responses.clear()
|
153
|
+
return r
|
154
|
+
|
155
|
+
def drop_refs(self, refs: Sequence[Ref]) -> None:
|
156
|
+
"""
|
157
|
+
noop as this is used for the Rust controller to know when to gc invocations_for_ref for failed invocations
|
158
|
+
"""
|
159
|
+
pass
|
160
|
+
|
161
|
+
def node(
|
162
|
+
self, seq: Seq, defs: Sequence["Tensor"], uses: Sequence["Tensor"]
|
163
|
+
) -> None:
|
164
|
+
self.history.ident(seq, defs, uses)
|
165
|
+
|
166
|
+
def worker_world_state(self) -> WorldState:
|
167
|
+
# Eventhough not implemented, return needed so return value complies with type checking
|
168
|
+
assert 1 == 2, "not implemented"
|
169
|
+
return WorldState()
|
170
|
+
|
171
|
+
# Below are the messages that should be executed on "workers".
|
172
|
+
def CommandGroup(self, ranks: Ranks, msg: messages.CommandGroup):
|
173
|
+
for command in msg.commands:
|
174
|
+
self.send(ranks, command)
|
175
|
+
|
176
|
+
def RequestStatus(self, ranks: Ranks, msg: messages.RequestStatus):
|
177
|
+
for rank in iter_ranks(ranks):
|
178
|
+
for r in self.history.rank_completed(rank, msg.ident + 1):
|
179
|
+
self.responses.append(r)
|
180
|
+
|
181
|
+
def SendValue(self, ranks: Ranks, msg: messages.SendValue):
|
182
|
+
dtensors, unflatten = flatten(
|
183
|
+
(msg.args, msg.kwargs), lambda x: isinstance(x, torch.Tensor)
|
184
|
+
)
|
185
|
+
fake_args, _fake_kwargs = unflatten(d._fake for d in dtensors)
|
186
|
+
if msg.function is not None:
|
187
|
+
fake_result = None
|
188
|
+
else:
|
189
|
+
fake_result = fake_args[0]
|
190
|
+
|
191
|
+
if msg.destination is None:
|
192
|
+
# If the destination is the controller, we need to send back an actual
|
193
|
+
# tensor, not a fake tensor because the rest operations are likely to
|
194
|
+
# be data dependent (e.g., losses.item()).
|
195
|
+
# Note that this also means that if the controller are going to branch
|
196
|
+
# out the execution, the execution path is going to diverge from the
|
197
|
+
# actual workload.
|
198
|
+
with no_mesh.activate():
|
199
|
+
tensors, unflatten = flatten(
|
200
|
+
fake_result, lambda x: isinstance(x, torch.Tensor)
|
201
|
+
)
|
202
|
+
fake_result = unflatten(
|
203
|
+
torch.zeros(
|
204
|
+
t.size(), dtype=t.dtype, device=t.device, requires_grad=False
|
205
|
+
)
|
206
|
+
for t in tensors
|
207
|
+
)
|
208
|
+
for _ in iter_ranks(ranks):
|
209
|
+
self.responses.append(
|
210
|
+
self.history.future_completed(msg.ident, fake_result)
|
211
|
+
)
|
212
|
+
|
213
|
+
def Exit(self, ranks: Ranks, msg: messages.Exit):
|
214
|
+
self.exited = True
|
@@ -0,0 +1,424 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import contextlib
|
9
|
+
import copy
|
10
|
+
import enum
|
11
|
+
import functools
|
12
|
+
import multiprocessing
|
13
|
+
import os
|
14
|
+
import socket
|
15
|
+
import time
|
16
|
+
import traceback
|
17
|
+
|
18
|
+
from contextlib import closing
|
19
|
+
from datetime import timedelta
|
20
|
+
from typing import (
|
21
|
+
Any,
|
22
|
+
Callable,
|
23
|
+
Dict,
|
24
|
+
Generator,
|
25
|
+
List,
|
26
|
+
NamedTuple,
|
27
|
+
Optional,
|
28
|
+
Set,
|
29
|
+
Tuple,
|
30
|
+
)
|
31
|
+
|
32
|
+
import torch
|
33
|
+
import torch.distributed as dist
|
34
|
+
from monarch.common import messages
|
35
|
+
from monarch.common.function import resolvable_function
|
36
|
+
from monarch.common.function_caching import (
|
37
|
+
hashable_tensor_flatten,
|
38
|
+
HashableTreeSpec,
|
39
|
+
key_filters,
|
40
|
+
TensorGroup,
|
41
|
+
)
|
42
|
+
from monarch.common.tensor_factory import TensorFactory
|
43
|
+
from monarch.simulator.command_history import CommandHistory, DTensorRef
|
44
|
+
from torch.utils import _pytree as pytree
|
45
|
+
from torch.utils._mode_utils import no_dispatch
|
46
|
+
|
47
|
+
|
48
|
+
def get_free_port() -> int:
|
49
|
+
configs = [(socket.AF_INET6, "::1"), (socket.AF_INET, "127.0.0.1")]
|
50
|
+
errors = []
|
51
|
+
|
52
|
+
for addr_family, address in configs:
|
53
|
+
with socket.socket(addr_family, socket.SOCK_STREAM) as s:
|
54
|
+
try:
|
55
|
+
s.bind((address, 0))
|
56
|
+
s.listen(0)
|
57
|
+
with closing(s):
|
58
|
+
return s.getsockname()[1]
|
59
|
+
except Exception as e:
|
60
|
+
errors.append(
|
61
|
+
f"Binding failed with address {address} while getting free port: {e}"
|
62
|
+
)
|
63
|
+
|
64
|
+
# If this is reached, we failed to bind to any of the configs
|
65
|
+
raise Exception(", ".join(errors))
|
66
|
+
|
67
|
+
|
68
|
+
# These functions below are from cached_remote_function.py but depending on
|
69
|
+
# cached_remote_function.py can cauce dependency issues.
|
70
|
+
def _to_factory(x):
|
71
|
+
if isinstance(x, torch.Tensor):
|
72
|
+
return (TensorFactory.from_tensor(x), x.requires_grad)
|
73
|
+
return x
|
74
|
+
|
75
|
+
|
76
|
+
def _filter_key(v: Any):
|
77
|
+
for filter in key_filters:
|
78
|
+
v = filter(v)
|
79
|
+
return v
|
80
|
+
|
81
|
+
|
82
|
+
def _make_key(args, kwargs):
|
83
|
+
values, spec = pytree.tree_flatten((args, kwargs))
|
84
|
+
return tuple(_filter_key(v) for v in values), HashableTreeSpec.from_treespec(spec)
|
85
|
+
|
86
|
+
|
87
|
+
class ProfilingWorker:
|
88
|
+
_float_types: Set[torch.dtype] = {
|
89
|
+
torch.float16,
|
90
|
+
torch.bfloat16,
|
91
|
+
torch.float32,
|
92
|
+
torch.float64,
|
93
|
+
}
|
94
|
+
|
95
|
+
def __init__(self, world_size, rank) -> None:
|
96
|
+
self.world_size = world_size
|
97
|
+
self.rank = rank
|
98
|
+
self.counter = 0
|
99
|
+
|
100
|
+
@contextlib.contextmanager
|
101
|
+
def _worker_env(self) -> Generator[dist.TCPStore, None, None]:
|
102
|
+
try:
|
103
|
+
store = dist.TCPStore(
|
104
|
+
os.environ["STORE_HOSTNAME"],
|
105
|
+
int(os.environ["STORE_PORT"]),
|
106
|
+
timeout=timedelta(seconds=10),
|
107
|
+
)
|
108
|
+
torch.cuda.set_device(self.rank)
|
109
|
+
yield store
|
110
|
+
finally:
|
111
|
+
if dist.is_initialized():
|
112
|
+
dist.destroy_process_group()
|
113
|
+
|
114
|
+
# Adapted from: https://fburl.com/3xpyoq93
|
115
|
+
# NB: returns fake tensors
|
116
|
+
def _run_function(
|
117
|
+
self, func: Callable, args: Any, kwargs: Any
|
118
|
+
) -> Tuple[int, Any | None]:
|
119
|
+
"""
|
120
|
+
Runs and benchmarks a fallback kernel for a given function.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
func (Callable): The function to benchmark.
|
124
|
+
args (Tuple): The arguments to pass to the function.
|
125
|
+
kwargs (Dict[str, Any]): The keyword arguments to pass to the function.
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
Tuple[int, Any | None]: A tuple containing the mean operation time in nano-seconds
|
129
|
+
and the result of the function.
|
130
|
+
"""
|
131
|
+
# these should all be supported, just to be safe
|
132
|
+
# avoid fallback for operators which inplace modify metadata
|
133
|
+
# because the input fake tensors would be umodified
|
134
|
+
|
135
|
+
if torch.Tag.inplace_view in getattr(func, "tags", ()):
|
136
|
+
raise NotImplementedError
|
137
|
+
|
138
|
+
if args is None:
|
139
|
+
args = ()
|
140
|
+
|
141
|
+
if kwargs is None:
|
142
|
+
kwargs = {}
|
143
|
+
|
144
|
+
warmup_iters, actual_iters = 2, 3
|
145
|
+
# We have to deecopy before entering `no_dispatch()` context so that
|
146
|
+
# the copy won't materialize the fake tensor to a tensor automatically.
|
147
|
+
args_copies = [
|
148
|
+
copy.deepcopy(args) for _ in range(warmup_iters + actual_iters + 1)
|
149
|
+
]
|
150
|
+
kwargs_copies = [
|
151
|
+
copy.deepcopy(kwargs) for _ in range(warmup_iters + actual_iters + 1)
|
152
|
+
]
|
153
|
+
|
154
|
+
with no_dispatch():
|
155
|
+
materialized_tensors = {}
|
156
|
+
|
157
|
+
def to_real_tensor(e): # type: ignore[no-untyped-def]
|
158
|
+
if isinstance(e, DTensorRef):
|
159
|
+
ref = e.ref
|
160
|
+
|
161
|
+
# TODO: Should we investigate this issue or not
|
162
|
+
# much we can do?
|
163
|
+
# Context: caching the materilized tensors won't work for
|
164
|
+
# TE's backward. It will crash without throwing any exception.
|
165
|
+
# out = materialized_tensors.get(ref, None)
|
166
|
+
out = None
|
167
|
+
if out is None:
|
168
|
+
e = e._fake
|
169
|
+
assert e is not None
|
170
|
+
if e.dtype in self._float_types:
|
171
|
+
out = torch.rand_like(e, device=e.fake_device)
|
172
|
+
else:
|
173
|
+
out = torch.ones_like(e, device=e.fake_device)
|
174
|
+
if e.is_sparse:
|
175
|
+
out._coalesced_(e.is_coalesced())
|
176
|
+
materialized_tensors[ref] = out
|
177
|
+
return out
|
178
|
+
return e
|
179
|
+
|
180
|
+
def materialize():
|
181
|
+
args = args_copies.pop()
|
182
|
+
kwargs = kwargs_copies.pop()
|
183
|
+
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
|
184
|
+
flat_args = [to_real_tensor(a) for a in flat_args]
|
185
|
+
args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
|
186
|
+
return args, kwargs
|
187
|
+
|
188
|
+
args, kwargs = materialize()
|
189
|
+
r = func(*args, **kwargs)
|
190
|
+
|
191
|
+
warmup_iters, actual_iters = 2, 3
|
192
|
+
for _ in range(warmup_iters):
|
193
|
+
args, kwargs = materialize()
|
194
|
+
func(*args, **kwargs)
|
195
|
+
|
196
|
+
start_event = torch.cuda.Event(enable_timing=True)
|
197
|
+
end_event = torch.cuda.Event(enable_timing=True)
|
198
|
+
start_event.record(torch.cuda.current_stream())
|
199
|
+
for _ in range(actual_iters):
|
200
|
+
args, kwargs = materialize()
|
201
|
+
func(*args, **kwargs)
|
202
|
+
end_event.record(torch.cuda.current_stream())
|
203
|
+
torch.cuda.synchronize()
|
204
|
+
cuda_time = start_event.elapsed_time(end_event)
|
205
|
+
mean_op_time = int(cuda_time / actual_iters * 1000)
|
206
|
+
|
207
|
+
return r, mean_op_time
|
208
|
+
|
209
|
+
def CallFunction(self, msg) -> None:
|
210
|
+
func = msg.function.resolve()
|
211
|
+
ret = self._run_function(func, msg.args, msg.kwargs)
|
212
|
+
|
213
|
+
count = 2**31
|
214
|
+
|
215
|
+
def tensor_to_dtensor_ref(t):
|
216
|
+
nonlocal count
|
217
|
+
count += 1
|
218
|
+
t.ref = count
|
219
|
+
return DTensorRef(t)
|
220
|
+
|
221
|
+
return pytree.tree_map_only(torch.Tensor, tensor_to_dtensor_ref, ret)
|
222
|
+
|
223
|
+
def run(self, conn) -> None:
|
224
|
+
with self._worker_env() as store:
|
225
|
+
try:
|
226
|
+
while True:
|
227
|
+
msg = conn.recv()
|
228
|
+
if msg == "exit":
|
229
|
+
break
|
230
|
+
elif msg == "init_pg":
|
231
|
+
if not dist.is_initialized:
|
232
|
+
dist.init_process_group(
|
233
|
+
backend="nccl",
|
234
|
+
world_size=self.world_size,
|
235
|
+
rank=self.rank,
|
236
|
+
store=store,
|
237
|
+
)
|
238
|
+
else:
|
239
|
+
ret = self.CallFunction(msg)
|
240
|
+
conn.send(("result", ret))
|
241
|
+
self.counter += 1
|
242
|
+
except Exception:
|
243
|
+
conn.send(("exception", traceback.format_exc()))
|
244
|
+
finally:
|
245
|
+
conn.close()
|
246
|
+
|
247
|
+
|
248
|
+
class RuntimeProfiler:
|
249
|
+
def __init__(self, world_size: int = 8, port: int = -1) -> None:
|
250
|
+
# TODO: Add a cached mode to save the results into a pickle file so that
|
251
|
+
# we can reuse the result without running anything.
|
252
|
+
self.world_size = world_size
|
253
|
+
self.port = port if port > 0 else get_free_port()
|
254
|
+
self._initizlied = False
|
255
|
+
self.parent_conns: List[multiprocessing.connection.Connection] = []
|
256
|
+
self.cached: Dict[Tuple[Any, ...], Any] = {}
|
257
|
+
|
258
|
+
def _lazy_init(self):
|
259
|
+
if self._initizlied:
|
260
|
+
return
|
261
|
+
|
262
|
+
self.store = dist.TCPStore("localhost", self.port, is_master=True)
|
263
|
+
self.processes = []
|
264
|
+
self.world_size = self.world_size
|
265
|
+
ctx = multiprocessing.get_context("spawn")
|
266
|
+
os.environ["STORE_HOSTNAME"] = "localhost"
|
267
|
+
os.environ["STORE_PORT"] = str(self.port)
|
268
|
+
for i in range(self.world_size):
|
269
|
+
parent_conn, child_conn = multiprocessing.Pipe()
|
270
|
+
worker = ProfilingWorker(self.world_size, i)
|
271
|
+
self.processes.append(
|
272
|
+
ctx.Process(target=worker.run, args=(child_conn,), daemon=True),
|
273
|
+
)
|
274
|
+
self.parent_conns.append(parent_conn)
|
275
|
+
self.processes[-1].start()
|
276
|
+
|
277
|
+
self._initizlied = True
|
278
|
+
|
279
|
+
def __exit__(self) -> None:
|
280
|
+
if self._initizlied:
|
281
|
+
for i in range(self.world_size):
|
282
|
+
conn = self.parent_conns[i]
|
283
|
+
conn.send("exit")
|
284
|
+
time.sleep(0.1)
|
285
|
+
|
286
|
+
def profile_cmd(self, cmd, ranks) -> List[Any | None]:
|
287
|
+
self._lazy_init()
|
288
|
+
|
289
|
+
ret = []
|
290
|
+
assert type(cmd).__name__ == "CallFunction"
|
291
|
+
cmd = CommandHistory.convert_msg(cmd)
|
292
|
+
cmd = cmd._replace(function=resolvable_function(cmd.function))
|
293
|
+
|
294
|
+
def dtensor_ref_filter(v: Any):
|
295
|
+
if isinstance(v, DTensorRef):
|
296
|
+
return v.factory
|
297
|
+
return v
|
298
|
+
|
299
|
+
key_filters.append(dtensor_ref_filter)
|
300
|
+
tensors, shape_key = hashable_tensor_flatten((cmd, ranks), {})
|
301
|
+
inputs_group = TensorGroup([t._fake for t in tensors]) # pyre-ignore[16]
|
302
|
+
requires_grads = tuple(t.requires_grad for t in tensors)
|
303
|
+
key = (shape_key, inputs_group.pattern, requires_grads)
|
304
|
+
key_filters.pop()
|
305
|
+
# key = _make_key((cmd, ranks), None)
|
306
|
+
if key in self.cached:
|
307
|
+
return self.cached[key]
|
308
|
+
|
309
|
+
for i in ranks:
|
310
|
+
conn = self.parent_conns[i]
|
311
|
+
conn.send(cmd)
|
312
|
+
|
313
|
+
# This cannot be merged to the previous for loop. A deadlock can happen.
|
314
|
+
for _ in ranks:
|
315
|
+
ret.append(conn.recv())
|
316
|
+
|
317
|
+
clean_ret = []
|
318
|
+
for r in ret:
|
319
|
+
if r[0] == "exception":
|
320
|
+
raise RuntimeError(r[1])
|
321
|
+
clean_ret.append(r[1])
|
322
|
+
|
323
|
+
self.cached[key] = clean_ret
|
324
|
+
return clean_ret
|
325
|
+
|
326
|
+
|
327
|
+
def _return_if_exist(attr):
|
328
|
+
def decorator(func):
|
329
|
+
@functools.wraps(func)
|
330
|
+
def wrapper(self, *args, **kwargs):
|
331
|
+
user_fn = getattr(self, attr)
|
332
|
+
if isinstance(user_fn, int):
|
333
|
+
return user_fn
|
334
|
+
elif callable(user_fn):
|
335
|
+
return user_fn(*args, **kwargs)
|
336
|
+
return func(self, *args, **kwargs)
|
337
|
+
|
338
|
+
return wrapper
|
339
|
+
|
340
|
+
return decorator
|
341
|
+
|
342
|
+
|
343
|
+
class TimingType(str, enum.Enum):
|
344
|
+
SEND_TENSOR = "_send_tensor_time"
|
345
|
+
REDUCE = "_reduce_time"
|
346
|
+
CALL_FUNCTION = "_call_function_time"
|
347
|
+
KERNEL_LAUNCH = "_kernel_launch_time"
|
348
|
+
WAIT_EVENT = "_wait_event_time"
|
349
|
+
|
350
|
+
|
351
|
+
TimingFunction = Callable[[Optional[NamedTuple]], int]
|
352
|
+
|
353
|
+
|
354
|
+
class RuntimeEstimator:
|
355
|
+
def __init__(self) -> None:
|
356
|
+
self._call_function_time: TimingFunction | int | None = None
|
357
|
+
self._reduce_time: TimingFunction | int | None = None
|
358
|
+
self._send_tensor_time: TimingFunction | int | None = None
|
359
|
+
self._wait_event_time: int | None = None
|
360
|
+
self._kernel_launch_time: int | None = None
|
361
|
+
|
362
|
+
@_return_if_exist("_send_tensor_time")
|
363
|
+
def _get_send_tensor_time(self, msg: messages.SendTensor) -> int:
|
364
|
+
if msg.from_ranks == msg.to_ranks:
|
365
|
+
return 1_000
|
366
|
+
return 100_000
|
367
|
+
|
368
|
+
@_return_if_exist("_reduce_time")
|
369
|
+
def _get_reduce_time(self, msg: messages.Reduce) -> int:
|
370
|
+
return 100_000
|
371
|
+
|
372
|
+
@_return_if_exist("_call_function_time")
|
373
|
+
def _get_call_function_time(self, msg: messages.CallFunction) -> int:
|
374
|
+
return 10_000
|
375
|
+
|
376
|
+
@_return_if_exist("_kernel_launch_time")
|
377
|
+
def _get_kernel_launch_time(self) -> int:
|
378
|
+
return 500
|
379
|
+
|
380
|
+
@_return_if_exist("_wait_event_time")
|
381
|
+
def _get_wait_event_time(self) -> int:
|
382
|
+
return 500
|
383
|
+
|
384
|
+
def set_custom_timing(
|
385
|
+
self, func_or_time: Dict[TimingType, TimingFunction | int]
|
386
|
+
) -> None:
|
387
|
+
"""
|
388
|
+
Set custom timing values for specific message types or events.
|
389
|
+
|
390
|
+
This method allows the user to define custom timing values for various
|
391
|
+
operations in the simulator. The timing can be specified either as a fixed
|
392
|
+
integer value or as a function that computes the timing dynamically.
|
393
|
+
All the integer values are in nanoseconds.
|
394
|
+
|
395
|
+
Args:
|
396
|
+
func_or_time (Dict[TimingType, TimingFunction | int]): A dictionary
|
397
|
+
mapping TimingType to either a function or an integer. If a function
|
398
|
+
is provided, it should accept an optional NamedTuple as input and
|
399
|
+
return an integer representing the timing in nanoseconds.
|
400
|
+
|
401
|
+
Raises:
|
402
|
+
AssertionError: If the values in the dictionary are neither integers
|
403
|
+
nor callable functions.
|
404
|
+
"""
|
405
|
+
for k, v in func_or_time.items():
|
406
|
+
assert isinstance(v, int) or callable(
|
407
|
+
v
|
408
|
+
), "The supported customized timing are an integer or a function."
|
409
|
+
setattr(self, k.value, v)
|
410
|
+
|
411
|
+
def get_runtime(self, msg) -> int:
|
412
|
+
match msg:
|
413
|
+
case messages.CallFunction():
|
414
|
+
return self._get_call_function_time(msg)
|
415
|
+
case messages.Reduce():
|
416
|
+
return self._get_reduce_time(msg)
|
417
|
+
case messages.SendTensor():
|
418
|
+
return self._get_send_tensor_time(msg)
|
419
|
+
case "kernel_launch":
|
420
|
+
return self._get_kernel_launch_time()
|
421
|
+
case "wait_event":
|
422
|
+
return self._get_wait_event_time()
|
423
|
+
case _:
|
424
|
+
raise ValueError(f"Get an unexpected message for profiling, {msg}.")
|