torchmonarch-nightly 2025.7.28__cp312-cp312-manylinux2014_x86_64.whl → 2025.7.30__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,21 @@
7
7
  import asyncio
8
8
  import traceback
9
9
  from functools import partial
10
- from typing import Generator, Generic, Optional, TypeVar
10
+ from typing import (
11
+ Any,
12
+ cast,
13
+ Coroutine,
14
+ Generator,
15
+ Generic,
16
+ Literal,
17
+ NamedTuple,
18
+ Optional,
19
+ TypeVar,
20
+ )
21
+
22
+ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
23
+
24
+ from typing_extensions import Self
11
25
 
12
26
  R = TypeVar("R")
13
27
 
@@ -48,43 +62,76 @@ async def _aincomplete(impl, self):
48
62
  # loop machinery, this gives it the same throughput as if we ran it synchronously.
49
63
 
50
64
 
51
- class Future(Generic[R]):
52
- def __init__(self, *, impl, requires_loop=True):
53
- self._aget = partial(_aincomplete, impl)
54
- self._requires_loop = requires_loop
65
+ class _Unawaited(NamedTuple):
66
+ coro: PythonTask
55
67
 
56
- def get(self, timeout: Optional[float] = None) -> R:
57
- if asyncio._get_running_loop() is not None:
58
- raise RuntimeError("get() cannot be called from within an async context")
59
- if timeout is not None:
60
- return asyncio.run(asyncio.wait_for(self._aget(self), timeout))
61
- if not self._requires_loop:
62
- try:
63
- coro = self._aget(self)
64
- next(coro.__await__())
65
- tb_str = "".join(traceback.format_stack(coro.cr_frame))
66
- raise RuntimeError(
67
- f"a coroutine paused with a future with requires_loop=False cannot block on a python asyncio.Future. Use requires_loop=True.\n{tb_str}"
68
- )
69
- except StopIteration as e:
70
- return e.value
71
- return asyncio.run(self._aget(self))
72
68
 
73
- def __await__(self) -> Generator[R, None, R]:
74
- return self._aget(self).__await__()
69
+ class _Complete(NamedTuple):
70
+ value: Any
71
+
72
+
73
+ class _Exception(NamedTuple):
74
+ exe: Exception
75
+
75
76
 
76
- def _set_result(self, result):
77
- async def af(self):
78
- return result
77
+ class _Asyncio(NamedTuple):
78
+ fut: asyncio.Future
79
79
 
80
- self._aget = af
81
- return result
82
80
 
83
- def _set_exception(self, e):
84
- async def af(self):
85
- raise e
81
+ _Status = _Unawaited | _Complete | _Exception | _Asyncio
86
82
 
87
- self._aget = af
83
+
84
+ class Future(Generic[R]):
85
+ def __init__(self, *, coro: "Coroutine[Any, Any, R] | PythonTask[R]"):
86
+ self._status: _Status = _Unawaited(
87
+ coro if isinstance(coro, PythonTask) else PythonTask.from_coroutine(coro)
88
+ )
89
+
90
+ def get(self, timeout: Optional[float] = None) -> R:
91
+ match self._status:
92
+ case _Unawaited(coro=coro):
93
+ try:
94
+ if timeout is not None:
95
+ coro = coro.with_timeout(timeout)
96
+ v = coro.block_on()
97
+ self._status = _Complete(v)
98
+ return cast("R", v)
99
+ except Exception as e:
100
+ self._status = _Exception(e)
101
+ raise e from None
102
+ case _Asyncio(_):
103
+ raise ValueError(
104
+ "already converted into an asyncio.Future, use 'await' to get the value."
105
+ )
106
+ case _Complete(value=value):
107
+ return cast("R", value)
108
+ case _Exception(exe=exe):
109
+ raise exe
110
+ case _:
111
+ raise RuntimeError("unknown status")
112
+
113
+ def __await__(self) -> Generator[Any, Any, R]:
114
+ match self._status:
115
+ case _Unawaited(coro=coro):
116
+ loop = asyncio.get_running_loop()
117
+ fut = loop.create_future()
118
+ self._status = _Asyncio(fut)
119
+
120
+ async def mark_complete():
121
+ try:
122
+ func, value = fut.set_result, await coro
123
+ except Exception as e:
124
+ func, value = fut.set_exception, e
125
+ loop.call_soon_threadsafe(func, value)
126
+
127
+ PythonTask.from_coroutine(mark_complete()).spawn()
128
+ return fut.__await__()
129
+ case _Asyncio(fut=fut):
130
+ return fut.__await__()
131
+ case _:
132
+ raise ValueError(
133
+ "already converted into a synchronous future, use 'get' to get the value."
134
+ )
88
135
 
89
136
  # compatibility with old tensor engine Future objects
90
137
  # hopefully we do not need done(), add_callback because
@@ -47,9 +47,12 @@ class PdbWrapper(pdb.Pdb):
47
47
  super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self))
48
48
  self._first = True
49
49
 
50
- def set_trace(self, frame):
50
+ def set_trace(self, frame=None):
51
51
  self.client_ref.debugger_session_start.broadcast(
52
- self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id
52
+ self.rank,
53
+ self.coords,
54
+ socket.getfqdn(socket.gethostname()),
55
+ self.actor_id.actor_name,
53
56
  )
54
57
  if self.header:
55
58
  self.message(self.header)
@@ -67,7 +70,9 @@ class PdbWrapper(pdb.Pdb):
67
70
  super().do_clear(arg)
68
71
 
69
72
  def end_debug_session(self):
70
- self.client_ref.debugger_session_end.broadcast(self.rank)
73
+ self.client_ref.debugger_session_end.broadcast(
74
+ self.actor_id.actor_name, self.rank
75
+ )
71
76
  # Once the debug client actor is notified of the session being over,
72
77
  # we need to prevent any additional requests being sent for the session
73
78
  # by redirecting stdin and stdout.
@@ -88,7 +93,7 @@ class ReadWrapper(io.RawIOBase):
88
93
  def readinto(self, b):
89
94
  with fake_sync_state():
90
95
  response = self.session.client_ref.debugger_read.call_one(
91
- self.session.rank, len(b)
96
+ self.session.actor_id.actor_name, self.session.rank, len(b)
92
97
  ).get()
93
98
  if response == "detach":
94
99
  # this gets injected by the worker event loop to
@@ -124,6 +129,7 @@ class WriteWrapper:
124
129
  # pyre-ignore
125
130
  lineno = self.session.curframe.f_lineno
126
131
  self.session.client_ref.debugger_write.broadcast(
132
+ self.session.actor_id.actor_name,
127
133
  self.session.rank,
128
134
  DebuggerWrite(
129
135
  s.encode(),
@@ -37,15 +37,14 @@ from monarch._rust_bindings.monarch_hyperactor.proc_mesh import (
37
37
  ProcMeshMonitor,
38
38
  )
39
39
  from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
40
- from monarch._src.actor.actor_mesh import (
41
- _Actor,
42
- _ActorMeshRefImpl,
43
- Actor,
44
- ActorMeshRef,
45
- fake_sync_state,
46
- )
40
+ from monarch._src.actor.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef
47
41
 
48
- from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator, SimAllocator
42
+ from monarch._src.actor.allocator import (
43
+ AllocateMixin,
44
+ LocalAllocator,
45
+ ProcessAllocator,
46
+ SimAllocator,
47
+ )
49
48
  from monarch._src.actor.code_sync import (
50
49
  CodeSyncMeshClient,
51
50
  RemoteWorkspace,
@@ -111,29 +110,12 @@ except ImportError:
111
110
  IN_PAR = False
112
111
 
113
112
 
114
- async def _allocate_nonblocking(
115
- alloc: Alloc, setup: Callable[[], None] | None = None
116
- ) -> "ProcMesh":
117
- _proc_mesh = await HyProcMesh.allocate_nonblocking(alloc)
118
- if setup is None:
119
- return ProcMesh(_proc_mesh)
120
- # If the user has passed the setup lambda, we need to call
121
- # it here before any of the other actors are spawned so that
122
- # the environment variables are set up before cuda init.
123
- proc_mesh = ProcMesh(_proc_mesh)
124
- setup_actor = await proc_mesh.spawn("setup", SetupActor, setup)
125
- await setup_actor.setup.call()
126
- del setup_actor
127
- return proc_mesh
128
-
129
-
130
113
  class ProcMesh(MeshTrait):
131
114
  def __init__(
132
115
  self,
133
116
  hy_proc_mesh: HyProcMesh,
134
117
  _mock_shape: Optional[Shape] = None,
135
118
  _device_mesh: Optional["DeviceMesh"] = None,
136
- _is_initializing_debugger: bool = False,
137
119
  ) -> None:
138
120
  self._proc_mesh = hy_proc_mesh
139
121
  self._mock_shape: Optional[Shape] = _mock_shape
@@ -146,20 +128,32 @@ class ProcMesh(MeshTrait):
146
128
  self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh
147
129
  self._stopped = False
148
130
 
149
- # This code is unsafe in async contexts, but we currently do it all over the place
150
- # we need to refactor this by moving it to the first time we try to spawn on the mesh.
151
- # Right now we simply preserve the previous behavior and disable the check that prevents
152
- # end users from doing the same.
153
- with fake_sync_state():
154
- if _mock_shape is None and HAS_TENSOR_ENGINE:
155
- # type: ignore[21]
156
- self._rdma_manager = _RdmaManager.create_rdma_manager_blocking(
157
- self._proc_mesh
158
- )
159
- if not _is_initializing_debugger and _mock_shape is None:
160
- self._debug_manager = self.spawn(
161
- _DEBUG_MANAGER_ACTOR_NAME, DebugManager, debug_client()
162
- ).get()
131
+ async def _init_manager_actors(
132
+ self,
133
+ setup: Callable[[], None] | None = None,
134
+ ) -> "ProcMesh":
135
+ _rdma_manager = (
136
+ # pyre-ignore
137
+ await _RdmaManager.create_rdma_manager_nonblocking(self._proc_mesh)
138
+ if HAS_TENSOR_ENGINE
139
+ else None
140
+ )
141
+
142
+ _debug_manager = await self._spawn_nonblocking(
143
+ _DEBUG_MANAGER_ACTOR_NAME, DebugManager, await _debug_client()
144
+ )
145
+
146
+ self._debug_manager = _debug_manager
147
+ self._rdma_manager = _rdma_manager
148
+
149
+ if setup is not None:
150
+ # If the user has passed the setup lambda, we need to call
151
+ # it here before any of the other actors are spawned so that
152
+ # the environment variables are set up before cuda init.
153
+ setup_actor = await self._spawn_nonblocking("setup", SetupActor, setup)
154
+ # pyre-ignore
155
+ await setup_actor.setup.call()._status.coro
156
+ return self
163
157
 
164
158
  @property
165
159
  def _shape(self) -> Shape:
@@ -184,10 +178,7 @@ class ProcMesh(MeshTrait):
184
178
  def spawn(self, name: str, Class: Type[T], *args: Any, **kwargs: Any) -> Future[T]:
185
179
  if self._mock_shape is not None:
186
180
  raise NotImplementedError("NYI: spawn on slice of a proc mesh.")
187
- return Future(
188
- impl=lambda: self._spawn_nonblocking(name, Class, *args, **kwargs),
189
- requires_loop=False,
190
- )
181
+ return Future(coro=self._spawn_nonblocking(name, Class, *args, **kwargs))
191
182
 
192
183
  async def monitor(self) -> ProcMeshMonitor:
193
184
  """
@@ -230,8 +221,7 @@ class ProcMesh(MeshTrait):
230
221
  ```
231
222
  """
232
223
  return Future(
233
- impl=lambda: _allocate_nonblocking(alloc, setup),
234
- requires_loop=False,
224
+ coro=_proc_mesh_from_alloc_coro(alloc, setup, init_manager_actors=True)
235
225
  )
236
226
 
237
227
  def __repr__(self) -> str:
@@ -345,10 +335,7 @@ class ProcMesh(MeshTrait):
345
335
  await self._proc_mesh.stop_nonblocking()
346
336
  self._stopped = True
347
337
 
348
- return Future(
349
- impl=lambda: _stop_nonblocking(),
350
- requires_loop=False,
351
- )
338
+ return Future(coro=_stop_nonblocking())
352
339
 
353
340
  async def __aexit__(
354
341
  self, exc_type: object, exc_val: object, exc_tb: object
@@ -370,46 +357,15 @@ class ProcMesh(MeshTrait):
370
357
  # Cannot call stop here because it is async.
371
358
 
372
359
 
373
- async def local_proc_mesh_nonblocking(
374
- *,
375
- gpus: Optional[int] = None,
376
- hosts: int = 1,
377
- _is_initializing_debugger: bool = False,
378
- ) -> ProcMesh:
379
- if gpus is None:
380
- gpus = _local_device_count()
381
- spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
382
- allocator = LocalAllocator()
383
- alloc = await allocator.allocate(spec)
384
- proc_mesh = HyProcMesh.allocate_nonblocking(alloc)
385
- return ProcMesh(
386
- await proc_mesh,
387
- _is_initializing_debugger=_is_initializing_debugger,
388
- )
389
-
390
-
391
360
  def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]:
392
361
  return Future(
393
- impl=lambda: local_proc_mesh_nonblocking(gpus=gpus, hosts=hosts),
394
- requires_loop=False,
362
+ coro=_proc_mesh_coro(gpus=gpus, hosts=hosts, allocator=LocalAllocator())
395
363
  )
396
364
 
397
365
 
398
- async def sim_proc_mesh_nonblocking(
399
- *, gpus: Optional[int] = None, hosts: int = 1
400
- ) -> ProcMesh:
401
- if gpus is None:
402
- gpus = _local_device_count()
403
- spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts)
404
- allocator = SimAllocator()
405
- alloc = await allocator.allocate(spec)
406
- return await ProcMesh.from_alloc(alloc)
407
-
408
-
409
366
  def sim_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]:
410
367
  return Future(
411
- impl=lambda: sim_proc_mesh_nonblocking(gpus=gpus, hosts=hosts),
412
- requires_loop=False,
368
+ coro=_proc_mesh_coro(gpus=gpus, hosts=hosts, allocator=SimAllocator())
413
369
  )
414
370
 
415
371
 
@@ -431,33 +387,35 @@ def _get_bootstrap_args() -> tuple[str, Optional[list[str]], dict[str, str]]:
431
387
  return cmd, args, env
432
388
 
433
389
 
434
- async def proc_mesh_nonblocking(
390
+ async def _proc_mesh_from_alloc_coro(
391
+ alloc: Alloc,
392
+ setup: Callable[[], None] | None,
393
+ init_manager_actors: bool,
394
+ ) -> ProcMesh:
395
+ _hy_proc_mesh = await HyProcMesh.allocate_nonblocking(alloc)
396
+ proc_mesh = ProcMesh(_hy_proc_mesh)
397
+ if init_manager_actors:
398
+ await proc_mesh._init_manager_actors(setup)
399
+ return proc_mesh
400
+
401
+
402
+ async def _proc_mesh_coro(
435
403
  *,
404
+ allocator: AllocateMixin,
436
405
  gpus: Optional[int] = None,
437
406
  hosts: int = 1,
438
- env: dict[str, str] | None = None,
439
407
  setup: Callable[[], None] | None = None,
408
+ init_manager_actors: bool = True,
440
409
  ) -> ProcMesh:
441
410
  if gpus is None:
442
411
  gpus = _local_device_count()
443
412
  # gpus must come last in this order because
444
413
  # test_remote_function_all_gather expects that hosts comes before gpus
445
414
  # in the order of the dimensions.
446
- spec = AllocSpec(AllocConstraints(), hosts=hosts, gpus=gpus)
447
- env = env or {}
448
- # Todo: Deprecate the env field from the ProcessAllocator
449
- # The PAR_MAIN_OVERRIDE needs to be passed as an env
450
- # to the proc mesh construction in rust, so can not be moved to the
451
- # SetupActor yet
452
- cmd, args, bootstrap_env = _get_bootstrap_args()
453
- env.update(bootstrap_env)
454
- allocator = ProcessAllocator(cmd, args, env)
455
- alloc = await allocator.allocate(spec)
415
+ spec: AllocSpec = AllocSpec(AllocConstraints(), hosts=hosts, gpus=gpus)
416
+ alloc = await allocator.allocate_nonblocking(spec)
456
417
 
457
- return await ProcMesh.from_alloc(
458
- alloc,
459
- setup=setup,
460
- )
418
+ return await _proc_mesh_from_alloc_coro(alloc, setup, init_manager_actors)
461
419
 
462
420
 
463
421
  def proc_mesh(
@@ -467,12 +425,22 @@ def proc_mesh(
467
425
  env: dict[str, str] | None = None,
468
426
  setup: Callable[[], None] | None = None,
469
427
  ) -> Future[ProcMesh]:
470
- return Future(
471
- impl=lambda: proc_mesh_nonblocking(
472
- gpus=gpus, hosts=hosts, env=env, setup=setup
473
- ),
474
- requires_loop=False,
428
+ env = env or {}
429
+
430
+ # Todo: Deprecate the env field from the ProcessAllocator
431
+ # The PAR_MAIN_OVERRIDE needs to be passed as an env
432
+ # to the proc mesh construction in rust, so can not be moved to the
433
+ # SetupActor yet
434
+ cmd, args, bootstrap_env = _get_bootstrap_args()
435
+ env.update(bootstrap_env)
436
+ task = _proc_mesh_coro(
437
+ gpus=gpus,
438
+ hosts=hosts,
439
+ setup=setup,
440
+ allocator=ProcessAllocator(cmd, args, env),
441
+ init_manager_actors=True,
475
442
  )
443
+ return Future(coro=task)
476
444
 
477
445
 
478
446
  _debug_proc_mesh: Optional["ProcMesh"] = None
@@ -482,15 +450,12 @@ _debug_proc_mesh: Optional["ProcMesh"] = None
482
450
  # doesn't trigger the debug client to spawn, which could cause confusing
483
451
  # logs. This is defined in proc_mesh.py instead of debugger.py for
484
452
  # circular import reasons.
485
- def _get_debug_proc_mesh() -> "ProcMesh":
453
+ async def _get_debug_proc_mesh() -> "ProcMesh":
486
454
  global _debug_proc_mesh
487
455
  if _debug_proc_mesh is None:
488
- _debug_proc_mesh = Future(
489
- impl=lambda: local_proc_mesh_nonblocking(
490
- gpus=1, hosts=1, _is_initializing_debugger=True
491
- ),
492
- requires_loop=False,
493
- ).get()
456
+ _debug_proc_mesh = await _proc_mesh_coro(
457
+ gpus=1, hosts=1, allocator=LocalAllocator(), init_manager_actors=False
458
+ )
494
459
  return _debug_proc_mesh
495
460
 
496
461
 
@@ -499,10 +464,13 @@ _debug_client_mesh: Optional[DebugClient] = None
499
464
 
500
465
  # Lazy init for the same reason as above. This is defined in proc_mesh.py
501
466
  # instead of debugger.py for circular import reasons.
502
- def debug_client() -> DebugClient:
467
+ async def _debug_client() -> DebugClient:
503
468
  global _debug_client_mesh
504
469
  if _debug_client_mesh is None:
505
- _debug_client_mesh = (
506
- _get_debug_proc_mesh().spawn("debug_client", DebugClient).get()
507
- )
470
+ mesh = await _get_debug_proc_mesh()
471
+ _debug_client_mesh = await mesh._spawn_nonblocking("debug_client", DebugClient)
508
472
  return _debug_client_mesh
473
+
474
+
475
+ def debug_client() -> DebugClient:
476
+ return Future(coro=_debug_client()).get()
@@ -31,6 +31,32 @@ def iter_ranks(ranks: Slices) -> Generator[int, None, None]:
31
31
  yield from ranks
32
32
 
33
33
 
34
+ class ShapeExt:
35
+ """Extension methods for Shape that add higher-level
36
+ functionality."""
37
+
38
+ @staticmethod
39
+ def slice(shape: Shape, **kwargs) -> Shape:
40
+ """Select along named dimensions. Integer values remove
41
+ dimensions, slice objects keep dimensions but restrict them.
42
+
43
+ Examples: ShapeExt.slice(shape, batch=3, gpu=slice(2, 6))
44
+ """
45
+ for label, selector in kwargs.items():
46
+ if label not in shape.labels:
47
+ raise TypeError(f"Shape does not have dimension labeled {label!r}")
48
+ if isinstance(selector, slice):
49
+ shape = shape.select(label, selector)
50
+ else:
51
+ if (
52
+ selector < 0
53
+ or selector >= shape.ndslice.sizes[shape.labels.index(label)]
54
+ ):
55
+ raise IndexError("index out of range")
56
+ shape = shape.at(label, selector)
57
+ return shape
58
+
59
+
34
60
  class MeshTrait(ABC):
35
61
  """
36
62
  Mesh interface. Implemented via Shape.
@@ -51,40 +77,13 @@ class MeshTrait(ABC):
51
77
  def _new_with_shape(self, shape: Shape) -> Self: ...
52
78
 
53
79
  def slice(self, **kwargs) -> Self:
54
- """
55
- mesh.slice(batch=3) or mesh.slice(batch=slice(3, None))
56
- """
57
- ndslice = self._ndslice
58
- labels = self._labels
59
- offset = ndslice.offset
60
- names = []
61
- sizes = []
62
- strides = []
63
- for name, size, stride in zip(labels, ndslice.sizes, ndslice.strides):
64
- if name in kwargs:
65
- e = kwargs.pop(name)
66
- if isinstance(e, slice):
67
- start, stop, slice_stride = e.indices(size)
68
- offset += start * stride
69
- names.append(name)
70
- sizes.append((stop - start) // slice_stride)
71
- strides.append(slice_stride * stride)
72
- else:
73
- if e >= size or e < 0:
74
- raise IndexError("index out of range")
75
- offset += e * stride
76
- else:
77
- names.append(name)
78
- sizes.append(size)
79
- strides.append(stride)
80
-
81
- if kwargs:
82
- raise TypeError(
83
- f"{self} does not have dimension(s) named {tuple(kwargs.keys())}"
84
- )
80
+ """Select along named dimensions. Integer values remove
81
+ dimensions, slice objects keep dimensions but restrict them.
85
82
 
86
- new_ndslice = NDSlice(offset=offset, sizes=sizes, strides=strides)
87
- return self._new_with_shape(Shape(names, new_ndslice))
83
+ Examples: mesh.slice(batch=3, gpu=slice(2, 6))
84
+ """
85
+ shape = Shape(list(self._labels), self._ndslice)
86
+ return self._new_with_shape(ShapeExt.slice(shape, **kwargs))
88
87
 
89
88
  def split(self, **kwargs) -> Self:
90
89
  """
@@ -120,12 +120,15 @@ class RDMABuffer:
120
120
  f"offset + size ({offset + size}) must be <= dst.numel() ({dst.numel()})"
121
121
  )
122
122
 
123
+ local_proc_id = MonarchContext.get().proc_id
124
+ client = MonarchContext.get().mailbox
125
+
123
126
  async def read_into_nonblocking() -> Optional[int]:
124
127
  res = await self._buffer.read_into(
125
128
  addr=addr,
126
129
  size=size,
127
- local_proc_id=MonarchContext.get().proc_id,
128
- client=MonarchContext.get().mailbox,
130
+ local_proc_id=local_proc_id,
131
+ client=client,
129
132
  timeout=timeout,
130
133
  )
131
134
  # TODO - remove this once GPU support is added.
@@ -133,7 +136,7 @@ class RDMABuffer:
133
136
  dst_gpu.copy_(dst)
134
137
  return res
135
138
 
136
- return Future(impl=read_into_nonblocking, requires_loop=False)
139
+ return Future(coro=read_into_nonblocking())
137
140
 
138
141
  def write_from(
139
142
  self, src: torch.Tensor, offset: int = 0, timeout: int = 3
@@ -164,12 +167,15 @@ class RDMABuffer:
164
167
  f"size + offset ({size + offset}) must be <= src.numel() ({src.numel()})"
165
168
  )
166
169
 
170
+ local_proc_id = MonarchContext.get().proc_id
171
+ client = MonarchContext.get().mailbox
172
+
167
173
  async def write_from_nonblocking() -> None:
168
174
  res = await self._buffer.write_from(
169
175
  addr=addr,
170
176
  size=size,
171
- local_proc_id=MonarchContext.get().proc_id,
172
- client=MonarchContext.get().mailbox,
177
+ local_proc_id=local_proc_id,
178
+ client=client,
173
179
  timeout=timeout,
174
180
  )
175
181
  # TODO - remove this once GPU support is added.
@@ -177,4 +183,4 @@ class RDMABuffer:
177
183
  src_gpu.copy_(src)
178
184
  return res
179
185
 
180
- return Future(impl=write_from_nonblocking, requires_loop=False)
186
+ return Future(coro=write_from_nonblocking())
@@ -11,6 +11,7 @@ import os
11
11
  import pdb # noqa
12
12
  import traceback
13
13
  from collections import deque
14
+ from functools import partial
14
15
  from logging import Logger
15
16
  from typing import (
16
17
  Any,
@@ -32,6 +33,7 @@ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monar
32
33
  from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller
33
34
  from monarch._rust_bindings.monarch_extension.tensor_worker import Ref
34
35
  from monarch._rust_bindings.monarch_hyperactor.actor import (
36
+ MethodSpecifier,
35
37
  PythonMessage,
36
38
  PythonMessageKind,
37
39
  UnflattenArg,
@@ -40,6 +42,7 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
40
42
  from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
41
43
  ActorId,
42
44
  )
45
+ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
43
46
  from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple
44
47
  from monarch._src.actor.endpoint import Selection
45
48
  from monarch._src.actor.shape import NDSlice
@@ -48,7 +51,7 @@ from monarch.common.controller_api import TController
48
51
  from monarch.common.function import ResolvableFunction
49
52
  from monarch.common.invocation import Seq
50
53
  from monarch.common.messages import Referenceable, SendResultOfActorCall
51
- from monarch.common.stream import StreamRef
54
+ from monarch.common.stream import Stream, StreamRef
52
55
  from monarch.common.tensor import dtensor_check, InputChecker, Tensor
53
56
  from monarch.common.tree import flatten
54
57
  from monarch.tensor_worker_main import _set_trace
@@ -322,9 +325,39 @@ def actor_send(
322
325
 
323
326
  client = cast(MeshClient, checker.mesh.client)
324
327
 
325
- stream_ref = chosen_stream._to_ref(client)
328
+ rest = partial(
329
+ _actor_send,
330
+ endpoint,
331
+ args_kwargs_tuple,
332
+ refs,
333
+ port,
334
+ selection,
335
+ client,
336
+ checker.mesh,
337
+ tensors,
338
+ chosen_stream,
339
+ )
340
+ if isinstance(endpoint._name, MethodSpecifier.Init):
341
+ # Init runs within the tokio loop, but creating a node blocks the loop sending actor messages, so
342
+ # we offload to a blocking thread
343
+ PythonTask.spawn_blocking(rest)
344
+ else:
345
+ rest()
326
346
 
327
- fut = (port, checker.mesh._ndslice) if port is not None else None
347
+
348
+ def _actor_send(
349
+ endpoint: ActorEndpoint,
350
+ args_kwargs_tuple: bytes,
351
+ refs: Sequence[Any],
352
+ port: Optional[Port[Any]],
353
+ selection: Selection,
354
+ client: MeshClient,
355
+ mesh: DeviceMesh,
356
+ tensors: List[Tensor],
357
+ chosen_stream: Stream,
358
+ ):
359
+ stream_ref = chosen_stream._to_ref(client)
360
+ fut = (port, mesh._ndslice) if port is not None else None
328
361
 
329
362
  ident = client.new_node([], tensors, cast("OldFuture", fut))
330
363
 
@@ -340,7 +373,7 @@ def actor_send(
340
373
  endpoint, selection, client, ident, args_kwargs_tuple, refs
341
374
  )
342
375
  worker_msg = SendResultOfActorCall(ident, broker_id, tensors, [], stream_ref)
343
- client.send(checker.mesh._ndslice, worker_msg)
376
+ client.send(mesh._ndslice, worker_msg)
344
377
  # we have to ask for status updates
345
378
  # from workers to be sure they have finished
346
379
  # enough work to count this future as finished,
Binary file