torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,103 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import os
10
+ import socket
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _handle_unhandled_exception(*args):
18
+ logger.error("Uncaught exception", exc_info=args)
19
+
20
+
21
+ _glog_level_to_abbr = {
22
+ "DEBUG": "V", # V is for VERBOSE in glog
23
+ "INFO": "I",
24
+ "WARNING": "W",
25
+ "ERROR": "E",
26
+ "CRITICAL": "C",
27
+ }
28
+
29
+
30
+ def fix_exception_lines(tb_lines):
31
+ formatted_lines = []
32
+ for line in tb_lines:
33
+ # Replace the standard file and line format with the custom format
34
+ if line.startswith(" File"):
35
+ # Extract the filename and line number
36
+ parts = line.split(",")
37
+ file_info = parts[0].strip()[6:-1] # Remove ' File "' and '"'
38
+ line_info = parts[1].strip()[5:] # Remove 'line '
39
+ new_line = f" File {file_info}:{line_info}"
40
+ if len(parts) > 2:
41
+ new_line += ", " + ",".join(parts[2:]).strip()
42
+ formatted_lines.append(new_line)
43
+ else:
44
+ formatted_lines.append(line.strip())
45
+ return formatted_lines
46
+
47
+
48
+ class _Formatter(logging.Formatter):
49
+ def __init__(self, suffix):
50
+ self.suffix = suffix
51
+
52
+ def format(self, record):
53
+ message = record.getMessage()
54
+ asctime = self.formatTime(record, "%m%d %H:%M:%S")
55
+
56
+ lines = message.strip().split("\n")
57
+ if record.exc_info:
58
+ exc_info = fix_exception_lines(
59
+ self.formatException(record.exc_info).split("\n")
60
+ )
61
+ lines.extend(exc_info)
62
+ if record.stack_info:
63
+ stack_info = self.formatStack(record.stack_info)
64
+ lines.extend(stack_info.strip().split("\n"))
65
+
66
+ shortlevel = _glog_level_to_abbr.get(record.levelname, record.levelname[0])
67
+
68
+ prefix = (
69
+ f"{shortlevel}{asctime}.{int(record.msecs*1000):06d} "
70
+ f"{record.filename}:"
71
+ f"{record.lineno}]{self.suffix}"
72
+ )
73
+ return "\n".join(f"{prefix} {line}" for line in lines)
74
+
75
+
76
+ def initialize_logging(process_name=None):
77
+ log_folder = os.environ.get("TORCH_MONARCH_LOG_FOLDER")
78
+ log_level = os.environ.get("TORCH_MONARCH_LOG_LEVEL", "INFO")
79
+ suffix = "" if process_name is None else f" {process_name}:"
80
+ handler = None
81
+ if log_folder is not None:
82
+ log_folder_path = Path(log_folder)
83
+ log_folder_path.mkdir(parents=True, exist_ok=True)
84
+ safe_process_name = (
85
+ process_name.replace("/", "_") if process_name else "logfile.log"
86
+ )
87
+ log_file_name = f"{safe_process_name}.log"
88
+ log_file_path = log_folder_path / log_file_name
89
+ handler = logging.FileHandler(log_file_path)
90
+ else:
91
+ handler = logging.StreamHandler()
92
+ handler.setFormatter(_Formatter(suffix))
93
+ handler.setLevel(log_level)
94
+ logging.root.setLevel(log_level)
95
+ logging.root.addHandler(handler)
96
+ sys.excepthook = _handle_unhandled_exception
97
+
98
+
99
+ def gethostname():
100
+ """Get the hostname of the machine."""
101
+ hostname = socket.gethostname()
102
+ hostname = hostname.replace(".facebook.com", "")
103
+ return hostname
@@ -0,0 +1,42 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import importlib.resources
8
+ import os
9
+ import sys
10
+
11
+ try:
12
+ from __manifest__ import fbmake # noqa
13
+
14
+ IN_PAR = True
15
+ except ImportError:
16
+ IN_PAR = False
17
+
18
+ PYTHON_EXECUTABLE: str
19
+ if IN_PAR:
20
+ # The worker bootstrap binary will import this supervisor lib. When that
21
+ # happens don't try to search for the bootstrap binary again, just use the
22
+ # current executable.
23
+ import __main__ as main_module # @manual
24
+
25
+ if hasattr(main_module, "__MONARCH_TENSOR_WORKER_ENV__"):
26
+ PYTHON_EXECUTABLE = os.environ["FB_XAR_INVOKED_NAME"]
27
+ else:
28
+ try:
29
+ with importlib.resources.path(
30
+ "monarch_tensor_worker_env", "worker_env"
31
+ ) as path:
32
+ if not path.exists():
33
+ raise ImportError()
34
+ PYTHON_EXECUTABLE = str(path)
35
+ except ImportError:
36
+ raise ImportError(
37
+ "Monarch worker env not found, please define a custom 'monarch_tensor_worker_env' or "
38
+ "add '//monarch/python/monarch_supervisor/worker:default_worker_env' "
39
+ "to your binary dependencies in TARGETS"
40
+ )
41
+ else:
42
+ PYTHON_EXECUTABLE = sys.executable
tests/__init__.py ADDED
File without changes
@@ -0,0 +1,124 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import sys
10
+
11
+ import torch
12
+ import torch.utils.benchmark as benchmark
13
+
14
+ # this function helps get a local device mesh for testing
15
+ from monarch._testing import mock_mesh
16
+ from monarch.builtins.log import set_logging_level_remote
17
+
18
+ from monarch.common._coalescing import coalescing
19
+ from monarch.common.remote import remote
20
+ from monarch.fetch import fetch_shard
21
+ from monarch.python_local_mesh import python_local_mesh
22
+ from monarch_supervisor.logging import initialize_logging
23
+ from tests.dispatch_bench_helper import run_loop, run_loop_local
24
+
25
+ NITER = 10000
26
+ DEFAULT_TENSOR_SIZE = (100, 100)
27
+
28
+ initialize_logging("dispatch_bench")
29
+
30
+
31
+ # user-defined remote functions
32
+ log = remote("monarch.worker._testing_function.log", propagate="inspect")
33
+
34
+
35
+ def local_run():
36
+ run_loop_local(NITER, DEFAULT_TENSOR_SIZE)
37
+
38
+
39
+ def dispatch_to_worker(device_mesh, n_iter, tensor_size):
40
+ with device_mesh.activate():
41
+ result = run_loop_local(n_iter, tensor_size)
42
+ local_result = fetch_shard(result, {"host": 0, "gpu": 0})
43
+ local_result = local_result.result()
44
+
45
+
46
+ def dispatch_to_worker_remote_function(device_mesh, n_iter, tensor_size):
47
+ with device_mesh.activate():
48
+ result = run_loop(n_iter, tensor_size)
49
+ local_result = fetch_shard(result, {"host": 0, "gpu": 0})
50
+ local_result = local_result.result()
51
+
52
+
53
+ def dispatch_to_worker_coalescing(device_mesh, n_iter, tensor_size):
54
+ with device_mesh.activate():
55
+ with coalescing():
56
+ result = run_loop_local(n_iter, tensor_size)
57
+ local_result = fetch_shard(result, {"host": 0, "gpu": 0})
58
+ local_result = local_result.result()
59
+
60
+
61
+ def main():
62
+ mocked = False
63
+ torch.set_default_device("cuda")
64
+ if mocked:
65
+ device_mesh = mock_mesh(hosts=1, gpus=1)
66
+ else:
67
+ device_mesh = python_local_mesh(hosts=1, gpus=1)
68
+
69
+ with device_mesh.activate():
70
+ torch.set_default_device("cuda")
71
+ set_logging_level_remote(logging.WARNING)
72
+
73
+ # bench 1: local compute only
74
+ t0 = benchmark.Timer(
75
+ stmt="run_loop_local(niter, tensor_size)",
76
+ setup="from __main__ import run_loop_local",
77
+ globals={"niter": NITER, "tensor_size": DEFAULT_TENSOR_SIZE},
78
+ )
79
+ local_only_results = t0.blocked_autorange(min_run_time=10)
80
+ print(local_only_results)
81
+
82
+ t1 = benchmark.Timer(
83
+ stmt="dispatch_to_worker(device_mesh, niter, tensor_size)",
84
+ setup="from __main__ import dispatch_to_worker",
85
+ globals={
86
+ "device_mesh": device_mesh,
87
+ "niter": NITER,
88
+ "tensor_size": DEFAULT_TENSOR_SIZE,
89
+ },
90
+ )
91
+ dispatch_to_worker_results = t1.blocked_autorange(min_run_time=10)
92
+ print(dispatch_to_worker_results)
93
+
94
+ t2 = benchmark.Timer(
95
+ stmt="dispatch_to_worker_remote_function(device_mesh, niter, tensor_size)",
96
+ setup="from __main__ import dispatch_to_worker_remote_function",
97
+ globals={
98
+ "device_mesh": device_mesh,
99
+ "niter": NITER,
100
+ "tensor_size": DEFAULT_TENSOR_SIZE,
101
+ },
102
+ )
103
+ dispatch_to_worker_remote_function_results = t2.blocked_autorange(min_run_time=10)
104
+ print(dispatch_to_worker_remote_function_results)
105
+
106
+ t3 = benchmark.Timer(
107
+ stmt="dispatch_to_worker_coalescing(device_mesh, niter, tensor_size)",
108
+ setup="from __main__ import dispatch_to_worker_coalescing",
109
+ globals={
110
+ "device_mesh": device_mesh,
111
+ "niter": NITER,
112
+ "tensor_size": DEFAULT_TENSOR_SIZE,
113
+ },
114
+ )
115
+ dispatch_to_worker_coalescing_results = t3.blocked_autorange(min_run_time=10)
116
+ print(dispatch_to_worker_coalescing_results)
117
+
118
+ device_mesh.exit()
119
+
120
+ return 0
121
+
122
+
123
+ if __name__ == "__main__":
124
+ sys.exit(main())
@@ -0,0 +1,25 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import torch
9
+
10
+ from monarch.common.remote import remote
11
+
12
+
13
+ def run_loop_local(n_iters, tensor_shape=(2, 2)):
14
+ local = torch.zeros(*tensor_shape)
15
+ ones = torch.ones(*tensor_shape)
16
+ for _ in range(n_iters):
17
+ local = ones + local
18
+ return local
19
+
20
+
21
+ def _run_loop(*args, **kwargs):
22
+ return torch.ones(args[1])
23
+
24
+
25
+ run_loop = remote("tests.dispatch_bench_helper.run_loop_local", propagate=_run_loop)
@@ -0,0 +1,139 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import ctypes
8
+ import sys
9
+
10
+ from monarch._rust_bindings.monarch_extension.panic import panicking_function
11
+
12
+ from monarch.actor_mesh import Actor, endpoint
13
+ from monarch.proc_mesh import proc_mesh
14
+
15
+
16
+ class ErrorActor(Actor):
17
+ """An actor that has endpoints cause segfaults."""
18
+
19
+ @endpoint
20
+ async def cause_segfault(self) -> None:
21
+ """Endpoint that causes a segmentation fault."""
22
+ # Create a C function pointer to an invalid memory address
23
+ # This will reliably cause a segmentation fault when called
24
+ function_type = ctypes.CFUNCTYPE(None)
25
+ # Use a non-zero but invalid address to avoid ctypes null pointer checks
26
+ invalid_address = 0xDEADBEEF
27
+ invalid_function = function_type(invalid_address)
28
+ # Calling this function will cause a segfault
29
+ invalid_function()
30
+
31
+ @endpoint
32
+ async def cause_panic(self) -> None:
33
+ """Endpoint that calls a Rust function that panics."""
34
+ panicking_function()
35
+
36
+
37
+ class ErrorActorSync(Actor):
38
+ """An actor that has endpoints cause segfaults."""
39
+
40
+ @endpoint # pyre-ignore
41
+ def cause_segfault(self) -> None:
42
+ """Endpoint that causes a segmentation fault."""
43
+ # Create a C function pointer to an invalid memory address
44
+ # This will reliably cause a segmentation fault when called
45
+ function_type = ctypes.CFUNCTYPE(None)
46
+ # Use a non-zero but invalid address to avoid ctypes null pointer checks
47
+ invalid_address = 0xDEADBEEF
48
+ invalid_function = function_type(invalid_address)
49
+ # Calling this function will cause a segfault
50
+ invalid_function()
51
+
52
+ @endpoint # pyre-ignore
53
+ def cause_panic(self) -> None:
54
+ """Endpoint that calls a Rust function that panics."""
55
+ panicking_function()
56
+
57
+
58
+ def _run_error_test_sync(num_procs, sync_endpoint, endpoint_name):
59
+ proc = proc_mesh(gpus=num_procs).get()
60
+ if sync_endpoint:
61
+ actor_class = ErrorActorSync
62
+ else:
63
+ actor_class = ErrorActor
64
+ error_actor = proc.spawn("error_actor", actor_class).get()
65
+
66
+ # This output is checked in the test to make sure that the process actually got here
67
+ print("I actually ran")
68
+ sys.stdout.flush()
69
+
70
+ if endpoint_name == "cause_segfault":
71
+ endpoint = error_actor.cause_segfault
72
+ elif endpoint_name == "cause_panic":
73
+ endpoint = error_actor.cause_panic
74
+ else:
75
+ raise ValueError(f"Unknown endpoint name: {endpoint_name}")
76
+
77
+ # Exercise both call() and call_one() in our tests, to check that error
78
+ # aggregation behavior is consistent.
79
+ if num_procs == 1:
80
+ endpoint.call_one().get()
81
+ else:
82
+ endpoint.call().get()
83
+
84
+
85
+ def _run_error_test(num_procs, sync_endpoint, endpoint_name):
86
+ import asyncio
87
+
88
+ if sync_endpoint:
89
+ actor_class = ErrorActorSync
90
+ else:
91
+ actor_class = ErrorActor
92
+
93
+ async def run_test():
94
+ proc = await proc_mesh(gpus=num_procs)
95
+ error_actor = await proc.spawn("error_actor", actor_class)
96
+
97
+ # This output is checked in the test to make sure that the process actually got here
98
+ print("I actually ran")
99
+ sys.stdout.flush()
100
+
101
+ if endpoint_name == "cause_segfault":
102
+ endpoint = error_actor.cause_segfault
103
+ elif endpoint_name == "cause_panic":
104
+ endpoint = error_actor.cause_panic
105
+ else:
106
+ raise ValueError(f"Unknown endpoint name: {endpoint_name}")
107
+
108
+ # Exercise both call() and call_one() in our tests, to check that error
109
+ # aggregation behavior is consistent.
110
+ if num_procs == 1:
111
+ await endpoint.call_one()
112
+ else:
113
+ await endpoint.call()
114
+
115
+ asyncio.run(run_test())
116
+
117
+
118
+ def main():
119
+ import argparse
120
+
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument("--num-procs", type=int)
123
+ parser.add_argument("--sync-test-impl", type=bool)
124
+ parser.add_argument("--sync-endpoint", type=bool)
125
+ parser.add_argument("--endpoint-name", type=str)
126
+ args = parser.parse_args()
127
+
128
+ print(
129
+ f"Running segfault test: {args.num_procs=} {args.sync_test_impl=} {args.sync_endpoint=}, {args.endpoint_name=}"
130
+ )
131
+
132
+ if args.sync_test_impl:
133
+ _run_error_test_sync(args.num_procs, args.sync_endpoint, args.endpoint_name)
134
+ else:
135
+ _run_error_test(args.num_procs, args.sync_endpoint, args.endpoint_name)
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
File without changes
@@ -0,0 +1,136 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import unittest
9
+
10
+ import pytest
11
+
12
+ import torch
13
+
14
+ from monarch.common import messages
15
+ from monarch.simulator.profiling import RuntimeEstimator, RuntimeProfiler, TimingType
16
+
17
+
18
+ # pyre-ignore-all-errors[6]
19
+ # pyre-ignore-all-errors[16]
20
+ class TestRuntimeEstimator(unittest.TestCase):
21
+ def test_user_manual_setting(self):
22
+ runtime = RuntimeEstimator()
23
+
24
+ input_tensor = torch.rand(10, 10)
25
+ input_tensor.ref = 1
26
+ input_tensor._fake = None
27
+ output_tensor = torch.rand(10, 10)
28
+ output_tensor.ref = 2
29
+ output_tensor._fake = None
30
+
31
+ send_tensor = messages.SendTensor(
32
+ result=output_tensor,
33
+ from_ranks=[1],
34
+ to_ranks=[2],
35
+ tensor=input_tensor,
36
+ factory=None,
37
+ from_stream=None,
38
+ to_stream=None,
39
+ )
40
+ reduce = messages.Reduce(
41
+ result=output_tensor,
42
+ local_tensor=input_tensor,
43
+ factory=None,
44
+ source_mesh=None,
45
+ stream=None,
46
+ dims=None,
47
+ reduction=None,
48
+ scatter=False,
49
+ inplace=False,
50
+ out=None,
51
+ )
52
+ call_function = messages.CallFunction(
53
+ ident=1,
54
+ result=None,
55
+ mutates=None,
56
+ function=None,
57
+ args=None,
58
+ kwargs=None,
59
+ stream=None,
60
+ device_mesh=None,
61
+ remote_process_groups=None,
62
+ )
63
+
64
+ self.assertEqual(runtime.get_runtime(send_tensor), 100_000)
65
+ self.assertEqual(runtime.get_runtime(reduce), 100_000)
66
+ self.assertEqual(runtime.get_runtime(call_function), 10_000)
67
+ self.assertEqual(runtime.get_runtime("kernel_launch"), 500)
68
+ self.assertEqual(runtime.get_runtime("wait_event"), 500)
69
+
70
+ runtime.set_custom_timing(
71
+ {
72
+ TimingType.SEND_TENSOR: 1_000,
73
+ TimingType.REDUCE: 2_000,
74
+ TimingType.CALL_FUNCTION: 3_000,
75
+ TimingType.KERNEL_LAUNCH: 4_000,
76
+ TimingType.WAIT_EVENT: 5_000,
77
+ }
78
+ )
79
+ self.assertEqual(runtime.get_runtime(send_tensor), 1_000)
80
+ self.assertEqual(runtime.get_runtime(reduce), 2_000)
81
+ self.assertEqual(runtime.get_runtime(call_function), 3_000)
82
+ self.assertEqual(runtime.get_runtime("kernel_launch"), 4_000)
83
+ self.assertEqual(runtime.get_runtime("wait_event"), 5_000)
84
+
85
+ runtime.set_custom_timing(
86
+ {
87
+ TimingType.SEND_TENSOR: lambda msg: 4_000,
88
+ TimingType.REDUCE: lambda msg: 5_000,
89
+ TimingType.CALL_FUNCTION: lambda msg: 6_000,
90
+ TimingType.KERNEL_LAUNCH: lambda: 8_000,
91
+ TimingType.WAIT_EVENT: lambda: 9_000,
92
+ }
93
+ )
94
+ self.assertEqual(runtime.get_runtime(send_tensor), 4_000)
95
+ self.assertEqual(runtime.get_runtime(reduce), 5_000)
96
+ self.assertEqual(runtime.get_runtime(call_function), 6_000)
97
+ self.assertEqual(runtime.get_runtime("kernel_launch"), 8_000)
98
+ self.assertEqual(runtime.get_runtime("wait_event"), 9_000)
99
+
100
+ @pytest.mark.oss_skip
101
+ def test_runtime_profiler(self) -> None:
102
+ m1 = torch.rand(1000, 2000).cuda()
103
+ m2 = torch.rand(2000, 4000).cuda()
104
+ m1.ref = 1
105
+ m2.ref = 2
106
+ msg = messages.CallFunction(
107
+ ident=1,
108
+ result=None,
109
+ mutates=None,
110
+ function=torch.ops.aten.mm.default,
111
+ args=(m1, m2),
112
+ kwargs=None,
113
+ stream=None,
114
+ device_mesh=None,
115
+ remote_process_groups=None,
116
+ )
117
+ profiler = RuntimeProfiler()
118
+
119
+ ret = profiler.profile_cmd(msg, ranks=[0])[0]
120
+ self.assertEqual(ret[0].factory.size, (1000, 4000))
121
+ # Should be at least 0.1 ms
122
+ self.assertTrue(ret[1] > 100)
123
+ # Should be at most 100 ms
124
+ self.assertTrue(ret[1] < 100_000)
125
+
126
+ # Change the cached profiling result to verify if cached mechanism works
127
+ key = next(iter(profiler.cached.keys()))
128
+ profiler.cached[key][0] = (profiler.cached[key][0][0], 987_654_321)
129
+
130
+ ret = profiler.profile_cmd(msg, ranks=[0])[0]
131
+ self.assertEqual(ret[0].factory.size, (1000, 4000))
132
+ self.assertEqual(ret[1], 987_654_321)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ unittest.main()