torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,481 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import os
10
+ import threading
11
+ from time import sleep, time
12
+ from typing import Tuple
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+ import torch.nn as nn
17
+ import torch.optim as optim
18
+ from monarch._rust_bindings.monarch_extension.debugger import ( # @manual=//monarch/monarch_extension:monarch_extension
19
+ get_bytes_from_write_action,
20
+ PdbActor,
21
+ )
22
+ from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
23
+ from monarch.common import opaque_ref
24
+ from monarch.common.pipe import Pipe
25
+ from monarch.common.process_group import SingleControllerProcessGroupWrapper
26
+ from monarch.common.remote import remote
27
+
28
+ from torch.utils.data import DataLoader, TensorDataset
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ """
34
+ Collection of worker-side remote functions that are used in unit tests
35
+ """
36
+
37
+
38
+ # code used for testing but useful to have importable (e.g. can refer to remote functions)
39
+ def do_bogus_tensor_work(x, y, fail_rank=None):
40
+ if fail_rank is not None and int(os.environ["RANK"]) != fail_rank:
41
+ return x
42
+ return x @ y
43
+
44
+
45
+ def set_device_udf_worker(device: int):
46
+ torch.cuda.set_device(device)
47
+ return torch.ones(1)
48
+
49
+
50
+ def example_data_loader(p: "Pipe", x: int, y: int):
51
+ for i in range(x, y):
52
+ p.send(torch.full((), i))
53
+
54
+
55
+ def example_data_loader_small_pipe(p: "Pipe", iters: int, shape: Tuple[int, int]):
56
+ t0 = time()
57
+ for i in range(iters):
58
+ if time() - t0 > 0.5:
59
+ p.send(torch.full(shape, -1.0))
60
+ else:
61
+ p.send(torch.full(shape, i))
62
+
63
+
64
+ def example_echo_add(p: "Pipe"):
65
+ while True:
66
+ p.send(p.recv() + 1 + p.ranks["gpu"])
67
+
68
+
69
+ def log(*args, **kwargs):
70
+ logger.info(*args, **kwargs)
71
+
72
+
73
+ def remote_sleep(t: float):
74
+ sleep(t)
75
+
76
+
77
+ def has_nan(t):
78
+ return torch.isnan(t).any().item()
79
+
80
+
81
+ def new_barrier_hackery(threads):
82
+ global _barrier
83
+ _barrier = threading.Barrier(threads)
84
+ return torch.zeros(1)
85
+
86
+
87
+ def wait_barrier_hackery(t: torch.Tensor):
88
+ # pyre-fixme[10]: Name `_barrier` is used but not defined.
89
+ _barrier.wait()
90
+
91
+
92
+ def all_reduce_prop(tensor, *args, **kwargs):
93
+ tensor.add_(1)
94
+ return tensor
95
+
96
+
97
+ @remote(propagate=all_reduce_prop)
98
+ def all_reduce(tensor, group=None, op=dist.ReduceOp.SUM):
99
+ dist.all_reduce(tensor, op=op, group=group)
100
+ return tensor
101
+
102
+
103
+ @remote(propagate=lambda *args, **kwargs: torch.ones(1))
104
+ def barrier(group=None, device_ids=None):
105
+ if isinstance(group, SingleControllerProcessGroupWrapper):
106
+ group = group.process_group
107
+ dist.barrier(group=group, async_op=False, device_ids=device_ids)
108
+ return torch.ones(1)
109
+
110
+
111
+ @remote(
112
+ propagate=lambda tensor_list, *args, **kwargs: [
113
+ torch.zeros_like(t) for t in tensor_list
114
+ ]
115
+ )
116
+ def all_gather(
117
+ tensor_list: list[torch.Tensor],
118
+ tensor: torch.Tensor,
119
+ group=None,
120
+ ) -> list[torch.Tensor]:
121
+ dist.all_gather(tensor_list, tensor, group=group, async_op=False)
122
+ return tensor_list
123
+
124
+
125
+ @remote(propagate=lambda output_tensor, input_tensor, group=None: torch.zeros(1))
126
+ def all_gather_into_tensor(output_tensor, input_tensor, group=None):
127
+ dist.all_gather_into_tensor(output_tensor, input_tensor, group=group)
128
+ return torch.ones(1)
129
+
130
+
131
+ @remote(propagate=lambda t, *args, **kwargs: torch.ones(1))
132
+ def isend(t, destination, group=None):
133
+ if isinstance(group, SingleControllerProcessGroupWrapper):
134
+ group = group.process_group
135
+ req = dist.isend(t, destination.item(), group=group)
136
+ assert isinstance(req.is_completed(), bool)
137
+ req.wait()
138
+ return torch.ones(1)
139
+
140
+
141
+ def irecv_prop(t, src, group=None):
142
+ # irecv mutates its input.
143
+ t.add_(1)
144
+ return torch.ones(1)
145
+
146
+
147
+ @remote(propagate=irecv_prop)
148
+ def irecv(t, src, group=None):
149
+ if isinstance(group, SingleControllerProcessGroupWrapper):
150
+ group = group.process_group
151
+ req = dist.irecv(tensor=t, src=src.item(), group=group)
152
+ assert isinstance(req.is_completed(), bool)
153
+ req.wait()
154
+ return torch.ones(1)
155
+
156
+
157
+ def gonna_pdb():
158
+ x = 3 + 4
159
+ import pdb # noqa
160
+
161
+ pdb.set_trace()
162
+ print(x)
163
+
164
+
165
+ def do_some_processing(a_string):
166
+ return a_string + " processed"
167
+
168
+
169
+ def how_many_of_these_do_you_want(n: int, t: torch.Tensor):
170
+ return [t + i for i in range(n)]
171
+
172
+
173
+ def remote_chunk(t: torch.Tensor):
174
+ return t.chunk(4, dim=0)
175
+
176
+
177
+ class TestRemoteAutogradFunction(torch.autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, x, y):
180
+ ctx.save_for_backward(x)
181
+ if x.requires_grad:
182
+ out0 = x * y
183
+ else:
184
+ out0 = x + y
185
+
186
+ return out0, y, torch.ones(4), 4
187
+
188
+ @staticmethod
189
+ def backward(ctx, dx1, dx2, dx3, dx4):
190
+ return dx1, dx2
191
+
192
+
193
+ class _TestMultiplyAllReduce(torch.autograd.Function):
194
+ "Existing user autograd.Function"
195
+
196
+ @staticmethod
197
+ def forward(ctx, x, y, pg):
198
+ wa = torch.rand(x.shape, device=x.device)
199
+ ctx.save_for_backward(x, y, wa)
200
+ ctx.my_property = True
201
+ ctx.pg = pg
202
+ z = x * y
203
+ dist.all_reduce(z, op=dist.ReduceOp.SUM, group=pg)
204
+ return z
205
+
206
+ @staticmethod
207
+ def backward(ctx, grad_output):
208
+ x, y, a = ctx.saved_tensors
209
+ assert ctx.my_property
210
+ grad_x = grad_output * y
211
+ grad_y = grad_output * x * a
212
+ dist.all_reduce(grad_x, op=dist.ReduceOp.SUM, group=ctx.pg)
213
+ dist.all_reduce(grad_y, op=dist.ReduceOp.SUM, group=ctx.pg)
214
+ return grad_x, grad_y, None
215
+
216
+
217
+ class SimpleModel(nn.Module):
218
+ def __init__(self, input_size, hidden_size, output_size):
219
+ super(SimpleModel, self).__init__()
220
+ self.fc1 = nn.Linear(input_size, hidden_size)
221
+ self.relu = nn.ReLU()
222
+ self.fc2 = nn.Linear(hidden_size, output_size)
223
+
224
+ def forward(self, x):
225
+ x = self.fc1(x)
226
+ x = self.relu(x)
227
+ x = self.fc2(x)
228
+ return x
229
+
230
+
231
+ def setup_state_worker():
232
+ input_size = 10
233
+ hidden_size = 20
234
+ output_size = 1
235
+ batch_size = 16
236
+ learning_rate = 0.01
237
+
238
+ x = torch.rand(100, input_size).cuda()
239
+ y = torch.rand(100, output_size).cuda()
240
+ dataset = TensorDataset(x, y)
241
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
242
+
243
+ model = SimpleModel(input_size, hidden_size, output_size).cuda()
244
+ criterion = nn.MSELoss().cuda()
245
+ optimizer = optim.SGD(model.parameters(), lr=learning_rate)
246
+
247
+ return [
248
+ opaque_ref.OpaqueRef(obj) for obj in [model, dataloader, criterion, optimizer]
249
+ ]
250
+
251
+
252
+ def iteration_worker(model_ref, dataloader_ref, criterion_ref, optimizer_ref, pg):
253
+ model = model_ref.value
254
+ dataloader = dataloader_ref.value
255
+ criterion = criterion_ref.value
256
+ optimizer = optimizer_ref.value
257
+
258
+ epoch_loss = 0.0
259
+ for inputs, targets in dataloader:
260
+ outputs = model(inputs)
261
+ loss = criterion(outputs, targets)
262
+
263
+ optimizer.zero_grad()
264
+ loss.backward()
265
+ for param in model.parameters():
266
+ dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pg)
267
+ optimizer.step()
268
+
269
+ epoch_loss += loss.item()
270
+ return torch.tensor(epoch_loss)
271
+
272
+
273
+ def create_opaque_ref_worker():
274
+ return opaque_ref.OpaqueRef(nn.Linear(10, 10))
275
+
276
+
277
+ def opaque_ref_key_table_length_worker() -> torch.Tensor:
278
+ return torch.tensor(len(list(opaque_ref._key_table.keys())))
279
+
280
+
281
+ class WorkerFoo:
282
+ def __init__(self, v):
283
+ self.t = torch.full((), v)
284
+
285
+ def add(self, b):
286
+ return self.t + b
287
+
288
+
289
+ def reduce_prop(tensor, *args, **kwargs):
290
+ return tensor.add_(1)
291
+
292
+
293
+ @remote(propagate=reduce_prop)
294
+ def reduce(
295
+ tensor: torch.Tensor,
296
+ dst: int | None = None,
297
+ op: dist.ReduceOp = dist.ReduceOp.SUM,
298
+ group=None,
299
+ group_dst: int | None = None,
300
+ ) -> torch.Tensor:
301
+ if isinstance(group, SingleControllerProcessGroupWrapper):
302
+ group = group.process_group
303
+ dist.reduce(tensor, dst, op=op, group=group, async_op=False, group_dst=group_dst)
304
+ return tensor
305
+
306
+
307
+ def reduce_scatter_prop(output, *args, **kwargs):
308
+ # reduce_scatter mutates its input argument.
309
+ output.add_(1)
310
+ return output
311
+
312
+
313
+ @remote(propagate=reduce_scatter_prop)
314
+ def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None):
315
+ if isinstance(group, SingleControllerProcessGroupWrapper):
316
+ group = group.process_group
317
+ dist.reduce_scatter(output, input_list, op=op, group=group, async_op=False)
318
+ return output
319
+
320
+
321
+ def reduce_scatter_tensor_prop(tensor, *args, **kwargs):
322
+ # reduce_scatter_tensor mutates its input argument.
323
+ tensor.add_(1)
324
+ return tensor
325
+
326
+
327
+ @remote(propagate=reduce_scatter_tensor_prop)
328
+ def reduce_scatter_tensor(
329
+ output_tensor, input_tensor, group=None, op=dist.ReduceOp.SUM
330
+ ):
331
+ dist.reduce_scatter_tensor(output_tensor, input_tensor, group=group, op=op)
332
+ return output_tensor
333
+
334
+
335
+ def gather_prop(tensor, gather_list=None, *args, **kwargs) -> torch.Tensor:
336
+ # Gather mutates its gather_list and does not modify the input tensor.
337
+ if gather_list is not None:
338
+ for t in gather_list:
339
+ t.add_(1)
340
+ return torch.zeros_like(tensor)
341
+
342
+
343
+ @remote(propagate=gather_prop)
344
+ def gather(
345
+ tensor: torch.Tensor,
346
+ gather_list: list[torch.Tensor] | None = None,
347
+ dst: int | None = None,
348
+ group=None,
349
+ group_dst: int | None = None,
350
+ ) -> torch.Tensor:
351
+ if isinstance(group, SingleControllerProcessGroupWrapper):
352
+ group = group.process_group
353
+ if group_dst is not None:
354
+ if group_dst != dist.get_rank(group):
355
+ # Don't set the gather_list on any rank other than the source.
356
+ gather_list = None
357
+ elif dst is not None:
358
+ if dst != dist.get_rank(group):
359
+ # Don't set the gather_list on any rank other than the source.
360
+ gather_list = None
361
+ dist.gather(
362
+ tensor,
363
+ gather_list=gather_list,
364
+ dst=dst,
365
+ group=group,
366
+ async_op=False,
367
+ group_dst=group_dst,
368
+ )
369
+ return tensor
370
+
371
+
372
+ # Scatter mutates its input tensor.
373
+ @remote(propagate=lambda tensor, *args, **kwargs: tensor.add_(1))
374
+ def scatter(
375
+ tensor: torch.Tensor,
376
+ scatter_list: list[torch.Tensor] | None = None,
377
+ src: int | None = None,
378
+ group=None,
379
+ group_src: int | None = None,
380
+ ) -> torch.Tensor:
381
+ if isinstance(group, SingleControllerProcessGroupWrapper):
382
+ group = group.process_group
383
+ if group_src is not None:
384
+ if group_src != dist.get_rank(group):
385
+ # Don't set the scatter_list on any rank other than the source.
386
+ scatter_list = None
387
+ elif src is not None:
388
+ if src != dist.get_rank(group):
389
+ # Don't set the scatter_list on any rank other than the source.
390
+ scatter_list = None
391
+ dist.scatter(
392
+ tensor,
393
+ scatter_list=scatter_list,
394
+ src=src,
395
+ group=group,
396
+ async_op=False,
397
+ group_src=group_src,
398
+ )
399
+ return tensor
400
+
401
+
402
+ def inner_remote_function_that_fails():
403
+ raise Exception("Failed to execute inner_remote_function_that_fails")
404
+
405
+
406
+ def outer_remote_function_that_calls_inner():
407
+ inner_remote_function_that_fails()
408
+ return torch.zeros(1)
409
+
410
+
411
+ def broadcast_prop(tensor, *args, **kwargs) -> torch.Tensor:
412
+ # Broadcast mutates its input tensor.
413
+ return tensor.add_(1)
414
+
415
+
416
+ @remote(propagate=broadcast_prop)
417
+ def broadcast(
418
+ tensor: torch.Tensor,
419
+ src: int | None = None,
420
+ group=None,
421
+ group_src: int | None = None,
422
+ ) -> torch.Tensor:
423
+ if isinstance(group, SingleControllerProcessGroupWrapper):
424
+ group = group.process_group
425
+ dist.broadcast(tensor, src=src, group=group, async_op=False, group_src=group_src)
426
+ return tensor
427
+
428
+
429
+ def all_to_all_prop(
430
+ output_tensor_list: list[torch.Tensor],
431
+ input_tensor_list: list[torch.Tensor],
432
+ *args,
433
+ **kwargs,
434
+ ) -> list[torch.Tensor]:
435
+ for t in output_tensor_list:
436
+ # Mutate the output tensors to ensure that fetches on the output tensor
437
+ # list are propagated.
438
+ t.add_(1)
439
+ return output_tensor_list
440
+
441
+
442
+ @remote(propagate=all_to_all_prop)
443
+ def all_to_all(
444
+ output_tensor_list: list[torch.Tensor],
445
+ input_tensor_list: list[torch.Tensor],
446
+ group=None,
447
+ ) -> list[torch.Tensor]:
448
+ if isinstance(group, SingleControllerProcessGroupWrapper):
449
+ group = group.process_group
450
+ dist.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=False)
451
+ return output_tensor_list
452
+
453
+
454
+ def all_to_all_single_prop(output_tensor, *args, **kwargs) -> torch.Tensor:
455
+ # Mutate the output tensor to ensure that fetches on the output tensor
456
+ # are propagated.
457
+ output_tensor.add_(1)
458
+ return output_tensor
459
+
460
+
461
+ @remote(propagate=all_to_all_single_prop)
462
+ def all_to_all_single(
463
+ output_tensor: torch.Tensor, input_tensor: torch.Tensor, group=None
464
+ ) -> torch.Tensor:
465
+ if isinstance(group, SingleControllerProcessGroupWrapper):
466
+ group = group.process_group
467
+ dist.all_to_all_single(output_tensor, input_tensor, group=group)
468
+ return output_tensor
469
+
470
+
471
+ def test_pdb_actor():
472
+ pdb_actor = PdbActor()
473
+ pdb_actor.send(DebuggerAction.Paused())
474
+ assert isinstance(pdb_actor.receive(), DebuggerAction.Attach)
475
+ pdb_actor.send(DebuggerAction.Read(4))
476
+ msg = pdb_actor.receive()
477
+ assert isinstance(msg, DebuggerAction.Write)
478
+ assert get_bytes_from_write_action(msg) == b"1234"
479
+ pdb_actor.send(DebuggerAction.Write(b"5678"))
480
+ assert isinstance(pdb_actor.receive(), DebuggerAction.Detach)
481
+ return torch.zeros(1)