torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,411 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import functools
9
+ import json
10
+ import os
11
+ import tempfile
12
+ from typing import cast, Dict, List, Tuple
13
+
14
+ import monarch
15
+
16
+ import numpy as np
17
+
18
+ import pytest
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ from monarch import fetch_shard, NDSlice, Stream, Tensor
23
+ from monarch.simulator.simulator import Simulator, SimulatorTraceMode
24
+ from monarch.simulator.utils import file_path_with_iter
25
+
26
+
27
+ def with_tempfile(suffix=".json", unlink=True):
28
+ def decorator(func):
29
+ @functools.wraps(func)
30
+ def wrapper(self, *args, **kwargs):
31
+ temp_fd, temp_path = tempfile.mkstemp(suffix=suffix)
32
+ os.close(temp_fd)
33
+ try:
34
+ return func(self, *args, trace_path=temp_path, **kwargs)
35
+ finally:
36
+ # unlink should only be False when debugging.
37
+ if unlink:
38
+ os.unlink(temp_path)
39
+ else:
40
+ import logging
41
+
42
+ logging.warning(temp_path)
43
+
44
+ return wrapper
45
+
46
+ return decorator
47
+
48
+
49
+ @monarch.remote(propagate=lambda x, group: x.add_(1))
50
+ def simple_all_reduce_local(x, group=None):
51
+ dist.all_reduce(x, op=dist.ReduceOp.SUM, group=group)
52
+ return x
53
+
54
+
55
+ # pyre-ignore-all-errors[6]
56
+ # pyre-ignore-all-errors[16]
57
+ # Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
58
+ # out is not counted as a failure, so we set a more restrictive timeout to
59
+ # ensure we see a hard failure in CI.
60
+ @pytest.mark.timeout(120)
61
+ class TestSimulator:
62
+ def _get_simulation_result(
63
+ self, pid: int, trace_path
64
+ ) -> Tuple[Dict[str, List[str] | List[Tuple[float, float]]], List[int]]:
65
+ with open(file_path_with_iter(trace_path, 0), "r") as f:
66
+ traces = json.load(f)["traceEvents"]
67
+ simulator_commands: Dict[str, List[str] | List[Tuple[float, float]]] = {}
68
+ tid_to_name = {}
69
+ memory = []
70
+ for trace in traces:
71
+ if trace["pid"] != pid:
72
+ continue
73
+ if trace["name"] == "process_name":
74
+ continue
75
+
76
+ if trace["name"] == "thread_name":
77
+ tid = trace["tid"]
78
+ name = trace["args"]["name"]
79
+ tid_to_name[tid] = name
80
+ simulator_commands[name] = []
81
+ simulator_commands[f"{name} timestamp"] = []
82
+ elif trace["cat"] == "compute":
83
+ tid = trace["tid"]
84
+ name = tid_to_name[tid]
85
+ simulator_commands[name].append(trace["name"])
86
+ cast(
87
+ List[Tuple[float, float]],
88
+ simulator_commands[f"{name} timestamp"],
89
+ ).append((float(trace["ts"]), float(trace["ts"] + trace["dur"])))
90
+ elif trace["cat"] == "memory":
91
+ memory.append(trace["args"]["allocated"])
92
+
93
+ return simulator_commands, memory
94
+
95
+ @pytest.mark.parametrize("group_workers", [False, True])
96
+ @with_tempfile()
97
+ def test_borrow(self, group_workers, trace_path=None):
98
+ mesh = monarch.Simulator(
99
+ hosts=1,
100
+ gpus=2,
101
+ trace_path=trace_path,
102
+ group_workers=group_workers,
103
+ trace_mode=SimulatorTraceMode.EVERYTHING,
104
+ ).mesh
105
+ other_stream = Stream("other")
106
+
107
+ with mesh.activate(), torch.device("cuda"):
108
+ ac1 = torch.randn(100, 100)
109
+ ac2 = torch.mm(ac1, ac1)
110
+ ac3 = torch.nn.init.uniform_(ac2)
111
+ borrow_ac3, borrow = other_stream.borrow(ac3, mutable=True)
112
+ with other_stream.activate():
113
+ borrow_ac3.add_(borrow_ac3)
114
+ borrow.drop()
115
+ mesh.exit()
116
+
117
+ commands, _ = self._get_simulation_result(0, trace_path)
118
+ assert commands["Controller"] == [
119
+ "aten.randn",
120
+ "aten.mm",
121
+ "aten.uniform_",
122
+ "DeleteRefs", # Delete ac2
123
+ "BorrowCreate", # borrow()
124
+ "BorrowFirstUse", # borrow_ac3 of add_()
125
+ "aten.add_.Tensor", # add_()
126
+ "DeleteRefs", # delete the result of add_()
127
+ "BorrowLastUse", # drop() will cal _drop_ref()
128
+ "DeleteRefs", # delete borrow_ac3
129
+ "BorrowDrop", # drop()
130
+ "RequestStatus", # drop()
131
+ "Exit", # isn't this obvious :)
132
+ ]
133
+
134
+ _, memory = self._get_simulation_result(1, trace_path)
135
+ assert memory == [
136
+ 0.04, # randn()
137
+ 0.08, # mm()
138
+ ]
139
+
140
+ @pytest.mark.parametrize("group_workers", [False, True])
141
+ @with_tempfile()
142
+ def test_to_mesh(self, group_workers, trace_path=None) -> None:
143
+ mesh = monarch.Simulator(
144
+ hosts=2, gpus=2, trace_path=trace_path, group_workers=group_workers
145
+ ).mesh
146
+ pp_meshes = [mesh(host=0), mesh(host=1)]
147
+
148
+ with pp_meshes[0].activate(), torch.device("cuda"):
149
+ x = cast(Tensor, torch.randn(100, 100))
150
+ y = x.to_mesh(pp_meshes[0]) # noqa
151
+ z = x.to_mesh(pp_meshes[1]) # noqa
152
+ mesh.exit()
153
+
154
+ commands, memory = self._get_simulation_result(1, trace_path)
155
+ assert commands["main"] == [
156
+ "aten.randn",
157
+ "SendTensor",
158
+ "SendTensor",
159
+ ]
160
+
161
+ # note in simulator definition of simulator's SendTensor
162
+ # mentions that memory might not be accurately modelled
163
+ # when destination/src is the same. When to_mesh
164
+ # received aliasing fixes, this seemed to throw off
165
+ # the simulators memory calculations here.
166
+ # We need to address the memory copy behavior of to_mesh
167
+ # first, and then align the simulator with the fix in
168
+ # copy behavior.
169
+
170
+ # assert memory == [
171
+ # 0.04, # randn()
172
+ # ]
173
+ # commands, memory = self._get_simulation_result(3, trace_path)
174
+ # assert memory == [
175
+ # 0.04, # SendTensor
176
+ # ]
177
+
178
+ @pytest.mark.parametrize("group_workers", [False, True])
179
+ @with_tempfile()
180
+ def test_reduce_with_stream_trace_only(
181
+ self, group_workers, trace_path=None
182
+ ) -> None:
183
+ mesh = monarch.Simulator(
184
+ hosts=1,
185
+ gpus=2,
186
+ trace_path=trace_path,
187
+ trace_mode=SimulatorTraceMode.STREAM_ONLY,
188
+ group_workers=group_workers,
189
+ ).mesh
190
+ reducer_stream = Stream("reducer_stream")
191
+
192
+ with mesh.activate(), torch.device("cuda"):
193
+ x = cast(Tensor, torch.randn(100, 100))
194
+ y = cast(Tensor, torch.randn(100, 100))
195
+ z = cast(Tensor, torch.randn(100, 100))
196
+ flatten = torch.cat((x.view((10000,)), y.view((10000,))))
197
+ flatten_borrow, borrow = reducer_stream.borrow(cast(Tensor, flatten))
198
+ with reducer_stream.activate():
199
+ flatten_borrow.reduce_("gpu", reduction="avg")
200
+ y = y @ z
201
+ borrow.drop()
202
+ x = cast(Tensor, torch.randn(100, 100))
203
+ new_x, new_y = flatten.split((10000, 10000))
204
+ del flatten
205
+ # Need another command to trigger the controller to send the delete
206
+ # command.
207
+ no_use_1 = cast(Tensor, torch.randn(100, 100)) # noqa
208
+ del new_x
209
+ del new_y
210
+ no_use_2 = cast(Tensor, torch.randn(100, 100)) # noqa
211
+
212
+ mesh.exit()
213
+
214
+ commands, memory = self._get_simulation_result(1, trace_path)
215
+
216
+ assert memory == [
217
+ 0.04, # x
218
+ 0.08, # y
219
+ 0.12, # z
220
+ 0.20, # torch.cat
221
+ 0.24, # mm
222
+ 0.20, # del the original y
223
+ 0.24, # new x
224
+ 0.20, # del the original x
225
+ 0.24, # no_use1
226
+ 0.16, # del new_x, del_y => flatten removed
227
+ 0.20, # no_use_2
228
+ ]
229
+ self.maxDiff = 10000
230
+ assert commands["main"] == [
231
+ "aten.randn", # x
232
+ "aten.randn", # y
233
+ "aten.randn", # z
234
+ "aten.cat", # cat
235
+ "aten.mm", # mm
236
+ "waiting for reducer_stream", # drop
237
+ "aten.randn", # second x
238
+ "aten.split_with_sizes", # split
239
+ "aten.randn", # no_use_1
240
+ "aten.randn", # no_use_2
241
+ ]
242
+
243
+ assert commands["main timestamp"] == [
244
+ (0.0, 10.0),
245
+ (10.0, 20.0),
246
+ (20.0, 30.0),
247
+ (30.0, 40.0),
248
+ (40.0, 50.0),
249
+ # Reduce is set to 100ms which is partially overlapped with mm
250
+ # which is set to 10ms.
251
+ (50.0, 140.0),
252
+ (140.0, 150.0),
253
+ (150.0, 160.0),
254
+ (160.0, 170.0),
255
+ (170.0, 180.0),
256
+ ]
257
+
258
+ assert commands["reducer_stream"] == [
259
+ "waiting for main", # borrow first use
260
+ "reduce_scatter", # reduce_
261
+ ]
262
+ assert commands["reducer_stream timestamp"] == [
263
+ (0.0, 40.0),
264
+ (40.0, 140.0), # reduce_
265
+ ]
266
+
267
+ if not group_workers:
268
+ assert commands["Device 0"] == []
269
+ else:
270
+ assert commands["Device 0 [0-1]"] == []
271
+ commands, _ = self._get_simulation_result(0, trace_path)
272
+ assert commands["Controller"] == []
273
+
274
+ def test_ndslice_to_worker_group(self) -> None:
275
+ simulator = Simulator(world_size=1024, group_workers=True)
276
+
277
+ # [0, 1024]
278
+ ranks = [NDSlice(offset=0, sizes=[1024], strides=[1])]
279
+ groups = list(simulator._ndslice_to_worker_group(ranks))
280
+ assert len(groups) == 1
281
+
282
+ # [0, 512), [512, 1024)
283
+ ranks = [NDSlice(offset=0, sizes=[512], strides=[1])]
284
+ groups = list(simulator._ndslice_to_worker_group(ranks))
285
+ assert len(groups) == 1
286
+ assert len(simulator._worker_groups) == 2
287
+ np.testing.assert_array_equal(
288
+ simulator._worker_groups[0].workers, np.arange(512)
289
+ )
290
+ np.testing.assert_array_equal(
291
+ simulator._worker_groups[1].workers, np.arange(512, 1024)
292
+ )
293
+
294
+ # [0, 512), ([512, 640), [768, 1024)), [640, 768)
295
+ ranks = [NDSlice(offset=640, sizes=[128], strides=[1])]
296
+ groups = list(simulator._ndslice_to_worker_group(ranks))
297
+ assert len(groups) == 1
298
+ assert len(simulator._worker_groups) == 3
299
+ np.testing.assert_array_equal(
300
+ simulator._worker_groups[0].workers, np.arange(512)
301
+ )
302
+ np.testing.assert_array_equal(
303
+ simulator._worker_groups[1].workers, np.arange(640, 768)
304
+ )
305
+ np.testing.assert_array_equal(
306
+ simulator._worker_groups[2].workers,
307
+ np.concatenate((np.arange(512, 640), np.arange(768, 1024))),
308
+ )
309
+
310
+ # [0, 256), [256, 512), [512, 600), ([600, 640), [768, 1024)), [640, 768)
311
+ ranks = [NDSlice(offset=256, sizes=[344], strides=[1])]
312
+ groups = list(simulator._ndslice_to_worker_group(ranks))
313
+ assert len(groups) == 2
314
+ assert len(simulator._worker_groups) == 5
315
+ np.testing.assert_array_equal(
316
+ simulator._worker_groups[0].workers, np.arange(256, 512)
317
+ )
318
+ np.testing.assert_array_equal(
319
+ simulator._worker_groups[1].workers, np.arange(0, 256)
320
+ )
321
+ np.testing.assert_array_equal(
322
+ simulator._worker_groups[2].workers, np.arange(640, 768)
323
+ )
324
+ np.testing.assert_array_equal(
325
+ simulator._worker_groups[3].workers, np.arange(512, 600)
326
+ )
327
+ np.testing.assert_array_equal(
328
+ simulator._worker_groups[4].workers,
329
+ np.concatenate((np.arange(600, 640), np.arange(768, 1024))),
330
+ )
331
+
332
+ @with_tempfile(unlink=False)
333
+ def test_cached_remote_function(self, trace_path=None) -> None:
334
+ mesh = monarch.Simulator(
335
+ hosts=1,
336
+ gpus=2,
337
+ trace_path=trace_path,
338
+ trace_mode=SimulatorTraceMode.STREAM_ONLY,
339
+ ).mesh
340
+ with mesh.activate():
341
+ pg = mesh.process_group(("gpu",))
342
+ myrank = mesh.rank("host") * 8 + mesh.rank("gpu")
343
+ x = torch.ones((3, 4), device="cuda") * myrank
344
+ reduce = simple_all_reduce_local(x, group=pg)
345
+ assert reduce is not None
346
+ local_reduce = fetch_shard(reduce)
347
+ _ = local_reduce.result()
348
+ mesh.exit()
349
+
350
+ @with_tempfile(unlink=False)
351
+ def test_chunk_cat(self, trace_path=None) -> None:
352
+ mesh = monarch.Simulator(
353
+ hosts=1,
354
+ gpus=2,
355
+ trace_path=trace_path,
356
+ trace_mode=SimulatorTraceMode.STREAM_ONLY,
357
+ ).mesh
358
+
359
+ with mesh.activate():
360
+ x = torch.ones((4, 4), device="cuda")
361
+ y = torch.ones((4, 4), device="cuda")
362
+ out = torch.zeros((2, 8), device="cuda")
363
+ input_tensors = [x, y]
364
+ torch._chunk_cat(
365
+ input_tensors,
366
+ dim=0,
367
+ num_chunks=2,
368
+ out=out,
369
+ )
370
+ torch._chunk_cat(
371
+ input_tensors,
372
+ dim=0,
373
+ num_chunks=2,
374
+ out=out,
375
+ )
376
+ mesh.exit()
377
+
378
+ @with_tempfile(unlink=False)
379
+ def test_view(self, trace_path=None) -> None:
380
+ mesh = monarch.Simulator(
381
+ hosts=1,
382
+ gpus=2,
383
+ trace_path=trace_path,
384
+ trace_mode=SimulatorTraceMode.STREAM_ONLY,
385
+ ).mesh
386
+
387
+ with mesh.activate():
388
+ x = torch.ones((4, 4), device="cuda")
389
+ x = x.flatten()
390
+ mesh.exit()
391
+ commands, memory = self._get_simulation_result(1, trace_path)
392
+ # Only one should be capture as view is a CPU op.
393
+ assert commands["main"] == ["aten.ones"]
394
+
395
+ @with_tempfile(unlink=False)
396
+ def test_send_tensor(self, trace_path=None) -> None:
397
+ mesh = monarch.Simulator(
398
+ hosts=1,
399
+ gpus=2,
400
+ trace_path=trace_path,
401
+ trace_mode=SimulatorTraceMode.STREAM_ONLY,
402
+ ).mesh
403
+
404
+ with mesh(gpu=0).activate():
405
+ x = torch.ones((4, 4), device="cuda")
406
+ _ = x.to_mesh(mesh(gpu=1))
407
+ mesh.exit()
408
+ commands, memory = self._get_simulation_result(1, trace_path)
409
+ assert commands["main"], ["aten.ones", "SendTensor"]
410
+ commands, memory = self._get_simulation_result(3, trace_path)
411
+ assert commands["main"], ["RecvTensor"]
@@ -0,0 +1,64 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import unittest
9
+
10
+ from monarch.simulator.task import Task, TaskState, WorkerTaskManager
11
+
12
+
13
+ class TestTask(unittest.TestCase):
14
+ def test_worker_task_manager(self):
15
+ manager = WorkerTaskManager()
16
+ kwargs = {
17
+ "inputs": [2],
18
+ "outputs": [3],
19
+ "command_id": 1,
20
+ "start_time": 9,
21
+ "runtime": 1,
22
+ "meta": ["a"],
23
+ }
24
+ task = Task(**kwargs)
25
+ task._state = TaskState.EXECUTED
26
+
27
+ manager.add(task)
28
+ # This task is executed.
29
+ manager.remove(task)
30
+
31
+ task2 = Task(**kwargs)
32
+ task2.dependencies = [task]
33
+ manager.add(task2)
34
+
35
+ collectives = []
36
+ collective_task = Task(collectives=collectives, **kwargs)
37
+ collective_task.dependencies = [task2]
38
+ manager.add(collective_task)
39
+ # This is from another worker. Don't add it to the manager.
40
+ other_worker_task = Task(**kwargs)
41
+
42
+ collectives.append(other_worker_task)
43
+ wait_task = Task(waits=[task], **kwargs)
44
+ manager.add(wait_task)
45
+
46
+ cloned_manager = manager.clone()
47
+
48
+ self.assertEqual(len(manager.tasks), 3)
49
+ self.assertEqual(manager.tasks.keys(), cloned_manager.tasks.keys())
50
+ cloned_task2 = cloned_manager.tasks[task2.task_id]
51
+ self.assertNotEqual(task2, cloned_task2)
52
+ for k in kwargs.keys():
53
+ self.assertEqual(getattr(cloned_task2, k), getattr(task2, k))
54
+ self.assertEqual(cloned_task2.dependencies[0].task_id, task.task_id)
55
+ self.assertNotEqual(cloned_task2.dependencies[0], task)
56
+ cloned_wait_task = cloned_manager.tasks[wait_task.task_id]
57
+ self.assertEqual(cloned_wait_task.waits[0].task_id, task.task_id)
58
+ self.assertNotEqual(cloned_wait_task.waits[0], task)
59
+
60
+ self.assertEqual(len(collectives), 3)
61
+ cloned_collective_task = cloned_manager.tasks[collective_task.task_id]
62
+ self.assertTrue(collective_task in collectives)
63
+ self.assertTrue(cloned_collective_task in collectives)
64
+ self.assertNotEqual(collective_task, cloned_collective_task)
@@ -0,0 +1,102 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import unittest
9
+ from typing import Tuple
10
+
11
+ import torch
12
+ from monarch.common.fake import fake_call
13
+
14
+ from monarch.simulator.profiling import RuntimeEstimator
15
+ from monarch.simulator.task import Task
16
+ from monarch.simulator.tensor import FakeTensorTracker
17
+ from monarch.simulator.worker import Worker
18
+
19
+
20
+ # pyre-ignore-all-errors[6]
21
+ # pyre-ignore-all-errors[16]
22
+ def create_test_tasks(fake_tensor_tracker) -> Tuple[Task, ...]:
23
+ def fake():
24
+ for i in range(4):
25
+ tensor = torch.randn(100, 100).cuda()
26
+ tensor.ref = i
27
+ tensor._fake = tensor
28
+ fake_tensor_tracker.add({i: tensor})
29
+
30
+ kwargs = {
31
+ "inputs": [],
32
+ "outputs": [0],
33
+ "command_id": 0,
34
+ "start_time": 1,
35
+ "runtime": 1,
36
+ "meta": ["randn"],
37
+ }
38
+
39
+ task0 = Task(**kwargs)
40
+
41
+ kwargs["outputs"] = [1]
42
+ task1 = Task(**kwargs)
43
+
44
+ kwargs["inputs"] = [0, 1]
45
+ kwargs["outputs"] = [2]
46
+ kwargs["meta"] = ["mm"]
47
+ task2 = Task(**kwargs)
48
+
49
+ kwargs["inputs"] = [2, 1]
50
+ kwargs["outputs"] = [3]
51
+ kwargs["meta"] = ["mm"]
52
+ task3 = Task(**kwargs)
53
+
54
+ return task0, task1, task2, task3
55
+
56
+ return fake_call(fake)
57
+
58
+
59
+ class TestWorker(unittest.TestCase):
60
+ def test_stream_clone(self):
61
+ worker = Worker(FakeTensorTracker(), RuntimeEstimator())
62
+ worker.create_stream(0, "main", default=True)
63
+
64
+ tasks = create_test_tasks(worker.fake_tensor_tracker)
65
+ for i in range(3):
66
+ worker.add_task(tasks[i], stream=0, now=i * 10 + 1)
67
+ # Execute the first and second task
68
+ for _ in range(2):
69
+ worker.maybe_set_ready()
70
+ worker.maybe_execute()
71
+ worker.maybe_finish()
72
+
73
+ main_stream = worker.streams[0]
74
+ cloned_task_manager = worker.task_manager.clone()
75
+ cloned_storage_tracker = worker.storage_tracker.clone()
76
+ cloned_cpu_tensors = worker.cpu_tensors.clone(
77
+ cloned_task_manager, cloned_storage_tracker
78
+ )
79
+ cloned_stream = main_stream.clone(
80
+ cloned_task_manager, cloned_storage_tracker, cloned_cpu_tensors
81
+ )
82
+
83
+ self.assertEqual(cloned_stream.last_task.task_id, main_stream.last_task.task_id)
84
+ self.assertEqual(len(cloned_stream.task_queue), len(main_stream.task_queue))
85
+
86
+ self.assertEqual(cloned_stream.now, main_stream.now)
87
+ self.assertEqual(cloned_stream.events, main_stream.events)
88
+ self.assertNotEqual(id(cloned_stream.events), id(main_stream.events))
89
+ self.assertEqual(cloned_stream.memory.usage, main_stream.memory.usage)
90
+ self.assertEqual(cloned_stream.memory.events, main_stream.memory.events)
91
+ self.assertNotEqual(
92
+ cloned_stream.memory.storage_tracker, main_stream.memory.storage_tracker
93
+ )
94
+ self.assertNotEqual(id(cloned_stream.memory), id(main_stream.memory))
95
+ self.assertEqual(
96
+ cloned_stream.tensors.pending_delete_tensors,
97
+ main_stream.tensors.pending_delete_tensors,
98
+ )
99
+ self.assertEqual(
100
+ cloned_stream.tensors.tensors.keys(), main_stream.tensors.tensors.keys()
101
+ )
102
+ self.assertNotEqual(cloned_stream.tensors.tensors, main_stream.tensors.tensors)
tests/sleep_binary.py ADDED
@@ -0,0 +1,35 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
9
+
10
+ """
11
+ A simple binary that calls the sleep_indefinitely_for_unit_tests function from the monarch extension.
12
+ This is used to test the signal handling behavior of signal_safe_block_on.
13
+ """
14
+
15
+ import sys
16
+
17
+ from monarch._rust_bindings.monarch_hyperactor.runtime import ( # @manual
18
+ sleep_indefinitely_for_unit_tests,
19
+ )
20
+
21
+
22
+ def main() -> None:
23
+ print("Starting sleep_binary. Process will sleep indefinitely until interrupted.")
24
+ sys.stdout.flush() # Ensure the message is printed before we sleep
25
+
26
+ try:
27
+ # This will sleep indefinitely until interrupted by a signal
28
+ sleep_indefinitely_for_unit_tests()
29
+ except KeyboardInterrupt:
30
+ print("Received KeyboardInterrupt, exiting.")
31
+ sys.exit(0)
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()