torchdtw 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchdtw-0.0.1/PKG-INFO +15 -0
- torchdtw-0.0.1/README.md +2 -0
- torchdtw-0.0.1/pyproject.toml +68 -0
- torchdtw-0.0.1/setup.cfg +4 -0
- torchdtw-0.0.1/setup.py +58 -0
- torchdtw-0.0.1/src/torchdtw/__init__.py +49 -0
- torchdtw-0.0.1/src/torchdtw/csrc/dtw.cpp +140 -0
- torchdtw-0.0.1/src/torchdtw/py.typed +0 -0
- torchdtw-0.0.1/src/torchdtw.egg-info/PKG-INFO +15 -0
- torchdtw-0.0.1/src/torchdtw.egg-info/SOURCES.txt +13 -0
- torchdtw-0.0.1/src/torchdtw.egg-info/dependency_links.txt +1 -0
- torchdtw-0.0.1/src/torchdtw.egg-info/requires.txt +2 -0
- torchdtw-0.0.1/src/torchdtw.egg-info/top_level.txt +1 -0
- torchdtw-0.0.1/tests/test_dtw.py +63 -0
- torchdtw-0.0.1/tests/test_opcheck.py +54 -0
torchdtw-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchdtw
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Author: Maxime Poli
|
|
6
|
+
Author-email: CoML <dev@cognitive-ml.fr>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Keywords: machine learning
|
|
9
|
+
Requires-Python: >=3.12
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: numpy>=1.26.4
|
|
12
|
+
Requires-Dist: torch>=2.9.0
|
|
13
|
+
|
|
14
|
+
# PyTorch DTW C++ extension
|
|
15
|
+
|
torchdtw-0.0.1/README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = [
|
|
3
|
+
"setuptools>=80.9.0",
|
|
4
|
+
"torch>=2.9.0",
|
|
5
|
+
"numpy>=1.26.4",
|
|
6
|
+
"ninja>=1.11",
|
|
7
|
+
]
|
|
8
|
+
build-backend = "setuptools.build_meta"
|
|
9
|
+
|
|
10
|
+
[project]
|
|
11
|
+
name = "torchdtw"
|
|
12
|
+
version = "0.0.1"
|
|
13
|
+
description = "Add your description here"
|
|
14
|
+
readme = "README.md"
|
|
15
|
+
requires-python = ">=3.12"
|
|
16
|
+
authors = [
|
|
17
|
+
{ name = "Maxime Poli" },
|
|
18
|
+
{ name = "CoML", email = "dev@cognitive-ml.fr" },
|
|
19
|
+
]
|
|
20
|
+
keywords = ["machine learning"]
|
|
21
|
+
license = "MIT"
|
|
22
|
+
dependencies = [
|
|
23
|
+
"numpy>=1.26.4",
|
|
24
|
+
"torch>=2.9.0",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[dependency-groups]
|
|
28
|
+
dev = [
|
|
29
|
+
"ipykernel>=7.1.0",
|
|
30
|
+
"ruff>=0.14.2",
|
|
31
|
+
"typos>=1.36.2",
|
|
32
|
+
]
|
|
33
|
+
test = [
|
|
34
|
+
"hypothesis>=6.142.5",
|
|
35
|
+
"pytest>=8.4.2",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
[tool.ruff]
|
|
39
|
+
line-length = 119
|
|
40
|
+
|
|
41
|
+
[tool.ruff.lint]
|
|
42
|
+
select = ["ALL"]
|
|
43
|
+
ignore = [
|
|
44
|
+
"COM812", # missing-trailing-comma
|
|
45
|
+
"D105", # undocumented-magic-method
|
|
46
|
+
"D107", # undocumented-public-init
|
|
47
|
+
"D203", # incorrect-blank-line-before-class
|
|
48
|
+
"D213", # multi-line-summary-second-line
|
|
49
|
+
"N803", # invalid-argument-name
|
|
50
|
+
"N806", # non-lowercase-variable-in-function
|
|
51
|
+
"PLR0913", # too-many-arguments
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
[tool.ruff.lint.pylint]
|
|
55
|
+
allow-magic-value-types = ["int"]
|
|
56
|
+
|
|
57
|
+
[tool.ruff.lint.flake8-self]
|
|
58
|
+
ignore-names = ["_check"]
|
|
59
|
+
|
|
60
|
+
[tool.ruff.lint.pep8-naming]
|
|
61
|
+
ignore-names = ["F"]
|
|
62
|
+
|
|
63
|
+
[tool.uv.workspace]
|
|
64
|
+
members = [
|
|
65
|
+
"benchmark",
|
|
66
|
+
"torchdtw",
|
|
67
|
+
]
|
|
68
|
+
|
torchdtw-0.0.1/setup.cfg
ADDED
torchdtw-0.0.1/setup.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Build the DTW PyTorch C++ extension."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from setuptools import Extension, setup
|
|
8
|
+
from torch.torch_version import Version
|
|
9
|
+
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CppExtension, CUDAExtension
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_openmp_flags() -> tuple[list[str], list[str]]:
|
|
13
|
+
"""Return the compiler and linker flags for OpenMP."""
|
|
14
|
+
match sys.platform:
|
|
15
|
+
case "linux":
|
|
16
|
+
compile_flags, link_flags = ["-fopenmp"], ["-fopenmp"]
|
|
17
|
+
case "win32":
|
|
18
|
+
compile_flags, link_flags = ["-openmp"], []
|
|
19
|
+
case _: # On MacOS, we use the OpenMP version vendored by PyTorch
|
|
20
|
+
return [], []
|
|
21
|
+
return compile_flags, link_flags
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_cuda_arch_list() -> str:
|
|
25
|
+
"""Supported CUDA architectures. Volta is not supported by CUDA 13.0."""
|
|
26
|
+
if torch.version.cuda is None or Version(torch.version.cuda) < Version("13.0"):
|
|
27
|
+
return "Volta;Turing;Ampere;Ada;Hopper"
|
|
28
|
+
return "Turing;Ampere;Ada;Hopper"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_extension() -> Extension:
|
|
32
|
+
"""Either CUDA or CPU extension."""
|
|
33
|
+
use_cuda = CUDA_HOME is not None
|
|
34
|
+
extension = CUDAExtension if use_cuda else CppExtension
|
|
35
|
+
openmp_flags = get_openmp_flags()
|
|
36
|
+
extra_compile_args = {
|
|
37
|
+
"cxx": ["-fdiagnostics-color=always", "-O3"] + openmp_flags[0],
|
|
38
|
+
"nvcc": ["-O3"],
|
|
39
|
+
}
|
|
40
|
+
sources = ["src/torchdtw/csrc/dtw.cpp"]
|
|
41
|
+
if use_cuda:
|
|
42
|
+
os.environ["TORCH_CUDA_ARCH_LIST"] = get_cuda_arch_list()
|
|
43
|
+
sources.append("src/torchdtw/csrc/cuda/dtw.cu")
|
|
44
|
+
return extension(
|
|
45
|
+
"torchdtw._C",
|
|
46
|
+
sources,
|
|
47
|
+
extra_compile_args=extra_compile_args,
|
|
48
|
+
extra_link_args=openmp_flags[1],
|
|
49
|
+
py_limited_api=True,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
if __name__ == "__main__":
|
|
54
|
+
setup(
|
|
55
|
+
ext_modules=[get_extension()],
|
|
56
|
+
cmdclass={"build_ext": BuildExtension},
|
|
57
|
+
options={"bdist_wheel": {"py_limited_api": "cp312"}},
|
|
58
|
+
)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""DTW implementation using PyTorch C++ extensions, with CPU and CUDA backends."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from . import _C # noqa: F401 # ty: ignore[unresolved-import]
|
|
6
|
+
|
|
7
|
+
__all__ = ["dtw", "dtw_batch"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def dtw(distances: torch.Tensor) -> torch.Tensor:
|
|
11
|
+
"""Compute the DTW of the given ``distances`` 2D tensor.
|
|
12
|
+
|
|
13
|
+
:param distances: A 2D tensor of shape (n, m) representing the pairwise distances between two sequences.
|
|
14
|
+
"""
|
|
15
|
+
return torch.ops.torchdtw.dtw.default(distances)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def dtw_batch(distances: torch.Tensor, sx: torch.Tensor, sy: torch.Tensor, *, symmetric: bool) -> torch.Tensor:
|
|
19
|
+
"""Compute the batched DTW on the ``distances`` 4D tensor.
|
|
20
|
+
|
|
21
|
+
:param distances: A 4D tensor of shape (n1, n2, s1, s2) representing the pairwise distances between two
|
|
22
|
+
batches of sequences.
|
|
23
|
+
:param sx: A 1D tensor of shape (n1,) representing the lengths of the sequences in the first batch.
|
|
24
|
+
:param sy: A 1D tensor of shape (n2,) representing the lengths of the sequences in the second batch.
|
|
25
|
+
:param symmetric: Whether or not the DTW is symmetric (i.e., the two batches are the same).
|
|
26
|
+
"""
|
|
27
|
+
return torch.ops.torchdtw.dtw_batch.default(distances, sx, sy, symmetric)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@torch.library.register_fake("torchdtw::dtw")
|
|
31
|
+
def _(distances: torch.Tensor) -> torch.Tensor:
|
|
32
|
+
"""Register the FakeTensor kernel for dtw, for compatibility with torch.compile."""
|
|
33
|
+
torch._check(distances.ndim == 2)
|
|
34
|
+
torch._check(distances.dtype == torch.float32)
|
|
35
|
+
return torch.empty((), dtype=torch.float32, layout=distances.layout, device=distances.device)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@torch.library.register_fake("torchdtw::dtw_batch")
|
|
39
|
+
def _(distances: torch.Tensor, sx: torch.Tensor, sy: torch.Tensor, symmetric: bool) -> torch.Tensor: # noqa: FBT001
|
|
40
|
+
"""Register the FakeTensor kernel for dtw_batch, for compatibility with torch.compile."""
|
|
41
|
+
torch._check(distances.ndim == 4)
|
|
42
|
+
torch._check(sx.ndim == 1)
|
|
43
|
+
torch._check(sy.ndim == 1)
|
|
44
|
+
torch._check(distances.dtype == torch.float32)
|
|
45
|
+
torch._check(sx.dtype == torch.long)
|
|
46
|
+
torch._check(sy.dtype == torch.long)
|
|
47
|
+
torch._check(isinstance(symmetric, bool))
|
|
48
|
+
nx, ny, _, _ = distances.shape
|
|
49
|
+
return torch.empty((nx, ny), dtype=torch.float32, layout=distances.layout, device=distances.device)
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
#include <Python.h>
|
|
2
|
+
#include <omp.h>
|
|
3
|
+
#include <torch/csrc/stable/library.h>
|
|
4
|
+
#include <torch/csrc/stable/ops.h>
|
|
5
|
+
#include <torch/csrc/stable/tensor.h>
|
|
6
|
+
#include <torch/headeronly/util/Exception.h>
|
|
7
|
+
#include <algorithm>
|
|
8
|
+
#include <vector>
|
|
9
|
+
|
|
10
|
+
extern "C" {
|
|
11
|
+
/* Creates a dummy empty _C module that can be imported from Python.
|
|
12
|
+
The import from Python will load the .so consisting of this file
|
|
13
|
+
in this extension, so that the STABLE_TORCH_LIBRARY static initializers
|
|
14
|
+
below are run. */
|
|
15
|
+
PyObject* PyInit__C(void) {
|
|
16
|
+
static struct PyModuleDef module_def = {
|
|
17
|
+
PyModuleDef_HEAD_INIT,
|
|
18
|
+
"_C", /* name of module */
|
|
19
|
+
NULL, /* module documentation, may be NULL */
|
|
20
|
+
-1, /* size of per-interpreter state of the module,
|
|
21
|
+
or -1 if the module keeps state in global variables. */
|
|
22
|
+
NULL, /* methods */
|
|
23
|
+
};
|
|
24
|
+
return PyModule_Create(&module_def);
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
namespace torchdtw {
|
|
29
|
+
|
|
30
|
+
using torch::stable::Tensor;
|
|
31
|
+
|
|
32
|
+
inline float dtw(
|
|
33
|
+
const float* distances,
|
|
34
|
+
const int64_t N,
|
|
35
|
+
const int64_t M,
|
|
36
|
+
const int64_t stride_x,
|
|
37
|
+
const int64_t stride_y) {
|
|
38
|
+
STD_TORCH_CHECK(N > 0 && M > 0, "Empty input tensor");
|
|
39
|
+
STD_TORCH_CHECK(stride_x > 0 && stride_y > 0, "Strides must be positive");
|
|
40
|
+
std::vector<float> cost(N * M);
|
|
41
|
+
|
|
42
|
+
cost[0] = distances[0];
|
|
43
|
+
for (int64_t i = 1; i < N; i++) {
|
|
44
|
+
cost[i * M] = distances[i * stride_x] + cost[(i - 1) * M];
|
|
45
|
+
}
|
|
46
|
+
for (int64_t j = 1; j < M; j++) {
|
|
47
|
+
cost[j] = distances[j * stride_y] + cost[j - 1];
|
|
48
|
+
}
|
|
49
|
+
for (int64_t i = 1; i < N; i++) {
|
|
50
|
+
for (int64_t j = 1; j < M; j++) {
|
|
51
|
+
cost[i * M + j] = distances[i * stride_x + j * stride_y] +
|
|
52
|
+
std::min({cost[(i - 1) * M + j], cost[(i - 1) * M + j - 1], cost[i * M + j - 1]});
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
int64_t path_len = 1;
|
|
57
|
+
int64_t i = N - 1;
|
|
58
|
+
int64_t j = M - 1;
|
|
59
|
+
while (i > 0 && j > 0) {
|
|
60
|
+
const float c_up = cost[(i - 1) * M + j];
|
|
61
|
+
const float c_left = cost[i * M + j - 1];
|
|
62
|
+
const float c_diag = cost[(i - 1) * M + j - 1];
|
|
63
|
+
if (c_diag <= c_left && c_diag <= c_up) {
|
|
64
|
+
i--;
|
|
65
|
+
j--;
|
|
66
|
+
} else if (c_left <= c_up) {
|
|
67
|
+
j--;
|
|
68
|
+
} else {
|
|
69
|
+
i--;
|
|
70
|
+
}
|
|
71
|
+
path_len++;
|
|
72
|
+
}
|
|
73
|
+
if (i == 0)
|
|
74
|
+
path_len += j;
|
|
75
|
+
if (j == 0)
|
|
76
|
+
path_len += i;
|
|
77
|
+
return cost[(N - 1) * M + M - 1] / path_len;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
Tensor dtw_cpu(const Tensor distances) {
|
|
81
|
+
float result =
|
|
82
|
+
dtw(reinterpret_cast<const float*>(distances.data_ptr()),
|
|
83
|
+
distances.size(0),
|
|
84
|
+
distances.size(1),
|
|
85
|
+
distances.stride(0),
|
|
86
|
+
distances.stride(1));
|
|
87
|
+
Tensor out = torch::stable::new_empty(distances, {});
|
|
88
|
+
torch::stable::fill_(out, result);
|
|
89
|
+
return out;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
Tensor dtw_batch_cpu(const Tensor distances, const Tensor sx, const Tensor sy, bool symmetric) {
|
|
93
|
+
const int64_t nx = distances.size(0);
|
|
94
|
+
const int64_t ny = distances.size(1);
|
|
95
|
+
Tensor out = torch::stable::new_zeros(distances, {nx, ny});
|
|
96
|
+
|
|
97
|
+
const float* distances_ptr = reinterpret_cast<const float*>(distances.data_ptr());
|
|
98
|
+
const int64_t* sx_ptr = reinterpret_cast<const int64_t*>(sx.data_ptr());
|
|
99
|
+
const int64_t* sy_ptr = reinterpret_cast<const int64_t*>(sy.data_ptr());
|
|
100
|
+
float* out_ptr = reinterpret_cast<float*>(out.data_ptr());
|
|
101
|
+
|
|
102
|
+
#pragma omp parallel for schedule(dynamic)
|
|
103
|
+
for (int64_t i = 0; i < nx; i++) {
|
|
104
|
+
const int64_t start_j = symmetric ? i : 0;
|
|
105
|
+
for (int64_t j = start_j; j < ny; j++) {
|
|
106
|
+
if (symmetric && i == j)
|
|
107
|
+
continue;
|
|
108
|
+
out_ptr[i * ny + j] =
|
|
109
|
+
dtw(distances_ptr + i * distances.stride(0) + j * distances.stride(1),
|
|
110
|
+
sx_ptr[i],
|
|
111
|
+
sy_ptr[j],
|
|
112
|
+
distances.stride(2),
|
|
113
|
+
distances.stride(3));
|
|
114
|
+
if (symmetric && i != j) {
|
|
115
|
+
out_ptr[j * ny + i] = out_ptr[i * ny + j];
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
};
|
|
119
|
+
return out;
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
void boxed_dtw_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
|
123
|
+
stack[0] = from(dtw_cpu(to<Tensor>(stack[0])));
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
void boxed_dtw_batch_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
|
127
|
+
stack[0] = from(dtw_batch_cpu(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]), to<bool>(stack[3])));
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
STABLE_TORCH_LIBRARY(torchdtw, m) {
|
|
131
|
+
m.def("dtw(Tensor distances) -> Tensor");
|
|
132
|
+
m.def("dtw_batch(Tensor distances, Tensor sx, Tensor sy, bool symmetric) -> Tensor");
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
STABLE_TORCH_LIBRARY_IMPL(torchdtw, CPU, m) {
|
|
136
|
+
m.impl("dtw", &boxed_dtw_cpu);
|
|
137
|
+
m.impl("dtw_batch", &boxed_dtw_batch_cpu);
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
} // namespace torchdtw
|
|
File without changes
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchdtw
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Author: Maxime Poli
|
|
6
|
+
Author-email: CoML <dev@cognitive-ml.fr>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Keywords: machine learning
|
|
9
|
+
Requires-Python: >=3.12
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: numpy>=1.26.4
|
|
12
|
+
Requires-Dist: torch>=2.9.0
|
|
13
|
+
|
|
14
|
+
# PyTorch DTW C++ extension
|
|
15
|
+
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
setup.py
|
|
4
|
+
src/torchdtw/__init__.py
|
|
5
|
+
src/torchdtw/py.typed
|
|
6
|
+
src/torchdtw.egg-info/PKG-INFO
|
|
7
|
+
src/torchdtw.egg-info/SOURCES.txt
|
|
8
|
+
src/torchdtw.egg-info/dependency_links.txt
|
|
9
|
+
src/torchdtw.egg-info/requires.txt
|
|
10
|
+
src/torchdtw.egg-info/top_level.txt
|
|
11
|
+
src/torchdtw/csrc/dtw.cpp
|
|
12
|
+
tests/test_dtw.py
|
|
13
|
+
tests/test_opcheck.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torchdtw
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Compare CPU and CUDA dtw implementations."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import torch
|
|
5
|
+
from hypothesis import given, settings
|
|
6
|
+
from hypothesis import strategies as st
|
|
7
|
+
|
|
8
|
+
from torchdtw import dtw, dtw_batch
|
|
9
|
+
|
|
10
|
+
rtol, atol = 0, 1e-9
|
|
11
|
+
skipifnogpu = pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available")
|
|
12
|
+
|
|
13
|
+
DIM, BATCH = st.integers(1, 1280), st.integers(1, 3)
|
|
14
|
+
LOW, HIGH_MINUS_LOW = st.floats(-100, 100), st.floats(0.1, 100)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def make_tensor(shape: tuple[int, ...], *, dtype: torch.dtype, low: float, high: float) -> torch.Tensor:
|
|
18
|
+
"""Build a tensor for testing."""
|
|
19
|
+
if low == high and dtype == torch.long:
|
|
20
|
+
return torch.ones(shape, dtype=torch.long, device="cpu")
|
|
21
|
+
return torch.testing.make_tensor(shape, dtype=dtype, device="cpu", low=low, high=high)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@skipifnogpu
|
|
25
|
+
@given(x=DIM, y=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
|
|
26
|
+
@settings(deadline=None)
|
|
27
|
+
def test_dtw(x: int, y: int, low: float, high_minus_low: float) -> None:
|
|
28
|
+
"""Compare the output of dtw between CPU and GPU implementations."""
|
|
29
|
+
d = make_tensor((x, y), dtype=torch.float32, low=low, high=high_minus_low + low)
|
|
30
|
+
torch.testing.assert_close(dtw(d), dtw(d.cuda()).cpu(), rtol=rtol, atol=atol)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@skipifnogpu
|
|
34
|
+
@given(n=BATCH, x=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
|
|
35
|
+
@settings(deadline=None)
|
|
36
|
+
def test_dtw_batch_symmetric(n: int, x: int, low: float, high_minus_low: float) -> None:
|
|
37
|
+
"""Compare the output of dtw_batch between CPU and GPU implementations, symmetric case."""
|
|
38
|
+
d = make_tensor((n, n, x, x), dtype=torch.float32, low=low, high=high_minus_low + low)
|
|
39
|
+
sx = make_tensor((n,), dtype=torch.long, low=1, high=x)
|
|
40
|
+
i, j = torch.triu_indices(n, n)
|
|
41
|
+
d[i, j] = d[j, i]
|
|
42
|
+
torch.testing.assert_close(
|
|
43
|
+
dtw_batch(d, sx, sx, symmetric=True),
|
|
44
|
+
dtw_batch(d.cuda(), sx.cuda(), sx.cuda(), symmetric=True).cpu(),
|
|
45
|
+
rtol=rtol,
|
|
46
|
+
atol=atol,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@skipifnogpu
|
|
51
|
+
@given(n=BATCH, m=BATCH, x=DIM, y=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
|
|
52
|
+
@settings(deadline=None)
|
|
53
|
+
def test_dtw_batch_not_symmetric(n: int, m: int, x: int, y: int, low: float, high_minus_low: float) -> None:
|
|
54
|
+
"""Compare the output of dtw_batch between CPU and GPU implementations, non symmetric case."""
|
|
55
|
+
d = make_tensor((n, m, x, y), dtype=torch.float32, low=low, high=high_minus_low + low)
|
|
56
|
+
sx = make_tensor((n,), dtype=torch.long, low=1, high=x)
|
|
57
|
+
sy = make_tensor((m,), dtype=torch.long, low=1, high=y)
|
|
58
|
+
torch.testing.assert_close(
|
|
59
|
+
dtw_batch(d, sx, sy, symmetric=False),
|
|
60
|
+
dtw_batch(d.cuda(), sx.cuda(), sy.cuda(), symmetric=False).cpu(),
|
|
61
|
+
rtol=rtol,
|
|
62
|
+
atol=atol,
|
|
63
|
+
)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Check for compatibility with torch.compile."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from hypothesis import given, settings
|
|
5
|
+
from hypothesis import strategies as st
|
|
6
|
+
from torch.library import opcheck
|
|
7
|
+
|
|
8
|
+
import torchdtw # noqa: F401 # Need to import it to register dtw operation
|
|
9
|
+
|
|
10
|
+
DIM, BATCH = st.integers(1, 1280), st.integers(1, 3)
|
|
11
|
+
LOW, HIGH_MINUS_LOW = st.floats(-100, 100), st.floats(0.1, 100)
|
|
12
|
+
CUDA_AVAILABLE = torch.cuda.is_available()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def make_tensor(shape: tuple[int, ...], *, dtype: torch.dtype, low: float, high: float) -> torch.Tensor:
|
|
16
|
+
"""Build a tensor for testing."""
|
|
17
|
+
if low == high and dtype == torch.long:
|
|
18
|
+
return torch.ones(shape, dtype=torch.long, device="cpu")
|
|
19
|
+
return torch.testing.make_tensor(shape, dtype=dtype, device="cpu", low=low, high=high)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@given(x=DIM, y=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
|
|
23
|
+
@settings(deadline=None)
|
|
24
|
+
def test_opcheck_dtw(x: int, y: int, low: float, high_minus_low: float) -> None:
|
|
25
|
+
"""Verify that dtw can be torch compiled."""
|
|
26
|
+
sample = make_tensor((x, y), dtype=torch.float32, low=low, high=high_minus_low + low)
|
|
27
|
+
opcheck(torch.ops.torchdtw.dtw.default, (sample,))
|
|
28
|
+
if CUDA_AVAILABLE:
|
|
29
|
+
opcheck(torch.ops.torchdtw.dtw.default, (sample.cuda(),))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@given(n=BATCH, x=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
|
|
33
|
+
@settings(deadline=None)
|
|
34
|
+
def test_opcheck_dtw_batch_symmetric(n: int, x: int, low: float, high_minus_low: float) -> None:
|
|
35
|
+
"""Verify that dtw_batch can be torch compiled, with symmetric input."""
|
|
36
|
+
sample = make_tensor((n, n, x, x), dtype=torch.float32, low=low, high=high_minus_low + low)
|
|
37
|
+
sx = make_tensor((n,), dtype=torch.long, low=1, high=x)
|
|
38
|
+
i, j = torch.triu_indices(n, n)
|
|
39
|
+
sample[i, j] = sample[j, i]
|
|
40
|
+
opcheck(torch.ops.torchdtw.dtw_batch.default, (sample, sx, sx), {"symmetric": True})
|
|
41
|
+
if CUDA_AVAILABLE:
|
|
42
|
+
opcheck(torch.ops.torchdtw.dtw_batch.default, (sample.cuda(), sx.cuda(), sx.cuda()), {"symmetric": True})
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@given(n=BATCH, m=BATCH, x=DIM, y=DIM, low=LOW, high_minus_low=HIGH_MINUS_LOW)
|
|
46
|
+
@settings(deadline=None)
|
|
47
|
+
def test_opcheck_dtw_batch_not_symmetric(n: int, m: int, x: int, y: int, low: float, high_minus_low: float) -> None:
|
|
48
|
+
"""Verify that dtw_batch can be torch compiled, with symmetric input."""
|
|
49
|
+
sample = make_tensor((n, m, x, y), dtype=torch.float32, low=low, high=high_minus_low + low)
|
|
50
|
+
sx = make_tensor((n,), dtype=torch.long, low=1, high=x)
|
|
51
|
+
sy = make_tensor((m,), dtype=torch.long, low=1, high=y)
|
|
52
|
+
opcheck(torch.ops.torchdtw.dtw_batch.default, (sample, sx, sy), {"symmetric": False})
|
|
53
|
+
if CUDA_AVAILABLE:
|
|
54
|
+
opcheck(torch.ops.torchdtw.dtw_batch.default, (sample.cuda(), sx.cuda(), sy.cuda()), {"symmetric": False})
|