torchmonarch-nightly 2025.6.30__cp311-cp311-manylinux2014_x86_64.whl → 2025.7.25__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +13 -9
- monarch/_rust_bindings.so +0 -0
- monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
- monarch/_src/actor/actor_mesh.py +874 -0
- monarch/{allocator.py → _src/actor/allocator.py} +26 -17
- monarch/_src/actor/bootstrap_main.py +73 -0
- monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
- monarch/_src/actor/code_sync/auto_reload.py +223 -0
- monarch/_src/actor/debugger.py +565 -0
- monarch/_src/actor/endpoint.py +270 -0
- monarch/_src/actor/event_loop.py +97 -0
- monarch/_src/actor/future.py +100 -0
- monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
- monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
- monarch/_src/actor/proc_mesh.py +500 -0
- monarch/_src/actor/sync_state.py +18 -0
- monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
- monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
- monarch/_src/actor/tensor_engine_shim.py +56 -0
- monarch/_src/tensor_engine/rdma.py +180 -0
- monarch/_testing.py +3 -2
- monarch/actor/__init__.py +51 -0
- monarch/actor_mesh.py +6 -752
- monarch/bootstrap_main.py +8 -47
- monarch/common/client.py +1 -1
- monarch/common/controller_api.py +2 -1
- monarch/common/device_mesh.py +12 -2
- monarch/common/messages.py +12 -1
- monarch/common/recording.py +4 -3
- monarch/common/remote.py +135 -52
- monarch/common/tensor.py +2 -1
- monarch/controller/backend.py +2 -2
- monarch/controller/controller.py +2 -1
- monarch/controller/rust_backend/controller.py +2 -1
- monarch/fetch.py +3 -5
- monarch/mesh_controller.py +201 -139
- monarch/monarch_controller +0 -0
- monarch/opaque_module.py +4 -6
- monarch/opaque_object.py +3 -3
- monarch/proc_mesh.py +6 -309
- monarch/python_local_mesh.py +1 -1
- monarch/rust_backend_mesh.py +2 -1
- monarch/rust_local_mesh.py +4 -2
- monarch/sim_mesh.py +10 -19
- monarch/simulator/command_history.py +1 -1
- monarch/simulator/interface.py +2 -1
- monarch/simulator/mock_controller.py +1 -1
- monarch/simulator/simulator.py +1 -1
- monarch/tensor_engine/__init__.py +23 -0
- monarch/tensor_worker_main.py +3 -1
- monarch/tools/cli.py +3 -1
- monarch/tools/commands.py +95 -35
- monarch/tools/mesh_spec.py +55 -0
- monarch/tools/utils.py +38 -0
- monarch/worker/worker.py +1 -1
- monarch/world_mesh.py +2 -1
- monarch_supervisor/python_executable.py +6 -3
- tests/error_test_binary.py +75 -9
- tests/test_actor_error.py +370 -21
- tests/test_alloc.py +1 -1
- tests/test_allocator.py +373 -17
- tests/test_controller.py +2 -0
- tests/test_debugger.py +416 -0
- tests/test_env_before_cuda.py +162 -0
- tests/test_python_actors.py +184 -332
- tests/test_rdma.py +198 -0
- tests/test_remote_functions.py +40 -12
- tests/test_rust_backend.py +7 -5
- tests/test_sim_backend.py +1 -4
- tests/test_tensor_engine.py +55 -1
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/METADATA +6 -1
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/RECORD +80 -68
- torchmonarch_nightly-2025.7.25.dist-info/entry_points.txt +3 -0
- monarch/_monarch/hyperactor/__init__.py +0 -58
- monarch/_monarch/worker/debugger.py +0 -117
- monarch/_monarch/worker/logging.py +0 -107
- monarch/debugger.py +0 -379
- monarch/future.py +0 -76
- monarch/rdma.py +0 -162
- torchmonarch_nightly-2025.6.30.dist-info/entry_points.txt +0 -3
- /monarch/{_monarch/worker → _src}/__init__.py +0 -0
- /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
- /monarch/{common → _src/actor}/shape.py +0 -0
- /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/WHEEL +0 -0
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/licenses/LICENSE +0 -0
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/top_level.txt +0 -0
monarch/mesh_controller.py
CHANGED
@@ -7,40 +7,63 @@
|
|
7
7
|
import atexit
|
8
8
|
import logging
|
9
9
|
import os
|
10
|
-
|
10
|
+
|
11
|
+
import pdb # noqa
|
11
12
|
import traceback
|
12
13
|
from collections import deque
|
13
14
|
from logging import Logger
|
14
|
-
from typing import
|
15
|
+
from typing import (
|
16
|
+
Any,
|
17
|
+
cast,
|
18
|
+
List,
|
19
|
+
NamedTuple,
|
20
|
+
Optional,
|
21
|
+
Sequence,
|
22
|
+
Tuple,
|
23
|
+
TYPE_CHECKING,
|
24
|
+
Union,
|
25
|
+
)
|
15
26
|
|
16
27
|
import torch.utils._python_dispatch
|
17
|
-
|
18
|
-
from monarch import NDSlice
|
19
|
-
from monarch._rust_bindings.monarch_extension import client, debugger
|
28
|
+
from monarch._rust_bindings.monarch_extension import client
|
20
29
|
from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
|
21
30
|
WorldState,
|
22
31
|
)
|
23
32
|
from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller
|
33
|
+
from monarch._rust_bindings.monarch_hyperactor.actor import (
|
34
|
+
PythonMessage,
|
35
|
+
PythonMessageKind,
|
36
|
+
UnflattenArg,
|
37
|
+
)
|
38
|
+
from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
|
24
39
|
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
25
40
|
ActorId,
|
26
41
|
)
|
42
|
+
from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple
|
43
|
+
from monarch._src.actor.endpoint import Selection
|
44
|
+
from monarch._src.actor.shape import NDSlice
|
45
|
+
from monarch.common import device_mesh, messages, stream
|
46
|
+
from monarch.common.controller_api import TController
|
47
|
+
from monarch.common.invocation import Seq
|
48
|
+
from monarch.common.messages import Referenceable, SendResultOfActorCall
|
49
|
+
from monarch.common.stream import StreamRef
|
50
|
+
from monarch.common.tensor import InputChecker, Tensor
|
51
|
+
from monarch.tensor_worker_main import _set_trace
|
27
52
|
|
28
53
|
if TYPE_CHECKING:
|
29
54
|
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import (
|
30
55
|
ProcMesh as HyProcMesh,
|
31
56
|
)
|
32
|
-
from monarch.
|
57
|
+
from monarch.actor import ProcMesh
|
33
58
|
|
34
59
|
from monarch._rust_bindings.monarch_hyperactor.shape import Point
|
35
60
|
|
36
|
-
from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
|
37
61
|
from monarch.common.client import Client
|
38
62
|
from monarch.common.controller_api import LogMessage, MessageResult
|
39
|
-
from monarch.common.device_mesh import DeviceMesh
|
63
|
+
from monarch.common.device_mesh import DeviceMesh
|
64
|
+
from monarch.common.future import Future as OldFuture
|
40
65
|
from monarch.common.invocation import DeviceException, RemoteException
|
41
|
-
from monarch.controller.debugger import read as debugger_read, write as debugger_write
|
42
66
|
from monarch.rust_local_mesh import _get_worker_exec_info
|
43
|
-
from pyre_extensions import none_throws
|
44
67
|
|
45
68
|
logger: Logger = logging.getLogger(__name__)
|
46
69
|
|
@@ -48,6 +71,7 @@ logger: Logger = logging.getLogger(__name__)
|
|
48
71
|
class Controller(_Controller):
|
49
72
|
def __init__(self, workers: "HyProcMesh") -> None:
|
50
73
|
super().__init__()
|
74
|
+
self._mailbox: Mailbox = workers.client
|
51
75
|
# Buffer for messages unrelated to debugging that are received while a
|
52
76
|
# debugger session is active.
|
53
77
|
self._non_debugger_pending_messages: deque[
|
@@ -58,19 +82,9 @@ class Controller(_Controller):
|
|
58
82
|
def next_message(
|
59
83
|
self, timeout: Optional[float]
|
60
84
|
) -> Optional[LogMessage | MessageResult]:
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
msg = self._get_next_message(timeout_msec=int((timeout or 0.0) * 1000.0))
|
65
|
-
if msg is None:
|
66
|
-
return None
|
67
|
-
|
68
|
-
if isinstance(msg, client.WorkerResponse):
|
69
|
-
return _worker_response_to_result(msg)
|
70
|
-
elif isinstance(msg, client.LogMessage):
|
71
|
-
return LogMessage(msg.level, msg.message)
|
72
|
-
elif isinstance(msg, client.DebuggerMessage):
|
73
|
-
self._run_debugger_loop(msg)
|
85
|
+
raise RuntimeError(
|
86
|
+
"internal error: tensor engine does not produce futures that call next_message"
|
87
|
+
)
|
74
88
|
|
75
89
|
def send(
|
76
90
|
self,
|
@@ -86,56 +100,6 @@ class Controller(_Controller):
|
|
86
100
|
self._drain_and_stop()
|
87
101
|
return []
|
88
102
|
|
89
|
-
def _run_debugger_loop(self, message: client.DebuggerMessage) -> None:
|
90
|
-
if not isinstance(message.action, DebuggerAction.Paused):
|
91
|
-
raise RuntimeError(
|
92
|
-
f"Unexpected debugger message {message} when no debugger session is running"
|
93
|
-
)
|
94
|
-
|
95
|
-
self._pending_debugger_sessions.append(message.debugger_actor_id)
|
96
|
-
while self._pending_debugger_sessions:
|
97
|
-
debugger_actor_id = self._pending_debugger_sessions.popleft()
|
98
|
-
rank = debugger_actor_id.rank
|
99
|
-
proc_id = debugger_actor_id.proc_id
|
100
|
-
debugger_write(
|
101
|
-
f"pdb attached to proc {proc_id} with rank {rank}, debugger actor {debugger_actor_id} \n"
|
102
|
-
)
|
103
|
-
|
104
|
-
self._debugger_attach(debugger_actor_id)
|
105
|
-
while True:
|
106
|
-
# TODO: Add appropriate timeout.
|
107
|
-
msg = self._get_next_message(timeout_msec=None)
|
108
|
-
|
109
|
-
if not isinstance(msg, client.DebuggerMessage):
|
110
|
-
self._non_debugger_pending_messages.append(msg)
|
111
|
-
continue
|
112
|
-
|
113
|
-
if msg.debugger_actor_id != debugger_actor_id:
|
114
|
-
if isinstance(msg.action, DebuggerAction.Paused):
|
115
|
-
self._pending_debugger_sessions.append(msg.debugger_actor_id)
|
116
|
-
continue
|
117
|
-
else:
|
118
|
-
raise RuntimeError(
|
119
|
-
f"unexpected debugger message {msg} from rank {msg.debugger_actor_id.rank} "
|
120
|
-
f"when debugging rank {debugger_actor_id.rank}"
|
121
|
-
)
|
122
|
-
|
123
|
-
action = msg.action
|
124
|
-
if isinstance(action, DebuggerAction.Detach):
|
125
|
-
break
|
126
|
-
elif isinstance(action, DebuggerAction.Read):
|
127
|
-
self._debugger_write(
|
128
|
-
debugger_actor_id, debugger_read(action.requested_size)
|
129
|
-
)
|
130
|
-
elif isinstance(action, DebuggerAction.Write):
|
131
|
-
debugger_write(
|
132
|
-
debugger.get_bytes_from_write_action(action).decode()
|
133
|
-
)
|
134
|
-
else:
|
135
|
-
raise RuntimeError(
|
136
|
-
f"unexpected debugger message {msg} when debugging rank {debugger_actor_id.rank}"
|
137
|
-
)
|
138
|
-
|
139
103
|
def worker_world_state(self) -> WorldState:
|
140
104
|
raise NotImplementedError("worker world state")
|
141
105
|
|
@@ -145,54 +109,6 @@ class Controller(_Controller):
|
|
145
109
|
pass
|
146
110
|
|
147
111
|
|
148
|
-
# TODO: Handling conversion of the response can move to a separate module over time
|
149
|
-
# especially as we have structured error messages.
|
150
|
-
def _worker_response_to_result(result: client.WorkerResponse) -> MessageResult:
|
151
|
-
if not result.is_exception():
|
152
|
-
# The result of the message needs to be unwrapped on a real device.
|
153
|
-
# Staying as a fake tensor will fail the tensor deserialization.
|
154
|
-
with no_mesh.activate():
|
155
|
-
return MessageResult(result.seq, result.result(), None)
|
156
|
-
exc = none_throws(result.exception())
|
157
|
-
if isinstance(exc, client.Error):
|
158
|
-
worker_frames = [
|
159
|
-
traceback.FrameSummary("<unknown>", None, frame)
|
160
|
-
for frame in exc.backtrace.split("\\n")
|
161
|
-
]
|
162
|
-
return MessageResult(
|
163
|
-
seq=result.seq,
|
164
|
-
result=None,
|
165
|
-
error=RemoteException(
|
166
|
-
seq=exc.caused_by_seq,
|
167
|
-
exception=RuntimeError(exc.backtrace),
|
168
|
-
controller_frame_index=0, # TODO: T225205291 fix this once we have recording support in rust
|
169
|
-
controller_frames=None,
|
170
|
-
worker_frames=worker_frames,
|
171
|
-
source_actor_id=exc.actor_id,
|
172
|
-
message=f"Remote function in {exc.actor_id} errored.",
|
173
|
-
),
|
174
|
-
)
|
175
|
-
elif isinstance(exc, client.Failure):
|
176
|
-
frames = [
|
177
|
-
traceback.FrameSummary("<unknown>", None, frame)
|
178
|
-
for frame in exc.backtrace.split("\n")
|
179
|
-
]
|
180
|
-
reason = f"Actor {exc.actor_id} crashed on {exc.address}, check the host log for details"
|
181
|
-
logger.error(reason)
|
182
|
-
return MessageResult(
|
183
|
-
seq=0, # seq is not consumed for DeviceException; it will be directly thrown by the client
|
184
|
-
result=None,
|
185
|
-
error=DeviceException(
|
186
|
-
exception=RuntimeError(reason),
|
187
|
-
frames=frames,
|
188
|
-
source_actor_id=exc.actor_id,
|
189
|
-
message=reason,
|
190
|
-
),
|
191
|
-
)
|
192
|
-
else:
|
193
|
-
raise RuntimeError(f"Unknown exception type: {type(exc)}")
|
194
|
-
|
195
|
-
|
196
112
|
def _initialize_env(worker_point: Point, proc_id: str) -> None:
|
197
113
|
worker_rank = worker_point.rank
|
198
114
|
try:
|
@@ -213,12 +129,50 @@ def _initialize_env(worker_point: Point, proc_id: str) -> None:
|
|
213
129
|
"LOCAL_WORLD_SIZE": str(gpus_per_host),
|
214
130
|
}
|
215
131
|
os.environ.update(process_env)
|
132
|
+
pdb.set_trace = _set_trace
|
133
|
+
# workaround for set_manual_seed somehow not working if cuda is not initialized\
|
134
|
+
if torch.cuda.is_available():
|
135
|
+
torch.cuda.init()
|
216
136
|
except Exception:
|
217
137
|
traceback.print_exc()
|
218
138
|
raise
|
219
139
|
|
220
140
|
|
221
141
|
class MeshClient(Client):
|
142
|
+
def fetch(
|
143
|
+
self,
|
144
|
+
mesh: "DeviceMesh",
|
145
|
+
stream: "StreamRef",
|
146
|
+
shard,
|
147
|
+
preprocess_message,
|
148
|
+
args,
|
149
|
+
kwargs,
|
150
|
+
defs: Tuple["Tensor", ...],
|
151
|
+
uses: Tuple["Tensor", ...],
|
152
|
+
) -> "OldFuture": # the OldFuture is a lie
|
153
|
+
sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
|
154
|
+
|
155
|
+
ident = self.new_node(defs, uses, cast("OldFuture", sender))
|
156
|
+
process = mesh._process(shard)
|
157
|
+
self.send(
|
158
|
+
process,
|
159
|
+
messages.SendValue(
|
160
|
+
ident,
|
161
|
+
None,
|
162
|
+
defs,
|
163
|
+
preprocess_message,
|
164
|
+
args,
|
165
|
+
kwargs,
|
166
|
+
stream,
|
167
|
+
),
|
168
|
+
)
|
169
|
+
# we have to ask for status updates
|
170
|
+
# from workers to be sure they have finished
|
171
|
+
# enough work to count this future as finished,
|
172
|
+
# and all potential errors have been reported
|
173
|
+
self._request_status()
|
174
|
+
return cast("OldFuture", receiver.recv())
|
175
|
+
|
222
176
|
def shutdown(
|
223
177
|
self,
|
224
178
|
destroy_pg: bool = True,
|
@@ -232,27 +186,43 @@ class MeshClient(Client):
|
|
232
186
|
atexit.unregister(self._atexit)
|
233
187
|
self._shutdown = True
|
234
188
|
|
235
|
-
|
236
|
-
|
237
|
-
self.
|
238
|
-
|
239
|
-
|
240
|
-
ttl = 60
|
241
|
-
start_time = time.time()
|
242
|
-
end_time = start_time + ttl
|
243
|
-
while ttl > 0 and self.last_assigned_seq > self.last_processed_seq:
|
244
|
-
ttl = end_time - time.time()
|
245
|
-
self.handle_next_message(ttl)
|
246
|
-
if self._pending_shutdown_error:
|
247
|
-
raise self._pending_shutdown_error
|
248
|
-
|
249
|
-
if ttl <= 0:
|
250
|
-
raise RuntimeError("shutdown timed out")
|
251
|
-
|
189
|
+
sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
|
190
|
+
assert sender._port_ref is not None
|
191
|
+
self._mesh_controller.sync_at_exit(sender._port_ref.port_id)
|
192
|
+
receiver.recv().get(timeout=60)
|
252
193
|
# we are not expecting anything more now, because we already
|
253
194
|
# waited for the responses
|
254
195
|
self.inner.drain_and_stop()
|
255
196
|
|
197
|
+
@property
|
198
|
+
def _mesh_controller(self) -> Controller:
|
199
|
+
return cast(Controller, self.inner)
|
200
|
+
|
201
|
+
def new_node_nocoalesce(
|
202
|
+
self,
|
203
|
+
defs: Sequence["Tensor"],
|
204
|
+
uses: Sequence["Tensor"],
|
205
|
+
future: Optional["OldFuture"],
|
206
|
+
tracebacks: List[List[traceback.FrameSummary]],
|
207
|
+
) -> Seq:
|
208
|
+
seq = self._next_seq()
|
209
|
+
for d in defs:
|
210
|
+
d._seq = seq
|
211
|
+
response_port = None
|
212
|
+
if future is not None:
|
213
|
+
# method annotation is a lie to make Client happy
|
214
|
+
port, slice = cast("Tuple[Port[Any], NDSlice]", future)
|
215
|
+
assert port._port_ref is not None
|
216
|
+
response_port = (port._port_ref.port_id, slice)
|
217
|
+
self._mesh_controller.node(seq, defs, uses, response_port, tracebacks)
|
218
|
+
return seq
|
219
|
+
|
220
|
+
def handle_next_message(self, timeout: Optional[float]) -> bool:
|
221
|
+
"""
|
222
|
+
Mesh controller message loop is handled by the tokio event loop.
|
223
|
+
"""
|
224
|
+
return False
|
225
|
+
|
256
226
|
|
257
227
|
def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
|
258
228
|
# This argument to Controller
|
@@ -260,7 +230,7 @@ def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
|
|
260
230
|
# report the proc ID instead of the rank it currently does.
|
261
231
|
gpus = proc_mesh.sizes.get("gpus", 1)
|
262
232
|
backend_ctrl = Controller(proc_mesh._proc_mesh)
|
263
|
-
client = MeshClient(backend_ctrl, proc_mesh.size(), gpus)
|
233
|
+
client = MeshClient(cast("TController", backend_ctrl), proc_mesh.size(), gpus)
|
264
234
|
dm = DeviceMesh(
|
265
235
|
client,
|
266
236
|
NDSlice.new_row_major(list(proc_mesh.sizes.values())),
|
@@ -268,3 +238,95 @@ def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
|
|
268
238
|
)
|
269
239
|
dm.exit = lambda: client.shutdown()
|
270
240
|
return dm
|
241
|
+
|
242
|
+
|
243
|
+
class RemoteException(Exception):
|
244
|
+
def __init__(
|
245
|
+
self,
|
246
|
+
worker_error_string: str, # this should really be an exception + stacktrace but
|
247
|
+
# worker code needs major refactor to make this possible
|
248
|
+
controller_frames: List[traceback.FrameSummary],
|
249
|
+
rank: int,
|
250
|
+
):
|
251
|
+
self.worker_error_string = worker_error_string
|
252
|
+
self.controller_frames = controller_frames
|
253
|
+
self.rank = rank
|
254
|
+
|
255
|
+
def __str__(self):
|
256
|
+
try:
|
257
|
+
controller_tb = "".join(traceback.format_list(self.controller_frames))
|
258
|
+
return (
|
259
|
+
f"A remote function has failed asynchronously on rank {self.rank}.\n"
|
260
|
+
f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}"
|
261
|
+
f"Error as reported from worker:\n{self.worker_error_string}"
|
262
|
+
)
|
263
|
+
except Exception:
|
264
|
+
traceback.print_exc()
|
265
|
+
return "<exception formatting RemoteException>"
|
266
|
+
|
267
|
+
|
268
|
+
def actor_send(
|
269
|
+
endpoint: ActorEndpoint,
|
270
|
+
args_kwargs_tuple: bytes,
|
271
|
+
refs: Sequence[Any],
|
272
|
+
port: Optional[Port[Any]],
|
273
|
+
selection: Selection,
|
274
|
+
):
|
275
|
+
unflatten_args = [
|
276
|
+
UnflattenArg.PyObject if isinstance(ref, Tensor) else UnflattenArg.Mailbox
|
277
|
+
for ref in refs
|
278
|
+
]
|
279
|
+
tensors = [ref for ref in refs if isinstance(ref, Tensor)]
|
280
|
+
# we have some monarch references, we need to ensure their
|
281
|
+
# proc_mesh matches that of the tensors we sent to it
|
282
|
+
chosen_stream = stream._active
|
283
|
+
for t in tensors:
|
284
|
+
if hasattr(t, "stream"):
|
285
|
+
chosen_stream = t.stream
|
286
|
+
break
|
287
|
+
with InputChecker(refs, lambda x: f"actor_call({x})") as checker:
|
288
|
+
checker.check_mesh_stream_local(device_mesh._active, chosen_stream)
|
289
|
+
# TODO: move propagators into Endpoint abstraction and run the propagator to get the
|
290
|
+
# mutates
|
291
|
+
checker.check_permission(())
|
292
|
+
selected_device_mesh = (
|
293
|
+
endpoint._actor_mesh._proc_mesh and endpoint._actor_mesh._proc_mesh._device_mesh
|
294
|
+
)
|
295
|
+
if selected_device_mesh is not checker.mesh:
|
296
|
+
raise ValueError(
|
297
|
+
f"monarch Tensors sent to an actor must be located on the same process as the actor. However {checker.mesh} is not {selected_device_mesh}."
|
298
|
+
"NYI: better serialization of mesh names to make the mismatch more clear."
|
299
|
+
)
|
300
|
+
|
301
|
+
client = cast(MeshClient, checker.mesh.client)
|
302
|
+
|
303
|
+
broker_id: Tuple[str, int] = client._mesh_controller.broker_id
|
304
|
+
|
305
|
+
stream_ref = chosen_stream._to_ref(client)
|
306
|
+
|
307
|
+
fut = (port, checker.mesh._ndslice) if port is not None else None
|
308
|
+
|
309
|
+
ident = client.new_node([], tensors, cast("OldFuture", fut))
|
310
|
+
|
311
|
+
# To ensure that both the actor and the stream execute in order, we send a message
|
312
|
+
# to each at this point. The message to the worker will be handled on the stream actor where
|
313
|
+
# it will send the 'tensor's to the broker actor locally, along with a response port with the
|
314
|
+
# computed value.
|
315
|
+
|
316
|
+
# The message to the generic actor tells it to first wait on the broker to get the local arguments
|
317
|
+
# from the stream, then it will run the actor method, and send the result to response port.
|
318
|
+
|
319
|
+
actor_msg = PythonMessage(
|
320
|
+
PythonMessageKind.CallMethodIndirect(
|
321
|
+
endpoint._name, broker_id, ident, unflatten_args
|
322
|
+
),
|
323
|
+
args_kwargs_tuple,
|
324
|
+
)
|
325
|
+
endpoint._actor_mesh.cast(actor_msg, selection)
|
326
|
+
worker_msg = SendResultOfActorCall(ident, broker_id, tensors, [], stream_ref)
|
327
|
+
client.send(checker.mesh._ndslice, worker_msg)
|
328
|
+
# we have to ask for status updates
|
329
|
+
# from workers to be sure they have finished
|
330
|
+
# enough work to count this future as finished,
|
331
|
+
# and all potential errors have been reported
|
332
|
+
client._request_status()
|
monarch/monarch_controller
CHANGED
Binary file
|
monarch/opaque_module.py
CHANGED
@@ -9,7 +9,7 @@ from typing import List
|
|
9
9
|
import torch
|
10
10
|
from monarch.common.function_caching import TensorGroup, TensorGroupPattern
|
11
11
|
from monarch.common.opaque_ref import OpaqueRef
|
12
|
-
from monarch.common.remote import remote
|
12
|
+
from monarch.common.remote import call_on_shard_and_fetch, remote
|
13
13
|
from monarch.common.tensor_factory import TensorFactory
|
14
14
|
from monarch.common.tree import flatten
|
15
15
|
from monarch.opaque_object import _fresh_opaque_ref, OpaqueObject
|
@@ -144,11 +144,9 @@ class OpaqueModule:
|
|
144
144
|
|
145
145
|
def parameters(self):
|
146
146
|
if self._parameters is None:
|
147
|
-
tensor_group_pattern = (
|
148
|
-
remote(_get_parameters_shape)
|
149
|
-
|
150
|
-
.result()
|
151
|
-
)
|
147
|
+
tensor_group_pattern = call_on_shard_and_fetch(
|
148
|
+
remote(_get_parameters_shape), self._object
|
149
|
+
).result()
|
152
150
|
self._parameters = [
|
153
151
|
p.requires_grad_(True)
|
154
152
|
for p in remote(
|
monarch/opaque_object.py
CHANGED
@@ -14,7 +14,7 @@ from monarch.common.function import (
|
|
14
14
|
)
|
15
15
|
|
16
16
|
from monarch.common.opaque_ref import OpaqueRef
|
17
|
-
from monarch.common.remote import remote
|
17
|
+
from monarch.common.remote import call_on_shard_and_fetch, remote
|
18
18
|
|
19
19
|
|
20
20
|
def _invoke_method(obj: OpaqueRef, method_name: str, *args, **kwargs):
|
@@ -83,6 +83,6 @@ class OpaqueObject(OpaqueRef):
|
|
83
83
|
return endpoint(self, method_name, *args, **kwargs)
|
84
84
|
|
85
85
|
def call_method_on_shard_and_fetch(self, method_name, *args, **kwargs):
|
86
|
-
return
|
87
|
-
self, method_name, *args, **kwargs
|
86
|
+
return call_on_shard_and_fetch(
|
87
|
+
remote(_invoke_method), self, method_name, *args, **kwargs
|
88
88
|
)
|