torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,308 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
|
9
|
+
import functools
|
10
|
+
from collections import defaultdict
|
11
|
+
from contextlib import contextmanager
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from typing import (
|
14
|
+
Any,
|
15
|
+
Callable,
|
16
|
+
Dict,
|
17
|
+
Generator,
|
18
|
+
List,
|
19
|
+
NamedTuple,
|
20
|
+
Optional,
|
21
|
+
Sequence,
|
22
|
+
Tuple,
|
23
|
+
TYPE_CHECKING,
|
24
|
+
)
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from monarch.common import messages
|
28
|
+
|
29
|
+
from monarch.common.fake import fake_call
|
30
|
+
from monarch.common.function_caching import (
|
31
|
+
hashable_tensor_flatten,
|
32
|
+
TensorGroup,
|
33
|
+
TensorGroupPattern,
|
34
|
+
)
|
35
|
+
from monarch.common.tensor import InputChecker, Tensor
|
36
|
+
from monarch.common.tree import flatten
|
37
|
+
|
38
|
+
if TYPE_CHECKING:
|
39
|
+
from monarch.common.client import Recorder
|
40
|
+
from monarch.common.recording import Recording
|
41
|
+
|
42
|
+
from .client import Client
|
43
|
+
|
44
|
+
_coalescing = None
|
45
|
+
|
46
|
+
|
47
|
+
class CoalescingState:
|
48
|
+
def __init__(self, recording=False):
|
49
|
+
self.controller: Optional["Client"] = None
|
50
|
+
self.recorder: Optional["Recorder"] = None
|
51
|
+
self.recording = recording
|
52
|
+
|
53
|
+
def set_controller(self, controller: "Client"):
|
54
|
+
if self.controller is None:
|
55
|
+
self.controller = controller
|
56
|
+
controller.flush_deletes(False)
|
57
|
+
if self.controller is not controller:
|
58
|
+
raise ValueError(
|
59
|
+
"using multiple controllers in the same coalescing block is not supported"
|
60
|
+
)
|
61
|
+
|
62
|
+
@contextmanager
|
63
|
+
def activate(self) -> Generator[None, Any, Any]:
|
64
|
+
global _coalescing
|
65
|
+
assert _coalescing is None
|
66
|
+
finished = False
|
67
|
+
try:
|
68
|
+
_coalescing = self
|
69
|
+
yield
|
70
|
+
finished = True
|
71
|
+
finally:
|
72
|
+
ctrl = self.controller
|
73
|
+
if ctrl is not None:
|
74
|
+
if finished:
|
75
|
+
ctrl.flush_deletes()
|
76
|
+
self.recorder = ctrl.reset_recorder()
|
77
|
+
if not finished:
|
78
|
+
self.recorder.abandon()
|
79
|
+
_coalescing = None
|
80
|
+
|
81
|
+
|
82
|
+
@contextmanager
|
83
|
+
def coalescing() -> Generator[None, Any, Any]:
|
84
|
+
global _coalescing
|
85
|
+
if _coalescing is not None:
|
86
|
+
yield
|
87
|
+
return
|
88
|
+
|
89
|
+
state = CoalescingState()
|
90
|
+
with state.activate():
|
91
|
+
yield
|
92
|
+
|
93
|
+
if state.recorder is not None:
|
94
|
+
assert state.controller is not None
|
95
|
+
state.recorder.run_once(state.controller)
|
96
|
+
|
97
|
+
|
98
|
+
def _record_and_define(
|
99
|
+
fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
100
|
+
) -> "CacheEntry":
|
101
|
+
input_tensors, unflatten_input = flatten(
|
102
|
+
(args, kwargs), lambda x: isinstance(x, Tensor)
|
103
|
+
)
|
104
|
+
|
105
|
+
with InputChecker.from_flat_args(
|
106
|
+
"compile", input_tensors, unflatten_input
|
107
|
+
) as checker:
|
108
|
+
checker.check_no_requires_grad()
|
109
|
+
|
110
|
+
for a in input_tensors:
|
111
|
+
assert a._seq is not None
|
112
|
+
|
113
|
+
state = CoalescingState(recording=True)
|
114
|
+
with state.activate():
|
115
|
+
formal_tensors = []
|
116
|
+
for i, input in enumerate(input_tensors):
|
117
|
+
state.set_controller(input.mesh.client)
|
118
|
+
t = Tensor(input._fake, input.mesh, input.stream)
|
119
|
+
input.mesh._send(
|
120
|
+
messages.RecordingFormal(t, i, t.stream._to_ref(input.mesh.client))
|
121
|
+
)
|
122
|
+
formal_tensors.append(t)
|
123
|
+
formal_args, formal_kwargs = unflatten_input(formal_tensors)
|
124
|
+
recorded_result = fn(*formal_args, **formal_kwargs)
|
125
|
+
output_tensors, unflatten_result = flatten(
|
126
|
+
recorded_result, lambda x: isinstance(x, Tensor)
|
127
|
+
)
|
128
|
+
with InputChecker(
|
129
|
+
output_tensors,
|
130
|
+
lambda ts: f"{unflatten_result(ts)} = compiled_function(...)",
|
131
|
+
) as checker:
|
132
|
+
checker.check_no_requires_grad()
|
133
|
+
for i, output in enumerate(output_tensors):
|
134
|
+
state.set_controller(output.mesh.client)
|
135
|
+
output.mesh._send(
|
136
|
+
messages.RecordingResult(
|
137
|
+
output, i, output.stream._to_ref(output.mesh.client)
|
138
|
+
)
|
139
|
+
)
|
140
|
+
|
141
|
+
recorder = state.recorder
|
142
|
+
if recorder is None:
|
143
|
+
# no input tensors or output tensors, so just cache the result
|
144
|
+
return CacheEntry(
|
145
|
+
TensorGroup([]),
|
146
|
+
TensorGroupPattern(()),
|
147
|
+
lambda args, kwargs: recorded_result,
|
148
|
+
None,
|
149
|
+
)
|
150
|
+
|
151
|
+
controller = state.controller
|
152
|
+
assert controller is not None
|
153
|
+
recorder.add((), output_tensors, [])
|
154
|
+
recording = recorder.define_recording(
|
155
|
+
controller, len(output_tensors), len(input_tensors)
|
156
|
+
)
|
157
|
+
|
158
|
+
fake_uses = [r._fake for r in recording.uses]
|
159
|
+
captures_group = TensorGroup(fake_uses)
|
160
|
+
inputs_group = TensorGroup([i._fake for i in input_tensors], parent=captures_group)
|
161
|
+
|
162
|
+
outputs_group = TensorGroup([o._fake for o in output_tensors], parent=inputs_group)
|
163
|
+
outputs_pattern = outputs_group.pattern
|
164
|
+
|
165
|
+
def run(args, kwargs):
|
166
|
+
actuals, _ = flatten((args, kwargs), lambda x: isinstance(x, Tensor))
|
167
|
+
for a in actuals:
|
168
|
+
assert a._seq is not None
|
169
|
+
|
170
|
+
fake_result_tensors = fake_call(
|
171
|
+
outputs_pattern.empty, [fake_uses, [a._fake for a in actuals]]
|
172
|
+
)
|
173
|
+
|
174
|
+
# recording.run does permissions checks on all the tensors.
|
175
|
+
# if those checks fail then the tensors here will have been created
|
176
|
+
# but not defined, causes spurious delete messages.
|
177
|
+
# To avoid this, we pass a generator rather than a list
|
178
|
+
# and only create the tensors in run
|
179
|
+
result_tensors_generator = (
|
180
|
+
Tensor(f, o.mesh, o.stream)
|
181
|
+
for f, o in zip(fake_result_tensors, output_tensors)
|
182
|
+
)
|
183
|
+
return unflatten_result(recording.run(result_tensors_generator, actuals))
|
184
|
+
|
185
|
+
return CacheEntry(captures_group, inputs_group.pattern, run, recording)
|
186
|
+
|
187
|
+
|
188
|
+
@dataclass
|
189
|
+
class CacheEntry:
|
190
|
+
captures_group: TensorGroup
|
191
|
+
inputs_pattern: TensorGroupPattern
|
192
|
+
run: Callable[[Tuple[Any, ...], Dict[str, Any]], Any]
|
193
|
+
to_verify: Optional["Recording"]
|
194
|
+
|
195
|
+
def matches(self, input_tensors: List[torch.Tensor]) -> bool:
|
196
|
+
# if an input aliases a captured tensor, then we have
|
197
|
+
# to check that all future inputs alias the _same exact_
|
198
|
+
# captured tensor. These are additional checks after
|
199
|
+
# matching on the pattern of aliasing for just the inputs because
|
200
|
+
# we do not what the captures would be without first matching the inputs without the captures.
|
201
|
+
inputs_group = TensorGroup(input_tensors, parent=self.captures_group)
|
202
|
+
return self.inputs_pattern == inputs_group.pattern
|
203
|
+
|
204
|
+
|
205
|
+
def compile(fn=None, verify=True):
|
206
|
+
"""
|
207
|
+
Wraps `fn` such that it records and later replays a single message to workers
|
208
|
+
to instruct them to run the entire contents of this function. Since the function invocation
|
209
|
+
is much smaller than the original set of messages and since we do not re-execute the python inside
|
210
|
+
the function after recording, this has substantially lower latency.
|
211
|
+
|
212
|
+
While eventually `compile` will be backed by `torch.compile`'s dynamo executor, it currently
|
213
|
+
works as a simple tracer with the following rules for when it chooses to trace vs when
|
214
|
+
it will reuse an existing trace.
|
215
|
+
|
216
|
+
A new trace is created whenever:
|
217
|
+
|
218
|
+
* The _values_ of a non-tensor argument to fn have not been seen before.
|
219
|
+
* The _metadata_ of a tensor arguments has not been seen before. Metadata includes the sizes, strides,
|
220
|
+
dtype, devices, layout, device meshes, streams, and pattern of aliasing of the arguments
|
221
|
+
with respect to other arguments and any values the trace captures.
|
222
|
+
|
223
|
+
A new trace will not be created in these following situations that are known to be **unsafe**:
|
224
|
+
|
225
|
+
* A value that is not an argument to the function but is used by the function (e.g. a global),
|
226
|
+
changes in a way that would affect what messages are being sent.
|
227
|
+
* A tensor that is not an argument to the function changes metadata, or gets reassigned to
|
228
|
+
a new tensor in Python.
|
229
|
+
|
230
|
+
|
231
|
+
The trace is allowed to use tensors that are referenced in the body but not listed as arguments,
|
232
|
+
such as globals or closure-captured locals as long as these values are not modified in the
|
233
|
+
the ways that are listed as unsafe above. When switched to a torch.compile backed version,
|
234
|
+
these safety caveats will be improved.
|
235
|
+
|
236
|
+
Compilation currently does not work if the inputs or outputs to the function have `requires_grad=True`,
|
237
|
+
because we will not generate a correctly backwards pass graph. However, captured tensors
|
238
|
+
are allowed to be requires_grad=True, and gradient calculation (forward+backward)
|
239
|
+
can run entirely within the function.
|
240
|
+
|
241
|
+
Can be used as a wrapper:
|
242
|
+
wrapped = compile(my_function, verify=False)
|
243
|
+
|
244
|
+
Or as a decorator:
|
245
|
+
|
246
|
+
@compile
|
247
|
+
def my_function(...):
|
248
|
+
...
|
249
|
+
|
250
|
+
@compile(verify=False)
|
251
|
+
def my_function(...):
|
252
|
+
...
|
253
|
+
|
254
|
+
Args:
|
255
|
+
|
256
|
+
fn (callable): the function to be wrapped. (Default: None, in which case we return a single argument,
|
257
|
+
function that can be used as a decorator)
|
258
|
+
verify (bool): To guard as much as possible against the above unsafe situations,
|
259
|
+
if `verify=True`, the first time we would reuse a trace, we additionally do another
|
260
|
+
recording and check the second recording matches the original recording, and report
|
261
|
+
where they diverge. (Default: True)
|
262
|
+
|
263
|
+
|
264
|
+
Returns:
|
265
|
+
If fn=None, it returns a function that can be used as a decorator on a function to
|
266
|
+
be wrapped. Otherwise, it returns the wrapped function itself.
|
267
|
+
|
268
|
+
"""
|
269
|
+
if fn is None:
|
270
|
+
return lambda fn: compile(fn, verify)
|
271
|
+
cache: Dict[Any, Recording] = defaultdict(list)
|
272
|
+
|
273
|
+
@functools.wraps(fn)
|
274
|
+
def wrapper(*args, **kwargs):
|
275
|
+
global _coalescing
|
276
|
+
if _coalescing:
|
277
|
+
return fn(*args, **kwargs)
|
278
|
+
|
279
|
+
tensors, shape_key = hashable_tensor_flatten(args, kwargs)
|
280
|
+
input_group = TensorGroup([t._fake for t in tensors])
|
281
|
+
props = tuple((t.mesh, t.stream, t.requires_grad) for t in tensors)
|
282
|
+
key = (shape_key, input_group.pattern, props)
|
283
|
+
for entry in cache[key]:
|
284
|
+
if entry.matches(input_group.tensors):
|
285
|
+
if entry.to_verify is not None:
|
286
|
+
entry.to_verify.client.recorder.verify_against(entry.to_verify)
|
287
|
+
_record_and_define(fn, args, kwargs)
|
288
|
+
entry.to_verify = None
|
289
|
+
return entry.run(args, kwargs)
|
290
|
+
|
291
|
+
entry = _record_and_define(fn, args, kwargs)
|
292
|
+
if not verify:
|
293
|
+
entry.to_verify = None
|
294
|
+
cache[key].append(entry)
|
295
|
+
return entry.run(args, kwargs)
|
296
|
+
|
297
|
+
return wrapper
|
298
|
+
|
299
|
+
|
300
|
+
def is_active(controller: "Client"):
|
301
|
+
if _coalescing is None:
|
302
|
+
return False
|
303
|
+
_coalescing.set_controller(controller)
|
304
|
+
return True
|
305
|
+
|
306
|
+
|
307
|
+
def is_recording(controller: "Client"):
|
308
|
+
return is_active(controller) and _coalescing.recording
|
@@ -0,0 +1,18 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import os
|
8
|
+
import re
|
9
|
+
from pathlib import Path
|
10
|
+
|
11
|
+
|
12
|
+
def _local_device_count():
|
13
|
+
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
14
|
+
return len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
|
15
|
+
dev_path = Path("/dev")
|
16
|
+
pattern = re.compile(r"nvidia\d+$")
|
17
|
+
nvidia_devices = [dev for dev in dev_path.iterdir() if pattern.match(dev.name)]
|
18
|
+
return len(nvidia_devices)
|
@@ -0,0 +1,172 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
from typing import Callable, List, Optional
|
8
|
+
|
9
|
+
import torch
|
10
|
+
|
11
|
+
|
12
|
+
def tensor_to_table(
|
13
|
+
tensor: torch.Tensor,
|
14
|
+
format_data: Callable,
|
15
|
+
axis_labels: Optional[List[List[str]]] = None,
|
16
|
+
axis_names: Optional[List[str]] = None,
|
17
|
+
format_spec: str = ".4f",
|
18
|
+
table_format: str = "grid",
|
19
|
+
) -> str:
|
20
|
+
"""
|
21
|
+
Convert a tensor into formatted tables with generic dimension handling.
|
22
|
+
|
23
|
+
Parameters:
|
24
|
+
-----------
|
25
|
+
tensor : torch.Tensor or np.ndarray
|
26
|
+
Input tensor to be converted (1D, 2D, or 3D)
|
27
|
+
axis_labels : list of lists, optional
|
28
|
+
Labels for each axis, ordered from outer to inner dimension
|
29
|
+
For 1D: [column_labels]
|
30
|
+
For 2D: [row_labels, column_labels]
|
31
|
+
For 3D: [depth_labels, row_labels, column_labels]
|
32
|
+
axis_names : list, optional
|
33
|
+
Names for each axis, ordered from outer to inner dimension
|
34
|
+
For 1D: [column_name]
|
35
|
+
For 2D: [row_name, column_name]
|
36
|
+
For 3D: [depth_name, row_name, column_name]
|
37
|
+
format_spec : str, optional
|
38
|
+
Format specification for numbers (default: ".4f")
|
39
|
+
table_format : str, optional
|
40
|
+
Table format style for tabulate (default: "grid")
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
--------
|
44
|
+
str : Formatted table string
|
45
|
+
"""
|
46
|
+
import numpy as np
|
47
|
+
from tabulate import tabulate
|
48
|
+
|
49
|
+
assert tensor.dtype == torch.int
|
50
|
+
# Convert tensor to numpy for easier manipulation
|
51
|
+
data = tensor.detach().cpu().numpy()
|
52
|
+
|
53
|
+
# Normalize dimensions
|
54
|
+
orig_ndim = data.ndim
|
55
|
+
if data.ndim == 1:
|
56
|
+
data = data.reshape(1, 1, -1)
|
57
|
+
elif data.ndim == 2:
|
58
|
+
data = data.reshape(1, *data.shape)
|
59
|
+
elif data.ndim > 3:
|
60
|
+
raise ValueError("Input tensor must be 1D, 2D, or 3D")
|
61
|
+
|
62
|
+
# Get tensor dimensions
|
63
|
+
depth, rows, cols = data.shape
|
64
|
+
|
65
|
+
# Generate or validate labels for each dimension
|
66
|
+
if axis_labels is None:
|
67
|
+
axis_labels = []
|
68
|
+
|
69
|
+
# Pad or truncate axis_labels based on tensor dimensions
|
70
|
+
ndim = orig_ndim
|
71
|
+
while len(axis_labels) < ndim:
|
72
|
+
dim_size = data.shape[-(len(axis_labels) + 1)]
|
73
|
+
axis_labels.insert(0, [f"D{len(axis_labels)}_{i+1}" for i in range(dim_size)])
|
74
|
+
axis_labels = axis_labels[-ndim:]
|
75
|
+
|
76
|
+
# Convert to internal format (depth, rows, cols)
|
77
|
+
all_labels = [None] * 3
|
78
|
+
if ndim == 1:
|
79
|
+
all_labels = [["1"], ["1"], axis_labels[0]]
|
80
|
+
elif ndim == 2:
|
81
|
+
all_labels = [["1"], axis_labels[0], axis_labels[1]]
|
82
|
+
else:
|
83
|
+
all_labels = axis_labels
|
84
|
+
|
85
|
+
# Handle axis names similarly
|
86
|
+
if axis_names is None:
|
87
|
+
axis_names = []
|
88
|
+
|
89
|
+
# Pad or truncate axis_names based on tensor dimensions
|
90
|
+
while len(axis_names) < ndim:
|
91
|
+
axis_names.insert(0, f"Dimension {len(axis_names)}")
|
92
|
+
axis_names = axis_names[-ndim:]
|
93
|
+
|
94
|
+
# Convert to internal format (depth, rows, cols)
|
95
|
+
all_names = [None] * 3
|
96
|
+
if ndim == 1:
|
97
|
+
all_names = [None, None, axis_names[0]]
|
98
|
+
elif ndim == 2:
|
99
|
+
all_names = [None, axis_names[0], axis_names[1]]
|
100
|
+
else:
|
101
|
+
all_names = axis_names
|
102
|
+
|
103
|
+
# Format output
|
104
|
+
tables = []
|
105
|
+
for d in range(depth):
|
106
|
+
# Format slice data
|
107
|
+
formatted_data = [[format_data(x) for x in row] for row in data[d]]
|
108
|
+
|
109
|
+
# Add row labels except for 1D tensors
|
110
|
+
if orig_ndim > 1:
|
111
|
+
formatted_data = [
|
112
|
+
[all_labels[1][i]] + row for i, row in enumerate(formatted_data)
|
113
|
+
]
|
114
|
+
|
115
|
+
# Create slice header for 3D tensors
|
116
|
+
if orig_ndim == 3:
|
117
|
+
slice_header = (
|
118
|
+
f"\n{all_names[0]}: {all_labels[0][d]}\n"
|
119
|
+
if d > 0
|
120
|
+
else f"{all_names[0]}: {all_labels[0][d]}\n"
|
121
|
+
)
|
122
|
+
else:
|
123
|
+
slice_header = ""
|
124
|
+
|
125
|
+
# Create table
|
126
|
+
headers = [""] + all_labels[2] if orig_ndim > 1 else all_labels[2]
|
127
|
+
table = tabulate(
|
128
|
+
formatted_data,
|
129
|
+
headers=headers,
|
130
|
+
tablefmt=table_format,
|
131
|
+
stralign="right",
|
132
|
+
numalign="right",
|
133
|
+
)
|
134
|
+
|
135
|
+
# Add axis labels
|
136
|
+
lines = table.split("\n")
|
137
|
+
|
138
|
+
# Add column axis name for all dimensions on first slice
|
139
|
+
if d == 0 and all_names[2]:
|
140
|
+
if orig_ndim == 1:
|
141
|
+
# For 1D, center the column name over the entire table
|
142
|
+
col_label = f"{all_names[2]:^{len(lines[0])}}"
|
143
|
+
else:
|
144
|
+
# For 2D and 3D, account for row labels
|
145
|
+
total_width = len(lines[0])
|
146
|
+
y_axis_width = max(len(label) for label in all_labels[1]) + 4
|
147
|
+
data_width = total_width - y_axis_width
|
148
|
+
col_label = f"{' ' * y_axis_width}{all_names[2]:^{data_width}}"
|
149
|
+
lines.insert(0, col_label)
|
150
|
+
|
151
|
+
# Add row axis name (only for 2D and 3D tensors)
|
152
|
+
if orig_ndim > 1 and all_names[1]:
|
153
|
+
label_lines = lines[1:] if d == 0 and all_names[2] else lines
|
154
|
+
max_label_length = len(all_names[1])
|
155
|
+
padded_label = f"{all_names[1]:>{max_label_length}} │"
|
156
|
+
|
157
|
+
if d == 0 and all_names[2]:
|
158
|
+
lines[0] = f"{' ' * (max_label_length + 2)}{lines[0]}"
|
159
|
+
|
160
|
+
for i, line in enumerate(label_lines):
|
161
|
+
if i == len(label_lines) // 2:
|
162
|
+
lines[i + (1 if d == 0 and all_names[2] else 0)] = (
|
163
|
+
f"{padded_label}{line}"
|
164
|
+
)
|
165
|
+
else:
|
166
|
+
lines[i + (1 if d == 0 and all_names[2] else 0)] = (
|
167
|
+
f"{' ' * max_label_length} │{line}"
|
168
|
+
)
|
169
|
+
|
170
|
+
tables.append(slice_header + "\n".join(lines))
|
171
|
+
|
172
|
+
return "\n".join(tables)
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import torch
|
9
|
+
|
10
|
+
|
11
|
+
# All of the tensor examples in this zoo inherit from BaseTensor. Ideally,
|
12
|
+
# however, they would inherit directly from Tensor. This is just our staging
|
13
|
+
# ground for applying behavior that hasn't yet made it into core but that
|
14
|
+
# we would like to apply by default.
|
15
|
+
class BaseTensor(torch.Tensor):
|
16
|
+
# See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary
|
17
|
+
# to ensure that super().__new__ can cooperate with each other
|
18
|
+
@staticmethod
|
19
|
+
def __new__(cls, elem, *, requires_grad=None):
|
20
|
+
if requires_grad is None:
|
21
|
+
return super().__new__(cls, elem)
|
22
|
+
else:
|
23
|
+
return cls._make_subclass(cls, elem, requires_grad)
|
24
|
+
|
25
|
+
# If __torch_dispatch__ is defined (which it will be for all our examples)
|
26
|
+
# the default torch function implementation (which preserves subclasses)
|
27
|
+
# typically must be disabled
|
28
|
+
__torch_function__ = torch._C._disabled_torch_function_impl
|
@@ -0,0 +1,143 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import traceback
|
9
|
+
import warnings
|
10
|
+
from typing import List, Optional, TYPE_CHECKING
|
11
|
+
from weakref import ref, WeakSet
|
12
|
+
|
13
|
+
from . import messages
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from .device_mesh import DeviceMesh
|
17
|
+
from .tensor import Tensor
|
18
|
+
|
19
|
+
|
20
|
+
# all the aliases for the same storage on a particular stream
|
21
|
+
# borrows of a storage to another stream are not considered aliases
|
22
|
+
# but instead copies and will have a different set of storage aliases.
|
23
|
+
# conceptually, think of borrows as copies that have been guarded
|
24
|
+
# so that we do not actually perform the data movement.
|
25
|
+
class StorageAliases:
|
26
|
+
def __init__(self):
|
27
|
+
# how are we allowed to access this storage
|
28
|
+
# string containing 0 or more of:
|
29
|
+
# r - can read
|
30
|
+
# w - can write
|
31
|
+
self.access = "rw"
|
32
|
+
# what Tensor aliases exist for this storage
|
33
|
+
self.aliases = WeakSet()
|
34
|
+
|
35
|
+
# was this set of storages originally a borrow
|
36
|
+
# from another stream?
|
37
|
+
self._borrow: Optional[ref[Borrow]] = None
|
38
|
+
self.borrowed_from: "Optional[StorageAliases]" = None
|
39
|
+
# how many times has this storage been borrowed?
|
40
|
+
self.live_borrows = WeakSet()
|
41
|
+
|
42
|
+
@property
|
43
|
+
def borrow(self) -> "Borrow":
|
44
|
+
assert self._borrow is not None
|
45
|
+
borrow = self._borrow()
|
46
|
+
assert borrow is not None
|
47
|
+
return borrow
|
48
|
+
|
49
|
+
def register(self, tensor: "Tensor"):
|
50
|
+
self.aliases.add(tensor)
|
51
|
+
if self.borrowed_from is not None:
|
52
|
+
self.borrow._live_tensors += 1
|
53
|
+
|
54
|
+
def unregister(self, tensor: "Tensor"):
|
55
|
+
borrowed_from = self.borrowed_from
|
56
|
+
if borrowed_from is not None:
|
57
|
+
borrow = self.borrow
|
58
|
+
borrow._live_tensors -= 1
|
59
|
+
if borrow._live_tensors == 0:
|
60
|
+
borrow._use()
|
61
|
+
if self.access == "rw":
|
62
|
+
# returning a mutable borrow needs to propagate errors
|
63
|
+
# from the stream (which may have mutated the value) back to the values
|
64
|
+
# on the origin stream. This does not happen automatically because
|
65
|
+
# borrows are not tracked as tensor aliases, but are instead treated
|
66
|
+
# as a kind of optimized copy or move.
|
67
|
+
tensor.mesh.client.new_node(borrowed_from.aliases, (tensor,))
|
68
|
+
tensor.mesh._send(messages.BorrowLastUse(borrow._id))
|
69
|
+
|
70
|
+
def borrow_from(
|
71
|
+
self, id: int, mesh: "DeviceMesh", f: "StorageAliases", mutable: bool
|
72
|
+
):
|
73
|
+
assert (
|
74
|
+
self.borrowed_from is None
|
75
|
+
), "we should have created a new storage with no borrows"
|
76
|
+
if mutable:
|
77
|
+
if "w" not in f.access:
|
78
|
+
raise RuntimeError(
|
79
|
+
"Cannot borrow this tensor mutably because it (or a view) is already being borrowed non-mutably."
|
80
|
+
)
|
81
|
+
f.access = ""
|
82
|
+
self.access = "rw"
|
83
|
+
else:
|
84
|
+
f.access = self.access = "r"
|
85
|
+
self.borrowed_from = f
|
86
|
+
borrow = Borrow(id, self, mesh)
|
87
|
+
f.live_borrows.add(borrow)
|
88
|
+
self._borrow = ref(borrow)
|
89
|
+
return borrow
|
90
|
+
|
91
|
+
|
92
|
+
class Borrow:
|
93
|
+
def __init__(self, id: int, aliases: StorageAliases, mesh: "DeviceMesh"):
|
94
|
+
self._storage_aliases = aliases
|
95
|
+
self._mesh = mesh
|
96
|
+
self._id = id
|
97
|
+
self._live_tensors = 1
|
98
|
+
self._dropped = False
|
99
|
+
self._used = False
|
100
|
+
self._frames: List[traceback.FrameSummary] = traceback.extract_stack()
|
101
|
+
|
102
|
+
@property
|
103
|
+
def traceback_string(self):
|
104
|
+
return "".join(traceback.format_list(self._frames))
|
105
|
+
|
106
|
+
def __enter__(self):
|
107
|
+
pass
|
108
|
+
|
109
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
110
|
+
self.drop()
|
111
|
+
|
112
|
+
def _use(self):
|
113
|
+
if self._used:
|
114
|
+
return
|
115
|
+
self._used = True
|
116
|
+
self._mesh._send(messages.BorrowFirstUse(self._id))
|
117
|
+
|
118
|
+
def drop(self) -> None:
|
119
|
+
if self._dropped:
|
120
|
+
return
|
121
|
+
self._dropped = True
|
122
|
+
|
123
|
+
for alias in self._storage_aliases.aliases:
|
124
|
+
alias._drop_ref()
|
125
|
+
|
126
|
+
self._mesh.client.drop_borrow(self)
|
127
|
+
self._mesh._send(messages.BorrowDrop(self._id))
|
128
|
+
f = self._storage_aliases.borrowed_from
|
129
|
+
assert f is not None
|
130
|
+
f.live_borrows.remove(self)
|
131
|
+
if len(f.live_borrows) == 0:
|
132
|
+
f.access = "rw" if f.borrowed_from is None else self._storage_aliases.access
|
133
|
+
|
134
|
+
def __del__(self):
|
135
|
+
if not self._dropped:
|
136
|
+
current = "".join(traceback.format_stack())
|
137
|
+
warnings.warn(
|
138
|
+
"borrow.drop() must be called before a borrowed tensor is freed to specify when the borrowed tensor should return to its origin stream, but borrow is being deleted before drop."
|
139
|
+
"borrow.drop() is being called automatically here to ensure correctness, but this will force a synchronization back to the original stream at this point which might not be intended."
|
140
|
+
f"\nTraceback of __del__(most recent call last):\n{current}\nTraceback of original borrow (most recent call last):{self.traceback_string}",
|
141
|
+
stacklevel=2,
|
142
|
+
)
|
143
|
+
self.drop()
|