torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,1191 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import asyncio
9
+ import bdb
10
+ import itertools
11
+ import logging
12
+ import os
13
+ import pdb # noqa
14
+ import queue
15
+ import threading
16
+ from collections import deque
17
+ from contextlib import contextmanager
18
+ from traceback import extract_tb
19
+ from typing import (
20
+ Any,
21
+ Callable,
22
+ Dict,
23
+ Generator,
24
+ List,
25
+ NamedTuple,
26
+ Optional,
27
+ Protocol,
28
+ Sequence,
29
+ Tuple,
30
+ Union,
31
+ )
32
+
33
+ from weakref import WeakKeyDictionary
34
+
35
+ import torch
36
+ import torch.distributed
37
+ import torch.fx
38
+ import zmq
39
+ import zmq.asyncio
40
+
41
+ from monarch.common import messages
42
+ from monarch.common.function import ResolvableFunction
43
+ from monarch.common.messages import DependentOnError, Dims
44
+ from monarch.common.process_group import SingleControllerProcessGroupWrapper
45
+ from monarch.common.reference import Ref, Referenceable
46
+ from monarch.common.shape import NDSlice
47
+ from monarch.common.tensor_factory import TensorFactory
48
+ from monarch.common.tree import flatten, flattener
49
+ from monarch_supervisor import get_message_queue, Letter
50
+ from monarch_supervisor.logging import initialize_logging
51
+
52
+ from .compiled_block import CompiledBlock
53
+ from .debugger import _set_trace
54
+ from .monitor import Monitor
55
+
56
+ logger = logging.getLogger(__name__)
57
+ try:
58
+ CONTROLLER_COMPILED_REPEAT = 0 != int(os.environ["CONTROLLER_COMPILED_REPEAT"])
59
+ except KeyError:
60
+ CONTROLLER_COMPILED_REPEAT = True
61
+
62
+
63
+ def set_default_dtype(dtype: torch.dtype):
64
+ torch.set_default_dtype(dtype)
65
+
66
+
67
+ class Dim(NamedTuple):
68
+ name: str
69
+ rank: int
70
+ size: int
71
+ members: List[int]
72
+
73
+
74
+ class RemoteProcessGroupShell:
75
+ def __init__(self, device_mesh: "DeviceMesh", dims: Dims, ref: Ref):
76
+ self.device_mesh = device_mesh
77
+ self.dims = dims
78
+ self.ref = ref
79
+
80
+ # return the process group, sanity checking that the stream it was created on is the stream it is being used on.
81
+ def get_process_group_for_stream(self, stream: "Stream"):
82
+ return self.device_mesh.get_process_group(stream, self.dims, pg=self.ref)
83
+
84
+
85
+ def _new_process_group(
86
+ controller_global_unique_name: str, ranks: Optional[List[int]], split: bool
87
+ ):
88
+ assert torch.distributed.is_initialized()
89
+ from unittest.mock import patch
90
+
91
+ # Pytorch versions from about the past month have an implementation of process group name with local names that
92
+ # can cause TCPStore name collisions (https://www.internalfb.com/intern/diff/D67312715/).
93
+ # This will get fixed soon in pytorch, but will take some time to rollout.
94
+ # In the meantime, our workers have enough knowledge to simply generate a unique names based on the data they already have.
95
+ # While not strictly needed once pytorch fixes the bug, this illustrates how our own initialization of nccl can just directly
96
+ # provide a unique key for each process group it is creating.
97
+ with patch(
98
+ "torch.distributed.distributed_c10d._process_group_name",
99
+ side_effect=lambda *args, **kwargs: controller_global_unique_name,
100
+ ) as the_patch:
101
+ if split:
102
+ assert ranks is not None
103
+ pg = torch.distributed.split_group(None, [ranks])
104
+ else:
105
+ pg = torch.distributed.new_group(ranks, use_local_synchronization=True)
106
+
107
+ assert the_patch.called
108
+ return pg
109
+
110
+
111
+ restart_count = 0
112
+
113
+
114
+ class DeviceMesh:
115
+ def __init__(self, id: int, names: Dims, ranks: NDSlice, rank: int):
116
+ self.id = id
117
+ self.dims: Dict[str, Dim] = {}
118
+ coordinates = ranks.coordinates(rank)
119
+ for coordinate, name, size, stride in zip(
120
+ coordinates, names, ranks.sizes, ranks.strides
121
+ ):
122
+ start = rank - stride * coordinate
123
+ members = [*range(start, start + stride * size, stride)]
124
+ assert members[coordinate] == rank
125
+ self.dims[name] = Dim(name, coordinate, size, members)
126
+ self.all_ranks: List[int] = list(ranks)
127
+ self.process_group_for_stream: WeakKeyDictionary["Stream", Any] = (
128
+ WeakKeyDictionary()
129
+ )
130
+
131
+ def get_ranks_for_dim_slice(self, names: Dims):
132
+ if len(names) == 0:
133
+ return []
134
+ if len(names) == 1:
135
+ return self.dims[names[0]].members
136
+ if len(names) == len(self.dims):
137
+ return self.all_ranks
138
+
139
+ dims = [self.dims[n] for n in names]
140
+
141
+ members = [dim.members for dim in dims]
142
+ strides = [d[1] - d[0] if len(d) > 1 else 0 for d in members]
143
+ start = members[0][dims[0].rank]
144
+ for d, s in zip(dims, strides):
145
+ start -= s * d.rank
146
+
147
+ ranks = []
148
+ for idxs in itertools.product(*[range(d.size) for d in dims]):
149
+ offset = sum([i * s for i, s in zip(idxs, strides)])
150
+ ranks.append(start + offset)
151
+ return ranks
152
+
153
+ def create_process_group(
154
+ self, stream: "Stream", dims: Dims, pg: Optional[Ref] = None
155
+ ):
156
+ if stream not in self.process_group_for_stream:
157
+ self.process_group_for_stream[stream] = {}
158
+ dims = tuple(sorted(dims))
159
+ key = (pg, dims)
160
+ if key in self.process_group_for_stream[stream]:
161
+ raise AssertionError(
162
+ f"Tried to create a process group for {stream=}, {dims=} but it already exists!"
163
+ )
164
+ ranks = self.get_ranks_for_dim_slice(dims)
165
+ indices = [
166
+ str(d.rank) if d.name not in dims else "X" for d in self.dims.values()
167
+ ]
168
+ name = f"restart_{restart_count}_mesh_{self.id}_stream_{stream.id}_{'_'.join(indices)}"
169
+ if pg is not None:
170
+ name += f"_group_{pg}"
171
+ self.process_group_for_stream[stream][key] = (
172
+ SingleControllerProcessGroupWrapper(
173
+ _new_process_group(name, ranks, split=True)
174
+ )
175
+ )
176
+ return self.get_process_group(stream, dims, pg=pg)
177
+
178
+ def get_process_group(self, stream: "Stream", dims: Dims, pg: Optional[Ref] = None):
179
+ dims = tuple(sorted(dims))
180
+ key = (pg, dims)
181
+ return self.process_group_for_stream[stream][key]
182
+
183
+ def create_process_group_shell(self, dims: Dims, ref: Ref):
184
+ return RemoteProcessGroupShell(self, dims, ref)
185
+
186
+
187
+ def _rank(mesh: "DeviceMesh", dim: str):
188
+ return torch.full((), mesh.dims[dim].rank, dtype=torch.long)
189
+
190
+
191
+ def _process_idx(mesh: "DeviceMesh"):
192
+ """
193
+ Return linear idx of the current process in the mesh.
194
+ """
195
+ # any dimension can be used to query our rank
196
+ _, dim = next(iter(mesh.dims.items()))
197
+ return torch.full((), dim.members[dim.rank], dtype=torch.long)
198
+
199
+
200
+ def _reduce(
201
+ local_tensor: torch.Tensor,
202
+ source_mesh: DeviceMesh,
203
+ group,
204
+ group_size: int,
205
+ reduction: str,
206
+ scatter: bool,
207
+ inplace: bool,
208
+ out: Optional[torch.Tensor],
209
+ ):
210
+ if reduction == "stack":
211
+ if scatter:
212
+ output = local_tensor
213
+ if not inplace:
214
+ output = local_tensor.clone() if out is None else out
215
+ torch.distributed.all_to_all_single(output, local_tensor, group=group)
216
+ return output
217
+
218
+ assert not inplace
219
+ output = (
220
+ torch.empty(
221
+ [group_size, *local_tensor.shape],
222
+ dtype=local_tensor.dtype,
223
+ device=local_tensor.device,
224
+ layout=local_tensor.layout,
225
+ )
226
+ if out is None
227
+ else out
228
+ )
229
+ torch.distributed.all_gather_into_tensor(output, local_tensor, group=group)
230
+ return output
231
+
232
+ op = getattr(torch.distributed.ReduceOp, reduction.upper())
233
+
234
+ if scatter:
235
+ assert not inplace
236
+ output = (
237
+ torch.empty(
238
+ local_tensor.shape[1:],
239
+ dtype=local_tensor.dtype,
240
+ device=local_tensor.device,
241
+ layout=local_tensor.layout,
242
+ )
243
+ if out is None
244
+ else out
245
+ )
246
+ torch.distributed.reduce_scatter_tensor(
247
+ output, local_tensor, op=op, group=group
248
+ )
249
+ return output
250
+
251
+ output = local_tensor
252
+ if not inplace:
253
+ output = local_tensor.clone() if out is None else out
254
+ torch.distributed.all_reduce(output, op=op, group=group)
255
+ return output
256
+
257
+
258
+ class _TLS(threading.local):
259
+ def __init__(self):
260
+ self.tracing: Optional["CompiledBlock"] = None
261
+ self.stream: Optional["Stream"] = None
262
+
263
+
264
+ _tls = _TLS()
265
+
266
+
267
+ def schedule_on_stream_thread(executes_on_error: bool):
268
+ def wrapper(fn):
269
+ return lambda self, *args, **kwargs: self.schedule(
270
+ lambda: (
271
+ logger.debug(
272
+ "executing: %s(args=%s, kwargs=%s)", fn.__name__, args, kwargs
273
+ ),
274
+ fn(self, *args, **kwargs),
275
+ ),
276
+ executes_on_error,
277
+ )
278
+
279
+ return wrapper
280
+
281
+
282
+ class Stream:
283
+ def __init__(self, worker: "Worker", id: int, default: bool):
284
+ self.id = id
285
+ self.worker = worker
286
+ self.thread: Optional[threading.Thread] = None
287
+ self.q: queue.Queue[Callable[[], None]] = queue.Queue()
288
+ # used to send messages pdb from controller see debugger.py
289
+ self.debugger_queue: queue.Queue[Any] = queue.Queue()
290
+ self.should_exit = threading.Event()
291
+ self.current_recording: Optional[int] = None
292
+ if default:
293
+ self._cuda_stream = None
294
+ else:
295
+ self._cuda_stream = torch.cuda.Stream()
296
+
297
+ @schedule_on_stream_thread(executes_on_error=False)
298
+ def run_recording(
299
+ self, ident: int, impl: Callable, results: List["Cell"], inputs: List["Cell"]
300
+ ):
301
+ self.current_recording = ident
302
+ try:
303
+ impl(results, inputs)
304
+ finally:
305
+ self.current_recording = None
306
+
307
+ @property
308
+ def cuda_stream(self):
309
+ if self._cuda_stream is None:
310
+ return torch.cuda.current_stream()
311
+ else:
312
+ return self._cuda_stream
313
+
314
+ @contextmanager
315
+ def enable(self):
316
+ if self._cuda_stream is None:
317
+ yield
318
+ return
319
+ with torch.cuda.stream(self._cuda_stream):
320
+ yield
321
+
322
+ def event(self):
323
+ e = torch.cuda.Event()
324
+ self.cuda_stream.record_event(e)
325
+ return e
326
+
327
+ def wait_event(self, event):
328
+ self.cuda_stream.wait_event(event)
329
+
330
+ def wait_stream(self, stream):
331
+ self.cuda_stream.wait_stream(stream.cuda_stream)
332
+
333
+ def start(self) -> threading.Thread:
334
+ thread = threading.Thread(target=self.main)
335
+ thread.start()
336
+ return thread
337
+
338
+ def main(self):
339
+ _tls.stream = self
340
+ with self.enable():
341
+ try:
342
+ while True:
343
+ self.q.get()()
344
+ except StopIteration:
345
+ pass
346
+ except Exception as e:
347
+ logger.exception("Stream thread exiting with exception.")
348
+ msg = messages.InternalException(e, extract_tb(e.__traceback__))
349
+ self.worker.schedule(lambda: self.worker.internal_error(msg))
350
+
351
+ def exit(self):
352
+ def stop():
353
+ raise StopIteration
354
+
355
+ self.schedule(stop)
356
+ self.debugger_queue.put("detach")
357
+
358
+ def join(self):
359
+ if self.thread is None:
360
+ return
361
+ self.exit()
362
+ self.thread.join()
363
+
364
+ def schedule(self, fn: Callable[[], None], executes_on_error: bool = False):
365
+ if _tls.tracing:
366
+ tracing = _tls.tracing
367
+ if executes_on_error:
368
+ tracing.fallback[self].append(fn)
369
+ with tracing.record_to(self):
370
+ fn()
371
+ return
372
+
373
+ if self.thread is None:
374
+ self.thread = threading.Thread(target=self.main, daemon=True)
375
+ self.thread.start()
376
+ self.q.put(fn)
377
+
378
+ def call_or_trace(self, fn, *args, **kwargs):
379
+ if _tls.tracing:
380
+ return _tls.tracing.call_function(fn, args, kwargs)
381
+ return fn(*args, **kwargs)
382
+
383
+ def report_error(self, ident: int, index: int, e: Exception, extra: Any = None):
384
+ logger.exception(f"Error generating {ident}, {extra=}", exc_info=e)
385
+ self.worker.q.send(
386
+ messages.RemoteFunctionFailed(ident, index, e, extract_tb(e.__traceback__))
387
+ )
388
+ return DependentOnError(ident)
389
+
390
+ @contextmanager
391
+ def try_define(
392
+ self, ident: Optional[int], results: Sequence["Cell"], extra: Any = None
393
+ ):
394
+ tracing = _tls.tracing
395
+ if tracing:
396
+ ctx = tracing.current_context
397
+ ctx.ident = ident
398
+ tracing.mutates(results)
399
+
400
+ try:
401
+ yield
402
+ except DependentOnError as e:
403
+ for r in results:
404
+ r.set(e)
405
+ # note: there is no need to to send RemoteFunctionFailed
406
+ # because the controller would have already gotten and propagated the
407
+ # original created of DependentOnError.
408
+ except bdb.BdbQuit:
409
+ raise
410
+ except Exception as e:
411
+ # when try_define does not have an ident,
412
+ # the only error we expected is DependendOnError
413
+ # other errors should get treated as internal errors.
414
+ if ident is None:
415
+ raise
416
+ if self.current_recording is not None:
417
+ exc = self.report_error(self.current_recording, ident, e, extra)
418
+ else:
419
+ exc = self.report_error(ident, 0, e, extra)
420
+ for r in results:
421
+ r.set(exc)
422
+ finally:
423
+ if _tls.tracing:
424
+ # pyre-fixme[8]: Attribute has type `ErrorContext`; used as `None`.
425
+ _tls.tracing.current_context = None
426
+
427
+ @schedule_on_stream_thread(executes_on_error=False)
428
+ def call_function(
429
+ self,
430
+ ident: int,
431
+ defines: Tuple["Cell", ...],
432
+ flatten_result: Any,
433
+ mutates: Tuple["Cell", ...],
434
+ rfunction: ResolvableFunction,
435
+ inputs: List["Cell"],
436
+ unflatten_inputs: Any,
437
+ device_mesh: Optional["DeviceMesh"] = None,
438
+ ):
439
+ with self.try_define(
440
+ ident, [*defines, *mutates], extra=(rfunction, defines, mutates, inputs)
441
+ ):
442
+ function = rfunction.resolve()
443
+ resolved_inputs = []
444
+ for i in inputs:
445
+ input_ = i.get()
446
+ if isinstance(input_, RemoteProcessGroupShell):
447
+ # get the process group for the stream but dont' allow it to be created from
448
+ # this context since this isn't being run on the event loop.
449
+ resolved_inputs.append(input_.get_process_group_for_stream(self))
450
+ else:
451
+ resolved_inputs.append(input_)
452
+
453
+ args, kwargs = unflatten_inputs(resolved_inputs)
454
+ if _tls.tracing:
455
+ block = _tls.tracing
456
+ fn_node: torch.fx.Node = block.call_function(function, args, kwargs)
457
+ tensors = [
458
+ t.node if isinstance(t, torch.fx.Proxy) else t
459
+ for t in flatten_result(block.proxy(fn_node))
460
+ ]
461
+ else:
462
+ result = function(*args, **kwargs)
463
+ tensors = flatten_result(result)
464
+ assert len(defines) == len(tensors)
465
+ for d, t in zip(defines, tensors):
466
+ d.set(t)
467
+
468
+ @schedule_on_stream_thread(executes_on_error=False)
469
+ def send_value(
470
+ self,
471
+ ident: int,
472
+ rfunction: Optional[ResolvableFunction],
473
+ mutates: Tuple["Cell", ...],
474
+ inputs: List["Cell"],
475
+ unflatten: Any,
476
+ pipe: Optional["WorkerPipe"],
477
+ ):
478
+ with self.try_define(ident, mutates):
479
+ args, kwargs = unflatten(c.get() for c in inputs)
480
+ function = (lambda x: x) if rfunction is None else rfunction.resolve()
481
+ result = function(*args, **kwargs)
482
+ if pipe is None:
483
+ self.worker.q.send(messages.FetchResult(ident, result))
484
+ else:
485
+ self.call_or_trace(pipe.send, result)
486
+
487
+ @schedule_on_stream_thread(executes_on_error=False)
488
+ def collective_call(
489
+ self,
490
+ function: Callable,
491
+ factory: TensorFactory,
492
+ input_: "Cell",
493
+ result: "Cell",
494
+ out: Optional["Cell"] = None,
495
+ ):
496
+ try:
497
+ local_tensor = input_.get()
498
+ out_tensor = None if out is None else out.get()
499
+ except DependentOnError:
500
+ # even if we were broken before, we have to participate in the collective
501
+ # because we cannot signal to other ranks that we were broken
502
+ # the controller will see the error message we sent before and know
503
+ # the downstream values are broken.
504
+ local_tensor = factory.zeros()
505
+ out_tensor = None
506
+ # XXX - we should be careful about starting the collective with a tensor that doesn't match the expected
507
+ # factory size. It can error. however, before we can do something about it we need to assign a failure
508
+ # identity to this reduce object.
509
+ output = self.call_or_trace(function, local_tensor, out_tensor)
510
+ result.set(output)
511
+
512
+ @schedule_on_stream_thread(executes_on_error=True)
513
+ def borrow_create(self, input_: "Cell", borrow: "Borrow"):
514
+ self.call_or_trace(borrow.create, input_.get(), self)
515
+
516
+ @schedule_on_stream_thread(executes_on_error=True)
517
+ def borrow_first_use(self, result: "Cell", borrow: "Borrow"):
518
+ with self.try_define(None, [result]):
519
+ result.set(self.call_or_trace(borrow.first_use))
520
+
521
+ @schedule_on_stream_thread(executes_on_error=True)
522
+ def borrow_last_use(self, borrow: "Borrow"):
523
+ self.call_or_trace(borrow.last_use)
524
+
525
+ @schedule_on_stream_thread(executes_on_error=True)
526
+ def borrow_drop(self, borrow: "Borrow"):
527
+ self.call_or_trace(borrow.drop)
528
+
529
+
530
+ class Borrow:
531
+ def __init__(self, from_stream: Stream, to_stream: Stream):
532
+ self.from_stream = from_stream
533
+ self.to_stream = to_stream
534
+ self.first_use_queue = queue.Queue()
535
+ self.last_use_queue = queue.Queue()
536
+ # used to ensure the tensor memory stays alive in the
537
+ # allocator until it is returned to its original stream
538
+ self.tensor_storage = Cell(None)
539
+
540
+ def create(self, input_: Any, stream: Stream):
541
+ self.first_use_queue.put((stream.event(), input_))
542
+
543
+ def first_use(self):
544
+ event, t = self.first_use_queue.get()
545
+ self.tensor_storage.set(t)
546
+ self.to_stream.wait_event(event)
547
+ # raise any potential error _after_ already processing
548
+ # the events. We always do the synchronizations even
549
+ # if the value being borrowed is an error.
550
+ return self.tensor_storage.get()
551
+
552
+ def last_use(self):
553
+ t = self.tensor_storage.value
554
+ self.tensor_storage.set(undefined_cell)
555
+ self.last_use_queue.put((self.to_stream.event(), t))
556
+
557
+ def drop(self):
558
+ event, t = self.last_use_queue.get()
559
+ self.from_stream.wait_event(event)
560
+ del t
561
+
562
+
563
+ class WorkerMessageQueue(Protocol):
564
+ def _socket(self, kind) -> zmq.Socket: ...
565
+
566
+ def send(self, message: Any) -> None: ...
567
+
568
+ async def recv_async(self) -> Letter: ...
569
+
570
+ def recvready(self, timeout: Optional[float]) -> List[Letter]: ...
571
+
572
+
573
+ class WorkerPipe:
574
+ """
575
+ Worker (e.g Trainer) process pipe
576
+ """
577
+
578
+ def __init__(self, q: WorkerMessageQueue, pipe_name: str, max_messages: int = 50):
579
+ # breaking abstraction layer here, but it is an easy way to get a way to send messages
580
+ # to the process
581
+ self._sock = q._socket(zmq.PAIR)
582
+ self._sock.setsockopt(zmq.SNDHWM, max_messages)
583
+ self._sock.setsockopt(zmq.RCVHWM, max_messages)
584
+ self._sock.bind(pipe_name)
585
+
586
+ def send(self, v: Any):
587
+ self._sock.send_pyobj(v)
588
+
589
+ def recv(self) -> Any:
590
+ return self._sock.recv_pyobj()
591
+
592
+ # Allows us to pass the pipe as a function that can be called to get the next value
593
+ def resolve(self) -> Callable:
594
+ return self.recv
595
+
596
+
597
+ undefined_cell = RuntimeError("undefined cell")
598
+
599
+
600
+ class Cell:
601
+ __slots__ = ("value",)
602
+
603
+ def __init__(self, initial_value=undefined_cell):
604
+ self.value: Any = initial_value
605
+
606
+ def __repr__(self):
607
+ return "<C>"
608
+
609
+ def set(self, value: Any):
610
+ self.value = value
611
+
612
+ def clear(self):
613
+ self.value = undefined_cell
614
+
615
+ def is_defined(self):
616
+ return self.value is not undefined_cell
617
+
618
+ def get(self) -> Any:
619
+ tracing = _tls.tracing
620
+ if (
621
+ tracing is not None
622
+ and self not in tracing.defined_cells
623
+ and tracing.recording_stream is not None
624
+ ):
625
+ return tracing.input_cell(self)
626
+ v = self.value
627
+ if isinstance(v, Exception):
628
+ raise v
629
+ return v
630
+
631
+
632
+ class Worker:
633
+ def __init__(self, q: WorkerMessageQueue, rank: int, world: int, local_rank: int):
634
+ # remote ref id to local value
635
+ self.env: Dict[int, Cell] = {}
636
+ self.q = q
637
+ self.rank = rank
638
+ self.world = world
639
+ self.local_rank = local_rank
640
+ self.last_send_status = 0
641
+ self.borrows: Dict[int, Tuple[Ref, Borrow]] = {}
642
+ self.streams: List[Stream] = []
643
+ self.send_recv_process_groups: Dict[Tuple[Stream, Stream], Any] = {}
644
+ self.loop: Optional[asyncio.AbstractEventLoop] = None
645
+ self.stream_thread_error = False
646
+ self.max_received_ident = 0
647
+
648
+ def handle_message(self, event: NamedTuple):
649
+ cmd = event.__class__.__name__
650
+ if ident := getattr(event, "ident", None):
651
+ self.max_received_ident = max(self.max_received_ident, ident)
652
+ fn = getattr(self, cmd, None)
653
+ if fn is not None:
654
+ return fn(event)
655
+ raise RuntimeError(f"unhandled event: {event}")
656
+
657
+ def CreateDeviceMesh(
658
+ self, m: messages.CreateDeviceMesh
659
+ ): # result: "Ref", names: Tuple[str, ...], ranks: NDSlice):
660
+ # pyre-ignore
661
+ self.define(m.result, DeviceMesh(m.result.id, m.names, m.ranks, self.rank))
662
+
663
+ def resolve(self, r: Union[Referenceable, Ref]) -> Cell:
664
+ assert isinstance(r, Ref)
665
+ return self.env[r.id]
666
+
667
+ def CallFunction(self, m: messages.CallFunction):
668
+ flatten_result = flattener(m.result, lambda x: isinstance(x, Ref))
669
+ results = flatten_result(m.result)
670
+ defines = tuple(self.cell(r) for r in results)
671
+ mutates = tuple(self.resolve(r) for r in m.mutates)
672
+ stream: Stream = self.resolve(m.stream).get()
673
+ device_mesh = (
674
+ self.resolve(m.device_mesh).get() if m.device_mesh is not None else None
675
+ )
676
+ inputs, unflatten_inputs = self._inputs((m.args, m.kwargs))
677
+
678
+ stream.call_function(
679
+ m.ident,
680
+ defines,
681
+ flatten_result,
682
+ mutates,
683
+ m.function,
684
+ inputs,
685
+ unflatten_inputs,
686
+ device_mesh,
687
+ )
688
+
689
+ def CreateRemoteProcessGroup(self, m: messages.CreateRemoteProcessGroup):
690
+ device_mesh = self.resolve(m.device_mesh).get()
691
+ result = self.cell(m.result)
692
+ result.set(device_mesh.create_process_group_shell(m.dims, m.result))
693
+
694
+ def CreateStream(self, m: messages.CreateStream):
695
+ # pyre-ignore
696
+ stream = Stream(self, m.result.id, m.default)
697
+ self.streams.append(stream)
698
+ self.define(m.result, stream)
699
+
700
+ def _inputs(self, obj):
701
+ refs, unflatten = flatten(obj, lambda x: isinstance(x, Ref))
702
+ inputs = [self.env[r.id] for r in refs]
703
+ return inputs, unflatten
704
+
705
+ def SendValue(self, m: messages.SendValue):
706
+ assert (
707
+ not _tls.tracing
708
+ ), "controller should have prevented SendValue in repeat block."
709
+ stream: Stream = self.resolve(m.stream).get()
710
+ pipe: Optional["WorkerPipe"] = (
711
+ self.resolve(m.destination).get() if m.destination is not None else None
712
+ )
713
+ inputs, unflatten = self._inputs((m.args, m.kwargs))
714
+ mutates = tuple(self.resolve(r) for r in m.mutates)
715
+ stream.send_value(m.ident, m.function, mutates, inputs, unflatten, pipe)
716
+
717
+ def PipeRecv(self, m: messages.PipeRecv):
718
+ stream: Stream = self.resolve(m.stream).get()
719
+ pipe: WorkerPipe = self.resolve(m.pipe).get()
720
+ flatten = flattener(m.result, lambda x: isinstance(x, Ref))
721
+ results = flatten(m.result)
722
+ results = tuple(self.cell(r) for r in results)
723
+ stream.call_function(
724
+ m.ident,
725
+ results,
726
+ flatten,
727
+ (),
728
+ pipe,
729
+ (),
730
+ lambda x: ((), {}),
731
+ )
732
+
733
+ def RequestStatus(self, m: messages.RequestStatus):
734
+ # wait until all streams have reach the point
735
+ # we have scheduled, and then respond to the message
736
+ ident = m.ident
737
+ count = 0
738
+ expected = 0
739
+
740
+ # runs on asyncio event loop, but
741
+ # is placed on the event loop by the
742
+ # stream thread when it reaches this work item
743
+ def increment_and_send():
744
+ nonlocal count
745
+ count += 1
746
+ if count == expected:
747
+ self._send_status(ident + 1)
748
+
749
+ for stream in self.streams:
750
+ if stream.thread is not None:
751
+ expected += 1
752
+ stream.schedule(lambda: self.schedule(increment_and_send))
753
+
754
+ # if there were no active threads we still need to respond to status
755
+ # messages to make sure controller knows we are alive
756
+ if expected == 0:
757
+ self._send_status(ident + 1)
758
+
759
+ def Exit(self, m: messages.Exit):
760
+ for stream in self.streams:
761
+ stream.exit()
762
+ for stream in self.streams:
763
+ logger.info("joining stream")
764
+ stream.join()
765
+ if torch.distributed.is_initialized() and m.destroy_pg:
766
+ for pg in self.send_recv_process_groups.values():
767
+ torch.distributed.destroy_process_group(pg)
768
+ if torch.cuda.is_available():
769
+ torch.cuda.synchronize()
770
+ torch.distributed.barrier()
771
+ torch.distributed.destroy_process_group()
772
+ logger.info("PG destroyed")
773
+ raise StopIteration()
774
+
775
+ def CommandGroup(self, m: messages.CommandGroup):
776
+ for cmd in m.commands:
777
+ self.handle_message(cmd)
778
+
779
+ @contextmanager
780
+ def trace(self, value: Optional["CompiledBlock"]) -> Generator[None, Any, Any]:
781
+ old, _tls.tracing = _tls.tracing, value
782
+ try:
783
+ yield
784
+ finally:
785
+ _tls.tracing = old
786
+
787
+ def DefineRecording(self, m: messages.DefineRecording):
788
+ block = CompiledBlock()
789
+ with self.trace(block):
790
+ for cmd in m.commands:
791
+ self.handle_message(cmd)
792
+ block.emit()
793
+ self.define(m.result, block)
794
+
795
+ def RecordingFormal(self, m: messages.RecordingFormal):
796
+ block = _tls.tracing
797
+ assert block is not None
798
+ self.cell(m.result).set(
799
+ block.define_formal(self.resolve(m.stream).get(), m.argument_index)
800
+ )
801
+
802
+ def RecordingResult(self, m: messages.RecordingResult):
803
+ block = _tls.tracing
804
+ assert block is not None
805
+ with block.record_to(self.resolve(m.stream).get()):
806
+ node = self.resolve(m.input).get()
807
+ assert isinstance(node, torch.fx.Node)
808
+ block.define_result(node, m.output_index)
809
+
810
+ def CallRecording(self, m: messages.CallRecording):
811
+ recording: CompiledBlock = self.resolve(m.recording).get()
812
+ actuals = [
813
+ self.resolve(a) if i in recording.used_formals else None
814
+ for i, a in enumerate(m.actuals)
815
+ ]
816
+ results = [
817
+ self.cell(r) if i in recording.used_results else None
818
+ for i, r in enumerate(m.results)
819
+ ]
820
+ for stream, impl in recording.impls.items():
821
+ stream.run_recording(m.ident, impl, results, actuals)
822
+
823
+ def DeleteRefs(self, m: messages.DeleteRefs):
824
+ for id in m.refs:
825
+ del self.env[id]
826
+
827
+ def BorrowCreate(self, m: messages.BorrowCreate):
828
+ from_stream: Stream = self.resolve(m.from_stream).get()
829
+ to_stream: Stream = self.resolve(m.to_stream).get()
830
+ tensor = self.resolve(m.tensor)
831
+ borrow = Borrow(from_stream, to_stream)
832
+ if _tls.tracing:
833
+ _tls.tracing.defined_borrows[borrow] = True
834
+ from_stream.borrow_create(tensor, borrow)
835
+ # pyre-fixme[6]: For 2nd argument expected `Tuple[Ref, Borrow]` but got
836
+ # `Tuple[Tensor, Borrow]`.
837
+ self.borrows[m.borrow] = (m.result, borrow)
838
+
839
+ def BorrowFirstUse(self, m: messages.BorrowFirstUse):
840
+ result_id, borrow = self.borrows[m.borrow]
841
+ result = self.cell(result_id)
842
+ borrow.to_stream.borrow_first_use(result, borrow)
843
+
844
+ def BorrowLastUse(self, m: messages.BorrowLastUse):
845
+ _, borrow = self.borrows[m.borrow]
846
+ stream = borrow.to_stream
847
+ stream.borrow_last_use(borrow)
848
+
849
+ def BorrowDrop(self, m: messages.BorrowDrop):
850
+ _, borrow = self.borrows.pop(m.borrow)
851
+ assert (
852
+ not _tls.tracing or borrow in _tls.tracing.defined_borrows
853
+ ), "controller should have stopped a drop of a borrow not created in a repeat loop"
854
+ stream = borrow.from_stream
855
+ stream.borrow_drop(borrow)
856
+
857
+ def CreatePipe(self, m: messages.CreatePipe):
858
+ device_mesh: DeviceMesh = self.resolve(m.device_mesh).get()
859
+ pipe_name = f"{m.key}-{self.rank}"
860
+ ranks = {k: v.rank for k, v in device_mesh.dims.items()}
861
+ sizes = {k: v.size for k, v in device_mesh.dims.items()}
862
+ pipe = WorkerPipe(self.q, pipe_name, m.max_messages)
863
+ self.define(m.result, pipe)
864
+
865
+ pipe.send((m.function, ranks, sizes, m.args, m.kwargs))
866
+
867
+ def SplitComm(self, m: messages.SplitComm):
868
+ # Test whether this rank is in the mesh specified by the SplitComm
869
+ # command. We do this by attempting to dereference the mesh ref; only
870
+ # the ranks that are on the mesh will succeed.
871
+ try:
872
+ device_mesh = self.resolve(m.device_mesh).get()
873
+ in_mesh = True
874
+ except KeyError:
875
+ in_mesh = False
876
+
877
+ if in_mesh:
878
+ # Create a split process group
879
+ stream = self.resolve(m.stream).get()
880
+ device_mesh.create_process_group(stream, m.dims)
881
+ else:
882
+ # this rank is not in the split group. We still need to participate
883
+ # in the commSplit call, however.
884
+
885
+ # This weird incantation is because the current default split_group
886
+ # API requires all participants to know what the split ranks should
887
+ # be. In our case, workers not part of the new group don't know. So
888
+ # instead we manually contribute a NOCOLOR ncclCommSplit call.
889
+ default_pg = torch.distributed.distributed_c10d._get_default_group()
890
+ # pyre-ignore[16]
891
+ default_pg._get_backend(torch.device("cuda")).perform_nocolor_split(
892
+ default_pg.bound_device_id
893
+ )
894
+
895
+ def SplitCommForProcessGroup(self, m: messages.SplitCommForProcessGroup):
896
+ # Test whether this rank is in the mesh specified by the
897
+ # SplitCommForProcessGroup command. We do this by attempting to
898
+ # dereference the mesh ref; only the ranks that are on the mesh will
899
+ # succeed.
900
+ try:
901
+ pg = self.resolve(m.remote_process_group).get()
902
+ in_mesh = True
903
+ except KeyError:
904
+ in_mesh = False
905
+
906
+ if in_mesh:
907
+ # Create a split process group
908
+ stream = self.resolve(m.stream).get()
909
+ pg.device_mesh.create_process_group(
910
+ stream, pg.dims, pg=m.remote_process_group
911
+ )
912
+ else:
913
+ # this rank is not in the split group. We still need to participate
914
+ # in the commSplit call, however.
915
+
916
+ # This weird incantation is because the current default split_group
917
+ # API requires all participants to know what the split ranks should
918
+ # be. In our case, workers not part of the new group don't know. So
919
+ # instead we manually contribute a NOCOLOR ncclCommSplit call.
920
+ default_pg = torch.distributed.distributed_c10d._get_default_group()
921
+ # pyre-ignore[16]
922
+ default_pg._get_backend(torch.device("cuda")).perform_nocolor_split(
923
+ default_pg.bound_device_id
924
+ )
925
+
926
+ def Reduce(self, m: messages.Reduce):
927
+ stream: Stream = self.resolve(m.stream).get()
928
+ source_mesh: DeviceMesh = self.resolve(m.source_mesh).get()
929
+ assert len(m.dims) <= len(source_mesh.dims)
930
+ if len(m.dims) > 1:
931
+ assert m.reduction != "stack" and not m.scatter
932
+ pg = source_mesh.get_process_group(stream, m.dims)
933
+ local_tensor = self.resolve(m.local_tensor)
934
+ out = None if m.out is None else self.resolve(m.out)
935
+ output = self.cell(m.result)
936
+
937
+ # we need N only for "stack", and in this case we asserted that that len(m.dims) = 1
938
+ N = len(source_mesh.dims[m.dims[0]].members) if m.reduction == "stack" else -1
939
+
940
+ def reducer(local_tensor, out):
941
+ return _reduce(
942
+ local_tensor,
943
+ source_mesh,
944
+ pg,
945
+ N,
946
+ m.reduction,
947
+ m.scatter,
948
+ m.inplace,
949
+ out,
950
+ )
951
+
952
+ stream.collective_call(reducer, m.factory, local_tensor, output, out)
953
+
954
+ def SendTensor(self, m: messages.SendTensor):
955
+ send_stream: Stream = self.resolve(m.from_stream).get()
956
+ recv_stream: Stream = self.resolve(m.to_stream).get()
957
+ pg = self.send_recv_process_groups[(send_stream, recv_stream)]
958
+
959
+ try:
960
+ index = m.from_ranks.index(self.rank)
961
+ send_to_rank = m.to_ranks[index]
962
+ except ValueError:
963
+ send_to_rank = None
964
+
965
+ try:
966
+ index = m.to_ranks.index(self.rank)
967
+ recv_from_rank = m.from_ranks[index]
968
+ except ValueError:
969
+ recv_from_rank = None
970
+
971
+ if send_to_rank is None:
972
+ the_stream = recv_stream
973
+ elif recv_from_rank is None:
974
+ the_stream = send_stream
975
+ elif send_stream is recv_stream:
976
+ the_stream = send_stream
977
+ else:
978
+ raise NotImplementedError(
979
+ "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver."
980
+ "It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync."
981
+ "Then the send stream would do the nccl op, and then sync with sending stream again."
982
+ )
983
+
984
+ def send_recv(
985
+ input_tensor: torch.Tensor, out: Optional[torch.Tensor]
986
+ ) -> Optional[torch.Tensor]:
987
+ # we consider to_mesh to always copy a tensor. But if the
988
+ # from and to rank are the same, we really do not have
989
+ # copy it. In this case we do a copy-on-write via _lazy_clone.
990
+ # The tensor will only be copied for real if someone later
991
+ # tries to mutate it.
992
+ if send_to_rank == recv_from_rank:
993
+ return input_tensor._lazy_clone()
994
+ ops = []
995
+ P2POp = torch.distributed.P2POp
996
+ isend, irecv = torch.distributed.isend, torch.distributed.irecv
997
+ if send_to_rank is not None:
998
+ ops.append(P2POp(isend, input_tensor, send_to_rank, pg))
999
+
1000
+ if recv_from_rank is not None:
1001
+ output = m.factory.empty()
1002
+ ops.append(P2POp(irecv, output, recv_from_rank, pg))
1003
+ else:
1004
+ output = None
1005
+ # invoke batched p2p ops
1006
+ for op in torch.distributed.batch_isend_irecv(ops):
1007
+ op.wait()
1008
+ return output
1009
+
1010
+ input = Cell(None) if send_to_rank is None else self.resolve(m.tensor)
1011
+ output = Cell(None) if recv_from_rank is None else self.cell(m.result)
1012
+ the_stream.collective_call(send_recv, m.factory, input, output, None)
1013
+
1014
+ def BackendNetworkInit(self, m: messages.BackendNetworkInit):
1015
+ if torch.distributed.is_initialized():
1016
+ return # for restarts in tests
1017
+ store = torch.distributed.TCPStore(
1018
+ m.hostname or os.environ["STORE_HOSTNAME"],
1019
+ m.port or int(os.environ["STORE_PORT"]),
1020
+ )
1021
+ torch.distributed.init_process_group(
1022
+ backend="nccl",
1023
+ world_size=self.world,
1024
+ rank=self.rank,
1025
+ store=store,
1026
+ device_id=torch.device("cuda:0"),
1027
+ )
1028
+ b = torch.zeros(1, device="cuda")
1029
+ torch.distributed.all_reduce(b)
1030
+
1031
+ def BackendNetworkPointToPointInit(
1032
+ self, m: messages.BackendNetworkPointToPointInit
1033
+ ):
1034
+ from_stream: Stream = self.resolve(m.from_stream).get()
1035
+ to_stream: Stream = self.resolve(m.to_stream).get()
1036
+ self.send_recv_process_groups[(from_stream, to_stream)] = _new_process_group(
1037
+ f"restart_{restart_count}_send_{from_stream.id}_recv_{to_stream.id}",
1038
+ None,
1039
+ split=False,
1040
+ )
1041
+
1042
+ def DebuggerMessage(self, m: messages.DebuggerMessage):
1043
+ stream: Stream = self.env[m.stream_id].get()
1044
+ stream.debugger_queue.put(m.action)
1045
+
1046
+ def define(self, r: Union[Ref, Referenceable], value: Any):
1047
+ assert isinstance(r, Ref)
1048
+ self.env[r.id] = Cell(value)
1049
+
1050
+ def cell(self, r: Union[Ref, Referenceable]):
1051
+ assert isinstance(r, Ref)
1052
+ c = self.env[r.id] = Cell()
1053
+ if _tls.tracing:
1054
+ _tls.tracing.defined_cells[c] = r.id
1055
+ return c
1056
+
1057
+ def _send_status(self, first_uncompleted_ident):
1058
+ if first_uncompleted_ident > self.last_send_status:
1059
+ self.q.send(messages.Status(first_uncompleted_ident))
1060
+ self.last_send_status = first_uncompleted_ident
1061
+
1062
+ async def worker_loop(self):
1063
+ monitor = Monitor()
1064
+ monitor.start()
1065
+ self.loop = asyncio.get_event_loop()
1066
+ debugq = deque()
1067
+ while True:
1068
+ try:
1069
+ # eventually this event loop should be handled as a separate
1070
+ # thread (maybe not even python) that just takes and
1071
+ # responds to messages, with a strong guarentee of never
1072
+ # getting stuck. For now we just run everything on this thread.
1073
+ monitor(
1074
+ lambda: (
1075
+ logger.error(
1076
+ f"possible stall while waiting for message: recent messages: {debugq} "
1077
+ f"{self.max_received_ident=} {self.last_send_status=}"
1078
+ ),
1079
+ logger.setLevel(logging.INFO),
1080
+ ),
1081
+ 30.0,
1082
+ )
1083
+ _, msg = await self.q.recv_async()
1084
+ logger.debug(f"event: {msg}, env={list(self.env.keys())}")
1085
+ monitor(
1086
+ (
1087
+ lambda msg=msg: logger.error(
1088
+ f"possible stall while handling {msg}"
1089
+ )
1090
+ ),
1091
+ 30.0,
1092
+ )
1093
+ self.handle_message(msg)
1094
+
1095
+ debugq.append(msg)
1096
+ while len(debugq) > 10:
1097
+ debugq.popleft()
1098
+ except StopIteration:
1099
+ self.q.recvready(0)
1100
+ self.q.recvready(0.01)
1101
+ return
1102
+ except Exception as e:
1103
+ logger.exception("Worker event loop exiting with internal exception")
1104
+ self.internal_error(
1105
+ messages.InternalException(e, extract_tb(e.__traceback__))
1106
+ )
1107
+
1108
+ def schedule(self, fn: Callable[[], None]):
1109
+ assert self.loop is not None
1110
+ self.loop.call_soon_threadsafe(fn)
1111
+
1112
+ def internal_error(self, msg: messages.InternalException):
1113
+ self.q.send(msg)
1114
+ assert self.loop is not None
1115
+ self.loop.stop()
1116
+
1117
+ def event_loop(self):
1118
+ pdb.set_trace = _set_trace
1119
+ try:
1120
+ asyncio.run(self.worker_loop())
1121
+ except RuntimeError as e:
1122
+ if "Event loop stopped" in str(e):
1123
+ logger.warning("Event loop exiting after reporting an internal error.")
1124
+
1125
+ else:
1126
+ raise
1127
+
1128
+
1129
+ def worker_main(_restartable):
1130
+ rank = int(os.environ["RANK"])
1131
+ world = int(os.environ["WORLD_SIZE"])
1132
+ local_rank = int(os.environ["LOCAL_RANK"])
1133
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
1134
+ devices = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
1135
+ device = devices[local_rank]
1136
+ else:
1137
+ device = str(local_rank)
1138
+ os.environ["CUDA_VISIBLE_DEVICES"] = device
1139
+ initialize_logging(process_name=f"worker_{rank}")
1140
+ logger.info("starting, restartable=%s, local_rank=%d", _restartable, local_rank)
1141
+ # force CUDA to initialize before do any multithreading. This is a
1142
+ # workaround until https://github.com/pytorch/pytorch/pull/143238 is
1143
+ # available everywhere.
1144
+ if torch.cuda.is_available():
1145
+ torch.ones(1, device="cuda")
1146
+ q = get_message_queue()
1147
+ global restart_count
1148
+ for restart in itertools.count():
1149
+ restart_count = restart
1150
+ worker = Worker(q, rank, world, local_rank)
1151
+ worker.event_loop()
1152
+ if not _restartable:
1153
+ break
1154
+ q.send(messages.Restarted(0))
1155
+ logger.info("restarting")
1156
+
1157
+
1158
+ class ProcessPipe:
1159
+ """Pipe Process Pipe"""
1160
+
1161
+ def __init__(self, key: str, max_messages):
1162
+ import zmq
1163
+
1164
+ q = get_message_queue()
1165
+ self._sock = q._socket(zmq.PAIR)
1166
+ self._sock.setsockopt(zmq.SNDHWM, max_messages)
1167
+ self._sock.setsockopt(zmq.RCVHWM, max_messages)
1168
+ self._sock.connect(key)
1169
+ self.ranks = {}
1170
+ self.sizes = {}
1171
+
1172
+ def send(self, any: Any):
1173
+ self._sock.send_pyobj(any)
1174
+
1175
+ def recv(self):
1176
+ return self._sock.recv_pyobj()
1177
+
1178
+
1179
+ def pipe_main(key: str, max_messages):
1180
+ """Main function for pipe process"""
1181
+ initialize_logging(f"pipe_{key}")
1182
+ pipe_obj = ProcessPipe(key, max_messages)
1183
+ rfunction, pipe_obj.ranks, pipe_obj.sizes, args, kwargs = pipe_obj.recv()
1184
+ function = rfunction.resolve()
1185
+ try:
1186
+ function(pipe_obj, *args, **kwargs)
1187
+ except Exception as e:
1188
+ logger.exception("pipe_main exiting with exception")
1189
+ get_message_queue().send(
1190
+ messages.RemoteGeneratorFailed(e, extract_tb(e.__traceback__))
1191
+ )