torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,736 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import operator
|
9
|
+
import re
|
10
|
+
import threading
|
11
|
+
import time
|
12
|
+
from types import ModuleType
|
13
|
+
from unittest.mock import AsyncMock, patch
|
14
|
+
|
15
|
+
import monarch
|
16
|
+
|
17
|
+
import pytest
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
from monarch.actor_mesh import (
|
22
|
+
Accumulator,
|
23
|
+
Actor,
|
24
|
+
current_actor_name,
|
25
|
+
current_rank,
|
26
|
+
current_size,
|
27
|
+
endpoint,
|
28
|
+
MonarchContext,
|
29
|
+
)
|
30
|
+
from monarch.debugger import init_debugging
|
31
|
+
from monarch.future import ActorFuture
|
32
|
+
|
33
|
+
from monarch.proc_mesh import local_proc_mesh, proc_mesh
|
34
|
+
from monarch.rdma import RDMABuffer
|
35
|
+
|
36
|
+
needs_cuda = pytest.mark.skipif(
|
37
|
+
not torch.cuda.is_available(),
|
38
|
+
reason="CUDA not available",
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
class Counter(Actor):
|
43
|
+
def __init__(self, v: int):
|
44
|
+
self.v = v
|
45
|
+
|
46
|
+
@endpoint
|
47
|
+
async def incr(self):
|
48
|
+
self.v += 1
|
49
|
+
|
50
|
+
@endpoint
|
51
|
+
async def value(self) -> int:
|
52
|
+
return self.v
|
53
|
+
|
54
|
+
|
55
|
+
class Indirect(Actor):
|
56
|
+
@endpoint
|
57
|
+
async def call_value(self, c: Counter) -> int:
|
58
|
+
return await c.value.choose()
|
59
|
+
|
60
|
+
|
61
|
+
class ParameterServer(Actor):
|
62
|
+
def __init__(self):
|
63
|
+
self.params = torch.rand(10, 10)
|
64
|
+
self.grad_buffer = torch.rand(10, 10)
|
65
|
+
|
66
|
+
@endpoint
|
67
|
+
async def grad_handle(self) -> RDMABuffer:
|
68
|
+
byte_tensor = self.grad_buffer.view(torch.uint8).flatten()
|
69
|
+
return RDMABuffer(byte_tensor)
|
70
|
+
|
71
|
+
@endpoint
|
72
|
+
async def update(self):
|
73
|
+
self.params += 0.01 * self.grad_buffer
|
74
|
+
|
75
|
+
@endpoint
|
76
|
+
async def get_grad_buffer(self) -> torch.Tensor:
|
77
|
+
# just used for testing
|
78
|
+
return self.grad_buffer
|
79
|
+
|
80
|
+
|
81
|
+
async def test_choose():
|
82
|
+
proc = await local_proc_mesh(gpus=2)
|
83
|
+
v = await proc.spawn("counter", Counter, 3)
|
84
|
+
i = await proc.spawn("indirect", Indirect)
|
85
|
+
v.incr.broadcast()
|
86
|
+
result = await v.value.choose()
|
87
|
+
result2 = await i.call_value.choose(v)
|
88
|
+
|
89
|
+
assert result == result2
|
90
|
+
|
91
|
+
|
92
|
+
async def test_stream():
|
93
|
+
proc = await local_proc_mesh(gpus=2)
|
94
|
+
v = await proc.spawn("counter2", Counter, 3)
|
95
|
+
v.incr.broadcast()
|
96
|
+
|
97
|
+
assert 8 == sum([x async for x in v.value.stream()])
|
98
|
+
|
99
|
+
|
100
|
+
class ParameterClient(Actor):
|
101
|
+
def __init__(self, server, buffer):
|
102
|
+
self.server = server
|
103
|
+
byte_tensor = buffer.view(torch.uint8).flatten()
|
104
|
+
self.buffer = byte_tensor
|
105
|
+
|
106
|
+
@endpoint
|
107
|
+
async def upload(self, tensor):
|
108
|
+
gh = await self.server.grad_handle.call_one()
|
109
|
+
await gh.write(tensor)
|
110
|
+
|
111
|
+
@endpoint
|
112
|
+
async def download(self):
|
113
|
+
gh = await self.server.grad_handle.call_one()
|
114
|
+
await gh.read_into(self.buffer)
|
115
|
+
|
116
|
+
@endpoint
|
117
|
+
async def get_buffer(self):
|
118
|
+
return self.buffer
|
119
|
+
|
120
|
+
|
121
|
+
@needs_cuda
|
122
|
+
async def test_proc_mesh_rdma():
|
123
|
+
proc = await proc_mesh(gpus=1)
|
124
|
+
server = await proc.spawn("server", ParameterServer)
|
125
|
+
|
126
|
+
# --- CPU TESTS ---
|
127
|
+
client_cpu = await proc.spawn(
|
128
|
+
"client_cpu", ParameterClient, server, torch.ones(10, 10)
|
129
|
+
)
|
130
|
+
x = await client_cpu.get_buffer.call_one()
|
131
|
+
assert torch.sum(x.view(torch.float32).view(10, 10)) == 100
|
132
|
+
zeros = torch.zeros(10, 10)
|
133
|
+
await client_cpu.upload.call_one(zeros.view(torch.uint8).flatten())
|
134
|
+
await client_cpu.download.call_one()
|
135
|
+
x = await client_cpu.get_buffer.call_one()
|
136
|
+
assert torch.sum(x.view(torch.float32).view(10, 10)) == 0
|
137
|
+
|
138
|
+
# --- Modify server's backing buffer directly ---
|
139
|
+
await server.update.call_one()
|
140
|
+
|
141
|
+
# Should reflect updated values
|
142
|
+
await client_cpu.download.call_one()
|
143
|
+
|
144
|
+
buffer = await client_cpu.get_buffer.call_one()
|
145
|
+
remote_grad = await server.get_grad_buffer.call_one()
|
146
|
+
assert torch.allclose(buffer.view(torch.float32).view(10, 10), remote_grad)
|
147
|
+
|
148
|
+
# --- GPU TESTS ---
|
149
|
+
client_gpu = await proc.spawn(
|
150
|
+
"client_gpu", ParameterClient, server, torch.ones(10, 10, device="cuda")
|
151
|
+
)
|
152
|
+
x = await client_gpu.get_buffer.call_one()
|
153
|
+
buffer = x.view(torch.float32).view(10, 10)
|
154
|
+
assert torch.sum(buffer) == 100
|
155
|
+
zeros = torch.zeros(10, 10, device="cuda")
|
156
|
+
await client_gpu.upload.call_one(zeros.view(torch.uint8).flatten())
|
157
|
+
await client_gpu.download.call_one()
|
158
|
+
x = await client_gpu.get_buffer.call_one()
|
159
|
+
buffer_gpu = x.view(torch.float32).view(10, 10)
|
160
|
+
assert torch.sum(buffer_gpu) == 0
|
161
|
+
assert buffer_gpu.device.type == "cuda"
|
162
|
+
|
163
|
+
# Modify server state again
|
164
|
+
await server.update.call_one()
|
165
|
+
await client_gpu.download.call_one()
|
166
|
+
x = await client_gpu.get_buffer.call_one()
|
167
|
+
buffer_gpu = x.view(torch.float32).view(10, 10)
|
168
|
+
remote_grad = await server.get_grad_buffer.call_one()
|
169
|
+
assert torch.allclose(buffer_gpu.cpu(), remote_grad)
|
170
|
+
|
171
|
+
|
172
|
+
class To(Actor):
|
173
|
+
@endpoint
|
174
|
+
async def whoami(self):
|
175
|
+
return current_actor_name()
|
176
|
+
|
177
|
+
|
178
|
+
class From(Actor):
|
179
|
+
@endpoint
|
180
|
+
async def get(self, to: To):
|
181
|
+
return [x async for x in to.whoami.stream()]
|
182
|
+
|
183
|
+
|
184
|
+
async def test_mesh_passed_to_mesh():
|
185
|
+
proc = await local_proc_mesh(gpus=2)
|
186
|
+
f = await proc.spawn("from", From)
|
187
|
+
t = await proc.spawn("to", To)
|
188
|
+
all = [y async for x in f.get.stream(t) for y in x]
|
189
|
+
assert len(all) == 4
|
190
|
+
assert all[0] != all[1]
|
191
|
+
|
192
|
+
|
193
|
+
async def test_mesh_passed_to_mesh_on_different_proc_mesh():
|
194
|
+
proc = await local_proc_mesh(gpus=2)
|
195
|
+
proc2 = await local_proc_mesh(gpus=2)
|
196
|
+
f = await proc.spawn("from", From)
|
197
|
+
t = await proc2.spawn("to", To)
|
198
|
+
all = [y async for x in f.get.stream(t) for y in x]
|
199
|
+
assert len(all) == 4
|
200
|
+
assert all[0] != all[1]
|
201
|
+
|
202
|
+
|
203
|
+
async def test_actor_slicing():
|
204
|
+
proc = await local_proc_mesh(gpus=2)
|
205
|
+
proc2 = await local_proc_mesh(gpus=2)
|
206
|
+
|
207
|
+
f = await proc.spawn("from", From)
|
208
|
+
t = await proc2.spawn("to", To)
|
209
|
+
|
210
|
+
assert await t.slice(gpus=0).whoami.call() != await t.slice(gpus=1).whoami.call()
|
211
|
+
|
212
|
+
result = [y async for x in f.get.stream(t.slice(gpus=0)) for y in x]
|
213
|
+
assert len(result) == 2
|
214
|
+
|
215
|
+
assert result[0] == result[1]
|
216
|
+
|
217
|
+
|
218
|
+
async def test_aggregate():
|
219
|
+
proc = await local_proc_mesh(gpus=2)
|
220
|
+
counter = await proc.spawn("counter", Counter, 1)
|
221
|
+
counter.incr.broadcast()
|
222
|
+
acc = Accumulator(counter.value, 0, operator.add)
|
223
|
+
r = await acc.accumulate()
|
224
|
+
assert r == 4
|
225
|
+
|
226
|
+
|
227
|
+
class RunIt(Actor):
|
228
|
+
@endpoint
|
229
|
+
async def run(self, fn):
|
230
|
+
return fn()
|
231
|
+
|
232
|
+
|
233
|
+
async def test_rank_size():
|
234
|
+
proc = await local_proc_mesh(gpus=2)
|
235
|
+
r = await proc.spawn("runit", RunIt)
|
236
|
+
|
237
|
+
acc = Accumulator(r.run, 0, operator.add)
|
238
|
+
|
239
|
+
assert 1 == await acc.accumulate(lambda: current_rank()["gpus"])
|
240
|
+
assert 4 == await acc.accumulate(lambda: current_size()["gpus"])
|
241
|
+
|
242
|
+
|
243
|
+
class TrainerActor(Actor):
|
244
|
+
def __init__(self):
|
245
|
+
super().__init__()
|
246
|
+
self.trainer = torch.nn.Linear(10, 10).to("cuda")
|
247
|
+
self.trainer.weight.data.zero_()
|
248
|
+
|
249
|
+
@endpoint
|
250
|
+
async def init(self, gen):
|
251
|
+
ranks = current_rank()
|
252
|
+
self.gen = gen.slice(**ranks)
|
253
|
+
|
254
|
+
@endpoint
|
255
|
+
async def exchange_metadata(self):
|
256
|
+
byte_tensor = self.trainer.weight.data.view(torch.uint8).flatten()
|
257
|
+
self.handle = RDMABuffer(byte_tensor)
|
258
|
+
await self.gen.attach_weight_buffer.call(self.handle)
|
259
|
+
|
260
|
+
@endpoint
|
261
|
+
async def weights_ready(self):
|
262
|
+
self.trainer.weight.data.add_(1.0)
|
263
|
+
|
264
|
+
|
265
|
+
class GeneratorActor(Actor):
|
266
|
+
def __init__(self):
|
267
|
+
super().__init__()
|
268
|
+
self.generator = torch.nn.Linear(10, 10).to("cuda")
|
269
|
+
self.step = 0
|
270
|
+
|
271
|
+
@endpoint
|
272
|
+
async def init(self, trainer):
|
273
|
+
ranks = current_rank()
|
274
|
+
self.trainer = trainer.slice(**ranks)
|
275
|
+
|
276
|
+
@endpoint
|
277
|
+
async def attach_weight_buffer(self, handle):
|
278
|
+
self.handle = handle
|
279
|
+
|
280
|
+
@endpoint
|
281
|
+
async def update_weights(self):
|
282
|
+
self.step += 1
|
283
|
+
byte_tensor = self.generator.weight.data.view(torch.uint8).flatten()
|
284
|
+
await self.handle.read_into(byte_tensor)
|
285
|
+
assert (
|
286
|
+
torch.sum(self.generator.weight.data) == self.step * 100
|
287
|
+
), f"{torch.sum(self.generator.weight.data)=}, {self.step=}"
|
288
|
+
|
289
|
+
|
290
|
+
@needs_cuda
|
291
|
+
async def test_gpu_trainer_generator():
|
292
|
+
trainer_proc = await proc_mesh(gpus=1)
|
293
|
+
gen_proc = await proc_mesh(gpus=1)
|
294
|
+
trainer = await trainer_proc.spawn("trainer", TrainerActor)
|
295
|
+
generator = await gen_proc.spawn("gen", GeneratorActor)
|
296
|
+
|
297
|
+
await generator.init.call(trainer)
|
298
|
+
await trainer.init.call(generator)
|
299
|
+
await trainer.exchange_metadata.call()
|
300
|
+
|
301
|
+
for _ in range(3):
|
302
|
+
await trainer.weights_ready.call()
|
303
|
+
await generator.update_weights.call()
|
304
|
+
|
305
|
+
|
306
|
+
class SyncActor(Actor):
|
307
|
+
@endpoint
|
308
|
+
def sync_endpoint(self, a_counter: Counter):
|
309
|
+
return a_counter.value.choose().get()
|
310
|
+
|
311
|
+
|
312
|
+
async def test_sync_actor():
|
313
|
+
proc = await local_proc_mesh(gpus=2)
|
314
|
+
a = await proc.spawn("actor", SyncActor)
|
315
|
+
c = await proc.spawn("counter", Counter, 5)
|
316
|
+
r = await a.sync_endpoint.choose(c)
|
317
|
+
assert r == 5
|
318
|
+
|
319
|
+
|
320
|
+
@needs_cuda
|
321
|
+
def test_gpu_trainer_generator_sync() -> None:
|
322
|
+
trainer_proc = proc_mesh(gpus=1).get()
|
323
|
+
gen_proc = proc_mesh(gpus=1).get()
|
324
|
+
trainer = trainer_proc.spawn("trainer", TrainerActor).get()
|
325
|
+
generator = gen_proc.spawn("gen", GeneratorActor).get()
|
326
|
+
|
327
|
+
generator.init.call(trainer).get()
|
328
|
+
trainer.init.call(generator).get()
|
329
|
+
trainer.exchange_metadata.call().get()
|
330
|
+
|
331
|
+
for _ in range(3):
|
332
|
+
trainer.weights_ready.call().get()
|
333
|
+
generator.update_weights.call().get()
|
334
|
+
|
335
|
+
|
336
|
+
def test_sync_actor_sync_client():
|
337
|
+
proc = local_proc_mesh(gpus=2).get()
|
338
|
+
a = proc.spawn("actor", SyncActor).get()
|
339
|
+
c = proc.spawn("counter", Counter, 5).get()
|
340
|
+
r = a.sync_endpoint.choose(c).get()
|
341
|
+
assert r == 5
|
342
|
+
|
343
|
+
|
344
|
+
def test_proc_mesh_size() -> None:
|
345
|
+
proc = local_proc_mesh(gpus=2).get()
|
346
|
+
assert 2 == proc.size("gpus")
|
347
|
+
|
348
|
+
|
349
|
+
def test_rank_size_sync() -> None:
|
350
|
+
proc = local_proc_mesh(gpus=2).get()
|
351
|
+
r = proc.spawn("runit", RunIt).get()
|
352
|
+
|
353
|
+
acc = Accumulator(r.run, 0, operator.add)
|
354
|
+
assert 1 == acc.accumulate(lambda: current_rank()["gpus"]).get()
|
355
|
+
assert 4 == acc.accumulate(lambda: current_size()["gpus"]).get()
|
356
|
+
|
357
|
+
|
358
|
+
def test_accumulate_sync() -> None:
|
359
|
+
proc = local_proc_mesh(gpus=2).get()
|
360
|
+
counter = proc.spawn("counter", Counter, 1).get()
|
361
|
+
counter.incr.broadcast()
|
362
|
+
acc = Accumulator(counter.value, 0, operator.add)
|
363
|
+
r = acc.accumulate().get()
|
364
|
+
assert r == 4
|
365
|
+
|
366
|
+
|
367
|
+
class CastToCounter(Actor):
|
368
|
+
@endpoint
|
369
|
+
def doit(self, c: Counter):
|
370
|
+
return list(c.value.call().get())
|
371
|
+
|
372
|
+
|
373
|
+
def test_value_mesh() -> None:
|
374
|
+
proc = local_proc_mesh(gpus=2).get()
|
375
|
+
counter = proc.spawn("counter", Counter, 0).get()
|
376
|
+
counter.slice(hosts=0, gpus=1).incr.broadcast()
|
377
|
+
x = counter.value.call().get()
|
378
|
+
assert 0 == x.item(hosts=0, gpus=0)
|
379
|
+
assert 1 == x.item(hosts=0, gpus=1)
|
380
|
+
assert 1 == x.slice(hosts=0, gpus=1).item()
|
381
|
+
n = proc.spawn("ctc", CastToCounter).get()
|
382
|
+
assert list(x) == n.slice(gpus=0).doit.call_one(counter).get()
|
383
|
+
|
384
|
+
|
385
|
+
def test_rust_binding_modules_correct() -> None:
|
386
|
+
import monarch._rust_bindings as bindings
|
387
|
+
|
388
|
+
def check(module, path):
|
389
|
+
for name, value in module.__dict__.items():
|
390
|
+
if name.startswith("__"):
|
391
|
+
continue
|
392
|
+
if isinstance(value, ModuleType):
|
393
|
+
check(value, f"{path}.{name}")
|
394
|
+
elif hasattr(value, "__module__"):
|
395
|
+
assert value.__name__ == name
|
396
|
+
assert value.__module__ == path
|
397
|
+
|
398
|
+
check(bindings, "monarch._rust_bindings")
|
399
|
+
|
400
|
+
|
401
|
+
def test_proc_mesh_liveness() -> None:
|
402
|
+
mesh = proc_mesh(gpus=2).get()
|
403
|
+
counter = mesh.spawn("counter", Counter, 1).get()
|
404
|
+
del mesh
|
405
|
+
# Give some time for the mesh to have been shut down.
|
406
|
+
# (It only would if there were a bug.)
|
407
|
+
time.sleep(0.5)
|
408
|
+
counter.value.call().get()
|
409
|
+
|
410
|
+
|
411
|
+
def _debugee_actor_internal(rank):
|
412
|
+
if rank == 0:
|
413
|
+
breakpoint() # noqa
|
414
|
+
rank += 1
|
415
|
+
return rank
|
416
|
+
elif rank == 1:
|
417
|
+
breakpoint() # noqa
|
418
|
+
rank += 2
|
419
|
+
return rank
|
420
|
+
elif rank == 2:
|
421
|
+
breakpoint() # noqa
|
422
|
+
rank += 3
|
423
|
+
raise ValueError("bad rank")
|
424
|
+
elif rank == 3:
|
425
|
+
breakpoint() # noqa
|
426
|
+
rank += 4
|
427
|
+
return rank
|
428
|
+
|
429
|
+
|
430
|
+
class DebugeeActor(Actor):
|
431
|
+
@endpoint
|
432
|
+
async def to_debug(self):
|
433
|
+
rank = MonarchContext.get().point.rank
|
434
|
+
return _debugee_actor_internal(rank)
|
435
|
+
|
436
|
+
|
437
|
+
async def test_debug() -> None:
|
438
|
+
input_mock = AsyncMock()
|
439
|
+
input_mock.side_effect = [
|
440
|
+
"attach 1",
|
441
|
+
"n",
|
442
|
+
"n",
|
443
|
+
"n",
|
444
|
+
"n",
|
445
|
+
"detach",
|
446
|
+
"attach 1",
|
447
|
+
"detach",
|
448
|
+
"quit",
|
449
|
+
"cast 0,3 n",
|
450
|
+
"cast 0,3 n",
|
451
|
+
# Attaching to 0 and 3 ensures that when we call "list"
|
452
|
+
# the next time, their function/lineno info will be
|
453
|
+
# up-to-date.
|
454
|
+
"attach 0",
|
455
|
+
"detach",
|
456
|
+
"attach 3",
|
457
|
+
"detach",
|
458
|
+
"quit",
|
459
|
+
"attach 2",
|
460
|
+
"c",
|
461
|
+
"quit",
|
462
|
+
"continue",
|
463
|
+
]
|
464
|
+
|
465
|
+
outputs = []
|
466
|
+
|
467
|
+
def _patch_output(msg):
|
468
|
+
nonlocal outputs
|
469
|
+
outputs.append(msg)
|
470
|
+
|
471
|
+
with patch("monarch.debugger._debugger_input", side_effect=input_mock), patch(
|
472
|
+
"monarch.debugger._debugger_output", new=_patch_output
|
473
|
+
):
|
474
|
+
proc = await proc_mesh(hosts=2, gpus=2)
|
475
|
+
debugee = await proc.spawn("debugee", DebugeeActor)
|
476
|
+
debug_client = await init_debugging(debugee)
|
477
|
+
|
478
|
+
fut = debugee.to_debug.call()
|
479
|
+
await debug_client.wait_pending_session.call_one()
|
480
|
+
breakpoints = []
|
481
|
+
for i in range(10):
|
482
|
+
breakpoints = await debug_client.list.call_one()
|
483
|
+
if len(breakpoints) == 4:
|
484
|
+
break
|
485
|
+
await asyncio.sleep(1)
|
486
|
+
if i == 9:
|
487
|
+
raise RuntimeError("timed out waiting for breakpoints")
|
488
|
+
|
489
|
+
initial_linenos = {}
|
490
|
+
for i in range(len(breakpoints)):
|
491
|
+
rank, coords, _, _, function, lineno = breakpoints[i]
|
492
|
+
initial_linenos[rank] = lineno
|
493
|
+
assert rank == i
|
494
|
+
assert coords == {"hosts": rank % 2, "gpus": rank // 2}
|
495
|
+
assert function == "test_python_actors._debugee_actor_internal"
|
496
|
+
assert lineno == breakpoints[0][5] + 4 * rank
|
497
|
+
|
498
|
+
await debug_client.enter.call_one()
|
499
|
+
|
500
|
+
# Check that when detaching and re-attaching to a session, the last portion of the output is repeated
|
501
|
+
expected_last_output = [
|
502
|
+
r"--Return--",
|
503
|
+
r"\n",
|
504
|
+
r"> (/.*/)+test_python_actors.py\(\d+\)to_debug\(\)->3\n-> return _debugee_actor_internal\(rank\)",
|
505
|
+
r"\n",
|
506
|
+
r"\(Pdb\) ",
|
507
|
+
]
|
508
|
+
output_len = len(expected_last_output)
|
509
|
+
assert outputs[-2 * output_len : -output_len] == outputs[-output_len:]
|
510
|
+
for real_output, expected_output in zip(
|
511
|
+
outputs[-output_len:], expected_last_output
|
512
|
+
):
|
513
|
+
assert re.match(expected_output, real_output) is not None
|
514
|
+
|
515
|
+
breakpoints = await debug_client.list.call_one()
|
516
|
+
for i in range(len(breakpoints)):
|
517
|
+
if i == 1:
|
518
|
+
assert breakpoints[i][4] == "test_python_actors.to_debug"
|
519
|
+
else:
|
520
|
+
assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal"
|
521
|
+
assert breakpoints[i][5] == initial_linenos[i]
|
522
|
+
|
523
|
+
await debug_client.enter.call_one()
|
524
|
+
|
525
|
+
breakpoints = await debug_client.list.call_one()
|
526
|
+
for i in range(len(breakpoints)):
|
527
|
+
if i == 1:
|
528
|
+
assert breakpoints[i][4] == "test_python_actors.to_debug"
|
529
|
+
elif i in (0, 3):
|
530
|
+
assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal"
|
531
|
+
assert breakpoints[i][5] == initial_linenos[i] + 2
|
532
|
+
else:
|
533
|
+
assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal"
|
534
|
+
assert breakpoints[i][5] == initial_linenos[i]
|
535
|
+
|
536
|
+
await debug_client.enter.call_one()
|
537
|
+
|
538
|
+
breakpoints = await debug_client.list.call_one()
|
539
|
+
assert len(breakpoints) == 3
|
540
|
+
for i, rank in enumerate((0, 1, 3)):
|
541
|
+
assert breakpoints[i][0] == rank
|
542
|
+
|
543
|
+
await debug_client.enter.call_one()
|
544
|
+
breakpoints = await debug_client.list.call_one()
|
545
|
+
assert len(breakpoints) == 0
|
546
|
+
|
547
|
+
with pytest.raises(monarch.actor_mesh.ActorError, match="ValueError: bad rank"):
|
548
|
+
await fut
|
549
|
+
|
550
|
+
|
551
|
+
class TLSActor(Actor):
|
552
|
+
"""An actor that manages thread-local state."""
|
553
|
+
|
554
|
+
def __init__(self):
|
555
|
+
self.local = threading.local()
|
556
|
+
self.local.value = 0
|
557
|
+
|
558
|
+
@endpoint
|
559
|
+
def increment(self):
|
560
|
+
self.local.value += 1
|
561
|
+
|
562
|
+
@endpoint
|
563
|
+
async def increment_async(self):
|
564
|
+
self.local.value += 1
|
565
|
+
|
566
|
+
@endpoint
|
567
|
+
def get(self):
|
568
|
+
return self.local.value
|
569
|
+
|
570
|
+
@endpoint
|
571
|
+
async def get_async(self):
|
572
|
+
return self.local.value
|
573
|
+
|
574
|
+
|
575
|
+
async def test_actor_tls() -> None:
|
576
|
+
"""Test that thread-local state is respected."""
|
577
|
+
pm = await proc_mesh(gpus=1)
|
578
|
+
am = await pm.spawn("tls", TLSActor)
|
579
|
+
await am.increment.call_one()
|
580
|
+
await am.increment_async.call_one()
|
581
|
+
await am.increment.call_one()
|
582
|
+
await am.increment_async.call_one()
|
583
|
+
|
584
|
+
assert 4 == await am.get.call_one()
|
585
|
+
assert 4 == await am.get_async.call_one()
|
586
|
+
|
587
|
+
|
588
|
+
class TLSActorFullSync(Actor):
|
589
|
+
"""An actor that manages thread-local state."""
|
590
|
+
|
591
|
+
def __init__(self):
|
592
|
+
self.local = threading.local()
|
593
|
+
self.local.value = 0
|
594
|
+
|
595
|
+
@endpoint
|
596
|
+
def increment(self):
|
597
|
+
self.local.value += 1
|
598
|
+
|
599
|
+
@endpoint
|
600
|
+
def get(self):
|
601
|
+
return self.local.value
|
602
|
+
|
603
|
+
|
604
|
+
async def test_actor_tls_full_sync() -> None:
|
605
|
+
"""Test that thread-local state is respected."""
|
606
|
+
pm = await proc_mesh(gpus=1)
|
607
|
+
am = await pm.spawn("tls", TLSActorFullSync)
|
608
|
+
await am.increment.call_one()
|
609
|
+
await am.increment.call_one()
|
610
|
+
await am.increment.call_one()
|
611
|
+
await am.increment.call_one()
|
612
|
+
|
613
|
+
assert 4 == await am.get.call_one()
|
614
|
+
|
615
|
+
|
616
|
+
class AsyncActor(Actor):
|
617
|
+
def __init__(self):
|
618
|
+
self.should_exit = False
|
619
|
+
|
620
|
+
@endpoint
|
621
|
+
async def sleep(self) -> None:
|
622
|
+
while True and not self.should_exit:
|
623
|
+
await asyncio.sleep(1)
|
624
|
+
|
625
|
+
@endpoint
|
626
|
+
async def no_more(self) -> None:
|
627
|
+
self.should_exit = True
|
628
|
+
|
629
|
+
|
630
|
+
@pytest.mark.timeout(15)
|
631
|
+
async def test_async_concurrency():
|
632
|
+
"""Test that async endpoints will be processed concurrently."""
|
633
|
+
pm = await proc_mesh(gpus=1)
|
634
|
+
am = await pm.spawn("async", AsyncActor)
|
635
|
+
fut = am.sleep.call()
|
636
|
+
# This call should go through and exit the sleep loop, as long as we are
|
637
|
+
# actually concurrently processing messages.
|
638
|
+
await am.no_more.call()
|
639
|
+
await fut
|
640
|
+
|
641
|
+
|
642
|
+
async def awaitit(f):
|
643
|
+
return await f
|
644
|
+
|
645
|
+
|
646
|
+
def test_actor_future():
|
647
|
+
v = 0
|
648
|
+
|
649
|
+
async def incr():
|
650
|
+
nonlocal v
|
651
|
+
v += 1
|
652
|
+
return v
|
653
|
+
|
654
|
+
# can use async implementation from sync
|
655
|
+
# if no non-blocking is provided
|
656
|
+
f = ActorFuture(incr)
|
657
|
+
assert f.get() == 1
|
658
|
+
assert v == 1
|
659
|
+
assert f.get() == 1
|
660
|
+
assert asyncio.run(awaitit(f)) == 1
|
661
|
+
|
662
|
+
f = ActorFuture(incr)
|
663
|
+
assert asyncio.run(awaitit(f)) == 2
|
664
|
+
assert f.get() == 2
|
665
|
+
|
666
|
+
def incr2():
|
667
|
+
nonlocal v
|
668
|
+
v += 2
|
669
|
+
return v
|
670
|
+
|
671
|
+
# Use non-blocking optimization if provided
|
672
|
+
f = ActorFuture(incr, incr2)
|
673
|
+
assert f.get() == 4
|
674
|
+
assert asyncio.run(awaitit(f)) == 4
|
675
|
+
|
676
|
+
async def nope():
|
677
|
+
nonlocal v
|
678
|
+
v += 1
|
679
|
+
raise ValueError("nope")
|
680
|
+
|
681
|
+
f = ActorFuture(nope)
|
682
|
+
|
683
|
+
with pytest.raises(ValueError):
|
684
|
+
f.get()
|
685
|
+
|
686
|
+
assert v == 5
|
687
|
+
|
688
|
+
with pytest.raises(ValueError):
|
689
|
+
f.get()
|
690
|
+
|
691
|
+
assert v == 5
|
692
|
+
|
693
|
+
with pytest.raises(ValueError):
|
694
|
+
asyncio.run(awaitit(f))
|
695
|
+
|
696
|
+
assert v == 5
|
697
|
+
|
698
|
+
def nope():
|
699
|
+
nonlocal v
|
700
|
+
v += 1
|
701
|
+
raise ValueError("nope")
|
702
|
+
|
703
|
+
f = ActorFuture(incr, nope)
|
704
|
+
|
705
|
+
with pytest.raises(ValueError):
|
706
|
+
f.get()
|
707
|
+
|
708
|
+
assert v == 6
|
709
|
+
|
710
|
+
with pytest.raises(ValueError):
|
711
|
+
f.result()
|
712
|
+
|
713
|
+
assert f.exception() is not None
|
714
|
+
|
715
|
+
assert v == 6
|
716
|
+
|
717
|
+
with pytest.raises(ValueError):
|
718
|
+
asyncio.run(awaitit(f))
|
719
|
+
|
720
|
+
assert v == 6
|
721
|
+
|
722
|
+
async def seven():
|
723
|
+
return 7
|
724
|
+
|
725
|
+
f = ActorFuture(seven)
|
726
|
+
|
727
|
+
assert 7 == f.get(timeout=0.001)
|
728
|
+
|
729
|
+
async def neverfinish():
|
730
|
+
f = asyncio.Future()
|
731
|
+
await f
|
732
|
+
|
733
|
+
f = ActorFuture(neverfinish)
|
734
|
+
|
735
|
+
with pytest.raises(asyncio.exceptions.TimeoutError):
|
736
|
+
f.get(timeout=0.1)
|