torchmonarch-nightly 2025.8.1__cp310-cp310-manylinux2014_x86_64.whl → 2025.9.3__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/_rust_bindings.so +0 -0
- monarch/_src/actor/actor_mesh.py +414 -216
- monarch/_src/actor/allocator.py +75 -6
- monarch/_src/actor/bootstrap_main.py +7 -4
- monarch/_src/actor/code_sync/__init__.py +2 -0
- monarch/_src/actor/debugger/__init__.py +7 -0
- monarch/_src/actor/{debugger.py → debugger/debugger.py} +246 -135
- monarch/_src/actor/{pdb_wrapper.py → debugger/pdb_wrapper.py} +62 -23
- monarch/_src/actor/endpoint.py +27 -45
- monarch/_src/actor/future.py +86 -24
- monarch/_src/actor/host_mesh.py +125 -0
- monarch/_src/actor/logging.py +94 -0
- monarch/_src/actor/pickle.py +25 -0
- monarch/_src/actor/proc_mesh.py +423 -156
- monarch/_src/actor/python_extension_methods.py +90 -0
- monarch/_src/actor/shape.py +8 -1
- monarch/_src/actor/source_loader.py +45 -0
- monarch/_src/actor/telemetry/__init__.py +172 -0
- monarch/_src/actor/telemetry/rust_span_tracing.py +6 -39
- monarch/_src/debug_cli/__init__.py +7 -0
- monarch/_src/debug_cli/debug_cli.py +43 -0
- monarch/_src/tensor_engine/rdma.py +64 -9
- monarch/_testing.py +1 -3
- monarch/actor/__init__.py +24 -4
- monarch/common/_C.so +0 -0
- monarch/common/device_mesh.py +14 -0
- monarch/common/future.py +10 -0
- monarch/common/remote.py +14 -25
- monarch/common/tensor.py +12 -0
- monarch/debug_cli/__init__.py +7 -0
- monarch/debug_cli/__main__.py +12 -0
- monarch/fetch.py +2 -2
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +4 -2
- monarch/mesh_controller.py +34 -14
- monarch/monarch_controller +0 -0
- monarch/tools/colors.py +25 -0
- monarch/tools/commands.py +42 -7
- monarch/tools/components/hyperactor.py +1 -1
- monarch/tools/config/__init__.py +31 -4
- monarch/tools/config/defaults.py +13 -3
- monarch/tools/config/environment.py +45 -0
- monarch/tools/config/workspace.py +165 -0
- monarch/tools/mesh_spec.py +2 -0
- monarch/utils/__init__.py +9 -0
- monarch/utils/utils.py +78 -0
- tests/error_test_binary.py +5 -3
- tests/python_actor_test_binary.py +52 -0
- tests/test_actor_error.py +142 -14
- tests/test_alloc.py +1 -1
- tests/test_allocator.py +59 -72
- tests/test_coalescing.py +1 -1
- tests/test_debugger.py +639 -45
- tests/test_env_before_cuda.py +4 -4
- tests/test_mesh_trait.py +38 -0
- tests/test_python_actors.py +979 -75
- tests/test_rdma.py +7 -6
- tests/test_tensor_engine.py +6 -6
- {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/METADATA +82 -4
- {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/RECORD +64 -48
- {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/WHEEL +0 -0
- {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/entry_points.txt +0 -0
- {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/licenses/LICENSE +0 -0
- {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/top_level.txt +0 -0
tests/test_python_actors.py
CHANGED
@@ -5,24 +5,45 @@
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
6
6
|
|
7
7
|
# pyre-unsafe
|
8
|
+
|
8
9
|
import asyncio
|
10
|
+
import ctypes
|
11
|
+
import gc
|
12
|
+
import importlib.resources
|
9
13
|
import logging
|
10
14
|
import operator
|
11
15
|
import os
|
16
|
+
import re
|
17
|
+
import subprocess
|
12
18
|
import sys
|
13
19
|
import tempfile
|
14
20
|
import threading
|
15
21
|
import time
|
16
22
|
import unittest
|
17
|
-
|
23
|
+
import unittest.mock
|
18
24
|
from types import ModuleType
|
19
|
-
from typing import cast
|
25
|
+
from typing import cast, Tuple
|
20
26
|
|
21
27
|
import pytest
|
22
28
|
|
23
29
|
import torch
|
30
|
+
from monarch._rust_bindings.monarch_hyperactor.actor import (
|
31
|
+
PythonMessage,
|
32
|
+
PythonMessageKind,
|
33
|
+
)
|
34
|
+
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
|
35
|
+
PortId,
|
36
|
+
PortRef,
|
37
|
+
UndeliverableMessageEnvelope,
|
38
|
+
)
|
39
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
|
40
|
+
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
|
24
41
|
|
25
|
-
from monarch._src.actor.actor_mesh import
|
42
|
+
from monarch._src.actor.actor_mesh import ActorMesh, Channel, context, Port
|
43
|
+
from monarch._src.actor.allocator import AllocHandle
|
44
|
+
from monarch._src.actor.future import Future
|
45
|
+
from monarch._src.actor.host_mesh import create_local_host_mesh, fake_in_process_host
|
46
|
+
from monarch._src.actor.proc_mesh import ProcMesh
|
26
47
|
|
27
48
|
from monarch.actor import (
|
28
49
|
Accumulator,
|
@@ -31,10 +52,10 @@ from monarch.actor import (
|
|
31
52
|
current_rank,
|
32
53
|
current_size,
|
33
54
|
endpoint,
|
34
|
-
|
35
|
-
|
36
|
-
proc_mesh,
|
55
|
+
this_host,
|
56
|
+
this_proc,
|
37
57
|
)
|
58
|
+
from monarch.tools.config import defaults
|
38
59
|
from typing_extensions import assert_type
|
39
60
|
|
40
61
|
|
@@ -67,8 +88,9 @@ class Indirect(Actor):
|
|
67
88
|
return await c.value.choose()
|
68
89
|
|
69
90
|
|
91
|
+
@pytest.mark.timeout(60)
|
70
92
|
async def test_choose():
|
71
|
-
proc = await
|
93
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
72
94
|
v = await proc.spawn("counter", Counter, 3)
|
73
95
|
i = await proc.spawn("indirect", Indirect)
|
74
96
|
v.incr.broadcast()
|
@@ -85,8 +107,9 @@ async def test_choose():
|
|
85
107
|
assert result2 == result3
|
86
108
|
|
87
109
|
|
110
|
+
@pytest.mark.timeout(60)
|
88
111
|
async def test_stream():
|
89
|
-
proc = await
|
112
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
90
113
|
v = await proc.spawn("counter2", Counter, 3)
|
91
114
|
v.incr.broadcast()
|
92
115
|
|
@@ -101,46 +124,50 @@ class To(Actor):
|
|
101
124
|
|
102
125
|
class From(Actor):
|
103
126
|
@endpoint
|
104
|
-
async def
|
127
|
+
async def fetch(self, to: To):
|
105
128
|
return [await x for x in to.whoami.stream()]
|
106
129
|
|
107
130
|
|
131
|
+
@pytest.mark.timeout(60)
|
108
132
|
async def test_mesh_passed_to_mesh():
|
109
|
-
proc = await
|
133
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
110
134
|
f = await proc.spawn("from", From)
|
111
135
|
t = await proc.spawn("to", To)
|
112
|
-
all = [y for x in f.
|
136
|
+
all = [y for x in f.fetch.stream(t) for y in await x]
|
113
137
|
assert len(all) == 4
|
114
138
|
assert all[0] != all[1]
|
115
139
|
|
116
140
|
|
141
|
+
@pytest.mark.timeout(60)
|
117
142
|
async def test_mesh_passed_to_mesh_on_different_proc_mesh():
|
118
|
-
proc = await
|
119
|
-
proc2 = await
|
143
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
144
|
+
proc2 = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
120
145
|
f = await proc.spawn("from", From)
|
121
146
|
t = await proc2.spawn("to", To)
|
122
|
-
all = [y for x in f.
|
147
|
+
all = [y for x in f.fetch.stream(t) for y in await x]
|
123
148
|
assert len(all) == 4
|
124
149
|
assert all[0] != all[1]
|
125
150
|
|
126
151
|
|
127
|
-
|
128
|
-
|
129
|
-
|
152
|
+
@pytest.mark.timeout(60)
|
153
|
+
def test_actor_slicing():
|
154
|
+
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
155
|
+
proc2 = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
130
156
|
|
131
|
-
f =
|
132
|
-
t =
|
157
|
+
f = proc.spawn("from", From)
|
158
|
+
t = proc2.spawn("to", To)
|
133
159
|
|
134
|
-
assert
|
160
|
+
assert t.slice(gpus=0).whoami.call().get() != t.slice(gpus=1).whoami.call().get()
|
135
161
|
|
136
|
-
result = [y for x in f.
|
162
|
+
result = [y for x in f.fetch.stream(t.slice(gpus=0)) for y in x.get()]
|
137
163
|
assert len(result) == 2
|
138
164
|
|
139
165
|
assert result[0] == result[1]
|
140
166
|
|
141
167
|
|
168
|
+
@pytest.mark.timeout(60)
|
142
169
|
async def test_aggregate():
|
143
|
-
proc = await
|
170
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
144
171
|
counter = await proc.spawn("counter", Counter, 1)
|
145
172
|
counter.incr.broadcast()
|
146
173
|
acc = Accumulator(counter.value, 0, operator.add)
|
@@ -153,9 +180,14 @@ class RunIt(Actor):
|
|
153
180
|
async def run(self, fn):
|
154
181
|
return fn()
|
155
182
|
|
183
|
+
@endpoint
|
184
|
+
async def return_current_rank_str(self):
|
185
|
+
return str(current_rank())
|
186
|
+
|
156
187
|
|
188
|
+
@pytest.mark.timeout(60)
|
157
189
|
async def test_rank_size():
|
158
|
-
proc = await
|
190
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
159
191
|
r = await proc.spawn("runit", RunIt)
|
160
192
|
|
161
193
|
acc = Accumulator(r.run, 0, operator.add)
|
@@ -164,35 +196,50 @@ async def test_rank_size():
|
|
164
196
|
assert 4 == await acc.accumulate(lambda: current_size()["gpus"])
|
165
197
|
|
166
198
|
|
199
|
+
@pytest.mark.timeout(60)
|
200
|
+
async def test_rank_string():
|
201
|
+
proc = fake_in_process_host().spawn_procs(per_host={"hosts": 1, "gpus": 2})
|
202
|
+
r = proc.spawn("runit", RunIt)
|
203
|
+
vm = r.return_current_rank_str.call().get()
|
204
|
+
r0 = vm.flatten("r").slice(r=0).item()
|
205
|
+
r1 = vm.flatten("r").slice(r=1).item()
|
206
|
+
assert r0 == "{'hosts': 0/1, 'gpus': 0/2}"
|
207
|
+
assert r1 == "{'hosts': 0/1, 'gpus': 1/2}"
|
208
|
+
|
209
|
+
|
167
210
|
class SyncActor(Actor):
|
168
211
|
@endpoint
|
169
212
|
def sync_endpoint(self, a_counter: Counter):
|
170
213
|
return a_counter.value.choose().get()
|
171
214
|
|
172
215
|
|
216
|
+
@pytest.mark.timeout(60)
|
173
217
|
async def test_sync_actor():
|
174
|
-
proc = await
|
218
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
175
219
|
a = await proc.spawn("actor", SyncActor)
|
176
220
|
c = await proc.spawn("counter", Counter, 5)
|
177
221
|
r = await a.sync_endpoint.choose(c)
|
178
222
|
assert r == 5
|
179
223
|
|
180
224
|
|
181
|
-
|
182
|
-
|
225
|
+
@pytest.mark.timeout(60)
|
226
|
+
def test_sync_actor_sync_client() -> None:
|
227
|
+
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
183
228
|
a = proc.spawn("actor", SyncActor).get()
|
184
229
|
c = proc.spawn("counter", Counter, 5).get()
|
185
230
|
r = a.sync_endpoint.choose(c).get()
|
186
231
|
assert r == 5
|
187
232
|
|
188
233
|
|
234
|
+
@pytest.mark.timeout(60)
|
189
235
|
def test_proc_mesh_size() -> None:
|
190
|
-
proc =
|
236
|
+
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
191
237
|
assert 2 == proc.size("gpus")
|
192
238
|
|
193
239
|
|
240
|
+
@pytest.mark.timeout(60)
|
194
241
|
def test_rank_size_sync() -> None:
|
195
|
-
proc =
|
242
|
+
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
196
243
|
r = proc.spawn("runit", RunIt).get()
|
197
244
|
|
198
245
|
acc = Accumulator(r.run, 0, operator.add)
|
@@ -200,8 +247,9 @@ def test_rank_size_sync() -> None:
|
|
200
247
|
assert 4 == acc.accumulate(lambda: current_size()["gpus"]).get()
|
201
248
|
|
202
249
|
|
250
|
+
@pytest.mark.timeout(60)
|
203
251
|
def test_accumulate_sync() -> None:
|
204
|
-
proc =
|
252
|
+
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
205
253
|
counter = proc.spawn("counter", Counter, 1).get()
|
206
254
|
counter.incr.broadcast()
|
207
255
|
acc = Accumulator(counter.value, 0, operator.add)
|
@@ -215,8 +263,9 @@ class CastToCounter(Actor):
|
|
215
263
|
return list(c.value.call().get())
|
216
264
|
|
217
265
|
|
266
|
+
@pytest.mark.timeout(60)
|
218
267
|
def test_value_mesh() -> None:
|
219
|
-
proc =
|
268
|
+
proc = fake_in_process_host().spawn_procs(per_host={"hosts": 1, "gpus": 2})
|
220
269
|
counter = proc.spawn("counter", Counter, 0).get()
|
221
270
|
counter.slice(hosts=0, gpus=1).incr.broadcast()
|
222
271
|
x = counter.value.call().get()
|
@@ -227,7 +276,18 @@ def test_value_mesh() -> None:
|
|
227
276
|
assert list(x) == n.slice(gpus=0).doit.call_one(counter).get()
|
228
277
|
|
229
278
|
|
279
|
+
@pytest.mark.timeout(60)
|
230
280
|
def test_rust_binding_modules_correct() -> None:
|
281
|
+
"""
|
282
|
+
This tests that rust bindings will survive pickling correctly.
|
283
|
+
|
284
|
+
To correctly define a rust binding, either
|
285
|
+
|
286
|
+
(1) Set its module to "monarch._rust_bindings.rust_crate.rust_module",
|
287
|
+
and make sure it is registered in monarch_extension/lib.rs
|
288
|
+
(2) Set its module to some existing python file, and use @rust_struct to install
|
289
|
+
the rust struct in that file and patch in any python extension methods.
|
290
|
+
"""
|
231
291
|
import monarch._rust_bindings as bindings
|
232
292
|
|
233
293
|
def check(module, path):
|
@@ -237,14 +297,16 @@ def test_rust_binding_modules_correct() -> None:
|
|
237
297
|
if isinstance(value, ModuleType):
|
238
298
|
check(value, f"{path}.{name}")
|
239
299
|
elif hasattr(value, "__module__"):
|
240
|
-
|
241
|
-
|
300
|
+
value_module = importlib.import_module(value.__module__)
|
301
|
+
resolved_value = getattr(value_module, value.__name__)
|
302
|
+
assert value is resolved_value
|
242
303
|
|
243
304
|
check(bindings, "monarch._rust_bindings")
|
244
305
|
|
245
306
|
|
307
|
+
@pytest.mark.timeout(60)
|
246
308
|
def test_proc_mesh_liveness() -> None:
|
247
|
-
mesh =
|
309
|
+
mesh = this_host().spawn_procs(per_host={"gpus": 2})
|
248
310
|
counter = mesh.spawn("counter", Counter, 1).get()
|
249
311
|
del mesh
|
250
312
|
# Give some time for the mesh to have been shut down.
|
@@ -269,7 +331,7 @@ class TLSActor(Actor):
|
|
269
331
|
self.local.value += 1
|
270
332
|
|
271
333
|
@endpoint
|
272
|
-
def
|
334
|
+
def get_value(self):
|
273
335
|
return self.local.value
|
274
336
|
|
275
337
|
@endpoint
|
@@ -277,16 +339,17 @@ class TLSActor(Actor):
|
|
277
339
|
return self.local.value
|
278
340
|
|
279
341
|
|
342
|
+
@pytest.mark.timeout(60)
|
280
343
|
async def test_actor_tls() -> None:
|
281
344
|
"""Test that thread-local state is respected."""
|
282
|
-
pm =
|
345
|
+
pm = this_host().spawn_procs(per_host={"gpus": 1})
|
283
346
|
am = await pm.spawn("tls", TLSActor)
|
284
347
|
await am.increment.call_one()
|
285
348
|
await am.increment_async.call_one()
|
286
349
|
await am.increment.call_one()
|
287
350
|
await am.increment_async.call_one()
|
288
351
|
|
289
|
-
assert 4 == await am.
|
352
|
+
assert 4 == await am.get_value.call_one()
|
290
353
|
assert 4 == await am.get_async.call_one()
|
291
354
|
|
292
355
|
|
@@ -302,20 +365,21 @@ class TLSActorFullSync(Actor):
|
|
302
365
|
self.local.value += 1
|
303
366
|
|
304
367
|
@endpoint
|
305
|
-
def
|
368
|
+
def get_value(self):
|
306
369
|
return self.local.value
|
307
370
|
|
308
371
|
|
372
|
+
@pytest.mark.timeout(60)
|
309
373
|
async def test_actor_tls_full_sync() -> None:
|
310
374
|
"""Test that thread-local state is respected."""
|
311
|
-
pm =
|
375
|
+
pm = this_host().spawn_procs(per_host={"gpus": 1})
|
312
376
|
am = await pm.spawn("tls", TLSActorFullSync)
|
313
377
|
await am.increment.call_one()
|
314
378
|
await am.increment.call_one()
|
315
379
|
await am.increment.call_one()
|
316
380
|
await am.increment.call_one()
|
317
381
|
|
318
|
-
assert 4 == await am.
|
382
|
+
assert 4 == await am.get_value.call_one()
|
319
383
|
|
320
384
|
|
321
385
|
class AsyncActor(Actor):
|
@@ -332,10 +396,10 @@ class AsyncActor(Actor):
|
|
332
396
|
self.should_exit = True
|
333
397
|
|
334
398
|
|
335
|
-
@pytest.mark.timeout(
|
399
|
+
@pytest.mark.timeout(30)
|
336
400
|
async def test_async_concurrency():
|
337
401
|
"""Test that async endpoints will be processed concurrently."""
|
338
|
-
pm = await
|
402
|
+
pm = await this_host().spawn_procs()
|
339
403
|
am = await pm.spawn("async", AsyncActor)
|
340
404
|
fut = am.sleep.call()
|
341
405
|
# This call should go through and exit the sleep loop, as long as we are
|
@@ -441,19 +505,35 @@ async def awaitit(f):
|
|
441
505
|
|
442
506
|
|
443
507
|
class Printer(Actor):
|
444
|
-
def __init__(self):
|
445
|
-
self.
|
446
|
-
self.logger.setLevel(INFO)
|
508
|
+
def __init__(self) -> None:
|
509
|
+
self._logger: logging.Logger = logging.getLogger()
|
447
510
|
|
448
511
|
@endpoint
|
449
|
-
async def print(self, content: str):
|
450
|
-
print(f"{
|
512
|
+
async def print(self, content: str) -> None:
|
513
|
+
print(f"{content}", flush=True)
|
514
|
+
sys.stdout.flush()
|
515
|
+
sys.stderr.flush()
|
451
516
|
|
452
517
|
@endpoint
|
453
|
-
async def log(self, content: str):
|
454
|
-
self.
|
455
|
-
|
456
|
-
|
518
|
+
async def log(self, content: str) -> None:
|
519
|
+
self._logger.error(f"{content}")
|
520
|
+
for handler in self._logger.handlers:
|
521
|
+
handler.flush()
|
522
|
+
sys.stdout.flush()
|
523
|
+
sys.stderr.flush()
|
524
|
+
|
525
|
+
def _handle_undeliverable_message(
|
526
|
+
self, message: UndeliverableMessageEnvelope
|
527
|
+
) -> bool:
|
528
|
+
# Don't throw an error on undeliverable messages. This actor is used in a test for
|
529
|
+
# stopping actor meshes, and if we throw an error here then there is a race between
|
530
|
+
# the asserted error that the mesh was stopped and the supervision error that a message
|
531
|
+
# wasn't delivered.
|
532
|
+
self._logger.error(f"Ignoring undeliverable message: {message}")
|
533
|
+
return True
|
534
|
+
|
535
|
+
|
536
|
+
@pytest.mark.timeout(60)
|
457
537
|
async def test_actor_log_streaming() -> None:
|
458
538
|
# Save original file descriptors
|
459
539
|
original_stdout_fd = os.dup(1) # stdout
|
@@ -481,19 +561,247 @@ async def test_actor_log_streaming() -> None:
|
|
481
561
|
sys.stderr = stderr_file
|
482
562
|
|
483
563
|
try:
|
484
|
-
pm =
|
564
|
+
pm = this_host().spawn_procs(per_host={"gpus": 2})
|
485
565
|
am = await pm.spawn("printer", Printer)
|
486
566
|
|
487
|
-
|
488
|
-
await
|
567
|
+
# Disable streaming logs to client
|
568
|
+
await pm.logging_option(
|
569
|
+
stream_to_client=False, aggregate_window_sec=None
|
570
|
+
)
|
571
|
+
await asyncio.sleep(1)
|
572
|
+
|
573
|
+
# These should not be streamed to client initially
|
574
|
+
for _ in range(5):
|
575
|
+
await am.print.call("no print streaming")
|
576
|
+
await am.log.call("no log streaming")
|
577
|
+
await asyncio.sleep(1)
|
578
|
+
|
579
|
+
# Enable streaming logs to client
|
580
|
+
await pm.logging_option(
|
581
|
+
stream_to_client=True, aggregate_window_sec=1, level=logging.FATAL
|
582
|
+
)
|
583
|
+
# Give it some time to reflect
|
584
|
+
await asyncio.sleep(1)
|
585
|
+
|
586
|
+
# These should be streamed to client
|
587
|
+
for _ in range(5):
|
588
|
+
await am.print.call("has print streaming")
|
589
|
+
await am.log.call("no log streaming due to level mismatch")
|
590
|
+
await asyncio.sleep(1)
|
591
|
+
|
592
|
+
# Enable streaming logs to client
|
593
|
+
await pm.logging_option(
|
594
|
+
stream_to_client=True, aggregate_window_sec=1, level=logging.ERROR
|
595
|
+
)
|
596
|
+
# Give it some time to reflect
|
597
|
+
await asyncio.sleep(1)
|
598
|
+
|
599
|
+
# These should be streamed to client
|
600
|
+
for _ in range(5):
|
601
|
+
await am.print.call("has print streaming too")
|
602
|
+
await am.log.call("has log streaming as level matched")
|
603
|
+
|
604
|
+
await pm.stop()
|
605
|
+
|
606
|
+
# Flush all outputs
|
607
|
+
stdout_file.flush()
|
608
|
+
stderr_file.flush()
|
609
|
+
os.fsync(stdout_file.fileno())
|
610
|
+
os.fsync(stderr_file.fileno())
|
611
|
+
|
612
|
+
finally:
|
613
|
+
# Restore Python's sys.stdout/stderr
|
614
|
+
sys.stdout = original_sys_stdout
|
615
|
+
sys.stderr = original_sys_stderr
|
616
|
+
|
617
|
+
# Restore original file descriptors
|
618
|
+
os.dup2(original_stdout_fd, 1)
|
619
|
+
os.dup2(original_stderr_fd, 2)
|
620
|
+
|
621
|
+
# Read the captured output
|
622
|
+
with open(stdout_path, "r") as f:
|
623
|
+
stdout_content = f.read()
|
624
|
+
|
625
|
+
with open(stderr_path, "r") as f:
|
626
|
+
stderr_content = f.read()
|
627
|
+
|
628
|
+
# Clean up temp files
|
629
|
+
os.unlink(stdout_path)
|
630
|
+
os.unlink(stderr_path)
|
631
|
+
|
632
|
+
# Assertions on the captured output
|
633
|
+
# Has a leading context so we can distinguish between streamed log and
|
634
|
+
# the log directly printed by the child processes as they share the same stdout/stderr
|
635
|
+
assert not re.search(
|
636
|
+
r"similar log lines.*no print streaming", stdout_content
|
637
|
+
), stdout_content
|
638
|
+
assert not re.search(
|
639
|
+
r"similar log lines.*no print streaming", stderr_content
|
640
|
+
), stderr_content
|
641
|
+
assert not re.search(
|
642
|
+
r"similar log lines.*no log streaming", stdout_content
|
643
|
+
), stdout_content
|
644
|
+
assert not re.search(
|
645
|
+
r"similar log lines.*no log streaming", stderr_content
|
646
|
+
), stderr_content
|
647
|
+
assert not re.search(
|
648
|
+
r"similar log lines.*no log streaming due to level mismatch", stdout_content
|
649
|
+
), stdout_content
|
650
|
+
assert not re.search(
|
651
|
+
r"similar log lines.*no log streaming due to level mismatch", stderr_content
|
652
|
+
), stderr_content
|
653
|
+
|
654
|
+
assert re.search(
|
655
|
+
r"similar log lines.*has print streaming", stdout_content
|
656
|
+
), stdout_content
|
657
|
+
assert not re.search(
|
658
|
+
r"similar log lines.*has print streaming", stderr_content
|
659
|
+
), stderr_content
|
660
|
+
assert re.search(
|
661
|
+
r"similar log lines.*has print streaming too", stdout_content
|
662
|
+
), stdout_content
|
663
|
+
assert not re.search(
|
664
|
+
r"similar log lines.*has print streaming too", stderr_content
|
665
|
+
), stderr_content
|
666
|
+
assert not re.search(
|
667
|
+
r"similar log lines.*log streaming as level matched", stdout_content
|
668
|
+
), stdout_content
|
669
|
+
assert re.search(
|
670
|
+
r"similar log lines.*log streaming as level matched",
|
671
|
+
stderr_content,
|
672
|
+
), stderr_content
|
673
|
+
|
674
|
+
finally:
|
675
|
+
# Ensure file descriptors are restored even if something goes wrong
|
676
|
+
try:
|
677
|
+
os.dup2(original_stdout_fd, 1)
|
678
|
+
os.dup2(original_stderr_fd, 2)
|
679
|
+
os.close(original_stdout_fd)
|
680
|
+
os.close(original_stderr_fd)
|
681
|
+
except OSError:
|
682
|
+
pass
|
683
|
+
|
684
|
+
|
685
|
+
@pytest.mark.timeout(120)
|
686
|
+
async def test_alloc_based_log_streaming() -> None:
|
687
|
+
"""Test both AllocHandle.stream_logs = False and True cases."""
|
688
|
+
|
689
|
+
async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
|
690
|
+
# Save original file descriptors
|
691
|
+
original_stdout_fd = os.dup(1) # stdout
|
692
|
+
|
693
|
+
try:
|
694
|
+
# Create temporary files to capture output
|
695
|
+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
|
696
|
+
stdout_path = stdout_file.name
|
697
|
+
os.dup2(stdout_file.fileno(), 1)
|
698
|
+
original_sys_stdout = sys.stdout
|
699
|
+
sys.stdout = stdout_file
|
700
|
+
|
701
|
+
try:
|
702
|
+
# Create proc mesh with custom stream_logs setting
|
703
|
+
host_mesh = create_local_host_mesh()
|
704
|
+
alloc_handle = host_mesh._alloc(hosts=1, gpus=2)
|
705
|
+
|
706
|
+
# Override the stream_logs setting
|
707
|
+
custom_alloc_handle = AllocHandle(
|
708
|
+
alloc_handle._hy_alloc, alloc_handle._extent, stream_logs
|
709
|
+
)
|
489
710
|
|
490
|
-
|
711
|
+
pm = ProcMesh.from_alloc(custom_alloc_handle)
|
712
|
+
am = await pm.spawn("printer", Printer)
|
491
713
|
|
492
|
-
|
493
|
-
await am.log.call("hello 4")
|
714
|
+
await pm.initialized
|
494
715
|
|
495
|
-
|
496
|
-
|
716
|
+
for _ in range(5):
|
717
|
+
await am.print.call(f"{test_name} print streaming")
|
718
|
+
|
719
|
+
await pm.stop()
|
720
|
+
|
721
|
+
# Flush all outputs
|
722
|
+
stdout_file.flush()
|
723
|
+
os.fsync(stdout_file.fileno())
|
724
|
+
|
725
|
+
finally:
|
726
|
+
# Restore Python's sys.stdout
|
727
|
+
sys.stdout = original_sys_stdout
|
728
|
+
|
729
|
+
# Restore original file descriptors
|
730
|
+
os.dup2(original_stdout_fd, 1)
|
731
|
+
|
732
|
+
# Read the captured output
|
733
|
+
with open(stdout_path, "r") as f:
|
734
|
+
stdout_content = f.read()
|
735
|
+
|
736
|
+
# Clean up temp files
|
737
|
+
os.unlink(stdout_path)
|
738
|
+
|
739
|
+
if not stream_logs:
|
740
|
+
# When stream_logs=False, logs should not be streamed to client
|
741
|
+
assert not re.search(
|
742
|
+
rf"similar log lines.*{test_name} print streaming", stdout_content
|
743
|
+
), f"stream_logs=True case: {stdout_content}"
|
744
|
+
assert re.search(
|
745
|
+
rf"{test_name} print streaming", stdout_content
|
746
|
+
), f"stream_logs=True case: {stdout_content}"
|
747
|
+
else:
|
748
|
+
# When stream_logs=True, logs should be streamed to client (no aggregation by default)
|
749
|
+
assert re.search(
|
750
|
+
rf"similar log lines.*{test_name} print streaming", stdout_content
|
751
|
+
), f"stream_logs=False case: {stdout_content}"
|
752
|
+
assert not re.search(
|
753
|
+
rf"\[[0-9]\]{test_name} print streaming", stdout_content
|
754
|
+
), f"stream_logs=False case: {stdout_content}"
|
755
|
+
|
756
|
+
finally:
|
757
|
+
# Ensure file descriptors are restored even if something goes wrong
|
758
|
+
try:
|
759
|
+
os.dup2(original_stdout_fd, 1)
|
760
|
+
os.close(original_stdout_fd)
|
761
|
+
except OSError:
|
762
|
+
pass
|
763
|
+
|
764
|
+
# Test both cases
|
765
|
+
await test_stream_logs_case(False, "stream_logs_false")
|
766
|
+
await test_stream_logs_case(True, "stream_logs_true")
|
767
|
+
|
768
|
+
|
769
|
+
@pytest.mark.timeout(60)
|
770
|
+
async def test_logging_option_defaults() -> None:
|
771
|
+
# Save original file descriptors
|
772
|
+
original_stdout_fd = os.dup(1) # stdout
|
773
|
+
original_stderr_fd = os.dup(2) # stderr
|
774
|
+
|
775
|
+
try:
|
776
|
+
# Create temporary files to capture output
|
777
|
+
with tempfile.NamedTemporaryFile(
|
778
|
+
mode="w+", delete=False
|
779
|
+
) as stdout_file, tempfile.NamedTemporaryFile(
|
780
|
+
mode="w+", delete=False
|
781
|
+
) as stderr_file:
|
782
|
+
stdout_path = stdout_file.name
|
783
|
+
stderr_path = stderr_file.name
|
784
|
+
|
785
|
+
# Redirect file descriptors to our temp files
|
786
|
+
# This will capture both Python and Rust output
|
787
|
+
os.dup2(stdout_file.fileno(), 1)
|
788
|
+
os.dup2(stderr_file.fileno(), 2)
|
789
|
+
|
790
|
+
# Also redirect Python's sys.stdout/stderr for completeness
|
791
|
+
original_sys_stdout = sys.stdout
|
792
|
+
original_sys_stderr = sys.stderr
|
793
|
+
sys.stdout = stdout_file
|
794
|
+
sys.stderr = stderr_file
|
795
|
+
|
796
|
+
try:
|
797
|
+
pm = await this_host().spawn_procs(per_host={"gpus": 2})
|
798
|
+
am = await pm.spawn("printer", Printer)
|
799
|
+
|
800
|
+
for _ in range(5):
|
801
|
+
await am.print.call("print streaming")
|
802
|
+
await am.log.call("log streaming")
|
803
|
+
|
804
|
+
await pm.stop()
|
497
805
|
|
498
806
|
# Flush all outputs
|
499
807
|
stdout_file.flush()
|
@@ -514,16 +822,27 @@ async def test_actor_log_streaming() -> None:
|
|
514
822
|
with open(stdout_path, "r") as f:
|
515
823
|
stdout_content = f.read()
|
516
824
|
|
825
|
+
with open(stderr_path, "r") as f:
|
826
|
+
stderr_content = f.read()
|
827
|
+
|
517
828
|
# Clean up temp files
|
518
829
|
os.unlink(stdout_path)
|
519
830
|
os.unlink(stderr_path)
|
520
831
|
|
521
|
-
#
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
assert "
|
526
|
-
|
832
|
+
# Assertions on the captured output
|
833
|
+
assert not re.search(
|
834
|
+
r"similar log lines.*print streaming", stdout_content
|
835
|
+
), stdout_content
|
836
|
+
assert re.search(r"print streaming", stdout_content), stdout_content
|
837
|
+
assert not re.search(
|
838
|
+
r"similar log lines.*print streaming", stderr_content
|
839
|
+
), stderr_content
|
840
|
+
assert not re.search(
|
841
|
+
r"similar log lines.*log streaming", stdout_content
|
842
|
+
), stdout_content
|
843
|
+
assert not re.search(
|
844
|
+
r"similar log lines.*log streaming", stderr_content
|
845
|
+
), stderr_content
|
527
846
|
|
528
847
|
finally:
|
529
848
|
# Ensure file descriptors are restored even if something goes wrong
|
@@ -536,6 +855,378 @@ async def test_actor_log_streaming() -> None:
|
|
536
855
|
pass
|
537
856
|
|
538
857
|
|
858
|
+
# oss_skip: pytest keeps complaining about mocking get_ipython module
|
859
|
+
@pytest.mark.oss_skip
|
860
|
+
@pytest.mark.timeout(180)
|
861
|
+
async def test_flush_logs_ipython() -> None:
|
862
|
+
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
|
863
|
+
# Save original file descriptors
|
864
|
+
original_stdout_fd = os.dup(1) # stdout
|
865
|
+
|
866
|
+
try:
|
867
|
+
# Create temporary files to capture output
|
868
|
+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
|
869
|
+
stdout_path = stdout_file.name
|
870
|
+
|
871
|
+
# Redirect file descriptors to our temp files
|
872
|
+
os.dup2(stdout_file.fileno(), 1)
|
873
|
+
|
874
|
+
# Also redirect Python's sys.stdout
|
875
|
+
original_sys_stdout = sys.stdout
|
876
|
+
sys.stdout = stdout_file
|
877
|
+
|
878
|
+
try:
|
879
|
+
# Mock IPython environment
|
880
|
+
class MockExecutionResult:
|
881
|
+
pass
|
882
|
+
|
883
|
+
class MockEvents:
|
884
|
+
def __init__(self):
|
885
|
+
self.callbacks = {}
|
886
|
+
self.registers = 0
|
887
|
+
self.unregisters = 0
|
888
|
+
|
889
|
+
def register(self, event_name, callback):
|
890
|
+
if event_name not in self.callbacks:
|
891
|
+
self.callbacks[event_name] = []
|
892
|
+
self.callbacks[event_name].append(callback)
|
893
|
+
self.registers += 1
|
894
|
+
|
895
|
+
def unregister(self, event_name, callback):
|
896
|
+
if event_name not in self.callbacks:
|
897
|
+
raise ValueError(f"Event {event_name} not registered")
|
898
|
+
assert callback in self.callbacks[event_name]
|
899
|
+
self.callbacks[event_name].remove(callback)
|
900
|
+
self.unregisters += 1
|
901
|
+
|
902
|
+
def trigger(self, event_name, *args, **kwargs):
|
903
|
+
if event_name in self.callbacks:
|
904
|
+
for callback in self.callbacks[event_name]:
|
905
|
+
callback(*args, **kwargs)
|
906
|
+
|
907
|
+
class MockIPython:
|
908
|
+
def __init__(self):
|
909
|
+
self.events = MockEvents()
|
910
|
+
|
911
|
+
mock_ipython = MockIPython()
|
912
|
+
|
913
|
+
with unittest.mock.patch(
|
914
|
+
"monarch._src.actor.logging.get_ipython",
|
915
|
+
lambda: mock_ipython,
|
916
|
+
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
|
917
|
+
# Make sure we can register and unregister callbacks
|
918
|
+
for i in range(3):
|
919
|
+
pm1 = await this_host().spawn_procs(per_host={"gpus": 2})
|
920
|
+
pm2 = await this_host().spawn_procs(per_host={"gpus": 2})
|
921
|
+
am1 = await pm1.spawn("printer", Printer)
|
922
|
+
am2 = await pm2.spawn("printer", Printer)
|
923
|
+
|
924
|
+
# Set aggregation window to ensure logs are buffered
|
925
|
+
await pm1.logging_option(
|
926
|
+
stream_to_client=True, aggregate_window_sec=600
|
927
|
+
)
|
928
|
+
await pm2.logging_option(
|
929
|
+
stream_to_client=True, aggregate_window_sec=600
|
930
|
+
)
|
931
|
+
# TODO: fix the following assertion
|
932
|
+
# assert mock_ipython.events.unregisters == 2 * i
|
933
|
+
|
934
|
+
# _get_controller_controller() spawns an extra local mesh
|
935
|
+
# but log streaming is disabled so it doesn't hurt
|
936
|
+
assert mock_ipython.events.registers == 1 + 2 * (i + 1)
|
937
|
+
await asyncio.sleep(1)
|
938
|
+
|
939
|
+
# Generate some logs that will be aggregated
|
940
|
+
for _ in range(5):
|
941
|
+
await am1.print.call("ipython1 test log")
|
942
|
+
await am2.print.call("ipython2 test log")
|
943
|
+
|
944
|
+
# Trigger the post_run_cell event which should flush logs
|
945
|
+
mock_ipython.events.trigger(
|
946
|
+
"post_run_cell", MockExecutionResult()
|
947
|
+
)
|
948
|
+
|
949
|
+
# Flush all outputs
|
950
|
+
stdout_file.flush()
|
951
|
+
os.fsync(stdout_file.fileno())
|
952
|
+
|
953
|
+
gc.collect()
|
954
|
+
|
955
|
+
# Same as above, _get_controller_controller() spawns an extra local mesh
|
956
|
+
assert mock_ipython.events.registers == 7
|
957
|
+
# There are many objects still taking refs
|
958
|
+
# TODO: fix the following assertion
|
959
|
+
assert mock_ipython.events.unregisters == 0
|
960
|
+
assert len(mock_ipython.events.callbacks["post_run_cell"]) == 7
|
961
|
+
finally:
|
962
|
+
# Restore Python's sys.stdout
|
963
|
+
sys.stdout = original_sys_stdout
|
964
|
+
|
965
|
+
# Restore original file descriptors
|
966
|
+
os.dup2(original_stdout_fd, 1)
|
967
|
+
|
968
|
+
# Read the captured output
|
969
|
+
with open(stdout_path, "r") as f:
|
970
|
+
stdout_content = f.read()
|
971
|
+
|
972
|
+
# TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils
|
973
|
+
|
974
|
+
# Clean up temp files
|
975
|
+
os.unlink(stdout_path)
|
976
|
+
|
977
|
+
# Verify that logs were flushed when the post_run_cell event was triggered
|
978
|
+
# We should see the aggregated logs in the output
|
979
|
+
assert (
|
980
|
+
len(
|
981
|
+
re.findall(
|
982
|
+
r"\[10 similar log lines\].*ipython1 test log", stdout_content
|
983
|
+
)
|
984
|
+
)
|
985
|
+
== 3
|
986
|
+
), stdout_content
|
987
|
+
|
988
|
+
assert (
|
989
|
+
len(
|
990
|
+
re.findall(
|
991
|
+
r"\[10 similar log lines\].*ipython2 test log", stdout_content
|
992
|
+
)
|
993
|
+
)
|
994
|
+
== 3
|
995
|
+
), stdout_content
|
996
|
+
|
997
|
+
finally:
|
998
|
+
# Ensure file descriptors are restored even if something goes wrong
|
999
|
+
try:
|
1000
|
+
os.dup2(original_stdout_fd, 1)
|
1001
|
+
os.close(original_stdout_fd)
|
1002
|
+
except OSError:
|
1003
|
+
pass
|
1004
|
+
|
1005
|
+
|
1006
|
+
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
|
1007
|
+
@pytest.mark.oss_skip
|
1008
|
+
async def test_flush_logs_fast_exit() -> None:
|
1009
|
+
# We use a subprocess to run the test so we can handle the flushed logs at the end.
|
1010
|
+
# Otherwise, it is hard to restore the original stdout/stderr.
|
1011
|
+
|
1012
|
+
test_bin = importlib.resources.files(str(__package__)).joinpath("test_bin")
|
1013
|
+
|
1014
|
+
# Run the binary in a separate process and capture stdout and stderr
|
1015
|
+
cmd = [str(test_bin), "flush-logs"]
|
1016
|
+
|
1017
|
+
process = subprocess.run(cmd, capture_output=True, timeout=60, text=True)
|
1018
|
+
|
1019
|
+
# Check if the process ended without error
|
1020
|
+
if process.returncode != 0:
|
1021
|
+
raise RuntimeError(f"{cmd} ended with error code {process.returncode}. ")
|
1022
|
+
|
1023
|
+
# Assertions on the captured output, 160 = 32 procs * 5 logs per proc
|
1024
|
+
# 32 and 5 are specified in the test_bin flush-logs.
|
1025
|
+
assert (
|
1026
|
+
len(
|
1027
|
+
re.findall(
|
1028
|
+
r"160 similar log lines.*has print streaming",
|
1029
|
+
process.stdout,
|
1030
|
+
)
|
1031
|
+
)
|
1032
|
+
== 1
|
1033
|
+
), process.stdout
|
1034
|
+
|
1035
|
+
|
1036
|
+
@pytest.mark.timeout(60)
|
1037
|
+
async def test_flush_on_disable_aggregation() -> None:
|
1038
|
+
"""Test that logs are flushed when disabling aggregation.
|
1039
|
+
|
1040
|
+
This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation."
|
1041
|
+
"""
|
1042
|
+
# Save original file descriptors
|
1043
|
+
original_stdout_fd = os.dup(1) # stdout
|
1044
|
+
|
1045
|
+
try:
|
1046
|
+
# Create temporary files to capture output
|
1047
|
+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
|
1048
|
+
stdout_path = stdout_file.name
|
1049
|
+
|
1050
|
+
# Redirect file descriptors to our temp files
|
1051
|
+
os.dup2(stdout_file.fileno(), 1)
|
1052
|
+
|
1053
|
+
# Also redirect Python's sys.stdout
|
1054
|
+
original_sys_stdout = sys.stdout
|
1055
|
+
sys.stdout = stdout_file
|
1056
|
+
|
1057
|
+
try:
|
1058
|
+
pm = await this_host().spawn_procs(per_host={"gpus": 2})
|
1059
|
+
am = await pm.spawn("printer", Printer)
|
1060
|
+
|
1061
|
+
# Set a long aggregation window to ensure logs aren't flushed immediately
|
1062
|
+
await pm.logging_option(stream_to_client=True, aggregate_window_sec=60)
|
1063
|
+
|
1064
|
+
# Generate some logs that will be aggregated but not flushed immediately
|
1065
|
+
for _ in range(5):
|
1066
|
+
await am.print.call("aggregated log line")
|
1067
|
+
await asyncio.sleep(1)
|
1068
|
+
|
1069
|
+
# Now disable aggregation - this should trigger an immediate flush
|
1070
|
+
await pm.logging_option(
|
1071
|
+
stream_to_client=True, aggregate_window_sec=None
|
1072
|
+
)
|
1073
|
+
|
1074
|
+
# Wait a bit to ensure logs are collected
|
1075
|
+
await asyncio.sleep(1)
|
1076
|
+
for _ in range(5):
|
1077
|
+
await am.print.call("single log line")
|
1078
|
+
|
1079
|
+
await pm.stop()
|
1080
|
+
|
1081
|
+
# Flush all outputs
|
1082
|
+
stdout_file.flush()
|
1083
|
+
os.fsync(stdout_file.fileno())
|
1084
|
+
|
1085
|
+
finally:
|
1086
|
+
# Restore Python's sys.stdout
|
1087
|
+
sys.stdout = original_sys_stdout
|
1088
|
+
|
1089
|
+
# Restore original file descriptors
|
1090
|
+
os.dup2(original_stdout_fd, 1)
|
1091
|
+
|
1092
|
+
# Read the captured output
|
1093
|
+
with open(stdout_path, "r") as f:
|
1094
|
+
stdout_content = f.read()
|
1095
|
+
|
1096
|
+
# Clean up temp files
|
1097
|
+
os.unlink(stdout_path)
|
1098
|
+
|
1099
|
+
# Verify that logs were flushed when aggregation was disabled
|
1100
|
+
# We should see the aggregated logs in the output
|
1101
|
+
# 10 = 5 log lines * 2 procs
|
1102
|
+
assert re.search(
|
1103
|
+
r"\[10 similar log lines\].*aggregated log line", stdout_content
|
1104
|
+
), stdout_content
|
1105
|
+
|
1106
|
+
# No aggregated single log lines
|
1107
|
+
assert not re.search(
|
1108
|
+
r"similar log lines.*single log line", stdout_content
|
1109
|
+
), stdout_content
|
1110
|
+
|
1111
|
+
# 10 = 5 log lines * 2 procs
|
1112
|
+
assert (
|
1113
|
+
len(re.findall(r"\[.* [0-9]+\] single log line", stdout_content)) == 10
|
1114
|
+
), stdout_content
|
1115
|
+
|
1116
|
+
finally:
|
1117
|
+
# Ensure file descriptors are restored even if something goes wrong
|
1118
|
+
try:
|
1119
|
+
os.dup2(original_stdout_fd, 1)
|
1120
|
+
os.close(original_stdout_fd)
|
1121
|
+
except OSError:
|
1122
|
+
pass
|
1123
|
+
|
1124
|
+
|
1125
|
+
@pytest.mark.timeout(120)
|
1126
|
+
async def test_multiple_ongoing_flushes_no_deadlock() -> None:
|
1127
|
+
"""
|
1128
|
+
The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked.
|
1129
|
+
Because now a flush call is purely sync, it is very easy to get into a deadlock.
|
1130
|
+
So we assert the last flush call will not get into such a state.
|
1131
|
+
"""
|
1132
|
+
pm = this_host().spawn_procs(per_host={"gpus": 4})
|
1133
|
+
am = pm.spawn("printer", Printer)
|
1134
|
+
|
1135
|
+
# Generate some logs that will be aggregated but not flushed immediately
|
1136
|
+
for _ in range(10):
|
1137
|
+
await am.print.call("aggregated log line")
|
1138
|
+
|
1139
|
+
log_mesh = pm._logging_manager._logging_mesh_client
|
1140
|
+
assert log_mesh is not None
|
1141
|
+
futures = []
|
1142
|
+
for _ in range(5):
|
1143
|
+
# FIXME: the order of futures doesn't necessarily mean the order of flushes due to the async nature.
|
1144
|
+
await asyncio.sleep(0.1)
|
1145
|
+
futures.append(Future(coro=log_mesh.flush().spawn().task()))
|
1146
|
+
|
1147
|
+
# The last flush should not block
|
1148
|
+
futures[-1].get()
|
1149
|
+
|
1150
|
+
|
1151
|
+
@pytest.mark.timeout(60)
|
1152
|
+
async def test_adjust_aggregation_window() -> None:
|
1153
|
+
"""Test that the flush deadline is updated when the aggregation window is adjusted.
|
1154
|
+
|
1155
|
+
This tests the corner case: "This can happen if the user has adjusted the aggregation window."
|
1156
|
+
"""
|
1157
|
+
# Save original file descriptors
|
1158
|
+
original_stdout_fd = os.dup(1) # stdout
|
1159
|
+
|
1160
|
+
try:
|
1161
|
+
# Create temporary files to capture output
|
1162
|
+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
|
1163
|
+
stdout_path = stdout_file.name
|
1164
|
+
|
1165
|
+
# Redirect file descriptors to our temp files
|
1166
|
+
os.dup2(stdout_file.fileno(), 1)
|
1167
|
+
|
1168
|
+
# Also redirect Python's sys.stdout
|
1169
|
+
original_sys_stdout = sys.stdout
|
1170
|
+
sys.stdout = stdout_file
|
1171
|
+
|
1172
|
+
try:
|
1173
|
+
pm = await this_host().spawn_procs(per_host={"gpus": 2})
|
1174
|
+
am = await pm.spawn("printer", Printer)
|
1175
|
+
|
1176
|
+
# Set a long aggregation window initially
|
1177
|
+
await pm.logging_option(stream_to_client=True, aggregate_window_sec=100)
|
1178
|
+
|
1179
|
+
# Generate some logs that will be aggregated
|
1180
|
+
for _ in range(3):
|
1181
|
+
await am.print.call("first batch of logs")
|
1182
|
+
await asyncio.sleep(1)
|
1183
|
+
|
1184
|
+
# Now adjust to a shorter window - this should update the flush deadline
|
1185
|
+
await pm.logging_option(stream_to_client=True, aggregate_window_sec=2)
|
1186
|
+
|
1187
|
+
# Generate more logs
|
1188
|
+
for _ in range(3):
|
1189
|
+
await am.print.call("second batch of logs")
|
1190
|
+
|
1191
|
+
await pm.stop()
|
1192
|
+
|
1193
|
+
# Flush all outputs
|
1194
|
+
stdout_file.flush()
|
1195
|
+
os.fsync(stdout_file.fileno())
|
1196
|
+
|
1197
|
+
finally:
|
1198
|
+
# Restore Python's sys.stdout/stderr
|
1199
|
+
sys.stdout = original_sys_stdout
|
1200
|
+
|
1201
|
+
# Restore original file descriptors
|
1202
|
+
os.dup2(original_stdout_fd, 1)
|
1203
|
+
|
1204
|
+
# Read the captured output
|
1205
|
+
with open(stdout_path, "r") as f:
|
1206
|
+
stdout_content = f.read()
|
1207
|
+
|
1208
|
+
# Clean up temp files
|
1209
|
+
os.unlink(stdout_path)
|
1210
|
+
|
1211
|
+
# Verify that logs were flushed when the aggregation window was adjusted
|
1212
|
+
# We should see both batches of logs in the output
|
1213
|
+
assert re.search(
|
1214
|
+
r"\[6 similar log lines\].*first batch of logs", stdout_content
|
1215
|
+
), stdout_content
|
1216
|
+
|
1217
|
+
assert re.search(
|
1218
|
+
r"similar log lines.*second batch of logs", stdout_content
|
1219
|
+
), stdout_content
|
1220
|
+
|
1221
|
+
finally:
|
1222
|
+
# Ensure file descriptors are restored even if something goes wrong
|
1223
|
+
try:
|
1224
|
+
os.dup2(original_stdout_fd, 1)
|
1225
|
+
os.close(original_stdout_fd)
|
1226
|
+
except OSError:
|
1227
|
+
pass
|
1228
|
+
|
1229
|
+
|
539
1230
|
class SendAlot(Actor):
|
540
1231
|
@endpoint
|
541
1232
|
async def send(self, port: Port[int]):
|
@@ -543,10 +1234,11 @@ class SendAlot(Actor):
|
|
543
1234
|
port.send(i)
|
544
1235
|
|
545
1236
|
|
546
|
-
|
547
|
-
|
1237
|
+
@pytest.mark.timeout(60)
|
1238
|
+
def test_port_as_argument() -> None:
|
1239
|
+
proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 1})
|
548
1240
|
s = proc_mesh.spawn("send_alot", SendAlot).get()
|
549
|
-
send, recv =
|
1241
|
+
send, recv = Channel[int].open()
|
550
1242
|
|
551
1243
|
s.send.broadcast(send)
|
552
1244
|
|
@@ -554,14 +1246,14 @@ def test_port_as_argument():
|
|
554
1246
|
assert i == recv.recv().get()
|
555
1247
|
|
556
1248
|
|
557
|
-
@pytest.mark.timeout(
|
1249
|
+
@pytest.mark.timeout(30)
|
558
1250
|
async def test_same_actor_twice() -> None:
|
559
|
-
pm =
|
560
|
-
await pm.spawn("dup", Counter, 0)
|
1251
|
+
pm = this_host().spawn_procs(per_host={"gpus": 1})
|
1252
|
+
await pm.spawn("dup", Counter, 0).initialized
|
561
1253
|
|
562
1254
|
# The second spawn with the same name should fail with a specific error
|
563
1255
|
with pytest.raises(Exception) as exc_info:
|
564
|
-
await pm.spawn("dup", Counter, 0)
|
1256
|
+
await pm.spawn("dup", Counter, 0).initialized
|
565
1257
|
|
566
1258
|
# Assert that the error message contains the expected text about duplicate actor name
|
567
1259
|
error_msg = str(exc_info.value)
|
@@ -570,23 +1262,81 @@ async def test_same_actor_twice() -> None:
|
|
570
1262
|
), f"Expected error message about duplicate actor name, got: {error_msg}"
|
571
1263
|
|
572
1264
|
|
1265
|
+
class LsActor(Actor):
|
1266
|
+
def __init__(self, workspace: str):
|
1267
|
+
self.workspace = workspace
|
1268
|
+
|
1269
|
+
@endpoint
|
1270
|
+
async def ls(self) -> list[str]:
|
1271
|
+
return os.listdir(self.workspace)
|
1272
|
+
|
1273
|
+
|
1274
|
+
async def test_sync_workspace() -> None:
|
1275
|
+
# create two workspaces: one for local and one for remote
|
1276
|
+
with tempfile.TemporaryDirectory() as workspace_src, tempfile.TemporaryDirectory() as workspace_dst:
|
1277
|
+
|
1278
|
+
def bootstrap_WORKSPACE_DIR() -> None:
|
1279
|
+
import os
|
1280
|
+
|
1281
|
+
os.environ["WORKSPACE_DIR"] = workspace_dst
|
1282
|
+
|
1283
|
+
pm = this_host().spawn_procs(
|
1284
|
+
per_host={"gpus": 1}, bootstrap=bootstrap_WORKSPACE_DIR
|
1285
|
+
)
|
1286
|
+
|
1287
|
+
config = defaults.config("slurm", workspace_src)
|
1288
|
+
await pm.sync_workspace(workspace=config.workspace, auto_reload=True)
|
1289
|
+
|
1290
|
+
# no file in remote workspace initially
|
1291
|
+
am = await pm.spawn("ls", LsActor, workspace_dst)
|
1292
|
+
for item in list(am.ls.call().get()):
|
1293
|
+
assert len(item[1]) == 0
|
1294
|
+
|
1295
|
+
# write a file to local workspace
|
1296
|
+
file_path = os.path.join(workspace_src, "new_file")
|
1297
|
+
with open(file_path, "w") as f:
|
1298
|
+
f.write("hello world")
|
1299
|
+
f.flush()
|
1300
|
+
|
1301
|
+
# force a sync and it should populate on the dst workspace
|
1302
|
+
await pm.sync_workspace(config.workspace, auto_reload=True)
|
1303
|
+
for item in list(am.ls.call().get()):
|
1304
|
+
assert len(item[1]) == 1
|
1305
|
+
assert item[1][0] == "new_file"
|
1306
|
+
file_path = os.path.join(workspace_dst, item[1][0])
|
1307
|
+
with open(file_path, "r") as f:
|
1308
|
+
assert f.readline() == "hello world"
|
1309
|
+
|
1310
|
+
# sanity check
|
1311
|
+
assert "WORKSPACE_DIR" not in os.environ, "test leaves env var side-effects!"
|
1312
|
+
|
1313
|
+
|
573
1314
|
class TestActorMeshStop(unittest.IsolatedAsyncioTestCase):
|
574
1315
|
async def test_actor_mesh_stop(self) -> None:
|
575
|
-
pm =
|
1316
|
+
pm = this_host().spawn_procs(per_host={"gpus": 2})
|
576
1317
|
am_1 = await pm.spawn("printer", Printer)
|
577
1318
|
am_2 = await pm.spawn("printer2", Printer)
|
578
1319
|
await am_1.print.call("hello 1")
|
579
1320
|
await am_1.log.call("hello 2")
|
580
|
-
await cast(
|
1321
|
+
await cast(ActorMesh, am_1).stop()
|
581
1322
|
|
582
1323
|
with self.assertRaisesRegex(
|
583
|
-
RuntimeError, expected_regex="`
|
1324
|
+
RuntimeError, expected_regex="`PythonActorMesh` has already been stopped"
|
584
1325
|
):
|
585
1326
|
await am_1.print.call("hello 1")
|
586
1327
|
|
587
1328
|
await am_2.print.call("hello 3")
|
588
1329
|
await am_2.log.call("hello 4")
|
589
1330
|
|
1331
|
+
await pm.stop()
|
1332
|
+
|
1333
|
+
async def test_proc_mesh_stop_after_actor_mesh_stop(self) -> None:
|
1334
|
+
pm = this_host().spawn_procs(per_host={"gpus": 2})
|
1335
|
+
am = await pm.spawn("printer", Printer)
|
1336
|
+
|
1337
|
+
await cast(ActorMesh, am).stop()
|
1338
|
+
await pm.stop()
|
1339
|
+
|
590
1340
|
|
591
1341
|
class PortedActor(Actor):
|
592
1342
|
@endpoint(explicit_response_port=True)
|
@@ -594,7 +1344,161 @@ class PortedActor(Actor):
|
|
594
1344
|
port.send(3 + b)
|
595
1345
|
|
596
1346
|
|
1347
|
+
@pytest.mark.timeout(60)
|
597
1348
|
def test_ported_actor():
|
598
|
-
proc_mesh =
|
1349
|
+
proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 1}).get()
|
599
1350
|
a = proc_mesh.spawn("port_actor", PortedActor).get()
|
600
1351
|
assert 5 == a.add.call_one(2).get()
|
1352
|
+
|
1353
|
+
|
1354
|
+
async def _recv():
|
1355
|
+
return (7, 2, 3)
|
1356
|
+
|
1357
|
+
|
1358
|
+
async def consume():
|
1359
|
+
r = await PythonTask.from_coroutine(_recv())
|
1360
|
+
assert r == (7, 2, 3)
|
1361
|
+
|
1362
|
+
|
1363
|
+
@pytest.mark.timeout(60)
|
1364
|
+
def test_python_task_tuple() -> None:
|
1365
|
+
PythonTask.from_coroutine(consume()).block_on()
|
1366
|
+
|
1367
|
+
|
1368
|
+
def test_select_result() -> None:
|
1369
|
+
def s(t):
|
1370
|
+
time.sleep(t)
|
1371
|
+
return t
|
1372
|
+
|
1373
|
+
a = PythonTask.spawn_blocking(lambda: s(4))
|
1374
|
+
b = PythonTask.spawn_blocking(lambda: s(0))
|
1375
|
+
r = PythonTask.select_one([a.task(), b.task()]).block_on()
|
1376
|
+
assert r == (0, 1)
|
1377
|
+
|
1378
|
+
|
1379
|
+
def test_mesh_len():
|
1380
|
+
proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 12})
|
1381
|
+
s = proc_mesh.spawn("sync_actor", SyncActor).get()
|
1382
|
+
assert 12 == len(s)
|
1383
|
+
|
1384
|
+
|
1385
|
+
class UndeliverableMessageReceiver(Actor):
|
1386
|
+
def __init__(self):
|
1387
|
+
self._messages = asyncio.Queue()
|
1388
|
+
|
1389
|
+
@endpoint
|
1390
|
+
async def receive_undeliverable(
|
1391
|
+
self, sender: ActorId, dest: PortId, error_msg: str
|
1392
|
+
) -> None:
|
1393
|
+
await self._messages.put((sender, dest, error_msg))
|
1394
|
+
|
1395
|
+
@endpoint
|
1396
|
+
async def get_messages(self) -> Tuple[ActorId, PortId, str]:
|
1397
|
+
return await self._messages.get()
|
1398
|
+
|
1399
|
+
|
1400
|
+
class UndeliverableMessageSender(Actor):
|
1401
|
+
@endpoint
|
1402
|
+
def send_undeliverable(self) -> None:
|
1403
|
+
mailbox = context().actor_instance._mailbox
|
1404
|
+
port_id = PortId(
|
1405
|
+
actor_id=ActorId(
|
1406
|
+
world_name=mailbox.actor_id.world_name, rank=0, actor_name="bogus"
|
1407
|
+
),
|
1408
|
+
port=1234,
|
1409
|
+
)
|
1410
|
+
port_ref = PortRef(port_id)
|
1411
|
+
port_ref.send(
|
1412
|
+
mailbox,
|
1413
|
+
PythonMessage(PythonMessageKind.Result(None), b"123"),
|
1414
|
+
)
|
1415
|
+
|
1416
|
+
|
1417
|
+
class UndeliverableMessageSenderWithOverride(UndeliverableMessageSender):
|
1418
|
+
def __init__(self, receiver: UndeliverableMessageReceiver):
|
1419
|
+
self._receiver = receiver
|
1420
|
+
|
1421
|
+
def _handle_undeliverable_message(
|
1422
|
+
self, message: UndeliverableMessageEnvelope
|
1423
|
+
) -> bool:
|
1424
|
+
self._receiver.receive_undeliverable.call_one(
|
1425
|
+
message.sender(), message.dest(), message.error_msg()
|
1426
|
+
).get()
|
1427
|
+
return True
|
1428
|
+
|
1429
|
+
|
1430
|
+
@pytest.mark.timeout(60)
|
1431
|
+
async def test_undeliverable_message_with_override() -> None:
|
1432
|
+
pm = this_host().spawn_procs(per_host={"gpus": 1})
|
1433
|
+
receiver = pm.spawn("undeliverable_receiver", UndeliverableMessageReceiver)
|
1434
|
+
sender = pm.spawn(
|
1435
|
+
"undeliverable_sender", UndeliverableMessageSenderWithOverride, receiver
|
1436
|
+
)
|
1437
|
+
sender.send_undeliverable.call().get()
|
1438
|
+
sender, dest, error_msg = receiver.get_messages.call_one().get()
|
1439
|
+
assert sender.actor_name == "undeliverable_sender"
|
1440
|
+
assert dest.actor_id.actor_name == "bogus"
|
1441
|
+
assert error_msg is not None
|
1442
|
+
pm.stop().get()
|
1443
|
+
|
1444
|
+
|
1445
|
+
@pytest.mark.timeout(60)
|
1446
|
+
async def test_undeliverable_message_without_override() -> None:
|
1447
|
+
pm = this_host().spawn_procs(per_host={"gpus": 1})
|
1448
|
+
sender = pm.spawn("undeliverable_sender", UndeliverableMessageSender)
|
1449
|
+
sender.send_undeliverable.call().get()
|
1450
|
+
# Wait a few seconds to ensure that the undeliverable message is processed
|
1451
|
+
# without crashing anything
|
1452
|
+
await asyncio.sleep(5)
|
1453
|
+
pm.stop().get()
|
1454
|
+
|
1455
|
+
|
1456
|
+
def test_this_and_that():
|
1457
|
+
counter = this_proc().spawn("counter", Counter, 7)
|
1458
|
+
assert 7 == counter.value.call_one().get()
|
1459
|
+
|
1460
|
+
|
1461
|
+
class ReceptorActor(Actor):
|
1462
|
+
@endpoint
|
1463
|
+
def status(self):
|
1464
|
+
return 1
|
1465
|
+
|
1466
|
+
|
1467
|
+
async def test_things_survive_losing_python_reference() -> None:
|
1468
|
+
"""Test the slice_receptor_mesh function in LOCAL mode, verifying that setup methods are called."""
|
1469
|
+
|
1470
|
+
receptor = (
|
1471
|
+
this_host()
|
1472
|
+
.spawn_procs(per_host={"gpus": 1})
|
1473
|
+
.spawn(
|
1474
|
+
"receptor",
|
1475
|
+
ReceptorActor,
|
1476
|
+
)
|
1477
|
+
)
|
1478
|
+
receptor = receptor.slice(gpus=0)
|
1479
|
+
|
1480
|
+
await receptor.status.call()
|
1481
|
+
|
1482
|
+
|
1483
|
+
class IsInit(Actor):
|
1484
|
+
@endpoint
|
1485
|
+
def is_cuda_initialized(self) -> bool:
|
1486
|
+
cuda = ctypes.CDLL("libcuda.so.1")
|
1487
|
+
CUresult = ctypes.c_int
|
1488
|
+
cuDeviceGetCount = cuda.cuDeviceGetCount
|
1489
|
+
cuDeviceGetCount.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
1490
|
+
cuDeviceGetCount.restype = CUresult
|
1491
|
+
count = ctypes.c_int()
|
1492
|
+
result = cuDeviceGetCount(ctypes.byref(count))
|
1493
|
+
CUDA_ERROR_NOT_INITIALIZED = 3
|
1494
|
+
return result == CUDA_ERROR_NOT_INITIALIZED
|
1495
|
+
|
1496
|
+
|
1497
|
+
@pytest.mark.oss_skip
|
1498
|
+
def test_cuda_is_not_initialized_in_a_new_proc():
|
1499
|
+
try:
|
1500
|
+
ctypes.CDLL("libcuda.so.1")
|
1501
|
+
except OSError:
|
1502
|
+
pytest.skip("cannot find cuda")
|
1503
|
+
proc = this_host().spawn_procs().spawn("is_init", IsInit)
|
1504
|
+
assert not proc.is_cuda_initialized.call_one().get()
|