torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,249 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from functools import cache
9
+ from logging import getLogger
10
+ from timeit import default_timer as timer
11
+
12
+ from .schedule_ir import (
13
+ _Action,
14
+ _add_send_recv,
15
+ _ComputationType,
16
+ _dump_csv,
17
+ _format_pipeline_order,
18
+ _merge_bw,
19
+ BACKWARD,
20
+ FORWARD,
21
+ FULL_BACKWARD,
22
+ )
23
+
24
+ logger = getLogger()
25
+
26
+
27
+ def get_stage_str(model_chunk_index, training_stage, mb_index):
28
+ ctype = _ComputationType.from_str(training_stage)
29
+ return str(_Action(model_chunk_index, ctype, mb_index))
30
+
31
+
32
+ def get_dora_schedule(
33
+ num_model_chunks,
34
+ pipeline_parallel_size,
35
+ num_round,
36
+ num_microbatch_per_round,
37
+ zero_bubble,
38
+ total_num_microbatches,
39
+ num_microbatches,
40
+ dfs=False,
41
+ prefetch_weight_latency=1.0,
42
+ enable_weight_sharding_in_pp=False,
43
+ enable_wgrad_sharding_in_pp=False,
44
+ ):
45
+ start_time = timer()
46
+ num_warmup_microbatches_list = []
47
+ num_1f1b_microbatches_list = []
48
+ num_additional_1b1w_list = []
49
+ for pipeline_parallel_rank in range(pipeline_parallel_size):
50
+ num_warmup_microbatches = 0
51
+ # The number of microbatches that last pipeline stage run before 1f1b.
52
+ num_warmup_microbatches += (num_model_chunks - 1) * num_microbatch_per_round
53
+ # From last PP stage up, each rank will be 2 more than the previous one.
54
+ num_warmup_microbatches += (
55
+ pipeline_parallel_size - pipeline_parallel_rank - 1
56
+ ) * 2
57
+ num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
58
+ # The number of 1f1b for zero bubble schedule
59
+ if num_microbatches == pipeline_parallel_size:
60
+ num_1f1b_microbatches = pipeline_parallel_rank
61
+ else:
62
+ num_1f1b_microbatches = 2 * pipeline_parallel_rank
63
+ num_additional_1b1w = max(
64
+ int(math.ceil((pipeline_parallel_size - 4) / 2)) - pipeline_parallel_rank,
65
+ 0,
66
+ )
67
+ if dfs:
68
+ num_1f1b_microbatches = 0
69
+ num_additional_1b1w = 0
70
+
71
+ num_warmup_microbatches_list.append(num_warmup_microbatches)
72
+ num_1f1b_microbatches_list.append(num_1f1b_microbatches)
73
+ num_additional_1b1w_list.append(num_additional_1b1w)
74
+ schedules = []
75
+
76
+ def get_last_pp_rank(i):
77
+ return (i - 1) % pipeline_parallel_size, i - 1 < 0
78
+
79
+ def get_next_pp_rank(i):
80
+ return (i + 1) % pipeline_parallel_size, i + 1 >= pipeline_parallel_size
81
+
82
+ for pipeline_parallel_rank in range(pipeline_parallel_size):
83
+ s = []
84
+ fwd_mb_index_list = [0 for i in range(num_model_chunks)]
85
+ bwd_mb_index_list = [0 for i in range(num_model_chunks)]
86
+ fwd_model_chunk_index = 0
87
+ bwd_model_chunk_index = num_model_chunks - 1
88
+ weight_store = []
89
+ num_warmup_microbatches = num_warmup_microbatches_list[pipeline_parallel_rank]
90
+ num_1f1b_microbatches = num_1f1b_microbatches_list[pipeline_parallel_rank]
91
+ num_additional_1b1w = num_additional_1b1w_list[pipeline_parallel_rank]
92
+ fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
93
+ bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
94
+ fill_1b1w = False
95
+ for _ in range(num_warmup_microbatches): # warm up fwd
96
+ fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
97
+ bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
98
+ tmp = get_stage_str(fwd_model_chunk_index, "F", fwd_mb_index)
99
+ s.append(tmp)
100
+ fwd_mb_index_list[fwd_model_chunk_index] += 1
101
+ if fwd_mb_index_list[fwd_model_chunk_index] % num_microbatch_per_round == 0:
102
+ if fwd_model_chunk_index < num_model_chunks - 1:
103
+ fwd_model_chunk_index += 1
104
+ else:
105
+ fwd_model_chunk_index = 0
106
+ for i in range(
107
+ total_num_microbatches - num_warmup_microbatches
108
+ ): # 1f1b and 1f1b1w
109
+ if (
110
+ fwd_model_chunk_index == 1 and not fill_1b1w
111
+ ): # additional 1b1w to fill before fwd
112
+ fill_1b1w = True
113
+ for _ in range(num_additional_1b1w):
114
+ bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
115
+ tmp = get_stage_str(bwd_model_chunk_index, "B", bwd_mb_index)
116
+ s.append(tmp)
117
+ tmp = get_stage_str(bwd_model_chunk_index, "W", bwd_mb_index)
118
+ s.append(tmp)
119
+ bwd_mb_index_list[bwd_model_chunk_index] += 1
120
+ if (
121
+ bwd_mb_index_list[bwd_model_chunk_index]
122
+ % num_microbatch_per_round
123
+ == 0
124
+ ):
125
+ if bwd_model_chunk_index > 0:
126
+ bwd_model_chunk_index -= 1
127
+ else:
128
+ bwd_model_chunk_index = num_model_chunks - 1
129
+ fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
130
+ bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
131
+ tmp = get_stage_str(fwd_model_chunk_index, "F", fwd_mb_index)
132
+ s.append(tmp)
133
+ fwd_mb_index_list[fwd_model_chunk_index] += 1
134
+ if fwd_mb_index_list[fwd_model_chunk_index] % num_microbatch_per_round == 0:
135
+ if fwd_model_chunk_index < num_model_chunks - 1:
136
+ fwd_model_chunk_index += 1
137
+ else:
138
+ fwd_model_chunk_index = 0
139
+ tmp = get_stage_str(
140
+ bwd_model_chunk_index, "B" if zero_bubble else "BW", bwd_mb_index
141
+ )
142
+ s.append(tmp)
143
+ tmp = get_stage_str(bwd_model_chunk_index, "W", bwd_mb_index)
144
+ if zero_bubble and i < num_1f1b_microbatches:
145
+ weight_store.append(tmp)
146
+ else:
147
+ s.append(tmp)
148
+ bwd_mb_index_list[bwd_model_chunk_index] += 1
149
+ if bwd_mb_index_list[bwd_model_chunk_index] % num_microbatch_per_round == 0:
150
+ if bwd_model_chunk_index > 0:
151
+ bwd_model_chunk_index -= 1
152
+ else:
153
+ bwd_model_chunk_index = num_model_chunks - 1
154
+ num_cooldown = (
155
+ num_warmup_microbatches - num_additional_1b1w
156
+ if fill_1b1w
157
+ else num_warmup_microbatches
158
+ )
159
+ for _ in range(num_cooldown): # cooldown bwd
160
+ fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
161
+ bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
162
+ tmp = get_stage_str(bwd_model_chunk_index, "B", bwd_mb_index)
163
+ s.append(tmp)
164
+ tmp = get_stage_str(bwd_model_chunk_index, "W", bwd_mb_index)
165
+ s.append(tmp)
166
+ bwd_mb_index_list[bwd_model_chunk_index] += 1
167
+ if bwd_mb_index_list[bwd_model_chunk_index] % num_microbatch_per_round == 0:
168
+ if bwd_model_chunk_index > 0:
169
+ bwd_model_chunk_index -= 1
170
+ else:
171
+ bwd_model_chunk_index = num_model_chunks - 1
172
+ if len(weight_store) > 0:
173
+ s += weight_store
174
+ schedules.append(s)
175
+
176
+ compute_schedules = {}
177
+ for rank in range(pipeline_parallel_size):
178
+ compute_schedules[rank] = []
179
+ for action_str in schedules[rank]:
180
+ action = _Action.from_str(action_str)
181
+ stage_index = action.stage_index * pipeline_parallel_size + rank
182
+ action = _Action(
183
+ stage_index, action.computation_type, action.microbatch_index
184
+ )
185
+ compute_schedules[rank].append(action)
186
+
187
+ lowered_comm_schedule = compute_schedules
188
+ for rank in lowered_comm_schedule:
189
+ lowered_comm_schedule[rank] = _merge_bw(lowered_comm_schedule[rank])
190
+
191
+ dump_scheduler_ir = True
192
+ if dump_scheduler_ir:
193
+ compute_str = _format_pipeline_order(lowered_comm_schedule)
194
+ with open("lowered_compute.log", "w") as logf:
195
+ logf.write(compute_str)
196
+ _dump_csv(compute_schedules, "lowered_compute.csv")
197
+
198
+ lowered_comm_schedule = _add_send_recv(
199
+ lowered_comm_schedule,
200
+ stage_to_rank=lambda chunk_index: chunk_index % pipeline_parallel_size,
201
+ num_stages=num_model_chunks * pipeline_parallel_size,
202
+ )
203
+
204
+ comms_str = _format_pipeline_order(lowered_comm_schedule)
205
+ if dump_scheduler_ir:
206
+ with open("lowered_comms.log", "w") as logf:
207
+ logf.write(comms_str)
208
+ _dump_csv(lowered_comm_schedule, "lowered_compute_with_send_recv.csv")
209
+ logger.debug("---------- lowered IR\n%s----------", comms_str)
210
+
211
+ if not enable_weight_sharding_in_pp and not enable_wgrad_sharding_in_pp:
212
+ return lowered_comm_schedule
213
+
214
+ generation_time = timer() - start_time
215
+ logger.info(f"schedule generation took {generation_time:.6f} seconds")
216
+
217
+ return lowered_comm_schedule
218
+
219
+
220
+ # TODO - replace bfs / dfs functions below with new IR generators
221
+ ir_schedules = {
222
+ # "dora": get_dora_schedule,
223
+ "dora-dfs": lambda *args, **kwargs: get_dora_schedule(*args, **kwargs, dfs=True),
224
+ # "zbv": get_zbv_schedule,
225
+ # "zbw": get_zbw_schedule,
226
+ }
227
+
228
+ is_zero_bubble = {
229
+ # "dora": True,
230
+ "dora-dfs": True,
231
+ # "zbv": True,
232
+ # "zbw": True,
233
+ }
234
+
235
+
236
+ @cache
237
+ def generate_schedule(name: str, *args, **kwargs):
238
+ assert name in ir_schedules, f"{name} is not a supported schedule type"
239
+ schedules = ir_schedules[name](*args, **kwargs)
240
+ stage_to_rank = {}
241
+ for rank, schedule_actions_rank in schedules.items():
242
+ for action in schedule_actions_rank:
243
+ comp_type = action.computation_type
244
+ stage_idx = action.stage_index
245
+ if comp_type == FORWARD:
246
+ stage_to_rank[stage_idx] = rank
247
+ if comp_type in (BACKWARD, FULL_BACKWARD):
248
+ stage_to_rank[stage_idx] = rank
249
+ return schedules, stage_to_rank
monarch/proc_mesh.py ADDED
@@ -0,0 +1,188 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+
9
+ from typing import Any, cast, Optional, Type, TypeVar
10
+
11
+ import monarch
12
+ from monarch import ActorFuture as Future
13
+
14
+ from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
15
+ Alloc,
16
+ AllocConstraints,
17
+ AllocSpec,
18
+ )
19
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
20
+ from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
21
+ from monarch.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef
22
+
23
+ from monarch.common._device_utils import _local_device_count
24
+ from monarch.rdma import RDMAManager
25
+
26
+ T = TypeVar("T")
27
+ try:
28
+ from __manifest__ import fbmake # noqa
29
+
30
+ IN_PAR = True
31
+ except ImportError:
32
+ IN_PAR = False
33
+
34
+
35
+ async def _allocate_nonblocking(alloc: Alloc) -> "ProcMesh":
36
+ return ProcMesh(await HyProcMesh.allocate_nonblocking(alloc))
37
+
38
+
39
+ def _allocate_blocking(alloc: Alloc) -> "ProcMesh":
40
+ return ProcMesh(HyProcMesh.allocate_blocking(alloc))
41
+
42
+
43
+ class ProcMesh:
44
+ def __init__(self, hy_proc_mesh: HyProcMesh) -> None:
45
+ self._proc_mesh = hy_proc_mesh
46
+ self._mailbox: Mailbox = self._proc_mesh.client
47
+ self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager)
48
+
49
+ def spawn(self, name: str, Class: Type[T], *args: Any, **kwargs: Any) -> Future[T]:
50
+ return Future(
51
+ lambda: self._spawn_nonblocking(name, Class, *args, **kwargs),
52
+ lambda: self._spawn_blocking(name, Class, *args, **kwargs),
53
+ )
54
+
55
+ @classmethod
56
+ def from_alloc(self, alloc: Alloc) -> Future["ProcMesh"]:
57
+ return Future(
58
+ lambda: _allocate_nonblocking(alloc),
59
+ lambda: _allocate_blocking(alloc),
60
+ )
61
+
62
+ def _spawn_blocking(
63
+ self, name: str, Class: Type[T], *args: Any, **kwargs: Any
64
+ ) -> T:
65
+ if not issubclass(Class, Actor):
66
+ raise ValueError(
67
+ f"{Class} must subclass monarch.service.Actor to spawn it."
68
+ )
69
+
70
+ actor_mesh = self._proc_mesh.spawn_blocking(name, _Actor)
71
+ service = ActorMeshRef(
72
+ Class,
73
+ _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh),
74
+ self._mailbox,
75
+ )
76
+ # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by
77
+ # doing `ActorMeshRef(Class, actor_handle)` but not calling _create.
78
+ service._create(args, kwargs)
79
+ return cast(T, service)
80
+
81
+ def __repr__(self) -> str:
82
+ return repr(self._proc_mesh)
83
+
84
+ def __str__(self) -> str:
85
+ return str(self._proc_mesh)
86
+
87
+ async def _spawn_nonblocking(
88
+ self, name: str, Class: Type[T], *args: Any, **kwargs: Any
89
+ ) -> T:
90
+ if not issubclass(Class, Actor):
91
+ raise ValueError(
92
+ f"{Class} must subclass monarch.service.Actor to spawn it."
93
+ )
94
+
95
+ actor_mesh = await self._proc_mesh.spawn_nonblocking(name, _Actor)
96
+ service = ActorMeshRef(
97
+ Class,
98
+ _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh),
99
+ self._mailbox,
100
+ )
101
+ # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by
102
+ # doing `ActorMeshRef(Class, actor_handle)` but not calling _create.
103
+ service._create(args, kwargs)
104
+ return cast(T, service)
105
+
106
+
107
+ async def local_proc_mesh_nonblocking(
108
+ *, gpus: Optional[int] = None, hosts: int = 1
109
+ ) -> ProcMesh:
110
+ if gpus is None:
111
+ gpus = _local_device_count()
112
+ spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
113
+ allocator = monarch.LocalAllocator()
114
+ alloc = await allocator.allocate(spec)
115
+ return await ProcMesh.from_alloc(alloc)
116
+
117
+
118
+ def local_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh:
119
+ if gpus is None:
120
+ gpus = _local_device_count()
121
+ spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
122
+ allocator = monarch.LocalAllocator()
123
+ alloc = allocator.allocate(spec).get()
124
+ return ProcMesh.from_alloc(alloc).get()
125
+
126
+
127
+ def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]:
128
+ return Future(
129
+ lambda: local_proc_mesh_nonblocking(gpus=gpus, hosts=hosts),
130
+ lambda: local_proc_mesh_blocking(gpus=gpus, hosts=hosts),
131
+ )
132
+
133
+
134
+ _BOOTSTRAP_MAIN = "monarch.bootstrap_main"
135
+
136
+
137
+ def _get_bootstrap_args() -> tuple[str, Optional[list[str]], dict[str, str]]:
138
+ if IN_PAR:
139
+ cmd = sys.argv[0]
140
+ args = None
141
+ env = {
142
+ "PAR_MAIN_OVERRIDE": _BOOTSTRAP_MAIN,
143
+ }
144
+ else:
145
+ cmd = sys.executable
146
+ args = ["-m", _BOOTSTRAP_MAIN]
147
+ env = {}
148
+
149
+ return cmd, args, env
150
+
151
+
152
+ async def proc_mesh_nonblocking(
153
+ *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None
154
+ ) -> ProcMesh:
155
+ if gpus is None:
156
+ gpus = _local_device_count()
157
+ spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
158
+ env = env or {}
159
+ cmd, args, base_env = _get_bootstrap_args()
160
+ env.update(base_env)
161
+ env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
162
+ allocator = monarch.ProcessAllocator(cmd, args, env)
163
+ alloc = await allocator.allocate(spec)
164
+ return await ProcMesh.from_alloc(alloc)
165
+
166
+
167
+ def proc_mesh_blocking(
168
+ *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None
169
+ ) -> ProcMesh:
170
+ if gpus is None:
171
+ gpus = _local_device_count()
172
+ spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
173
+ env = env or {}
174
+ cmd, args, base_env = _get_bootstrap_args()
175
+ env.update(base_env)
176
+ env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
177
+ allocator = monarch.ProcessAllocator(cmd, args, env)
178
+ alloc = allocator.allocate(spec).get()
179
+ return ProcMesh.from_alloc(alloc).get()
180
+
181
+
182
+ def proc_mesh(
183
+ *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None
184
+ ) -> Future[ProcMesh]:
185
+ return Future(
186
+ lambda: proc_mesh_nonblocking(gpus=gpus, hosts=hosts, env=env),
187
+ lambda: proc_mesh_blocking(gpus=gpus, hosts=hosts, env=env),
188
+ )
monarch/profiler.py ADDED
@@ -0,0 +1,160 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import itertools
9
+ import os
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from pathlib import Path
13
+ from typing import Any, Dict, NamedTuple, Optional, Tuple
14
+
15
+ import torch
16
+ from monarch.common.remote import remote
17
+ from monarch.remote_class import ControllerRemoteClass, WorkerRemoteClass
18
+
19
+
20
+ class Schedule(NamedTuple):
21
+ wait: int
22
+ warmup: int
23
+ active: int
24
+ repeat: int = 0
25
+ skip_first: int = 0
26
+
27
+
28
+ class profile:
29
+ """
30
+ The class wraps `torch.profiler.profile()` to allow invoking the profiler remotely.
31
+ There are two main differences:
32
+ 1) `on_trace_ready` can only be a string, indicating the folder where the traces
33
+ will be saved.
34
+ 2) `schedule` must be of type `monarch.profiler.Schedule`.
35
+ """
36
+
37
+ PATH_KEY = "on_trace_ready"
38
+ _counter = itertools.count()
39
+
40
+ def __init__(self, *args, **kwargs) -> None:
41
+ assert isinstance(kwargs.get(self.PATH_KEY, None), str), (
42
+ f"{self.PATH_KEY} must be passed and must be a string to represent the "
43
+ "path to save the profiler."
44
+ )
45
+ schedule = kwargs.get("schedule", None)
46
+ assert (
47
+ isinstance(schedule, Schedule) or schedule is None
48
+ ), "schedule can only be monarch.profiler.Schedule or None."
49
+ self.id = next(self._counter)
50
+ _profiler_controller_init(self.id, *args, **kwargs)
51
+
52
+ def __enter__(self) -> "profile":
53
+ _profiler_controller_enter(self.id)
54
+ return self
55
+
56
+ def __exit__(self, *args, **kwargs) -> None:
57
+ _profiler_controller_exit(self.id)
58
+
59
+ def step(self) -> None:
60
+ _profiler_controller_step(self.id)
61
+
62
+
63
+ @dataclass
64
+ class _Profiler:
65
+ args: Tuple[Any, ...]
66
+ kwargs: Dict[str, Any]
67
+ profiler: Optional[torch.profiler.profile] = None
68
+
69
+
70
+ _profilers: Dict[int, _Profiler] = {}
71
+
72
+
73
+ def _profiler_init(ident, *args, **kwargs) -> None:
74
+ global _profilers
75
+ assert (
76
+ ident not in _profilers
77
+ ), f"Initializing an already existing profiler, {ident=}"
78
+ _profilers[ident] = _Profiler(args, kwargs)
79
+ # It's unclear why we cannot create the profiler here. Even though
80
+ # the thread is the same, profiler complains thread id mismatch.
81
+
82
+
83
+ def _profiler_enter(ident, *args, **kwargs) -> None:
84
+ def on_trace_ready(prof, dir_path):
85
+ dir_path = Path(dir_path).absolute()
86
+ os.makedirs(dir_path, exist_ok=True)
87
+ # This is not a synchronized call, so it is okay to call without
88
+ # device mesh.
89
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
90
+ prof.export_chrome_trace(f"{dir_path}/trace_{rank}.json")
91
+
92
+ profiler = _profilers[ident]
93
+ profiler.kwargs[profile.PATH_KEY] = partial(
94
+ on_trace_ready, dir_path=profiler.kwargs[profile.PATH_KEY]
95
+ )
96
+ schedule = profiler.kwargs.get("schedule", None)
97
+ if schedule is not None:
98
+ profiler.kwargs["schedule"] = torch.profiler.schedule(**schedule._asdict())
99
+ profiler.profiler = torch.profiler.profile(*profiler.args, **profiler.kwargs)
100
+
101
+ profiler.profiler.__enter__()
102
+
103
+
104
+ def _profiler_exit(ident, *args, **kwargs) -> None:
105
+ profiler = _profilers[ident].profiler
106
+ assert profiler is not None
107
+ profiler.__exit__(None, None, None)
108
+ _profilers.pop(ident)
109
+
110
+
111
+ def _profiler_step(ident, *args, **kwargs) -> None:
112
+ profiler = _profilers[ident].profiler
113
+ assert profiler is not None
114
+ profiler.step()
115
+
116
+
117
+ _profiler_controller_init = remote(
118
+ "monarch.profiler._profiler_init", propagate="inspect"
119
+ )
120
+
121
+ _profiler_controller_enter = remote(
122
+ "monarch.profiler._profiler_enter", propagate="inspect"
123
+ )
124
+
125
+ _profiler_controller_exit = remote(
126
+ "monarch.profiler._profiler_exit", propagate="inspect"
127
+ )
128
+
129
+ _profiler_controller_step = remote(
130
+ "monarch.profiler._profiler_step", propagate="inspect"
131
+ )
132
+
133
+
134
+ class record_function(ControllerRemoteClass):
135
+ """
136
+ The class wraps `torch.profiler.record_function()` to allow invoking the
137
+ record_function remotely.
138
+ """
139
+
140
+ def __init__(self, name: str, args: Optional[str] = None) -> None:
141
+ super().__init__("monarch.profiler.WorkerRecordFunction", name, args)
142
+
143
+ @ControllerRemoteClass.remote_method
144
+ def __enter__(self) -> "record_function":
145
+ return self
146
+
147
+ @ControllerRemoteClass.remote_method
148
+ def __exit__(self, *args, **kwargs) -> None:
149
+ return
150
+
151
+
152
+ class WorkerRecordFunction(WorkerRemoteClass):
153
+ def __init__(self, *args, **kwargs) -> None:
154
+ self._record_function = torch.profiler.record_function(*args, **kwargs)
155
+
156
+ def __enter__(self) -> None:
157
+ self._record_function.__enter__()
158
+
159
+ def __exit__(self, *args, **kwargs) -> None:
160
+ self._record_function.__exit__(*args, **kwargs)
@@ -0,0 +1,107 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import os
9
+ import subprocess
10
+ from time import sleep
11
+ from typing import Optional, TYPE_CHECKING
12
+
13
+ import monarch_supervisor
14
+ from monarch.common._device_utils import _local_device_count
15
+ from monarch.common.fake import fake_call
16
+ from monarch.common.invocation import DeviceException, RemoteException
17
+ from monarch.world_mesh import world_mesh
18
+ from monarch_supervisor import Context, HostConnected
19
+ from monarch_supervisor.python_executable import PYTHON_EXECUTABLE
20
+
21
+ if TYPE_CHECKING:
22
+ from monarch.common.device_mesh import DeviceMesh
23
+
24
+
25
+ class PythonLocalContext:
26
+ def __init__(self, N: int):
27
+ # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
28
+ fake_call(lambda: 0)
29
+
30
+ self.ctx = ctx = Context()
31
+ ctx.request_hosts(N)
32
+
33
+ # we want ctx to start its listener threads
34
+ # before creating the hosts because
35
+ # initialization will happen faster in this case
36
+ sleep(0)
37
+ supervisor_addr = f"tcp://127.0.0.1:{ctx.port}"
38
+
39
+ env = {
40
+ **os.environ,
41
+ "TORCH_SUPERVISOR_HEARTBEAT_INTERVAL": str(
42
+ monarch_supervisor.HEARTBEAT_INTERVAL
43
+ ),
44
+ # This is needed to avoid a hard failure in ncclx when we do not
45
+ # have backend topology info (eg. on RE).
46
+ "NCCL_IGNORE_TOPO_LOAD_FAILURE": "true",
47
+ }
48
+
49
+ # start_new_session=True, because we want the host managers to be able to kill
50
+ # any worker processes before they exit, even if the supervisor crashes, or we ctrl-c
51
+ # it in testing.
52
+ self.host_managers = [
53
+ subprocess.Popen(
54
+ [
55
+ PYTHON_EXECUTABLE,
56
+ "-m",
57
+ "monarch_supervisor.host",
58
+ supervisor_addr,
59
+ ],
60
+ env=env,
61
+ start_new_session=True,
62
+ )
63
+ for _ in range(N)
64
+ ]
65
+ connections = ctx.messagefilter(HostConnected)
66
+ self.hosts = [connections.recv(timeout=30).sender for _ in range(N)]
67
+
68
+ def shutdown(self):
69
+ self.ctx.shutdown()
70
+ for host_manager in self.host_managers:
71
+ host_manager.wait(timeout=10)
72
+
73
+
74
+ def python_local_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> "DeviceMesh":
75
+ """
76
+ Creates a local device mesh with the given number of hosts and gpus per host.
77
+ Easy way to use PythonLocalContext.
78
+
79
+ Args:
80
+ gpus (Optional[int]): number of gpus per host.
81
+ Default: the number of GPUs this machine has.
82
+
83
+ hosts (int): number of hosts, primarily used for simulating multiple machines locally.
84
+ Default: 1
85
+
86
+ Example::
87
+ local_mesh = python_local_mesh(gpus=2)
88
+ with local_mesh.activate():
89
+ x = torch.rand(3, 4)
90
+ local_tensor = fetch_shard(x).result()
91
+
92
+ # Cleanly shut down the local mesh and exit.
93
+ local_mesh.exit()
94
+ """
95
+ ctx = PythonLocalContext(hosts)
96
+ if gpus is None:
97
+ gpus = _local_device_count()
98
+ dm = world_mesh(ctx.ctx, ctx.hosts, gpus)
99
+
100
+ def exit(
101
+ error: Optional[RemoteException | DeviceException | Exception] = None,
102
+ ) -> None:
103
+ dm.client.shutdown(True, error)
104
+ ctx.shutdown()
105
+
106
+ dm.exit = exit
107
+ return dm