torchmonarch-nightly 2025.6.30__cp312-cp312-manylinux2014_x86_64.whl → 2025.7.25__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. monarch/__init__.py +13 -9
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
  4. monarch/_src/actor/actor_mesh.py +874 -0
  5. monarch/{allocator.py → _src/actor/allocator.py} +26 -17
  6. monarch/_src/actor/bootstrap_main.py +73 -0
  7. monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
  8. monarch/_src/actor/code_sync/auto_reload.py +223 -0
  9. monarch/_src/actor/debugger.py +565 -0
  10. monarch/_src/actor/endpoint.py +270 -0
  11. monarch/_src/actor/event_loop.py +97 -0
  12. monarch/_src/actor/future.py +100 -0
  13. monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
  14. monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
  15. monarch/_src/actor/proc_mesh.py +500 -0
  16. monarch/_src/actor/sync_state.py +18 -0
  17. monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
  18. monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
  19. monarch/_src/actor/tensor_engine_shim.py +56 -0
  20. monarch/_src/tensor_engine/rdma.py +180 -0
  21. monarch/_testing.py +3 -2
  22. monarch/actor/__init__.py +51 -0
  23. monarch/actor_mesh.py +6 -752
  24. monarch/bootstrap_main.py +8 -47
  25. monarch/common/client.py +1 -1
  26. monarch/common/controller_api.py +2 -1
  27. monarch/common/device_mesh.py +12 -2
  28. monarch/common/messages.py +12 -1
  29. monarch/common/recording.py +4 -3
  30. monarch/common/remote.py +135 -52
  31. monarch/common/tensor.py +2 -1
  32. monarch/controller/backend.py +2 -2
  33. monarch/controller/controller.py +2 -1
  34. monarch/controller/rust_backend/controller.py +2 -1
  35. monarch/fetch.py +3 -5
  36. monarch/mesh_controller.py +201 -139
  37. monarch/monarch_controller +0 -0
  38. monarch/opaque_module.py +4 -6
  39. monarch/opaque_object.py +3 -3
  40. monarch/proc_mesh.py +6 -309
  41. monarch/python_local_mesh.py +1 -1
  42. monarch/rust_backend_mesh.py +2 -1
  43. monarch/rust_local_mesh.py +4 -2
  44. monarch/sim_mesh.py +10 -19
  45. monarch/simulator/command_history.py +1 -1
  46. monarch/simulator/interface.py +2 -1
  47. monarch/simulator/mock_controller.py +1 -1
  48. monarch/simulator/simulator.py +1 -1
  49. monarch/tensor_engine/__init__.py +23 -0
  50. monarch/tensor_worker_main.py +3 -1
  51. monarch/tools/cli.py +3 -1
  52. monarch/tools/commands.py +95 -35
  53. monarch/tools/mesh_spec.py +55 -0
  54. monarch/tools/utils.py +38 -0
  55. monarch/worker/worker.py +1 -1
  56. monarch/world_mesh.py +2 -1
  57. monarch_supervisor/python_executable.py +6 -3
  58. tests/error_test_binary.py +75 -9
  59. tests/test_actor_error.py +370 -21
  60. tests/test_alloc.py +1 -1
  61. tests/test_allocator.py +373 -17
  62. tests/test_controller.py +2 -0
  63. tests/test_debugger.py +416 -0
  64. tests/test_env_before_cuda.py +162 -0
  65. tests/test_python_actors.py +184 -332
  66. tests/test_rdma.py +198 -0
  67. tests/test_remote_functions.py +40 -12
  68. tests/test_rust_backend.py +7 -5
  69. tests/test_sim_backend.py +1 -4
  70. tests/test_tensor_engine.py +55 -1
  71. {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/METADATA +6 -1
  72. {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/RECORD +80 -68
  73. torchmonarch_nightly-2025.7.25.dist-info/entry_points.txt +3 -0
  74. monarch/_monarch/hyperactor/__init__.py +0 -58
  75. monarch/_monarch/worker/debugger.py +0 -117
  76. monarch/_monarch/worker/logging.py +0 -107
  77. monarch/debugger.py +0 -379
  78. monarch/future.py +0 -76
  79. monarch/rdma.py +0 -162
  80. torchmonarch_nightly-2025.6.30.dist-info/entry_points.txt +0 -3
  81. /monarch/{_monarch/worker → _src}/__init__.py +0 -0
  82. /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
  83. /monarch/{common → _src/actor}/shape.py +0 -0
  84. /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
  85. {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/WHEEL +0 -0
  86. {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/licenses/LICENSE +0 -0
  87. {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/top_level.txt +0 -0
@@ -4,34 +4,39 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-unsafe
7
8
  import asyncio
9
+ import logging
8
10
  import operator
9
- import re
11
+ import os
12
+ import sys
13
+ import tempfile
10
14
  import threading
11
15
  import time
16
+ import unittest
17
+ from logging import INFO
12
18
  from types import ModuleType
13
- from unittest.mock import AsyncMock, patch
14
-
15
- import monarch
19
+ from typing import cast
16
20
 
17
21
  import pytest
18
22
 
19
23
  import torch
20
24
 
21
- from monarch.actor_mesh import (
25
+ from monarch._src.actor.actor_mesh import ActorMeshRef, Port, PortTuple
26
+
27
+ from monarch.actor import (
22
28
  Accumulator,
23
29
  Actor,
24
30
  current_actor_name,
25
31
  current_rank,
26
32
  current_size,
27
33
  endpoint,
28
- MonarchContext,
34
+ Future,
35
+ local_proc_mesh,
36
+ proc_mesh,
29
37
  )
30
- from monarch.debugger import init_debugging
31
- from monarch.future import ActorFuture
38
+ from typing_extensions import assert_type
32
39
 
33
- from monarch.proc_mesh import local_proc_mesh, proc_mesh
34
- from monarch.rdma import RDMABuffer
35
40
 
36
41
  needs_cuda = pytest.mark.skipif(
37
42
  not torch.cuda.is_available(),
@@ -51,6 +56,10 @@ class Counter(Actor):
51
56
  async def value(self) -> int:
52
57
  return self.v
53
58
 
59
+ @endpoint
60
+ def value_sync_endpoint(self) -> int:
61
+ return self.v
62
+
54
63
 
55
64
  class Indirect(Actor):
56
65
  @endpoint
@@ -58,36 +67,23 @@ class Indirect(Actor):
58
67
  return await c.value.choose()
59
68
 
60
69
 
61
- class ParameterServer(Actor):
62
- def __init__(self):
63
- self.params = torch.rand(10, 10)
64
- self.grad_buffer = torch.rand(10, 10)
65
-
66
- @endpoint
67
- async def grad_handle(self) -> RDMABuffer:
68
- byte_tensor = self.grad_buffer.view(torch.uint8).flatten()
69
- return RDMABuffer(byte_tensor)
70
-
71
- @endpoint
72
- async def update(self):
73
- self.params += 0.01 * self.grad_buffer
74
-
75
- @endpoint
76
- async def get_grad_buffer(self) -> torch.Tensor:
77
- # just used for testing
78
- return self.grad_buffer
79
-
80
-
81
70
  async def test_choose():
82
71
  proc = await local_proc_mesh(gpus=2)
83
72
  v = await proc.spawn("counter", Counter, 3)
84
73
  i = await proc.spawn("indirect", Indirect)
85
74
  v.incr.broadcast()
86
75
  result = await v.value.choose()
76
+
77
+ # Test that Pyre derives the correct type for result (int, not Any)
78
+ assert_type(result, int)
87
79
  result2 = await i.call_value.choose(v)
88
80
 
89
81
  assert result == result2
90
82
 
83
+ result3 = await v.value_sync_endpoint.choose()
84
+ assert_type(result, int)
85
+ assert result2 == result3
86
+
91
87
 
92
88
  async def test_stream():
93
89
  proc = await local_proc_mesh(gpus=2)
@@ -97,78 +93,6 @@ async def test_stream():
97
93
  assert 8 == sum([x async for x in v.value.stream()])
98
94
 
99
95
 
100
- class ParameterClient(Actor):
101
- def __init__(self, server, buffer):
102
- self.server = server
103
- byte_tensor = buffer.view(torch.uint8).flatten()
104
- self.buffer = byte_tensor
105
-
106
- @endpoint
107
- async def upload(self, tensor):
108
- gh = await self.server.grad_handle.call_one()
109
- await gh.write(tensor)
110
-
111
- @endpoint
112
- async def download(self):
113
- gh = await self.server.grad_handle.call_one()
114
- await gh.read_into(self.buffer)
115
-
116
- @endpoint
117
- async def get_buffer(self):
118
- return self.buffer
119
-
120
-
121
- @needs_cuda
122
- async def test_proc_mesh_rdma():
123
- proc = await proc_mesh(gpus=1)
124
- server = await proc.spawn("server", ParameterServer)
125
-
126
- # --- CPU TESTS ---
127
- client_cpu = await proc.spawn(
128
- "client_cpu", ParameterClient, server, torch.ones(10, 10)
129
- )
130
- x = await client_cpu.get_buffer.call_one()
131
- assert torch.sum(x.view(torch.float32).view(10, 10)) == 100
132
- zeros = torch.zeros(10, 10)
133
- await client_cpu.upload.call_one(zeros.view(torch.uint8).flatten())
134
- await client_cpu.download.call_one()
135
- x = await client_cpu.get_buffer.call_one()
136
- assert torch.sum(x.view(torch.float32).view(10, 10)) == 0
137
-
138
- # --- Modify server's backing buffer directly ---
139
- await server.update.call_one()
140
-
141
- # Should reflect updated values
142
- await client_cpu.download.call_one()
143
-
144
- buffer = await client_cpu.get_buffer.call_one()
145
- remote_grad = await server.get_grad_buffer.call_one()
146
- assert torch.allclose(buffer.view(torch.float32).view(10, 10), remote_grad)
147
-
148
- # --- GPU TESTS ---
149
- client_gpu = await proc.spawn(
150
- "client_gpu", ParameterClient, server, torch.ones(10, 10, device="cuda")
151
- )
152
- x = await client_gpu.get_buffer.call_one()
153
- buffer = x.view(torch.float32).view(10, 10)
154
- assert torch.sum(buffer) == 100
155
- zeros = torch.zeros(10, 10, device="cuda")
156
- await client_gpu.upload.call_one(zeros.view(torch.uint8).flatten())
157
- await client_gpu.download.call_one()
158
- x = await client_gpu.get_buffer.call_one()
159
- buffer_gpu = x.view(torch.float32).view(10, 10)
160
- assert torch.sum(buffer_gpu) == 0
161
- assert buffer_gpu.device.type == "cuda"
162
-
163
- # Modify server state again
164
- await server.update.call_one()
165
- await client_gpu.download.call_one()
166
- x = await client_gpu.get_buffer.call_one()
167
- buffer_gpu = x.view(torch.float32).view(10, 10)
168
- remote_grad = await server.get_grad_buffer.call_one()
169
- assert torch.allclose(buffer_gpu.cpu(), remote_grad)
170
-
171
-
172
96
  class To(Actor):
173
97
  @endpoint
174
98
  async def whoami(self):
@@ -240,69 +164,6 @@ async def test_rank_size():
240
164
  assert 4 == await acc.accumulate(lambda: current_size()["gpus"])
241
165
 
242
166
 
243
- class TrainerActor(Actor):
244
- def __init__(self):
245
- super().__init__()
246
- self.trainer = torch.nn.Linear(10, 10).to("cuda")
247
- self.trainer.weight.data.zero_()
248
-
249
- @endpoint
250
- async def init(self, gen):
251
- ranks = current_rank()
252
- self.gen = gen.slice(**ranks)
253
-
254
- @endpoint
255
- async def exchange_metadata(self):
256
- byte_tensor = self.trainer.weight.data.view(torch.uint8).flatten()
257
- self.handle = RDMABuffer(byte_tensor)
258
- await self.gen.attach_weight_buffer.call(self.handle)
259
-
260
- @endpoint
261
- async def weights_ready(self):
262
- self.trainer.weight.data.add_(1.0)
263
-
264
-
265
- class GeneratorActor(Actor):
266
- def __init__(self):
267
- super().__init__()
268
- self.generator = torch.nn.Linear(10, 10).to("cuda")
269
- self.step = 0
270
-
271
- @endpoint
272
- async def init(self, trainer):
273
- ranks = current_rank()
274
- self.trainer = trainer.slice(**ranks)
275
-
276
- @endpoint
277
- async def attach_weight_buffer(self, handle):
278
- self.handle = handle
279
-
280
- @endpoint
281
- async def update_weights(self):
282
- self.step += 1
283
- byte_tensor = self.generator.weight.data.view(torch.uint8).flatten()
284
- await self.handle.read_into(byte_tensor)
285
- assert (
286
- torch.sum(self.generator.weight.data) == self.step * 100
287
- ), f"{torch.sum(self.generator.weight.data)=}, {self.step=}"
288
-
289
-
290
- @needs_cuda
291
- async def test_gpu_trainer_generator():
292
- trainer_proc = await proc_mesh(gpus=1)
293
- gen_proc = await proc_mesh(gpus=1)
294
- trainer = await trainer_proc.spawn("trainer", TrainerActor)
295
- generator = await gen_proc.spawn("gen", GeneratorActor)
296
-
297
- await generator.init.call(trainer)
298
- await trainer.init.call(generator)
299
- await trainer.exchange_metadata.call()
300
-
301
- for _ in range(3):
302
- await trainer.weights_ready.call()
303
- await generator.update_weights.call()
304
-
305
-
306
167
  class SyncActor(Actor):
307
168
  @endpoint
308
169
  def sync_endpoint(self, a_counter: Counter):
@@ -317,22 +178,6 @@ async def test_sync_actor():
317
178
  assert r == 5
318
179
 
319
180
 
320
- @needs_cuda
321
- def test_gpu_trainer_generator_sync() -> None:
322
- trainer_proc = proc_mesh(gpus=1).get()
323
- gen_proc = proc_mesh(gpus=1).get()
324
- trainer = trainer_proc.spawn("trainer", TrainerActor).get()
325
- generator = gen_proc.spawn("gen", GeneratorActor).get()
326
-
327
- generator.init.call(trainer).get()
328
- trainer.init.call(generator).get()
329
- trainer.exchange_metadata.call().get()
330
-
331
- for _ in range(3):
332
- trainer.weights_ready.call().get()
333
- generator.update_weights.call().get()
334
-
335
-
336
181
  def test_sync_actor_sync_client():
337
182
  proc = local_proc_mesh(gpus=2).get()
338
183
  a = proc.spawn("actor", SyncActor).get()
@@ -408,146 +253,6 @@ def test_proc_mesh_liveness() -> None:
408
253
  counter.value.call().get()
409
254
 
410
255
 
411
- def _debugee_actor_internal(rank):
412
- if rank == 0:
413
- breakpoint() # noqa
414
- rank += 1
415
- return rank
416
- elif rank == 1:
417
- breakpoint() # noqa
418
- rank += 2
419
- return rank
420
- elif rank == 2:
421
- breakpoint() # noqa
422
- rank += 3
423
- raise ValueError("bad rank")
424
- elif rank == 3:
425
- breakpoint() # noqa
426
- rank += 4
427
- return rank
428
-
429
-
430
- class DebugeeActor(Actor):
431
- @endpoint
432
- async def to_debug(self):
433
- rank = MonarchContext.get().point.rank
434
- return _debugee_actor_internal(rank)
435
-
436
-
437
- async def test_debug() -> None:
438
- input_mock = AsyncMock()
439
- input_mock.side_effect = [
440
- "attach 1",
441
- "n",
442
- "n",
443
- "n",
444
- "n",
445
- "detach",
446
- "attach 1",
447
- "detach",
448
- "quit",
449
- "cast 0,3 n",
450
- "cast 0,3 n",
451
- # Attaching to 0 and 3 ensures that when we call "list"
452
- # the next time, their function/lineno info will be
453
- # up-to-date.
454
- "attach 0",
455
- "detach",
456
- "attach 3",
457
- "detach",
458
- "quit",
459
- "attach 2",
460
- "c",
461
- "quit",
462
- "continue",
463
- ]
464
-
465
- outputs = []
466
-
467
- def _patch_output(msg):
468
- nonlocal outputs
469
- outputs.append(msg)
470
-
471
- with patch("monarch.debugger._debugger_input", side_effect=input_mock), patch(
472
- "monarch.debugger._debugger_output", new=_patch_output
473
- ):
474
- proc = await proc_mesh(hosts=2, gpus=2)
475
- debugee = await proc.spawn("debugee", DebugeeActor)
476
- debug_client = await init_debugging(debugee)
477
-
478
- fut = debugee.to_debug.call()
479
- await debug_client.wait_pending_session.call_one()
480
- breakpoints = []
481
- for i in range(10):
482
- breakpoints = await debug_client.list.call_one()
483
- if len(breakpoints) == 4:
484
- break
485
- await asyncio.sleep(1)
486
- if i == 9:
487
- raise RuntimeError("timed out waiting for breakpoints")
488
-
489
- initial_linenos = {}
490
- for i in range(len(breakpoints)):
491
- rank, coords, _, _, function, lineno = breakpoints[i]
492
- initial_linenos[rank] = lineno
493
- assert rank == i
494
- assert coords == {"hosts": rank % 2, "gpus": rank // 2}
495
- assert function == "test_python_actors._debugee_actor_internal"
496
- assert lineno == breakpoints[0][5] + 4 * rank
497
-
498
- await debug_client.enter.call_one()
499
-
500
- # Check that when detaching and re-attaching to a session, the last portion of the output is repeated
501
- expected_last_output = [
502
- r"--Return--",
503
- r"\n",
504
- r"> (/.*/)+test_python_actors.py\(\d+\)to_debug\(\)->3\n-> return _debugee_actor_internal\(rank\)",
505
- r"\n",
506
- r"\(Pdb\) ",
507
- ]
508
- output_len = len(expected_last_output)
509
- assert outputs[-2 * output_len : -output_len] == outputs[-output_len:]
510
- for real_output, expected_output in zip(
511
- outputs[-output_len:], expected_last_output
512
- ):
513
- assert re.match(expected_output, real_output) is not None
514
-
515
- breakpoints = await debug_client.list.call_one()
516
- for i in range(len(breakpoints)):
517
- if i == 1:
518
- assert breakpoints[i][4] == "test_python_actors.to_debug"
519
- else:
520
- assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal"
521
- assert breakpoints[i][5] == initial_linenos[i]
522
-
523
- await debug_client.enter.call_one()
524
-
525
- breakpoints = await debug_client.list.call_one()
526
- for i in range(len(breakpoints)):
527
- if i == 1:
528
- assert breakpoints[i][4] == "test_python_actors.to_debug"
529
- elif i in (0, 3):
530
- assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal"
531
- assert breakpoints[i][5] == initial_linenos[i] + 2
532
- else:
533
- assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal"
534
- assert breakpoints[i][5] == initial_linenos[i]
535
-
536
- await debug_client.enter.call_one()
537
-
538
- breakpoints = await debug_client.list.call_one()
539
- assert len(breakpoints) == 3
540
- for i, rank in enumerate((0, 1, 3)):
541
- assert breakpoints[i][0] == rank
542
-
543
- await debug_client.enter.call_one()
544
- breakpoints = await debug_client.list.call_one()
545
- assert len(breakpoints) == 0
546
-
547
- with pytest.raises(monarch.actor_mesh.ActorError, match="ValueError: bad rank"):
548
- await fut
549
-
550
-
551
256
  class TLSActor(Actor):
552
257
  """An actor that manages thread-local state."""
553
258
 
@@ -643,7 +348,7 @@ async def awaitit(f):
643
348
  return await f
644
349
 
645
350
 
646
- def test_actor_future():
351
+ def test_actor_future() -> None:
647
352
  v = 0
648
353
 
649
354
  async def incr():
@@ -653,32 +358,31 @@ def test_actor_future():
653
358
 
654
359
  # can use async implementation from sync
655
360
  # if no non-blocking is provided
656
- f = ActorFuture(incr)
361
+ f = Future(impl=incr, requires_loop=False)
657
362
  assert f.get() == 1
658
363
  assert v == 1
659
364
  assert f.get() == 1
660
365
  assert asyncio.run(awaitit(f)) == 1
661
366
 
662
- f = ActorFuture(incr)
367
+ f = Future(impl=incr, requires_loop=False)
663
368
  assert asyncio.run(awaitit(f)) == 2
664
369
  assert f.get() == 2
665
370
 
666
- def incr2():
371
+ async def incr2():
667
372
  nonlocal v
668
373
  v += 2
669
374
  return v
670
375
 
671
376
  # Use non-blocking optimization if provided
672
- f = ActorFuture(incr, incr2)
377
+ f = Future(impl=incr2)
673
378
  assert f.get() == 4
674
- assert asyncio.run(awaitit(f)) == 4
675
379
 
676
380
  async def nope():
677
381
  nonlocal v
678
382
  v += 1
679
383
  raise ValueError("nope")
680
384
 
681
- f = ActorFuture(nope)
385
+ f = Future(impl=nope, requires_loop=False)
682
386
 
683
387
  with pytest.raises(ValueError):
684
388
  f.get()
@@ -695,12 +399,12 @@ def test_actor_future():
695
399
 
696
400
  assert v == 5
697
401
 
698
- def nope():
402
+ async def nope2():
699
403
  nonlocal v
700
404
  v += 1
701
405
  raise ValueError("nope")
702
406
 
703
- f = ActorFuture(incr, nope)
407
+ f = Future(impl=nope2)
704
408
 
705
409
  with pytest.raises(ValueError):
706
410
  f.get()
@@ -722,7 +426,7 @@ def test_actor_future():
722
426
  async def seven():
723
427
  return 7
724
428
 
725
- f = ActorFuture(seven)
429
+ f = Future(impl=seven, requires_loop=False)
726
430
 
727
431
  assert 7 == f.get(timeout=0.001)
728
432
 
@@ -730,7 +434,155 @@ def test_actor_future():
730
434
  f = asyncio.Future()
731
435
  await f
732
436
 
733
- f = ActorFuture(neverfinish)
437
+ f = Future(impl=neverfinish, requires_loop=True)
734
438
 
735
439
  with pytest.raises(asyncio.exceptions.TimeoutError):
736
440
  f.get(timeout=0.1)
441
+
442
+
443
+ class Printer(Actor):
444
+ def __init__(self):
445
+ self.logger = logging.getLogger()
446
+ self.logger.setLevel(INFO)
447
+
448
+ @endpoint
449
+ async def print(self, content: str):
450
+ print(f"{os.getpid()} {content}")
451
+
452
+ @endpoint
453
+ async def log(self, content: str):
454
+ self.logger.info(f"{os.getpid()} {content}")
455
+
456
+
457
+ async def test_actor_log_streaming() -> None:
458
+ # Save original file descriptors
459
+ original_stdout_fd = os.dup(1) # stdout
460
+ original_stderr_fd = os.dup(2) # stderr
461
+
462
+ try:
463
+ # Create temporary files to capture output
464
+ with tempfile.NamedTemporaryFile(
465
+ mode="w+", delete=False
466
+ ) as stdout_file, tempfile.NamedTemporaryFile(
467
+ mode="w+", delete=False
468
+ ) as stderr_file:
469
+ stdout_path = stdout_file.name
470
+ stderr_path = stderr_file.name
471
+
472
+ # Redirect file descriptors to our temp files
473
+ # This will capture both Python and Rust output
474
+ os.dup2(stdout_file.fileno(), 1)
475
+ os.dup2(stderr_file.fileno(), 2)
476
+
477
+ # Also redirect Python's sys.stdout/stderr for completeness
478
+ original_sys_stdout = sys.stdout
479
+ original_sys_stderr = sys.stderr
480
+ sys.stdout = stdout_file
481
+ sys.stderr = stderr_file
482
+
483
+ try:
484
+ pm = await proc_mesh(gpus=2)
485
+ am = await pm.spawn("printer", Printer)
486
+
487
+ await am.print.call("hello 1")
488
+ await am.log.call("hello 2")
489
+
490
+ await pm.logging_option(stream_to_client=True)
491
+
492
+ await am.print.call("hello 3")
493
+ await am.log.call("hello 4")
494
+
495
+ # Give it sometime to send log back
496
+ time.sleep(5)
497
+
498
+ # Flush all outputs
499
+ stdout_file.flush()
500
+ stderr_file.flush()
501
+ os.fsync(stdout_file.fileno())
502
+ os.fsync(stderr_file.fileno())
503
+
504
+ finally:
505
+ # Restore Python's sys.stdout/stderr
506
+ sys.stdout = original_sys_stdout
507
+ sys.stderr = original_sys_stderr
508
+
509
+ # Restore original file descriptors
510
+ os.dup2(original_stdout_fd, 1)
511
+ os.dup2(original_stderr_fd, 2)
512
+
513
+ # Read the captured output
514
+ with open(stdout_path, "r") as f:
515
+ stdout_content = f.read()
516
+
517
+ # Clean up temp files
518
+ os.unlink(stdout_path)
519
+ os.unlink(stderr_path)
520
+
521
+ # TODO: (@jamessun) we need to disable logging forwarder for python logger
522
+ # assert "hello 1" not in stdout_content
523
+ assert "hello 2" not in stdout_content
524
+
525
+ assert "hello 3" in stdout_content
526
+ # assert "hello 4" in stdout_content
527
+
528
+ finally:
529
+ # Ensure file descriptors are restored even if something goes wrong
530
+ try:
531
+ os.dup2(original_stdout_fd, 1)
532
+ os.dup2(original_stderr_fd, 2)
533
+ os.close(original_stdout_fd)
534
+ os.close(original_stderr_fd)
535
+ except OSError:
536
+ pass
537
+
538
+
539
+ class SendAlot(Actor):
540
+ @endpoint
541
+ async def send(self, port: Port[int]):
542
+ for i in range(100):
543
+ port.send(i)
544
+
545
+
546
+ def test_port_as_argument():
547
+ proc_mesh = local_proc_mesh(gpus=1).get()
548
+ s = proc_mesh.spawn("send_alot", SendAlot).get()
549
+ send, recv = PortTuple.create(proc_mesh._mailbox)
550
+
551
+ s.send.broadcast(send)
552
+
553
+ for i in range(100):
554
+ assert i == recv.recv().get()
555
+
556
+
557
+ @pytest.mark.timeout(15)
558
+ async def test_same_actor_twice() -> None:
559
+ pm = await proc_mesh(gpus=1)
560
+ await pm.spawn("dup", Counter, 0)
561
+
562
+ # The second spawn with the same name should fail with a specific error
563
+ with pytest.raises(Exception) as exc_info:
564
+ await pm.spawn("dup", Counter, 0)
565
+
566
+ # Assert that the error message contains the expected text about duplicate actor name
567
+ error_msg = str(exc_info.value)
568
+ assert (
569
+ "gspawn failed: an actor with name 'dup' has already been spawned" in error_msg
570
+ ), f"Expected error message about duplicate actor name, got: {error_msg}"
571
+
572
+
573
+ class TestActorMeshStop(unittest.IsolatedAsyncioTestCase):
574
+ async def test_actor_mesh_stop(self) -> None:
575
+ pm = await proc_mesh(gpus=2)
576
+ am_1 = await pm.spawn("printer", Printer)
577
+ am_2 = await pm.spawn("printer2", Printer)
578
+ await am_1.print.call("hello 1")
579
+ await am_1.log.call("hello 2")
580
+ await cast(ActorMeshRef, am_1).stop()
581
+
582
+ with self.assertRaisesRegex(
583
+ RuntimeError, expected_regex="`ActorMesh` has been stopped"
584
+ ):
585
+ await am_1.print.call("hello 1")
586
+
587
+ await am_2.print.call("hello 3")
588
+ await am_2.log.call("hello 4")