torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,103 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import os
10
+ import socket
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _handle_unhandled_exception(*args):
18
+ logger.error("Uncaught exception", exc_info=args)
19
+
20
+
21
+ _glog_level_to_abbr = {
22
+ "DEBUG": "V", # V is for VERBOSE in glog
23
+ "INFO": "I",
24
+ "WARNING": "W",
25
+ "ERROR": "E",
26
+ "CRITICAL": "C",
27
+ }
28
+
29
+
30
+ def fix_exception_lines(tb_lines):
31
+ formatted_lines = []
32
+ for line in tb_lines:
33
+ # Replace the standard file and line format with the custom format
34
+ if line.startswith(" File"):
35
+ # Extract the filename and line number
36
+ parts = line.split(",")
37
+ file_info = parts[0].strip()[6:-1] # Remove ' File "' and '"'
38
+ line_info = parts[1].strip()[5:] # Remove 'line '
39
+ new_line = f" File {file_info}:{line_info}"
40
+ if len(parts) > 2:
41
+ new_line += ", " + ",".join(parts[2:]).strip()
42
+ formatted_lines.append(new_line)
43
+ else:
44
+ formatted_lines.append(line.strip())
45
+ return formatted_lines
46
+
47
+
48
+ class _Formatter(logging.Formatter):
49
+ def __init__(self, suffix):
50
+ self.suffix = suffix
51
+
52
+ def format(self, record):
53
+ message = record.getMessage()
54
+ asctime = self.formatTime(record, "%m%d %H:%M:%S")
55
+
56
+ lines = message.strip().split("\n")
57
+ if record.exc_info:
58
+ exc_info = fix_exception_lines(
59
+ self.formatException(record.exc_info).split("\n")
60
+ )
61
+ lines.extend(exc_info)
62
+ if record.stack_info:
63
+ stack_info = self.formatStack(record.stack_info)
64
+ lines.extend(stack_info.strip().split("\n"))
65
+
66
+ shortlevel = _glog_level_to_abbr.get(record.levelname, record.levelname[0])
67
+
68
+ prefix = (
69
+ f"{shortlevel}{asctime}.{int(record.msecs*1000):06d} "
70
+ f"{record.filename}:"
71
+ f"{record.lineno}]{self.suffix}"
72
+ )
73
+ return "\n".join(f"{prefix} {line}" for line in lines)
74
+
75
+
76
+ def initialize_logging(process_name=None):
77
+ log_folder = os.environ.get("TORCH_MONARCH_LOG_FOLDER")
78
+ log_level = os.environ.get("TORCH_MONARCH_LOG_LEVEL", "INFO")
79
+ suffix = "" if process_name is None else f" {process_name}:"
80
+ handler = None
81
+ if log_folder is not None:
82
+ log_folder_path = Path(log_folder)
83
+ log_folder_path.mkdir(parents=True, exist_ok=True)
84
+ safe_process_name = (
85
+ process_name.replace("/", "_") if process_name else "logfile.log"
86
+ )
87
+ log_file_name = f"{safe_process_name}.log"
88
+ log_file_path = log_folder_path / log_file_name
89
+ handler = logging.FileHandler(log_file_path)
90
+ else:
91
+ handler = logging.StreamHandler()
92
+ handler.setFormatter(_Formatter(suffix))
93
+ handler.setLevel(log_level)
94
+ logging.root.setLevel(log_level)
95
+ logging.root.addHandler(handler)
96
+ sys.excepthook = _handle_unhandled_exception
97
+
98
+
99
+ def gethostname():
100
+ """Get the hostname of the machine."""
101
+ hostname = socket.gethostname()
102
+ hostname = hostname.replace(".facebook.com", "")
103
+ return hostname
@@ -0,0 +1,42 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import importlib.resources
8
+ import os
9
+ import sys
10
+
11
+ try:
12
+ from __manifest__ import fbmake # noqa
13
+
14
+ IN_PAR = True
15
+ except ImportError:
16
+ IN_PAR = False
17
+
18
+ PYTHON_EXECUTABLE: str
19
+ if IN_PAR:
20
+ # The worker bootstrap binary will import this supervisor lib. When that
21
+ # happens don't try to search for the bootstrap binary again, just use the
22
+ # current executable.
23
+ import __main__ as main_module # @manual
24
+
25
+ if hasattr(main_module, "__MONARCH_TENSOR_WORKER_ENV__"):
26
+ PYTHON_EXECUTABLE = os.environ["FB_XAR_INVOKED_NAME"]
27
+ else:
28
+ try:
29
+ with importlib.resources.path(
30
+ "monarch_tensor_worker_env", "worker_env"
31
+ ) as path:
32
+ if not path.exists():
33
+ raise ImportError()
34
+ PYTHON_EXECUTABLE = str(path)
35
+ except ImportError:
36
+ raise ImportError(
37
+ "Monarch worker env not found, please define a custom 'monarch_tensor_worker_env' or "
38
+ "add '//monarch/python/monarch_supervisor/worker:default_worker_env' "
39
+ "to your binary dependencies in TARGETS"
40
+ )
41
+ else:
42
+ PYTHON_EXECUTABLE = sys.executable
tests/__init__.py ADDED
File without changes
@@ -0,0 +1,124 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import sys
10
+
11
+ import torch
12
+ import torch.utils.benchmark as benchmark
13
+
14
+ # this function helps get a local device mesh for testing
15
+ from monarch._testing import mock_mesh
16
+ from monarch.builtins.log import set_logging_level_remote
17
+
18
+ from monarch.common._coalescing import coalescing
19
+ from monarch.common.remote import remote
20
+ from monarch.fetch import fetch_shard
21
+ from monarch.python_local_mesh import python_local_mesh
22
+ from monarch_supervisor.logging import initialize_logging
23
+ from tests.dispatch_bench_helper import run_loop, run_loop_local
24
+
25
+ NITER = 10000
26
+ DEFAULT_TENSOR_SIZE = (100, 100)
27
+
28
+ initialize_logging("dispatch_bench")
29
+
30
+
31
+ # user-defined remote functions
32
+ log = remote("monarch.worker._testing_function.log", propagate="inspect")
33
+
34
+
35
+ def local_run():
36
+ run_loop_local(NITER, DEFAULT_TENSOR_SIZE)
37
+
38
+
39
+ def dispatch_to_worker(device_mesh, n_iter, tensor_size):
40
+ with device_mesh.activate():
41
+ result = run_loop_local(n_iter, tensor_size)
42
+ local_result = fetch_shard(result, {"host": 0, "gpu": 0})
43
+ local_result = local_result.result()
44
+
45
+
46
+ def dispatch_to_worker_remote_function(device_mesh, n_iter, tensor_size):
47
+ with device_mesh.activate():
48
+ result = run_loop(n_iter, tensor_size)
49
+ local_result = fetch_shard(result, {"host": 0, "gpu": 0})
50
+ local_result = local_result.result()
51
+
52
+
53
+ def dispatch_to_worker_coalescing(device_mesh, n_iter, tensor_size):
54
+ with device_mesh.activate():
55
+ with coalescing():
56
+ result = run_loop_local(n_iter, tensor_size)
57
+ local_result = fetch_shard(result, {"host": 0, "gpu": 0})
58
+ local_result = local_result.result()
59
+
60
+
61
+ def main():
62
+ mocked = False
63
+ torch.set_default_device("cuda")
64
+ if mocked:
65
+ device_mesh = mock_mesh(hosts=1, gpus=1)
66
+ else:
67
+ device_mesh = python_local_mesh(hosts=1, gpus=1)
68
+
69
+ with device_mesh.activate():
70
+ torch.set_default_device("cuda")
71
+ set_logging_level_remote(logging.WARNING)
72
+
73
+ # bench 1: local compute only
74
+ t0 = benchmark.Timer(
75
+ stmt="run_loop_local(niter, tensor_size)",
76
+ setup="from __main__ import run_loop_local",
77
+ globals={"niter": NITER, "tensor_size": DEFAULT_TENSOR_SIZE},
78
+ )
79
+ local_only_results = t0.blocked_autorange(min_run_time=10)
80
+ print(local_only_results)
81
+
82
+ t1 = benchmark.Timer(
83
+ stmt="dispatch_to_worker(device_mesh, niter, tensor_size)",
84
+ setup="from __main__ import dispatch_to_worker",
85
+ globals={
86
+ "device_mesh": device_mesh,
87
+ "niter": NITER,
88
+ "tensor_size": DEFAULT_TENSOR_SIZE,
89
+ },
90
+ )
91
+ dispatch_to_worker_results = t1.blocked_autorange(min_run_time=10)
92
+ print(dispatch_to_worker_results)
93
+
94
+ t2 = benchmark.Timer(
95
+ stmt="dispatch_to_worker_remote_function(device_mesh, niter, tensor_size)",
96
+ setup="from __main__ import dispatch_to_worker_remote_function",
97
+ globals={
98
+ "device_mesh": device_mesh,
99
+ "niter": NITER,
100
+ "tensor_size": DEFAULT_TENSOR_SIZE,
101
+ },
102
+ )
103
+ dispatch_to_worker_remote_function_results = t2.blocked_autorange(min_run_time=10)
104
+ print(dispatch_to_worker_remote_function_results)
105
+
106
+ t3 = benchmark.Timer(
107
+ stmt="dispatch_to_worker_coalescing(device_mesh, niter, tensor_size)",
108
+ setup="from __main__ import dispatch_to_worker_coalescing",
109
+ globals={
110
+ "device_mesh": device_mesh,
111
+ "niter": NITER,
112
+ "tensor_size": DEFAULT_TENSOR_SIZE,
113
+ },
114
+ )
115
+ dispatch_to_worker_coalescing_results = t3.blocked_autorange(min_run_time=10)
116
+ print(dispatch_to_worker_coalescing_results)
117
+
118
+ device_mesh.exit()
119
+
120
+ return 0
121
+
122
+
123
+ if __name__ == "__main__":
124
+ sys.exit(main())
@@ -0,0 +1,25 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import torch
9
+
10
+ from monarch.common.remote import remote
11
+
12
+
13
+ def run_loop_local(n_iters, tensor_shape=(2, 2)):
14
+ local = torch.zeros(*tensor_shape)
15
+ ones = torch.ones(*tensor_shape)
16
+ for _ in range(n_iters):
17
+ local = ones + local
18
+ return local
19
+
20
+
21
+ def _run_loop(*args, **kwargs):
22
+ return torch.ones(args[1])
23
+
24
+
25
+ run_loop = remote("tests.dispatch_bench_helper.run_loop_local", propagate=_run_loop)
@@ -0,0 +1,180 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import asyncio
8
+ import ctypes
9
+ import sys
10
+
11
+ import click
12
+
13
+ from monarch._rust_bindings.monarch_extension.panic import panicking_function
14
+
15
+ from monarch.actor_mesh import Actor, endpoint, send
16
+ from monarch.proc_mesh import proc_mesh
17
+
18
+
19
+ class ErrorActor(Actor):
20
+ """An actor that has endpoints cause segfaults."""
21
+
22
+ @endpoint
23
+ async def cause_segfault(self) -> None:
24
+ """Endpoint that causes a segmentation fault."""
25
+ # Create a C function pointer to an invalid memory address
26
+ # This will reliably cause a segmentation fault when called
27
+ function_type = ctypes.CFUNCTYPE(None)
28
+ # Use a non-zero but invalid address to avoid ctypes null pointer checks
29
+ invalid_address = 0xDEADBEEF
30
+ invalid_function = function_type(invalid_address)
31
+ # Calling this function will cause a segfault
32
+ invalid_function()
33
+
34
+ @endpoint
35
+ async def cause_panic(self) -> None:
36
+ """Endpoint that calls a Rust function that panics."""
37
+ panicking_function()
38
+
39
+ @endpoint
40
+ async def await_then_error(self) -> None:
41
+ await asyncio.sleep(0.1)
42
+ await asyncio.sleep(0.1)
43
+ raise RuntimeError("oh noez")
44
+
45
+
46
+ class ErrorActorSync(Actor):
47
+ """An actor that has endpoints cause segfaults."""
48
+
49
+ @endpoint # pyre-ignore
50
+ def cause_segfault(self) -> None:
51
+ """Endpoint that causes a segmentation fault."""
52
+ # Create a C function pointer to an invalid memory address
53
+ # This will reliably cause a segmentation fault when called
54
+ function_type = ctypes.CFUNCTYPE(None)
55
+ # Use a non-zero but invalid address to avoid ctypes null pointer checks
56
+ invalid_address = 0xDEADBEEF
57
+ invalid_function = function_type(invalid_address)
58
+ # Calling this function will cause a segfault
59
+ invalid_function()
60
+
61
+ @endpoint # pyre-ignore
62
+ def cause_panic(self) -> None:
63
+ """Endpoint that calls a Rust function that panics."""
64
+ panicking_function()
65
+
66
+
67
+ def _run_error_test_sync(num_procs, sync_endpoint, endpoint_name):
68
+ proc = proc_mesh(gpus=num_procs).get()
69
+ if sync_endpoint:
70
+ actor_class = ErrorActorSync
71
+ else:
72
+ actor_class = ErrorActor
73
+ error_actor = proc.spawn("error_actor", actor_class).get()
74
+
75
+ # This output is checked in the test to make sure that the process actually got here
76
+ print("I actually ran")
77
+ sys.stdout.flush()
78
+
79
+ if endpoint_name == "cause_segfault":
80
+ endpoint = error_actor.cause_segfault
81
+ elif endpoint_name == "cause_panic":
82
+ endpoint = error_actor.cause_panic
83
+ else:
84
+ raise ValueError(f"Unknown endpoint name: {endpoint_name}")
85
+
86
+ # Exercise both call() and call_one() in our tests, to check that error
87
+ # aggregation behavior is consistent.
88
+ if num_procs == 1:
89
+ endpoint.call_one().get()
90
+ else:
91
+ endpoint.call().get()
92
+
93
+
94
+ def _run_error_test(num_procs, sync_endpoint, endpoint_name):
95
+ import asyncio
96
+
97
+ if sync_endpoint:
98
+ actor_class = ErrorActorSync
99
+ else:
100
+ actor_class = ErrorActor
101
+
102
+ async def run_test():
103
+ proc = await proc_mesh(gpus=num_procs)
104
+ error_actor = await proc.spawn("error_actor", actor_class)
105
+
106
+ # This output is checked in the test to make sure that the process actually got here
107
+ print("I actually ran")
108
+ sys.stdout.flush()
109
+
110
+ if endpoint_name == "cause_segfault":
111
+ endpoint = error_actor.cause_segfault
112
+ elif endpoint_name == "cause_panic":
113
+ endpoint = error_actor.cause_panic
114
+ else:
115
+ raise ValueError(f"Unknown endpoint name: {endpoint_name}")
116
+
117
+ # Exercise both call() and call_one() in our tests, to check that error
118
+ # aggregation behavior is consistent.
119
+ if num_procs == 1:
120
+ await endpoint.call_one()
121
+ else:
122
+ await endpoint.call()
123
+
124
+ asyncio.run(run_test())
125
+
126
+
127
+ @click.group()
128
+ def main():
129
+ pass
130
+
131
+
132
+ @main.command("error-endpoint")
133
+ @click.option("--num-procs", type=int, required=True)
134
+ @click.option("--sync-test-impl", type=bool, required=True)
135
+ @click.option("--sync-endpoint", type=bool, required=True)
136
+ @click.option("--endpoint-name", type=str, required=True)
137
+ def error_endpoint(num_procs, sync_test_impl, sync_endpoint, endpoint_name):
138
+ print(
139
+ f"Running segfault test: {num_procs=} {sync_test_impl=} {sync_endpoint=}, {endpoint_name=}"
140
+ )
141
+
142
+ if sync_test_impl:
143
+ _run_error_test_sync(num_procs, sync_endpoint, endpoint_name)
144
+ else:
145
+ _run_error_test(num_procs, sync_endpoint, endpoint_name)
146
+
147
+
148
+ @main.command("error-bootstrap")
149
+ def error_bootstrap():
150
+ print("I actually ran")
151
+ sys.stdout.flush()
152
+
153
+ proc_mesh(gpus=4, env={"MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING": "1"}).get()
154
+
155
+
156
+ async def _error_unmonitored():
157
+ print("I actually ran")
158
+ sys.stdout.flush()
159
+
160
+ proc = await proc_mesh(gpus=1)
161
+ actor = await proc.spawn("error_actor", ErrorActor)
162
+
163
+ # fire and forget
164
+ send(actor.await_then_error, (), {}, None, "all")
165
+
166
+ # Wait. Eventually a supervision event will get propagated and the process
167
+ # will exit.
168
+ #
169
+ # If an event is not delivered, the test will time out before this sleep
170
+ # finishes.
171
+ await asyncio.sleep(300)
172
+
173
+
174
+ @main.command("error-unmonitored")
175
+ def error_unmonitored():
176
+ asyncio.run(_error_unmonitored())
177
+
178
+
179
+ if __name__ == "__main__":
180
+ main()
File without changes
@@ -0,0 +1,136 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import unittest
9
+
10
+ import pytest
11
+
12
+ import torch
13
+
14
+ from monarch.common import messages
15
+ from monarch.simulator.profiling import RuntimeEstimator, RuntimeProfiler, TimingType
16
+
17
+
18
+ # pyre-ignore-all-errors[6]
19
+ # pyre-ignore-all-errors[16]
20
+ class TestRuntimeEstimator(unittest.TestCase):
21
+ def test_user_manual_setting(self):
22
+ runtime = RuntimeEstimator()
23
+
24
+ input_tensor = torch.rand(10, 10)
25
+ input_tensor.ref = 1
26
+ input_tensor._fake = None
27
+ output_tensor = torch.rand(10, 10)
28
+ output_tensor.ref = 2
29
+ output_tensor._fake = None
30
+
31
+ send_tensor = messages.SendTensor(
32
+ result=output_tensor,
33
+ from_ranks=[1],
34
+ to_ranks=[2],
35
+ tensor=input_tensor,
36
+ factory=None,
37
+ from_stream=None,
38
+ to_stream=None,
39
+ )
40
+ reduce = messages.Reduce(
41
+ result=output_tensor,
42
+ local_tensor=input_tensor,
43
+ factory=None,
44
+ source_mesh=None,
45
+ stream=None,
46
+ dims=None,
47
+ reduction=None,
48
+ scatter=False,
49
+ inplace=False,
50
+ out=None,
51
+ )
52
+ call_function = messages.CallFunction(
53
+ ident=1,
54
+ result=None,
55
+ mutates=None,
56
+ function=None,
57
+ args=None,
58
+ kwargs=None,
59
+ stream=None,
60
+ device_mesh=None,
61
+ remote_process_groups=None,
62
+ )
63
+
64
+ self.assertEqual(runtime.get_runtime(send_tensor), 100_000)
65
+ self.assertEqual(runtime.get_runtime(reduce), 100_000)
66
+ self.assertEqual(runtime.get_runtime(call_function), 10_000)
67
+ self.assertEqual(runtime.get_runtime("kernel_launch"), 500)
68
+ self.assertEqual(runtime.get_runtime("wait_event"), 500)
69
+
70
+ runtime.set_custom_timing(
71
+ {
72
+ TimingType.SEND_TENSOR: 1_000,
73
+ TimingType.REDUCE: 2_000,
74
+ TimingType.CALL_FUNCTION: 3_000,
75
+ TimingType.KERNEL_LAUNCH: 4_000,
76
+ TimingType.WAIT_EVENT: 5_000,
77
+ }
78
+ )
79
+ self.assertEqual(runtime.get_runtime(send_tensor), 1_000)
80
+ self.assertEqual(runtime.get_runtime(reduce), 2_000)
81
+ self.assertEqual(runtime.get_runtime(call_function), 3_000)
82
+ self.assertEqual(runtime.get_runtime("kernel_launch"), 4_000)
83
+ self.assertEqual(runtime.get_runtime("wait_event"), 5_000)
84
+
85
+ runtime.set_custom_timing(
86
+ {
87
+ TimingType.SEND_TENSOR: lambda msg: 4_000,
88
+ TimingType.REDUCE: lambda msg: 5_000,
89
+ TimingType.CALL_FUNCTION: lambda msg: 6_000,
90
+ TimingType.KERNEL_LAUNCH: lambda: 8_000,
91
+ TimingType.WAIT_EVENT: lambda: 9_000,
92
+ }
93
+ )
94
+ self.assertEqual(runtime.get_runtime(send_tensor), 4_000)
95
+ self.assertEqual(runtime.get_runtime(reduce), 5_000)
96
+ self.assertEqual(runtime.get_runtime(call_function), 6_000)
97
+ self.assertEqual(runtime.get_runtime("kernel_launch"), 8_000)
98
+ self.assertEqual(runtime.get_runtime("wait_event"), 9_000)
99
+
100
+ @pytest.mark.oss_skip
101
+ def test_runtime_profiler(self) -> None:
102
+ m1 = torch.rand(1000, 2000).cuda()
103
+ m2 = torch.rand(2000, 4000).cuda()
104
+ m1.ref = 1
105
+ m2.ref = 2
106
+ msg = messages.CallFunction(
107
+ ident=1,
108
+ result=None,
109
+ mutates=None,
110
+ function=torch.ops.aten.mm.default,
111
+ args=(m1, m2),
112
+ kwargs=None,
113
+ stream=None,
114
+ device_mesh=None,
115
+ remote_process_groups=None,
116
+ )
117
+ profiler = RuntimeProfiler()
118
+
119
+ ret = profiler.profile_cmd(msg, ranks=[0])[0]
120
+ self.assertEqual(ret[0].factory.size, (1000, 4000))
121
+ # Should be at least 0.1 ms
122
+ self.assertTrue(ret[1] > 100)
123
+ # Should be at most 100 ms
124
+ self.assertTrue(ret[1] < 100_000)
125
+
126
+ # Change the cached profiling result to verify if cached mechanism works
127
+ key = next(iter(profiler.cached.keys()))
128
+ profiler.cached[key][0] = (profiler.cached[key][0][0], 987_654_321)
129
+
130
+ ret = profiler.profile_cmd(msg, ranks=[0])[0]
131
+ self.assertEqual(ret[0].factory.size, (1000, 4000))
132
+ self.assertEqual(ret[1], 987_654_321)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ unittest.main()