torchmonarch-nightly 2025.8.1__cp313-cp313-manylinux2014_x86_64.whl → 2025.9.3__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. monarch/_rust_bindings.so +0 -0
  2. monarch/_src/actor/actor_mesh.py +414 -216
  3. monarch/_src/actor/allocator.py +75 -6
  4. monarch/_src/actor/bootstrap_main.py +7 -4
  5. monarch/_src/actor/code_sync/__init__.py +2 -0
  6. monarch/_src/actor/debugger/__init__.py +7 -0
  7. monarch/_src/actor/{debugger.py → debugger/debugger.py} +246 -135
  8. monarch/_src/actor/{pdb_wrapper.py → debugger/pdb_wrapper.py} +62 -23
  9. monarch/_src/actor/endpoint.py +27 -45
  10. monarch/_src/actor/future.py +86 -24
  11. monarch/_src/actor/host_mesh.py +125 -0
  12. monarch/_src/actor/logging.py +94 -0
  13. monarch/_src/actor/pickle.py +25 -0
  14. monarch/_src/actor/proc_mesh.py +423 -156
  15. monarch/_src/actor/python_extension_methods.py +90 -0
  16. monarch/_src/actor/shape.py +8 -1
  17. monarch/_src/actor/source_loader.py +45 -0
  18. monarch/_src/actor/telemetry/__init__.py +172 -0
  19. monarch/_src/actor/telemetry/rust_span_tracing.py +6 -39
  20. monarch/_src/debug_cli/__init__.py +7 -0
  21. monarch/_src/debug_cli/debug_cli.py +43 -0
  22. monarch/_src/tensor_engine/rdma.py +64 -9
  23. monarch/_testing.py +1 -3
  24. monarch/actor/__init__.py +24 -4
  25. monarch/common/_C.so +0 -0
  26. monarch/common/device_mesh.py +14 -0
  27. monarch/common/future.py +10 -0
  28. monarch/common/remote.py +14 -25
  29. monarch/common/tensor.py +12 -0
  30. monarch/debug_cli/__init__.py +7 -0
  31. monarch/debug_cli/__main__.py +12 -0
  32. monarch/fetch.py +2 -2
  33. monarch/gradient/_gradient_generator.so +0 -0
  34. monarch/gradient_generator.py +4 -2
  35. monarch/mesh_controller.py +34 -14
  36. monarch/monarch_controller +0 -0
  37. monarch/tools/colors.py +25 -0
  38. monarch/tools/commands.py +42 -7
  39. monarch/tools/components/hyperactor.py +1 -1
  40. monarch/tools/config/__init__.py +31 -4
  41. monarch/tools/config/defaults.py +13 -3
  42. monarch/tools/config/environment.py +45 -0
  43. monarch/tools/config/workspace.py +165 -0
  44. monarch/tools/mesh_spec.py +2 -0
  45. monarch/utils/__init__.py +9 -0
  46. monarch/utils/utils.py +78 -0
  47. tests/error_test_binary.py +5 -3
  48. tests/python_actor_test_binary.py +52 -0
  49. tests/test_actor_error.py +142 -14
  50. tests/test_alloc.py +1 -1
  51. tests/test_allocator.py +59 -72
  52. tests/test_coalescing.py +1 -1
  53. tests/test_debugger.py +639 -45
  54. tests/test_env_before_cuda.py +4 -4
  55. tests/test_mesh_trait.py +38 -0
  56. tests/test_python_actors.py +979 -75
  57. tests/test_rdma.py +7 -6
  58. tests/test_tensor_engine.py +6 -6
  59. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/METADATA +82 -4
  60. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/RECORD +64 -48
  61. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/WHEEL +0 -0
  62. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/entry_points.txt +0 -0
  63. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/licenses/LICENSE +0 -0
  64. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,12 @@
8
8
  import bdb
9
9
  import inspect
10
10
  import io
11
+ import linecache
12
+ import os
11
13
  import pdb # noqa
12
14
  import socket
13
15
  import sys
16
+ from contextlib import contextmanager
14
17
  from dataclasses import dataclass
15
18
 
16
19
  from typing import Dict, TYPE_CHECKING
@@ -19,7 +22,7 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
19
22
  from monarch._src.actor.sync_state import fake_sync_state
20
23
 
21
24
  if TYPE_CHECKING:
22
- from monarch._src.actor.debugger import DebugClient
25
+ from monarch._src.actor.debugger.debugger import DebugController
23
26
 
24
27
 
25
28
  @dataclass
@@ -29,31 +32,41 @@ class DebuggerWrite:
29
32
  lineno: int | None
30
33
 
31
34
 
35
+ @contextmanager
36
+ def _debug_controller_request_ctx():
37
+ try:
38
+ with fake_sync_state():
39
+ yield
40
+ except Exception as e:
41
+ raise bdb.BdbQuit from e
42
+
43
+
32
44
  class PdbWrapper(pdb.Pdb):
33
45
  def __init__(
34
46
  self,
35
47
  rank: int,
36
48
  coords: Dict[str, int],
37
49
  actor_id: ActorId,
38
- client_ref: "DebugClient",
50
+ controller: "DebugController",
39
51
  header: str | None = None,
40
52
  ):
41
53
  self.rank = rank
42
54
  self.coords = coords
43
55
  self.header = header
44
56
  self.actor_id = actor_id
45
- self.client_ref = client_ref
57
+ self.controller = controller
46
58
  # pyre-ignore
47
59
  super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self))
48
60
  self._first = True
49
61
 
50
62
  def set_trace(self, frame=None):
51
- self.client_ref.debugger_session_start.broadcast(
52
- self.rank,
53
- self.coords,
54
- socket.getfqdn(socket.gethostname()),
55
- self.actor_id.actor_name,
56
- )
63
+ with _debug_controller_request_ctx():
64
+ self.controller.debugger_session_start.call_one(
65
+ self.rank,
66
+ self.coords,
67
+ socket.getfqdn(socket.gethostname()),
68
+ self.actor_id.actor_name,
69
+ ).get()
57
70
  if self.header:
58
71
  self.message(self.header)
59
72
  super().set_trace(frame)
@@ -69,10 +82,35 @@ class PdbWrapper(pdb.Pdb):
69
82
  else:
70
83
  super().do_clear(arg)
71
84
 
85
+ def lookupmodule(self, filename):
86
+ filename = super().lookupmodule(filename)
87
+ if (
88
+ filename is not None
89
+ and not os.path.exists(filename)
90
+ and filename not in linecache.cache
91
+ ):
92
+ from monarch._src.actor.actor_mesh import ActorError
93
+ from monarch._src.actor.source_loader import load_remote_source
94
+
95
+ try:
96
+ with fake_sync_state():
97
+ source = load_remote_source(filename)
98
+ if source:
99
+ linecache.cache[filename] = (
100
+ len(source),
101
+ None,
102
+ source.splitlines(keepends=True),
103
+ filename,
104
+ )
105
+ except ActorError as e:
106
+ self.error(f"Failed querying root client host for source code: {e}")
107
+ return filename
108
+
72
109
  def end_debug_session(self):
73
- self.client_ref.debugger_session_end.broadcast(
74
- self.actor_id.actor_name, self.rank
75
- )
110
+ with _debug_controller_request_ctx():
111
+ self.controller.debugger_session_end.call_one(
112
+ self.actor_id.actor_name, self.rank
113
+ ).get()
76
114
  # Once the debug client actor is notified of the session being over,
77
115
  # we need to prevent any additional requests being sent for the session
78
116
  # by redirecting stdin and stdout.
@@ -91,8 +129,8 @@ class ReadWrapper(io.RawIOBase):
91
129
  self.session = session
92
130
 
93
131
  def readinto(self, b):
94
- with fake_sync_state():
95
- response = self.session.client_ref.debugger_read.call_one(
132
+ with _debug_controller_request_ctx():
133
+ response = self.session.controller.debugger_read.call_one(
96
134
  self.session.actor_id.actor_name, self.session.rank, len(b)
97
135
  ).get()
98
136
  if response == "detach":
@@ -128,15 +166,16 @@ class WriteWrapper:
128
166
  function = f"{inspect.getmodulename(self.session.curframe.f_code.co_filename)}.{self.session.curframe.f_code.co_name}"
129
167
  # pyre-ignore
130
168
  lineno = self.session.curframe.f_lineno
131
- self.session.client_ref.debugger_write.broadcast(
132
- self.session.actor_id.actor_name,
133
- self.session.rank,
134
- DebuggerWrite(
135
- s.encode(),
136
- function,
137
- lineno,
138
- ),
139
- )
169
+ with _debug_controller_request_ctx():
170
+ self.session.controller.debugger_write.call_one(
171
+ self.session.actor_id.actor_name,
172
+ self.session.rank,
173
+ DebuggerWrite(
174
+ s.encode(),
175
+ function,
176
+ lineno,
177
+ ),
178
+ ).get()
140
179
 
141
180
  def flush(self):
142
181
  pass
@@ -11,7 +11,6 @@ from abc import ABC, abstractmethod
11
11
  from operator import mul
12
12
  from typing import (
13
13
  Any,
14
- AsyncGenerator,
15
14
  Awaitable,
16
15
  Callable,
17
16
  cast,
@@ -31,36 +30,25 @@ from typing import (
31
30
  TypeVar,
32
31
  )
33
32
 
33
+ from monarch._rust_bindings.monarch_hyperactor.shape import Extent
34
+
34
35
  from monarch._src.actor.future import Future
35
36
  from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call
36
37
 
37
38
  if TYPE_CHECKING:
38
39
  from monarch._src.actor.actor_mesh import (
39
- ActorMeshRef,
40
+ ActorMesh,
41
+ HyOncePortReceiver,
40
42
  HyPortReceiver,
41
- OncePortReceiver,
42
43
  Port,
43
- PortTuple,
44
+ PortReceiver,
44
45
  ValueMesh,
45
46
  )
46
47
 
47
48
  P = ParamSpec("P")
48
49
  R = TypeVar("R")
49
50
 
50
- Selection = Literal["all", "choose"] | int
51
-
52
-
53
- class Extent:
54
- def __init__(self, labels: Sequence[str], sizes: Sequence[int]) -> None:
55
- self.labels = labels
56
- self.sizes = sizes
57
-
58
- @property
59
- def nelements(self) -> int:
60
- return functools.reduce(mul, self.sizes, 1)
61
-
62
- def __str__(self) -> str:
63
- return str(dict(zip(self.labels, self.sizes)))
51
+ Selection = Literal["all", "choose"]
64
52
 
65
53
 
66
54
  Propagator = Any
@@ -90,9 +78,10 @@ class Endpoint(ABC, Generic[P, R]):
90
78
  """
91
79
  pass
92
80
 
93
- @abstractmethod
94
- def _port(self, once: bool = False) -> "PortTuple[R]":
95
- pass
81
+ def _port(self, once: bool = False) -> "Tuple[Port[R], PortReceiver[R]]":
82
+ from monarch._src.actor.actor_mesh import Channel
83
+
84
+ return Channel[R].open(once)
96
85
 
97
86
  @abstractmethod
98
87
  def _call_name(self) -> Any:
@@ -101,7 +90,7 @@ class Endpoint(ABC, Generic[P, R]):
101
90
  """
102
91
  pass
103
92
 
104
- def _supervise(self, r: "HyPortReceiver | OncePortReceiver") -> Any:
93
+ def _supervise(self, r: "HyPortReceiver | HyOncePortReceiver") -> Any:
105
94
  return r
106
95
 
107
96
  # the following are all 'adverbs' or different ways to handle the
@@ -115,17 +104,14 @@ class Endpoint(ABC, Generic[P, R]):
115
104
 
116
105
  Load balanced RPC-style entrypoint for request/response messaging.
117
106
  """
118
- from monarch._src.actor.actor_mesh import port
119
107
 
120
- p, r = port(self, once=True)
108
+ p, r = self._port(once=True)
121
109
  # pyre-ignore
122
110
  self._send(args, kwargs, port=p, selection="choose")
123
111
  return r.recv()
124
112
 
125
113
  def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
126
- from monarch._src.actor.actor_mesh import port
127
-
128
- p, r = port(self, once=True)
114
+ p, r = self._port(once=True)
129
115
  # pyre-ignore
130
116
  extent = self._send(args, kwargs, port=p, selection="choose")
131
117
  if extent.nelements != 1:
@@ -135,9 +121,10 @@ class Endpoint(ABC, Generic[P, R]):
135
121
  return r.recv()
136
122
 
137
123
  def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
138
- from monarch._src.actor.actor_mesh import ranked_port, ValueMesh
124
+ from monarch._src.actor.actor_mesh import ValueMesh
139
125
 
140
- p, r = ranked_port(self)
126
+ p, unranked = self._port()
127
+ r = unranked.ranked()
141
128
  # pyre-ignore
142
129
  extent = self._send(args, kwargs, port=p)
143
130
 
@@ -157,29 +144,24 @@ class Endpoint(ABC, Generic[P, R]):
157
144
 
158
145
  return Future(coro=process())
159
146
 
160
- def _stream(
147
+ def stream(
161
148
  self, *args: P.args, **kwargs: P.kwargs
162
- ) -> Generator[Coroutine[Any, Any, R], None, None]:
149
+ ) -> Generator[Future[R], None, None]:
163
150
  """
164
151
  Broadcasts to all actors and yields their responses as a stream / generator.
165
152
 
166
153
  This enables processing results from multiple actors incrementally as
167
154
  they become available. Returns an async generator of response values.
168
155
  """
169
- from monarch._src.actor.actor_mesh import port
170
-
171
- p, r = port(self)
172
- # pyre-ignore
156
+ p, r = self._port()
157
+ # type: ignore
173
158
  extent = self._send(args, kwargs, port=p)
174
- for _ in range(extent.nelements):
175
- # pyre-ignore
176
- yield r._recv()
177
159
 
178
- def stream(
179
- self, *args: P.args, **kwargs: P.kwargs
180
- ) -> Generator[Future[R], None, None]:
181
- for coro in self._stream(*args, **kwargs):
182
- yield Future(coro=coro)
160
+ def _stream():
161
+ for _ in range(extent.nelements):
162
+ yield r.recv()
163
+
164
+ return _stream()
183
165
 
184
166
  def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
185
167
  """
@@ -261,12 +243,12 @@ class EndpointProperty(Generic[P, R]):
261
243
 
262
244
  class NotAnEndpoint:
263
245
  """
264
- Used as the dynamic value of functions on an ActorMeshRef that were not marked as endpoints.
246
+ Used as the dynamic value of functions on an ActorMesh that were not marked as endpoints.
265
247
  This is used both to give a better error message (since we cannot prevent the type system from thinking they are methods),
266
248
  and to provide the oppurtunity for someone to do endpoint(x.foo) on something that wasn't marked as an endpoint.
267
249
  """
268
250
 
269
- def __init__(self, ref: "ActorMeshRef", name: str):
251
+ def __init__(self, ref: "ActorMesh", name: str):
270
252
  self._ref = ref
271
253
  self._name = name
272
254
 
@@ -6,6 +6,7 @@
6
6
 
7
7
  import asyncio
8
8
  import traceback
9
+ import warnings
9
10
  from functools import partial
10
11
  from typing import (
11
12
  Any,
@@ -19,9 +20,13 @@ from typing import (
19
20
  TypeVar,
20
21
  )
21
22
 
22
- from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
23
+ from monarch._rust_bindings.monarch_hyperactor.pytokio import (
24
+ is_tokio_thread,
25
+ PythonTask,
26
+ Shared,
27
+ )
23
28
 
24
- from typing_extensions import Self
29
+ from typing_extensions import deprecated, Self
25
30
 
26
31
  R = TypeVar("R")
27
32
 
@@ -78,7 +83,11 @@ class _Asyncio(NamedTuple):
78
83
  fut: asyncio.Future
79
84
 
80
85
 
81
- _Status = _Unawaited | _Complete | _Exception | _Asyncio
86
+ class _Tokio(NamedTuple):
87
+ shared: Shared
88
+
89
+
90
+ _Status = _Unawaited | _Complete | _Exception | _Asyncio | _Tokio
82
91
 
83
92
 
84
93
  class Future(Generic[R]):
@@ -107,31 +116,60 @@ class Future(Generic[R]):
107
116
  return cast("R", value)
108
117
  case _Exception(exe=exe):
109
118
  raise exe
119
+ case _Tokio(_):
120
+ raise ValueError(
121
+ "already converted into a pytokio.Shared object, use 'await' from a PythonTask coroutine to get the value."
122
+ )
110
123
  case _:
111
124
  raise RuntimeError("unknown status")
112
125
 
113
126
  def __await__(self) -> Generator[Any, Any, R]:
114
- match self._status:
115
- case _Unawaited(coro=coro):
116
- loop = asyncio.get_running_loop()
117
- fut = loop.create_future()
118
- self._status = _Asyncio(fut)
119
-
120
- async def mark_complete():
121
- try:
122
- func, value = fut.set_result, await coro
123
- except Exception as e:
124
- func, value = fut.set_exception, e
125
- loop.call_soon_threadsafe(func, value)
126
-
127
- PythonTask.from_coroutine(mark_complete()).spawn()
128
- return fut.__await__()
129
- case _Asyncio(fut=fut):
130
- return fut.__await__()
131
- case _:
132
- raise ValueError(
133
- "already converted into a synchronous future, use 'get' to get the value."
134
- )
127
+ if asyncio._get_running_loop() is not None:
128
+ match self._status:
129
+ case _Unawaited(coro=coro):
130
+ loop = asyncio.get_running_loop()
131
+ fut = loop.create_future()
132
+ self._status = _Asyncio(fut)
133
+
134
+ async def mark_complete():
135
+ try:
136
+ func, value = fut.set_result, await coro
137
+ except Exception as e:
138
+ func, value = fut.set_exception, e
139
+ loop.call_soon_threadsafe(func, value)
140
+
141
+ PythonTask.from_coroutine(mark_complete()).spawn()
142
+ return fut.__await__()
143
+ case _Asyncio(fut=fut):
144
+ return fut.__await__()
145
+ case _Tokio(_):
146
+ raise ValueError(
147
+ "already converted into a tokio future, but being awaited from the asyncio loop."
148
+ )
149
+ case _:
150
+ raise ValueError(
151
+ "already converted into a synchronous future, use 'get' to get the value."
152
+ )
153
+ elif is_tokio_thread():
154
+ match self._status:
155
+ case _Unawaited(coro=coro):
156
+ shared = coro.spawn()
157
+ self._status = _Tokio(shared)
158
+ return shared.__await__()
159
+ case _Tokio(shared=shared):
160
+ return shared.__await__()
161
+ case _Asyncio(_):
162
+ raise ValueError(
163
+ "already converted into asyncio future, but being awaited from the tokio loop."
164
+ )
165
+ case _:
166
+ raise ValueError(
167
+ "already converted into a synchronous future, use 'get' to get the value."
168
+ )
169
+ else:
170
+ raise ValueError(
171
+ "__await__ with no active event loop (either asyncio or tokio)"
172
+ )
135
173
 
136
174
  # compatibility with old tensor engine Future objects
137
175
  # hopefully we do not need done(), add_callback because
@@ -145,3 +183,27 @@ class Future(Generic[R]):
145
183
  return None
146
184
  except Exception as e:
147
185
  return e
186
+
187
+
188
+ class DeprecatedNotAFuture:
189
+ """
190
+ We used to return Future[Alloc] and Future[Actor] and Future[ProcMesh].
191
+ Now the only Futures are generated as responses to messages.
192
+
193
+ This polyfills the await/get methods to those objects and raises the deprecation
194
+ warning that we are going to remove this.
195
+ """
196
+
197
+ def get(self) -> "Self":
198
+ cls = type(self)
199
+ typ = f"{cls.__module__}.{cls.__qualname__}"
200
+ warnings.warn(
201
+ f"This get()/await can be removed. get() and await is deprecated for {typ}, we directly return {typ} instead of Future[{typ}].\n",
202
+ DeprecationWarning,
203
+ stacklevel=2,
204
+ )
205
+ return self
206
+
207
+ def __await__(self) -> "Generator[Any, Any, Self]":
208
+ yield from ()
209
+ return self
@@ -0,0 +1,125 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import warnings
8
+ from math import prod
9
+
10
+ from typing import Callable, Dict, Optional, Tuple
11
+
12
+ from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints, AllocSpec
13
+
14
+ from monarch._src.actor.actor_mesh import context
15
+ from monarch._src.actor.allocator import AllocateMixin, AllocHandle, LocalAllocator
16
+ from monarch._src.actor.proc_mesh import _get_bootstrap_args, ProcessAllocator, ProcMesh
17
+ from monarch._src.actor.shape import MeshTrait, NDSlice, Shape
18
+
19
+
20
+ def this_host() -> "HostMesh":
21
+ """
22
+ The current machine.
23
+
24
+ This is just shorthand for looking it up via the context
25
+ """
26
+ return context().actor_instance.proc.host_mesh
27
+
28
+
29
+ def this_proc() -> "ProcMesh":
30
+ """
31
+ The current singleton process that this specific actor is
32
+ running on
33
+ """
34
+ return context().actor_instance.proc
35
+
36
+
37
+ def create_local_host_mesh() -> "HostMesh":
38
+ cmd, args, env = _get_bootstrap_args()
39
+ return HostMesh(Shape.unity(), ProcessAllocator(cmd, args, env))
40
+
41
+
42
+ class HostMesh(MeshTrait):
43
+ def __init__(self, shape: Shape, allocator: AllocateMixin):
44
+ self._allocator = allocator
45
+ self._shape = shape
46
+ self._spawned = 0
47
+
48
+ def _alloc(self, hosts: int, gpus: int) -> "AllocHandle":
49
+ spec: AllocSpec = AllocSpec(AllocConstraints(), hosts=hosts, gpus=gpus)
50
+ return self._allocator.allocate(spec)
51
+
52
+ def spawn_procs(
53
+ self,
54
+ per_host: Optional[Dict[str, int]] = None,
55
+ bootstrap: Optional[Callable[[], None]] = None,
56
+ ) -> "ProcMesh":
57
+ """
58
+ Start new processes on this host mesh. By default this starts one proc
59
+ on each host in the mesh. Additional procs can be started using `per_host` to
60
+ specify the local shape, e.g.
61
+ per_host = {'gpus': 8}
62
+ Will create a proc mesh with an additional 'gpus' dimension.
63
+
64
+ `bootstrap` is a function that will be run at startup on each proc and can be used to e.g.
65
+ configure CUDA or NCCL. We guarantee that CUDA has not been initialized before boostrap is called.
66
+ """
67
+ if per_host is None:
68
+ per_host = {}
69
+ if self._spawned > 0 and len(self._ndslice) > 1:
70
+ warnings.warn(
71
+ "spawning multiple procs on the same host mesh is kinda fake at the moment, there is no guarentee that the two different spawns will be on shared hosts",
72
+ stacklevel=2,
73
+ )
74
+ self._spawned += 1
75
+ hosts = len(self._ndslice)
76
+ flat_per_host = prod(per_host.values())
77
+ alloc_handle = self._alloc(hosts, flat_per_host)
78
+
79
+ new_extent = dict(zip(self._labels, self._ndslice.sizes))
80
+
81
+ conflicting_keys = set(per_host.keys()) & set(new_extent.keys())
82
+ if conflicting_keys:
83
+ raise ValueError(
84
+ f"host mesh already has dims {', '.join(sorted(conflicting_keys))}"
85
+ )
86
+
87
+ new_extent.update(per_host)
88
+ return ProcMesh.from_alloc(alloc_handle.reshape(new_extent), bootstrap)
89
+
90
+ @property
91
+ def _ndslice(self) -> NDSlice:
92
+ return self._shape.ndslice
93
+
94
+ @property
95
+ def _labels(self) -> Tuple[str, ...]:
96
+ return tuple(self._shape.labels)
97
+
98
+ def _new_with_shape(self, shape: Shape) -> "HostMesh":
99
+ warnings.warn(
100
+ "Slicing a host mesh is kinda fake at the moment, there is no guarentee that procs in the slice will end up on the corresponding hosts",
101
+ stacklevel=2,
102
+ )
103
+ return HostMesh(
104
+ Shape(self._labels, NDSlice.new_row_major(self._ndslice.sizes)),
105
+ self._allocator,
106
+ )
107
+
108
+
109
+ def fake_in_process_host() -> "HostMesh":
110
+ return HostMesh(Shape.unity(), LocalAllocator())
111
+
112
+
113
+ def hosts_from_config(name: str):
114
+ """
115
+ Get the host mesh 'name' from the monarch configuration for the project.
116
+
117
+ This config can be modified so that the same code can create meshes from scheduler sources,
118
+ and different sizes etc.
119
+
120
+ WARNING: This function is a standin so that our getting_started example code works. The real implementation
121
+ needs an RFC design.
122
+ """
123
+
124
+ shape = Shape(["hosts"], NDSlice.new_row_major([2]))
125
+ return HostMesh(shape, ProcessAllocator(*_get_bootstrap_args()))
@@ -0,0 +1,94 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import gc
10
+ import logging
11
+
12
+ from typing import Callable
13
+
14
+ from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
15
+
16
+ from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
17
+ from monarch._src.actor.future import Future
18
+
19
+ IN_IPYTHON = False
20
+ try:
21
+ # Check if we are in ipython environment
22
+ # pyre-ignore[21]
23
+ from IPython import get_ipython
24
+
25
+ # pyre-ignore[21]
26
+ from IPython.core.interactiveshell import ExecutionResult
27
+
28
+ IN_IPYTHON = get_ipython() is not None
29
+ except ImportError:
30
+ pass
31
+
32
+
33
+ class LoggingManager:
34
+ def __init__(self) -> None:
35
+ self._logging_mesh_client: LoggingMeshClient | None = None
36
+ self._ipython_flush_logs_handler: Callable[..., None] | None = None
37
+
38
+ async def init(self, proc_mesh: HyProcMesh, stream_to_client: bool) -> None:
39
+ if self._logging_mesh_client is not None:
40
+ return
41
+
42
+ self._logging_mesh_client = await LoggingMeshClient.spawn(proc_mesh=proc_mesh)
43
+ self._logging_mesh_client.set_mode(
44
+ stream_to_client=stream_to_client,
45
+ aggregate_window_sec=3 if stream_to_client else None,
46
+ level=logging.INFO,
47
+ )
48
+
49
+ if IN_IPYTHON:
50
+ # For ipython environment, a cell can end fast with threads running in background.
51
+ # Flush all the ongoing logs proactively to avoid missing logs.
52
+ assert self._logging_mesh_client is not None
53
+ logging_client: LoggingMeshClient = self._logging_mesh_client
54
+ ipython = get_ipython()
55
+
56
+ # pyre-ignore[11]
57
+ def flush_logs(_: ExecutionResult) -> None:
58
+ try:
59
+ Future(coro=logging_client.flush().spawn().task()).get(3)
60
+ except TimeoutError:
61
+ # We need to prevent failed proc meshes not coming back
62
+ pass
63
+
64
+ # Force to recycle previous undropped proc_mesh.
65
+ # Otherwise, we may end up with unregisterd dead callbacks.
66
+ gc.collect()
67
+
68
+ # Store the handler reference so we can unregister it later
69
+ self._ipython_flush_logs_handler = flush_logs
70
+ ipython.events.register("post_run_cell", flush_logs)
71
+
72
+ async def logging_option(
73
+ self,
74
+ stream_to_client: bool = True,
75
+ aggregate_window_sec: int | None = 3,
76
+ level: int = logging.INFO,
77
+ ) -> None:
78
+ if level < 0 or level > 255:
79
+ raise ValueError("Invalid logging level: {}".format(level))
80
+
81
+ assert self._logging_mesh_client is not None
82
+ self._logging_mesh_client.set_mode(
83
+ stream_to_client=stream_to_client,
84
+ aggregate_window_sec=aggregate_window_sec,
85
+ level=level,
86
+ )
87
+
88
+ def stop(self) -> None:
89
+ if self._ipython_flush_logs_handler is not None:
90
+ assert IN_IPYTHON
91
+ ipython = get_ipython()
92
+ assert ipython is not None
93
+ ipython.events.unregister("post_run_cell", self._ipython_flush_logs_handler)
94
+ self._ipython_flush_logs_handler = None