torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,308 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import functools
10
+ from collections import defaultdict
11
+ from contextlib import contextmanager
12
+ from dataclasses import dataclass
13
+ from typing import (
14
+ Any,
15
+ Callable,
16
+ Dict,
17
+ Generator,
18
+ List,
19
+ NamedTuple,
20
+ Optional,
21
+ Sequence,
22
+ Tuple,
23
+ TYPE_CHECKING,
24
+ )
25
+
26
+ import torch
27
+ from monarch.common import messages
28
+
29
+ from monarch.common.fake import fake_call
30
+ from monarch.common.function_caching import (
31
+ hashable_tensor_flatten,
32
+ TensorGroup,
33
+ TensorGroupPattern,
34
+ )
35
+ from monarch.common.tensor import InputChecker, Tensor
36
+ from monarch.common.tree import flatten
37
+
38
+ if TYPE_CHECKING:
39
+ from monarch.common.client import Recorder
40
+ from monarch.common.recording import Recording
41
+
42
+ from .client import Client
43
+
44
+ _coalescing = None
45
+
46
+
47
+ class CoalescingState:
48
+ def __init__(self, recording=False):
49
+ self.controller: Optional["Client"] = None
50
+ self.recorder: Optional["Recorder"] = None
51
+ self.recording = recording
52
+
53
+ def set_controller(self, controller: "Client"):
54
+ if self.controller is None:
55
+ self.controller = controller
56
+ controller.flush_deletes(False)
57
+ if self.controller is not controller:
58
+ raise ValueError(
59
+ "using multiple controllers in the same coalescing block is not supported"
60
+ )
61
+
62
+ @contextmanager
63
+ def activate(self) -> Generator[None, Any, Any]:
64
+ global _coalescing
65
+ assert _coalescing is None
66
+ finished = False
67
+ try:
68
+ _coalescing = self
69
+ yield
70
+ finished = True
71
+ finally:
72
+ ctrl = self.controller
73
+ if ctrl is not None:
74
+ if finished:
75
+ ctrl.flush_deletes()
76
+ self.recorder = ctrl.reset_recorder()
77
+ if not finished:
78
+ self.recorder.abandon()
79
+ _coalescing = None
80
+
81
+
82
+ @contextmanager
83
+ def coalescing() -> Generator[None, Any, Any]:
84
+ global _coalescing
85
+ if _coalescing is not None:
86
+ yield
87
+ return
88
+
89
+ state = CoalescingState()
90
+ with state.activate():
91
+ yield
92
+
93
+ if state.recorder is not None:
94
+ assert state.controller is not None
95
+ state.recorder.run_once(state.controller)
96
+
97
+
98
+ def _record_and_define(
99
+ fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any]
100
+ ) -> "CacheEntry":
101
+ input_tensors, unflatten_input = flatten(
102
+ (args, kwargs), lambda x: isinstance(x, Tensor)
103
+ )
104
+
105
+ with InputChecker.from_flat_args(
106
+ "compile", input_tensors, unflatten_input
107
+ ) as checker:
108
+ checker.check_no_requires_grad()
109
+
110
+ for a in input_tensors:
111
+ assert a._seq is not None
112
+
113
+ state = CoalescingState(recording=True)
114
+ with state.activate():
115
+ formal_tensors = []
116
+ for i, input in enumerate(input_tensors):
117
+ state.set_controller(input.mesh.client)
118
+ t = Tensor(input._fake, input.mesh, input.stream)
119
+ input.mesh._send(
120
+ messages.RecordingFormal(t, i, t.stream._to_ref(input.mesh.client))
121
+ )
122
+ formal_tensors.append(t)
123
+ formal_args, formal_kwargs = unflatten_input(formal_tensors)
124
+ recorded_result = fn(*formal_args, **formal_kwargs)
125
+ output_tensors, unflatten_result = flatten(
126
+ recorded_result, lambda x: isinstance(x, Tensor)
127
+ )
128
+ with InputChecker(
129
+ output_tensors,
130
+ lambda ts: f"{unflatten_result(ts)} = compiled_function(...)",
131
+ ) as checker:
132
+ checker.check_no_requires_grad()
133
+ for i, output in enumerate(output_tensors):
134
+ state.set_controller(output.mesh.client)
135
+ output.mesh._send(
136
+ messages.RecordingResult(
137
+ output, i, output.stream._to_ref(output.mesh.client)
138
+ )
139
+ )
140
+
141
+ recorder = state.recorder
142
+ if recorder is None:
143
+ # no input tensors or output tensors, so just cache the result
144
+ return CacheEntry(
145
+ TensorGroup([]),
146
+ TensorGroupPattern(()),
147
+ lambda args, kwargs: recorded_result,
148
+ None,
149
+ )
150
+
151
+ controller = state.controller
152
+ assert controller is not None
153
+ recorder.add((), output_tensors, [])
154
+ recording = recorder.define_recording(
155
+ controller, len(output_tensors), len(input_tensors)
156
+ )
157
+
158
+ fake_uses = [r._fake for r in recording.uses]
159
+ captures_group = TensorGroup(fake_uses)
160
+ inputs_group = TensorGroup([i._fake for i in input_tensors], parent=captures_group)
161
+
162
+ outputs_group = TensorGroup([o._fake for o in output_tensors], parent=inputs_group)
163
+ outputs_pattern = outputs_group.pattern
164
+
165
+ def run(args, kwargs):
166
+ actuals, _ = flatten((args, kwargs), lambda x: isinstance(x, Tensor))
167
+ for a in actuals:
168
+ assert a._seq is not None
169
+
170
+ fake_result_tensors = fake_call(
171
+ outputs_pattern.empty, [fake_uses, [a._fake for a in actuals]]
172
+ )
173
+
174
+ # recording.run does permissions checks on all the tensors.
175
+ # if those checks fail then the tensors here will have been created
176
+ # but not defined, causes spurious delete messages.
177
+ # To avoid this, we pass a generator rather than a list
178
+ # and only create the tensors in run
179
+ result_tensors_generator = (
180
+ Tensor(f, o.mesh, o.stream)
181
+ for f, o in zip(fake_result_tensors, output_tensors)
182
+ )
183
+ return unflatten_result(recording.run(result_tensors_generator, actuals))
184
+
185
+ return CacheEntry(captures_group, inputs_group.pattern, run, recording)
186
+
187
+
188
+ @dataclass
189
+ class CacheEntry:
190
+ captures_group: TensorGroup
191
+ inputs_pattern: TensorGroupPattern
192
+ run: Callable[[Tuple[Any, ...], Dict[str, Any]], Any]
193
+ to_verify: Optional["Recording"]
194
+
195
+ def matches(self, input_tensors: List[torch.Tensor]) -> bool:
196
+ # if an input aliases a captured tensor, then we have
197
+ # to check that all future inputs alias the _same exact_
198
+ # captured tensor. These are additional checks after
199
+ # matching on the pattern of aliasing for just the inputs because
200
+ # we do not what the captures would be without first matching the inputs without the captures.
201
+ inputs_group = TensorGroup(input_tensors, parent=self.captures_group)
202
+ return self.inputs_pattern == inputs_group.pattern
203
+
204
+
205
+ def compile(fn=None, verify=True):
206
+ """
207
+ Wraps `fn` such that it records and later replays a single message to workers
208
+ to instruct them to run the entire contents of this function. Since the function invocation
209
+ is much smaller than the original set of messages and since we do not re-execute the python inside
210
+ the function after recording, this has substantially lower latency.
211
+
212
+ While eventually `compile` will be backed by `torch.compile`'s dynamo executor, it currently
213
+ works as a simple tracer with the following rules for when it chooses to trace vs when
214
+ it will reuse an existing trace.
215
+
216
+ A new trace is created whenever:
217
+
218
+ * The _values_ of a non-tensor argument to fn have not been seen before.
219
+ * The _metadata_ of a tensor arguments has not been seen before. Metadata includes the sizes, strides,
220
+ dtype, devices, layout, device meshes, streams, and pattern of aliasing of the arguments
221
+ with respect to other arguments and any values the trace captures.
222
+
223
+ A new trace will not be created in these following situations that are known to be **unsafe**:
224
+
225
+ * A value that is not an argument to the function but is used by the function (e.g. a global),
226
+ changes in a way that would affect what messages are being sent.
227
+ * A tensor that is not an argument to the function changes metadata, or gets reassigned to
228
+ a new tensor in Python.
229
+
230
+
231
+ The trace is allowed to use tensors that are referenced in the body but not listed as arguments,
232
+ such as globals or closure-captured locals as long as these values are not modified in the
233
+ the ways that are listed as unsafe above. When switched to a torch.compile backed version,
234
+ these safety caveats will be improved.
235
+
236
+ Compilation currently does not work if the inputs or outputs to the function have `requires_grad=True`,
237
+ because we will not generate a correctly backwards pass graph. However, captured tensors
238
+ are allowed to be requires_grad=True, and gradient calculation (forward+backward)
239
+ can run entirely within the function.
240
+
241
+ Can be used as a wrapper:
242
+ wrapped = compile(my_function, verify=False)
243
+
244
+ Or as a decorator:
245
+
246
+ @compile
247
+ def my_function(...):
248
+ ...
249
+
250
+ @compile(verify=False)
251
+ def my_function(...):
252
+ ...
253
+
254
+ Args:
255
+
256
+ fn (callable): the function to be wrapped. (Default: None, in which case we return a single argument,
257
+ function that can be used as a decorator)
258
+ verify (bool): To guard as much as possible against the above unsafe situations,
259
+ if `verify=True`, the first time we would reuse a trace, we additionally do another
260
+ recording and check the second recording matches the original recording, and report
261
+ where they diverge. (Default: True)
262
+
263
+
264
+ Returns:
265
+ If fn=None, it returns a function that can be used as a decorator on a function to
266
+ be wrapped. Otherwise, it returns the wrapped function itself.
267
+
268
+ """
269
+ if fn is None:
270
+ return lambda fn: compile(fn, verify)
271
+ cache: Dict[Any, Recording] = defaultdict(list)
272
+
273
+ @functools.wraps(fn)
274
+ def wrapper(*args, **kwargs):
275
+ global _coalescing
276
+ if _coalescing:
277
+ return fn(*args, **kwargs)
278
+
279
+ tensors, shape_key = hashable_tensor_flatten(args, kwargs)
280
+ input_group = TensorGroup([t._fake for t in tensors])
281
+ props = tuple((t.mesh, t.stream, t.requires_grad) for t in tensors)
282
+ key = (shape_key, input_group.pattern, props)
283
+ for entry in cache[key]:
284
+ if entry.matches(input_group.tensors):
285
+ if entry.to_verify is not None:
286
+ entry.to_verify.client.recorder.verify_against(entry.to_verify)
287
+ _record_and_define(fn, args, kwargs)
288
+ entry.to_verify = None
289
+ return entry.run(args, kwargs)
290
+
291
+ entry = _record_and_define(fn, args, kwargs)
292
+ if not verify:
293
+ entry.to_verify = None
294
+ cache[key].append(entry)
295
+ return entry.run(args, kwargs)
296
+
297
+ return wrapper
298
+
299
+
300
+ def is_active(controller: "Client"):
301
+ if _coalescing is None:
302
+ return False
303
+ _coalescing.set_controller(controller)
304
+ return True
305
+
306
+
307
+ def is_recording(controller: "Client"):
308
+ return is_active(controller) and _coalescing.recording
@@ -0,0 +1,18 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import re
9
+ from pathlib import Path
10
+
11
+
12
+ def _local_device_count():
13
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
14
+ return len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
15
+ dev_path = Path("/dev")
16
+ pattern = re.compile(r"nvidia\d+$")
17
+ nvidia_devices = [dev for dev in dev_path.iterdir() if pattern.match(dev.name)]
18
+ return len(nvidia_devices)
@@ -0,0 +1,172 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, List, Optional
8
+
9
+ import torch
10
+
11
+
12
+ def tensor_to_table(
13
+ tensor: torch.Tensor,
14
+ format_data: Callable,
15
+ axis_labels: Optional[List[List[str]]] = None,
16
+ axis_names: Optional[List[str]] = None,
17
+ format_spec: str = ".4f",
18
+ table_format: str = "grid",
19
+ ) -> str:
20
+ """
21
+ Convert a tensor into formatted tables with generic dimension handling.
22
+
23
+ Parameters:
24
+ -----------
25
+ tensor : torch.Tensor or np.ndarray
26
+ Input tensor to be converted (1D, 2D, or 3D)
27
+ axis_labels : list of lists, optional
28
+ Labels for each axis, ordered from outer to inner dimension
29
+ For 1D: [column_labels]
30
+ For 2D: [row_labels, column_labels]
31
+ For 3D: [depth_labels, row_labels, column_labels]
32
+ axis_names : list, optional
33
+ Names for each axis, ordered from outer to inner dimension
34
+ For 1D: [column_name]
35
+ For 2D: [row_name, column_name]
36
+ For 3D: [depth_name, row_name, column_name]
37
+ format_spec : str, optional
38
+ Format specification for numbers (default: ".4f")
39
+ table_format : str, optional
40
+ Table format style for tabulate (default: "grid")
41
+
42
+ Returns:
43
+ --------
44
+ str : Formatted table string
45
+ """
46
+ import numpy as np
47
+ from tabulate import tabulate
48
+
49
+ assert tensor.dtype == torch.int
50
+ # Convert tensor to numpy for easier manipulation
51
+ data = tensor.detach().cpu().numpy()
52
+
53
+ # Normalize dimensions
54
+ orig_ndim = data.ndim
55
+ if data.ndim == 1:
56
+ data = data.reshape(1, 1, -1)
57
+ elif data.ndim == 2:
58
+ data = data.reshape(1, *data.shape)
59
+ elif data.ndim > 3:
60
+ raise ValueError("Input tensor must be 1D, 2D, or 3D")
61
+
62
+ # Get tensor dimensions
63
+ depth, rows, cols = data.shape
64
+
65
+ # Generate or validate labels for each dimension
66
+ if axis_labels is None:
67
+ axis_labels = []
68
+
69
+ # Pad or truncate axis_labels based on tensor dimensions
70
+ ndim = orig_ndim
71
+ while len(axis_labels) < ndim:
72
+ dim_size = data.shape[-(len(axis_labels) + 1)]
73
+ axis_labels.insert(0, [f"D{len(axis_labels)}_{i+1}" for i in range(dim_size)])
74
+ axis_labels = axis_labels[-ndim:]
75
+
76
+ # Convert to internal format (depth, rows, cols)
77
+ all_labels = [None] * 3
78
+ if ndim == 1:
79
+ all_labels = [["1"], ["1"], axis_labels[0]]
80
+ elif ndim == 2:
81
+ all_labels = [["1"], axis_labels[0], axis_labels[1]]
82
+ else:
83
+ all_labels = axis_labels
84
+
85
+ # Handle axis names similarly
86
+ if axis_names is None:
87
+ axis_names = []
88
+
89
+ # Pad or truncate axis_names based on tensor dimensions
90
+ while len(axis_names) < ndim:
91
+ axis_names.insert(0, f"Dimension {len(axis_names)}")
92
+ axis_names = axis_names[-ndim:]
93
+
94
+ # Convert to internal format (depth, rows, cols)
95
+ all_names = [None] * 3
96
+ if ndim == 1:
97
+ all_names = [None, None, axis_names[0]]
98
+ elif ndim == 2:
99
+ all_names = [None, axis_names[0], axis_names[1]]
100
+ else:
101
+ all_names = axis_names
102
+
103
+ # Format output
104
+ tables = []
105
+ for d in range(depth):
106
+ # Format slice data
107
+ formatted_data = [[format_data(x) for x in row] for row in data[d]]
108
+
109
+ # Add row labels except for 1D tensors
110
+ if orig_ndim > 1:
111
+ formatted_data = [
112
+ [all_labels[1][i]] + row for i, row in enumerate(formatted_data)
113
+ ]
114
+
115
+ # Create slice header for 3D tensors
116
+ if orig_ndim == 3:
117
+ slice_header = (
118
+ f"\n{all_names[0]}: {all_labels[0][d]}\n"
119
+ if d > 0
120
+ else f"{all_names[0]}: {all_labels[0][d]}\n"
121
+ )
122
+ else:
123
+ slice_header = ""
124
+
125
+ # Create table
126
+ headers = [""] + all_labels[2] if orig_ndim > 1 else all_labels[2]
127
+ table = tabulate(
128
+ formatted_data,
129
+ headers=headers,
130
+ tablefmt=table_format,
131
+ stralign="right",
132
+ numalign="right",
133
+ )
134
+
135
+ # Add axis labels
136
+ lines = table.split("\n")
137
+
138
+ # Add column axis name for all dimensions on first slice
139
+ if d == 0 and all_names[2]:
140
+ if orig_ndim == 1:
141
+ # For 1D, center the column name over the entire table
142
+ col_label = f"{all_names[2]:^{len(lines[0])}}"
143
+ else:
144
+ # For 2D and 3D, account for row labels
145
+ total_width = len(lines[0])
146
+ y_axis_width = max(len(label) for label in all_labels[1]) + 4
147
+ data_width = total_width - y_axis_width
148
+ col_label = f"{' ' * y_axis_width}{all_names[2]:^{data_width}}"
149
+ lines.insert(0, col_label)
150
+
151
+ # Add row axis name (only for 2D and 3D tensors)
152
+ if orig_ndim > 1 and all_names[1]:
153
+ label_lines = lines[1:] if d == 0 and all_names[2] else lines
154
+ max_label_length = len(all_names[1])
155
+ padded_label = f"{all_names[1]:>{max_label_length}} │"
156
+
157
+ if d == 0 and all_names[2]:
158
+ lines[0] = f"{' ' * (max_label_length + 2)}{lines[0]}"
159
+
160
+ for i, line in enumerate(label_lines):
161
+ if i == len(label_lines) // 2:
162
+ lines[i + (1 if d == 0 and all_names[2] else 0)] = (
163
+ f"{padded_label}{line}"
164
+ )
165
+ else:
166
+ lines[i + (1 if d == 0 and all_names[2] else 0)] = (
167
+ f"{' ' * max_label_length} │{line}"
168
+ )
169
+
170
+ tables.append(slice_header + "\n".join(lines))
171
+
172
+ return "\n".join(tables)
@@ -0,0 +1,28 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import torch
9
+
10
+
11
+ # All of the tensor examples in this zoo inherit from BaseTensor. Ideally,
12
+ # however, they would inherit directly from Tensor. This is just our staging
13
+ # ground for applying behavior that hasn't yet made it into core but that
14
+ # we would like to apply by default.
15
+ class BaseTensor(torch.Tensor):
16
+ # See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary
17
+ # to ensure that super().__new__ can cooperate with each other
18
+ @staticmethod
19
+ def __new__(cls, elem, *, requires_grad=None):
20
+ if requires_grad is None:
21
+ return super().__new__(cls, elem)
22
+ else:
23
+ return cls._make_subclass(cls, elem, requires_grad)
24
+
25
+ # If __torch_dispatch__ is defined (which it will be for all our examples)
26
+ # the default torch function implementation (which preserves subclasses)
27
+ # typically must be disabled
28
+ __torch_function__ = torch._C._disabled_torch_function_impl
@@ -0,0 +1,143 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import traceback
9
+ import warnings
10
+ from typing import List, Optional, TYPE_CHECKING
11
+ from weakref import ref, WeakSet
12
+
13
+ from . import messages
14
+
15
+ if TYPE_CHECKING:
16
+ from .device_mesh import DeviceMesh
17
+ from .tensor import Tensor
18
+
19
+
20
+ # all the aliases for the same storage on a particular stream
21
+ # borrows of a storage to another stream are not considered aliases
22
+ # but instead copies and will have a different set of storage aliases.
23
+ # conceptually, think of borrows as copies that have been guarded
24
+ # so that we do not actually perform the data movement.
25
+ class StorageAliases:
26
+ def __init__(self):
27
+ # how are we allowed to access this storage
28
+ # string containing 0 or more of:
29
+ # r - can read
30
+ # w - can write
31
+ self.access = "rw"
32
+ # what Tensor aliases exist for this storage
33
+ self.aliases = WeakSet()
34
+
35
+ # was this set of storages originally a borrow
36
+ # from another stream?
37
+ self._borrow: Optional[ref[Borrow]] = None
38
+ self.borrowed_from: "Optional[StorageAliases]" = None
39
+ # how many times has this storage been borrowed?
40
+ self.live_borrows = WeakSet()
41
+
42
+ @property
43
+ def borrow(self) -> "Borrow":
44
+ assert self._borrow is not None
45
+ borrow = self._borrow()
46
+ assert borrow is not None
47
+ return borrow
48
+
49
+ def register(self, tensor: "Tensor"):
50
+ self.aliases.add(tensor)
51
+ if self.borrowed_from is not None:
52
+ self.borrow._live_tensors += 1
53
+
54
+ def unregister(self, tensor: "Tensor"):
55
+ borrowed_from = self.borrowed_from
56
+ if borrowed_from is not None:
57
+ borrow = self.borrow
58
+ borrow._live_tensors -= 1
59
+ if borrow._live_tensors == 0:
60
+ borrow._use()
61
+ if self.access == "rw":
62
+ # returning a mutable borrow needs to propagate errors
63
+ # from the stream (which may have mutated the value) back to the values
64
+ # on the origin stream. This does not happen automatically because
65
+ # borrows are not tracked as tensor aliases, but are instead treated
66
+ # as a kind of optimized copy or move.
67
+ tensor.mesh.client.new_node(borrowed_from.aliases, (tensor,))
68
+ tensor.mesh._send(messages.BorrowLastUse(borrow._id))
69
+
70
+ def borrow_from(
71
+ self, id: int, mesh: "DeviceMesh", f: "StorageAliases", mutable: bool
72
+ ):
73
+ assert (
74
+ self.borrowed_from is None
75
+ ), "we should have created a new storage with no borrows"
76
+ if mutable:
77
+ if "w" not in f.access:
78
+ raise RuntimeError(
79
+ "Cannot borrow this tensor mutably because it (or a view) is already being borrowed non-mutably."
80
+ )
81
+ f.access = ""
82
+ self.access = "rw"
83
+ else:
84
+ f.access = self.access = "r"
85
+ self.borrowed_from = f
86
+ borrow = Borrow(id, self, mesh)
87
+ f.live_borrows.add(borrow)
88
+ self._borrow = ref(borrow)
89
+ return borrow
90
+
91
+
92
+ class Borrow:
93
+ def __init__(self, id: int, aliases: StorageAliases, mesh: "DeviceMesh"):
94
+ self._storage_aliases = aliases
95
+ self._mesh = mesh
96
+ self._id = id
97
+ self._live_tensors = 1
98
+ self._dropped = False
99
+ self._used = False
100
+ self._frames: List[traceback.FrameSummary] = traceback.extract_stack()
101
+
102
+ @property
103
+ def traceback_string(self):
104
+ return "".join(traceback.format_list(self._frames))
105
+
106
+ def __enter__(self):
107
+ pass
108
+
109
+ def __exit__(self, exc_type, exc_value, traceback):
110
+ self.drop()
111
+
112
+ def _use(self):
113
+ if self._used:
114
+ return
115
+ self._used = True
116
+ self._mesh._send(messages.BorrowFirstUse(self._id))
117
+
118
+ def drop(self) -> None:
119
+ if self._dropped:
120
+ return
121
+ self._dropped = True
122
+
123
+ for alias in self._storage_aliases.aliases:
124
+ alias._drop_ref()
125
+
126
+ self._mesh.client.drop_borrow(self)
127
+ self._mesh._send(messages.BorrowDrop(self._id))
128
+ f = self._storage_aliases.borrowed_from
129
+ assert f is not None
130
+ f.live_borrows.remove(self)
131
+ if len(f.live_borrows) == 0:
132
+ f.access = "rw" if f.borrowed_from is None else self._storage_aliases.access
133
+
134
+ def __del__(self):
135
+ if not self._dropped:
136
+ current = "".join(traceback.format_stack())
137
+ warnings.warn(
138
+ "borrow.drop() must be called before a borrowed tensor is freed to specify when the borrowed tensor should return to its origin stream, but borrow is being deleted before drop."
139
+ "borrow.drop() is being called automatically here to ensure correctness, but this will force a synchronization back to the original stream at this point which might not be intended."
140
+ f"\nTraceback of __del__(most recent call last):\n{current}\nTraceback of original borrow (most recent call last):{self.traceback_string}",
141
+ stacklevel=2,
142
+ )
143
+ self.drop()