torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,424 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import itertools
9
+ import traceback
10
+ import warnings
11
+ from dataclasses import dataclass
12
+ from typing import List, NamedTuple, Optional, Sequence
13
+
14
+ import torch
15
+
16
+ from monarch.common import messages
17
+ from monarch.common.shape import NDSlice
18
+ from monarch.simulator.ir import IRGraph
19
+ from monarch.simulator.tensor import DTensorRef
20
+ from monarch.simulator.utils import clean_name, file_path_with_iter
21
+
22
+ from torch.utils._pytree import tree_map
23
+
24
+
25
+ @dataclass
26
+ class Command:
27
+ timestamp: int
28
+ # Either "send" or "recvready" now.
29
+ backend_command: str
30
+ # "send" arguments
31
+ ranks: Optional[List[NDSlice]] = None
32
+ msg: Optional[NamedTuple] = None
33
+ # "recvready" arguments
34
+ timeout: Optional[float] = None
35
+
36
+
37
+ class CommandHistory:
38
+ """
39
+ A class to record commands sent to the SimulatorBackend. The class can be
40
+ later be used for replaying the recorded commands.
41
+
42
+ Args:
43
+ maxlen (int): The maximum number of commands to record. Defaults to 10_000_000.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ world_size: int,
49
+ *,
50
+ maxlen: int = 10_000_000,
51
+ file_path: str = "command_history.pt",
52
+ ) -> None:
53
+ self.world_size = world_size
54
+ self.maxlen = maxlen
55
+ self.commands: List[Command] = []
56
+ self.warn_once: bool = False
57
+ self.file_path = file_path
58
+
59
+ def __del__(self):
60
+ DTensorRef.created.clear()
61
+
62
+ def record(
63
+ self,
64
+ now: int,
65
+ backend_command: str,
66
+ command_id: int,
67
+ traceback: Sequence[traceback.FrameSummary] = (),
68
+ ranks: Optional[List[NDSlice]] = None,
69
+ msg: Optional[NamedTuple] = None,
70
+ timeout: Optional[float] = None,
71
+ ir: Optional[IRGraph] = None,
72
+ ) -> Command:
73
+ command = self.convert_command(
74
+ now, backend_command, command_id, traceback, ranks, msg, timeout, ir
75
+ )
76
+ if len(self.commands) < self.maxlen:
77
+ self.commands.append(command)
78
+ elif not self.warn_once:
79
+ warnings.warn(
80
+ (
81
+ f"CommandHistory's maxlen is {self.maxlen}, and we already "
82
+ " execeed the limit. The rest commands will not be recorded."
83
+ ),
84
+ stacklevel=2,
85
+ )
86
+ self.warn_once = True
87
+ return command
88
+
89
+ @staticmethod
90
+ def convert_command(
91
+ now: int,
92
+ backend_command: str,
93
+ command_id: int,
94
+ traceback: Sequence[traceback.FrameSummary] = (),
95
+ ranks: Optional[List[NDSlice]] = None,
96
+ msg: Optional[NamedTuple] = None,
97
+ timeout: Optional[float] = None,
98
+ ir: Optional[IRGraph] = None,
99
+ ) -> Command:
100
+ msg = CommandHistory._convert_command(msg)
101
+
102
+ if ir:
103
+ if isinstance(msg, messages.CommandGroup):
104
+ for i, command in enumerate(msg.commands):
105
+ CommandHistory._maybe_insert_ir(
106
+ ir, command_id + i + 1, traceback, ranks, command
107
+ ) # i starts from 0, so command_id + i + 1
108
+ else:
109
+ CommandHistory._maybe_insert_ir(ir, command_id, traceback, ranks, msg)
110
+ return Command(
111
+ timestamp=now,
112
+ backend_command=backend_command,
113
+ ranks=ranks,
114
+ msg=msg,
115
+ timeout=timeout,
116
+ )
117
+
118
+ @staticmethod
119
+ def convert_msg(msg):
120
+ def _convert_arg(v):
121
+ if isinstance(v, torch.Tensor):
122
+ return DTensorRef.from_ref(v)
123
+ return v
124
+
125
+ name = type(msg).__name__
126
+ match name:
127
+ case "CallFunction":
128
+ args, kwargs, mutates, result = tree_map(
129
+ _convert_arg, (msg.args, msg.kwargs, msg.mutates, msg.result)
130
+ )
131
+ msg = msg._replace(
132
+ args=args, kwargs=kwargs, mutates=mutates, result=result
133
+ )
134
+ case "SendTensor":
135
+ msg = msg._replace(
136
+ tensor=DTensorRef.from_ref(msg.tensor),
137
+ result=DTensorRef.from_ref(msg.result),
138
+ )
139
+ case "Reduce":
140
+ msg = msg._replace(
141
+ local_tensor=DTensorRef.from_ref(msg.local_tensor),
142
+ result=DTensorRef.from_ref(msg.result),
143
+ )
144
+ case "BorrowCreate":
145
+ msg = msg._replace(
146
+ result=DTensorRef.from_ref(msg.result),
147
+ tensor=DTensorRef.from_ref(msg.tensor),
148
+ )
149
+
150
+ return msg
151
+
152
+ @staticmethod
153
+ def _convert_command(msg):
154
+ if isinstance(msg, messages.CommandGroup):
155
+ for idx, command in enumerate(msg.commands):
156
+ msg.commands[idx] = CommandHistory.convert_msg(command)
157
+ return msg
158
+ else:
159
+ return CommandHistory.convert_msg(msg)
160
+
161
+ # TODO: Add function to simplify repeated modifications to ir
162
+ @staticmethod
163
+ def _maybe_insert_ir(
164
+ ir: IRGraph,
165
+ command_id: int,
166
+ tb: Sequence[traceback.FrameSummary] = (),
167
+ ranks: Optional[List[NDSlice]] = None,
168
+ msg: Optional[NamedTuple] = None,
169
+ ) -> None:
170
+ # Process tensor results and update IR
171
+ def _process_tensor_results(
172
+ result,
173
+ worker_rank,
174
+ stream_name,
175
+ command_id,
176
+ mutate=False,
177
+ borrow_src_tensor_ref=None,
178
+ ):
179
+ if result is not None:
180
+ results_list = result if isinstance(result, list) else [result]
181
+ for tensor_ref in results_list:
182
+ fake = tensor_ref._fake
183
+ ir.update_tensor(
184
+ tensor_ref._storage_id,
185
+ tensor_ref.ref,
186
+ fake.dtype,
187
+ tuple(fake.shape),
188
+ worker_rank,
189
+ stream_name,
190
+ command_id,
191
+ mutate=mutate,
192
+ borrow_src_tensor_ref=borrow_src_tensor_ref,
193
+ tensor_size=tensor_ref._size,
194
+ )
195
+
196
+ assert msg is not None
197
+ stream_name = src_stream_name = dst_stream_name = ""
198
+ flattened_ranks = list(itertools.chain.from_iterable(ranks or []))
199
+ command_type = ""
200
+ devices = []
201
+ control_dependencies = []
202
+ dag_item_type = type(msg).__name__
203
+ result = getattr(msg, "result", None)
204
+ for worker_rank in flattened_ranks:
205
+ match dag_item_type:
206
+ case "CallFunction":
207
+ stream_name = getattr(msg, "stream", None).name
208
+ command_type = (
209
+ f"CallFunction: {clean_name(str(getattr(msg, 'function', '')))}"
210
+ )
211
+ devices = [worker_rank]
212
+ msg_args = getattr(msg, "args", None)
213
+ if msg_args is not None:
214
+ for arg in msg_args:
215
+ if isinstance(arg, DTensorRef):
216
+ _process_tensor_results(
217
+ arg, worker_rank, stream_name, command_id
218
+ )
219
+ msg_mutates = getattr(msg, "mutates", None)
220
+ if msg_mutates is not None:
221
+ for mutate_src in msg_mutates:
222
+ if isinstance(mutate_src, DTensorRef) or (
223
+ isinstance(mutate_src, list)
224
+ and all(isinstance(m, DTensorRef) for m in mutate_src)
225
+ ):
226
+ mutates_list = (
227
+ mutate_src
228
+ if isinstance(mutate_src, list)
229
+ else [mutate_src]
230
+ )
231
+ _process_tensor_results(
232
+ mutates_list,
233
+ worker_rank,
234
+ stream_name,
235
+ command_id,
236
+ mutate=True,
237
+ )
238
+ _process_tensor_results(
239
+ result,
240
+ worker_rank,
241
+ stream_name,
242
+ command_id,
243
+ )
244
+
245
+ case "Reduce":
246
+ stream_name = getattr(msg, "stream", None).name
247
+ reduction = getattr(msg, "reduction", None)
248
+ scatter = getattr(msg, "scatter", False)
249
+ if reduction == "stack":
250
+ if scatter:
251
+ reduce_type = "all_to_all"
252
+ else:
253
+ reduce_type = "all_gather"
254
+ else:
255
+ if scatter:
256
+ reduce_type = "all_reduce"
257
+ else:
258
+ reduce_type = "reduce_scatter"
259
+ command_type = f"Reduce: {reduce_type}: {result.ref}" # use result.ref as unique Reduce id
260
+ devices = flattened_ranks
261
+ _process_tensor_results(
262
+ result, worker_rank, stream_name, command_id
263
+ )
264
+ case "BorrowCreate":
265
+ borrow_id = getattr(msg, "borrow", None)
266
+ borrow_src_tensor_ref = getattr(msg, "tensor", None).ref
267
+ stream_name = src_stream_name = getattr(
268
+ msg, "from_stream", None
269
+ ).name
270
+ dst_stream_name = getattr(msg, "to_stream", None).name
271
+
272
+ command_type = f"BorrowCreate: {borrow_id}"
273
+ devices = [worker_rank]
274
+ ir.add_borrow(
275
+ borrow_id,
276
+ worker_rank,
277
+ src_stream_name,
278
+ dst_stream_name,
279
+ command_id,
280
+ )
281
+ _process_tensor_results(
282
+ result,
283
+ worker_rank,
284
+ dst_stream_name,
285
+ command_id,
286
+ borrow_src_tensor_ref=borrow_src_tensor_ref,
287
+ )
288
+ case "BorrowFirstUse":
289
+ borrow_id = getattr(msg, "borrow", None)
290
+ stream_name = ir._control.borrows_info[borrow_id].dst_stream_name
291
+ command_type = f"BorrowFirstUse: {borrow_id}"
292
+ devices = [worker_rank]
293
+ control_dependencies = [
294
+ ir._control.borrows_info[borrow_id].create_id
295
+ ]
296
+ ir._control.borrows_info[borrow_id].firstuse_id = command_id
297
+ case "BorrowLastUse":
298
+ borrow_id = getattr(msg, "borrow", None)
299
+ stream_name = src_stream_name = ir._control.borrows_info[
300
+ borrow_id
301
+ ].dst_stream_name
302
+ dst_stream_name = ir._control.borrows_info[
303
+ borrow_id
304
+ ].src_stream_name
305
+ command_type = f"BorrowLastUse: {borrow_id}"
306
+ devices = [worker_rank]
307
+ ir._control.borrows_info[borrow_id].lastuse_id = command_id
308
+ case "BorrowDrop":
309
+ borrow_id = getattr(msg, "borrow", None)
310
+ stream_name = ir._control.borrows_info[borrow_id].src_stream_name
311
+ command_type = f"BorrowDrop: {borrow_id}"
312
+ devices = [worker_rank]
313
+ control_dependencies = [
314
+ ir._control.borrows_info[borrow_id].lastuse_id
315
+ ]
316
+ ir._control.borrows_info[borrow_id].drop_id = command_id
317
+
318
+ if dag_item_type in [
319
+ "CallFunction",
320
+ "Reduce",
321
+ "BorrowCreate",
322
+ "BorrowFirstUse",
323
+ "BorrowLastUse",
324
+ "BorrowDrop",
325
+ ]:
326
+ ir.insert_node(
327
+ worker_rank,
328
+ stream_name,
329
+ command_id,
330
+ command_type,
331
+ devices,
332
+ control_dependencies,
333
+ traceback.format_list(tb),
334
+ )
335
+
336
+ assert ranks is not None
337
+ if dag_item_type == "SendTensor" and len(ranks) == 2:
338
+ src_flattened_ranks = list(
339
+ itertools.chain.from_iterable([ranks[0]])
340
+ ) # for SendTensor, ranks[0] == source ranks
341
+ dst_flattened_ranks = list(
342
+ itertools.chain.from_iterable([ranks[1]])
343
+ ) # for SendTensor, ranks[1] == destination ranks
344
+
345
+ src_stream_name = getattr(msg, "from_stream", None).name
346
+ dst_stream_name = getattr(msg, "to_stream", None).name
347
+
348
+ # Create sets of (rank, stream) pairs for source and destination ranks
349
+ src_rank_stream_pairs = {
350
+ (rank, src_stream_name) for rank in src_flattened_ranks
351
+ }
352
+ dst_rank_stream_pairs = {
353
+ (rank, dst_stream_name) for rank in dst_flattened_ranks
354
+ }
355
+ rank_stream_pairs = (
356
+ src_rank_stream_pairs | dst_rank_stream_pairs
357
+ ) # find the union of the two sets
358
+ command_type = f"SendTensor: {result.ref if result else None}"
359
+ devices = flattened_ranks
360
+ control_dependencies = flattened_ranks
361
+ for rank, stream_name in rank_stream_pairs:
362
+ ir.insert_node(
363
+ rank,
364
+ stream_name,
365
+ command_id,
366
+ command_type,
367
+ devices,
368
+ control_dependencies,
369
+ traceback.format_list(tb),
370
+ )
371
+ src_tensor = getattr(msg, "tensor", None)
372
+ if src_tensor is not None:
373
+ src_tensors_list = (
374
+ src_tensor if isinstance(src_tensor, list) else [src_tensor]
375
+ )
376
+ for src_t in src_tensors_list:
377
+ for rank, src_stream_name in src_rank_stream_pairs:
378
+ _process_tensor_results(
379
+ src_t, rank, src_stream_name, command_id
380
+ )
381
+ if result is not None:
382
+ results_list = result if isinstance(result, list) else [result]
383
+ for res in results_list:
384
+ ir.add_sendtensor(
385
+ res.ref,
386
+ src_flattened_ranks,
387
+ src_stream_name,
388
+ dst_flattened_ranks,
389
+ dst_stream_name,
390
+ tuple(res._fake.size()),
391
+ )
392
+ for rank, dst_stream_name in dst_rank_stream_pairs:
393
+ _process_tensor_results(res, rank, dst_stream_name, command_id)
394
+
395
+ if dag_item_type == "DeleteRefs":
396
+ refs = getattr(msg, "refs", None)
397
+ for ref in refs:
398
+ stream_name = ir._data.tensorref_to_stream[ref]
399
+ # Do not call _insert_node() since we do not need DeleteRefs for the control DAG
400
+ ir.delete_tensor(
401
+ ref,
402
+ flattened_ranks,
403
+ stream_name,
404
+ command_id,
405
+ )
406
+
407
+ def step(self, iter_count: int, dump: bool = False) -> None:
408
+ if dump:
409
+ self.dump(file_path_with_iter(self.file_path, iter_count))
410
+
411
+ self.commands.clear()
412
+
413
+ def dump(self, file_path: str) -> None:
414
+ with open(file_path, "wb") as f:
415
+ torch.save({"world_size": self.world_size, "commands": self.commands}, f)
416
+
417
+ @classmethod
418
+ def load(cls, filename: str) -> "CommandHistory":
419
+ with open(filename, "rb") as f:
420
+ states = torch.load(f, weights_only=False)
421
+ self = cls(states["world_size"])
422
+ self.commands = states["commands"]
423
+
424
+ return self
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import contextlib
9
+
10
+ META_VAL = []
11
+
12
+
13
+ @contextlib.contextmanager
14
+ def set_meta(new_value):
15
+ # Sets the metadata for any tasks created under this
16
+ global META_VAL
17
+ META_VAL.append(new_value)
18
+ try:
19
+ yield
20
+ finally:
21
+ META_VAL.pop()
@@ -0,0 +1,59 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Union
8
+
9
+ from monarch.common.client import Client as _Client
10
+ from monarch.common.device_mesh import DeviceMesh
11
+ from monarch.common.shape import NDSlice
12
+
13
+ from monarch.simulator.ir import IRGraph
14
+ from monarch.simulator.simulator import (
15
+ SimulatorBackendMode,
16
+ SimulatorController as _SimulatorController,
17
+ SimulatorInterface,
18
+ SimulatorTraceMode,
19
+ )
20
+
21
+
22
+ def Simulator(
23
+ hosts: int,
24
+ gpus: int,
25
+ *,
26
+ simulate_mode: Union["str", SimulatorBackendMode] = SimulatorBackendMode.SIMULATE,
27
+ trace_mode: Union["str", SimulatorTraceMode] = SimulatorTraceMode.STREAM_ONLY,
28
+ upload_trace: bool = False,
29
+ trace_path: str = "trace.json",
30
+ command_history_path: str = "command_history.pkl",
31
+ group_workers: bool = False,
32
+ build_ir: bool = False,
33
+ ) -> "SimulatorInterface":
34
+ if isinstance(simulate_mode, str):
35
+ simulate_mode = getattr(SimulatorBackendMode, simulate_mode.upper())
36
+ if isinstance(trace_mode, str):
37
+ trace_mode = getattr(SimulatorTraceMode, trace_mode.upper())
38
+
39
+ ir = IRGraph() if build_ir else None
40
+ ctrl = _SimulatorController(
41
+ hosts * gpus,
42
+ gpu_per_host=gpus,
43
+ simulate_mode=simulate_mode,
44
+ trace_mode=trace_mode,
45
+ upload_trace=upload_trace,
46
+ trace_path=trace_path,
47
+ command_history_path=command_history_path,
48
+ group_workers=group_workers,
49
+ ir=ir,
50
+ )
51
+ client = _Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
52
+ dm = DeviceMesh(
53
+ client,
54
+ NDSlice(offset=0, sizes=[hosts, gpus], strides=[gpus, 1]),
55
+ ("host", "gpu"),
56
+ )
57
+
58
+ dm.exit = lambda: client.shutdown()
59
+ return SimulatorInterface(dm, ctrl, ir)