torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,240 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import importlib.resources
|
9
|
+
import subprocess
|
10
|
+
|
11
|
+
import pytest
|
12
|
+
from monarch.actor_mesh import Actor, ActorError, endpoint, send
|
13
|
+
|
14
|
+
from monarch.proc_mesh import proc_mesh
|
15
|
+
|
16
|
+
|
17
|
+
class ExceptionActor(Actor):
|
18
|
+
@endpoint
|
19
|
+
async def raise_exception(self) -> None:
|
20
|
+
raise Exception("This is a test exception")
|
21
|
+
|
22
|
+
@endpoint
|
23
|
+
async def print_value(self, value) -> None:
|
24
|
+
"""Endpoint that takes a value and prints it."""
|
25
|
+
print(f"Value received: {value}")
|
26
|
+
return value
|
27
|
+
|
28
|
+
|
29
|
+
class ExceptionActorSync(Actor):
|
30
|
+
@endpoint # pyre-ignore
|
31
|
+
def raise_exception(self) -> None:
|
32
|
+
raise Exception("This is a test exception")
|
33
|
+
|
34
|
+
|
35
|
+
class BrokenPickleClass:
|
36
|
+
"""A class that can be configured to raise exceptions during pickling/unpickling."""
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
raise_on_getstate=False,
|
41
|
+
raise_on_setstate=False,
|
42
|
+
exception_message="Pickle error",
|
43
|
+
):
|
44
|
+
self.raise_on_getstate = raise_on_getstate
|
45
|
+
self.raise_on_setstate = raise_on_setstate
|
46
|
+
self.exception_message = exception_message
|
47
|
+
self.value = "test_value"
|
48
|
+
|
49
|
+
def __getstate__(self):
|
50
|
+
"""Called when pickling the object."""
|
51
|
+
if self.raise_on_getstate:
|
52
|
+
raise RuntimeError(f"__getstate__ error: {self.exception_message}")
|
53
|
+
return {
|
54
|
+
"raise_on_getstate": self.raise_on_getstate,
|
55
|
+
"raise_on_setstate": self.raise_on_setstate,
|
56
|
+
"exception_message": self.exception_message,
|
57
|
+
"value": self.value,
|
58
|
+
}
|
59
|
+
|
60
|
+
def __setstate__(self, state):
|
61
|
+
"""Called when unpickling the object."""
|
62
|
+
if state.get("raise_on_setstate", False):
|
63
|
+
raise RuntimeError(
|
64
|
+
f"__setstate__ error: {state.get('exception_message', 'Unpickle error')}"
|
65
|
+
)
|
66
|
+
self.__dict__.update(state)
|
67
|
+
|
68
|
+
|
69
|
+
@pytest.mark.parametrize(
|
70
|
+
"actor_class",
|
71
|
+
[ExceptionActor, ExceptionActorSync],
|
72
|
+
)
|
73
|
+
@pytest.mark.parametrize("num_procs", [1, 2])
|
74
|
+
async def test_actor_exception(actor_class, num_procs):
|
75
|
+
"""
|
76
|
+
Test that exceptions raised in actor endpoints are propagated to the client.
|
77
|
+
"""
|
78
|
+
proc = await proc_mesh(gpus=num_procs)
|
79
|
+
exception_actor = await proc.spawn("exception_actor", actor_class)
|
80
|
+
|
81
|
+
with pytest.raises(ActorError, match="This is a test exception"):
|
82
|
+
if num_procs == 1:
|
83
|
+
await exception_actor.raise_exception.call_one()
|
84
|
+
else:
|
85
|
+
await exception_actor.raise_exception.call()
|
86
|
+
|
87
|
+
|
88
|
+
@pytest.mark.parametrize(
|
89
|
+
"actor_class",
|
90
|
+
[ExceptionActor, ExceptionActorSync],
|
91
|
+
)
|
92
|
+
@pytest.mark.parametrize("num_procs", [1, 2])
|
93
|
+
def test_actor_exception_sync(actor_class, num_procs):
|
94
|
+
"""
|
95
|
+
Test that exceptions raised in actor endpoints are propagated to the client.
|
96
|
+
"""
|
97
|
+
proc = proc_mesh(gpus=num_procs).get()
|
98
|
+
exception_actor = proc.spawn("exception_actor", actor_class).get()
|
99
|
+
|
100
|
+
with pytest.raises(ActorError, match="This is a test exception"):
|
101
|
+
if num_procs == 1:
|
102
|
+
exception_actor.raise_exception.call_one().get()
|
103
|
+
else:
|
104
|
+
exception_actor.raise_exception.call().get()
|
105
|
+
|
106
|
+
|
107
|
+
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
|
108
|
+
@pytest.mark.oss_skip
|
109
|
+
@pytest.mark.parametrize("num_procs", [1, 2])
|
110
|
+
@pytest.mark.parametrize("sync_endpoint", [False, True])
|
111
|
+
@pytest.mark.parametrize("sync_test_impl", [False, True])
|
112
|
+
@pytest.mark.parametrize("endpoint_name", ["cause_segfault", "cause_panic"])
|
113
|
+
def test_actor_supervision(num_procs, sync_endpoint, sync_test_impl, endpoint_name):
|
114
|
+
"""
|
115
|
+
Test that an endpoint causing spontaenous process exit is handled by the supervisor.
|
116
|
+
|
117
|
+
Today, these events are delivered to the client and cause the client process
|
118
|
+
to exit with a non-zero code, so the only way we can test it is via a
|
119
|
+
subprocess harness.
|
120
|
+
"""
|
121
|
+
# Run the segfault test in a subprocess
|
122
|
+
test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
|
123
|
+
cmd = [
|
124
|
+
str(test_bin),
|
125
|
+
"error-endpoint",
|
126
|
+
f"--num-procs={num_procs}",
|
127
|
+
f"--sync-endpoint={sync_endpoint}",
|
128
|
+
f"--sync-test-impl={sync_test_impl}",
|
129
|
+
f"--endpoint-name={endpoint_name}",
|
130
|
+
]
|
131
|
+
try:
|
132
|
+
print("running cmd", " ".join(cmd))
|
133
|
+
process = subprocess.run(cmd, capture_output=True, timeout=180)
|
134
|
+
except subprocess.TimeoutExpired as e:
|
135
|
+
print("timeout expired")
|
136
|
+
if e.stdout is not None:
|
137
|
+
print(e.stdout.decode())
|
138
|
+
if e.stderr is not None:
|
139
|
+
print(e.stderr.decode())
|
140
|
+
raise
|
141
|
+
|
142
|
+
# Assert that the subprocess exited with a non-zero code
|
143
|
+
assert "I actually ran" in process.stdout.decode()
|
144
|
+
assert (
|
145
|
+
process.returncode != 0
|
146
|
+
), f"Expected non-zero exit code, got {process.returncode}"
|
147
|
+
|
148
|
+
|
149
|
+
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
|
150
|
+
@pytest.mark.oss_skip
|
151
|
+
def test_proc_mesh_bootstrap_error():
|
152
|
+
"""
|
153
|
+
Test that attempts to spawn a ProcMesh with a failure during bootstrap.
|
154
|
+
"""
|
155
|
+
# Run the segfault test in a subprocess
|
156
|
+
test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
|
157
|
+
cmd = [
|
158
|
+
str(test_bin),
|
159
|
+
"error-bootstrap",
|
160
|
+
]
|
161
|
+
try:
|
162
|
+
print("running cmd", " ".join(cmd))
|
163
|
+
process = subprocess.run(cmd, capture_output=True, timeout=180)
|
164
|
+
except subprocess.TimeoutExpired as e:
|
165
|
+
print("timeout expired")
|
166
|
+
if e.stdout is not None:
|
167
|
+
print(e.stdout.decode())
|
168
|
+
if e.stderr is not None:
|
169
|
+
print(e.stderr.decode())
|
170
|
+
raise
|
171
|
+
|
172
|
+
# Assert that the subprocess exited with a non-zero code
|
173
|
+
assert "I actually ran" in process.stdout.decode()
|
174
|
+
assert (
|
175
|
+
process.returncode != 0
|
176
|
+
), f"Expected non-zero exit code, got {process.returncode}"
|
177
|
+
|
178
|
+
|
179
|
+
@pytest.mark.parametrize("raise_on_getstate", [True, False])
|
180
|
+
@pytest.mark.parametrize("raise_on_setstate", [True, False])
|
181
|
+
@pytest.mark.parametrize("num_procs", [1, 2])
|
182
|
+
async def test_broken_pickle_class(raise_on_getstate, raise_on_setstate, num_procs):
|
183
|
+
"""
|
184
|
+
Test that exceptions during pickling/unpickling are properly handled.
|
185
|
+
|
186
|
+
This test creates a BrokenPickleClass instance configured to raise exceptions
|
187
|
+
during __getstate__ and/or __setstate__, then passes it to an ExceptionActor's
|
188
|
+
print_value endpoint and verifies that an ActorError is raised.
|
189
|
+
"""
|
190
|
+
if not raise_on_getstate and not raise_on_setstate:
|
191
|
+
# Pass this test trivially
|
192
|
+
return
|
193
|
+
|
194
|
+
proc = await proc_mesh(gpus=num_procs)
|
195
|
+
exception_actor = await proc.spawn("exception_actor", ExceptionActor)
|
196
|
+
|
197
|
+
# Create a BrokenPickleClass instance configured to raise exceptions
|
198
|
+
broken_obj = BrokenPickleClass(
|
199
|
+
raise_on_getstate=raise_on_getstate,
|
200
|
+
raise_on_setstate=raise_on_setstate,
|
201
|
+
exception_message="Test pickle error",
|
202
|
+
)
|
203
|
+
|
204
|
+
# On the getstate path, we expect a RuntimeError to be raised locally.
|
205
|
+
# On the setstate path, we expect an ActorError to be raised remotely.
|
206
|
+
error_type = RuntimeError if raise_on_getstate else ActorError
|
207
|
+
error_pattern = "__getstate__ error" if raise_on_getstate else "__setstate__ error"
|
208
|
+
|
209
|
+
with pytest.raises(error_type, match=error_pattern):
|
210
|
+
if num_procs == 1:
|
211
|
+
await exception_actor.print_value.call_one(broken_obj)
|
212
|
+
else:
|
213
|
+
await exception_actor.print_value.call(broken_obj)
|
214
|
+
|
215
|
+
|
216
|
+
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
|
217
|
+
@pytest.mark.oss_skip
|
218
|
+
async def test_exception_after_wait_unmonitored():
|
219
|
+
# Run the test in a subprocess
|
220
|
+
test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
|
221
|
+
cmd = [
|
222
|
+
str(test_bin),
|
223
|
+
"error-unmonitored",
|
224
|
+
]
|
225
|
+
try:
|
226
|
+
print("running cmd", " ".join(cmd))
|
227
|
+
process = subprocess.run(cmd, capture_output=True, timeout=180)
|
228
|
+
except subprocess.TimeoutExpired as e:
|
229
|
+
print("timeout expired")
|
230
|
+
if e.stdout is not None:
|
231
|
+
print(e.stdout.decode())
|
232
|
+
if e.stderr is not None:
|
233
|
+
print(e.stderr.decode())
|
234
|
+
raise
|
235
|
+
|
236
|
+
# Assert that the subprocess exited with a non-zero code
|
237
|
+
assert "I actually ran" in process.stdout.decode()
|
238
|
+
assert (
|
239
|
+
process.returncode != 0
|
240
|
+
), f"Expected non-zero exit code, got {process.returncode}"
|
tests/test_alloc.py
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
from unittest import IsolatedAsyncioTestCase
|
10
|
+
|
11
|
+
from monarch import ProcessAllocator
|
12
|
+
from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
13
|
+
AllocConstraints,
|
14
|
+
AllocSpec,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
class TestAlloc(IsolatedAsyncioTestCase):
|
19
|
+
async def test_basic(self) -> None:
|
20
|
+
cmd = "echo hello"
|
21
|
+
allocator = ProcessAllocator(cmd)
|
22
|
+
spec = AllocSpec(AllocConstraints(), replica=2)
|
23
|
+
alloc = await allocator.allocate(spec)
|
24
|
+
|
25
|
+
print(alloc)
|
tests/test_allocator.py
ADDED
@@ -0,0 +1,365 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
import contextlib
|
10
|
+
import importlib.resources
|
11
|
+
import math
|
12
|
+
import os
|
13
|
+
import subprocess
|
14
|
+
import sys
|
15
|
+
import unittest
|
16
|
+
from datetime import timedelta
|
17
|
+
from typing import Generator, Optional
|
18
|
+
from unittest import mock
|
19
|
+
|
20
|
+
import cloudpickle
|
21
|
+
import pytest
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import torch.distributed as dist
|
25
|
+
import torch.nn.functional as F
|
26
|
+
|
27
|
+
from monarch._rust_bindings.hyperactor_extension.alloc import (
|
28
|
+
AllocConstraints,
|
29
|
+
AllocSpec,
|
30
|
+
)
|
31
|
+
from monarch._rust_bindings.monarch_hyperactor.channel import (
|
32
|
+
ChannelAddr,
|
33
|
+
ChannelTransport,
|
34
|
+
)
|
35
|
+
from monarch.actor_mesh import Actor, current_rank, current_size, endpoint, ValueMesh
|
36
|
+
from monarch.allocator import (
|
37
|
+
ALLOC_LABEL_PROC_MESH_NAME,
|
38
|
+
RemoteAllocator,
|
39
|
+
StaticRemoteAllocInitializer,
|
40
|
+
TorchXRemoteAllocInitializer,
|
41
|
+
)
|
42
|
+
from monarch.proc_mesh import ProcMesh
|
43
|
+
from monarch.tools.mesh_spec import MeshSpec, ServerSpec
|
44
|
+
from monarch.tools.network import get_sockaddr
|
45
|
+
|
46
|
+
from torch.distributed.elastic.utils.distributed import get_free_port
|
47
|
+
from torchx.specs import AppState
|
48
|
+
|
49
|
+
_100_MILLISECONDS = timedelta(milliseconds=100)
|
50
|
+
|
51
|
+
SERVER_READY = "monarch.tools.commands.server_ready"
|
52
|
+
|
53
|
+
|
54
|
+
class TestActor(Actor):
|
55
|
+
"""Silly actor that computes the world size by all-reducing rank-hot tensors"""
|
56
|
+
|
57
|
+
def __init__(self) -> None:
|
58
|
+
self.rank: int = current_rank().rank
|
59
|
+
self.world_size: int = math.prod(current_size().values())
|
60
|
+
|
61
|
+
@endpoint
|
62
|
+
async def compute_world_size(self, master_addr: str, master_port: int) -> int:
|
63
|
+
os.environ["MASTER_ADDR"] = master_addr
|
64
|
+
os.environ["MASTER_PORT"] = str(master_port)
|
65
|
+
dist.init_process_group("gloo", rank=self.rank, world_size=self.world_size)
|
66
|
+
|
67
|
+
try:
|
68
|
+
t = F.one_hot(torch.tensor(self.rank), num_classes=dist.get_world_size())
|
69
|
+
dist.all_reduce(t)
|
70
|
+
return int(torch.sum(t).item())
|
71
|
+
finally:
|
72
|
+
dist.destroy_process_group()
|
73
|
+
|
74
|
+
|
75
|
+
@contextlib.contextmanager
|
76
|
+
def remote_process_allocator(addr: Optional[str] = None) -> Generator[str, None, None]:
|
77
|
+
with importlib.resources.path(__package__, "") as package_path:
|
78
|
+
addr = addr or ChannelAddr.any(ChannelTransport.Unix)
|
79
|
+
|
80
|
+
process_allocator = subprocess.Popen(
|
81
|
+
args=[
|
82
|
+
"process_allocator",
|
83
|
+
f"--addr={addr}",
|
84
|
+
],
|
85
|
+
env={
|
86
|
+
# prefix PATH with this test module's directory to
|
87
|
+
# give 'process_allocator' and 'monarch_bootstrap' binary resources
|
88
|
+
# in this test module's directory precedence over the installed ones
|
89
|
+
# useful in BUCK where these binaries are added as 'resources' of this test target
|
90
|
+
"PATH": f"{package_path}:{os.getenv('PATH', '')}",
|
91
|
+
"RUST_LOG": "debug",
|
92
|
+
},
|
93
|
+
)
|
94
|
+
try:
|
95
|
+
yield addr
|
96
|
+
finally:
|
97
|
+
process_allocator.terminate()
|
98
|
+
try:
|
99
|
+
five_seconds = 5
|
100
|
+
process_allocator.wait(timeout=five_seconds)
|
101
|
+
except subprocess.TimeoutExpired:
|
102
|
+
process_allocator.kill()
|
103
|
+
|
104
|
+
|
105
|
+
class TestRemoteAllocator(unittest.IsolatedAsyncioTestCase):
|
106
|
+
@classmethod
|
107
|
+
def setUpClass(cls) -> None:
|
108
|
+
cloudpickle.register_pickle_by_value(sys.modules[TestActor.__module__])
|
109
|
+
|
110
|
+
@classmethod
|
111
|
+
def tearDownClass(cls) -> None:
|
112
|
+
cloudpickle.unregister_pickle_by_value(sys.modules[TestActor.__module__])
|
113
|
+
|
114
|
+
def assert_computed_world_size(
|
115
|
+
self, computed: ValueMesh[int], expected_world_size: int
|
116
|
+
) -> None:
|
117
|
+
expected_world_sizes = {
|
118
|
+
rank: expected_world_size for rank in range(0, expected_world_size)
|
119
|
+
}
|
120
|
+
computed_world_sizes = {p.rank: v for p, v in list(computed.flatten("rank"))}
|
121
|
+
self.assertDictEqual(expected_world_sizes, computed_world_sizes)
|
122
|
+
|
123
|
+
async def test_call_allocate_twice(self) -> None:
|
124
|
+
class DeletingAllocInitializer(StaticRemoteAllocInitializer):
|
125
|
+
"""test initializer that removes the last address from the list each time initialize_alloc() is called
|
126
|
+
used to test that the state of the initializer is preserved across calls to allocate()
|
127
|
+
"""
|
128
|
+
|
129
|
+
async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
|
130
|
+
alloc = await super().initialize_alloc(match_labels)
|
131
|
+
self.addrs.pop(-1)
|
132
|
+
return alloc
|
133
|
+
|
134
|
+
with remote_process_allocator() as host1, remote_process_allocator() as host2:
|
135
|
+
initializer = DeletingAllocInitializer(host1, host2)
|
136
|
+
|
137
|
+
allocator = RemoteAllocator(
|
138
|
+
world_id="test_remote_allocator",
|
139
|
+
initializer=initializer,
|
140
|
+
heartbeat_interval=_100_MILLISECONDS,
|
141
|
+
)
|
142
|
+
|
143
|
+
spec = AllocSpec(AllocConstraints(), host=1, gpu=1)
|
144
|
+
|
145
|
+
await allocator.allocate(spec)
|
146
|
+
self.assertEqual([host1], initializer.addrs)
|
147
|
+
|
148
|
+
await allocator.allocate(spec)
|
149
|
+
self.assertEqual([], initializer.addrs)
|
150
|
+
|
151
|
+
async def test_throws_when_initializer_returns_empty_addrs(self) -> None:
|
152
|
+
class EmptyAllocInitializer(StaticRemoteAllocInitializer):
|
153
|
+
"""test initializer that returns an empty list of addresses"""
|
154
|
+
|
155
|
+
async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
|
156
|
+
_ = match_labels # Suppress unused variable warning
|
157
|
+
return []
|
158
|
+
|
159
|
+
empty_initializer = EmptyAllocInitializer()
|
160
|
+
with self.assertRaisesRegex(
|
161
|
+
RuntimeError, r"initializer must return non-empty list of addresses"
|
162
|
+
):
|
163
|
+
allocator = RemoteAllocator(
|
164
|
+
world_id="test_remote_allocator",
|
165
|
+
initializer=empty_initializer,
|
166
|
+
heartbeat_interval=_100_MILLISECONDS,
|
167
|
+
)
|
168
|
+
await allocator.allocate(AllocSpec(AllocConstraints(), host=1, gpu=1))
|
169
|
+
|
170
|
+
async def test_allocate_2d_mesh(self) -> None:
|
171
|
+
hosts = 2
|
172
|
+
gpus = 4
|
173
|
+
world_size = hosts * gpus
|
174
|
+
spec = AllocSpec(AllocConstraints(), host=hosts, gpu=gpus)
|
175
|
+
|
176
|
+
# create 2x process-allocators (on their own bind addresses) to simulate 2 hosts
|
177
|
+
with remote_process_allocator() as host1, remote_process_allocator() as host2:
|
178
|
+
allocator = RemoteAllocator(
|
179
|
+
world_id="test_remote_allocator",
|
180
|
+
initializer=StaticRemoteAllocInitializer(host1, host2),
|
181
|
+
heartbeat_interval=_100_MILLISECONDS,
|
182
|
+
)
|
183
|
+
alloc = await allocator.allocate(spec)
|
184
|
+
proc_mesh = await ProcMesh.from_alloc(alloc)
|
185
|
+
actor = await proc_mesh.spawn("test_actor", TestActor)
|
186
|
+
|
187
|
+
values = await actor.compute_world_size.call(
|
188
|
+
master_addr="0.0.0.0",
|
189
|
+
master_port=get_free_port(),
|
190
|
+
)
|
191
|
+
|
192
|
+
self.assert_computed_world_size(values, world_size)
|
193
|
+
|
194
|
+
async def test_stacked_1d_meshes(self) -> None:
|
195
|
+
# create two stacked actor meshes on the same host
|
196
|
+
# each actor mesh running on separate process-allocators
|
197
|
+
|
198
|
+
with remote_process_allocator() as host1_a, remote_process_allocator() as host1_b:
|
199
|
+
allocator_a = RemoteAllocator(
|
200
|
+
world_id="a",
|
201
|
+
initializer=StaticRemoteAllocInitializer(host1_a),
|
202
|
+
heartbeat_interval=_100_MILLISECONDS,
|
203
|
+
)
|
204
|
+
allocator_b = RemoteAllocator(
|
205
|
+
world_id="b",
|
206
|
+
initializer=StaticRemoteAllocInitializer(host1_b),
|
207
|
+
heartbeat_interval=_100_MILLISECONDS,
|
208
|
+
)
|
209
|
+
|
210
|
+
spec_a = AllocSpec(AllocConstraints(), host=1, gpu=2)
|
211
|
+
spec_b = AllocSpec(AllocConstraints(), host=1, gpu=6)
|
212
|
+
|
213
|
+
proc_mesh_a = await ProcMesh.from_alloc(await allocator_a.allocate(spec_a))
|
214
|
+
proc_mesh_b = await ProcMesh.from_alloc(await allocator_b.allocate(spec_b))
|
215
|
+
|
216
|
+
actor_a = await proc_mesh_a.spawn("actor_a", TestActor)
|
217
|
+
actor_b = await proc_mesh_b.spawn("actor_b", TestActor)
|
218
|
+
|
219
|
+
results_a = await actor_a.compute_world_size.call(
|
220
|
+
master_addr="0.0.0.0", master_port=get_free_port()
|
221
|
+
)
|
222
|
+
results_b = await actor_b.compute_world_size.call(
|
223
|
+
master_addr="0.0.0.0", master_port=get_free_port()
|
224
|
+
)
|
225
|
+
|
226
|
+
self.assert_computed_world_size(results_a, 2) # a is a 1x2 mesh
|
227
|
+
self.assert_computed_world_size(results_b, 6) # b is a 1x6 mesh
|
228
|
+
|
229
|
+
async def test_torchx_remote_alloc_initializer_no_server(self) -> None:
|
230
|
+
with mock.patch(SERVER_READY, return_value=None):
|
231
|
+
initializer = TorchXRemoteAllocInitializer("slurm:///123")
|
232
|
+
allocator = RemoteAllocator(world_id="test", initializer=initializer)
|
233
|
+
|
234
|
+
with self.assertRaisesRegex(
|
235
|
+
RuntimeError,
|
236
|
+
r"slurm:///123 does not exist or is in a terminal state",
|
237
|
+
):
|
238
|
+
await allocator.allocate(AllocSpec(AllocConstraints(), host=1, gpu=1))
|
239
|
+
|
240
|
+
async def test_torchx_remote_alloc_initializer_no_match_label_gt_1_meshes(
|
241
|
+
self,
|
242
|
+
) -> None:
|
243
|
+
# asserts that an exception is raised if no match label is specified in alloc constraints
|
244
|
+
# but there are more than 1 mesh (hence ambiguous which mesh to allocate on)
|
245
|
+
|
246
|
+
server = ServerSpec(
|
247
|
+
name="__UNUSED__",
|
248
|
+
state=AppState.RUNNING,
|
249
|
+
meshes=[MeshSpec(name="x", num_hosts=1), MeshSpec(name="y", num_hosts=1)],
|
250
|
+
)
|
251
|
+
|
252
|
+
with mock.patch(SERVER_READY, return_value=server):
|
253
|
+
initializer = TorchXRemoteAllocInitializer("slurm:///123")
|
254
|
+
allocator = RemoteAllocator(world_id="test", initializer=initializer)
|
255
|
+
|
256
|
+
with self.assertRaisesRegex(
|
257
|
+
RuntimeError,
|
258
|
+
r"2 proc meshes in slurm:///123, please specify the mesh name as a match label `procmesh.monarch.meta.com/name`",
|
259
|
+
):
|
260
|
+
await allocator.allocate(AllocSpec(AllocConstraints(), host=1, gpu=1))
|
261
|
+
|
262
|
+
@pytest.mark.oss_skip # pyre-ignore[56] TODO T228752279
|
263
|
+
async def test_torchx_remote_alloc_initializer_no_match_label_1_mesh(self) -> None:
|
264
|
+
server = ServerSpec(
|
265
|
+
name="__UNUSED__",
|
266
|
+
state=AppState.RUNNING,
|
267
|
+
meshes=[
|
268
|
+
MeshSpec(
|
269
|
+
name="x",
|
270
|
+
num_hosts=1,
|
271
|
+
transport="tcp",
|
272
|
+
hostnames=["localhost"],
|
273
|
+
)
|
274
|
+
],
|
275
|
+
)
|
276
|
+
port = get_free_port()
|
277
|
+
with remote_process_allocator(addr=f"tcp!{get_sockaddr('localhost', port)}"):
|
278
|
+
with mock.patch(SERVER_READY, return_value=server):
|
279
|
+
initializer = TorchXRemoteAllocInitializer("local:///test", port=port)
|
280
|
+
allocator = RemoteAllocator(
|
281
|
+
world_id="test",
|
282
|
+
initializer=initializer,
|
283
|
+
heartbeat_interval=_100_MILLISECONDS,
|
284
|
+
)
|
285
|
+
alloc = await allocator.allocate(
|
286
|
+
AllocSpec(AllocConstraints(), host=1, gpu=4)
|
287
|
+
)
|
288
|
+
proc_mesh = await ProcMesh.from_alloc(alloc)
|
289
|
+
actor = await proc_mesh.spawn("test_actor", TestActor)
|
290
|
+
results = await actor.compute_world_size.call(
|
291
|
+
master_addr="0.0.0.0", master_port=get_free_port()
|
292
|
+
)
|
293
|
+
self.assert_computed_world_size(results, 4) # 1x4 mesh
|
294
|
+
|
295
|
+
@pytest.mark.oss_skip # pyre-ignore[56] TODO T228752279
|
296
|
+
async def test_torchx_remote_alloc_initializer_with_match_label(self) -> None:
|
297
|
+
server = ServerSpec(
|
298
|
+
name="__UNUSED__",
|
299
|
+
state=AppState.RUNNING,
|
300
|
+
meshes=[
|
301
|
+
MeshSpec(
|
302
|
+
name="x",
|
303
|
+
num_hosts=1,
|
304
|
+
transport="tcp",
|
305
|
+
hostnames=["localhost"],
|
306
|
+
)
|
307
|
+
],
|
308
|
+
)
|
309
|
+
port = get_free_port()
|
310
|
+
with remote_process_allocator(addr=f"tcp!{get_sockaddr('localhost', port)}"):
|
311
|
+
with mock.patch(SERVER_READY, return_value=server):
|
312
|
+
initializer = TorchXRemoteAllocInitializer("local:///test", port=port)
|
313
|
+
allocator = RemoteAllocator(
|
314
|
+
world_id="test",
|
315
|
+
initializer=initializer,
|
316
|
+
heartbeat_interval=_100_MILLISECONDS,
|
317
|
+
)
|
318
|
+
alloc = await allocator.allocate(
|
319
|
+
AllocSpec(
|
320
|
+
AllocConstraints(
|
321
|
+
match_labels={ALLOC_LABEL_PROC_MESH_NAME: "x"}
|
322
|
+
),
|
323
|
+
host=1,
|
324
|
+
gpu=3,
|
325
|
+
)
|
326
|
+
)
|
327
|
+
proc_mesh = await ProcMesh.from_alloc(alloc)
|
328
|
+
actor = await proc_mesh.spawn("test_actor", TestActor)
|
329
|
+
results = await actor.compute_world_size.call(
|
330
|
+
master_addr="0.0.0.0", master_port=get_free_port()
|
331
|
+
)
|
332
|
+
self.assert_computed_world_size(results, 3) # 1x3 mesh
|
333
|
+
|
334
|
+
async def test_torchx_remote_alloc_initializer_with_match_label_no_match(
|
335
|
+
self,
|
336
|
+
) -> None:
|
337
|
+
# assert that match label with a mesh name that does not exist should error out
|
338
|
+
|
339
|
+
server = ServerSpec(
|
340
|
+
name="test",
|
341
|
+
state=AppState.RUNNING,
|
342
|
+
meshes=[
|
343
|
+
MeshSpec(
|
344
|
+
name="x",
|
345
|
+
num_hosts=1,
|
346
|
+
transport="tcp",
|
347
|
+
hostnames=["localhost"],
|
348
|
+
)
|
349
|
+
],
|
350
|
+
)
|
351
|
+
|
352
|
+
with mock.patch(SERVER_READY, return_value=server):
|
353
|
+
with self.assertRaisesRegex(RuntimeError, r"'y' not found in job: test"):
|
354
|
+
initializer = TorchXRemoteAllocInitializer("local:///test")
|
355
|
+
allocator = RemoteAllocator(world_id="test", initializer=initializer)
|
356
|
+
alloc = await allocator.allocate(
|
357
|
+
AllocSpec(
|
358
|
+
AllocConstraints(
|
359
|
+
match_labels={ALLOC_LABEL_PROC_MESH_NAME: "y"}
|
360
|
+
),
|
361
|
+
host=1,
|
362
|
+
gpu=1,
|
363
|
+
)
|
364
|
+
)
|
365
|
+
await ProcMesh.from_alloc(alloc)
|