torchmonarch-nightly 2025.8.2__cp313-cp313-manylinux2014_x86_64.whl → 2025.9.3__cp313-cp313-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/_rust_bindings.so +0 -0
- monarch/_src/actor/actor_mesh.py +414 -216
- monarch/_src/actor/allocator.py +75 -6
- monarch/_src/actor/bootstrap_main.py +7 -4
- monarch/_src/actor/code_sync/__init__.py +2 -0
- monarch/_src/actor/debugger/__init__.py +7 -0
- monarch/_src/actor/{debugger.py → debugger/debugger.py} +246 -135
- monarch/_src/actor/{pdb_wrapper.py → debugger/pdb_wrapper.py} +62 -23
- monarch/_src/actor/endpoint.py +27 -45
- monarch/_src/actor/future.py +86 -24
- monarch/_src/actor/host_mesh.py +125 -0
- monarch/_src/actor/logging.py +94 -0
- monarch/_src/actor/pickle.py +25 -0
- monarch/_src/actor/proc_mesh.py +423 -156
- monarch/_src/actor/python_extension_methods.py +90 -0
- monarch/_src/actor/shape.py +8 -1
- monarch/_src/actor/source_loader.py +45 -0
- monarch/_src/actor/telemetry/__init__.py +172 -0
- monarch/_src/actor/telemetry/rust_span_tracing.py +6 -39
- monarch/_src/debug_cli/__init__.py +7 -0
- monarch/_src/debug_cli/debug_cli.py +43 -0
- monarch/_src/tensor_engine/rdma.py +64 -9
- monarch/_testing.py +1 -3
- monarch/actor/__init__.py +24 -4
- monarch/common/_C.so +0 -0
- monarch/common/device_mesh.py +14 -0
- monarch/common/future.py +10 -0
- monarch/common/remote.py +14 -25
- monarch/common/tensor.py +12 -0
- monarch/debug_cli/__init__.py +7 -0
- monarch/debug_cli/__main__.py +12 -0
- monarch/fetch.py +2 -2
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +4 -2
- monarch/mesh_controller.py +34 -14
- monarch/monarch_controller +0 -0
- monarch/tools/colors.py +25 -0
- monarch/tools/commands.py +42 -7
- monarch/tools/components/hyperactor.py +1 -1
- monarch/tools/config/__init__.py +31 -4
- monarch/tools/config/defaults.py +13 -3
- monarch/tools/config/environment.py +45 -0
- monarch/tools/config/workspace.py +165 -0
- monarch/tools/mesh_spec.py +2 -0
- monarch/utils/__init__.py +9 -0
- monarch/utils/utils.py +78 -0
- tests/error_test_binary.py +5 -3
- tests/python_actor_test_binary.py +52 -0
- tests/test_actor_error.py +142 -14
- tests/test_alloc.py +1 -1
- tests/test_allocator.py +59 -72
- tests/test_debugger.py +639 -45
- tests/test_env_before_cuda.py +4 -4
- tests/test_mesh_trait.py +38 -0
- tests/test_python_actors.py +965 -75
- tests/test_rdma.py +7 -6
- tests/test_tensor_engine.py +6 -6
- {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/METADATA +82 -4
- {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/RECORD +63 -47
- {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/WHEEL +0 -0
- {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/entry_points.txt +0 -0
- {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/licenses/LICENSE +0 -0
- {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/top_level.txt +0 -0
tests/test_python_actors.py
CHANGED
@@ -5,25 +5,45 @@
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
6
6
|
|
7
7
|
# pyre-unsafe
|
8
|
+
|
8
9
|
import asyncio
|
10
|
+
import ctypes
|
11
|
+
import gc
|
12
|
+
import importlib.resources
|
9
13
|
import logging
|
10
14
|
import operator
|
11
15
|
import os
|
16
|
+
import re
|
17
|
+
import subprocess
|
12
18
|
import sys
|
13
19
|
import tempfile
|
14
20
|
import threading
|
15
21
|
import time
|
16
22
|
import unittest
|
17
|
-
|
23
|
+
import unittest.mock
|
18
24
|
from types import ModuleType
|
19
|
-
from typing import cast
|
25
|
+
from typing import cast, Tuple
|
20
26
|
|
21
27
|
import pytest
|
22
28
|
|
23
29
|
import torch
|
30
|
+
from monarch._rust_bindings.monarch_hyperactor.actor import (
|
31
|
+
PythonMessage,
|
32
|
+
PythonMessageKind,
|
33
|
+
)
|
34
|
+
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
|
35
|
+
PortId,
|
36
|
+
PortRef,
|
37
|
+
UndeliverableMessageEnvelope,
|
38
|
+
)
|
39
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
|
24
40
|
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
|
25
41
|
|
26
|
-
from monarch._src.actor.actor_mesh import
|
42
|
+
from monarch._src.actor.actor_mesh import ActorMesh, Channel, context, Port
|
43
|
+
from monarch._src.actor.allocator import AllocHandle
|
44
|
+
from monarch._src.actor.future import Future
|
45
|
+
from monarch._src.actor.host_mesh import create_local_host_mesh, fake_in_process_host
|
46
|
+
from monarch._src.actor.proc_mesh import ProcMesh
|
27
47
|
|
28
48
|
from monarch.actor import (
|
29
49
|
Accumulator,
|
@@ -32,10 +52,10 @@ from monarch.actor import (
|
|
32
52
|
current_rank,
|
33
53
|
current_size,
|
34
54
|
endpoint,
|
35
|
-
|
36
|
-
|
37
|
-
proc_mesh,
|
55
|
+
this_host,
|
56
|
+
this_proc,
|
38
57
|
)
|
58
|
+
from monarch.tools.config import defaults
|
39
59
|
from typing_extensions import assert_type
|
40
60
|
|
41
61
|
|
@@ -68,8 +88,9 @@ class Indirect(Actor):
|
|
68
88
|
return await c.value.choose()
|
69
89
|
|
70
90
|
|
91
|
+
@pytest.mark.timeout(60)
|
71
92
|
async def test_choose():
|
72
|
-
proc = await
|
93
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
73
94
|
v = await proc.spawn("counter", Counter, 3)
|
74
95
|
i = await proc.spawn("indirect", Indirect)
|
75
96
|
v.incr.broadcast()
|
@@ -86,8 +107,9 @@ async def test_choose():
|
|
86
107
|
assert result2 == result3
|
87
108
|
|
88
109
|
|
110
|
+
@pytest.mark.timeout(60)
|
89
111
|
async def test_stream():
|
90
|
-
proc = await
|
112
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
91
113
|
v = await proc.spawn("counter2", Counter, 3)
|
92
114
|
v.incr.broadcast()
|
93
115
|
|
@@ -102,46 +124,50 @@ class To(Actor):
|
|
102
124
|
|
103
125
|
class From(Actor):
|
104
126
|
@endpoint
|
105
|
-
async def
|
127
|
+
async def fetch(self, to: To):
|
106
128
|
return [await x for x in to.whoami.stream()]
|
107
129
|
|
108
130
|
|
131
|
+
@pytest.mark.timeout(60)
|
109
132
|
async def test_mesh_passed_to_mesh():
|
110
|
-
proc = await
|
133
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
111
134
|
f = await proc.spawn("from", From)
|
112
135
|
t = await proc.spawn("to", To)
|
113
|
-
all = [y for x in f.
|
136
|
+
all = [y for x in f.fetch.stream(t) for y in await x]
|
114
137
|
assert len(all) == 4
|
115
138
|
assert all[0] != all[1]
|
116
139
|
|
117
140
|
|
141
|
+
@pytest.mark.timeout(60)
|
118
142
|
async def test_mesh_passed_to_mesh_on_different_proc_mesh():
|
119
|
-
proc = await
|
120
|
-
proc2 = await
|
143
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
144
|
+
proc2 = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
121
145
|
f = await proc.spawn("from", From)
|
122
146
|
t = await proc2.spawn("to", To)
|
123
|
-
all = [y for x in f.
|
147
|
+
all = [y for x in f.fetch.stream(t) for y in await x]
|
124
148
|
assert len(all) == 4
|
125
149
|
assert all[0] != all[1]
|
126
150
|
|
127
151
|
|
128
|
-
|
129
|
-
|
130
|
-
|
152
|
+
@pytest.mark.timeout(60)
|
153
|
+
def test_actor_slicing():
|
154
|
+
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
155
|
+
proc2 = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
131
156
|
|
132
|
-
f =
|
133
|
-
t =
|
157
|
+
f = proc.spawn("from", From)
|
158
|
+
t = proc2.spawn("to", To)
|
134
159
|
|
135
|
-
assert
|
160
|
+
assert t.slice(gpus=0).whoami.call().get() != t.slice(gpus=1).whoami.call().get()
|
136
161
|
|
137
|
-
result = [y for x in f.
|
162
|
+
result = [y for x in f.fetch.stream(t.slice(gpus=0)) for y in x.get()]
|
138
163
|
assert len(result) == 2
|
139
164
|
|
140
165
|
assert result[0] == result[1]
|
141
166
|
|
142
167
|
|
168
|
+
@pytest.mark.timeout(60)
|
143
169
|
async def test_aggregate():
|
144
|
-
proc = await
|
170
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
145
171
|
counter = await proc.spawn("counter", Counter, 1)
|
146
172
|
counter.incr.broadcast()
|
147
173
|
acc = Accumulator(counter.value, 0, operator.add)
|
@@ -154,9 +180,14 @@ class RunIt(Actor):
|
|
154
180
|
async def run(self, fn):
|
155
181
|
return fn()
|
156
182
|
|
183
|
+
@endpoint
|
184
|
+
async def return_current_rank_str(self):
|
185
|
+
return str(current_rank())
|
186
|
+
|
157
187
|
|
188
|
+
@pytest.mark.timeout(60)
|
158
189
|
async def test_rank_size():
|
159
|
-
proc = await
|
190
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
160
191
|
r = await proc.spawn("runit", RunIt)
|
161
192
|
|
162
193
|
acc = Accumulator(r.run, 0, operator.add)
|
@@ -165,35 +196,50 @@ async def test_rank_size():
|
|
165
196
|
assert 4 == await acc.accumulate(lambda: current_size()["gpus"])
|
166
197
|
|
167
198
|
|
199
|
+
@pytest.mark.timeout(60)
|
200
|
+
async def test_rank_string():
|
201
|
+
proc = fake_in_process_host().spawn_procs(per_host={"hosts": 1, "gpus": 2})
|
202
|
+
r = proc.spawn("runit", RunIt)
|
203
|
+
vm = r.return_current_rank_str.call().get()
|
204
|
+
r0 = vm.flatten("r").slice(r=0).item()
|
205
|
+
r1 = vm.flatten("r").slice(r=1).item()
|
206
|
+
assert r0 == "{'hosts': 0/1, 'gpus': 0/2}"
|
207
|
+
assert r1 == "{'hosts': 0/1, 'gpus': 1/2}"
|
208
|
+
|
209
|
+
|
168
210
|
class SyncActor(Actor):
|
169
211
|
@endpoint
|
170
212
|
def sync_endpoint(self, a_counter: Counter):
|
171
213
|
return a_counter.value.choose().get()
|
172
214
|
|
173
215
|
|
216
|
+
@pytest.mark.timeout(60)
|
174
217
|
async def test_sync_actor():
|
175
|
-
proc = await
|
218
|
+
proc = await fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
176
219
|
a = await proc.spawn("actor", SyncActor)
|
177
220
|
c = await proc.spawn("counter", Counter, 5)
|
178
221
|
r = await a.sync_endpoint.choose(c)
|
179
222
|
assert r == 5
|
180
223
|
|
181
224
|
|
182
|
-
|
183
|
-
|
225
|
+
@pytest.mark.timeout(60)
|
226
|
+
def test_sync_actor_sync_client() -> None:
|
227
|
+
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
184
228
|
a = proc.spawn("actor", SyncActor).get()
|
185
229
|
c = proc.spawn("counter", Counter, 5).get()
|
186
230
|
r = a.sync_endpoint.choose(c).get()
|
187
231
|
assert r == 5
|
188
232
|
|
189
233
|
|
234
|
+
@pytest.mark.timeout(60)
|
190
235
|
def test_proc_mesh_size() -> None:
|
191
|
-
proc =
|
236
|
+
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
192
237
|
assert 2 == proc.size("gpus")
|
193
238
|
|
194
239
|
|
240
|
+
@pytest.mark.timeout(60)
|
195
241
|
def test_rank_size_sync() -> None:
|
196
|
-
proc =
|
242
|
+
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
197
243
|
r = proc.spawn("runit", RunIt).get()
|
198
244
|
|
199
245
|
acc = Accumulator(r.run, 0, operator.add)
|
@@ -201,8 +247,9 @@ def test_rank_size_sync() -> None:
|
|
201
247
|
assert 4 == acc.accumulate(lambda: current_size()["gpus"]).get()
|
202
248
|
|
203
249
|
|
250
|
+
@pytest.mark.timeout(60)
|
204
251
|
def test_accumulate_sync() -> None:
|
205
|
-
proc =
|
252
|
+
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
|
206
253
|
counter = proc.spawn("counter", Counter, 1).get()
|
207
254
|
counter.incr.broadcast()
|
208
255
|
acc = Accumulator(counter.value, 0, operator.add)
|
@@ -216,8 +263,9 @@ class CastToCounter(Actor):
|
|
216
263
|
return list(c.value.call().get())
|
217
264
|
|
218
265
|
|
266
|
+
@pytest.mark.timeout(60)
|
219
267
|
def test_value_mesh() -> None:
|
220
|
-
proc =
|
268
|
+
proc = fake_in_process_host().spawn_procs(per_host={"hosts": 1, "gpus": 2})
|
221
269
|
counter = proc.spawn("counter", Counter, 0).get()
|
222
270
|
counter.slice(hosts=0, gpus=1).incr.broadcast()
|
223
271
|
x = counter.value.call().get()
|
@@ -228,7 +276,18 @@ def test_value_mesh() -> None:
|
|
228
276
|
assert list(x) == n.slice(gpus=0).doit.call_one(counter).get()
|
229
277
|
|
230
278
|
|
279
|
+
@pytest.mark.timeout(60)
|
231
280
|
def test_rust_binding_modules_correct() -> None:
|
281
|
+
"""
|
282
|
+
This tests that rust bindings will survive pickling correctly.
|
283
|
+
|
284
|
+
To correctly define a rust binding, either
|
285
|
+
|
286
|
+
(1) Set its module to "monarch._rust_bindings.rust_crate.rust_module",
|
287
|
+
and make sure it is registered in monarch_extension/lib.rs
|
288
|
+
(2) Set its module to some existing python file, and use @rust_struct to install
|
289
|
+
the rust struct in that file and patch in any python extension methods.
|
290
|
+
"""
|
232
291
|
import monarch._rust_bindings as bindings
|
233
292
|
|
234
293
|
def check(module, path):
|
@@ -238,14 +297,16 @@ def test_rust_binding_modules_correct() -> None:
|
|
238
297
|
if isinstance(value, ModuleType):
|
239
298
|
check(value, f"{path}.{name}")
|
240
299
|
elif hasattr(value, "__module__"):
|
241
|
-
|
242
|
-
|
300
|
+
value_module = importlib.import_module(value.__module__)
|
301
|
+
resolved_value = getattr(value_module, value.__name__)
|
302
|
+
assert value is resolved_value
|
243
303
|
|
244
304
|
check(bindings, "monarch._rust_bindings")
|
245
305
|
|
246
306
|
|
307
|
+
@pytest.mark.timeout(60)
|
247
308
|
def test_proc_mesh_liveness() -> None:
|
248
|
-
mesh =
|
309
|
+
mesh = this_host().spawn_procs(per_host={"gpus": 2})
|
249
310
|
counter = mesh.spawn("counter", Counter, 1).get()
|
250
311
|
del mesh
|
251
312
|
# Give some time for the mesh to have been shut down.
|
@@ -270,7 +331,7 @@ class TLSActor(Actor):
|
|
270
331
|
self.local.value += 1
|
271
332
|
|
272
333
|
@endpoint
|
273
|
-
def
|
334
|
+
def get_value(self):
|
274
335
|
return self.local.value
|
275
336
|
|
276
337
|
@endpoint
|
@@ -278,16 +339,17 @@ class TLSActor(Actor):
|
|
278
339
|
return self.local.value
|
279
340
|
|
280
341
|
|
342
|
+
@pytest.mark.timeout(60)
|
281
343
|
async def test_actor_tls() -> None:
|
282
344
|
"""Test that thread-local state is respected."""
|
283
|
-
pm =
|
345
|
+
pm = this_host().spawn_procs(per_host={"gpus": 1})
|
284
346
|
am = await pm.spawn("tls", TLSActor)
|
285
347
|
await am.increment.call_one()
|
286
348
|
await am.increment_async.call_one()
|
287
349
|
await am.increment.call_one()
|
288
350
|
await am.increment_async.call_one()
|
289
351
|
|
290
|
-
assert 4 == await am.
|
352
|
+
assert 4 == await am.get_value.call_one()
|
291
353
|
assert 4 == await am.get_async.call_one()
|
292
354
|
|
293
355
|
|
@@ -303,20 +365,21 @@ class TLSActorFullSync(Actor):
|
|
303
365
|
self.local.value += 1
|
304
366
|
|
305
367
|
@endpoint
|
306
|
-
def
|
368
|
+
def get_value(self):
|
307
369
|
return self.local.value
|
308
370
|
|
309
371
|
|
372
|
+
@pytest.mark.timeout(60)
|
310
373
|
async def test_actor_tls_full_sync() -> None:
|
311
374
|
"""Test that thread-local state is respected."""
|
312
|
-
pm =
|
375
|
+
pm = this_host().spawn_procs(per_host={"gpus": 1})
|
313
376
|
am = await pm.spawn("tls", TLSActorFullSync)
|
314
377
|
await am.increment.call_one()
|
315
378
|
await am.increment.call_one()
|
316
379
|
await am.increment.call_one()
|
317
380
|
await am.increment.call_one()
|
318
381
|
|
319
|
-
assert 4 == await am.
|
382
|
+
assert 4 == await am.get_value.call_one()
|
320
383
|
|
321
384
|
|
322
385
|
class AsyncActor(Actor):
|
@@ -333,10 +396,10 @@ class AsyncActor(Actor):
|
|
333
396
|
self.should_exit = True
|
334
397
|
|
335
398
|
|
336
|
-
@pytest.mark.timeout(
|
399
|
+
@pytest.mark.timeout(30)
|
337
400
|
async def test_async_concurrency():
|
338
401
|
"""Test that async endpoints will be processed concurrently."""
|
339
|
-
pm = await
|
402
|
+
pm = await this_host().spawn_procs()
|
340
403
|
am = await pm.spawn("async", AsyncActor)
|
341
404
|
fut = am.sleep.call()
|
342
405
|
# This call should go through and exit the sleep loop, as long as we are
|
@@ -442,19 +505,35 @@ async def awaitit(f):
|
|
442
505
|
|
443
506
|
|
444
507
|
class Printer(Actor):
|
445
|
-
def __init__(self):
|
446
|
-
self.
|
447
|
-
self.logger.setLevel(INFO)
|
508
|
+
def __init__(self) -> None:
|
509
|
+
self._logger: logging.Logger = logging.getLogger()
|
448
510
|
|
449
511
|
@endpoint
|
450
|
-
async def print(self, content: str):
|
451
|
-
print(f"{
|
512
|
+
async def print(self, content: str) -> None:
|
513
|
+
print(f"{content}", flush=True)
|
514
|
+
sys.stdout.flush()
|
515
|
+
sys.stderr.flush()
|
452
516
|
|
453
517
|
@endpoint
|
454
|
-
async def log(self, content: str):
|
455
|
-
self.
|
456
|
-
|
457
|
-
|
518
|
+
async def log(self, content: str) -> None:
|
519
|
+
self._logger.error(f"{content}")
|
520
|
+
for handler in self._logger.handlers:
|
521
|
+
handler.flush()
|
522
|
+
sys.stdout.flush()
|
523
|
+
sys.stderr.flush()
|
524
|
+
|
525
|
+
def _handle_undeliverable_message(
|
526
|
+
self, message: UndeliverableMessageEnvelope
|
527
|
+
) -> bool:
|
528
|
+
# Don't throw an error on undeliverable messages. This actor is used in a test for
|
529
|
+
# stopping actor meshes, and if we throw an error here then there is a race between
|
530
|
+
# the asserted error that the mesh was stopped and the supervision error that a message
|
531
|
+
# wasn't delivered.
|
532
|
+
self._logger.error(f"Ignoring undeliverable message: {message}")
|
533
|
+
return True
|
534
|
+
|
535
|
+
|
536
|
+
@pytest.mark.timeout(60)
|
458
537
|
async def test_actor_log_streaming() -> None:
|
459
538
|
# Save original file descriptors
|
460
539
|
original_stdout_fd = os.dup(1) # stdout
|
@@ -482,19 +561,247 @@ async def test_actor_log_streaming() -> None:
|
|
482
561
|
sys.stderr = stderr_file
|
483
562
|
|
484
563
|
try:
|
485
|
-
pm =
|
564
|
+
pm = this_host().spawn_procs(per_host={"gpus": 2})
|
486
565
|
am = await pm.spawn("printer", Printer)
|
487
566
|
|
488
|
-
|
489
|
-
await
|
567
|
+
# Disable streaming logs to client
|
568
|
+
await pm.logging_option(
|
569
|
+
stream_to_client=False, aggregate_window_sec=None
|
570
|
+
)
|
571
|
+
await asyncio.sleep(1)
|
572
|
+
|
573
|
+
# These should not be streamed to client initially
|
574
|
+
for _ in range(5):
|
575
|
+
await am.print.call("no print streaming")
|
576
|
+
await am.log.call("no log streaming")
|
577
|
+
await asyncio.sleep(1)
|
578
|
+
|
579
|
+
# Enable streaming logs to client
|
580
|
+
await pm.logging_option(
|
581
|
+
stream_to_client=True, aggregate_window_sec=1, level=logging.FATAL
|
582
|
+
)
|
583
|
+
# Give it some time to reflect
|
584
|
+
await asyncio.sleep(1)
|
585
|
+
|
586
|
+
# These should be streamed to client
|
587
|
+
for _ in range(5):
|
588
|
+
await am.print.call("has print streaming")
|
589
|
+
await am.log.call("no log streaming due to level mismatch")
|
590
|
+
await asyncio.sleep(1)
|
591
|
+
|
592
|
+
# Enable streaming logs to client
|
593
|
+
await pm.logging_option(
|
594
|
+
stream_to_client=True, aggregate_window_sec=1, level=logging.ERROR
|
595
|
+
)
|
596
|
+
# Give it some time to reflect
|
597
|
+
await asyncio.sleep(1)
|
598
|
+
|
599
|
+
# These should be streamed to client
|
600
|
+
for _ in range(5):
|
601
|
+
await am.print.call("has print streaming too")
|
602
|
+
await am.log.call("has log streaming as level matched")
|
603
|
+
|
604
|
+
await pm.stop()
|
605
|
+
|
606
|
+
# Flush all outputs
|
607
|
+
stdout_file.flush()
|
608
|
+
stderr_file.flush()
|
609
|
+
os.fsync(stdout_file.fileno())
|
610
|
+
os.fsync(stderr_file.fileno())
|
611
|
+
|
612
|
+
finally:
|
613
|
+
# Restore Python's sys.stdout/stderr
|
614
|
+
sys.stdout = original_sys_stdout
|
615
|
+
sys.stderr = original_sys_stderr
|
616
|
+
|
617
|
+
# Restore original file descriptors
|
618
|
+
os.dup2(original_stdout_fd, 1)
|
619
|
+
os.dup2(original_stderr_fd, 2)
|
620
|
+
|
621
|
+
# Read the captured output
|
622
|
+
with open(stdout_path, "r") as f:
|
623
|
+
stdout_content = f.read()
|
624
|
+
|
625
|
+
with open(stderr_path, "r") as f:
|
626
|
+
stderr_content = f.read()
|
627
|
+
|
628
|
+
# Clean up temp files
|
629
|
+
os.unlink(stdout_path)
|
630
|
+
os.unlink(stderr_path)
|
631
|
+
|
632
|
+
# Assertions on the captured output
|
633
|
+
# Has a leading context so we can distinguish between streamed log and
|
634
|
+
# the log directly printed by the child processes as they share the same stdout/stderr
|
635
|
+
assert not re.search(
|
636
|
+
r"similar log lines.*no print streaming", stdout_content
|
637
|
+
), stdout_content
|
638
|
+
assert not re.search(
|
639
|
+
r"similar log lines.*no print streaming", stderr_content
|
640
|
+
), stderr_content
|
641
|
+
assert not re.search(
|
642
|
+
r"similar log lines.*no log streaming", stdout_content
|
643
|
+
), stdout_content
|
644
|
+
assert not re.search(
|
645
|
+
r"similar log lines.*no log streaming", stderr_content
|
646
|
+
), stderr_content
|
647
|
+
assert not re.search(
|
648
|
+
r"similar log lines.*no log streaming due to level mismatch", stdout_content
|
649
|
+
), stdout_content
|
650
|
+
assert not re.search(
|
651
|
+
r"similar log lines.*no log streaming due to level mismatch", stderr_content
|
652
|
+
), stderr_content
|
653
|
+
|
654
|
+
assert re.search(
|
655
|
+
r"similar log lines.*has print streaming", stdout_content
|
656
|
+
), stdout_content
|
657
|
+
assert not re.search(
|
658
|
+
r"similar log lines.*has print streaming", stderr_content
|
659
|
+
), stderr_content
|
660
|
+
assert re.search(
|
661
|
+
r"similar log lines.*has print streaming too", stdout_content
|
662
|
+
), stdout_content
|
663
|
+
assert not re.search(
|
664
|
+
r"similar log lines.*has print streaming too", stderr_content
|
665
|
+
), stderr_content
|
666
|
+
assert not re.search(
|
667
|
+
r"similar log lines.*log streaming as level matched", stdout_content
|
668
|
+
), stdout_content
|
669
|
+
assert re.search(
|
670
|
+
r"similar log lines.*log streaming as level matched",
|
671
|
+
stderr_content,
|
672
|
+
), stderr_content
|
673
|
+
|
674
|
+
finally:
|
675
|
+
# Ensure file descriptors are restored even if something goes wrong
|
676
|
+
try:
|
677
|
+
os.dup2(original_stdout_fd, 1)
|
678
|
+
os.dup2(original_stderr_fd, 2)
|
679
|
+
os.close(original_stdout_fd)
|
680
|
+
os.close(original_stderr_fd)
|
681
|
+
except OSError:
|
682
|
+
pass
|
683
|
+
|
684
|
+
|
685
|
+
@pytest.mark.timeout(120)
|
686
|
+
async def test_alloc_based_log_streaming() -> None:
|
687
|
+
"""Test both AllocHandle.stream_logs = False and True cases."""
|
688
|
+
|
689
|
+
async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
|
690
|
+
# Save original file descriptors
|
691
|
+
original_stdout_fd = os.dup(1) # stdout
|
692
|
+
|
693
|
+
try:
|
694
|
+
# Create temporary files to capture output
|
695
|
+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
|
696
|
+
stdout_path = stdout_file.name
|
697
|
+
os.dup2(stdout_file.fileno(), 1)
|
698
|
+
original_sys_stdout = sys.stdout
|
699
|
+
sys.stdout = stdout_file
|
700
|
+
|
701
|
+
try:
|
702
|
+
# Create proc mesh with custom stream_logs setting
|
703
|
+
host_mesh = create_local_host_mesh()
|
704
|
+
alloc_handle = host_mesh._alloc(hosts=1, gpus=2)
|
705
|
+
|
706
|
+
# Override the stream_logs setting
|
707
|
+
custom_alloc_handle = AllocHandle(
|
708
|
+
alloc_handle._hy_alloc, alloc_handle._extent, stream_logs
|
709
|
+
)
|
710
|
+
|
711
|
+
pm = ProcMesh.from_alloc(custom_alloc_handle)
|
712
|
+
am = await pm.spawn("printer", Printer)
|
713
|
+
|
714
|
+
await pm.initialized
|
715
|
+
|
716
|
+
for _ in range(5):
|
717
|
+
await am.print.call(f"{test_name} print streaming")
|
718
|
+
|
719
|
+
await pm.stop()
|
720
|
+
|
721
|
+
# Flush all outputs
|
722
|
+
stdout_file.flush()
|
723
|
+
os.fsync(stdout_file.fileno())
|
724
|
+
|
725
|
+
finally:
|
726
|
+
# Restore Python's sys.stdout
|
727
|
+
sys.stdout = original_sys_stdout
|
728
|
+
|
729
|
+
# Restore original file descriptors
|
730
|
+
os.dup2(original_stdout_fd, 1)
|
731
|
+
|
732
|
+
# Read the captured output
|
733
|
+
with open(stdout_path, "r") as f:
|
734
|
+
stdout_content = f.read()
|
735
|
+
|
736
|
+
# Clean up temp files
|
737
|
+
os.unlink(stdout_path)
|
738
|
+
|
739
|
+
if not stream_logs:
|
740
|
+
# When stream_logs=False, logs should not be streamed to client
|
741
|
+
assert not re.search(
|
742
|
+
rf"similar log lines.*{test_name} print streaming", stdout_content
|
743
|
+
), f"stream_logs=True case: {stdout_content}"
|
744
|
+
assert re.search(
|
745
|
+
rf"{test_name} print streaming", stdout_content
|
746
|
+
), f"stream_logs=True case: {stdout_content}"
|
747
|
+
else:
|
748
|
+
# When stream_logs=True, logs should be streamed to client (no aggregation by default)
|
749
|
+
assert re.search(
|
750
|
+
rf"similar log lines.*{test_name} print streaming", stdout_content
|
751
|
+
), f"stream_logs=False case: {stdout_content}"
|
752
|
+
assert not re.search(
|
753
|
+
rf"\[[0-9]\]{test_name} print streaming", stdout_content
|
754
|
+
), f"stream_logs=False case: {stdout_content}"
|
755
|
+
|
756
|
+
finally:
|
757
|
+
# Ensure file descriptors are restored even if something goes wrong
|
758
|
+
try:
|
759
|
+
os.dup2(original_stdout_fd, 1)
|
760
|
+
os.close(original_stdout_fd)
|
761
|
+
except OSError:
|
762
|
+
pass
|
763
|
+
|
764
|
+
# Test both cases
|
765
|
+
await test_stream_logs_case(False, "stream_logs_false")
|
766
|
+
await test_stream_logs_case(True, "stream_logs_true")
|
767
|
+
|
768
|
+
|
769
|
+
@pytest.mark.timeout(60)
|
770
|
+
async def test_logging_option_defaults() -> None:
|
771
|
+
# Save original file descriptors
|
772
|
+
original_stdout_fd = os.dup(1) # stdout
|
773
|
+
original_stderr_fd = os.dup(2) # stderr
|
774
|
+
|
775
|
+
try:
|
776
|
+
# Create temporary files to capture output
|
777
|
+
with tempfile.NamedTemporaryFile(
|
778
|
+
mode="w+", delete=False
|
779
|
+
) as stdout_file, tempfile.NamedTemporaryFile(
|
780
|
+
mode="w+", delete=False
|
781
|
+
) as stderr_file:
|
782
|
+
stdout_path = stdout_file.name
|
783
|
+
stderr_path = stderr_file.name
|
784
|
+
|
785
|
+
# Redirect file descriptors to our temp files
|
786
|
+
# This will capture both Python and Rust output
|
787
|
+
os.dup2(stdout_file.fileno(), 1)
|
788
|
+
os.dup2(stderr_file.fileno(), 2)
|
789
|
+
|
790
|
+
# Also redirect Python's sys.stdout/stderr for completeness
|
791
|
+
original_sys_stdout = sys.stdout
|
792
|
+
original_sys_stderr = sys.stderr
|
793
|
+
sys.stdout = stdout_file
|
794
|
+
sys.stderr = stderr_file
|
490
795
|
|
491
|
-
|
796
|
+
try:
|
797
|
+
pm = await this_host().spawn_procs(per_host={"gpus": 2})
|
798
|
+
am = await pm.spawn("printer", Printer)
|
492
799
|
|
493
|
-
|
494
|
-
|
800
|
+
for _ in range(5):
|
801
|
+
await am.print.call("print streaming")
|
802
|
+
await am.log.call("log streaming")
|
495
803
|
|
496
|
-
|
497
|
-
time.sleep(5)
|
804
|
+
await pm.stop()
|
498
805
|
|
499
806
|
# Flush all outputs
|
500
807
|
stdout_file.flush()
|
@@ -515,16 +822,27 @@ async def test_actor_log_streaming() -> None:
|
|
515
822
|
with open(stdout_path, "r") as f:
|
516
823
|
stdout_content = f.read()
|
517
824
|
|
825
|
+
with open(stderr_path, "r") as f:
|
826
|
+
stderr_content = f.read()
|
827
|
+
|
518
828
|
# Clean up temp files
|
519
829
|
os.unlink(stdout_path)
|
520
830
|
os.unlink(stderr_path)
|
521
831
|
|
522
|
-
#
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
assert "
|
527
|
-
|
832
|
+
# Assertions on the captured output
|
833
|
+
assert not re.search(
|
834
|
+
r"similar log lines.*print streaming", stdout_content
|
835
|
+
), stdout_content
|
836
|
+
assert re.search(r"print streaming", stdout_content), stdout_content
|
837
|
+
assert not re.search(
|
838
|
+
r"similar log lines.*print streaming", stderr_content
|
839
|
+
), stderr_content
|
840
|
+
assert not re.search(
|
841
|
+
r"similar log lines.*log streaming", stdout_content
|
842
|
+
), stdout_content
|
843
|
+
assert not re.search(
|
844
|
+
r"similar log lines.*log streaming", stderr_content
|
845
|
+
), stderr_content
|
528
846
|
|
529
847
|
finally:
|
530
848
|
# Ensure file descriptors are restored even if something goes wrong
|
@@ -537,6 +855,378 @@ async def test_actor_log_streaming() -> None:
|
|
537
855
|
pass
|
538
856
|
|
539
857
|
|
858
|
+
# oss_skip: pytest keeps complaining about mocking get_ipython module
|
859
|
+
@pytest.mark.oss_skip
|
860
|
+
@pytest.mark.timeout(180)
|
861
|
+
async def test_flush_logs_ipython() -> None:
|
862
|
+
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
|
863
|
+
# Save original file descriptors
|
864
|
+
original_stdout_fd = os.dup(1) # stdout
|
865
|
+
|
866
|
+
try:
|
867
|
+
# Create temporary files to capture output
|
868
|
+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
|
869
|
+
stdout_path = stdout_file.name
|
870
|
+
|
871
|
+
# Redirect file descriptors to our temp files
|
872
|
+
os.dup2(stdout_file.fileno(), 1)
|
873
|
+
|
874
|
+
# Also redirect Python's sys.stdout
|
875
|
+
original_sys_stdout = sys.stdout
|
876
|
+
sys.stdout = stdout_file
|
877
|
+
|
878
|
+
try:
|
879
|
+
# Mock IPython environment
|
880
|
+
class MockExecutionResult:
|
881
|
+
pass
|
882
|
+
|
883
|
+
class MockEvents:
|
884
|
+
def __init__(self):
|
885
|
+
self.callbacks = {}
|
886
|
+
self.registers = 0
|
887
|
+
self.unregisters = 0
|
888
|
+
|
889
|
+
def register(self, event_name, callback):
|
890
|
+
if event_name not in self.callbacks:
|
891
|
+
self.callbacks[event_name] = []
|
892
|
+
self.callbacks[event_name].append(callback)
|
893
|
+
self.registers += 1
|
894
|
+
|
895
|
+
def unregister(self, event_name, callback):
|
896
|
+
if event_name not in self.callbacks:
|
897
|
+
raise ValueError(f"Event {event_name} not registered")
|
898
|
+
assert callback in self.callbacks[event_name]
|
899
|
+
self.callbacks[event_name].remove(callback)
|
900
|
+
self.unregisters += 1
|
901
|
+
|
902
|
+
def trigger(self, event_name, *args, **kwargs):
|
903
|
+
if event_name in self.callbacks:
|
904
|
+
for callback in self.callbacks[event_name]:
|
905
|
+
callback(*args, **kwargs)
|
906
|
+
|
907
|
+
class MockIPython:
|
908
|
+
def __init__(self):
|
909
|
+
self.events = MockEvents()
|
910
|
+
|
911
|
+
mock_ipython = MockIPython()
|
912
|
+
|
913
|
+
with unittest.mock.patch(
|
914
|
+
"monarch._src.actor.logging.get_ipython",
|
915
|
+
lambda: mock_ipython,
|
916
|
+
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
|
917
|
+
# Make sure we can register and unregister callbacks
|
918
|
+
for i in range(3):
|
919
|
+
pm1 = await this_host().spawn_procs(per_host={"gpus": 2})
|
920
|
+
pm2 = await this_host().spawn_procs(per_host={"gpus": 2})
|
921
|
+
am1 = await pm1.spawn("printer", Printer)
|
922
|
+
am2 = await pm2.spawn("printer", Printer)
|
923
|
+
|
924
|
+
# Set aggregation window to ensure logs are buffered
|
925
|
+
await pm1.logging_option(
|
926
|
+
stream_to_client=True, aggregate_window_sec=600
|
927
|
+
)
|
928
|
+
await pm2.logging_option(
|
929
|
+
stream_to_client=True, aggregate_window_sec=600
|
930
|
+
)
|
931
|
+
# TODO: fix the following assertion
|
932
|
+
# assert mock_ipython.events.unregisters == 2 * i
|
933
|
+
|
934
|
+
# _get_controller_controller() spawns an extra local mesh
|
935
|
+
# but log streaming is disabled so it doesn't hurt
|
936
|
+
assert mock_ipython.events.registers == 1 + 2 * (i + 1)
|
937
|
+
await asyncio.sleep(1)
|
938
|
+
|
939
|
+
# Generate some logs that will be aggregated
|
940
|
+
for _ in range(5):
|
941
|
+
await am1.print.call("ipython1 test log")
|
942
|
+
await am2.print.call("ipython2 test log")
|
943
|
+
|
944
|
+
# Trigger the post_run_cell event which should flush logs
|
945
|
+
mock_ipython.events.trigger(
|
946
|
+
"post_run_cell", MockExecutionResult()
|
947
|
+
)
|
948
|
+
|
949
|
+
# Flush all outputs
|
950
|
+
stdout_file.flush()
|
951
|
+
os.fsync(stdout_file.fileno())
|
952
|
+
|
953
|
+
gc.collect()
|
954
|
+
|
955
|
+
# Same as above, _get_controller_controller() spawns an extra local mesh
|
956
|
+
assert mock_ipython.events.registers == 7
|
957
|
+
# There are many objects still taking refs
|
958
|
+
# TODO: fix the following assertion
|
959
|
+
assert mock_ipython.events.unregisters == 0
|
960
|
+
assert len(mock_ipython.events.callbacks["post_run_cell"]) == 7
|
961
|
+
finally:
|
962
|
+
# Restore Python's sys.stdout
|
963
|
+
sys.stdout = original_sys_stdout
|
964
|
+
|
965
|
+
# Restore original file descriptors
|
966
|
+
os.dup2(original_stdout_fd, 1)
|
967
|
+
|
968
|
+
# Read the captured output
|
969
|
+
with open(stdout_path, "r") as f:
|
970
|
+
stdout_content = f.read()
|
971
|
+
|
972
|
+
# TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils
|
973
|
+
|
974
|
+
# Clean up temp files
|
975
|
+
os.unlink(stdout_path)
|
976
|
+
|
977
|
+
# Verify that logs were flushed when the post_run_cell event was triggered
|
978
|
+
# We should see the aggregated logs in the output
|
979
|
+
assert (
|
980
|
+
len(
|
981
|
+
re.findall(
|
982
|
+
r"\[10 similar log lines\].*ipython1 test log", stdout_content
|
983
|
+
)
|
984
|
+
)
|
985
|
+
== 3
|
986
|
+
), stdout_content
|
987
|
+
|
988
|
+
assert (
|
989
|
+
len(
|
990
|
+
re.findall(
|
991
|
+
r"\[10 similar log lines\].*ipython2 test log", stdout_content
|
992
|
+
)
|
993
|
+
)
|
994
|
+
== 3
|
995
|
+
), stdout_content
|
996
|
+
|
997
|
+
finally:
|
998
|
+
# Ensure file descriptors are restored even if something goes wrong
|
999
|
+
try:
|
1000
|
+
os.dup2(original_stdout_fd, 1)
|
1001
|
+
os.close(original_stdout_fd)
|
1002
|
+
except OSError:
|
1003
|
+
pass
|
1004
|
+
|
1005
|
+
|
1006
|
+
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
|
1007
|
+
@pytest.mark.oss_skip
|
1008
|
+
async def test_flush_logs_fast_exit() -> None:
|
1009
|
+
# We use a subprocess to run the test so we can handle the flushed logs at the end.
|
1010
|
+
# Otherwise, it is hard to restore the original stdout/stderr.
|
1011
|
+
|
1012
|
+
test_bin = importlib.resources.files(str(__package__)).joinpath("test_bin")
|
1013
|
+
|
1014
|
+
# Run the binary in a separate process and capture stdout and stderr
|
1015
|
+
cmd = [str(test_bin), "flush-logs"]
|
1016
|
+
|
1017
|
+
process = subprocess.run(cmd, capture_output=True, timeout=60, text=True)
|
1018
|
+
|
1019
|
+
# Check if the process ended without error
|
1020
|
+
if process.returncode != 0:
|
1021
|
+
raise RuntimeError(f"{cmd} ended with error code {process.returncode}. ")
|
1022
|
+
|
1023
|
+
# Assertions on the captured output, 160 = 32 procs * 5 logs per proc
|
1024
|
+
# 32 and 5 are specified in the test_bin flush-logs.
|
1025
|
+
assert (
|
1026
|
+
len(
|
1027
|
+
re.findall(
|
1028
|
+
r"160 similar log lines.*has print streaming",
|
1029
|
+
process.stdout,
|
1030
|
+
)
|
1031
|
+
)
|
1032
|
+
== 1
|
1033
|
+
), process.stdout
|
1034
|
+
|
1035
|
+
|
1036
|
+
@pytest.mark.timeout(60)
|
1037
|
+
async def test_flush_on_disable_aggregation() -> None:
|
1038
|
+
"""Test that logs are flushed when disabling aggregation.
|
1039
|
+
|
1040
|
+
This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation."
|
1041
|
+
"""
|
1042
|
+
# Save original file descriptors
|
1043
|
+
original_stdout_fd = os.dup(1) # stdout
|
1044
|
+
|
1045
|
+
try:
|
1046
|
+
# Create temporary files to capture output
|
1047
|
+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
|
1048
|
+
stdout_path = stdout_file.name
|
1049
|
+
|
1050
|
+
# Redirect file descriptors to our temp files
|
1051
|
+
os.dup2(stdout_file.fileno(), 1)
|
1052
|
+
|
1053
|
+
# Also redirect Python's sys.stdout
|
1054
|
+
original_sys_stdout = sys.stdout
|
1055
|
+
sys.stdout = stdout_file
|
1056
|
+
|
1057
|
+
try:
|
1058
|
+
pm = await this_host().spawn_procs(per_host={"gpus": 2})
|
1059
|
+
am = await pm.spawn("printer", Printer)
|
1060
|
+
|
1061
|
+
# Set a long aggregation window to ensure logs aren't flushed immediately
|
1062
|
+
await pm.logging_option(stream_to_client=True, aggregate_window_sec=60)
|
1063
|
+
|
1064
|
+
# Generate some logs that will be aggregated but not flushed immediately
|
1065
|
+
for _ in range(5):
|
1066
|
+
await am.print.call("aggregated log line")
|
1067
|
+
await asyncio.sleep(1)
|
1068
|
+
|
1069
|
+
# Now disable aggregation - this should trigger an immediate flush
|
1070
|
+
await pm.logging_option(
|
1071
|
+
stream_to_client=True, aggregate_window_sec=None
|
1072
|
+
)
|
1073
|
+
|
1074
|
+
# Wait a bit to ensure logs are collected
|
1075
|
+
await asyncio.sleep(1)
|
1076
|
+
for _ in range(5):
|
1077
|
+
await am.print.call("single log line")
|
1078
|
+
|
1079
|
+
await pm.stop()
|
1080
|
+
|
1081
|
+
# Flush all outputs
|
1082
|
+
stdout_file.flush()
|
1083
|
+
os.fsync(stdout_file.fileno())
|
1084
|
+
|
1085
|
+
finally:
|
1086
|
+
# Restore Python's sys.stdout
|
1087
|
+
sys.stdout = original_sys_stdout
|
1088
|
+
|
1089
|
+
# Restore original file descriptors
|
1090
|
+
os.dup2(original_stdout_fd, 1)
|
1091
|
+
|
1092
|
+
# Read the captured output
|
1093
|
+
with open(stdout_path, "r") as f:
|
1094
|
+
stdout_content = f.read()
|
1095
|
+
|
1096
|
+
# Clean up temp files
|
1097
|
+
os.unlink(stdout_path)
|
1098
|
+
|
1099
|
+
# Verify that logs were flushed when aggregation was disabled
|
1100
|
+
# We should see the aggregated logs in the output
|
1101
|
+
# 10 = 5 log lines * 2 procs
|
1102
|
+
assert re.search(
|
1103
|
+
r"\[10 similar log lines\].*aggregated log line", stdout_content
|
1104
|
+
), stdout_content
|
1105
|
+
|
1106
|
+
# No aggregated single log lines
|
1107
|
+
assert not re.search(
|
1108
|
+
r"similar log lines.*single log line", stdout_content
|
1109
|
+
), stdout_content
|
1110
|
+
|
1111
|
+
# 10 = 5 log lines * 2 procs
|
1112
|
+
assert (
|
1113
|
+
len(re.findall(r"\[.* [0-9]+\] single log line", stdout_content)) == 10
|
1114
|
+
), stdout_content
|
1115
|
+
|
1116
|
+
finally:
|
1117
|
+
# Ensure file descriptors are restored even if something goes wrong
|
1118
|
+
try:
|
1119
|
+
os.dup2(original_stdout_fd, 1)
|
1120
|
+
os.close(original_stdout_fd)
|
1121
|
+
except OSError:
|
1122
|
+
pass
|
1123
|
+
|
1124
|
+
|
1125
|
+
@pytest.mark.timeout(120)
|
1126
|
+
async def test_multiple_ongoing_flushes_no_deadlock() -> None:
|
1127
|
+
"""
|
1128
|
+
The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked.
|
1129
|
+
Because now a flush call is purely sync, it is very easy to get into a deadlock.
|
1130
|
+
So we assert the last flush call will not get into such a state.
|
1131
|
+
"""
|
1132
|
+
pm = this_host().spawn_procs(per_host={"gpus": 4})
|
1133
|
+
am = pm.spawn("printer", Printer)
|
1134
|
+
|
1135
|
+
# Generate some logs that will be aggregated but not flushed immediately
|
1136
|
+
for _ in range(10):
|
1137
|
+
await am.print.call("aggregated log line")
|
1138
|
+
|
1139
|
+
log_mesh = pm._logging_manager._logging_mesh_client
|
1140
|
+
assert log_mesh is not None
|
1141
|
+
futures = []
|
1142
|
+
for _ in range(5):
|
1143
|
+
# FIXME: the order of futures doesn't necessarily mean the order of flushes due to the async nature.
|
1144
|
+
await asyncio.sleep(0.1)
|
1145
|
+
futures.append(Future(coro=log_mesh.flush().spawn().task()))
|
1146
|
+
|
1147
|
+
# The last flush should not block
|
1148
|
+
futures[-1].get()
|
1149
|
+
|
1150
|
+
|
1151
|
+
@pytest.mark.timeout(60)
|
1152
|
+
async def test_adjust_aggregation_window() -> None:
|
1153
|
+
"""Test that the flush deadline is updated when the aggregation window is adjusted.
|
1154
|
+
|
1155
|
+
This tests the corner case: "This can happen if the user has adjusted the aggregation window."
|
1156
|
+
"""
|
1157
|
+
# Save original file descriptors
|
1158
|
+
original_stdout_fd = os.dup(1) # stdout
|
1159
|
+
|
1160
|
+
try:
|
1161
|
+
# Create temporary files to capture output
|
1162
|
+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
|
1163
|
+
stdout_path = stdout_file.name
|
1164
|
+
|
1165
|
+
# Redirect file descriptors to our temp files
|
1166
|
+
os.dup2(stdout_file.fileno(), 1)
|
1167
|
+
|
1168
|
+
# Also redirect Python's sys.stdout
|
1169
|
+
original_sys_stdout = sys.stdout
|
1170
|
+
sys.stdout = stdout_file
|
1171
|
+
|
1172
|
+
try:
|
1173
|
+
pm = await this_host().spawn_procs(per_host={"gpus": 2})
|
1174
|
+
am = await pm.spawn("printer", Printer)
|
1175
|
+
|
1176
|
+
# Set a long aggregation window initially
|
1177
|
+
await pm.logging_option(stream_to_client=True, aggregate_window_sec=100)
|
1178
|
+
|
1179
|
+
# Generate some logs that will be aggregated
|
1180
|
+
for _ in range(3):
|
1181
|
+
await am.print.call("first batch of logs")
|
1182
|
+
await asyncio.sleep(1)
|
1183
|
+
|
1184
|
+
# Now adjust to a shorter window - this should update the flush deadline
|
1185
|
+
await pm.logging_option(stream_to_client=True, aggregate_window_sec=2)
|
1186
|
+
|
1187
|
+
# Generate more logs
|
1188
|
+
for _ in range(3):
|
1189
|
+
await am.print.call("second batch of logs")
|
1190
|
+
|
1191
|
+
await pm.stop()
|
1192
|
+
|
1193
|
+
# Flush all outputs
|
1194
|
+
stdout_file.flush()
|
1195
|
+
os.fsync(stdout_file.fileno())
|
1196
|
+
|
1197
|
+
finally:
|
1198
|
+
# Restore Python's sys.stdout/stderr
|
1199
|
+
sys.stdout = original_sys_stdout
|
1200
|
+
|
1201
|
+
# Restore original file descriptors
|
1202
|
+
os.dup2(original_stdout_fd, 1)
|
1203
|
+
|
1204
|
+
# Read the captured output
|
1205
|
+
with open(stdout_path, "r") as f:
|
1206
|
+
stdout_content = f.read()
|
1207
|
+
|
1208
|
+
# Clean up temp files
|
1209
|
+
os.unlink(stdout_path)
|
1210
|
+
|
1211
|
+
# Verify that logs were flushed when the aggregation window was adjusted
|
1212
|
+
# We should see both batches of logs in the output
|
1213
|
+
assert re.search(
|
1214
|
+
r"\[6 similar log lines\].*first batch of logs", stdout_content
|
1215
|
+
), stdout_content
|
1216
|
+
|
1217
|
+
assert re.search(
|
1218
|
+
r"similar log lines.*second batch of logs", stdout_content
|
1219
|
+
), stdout_content
|
1220
|
+
|
1221
|
+
finally:
|
1222
|
+
# Ensure file descriptors are restored even if something goes wrong
|
1223
|
+
try:
|
1224
|
+
os.dup2(original_stdout_fd, 1)
|
1225
|
+
os.close(original_stdout_fd)
|
1226
|
+
except OSError:
|
1227
|
+
pass
|
1228
|
+
|
1229
|
+
|
540
1230
|
class SendAlot(Actor):
|
541
1231
|
@endpoint
|
542
1232
|
async def send(self, port: Port[int]):
|
@@ -544,10 +1234,11 @@ class SendAlot(Actor):
|
|
544
1234
|
port.send(i)
|
545
1235
|
|
546
1236
|
|
547
|
-
|
548
|
-
|
1237
|
+
@pytest.mark.timeout(60)
|
1238
|
+
def test_port_as_argument() -> None:
|
1239
|
+
proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 1})
|
549
1240
|
s = proc_mesh.spawn("send_alot", SendAlot).get()
|
550
|
-
send, recv =
|
1241
|
+
send, recv = Channel[int].open()
|
551
1242
|
|
552
1243
|
s.send.broadcast(send)
|
553
1244
|
|
@@ -555,14 +1246,14 @@ def test_port_as_argument():
|
|
555
1246
|
assert i == recv.recv().get()
|
556
1247
|
|
557
1248
|
|
558
|
-
@pytest.mark.timeout(
|
1249
|
+
@pytest.mark.timeout(30)
|
559
1250
|
async def test_same_actor_twice() -> None:
|
560
|
-
pm =
|
561
|
-
await pm.spawn("dup", Counter, 0)
|
1251
|
+
pm = this_host().spawn_procs(per_host={"gpus": 1})
|
1252
|
+
await pm.spawn("dup", Counter, 0).initialized
|
562
1253
|
|
563
1254
|
# The second spawn with the same name should fail with a specific error
|
564
1255
|
with pytest.raises(Exception) as exc_info:
|
565
|
-
await pm.spawn("dup", Counter, 0)
|
1256
|
+
await pm.spawn("dup", Counter, 0).initialized
|
566
1257
|
|
567
1258
|
# Assert that the error message contains the expected text about duplicate actor name
|
568
1259
|
error_msg = str(exc_info.value)
|
@@ -571,23 +1262,81 @@ async def test_same_actor_twice() -> None:
|
|
571
1262
|
), f"Expected error message about duplicate actor name, got: {error_msg}"
|
572
1263
|
|
573
1264
|
|
1265
|
+
class LsActor(Actor):
|
1266
|
+
def __init__(self, workspace: str):
|
1267
|
+
self.workspace = workspace
|
1268
|
+
|
1269
|
+
@endpoint
|
1270
|
+
async def ls(self) -> list[str]:
|
1271
|
+
return os.listdir(self.workspace)
|
1272
|
+
|
1273
|
+
|
1274
|
+
async def test_sync_workspace() -> None:
|
1275
|
+
# create two workspaces: one for local and one for remote
|
1276
|
+
with tempfile.TemporaryDirectory() as workspace_src, tempfile.TemporaryDirectory() as workspace_dst:
|
1277
|
+
|
1278
|
+
def bootstrap_WORKSPACE_DIR() -> None:
|
1279
|
+
import os
|
1280
|
+
|
1281
|
+
os.environ["WORKSPACE_DIR"] = workspace_dst
|
1282
|
+
|
1283
|
+
pm = this_host().spawn_procs(
|
1284
|
+
per_host={"gpus": 1}, bootstrap=bootstrap_WORKSPACE_DIR
|
1285
|
+
)
|
1286
|
+
|
1287
|
+
config = defaults.config("slurm", workspace_src)
|
1288
|
+
await pm.sync_workspace(workspace=config.workspace, auto_reload=True)
|
1289
|
+
|
1290
|
+
# no file in remote workspace initially
|
1291
|
+
am = await pm.spawn("ls", LsActor, workspace_dst)
|
1292
|
+
for item in list(am.ls.call().get()):
|
1293
|
+
assert len(item[1]) == 0
|
1294
|
+
|
1295
|
+
# write a file to local workspace
|
1296
|
+
file_path = os.path.join(workspace_src, "new_file")
|
1297
|
+
with open(file_path, "w") as f:
|
1298
|
+
f.write("hello world")
|
1299
|
+
f.flush()
|
1300
|
+
|
1301
|
+
# force a sync and it should populate on the dst workspace
|
1302
|
+
await pm.sync_workspace(config.workspace, auto_reload=True)
|
1303
|
+
for item in list(am.ls.call().get()):
|
1304
|
+
assert len(item[1]) == 1
|
1305
|
+
assert item[1][0] == "new_file"
|
1306
|
+
file_path = os.path.join(workspace_dst, item[1][0])
|
1307
|
+
with open(file_path, "r") as f:
|
1308
|
+
assert f.readline() == "hello world"
|
1309
|
+
|
1310
|
+
# sanity check
|
1311
|
+
assert "WORKSPACE_DIR" not in os.environ, "test leaves env var side-effects!"
|
1312
|
+
|
1313
|
+
|
574
1314
|
class TestActorMeshStop(unittest.IsolatedAsyncioTestCase):
|
575
1315
|
async def test_actor_mesh_stop(self) -> None:
|
576
|
-
pm =
|
1316
|
+
pm = this_host().spawn_procs(per_host={"gpus": 2})
|
577
1317
|
am_1 = await pm.spawn("printer", Printer)
|
578
1318
|
am_2 = await pm.spawn("printer2", Printer)
|
579
1319
|
await am_1.print.call("hello 1")
|
580
1320
|
await am_1.log.call("hello 2")
|
581
|
-
await cast(
|
1321
|
+
await cast(ActorMesh, am_1).stop()
|
582
1322
|
|
583
1323
|
with self.assertRaisesRegex(
|
584
|
-
RuntimeError, expected_regex="`
|
1324
|
+
RuntimeError, expected_regex="`PythonActorMesh` has already been stopped"
|
585
1325
|
):
|
586
1326
|
await am_1.print.call("hello 1")
|
587
1327
|
|
588
1328
|
await am_2.print.call("hello 3")
|
589
1329
|
await am_2.log.call("hello 4")
|
590
1330
|
|
1331
|
+
await pm.stop()
|
1332
|
+
|
1333
|
+
async def test_proc_mesh_stop_after_actor_mesh_stop(self) -> None:
|
1334
|
+
pm = this_host().spawn_procs(per_host={"gpus": 2})
|
1335
|
+
am = await pm.spawn("printer", Printer)
|
1336
|
+
|
1337
|
+
await cast(ActorMesh, am).stop()
|
1338
|
+
await pm.stop()
|
1339
|
+
|
591
1340
|
|
592
1341
|
class PortedActor(Actor):
|
593
1342
|
@endpoint(explicit_response_port=True)
|
@@ -595,8 +1344,9 @@ class PortedActor(Actor):
|
|
595
1344
|
port.send(3 + b)
|
596
1345
|
|
597
1346
|
|
1347
|
+
@pytest.mark.timeout(60)
|
598
1348
|
def test_ported_actor():
|
599
|
-
proc_mesh =
|
1349
|
+
proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 1}).get()
|
600
1350
|
a = proc_mesh.spawn("port_actor", PortedActor).get()
|
601
1351
|
assert 5 == a.add.call_one(2).get()
|
602
1352
|
|
@@ -610,5 +1360,145 @@ async def consume():
|
|
610
1360
|
assert r == (7, 2, 3)
|
611
1361
|
|
612
1362
|
|
1363
|
+
@pytest.mark.timeout(60)
|
613
1364
|
def test_python_task_tuple() -> None:
|
614
1365
|
PythonTask.from_coroutine(consume()).block_on()
|
1366
|
+
|
1367
|
+
|
1368
|
+
def test_select_result() -> None:
|
1369
|
+
def s(t):
|
1370
|
+
time.sleep(t)
|
1371
|
+
return t
|
1372
|
+
|
1373
|
+
a = PythonTask.spawn_blocking(lambda: s(4))
|
1374
|
+
b = PythonTask.spawn_blocking(lambda: s(0))
|
1375
|
+
r = PythonTask.select_one([a.task(), b.task()]).block_on()
|
1376
|
+
assert r == (0, 1)
|
1377
|
+
|
1378
|
+
|
1379
|
+
def test_mesh_len():
|
1380
|
+
proc_mesh = fake_in_process_host().spawn_procs(per_host={"gpus": 12})
|
1381
|
+
s = proc_mesh.spawn("sync_actor", SyncActor).get()
|
1382
|
+
assert 12 == len(s)
|
1383
|
+
|
1384
|
+
|
1385
|
+
class UndeliverableMessageReceiver(Actor):
|
1386
|
+
def __init__(self):
|
1387
|
+
self._messages = asyncio.Queue()
|
1388
|
+
|
1389
|
+
@endpoint
|
1390
|
+
async def receive_undeliverable(
|
1391
|
+
self, sender: ActorId, dest: PortId, error_msg: str
|
1392
|
+
) -> None:
|
1393
|
+
await self._messages.put((sender, dest, error_msg))
|
1394
|
+
|
1395
|
+
@endpoint
|
1396
|
+
async def get_messages(self) -> Tuple[ActorId, PortId, str]:
|
1397
|
+
return await self._messages.get()
|
1398
|
+
|
1399
|
+
|
1400
|
+
class UndeliverableMessageSender(Actor):
|
1401
|
+
@endpoint
|
1402
|
+
def send_undeliverable(self) -> None:
|
1403
|
+
mailbox = context().actor_instance._mailbox
|
1404
|
+
port_id = PortId(
|
1405
|
+
actor_id=ActorId(
|
1406
|
+
world_name=mailbox.actor_id.world_name, rank=0, actor_name="bogus"
|
1407
|
+
),
|
1408
|
+
port=1234,
|
1409
|
+
)
|
1410
|
+
port_ref = PortRef(port_id)
|
1411
|
+
port_ref.send(
|
1412
|
+
mailbox,
|
1413
|
+
PythonMessage(PythonMessageKind.Result(None), b"123"),
|
1414
|
+
)
|
1415
|
+
|
1416
|
+
|
1417
|
+
class UndeliverableMessageSenderWithOverride(UndeliverableMessageSender):
|
1418
|
+
def __init__(self, receiver: UndeliverableMessageReceiver):
|
1419
|
+
self._receiver = receiver
|
1420
|
+
|
1421
|
+
def _handle_undeliverable_message(
|
1422
|
+
self, message: UndeliverableMessageEnvelope
|
1423
|
+
) -> bool:
|
1424
|
+
self._receiver.receive_undeliverable.call_one(
|
1425
|
+
message.sender(), message.dest(), message.error_msg()
|
1426
|
+
).get()
|
1427
|
+
return True
|
1428
|
+
|
1429
|
+
|
1430
|
+
@pytest.mark.timeout(60)
|
1431
|
+
async def test_undeliverable_message_with_override() -> None:
|
1432
|
+
pm = this_host().spawn_procs(per_host={"gpus": 1})
|
1433
|
+
receiver = pm.spawn("undeliverable_receiver", UndeliverableMessageReceiver)
|
1434
|
+
sender = pm.spawn(
|
1435
|
+
"undeliverable_sender", UndeliverableMessageSenderWithOverride, receiver
|
1436
|
+
)
|
1437
|
+
sender.send_undeliverable.call().get()
|
1438
|
+
sender, dest, error_msg = receiver.get_messages.call_one().get()
|
1439
|
+
assert sender.actor_name == "undeliverable_sender"
|
1440
|
+
assert dest.actor_id.actor_name == "bogus"
|
1441
|
+
assert error_msg is not None
|
1442
|
+
pm.stop().get()
|
1443
|
+
|
1444
|
+
|
1445
|
+
@pytest.mark.timeout(60)
|
1446
|
+
async def test_undeliverable_message_without_override() -> None:
|
1447
|
+
pm = this_host().spawn_procs(per_host={"gpus": 1})
|
1448
|
+
sender = pm.spawn("undeliverable_sender", UndeliverableMessageSender)
|
1449
|
+
sender.send_undeliverable.call().get()
|
1450
|
+
# Wait a few seconds to ensure that the undeliverable message is processed
|
1451
|
+
# without crashing anything
|
1452
|
+
await asyncio.sleep(5)
|
1453
|
+
pm.stop().get()
|
1454
|
+
|
1455
|
+
|
1456
|
+
def test_this_and_that():
|
1457
|
+
counter = this_proc().spawn("counter", Counter, 7)
|
1458
|
+
assert 7 == counter.value.call_one().get()
|
1459
|
+
|
1460
|
+
|
1461
|
+
class ReceptorActor(Actor):
|
1462
|
+
@endpoint
|
1463
|
+
def status(self):
|
1464
|
+
return 1
|
1465
|
+
|
1466
|
+
|
1467
|
+
async def test_things_survive_losing_python_reference() -> None:
|
1468
|
+
"""Test the slice_receptor_mesh function in LOCAL mode, verifying that setup methods are called."""
|
1469
|
+
|
1470
|
+
receptor = (
|
1471
|
+
this_host()
|
1472
|
+
.spawn_procs(per_host={"gpus": 1})
|
1473
|
+
.spawn(
|
1474
|
+
"receptor",
|
1475
|
+
ReceptorActor,
|
1476
|
+
)
|
1477
|
+
)
|
1478
|
+
receptor = receptor.slice(gpus=0)
|
1479
|
+
|
1480
|
+
await receptor.status.call()
|
1481
|
+
|
1482
|
+
|
1483
|
+
class IsInit(Actor):
|
1484
|
+
@endpoint
|
1485
|
+
def is_cuda_initialized(self) -> bool:
|
1486
|
+
cuda = ctypes.CDLL("libcuda.so.1")
|
1487
|
+
CUresult = ctypes.c_int
|
1488
|
+
cuDeviceGetCount = cuda.cuDeviceGetCount
|
1489
|
+
cuDeviceGetCount.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
1490
|
+
cuDeviceGetCount.restype = CUresult
|
1491
|
+
count = ctypes.c_int()
|
1492
|
+
result = cuDeviceGetCount(ctypes.byref(count))
|
1493
|
+
CUDA_ERROR_NOT_INITIALIZED = 3
|
1494
|
+
return result == CUDA_ERROR_NOT_INITIALIZED
|
1495
|
+
|
1496
|
+
|
1497
|
+
@pytest.mark.oss_skip
|
1498
|
+
def test_cuda_is_not_initialized_in_a_new_proc():
|
1499
|
+
try:
|
1500
|
+
ctypes.CDLL("libcuda.so.1")
|
1501
|
+
except OSError:
|
1502
|
+
pytest.skip("cannot find cuda")
|
1503
|
+
proc = this_host().spawn_procs().spawn("is_init", IsInit)
|
1504
|
+
assert not proc.is_cuda_initialized.call_one().get()
|