torchmonarch-nightly 2025.7.1__cp313-cp313-manylinux2014_x86_64.whl → 2025.7.25__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. monarch/__init__.py +13 -9
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
  4. monarch/_src/actor/actor_mesh.py +874 -0
  5. monarch/{allocator.py → _src/actor/allocator.py} +26 -17
  6. monarch/_src/actor/bootstrap_main.py +73 -0
  7. monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
  8. monarch/_src/actor/code_sync/auto_reload.py +223 -0
  9. monarch/_src/actor/debugger.py +565 -0
  10. monarch/_src/actor/endpoint.py +270 -0
  11. monarch/_src/actor/event_loop.py +97 -0
  12. monarch/_src/actor/future.py +100 -0
  13. monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
  14. monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
  15. monarch/_src/actor/proc_mesh.py +500 -0
  16. monarch/_src/actor/sync_state.py +18 -0
  17. monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
  18. monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
  19. monarch/_src/actor/tensor_engine_shim.py +56 -0
  20. monarch/_src/tensor_engine/rdma.py +180 -0
  21. monarch/_testing.py +3 -2
  22. monarch/actor/__init__.py +51 -0
  23. monarch/actor_mesh.py +6 -765
  24. monarch/bootstrap_main.py +8 -47
  25. monarch/common/client.py +1 -1
  26. monarch/common/controller_api.py +2 -1
  27. monarch/common/device_mesh.py +12 -2
  28. monarch/common/messages.py +12 -1
  29. monarch/common/recording.py +4 -3
  30. monarch/common/remote.py +135 -52
  31. monarch/common/tensor.py +2 -1
  32. monarch/controller/backend.py +2 -2
  33. monarch/controller/controller.py +2 -1
  34. monarch/controller/rust_backend/controller.py +2 -1
  35. monarch/fetch.py +3 -5
  36. monarch/mesh_controller.py +201 -139
  37. monarch/monarch_controller +0 -0
  38. monarch/opaque_module.py +4 -6
  39. monarch/opaque_object.py +3 -3
  40. monarch/proc_mesh.py +6 -309
  41. monarch/python_local_mesh.py +1 -1
  42. monarch/rust_backend_mesh.py +2 -1
  43. monarch/rust_local_mesh.py +4 -2
  44. monarch/sim_mesh.py +10 -19
  45. monarch/simulator/command_history.py +1 -1
  46. monarch/simulator/interface.py +2 -1
  47. monarch/simulator/mock_controller.py +1 -1
  48. monarch/simulator/simulator.py +1 -1
  49. monarch/tensor_engine/__init__.py +23 -0
  50. monarch/tensor_worker_main.py +3 -1
  51. monarch/tools/cli.py +3 -1
  52. monarch/tools/commands.py +95 -35
  53. monarch/tools/mesh_spec.py +55 -0
  54. monarch/tools/utils.py +38 -0
  55. monarch/worker/worker.py +1 -1
  56. monarch/world_mesh.py +2 -1
  57. monarch_supervisor/python_executable.py +6 -3
  58. tests/error_test_binary.py +48 -10
  59. tests/test_actor_error.py +370 -21
  60. tests/test_alloc.py +1 -1
  61. tests/test_allocator.py +373 -17
  62. tests/test_controller.py +2 -0
  63. tests/test_debugger.py +416 -0
  64. tests/test_env_before_cuda.py +162 -0
  65. tests/test_python_actors.py +184 -333
  66. tests/test_rdma.py +198 -0
  67. tests/test_remote_functions.py +40 -12
  68. tests/test_rust_backend.py +7 -5
  69. tests/test_sim_backend.py +1 -4
  70. tests/test_tensor_engine.py +55 -1
  71. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/METADATA +6 -1
  72. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/RECORD +80 -68
  73. torchmonarch_nightly-2025.7.25.dist-info/entry_points.txt +3 -0
  74. monarch/_monarch/hyperactor/__init__.py +0 -58
  75. monarch/_monarch/worker/debugger.py +0 -117
  76. monarch/_monarch/worker/logging.py +0 -107
  77. monarch/debugger.py +0 -379
  78. monarch/future.py +0 -76
  79. monarch/rdma.py +0 -162
  80. torchmonarch_nightly-2025.7.1.dist-info/entry_points.txt +0 -3
  81. /monarch/{_monarch/worker → _src}/__init__.py +0 -0
  82. /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
  83. /monarch/{common → _src/actor}/shape.py +0 -0
  84. /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
  85. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/WHEEL +0 -0
  86. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/licenses/LICENSE +0 -0
  87. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/top_level.txt +0 -0
tests/test_debugger.py ADDED
@@ -0,0 +1,416 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import asyncio
9
+ import re
10
+ import sys
11
+ from unittest.mock import AsyncMock, MagicMock, patch
12
+
13
+ import monarch
14
+ import monarch.actor as actor
15
+
16
+ import pytest
17
+
18
+ import torch
19
+
20
+ from monarch._src.actor.actor_mesh import Actor, current_rank
21
+ from monarch._src.actor.debugger import (
22
+ Attach,
23
+ Cast,
24
+ Continue,
25
+ DebugClient,
26
+ DebugCommand,
27
+ DebugSession,
28
+ Help,
29
+ ListCommand,
30
+ Quit,
31
+ )
32
+ from monarch._src.actor.endpoint import endpoint
33
+
34
+ from monarch._src.actor.proc_mesh import proc_mesh
35
+
36
+ needs_cuda = pytest.mark.skipif(
37
+ not torch.cuda.is_available(),
38
+ reason="CUDA not available",
39
+ )
40
+
41
+
42
+ def _bad_rank():
43
+ raise ValueError("bad rank")
44
+
45
+
46
+ def _debugee_actor_internal(rank):
47
+ if rank == 0:
48
+ breakpoint() # noqa
49
+ rank += 1
50
+ rank += 1
51
+ return rank
52
+ elif rank == 1:
53
+ breakpoint() # noqa
54
+ rank += 2
55
+ rank += 2
56
+ return rank
57
+ elif rank == 2:
58
+ breakpoint() # noqa
59
+ rank += 3
60
+ rank += 3
61
+ _bad_rank()
62
+ elif rank == 3:
63
+ breakpoint() # noqa
64
+ rank += 4
65
+ rank += 4
66
+ return rank
67
+
68
+
69
+ class DebugeeActor(Actor):
70
+ @endpoint
71
+ async def to_debug(self):
72
+ rank = current_rank().rank
73
+ return _debugee_actor_internal(rank)
74
+
75
+
76
+ @pytest.mark.skipif(
77
+ torch.cuda.device_count() < 2,
78
+ reason="Not enough GPUs, this test requires at least 2 GPUs",
79
+ )
80
+ async def test_debug() -> None:
81
+ input_mock = AsyncMock()
82
+ input_mock.side_effect = [
83
+ "attach 1",
84
+ "n",
85
+ "n",
86
+ "n",
87
+ "n",
88
+ "detach",
89
+ "attach 1",
90
+ "detach",
91
+ "quit",
92
+ "cast ranks(0,3) n",
93
+ "cast ranks(0,3) n",
94
+ # Attaching to 0 and 3 ensures that when we call "list"
95
+ # the next time, their function/lineno info will be
96
+ # up-to-date.
97
+ "attach 0",
98
+ "detach",
99
+ "attach 3",
100
+ "detach",
101
+ "quit",
102
+ "attach 2",
103
+ "c",
104
+ "detach",
105
+ "quit",
106
+ "attach 2",
107
+ "bt",
108
+ "c",
109
+ "quit",
110
+ "continue",
111
+ ]
112
+
113
+ outputs = []
114
+
115
+ def _patch_output(msg):
116
+ nonlocal outputs
117
+ outputs.append(msg)
118
+
119
+ with patch(
120
+ "monarch._src.actor.debugger._debugger_input", side_effect=input_mock
121
+ ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output):
122
+ proc = await proc_mesh(hosts=2, gpus=2)
123
+ debugee = await proc.spawn("debugee", DebugeeActor)
124
+ debug_client = actor.debug_client()
125
+
126
+ fut = debugee.to_debug.call()
127
+ await debug_client.wait_pending_session.call_one()
128
+ breakpoints = []
129
+ for i in range(10):
130
+ breakpoints = await debug_client.list.call_one()
131
+ if len(breakpoints) == 4:
132
+ break
133
+ await asyncio.sleep(1)
134
+ if i == 9:
135
+ raise RuntimeError("timed out waiting for breakpoints")
136
+
137
+ initial_linenos = {}
138
+ for i in range(len(breakpoints)):
139
+ rank, coords, _, _, function, lineno = breakpoints[i]
140
+ initial_linenos[rank] = lineno
141
+ assert rank == i
142
+ assert coords == {"hosts": rank // 2, "gpus": rank % 2}
143
+ assert function == "test_debugger._debugee_actor_internal"
144
+ assert lineno == breakpoints[0][5] + 5 * rank
145
+
146
+ await debug_client.enter.call_one()
147
+
148
+ # Check that when detaching and re-attaching to a session, the last portion of the output is repeated
149
+ expected_last_output = [
150
+ r"--Return--",
151
+ r"\n",
152
+ r"> (/.*/)+test_debugger.py\(\d+\)to_debug\(\)->5\n-> return _debugee_actor_internal\(rank\)",
153
+ r"\n",
154
+ r"\(Pdb\) ",
155
+ ]
156
+ output_len = len(expected_last_output)
157
+ assert outputs[-2 * output_len : -output_len] == outputs[-output_len:]
158
+ for real_output, expected_output in zip(
159
+ outputs[-output_len:], expected_last_output
160
+ ):
161
+ assert re.match(expected_output, real_output) is not None
162
+
163
+ breakpoints = await debug_client.list.call_one()
164
+ for i in range(len(breakpoints)):
165
+ if i == 1:
166
+ assert breakpoints[i][4] == "test_debugger.to_debug"
167
+ else:
168
+ assert breakpoints[i][4] == "test_debugger._debugee_actor_internal"
169
+ assert breakpoints[i][5] == initial_linenos[i]
170
+
171
+ await debug_client.enter.call_one()
172
+
173
+ breakpoints = await debug_client.list.call_one()
174
+ for i in range(len(breakpoints)):
175
+ if i == 1:
176
+ assert breakpoints[i][4] == "test_debugger.to_debug"
177
+ elif i in (0, 3):
178
+ assert breakpoints[i][4] == "test_debugger._debugee_actor_internal"
179
+ assert breakpoints[i][5] == initial_linenos[i] + 2
180
+ else:
181
+ assert breakpoints[i][4] == "test_debugger._debugee_actor_internal"
182
+ assert breakpoints[i][5] == initial_linenos[i]
183
+
184
+ await debug_client.enter.call_one()
185
+
186
+ breakpoints = await debug_client.list.call_one()
187
+ assert len(breakpoints) == 4
188
+ # Expect post-mortem debugging for rank 2
189
+ assert breakpoints[2][4] == "test_debugger._bad_rank"
190
+
191
+ await debug_client.enter.call_one()
192
+
193
+ expected_last_output = [
194
+ r"\s*(/.*/)+test_debugger.py\(\d+\)_debugee_actor_internal\(\)\n-> _bad_rank\(\)",
195
+ r"\n",
196
+ r'> (/.*/)+test_debugger.py\(\d+\)_bad_rank\(\)\n-> raise ValueError\("bad rank"\)',
197
+ r"\n",
198
+ r"\(Pdb\) ",
199
+ ]
200
+
201
+ for output, expected_output in zip(
202
+ outputs[-len(expected_last_output) :], expected_last_output
203
+ ):
204
+ assert re.match(expected_output, output) is not None
205
+
206
+ breakpoints = await debug_client.list.call_one()
207
+ assert len(breakpoints) == 3
208
+ for i, rank in enumerate((0, 1, 3)):
209
+ assert breakpoints[i][0] == rank
210
+
211
+ await debug_client.enter.call_one()
212
+ breakpoints = await debug_client.list.call_one()
213
+ assert len(breakpoints) == 0
214
+
215
+ with pytest.raises(
216
+ monarch._src.actor.actor_mesh.ActorError, match="ValueError: bad rank"
217
+ ):
218
+ await fut
219
+
220
+
221
+ async def test_cast_input_and_wait() -> None:
222
+ debug_client = DebugClient()
223
+
224
+ mock_sessions = {}
225
+ for host in range(3):
226
+ for gpu in range(8):
227
+ rank = host * 8 + gpu
228
+ mock_session = MagicMock(spec=DebugSession)
229
+ mock_session.attach = AsyncMock()
230
+ mock_session.rank = rank
231
+ mock_session.coords = {"hosts": host, "gpus": gpu}
232
+ mock_sessions[rank] = mock_session
233
+
234
+ debug_client.sessions = mock_sessions
235
+
236
+ # Cast to a single rank
237
+ await debug_client._cast_input_and_wait("n", 2)
238
+ mock_sessions[2].attach.assert_called_once_with("n", suppress_output=True)
239
+ for rank, session in mock_sessions.items():
240
+ if rank != 2:
241
+ session.attach.assert_not_called()
242
+
243
+ for session in mock_sessions.values():
244
+ session.attach.reset_mock()
245
+
246
+ # Cast to a list of ranks
247
+ ranks = [1, 3, 5]
248
+ await debug_client._cast_input_and_wait("n", ranks)
249
+ for rank in ranks:
250
+ mock_sessions[rank].attach.assert_called_once_with("n", suppress_output=True)
251
+ for rank, session in mock_sessions.items():
252
+ if rank not in ranks:
253
+ session.attach.assert_not_called()
254
+
255
+ for session in mock_sessions.values():
256
+ session.attach.reset_mock()
257
+
258
+ # Cast to a range of ranks
259
+ ranks = range(2, 24, 3)
260
+ await debug_client._cast_input_and_wait("n", ranks)
261
+ for rank in ranks:
262
+ mock_sessions[rank].attach.assert_called_once_with("n", suppress_output=True)
263
+ for rank, session in mock_sessions.items():
264
+ if rank not in ranks:
265
+ session.attach.assert_not_called()
266
+
267
+ for session in mock_sessions.values():
268
+ session.attach.reset_mock()
269
+
270
+ # Cast to all ranks
271
+ await debug_client._cast_input_and_wait("n", None)
272
+ for session in mock_sessions.values():
273
+ session.attach.assert_called_once_with("n", suppress_output=True)
274
+
275
+ for session in mock_sessions.values():
276
+ session.attach.reset_mock()
277
+
278
+ # Cast using dimension filtering with a single value
279
+ await debug_client._cast_input_and_wait("n", {"hosts": 1})
280
+ for session in mock_sessions.values():
281
+ if session.coords["hosts"] == 1:
282
+ session.attach.assert_called_once_with("n", suppress_output=True)
283
+ else:
284
+ session.attach.assert_not_called()
285
+
286
+ for session in mock_sessions.values():
287
+ session.attach.reset_mock()
288
+
289
+ # Cast using dimension filtering with a list
290
+ await debug_client._cast_input_and_wait("n", {"hosts": [0, 2]})
291
+ for _rank, session in mock_sessions.items():
292
+ if session.coords["hosts"] in [0, 2]:
293
+ session.attach.assert_called_once_with("n", suppress_output=True)
294
+ else:
295
+ session.attach.assert_not_called()
296
+
297
+ for session in mock_sessions.values():
298
+ session.attach.reset_mock()
299
+
300
+ # Cast using dimension filtering with a range
301
+ await debug_client._cast_input_and_wait("n", {"gpus": range(5, 8)})
302
+ for session in mock_sessions.values():
303
+ if session.coords["gpus"] in range(5, 8):
304
+ session.attach.assert_called_once_with("n", suppress_output=True)
305
+ else:
306
+ session.attach.assert_not_called()
307
+
308
+ for session in mock_sessions.values():
309
+ session.attach.reset_mock()
310
+
311
+ # Cast using multiple dimension filters
312
+ await debug_client._cast_input_and_wait(
313
+ "n", {"hosts": [1, 3], "gpus": range(0, sys.maxsize, 3)}
314
+ )
315
+ for session in mock_sessions.values():
316
+ if session.coords["hosts"] in [1, 3] and session.coords["gpus"] in range(
317
+ 0, sys.maxsize, 3
318
+ ):
319
+ session.attach.assert_called_once_with("n", suppress_output=True)
320
+ else:
321
+ session.attach.assert_not_called()
322
+
323
+ for session in mock_sessions.values():
324
+ session.attach.reset_mock()
325
+
326
+ # Cast with non-existent dimension
327
+ await debug_client._cast_input_and_wait("n", {"hosts": 0, "gpus": 0, "foo": 0})
328
+ for session in mock_sessions.values():
329
+ session.attach.assert_not_called()
330
+
331
+
332
+ @pytest.mark.parametrize(
333
+ ["user_input", "expected_output"],
334
+ [
335
+ ("attach 1", Attach(1)),
336
+ ("a 100", Attach(100)),
337
+ ("list", ListCommand()),
338
+ ("l", ListCommand()),
339
+ ("help", Help()),
340
+ ("h", Help()),
341
+ ("quit", Quit()),
342
+ ("q", Quit()),
343
+ ("continue", Continue()),
344
+ ("c", Continue()),
345
+ ("cast ranks(123) b 25", Cast(ranks=123, command="b 25")),
346
+ ("cast ranks(12,34,56) b 25", Cast(ranks=[12, 34, 56], command="b 25")),
347
+ ("cast ranks(:) b 25", Cast(ranks=range(0, sys.maxsize), command="b 25")),
348
+ ("cast ranks(:123) b 25", Cast(ranks=range(0, 123), command="b 25")),
349
+ ("cast ranks(123:) b 25", Cast(ranks=range(123, sys.maxsize), command="b 25")),
350
+ ("cast ranks(123:456) b 25", Cast(ranks=range(123, 456), command="b 25")),
351
+ ("cast ranks(::) b 25", Cast(ranks=range(0, sys.maxsize), command="b 25")),
352
+ (
353
+ "cast ranks(::123) b 25",
354
+ Cast(ranks=range(0, sys.maxsize, 123), command="b 25"),
355
+ ),
356
+ ("cast ranks(123::) b 25", Cast(ranks=range(123, sys.maxsize), command="b 25")),
357
+ ("cast ranks(:123:) b 25", Cast(ranks=range(0, 123), command="b 25")),
358
+ ("cast ranks(:456:123) b 25", Cast(ranks=range(0, 456, 123), command="b 25")),
359
+ (
360
+ "cast ranks(456::123) b 25",
361
+ Cast(ranks=range(456, sys.maxsize, 123), command="b 25"),
362
+ ),
363
+ ("cast ranks(123:456:) b 25", Cast(ranks=range(123, 456), command="b 25")),
364
+ (
365
+ "cast ranks(456:789:123) b 25",
366
+ Cast(ranks=range(456, 789, 123), command="b 25"),
367
+ ),
368
+ ("cast ranks(dim1=123) up 2", Cast(ranks={"dim1": 123}, command="up 2")),
369
+ (
370
+ "cast ranks(dim1=123, dim2=(12,34,56), dim3=15::2) up 2",
371
+ Cast(
372
+ ranks={
373
+ "dim1": 123,
374
+ "dim2": [12, 34, 56],
375
+ "dim3": range(15, sys.maxsize, 2),
376
+ },
377
+ command="up 2",
378
+ ),
379
+ ),
380
+ ],
381
+ )
382
+ async def test_debug_command_parser_valid_inputs(user_input, expected_output):
383
+ assert DebugCommand.parse(user_input) == expected_output
384
+
385
+
386
+ @pytest.mark.parametrize(
387
+ "invalid_input",
388
+ [
389
+ "",
390
+ "attch 1",
391
+ "attach",
392
+ "cast rnks(123) b 25",
393
+ "cast ranks() b 25",
394
+ "cast ranks(1ab) b 25",
395
+ "cast ranks(1,a,3) b 25",
396
+ "cast ranks(a:2:4) b 25",
397
+ "cast ranks(1,2,3",
398
+ "cast ranks(1,2,3)) b 25",
399
+ "cast ranks(1,) b 25",
400
+ "cast ranks(1,2,) b 25",
401
+ "cast ranks(,1,2) b 25",
402
+ "cast ranks(1,,2) b 25",
403
+ "cast ranks(:::) b 25",
404
+ "cast ranks(:123::) b 25",
405
+ "cast ranks(1:2:3,4) b 25",
406
+ "cast ranks(dim1=) b 25",
407
+ "cast ranks(dim1=123, dim2=) b 25",
408
+ "cast ranks(dim1=123, dim2=(12,34,56) b 25",
409
+ "cast ranks(dim1=123, dim2=(,12,34,56) b 25",
410
+ "cast ranks(dim1=123, dim2=(12,,34,56) b 25",
411
+ "cast ranks(dim1=123, dim2=(12,34,56), dim3=15::2 b 25",
412
+ "cast ranks(dim1=123,) b 25",
413
+ ],
414
+ )
415
+ async def test_debug_command_parser_invalid_inputs(invalid_input):
416
+ assert DebugCommand.parse(invalid_input) is None
@@ -0,0 +1,162 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import os
10
+ import sys
11
+ import unittest
12
+ from typing import Dict, List
13
+
14
+ import cloudpickle
15
+
16
+ import torch
17
+ from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints, AllocSpec
18
+ from monarch._src.actor.actor_mesh import MonarchContext
19
+ from monarch._src.actor.allocator import LocalAllocator
20
+ from monarch._src.actor.proc_mesh import proc_mesh
21
+ from monarch.actor import Actor, endpoint, ProcMesh
22
+
23
+
24
+ class CudaInitTestActor(Actor):
25
+ """Actor that initializes CUDA and checks environment variables"""
26
+
27
+ def __init__(self) -> None:
28
+ self.env_vars_before_init: Dict[str, str] = {}
29
+ self.cuda_initialized: bool = False
30
+
31
+ @endpoint
32
+ async def init_cuda_and_check_env(self, env_var_names: List[str]) -> Dict[str, str]:
33
+ """
34
+ Check environment variables before initializing CUDA
35
+ Returns the values of the environment variables
36
+ """
37
+ for var_name in env_var_names:
38
+ self.env_vars_before_init[var_name] = os.environ.get(var_name, "NOT_SET")
39
+
40
+ if torch.cuda.is_available():
41
+ torch.cuda.init()
42
+ self.cuda_initialized = True
43
+
44
+ return self.env_vars_before_init
45
+
46
+ @endpoint
47
+ async def is_cuda_initialized(self) -> bool:
48
+ """Return whether CUDA was initialized"""
49
+ return self.cuda_initialized
50
+
51
+
52
+ class TestEnvBeforeCuda(unittest.IsolatedAsyncioTestCase):
53
+ """Test that the env vars are setup before cuda init"""
54
+
55
+ @classmethod
56
+ def setUpClass(cls) -> None:
57
+ cloudpickle.register_pickle_by_value(sys.modules[CudaInitTestActor.__module__])
58
+
59
+ @classmethod
60
+ def tearDownClass(cls) -> None:
61
+ cloudpickle.unregister_pickle_by_value(
62
+ sys.modules[CudaInitTestActor.__module__]
63
+ )
64
+
65
+ async def test_lambda_sets_env_vars_before_cuda_init(self) -> None:
66
+ """Test that environment variables are set by lambda before CUDA initialization"""
67
+ cuda_env_vars: Dict[str, str] = {
68
+ "CUDA_VISIBLE_DEVICES": "0",
69
+ "CUDA_CACHE_PATH": "/tmp/cuda_cache_test",
70
+ "CUDA_LAUNCH_BLOCKING": "1",
71
+ }
72
+
73
+ def setup_cuda_env(_: MonarchContext) -> None:
74
+ for name, value in cuda_env_vars.items():
75
+ os.environ[name] = value
76
+
77
+ spec = AllocSpec(AllocConstraints(), gpus=1, hosts=1)
78
+ allocator = LocalAllocator()
79
+ alloc = await allocator.allocate(spec)
80
+
81
+ proc_mesh = await ProcMesh.from_alloc(alloc, setup=setup_cuda_env)
82
+
83
+ try:
84
+ actor = await proc_mesh.spawn("cuda_init", CudaInitTestActor)
85
+
86
+ env_vars = await actor.init_cuda_and_check_env.call_one(
87
+ list(cuda_env_vars.keys())
88
+ )
89
+
90
+ await actor.is_cuda_initialized.call_one()
91
+
92
+ for name, expected_value in cuda_env_vars.items():
93
+ self.assertEqual(
94
+ env_vars.get(name),
95
+ expected_value,
96
+ f"Environment variable {name} was not set correctly before CUDA initialization",
97
+ )
98
+
99
+ finally:
100
+ await proc_mesh.stop()
101
+
102
+ async def test_proc_mesh_with_lambda_env(self) -> None:
103
+ """Test that proc_mesh function works with lambda for env parameter"""
104
+ cuda_env_vars: Dict[str, str] = {
105
+ "CUDA_DEVICE_ORDER": "PCI_BUS_ID",
106
+ "CUDA_MODULE_LOADING": "LAZY",
107
+ "CUDA_DEVICE_MAX_CONNECTIONS": "1",
108
+ }
109
+
110
+ def setup_cuda_env(_: MonarchContext) -> None:
111
+ for name, value in cuda_env_vars.items():
112
+ os.environ[name] = value
113
+
114
+ proc_mesh_instance = await proc_mesh(gpus=1, hosts=1, setup=setup_cuda_env)
115
+
116
+ try:
117
+ actor = await proc_mesh_instance.spawn("cuda_init", CudaInitTestActor)
118
+
119
+ env_vars = await actor.init_cuda_and_check_env.call_one(
120
+ list(cuda_env_vars.keys())
121
+ )
122
+ for name, expected_value in cuda_env_vars.items():
123
+ self.assertEqual(
124
+ env_vars.get(name),
125
+ expected_value,
126
+ f"Environment variable {name} was not set correctly before CUDA initialization",
127
+ )
128
+
129
+ finally:
130
+ await proc_mesh_instance.stop()
131
+
132
+ async def test_proc_mesh_with_dictionary_env(self) -> None:
133
+ """Test that proc_mesh function works with dictionary for env parameter"""
134
+ cuda_env_vars: Dict[str, str] = {
135
+ "CUDA_DEVICE_ORDER": "PCI_BUS_ID",
136
+ "CUDA_MODULE_LOADING": "LAZY",
137
+ "CUDA_DEVICE_MAX_CONNECTIONS": "1",
138
+ }
139
+
140
+ proc_mesh_instance = await proc_mesh(gpus=1, hosts=1, env=cuda_env_vars)
141
+
142
+ try:
143
+ actor = await proc_mesh_instance.spawn("cuda_init", CudaInitTestActor)
144
+ env_vars = await actor.init_cuda_and_check_env.call_one(
145
+ list(cuda_env_vars.keys())
146
+ )
147
+
148
+ self.assertEqual(
149
+ env_vars.get("CUDA_DEVICE_ORDER"),
150
+ "PCI_BUS_ID",
151
+ )
152
+ self.assertEqual(
153
+ env_vars.get("CUDA_MODULE_LOADING"),
154
+ "LAZY",
155
+ )
156
+ self.assertEqual(
157
+ env_vars.get("CUDA_DEVICE_MAX_CONNECTIONS"),
158
+ "1",
159
+ )
160
+
161
+ finally:
162
+ await proc_mesh_instance.stop()