torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/tensorboard.py ADDED
@@ -0,0 +1,84 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ from typing import Any
10
+
11
+ from monarch.common.device_mesh import DeviceMesh
12
+ from monarch.remote_class import ControllerRemoteClass, WorkerRemoteClass
13
+ from torch.utils.tensorboard import SummaryWriter
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class Tensorboard(ControllerRemoteClass):
19
+ def __init__(self, coordinator: DeviceMesh, path: str, *args, **kwargs) -> None:
20
+ from monarch import IN_PAR
21
+
22
+ self.path = path
23
+ self.url: str = ""
24
+ self.coordinator = coordinator
25
+ if path.startswith("manifold://"):
26
+ if not IN_PAR:
27
+ raise RuntimeError(
28
+ "Cannot save tensorboard to manifold with conda environment. "
29
+ "Save to the local filesystem or oilfs instead"
30
+ )
31
+
32
+ manifold_url = f"https://internalfb.com/intern/tensorboard/?dir={path}"
33
+ self.url = manifold_url
34
+ else:
35
+ self.url = path
36
+
37
+ # Only create tensorboard for the coordinator rank.
38
+ with self.coordinator.activate():
39
+ super().__init__(
40
+ "monarch.tensorboard._WorkerSummaryWriter",
41
+ path,
42
+ *args,
43
+ **kwargs,
44
+ )
45
+
46
+ logger.info("Run `tensorboard --logdir %s` to launch the tensorboard.", path)
47
+
48
+ @ControllerRemoteClass.remote_method
49
+ def _log(self, name: str, data: Any, step: int) -> None:
50
+ pass
51
+
52
+ @ControllerRemoteClass.remote_method
53
+ def _flush(self) -> None:
54
+ pass
55
+
56
+ @ControllerRemoteClass.remote_method
57
+ def _close(self) -> None:
58
+ pass
59
+
60
+ def log(self, name: str, data: Any, step: int) -> None:
61
+ with self.coordinator.activate():
62
+ self._log(name, data, step)
63
+
64
+ def flush(self) -> None:
65
+ with self.coordinator.activate():
66
+ self._flush()
67
+
68
+ def close(self) -> None:
69
+ with self.coordinator.activate():
70
+ self._close()
71
+
72
+
73
+ class _WorkerSummaryWriter(WorkerRemoteClass):
74
+ def __init__(self, path: str, *args, **kwargs) -> None:
75
+ self._writer = SummaryWriter(path, *args, **kwargs)
76
+
77
+ def _log(self, name: str, data: Any, step: int) -> None:
78
+ self._writer.add_scalar(name, data, global_step=step, new_style=True)
79
+
80
+ def _flush(self) -> None:
81
+ self._writer.flush()
82
+
83
+ def _close(self) -> None:
84
+ self._writer.close()
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .execution_timer import (
8
+ execution_timer_start,
9
+ execution_timer_stop,
10
+ ExecutionTimer,
11
+ get_execution_timer_average_ms,
12
+ get_latest_timer_measurement,
13
+ )
14
+
15
+ __all__ = [
16
+ "ExecutionTimer",
17
+ "execution_timer_start",
18
+ "execution_timer_stop",
19
+ "get_latest_timer_measurement",
20
+ "get_execution_timer_average_ms",
21
+ ]
@@ -0,0 +1,78 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """An example that demonstrates how to use ExecutionTimer with a Monarch program.
8
+
9
+ Run this with
10
+ buck run //monarch/python/monarch/timer:example_monarch
11
+
12
+ """
13
+ # pyre-unsafe
14
+
15
+ import logging
16
+
17
+ import torch
18
+
19
+ from monarch import inspect, remote
20
+ from monarch.rust_local_mesh import local_mesh
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ execution_timer_start = remote(
26
+ "monarch.timer.remote_execution_timer.execution_timer_start", propagate="inspect"
27
+ )
28
+
29
+ execution_timer_stop = remote(
30
+ "monarch.timer.remote_execution_timer.execution_timer_stop", propagate="inspect"
31
+ )
32
+
33
+ get_execution_timer_average_ms = remote(
34
+ "monarch.timer.remote_execution_timer.get_execution_timer_average_ms",
35
+ propagate=lambda: torch.tensor(0.0, dtype=torch.float64),
36
+ )
37
+
38
+ get_time_perfcounter = remote(
39
+ "monarch.timer.remote_execution_timer.get_time_perfcounter",
40
+ propagate=lambda: torch.tensor(0.0, dtype=torch.float64),
41
+ )
42
+
43
+
44
+ def main() -> None:
45
+ with local_mesh(hosts=1, gpus_per_host=1) as mesh:
46
+ with mesh.activate():
47
+ num_iterations = 5
48
+
49
+ a = torch.randn(1000, 1000, device="cuda")
50
+ b = torch.randn(1000, 1000, device="cuda")
51
+ torch.matmul(a, b)
52
+
53
+ total_dt = torch.zeros(1, dtype=torch.float64)
54
+
55
+ for _ in range(num_iterations):
56
+ t0 = get_time_perfcounter()
57
+ torch.matmul(a, b)
58
+ total_dt += get_time_perfcounter() - t0
59
+
60
+ for _ in range(num_iterations):
61
+ execution_timer_start()
62
+ torch.matmul(a, b)
63
+ execution_timer_stop()
64
+
65
+ cuda_average_ms = get_execution_timer_average_ms()
66
+ local_total_dt = inspect(total_dt)
67
+ local_cuda_avg_ms = inspect(cuda_average_ms)
68
+
69
+ local_total_dt = local_total_dt.item()
70
+ local_cuda_avg_ms = local_cuda_avg_ms.item()
71
+ mesh.exit()
72
+ avg_perfcounter_ms = local_total_dt / num_iterations * 1000
73
+ print(f"average time w/ perfcounter: {avg_perfcounter_ms:.4f} (ms)")
74
+ print(f"average time w/ ExecutionTimer: {local_cuda_avg_ms:.4f} (ms)")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
@@ -0,0 +1,55 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """An example that demonstrates how to use ExecutionTimer in a SPMD style program.
8
+
9
+ Run this with:
10
+ buck run //monarch/python/monarch/timer:example_spmd
11
+ """
12
+
13
+ import time
14
+
15
+ # pyre-strict
16
+
17
+ import torch
18
+ from monarch.timer import ExecutionTimer
19
+
20
+
21
+ def main() -> None:
22
+ # Check if CUDA is available
23
+ if not torch.cuda.is_available():
24
+ print("CUDA is not available. Exiting.")
25
+ return
26
+
27
+ device = torch.device("cuda")
28
+
29
+ num_iterations = 5
30
+
31
+ a = torch.randn(1000, 1000, device=device)
32
+ b = torch.randn(1000, 1000, device=device)
33
+
34
+ # Warmup
35
+ torch.matmul(a, b)
36
+ torch.cuda.synchronize()
37
+
38
+ cpu_timings = []
39
+ for _ in range(num_iterations):
40
+ t0 = time.perf_counter()
41
+ torch.matmul(a, b)
42
+ cpu_timings.append(time.perf_counter() - t0)
43
+
44
+ for _ in range(num_iterations):
45
+ with ExecutionTimer.time("matrix_multiply"):
46
+ torch.matmul(a, b)
47
+
48
+ mean_cuda_ms = ExecutionTimer.summary()["matrix_multiply"]["mean_ms"]
49
+ mean_perfcounter_ms = sum(cpu_timings) / len(cpu_timings) * 1000
50
+ print("mean perf counter times: ", mean_perfcounter_ms)
51
+ print("mean cuda times: ", mean_cuda_ms)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ main()
@@ -0,0 +1,199 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """A simple timer that utilizes CUDA events to measure time spent in GPU kernels."""
8
+
9
+ # pyre-strict
10
+ import logging
11
+ import threading
12
+ import time
13
+ from contextlib import contextmanager
14
+ from typing import Any, Dict, Generator, List, Optional, Tuple
15
+
16
+ import torch
17
+
18
+
19
+ class ExecutionTimer:
20
+ """
21
+ A lightweight timer for measuring CPU or GPU execution time.
22
+ """
23
+
24
+ _enable_cuda: bool = torch.cuda.is_available()
25
+ _times: Dict[str, List[float]] = {}
26
+ _lock = threading.Lock()
27
+ _threads: Dict[str, List[threading.Thread]] = {}
28
+ _events: Dict[
29
+ str, List[Tuple[torch.cuda.Event, torch.cuda.Event, torch.cuda.Stream]]
30
+ ] = {}
31
+ _cuda_warning_shown: bool = False
32
+ _cpu_start_times: Dict[str, List[float]] = {}
33
+
34
+ @classmethod
35
+ @contextmanager
36
+ # pyre-fixme[3]: Return type must be specified as type that does not contain `Any`.
37
+ def time(
38
+ cls, name: Optional[str] = None, use_cpu: bool = False
39
+ ) -> Generator[None, Any, Any]:
40
+ """
41
+ Context manager for timing an operation.
42
+ Args:
43
+ name (str): Name of the timing section
44
+ use_cpu (bool): Whether to use CPU time instead of CUDA time. Defaults to false.
45
+ Example:
46
+ with ExecutionTimer.time("matrix_multiply"):
47
+ result = torch.matmul(a, b)
48
+
49
+ with ExecutionTimer.time("sleep", use_cpu=True):
50
+ time.sleep(1)
51
+ """
52
+ cls.start(name, use_cpu)
53
+ try:
54
+ yield
55
+ finally:
56
+ cls.stop(name, use_cpu)
57
+
58
+ @classmethod
59
+ def start(cls, name: Optional[str] = None, use_cpu: bool = False) -> None:
60
+ if not cls._enable_cuda and cls._cuda_warning_shown:
61
+ logging.warning("CUDA not available, falling back to CPU timing")
62
+ cls._cuda_warning_shown = True
63
+
64
+ if not name:
65
+ name = "default"
66
+ if name not in cls._times:
67
+ cls._times[name] = []
68
+
69
+ if not cls._enable_cuda or use_cpu:
70
+ if name not in cls._cpu_start_times:
71
+ cls._cpu_start_times[name] = []
72
+ cls._cpu_start_times[name].append(time.perf_counter())
73
+ else:
74
+ stream = torch.cuda.current_stream()
75
+ start_event = torch.cuda.Event(enable_timing=True)
76
+ end_event = torch.cuda.Event(enable_timing=True)
77
+ start_event.record(stream)
78
+ if name not in cls._events:
79
+ cls._events[name] = []
80
+ cls._events[name].append((start_event, end_event, stream))
81
+
82
+ @classmethod
83
+ def stop(cls, name: Optional[str] = None, use_cpu: bool = False) -> None:
84
+ if not name:
85
+ name = "default"
86
+
87
+ if not cls._enable_cuda or use_cpu:
88
+ assert (
89
+ name in cls._cpu_start_times
90
+ ), f"No CPU start time found for {name}, did you run start()?"
91
+ start_time = cls._cpu_start_times[name].pop()
92
+ elapsed_time_ms = (time.perf_counter() - start_time) * 1000
93
+ with cls._lock:
94
+ cls._times[name].append(elapsed_time_ms)
95
+
96
+ if name in cls._events and cls._events[name]:
97
+ start_event, end_event, stream = cls._events[name].pop()
98
+ end_event.record(stream)
99
+
100
+ # We create a separate thread to poll on the event status
101
+ # to avoid blocking the main thread.
102
+ thread = threading.Thread(
103
+ target=cls._check_event_completion, args=(name, start_event, end_event)
104
+ )
105
+ thread.start()
106
+ if name not in cls._threads:
107
+ cls._threads[name] = []
108
+ cls._threads[name].append(thread)
109
+
110
+ @classmethod
111
+ def _check_event_completion(
112
+ cls, name: str, start_event: torch.cuda.Event, end_event: torch.cuda.Event
113
+ ) -> None:
114
+ while True:
115
+ if end_event.query():
116
+ with cls._lock:
117
+ cuda_time = start_event.elapsed_time(end_event)
118
+ cls._times[name].append(cuda_time)
119
+ break
120
+ time.sleep(0.01)
121
+
122
+ @classmethod
123
+ def reset(cls) -> None:
124
+ """Clear all timing data."""
125
+ with cls._lock:
126
+ cls._times = {}
127
+ cls._threads = {}
128
+
129
+ @classmethod
130
+ def summary(cls) -> Dict[str, Dict[str, float]]:
131
+ """
132
+ Get summary of all timing data.
133
+ Returns:
134
+ Dict containing timing statistics for each section
135
+ """
136
+ # Wait for all in-flight measurements to complete
137
+ for _, threads in cls._threads.items():
138
+ for thread in threads:
139
+ thread.join()
140
+ with cls._lock:
141
+ result = {}
142
+ for name, times in cls._times.items():
143
+ if not times:
144
+ continue
145
+ result[name] = {
146
+ "count": len(times),
147
+ "mean_ms": sum(times) / len(times),
148
+ "total_ms": sum(times),
149
+ "min_ms": min(times),
150
+ "max_ms": max(times),
151
+ }
152
+ return result
153
+
154
+ @classmethod
155
+ def get_latest_measurement(cls, name: Optional[str] = None) -> float:
156
+ """Get the latest measurement (in ms) for a given section."""
157
+ if not name:
158
+ name = "default"
159
+ if name in cls._threads:
160
+ for thread in cls._threads[name]:
161
+ thread.join()
162
+ cls._threads[name] = []
163
+ with cls._lock:
164
+ if name not in cls._times or not cls._times[name]:
165
+ logging.warning(f"Section {name} not found in timing data.")
166
+ return 0.0
167
+ return cls._times[name][-1]
168
+
169
+
170
+ def execution_timer_start(name: Optional[str] = None, use_cpu: bool = False) -> None:
171
+ """Start the ExecutionTimer."""
172
+ ExecutionTimer.start(name=name, use_cpu=use_cpu)
173
+
174
+
175
+ def execution_timer_stop(name: Optional[str] = None, use_cpu: bool = False) -> None:
176
+ """Stop the ExecutionTimer."""
177
+ ExecutionTimer.stop(name=name, use_cpu=use_cpu)
178
+
179
+
180
+ def get_execution_timer_average_ms(name: str = "default") -> torch.Tensor:
181
+ """Get the ExecutionTimer results."""
182
+ return torch.tensor(ExecutionTimer.summary()[name]["mean_ms"], dtype=torch.float64)
183
+
184
+
185
+ def get_latest_timer_measurement(name: Optional[str] = None) -> torch.Tensor:
186
+ """Get the latest ExecutionTimer results."""
187
+ return torch.tensor(
188
+ ExecutionTimer.get_latest_measurement(name), dtype=torch.float64
189
+ )
190
+
191
+
192
+ def get_execution_timer_summary() -> Dict[str, Dict[str, float]]:
193
+ """Get the ExecutionTimer summary."""
194
+ return ExecutionTimer.summary()
195
+
196
+
197
+ def get_time_perfcounter() -> torch.Tensor:
198
+ """Get the time performance counter. Should be used only for debugging."""
199
+ return torch.tensor(time.perf_counter(), dtype=torch.float64)
@@ -0,0 +1,131 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Test suite for ExecutionTimer class."""
8
+
9
+ # pyre-strict
10
+
11
+ import time
12
+ import unittest
13
+ from unittest.mock import MagicMock, patch
14
+
15
+ from monarch.timer.execution_timer import ExecutionTimer
16
+
17
+
18
+ class TestExecutionTimer(unittest.TestCase):
19
+ """Test suite for the ExecutionTimer class."""
20
+
21
+ def setUp(self) -> None:
22
+ """Reset the profiler state before each test."""
23
+ ExecutionTimer.reset()
24
+
25
+ def test_basic_timing(self) -> None:
26
+ """Test basic CUDA timing functionality."""
27
+ with ExecutionTimer.time("test_section"):
28
+ time.sleep(0.01) # Sleep for 10ms
29
+
30
+ # Get the stats
31
+ stats = ExecutionTimer.summary()
32
+
33
+ # Check that our section exists
34
+ self.assertIn("test_section", stats)
35
+
36
+ # Check timing (should be at least 10ms, allow some overhead)
37
+ section_stats = stats["test_section"]
38
+ self.assertEqual(section_stats["count"], 1)
39
+ self.assertGreaterEqual(section_stats["mean_ms"], 10) # At least 10ms
40
+ self.assertLess(section_stats["mean_ms"], 50) # Reasonable upper bound
41
+
42
+ def test_multiple_timing_same_section(self) -> None:
43
+ """Test timing the same section multiple times."""
44
+ for _ in range(5):
45
+ with ExecutionTimer.time("repeated_section"):
46
+ time.sleep(0.01)
47
+
48
+ stats = ExecutionTimer.summary()
49
+ self.assertIn("repeated_section", stats)
50
+
51
+ section_stats = stats["repeated_section"]
52
+ self.assertEqual(section_stats["count"], 5)
53
+ self.assertGreaterEqual(section_stats["mean_ms"], 10)
54
+ self.assertGreaterEqual(section_stats["total_ms"], 50)
55
+
56
+ def test_reset(self) -> None:
57
+ """Test that reset clears all timing data."""
58
+ with ExecutionTimer.time("before_reset"):
59
+ time.sleep(0.01)
60
+
61
+ # Verify the section exists
62
+ stats_before = ExecutionTimer.summary()
63
+ self.assertIn("before_reset", stats_before)
64
+
65
+ # Reset the profiler
66
+ ExecutionTimer.reset()
67
+
68
+ # Timing section should be gone
69
+ stats_after = ExecutionTimer.summary()
70
+ self.assertNotIn("before_reset", stats_after)
71
+
72
+ @patch("torch.cuda.is_available")
73
+ @patch("torch.cuda.Event")
74
+ def test_cuda_mocked(
75
+ self, mock_event: MagicMock, mock_is_available: MagicMock
76
+ ) -> None:
77
+ """Test CUDA timing with mocked CUDA functions."""
78
+
79
+ mock_is_available.return_value = True
80
+
81
+ mock_event_instance = MagicMock()
82
+ mock_event_instance.elapsed_time.return_value = 15.0 # 15ms
83
+ mock_event.return_value = mock_event_instance
84
+
85
+ with ExecutionTimer.time("mocked_cuda"):
86
+ time.sleep(0.01)
87
+
88
+ stats = ExecutionTimer.summary()
89
+
90
+ # Should have CUDA timings
91
+ self.assertIn("mocked_cuda", stats)
92
+ self.assertEqual(stats["mocked_cuda"]["mean_ms"], 15.0)
93
+
94
+ def test_get_latest_measurement(self) -> None:
95
+ """Test get_latest_measurement."""
96
+ with ExecutionTimer.time("latest_measurement_test"):
97
+ time.sleep(0.01) # Sleep for 10ms
98
+
99
+ # Get the latest measurement
100
+ latest_measurement = ExecutionTimer.get_latest_measurement(
101
+ "latest_measurement_test"
102
+ )
103
+
104
+ self.assertGreaterEqual(latest_measurement, 5)
105
+ self.assertLess(latest_measurement, 50) # Reasonable upper bound
106
+
107
+ # Test for a non-existent section
108
+ non_existent_measurement = ExecutionTimer.get_latest_measurement(
109
+ "non_existent_section"
110
+ )
111
+ self.assertEqual(non_existent_measurement, 0.0)
112
+
113
+ def test_cpu_timing(self) -> None:
114
+ """Test CPU timing functionality."""
115
+ with ExecutionTimer.time("cpu_section", use_cpu=True):
116
+ time.sleep(0.01) # Sleep for 10ms
117
+ # Get the stats
118
+ stats = ExecutionTimer.summary()
119
+
120
+ # Check that our section exists
121
+ self.assertIn("cpu_section", stats)
122
+
123
+ # Check timing (should be at least 10ms, allow some overhead)
124
+ section_stats = stats["cpu_section"]
125
+ self.assertEqual(section_stats["count"], 1)
126
+ self.assertGreaterEqual(section_stats["mean_ms"], 10) # At least 10ms
127
+ self.assertLess(section_stats["mean_ms"], 50) # Reasonable upper bound
128
+
129
+
130
+ if __name__ == "__main__":
131
+ unittest.main()
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict