torchmonarch-nightly 2025.7.1__cp310-cp310-manylinux2014_x86_64.whl → 2025.7.26__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +13 -9
- monarch/_rust_bindings.so +0 -0
- monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
- monarch/_src/actor/actor_mesh.py +878 -0
- monarch/{allocator.py → _src/actor/allocator.py} +26 -17
- monarch/_src/actor/bootstrap_main.py +73 -0
- monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
- monarch/_src/actor/code_sync/auto_reload.py +223 -0
- monarch/_src/actor/debugger.py +565 -0
- monarch/_src/actor/endpoint.py +303 -0
- monarch/_src/actor/event_loop.py +97 -0
- monarch/_src/actor/future.py +100 -0
- monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
- monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
- monarch/_src/actor/proc_mesh.py +508 -0
- monarch/_src/actor/sync_state.py +18 -0
- monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
- monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
- monarch/_src/actor/tensor_engine_shim.py +59 -0
- monarch/_src/tensor_engine/rdma.py +180 -0
- monarch/_testing.py +3 -2
- monarch/actor/__init__.py +53 -0
- monarch/actor_mesh.py +6 -765
- monarch/bootstrap_main.py +8 -47
- monarch/common/client.py +1 -1
- monarch/common/controller_api.py +2 -1
- monarch/common/device_mesh.py +12 -2
- monarch/common/messages.py +21 -1
- monarch/common/recording.py +4 -3
- monarch/common/remote.py +135 -52
- monarch/common/tensor.py +2 -1
- monarch/controller/backend.py +2 -2
- monarch/controller/controller.py +2 -1
- monarch/controller/rust_backend/controller.py +2 -1
- monarch/fetch.py +3 -5
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/mesh_controller.py +263 -139
- monarch/monarch_controller +0 -0
- monarch/opaque_module.py +4 -6
- monarch/opaque_object.py +3 -3
- monarch/proc_mesh.py +6 -309
- monarch/python_local_mesh.py +1 -1
- monarch/rust_backend_mesh.py +2 -1
- monarch/rust_local_mesh.py +4 -2
- monarch/sim_mesh.py +10 -19
- monarch/simulator/command_history.py +1 -1
- monarch/simulator/interface.py +2 -1
- monarch/simulator/mock_controller.py +1 -1
- monarch/simulator/simulator.py +1 -1
- monarch/tensor_engine/__init__.py +23 -0
- monarch/tensor_worker_main.py +3 -1
- monarch/tools/cli.py +3 -1
- monarch/tools/commands.py +129 -47
- monarch/tools/components/hyperactor.py +5 -3
- monarch/tools/config/__init__.py +18 -1
- monarch/tools/config/defaults.py +2 -2
- monarch/tools/mesh_spec.py +59 -1
- monarch/tools/utils.py +38 -0
- monarch/worker/worker.py +1 -1
- monarch/world_mesh.py +2 -1
- monarch_supervisor/python_executable.py +6 -3
- tests/error_test_binary.py +48 -10
- tests/test_actor_error.py +370 -21
- tests/test_alloc.py +1 -1
- tests/test_allocator.py +369 -17
- tests/test_controller.py +2 -0
- tests/test_debugger.py +416 -0
- tests/test_env_before_cuda.py +161 -0
- tests/test_python_actors.py +184 -333
- tests/test_rdma.py +198 -0
- tests/test_remote_functions.py +40 -12
- tests/test_rust_backend.py +7 -5
- tests/test_sim_backend.py +1 -4
- tests/test_tensor_engine.py +81 -1
- {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/METADATA +39 -1
- {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/RECORD +84 -72
- torchmonarch_nightly-2025.7.26.dist-info/entry_points.txt +3 -0
- monarch/_monarch/hyperactor/__init__.py +0 -58
- monarch/_monarch/worker/debugger.py +0 -117
- monarch/_monarch/worker/logging.py +0 -107
- monarch/debugger.py +0 -379
- monarch/future.py +0 -76
- monarch/rdma.py +0 -162
- torchmonarch_nightly-2025.7.1.dist-info/entry_points.txt +0 -3
- /monarch/{_monarch/worker → _src}/__init__.py +0 -0
- /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
- /monarch/{common → _src/actor}/shape.py +0 -0
- /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
- {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/WHEEL +0 -0
- {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/licenses/LICENSE +0 -0
- {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/top_level.txt +0 -0
monarch/bootstrap_main.py
CHANGED
@@ -4,56 +4,17 @@
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
6
6
|
|
7
|
-
|
8
|
-
This is the main function for the boostrapping a new process using a ProcessAllocator.
|
9
|
-
"""
|
7
|
+
import warnings
|
10
8
|
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
9
|
+
warnings.warn(
|
10
|
+
"monarch.bootstrap_main is deprecated, please use from monarch._src.actor.bootstrap_main instead.",
|
11
|
+
DeprecationWarning,
|
12
|
+
stacklevel=2,
|
13
|
+
)
|
16
14
|
|
17
|
-
|
18
|
-
import torch # noqa[F401]
|
19
|
-
|
20
|
-
|
21
|
-
async def main():
|
22
|
-
from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main
|
23
|
-
|
24
|
-
await bootstrap_main()
|
25
|
-
|
26
|
-
|
27
|
-
def invoke_main():
|
28
|
-
# if this is invoked with the stdout piped somewhere, then print
|
29
|
-
# changes its buffering behavior. So we default to the standard
|
30
|
-
# behavior of std out as if it were a terminal.
|
31
|
-
sys.stdout.reconfigure(line_buffering=True)
|
32
|
-
global bootstrap_main
|
33
|
-
|
34
|
-
# TODO: figure out what from worker_main.py we should reproduce here.
|
35
|
-
from monarch.telemetry import TracingForwarder
|
36
|
-
|
37
|
-
if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1":
|
38
|
-
raise RuntimeError("Error during bootstrap for testing")
|
39
|
-
|
40
|
-
# forward logs to rust tracing. Defaults to on.
|
41
|
-
if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1":
|
42
|
-
logging.root.addHandler(TracingForwarder(level=logging.DEBUG))
|
43
|
-
|
44
|
-
try:
|
45
|
-
with (
|
46
|
-
importlib.resources.path("monarch", "py-spy") as pyspy,
|
47
|
-
):
|
48
|
-
if pyspy.exists():
|
49
|
-
os.environ["PYSPY_BIN"] = str(pyspy)
|
50
|
-
# fallback to using local py-spy
|
51
|
-
except Exception as e:
|
52
|
-
logging.warning(f"Failed to set up py-spy: {e}")
|
53
|
-
|
54
|
-
# Start an event loop for PythonActors to use.
|
55
|
-
asyncio.run(main())
|
15
|
+
from monarch._src.actor.bootstrap_main import * # noqa
|
56
16
|
|
57
17
|
|
58
18
|
if __name__ == "__main__":
|
19
|
+
# noqa
|
59
20
|
invoke_main() # pragma: no cover
|
monarch/common/client.py
CHANGED
@@ -37,6 +37,7 @@ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monar
|
|
37
37
|
LogLevel,
|
38
38
|
WorldState,
|
39
39
|
)
|
40
|
+
from monarch._src.actor.shape import NDSlice
|
40
41
|
from monarch.common import messages
|
41
42
|
from monarch.common.borrows import Borrow, StorageAliases
|
42
43
|
from monarch.common.controller_api import LogMessage, MessageResult, TController
|
@@ -47,7 +48,6 @@ from monarch.common.invocation import DeviceException, RemoteException, Seq
|
|
47
48
|
from monarch.common.recording import flatten_messages, Recording
|
48
49
|
|
49
50
|
from monarch.common.reference import Ref, Referenceable
|
50
|
-
from monarch.common.shape import NDSlice
|
51
51
|
from monarch.common.stream import StreamRef
|
52
52
|
from monarch.common.tensor import Tensor
|
53
53
|
from monarch.common.tree import tree_map
|
monarch/common/controller_api.py
CHANGED
@@ -13,9 +13,10 @@ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monar
|
|
13
13
|
WorldState,
|
14
14
|
)
|
15
15
|
|
16
|
+
from monarch._src.actor.shape import NDSlice
|
17
|
+
|
16
18
|
from monarch.common.invocation import DeviceException, RemoteException, Seq
|
17
19
|
from monarch.common.reference import Ref
|
18
|
-
from monarch.common.shape import NDSlice
|
19
20
|
from monarch.common.tensor import Tensor
|
20
21
|
|
21
22
|
|
monarch/common/device_mesh.py
CHANGED
@@ -28,16 +28,16 @@ from typing import (
|
|
28
28
|
|
29
29
|
import monarch.common.messages as messages
|
30
30
|
import torch
|
31
|
-
from monarch.
|
31
|
+
from monarch._src.actor.shape import MeshTrait, NDSlice, Shape
|
32
32
|
|
33
33
|
from torch.utils._python_dispatch import TorchDispatchMode
|
34
34
|
from torch.utils._pytree import tree_map
|
35
|
+
from torch.utils.weak import weakref
|
35
36
|
|
36
37
|
from ._tensor_to_table import tensor_to_table
|
37
38
|
from .context_manager import activate_first_context_manager
|
38
39
|
from .messages import Dims
|
39
40
|
from .reference import Referenceable
|
40
|
-
from .shape import NDSlice, Shape
|
41
41
|
from .stream import Stream
|
42
42
|
from .tensor import MeshSliceTensor, Tensor
|
43
43
|
|
@@ -171,6 +171,7 @@ class DeviceMesh(Referenceable, MeshTrait):
|
|
171
171
|
self.exit = lambda: None
|
172
172
|
self.ref = None
|
173
173
|
self._active_mesh_context = None
|
174
|
+
self._subset_of: Optional[weakref.ReferenceType["DeviceMesh"]] = None
|
174
175
|
|
175
176
|
def define_remotely(self):
|
176
177
|
if self.ref is None:
|
@@ -228,8 +229,17 @@ class DeviceMesh(Referenceable, MeshTrait):
|
|
228
229
|
def _new_with_shape(self, shape: Shape) -> "DeviceMesh":
|
229
230
|
mesh = DeviceMesh(self.client, shape.ndslice, tuple(shape.labels))
|
230
231
|
mesh.exit = self.exit
|
232
|
+
mesh._subset_of = weakref.ref(self)
|
231
233
|
return mesh
|
232
234
|
|
235
|
+
def _is_subset_of(self, other: "DeviceMesh") -> bool:
|
236
|
+
p = self
|
237
|
+
while p is not None:
|
238
|
+
if p is other:
|
239
|
+
return True
|
240
|
+
p = None if p._subset_of is None else p._subset_of()
|
241
|
+
return False
|
242
|
+
|
233
243
|
def __call__(self, **kwargs) -> "DeviceMesh":
|
234
244
|
"""
|
235
245
|
device_mesh(batch=3) or device_mesh(batch=slice(3, None))
|
monarch/common/messages.py
CHANGED
@@ -17,18 +17,21 @@ from typing import (
|
|
17
17
|
NamedTuple,
|
18
18
|
Optional,
|
19
19
|
Protocol,
|
20
|
+
Sequence,
|
20
21
|
Tuple,
|
21
22
|
TYPE_CHECKING,
|
22
23
|
)
|
23
24
|
|
24
25
|
from monarch._rust_bindings.monarch_extension import tensor_worker
|
26
|
+
from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
|
27
|
+
|
28
|
+
from monarch._src.actor.shape import NDSlice
|
25
29
|
from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction
|
26
30
|
from monarch.common.invocation import DeviceException, RemoteException
|
27
31
|
from monarch.common.reference import Referenceable
|
28
32
|
from monarch.common.tree import flattener
|
29
33
|
from pyre_extensions import none_throws
|
30
34
|
|
31
|
-
from .shape import NDSlice
|
32
35
|
from .tensor_factory import TensorFactory
|
33
36
|
|
34
37
|
if TYPE_CHECKING:
|
@@ -424,6 +427,23 @@ class SendTensor(NamedTuple):
|
|
424
427
|
)
|
425
428
|
|
426
429
|
|
430
|
+
class SendResultOfActorCall(NamedTuple):
|
431
|
+
seq: int
|
432
|
+
broker_id: Tuple[str, int]
|
433
|
+
local_state: Sequence[Tensor | tensor_worker.Ref]
|
434
|
+
mutates: List[tensor_worker.Ref]
|
435
|
+
stream: tensor_worker.StreamRef
|
436
|
+
|
437
|
+
|
438
|
+
class CallActorMethod(NamedTuple):
|
439
|
+
seq: int
|
440
|
+
result: object
|
441
|
+
broker_id: Tuple[str, int]
|
442
|
+
local_state: Sequence[Tensor | tensor_worker.Ref]
|
443
|
+
mutates: List[tensor_worker.Ref]
|
444
|
+
stream: tensor_worker.StreamRef
|
445
|
+
|
446
|
+
|
427
447
|
class SplitComm(NamedTuple):
|
428
448
|
dims: Dims
|
429
449
|
device_mesh: DeviceMesh
|
monarch/common/recording.py
CHANGED
@@ -10,9 +10,9 @@ import traceback
|
|
10
10
|
from collections import defaultdict
|
11
11
|
from typing import cast, Dict, Generator, List, NamedTuple, Tuple, TYPE_CHECKING, Union
|
12
12
|
|
13
|
-
from monarch.
|
13
|
+
from monarch._src.actor.shape import iter_ranks
|
14
14
|
|
15
|
-
from monarch.common.
|
15
|
+
from monarch.common.reference import Ref
|
16
16
|
|
17
17
|
from monarch.common.tensor import InputChecker
|
18
18
|
|
@@ -21,8 +21,9 @@ from . import messages
|
|
21
21
|
if TYPE_CHECKING:
|
22
22
|
from monarch.common.client import Client
|
23
23
|
|
24
|
+
from monarch._src.actor.shape import NDSlice
|
25
|
+
|
24
26
|
from .reference import Referenceable
|
25
|
-
from .shape import NDSlice
|
26
27
|
from .tensor import Tensor
|
27
28
|
|
28
29
|
logger = logging.getLogger(__name__)
|
monarch/common/remote.py
CHANGED
@@ -8,12 +8,12 @@
|
|
8
8
|
|
9
9
|
import functools
|
10
10
|
import logging
|
11
|
-
import warnings
|
12
11
|
|
13
12
|
from logging import Logger
|
14
13
|
from typing import (
|
15
14
|
Any,
|
16
15
|
Callable,
|
16
|
+
cast,
|
17
17
|
Dict,
|
18
18
|
Generic,
|
19
19
|
Literal,
|
@@ -28,12 +28,18 @@ from typing import (
|
|
28
28
|
import monarch.common.messages as messages
|
29
29
|
|
30
30
|
import torch
|
31
|
+
from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
|
32
|
+
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
|
33
|
+
from monarch._src.actor.actor_mesh import Port, PortTuple
|
34
|
+
from monarch._src.actor.endpoint import Extent, Selection
|
31
35
|
|
32
|
-
from monarch.common import _coalescing, device_mesh,
|
36
|
+
from monarch.common import _coalescing, device_mesh, stream
|
37
|
+
from monarch.common.future import Future as OldFuture
|
33
38
|
|
34
39
|
if TYPE_CHECKING:
|
35
40
|
from monarch.common.client import Client
|
36
41
|
|
42
|
+
from monarch._src.actor.endpoint import Endpoint
|
37
43
|
from monarch.common.device_mesh import RemoteProcessGroup
|
38
44
|
from monarch.common.fake import fake_call
|
39
45
|
|
@@ -49,9 +55,9 @@ from monarch.common.function_caching import (
|
|
49
55
|
TensorGroup,
|
50
56
|
TensorPlaceholder,
|
51
57
|
)
|
52
|
-
from monarch.common.future import Future
|
53
58
|
from monarch.common.messages import Dims
|
54
|
-
|
59
|
+
|
60
|
+
from monarch.common.tensor import dtensor_check, dtensor_dispatch, InputChecker
|
55
61
|
from monarch.common.tree import flatten, tree_map
|
56
62
|
from torch import autograd, distributed as dist
|
57
63
|
from typing_extensions import ParamSpec
|
@@ -62,42 +68,96 @@ P = ParamSpec("P")
|
|
62
68
|
R = TypeVar("R")
|
63
69
|
T = TypeVar("T")
|
64
70
|
|
65
|
-
Propagator = Callable | Literal["mocked", "cached", "inspect"] | None
|
66
|
-
|
67
71
|
|
68
|
-
class Remote(Generic[P, R]):
|
72
|
+
class Remote(Generic[P, R], Endpoint[P, R]):
|
69
73
|
def __init__(self, impl: Any, propagator_arg: Propagator):
|
74
|
+
super().__init__(propagator_arg)
|
70
75
|
self._remote_impl = impl
|
71
|
-
|
72
|
-
|
76
|
+
|
77
|
+
def _call_name(self) -> Any:
|
78
|
+
return self._remote_impl
|
79
|
+
|
80
|
+
def _send(
|
81
|
+
self,
|
82
|
+
args: Tuple[Any, ...],
|
83
|
+
kwargs: Dict[str, Any],
|
84
|
+
port: "Optional[Port]" = None,
|
85
|
+
selection: Selection = "all",
|
86
|
+
) -> Extent:
|
87
|
+
ambient_mesh = device_mesh._active
|
88
|
+
propagator = self._fetch_propagate
|
89
|
+
rfunction = self._maybe_resolvable
|
90
|
+
# a None rfunction is an optimization for the identity function (lambda x: x)
|
91
|
+
if rfunction is None:
|
92
|
+
preprocess_message = None
|
93
|
+
rfunction = ResolvableFunctionFromPath("ident")
|
94
|
+
else:
|
95
|
+
preprocess_message = rfunction
|
96
|
+
_, dtensors, mutates, tensor_mesh = dtensor_check(
|
97
|
+
propagator, rfunction, args, kwargs, ambient_mesh, stream._active
|
98
|
+
)
|
99
|
+
|
100
|
+
if ambient_mesh is None:
|
101
|
+
raise ValueError(
|
102
|
+
"Calling a 'remote' monarch function requires an active proc_mesh (`with proc_mesh.activate():`)"
|
103
|
+
)
|
104
|
+
|
105
|
+
if not ambient_mesh._is_subset_of(tensor_mesh):
|
106
|
+
raise ValueError(
|
107
|
+
f"The current mesh {ambient_mesh} is not a subset of the mesh on which the tensors being used are defined {tensor_mesh}"
|
108
|
+
)
|
109
|
+
|
110
|
+
client: "Client" = ambient_mesh.client
|
111
|
+
if _coalescing.is_active(client):
|
112
|
+
raise NotImplementedError("NYI: fetching results during a coalescing block")
|
113
|
+
stream_ref = stream._active._to_ref(client)
|
114
|
+
|
115
|
+
fut = (port, ambient_mesh._ndslice)
|
116
|
+
|
117
|
+
ident = client.new_node(mutates, dtensors, cast("OldFuture", fut))
|
118
|
+
|
119
|
+
client.send(
|
120
|
+
ambient_mesh._ndslice,
|
121
|
+
messages.SendValue(
|
122
|
+
ident,
|
123
|
+
None,
|
124
|
+
mutates,
|
125
|
+
preprocess_message,
|
126
|
+
args,
|
127
|
+
kwargs,
|
128
|
+
stream_ref,
|
129
|
+
),
|
130
|
+
)
|
131
|
+
# we have to ask for status updates
|
132
|
+
# from workers to be sure they have finished
|
133
|
+
# enough work to count this future as finished,
|
134
|
+
# and all potential errors have been reported
|
135
|
+
client._request_status()
|
136
|
+
return Extent(ambient_mesh._labels, ambient_mesh._ndslice.sizes)
|
137
|
+
|
138
|
+
def _port(self, once: bool = False) -> "PortTuple[R]":
|
139
|
+
ambient_mesh = device_mesh._active
|
140
|
+
if ambient_mesh is None:
|
141
|
+
raise ValueError(
|
142
|
+
"FIXME - cannot create a port without an active proc_mesh, because there is not way to create a port without a mailbox"
|
143
|
+
)
|
144
|
+
mesh_controller = getattr(ambient_mesh.client, "_mesh_controller", None)
|
145
|
+
if mesh_controller is None:
|
146
|
+
raise ValueError(
|
147
|
+
"Cannot create raw port objects with an old-style tensor engine controller."
|
148
|
+
)
|
149
|
+
mailbox: Mailbox = mesh_controller._mailbox
|
150
|
+
return PortTuple.create(mailbox, once)
|
73
151
|
|
74
152
|
@property
|
75
153
|
def _resolvable(self):
|
76
154
|
return resolvable_function(self._remote_impl)
|
77
155
|
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
self._cache = {}
|
82
|
-
return _cached_propagation(self._cache, self._resolvable, args, kwargs)
|
83
|
-
elif self._propagator_arg == "inspect":
|
84
|
-
return None
|
85
|
-
elif self._propagator_arg == "mocked":
|
86
|
-
raise NotImplementedError("mocked propagation")
|
87
|
-
else:
|
88
|
-
return fake_call(self._propagator_arg, *fake_args, **fake_kwargs)
|
89
|
-
|
90
|
-
def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs):
|
91
|
-
if self._propagator_arg is None:
|
92
|
-
return # no propgator provided, so we just assume no mutations
|
93
|
-
return self._propagate(args, kwargs, fake_args, fake_kwargs)
|
94
|
-
|
95
|
-
def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs):
|
96
|
-
if not callable(self._propagator_arg):
|
97
|
-
raise ValueError("Must specify explicit callable for pipe")
|
98
|
-
return self._propagate(args, kwargs, fake_args, fake_kwargs)
|
156
|
+
@property
|
157
|
+
def _maybe_resolvable(self):
|
158
|
+
return None if self._remote_impl is None else self._resolvable
|
99
159
|
|
100
|
-
def
|
160
|
+
def _rref(self, args, kwargs):
|
101
161
|
return dtensor_dispatch(
|
102
162
|
self._resolvable,
|
103
163
|
self._propagate,
|
@@ -107,12 +167,8 @@ class Remote(Generic[P, R]):
|
|
107
167
|
stream._active,
|
108
168
|
)
|
109
169
|
|
110
|
-
def
|
111
|
-
self
|
112
|
-
) -> Future[R]:
|
113
|
-
return _call_on_shard_and_fetch(
|
114
|
-
self._resolvable, self._fetch_propagate, *args, shard=shard, **kwargs
|
115
|
-
)
|
170
|
+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
171
|
+
return self.rref(*args, **kwargs)
|
116
172
|
|
117
173
|
|
118
174
|
# This can't just be Callable because otherwise we are not
|
@@ -151,14 +207,43 @@ def remote(function: Any = None, *, propagate: Propagator = None) -> Any:
|
|
151
207
|
return Remote(function, propagate)
|
152
208
|
|
153
209
|
|
154
|
-
|
155
|
-
|
156
|
-
|
210
|
+
remote_identity = Remote(None, lambda x: x)
|
211
|
+
|
212
|
+
|
213
|
+
def call_on_shard_and_fetch(
|
214
|
+
remote: Endpoint[P, R], *args, shard: Dict[str, int] | None = None, **kwargs
|
215
|
+
) -> OldFuture[R]:
|
216
|
+
# We have to flatten the tensors twice: first to discover
|
217
|
+
# which mesh we are working on to shard it, and then again when doing the
|
218
|
+
# dtensor_check in send. This complexity is a consequence of doing
|
219
|
+
# implicit inference of the mesh from the tensors.
|
220
|
+
dtensors, unflatten = flatten((args, kwargs), lambda x: isinstance(x, torch.Tensor))
|
221
|
+
with InputChecker.from_flat_args(
|
222
|
+
remote._call_name(), dtensors, unflatten
|
223
|
+
) as checker:
|
224
|
+
checker.check_mesh_stream_local(device_mesh._active, stream._active)
|
225
|
+
|
226
|
+
if not hasattr(checker.mesh.client, "_mesh_controller"):
|
227
|
+
return _old_call_on_shard_and_fetch(
|
228
|
+
cast("Remote[P, R]", remote),
|
229
|
+
*args,
|
230
|
+
shard=shard,
|
231
|
+
**kwargs,
|
232
|
+
)
|
233
|
+
|
234
|
+
selected_slice = checker.mesh._process(shard)
|
235
|
+
shard_mesh = checker.mesh._new_with_shape(Shape(["_"], selected_slice))
|
236
|
+
with shard_mesh.activate():
|
237
|
+
return cast("OldFuture[R]", remote.call_one(*args, **kwargs))
|
238
|
+
|
239
|
+
|
240
|
+
def _old_call_on_shard_and_fetch(
|
241
|
+
remote_obj: Remote[P, R],
|
157
242
|
/,
|
158
243
|
*args: object,
|
159
244
|
shard: dict[str, int] | None = None,
|
160
245
|
**kwargs: object,
|
161
|
-
) ->
|
246
|
+
) -> OldFuture[R]:
|
162
247
|
"""
|
163
248
|
Call `function` at the coordinates `shard` of the current device mesh, and retrieve the result as a Future.
|
164
249
|
function - the remote function to call
|
@@ -166,6 +251,9 @@ def _call_on_shard_and_fetch(
|
|
166
251
|
shard - a dictionary from mesh dimension name to coordinate of the shard
|
167
252
|
If None, this will fetch from coordinate 0 for all dimensions (useful after all_reduce/all_gather)
|
168
253
|
"""
|
254
|
+
|
255
|
+
rfunction = remote_obj._maybe_resolvable
|
256
|
+
propagator = remote_obj._fetch_propagate
|
169
257
|
ambient_mesh = device_mesh._active
|
170
258
|
|
171
259
|
if rfunction is None:
|
@@ -180,15 +268,9 @@ def _call_on_shard_and_fetch(
|
|
180
268
|
client: "Client" = mesh.client
|
181
269
|
if _coalescing.is_active(client):
|
182
270
|
raise NotImplementedError("NYI: fetching results during a coalescing block")
|
271
|
+
stream_ref = stream._active._to_ref(client)
|
183
272
|
return client.fetch(
|
184
|
-
mesh,
|
185
|
-
stream._active._to_ref(client),
|
186
|
-
shard,
|
187
|
-
preprocess_message,
|
188
|
-
args,
|
189
|
-
kwargs,
|
190
|
-
mutates,
|
191
|
-
dtensors,
|
273
|
+
mesh, stream_ref, shard, preprocess_message, args, kwargs, mutates, dtensors
|
192
274
|
)
|
193
275
|
|
194
276
|
|
@@ -270,8 +352,9 @@ _miss = 0
|
|
270
352
|
_hit = 0
|
271
353
|
|
272
354
|
|
273
|
-
def _cached_propagation(_cache, rfunction, args, kwargs):
|
355
|
+
def _cached_propagation(_cache, rfunction: ResolvableFunction, args, kwargs):
|
274
356
|
tensors, shape_key = hashable_tensor_flatten(args, kwargs)
|
357
|
+
# pyre-ignore
|
275
358
|
inputs_group = TensorGroup([t._fake for t in tensors])
|
276
359
|
requires_grads = tuple(t.requires_grad for t in tensors)
|
277
360
|
key = (shape_key, inputs_group.pattern, requires_grads)
|
@@ -280,8 +363,8 @@ def _cached_propagation(_cache, rfunction, args, kwargs):
|
|
280
363
|
if key not in _cache:
|
281
364
|
_miss += 1
|
282
365
|
args_no_pg, kwargs_no_pg = tree_map(_mock_pgs, (args, kwargs))
|
283
|
-
result_with_placeholders, output_pattern =
|
284
|
-
function=rfunction, args=args_no_pg, kwargs=kwargs_no_pg
|
366
|
+
result_with_placeholders, output_pattern = call_on_shard_and_fetch(
|
367
|
+
_propagate, function=rfunction, args=args_no_pg, kwargs=kwargs_no_pg
|
285
368
|
).result()
|
286
369
|
|
287
370
|
_, unflatten_result = flatten(
|
monarch/common/tensor.py
CHANGED
@@ -40,12 +40,13 @@ from .borrows import StorageAliases
|
|
40
40
|
if TYPE_CHECKING:
|
41
41
|
from monarch.common.device_mesh import DeviceMesh
|
42
42
|
|
43
|
+
from monarch._src.actor.shape import NDSlice
|
44
|
+
|
43
45
|
from .fake import fake_call
|
44
46
|
from .function import Propagator, ResolvableFunction
|
45
47
|
from .invocation import Invocation
|
46
48
|
from .messages import Dims
|
47
49
|
from .reference import Referenceable
|
48
|
-
from .shape import NDSlice
|
49
50
|
from .stream import Stream
|
50
51
|
from .tree import flatten
|
51
52
|
|
monarch/controller/backend.py
CHANGED
@@ -13,9 +13,9 @@ import socket
|
|
13
13
|
from abc import ABC, abstractmethod
|
14
14
|
from typing import List, NamedTuple, Optional, Sequence, Tuple
|
15
15
|
|
16
|
-
from monarch.
|
16
|
+
from monarch._src.actor.shape import iter_ranks, Slices as Ranks
|
17
17
|
|
18
|
-
from monarch.common
|
18
|
+
from monarch.common import messages
|
19
19
|
from monarch_supervisor import (
|
20
20
|
Context,
|
21
21
|
FunctionCall,
|
monarch/controller/controller.py
CHANGED
@@ -19,11 +19,12 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarc
|
|
19
19
|
ActorId,
|
20
20
|
)
|
21
21
|
|
22
|
+
from monarch._src.actor.shape import NDSlice
|
23
|
+
|
22
24
|
from monarch.common import messages
|
23
25
|
from monarch.common.controller_api import LogMessage, MessageResult
|
24
26
|
from monarch.common.invocation import DeviceException, Seq
|
25
27
|
from monarch.common.reference import Ref
|
26
|
-
from monarch.common.shape import NDSlice
|
27
28
|
from monarch.common.tensor import Tensor
|
28
29
|
from monarch.controller import debugger
|
29
30
|
|
@@ -29,11 +29,12 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarc
|
|
29
29
|
)
|
30
30
|
|
31
31
|
from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
|
32
|
+
|
33
|
+
from monarch._src.actor.shape import NDSlice
|
32
34
|
from monarch.common.controller_api import LogMessage, MessageResult
|
33
35
|
from monarch.common.device_mesh import no_mesh
|
34
36
|
from monarch.common.invocation import DeviceException, RemoteException
|
35
37
|
from monarch.common.messages import SupportsToRustMessage
|
36
|
-
from monarch.common.shape import NDSlice
|
37
38
|
from monarch.common.tensor import Tensor
|
38
39
|
from monarch.controller.debugger import read as debugger_read, write as debugger_write
|
39
40
|
from pyre_extensions import none_throws
|
monarch/fetch.py
CHANGED
@@ -9,13 +9,13 @@
|
|
9
9
|
This is a utility file for fetching a shard of a tensor from remote.
|
10
10
|
"""
|
11
11
|
|
12
|
-
from typing import TypeVar
|
12
|
+
from typing import cast, TypeVar
|
13
13
|
|
14
14
|
from monarch.common.device_mesh import no_mesh
|
15
15
|
|
16
16
|
from monarch.common.future import Future
|
17
17
|
|
18
|
-
from monarch.common.remote import
|
18
|
+
from monarch.common.remote import call_on_shard_and_fetch, remote_identity
|
19
19
|
|
20
20
|
T = TypeVar("T")
|
21
21
|
|
@@ -37,9 +37,7 @@ def fetch_shard(
|
|
37
37
|
shard = {}
|
38
38
|
shard.update(kwargs)
|
39
39
|
|
40
|
-
return
|
41
|
-
None, lambda *args, **kwargs: None, obj, shard=shard
|
42
|
-
)
|
40
|
+
return cast("Future[T]", call_on_shard_and_fetch(remote_identity, obj, shard=shard))
|
43
41
|
|
44
42
|
|
45
43
|
def show(obj: T, shard: dict[str, int] | None = None, **kwargs: int) -> object:
|
Binary file
|