torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +74 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +198 -0
- monarch/actor_mesh.py +692 -0
- monarch/allocator.py +62 -0
- monarch/bootstrap_main.py +75 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +69 -0
- monarch/cached_remote_function.py +257 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +646 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +443 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +572 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +304 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +204 -0
- monarch/common/stream.py +111 -0
- monarch/common/tensor.py +793 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/fetch.py +55 -0
- monarch/future.py +25 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/proc_mesh.py +188 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +190 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +357 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +189 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +57 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +121 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +139 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +112 -0
- tests/test_alloc.py +25 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +835 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +372 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +182 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
- torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import importlib.resources
|
8
|
+
import subprocess
|
9
|
+
|
10
|
+
import pytest
|
11
|
+
from monarch.actor_mesh import Actor, ActorMeshRefCallFailedException, endpoint
|
12
|
+
|
13
|
+
from monarch.proc_mesh import proc_mesh
|
14
|
+
|
15
|
+
|
16
|
+
class ExceptionActor(Actor):
|
17
|
+
"""An actor that has endpoints which raise exceptions."""
|
18
|
+
|
19
|
+
@endpoint
|
20
|
+
async def raise_exception(self) -> None:
|
21
|
+
"""Endpoint that raises an exception."""
|
22
|
+
raise Exception("This is a test exception")
|
23
|
+
|
24
|
+
|
25
|
+
class ExceptionActorSync(Actor):
|
26
|
+
"""An actor that has endpoints which raise exceptions."""
|
27
|
+
|
28
|
+
@endpoint # pyre-ignore
|
29
|
+
def raise_exception(self) -> None:
|
30
|
+
"""Endpoint that raises an exception."""
|
31
|
+
raise Exception("This is a test exception")
|
32
|
+
|
33
|
+
|
34
|
+
@pytest.mark.parametrize(
|
35
|
+
"actor_class,actor_name",
|
36
|
+
[
|
37
|
+
(ExceptionActor, "exception_actor_async_call"),
|
38
|
+
(ExceptionActorSync, "exception_actor_sync_call"),
|
39
|
+
],
|
40
|
+
)
|
41
|
+
@pytest.mark.parametrize("num_procs", [1, 2])
|
42
|
+
async def test_actor_exception(actor_class, actor_name, num_procs):
|
43
|
+
"""
|
44
|
+
Test that exceptions raised in actor endpoints are propagated to the client.
|
45
|
+
"""
|
46
|
+
proc = await proc_mesh(gpus=num_procs)
|
47
|
+
exception_actor = await proc.spawn(actor_name, actor_class)
|
48
|
+
|
49
|
+
with pytest.raises(
|
50
|
+
ActorMeshRefCallFailedException, match="This is a test exception"
|
51
|
+
):
|
52
|
+
if num_procs == 1:
|
53
|
+
await exception_actor.raise_exception.call_one()
|
54
|
+
else:
|
55
|
+
await exception_actor.raise_exception.call()
|
56
|
+
|
57
|
+
|
58
|
+
@pytest.mark.parametrize(
|
59
|
+
"actor_class,actor_name",
|
60
|
+
[
|
61
|
+
(ExceptionActor, "exception_actor_async_call"),
|
62
|
+
(ExceptionActorSync, "exception_actor_sync_call"),
|
63
|
+
],
|
64
|
+
)
|
65
|
+
@pytest.mark.parametrize("num_procs", [1, 2])
|
66
|
+
def test_actor_exception_sync(actor_class, actor_name, num_procs):
|
67
|
+
"""
|
68
|
+
Test that exceptions raised in actor endpoints are propagated to the client.
|
69
|
+
"""
|
70
|
+
proc = proc_mesh(gpus=num_procs).get()
|
71
|
+
exception_actor = proc.spawn(actor_name, actor_class).get()
|
72
|
+
|
73
|
+
with pytest.raises(
|
74
|
+
ActorMeshRefCallFailedException, match="This is a test exception"
|
75
|
+
):
|
76
|
+
if num_procs == 1:
|
77
|
+
exception_actor.raise_exception.call_one().get()
|
78
|
+
else:
|
79
|
+
exception_actor.raise_exception.call().get()
|
80
|
+
|
81
|
+
|
82
|
+
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
|
83
|
+
@pytest.mark.oss_skip
|
84
|
+
@pytest.mark.parametrize("num_procs", [1, 2])
|
85
|
+
@pytest.mark.parametrize("sync_endpoint", [False, True])
|
86
|
+
@pytest.mark.parametrize("sync_test_impl", [False, True])
|
87
|
+
@pytest.mark.parametrize("endpoint_name", ["cause_segfault", "cause_panic"])
|
88
|
+
def test_actor_segfault(num_procs, sync_endpoint, sync_test_impl, endpoint_name):
|
89
|
+
"""
|
90
|
+
Test that segfaults in actor endpoints result in a non-zero exit code.
|
91
|
+
This test spawns a subprocess that will segfault and checks its exit code.
|
92
|
+
|
93
|
+
Tests both ExceptionActor and ExceptionActorSync using async API.
|
94
|
+
"""
|
95
|
+
# Run the segfault test in a subprocess
|
96
|
+
test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
|
97
|
+
cmd = [
|
98
|
+
str(test_bin),
|
99
|
+
f"--num-procs={num_procs}",
|
100
|
+
f"--sync-endpoint={sync_endpoint}",
|
101
|
+
f"--sync-test-impl={sync_test_impl}",
|
102
|
+
f"--endpoint-name={endpoint_name}",
|
103
|
+
]
|
104
|
+
process = subprocess.run(cmd, capture_output=True, timeout=60)
|
105
|
+
print(process.stdout.decode())
|
106
|
+
print(process.stderr.decode())
|
107
|
+
|
108
|
+
# Assert that the subprocess exited with a non-zero code
|
109
|
+
assert "I actually ran" in process.stdout.decode()
|
110
|
+
assert (
|
111
|
+
process.returncode != 0
|
112
|
+
), f"Expected non-zero exit code, got {process.returncode}"
|
tests/test_alloc.py
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
from unittest import IsolatedAsyncioTestCase
|
10
|
+
|
11
|
+
from monarch import ProcessAllocator
|
12
|
+
from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
13
|
+
AllocConstraints,
|
14
|
+
AllocSpec,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
class TestAlloc(IsolatedAsyncioTestCase):
|
19
|
+
async def test_basic(self) -> None:
|
20
|
+
cmd = "echo hello"
|
21
|
+
allocator = ProcessAllocator(cmd)
|
22
|
+
spec = AllocSpec(AllocConstraints(), replica=2)
|
23
|
+
alloc = await allocator.allocate(spec)
|
24
|
+
|
25
|
+
print(alloc)
|
tests/test_coalescing.py
ADDED
@@ -0,0 +1,492 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
|
9
|
+
import itertools
|
10
|
+
from contextlib import contextmanager
|
11
|
+
from enum import Enum
|
12
|
+
from typing import ContextManager, List
|
13
|
+
from unittest.mock import patch
|
14
|
+
|
15
|
+
import monarch
|
16
|
+
|
17
|
+
import pytest
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from monarch import (
|
21
|
+
coalescing,
|
22
|
+
DeviceMesh,
|
23
|
+
fetch_shard,
|
24
|
+
get_active_mesh,
|
25
|
+
get_active_stream,
|
26
|
+
no_mesh,
|
27
|
+
remote,
|
28
|
+
Stream,
|
29
|
+
)
|
30
|
+
from monarch._testing import TestingContext
|
31
|
+
from monarch.common._coalescing import _record_and_define, compile
|
32
|
+
from monarch.common.function_caching import AliasOf, Storage, TensorGroup
|
33
|
+
from monarch.common.tensor import Tensor
|
34
|
+
|
35
|
+
|
36
|
+
def _do_bogus_tensor_work(x, y, fail_rank=None):
|
37
|
+
return x + y # real function actually does x @ y
|
38
|
+
|
39
|
+
|
40
|
+
do_bogus_tensor_work = remote(
|
41
|
+
"monarch.worker._testing_function.do_bogus_tensor_work",
|
42
|
+
propagate=_do_bogus_tensor_work,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
def inspect(x):
|
47
|
+
return fetch_shard(x).result().item()
|
48
|
+
|
49
|
+
|
50
|
+
@pytest.fixture(scope="module", autouse=True)
|
51
|
+
def testing_context():
|
52
|
+
global local
|
53
|
+
with TestingContext() as local:
|
54
|
+
yield
|
55
|
+
|
56
|
+
|
57
|
+
class BackendType(Enum):
|
58
|
+
PY = "py"
|
59
|
+
RS = "rs"
|
60
|
+
|
61
|
+
|
62
|
+
@pytest.mark.skipif(
|
63
|
+
torch.cuda.device_count() < 2,
|
64
|
+
reason="Not enough GPUs, this test requires at least 2 GPUs",
|
65
|
+
)
|
66
|
+
@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
|
67
|
+
class TestCoalescing:
|
68
|
+
@classmethod
|
69
|
+
def local_device_mesh(
|
70
|
+
cls,
|
71
|
+
num_hosts: int,
|
72
|
+
gpu_per_host: int,
|
73
|
+
backend_type: BackendType,
|
74
|
+
activate: bool = True,
|
75
|
+
) -> ContextManager[DeviceMesh]:
|
76
|
+
# pyre-fixme[10]: pytest defines this fixture.
|
77
|
+
return local.local_device_mesh(
|
78
|
+
num_hosts,
|
79
|
+
gpu_per_host,
|
80
|
+
activate,
|
81
|
+
rust=backend_type == BackendType.RS,
|
82
|
+
)
|
83
|
+
|
84
|
+
@property
|
85
|
+
def num_outstanding_messages(self) -> int:
|
86
|
+
return sum(
|
87
|
+
len(msgs)
|
88
|
+
for msgs in get_active_mesh().client.recorder.flat_messages.values()
|
89
|
+
)
|
90
|
+
|
91
|
+
def test_basic_coalescing(self, backend_type) -> None:
|
92
|
+
with self.local_device_mesh(1, 1, backend_type):
|
93
|
+
with coalescing():
|
94
|
+
a = torch.zeros(3, 4)
|
95
|
+
for _ in range(1, 10):
|
96
|
+
a = a + torch.ones(3, 4)
|
97
|
+
# no messages should have been sient since coalescing is enabled
|
98
|
+
assert self.num_outstanding_messages >= 10
|
99
|
+
# now that the coalesce is done we should have flushed the messages
|
100
|
+
assert self.num_outstanding_messages == 0
|
101
|
+
|
102
|
+
def test_repeat_simple(self, backend_type) -> None:
|
103
|
+
with self.local_device_mesh(1, 1, backend_type):
|
104
|
+
a = torch.zeros(())
|
105
|
+
|
106
|
+
@compile(verify=False)
|
107
|
+
def fn():
|
108
|
+
nonlocal a
|
109
|
+
z = torch.ones(())
|
110
|
+
a += z
|
111
|
+
return z
|
112
|
+
|
113
|
+
z = None
|
114
|
+
for _ in range(3):
|
115
|
+
z = fn()
|
116
|
+
|
117
|
+
assert inspect(a) == 3
|
118
|
+
assert inspect(z) == 1
|
119
|
+
|
120
|
+
def test_repeat_formals(self, backend_type) -> None:
|
121
|
+
with self.local_device_mesh(1, 1, backend_type):
|
122
|
+
a = torch.rand(3, 4)
|
123
|
+
|
124
|
+
@compile(verify=False)
|
125
|
+
def fn(a, b):
|
126
|
+
return 2 * a + b
|
127
|
+
|
128
|
+
for _ in range(3):
|
129
|
+
b = torch.rand(3, 4)
|
130
|
+
z = fn(a, b)
|
131
|
+
lz, la, lb = monarch.inspect((z, a, b))
|
132
|
+
assert isinstance(la, torch.Tensor)
|
133
|
+
assert isinstance(lb, torch.Tensor)
|
134
|
+
with no_mesh.activate():
|
135
|
+
assert torch.allclose(lz, 2 * la + lb)
|
136
|
+
|
137
|
+
@compile(verify=False)
|
138
|
+
def fn(b):
|
139
|
+
return 2 * a + b
|
140
|
+
|
141
|
+
for _ in range(3):
|
142
|
+
b = torch.rand(3, 4)
|
143
|
+
z = fn(b)
|
144
|
+
lz, la, lb = monarch.inspect((z, a, b))
|
145
|
+
assert isinstance(la, torch.Tensor)
|
146
|
+
assert isinstance(lb, torch.Tensor)
|
147
|
+
with no_mesh.activate():
|
148
|
+
assert torch.allclose(lz, 2 * la + lb)
|
149
|
+
|
150
|
+
def test_repeat_error_inside(self, backend_type) -> None:
|
151
|
+
with self.local_device_mesh(1, 1, backend_type):
|
152
|
+
a = torch.zeros(())
|
153
|
+
|
154
|
+
@compile(verify=False)
|
155
|
+
def fn():
|
156
|
+
nonlocal a
|
157
|
+
z = torch.ones(())
|
158
|
+
a += z
|
159
|
+
do_bogus_tensor_work(z, z)
|
160
|
+
return z
|
161
|
+
|
162
|
+
z = fn()
|
163
|
+
# recorded coalescing will lump errors together so check that
|
164
|
+
with pytest.raises(Exception, match="both arguments to matmul"):
|
165
|
+
inspect(z)
|
166
|
+
|
167
|
+
def test_repeat_inner_borrow(self, backend_type) -> None:
|
168
|
+
with self.local_device_mesh(1, 1, backend_type):
|
169
|
+
a = torch.zeros(())
|
170
|
+
other = Stream("other")
|
171
|
+
with other.activate():
|
172
|
+
b = torch.ones(())
|
173
|
+
|
174
|
+
@compile(verify=False)
|
175
|
+
def fn():
|
176
|
+
nonlocal a, b
|
177
|
+
c, borrow = get_active_stream().borrow(b)
|
178
|
+
with borrow:
|
179
|
+
a += c
|
180
|
+
|
181
|
+
for _ in range(3):
|
182
|
+
fn()
|
183
|
+
|
184
|
+
assert inspect(a) == 3
|
185
|
+
|
186
|
+
def test_repeat_outer_borrow(self, backend_type) -> None:
|
187
|
+
with self.local_device_mesh(1, 1, backend_type):
|
188
|
+
a = torch.zeros(())
|
189
|
+
other = Stream("other")
|
190
|
+
with other.activate():
|
191
|
+
b = torch.ones(())
|
192
|
+
c, borrow = get_active_stream().borrow(b)
|
193
|
+
|
194
|
+
@compile(verify=False)
|
195
|
+
def fn():
|
196
|
+
nonlocal a, c
|
197
|
+
a += c
|
198
|
+
z = torch.rand(3, 4)
|
199
|
+
del c
|
200
|
+
return z
|
201
|
+
|
202
|
+
with borrow:
|
203
|
+
z = None
|
204
|
+
for _ in range(3):
|
205
|
+
z = fn()
|
206
|
+
|
207
|
+
result = fetch_shard(a).result()
|
208
|
+
fetch_shard(z).result()
|
209
|
+
with no_mesh.activate():
|
210
|
+
assert result.item() == 3
|
211
|
+
|
212
|
+
def test_nested_coalescing(self, backend_type) -> None:
|
213
|
+
with self.local_device_mesh(1, 1, backend_type):
|
214
|
+
with coalescing():
|
215
|
+
a = torch.zeros(3, 4)
|
216
|
+
with coalescing():
|
217
|
+
for _ in range(1, 10):
|
218
|
+
a = a + torch.ones(3, 4)
|
219
|
+
# confirm that there are messages awaiting to be send
|
220
|
+
assert self.num_outstanding_messages >= 10
|
221
|
+
# since we are in the nested block we shouldn't have flushed the messages yet
|
222
|
+
assert self.num_outstanding_messages >= 10
|
223
|
+
# now that the outer coalesce is done we should have flushed the messages
|
224
|
+
assert self.num_outstanding_messages == 0
|
225
|
+
|
226
|
+
def test_no_coalescing(self, backend_type) -> None:
|
227
|
+
with self.local_device_mesh(1, 1, backend_type):
|
228
|
+
a = torch.zeros(3, 4)
|
229
|
+
for _ in range(1, 10):
|
230
|
+
a = a + torch.ones(3, 4)
|
231
|
+
# without coalescing the messages should be sent with nothing outstanding
|
232
|
+
assert self.num_outstanding_messages == 0
|
233
|
+
|
234
|
+
@contextmanager
|
235
|
+
def assertRecorded(self, times: int):
|
236
|
+
with patch(
|
237
|
+
"monarch.common._coalescing._record_and_define",
|
238
|
+
side_effect=_record_and_define,
|
239
|
+
) as m:
|
240
|
+
yield
|
241
|
+
assert m.call_count == times
|
242
|
+
|
243
|
+
def assertAliases(self, tensors: List[Tensor], aliasing: List[int]):
|
244
|
+
group = TensorGroup([t._fake for t in tensors])
|
245
|
+
c = iter(itertools.count())
|
246
|
+
actual = []
|
247
|
+
assert len(group.pattern.entries) == len(tensors)
|
248
|
+
assert len(aliasing) == len(tensors)
|
249
|
+
for e in group.pattern.entries:
|
250
|
+
match e.storage:
|
251
|
+
case AliasOf(offset=offset):
|
252
|
+
actual.append(offset)
|
253
|
+
case Storage():
|
254
|
+
actual.append(next(c))
|
255
|
+
assert aliasing == actual
|
256
|
+
|
257
|
+
def test_compile_aliasing(self, backend_type) -> None:
|
258
|
+
with self.local_device_mesh(1, 1, backend_type):
|
259
|
+
|
260
|
+
@compile(verify=False)
|
261
|
+
def add(a, b):
|
262
|
+
return a + b
|
263
|
+
|
264
|
+
@compile(verify=False)
|
265
|
+
def return_cond(a, b, c):
|
266
|
+
if c:
|
267
|
+
return a
|
268
|
+
else:
|
269
|
+
return b
|
270
|
+
|
271
|
+
a = torch.rand(3, 4)
|
272
|
+
b = torch.rand(3, 4)
|
273
|
+
with self.assertRecorded(1):
|
274
|
+
r = add(a, b)
|
275
|
+
assert r.size() == (3, 4)
|
276
|
+
r2 = add(b, a)
|
277
|
+
self.assertAliases([a, b, r2, r], [0, 1, 2, 3])
|
278
|
+
|
279
|
+
c = torch.rand(4)
|
280
|
+
d = torch.rand(4, 4)
|
281
|
+
with self.assertRecorded(1):
|
282
|
+
e = add(c, d)
|
283
|
+
assert e.size() == (4, 4)
|
284
|
+
e = add(c, torch.rand(4, 4))
|
285
|
+
assert e.size() == (4, 4)
|
286
|
+
|
287
|
+
with self.assertRecorded(1):
|
288
|
+
r = add(a, 4)
|
289
|
+
self.assertAliases([r, a], [0, 1])
|
290
|
+
|
291
|
+
with self.assertRecorded(1):
|
292
|
+
r0 = return_cond(a, b, True)
|
293
|
+
self.assertAliases([a, b, r0], [0, 1, 0])
|
294
|
+
r1 = return_cond(b, a, True)
|
295
|
+
self.assertAliases([a, b, r1], [0, 1, 1])
|
296
|
+
|
297
|
+
with self.assertRecorded(1):
|
298
|
+
r0 = return_cond(a, b, False)
|
299
|
+
self.assertAliases([a, b, r0], [0, 1, 1])
|
300
|
+
r1 = return_cond(a, b, False)
|
301
|
+
self.assertAliases([b, a, r1], [0, 1, 0])
|
302
|
+
|
303
|
+
@compile(verify=False)
|
304
|
+
def captured(b):
|
305
|
+
return a + b
|
306
|
+
|
307
|
+
with self.assertRecorded(1):
|
308
|
+
r = captured(b)
|
309
|
+
self.assertAliases([a, b, r], [0, 1, 2])
|
310
|
+
r = captured(torch.rand(3, 4))
|
311
|
+
assert r.size() == (3, 4)
|
312
|
+
|
313
|
+
with self.assertRecorded(1):
|
314
|
+
# input aliased with capture
|
315
|
+
captured(a)
|
316
|
+
captured(a)
|
317
|
+
|
318
|
+
@compile(verify=False)
|
319
|
+
def weird(f, g):
|
320
|
+
o = f + g
|
321
|
+
return o, o[0], f[0], g[0], a[0]
|
322
|
+
|
323
|
+
with self.assertRecorded(1):
|
324
|
+
r0, r1, r2, r3, r4 = weird(c, d)
|
325
|
+
self.assertAliases(
|
326
|
+
[c, d, a, r0, r1, r2, r3, r4], [0, 1, 2, 3, 3, 0, 1, 2]
|
327
|
+
)
|
328
|
+
|
329
|
+
def test_compile_input_permissions(self, backend_type):
|
330
|
+
with self.local_device_mesh(1, 1, backend_type):
|
331
|
+
a = torch.rand(3, 4)
|
332
|
+
|
333
|
+
@compile(verify=False)
|
334
|
+
def add(b):
|
335
|
+
return a + b
|
336
|
+
|
337
|
+
with self.assertRecorded(1):
|
338
|
+
c = add(torch.rand(3, 4))
|
339
|
+
|
340
|
+
other = Stream("other")
|
341
|
+
ab, borrow = other.borrow(a, mutable=True)
|
342
|
+
|
343
|
+
with borrow:
|
344
|
+
with pytest.raises(TypeError, match="BORROWED"):
|
345
|
+
add(torch.rand(3, 4))
|
346
|
+
|
347
|
+
# test we can read it again
|
348
|
+
add(torch.rand(3, 4))
|
349
|
+
|
350
|
+
ab, borrow = other.borrow(a)
|
351
|
+
with borrow:
|
352
|
+
add(torch.rand(3, 4))
|
353
|
+
|
354
|
+
with self.assertRecorded(0):
|
355
|
+
with other.activate():
|
356
|
+
c = torch.rand(3, 4)
|
357
|
+
c, borrow = monarch.get_active_stream().borrow(c)
|
358
|
+
with borrow:
|
359
|
+
add(c)
|
360
|
+
|
361
|
+
a.drop()
|
362
|
+
|
363
|
+
with pytest.raises(TypeError, match="DROPPED"):
|
364
|
+
add(torch.rand(3, 4))
|
365
|
+
|
366
|
+
def test_compile_verify(self, backend_type):
|
367
|
+
with self.local_device_mesh(1, 1, backend_type):
|
368
|
+
a = torch.rand(3, 4)
|
369
|
+
|
370
|
+
@compile(verify=True)
|
371
|
+
def add(b):
|
372
|
+
return a + b
|
373
|
+
|
374
|
+
c = False
|
375
|
+
|
376
|
+
@compile(verify=True)
|
377
|
+
def add_broken(b):
|
378
|
+
nonlocal c
|
379
|
+
if c:
|
380
|
+
a = torch.zeros(3, 4)
|
381
|
+
else:
|
382
|
+
a = torch.rand(3, 4)
|
383
|
+
return a.add(b)
|
384
|
+
|
385
|
+
with self.assertRecorded(2):
|
386
|
+
add(torch.rand(3, 4))
|
387
|
+
add(torch.rand(3, 4))
|
388
|
+
add(torch.rand(3, 4))
|
389
|
+
|
390
|
+
add_broken(torch.rand(3, 4))
|
391
|
+
with pytest.raises(RuntimeError, match="diverges"):
|
392
|
+
c = True
|
393
|
+
add_broken(torch.rand(3, 4))
|
394
|
+
|
395
|
+
def test_dropped(self, backend_type):
|
396
|
+
with self.local_device_mesh(1, 1, backend_type):
|
397
|
+
a = torch.rand(3, 4)
|
398
|
+
b = None
|
399
|
+
|
400
|
+
@compile(verify=False)
|
401
|
+
def foo():
|
402
|
+
nonlocal b
|
403
|
+
b = a + a
|
404
|
+
|
405
|
+
foo()
|
406
|
+
with pytest.raises(TypeError, match="DROPPED"):
|
407
|
+
b.add(4)
|
408
|
+
|
409
|
+
def test_across_mesh(self, backend_type):
|
410
|
+
with self.local_device_mesh(2, 1, backend_type) as m:
|
411
|
+
m0 = m(host=0)
|
412
|
+
m1 = m(host=1)
|
413
|
+
|
414
|
+
@compile
|
415
|
+
def foo(a, b):
|
416
|
+
with m0.activate():
|
417
|
+
r0 = a + a
|
418
|
+
with m1.activate():
|
419
|
+
r1 = b + b
|
420
|
+
return r0, r1
|
421
|
+
|
422
|
+
with m0.activate():
|
423
|
+
a = torch.rand(3, 4)
|
424
|
+
with m1.activate():
|
425
|
+
b = torch.rand(3, 4)
|
426
|
+
|
427
|
+
r0, r1 = foo(a, b)
|
428
|
+
with m0.activate():
|
429
|
+
monarch.inspect(r0)
|
430
|
+
with m1.activate():
|
431
|
+
monarch.inspect(r0)
|
432
|
+
|
433
|
+
def test_grad_not_supported(self, backend_type):
|
434
|
+
with self.local_device_mesh(1, 1, backend_type):
|
435
|
+
|
436
|
+
@compile
|
437
|
+
def foo(x):
|
438
|
+
return x
|
439
|
+
|
440
|
+
y = torch.rand(3, requires_grad=True)
|
441
|
+
|
442
|
+
@compile
|
443
|
+
def returnit():
|
444
|
+
return y
|
445
|
+
|
446
|
+
with pytest.raises(TypeError, match="REQUIRES_GRAD"):
|
447
|
+
foo(torch.rand(3, requires_grad=True))
|
448
|
+
|
449
|
+
with pytest.raises(TypeError, match="REQUIRES_GRAD"):
|
450
|
+
returnit()
|
451
|
+
|
452
|
+
def test_mutate_inputs(self, backend_type):
|
453
|
+
with self.local_device_mesh(1, 1, backend_type) as mesh:
|
454
|
+
|
455
|
+
@compile(verify=False)
|
456
|
+
def foo(x_not_mutated, w_not_mutated, y, y_alias, z, z_alias):
|
457
|
+
u = (
|
458
|
+
x_not_mutated.mul(2.0)
|
459
|
+
+ w_not_mutated
|
460
|
+
+ z_alias.unsqueeze(0).repeat(3, 1)
|
461
|
+
)
|
462
|
+
v = y.add(5.0)
|
463
|
+
stream = monarch.Stream("borrow")
|
464
|
+
borrowed_y_alias, y_alias_borrow = stream.borrow(y_alias, mutable=True)
|
465
|
+
with stream.activate():
|
466
|
+
borrowed_y_alias.add_(1.0)
|
467
|
+
y_alias_borrow.drop()
|
468
|
+
z.add_(1.0)
|
469
|
+
return u, v
|
470
|
+
|
471
|
+
x_not_mutated = torch.rand(3, 3)
|
472
|
+
w_not_mutated = torch.rand(3, 3)
|
473
|
+
y = torch.rand(3, 3)
|
474
|
+
y_alias = y.reshape(-1)
|
475
|
+
z = torch.rand(3, 3)
|
476
|
+
z_alias = z[0, :]
|
477
|
+
|
478
|
+
mutated_inputs = (y, y_alias, z, z_alias)
|
479
|
+
mutated_aliases = set().union(*[t._aliases.aliases for t in mutated_inputs])
|
480
|
+
all_inputs = (x_not_mutated, w_not_mutated) + mutated_inputs
|
481
|
+
with patch.object(
|
482
|
+
mesh.client,
|
483
|
+
"new_node_nocoalesce",
|
484
|
+
side_effect=mesh.client.new_node_nocoalesce,
|
485
|
+
) as new_node:
|
486
|
+
for _ in range(2):
|
487
|
+
u, v = foo(*all_inputs)
|
488
|
+
(mutated, used, _, _), _ = new_node.call_args
|
489
|
+
assert mutated_aliases.union(
|
490
|
+
u._aliases.aliases, v._aliases.aliases
|
491
|
+
) == set(mutated)
|
492
|
+
assert set(all_inputs) == set(used)
|