torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,845 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import itertools
9
+ import logging
10
+ import re
11
+ import sys
12
+ import traceback
13
+ from contextlib import contextmanager
14
+
15
+ import monarch
16
+ import monarch.random
17
+ import pytest
18
+
19
+ import torch
20
+
21
+ from monarch import (
22
+ DeviceMesh,
23
+ fetch_shard,
24
+ grad_function,
25
+ grad_generator,
26
+ no_mesh,
27
+ Stream,
28
+ Tensor,
29
+ )
30
+
31
+ from monarch._testing import BackendType, TestingContext
32
+ from monarch.common.controller_api import LogMessage
33
+ from monarch.common.invocation import DeviceException
34
+ from monarch.common.remote import remote
35
+ from monarch.common.tree import flattener
36
+ from monarch.rust_local_mesh import (
37
+ ControllerParams,
38
+ local_mesh,
39
+ local_meshes_and_bootstraps,
40
+ LoggingLocation,
41
+ SocketType,
42
+ SupervisionParams,
43
+ )
44
+ from monarch_supervisor.logging import fix_exception_lines
45
+
46
+
47
+ def custom_excepthook(exc_type, exc_value, exc_traceback):
48
+ tb_lines = fix_exception_lines(
49
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
50
+ )
51
+ print("\n".join(tb_lines), file=sys.stderr)
52
+
53
+
54
+ sys.excepthook = custom_excepthook
55
+
56
+
57
+ @pytest.fixture(scope="module", autouse=True)
58
+ def testing_context():
59
+ global local
60
+ with TestingContext() as local:
61
+ yield
62
+
63
+
64
+ @contextmanager
65
+ def local_rust_device_mesh(
66
+ hosts,
67
+ gpu_per_host,
68
+ activate: bool = True,
69
+ controller_params: ControllerParams | None = None,
70
+ ):
71
+ with local_mesh(
72
+ hosts=hosts,
73
+ gpus_per_host=gpu_per_host,
74
+ socket_type=SocketType.UNIX,
75
+ logging_location=LoggingLocation.FILE,
76
+ controller_params=controller_params,
77
+ ) as dm:
78
+ try:
79
+ if activate:
80
+ with dm.activate():
81
+ yield dm
82
+ else:
83
+ yield dm
84
+ dm.exit()
85
+ except Exception:
86
+ dm.client._shutdown = True
87
+ raise
88
+
89
+
90
+ panic = remote("__test_panic", propagate="inspect")
91
+
92
+ remote_sleep = remote("time.sleep", propagate="inspect")
93
+
94
+
95
+ @pytest.mark.skipif(
96
+ torch.cuda.device_count() < 2,
97
+ reason="Not enough GPUs, this test requires at least 2 GPUs",
98
+ )
99
+ @pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS, "mesh"])
100
+ # Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
101
+ # out is not counted as a failure, so we set a more restrictive timeout to
102
+ # ensure we see a hard failure in CI.
103
+ @pytest.mark.timeout(120)
104
+ class TestController:
105
+ @classmethod
106
+ def local_device_mesh(
107
+ cls,
108
+ N,
109
+ gpu_per_host,
110
+ backend_type,
111
+ activate=True,
112
+ ):
113
+ return local.local_device_mesh(
114
+ N,
115
+ gpu_per_host,
116
+ activate,
117
+ backend=str(backend_type),
118
+ )
119
+
120
+ def test_errors(self, backend_type):
121
+ t = torch.rand(3, 4)
122
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
123
+ y = torch.rand(3, 4)
124
+ with pytest.raises(TypeError, match="LOCAL_TENSOR"):
125
+ t.add(y)
126
+ with pytest.raises(TypeError, match="WRONG_MESH"):
127
+ sm = device_mesh.slice(host=0)
128
+ with sm.activate():
129
+ x = torch.rand(3, 4)
130
+ x.add(y)
131
+
132
+ other = Stream("other")
133
+ t = torch.rand(10).cuda()
134
+ with pytest.raises(TypeError, match="WRONG_STREAM"):
135
+ with other.activate():
136
+ t = t.reduce("host", "sum")
137
+
138
+ def test_sub_mesh(self, backend_type):
139
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
140
+ h0 = device_mesh.slice(host=0)
141
+ h1 = device_mesh.slice(host=1)
142
+ with h0.activate():
143
+ _ = torch.rand(3, 4)
144
+ with h1.activate():
145
+ _ = torch.rand(3, 4)
146
+ # Runs on a different mesh but should still work
147
+
148
+ def test_fetch_result_device(self, backend_type):
149
+ with self.local_device_mesh(2, 2, backend_type):
150
+ on_gpu = torch.ones(2, 3, device="cuda")
151
+ on_cpu = torch.ones(2, 3, device="cpu")
152
+
153
+ on_gpu_local = fetch_shard(on_gpu).result()
154
+ on_cpu_local = fetch_shard(on_cpu).result()
155
+
156
+ assert on_gpu_local.device == torch.device("cpu")
157
+ assert on_cpu_local.device == torch.device("cpu")
158
+
159
+ def test_dim1_mesh(self, backend_type):
160
+ with self.local_device_mesh(2, 2, backend_type, activate=False) as device_mesh:
161
+ mesh3d = device_mesh.split(host=("oh", "ih"), ih=1)
162
+ with mesh3d.activate():
163
+ x = torch.ones(3, 4)
164
+ local_x = fetch_shard(x).result()
165
+
166
+ assert torch.equal(local_x, torch.ones(3, 4))
167
+
168
+ def test_sub_mesh_use_only_one(self, backend_type):
169
+ with self.local_device_mesh(2, 2, backend_type, activate=False) as device_mesh:
170
+ h0 = device_mesh.slice(host=0)
171
+
172
+ with h0.activate():
173
+ x = torch.ones(3, 4)
174
+ local_x = fetch_shard(x)
175
+
176
+ local_x = local_x.result(timeout=20)
177
+ assert torch.equal(local_x, torch.ones(3, 4))
178
+
179
+ def test_sub_mesh_process_grop(self, backend_type):
180
+ with self.local_device_mesh(2, 2, backend_type, activate=False) as device_mesh:
181
+ h0 = device_mesh.slice(host=0)
182
+ pg0 = h0.process_group(("gpu",))
183
+ pg1 = h0.process_group(("gpu",))
184
+ # Is there a way to functionally test that these two PG's aren't
185
+ # the same in the backend?
186
+ assert pg0 != pg1
187
+
188
+ def test_reduce(self, backend_type):
189
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
190
+ x = (
191
+ 12 * 2 * device_mesh.rank("host")
192
+ + 12 * device_mesh.rank("gpu")
193
+ + torch.arange(12, device="cuda").reshape(3, 4)
194
+ )
195
+ y = x.reduce("gpu", "sum")
196
+ g = x.reduce("gpu", "stack")
197
+ with pytest.raises(TypeError, match="When scattering"):
198
+ x = x.reduce("gpu", "sum", scatter=True)
199
+ x = x.reshape(2, 6)
200
+ atoa = x.reduce("gpu", "stack", scatter=True)
201
+ rs = x.reduce("gpu", "sum", scatter=True)
202
+ rad = x.reduce((), "sum")
203
+ rade = x.reduce(("gpu", "host"), "sum")
204
+ with pytest.raises(
205
+ ValueError, match="is not valid for multiple dimensions"
206
+ ):
207
+ x.reduce((), "sum", scatter=True)
208
+ with pytest.raises(
209
+ ValueError, match="is not valid for multiple dimensions"
210
+ ):
211
+ x.reduce((), "stack")
212
+ with pytest.raises(
213
+ ValueError, match="is not valid for multiple dimensions"
214
+ ):
215
+ x.reduce((), "stack", scatter=True)
216
+ y_local = fetch_shard(y).result()
217
+ g_local = fetch_shard(g).result()
218
+ # TODO compute the expected values to compare agains in the below section
219
+ _ = fetch_shard(atoa).result()
220
+ _ = fetch_shard(rs).result()
221
+ rad_local = fetch_shard(rad).result()
222
+ rade_local = fetch_shard(rade).result()
223
+
224
+ xs = {
225
+ (h, g): 12 * 2 * h + 12 * g + torch.arange(12, device="cpu").reshape(3, 4)
226
+ for h, g in itertools.product(range(2), range(2))
227
+ }
228
+
229
+ y_expected = xs[(0, 0)] + xs[(0, 1)]
230
+ g_expected = torch.stack([xs[(0, 0)], xs[(0, 1)]])
231
+ assert torch.equal(y_local, y_expected)
232
+ assert torch.equal(g_local, g_expected)
233
+ rad_expected = (xs[(0, 0)] + xs[(0, 1)] + xs[(1, 0)] + xs[(1, 1)]).reshape(
234
+ rad_local.shape
235
+ )
236
+ assert torch.equal(rad_local, rad_expected)
237
+ assert torch.equal(rade_local, rad_expected)
238
+
239
+ # test is run on 4 GPUs, can't have mesh with 3 non-trivial dimensions
240
+ with self.local_device_mesh(2, 2, backend_type, activate=False) as mesh2d:
241
+ device_mesh = mesh2d.split(host=("oh", "ih"), ih=1)
242
+ with device_mesh.activate():
243
+ x = (
244
+ 12 * 2 * device_mesh.rank("oh")
245
+ + 12 * device_mesh.rank("gpu")
246
+ + torch.arange(12, device="cuda").reshape(3, 4)
247
+ )
248
+ y = x.reduce(("ih", "gpu"), "sum")
249
+ y_local = fetch_shard(y).result()
250
+ z = x.reduce(("oh", "gpu"), "sum")
251
+ z_local = fetch_shard(z).result()
252
+
253
+ assert torch.equal(y_local, y_expected)
254
+ assert torch.equal(z_local, rad_expected.reshape(z_local.shape))
255
+
256
+ def test_reduce_out(self, backend_type):
257
+ with self.local_device_mesh(2, 2, backend_type):
258
+ inp = torch.rand(2, 4, device="cuda")
259
+ out_incorrect = torch.rand(2, 4, device="cuda")
260
+ out = torch.rand(4, device="cuda")
261
+
262
+ with pytest.raises(
263
+ ValueError, match="Reduce expects the shape to be torch.Size."
264
+ ):
265
+ _ = inp.reduce("host", reduction="sum", scatter=True, out=out_incorrect)
266
+
267
+ reduce_out = inp.reduce("host", reduction="sum", scatter=True)
268
+ local_out = fetch_shard(out).result()
269
+ local_reduce_out = fetch_shard(reduce_out).result()
270
+ assert out._fake is not reduce_out._fake
271
+ with no_mesh.activate():
272
+ assert not torch.equal(local_out, local_reduce_out)
273
+
274
+ reduce_out = inp.reduce("host", reduction="sum", scatter=True, out=out)
275
+ local_out = fetch_shard(out).result()
276
+ local_reduce_out = fetch_shard(reduce_out).result()
277
+ assert out._fake is reduce_out._fake
278
+ with no_mesh.activate():
279
+ assert torch.equal(local_out, local_reduce_out)
280
+
281
+ def test_fetch(self, backend_type):
282
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
283
+ h = device_mesh.rank("host")
284
+ g = device_mesh.rank("gpu")
285
+ for hi in range(2):
286
+ for gi in range(2):
287
+ x, y = fetch_shard((h, g), {"host": hi, "gpu": gi}).result()
288
+ with no_mesh.activate():
289
+ assert (hi, gi) == (x.item(), y.item())
290
+
291
+ def test_mutate(self, backend_type):
292
+ with self.local_device_mesh(2, 2, backend_type):
293
+ x = torch.rand(3, 4).cuda()
294
+ x.abs_()
295
+ s = Stream("other")
296
+ b, drop = s.borrow(x)
297
+ with pytest.raises(TypeError, match="would be mutated"):
298
+ x.abs_()
299
+ with s.activate():
300
+ _ = b.add(b)
301
+ drop.drop()
302
+ x.abs_()
303
+ b, drop = s.borrow(x, mutable=True)
304
+ with s.activate():
305
+ b.abs_()
306
+ drop.drop()
307
+ # del b
308
+ x.abs_()
309
+
310
+ def test_movement(self, backend_type):
311
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
312
+ sm0 = device_mesh.slice(host=0)
313
+ sm1 = device_mesh.slice(host=1)
314
+
315
+ with sm0.activate():
316
+ x = torch.rand(3, 4, device="cuda")
317
+ _ = x.to_mesh(sm1)
318
+
319
+ a = torch.rand(3, 4, device="cuda")
320
+
321
+ b = a.slice_mesh(host=0)
322
+ _ = b.to_mesh(sm0)
323
+ _ = b.to_mesh(sm1)
324
+
325
+ def test_broadcast_one(self, backend_type):
326
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
327
+ for dim in ("host", "gpu"):
328
+ subset = device_mesh.slice(**{dim: 1})
329
+ with subset.activate():
330
+ x = torch.rand(3, device="cuda")
331
+ y = x.to_mesh(device_mesh)
332
+
333
+ with subset.activate():
334
+ a = monarch.inspect(x)
335
+ with device_mesh.activate():
336
+ b = monarch.inspect(y.reduce(dim, reduction="stack"))
337
+ with no_mesh.activate():
338
+ assert torch.allclose(a.expand(2, -1), b, rtol=0, atol=0)
339
+
340
+ def test_broadcast_two(self, backend_type):
341
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
342
+ subset = device_mesh.slice(host=1, gpu=1)
343
+ with subset.activate():
344
+ x = torch.rand(3, device="cuda")
345
+ y = x.to_mesh(device_mesh)
346
+
347
+ with subset.activate():
348
+ a = monarch.inspect(x)
349
+ with device_mesh.activate():
350
+ b = monarch.inspect(
351
+ y.reduce("host", reduction="stack").reduce("gpu", reduction="stack")
352
+ )
353
+ with no_mesh.activate():
354
+ assert torch.allclose(a.expand(2, 2, -1), b, rtol=0, atol=0)
355
+
356
+ def test_autograd(self, backend_type):
357
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
358
+ x = torch.rand(3, 4, requires_grad=True)
359
+ y = torch.rand(4, 3, requires_grad=True)
360
+ z = torch.rand(3, requires_grad=True)
361
+
362
+ foo = (x @ y + z).sum()
363
+ with no_mesh.activate():
364
+ # check backward restores forward mesh
365
+ for t in grad_generator(foo, [z, y, x]):
366
+ with device_mesh.activate():
367
+ fetch_shard(t).result()
368
+
369
+ def test_mesh_semantics(self, backend_type):
370
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
371
+ host0 = device_mesh.slice(host=0)
372
+ host1 = device_mesh.slice(host=1)
373
+ with host0.activate():
374
+ x = torch.randn(5)
375
+ y = x * 5
376
+ with host1.activate():
377
+ a = torch.randn(5)
378
+ b = a * 5
379
+ x.cos()
380
+ y.cos()
381
+ b.cos()
382
+
383
+ def test_autograd_multi_mesh(self, backend_type):
384
+ @grad_function
385
+ def to_mesh(x: Tensor, mesh: DeviceMesh):
386
+ omesh = x.mesh
387
+
388
+ def backward(grad_x: Tensor):
389
+ print(grad_x.mesh, omesh)
390
+ return grad_x.to_mesh(omesh), None
391
+
392
+ return x.to_mesh(mesh), backward
393
+
394
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
395
+ host0 = device_mesh.slice(host=0)
396
+ host1 = device_mesh.slice(host=1)
397
+ with host0.activate():
398
+ x = torch.rand(3, 4, requires_grad=True, device="cuda")
399
+ y = torch.rand(4, 3, requires_grad=True, device="cuda")
400
+ t = x @ y
401
+ t = to_mesh(t, host1)
402
+ with host1.activate():
403
+ z = torch.rand(3, requires_grad=True, device="cuda")
404
+ foo = (t + z).sum()
405
+
406
+ for r in grad_generator(foo, [z, y, x]):
407
+ with r.mesh.activate():
408
+ print(fetch_shard(r).result())
409
+
410
+ def test_many(self, backend_type):
411
+ with self.local_device_mesh(2, 2, backend_type):
412
+ x = torch.rand(3, 4)
413
+ for _ in range(2048):
414
+ x = x + torch.rand(3, 4)
415
+ fetch_shard(x).result()
416
+
417
+ def test_flattener(self, backend_type):
418
+ e = (8, 9, {"a": 10, "b": 11})
419
+ flatten = flattener(e)
420
+ e2 = (0, 1, {"a": 2, "b": 3})
421
+ assert [0, 1, 2, 3] == flatten(e2)
422
+
423
+ def test_torch_tensor(self, backend_type):
424
+ with self.local_device_mesh(2, 2, backend_type):
425
+ t = torch.tensor([1, 2, 4])
426
+ tc = torch.tensor([1, 2, 4], device="cuda")
427
+ t2 = fetch_shard(t).result()
428
+ tc2 = fetch_shard(tc).result()
429
+ assert torch.allclose(t2, torch.tensor([1, 2, 4]))
430
+ assert torch.allclose(tc2, torch.tensor([1, 2, 4], device="cpu"))
431
+
432
+ def test_to_mesh_aliasing(self, backend_type):
433
+ with self.local_device_mesh(2, 2, backend_type) as mesh:
434
+ p2p_stream = Stream("p2p_stream")
435
+
436
+ ppmesh = mesh.flatten("all").split(
437
+ all=(
438
+ "dp",
439
+ "pp",
440
+ ),
441
+ pp=2,
442
+ )
443
+ pp_meshes = [ppmesh.slice(pp=i) for i in range(2)]
444
+
445
+ with ppmesh.activate():
446
+ with pp_meshes[0].activate():
447
+ x = torch.randn((3, 3), device="cuda")
448
+ x_borrowed_tensor, x_borrow = p2p_stream.borrow(x)
449
+ with p2p_stream.activate():
450
+ y_on_mesh_1_p2p_stream = x_borrowed_tensor.to_mesh(pp_meshes[1])
451
+
452
+ with pp_meshes[1].activate():
453
+ x_borrow.drop()
454
+ y_on_mesh_1_default_stream, y_borrow = (
455
+ monarch.get_active_stream().borrow(y_on_mesh_1_p2p_stream)
456
+ )
457
+
458
+ monarch.inspect(y_on_mesh_1_default_stream)
459
+ y_borrow.drop()
460
+
461
+ def test_to_mesh_cow(self, backend_type):
462
+ with self.local_device_mesh(2, 2, backend_type) as mesh:
463
+ t = torch.zeros((), device="cuda")
464
+ t2 = t.to_mesh(mesh)
465
+ t.add_(1)
466
+ assert monarch.inspect(t2).item() == 0
467
+ assert monarch.inspect(t).item() == 1
468
+
469
+ def test_to_mesh_stream(self, backend_type):
470
+ other = monarch.Stream("other")
471
+ with self.local_device_mesh(2, 2, backend_type) as mesh:
472
+ m0 = mesh.slice(host=0)
473
+ m1 = mesh.slice(host=1)
474
+ with m0.activate():
475
+ t2 = torch.rand(3, 4, device="cuda").to_mesh(m1, stream=other)
476
+ with m1.activate(), other.activate():
477
+ # assert doesn't fail
478
+ monarch.inspect(t2 + t2)
479
+
480
+ def test_dropped_trace(self, backend_type):
481
+ with self.local_device_mesh(2, 2, backend_type) as _:
482
+ x = torch.rand(4, 4).cuda()
483
+ s = Stream("other")
484
+ b, drop = s.borrow(x)
485
+ drop.drop()
486
+ with s.activate():
487
+ pattern = re.compile(
488
+ ".*tensor.*is dropped at.*.*drop.drop().*", flags=re.DOTALL
489
+ )
490
+ with pytest.raises(TypeError, match=pattern):
491
+ _ = b.abs()
492
+
493
+ def test_sub_mesh_reduce(self, backend_type):
494
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
495
+ host1 = device_mesh.slice(host=1)
496
+ with host1.activate():
497
+ myrank = (
498
+ (device_mesh.rank("host") + 1) * 2 + device_mesh.rank("gpu") + 1
499
+ )
500
+ x = torch.ones((3, 4), device="cuda") * myrank
501
+ reduce = x.reduce("gpu", "sum")
502
+ local_reduce = fetch_shard(reduce).result()
503
+
504
+ assert torch.equal(local_reduce, torch.ones(3, 4) * 11)
505
+
506
+ def test_size(self, backend_type):
507
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
508
+ assert device_mesh.size(["host", "gpu"]) == 4
509
+
510
+ def test_random_state(self, backend_type):
511
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
512
+ monarch.random.make_deterministic()
513
+ for device in ("cpu", "cuda"):
514
+ a = monarch.random.get_state()
515
+ monarch.inspect(a)
516
+ first = torch.rand(1, device=device)
517
+ monarch.random.set_state(a)
518
+ second = torch.rand(1, device=device)
519
+ f, s = monarch.inspect((first, second))
520
+ with no_mesh.activate():
521
+ assert torch.allclose(f, s, atol=0, rtol=1)
522
+ seed = device_mesh.rank(["host", "gpu"]) + 4
523
+ s2 = monarch.random.new_state(seed)
524
+ s3 = monarch.random.new_state(seed)
525
+ monarch.random.set_state(s2)
526
+ r0 = torch.rand(1, device=device)
527
+ if device == "cuda":
528
+ for d in ("host", "gpu"):
529
+ r0 = r0.reduce(d, reduction="stack")
530
+ monarch.random.set_state(s3)
531
+ r1 = torch.rand(1, device=device)
532
+ if device == "cuda":
533
+ for d in ("host", "gpu"):
534
+ r1 = r1.reduce(d, reduction="stack")
535
+ r2, r3 = monarch.inspect((r0, r1))
536
+ monarch.random.set_state(a)
537
+ with no_mesh.activate():
538
+ assert torch.allclose(r2, r3, atol=0, rtol=0)
539
+ assert not torch.allclose(r2, f, atol=0, rtol=0)
540
+
541
+ def test_torch_op_with_optional_tensors(self, backend_type):
542
+ """
543
+ This test ensures that for torch ops like LayerNorm, which allow for
544
+ optional tensor arguments, the controller serializes monarch tensors
545
+ correctly as Refs instead of as IValues.
546
+ """
547
+ with self.local_device_mesh(2, 2, backend_type):
548
+ x = torch.rand(3, 4, device="cuda")
549
+ # When bias and elementwise_affine are true, extra tensors are passed through optional
550
+ # fields inside LayerNorm. When they are false, None is passed to the same optional fields.
551
+ # If we are handling serialization correctly, there shouldn't be a crash in either case.
552
+ layer_norm_with_vals = torch.nn.LayerNorm(
553
+ 4, device="cuda", bias=True, elementwise_affine=True
554
+ )
555
+ layer_norm_with_none = torch.nn.LayerNorm(
556
+ 4, device="cuda", bias=False, elementwise_affine=False
557
+ )
558
+ monarch.inspect(layer_norm_with_vals(x))
559
+ monarch.inspect(layer_norm_with_none(x))
560
+
561
+ def test_reduce_pytree(self, backend_type):
562
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
563
+ a = device_mesh.rank(("gpu", "host")) + torch.zeros((1,), device="cuda")
564
+ b = device_mesh.rank(("gpu", "host")) + torch.ones((1,), device="cuda")
565
+
566
+ tensor_dict = {"a": a, "b": b}
567
+ _ = monarch.reduce_(tensor_dict, dims=("gpu", "host"), reduction="sum")
568
+ reduced_tensor_dict = monarch.reduce(
569
+ tensor_dict, dims=("gpu", "host"), reduction="sum"
570
+ )
571
+ reduced_a = fetch_shard(reduced_tensor_dict["a"]).result()
572
+ reduced_b = fetch_shard(reduced_tensor_dict["b"]).result()
573
+ reduced_a_inplace = fetch_shard(tensor_dict["a"]).result()
574
+ reduced_b_inplace = fetch_shard(tensor_dict["b"]).result()
575
+
576
+ assert torch.equal(reduced_a_inplace, torch.tensor([6.0]))
577
+ assert torch.equal(reduced_b_inplace, torch.tensor([10.0]))
578
+ assert torch.equal(reduced_a, torch.tensor([24.0]))
579
+ assert torch.equal(reduced_b, torch.tensor([40.0]))
580
+
581
+ def test_to_mesh_pytree(self, backend_type):
582
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
583
+ host0 = device_mesh.slice(host=0)
584
+ host1 = device_mesh.slice(host=1)
585
+
586
+ with host0.activate():
587
+ a = torch.zeros((1,), device="cuda")
588
+ b = torch.ones((1,), device="cuda")
589
+ tensor_dict = {"a": a, "b": b}
590
+ moved_tensor_dict = monarch.to_mesh(tensor_dict, host1)
591
+
592
+ with host1.activate():
593
+ moved_tensor_dict["a"].add_(1)
594
+ moved_tensor_dict["b"].add_(1)
595
+
596
+ moved_tensor_a = monarch.inspect(moved_tensor_dict["a"])
597
+ moved_tensor_b = monarch.inspect(moved_tensor_dict["b"])
598
+
599
+ host0.exit()
600
+ host1.exit()
601
+
602
+ assert torch.equal(moved_tensor_a, torch.tensor([1.0]))
603
+ assert torch.equal(moved_tensor_b, torch.tensor([2.0]))
604
+
605
+ def test_hanging_error(self, backend_type):
606
+ if backend_type != "mesh":
607
+ pytest.skip("only relevant for mesh backend")
608
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
609
+ remote(lambda: torch.rand(3) + torch.rand(4), propagate=lambda: None)()
610
+
611
+ with pytest.raises(Exception, match="The size of tensor"):
612
+ device_mesh.client.shutdown()
613
+
614
+ def test_slice_mesh_pytree(self, backend_type):
615
+ with self.local_device_mesh(2, 2, backend_type) as device_mesh:
616
+ a = device_mesh.rank(("host")) + torch.zeros((1,), device="cuda")
617
+ b = device_mesh.rank(("host")) + torch.ones((1,), device="cuda")
618
+
619
+ tensor_dict = {"a": a, "b": b}
620
+ host0_slices = monarch.slice_mesh(tensor_dict, host=0)
621
+ host1_slices = monarch.slice_mesh(tensor_dict, host=1)
622
+
623
+ host0 = device_mesh.slice(host=0)
624
+ host1 = device_mesh.slice(host=1)
625
+
626
+ host0_tensors = monarch.to_mesh(host0_slices, host0)
627
+ host1_tensors = monarch.to_mesh(host1_slices, host1)
628
+
629
+ with host0.activate():
630
+ _ = monarch.reduce_(host0_tensors, dims=("gpu"), reduction="sum")
631
+ host0_a = fetch_shard(host0_tensors["a"]).result()
632
+ host0_b = fetch_shard(host0_tensors["b"]).result()
633
+
634
+ with host1.activate():
635
+ _ = monarch.reduce_(host1_tensors, dims=("gpu"), reduction="sum")
636
+ host1_a = fetch_shard(host1_tensors["a"]).result()
637
+ host1_b = fetch_shard(host1_tensors["b"]).result()
638
+
639
+ host0.exit()
640
+ host1.exit()
641
+
642
+ assert torch.equal(host0_a, torch.tensor([0.0]))
643
+ assert torch.equal(host0_b, torch.tensor([2.0]))
644
+ assert torch.equal(host1_a, torch.tensor([2.0]))
645
+ assert torch.equal(host1_b, torch.tensor([4.0]))
646
+
647
+
648
+ def test_panicking_worker():
649
+ with pytest.raises(DeviceException, match="__test_panic called"):
650
+ with local_rust_device_mesh(1, 1) as _:
651
+ panic()
652
+ # induce a sync to allow the panic to propagate back
653
+ _ = fetch_shard(torch.ones(2, 3)).result()
654
+
655
+
656
+ def test_timeout_warning(caplog):
657
+ timeout = 3
658
+ with local_rust_device_mesh(
659
+ 1,
660
+ 2,
661
+ True,
662
+ controller_params=ControllerParams(1, timeout, 100, False),
663
+ ) as dm:
664
+ for _ in range(3):
665
+ dm.client.new_node([], [])
666
+
667
+ assert dm.client.inner.next_message(timeout * 3) is None
668
+
669
+ remote_sleep(timeout * 2)
670
+ for _ in range(3):
671
+ dm.client.new_node([], [])
672
+
673
+ with caplog.at_level(logging.WARNING, logger=dm.client.__module__):
674
+ has_message = dm.client.handle_next_message(120)
675
+ assert has_message
676
+ assert (
677
+ f"ranks 1, 0 have operations that have not completed after {timeout} seconds"
678
+ in caplog.text
679
+ ) or (
680
+ f"ranks 0, 1 have operations that have not completed after {timeout} seconds"
681
+ in caplog.text
682
+ )
683
+
684
+
685
+ def test_timeout_failure():
686
+ timeout = 3
687
+ with local_rust_device_mesh(
688
+ 1,
689
+ 1,
690
+ True,
691
+ controller_params=ControllerParams(1, timeout, 100, True),
692
+ ) as dm:
693
+ for _ in range(3):
694
+ dm.client.new_node([], [])
695
+
696
+ assert dm.client.inner.next_message(timeout * 3) is None
697
+
698
+ remote_sleep(timeout * 2)
699
+ for _ in range(3):
700
+ dm.client.new_node([], [])
701
+
702
+ for _ in range(5):
703
+ result = dm.client.inner.next_message(1)
704
+ if result is None:
705
+ continue
706
+ if isinstance(result, LogMessage):
707
+ continue
708
+ if result.error is None:
709
+ continue
710
+ assert isinstance(result.error, DeviceException)
711
+ assert "crashed" in result.error.message in result.error.message
712
+ assert "mesh_0_worker[0].worker[0]" in result.error.message
713
+ assert (
714
+ f"ranks 0 have operations that have not completed after {timeout} seconds"
715
+ in result.error.frames[0].name
716
+ )
717
+
718
+
719
+ def test_supervision_heartbeat_failure():
720
+ (dms, bootstrap) = local_meshes_and_bootstraps(
721
+ meshes=1,
722
+ hosts_per_mesh=1,
723
+ gpus_per_host=2,
724
+ socket_type=SocketType.UNIX,
725
+ logging_location=LoggingLocation.DEFAULT,
726
+ supervision_params=SupervisionParams(
727
+ # Set a low timeout so heatbeat failure can be detected faster.
728
+ update_timeout_in_sec=10,
729
+ query_interval_in_sec=1,
730
+ update_interval_in_sec=1,
731
+ ),
732
+ )
733
+ assert len(dms) == 1
734
+ dm = dms[0]
735
+
736
+ # Kill a process of a worker actor. This should trigger supervision
737
+ # heartbeat failure event.
738
+ # Index 0 and 1 are system process and controller process respectively.
739
+ process = bootstrap.processes[2]
740
+ process.kill()
741
+
742
+ for _ in range(20):
743
+ # poll the next message in order to get the supervision failure
744
+ result = dm.client.inner.next_message(3)
745
+ if result is None:
746
+ continue
747
+ if result.error is None:
748
+ continue
749
+ assert isinstance(result.error, DeviceException)
750
+ assert "crashed" in result.error.message
751
+ return
752
+
753
+ dm.exit()
754
+ raise AssertionError("Should have failed supervision health check")
755
+
756
+
757
+ def test_supervision_system_actor_down():
758
+ (dms, bootstrap) = local_meshes_and_bootstraps(
759
+ meshes=1,
760
+ hosts_per_mesh=1,
761
+ gpus_per_host=2,
762
+ socket_type=SocketType.UNIX,
763
+ logging_location=LoggingLocation.DEFAULT,
764
+ supervision_params=SupervisionParams(
765
+ # Set a low timeout so heatbeat failure can be detected faster.
766
+ update_timeout_in_sec=10,
767
+ query_interval_in_sec=1,
768
+ update_interval_in_sec=1,
769
+ ),
770
+ )
771
+ assert len(dms) == 1
772
+ dm = dms[0]
773
+
774
+ # Index 0 is system process
775
+ process = bootstrap.processes[0]
776
+ process.kill()
777
+
778
+ try:
779
+ for _ in range(20):
780
+ # poll the next message in order to get the supervision failure
781
+ dm.client.inner.next_message(3)
782
+ except RuntimeError as e:
783
+ assert "actor has been stopped" in str(e)
784
+ return
785
+
786
+ dm.exit()
787
+ raise AssertionError("Should have failed supervision health check")
788
+
789
+
790
+ def test_supervision_controller_actor_down():
791
+ (dms, bootstrap) = local_meshes_and_bootstraps(
792
+ meshes=1,
793
+ hosts_per_mesh=1,
794
+ gpus_per_host=2,
795
+ socket_type=SocketType.UNIX,
796
+ logging_location=LoggingLocation.DEFAULT,
797
+ supervision_params=SupervisionParams(
798
+ # Set a low timeout so heatbeat failure can be detected faster.
799
+ update_timeout_in_sec=10,
800
+ query_interval_in_sec=1,
801
+ update_interval_in_sec=1,
802
+ ),
803
+ )
804
+ assert len(dms) == 1
805
+ dm = dms[0]
806
+
807
+ # Index 1 is controller process
808
+ process = bootstrap.processes[1]
809
+ process.kill()
810
+
811
+ for _ in range(20):
812
+ # poll the next message in order to get the supervision failure
813
+ result = dm.client.inner.next_message(3)
814
+ if result is None:
815
+ continue
816
+ if result.error is None:
817
+ continue
818
+ assert isinstance(result.error, DeviceException)
819
+ assert "mesh_0_controller[0].controller[0] crashed" in result.error.message
820
+ return
821
+
822
+ dm.exit()
823
+ raise AssertionError("Should have failed supervision health check")
824
+
825
+
826
+ def a_function_called_by_a_live_function(x):
827
+ return 2 * x
828
+
829
+
830
+ def a_live_function_call_by_a_live_function(x):
831
+ return 3 * x
832
+
833
+
834
+ def test_delete_refs():
835
+ with local_mesh(
836
+ hosts=2,
837
+ gpus_per_host=2,
838
+ socket_type=SocketType.UNIX,
839
+ logging_location=LoggingLocation.DEFAULT,
840
+ ) as dm:
841
+ dm.client.delete_ref(dm, 1)
842
+ dm.client.delete_ref(dm, 2)
843
+ assert len(dm.client._pending_del[dm]) == 2
844
+ dm.client.flush_deletes()
845
+ assert len(dm.client._pending_del[dm]) == 0