torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/profiler.py ADDED
@@ -0,0 +1,160 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import itertools
9
+ import os
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from pathlib import Path
13
+ from typing import Any, Dict, NamedTuple, Optional, Tuple
14
+
15
+ import torch
16
+ from monarch.common.remote import remote
17
+ from monarch.remote_class import ControllerRemoteClass, WorkerRemoteClass
18
+
19
+
20
+ class Schedule(NamedTuple):
21
+ wait: int
22
+ warmup: int
23
+ active: int
24
+ repeat: int = 0
25
+ skip_first: int = 0
26
+
27
+
28
+ class profile:
29
+ """
30
+ The class wraps `torch.profiler.profile()` to allow invoking the profiler remotely.
31
+ There are two main differences:
32
+ 1) `on_trace_ready` can only be a string, indicating the folder where the traces
33
+ will be saved.
34
+ 2) `schedule` must be of type `monarch.profiler.Schedule`.
35
+ """
36
+
37
+ PATH_KEY = "on_trace_ready"
38
+ _counter = itertools.count()
39
+
40
+ def __init__(self, *args, **kwargs) -> None:
41
+ assert isinstance(kwargs.get(self.PATH_KEY, None), str), (
42
+ f"{self.PATH_KEY} must be passed and must be a string to represent the "
43
+ "path to save the profiler."
44
+ )
45
+ schedule = kwargs.get("schedule", None)
46
+ assert (
47
+ isinstance(schedule, Schedule) or schedule is None
48
+ ), "schedule can only be monarch.profiler.Schedule or None."
49
+ self.id = next(self._counter)
50
+ _profiler_controller_init(self.id, *args, **kwargs)
51
+
52
+ def __enter__(self) -> "profile":
53
+ _profiler_controller_enter(self.id)
54
+ return self
55
+
56
+ def __exit__(self, *args, **kwargs) -> None:
57
+ _profiler_controller_exit(self.id)
58
+
59
+ def step(self) -> None:
60
+ _profiler_controller_step(self.id)
61
+
62
+
63
+ @dataclass
64
+ class _Profiler:
65
+ args: Tuple[Any, ...]
66
+ kwargs: Dict[str, Any]
67
+ profiler: Optional[torch.profiler.profile] = None
68
+
69
+
70
+ _profilers: Dict[int, _Profiler] = {}
71
+
72
+
73
+ def _profiler_init(ident, *args, **kwargs) -> None:
74
+ global _profilers
75
+ assert (
76
+ ident not in _profilers
77
+ ), f"Initializing an already existing profiler, {ident=}"
78
+ _profilers[ident] = _Profiler(args, kwargs)
79
+ # It's unclear why we cannot create the profiler here. Even though
80
+ # the thread is the same, profiler complains thread id mismatch.
81
+
82
+
83
+ def _profiler_enter(ident, *args, **kwargs) -> None:
84
+ def on_trace_ready(prof, dir_path):
85
+ dir_path = Path(dir_path).absolute()
86
+ os.makedirs(dir_path, exist_ok=True)
87
+ # This is not a synchronized call, so it is okay to call without
88
+ # device mesh.
89
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
90
+ prof.export_chrome_trace(f"{dir_path}/trace_{rank}.json")
91
+
92
+ profiler = _profilers[ident]
93
+ profiler.kwargs[profile.PATH_KEY] = partial(
94
+ on_trace_ready, dir_path=profiler.kwargs[profile.PATH_KEY]
95
+ )
96
+ schedule = profiler.kwargs.get("schedule", None)
97
+ if schedule is not None:
98
+ profiler.kwargs["schedule"] = torch.profiler.schedule(**schedule._asdict())
99
+ profiler.profiler = torch.profiler.profile(*profiler.args, **profiler.kwargs)
100
+
101
+ profiler.profiler.__enter__()
102
+
103
+
104
+ def _profiler_exit(ident, *args, **kwargs) -> None:
105
+ profiler = _profilers[ident].profiler
106
+ assert profiler is not None
107
+ profiler.__exit__(None, None, None)
108
+ _profilers.pop(ident)
109
+
110
+
111
+ def _profiler_step(ident, *args, **kwargs) -> None:
112
+ profiler = _profilers[ident].profiler
113
+ assert profiler is not None
114
+ profiler.step()
115
+
116
+
117
+ _profiler_controller_init = remote(
118
+ "monarch.profiler._profiler_init", propagate="inspect"
119
+ )
120
+
121
+ _profiler_controller_enter = remote(
122
+ "monarch.profiler._profiler_enter", propagate="inspect"
123
+ )
124
+
125
+ _profiler_controller_exit = remote(
126
+ "monarch.profiler._profiler_exit", propagate="inspect"
127
+ )
128
+
129
+ _profiler_controller_step = remote(
130
+ "monarch.profiler._profiler_step", propagate="inspect"
131
+ )
132
+
133
+
134
+ class record_function(ControllerRemoteClass):
135
+ """
136
+ The class wraps `torch.profiler.record_function()` to allow invoking the
137
+ record_function remotely.
138
+ """
139
+
140
+ def __init__(self, name: str, args: Optional[str] = None) -> None:
141
+ super().__init__("monarch.profiler.WorkerRecordFunction", name, args)
142
+
143
+ @ControllerRemoteClass.remote_method
144
+ def __enter__(self) -> "record_function":
145
+ return self
146
+
147
+ @ControllerRemoteClass.remote_method
148
+ def __exit__(self, *args, **kwargs) -> None:
149
+ return
150
+
151
+
152
+ class WorkerRecordFunction(WorkerRemoteClass):
153
+ def __init__(self, *args, **kwargs) -> None:
154
+ self._record_function = torch.profiler.record_function(*args, **kwargs)
155
+
156
+ def __enter__(self) -> None:
157
+ self._record_function.__enter__()
158
+
159
+ def __exit__(self, *args, **kwargs) -> None:
160
+ self._record_function.__exit__(*args, **kwargs)
@@ -0,0 +1,107 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import os
9
+ import subprocess
10
+ from time import sleep
11
+ from typing import Optional, TYPE_CHECKING
12
+
13
+ import monarch_supervisor
14
+ from monarch.common._device_utils import _local_device_count
15
+ from monarch.common.fake import fake_call
16
+ from monarch.common.invocation import DeviceException, RemoteException
17
+ from monarch.world_mesh import world_mesh
18
+ from monarch_supervisor import Context, HostConnected
19
+ from monarch_supervisor.python_executable import PYTHON_EXECUTABLE
20
+
21
+ if TYPE_CHECKING:
22
+ from monarch.common.device_mesh import DeviceMesh
23
+
24
+
25
+ class PythonLocalContext:
26
+ def __init__(self, N: int):
27
+ # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
28
+ fake_call(lambda: 0)
29
+
30
+ self.ctx = ctx = Context()
31
+ ctx.request_hosts(N)
32
+
33
+ # we want ctx to start its listener threads
34
+ # before creating the hosts because
35
+ # initialization will happen faster in this case
36
+ sleep(0)
37
+ supervisor_addr = f"tcp://127.0.0.1:{ctx.port}"
38
+
39
+ env = {
40
+ **os.environ,
41
+ "TORCH_SUPERVISOR_HEARTBEAT_INTERVAL": str(
42
+ monarch_supervisor.HEARTBEAT_INTERVAL
43
+ ),
44
+ # This is needed to avoid a hard failure in ncclx when we do not
45
+ # have backend topology info (eg. on RE).
46
+ "NCCL_IGNORE_TOPO_LOAD_FAILURE": "true",
47
+ }
48
+
49
+ # start_new_session=True, because we want the host managers to be able to kill
50
+ # any worker processes before they exit, even if the supervisor crashes, or we ctrl-c
51
+ # it in testing.
52
+ self.host_managers = [
53
+ subprocess.Popen(
54
+ [
55
+ PYTHON_EXECUTABLE,
56
+ "-m",
57
+ "monarch_supervisor.host",
58
+ supervisor_addr,
59
+ ],
60
+ env=env,
61
+ start_new_session=True,
62
+ )
63
+ for _ in range(N)
64
+ ]
65
+ connections = ctx.messagefilter(HostConnected)
66
+ self.hosts = [connections.recv(timeout=30).sender for _ in range(N)]
67
+
68
+ def shutdown(self):
69
+ self.ctx.shutdown()
70
+ for host_manager in self.host_managers:
71
+ host_manager.wait(timeout=10)
72
+
73
+
74
+ def python_local_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> "DeviceMesh":
75
+ """
76
+ Creates a local device mesh with the given number of hosts and gpus per host.
77
+ Easy way to use PythonLocalContext.
78
+
79
+ Args:
80
+ gpus (Optional[int]): number of gpus per host.
81
+ Default: the number of GPUs this machine has.
82
+
83
+ hosts (int): number of hosts, primarily used for simulating multiple machines locally.
84
+ Default: 1
85
+
86
+ Example::
87
+ local_mesh = python_local_mesh(gpus=2)
88
+ with local_mesh.activate():
89
+ x = torch.rand(3, 4)
90
+ local_tensor = fetch_shard(x).result()
91
+
92
+ # Cleanly shut down the local mesh and exit.
93
+ local_mesh.exit()
94
+ """
95
+ ctx = PythonLocalContext(hosts)
96
+ if gpus is None:
97
+ gpus = _local_device_count()
98
+ dm = world_mesh(ctx.ctx, ctx.hosts, gpus)
99
+
100
+ def exit(
101
+ error: Optional[RemoteException | DeviceException | Exception] = None,
102
+ ) -> None:
103
+ dm.client.shutdown(True, error)
104
+ ctx.shutdown()
105
+
106
+ dm.exit = exit
107
+ return dm
monarch/random.py ADDED
@@ -0,0 +1,61 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from typing import NamedTuple, Tuple
9
+
10
+ import torch
11
+ from monarch.common.remote import remote
12
+ from monarch.common.tensor import Tensor
13
+
14
+
15
+ class State(NamedTuple):
16
+ cpu: Tensor
17
+ cuda: Tensor
18
+
19
+
20
+ @remote(
21
+ propagate=lambda: (
22
+ torch.empty(5056, dtype=torch.uint8),
23
+ torch.empty(16, dtype=torch.uint8),
24
+ )
25
+ )
26
+ def _get_state() -> Tuple[torch.Tensor, torch.Tensor]:
27
+ return (torch.get_rng_state(), torch.cuda.get_rng_state())
28
+
29
+
30
+ @remote(propagate=lambda state: None)
31
+ def set_state(state: Tuple[Tensor, Tensor]):
32
+ cpu, device = state
33
+ torch.set_rng_state(cpu)
34
+ torch.cuda.set_rng_state(device)
35
+
36
+
37
+ @remote(propagate=lambda _: None)
38
+ def _manual_seed(seed: torch.Tensor):
39
+ torch.manual_seed(seed.item())
40
+
41
+
42
+ @remote(propagate=lambda: None)
43
+ def make_deterministic():
44
+ torch.use_deterministic_algorithms(True)
45
+ torch.backends.cudnn.deterministic = True
46
+ torch.backends.cudnn.benchmark = False
47
+ # env var for deterministic CuBLAS
48
+ # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
49
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
50
+
51
+
52
+ def get_state() -> State:
53
+ return State(*_get_state())
54
+
55
+
56
+ def new_state(seed: Tensor) -> State:
57
+ orig = get_state()
58
+ _manual_seed(seed)
59
+ mine = get_state()
60
+ set_state(orig)
61
+ return mine
monarch/rdma.py ADDED
@@ -0,0 +1,162 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import ctypes
8
+
9
+ from dataclasses import dataclass
10
+ from typing import cast, Dict, Optional, Tuple
11
+
12
+ import torch
13
+
14
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
15
+
16
+ from monarch.actor_mesh import (
17
+ _ActorMeshRefImpl,
18
+ Actor,
19
+ ActorMeshRef,
20
+ endpoint,
21
+ MonarchContext,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class LocalRDMARecord:
27
+ data: torch.Tensor
28
+
29
+
30
+ _local_buffers: Dict[int, "LocalRDMARecord"] = {}
31
+
32
+
33
+ def _get_bytes(storage: torch.Tensor, offset: int, size: int) -> bytearray:
34
+ """Extracts a bytearray from a 1D, 1byte per item tensor."""
35
+ if offset + size > storage.numel():
36
+ raise ValueError(f"Read out of range: {offset + size} > {storage.size()}")
37
+ addr = storage.data_ptr()
38
+ if storage.device.type != "cpu":
39
+ result = bytearray(size)
40
+ result_tensor = torch.frombuffer(
41
+ result,
42
+ dtype=torch.uint8,
43
+ )
44
+ source_tensor = storage[offset:]
45
+ result_tensor.copy_(source_tensor)
46
+ else:
47
+ ctypes_array = (ctypes.c_byte * size).from_address(addr)
48
+ result = bytearray(ctypes_array)
49
+ return result
50
+
51
+
52
+ class RDMAManager(Actor):
53
+ @staticmethod
54
+ def on_proc(proc_id: str) -> "RDMAManager":
55
+ ctx = MonarchContext.get()
56
+ return cast(
57
+ RDMAManager,
58
+ ActorMeshRef(
59
+ RDMAManager,
60
+ _ActorMeshRefImpl.from_actor_id(
61
+ ctx.mailbox,
62
+ ActorId.from_string(f"{proc_id}.rdma_manager[0]"),
63
+ ),
64
+ ctx.mailbox,
65
+ ),
66
+ )
67
+
68
+ @endpoint
69
+ async def drop(self, addr: int) -> None:
70
+ if addr in _local_buffers:
71
+ del _local_buffers[addr]
72
+
73
+ @endpoint
74
+ async def fetch(self, addr: int, offset: int, nbytes: int) -> bytearray:
75
+ if addr not in _local_buffers:
76
+ raise ValueError(f"Unknown buffer {addr}")
77
+ storage = _local_buffers[addr].data
78
+ return _get_bytes(storage, offset, nbytes)
79
+
80
+ @endpoint
81
+ async def put(self, addr: int, offset: int, bytes: bytearray) -> None:
82
+ if addr not in _local_buffers:
83
+ raise ValueError(f"Unknown buffer {addr}")
84
+ storage = _local_buffers[addr].data
85
+ storage[offset : offset + len(bytes)] = torch.frombuffer(
86
+ bytes, dtype=storage.dtype
87
+ )
88
+
89
+
90
+ def _assert_tensor_is_1d_contiguous_uint8(t: torch.Tensor) -> None:
91
+ if t.ndim != 1:
92
+ raise ValueError(f"Tensor must be 1D, got {t.ndim}D")
93
+ if t.dtype != torch.uint8:
94
+ raise ValueError(f"Tensor must be uint8, got {t.dtype}")
95
+ if not t.is_contiguous():
96
+ raise ValueError("Tensor must be contiguous")
97
+
98
+
99
+ class RDMABuffer:
100
+ def __init__(self, data: torch.Tensor) -> None:
101
+ """
102
+ RDMABuffer only supports 1D contiguous tensors that are 1 byte per item.
103
+
104
+ To create a 1 byte, 1D view, use t.view(torch.uint8).flatten()
105
+
106
+ TODO: Create TensorBuffer, which will be main user API supporting non-contiguous , multi-byte-per-elment tensors
107
+ """
108
+ _assert_tensor_is_1d_contiguous_uint8(data)
109
+ assert data.storage_offset() == 0
110
+ storage = data.untyped_storage()
111
+ self.addr: int = storage.data_ptr()
112
+ self.begin = 0
113
+ self.end: int = storage.size()
114
+ self.proc_id: str = MonarchContext.get().proc_id
115
+ self.local_data: object = None
116
+ _local_buffers[self.addr] = LocalRDMARecord(data)
117
+
118
+ def drop(self) -> None:
119
+ if self.proc_id is None:
120
+ del _local_buffers[self.addr]
121
+ return
122
+ rmda_actor = RDMAManager.on_proc(self.proc_id)
123
+ # pyre-ignore[16]: Undefined attribute [16]: `Endpoint` has no attribute `cast`.
124
+ rmda_actor.drop.cast(self.addr)
125
+
126
+ def __getstate__(self) -> Tuple[int, int, int, Optional[str]]:
127
+ proc_id = self.proc_id
128
+ # locally created RDMABuffer being set remotely,
129
+ # record its proc_id so we know how to establish connections to it
130
+ if proc_id is None:
131
+ proc_id = MonarchContext.get().proc_id
132
+ return (self.addr, self.begin, self.end, proc_id)
133
+
134
+ def __setstate__(self, state: Tuple[int, int, int, str]) -> None:
135
+ self.local_data = None
136
+ self.addr, self.begin, self.end, self.proc_id = state
137
+
138
+ async def read_into(self, dst: torch.Tensor, offset: int = 0) -> None:
139
+ """
140
+ Read data from the RDMABuffer into a destination tensor.
141
+
142
+ The destination tensor must be contiguous and 1 byte per item.
143
+ """
144
+ _assert_tensor_is_1d_contiguous_uint8(dst)
145
+ bytes = await RDMAManager.on_proc(self.proc_id).fetch.call_one(
146
+ self.addr, offset, dst.numel()
147
+ )
148
+ dst.copy_(torch.frombuffer(bytes, dtype=torch.uint8))
149
+
150
+ async def write(self, src: torch.Tensor, offset: int = 0) -> None:
151
+ """
152
+ Write data from a source tensor into the RDMABuffer.
153
+
154
+ The source tensor must be contiguous and 1 byte per item.
155
+ """
156
+ _assert_tensor_is_1d_contiguous_uint8(src)
157
+ bytes = _get_bytes(
158
+ src,
159
+ cast(int, src.storage_offset()),
160
+ src.numel(),
161
+ )
162
+ await RDMAManager.on_proc(self.proc_id).put.call_one(self.addr, offset, bytes)
@@ -0,0 +1,114 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import importlib
9
+ import itertools
10
+ from typing import Any, Dict
11
+
12
+ from monarch.common import device_mesh
13
+ from monarch.common.remote import remote
14
+
15
+
16
+ class ControllerRemoteClass:
17
+ """
18
+ This class simplifies the creation and management of remote classes. It serves as
19
+ the controller side of a remote class architecture. Classes that are intended to be
20
+ controlled remotely should inherit from this class. The constructor of the inheriting
21
+ class must invoke `super().__init__()` with the path to the remote class that will be
22
+ used on the worker nodes. Methods that are intended for remote execution must be
23
+ decorated with `ControllerRemoteClass.remote_method`.
24
+
25
+ Note: This class is designed for use by the controller developer only and should
26
+ not be directly used in model code.
27
+
28
+ Example usage:
29
+
30
+ class ControllerMyClass(ControllerRemoteClass):
31
+ def __init__(self, *args, **kwargs) -> None:
32
+ super().__init__("my_package.my_class", *args, **kwargs)
33
+
34
+ @ControllerRemoteClass.remote_method
35
+ def some_method(self, *args, **kwargs) -> None:
36
+ # This method is intended for remote execution and does nothing locally.
37
+ pass
38
+ """
39
+
40
+ _counter = itertools.count()
41
+
42
+ def __init__(self, cls_path: str, *args, **kwargs) -> None:
43
+ self.ident = next(ControllerRemoteClass._counter)
44
+ self.cls_path = cls_path
45
+ self.mesh = device_mesh._active
46
+ _controller_remote_class_method(
47
+ cls_path, "remote_init", self.ident, *args, **kwargs
48
+ )
49
+
50
+ def __del__(self) -> None:
51
+ mesh = getattr(self, "mesh", None)
52
+ if mesh is not None and not mesh.client._shutdown:
53
+ with self.mesh.activate():
54
+ _controller_remote_class_method(
55
+ self.cls_path,
56
+ "remote_del",
57
+ self.ident,
58
+ )
59
+
60
+ @staticmethod
61
+ def remote_method(fn):
62
+ def wrapper(self, *args, **kwargs) -> None:
63
+ _controller_remote_class_method(
64
+ self.cls_path, "remote_method", self.ident, fn.__name__, *args, **kwargs
65
+ )
66
+
67
+ return wrapper
68
+
69
+
70
+ # Add the logic as a separate private function instead of adding ita to
71
+ # ResolvableFunctionFromPath. This avoids users to using this directly.
72
+ _controller_remote_class_method = remote(
73
+ "monarch.remote_class._remote_class_method", propagate="inspect"
74
+ )
75
+
76
+
77
+ def _remote_class_method(cls_path: str, method_name: str, *args, **kwargs) -> None:
78
+ modulename, classname = cls_path.rsplit(".", 1)
79
+ module = importlib.import_module(modulename)
80
+ cls = getattr(module, classname)
81
+ method = getattr(cls, method_name)
82
+ method(*args, **kwargs)
83
+
84
+
85
+ class WorkerRemoteClass:
86
+ """
87
+ This class is designed to be used alongside ``ControllerRemoteClass`` and represents
88
+ the worker-side of a remote class architecture. Instances of this class should just
89
+ mimic standard Python classes, with the notable exception that all methods must
90
+ return None -- the current RemoteClass architecture does not support methods that
91
+ return values.
92
+
93
+ The `ident` attribute is used for tracking object instances created via `remote_init`.
94
+ This tracking is necessary because the remote function would otherwise lose the
95
+ reference to the object.
96
+ """
97
+
98
+ _objects: Dict[int, Any] = {}
99
+
100
+ @classmethod
101
+ def remote_init(cls, ident: int, *args, **kwargs) -> None:
102
+ WorkerRemoteClass._objects[ident] = cls(*args, **kwargs)
103
+
104
+ @classmethod
105
+ def remote_del(cls, ident) -> None:
106
+ WorkerRemoteClass._objects.pop(ident)
107
+
108
+ @classmethod
109
+ def remote_method(cls, ident: int, method_name, *args, **kwargs) -> None:
110
+ instance = WorkerRemoteClass._objects[ident]
111
+ assert (
112
+ cls == instance.__class__
113
+ ), "Mismatched class type {cls} {instance.__class__}"
114
+ getattr(instance, method_name)(*args, **kwargs)