torchmonarch-nightly 2025.8.2__cp312-cp312-manylinux2014_x86_64.whl → 2025.9.3__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. monarch/_rust_bindings.so +0 -0
  2. monarch/_src/actor/actor_mesh.py +414 -216
  3. monarch/_src/actor/allocator.py +75 -6
  4. monarch/_src/actor/bootstrap_main.py +7 -4
  5. monarch/_src/actor/code_sync/__init__.py +2 -0
  6. monarch/_src/actor/debugger/__init__.py +7 -0
  7. monarch/_src/actor/{debugger.py → debugger/debugger.py} +246 -135
  8. monarch/_src/actor/{pdb_wrapper.py → debugger/pdb_wrapper.py} +62 -23
  9. monarch/_src/actor/endpoint.py +27 -45
  10. monarch/_src/actor/future.py +86 -24
  11. monarch/_src/actor/host_mesh.py +125 -0
  12. monarch/_src/actor/logging.py +94 -0
  13. monarch/_src/actor/pickle.py +25 -0
  14. monarch/_src/actor/proc_mesh.py +423 -156
  15. monarch/_src/actor/python_extension_methods.py +90 -0
  16. monarch/_src/actor/shape.py +8 -1
  17. monarch/_src/actor/source_loader.py +45 -0
  18. monarch/_src/actor/telemetry/__init__.py +172 -0
  19. monarch/_src/actor/telemetry/rust_span_tracing.py +6 -39
  20. monarch/_src/debug_cli/__init__.py +7 -0
  21. monarch/_src/debug_cli/debug_cli.py +43 -0
  22. monarch/_src/tensor_engine/rdma.py +64 -9
  23. monarch/_testing.py +1 -3
  24. monarch/actor/__init__.py +24 -4
  25. monarch/common/_C.so +0 -0
  26. monarch/common/device_mesh.py +14 -0
  27. monarch/common/future.py +10 -0
  28. monarch/common/remote.py +14 -25
  29. monarch/common/tensor.py +12 -0
  30. monarch/debug_cli/__init__.py +7 -0
  31. monarch/debug_cli/__main__.py +12 -0
  32. monarch/fetch.py +2 -2
  33. monarch/gradient/_gradient_generator.so +0 -0
  34. monarch/gradient_generator.py +4 -2
  35. monarch/mesh_controller.py +34 -14
  36. monarch/monarch_controller +0 -0
  37. monarch/tools/colors.py +25 -0
  38. monarch/tools/commands.py +42 -7
  39. monarch/tools/components/hyperactor.py +1 -1
  40. monarch/tools/config/__init__.py +31 -4
  41. monarch/tools/config/defaults.py +13 -3
  42. monarch/tools/config/environment.py +45 -0
  43. monarch/tools/config/workspace.py +165 -0
  44. monarch/tools/mesh_spec.py +2 -0
  45. monarch/utils/__init__.py +9 -0
  46. monarch/utils/utils.py +78 -0
  47. tests/error_test_binary.py +5 -3
  48. tests/python_actor_test_binary.py +52 -0
  49. tests/test_actor_error.py +142 -14
  50. tests/test_alloc.py +1 -1
  51. tests/test_allocator.py +59 -72
  52. tests/test_debugger.py +639 -45
  53. tests/test_env_before_cuda.py +4 -4
  54. tests/test_mesh_trait.py +38 -0
  55. tests/test_python_actors.py +965 -75
  56. tests/test_rdma.py +7 -6
  57. tests/test_tensor_engine.py +6 -6
  58. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/METADATA +82 -4
  59. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/RECORD +63 -47
  60. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/WHEEL +0 -0
  61. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/entry_points.txt +0 -0
  62. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/licenses/LICENSE +0 -0
  63. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,90 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import importlib
8
+
9
+ from typing import cast, Type, TypeVar
10
+
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ class PatchRustClass:
16
+ def __init__(self, rust_class: Type):
17
+ self.rust_class = rust_class
18
+
19
+ def __call__(self, python_class: Type[T]) -> Type[T]:
20
+ rust_name = f"{self.rust_class.__module__}.{self.rust_class.__name__}"
21
+ python_name = f"{python_class.__module__}.{python_class.__name__}"
22
+ if rust_name != python_name:
23
+ raise ValueError(f"mismatched type names {rust_name} != {python_name}")
24
+ for name, implementation in python_class.__dict__.items():
25
+ if hasattr(self.rust_class, name):
26
+ # do not patch in the stub methods that
27
+ # are already defined by the rust implementation
28
+ continue
29
+ if not callable(implementation) and not isinstance(
30
+ implementation, property
31
+ ):
32
+ continue
33
+ setattr(self.rust_class, name, implementation)
34
+ return cast(Type[T], self.rust_class)
35
+
36
+
37
+ def rust_struct(name: str) -> PatchRustClass:
38
+ """
39
+ When we bind a rust struct into Python, it is sometimes faster to implement
40
+ parts of the desired Python API in Python. It is also easier to understand
41
+ what the class does in terms of these methods.
42
+
43
+ We also want to avoid having to wrap rust objects in another layer of python objects
44
+ because:
45
+ * wrappers double the python overhead
46
+ * it is easy to confuse which level of wrappers and API takes, especially
47
+ along the python<->rust boundary.
48
+
49
+
50
+ To avoid wrappers we first define the class in pyo3. Lets say we add a class
51
+ monarch_hyperactor::actor_mesh::TestClass which we will want to extend with python methods in
52
+ the monarch/actor/_src/actor_mesh.py. In rust we will define the class as
53
+
54
+ #[pyclass(name = "TestClass", module = "monarch._src.actor_mesh")]
55
+ struct TestClass {}
56
+ #[pymethods]
57
+ impl TestClass {
58
+ fn hello(&self) {
59
+ println!("hello");
60
+ }
61
+ }
62
+
63
+ Then rather than writing typing stubs in a pyi file we write the stub code directly in
64
+ monarch/actor/_src/actor_mesh.py along with any helper methods:
65
+
66
+ @rust_struct("monarch_hyperactor::actor_mesh::TestClass")
67
+ class TestClass:
68
+ def hello(self) -> None:
69
+ ...
70
+ def hello_world(self) -> None:
71
+ self.hello()
72
+ print("world")
73
+
74
+ This class annotation then merges the python extension methods with the rust
75
+ class implementation. Any rust code that returns the TestClass will have the `hello_world`
76
+ extension method attached. Python typechecking always things TestClass is the python code,
77
+ so typing works.
78
+
79
+ It is ok to have the pyclass module not match where it is defined because (1) we patch it into the right place
80
+ to make sure pickling works, and (2) the rust_struct annotation points directly to where to find the rust code,
81
+ and will be discovered by goto line in the IDE.
82
+ """
83
+
84
+ *modules, name = name.split("::")
85
+ module_name = ".".join(modules)
86
+ module = importlib.import_module(f"monarch._rust_bindings.{module_name}")
87
+
88
+ rust_class = getattr(module, name)
89
+
90
+ return PatchRustClass(rust_class)
@@ -10,7 +10,7 @@ from abc import ABC, abstractmethod
10
10
 
11
11
  from typing import Dict, Generator, Sequence, Tuple, Union
12
12
 
13
- from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
13
+ from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape, Slice
14
14
 
15
15
  from typing_extensions import Self
16
16
 
@@ -224,5 +224,12 @@ class MeshTrait(ABC):
224
224
  def sizes(self) -> dict[str, int]:
225
225
  return dict(zip(self._labels, self._ndslice.sizes))
226
226
 
227
+ @property
228
+ def extent(self) -> "Extent":
229
+ return Extent(self._labels, self._ndslice.sizes)
230
+
231
+ def __len__(self) -> int:
232
+ return len(self._ndslice)
233
+
227
234
 
228
235
  __all__ = ["NDSlice", "Shape", "MeshTrait"]
@@ -0,0 +1,45 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import functools
9
+ import importlib
10
+ import importlib.abc
11
+ import linecache
12
+
13
+ from monarch._src.actor.actor_mesh import _context, Actor
14
+ from monarch._src.actor.endpoint import endpoint
15
+ from monarch._src.actor.proc_mesh import get_or_spawn_controller
16
+ from monarch._src.actor.sync_state import fake_sync_state
17
+
18
+
19
+ class SourceLoaderController(Actor):
20
+ @endpoint
21
+ def get_source(self, filename: str) -> str:
22
+ return "".join(linecache.getlines(filename))
23
+
24
+
25
+ @functools.cache
26
+ def source_loader_controller() -> SourceLoaderController:
27
+ with fake_sync_state():
28
+ return get_or_spawn_controller("source_loader", SourceLoaderController).get()
29
+
30
+
31
+ @functools.cache
32
+ def load_remote_source(filename: str) -> str:
33
+ with fake_sync_state():
34
+ return source_loader_controller().get_source.call_one(filename).get()
35
+
36
+
37
+ class RemoteImportLoader(importlib.abc.Loader):
38
+ def __init__(self, filename: str):
39
+ self._filename = filename
40
+
41
+ def get_source(self, _module_name: str) -> str:
42
+ if _context.get(None) is not None:
43
+ return load_remote_source(self._filename)
44
+ else:
45
+ raise ImportError(f"could not get source for {self._filename}")
@@ -8,12 +8,184 @@
8
8
 
9
9
 
10
10
  import logging
11
+ import warnings
12
+ from typing import Optional, Sequence
13
+
14
+ import opentelemetry.metrics as metrics # @manual=fbsource//third-party/pypi/opentelemetry-api:opentelemetry-api
15
+ import opentelemetry.trace as trace # @manual=fbsource//third-party/pypi/opentelemetry-api:opentelemetry-api
11
16
 
12
17
  from monarch._rust_bindings.monarch_hyperactor.telemetry import ( # @manual=//monarch/monarch_extension:monarch_extension
13
18
  forward_to_tracing,
19
+ PyCounter,
20
+ PyHistogram,
21
+ PyUpDownCounter,
14
22
  )
23
+ from monarch._src.actor.telemetry.rust_span_tracing import RustTracerProvider
24
+ from opentelemetry.context import Context
25
+ from opentelemetry.metrics import CallbackT
26
+ from opentelemetry.util.types import Attributes
15
27
 
16
28
 
17
29
  class TracingForwarder(logging.Handler):
18
30
  def emit(self, record: logging.LogRecord) -> None:
19
31
  forward_to_tracing(record)
32
+
33
+
34
+ class Counter(metrics.Counter):
35
+ inner: PyCounter
36
+
37
+ def __init__(self, name: str) -> None:
38
+ super().__init__(name)
39
+ self.inner = PyCounter(name)
40
+
41
+ def add(
42
+ self,
43
+ amount: int | float,
44
+ attributes: Optional[Attributes] = None,
45
+ context: Optional[Context] = None,
46
+ ) -> None:
47
+ return self.inner.add(int(amount))
48
+
49
+
50
+ class UpDownCounter(metrics.UpDownCounter):
51
+ inner: PyUpDownCounter
52
+
53
+ def __init__(self, name: str) -> None:
54
+ super().__init__(name)
55
+ self.inner = PyUpDownCounter(name)
56
+
57
+ def add(
58
+ self,
59
+ amount: int | float,
60
+ attributes: Optional[Attributes] = None,
61
+ context: Optional[Context] = None,
62
+ ) -> None:
63
+ self.inner.add(int(amount))
64
+
65
+
66
+ class Histogram(metrics.Histogram):
67
+ inner: PyHistogram
68
+
69
+ def __init__(self, name: str) -> None:
70
+ super().__init__(name)
71
+ self.inner = PyHistogram(name)
72
+
73
+ def record(
74
+ self,
75
+ amount: int | float,
76
+ attributes: Optional[Attributes] = None,
77
+ context: Optional[Context] = None,
78
+ ) -> None:
79
+ self.inner.record(amount)
80
+
81
+
82
+ class Meter(metrics.Meter):
83
+ def create_counter(
84
+ self,
85
+ name: str,
86
+ unit: str = "",
87
+ description: str = "",
88
+ ) -> metrics.Counter:
89
+ return Counter(name)
90
+
91
+ def create_up_down_counter(
92
+ self,
93
+ name: str,
94
+ unit: str = "",
95
+ description: str = "",
96
+ ) -> metrics.UpDownCounter:
97
+ return UpDownCounter(name)
98
+
99
+ def create_observable_counter(
100
+ self,
101
+ name: str,
102
+ callbacks: Optional[Sequence[CallbackT]] = None,
103
+ unit: str = "",
104
+ description: str = "",
105
+ ) -> metrics.ObservableCounter:
106
+ raise NotImplementedError()
107
+
108
+ def create_histogram(
109
+ self,
110
+ name: str,
111
+ unit: str = "",
112
+ description: str = "",
113
+ *,
114
+ explicit_bucket_boundaries_advisory: Optional[Sequence[float]] = None,
115
+ ) -> metrics.Histogram:
116
+ return Histogram(name)
117
+
118
+ def create_gauge( # type: ignore # pylint: disable=no-self-use
119
+ self,
120
+ name: str,
121
+ unit: str = "",
122
+ description: str = "",
123
+ ) -> metrics._Gauge: # pyright: ignore[reportReturnType]
124
+ warnings.warn(
125
+ "create_gauge() is not implemented and will be a no-op", stacklevel=2
126
+ )
127
+ raise NotImplementedError()
128
+
129
+ def create_observable_gauge(
130
+ self,
131
+ name: str,
132
+ callbacks: Optional[Sequence[CallbackT]] = None,
133
+ unit: str = "",
134
+ description: str = "",
135
+ ) -> metrics.ObservableGauge:
136
+ raise NotImplementedError()
137
+
138
+ def create_observable_up_down_counter(
139
+ self,
140
+ name: str,
141
+ callbacks: Optional[Sequence[CallbackT]] = None,
142
+ unit: str = "",
143
+ description: str = "",
144
+ ) -> metrics.ObservableUpDownCounter:
145
+ raise NotImplementedError()
146
+
147
+
148
+ class MeterProvider(metrics.MeterProvider):
149
+ def get_meter(
150
+ self,
151
+ name: str,
152
+ version: Optional[str] = None,
153
+ schema_url: Optional[str] = None,
154
+ attributes: Optional[Attributes] = None,
155
+ ) -> metrics.Meter:
156
+ return Meter(name, version, schema_url)
157
+
158
+
159
+ def get_monarch_tracer() -> trace.Tracer:
160
+ """
161
+ Creates and returns a Monarch python tracer that logs to the Rust telemetry system.
162
+
163
+ Returns:
164
+ Tracer: A configured OpenTelemetry tracer for Monarch.
165
+
166
+ Usage:
167
+ tracer = get_monarch_tracer()
168
+ with tracer.start_as_current_span("span_name") as span:
169
+ # code here
170
+ """
171
+ install()
172
+ return trace.get_tracer("monarch.python.tracer")
173
+
174
+
175
+ _INSTALLED = False
176
+
177
+ METER: metrics.Meter = metrics.get_meter("monarch")
178
+
179
+
180
+ def install() -> None:
181
+ global _INSTALLED
182
+ if _INSTALLED:
183
+ return
184
+
185
+ provider = RustTracerProvider()
186
+ trace.set_tracer_provider(provider)
187
+ metrics.set_meter_provider(MeterProvider())
188
+
189
+ global METER
190
+ METER = metrics.get_meter("monarch")
191
+ _INSTALLED = True
@@ -11,28 +11,23 @@ from contextlib import contextmanager
11
11
  from typing import Iterator, Mapping, Optional, Union
12
12
 
13
13
  import opentelemetry.util.types as types # @manual=fbsource//third-party/pypi/opentelemetry-api:opentelemetry-api
14
-
15
14
  from monarch._rust_bindings.monarch_hyperactor.telemetry import (
16
15
  get_current_span_id,
17
16
  PySpan,
18
17
  )
19
-
20
18
  from opentelemetry import ( # @manual=fbsource//third-party/pypi/opentelemetry-api:opentelemetry-api
21
19
  trace,
22
20
  )
23
- from opentelemetry.trace import Tracer
24
21
  from opentelemetry.trace.status import Status, StatusCode
25
- from pyre_extensions import override
26
22
 
27
23
  logger: logging.Logger = logging.getLogger(__name__)
28
24
 
29
25
 
30
26
  class SpanWrapper(trace.Span):
31
- def __init__(self, name: str) -> None:
27
+ def __init__(self, name: str, actor_id: Optional[str]) -> None:
32
28
  super().__init__()
33
- self._span: PySpan | None = PySpan(name)
29
+ self._span: PySpan | None = PySpan(name, actor_id)
34
30
 
35
- @override
36
31
  def end(self, end_time: Optional[int] = None) -> None:
37
32
  # since PySpan is not sendable, we need to make sure it is deallocated on this thread so it doesn't log warnings.
38
33
  s = self._span
@@ -94,7 +89,8 @@ class RustTracer(trace.Tracer):
94
89
  record_exception: bool = True,
95
90
  set_status_on_exception: bool = True,
96
91
  ) -> trace.Span:
97
- return SpanWrapper(name)
92
+ actor_id = str(attributes.get("actor_id")) if attributes else None
93
+ return SpanWrapper(name, actor_id)
98
94
 
99
95
  @contextmanager
100
96
  # pyre-fixme[15]: `start_as_current_span` overrides method defined in `Tracer`
@@ -111,14 +107,14 @@ class RustTracer(trace.Tracer):
111
107
  set_status_on_exception: bool = True,
112
108
  end_on_exit: bool = True,
113
109
  ) -> Iterator[trace.Span]:
114
- with SpanWrapper(name) as s:
110
+ actor_id = str(attributes.get("actor_id")) if attributes else None
111
+ with SpanWrapper(name, actor_id) as s:
115
112
  with trace.use_span(s):
116
113
  yield s
117
114
  del s
118
115
 
119
116
 
120
117
  class RustTracerProvider(trace.TracerProvider):
121
- @override
122
118
  def get_tracer(
123
119
  self,
124
120
  instrumenting_module_name: str,
@@ -128,32 +124,3 @@ class RustTracerProvider(trace.TracerProvider):
128
124
  **kwargs: object,
129
125
  ) -> trace.Tracer:
130
126
  return RustTracer()
131
-
132
-
133
- def get_monarch_tracer() -> Tracer:
134
- """
135
- Creates and returns a Monarch python tracer that logs to the Rust telemetry system.
136
-
137
- Returns:
138
- Tracer: A configured OpenTelemetry tracer for Monarch.
139
-
140
- Usage:
141
- tracer = get_monarch_tracer()
142
- with tracer.start_as_current_span("span_name") as span:
143
- # code here
144
- """
145
- install()
146
- return trace.get_tracer("monarch.python.tracer")
147
-
148
-
149
- _INSTALLED = False
150
-
151
-
152
- def install() -> None:
153
- global _INSTALLED
154
- if _INSTALLED:
155
- return
156
-
157
- provider = RustTracerProvider()
158
- trace.set_tracer_provider(provider)
159
- _INSTALLED = True
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
@@ -0,0 +1,43 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import argparse
9
+ import logging
10
+ import subprocess
11
+
12
+ from monarch._src.actor.debugger.debugger import (
13
+ _get_debug_server_host,
14
+ _get_debug_server_port,
15
+ )
16
+
17
+
18
+ def run():
19
+ parser = argparse.ArgumentParser(description="Monarch Debug CLI")
20
+ parser.add_argument(
21
+ "--host",
22
+ type=str,
23
+ default=_get_debug_server_host(),
24
+ help="Hostname where the debug server is running",
25
+ )
26
+ parser.add_argument(
27
+ "--port",
28
+ type=str,
29
+ default=_get_debug_server_port(),
30
+ help="Port that the debug server is listening on",
31
+ )
32
+ args = parser.parse_args()
33
+
34
+ for cmd in ["ncat", "nc", "netcat"]:
35
+ try:
36
+ subprocess.run([cmd, f"{args.host}", f"{args.port}"], check=True)
37
+ return
38
+ except FileNotFoundError:
39
+ pass
40
+
41
+ logging.error(
42
+ "Could not find a suitable netcat binary. Please install one and try again."
43
+ )
@@ -4,19 +4,28 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-unsafe
8
+ import asyncio
9
+ import functools
7
10
  import logging
8
11
  import warnings
9
12
  from typing import Optional
10
13
 
11
14
  import torch
15
+ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
12
16
 
13
17
  try:
14
- from monarch._rust_bindings.rdma import _RdmaBuffer
18
+ from monarch._rust_bindings.rdma import _RdmaBuffer, _RdmaManager
15
19
  except ImportError as e:
16
20
  logging.error("RDMA is not available: {}".format(e))
17
21
  raise e
18
- from monarch._src.actor.actor_mesh import MonarchContext
22
+ from typing import Dict
23
+
24
+ from monarch._src.actor.actor_mesh import Actor, context
25
+ from monarch._src.actor.endpoint import endpoint
19
26
  from monarch._src.actor.future import Future
27
+ from monarch._src.actor.proc_mesh import get_or_spawn_controller, ProcMesh
28
+ from pyre_extensions import none_throws
20
29
 
21
30
 
22
31
  # RDMARead/WriteTransferWarnings are warnings that are only printed once per process.
@@ -37,6 +46,44 @@ def is_available():
37
46
  return _RdmaBuffer.rdma_supported()
38
47
 
39
48
 
49
+ class RdmaController(Actor):
50
+ def __init__(self) -> None:
51
+ self._managers: Dict[ProcMesh, _RdmaManager] = {}
52
+ self._lock = asyncio.Lock()
53
+
54
+ @endpoint
55
+ async def init_rdma_on_mesh(self, proc_mesh: ProcMesh) -> None:
56
+ if not _RdmaBuffer.rdma_supported():
57
+ raise RuntimeError(
58
+ "Cannot spawn _RdmaManager because RDMA is not supported on this machine"
59
+ )
60
+
61
+ if proc_mesh in self._managers:
62
+ return
63
+
64
+ async with self._lock:
65
+ if proc_mesh not in self._managers:
66
+ self._managers[proc_mesh] = none_throws(
67
+ await Future(
68
+ coro=_RdmaManager.create_rdma_manager_nonblocking(
69
+ await Future(coro=proc_mesh._proc_mesh.task())
70
+ )
71
+ )
72
+ )
73
+
74
+
75
+ # Cached so that we don't have to call out to the root client every time,
76
+ # which may be on a different host.
77
+ @functools.cache
78
+ def _ensure_init_rdma_manager() -> Shared[None]:
79
+ async def task() -> None:
80
+ await (
81
+ await get_or_spawn_controller("rdma_controller", RdmaController)
82
+ ).init_rdma_on_mesh.call_one(none_throws(context().actor_instance.proc_mesh))
83
+
84
+ return PythonTask.from_coroutine(task()).spawn()
85
+
86
+
40
87
  def _assert_tensor_is_1d_contiguous_uint8(t: torch.Tensor) -> None:
41
88
  if t.ndim != 1:
42
89
  raise ValueError(f"Tensor must be 1D, got {t.ndim}D")
@@ -59,6 +106,10 @@ class RDMABuffer:
59
106
  is_available()
60
107
  ), "Tried to create an RDMABuffer, but RDMA is not available on this platform."
61
108
 
109
+ # We need to ensure that _RdmaManager is initialized at this point, because under the hood
110
+ # _RdmaBuffer.create_rdma_buffer_blocking relies on this being the case.
111
+ _ensure_init_rdma_manager().block_on()
112
+
62
113
  if data.device.type != "cpu":
63
114
  # TODO - CUDA support for RDMABuffer exists at the Rust layer, but
64
115
  # runs into issues with MR creation. For now, only support CPU tensors.
@@ -76,12 +127,12 @@ class RDMABuffer:
76
127
  storage = data.untyped_storage()
77
128
  addr: int = storage.data_ptr()
78
129
  size = storage.element_size() * data.numel()
79
- ctx = MonarchContext.get()
130
+ ctx = context()
80
131
  self._buffer: _RdmaBuffer = _RdmaBuffer.create_rdma_buffer_blocking(
81
132
  addr=addr,
82
133
  size=size,
83
- proc_id=ctx.proc_id,
84
- client=ctx.mailbox,
134
+ proc_id=ctx.actor_instance.proc_id,
135
+ client=ctx.actor_instance._mailbox,
85
136
  )
86
137
  # TODO - specific exception
87
138
  except Exception as e:
@@ -120,10 +171,12 @@ class RDMABuffer:
120
171
  f"offset + size ({offset + size}) must be <= dst.numel() ({dst.numel()})"
121
172
  )
122
173
 
123
- local_proc_id = MonarchContext.get().proc_id
124
- client = MonarchContext.get().mailbox
174
+ local_proc_id = context().actor_instance.proc_id
175
+ client = context().actor_instance._mailbox
125
176
 
126
177
  async def read_into_nonblocking() -> Optional[int]:
178
+ await _ensure_init_rdma_manager()
179
+
127
180
  res = await self._buffer.read_into(
128
181
  addr=addr,
129
182
  size=size,
@@ -167,10 +220,12 @@ class RDMABuffer:
167
220
  f"size + offset ({size + offset}) must be <= src.numel() ({src.numel()})"
168
221
  )
169
222
 
170
- local_proc_id = MonarchContext.get().proc_id
171
- client = MonarchContext.get().mailbox
223
+ local_proc_id = context().actor_instance.proc_id
224
+ client = context().actor_instance._mailbox
172
225
 
173
226
  async def write_from_nonblocking() -> None:
227
+ await _ensure_init_rdma_manager()
228
+
174
229
  res = await self._buffer.write_from(
175
230
  addr=addr,
176
231
  size=size,
monarch/_testing.py CHANGED
@@ -133,9 +133,7 @@ class TestingContext:
133
133
  ) -> Generator[DeviceMesh, None, None]:
134
134
  key = (num_hosts, gpu_per_host)
135
135
  if key not in self._proc_mesh_cache:
136
- self._proc_mesh_cache[key] = proc_mesh(
137
- hosts=num_hosts, gpus=gpu_per_host
138
- ).get()
136
+ self._proc_mesh_cache[key] = proc_mesh(hosts=num_hosts, gpus=gpu_per_host)
139
137
 
140
138
  dm = spawn_tensor_engine(self._proc_mesh_cache[key])
141
139
  dm = dm.rename(hosts="host", gpus="gpu")
monarch/actor/__init__.py CHANGED
@@ -4,6 +4,7 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-unsafe
7
8
  """
8
9
  Monarch Actor API - Public interface for actor functionality.
9
10
  """
@@ -13,18 +14,29 @@ from monarch._src.actor.actor_mesh import (
13
14
  Actor,
14
15
  ActorError,
15
16
  as_endpoint,
17
+ Channel,
18
+ context,
16
19
  current_actor_name,
17
20
  current_rank,
18
21
  current_size,
19
22
  Point,
20
- port,
23
+ Port,
24
+ PortReceiver,
21
25
  send,
22
26
  ValueMesh,
23
27
  )
28
+ from monarch._src.actor.debugger.debugger import debug_controller
24
29
  from monarch._src.actor.endpoint import endpoint
25
30
  from monarch._src.actor.future import Future
31
+
32
+ from monarch._src.actor.host_mesh import (
33
+ HostMesh,
34
+ hosts_from_config,
35
+ this_host,
36
+ this_proc,
37
+ )
26
38
  from monarch._src.actor.proc_mesh import (
27
- debug_client,
39
+ get_or_spawn_controller,
28
40
  local_proc_mesh,
29
41
  proc_mesh,
30
42
  ProcMesh,
@@ -45,9 +57,17 @@ __all__ = [
45
57
  "Point",
46
58
  "proc_mesh",
47
59
  "ProcMesh",
48
- "port",
60
+ "Channel",
49
61
  "send",
50
62
  "sim_proc_mesh",
51
63
  "ValueMesh",
52
- "debug_client",
64
+ "debug_controller",
65
+ "get_or_spawn_controller",
66
+ "this_host",
67
+ "this_proc",
68
+ "HostMesh",
69
+ "context",
70
+ "hosts_from_config",
71
+ "Port",
72
+ "PortReceiver",
53
73
  ]
monarch/common/_C.so CHANGED
Binary file