torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,736 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import asyncio
8
+ import operator
9
+ import re
10
+ import threading
11
+ import time
12
+ from types import ModuleType
13
+ from unittest.mock import AsyncMock, patch
14
+
15
+ import monarch
16
+
17
+ import pytest
18
+
19
+ import torch
20
+
21
+ from monarch.actor_mesh import (
22
+ Accumulator,
23
+ Actor,
24
+ current_actor_name,
25
+ current_rank,
26
+ current_size,
27
+ endpoint,
28
+ MonarchContext,
29
+ )
30
+ from monarch.debugger import init_debugging
31
+ from monarch.future import ActorFuture
32
+
33
+ from monarch.proc_mesh import local_proc_mesh, proc_mesh
34
+ from monarch.rdma import RDMABuffer
35
+
36
+ needs_cuda = pytest.mark.skipif(
37
+ not torch.cuda.is_available(),
38
+ reason="CUDA not available",
39
+ )
40
+
41
+
42
+ class Counter(Actor):
43
+ def __init__(self, v: int):
44
+ self.v = v
45
+
46
+ @endpoint
47
+ async def incr(self):
48
+ self.v += 1
49
+
50
+ @endpoint
51
+ async def value(self) -> int:
52
+ return self.v
53
+
54
+
55
+ class Indirect(Actor):
56
+ @endpoint
57
+ async def call_value(self, c: Counter) -> int:
58
+ return await c.value.choose()
59
+
60
+
61
+ class ParameterServer(Actor):
62
+ def __init__(self):
63
+ self.params = torch.rand(10, 10)
64
+ self.grad_buffer = torch.rand(10, 10)
65
+
66
+ @endpoint
67
+ async def grad_handle(self) -> RDMABuffer:
68
+ byte_tensor = self.grad_buffer.view(torch.uint8).flatten()
69
+ return RDMABuffer(byte_tensor)
70
+
71
+ @endpoint
72
+ async def update(self):
73
+ self.params += 0.01 * self.grad_buffer
74
+
75
+ @endpoint
76
+ async def get_grad_buffer(self) -> torch.Tensor:
77
+ # just used for testing
78
+ return self.grad_buffer
79
+
80
+
81
+ async def test_choose():
82
+ proc = await local_proc_mesh(gpus=2)
83
+ v = await proc.spawn("counter", Counter, 3)
84
+ i = await proc.spawn("indirect", Indirect)
85
+ v.incr.broadcast()
86
+ result = await v.value.choose()
87
+ result2 = await i.call_value.choose(v)
88
+
89
+ assert result == result2
90
+
91
+
92
+ async def test_stream():
93
+ proc = await local_proc_mesh(gpus=2)
94
+ v = await proc.spawn("counter2", Counter, 3)
95
+ v.incr.broadcast()
96
+
97
+ assert 8 == sum([x async for x in v.value.stream()])
98
+
99
+
100
+ class ParameterClient(Actor):
101
+ def __init__(self, server, buffer):
102
+ self.server = server
103
+ byte_tensor = buffer.view(torch.uint8).flatten()
104
+ self.buffer = byte_tensor
105
+
106
+ @endpoint
107
+ async def upload(self, tensor):
108
+ gh = await self.server.grad_handle.call_one()
109
+ await gh.write(tensor)
110
+
111
+ @endpoint
112
+ async def download(self):
113
+ gh = await self.server.grad_handle.call_one()
114
+ await gh.read_into(self.buffer)
115
+
116
+ @endpoint
117
+ async def get_buffer(self):
118
+ return self.buffer
119
+
120
+
121
+ @needs_cuda
122
+ async def test_proc_mesh_rdma():
123
+ proc = await proc_mesh(gpus=1)
124
+ server = await proc.spawn("server", ParameterServer)
125
+
126
+ # --- CPU TESTS ---
127
+ client_cpu = await proc.spawn(
128
+ "client_cpu", ParameterClient, server, torch.ones(10, 10)
129
+ )
130
+ x = await client_cpu.get_buffer.call_one()
131
+ assert torch.sum(x.view(torch.float32).view(10, 10)) == 100
132
+ zeros = torch.zeros(10, 10)
133
+ await client_cpu.upload.call_one(zeros.view(torch.uint8).flatten())
134
+ await client_cpu.download.call_one()
135
+ x = await client_cpu.get_buffer.call_one()
136
+ assert torch.sum(x.view(torch.float32).view(10, 10)) == 0
137
+
138
+ # --- Modify server's backing buffer directly ---
139
+ await server.update.call_one()
140
+
141
+ # Should reflect updated values
142
+ await client_cpu.download.call_one()
143
+
144
+ buffer = await client_cpu.get_buffer.call_one()
145
+ remote_grad = await server.get_grad_buffer.call_one()
146
+ assert torch.allclose(buffer.view(torch.float32).view(10, 10), remote_grad)
147
+
148
+ # --- GPU TESTS ---
149
+ client_gpu = await proc.spawn(
150
+ "client_gpu", ParameterClient, server, torch.ones(10, 10, device="cuda")
151
+ )
152
+ x = await client_gpu.get_buffer.call_one()
153
+ buffer = x.view(torch.float32).view(10, 10)
154
+ assert torch.sum(buffer) == 100
155
+ zeros = torch.zeros(10, 10, device="cuda")
156
+ await client_gpu.upload.call_one(zeros.view(torch.uint8).flatten())
157
+ await client_gpu.download.call_one()
158
+ x = await client_gpu.get_buffer.call_one()
159
+ buffer_gpu = x.view(torch.float32).view(10, 10)
160
+ assert torch.sum(buffer_gpu) == 0
161
+ assert buffer_gpu.device.type == "cuda"
162
+
163
+ # Modify server state again
164
+ await server.update.call_one()
165
+ await client_gpu.download.call_one()
166
+ x = await client_gpu.get_buffer.call_one()
167
+ buffer_gpu = x.view(torch.float32).view(10, 10)
168
+ remote_grad = await server.get_grad_buffer.call_one()
169
+ assert torch.allclose(buffer_gpu.cpu(), remote_grad)
170
+
171
+
172
+ class To(Actor):
173
+ @endpoint
174
+ async def whoami(self):
175
+ return current_actor_name()
176
+
177
+
178
+ class From(Actor):
179
+ @endpoint
180
+ async def get(self, to: To):
181
+ return [x async for x in to.whoami.stream()]
182
+
183
+
184
+ async def test_mesh_passed_to_mesh():
185
+ proc = await local_proc_mesh(gpus=2)
186
+ f = await proc.spawn("from", From)
187
+ t = await proc.spawn("to", To)
188
+ all = [y async for x in f.get.stream(t) for y in x]
189
+ assert len(all) == 4
190
+ assert all[0] != all[1]
191
+
192
+
193
+ async def test_mesh_passed_to_mesh_on_different_proc_mesh():
194
+ proc = await local_proc_mesh(gpus=2)
195
+ proc2 = await local_proc_mesh(gpus=2)
196
+ f = await proc.spawn("from", From)
197
+ t = await proc2.spawn("to", To)
198
+ all = [y async for x in f.get.stream(t) for y in x]
199
+ assert len(all) == 4
200
+ assert all[0] != all[1]
201
+
202
+
203
+ async def test_actor_slicing():
204
+ proc = await local_proc_mesh(gpus=2)
205
+ proc2 = await local_proc_mesh(gpus=2)
206
+
207
+ f = await proc.spawn("from", From)
208
+ t = await proc2.spawn("to", To)
209
+
210
+ assert await t.slice(gpus=0).whoami.call() != await t.slice(gpus=1).whoami.call()
211
+
212
+ result = [y async for x in f.get.stream(t.slice(gpus=0)) for y in x]
213
+ assert len(result) == 2
214
+
215
+ assert result[0] == result[1]
216
+
217
+
218
+ async def test_aggregate():
219
+ proc = await local_proc_mesh(gpus=2)
220
+ counter = await proc.spawn("counter", Counter, 1)
221
+ counter.incr.broadcast()
222
+ acc = Accumulator(counter.value, 0, operator.add)
223
+ r = await acc.accumulate()
224
+ assert r == 4
225
+
226
+
227
+ class RunIt(Actor):
228
+ @endpoint
229
+ async def run(self, fn):
230
+ return fn()
231
+
232
+
233
+ async def test_rank_size():
234
+ proc = await local_proc_mesh(gpus=2)
235
+ r = await proc.spawn("runit", RunIt)
236
+
237
+ acc = Accumulator(r.run, 0, operator.add)
238
+
239
+ assert 1 == await acc.accumulate(lambda: current_rank()["gpus"])
240
+ assert 4 == await acc.accumulate(lambda: current_size()["gpus"])
241
+
242
+
243
+ class TrainerActor(Actor):
244
+ def __init__(self):
245
+ super().__init__()
246
+ self.trainer = torch.nn.Linear(10, 10).to("cuda")
247
+ self.trainer.weight.data.zero_()
248
+
249
+ @endpoint
250
+ async def init(self, gen):
251
+ ranks = current_rank()
252
+ self.gen = gen.slice(**ranks)
253
+
254
+ @endpoint
255
+ async def exchange_metadata(self):
256
+ byte_tensor = self.trainer.weight.data.view(torch.uint8).flatten()
257
+ self.handle = RDMABuffer(byte_tensor)
258
+ await self.gen.attach_weight_buffer.call(self.handle)
259
+
260
+ @endpoint
261
+ async def weights_ready(self):
262
+ self.trainer.weight.data.add_(1.0)
263
+
264
+
265
+ class GeneratorActor(Actor):
266
+ def __init__(self):
267
+ super().__init__()
268
+ self.generator = torch.nn.Linear(10, 10).to("cuda")
269
+ self.step = 0
270
+
271
+ @endpoint
272
+ async def init(self, trainer):
273
+ ranks = current_rank()
274
+ self.trainer = trainer.slice(**ranks)
275
+
276
+ @endpoint
277
+ async def attach_weight_buffer(self, handle):
278
+ self.handle = handle
279
+
280
+ @endpoint
281
+ async def update_weights(self):
282
+ self.step += 1
283
+ byte_tensor = self.generator.weight.data.view(torch.uint8).flatten()
284
+ await self.handle.read_into(byte_tensor)
285
+ assert (
286
+ torch.sum(self.generator.weight.data) == self.step * 100
287
+ ), f"{torch.sum(self.generator.weight.data)=}, {self.step=}"
288
+
289
+
290
+ @needs_cuda
291
+ async def test_gpu_trainer_generator():
292
+ trainer_proc = await proc_mesh(gpus=1)
293
+ gen_proc = await proc_mesh(gpus=1)
294
+ trainer = await trainer_proc.spawn("trainer", TrainerActor)
295
+ generator = await gen_proc.spawn("gen", GeneratorActor)
296
+
297
+ await generator.init.call(trainer)
298
+ await trainer.init.call(generator)
299
+ await trainer.exchange_metadata.call()
300
+
301
+ for _ in range(3):
302
+ await trainer.weights_ready.call()
303
+ await generator.update_weights.call()
304
+
305
+
306
+ class SyncActor(Actor):
307
+ @endpoint
308
+ def sync_endpoint(self, a_counter: Counter):
309
+ return a_counter.value.choose().get()
310
+
311
+
312
+ async def test_sync_actor():
313
+ proc = await local_proc_mesh(gpus=2)
314
+ a = await proc.spawn("actor", SyncActor)
315
+ c = await proc.spawn("counter", Counter, 5)
316
+ r = await a.sync_endpoint.choose(c)
317
+ assert r == 5
318
+
319
+
320
+ @needs_cuda
321
+ def test_gpu_trainer_generator_sync() -> None:
322
+ trainer_proc = proc_mesh(gpus=1).get()
323
+ gen_proc = proc_mesh(gpus=1).get()
324
+ trainer = trainer_proc.spawn("trainer", TrainerActor).get()
325
+ generator = gen_proc.spawn("gen", GeneratorActor).get()
326
+
327
+ generator.init.call(trainer).get()
328
+ trainer.init.call(generator).get()
329
+ trainer.exchange_metadata.call().get()
330
+
331
+ for _ in range(3):
332
+ trainer.weights_ready.call().get()
333
+ generator.update_weights.call().get()
334
+
335
+
336
+ def test_sync_actor_sync_client():
337
+ proc = local_proc_mesh(gpus=2).get()
338
+ a = proc.spawn("actor", SyncActor).get()
339
+ c = proc.spawn("counter", Counter, 5).get()
340
+ r = a.sync_endpoint.choose(c).get()
341
+ assert r == 5
342
+
343
+
344
+ def test_proc_mesh_size() -> None:
345
+ proc = local_proc_mesh(gpus=2).get()
346
+ assert 2 == proc.size("gpus")
347
+
348
+
349
+ def test_rank_size_sync() -> None:
350
+ proc = local_proc_mesh(gpus=2).get()
351
+ r = proc.spawn("runit", RunIt).get()
352
+
353
+ acc = Accumulator(r.run, 0, operator.add)
354
+ assert 1 == acc.accumulate(lambda: current_rank()["gpus"]).get()
355
+ assert 4 == acc.accumulate(lambda: current_size()["gpus"]).get()
356
+
357
+
358
+ def test_accumulate_sync() -> None:
359
+ proc = local_proc_mesh(gpus=2).get()
360
+ counter = proc.spawn("counter", Counter, 1).get()
361
+ counter.incr.broadcast()
362
+ acc = Accumulator(counter.value, 0, operator.add)
363
+ r = acc.accumulate().get()
364
+ assert r == 4
365
+
366
+
367
+ class CastToCounter(Actor):
368
+ @endpoint
369
+ def doit(self, c: Counter):
370
+ return list(c.value.call().get())
371
+
372
+
373
+ def test_value_mesh() -> None:
374
+ proc = local_proc_mesh(gpus=2).get()
375
+ counter = proc.spawn("counter", Counter, 0).get()
376
+ counter.slice(hosts=0, gpus=1).incr.broadcast()
377
+ x = counter.value.call().get()
378
+ assert 0 == x.item(hosts=0, gpus=0)
379
+ assert 1 == x.item(hosts=0, gpus=1)
380
+ assert 1 == x.slice(hosts=0, gpus=1).item()
381
+ n = proc.spawn("ctc", CastToCounter).get()
382
+ assert list(x) == n.slice(gpus=0).doit.call_one(counter).get()
383
+
384
+
385
+ def test_rust_binding_modules_correct() -> None:
386
+ import monarch._rust_bindings as bindings
387
+
388
+ def check(module, path):
389
+ for name, value in module.__dict__.items():
390
+ if name.startswith("__"):
391
+ continue
392
+ if isinstance(value, ModuleType):
393
+ check(value, f"{path}.{name}")
394
+ elif hasattr(value, "__module__"):
395
+ assert value.__name__ == name
396
+ assert value.__module__ == path
397
+
398
+ check(bindings, "monarch._rust_bindings")
399
+
400
+
401
+ def test_proc_mesh_liveness() -> None:
402
+ mesh = proc_mesh(gpus=2).get()
403
+ counter = mesh.spawn("counter", Counter, 1).get()
404
+ del mesh
405
+ # Give some time for the mesh to have been shut down.
406
+ # (It only would if there were a bug.)
407
+ time.sleep(0.5)
408
+ counter.value.call().get()
409
+
410
+
411
+ def _debugee_actor_internal(rank):
412
+ if rank == 0:
413
+ breakpoint() # noqa
414
+ rank += 1
415
+ return rank
416
+ elif rank == 1:
417
+ breakpoint() # noqa
418
+ rank += 2
419
+ return rank
420
+ elif rank == 2:
421
+ breakpoint() # noqa
422
+ rank += 3
423
+ raise ValueError("bad rank")
424
+ elif rank == 3:
425
+ breakpoint() # noqa
426
+ rank += 4
427
+ return rank
428
+
429
+
430
+ class DebugeeActor(Actor):
431
+ @endpoint
432
+ async def to_debug(self):
433
+ rank = MonarchContext.get().point.rank
434
+ return _debugee_actor_internal(rank)
435
+
436
+
437
+ async def test_debug() -> None:
438
+ input_mock = AsyncMock()
439
+ input_mock.side_effect = [
440
+ "attach 1",
441
+ "n",
442
+ "n",
443
+ "n",
444
+ "n",
445
+ "detach",
446
+ "attach 1",
447
+ "detach",
448
+ "quit",
449
+ "cast 0,3 n",
450
+ "cast 0,3 n",
451
+ # Attaching to 0 and 3 ensures that when we call "list"
452
+ # the next time, their function/lineno info will be
453
+ # up-to-date.
454
+ "attach 0",
455
+ "detach",
456
+ "attach 3",
457
+ "detach",
458
+ "quit",
459
+ "attach 2",
460
+ "c",
461
+ "quit",
462
+ "continue",
463
+ ]
464
+
465
+ outputs = []
466
+
467
+ def _patch_output(msg):
468
+ nonlocal outputs
469
+ outputs.append(msg)
470
+
471
+ with patch("monarch.debugger._debugger_input", side_effect=input_mock), patch(
472
+ "monarch.debugger._debugger_output", new=_patch_output
473
+ ):
474
+ proc = await proc_mesh(hosts=2, gpus=2)
475
+ debugee = await proc.spawn("debugee", DebugeeActor)
476
+ debug_client = await init_debugging(debugee)
477
+
478
+ fut = debugee.to_debug.call()
479
+ await debug_client.wait_pending_session.call_one()
480
+ breakpoints = []
481
+ for i in range(10):
482
+ breakpoints = await debug_client.list.call_one()
483
+ if len(breakpoints) == 4:
484
+ break
485
+ await asyncio.sleep(1)
486
+ if i == 9:
487
+ raise RuntimeError("timed out waiting for breakpoints")
488
+
489
+ initial_linenos = {}
490
+ for i in range(len(breakpoints)):
491
+ rank, coords, _, _, function, lineno = breakpoints[i]
492
+ initial_linenos[rank] = lineno
493
+ assert rank == i
494
+ assert coords == {"hosts": rank % 2, "gpus": rank // 2}
495
+ assert function == "test_python_actors._debugee_actor_internal"
496
+ assert lineno == breakpoints[0][5] + 4 * rank
497
+
498
+ await debug_client.enter.call_one()
499
+
500
+ # Check that when detaching and re-attaching to a session, the last portion of the output is repeated
501
+ expected_last_output = [
502
+ r"--Return--",
503
+ r"\n",
504
+ r"> (/.*/)+test_python_actors.py\(\d+\)to_debug\(\)->3\n-> return _debugee_actor_internal\(rank\)",
505
+ r"\n",
506
+ r"\(Pdb\) ",
507
+ ]
508
+ output_len = len(expected_last_output)
509
+ assert outputs[-2 * output_len : -output_len] == outputs[-output_len:]
510
+ for real_output, expected_output in zip(
511
+ outputs[-output_len:], expected_last_output
512
+ ):
513
+ assert re.match(expected_output, real_output) is not None
514
+
515
+ breakpoints = await debug_client.list.call_one()
516
+ for i in range(len(breakpoints)):
517
+ if i == 1:
518
+ assert breakpoints[i][4] == "test_python_actors.to_debug"
519
+ else:
520
+ assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal"
521
+ assert breakpoints[i][5] == initial_linenos[i]
522
+
523
+ await debug_client.enter.call_one()
524
+
525
+ breakpoints = await debug_client.list.call_one()
526
+ for i in range(len(breakpoints)):
527
+ if i == 1:
528
+ assert breakpoints[i][4] == "test_python_actors.to_debug"
529
+ elif i in (0, 3):
530
+ assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal"
531
+ assert breakpoints[i][5] == initial_linenos[i] + 2
532
+ else:
533
+ assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal"
534
+ assert breakpoints[i][5] == initial_linenos[i]
535
+
536
+ await debug_client.enter.call_one()
537
+
538
+ breakpoints = await debug_client.list.call_one()
539
+ assert len(breakpoints) == 3
540
+ for i, rank in enumerate((0, 1, 3)):
541
+ assert breakpoints[i][0] == rank
542
+
543
+ await debug_client.enter.call_one()
544
+ breakpoints = await debug_client.list.call_one()
545
+ assert len(breakpoints) == 0
546
+
547
+ with pytest.raises(monarch.actor_mesh.ActorError, match="ValueError: bad rank"):
548
+ await fut
549
+
550
+
551
+ class TLSActor(Actor):
552
+ """An actor that manages thread-local state."""
553
+
554
+ def __init__(self):
555
+ self.local = threading.local()
556
+ self.local.value = 0
557
+
558
+ @endpoint
559
+ def increment(self):
560
+ self.local.value += 1
561
+
562
+ @endpoint
563
+ async def increment_async(self):
564
+ self.local.value += 1
565
+
566
+ @endpoint
567
+ def get(self):
568
+ return self.local.value
569
+
570
+ @endpoint
571
+ async def get_async(self):
572
+ return self.local.value
573
+
574
+
575
+ async def test_actor_tls() -> None:
576
+ """Test that thread-local state is respected."""
577
+ pm = await proc_mesh(gpus=1)
578
+ am = await pm.spawn("tls", TLSActor)
579
+ await am.increment.call_one()
580
+ await am.increment_async.call_one()
581
+ await am.increment.call_one()
582
+ await am.increment_async.call_one()
583
+
584
+ assert 4 == await am.get.call_one()
585
+ assert 4 == await am.get_async.call_one()
586
+
587
+
588
+ class TLSActorFullSync(Actor):
589
+ """An actor that manages thread-local state."""
590
+
591
+ def __init__(self):
592
+ self.local = threading.local()
593
+ self.local.value = 0
594
+
595
+ @endpoint
596
+ def increment(self):
597
+ self.local.value += 1
598
+
599
+ @endpoint
600
+ def get(self):
601
+ return self.local.value
602
+
603
+
604
+ async def test_actor_tls_full_sync() -> None:
605
+ """Test that thread-local state is respected."""
606
+ pm = await proc_mesh(gpus=1)
607
+ am = await pm.spawn("tls", TLSActorFullSync)
608
+ await am.increment.call_one()
609
+ await am.increment.call_one()
610
+ await am.increment.call_one()
611
+ await am.increment.call_one()
612
+
613
+ assert 4 == await am.get.call_one()
614
+
615
+
616
+ class AsyncActor(Actor):
617
+ def __init__(self):
618
+ self.should_exit = False
619
+
620
+ @endpoint
621
+ async def sleep(self) -> None:
622
+ while True and not self.should_exit:
623
+ await asyncio.sleep(1)
624
+
625
+ @endpoint
626
+ async def no_more(self) -> None:
627
+ self.should_exit = True
628
+
629
+
630
+ @pytest.mark.timeout(15)
631
+ async def test_async_concurrency():
632
+ """Test that async endpoints will be processed concurrently."""
633
+ pm = await proc_mesh(gpus=1)
634
+ am = await pm.spawn("async", AsyncActor)
635
+ fut = am.sleep.call()
636
+ # This call should go through and exit the sleep loop, as long as we are
637
+ # actually concurrently processing messages.
638
+ await am.no_more.call()
639
+ await fut
640
+
641
+
642
+ async def awaitit(f):
643
+ return await f
644
+
645
+
646
+ def test_actor_future():
647
+ v = 0
648
+
649
+ async def incr():
650
+ nonlocal v
651
+ v += 1
652
+ return v
653
+
654
+ # can use async implementation from sync
655
+ # if no non-blocking is provided
656
+ f = ActorFuture(incr)
657
+ assert f.get() == 1
658
+ assert v == 1
659
+ assert f.get() == 1
660
+ assert asyncio.run(awaitit(f)) == 1
661
+
662
+ f = ActorFuture(incr)
663
+ assert asyncio.run(awaitit(f)) == 2
664
+ assert f.get() == 2
665
+
666
+ def incr2():
667
+ nonlocal v
668
+ v += 2
669
+ return v
670
+
671
+ # Use non-blocking optimization if provided
672
+ f = ActorFuture(incr, incr2)
673
+ assert f.get() == 4
674
+ assert asyncio.run(awaitit(f)) == 4
675
+
676
+ async def nope():
677
+ nonlocal v
678
+ v += 1
679
+ raise ValueError("nope")
680
+
681
+ f = ActorFuture(nope)
682
+
683
+ with pytest.raises(ValueError):
684
+ f.get()
685
+
686
+ assert v == 5
687
+
688
+ with pytest.raises(ValueError):
689
+ f.get()
690
+
691
+ assert v == 5
692
+
693
+ with pytest.raises(ValueError):
694
+ asyncio.run(awaitit(f))
695
+
696
+ assert v == 5
697
+
698
+ def nope():
699
+ nonlocal v
700
+ v += 1
701
+ raise ValueError("nope")
702
+
703
+ f = ActorFuture(incr, nope)
704
+
705
+ with pytest.raises(ValueError):
706
+ f.get()
707
+
708
+ assert v == 6
709
+
710
+ with pytest.raises(ValueError):
711
+ f.result()
712
+
713
+ assert f.exception() is not None
714
+
715
+ assert v == 6
716
+
717
+ with pytest.raises(ValueError):
718
+ asyncio.run(awaitit(f))
719
+
720
+ assert v == 6
721
+
722
+ async def seven():
723
+ return 7
724
+
725
+ f = ActorFuture(seven)
726
+
727
+ assert 7 == f.get(timeout=0.001)
728
+
729
+ async def neverfinish():
730
+ f = asyncio.Future()
731
+ await f
732
+
733
+ f = ActorFuture(neverfinish)
734
+
735
+ with pytest.raises(asyncio.exceptions.TimeoutError):
736
+ f.get(timeout=0.1)