torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +74 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +198 -0
- monarch/actor_mesh.py +692 -0
- monarch/allocator.py +62 -0
- monarch/bootstrap_main.py +75 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +69 -0
- monarch/cached_remote_function.py +257 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +646 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +443 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +572 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +304 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +204 -0
- monarch/common/stream.py +111 -0
- monarch/common/tensor.py +793 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/fetch.py +55 -0
- monarch/future.py +25 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/proc_mesh.py +188 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +190 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +357 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +189 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +57 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +121 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +139 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +112 -0
- tests/test_alloc.py +25 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +835 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +372 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +182 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
- torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,249 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import math
|
8
|
+
from functools import cache
|
9
|
+
from logging import getLogger
|
10
|
+
from timeit import default_timer as timer
|
11
|
+
|
12
|
+
from .schedule_ir import (
|
13
|
+
_Action,
|
14
|
+
_add_send_recv,
|
15
|
+
_ComputationType,
|
16
|
+
_dump_csv,
|
17
|
+
_format_pipeline_order,
|
18
|
+
_merge_bw,
|
19
|
+
BACKWARD,
|
20
|
+
FORWARD,
|
21
|
+
FULL_BACKWARD,
|
22
|
+
)
|
23
|
+
|
24
|
+
logger = getLogger()
|
25
|
+
|
26
|
+
|
27
|
+
def get_stage_str(model_chunk_index, training_stage, mb_index):
|
28
|
+
ctype = _ComputationType.from_str(training_stage)
|
29
|
+
return str(_Action(model_chunk_index, ctype, mb_index))
|
30
|
+
|
31
|
+
|
32
|
+
def get_dora_schedule(
|
33
|
+
num_model_chunks,
|
34
|
+
pipeline_parallel_size,
|
35
|
+
num_round,
|
36
|
+
num_microbatch_per_round,
|
37
|
+
zero_bubble,
|
38
|
+
total_num_microbatches,
|
39
|
+
num_microbatches,
|
40
|
+
dfs=False,
|
41
|
+
prefetch_weight_latency=1.0,
|
42
|
+
enable_weight_sharding_in_pp=False,
|
43
|
+
enable_wgrad_sharding_in_pp=False,
|
44
|
+
):
|
45
|
+
start_time = timer()
|
46
|
+
num_warmup_microbatches_list = []
|
47
|
+
num_1f1b_microbatches_list = []
|
48
|
+
num_additional_1b1w_list = []
|
49
|
+
for pipeline_parallel_rank in range(pipeline_parallel_size):
|
50
|
+
num_warmup_microbatches = 0
|
51
|
+
# The number of microbatches that last pipeline stage run before 1f1b.
|
52
|
+
num_warmup_microbatches += (num_model_chunks - 1) * num_microbatch_per_round
|
53
|
+
# From last PP stage up, each rank will be 2 more than the previous one.
|
54
|
+
num_warmup_microbatches += (
|
55
|
+
pipeline_parallel_size - pipeline_parallel_rank - 1
|
56
|
+
) * 2
|
57
|
+
num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
|
58
|
+
# The number of 1f1b for zero bubble schedule
|
59
|
+
if num_microbatches == pipeline_parallel_size:
|
60
|
+
num_1f1b_microbatches = pipeline_parallel_rank
|
61
|
+
else:
|
62
|
+
num_1f1b_microbatches = 2 * pipeline_parallel_rank
|
63
|
+
num_additional_1b1w = max(
|
64
|
+
int(math.ceil((pipeline_parallel_size - 4) / 2)) - pipeline_parallel_rank,
|
65
|
+
0,
|
66
|
+
)
|
67
|
+
if dfs:
|
68
|
+
num_1f1b_microbatches = 0
|
69
|
+
num_additional_1b1w = 0
|
70
|
+
|
71
|
+
num_warmup_microbatches_list.append(num_warmup_microbatches)
|
72
|
+
num_1f1b_microbatches_list.append(num_1f1b_microbatches)
|
73
|
+
num_additional_1b1w_list.append(num_additional_1b1w)
|
74
|
+
schedules = []
|
75
|
+
|
76
|
+
def get_last_pp_rank(i):
|
77
|
+
return (i - 1) % pipeline_parallel_size, i - 1 < 0
|
78
|
+
|
79
|
+
def get_next_pp_rank(i):
|
80
|
+
return (i + 1) % pipeline_parallel_size, i + 1 >= pipeline_parallel_size
|
81
|
+
|
82
|
+
for pipeline_parallel_rank in range(pipeline_parallel_size):
|
83
|
+
s = []
|
84
|
+
fwd_mb_index_list = [0 for i in range(num_model_chunks)]
|
85
|
+
bwd_mb_index_list = [0 for i in range(num_model_chunks)]
|
86
|
+
fwd_model_chunk_index = 0
|
87
|
+
bwd_model_chunk_index = num_model_chunks - 1
|
88
|
+
weight_store = []
|
89
|
+
num_warmup_microbatches = num_warmup_microbatches_list[pipeline_parallel_rank]
|
90
|
+
num_1f1b_microbatches = num_1f1b_microbatches_list[pipeline_parallel_rank]
|
91
|
+
num_additional_1b1w = num_additional_1b1w_list[pipeline_parallel_rank]
|
92
|
+
fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
|
93
|
+
bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
|
94
|
+
fill_1b1w = False
|
95
|
+
for _ in range(num_warmup_microbatches): # warm up fwd
|
96
|
+
fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
|
97
|
+
bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
|
98
|
+
tmp = get_stage_str(fwd_model_chunk_index, "F", fwd_mb_index)
|
99
|
+
s.append(tmp)
|
100
|
+
fwd_mb_index_list[fwd_model_chunk_index] += 1
|
101
|
+
if fwd_mb_index_list[fwd_model_chunk_index] % num_microbatch_per_round == 0:
|
102
|
+
if fwd_model_chunk_index < num_model_chunks - 1:
|
103
|
+
fwd_model_chunk_index += 1
|
104
|
+
else:
|
105
|
+
fwd_model_chunk_index = 0
|
106
|
+
for i in range(
|
107
|
+
total_num_microbatches - num_warmup_microbatches
|
108
|
+
): # 1f1b and 1f1b1w
|
109
|
+
if (
|
110
|
+
fwd_model_chunk_index == 1 and not fill_1b1w
|
111
|
+
): # additional 1b1w to fill before fwd
|
112
|
+
fill_1b1w = True
|
113
|
+
for _ in range(num_additional_1b1w):
|
114
|
+
bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
|
115
|
+
tmp = get_stage_str(bwd_model_chunk_index, "B", bwd_mb_index)
|
116
|
+
s.append(tmp)
|
117
|
+
tmp = get_stage_str(bwd_model_chunk_index, "W", bwd_mb_index)
|
118
|
+
s.append(tmp)
|
119
|
+
bwd_mb_index_list[bwd_model_chunk_index] += 1
|
120
|
+
if (
|
121
|
+
bwd_mb_index_list[bwd_model_chunk_index]
|
122
|
+
% num_microbatch_per_round
|
123
|
+
== 0
|
124
|
+
):
|
125
|
+
if bwd_model_chunk_index > 0:
|
126
|
+
bwd_model_chunk_index -= 1
|
127
|
+
else:
|
128
|
+
bwd_model_chunk_index = num_model_chunks - 1
|
129
|
+
fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
|
130
|
+
bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
|
131
|
+
tmp = get_stage_str(fwd_model_chunk_index, "F", fwd_mb_index)
|
132
|
+
s.append(tmp)
|
133
|
+
fwd_mb_index_list[fwd_model_chunk_index] += 1
|
134
|
+
if fwd_mb_index_list[fwd_model_chunk_index] % num_microbatch_per_round == 0:
|
135
|
+
if fwd_model_chunk_index < num_model_chunks - 1:
|
136
|
+
fwd_model_chunk_index += 1
|
137
|
+
else:
|
138
|
+
fwd_model_chunk_index = 0
|
139
|
+
tmp = get_stage_str(
|
140
|
+
bwd_model_chunk_index, "B" if zero_bubble else "BW", bwd_mb_index
|
141
|
+
)
|
142
|
+
s.append(tmp)
|
143
|
+
tmp = get_stage_str(bwd_model_chunk_index, "W", bwd_mb_index)
|
144
|
+
if zero_bubble and i < num_1f1b_microbatches:
|
145
|
+
weight_store.append(tmp)
|
146
|
+
else:
|
147
|
+
s.append(tmp)
|
148
|
+
bwd_mb_index_list[bwd_model_chunk_index] += 1
|
149
|
+
if bwd_mb_index_list[bwd_model_chunk_index] % num_microbatch_per_round == 0:
|
150
|
+
if bwd_model_chunk_index > 0:
|
151
|
+
bwd_model_chunk_index -= 1
|
152
|
+
else:
|
153
|
+
bwd_model_chunk_index = num_model_chunks - 1
|
154
|
+
num_cooldown = (
|
155
|
+
num_warmup_microbatches - num_additional_1b1w
|
156
|
+
if fill_1b1w
|
157
|
+
else num_warmup_microbatches
|
158
|
+
)
|
159
|
+
for _ in range(num_cooldown): # cooldown bwd
|
160
|
+
fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
|
161
|
+
bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
|
162
|
+
tmp = get_stage_str(bwd_model_chunk_index, "B", bwd_mb_index)
|
163
|
+
s.append(tmp)
|
164
|
+
tmp = get_stage_str(bwd_model_chunk_index, "W", bwd_mb_index)
|
165
|
+
s.append(tmp)
|
166
|
+
bwd_mb_index_list[bwd_model_chunk_index] += 1
|
167
|
+
if bwd_mb_index_list[bwd_model_chunk_index] % num_microbatch_per_round == 0:
|
168
|
+
if bwd_model_chunk_index > 0:
|
169
|
+
bwd_model_chunk_index -= 1
|
170
|
+
else:
|
171
|
+
bwd_model_chunk_index = num_model_chunks - 1
|
172
|
+
if len(weight_store) > 0:
|
173
|
+
s += weight_store
|
174
|
+
schedules.append(s)
|
175
|
+
|
176
|
+
compute_schedules = {}
|
177
|
+
for rank in range(pipeline_parallel_size):
|
178
|
+
compute_schedules[rank] = []
|
179
|
+
for action_str in schedules[rank]:
|
180
|
+
action = _Action.from_str(action_str)
|
181
|
+
stage_index = action.stage_index * pipeline_parallel_size + rank
|
182
|
+
action = _Action(
|
183
|
+
stage_index, action.computation_type, action.microbatch_index
|
184
|
+
)
|
185
|
+
compute_schedules[rank].append(action)
|
186
|
+
|
187
|
+
lowered_comm_schedule = compute_schedules
|
188
|
+
for rank in lowered_comm_schedule:
|
189
|
+
lowered_comm_schedule[rank] = _merge_bw(lowered_comm_schedule[rank])
|
190
|
+
|
191
|
+
dump_scheduler_ir = True
|
192
|
+
if dump_scheduler_ir:
|
193
|
+
compute_str = _format_pipeline_order(lowered_comm_schedule)
|
194
|
+
with open("lowered_compute.log", "w") as logf:
|
195
|
+
logf.write(compute_str)
|
196
|
+
_dump_csv(compute_schedules, "lowered_compute.csv")
|
197
|
+
|
198
|
+
lowered_comm_schedule = _add_send_recv(
|
199
|
+
lowered_comm_schedule,
|
200
|
+
stage_to_rank=lambda chunk_index: chunk_index % pipeline_parallel_size,
|
201
|
+
num_stages=num_model_chunks * pipeline_parallel_size,
|
202
|
+
)
|
203
|
+
|
204
|
+
comms_str = _format_pipeline_order(lowered_comm_schedule)
|
205
|
+
if dump_scheduler_ir:
|
206
|
+
with open("lowered_comms.log", "w") as logf:
|
207
|
+
logf.write(comms_str)
|
208
|
+
_dump_csv(lowered_comm_schedule, "lowered_compute_with_send_recv.csv")
|
209
|
+
logger.debug("---------- lowered IR\n%s----------", comms_str)
|
210
|
+
|
211
|
+
if not enable_weight_sharding_in_pp and not enable_wgrad_sharding_in_pp:
|
212
|
+
return lowered_comm_schedule
|
213
|
+
|
214
|
+
generation_time = timer() - start_time
|
215
|
+
logger.info(f"schedule generation took {generation_time:.6f} seconds")
|
216
|
+
|
217
|
+
return lowered_comm_schedule
|
218
|
+
|
219
|
+
|
220
|
+
# TODO - replace bfs / dfs functions below with new IR generators
|
221
|
+
ir_schedules = {
|
222
|
+
# "dora": get_dora_schedule,
|
223
|
+
"dora-dfs": lambda *args, **kwargs: get_dora_schedule(*args, **kwargs, dfs=True),
|
224
|
+
# "zbv": get_zbv_schedule,
|
225
|
+
# "zbw": get_zbw_schedule,
|
226
|
+
}
|
227
|
+
|
228
|
+
is_zero_bubble = {
|
229
|
+
# "dora": True,
|
230
|
+
"dora-dfs": True,
|
231
|
+
# "zbv": True,
|
232
|
+
# "zbw": True,
|
233
|
+
}
|
234
|
+
|
235
|
+
|
236
|
+
@cache
|
237
|
+
def generate_schedule(name: str, *args, **kwargs):
|
238
|
+
assert name in ir_schedules, f"{name} is not a supported schedule type"
|
239
|
+
schedules = ir_schedules[name](*args, **kwargs)
|
240
|
+
stage_to_rank = {}
|
241
|
+
for rank, schedule_actions_rank in schedules.items():
|
242
|
+
for action in schedule_actions_rank:
|
243
|
+
comp_type = action.computation_type
|
244
|
+
stage_idx = action.stage_index
|
245
|
+
if comp_type == FORWARD:
|
246
|
+
stage_to_rank[stage_idx] = rank
|
247
|
+
if comp_type in (BACKWARD, FULL_BACKWARD):
|
248
|
+
stage_to_rank[stage_idx] = rank
|
249
|
+
return schedules, stage_to_rank
|
monarch/proc_mesh.py
ADDED
@@ -0,0 +1,188 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import sys
|
8
|
+
|
9
|
+
from typing import Any, cast, Optional, Type, TypeVar
|
10
|
+
|
11
|
+
import monarch
|
12
|
+
from monarch import ActorFuture as Future
|
13
|
+
|
14
|
+
from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
|
15
|
+
Alloc,
|
16
|
+
AllocConstraints,
|
17
|
+
AllocSpec,
|
18
|
+
)
|
19
|
+
from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
|
20
|
+
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
|
21
|
+
from monarch.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef
|
22
|
+
|
23
|
+
from monarch.common._device_utils import _local_device_count
|
24
|
+
from monarch.rdma import RDMAManager
|
25
|
+
|
26
|
+
T = TypeVar("T")
|
27
|
+
try:
|
28
|
+
from __manifest__ import fbmake # noqa
|
29
|
+
|
30
|
+
IN_PAR = True
|
31
|
+
except ImportError:
|
32
|
+
IN_PAR = False
|
33
|
+
|
34
|
+
|
35
|
+
async def _allocate_nonblocking(alloc: Alloc) -> "ProcMesh":
|
36
|
+
return ProcMesh(await HyProcMesh.allocate_nonblocking(alloc))
|
37
|
+
|
38
|
+
|
39
|
+
def _allocate_blocking(alloc: Alloc) -> "ProcMesh":
|
40
|
+
return ProcMesh(HyProcMesh.allocate_blocking(alloc))
|
41
|
+
|
42
|
+
|
43
|
+
class ProcMesh:
|
44
|
+
def __init__(self, hy_proc_mesh: HyProcMesh) -> None:
|
45
|
+
self._proc_mesh = hy_proc_mesh
|
46
|
+
self._mailbox: Mailbox = self._proc_mesh.client
|
47
|
+
self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager)
|
48
|
+
|
49
|
+
def spawn(self, name: str, Class: Type[T], *args: Any, **kwargs: Any) -> Future[T]:
|
50
|
+
return Future(
|
51
|
+
lambda: self._spawn_nonblocking(name, Class, *args, **kwargs),
|
52
|
+
lambda: self._spawn_blocking(name, Class, *args, **kwargs),
|
53
|
+
)
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def from_alloc(self, alloc: Alloc) -> Future["ProcMesh"]:
|
57
|
+
return Future(
|
58
|
+
lambda: _allocate_nonblocking(alloc),
|
59
|
+
lambda: _allocate_blocking(alloc),
|
60
|
+
)
|
61
|
+
|
62
|
+
def _spawn_blocking(
|
63
|
+
self, name: str, Class: Type[T], *args: Any, **kwargs: Any
|
64
|
+
) -> T:
|
65
|
+
if not issubclass(Class, Actor):
|
66
|
+
raise ValueError(
|
67
|
+
f"{Class} must subclass monarch.service.Actor to spawn it."
|
68
|
+
)
|
69
|
+
|
70
|
+
actor_mesh = self._proc_mesh.spawn_blocking(name, _Actor)
|
71
|
+
service = ActorMeshRef(
|
72
|
+
Class,
|
73
|
+
_ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh),
|
74
|
+
self._mailbox,
|
75
|
+
)
|
76
|
+
# useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by
|
77
|
+
# doing `ActorMeshRef(Class, actor_handle)` but not calling _create.
|
78
|
+
service._create(args, kwargs)
|
79
|
+
return cast(T, service)
|
80
|
+
|
81
|
+
def __repr__(self) -> str:
|
82
|
+
return repr(self._proc_mesh)
|
83
|
+
|
84
|
+
def __str__(self) -> str:
|
85
|
+
return str(self._proc_mesh)
|
86
|
+
|
87
|
+
async def _spawn_nonblocking(
|
88
|
+
self, name: str, Class: Type[T], *args: Any, **kwargs: Any
|
89
|
+
) -> T:
|
90
|
+
if not issubclass(Class, Actor):
|
91
|
+
raise ValueError(
|
92
|
+
f"{Class} must subclass monarch.service.Actor to spawn it."
|
93
|
+
)
|
94
|
+
|
95
|
+
actor_mesh = await self._proc_mesh.spawn_nonblocking(name, _Actor)
|
96
|
+
service = ActorMeshRef(
|
97
|
+
Class,
|
98
|
+
_ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh),
|
99
|
+
self._mailbox,
|
100
|
+
)
|
101
|
+
# useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by
|
102
|
+
# doing `ActorMeshRef(Class, actor_handle)` but not calling _create.
|
103
|
+
service._create(args, kwargs)
|
104
|
+
return cast(T, service)
|
105
|
+
|
106
|
+
|
107
|
+
async def local_proc_mesh_nonblocking(
|
108
|
+
*, gpus: Optional[int] = None, hosts: int = 1
|
109
|
+
) -> ProcMesh:
|
110
|
+
if gpus is None:
|
111
|
+
gpus = _local_device_count()
|
112
|
+
spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
|
113
|
+
allocator = monarch.LocalAllocator()
|
114
|
+
alloc = await allocator.allocate(spec)
|
115
|
+
return await ProcMesh.from_alloc(alloc)
|
116
|
+
|
117
|
+
|
118
|
+
def local_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh:
|
119
|
+
if gpus is None:
|
120
|
+
gpus = _local_device_count()
|
121
|
+
spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
|
122
|
+
allocator = monarch.LocalAllocator()
|
123
|
+
alloc = allocator.allocate(spec).get()
|
124
|
+
return ProcMesh.from_alloc(alloc).get()
|
125
|
+
|
126
|
+
|
127
|
+
def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]:
|
128
|
+
return Future(
|
129
|
+
lambda: local_proc_mesh_nonblocking(gpus=gpus, hosts=hosts),
|
130
|
+
lambda: local_proc_mesh_blocking(gpus=gpus, hosts=hosts),
|
131
|
+
)
|
132
|
+
|
133
|
+
|
134
|
+
_BOOTSTRAP_MAIN = "monarch.bootstrap_main"
|
135
|
+
|
136
|
+
|
137
|
+
def _get_bootstrap_args() -> tuple[str, Optional[list[str]], dict[str, str]]:
|
138
|
+
if IN_PAR:
|
139
|
+
cmd = sys.argv[0]
|
140
|
+
args = None
|
141
|
+
env = {
|
142
|
+
"PAR_MAIN_OVERRIDE": _BOOTSTRAP_MAIN,
|
143
|
+
}
|
144
|
+
else:
|
145
|
+
cmd = sys.executable
|
146
|
+
args = ["-m", _BOOTSTRAP_MAIN]
|
147
|
+
env = {}
|
148
|
+
|
149
|
+
return cmd, args, env
|
150
|
+
|
151
|
+
|
152
|
+
async def proc_mesh_nonblocking(
|
153
|
+
*, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None
|
154
|
+
) -> ProcMesh:
|
155
|
+
if gpus is None:
|
156
|
+
gpus = _local_device_count()
|
157
|
+
spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
|
158
|
+
env = env or {}
|
159
|
+
cmd, args, base_env = _get_bootstrap_args()
|
160
|
+
env.update(base_env)
|
161
|
+
env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
|
162
|
+
allocator = monarch.ProcessAllocator(cmd, args, env)
|
163
|
+
alloc = await allocator.allocate(spec)
|
164
|
+
return await ProcMesh.from_alloc(alloc)
|
165
|
+
|
166
|
+
|
167
|
+
def proc_mesh_blocking(
|
168
|
+
*, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None
|
169
|
+
) -> ProcMesh:
|
170
|
+
if gpus is None:
|
171
|
+
gpus = _local_device_count()
|
172
|
+
spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
|
173
|
+
env = env or {}
|
174
|
+
cmd, args, base_env = _get_bootstrap_args()
|
175
|
+
env.update(base_env)
|
176
|
+
env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
|
177
|
+
allocator = monarch.ProcessAllocator(cmd, args, env)
|
178
|
+
alloc = allocator.allocate(spec).get()
|
179
|
+
return ProcMesh.from_alloc(alloc).get()
|
180
|
+
|
181
|
+
|
182
|
+
def proc_mesh(
|
183
|
+
*, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None
|
184
|
+
) -> Future[ProcMesh]:
|
185
|
+
return Future(
|
186
|
+
lambda: proc_mesh_nonblocking(gpus=gpus, hosts=hosts, env=env),
|
187
|
+
lambda: proc_mesh_blocking(gpus=gpus, hosts=hosts, env=env),
|
188
|
+
)
|
monarch/profiler.py
ADDED
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import itertools
|
9
|
+
import os
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from functools import partial
|
12
|
+
from pathlib import Path
|
13
|
+
from typing import Any, Dict, NamedTuple, Optional, Tuple
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from monarch.common.remote import remote
|
17
|
+
from monarch.remote_class import ControllerRemoteClass, WorkerRemoteClass
|
18
|
+
|
19
|
+
|
20
|
+
class Schedule(NamedTuple):
|
21
|
+
wait: int
|
22
|
+
warmup: int
|
23
|
+
active: int
|
24
|
+
repeat: int = 0
|
25
|
+
skip_first: int = 0
|
26
|
+
|
27
|
+
|
28
|
+
class profile:
|
29
|
+
"""
|
30
|
+
The class wraps `torch.profiler.profile()` to allow invoking the profiler remotely.
|
31
|
+
There are two main differences:
|
32
|
+
1) `on_trace_ready` can only be a string, indicating the folder where the traces
|
33
|
+
will be saved.
|
34
|
+
2) `schedule` must be of type `monarch.profiler.Schedule`.
|
35
|
+
"""
|
36
|
+
|
37
|
+
PATH_KEY = "on_trace_ready"
|
38
|
+
_counter = itertools.count()
|
39
|
+
|
40
|
+
def __init__(self, *args, **kwargs) -> None:
|
41
|
+
assert isinstance(kwargs.get(self.PATH_KEY, None), str), (
|
42
|
+
f"{self.PATH_KEY} must be passed and must be a string to represent the "
|
43
|
+
"path to save the profiler."
|
44
|
+
)
|
45
|
+
schedule = kwargs.get("schedule", None)
|
46
|
+
assert (
|
47
|
+
isinstance(schedule, Schedule) or schedule is None
|
48
|
+
), "schedule can only be monarch.profiler.Schedule or None."
|
49
|
+
self.id = next(self._counter)
|
50
|
+
_profiler_controller_init(self.id, *args, **kwargs)
|
51
|
+
|
52
|
+
def __enter__(self) -> "profile":
|
53
|
+
_profiler_controller_enter(self.id)
|
54
|
+
return self
|
55
|
+
|
56
|
+
def __exit__(self, *args, **kwargs) -> None:
|
57
|
+
_profiler_controller_exit(self.id)
|
58
|
+
|
59
|
+
def step(self) -> None:
|
60
|
+
_profiler_controller_step(self.id)
|
61
|
+
|
62
|
+
|
63
|
+
@dataclass
|
64
|
+
class _Profiler:
|
65
|
+
args: Tuple[Any, ...]
|
66
|
+
kwargs: Dict[str, Any]
|
67
|
+
profiler: Optional[torch.profiler.profile] = None
|
68
|
+
|
69
|
+
|
70
|
+
_profilers: Dict[int, _Profiler] = {}
|
71
|
+
|
72
|
+
|
73
|
+
def _profiler_init(ident, *args, **kwargs) -> None:
|
74
|
+
global _profilers
|
75
|
+
assert (
|
76
|
+
ident not in _profilers
|
77
|
+
), f"Initializing an already existing profiler, {ident=}"
|
78
|
+
_profilers[ident] = _Profiler(args, kwargs)
|
79
|
+
# It's unclear why we cannot create the profiler here. Even though
|
80
|
+
# the thread is the same, profiler complains thread id mismatch.
|
81
|
+
|
82
|
+
|
83
|
+
def _profiler_enter(ident, *args, **kwargs) -> None:
|
84
|
+
def on_trace_ready(prof, dir_path):
|
85
|
+
dir_path = Path(dir_path).absolute()
|
86
|
+
os.makedirs(dir_path, exist_ok=True)
|
87
|
+
# This is not a synchronized call, so it is okay to call without
|
88
|
+
# device mesh.
|
89
|
+
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
90
|
+
prof.export_chrome_trace(f"{dir_path}/trace_{rank}.json")
|
91
|
+
|
92
|
+
profiler = _profilers[ident]
|
93
|
+
profiler.kwargs[profile.PATH_KEY] = partial(
|
94
|
+
on_trace_ready, dir_path=profiler.kwargs[profile.PATH_KEY]
|
95
|
+
)
|
96
|
+
schedule = profiler.kwargs.get("schedule", None)
|
97
|
+
if schedule is not None:
|
98
|
+
profiler.kwargs["schedule"] = torch.profiler.schedule(**schedule._asdict())
|
99
|
+
profiler.profiler = torch.profiler.profile(*profiler.args, **profiler.kwargs)
|
100
|
+
|
101
|
+
profiler.profiler.__enter__()
|
102
|
+
|
103
|
+
|
104
|
+
def _profiler_exit(ident, *args, **kwargs) -> None:
|
105
|
+
profiler = _profilers[ident].profiler
|
106
|
+
assert profiler is not None
|
107
|
+
profiler.__exit__(None, None, None)
|
108
|
+
_profilers.pop(ident)
|
109
|
+
|
110
|
+
|
111
|
+
def _profiler_step(ident, *args, **kwargs) -> None:
|
112
|
+
profiler = _profilers[ident].profiler
|
113
|
+
assert profiler is not None
|
114
|
+
profiler.step()
|
115
|
+
|
116
|
+
|
117
|
+
_profiler_controller_init = remote(
|
118
|
+
"monarch.profiler._profiler_init", propagate="inspect"
|
119
|
+
)
|
120
|
+
|
121
|
+
_profiler_controller_enter = remote(
|
122
|
+
"monarch.profiler._profiler_enter", propagate="inspect"
|
123
|
+
)
|
124
|
+
|
125
|
+
_profiler_controller_exit = remote(
|
126
|
+
"monarch.profiler._profiler_exit", propagate="inspect"
|
127
|
+
)
|
128
|
+
|
129
|
+
_profiler_controller_step = remote(
|
130
|
+
"monarch.profiler._profiler_step", propagate="inspect"
|
131
|
+
)
|
132
|
+
|
133
|
+
|
134
|
+
class record_function(ControllerRemoteClass):
|
135
|
+
"""
|
136
|
+
The class wraps `torch.profiler.record_function()` to allow invoking the
|
137
|
+
record_function remotely.
|
138
|
+
"""
|
139
|
+
|
140
|
+
def __init__(self, name: str, args: Optional[str] = None) -> None:
|
141
|
+
super().__init__("monarch.profiler.WorkerRecordFunction", name, args)
|
142
|
+
|
143
|
+
@ControllerRemoteClass.remote_method
|
144
|
+
def __enter__(self) -> "record_function":
|
145
|
+
return self
|
146
|
+
|
147
|
+
@ControllerRemoteClass.remote_method
|
148
|
+
def __exit__(self, *args, **kwargs) -> None:
|
149
|
+
return
|
150
|
+
|
151
|
+
|
152
|
+
class WorkerRecordFunction(WorkerRemoteClass):
|
153
|
+
def __init__(self, *args, **kwargs) -> None:
|
154
|
+
self._record_function = torch.profiler.record_function(*args, **kwargs)
|
155
|
+
|
156
|
+
def __enter__(self) -> None:
|
157
|
+
self._record_function.__enter__()
|
158
|
+
|
159
|
+
def __exit__(self, *args, **kwargs) -> None:
|
160
|
+
self._record_function.__exit__(*args, **kwargs)
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import os
|
9
|
+
import subprocess
|
10
|
+
from time import sleep
|
11
|
+
from typing import Optional, TYPE_CHECKING
|
12
|
+
|
13
|
+
import monarch_supervisor
|
14
|
+
from monarch.common._device_utils import _local_device_count
|
15
|
+
from monarch.common.fake import fake_call
|
16
|
+
from monarch.common.invocation import DeviceException, RemoteException
|
17
|
+
from monarch.world_mesh import world_mesh
|
18
|
+
from monarch_supervisor import Context, HostConnected
|
19
|
+
from monarch_supervisor.python_executable import PYTHON_EXECUTABLE
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from monarch.common.device_mesh import DeviceMesh
|
23
|
+
|
24
|
+
|
25
|
+
class PythonLocalContext:
|
26
|
+
def __init__(self, N: int):
|
27
|
+
# do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
|
28
|
+
fake_call(lambda: 0)
|
29
|
+
|
30
|
+
self.ctx = ctx = Context()
|
31
|
+
ctx.request_hosts(N)
|
32
|
+
|
33
|
+
# we want ctx to start its listener threads
|
34
|
+
# before creating the hosts because
|
35
|
+
# initialization will happen faster in this case
|
36
|
+
sleep(0)
|
37
|
+
supervisor_addr = f"tcp://127.0.0.1:{ctx.port}"
|
38
|
+
|
39
|
+
env = {
|
40
|
+
**os.environ,
|
41
|
+
"TORCH_SUPERVISOR_HEARTBEAT_INTERVAL": str(
|
42
|
+
monarch_supervisor.HEARTBEAT_INTERVAL
|
43
|
+
),
|
44
|
+
# This is needed to avoid a hard failure in ncclx when we do not
|
45
|
+
# have backend topology info (eg. on RE).
|
46
|
+
"NCCL_IGNORE_TOPO_LOAD_FAILURE": "true",
|
47
|
+
}
|
48
|
+
|
49
|
+
# start_new_session=True, because we want the host managers to be able to kill
|
50
|
+
# any worker processes before they exit, even if the supervisor crashes, or we ctrl-c
|
51
|
+
# it in testing.
|
52
|
+
self.host_managers = [
|
53
|
+
subprocess.Popen(
|
54
|
+
[
|
55
|
+
PYTHON_EXECUTABLE,
|
56
|
+
"-m",
|
57
|
+
"monarch_supervisor.host",
|
58
|
+
supervisor_addr,
|
59
|
+
],
|
60
|
+
env=env,
|
61
|
+
start_new_session=True,
|
62
|
+
)
|
63
|
+
for _ in range(N)
|
64
|
+
]
|
65
|
+
connections = ctx.messagefilter(HostConnected)
|
66
|
+
self.hosts = [connections.recv(timeout=30).sender for _ in range(N)]
|
67
|
+
|
68
|
+
def shutdown(self):
|
69
|
+
self.ctx.shutdown()
|
70
|
+
for host_manager in self.host_managers:
|
71
|
+
host_manager.wait(timeout=10)
|
72
|
+
|
73
|
+
|
74
|
+
def python_local_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> "DeviceMesh":
|
75
|
+
"""
|
76
|
+
Creates a local device mesh with the given number of hosts and gpus per host.
|
77
|
+
Easy way to use PythonLocalContext.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
gpus (Optional[int]): number of gpus per host.
|
81
|
+
Default: the number of GPUs this machine has.
|
82
|
+
|
83
|
+
hosts (int): number of hosts, primarily used for simulating multiple machines locally.
|
84
|
+
Default: 1
|
85
|
+
|
86
|
+
Example::
|
87
|
+
local_mesh = python_local_mesh(gpus=2)
|
88
|
+
with local_mesh.activate():
|
89
|
+
x = torch.rand(3, 4)
|
90
|
+
local_tensor = fetch_shard(x).result()
|
91
|
+
|
92
|
+
# Cleanly shut down the local mesh and exit.
|
93
|
+
local_mesh.exit()
|
94
|
+
"""
|
95
|
+
ctx = PythonLocalContext(hosts)
|
96
|
+
if gpus is None:
|
97
|
+
gpus = _local_device_count()
|
98
|
+
dm = world_mesh(ctx.ctx, ctx.hosts, gpus)
|
99
|
+
|
100
|
+
def exit(
|
101
|
+
error: Optional[RemoteException | DeviceException | Exception] = None,
|
102
|
+
) -> None:
|
103
|
+
dm.client.shutdown(True, error)
|
104
|
+
ctx.shutdown()
|
105
|
+
|
106
|
+
dm.exit = exit
|
107
|
+
return dm
|