torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,249 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from functools import cache
9
+ from logging import getLogger
10
+ from timeit import default_timer as timer
11
+
12
+ from .schedule_ir import (
13
+ _Action,
14
+ _add_send_recv,
15
+ _ComputationType,
16
+ _dump_csv,
17
+ _format_pipeline_order,
18
+ _merge_bw,
19
+ BACKWARD,
20
+ FORWARD,
21
+ FULL_BACKWARD,
22
+ )
23
+
24
+ logger = getLogger()
25
+
26
+
27
+ def get_stage_str(model_chunk_index, training_stage, mb_index):
28
+ ctype = _ComputationType.from_str(training_stage)
29
+ return str(_Action(model_chunk_index, ctype, mb_index))
30
+
31
+
32
+ def get_dora_schedule(
33
+ num_model_chunks,
34
+ pipeline_parallel_size,
35
+ num_round,
36
+ num_microbatch_per_round,
37
+ zero_bubble,
38
+ total_num_microbatches,
39
+ num_microbatches,
40
+ dfs=False,
41
+ prefetch_weight_latency=1.0,
42
+ enable_weight_sharding_in_pp=False,
43
+ enable_wgrad_sharding_in_pp=False,
44
+ ):
45
+ start_time = timer()
46
+ num_warmup_microbatches_list = []
47
+ num_1f1b_microbatches_list = []
48
+ num_additional_1b1w_list = []
49
+ for pipeline_parallel_rank in range(pipeline_parallel_size):
50
+ num_warmup_microbatches = 0
51
+ # The number of microbatches that last pipeline stage run before 1f1b.
52
+ num_warmup_microbatches += (num_model_chunks - 1) * num_microbatch_per_round
53
+ # From last PP stage up, each rank will be 2 more than the previous one.
54
+ num_warmup_microbatches += (
55
+ pipeline_parallel_size - pipeline_parallel_rank - 1
56
+ ) * 2
57
+ num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
58
+ # The number of 1f1b for zero bubble schedule
59
+ if num_microbatches == pipeline_parallel_size:
60
+ num_1f1b_microbatches = pipeline_parallel_rank
61
+ else:
62
+ num_1f1b_microbatches = 2 * pipeline_parallel_rank
63
+ num_additional_1b1w = max(
64
+ int(math.ceil((pipeline_parallel_size - 4) / 2)) - pipeline_parallel_rank,
65
+ 0,
66
+ )
67
+ if dfs:
68
+ num_1f1b_microbatches = 0
69
+ num_additional_1b1w = 0
70
+
71
+ num_warmup_microbatches_list.append(num_warmup_microbatches)
72
+ num_1f1b_microbatches_list.append(num_1f1b_microbatches)
73
+ num_additional_1b1w_list.append(num_additional_1b1w)
74
+ schedules = []
75
+
76
+ def get_last_pp_rank(i):
77
+ return (i - 1) % pipeline_parallel_size, i - 1 < 0
78
+
79
+ def get_next_pp_rank(i):
80
+ return (i + 1) % pipeline_parallel_size, i + 1 >= pipeline_parallel_size
81
+
82
+ for pipeline_parallel_rank in range(pipeline_parallel_size):
83
+ s = []
84
+ fwd_mb_index_list = [0 for i in range(num_model_chunks)]
85
+ bwd_mb_index_list = [0 for i in range(num_model_chunks)]
86
+ fwd_model_chunk_index = 0
87
+ bwd_model_chunk_index = num_model_chunks - 1
88
+ weight_store = []
89
+ num_warmup_microbatches = num_warmup_microbatches_list[pipeline_parallel_rank]
90
+ num_1f1b_microbatches = num_1f1b_microbatches_list[pipeline_parallel_rank]
91
+ num_additional_1b1w = num_additional_1b1w_list[pipeline_parallel_rank]
92
+ fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
93
+ bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
94
+ fill_1b1w = False
95
+ for _ in range(num_warmup_microbatches): # warm up fwd
96
+ fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
97
+ bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
98
+ tmp = get_stage_str(fwd_model_chunk_index, "F", fwd_mb_index)
99
+ s.append(tmp)
100
+ fwd_mb_index_list[fwd_model_chunk_index] += 1
101
+ if fwd_mb_index_list[fwd_model_chunk_index] % num_microbatch_per_round == 0:
102
+ if fwd_model_chunk_index < num_model_chunks - 1:
103
+ fwd_model_chunk_index += 1
104
+ else:
105
+ fwd_model_chunk_index = 0
106
+ for i in range(
107
+ total_num_microbatches - num_warmup_microbatches
108
+ ): # 1f1b and 1f1b1w
109
+ if (
110
+ fwd_model_chunk_index == 1 and not fill_1b1w
111
+ ): # additional 1b1w to fill before fwd
112
+ fill_1b1w = True
113
+ for _ in range(num_additional_1b1w):
114
+ bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
115
+ tmp = get_stage_str(bwd_model_chunk_index, "B", bwd_mb_index)
116
+ s.append(tmp)
117
+ tmp = get_stage_str(bwd_model_chunk_index, "W", bwd_mb_index)
118
+ s.append(tmp)
119
+ bwd_mb_index_list[bwd_model_chunk_index] += 1
120
+ if (
121
+ bwd_mb_index_list[bwd_model_chunk_index]
122
+ % num_microbatch_per_round
123
+ == 0
124
+ ):
125
+ if bwd_model_chunk_index > 0:
126
+ bwd_model_chunk_index -= 1
127
+ else:
128
+ bwd_model_chunk_index = num_model_chunks - 1
129
+ fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
130
+ bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
131
+ tmp = get_stage_str(fwd_model_chunk_index, "F", fwd_mb_index)
132
+ s.append(tmp)
133
+ fwd_mb_index_list[fwd_model_chunk_index] += 1
134
+ if fwd_mb_index_list[fwd_model_chunk_index] % num_microbatch_per_round == 0:
135
+ if fwd_model_chunk_index < num_model_chunks - 1:
136
+ fwd_model_chunk_index += 1
137
+ else:
138
+ fwd_model_chunk_index = 0
139
+ tmp = get_stage_str(
140
+ bwd_model_chunk_index, "B" if zero_bubble else "BW", bwd_mb_index
141
+ )
142
+ s.append(tmp)
143
+ tmp = get_stage_str(bwd_model_chunk_index, "W", bwd_mb_index)
144
+ if zero_bubble and i < num_1f1b_microbatches:
145
+ weight_store.append(tmp)
146
+ else:
147
+ s.append(tmp)
148
+ bwd_mb_index_list[bwd_model_chunk_index] += 1
149
+ if bwd_mb_index_list[bwd_model_chunk_index] % num_microbatch_per_round == 0:
150
+ if bwd_model_chunk_index > 0:
151
+ bwd_model_chunk_index -= 1
152
+ else:
153
+ bwd_model_chunk_index = num_model_chunks - 1
154
+ num_cooldown = (
155
+ num_warmup_microbatches - num_additional_1b1w
156
+ if fill_1b1w
157
+ else num_warmup_microbatches
158
+ )
159
+ for _ in range(num_cooldown): # cooldown bwd
160
+ fwd_mb_index = fwd_mb_index_list[fwd_model_chunk_index]
161
+ bwd_mb_index = bwd_mb_index_list[bwd_model_chunk_index]
162
+ tmp = get_stage_str(bwd_model_chunk_index, "B", bwd_mb_index)
163
+ s.append(tmp)
164
+ tmp = get_stage_str(bwd_model_chunk_index, "W", bwd_mb_index)
165
+ s.append(tmp)
166
+ bwd_mb_index_list[bwd_model_chunk_index] += 1
167
+ if bwd_mb_index_list[bwd_model_chunk_index] % num_microbatch_per_round == 0:
168
+ if bwd_model_chunk_index > 0:
169
+ bwd_model_chunk_index -= 1
170
+ else:
171
+ bwd_model_chunk_index = num_model_chunks - 1
172
+ if len(weight_store) > 0:
173
+ s += weight_store
174
+ schedules.append(s)
175
+
176
+ compute_schedules = {}
177
+ for rank in range(pipeline_parallel_size):
178
+ compute_schedules[rank] = []
179
+ for action_str in schedules[rank]:
180
+ action = _Action.from_str(action_str)
181
+ stage_index = action.stage_index * pipeline_parallel_size + rank
182
+ action = _Action(
183
+ stage_index, action.computation_type, action.microbatch_index
184
+ )
185
+ compute_schedules[rank].append(action)
186
+
187
+ lowered_comm_schedule = compute_schedules
188
+ for rank in lowered_comm_schedule:
189
+ lowered_comm_schedule[rank] = _merge_bw(lowered_comm_schedule[rank])
190
+
191
+ dump_scheduler_ir = True
192
+ if dump_scheduler_ir:
193
+ compute_str = _format_pipeline_order(lowered_comm_schedule)
194
+ with open("lowered_compute.log", "w") as logf:
195
+ logf.write(compute_str)
196
+ _dump_csv(compute_schedules, "lowered_compute.csv")
197
+
198
+ lowered_comm_schedule = _add_send_recv(
199
+ lowered_comm_schedule,
200
+ stage_to_rank=lambda chunk_index: chunk_index % pipeline_parallel_size,
201
+ num_stages=num_model_chunks * pipeline_parallel_size,
202
+ )
203
+
204
+ comms_str = _format_pipeline_order(lowered_comm_schedule)
205
+ if dump_scheduler_ir:
206
+ with open("lowered_comms.log", "w") as logf:
207
+ logf.write(comms_str)
208
+ _dump_csv(lowered_comm_schedule, "lowered_compute_with_send_recv.csv")
209
+ logger.debug("---------- lowered IR\n%s----------", comms_str)
210
+
211
+ if not enable_weight_sharding_in_pp and not enable_wgrad_sharding_in_pp:
212
+ return lowered_comm_schedule
213
+
214
+ generation_time = timer() - start_time
215
+ logger.info(f"schedule generation took {generation_time:.6f} seconds")
216
+
217
+ return lowered_comm_schedule
218
+
219
+
220
+ # TODO - replace bfs / dfs functions below with new IR generators
221
+ ir_schedules = {
222
+ # "dora": get_dora_schedule,
223
+ "dora-dfs": lambda *args, **kwargs: get_dora_schedule(*args, **kwargs, dfs=True),
224
+ # "zbv": get_zbv_schedule,
225
+ # "zbw": get_zbw_schedule,
226
+ }
227
+
228
+ is_zero_bubble = {
229
+ # "dora": True,
230
+ "dora-dfs": True,
231
+ # "zbv": True,
232
+ # "zbw": True,
233
+ }
234
+
235
+
236
+ @cache
237
+ def generate_schedule(name: str, *args, **kwargs):
238
+ assert name in ir_schedules, f"{name} is not a supported schedule type"
239
+ schedules = ir_schedules[name](*args, **kwargs)
240
+ stage_to_rank = {}
241
+ for rank, schedule_actions_rank in schedules.items():
242
+ for action in schedule_actions_rank:
243
+ comp_type = action.computation_type
244
+ stage_idx = action.stage_index
245
+ if comp_type == FORWARD:
246
+ stage_to_rank[stage_idx] = rank
247
+ if comp_type in (BACKWARD, FULL_BACKWARD):
248
+ stage_to_rank[stage_idx] = rank
249
+ return schedules, stage_to_rank
monarch/pdb_wrapper.py ADDED
@@ -0,0 +1,135 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import bdb
8
+ import inspect
9
+ import io
10
+ import pdb # noqa
11
+ import socket
12
+ import sys
13
+ from dataclasses import dataclass
14
+
15
+ from typing import Dict, TYPE_CHECKING
16
+
17
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
18
+
19
+ if TYPE_CHECKING:
20
+ from monarch.debugger import DebugClient
21
+
22
+
23
+ @dataclass
24
+ class DebuggerWrite:
25
+ payload: bytes
26
+ function: str | None
27
+ lineno: int | None
28
+
29
+
30
+ class PdbWrapper(pdb.Pdb):
31
+ def __init__(
32
+ self,
33
+ rank: int,
34
+ coords: Dict[str, int],
35
+ actor_id: ActorId,
36
+ client_ref: "DebugClient",
37
+ header: str | None = None,
38
+ ):
39
+ self.rank = rank
40
+ self.coords = coords
41
+ self.header = header
42
+ self.actor_id = actor_id
43
+ self.client_ref = client_ref
44
+ # pyre-ignore
45
+ super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self))
46
+ self._first = True
47
+
48
+ def setup(self, *args, **kwargs):
49
+ r = super().setup(*args, **kwargs)
50
+ if self._first:
51
+ self._first = False
52
+ # when we enter the debugger, we want to present the user's stack frame
53
+ # not the nested one inside session.run. This means that the local
54
+ # variables are what gets printed, etc. To do this
55
+ # we first execute up 2 to get to that frame.
56
+ self.do_up(2)
57
+ return r
58
+
59
+ def set_continue(self) -> None:
60
+ r = super().set_continue()
61
+ if not self.breaks:
62
+ # no more breakpoints so this debugger will not
63
+ # be used again, and we detach from the controller io.
64
+ self.client_ref.debugger_session_end.call_one(self.rank).get()
65
+ # break cycle with itself before we exit
66
+ self.stdin = sys.stdin
67
+ self.stdout = sys.stdout
68
+ return r
69
+
70
+ def set_trace(self):
71
+ self.client_ref.debugger_session_start.call_one(
72
+ self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id
73
+ ).get()
74
+ if self.header:
75
+ self.message(self.header)
76
+ super().set_trace()
77
+
78
+
79
+ class ReadWrapper(io.RawIOBase):
80
+ def __init__(self, session: "PdbWrapper"):
81
+ self.session = session
82
+
83
+ def readinto(self, b):
84
+ response = self.session.client_ref.debugger_read.call_one(
85
+ self.session.rank, len(b)
86
+ ).get()
87
+ if response == "detach":
88
+ # this gets injected by the worker event loop to
89
+ # get the worker thread to exit on an Exit command.
90
+ raise bdb.BdbQuit
91
+ assert isinstance(response, DebuggerWrite) and len(response.payload) <= len(b)
92
+ b[: len(response.payload)] = response.payload
93
+ return len(response.payload)
94
+
95
+ def readable(self) -> bool:
96
+ return True
97
+
98
+ @classmethod
99
+ def create(cls, session: "PdbWrapper"):
100
+ return io.TextIOWrapper(io.BufferedReader(cls(session)))
101
+
102
+
103
+ class WriteWrapper:
104
+ def __init__(self, session: "PdbWrapper"):
105
+ self.session = session
106
+
107
+ def writable(self) -> bool:
108
+ return True
109
+
110
+ def write(self, s: str):
111
+ function = None
112
+ lineno = None
113
+ if self.session.curframe is not None:
114
+ # pyre-ignore
115
+ function = f"{inspect.getmodulename(self.session.curframe.f_code.co_filename)}.{self.session.curframe.f_code.co_name}"
116
+ # pyre-ignore
117
+ lineno = self.session.curframe.f_lineno
118
+ self.session.client_ref.debugger_write.call_one(
119
+ self.session.rank,
120
+ DebuggerWrite(
121
+ s.encode(),
122
+ function,
123
+ lineno,
124
+ ),
125
+ ).get()
126
+
127
+ def flush(self):
128
+ pass
129
+
130
+
131
+ def remote_breakpointhook(
132
+ rank: int, coords: Dict[str, int], actor_id: ActorId, client_ref: "DebugClient"
133
+ ):
134
+ ds = PdbWrapper(rank, coords, actor_id, client_ref)
135
+ ds.set_trace()
monarch/proc_mesh.py ADDED
@@ -0,0 +1,299 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import os
10
+ import sys
11
+ from contextlib import AbstractContextManager
12
+
13
+ from typing import (
14
+ Any,
15
+ cast,
16
+ Dict,
17
+ List,
18
+ Optional,
19
+ Sequence,
20
+ Type,
21
+ TYPE_CHECKING,
22
+ TypeVar,
23
+ )
24
+
25
+ if TYPE_CHECKING:
26
+ import torch
27
+
28
+ import monarch
29
+ from monarch import ActorFuture as Future
30
+
31
+ # Conditionally import DeviceMesh and spawn_tensor_engine only if tensor_engine is available
32
+ # pyre-ignore[21]
33
+ from monarch._rust_bindings import has_tensor_engine
34
+
35
+ from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
36
+ Alloc,
37
+ AllocConstraints,
38
+ AllocSpec,
39
+ )
40
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
41
+ from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
42
+ from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
43
+ from monarch.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef
44
+
45
+ from monarch.code_sync import RemoteWorkspace, RsyncMeshClient
46
+ from monarch.common._device_utils import _local_device_count
47
+ from monarch.common.shape import MeshTrait
48
+ from monarch.rdma import RDMAManager
49
+
50
+ if has_tensor_engine():
51
+ from monarch.common.device_mesh import DeviceMesh
52
+ from monarch.mesh_controller import spawn_tensor_engine
53
+ else:
54
+ DeviceMesh = None
55
+ spawn_tensor_engine = None
56
+
57
+ T = TypeVar("T")
58
+ try:
59
+ from __manifest__ import fbmake # noqa
60
+
61
+ IN_PAR = True
62
+ except ImportError:
63
+ IN_PAR = False
64
+
65
+
66
+ async def _allocate_nonblocking(alloc: Alloc) -> "ProcMesh":
67
+ return ProcMesh(await HyProcMesh.allocate_nonblocking(alloc))
68
+
69
+
70
+ def _allocate_blocking(alloc: Alloc) -> "ProcMesh":
71
+ return ProcMesh(HyProcMesh.allocate_blocking(alloc))
72
+
73
+
74
+ class ProcMesh(MeshTrait):
75
+ def __init__(
76
+ self,
77
+ hy_proc_mesh: HyProcMesh,
78
+ _mock_shape: Optional[Shape] = None,
79
+ _device_mesh: Optional[DeviceMesh] = None,
80
+ ) -> None:
81
+ self._proc_mesh = hy_proc_mesh
82
+ self._mock_shape: Optional[Shape] = _mock_shape
83
+ self._mailbox: Mailbox = self._proc_mesh.client
84
+ self._rdma_manager: Optional[RDMAManager] = None
85
+ self._rsync_mesh_client: Optional[RsyncMeshClient] = None
86
+ self._maybe_device_mesh: Optional[DeviceMesh] = _device_mesh
87
+ if _mock_shape is None:
88
+ self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager)
89
+
90
+ @property
91
+ def _shape(self) -> Shape:
92
+ return self._proc_mesh.shape if self._mock_shape is None else self._mock_shape
93
+
94
+ @property
95
+ def _ndslice(self) -> Slice:
96
+ return self._shape.ndslice
97
+
98
+ @property
99
+ def _labels(self) -> List[str]:
100
+ return self._shape.labels
101
+
102
+ def _new_with_shape(self, shape: Shape) -> "ProcMesh":
103
+ device_mesh = (
104
+ None
105
+ if self._device_mesh is None
106
+ else self._device_mesh._new_with_shape(shape)
107
+ )
108
+ return ProcMesh(self._proc_mesh, _mock_shape=shape, _device_mesh=device_mesh)
109
+
110
+ def spawn(
111
+ self, name: str, Class: Type[T], *args: Any, **kwargs: Any
112
+ ) -> Future[ActorMeshRef[T]]:
113
+ if self._mock_shape is not None:
114
+ raise NotImplementedError("NYI: spawn on slice of a proc mesh.")
115
+ return Future(
116
+ lambda: self._spawn_nonblocking(name, Class, *args, **kwargs),
117
+ lambda: self._spawn_blocking(name, Class, *args, **kwargs),
118
+ )
119
+
120
+ @classmethod
121
+ def from_alloc(self, alloc: Alloc) -> Future["ProcMesh"]:
122
+ return Future(
123
+ lambda: _allocate_nonblocking(alloc),
124
+ lambda: _allocate_blocking(alloc),
125
+ )
126
+
127
+ def _spawn_blocking(
128
+ self, name: str, Class: Type[T], *args: Any, **kwargs: Any
129
+ ) -> T:
130
+ if not issubclass(Class, Actor):
131
+ raise ValueError(
132
+ f"{Class} must subclass monarch.service.Actor to spawn it."
133
+ )
134
+
135
+ actor_mesh = self._proc_mesh.spawn_blocking(name, _Actor)
136
+ service = ActorMeshRef(
137
+ Class,
138
+ _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh),
139
+ self._mailbox,
140
+ )
141
+ # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by
142
+ # doing `ActorMeshRef(Class, actor_handle)` but not calling _create.
143
+ service._create(args, kwargs)
144
+ return cast(T, service)
145
+
146
+ def __repr__(self) -> str:
147
+ return repr(self._proc_mesh)
148
+
149
+ def __str__(self) -> str:
150
+ return str(self._proc_mesh)
151
+
152
+ async def _spawn_nonblocking(
153
+ self, name: str, Class: Type[T], *args: Any, **kwargs: Any
154
+ ) -> T:
155
+ if not issubclass(Class, Actor):
156
+ raise ValueError(
157
+ f"{Class} must subclass monarch.service.Actor to spawn it."
158
+ )
159
+
160
+ actor_mesh = await self._proc_mesh.spawn_nonblocking(name, _Actor)
161
+ service = ActorMeshRef(
162
+ Class,
163
+ _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh),
164
+ self._mailbox,
165
+ )
166
+ # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by
167
+ # doing `ActorMeshRef(Class, actor_handle)` but not calling _create.
168
+ service._create(args, kwargs)
169
+ return cast(T, service)
170
+
171
+ @property
172
+ def _device_mesh(self) -> "DeviceMesh":
173
+ if spawn_tensor_engine is None:
174
+ raise RuntimeError(
175
+ "DeviceMesh is not available because tensor_engine was not compiled (USE_TENSOR_ENGINE=0)"
176
+ )
177
+ if self._maybe_device_mesh is None:
178
+ if self._mock_shape is not None:
179
+ raise NotImplementedError(
180
+ "NYI: activating a proc mesh must first happen on the root proc_mesh until we fix spawning on submeshes."
181
+ )
182
+ self._maybe_device_mesh = spawn_tensor_engine(self)
183
+ return self._maybe_device_mesh
184
+
185
+ # pyre-ignore
186
+ def activate(self) -> AbstractContextManager:
187
+ return self._device_mesh.activate()
188
+
189
+ def rank_tensor(self, dim: str | Sequence[str]) -> "torch.Tensor":
190
+ return self._device_mesh.rank(dim)
191
+
192
+ def rank_tensors(self) -> Dict[str, "torch.Tensor"]:
193
+ return self._device_mesh.ranks
194
+
195
+ async def sync_workspace(self) -> None:
196
+ if self._rsync_mesh_client is None:
197
+ # TODO(agallagher): We need some way to configure and pass this
198
+ # in -- right now we're assuming the `gpu` dimension, which isn't
199
+ # correct.
200
+ assert set(self._proc_mesh.shape.labels).issubset({"gpus", "hosts"})
201
+ # The workspace shape (i.e. only perform one rsync per host).
202
+ workspace_shape = self.slice(gpus=slice(0, 1, 1))._mock_shape
203
+ assert workspace_shape is not None
204
+ # TODO(agallagher): We should probably hide this behind something
205
+ # like a `Workspace` class and support abstracting/configuring
206
+ # different sync methods.
207
+ self._rsync_mesh_client = RsyncMeshClient.spawn_blocking(
208
+ proc_mesh=self._proc_mesh,
209
+ shape=workspace_shape,
210
+ # TODO(agallagher): Is there a better way to infer/set the local
211
+ # workspace dir, rather than use PWD?
212
+ local_workspace=os.getcwd(),
213
+ remote_workspace=RemoteWorkspace.FromEnvVar("WORKSPACE_DIR"),
214
+ )
215
+ await self._rsync_mesh_client.sync_workspace()
216
+
217
+
218
+ async def local_proc_mesh_nonblocking(
219
+ *, gpus: Optional[int] = None, hosts: int = 1
220
+ ) -> ProcMesh:
221
+ if gpus is None:
222
+ gpus = _local_device_count()
223
+ spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
224
+ allocator = monarch.LocalAllocator()
225
+ alloc = await allocator.allocate(spec)
226
+ return await ProcMesh.from_alloc(alloc)
227
+
228
+
229
+ def local_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh:
230
+ if gpus is None:
231
+ gpus = _local_device_count()
232
+ spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
233
+ allocator = monarch.LocalAllocator()
234
+ alloc = allocator.allocate(spec).get()
235
+ return ProcMesh.from_alloc(alloc).get()
236
+
237
+
238
+ def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]:
239
+ return Future(
240
+ lambda: local_proc_mesh_nonblocking(gpus=gpus, hosts=hosts),
241
+ lambda: local_proc_mesh_blocking(gpus=gpus, hosts=hosts),
242
+ )
243
+
244
+
245
+ _BOOTSTRAP_MAIN = "monarch.bootstrap_main"
246
+
247
+
248
+ def _get_bootstrap_args() -> tuple[str, Optional[list[str]], dict[str, str]]:
249
+ if IN_PAR:
250
+ cmd = sys.argv[0]
251
+ args = None
252
+ env = {
253
+ "PAR_MAIN_OVERRIDE": _BOOTSTRAP_MAIN,
254
+ }
255
+ else:
256
+ cmd = sys.executable
257
+ args = ["-m", _BOOTSTRAP_MAIN]
258
+ env = {}
259
+
260
+ return cmd, args, env
261
+
262
+
263
+ async def proc_mesh_nonblocking(
264
+ *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None
265
+ ) -> ProcMesh:
266
+ if gpus is None:
267
+ gpus = _local_device_count()
268
+ spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
269
+ env = env or {}
270
+ cmd, args, base_env = _get_bootstrap_args()
271
+ env.update(base_env)
272
+ env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
273
+ allocator = monarch.ProcessAllocator(cmd, args, env)
274
+ alloc = await allocator.allocate(spec)
275
+ return await ProcMesh.from_alloc(alloc)
276
+
277
+
278
+ def proc_mesh_blocking(
279
+ *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None
280
+ ) -> ProcMesh:
281
+ if gpus is None:
282
+ gpus = _local_device_count()
283
+ spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
284
+ env = env or {}
285
+ cmd, args, base_env = _get_bootstrap_args()
286
+ env.update(base_env)
287
+ env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
288
+ allocator = monarch.ProcessAllocator(cmd, args, env)
289
+ alloc = allocator.allocate(spec).get()
290
+ return ProcMesh.from_alloc(alloc).get()
291
+
292
+
293
+ def proc_mesh(
294
+ *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None
295
+ ) -> Future[ProcMesh]:
296
+ return Future(
297
+ lambda: proc_mesh_nonblocking(gpus=gpus, hosts=hosts, env=env),
298
+ lambda: proc_mesh_blocking(gpus=gpus, hosts=hosts, env=env),
299
+ )