mccl 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. mccl-0.3.0/LICENSE +21 -0
  2. mccl-0.3.0/MANIFEST.in +3 -0
  3. mccl-0.3.0/PKG-INFO +162 -0
  4. mccl-0.3.0/README.md +138 -0
  5. mccl-0.3.0/csrc/backend/MPSDispatch.cpp +144 -0
  6. mccl-0.3.0/csrc/backend/Options.hpp +47 -0
  7. mccl-0.3.0/csrc/backend/ProcessGroupMCCL.cpp +1372 -0
  8. mccl-0.3.0/csrc/backend/ProcessGroupMCCL.hpp +143 -0
  9. mccl-0.3.0/csrc/backend/Registration.cpp +214 -0
  10. mccl-0.3.0/csrc/backend/WorkMCCL.cpp +104 -0
  11. mccl-0.3.0/csrc/backend/WorkMCCL.hpp +48 -0
  12. mccl-0.3.0/csrc/common/Errors.hpp +68 -0
  13. mccl-0.3.0/csrc/common/Logging.hpp +104 -0
  14. mccl-0.3.0/csrc/common/TensorChecks.hpp +60 -0
  15. mccl-0.3.0/csrc/common/Version.hpp +20 -0
  16. mccl-0.3.0/csrc/compression/Compression.cpp +22 -0
  17. mccl-0.3.0/csrc/compression/Compression.hpp +47 -0
  18. mccl-0.3.0/csrc/compression/FP16Compression.cpp +151 -0
  19. mccl-0.3.0/csrc/compression/FP16Compression.hpp +32 -0
  20. mccl-0.3.0/csrc/compression/TopKCompression.cpp +143 -0
  21. mccl-0.3.0/csrc/compression/TopKCompression.hpp +48 -0
  22. mccl-0.3.0/csrc/metal/AccelerateOps.hpp +54 -0
  23. mccl-0.3.0/csrc/metal/AccelerateOps.mm +205 -0
  24. mccl-0.3.0/csrc/metal/EventSync.hpp +52 -0
  25. mccl-0.3.0/csrc/metal/EventSync.mm +162 -0
  26. mccl-0.3.0/csrc/metal/MPSInterop.hpp +79 -0
  27. mccl-0.3.0/csrc/metal/MPSInterop.mm +425 -0
  28. mccl-0.3.0/csrc/metal/MetalKernels.hpp +43 -0
  29. mccl-0.3.0/csrc/metal/MetalKernels.mm +665 -0
  30. mccl-0.3.0/csrc/metal/shaders.metal +552 -0
  31. mccl-0.3.0/csrc/runtime/HealthMonitor.cpp +79 -0
  32. mccl-0.3.0/csrc/runtime/HealthMonitor.hpp +59 -0
  33. mccl-0.3.0/csrc/runtime/MemoryPool.cpp +128 -0
  34. mccl-0.3.0/csrc/runtime/MemoryPool.hpp +77 -0
  35. mccl-0.3.0/csrc/runtime/Metrics.cpp +139 -0
  36. mccl-0.3.0/csrc/runtime/Metrics.hpp +97 -0
  37. mccl-0.3.0/csrc/runtime/ProgressEngine.cpp +153 -0
  38. mccl-0.3.0/csrc/runtime/ProgressEngine.hpp +69 -0
  39. mccl-0.3.0/csrc/runtime/Rendezvous.cpp +81 -0
  40. mccl-0.3.0/csrc/runtime/Rendezvous.hpp +39 -0
  41. mccl-0.3.0/csrc/runtime/Watchdog.cpp +84 -0
  42. mccl-0.3.0/csrc/runtime/Watchdog.hpp +64 -0
  43. mccl-0.3.0/csrc/transport/Connection.cpp +383 -0
  44. mccl-0.3.0/csrc/transport/Connection.hpp +75 -0
  45. mccl-0.3.0/csrc/transport/Protocol.hpp +161 -0
  46. mccl-0.3.0/csrc/transport/TcpTransport.cpp +880 -0
  47. mccl-0.3.0/csrc/transport/TcpTransport.hpp +101 -0
  48. mccl-0.3.0/csrc/transport/Transport.hpp +57 -0
  49. mccl-0.3.0/csrc/transport/rdma/IbvWrapper.cpp +98 -0
  50. mccl-0.3.0/csrc/transport/rdma/IbvWrapper.hpp +39 -0
  51. mccl-0.3.0/csrc/transport/rdma/RdmaConnection.cpp +230 -0
  52. mccl-0.3.0/csrc/transport/rdma/RdmaConnection.hpp +95 -0
  53. mccl-0.3.0/csrc/transport/rdma/RdmaTransport.cpp +419 -0
  54. mccl-0.3.0/csrc/transport/rdma/RdmaTransport.hpp +97 -0
  55. mccl-0.3.0/csrc/transport/rdma/SharedBuffer.cpp +94 -0
  56. mccl-0.3.0/csrc/transport/rdma/SharedBuffer.hpp +50 -0
  57. mccl-0.3.0/csrc/transport/rdma/ibverbs_compat.h +332 -0
  58. mccl-0.3.0/mccl/__init__.py +159 -0
  59. mccl-0.3.0/mccl/config.py +191 -0
  60. mccl-0.3.0/mccl/tuning.py +31 -0
  61. mccl-0.3.0/mccl/version.py +16 -0
  62. mccl-0.3.0/mccl.egg-info/PKG-INFO +162 -0
  63. mccl-0.3.0/mccl.egg-info/SOURCES.txt +77 -0
  64. mccl-0.3.0/mccl.egg-info/dependency_links.txt +1 -0
  65. mccl-0.3.0/mccl.egg-info/requires.txt +5 -0
  66. mccl-0.3.0/mccl.egg-info/top_level.txt +1 -0
  67. mccl-0.3.0/pyproject.toml +46 -0
  68. mccl-0.3.0/setup.cfg +4 -0
  69. mccl-0.3.0/setup.py +265 -0
  70. mccl-0.3.0/tests/test_build.py +39 -0
  71. mccl-0.3.0/tests/test_compression.py +127 -0
  72. mccl-0.3.0/tests/test_cpu_tensors.py +111 -0
  73. mccl-0.3.0/tests/test_ddp.py +196 -0
  74. mccl-0.3.0/tests/test_local_kernels.py +290 -0
  75. mccl-0.3.0/tests/test_process_group_local.py +325 -0
  76. mccl-0.3.0/tests/test_protocol.py +91 -0
  77. mccl-0.3.0/tests/test_soak.py +188 -0
  78. mccl-0.3.0/tests/test_two_host_ddp.py +117 -0
  79. mccl-0.3.0/tests/test_v2_collectives.py +253 -0
mccl-0.3.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 MCCL Contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
mccl-0.3.0/MANIFEST.in ADDED
@@ -0,0 +1,3 @@
1
+ # Ensure native sources ship in sdist (wheels are built on macOS arm64 only).
2
+ recursive-include csrc *
3
+ include LICENSE README.md pyproject.toml setup.py
mccl-0.3.0/PKG-INFO ADDED
@@ -0,0 +1,162 @@
1
+ Metadata-Version: 2.4
2
+ Name: mccl
3
+ Version: 0.3.0
4
+ Summary: MPS-native ProcessGroup backend for PyTorch Distributed on Apple Silicon
5
+ License: MIT
6
+ Project-URL: Repository, https://github.com/OWNER/REPO
7
+ Classifier: Development Status :: 4 - Beta
8
+ Classifier: Intended Audience :: Developers
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: MacOS :: MacOS X
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Requires-Python: >=3.11
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch>=2.5.0
19
+ Provides-Extra: dev
20
+ Requires-Dist: pytest>=7.0; extra == "dev"
21
+ Requires-Dist: pytest-timeout>=2.1; extra == "dev"
22
+ Dynamic: license-file
23
+ Dynamic: requires-python
24
+
25
+ # MCCL
26
+
27
+ [![CI](https://github.com/OWNER/REPO/actions/workflows/ci.yml/badge.svg)](https://github.com/OWNER/REPO/actions/workflows/ci.yml)
28
+
29
+ `torch.distributed` backend for DDP and collectives on **MPS** (Apple Silicon). TCP by default; RDMA only if the machine/OS actually supports it.
30
+
31
+ ## Requirements
32
+
33
+ - Apple Silicon Mac (arm64). No Intel.
34
+ - **Xcode Command Line Tools** — `xcode-select --install` (needed to compile the extension).
35
+ - **Full Xcode** — optional; speeds up Metal by emitting a `.metallib` at build time instead of JIT at runtime.
36
+ - **Python 3.11+**
37
+ - **PyTorch 2.5+** installed *before* you build or `pip install` this package.
38
+
39
+ ## Install
40
+
41
+ ```bash
42
+ pip install torch
43
+ pip install mccl
44
+ ```
45
+
46
+ Source tree: `pip install -e ".[dev]"`. If the PyPI name `mccl` is taken, rename in `pyproject.toml` and `setup.py`.
47
+
48
+ Demo: https://github.com/user-attachments/assets/21865149-b077-4b65-93cc-f9e319ff0328
49
+
50
+ ## Examples
51
+
52
+ ```bash
53
+ python examples/ddp_dummy_train.py --baseline
54
+ torchrun --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=29500 \
55
+ examples/ddp_dummy_train.py
56
+ ```
57
+
58
+ Defaults there: DDP `BATCH_SIZE=128` per rank → global 256 with 2 ranks; baseline path uses global 256 unless you override. Shrink batch if you OOM.
59
+
60
+ Minimal DDP script (run with `torchrun` below). Multi-node needs `MCCL_LISTEN_ADDR`, `MCCL_PORT_BASE`, etc. — [docs/MULTINODE.md](docs/MULTINODE.md).
61
+
62
+ ```python
63
+ import os
64
+ import torch
65
+ import torch.nn as nn
66
+ import torch.distributed as dist
67
+ from torch.nn.parallel import DistributedDataParallel as DDP
68
+ import mccl
69
+
70
+ def main():
71
+ rank = int(os.environ["RANK"])
72
+ world_size = int(os.environ["WORLD_SIZE"])
73
+ device = torch.device("mps:0")
74
+
75
+ dist.init_process_group(backend="mccl", device_id=device)
76
+
77
+ torch.manual_seed(42)
78
+ model = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 10)).to(device)
79
+ ddp_model = DDP(model)
80
+ optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=1e-3)
81
+ loss_fn = nn.CrossEntropyLoss()
82
+
83
+ for step in range(10):
84
+ x = torch.randn(8, 512, device=device)
85
+ y = torch.randint(0, 10, (8,), device=device)
86
+ optimizer.zero_grad(set_to_none=True)
87
+ loss_fn(ddp_model(x), y).backward()
88
+ optimizer.step()
89
+ if rank == 0:
90
+ print(step, "ok")
91
+
92
+ dist.destroy_process_group()
93
+
94
+ if __name__ == "__main__":
95
+ main()
96
+ ```
97
+
98
+ ```bash
99
+ torchrun --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=29500 your_train.py
100
+ ```
101
+
102
+ ## Throughput
103
+
104
+ One saved run, **M4 Max** + **M1 Max** MBPs, TCP over TB, global batch **256**, ~**96.5M** params. Your numbers will differ.
105
+
106
+ ```
107
+ single M1 Max (MPS): 78.3 samples/s (global_batch=256, world=1)
108
+ DDP (MCCL): 134.2 samples/s (global_batch=256, world=2)
109
+ baseline / DDP: 0.58× (~172% DDP vs baseline)
110
+ ```
111
+
112
+ Tiny batches = comm noise dominates. Different chips on each rank = slowest one paces the step.
113
+
114
+ ```bash
115
+ python examples/ddp_dummy_train.py --baseline --save-stats baseline_stats.json
116
+ torchrun --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=29500 \
117
+ examples/ddp_dummy_train.py --save-stats ddp_stats.json
118
+ python examples/benchmark_throughput.py --baseline baseline_stats.json --ddp ddp_stats.json -o bench
119
+ ```
120
+
121
+ `bash scripts/benchmark_matrix.sh` for more sweeps.
122
+
123
+ ![bench](bench.png)
124
+ ![bars](bench_bars.png)
125
+
126
+ ## PyPI (maintainers)
127
+
128
+ **CI (tests only):** push to **`main`** or **`master`**, or open a PR targeting those branches. That runs [`.github/workflows/ci.yml`](.github/workflows/ci.yml) — it does **not** upload to PyPI.
129
+
130
+ **Upload:** [`.github/workflows/publish.yml`](.github/workflows/publish.yml) runs on **GitHub Release (published)** or **Actions → Publish to PyPI → Run workflow** (`workflow_dispatch`).
131
+
132
+ 1. GitHub repo → **Settings → Secrets and variables → Actions** → New repository secret **`PYPI_API_TOKEN`** (PyPI → Account settings → API tokens).
133
+ 2. Bump **`version`** in `pyproject.toml`, `setup.py`, and the assertion in `tests/test_build.py`.
134
+ 3. Either: **Releases → Draft a new release** → publish (triggers upload), or **Actions** tab → **Publish to PyPI** → **Run workflow** → branch `main`.
135
+
136
+ First-time PyPI: create the **`mccl`** project on pypi.org (or change the package `name` everywhere if the name is taken).
137
+
138
+ ## Collectives
139
+
140
+ `allreduce`, `broadcast`, `barrier`, `allgather`, `reduce_scatter`, `send`, `recv`
141
+
142
+ ## Diagnostics
143
+
144
+ ```python
145
+ mccl.get_metrics(); mccl.log_metrics(); mccl.reset_metrics()
146
+ ```
147
+
148
+ Verbose startup: `MCCL_LOG_LEVEL=INFO`. Stuck multi-node: [docs/MULTINODE.md](docs/MULTINODE.md).
149
+
150
+ ## Transport
151
+
152
+ Bench plots were TCP over a Thunderbolt-style link, not RDMA. Wi‑Fi/Ethernet work, just slower. TB wiring: [docs/THUNDERBOLT_SETUP.md](docs/THUNDERBOLT_SETUP.md). RDMA path exists on TB5-capable hardware + `librdma.dylib`; `rdma_ctl enable` from Recovery once; we didn’t use that for the graphs above.
153
+
154
+ ## Internals
155
+
156
+ Apple Silicon is **UMA**: GPU and CPU share a physical memory pool. MPS tensors are usually **`MTLBuffer`s**; with **`MTLStorageModeShared`**, `buffer.contents` is a CPU pointer into the **same pages** the GPU uses (`extract_mps_buffer`, `MPSInterop.mm`). MCCL **exploits that** by staging sends from that pointer, writing receives with `memcpy` into it, and running **Accelerate/vDSP** in **`AccelerateOps.mm`** on the same bytes—no duplicate host tensor when the fast path applies. **Private** GPU storage still needs a **blit** through a shared staging buffer (`chunked_blit_*`).
157
+
158
+ I/O runs on a **queued worker** (`ProgressEngine`, `csrc/runtime/`). Before the worker reads or sends, **`commit_mps_and_signal` / `wait_for_mps`** (`EventSync.mm`) align CPU access with a finished PyTorch MPS command buffer via **`MTLSharedEvent`**; `MCCL_EVENT_SYNC=0` forces stream sync instead. `ProcessGroupMCCL.cpp` submits work into this pipeline. [docs/DEVELOPING.md](docs/DEVELOPING.md) covers collectives and transport.
159
+
160
+ ## License
161
+
162
+ MIT — [LICENSE](LICENSE)
mccl-0.3.0/README.md ADDED
@@ -0,0 +1,138 @@
1
+ # MCCL
2
+
3
+ [![CI](https://github.com/OWNER/REPO/actions/workflows/ci.yml/badge.svg)](https://github.com/OWNER/REPO/actions/workflows/ci.yml)
4
+
5
+ `torch.distributed` backend for DDP and collectives on **MPS** (Apple Silicon). TCP by default; RDMA only if the machine/OS actually supports it.
6
+
7
+ ## Requirements
8
+
9
+ - Apple Silicon Mac (arm64). No Intel.
10
+ - **Xcode Command Line Tools** — `xcode-select --install` (needed to compile the extension).
11
+ - **Full Xcode** — optional; speeds up Metal by emitting a `.metallib` at build time instead of JIT at runtime.
12
+ - **Python 3.11+**
13
+ - **PyTorch 2.5+** installed *before* you build or `pip install` this package.
14
+
15
+ ## Install
16
+
17
+ ```bash
18
+ pip install torch
19
+ pip install mccl
20
+ ```
21
+
22
+ Source tree: `pip install -e ".[dev]"`. If the PyPI name `mccl` is taken, rename in `pyproject.toml` and `setup.py`.
23
+
24
+ Demo: https://github.com/user-attachments/assets/21865149-b077-4b65-93cc-f9e319ff0328
25
+
26
+ ## Examples
27
+
28
+ ```bash
29
+ python examples/ddp_dummy_train.py --baseline
30
+ torchrun --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=29500 \
31
+ examples/ddp_dummy_train.py
32
+ ```
33
+
34
+ Defaults there: DDP `BATCH_SIZE=128` per rank → global 256 with 2 ranks; baseline path uses global 256 unless you override. Shrink batch if you OOM.
35
+
36
+ Minimal DDP script (run with `torchrun` below). Multi-node needs `MCCL_LISTEN_ADDR`, `MCCL_PORT_BASE`, etc. — [docs/MULTINODE.md](docs/MULTINODE.md).
37
+
38
+ ```python
39
+ import os
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.distributed as dist
43
+ from torch.nn.parallel import DistributedDataParallel as DDP
44
+ import mccl
45
+
46
+ def main():
47
+ rank = int(os.environ["RANK"])
48
+ world_size = int(os.environ["WORLD_SIZE"])
49
+ device = torch.device("mps:0")
50
+
51
+ dist.init_process_group(backend="mccl", device_id=device)
52
+
53
+ torch.manual_seed(42)
54
+ model = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 10)).to(device)
55
+ ddp_model = DDP(model)
56
+ optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=1e-3)
57
+ loss_fn = nn.CrossEntropyLoss()
58
+
59
+ for step in range(10):
60
+ x = torch.randn(8, 512, device=device)
61
+ y = torch.randint(0, 10, (8,), device=device)
62
+ optimizer.zero_grad(set_to_none=True)
63
+ loss_fn(ddp_model(x), y).backward()
64
+ optimizer.step()
65
+ if rank == 0:
66
+ print(step, "ok")
67
+
68
+ dist.destroy_process_group()
69
+
70
+ if __name__ == "__main__":
71
+ main()
72
+ ```
73
+
74
+ ```bash
75
+ torchrun --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=29500 your_train.py
76
+ ```
77
+
78
+ ## Throughput
79
+
80
+ One saved run, **M4 Max** + **M1 Max** MBPs, TCP over TB, global batch **256**, ~**96.5M** params. Your numbers will differ.
81
+
82
+ ```
83
+ single M1 Max (MPS): 78.3 samples/s (global_batch=256, world=1)
84
+ DDP (MCCL): 134.2 samples/s (global_batch=256, world=2)
85
+ baseline / DDP: 0.58× (~172% DDP vs baseline)
86
+ ```
87
+
88
+ Tiny batches = comm noise dominates. Different chips on each rank = slowest one paces the step.
89
+
90
+ ```bash
91
+ python examples/ddp_dummy_train.py --baseline --save-stats baseline_stats.json
92
+ torchrun --nproc_per_node=2 --nnodes=1 --master_addr=127.0.0.1 --master_port=29500 \
93
+ examples/ddp_dummy_train.py --save-stats ddp_stats.json
94
+ python examples/benchmark_throughput.py --baseline baseline_stats.json --ddp ddp_stats.json -o bench
95
+ ```
96
+
97
+ `bash scripts/benchmark_matrix.sh` for more sweeps.
98
+
99
+ ![bench](bench.png)
100
+ ![bars](bench_bars.png)
101
+
102
+ ## PyPI (maintainers)
103
+
104
+ **CI (tests only):** push to **`main`** or **`master`**, or open a PR targeting those branches. That runs [`.github/workflows/ci.yml`](.github/workflows/ci.yml) — it does **not** upload to PyPI.
105
+
106
+ **Upload:** [`.github/workflows/publish.yml`](.github/workflows/publish.yml) runs on **GitHub Release (published)** or **Actions → Publish to PyPI → Run workflow** (`workflow_dispatch`).
107
+
108
+ 1. GitHub repo → **Settings → Secrets and variables → Actions** → New repository secret **`PYPI_API_TOKEN`** (PyPI → Account settings → API tokens).
109
+ 2. Bump **`version`** in `pyproject.toml`, `setup.py`, and the assertion in `tests/test_build.py`.
110
+ 3. Either: **Releases → Draft a new release** → publish (triggers upload), or **Actions** tab → **Publish to PyPI** → **Run workflow** → branch `main`.
111
+
112
+ First-time PyPI: create the **`mccl`** project on pypi.org (or change the package `name` everywhere if the name is taken).
113
+
114
+ ## Collectives
115
+
116
+ `allreduce`, `broadcast`, `barrier`, `allgather`, `reduce_scatter`, `send`, `recv`
117
+
118
+ ## Diagnostics
119
+
120
+ ```python
121
+ mccl.get_metrics(); mccl.log_metrics(); mccl.reset_metrics()
122
+ ```
123
+
124
+ Verbose startup: `MCCL_LOG_LEVEL=INFO`. Stuck multi-node: [docs/MULTINODE.md](docs/MULTINODE.md).
125
+
126
+ ## Transport
127
+
128
+ Bench plots were TCP over a Thunderbolt-style link, not RDMA. Wi‑Fi/Ethernet work, just slower. TB wiring: [docs/THUNDERBOLT_SETUP.md](docs/THUNDERBOLT_SETUP.md). RDMA path exists on TB5-capable hardware + `librdma.dylib`; `rdma_ctl enable` from Recovery once; we didn’t use that for the graphs above.
129
+
130
+ ## Internals
131
+
132
+ Apple Silicon is **UMA**: GPU and CPU share a physical memory pool. MPS tensors are usually **`MTLBuffer`s**; with **`MTLStorageModeShared`**, `buffer.contents` is a CPU pointer into the **same pages** the GPU uses (`extract_mps_buffer`, `MPSInterop.mm`). MCCL **exploits that** by staging sends from that pointer, writing receives with `memcpy` into it, and running **Accelerate/vDSP** in **`AccelerateOps.mm`** on the same bytes—no duplicate host tensor when the fast path applies. **Private** GPU storage still needs a **blit** through a shared staging buffer (`chunked_blit_*`).
133
+
134
+ I/O runs on a **queued worker** (`ProgressEngine`, `csrc/runtime/`). Before the worker reads or sends, **`commit_mps_and_signal` / `wait_for_mps`** (`EventSync.mm`) align CPU access with a finished PyTorch MPS command buffer via **`MTLSharedEvent`**; `MCCL_EVENT_SYNC=0` forces stream sync instead. `ProcessGroupMCCL.cpp` submits work into this pipeline. [docs/DEVELOPING.md](docs/DEVELOPING.md) covers collectives and transport.
135
+
136
+ ## License
137
+
138
+ MIT — [LICENSE](LICENSE)
@@ -0,0 +1,144 @@
1
+ /**
2
+ * MPS dispatch key registration for c10d collective ops.
3
+ *
4
+ * PyTorch 2.10 routes all c10d ops through the c10 Dispatcher, keyed by the
5
+ * device type of the tensor arguments. Out of the box only CPU, CUDA, and
6
+ * PrivateUse1 keys are registered. This file adds the MPS key so that
7
+ * collectives on MPS tensors route to the MCCL backend.
8
+ *
9
+ * Each implementation is a thin shim that extracts the Backend for MPS from
10
+ * the ProcessGroup and forwards to the corresponding virtual method — the
11
+ * same pattern PyTorch uses internally for CPU/CUDA in Ops.cpp.
12
+ */
13
+
14
+ #include <torch/library.h>
15
+ #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
16
+ #include <torch/csrc/distributed/c10d/Types.hpp>
17
+
18
+ namespace mccl {
19
+ namespace {
20
+
21
+ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<c10d::Work>>
22
+ allreduce_MPS(
23
+ at::TensorList tensors,
24
+ const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
25
+ const c10::intrusive_ptr<c10d::ReduceOp>& reduce_op,
26
+ const std::optional<at::Tensor>& sparse_indices,
27
+ bool async_op,
28
+ int64_t timeout) {
29
+ auto tensor_vec = tensors.vec();
30
+ c10d::AllreduceOptions opts;
31
+ opts.reduceOp = *reduce_op;
32
+ opts.sparseIndices = sparse_indices;
33
+ opts.asyncOp = async_op;
34
+ if (timeout >= 0)
35
+ opts.timeout = std::chrono::milliseconds(timeout);
36
+ auto work =
37
+ process_group->getBackend(c10::DeviceType::MPS)->allreduce(tensor_vec, opts);
38
+ return {std::move(tensor_vec), work};
39
+ }
40
+
41
+ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<c10d::Work>>
42
+ broadcast_MPS(
43
+ at::TensorList tensors,
44
+ const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
45
+ int64_t root_rank,
46
+ int64_t root_tensor,
47
+ bool async_op,
48
+ int64_t timeout) {
49
+ auto tensor_vec = tensors.vec();
50
+ c10d::BroadcastOptions opts;
51
+ opts.rootRank = root_rank;
52
+ opts.rootTensor = root_tensor;
53
+ opts.asyncOp = async_op;
54
+ if (timeout >= 0)
55
+ opts.timeout = std::chrono::milliseconds(timeout);
56
+ auto work =
57
+ process_group->getBackend(c10::DeviceType::MPS)->broadcast(tensor_vec, opts);
58
+ return {std::move(tensor_vec), work};
59
+ }
60
+
61
+ c10::intrusive_ptr<c10d::Work> barrier_MPS(
62
+ at::Tensor /* tensor */,
63
+ const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
64
+ const std::vector<int64_t>& device_ids,
65
+ bool async_op,
66
+ int64_t timeout) {
67
+ c10d::BarrierOptions opts;
68
+ opts.device_ids = device_ids;
69
+ opts.asyncOp = async_op;
70
+ if (timeout >= 0)
71
+ opts.timeout = std::chrono::milliseconds(timeout);
72
+ return process_group->getBackend(c10::DeviceType::MPS)->barrier(opts);
73
+ }
74
+
75
+ std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<c10d::Work>>
76
+ allgather_MPS(
77
+ const std::vector<std::vector<at::Tensor>>& output_tensors,
78
+ at::TensorList input_tensors,
79
+ const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
80
+ bool async_op,
81
+ int64_t timeout) {
82
+ auto output_vec = output_tensors;
83
+ auto input_vec = input_tensors.vec();
84
+ c10d::AllgatherOptions opts;
85
+ opts.asyncOp = async_op;
86
+ if (timeout >= 0)
87
+ opts.timeout = std::chrono::milliseconds(timeout);
88
+ auto work = process_group->getBackend(c10::DeviceType::MPS)
89
+ ->allgather(output_vec, input_vec, opts);
90
+ return {std::move(output_vec), work};
91
+ }
92
+
93
+ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<c10d::Work>>
94
+ reduce_scatter_MPS(
95
+ const at::TensorList& output_tensors,
96
+ const std::vector<std::vector<at::Tensor>>& input_tensors,
97
+ const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
98
+ const c10::intrusive_ptr<c10d::ReduceOp>& reduce_op,
99
+ bool async_op,
100
+ int64_t timeout) {
101
+ auto output_vec = output_tensors.vec();
102
+ auto input_vec = input_tensors;
103
+ c10d::ReduceScatterOptions opts;
104
+ opts.reduceOp = *reduce_op;
105
+ opts.asyncOp = async_op;
106
+ if (timeout >= 0)
107
+ opts.timeout = std::chrono::milliseconds(timeout);
108
+ auto work = process_group->getBackend(c10::DeviceType::MPS)
109
+ ->reduce_scatter(output_vec, input_vec, opts);
110
+ return {std::move(output_vec), work};
111
+ }
112
+
113
+ c10::intrusive_ptr<c10d::Work> send_MPS(
114
+ at::TensorList tensors,
115
+ const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
116
+ int64_t dst_rank,
117
+ int64_t tag) {
118
+ auto tensor_vec = tensors.vec();
119
+ return process_group->getBackend(c10::DeviceType::MPS)
120
+ ->send(tensor_vec, static_cast<int>(dst_rank), static_cast<int>(tag));
121
+ }
122
+
123
+ c10::intrusive_ptr<c10d::Work> recv_MPS(
124
+ at::TensorList tensors,
125
+ const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
126
+ int64_t src_rank,
127
+ int64_t tag) {
128
+ auto tensor_vec = tensors.vec();
129
+ return process_group->getBackend(c10::DeviceType::MPS)
130
+ ->recv(tensor_vec, static_cast<int>(src_rank), static_cast<int>(tag));
131
+ }
132
+
133
+ TORCH_LIBRARY_IMPL(c10d, MPS, m) {
134
+ m.impl("allreduce_", allreduce_MPS);
135
+ m.impl("broadcast_", broadcast_MPS);
136
+ m.impl("barrier", barrier_MPS);
137
+ m.impl("allgather_", allgather_MPS);
138
+ m.impl("reduce_scatter_", reduce_scatter_MPS);
139
+ m.impl("send", send_MPS);
140
+ m.impl("recv_", recv_MPS);
141
+ }
142
+
143
+ } // anonymous namespace
144
+ } // namespace mccl
@@ -0,0 +1,47 @@
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+ #include <c10d/ProcessGroup.hpp>
5
+ #include <chrono>
6
+ #include <string>
7
+
8
+ #include "compression/Compression.hpp"
9
+
10
+ namespace mccl {
11
+
12
+ struct MCCLOptions : public c10d::Backend::Options {
13
+ MCCLOptions()
14
+ : c10d::Backend::Options("mccl", std::chrono::milliseconds(30000)) {}
15
+
16
+ explicit MCCLOptions(std::chrono::milliseconds timeout)
17
+ : c10d::Backend::Options("mccl", timeout) {}
18
+
19
+ // Transport
20
+ std::string transport = "auto"; // "auto", "tcp", "rdma"
21
+ std::string listen_addr = "0.0.0.0";
22
+ uint16_t port_base = 29600;
23
+ std::string ifname;
24
+ size_t chunk_bytes = 4 * 1024 * 1024;
25
+ size_t small_msg_threshold = 65536;
26
+ bool transport_crc = false;
27
+
28
+ // Compute
29
+ bool fast_math = true;
30
+ uint32_t gpu_threshold = 4096;
31
+ bool overlap_comm = true;
32
+
33
+ // Engine
34
+ size_t max_queue_depth = 1024;
35
+
36
+ // Compression
37
+ CompressionMode compression = CompressionMode::NONE;
38
+ double topk_ratio = 0.01;
39
+
40
+ // Watchdog
41
+ std::chrono::milliseconds watchdog_timeout{300000};
42
+
43
+ // Health monitor
44
+ std::chrono::milliseconds heartbeat_interval{5000};
45
+ };
46
+
47
+ } // namespace mccl