torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,1271 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import itertools
9
+ import math
10
+ import sys
11
+ import traceback
12
+ from enum import Enum
13
+ from typing import Callable, ContextManager, Tuple
14
+ from unittest.mock import patch
15
+
16
+ import monarch
17
+ import pytest
18
+
19
+ import torch
20
+ from monarch import (
21
+ fetch_shard,
22
+ inspect,
23
+ no_mesh,
24
+ OpaqueRef,
25
+ Pipe,
26
+ remote,
27
+ remote_generator,
28
+ RemoteException,
29
+ Stream,
30
+ )
31
+ from monarch._testing import BackendType, TestingContext
32
+ from monarch.builtins.log import log_remote
33
+ from monarch.builtins.random import set_manual_seed_remote
34
+ from monarch.cached_remote_function import remote_autograd_function
35
+ from monarch.common import remote as remote_module
36
+ from monarch.common.device_mesh import DeviceMesh
37
+ from monarch.common.remote import Remote
38
+
39
+ from monarch.opaque_module import OpaqueModule
40
+ from monarch.opaque_object import opaque_method, OpaqueObject
41
+ from monarch.worker._testing_function import (
42
+ all_gather,
43
+ all_gather_into_tensor,
44
+ all_reduce,
45
+ all_to_all,
46
+ all_to_all_single,
47
+ barrier,
48
+ broadcast,
49
+ gather,
50
+ irecv,
51
+ isend,
52
+ reduce,
53
+ reduce_scatter,
54
+ reduce_scatter_tensor,
55
+ scatter,
56
+ )
57
+ from monarch_supervisor.logging import fix_exception_lines
58
+ from torch.distributed import ReduceOp
59
+
60
+
61
+ def custom_excepthook(exc_type, exc_value, exc_traceback):
62
+ tb_lines = fix_exception_lines(
63
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
64
+ )
65
+ print("\n".join(tb_lines), file=sys.stderr)
66
+
67
+
68
+ sys.excepthook = custom_excepthook
69
+
70
+
71
+ def _set_device_udf(*args):
72
+ return torch.zeros(1)
73
+
74
+
75
+ set_device_udf = remote(
76
+ "monarch.worker._testing_function.set_device_udf_worker", propagate=_set_device_udf
77
+ )
78
+
79
+ rlist = remote("builtins.list", propagate=lambda elem: elem)
80
+
81
+
82
+ def _do_bogus_tensor_work(x, y, fail_rank=None):
83
+ return x + y # real function actually does x @ y
84
+
85
+
86
+ do_bogus_tensor_work = remote(
87
+ "monarch.worker._testing_function.do_bogus_tensor_work",
88
+ propagate=_do_bogus_tensor_work,
89
+ )
90
+
91
+
92
+ @remote_generator("monarch.worker._testing_function.example_echo_add")
93
+ def example_echo_add(p: "Pipe"):
94
+ while True:
95
+ yield p.recv() + 1
96
+
97
+
98
+ @remote_generator("monarch.worker._testing_function.example_data_loader")
99
+ def example_data_loader(p: "Pipe", x, y):
100
+ for _i in range(x, y):
101
+ yield torch.zeros(())
102
+
103
+
104
+ @remote_generator(
105
+ "monarch.worker._testing_function.example_data_loader_small_pipe",
106
+ max_messages=1,
107
+ )
108
+ def example_data_loader_small_pipe(p: "Pipe", iters: int, shape: Tuple[int, int]):
109
+ for _i in range(iters):
110
+ yield torch.zeros(shape)
111
+
112
+
113
+ sleep = remote("monarch.worker._testing_function.remote_sleep", propagate="inspect")
114
+
115
+ new_barrier_hackery = remote(
116
+ "monarch.worker._testing_function.new_barrier_hackery",
117
+ propagate=lambda threads: torch.zeros(1),
118
+ )
119
+
120
+ wait_barrier_hackery = remote(
121
+ "monarch.worker._testing_function.wait_barrier_hackery",
122
+ propagate=lambda t: None,
123
+ )
124
+
125
+ setup_state = remote(
126
+ "monarch.worker._testing_function.setup_state_worker",
127
+ propagate=lambda: [OpaqueRef(None) for _ in range(4)],
128
+ )
129
+
130
+ iteration = remote(
131
+ "monarch.worker._testing_function.iteration_worker",
132
+ propagate=lambda model, dataloader, criterion, optimizer, pg: torch.zeros(1),
133
+ )
134
+
135
+ opaque_ref_key_table_length = remote(
136
+ "monarch.worker._testing_function.opaque_ref_key_table_length_worker",
137
+ propagate=lambda: torch.zeros(1),
138
+ )
139
+
140
+ create_opaque_ref = remote(
141
+ "monarch.worker._testing_function.create_opaque_ref_worker",
142
+ propagate=lambda: OpaqueRef(None),
143
+ )
144
+
145
+ outer_remote_function_that_calls_inner = remote(
146
+ "monarch.worker._testing_function.outer_remote_function_that_calls_inner",
147
+ propagate=lambda: torch.zeros(1),
148
+ )
149
+
150
+
151
+ @pytest.fixture(scope="module", autouse=True)
152
+ def testing_context():
153
+ global local
154
+ with TestingContext() as local:
155
+ yield
156
+
157
+
158
+ class RemoteFunctionsTestBase:
159
+ @classmethod
160
+ def local_device_mesh(
161
+ cls,
162
+ num_hosts: int,
163
+ gpu_per_host: int,
164
+ backend_type: BackendType,
165
+ activate: bool = True,
166
+ ) -> ContextManager[DeviceMesh]:
167
+ # pyre-fixme[10]: pytest defines this fixture.
168
+ return local.local_device_mesh(
169
+ num_hosts,
170
+ gpu_per_host,
171
+ activate,
172
+ rust=backend_type == BackendType.RS,
173
+ )
174
+
175
+
176
+ @pytest.mark.skipif(
177
+ torch.cuda.device_count() < 2,
178
+ reason="Not enough GPUs, this test requires at least 2 GPUs",
179
+ )
180
+ # Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
181
+ # out is not counted as a failure, so we set a more restrictive timeout to
182
+ # ensure we see a hard failure in CI.
183
+ @pytest.mark.timeout(120)
184
+ @pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
185
+ class TestRemoteFunctions(RemoteFunctionsTestBase):
186
+ @classmethod
187
+ def do_test_reduce_scatter_tensor(cls, backend_type, reduce_op, expected_tensor):
188
+ n_gpus = 2
189
+ with cls.local_device_mesh(2, n_gpus, backend_type) as device_mesh:
190
+ rank = device_mesh.rank("host") * n_gpus + device_mesh.rank("gpu")
191
+ tensor_in = rank * torch.arange(0, 8, device="cuda", dtype=float).reshape(
192
+ 4, 2
193
+ )
194
+ tensor_out = torch.arange(2, device="cuda", dtype=float)
195
+ pg = device_mesh.process_group(("host", "gpu"))
196
+
197
+ reduce_scatter_tensor(tensor_out, tensor_in, op=reduce_op, group=pg)
198
+
199
+ for host in range(2):
200
+ for gpu in range(n_gpus):
201
+ rank = 2 * host + gpu
202
+ local_tensor_out = inspect(tensor_out, {"host": host, "gpu": gpu})
203
+ with no_mesh.activate():
204
+ assert torch.equal(
205
+ local_tensor_out,
206
+ expected_tensor[rank],
207
+ )
208
+
209
+ @classmethod
210
+ def do_test_reduce_scatter_tensor_subgroup(
211
+ cls,
212
+ backend_type: BackendType,
213
+ reduce_op,
214
+ expected_tensor_host_group: torch.Tensor,
215
+ expected_tensor_gpu_group: torch.Tensor,
216
+ ) -> None:
217
+ n_gpus = 2
218
+ with cls.local_device_mesh(2, n_gpus, backend_type) as device_mesh:
219
+ # Use a group smaller than the world size.
220
+ host_pg = device_mesh.process_group("host")
221
+ gpu_pg = device_mesh.process_group("gpu")
222
+ # host_rank = device_mesh.rank("host")
223
+ # gpu_rank = device_mesh.rank("gpu")
224
+ rank = device_mesh.rank(("host", "gpu"))
225
+
226
+ tensor_in = rank * torch.arange(
227
+ 0, 8, device="cuda", dtype=torch.float32
228
+ ).reshape(4, 2)
229
+
230
+ gpu_tensor_out = torch.zeros(4, device="cuda", dtype=torch.float32)
231
+ reduce_scatter_tensor(gpu_tensor_out, tensor_in, op=reduce_op, group=gpu_pg)
232
+
233
+ tensor_in = rank * torch.arange(
234
+ 0, 8, device="cuda", dtype=torch.float32
235
+ ).reshape(4, 2)
236
+ host_tensor_out = torch.zeros(4, device="cuda", dtype=torch.float32)
237
+ reduce_scatter_tensor(
238
+ host_tensor_out, tensor_in, op=reduce_op, group=host_pg
239
+ )
240
+
241
+ for host in range(2):
242
+ for gpu in range(n_gpus):
243
+ rank = host * 2 + gpu
244
+ local_gpu_tensor_out = inspect(
245
+ gpu_tensor_out, {"host": host, "gpu": gpu}
246
+ )
247
+ local_host_tensor_out = inspect(
248
+ host_tensor_out, {"host": host, "gpu": gpu}
249
+ )
250
+ with no_mesh.activate():
251
+ assert torch.equal(
252
+ local_host_tensor_out,
253
+ expected_tensor_host_group[rank],
254
+ ), f"{rank=}, {host=}, {gpu=}"
255
+ assert torch.equal(
256
+ local_gpu_tensor_out,
257
+ expected_tensor_gpu_group[rank],
258
+ ), f"{rank=}, {host=}, {gpu=}"
259
+
260
+ @classmethod
261
+ def do_test_reduce_scatter(
262
+ cls,
263
+ backend_type: BackendType,
264
+ reduce_op: ReduceOp,
265
+ expected_tensor: torch.Tensor,
266
+ ) -> None:
267
+ n_gpus = 2
268
+ with cls.local_device_mesh(2, n_gpus, backend_type) as device_mesh:
269
+ rank = device_mesh.rank("host") * n_gpus + device_mesh.rank("gpu")
270
+ tensor_in = rank * torch.arange(0, 8, device="cuda", dtype=torch.float32)
271
+ tensor_out = torch.arange(2, device="cuda", dtype=torch.float32)
272
+ pg = device_mesh.process_group(("host", "gpu"))
273
+
274
+ tensor_out = reduce_scatter(
275
+ tensor_out,
276
+ list(torch.chunk(tensor_in, 2 * n_gpus)),
277
+ op=reduce_op,
278
+ group=pg,
279
+ )
280
+
281
+ for host in range(2):
282
+ for gpu in range(n_gpus):
283
+ rank = 2 * host + gpu
284
+ local_tensor_out = inspect(tensor_out, {"host": host, "gpu": gpu})
285
+ with no_mesh.activate():
286
+ assert torch.equal(
287
+ local_tensor_out,
288
+ expected_tensor[rank],
289
+ )
290
+
291
+ @classmethod
292
+ def do_test_all_reduce(cls, backend_type, reduce_op, expected_tensor):
293
+ n_gpus = 2
294
+ with cls.local_device_mesh(2, n_gpus, backend_type) as device_mesh:
295
+ rank = device_mesh.rank(("host", "gpu"))
296
+ tensor_in = rank * torch.arange(0, 8, device="cuda", dtype=float).reshape(
297
+ 4, 2
298
+ )
299
+ pg = device_mesh.process_group(("host", "gpu"))
300
+
301
+ tensor_out = all_reduce(tensor_in, op=reduce_op, group=pg)
302
+
303
+ for host in range(2):
304
+ for gpu in range(n_gpus):
305
+ local_tensor_out = inspect(tensor_out, {"host": host, "gpu": gpu})
306
+ with no_mesh.activate():
307
+ assert torch.equal(
308
+ local_tensor_out,
309
+ expected_tensor,
310
+ )
311
+
312
+ def test_hello(self, backend_type):
313
+ with self.local_device_mesh(2, 2, backend_type):
314
+ log_remote("hello, world")
315
+
316
+ def test_eager_remote_function_failed(self, backend_type):
317
+ if backend_type == BackendType.PY:
318
+ pytest.skip("Python support not planned for this test")
319
+ with self.local_device_mesh(1, 2, backend_type) as _:
320
+ x = torch.rand(3, 4)
321
+ y = torch.rand(3, 4)
322
+ z = do_bogus_tensor_work(x, y, fail_rank=1)
323
+ a = z + x
324
+ with pytest.raises(RemoteException, match="do_bogus_tensor_work"):
325
+ # NCCL init is slow, and fails on internal RE!
326
+ _ = fetch_shard(a).result(timeout=40)
327
+
328
+ def test_set_device_inside_udf_fails_with_explanation(self, backend_type):
329
+ if backend_type == BackendType.PY:
330
+ pytest.skip("Python support not planned for this test")
331
+ with self.local_device_mesh(2, 2, backend_type):
332
+ t = set_device_udf(2)
333
+ try:
334
+ inspect(t)
335
+ except RemoteException as e:
336
+ backtrace = "\n".join([frame.name for frame in e.worker_frames])
337
+ assert "are available to monarch worker" in backtrace
338
+
339
+ def test_simple_tensors(self, backend_type):
340
+ with self.local_device_mesh(2, 2, backend_type):
341
+ x = torch.rand(3, 4)
342
+ y = x + x
343
+ log_remote("%s %s", x, y)
344
+ z = torch.std_mean(x)
345
+ log_remote("%s", z)
346
+
347
+ def test_user_call(self, backend_type):
348
+ with self.local_device_mesh(2, 2, backend_type) as _:
349
+ x = torch.rand(3, 4)
350
+ y = rlist((x + 1, x))
351
+ log_remote("%s", y)
352
+
353
+ # resume monday:
354
+ # 1. tensor ctor resource guard (done)
355
+ # 2. __torch_dispatch__ forward of normal ops (done)
356
+ # 3. collectives created for device mesh
357
+ # 4. implement comms APIs
358
+ # 5. transfer tensor back, and simple future to wait for result.
359
+
360
+ def test_remote_function_with_comms_full_mesh(self, backend_type):
361
+ nGPUs = 2
362
+ with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
363
+ pg = device_mesh.process_group(("host", "gpu"))
364
+ myrank = (
365
+ (device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
366
+ )
367
+ x = torch.ones((3, 4), device="cuda") * myrank
368
+
369
+ reduce = all_reduce(x, group=pg)
370
+ local_reduce = fetch_shard(reduce).result()
371
+ assert torch.equal(local_reduce, torch.ones(3, 4) * 18)
372
+
373
+ def test_remote_function_with_comms_by_dimension(self, backend_type):
374
+ nGPUs = 2
375
+ with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
376
+ pg = device_mesh.process_group(("gpu",))
377
+ myrank = (
378
+ (device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
379
+ )
380
+ x = torch.ones((3, 4), device="cuda") * myrank
381
+ reduce = all_reduce(x, group=pg)
382
+ local_reduce_host_0 = fetch_shard(reduce).result()
383
+ local_reduce_host_1 = fetch_shard(reduce, {"gpu": 1, "host": 1}).result()
384
+ assert torch.equal(local_reduce_host_0, torch.ones(3, 4) * 7)
385
+ assert torch.equal(local_reduce_host_1, torch.ones(3, 4) * 11)
386
+
387
+ with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
388
+ pg = device_mesh.process_group(("host",))
389
+ myrank = (
390
+ (device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
391
+ )
392
+ x = torch.ones((3, 4), device="cuda") * myrank
393
+ reduce = all_reduce(x, group=pg)
394
+ local_reduce_gpu_0 = fetch_shard(reduce).result()
395
+ local_reduce_gpu_2 = fetch_shard(reduce, {"gpu": 1, "host": 0}).result()
396
+ assert torch.equal(local_reduce_gpu_0, torch.ones(3, 4) * 8)
397
+
398
+ assert torch.equal(local_reduce_gpu_2, torch.ones(3, 4) * 10)
399
+
400
+ def test_remote_function_with_comms_sub_mesh(self, backend_type):
401
+ nGPUs = 2
402
+ with self.local_device_mesh(
403
+ 2, nGPUs, backend_type, activate=False
404
+ ) as device_mesh:
405
+ host1 = device_mesh(host=1)
406
+ with host1.activate():
407
+ pg = device_mesh.process_group(("gpu",))
408
+ myrank = (
409
+ (device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
410
+ )
411
+ x = torch.ones((3, 4), device="cuda") * myrank
412
+ reduce = all_reduce(x, group=pg)
413
+ local_reduce = fetch_shard(reduce).result()
414
+
415
+ assert torch.equal(local_reduce, torch.ones(3, 4) * 11)
416
+
417
+ host0 = device_mesh(host=0)
418
+ with host0.activate():
419
+ pg = device_mesh.process_group(("gpu",))
420
+ myrank = (
421
+ (device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
422
+ )
423
+ x = torch.ones((3, 4), device="cuda") * myrank
424
+ reduce = all_reduce(x, group=pg)
425
+ local_reduce = fetch_shard(reduce).result()
426
+
427
+ assert torch.equal(local_reduce, torch.ones(3, 4) * 7)
428
+
429
+ def test_remote_exception(self, backend_type):
430
+ with self.local_device_mesh(2, 2, backend_type) as _:
431
+ x = torch.rand(3, 4)
432
+ y = torch.rand(3, 4)
433
+ z = do_bogus_tensor_work(x, y)
434
+ a = z + x
435
+ b = x + y
436
+ with pytest.raises(RemoteException, match="do_bogus_tensor_work"):
437
+ # NCCL init is slow, and fails on internal RE!
438
+ _ = fetch_shard(a).result(timeout=20)
439
+ # but values not dependent on z are fine
440
+ fetch_shard(b).result(timeout=10)
441
+
442
+ def test_remote_function_barrier(self, backend_type):
443
+ if backend_type == BackendType.PY:
444
+ pytest.skip("FIXME: Python support for this function")
445
+ nGPUs = 2
446
+ with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
447
+ pg = device_mesh.process_group(("host", "gpu"))
448
+ finished = barrier(group=pg)
449
+ local = fetch_shard(finished).result()
450
+ assert local.item() == 1.0
451
+
452
+ def test_remote_function_all_gather(self, backend_type: BackendType) -> None:
453
+ nGPUs = 2
454
+ with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
455
+ myrank = (
456
+ (device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
457
+ )
458
+ # Don't start at zero to ensure there are no leftover zeros.
459
+ tensor_in = torch.arange(1, 3, device="cuda") * myrank
460
+ world_size = 2 * nGPUs
461
+ tensor_out = list(
462
+ torch.zeros(2 * world_size, dtype=torch.int64, device="cuda").chunk(
463
+ world_size
464
+ )
465
+ )
466
+ pg = device_mesh.process_group(("host", "gpu"))
467
+
468
+ tensor_out = all_gather(tensor_out, tensor_in, group=pg)
469
+ local_tensor_out = inspect(tensor_out)
470
+
471
+ t0, t1, t2, t3 = local_tensor_out
472
+ assert torch.equal(t0, torch.tensor([3, 6]))
473
+ assert torch.equal(t1, torch.tensor([4, 8]))
474
+ assert torch.equal(t2, torch.tensor([5, 10]))
475
+ assert torch.equal(t3, torch.tensor([6, 12]))
476
+
477
+ def test_remote_function_all_gather_into_tensor(self, backend_type):
478
+ nGPUs = 2
479
+ with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
480
+ myrank = (
481
+ (device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
482
+ )
483
+ # Don't start at zero to ensure there are no leftover zeros.
484
+ tensor_in = torch.arange(1, 3, device="cuda") * myrank
485
+ tensor_out = torch.zeros(2 * nGPUs * 2, dtype=torch.int64, device="cuda")
486
+ pg = device_mesh.process_group(("host", "gpu"))
487
+
488
+ finished = all_gather_into_tensor(tensor_out, tensor_in, group=pg)
489
+ local_finished = inspect(finished)
490
+ local_tensor_out = inspect(tensor_out)
491
+
492
+ assert local_finished.item() == 1.0
493
+ assert torch.equal(local_tensor_out, torch.tensor([3, 6, 4, 8, 5, 10, 6, 12]))
494
+
495
+ def test_remote_function_isend(self, backend_type):
496
+ nGPUs = 2
497
+ with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
498
+ pg = device_mesh.process_group(("host",))
499
+ host_0_mesh = device_mesh(host=0)
500
+ host_1_mesh = device_mesh(host=1)
501
+ with host_0_mesh.activate():
502
+ to_rank = (device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank(
503
+ "gpu"
504
+ )
505
+ t0 = torch.ones(1, device="cuda")
506
+ finished0 = isend(t0, to_rank, group=pg)
507
+ with host_1_mesh.activate():
508
+ from_rank = (device_mesh.rank("host") - 1) * nGPUs + device_mesh.rank(
509
+ "gpu"
510
+ )
511
+ t1 = torch.zeros(1, device="cuda")
512
+ finished1 = irecv(t1, from_rank, group=pg)
513
+
514
+ with host_0_mesh.activate():
515
+ local_finished_0 = inspect(finished0)
516
+ with host_1_mesh.activate():
517
+ local_finished_1 = inspect(finished1)
518
+ assert local_finished_0.item() == 1.0
519
+ assert local_finished_1.item() == 1.0
520
+
521
+ def test_distributed_error(self, backend_type):
522
+ with self.local_device_mesh(2, 2, backend_type) as _:
523
+ x = torch.rand(3, 4).cuda()
524
+ y = torch.rand(3, 4).cuda()
525
+ # z is broken on rank 1 but not others
526
+ z = do_bogus_tensor_work(x, y, fail_rank=1)
527
+ # test that rank 1 is still doing work despite z failing
528
+ a = (x + y).reduce("gpu")
529
+ fetch_shard(a).result()
530
+ # but z itself should fail, even if we do not fetch it from rank 1
531
+ # (since fetch shard says we first want to assert the whole tensor is correct)
532
+ with pytest.raises(RemoteException, match="do_bogus_tensor_work"):
533
+ fetch_shard(z).result()
534
+ # try to reduce z, which should fail, but ranks that are not 1 do not
535
+ # know about the failure. Rank 1 should still participate in the reduce
536
+ # to unblock work.
537
+ rz = z.reduce("gpu")
538
+ # but we should see the error message still retrieving it because it is
539
+ # dependent on an error.
540
+ with pytest.raises(RemoteException, match="do_bogus_tensor_work"):
541
+ fetch_shard(rz).result()
542
+ # however, we should still be able to compute and get a result back
543
+ # from host 1, signaling that the reduction didn't get cuda compute stuck.
544
+ fetch_shard(2 * x, gpu=1, host=0).result()
545
+
546
+ def test_pipe(self, backend_type):
547
+ with self.local_device_mesh(2, 2, backend_type):
548
+ p = example_echo_add()
549
+ for _i in range(10):
550
+ x = torch.rand(3, 4)
551
+ p.send(x)
552
+ y = p.recv()
553
+ x, y = fetch_shard((x, y)).result()
554
+ with no_mesh.activate():
555
+ assert torch.allclose(x + 1, y)
556
+
557
+ def test_loader(self, backend_type):
558
+ with self.local_device_mesh(2, 2, backend_type):
559
+ p = example_data_loader(3, 7)
560
+ for i in range(3, 7):
561
+ x = fetch_shard(p.recv()).result()
562
+ with no_mesh.activate():
563
+ assert x.item() == i
564
+
565
+ def test_loader_blocks_with_small_pipe(self, backend_type):
566
+ with self.local_device_mesh(2, 2, backend_type):
567
+ iters = 10
568
+ p = example_data_loader_small_pipe(iters, (1000, 1000))
569
+ # timeout should proc on pipe process
570
+ sleep(0.6)
571
+ # it takes a few iters of reasonably sized tensors to fill up OS buffer
572
+ # max_messages (SNDHWM) only affects the zmq buffer
573
+ for _ in range(iters - 1):
574
+ p.recv()
575
+ t = fetch_shard(p.recv()).result()
576
+ assert t[0][0].item() == -1.0
577
+
578
+ def test_streams_run_parallel(self, backend_type):
579
+ with self.local_device_mesh(2, 2, backend_type):
580
+ # test that these two streams do in fact run in parallel
581
+ # on the worker by having each stream wait on a barrier.
582
+ # The Tensor t is just used as a data-dependency so that
583
+ # we can make sure new_barrier_hackery is called before
584
+ # the wait on 'other'.
585
+ other = Stream("other")
586
+ t = new_barrier_hackery(2)
587
+ t_other, borrow = other.borrow(t)
588
+ with borrow:
589
+ with other.activate():
590
+ wait_barrier_hackery(t_other)
591
+ wait_barrier_hackery(t)
592
+ fetch_shard(t).result()
593
+
594
+ def test_debug(self, backend_type):
595
+ gonna_pdb = remote(
596
+ "monarch.worker._testing_function.gonna_pdb", propagate="inspect"
597
+ )
598
+
599
+ with self.local_device_mesh(2, 2, backend_type):
600
+ writes = []
601
+
602
+ def dw(s):
603
+ writes.append(s)
604
+
605
+ def dr(n):
606
+ buffer = "".join(["print(x)\n", "c\n"]).encode()
607
+ assert len(buffer) <= n
608
+ return buffer
609
+
610
+ if backend_type == BackendType.RS:
611
+ patch_read = patch(
612
+ "monarch.controller.rust_backend.controller.debugger_read", new=dr
613
+ )
614
+ patch_write = patch(
615
+ "monarch.controller.rust_backend.controller.debugger_write", new=dw
616
+ )
617
+ else:
618
+ patch_read = patch("monarch.controller.debugger.read", new=dr)
619
+ patch_write = patch("monarch.controller.debugger.write", new=dw)
620
+ with patch_read, patch_write:
621
+ gonna_pdb()
622
+ # xxx: we do not process messages from workers
623
+ # unless fetching a result
624
+ fetch_shard(None).result()
625
+ assert "".join(writes).count("7\n") == 4
626
+
627
+ def test_fetch_preprocess(self, backend_type):
628
+ with self.local_device_mesh(2, 2, backend_type):
629
+ assert (
630
+ "an argument processed"
631
+ == remote("monarch.worker._testing_function.do_some_processing")
632
+ .call_on_shard_and_fetch(
633
+ "an argument",
634
+ )
635
+ .result()
636
+ )
637
+
638
+ def test_cached_remote_function(self, backend_type):
639
+ fn = remote("monarch.worker._testing_function.how_many_of_these_do_you_want")
640
+ start_hits = remote_module._hit
641
+ with self.local_device_mesh(2, 2, backend_type):
642
+ x = torch.ones(3, 4)
643
+ y = torch.rand(3, 4)
644
+
645
+ a, _, _ = fn(3, x)
646
+ b, _, _ = fn(3, x)
647
+ assert len(a._aliases.aliases) == 1
648
+ assert len(b._aliases.aliases) == 1
649
+ _, _, _ = fn(3, y)
650
+ t0, t1 = fn(2, x)
651
+ t0.add(t1)
652
+ local_a = fetch_shard(a).result()
653
+ with no_mesh.activate():
654
+ assert torch.all(local_a == 1.0)
655
+
656
+ end_hits = remote_module._hit
657
+ assert end_hits - start_hits == 2
658
+
659
+ def test_remote_autograd_function(self, backend_type):
660
+ from monarch.worker import _testing_function
661
+
662
+ remote_fn = remote_autograd_function(
663
+ _testing_function.TestRemoteAutogradFunction
664
+ )
665
+
666
+ with self.local_device_mesh(1, 1, backend_type):
667
+ x = torch.ones(1, requires_grad=True)
668
+ y = torch.ones_like(x).requires_grad_(True)
669
+ outs = remote_fn.apply(x, y)
670
+ assert outs[3] == 4
671
+ local_0 = fetch_shard(outs[0]).result()
672
+ local_1 = fetch_shard(outs[1]).result()
673
+ (outs[0] + outs[1]).sum().backward()
674
+ # unfortunately, grad_fn of local tensor is always None
675
+ # regardless of whether we set `no_grad` on the worker
676
+ # so we can test only requires_grad
677
+ for ll in (local_0, local_1):
678
+ assert not ll.requires_grad
679
+ grad_local_0 = fetch_shard(x.grad).result()
680
+ grad_local_1 = fetch_shard(x.grad).result()
681
+ x = x.detach()
682
+ x.grad = None
683
+ y.grad = None
684
+ outs = remote_fn.apply(x, y)
685
+ local_0_f = fetch_shard(outs[0]).result()
686
+ (outs[0] + outs[1]).sum().backward()
687
+ assert x.grad is None
688
+ grad_local_1_f = fetch_shard(y.grad).result()
689
+
690
+ assert torch.equal(local_0_f, torch.full_like(local_0_f, 2))
691
+ assert torch.equal(local_0, torch.ones_like(local_0))
692
+ assert torch.equal(grad_local_0, torch.ones_like(local_0))
693
+ assert torch.equal(grad_local_1, torch.ones_like(local_0))
694
+ assert torch.equal(grad_local_1_f, torch.ones_like(local_0))
695
+
696
+ def test_cached_remote_aliases(self, backend_type):
697
+ fn = remote("monarch.worker._testing_function.remote_chunk")
698
+ with self.local_device_mesh(1, 1, backend_type):
699
+ x = torch.randn(16, 5, device="cuda")
700
+ outs = fn(x)
701
+ aliases = outs[0]._aliases.aliases
702
+ # x and 4 results of x.chunk(4)
703
+ assert len(aliases) == 5
704
+ assert outs[2]._fake.storage_offset() == 40
705
+
706
+ def test_live_function(self, backend_type):
707
+ def bar(x, y):
708
+ return (
709
+ a_function_called_by_a_live_function(x)
710
+ + a_live_function_call_by_a_live_function(y)
711
+ + math.pi
712
+ )
713
+
714
+ @remote
715
+ def check(x):
716
+ return torch.allclose(x, torch.zeros(()) + math.pi + 5)
717
+
718
+ y = 7
719
+
720
+ @monarch.remote
721
+ def close():
722
+ return y
723
+
724
+ @monarch.remote
725
+ def cuda_works(x):
726
+ return x.cuda()
727
+
728
+ with self.local_device_mesh(2, 2, backend_type):
729
+ a = torch.ones(())
730
+ assert check.call_on_shard_and_fetch(bar(a, a)).result()
731
+ # ensure we do not attempt to pickle closures
732
+ close()
733
+
734
+ b = cuda_works(a)
735
+ fetch_shard(b).result()
736
+
737
+ @monarch.remote
738
+ def something_else():
739
+ raise Exception("No") # this line appears
740
+
741
+ # check that the stack trace has correct line numbers
742
+ with pytest.raises(Exception, match=r"this line appears"):
743
+ something_else()
744
+
745
+ def test_setting_random_seed(self, backend_type):
746
+ with self.local_device_mesh(2, 2, backend_type):
747
+ set_manual_seed_remote(12345)
748
+ t = torch.randn(3, 4)
749
+ t_d = torch.randn(3, 4, device="cuda")
750
+ ref = fetch_shard(t).result()
751
+ ref_d = fetch_shard(t_d).result()
752
+ vals = {
753
+ (h, d): fetch_shard(t, {"host": h, "gpu": d}).result()
754
+ for h, d in itertools.product(range(2), repeat=2)
755
+ }
756
+
757
+ vals_d = {
758
+ (h, d): fetch_shard(t_d, {"host": h, "gpu": d}).result()
759
+ for h, d in itertools.product(range(2), repeat=2)
760
+ }
761
+
762
+ for v, v_d in zip(vals.values(), vals_d.values()):
763
+ assert torch.equal(v, ref)
764
+ assert torch.equal(v_d, ref_d)
765
+
766
+ def test_return_exception(self, backend_type):
767
+ @monarch.remote
768
+ def simple():
769
+ return Exception("is a valid value to return")
770
+
771
+ with self.local_device_mesh(1, 1, backend_type):
772
+ # This should be a valid return than an exception to raise
773
+ simple.call_on_shard_and_fetch().result()
774
+
775
+ def test_opaque_object(self, backend_type):
776
+ with self.local_device_mesh(2, 2, backend_type):
777
+
778
+ class Foo(OpaqueObject):
779
+ @opaque_method
780
+ def add(self, x: torch.Tensor):
781
+ return x + x
782
+
783
+ f = Foo("monarch.worker._testing_function.WorkerFoo", 4.0)
784
+
785
+ result = monarch.inspect(f.add(torch.ones(3, 4)))
786
+ with monarch.no_mesh.activate():
787
+ assert torch.allclose(torch.full((3, 4), 5.0), result)
788
+
789
+ f.hi = 4
790
+ assert f.hi == 4
791
+
792
+ def test_opaqueRef_setup_state_and_iteration(self, backend_type):
793
+ with self.local_device_mesh(1, 2, backend_type) as mesh:
794
+ pg = mesh.process_group(("gpu",))
795
+ model, dataloader, criterion, optimizer = setup_state()
796
+ num_epochs = 5
797
+ for _ in range(num_epochs):
798
+ loss = iteration(model, dataloader, criterion, optimizer, pg)
799
+ assert inspect(loss).item() > 0
800
+
801
+ def test_opaqueRef_key_deleted(self, backend_type):
802
+ with self.local_device_mesh(1, 1, backend_type):
803
+ ref = create_opaque_ref()
804
+ assert inspect(opaque_ref_key_table_length()).item() == 1
805
+ del ref
806
+ assert inspect(opaque_ref_key_table_length()).item() == 0
807
+
808
+ def test_opaque_module(self, backend_type):
809
+ with self.local_device_mesh(2, 2, backend_type):
810
+ linear = OpaqueModule("torch.nn.Linear", 3, 3, device="cuda")
811
+ with torch.no_grad():
812
+ for p in linear.parameters():
813
+ p.zero_()
814
+ input_ = torch.rand(4, 3, device="cuda")
815
+ # we should have been able to clear the parameters and have that result
816
+ # affect how the linear works.
817
+ output = linear.call_method("forward", lambda self, x: x.clone(), input_)
818
+ assert monarch.inspect(output.sum()).item() == 0
819
+
820
+ def test_opaque_module_autograd(self, backend_type):
821
+ with self.local_device_mesh(2, 2, backend_type):
822
+ input_ = torch.rand(3, 3, device="cuda", requires_grad=True)
823
+
824
+ linear = OpaqueModule("torch.nn.Linear", 3, 3, device="cuda")
825
+ output = linear(input_, propagator=lambda self, x: x.clone())
826
+ r = output.sum()
827
+ with torch.no_grad():
828
+ r.backward()
829
+
830
+ weight, bias = linear.parameters()
831
+ ig0, wg0, bg0 = monarch.inspect((input_.grad, weight.grad, bias.grad))
832
+
833
+ input_.grad = None
834
+ weight.grad = None
835
+ bias.grad = None
836
+
837
+ (input_ @ weight.T + bias).sum().backward()
838
+
839
+ ig1, wg1, bg1 = monarch.inspect((input_.grad, weight.grad, bias.grad))
840
+
841
+ with monarch.no_mesh.activate():
842
+ assert torch.allclose(ig0, ig1)
843
+ assert torch.allclose(wg0, wg1)
844
+ assert torch.allclose(bg0, bg1)
845
+
846
+ def test_remote_function_reduce_scatter_tensor_sum(self, backend_type):
847
+ self.do_test_reduce_scatter_tensor(
848
+ backend_type,
849
+ torch.distributed.ReduceOp.SUM,
850
+ (
851
+ torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
852
+ * torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
853
+ ).sum(0),
854
+ )
855
+
856
+ def test_remote_function_reduce_scatter_tensor_subgroup_sum(
857
+ self, backend_type: BackendType
858
+ ) -> None:
859
+ self.do_test_reduce_scatter_tensor_subgroup(
860
+ backend_type,
861
+ torch.distributed.ReduceOp.SUM,
862
+ expected_tensor_host_group=torch.tensor(
863
+ [[0, 2, 4, 6], [0, 4, 8, 12], [8, 10, 12, 14], [16, 20, 24, 28]],
864
+ dtype=torch.float32,
865
+ ),
866
+ expected_tensor_gpu_group=torch.tensor(
867
+ [[0, 1, 2, 3], [4, 5, 6, 7], [0, 5, 10, 15], [20, 25, 30, 35]],
868
+ dtype=torch.float32,
869
+ ),
870
+ )
871
+
872
+ def test_remote_function_reduce_scatter_tensor_avg(self, backend_type):
873
+ self.do_test_reduce_scatter_tensor(
874
+ backend_type,
875
+ torch.distributed.ReduceOp.AVG,
876
+ (
877
+ torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
878
+ * torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
879
+ ).mean(0),
880
+ )
881
+
882
+ def test_remote_function_reduce_scatter_sum(
883
+ self, backend_type: BackendType
884
+ ) -> None:
885
+ self.do_test_reduce_scatter(
886
+ backend_type,
887
+ torch.distributed.ReduceOp.SUM,
888
+ (
889
+ torch.arange(0, 8, dtype=torch.float32).reshape(1, 4, 2).repeat(4, 1, 1)
890
+ * torch.arange(4, dtype=torch.float32).unsqueeze(-1).unsqueeze(-1)
891
+ ).sum(0),
892
+ )
893
+
894
+ def test_remote_function_reduce_scatter_avg(
895
+ self, backend_type: BackendType
896
+ ) -> None:
897
+ self.do_test_reduce_scatter(
898
+ backend_type,
899
+ torch.distributed.ReduceOp.AVG,
900
+ (
901
+ torch.arange(0, 8, dtype=torch.float32).reshape(1, 4, 2).repeat(4, 1, 1)
902
+ * torch.arange(4, dtype=torch.float32).unsqueeze(-1).unsqueeze(-1)
903
+ ).mean(0),
904
+ )
905
+
906
+ def test_remote_function_all_reduce_sum(self, backend_type):
907
+ self.do_test_all_reduce(
908
+ backend_type,
909
+ torch.distributed.ReduceOp.SUM,
910
+ (
911
+ torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
912
+ * torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
913
+ ).sum(0),
914
+ )
915
+
916
+ def test_remote_function_all_reduce_avg(self, backend_type):
917
+ self.do_test_all_reduce(
918
+ backend_type,
919
+ torch.distributed.ReduceOp.AVG,
920
+ (
921
+ torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
922
+ * torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
923
+ ).mean(0),
924
+ )
925
+
926
+ def test_remote_function_all_reduce_max(self, backend_type):
927
+ self.do_test_all_reduce(
928
+ backend_type,
929
+ torch.distributed.ReduceOp.MAX,
930
+ (
931
+ torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
932
+ * torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
933
+ ).max(0)[0],
934
+ )
935
+
936
+ def test_remote_function_all_reduce_min(self, backend_type):
937
+ self.do_test_all_reduce(
938
+ backend_type,
939
+ torch.distributed.ReduceOp.MIN,
940
+ (
941
+ torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
942
+ * torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
943
+ ).min(0)[0],
944
+ )
945
+
946
+ def test_remote_function_failure_message_contains_traceback(self, backend_type):
947
+ with self.local_device_mesh(2, 2, backend_type):
948
+ x = outer_remote_function_that_calls_inner()
949
+ try:
950
+ inspect(x)
951
+ except RemoteException as e:
952
+ backtrace = "\n".join([frame.name for frame in e.worker_frames])
953
+ assert "outer_remote_function" in backtrace
954
+ assert "inner_remote_function" in backtrace
955
+
956
+ def test_remote_function_broadcast(self, backend_type):
957
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
958
+ pg = device_mesh.process_group(("host", "gpu"))
959
+ for i in range(4):
960
+ rank = 2 * device_mesh.rank("host") + device_mesh.rank("gpu")
961
+ rank = rank.cuda()
962
+ broadcast(rank, src=i, group=pg)
963
+ for host in range(2):
964
+ for gpu in range(2):
965
+ with no_mesh.activate():
966
+ assert inspect(rank, {"host": host, "gpu": gpu}).item() == i
967
+
968
+ def test_remote_function_all_to_all_single(self, backend_type):
969
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
970
+ pg = device_mesh.process_group(("host", "gpu"))
971
+ tensor_in = torch.arange(4, device="cuda", dtype=float)
972
+ tensor_out = torch.empty(4, device="cuda", dtype=float)
973
+ all_to_all_single(tensor_out, tensor_in, group=pg)
974
+ for host in range(2):
975
+ for gpu in range(2):
976
+ rank = 2 * host + gpu
977
+ with no_mesh.activate():
978
+ assert torch.equal(
979
+ inspect(tensor_out, {"host": host, "gpu": gpu}),
980
+ rank * torch.ones(4),
981
+ )
982
+
983
+ def test_remote_function_all_to_all(self, backend_type: BackendType) -> None:
984
+ world_size = 2
985
+ n_gpus = 2
986
+ size = world_size * n_gpus
987
+ expected_tensors = [
988
+ torch.tensor([0, 4, 8, 12], dtype=torch.float32),
989
+ torch.tensor([1, 5, 9, 13], dtype=torch.float32),
990
+ torch.tensor([2, 6, 10, 14], dtype=torch.float32),
991
+ torch.tensor([3, 7, 11, 15], dtype=torch.float32),
992
+ ]
993
+
994
+ with self.local_device_mesh(world_size, n_gpus, backend_type) as device_mesh:
995
+ pg = device_mesh.process_group(("host", "gpu"))
996
+ rank = n_gpus * device_mesh.rank("host") + device_mesh.rank("gpu")
997
+ in_tensors = list(
998
+ torch.chunk(
999
+ torch.arange(size, device="cuda", dtype=torch.float32)
1000
+ + (rank * size),
1001
+ size,
1002
+ )
1003
+ )
1004
+ # These values will be replaced, just used for shape.
1005
+ out_tensors = list(torch.zeros(size, device="cuda").chunk(size))
1006
+ out_tensors = all_to_all(out_tensors, in_tensors, group=pg)
1007
+ for host in range(world_size):
1008
+ for gpu in range(n_gpus):
1009
+ local_tensor_out = inspect(out_tensors, {"host": host, "gpu": gpu})
1010
+ rank = host * n_gpus + gpu
1011
+ with no_mesh.activate():
1012
+ # Combine the tensor list together for a better comparison
1013
+ # message.
1014
+ local_tensor_out = torch.cat(local_tensor_out)
1015
+ assert torch.equal(
1016
+ local_tensor_out, expected_tensors[rank]
1017
+ ), f"For {rank=}, {host=}, {gpu=}"
1018
+
1019
+
1020
+ @pytest.mark.skipif(
1021
+ torch.cuda.device_count() < 2,
1022
+ reason="Not enough GPUs, this test requires at least 2 GPUs",
1023
+ )
1024
+ # Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
1025
+ # out is not counted as a failure, so we set a more restrictive timeout to
1026
+ # ensure we see a hard failure in CI.
1027
+ @pytest.mark.timeout(120)
1028
+ @pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
1029
+ class TestComm(RemoteFunctionsTestBase):
1030
+ N_GPUS: int = 2
1031
+ N_HOSTS: int = 2
1032
+
1033
+ @property
1034
+ def world_size(self) -> int:
1035
+ return self.N_GPUS * self.N_HOSTS
1036
+
1037
+ @property
1038
+ def device(self):
1039
+ self.fail("test subclass didn't override device")
1040
+
1041
+ def _test_tensor_dtype_complex(self, backend_type: BackendType) -> None:
1042
+ with self.local_device_mesh(
1043
+ self.N_HOSTS, self.N_GPUS, backend_type
1044
+ ) as device_mesh:
1045
+ group = device_mesh.process_group(("host", "gpu"))
1046
+ tensor = torch.rand(2, device="cuda")
1047
+ tensor_c = torch.view_as_complex(tensor)
1048
+ tensor_list = [
1049
+ torch.rand(2, device="cuda") for _ in range(self.N_HOSTS * self.N_GPUS)
1050
+ ]
1051
+ tensor_list_c = list(tensor_list)
1052
+ tensor_list_c[1] = torch.view_as_complex(tensor_list_c[1])
1053
+
1054
+ inspect(all_gather(tensor_list, tensor, group=group))
1055
+ inspect(all_gather(tensor_list, tensor_c, group=group))
1056
+ inspect(all_gather(tensor_list_c, tensor, group=group))
1057
+ inspect(all_gather(tensor_list_c, tensor_c, group=group))
1058
+
1059
+ def test_nccl_barrier(self, backend_type: BackendType) -> None:
1060
+ with self.local_device_mesh(
1061
+ self.N_HOSTS, self.N_GPUS, backend_type
1062
+ ) as device_mesh:
1063
+ pg = device_mesh.process_group(("host", "gpu"))
1064
+ rank = device_mesh.rank(("host", "gpu"))
1065
+ t = torch.tensor([1] * 10, device="cuda") + rank
1066
+ all_reduce(t, group=pg)
1067
+
1068
+ for host in range(self.N_HOSTS):
1069
+ for gpu in range(self.N_GPUS):
1070
+ rank = 2 * host + gpu
1071
+ with no_mesh.activate():
1072
+ # all reduce will sum rank + 1 across all ranks.
1073
+ expected_tensor = torch.tensor(
1074
+ [sum(range(1, self.world_size + 1))] * 10
1075
+ )
1076
+ assert torch.equal(
1077
+ expected_tensor,
1078
+ inspect(t, {"host": host, "gpu": gpu}),
1079
+ )
1080
+
1081
+ def test_tensor_dtype_complex(self, backend_type: BackendType) -> None:
1082
+ self._test_tensor_dtype_complex(backend_type)
1083
+
1084
+ def test_reduce_scatter_base_k(self, backend_type: BackendType) -> None:
1085
+ expected_tensor = (
1086
+ torch.arange(self.N_HOSTS * self.N_GPUS * 2, dtype=torch.float32)
1087
+ .reshape(1, self.N_HOSTS * self.N_GPUS, 2)
1088
+ .repeat(self.N_HOSTS * self.N_GPUS, 1, 1)
1089
+ ).sum(0)
1090
+ with self.local_device_mesh(
1091
+ self.N_HOSTS, self.N_GPUS, backend_type
1092
+ ) as device_mesh:
1093
+ pg = device_mesh.process_group(("host", "gpu"))
1094
+ output_tensor = torch.zeros(2, dtype=torch.int64, device="cuda")
1095
+ input_tensors = torch.arange(
1096
+ self.N_HOSTS * self.N_GPUS * 2, dtype=torch.int64, device="cuda"
1097
+ )
1098
+ input_tensors = torch.reshape(
1099
+ input_tensors, (self.N_HOSTS * self.N_GPUS, 2)
1100
+ )
1101
+ # Input is [[0, 1], [2, 3], [4, 5], [6, 7]] across 4 ranks.
1102
+ # After reduce + scatter, output_tensor should be [0 * 4, 1 * 4] on the 0th rank
1103
+ # and [2 * 4, 3 * 4] on the 1st rank, and so on
1104
+ reduce_scatter_tensor(output_tensor, input_tensors, group=pg)
1105
+
1106
+ for host in range(self.N_HOSTS):
1107
+ for gpu in range(self.N_GPUS):
1108
+ rank = 2 * host + gpu
1109
+ output_tensor_local = inspect(
1110
+ output_tensor, {"host": host, "gpu": gpu}
1111
+ )
1112
+ with no_mesh.activate():
1113
+ assert torch.equal(output_tensor_local, expected_tensor[rank])
1114
+
1115
+
1116
+ @pytest.mark.skipif(
1117
+ torch.cuda.device_count() < 2,
1118
+ reason="Not enough GPUs, this test requires at least 2 GPUs",
1119
+ )
1120
+ # Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
1121
+ # out is not counted as a failure, so we set a more restrictive timeout to
1122
+ # ensure we see a hard failure in CI.
1123
+ @pytest.mark.timeout(120)
1124
+ @pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
1125
+ class TestNcclProcessGroupWithDispatchedCollectives(RemoteFunctionsTestBase):
1126
+ """This test is copied from test_c10d_nccl.py::NcclProcessGroupWithDispatchedCollectivesTests
1127
+ in torch, but modified to setup a Monarch device mesh and use remote functions"""
1128
+
1129
+ N_GPUS: int = 2
1130
+ N_HOSTS: int = 2
1131
+
1132
+ def _call_collective_with_varying_tensors(
1133
+ self,
1134
+ world_size: int,
1135
+ # pyre-fixme[24]: Incorrect ParamsSpec annotation.
1136
+ collective: Remote[..., torch.Tensor],
1137
+ *args,
1138
+ **kwargs,
1139
+ ) -> None:
1140
+ # call collective with varying tensors to ensure that the tensors are
1141
+ # correctly dispatched
1142
+
1143
+ # ensure supported devices (cpu, cuda) succeeds during dispatch call
1144
+ tensor = torch.zeros(2, 2, device=torch.device("cuda"))
1145
+ # multi tensor collectives
1146
+ if collective == barrier:
1147
+ fetch_shard(collective(*args, **kwargs)).result()
1148
+ elif collective == all_gather:
1149
+ output_list = list(
1150
+ torch.zeros(world_size * 2, 2, device=torch.device("cuda")).chunk(
1151
+ world_size
1152
+ )
1153
+ )
1154
+ fetch_shard(collective(output_list, tensor, *args, **kwargs)).result()
1155
+ elif collective == reduce_scatter:
1156
+ fetch_shard(
1157
+ collective(tensor, [tensor] * world_size, *args, **kwargs)
1158
+ ).result()
1159
+ elif collective == gather:
1160
+ gather_list = list(
1161
+ torch.zeros(world_size * 2, 2, device=torch.device("cuda")).chunk(
1162
+ world_size
1163
+ )
1164
+ )
1165
+ fetch_shard(collective(tensor, gather_list, *args, **kwargs)).result()
1166
+ elif collective == scatter:
1167
+ fetch_shard(
1168
+ collective(tensor, [tensor] * world_size, *args, **kwargs)
1169
+ ).result()
1170
+ elif collective == all_to_all:
1171
+ fetch_shard(
1172
+ collective(
1173
+ [tensor] * world_size, [tensor] * world_size, *args, **kwargs
1174
+ )
1175
+ ).result()
1176
+ else:
1177
+ fetch_shard(collective(tensor, *args, **kwargs)).result()
1178
+
1179
+ @pytest.mark.parametrize(
1180
+ "collective",
1181
+ [
1182
+ reduce,
1183
+ broadcast,
1184
+ all_reduce,
1185
+ all_gather,
1186
+ reduce_scatter,
1187
+ barrier,
1188
+ all_to_all,
1189
+ gather,
1190
+ scatter,
1191
+ ],
1192
+ ids=[
1193
+ "reduce",
1194
+ "broadcast",
1195
+ "all_reduce",
1196
+ "all_gather",
1197
+ "reduce_scatter",
1198
+ "barrier",
1199
+ "all_to_all",
1200
+ "gather",
1201
+ "scatter",
1202
+ ],
1203
+ )
1204
+ def test_collectives(
1205
+ self, backend_type: BackendType, collective: Callable[..., torch.Tensor]
1206
+ ) -> None:
1207
+ world_size = self.N_HOSTS * self.N_GPUS
1208
+ with self.local_device_mesh(
1209
+ self.N_HOSTS, self.N_GPUS, backend_type
1210
+ ) as device_mesh:
1211
+ rank = device_mesh.rank(("host", "gpu"))
1212
+ pg = device_mesh.process_group(("host", "gpu"))
1213
+
1214
+ kwargs: dict[str, object] = {"group": pg}
1215
+ if collective == reduce:
1216
+ kwargs["group_dst"] = 0
1217
+ elif collective == broadcast:
1218
+ kwargs["group_src"] = rank
1219
+ elif collective == gather:
1220
+ kwargs["group_dst"] = 0
1221
+ elif collective == scatter:
1222
+ kwargs["group_src"] = 0
1223
+ self._call_collective_with_varying_tensors(world_size, collective, **kwargs)
1224
+
1225
+ def test_all_to_all_single(self, backend_type: BackendType) -> None:
1226
+ with self.local_device_mesh(
1227
+ self.N_HOSTS, self.N_GPUS, backend_type
1228
+ ) as device_mesh:
1229
+ pg = device_mesh.process_group(("host", "gpu"))
1230
+ # test alltoall_base
1231
+ tensor_in = torch.arange(4, device="cuda", dtype=torch.float32)
1232
+ tensor_out = torch.empty(4, device="cuda", dtype=torch.float32)
1233
+ all_to_all_single(tensor_out, tensor_in, group=pg)
1234
+
1235
+ for host in range(self.N_HOSTS):
1236
+ for gpu in range(self.N_GPUS):
1237
+ rank = 2 * host + gpu
1238
+ with no_mesh.activate():
1239
+ assert torch.equal(
1240
+ inspect(tensor_out, {"host": host, "gpu": gpu}),
1241
+ rank * torch.ones(4),
1242
+ )
1243
+
1244
+ def test_allgather_base(self, backend_type: BackendType) -> None:
1245
+ with self.local_device_mesh(
1246
+ self.N_HOSTS, self.N_GPUS, backend_type
1247
+ ) as device_mesh:
1248
+ pg = device_mesh.process_group(("host", "gpu"))
1249
+ rank = (
1250
+ (device_mesh.rank("host") + 1) * self.N_GPUS
1251
+ + device_mesh.rank("gpu")
1252
+ + 1
1253
+ )
1254
+ tensor_in = torch.arange(2, device="cuda") * rank
1255
+ tensor_out = torch.zeros(
1256
+ self.N_HOSTS * self.N_GPUS * 2, dtype=torch.int64, device="cuda"
1257
+ )
1258
+ all_gather_into_tensor(tensor_out, tensor_in, group=pg)
1259
+ local_tensor_out = inspect(tensor_out)
1260
+ with no_mesh.activate():
1261
+ assert torch.equal(
1262
+ local_tensor_out, torch.tensor([0, 3, 0, 4, 0, 5, 0, 6])
1263
+ )
1264
+
1265
+
1266
+ def a_function_called_by_a_live_function(x):
1267
+ return 2 * x
1268
+
1269
+
1270
+ def a_live_function_call_by_a_live_function(x):
1271
+ return 3 * x