torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,255 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
|
9
|
+
import copy
|
10
|
+
import itertools
|
11
|
+
import traceback
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from enum import auto, Enum
|
14
|
+
from typing import cast, Dict, List, Optional, Sequence
|
15
|
+
|
16
|
+
from monarch.simulator.config import META_VAL
|
17
|
+
|
18
|
+
|
19
|
+
class TaskState(Enum):
|
20
|
+
PENDING = auto()
|
21
|
+
READY = auto()
|
22
|
+
EXECUTING = auto()
|
23
|
+
EXECUTED = auto()
|
24
|
+
|
25
|
+
|
26
|
+
class Task:
|
27
|
+
"""
|
28
|
+
A class to represent a task in a stream. A task is ready immediately if all
|
29
|
+
its dependencies are executed. A task is executed if it is ready and it is
|
30
|
+
the first task in the stream. A task can be marked as executed if it is executing
|
31
|
+
and all the collectives, if any, of the task are executing.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
inputs (List[int]): A list of input tensor ids.
|
35
|
+
outputs (List[int]): A list of output tensor ids.
|
36
|
+
command_id (int): The id of the command this task executes.
|
37
|
+
runtime (int): The runtime of the task in nanoseconds.
|
38
|
+
meta (List[str]): A list of metadata associated with the task.
|
39
|
+
collectives (Optional[List]): A list of collectives associated with the task.
|
40
|
+
Defaults to None.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
inputs: List[int],
|
46
|
+
outputs: List[int],
|
47
|
+
command_id: int,
|
48
|
+
start_time: int,
|
49
|
+
runtime: int,
|
50
|
+
meta: List[str],
|
51
|
+
collectives: Optional[List["Task"]] = None,
|
52
|
+
waits: Optional[List["Task"]] = None,
|
53
|
+
traceback: Sequence[traceback.FrameSummary] = (),
|
54
|
+
):
|
55
|
+
self.inputs = inputs
|
56
|
+
self.outputs = outputs
|
57
|
+
self.runtime = runtime
|
58
|
+
self.meta = meta + META_VAL
|
59
|
+
self.dependencies = []
|
60
|
+
self.collectives = collectives
|
61
|
+
self.waits = waits
|
62
|
+
self.command_id = command_id
|
63
|
+
self.traceback = traceback
|
64
|
+
if self.collectives is not None:
|
65
|
+
self.collectives.append(self)
|
66
|
+
|
67
|
+
self._state = TaskState.PENDING
|
68
|
+
self.start_time = start_time
|
69
|
+
self.end_time = 0
|
70
|
+
|
71
|
+
# Assied by WorkerTaskManager
|
72
|
+
self.task_id: Optional[int] = None
|
73
|
+
|
74
|
+
def __repr__(self):
|
75
|
+
return " ".join(self.meta)
|
76
|
+
|
77
|
+
@property
|
78
|
+
def state(self) -> TaskState:
|
79
|
+
return self._state
|
80
|
+
|
81
|
+
def maybe_set_ready(self) -> bool:
|
82
|
+
"""
|
83
|
+
Sets the task state to READY if it is ready. Returns True if the task state
|
84
|
+
changes from PENDING to READY.
|
85
|
+
"""
|
86
|
+
if self._state != TaskState.PENDING:
|
87
|
+
return False
|
88
|
+
|
89
|
+
if self.dependencies:
|
90
|
+
for d in self.dependencies:
|
91
|
+
if d._state != TaskState.EXECUTED:
|
92
|
+
return False
|
93
|
+
self.start_time = max(self.start_time, d.end_time)
|
94
|
+
self._state = TaskState.READY
|
95
|
+
return True
|
96
|
+
|
97
|
+
def maybe_execute(self) -> bool:
|
98
|
+
"""
|
99
|
+
Executes the task if it is ready. Returns True if the task state changes
|
100
|
+
from READY to EXECUTING.
|
101
|
+
"""
|
102
|
+
if self._state != TaskState.READY:
|
103
|
+
return False
|
104
|
+
|
105
|
+
self._state = TaskState.EXECUTING
|
106
|
+
return True
|
107
|
+
|
108
|
+
def maybe_finish(self) -> bool:
|
109
|
+
"""
|
110
|
+
Finish the task if it is executing and all the associated collectives,
|
111
|
+
if any, are executing or executed. Return True if the task state changes from
|
112
|
+
EXECUTING to EXECUTED.
|
113
|
+
"""
|
114
|
+
if not self._state == TaskState.EXECUTING:
|
115
|
+
return False
|
116
|
+
|
117
|
+
executed = True
|
118
|
+
if self.collectives:
|
119
|
+
executed = all(
|
120
|
+
c.state in (TaskState.EXECUTING, TaskState.EXECUTED)
|
121
|
+
for c in self.collectives
|
122
|
+
)
|
123
|
+
if self.waits:
|
124
|
+
executed = executed and all(
|
125
|
+
c.state == TaskState.EXECUTED for c in self.waits
|
126
|
+
)
|
127
|
+
if not executed:
|
128
|
+
return False
|
129
|
+
|
130
|
+
self._state = TaskState.EXECUTED
|
131
|
+
if self.collectives:
|
132
|
+
straggler_time = max(c.start_time for c in self.collectives)
|
133
|
+
self.end_time = straggler_time + self.runtime
|
134
|
+
if self.waits:
|
135
|
+
last_wait_event_time = max(c.end_time for c in self.waits)
|
136
|
+
self.end_time = max(self.end_time, last_wait_event_time)
|
137
|
+
if self.meta[0] != "aten.view":
|
138
|
+
self.end_time = max(self.end_time, self.start_time + self.runtime)
|
139
|
+
else:
|
140
|
+
# TODO: this is a workaround to removing `view` from the trace.
|
141
|
+
# What we really should do is to have the CPU trace besides GPU trace.
|
142
|
+
self.end_time = self.start_time
|
143
|
+
|
144
|
+
return True
|
145
|
+
|
146
|
+
def clone(self) -> "Task":
|
147
|
+
return copy.copy(self)
|
148
|
+
|
149
|
+
|
150
|
+
@dataclass
|
151
|
+
class Borrow:
|
152
|
+
ident: int
|
153
|
+
tensor_src_id: int
|
154
|
+
tensor_dst_id: int
|
155
|
+
from_stream: int
|
156
|
+
to_stream: int
|
157
|
+
|
158
|
+
|
159
|
+
class EventTask(Task):
|
160
|
+
"""Represents an event task in a stream."""
|
161
|
+
|
162
|
+
def __init__(
|
163
|
+
self,
|
164
|
+
recorded_task: Task,
|
165
|
+
event_stream: int,
|
166
|
+
event_stream_name: str,
|
167
|
+
wait_stream: int,
|
168
|
+
wait_stream_name: str,
|
169
|
+
start_time: int,
|
170
|
+
command_id: int,
|
171
|
+
runtime: int = 1,
|
172
|
+
borrow: Optional[Borrow] = None,
|
173
|
+
traceback: Sequence[traceback.FrameSummary] = (),
|
174
|
+
):
|
175
|
+
super().__init__(
|
176
|
+
inputs=[],
|
177
|
+
outputs=[],
|
178
|
+
command_id=command_id,
|
179
|
+
start_time=start_time,
|
180
|
+
runtime=runtime,
|
181
|
+
meta=["waiting for", event_stream_name],
|
182
|
+
waits=[recorded_task],
|
183
|
+
traceback=traceback,
|
184
|
+
)
|
185
|
+
self.event_stream = event_stream
|
186
|
+
self.event_stream_name = event_stream_name
|
187
|
+
self.wait_stream = wait_stream
|
188
|
+
self.wait_stream_name = wait_stream_name
|
189
|
+
self.borrow = borrow
|
190
|
+
|
191
|
+
def clone(self) -> "EventTask":
|
192
|
+
return copy.copy(self)
|
193
|
+
|
194
|
+
|
195
|
+
class WorkerTaskManager(Task):
|
196
|
+
def __init__(self) -> None:
|
197
|
+
self.tasks: Dict[int, Task] = {}
|
198
|
+
self.task_id = itertools.count()
|
199
|
+
|
200
|
+
def add(self, task: Task) -> int:
|
201
|
+
task_id = next(self.task_id)
|
202
|
+
self.tasks[task_id] = task
|
203
|
+
task.task_id = task_id
|
204
|
+
return task_id
|
205
|
+
|
206
|
+
def remove(self, task: Task) -> None:
|
207
|
+
if (task_id := task.task_id) is not None:
|
208
|
+
self.tasks.pop(task_id)
|
209
|
+
else:
|
210
|
+
raise ValueError("task_id is None")
|
211
|
+
|
212
|
+
def clone(self) -> "WorkerTaskManager":
|
213
|
+
cloned_tasks = {}
|
214
|
+
for task_id, task in self.tasks.items():
|
215
|
+
cloned_task = task.clone()
|
216
|
+
# Both dependencies and waits are all tasks on the same worker
|
217
|
+
# thread. Thus, they must be in the same WorkerTaskManager or
|
218
|
+
# they must be executed.
|
219
|
+
cloned_tasks[task_id] = cloned_task
|
220
|
+
if task.dependencies:
|
221
|
+
cloned_task.dependencies = []
|
222
|
+
for dep in task.dependencies:
|
223
|
+
if dep.task_id not in cloned_tasks:
|
224
|
+
# The dependency is executed, so it is not in the
|
225
|
+
# WorkerTaskManager. Just clone it to ensure the
|
226
|
+
# dependency is cloned but not added to the new
|
227
|
+
# WorkerTaskManager.
|
228
|
+
assert dep.state == TaskState.EXECUTED
|
229
|
+
cloned_task.dependencies.append(dep.clone())
|
230
|
+
else:
|
231
|
+
cloned_task.dependencies.append(cloned_tasks[dep.task_id])
|
232
|
+
if task.waits is not None:
|
233
|
+
cloned_task.waits = []
|
234
|
+
for wait in cast(List[Task], task.waits):
|
235
|
+
if wait.task_id not in cloned_tasks:
|
236
|
+
assert wait.state == TaskState.EXECUTED
|
237
|
+
assert cloned_task.waits is not None
|
238
|
+
cloned_task.waits.append(wait.clone())
|
239
|
+
else:
|
240
|
+
assert cloned_task.waits is not None
|
241
|
+
cloned_task.waits.append(cloned_tasks[wait.task_id])
|
242
|
+
|
243
|
+
# TODO: the global list shared by all the tasks with the same collective
|
244
|
+
# is a neat idea but can be hard to debug. Consider make it more explicit.
|
245
|
+
if cloned_task.collectives:
|
246
|
+
cloned_task.collectives.append(cloned_task)
|
247
|
+
|
248
|
+
cloned_tasks[task_id] = cloned_task
|
249
|
+
|
250
|
+
ret = WorkerTaskManager()
|
251
|
+
# Waste one to ensure all the cloned WorkerTaskManager has the same task_id.
|
252
|
+
next_task_id = next(self.task_id)
|
253
|
+
ret.task_id = itertools.count(next_task_id + 1)
|
254
|
+
ret.tasks = cloned_tasks
|
255
|
+
return ret
|
@@ -0,0 +1,373 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import copy
|
9
|
+
import heapq
|
10
|
+
import logging
|
11
|
+
import traceback
|
12
|
+
from collections import defaultdict
|
13
|
+
from itertools import count
|
14
|
+
from typing import Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from monarch.common.fake import fake_call
|
18
|
+
from monarch.common.tensor_factory import TensorFactory
|
19
|
+
from monarch.simulator.task import Task, WorkerTaskManager
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class DTensorRef:
|
25
|
+
"""
|
26
|
+
A reference to a `controller.tensor.Tensor` object.
|
27
|
+
|
28
|
+
This class is used to keep track of DTensor objects that have been created
|
29
|
+
and by the controller and to provide the mechanism to serialize DTensor
|
30
|
+
objects (torch.save/torch.load).
|
31
|
+
"""
|
32
|
+
|
33
|
+
created: Dict[int, "DTensorRef"] = {}
|
34
|
+
|
35
|
+
def __init__(self, tensor):
|
36
|
+
self.ref = tensor.ref
|
37
|
+
self.factory = TensorFactory.from_tensor(tensor)
|
38
|
+
self._fake: Optional[torch._subclasses.FakeTensor] = getattr(
|
39
|
+
tensor, "_fake", None
|
40
|
+
)
|
41
|
+
if self._fake is not None:
|
42
|
+
self._storage_id: Optional[torch.types._int] = id(
|
43
|
+
self._fake.untyped_storage()
|
44
|
+
)
|
45
|
+
self._size: Optional[int] = self._fake.untyped_storage().size()
|
46
|
+
else:
|
47
|
+
self._storage_id = None
|
48
|
+
self._size = None
|
49
|
+
|
50
|
+
def __repr__(self):
|
51
|
+
return f"DTensorRef({self.ref})"
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def from_ref(cls, tensor) -> "DTensorRef":
|
55
|
+
if tensor.ref not in cls.created:
|
56
|
+
cls.created[tensor.ref] = cls(tensor)
|
57
|
+
return cls.created[tensor.ref]
|
58
|
+
|
59
|
+
def __getstate__(self):
|
60
|
+
return {
|
61
|
+
"ref": self.ref,
|
62
|
+
"factory": self.factory,
|
63
|
+
"_fake": None,
|
64
|
+
}
|
65
|
+
|
66
|
+
def __setstate__(self, state):
|
67
|
+
self.__dict__.update(state)
|
68
|
+
self._fake = fake_call(self.factory.zeros)
|
69
|
+
|
70
|
+
def __deepcopy__(self, memo):
|
71
|
+
if self._fake is None:
|
72
|
+
raise RuntimeError()
|
73
|
+
|
74
|
+
fake = fake_call(self.factory.zeros)
|
75
|
+
fake._fake = fake
|
76
|
+
fake.ref = self.ref
|
77
|
+
return self.__class__(fake)
|
78
|
+
|
79
|
+
|
80
|
+
class FakeTensorTracker:
|
81
|
+
"""
|
82
|
+
Tracks the fake tensors created in the simulator. While each worker and stream
|
83
|
+
maintain its own tensors, we don't want to create one FakeTensor per stream/worker.
|
84
|
+
Instead, we can just share the fake tensor for the same tensor id.
|
85
|
+
This can reduce the simulation time.
|
86
|
+
|
87
|
+
A fake tensor is created when it is first created in any worker and is deleted
|
88
|
+
when it is deleted in all workers.
|
89
|
+
"""
|
90
|
+
|
91
|
+
def __init__(self) -> None:
|
92
|
+
self.tensors: Dict[int, torch._subclasses.FakeTensor] = {}
|
93
|
+
self._ref: Dict[int, int] = defaultdict(int)
|
94
|
+
self._borrowed_tensors: Set[int] = set()
|
95
|
+
|
96
|
+
def add(
|
97
|
+
self, tensors: Dict[int, torch._subclasses.FakeTensor], is_borrowed=False
|
98
|
+
) -> None:
|
99
|
+
self.tensors.update(tensors)
|
100
|
+
if is_borrowed:
|
101
|
+
self._borrowed_tensors.update(set(tensors.keys()))
|
102
|
+
|
103
|
+
def is_borrowed(self, tensor: int) -> bool:
|
104
|
+
return tensor in self._borrowed_tensors
|
105
|
+
|
106
|
+
def incr_ref(self, tensor_id: int) -> None:
|
107
|
+
assert tensor_id in self.tensors, f"Tensor {tensor_id} is not created"
|
108
|
+
self._ref[tensor_id] += 1
|
109
|
+
|
110
|
+
def decr_ref(self, tensor_id: int):
|
111
|
+
ref = self._ref[tensor_id] - 1
|
112
|
+
assert ref >= 0, f"Tensor {tensor_id} has negative ref count {ref}"
|
113
|
+
if ref == 0:
|
114
|
+
self.tensors.pop(tensor_id)
|
115
|
+
self._ref.pop(tensor_id)
|
116
|
+
else:
|
117
|
+
self._ref[tensor_id] = ref
|
118
|
+
|
119
|
+
|
120
|
+
class StorageEvent(NamedTuple):
|
121
|
+
address: int
|
122
|
+
delta: int
|
123
|
+
|
124
|
+
|
125
|
+
class WorkerStorageTracker:
|
126
|
+
def __init__(self, fake_tensor_tracker) -> None:
|
127
|
+
self.storages: Dict[torch.UntypedStorage, Set[int]] = {}
|
128
|
+
self.fake_tensor_tracker = fake_tensor_tracker
|
129
|
+
self._addr_counter = count(step=128) # aligning 128-byte cache lines?
|
130
|
+
self.storage_addresses: Dict[torch.UntypedStorage, int] = {}
|
131
|
+
|
132
|
+
def incr_ref(self, tensor_id: int) -> Optional[StorageEvent]:
|
133
|
+
fake = self.fake_tensor_tracker.tensors[tensor_id]
|
134
|
+
storage = fake.untyped_storage()
|
135
|
+
if storage not in self.storages:
|
136
|
+
self.storages[storage] = {tensor_id}
|
137
|
+
addr = next(self._addr_counter)
|
138
|
+
self.storage_addresses[storage] = addr
|
139
|
+
if self.fake_tensor_tracker.is_borrowed(tensor_id):
|
140
|
+
return None # Q: should self._addr_counter be reversed?
|
141
|
+
else:
|
142
|
+
return StorageEvent(addr, storage.size())
|
143
|
+
else:
|
144
|
+
self.storages[storage].add(tensor_id)
|
145
|
+
return None
|
146
|
+
|
147
|
+
def decr_ref(self, tensor_id: int) -> Optional[StorageEvent]:
|
148
|
+
fake = self.fake_tensor_tracker.tensors[tensor_id]
|
149
|
+
storage = fake.untyped_storage()
|
150
|
+
if storage not in self.storages:
|
151
|
+
raise RuntimeError(
|
152
|
+
f"{storage} is being dereferenced but it is not tracked."
|
153
|
+
)
|
154
|
+
else:
|
155
|
+
references = self.storages[storage]
|
156
|
+
references.remove(tensor_id)
|
157
|
+
if len(references) == 0:
|
158
|
+
self.storages.pop(storage)
|
159
|
+
addr = self.storage_addresses.pop(storage)
|
160
|
+
if self.fake_tensor_tracker.is_borrowed(tensor_id):
|
161
|
+
# The controller creates a new FakeTensor for Borrow.
|
162
|
+
# So we should not count the storage usage of this
|
163
|
+
# FakeTensor as it is not a materialized tensor on
|
164
|
+
# the works.
|
165
|
+
return None
|
166
|
+
else:
|
167
|
+
return StorageEvent(addr, storage.size())
|
168
|
+
return None
|
169
|
+
|
170
|
+
def clone(self) -> "WorkerStorageTracker":
|
171
|
+
ret = WorkerStorageTracker(self.fake_tensor_tracker)
|
172
|
+
ret.storages = copy.copy(self.storages)
|
173
|
+
return ret
|
174
|
+
|
175
|
+
|
176
|
+
class MemoryEvent(NamedTuple):
|
177
|
+
timestamp: int
|
178
|
+
address: int
|
179
|
+
delta: int
|
180
|
+
traceback: Sequence[traceback.FrameSummary]
|
181
|
+
|
182
|
+
def __lt__(self, other):
|
183
|
+
if self.timestamp == other.timestamp:
|
184
|
+
return self.delta < other.delta
|
185
|
+
return self.timestamp < other.timestamp
|
186
|
+
|
187
|
+
def __gt__(self, other):
|
188
|
+
if self.timestamp == other.timestamp:
|
189
|
+
return self.delta > other.delta
|
190
|
+
return self.timestamp > other.timestamp
|
191
|
+
|
192
|
+
def __eq__(self, other):
|
193
|
+
return self.timestamp == other.timestamp and self.delta == other.delta
|
194
|
+
|
195
|
+
|
196
|
+
class StreamMemoryTracker:
|
197
|
+
"""
|
198
|
+
Tracks the memory events (timestamp, usage_delta) of a stream. The usage
|
199
|
+
may not be added in the correct time order due to the asynchronous
|
200
|
+
simulated-execution of worker CPU thread and the stream thread. Thus a
|
201
|
+
heap is used to sort the events by timestamp.
|
202
|
+
"""
|
203
|
+
|
204
|
+
def __init__(self, storage_tracker: WorkerStorageTracker) -> None:
|
205
|
+
self.usage = 0
|
206
|
+
self.events: List[MemoryEvent] = []
|
207
|
+
self.storage_tracker = storage_tracker
|
208
|
+
self._tracked_addresses: Dict[int, int] = {}
|
209
|
+
|
210
|
+
def incr_ref(
|
211
|
+
self, ts: int, tensor_id, traceback: Optional[Sequence[traceback.FrameSummary]]
|
212
|
+
) -> None:
|
213
|
+
storage_event = self.storage_tracker.incr_ref(tensor_id)
|
214
|
+
delta = 0 if storage_event is None else storage_event.delta
|
215
|
+
logger.debug(
|
216
|
+
f"StreamMemoryTracker got {tensor_id} at {ts} and delta is {delta}."
|
217
|
+
)
|
218
|
+
# Some operators may return zero-size tensors.
|
219
|
+
# One example is aten._scaled_dot_product_flash_attention.default
|
220
|
+
torch.ops.aten._scaled_dot_product_flash_attention.default
|
221
|
+
if storage_event is not None and storage_event.delta != 0:
|
222
|
+
assert ts >= 0
|
223
|
+
assert traceback is not None
|
224
|
+
self._add_usage(ts, storage_event, traceback)
|
225
|
+
|
226
|
+
def decr_ref(
|
227
|
+
self, ts: int, tensor_id, traceback: Optional[Sequence[traceback.FrameSummary]]
|
228
|
+
) -> None:
|
229
|
+
storage_event = self.storage_tracker.decr_ref(tensor_id)
|
230
|
+
if storage_event is not None and storage_event.delta != 0:
|
231
|
+
assert ts >= 0
|
232
|
+
assert traceback is not None
|
233
|
+
self._remove_usage(ts, storage_event, traceback)
|
234
|
+
|
235
|
+
def _remove_usage(self, ts: int, storage_event: StorageEvent, traceback) -> None:
|
236
|
+
assert storage_event.delta <= self.usage
|
237
|
+
self.usage -= storage_event.delta
|
238
|
+
recorded_ts = self._tracked_addresses.pop(storage_event.address, -1)
|
239
|
+
if recorded_ts == -1:
|
240
|
+
raise RuntimeError(f"Cannot find the address {storage_event.address}")
|
241
|
+
if recorded_ts >= ts:
|
242
|
+
raise RuntimeError(
|
243
|
+
f"The address {storage_event.address} is allocated after being freed"
|
244
|
+
)
|
245
|
+
heapq.heappush(
|
246
|
+
self.events,
|
247
|
+
MemoryEvent(ts, storage_event.address, -storage_event.delta, traceback),
|
248
|
+
)
|
249
|
+
|
250
|
+
def _add_usage(self, ts: int, storage_event: StorageEvent, traceback) -> None:
|
251
|
+
self.usage += storage_event.delta
|
252
|
+
self._tracked_addresses[storage_event.address] = ts
|
253
|
+
heapq.heappush(
|
254
|
+
self.events,
|
255
|
+
MemoryEvent(ts, storage_event.address, storage_event.delta, traceback),
|
256
|
+
)
|
257
|
+
|
258
|
+
def pop_event(self) -> MemoryEvent:
|
259
|
+
return heapq.heappop(self.events)
|
260
|
+
|
261
|
+
def clone(self, storage_tracker: WorkerStorageTracker) -> "StreamMemoryTracker":
|
262
|
+
ret = StreamMemoryTracker(storage_tracker)
|
263
|
+
ret.usage = self.usage
|
264
|
+
ret.events = copy.copy(self.events)
|
265
|
+
return ret
|
266
|
+
|
267
|
+
|
268
|
+
class TensorManager:
|
269
|
+
"""
|
270
|
+
Tracks the tensor created in a worker or a stream. It can be CPU tensor,
|
271
|
+
which can only be owned by the worker or a gpu tensor which can only be
|
272
|
+
owned by a stream.
|
273
|
+
"""
|
274
|
+
|
275
|
+
def __init__(
|
276
|
+
self,
|
277
|
+
fake_tensor_tracker: FakeTensorTracker,
|
278
|
+
memory: Optional[StreamMemoryTracker],
|
279
|
+
) -> None:
|
280
|
+
self.tensors: Dict[int, Set[Union[Task, int]]] = {}
|
281
|
+
self.delete_tracebacks: Dict[
|
282
|
+
int, Optional[Sequence[traceback.FrameSummary]]
|
283
|
+
] = {}
|
284
|
+
self.pending_delete_tensors: Set[int] = set()
|
285
|
+
self.memory = memory
|
286
|
+
self.fake_tensor_tracker = fake_tensor_tracker
|
287
|
+
|
288
|
+
def add(self, tensor_id: int, refs: Tuple[Union[Task, int], ...], now: int) -> None:
|
289
|
+
logger.debug(f"TensorManager got {tensor_id} at {now}.")
|
290
|
+
self.tensors[tensor_id] = set(refs)
|
291
|
+
self.fake_tensor_tracker.incr_ref(tensor_id)
|
292
|
+
|
293
|
+
def first_use(
|
294
|
+
self,
|
295
|
+
tensor_id: int,
|
296
|
+
now: int,
|
297
|
+
traceback: Optional[Sequence[traceback.FrameSummary]],
|
298
|
+
) -> None:
|
299
|
+
logging.debug(f"TensorManager: {tensor_id} is first used")
|
300
|
+
if self.memory:
|
301
|
+
self.memory.incr_ref(now, tensor_id, traceback)
|
302
|
+
|
303
|
+
def incr_ref(self, tensor_id: int, ref: Union[Task, int]) -> None:
|
304
|
+
logging.debug(f"TensorManager: {tensor_id} is referenced.")
|
305
|
+
self.tensors[tensor_id].add(ref)
|
306
|
+
|
307
|
+
def decr_ref(
|
308
|
+
self,
|
309
|
+
tensor_id: int,
|
310
|
+
ref: Union[Task, int],
|
311
|
+
now: int,
|
312
|
+
traceback: Optional[Sequence[traceback.FrameSummary]],
|
313
|
+
) -> None:
|
314
|
+
logging.debug(f"TensorManager: {tensor_id} decr_ref.")
|
315
|
+
self.tensors[tensor_id].remove(ref)
|
316
|
+
self._maybe_delete_tensor(tensor_id, now, traceback)
|
317
|
+
|
318
|
+
def delete(
|
319
|
+
self,
|
320
|
+
tensor_id: int,
|
321
|
+
now: int,
|
322
|
+
traceback: Optional[Sequence[traceback.FrameSummary]],
|
323
|
+
) -> None:
|
324
|
+
self.pending_delete_tensors.add(tensor_id)
|
325
|
+
self._maybe_delete_tensor(tensor_id, now, traceback)
|
326
|
+
|
327
|
+
def __contains__(self, key: int) -> bool:
|
328
|
+
return key in self.tensors
|
329
|
+
|
330
|
+
def _maybe_delete_tensor(
|
331
|
+
self,
|
332
|
+
tensor_id: int,
|
333
|
+
now: int,
|
334
|
+
traceback: Optional[Sequence[traceback.FrameSummary]],
|
335
|
+
) -> None:
|
336
|
+
if len(self.tensors[tensor_id]) > 0:
|
337
|
+
return
|
338
|
+
|
339
|
+
if tensor_id not in self.pending_delete_tensors:
|
340
|
+
# While no one is using this tensor, Controller has not
|
341
|
+
# asked us to delete the tensor. Track the traceback of
|
342
|
+
# the last task.
|
343
|
+
self.delete_tracebacks[tensor_id] = traceback
|
344
|
+
return
|
345
|
+
|
346
|
+
traceback = (
|
347
|
+
traceback
|
348
|
+
if traceback is not None
|
349
|
+
else self.delete_tracebacks.pop(tensor_id, None)
|
350
|
+
)
|
351
|
+
|
352
|
+
if self.memory:
|
353
|
+
self.memory.decr_ref(now, tensor_id, traceback)
|
354
|
+
|
355
|
+
self.tensors.pop(tensor_id)
|
356
|
+
self.fake_tensor_tracker.decr_ref(tensor_id)
|
357
|
+
self.pending_delete_tensors.remove(tensor_id)
|
358
|
+
|
359
|
+
def clone(
|
360
|
+
self, task_manager: WorkerTaskManager, memory: Optional[StreamMemoryTracker]
|
361
|
+
) -> "TensorManager":
|
362
|
+
ret = TensorManager(self.fake_tensor_tracker, memory)
|
363
|
+
ret.pending_delete_tensors = copy.copy(self.pending_delete_tensors)
|
364
|
+
for k, v in self.tensors.items():
|
365
|
+
new_v = set()
|
366
|
+
for task in v:
|
367
|
+
if isinstance(task, Task):
|
368
|
+
assert task.task_id is not None
|
369
|
+
new_v.add(task_manager.tasks[task.task_id])
|
370
|
+
else:
|
371
|
+
new_v.add(task)
|
372
|
+
ret.tensors[k] = new_v
|
373
|
+
return ret
|