torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
tests/test_coalescing.py
ADDED
@@ -0,0 +1,492 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
|
9
|
+
import itertools
|
10
|
+
from contextlib import contextmanager
|
11
|
+
from enum import Enum
|
12
|
+
from typing import ContextManager, List
|
13
|
+
from unittest.mock import patch
|
14
|
+
|
15
|
+
import monarch
|
16
|
+
|
17
|
+
import pytest
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from monarch import (
|
21
|
+
coalescing,
|
22
|
+
DeviceMesh,
|
23
|
+
fetch_shard,
|
24
|
+
get_active_mesh,
|
25
|
+
get_active_stream,
|
26
|
+
no_mesh,
|
27
|
+
remote,
|
28
|
+
Stream,
|
29
|
+
)
|
30
|
+
from monarch._testing import TestingContext
|
31
|
+
from monarch.common._coalescing import _record_and_define, compile
|
32
|
+
from monarch.common.function_caching import AliasOf, Storage, TensorGroup
|
33
|
+
from monarch.common.tensor import Tensor
|
34
|
+
|
35
|
+
|
36
|
+
def _do_bogus_tensor_work(x, y, fail_rank=None):
|
37
|
+
return x + y # real function actually does x @ y
|
38
|
+
|
39
|
+
|
40
|
+
do_bogus_tensor_work = remote(
|
41
|
+
"monarch.worker._testing_function.do_bogus_tensor_work",
|
42
|
+
propagate=_do_bogus_tensor_work,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
def inspect(x):
|
47
|
+
return fetch_shard(x).result().item()
|
48
|
+
|
49
|
+
|
50
|
+
@pytest.fixture(scope="module", autouse=True)
|
51
|
+
def testing_context():
|
52
|
+
global local
|
53
|
+
with TestingContext() as local:
|
54
|
+
yield
|
55
|
+
|
56
|
+
|
57
|
+
class BackendType(Enum):
|
58
|
+
PY = "py"
|
59
|
+
RS = "rs"
|
60
|
+
|
61
|
+
|
62
|
+
@pytest.mark.skipif(
|
63
|
+
torch.cuda.device_count() < 2,
|
64
|
+
reason="Not enough GPUs, this test requires at least 2 GPUs",
|
65
|
+
)
|
66
|
+
@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
|
67
|
+
class TestCoalescing:
|
68
|
+
@classmethod
|
69
|
+
def local_device_mesh(
|
70
|
+
cls,
|
71
|
+
num_hosts: int,
|
72
|
+
gpu_per_host: int,
|
73
|
+
backend_type: BackendType,
|
74
|
+
activate: bool = True,
|
75
|
+
) -> ContextManager[DeviceMesh]:
|
76
|
+
# pyre-fixme[10]: pytest defines this fixture.
|
77
|
+
return local.local_device_mesh(
|
78
|
+
num_hosts,
|
79
|
+
gpu_per_host,
|
80
|
+
activate,
|
81
|
+
backend=str(backend_type),
|
82
|
+
)
|
83
|
+
|
84
|
+
@property
|
85
|
+
def num_outstanding_messages(self) -> int:
|
86
|
+
return sum(
|
87
|
+
len(msgs)
|
88
|
+
for msgs in get_active_mesh().client.recorder.flat_messages.values()
|
89
|
+
)
|
90
|
+
|
91
|
+
def test_basic_coalescing(self, backend_type) -> None:
|
92
|
+
with self.local_device_mesh(1, 1, backend_type):
|
93
|
+
with coalescing():
|
94
|
+
a = torch.zeros(3, 4)
|
95
|
+
for _ in range(1, 10):
|
96
|
+
a = a + torch.ones(3, 4)
|
97
|
+
# no messages should have been sient since coalescing is enabled
|
98
|
+
assert self.num_outstanding_messages >= 10
|
99
|
+
# now that the coalesce is done we should have flushed the messages
|
100
|
+
assert self.num_outstanding_messages == 0
|
101
|
+
|
102
|
+
def test_repeat_simple(self, backend_type) -> None:
|
103
|
+
with self.local_device_mesh(1, 1, backend_type):
|
104
|
+
a = torch.zeros(())
|
105
|
+
|
106
|
+
@compile(verify=False)
|
107
|
+
def fn():
|
108
|
+
nonlocal a
|
109
|
+
z = torch.ones(())
|
110
|
+
a += z
|
111
|
+
return z
|
112
|
+
|
113
|
+
z = None
|
114
|
+
for _ in range(3):
|
115
|
+
z = fn()
|
116
|
+
|
117
|
+
assert inspect(a) == 3
|
118
|
+
assert inspect(z) == 1
|
119
|
+
|
120
|
+
def test_repeat_formals(self, backend_type) -> None:
|
121
|
+
with self.local_device_mesh(1, 1, backend_type):
|
122
|
+
a = torch.rand(3, 4)
|
123
|
+
|
124
|
+
@compile(verify=False)
|
125
|
+
def fn(a, b):
|
126
|
+
return 2 * a + b
|
127
|
+
|
128
|
+
for _ in range(3):
|
129
|
+
b = torch.rand(3, 4)
|
130
|
+
z = fn(a, b)
|
131
|
+
lz, la, lb = monarch.inspect((z, a, b))
|
132
|
+
assert isinstance(la, torch.Tensor)
|
133
|
+
assert isinstance(lb, torch.Tensor)
|
134
|
+
with no_mesh.activate():
|
135
|
+
assert torch.allclose(lz, 2 * la + lb)
|
136
|
+
|
137
|
+
@compile(verify=False)
|
138
|
+
def fn(b):
|
139
|
+
return 2 * a + b
|
140
|
+
|
141
|
+
for _ in range(3):
|
142
|
+
b = torch.rand(3, 4)
|
143
|
+
z = fn(b)
|
144
|
+
lz, la, lb = monarch.inspect((z, a, b))
|
145
|
+
assert isinstance(la, torch.Tensor)
|
146
|
+
assert isinstance(lb, torch.Tensor)
|
147
|
+
with no_mesh.activate():
|
148
|
+
assert torch.allclose(lz, 2 * la + lb)
|
149
|
+
|
150
|
+
def test_repeat_error_inside(self, backend_type) -> None:
|
151
|
+
with self.local_device_mesh(1, 1, backend_type):
|
152
|
+
a = torch.zeros(())
|
153
|
+
|
154
|
+
@compile(verify=False)
|
155
|
+
def fn():
|
156
|
+
nonlocal a
|
157
|
+
z = torch.ones(())
|
158
|
+
a += z
|
159
|
+
do_bogus_tensor_work(z, z)
|
160
|
+
return z
|
161
|
+
|
162
|
+
z = fn()
|
163
|
+
# recorded coalescing will lump errors together so check that
|
164
|
+
with pytest.raises(Exception, match="both arguments to matmul"):
|
165
|
+
inspect(z)
|
166
|
+
|
167
|
+
def test_repeat_inner_borrow(self, backend_type) -> None:
|
168
|
+
with self.local_device_mesh(1, 1, backend_type):
|
169
|
+
a = torch.zeros(())
|
170
|
+
other = Stream("other")
|
171
|
+
with other.activate():
|
172
|
+
b = torch.ones(())
|
173
|
+
|
174
|
+
@compile(verify=False)
|
175
|
+
def fn():
|
176
|
+
nonlocal a, b
|
177
|
+
c, borrow = get_active_stream().borrow(b)
|
178
|
+
with borrow:
|
179
|
+
a += c
|
180
|
+
|
181
|
+
for _ in range(3):
|
182
|
+
fn()
|
183
|
+
|
184
|
+
assert inspect(a) == 3
|
185
|
+
|
186
|
+
def test_repeat_outer_borrow(self, backend_type) -> None:
|
187
|
+
with self.local_device_mesh(1, 1, backend_type):
|
188
|
+
a = torch.zeros(())
|
189
|
+
other = Stream("other")
|
190
|
+
with other.activate():
|
191
|
+
b = torch.ones(())
|
192
|
+
c, borrow = get_active_stream().borrow(b)
|
193
|
+
|
194
|
+
@compile(verify=False)
|
195
|
+
def fn():
|
196
|
+
nonlocal a, c
|
197
|
+
a += c
|
198
|
+
z = torch.rand(3, 4)
|
199
|
+
del c
|
200
|
+
return z
|
201
|
+
|
202
|
+
with borrow:
|
203
|
+
z = None
|
204
|
+
for _ in range(3):
|
205
|
+
z = fn()
|
206
|
+
|
207
|
+
result = fetch_shard(a).result()
|
208
|
+
fetch_shard(z).result()
|
209
|
+
with no_mesh.activate():
|
210
|
+
assert result.item() == 3
|
211
|
+
|
212
|
+
def test_nested_coalescing(self, backend_type) -> None:
|
213
|
+
with self.local_device_mesh(1, 1, backend_type):
|
214
|
+
with coalescing():
|
215
|
+
a = torch.zeros(3, 4)
|
216
|
+
with coalescing():
|
217
|
+
for _ in range(1, 10):
|
218
|
+
a = a + torch.ones(3, 4)
|
219
|
+
# confirm that there are messages awaiting to be send
|
220
|
+
assert self.num_outstanding_messages >= 10
|
221
|
+
# since we are in the nested block we shouldn't have flushed the messages yet
|
222
|
+
assert self.num_outstanding_messages >= 10
|
223
|
+
# now that the outer coalesce is done we should have flushed the messages
|
224
|
+
assert self.num_outstanding_messages == 0
|
225
|
+
|
226
|
+
def test_no_coalescing(self, backend_type) -> None:
|
227
|
+
with self.local_device_mesh(1, 1, backend_type):
|
228
|
+
a = torch.zeros(3, 4)
|
229
|
+
for _ in range(1, 10):
|
230
|
+
a = a + torch.ones(3, 4)
|
231
|
+
# without coalescing the messages should be sent with nothing outstanding
|
232
|
+
assert self.num_outstanding_messages == 0
|
233
|
+
|
234
|
+
@contextmanager
|
235
|
+
def assertRecorded(self, times: int):
|
236
|
+
with patch(
|
237
|
+
"monarch.common._coalescing._record_and_define",
|
238
|
+
side_effect=_record_and_define,
|
239
|
+
) as m:
|
240
|
+
yield
|
241
|
+
assert m.call_count == times
|
242
|
+
|
243
|
+
def assertAliases(self, tensors: List[Tensor], aliasing: List[int]):
|
244
|
+
group = TensorGroup([t._fake for t in tensors])
|
245
|
+
c = iter(itertools.count())
|
246
|
+
actual = []
|
247
|
+
assert len(group.pattern.entries) == len(tensors)
|
248
|
+
assert len(aliasing) == len(tensors)
|
249
|
+
for e in group.pattern.entries:
|
250
|
+
match e.storage:
|
251
|
+
case AliasOf(offset=offset):
|
252
|
+
actual.append(offset)
|
253
|
+
case Storage():
|
254
|
+
actual.append(next(c))
|
255
|
+
assert aliasing == actual
|
256
|
+
|
257
|
+
def test_compile_aliasing(self, backend_type) -> None:
|
258
|
+
with self.local_device_mesh(1, 1, backend_type):
|
259
|
+
|
260
|
+
@compile(verify=False)
|
261
|
+
def add(a, b):
|
262
|
+
return a + b
|
263
|
+
|
264
|
+
@compile(verify=False)
|
265
|
+
def return_cond(a, b, c):
|
266
|
+
if c:
|
267
|
+
return a
|
268
|
+
else:
|
269
|
+
return b
|
270
|
+
|
271
|
+
a = torch.rand(3, 4)
|
272
|
+
b = torch.rand(3, 4)
|
273
|
+
with self.assertRecorded(1):
|
274
|
+
r = add(a, b)
|
275
|
+
assert r.size() == (3, 4)
|
276
|
+
r2 = add(b, a)
|
277
|
+
self.assertAliases([a, b, r2, r], [0, 1, 2, 3])
|
278
|
+
|
279
|
+
c = torch.rand(4)
|
280
|
+
d = torch.rand(4, 4)
|
281
|
+
with self.assertRecorded(1):
|
282
|
+
e = add(c, d)
|
283
|
+
assert e.size() == (4, 4)
|
284
|
+
e = add(c, torch.rand(4, 4))
|
285
|
+
assert e.size() == (4, 4)
|
286
|
+
|
287
|
+
with self.assertRecorded(1):
|
288
|
+
r = add(a, 4)
|
289
|
+
self.assertAliases([r, a], [0, 1])
|
290
|
+
|
291
|
+
with self.assertRecorded(1):
|
292
|
+
r0 = return_cond(a, b, True)
|
293
|
+
self.assertAliases([a, b, r0], [0, 1, 0])
|
294
|
+
r1 = return_cond(b, a, True)
|
295
|
+
self.assertAliases([a, b, r1], [0, 1, 1])
|
296
|
+
|
297
|
+
with self.assertRecorded(1):
|
298
|
+
r0 = return_cond(a, b, False)
|
299
|
+
self.assertAliases([a, b, r0], [0, 1, 1])
|
300
|
+
r1 = return_cond(a, b, False)
|
301
|
+
self.assertAliases([b, a, r1], [0, 1, 0])
|
302
|
+
|
303
|
+
@compile(verify=False)
|
304
|
+
def captured(b):
|
305
|
+
return a + b
|
306
|
+
|
307
|
+
with self.assertRecorded(1):
|
308
|
+
r = captured(b)
|
309
|
+
self.assertAliases([a, b, r], [0, 1, 2])
|
310
|
+
r = captured(torch.rand(3, 4))
|
311
|
+
assert r.size() == (3, 4)
|
312
|
+
|
313
|
+
with self.assertRecorded(1):
|
314
|
+
# input aliased with capture
|
315
|
+
captured(a)
|
316
|
+
captured(a)
|
317
|
+
|
318
|
+
@compile(verify=False)
|
319
|
+
def weird(f, g):
|
320
|
+
o = f + g
|
321
|
+
return o, o[0], f[0], g[0], a[0]
|
322
|
+
|
323
|
+
with self.assertRecorded(1):
|
324
|
+
r0, r1, r2, r3, r4 = weird(c, d)
|
325
|
+
self.assertAliases(
|
326
|
+
[c, d, a, r0, r1, r2, r3, r4], [0, 1, 2, 3, 3, 0, 1, 2]
|
327
|
+
)
|
328
|
+
|
329
|
+
def test_compile_input_permissions(self, backend_type):
|
330
|
+
with self.local_device_mesh(1, 1, backend_type):
|
331
|
+
a = torch.rand(3, 4)
|
332
|
+
|
333
|
+
@compile(verify=False)
|
334
|
+
def add(b):
|
335
|
+
return a + b
|
336
|
+
|
337
|
+
with self.assertRecorded(1):
|
338
|
+
c = add(torch.rand(3, 4))
|
339
|
+
|
340
|
+
other = Stream("other")
|
341
|
+
ab, borrow = other.borrow(a, mutable=True)
|
342
|
+
|
343
|
+
with borrow:
|
344
|
+
with pytest.raises(TypeError, match="BORROWED"):
|
345
|
+
add(torch.rand(3, 4))
|
346
|
+
|
347
|
+
# test we can read it again
|
348
|
+
add(torch.rand(3, 4))
|
349
|
+
|
350
|
+
ab, borrow = other.borrow(a)
|
351
|
+
with borrow:
|
352
|
+
add(torch.rand(3, 4))
|
353
|
+
|
354
|
+
with self.assertRecorded(0):
|
355
|
+
with other.activate():
|
356
|
+
c = torch.rand(3, 4)
|
357
|
+
c, borrow = monarch.get_active_stream().borrow(c)
|
358
|
+
with borrow:
|
359
|
+
add(c)
|
360
|
+
|
361
|
+
a.drop()
|
362
|
+
|
363
|
+
with pytest.raises(TypeError, match="DROPPED"):
|
364
|
+
add(torch.rand(3, 4))
|
365
|
+
|
366
|
+
def test_compile_verify(self, backend_type):
|
367
|
+
with self.local_device_mesh(1, 1, backend_type):
|
368
|
+
a = torch.rand(3, 4)
|
369
|
+
|
370
|
+
@compile(verify=True)
|
371
|
+
def add(b):
|
372
|
+
return a + b
|
373
|
+
|
374
|
+
c = False
|
375
|
+
|
376
|
+
@compile(verify=True)
|
377
|
+
def add_broken(b):
|
378
|
+
nonlocal c
|
379
|
+
if c:
|
380
|
+
a = torch.zeros(3, 4)
|
381
|
+
else:
|
382
|
+
a = torch.rand(3, 4)
|
383
|
+
return a.add(b)
|
384
|
+
|
385
|
+
with self.assertRecorded(2):
|
386
|
+
add(torch.rand(3, 4))
|
387
|
+
add(torch.rand(3, 4))
|
388
|
+
add(torch.rand(3, 4))
|
389
|
+
|
390
|
+
add_broken(torch.rand(3, 4))
|
391
|
+
with pytest.raises(RuntimeError, match="diverges"):
|
392
|
+
c = True
|
393
|
+
add_broken(torch.rand(3, 4))
|
394
|
+
|
395
|
+
def test_dropped(self, backend_type):
|
396
|
+
with self.local_device_mesh(1, 1, backend_type):
|
397
|
+
a = torch.rand(3, 4)
|
398
|
+
b = None
|
399
|
+
|
400
|
+
@compile(verify=False)
|
401
|
+
def foo():
|
402
|
+
nonlocal b
|
403
|
+
b = a + a
|
404
|
+
|
405
|
+
foo()
|
406
|
+
with pytest.raises(TypeError, match="DROPPED"):
|
407
|
+
b.add(4)
|
408
|
+
|
409
|
+
def test_across_mesh(self, backend_type):
|
410
|
+
with self.local_device_mesh(2, 1, backend_type) as m:
|
411
|
+
m0 = m(host=0)
|
412
|
+
m1 = m(host=1)
|
413
|
+
|
414
|
+
@compile
|
415
|
+
def foo(a, b):
|
416
|
+
with m0.activate():
|
417
|
+
r0 = a + a
|
418
|
+
with m1.activate():
|
419
|
+
r1 = b + b
|
420
|
+
return r0, r1
|
421
|
+
|
422
|
+
with m0.activate():
|
423
|
+
a = torch.rand(3, 4)
|
424
|
+
with m1.activate():
|
425
|
+
b = torch.rand(3, 4)
|
426
|
+
|
427
|
+
r0, r1 = foo(a, b)
|
428
|
+
with m0.activate():
|
429
|
+
monarch.inspect(r0)
|
430
|
+
with m1.activate():
|
431
|
+
monarch.inspect(r0)
|
432
|
+
|
433
|
+
def test_grad_not_supported(self, backend_type):
|
434
|
+
with self.local_device_mesh(1, 1, backend_type):
|
435
|
+
|
436
|
+
@compile
|
437
|
+
def foo(x):
|
438
|
+
return x
|
439
|
+
|
440
|
+
y = torch.rand(3, requires_grad=True)
|
441
|
+
|
442
|
+
@compile
|
443
|
+
def returnit():
|
444
|
+
return y
|
445
|
+
|
446
|
+
with pytest.raises(TypeError, match="REQUIRES_GRAD"):
|
447
|
+
foo(torch.rand(3, requires_grad=True))
|
448
|
+
|
449
|
+
with pytest.raises(TypeError, match="REQUIRES_GRAD"):
|
450
|
+
returnit()
|
451
|
+
|
452
|
+
def test_mutate_inputs(self, backend_type):
|
453
|
+
with self.local_device_mesh(1, 1, backend_type) as mesh:
|
454
|
+
|
455
|
+
@compile(verify=False)
|
456
|
+
def foo(x_not_mutated, w_not_mutated, y, y_alias, z, z_alias):
|
457
|
+
u = (
|
458
|
+
x_not_mutated.mul(2.0)
|
459
|
+
+ w_not_mutated
|
460
|
+
+ z_alias.unsqueeze(0).repeat(3, 1)
|
461
|
+
)
|
462
|
+
v = y.add(5.0)
|
463
|
+
stream = monarch.Stream("borrow")
|
464
|
+
borrowed_y_alias, y_alias_borrow = stream.borrow(y_alias, mutable=True)
|
465
|
+
with stream.activate():
|
466
|
+
borrowed_y_alias.add_(1.0)
|
467
|
+
y_alias_borrow.drop()
|
468
|
+
z.add_(1.0)
|
469
|
+
return u, v
|
470
|
+
|
471
|
+
x_not_mutated = torch.rand(3, 3)
|
472
|
+
w_not_mutated = torch.rand(3, 3)
|
473
|
+
y = torch.rand(3, 3)
|
474
|
+
y_alias = y.reshape(-1)
|
475
|
+
z = torch.rand(3, 3)
|
476
|
+
z_alias = z[0, :]
|
477
|
+
|
478
|
+
mutated_inputs = (y, y_alias, z, z_alias)
|
479
|
+
mutated_aliases = set().union(*[t._aliases.aliases for t in mutated_inputs])
|
480
|
+
all_inputs = (x_not_mutated, w_not_mutated) + mutated_inputs
|
481
|
+
with patch.object(
|
482
|
+
mesh.client,
|
483
|
+
"new_node_nocoalesce",
|
484
|
+
side_effect=mesh.client.new_node_nocoalesce,
|
485
|
+
) as new_node:
|
486
|
+
for _ in range(2):
|
487
|
+
u, v = foo(*all_inputs)
|
488
|
+
(mutated, used, _, _), _ = new_node.call_args
|
489
|
+
assert mutated_aliases.union(
|
490
|
+
u._aliases.aliases, v._aliases.aliases
|
491
|
+
) == set(mutated)
|
492
|
+
assert set(all_inputs) == set(used)
|