torch-chamfer-dist 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Janos
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,2 @@
1
+ include README.md LICENSE
2
+ recursive-include chamfer/src *
@@ -0,0 +1,87 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-chamfer-dist
3
+ Version: 0.1.1
4
+ Summary: Chamfer distance with Metal/MPS acceleration (macOS)
5
+ Author: Janos
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/Janos95/chamfer
8
+ Requires-Python: >=3.9
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: torch>=2.1
12
+ Requires-Dist: nanobind>=2.0
13
+ Dynamic: license-file
14
+
15
+ # torch-chamfer-dist
16
+
17
+ `torch-chamfer-dist` provides a fast Chamfer distance implementation for PyTorch. On macOS with
18
+ Metal/MPS it runs kd-tree nearest-neighbour queries directly on the GPU; elsewhere it falls back to
19
+ an optimized CPU kd-tree. Autograd support is built in.
20
+
21
+ ## Installation
22
+
23
+ ```bash
24
+ pip install torch-chamfer-dist
25
+ ```
26
+
27
+ The provided wheel targets macOS 13+ (arm64 and x86_64). On other platforms the CPU backend is
28
+ selected automatically.
29
+
30
+ ## Quick start
31
+
32
+ ```python
33
+ import torch
34
+ import chamfer
35
+
36
+ # Create two point clouds on the desired device ("mps" for Metal, "cpu" otherwise)
37
+ a = torch.rand(5_000, 3, device="mps")
38
+ b = torch.rand(5_000, 3, device="mps")
39
+
40
+ # Nearest neighbours via kd-tree
41
+ dist_idx, dist_sq = chamfer.closest_points(a, b)
42
+
43
+ # Chamfer distance with gradients
44
+ loss = chamfer.chamfer_distance(a, b)
45
+ loss.backward()
46
+ ```
47
+
48
+ The device of the inputs determines the backend. When both tensors live on MPS the Metal kernel is
49
+ used; otherwise a CPU kd-tree path runs. Gradients are computed on the same device without host
50
+ roundtrips.
51
+
52
+ ## Benchmarks
53
+
54
+ The repository ships a benchmark script comparing brute-force, CPU kd-tree, and Metal kd-tree
55
+ implementations. Example (20k points per cloud on an M2 Pro):
56
+
57
+ ```
58
+ Method | Forward | Backward
59
+ ------------+-------------------+------------------
60
+ Brute force | 0.885 s | 1.829 s
61
+ KD-tree CPU | 0.139 s (6.39x) | 0.269 s (6.79x)
62
+ KD-tree MPS | 0.008 s (115.31x) | 0.012 s (147.63x)
63
+ ```
64
+
65
+ Run the benchmark locally:
66
+
67
+ ```bash
68
+ PYTHONPATH=. python benchmarks/benchmark_chamfer.py --n 20000 --chunk 4096 --repeat 3
69
+ ```
70
+
71
+ Set `CHAMFER_PROFILE=1` to emit per-stage timings (tree build, kernel wait, etc.).
72
+
73
+ ## Development
74
+
75
+ - Install dependencies: `pip install torch nanobind pytest build`.
76
+ - Run tests: `python -m pytest`.
77
+ - Build wheel: `python -m build`.
78
+
79
+ ### Publishing to PyPI
80
+
81
+ ```bash
82
+ python -m pip install --upgrade build twine
83
+ python -m build
84
+ python -m twine upload dist/*
85
+ ```
86
+
87
+ Remember to bump the version in `pyproject.toml` before tagging and uploading a release.
@@ -0,0 +1,73 @@
1
+ # torch-chamfer-dist
2
+
3
+ `torch-chamfer-dist` provides a fast Chamfer distance implementation for PyTorch. On macOS with
4
+ Metal/MPS it runs kd-tree nearest-neighbour queries directly on the GPU; elsewhere it falls back to
5
+ an optimized CPU kd-tree. Autograd support is built in.
6
+
7
+ ## Installation
8
+
9
+ ```bash
10
+ pip install torch-chamfer-dist
11
+ ```
12
+
13
+ The provided wheel targets macOS 13+ (arm64 and x86_64). On other platforms the CPU backend is
14
+ selected automatically.
15
+
16
+ ## Quick start
17
+
18
+ ```python
19
+ import torch
20
+ import chamfer
21
+
22
+ # Create two point clouds on the desired device ("mps" for Metal, "cpu" otherwise)
23
+ a = torch.rand(5_000, 3, device="mps")
24
+ b = torch.rand(5_000, 3, device="mps")
25
+
26
+ # Nearest neighbours via kd-tree
27
+ dist_idx, dist_sq = chamfer.closest_points(a, b)
28
+
29
+ # Chamfer distance with gradients
30
+ loss = chamfer.chamfer_distance(a, b)
31
+ loss.backward()
32
+ ```
33
+
34
+ The device of the inputs determines the backend. When both tensors live on MPS the Metal kernel is
35
+ used; otherwise a CPU kd-tree path runs. Gradients are computed on the same device without host
36
+ roundtrips.
37
+
38
+ ## Benchmarks
39
+
40
+ The repository ships a benchmark script comparing brute-force, CPU kd-tree, and Metal kd-tree
41
+ implementations. Example (20k points per cloud on an M2 Pro):
42
+
43
+ ```
44
+ Method | Forward | Backward
45
+ ------------+-------------------+------------------
46
+ Brute force | 0.885 s | 1.829 s
47
+ KD-tree CPU | 0.139 s (6.39x) | 0.269 s (6.79x)
48
+ KD-tree MPS | 0.008 s (115.31x) | 0.012 s (147.63x)
49
+ ```
50
+
51
+ Run the benchmark locally:
52
+
53
+ ```bash
54
+ PYTHONPATH=. python benchmarks/benchmark_chamfer.py --n 20000 --chunk 4096 --repeat 3
55
+ ```
56
+
57
+ Set `CHAMFER_PROFILE=1` to emit per-stage timings (tree build, kernel wait, etc.).
58
+
59
+ ## Development
60
+
61
+ - Install dependencies: `pip install torch nanobind pytest build`.
62
+ - Run tests: `python -m pytest`.
63
+ - Build wheel: `python -m build`.
64
+
65
+ ### Publishing to PyPI
66
+
67
+ ```bash
68
+ python -m pip install --upgrade build twine
69
+ python -m build
70
+ python -m twine upload dist/*
71
+ ```
72
+
73
+ Remember to bump the version in `pyproject.toml` before tagging and uploading a release.
@@ -0,0 +1,240 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import site
5
+ from pathlib import Path
6
+ from typing import Tuple
7
+
8
+ import torch
9
+
10
+ __all__ = ["closest_points", "chamfer_distance"]
11
+
12
+ _EXTENSION = None
13
+
14
+
15
+ def _extension() -> object:
16
+ global _EXTENSION
17
+ if _EXTENSION is not None:
18
+ return _EXTENSION
19
+
20
+ try:
21
+ import chamfer_ext # type: ignore
22
+ except ImportError:
23
+ src_dir = Path(__file__).resolve().parent / "src"
24
+ if not src_dir.exists():
25
+ raise RuntimeError(
26
+ "chamfer_ext extension not built. Install from wheel or run setup.py."
27
+ ) from None
28
+
29
+ from torch.utils.cpp_extension import load
30
+ import nanobind
31
+
32
+ nanobind_root = Path(nanobind.__file__).resolve().parent
33
+ nb_combined = nanobind_root / "src" / "nb_combined.cpp"
34
+
35
+ sources = [
36
+ src_dir / "metal_bridge.mm",
37
+ src_dir / "kd_tree.cpp",
38
+ nb_combined,
39
+ ]
40
+ include_dirs = [
41
+ str(src_dir),
42
+ str(nanobind_root / "include"),
43
+ str(nanobind_root / "ext" / "robin_map" / "include"),
44
+ ]
45
+
46
+ os.environ.setdefault("MACOSX_DEPLOYMENT_TARGET", "13.0")
47
+ user_bin = Path(site.getuserbase()) / "bin"
48
+ if user_bin.exists():
49
+ current_path = os.environ.get("PATH", "")
50
+ if str(user_bin) not in current_path.split(os.pathsep):
51
+ os.environ["PATH"] = os.pathsep.join(
52
+ [str(user_bin)] + ([current_path] if current_path else [])
53
+ )
54
+
55
+ extra_cflags = ["-std=c++20", "-fobjc-arc", "-fvisibility=hidden"]
56
+ extra_ldflags = ["-framework", "Metal", "-framework", "Foundation"]
57
+
58
+ chamfer_ext = load(
59
+ name="chamfer_ext",
60
+ sources=[str(path) for path in sources if path.exists()],
61
+ extra_include_paths=include_dirs,
62
+ extra_cflags=extra_cflags,
63
+ extra_ldflags=extra_ldflags,
64
+ verbose=False,
65
+ )
66
+
67
+ _EXTENSION = chamfer_ext
68
+ return _EXTENSION
69
+
70
+
71
+ def _mps_available() -> bool:
72
+ return bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
73
+
74
+
75
+ def _validate_pair(query: torch.Tensor, reference: torch.Tensor) -> None:
76
+ if query.dim() != 2:
77
+ raise ValueError("query tensor must be 2D [N, K]")
78
+ if reference.dim() != 2:
79
+ raise ValueError("reference tensor must be 2D [M, K]")
80
+ if query.size(1) != reference.size(1):
81
+ raise ValueError("query and reference tensors must have matching feature dimensions")
82
+
83
+ def _require_device(tensor: torch.Tensor, device: str, name: str) -> None:
84
+ if tensor.device.type != device:
85
+ raise ValueError(f"{name} tensor must live on {device}, but found {tensor.device.type}")
86
+
87
+
88
+ def _require_float32(tensor: torch.Tensor, name: str) -> None:
89
+ if tensor.dtype != torch.float32:
90
+ raise ValueError(f"{name} tensor must be float32, but found {tensor.dtype}")
91
+
92
+
93
+ def _prepare_backend_tensors(
94
+ query: torch.Tensor, reference: torch.Tensor, *, is_mps: bool
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ device = "mps" if is_mps else "cpu"
97
+ _require_device(query, device, "query")
98
+ _require_device(reference, device, "reference")
99
+ _require_float32(query, "query")
100
+ _require_float32(reference, "reference")
101
+ return query.contiguous(), reference.contiguous()
102
+
103
+
104
+ def _decide_backend(
105
+ query: torch.Tensor, reference: torch.Tensor, use_mps: bool | None
106
+ ) -> bool:
107
+ mps_available = _mps_available()
108
+ inputs_on_mps = query.device.type == "mps" and reference.device.type == "mps"
109
+ inputs_on_cpu = query.device.type == "cpu" and reference.device.type == "cpu"
110
+
111
+ if use_mps is True:
112
+ if not mps_available:
113
+ raise RuntimeError("MPS was requested, but torch.backends.mps.is_available() is False")
114
+ if not inputs_on_mps:
115
+ raise ValueError("MPS execution requires both tensors to be on the mps device")
116
+ return True
117
+
118
+ if use_mps is False:
119
+ if not inputs_on_cpu:
120
+ raise ValueError("CPU execution requires both tensors to be on the cpu device")
121
+ return False
122
+
123
+ if inputs_on_mps:
124
+ if not mps_available:
125
+ raise RuntimeError("Input tensors are on MPS, but the MPS backend is unavailable")
126
+ return True
127
+
128
+ if inputs_on_cpu:
129
+ return False
130
+
131
+ raise ValueError("query and reference must both reside on either CPU or MPS device")
132
+
133
+
134
+ def closest_points(
135
+ query: torch.Tensor,
136
+ reference: torch.Tensor,
137
+ *,
138
+ use_mps: bool | None = None,
139
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
140
+ """Return (indices, squared distances) of nearest neighbours in *reference* for each query point.
141
+
142
+ The search uses a kd-tree constructed on the CPU but traversed on the GPU via MPS/Metal.
143
+ """
144
+
145
+ _validate_pair(query, reference)
146
+ use_mps_flag = _decide_backend(query, reference, use_mps)
147
+ query_prepped, reference_prepped = _prepare_backend_tensors(query, reference, is_mps=use_mps_flag)
148
+ ext = _extension()
149
+ if use_mps_flag:
150
+ return ext.kd_query(query_prepped, reference_prepped)
151
+ if not hasattr(ext, "kd_query_cpu"):
152
+ raise RuntimeError("CPU kd-tree query is not available in the compiled extension")
153
+ return ext.kd_query_cpu(query_prepped, reference_prepped)
154
+
155
+
156
+ class _ChamferDistanceFunction(torch.autograd.Function):
157
+ @staticmethod
158
+ def forward(ctx, a: torch.Tensor, b: torch.Tensor, use_mps_flag: bool | None = None) -> torch.Tensor:
159
+ if a.device != b.device:
160
+ raise ValueError("points_a and points_b must be on the same device")
161
+ assert a.device.type in {"cpu", "mps"}, "Unsupported device for chamfer_distance"
162
+
163
+ _validate_pair(a, b)
164
+ backend_is_mps = _decide_backend(a, b, use_mps_flag)
165
+ a_prepped, b_prepped = _prepare_backend_tensors(a, b, is_mps=backend_is_mps)
166
+
167
+ idx_ab_tensor, _ = closest_points(a_prepped, b_prepped, use_mps=backend_is_mps)
168
+ idx_ba_tensor, _ = closest_points(b_prepped, a_prepped, use_mps=backend_is_mps)
169
+
170
+ idx_ab = idx_ab_tensor.to(device=b_prepped.device, dtype=torch.long)
171
+ idx_ba = idx_ba_tensor.to(device=a_prepped.device, dtype=torch.long)
172
+
173
+ nn_ab = torch.index_select(b_prepped, 0, idx_ab)
174
+ nn_ba = torch.index_select(a_prepped, 0, idx_ba)
175
+
176
+ diff_ab = a_prepped - nn_ab
177
+ diff_ba = b_prepped - nn_ba
178
+
179
+ loss_ab = torch.sum(diff_ab * diff_ab, dim=1).mean()
180
+ loss_ba = torch.sum(diff_ba * diff_ba, dim=1).mean()
181
+ loss = loss_ab + loss_ba
182
+
183
+ ctx.save_for_backward(
184
+ a_prepped,
185
+ b_prepped,
186
+ idx_ab_tensor.to(torch.long),
187
+ idx_ba_tensor.to(torch.long),
188
+ )
189
+ ctx.sizes = (a_prepped.shape[0], b_prepped.shape[0])
190
+
191
+ return loss
192
+
193
+ @staticmethod
194
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
195
+ a, b, idx_ab_saved, idx_ba_saved = ctx.saved_tensors
196
+ n_a, n_b = ctx.sizes
197
+
198
+ grad_a = grad_b = None
199
+ scalar_a = grad_output.to(device=a.device, dtype=a.dtype)
200
+ scalar_b = grad_output.to(device=b.device, dtype=b.dtype)
201
+
202
+ # All tensors are either on CPU or MPS; keep computations there.
203
+ assert a.device == b.device == idx_ab_saved.device == idx_ba_saved.device
204
+
205
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
206
+ idx_ab = idx_ab_saved.to(device=b.device)
207
+ nn_ab = torch.index_select(b, 0, idx_ab)
208
+ diff_ab = a - nn_ab
209
+
210
+ coeff_ab = (2.0 / float(n_a)) * scalar_a
211
+
212
+ if ctx.needs_input_grad[1] or ctx.needs_input_grad[0]:
213
+ idx_ba = idx_ba_saved.to(device=a.device)
214
+ nn_ba = torch.index_select(a, 0, idx_ba)
215
+ diff_ba = b - nn_ba
216
+
217
+ coeff_ba = (2.0 / float(n_b)) * scalar_b
218
+
219
+ if ctx.needs_input_grad[0]:
220
+ grad_a = coeff_ab * diff_ab
221
+ grad_a = grad_a.contiguous()
222
+ scatter_idx = idx_ba_saved
223
+ grad_a.index_add_(0, scatter_idx, (-coeff_ba) * diff_ba)
224
+
225
+ if ctx.needs_input_grad[1]:
226
+ grad_b = coeff_ba * diff_ba
227
+ grad_b = grad_b.contiguous()
228
+ scatter_idx = idx_ab_saved
229
+ grad_b.index_add_(0, scatter_idx, (-coeff_ab) * diff_ab)
230
+
231
+ return grad_a, grad_b, None
232
+
233
+
234
+ def chamfer_distance(
235
+ points_a: torch.Tensor,
236
+ points_b: torch.Tensor,
237
+ *,
238
+ use_mps: bool | None = None,
239
+ ) -> torch.Tensor:
240
+ return _ChamferDistanceFunction.apply(points_a, points_b, use_mps)
@@ -0,0 +1,86 @@
1
+ #include "kd_tree.hpp"
2
+
3
+ #include <algorithm>
4
+ #include <atomic>
5
+ #include <functional>
6
+ #include <future>
7
+ #include <numeric>
8
+ #include <stdexcept>
9
+
10
+ namespace chamfer {
11
+
12
+ std::vector<KDNodeGPU> build_kd_tree(const float* points, int64_t num_points, int64_t dims) {
13
+ if (num_points <= 0) {
14
+ throw std::invalid_argument("build_kd_tree: num_points must be positive");
15
+ }
16
+ if (dims <= 0) {
17
+ throw std::invalid_argument("build_kd_tree: dims must be positive");
18
+ }
19
+
20
+ std::vector<int> order(num_points);
21
+ std::iota(order.begin(), order.end(), 0);
22
+
23
+ std::vector<KDNodeGPU> gpu_nodes(static_cast<size_t>(num_points));
24
+ std::atomic<int> next_index{0};
25
+
26
+ const int dims_int = static_cast<int>(dims);
27
+ const int max_parallel_depth = 2;
28
+ const int parallel_threshold = 2048;
29
+
30
+ std::function<int(int, int, int)> build = [&](int start, int end, int depth) -> int {
31
+ if (start >= end) {
32
+ return -1;
33
+ }
34
+
35
+ int axis = depth % dims_int;
36
+ int mid = (start + end) / 2;
37
+
38
+ auto comparator = [points, dims_int, axis, &order](int lhs, int rhs) {
39
+ float l = points[static_cast<int64_t>(lhs) * dims_int + axis];
40
+ float r = points[static_cast<int64_t>(rhs) * dims_int + axis];
41
+ if (l == r) {
42
+ return lhs < rhs;
43
+ }
44
+ return l < r;
45
+ };
46
+
47
+ std::nth_element(order.begin() + start, order.begin() + mid, order.begin() + end, comparator);
48
+
49
+ int current = next_index.fetch_add(1, std::memory_order_relaxed);
50
+ KDNodeGPU& node = gpu_nodes[static_cast<size_t>(current)];
51
+ node.point_index = order[mid];
52
+ node.split_dim = axis;
53
+ node.split_value = points[static_cast<int64_t>(node.point_index) * dims_int + axis];
54
+ node.pad0 = 0.0f;
55
+ node.pad1 = 0.0f;
56
+ node.pad2 = 0.0f;
57
+
58
+ const bool parallel = depth < max_parallel_depth && (end - start) > parallel_threshold;
59
+
60
+ int left_index;
61
+ int right_index;
62
+ if (parallel) {
63
+ auto future_left = std::async(std::launch::async, [&]() {
64
+ return build(start, mid, depth + 1);
65
+ });
66
+ right_index = build(mid + 1, end, depth + 1);
67
+ left_index = future_left.get();
68
+ } else {
69
+ left_index = build(start, mid, depth + 1);
70
+ right_index = build(mid + 1, end, depth + 1);
71
+ }
72
+
73
+ node.left = left_index;
74
+ node.right = right_index;
75
+ return current;
76
+ };
77
+
78
+ int root_index = build(0, static_cast<int>(num_points), 0);
79
+ (void)root_index;
80
+
81
+ gpu_nodes.resize(static_cast<size_t>(next_index.load(std::memory_order_relaxed)));
82
+
83
+ return gpu_nodes;
84
+ }
85
+
86
+ } // namespace chamfer
@@ -0,0 +1,22 @@
1
+ #pragma once
2
+
3
+ #include <vector>
4
+ #include <cstddef>
5
+
6
+ namespace chamfer {
7
+
8
+ struct KDNodeGPU {
9
+ int left;
10
+ int right;
11
+ int point_index;
12
+ int split_dim;
13
+ float split_value;
14
+ float pad0;
15
+ float pad1;
16
+ float pad2;
17
+ };
18
+
19
+ std::vector<KDNodeGPU> build_kd_tree(const float* points, int64_t num_points, int64_t dims);
20
+
21
+ }
22
+
@@ -0,0 +1,503 @@
1
+ #import <Foundation/Foundation.h>
2
+ #import <Metal/Metal.h>
3
+
4
+ #include <nanobind/nanobind.h>
5
+ #include <torch/extension.h>
6
+ #include <torch/csrc/autograd/python_variable.h>
7
+ #include <ATen/mps/MPSStream.h>
8
+
9
+ #include <algorithm>
10
+ #include <cstring>
11
+ #include <limits>
12
+ #include <mutex>
13
+ #include <stdexcept>
14
+ #include <string>
15
+ #include <vector>
16
+ #include <mach/mach_time.h>
17
+
18
+ #include "kd_tree.hpp"
19
+
20
+ namespace nb = nanobind;
21
+
22
+ namespace {
23
+
24
+ inline id<MTLBuffer> tensor_to_mtl_buffer(const at::Tensor& tensor) {
25
+ return (__bridge id<MTLBuffer>)(tensor.storage().data());
26
+ }
27
+
28
+ struct TimebaseInfo {
29
+ uint64_t numer = 0;
30
+ uint64_t denom = 0;
31
+ TimebaseInfo() {
32
+ mach_timebase_info_data_t info;
33
+ mach_timebase_info(&info);
34
+ numer = info.numer;
35
+ denom = info.denom;
36
+ }
37
+ double to_millis(uint64_t delta) const {
38
+ double nanoseconds = static_cast<double>(delta) * static_cast<double>(numer) / static_cast<double>(denom);
39
+ return nanoseconds / 1e6;
40
+ }
41
+ };
42
+
43
+ const TimebaseInfo& timebase() {
44
+ static TimebaseInfo info;
45
+ return info;
46
+ }
47
+
48
+ bool should_profile() {
49
+ static bool initialized = false;
50
+ static bool enabled = false;
51
+ if (!initialized) {
52
+ const char* env = std::getenv("CHAMFER_PROFILE");
53
+ enabled = env && std::strlen(env) > 0;
54
+ initialized = true;
55
+ }
56
+ return enabled;
57
+ }
58
+
59
+ struct ScopedTimer {
60
+ const TimebaseInfo& info;
61
+ uint64_t start;
62
+ std::string label;
63
+ bool enabled;
64
+ ScopedTimer(const TimebaseInfo& info, std::string lbl, bool en)
65
+ : info(info), start(en ? mach_absolute_time() : 0), label(std::move(lbl)), enabled(en) {}
66
+ ~ScopedTimer() {
67
+ if (enabled) {
68
+ uint64_t end = mach_absolute_time();
69
+ double ms = info.to_millis(end - start);
70
+ fprintf(stderr, "[chamfer] %s: %.3f ms\n", label.c_str(), ms);
71
+ }
72
+ }
73
+ };
74
+
75
+ constexpr const char* kMetalSource = R"(using namespace metal;
76
+
77
+ struct KDNode {
78
+ int left;
79
+ int right;
80
+ int point_index;
81
+ int split_dim;
82
+ float split_value;
83
+ float pad0;
84
+ float pad1;
85
+ float pad2;
86
+ };
87
+
88
+ inline float distance_squared(const device float* a,
89
+ const device float* b,
90
+ int dims) {
91
+ float acc = 0.0f;
92
+ for (int i = 0; i < dims; ++i) {
93
+ float diff = a[i] - b[i];
94
+ acc += diff * diff;
95
+ }
96
+ return acc;
97
+ }
98
+
99
+ kernel void kd_query(device const float* ref_points [[buffer(0)]],
100
+ device const KDNode* nodes [[buffer(1)]],
101
+ constant int& num_nodes [[buffer(2)]],
102
+ constant int& dims [[buffer(3)]],
103
+ device const float* queries [[buffer(4)]],
104
+ constant int& num_queries [[buffer(5)]],
105
+ device int* out_indices [[buffer(6)]],
106
+ device float* out_distances [[buffer(7)]],
107
+ uint gid [[thread_position_in_grid]]) {
108
+ if (gid >= static_cast<uint>(num_queries)) {
109
+ return;
110
+ }
111
+
112
+ constexpr int STACK_CAP = 128;
113
+ int stack[STACK_CAP];
114
+ int stack_size = 0;
115
+
116
+ if (num_nodes > 0) {
117
+ stack[stack_size++] = 0;
118
+ }
119
+
120
+ device const float* query = queries + static_cast<size_t>(gid) * static_cast<size_t>(dims);
121
+
122
+ float best_dist = INFINITY;
123
+ int best_index = -1;
124
+
125
+ while (stack_size > 0) {
126
+ int node_idx = stack[--stack_size];
127
+ if (node_idx < 0 || node_idx >= num_nodes) {
128
+ continue;
129
+ }
130
+
131
+ KDNode node = nodes[node_idx];
132
+ int point_idx = node.point_index;
133
+ device const float* point = ref_points + static_cast<size_t>(point_idx) * static_cast<size_t>(dims);
134
+
135
+ float dist = distance_squared(query, point, dims);
136
+ if (dist < best_dist) {
137
+ best_dist = dist;
138
+ best_index = point_idx;
139
+ }
140
+
141
+ int left = node.left;
142
+ int right = node.right;
143
+ if (left < 0 && right < 0) {
144
+ continue;
145
+ }
146
+
147
+ float diff = query[node.split_dim] - node.split_value;
148
+ int near_child = diff <= 0.0f ? left : right;
149
+ int far_child = diff <= 0.0f ? right : left;
150
+
151
+ if (far_child >= 0 && stack_size < STACK_CAP && diff * diff < best_dist) {
152
+ stack[stack_size++] = far_child;
153
+ }
154
+ if (near_child >= 0 && stack_size < STACK_CAP) {
155
+ stack[stack_size++] = near_child;
156
+ }
157
+ }
158
+
159
+ if (best_index < 0) {
160
+ best_dist = 0.0f;
161
+ }
162
+
163
+ out_indices[gid] = best_index;
164
+ out_distances[gid] = best_dist;
165
+ }
166
+ )";
167
+
168
+ struct MetalContext {
169
+ id<MTLDevice> device = nil;
170
+ id<MTLCommandQueue> queue = nil;
171
+ id<MTLLibrary> library = nil;
172
+ id<MTLComputePipelineState> pipeline = nil;
173
+ bool initialized = false;
174
+ bool attempted = false;
175
+ std::string error_message;
176
+ };
177
+
178
+ MetalContext& get_context() {
179
+ static MetalContext ctx;
180
+ return ctx;
181
+ }
182
+
183
+ void initialize_metal_once() {
184
+ auto& ctx = get_context();
185
+ static std::once_flag once_flag;
186
+ std::call_once(once_flag, [&ctx]() {
187
+ ctx.attempted = true;
188
+ ctx.device = MTLCreateSystemDefaultDevice();
189
+ if (!ctx.device) {
190
+ ctx.error_message = "No Metal-capable device available for MPS";
191
+ return;
192
+ }
193
+ ctx.queue = [ctx.device newCommandQueue];
194
+ if (!ctx.queue) {
195
+ ctx.error_message = "Failed to create Metal command queue";
196
+ return;
197
+ }
198
+
199
+ NSError* error = nil;
200
+ NSString* source = [[NSString alloc] initWithUTF8String:kMetalSource];
201
+ MTLCompileOptions* options = [[MTLCompileOptions alloc] init];
202
+ options.fastMathEnabled = YES;
203
+
204
+ ctx.library = [ctx.device newLibraryWithSource:source options:options error:&error];
205
+ if (!ctx.library) {
206
+ std::string message = "Failed to compile Metal library: ";
207
+ if (error) {
208
+ message += [[error localizedDescription] UTF8String];
209
+ }
210
+ ctx.error_message = message;
211
+ return;
212
+ }
213
+
214
+ id<MTLFunction> function = [ctx.library newFunctionWithName:@"kd_query"];
215
+ if (!function) {
216
+ ctx.error_message = "Failed to load kd_query function from Metal library";
217
+ return;
218
+ }
219
+
220
+ ctx.pipeline = [ctx.device newComputePipelineStateWithFunction:function error:&error];
221
+ if (!ctx.pipeline) {
222
+ std::string message = "Failed to create pipeline state: ";
223
+ if (error) {
224
+ message += [[error localizedDescription] UTF8String];
225
+ }
226
+ ctx.error_message = message;
227
+ return;
228
+ }
229
+
230
+ ctx.initialized = true;
231
+ });
232
+ }
233
+
234
+ void ensure_initialized() {
235
+ initialize_metal_once();
236
+ auto& ctx = get_context();
237
+
238
+ if (!ctx.initialized) {
239
+ if (!ctx.error_message.empty()) {
240
+ throw std::runtime_error(ctx.error_message);
241
+ }
242
+ throw std::runtime_error("Metal context failed to initialize");
243
+ }
244
+ }
245
+
246
+ const at::Tensor& tensor_from_nb(nb::handle h) {
247
+ if (!THPVariable_Check(h.ptr())) {
248
+ throw nb::type_error("expected a torch.Tensor");
249
+ }
250
+ return THPVariable_Unpack(h.ptr());
251
+ }
252
+
253
+ nb::tuple kd_tree_query(nb::handle query_handle, nb::handle reference_handle) {
254
+ torch::NoGradGuard guard;
255
+
256
+ const bool profile = should_profile();
257
+ const TimebaseInfo& tinfo = timebase();
258
+ ScopedTimer total_timer(tinfo, "kd_query_total", profile);
259
+
260
+ const at::Tensor& query_in = tensor_from_nb(query_handle);
261
+ const at::Tensor& reference_in = tensor_from_nb(reference_handle);
262
+
263
+ if (query_in.dim() != 2) {
264
+ throw std::invalid_argument("query tensor must be 2D [N, K]");
265
+ }
266
+ if (reference_in.dim() != 2) {
267
+ throw std::invalid_argument("reference tensor must be 2D [M, K]");
268
+ }
269
+ if (query_in.size(1) != reference_in.size(1)) {
270
+ throw std::invalid_argument("query and reference tensors must have the same dimensionality");
271
+ }
272
+
273
+ if (!query_in.device().is_mps() || !reference_in.device().is_mps()) {
274
+ throw std::invalid_argument("kd_query expects query and reference tensors on MPS device");
275
+ }
276
+ if (query_in.scalar_type() != at::kFloat || reference_in.scalar_type() != at::kFloat) {
277
+ throw std::invalid_argument("kd_query expects float32 tensors");
278
+ }
279
+
280
+ int64_t dims = query_in.size(1);
281
+ int64_t num_query = query_in.size(0);
282
+ int64_t num_reference = reference_in.size(0);
283
+
284
+ if (num_reference == 0) {
285
+ throw std::invalid_argument("reference set must contain at least one point");
286
+ }
287
+
288
+ at::Tensor query_mps = query_in.contiguous();
289
+ at::Tensor reference_mps = reference_in.contiguous();
290
+
291
+ at::mps::getCurrentMPSStream()->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
292
+
293
+ ensure_initialized();
294
+ auto& ctx = get_context();
295
+
296
+ at::Tensor reference_cpu;
297
+ {
298
+ ScopedTimer cpu_copy_timer(tinfo, "kd_query_copy_to_cpu", profile);
299
+ reference_cpu = reference_mps.to(at::kCPU).contiguous();
300
+ }
301
+
302
+ std::vector<chamfer::KDNodeGPU> nodes;
303
+ {
304
+ ScopedTimer build_timer(tinfo, "kd_tree_build", profile);
305
+ nodes = chamfer::build_kd_tree(reference_cpu.data_ptr<float>(), num_reference, dims);
306
+ }
307
+ if (nodes.empty()) {
308
+ throw std::runtime_error("Failed to build kd-tree");
309
+ }
310
+
311
+ NSUInteger node_bytes = static_cast<NSUInteger>(nodes.size() * sizeof(chamfer::KDNodeGPU));
312
+ id<MTLBuffer> node_buffer = [ctx.device newBufferWithBytes:nodes.data()
313
+ length:node_bytes
314
+ options:MTLResourceStorageModeShared];
315
+ if (!node_buffer) {
316
+ throw std::runtime_error("Failed to allocate node buffers");
317
+ }
318
+
319
+ auto indices_tensor = torch::empty({num_query}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kMPS));
320
+ auto distances_tensor = torch::empty({num_query}, torch::TensorOptions().dtype(torch::kFloat).device(torch::kMPS));
321
+
322
+ id<MTLBuffer> points_buffer = tensor_to_mtl_buffer(reference_mps);
323
+ id<MTLBuffer> query_buffer = tensor_to_mtl_buffer(query_mps);
324
+ id<MTLBuffer> indices_buffer = tensor_to_mtl_buffer(indices_tensor);
325
+ id<MTLBuffer> distances_buffer = tensor_to_mtl_buffer(distances_tensor);
326
+
327
+ if (!points_buffer || !query_buffer || !node_buffer || !indices_buffer || !distances_buffer) {
328
+ throw std::runtime_error("Failed to allocate Metal buffers");
329
+ }
330
+
331
+ id<MTLCommandBuffer> command_buffer = [ctx.queue commandBuffer];
332
+ if (!command_buffer) {
333
+ throw std::runtime_error("Failed to create Metal command buffer");
334
+ }
335
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
336
+ [encoder setComputePipelineState:ctx.pipeline];
337
+
338
+ int num_nodes = static_cast<int>(nodes.size());
339
+ int dims_i = static_cast<int>(dims);
340
+ int num_query_i = static_cast<int>(num_query);
341
+
342
+ NSUInteger points_offset = static_cast<NSUInteger>(reference_mps.storage_offset() * reference_mps.element_size());
343
+ NSUInteger query_offset = static_cast<NSUInteger>(query_mps.storage_offset() * query_mps.element_size());
344
+ NSUInteger indices_offset = static_cast<NSUInteger>(indices_tensor.storage_offset() * indices_tensor.element_size());
345
+ NSUInteger distances_offset = static_cast<NSUInteger>(distances_tensor.storage_offset() * distances_tensor.element_size());
346
+
347
+ [encoder setBuffer:points_buffer offset:points_offset atIndex:0];
348
+ [encoder setBuffer:node_buffer offset:0 atIndex:1];
349
+ [encoder setBytes:&num_nodes length:sizeof(int) atIndex:2];
350
+ [encoder setBytes:&dims_i length:sizeof(int) atIndex:3];
351
+ [encoder setBuffer:query_buffer offset:query_offset atIndex:4];
352
+ [encoder setBytes:&num_query_i length:sizeof(int) atIndex:5];
353
+ [encoder setBuffer:indices_buffer offset:indices_offset atIndex:6];
354
+ [encoder setBuffer:distances_buffer offset:distances_offset atIndex:7];
355
+
356
+ NSUInteger max_threads = ctx.pipeline.maxTotalThreadsPerThreadgroup;
357
+ if (max_threads == 0) {
358
+ max_threads = 64;
359
+ }
360
+ NSUInteger threadgroup_size = std::min<NSUInteger>(max_threads, 256);
361
+ MTLSize threads_per_threadgroup = MTLSizeMake(threadgroup_size, 1, 1);
362
+ NSUInteger grid_threads = static_cast<NSUInteger>(num_query);
363
+ NSUInteger groups = (grid_threads + threadgroup_size - 1) / threadgroup_size;
364
+ MTLSize threads_per_grid = MTLSizeMake(groups * threadgroup_size, 1, 1);
365
+ {
366
+ ScopedTimer dispatch_timer(tinfo, "kd_query_dispatch", profile);
367
+ [encoder dispatchThreads:threads_per_grid threadsPerThreadgroup:threads_per_threadgroup];
368
+ [encoder endEncoding];
369
+ [command_buffer commit];
370
+ }
371
+
372
+ {
373
+ ScopedTimer wait_timer(tinfo, "kd_query_wait", profile);
374
+ [command_buffer waitUntilCompleted];
375
+ }
376
+
377
+ PyObject* indices_obj = THPVariable_Wrap(indices_tensor);
378
+ PyObject* distances_obj = THPVariable_Wrap(distances_tensor);
379
+
380
+ return nb::make_tuple(nb::steal<nb::object>(indices_obj), nb::steal<nb::object>(distances_obj));
381
+ }
382
+
383
+ nb::tuple kd_tree_query_cpu(nb::handle query_handle, nb::handle reference_handle) {
384
+ torch::NoGradGuard guard;
385
+
386
+ const at::Tensor& query_in = tensor_from_nb(query_handle);
387
+ const at::Tensor& reference_in = tensor_from_nb(reference_handle);
388
+
389
+ if (query_in.dim() != 2) {
390
+ throw std::invalid_argument("query tensor must be 2D [N, K]");
391
+ }
392
+ if (reference_in.dim() != 2) {
393
+ throw std::invalid_argument("reference tensor must be 2D [M, K]");
394
+ }
395
+ if (query_in.size(1) != reference_in.size(1)) {
396
+ throw std::invalid_argument("query and reference tensors must have the same dimensionality");
397
+ }
398
+
399
+ int64_t dims = query_in.size(1);
400
+ int64_t num_query = query_in.size(0);
401
+ int64_t num_reference = reference_in.size(0);
402
+
403
+ if (num_reference == 0) {
404
+ throw std::invalid_argument("reference set must contain at least one point");
405
+ }
406
+
407
+ at::Tensor query_cpu = query_in;
408
+ if (!query_cpu.device().is_cpu() || query_cpu.scalar_type() != at::kFloat || !query_cpu.is_contiguous()) {
409
+ query_cpu = query_in.to(at::kCPU, at::kFloat).contiguous();
410
+ }
411
+
412
+ at::Tensor reference_cpu = reference_in;
413
+ if (!reference_cpu.device().is_cpu() || reference_cpu.scalar_type() != at::kFloat || !reference_cpu.is_contiguous()) {
414
+ reference_cpu = reference_in.to(at::kCPU, at::kFloat).contiguous();
415
+ }
416
+
417
+ auto nodes = chamfer::build_kd_tree(reference_cpu.data_ptr<float>(), num_reference, dims);
418
+ if (nodes.empty()) {
419
+ throw std::runtime_error("Failed to build kd-tree");
420
+ }
421
+
422
+ auto indices_tensor = torch::empty({num_query}, torch::dtype(torch::kInt32).device(torch::kCPU));
423
+ auto distances_tensor = torch::empty({num_query}, torch::dtype(torch::kFloat).device(torch::kCPU));
424
+
425
+ const float* query_ptr = query_cpu.data_ptr<float>();
426
+ const float* reference_ptr = reference_cpu.data_ptr<float>();
427
+ int32_t* index_ptr = indices_tensor.data_ptr<int32_t>();
428
+ float* distance_ptr = distances_tensor.data_ptr<float>();
429
+
430
+ std::vector<int> stack;
431
+ stack.reserve(64);
432
+
433
+ for (int64_t qi = 0; qi < num_query; ++qi) {
434
+ const float* query = query_ptr + qi * dims;
435
+ float best_dist = std::numeric_limits<float>::infinity();
436
+ int best_index = -1;
437
+
438
+ stack.clear();
439
+ if (!nodes.empty()) {
440
+ stack.push_back(0);
441
+ }
442
+
443
+ while (!stack.empty()) {
444
+ int node_idx = stack.back();
445
+ stack.pop_back();
446
+ if (node_idx < 0 || node_idx >= static_cast<int>(nodes.size())) {
447
+ continue;
448
+ }
449
+
450
+ const auto& node = nodes[node_idx];
451
+ int point_idx = node.point_index;
452
+ const float* point = reference_ptr + static_cast<int64_t>(point_idx) * dims;
453
+
454
+ float dist = 0.0f;
455
+ for (int64_t d = 0; d < dims; ++d) {
456
+ float diff = query[d] - point[d];
457
+ dist += diff * diff;
458
+ }
459
+
460
+ if (dist < best_dist) {
461
+ best_dist = dist;
462
+ best_index = point_idx;
463
+ }
464
+
465
+ int left = node.left;
466
+ int right = node.right;
467
+ if (left < 0 && right < 0) {
468
+ continue;
469
+ }
470
+
471
+ float diff = query[node.split_dim] - node.split_value;
472
+ int near_child = diff <= 0.0f ? left : right;
473
+ int far_child = diff <= 0.0f ? right : left;
474
+
475
+ if (far_child >= 0 && diff * diff < best_dist) {
476
+ stack.push_back(far_child);
477
+ }
478
+ if (near_child >= 0) {
479
+ stack.push_back(near_child);
480
+ }
481
+ }
482
+
483
+ if (best_index < 0) {
484
+ best_dist = 0.0f;
485
+ best_index = 0;
486
+ }
487
+
488
+ index_ptr[qi] = best_index;
489
+ distance_ptr[qi] = best_dist;
490
+ }
491
+
492
+ PyObject* indices_obj = THPVariable_Wrap(indices_tensor);
493
+ PyObject* distances_obj = THPVariable_Wrap(distances_tensor);
494
+
495
+ return nb::make_tuple(nb::steal<nb::object>(indices_obj), nb::steal<nb::object>(distances_obj));
496
+ }
497
+
498
+ } // namespace
499
+
500
+ NB_MODULE(chamfer_ext, m) {
501
+ m.def("kd_query", &kd_tree_query, "KD-tree nearest neighbour query using Metal");
502
+ m.def("kd_query_cpu", &kd_tree_query_cpu, "KD-tree nearest neighbour query on CPU");
503
+ }
@@ -0,0 +1,24 @@
1
+ [build-system]
2
+ requires = ["setuptools>=69", "wheel", "torch", "nanobind"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "torch-chamfer-dist"
7
+ version = "0.1.1"
8
+ description = "Chamfer distance with Metal/MPS acceleration (macOS)"
9
+ readme = "README.md"
10
+ license = "MIT"
11
+ requires-python = ">=3.9"
12
+ authors = [{name = "Janos"}]
13
+ dependencies = ["torch>=2.1", "nanobind>=2.0"]
14
+
15
+ [project.urls]
16
+ Homepage = "https://github.com/Janos95/chamfer"
17
+
18
+ [tool.setuptools]
19
+ packages = ["chamfer"]
20
+ package-dir = {"chamfer" = "chamfer"}
21
+ include-package-data = true
22
+
23
+ [tool.setuptools.package-data]
24
+ chamfer = ["src/*.mm", "src/*.cpp", "src/*.hpp"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ from setuptools import Extension, setup
8
+ from setuptools.command.build_ext import build_ext
9
+
10
+
11
+ os.environ.setdefault("MACOSX_DEPLOYMENT_TARGET", "13.0")
12
+
13
+
14
+ def gather_include_dirs() -> list[str]:
15
+ import torch
16
+ import nanobind
17
+
18
+ includes = []
19
+ try:
20
+ from torch.utils.cpp_extension import include_paths
21
+ except ImportError:
22
+ torch_dir = Path(torch.__file__).resolve().parent
23
+ includes.append(str(torch_dir / "include"))
24
+ includes.append(str(torch_dir / "include" / "torch" / "csrc" / "api" / "include"))
25
+ else:
26
+ includes.extend(include_paths())
27
+
28
+ nb_root = Path(nanobind.__file__).resolve().parent
29
+ includes.append(str(nb_root / "include"))
30
+ includes.append(str(nb_root / "ext" / "robin_map" / "include"))
31
+ return includes
32
+
33
+
34
+ def gather_extra_sources() -> list[str]:
35
+ import nanobind
36
+
37
+ nb_root = Path(nanobind.__file__).resolve().parent
38
+ nb_combined = nb_root / "src" / "nb_combined.cpp"
39
+ if nb_combined.exists():
40
+ return [str(nb_combined)]
41
+ return []
42
+
43
+
44
+ class TorchBuildExt(build_ext):
45
+ def build_extensions(self) -> None:
46
+ include_dirs = gather_include_dirs()
47
+ extra_sources = gather_extra_sources()
48
+ compiler = self.compiler
49
+ if ".mm" not in compiler.src_extensions:
50
+ compiler.src_extensions.append(".mm")
51
+ compiler.language_map[".mm"] = "objc++"
52
+ for ext in self.extensions:
53
+ ext.include_dirs.extend(include_dirs)
54
+ ext.sources.extend(extra_sources)
55
+ super().build_extensions()
56
+
57
+
58
+ def make_extension() -> Extension:
59
+ return Extension(
60
+ "chamfer_ext",
61
+ sources=[
62
+ "chamfer/src/metal_bridge.mm",
63
+ "chamfer/src/kd_tree.cpp",
64
+ ],
65
+ extra_compile_args=["-std=c++20", "-fobjc-arc", "-fvisibility=hidden"],
66
+ extra_link_args=["-framework", "Metal", "-framework", "Foundation"],
67
+ language="c++",
68
+ )
69
+
70
+
71
+ IS_BUILDING_SDIST = "sdist" in sys.argv
72
+ extensions = [] if IS_BUILDING_SDIST else [make_extension()]
73
+
74
+ setup(
75
+ cmdclass={"build_ext": TorchBuildExt},
76
+ ext_modules=extensions,
77
+ )
@@ -0,0 +1,109 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ import chamfer
7
+
8
+
9
+ def brute_force_closest(query: torch.Tensor, reference: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
10
+ query = query.to(torch.float32).contiguous()
11
+ reference = reference.to(torch.float32).contiguous()
12
+ diff = query[:, None, :] - reference[None, :, :]
13
+ dists = torch.sum(diff * diff, dim=-1)
14
+ min_dists, indices = torch.min(dists, dim=1)
15
+ return indices.to(torch.int32), min_dists
16
+
17
+
18
+ def brute_chamfer(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
19
+ _, d1 = brute_force_closest(a, b)
20
+ _, d2 = brute_force_closest(b, a)
21
+ return d1.mean() + d2.mean()
22
+
23
+
24
+ @pytest.mark.parametrize("dims", [2, 3, 5])
25
+ @pytest.mark.parametrize("use_mps", [None, False])
26
+ def test_closest_points_matches_bruteforce(dims: int, use_mps: Optional[bool]) -> None:
27
+ torch.manual_seed(42 + dims)
28
+ query = torch.rand(128, dims)
29
+ reference = torch.rand(200, dims)
30
+
31
+ idx_gpu, dist_gpu = chamfer.closest_points(query, reference, use_mps=use_mps)
32
+ idx_cpu, dist_cpu = brute_force_closest(query, reference)
33
+
34
+ assert idx_gpu.shape == idx_cpu.shape == (query.size(0),)
35
+ assert torch.all(idx_gpu >= 0)
36
+ torch.testing.assert_close(dist_gpu, dist_cpu, atol=1e-5, rtol=1e-4)
37
+
38
+
39
+ @pytest.mark.parametrize("dims", [2, 3])
40
+ def test_chamfer_distance_matches_bruteforce(dims: int) -> None:
41
+ torch.manual_seed(123 + dims)
42
+ a = torch.rand(64, dims)
43
+ b = torch.rand(96, dims)
44
+
45
+ chamfer_gpu = chamfer.chamfer_distance(a, b, use_mps=False)
46
+ chamfer_cpu = brute_chamfer(a, b)
47
+
48
+ torch.testing.assert_close(chamfer_gpu, chamfer_cpu, atol=1e-5, rtol=1e-4)
49
+
50
+
51
+ @pytest.mark.parametrize("dims", [2, 3])
52
+ def test_grad_disabled_by_default(dims: int) -> None:
53
+ a = torch.randn(8, dims, requires_grad=False)
54
+ b = torch.randn(12, dims, requires_grad=False)
55
+
56
+ _, dists = chamfer.closest_points(a, b)
57
+ assert not dists.requires_grad
58
+
59
+
60
+ @pytest.mark.parametrize("dims", [2, 3])
61
+ def test_chamfer_distance_gradients_match_bruteforce(dims: int) -> None:
62
+ torch.manual_seed(321 + dims)
63
+ a = torch.rand(32, dims, requires_grad=True)
64
+ b = torch.rand(40, dims, requires_grad=True)
65
+
66
+ loss_kd = chamfer.chamfer_distance(a, b, use_mps=False)
67
+ grad_a_kd, grad_b_kd = torch.autograd.grad(loss_kd, (a, b), create_graph=False)
68
+
69
+ a_ref = a.detach().clone().requires_grad_(True)
70
+ b_ref = b.detach().clone().requires_grad_(True)
71
+ loss_brute = brute_chamfer(a_ref, b_ref)
72
+ grad_a_brute, grad_b_brute = torch.autograd.grad(loss_brute, (a_ref, b_ref), create_graph=False)
73
+
74
+ torch.testing.assert_close(grad_a_kd, grad_a_brute, atol=1e-4, rtol=1e-4)
75
+ torch.testing.assert_close(grad_b_kd, grad_b_brute, atol=1e-4, rtol=1e-4)
76
+
77
+
78
+ def test_closest_points_use_mps_requires_device() -> None:
79
+ query = torch.rand(16, 3)
80
+ reference = torch.rand(32, 3)
81
+ expected_exception = ValueError if torch.backends.mps.is_available() else RuntimeError
82
+ with pytest.raises(expected_exception):
83
+ chamfer.closest_points(query, reference, use_mps=True)
84
+
85
+
86
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS backend unavailable")
87
+ def test_closest_points_mps_matches_cpu() -> None:
88
+ torch.manual_seed(7)
89
+ query_cpu = torch.rand(64, 3)
90
+ reference_cpu = torch.rand(96, 3)
91
+
92
+ query_mps = query_cpu.to("mps")
93
+ reference_mps = reference_cpu.to("mps")
94
+
95
+ idx_mps, dist_mps = chamfer.closest_points(query_mps, reference_mps, use_mps=True)
96
+
97
+ assert idx_mps.device.type == "mps"
98
+ assert dist_mps.device.type == "mps"
99
+
100
+ _, dist_cpu = brute_force_closest(query_cpu, reference_cpu)
101
+ torch.testing.assert_close(dist_mps.cpu(), dist_cpu, atol=1e-5, rtol=1e-4)
102
+
103
+
104
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS backend unavailable")
105
+ def test_closest_points_mps_rejects_cpu_inputs() -> None:
106
+ query_mps = torch.rand(8, 3, device="mps")
107
+ reference_cpu = torch.rand(8, 3)
108
+ with pytest.raises(ValueError):
109
+ chamfer.closest_points(query_mps, reference_cpu, use_mps=True)
@@ -0,0 +1,87 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-chamfer-dist
3
+ Version: 0.1.1
4
+ Summary: Chamfer distance with Metal/MPS acceleration (macOS)
5
+ Author: Janos
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/Janos95/chamfer
8
+ Requires-Python: >=3.9
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: torch>=2.1
12
+ Requires-Dist: nanobind>=2.0
13
+ Dynamic: license-file
14
+
15
+ # torch-chamfer-dist
16
+
17
+ `torch-chamfer-dist` provides a fast Chamfer distance implementation for PyTorch. On macOS with
18
+ Metal/MPS it runs kd-tree nearest-neighbour queries directly on the GPU; elsewhere it falls back to
19
+ an optimized CPU kd-tree. Autograd support is built in.
20
+
21
+ ## Installation
22
+
23
+ ```bash
24
+ pip install torch-chamfer-dist
25
+ ```
26
+
27
+ The provided wheel targets macOS 13+ (arm64 and x86_64). On other platforms the CPU backend is
28
+ selected automatically.
29
+
30
+ ## Quick start
31
+
32
+ ```python
33
+ import torch
34
+ import chamfer
35
+
36
+ # Create two point clouds on the desired device ("mps" for Metal, "cpu" otherwise)
37
+ a = torch.rand(5_000, 3, device="mps")
38
+ b = torch.rand(5_000, 3, device="mps")
39
+
40
+ # Nearest neighbours via kd-tree
41
+ dist_idx, dist_sq = chamfer.closest_points(a, b)
42
+
43
+ # Chamfer distance with gradients
44
+ loss = chamfer.chamfer_distance(a, b)
45
+ loss.backward()
46
+ ```
47
+
48
+ The device of the inputs determines the backend. When both tensors live on MPS the Metal kernel is
49
+ used; otherwise a CPU kd-tree path runs. Gradients are computed on the same device without host
50
+ roundtrips.
51
+
52
+ ## Benchmarks
53
+
54
+ The repository ships a benchmark script comparing brute-force, CPU kd-tree, and Metal kd-tree
55
+ implementations. Example (20k points per cloud on an M2 Pro):
56
+
57
+ ```
58
+ Method | Forward | Backward
59
+ ------------+-------------------+------------------
60
+ Brute force | 0.885 s | 1.829 s
61
+ KD-tree CPU | 0.139 s (6.39x) | 0.269 s (6.79x)
62
+ KD-tree MPS | 0.008 s (115.31x) | 0.012 s (147.63x)
63
+ ```
64
+
65
+ Run the benchmark locally:
66
+
67
+ ```bash
68
+ PYTHONPATH=. python benchmarks/benchmark_chamfer.py --n 20000 --chunk 4096 --repeat 3
69
+ ```
70
+
71
+ Set `CHAMFER_PROFILE=1` to emit per-stage timings (tree build, kernel wait, etc.).
72
+
73
+ ## Development
74
+
75
+ - Install dependencies: `pip install torch nanobind pytest build`.
76
+ - Run tests: `python -m pytest`.
77
+ - Build wheel: `python -m build`.
78
+
79
+ ### Publishing to PyPI
80
+
81
+ ```bash
82
+ python -m pip install --upgrade build twine
83
+ python -m build
84
+ python -m twine upload dist/*
85
+ ```
86
+
87
+ Remember to bump the version in `pyproject.toml` before tagging and uploading a release.
@@ -0,0 +1,15 @@
1
+ LICENSE
2
+ MANIFEST.in
3
+ README.md
4
+ pyproject.toml
5
+ setup.py
6
+ chamfer/__init__.py
7
+ chamfer/src/kd_tree.cpp
8
+ chamfer/src/kd_tree.hpp
9
+ chamfer/src/metal_bridge.mm
10
+ tests/test_chamfer.py
11
+ torch_chamfer_dist.egg-info/PKG-INFO
12
+ torch_chamfer_dist.egg-info/SOURCES.txt
13
+ torch_chamfer_dist.egg-info/dependency_links.txt
14
+ torch_chamfer_dist.egg-info/requires.txt
15
+ torch_chamfer_dist.egg-info/top_level.txt
@@ -0,0 +1,2 @@
1
+ torch>=2.1
2
+ nanobind>=2.0