torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,255 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import copy
10
+ import itertools
11
+ import traceback
12
+ from dataclasses import dataclass
13
+ from enum import auto, Enum
14
+ from typing import cast, Dict, List, Optional, Sequence
15
+
16
+ from monarch.simulator.config import META_VAL
17
+
18
+
19
+ class TaskState(Enum):
20
+ PENDING = auto()
21
+ READY = auto()
22
+ EXECUTING = auto()
23
+ EXECUTED = auto()
24
+
25
+
26
+ class Task:
27
+ """
28
+ A class to represent a task in a stream. A task is ready immediately if all
29
+ its dependencies are executed. A task is executed if it is ready and it is
30
+ the first task in the stream. A task can be marked as executed if it is executing
31
+ and all the collectives, if any, of the task are executing.
32
+
33
+ Args:
34
+ inputs (List[int]): A list of input tensor ids.
35
+ outputs (List[int]): A list of output tensor ids.
36
+ command_id (int): The id of the command this task executes.
37
+ runtime (int): The runtime of the task in nanoseconds.
38
+ meta (List[str]): A list of metadata associated with the task.
39
+ collectives (Optional[List]): A list of collectives associated with the task.
40
+ Defaults to None.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ inputs: List[int],
46
+ outputs: List[int],
47
+ command_id: int,
48
+ start_time: int,
49
+ runtime: int,
50
+ meta: List[str],
51
+ collectives: Optional[List["Task"]] = None,
52
+ waits: Optional[List["Task"]] = None,
53
+ traceback: Sequence[traceback.FrameSummary] = (),
54
+ ):
55
+ self.inputs = inputs
56
+ self.outputs = outputs
57
+ self.runtime = runtime
58
+ self.meta = meta + META_VAL
59
+ self.dependencies = []
60
+ self.collectives = collectives
61
+ self.waits = waits
62
+ self.command_id = command_id
63
+ self.traceback = traceback
64
+ if self.collectives is not None:
65
+ self.collectives.append(self)
66
+
67
+ self._state = TaskState.PENDING
68
+ self.start_time = start_time
69
+ self.end_time = 0
70
+
71
+ # Assied by WorkerTaskManager
72
+ self.task_id: Optional[int] = None
73
+
74
+ def __repr__(self):
75
+ return " ".join(self.meta)
76
+
77
+ @property
78
+ def state(self) -> TaskState:
79
+ return self._state
80
+
81
+ def maybe_set_ready(self) -> bool:
82
+ """
83
+ Sets the task state to READY if it is ready. Returns True if the task state
84
+ changes from PENDING to READY.
85
+ """
86
+ if self._state != TaskState.PENDING:
87
+ return False
88
+
89
+ if self.dependencies:
90
+ for d in self.dependencies:
91
+ if d._state != TaskState.EXECUTED:
92
+ return False
93
+ self.start_time = max(self.start_time, d.end_time)
94
+ self._state = TaskState.READY
95
+ return True
96
+
97
+ def maybe_execute(self) -> bool:
98
+ """
99
+ Executes the task if it is ready. Returns True if the task state changes
100
+ from READY to EXECUTING.
101
+ """
102
+ if self._state != TaskState.READY:
103
+ return False
104
+
105
+ self._state = TaskState.EXECUTING
106
+ return True
107
+
108
+ def maybe_finish(self) -> bool:
109
+ """
110
+ Finish the task if it is executing and all the associated collectives,
111
+ if any, are executing or executed. Return True if the task state changes from
112
+ EXECUTING to EXECUTED.
113
+ """
114
+ if not self._state == TaskState.EXECUTING:
115
+ return False
116
+
117
+ executed = True
118
+ if self.collectives:
119
+ executed = all(
120
+ c.state in (TaskState.EXECUTING, TaskState.EXECUTED)
121
+ for c in self.collectives
122
+ )
123
+ if self.waits:
124
+ executed = executed and all(
125
+ c.state == TaskState.EXECUTED for c in self.waits
126
+ )
127
+ if not executed:
128
+ return False
129
+
130
+ self._state = TaskState.EXECUTED
131
+ if self.collectives:
132
+ straggler_time = max(c.start_time for c in self.collectives)
133
+ self.end_time = straggler_time + self.runtime
134
+ if self.waits:
135
+ last_wait_event_time = max(c.end_time for c in self.waits)
136
+ self.end_time = max(self.end_time, last_wait_event_time)
137
+ if self.meta[0] != "aten.view":
138
+ self.end_time = max(self.end_time, self.start_time + self.runtime)
139
+ else:
140
+ # TODO: this is a workaround to removing `view` from the trace.
141
+ # What we really should do is to have the CPU trace besides GPU trace.
142
+ self.end_time = self.start_time
143
+
144
+ return True
145
+
146
+ def clone(self) -> "Task":
147
+ return copy.copy(self)
148
+
149
+
150
+ @dataclass
151
+ class Borrow:
152
+ ident: int
153
+ tensor_src_id: int
154
+ tensor_dst_id: int
155
+ from_stream: int
156
+ to_stream: int
157
+
158
+
159
+ class EventTask(Task):
160
+ """Represents an event task in a stream."""
161
+
162
+ def __init__(
163
+ self,
164
+ recorded_task: Task,
165
+ event_stream: int,
166
+ event_stream_name: str,
167
+ wait_stream: int,
168
+ wait_stream_name: str,
169
+ start_time: int,
170
+ command_id: int,
171
+ runtime: int = 1,
172
+ borrow: Optional[Borrow] = None,
173
+ traceback: Sequence[traceback.FrameSummary] = (),
174
+ ):
175
+ super().__init__(
176
+ inputs=[],
177
+ outputs=[],
178
+ command_id=command_id,
179
+ start_time=start_time,
180
+ runtime=runtime,
181
+ meta=["waiting for", event_stream_name],
182
+ waits=[recorded_task],
183
+ traceback=traceback,
184
+ )
185
+ self.event_stream = event_stream
186
+ self.event_stream_name = event_stream_name
187
+ self.wait_stream = wait_stream
188
+ self.wait_stream_name = wait_stream_name
189
+ self.borrow = borrow
190
+
191
+ def clone(self) -> "EventTask":
192
+ return copy.copy(self)
193
+
194
+
195
+ class WorkerTaskManager(Task):
196
+ def __init__(self) -> None:
197
+ self.tasks: Dict[int, Task] = {}
198
+ self.task_id = itertools.count()
199
+
200
+ def add(self, task: Task) -> int:
201
+ task_id = next(self.task_id)
202
+ self.tasks[task_id] = task
203
+ task.task_id = task_id
204
+ return task_id
205
+
206
+ def remove(self, task: Task) -> None:
207
+ if (task_id := task.task_id) is not None:
208
+ self.tasks.pop(task_id)
209
+ else:
210
+ raise ValueError("task_id is None")
211
+
212
+ def clone(self) -> "WorkerTaskManager":
213
+ cloned_tasks = {}
214
+ for task_id, task in self.tasks.items():
215
+ cloned_task = task.clone()
216
+ # Both dependencies and waits are all tasks on the same worker
217
+ # thread. Thus, they must be in the same WorkerTaskManager or
218
+ # they must be executed.
219
+ cloned_tasks[task_id] = cloned_task
220
+ if task.dependencies:
221
+ cloned_task.dependencies = []
222
+ for dep in task.dependencies:
223
+ if dep.task_id not in cloned_tasks:
224
+ # The dependency is executed, so it is not in the
225
+ # WorkerTaskManager. Just clone it to ensure the
226
+ # dependency is cloned but not added to the new
227
+ # WorkerTaskManager.
228
+ assert dep.state == TaskState.EXECUTED
229
+ cloned_task.dependencies.append(dep.clone())
230
+ else:
231
+ cloned_task.dependencies.append(cloned_tasks[dep.task_id])
232
+ if task.waits is not None:
233
+ cloned_task.waits = []
234
+ for wait in cast(List[Task], task.waits):
235
+ if wait.task_id not in cloned_tasks:
236
+ assert wait.state == TaskState.EXECUTED
237
+ assert cloned_task.waits is not None
238
+ cloned_task.waits.append(wait.clone())
239
+ else:
240
+ assert cloned_task.waits is not None
241
+ cloned_task.waits.append(cloned_tasks[wait.task_id])
242
+
243
+ # TODO: the global list shared by all the tasks with the same collective
244
+ # is a neat idea but can be hard to debug. Consider make it more explicit.
245
+ if cloned_task.collectives:
246
+ cloned_task.collectives.append(cloned_task)
247
+
248
+ cloned_tasks[task_id] = cloned_task
249
+
250
+ ret = WorkerTaskManager()
251
+ # Waste one to ensure all the cloned WorkerTaskManager has the same task_id.
252
+ next_task_id = next(self.task_id)
253
+ ret.task_id = itertools.count(next_task_id + 1)
254
+ ret.tasks = cloned_tasks
255
+ return ret
@@ -0,0 +1,373 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import copy
9
+ import heapq
10
+ import logging
11
+ import traceback
12
+ from collections import defaultdict
13
+ from itertools import count
14
+ from typing import Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
15
+
16
+ import torch
17
+ from monarch.common.fake import fake_call
18
+ from monarch.common.tensor_factory import TensorFactory
19
+ from monarch.simulator.task import Task, WorkerTaskManager
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class DTensorRef:
25
+ """
26
+ A reference to a `controller.tensor.Tensor` object.
27
+
28
+ This class is used to keep track of DTensor objects that have been created
29
+ and by the controller and to provide the mechanism to serialize DTensor
30
+ objects (torch.save/torch.load).
31
+ """
32
+
33
+ created: Dict[int, "DTensorRef"] = {}
34
+
35
+ def __init__(self, tensor):
36
+ self.ref = tensor.ref
37
+ self.factory = TensorFactory.from_tensor(tensor)
38
+ self._fake: Optional[torch._subclasses.FakeTensor] = getattr(
39
+ tensor, "_fake", None
40
+ )
41
+ if self._fake is not None:
42
+ self._storage_id: Optional[torch.types._int] = id(
43
+ self._fake.untyped_storage()
44
+ )
45
+ self._size: Optional[int] = self._fake.untyped_storage().size()
46
+ else:
47
+ self._storage_id = None
48
+ self._size = None
49
+
50
+ def __repr__(self):
51
+ return f"DTensorRef({self.ref})"
52
+
53
+ @classmethod
54
+ def from_ref(cls, tensor) -> "DTensorRef":
55
+ if tensor.ref not in cls.created:
56
+ cls.created[tensor.ref] = cls(tensor)
57
+ return cls.created[tensor.ref]
58
+
59
+ def __getstate__(self):
60
+ return {
61
+ "ref": self.ref,
62
+ "factory": self.factory,
63
+ "_fake": None,
64
+ }
65
+
66
+ def __setstate__(self, state):
67
+ self.__dict__.update(state)
68
+ self._fake = fake_call(self.factory.zeros)
69
+
70
+ def __deepcopy__(self, memo):
71
+ if self._fake is None:
72
+ raise RuntimeError()
73
+
74
+ fake = fake_call(self.factory.zeros)
75
+ fake._fake = fake
76
+ fake.ref = self.ref
77
+ return self.__class__(fake)
78
+
79
+
80
+ class FakeTensorTracker:
81
+ """
82
+ Tracks the fake tensors created in the simulator. While each worker and stream
83
+ maintain its own tensors, we don't want to create one FakeTensor per stream/worker.
84
+ Instead, we can just share the fake tensor for the same tensor id.
85
+ This can reduce the simulation time.
86
+
87
+ A fake tensor is created when it is first created in any worker and is deleted
88
+ when it is deleted in all workers.
89
+ """
90
+
91
+ def __init__(self) -> None:
92
+ self.tensors: Dict[int, torch._subclasses.FakeTensor] = {}
93
+ self._ref: Dict[int, int] = defaultdict(int)
94
+ self._borrowed_tensors: Set[int] = set()
95
+
96
+ def add(
97
+ self, tensors: Dict[int, torch._subclasses.FakeTensor], is_borrowed=False
98
+ ) -> None:
99
+ self.tensors.update(tensors)
100
+ if is_borrowed:
101
+ self._borrowed_tensors.update(set(tensors.keys()))
102
+
103
+ def is_borrowed(self, tensor: int) -> bool:
104
+ return tensor in self._borrowed_tensors
105
+
106
+ def incr_ref(self, tensor_id: int) -> None:
107
+ assert tensor_id in self.tensors, f"Tensor {tensor_id} is not created"
108
+ self._ref[tensor_id] += 1
109
+
110
+ def decr_ref(self, tensor_id: int):
111
+ ref = self._ref[tensor_id] - 1
112
+ assert ref >= 0, f"Tensor {tensor_id} has negative ref count {ref}"
113
+ if ref == 0:
114
+ self.tensors.pop(tensor_id)
115
+ self._ref.pop(tensor_id)
116
+ else:
117
+ self._ref[tensor_id] = ref
118
+
119
+
120
+ class StorageEvent(NamedTuple):
121
+ address: int
122
+ delta: int
123
+
124
+
125
+ class WorkerStorageTracker:
126
+ def __init__(self, fake_tensor_tracker) -> None:
127
+ self.storages: Dict[torch.UntypedStorage, Set[int]] = {}
128
+ self.fake_tensor_tracker = fake_tensor_tracker
129
+ self._addr_counter = count(step=128) # aligning 128-byte cache lines?
130
+ self.storage_addresses: Dict[torch.UntypedStorage, int] = {}
131
+
132
+ def incr_ref(self, tensor_id: int) -> Optional[StorageEvent]:
133
+ fake = self.fake_tensor_tracker.tensors[tensor_id]
134
+ storage = fake.untyped_storage()
135
+ if storage not in self.storages:
136
+ self.storages[storage] = {tensor_id}
137
+ addr = next(self._addr_counter)
138
+ self.storage_addresses[storage] = addr
139
+ if self.fake_tensor_tracker.is_borrowed(tensor_id):
140
+ return None # Q: should self._addr_counter be reversed?
141
+ else:
142
+ return StorageEvent(addr, storage.size())
143
+ else:
144
+ self.storages[storage].add(tensor_id)
145
+ return None
146
+
147
+ def decr_ref(self, tensor_id: int) -> Optional[StorageEvent]:
148
+ fake = self.fake_tensor_tracker.tensors[tensor_id]
149
+ storage = fake.untyped_storage()
150
+ if storage not in self.storages:
151
+ raise RuntimeError(
152
+ f"{storage} is being dereferenced but it is not tracked."
153
+ )
154
+ else:
155
+ references = self.storages[storage]
156
+ references.remove(tensor_id)
157
+ if len(references) == 0:
158
+ self.storages.pop(storage)
159
+ addr = self.storage_addresses.pop(storage)
160
+ if self.fake_tensor_tracker.is_borrowed(tensor_id):
161
+ # The controller creates a new FakeTensor for Borrow.
162
+ # So we should not count the storage usage of this
163
+ # FakeTensor as it is not a materialized tensor on
164
+ # the works.
165
+ return None
166
+ else:
167
+ return StorageEvent(addr, storage.size())
168
+ return None
169
+
170
+ def clone(self) -> "WorkerStorageTracker":
171
+ ret = WorkerStorageTracker(self.fake_tensor_tracker)
172
+ ret.storages = copy.copy(self.storages)
173
+ return ret
174
+
175
+
176
+ class MemoryEvent(NamedTuple):
177
+ timestamp: int
178
+ address: int
179
+ delta: int
180
+ traceback: Sequence[traceback.FrameSummary]
181
+
182
+ def __lt__(self, other):
183
+ if self.timestamp == other.timestamp:
184
+ return self.delta < other.delta
185
+ return self.timestamp < other.timestamp
186
+
187
+ def __gt__(self, other):
188
+ if self.timestamp == other.timestamp:
189
+ return self.delta > other.delta
190
+ return self.timestamp > other.timestamp
191
+
192
+ def __eq__(self, other):
193
+ return self.timestamp == other.timestamp and self.delta == other.delta
194
+
195
+
196
+ class StreamMemoryTracker:
197
+ """
198
+ Tracks the memory events (timestamp, usage_delta) of a stream. The usage
199
+ may not be added in the correct time order due to the asynchronous
200
+ simulated-execution of worker CPU thread and the stream thread. Thus a
201
+ heap is used to sort the events by timestamp.
202
+ """
203
+
204
+ def __init__(self, storage_tracker: WorkerStorageTracker) -> None:
205
+ self.usage = 0
206
+ self.events: List[MemoryEvent] = []
207
+ self.storage_tracker = storage_tracker
208
+ self._tracked_addresses: Dict[int, int] = {}
209
+
210
+ def incr_ref(
211
+ self, ts: int, tensor_id, traceback: Optional[Sequence[traceback.FrameSummary]]
212
+ ) -> None:
213
+ storage_event = self.storage_tracker.incr_ref(tensor_id)
214
+ delta = 0 if storage_event is None else storage_event.delta
215
+ logger.debug(
216
+ f"StreamMemoryTracker got {tensor_id} at {ts} and delta is {delta}."
217
+ )
218
+ # Some operators may return zero-size tensors.
219
+ # One example is aten._scaled_dot_product_flash_attention.default
220
+ torch.ops.aten._scaled_dot_product_flash_attention.default
221
+ if storage_event is not None and storage_event.delta != 0:
222
+ assert ts >= 0
223
+ assert traceback is not None
224
+ self._add_usage(ts, storage_event, traceback)
225
+
226
+ def decr_ref(
227
+ self, ts: int, tensor_id, traceback: Optional[Sequence[traceback.FrameSummary]]
228
+ ) -> None:
229
+ storage_event = self.storage_tracker.decr_ref(tensor_id)
230
+ if storage_event is not None and storage_event.delta != 0:
231
+ assert ts >= 0
232
+ assert traceback is not None
233
+ self._remove_usage(ts, storage_event, traceback)
234
+
235
+ def _remove_usage(self, ts: int, storage_event: StorageEvent, traceback) -> None:
236
+ assert storage_event.delta <= self.usage
237
+ self.usage -= storage_event.delta
238
+ recorded_ts = self._tracked_addresses.pop(storage_event.address, -1)
239
+ if recorded_ts == -1:
240
+ raise RuntimeError(f"Cannot find the address {storage_event.address}")
241
+ if recorded_ts >= ts:
242
+ raise RuntimeError(
243
+ f"The address {storage_event.address} is allocated after being freed"
244
+ )
245
+ heapq.heappush(
246
+ self.events,
247
+ MemoryEvent(ts, storage_event.address, -storage_event.delta, traceback),
248
+ )
249
+
250
+ def _add_usage(self, ts: int, storage_event: StorageEvent, traceback) -> None:
251
+ self.usage += storage_event.delta
252
+ self._tracked_addresses[storage_event.address] = ts
253
+ heapq.heappush(
254
+ self.events,
255
+ MemoryEvent(ts, storage_event.address, storage_event.delta, traceback),
256
+ )
257
+
258
+ def pop_event(self) -> MemoryEvent:
259
+ return heapq.heappop(self.events)
260
+
261
+ def clone(self, storage_tracker: WorkerStorageTracker) -> "StreamMemoryTracker":
262
+ ret = StreamMemoryTracker(storage_tracker)
263
+ ret.usage = self.usage
264
+ ret.events = copy.copy(self.events)
265
+ return ret
266
+
267
+
268
+ class TensorManager:
269
+ """
270
+ Tracks the tensor created in a worker or a stream. It can be CPU tensor,
271
+ which can only be owned by the worker or a gpu tensor which can only be
272
+ owned by a stream.
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ fake_tensor_tracker: FakeTensorTracker,
278
+ memory: Optional[StreamMemoryTracker],
279
+ ) -> None:
280
+ self.tensors: Dict[int, Set[Union[Task, int]]] = {}
281
+ self.delete_tracebacks: Dict[
282
+ int, Optional[Sequence[traceback.FrameSummary]]
283
+ ] = {}
284
+ self.pending_delete_tensors: Set[int] = set()
285
+ self.memory = memory
286
+ self.fake_tensor_tracker = fake_tensor_tracker
287
+
288
+ def add(self, tensor_id: int, refs: Tuple[Union[Task, int], ...], now: int) -> None:
289
+ logger.debug(f"TensorManager got {tensor_id} at {now}.")
290
+ self.tensors[tensor_id] = set(refs)
291
+ self.fake_tensor_tracker.incr_ref(tensor_id)
292
+
293
+ def first_use(
294
+ self,
295
+ tensor_id: int,
296
+ now: int,
297
+ traceback: Optional[Sequence[traceback.FrameSummary]],
298
+ ) -> None:
299
+ logging.debug(f"TensorManager: {tensor_id} is first used")
300
+ if self.memory:
301
+ self.memory.incr_ref(now, tensor_id, traceback)
302
+
303
+ def incr_ref(self, tensor_id: int, ref: Union[Task, int]) -> None:
304
+ logging.debug(f"TensorManager: {tensor_id} is referenced.")
305
+ self.tensors[tensor_id].add(ref)
306
+
307
+ def decr_ref(
308
+ self,
309
+ tensor_id: int,
310
+ ref: Union[Task, int],
311
+ now: int,
312
+ traceback: Optional[Sequence[traceback.FrameSummary]],
313
+ ) -> None:
314
+ logging.debug(f"TensorManager: {tensor_id} decr_ref.")
315
+ self.tensors[tensor_id].remove(ref)
316
+ self._maybe_delete_tensor(tensor_id, now, traceback)
317
+
318
+ def delete(
319
+ self,
320
+ tensor_id: int,
321
+ now: int,
322
+ traceback: Optional[Sequence[traceback.FrameSummary]],
323
+ ) -> None:
324
+ self.pending_delete_tensors.add(tensor_id)
325
+ self._maybe_delete_tensor(tensor_id, now, traceback)
326
+
327
+ def __contains__(self, key: int) -> bool:
328
+ return key in self.tensors
329
+
330
+ def _maybe_delete_tensor(
331
+ self,
332
+ tensor_id: int,
333
+ now: int,
334
+ traceback: Optional[Sequence[traceback.FrameSummary]],
335
+ ) -> None:
336
+ if len(self.tensors[tensor_id]) > 0:
337
+ return
338
+
339
+ if tensor_id not in self.pending_delete_tensors:
340
+ # While no one is using this tensor, Controller has not
341
+ # asked us to delete the tensor. Track the traceback of
342
+ # the last task.
343
+ self.delete_tracebacks[tensor_id] = traceback
344
+ return
345
+
346
+ traceback = (
347
+ traceback
348
+ if traceback is not None
349
+ else self.delete_tracebacks.pop(tensor_id, None)
350
+ )
351
+
352
+ if self.memory:
353
+ self.memory.decr_ref(now, tensor_id, traceback)
354
+
355
+ self.tensors.pop(tensor_id)
356
+ self.fake_tensor_tracker.decr_ref(tensor_id)
357
+ self.pending_delete_tensors.remove(tensor_id)
358
+
359
+ def clone(
360
+ self, task_manager: WorkerTaskManager, memory: Optional[StreamMemoryTracker]
361
+ ) -> "TensorManager":
362
+ ret = TensorManager(self.fake_tensor_tracker, memory)
363
+ ret.pending_delete_tensors = copy.copy(self.pending_delete_tensors)
364
+ for k, v in self.tensors.items():
365
+ new_v = set()
366
+ for task in v:
367
+ if isinstance(task, Task):
368
+ assert task.task_id is not None
369
+ new_v.add(task_manager.tasks[task.task_id])
370
+ else:
371
+ new_v.add(task)
372
+ ret.tensors[k] = new_v
373
+ return ret