torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,240 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import asyncio
8
+ import importlib.resources
9
+ import subprocess
10
+
11
+ import pytest
12
+ from monarch.actor_mesh import Actor, ActorError, endpoint, send
13
+
14
+ from monarch.proc_mesh import proc_mesh
15
+
16
+
17
+ class ExceptionActor(Actor):
18
+ @endpoint
19
+ async def raise_exception(self) -> None:
20
+ raise Exception("This is a test exception")
21
+
22
+ @endpoint
23
+ async def print_value(self, value) -> None:
24
+ """Endpoint that takes a value and prints it."""
25
+ print(f"Value received: {value}")
26
+ return value
27
+
28
+
29
+ class ExceptionActorSync(Actor):
30
+ @endpoint # pyre-ignore
31
+ def raise_exception(self) -> None:
32
+ raise Exception("This is a test exception")
33
+
34
+
35
+ class BrokenPickleClass:
36
+ """A class that can be configured to raise exceptions during pickling/unpickling."""
37
+
38
+ def __init__(
39
+ self,
40
+ raise_on_getstate=False,
41
+ raise_on_setstate=False,
42
+ exception_message="Pickle error",
43
+ ):
44
+ self.raise_on_getstate = raise_on_getstate
45
+ self.raise_on_setstate = raise_on_setstate
46
+ self.exception_message = exception_message
47
+ self.value = "test_value"
48
+
49
+ def __getstate__(self):
50
+ """Called when pickling the object."""
51
+ if self.raise_on_getstate:
52
+ raise RuntimeError(f"__getstate__ error: {self.exception_message}")
53
+ return {
54
+ "raise_on_getstate": self.raise_on_getstate,
55
+ "raise_on_setstate": self.raise_on_setstate,
56
+ "exception_message": self.exception_message,
57
+ "value": self.value,
58
+ }
59
+
60
+ def __setstate__(self, state):
61
+ """Called when unpickling the object."""
62
+ if state.get("raise_on_setstate", False):
63
+ raise RuntimeError(
64
+ f"__setstate__ error: {state.get('exception_message', 'Unpickle error')}"
65
+ )
66
+ self.__dict__.update(state)
67
+
68
+
69
+ @pytest.mark.parametrize(
70
+ "actor_class",
71
+ [ExceptionActor, ExceptionActorSync],
72
+ )
73
+ @pytest.mark.parametrize("num_procs", [1, 2])
74
+ async def test_actor_exception(actor_class, num_procs):
75
+ """
76
+ Test that exceptions raised in actor endpoints are propagated to the client.
77
+ """
78
+ proc = await proc_mesh(gpus=num_procs)
79
+ exception_actor = await proc.spawn("exception_actor", actor_class)
80
+
81
+ with pytest.raises(ActorError, match="This is a test exception"):
82
+ if num_procs == 1:
83
+ await exception_actor.raise_exception.call_one()
84
+ else:
85
+ await exception_actor.raise_exception.call()
86
+
87
+
88
+ @pytest.mark.parametrize(
89
+ "actor_class",
90
+ [ExceptionActor, ExceptionActorSync],
91
+ )
92
+ @pytest.mark.parametrize("num_procs", [1, 2])
93
+ def test_actor_exception_sync(actor_class, num_procs):
94
+ """
95
+ Test that exceptions raised in actor endpoints are propagated to the client.
96
+ """
97
+ proc = proc_mesh(gpus=num_procs).get()
98
+ exception_actor = proc.spawn("exception_actor", actor_class).get()
99
+
100
+ with pytest.raises(ActorError, match="This is a test exception"):
101
+ if num_procs == 1:
102
+ exception_actor.raise_exception.call_one().get()
103
+ else:
104
+ exception_actor.raise_exception.call().get()
105
+
106
+
107
+ # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
108
+ @pytest.mark.oss_skip
109
+ @pytest.mark.parametrize("num_procs", [1, 2])
110
+ @pytest.mark.parametrize("sync_endpoint", [False, True])
111
+ @pytest.mark.parametrize("sync_test_impl", [False, True])
112
+ @pytest.mark.parametrize("endpoint_name", ["cause_segfault", "cause_panic"])
113
+ def test_actor_supervision(num_procs, sync_endpoint, sync_test_impl, endpoint_name):
114
+ """
115
+ Test that an endpoint causing spontaenous process exit is handled by the supervisor.
116
+
117
+ Today, these events are delivered to the client and cause the client process
118
+ to exit with a non-zero code, so the only way we can test it is via a
119
+ subprocess harness.
120
+ """
121
+ # Run the segfault test in a subprocess
122
+ test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
123
+ cmd = [
124
+ str(test_bin),
125
+ "error-endpoint",
126
+ f"--num-procs={num_procs}",
127
+ f"--sync-endpoint={sync_endpoint}",
128
+ f"--sync-test-impl={sync_test_impl}",
129
+ f"--endpoint-name={endpoint_name}",
130
+ ]
131
+ try:
132
+ print("running cmd", " ".join(cmd))
133
+ process = subprocess.run(cmd, capture_output=True, timeout=180)
134
+ except subprocess.TimeoutExpired as e:
135
+ print("timeout expired")
136
+ if e.stdout is not None:
137
+ print(e.stdout.decode())
138
+ if e.stderr is not None:
139
+ print(e.stderr.decode())
140
+ raise
141
+
142
+ # Assert that the subprocess exited with a non-zero code
143
+ assert "I actually ran" in process.stdout.decode()
144
+ assert (
145
+ process.returncode != 0
146
+ ), f"Expected non-zero exit code, got {process.returncode}"
147
+
148
+
149
+ # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
150
+ @pytest.mark.oss_skip
151
+ def test_proc_mesh_bootstrap_error():
152
+ """
153
+ Test that attempts to spawn a ProcMesh with a failure during bootstrap.
154
+ """
155
+ # Run the segfault test in a subprocess
156
+ test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
157
+ cmd = [
158
+ str(test_bin),
159
+ "error-bootstrap",
160
+ ]
161
+ try:
162
+ print("running cmd", " ".join(cmd))
163
+ process = subprocess.run(cmd, capture_output=True, timeout=180)
164
+ except subprocess.TimeoutExpired as e:
165
+ print("timeout expired")
166
+ if e.stdout is not None:
167
+ print(e.stdout.decode())
168
+ if e.stderr is not None:
169
+ print(e.stderr.decode())
170
+ raise
171
+
172
+ # Assert that the subprocess exited with a non-zero code
173
+ assert "I actually ran" in process.stdout.decode()
174
+ assert (
175
+ process.returncode != 0
176
+ ), f"Expected non-zero exit code, got {process.returncode}"
177
+
178
+
179
+ @pytest.mark.parametrize("raise_on_getstate", [True, False])
180
+ @pytest.mark.parametrize("raise_on_setstate", [True, False])
181
+ @pytest.mark.parametrize("num_procs", [1, 2])
182
+ async def test_broken_pickle_class(raise_on_getstate, raise_on_setstate, num_procs):
183
+ """
184
+ Test that exceptions during pickling/unpickling are properly handled.
185
+
186
+ This test creates a BrokenPickleClass instance configured to raise exceptions
187
+ during __getstate__ and/or __setstate__, then passes it to an ExceptionActor's
188
+ print_value endpoint and verifies that an ActorError is raised.
189
+ """
190
+ if not raise_on_getstate and not raise_on_setstate:
191
+ # Pass this test trivially
192
+ return
193
+
194
+ proc = await proc_mesh(gpus=num_procs)
195
+ exception_actor = await proc.spawn("exception_actor", ExceptionActor)
196
+
197
+ # Create a BrokenPickleClass instance configured to raise exceptions
198
+ broken_obj = BrokenPickleClass(
199
+ raise_on_getstate=raise_on_getstate,
200
+ raise_on_setstate=raise_on_setstate,
201
+ exception_message="Test pickle error",
202
+ )
203
+
204
+ # On the getstate path, we expect a RuntimeError to be raised locally.
205
+ # On the setstate path, we expect an ActorError to be raised remotely.
206
+ error_type = RuntimeError if raise_on_getstate else ActorError
207
+ error_pattern = "__getstate__ error" if raise_on_getstate else "__setstate__ error"
208
+
209
+ with pytest.raises(error_type, match=error_pattern):
210
+ if num_procs == 1:
211
+ await exception_actor.print_value.call_one(broken_obj)
212
+ else:
213
+ await exception_actor.print_value.call(broken_obj)
214
+
215
+
216
+ # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
217
+ @pytest.mark.oss_skip
218
+ async def test_exception_after_wait_unmonitored():
219
+ # Run the test in a subprocess
220
+ test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
221
+ cmd = [
222
+ str(test_bin),
223
+ "error-unmonitored",
224
+ ]
225
+ try:
226
+ print("running cmd", " ".join(cmd))
227
+ process = subprocess.run(cmd, capture_output=True, timeout=180)
228
+ except subprocess.TimeoutExpired as e:
229
+ print("timeout expired")
230
+ if e.stdout is not None:
231
+ print(e.stdout.decode())
232
+ if e.stderr is not None:
233
+ print(e.stderr.decode())
234
+ raise
235
+
236
+ # Assert that the subprocess exited with a non-zero code
237
+ assert "I actually ran" in process.stdout.decode()
238
+ assert (
239
+ process.returncode != 0
240
+ ), f"Expected non-zero exit code, got {process.returncode}"
tests/test_alloc.py ADDED
@@ -0,0 +1,25 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from unittest import IsolatedAsyncioTestCase
10
+
11
+ from monarch import ProcessAllocator
12
+ from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
13
+ AllocConstraints,
14
+ AllocSpec,
15
+ )
16
+
17
+
18
+ class TestAlloc(IsolatedAsyncioTestCase):
19
+ async def test_basic(self) -> None:
20
+ cmd = "echo hello"
21
+ allocator = ProcessAllocator(cmd)
22
+ spec = AllocSpec(AllocConstraints(), replica=2)
23
+ alloc = await allocator.allocate(spec)
24
+
25
+ print(alloc)
@@ -0,0 +1,365 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import contextlib
10
+ import importlib.resources
11
+ import math
12
+ import os
13
+ import subprocess
14
+ import sys
15
+ import unittest
16
+ from datetime import timedelta
17
+ from typing import Generator, Optional
18
+ from unittest import mock
19
+
20
+ import cloudpickle
21
+ import pytest
22
+
23
+ import torch
24
+ import torch.distributed as dist
25
+ import torch.nn.functional as F
26
+
27
+ from monarch._rust_bindings.hyperactor_extension.alloc import (
28
+ AllocConstraints,
29
+ AllocSpec,
30
+ )
31
+ from monarch._rust_bindings.monarch_hyperactor.channel import (
32
+ ChannelAddr,
33
+ ChannelTransport,
34
+ )
35
+ from monarch.actor_mesh import Actor, current_rank, current_size, endpoint, ValueMesh
36
+ from monarch.allocator import (
37
+ ALLOC_LABEL_PROC_MESH_NAME,
38
+ RemoteAllocator,
39
+ StaticRemoteAllocInitializer,
40
+ TorchXRemoteAllocInitializer,
41
+ )
42
+ from monarch.proc_mesh import ProcMesh
43
+ from monarch.tools.mesh_spec import MeshSpec, ServerSpec
44
+ from monarch.tools.network import get_sockaddr
45
+
46
+ from torch.distributed.elastic.utils.distributed import get_free_port
47
+ from torchx.specs import AppState
48
+
49
+ _100_MILLISECONDS = timedelta(milliseconds=100)
50
+
51
+ SERVER_READY = "monarch.tools.commands.server_ready"
52
+
53
+
54
+ class TestActor(Actor):
55
+ """Silly actor that computes the world size by all-reducing rank-hot tensors"""
56
+
57
+ def __init__(self) -> None:
58
+ self.rank: int = current_rank().rank
59
+ self.world_size: int = math.prod(current_size().values())
60
+
61
+ @endpoint
62
+ async def compute_world_size(self, master_addr: str, master_port: int) -> int:
63
+ os.environ["MASTER_ADDR"] = master_addr
64
+ os.environ["MASTER_PORT"] = str(master_port)
65
+ dist.init_process_group("gloo", rank=self.rank, world_size=self.world_size)
66
+
67
+ try:
68
+ t = F.one_hot(torch.tensor(self.rank), num_classes=dist.get_world_size())
69
+ dist.all_reduce(t)
70
+ return int(torch.sum(t).item())
71
+ finally:
72
+ dist.destroy_process_group()
73
+
74
+
75
+ @contextlib.contextmanager
76
+ def remote_process_allocator(addr: Optional[str] = None) -> Generator[str, None, None]:
77
+ with importlib.resources.path(__package__, "") as package_path:
78
+ addr = addr or ChannelAddr.any(ChannelTransport.Unix)
79
+
80
+ process_allocator = subprocess.Popen(
81
+ args=[
82
+ "process_allocator",
83
+ f"--addr={addr}",
84
+ ],
85
+ env={
86
+ # prefix PATH with this test module's directory to
87
+ # give 'process_allocator' and 'monarch_bootstrap' binary resources
88
+ # in this test module's directory precedence over the installed ones
89
+ # useful in BUCK where these binaries are added as 'resources' of this test target
90
+ "PATH": f"{package_path}:{os.getenv('PATH', '')}",
91
+ "RUST_LOG": "debug",
92
+ },
93
+ )
94
+ try:
95
+ yield addr
96
+ finally:
97
+ process_allocator.terminate()
98
+ try:
99
+ five_seconds = 5
100
+ process_allocator.wait(timeout=five_seconds)
101
+ except subprocess.TimeoutExpired:
102
+ process_allocator.kill()
103
+
104
+
105
+ class TestRemoteAllocator(unittest.IsolatedAsyncioTestCase):
106
+ @classmethod
107
+ def setUpClass(cls) -> None:
108
+ cloudpickle.register_pickle_by_value(sys.modules[TestActor.__module__])
109
+
110
+ @classmethod
111
+ def tearDownClass(cls) -> None:
112
+ cloudpickle.unregister_pickle_by_value(sys.modules[TestActor.__module__])
113
+
114
+ def assert_computed_world_size(
115
+ self, computed: ValueMesh[int], expected_world_size: int
116
+ ) -> None:
117
+ expected_world_sizes = {
118
+ rank: expected_world_size for rank in range(0, expected_world_size)
119
+ }
120
+ computed_world_sizes = {p.rank: v for p, v in list(computed.flatten("rank"))}
121
+ self.assertDictEqual(expected_world_sizes, computed_world_sizes)
122
+
123
+ async def test_call_allocate_twice(self) -> None:
124
+ class DeletingAllocInitializer(StaticRemoteAllocInitializer):
125
+ """test initializer that removes the last address from the list each time initialize_alloc() is called
126
+ used to test that the state of the initializer is preserved across calls to allocate()
127
+ """
128
+
129
+ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
130
+ alloc = await super().initialize_alloc(match_labels)
131
+ self.addrs.pop(-1)
132
+ return alloc
133
+
134
+ with remote_process_allocator() as host1, remote_process_allocator() as host2:
135
+ initializer = DeletingAllocInitializer(host1, host2)
136
+
137
+ allocator = RemoteAllocator(
138
+ world_id="test_remote_allocator",
139
+ initializer=initializer,
140
+ heartbeat_interval=_100_MILLISECONDS,
141
+ )
142
+
143
+ spec = AllocSpec(AllocConstraints(), host=1, gpu=1)
144
+
145
+ await allocator.allocate(spec)
146
+ self.assertEqual([host1], initializer.addrs)
147
+
148
+ await allocator.allocate(spec)
149
+ self.assertEqual([], initializer.addrs)
150
+
151
+ async def test_throws_when_initializer_returns_empty_addrs(self) -> None:
152
+ class EmptyAllocInitializer(StaticRemoteAllocInitializer):
153
+ """test initializer that returns an empty list of addresses"""
154
+
155
+ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
156
+ _ = match_labels # Suppress unused variable warning
157
+ return []
158
+
159
+ empty_initializer = EmptyAllocInitializer()
160
+ with self.assertRaisesRegex(
161
+ RuntimeError, r"initializer must return non-empty list of addresses"
162
+ ):
163
+ allocator = RemoteAllocator(
164
+ world_id="test_remote_allocator",
165
+ initializer=empty_initializer,
166
+ heartbeat_interval=_100_MILLISECONDS,
167
+ )
168
+ await allocator.allocate(AllocSpec(AllocConstraints(), host=1, gpu=1))
169
+
170
+ async def test_allocate_2d_mesh(self) -> None:
171
+ hosts = 2
172
+ gpus = 4
173
+ world_size = hosts * gpus
174
+ spec = AllocSpec(AllocConstraints(), host=hosts, gpu=gpus)
175
+
176
+ # create 2x process-allocators (on their own bind addresses) to simulate 2 hosts
177
+ with remote_process_allocator() as host1, remote_process_allocator() as host2:
178
+ allocator = RemoteAllocator(
179
+ world_id="test_remote_allocator",
180
+ initializer=StaticRemoteAllocInitializer(host1, host2),
181
+ heartbeat_interval=_100_MILLISECONDS,
182
+ )
183
+ alloc = await allocator.allocate(spec)
184
+ proc_mesh = await ProcMesh.from_alloc(alloc)
185
+ actor = await proc_mesh.spawn("test_actor", TestActor)
186
+
187
+ values = await actor.compute_world_size.call(
188
+ master_addr="0.0.0.0",
189
+ master_port=get_free_port(),
190
+ )
191
+
192
+ self.assert_computed_world_size(values, world_size)
193
+
194
+ async def test_stacked_1d_meshes(self) -> None:
195
+ # create two stacked actor meshes on the same host
196
+ # each actor mesh running on separate process-allocators
197
+
198
+ with remote_process_allocator() as host1_a, remote_process_allocator() as host1_b:
199
+ allocator_a = RemoteAllocator(
200
+ world_id="a",
201
+ initializer=StaticRemoteAllocInitializer(host1_a),
202
+ heartbeat_interval=_100_MILLISECONDS,
203
+ )
204
+ allocator_b = RemoteAllocator(
205
+ world_id="b",
206
+ initializer=StaticRemoteAllocInitializer(host1_b),
207
+ heartbeat_interval=_100_MILLISECONDS,
208
+ )
209
+
210
+ spec_a = AllocSpec(AllocConstraints(), host=1, gpu=2)
211
+ spec_b = AllocSpec(AllocConstraints(), host=1, gpu=6)
212
+
213
+ proc_mesh_a = await ProcMesh.from_alloc(await allocator_a.allocate(spec_a))
214
+ proc_mesh_b = await ProcMesh.from_alloc(await allocator_b.allocate(spec_b))
215
+
216
+ actor_a = await proc_mesh_a.spawn("actor_a", TestActor)
217
+ actor_b = await proc_mesh_b.spawn("actor_b", TestActor)
218
+
219
+ results_a = await actor_a.compute_world_size.call(
220
+ master_addr="0.0.0.0", master_port=get_free_port()
221
+ )
222
+ results_b = await actor_b.compute_world_size.call(
223
+ master_addr="0.0.0.0", master_port=get_free_port()
224
+ )
225
+
226
+ self.assert_computed_world_size(results_a, 2) # a is a 1x2 mesh
227
+ self.assert_computed_world_size(results_b, 6) # b is a 1x6 mesh
228
+
229
+ async def test_torchx_remote_alloc_initializer_no_server(self) -> None:
230
+ with mock.patch(SERVER_READY, return_value=None):
231
+ initializer = TorchXRemoteAllocInitializer("slurm:///123")
232
+ allocator = RemoteAllocator(world_id="test", initializer=initializer)
233
+
234
+ with self.assertRaisesRegex(
235
+ RuntimeError,
236
+ r"slurm:///123 does not exist or is in a terminal state",
237
+ ):
238
+ await allocator.allocate(AllocSpec(AllocConstraints(), host=1, gpu=1))
239
+
240
+ async def test_torchx_remote_alloc_initializer_no_match_label_gt_1_meshes(
241
+ self,
242
+ ) -> None:
243
+ # asserts that an exception is raised if no match label is specified in alloc constraints
244
+ # but there are more than 1 mesh (hence ambiguous which mesh to allocate on)
245
+
246
+ server = ServerSpec(
247
+ name="__UNUSED__",
248
+ state=AppState.RUNNING,
249
+ meshes=[MeshSpec(name="x", num_hosts=1), MeshSpec(name="y", num_hosts=1)],
250
+ )
251
+
252
+ with mock.patch(SERVER_READY, return_value=server):
253
+ initializer = TorchXRemoteAllocInitializer("slurm:///123")
254
+ allocator = RemoteAllocator(world_id="test", initializer=initializer)
255
+
256
+ with self.assertRaisesRegex(
257
+ RuntimeError,
258
+ r"2 proc meshes in slurm:///123, please specify the mesh name as a match label `procmesh.monarch.meta.com/name`",
259
+ ):
260
+ await allocator.allocate(AllocSpec(AllocConstraints(), host=1, gpu=1))
261
+
262
+ @pytest.mark.oss_skip # pyre-ignore[56] TODO T228752279
263
+ async def test_torchx_remote_alloc_initializer_no_match_label_1_mesh(self) -> None:
264
+ server = ServerSpec(
265
+ name="__UNUSED__",
266
+ state=AppState.RUNNING,
267
+ meshes=[
268
+ MeshSpec(
269
+ name="x",
270
+ num_hosts=1,
271
+ transport="tcp",
272
+ hostnames=["localhost"],
273
+ )
274
+ ],
275
+ )
276
+ port = get_free_port()
277
+ with remote_process_allocator(addr=f"tcp!{get_sockaddr('localhost', port)}"):
278
+ with mock.patch(SERVER_READY, return_value=server):
279
+ initializer = TorchXRemoteAllocInitializer("local:///test", port=port)
280
+ allocator = RemoteAllocator(
281
+ world_id="test",
282
+ initializer=initializer,
283
+ heartbeat_interval=_100_MILLISECONDS,
284
+ )
285
+ alloc = await allocator.allocate(
286
+ AllocSpec(AllocConstraints(), host=1, gpu=4)
287
+ )
288
+ proc_mesh = await ProcMesh.from_alloc(alloc)
289
+ actor = await proc_mesh.spawn("test_actor", TestActor)
290
+ results = await actor.compute_world_size.call(
291
+ master_addr="0.0.0.0", master_port=get_free_port()
292
+ )
293
+ self.assert_computed_world_size(results, 4) # 1x4 mesh
294
+
295
+ @pytest.mark.oss_skip # pyre-ignore[56] TODO T228752279
296
+ async def test_torchx_remote_alloc_initializer_with_match_label(self) -> None:
297
+ server = ServerSpec(
298
+ name="__UNUSED__",
299
+ state=AppState.RUNNING,
300
+ meshes=[
301
+ MeshSpec(
302
+ name="x",
303
+ num_hosts=1,
304
+ transport="tcp",
305
+ hostnames=["localhost"],
306
+ )
307
+ ],
308
+ )
309
+ port = get_free_port()
310
+ with remote_process_allocator(addr=f"tcp!{get_sockaddr('localhost', port)}"):
311
+ with mock.patch(SERVER_READY, return_value=server):
312
+ initializer = TorchXRemoteAllocInitializer("local:///test", port=port)
313
+ allocator = RemoteAllocator(
314
+ world_id="test",
315
+ initializer=initializer,
316
+ heartbeat_interval=_100_MILLISECONDS,
317
+ )
318
+ alloc = await allocator.allocate(
319
+ AllocSpec(
320
+ AllocConstraints(
321
+ match_labels={ALLOC_LABEL_PROC_MESH_NAME: "x"}
322
+ ),
323
+ host=1,
324
+ gpu=3,
325
+ )
326
+ )
327
+ proc_mesh = await ProcMesh.from_alloc(alloc)
328
+ actor = await proc_mesh.spawn("test_actor", TestActor)
329
+ results = await actor.compute_world_size.call(
330
+ master_addr="0.0.0.0", master_port=get_free_port()
331
+ )
332
+ self.assert_computed_world_size(results, 3) # 1x3 mesh
333
+
334
+ async def test_torchx_remote_alloc_initializer_with_match_label_no_match(
335
+ self,
336
+ ) -> None:
337
+ # assert that match label with a mesh name that does not exist should error out
338
+
339
+ server = ServerSpec(
340
+ name="test",
341
+ state=AppState.RUNNING,
342
+ meshes=[
343
+ MeshSpec(
344
+ name="x",
345
+ num_hosts=1,
346
+ transport="tcp",
347
+ hostnames=["localhost"],
348
+ )
349
+ ],
350
+ )
351
+
352
+ with mock.patch(SERVER_READY, return_value=server):
353
+ with self.assertRaisesRegex(RuntimeError, r"'y' not found in job: test"):
354
+ initializer = TorchXRemoteAllocInitializer("local:///test")
355
+ allocator = RemoteAllocator(world_id="test", initializer=initializer)
356
+ alloc = await allocator.allocate(
357
+ AllocSpec(
358
+ AllocConstraints(
359
+ match_labels={ALLOC_LABEL_PROC_MESH_NAME: "y"}
360
+ ),
361
+ host=1,
362
+ gpu=1,
363
+ )
364
+ )
365
+ await ProcMesh.from_alloc(alloc)