torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/sim_mesh.py
ADDED
@@ -0,0 +1,359 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
import importlib.resources
|
10
|
+
import logging
|
11
|
+
import os
|
12
|
+
import random
|
13
|
+
import string
|
14
|
+
import subprocess
|
15
|
+
import tempfile
|
16
|
+
import time
|
17
|
+
from pathlib import Path
|
18
|
+
from typing import (
|
19
|
+
Callable,
|
20
|
+
ContextManager as AbstractContextManager,
|
21
|
+
Dict,
|
22
|
+
Generic,
|
23
|
+
Iterable,
|
24
|
+
List,
|
25
|
+
Optional,
|
26
|
+
Tuple,
|
27
|
+
)
|
28
|
+
|
29
|
+
from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
|
30
|
+
ClientActor,
|
31
|
+
)
|
32
|
+
|
33
|
+
from monarch._rust_bindings.monarch_extension.simulator_client import ( # @manual=//monarch/monarch_extension:monarch_extension
|
34
|
+
bootstrap_simulator_backend,
|
35
|
+
SimulatorClient,
|
36
|
+
)
|
37
|
+
|
38
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
39
|
+
ActorId,
|
40
|
+
init_proc,
|
41
|
+
Proc,
|
42
|
+
)
|
43
|
+
from monarch.common.client import Client
|
44
|
+
from monarch.common.constants import (
|
45
|
+
SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL,
|
46
|
+
SIM_MESH_CLIENT_TIMEOUT,
|
47
|
+
)
|
48
|
+
from monarch.common.device_mesh import DeviceMesh
|
49
|
+
from monarch.common.fake import fake_call
|
50
|
+
from monarch.common.future import Future, T
|
51
|
+
from monarch.common.invocation import DeviceException, RemoteException
|
52
|
+
from monarch.common.messages import Dims
|
53
|
+
from monarch.common.shape import NDSlice
|
54
|
+
from monarch.controller.rust_backend.controller import RustController
|
55
|
+
from monarch.rust_backend_mesh import MeshWorld
|
56
|
+
|
57
|
+
|
58
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
59
|
+
|
60
|
+
|
61
|
+
def sim_mesh(
|
62
|
+
n_meshes: int, hosts: int, gpus_per_host: int, proxy_addr: Optional[str] = None
|
63
|
+
) -> List[DeviceMesh]:
|
64
|
+
"""
|
65
|
+
Creates a single simulated device mesh with the given number of per host.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
n_meshes : number of device meshes to create.
|
69
|
+
hosts : number of hosts, primarily used for simulating multiple machines locally.
|
70
|
+
Default: 1
|
71
|
+
gpus_per_host : number of gpus per host.
|
72
|
+
Default: the number of GPUs this machine has.
|
73
|
+
"""
|
74
|
+
mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = {}
|
75
|
+
bootstrap: Bootstrap = Bootstrap(
|
76
|
+
n_meshes,
|
77
|
+
mesh_world_state,
|
78
|
+
proxy_addr=proxy_addr,
|
79
|
+
world_size=hosts * gpus_per_host,
|
80
|
+
)
|
81
|
+
|
82
|
+
client_proc_id = "client[0]"
|
83
|
+
client_proc: Proc = init_proc(
|
84
|
+
proc_id=client_proc_id,
|
85
|
+
bootstrap_addr=bootstrap.client_bootstrap_addr,
|
86
|
+
timeout=SIM_MESH_CLIENT_TIMEOUT, # unused
|
87
|
+
supervision_update_interval=SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL,
|
88
|
+
listen_addr=bootstrap.client_listen_addr,
|
89
|
+
)
|
90
|
+
root_client_actor: ClientActor = ClientActor(
|
91
|
+
proc=client_proc, actor_name="root_client"
|
92
|
+
)
|
93
|
+
|
94
|
+
dms = []
|
95
|
+
for i in range(n_meshes):
|
96
|
+
controller_id = ActorId(
|
97
|
+
world_name=f"mesh_{i}_controller", rank=0, actor_name="root"
|
98
|
+
)
|
99
|
+
# Create a new device mesh
|
100
|
+
backend_ctrl = RustController(
|
101
|
+
proc=client_proc,
|
102
|
+
client_actor=ClientActor.new_with_parent(
|
103
|
+
client_proc, root_client_actor.actor_id
|
104
|
+
),
|
105
|
+
controller_id=controller_id,
|
106
|
+
worker_world_name=f"mesh_{i}_worker",
|
107
|
+
)
|
108
|
+
client = Client(backend_ctrl, hosts * gpus_per_host, gpus_per_host)
|
109
|
+
dm = SimMesh(
|
110
|
+
client,
|
111
|
+
NDSlice(offset=0, sizes=[hosts, gpus_per_host], strides=[gpus_per_host, 1]),
|
112
|
+
("host", "gpu"),
|
113
|
+
bootstrap._simulator_client,
|
114
|
+
f"mesh_{i}_worker",
|
115
|
+
)
|
116
|
+
dms.append(dm)
|
117
|
+
|
118
|
+
return dms
|
119
|
+
|
120
|
+
|
121
|
+
class OriginalFutureWrapper(Generic[T]):
|
122
|
+
result: Callable[
|
123
|
+
[
|
124
|
+
Future[T],
|
125
|
+
float | None,
|
126
|
+
],
|
127
|
+
T,
|
128
|
+
] = Future.result
|
129
|
+
_set_result: Callable[[Future[T], T], None] = Future._set_result
|
130
|
+
|
131
|
+
|
132
|
+
class SimMesh(DeviceMesh, Generic[T]):
|
133
|
+
def __init__(
|
134
|
+
self,
|
135
|
+
client: "Client",
|
136
|
+
processes: "NDSlice",
|
137
|
+
names: Dims,
|
138
|
+
simulator_client: SimulatorClient,
|
139
|
+
mesh_name: str = "default",
|
140
|
+
) -> None:
|
141
|
+
super().__init__(client, processes, names, mesh_name)
|
142
|
+
self.simulator_client: SimulatorClient = simulator_client
|
143
|
+
|
144
|
+
# monkey patch Future.result and Future._set_result to hook into set_training_script_state_{running,waiting}
|
145
|
+
def activate(self) -> AbstractContextManager[DeviceMesh]:
|
146
|
+
def sim_result(fut: Future[T], timeout: float | None = None) -> T:
|
147
|
+
self.simulator_client.set_training_script_state_waiting()
|
148
|
+
return OriginalFutureWrapper.result(fut, timeout)
|
149
|
+
|
150
|
+
def sim_set_result(fut: Future[T], result: T) -> None:
|
151
|
+
self.simulator_client.set_training_script_state_running()
|
152
|
+
return OriginalFutureWrapper._set_result(fut, result)
|
153
|
+
|
154
|
+
# pyre-ignore
|
155
|
+
Future.result = sim_result
|
156
|
+
Future._set_result = sim_set_result
|
157
|
+
|
158
|
+
return super().activate()
|
159
|
+
|
160
|
+
# restore Future.result and Future._set_result to their previous values
|
161
|
+
def exit(
|
162
|
+
self,
|
163
|
+
error: Optional[RemoteException | DeviceException | Exception] = None,
|
164
|
+
) -> None:
|
165
|
+
self.client.shutdown(True, error)
|
166
|
+
# pyre-ignore
|
167
|
+
Future.result = OriginalFutureWrapper._result
|
168
|
+
Future._set_result = OriginalFutureWrapper._set_result
|
169
|
+
|
170
|
+
|
171
|
+
def _random_id(length: int = 14) -> str:
|
172
|
+
"""
|
173
|
+
A simple random id generator.
|
174
|
+
"""
|
175
|
+
return "".join(random.choice(string.ascii_lowercase) for _ in range(length))
|
176
|
+
|
177
|
+
|
178
|
+
class Bootstrap:
|
179
|
+
def __init__(
|
180
|
+
self,
|
181
|
+
num_meshes: int,
|
182
|
+
mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]],
|
183
|
+
proxy_addr: Optional[str] = None,
|
184
|
+
world_size: int = 1,
|
185
|
+
) -> None:
|
186
|
+
"""
|
187
|
+
Bootstraps a SimMesh.
|
188
|
+
Args:
|
189
|
+
num_meshes: int - number of meshes to create.
|
190
|
+
proxy_addr: Option[str] - the proxy address of the simulation process
|
191
|
+
mesh_world_state: a state of the meshes. Keys are the MeshWorld and values are boolean indicating if this mesh is active.
|
192
|
+
"""
|
193
|
+
# do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
|
194
|
+
fake_call(lambda: 0)
|
195
|
+
|
196
|
+
env = os.environ.copy()
|
197
|
+
env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
|
198
|
+
self.env: dict[str, str] = env
|
199
|
+
|
200
|
+
self._mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = mesh_world_state
|
201
|
+
|
202
|
+
proxy_addr = proxy_addr or f"unix!@{_random_id()}-proxy"
|
203
|
+
self.bootstrap_addr: str = f"sim!unix!@system,{proxy_addr}"
|
204
|
+
|
205
|
+
client_proxy_addr = f"unix!@{_random_id()}-proxy"
|
206
|
+
self.client_listen_addr: str = f"sim!unix!@client,{client_proxy_addr}"
|
207
|
+
self.client_bootstrap_addr: str = (
|
208
|
+
f"sim!unix!@client,{client_proxy_addr},unix!@system,{proxy_addr}"
|
209
|
+
)
|
210
|
+
bootstrap_simulator_backend(self.bootstrap_addr, proxy_addr, world_size)
|
211
|
+
|
212
|
+
self._simulator_client = SimulatorClient(proxy_addr)
|
213
|
+
for i in range(num_meshes):
|
214
|
+
mesh_name: str = f"mesh_{i}"
|
215
|
+
controller_world: str = f"{mesh_name}_controller"
|
216
|
+
worker_world: str = f"{mesh_name}_worker"
|
217
|
+
controller_id: ActorId = ActorId(
|
218
|
+
world_name=controller_world,
|
219
|
+
rank=0,
|
220
|
+
actor_name="root",
|
221
|
+
)
|
222
|
+
mesh_world = (worker_world, controller_id)
|
223
|
+
self._mesh_world_state[mesh_world] = None
|
224
|
+
self.spawn_mesh(mesh_world)
|
225
|
+
# sleep for 10 sec for the worker and controller tasks to be spawned and ready.
|
226
|
+
time.sleep(10)
|
227
|
+
|
228
|
+
def get_mesh_worlds(self) -> List[MeshWorld]:
|
229
|
+
return []
|
230
|
+
|
231
|
+
def kill_mesh(self, mesh_world: MeshWorld) -> None:
|
232
|
+
pass
|
233
|
+
|
234
|
+
def spawn_mesh(self, mesh_world: MeshWorld) -> None:
|
235
|
+
worker_world, controller_id = mesh_world
|
236
|
+
controller_world = controller_id.world_name
|
237
|
+
self._simulator_client.spawn_mesh(
|
238
|
+
self.bootstrap_addr, f"{controller_world}[0].root", worker_world
|
239
|
+
)
|
240
|
+
|
241
|
+
|
242
|
+
def _validate_proccesses_end(
|
243
|
+
processes: Iterable[subprocess.Popen[bytes]],
|
244
|
+
timeout_in_sec: int = 1,
|
245
|
+
raise_on_abnormal_exit: bool = True,
|
246
|
+
) -> list[int]:
|
247
|
+
"""
|
248
|
+
Check if processes have ended properly. Raise errors immediately
|
249
|
+
if any process has ended with a non-zero return code.
|
250
|
+
Return a list of process indices that have not ended yet.
|
251
|
+
"""
|
252
|
+
running = []
|
253
|
+
start_time = time.time()
|
254
|
+
for i, process in enumerate(processes):
|
255
|
+
try:
|
256
|
+
current_time = time.time()
|
257
|
+
elapsed_time = current_time - start_time
|
258
|
+
# The processes are running in parallel. No need to wait for
|
259
|
+
# `timeout_in_sec` for each process. Only count the remaining ones.
|
260
|
+
wait_in_sec = max(0, timeout_in_sec - elapsed_time)
|
261
|
+
return_code = process.wait(timeout=wait_in_sec)
|
262
|
+
if return_code != 0:
|
263
|
+
error_message: str = (
|
264
|
+
f"Process[{i}] {process.pid} exited with "
|
265
|
+
f"return code {return_code}. Command:\n "
|
266
|
+
f"{process.args!r}"
|
267
|
+
)
|
268
|
+
if raise_on_abnormal_exit:
|
269
|
+
raise RuntimeError(error_message)
|
270
|
+
else:
|
271
|
+
logger.error(error_message)
|
272
|
+
except subprocess.TimeoutExpired:
|
273
|
+
running.append(i)
|
274
|
+
|
275
|
+
return running
|
276
|
+
|
277
|
+
|
278
|
+
class PoolDeviceMeshProvider:
|
279
|
+
def __init__(
|
280
|
+
self,
|
281
|
+
hosts_per_mesh: int,
|
282
|
+
gpus_per_host: int,
|
283
|
+
client_proc: Proc,
|
284
|
+
mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]],
|
285
|
+
simulator_client: SimulatorClient,
|
286
|
+
) -> None:
|
287
|
+
self._hosts_per_mesh = hosts_per_mesh
|
288
|
+
self._gpus_per_host = gpus_per_host
|
289
|
+
self._client_proc = client_proc
|
290
|
+
self._root_client_actor: ClientActor = ClientActor(
|
291
|
+
proc=client_proc, actor_name="root_client"
|
292
|
+
)
|
293
|
+
self._mesh_world_state = mesh_world_state
|
294
|
+
self._simulator_client = simulator_client
|
295
|
+
|
296
|
+
def new_mesh(self, timeout_in_sec: Optional[int] = None) -> DeviceMesh:
|
297
|
+
mesh_world_to_create = next(
|
298
|
+
(
|
299
|
+
mesh_world
|
300
|
+
for mesh_world, is_created in self._mesh_world_state.items()
|
301
|
+
if not is_created
|
302
|
+
),
|
303
|
+
None,
|
304
|
+
)
|
305
|
+
assert mesh_world_to_create is not None, "No mesh world to create"
|
306
|
+
|
307
|
+
worker_world, controller_id = mesh_world_to_create
|
308
|
+
# Create a new device mesh
|
309
|
+
backend_ctrl = RustController(
|
310
|
+
proc=self._client_proc,
|
311
|
+
client_actor=ClientActor.new_with_parent(
|
312
|
+
self._client_proc, self._root_client_actor.actor_id
|
313
|
+
),
|
314
|
+
controller_id=controller_id,
|
315
|
+
worker_world_name=worker_world,
|
316
|
+
)
|
317
|
+
client = Client(
|
318
|
+
backend_ctrl,
|
319
|
+
self._hosts_per_mesh * self._gpus_per_host,
|
320
|
+
self._gpus_per_host,
|
321
|
+
)
|
322
|
+
dm = SimMesh(
|
323
|
+
client,
|
324
|
+
NDSlice(
|
325
|
+
offset=0,
|
326
|
+
sizes=[self._hosts_per_mesh, self._gpus_per_host],
|
327
|
+
strides=[self._gpus_per_host, 1],
|
328
|
+
),
|
329
|
+
("host", "gpu"),
|
330
|
+
self._simulator_client,
|
331
|
+
worker_world,
|
332
|
+
)
|
333
|
+
self._mesh_world_state[mesh_world_to_create] = dm
|
334
|
+
|
335
|
+
return dm
|
336
|
+
|
337
|
+
|
338
|
+
def sim_mesh_provider(
|
339
|
+
num_meshes: int, hosts_per_mesh: int, gpus_per_host: int
|
340
|
+
) -> Tuple[PoolDeviceMeshProvider, Bootstrap]:
|
341
|
+
mesh_world_state = {}
|
342
|
+
bootstrap = Bootstrap(num_meshes, mesh_world_state)
|
343
|
+
|
344
|
+
client_proc_id = "client[0]"
|
345
|
+
client_proc: Proc = init_proc(
|
346
|
+
proc_id=client_proc_id,
|
347
|
+
bootstrap_addr=bootstrap.client_bootstrap_addr,
|
348
|
+
timeout=SIM_MESH_CLIENT_TIMEOUT, # unused
|
349
|
+
supervision_update_interval=SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL,
|
350
|
+
listen_addr=bootstrap.client_listen_addr,
|
351
|
+
)
|
352
|
+
dm_provider = PoolDeviceMeshProvider(
|
353
|
+
hosts_per_mesh,
|
354
|
+
gpus_per_host,
|
355
|
+
client_proc,
|
356
|
+
mesh_world_state,
|
357
|
+
bootstrap._simulator_client,
|
358
|
+
)
|
359
|
+
return (dm_provider, bootstrap)
|