torchdtw 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,15 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchdtw
3
+ Version: 0.0.1
4
+ Summary: Add your description here
5
+ Author: Maxime Poli
6
+ Author-email: CoML <dev@cognitive-ml.fr>
7
+ License-Expression: MIT
8
+ Keywords: machine learning
9
+ Requires-Python: >=3.12
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: numpy>=1.26.4
12
+ Requires-Dist: torch>=2.9.0
13
+
14
+ # PyTorch DTW C++ extension
15
+
@@ -0,0 +1,2 @@
1
+ # PyTorch DTW C++ extension
2
+
@@ -0,0 +1,68 @@
1
+ [build-system]
2
+ requires = [
3
+ "setuptools>=80.9.0",
4
+ "torch>=2.9.0",
5
+ "numpy>=1.26.4",
6
+ "ninja>=1.11",
7
+ ]
8
+ build-backend = "setuptools.build_meta"
9
+
10
+ [project]
11
+ name = "torchdtw"
12
+ version = "0.0.1"
13
+ description = "Add your description here"
14
+ readme = "README.md"
15
+ requires-python = ">=3.12"
16
+ authors = [
17
+ { name = "Maxime Poli" },
18
+ { name = "CoML", email = "dev@cognitive-ml.fr" },
19
+ ]
20
+ keywords = ["machine learning"]
21
+ license = "MIT"
22
+ dependencies = [
23
+ "numpy>=1.26.4",
24
+ "torch>=2.9.0",
25
+ ]
26
+
27
+ [dependency-groups]
28
+ dev = [
29
+ "ipykernel>=7.1.0",
30
+ "ruff>=0.14.2",
31
+ "typos>=1.36.2",
32
+ ]
33
+ test = [
34
+ "hypothesis>=6.142.5",
35
+ "pytest>=8.4.2",
36
+ ]
37
+
38
+ [tool.ruff]
39
+ line-length = 119
40
+
41
+ [tool.ruff.lint]
42
+ select = ["ALL"]
43
+ ignore = [
44
+ "COM812", # missing-trailing-comma
45
+ "D105", # undocumented-magic-method
46
+ "D107", # undocumented-public-init
47
+ "D203", # incorrect-blank-line-before-class
48
+ "D213", # multi-line-summary-second-line
49
+ "N803", # invalid-argument-name
50
+ "N806", # non-lowercase-variable-in-function
51
+ "PLR0913", # too-many-arguments
52
+ ]
53
+
54
+ [tool.ruff.lint.pylint]
55
+ allow-magic-value-types = ["int"]
56
+
57
+ [tool.ruff.lint.flake8-self]
58
+ ignore-names = ["_check"]
59
+
60
+ [tool.ruff.lint.pep8-naming]
61
+ ignore-names = ["F"]
62
+
63
+ [tool.uv.workspace]
64
+ members = [
65
+ "benchmark",
66
+ "torchdtw",
67
+ ]
68
+
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,58 @@
1
+ """Build the DTW PyTorch C++ extension."""
2
+
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+ from setuptools import Extension, setup
8
+ from torch.torch_version import Version
9
+ from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CppExtension, CUDAExtension
10
+
11
+
12
+ def get_openmp_flags() -> tuple[list[str], list[str]]:
13
+ """Return the compiler and linker flags for OpenMP."""
14
+ match sys.platform:
15
+ case "linux":
16
+ compile_flags, link_flags = ["-fopenmp"], ["-fopenmp"]
17
+ case "win32":
18
+ compile_flags, link_flags = ["-openmp"], []
19
+ case _: # On MacOS, we use the OpenMP version vendored by PyTorch
20
+ return [], []
21
+ return compile_flags, link_flags
22
+
23
+
24
+ def get_cuda_arch_list() -> str:
25
+ """Supported CUDA architectures. Volta is not supported by CUDA 13.0."""
26
+ if torch.version.cuda is None or Version(torch.version.cuda) < Version("13.0"):
27
+ return "Volta;Turing;Ampere;Ada;Hopper"
28
+ return "Turing;Ampere;Ada;Hopper"
29
+
30
+
31
+ def get_extension() -> Extension:
32
+ """Either CUDA or CPU extension."""
33
+ use_cuda = CUDA_HOME is not None
34
+ extension = CUDAExtension if use_cuda else CppExtension
35
+ openmp_flags = get_openmp_flags()
36
+ extra_compile_args = {
37
+ "cxx": ["-fdiagnostics-color=always", "-O3"] + openmp_flags[0],
38
+ "nvcc": ["-O3"],
39
+ }
40
+ sources = ["src/torchdtw/csrc/dtw.cpp"]
41
+ if use_cuda:
42
+ os.environ["TORCH_CUDA_ARCH_LIST"] = get_cuda_arch_list()
43
+ sources.append("src/torchdtw/csrc/cuda/dtw.cu")
44
+ return extension(
45
+ "torchdtw._C",
46
+ sources,
47
+ extra_compile_args=extra_compile_args,
48
+ extra_link_args=openmp_flags[1],
49
+ py_limited_api=True,
50
+ )
51
+
52
+
53
+ if __name__ == "__main__":
54
+ setup(
55
+ ext_modules=[get_extension()],
56
+ cmdclass={"build_ext": BuildExtension},
57
+ options={"bdist_wheel": {"py_limited_api": "cp312"}},
58
+ )
@@ -0,0 +1,49 @@
1
+ """DTW implementation using PyTorch C++ extensions, with CPU and CUDA backends."""
2
+
3
+ import torch
4
+
5
+ from . import _C # noqa: F401 # ty: ignore[unresolved-import]
6
+
7
+ __all__ = ["dtw", "dtw_batch"]
8
+
9
+
10
+ def dtw(distances: torch.Tensor) -> torch.Tensor:
11
+ """Compute the DTW of the given ``distances`` 2D tensor.
12
+
13
+ :param distances: A 2D tensor of shape (n, m) representing the pairwise distances between two sequences.
14
+ """
15
+ return torch.ops.torchdtw.dtw.default(distances)
16
+
17
+
18
+ def dtw_batch(distances: torch.Tensor, sx: torch.Tensor, sy: torch.Tensor, *, symmetric: bool) -> torch.Tensor:
19
+ """Compute the batched DTW on the ``distances`` 4D tensor.
20
+
21
+ :param distances: A 4D tensor of shape (n1, n2, s1, s2) representing the pairwise distances between two
22
+ batches of sequences.
23
+ :param sx: A 1D tensor of shape (n1,) representing the lengths of the sequences in the first batch.
24
+ :param sy: A 1D tensor of shape (n2,) representing the lengths of the sequences in the second batch.
25
+ :param symmetric: Whether or not the DTW is symmetric (i.e., the two batches are the same).
26
+ """
27
+ return torch.ops.torchdtw.dtw_batch.default(distances, sx, sy, symmetric)
28
+
29
+
30
+ @torch.library.register_fake("torchdtw::dtw")
31
+ def _(distances: torch.Tensor) -> torch.Tensor:
32
+ """Register the FakeTensor kernel for dtw, for compatibility with torch.compile."""
33
+ torch._check(distances.ndim == 2)
34
+ torch._check(distances.dtype == torch.float32)
35
+ return torch.empty((), dtype=torch.float32, layout=distances.layout, device=distances.device)
36
+
37
+
38
+ @torch.library.register_fake("torchdtw::dtw_batch")
39
+ def _(distances: torch.Tensor, sx: torch.Tensor, sy: torch.Tensor, symmetric: bool) -> torch.Tensor: # noqa: FBT001
40
+ """Register the FakeTensor kernel for dtw_batch, for compatibility with torch.compile."""
41
+ torch._check(distances.ndim == 4)
42
+ torch._check(sx.ndim == 1)
43
+ torch._check(sy.ndim == 1)
44
+ torch._check(distances.dtype == torch.float32)
45
+ torch._check(sx.dtype == torch.long)
46
+ torch._check(sy.dtype == torch.long)
47
+ torch._check(isinstance(symmetric, bool))
48
+ nx, ny, _, _ = distances.shape
49
+ return torch.empty((nx, ny), dtype=torch.float32, layout=distances.layout, device=distances.device)
@@ -0,0 +1,140 @@
1
+ #include <Python.h>
2
+ #include <omp.h>
3
+ #include <torch/csrc/stable/library.h>
4
+ #include <torch/csrc/stable/ops.h>
5
+ #include <torch/csrc/stable/tensor.h>
6
+ #include <torch/headeronly/util/Exception.h>
7
+ #include <algorithm>
8
+ #include <vector>
9
+
10
+ extern "C" {
11
+ /* Creates a dummy empty _C module that can be imported from Python.
12
+ The import from Python will load the .so consisting of this file
13
+ in this extension, so that the STABLE_TORCH_LIBRARY static initializers
14
+ below are run. */
15
+ PyObject* PyInit__C(void) {
16
+ static struct PyModuleDef module_def = {
17
+ PyModuleDef_HEAD_INIT,
18
+ "_C", /* name of module */
19
+ NULL, /* module documentation, may be NULL */
20
+ -1, /* size of per-interpreter state of the module,
21
+ or -1 if the module keeps state in global variables. */
22
+ NULL, /* methods */
23
+ };
24
+ return PyModule_Create(&module_def);
25
+ }
26
+ }
27
+
28
+ namespace torchdtw {
29
+
30
+ using torch::stable::Tensor;
31
+
32
+ inline float dtw(
33
+ const float* distances,
34
+ const int64_t N,
35
+ const int64_t M,
36
+ const int64_t stride_x,
37
+ const int64_t stride_y) {
38
+ STD_TORCH_CHECK(N > 0 && M > 0, "Empty input tensor");
39
+ STD_TORCH_CHECK(stride_x > 0 && stride_y > 0, "Strides must be positive");
40
+ std::vector<float> cost(N * M);
41
+
42
+ cost[0] = distances[0];
43
+ for (int64_t i = 1; i < N; i++) {
44
+ cost[i * M] = distances[i * stride_x] + cost[(i - 1) * M];
45
+ }
46
+ for (int64_t j = 1; j < M; j++) {
47
+ cost[j] = distances[j * stride_y] + cost[j - 1];
48
+ }
49
+ for (int64_t i = 1; i < N; i++) {
50
+ for (int64_t j = 1; j < M; j++) {
51
+ cost[i * M + j] = distances[i * stride_x + j * stride_y] +
52
+ std::min({cost[(i - 1) * M + j], cost[(i - 1) * M + j - 1], cost[i * M + j - 1]});
53
+ }
54
+ }
55
+
56
+ int64_t path_len = 1;
57
+ int64_t i = N - 1;
58
+ int64_t j = M - 1;
59
+ while (i > 0 && j > 0) {
60
+ const float c_up = cost[(i - 1) * M + j];
61
+ const float c_left = cost[i * M + j - 1];
62
+ const float c_diag = cost[(i - 1) * M + j - 1];
63
+ if (c_diag <= c_left && c_diag <= c_up) {
64
+ i--;
65
+ j--;
66
+ } else if (c_left <= c_up) {
67
+ j--;
68
+ } else {
69
+ i--;
70
+ }
71
+ path_len++;
72
+ }
73
+ if (i == 0)
74
+ path_len += j;
75
+ if (j == 0)
76
+ path_len += i;
77
+ return cost[(N - 1) * M + M - 1] / path_len;
78
+ }
79
+
80
+ Tensor dtw_cpu(const Tensor distances) {
81
+ float result =
82
+ dtw(reinterpret_cast<const float*>(distances.data_ptr()),
83
+ distances.size(0),
84
+ distances.size(1),
85
+ distances.stride(0),
86
+ distances.stride(1));
87
+ Tensor out = torch::stable::new_empty(distances, {});
88
+ torch::stable::fill_(out, result);
89
+ return out;
90
+ }
91
+
92
+ Tensor dtw_batch_cpu(const Tensor distances, const Tensor sx, const Tensor sy, bool symmetric) {
93
+ const int64_t nx = distances.size(0);
94
+ const int64_t ny = distances.size(1);
95
+ Tensor out = torch::stable::new_zeros(distances, {nx, ny});
96
+
97
+ const float* distances_ptr = reinterpret_cast<const float*>(distances.data_ptr());
98
+ const int64_t* sx_ptr = reinterpret_cast<const int64_t*>(sx.data_ptr());
99
+ const int64_t* sy_ptr = reinterpret_cast<const int64_t*>(sy.data_ptr());
100
+ float* out_ptr = reinterpret_cast<float*>(out.data_ptr());
101
+
102
+ #pragma omp parallel for schedule(dynamic)
103
+ for (int64_t i = 0; i < nx; i++) {
104
+ const int64_t start_j = symmetric ? i : 0;
105
+ for (int64_t j = start_j; j < ny; j++) {
106
+ if (symmetric && i == j)
107
+ continue;
108
+ out_ptr[i * ny + j] =
109
+ dtw(distances_ptr + i * distances.stride(0) + j * distances.stride(1),
110
+ sx_ptr[i],
111
+ sy_ptr[j],
112
+ distances.stride(2),
113
+ distances.stride(3));
114
+ if (symmetric && i != j) {
115
+ out_ptr[j * ny + i] = out_ptr[i * ny + j];
116
+ }
117
+ }
118
+ };
119
+ return out;
120
+ }
121
+
122
+ void boxed_dtw_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
123
+ stack[0] = from(dtw_cpu(to<Tensor>(stack[0])));
124
+ }
125
+
126
+ void boxed_dtw_batch_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
127
+ stack[0] = from(dtw_batch_cpu(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]), to<bool>(stack[3])));
128
+ }
129
+
130
+ STABLE_TORCH_LIBRARY(torchdtw, m) {
131
+ m.def("dtw(Tensor distances) -> Tensor");
132
+ m.def("dtw_batch(Tensor distances, Tensor sx, Tensor sy, bool symmetric) -> Tensor");
133
+ }
134
+
135
+ STABLE_TORCH_LIBRARY_IMPL(torchdtw, CPU, m) {
136
+ m.impl("dtw", &boxed_dtw_cpu);
137
+ m.impl("dtw_batch", &boxed_dtw_batch_cpu);
138
+ }
139
+
140
+ } // namespace torchdtw
File without changes
@@ -0,0 +1,15 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchdtw
3
+ Version: 0.0.1
4
+ Summary: Add your description here
5
+ Author: Maxime Poli
6
+ Author-email: CoML <dev@cognitive-ml.fr>
7
+ License-Expression: MIT
8
+ Keywords: machine learning
9
+ Requires-Python: >=3.12
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: numpy>=1.26.4
12
+ Requires-Dist: torch>=2.9.0
13
+
14
+ # PyTorch DTW C++ extension
15
+
@@ -0,0 +1,13 @@
1
+ README.md
2
+ pyproject.toml
3
+ setup.py
4
+ src/torchdtw/__init__.py
5
+ src/torchdtw/py.typed
6
+ src/torchdtw.egg-info/PKG-INFO
7
+ src/torchdtw.egg-info/SOURCES.txt
8
+ src/torchdtw.egg-info/dependency_links.txt
9
+ src/torchdtw.egg-info/requires.txt
10
+ src/torchdtw.egg-info/top_level.txt
11
+ src/torchdtw/csrc/dtw.cpp
12
+ tests/test_dtw.py
13
+ tests/test_opcheck.py
@@ -0,0 +1,2 @@
1
+ numpy>=1.26.4
2
+ torch>=2.9.0
@@ -0,0 +1 @@
1
+ torchdtw
@@ -0,0 +1,63 @@
1
+ """Compare CPU and CUDA dtw implementations."""
2
+
3
+ import pytest
4
+ import torch
5
+ from hypothesis import given, settings
6
+ from hypothesis import strategies as st
7
+
8
+ from torchdtw import dtw, dtw_batch
9
+
10
+ rtol, atol = 0, 1e-9
11
+ skipifnogpu = pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available")
12
+
13
+ DIM, BATCH = st.integers(1, 1280), st.integers(1, 3)
14
+ LOW, HIGH_MINUS_LOW = st.floats(-100, 100), st.floats(0.1, 100)
15
+
16
+
17
+ def make_tensor(shape: tuple[int, ...], *, dtype: torch.dtype, low: float, high: float) -> torch.Tensor:
18
+ """Build a tensor for testing."""
19
+ if low == high and dtype == torch.long:
20
+ return torch.ones(shape, dtype=torch.long, device="cpu")
21
+ return torch.testing.make_tensor(shape, dtype=dtype, device="cpu", low=low, high=high)
22
+
23
+
24
+ @skipifnogpu
25
+ @given(x=DIM, y=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
26
+ @settings(deadline=None)
27
+ def test_dtw(x: int, y: int, low: float, high_minus_low: float) -> None:
28
+ """Compare the output of dtw between CPU and GPU implementations."""
29
+ d = make_tensor((x, y), dtype=torch.float32, low=low, high=high_minus_low + low)
30
+ torch.testing.assert_close(dtw(d), dtw(d.cuda()).cpu(), rtol=rtol, atol=atol)
31
+
32
+
33
+ @skipifnogpu
34
+ @given(n=BATCH, x=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
35
+ @settings(deadline=None)
36
+ def test_dtw_batch_symmetric(n: int, x: int, low: float, high_minus_low: float) -> None:
37
+ """Compare the output of dtw_batch between CPU and GPU implementations, symmetric case."""
38
+ d = make_tensor((n, n, x, x), dtype=torch.float32, low=low, high=high_minus_low + low)
39
+ sx = make_tensor((n,), dtype=torch.long, low=1, high=x)
40
+ i, j = torch.triu_indices(n, n)
41
+ d[i, j] = d[j, i]
42
+ torch.testing.assert_close(
43
+ dtw_batch(d, sx, sx, symmetric=True),
44
+ dtw_batch(d.cuda(), sx.cuda(), sx.cuda(), symmetric=True).cpu(),
45
+ rtol=rtol,
46
+ atol=atol,
47
+ )
48
+
49
+
50
+ @skipifnogpu
51
+ @given(n=BATCH, m=BATCH, x=DIM, y=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
52
+ @settings(deadline=None)
53
+ def test_dtw_batch_not_symmetric(n: int, m: int, x: int, y: int, low: float, high_minus_low: float) -> None:
54
+ """Compare the output of dtw_batch between CPU and GPU implementations, non symmetric case."""
55
+ d = make_tensor((n, m, x, y), dtype=torch.float32, low=low, high=high_minus_low + low)
56
+ sx = make_tensor((n,), dtype=torch.long, low=1, high=x)
57
+ sy = make_tensor((m,), dtype=torch.long, low=1, high=y)
58
+ torch.testing.assert_close(
59
+ dtw_batch(d, sx, sy, symmetric=False),
60
+ dtw_batch(d.cuda(), sx.cuda(), sy.cuda(), symmetric=False).cpu(),
61
+ rtol=rtol,
62
+ atol=atol,
63
+ )
@@ -0,0 +1,54 @@
1
+ """Check for compatibility with torch.compile."""
2
+
3
+ import torch
4
+ from hypothesis import given, settings
5
+ from hypothesis import strategies as st
6
+ from torch.library import opcheck
7
+
8
+ import torchdtw # noqa: F401 # Need to import it to register dtw operation
9
+
10
+ DIM, BATCH = st.integers(1, 1280), st.integers(1, 3)
11
+ LOW, HIGH_MINUS_LOW = st.floats(-100, 100), st.floats(0.1, 100)
12
+ CUDA_AVAILABLE = torch.cuda.is_available()
13
+
14
+
15
+ def make_tensor(shape: tuple[int, ...], *, dtype: torch.dtype, low: float, high: float) -> torch.Tensor:
16
+ """Build a tensor for testing."""
17
+ if low == high and dtype == torch.long:
18
+ return torch.ones(shape, dtype=torch.long, device="cpu")
19
+ return torch.testing.make_tensor(shape, dtype=dtype, device="cpu", low=low, high=high)
20
+
21
+
22
+ @given(x=DIM, y=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
23
+ @settings(deadline=None)
24
+ def test_opcheck_dtw(x: int, y: int, low: float, high_minus_low: float) -> None:
25
+ """Verify that dtw can be torch compiled."""
26
+ sample = make_tensor((x, y), dtype=torch.float32, low=low, high=high_minus_low + low)
27
+ opcheck(torch.ops.torchdtw.dtw.default, (sample,))
28
+ if CUDA_AVAILABLE:
29
+ opcheck(torch.ops.torchdtw.dtw.default, (sample.cuda(),))
30
+
31
+
32
+ @given(n=BATCH, x=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
33
+ @settings(deadline=None)
34
+ def test_opcheck_dtw_batch_symmetric(n: int, x: int, low: float, high_minus_low: float) -> None:
35
+ """Verify that dtw_batch can be torch compiled, with symmetric input."""
36
+ sample = make_tensor((n, n, x, x), dtype=torch.float32, low=low, high=high_minus_low + low)
37
+ sx = make_tensor((n,), dtype=torch.long, low=1, high=x)
38
+ i, j = torch.triu_indices(n, n)
39
+ sample[i, j] = sample[j, i]
40
+ opcheck(torch.ops.torchdtw.dtw_batch.default, (sample, sx, sx), {"symmetric": True})
41
+ if CUDA_AVAILABLE:
42
+ opcheck(torch.ops.torchdtw.dtw_batch.default, (sample.cuda(), sx.cuda(), sx.cuda()), {"symmetric": True})
43
+
44
+
45
+ @given(n=BATCH, m=BATCH, x=DIM, y=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
46
+ @settings(deadline=None)
47
+ def test_opcheck_dtw_batch_not_symmetric(n: int, m: int, x: int, y: int, low: float, high_minus_low: float) -> None:
48
+ """Verify that dtw_batch can be torch compiled, with symmetric input."""
49
+ sample = make_tensor((n, m, x, y), dtype=torch.float32, low=low, high=high_minus_low + low)
50
+ sx = make_tensor((n,), dtype=torch.long, low=1, high=x)
51
+ sy = make_tensor((m,), dtype=torch.long, low=1, high=y)
52
+ opcheck(torch.ops.torchdtw.dtw_batch.default, (sample, sx, sy), {"symmetric": False})
53
+ if CUDA_AVAILABLE:
54
+ opcheck(torch.ops.torchdtw.dtw_batch.default, (sample.cuda(), sx.cuda(), sy.cuda()), {"symmetric": False})