torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +74 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +198 -0
- monarch/actor_mesh.py +692 -0
- monarch/allocator.py +62 -0
- monarch/bootstrap_main.py +75 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +69 -0
- monarch/cached_remote_function.py +257 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +646 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +443 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +572 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +304 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +204 -0
- monarch/common/stream.py +111 -0
- monarch/common/tensor.py +793 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/fetch.py +55 -0
- monarch/future.py +25 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/proc_mesh.py +188 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +190 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +357 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +189 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +57 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +121 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +139 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +112 -0
- tests/test_alloc.py +25 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +835 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +372 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +182 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
- torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,389 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
|
9
|
+
import copy
|
10
|
+
import itertools
|
11
|
+
import logging
|
12
|
+
import traceback
|
13
|
+
from collections import deque
|
14
|
+
from typing import cast, Dict, List, Optional, Sequence, Tuple
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
from monarch.simulator.config import META_VAL
|
18
|
+
from monarch.simulator.profiling import RuntimeEstimator
|
19
|
+
from monarch.simulator.task import Borrow, EventTask, Task, WorkerTaskManager
|
20
|
+
from monarch.simulator.tensor import (
|
21
|
+
FakeTensorTracker,
|
22
|
+
StreamMemoryTracker,
|
23
|
+
TensorManager,
|
24
|
+
WorkerStorageTracker,
|
25
|
+
)
|
26
|
+
from monarch.simulator.trace import TraceEvent
|
27
|
+
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
class Stream:
|
32
|
+
"""Represents a worker stream."""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
ident: int,
|
37
|
+
name: str,
|
38
|
+
fake_tensor_tracker: FakeTensorTracker,
|
39
|
+
storage_tracker: WorkerStorageTracker,
|
40
|
+
cpu_tensors: TensorManager,
|
41
|
+
) -> None:
|
42
|
+
self.id = ident
|
43
|
+
self.name = name
|
44
|
+
self.task_queue = deque()
|
45
|
+
self.last_task: Optional[Task] = None
|
46
|
+
self.now = 0
|
47
|
+
self.events: List[TraceEvent] = []
|
48
|
+
self.memory = StreamMemoryTracker(storage_tracker)
|
49
|
+
# Local tensors created on this stream. tTe value means which tasks
|
50
|
+
# or borrows (int) are using this tensor.
|
51
|
+
self.tensors = TensorManager(fake_tensor_tracker, self.memory)
|
52
|
+
self.cpu_tensors = cpu_tensors
|
53
|
+
self.fake_tensor_tracker = fake_tensor_tracker
|
54
|
+
|
55
|
+
def add_task(self, task: Task) -> None:
|
56
|
+
"""
|
57
|
+
Add a task to this stream. A task is always pending in the beginning and
|
58
|
+
will be executed only if it is ready and is the first task in the stream.
|
59
|
+
"""
|
60
|
+
task.start_time = max(self.now, task.start_time)
|
61
|
+
|
62
|
+
for output in set(task.outputs) - set(task.inputs):
|
63
|
+
self.tensors.add(output, (task,), task.start_time)
|
64
|
+
|
65
|
+
# Input must be from the previous tasks on the same stream or from
|
66
|
+
# the borrowed tensors.
|
67
|
+
for tensor in task.inputs:
|
68
|
+
if tensor in self.cpu_tensors:
|
69
|
+
self.cpu_tensors.incr_ref(tensor, task)
|
70
|
+
else:
|
71
|
+
self.tensors.incr_ref(tensor, task)
|
72
|
+
|
73
|
+
if self.task_queue:
|
74
|
+
task.dependencies.append(self.task_queue[-1])
|
75
|
+
elif self.last_task:
|
76
|
+
task.dependencies.append(self.last_task)
|
77
|
+
|
78
|
+
self.task_queue.append(task)
|
79
|
+
|
80
|
+
def lend(self, borrow: Borrow) -> None:
|
81
|
+
self.tensors.incr_ref(borrow.tensor_src_id, borrow.ident)
|
82
|
+
|
83
|
+
def return_borrow(self, borrow: Borrow) -> None:
|
84
|
+
self.tensors.decr_ref(borrow.tensor_src_id, borrow.ident, self.now, None)
|
85
|
+
|
86
|
+
def borrow(self, borrow: Borrow) -> None:
|
87
|
+
# We don't care about the timestamp as borrow should not incur any memory
|
88
|
+
# usage change.
|
89
|
+
self.tensors.add(borrow.tensor_dst_id, (), -1)
|
90
|
+
self.tensors.first_use(borrow.tensor_dst_id, -1, None)
|
91
|
+
|
92
|
+
def borrow_drop(self, borrow: Borrow) -> None:
|
93
|
+
# We don't care about the timestamp as borrow should not incur any memory
|
94
|
+
# usage change.
|
95
|
+
# self.tensors.delete(borrow.tensor_dst_id, -1)
|
96
|
+
pass
|
97
|
+
|
98
|
+
def delete_refs(self, tensor_ids: List[int], now: int) -> None:
|
99
|
+
tb = traceback.extract_stack()
|
100
|
+
for tensor_id in tensor_ids:
|
101
|
+
if tensor_id not in self.tensors:
|
102
|
+
continue
|
103
|
+
now = max(self.now, now)
|
104
|
+
self.tensors.delete(tensor_id, now, tb)
|
105
|
+
|
106
|
+
def maybe_set_ready(self) -> bool:
|
107
|
+
if self.task_queue:
|
108
|
+
return self.task_queue[0].maybe_set_ready()
|
109
|
+
return False
|
110
|
+
|
111
|
+
def maybe_execute(self) -> bool:
|
112
|
+
"""
|
113
|
+
Check if we can execute the first task of this stream. Return True if
|
114
|
+
the first task's state is changed from READY to EXECUTING.
|
115
|
+
"""
|
116
|
+
if self.task_queue:
|
117
|
+
task = self.task_queue[0]
|
118
|
+
executing = task.maybe_execute()
|
119
|
+
if executing:
|
120
|
+
for output in set(task.outputs) - set(task.inputs):
|
121
|
+
self.tensors.first_use(output, task.start_time, task.traceback)
|
122
|
+
return False
|
123
|
+
|
124
|
+
def maybe_finish(self) -> Tuple[Optional[Task], Optional[Task]]:
|
125
|
+
"""
|
126
|
+
Check if we can finish the first task of this stream. Return the task if
|
127
|
+
the first task's state is changed from EXECUTING to EXECUTED else return
|
128
|
+
None.
|
129
|
+
"""
|
130
|
+
if not self.task_queue:
|
131
|
+
return (None, None)
|
132
|
+
|
133
|
+
task = self.task_queue[0]
|
134
|
+
if not task.maybe_finish():
|
135
|
+
return (None, None)
|
136
|
+
|
137
|
+
task = self.task_queue.popleft()
|
138
|
+
original_last_task = self.last_task
|
139
|
+
self.last_task = task
|
140
|
+
|
141
|
+
# Update the tensor and memory usage.
|
142
|
+
if isinstance(task, EventTask):
|
143
|
+
borrow = task.borrow
|
144
|
+
if borrow is not None and borrow.tensor_src_id in self.tensors:
|
145
|
+
self.tensors.decr_ref(
|
146
|
+
borrow.tensor_src_id, borrow.ident, task.end_time, task.traceback
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
removed_tensors = set()
|
150
|
+
for tensor in itertools.chain(task.inputs, task.outputs):
|
151
|
+
if tensor in self.cpu_tensors:
|
152
|
+
self.cpu_tensors.decr_ref(
|
153
|
+
tensor, task, task.end_time, task.traceback
|
154
|
+
)
|
155
|
+
removed_tensors.add(tensor)
|
156
|
+
elif tensor not in self.tensors:
|
157
|
+
raise RuntimeError(f"tensor {tensor} not in self.tensors.")
|
158
|
+
elif tensor not in removed_tensors:
|
159
|
+
# We also remove the reference even if the tensor is in
|
160
|
+
# outputs -- the tensor is not going to be deleted until
|
161
|
+
# DeleteRef is received.
|
162
|
+
self.tensors.decr_ref(tensor, task, task.end_time, task.traceback)
|
163
|
+
removed_tensors.add(tensor)
|
164
|
+
|
165
|
+
# Add TraceEvent.
|
166
|
+
if task.end_time > task.start_time:
|
167
|
+
runtime = task.end_time - task.start_time
|
168
|
+
self.events.append(
|
169
|
+
TraceEvent(
|
170
|
+
task.start_time, runtime, task.meta, task.command_id, task.traceback
|
171
|
+
)
|
172
|
+
)
|
173
|
+
|
174
|
+
# update the stream timestamp
|
175
|
+
self.now = task.end_time
|
176
|
+
return (original_last_task, task)
|
177
|
+
|
178
|
+
def wait_event(self, event: EventTask) -> None:
|
179
|
+
self.add_task(event)
|
180
|
+
|
181
|
+
def record_event(self) -> Task:
|
182
|
+
if self.task_queue:
|
183
|
+
return self.task_queue[-1]
|
184
|
+
elif self.last_task:
|
185
|
+
return self.last_task
|
186
|
+
else:
|
187
|
+
raise RuntimeError("No tasks can be recorded.")
|
188
|
+
|
189
|
+
def clone(
|
190
|
+
self,
|
191
|
+
task_manager: WorkerTaskManager,
|
192
|
+
storage_tracker: WorkerStorageTracker,
|
193
|
+
cpu_tensors: TensorManager,
|
194
|
+
) -> "Stream":
|
195
|
+
ret = Stream(
|
196
|
+
ident=self.id,
|
197
|
+
name=self.name,
|
198
|
+
fake_tensor_tracker=self.fake_tensor_tracker,
|
199
|
+
storage_tracker=storage_tracker,
|
200
|
+
cpu_tensors=cpu_tensors,
|
201
|
+
)
|
202
|
+
for task in self.task_queue:
|
203
|
+
ret.task_queue.append(task_manager.tasks[task.task_id])
|
204
|
+
if self.last_task:
|
205
|
+
assert self.last_task.task_id is not None
|
206
|
+
ret.last_task = task_manager.tasks[self.last_task.task_id]
|
207
|
+
ret.now = self.now
|
208
|
+
ret.events = copy.copy(self.events)
|
209
|
+
ret.memory = self.memory.clone(storage_tracker)
|
210
|
+
ret.tensors = self.tensors.clone(task_manager, ret.memory)
|
211
|
+
return ret
|
212
|
+
|
213
|
+
|
214
|
+
class Worker:
|
215
|
+
"""Represents a worker."""
|
216
|
+
|
217
|
+
def __init__(
|
218
|
+
self,
|
219
|
+
fake_tensor_tracker: FakeTensorTracker,
|
220
|
+
runtime: RuntimeEstimator,
|
221
|
+
) -> None:
|
222
|
+
self.runtime = runtime
|
223
|
+
self.streams: Dict[int, Stream] = {}
|
224
|
+
self.default_stream_id = 0
|
225
|
+
self.events: List[TraceEvent] = []
|
226
|
+
self.wait_events: Dict[int, EventTask] = {}
|
227
|
+
self.fake_tensor_tracker = fake_tensor_tracker
|
228
|
+
self.storage_tracker = WorkerStorageTracker(fake_tensor_tracker)
|
229
|
+
# We don't track the CPU, memory usage. So pass None as the memory
|
230
|
+
# argument.
|
231
|
+
self.cpu_tensors = TensorManager(fake_tensor_tracker, None)
|
232
|
+
self.borrows: Dict[int, Borrow] = {}
|
233
|
+
|
234
|
+
self.task_manager = WorkerTaskManager()
|
235
|
+
|
236
|
+
def record_command(
|
237
|
+
self,
|
238
|
+
command: str,
|
239
|
+
command_id: int,
|
240
|
+
now: int,
|
241
|
+
traceback: Sequence[traceback.FrameSummary],
|
242
|
+
) -> None:
|
243
|
+
# This is a CPU activity event.
|
244
|
+
self.events.append(
|
245
|
+
TraceEvent(
|
246
|
+
now,
|
247
|
+
self.runtime.get_runtime("kernel_launch"),
|
248
|
+
[command] + META_VAL,
|
249
|
+
command_id,
|
250
|
+
traceback,
|
251
|
+
)
|
252
|
+
)
|
253
|
+
|
254
|
+
def create_stream(self, ident: int, name: str, default: bool) -> None:
|
255
|
+
if ident in self.streams:
|
256
|
+
raise ValueError(f"{ident} is already created.")
|
257
|
+
self.streams[ident] = Stream(
|
258
|
+
ident,
|
259
|
+
name,
|
260
|
+
self.fake_tensor_tracker,
|
261
|
+
self.storage_tracker,
|
262
|
+
self.cpu_tensors,
|
263
|
+
)
|
264
|
+
if default:
|
265
|
+
self.default_stream_id = ident
|
266
|
+
|
267
|
+
def add_task(self, task: Task, now: int, stream: Optional[int] = None) -> None:
|
268
|
+
self.record_command(task.meta[0], task.command_id, now, task.traceback)
|
269
|
+
if stream is None:
|
270
|
+
stream = self.default_stream_id
|
271
|
+
self.streams[stream].add_task(task)
|
272
|
+
self.task_manager.add(task)
|
273
|
+
|
274
|
+
def borrow(self, task: EventTask, borrow: Borrow) -> None:
|
275
|
+
from_stream = task.event_stream
|
276
|
+
to_stream = task.wait_stream
|
277
|
+
self.streams[from_stream].lend(borrow)
|
278
|
+
self.streams[to_stream].borrow(borrow)
|
279
|
+
|
280
|
+
# Record the event from the source stream so that the destination stream
|
281
|
+
# can wait for it when the borrowed tensor is first used.
|
282
|
+
# TODO: can we unify the separate data structures that keep tasks?
|
283
|
+
self.wait_events[borrow.ident] = task
|
284
|
+
self.task_manager.add(task)
|
285
|
+
self.borrows[borrow.ident] = borrow
|
286
|
+
|
287
|
+
def borrow_first_use(self, borrow_id: int, now: int) -> None:
|
288
|
+
task = self.wait_events[borrow_id]
|
289
|
+
to_stream = task.wait_stream
|
290
|
+
|
291
|
+
# The destination stream needs to wait for the event before it can use
|
292
|
+
# the borrowed tensor.
|
293
|
+
self.record_command(task.meta[0], task.command_id, now, task.traceback)
|
294
|
+
self.streams[to_stream].wait_event(task)
|
295
|
+
|
296
|
+
def borrow_last_use(self, task: EventTask, borrow_id: int) -> None:
|
297
|
+
# Record the last use event from the destination stream so that the
|
298
|
+
# source stream can wait for it when the borrow is dropped.
|
299
|
+
self.wait_events[borrow_id] = task
|
300
|
+
self.task_manager.add(task)
|
301
|
+
|
302
|
+
def borrow_drop(self, borrow_id: int, now: int) -> None:
|
303
|
+
task = self.wait_events[borrow_id]
|
304
|
+
from_stream = task.wait_stream
|
305
|
+
to_stream = task.event_stream
|
306
|
+
|
307
|
+
# Wait for the last usage.
|
308
|
+
borrow = self.borrows[borrow_id]
|
309
|
+
self.record_command(task.meta[0], task.command_id, now, task.traceback)
|
310
|
+
self.streams[from_stream].wait_event(task)
|
311
|
+
self.streams[from_stream].return_borrow(borrow)
|
312
|
+
self.streams[to_stream].borrow_drop(borrow)
|
313
|
+
|
314
|
+
def add_cpu_tensor(self, tensor_id: int, ts: int) -> None:
|
315
|
+
# Currently we don't simulate any CPU ops and memory, so this is the
|
316
|
+
# API to add CPU tensors. We also don't add the dependency of the
|
317
|
+
# creation task as it is a CPU op (e.g., dataloader).
|
318
|
+
self.cpu_tensors.add(tensor_id, (), ts)
|
319
|
+
|
320
|
+
def delete_refs(self, tensor_ids: List[int], ts: int) -> None:
|
321
|
+
for tensor_id in tensor_ids:
|
322
|
+
if tensor_id in self.cpu_tensors:
|
323
|
+
self.cpu_tensors.delete(tensor_id, ts, None)
|
324
|
+
|
325
|
+
for stream in self.streams.values():
|
326
|
+
stream.delete_refs(tensor_ids, ts)
|
327
|
+
|
328
|
+
def maybe_set_ready(self) -> bool:
|
329
|
+
"""
|
330
|
+
Check if we can set ready for tasks on the streams of the worker. Return
|
331
|
+
True if we execute at least one task.
|
332
|
+
"""
|
333
|
+
return any(s.maybe_set_ready() for s in self.streams.values())
|
334
|
+
|
335
|
+
def maybe_execute(self) -> bool:
|
336
|
+
"""
|
337
|
+
Check if we can execute tasks on the streams of the worker. Return
|
338
|
+
True if we execute at least one task.
|
339
|
+
"""
|
340
|
+
return any(s.maybe_execute() for s in self.streams.values())
|
341
|
+
|
342
|
+
def maybe_finish(self) -> bool:
|
343
|
+
"""
|
344
|
+
Check if we can finish any task on the streams of the worker. Return
|
345
|
+
True if we finish at least one task.
|
346
|
+
"""
|
347
|
+
ret = False
|
348
|
+
for stream in self.streams.values():
|
349
|
+
last_task, task = stream.maybe_finish()
|
350
|
+
if task:
|
351
|
+
ret = True
|
352
|
+
if last_task:
|
353
|
+
self.task_manager.remove(last_task)
|
354
|
+
return ret
|
355
|
+
|
356
|
+
|
357
|
+
class WorkerGroup(Worker):
|
358
|
+
def __init__(
|
359
|
+
self,
|
360
|
+
workers,
|
361
|
+
fake_tensor_tracker: FakeTensorTracker,
|
362
|
+
runtime: RuntimeEstimator,
|
363
|
+
) -> None:
|
364
|
+
super().__init__(fake_tensor_tracker, runtime)
|
365
|
+
self.workers = workers
|
366
|
+
|
367
|
+
def clone(self, workers) -> "WorkerGroup":
|
368
|
+
ret = WorkerGroup(workers, self.fake_tensor_tracker, self.runtime)
|
369
|
+
ret.default_stream_id = self.default_stream_id
|
370
|
+
ret.events = copy.copy(self.events)
|
371
|
+
ret.borrows = copy.copy(self.borrows)
|
372
|
+
ret.task_manager = self.task_manager.clone()
|
373
|
+
ret.storage_tracker = self.storage_tracker.clone()
|
374
|
+
ret.cpu_tensors = self.cpu_tensors.clone(ret.task_manager, None)
|
375
|
+
for ident, task in self.wait_events.items():
|
376
|
+
assert task.task_id is not None
|
377
|
+
ret.wait_events[ident] = cast(
|
378
|
+
EventTask, ret.task_manager.tasks[task.task_id]
|
379
|
+
)
|
380
|
+
for sid, stream in self.streams.items():
|
381
|
+
ret.streams[sid] = stream.clone(
|
382
|
+
ret.task_manager, ret.storage_tracker, ret.cpu_tensors
|
383
|
+
)
|
384
|
+
return ret
|
385
|
+
|
386
|
+
def split(self, split_set) -> "WorkerGroup":
|
387
|
+
assert len(np.setdiff1d(split_set, self.workers, assume_unique=True)) == 0
|
388
|
+
self.workers = np.setdiff1d(self.workers, split_set, assume_unique=True)
|
389
|
+
return self.clone(split_set)
|
@@ -0,0 +1,260 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
"""
|
8
|
+
This is the main function for the worker / pipe processes. It expects the args to
|
9
|
+
the process to be passed in on the command line and accessible in `sys.argv`.
|
10
|
+
|
11
|
+
To see the supported arguments checkout `monarch_tensor_worker::bootstrap`.
|
12
|
+
"""
|
13
|
+
|
14
|
+
# pyre-unsafe
|
15
|
+
|
16
|
+
import bdb
|
17
|
+
|
18
|
+
import importlib.resources
|
19
|
+
import io
|
20
|
+
|
21
|
+
import logging
|
22
|
+
import os
|
23
|
+
|
24
|
+
import pdb # noqa # noqa
|
25
|
+
import socket
|
26
|
+
import sys
|
27
|
+
from pathlib import Path
|
28
|
+
from typing import cast, Optional
|
29
|
+
|
30
|
+
from monarch._rust_bindings.monarch_extension import debugger
|
31
|
+
from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
def _handle_unhandled_exception(*args):
|
37
|
+
logger.error("Uncaught exception", exc_info=args)
|
38
|
+
|
39
|
+
|
40
|
+
_glog_level_to_abbr = {
|
41
|
+
"DEBUG": "V", # V is for VERBOSE in glog
|
42
|
+
"INFO": "I",
|
43
|
+
"WARNING": "W",
|
44
|
+
"ERROR": "E",
|
45
|
+
"CRITICAL": "C",
|
46
|
+
}
|
47
|
+
|
48
|
+
|
49
|
+
def fix_exception_lines(tb_lines):
|
50
|
+
formatted_lines = []
|
51
|
+
for line in tb_lines:
|
52
|
+
# Replace the standard file and line format with the custom format
|
53
|
+
if line.startswith(" File"):
|
54
|
+
# Extract the filename and line number
|
55
|
+
parts = line.split(",")
|
56
|
+
file_info = parts[0].strip()[6:-1] # Remove ' File "' and '"'
|
57
|
+
line_info = parts[1].strip()[5:] # Remove 'line '
|
58
|
+
new_line = f" File {file_info}:{line_info}"
|
59
|
+
if len(parts) > 2:
|
60
|
+
new_line += ", " + ",".join(parts[2:]).strip()
|
61
|
+
formatted_lines.append(new_line)
|
62
|
+
else:
|
63
|
+
formatted_lines.append(line.strip())
|
64
|
+
return formatted_lines
|
65
|
+
|
66
|
+
|
67
|
+
class _Formatter(logging.Formatter):
|
68
|
+
def __init__(self, suffix):
|
69
|
+
self.suffix = suffix
|
70
|
+
|
71
|
+
def format(self, record):
|
72
|
+
message = record.getMessage()
|
73
|
+
asctime = self.formatTime(record, "%m%d %H:%M:%S")
|
74
|
+
|
75
|
+
lines = message.strip().split("\n")
|
76
|
+
if record.exc_info:
|
77
|
+
exc_info = fix_exception_lines(
|
78
|
+
self.formatException(record.exc_info).split("\n")
|
79
|
+
)
|
80
|
+
lines.extend(exc_info)
|
81
|
+
if record.stack_info:
|
82
|
+
stack_info = self.formatStack(record.stack_info)
|
83
|
+
lines.extend(stack_info.strip().split("\n"))
|
84
|
+
|
85
|
+
shortlevel = _glog_level_to_abbr.get(record.levelname, record.levelname[0])
|
86
|
+
|
87
|
+
prefix = (
|
88
|
+
f"{shortlevel}{asctime}.{int(record.msecs*1000):06d} "
|
89
|
+
f"{record.filename}:"
|
90
|
+
f"{record.lineno}]{self.suffix}"
|
91
|
+
)
|
92
|
+
return "\n".join(f"{prefix} {line}" for line in lines)
|
93
|
+
|
94
|
+
|
95
|
+
def initialize_logging(process_name=None):
|
96
|
+
log_folder = os.environ.get("TORCH_MONARCH_LOG_FOLDER")
|
97
|
+
log_level = os.environ.get("TORCH_MONARCH_LOG_LEVEL", "INFO")
|
98
|
+
suffix = "" if process_name is None else f" {process_name}:"
|
99
|
+
handler = None
|
100
|
+
if log_folder is not None:
|
101
|
+
log_folder_path = Path(log_folder)
|
102
|
+
log_folder_path.mkdir(parents=True, exist_ok=True)
|
103
|
+
safe_process_name = (
|
104
|
+
process_name.replace("/", "_") if process_name else "logfile.log"
|
105
|
+
)
|
106
|
+
log_file_name = f"{safe_process_name}.log"
|
107
|
+
log_file_path = log_folder_path / log_file_name
|
108
|
+
handler = logging.FileHandler(log_file_path)
|
109
|
+
else:
|
110
|
+
handler = logging.StreamHandler()
|
111
|
+
handler.setFormatter(_Formatter(suffix))
|
112
|
+
handler.setLevel(log_level)
|
113
|
+
logging.root.setLevel(log_level)
|
114
|
+
logging.root.addHandler(handler)
|
115
|
+
sys.excepthook = _handle_unhandled_exception
|
116
|
+
|
117
|
+
|
118
|
+
def gethostname():
|
119
|
+
"""Get the hostname of the machine."""
|
120
|
+
hostname = socket.gethostname()
|
121
|
+
hostname = hostname.replace(".facebook.com", "")
|
122
|
+
return hostname
|
123
|
+
|
124
|
+
|
125
|
+
def _set_trace(*, header=None):
|
126
|
+
ds = PdbWrapper(header)
|
127
|
+
ds.set_trace()
|
128
|
+
|
129
|
+
|
130
|
+
class PdbWrapper(pdb.Pdb):
|
131
|
+
def __init__(self, header: Optional[str]):
|
132
|
+
self._actor = debugger.PdbActor()
|
133
|
+
self.header = header
|
134
|
+
super().__init__(
|
135
|
+
# pyre-ignore
|
136
|
+
stdout=WriteWrapper(self._actor),
|
137
|
+
stdin=ReadWrapper.create(self._actor),
|
138
|
+
)
|
139
|
+
self._first = True
|
140
|
+
|
141
|
+
def setup(self, *args, **kwargs):
|
142
|
+
r = super().setup(*args, **kwargs)
|
143
|
+
if self._first:
|
144
|
+
self._first = False
|
145
|
+
# when we enter the debugger, we want to present the user's stack frame
|
146
|
+
# not the nested one inside session.run. This means that the local
|
147
|
+
# variables are what gets printed, etc. To do this
|
148
|
+
# we first execute up 2 to get to that frame.
|
149
|
+
self.do_up(2)
|
150
|
+
return r
|
151
|
+
|
152
|
+
def set_continue(self) -> None:
|
153
|
+
r = super().set_continue()
|
154
|
+
if not self.breaks:
|
155
|
+
# no more breakpoints so this debugger will not
|
156
|
+
# be used again, and we detach from the controller io.
|
157
|
+
self._actor.send(DebuggerAction.Detach())
|
158
|
+
self._actor.drain_and_stop()
|
159
|
+
# break cycle with itself before we exit
|
160
|
+
self.stdin = sys.stdin
|
161
|
+
self.stdout = sys.stdout
|
162
|
+
return r
|
163
|
+
|
164
|
+
def set_trace(self):
|
165
|
+
self._actor.send(DebuggerAction.Paused())
|
166
|
+
message = self._actor.receive()
|
167
|
+
# we give the controller the option to ignore this request to debug
|
168
|
+
# by issuing a "detach" message immediately.
|
169
|
+
if isinstance(message, DebuggerAction.Detach):
|
170
|
+
return
|
171
|
+
elif isinstance(message, DebuggerAction.Attach):
|
172
|
+
pass
|
173
|
+
else:
|
174
|
+
raise RuntimeError(f"unexpected debugger message {message}")
|
175
|
+
if self.header:
|
176
|
+
self.message(self.header)
|
177
|
+
super().set_trace()
|
178
|
+
|
179
|
+
def set_quit(self):
|
180
|
+
self._actor.send(DebuggerAction.Detach())
|
181
|
+
self._actor.drain_and_stop()
|
182
|
+
super().set_quit()
|
183
|
+
|
184
|
+
|
185
|
+
class ReadWrapper(io.RawIOBase):
|
186
|
+
def __init__(self, actor: debugger.PdbActor):
|
187
|
+
self._actor = actor
|
188
|
+
|
189
|
+
def readinto(self, b):
|
190
|
+
self._actor.send(DebuggerAction.Read(len(b)))
|
191
|
+
response = self._actor.receive()
|
192
|
+
if isinstance(response, DebuggerAction.Detach):
|
193
|
+
raise bdb.BdbQuit
|
194
|
+
assert isinstance(response, DebuggerAction.Write)
|
195
|
+
response = cast(DebuggerAction.Write, response)
|
196
|
+
payload = debugger.get_bytes_from_write_action(response)
|
197
|
+
assert len(payload) <= len(b)
|
198
|
+
b[: len(payload)] = payload
|
199
|
+
return len(payload)
|
200
|
+
|
201
|
+
def readable(self) -> bool:
|
202
|
+
return True
|
203
|
+
|
204
|
+
@classmethod
|
205
|
+
def create(cls, actor: debugger.PdbActor):
|
206
|
+
return io.TextIOWrapper(io.BufferedReader(cls(actor)))
|
207
|
+
|
208
|
+
|
209
|
+
class WriteWrapper:
|
210
|
+
def __init__(self, actor: debugger.PdbActor):
|
211
|
+
self._actor = actor
|
212
|
+
|
213
|
+
def writable(self) -> bool:
|
214
|
+
return True
|
215
|
+
|
216
|
+
def write(self, s: str):
|
217
|
+
self._actor.send(DebuggerAction.Write(s.encode()))
|
218
|
+
|
219
|
+
def flush(self):
|
220
|
+
pass
|
221
|
+
|
222
|
+
|
223
|
+
if __name__ == "__main__":
|
224
|
+
# torch is import to make sure all the dynamic types are registered
|
225
|
+
import torch # noqa
|
226
|
+
|
227
|
+
if torch.cuda.is_available():
|
228
|
+
# Force CUDA initialization early on. CUDA init is lazy, and Python CUDA
|
229
|
+
# APIs are guarded to init CUDA if necessary. But our worker calls
|
230
|
+
# raw libtorch APIs which are not similarly guarded. So just initialize here
|
231
|
+
# to avoid issues with potentially using uninitialized CUDA state.
|
232
|
+
torch.cuda.init()
|
233
|
+
|
234
|
+
from monarch._rust_bindings.monarch_extension import ( # @manual=//monarch/monarch_extension:monarch_extension
|
235
|
+
tensor_worker,
|
236
|
+
)
|
237
|
+
|
238
|
+
initialize_logging()
|
239
|
+
|
240
|
+
def check_set_device(device):
|
241
|
+
import os
|
242
|
+
|
243
|
+
if str(device) not in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(","):
|
244
|
+
raise ValueError(
|
245
|
+
f"Only devices {os.environ.get('CUDA_VISIBLE_DEVICES', 'None')} are available to monarch worker, "
|
246
|
+
f"but torch.cuda.set_device({device}) was called"
|
247
|
+
)
|
248
|
+
|
249
|
+
torch.cuda.set_device = check_set_device
|
250
|
+
|
251
|
+
with (
|
252
|
+
importlib.resources.path("monarch", "py-spy") as pyspy,
|
253
|
+
):
|
254
|
+
if pyspy.exists():
|
255
|
+
os.environ["PYSPY_BIN"] = str(pyspy)
|
256
|
+
# fallback to using local py-spy
|
257
|
+
|
258
|
+
pdb.set_trace = _set_trace
|
259
|
+
# pyre-ignore[16]
|
260
|
+
tensor_worker.worker_main()
|