torchmonarch-nightly 2025.7.1__cp313-cp313-manylinux2014_x86_64.whl → 2025.7.26__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. monarch/__init__.py +13 -9
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
  4. monarch/_src/actor/actor_mesh.py +878 -0
  5. monarch/{allocator.py → _src/actor/allocator.py} +26 -17
  6. monarch/_src/actor/bootstrap_main.py +73 -0
  7. monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
  8. monarch/_src/actor/code_sync/auto_reload.py +223 -0
  9. monarch/_src/actor/debugger.py +565 -0
  10. monarch/_src/actor/endpoint.py +303 -0
  11. monarch/_src/actor/event_loop.py +97 -0
  12. monarch/_src/actor/future.py +100 -0
  13. monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
  14. monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
  15. monarch/_src/actor/proc_mesh.py +508 -0
  16. monarch/_src/actor/sync_state.py +18 -0
  17. monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
  18. monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
  19. monarch/_src/actor/tensor_engine_shim.py +59 -0
  20. monarch/_src/tensor_engine/rdma.py +180 -0
  21. monarch/_testing.py +3 -2
  22. monarch/actor/__init__.py +53 -0
  23. monarch/actor_mesh.py +6 -765
  24. monarch/bootstrap_main.py +8 -47
  25. monarch/common/client.py +1 -1
  26. monarch/common/controller_api.py +2 -1
  27. monarch/common/device_mesh.py +12 -2
  28. monarch/common/messages.py +21 -1
  29. monarch/common/recording.py +4 -3
  30. monarch/common/remote.py +135 -52
  31. monarch/common/tensor.py +2 -1
  32. monarch/controller/backend.py +2 -2
  33. monarch/controller/controller.py +2 -1
  34. monarch/controller/rust_backend/controller.py +2 -1
  35. monarch/fetch.py +3 -5
  36. monarch/gradient/_gradient_generator.so +0 -0
  37. monarch/mesh_controller.py +263 -139
  38. monarch/monarch_controller +0 -0
  39. monarch/opaque_module.py +4 -6
  40. monarch/opaque_object.py +3 -3
  41. monarch/proc_mesh.py +6 -309
  42. monarch/python_local_mesh.py +1 -1
  43. monarch/rust_backend_mesh.py +2 -1
  44. monarch/rust_local_mesh.py +4 -2
  45. monarch/sim_mesh.py +10 -19
  46. monarch/simulator/command_history.py +1 -1
  47. monarch/simulator/interface.py +2 -1
  48. monarch/simulator/mock_controller.py +1 -1
  49. monarch/simulator/simulator.py +1 -1
  50. monarch/tensor_engine/__init__.py +23 -0
  51. monarch/tensor_worker_main.py +3 -1
  52. monarch/tools/cli.py +3 -1
  53. monarch/tools/commands.py +129 -47
  54. monarch/tools/components/hyperactor.py +5 -3
  55. monarch/tools/config/__init__.py +18 -1
  56. monarch/tools/config/defaults.py +2 -2
  57. monarch/tools/mesh_spec.py +59 -1
  58. monarch/tools/utils.py +38 -0
  59. monarch/worker/worker.py +1 -1
  60. monarch/world_mesh.py +2 -1
  61. monarch_supervisor/python_executable.py +6 -3
  62. tests/error_test_binary.py +48 -10
  63. tests/test_actor_error.py +370 -21
  64. tests/test_alloc.py +1 -1
  65. tests/test_allocator.py +369 -17
  66. tests/test_controller.py +2 -0
  67. tests/test_debugger.py +416 -0
  68. tests/test_env_before_cuda.py +161 -0
  69. tests/test_python_actors.py +184 -333
  70. tests/test_rdma.py +198 -0
  71. tests/test_remote_functions.py +40 -12
  72. tests/test_rust_backend.py +7 -5
  73. tests/test_sim_backend.py +1 -4
  74. tests/test_tensor_engine.py +81 -1
  75. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/METADATA +39 -1
  76. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/RECORD +84 -72
  77. torchmonarch_nightly-2025.7.26.dist-info/entry_points.txt +3 -0
  78. monarch/_monarch/hyperactor/__init__.py +0 -58
  79. monarch/_monarch/worker/debugger.py +0 -117
  80. monarch/_monarch/worker/logging.py +0 -107
  81. monarch/debugger.py +0 -379
  82. monarch/future.py +0 -76
  83. monarch/rdma.py +0 -162
  84. torchmonarch_nightly-2025.7.1.dist-info/entry_points.txt +0 -3
  85. /monarch/{_monarch/worker → _src}/__init__.py +0 -0
  86. /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
  87. /monarch/{common → _src/actor}/shape.py +0 -0
  88. /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
  89. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/WHEEL +0 -0
  90. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/licenses/LICENSE +0 -0
  91. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/top_level.txt +0 -0
monarch/tools/commands.py CHANGED
@@ -7,22 +7,22 @@
7
7
  # pyre-strict
8
8
 
9
9
  import argparse
10
- import functools
10
+ import asyncio
11
11
  import inspect
12
12
  import logging
13
13
  import os
14
- import time
15
- from datetime import timedelta
14
+ from datetime import datetime, timedelta
16
15
  from typing import Any, Callable, Mapping, Optional, Union
17
16
 
17
+ from monarch.tools.components.hyperactor import DEFAULT_NAME
18
+
18
19
  from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/config/meta:defaults
19
20
  Config,
20
21
  defaults,
21
22
  )
22
-
23
23
  from monarch.tools.mesh_spec import mesh_spec_from_metadata, ServerSpec
24
- from torchx.runner import Runner
25
- from torchx.specs import AppDef, AppDryRunInfo, AppState, CfgVal
24
+ from torchx.runner import Runner # @manual=//torchx/runner:lib_core
25
+ from torchx.specs import AppDef, AppDryRunInfo, AppState, CfgVal, parse_app_handle
26
26
  from torchx.specs.builders import parse_args
27
27
  from torchx.util.types import decode, decode_optional
28
28
 
@@ -84,14 +84,10 @@ def component_args_from_cli(
84
84
 
85
85
  def create(
86
86
  config: Config,
87
- component_fn: Optional[Callable[..., AppDef]] = None,
88
- ) -> Callable[..., Union[str, AppDryRunInfo]]:
87
+ name: str = DEFAULT_NAME,
88
+ ) -> Union[str, AppDryRunInfo]:
89
89
  """Creates a monarch server by submitting it as a job to the target scheduler.
90
90
 
91
- Note that this function returns a `Callable` that has to be called with the
92
- same arguments that one would call the `component_fn` to actually submit
93
- the job that runs the monarch server.
94
-
95
91
  Usage:
96
92
 
97
93
  .. doc-test::
@@ -99,6 +95,8 @@ def create(
99
95
  from monarch.tools.config import defaults
100
96
 
101
97
  config = defaults.config(scheduler="slurm")
98
+ config.appdef = defaults.component_fn(scheduler=config.scheduler)()
99
+
102
100
  config.scheduler_args.update(
103
101
  {
104
102
  "partition": "prod",
@@ -108,7 +106,7 @@ def create(
108
106
  )
109
107
  config.dryrun = True
110
108
 
111
- create(default_config)(host_type="gpu.medium", num_hosts=4)
109
+ create(config)
112
110
 
113
111
 
114
112
  Args:
@@ -117,36 +115,32 @@ def create(
117
115
  component_fn: a function that returns the AppDef (job def).
118
116
  If not provided, defaults to the configured default for the scheduler
119
117
  (in most cases ``monarch.tools.components.hyperactor.proc_mesh``)
118
+ name: the name of the job. If none, a default job name will be created.
120
119
  """
121
120
  scheduler: str = config.scheduler
122
121
  cfg: Mapping[str, CfgVal] = config.scheduler_args
123
- component: Callable[..., AppDef] = component_fn or defaults.component_fn(scheduler)
124
-
125
- @functools.wraps(component)
126
- def _run(*args: Any, **kwargs: Any) -> Union[str, AppDryRunInfo]:
127
- # for logging call-site context in application metadata
128
- os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "monarch")
129
122
 
130
- appdef = component(*args, **kwargs)
123
+ # for logging call-site context in application metadata
124
+ os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "monarch")
131
125
 
132
- with torchx_runner() as runner:
133
- info = runner.dryrun(appdef, scheduler, cfg, config.workspace)
126
+ with torchx_runner() as runner:
127
+ appdef: AppDef = AppDef(name, config.appdef.roles, config.appdef.metadata)
134
128
 
135
- info_json_fmt = AppDryRunInfo(
136
- info.request,
137
- fmt=defaults.dryrun_info_formatter(info),
138
- )
139
- info_json_fmt._app = info._app
140
- info_json_fmt._cfg = info._cfg
141
- info_json_fmt._scheduler = info._scheduler
129
+ info = runner.dryrun(appdef, scheduler, cfg, config.workspace)
142
130
 
143
- if config.dryrun:
144
- return info_json_fmt
145
- else:
146
- server_handle = runner.schedule(info)
147
- return server_handle
131
+ info_json_fmt = AppDryRunInfo(
132
+ info.request,
133
+ fmt=defaults.dryrun_info_formatter(info),
134
+ )
135
+ info_json_fmt._app = info._app
136
+ info_json_fmt._cfg = info._cfg
137
+ info_json_fmt._scheduler = info._scheduler
148
138
 
149
- return _run
139
+ if config.dryrun:
140
+ return info_json_fmt
141
+ else:
142
+ server_handle = runner.schedule(info)
143
+ return server_handle
150
144
 
151
145
 
152
146
  def info(server_handle: str) -> Optional[ServerSpec]:
@@ -180,17 +174,27 @@ def info(server_handle: str) -> Optional[ServerSpec]:
180
174
  # null-guard since some schedulers do not fill replica_status
181
175
  if host_status := replica_status.get(role.name):
182
176
  spec.hostnames = [h.hostname for h in host_status]
177
+ # the mesh status is based on the "least progressive" replica status
178
+ spec.state = min(h.state for h in host_status)
183
179
 
184
180
  mesh_specs.append(spec)
185
181
 
186
- return ServerSpec(name=appdef.name, state=status.state, meshes=mesh_specs)
182
+ scheduler, namespace, _ = parse_app_handle(server_handle)
183
+ return ServerSpec(
184
+ name=appdef.name,
185
+ state=status.state,
186
+ meshes=mesh_specs,
187
+ scheduler=scheduler,
188
+ namespace=namespace,
189
+ )
187
190
 
188
191
 
189
192
  _5_SECONDS = timedelta(seconds=5)
190
193
 
191
194
 
192
195
  async def server_ready(
193
- server_handle: str, check_interval: timedelta = _5_SECONDS
196
+ server_handle: str,
197
+ check_interval: timedelta = _5_SECONDS,
194
198
  ) -> Optional[ServerSpec]:
195
199
  """Waits until the server's job is in RUNNING state to returns the server spec.
196
200
  Returns `None` if the server does not exist.
@@ -213,6 +217,8 @@ async def server_ready(
213
217
 
214
218
  """
215
219
 
220
+ check_interval_seconds = check_interval.total_seconds()
221
+ start = datetime.now()
216
222
  while True:
217
223
  server_spec = info(server_handle)
218
224
 
@@ -222,18 +228,94 @@ async def server_ready(
222
228
  if server_spec.state <= AppState.PENDING: # UNSUBMITTED or SUBMITTED or PENDING
223
229
  # NOTE: TorchX currently does not have async APIs so need to loop-on-interval
224
230
  # TODO maybe inverse exponential backoff instead of constant interval?
225
- check_interval_seconds = check_interval.total_seconds()
226
- logger.info(
227
- "waiting for %s to be %s (current: %s), will check again in %g seconds...",
228
- server_handle,
229
- AppState.RUNNING,
230
- server_spec.state,
231
- check_interval_seconds,
231
+ print(
232
+ f"Waiting for {server_handle} to be {AppState.RUNNING} (current: {server_spec.state}); "
233
+ f"will check again in {check_interval_seconds} seconds. "
234
+ f"Total wait time: {datetime.now() - start}",
235
+ end="\r",
232
236
  )
233
- time.sleep(check_interval_seconds)
237
+ await asyncio.sleep(check_interval_seconds)
234
238
  continue
235
- else:
236
- return server_spec
239
+
240
+ # check if hosts are allocated for all the meshes
241
+ if server_spec.state == AppState.RUNNING:
242
+ running = True
243
+ for mesh_spec in server_spec.meshes:
244
+ if mesh_spec.state <= AppState.PENDING:
245
+ print(
246
+ f"Job {server_handle} is running but waiting for mesh {mesh_spec.name} "
247
+ f"to be {AppState.RUNNING} (current: {mesh_spec.state}); "
248
+ f"will check again in {check_interval_seconds} seconds. "
249
+ f"Total wait time: {datetime.now() - start}",
250
+ end="\r",
251
+ )
252
+ running = False
253
+ break
254
+ if not running:
255
+ await asyncio.sleep(check_interval_seconds)
256
+ continue
257
+
258
+ return server_spec
259
+
260
+
261
+ # TODO: this API is overloaded. Ideally, we do not need config to get or an handle to create.
262
+ async def get_or_create(
263
+ name: str,
264
+ config: Config,
265
+ check_interval: timedelta = _5_SECONDS,
266
+ ) -> ServerSpec:
267
+ """Waits for the server based on identity `name` in the scheduler specified in the `config`
268
+ to be ready (e.g. RUNNING). If the server is not found then this function creates one
269
+ per the `config` spec, and waits for the server to be ready before returning.
270
+
271
+ Usage:
272
+
273
+ .. code-block:: python
274
+
275
+ from monarch.tools.config import defaults
276
+
277
+ config = defaults.config(scheduler)
278
+ config.appdef = defaults.component_fn(config.scheduler)()
279
+
280
+ server_handle = get_or_create(name="my_job_name", config)
281
+ server_info = info(server_handle)
282
+
283
+ Returns: A `ServerSpec` containing information about either the existing or the newly
284
+ created server.
285
+
286
+ """
287
+ assert not config.dryrun, "dryrun is not supported for get_or_create(), for dryrun use the create() API instead"
288
+
289
+ server_handle = f"{config.scheduler}:///{name}"
290
+ server_info = await server_ready(server_handle, check_interval)
291
+
292
+ if not server_info or not server_info.is_running: # then create one
293
+ logger.info(
294
+ "no existing RUNNING server `%s` creating new one...", server_handle
295
+ )
296
+
297
+ # no dryrun (see assertion above) support so will always be a handle (str)
298
+ new_server_handle = str(create(config, name))
299
+
300
+ logger.info(f"created new `{new_server_handle}` waiting for it to be ready...")
301
+
302
+ server_info = await server_ready(new_server_handle, check_interval)
303
+
304
+ if not server_info:
305
+ raise RuntimeError(
306
+ f"the new server `{new_server_handle}` went missing (should never happen)"
307
+ )
308
+
309
+ if not server_info.is_running:
310
+ raise RuntimeError(
311
+ f"the new server `{new_server_handle}` has {server_info.state}"
312
+ )
313
+
314
+ print(f"\x1b[36mNew job `{new_server_handle}` is ready to serve. \x1b[0m")
315
+ return server_info
316
+ else:
317
+ print(f"\x1b[36mFound existing job `{server_handle}` ready to serve. \x1b[0m")
318
+ return server_info
237
319
 
238
320
 
239
321
  def kill(server_handle: str) -> None:
@@ -9,6 +9,7 @@ import getpass
9
9
  from typing import Optional
10
10
 
11
11
  from monarch.tools import mesh_spec
12
+ from monarch.tools.config import UnnamedAppDef
12
13
  from monarch.tools.mesh_spec import mesh_spec_from_str
13
14
  from torchx import specs
14
15
 
@@ -16,17 +17,18 @@ _DEFAULT_MESHES = ["mesh_0:1:gpu.small"]
16
17
 
17
18
  _USER: str = getpass.getuser()
18
19
 
20
+ DEFAULT_NAME: str = f"monarch-{_USER}"
21
+
19
22
  __version__ = "latest" # TODO get version from monarch.__version_
20
23
 
21
24
 
22
25
  def proc_mesh(
23
- name: str = f"monarch-{_USER}",
24
26
  image: str = f"ghcr.io/pytorch-labs/monarch:{__version__}", # TODO docker needs to be built and pushed to ghcr
25
27
  meshes: list[str] = _DEFAULT_MESHES,
26
28
  env: Optional[dict[str, str]] = None,
27
29
  port: int = mesh_spec.DEFAULT_REMOTE_ALLOCATOR_PORT,
28
30
  program: str = "monarch_bootstrap", # installed with monarch wheel (as console script)
29
- ) -> specs.AppDef:
31
+ ) -> UnnamedAppDef:
30
32
  """
31
33
  Args:
32
34
  name: the name of the monarch server job
@@ -37,7 +39,7 @@ def proc_mesh(
37
39
  program: path to the binary that the remote process allocator spawns on an allocation request
38
40
  """
39
41
 
40
- appdef = specs.AppDef(name)
42
+ appdef = UnnamedAppDef()
41
43
 
42
44
  for mesh in [mesh_spec_from_str(mesh) for mesh in meshes]:
43
45
  mesh_role = specs.Role(
@@ -6,15 +6,32 @@
6
6
 
7
7
  # pyre-strict
8
8
  from dataclasses import dataclass, field
9
- from typing import Any, Optional
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ from torchx.specs import Role
10
12
 
11
13
 
12
14
  NOT_SET: str = "__NOT_SET__"
13
15
 
14
16
 
17
+ @dataclass
18
+ class UnnamedAppDef:
19
+ """
20
+ A TorchX AppDef without a name.
21
+ """
22
+
23
+ roles: List[Role] = field(default_factory=list)
24
+ metadata: Dict[str, str] = field(default_factory=dict)
25
+
26
+
15
27
  @dataclass
16
28
  class Config:
29
+ """
30
+ All configs needed to schedule a mesh of allocators.
31
+ """
32
+
17
33
  scheduler: str = NOT_SET
18
34
  scheduler_args: dict[str, Any] = field(default_factory=dict)
19
35
  workspace: Optional[str] = None
20
36
  dryrun: bool = False
37
+ appdef: UnnamedAppDef = UnnamedAppDef()
@@ -11,7 +11,7 @@
11
11
  from typing import Callable, Optional
12
12
 
13
13
  from monarch.tools.components import hyperactor
14
- from monarch.tools.config import Config
14
+ from monarch.tools.config import Config, UnnamedAppDef
15
15
 
16
16
  from torchx import specs
17
17
  from torchx.schedulers import (
@@ -23,7 +23,7 @@ from torchx.schedulers import (
23
23
  )
24
24
 
25
25
 
26
- def component_fn(scheduler: str) -> Callable[..., specs.AppDef]:
26
+ def component_fn(scheduler: str) -> Callable[..., UnnamedAppDef]:
27
27
  """The default TorchX component function for the scheduler"""
28
28
  return hyperactor.proc_mesh
29
29
 
@@ -9,8 +9,11 @@ import string
9
9
  from dataclasses import dataclass, field
10
10
  from typing import Any, Optional
11
11
 
12
+ from monarch.tools.config import UnnamedAppDef
13
+
12
14
  from monarch.tools.network import get_sockaddr
13
15
  from torchx import specs
16
+ from torchx.specs.api import is_terminal
14
17
 
15
18
  DEFAULT_REMOTE_ALLOCATOR_PORT = 26600
16
19
 
@@ -38,6 +41,7 @@ class MeshSpec:
38
41
  transport: str = "tcp"
39
42
  port: int = DEFAULT_REMOTE_ALLOCATOR_PORT
40
43
  hostnames: list[str] = field(default_factory=list)
44
+ state: specs.AppState = specs.AppState.UNSUBMITTED
41
45
 
42
46
  def server_addrs(
43
47
  self, transport: Optional[str] = None, port: Optional[int] = None
@@ -68,7 +72,7 @@ def _tag(mesh_name: str, tag_template: str) -> str:
68
72
  return string.Template(tag_template).substitute(mesh_name=mesh_name)
69
73
 
70
74
 
71
- def tag_as_metadata(mesh_spec: MeshSpec, appdef: specs.AppDef) -> None:
75
+ def tag_as_metadata(mesh_spec: MeshSpec, appdef: UnnamedAppDef) -> None:
72
76
  appdef.metadata[_tag(mesh_spec.name, _TAG_HOST_TYPE)] = mesh_spec.host_type
73
77
  appdef.metadata[_tag(mesh_spec.name, _TAG_GPUS)] = str(mesh_spec.gpus)
74
78
  appdef.metadata[_tag(mesh_spec.name, _TAG_TRANSPORT)] = mesh_spec.transport
@@ -122,11 +126,64 @@ class ServerSpec:
122
126
  name: str
123
127
  state: specs.AppState
124
128
  meshes: list[MeshSpec]
129
+ scheduler: str
130
+ namespace: str = ""
131
+
132
+ @property
133
+ def server_handle(self) -> str:
134
+ return f"{self.scheduler}://{self.namespace}/{self.name}"
125
135
 
126
136
  @property
127
137
  def is_running(self) -> bool:
128
138
  return self.state == specs.AppState.RUNNING
129
139
 
140
+ def host0(self, mesh_name: str) -> str:
141
+ """The hostname of the first node in the given mesh.
142
+ The return value of this method can be used to set `MASTER_ADDR` env var for torch.distributed.
143
+
144
+ NOTE: the state of this server must be RUNNING for this method to return a valid value.
145
+
146
+ Usage:
147
+
148
+ .. code-block::python
149
+ from monarch.tools.commands import get_or_create
150
+
151
+ server_info = await get_or_create(...)
152
+ assert server_info.is_running
153
+
154
+ # allocate proc mesh -> create actor (code omitted for brevity)...
155
+
156
+ trainer_actor.call(
157
+ MASTER_ADDR=server_info.host0("trainer") # trainer mesh's 1st host
158
+ MASTER_PORT=29500,
159
+ ...
160
+ )
161
+
162
+ NOTE: The ordering of the hostnames is exactly the same as what comes back from the underlying
163
+ scheduler's `describe_job` or `list_*` API. Please find the exact semantics in the
164
+ respective scheduler's implementation in https://github.com/pytorch/torchx/tree/main/torchx/schedulers.
165
+ """
166
+ mesh_spec = self.get_mesh_spec(mesh_name)
167
+ if self.is_running:
168
+ # hostnames are only valid when the server is RUNNING
169
+ if not mesh_spec.hostnames:
170
+ raise RuntimeError(f"{self.server_handle} does not have any hosts")
171
+ return mesh_spec.hostnames[0]
172
+ elif self.state in [specs.AppState.SUBMITTED, specs.AppState.PENDING]:
173
+ raise RuntimeError(
174
+ f"{self.server_handle} is {self.state}."
175
+ f" Use `monarch.tools.commands.server_ready()` to wait for the server to be {specs.AppState.RUNNING}"
176
+ )
177
+ elif is_terminal(self.state):
178
+ raise RuntimeError(
179
+ f"{self.server_handle} is {self.state}."
180
+ " Use `monarch.tools.commands.get_or_create()` to create a new server"
181
+ )
182
+ else:
183
+ raise RuntimeError(
184
+ f"{self.server_handle} is in an invalid state: {self.state}. Please report this as a bug"
185
+ )
186
+
130
187
  def get_mesh_spec(self, mesh_name: str) -> MeshSpec:
131
188
  for mesh_spec in self.meshes:
132
189
  if mesh_spec.name == mesh_name:
@@ -152,6 +209,7 @@ class ServerSpec:
152
209
 
153
210
  return {
154
211
  "name": self.name,
212
+ "server_handle": self.server_handle,
155
213
  "state": self.state.name,
156
214
  "meshes": {
157
215
  mesh.name: {
monarch/tools/utils.py ADDED
@@ -0,0 +1,38 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ import os
9
+ from typing import Optional
10
+
11
+
12
+ class conda:
13
+ """Conda related util functions."""
14
+
15
+ @staticmethod
16
+ def active_env_dir() -> Optional[str]:
17
+ """
18
+ Returns the currently active conda environment's directory.
19
+ `None` if run outside of a conda environment.
20
+ """
21
+ return os.getenv("CONDA_PREFIX")
22
+
23
+ @staticmethod
24
+ def active_env_name() -> Optional[str]:
25
+ """
26
+ Returns the currently active conda environment name.
27
+ `None` if run outside of a conda environment.
28
+ """
29
+ env_name = os.getenv("CONDA_DEFAULT_ENV")
30
+
31
+ if not env_name:
32
+ # conda envs activated with metaconda doesn't set CODNA_DEFAULT_ENV so
33
+ # fallback to CONDA_PREFIX which points to the path of the currently active conda environment
34
+ # e.g./home/$USER/.conda/envs/{env_name}
35
+ if env_dir := conda.active_env_dir():
36
+ env_name = os.path.basename(env_dir)
37
+
38
+ return env_name
monarch/worker/worker.py CHANGED
@@ -37,13 +37,13 @@ import torch.distributed
37
37
  import torch.fx
38
38
  import zmq
39
39
  import zmq.asyncio
40
+ from monarch._src.actor.shape import NDSlice
40
41
 
41
42
  from monarch.common import messages
42
43
  from monarch.common.function import ResolvableFunction
43
44
  from monarch.common.messages import DependentOnError, Dims
44
45
  from monarch.common.process_group import SingleControllerProcessGroupWrapper
45
46
  from monarch.common.reference import Ref, Referenceable
46
- from monarch.common.shape import NDSlice
47
47
  from monarch.common.tensor_factory import TensorFactory
48
48
  from monarch.common.tree import flatten, flattener
49
49
  from monarch_supervisor import get_message_queue, Letter
monarch/world_mesh.py CHANGED
@@ -8,10 +8,11 @@
8
8
 
9
9
  from typing import List
10
10
 
11
+ from monarch._src.actor.shape import NDSlice
12
+
11
13
  from monarch.common.client import Client
12
14
 
13
15
  from monarch.common.device_mesh import DeviceMesh
14
- from monarch.common.shape import NDSlice
15
16
 
16
17
  from monarch.controller.backend import ProcessBackend
17
18
 
@@ -11,7 +11,10 @@ import sys
11
11
  try:
12
12
  from __manifest__ import fbmake # noqa
13
13
 
14
- IN_PAR = True
14
+ # simply checking for the existence of __manifest__ is not enough to tell if we are in a PAR
15
+ # because monarch wheels include a dummy __manifest__ (see fbcode//monarch/python/monarch/session/meta/__manifest__.py)
16
+ # so that we can use libfb programmatically. Hence additionally check if the `par_style` key is not null/empty
17
+ IN_PAR = bool(fbmake.get("par_style"))
15
18
  except ImportError:
16
19
  IN_PAR = False
17
20
 
@@ -26,8 +29,8 @@ if IN_PAR:
26
29
  PYTHON_EXECUTABLE = os.environ["FB_XAR_INVOKED_NAME"]
27
30
  else:
28
31
  try:
29
- with importlib.resources.path(
30
- "monarch_tensor_worker_env", "worker_env"
32
+ with importlib.resources.as_file(
33
+ importlib.resources.files("monarch_tensor_worker_env") / "worker_env"
31
34
  ) as path:
32
35
  if not path.exists():
33
36
  raise ImportError()
@@ -13,8 +13,7 @@ from monarch._rust_bindings.monarch_extension.blocking import blocking_function
13
13
 
14
14
  from monarch._rust_bindings.monarch_extension.panic import panicking_function
15
15
 
16
- from monarch.actor_mesh import Actor, endpoint, send
17
- from monarch.proc_mesh import proc_mesh
16
+ from monarch.actor import Actor, endpoint, proc_mesh, send
18
17
 
19
18
 
20
19
  class ErrorActor(Actor):
@@ -48,6 +47,13 @@ class ErrorActor(Actor):
48
47
  await asyncio.sleep(0.1)
49
48
  raise RuntimeError("oh noez")
50
49
 
50
+ @endpoint
51
+ async def get_pid(self) -> int:
52
+ """Endpoint that returns the process PID."""
53
+ import os
54
+
55
+ return os.getpid()
56
+
51
57
 
52
58
  class ErrorActorSync(Actor):
53
59
  """An actor that has endpoints cause segfaults."""
@@ -79,8 +85,7 @@ def _run_error_test_sync(num_procs, sync_endpoint, endpoint_name):
79
85
  error_actor = proc.spawn("error_actor", actor_class).get()
80
86
 
81
87
  # This output is checked in the test to make sure that the process actually got here
82
- print("I actually ran")
83
- sys.stdout.flush()
88
+ print("Started function error_test", flush=True)
84
89
 
85
90
  if endpoint_name == "cause_segfault":
86
91
  endpoint = error_actor.cause_segfault
@@ -110,8 +115,7 @@ def _run_error_test(num_procs, sync_endpoint, endpoint_name):
110
115
  error_actor = await proc.spawn("error_actor", actor_class)
111
116
 
112
117
  # This output is checked in the test to make sure that the process actually got here
113
- print("I actually ran")
114
- sys.stdout.flush()
118
+ print("Started function error_test", flush=True)
115
119
 
116
120
  if endpoint_name == "cause_segfault":
117
121
  endpoint = error_actor.cause_segfault
@@ -153,15 +157,13 @@ def error_endpoint(num_procs, sync_test_impl, sync_endpoint, endpoint_name):
153
157
 
154
158
  @main.command("error-bootstrap")
155
159
  def error_bootstrap():
156
- print("I actually ran")
157
- sys.stdout.flush()
160
+ print("Started function error_bootstrap", flush=True)
158
161
 
159
162
  proc_mesh(gpus=4, env={"MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING": "1"}).get()
160
163
 
161
164
 
162
165
  async def _error_unmonitored():
163
- print("I actually ran")
164
- sys.stdout.flush()
166
+ print("Started function _error_unmonitored", flush=True)
165
167
 
166
168
  proc = await proc_mesh(gpus=1)
167
169
  actor = await proc.spawn("error_actor", ErrorActor)
@@ -204,5 +206,41 @@ def error_unmonitored():
204
206
  asyncio.run(_error_unmonitored())
205
207
 
206
208
 
209
+ async def _error_cleanup():
210
+ """Test function that spawns an 8 process procmesh and calls an endpoint that returns a normal exception."""
211
+ print("Started function _error_cleanup() for parent process", flush=True)
212
+
213
+ # Spawn an 8 process procmesh
214
+ proc = await proc_mesh(gpus=8)
215
+ error_actor = await proc.spawn("error_actor", ErrorActor)
216
+
217
+ print("Procmesh spawned, collecting child PIDs from actors", flush=True)
218
+
219
+ # Get PIDs from all actor processes
220
+ try:
221
+ # Call get_pid endpoint on all actors to collect their PIDs
222
+ pids = await error_actor.get_pid.call()
223
+ child_pids = [str(pid) for _, pid in pids]
224
+ print(f"CHILD_PIDS: {','.join(child_pids)}", flush=True)
225
+ except Exception as e:
226
+ print(f"Error getting child PIDs from actors: {e}", flush=True)
227
+
228
+ print("About to call endpoint that raises exception", flush=True)
229
+
230
+ # Call an endpoint that raises a normal exception
231
+ try:
232
+ await error_actor.await_then_error.call()
233
+ except Exception as e:
234
+ print(f"Expected exception caught: {e}", flush=True)
235
+ # Re-raise to cause the process to exit with non-zero code
236
+ raise
237
+
238
+
239
+ @main.command("error-cleanup")
240
+ def error_cleanup():
241
+ """Command that spawns an 8 process procmesh and calls an endpoint that returns a normal exception."""
242
+ asyncio.run(_error_cleanup())
243
+
244
+
207
245
  if __name__ == "__main__":
208
246
  main()