torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,389 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import copy
10
+ import itertools
11
+ import logging
12
+ import traceback
13
+ from collections import deque
14
+ from typing import cast, Dict, List, Optional, Sequence, Tuple
15
+
16
+ import numpy as np
17
+ from monarch.simulator.config import META_VAL
18
+ from monarch.simulator.profiling import RuntimeEstimator
19
+ from monarch.simulator.task import Borrow, EventTask, Task, WorkerTaskManager
20
+ from monarch.simulator.tensor import (
21
+ FakeTensorTracker,
22
+ StreamMemoryTracker,
23
+ TensorManager,
24
+ WorkerStorageTracker,
25
+ )
26
+ from monarch.simulator.trace import TraceEvent
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class Stream:
32
+ """Represents a worker stream."""
33
+
34
+ def __init__(
35
+ self,
36
+ ident: int,
37
+ name: str,
38
+ fake_tensor_tracker: FakeTensorTracker,
39
+ storage_tracker: WorkerStorageTracker,
40
+ cpu_tensors: TensorManager,
41
+ ) -> None:
42
+ self.id = ident
43
+ self.name = name
44
+ self.task_queue = deque()
45
+ self.last_task: Optional[Task] = None
46
+ self.now = 0
47
+ self.events: List[TraceEvent] = []
48
+ self.memory = StreamMemoryTracker(storage_tracker)
49
+ # Local tensors created on this stream. tTe value means which tasks
50
+ # or borrows (int) are using this tensor.
51
+ self.tensors = TensorManager(fake_tensor_tracker, self.memory)
52
+ self.cpu_tensors = cpu_tensors
53
+ self.fake_tensor_tracker = fake_tensor_tracker
54
+
55
+ def add_task(self, task: Task) -> None:
56
+ """
57
+ Add a task to this stream. A task is always pending in the beginning and
58
+ will be executed only if it is ready and is the first task in the stream.
59
+ """
60
+ task.start_time = max(self.now, task.start_time)
61
+
62
+ for output in set(task.outputs) - set(task.inputs):
63
+ self.tensors.add(output, (task,), task.start_time)
64
+
65
+ # Input must be from the previous tasks on the same stream or from
66
+ # the borrowed tensors.
67
+ for tensor in task.inputs:
68
+ if tensor in self.cpu_tensors:
69
+ self.cpu_tensors.incr_ref(tensor, task)
70
+ else:
71
+ self.tensors.incr_ref(tensor, task)
72
+
73
+ if self.task_queue:
74
+ task.dependencies.append(self.task_queue[-1])
75
+ elif self.last_task:
76
+ task.dependencies.append(self.last_task)
77
+
78
+ self.task_queue.append(task)
79
+
80
+ def lend(self, borrow: Borrow) -> None:
81
+ self.tensors.incr_ref(borrow.tensor_src_id, borrow.ident)
82
+
83
+ def return_borrow(self, borrow: Borrow) -> None:
84
+ self.tensors.decr_ref(borrow.tensor_src_id, borrow.ident, self.now, None)
85
+
86
+ def borrow(self, borrow: Borrow) -> None:
87
+ # We don't care about the timestamp as borrow should not incur any memory
88
+ # usage change.
89
+ self.tensors.add(borrow.tensor_dst_id, (), -1)
90
+ self.tensors.first_use(borrow.tensor_dst_id, -1, None)
91
+
92
+ def borrow_drop(self, borrow: Borrow) -> None:
93
+ # We don't care about the timestamp as borrow should not incur any memory
94
+ # usage change.
95
+ # self.tensors.delete(borrow.tensor_dst_id, -1)
96
+ pass
97
+
98
+ def delete_refs(self, tensor_ids: List[int], now: int) -> None:
99
+ tb = traceback.extract_stack()
100
+ for tensor_id in tensor_ids:
101
+ if tensor_id not in self.tensors:
102
+ continue
103
+ now = max(self.now, now)
104
+ self.tensors.delete(tensor_id, now, tb)
105
+
106
+ def maybe_set_ready(self) -> bool:
107
+ if self.task_queue:
108
+ return self.task_queue[0].maybe_set_ready()
109
+ return False
110
+
111
+ def maybe_execute(self) -> bool:
112
+ """
113
+ Check if we can execute the first task of this stream. Return True if
114
+ the first task's state is changed from READY to EXECUTING.
115
+ """
116
+ if self.task_queue:
117
+ task = self.task_queue[0]
118
+ executing = task.maybe_execute()
119
+ if executing:
120
+ for output in set(task.outputs) - set(task.inputs):
121
+ self.tensors.first_use(output, task.start_time, task.traceback)
122
+ return False
123
+
124
+ def maybe_finish(self) -> Tuple[Optional[Task], Optional[Task]]:
125
+ """
126
+ Check if we can finish the first task of this stream. Return the task if
127
+ the first task's state is changed from EXECUTING to EXECUTED else return
128
+ None.
129
+ """
130
+ if not self.task_queue:
131
+ return (None, None)
132
+
133
+ task = self.task_queue[0]
134
+ if not task.maybe_finish():
135
+ return (None, None)
136
+
137
+ task = self.task_queue.popleft()
138
+ original_last_task = self.last_task
139
+ self.last_task = task
140
+
141
+ # Update the tensor and memory usage.
142
+ if isinstance(task, EventTask):
143
+ borrow = task.borrow
144
+ if borrow is not None and borrow.tensor_src_id in self.tensors:
145
+ self.tensors.decr_ref(
146
+ borrow.tensor_src_id, borrow.ident, task.end_time, task.traceback
147
+ )
148
+ else:
149
+ removed_tensors = set()
150
+ for tensor in itertools.chain(task.inputs, task.outputs):
151
+ if tensor in self.cpu_tensors:
152
+ self.cpu_tensors.decr_ref(
153
+ tensor, task, task.end_time, task.traceback
154
+ )
155
+ removed_tensors.add(tensor)
156
+ elif tensor not in self.tensors:
157
+ raise RuntimeError(f"tensor {tensor} not in self.tensors.")
158
+ elif tensor not in removed_tensors:
159
+ # We also remove the reference even if the tensor is in
160
+ # outputs -- the tensor is not going to be deleted until
161
+ # DeleteRef is received.
162
+ self.tensors.decr_ref(tensor, task, task.end_time, task.traceback)
163
+ removed_tensors.add(tensor)
164
+
165
+ # Add TraceEvent.
166
+ if task.end_time > task.start_time:
167
+ runtime = task.end_time - task.start_time
168
+ self.events.append(
169
+ TraceEvent(
170
+ task.start_time, runtime, task.meta, task.command_id, task.traceback
171
+ )
172
+ )
173
+
174
+ # update the stream timestamp
175
+ self.now = task.end_time
176
+ return (original_last_task, task)
177
+
178
+ def wait_event(self, event: EventTask) -> None:
179
+ self.add_task(event)
180
+
181
+ def record_event(self) -> Task:
182
+ if self.task_queue:
183
+ return self.task_queue[-1]
184
+ elif self.last_task:
185
+ return self.last_task
186
+ else:
187
+ raise RuntimeError("No tasks can be recorded.")
188
+
189
+ def clone(
190
+ self,
191
+ task_manager: WorkerTaskManager,
192
+ storage_tracker: WorkerStorageTracker,
193
+ cpu_tensors: TensorManager,
194
+ ) -> "Stream":
195
+ ret = Stream(
196
+ ident=self.id,
197
+ name=self.name,
198
+ fake_tensor_tracker=self.fake_tensor_tracker,
199
+ storage_tracker=storage_tracker,
200
+ cpu_tensors=cpu_tensors,
201
+ )
202
+ for task in self.task_queue:
203
+ ret.task_queue.append(task_manager.tasks[task.task_id])
204
+ if self.last_task:
205
+ assert self.last_task.task_id is not None
206
+ ret.last_task = task_manager.tasks[self.last_task.task_id]
207
+ ret.now = self.now
208
+ ret.events = copy.copy(self.events)
209
+ ret.memory = self.memory.clone(storage_tracker)
210
+ ret.tensors = self.tensors.clone(task_manager, ret.memory)
211
+ return ret
212
+
213
+
214
+ class Worker:
215
+ """Represents a worker."""
216
+
217
+ def __init__(
218
+ self,
219
+ fake_tensor_tracker: FakeTensorTracker,
220
+ runtime: RuntimeEstimator,
221
+ ) -> None:
222
+ self.runtime = runtime
223
+ self.streams: Dict[int, Stream] = {}
224
+ self.default_stream_id = 0
225
+ self.events: List[TraceEvent] = []
226
+ self.wait_events: Dict[int, EventTask] = {}
227
+ self.fake_tensor_tracker = fake_tensor_tracker
228
+ self.storage_tracker = WorkerStorageTracker(fake_tensor_tracker)
229
+ # We don't track the CPU, memory usage. So pass None as the memory
230
+ # argument.
231
+ self.cpu_tensors = TensorManager(fake_tensor_tracker, None)
232
+ self.borrows: Dict[int, Borrow] = {}
233
+
234
+ self.task_manager = WorkerTaskManager()
235
+
236
+ def record_command(
237
+ self,
238
+ command: str,
239
+ command_id: int,
240
+ now: int,
241
+ traceback: Sequence[traceback.FrameSummary],
242
+ ) -> None:
243
+ # This is a CPU activity event.
244
+ self.events.append(
245
+ TraceEvent(
246
+ now,
247
+ self.runtime.get_runtime("kernel_launch"),
248
+ [command] + META_VAL,
249
+ command_id,
250
+ traceback,
251
+ )
252
+ )
253
+
254
+ def create_stream(self, ident: int, name: str, default: bool) -> None:
255
+ if ident in self.streams:
256
+ raise ValueError(f"{ident} is already created.")
257
+ self.streams[ident] = Stream(
258
+ ident,
259
+ name,
260
+ self.fake_tensor_tracker,
261
+ self.storage_tracker,
262
+ self.cpu_tensors,
263
+ )
264
+ if default:
265
+ self.default_stream_id = ident
266
+
267
+ def add_task(self, task: Task, now: int, stream: Optional[int] = None) -> None:
268
+ self.record_command(task.meta[0], task.command_id, now, task.traceback)
269
+ if stream is None:
270
+ stream = self.default_stream_id
271
+ self.streams[stream].add_task(task)
272
+ self.task_manager.add(task)
273
+
274
+ def borrow(self, task: EventTask, borrow: Borrow) -> None:
275
+ from_stream = task.event_stream
276
+ to_stream = task.wait_stream
277
+ self.streams[from_stream].lend(borrow)
278
+ self.streams[to_stream].borrow(borrow)
279
+
280
+ # Record the event from the source stream so that the destination stream
281
+ # can wait for it when the borrowed tensor is first used.
282
+ # TODO: can we unify the separate data structures that keep tasks?
283
+ self.wait_events[borrow.ident] = task
284
+ self.task_manager.add(task)
285
+ self.borrows[borrow.ident] = borrow
286
+
287
+ def borrow_first_use(self, borrow_id: int, now: int) -> None:
288
+ task = self.wait_events[borrow_id]
289
+ to_stream = task.wait_stream
290
+
291
+ # The destination stream needs to wait for the event before it can use
292
+ # the borrowed tensor.
293
+ self.record_command(task.meta[0], task.command_id, now, task.traceback)
294
+ self.streams[to_stream].wait_event(task)
295
+
296
+ def borrow_last_use(self, task: EventTask, borrow_id: int) -> None:
297
+ # Record the last use event from the destination stream so that the
298
+ # source stream can wait for it when the borrow is dropped.
299
+ self.wait_events[borrow_id] = task
300
+ self.task_manager.add(task)
301
+
302
+ def borrow_drop(self, borrow_id: int, now: int) -> None:
303
+ task = self.wait_events[borrow_id]
304
+ from_stream = task.wait_stream
305
+ to_stream = task.event_stream
306
+
307
+ # Wait for the last usage.
308
+ borrow = self.borrows[borrow_id]
309
+ self.record_command(task.meta[0], task.command_id, now, task.traceback)
310
+ self.streams[from_stream].wait_event(task)
311
+ self.streams[from_stream].return_borrow(borrow)
312
+ self.streams[to_stream].borrow_drop(borrow)
313
+
314
+ def add_cpu_tensor(self, tensor_id: int, ts: int) -> None:
315
+ # Currently we don't simulate any CPU ops and memory, so this is the
316
+ # API to add CPU tensors. We also don't add the dependency of the
317
+ # creation task as it is a CPU op (e.g., dataloader).
318
+ self.cpu_tensors.add(tensor_id, (), ts)
319
+
320
+ def delete_refs(self, tensor_ids: List[int], ts: int) -> None:
321
+ for tensor_id in tensor_ids:
322
+ if tensor_id in self.cpu_tensors:
323
+ self.cpu_tensors.delete(tensor_id, ts, None)
324
+
325
+ for stream in self.streams.values():
326
+ stream.delete_refs(tensor_ids, ts)
327
+
328
+ def maybe_set_ready(self) -> bool:
329
+ """
330
+ Check if we can set ready for tasks on the streams of the worker. Return
331
+ True if we execute at least one task.
332
+ """
333
+ return any(s.maybe_set_ready() for s in self.streams.values())
334
+
335
+ def maybe_execute(self) -> bool:
336
+ """
337
+ Check if we can execute tasks on the streams of the worker. Return
338
+ True if we execute at least one task.
339
+ """
340
+ return any(s.maybe_execute() for s in self.streams.values())
341
+
342
+ def maybe_finish(self) -> bool:
343
+ """
344
+ Check if we can finish any task on the streams of the worker. Return
345
+ True if we finish at least one task.
346
+ """
347
+ ret = False
348
+ for stream in self.streams.values():
349
+ last_task, task = stream.maybe_finish()
350
+ if task:
351
+ ret = True
352
+ if last_task:
353
+ self.task_manager.remove(last_task)
354
+ return ret
355
+
356
+
357
+ class WorkerGroup(Worker):
358
+ def __init__(
359
+ self,
360
+ workers,
361
+ fake_tensor_tracker: FakeTensorTracker,
362
+ runtime: RuntimeEstimator,
363
+ ) -> None:
364
+ super().__init__(fake_tensor_tracker, runtime)
365
+ self.workers = workers
366
+
367
+ def clone(self, workers) -> "WorkerGroup":
368
+ ret = WorkerGroup(workers, self.fake_tensor_tracker, self.runtime)
369
+ ret.default_stream_id = self.default_stream_id
370
+ ret.events = copy.copy(self.events)
371
+ ret.borrows = copy.copy(self.borrows)
372
+ ret.task_manager = self.task_manager.clone()
373
+ ret.storage_tracker = self.storage_tracker.clone()
374
+ ret.cpu_tensors = self.cpu_tensors.clone(ret.task_manager, None)
375
+ for ident, task in self.wait_events.items():
376
+ assert task.task_id is not None
377
+ ret.wait_events[ident] = cast(
378
+ EventTask, ret.task_manager.tasks[task.task_id]
379
+ )
380
+ for sid, stream in self.streams.items():
381
+ ret.streams[sid] = stream.clone(
382
+ ret.task_manager, ret.storage_tracker, ret.cpu_tensors
383
+ )
384
+ return ret
385
+
386
+ def split(self, split_set) -> "WorkerGroup":
387
+ assert len(np.setdiff1d(split_set, self.workers, assume_unique=True)) == 0
388
+ self.workers = np.setdiff1d(self.workers, split_set, assume_unique=True)
389
+ return self.clone(split_set)
@@ -0,0 +1,260 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ This is the main function for the worker / pipe processes. It expects the args to
9
+ the process to be passed in on the command line and accessible in `sys.argv`.
10
+
11
+ To see the supported arguments checkout `monarch_tensor_worker::bootstrap`.
12
+ """
13
+
14
+ # pyre-unsafe
15
+
16
+ import bdb
17
+
18
+ import importlib.resources
19
+ import io
20
+
21
+ import logging
22
+ import os
23
+
24
+ import pdb # noqa # noqa
25
+ import socket
26
+ import sys
27
+ from pathlib import Path
28
+ from typing import cast, Optional
29
+
30
+ from monarch._rust_bindings.monarch_extension import debugger
31
+ from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ def _handle_unhandled_exception(*args):
37
+ logger.error("Uncaught exception", exc_info=args)
38
+
39
+
40
+ _glog_level_to_abbr = {
41
+ "DEBUG": "V", # V is for VERBOSE in glog
42
+ "INFO": "I",
43
+ "WARNING": "W",
44
+ "ERROR": "E",
45
+ "CRITICAL": "C",
46
+ }
47
+
48
+
49
+ def fix_exception_lines(tb_lines):
50
+ formatted_lines = []
51
+ for line in tb_lines:
52
+ # Replace the standard file and line format with the custom format
53
+ if line.startswith(" File"):
54
+ # Extract the filename and line number
55
+ parts = line.split(",")
56
+ file_info = parts[0].strip()[6:-1] # Remove ' File "' and '"'
57
+ line_info = parts[1].strip()[5:] # Remove 'line '
58
+ new_line = f" File {file_info}:{line_info}"
59
+ if len(parts) > 2:
60
+ new_line += ", " + ",".join(parts[2:]).strip()
61
+ formatted_lines.append(new_line)
62
+ else:
63
+ formatted_lines.append(line.strip())
64
+ return formatted_lines
65
+
66
+
67
+ class _Formatter(logging.Formatter):
68
+ def __init__(self, suffix):
69
+ self.suffix = suffix
70
+
71
+ def format(self, record):
72
+ message = record.getMessage()
73
+ asctime = self.formatTime(record, "%m%d %H:%M:%S")
74
+
75
+ lines = message.strip().split("\n")
76
+ if record.exc_info:
77
+ exc_info = fix_exception_lines(
78
+ self.formatException(record.exc_info).split("\n")
79
+ )
80
+ lines.extend(exc_info)
81
+ if record.stack_info:
82
+ stack_info = self.formatStack(record.stack_info)
83
+ lines.extend(stack_info.strip().split("\n"))
84
+
85
+ shortlevel = _glog_level_to_abbr.get(record.levelname, record.levelname[0])
86
+
87
+ prefix = (
88
+ f"{shortlevel}{asctime}.{int(record.msecs*1000):06d} "
89
+ f"{record.filename}:"
90
+ f"{record.lineno}]{self.suffix}"
91
+ )
92
+ return "\n".join(f"{prefix} {line}" for line in lines)
93
+
94
+
95
+ def initialize_logging(process_name=None):
96
+ log_folder = os.environ.get("TORCH_MONARCH_LOG_FOLDER")
97
+ log_level = os.environ.get("TORCH_MONARCH_LOG_LEVEL", "INFO")
98
+ suffix = "" if process_name is None else f" {process_name}:"
99
+ handler = None
100
+ if log_folder is not None:
101
+ log_folder_path = Path(log_folder)
102
+ log_folder_path.mkdir(parents=True, exist_ok=True)
103
+ safe_process_name = (
104
+ process_name.replace("/", "_") if process_name else "logfile.log"
105
+ )
106
+ log_file_name = f"{safe_process_name}.log"
107
+ log_file_path = log_folder_path / log_file_name
108
+ handler = logging.FileHandler(log_file_path)
109
+ else:
110
+ handler = logging.StreamHandler()
111
+ handler.setFormatter(_Formatter(suffix))
112
+ handler.setLevel(log_level)
113
+ logging.root.setLevel(log_level)
114
+ logging.root.addHandler(handler)
115
+ sys.excepthook = _handle_unhandled_exception
116
+
117
+
118
+ def gethostname():
119
+ """Get the hostname of the machine."""
120
+ hostname = socket.gethostname()
121
+ hostname = hostname.replace(".facebook.com", "")
122
+ return hostname
123
+
124
+
125
+ def _set_trace(*, header=None):
126
+ ds = PdbWrapper(header)
127
+ ds.set_trace()
128
+
129
+
130
+ class PdbWrapper(pdb.Pdb):
131
+ def __init__(self, header: Optional[str]):
132
+ self._actor = debugger.PdbActor()
133
+ self.header = header
134
+ super().__init__(
135
+ # pyre-ignore
136
+ stdout=WriteWrapper(self._actor),
137
+ stdin=ReadWrapper.create(self._actor),
138
+ )
139
+ self._first = True
140
+
141
+ def setup(self, *args, **kwargs):
142
+ r = super().setup(*args, **kwargs)
143
+ if self._first:
144
+ self._first = False
145
+ # when we enter the debugger, we want to present the user's stack frame
146
+ # not the nested one inside session.run. This means that the local
147
+ # variables are what gets printed, etc. To do this
148
+ # we first execute up 2 to get to that frame.
149
+ self.do_up(2)
150
+ return r
151
+
152
+ def set_continue(self) -> None:
153
+ r = super().set_continue()
154
+ if not self.breaks:
155
+ # no more breakpoints so this debugger will not
156
+ # be used again, and we detach from the controller io.
157
+ self._actor.send(DebuggerAction.Detach())
158
+ self._actor.drain_and_stop()
159
+ # break cycle with itself before we exit
160
+ self.stdin = sys.stdin
161
+ self.stdout = sys.stdout
162
+ return r
163
+
164
+ def set_trace(self):
165
+ self._actor.send(DebuggerAction.Paused())
166
+ message = self._actor.receive()
167
+ # we give the controller the option to ignore this request to debug
168
+ # by issuing a "detach" message immediately.
169
+ if isinstance(message, DebuggerAction.Detach):
170
+ return
171
+ elif isinstance(message, DebuggerAction.Attach):
172
+ pass
173
+ else:
174
+ raise RuntimeError(f"unexpected debugger message {message}")
175
+ if self.header:
176
+ self.message(self.header)
177
+ super().set_trace()
178
+
179
+ def set_quit(self):
180
+ self._actor.send(DebuggerAction.Detach())
181
+ self._actor.drain_and_stop()
182
+ super().set_quit()
183
+
184
+
185
+ class ReadWrapper(io.RawIOBase):
186
+ def __init__(self, actor: debugger.PdbActor):
187
+ self._actor = actor
188
+
189
+ def readinto(self, b):
190
+ self._actor.send(DebuggerAction.Read(len(b)))
191
+ response = self._actor.receive()
192
+ if isinstance(response, DebuggerAction.Detach):
193
+ raise bdb.BdbQuit
194
+ assert isinstance(response, DebuggerAction.Write)
195
+ response = cast(DebuggerAction.Write, response)
196
+ payload = debugger.get_bytes_from_write_action(response)
197
+ assert len(payload) <= len(b)
198
+ b[: len(payload)] = payload
199
+ return len(payload)
200
+
201
+ def readable(self) -> bool:
202
+ return True
203
+
204
+ @classmethod
205
+ def create(cls, actor: debugger.PdbActor):
206
+ return io.TextIOWrapper(io.BufferedReader(cls(actor)))
207
+
208
+
209
+ class WriteWrapper:
210
+ def __init__(self, actor: debugger.PdbActor):
211
+ self._actor = actor
212
+
213
+ def writable(self) -> bool:
214
+ return True
215
+
216
+ def write(self, s: str):
217
+ self._actor.send(DebuggerAction.Write(s.encode()))
218
+
219
+ def flush(self):
220
+ pass
221
+
222
+
223
+ if __name__ == "__main__":
224
+ # torch is import to make sure all the dynamic types are registered
225
+ import torch # noqa
226
+
227
+ if torch.cuda.is_available():
228
+ # Force CUDA initialization early on. CUDA init is lazy, and Python CUDA
229
+ # APIs are guarded to init CUDA if necessary. But our worker calls
230
+ # raw libtorch APIs which are not similarly guarded. So just initialize here
231
+ # to avoid issues with potentially using uninitialized CUDA state.
232
+ torch.cuda.init()
233
+
234
+ from monarch._rust_bindings.monarch_extension import ( # @manual=//monarch/monarch_extension:monarch_extension
235
+ tensor_worker,
236
+ )
237
+
238
+ initialize_logging()
239
+
240
+ def check_set_device(device):
241
+ import os
242
+
243
+ if str(device) not in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(","):
244
+ raise ValueError(
245
+ f"Only devices {os.environ.get('CUDA_VISIBLE_DEVICES', 'None')} are available to monarch worker, "
246
+ f"but torch.cuda.set_device({device}) was called"
247
+ )
248
+
249
+ torch.cuda.set_device = check_set_device
250
+
251
+ with (
252
+ importlib.resources.path("monarch", "py-spy") as pyspy,
253
+ ):
254
+ if pyspy.exists():
255
+ os.environ["PYSPY_BIN"] = str(pyspy)
256
+ # fallback to using local py-spy
257
+
258
+ pdb.set_trace = _set_trace
259
+ # pyre-ignore[16]
260
+ tensor_worker.worker_main()