torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,48 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import io
|
8
|
+
import pickle
|
9
|
+
from typing import Any, Callable, Iterable, List, Tuple
|
10
|
+
|
11
|
+
import cloudpickle
|
12
|
+
|
13
|
+
|
14
|
+
class _Pickler(cloudpickle.Pickler):
|
15
|
+
def __init__(self, filter):
|
16
|
+
self.f = io.BytesIO()
|
17
|
+
super().__init__(self.f)
|
18
|
+
self._filter = filter
|
19
|
+
self._saved = []
|
20
|
+
|
21
|
+
def persistent_id(self, obj):
|
22
|
+
if not self._filter(obj):
|
23
|
+
return None
|
24
|
+
self._saved.append(obj)
|
25
|
+
return len(self._saved) - 1
|
26
|
+
|
27
|
+
|
28
|
+
class _Unpickler(pickle.Unpickler):
|
29
|
+
def __init__(self, data, sequence: Iterable[Any]):
|
30
|
+
super().__init__(io.BytesIO(data))
|
31
|
+
self._iter = iter(sequence)
|
32
|
+
self._values = []
|
33
|
+
|
34
|
+
def persistent_load(self, id):
|
35
|
+
while id >= len(self._values):
|
36
|
+
self._values.append(next(self._iter))
|
37
|
+
return self._values[id]
|
38
|
+
|
39
|
+
|
40
|
+
def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], bytes]:
|
41
|
+
pickler = _Pickler(filter)
|
42
|
+
pickler.dump(obj)
|
43
|
+
return pickler._saved, pickler.f.getvalue()
|
44
|
+
|
45
|
+
|
46
|
+
def unflatten(data: bytes, values: Iterable[Any]) -> Any:
|
47
|
+
up = _Unpickler(data, values)
|
48
|
+
return up.load()
|
monarch/common/pipe.py
ADDED
@@ -0,0 +1,152 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import uuid
|
9
|
+
from collections import deque
|
10
|
+
from typing import Any, Dict
|
11
|
+
|
12
|
+
import torch
|
13
|
+
from monarch.common.remote import Remote, remote
|
14
|
+
|
15
|
+
from . import device_mesh, messages, stream
|
16
|
+
from .fake import fake_call
|
17
|
+
from .function import ResolvableFunctionFromPath
|
18
|
+
from .reference import Referenceable
|
19
|
+
from .tensor import dtensor_check, Tensor
|
20
|
+
from .tree import flatten
|
21
|
+
|
22
|
+
|
23
|
+
def remote_generator(path: str, max_messages: int = 50):
|
24
|
+
def wrapper(annotation):
|
25
|
+
fn = remote(path, propagate=annotation)
|
26
|
+
return lambda *args, **kwargs: create_pipe(
|
27
|
+
fn, *args, max_messages=max_messages, **kwargs
|
28
|
+
)
|
29
|
+
|
30
|
+
return wrapper
|
31
|
+
|
32
|
+
|
33
|
+
def create_pipe(fn, *args, max_messages: int = 50, **kwargs):
|
34
|
+
return Pipe(fn, max_messages, args, kwargs)
|
35
|
+
|
36
|
+
|
37
|
+
class Pipe(Referenceable):
|
38
|
+
"""
|
39
|
+
Pipe abstraction on the controller. Designed to to be used with ipc PAIR sockets, e.g dataloaders and trainers.
|
40
|
+
|
41
|
+
Example::
|
42
|
+
@remote_generator('dataloader.main')
|
43
|
+
def dataloader_pipe(pipe: Pipe, batch_size: int, sequence_length: int):
|
44
|
+
while True:
|
45
|
+
yield {
|
46
|
+
'input': torch.zeros(batch_size, sequence_length),
|
47
|
+
'target': torch.zeros(batch_size)
|
48
|
+
}
|
49
|
+
|
50
|
+
# On the controller
|
51
|
+
with mesh.activate():
|
52
|
+
dataloader = dataloader_pipe(1, 1)
|
53
|
+
input, target = dataloader.recv()
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(self, fn: Remote, max_messages: int, args, kwargs):
|
57
|
+
mesh = device_mesh._active
|
58
|
+
if mesh is None:
|
59
|
+
raise ValueError(
|
60
|
+
"Remote generators require an active device mesh (use `with mesh.activate():`"
|
61
|
+
)
|
62
|
+
mesh.define_remotely()
|
63
|
+
|
64
|
+
def no_references(x):
|
65
|
+
if isinstance(x, Referenceable):
|
66
|
+
raise ValueError("Cannot pass references to external generators")
|
67
|
+
|
68
|
+
flatten((args, kwargs), no_references)
|
69
|
+
self._fake_pipe = FakePipe()
|
70
|
+
if not isinstance(fn, Remote):
|
71
|
+
raise TypeError("expected fn to be a monarch.remote function.")
|
72
|
+
args_ = (self._fake_pipe, *args)
|
73
|
+
# we do not pass references to generators so fake_args == args
|
74
|
+
self._iterator = iter(fn._pipe_propagate(args_, kwargs, args_, kwargs))
|
75
|
+
self.ref = mesh.client.new_ref()
|
76
|
+
self.mesh = mesh
|
77
|
+
key = f"ipc:///tmp/proc-{uuid.uuid4()}"
|
78
|
+
self.mesh._send(
|
79
|
+
messages.CreatePipe(
|
80
|
+
self, key, fn._resolvable, max_messages, mesh, args, kwargs
|
81
|
+
)
|
82
|
+
)
|
83
|
+
|
84
|
+
def send(self, obj: Any):
|
85
|
+
client = self.mesh.client
|
86
|
+
_fake_result, dtensors, _mutates, device_mesh = dtensor_check(
|
87
|
+
(lambda args, kwargs, fake_args, fake_kwargs: fake_args[0]),
|
88
|
+
ResolvableFunctionFromPath("ident"),
|
89
|
+
(obj,),
|
90
|
+
{},
|
91
|
+
self.mesh,
|
92
|
+
stream._active,
|
93
|
+
)
|
94
|
+
if self.mesh is not device_mesh:
|
95
|
+
raise ValueError(
|
96
|
+
f"Pipe is defined on mesh {self.mesh} but inputs are defined on mesh {device_mesh}"
|
97
|
+
)
|
98
|
+
self._fake_pipe._fake_sends.append(_fake_result)
|
99
|
+
seq = client.new_node((), dtensors)
|
100
|
+
self.mesh._send(
|
101
|
+
messages.SendValue(
|
102
|
+
seq, self, (), None, (obj,), {}, stream._active._to_ref(client)
|
103
|
+
)
|
104
|
+
)
|
105
|
+
|
106
|
+
def recv(self) -> Any:
|
107
|
+
mesh = self.mesh
|
108
|
+
fake_result = fake_call(next, self._iterator)
|
109
|
+
fake_result_tensors, unflatten = flatten(
|
110
|
+
fake_result, lambda x: isinstance(x, torch.Tensor)
|
111
|
+
)
|
112
|
+
tensors = tuple(
|
113
|
+
Tensor(fake, mesh, stream._active) for fake in fake_result_tensors
|
114
|
+
)
|
115
|
+
seq = mesh.client.new_node(tensors, ())
|
116
|
+
result = unflatten(tensors)
|
117
|
+
mesh._send(
|
118
|
+
messages.PipeRecv(seq, result, self, stream._active._to_ref(mesh.client))
|
119
|
+
)
|
120
|
+
return result
|
121
|
+
|
122
|
+
def delete_ref(self, ref: int):
|
123
|
+
if not self.mesh.client._shutdown:
|
124
|
+
self.mesh.client.handle_deletes(self.mesh.processes, [ref])
|
125
|
+
|
126
|
+
# make typechecking happy for actual process functions
|
127
|
+
@property
|
128
|
+
def ranks(self) -> Dict["str", int]:
|
129
|
+
raise ValueError("cannot be accessed on controller")
|
130
|
+
|
131
|
+
@property
|
132
|
+
def sizes(self) -> Dict["str", int]:
|
133
|
+
raise ValueError("cannot be accessed on controller")
|
134
|
+
|
135
|
+
|
136
|
+
class FakePipe(Pipe):
|
137
|
+
"""
|
138
|
+
Container to observe faked objects that the controller sent to the process
|
139
|
+
"""
|
140
|
+
|
141
|
+
def __init__(self):
|
142
|
+
self._fake_sends = deque[Any]()
|
143
|
+
self.ref = None
|
144
|
+
|
145
|
+
def send(self, obj: Any):
|
146
|
+
raise RuntimeError(
|
147
|
+
"Rather than p.send(x) use yield x to simulate a pipe worker sending data."
|
148
|
+
)
|
149
|
+
|
150
|
+
def recv(self):
|
151
|
+
if self._fake_sends:
|
152
|
+
return self._fake_sends.popleft()
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
|
9
|
+
import logging
|
10
|
+
|
11
|
+
import torch.distributed as dist
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
def _wrap_method(process_group: dist.ProcessGroup, method):
|
17
|
+
def wrapper(*args, **kwargs):
|
18
|
+
logger.debug(
|
19
|
+
"ProcessGroup Call: %s with args %s and kwargs %s", method, args, kwargs
|
20
|
+
)
|
21
|
+
fn = getattr(process_group, method)
|
22
|
+
try:
|
23
|
+
return fn(*args, **kwargs)
|
24
|
+
except Exception as e:
|
25
|
+
logger.warning(
|
26
|
+
"ProcessGroup Call: %s with args %s and kwargs %s failed with exception: %s",
|
27
|
+
method,
|
28
|
+
args,
|
29
|
+
kwargs,
|
30
|
+
str(e),
|
31
|
+
)
|
32
|
+
# TODO(rajeshn): send a message back to the controller that this
|
33
|
+
# worker had a failed communication event
|
34
|
+
raise e
|
35
|
+
|
36
|
+
return wrapper
|
37
|
+
|
38
|
+
|
39
|
+
class SingleControllerProcessGroupWrapper:
|
40
|
+
"""
|
41
|
+
Wraps a ProcessGroup object to provide a single controller process group. This provides us a hook to observe
|
42
|
+
all the operatons on the process group to the controller.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __new__(cls, pg: dist.ProcessGroup):
|
46
|
+
instance = super().__new__(cls)
|
47
|
+
|
48
|
+
for attr in dir(type(pg)):
|
49
|
+
if not attr.startswith("__") and callable(getattr(type(pg), attr)):
|
50
|
+
setattr(instance, attr, _wrap_method(pg, attr))
|
51
|
+
|
52
|
+
return instance
|
53
|
+
|
54
|
+
def __init__(self, process_group):
|
55
|
+
self.process_group = process_group
|
@@ -0,0 +1,127 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import logging
|
9
|
+
import traceback
|
10
|
+
from collections import defaultdict
|
11
|
+
from typing import cast, Dict, Generator, List, NamedTuple, Tuple, TYPE_CHECKING, Union
|
12
|
+
|
13
|
+
from monarch.common.reference import Ref
|
14
|
+
|
15
|
+
from monarch.common.shape import iter_ranks
|
16
|
+
|
17
|
+
from monarch.common.tensor import InputChecker
|
18
|
+
|
19
|
+
from . import messages
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from monarch.common.client import Client
|
23
|
+
|
24
|
+
from .reference import Referenceable
|
25
|
+
from .shape import NDSlice
|
26
|
+
from .tensor import Tensor
|
27
|
+
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
_MAX_MESSAGES_PER_DEFINE_RECORDING = 1000
|
31
|
+
|
32
|
+
|
33
|
+
def flatten_messages(
|
34
|
+
messages: List[Tuple[Union[NDSlice, List[NDSlice]], NamedTuple]],
|
35
|
+
) -> Dict[int, List[NamedTuple]]:
|
36
|
+
result: Dict[int, List[NamedTuple]] = defaultdict(list)
|
37
|
+
for ranks, msg in messages:
|
38
|
+
for rank in iter_ranks(ranks):
|
39
|
+
result[rank].append(msg)
|
40
|
+
return result
|
41
|
+
|
42
|
+
|
43
|
+
class Recording(Referenceable):
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
client: "Client",
|
47
|
+
uses: List["Tensor"],
|
48
|
+
mutates: List["Tensor"],
|
49
|
+
mutated_formal_indices: List[int],
|
50
|
+
tracebacks: List[List[traceback.FrameSummary]],
|
51
|
+
buffered_messages: List[Tuple[Union[NDSlice, List[NDSlice]], NamedTuple]],
|
52
|
+
nresults: int,
|
53
|
+
nformals: int,
|
54
|
+
first_ref: int,
|
55
|
+
):
|
56
|
+
self.uses = uses
|
57
|
+
self.mutates = mutates
|
58
|
+
# on future invocations of this recording, new aliases for our mutated tensors exists
|
59
|
+
# and we will technically mutate them as well. This would be simplified and faster if our
|
60
|
+
# node tracking worked with storages rather than tensors, but for now we have to collect
|
61
|
+
# all the aliases on each invocation
|
62
|
+
self.mutate_aliases = [m._aliases.aliases for m in self.mutates]
|
63
|
+
self.mutated_formal_indices = mutated_formal_indices
|
64
|
+
self.tracebacks = tracebacks
|
65
|
+
self.ref = client.new_ref()
|
66
|
+
self.first_ref = first_ref
|
67
|
+
self.client = client
|
68
|
+
self.buffered_messages = buffered_messages
|
69
|
+
flat_messages = flatten_messages(self.buffered_messages)
|
70
|
+
self.ranks = NDSlice.from_list(sorted(flat_messages.keys()))
|
71
|
+
for rank, msgs in flat_messages.items():
|
72
|
+
ndslice = NDSlice(offset=rank, sizes=[], strides=[])
|
73
|
+
ntotal_messages = len(msgs) // _MAX_MESSAGES_PER_DEFINE_RECORDING + (
|
74
|
+
1 if len(msgs) % _MAX_MESSAGES_PER_DEFINE_RECORDING else 0
|
75
|
+
)
|
76
|
+
for enum_index, msg_index in enumerate(
|
77
|
+
range(0, len(msgs), _MAX_MESSAGES_PER_DEFINE_RECORDING)
|
78
|
+
):
|
79
|
+
self.client.send_nocoalesce(
|
80
|
+
ndslice,
|
81
|
+
messages.DefineRecording(
|
82
|
+
self,
|
83
|
+
nresults,
|
84
|
+
nformals,
|
85
|
+
msgs[
|
86
|
+
msg_index : msg_index # noqa: E203
|
87
|
+
+ _MAX_MESSAGES_PER_DEFINE_RECORDING
|
88
|
+
],
|
89
|
+
ntotal_messages,
|
90
|
+
enum_index,
|
91
|
+
),
|
92
|
+
)
|
93
|
+
|
94
|
+
def run(self, results: Generator[Tensor, None, None], actuals: List[Tensor]):
|
95
|
+
all_uses: List[Tensor] = [*self.uses, *actuals]
|
96
|
+
with InputChecker.from_flat_args(
|
97
|
+
"recording", all_uses, lambda ts: (tuple(ts), {})
|
98
|
+
) as checker:
|
99
|
+
mutates_actuals = [
|
100
|
+
actuals[i]._aliases.aliases for i in self.mutated_formal_indices
|
101
|
+
]
|
102
|
+
mutates = list(set().union(*self.mutate_aliases, *mutates_actuals))
|
103
|
+
checker.check_permission(mutates)
|
104
|
+
# we are careful to not generate the results tensors until
|
105
|
+
# after the input checker so that we do not create tensor objects
|
106
|
+
# for tensors that will never be defined by CallRecording
|
107
|
+
results_tuple = list(results)
|
108
|
+
seq = self.client.new_node(
|
109
|
+
results_tuple + mutates,
|
110
|
+
all_uses,
|
111
|
+
None,
|
112
|
+
self.tracebacks,
|
113
|
+
)
|
114
|
+
self.client.send(
|
115
|
+
self.ranks,
|
116
|
+
messages.CallRecording(
|
117
|
+
seq,
|
118
|
+
self,
|
119
|
+
cast(List[Tensor | Ref], results_tuple),
|
120
|
+
cast(List[Tensor | Ref], actuals),
|
121
|
+
),
|
122
|
+
)
|
123
|
+
return results_tuple
|
124
|
+
|
125
|
+
def delete_ref(self, ref: int):
|
126
|
+
if not self.client.has_shutdown:
|
127
|
+
self.client.handle_deletes(self.ranks, [ref])
|
@@ -0,0 +1,33 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from typing import Optional
|
9
|
+
|
10
|
+
from monarch._rust_bindings.monarch_extension.tensor_worker import Ref
|
11
|
+
|
12
|
+
|
13
|
+
class Referenceable:
|
14
|
+
def __init__(self):
|
15
|
+
self.ref: Optional[int] = None
|
16
|
+
|
17
|
+
def delete_ref(self, ref):
|
18
|
+
raise NotImplementedError("no delete_ref method")
|
19
|
+
|
20
|
+
def __reduce_ex__(self, protocol):
|
21
|
+
assert (
|
22
|
+
self.ref is not None
|
23
|
+
), f"{self} is being sent but does not have a reference"
|
24
|
+
return Ref, (self.ref,)
|
25
|
+
|
26
|
+
# Used by rust backend to get the ref for this object
|
27
|
+
def __monarch_ref__(self) -> int:
|
28
|
+
assert self.ref is not None
|
29
|
+
return self.ref
|
30
|
+
|
31
|
+
def __del__(self):
|
32
|
+
if self.ref is not None:
|
33
|
+
self.delete_ref(self.ref)
|