torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,481 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import logging
|
9
|
+
import os
|
10
|
+
import threading
|
11
|
+
from time import sleep, time
|
12
|
+
from typing import Tuple
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import torch.distributed as dist
|
16
|
+
import torch.nn as nn
|
17
|
+
import torch.optim as optim
|
18
|
+
from monarch._rust_bindings.monarch_extension.debugger import ( # @manual=//monarch/monarch_extension:monarch_extension
|
19
|
+
get_bytes_from_write_action,
|
20
|
+
PdbActor,
|
21
|
+
)
|
22
|
+
from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
|
23
|
+
from monarch.common import opaque_ref
|
24
|
+
from monarch.common.pipe import Pipe
|
25
|
+
from monarch.common.process_group import SingleControllerProcessGroupWrapper
|
26
|
+
from monarch.common.remote import remote
|
27
|
+
|
28
|
+
from torch.utils.data import DataLoader, TensorDataset
|
29
|
+
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
"""
|
34
|
+
Collection of worker-side remote functions that are used in unit tests
|
35
|
+
"""
|
36
|
+
|
37
|
+
|
38
|
+
# code used for testing but useful to have importable (e.g. can refer to remote functions)
|
39
|
+
def do_bogus_tensor_work(x, y, fail_rank=None):
|
40
|
+
if fail_rank is not None and int(os.environ["RANK"]) != fail_rank:
|
41
|
+
return x
|
42
|
+
return x @ y
|
43
|
+
|
44
|
+
|
45
|
+
def set_device_udf_worker(device: int):
|
46
|
+
torch.cuda.set_device(device)
|
47
|
+
return torch.ones(1)
|
48
|
+
|
49
|
+
|
50
|
+
def example_data_loader(p: "Pipe", x: int, y: int):
|
51
|
+
for i in range(x, y):
|
52
|
+
p.send(torch.full((), i))
|
53
|
+
|
54
|
+
|
55
|
+
def example_data_loader_small_pipe(p: "Pipe", iters: int, shape: Tuple[int, int]):
|
56
|
+
t0 = time()
|
57
|
+
for i in range(iters):
|
58
|
+
if time() - t0 > 0.5:
|
59
|
+
p.send(torch.full(shape, -1.0))
|
60
|
+
else:
|
61
|
+
p.send(torch.full(shape, i))
|
62
|
+
|
63
|
+
|
64
|
+
def example_echo_add(p: "Pipe"):
|
65
|
+
while True:
|
66
|
+
p.send(p.recv() + 1 + p.ranks["gpu"])
|
67
|
+
|
68
|
+
|
69
|
+
def log(*args, **kwargs):
|
70
|
+
logger.info(*args, **kwargs)
|
71
|
+
|
72
|
+
|
73
|
+
def remote_sleep(t: float):
|
74
|
+
sleep(t)
|
75
|
+
|
76
|
+
|
77
|
+
def has_nan(t):
|
78
|
+
return torch.isnan(t).any().item()
|
79
|
+
|
80
|
+
|
81
|
+
def new_barrier_hackery(threads):
|
82
|
+
global _barrier
|
83
|
+
_barrier = threading.Barrier(threads)
|
84
|
+
return torch.zeros(1)
|
85
|
+
|
86
|
+
|
87
|
+
def wait_barrier_hackery(t: torch.Tensor):
|
88
|
+
# pyre-fixme[10]: Name `_barrier` is used but not defined.
|
89
|
+
_barrier.wait()
|
90
|
+
|
91
|
+
|
92
|
+
def all_reduce_prop(tensor, *args, **kwargs):
|
93
|
+
tensor.add_(1)
|
94
|
+
return tensor
|
95
|
+
|
96
|
+
|
97
|
+
@remote(propagate=all_reduce_prop)
|
98
|
+
def all_reduce(tensor, group=None, op=dist.ReduceOp.SUM):
|
99
|
+
dist.all_reduce(tensor, op=op, group=group)
|
100
|
+
return tensor
|
101
|
+
|
102
|
+
|
103
|
+
@remote(propagate=lambda *args, **kwargs: torch.ones(1))
|
104
|
+
def barrier(group=None, device_ids=None):
|
105
|
+
if isinstance(group, SingleControllerProcessGroupWrapper):
|
106
|
+
group = group.process_group
|
107
|
+
dist.barrier(group=group, async_op=False, device_ids=device_ids)
|
108
|
+
return torch.ones(1)
|
109
|
+
|
110
|
+
|
111
|
+
@remote(
|
112
|
+
propagate=lambda tensor_list, *args, **kwargs: [
|
113
|
+
torch.zeros_like(t) for t in tensor_list
|
114
|
+
]
|
115
|
+
)
|
116
|
+
def all_gather(
|
117
|
+
tensor_list: list[torch.Tensor],
|
118
|
+
tensor: torch.Tensor,
|
119
|
+
group=None,
|
120
|
+
) -> list[torch.Tensor]:
|
121
|
+
dist.all_gather(tensor_list, tensor, group=group, async_op=False)
|
122
|
+
return tensor_list
|
123
|
+
|
124
|
+
|
125
|
+
@remote(propagate=lambda output_tensor, input_tensor, group=None: torch.zeros(1))
|
126
|
+
def all_gather_into_tensor(output_tensor, input_tensor, group=None):
|
127
|
+
dist.all_gather_into_tensor(output_tensor, input_tensor, group=group)
|
128
|
+
return torch.ones(1)
|
129
|
+
|
130
|
+
|
131
|
+
@remote(propagate=lambda t, *args, **kwargs: torch.ones(1))
|
132
|
+
def isend(t, destination, group=None):
|
133
|
+
if isinstance(group, SingleControllerProcessGroupWrapper):
|
134
|
+
group = group.process_group
|
135
|
+
req = dist.isend(t, destination.item(), group=group)
|
136
|
+
assert isinstance(req.is_completed(), bool)
|
137
|
+
req.wait()
|
138
|
+
return torch.ones(1)
|
139
|
+
|
140
|
+
|
141
|
+
def irecv_prop(t, src, group=None):
|
142
|
+
# irecv mutates its input.
|
143
|
+
t.add_(1)
|
144
|
+
return torch.ones(1)
|
145
|
+
|
146
|
+
|
147
|
+
@remote(propagate=irecv_prop)
|
148
|
+
def irecv(t, src, group=None):
|
149
|
+
if isinstance(group, SingleControllerProcessGroupWrapper):
|
150
|
+
group = group.process_group
|
151
|
+
req = dist.irecv(tensor=t, src=src.item(), group=group)
|
152
|
+
assert isinstance(req.is_completed(), bool)
|
153
|
+
req.wait()
|
154
|
+
return torch.ones(1)
|
155
|
+
|
156
|
+
|
157
|
+
def gonna_pdb():
|
158
|
+
x = 3 + 4
|
159
|
+
import pdb # noqa
|
160
|
+
|
161
|
+
pdb.set_trace()
|
162
|
+
print(x)
|
163
|
+
|
164
|
+
|
165
|
+
def do_some_processing(a_string):
|
166
|
+
return a_string + " processed"
|
167
|
+
|
168
|
+
|
169
|
+
def how_many_of_these_do_you_want(n: int, t: torch.Tensor):
|
170
|
+
return [t + i for i in range(n)]
|
171
|
+
|
172
|
+
|
173
|
+
def remote_chunk(t: torch.Tensor):
|
174
|
+
return t.chunk(4, dim=0)
|
175
|
+
|
176
|
+
|
177
|
+
class TestRemoteAutogradFunction(torch.autograd.Function):
|
178
|
+
@staticmethod
|
179
|
+
def forward(ctx, x, y):
|
180
|
+
ctx.save_for_backward(x)
|
181
|
+
if x.requires_grad:
|
182
|
+
out0 = x * y
|
183
|
+
else:
|
184
|
+
out0 = x + y
|
185
|
+
|
186
|
+
return out0, y, torch.ones(4), 4
|
187
|
+
|
188
|
+
@staticmethod
|
189
|
+
def backward(ctx, dx1, dx2, dx3, dx4):
|
190
|
+
return dx1, dx2
|
191
|
+
|
192
|
+
|
193
|
+
class _TestMultiplyAllReduce(torch.autograd.Function):
|
194
|
+
"Existing user autograd.Function"
|
195
|
+
|
196
|
+
@staticmethod
|
197
|
+
def forward(ctx, x, y, pg):
|
198
|
+
wa = torch.rand(x.shape, device=x.device)
|
199
|
+
ctx.save_for_backward(x, y, wa)
|
200
|
+
ctx.my_property = True
|
201
|
+
ctx.pg = pg
|
202
|
+
z = x * y
|
203
|
+
dist.all_reduce(z, op=dist.ReduceOp.SUM, group=pg)
|
204
|
+
return z
|
205
|
+
|
206
|
+
@staticmethod
|
207
|
+
def backward(ctx, grad_output):
|
208
|
+
x, y, a = ctx.saved_tensors
|
209
|
+
assert ctx.my_property
|
210
|
+
grad_x = grad_output * y
|
211
|
+
grad_y = grad_output * x * a
|
212
|
+
dist.all_reduce(grad_x, op=dist.ReduceOp.SUM, group=ctx.pg)
|
213
|
+
dist.all_reduce(grad_y, op=dist.ReduceOp.SUM, group=ctx.pg)
|
214
|
+
return grad_x, grad_y, None
|
215
|
+
|
216
|
+
|
217
|
+
class SimpleModel(nn.Module):
|
218
|
+
def __init__(self, input_size, hidden_size, output_size):
|
219
|
+
super(SimpleModel, self).__init__()
|
220
|
+
self.fc1 = nn.Linear(input_size, hidden_size)
|
221
|
+
self.relu = nn.ReLU()
|
222
|
+
self.fc2 = nn.Linear(hidden_size, output_size)
|
223
|
+
|
224
|
+
def forward(self, x):
|
225
|
+
x = self.fc1(x)
|
226
|
+
x = self.relu(x)
|
227
|
+
x = self.fc2(x)
|
228
|
+
return x
|
229
|
+
|
230
|
+
|
231
|
+
def setup_state_worker():
|
232
|
+
input_size = 10
|
233
|
+
hidden_size = 20
|
234
|
+
output_size = 1
|
235
|
+
batch_size = 16
|
236
|
+
learning_rate = 0.01
|
237
|
+
|
238
|
+
x = torch.rand(100, input_size).cuda()
|
239
|
+
y = torch.rand(100, output_size).cuda()
|
240
|
+
dataset = TensorDataset(x, y)
|
241
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
242
|
+
|
243
|
+
model = SimpleModel(input_size, hidden_size, output_size).cuda()
|
244
|
+
criterion = nn.MSELoss().cuda()
|
245
|
+
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
|
246
|
+
|
247
|
+
return [
|
248
|
+
opaque_ref.OpaqueRef(obj) for obj in [model, dataloader, criterion, optimizer]
|
249
|
+
]
|
250
|
+
|
251
|
+
|
252
|
+
def iteration_worker(model_ref, dataloader_ref, criterion_ref, optimizer_ref, pg):
|
253
|
+
model = model_ref.value
|
254
|
+
dataloader = dataloader_ref.value
|
255
|
+
criterion = criterion_ref.value
|
256
|
+
optimizer = optimizer_ref.value
|
257
|
+
|
258
|
+
epoch_loss = 0.0
|
259
|
+
for inputs, targets in dataloader:
|
260
|
+
outputs = model(inputs)
|
261
|
+
loss = criterion(outputs, targets)
|
262
|
+
|
263
|
+
optimizer.zero_grad()
|
264
|
+
loss.backward()
|
265
|
+
for param in model.parameters():
|
266
|
+
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pg)
|
267
|
+
optimizer.step()
|
268
|
+
|
269
|
+
epoch_loss += loss.item()
|
270
|
+
return torch.tensor(epoch_loss)
|
271
|
+
|
272
|
+
|
273
|
+
def create_opaque_ref_worker():
|
274
|
+
return opaque_ref.OpaqueRef(nn.Linear(10, 10))
|
275
|
+
|
276
|
+
|
277
|
+
def opaque_ref_key_table_length_worker() -> torch.Tensor:
|
278
|
+
return torch.tensor(len(list(opaque_ref._key_table.keys())))
|
279
|
+
|
280
|
+
|
281
|
+
class WorkerFoo:
|
282
|
+
def __init__(self, v):
|
283
|
+
self.t = torch.full((), v)
|
284
|
+
|
285
|
+
def add(self, b):
|
286
|
+
return self.t + b
|
287
|
+
|
288
|
+
|
289
|
+
def reduce_prop(tensor, *args, **kwargs):
|
290
|
+
return tensor.add_(1)
|
291
|
+
|
292
|
+
|
293
|
+
@remote(propagate=reduce_prop)
|
294
|
+
def reduce(
|
295
|
+
tensor: torch.Tensor,
|
296
|
+
dst: int | None = None,
|
297
|
+
op: dist.ReduceOp = dist.ReduceOp.SUM,
|
298
|
+
group=None,
|
299
|
+
group_dst: int | None = None,
|
300
|
+
) -> torch.Tensor:
|
301
|
+
if isinstance(group, SingleControllerProcessGroupWrapper):
|
302
|
+
group = group.process_group
|
303
|
+
dist.reduce(tensor, dst, op=op, group=group, async_op=False, group_dst=group_dst)
|
304
|
+
return tensor
|
305
|
+
|
306
|
+
|
307
|
+
def reduce_scatter_prop(output, *args, **kwargs):
|
308
|
+
# reduce_scatter mutates its input argument.
|
309
|
+
output.add_(1)
|
310
|
+
return output
|
311
|
+
|
312
|
+
|
313
|
+
@remote(propagate=reduce_scatter_prop)
|
314
|
+
def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None):
|
315
|
+
if isinstance(group, SingleControllerProcessGroupWrapper):
|
316
|
+
group = group.process_group
|
317
|
+
dist.reduce_scatter(output, input_list, op=op, group=group, async_op=False)
|
318
|
+
return output
|
319
|
+
|
320
|
+
|
321
|
+
def reduce_scatter_tensor_prop(tensor, *args, **kwargs):
|
322
|
+
# reduce_scatter_tensor mutates its input argument.
|
323
|
+
tensor.add_(1)
|
324
|
+
return tensor
|
325
|
+
|
326
|
+
|
327
|
+
@remote(propagate=reduce_scatter_tensor_prop)
|
328
|
+
def reduce_scatter_tensor(
|
329
|
+
output_tensor, input_tensor, group=None, op=dist.ReduceOp.SUM
|
330
|
+
):
|
331
|
+
dist.reduce_scatter_tensor(output_tensor, input_tensor, group=group, op=op)
|
332
|
+
return output_tensor
|
333
|
+
|
334
|
+
|
335
|
+
def gather_prop(tensor, gather_list=None, *args, **kwargs) -> torch.Tensor:
|
336
|
+
# Gather mutates its gather_list and does not modify the input tensor.
|
337
|
+
if gather_list is not None:
|
338
|
+
for t in gather_list:
|
339
|
+
t.add_(1)
|
340
|
+
return torch.zeros_like(tensor)
|
341
|
+
|
342
|
+
|
343
|
+
@remote(propagate=gather_prop)
|
344
|
+
def gather(
|
345
|
+
tensor: torch.Tensor,
|
346
|
+
gather_list: list[torch.Tensor] | None = None,
|
347
|
+
dst: int | None = None,
|
348
|
+
group=None,
|
349
|
+
group_dst: int | None = None,
|
350
|
+
) -> torch.Tensor:
|
351
|
+
if isinstance(group, SingleControllerProcessGroupWrapper):
|
352
|
+
group = group.process_group
|
353
|
+
if group_dst is not None:
|
354
|
+
if group_dst != dist.get_rank(group):
|
355
|
+
# Don't set the gather_list on any rank other than the source.
|
356
|
+
gather_list = None
|
357
|
+
elif dst is not None:
|
358
|
+
if dst != dist.get_rank(group):
|
359
|
+
# Don't set the gather_list on any rank other than the source.
|
360
|
+
gather_list = None
|
361
|
+
dist.gather(
|
362
|
+
tensor,
|
363
|
+
gather_list=gather_list,
|
364
|
+
dst=dst,
|
365
|
+
group=group,
|
366
|
+
async_op=False,
|
367
|
+
group_dst=group_dst,
|
368
|
+
)
|
369
|
+
return tensor
|
370
|
+
|
371
|
+
|
372
|
+
# Scatter mutates its input tensor.
|
373
|
+
@remote(propagate=lambda tensor, *args, **kwargs: tensor.add_(1))
|
374
|
+
def scatter(
|
375
|
+
tensor: torch.Tensor,
|
376
|
+
scatter_list: list[torch.Tensor] | None = None,
|
377
|
+
src: int | None = None,
|
378
|
+
group=None,
|
379
|
+
group_src: int | None = None,
|
380
|
+
) -> torch.Tensor:
|
381
|
+
if isinstance(group, SingleControllerProcessGroupWrapper):
|
382
|
+
group = group.process_group
|
383
|
+
if group_src is not None:
|
384
|
+
if group_src != dist.get_rank(group):
|
385
|
+
# Don't set the scatter_list on any rank other than the source.
|
386
|
+
scatter_list = None
|
387
|
+
elif src is not None:
|
388
|
+
if src != dist.get_rank(group):
|
389
|
+
# Don't set the scatter_list on any rank other than the source.
|
390
|
+
scatter_list = None
|
391
|
+
dist.scatter(
|
392
|
+
tensor,
|
393
|
+
scatter_list=scatter_list,
|
394
|
+
src=src,
|
395
|
+
group=group,
|
396
|
+
async_op=False,
|
397
|
+
group_src=group_src,
|
398
|
+
)
|
399
|
+
return tensor
|
400
|
+
|
401
|
+
|
402
|
+
def inner_remote_function_that_fails():
|
403
|
+
raise Exception("Failed to execute inner_remote_function_that_fails")
|
404
|
+
|
405
|
+
|
406
|
+
def outer_remote_function_that_calls_inner():
|
407
|
+
inner_remote_function_that_fails()
|
408
|
+
return torch.zeros(1)
|
409
|
+
|
410
|
+
|
411
|
+
def broadcast_prop(tensor, *args, **kwargs) -> torch.Tensor:
|
412
|
+
# Broadcast mutates its input tensor.
|
413
|
+
return tensor.add_(1)
|
414
|
+
|
415
|
+
|
416
|
+
@remote(propagate=broadcast_prop)
|
417
|
+
def broadcast(
|
418
|
+
tensor: torch.Tensor,
|
419
|
+
src: int | None = None,
|
420
|
+
group=None,
|
421
|
+
group_src: int | None = None,
|
422
|
+
) -> torch.Tensor:
|
423
|
+
if isinstance(group, SingleControllerProcessGroupWrapper):
|
424
|
+
group = group.process_group
|
425
|
+
dist.broadcast(tensor, src=src, group=group, async_op=False, group_src=group_src)
|
426
|
+
return tensor
|
427
|
+
|
428
|
+
|
429
|
+
def all_to_all_prop(
|
430
|
+
output_tensor_list: list[torch.Tensor],
|
431
|
+
input_tensor_list: list[torch.Tensor],
|
432
|
+
*args,
|
433
|
+
**kwargs,
|
434
|
+
) -> list[torch.Tensor]:
|
435
|
+
for t in output_tensor_list:
|
436
|
+
# Mutate the output tensors to ensure that fetches on the output tensor
|
437
|
+
# list are propagated.
|
438
|
+
t.add_(1)
|
439
|
+
return output_tensor_list
|
440
|
+
|
441
|
+
|
442
|
+
@remote(propagate=all_to_all_prop)
|
443
|
+
def all_to_all(
|
444
|
+
output_tensor_list: list[torch.Tensor],
|
445
|
+
input_tensor_list: list[torch.Tensor],
|
446
|
+
group=None,
|
447
|
+
) -> list[torch.Tensor]:
|
448
|
+
if isinstance(group, SingleControllerProcessGroupWrapper):
|
449
|
+
group = group.process_group
|
450
|
+
dist.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=False)
|
451
|
+
return output_tensor_list
|
452
|
+
|
453
|
+
|
454
|
+
def all_to_all_single_prop(output_tensor, *args, **kwargs) -> torch.Tensor:
|
455
|
+
# Mutate the output tensor to ensure that fetches on the output tensor
|
456
|
+
# are propagated.
|
457
|
+
output_tensor.add_(1)
|
458
|
+
return output_tensor
|
459
|
+
|
460
|
+
|
461
|
+
@remote(propagate=all_to_all_single_prop)
|
462
|
+
def all_to_all_single(
|
463
|
+
output_tensor: torch.Tensor, input_tensor: torch.Tensor, group=None
|
464
|
+
) -> torch.Tensor:
|
465
|
+
if isinstance(group, SingleControllerProcessGroupWrapper):
|
466
|
+
group = group.process_group
|
467
|
+
dist.all_to_all_single(output_tensor, input_tensor, group=group)
|
468
|
+
return output_tensor
|
469
|
+
|
470
|
+
|
471
|
+
def test_pdb_actor():
|
472
|
+
pdb_actor = PdbActor()
|
473
|
+
pdb_actor.send(DebuggerAction.Paused())
|
474
|
+
assert isinstance(pdb_actor.receive(), DebuggerAction.Attach)
|
475
|
+
pdb_actor.send(DebuggerAction.Read(4))
|
476
|
+
msg = pdb_actor.receive()
|
477
|
+
assert isinstance(msg, DebuggerAction.Write)
|
478
|
+
assert get_bytes_from_write_action(msg) == b"1234"
|
479
|
+
pdb_actor.send(DebuggerAction.Write(b"5678"))
|
480
|
+
assert isinstance(pdb_actor.receive(), DebuggerAction.Detach)
|
481
|
+
return torch.zeros(1)
|