torchmonarch-nightly 2025.8.2__cp310-cp310-manylinux2014_x86_64.whl → 2025.9.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. monarch/_rust_bindings.so +0 -0
  2. monarch/_src/actor/actor_mesh.py +504 -218
  3. monarch/_src/actor/allocator.py +75 -6
  4. monarch/_src/actor/bootstrap_main.py +7 -4
  5. monarch/_src/actor/code_sync/__init__.py +2 -0
  6. monarch/_src/actor/debugger/__init__.py +7 -0
  7. monarch/_src/actor/{debugger.py → debugger/debugger.py} +246 -135
  8. monarch/_src/actor/{pdb_wrapper.py → debugger/pdb_wrapper.py} +62 -23
  9. monarch/_src/actor/endpoint.py +27 -45
  10. monarch/_src/actor/future.py +86 -24
  11. monarch/_src/actor/host_mesh.py +125 -0
  12. monarch/_src/actor/logging.py +94 -0
  13. monarch/_src/actor/pickle.py +25 -0
  14. monarch/_src/actor/proc_mesh.py +423 -156
  15. monarch/_src/actor/python_extension_methods.py +90 -0
  16. monarch/_src/actor/shape.py +8 -1
  17. monarch/_src/actor/source_loader.py +45 -0
  18. monarch/_src/actor/telemetry/__init__.py +172 -0
  19. monarch/_src/actor/telemetry/rust_span_tracing.py +6 -39
  20. monarch/_src/debug_cli/__init__.py +7 -0
  21. monarch/_src/debug_cli/debug_cli.py +43 -0
  22. monarch/_src/tensor_engine/rdma.py +64 -9
  23. monarch/_testing.py +1 -3
  24. monarch/actor/__init__.py +24 -4
  25. monarch/common/_C.so +0 -0
  26. monarch/common/device_mesh.py +14 -0
  27. monarch/common/future.py +10 -0
  28. monarch/common/remote.py +14 -25
  29. monarch/common/tensor.py +12 -0
  30. monarch/debug_cli/__init__.py +7 -0
  31. monarch/debug_cli/__main__.py +12 -0
  32. monarch/fetch.py +2 -2
  33. monarch/gradient/_gradient_generator.so +0 -0
  34. monarch/gradient_generator.py +4 -2
  35. monarch/mesh_controller.py +34 -14
  36. monarch/monarch_controller +0 -0
  37. monarch/tools/colors.py +25 -0
  38. monarch/tools/commands.py +42 -7
  39. monarch/tools/components/hyperactor.py +6 -4
  40. monarch/tools/config/__init__.py +35 -12
  41. monarch/tools/config/defaults.py +15 -5
  42. monarch/tools/config/environment.py +45 -0
  43. monarch/tools/config/workspace.py +165 -0
  44. monarch/tools/mesh_spec.py +3 -3
  45. monarch/utils/__init__.py +9 -0
  46. monarch/utils/utils.py +78 -0
  47. tests/error_test_binary.py +5 -3
  48. tests/python_actor_test_binary.py +52 -0
  49. tests/test_actor_error.py +142 -14
  50. tests/test_alloc.py +1 -1
  51. tests/test_allocator.py +59 -72
  52. tests/test_debugger.py +639 -45
  53. tests/test_env_before_cuda.py +4 -4
  54. tests/test_mesh_trait.py +38 -0
  55. tests/test_python_actors.py +965 -75
  56. tests/test_rdma.py +7 -6
  57. tests/test_tensor_engine.py +6 -6
  58. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/METADATA +82 -4
  59. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/RECORD +63 -47
  60. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/WHEEL +0 -0
  61. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/entry_points.txt +0 -0
  62. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/licenses/LICENSE +0 -0
  63. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/top_level.txt +0 -0
@@ -154,6 +154,20 @@ class DeviceMeshInfo:
154
154
 
155
155
 
156
156
  class DeviceMesh(Referenceable, MeshTrait):
157
+ """A mesh of devices for distributed tensor operations.
158
+
159
+ DeviceMesh represents a collection of devices arranged in a
160
+ multidimensional grid for parallel computation. It manages
161
+ communication between devices and enables distributed execution
162
+ of operations across the mesh.
163
+
164
+ Args:
165
+ client (Client): The client connection to the mesh infrastructure
166
+ processes (NDSlice): Multi-dimensional slice representing the process layout
167
+ names (Dims): Names for each dimension of the mesh
168
+ mesh_name (str, optional): Name identifier for the mesh. Default: "default"
169
+ """
170
+
157
171
  def __init__(
158
172
  self,
159
173
  client: "Client",
monarch/common/future.py CHANGED
@@ -68,6 +68,16 @@ T = TypeVar("T")
68
68
 
69
69
 
70
70
  class Future(Generic[T]):
71
+ """A future object representing the result of an asynchronous computation.
72
+
73
+ Future provides a way to access the result of a computation that may not
74
+ have completed yet. It allows for non-blocking execution and provides
75
+ methods to wait for completion and retrieve results.
76
+
77
+ Args:
78
+ client (Client): The client connection for handling the future
79
+ """
80
+
71
81
  def __init__(self, client: "Client"):
72
82
  self._client = client
73
83
  self._status = "incomplete"
monarch/common/remote.py CHANGED
@@ -28,10 +28,10 @@ from typing import (
28
28
  import monarch.common.messages as messages
29
29
 
30
30
  import torch
31
- from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
32
- from monarch._rust_bindings.monarch_hyperactor.shape import Shape
33
- from monarch._src.actor.actor_mesh import Port, PortTuple
34
- from monarch._src.actor.endpoint import Extent, Selection
31
+ from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape
32
+ from monarch._src.actor.actor_mesh import Port
33
+ from monarch._src.actor.endpoint import Selection
34
+ from monarch._src.actor.future import Future
35
35
 
36
36
  from monarch.common import _coalescing, device_mesh, stream
37
37
  from monarch.common.future import Future as OldFuture
@@ -135,20 +135,6 @@ class Remote(Generic[P, R], Endpoint[P, R]):
135
135
  client._request_status()
136
136
  return Extent(ambient_mesh._labels, ambient_mesh._ndslice.sizes)
137
137
 
138
- def _port(self, once: bool = False) -> "PortTuple[R]":
139
- ambient_mesh = device_mesh._active
140
- if ambient_mesh is None:
141
- raise ValueError(
142
- "FIXME - cannot create a port without an active proc_mesh, because there is not way to create a port without a mailbox"
143
- )
144
- mesh_controller = getattr(ambient_mesh.client, "_mesh_controller", None)
145
- if mesh_controller is None:
146
- raise ValueError(
147
- "Cannot create raw port objects with an old-style tensor engine controller."
148
- )
149
- mailbox: Mailbox = mesh_controller._mailbox
150
- return PortTuple.create(mailbox, once)
151
-
152
138
  @property
153
139
  def _resolvable(self):
154
140
  return resolvable_function(self._remote_impl)
@@ -212,7 +198,7 @@ remote_identity = Remote(None, lambda x: x)
212
198
 
213
199
  def call_on_shard_and_fetch(
214
200
  remote: Endpoint[P, R], *args, shard: Dict[str, int] | None = None, **kwargs
215
- ) -> OldFuture[R]:
201
+ ) -> Future[R]:
216
202
  # We have to flatten the tensors twice: first to discover
217
203
  # which mesh we are working on to shard it, and then again when doing the
218
204
  # dtensor_check in send. This complexity is a consequence of doing
@@ -224,17 +210,20 @@ def call_on_shard_and_fetch(
224
210
  checker.check_mesh_stream_local(device_mesh._active, stream._active)
225
211
 
226
212
  if not hasattr(checker.mesh.client, "_mesh_controller"):
227
- return _old_call_on_shard_and_fetch(
228
- cast("Remote[P, R]", remote),
229
- *args,
230
- shard=shard,
231
- **kwargs,
213
+ return cast(
214
+ "Future[R]",
215
+ _old_call_on_shard_and_fetch(
216
+ cast("Remote[P, R]", remote),
217
+ *args,
218
+ shard=shard,
219
+ **kwargs,
220
+ ),
232
221
  )
233
222
 
234
223
  selected_slice = checker.mesh._process(shard)
235
224
  shard_mesh = checker.mesh._new_with_shape(Shape(["_"], selected_slice))
236
225
  with shard_mesh.activate():
237
- return cast("OldFuture[R]", remote.call_one(*args, **kwargs))
226
+ return remote.call_one(*args, **kwargs)
238
227
 
239
228
 
240
229
  def _old_call_on_shard_and_fetch(
monarch/common/tensor.py CHANGED
@@ -74,6 +74,18 @@ class DropLocation(NamedTuple):
74
74
 
75
75
 
76
76
  class Tensor(Referenceable, BaseTensor):
77
+ """A distributed tensor for distributed computation across device meshes.
78
+
79
+ Tensor represents a distributed tensor that spans across multiple devices
80
+ in a device mesh. It provides the same interface as PyTorch tensors but
81
+ enables distributed operations and communication patterns.
82
+
83
+ Args:
84
+ fake (torch.Tensor): A fake tensor representing the shape and type
85
+ mesh (DeviceMesh): The device mesh this tensor is distributed across
86
+ stream (Stream): The computation stream for this tensor
87
+ """
88
+
77
89
  # pyre-fixme[13]: Attribute `stream` is never initialized.
78
90
  stream: Stream
79
91
  # pyre-fixme[13]: Attribute `mesh` is never initialized.
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
@@ -0,0 +1,12 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from monarch._src.debug_cli import debug_cli
9
+
10
+
11
+ if __name__ == "__main__":
12
+ debug_cli.run()
monarch/fetch.py CHANGED
@@ -11,9 +11,9 @@ This is a utility file for fetching a shard of a tensor from remote.
11
11
 
12
12
  from typing import cast, TypeVar
13
13
 
14
- from monarch.common.device_mesh import no_mesh
14
+ from monarch.actor import Future
15
15
 
16
- from monarch.common.future import Future
16
+ from monarch.common.device_mesh import no_mesh
17
17
 
18
18
  from monarch.common.remote import call_on_shard_and_fetch, remote_identity
19
19
 
Binary file
@@ -151,14 +151,16 @@ def grad_function(fn):
151
151
 
152
152
 
153
153
  def gradient_execution_order(
154
- roots: Sequence[TensorOrEdge], with_respect_to: Sequence[TensorOrEdge]
154
+ roots: Sequence[TensorOrEdge], with_respect_to: Sequence[Any]
155
155
  ) -> List[int]:
156
156
  """
157
157
  Returns the order in which the gradients for `with_respect_to` would become available
158
158
  if autograd were run on `roots`. This is the reverse order of each tensors
159
159
  first use in the gradient computation.
160
160
  """
161
- with_respect_to = [_gradient_edge(g) for g in with_respect_to]
161
+ with_respect_to = [
162
+ (g.node, g.output_nr) for g in map(_gradient_edge, with_respect_to)
163
+ ]
162
164
  min_sequence_nr: Dict[Any, float] = {e: math.inf for e in with_respect_to}
163
165
 
164
166
  to_scan = [_gradient_edge(r).node for r in roots]
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-unsafe
8
+
7
9
  import atexit
8
10
  import logging
9
11
  import os
@@ -43,7 +45,7 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarc
43
45
  ActorId,
44
46
  )
45
47
  from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
46
- from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple
48
+ from monarch._src.actor.actor_mesh import ActorEndpoint, Channel, Port
47
49
  from monarch._src.actor.endpoint import Selection
48
50
  from monarch._src.actor.shape import NDSlice
49
51
  from monarch.common import device_mesh, messages, stream
@@ -63,6 +65,7 @@ if TYPE_CHECKING:
63
65
  from monarch.actor import ProcMesh
64
66
 
65
67
  from monarch._rust_bindings.monarch_hyperactor.shape import Point
68
+ from monarch._src.actor.device_utils import _local_device_count
66
69
 
67
70
  from monarch.common.client import Client
68
71
  from monarch.common.controller_api import LogMessage, MessageResult
@@ -119,9 +122,18 @@ def _initialize_env(worker_point: Point, proc_id: str) -> None:
119
122
  worker_rank = worker_point.rank
120
123
  try:
121
124
  _, worker_env = _get_worker_exec_info()
122
- local_rank = worker_point["gpus"]
123
- gpus_per_host = worker_point.size("gpus")
124
- num_worker_procs = len(worker_point.shape)
125
+
126
+ if "gpus" in worker_point:
127
+ local_rank = worker_point["gpus"]
128
+ gpus_per_host = worker_point.size("gpus")
129
+ elif "gpu" in worker_point:
130
+ local_rank = worker_point["gpu"]
131
+ gpus_per_host = worker_point.size("gpu")
132
+ else:
133
+ gpus_per_host = _local_device_count()
134
+ local_rank = worker_rank % gpus_per_host
135
+
136
+ num_worker_procs = worker_point.extent.nelements
125
137
  process_env = {
126
138
  **worker_env,
127
139
  "CUDA_VISIBLE_DEVICES": str(local_rank),
@@ -156,7 +168,7 @@ class MeshClient(Client):
156
168
  defs: Tuple["Tensor", ...],
157
169
  uses: Tuple["Tensor", ...],
158
170
  ) -> "OldFuture": # the OldFuture is a lie
159
- sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
171
+ sender, receiver = Channel.open(once=True)
160
172
 
161
173
  ident = self.new_node(defs, uses, cast("OldFuture", sender))
162
174
  process = mesh._process(shard)
@@ -192,7 +204,7 @@ class MeshClient(Client):
192
204
  atexit.unregister(self._atexit)
193
205
  self._shutdown = True
194
206
 
195
- sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
207
+ sender, receiver = Channel.open(once=True)
196
208
  assert sender._port_ref is not None
197
209
  self._mesh_controller.sync_at_exit(sender._port_ref.port_id)
198
210
  receiver.recv().get(timeout=60)
@@ -200,6 +212,14 @@ class MeshClient(Client):
200
212
  # waited for the responses
201
213
  self.inner.drain_and_stop()
202
214
 
215
+ def _atexit(self) -> None:
216
+ # Calling self.shutdown may cause a deadlock if something is wrong with
217
+ # the networking. Or should we make shutdown() not wait indefinitely?
218
+ self._shutdown = True
219
+
220
+ # send shutdown message to stop other processes.
221
+ self.inner.stop_mesh()
222
+
203
223
  @property
204
224
  def _mesh_controller(self) -> Controller:
205
225
  return cast(Controller, self.inner)
@@ -235,7 +255,9 @@ def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
235
255
  # is currently only used for debug printing. It should be fixed to
236
256
  # report the proc ID instead of the rank it currently does.
237
257
  gpus = proc_mesh.sizes.get("gpus", 1)
238
- backend_ctrl = Controller(proc_mesh._proc_mesh)
258
+
259
+ # we currently block on the creation of the proc mesh, but conceivably we could init concurrently here.
260
+ backend_ctrl = Controller(proc_mesh._proc_mesh.block_on())
239
261
  client = MeshClient(cast("TController", backend_ctrl), proc_mesh.size(), gpus)
240
262
  dm = DeviceMesh(
241
263
  client,
@@ -273,7 +295,7 @@ class RemoteException(Exception):
273
295
 
274
296
  def _cast_call_method_indirect(
275
297
  endpoint: ActorEndpoint,
276
- selection: Selection,
298
+ selection: str,
277
299
  client: MeshClient,
278
300
  seq: Seq,
279
301
  args_kwargs_tuple: bytes,
@@ -290,7 +312,7 @@ def _cast_call_method_indirect(
290
312
  ),
291
313
  args_kwargs_tuple,
292
314
  )
293
- endpoint._actor_mesh.cast(actor_msg, selection)
315
+ endpoint._actor_mesh.cast(actor_msg, selection, endpoint._mailbox)
294
316
  return broker_id
295
317
 
296
318
 
@@ -299,7 +321,7 @@ def actor_send(
299
321
  args_kwargs_tuple: bytes,
300
322
  refs: Sequence[Any],
301
323
  port: Optional[Port[Any]],
302
- selection: Selection,
324
+ selection: str,
303
325
  ):
304
326
  tensors = [ref for ref in refs if isinstance(ref, Tensor)]
305
327
  # we have some monarch references, we need to ensure their
@@ -314,9 +336,7 @@ def actor_send(
314
336
  # TODO: move propagators into Endpoint abstraction and run the propagator to get the
315
337
  # mutates
316
338
  checker.check_permission(())
317
- selected_device_mesh = (
318
- endpoint._actor_mesh._proc_mesh and endpoint._actor_mesh._proc_mesh._device_mesh
319
- )
339
+ selected_device_mesh = endpoint._proc_mesh and endpoint._proc_mesh._device_mesh
320
340
  if selected_device_mesh is not checker.mesh:
321
341
  raise ValueError(
322
342
  f"monarch Tensors sent to an actor must be located on the same process as the actor. However {checker.mesh} is not {selected_device_mesh}."
@@ -350,7 +370,7 @@ def _actor_send(
350
370
  args_kwargs_tuple: bytes,
351
371
  refs: Sequence[Any],
352
372
  port: Optional[Port[Any]],
353
- selection: Selection,
373
+ selection: str,
354
374
  client: MeshClient,
355
375
  mesh: DeviceMesh,
356
376
  tensors: List[Tensor],
Binary file
@@ -0,0 +1,25 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import sys
10
+
11
+ # only print colors if outputting directly to a terminal
12
+ if not sys.stdout.closed and sys.stdout.isatty():
13
+ GREEN = "\033[32m"
14
+ BLUE = "\033[34m"
15
+ ORANGE = "\033[38:2:238:76:44m"
16
+ GRAY = "\033[2m"
17
+ CYAN = "\033[36m"
18
+ ENDC = "\033[0m"
19
+ else:
20
+ GREEN = ""
21
+ ORANGE = ""
22
+ BLUE = ""
23
+ GRAY = ""
24
+ CYAN = ""
25
+ ENDC = ""
monarch/tools/commands.py CHANGED
@@ -11,9 +11,12 @@ import asyncio
11
11
  import inspect
12
12
  import logging
13
13
  import os
14
+ import tempfile
14
15
  from datetime import datetime, timedelta
16
+ from pathlib import Path
15
17
  from typing import Any, Callable, Mapping, Optional, Union
16
18
 
19
+ from monarch.tools.colors import CYAN, ENDC
17
20
  from monarch.tools.components.hyperactor import DEFAULT_NAME
18
21
 
19
22
  from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/config/meta:defaults
@@ -21,6 +24,8 @@ from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/con
21
24
  defaults,
22
25
  )
23
26
  from monarch.tools.mesh_spec import mesh_spec_from_metadata, ServerSpec
27
+ from monarch.tools.utils import MONARCH_HOME
28
+
24
29
  from torchx.runner import Runner # @manual=//torchx/runner:lib_core
25
30
  from torchx.specs import AppDef, AppDryRunInfo, AppState, CfgVal, parse_app_handle
26
31
  from torchx.specs.builders import parse_args
@@ -125,8 +130,18 @@ def create(
125
130
 
126
131
  with torchx_runner() as runner:
127
132
  appdef: AppDef = AppDef(name, config.appdef.roles, config.appdef.metadata)
133
+ if not config.workspace.dirs and not config.workspace.env:
134
+ info = runner.dryrun(appdef, scheduler, cfg, workspace=None)
135
+ else:
136
+ with tempfile.TemporaryDirectory(dir=MONARCH_HOME("out")) as tmpdir:
137
+ # multi-directory workspace is not supported natively in torchx; so merge into a single one
138
+ # TODO (kiuk@) may be able to delete bootstrap workspace copy (as the job is created)
139
+ # since proc_mesh.sync_workspace() can do this without having to merge the workspace
140
+ workspace_out = Path(tmpdir) / "workspace"
141
+ config.workspace.merge(workspace_out)
142
+ config.workspace.set_env_vars(appdef)
128
143
 
129
- info = runner.dryrun(appdef, scheduler, cfg, config.workspace)
144
+ info = runner.dryrun(appdef, scheduler, cfg, str(workspace_out))
130
145
 
131
146
  info_json_fmt = AppDryRunInfo(
132
147
  info.request,
@@ -173,19 +188,25 @@ def info(server_handle: str) -> Optional[ServerSpec]:
173
188
 
174
189
  # null-guard since some schedulers do not fill replica_status
175
190
  if host_status := replica_status.get(role.name):
176
- spec.hostnames = [h.hostname for h in host_status]
191
+ # make sure the hostnames are sorted by their respective node indexes
192
+ # this makes ServerSpec.host0 return hostname of node 0
193
+ spec.hostnames = [
194
+ h.hostname for h in sorted(host_status, key=lambda h: h.id)
195
+ ]
177
196
  # the mesh status is based on the "least progressive" replica status
178
197
  spec.state = min(h.state for h in host_status)
179
198
 
180
199
  mesh_specs.append(spec)
181
200
 
182
201
  scheduler, namespace, _ = parse_app_handle(server_handle)
202
+
183
203
  return ServerSpec(
184
204
  name=appdef.name,
185
205
  state=status.state,
186
206
  meshes=mesh_specs,
187
207
  scheduler=scheduler,
188
208
  namespace=namespace,
209
+ ui_url=status.ui_url,
189
210
  )
190
211
 
191
212
 
@@ -263,6 +284,7 @@ async def get_or_create(
263
284
  name: str,
264
285
  config: Config,
265
286
  check_interval: timedelta = _5_SECONDS,
287
+ force_restart: bool = False,
266
288
  ) -> ServerSpec:
267
289
  """Waits for the server based on identity `name` in the scheduler specified in the `config`
268
290
  to be ready (e.g. RUNNING). If the server is not found then this function creates one
@@ -280,6 +302,12 @@ async def get_or_create(
280
302
  server_handle = get_or_create(name="my_job_name", config)
281
303
  server_info = info(server_handle)
282
304
 
305
+ Args:
306
+ name: the name of the server (job) to get or create
307
+ config: configs used to create the job if one does not exist
308
+ check_interval: how often to poll the status of the job when waiting for it to be ready
309
+ force_restart: if True kills and re-creates the job even if one exists
310
+
283
311
  Returns: A `ServerSpec` containing information about either the existing or the newly
284
312
  created server.
285
313
 
@@ -288,7 +316,6 @@ async def get_or_create(
288
316
 
289
317
  server_handle = f"{config.scheduler}:///{name}"
290
318
  server_info = await server_ready(server_handle, check_interval)
291
-
292
319
  if not server_info or not server_info.is_running: # then create one
293
320
  logger.info(
294
321
  "no existing RUNNING server `%s` creating new one...", server_handle
@@ -311,11 +338,19 @@ async def get_or_create(
311
338
  f"the new server `{new_server_handle}` has {server_info.state}"
312
339
  )
313
340
 
314
- print(f"\x1b[36mNew job `{new_server_handle}` is ready to serve. \x1b[0m")
315
- return server_info
341
+ print(f"{CYAN}New job `{new_server_handle}` is ready to serve.{ENDC}")
316
342
  else:
317
- print(f"\x1b[36mFound existing job `{server_handle}` ready to serve. \x1b[0m")
318
- return server_info
343
+ print(f"{CYAN}Found existing job `{server_handle}` ready to serve.{ENDC}")
344
+
345
+ if force_restart:
346
+ print(f"{CYAN}force_restart=True, restarting `{server_handle}`.{ENDC}")
347
+ kill(server_handle)
348
+ server_info = await get_or_create(name, config, check_interval)
349
+
350
+ if server_info.ui_url: # not all schedulers have a UI URL
351
+ print(f"{CYAN}Job URL: {server_info.ui_url}{ENDC}")
352
+
353
+ return server_info
319
354
 
320
355
 
321
356
  def kill(server_handle: str) -> None:
@@ -9,7 +9,8 @@ import getpass
9
9
  from typing import Optional
10
10
 
11
11
  from monarch.tools import mesh_spec
12
- from monarch.tools.config import UnnamedAppDef
12
+
13
+ from monarch.tools.config import NOT_SET
13
14
  from monarch.tools.mesh_spec import mesh_spec_from_str
14
15
  from torchx import specs
15
16
 
@@ -19,16 +20,17 @@ _USER: str = getpass.getuser()
19
20
 
20
21
  DEFAULT_NAME: str = f"monarch-{_USER}"
21
22
 
23
+
22
24
  __version__ = "latest" # TODO get version from monarch.__version_
23
25
 
24
26
 
25
27
  def host_mesh(
26
- image: str = f"ghcr.io/pytorch-labs/monarch:{__version__}", # TODO docker needs to be built and pushed to ghcr
28
+ image: str = f"ghcr.io/meta-pytorch/monarch:{__version__}", # TODO docker needs to be built and pushed to ghcr
27
29
  meshes: list[str] = _DEFAULT_MESHES,
28
30
  env: Optional[dict[str, str]] = None,
29
31
  port: int = mesh_spec.DEFAULT_REMOTE_ALLOCATOR_PORT,
30
32
  program: str = "monarch_bootstrap", # installed with monarch wheel (as console script)
31
- ) -> UnnamedAppDef:
33
+ ) -> specs.AppDef:
32
34
  """
33
35
  Args:
34
36
  name: the name of the monarch server job
@@ -39,7 +41,7 @@ def host_mesh(
39
41
  program: path to the binary that the remote process allocator spawns on an allocation request
40
42
  """
41
43
 
42
- appdef = UnnamedAppDef()
44
+ appdef = specs.AppDef(name=NOT_SET)
43
45
 
44
46
  for mesh in [mesh_spec_from_str(mesh) for mesh in meshes]:
45
47
  mesh_role = specs.Role(
@@ -5,23 +5,24 @@
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
7
  # pyre-strict
8
+ import warnings
8
9
  from dataclasses import dataclass, field
9
- from typing import Any, Dict, List, Optional
10
+ from typing import Any
10
11
 
11
- from torchx.specs import Role
12
+ from monarch.tools.config.workspace import Workspace
12
13
 
14
+ # Gracefully handle cases where torchx might not be installed
15
+ # NOTE: this can be removed once torchx.specs moves to monarch.session
16
+ try:
17
+ from torchx import specs
18
+ except ImportError:
19
+ pass
13
20
 
14
21
  NOT_SET: str = "__NOT_SET__"
15
22
 
16
23
 
17
- @dataclass
18
- class UnnamedAppDef:
19
- """
20
- A TorchX AppDef without a name.
21
- """
22
-
23
- roles: List[Role] = field(default_factory=list)
24
- metadata: Dict[str, str] = field(default_factory=dict)
24
+ def _empty_appdef() -> "specs.AppDef":
25
+ return specs.AppDef(name=NOT_SET)
25
26
 
26
27
 
27
28
  @dataclass
@@ -32,6 +33,28 @@ class Config:
32
33
 
33
34
  scheduler: str = NOT_SET
34
35
  scheduler_args: dict[str, Any] = field(default_factory=dict)
35
- workspace: Optional[str] = None
36
+ workspace: Workspace = field(default_factory=Workspace.null)
36
37
  dryrun: bool = False
37
- appdef: UnnamedAppDef = field(default_factory=UnnamedAppDef)
38
+ appdef: "specs.AppDef" = field(default_factory=_empty_appdef)
39
+
40
+ def __post_init__(self) -> None:
41
+ # workspace used to be Optional[str]
42
+ # while we type it as class Workspace now, handle workspace=None and str for BC
43
+ if self.workspace is None:
44
+ deprecation_msg = (
45
+ "Setting `workspace=None` is deprecated."
46
+ " Use `workspace=monarch.tools.config.workspace.Workspace(env=None)` instead."
47
+ )
48
+ warnings.warn(deprecation_msg, FutureWarning, stacklevel=2)
49
+ self.workspace = Workspace.null()
50
+ elif isinstance(self.workspace, str):
51
+ deprecation_msg = (
52
+ f"Setting `workspace='{self.workspace}'` is deprecated."
53
+ f" Use `workspace=monarch.tools.config.workspace.Workspace(dirs=['{self.workspace}'])` instead."
54
+ )
55
+ warnings.warn(deprecation_msg, FutureWarning, stacklevel=2)
56
+ # previous behavior (when workspace was a str pointing to the local project dir)
57
+ # was to copy the local dir into $WORKSPACE_DIR. For example:
58
+ # ~/github/torch/** (local) -> $WORKSPACE_DIR/** (remote)
59
+ # so we map it to "".
60
+ self.workspace = Workspace(dirs={self.workspace: ""})
@@ -8,10 +8,12 @@
8
8
 
9
9
  """Defines defaults for ``monarch.tools``"""
10
10
 
11
- from typing import Callable, Optional
11
+ import warnings
12
+ from typing import Callable
12
13
 
13
14
  from monarch.tools.components import hyperactor
14
- from monarch.tools.config import Config, UnnamedAppDef
15
+ from monarch.tools.config import Config
16
+ from monarch.tools.config.workspace import Workspace
15
17
 
16
18
  from torchx import specs
17
19
  from torchx.schedulers import (
@@ -23,7 +25,7 @@ from torchx.schedulers import (
23
25
  )
24
26
 
25
27
 
26
- def component_fn(scheduler: str) -> Callable[..., UnnamedAppDef]:
28
+ def component_fn(scheduler: str) -> Callable[..., specs.AppDef]:
27
29
  """The default TorchX component function for the scheduler"""
28
30
  return hyperactor.host_mesh
29
31
 
@@ -40,9 +42,17 @@ def scheduler_factories() -> dict[str, SchedulerFactory]:
40
42
  }
41
43
 
42
44
 
43
- def config(scheduler: str, workspace: Optional[str] = None) -> Config:
45
+ def config(scheduler: str, workspace: str | None = None) -> Config:
44
46
  """The default :py:class:`~monarch.tools.config.Config` to use when submitting to the provided ``scheduler``."""
45
- return Config(scheduler=scheduler, workspace=workspace)
47
+ warnings.warn(
48
+ "`defaults.config()` is deprecated, prefer instantiating `Config()` directly",
49
+ FutureWarning,
50
+ stacklevel=2,
51
+ )
52
+ return Config(
53
+ scheduler=scheduler,
54
+ workspace=Workspace(dirs={workspace: ""}) if workspace else Workspace.null(),
55
+ )
46
56
 
47
57
 
48
58
  def dryrun_info_formatter(dryrun_info: specs.AppDryRunInfo) -> Callable[..., str]:
@@ -0,0 +1,45 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from monarch.tools import utils
10
+
11
+
12
+ class Environment:
13
+ """An environment holds the necessary dependencies for the projects (directories)
14
+ in a `monarch.tools.workspace.Workspace`. When specified as part of a Workspace,
15
+ the local environment is packed into an ephemeral "image" (e.g. Docker) to mirror
16
+ the locally installed packages on the remote job.
17
+ """
18
+
19
+ pass
20
+
21
+
22
+ class CondaEnvironment(Environment):
23
+ """Reference to a conda environment.
24
+ If no `conda_prefix` is specified, then defaults to the currently active conda environment.
25
+ """
26
+
27
+ def __init__(self, conda_prefix: str | None = None) -> None:
28
+ self._conda_prefix = conda_prefix
29
+
30
+ @property
31
+ def conda_prefix(self) -> str:
32
+ """Returns the `conda_prefix` this object was instantiated with or the currently active conda environment
33
+ if no `conda_prefix` was specified in the constructor."""
34
+ if not self._conda_prefix:
35
+ active_conda_prefix = utils.conda.active_env_dir()
36
+ assert active_conda_prefix, "No currently active conda environment. Either specify a `conda_prefix` or activate one."
37
+ return active_conda_prefix
38
+ else:
39
+ return self._conda_prefix
40
+
41
+ def __eq__(self, other: object) -> bool:
42
+ if not isinstance(other, CondaEnvironment):
43
+ return False
44
+
45
+ return self._conda_prefix == other._conda_prefix