torchmonarch-nightly 2025.8.1__cp313-cp313-manylinux2014_x86_64.whl → 2025.9.3__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. monarch/_rust_bindings.so +0 -0
  2. monarch/_src/actor/actor_mesh.py +414 -216
  3. monarch/_src/actor/allocator.py +75 -6
  4. monarch/_src/actor/bootstrap_main.py +7 -4
  5. monarch/_src/actor/code_sync/__init__.py +2 -0
  6. monarch/_src/actor/debugger/__init__.py +7 -0
  7. monarch/_src/actor/{debugger.py → debugger/debugger.py} +246 -135
  8. monarch/_src/actor/{pdb_wrapper.py → debugger/pdb_wrapper.py} +62 -23
  9. monarch/_src/actor/endpoint.py +27 -45
  10. monarch/_src/actor/future.py +86 -24
  11. monarch/_src/actor/host_mesh.py +125 -0
  12. monarch/_src/actor/logging.py +94 -0
  13. monarch/_src/actor/pickle.py +25 -0
  14. monarch/_src/actor/proc_mesh.py +423 -156
  15. monarch/_src/actor/python_extension_methods.py +90 -0
  16. monarch/_src/actor/shape.py +8 -1
  17. monarch/_src/actor/source_loader.py +45 -0
  18. monarch/_src/actor/telemetry/__init__.py +172 -0
  19. monarch/_src/actor/telemetry/rust_span_tracing.py +6 -39
  20. monarch/_src/debug_cli/__init__.py +7 -0
  21. monarch/_src/debug_cli/debug_cli.py +43 -0
  22. monarch/_src/tensor_engine/rdma.py +64 -9
  23. monarch/_testing.py +1 -3
  24. monarch/actor/__init__.py +24 -4
  25. monarch/common/_C.so +0 -0
  26. monarch/common/device_mesh.py +14 -0
  27. monarch/common/future.py +10 -0
  28. monarch/common/remote.py +14 -25
  29. monarch/common/tensor.py +12 -0
  30. monarch/debug_cli/__init__.py +7 -0
  31. monarch/debug_cli/__main__.py +12 -0
  32. monarch/fetch.py +2 -2
  33. monarch/gradient/_gradient_generator.so +0 -0
  34. monarch/gradient_generator.py +4 -2
  35. monarch/mesh_controller.py +34 -14
  36. monarch/monarch_controller +0 -0
  37. monarch/tools/colors.py +25 -0
  38. monarch/tools/commands.py +42 -7
  39. monarch/tools/components/hyperactor.py +1 -1
  40. monarch/tools/config/__init__.py +31 -4
  41. monarch/tools/config/defaults.py +13 -3
  42. monarch/tools/config/environment.py +45 -0
  43. monarch/tools/config/workspace.py +165 -0
  44. monarch/tools/mesh_spec.py +2 -0
  45. monarch/utils/__init__.py +9 -0
  46. monarch/utils/utils.py +78 -0
  47. tests/error_test_binary.py +5 -3
  48. tests/python_actor_test_binary.py +52 -0
  49. tests/test_actor_error.py +142 -14
  50. tests/test_alloc.py +1 -1
  51. tests/test_allocator.py +59 -72
  52. tests/test_coalescing.py +1 -1
  53. tests/test_debugger.py +639 -45
  54. tests/test_env_before_cuda.py +4 -4
  55. tests/test_mesh_trait.py +38 -0
  56. tests/test_python_actors.py +979 -75
  57. tests/test_rdma.py +7 -6
  58. tests/test_tensor_engine.py +6 -6
  59. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/METADATA +82 -4
  60. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/RECORD +64 -48
  61. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/WHEEL +0 -0
  62. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/entry_points.txt +0 -0
  63. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/licenses/LICENSE +0 -0
  64. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/top_level.txt +0 -0
@@ -5,24 +5,45 @@
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
7
  # pyre-unsafe
8
+
8
9
  import asyncio
10
+ import ctypes
11
+ import gc
12
+ import importlib.resources
9
13
  import logging
10
14
  import operator
11
15
  import os
16
+ import re
17
+ import subprocess
12
18
  import sys
13
19
  import tempfile
14
20
  import threading
15
21
  import time
16
22
  import unittest
17
- from logging import INFO
23
+ import unittest.mock
18
24
  from types import ModuleType
19
- from typing import cast
25
+ from typing import cast, Tuple
20
26
 
21
27
  import pytest
22
28
 
23
29
  import torch
30
+ from monarch._rust_bindings.monarch_hyperactor.actor import (
31
+ PythonMessage,
32
+ PythonMessageKind,
33
+ )
34
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
35
+ PortId,
36
+ PortRef,
37
+ UndeliverableMessageEnvelope,
38
+ )
39
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
40
+ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
24
41
 
25
- from monarch._src.actor.actor_mesh import ActorMeshRef, Port, PortTuple
42
+ from monarch._src.actor.actor_mesh import ActorMesh, Channel, context, Port
43
+ from monarch._src.actor.allocator import AllocHandle
44
+ from monarch._src.actor.future import Future
45
+ from monarch._src.actor.host_mesh import create_local_host_mesh, fake_in_process_host
46
+ from monarch._src.actor.proc_mesh import ProcMesh
26
47
 
27
48
  from monarch.actor import (
28
49
  Accumulator,
@@ -31,10 +52,10 @@ from monarch.actor import (
31
52
  current_rank,
32
53
  current_size,
33
54
  endpoint,
34
- Future,
35
- local_proc_mesh,
36
- proc_mesh,
55
+ this_host,
56
+ this_proc,
37
57
  )
58
+ from monarch.tools.config import defaults
38
59
  from typing_extensions import assert_type
39
60
 
40
61
 
@@ -67,8 +88,9 @@ class Indirect(Actor):
67
88
  return await c.value.choose()
68
89
 
69
90
 
91
+ @pytest.mark.timeout(60)
70
92
  async def test_choose():
71
- proc = await local_proc_mesh(gpus=2)
93
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
72
94
  v = await proc.spawn("counter", Counter, 3)
73
95
  i = await proc.spawn("indirect", Indirect)
74
96
  v.incr.broadcast()
@@ -85,8 +107,9 @@ async def test_choose():
85
107
  assert result2 == result3
86
108
 
87
109
 
110
+ @pytest.mark.timeout(60)
88
111
  async def test_stream():
89
- proc = await local_proc_mesh(gpus=2)
112
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
90
113
  v = await proc.spawn("counter2", Counter, 3)
91
114
  v.incr.broadcast()
92
115
 
@@ -101,46 +124,50 @@ class To(Actor):
101
124
 
102
125
  class From(Actor):
103
126
  @endpoint
104
- async def get(self, to: To):
127
+ async def fetch(self, to: To):
105
128
  return [await x for x in to.whoami.stream()]
106
129
 
107
130
 
131
+ @pytest.mark.timeout(60)
108
132
  async def test_mesh_passed_to_mesh():
109
- proc = await local_proc_mesh(gpus=2)
133
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
110
134
  f = await proc.spawn("from", From)
111
135
  t = await proc.spawn("to", To)
112
- all = [y for x in f.get.stream(t) for y in await x]
136
+ all = [y for x in f.fetch.stream(t) for y in await x]
113
137
  assert len(all) == 4
114
138
  assert all[0] != all[1]
115
139
 
116
140
 
141
+ @pytest.mark.timeout(60)
117
142
  async def test_mesh_passed_to_mesh_on_different_proc_mesh():
118
- proc = await local_proc_mesh(gpus=2)
119
- proc2 = await local_proc_mesh(gpus=2)
143
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
144
+ proc2 = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
120
145
  f = await proc.spawn("from", From)
121
146
  t = await proc2.spawn("to", To)
122
- all = [y for x in f.get.stream(t) for y in await x]
147
+ all = [y for x in f.fetch.stream(t) for y in await x]
123
148
  assert len(all) == 4
124
149
  assert all[0] != all[1]
125
150
 
126
151
 
127
- async def test_actor_slicing():
128
- proc = await local_proc_mesh(gpus=2)
129
- proc2 = await local_proc_mesh(gpus=2)
152
+ @pytest.mark.timeout(60)
153
+ def test_actor_slicing():
154
+ proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
155
+ proc2 = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
130
156
 
131
- f = await proc.spawn("from", From)
132
- t = await proc2.spawn("to", To)
157
+ f = proc.spawn("from", From)
158
+ t = proc2.spawn("to", To)
133
159
 
134
- assert await t.slice(gpus=0).whoami.call() != await t.slice(gpus=1).whoami.call()
160
+ assert t.slice(gpus=0).whoami.call().get() != t.slice(gpus=1).whoami.call().get()
135
161
 
136
- result = [y for x in f.get.stream(t.slice(gpus=0)) for y in await x]
162
+ result = [y for x in f.fetch.stream(t.slice(gpus=0)) for y in x.get()]
137
163
  assert len(result) == 2
138
164
 
139
165
  assert result[0] == result[1]
140
166
 
141
167
 
168
+ @pytest.mark.timeout(60)
142
169
  async def test_aggregate():
143
- proc = await local_proc_mesh(gpus=2)
170
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
144
171
  counter = await proc.spawn("counter", Counter, 1)
145
172
  counter.incr.broadcast()
146
173
  acc = Accumulator(counter.value, 0, operator.add)
@@ -153,9 +180,14 @@ class RunIt(Actor):
153
180
  async def run(self, fn):
154
181
  return fn()
155
182
 
183
+ @endpoint
184
+ async def return_current_rank_str(self):
185
+ return str(current_rank())
186
+
156
187
 
188
+ @pytest.mark.timeout(60)
157
189
  async def test_rank_size():
158
- proc = await local_proc_mesh(gpus=2)
190
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
159
191
  r = await proc.spawn("runit", RunIt)
160
192
 
161
193
  acc = Accumulator(r.run, 0, operator.add)
@@ -164,35 +196,50 @@ async def test_rank_size():
164
196
  assert 4 == await acc.accumulate(lambda: current_size()["gpus"])
165
197
 
166
198
 
199
+ @pytest.mark.timeout(60)
200
+ async def test_rank_string():
201
+ proc = fake_in_process_host().spawn_procs(per_host={"hosts": 1, "gpus": 2})
202
+ r = proc.spawn("runit", RunIt)
203
+ vm = r.return_current_rank_str.call().get()
204
+ r0 = vm.flatten("r").slice(r=0).item()
205
+ r1 = vm.flatten("r").slice(r=1).item()
206
+ assert r0 == "{'hosts': 0/1, 'gpus': 0/2}"
207
+ assert r1 == "{'hosts': 0/1, 'gpus': 1/2}"
208
+
209
+
167
210
  class SyncActor(Actor):
168
211
  @endpoint
169
212
  def sync_endpoint(self, a_counter: Counter):
170
213
  return a_counter.value.choose().get()
171
214
 
172
215
 
216
+ @pytest.mark.timeout(60)
173
217
  async def test_sync_actor():
174
- proc = await local_proc_mesh(gpus=2)
218
+ proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
175
219
  a = await proc.spawn("actor", SyncActor)
176
220
  c = await proc.spawn("counter", Counter, 5)
177
221
  r = await a.sync_endpoint.choose(c)
178
222
  assert r == 5
179
223
 
180
224
 
181
- def test_sync_actor_sync_client():
182
- proc = local_proc_mesh(gpus=2).get()
225
+ @pytest.mark.timeout(60)
226
+ def test_sync_actor_sync_client() -> None:
227
+ proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
183
228
  a = proc.spawn("actor", SyncActor).get()
184
229
  c = proc.spawn("counter", Counter, 5).get()
185
230
  r = a.sync_endpoint.choose(c).get()
186
231
  assert r == 5
187
232
 
188
233
 
234
+ @pytest.mark.timeout(60)
189
235
  def test_proc_mesh_size() -> None:
190
- proc = local_proc_mesh(gpus=2).get()
236
+ proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
191
237
  assert 2 == proc.size("gpus")
192
238
 
193
239
 
240
+ @pytest.mark.timeout(60)
194
241
  def test_rank_size_sync() -> None:
195
- proc = local_proc_mesh(gpus=2).get()
242
+ proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
196
243
  r = proc.spawn("runit", RunIt).get()
197
244
 
198
245
  acc = Accumulator(r.run, 0, operator.add)
@@ -200,8 +247,9 @@ def test_rank_size_sync() -> None:
200
247
  assert 4 == acc.accumulate(lambda: current_size()["gpus"]).get()
201
248
 
202
249
 
250
+ @pytest.mark.timeout(60)
203
251
  def test_accumulate_sync() -> None:
204
- proc = local_proc_mesh(gpus=2).get()
252
+ proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
205
253
  counter = proc.spawn("counter", Counter, 1).get()
206
254
  counter.incr.broadcast()
207
255
  acc = Accumulator(counter.value, 0, operator.add)
@@ -215,8 +263,9 @@ class CastToCounter(Actor):
215
263
  return list(c.value.call().get())
216
264
 
217
265
 
266
+ @pytest.mark.timeout(60)
218
267
  def test_value_mesh() -> None:
219
- proc = local_proc_mesh(gpus=2).get()
268
+ proc = fake_in_process_host().spawn_procs(per_host={"hosts": 1, "gpus": 2})
220
269
  counter = proc.spawn("counter", Counter, 0).get()
221
270
  counter.slice(hosts=0, gpus=1).incr.broadcast()
222
271
  x = counter.value.call().get()
@@ -227,7 +276,18 @@ def test_value_mesh() -> None:
227
276
  assert list(x) == n.slice(gpus=0).doit.call_one(counter).get()
228
277
 
229
278
 
279
+ @pytest.mark.timeout(60)
230
280
  def test_rust_binding_modules_correct() -> None:
281
+ """
282
+ This tests that rust bindings will survive pickling correctly.
283
+
284
+ To correctly define a rust binding, either
285
+
286
+ (1) Set its module to "monarch._rust_bindings.rust_crate.rust_module",
287
+ and make sure it is registered in monarch_extension/lib.rs
288
+ (2) Set its module to some existing python file, and use @rust_struct to install
289
+ the rust struct in that file and patch in any python extension methods.
290
+ """
231
291
  import monarch._rust_bindings as bindings
232
292
 
233
293
  def check(module, path):
@@ -237,14 +297,16 @@ def test_rust_binding_modules_correct() -> None:
237
297
  if isinstance(value, ModuleType):
238
298
  check(value, f"{path}.{name}")
239
299
  elif hasattr(value, "__module__"):
240
- assert value.__name__ == name
241
- assert value.__module__ == path
300
+ value_module = importlib.import_module(value.__module__)
301
+ resolved_value = getattr(value_module, value.__name__)
302
+ assert value is resolved_value
242
303
 
243
304
  check(bindings, "monarch._rust_bindings")
244
305
 
245
306
 
307
+ @pytest.mark.timeout(60)
246
308
  def test_proc_mesh_liveness() -> None:
247
- mesh = proc_mesh(gpus=2).get()
309
+ mesh = this_host().spawn_procs(per_host={"gpus": 2})
248
310
  counter = mesh.spawn("counter", Counter, 1).get()
249
311
  del mesh
250
312
  # Give some time for the mesh to have been shut down.
@@ -269,7 +331,7 @@ class TLSActor(Actor):
269
331
  self.local.value += 1
270
332
 
271
333
  @endpoint
272
- def get(self):
334
+ def get_value(self):
273
335
  return self.local.value
274
336
 
275
337
  @endpoint
@@ -277,16 +339,17 @@ class TLSActor(Actor):
277
339
  return self.local.value
278
340
 
279
341
 
342
+ @pytest.mark.timeout(60)
280
343
  async def test_actor_tls() -> None:
281
344
  """Test that thread-local state is respected."""
282
- pm = await proc_mesh(gpus=1)
345
+ pm = this_host().spawn_procs(per_host={"gpus": 1})
283
346
  am = await pm.spawn("tls", TLSActor)
284
347
  await am.increment.call_one()
285
348
  await am.increment_async.call_one()
286
349
  await am.increment.call_one()
287
350
  await am.increment_async.call_one()
288
351
 
289
- assert 4 == await am.get.call_one()
352
+ assert 4 == await am.get_value.call_one()
290
353
  assert 4 == await am.get_async.call_one()
291
354
 
292
355
 
@@ -302,20 +365,21 @@ class TLSActorFullSync(Actor):
302
365
  self.local.value += 1
303
366
 
304
367
  @endpoint
305
- def get(self):
368
+ def get_value(self):
306
369
  return self.local.value
307
370
 
308
371
 
372
+ @pytest.mark.timeout(60)
309
373
  async def test_actor_tls_full_sync() -> None:
310
374
  """Test that thread-local state is respected."""
311
- pm = await proc_mesh(gpus=1)
375
+ pm = this_host().spawn_procs(per_host={"gpus": 1})
312
376
  am = await pm.spawn("tls", TLSActorFullSync)
313
377
  await am.increment.call_one()
314
378
  await am.increment.call_one()
315
379
  await am.increment.call_one()
316
380
  await am.increment.call_one()
317
381
 
318
- assert 4 == await am.get.call_one()
382
+ assert 4 == await am.get_value.call_one()
319
383
 
320
384
 
321
385
  class AsyncActor(Actor):
@@ -332,10 +396,10 @@ class AsyncActor(Actor):
332
396
  self.should_exit = True
333
397
 
334
398
 
335
- @pytest.mark.timeout(15)
399
+ @pytest.mark.timeout(30)
336
400
  async def test_async_concurrency():
337
401
  """Test that async endpoints will be processed concurrently."""
338
- pm = await proc_mesh(gpus=1)
402
+ pm = await this_host().spawn_procs()
339
403
  am = await pm.spawn("async", AsyncActor)
340
404
  fut = am.sleep.call()
341
405
  # This call should go through and exit the sleep loop, as long as we are
@@ -441,19 +505,35 @@ async def awaitit(f):
441
505
 
442
506
 
443
507
  class Printer(Actor):
444
- def __init__(self):
445
- self.logger = logging.getLogger()
446
- self.logger.setLevel(INFO)
508
+ def __init__(self) -> None:
509
+ self._logger: logging.Logger = logging.getLogger()
447
510
 
448
511
  @endpoint
449
- async def print(self, content: str):
450
- print(f"{os.getpid()} {content}")
512
+ async def print(self, content: str) -> None:
513
+ print(f"{content}", flush=True)
514
+ sys.stdout.flush()
515
+ sys.stderr.flush()
451
516
 
452
517
  @endpoint
453
- async def log(self, content: str):
454
- self.logger.info(f"{os.getpid()} {content}")
455
-
456
-
518
+ async def log(self, content: str) -> None:
519
+ self._logger.error(f"{content}")
520
+ for handler in self._logger.handlers:
521
+ handler.flush()
522
+ sys.stdout.flush()
523
+ sys.stderr.flush()
524
+
525
+ def _handle_undeliverable_message(
526
+ self, message: UndeliverableMessageEnvelope
527
+ ) -> bool:
528
+ # Don't throw an error on undeliverable messages. This actor is used in a test for
529
+ # stopping actor meshes, and if we throw an error here then there is a race between
530
+ # the asserted error that the mesh was stopped and the supervision error that a message
531
+ # wasn't delivered.
532
+ self._logger.error(f"Ignoring undeliverable message: {message}")
533
+ return True
534
+
535
+
536
+ @pytest.mark.timeout(60)
457
537
  async def test_actor_log_streaming() -> None:
458
538
  # Save original file descriptors
459
539
  original_stdout_fd = os.dup(1) # stdout
@@ -481,19 +561,247 @@ async def test_actor_log_streaming() -> None:
481
561
  sys.stderr = stderr_file
482
562
 
483
563
  try:
484
- pm = await proc_mesh(gpus=2)
564
+ pm = this_host().spawn_procs(per_host={"gpus": 2})
485
565
  am = await pm.spawn("printer", Printer)
486
566
 
487
- await am.print.call("hello 1")
488
- await am.log.call("hello 2")
567
+ # Disable streaming logs to client
568
+ await pm.logging_option(
569
+ stream_to_client=False, aggregate_window_sec=None
570
+ )
571
+ await asyncio.sleep(1)
572
+
573
+ # These should not be streamed to client initially
574
+ for _ in range(5):
575
+ await am.print.call("no print streaming")
576
+ await am.log.call("no log streaming")
577
+ await asyncio.sleep(1)
578
+
579
+ # Enable streaming logs to client
580
+ await pm.logging_option(
581
+ stream_to_client=True, aggregate_window_sec=1, level=logging.FATAL
582
+ )
583
+ # Give it some time to reflect
584
+ await asyncio.sleep(1)
585
+
586
+ # These should be streamed to client
587
+ for _ in range(5):
588
+ await am.print.call("has print streaming")
589
+ await am.log.call("no log streaming due to level mismatch")
590
+ await asyncio.sleep(1)
591
+
592
+ # Enable streaming logs to client
593
+ await pm.logging_option(
594
+ stream_to_client=True, aggregate_window_sec=1, level=logging.ERROR
595
+ )
596
+ # Give it some time to reflect
597
+ await asyncio.sleep(1)
598
+
599
+ # These should be streamed to client
600
+ for _ in range(5):
601
+ await am.print.call("has print streaming too")
602
+ await am.log.call("has log streaming as level matched")
603
+
604
+ await pm.stop()
605
+
606
+ # Flush all outputs
607
+ stdout_file.flush()
608
+ stderr_file.flush()
609
+ os.fsync(stdout_file.fileno())
610
+ os.fsync(stderr_file.fileno())
611
+
612
+ finally:
613
+ # Restore Python's sys.stdout/stderr
614
+ sys.stdout = original_sys_stdout
615
+ sys.stderr = original_sys_stderr
616
+
617
+ # Restore original file descriptors
618
+ os.dup2(original_stdout_fd, 1)
619
+ os.dup2(original_stderr_fd, 2)
620
+
621
+ # Read the captured output
622
+ with open(stdout_path, "r") as f:
623
+ stdout_content = f.read()
624
+
625
+ with open(stderr_path, "r") as f:
626
+ stderr_content = f.read()
627
+
628
+ # Clean up temp files
629
+ os.unlink(stdout_path)
630
+ os.unlink(stderr_path)
631
+
632
+ # Assertions on the captured output
633
+ # Has a leading context so we can distinguish between streamed log and
634
+ # the log directly printed by the child processes as they share the same stdout/stderr
635
+ assert not re.search(
636
+ r"similar log lines.*no print streaming", stdout_content
637
+ ), stdout_content
638
+ assert not re.search(
639
+ r"similar log lines.*no print streaming", stderr_content
640
+ ), stderr_content
641
+ assert not re.search(
642
+ r"similar log lines.*no log streaming", stdout_content
643
+ ), stdout_content
644
+ assert not re.search(
645
+ r"similar log lines.*no log streaming", stderr_content
646
+ ), stderr_content
647
+ assert not re.search(
648
+ r"similar log lines.*no log streaming due to level mismatch", stdout_content
649
+ ), stdout_content
650
+ assert not re.search(
651
+ r"similar log lines.*no log streaming due to level mismatch", stderr_content
652
+ ), stderr_content
653
+
654
+ assert re.search(
655
+ r"similar log lines.*has print streaming", stdout_content
656
+ ), stdout_content
657
+ assert not re.search(
658
+ r"similar log lines.*has print streaming", stderr_content
659
+ ), stderr_content
660
+ assert re.search(
661
+ r"similar log lines.*has print streaming too", stdout_content
662
+ ), stdout_content
663
+ assert not re.search(
664
+ r"similar log lines.*has print streaming too", stderr_content
665
+ ), stderr_content
666
+ assert not re.search(
667
+ r"similar log lines.*log streaming as level matched", stdout_content
668
+ ), stdout_content
669
+ assert re.search(
670
+ r"similar log lines.*log streaming as level matched",
671
+ stderr_content,
672
+ ), stderr_content
673
+
674
+ finally:
675
+ # Ensure file descriptors are restored even if something goes wrong
676
+ try:
677
+ os.dup2(original_stdout_fd, 1)
678
+ os.dup2(original_stderr_fd, 2)
679
+ os.close(original_stdout_fd)
680
+ os.close(original_stderr_fd)
681
+ except OSError:
682
+ pass
683
+
684
+
685
+ @pytest.mark.timeout(120)
686
+ async def test_alloc_based_log_streaming() -> None:
687
+ """Test both AllocHandle.stream_logs = False and True cases."""
688
+
689
+ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
690
+ # Save original file descriptors
691
+ original_stdout_fd = os.dup(1) # stdout
692
+
693
+ try:
694
+ # Create temporary files to capture output
695
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
696
+ stdout_path = stdout_file.name
697
+ os.dup2(stdout_file.fileno(), 1)
698
+ original_sys_stdout = sys.stdout
699
+ sys.stdout = stdout_file
700
+
701
+ try:
702
+ # Create proc mesh with custom stream_logs setting
703
+ host_mesh = create_local_host_mesh()
704
+ alloc_handle = host_mesh._alloc(hosts=1, gpus=2)
705
+
706
+ # Override the stream_logs setting
707
+ custom_alloc_handle = AllocHandle(
708
+ alloc_handle._hy_alloc, alloc_handle._extent, stream_logs
709
+ )
489
710
 
490
- await pm.logging_option(stream_to_client=True)
711
+ pm = ProcMesh.from_alloc(custom_alloc_handle)
712
+ am = await pm.spawn("printer", Printer)
491
713
 
492
- await am.print.call("hello 3")
493
- await am.log.call("hello 4")
714
+ await pm.initialized
494
715
 
495
- # Give it sometime to send log back
496
- time.sleep(5)
716
+ for _ in range(5):
717
+ await am.print.call(f"{test_name} print streaming")
718
+
719
+ await pm.stop()
720
+
721
+ # Flush all outputs
722
+ stdout_file.flush()
723
+ os.fsync(stdout_file.fileno())
724
+
725
+ finally:
726
+ # Restore Python's sys.stdout
727
+ sys.stdout = original_sys_stdout
728
+
729
+ # Restore original file descriptors
730
+ os.dup2(original_stdout_fd, 1)
731
+
732
+ # Read the captured output
733
+ with open(stdout_path, "r") as f:
734
+ stdout_content = f.read()
735
+
736
+ # Clean up temp files
737
+ os.unlink(stdout_path)
738
+
739
+ if not stream_logs:
740
+ # When stream_logs=False, logs should not be streamed to client
741
+ assert not re.search(
742
+ rf"similar log lines.*{test_name} print streaming", stdout_content
743
+ ), f"stream_logs=True case: {stdout_content}"
744
+ assert re.search(
745
+ rf"{test_name} print streaming", stdout_content
746
+ ), f"stream_logs=True case: {stdout_content}"
747
+ else:
748
+ # When stream_logs=True, logs should be streamed to client (no aggregation by default)
749
+ assert re.search(
750
+ rf"similar log lines.*{test_name} print streaming", stdout_content
751
+ ), f"stream_logs=False case: {stdout_content}"
752
+ assert not re.search(
753
+ rf"\[[0-9]\]{test_name} print streaming", stdout_content
754
+ ), f"stream_logs=False case: {stdout_content}"
755
+
756
+ finally:
757
+ # Ensure file descriptors are restored even if something goes wrong
758
+ try:
759
+ os.dup2(original_stdout_fd, 1)
760
+ os.close(original_stdout_fd)
761
+ except OSError:
762
+ pass
763
+
764
+ # Test both cases
765
+ await test_stream_logs_case(False, "stream_logs_false")
766
+ await test_stream_logs_case(True, "stream_logs_true")
767
+
768
+
769
+ @pytest.mark.timeout(60)
770
+ async def test_logging_option_defaults() -> None:
771
+ # Save original file descriptors
772
+ original_stdout_fd = os.dup(1) # stdout
773
+ original_stderr_fd = os.dup(2) # stderr
774
+
775
+ try:
776
+ # Create temporary files to capture output
777
+ with tempfile.NamedTemporaryFile(
778
+ mode="w+", delete=False
779
+ ) as stdout_file, tempfile.NamedTemporaryFile(
780
+ mode="w+", delete=False
781
+ ) as stderr_file:
782
+ stdout_path = stdout_file.name
783
+ stderr_path = stderr_file.name
784
+
785
+ # Redirect file descriptors to our temp files
786
+ # This will capture both Python and Rust output
787
+ os.dup2(stdout_file.fileno(), 1)
788
+ os.dup2(stderr_file.fileno(), 2)
789
+
790
+ # Also redirect Python's sys.stdout/stderr for completeness
791
+ original_sys_stdout = sys.stdout
792
+ original_sys_stderr = sys.stderr
793
+ sys.stdout = stdout_file
794
+ sys.stderr = stderr_file
795
+
796
+ try:
797
+ pm = await this_host().spawn_procs(per_host={"gpus": 2})
798
+ am = await pm.spawn("printer", Printer)
799
+
800
+ for _ in range(5):
801
+ await am.print.call("print streaming")
802
+ await am.log.call("log streaming")
803
+
804
+ await pm.stop()
497
805
 
498
806
  # Flush all outputs
499
807
  stdout_file.flush()
@@ -514,16 +822,27 @@ async def test_actor_log_streaming() -> None:
514
822
  with open(stdout_path, "r") as f:
515
823
  stdout_content = f.read()
516
824
 
825
+ with open(stderr_path, "r") as f:
826
+ stderr_content = f.read()
827
+
517
828
  # Clean up temp files
518
829
  os.unlink(stdout_path)
519
830
  os.unlink(stderr_path)
520
831
 
521
- # TODO: (@jamessun) we need to disable logging forwarder for python logger
522
- # assert "hello 1" not in stdout_content
523
- assert "hello 2" not in stdout_content
524
-
525
- assert "hello 3" in stdout_content
526
- # assert "hello 4" in stdout_content
832
+ # Assertions on the captured output
833
+ assert not re.search(
834
+ r"similar log lines.*print streaming", stdout_content
835
+ ), stdout_content
836
+ assert re.search(r"print streaming", stdout_content), stdout_content
837
+ assert not re.search(
838
+ r"similar log lines.*print streaming", stderr_content
839
+ ), stderr_content
840
+ assert not re.search(
841
+ r"similar log lines.*log streaming", stdout_content
842
+ ), stdout_content
843
+ assert not re.search(
844
+ r"similar log lines.*log streaming", stderr_content
845
+ ), stderr_content
527
846
 
528
847
  finally:
529
848
  # Ensure file descriptors are restored even if something goes wrong
@@ -536,6 +855,378 @@ async def test_actor_log_streaming() -> None:
536
855
  pass
537
856
 
538
857
 
858
+ # oss_skip: pytest keeps complaining about mocking get_ipython module
859
+ @pytest.mark.oss_skip
860
+ @pytest.mark.timeout(180)
861
+ async def test_flush_logs_ipython() -> None:
862
+ """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
863
+ # Save original file descriptors
864
+ original_stdout_fd = os.dup(1) # stdout
865
+
866
+ try:
867
+ # Create temporary files to capture output
868
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
869
+ stdout_path = stdout_file.name
870
+
871
+ # Redirect file descriptors to our temp files
872
+ os.dup2(stdout_file.fileno(), 1)
873
+
874
+ # Also redirect Python's sys.stdout
875
+ original_sys_stdout = sys.stdout
876
+ sys.stdout = stdout_file
877
+
878
+ try:
879
+ # Mock IPython environment
880
+ class MockExecutionResult:
881
+ pass
882
+
883
+ class MockEvents:
884
+ def __init__(self):
885
+ self.callbacks = {}
886
+ self.registers = 0
887
+ self.unregisters = 0
888
+
889
+ def register(self, event_name, callback):
890
+ if event_name not in self.callbacks:
891
+ self.callbacks[event_name] = []
892
+ self.callbacks[event_name].append(callback)
893
+ self.registers += 1
894
+
895
+ def unregister(self, event_name, callback):
896
+ if event_name not in self.callbacks:
897
+ raise ValueError(f"Event {event_name} not registered")
898
+ assert callback in self.callbacks[event_name]
899
+ self.callbacks[event_name].remove(callback)
900
+ self.unregisters += 1
901
+
902
+ def trigger(self, event_name, *args, **kwargs):
903
+ if event_name in self.callbacks:
904
+ for callback in self.callbacks[event_name]:
905
+ callback(*args, **kwargs)
906
+
907
+ class MockIPython:
908
+ def __init__(self):
909
+ self.events = MockEvents()
910
+
911
+ mock_ipython = MockIPython()
912
+
913
+ with unittest.mock.patch(
914
+ "monarch._src.actor.logging.get_ipython",
915
+ lambda: mock_ipython,
916
+ ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
917
+ # Make sure we can register and unregister callbacks
918
+ for i in range(3):
919
+ pm1 = await this_host().spawn_procs(per_host={"gpus": 2})
920
+ pm2 = await this_host().spawn_procs(per_host={"gpus": 2})
921
+ am1 = await pm1.spawn("printer", Printer)
922
+ am2 = await pm2.spawn("printer", Printer)
923
+
924
+ # Set aggregation window to ensure logs are buffered
925
+ await pm1.logging_option(
926
+ stream_to_client=True, aggregate_window_sec=600
927
+ )
928
+ await pm2.logging_option(
929
+ stream_to_client=True, aggregate_window_sec=600
930
+ )
931
+ # TODO: fix the following assertion
932
+ # assert mock_ipython.events.unregisters == 2 * i
933
+
934
+ # _get_controller_controller() spawns an extra local mesh
935
+ # but log streaming is disabled so it doesn't hurt
936
+ assert mock_ipython.events.registers == 1 + 2 * (i + 1)
937
+ await asyncio.sleep(1)
938
+
939
+ # Generate some logs that will be aggregated
940
+ for _ in range(5):
941
+ await am1.print.call("ipython1 test log")
942
+ await am2.print.call("ipython2 test log")
943
+
944
+ # Trigger the post_run_cell event which should flush logs
945
+ mock_ipython.events.trigger(
946
+ "post_run_cell", MockExecutionResult()
947
+ )
948
+
949
+ # Flush all outputs
950
+ stdout_file.flush()
951
+ os.fsync(stdout_file.fileno())
952
+
953
+ gc.collect()
954
+
955
+ # Same as above, _get_controller_controller() spawns an extra local mesh
956
+ assert mock_ipython.events.registers == 7
957
+ # There are many objects still taking refs
958
+ # TODO: fix the following assertion
959
+ assert mock_ipython.events.unregisters == 0
960
+ assert len(mock_ipython.events.callbacks["post_run_cell"]) == 7
961
+ finally:
962
+ # Restore Python's sys.stdout
963
+ sys.stdout = original_sys_stdout
964
+
965
+ # Restore original file descriptors
966
+ os.dup2(original_stdout_fd, 1)
967
+
968
+ # Read the captured output
969
+ with open(stdout_path, "r") as f:
970
+ stdout_content = f.read()
971
+
972
+ # TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils
973
+
974
+ # Clean up temp files
975
+ os.unlink(stdout_path)
976
+
977
+ # Verify that logs were flushed when the post_run_cell event was triggered
978
+ # We should see the aggregated logs in the output
979
+ assert (
980
+ len(
981
+ re.findall(
982
+ r"\[10 similar log lines\].*ipython1 test log", stdout_content
983
+ )
984
+ )
985
+ == 3
986
+ ), stdout_content
987
+
988
+ assert (
989
+ len(
990
+ re.findall(
991
+ r"\[10 similar log lines\].*ipython2 test log", stdout_content
992
+ )
993
+ )
994
+ == 3
995
+ ), stdout_content
996
+
997
+ finally:
998
+ # Ensure file descriptors are restored even if something goes wrong
999
+ try:
1000
+ os.dup2(original_stdout_fd, 1)
1001
+ os.close(original_stdout_fd)
1002
+ except OSError:
1003
+ pass
1004
+
1005
+
1006
+ # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
1007
+ @pytest.mark.oss_skip
1008
+ async def test_flush_logs_fast_exit() -> None:
1009
+ # We use a subprocess to run the test so we can handle the flushed logs at the end.
1010
+ # Otherwise, it is hard to restore the original stdout/stderr.
1011
+
1012
+ test_bin = importlib.resources.files(str(__package__)).joinpath("test_bin")
1013
+
1014
+ # Run the binary in a separate process and capture stdout and stderr
1015
+ cmd = [str(test_bin), "flush-logs"]
1016
+
1017
+ process = subprocess.run(cmd, capture_output=True, timeout=60, text=True)
1018
+
1019
+ # Check if the process ended without error
1020
+ if process.returncode != 0:
1021
+ raise RuntimeError(f"{cmd} ended with error code {process.returncode}. ")
1022
+
1023
+ # Assertions on the captured output, 160 = 32 procs * 5 logs per proc
1024
+ # 32 and 5 are specified in the test_bin flush-logs.
1025
+ assert (
1026
+ len(
1027
+ re.findall(
1028
+ r"160 similar log lines.*has print streaming",
1029
+ process.stdout,
1030
+ )
1031
+ )
1032
+ == 1
1033
+ ), process.stdout
1034
+
1035
+
1036
+ @pytest.mark.timeout(60)
1037
+ async def test_flush_on_disable_aggregation() -> None:
1038
+ """Test that logs are flushed when disabling aggregation.
1039
+
1040
+ This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation."
1041
+ """
1042
+ # Save original file descriptors
1043
+ original_stdout_fd = os.dup(1) # stdout
1044
+
1045
+ try:
1046
+ # Create temporary files to capture output
1047
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
1048
+ stdout_path = stdout_file.name
1049
+
1050
+ # Redirect file descriptors to our temp files
1051
+ os.dup2(stdout_file.fileno(), 1)
1052
+
1053
+ # Also redirect Python's sys.stdout
1054
+ original_sys_stdout = sys.stdout
1055
+ sys.stdout = stdout_file
1056
+
1057
+ try:
1058
+ pm = await this_host().spawn_procs(per_host={"gpus": 2})
1059
+ am = await pm.spawn("printer", Printer)
1060
+
1061
+ # Set a long aggregation window to ensure logs aren't flushed immediately
1062
+ await pm.logging_option(stream_to_client=True, aggregate_window_sec=60)
1063
+
1064
+ # Generate some logs that will be aggregated but not flushed immediately
1065
+ for _ in range(5):
1066
+ await am.print.call("aggregated log line")
1067
+ await asyncio.sleep(1)
1068
+
1069
+ # Now disable aggregation - this should trigger an immediate flush
1070
+ await pm.logging_option(
1071
+ stream_to_client=True, aggregate_window_sec=None
1072
+ )
1073
+
1074
+ # Wait a bit to ensure logs are collected
1075
+ await asyncio.sleep(1)
1076
+ for _ in range(5):
1077
+ await am.print.call("single log line")
1078
+
1079
+ await pm.stop()
1080
+
1081
+ # Flush all outputs
1082
+ stdout_file.flush()
1083
+ os.fsync(stdout_file.fileno())
1084
+
1085
+ finally:
1086
+ # Restore Python's sys.stdout
1087
+ sys.stdout = original_sys_stdout
1088
+
1089
+ # Restore original file descriptors
1090
+ os.dup2(original_stdout_fd, 1)
1091
+
1092
+ # Read the captured output
1093
+ with open(stdout_path, "r") as f:
1094
+ stdout_content = f.read()
1095
+
1096
+ # Clean up temp files
1097
+ os.unlink(stdout_path)
1098
+
1099
+ # Verify that logs were flushed when aggregation was disabled
1100
+ # We should see the aggregated logs in the output
1101
+ # 10 = 5 log lines * 2 procs
1102
+ assert re.search(
1103
+ r"\[10 similar log lines\].*aggregated log line", stdout_content
1104
+ ), stdout_content
1105
+
1106
+ # No aggregated single log lines
1107
+ assert not re.search(
1108
+ r"similar log lines.*single log line", stdout_content
1109
+ ), stdout_content
1110
+
1111
+ # 10 = 5 log lines * 2 procs
1112
+ assert (
1113
+ len(re.findall(r"\[.* [0-9]+\] single log line", stdout_content)) == 10
1114
+ ), stdout_content
1115
+
1116
+ finally:
1117
+ # Ensure file descriptors are restored even if something goes wrong
1118
+ try:
1119
+ os.dup2(original_stdout_fd, 1)
1120
+ os.close(original_stdout_fd)
1121
+ except OSError:
1122
+ pass
1123
+
1124
+
1125
+ @pytest.mark.timeout(120)
1126
+ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
1127
+ """
1128
+ The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked.
1129
+ Because now a flush call is purely sync, it is very easy to get into a deadlock.
1130
+ So we assert the last flush call will not get into such a state.
1131
+ """
1132
+ pm = this_host().spawn_procs(per_host={"gpus": 4})
1133
+ am = pm.spawn("printer", Printer)
1134
+
1135
+ # Generate some logs that will be aggregated but not flushed immediately
1136
+ for _ in range(10):
1137
+ await am.print.call("aggregated log line")
1138
+
1139
+ log_mesh = pm._logging_manager._logging_mesh_client
1140
+ assert log_mesh is not None
1141
+ futures = []
1142
+ for _ in range(5):
1143
+ # FIXME: the order of futures doesn't necessarily mean the order of flushes due to the async nature.
1144
+ await asyncio.sleep(0.1)
1145
+ futures.append(Future(coro=log_mesh.flush().spawn().task()))
1146
+
1147
+ # The last flush should not block
1148
+ futures[-1].get()
1149
+
1150
+
1151
+ @pytest.mark.timeout(60)
1152
+ async def test_adjust_aggregation_window() -> None:
1153
+ """Test that the flush deadline is updated when the aggregation window is adjusted.
1154
+
1155
+ This tests the corner case: "This can happen if the user has adjusted the aggregation window."
1156
+ """
1157
+ # Save original file descriptors
1158
+ original_stdout_fd = os.dup(1) # stdout
1159
+
1160
+ try:
1161
+ # Create temporary files to capture output
1162
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
1163
+ stdout_path = stdout_file.name
1164
+
1165
+ # Redirect file descriptors to our temp files
1166
+ os.dup2(stdout_file.fileno(), 1)
1167
+
1168
+ # Also redirect Python's sys.stdout
1169
+ original_sys_stdout = sys.stdout
1170
+ sys.stdout = stdout_file
1171
+
1172
+ try:
1173
+ pm = await this_host().spawn_procs(per_host={"gpus": 2})
1174
+ am = await pm.spawn("printer", Printer)
1175
+
1176
+ # Set a long aggregation window initially
1177
+ await pm.logging_option(stream_to_client=True, aggregate_window_sec=100)
1178
+
1179
+ # Generate some logs that will be aggregated
1180
+ for _ in range(3):
1181
+ await am.print.call("first batch of logs")
1182
+ await asyncio.sleep(1)
1183
+
1184
+ # Now adjust to a shorter window - this should update the flush deadline
1185
+ await pm.logging_option(stream_to_client=True, aggregate_window_sec=2)
1186
+
1187
+ # Generate more logs
1188
+ for _ in range(3):
1189
+ await am.print.call("second batch of logs")
1190
+
1191
+ await pm.stop()
1192
+
1193
+ # Flush all outputs
1194
+ stdout_file.flush()
1195
+ os.fsync(stdout_file.fileno())
1196
+
1197
+ finally:
1198
+ # Restore Python's sys.stdout/stderr
1199
+ sys.stdout = original_sys_stdout
1200
+
1201
+ # Restore original file descriptors
1202
+ os.dup2(original_stdout_fd, 1)
1203
+
1204
+ # Read the captured output
1205
+ with open(stdout_path, "r") as f:
1206
+ stdout_content = f.read()
1207
+
1208
+ # Clean up temp files
1209
+ os.unlink(stdout_path)
1210
+
1211
+ # Verify that logs were flushed when the aggregation window was adjusted
1212
+ # We should see both batches of logs in the output
1213
+ assert re.search(
1214
+ r"\[6 similar log lines\].*first batch of logs", stdout_content
1215
+ ), stdout_content
1216
+
1217
+ assert re.search(
1218
+ r"similar log lines.*second batch of logs", stdout_content
1219
+ ), stdout_content
1220
+
1221
+ finally:
1222
+ # Ensure file descriptors are restored even if something goes wrong
1223
+ try:
1224
+ os.dup2(original_stdout_fd, 1)
1225
+ os.close(original_stdout_fd)
1226
+ except OSError:
1227
+ pass
1228
+
1229
+
539
1230
  class SendAlot(Actor):
540
1231
  @endpoint
541
1232
  async def send(self, port: Port[int]):
@@ -543,10 +1234,11 @@ class SendAlot(Actor):
543
1234
  port.send(i)
544
1235
 
545
1236
 
546
- def test_port_as_argument():
547
- proc_mesh = local_proc_mesh(gpus=1).get()
1237
+ @pytest.mark.timeout(60)
1238
+ def test_port_as_argument() -> None:
1239
+ proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 1})
548
1240
  s = proc_mesh.spawn("send_alot", SendAlot).get()
549
- send, recv = PortTuple.create(proc_mesh._mailbox)
1241
+ send, recv = Channel[int].open()
550
1242
 
551
1243
  s.send.broadcast(send)
552
1244
 
@@ -554,14 +1246,14 @@ def test_port_as_argument():
554
1246
  assert i == recv.recv().get()
555
1247
 
556
1248
 
557
- @pytest.mark.timeout(15)
1249
+ @pytest.mark.timeout(30)
558
1250
  async def test_same_actor_twice() -> None:
559
- pm = await proc_mesh(gpus=1)
560
- await pm.spawn("dup", Counter, 0)
1251
+ pm = this_host().spawn_procs(per_host={"gpus": 1})
1252
+ await pm.spawn("dup", Counter, 0).initialized
561
1253
 
562
1254
  # The second spawn with the same name should fail with a specific error
563
1255
  with pytest.raises(Exception) as exc_info:
564
- await pm.spawn("dup", Counter, 0)
1256
+ await pm.spawn("dup", Counter, 0).initialized
565
1257
 
566
1258
  # Assert that the error message contains the expected text about duplicate actor name
567
1259
  error_msg = str(exc_info.value)
@@ -570,23 +1262,81 @@ async def test_same_actor_twice() -> None:
570
1262
  ), f"Expected error message about duplicate actor name, got: {error_msg}"
571
1263
 
572
1264
 
1265
+ class LsActor(Actor):
1266
+ def __init__(self, workspace: str):
1267
+ self.workspace = workspace
1268
+
1269
+ @endpoint
1270
+ async def ls(self) -> list[str]:
1271
+ return os.listdir(self.workspace)
1272
+
1273
+
1274
+ async def test_sync_workspace() -> None:
1275
+ # create two workspaces: one for local and one for remote
1276
+ with tempfile.TemporaryDirectory() as workspace_src, tempfile.TemporaryDirectory() as workspace_dst:
1277
+
1278
+ def bootstrap_WORKSPACE_DIR() -> None:
1279
+ import os
1280
+
1281
+ os.environ["WORKSPACE_DIR"] = workspace_dst
1282
+
1283
+ pm = this_host().spawn_procs(
1284
+ per_host={"gpus": 1}, bootstrap=bootstrap_WORKSPACE_DIR
1285
+ )
1286
+
1287
+ config = defaults.config("slurm", workspace_src)
1288
+ await pm.sync_workspace(workspace=config.workspace, auto_reload=True)
1289
+
1290
+ # no file in remote workspace initially
1291
+ am = await pm.spawn("ls", LsActor, workspace_dst)
1292
+ for item in list(am.ls.call().get()):
1293
+ assert len(item[1]) == 0
1294
+
1295
+ # write a file to local workspace
1296
+ file_path = os.path.join(workspace_src, "new_file")
1297
+ with open(file_path, "w") as f:
1298
+ f.write("hello world")
1299
+ f.flush()
1300
+
1301
+ # force a sync and it should populate on the dst workspace
1302
+ await pm.sync_workspace(config.workspace, auto_reload=True)
1303
+ for item in list(am.ls.call().get()):
1304
+ assert len(item[1]) == 1
1305
+ assert item[1][0] == "new_file"
1306
+ file_path = os.path.join(workspace_dst, item[1][0])
1307
+ with open(file_path, "r") as f:
1308
+ assert f.readline() == "hello world"
1309
+
1310
+ # sanity check
1311
+ assert "WORKSPACE_DIR" not in os.environ, "test leaves env var side-effects!"
1312
+
1313
+
573
1314
  class TestActorMeshStop(unittest.IsolatedAsyncioTestCase):
574
1315
  async def test_actor_mesh_stop(self) -> None:
575
- pm = await proc_mesh(gpus=2)
1316
+ pm = this_host().spawn_procs(per_host={"gpus": 2})
576
1317
  am_1 = await pm.spawn("printer", Printer)
577
1318
  am_2 = await pm.spawn("printer2", Printer)
578
1319
  await am_1.print.call("hello 1")
579
1320
  await am_1.log.call("hello 2")
580
- await cast(ActorMeshRef, am_1).stop()
1321
+ await cast(ActorMesh, am_1).stop()
581
1322
 
582
1323
  with self.assertRaisesRegex(
583
- RuntimeError, expected_regex="`ActorMesh` has been stopped"
1324
+ RuntimeError, expected_regex="`PythonActorMesh` has already been stopped"
584
1325
  ):
585
1326
  await am_1.print.call("hello 1")
586
1327
 
587
1328
  await am_2.print.call("hello 3")
588
1329
  await am_2.log.call("hello 4")
589
1330
 
1331
+ await pm.stop()
1332
+
1333
+ async def test_proc_mesh_stop_after_actor_mesh_stop(self) -> None:
1334
+ pm = this_host().spawn_procs(per_host={"gpus": 2})
1335
+ am = await pm.spawn("printer", Printer)
1336
+
1337
+ await cast(ActorMesh, am).stop()
1338
+ await pm.stop()
1339
+
590
1340
 
591
1341
  class PortedActor(Actor):
592
1342
  @endpoint(explicit_response_port=True)
@@ -594,7 +1344,161 @@ class PortedActor(Actor):
594
1344
  port.send(3 + b)
595
1345
 
596
1346
 
1347
+ @pytest.mark.timeout(60)
597
1348
  def test_ported_actor():
598
- proc_mesh = local_proc_mesh(gpus=1).get()
1349
+ proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 1}).get()
599
1350
  a = proc_mesh.spawn("port_actor", PortedActor).get()
600
1351
  assert 5 == a.add.call_one(2).get()
1352
+
1353
+
1354
+ async def _recv():
1355
+ return (7, 2, 3)
1356
+
1357
+
1358
+ async def consume():
1359
+ r = await PythonTask.from_coroutine(_recv())
1360
+ assert r == (7, 2, 3)
1361
+
1362
+
1363
+ @pytest.mark.timeout(60)
1364
+ def test_python_task_tuple() -> None:
1365
+ PythonTask.from_coroutine(consume()).block_on()
1366
+
1367
+
1368
+ def test_select_result() -> None:
1369
+ def s(t):
1370
+ time.sleep(t)
1371
+ return t
1372
+
1373
+ a = PythonTask.spawn_blocking(lambda: s(4))
1374
+ b = PythonTask.spawn_blocking(lambda: s(0))
1375
+ r = PythonTask.select_one([a.task(), b.task()]).block_on()
1376
+ assert r == (0, 1)
1377
+
1378
+
1379
+ def test_mesh_len():
1380
+ proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 12})
1381
+ s = proc_mesh.spawn("sync_actor", SyncActor).get()
1382
+ assert 12 == len(s)
1383
+
1384
+
1385
+ class UndeliverableMessageReceiver(Actor):
1386
+ def __init__(self):
1387
+ self._messages = asyncio.Queue()
1388
+
1389
+ @endpoint
1390
+ async def receive_undeliverable(
1391
+ self, sender: ActorId, dest: PortId, error_msg: str
1392
+ ) -> None:
1393
+ await self._messages.put((sender, dest, error_msg))
1394
+
1395
+ @endpoint
1396
+ async def get_messages(self) -> Tuple[ActorId, PortId, str]:
1397
+ return await self._messages.get()
1398
+
1399
+
1400
+ class UndeliverableMessageSender(Actor):
1401
+ @endpoint
1402
+ def send_undeliverable(self) -> None:
1403
+ mailbox = context().actor_instance._mailbox
1404
+ port_id = PortId(
1405
+ actor_id=ActorId(
1406
+ world_name=mailbox.actor_id.world_name, rank=0, actor_name="bogus"
1407
+ ),
1408
+ port=1234,
1409
+ )
1410
+ port_ref = PortRef(port_id)
1411
+ port_ref.send(
1412
+ mailbox,
1413
+ PythonMessage(PythonMessageKind.Result(None), b"123"),
1414
+ )
1415
+
1416
+
1417
+ class UndeliverableMessageSenderWithOverride(UndeliverableMessageSender):
1418
+ def __init__(self, receiver: UndeliverableMessageReceiver):
1419
+ self._receiver = receiver
1420
+
1421
+ def _handle_undeliverable_message(
1422
+ self, message: UndeliverableMessageEnvelope
1423
+ ) -> bool:
1424
+ self._receiver.receive_undeliverable.call_one(
1425
+ message.sender(), message.dest(), message.error_msg()
1426
+ ).get()
1427
+ return True
1428
+
1429
+
1430
+ @pytest.mark.timeout(60)
1431
+ async def test_undeliverable_message_with_override() -> None:
1432
+ pm = this_host().spawn_procs(per_host={"gpus": 1})
1433
+ receiver = pm.spawn("undeliverable_receiver", UndeliverableMessageReceiver)
1434
+ sender = pm.spawn(
1435
+ "undeliverable_sender", UndeliverableMessageSenderWithOverride, receiver
1436
+ )
1437
+ sender.send_undeliverable.call().get()
1438
+ sender, dest, error_msg = receiver.get_messages.call_one().get()
1439
+ assert sender.actor_name == "undeliverable_sender"
1440
+ assert dest.actor_id.actor_name == "bogus"
1441
+ assert error_msg is not None
1442
+ pm.stop().get()
1443
+
1444
+
1445
+ @pytest.mark.timeout(60)
1446
+ async def test_undeliverable_message_without_override() -> None:
1447
+ pm = this_host().spawn_procs(per_host={"gpus": 1})
1448
+ sender = pm.spawn("undeliverable_sender", UndeliverableMessageSender)
1449
+ sender.send_undeliverable.call().get()
1450
+ # Wait a few seconds to ensure that the undeliverable message is processed
1451
+ # without crashing anything
1452
+ await asyncio.sleep(5)
1453
+ pm.stop().get()
1454
+
1455
+
1456
+ def test_this_and_that():
1457
+ counter = this_proc().spawn("counter", Counter, 7)
1458
+ assert 7 == counter.value.call_one().get()
1459
+
1460
+
1461
+ class ReceptorActor(Actor):
1462
+ @endpoint
1463
+ def status(self):
1464
+ return 1
1465
+
1466
+
1467
+ async def test_things_survive_losing_python_reference() -> None:
1468
+ """Test the slice_receptor_mesh function in LOCAL mode, verifying that setup methods are called."""
1469
+
1470
+ receptor = (
1471
+ this_host()
1472
+ .spawn_procs(per_host={"gpus": 1})
1473
+ .spawn(
1474
+ "receptor",
1475
+ ReceptorActor,
1476
+ )
1477
+ )
1478
+ receptor = receptor.slice(gpus=0)
1479
+
1480
+ await receptor.status.call()
1481
+
1482
+
1483
+ class IsInit(Actor):
1484
+ @endpoint
1485
+ def is_cuda_initialized(self) -> bool:
1486
+ cuda = ctypes.CDLL("libcuda.so.1")
1487
+ CUresult = ctypes.c_int
1488
+ cuDeviceGetCount = cuda.cuDeviceGetCount
1489
+ cuDeviceGetCount.argtypes = [ctypes.POINTER(ctypes.c_int)]
1490
+ cuDeviceGetCount.restype = CUresult
1491
+ count = ctypes.c_int()
1492
+ result = cuDeviceGetCount(ctypes.byref(count))
1493
+ CUDA_ERROR_NOT_INITIALIZED = 3
1494
+ return result == CUDA_ERROR_NOT_INITIALIZED
1495
+
1496
+
1497
+ @pytest.mark.oss_skip
1498
+ def test_cuda_is_not_initialized_in_a_new_proc():
1499
+ try:
1500
+ ctypes.CDLL("libcuda.so.1")
1501
+ except OSError:
1502
+ pytest.skip("cannot find cuda")
1503
+ proc = this_host().spawn_procs().spawn("is_init", IsInit)
1504
+ assert not proc.is_cuda_initialized.call_one().get()