torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/worker/worker.py
ADDED
@@ -0,0 +1,1191 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import asyncio
|
9
|
+
import bdb
|
10
|
+
import itertools
|
11
|
+
import logging
|
12
|
+
import os
|
13
|
+
import pdb # noqa
|
14
|
+
import queue
|
15
|
+
import threading
|
16
|
+
from collections import deque
|
17
|
+
from contextlib import contextmanager
|
18
|
+
from traceback import extract_tb
|
19
|
+
from typing import (
|
20
|
+
Any,
|
21
|
+
Callable,
|
22
|
+
Dict,
|
23
|
+
Generator,
|
24
|
+
List,
|
25
|
+
NamedTuple,
|
26
|
+
Optional,
|
27
|
+
Protocol,
|
28
|
+
Sequence,
|
29
|
+
Tuple,
|
30
|
+
Union,
|
31
|
+
)
|
32
|
+
|
33
|
+
from weakref import WeakKeyDictionary
|
34
|
+
|
35
|
+
import torch
|
36
|
+
import torch.distributed
|
37
|
+
import torch.fx
|
38
|
+
import zmq
|
39
|
+
import zmq.asyncio
|
40
|
+
|
41
|
+
from monarch.common import messages
|
42
|
+
from monarch.common.function import ResolvableFunction
|
43
|
+
from monarch.common.messages import DependentOnError, Dims
|
44
|
+
from monarch.common.process_group import SingleControllerProcessGroupWrapper
|
45
|
+
from monarch.common.reference import Ref, Referenceable
|
46
|
+
from monarch.common.shape import NDSlice
|
47
|
+
from monarch.common.tensor_factory import TensorFactory
|
48
|
+
from monarch.common.tree import flatten, flattener
|
49
|
+
from monarch_supervisor import get_message_queue, Letter
|
50
|
+
from monarch_supervisor.logging import initialize_logging
|
51
|
+
|
52
|
+
from .compiled_block import CompiledBlock
|
53
|
+
from .debugger import _set_trace
|
54
|
+
from .monitor import Monitor
|
55
|
+
|
56
|
+
logger = logging.getLogger(__name__)
|
57
|
+
try:
|
58
|
+
CONTROLLER_COMPILED_REPEAT = 0 != int(os.environ["CONTROLLER_COMPILED_REPEAT"])
|
59
|
+
except KeyError:
|
60
|
+
CONTROLLER_COMPILED_REPEAT = True
|
61
|
+
|
62
|
+
|
63
|
+
def set_default_dtype(dtype: torch.dtype):
|
64
|
+
torch.set_default_dtype(dtype)
|
65
|
+
|
66
|
+
|
67
|
+
class Dim(NamedTuple):
|
68
|
+
name: str
|
69
|
+
rank: int
|
70
|
+
size: int
|
71
|
+
members: List[int]
|
72
|
+
|
73
|
+
|
74
|
+
class RemoteProcessGroupShell:
|
75
|
+
def __init__(self, device_mesh: "DeviceMesh", dims: Dims, ref: Ref):
|
76
|
+
self.device_mesh = device_mesh
|
77
|
+
self.dims = dims
|
78
|
+
self.ref = ref
|
79
|
+
|
80
|
+
# return the process group, sanity checking that the stream it was created on is the stream it is being used on.
|
81
|
+
def get_process_group_for_stream(self, stream: "Stream"):
|
82
|
+
return self.device_mesh.get_process_group(stream, self.dims, pg=self.ref)
|
83
|
+
|
84
|
+
|
85
|
+
def _new_process_group(
|
86
|
+
controller_global_unique_name: str, ranks: Optional[List[int]], split: bool
|
87
|
+
):
|
88
|
+
assert torch.distributed.is_initialized()
|
89
|
+
from unittest.mock import patch
|
90
|
+
|
91
|
+
# Pytorch versions from about the past month have an implementation of process group name with local names that
|
92
|
+
# can cause TCPStore name collisions (https://www.internalfb.com/intern/diff/D67312715/).
|
93
|
+
# This will get fixed soon in pytorch, but will take some time to rollout.
|
94
|
+
# In the meantime, our workers have enough knowledge to simply generate a unique names based on the data they already have.
|
95
|
+
# While not strictly needed once pytorch fixes the bug, this illustrates how our own initialization of nccl can just directly
|
96
|
+
# provide a unique key for each process group it is creating.
|
97
|
+
with patch(
|
98
|
+
"torch.distributed.distributed_c10d._process_group_name",
|
99
|
+
side_effect=lambda *args, **kwargs: controller_global_unique_name,
|
100
|
+
) as the_patch:
|
101
|
+
if split:
|
102
|
+
assert ranks is not None
|
103
|
+
pg = torch.distributed.split_group(None, [ranks])
|
104
|
+
else:
|
105
|
+
pg = torch.distributed.new_group(ranks, use_local_synchronization=True)
|
106
|
+
|
107
|
+
assert the_patch.called
|
108
|
+
return pg
|
109
|
+
|
110
|
+
|
111
|
+
restart_count = 0
|
112
|
+
|
113
|
+
|
114
|
+
class DeviceMesh:
|
115
|
+
def __init__(self, id: int, names: Dims, ranks: NDSlice, rank: int):
|
116
|
+
self.id = id
|
117
|
+
self.dims: Dict[str, Dim] = {}
|
118
|
+
coordinates = ranks.coordinates(rank)
|
119
|
+
for coordinate, name, size, stride in zip(
|
120
|
+
coordinates, names, ranks.sizes, ranks.strides
|
121
|
+
):
|
122
|
+
start = rank - stride * coordinate
|
123
|
+
members = [*range(start, start + stride * size, stride)]
|
124
|
+
assert members[coordinate] == rank
|
125
|
+
self.dims[name] = Dim(name, coordinate, size, members)
|
126
|
+
self.all_ranks: List[int] = list(ranks)
|
127
|
+
self.process_group_for_stream: WeakKeyDictionary["Stream", Any] = (
|
128
|
+
WeakKeyDictionary()
|
129
|
+
)
|
130
|
+
|
131
|
+
def get_ranks_for_dim_slice(self, names: Dims):
|
132
|
+
if len(names) == 0:
|
133
|
+
return []
|
134
|
+
if len(names) == 1:
|
135
|
+
return self.dims[names[0]].members
|
136
|
+
if len(names) == len(self.dims):
|
137
|
+
return self.all_ranks
|
138
|
+
|
139
|
+
dims = [self.dims[n] for n in names]
|
140
|
+
|
141
|
+
members = [dim.members for dim in dims]
|
142
|
+
strides = [d[1] - d[0] if len(d) > 1 else 0 for d in members]
|
143
|
+
start = members[0][dims[0].rank]
|
144
|
+
for d, s in zip(dims, strides):
|
145
|
+
start -= s * d.rank
|
146
|
+
|
147
|
+
ranks = []
|
148
|
+
for idxs in itertools.product(*[range(d.size) for d in dims]):
|
149
|
+
offset = sum([i * s for i, s in zip(idxs, strides)])
|
150
|
+
ranks.append(start + offset)
|
151
|
+
return ranks
|
152
|
+
|
153
|
+
def create_process_group(
|
154
|
+
self, stream: "Stream", dims: Dims, pg: Optional[Ref] = None
|
155
|
+
):
|
156
|
+
if stream not in self.process_group_for_stream:
|
157
|
+
self.process_group_for_stream[stream] = {}
|
158
|
+
dims = tuple(sorted(dims))
|
159
|
+
key = (pg, dims)
|
160
|
+
if key in self.process_group_for_stream[stream]:
|
161
|
+
raise AssertionError(
|
162
|
+
f"Tried to create a process group for {stream=}, {dims=} but it already exists!"
|
163
|
+
)
|
164
|
+
ranks = self.get_ranks_for_dim_slice(dims)
|
165
|
+
indices = [
|
166
|
+
str(d.rank) if d.name not in dims else "X" for d in self.dims.values()
|
167
|
+
]
|
168
|
+
name = f"restart_{restart_count}_mesh_{self.id}_stream_{stream.id}_{'_'.join(indices)}"
|
169
|
+
if pg is not None:
|
170
|
+
name += f"_group_{pg}"
|
171
|
+
self.process_group_for_stream[stream][key] = (
|
172
|
+
SingleControllerProcessGroupWrapper(
|
173
|
+
_new_process_group(name, ranks, split=True)
|
174
|
+
)
|
175
|
+
)
|
176
|
+
return self.get_process_group(stream, dims, pg=pg)
|
177
|
+
|
178
|
+
def get_process_group(self, stream: "Stream", dims: Dims, pg: Optional[Ref] = None):
|
179
|
+
dims = tuple(sorted(dims))
|
180
|
+
key = (pg, dims)
|
181
|
+
return self.process_group_for_stream[stream][key]
|
182
|
+
|
183
|
+
def create_process_group_shell(self, dims: Dims, ref: Ref):
|
184
|
+
return RemoteProcessGroupShell(self, dims, ref)
|
185
|
+
|
186
|
+
|
187
|
+
def _rank(mesh: "DeviceMesh", dim: str):
|
188
|
+
return torch.full((), mesh.dims[dim].rank, dtype=torch.long)
|
189
|
+
|
190
|
+
|
191
|
+
def _process_idx(mesh: "DeviceMesh"):
|
192
|
+
"""
|
193
|
+
Return linear idx of the current process in the mesh.
|
194
|
+
"""
|
195
|
+
# any dimension can be used to query our rank
|
196
|
+
_, dim = next(iter(mesh.dims.items()))
|
197
|
+
return torch.full((), dim.members[dim.rank], dtype=torch.long)
|
198
|
+
|
199
|
+
|
200
|
+
def _reduce(
|
201
|
+
local_tensor: torch.Tensor,
|
202
|
+
source_mesh: DeviceMesh,
|
203
|
+
group,
|
204
|
+
group_size: int,
|
205
|
+
reduction: str,
|
206
|
+
scatter: bool,
|
207
|
+
inplace: bool,
|
208
|
+
out: Optional[torch.Tensor],
|
209
|
+
):
|
210
|
+
if reduction == "stack":
|
211
|
+
if scatter:
|
212
|
+
output = local_tensor
|
213
|
+
if not inplace:
|
214
|
+
output = local_tensor.clone() if out is None else out
|
215
|
+
torch.distributed.all_to_all_single(output, local_tensor, group=group)
|
216
|
+
return output
|
217
|
+
|
218
|
+
assert not inplace
|
219
|
+
output = (
|
220
|
+
torch.empty(
|
221
|
+
[group_size, *local_tensor.shape],
|
222
|
+
dtype=local_tensor.dtype,
|
223
|
+
device=local_tensor.device,
|
224
|
+
layout=local_tensor.layout,
|
225
|
+
)
|
226
|
+
if out is None
|
227
|
+
else out
|
228
|
+
)
|
229
|
+
torch.distributed.all_gather_into_tensor(output, local_tensor, group=group)
|
230
|
+
return output
|
231
|
+
|
232
|
+
op = getattr(torch.distributed.ReduceOp, reduction.upper())
|
233
|
+
|
234
|
+
if scatter:
|
235
|
+
assert not inplace
|
236
|
+
output = (
|
237
|
+
torch.empty(
|
238
|
+
local_tensor.shape[1:],
|
239
|
+
dtype=local_tensor.dtype,
|
240
|
+
device=local_tensor.device,
|
241
|
+
layout=local_tensor.layout,
|
242
|
+
)
|
243
|
+
if out is None
|
244
|
+
else out
|
245
|
+
)
|
246
|
+
torch.distributed.reduce_scatter_tensor(
|
247
|
+
output, local_tensor, op=op, group=group
|
248
|
+
)
|
249
|
+
return output
|
250
|
+
|
251
|
+
output = local_tensor
|
252
|
+
if not inplace:
|
253
|
+
output = local_tensor.clone() if out is None else out
|
254
|
+
torch.distributed.all_reduce(output, op=op, group=group)
|
255
|
+
return output
|
256
|
+
|
257
|
+
|
258
|
+
class _TLS(threading.local):
|
259
|
+
def __init__(self):
|
260
|
+
self.tracing: Optional["CompiledBlock"] = None
|
261
|
+
self.stream: Optional["Stream"] = None
|
262
|
+
|
263
|
+
|
264
|
+
_tls = _TLS()
|
265
|
+
|
266
|
+
|
267
|
+
def schedule_on_stream_thread(executes_on_error: bool):
|
268
|
+
def wrapper(fn):
|
269
|
+
return lambda self, *args, **kwargs: self.schedule(
|
270
|
+
lambda: (
|
271
|
+
logger.debug(
|
272
|
+
"executing: %s(args=%s, kwargs=%s)", fn.__name__, args, kwargs
|
273
|
+
),
|
274
|
+
fn(self, *args, **kwargs),
|
275
|
+
),
|
276
|
+
executes_on_error,
|
277
|
+
)
|
278
|
+
|
279
|
+
return wrapper
|
280
|
+
|
281
|
+
|
282
|
+
class Stream:
|
283
|
+
def __init__(self, worker: "Worker", id: int, default: bool):
|
284
|
+
self.id = id
|
285
|
+
self.worker = worker
|
286
|
+
self.thread: Optional[threading.Thread] = None
|
287
|
+
self.q: queue.Queue[Callable[[], None]] = queue.Queue()
|
288
|
+
# used to send messages pdb from controller see debugger.py
|
289
|
+
self.debugger_queue: queue.Queue[Any] = queue.Queue()
|
290
|
+
self.should_exit = threading.Event()
|
291
|
+
self.current_recording: Optional[int] = None
|
292
|
+
if default:
|
293
|
+
self._cuda_stream = None
|
294
|
+
else:
|
295
|
+
self._cuda_stream = torch.cuda.Stream()
|
296
|
+
|
297
|
+
@schedule_on_stream_thread(executes_on_error=False)
|
298
|
+
def run_recording(
|
299
|
+
self, ident: int, impl: Callable, results: List["Cell"], inputs: List["Cell"]
|
300
|
+
):
|
301
|
+
self.current_recording = ident
|
302
|
+
try:
|
303
|
+
impl(results, inputs)
|
304
|
+
finally:
|
305
|
+
self.current_recording = None
|
306
|
+
|
307
|
+
@property
|
308
|
+
def cuda_stream(self):
|
309
|
+
if self._cuda_stream is None:
|
310
|
+
return torch.cuda.current_stream()
|
311
|
+
else:
|
312
|
+
return self._cuda_stream
|
313
|
+
|
314
|
+
@contextmanager
|
315
|
+
def enable(self):
|
316
|
+
if self._cuda_stream is None:
|
317
|
+
yield
|
318
|
+
return
|
319
|
+
with torch.cuda.stream(self._cuda_stream):
|
320
|
+
yield
|
321
|
+
|
322
|
+
def event(self):
|
323
|
+
e = torch.cuda.Event()
|
324
|
+
self.cuda_stream.record_event(e)
|
325
|
+
return e
|
326
|
+
|
327
|
+
def wait_event(self, event):
|
328
|
+
self.cuda_stream.wait_event(event)
|
329
|
+
|
330
|
+
def wait_stream(self, stream):
|
331
|
+
self.cuda_stream.wait_stream(stream.cuda_stream)
|
332
|
+
|
333
|
+
def start(self) -> threading.Thread:
|
334
|
+
thread = threading.Thread(target=self.main)
|
335
|
+
thread.start()
|
336
|
+
return thread
|
337
|
+
|
338
|
+
def main(self):
|
339
|
+
_tls.stream = self
|
340
|
+
with self.enable():
|
341
|
+
try:
|
342
|
+
while True:
|
343
|
+
self.q.get()()
|
344
|
+
except StopIteration:
|
345
|
+
pass
|
346
|
+
except Exception as e:
|
347
|
+
logger.exception("Stream thread exiting with exception.")
|
348
|
+
msg = messages.InternalException(e, extract_tb(e.__traceback__))
|
349
|
+
self.worker.schedule(lambda: self.worker.internal_error(msg))
|
350
|
+
|
351
|
+
def exit(self):
|
352
|
+
def stop():
|
353
|
+
raise StopIteration
|
354
|
+
|
355
|
+
self.schedule(stop)
|
356
|
+
self.debugger_queue.put("detach")
|
357
|
+
|
358
|
+
def join(self):
|
359
|
+
if self.thread is None:
|
360
|
+
return
|
361
|
+
self.exit()
|
362
|
+
self.thread.join()
|
363
|
+
|
364
|
+
def schedule(self, fn: Callable[[], None], executes_on_error: bool = False):
|
365
|
+
if _tls.tracing:
|
366
|
+
tracing = _tls.tracing
|
367
|
+
if executes_on_error:
|
368
|
+
tracing.fallback[self].append(fn)
|
369
|
+
with tracing.record_to(self):
|
370
|
+
fn()
|
371
|
+
return
|
372
|
+
|
373
|
+
if self.thread is None:
|
374
|
+
self.thread = threading.Thread(target=self.main, daemon=True)
|
375
|
+
self.thread.start()
|
376
|
+
self.q.put(fn)
|
377
|
+
|
378
|
+
def call_or_trace(self, fn, *args, **kwargs):
|
379
|
+
if _tls.tracing:
|
380
|
+
return _tls.tracing.call_function(fn, args, kwargs)
|
381
|
+
return fn(*args, **kwargs)
|
382
|
+
|
383
|
+
def report_error(self, ident: int, index: int, e: Exception, extra: Any = None):
|
384
|
+
logger.exception(f"Error generating {ident}, {extra=}", exc_info=e)
|
385
|
+
self.worker.q.send(
|
386
|
+
messages.RemoteFunctionFailed(ident, index, e, extract_tb(e.__traceback__))
|
387
|
+
)
|
388
|
+
return DependentOnError(ident)
|
389
|
+
|
390
|
+
@contextmanager
|
391
|
+
def try_define(
|
392
|
+
self, ident: Optional[int], results: Sequence["Cell"], extra: Any = None
|
393
|
+
):
|
394
|
+
tracing = _tls.tracing
|
395
|
+
if tracing:
|
396
|
+
ctx = tracing.current_context
|
397
|
+
ctx.ident = ident
|
398
|
+
tracing.mutates(results)
|
399
|
+
|
400
|
+
try:
|
401
|
+
yield
|
402
|
+
except DependentOnError as e:
|
403
|
+
for r in results:
|
404
|
+
r.set(e)
|
405
|
+
# note: there is no need to to send RemoteFunctionFailed
|
406
|
+
# because the controller would have already gotten and propagated the
|
407
|
+
# original created of DependentOnError.
|
408
|
+
except bdb.BdbQuit:
|
409
|
+
raise
|
410
|
+
except Exception as e:
|
411
|
+
# when try_define does not have an ident,
|
412
|
+
# the only error we expected is DependendOnError
|
413
|
+
# other errors should get treated as internal errors.
|
414
|
+
if ident is None:
|
415
|
+
raise
|
416
|
+
if self.current_recording is not None:
|
417
|
+
exc = self.report_error(self.current_recording, ident, e, extra)
|
418
|
+
else:
|
419
|
+
exc = self.report_error(ident, 0, e, extra)
|
420
|
+
for r in results:
|
421
|
+
r.set(exc)
|
422
|
+
finally:
|
423
|
+
if _tls.tracing:
|
424
|
+
# pyre-fixme[8]: Attribute has type `ErrorContext`; used as `None`.
|
425
|
+
_tls.tracing.current_context = None
|
426
|
+
|
427
|
+
@schedule_on_stream_thread(executes_on_error=False)
|
428
|
+
def call_function(
|
429
|
+
self,
|
430
|
+
ident: int,
|
431
|
+
defines: Tuple["Cell", ...],
|
432
|
+
flatten_result: Any,
|
433
|
+
mutates: Tuple["Cell", ...],
|
434
|
+
rfunction: ResolvableFunction,
|
435
|
+
inputs: List["Cell"],
|
436
|
+
unflatten_inputs: Any,
|
437
|
+
device_mesh: Optional["DeviceMesh"] = None,
|
438
|
+
):
|
439
|
+
with self.try_define(
|
440
|
+
ident, [*defines, *mutates], extra=(rfunction, defines, mutates, inputs)
|
441
|
+
):
|
442
|
+
function = rfunction.resolve()
|
443
|
+
resolved_inputs = []
|
444
|
+
for i in inputs:
|
445
|
+
input_ = i.get()
|
446
|
+
if isinstance(input_, RemoteProcessGroupShell):
|
447
|
+
# get the process group for the stream but dont' allow it to be created from
|
448
|
+
# this context since this isn't being run on the event loop.
|
449
|
+
resolved_inputs.append(input_.get_process_group_for_stream(self))
|
450
|
+
else:
|
451
|
+
resolved_inputs.append(input_)
|
452
|
+
|
453
|
+
args, kwargs = unflatten_inputs(resolved_inputs)
|
454
|
+
if _tls.tracing:
|
455
|
+
block = _tls.tracing
|
456
|
+
fn_node: torch.fx.Node = block.call_function(function, args, kwargs)
|
457
|
+
tensors = [
|
458
|
+
t.node if isinstance(t, torch.fx.Proxy) else t
|
459
|
+
for t in flatten_result(block.proxy(fn_node))
|
460
|
+
]
|
461
|
+
else:
|
462
|
+
result = function(*args, **kwargs)
|
463
|
+
tensors = flatten_result(result)
|
464
|
+
assert len(defines) == len(tensors)
|
465
|
+
for d, t in zip(defines, tensors):
|
466
|
+
d.set(t)
|
467
|
+
|
468
|
+
@schedule_on_stream_thread(executes_on_error=False)
|
469
|
+
def send_value(
|
470
|
+
self,
|
471
|
+
ident: int,
|
472
|
+
rfunction: Optional[ResolvableFunction],
|
473
|
+
mutates: Tuple["Cell", ...],
|
474
|
+
inputs: List["Cell"],
|
475
|
+
unflatten: Any,
|
476
|
+
pipe: Optional["WorkerPipe"],
|
477
|
+
):
|
478
|
+
with self.try_define(ident, mutates):
|
479
|
+
args, kwargs = unflatten(c.get() for c in inputs)
|
480
|
+
function = (lambda x: x) if rfunction is None else rfunction.resolve()
|
481
|
+
result = function(*args, **kwargs)
|
482
|
+
if pipe is None:
|
483
|
+
self.worker.q.send(messages.FetchResult(ident, result))
|
484
|
+
else:
|
485
|
+
self.call_or_trace(pipe.send, result)
|
486
|
+
|
487
|
+
@schedule_on_stream_thread(executes_on_error=False)
|
488
|
+
def collective_call(
|
489
|
+
self,
|
490
|
+
function: Callable,
|
491
|
+
factory: TensorFactory,
|
492
|
+
input_: "Cell",
|
493
|
+
result: "Cell",
|
494
|
+
out: Optional["Cell"] = None,
|
495
|
+
):
|
496
|
+
try:
|
497
|
+
local_tensor = input_.get()
|
498
|
+
out_tensor = None if out is None else out.get()
|
499
|
+
except DependentOnError:
|
500
|
+
# even if we were broken before, we have to participate in the collective
|
501
|
+
# because we cannot signal to other ranks that we were broken
|
502
|
+
# the controller will see the error message we sent before and know
|
503
|
+
# the downstream values are broken.
|
504
|
+
local_tensor = factory.zeros()
|
505
|
+
out_tensor = None
|
506
|
+
# XXX - we should be careful about starting the collective with a tensor that doesn't match the expected
|
507
|
+
# factory size. It can error. however, before we can do something about it we need to assign a failure
|
508
|
+
# identity to this reduce object.
|
509
|
+
output = self.call_or_trace(function, local_tensor, out_tensor)
|
510
|
+
result.set(output)
|
511
|
+
|
512
|
+
@schedule_on_stream_thread(executes_on_error=True)
|
513
|
+
def borrow_create(self, input_: "Cell", borrow: "Borrow"):
|
514
|
+
self.call_or_trace(borrow.create, input_.get(), self)
|
515
|
+
|
516
|
+
@schedule_on_stream_thread(executes_on_error=True)
|
517
|
+
def borrow_first_use(self, result: "Cell", borrow: "Borrow"):
|
518
|
+
with self.try_define(None, [result]):
|
519
|
+
result.set(self.call_or_trace(borrow.first_use))
|
520
|
+
|
521
|
+
@schedule_on_stream_thread(executes_on_error=True)
|
522
|
+
def borrow_last_use(self, borrow: "Borrow"):
|
523
|
+
self.call_or_trace(borrow.last_use)
|
524
|
+
|
525
|
+
@schedule_on_stream_thread(executes_on_error=True)
|
526
|
+
def borrow_drop(self, borrow: "Borrow"):
|
527
|
+
self.call_or_trace(borrow.drop)
|
528
|
+
|
529
|
+
|
530
|
+
class Borrow:
|
531
|
+
def __init__(self, from_stream: Stream, to_stream: Stream):
|
532
|
+
self.from_stream = from_stream
|
533
|
+
self.to_stream = to_stream
|
534
|
+
self.first_use_queue = queue.Queue()
|
535
|
+
self.last_use_queue = queue.Queue()
|
536
|
+
# used to ensure the tensor memory stays alive in the
|
537
|
+
# allocator until it is returned to its original stream
|
538
|
+
self.tensor_storage = Cell(None)
|
539
|
+
|
540
|
+
def create(self, input_: Any, stream: Stream):
|
541
|
+
self.first_use_queue.put((stream.event(), input_))
|
542
|
+
|
543
|
+
def first_use(self):
|
544
|
+
event, t = self.first_use_queue.get()
|
545
|
+
self.tensor_storage.set(t)
|
546
|
+
self.to_stream.wait_event(event)
|
547
|
+
# raise any potential error _after_ already processing
|
548
|
+
# the events. We always do the synchronizations even
|
549
|
+
# if the value being borrowed is an error.
|
550
|
+
return self.tensor_storage.get()
|
551
|
+
|
552
|
+
def last_use(self):
|
553
|
+
t = self.tensor_storage.value
|
554
|
+
self.tensor_storage.set(undefined_cell)
|
555
|
+
self.last_use_queue.put((self.to_stream.event(), t))
|
556
|
+
|
557
|
+
def drop(self):
|
558
|
+
event, t = self.last_use_queue.get()
|
559
|
+
self.from_stream.wait_event(event)
|
560
|
+
del t
|
561
|
+
|
562
|
+
|
563
|
+
class WorkerMessageQueue(Protocol):
|
564
|
+
def _socket(self, kind) -> zmq.Socket: ...
|
565
|
+
|
566
|
+
def send(self, message: Any) -> None: ...
|
567
|
+
|
568
|
+
async def recv_async(self) -> Letter: ...
|
569
|
+
|
570
|
+
def recvready(self, timeout: Optional[float]) -> List[Letter]: ...
|
571
|
+
|
572
|
+
|
573
|
+
class WorkerPipe:
|
574
|
+
"""
|
575
|
+
Worker (e.g Trainer) process pipe
|
576
|
+
"""
|
577
|
+
|
578
|
+
def __init__(self, q: WorkerMessageQueue, pipe_name: str, max_messages: int = 50):
|
579
|
+
# breaking abstraction layer here, but it is an easy way to get a way to send messages
|
580
|
+
# to the process
|
581
|
+
self._sock = q._socket(zmq.PAIR)
|
582
|
+
self._sock.setsockopt(zmq.SNDHWM, max_messages)
|
583
|
+
self._sock.setsockopt(zmq.RCVHWM, max_messages)
|
584
|
+
self._sock.bind(pipe_name)
|
585
|
+
|
586
|
+
def send(self, v: Any):
|
587
|
+
self._sock.send_pyobj(v)
|
588
|
+
|
589
|
+
def recv(self) -> Any:
|
590
|
+
return self._sock.recv_pyobj()
|
591
|
+
|
592
|
+
# Allows us to pass the pipe as a function that can be called to get the next value
|
593
|
+
def resolve(self) -> Callable:
|
594
|
+
return self.recv
|
595
|
+
|
596
|
+
|
597
|
+
undefined_cell = RuntimeError("undefined cell")
|
598
|
+
|
599
|
+
|
600
|
+
class Cell:
|
601
|
+
__slots__ = ("value",)
|
602
|
+
|
603
|
+
def __init__(self, initial_value=undefined_cell):
|
604
|
+
self.value: Any = initial_value
|
605
|
+
|
606
|
+
def __repr__(self):
|
607
|
+
return "<C>"
|
608
|
+
|
609
|
+
def set(self, value: Any):
|
610
|
+
self.value = value
|
611
|
+
|
612
|
+
def clear(self):
|
613
|
+
self.value = undefined_cell
|
614
|
+
|
615
|
+
def is_defined(self):
|
616
|
+
return self.value is not undefined_cell
|
617
|
+
|
618
|
+
def get(self) -> Any:
|
619
|
+
tracing = _tls.tracing
|
620
|
+
if (
|
621
|
+
tracing is not None
|
622
|
+
and self not in tracing.defined_cells
|
623
|
+
and tracing.recording_stream is not None
|
624
|
+
):
|
625
|
+
return tracing.input_cell(self)
|
626
|
+
v = self.value
|
627
|
+
if isinstance(v, Exception):
|
628
|
+
raise v
|
629
|
+
return v
|
630
|
+
|
631
|
+
|
632
|
+
class Worker:
|
633
|
+
def __init__(self, q: WorkerMessageQueue, rank: int, world: int, local_rank: int):
|
634
|
+
# remote ref id to local value
|
635
|
+
self.env: Dict[int, Cell] = {}
|
636
|
+
self.q = q
|
637
|
+
self.rank = rank
|
638
|
+
self.world = world
|
639
|
+
self.local_rank = local_rank
|
640
|
+
self.last_send_status = 0
|
641
|
+
self.borrows: Dict[int, Tuple[Ref, Borrow]] = {}
|
642
|
+
self.streams: List[Stream] = []
|
643
|
+
self.send_recv_process_groups: Dict[Tuple[Stream, Stream], Any] = {}
|
644
|
+
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
645
|
+
self.stream_thread_error = False
|
646
|
+
self.max_received_ident = 0
|
647
|
+
|
648
|
+
def handle_message(self, event: NamedTuple):
|
649
|
+
cmd = event.__class__.__name__
|
650
|
+
if ident := getattr(event, "ident", None):
|
651
|
+
self.max_received_ident = max(self.max_received_ident, ident)
|
652
|
+
fn = getattr(self, cmd, None)
|
653
|
+
if fn is not None:
|
654
|
+
return fn(event)
|
655
|
+
raise RuntimeError(f"unhandled event: {event}")
|
656
|
+
|
657
|
+
def CreateDeviceMesh(
|
658
|
+
self, m: messages.CreateDeviceMesh
|
659
|
+
): # result: "Ref", names: Tuple[str, ...], ranks: NDSlice):
|
660
|
+
# pyre-ignore
|
661
|
+
self.define(m.result, DeviceMesh(m.result.id, m.names, m.ranks, self.rank))
|
662
|
+
|
663
|
+
def resolve(self, r: Union[Referenceable, Ref]) -> Cell:
|
664
|
+
assert isinstance(r, Ref)
|
665
|
+
return self.env[r.id]
|
666
|
+
|
667
|
+
def CallFunction(self, m: messages.CallFunction):
|
668
|
+
flatten_result = flattener(m.result, lambda x: isinstance(x, Ref))
|
669
|
+
results = flatten_result(m.result)
|
670
|
+
defines = tuple(self.cell(r) for r in results)
|
671
|
+
mutates = tuple(self.resolve(r) for r in m.mutates)
|
672
|
+
stream: Stream = self.resolve(m.stream).get()
|
673
|
+
device_mesh = (
|
674
|
+
self.resolve(m.device_mesh).get() if m.device_mesh is not None else None
|
675
|
+
)
|
676
|
+
inputs, unflatten_inputs = self._inputs((m.args, m.kwargs))
|
677
|
+
|
678
|
+
stream.call_function(
|
679
|
+
m.ident,
|
680
|
+
defines,
|
681
|
+
flatten_result,
|
682
|
+
mutates,
|
683
|
+
m.function,
|
684
|
+
inputs,
|
685
|
+
unflatten_inputs,
|
686
|
+
device_mesh,
|
687
|
+
)
|
688
|
+
|
689
|
+
def CreateRemoteProcessGroup(self, m: messages.CreateRemoteProcessGroup):
|
690
|
+
device_mesh = self.resolve(m.device_mesh).get()
|
691
|
+
result = self.cell(m.result)
|
692
|
+
result.set(device_mesh.create_process_group_shell(m.dims, m.result))
|
693
|
+
|
694
|
+
def CreateStream(self, m: messages.CreateStream):
|
695
|
+
# pyre-ignore
|
696
|
+
stream = Stream(self, m.result.id, m.default)
|
697
|
+
self.streams.append(stream)
|
698
|
+
self.define(m.result, stream)
|
699
|
+
|
700
|
+
def _inputs(self, obj):
|
701
|
+
refs, unflatten = flatten(obj, lambda x: isinstance(x, Ref))
|
702
|
+
inputs = [self.env[r.id] for r in refs]
|
703
|
+
return inputs, unflatten
|
704
|
+
|
705
|
+
def SendValue(self, m: messages.SendValue):
|
706
|
+
assert (
|
707
|
+
not _tls.tracing
|
708
|
+
), "controller should have prevented SendValue in repeat block."
|
709
|
+
stream: Stream = self.resolve(m.stream).get()
|
710
|
+
pipe: Optional["WorkerPipe"] = (
|
711
|
+
self.resolve(m.destination).get() if m.destination is not None else None
|
712
|
+
)
|
713
|
+
inputs, unflatten = self._inputs((m.args, m.kwargs))
|
714
|
+
mutates = tuple(self.resolve(r) for r in m.mutates)
|
715
|
+
stream.send_value(m.ident, m.function, mutates, inputs, unflatten, pipe)
|
716
|
+
|
717
|
+
def PipeRecv(self, m: messages.PipeRecv):
|
718
|
+
stream: Stream = self.resolve(m.stream).get()
|
719
|
+
pipe: WorkerPipe = self.resolve(m.pipe).get()
|
720
|
+
flatten = flattener(m.result, lambda x: isinstance(x, Ref))
|
721
|
+
results = flatten(m.result)
|
722
|
+
results = tuple(self.cell(r) for r in results)
|
723
|
+
stream.call_function(
|
724
|
+
m.ident,
|
725
|
+
results,
|
726
|
+
flatten,
|
727
|
+
(),
|
728
|
+
pipe,
|
729
|
+
(),
|
730
|
+
lambda x: ((), {}),
|
731
|
+
)
|
732
|
+
|
733
|
+
def RequestStatus(self, m: messages.RequestStatus):
|
734
|
+
# wait until all streams have reach the point
|
735
|
+
# we have scheduled, and then respond to the message
|
736
|
+
ident = m.ident
|
737
|
+
count = 0
|
738
|
+
expected = 0
|
739
|
+
|
740
|
+
# runs on asyncio event loop, but
|
741
|
+
# is placed on the event loop by the
|
742
|
+
# stream thread when it reaches this work item
|
743
|
+
def increment_and_send():
|
744
|
+
nonlocal count
|
745
|
+
count += 1
|
746
|
+
if count == expected:
|
747
|
+
self._send_status(ident + 1)
|
748
|
+
|
749
|
+
for stream in self.streams:
|
750
|
+
if stream.thread is not None:
|
751
|
+
expected += 1
|
752
|
+
stream.schedule(lambda: self.schedule(increment_and_send))
|
753
|
+
|
754
|
+
# if there were no active threads we still need to respond to status
|
755
|
+
# messages to make sure controller knows we are alive
|
756
|
+
if expected == 0:
|
757
|
+
self._send_status(ident + 1)
|
758
|
+
|
759
|
+
def Exit(self, m: messages.Exit):
|
760
|
+
for stream in self.streams:
|
761
|
+
stream.exit()
|
762
|
+
for stream in self.streams:
|
763
|
+
logger.info("joining stream")
|
764
|
+
stream.join()
|
765
|
+
if torch.distributed.is_initialized() and m.destroy_pg:
|
766
|
+
for pg in self.send_recv_process_groups.values():
|
767
|
+
torch.distributed.destroy_process_group(pg)
|
768
|
+
if torch.cuda.is_available():
|
769
|
+
torch.cuda.synchronize()
|
770
|
+
torch.distributed.barrier()
|
771
|
+
torch.distributed.destroy_process_group()
|
772
|
+
logger.info("PG destroyed")
|
773
|
+
raise StopIteration()
|
774
|
+
|
775
|
+
def CommandGroup(self, m: messages.CommandGroup):
|
776
|
+
for cmd in m.commands:
|
777
|
+
self.handle_message(cmd)
|
778
|
+
|
779
|
+
@contextmanager
|
780
|
+
def trace(self, value: Optional["CompiledBlock"]) -> Generator[None, Any, Any]:
|
781
|
+
old, _tls.tracing = _tls.tracing, value
|
782
|
+
try:
|
783
|
+
yield
|
784
|
+
finally:
|
785
|
+
_tls.tracing = old
|
786
|
+
|
787
|
+
def DefineRecording(self, m: messages.DefineRecording):
|
788
|
+
block = CompiledBlock()
|
789
|
+
with self.trace(block):
|
790
|
+
for cmd in m.commands:
|
791
|
+
self.handle_message(cmd)
|
792
|
+
block.emit()
|
793
|
+
self.define(m.result, block)
|
794
|
+
|
795
|
+
def RecordingFormal(self, m: messages.RecordingFormal):
|
796
|
+
block = _tls.tracing
|
797
|
+
assert block is not None
|
798
|
+
self.cell(m.result).set(
|
799
|
+
block.define_formal(self.resolve(m.stream).get(), m.argument_index)
|
800
|
+
)
|
801
|
+
|
802
|
+
def RecordingResult(self, m: messages.RecordingResult):
|
803
|
+
block = _tls.tracing
|
804
|
+
assert block is not None
|
805
|
+
with block.record_to(self.resolve(m.stream).get()):
|
806
|
+
node = self.resolve(m.input).get()
|
807
|
+
assert isinstance(node, torch.fx.Node)
|
808
|
+
block.define_result(node, m.output_index)
|
809
|
+
|
810
|
+
def CallRecording(self, m: messages.CallRecording):
|
811
|
+
recording: CompiledBlock = self.resolve(m.recording).get()
|
812
|
+
actuals = [
|
813
|
+
self.resolve(a) if i in recording.used_formals else None
|
814
|
+
for i, a in enumerate(m.actuals)
|
815
|
+
]
|
816
|
+
results = [
|
817
|
+
self.cell(r) if i in recording.used_results else None
|
818
|
+
for i, r in enumerate(m.results)
|
819
|
+
]
|
820
|
+
for stream, impl in recording.impls.items():
|
821
|
+
stream.run_recording(m.ident, impl, results, actuals)
|
822
|
+
|
823
|
+
def DeleteRefs(self, m: messages.DeleteRefs):
|
824
|
+
for id in m.refs:
|
825
|
+
del self.env[id]
|
826
|
+
|
827
|
+
def BorrowCreate(self, m: messages.BorrowCreate):
|
828
|
+
from_stream: Stream = self.resolve(m.from_stream).get()
|
829
|
+
to_stream: Stream = self.resolve(m.to_stream).get()
|
830
|
+
tensor = self.resolve(m.tensor)
|
831
|
+
borrow = Borrow(from_stream, to_stream)
|
832
|
+
if _tls.tracing:
|
833
|
+
_tls.tracing.defined_borrows[borrow] = True
|
834
|
+
from_stream.borrow_create(tensor, borrow)
|
835
|
+
# pyre-fixme[6]: For 2nd argument expected `Tuple[Ref, Borrow]` but got
|
836
|
+
# `Tuple[Tensor, Borrow]`.
|
837
|
+
self.borrows[m.borrow] = (m.result, borrow)
|
838
|
+
|
839
|
+
def BorrowFirstUse(self, m: messages.BorrowFirstUse):
|
840
|
+
result_id, borrow = self.borrows[m.borrow]
|
841
|
+
result = self.cell(result_id)
|
842
|
+
borrow.to_stream.borrow_first_use(result, borrow)
|
843
|
+
|
844
|
+
def BorrowLastUse(self, m: messages.BorrowLastUse):
|
845
|
+
_, borrow = self.borrows[m.borrow]
|
846
|
+
stream = borrow.to_stream
|
847
|
+
stream.borrow_last_use(borrow)
|
848
|
+
|
849
|
+
def BorrowDrop(self, m: messages.BorrowDrop):
|
850
|
+
_, borrow = self.borrows.pop(m.borrow)
|
851
|
+
assert (
|
852
|
+
not _tls.tracing or borrow in _tls.tracing.defined_borrows
|
853
|
+
), "controller should have stopped a drop of a borrow not created in a repeat loop"
|
854
|
+
stream = borrow.from_stream
|
855
|
+
stream.borrow_drop(borrow)
|
856
|
+
|
857
|
+
def CreatePipe(self, m: messages.CreatePipe):
|
858
|
+
device_mesh: DeviceMesh = self.resolve(m.device_mesh).get()
|
859
|
+
pipe_name = f"{m.key}-{self.rank}"
|
860
|
+
ranks = {k: v.rank for k, v in device_mesh.dims.items()}
|
861
|
+
sizes = {k: v.size for k, v in device_mesh.dims.items()}
|
862
|
+
pipe = WorkerPipe(self.q, pipe_name, m.max_messages)
|
863
|
+
self.define(m.result, pipe)
|
864
|
+
|
865
|
+
pipe.send((m.function, ranks, sizes, m.args, m.kwargs))
|
866
|
+
|
867
|
+
def SplitComm(self, m: messages.SplitComm):
|
868
|
+
# Test whether this rank is in the mesh specified by the SplitComm
|
869
|
+
# command. We do this by attempting to dereference the mesh ref; only
|
870
|
+
# the ranks that are on the mesh will succeed.
|
871
|
+
try:
|
872
|
+
device_mesh = self.resolve(m.device_mesh).get()
|
873
|
+
in_mesh = True
|
874
|
+
except KeyError:
|
875
|
+
in_mesh = False
|
876
|
+
|
877
|
+
if in_mesh:
|
878
|
+
# Create a split process group
|
879
|
+
stream = self.resolve(m.stream).get()
|
880
|
+
device_mesh.create_process_group(stream, m.dims)
|
881
|
+
else:
|
882
|
+
# this rank is not in the split group. We still need to participate
|
883
|
+
# in the commSplit call, however.
|
884
|
+
|
885
|
+
# This weird incantation is because the current default split_group
|
886
|
+
# API requires all participants to know what the split ranks should
|
887
|
+
# be. In our case, workers not part of the new group don't know. So
|
888
|
+
# instead we manually contribute a NOCOLOR ncclCommSplit call.
|
889
|
+
default_pg = torch.distributed.distributed_c10d._get_default_group()
|
890
|
+
# pyre-ignore[16]
|
891
|
+
default_pg._get_backend(torch.device("cuda")).perform_nocolor_split(
|
892
|
+
default_pg.bound_device_id
|
893
|
+
)
|
894
|
+
|
895
|
+
def SplitCommForProcessGroup(self, m: messages.SplitCommForProcessGroup):
|
896
|
+
# Test whether this rank is in the mesh specified by the
|
897
|
+
# SplitCommForProcessGroup command. We do this by attempting to
|
898
|
+
# dereference the mesh ref; only the ranks that are on the mesh will
|
899
|
+
# succeed.
|
900
|
+
try:
|
901
|
+
pg = self.resolve(m.remote_process_group).get()
|
902
|
+
in_mesh = True
|
903
|
+
except KeyError:
|
904
|
+
in_mesh = False
|
905
|
+
|
906
|
+
if in_mesh:
|
907
|
+
# Create a split process group
|
908
|
+
stream = self.resolve(m.stream).get()
|
909
|
+
pg.device_mesh.create_process_group(
|
910
|
+
stream, pg.dims, pg=m.remote_process_group
|
911
|
+
)
|
912
|
+
else:
|
913
|
+
# this rank is not in the split group. We still need to participate
|
914
|
+
# in the commSplit call, however.
|
915
|
+
|
916
|
+
# This weird incantation is because the current default split_group
|
917
|
+
# API requires all participants to know what the split ranks should
|
918
|
+
# be. In our case, workers not part of the new group don't know. So
|
919
|
+
# instead we manually contribute a NOCOLOR ncclCommSplit call.
|
920
|
+
default_pg = torch.distributed.distributed_c10d._get_default_group()
|
921
|
+
# pyre-ignore[16]
|
922
|
+
default_pg._get_backend(torch.device("cuda")).perform_nocolor_split(
|
923
|
+
default_pg.bound_device_id
|
924
|
+
)
|
925
|
+
|
926
|
+
def Reduce(self, m: messages.Reduce):
|
927
|
+
stream: Stream = self.resolve(m.stream).get()
|
928
|
+
source_mesh: DeviceMesh = self.resolve(m.source_mesh).get()
|
929
|
+
assert len(m.dims) <= len(source_mesh.dims)
|
930
|
+
if len(m.dims) > 1:
|
931
|
+
assert m.reduction != "stack" and not m.scatter
|
932
|
+
pg = source_mesh.get_process_group(stream, m.dims)
|
933
|
+
local_tensor = self.resolve(m.local_tensor)
|
934
|
+
out = None if m.out is None else self.resolve(m.out)
|
935
|
+
output = self.cell(m.result)
|
936
|
+
|
937
|
+
# we need N only for "stack", and in this case we asserted that that len(m.dims) = 1
|
938
|
+
N = len(source_mesh.dims[m.dims[0]].members) if m.reduction == "stack" else -1
|
939
|
+
|
940
|
+
def reducer(local_tensor, out):
|
941
|
+
return _reduce(
|
942
|
+
local_tensor,
|
943
|
+
source_mesh,
|
944
|
+
pg,
|
945
|
+
N,
|
946
|
+
m.reduction,
|
947
|
+
m.scatter,
|
948
|
+
m.inplace,
|
949
|
+
out,
|
950
|
+
)
|
951
|
+
|
952
|
+
stream.collective_call(reducer, m.factory, local_tensor, output, out)
|
953
|
+
|
954
|
+
def SendTensor(self, m: messages.SendTensor):
|
955
|
+
send_stream: Stream = self.resolve(m.from_stream).get()
|
956
|
+
recv_stream: Stream = self.resolve(m.to_stream).get()
|
957
|
+
pg = self.send_recv_process_groups[(send_stream, recv_stream)]
|
958
|
+
|
959
|
+
try:
|
960
|
+
index = m.from_ranks.index(self.rank)
|
961
|
+
send_to_rank = m.to_ranks[index]
|
962
|
+
except ValueError:
|
963
|
+
send_to_rank = None
|
964
|
+
|
965
|
+
try:
|
966
|
+
index = m.to_ranks.index(self.rank)
|
967
|
+
recv_from_rank = m.from_ranks[index]
|
968
|
+
except ValueError:
|
969
|
+
recv_from_rank = None
|
970
|
+
|
971
|
+
if send_to_rank is None:
|
972
|
+
the_stream = recv_stream
|
973
|
+
elif recv_from_rank is None:
|
974
|
+
the_stream = send_stream
|
975
|
+
elif send_stream is recv_stream:
|
976
|
+
the_stream = send_stream
|
977
|
+
else:
|
978
|
+
raise NotImplementedError(
|
979
|
+
"We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver."
|
980
|
+
"It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync."
|
981
|
+
"Then the send stream would do the nccl op, and then sync with sending stream again."
|
982
|
+
)
|
983
|
+
|
984
|
+
def send_recv(
|
985
|
+
input_tensor: torch.Tensor, out: Optional[torch.Tensor]
|
986
|
+
) -> Optional[torch.Tensor]:
|
987
|
+
# we consider to_mesh to always copy a tensor. But if the
|
988
|
+
# from and to rank are the same, we really do not have
|
989
|
+
# copy it. In this case we do a copy-on-write via _lazy_clone.
|
990
|
+
# The tensor will only be copied for real if someone later
|
991
|
+
# tries to mutate it.
|
992
|
+
if send_to_rank == recv_from_rank:
|
993
|
+
return input_tensor._lazy_clone()
|
994
|
+
ops = []
|
995
|
+
P2POp = torch.distributed.P2POp
|
996
|
+
isend, irecv = torch.distributed.isend, torch.distributed.irecv
|
997
|
+
if send_to_rank is not None:
|
998
|
+
ops.append(P2POp(isend, input_tensor, send_to_rank, pg))
|
999
|
+
|
1000
|
+
if recv_from_rank is not None:
|
1001
|
+
output = m.factory.empty()
|
1002
|
+
ops.append(P2POp(irecv, output, recv_from_rank, pg))
|
1003
|
+
else:
|
1004
|
+
output = None
|
1005
|
+
# invoke batched p2p ops
|
1006
|
+
for op in torch.distributed.batch_isend_irecv(ops):
|
1007
|
+
op.wait()
|
1008
|
+
return output
|
1009
|
+
|
1010
|
+
input = Cell(None) if send_to_rank is None else self.resolve(m.tensor)
|
1011
|
+
output = Cell(None) if recv_from_rank is None else self.cell(m.result)
|
1012
|
+
the_stream.collective_call(send_recv, m.factory, input, output, None)
|
1013
|
+
|
1014
|
+
def BackendNetworkInit(self, m: messages.BackendNetworkInit):
|
1015
|
+
if torch.distributed.is_initialized():
|
1016
|
+
return # for restarts in tests
|
1017
|
+
store = torch.distributed.TCPStore(
|
1018
|
+
m.hostname or os.environ["STORE_HOSTNAME"],
|
1019
|
+
m.port or int(os.environ["STORE_PORT"]),
|
1020
|
+
)
|
1021
|
+
torch.distributed.init_process_group(
|
1022
|
+
backend="nccl",
|
1023
|
+
world_size=self.world,
|
1024
|
+
rank=self.rank,
|
1025
|
+
store=store,
|
1026
|
+
device_id=torch.device("cuda:0"),
|
1027
|
+
)
|
1028
|
+
b = torch.zeros(1, device="cuda")
|
1029
|
+
torch.distributed.all_reduce(b)
|
1030
|
+
|
1031
|
+
def BackendNetworkPointToPointInit(
|
1032
|
+
self, m: messages.BackendNetworkPointToPointInit
|
1033
|
+
):
|
1034
|
+
from_stream: Stream = self.resolve(m.from_stream).get()
|
1035
|
+
to_stream: Stream = self.resolve(m.to_stream).get()
|
1036
|
+
self.send_recv_process_groups[(from_stream, to_stream)] = _new_process_group(
|
1037
|
+
f"restart_{restart_count}_send_{from_stream.id}_recv_{to_stream.id}",
|
1038
|
+
None,
|
1039
|
+
split=False,
|
1040
|
+
)
|
1041
|
+
|
1042
|
+
def DebuggerMessage(self, m: messages.DebuggerMessage):
|
1043
|
+
stream: Stream = self.env[m.stream_id].get()
|
1044
|
+
stream.debugger_queue.put(m.action)
|
1045
|
+
|
1046
|
+
def define(self, r: Union[Ref, Referenceable], value: Any):
|
1047
|
+
assert isinstance(r, Ref)
|
1048
|
+
self.env[r.id] = Cell(value)
|
1049
|
+
|
1050
|
+
def cell(self, r: Union[Ref, Referenceable]):
|
1051
|
+
assert isinstance(r, Ref)
|
1052
|
+
c = self.env[r.id] = Cell()
|
1053
|
+
if _tls.tracing:
|
1054
|
+
_tls.tracing.defined_cells[c] = r.id
|
1055
|
+
return c
|
1056
|
+
|
1057
|
+
def _send_status(self, first_uncompleted_ident):
|
1058
|
+
if first_uncompleted_ident > self.last_send_status:
|
1059
|
+
self.q.send(messages.Status(first_uncompleted_ident))
|
1060
|
+
self.last_send_status = first_uncompleted_ident
|
1061
|
+
|
1062
|
+
async def worker_loop(self):
|
1063
|
+
monitor = Monitor()
|
1064
|
+
monitor.start()
|
1065
|
+
self.loop = asyncio.get_event_loop()
|
1066
|
+
debugq = deque()
|
1067
|
+
while True:
|
1068
|
+
try:
|
1069
|
+
# eventually this event loop should be handled as a separate
|
1070
|
+
# thread (maybe not even python) that just takes and
|
1071
|
+
# responds to messages, with a strong guarentee of never
|
1072
|
+
# getting stuck. For now we just run everything on this thread.
|
1073
|
+
monitor(
|
1074
|
+
lambda: (
|
1075
|
+
logger.error(
|
1076
|
+
f"possible stall while waiting for message: recent messages: {debugq} "
|
1077
|
+
f"{self.max_received_ident=} {self.last_send_status=}"
|
1078
|
+
),
|
1079
|
+
logger.setLevel(logging.INFO),
|
1080
|
+
),
|
1081
|
+
30.0,
|
1082
|
+
)
|
1083
|
+
_, msg = await self.q.recv_async()
|
1084
|
+
logger.debug(f"event: {msg}, env={list(self.env.keys())}")
|
1085
|
+
monitor(
|
1086
|
+
(
|
1087
|
+
lambda msg=msg: logger.error(
|
1088
|
+
f"possible stall while handling {msg}"
|
1089
|
+
)
|
1090
|
+
),
|
1091
|
+
30.0,
|
1092
|
+
)
|
1093
|
+
self.handle_message(msg)
|
1094
|
+
|
1095
|
+
debugq.append(msg)
|
1096
|
+
while len(debugq) > 10:
|
1097
|
+
debugq.popleft()
|
1098
|
+
except StopIteration:
|
1099
|
+
self.q.recvready(0)
|
1100
|
+
self.q.recvready(0.01)
|
1101
|
+
return
|
1102
|
+
except Exception as e:
|
1103
|
+
logger.exception("Worker event loop exiting with internal exception")
|
1104
|
+
self.internal_error(
|
1105
|
+
messages.InternalException(e, extract_tb(e.__traceback__))
|
1106
|
+
)
|
1107
|
+
|
1108
|
+
def schedule(self, fn: Callable[[], None]):
|
1109
|
+
assert self.loop is not None
|
1110
|
+
self.loop.call_soon_threadsafe(fn)
|
1111
|
+
|
1112
|
+
def internal_error(self, msg: messages.InternalException):
|
1113
|
+
self.q.send(msg)
|
1114
|
+
assert self.loop is not None
|
1115
|
+
self.loop.stop()
|
1116
|
+
|
1117
|
+
def event_loop(self):
|
1118
|
+
pdb.set_trace = _set_trace
|
1119
|
+
try:
|
1120
|
+
asyncio.run(self.worker_loop())
|
1121
|
+
except RuntimeError as e:
|
1122
|
+
if "Event loop stopped" in str(e):
|
1123
|
+
logger.warning("Event loop exiting after reporting an internal error.")
|
1124
|
+
|
1125
|
+
else:
|
1126
|
+
raise
|
1127
|
+
|
1128
|
+
|
1129
|
+
def worker_main(_restartable):
|
1130
|
+
rank = int(os.environ["RANK"])
|
1131
|
+
world = int(os.environ["WORLD_SIZE"])
|
1132
|
+
local_rank = int(os.environ["LOCAL_RANK"])
|
1133
|
+
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
1134
|
+
devices = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
1135
|
+
device = devices[local_rank]
|
1136
|
+
else:
|
1137
|
+
device = str(local_rank)
|
1138
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = device
|
1139
|
+
initialize_logging(process_name=f"worker_{rank}")
|
1140
|
+
logger.info("starting, restartable=%s, local_rank=%d", _restartable, local_rank)
|
1141
|
+
# force CUDA to initialize before do any multithreading. This is a
|
1142
|
+
# workaround until https://github.com/pytorch/pytorch/pull/143238 is
|
1143
|
+
# available everywhere.
|
1144
|
+
if torch.cuda.is_available():
|
1145
|
+
torch.ones(1, device="cuda")
|
1146
|
+
q = get_message_queue()
|
1147
|
+
global restart_count
|
1148
|
+
for restart in itertools.count():
|
1149
|
+
restart_count = restart
|
1150
|
+
worker = Worker(q, rank, world, local_rank)
|
1151
|
+
worker.event_loop()
|
1152
|
+
if not _restartable:
|
1153
|
+
break
|
1154
|
+
q.send(messages.Restarted(0))
|
1155
|
+
logger.info("restarting")
|
1156
|
+
|
1157
|
+
|
1158
|
+
class ProcessPipe:
|
1159
|
+
"""Pipe Process Pipe"""
|
1160
|
+
|
1161
|
+
def __init__(self, key: str, max_messages):
|
1162
|
+
import zmq
|
1163
|
+
|
1164
|
+
q = get_message_queue()
|
1165
|
+
self._sock = q._socket(zmq.PAIR)
|
1166
|
+
self._sock.setsockopt(zmq.SNDHWM, max_messages)
|
1167
|
+
self._sock.setsockopt(zmq.RCVHWM, max_messages)
|
1168
|
+
self._sock.connect(key)
|
1169
|
+
self.ranks = {}
|
1170
|
+
self.sizes = {}
|
1171
|
+
|
1172
|
+
def send(self, any: Any):
|
1173
|
+
self._sock.send_pyobj(any)
|
1174
|
+
|
1175
|
+
def recv(self):
|
1176
|
+
return self._sock.recv_pyobj()
|
1177
|
+
|
1178
|
+
|
1179
|
+
def pipe_main(key: str, max_messages):
|
1180
|
+
"""Main function for pipe process"""
|
1181
|
+
initialize_logging(f"pipe_{key}")
|
1182
|
+
pipe_obj = ProcessPipe(key, max_messages)
|
1183
|
+
rfunction, pipe_obj.ranks, pipe_obj.sizes, args, kwargs = pipe_obj.recv()
|
1184
|
+
function = rfunction.resolve()
|
1185
|
+
try:
|
1186
|
+
function(pipe_obj, *args, **kwargs)
|
1187
|
+
except Exception as e:
|
1188
|
+
logger.exception("pipe_main exiting with exception")
|
1189
|
+
get_message_queue().send(
|
1190
|
+
messages.RemoteGeneratorFailed(e, extract_tb(e.__traceback__))
|
1191
|
+
)
|