torchmonarch-nightly 2025.8.2__cp310-cp310-manylinux2014_x86_64.whl → 2025.9.3__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. monarch/_rust_bindings.so +0 -0
  2. monarch/_src/actor/actor_mesh.py +414 -216
  3. monarch/_src/actor/allocator.py +75 -6
  4. monarch/_src/actor/bootstrap_main.py +7 -4
  5. monarch/_src/actor/code_sync/__init__.py +2 -0
  6. monarch/_src/actor/debugger/__init__.py +7 -0
  7. monarch/_src/actor/{debugger.py → debugger/debugger.py} +246 -135
  8. monarch/_src/actor/{pdb_wrapper.py → debugger/pdb_wrapper.py} +62 -23
  9. monarch/_src/actor/endpoint.py +27 -45
  10. monarch/_src/actor/future.py +86 -24
  11. monarch/_src/actor/host_mesh.py +125 -0
  12. monarch/_src/actor/logging.py +94 -0
  13. monarch/_src/actor/pickle.py +25 -0
  14. monarch/_src/actor/proc_mesh.py +423 -156
  15. monarch/_src/actor/python_extension_methods.py +90 -0
  16. monarch/_src/actor/shape.py +8 -1
  17. monarch/_src/actor/source_loader.py +45 -0
  18. monarch/_src/actor/telemetry/__init__.py +172 -0
  19. monarch/_src/actor/telemetry/rust_span_tracing.py +6 -39
  20. monarch/_src/debug_cli/__init__.py +7 -0
  21. monarch/_src/debug_cli/debug_cli.py +43 -0
  22. monarch/_src/tensor_engine/rdma.py +64 -9
  23. monarch/_testing.py +1 -3
  24. monarch/actor/__init__.py +24 -4
  25. monarch/common/_C.so +0 -0
  26. monarch/common/device_mesh.py +14 -0
  27. monarch/common/future.py +10 -0
  28. monarch/common/remote.py +14 -25
  29. monarch/common/tensor.py +12 -0
  30. monarch/debug_cli/__init__.py +7 -0
  31. monarch/debug_cli/__main__.py +12 -0
  32. monarch/fetch.py +2 -2
  33. monarch/gradient/_gradient_generator.so +0 -0
  34. monarch/gradient_generator.py +4 -2
  35. monarch/mesh_controller.py +34 -14
  36. monarch/monarch_controller +0 -0
  37. monarch/tools/colors.py +25 -0
  38. monarch/tools/commands.py +42 -7
  39. monarch/tools/components/hyperactor.py +1 -1
  40. monarch/tools/config/__init__.py +31 -4
  41. monarch/tools/config/defaults.py +13 -3
  42. monarch/tools/config/environment.py +45 -0
  43. monarch/tools/config/workspace.py +165 -0
  44. monarch/tools/mesh_spec.py +2 -0
  45. monarch/utils/__init__.py +9 -0
  46. monarch/utils/utils.py +78 -0
  47. tests/error_test_binary.py +5 -3
  48. tests/python_actor_test_binary.py +52 -0
  49. tests/test_actor_error.py +142 -14
  50. tests/test_alloc.py +1 -1
  51. tests/test_allocator.py +59 -72
  52. tests/test_debugger.py +639 -45
  53. tests/test_env_before_cuda.py +4 -4
  54. tests/test_mesh_trait.py +38 -0
  55. tests/test_python_actors.py +965 -75
  56. tests/test_rdma.py +7 -6
  57. tests/test_tensor_engine.py +6 -6
  58. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/METADATA +82 -4
  59. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/RECORD +63 -47
  60. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/WHEEL +0 -0
  61. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/entry_points.txt +0 -0
  62. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/licenses/LICENSE +0 -0
  63. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/top_level.txt +0 -0
@@ -5,25 +5,45 @@
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
7
  # pyre-unsafe
8
+
8
9
  import asyncio
10
+ import ctypes
11
+ import gc
12
+ import importlib.resources
9
13
  import logging
10
14
  import operator
11
15
  import os
16
+ import re
17
+ import subprocess
12
18
  import sys
13
19
  import tempfile
14
20
  import threading
15
21
  import time
16
22
  import unittest
17
- from logging import INFO
23
+ import unittest.mock
18
24
  from types import ModuleType
19
- from typing import cast
25
+ from typing import cast, Tuple
20
26
 
21
27
  import pytest
22
28
 
23
29
  import torch
30
+ from monarch._rust_bindings.monarch_hyperactor.actor import (
31
+ PythonMessage,
32
+ PythonMessageKind,
33
+ )
34
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
35
+ PortId,
36
+ PortRef,
37
+ UndeliverableMessageEnvelope,
38
+ )
39
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
24
40
  from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
25
41
 
26
- from monarch._src.actor.actor_mesh import ActorMeshRef, Port, PortTuple
42
+ from monarch._src.actor.actor_mesh import ActorMesh, Channel, context, Port
43
+ from monarch._src.actor.allocator import AllocHandle
44
+ from monarch._src.actor.future import Future
45
+ from monarch._src.actor.host_mesh import create_local_host_mesh, fake_in_process_host
46
+ from monarch._src.actor.proc_mesh import ProcMesh
27
47
 
28
48
  from monarch.actor import (
29
49
  Accumulator,
@@ -32,10 +52,10 @@ from monarch.actor import (
32
52
  current_rank,
33
53
  current_size,
34
54
  endpoint,
35
- Future,
36
- local_proc_mesh,
37
- proc_mesh,
55
+ this_host,
56
+ this_proc,
38
57
  )
58
+ from monarch.tools.config import defaults
39
59
  from typing_extensions import assert_type
40
60
 
41
61
 
@@ -68,8 +88,9 @@ class Indirect(Actor):
68
88
  return await c.value.choose()
69
89
 
70
90
 
91
+ @pytest.mark.timeout(60)
71
92
  async def test_choose():
72
- proc = await local_proc_mesh(gpus=2)
93
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
73
94
  v = await proc.spawn("counter", Counter, 3)
74
95
  i = await proc.spawn("indirect", Indirect)
75
96
  v.incr.broadcast()
@@ -86,8 +107,9 @@ async def test_choose():
86
107
  assert result2 == result3
87
108
 
88
109
 
110
+ @pytest.mark.timeout(60)
89
111
  async def test_stream():
90
- proc = await local_proc_mesh(gpus=2)
112
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
91
113
  v = await proc.spawn("counter2", Counter, 3)
92
114
  v.incr.broadcast()
93
115
 
@@ -102,46 +124,50 @@ class To(Actor):
102
124
 
103
125
  class From(Actor):
104
126
  @endpoint
105
- async def get(self, to: To):
127
+ async def fetch(self, to: To):
106
128
  return [await x for x in to.whoami.stream()]
107
129
 
108
130
 
131
+ @pytest.mark.timeout(60)
109
132
  async def test_mesh_passed_to_mesh():
110
- proc = await local_proc_mesh(gpus=2)
133
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
111
134
  f = await proc.spawn("from", From)
112
135
  t = await proc.spawn("to", To)
113
- all = [y for x in f.get.stream(t) for y in await x]
136
+ all = [y for x in f.fetch.stream(t) for y in await x]
114
137
  assert len(all) == 4
115
138
  assert all[0] != all[1]
116
139
 
117
140
 
141
+ @pytest.mark.timeout(60)
118
142
  async def test_mesh_passed_to_mesh_on_different_proc_mesh():
119
- proc = await local_proc_mesh(gpus=2)
120
- proc2 = await local_proc_mesh(gpus=2)
143
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
144
+ proc2 = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
121
145
  f = await proc.spawn("from", From)
122
146
  t = await proc2.spawn("to", To)
123
- all = [y for x in f.get.stream(t) for y in await x]
147
+ all = [y for x in f.fetch.stream(t) for y in await x]
124
148
  assert len(all) == 4
125
149
  assert all[0] != all[1]
126
150
 
127
151
 
128
- async def test_actor_slicing():
129
- proc = await local_proc_mesh(gpus=2)
130
- proc2 = await local_proc_mesh(gpus=2)
152
+ @pytest.mark.timeout(60)
153
+ def test_actor_slicing():
154
+ proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
155
+ proc2 = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
131
156
 
132
- f = await proc.spawn("from", From)
133
- t = await proc2.spawn("to", To)
157
+ f = proc.spawn("from", From)
158
+ t = proc2.spawn("to", To)
134
159
 
135
- assert await t.slice(gpus=0).whoami.call() != await t.slice(gpus=1).whoami.call()
160
+ assert t.slice(gpus=0).whoami.call().get() != t.slice(gpus=1).whoami.call().get()
136
161
 
137
- result = [y for x in f.get.stream(t.slice(gpus=0)) for y in await x]
162
+ result = [y for x in f.fetch.stream(t.slice(gpus=0)) for y in x.get()]
138
163
  assert len(result) == 2
139
164
 
140
165
  assert result[0] == result[1]
141
166
 
142
167
 
168
+ @pytest.mark.timeout(60)
143
169
  async def test_aggregate():
144
- proc = await local_proc_mesh(gpus=2)
170
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
145
171
  counter = await proc.spawn("counter", Counter, 1)
146
172
  counter.incr.broadcast()
147
173
  acc = Accumulator(counter.value, 0, operator.add)
@@ -154,9 +180,14 @@ class RunIt(Actor):
154
180
  async def run(self, fn):
155
181
  return fn()
156
182
 
183
+ @endpoint
184
+ async def return_current_rank_str(self):
185
+ return str(current_rank())
186
+
157
187
 
188
+ @pytest.mark.timeout(60)
158
189
  async def test_rank_size():
159
- proc = await local_proc_mesh(gpus=2)
190
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
160
191
  r = await proc.spawn("runit", RunIt)
161
192
 
162
193
  acc = Accumulator(r.run, 0, operator.add)
@@ -165,35 +196,50 @@ async def test_rank_size():
165
196
  assert 4 == await acc.accumulate(lambda: current_size()["gpus"])
166
197
 
167
198
 
199
+ @pytest.mark.timeout(60)
200
+ async def test_rank_string():
201
+ proc = fake_in_process_host().spawn_procs(per_host={"hosts": 1, "gpus": 2})
202
+ r = proc.spawn("runit", RunIt)
203
+ vm = r.return_current_rank_str.call().get()
204
+ r0 = vm.flatten("r").slice(r=0).item()
205
+ r1 = vm.flatten("r").slice(r=1).item()
206
+ assert r0 == "{'hosts': 0/1, 'gpus': 0/2}"
207
+ assert r1 == "{'hosts': 0/1, 'gpus': 1/2}"
208
+
209
+
168
210
  class SyncActor(Actor):
169
211
  @endpoint
170
212
  def sync_endpoint(self, a_counter: Counter):
171
213
  return a_counter.value.choose().get()
172
214
 
173
215
 
216
+ @pytest.mark.timeout(60)
174
217
  async def test_sync_actor():
175
- proc = await local_proc_mesh(gpus=2)
218
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
176
219
  a = await proc.spawn("actor", SyncActor)
177
220
  c = await proc.spawn("counter", Counter, 5)
178
221
  r = await a.sync_endpoint.choose(c)
179
222
  assert r == 5
180
223
 
181
224
 
182
- def test_sync_actor_sync_client():
183
- proc = local_proc_mesh(gpus=2).get()
225
+ @pytest.mark.timeout(60)
226
+ def test_sync_actor_sync_client() -> None:
227
+ proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
184
228
  a = proc.spawn("actor", SyncActor).get()
185
229
  c = proc.spawn("counter", Counter, 5).get()
186
230
  r = a.sync_endpoint.choose(c).get()
187
231
  assert r == 5
188
232
 
189
233
 
234
+ @pytest.mark.timeout(60)
190
235
  def test_proc_mesh_size() -> None:
191
- proc = local_proc_mesh(gpus=2).get()
236
+ proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
192
237
  assert 2 == proc.size("gpus")
193
238
 
194
239
 
240
+ @pytest.mark.timeout(60)
195
241
  def test_rank_size_sync() -> None:
196
- proc = local_proc_mesh(gpus=2).get()
242
+ proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
197
243
  r = proc.spawn("runit", RunIt).get()
198
244
 
199
245
  acc = Accumulator(r.run, 0, operator.add)
@@ -201,8 +247,9 @@ def test_rank_size_sync() -> None:
201
247
  assert 4 == acc.accumulate(lambda: current_size()["gpus"]).get()
202
248
 
203
249
 
250
+ @pytest.mark.timeout(60)
204
251
  def test_accumulate_sync() -> None:
205
- proc = local_proc_mesh(gpus=2).get()
252
+ proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
206
253
  counter = proc.spawn("counter", Counter, 1).get()
207
254
  counter.incr.broadcast()
208
255
  acc = Accumulator(counter.value, 0, operator.add)
@@ -216,8 +263,9 @@ class CastToCounter(Actor):
216
263
  return list(c.value.call().get())
217
264
 
218
265
 
266
+ @pytest.mark.timeout(60)
219
267
  def test_value_mesh() -> None:
220
- proc = local_proc_mesh(gpus=2).get()
268
+ proc = fake_in_process_host().spawn_procs(per_host={"hosts": 1, "gpus": 2})
221
269
  counter = proc.spawn("counter", Counter, 0).get()
222
270
  counter.slice(hosts=0, gpus=1).incr.broadcast()
223
271
  x = counter.value.call().get()
@@ -228,7 +276,18 @@ def test_value_mesh() -> None:
228
276
  assert list(x) == n.slice(gpus=0).doit.call_one(counter).get()
229
277
 
230
278
 
279
+ @pytest.mark.timeout(60)
231
280
  def test_rust_binding_modules_correct() -> None:
281
+ """
282
+ This tests that rust bindings will survive pickling correctly.
283
+
284
+ To correctly define a rust binding, either
285
+
286
+ (1) Set its module to "monarch._rust_bindings.rust_crate.rust_module",
287
+ and make sure it is registered in monarch_extension/lib.rs
288
+ (2) Set its module to some existing python file, and use @rust_struct to install
289
+ the rust struct in that file and patch in any python extension methods.
290
+ """
232
291
  import monarch._rust_bindings as bindings
233
292
 
234
293
  def check(module, path):
@@ -238,14 +297,16 @@ def test_rust_binding_modules_correct() -> None:
238
297
  if isinstance(value, ModuleType):
239
298
  check(value, f"{path}.{name}")
240
299
  elif hasattr(value, "__module__"):
241
- assert value.__name__ == name
242
- assert value.__module__ == path
300
+ value_module = importlib.import_module(value.__module__)
301
+ resolved_value = getattr(value_module, value.__name__)
302
+ assert value is resolved_value
243
303
 
244
304
  check(bindings, "monarch._rust_bindings")
245
305
 
246
306
 
307
+ @pytest.mark.timeout(60)
247
308
  def test_proc_mesh_liveness() -> None:
248
- mesh = proc_mesh(gpus=2).get()
309
+ mesh = this_host().spawn_procs(per_host={"gpus": 2})
249
310
  counter = mesh.spawn("counter", Counter, 1).get()
250
311
  del mesh
251
312
  # Give some time for the mesh to have been shut down.
@@ -270,7 +331,7 @@ class TLSActor(Actor):
270
331
  self.local.value += 1
271
332
 
272
333
  @endpoint
273
- def get(self):
334
+ def get_value(self):
274
335
  return self.local.value
275
336
 
276
337
  @endpoint
@@ -278,16 +339,17 @@ class TLSActor(Actor):
278
339
  return self.local.value
279
340
 
280
341
 
342
+ @pytest.mark.timeout(60)
281
343
  async def test_actor_tls() -> None:
282
344
  """Test that thread-local state is respected."""
283
- pm = await proc_mesh(gpus=1)
345
+ pm = this_host().spawn_procs(per_host={"gpus": 1})
284
346
  am = await pm.spawn("tls", TLSActor)
285
347
  await am.increment.call_one()
286
348
  await am.increment_async.call_one()
287
349
  await am.increment.call_one()
288
350
  await am.increment_async.call_one()
289
351
 
290
- assert 4 == await am.get.call_one()
352
+ assert 4 == await am.get_value.call_one()
291
353
  assert 4 == await am.get_async.call_one()
292
354
 
293
355
 
@@ -303,20 +365,21 @@ class TLSActorFullSync(Actor):
303
365
  self.local.value += 1
304
366
 
305
367
  @endpoint
306
- def get(self):
368
+ def get_value(self):
307
369
  return self.local.value
308
370
 
309
371
 
372
+ @pytest.mark.timeout(60)
310
373
  async def test_actor_tls_full_sync() -> None:
311
374
  """Test that thread-local state is respected."""
312
- pm = await proc_mesh(gpus=1)
375
+ pm = this_host().spawn_procs(per_host={"gpus": 1})
313
376
  am = await pm.spawn("tls", TLSActorFullSync)
314
377
  await am.increment.call_one()
315
378
  await am.increment.call_one()
316
379
  await am.increment.call_one()
317
380
  await am.increment.call_one()
318
381
 
319
- assert 4 == await am.get.call_one()
382
+ assert 4 == await am.get_value.call_one()
320
383
 
321
384
 
322
385
  class AsyncActor(Actor):
@@ -333,10 +396,10 @@ class AsyncActor(Actor):
333
396
  self.should_exit = True
334
397
 
335
398
 
336
- @pytest.mark.timeout(15)
399
+ @pytest.mark.timeout(30)
337
400
  async def test_async_concurrency():
338
401
  """Test that async endpoints will be processed concurrently."""
339
- pm = await proc_mesh(gpus=1)
402
+ pm = await this_host().spawn_procs()
340
403
  am = await pm.spawn("async", AsyncActor)
341
404
  fut = am.sleep.call()
342
405
  # This call should go through and exit the sleep loop, as long as we are
@@ -442,19 +505,35 @@ async def awaitit(f):
442
505
 
443
506
 
444
507
  class Printer(Actor):
445
- def __init__(self):
446
- self.logger = logging.getLogger()
447
- self.logger.setLevel(INFO)
508
+ def __init__(self) -> None:
509
+ self._logger: logging.Logger = logging.getLogger()
448
510
 
449
511
  @endpoint
450
- async def print(self, content: str):
451
- print(f"{os.getpid()} {content}")
512
+ async def print(self, content: str) -> None:
513
+ print(f"{content}", flush=True)
514
+ sys.stdout.flush()
515
+ sys.stderr.flush()
452
516
 
453
517
  @endpoint
454
- async def log(self, content: str):
455
- self.logger.info(f"{os.getpid()} {content}")
456
-
457
-
518
+ async def log(self, content: str) -> None:
519
+ self._logger.error(f"{content}")
520
+ for handler in self._logger.handlers:
521
+ handler.flush()
522
+ sys.stdout.flush()
523
+ sys.stderr.flush()
524
+
525
+ def _handle_undeliverable_message(
526
+ self, message: UndeliverableMessageEnvelope
527
+ ) -> bool:
528
+ # Don't throw an error on undeliverable messages. This actor is used in a test for
529
+ # stopping actor meshes, and if we throw an error here then there is a race between
530
+ # the asserted error that the mesh was stopped and the supervision error that a message
531
+ # wasn't delivered.
532
+ self._logger.error(f"Ignoring undeliverable message: {message}")
533
+ return True
534
+
535
+
536
+ @pytest.mark.timeout(60)
458
537
  async def test_actor_log_streaming() -> None:
459
538
  # Save original file descriptors
460
539
  original_stdout_fd = os.dup(1) # stdout
@@ -482,19 +561,247 @@ async def test_actor_log_streaming() -> None:
482
561
  sys.stderr = stderr_file
483
562
 
484
563
  try:
485
- pm = await proc_mesh(gpus=2)
564
+ pm = this_host().spawn_procs(per_host={"gpus": 2})
486
565
  am = await pm.spawn("printer", Printer)
487
566
 
488
- await am.print.call("hello 1")
489
- await am.log.call("hello 2")
567
+ # Disable streaming logs to client
568
+ await pm.logging_option(
569
+ stream_to_client=False, aggregate_window_sec=None
570
+ )
571
+ await asyncio.sleep(1)
572
+
573
+ # These should not be streamed to client initially
574
+ for _ in range(5):
575
+ await am.print.call("no print streaming")
576
+ await am.log.call("no log streaming")
577
+ await asyncio.sleep(1)
578
+
579
+ # Enable streaming logs to client
580
+ await pm.logging_option(
581
+ stream_to_client=True, aggregate_window_sec=1, level=logging.FATAL
582
+ )
583
+ # Give it some time to reflect
584
+ await asyncio.sleep(1)
585
+
586
+ # These should be streamed to client
587
+ for _ in range(5):
588
+ await am.print.call("has print streaming")
589
+ await am.log.call("no log streaming due to level mismatch")
590
+ await asyncio.sleep(1)
591
+
592
+ # Enable streaming logs to client
593
+ await pm.logging_option(
594
+ stream_to_client=True, aggregate_window_sec=1, level=logging.ERROR
595
+ )
596
+ # Give it some time to reflect
597
+ await asyncio.sleep(1)
598
+
599
+ # These should be streamed to client
600
+ for _ in range(5):
601
+ await am.print.call("has print streaming too")
602
+ await am.log.call("has log streaming as level matched")
603
+
604
+ await pm.stop()
605
+
606
+ # Flush all outputs
607
+ stdout_file.flush()
608
+ stderr_file.flush()
609
+ os.fsync(stdout_file.fileno())
610
+ os.fsync(stderr_file.fileno())
611
+
612
+ finally:
613
+ # Restore Python's sys.stdout/stderr
614
+ sys.stdout = original_sys_stdout
615
+ sys.stderr = original_sys_stderr
616
+
617
+ # Restore original file descriptors
618
+ os.dup2(original_stdout_fd, 1)
619
+ os.dup2(original_stderr_fd, 2)
620
+
621
+ # Read the captured output
622
+ with open(stdout_path, "r") as f:
623
+ stdout_content = f.read()
624
+
625
+ with open(stderr_path, "r") as f:
626
+ stderr_content = f.read()
627
+
628
+ # Clean up temp files
629
+ os.unlink(stdout_path)
630
+ os.unlink(stderr_path)
631
+
632
+ # Assertions on the captured output
633
+ # Has a leading context so we can distinguish between streamed log and
634
+ # the log directly printed by the child processes as they share the same stdout/stderr
635
+ assert not re.search(
636
+ r"similar log lines.*no print streaming", stdout_content
637
+ ), stdout_content
638
+ assert not re.search(
639
+ r"similar log lines.*no print streaming", stderr_content
640
+ ), stderr_content
641
+ assert not re.search(
642
+ r"similar log lines.*no log streaming", stdout_content
643
+ ), stdout_content
644
+ assert not re.search(
645
+ r"similar log lines.*no log streaming", stderr_content
646
+ ), stderr_content
647
+ assert not re.search(
648
+ r"similar log lines.*no log streaming due to level mismatch", stdout_content
649
+ ), stdout_content
650
+ assert not re.search(
651
+ r"similar log lines.*no log streaming due to level mismatch", stderr_content
652
+ ), stderr_content
653
+
654
+ assert re.search(
655
+ r"similar log lines.*has print streaming", stdout_content
656
+ ), stdout_content
657
+ assert not re.search(
658
+ r"similar log lines.*has print streaming", stderr_content
659
+ ), stderr_content
660
+ assert re.search(
661
+ r"similar log lines.*has print streaming too", stdout_content
662
+ ), stdout_content
663
+ assert not re.search(
664
+ r"similar log lines.*has print streaming too", stderr_content
665
+ ), stderr_content
666
+ assert not re.search(
667
+ r"similar log lines.*log streaming as level matched", stdout_content
668
+ ), stdout_content
669
+ assert re.search(
670
+ r"similar log lines.*log streaming as level matched",
671
+ stderr_content,
672
+ ), stderr_content
673
+
674
+ finally:
675
+ # Ensure file descriptors are restored even if something goes wrong
676
+ try:
677
+ os.dup2(original_stdout_fd, 1)
678
+ os.dup2(original_stderr_fd, 2)
679
+ os.close(original_stdout_fd)
680
+ os.close(original_stderr_fd)
681
+ except OSError:
682
+ pass
683
+
684
+
685
+ @pytest.mark.timeout(120)
686
+ async def test_alloc_based_log_streaming() -> None:
687
+ """Test both AllocHandle.stream_logs = False and True cases."""
688
+
689
+ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
690
+ # Save original file descriptors
691
+ original_stdout_fd = os.dup(1) # stdout
692
+
693
+ try:
694
+ # Create temporary files to capture output
695
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
696
+ stdout_path = stdout_file.name
697
+ os.dup2(stdout_file.fileno(), 1)
698
+ original_sys_stdout = sys.stdout
699
+ sys.stdout = stdout_file
700
+
701
+ try:
702
+ # Create proc mesh with custom stream_logs setting
703
+ host_mesh = create_local_host_mesh()
704
+ alloc_handle = host_mesh._alloc(hosts=1, gpus=2)
705
+
706
+ # Override the stream_logs setting
707
+ custom_alloc_handle = AllocHandle(
708
+ alloc_handle._hy_alloc, alloc_handle._extent, stream_logs
709
+ )
710
+
711
+ pm = ProcMesh.from_alloc(custom_alloc_handle)
712
+ am = await pm.spawn("printer", Printer)
713
+
714
+ await pm.initialized
715
+
716
+ for _ in range(5):
717
+ await am.print.call(f"{test_name} print streaming")
718
+
719
+ await pm.stop()
720
+
721
+ # Flush all outputs
722
+ stdout_file.flush()
723
+ os.fsync(stdout_file.fileno())
724
+
725
+ finally:
726
+ # Restore Python's sys.stdout
727
+ sys.stdout = original_sys_stdout
728
+
729
+ # Restore original file descriptors
730
+ os.dup2(original_stdout_fd, 1)
731
+
732
+ # Read the captured output
733
+ with open(stdout_path, "r") as f:
734
+ stdout_content = f.read()
735
+
736
+ # Clean up temp files
737
+ os.unlink(stdout_path)
738
+
739
+ if not stream_logs:
740
+ # When stream_logs=False, logs should not be streamed to client
741
+ assert not re.search(
742
+ rf"similar log lines.*{test_name} print streaming", stdout_content
743
+ ), f"stream_logs=True case: {stdout_content}"
744
+ assert re.search(
745
+ rf"{test_name} print streaming", stdout_content
746
+ ), f"stream_logs=True case: {stdout_content}"
747
+ else:
748
+ # When stream_logs=True, logs should be streamed to client (no aggregation by default)
749
+ assert re.search(
750
+ rf"similar log lines.*{test_name} print streaming", stdout_content
751
+ ), f"stream_logs=False case: {stdout_content}"
752
+ assert not re.search(
753
+ rf"\[[0-9]\]{test_name} print streaming", stdout_content
754
+ ), f"stream_logs=False case: {stdout_content}"
755
+
756
+ finally:
757
+ # Ensure file descriptors are restored even if something goes wrong
758
+ try:
759
+ os.dup2(original_stdout_fd, 1)
760
+ os.close(original_stdout_fd)
761
+ except OSError:
762
+ pass
763
+
764
+ # Test both cases
765
+ await test_stream_logs_case(False, "stream_logs_false")
766
+ await test_stream_logs_case(True, "stream_logs_true")
767
+
768
+
769
+ @pytest.mark.timeout(60)
770
+ async def test_logging_option_defaults() -> None:
771
+ # Save original file descriptors
772
+ original_stdout_fd = os.dup(1) # stdout
773
+ original_stderr_fd = os.dup(2) # stderr
774
+
775
+ try:
776
+ # Create temporary files to capture output
777
+ with tempfile.NamedTemporaryFile(
778
+ mode="w+", delete=False
779
+ ) as stdout_file, tempfile.NamedTemporaryFile(
780
+ mode="w+", delete=False
781
+ ) as stderr_file:
782
+ stdout_path = stdout_file.name
783
+ stderr_path = stderr_file.name
784
+
785
+ # Redirect file descriptors to our temp files
786
+ # This will capture both Python and Rust output
787
+ os.dup2(stdout_file.fileno(), 1)
788
+ os.dup2(stderr_file.fileno(), 2)
789
+
790
+ # Also redirect Python's sys.stdout/stderr for completeness
791
+ original_sys_stdout = sys.stdout
792
+ original_sys_stderr = sys.stderr
793
+ sys.stdout = stdout_file
794
+ sys.stderr = stderr_file
490
795
 
491
- await pm.logging_option(stream_to_client=True)
796
+ try:
797
+ pm = await this_host().spawn_procs(per_host={"gpus": 2})
798
+ am = await pm.spawn("printer", Printer)
492
799
 
493
- await am.print.call("hello 3")
494
- await am.log.call("hello 4")
800
+ for _ in range(5):
801
+ await am.print.call("print streaming")
802
+ await am.log.call("log streaming")
495
803
 
496
- # Give it sometime to send log back
497
- time.sleep(5)
804
+ await pm.stop()
498
805
 
499
806
  # Flush all outputs
500
807
  stdout_file.flush()
@@ -515,16 +822,27 @@ async def test_actor_log_streaming() -> None:
515
822
  with open(stdout_path, "r") as f:
516
823
  stdout_content = f.read()
517
824
 
825
+ with open(stderr_path, "r") as f:
826
+ stderr_content = f.read()
827
+
518
828
  # Clean up temp files
519
829
  os.unlink(stdout_path)
520
830
  os.unlink(stderr_path)
521
831
 
522
- # TODO: (@jamessun) we need to disable logging forwarder for python logger
523
- # assert "hello 1" not in stdout_content
524
- assert "hello 2" not in stdout_content
525
-
526
- assert "hello 3" in stdout_content
527
- # assert "hello 4" in stdout_content
832
+ # Assertions on the captured output
833
+ assert not re.search(
834
+ r"similar log lines.*print streaming", stdout_content
835
+ ), stdout_content
836
+ assert re.search(r"print streaming", stdout_content), stdout_content
837
+ assert not re.search(
838
+ r"similar log lines.*print streaming", stderr_content
839
+ ), stderr_content
840
+ assert not re.search(
841
+ r"similar log lines.*log streaming", stdout_content
842
+ ), stdout_content
843
+ assert not re.search(
844
+ r"similar log lines.*log streaming", stderr_content
845
+ ), stderr_content
528
846
 
529
847
  finally:
530
848
  # Ensure file descriptors are restored even if something goes wrong
@@ -537,6 +855,378 @@ async def test_actor_log_streaming() -> None:
537
855
  pass
538
856
 
539
857
 
858
+ # oss_skip: pytest keeps complaining about mocking get_ipython module
859
+ @pytest.mark.oss_skip
860
+ @pytest.mark.timeout(180)
861
+ async def test_flush_logs_ipython() -> None:
862
+ """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
863
+ # Save original file descriptors
864
+ original_stdout_fd = os.dup(1) # stdout
865
+
866
+ try:
867
+ # Create temporary files to capture output
868
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
869
+ stdout_path = stdout_file.name
870
+
871
+ # Redirect file descriptors to our temp files
872
+ os.dup2(stdout_file.fileno(), 1)
873
+
874
+ # Also redirect Python's sys.stdout
875
+ original_sys_stdout = sys.stdout
876
+ sys.stdout = stdout_file
877
+
878
+ try:
879
+ # Mock IPython environment
880
+ class MockExecutionResult:
881
+ pass
882
+
883
+ class MockEvents:
884
+ def __init__(self):
885
+ self.callbacks = {}
886
+ self.registers = 0
887
+ self.unregisters = 0
888
+
889
+ def register(self, event_name, callback):
890
+ if event_name not in self.callbacks:
891
+ self.callbacks[event_name] = []
892
+ self.callbacks[event_name].append(callback)
893
+ self.registers += 1
894
+
895
+ def unregister(self, event_name, callback):
896
+ if event_name not in self.callbacks:
897
+ raise ValueError(f"Event {event_name} not registered")
898
+ assert callback in self.callbacks[event_name]
899
+ self.callbacks[event_name].remove(callback)
900
+ self.unregisters += 1
901
+
902
+ def trigger(self, event_name, *args, **kwargs):
903
+ if event_name in self.callbacks:
904
+ for callback in self.callbacks[event_name]:
905
+ callback(*args, **kwargs)
906
+
907
+ class MockIPython:
908
+ def __init__(self):
909
+ self.events = MockEvents()
910
+
911
+ mock_ipython = MockIPython()
912
+
913
+ with unittest.mock.patch(
914
+ "monarch._src.actor.logging.get_ipython",
915
+ lambda: mock_ipython,
916
+ ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
917
+ # Make sure we can register and unregister callbacks
918
+ for i in range(3):
919
+ pm1 = await this_host().spawn_procs(per_host={"gpus": 2})
920
+ pm2 = await this_host().spawn_procs(per_host={"gpus": 2})
921
+ am1 = await pm1.spawn("printer", Printer)
922
+ am2 = await pm2.spawn("printer", Printer)
923
+
924
+ # Set aggregation window to ensure logs are buffered
925
+ await pm1.logging_option(
926
+ stream_to_client=True, aggregate_window_sec=600
927
+ )
928
+ await pm2.logging_option(
929
+ stream_to_client=True, aggregate_window_sec=600
930
+ )
931
+ # TODO: fix the following assertion
932
+ # assert mock_ipython.events.unregisters == 2 * i
933
+
934
+ # _get_controller_controller() spawns an extra local mesh
935
+ # but log streaming is disabled so it doesn't hurt
936
+ assert mock_ipython.events.registers == 1 + 2 * (i + 1)
937
+ await asyncio.sleep(1)
938
+
939
+ # Generate some logs that will be aggregated
940
+ for _ in range(5):
941
+ await am1.print.call("ipython1 test log")
942
+ await am2.print.call("ipython2 test log")
943
+
944
+ # Trigger the post_run_cell event which should flush logs
945
+ mock_ipython.events.trigger(
946
+ "post_run_cell", MockExecutionResult()
947
+ )
948
+
949
+ # Flush all outputs
950
+ stdout_file.flush()
951
+ os.fsync(stdout_file.fileno())
952
+
953
+ gc.collect()
954
+
955
+ # Same as above, _get_controller_controller() spawns an extra local mesh
956
+ assert mock_ipython.events.registers == 7
957
+ # There are many objects still taking refs
958
+ # TODO: fix the following assertion
959
+ assert mock_ipython.events.unregisters == 0
960
+ assert len(mock_ipython.events.callbacks["post_run_cell"]) == 7
961
+ finally:
962
+ # Restore Python's sys.stdout
963
+ sys.stdout = original_sys_stdout
964
+
965
+ # Restore original file descriptors
966
+ os.dup2(original_stdout_fd, 1)
967
+
968
+ # Read the captured output
969
+ with open(stdout_path, "r") as f:
970
+ stdout_content = f.read()
971
+
972
+ # TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils
973
+
974
+ # Clean up temp files
975
+ os.unlink(stdout_path)
976
+
977
+ # Verify that logs were flushed when the post_run_cell event was triggered
978
+ # We should see the aggregated logs in the output
979
+ assert (
980
+ len(
981
+ re.findall(
982
+ r"\[10 similar log lines\].*ipython1 test log", stdout_content
983
+ )
984
+ )
985
+ == 3
986
+ ), stdout_content
987
+
988
+ assert (
989
+ len(
990
+ re.findall(
991
+ r"\[10 similar log lines\].*ipython2 test log", stdout_content
992
+ )
993
+ )
994
+ == 3
995
+ ), stdout_content
996
+
997
+ finally:
998
+ # Ensure file descriptors are restored even if something goes wrong
999
+ try:
1000
+ os.dup2(original_stdout_fd, 1)
1001
+ os.close(original_stdout_fd)
1002
+ except OSError:
1003
+ pass
1004
+
1005
+
1006
+ # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
1007
+ @pytest.mark.oss_skip
1008
+ async def test_flush_logs_fast_exit() -> None:
1009
+ # We use a subprocess to run the test so we can handle the flushed logs at the end.
1010
+ # Otherwise, it is hard to restore the original stdout/stderr.
1011
+
1012
+ test_bin = importlib.resources.files(str(__package__)).joinpath("test_bin")
1013
+
1014
+ # Run the binary in a separate process and capture stdout and stderr
1015
+ cmd = [str(test_bin), "flush-logs"]
1016
+
1017
+ process = subprocess.run(cmd, capture_output=True, timeout=60, text=True)
1018
+
1019
+ # Check if the process ended without error
1020
+ if process.returncode != 0:
1021
+ raise RuntimeError(f"{cmd} ended with error code {process.returncode}. ")
1022
+
1023
+ # Assertions on the captured output, 160 = 32 procs * 5 logs per proc
1024
+ # 32 and 5 are specified in the test_bin flush-logs.
1025
+ assert (
1026
+ len(
1027
+ re.findall(
1028
+ r"160 similar log lines.*has print streaming",
1029
+ process.stdout,
1030
+ )
1031
+ )
1032
+ == 1
1033
+ ), process.stdout
1034
+
1035
+
1036
+ @pytest.mark.timeout(60)
1037
+ async def test_flush_on_disable_aggregation() -> None:
1038
+ """Test that logs are flushed when disabling aggregation.
1039
+
1040
+ This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation."
1041
+ """
1042
+ # Save original file descriptors
1043
+ original_stdout_fd = os.dup(1) # stdout
1044
+
1045
+ try:
1046
+ # Create temporary files to capture output
1047
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
1048
+ stdout_path = stdout_file.name
1049
+
1050
+ # Redirect file descriptors to our temp files
1051
+ os.dup2(stdout_file.fileno(), 1)
1052
+
1053
+ # Also redirect Python's sys.stdout
1054
+ original_sys_stdout = sys.stdout
1055
+ sys.stdout = stdout_file
1056
+
1057
+ try:
1058
+ pm = await this_host().spawn_procs(per_host={"gpus": 2})
1059
+ am = await pm.spawn("printer", Printer)
1060
+
1061
+ # Set a long aggregation window to ensure logs aren't flushed immediately
1062
+ await pm.logging_option(stream_to_client=True, aggregate_window_sec=60)
1063
+
1064
+ # Generate some logs that will be aggregated but not flushed immediately
1065
+ for _ in range(5):
1066
+ await am.print.call("aggregated log line")
1067
+ await asyncio.sleep(1)
1068
+
1069
+ # Now disable aggregation - this should trigger an immediate flush
1070
+ await pm.logging_option(
1071
+ stream_to_client=True, aggregate_window_sec=None
1072
+ )
1073
+
1074
+ # Wait a bit to ensure logs are collected
1075
+ await asyncio.sleep(1)
1076
+ for _ in range(5):
1077
+ await am.print.call("single log line")
1078
+
1079
+ await pm.stop()
1080
+
1081
+ # Flush all outputs
1082
+ stdout_file.flush()
1083
+ os.fsync(stdout_file.fileno())
1084
+
1085
+ finally:
1086
+ # Restore Python's sys.stdout
1087
+ sys.stdout = original_sys_stdout
1088
+
1089
+ # Restore original file descriptors
1090
+ os.dup2(original_stdout_fd, 1)
1091
+
1092
+ # Read the captured output
1093
+ with open(stdout_path, "r") as f:
1094
+ stdout_content = f.read()
1095
+
1096
+ # Clean up temp files
1097
+ os.unlink(stdout_path)
1098
+
1099
+ # Verify that logs were flushed when aggregation was disabled
1100
+ # We should see the aggregated logs in the output
1101
+ # 10 = 5 log lines * 2 procs
1102
+ assert re.search(
1103
+ r"\[10 similar log lines\].*aggregated log line", stdout_content
1104
+ ), stdout_content
1105
+
1106
+ # No aggregated single log lines
1107
+ assert not re.search(
1108
+ r"similar log lines.*single log line", stdout_content
1109
+ ), stdout_content
1110
+
1111
+ # 10 = 5 log lines * 2 procs
1112
+ assert (
1113
+ len(re.findall(r"\[.* [0-9]+\] single log line", stdout_content)) == 10
1114
+ ), stdout_content
1115
+
1116
+ finally:
1117
+ # Ensure file descriptors are restored even if something goes wrong
1118
+ try:
1119
+ os.dup2(original_stdout_fd, 1)
1120
+ os.close(original_stdout_fd)
1121
+ except OSError:
1122
+ pass
1123
+
1124
+
1125
+ @pytest.mark.timeout(120)
1126
+ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
1127
+ """
1128
+ The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked.
1129
+ Because now a flush call is purely sync, it is very easy to get into a deadlock.
1130
+ So we assert the last flush call will not get into such a state.
1131
+ """
1132
+ pm = this_host().spawn_procs(per_host={"gpus": 4})
1133
+ am = pm.spawn("printer", Printer)
1134
+
1135
+ # Generate some logs that will be aggregated but not flushed immediately
1136
+ for _ in range(10):
1137
+ await am.print.call("aggregated log line")
1138
+
1139
+ log_mesh = pm._logging_manager._logging_mesh_client
1140
+ assert log_mesh is not None
1141
+ futures = []
1142
+ for _ in range(5):
1143
+ # FIXME: the order of futures doesn't necessarily mean the order of flushes due to the async nature.
1144
+ await asyncio.sleep(0.1)
1145
+ futures.append(Future(coro=log_mesh.flush().spawn().task()))
1146
+
1147
+ # The last flush should not block
1148
+ futures[-1].get()
1149
+
1150
+
1151
+ @pytest.mark.timeout(60)
1152
+ async def test_adjust_aggregation_window() -> None:
1153
+ """Test that the flush deadline is updated when the aggregation window is adjusted.
1154
+
1155
+ This tests the corner case: "This can happen if the user has adjusted the aggregation window."
1156
+ """
1157
+ # Save original file descriptors
1158
+ original_stdout_fd = os.dup(1) # stdout
1159
+
1160
+ try:
1161
+ # Create temporary files to capture output
1162
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
1163
+ stdout_path = stdout_file.name
1164
+
1165
+ # Redirect file descriptors to our temp files
1166
+ os.dup2(stdout_file.fileno(), 1)
1167
+
1168
+ # Also redirect Python's sys.stdout
1169
+ original_sys_stdout = sys.stdout
1170
+ sys.stdout = stdout_file
1171
+
1172
+ try:
1173
+ pm = await this_host().spawn_procs(per_host={"gpus": 2})
1174
+ am = await pm.spawn("printer", Printer)
1175
+
1176
+ # Set a long aggregation window initially
1177
+ await pm.logging_option(stream_to_client=True, aggregate_window_sec=100)
1178
+
1179
+ # Generate some logs that will be aggregated
1180
+ for _ in range(3):
1181
+ await am.print.call("first batch of logs")
1182
+ await asyncio.sleep(1)
1183
+
1184
+ # Now adjust to a shorter window - this should update the flush deadline
1185
+ await pm.logging_option(stream_to_client=True, aggregate_window_sec=2)
1186
+
1187
+ # Generate more logs
1188
+ for _ in range(3):
1189
+ await am.print.call("second batch of logs")
1190
+
1191
+ await pm.stop()
1192
+
1193
+ # Flush all outputs
1194
+ stdout_file.flush()
1195
+ os.fsync(stdout_file.fileno())
1196
+
1197
+ finally:
1198
+ # Restore Python's sys.stdout/stderr
1199
+ sys.stdout = original_sys_stdout
1200
+
1201
+ # Restore original file descriptors
1202
+ os.dup2(original_stdout_fd, 1)
1203
+
1204
+ # Read the captured output
1205
+ with open(stdout_path, "r") as f:
1206
+ stdout_content = f.read()
1207
+
1208
+ # Clean up temp files
1209
+ os.unlink(stdout_path)
1210
+
1211
+ # Verify that logs were flushed when the aggregation window was adjusted
1212
+ # We should see both batches of logs in the output
1213
+ assert re.search(
1214
+ r"\[6 similar log lines\].*first batch of logs", stdout_content
1215
+ ), stdout_content
1216
+
1217
+ assert re.search(
1218
+ r"similar log lines.*second batch of logs", stdout_content
1219
+ ), stdout_content
1220
+
1221
+ finally:
1222
+ # Ensure file descriptors are restored even if something goes wrong
1223
+ try:
1224
+ os.dup2(original_stdout_fd, 1)
1225
+ os.close(original_stdout_fd)
1226
+ except OSError:
1227
+ pass
1228
+
1229
+
540
1230
  class SendAlot(Actor):
541
1231
  @endpoint
542
1232
  async def send(self, port: Port[int]):
@@ -544,10 +1234,11 @@ class SendAlot(Actor):
544
1234
  port.send(i)
545
1235
 
546
1236
 
547
- def test_port_as_argument():
548
- proc_mesh = local_proc_mesh(gpus=1).get()
1237
+ @pytest.mark.timeout(60)
1238
+ def test_port_as_argument() -> None:
1239
+ proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 1})
549
1240
  s = proc_mesh.spawn("send_alot", SendAlot).get()
550
- send, recv = PortTuple.create(proc_mesh._mailbox)
1241
+ send, recv = Channel[int].open()
551
1242
 
552
1243
  s.send.broadcast(send)
553
1244
 
@@ -555,14 +1246,14 @@ def test_port_as_argument():
555
1246
  assert i == recv.recv().get()
556
1247
 
557
1248
 
558
- @pytest.mark.timeout(15)
1249
+ @pytest.mark.timeout(30)
559
1250
  async def test_same_actor_twice() -> None:
560
- pm = await proc_mesh(gpus=1)
561
- await pm.spawn("dup", Counter, 0)
1251
+ pm = this_host().spawn_procs(per_host={"gpus": 1})
1252
+ await pm.spawn("dup", Counter, 0).initialized
562
1253
 
563
1254
  # The second spawn with the same name should fail with a specific error
564
1255
  with pytest.raises(Exception) as exc_info:
565
- await pm.spawn("dup", Counter, 0)
1256
+ await pm.spawn("dup", Counter, 0).initialized
566
1257
 
567
1258
  # Assert that the error message contains the expected text about duplicate actor name
568
1259
  error_msg = str(exc_info.value)
@@ -571,23 +1262,81 @@ async def test_same_actor_twice() -> None:
571
1262
  ), f"Expected error message about duplicate actor name, got: {error_msg}"
572
1263
 
573
1264
 
1265
+ class LsActor(Actor):
1266
+ def __init__(self, workspace: str):
1267
+ self.workspace = workspace
1268
+
1269
+ @endpoint
1270
+ async def ls(self) -> list[str]:
1271
+ return os.listdir(self.workspace)
1272
+
1273
+
1274
+ async def test_sync_workspace() -> None:
1275
+ # create two workspaces: one for local and one for remote
1276
+ with tempfile.TemporaryDirectory() as workspace_src, tempfile.TemporaryDirectory() as workspace_dst:
1277
+
1278
+ def bootstrap_WORKSPACE_DIR() -> None:
1279
+ import os
1280
+
1281
+ os.environ["WORKSPACE_DIR"] = workspace_dst
1282
+
1283
+ pm = this_host().spawn_procs(
1284
+ per_host={"gpus": 1}, bootstrap=bootstrap_WORKSPACE_DIR
1285
+ )
1286
+
1287
+ config = defaults.config("slurm", workspace_src)
1288
+ await pm.sync_workspace(workspace=config.workspace, auto_reload=True)
1289
+
1290
+ # no file in remote workspace initially
1291
+ am = await pm.spawn("ls", LsActor, workspace_dst)
1292
+ for item in list(am.ls.call().get()):
1293
+ assert len(item[1]) == 0
1294
+
1295
+ # write a file to local workspace
1296
+ file_path = os.path.join(workspace_src, "new_file")
1297
+ with open(file_path, "w") as f:
1298
+ f.write("hello world")
1299
+ f.flush()
1300
+
1301
+ # force a sync and it should populate on the dst workspace
1302
+ await pm.sync_workspace(config.workspace, auto_reload=True)
1303
+ for item in list(am.ls.call().get()):
1304
+ assert len(item[1]) == 1
1305
+ assert item[1][0] == "new_file"
1306
+ file_path = os.path.join(workspace_dst, item[1][0])
1307
+ with open(file_path, "r") as f:
1308
+ assert f.readline() == "hello world"
1309
+
1310
+ # sanity check
1311
+ assert "WORKSPACE_DIR" not in os.environ, "test leaves env var side-effects!"
1312
+
1313
+
574
1314
  class TestActorMeshStop(unittest.IsolatedAsyncioTestCase):
575
1315
  async def test_actor_mesh_stop(self) -> None:
576
- pm = await proc_mesh(gpus=2)
1316
+ pm = this_host().spawn_procs(per_host={"gpus": 2})
577
1317
  am_1 = await pm.spawn("printer", Printer)
578
1318
  am_2 = await pm.spawn("printer2", Printer)
579
1319
  await am_1.print.call("hello 1")
580
1320
  await am_1.log.call("hello 2")
581
- await cast(ActorMeshRef, am_1).stop()
1321
+ await cast(ActorMesh, am_1).stop()
582
1322
 
583
1323
  with self.assertRaisesRegex(
584
- RuntimeError, expected_regex="`ActorMesh` has been stopped"
1324
+ RuntimeError, expected_regex="`PythonActorMesh` has already been stopped"
585
1325
  ):
586
1326
  await am_1.print.call("hello 1")
587
1327
 
588
1328
  await am_2.print.call("hello 3")
589
1329
  await am_2.log.call("hello 4")
590
1330
 
1331
+ await pm.stop()
1332
+
1333
+ async def test_proc_mesh_stop_after_actor_mesh_stop(self) -> None:
1334
+ pm = this_host().spawn_procs(per_host={"gpus": 2})
1335
+ am = await pm.spawn("printer", Printer)
1336
+
1337
+ await cast(ActorMesh, am).stop()
1338
+ await pm.stop()
1339
+
591
1340
 
592
1341
  class PortedActor(Actor):
593
1342
  @endpoint(explicit_response_port=True)
@@ -595,8 +1344,9 @@ class PortedActor(Actor):
595
1344
  port.send(3 + b)
596
1345
 
597
1346
 
1347
+ @pytest.mark.timeout(60)
598
1348
  def test_ported_actor():
599
- proc_mesh = local_proc_mesh(gpus=1).get()
1349
+ proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 1}).get()
600
1350
  a = proc_mesh.spawn("port_actor", PortedActor).get()
601
1351
  assert 5 == a.add.call_one(2).get()
602
1352
 
@@ -610,5 +1360,145 @@ async def consume():
610
1360
  assert r == (7, 2, 3)
611
1361
 
612
1362
 
1363
+ @pytest.mark.timeout(60)
613
1364
  def test_python_task_tuple() -> None:
614
1365
  PythonTask.from_coroutine(consume()).block_on()
1366
+
1367
+
1368
+ def test_select_result() -> None:
1369
+ def s(t):
1370
+ time.sleep(t)
1371
+ return t
1372
+
1373
+ a = PythonTask.spawn_blocking(lambda: s(4))
1374
+ b = PythonTask.spawn_blocking(lambda: s(0))
1375
+ r = PythonTask.select_one([a.task(), b.task()]).block_on()
1376
+ assert r == (0, 1)
1377
+
1378
+
1379
+ def test_mesh_len():
1380
+ proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 12})
1381
+ s = proc_mesh.spawn("sync_actor", SyncActor).get()
1382
+ assert 12 == len(s)
1383
+
1384
+
1385
+ class UndeliverableMessageReceiver(Actor):
1386
+ def __init__(self):
1387
+ self._messages = asyncio.Queue()
1388
+
1389
+ @endpoint
1390
+ async def receive_undeliverable(
1391
+ self, sender: ActorId, dest: PortId, error_msg: str
1392
+ ) -> None:
1393
+ await self._messages.put((sender, dest, error_msg))
1394
+
1395
+ @endpoint
1396
+ async def get_messages(self) -> Tuple[ActorId, PortId, str]:
1397
+ return await self._messages.get()
1398
+
1399
+
1400
+ class UndeliverableMessageSender(Actor):
1401
+ @endpoint
1402
+ def send_undeliverable(self) -> None:
1403
+ mailbox = context().actor_instance._mailbox
1404
+ port_id = PortId(
1405
+ actor_id=ActorId(
1406
+ world_name=mailbox.actor_id.world_name, rank=0, actor_name="bogus"
1407
+ ),
1408
+ port=1234,
1409
+ )
1410
+ port_ref = PortRef(port_id)
1411
+ port_ref.send(
1412
+ mailbox,
1413
+ PythonMessage(PythonMessageKind.Result(None), b"123"),
1414
+ )
1415
+
1416
+
1417
+ class UndeliverableMessageSenderWithOverride(UndeliverableMessageSender):
1418
+ def __init__(self, receiver: UndeliverableMessageReceiver):
1419
+ self._receiver = receiver
1420
+
1421
+ def _handle_undeliverable_message(
1422
+ self, message: UndeliverableMessageEnvelope
1423
+ ) -> bool:
1424
+ self._receiver.receive_undeliverable.call_one(
1425
+ message.sender(), message.dest(), message.error_msg()
1426
+ ).get()
1427
+ return True
1428
+
1429
+
1430
+ @pytest.mark.timeout(60)
1431
+ async def test_undeliverable_message_with_override() -> None:
1432
+ pm = this_host().spawn_procs(per_host={"gpus": 1})
1433
+ receiver = pm.spawn("undeliverable_receiver", UndeliverableMessageReceiver)
1434
+ sender = pm.spawn(
1435
+ "undeliverable_sender", UndeliverableMessageSenderWithOverride, receiver
1436
+ )
1437
+ sender.send_undeliverable.call().get()
1438
+ sender, dest, error_msg = receiver.get_messages.call_one().get()
1439
+ assert sender.actor_name == "undeliverable_sender"
1440
+ assert dest.actor_id.actor_name == "bogus"
1441
+ assert error_msg is not None
1442
+ pm.stop().get()
1443
+
1444
+
1445
+ @pytest.mark.timeout(60)
1446
+ async def test_undeliverable_message_without_override() -> None:
1447
+ pm = this_host().spawn_procs(per_host={"gpus": 1})
1448
+ sender = pm.spawn("undeliverable_sender", UndeliverableMessageSender)
1449
+ sender.send_undeliverable.call().get()
1450
+ # Wait a few seconds to ensure that the undeliverable message is processed
1451
+ # without crashing anything
1452
+ await asyncio.sleep(5)
1453
+ pm.stop().get()
1454
+
1455
+
1456
+ def test_this_and_that():
1457
+ counter = this_proc().spawn("counter", Counter, 7)
1458
+ assert 7 == counter.value.call_one().get()
1459
+
1460
+
1461
+ class ReceptorActor(Actor):
1462
+ @endpoint
1463
+ def status(self):
1464
+ return 1
1465
+
1466
+
1467
+ async def test_things_survive_losing_python_reference() -> None:
1468
+ """Test the slice_receptor_mesh function in LOCAL mode, verifying that setup methods are called."""
1469
+
1470
+ receptor = (
1471
+ this_host()
1472
+ .spawn_procs(per_host={"gpus": 1})
1473
+ .spawn(
1474
+ "receptor",
1475
+ ReceptorActor,
1476
+ )
1477
+ )
1478
+ receptor = receptor.slice(gpus=0)
1479
+
1480
+ await receptor.status.call()
1481
+
1482
+
1483
+ class IsInit(Actor):
1484
+ @endpoint
1485
+ def is_cuda_initialized(self) -> bool:
1486
+ cuda = ctypes.CDLL("libcuda.so.1")
1487
+ CUresult = ctypes.c_int
1488
+ cuDeviceGetCount = cuda.cuDeviceGetCount
1489
+ cuDeviceGetCount.argtypes = [ctypes.POINTER(ctypes.c_int)]
1490
+ cuDeviceGetCount.restype = CUresult
1491
+ count = ctypes.c_int()
1492
+ result = cuDeviceGetCount(ctypes.byref(count))
1493
+ CUDA_ERROR_NOT_INITIALIZED = 3
1494
+ return result == CUDA_ERROR_NOT_INITIALIZED
1495
+
1496
+
1497
+ @pytest.mark.oss_skip
1498
+ def test_cuda_is_not_initialized_in_a_new_proc():
1499
+ try:
1500
+ ctypes.CDLL("libcuda.so.1")
1501
+ except OSError:
1502
+ pytest.skip("cannot find cuda")
1503
+ proc = this_host().spawn_procs().spawn("is_init", IsInit)
1504
+ assert not proc.is_cuda_initialized.call_one().get()