torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,103 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import logging
|
9
|
+
import os
|
10
|
+
import socket
|
11
|
+
import sys
|
12
|
+
from pathlib import Path
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
def _handle_unhandled_exception(*args):
|
18
|
+
logger.error("Uncaught exception", exc_info=args)
|
19
|
+
|
20
|
+
|
21
|
+
_glog_level_to_abbr = {
|
22
|
+
"DEBUG": "V", # V is for VERBOSE in glog
|
23
|
+
"INFO": "I",
|
24
|
+
"WARNING": "W",
|
25
|
+
"ERROR": "E",
|
26
|
+
"CRITICAL": "C",
|
27
|
+
}
|
28
|
+
|
29
|
+
|
30
|
+
def fix_exception_lines(tb_lines):
|
31
|
+
formatted_lines = []
|
32
|
+
for line in tb_lines:
|
33
|
+
# Replace the standard file and line format with the custom format
|
34
|
+
if line.startswith(" File"):
|
35
|
+
# Extract the filename and line number
|
36
|
+
parts = line.split(",")
|
37
|
+
file_info = parts[0].strip()[6:-1] # Remove ' File "' and '"'
|
38
|
+
line_info = parts[1].strip()[5:] # Remove 'line '
|
39
|
+
new_line = f" File {file_info}:{line_info}"
|
40
|
+
if len(parts) > 2:
|
41
|
+
new_line += ", " + ",".join(parts[2:]).strip()
|
42
|
+
formatted_lines.append(new_line)
|
43
|
+
else:
|
44
|
+
formatted_lines.append(line.strip())
|
45
|
+
return formatted_lines
|
46
|
+
|
47
|
+
|
48
|
+
class _Formatter(logging.Formatter):
|
49
|
+
def __init__(self, suffix):
|
50
|
+
self.suffix = suffix
|
51
|
+
|
52
|
+
def format(self, record):
|
53
|
+
message = record.getMessage()
|
54
|
+
asctime = self.formatTime(record, "%m%d %H:%M:%S")
|
55
|
+
|
56
|
+
lines = message.strip().split("\n")
|
57
|
+
if record.exc_info:
|
58
|
+
exc_info = fix_exception_lines(
|
59
|
+
self.formatException(record.exc_info).split("\n")
|
60
|
+
)
|
61
|
+
lines.extend(exc_info)
|
62
|
+
if record.stack_info:
|
63
|
+
stack_info = self.formatStack(record.stack_info)
|
64
|
+
lines.extend(stack_info.strip().split("\n"))
|
65
|
+
|
66
|
+
shortlevel = _glog_level_to_abbr.get(record.levelname, record.levelname[0])
|
67
|
+
|
68
|
+
prefix = (
|
69
|
+
f"{shortlevel}{asctime}.{int(record.msecs*1000):06d} "
|
70
|
+
f"{record.filename}:"
|
71
|
+
f"{record.lineno}]{self.suffix}"
|
72
|
+
)
|
73
|
+
return "\n".join(f"{prefix} {line}" for line in lines)
|
74
|
+
|
75
|
+
|
76
|
+
def initialize_logging(process_name=None):
|
77
|
+
log_folder = os.environ.get("TORCH_MONARCH_LOG_FOLDER")
|
78
|
+
log_level = os.environ.get("TORCH_MONARCH_LOG_LEVEL", "INFO")
|
79
|
+
suffix = "" if process_name is None else f" {process_name}:"
|
80
|
+
handler = None
|
81
|
+
if log_folder is not None:
|
82
|
+
log_folder_path = Path(log_folder)
|
83
|
+
log_folder_path.mkdir(parents=True, exist_ok=True)
|
84
|
+
safe_process_name = (
|
85
|
+
process_name.replace("/", "_") if process_name else "logfile.log"
|
86
|
+
)
|
87
|
+
log_file_name = f"{safe_process_name}.log"
|
88
|
+
log_file_path = log_folder_path / log_file_name
|
89
|
+
handler = logging.FileHandler(log_file_path)
|
90
|
+
else:
|
91
|
+
handler = logging.StreamHandler()
|
92
|
+
handler.setFormatter(_Formatter(suffix))
|
93
|
+
handler.setLevel(log_level)
|
94
|
+
logging.root.setLevel(log_level)
|
95
|
+
logging.root.addHandler(handler)
|
96
|
+
sys.excepthook = _handle_unhandled_exception
|
97
|
+
|
98
|
+
|
99
|
+
def gethostname():
|
100
|
+
"""Get the hostname of the machine."""
|
101
|
+
hostname = socket.gethostname()
|
102
|
+
hostname = hostname.replace(".facebook.com", "")
|
103
|
+
return hostname
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import importlib.resources
|
8
|
+
import os
|
9
|
+
import sys
|
10
|
+
|
11
|
+
try:
|
12
|
+
from __manifest__ import fbmake # noqa
|
13
|
+
|
14
|
+
IN_PAR = True
|
15
|
+
except ImportError:
|
16
|
+
IN_PAR = False
|
17
|
+
|
18
|
+
PYTHON_EXECUTABLE: str
|
19
|
+
if IN_PAR:
|
20
|
+
# The worker bootstrap binary will import this supervisor lib. When that
|
21
|
+
# happens don't try to search for the bootstrap binary again, just use the
|
22
|
+
# current executable.
|
23
|
+
import __main__ as main_module # @manual
|
24
|
+
|
25
|
+
if hasattr(main_module, "__MONARCH_TENSOR_WORKER_ENV__"):
|
26
|
+
PYTHON_EXECUTABLE = os.environ["FB_XAR_INVOKED_NAME"]
|
27
|
+
else:
|
28
|
+
try:
|
29
|
+
with importlib.resources.path(
|
30
|
+
"monarch_tensor_worker_env", "worker_env"
|
31
|
+
) as path:
|
32
|
+
if not path.exists():
|
33
|
+
raise ImportError()
|
34
|
+
PYTHON_EXECUTABLE = str(path)
|
35
|
+
except ImportError:
|
36
|
+
raise ImportError(
|
37
|
+
"Monarch worker env not found, please define a custom 'monarch_tensor_worker_env' or "
|
38
|
+
"add '//monarch/python/monarch_supervisor/worker:default_worker_env' "
|
39
|
+
"to your binary dependencies in TARGETS"
|
40
|
+
)
|
41
|
+
else:
|
42
|
+
PYTHON_EXECUTABLE = sys.executable
|
tests/__init__.py
ADDED
File without changes
|
tests/dispatch_bench.py
ADDED
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import logging
|
9
|
+
import sys
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import torch.utils.benchmark as benchmark
|
13
|
+
|
14
|
+
# this function helps get a local device mesh for testing
|
15
|
+
from monarch._testing import mock_mesh
|
16
|
+
from monarch.builtins.log import set_logging_level_remote
|
17
|
+
|
18
|
+
from monarch.common._coalescing import coalescing
|
19
|
+
from monarch.common.remote import remote
|
20
|
+
from monarch.fetch import fetch_shard
|
21
|
+
from monarch.python_local_mesh import python_local_mesh
|
22
|
+
from monarch_supervisor.logging import initialize_logging
|
23
|
+
from tests.dispatch_bench_helper import run_loop, run_loop_local
|
24
|
+
|
25
|
+
NITER = 10000
|
26
|
+
DEFAULT_TENSOR_SIZE = (100, 100)
|
27
|
+
|
28
|
+
initialize_logging("dispatch_bench")
|
29
|
+
|
30
|
+
|
31
|
+
# user-defined remote functions
|
32
|
+
log = remote("monarch.worker._testing_function.log", propagate="inspect")
|
33
|
+
|
34
|
+
|
35
|
+
def local_run():
|
36
|
+
run_loop_local(NITER, DEFAULT_TENSOR_SIZE)
|
37
|
+
|
38
|
+
|
39
|
+
def dispatch_to_worker(device_mesh, n_iter, tensor_size):
|
40
|
+
with device_mesh.activate():
|
41
|
+
result = run_loop_local(n_iter, tensor_size)
|
42
|
+
local_result = fetch_shard(result, {"host": 0, "gpu": 0})
|
43
|
+
local_result = local_result.result()
|
44
|
+
|
45
|
+
|
46
|
+
def dispatch_to_worker_remote_function(device_mesh, n_iter, tensor_size):
|
47
|
+
with device_mesh.activate():
|
48
|
+
result = run_loop(n_iter, tensor_size)
|
49
|
+
local_result = fetch_shard(result, {"host": 0, "gpu": 0})
|
50
|
+
local_result = local_result.result()
|
51
|
+
|
52
|
+
|
53
|
+
def dispatch_to_worker_coalescing(device_mesh, n_iter, tensor_size):
|
54
|
+
with device_mesh.activate():
|
55
|
+
with coalescing():
|
56
|
+
result = run_loop_local(n_iter, tensor_size)
|
57
|
+
local_result = fetch_shard(result, {"host": 0, "gpu": 0})
|
58
|
+
local_result = local_result.result()
|
59
|
+
|
60
|
+
|
61
|
+
def main():
|
62
|
+
mocked = False
|
63
|
+
torch.set_default_device("cuda")
|
64
|
+
if mocked:
|
65
|
+
device_mesh = mock_mesh(hosts=1, gpus=1)
|
66
|
+
else:
|
67
|
+
device_mesh = python_local_mesh(hosts=1, gpus=1)
|
68
|
+
|
69
|
+
with device_mesh.activate():
|
70
|
+
torch.set_default_device("cuda")
|
71
|
+
set_logging_level_remote(logging.WARNING)
|
72
|
+
|
73
|
+
# bench 1: local compute only
|
74
|
+
t0 = benchmark.Timer(
|
75
|
+
stmt="run_loop_local(niter, tensor_size)",
|
76
|
+
setup="from __main__ import run_loop_local",
|
77
|
+
globals={"niter": NITER, "tensor_size": DEFAULT_TENSOR_SIZE},
|
78
|
+
)
|
79
|
+
local_only_results = t0.blocked_autorange(min_run_time=10)
|
80
|
+
print(local_only_results)
|
81
|
+
|
82
|
+
t1 = benchmark.Timer(
|
83
|
+
stmt="dispatch_to_worker(device_mesh, niter, tensor_size)",
|
84
|
+
setup="from __main__ import dispatch_to_worker",
|
85
|
+
globals={
|
86
|
+
"device_mesh": device_mesh,
|
87
|
+
"niter": NITER,
|
88
|
+
"tensor_size": DEFAULT_TENSOR_SIZE,
|
89
|
+
},
|
90
|
+
)
|
91
|
+
dispatch_to_worker_results = t1.blocked_autorange(min_run_time=10)
|
92
|
+
print(dispatch_to_worker_results)
|
93
|
+
|
94
|
+
t2 = benchmark.Timer(
|
95
|
+
stmt="dispatch_to_worker_remote_function(device_mesh, niter, tensor_size)",
|
96
|
+
setup="from __main__ import dispatch_to_worker_remote_function",
|
97
|
+
globals={
|
98
|
+
"device_mesh": device_mesh,
|
99
|
+
"niter": NITER,
|
100
|
+
"tensor_size": DEFAULT_TENSOR_SIZE,
|
101
|
+
},
|
102
|
+
)
|
103
|
+
dispatch_to_worker_remote_function_results = t2.blocked_autorange(min_run_time=10)
|
104
|
+
print(dispatch_to_worker_remote_function_results)
|
105
|
+
|
106
|
+
t3 = benchmark.Timer(
|
107
|
+
stmt="dispatch_to_worker_coalescing(device_mesh, niter, tensor_size)",
|
108
|
+
setup="from __main__ import dispatch_to_worker_coalescing",
|
109
|
+
globals={
|
110
|
+
"device_mesh": device_mesh,
|
111
|
+
"niter": NITER,
|
112
|
+
"tensor_size": DEFAULT_TENSOR_SIZE,
|
113
|
+
},
|
114
|
+
)
|
115
|
+
dispatch_to_worker_coalescing_results = t3.blocked_autorange(min_run_time=10)
|
116
|
+
print(dispatch_to_worker_coalescing_results)
|
117
|
+
|
118
|
+
device_mesh.exit()
|
119
|
+
|
120
|
+
return 0
|
121
|
+
|
122
|
+
|
123
|
+
if __name__ == "__main__":
|
124
|
+
sys.exit(main())
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from monarch.common.remote import remote
|
11
|
+
|
12
|
+
|
13
|
+
def run_loop_local(n_iters, tensor_shape=(2, 2)):
|
14
|
+
local = torch.zeros(*tensor_shape)
|
15
|
+
ones = torch.ones(*tensor_shape)
|
16
|
+
for _ in range(n_iters):
|
17
|
+
local = ones + local
|
18
|
+
return local
|
19
|
+
|
20
|
+
|
21
|
+
def _run_loop(*args, **kwargs):
|
22
|
+
return torch.ones(args[1])
|
23
|
+
|
24
|
+
|
25
|
+
run_loop = remote("tests.dispatch_bench_helper.run_loop_local", propagate=_run_loop)
|
@@ -0,0 +1,180 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import ctypes
|
9
|
+
import sys
|
10
|
+
|
11
|
+
import click
|
12
|
+
|
13
|
+
from monarch._rust_bindings.monarch_extension.panic import panicking_function
|
14
|
+
|
15
|
+
from monarch.actor_mesh import Actor, endpoint, send
|
16
|
+
from monarch.proc_mesh import proc_mesh
|
17
|
+
|
18
|
+
|
19
|
+
class ErrorActor(Actor):
|
20
|
+
"""An actor that has endpoints cause segfaults."""
|
21
|
+
|
22
|
+
@endpoint
|
23
|
+
async def cause_segfault(self) -> None:
|
24
|
+
"""Endpoint that causes a segmentation fault."""
|
25
|
+
# Create a C function pointer to an invalid memory address
|
26
|
+
# This will reliably cause a segmentation fault when called
|
27
|
+
function_type = ctypes.CFUNCTYPE(None)
|
28
|
+
# Use a non-zero but invalid address to avoid ctypes null pointer checks
|
29
|
+
invalid_address = 0xDEADBEEF
|
30
|
+
invalid_function = function_type(invalid_address)
|
31
|
+
# Calling this function will cause a segfault
|
32
|
+
invalid_function()
|
33
|
+
|
34
|
+
@endpoint
|
35
|
+
async def cause_panic(self) -> None:
|
36
|
+
"""Endpoint that calls a Rust function that panics."""
|
37
|
+
panicking_function()
|
38
|
+
|
39
|
+
@endpoint
|
40
|
+
async def await_then_error(self) -> None:
|
41
|
+
await asyncio.sleep(0.1)
|
42
|
+
await asyncio.sleep(0.1)
|
43
|
+
raise RuntimeError("oh noez")
|
44
|
+
|
45
|
+
|
46
|
+
class ErrorActorSync(Actor):
|
47
|
+
"""An actor that has endpoints cause segfaults."""
|
48
|
+
|
49
|
+
@endpoint # pyre-ignore
|
50
|
+
def cause_segfault(self) -> None:
|
51
|
+
"""Endpoint that causes a segmentation fault."""
|
52
|
+
# Create a C function pointer to an invalid memory address
|
53
|
+
# This will reliably cause a segmentation fault when called
|
54
|
+
function_type = ctypes.CFUNCTYPE(None)
|
55
|
+
# Use a non-zero but invalid address to avoid ctypes null pointer checks
|
56
|
+
invalid_address = 0xDEADBEEF
|
57
|
+
invalid_function = function_type(invalid_address)
|
58
|
+
# Calling this function will cause a segfault
|
59
|
+
invalid_function()
|
60
|
+
|
61
|
+
@endpoint # pyre-ignore
|
62
|
+
def cause_panic(self) -> None:
|
63
|
+
"""Endpoint that calls a Rust function that panics."""
|
64
|
+
panicking_function()
|
65
|
+
|
66
|
+
|
67
|
+
def _run_error_test_sync(num_procs, sync_endpoint, endpoint_name):
|
68
|
+
proc = proc_mesh(gpus=num_procs).get()
|
69
|
+
if sync_endpoint:
|
70
|
+
actor_class = ErrorActorSync
|
71
|
+
else:
|
72
|
+
actor_class = ErrorActor
|
73
|
+
error_actor = proc.spawn("error_actor", actor_class).get()
|
74
|
+
|
75
|
+
# This output is checked in the test to make sure that the process actually got here
|
76
|
+
print("I actually ran")
|
77
|
+
sys.stdout.flush()
|
78
|
+
|
79
|
+
if endpoint_name == "cause_segfault":
|
80
|
+
endpoint = error_actor.cause_segfault
|
81
|
+
elif endpoint_name == "cause_panic":
|
82
|
+
endpoint = error_actor.cause_panic
|
83
|
+
else:
|
84
|
+
raise ValueError(f"Unknown endpoint name: {endpoint_name}")
|
85
|
+
|
86
|
+
# Exercise both call() and call_one() in our tests, to check that error
|
87
|
+
# aggregation behavior is consistent.
|
88
|
+
if num_procs == 1:
|
89
|
+
endpoint.call_one().get()
|
90
|
+
else:
|
91
|
+
endpoint.call().get()
|
92
|
+
|
93
|
+
|
94
|
+
def _run_error_test(num_procs, sync_endpoint, endpoint_name):
|
95
|
+
import asyncio
|
96
|
+
|
97
|
+
if sync_endpoint:
|
98
|
+
actor_class = ErrorActorSync
|
99
|
+
else:
|
100
|
+
actor_class = ErrorActor
|
101
|
+
|
102
|
+
async def run_test():
|
103
|
+
proc = await proc_mesh(gpus=num_procs)
|
104
|
+
error_actor = await proc.spawn("error_actor", actor_class)
|
105
|
+
|
106
|
+
# This output is checked in the test to make sure that the process actually got here
|
107
|
+
print("I actually ran")
|
108
|
+
sys.stdout.flush()
|
109
|
+
|
110
|
+
if endpoint_name == "cause_segfault":
|
111
|
+
endpoint = error_actor.cause_segfault
|
112
|
+
elif endpoint_name == "cause_panic":
|
113
|
+
endpoint = error_actor.cause_panic
|
114
|
+
else:
|
115
|
+
raise ValueError(f"Unknown endpoint name: {endpoint_name}")
|
116
|
+
|
117
|
+
# Exercise both call() and call_one() in our tests, to check that error
|
118
|
+
# aggregation behavior is consistent.
|
119
|
+
if num_procs == 1:
|
120
|
+
await endpoint.call_one()
|
121
|
+
else:
|
122
|
+
await endpoint.call()
|
123
|
+
|
124
|
+
asyncio.run(run_test())
|
125
|
+
|
126
|
+
|
127
|
+
@click.group()
|
128
|
+
def main():
|
129
|
+
pass
|
130
|
+
|
131
|
+
|
132
|
+
@main.command("error-endpoint")
|
133
|
+
@click.option("--num-procs", type=int, required=True)
|
134
|
+
@click.option("--sync-test-impl", type=bool, required=True)
|
135
|
+
@click.option("--sync-endpoint", type=bool, required=True)
|
136
|
+
@click.option("--endpoint-name", type=str, required=True)
|
137
|
+
def error_endpoint(num_procs, sync_test_impl, sync_endpoint, endpoint_name):
|
138
|
+
print(
|
139
|
+
f"Running segfault test: {num_procs=} {sync_test_impl=} {sync_endpoint=}, {endpoint_name=}"
|
140
|
+
)
|
141
|
+
|
142
|
+
if sync_test_impl:
|
143
|
+
_run_error_test_sync(num_procs, sync_endpoint, endpoint_name)
|
144
|
+
else:
|
145
|
+
_run_error_test(num_procs, sync_endpoint, endpoint_name)
|
146
|
+
|
147
|
+
|
148
|
+
@main.command("error-bootstrap")
|
149
|
+
def error_bootstrap():
|
150
|
+
print("I actually ran")
|
151
|
+
sys.stdout.flush()
|
152
|
+
|
153
|
+
proc_mesh(gpus=4, env={"MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING": "1"}).get()
|
154
|
+
|
155
|
+
|
156
|
+
async def _error_unmonitored():
|
157
|
+
print("I actually ran")
|
158
|
+
sys.stdout.flush()
|
159
|
+
|
160
|
+
proc = await proc_mesh(gpus=1)
|
161
|
+
actor = await proc.spawn("error_actor", ErrorActor)
|
162
|
+
|
163
|
+
# fire and forget
|
164
|
+
send(actor.await_then_error, (), {}, None, "all")
|
165
|
+
|
166
|
+
# Wait. Eventually a supervision event will get propagated and the process
|
167
|
+
# will exit.
|
168
|
+
#
|
169
|
+
# If an event is not delivered, the test will time out before this sleep
|
170
|
+
# finishes.
|
171
|
+
await asyncio.sleep(300)
|
172
|
+
|
173
|
+
|
174
|
+
@main.command("error-unmonitored")
|
175
|
+
def error_unmonitored():
|
176
|
+
asyncio.run(_error_unmonitored())
|
177
|
+
|
178
|
+
|
179
|
+
if __name__ == "__main__":
|
180
|
+
main()
|
File without changes
|
@@ -0,0 +1,136 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import unittest
|
9
|
+
|
10
|
+
import pytest
|
11
|
+
|
12
|
+
import torch
|
13
|
+
|
14
|
+
from monarch.common import messages
|
15
|
+
from monarch.simulator.profiling import RuntimeEstimator, RuntimeProfiler, TimingType
|
16
|
+
|
17
|
+
|
18
|
+
# pyre-ignore-all-errors[6]
|
19
|
+
# pyre-ignore-all-errors[16]
|
20
|
+
class TestRuntimeEstimator(unittest.TestCase):
|
21
|
+
def test_user_manual_setting(self):
|
22
|
+
runtime = RuntimeEstimator()
|
23
|
+
|
24
|
+
input_tensor = torch.rand(10, 10)
|
25
|
+
input_tensor.ref = 1
|
26
|
+
input_tensor._fake = None
|
27
|
+
output_tensor = torch.rand(10, 10)
|
28
|
+
output_tensor.ref = 2
|
29
|
+
output_tensor._fake = None
|
30
|
+
|
31
|
+
send_tensor = messages.SendTensor(
|
32
|
+
result=output_tensor,
|
33
|
+
from_ranks=[1],
|
34
|
+
to_ranks=[2],
|
35
|
+
tensor=input_tensor,
|
36
|
+
factory=None,
|
37
|
+
from_stream=None,
|
38
|
+
to_stream=None,
|
39
|
+
)
|
40
|
+
reduce = messages.Reduce(
|
41
|
+
result=output_tensor,
|
42
|
+
local_tensor=input_tensor,
|
43
|
+
factory=None,
|
44
|
+
source_mesh=None,
|
45
|
+
stream=None,
|
46
|
+
dims=None,
|
47
|
+
reduction=None,
|
48
|
+
scatter=False,
|
49
|
+
inplace=False,
|
50
|
+
out=None,
|
51
|
+
)
|
52
|
+
call_function = messages.CallFunction(
|
53
|
+
ident=1,
|
54
|
+
result=None,
|
55
|
+
mutates=None,
|
56
|
+
function=None,
|
57
|
+
args=None,
|
58
|
+
kwargs=None,
|
59
|
+
stream=None,
|
60
|
+
device_mesh=None,
|
61
|
+
remote_process_groups=None,
|
62
|
+
)
|
63
|
+
|
64
|
+
self.assertEqual(runtime.get_runtime(send_tensor), 100_000)
|
65
|
+
self.assertEqual(runtime.get_runtime(reduce), 100_000)
|
66
|
+
self.assertEqual(runtime.get_runtime(call_function), 10_000)
|
67
|
+
self.assertEqual(runtime.get_runtime("kernel_launch"), 500)
|
68
|
+
self.assertEqual(runtime.get_runtime("wait_event"), 500)
|
69
|
+
|
70
|
+
runtime.set_custom_timing(
|
71
|
+
{
|
72
|
+
TimingType.SEND_TENSOR: 1_000,
|
73
|
+
TimingType.REDUCE: 2_000,
|
74
|
+
TimingType.CALL_FUNCTION: 3_000,
|
75
|
+
TimingType.KERNEL_LAUNCH: 4_000,
|
76
|
+
TimingType.WAIT_EVENT: 5_000,
|
77
|
+
}
|
78
|
+
)
|
79
|
+
self.assertEqual(runtime.get_runtime(send_tensor), 1_000)
|
80
|
+
self.assertEqual(runtime.get_runtime(reduce), 2_000)
|
81
|
+
self.assertEqual(runtime.get_runtime(call_function), 3_000)
|
82
|
+
self.assertEqual(runtime.get_runtime("kernel_launch"), 4_000)
|
83
|
+
self.assertEqual(runtime.get_runtime("wait_event"), 5_000)
|
84
|
+
|
85
|
+
runtime.set_custom_timing(
|
86
|
+
{
|
87
|
+
TimingType.SEND_TENSOR: lambda msg: 4_000,
|
88
|
+
TimingType.REDUCE: lambda msg: 5_000,
|
89
|
+
TimingType.CALL_FUNCTION: lambda msg: 6_000,
|
90
|
+
TimingType.KERNEL_LAUNCH: lambda: 8_000,
|
91
|
+
TimingType.WAIT_EVENT: lambda: 9_000,
|
92
|
+
}
|
93
|
+
)
|
94
|
+
self.assertEqual(runtime.get_runtime(send_tensor), 4_000)
|
95
|
+
self.assertEqual(runtime.get_runtime(reduce), 5_000)
|
96
|
+
self.assertEqual(runtime.get_runtime(call_function), 6_000)
|
97
|
+
self.assertEqual(runtime.get_runtime("kernel_launch"), 8_000)
|
98
|
+
self.assertEqual(runtime.get_runtime("wait_event"), 9_000)
|
99
|
+
|
100
|
+
@pytest.mark.oss_skip
|
101
|
+
def test_runtime_profiler(self) -> None:
|
102
|
+
m1 = torch.rand(1000, 2000).cuda()
|
103
|
+
m2 = torch.rand(2000, 4000).cuda()
|
104
|
+
m1.ref = 1
|
105
|
+
m2.ref = 2
|
106
|
+
msg = messages.CallFunction(
|
107
|
+
ident=1,
|
108
|
+
result=None,
|
109
|
+
mutates=None,
|
110
|
+
function=torch.ops.aten.mm.default,
|
111
|
+
args=(m1, m2),
|
112
|
+
kwargs=None,
|
113
|
+
stream=None,
|
114
|
+
device_mesh=None,
|
115
|
+
remote_process_groups=None,
|
116
|
+
)
|
117
|
+
profiler = RuntimeProfiler()
|
118
|
+
|
119
|
+
ret = profiler.profile_cmd(msg, ranks=[0])[0]
|
120
|
+
self.assertEqual(ret[0].factory.size, (1000, 4000))
|
121
|
+
# Should be at least 0.1 ms
|
122
|
+
self.assertTrue(ret[1] > 100)
|
123
|
+
# Should be at most 100 ms
|
124
|
+
self.assertTrue(ret[1] < 100_000)
|
125
|
+
|
126
|
+
# Change the cached profiling result to verify if cached mechanism works
|
127
|
+
key = next(iter(profiler.cached.keys()))
|
128
|
+
profiler.cached[key][0] = (profiler.cached[key][0][0], 987_654_321)
|
129
|
+
|
130
|
+
ret = profiler.profile_cmd(msg, ranks=[0])[0]
|
131
|
+
self.assertEqual(ret[0].factory.size, (1000, 4000))
|
132
|
+
self.assertEqual(ret[1], 987_654_321)
|
133
|
+
|
134
|
+
|
135
|
+
if __name__ == "__main__":
|
136
|
+
unittest.main()
|