torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,112 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import importlib.resources
8
+ import subprocess
9
+
10
+ import pytest
11
+ from monarch.actor_mesh import Actor, ActorMeshRefCallFailedException, endpoint
12
+
13
+ from monarch.proc_mesh import proc_mesh
14
+
15
+
16
+ class ExceptionActor(Actor):
17
+ """An actor that has endpoints which raise exceptions."""
18
+
19
+ @endpoint
20
+ async def raise_exception(self) -> None:
21
+ """Endpoint that raises an exception."""
22
+ raise Exception("This is a test exception")
23
+
24
+
25
+ class ExceptionActorSync(Actor):
26
+ """An actor that has endpoints which raise exceptions."""
27
+
28
+ @endpoint # pyre-ignore
29
+ def raise_exception(self) -> None:
30
+ """Endpoint that raises an exception."""
31
+ raise Exception("This is a test exception")
32
+
33
+
34
+ @pytest.mark.parametrize(
35
+ "actor_class,actor_name",
36
+ [
37
+ (ExceptionActor, "exception_actor_async_call"),
38
+ (ExceptionActorSync, "exception_actor_sync_call"),
39
+ ],
40
+ )
41
+ @pytest.mark.parametrize("num_procs", [1, 2])
42
+ async def test_actor_exception(actor_class, actor_name, num_procs):
43
+ """
44
+ Test that exceptions raised in actor endpoints are propagated to the client.
45
+ """
46
+ proc = await proc_mesh(gpus=num_procs)
47
+ exception_actor = await proc.spawn(actor_name, actor_class)
48
+
49
+ with pytest.raises(
50
+ ActorMeshRefCallFailedException, match="This is a test exception"
51
+ ):
52
+ if num_procs == 1:
53
+ await exception_actor.raise_exception.call_one()
54
+ else:
55
+ await exception_actor.raise_exception.call()
56
+
57
+
58
+ @pytest.mark.parametrize(
59
+ "actor_class,actor_name",
60
+ [
61
+ (ExceptionActor, "exception_actor_async_call"),
62
+ (ExceptionActorSync, "exception_actor_sync_call"),
63
+ ],
64
+ )
65
+ @pytest.mark.parametrize("num_procs", [1, 2])
66
+ def test_actor_exception_sync(actor_class, actor_name, num_procs):
67
+ """
68
+ Test that exceptions raised in actor endpoints are propagated to the client.
69
+ """
70
+ proc = proc_mesh(gpus=num_procs).get()
71
+ exception_actor = proc.spawn(actor_name, actor_class).get()
72
+
73
+ with pytest.raises(
74
+ ActorMeshRefCallFailedException, match="This is a test exception"
75
+ ):
76
+ if num_procs == 1:
77
+ exception_actor.raise_exception.call_one().get()
78
+ else:
79
+ exception_actor.raise_exception.call().get()
80
+
81
+
82
+ # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
83
+ @pytest.mark.oss_skip
84
+ @pytest.mark.parametrize("num_procs", [1, 2])
85
+ @pytest.mark.parametrize("sync_endpoint", [False, True])
86
+ @pytest.mark.parametrize("sync_test_impl", [False, True])
87
+ @pytest.mark.parametrize("endpoint_name", ["cause_segfault", "cause_panic"])
88
+ def test_actor_segfault(num_procs, sync_endpoint, sync_test_impl, endpoint_name):
89
+ """
90
+ Test that segfaults in actor endpoints result in a non-zero exit code.
91
+ This test spawns a subprocess that will segfault and checks its exit code.
92
+
93
+ Tests both ExceptionActor and ExceptionActorSync using async API.
94
+ """
95
+ # Run the segfault test in a subprocess
96
+ test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
97
+ cmd = [
98
+ str(test_bin),
99
+ f"--num-procs={num_procs}",
100
+ f"--sync-endpoint={sync_endpoint}",
101
+ f"--sync-test-impl={sync_test_impl}",
102
+ f"--endpoint-name={endpoint_name}",
103
+ ]
104
+ process = subprocess.run(cmd, capture_output=True, timeout=60)
105
+ print(process.stdout.decode())
106
+ print(process.stderr.decode())
107
+
108
+ # Assert that the subprocess exited with a non-zero code
109
+ assert "I actually ran" in process.stdout.decode()
110
+ assert (
111
+ process.returncode != 0
112
+ ), f"Expected non-zero exit code, got {process.returncode}"
tests/test_alloc.py ADDED
@@ -0,0 +1,25 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from unittest import IsolatedAsyncioTestCase
10
+
11
+ from monarch import ProcessAllocator
12
+ from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
13
+ AllocConstraints,
14
+ AllocSpec,
15
+ )
16
+
17
+
18
+ class TestAlloc(IsolatedAsyncioTestCase):
19
+ async def test_basic(self) -> None:
20
+ cmd = "echo hello"
21
+ allocator = ProcessAllocator(cmd)
22
+ spec = AllocSpec(AllocConstraints(), replica=2)
23
+ alloc = await allocator.allocate(spec)
24
+
25
+ print(alloc)
@@ -0,0 +1,492 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import itertools
10
+ from contextlib import contextmanager
11
+ from enum import Enum
12
+ from typing import ContextManager, List
13
+ from unittest.mock import patch
14
+
15
+ import monarch
16
+
17
+ import pytest
18
+
19
+ import torch
20
+ from monarch import (
21
+ coalescing,
22
+ DeviceMesh,
23
+ fetch_shard,
24
+ get_active_mesh,
25
+ get_active_stream,
26
+ no_mesh,
27
+ remote,
28
+ Stream,
29
+ )
30
+ from monarch._testing import TestingContext
31
+ from monarch.common._coalescing import _record_and_define, compile
32
+ from monarch.common.function_caching import AliasOf, Storage, TensorGroup
33
+ from monarch.common.tensor import Tensor
34
+
35
+
36
+ def _do_bogus_tensor_work(x, y, fail_rank=None):
37
+ return x + y # real function actually does x @ y
38
+
39
+
40
+ do_bogus_tensor_work = remote(
41
+ "monarch.worker._testing_function.do_bogus_tensor_work",
42
+ propagate=_do_bogus_tensor_work,
43
+ )
44
+
45
+
46
+ def inspect(x):
47
+ return fetch_shard(x).result().item()
48
+
49
+
50
+ @pytest.fixture(scope="module", autouse=True)
51
+ def testing_context():
52
+ global local
53
+ with TestingContext() as local:
54
+ yield
55
+
56
+
57
+ class BackendType(Enum):
58
+ PY = "py"
59
+ RS = "rs"
60
+
61
+
62
+ @pytest.mark.skipif(
63
+ torch.cuda.device_count() < 2,
64
+ reason="Not enough GPUs, this test requires at least 2 GPUs",
65
+ )
66
+ @pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
67
+ class TestCoalescing:
68
+ @classmethod
69
+ def local_device_mesh(
70
+ cls,
71
+ num_hosts: int,
72
+ gpu_per_host: int,
73
+ backend_type: BackendType,
74
+ activate: bool = True,
75
+ ) -> ContextManager[DeviceMesh]:
76
+ # pyre-fixme[10]: pytest defines this fixture.
77
+ return local.local_device_mesh(
78
+ num_hosts,
79
+ gpu_per_host,
80
+ activate,
81
+ rust=backend_type == BackendType.RS,
82
+ )
83
+
84
+ @property
85
+ def num_outstanding_messages(self) -> int:
86
+ return sum(
87
+ len(msgs)
88
+ for msgs in get_active_mesh().client.recorder.flat_messages.values()
89
+ )
90
+
91
+ def test_basic_coalescing(self, backend_type) -> None:
92
+ with self.local_device_mesh(1, 1, backend_type):
93
+ with coalescing():
94
+ a = torch.zeros(3, 4)
95
+ for _ in range(1, 10):
96
+ a = a + torch.ones(3, 4)
97
+ # no messages should have been sient since coalescing is enabled
98
+ assert self.num_outstanding_messages >= 10
99
+ # now that the coalesce is done we should have flushed the messages
100
+ assert self.num_outstanding_messages == 0
101
+
102
+ def test_repeat_simple(self, backend_type) -> None:
103
+ with self.local_device_mesh(1, 1, backend_type):
104
+ a = torch.zeros(())
105
+
106
+ @compile(verify=False)
107
+ def fn():
108
+ nonlocal a
109
+ z = torch.ones(())
110
+ a += z
111
+ return z
112
+
113
+ z = None
114
+ for _ in range(3):
115
+ z = fn()
116
+
117
+ assert inspect(a) == 3
118
+ assert inspect(z) == 1
119
+
120
+ def test_repeat_formals(self, backend_type) -> None:
121
+ with self.local_device_mesh(1, 1, backend_type):
122
+ a = torch.rand(3, 4)
123
+
124
+ @compile(verify=False)
125
+ def fn(a, b):
126
+ return 2 * a + b
127
+
128
+ for _ in range(3):
129
+ b = torch.rand(3, 4)
130
+ z = fn(a, b)
131
+ lz, la, lb = monarch.inspect((z, a, b))
132
+ assert isinstance(la, torch.Tensor)
133
+ assert isinstance(lb, torch.Tensor)
134
+ with no_mesh.activate():
135
+ assert torch.allclose(lz, 2 * la + lb)
136
+
137
+ @compile(verify=False)
138
+ def fn(b):
139
+ return 2 * a + b
140
+
141
+ for _ in range(3):
142
+ b = torch.rand(3, 4)
143
+ z = fn(b)
144
+ lz, la, lb = monarch.inspect((z, a, b))
145
+ assert isinstance(la, torch.Tensor)
146
+ assert isinstance(lb, torch.Tensor)
147
+ with no_mesh.activate():
148
+ assert torch.allclose(lz, 2 * la + lb)
149
+
150
+ def test_repeat_error_inside(self, backend_type) -> None:
151
+ with self.local_device_mesh(1, 1, backend_type):
152
+ a = torch.zeros(())
153
+
154
+ @compile(verify=False)
155
+ def fn():
156
+ nonlocal a
157
+ z = torch.ones(())
158
+ a += z
159
+ do_bogus_tensor_work(z, z)
160
+ return z
161
+
162
+ z = fn()
163
+ # recorded coalescing will lump errors together so check that
164
+ with pytest.raises(Exception, match="both arguments to matmul"):
165
+ inspect(z)
166
+
167
+ def test_repeat_inner_borrow(self, backend_type) -> None:
168
+ with self.local_device_mesh(1, 1, backend_type):
169
+ a = torch.zeros(())
170
+ other = Stream("other")
171
+ with other.activate():
172
+ b = torch.ones(())
173
+
174
+ @compile(verify=False)
175
+ def fn():
176
+ nonlocal a, b
177
+ c, borrow = get_active_stream().borrow(b)
178
+ with borrow:
179
+ a += c
180
+
181
+ for _ in range(3):
182
+ fn()
183
+
184
+ assert inspect(a) == 3
185
+
186
+ def test_repeat_outer_borrow(self, backend_type) -> None:
187
+ with self.local_device_mesh(1, 1, backend_type):
188
+ a = torch.zeros(())
189
+ other = Stream("other")
190
+ with other.activate():
191
+ b = torch.ones(())
192
+ c, borrow = get_active_stream().borrow(b)
193
+
194
+ @compile(verify=False)
195
+ def fn():
196
+ nonlocal a, c
197
+ a += c
198
+ z = torch.rand(3, 4)
199
+ del c
200
+ return z
201
+
202
+ with borrow:
203
+ z = None
204
+ for _ in range(3):
205
+ z = fn()
206
+
207
+ result = fetch_shard(a).result()
208
+ fetch_shard(z).result()
209
+ with no_mesh.activate():
210
+ assert result.item() == 3
211
+
212
+ def test_nested_coalescing(self, backend_type) -> None:
213
+ with self.local_device_mesh(1, 1, backend_type):
214
+ with coalescing():
215
+ a = torch.zeros(3, 4)
216
+ with coalescing():
217
+ for _ in range(1, 10):
218
+ a = a + torch.ones(3, 4)
219
+ # confirm that there are messages awaiting to be send
220
+ assert self.num_outstanding_messages >= 10
221
+ # since we are in the nested block we shouldn't have flushed the messages yet
222
+ assert self.num_outstanding_messages >= 10
223
+ # now that the outer coalesce is done we should have flushed the messages
224
+ assert self.num_outstanding_messages == 0
225
+
226
+ def test_no_coalescing(self, backend_type) -> None:
227
+ with self.local_device_mesh(1, 1, backend_type):
228
+ a = torch.zeros(3, 4)
229
+ for _ in range(1, 10):
230
+ a = a + torch.ones(3, 4)
231
+ # without coalescing the messages should be sent with nothing outstanding
232
+ assert self.num_outstanding_messages == 0
233
+
234
+ @contextmanager
235
+ def assertRecorded(self, times: int):
236
+ with patch(
237
+ "monarch.common._coalescing._record_and_define",
238
+ side_effect=_record_and_define,
239
+ ) as m:
240
+ yield
241
+ assert m.call_count == times
242
+
243
+ def assertAliases(self, tensors: List[Tensor], aliasing: List[int]):
244
+ group = TensorGroup([t._fake for t in tensors])
245
+ c = iter(itertools.count())
246
+ actual = []
247
+ assert len(group.pattern.entries) == len(tensors)
248
+ assert len(aliasing) == len(tensors)
249
+ for e in group.pattern.entries:
250
+ match e.storage:
251
+ case AliasOf(offset=offset):
252
+ actual.append(offset)
253
+ case Storage():
254
+ actual.append(next(c))
255
+ assert aliasing == actual
256
+
257
+ def test_compile_aliasing(self, backend_type) -> None:
258
+ with self.local_device_mesh(1, 1, backend_type):
259
+
260
+ @compile(verify=False)
261
+ def add(a, b):
262
+ return a + b
263
+
264
+ @compile(verify=False)
265
+ def return_cond(a, b, c):
266
+ if c:
267
+ return a
268
+ else:
269
+ return b
270
+
271
+ a = torch.rand(3, 4)
272
+ b = torch.rand(3, 4)
273
+ with self.assertRecorded(1):
274
+ r = add(a, b)
275
+ assert r.size() == (3, 4)
276
+ r2 = add(b, a)
277
+ self.assertAliases([a, b, r2, r], [0, 1, 2, 3])
278
+
279
+ c = torch.rand(4)
280
+ d = torch.rand(4, 4)
281
+ with self.assertRecorded(1):
282
+ e = add(c, d)
283
+ assert e.size() == (4, 4)
284
+ e = add(c, torch.rand(4, 4))
285
+ assert e.size() == (4, 4)
286
+
287
+ with self.assertRecorded(1):
288
+ r = add(a, 4)
289
+ self.assertAliases([r, a], [0, 1])
290
+
291
+ with self.assertRecorded(1):
292
+ r0 = return_cond(a, b, True)
293
+ self.assertAliases([a, b, r0], [0, 1, 0])
294
+ r1 = return_cond(b, a, True)
295
+ self.assertAliases([a, b, r1], [0, 1, 1])
296
+
297
+ with self.assertRecorded(1):
298
+ r0 = return_cond(a, b, False)
299
+ self.assertAliases([a, b, r0], [0, 1, 1])
300
+ r1 = return_cond(a, b, False)
301
+ self.assertAliases([b, a, r1], [0, 1, 0])
302
+
303
+ @compile(verify=False)
304
+ def captured(b):
305
+ return a + b
306
+
307
+ with self.assertRecorded(1):
308
+ r = captured(b)
309
+ self.assertAliases([a, b, r], [0, 1, 2])
310
+ r = captured(torch.rand(3, 4))
311
+ assert r.size() == (3, 4)
312
+
313
+ with self.assertRecorded(1):
314
+ # input aliased with capture
315
+ captured(a)
316
+ captured(a)
317
+
318
+ @compile(verify=False)
319
+ def weird(f, g):
320
+ o = f + g
321
+ return o, o[0], f[0], g[0], a[0]
322
+
323
+ with self.assertRecorded(1):
324
+ r0, r1, r2, r3, r4 = weird(c, d)
325
+ self.assertAliases(
326
+ [c, d, a, r0, r1, r2, r3, r4], [0, 1, 2, 3, 3, 0, 1, 2]
327
+ )
328
+
329
+ def test_compile_input_permissions(self, backend_type):
330
+ with self.local_device_mesh(1, 1, backend_type):
331
+ a = torch.rand(3, 4)
332
+
333
+ @compile(verify=False)
334
+ def add(b):
335
+ return a + b
336
+
337
+ with self.assertRecorded(1):
338
+ c = add(torch.rand(3, 4))
339
+
340
+ other = Stream("other")
341
+ ab, borrow = other.borrow(a, mutable=True)
342
+
343
+ with borrow:
344
+ with pytest.raises(TypeError, match="BORROWED"):
345
+ add(torch.rand(3, 4))
346
+
347
+ # test we can read it again
348
+ add(torch.rand(3, 4))
349
+
350
+ ab, borrow = other.borrow(a)
351
+ with borrow:
352
+ add(torch.rand(3, 4))
353
+
354
+ with self.assertRecorded(0):
355
+ with other.activate():
356
+ c = torch.rand(3, 4)
357
+ c, borrow = monarch.get_active_stream().borrow(c)
358
+ with borrow:
359
+ add(c)
360
+
361
+ a.drop()
362
+
363
+ with pytest.raises(TypeError, match="DROPPED"):
364
+ add(torch.rand(3, 4))
365
+
366
+ def test_compile_verify(self, backend_type):
367
+ with self.local_device_mesh(1, 1, backend_type):
368
+ a = torch.rand(3, 4)
369
+
370
+ @compile(verify=True)
371
+ def add(b):
372
+ return a + b
373
+
374
+ c = False
375
+
376
+ @compile(verify=True)
377
+ def add_broken(b):
378
+ nonlocal c
379
+ if c:
380
+ a = torch.zeros(3, 4)
381
+ else:
382
+ a = torch.rand(3, 4)
383
+ return a.add(b)
384
+
385
+ with self.assertRecorded(2):
386
+ add(torch.rand(3, 4))
387
+ add(torch.rand(3, 4))
388
+ add(torch.rand(3, 4))
389
+
390
+ add_broken(torch.rand(3, 4))
391
+ with pytest.raises(RuntimeError, match="diverges"):
392
+ c = True
393
+ add_broken(torch.rand(3, 4))
394
+
395
+ def test_dropped(self, backend_type):
396
+ with self.local_device_mesh(1, 1, backend_type):
397
+ a = torch.rand(3, 4)
398
+ b = None
399
+
400
+ @compile(verify=False)
401
+ def foo():
402
+ nonlocal b
403
+ b = a + a
404
+
405
+ foo()
406
+ with pytest.raises(TypeError, match="DROPPED"):
407
+ b.add(4)
408
+
409
+ def test_across_mesh(self, backend_type):
410
+ with self.local_device_mesh(2, 1, backend_type) as m:
411
+ m0 = m(host=0)
412
+ m1 = m(host=1)
413
+
414
+ @compile
415
+ def foo(a, b):
416
+ with m0.activate():
417
+ r0 = a + a
418
+ with m1.activate():
419
+ r1 = b + b
420
+ return r0, r1
421
+
422
+ with m0.activate():
423
+ a = torch.rand(3, 4)
424
+ with m1.activate():
425
+ b = torch.rand(3, 4)
426
+
427
+ r0, r1 = foo(a, b)
428
+ with m0.activate():
429
+ monarch.inspect(r0)
430
+ with m1.activate():
431
+ monarch.inspect(r0)
432
+
433
+ def test_grad_not_supported(self, backend_type):
434
+ with self.local_device_mesh(1, 1, backend_type):
435
+
436
+ @compile
437
+ def foo(x):
438
+ return x
439
+
440
+ y = torch.rand(3, requires_grad=True)
441
+
442
+ @compile
443
+ def returnit():
444
+ return y
445
+
446
+ with pytest.raises(TypeError, match="REQUIRES_GRAD"):
447
+ foo(torch.rand(3, requires_grad=True))
448
+
449
+ with pytest.raises(TypeError, match="REQUIRES_GRAD"):
450
+ returnit()
451
+
452
+ def test_mutate_inputs(self, backend_type):
453
+ with self.local_device_mesh(1, 1, backend_type) as mesh:
454
+
455
+ @compile(verify=False)
456
+ def foo(x_not_mutated, w_not_mutated, y, y_alias, z, z_alias):
457
+ u = (
458
+ x_not_mutated.mul(2.0)
459
+ + w_not_mutated
460
+ + z_alias.unsqueeze(0).repeat(3, 1)
461
+ )
462
+ v = y.add(5.0)
463
+ stream = monarch.Stream("borrow")
464
+ borrowed_y_alias, y_alias_borrow = stream.borrow(y_alias, mutable=True)
465
+ with stream.activate():
466
+ borrowed_y_alias.add_(1.0)
467
+ y_alias_borrow.drop()
468
+ z.add_(1.0)
469
+ return u, v
470
+
471
+ x_not_mutated = torch.rand(3, 3)
472
+ w_not_mutated = torch.rand(3, 3)
473
+ y = torch.rand(3, 3)
474
+ y_alias = y.reshape(-1)
475
+ z = torch.rand(3, 3)
476
+ z_alias = z[0, :]
477
+
478
+ mutated_inputs = (y, y_alias, z, z_alias)
479
+ mutated_aliases = set().union(*[t._aliases.aliases for t in mutated_inputs])
480
+ all_inputs = (x_not_mutated, w_not_mutated) + mutated_inputs
481
+ with patch.object(
482
+ mesh.client,
483
+ "new_node_nocoalesce",
484
+ side_effect=mesh.client.new_node_nocoalesce,
485
+ ) as new_node:
486
+ for _ in range(2):
487
+ u, v = foo(*all_inputs)
488
+ (mutated, used, _, _), _ = new_node.call_args
489
+ assert mutated_aliases.union(
490
+ u._aliases.aliases, v._aliases.aliases
491
+ ) == set(mutated)
492
+ assert set(all_inputs) == set(used)