auvux-dsp 0.1.0.dev2__tar.gz → 0.1.0.dev3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/.gitignore +1 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/PKG-INFO +9 -4
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/README.md +8 -3
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/benchmarks/benchmark.py +28 -6
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/python/auvux/dsp/__init__.py +31 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/python/auvux/dsp/_dispatch.py +1 -1
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/python/auvux/dsp/_functional.py +172 -0
- auvux_dsp-0.1.0.dev3/python/auvux/dsp/_postproc.py +109 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/python/auvux/dsp/_transforms.py +416 -5
- auvux_dsp-0.1.0.dev3/python/auvux/dsp/_version.py +1 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/bindings/abi.hpp +1 -1
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/cuda/cuda_features.cu +146 -25
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/kernels/moments.metal +152 -27
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/moments/moments.hpp +23 -8
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/moments/moments_cpu.cpp +123 -8
- auvux_dsp-0.1.0.dev3/tests/test_chroma_cens.py +64 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_gpu_features.py +7 -3
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_moments.py +20 -1
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_onset_tempo.py +35 -0
- auvux_dsp-0.1.0.dev3/tests/test_postproc.py +76 -0
- auvux_dsp-0.1.0.dev3/tests/test_spectral_contrast.py +108 -0
- auvux_dsp-0.1.0.dev3/tests/test_tonnetz.py +65 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/torch_refs.py +30 -0
- auvux_dsp-0.1.0.dev2/python/auvux/dsp/_version.py +0 -1
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/.clang-format +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/.clangd +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/.github/workflows/ci.yml +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/.github/workflows/wheels.yml +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/CMakeLists.txt +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/LICENSE +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/THIRD_PARTY_LICENSES +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/cmake/embed_text.cmake +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/pyproject.toml +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/python/auvux/dsp/_convert.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/python/auvux/dsp/_filters.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/python/auvux/dsp/_torch.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/python/auvux/dsp/_transform.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/python/auvux/dsp/py.typed +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/scripts/dev-build.ps1 +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/scripts/dev-build.sh +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/bindings/bind_cqt.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/bindings/bind_features.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/bindings/bind_fft.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/bindings/bind_mel.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/bindings/bind_stft.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/bindings/bind_util.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/bindings/module.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/bindings/pooled_array.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/common/dlpack_bridge.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/common/host_pool.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/common/host_pool.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/common/threadpool.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/common/threadpool.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/fft/fft.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/fft/fft.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/fft/fft_impl.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/fft/fft_pffft.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/fft/fft_vdsp.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/cqt_plan.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/cqt_plan.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/cuda/cuda_common.cu +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/cuda/cuda_common.cuh +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/cuda/cuda_cqt.cu +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/cuda/cuda_mel.cu +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/cuda/cuda_stft.cu +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/gpu.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/gpu_common.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/gpu_stub.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/kernels/common.metal +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/kernels/cqt.metal +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/kernels/flux.metal +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/kernels/mel.metal +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/kernels/onset.metal +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/kernels/rms.metal +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/kernels/stft.metal +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/kernels/tempogram.metal +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/metal_common.h +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/metal_common.mm +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/metal_cqt.mm +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/metal_flux.mm +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/metal_mel.mm +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/metal_moments.mm +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/metal_onset.mm +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/metal_rms.mm +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/metal_stft.mm +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/gpu/metal/metal_tempogram.mm +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/chroma/chroma.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/chroma/chroma_cpu.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/cqt/cqt.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/cqt/cqt_cpu.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/cqt/cqt_filterbank.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/cqt/cqt_filterbank.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/db.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/flux/flux.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/flux/flux_cpu.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/frame.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/istft/istft.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/istft/istft_cpu.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/mel/mel.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/mel/mel_cpu.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/mel/mel_filterbank.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/ola.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/ola.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/onset/onset.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/onset/onset_cpu.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/rms/rms.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/rms/rms_cpu.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/stft/stft.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/stft/stft_cpu.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/tempogram/tempogram.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/tempogram/tempogram_cpu.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/types.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/window.cpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/ops/window.hpp +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/third_party/dlpack.h +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/third_party/pffft.c +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/src/third_party/pffft.h +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_adjoint.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_api.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_chroma.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_cqt.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_dynamics.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_fft.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_gpu.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_grad.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_istft.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_mel.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_mfcc.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_namespace.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_resident.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_stft.py +0 -0
- {auvux_dsp-0.1.0.dev2 → auvux_dsp-0.1.0.dev3}/tests/test_vqt.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: auvux-dsp
|
|
3
|
-
Version: 0.1.0.
|
|
3
|
+
Version: 0.1.0.dev3
|
|
4
4
|
Summary: Fast differentiable audio transforms (STFT, mel, MFCC, CQT, chroma) on CPU and GPU
|
|
5
5
|
Keywords: audio,dsp,stft,mel,mfcc,cqt,chroma,spectrogram,gpu
|
|
6
6
|
Author-Email: Peter Kiers <pkiers.1983@gmail.com>
|
|
@@ -73,9 +73,14 @@ Status: under construction.
|
|
|
73
73
|
- CPU (vDSP/PFFFT): STFT, ISTFT, MelSpectrogram, MFCC, CQT, VQT, Chroma —
|
|
74
74
|
forward and native backward, librosa-parity tested, torch autograd built in.
|
|
75
75
|
- Differentiable features: SpectralMoments (centroid, bandwidth, skewness,
|
|
76
|
-
kurtosis, flatness
|
|
77
|
-
forward + analytic adjoint, librosa-parity
|
|
78
|
-
-
|
|
76
|
+
kurtosis, flatness, slope, decrease, crest, entropy), RMS, SpectralFlux,
|
|
77
|
+
OnsetStrength, Tempogram — native forward + analytic adjoint, librosa-parity
|
|
78
|
+
/ torch-grad tested.
|
|
79
|
+
- Composed descriptors (built on the transforms above, so they inherit the GPU
|
|
80
|
+
path and stay differentiable): SpectralContrast, BandEnergy, Tonnetz,
|
|
81
|
+
ChromaCENS, FourierTempogram. Plus `delta` (Savitzky-Golay derivative) and
|
|
82
|
+
`pool` post-processing over any feature matrix.
|
|
83
|
+
- Metal: every native transform on GPU (n_fft <= 4096), forward + backward,
|
|
79
84
|
parity-tested against the CPU path. torch MPS tensors stay on the GPU end to
|
|
80
85
|
end (DLPack), and backend="auto" routes them there — no flags needed.
|
|
81
86
|
- CUDA: kernel-for-kernel twin of the Metal backend including the resident
|
|
@@ -50,9 +50,14 @@ Status: under construction.
|
|
|
50
50
|
- CPU (vDSP/PFFFT): STFT, ISTFT, MelSpectrogram, MFCC, CQT, VQT, Chroma —
|
|
51
51
|
forward and native backward, librosa-parity tested, torch autograd built in.
|
|
52
52
|
- Differentiable features: SpectralMoments (centroid, bandwidth, skewness,
|
|
53
|
-
kurtosis, flatness
|
|
54
|
-
forward + analytic adjoint, librosa-parity
|
|
55
|
-
-
|
|
53
|
+
kurtosis, flatness, slope, decrease, crest, entropy), RMS, SpectralFlux,
|
|
54
|
+
OnsetStrength, Tempogram — native forward + analytic adjoint, librosa-parity
|
|
55
|
+
/ torch-grad tested.
|
|
56
|
+
- Composed descriptors (built on the transforms above, so they inherit the GPU
|
|
57
|
+
path and stay differentiable): SpectralContrast, BandEnergy, Tonnetz,
|
|
58
|
+
ChromaCENS, FourierTempogram. Plus `delta` (Savitzky-Golay derivative) and
|
|
59
|
+
`pool` post-processing over any feature matrix.
|
|
60
|
+
- Metal: every native transform on GPU (n_fft <= 4096), forward + backward,
|
|
56
61
|
parity-tested against the CPU path. torch MPS tensors stay on the GPU end to
|
|
57
62
|
end (DLPack), and backend="auto" routes them there — no flags needed.
|
|
58
63
|
- CUDA: kernel-for-kernel twin of the Metal backend including the resident
|
|
@@ -95,8 +95,10 @@ def build_cases(y, sec, torch, device, gpu, resident):
|
|
|
95
95
|
("vqt", vqt_t),
|
|
96
96
|
("chroma", chroma_t),
|
|
97
97
|
]
|
|
98
|
-
#
|
|
99
|
-
#
|
|
98
|
+
# Differentiable features. moments/rms/flux/onset/tempogram have native GPU
|
|
99
|
+
# kernels; the composed descriptors (contrast, tonnetz, cens, fourier
|
|
100
|
+
# tempogram, band energy) ride the GPU of the transform they build on.
|
|
101
|
+
# librosa comparisons are registered below only for 1:1-matching definitions.
|
|
100
102
|
feat = [
|
|
101
103
|
("spectral_moments", dsp.SpectralMoments(sr=SR, n_fft=2048, hop_length=512)),
|
|
102
104
|
("spectral_centroid",
|
|
@@ -105,6 +107,15 @@ def build_cases(y, sec, torch, device, gpu, resident):
|
|
|
105
107
|
dsp.SpectralMoments(sr=SR, features=("bandwidth",), n_fft=2048, hop_length=512)),
|
|
106
108
|
("spectral_flatness",
|
|
107
109
|
dsp.SpectralMoments(sr=SR, features=("flatness",), n_fft=2048, hop_length=512)),
|
|
110
|
+
("spectral_descriptors",
|
|
111
|
+
dsp.SpectralMoments(sr=SR, features=("slope", "decrease", "crest", "entropy"),
|
|
112
|
+
n_fft=2048, hop_length=512)),
|
|
113
|
+
("spectral_contrast", dsp.SpectralContrast(sr=SR, n_fft=2048, hop_length=512)),
|
|
114
|
+
("band_energy", dsp.BandEnergy(sr=SR, n_fft=2048, hop_length=512)),
|
|
115
|
+
("tonnetz", dsp.Tonnetz(sr=SR, hop_length=512)),
|
|
116
|
+
("chroma_cens", dsp.ChromaCENS(sr=SR, hop_length=512)),
|
|
117
|
+
("fourier_tempogram",
|
|
118
|
+
dsp.FourierTempogram(sr=SR, win_length=256, n_fft=2048, hop_length=512, n_mels=128)),
|
|
108
119
|
("rms", dsp.RMS(sr=SR, frame_length=2048, hop_length=512)),
|
|
109
120
|
("spectral_flux", dsp.SpectralFlux(sr=SR, n_fft=2048, hop_length=512)),
|
|
110
121
|
("onset_strength", dsp.OnsetStrength(sr=SR, n_fft=2048, hop_length=512, n_mels=128)),
|
|
@@ -195,6 +206,16 @@ def build_cases(y, sec, torch, device, gpu, resident):
|
|
|
195
206
|
lambda: librosa.feature.spectral_flatness(y=y, n_fft=2048, hop_length=512))
|
|
196
207
|
add("forward", "rms", "librosa",
|
|
197
208
|
lambda: librosa.feature.rms(y=y, frame_length=2048, hop_length=512))
|
|
209
|
+
add("forward", "spectral_contrast", "librosa",
|
|
210
|
+
lambda: librosa.feature.spectral_contrast(y=y, sr=SR, n_fft=2048, hop_length=512))
|
|
211
|
+
# tonnetz rides chroma_cqt (matched 1:1); cens omits librosa's amplitude
|
|
212
|
+
# quantization and fourier_tempogram inherits the onset lag-shift, so the
|
|
213
|
+
# latter two are speed-only (see _NO_DIFF).
|
|
214
|
+
add("forward", "tonnetz", "librosa", lambda: librosa.feature.tonnetz(y=y, sr=SR))
|
|
215
|
+
add("forward", "chroma_cens", "librosa",
|
|
216
|
+
lambda: librosa.feature.chroma_cens(y=y, sr=SR, hop_length=512))
|
|
217
|
+
add("forward", "fourier_tempogram", "librosa",
|
|
218
|
+
lambda: librosa.feature.fourier_tempogram(y=y, sr=SR, hop_length=512, win_length=256))
|
|
198
219
|
# Speed-only: our onset/tempogram match librosa's values but not its
|
|
199
220
|
# `center=True` envelope lag-shift (n_fft//(2*hop) frames), so the diff
|
|
200
221
|
# is suppressed (see _NO_DIFF) — correctness lives in the test suite.
|
|
@@ -395,10 +416,11 @@ def build_cases(y, sec, torch, device, gpu, resident):
|
|
|
395
416
|
return cases
|
|
396
417
|
|
|
397
418
|
|
|
398
|
-
# Transforms compared against librosa for speed only: their
|
|
399
|
-
#
|
|
400
|
-
#
|
|
401
|
-
|
|
419
|
+
# Transforms compared against librosa for speed only: their values would read as
|
|
420
|
+
# a (misleading) relerr "diff" failure rather than a correctness check —
|
|
421
|
+
# onset/tempogram/fourier_tempogram differ by librosa's envelope lag-shift
|
|
422
|
+
# convention, chroma_cens by the amplitude quantization we intentionally omit.
|
|
423
|
+
_NO_DIFF = frozenset({"onset_strength", "tempogram", "fourier_tempogram", "chroma_cens"})
|
|
402
424
|
|
|
403
425
|
|
|
404
426
|
def run_cases(cases, repeats, warmup):
|
|
@@ -20,9 +20,12 @@ from auvux.dsp._filters import (
|
|
|
20
20
|
mel_to_hz,
|
|
21
21
|
)
|
|
22
22
|
from auvux.dsp._functional import (
|
|
23
|
+
band_energy,
|
|
23
24
|
chroma,
|
|
25
|
+
chroma_cens,
|
|
24
26
|
clear_caches,
|
|
25
27
|
cqt,
|
|
28
|
+
fourier_tempogram,
|
|
26
29
|
istft,
|
|
27
30
|
mel_spectrogram,
|
|
28
31
|
mfcc,
|
|
@@ -30,15 +33,22 @@ from auvux.dsp._functional import (
|
|
|
30
33
|
rms,
|
|
31
34
|
spectral_bandwidth,
|
|
32
35
|
spectral_centroid,
|
|
36
|
+
spectral_contrast,
|
|
37
|
+
spectral_crest,
|
|
38
|
+
spectral_decrease,
|
|
39
|
+
spectral_entropy,
|
|
33
40
|
spectral_flatness,
|
|
34
41
|
spectral_flux,
|
|
35
42
|
spectral_kurtosis,
|
|
36
43
|
spectral_moments,
|
|
37
44
|
spectral_skewness,
|
|
45
|
+
spectral_slope,
|
|
38
46
|
stft,
|
|
39
47
|
tempogram,
|
|
48
|
+
tonnetz,
|
|
40
49
|
vqt,
|
|
41
50
|
)
|
|
51
|
+
from auvux.dsp._postproc import delta, pool
|
|
42
52
|
from auvux.dsp._transform import Transform
|
|
43
53
|
from auvux.dsp._transforms import (
|
|
44
54
|
CQT,
|
|
@@ -47,12 +57,17 @@ from auvux.dsp._transforms import (
|
|
|
47
57
|
RMS,
|
|
48
58
|
STFT,
|
|
49
59
|
VQT,
|
|
60
|
+
BandEnergy,
|
|
50
61
|
Chroma,
|
|
62
|
+
ChromaCENS,
|
|
63
|
+
FourierTempogram,
|
|
51
64
|
MelSpectrogram,
|
|
52
65
|
OnsetStrength,
|
|
66
|
+
SpectralContrast,
|
|
53
67
|
SpectralFlux,
|
|
54
68
|
SpectralMoments,
|
|
55
69
|
Tempogram,
|
|
70
|
+
Tonnetz,
|
|
56
71
|
)
|
|
57
72
|
from auvux.dsp._version import __version__
|
|
58
73
|
|
|
@@ -64,25 +79,34 @@ __all__ = [
|
|
|
64
79
|
"STFT",
|
|
65
80
|
"VQT",
|
|
66
81
|
"BackendError",
|
|
82
|
+
"BandEnergy",
|
|
67
83
|
"Chroma",
|
|
84
|
+
"ChromaCENS",
|
|
85
|
+
"FourierTempogram",
|
|
68
86
|
"MelSpectrogram",
|
|
69
87
|
"OnsetStrength",
|
|
88
|
+
"SpectralContrast",
|
|
70
89
|
"SpectralFlux",
|
|
71
90
|
"SpectralMoments",
|
|
72
91
|
"Tempogram",
|
|
92
|
+
"Tonnetz",
|
|
73
93
|
"Transform",
|
|
74
94
|
"__version__",
|
|
75
95
|
"amplitude_to_db",
|
|
76
96
|
"backend_info",
|
|
77
97
|
"backends",
|
|
98
|
+
"band_energy",
|
|
78
99
|
"chroma",
|
|
100
|
+
"chroma_cens",
|
|
79
101
|
"clear_caches",
|
|
80
102
|
"cqt",
|
|
81
103
|
"cqt_frequencies",
|
|
82
104
|
"db_to_amplitude",
|
|
83
105
|
"db_to_power",
|
|
84
106
|
"dct_matrix",
|
|
107
|
+
"delta",
|
|
85
108
|
"fft_frequencies",
|
|
109
|
+
"fourier_tempogram",
|
|
86
110
|
"get_num_threads",
|
|
87
111
|
"gpu_available",
|
|
88
112
|
"gpu_backend",
|
|
@@ -94,17 +118,24 @@ __all__ = [
|
|
|
94
118
|
"mel_to_hz",
|
|
95
119
|
"mfcc",
|
|
96
120
|
"onset_strength",
|
|
121
|
+
"pool",
|
|
97
122
|
"power_to_db",
|
|
98
123
|
"rms",
|
|
99
124
|
"set_num_threads",
|
|
100
125
|
"spectral_bandwidth",
|
|
101
126
|
"spectral_centroid",
|
|
127
|
+
"spectral_contrast",
|
|
128
|
+
"spectral_crest",
|
|
129
|
+
"spectral_decrease",
|
|
130
|
+
"spectral_entropy",
|
|
102
131
|
"spectral_flatness",
|
|
103
132
|
"spectral_flux",
|
|
104
133
|
"spectral_kurtosis",
|
|
105
134
|
"spectral_moments",
|
|
106
135
|
"spectral_skewness",
|
|
136
|
+
"spectral_slope",
|
|
107
137
|
"stft",
|
|
108
138
|
"tempogram",
|
|
139
|
+
"tonnetz",
|
|
109
140
|
"vqt",
|
|
110
141
|
]
|
|
@@ -17,7 +17,7 @@ class BackendError(RuntimeError):
|
|
|
17
17
|
# Refuse to run against a compiled core older than these Python sources;
|
|
18
18
|
# editable installs do not rebuild the extension on source changes.
|
|
19
19
|
# Keep in lockstep with kAbiVersion in src/bindings/abi.hpp.
|
|
20
|
-
_EXPECTED_ABI =
|
|
20
|
+
_EXPECTED_ABI = 13
|
|
21
21
|
if getattr(_native, "abi_version", lambda: 0)() != _EXPECTED_ABI:
|
|
22
22
|
raise ImportError(
|
|
23
23
|
"auvux.dsp's compiled core is out of date with its Python sources; "
|
|
@@ -18,12 +18,17 @@ from auvux.dsp._transforms import (
|
|
|
18
18
|
RMS,
|
|
19
19
|
STFT,
|
|
20
20
|
VQT,
|
|
21
|
+
BandEnergy,
|
|
21
22
|
Chroma,
|
|
23
|
+
ChromaCENS,
|
|
24
|
+
FourierTempogram,
|
|
22
25
|
MelSpectrogram,
|
|
23
26
|
OnsetStrength,
|
|
27
|
+
SpectralContrast,
|
|
24
28
|
SpectralFlux,
|
|
25
29
|
SpectralMoments,
|
|
26
30
|
Tempogram,
|
|
31
|
+
Tonnetz,
|
|
27
32
|
)
|
|
28
33
|
|
|
29
34
|
_T = TypeVar("_T", bound=Transform)
|
|
@@ -249,6 +254,54 @@ def chroma(
|
|
|
249
254
|
return t(y, backend=backend)
|
|
250
255
|
|
|
251
256
|
|
|
257
|
+
def tonnetz(
|
|
258
|
+
y: Any,
|
|
259
|
+
*,
|
|
260
|
+
sr: float,
|
|
261
|
+
hop_length: int = 512,
|
|
262
|
+
n_chroma: int = 12,
|
|
263
|
+
bins_per_octave: int = 36,
|
|
264
|
+
n_octaves: int = 7,
|
|
265
|
+
fmin: float | None = None,
|
|
266
|
+
backend: str = "auto",
|
|
267
|
+
) -> Any:
|
|
268
|
+
t = _plan(
|
|
269
|
+
Tonnetz,
|
|
270
|
+
sr=sr,
|
|
271
|
+
hop_length=hop_length,
|
|
272
|
+
n_chroma=n_chroma,
|
|
273
|
+
bins_per_octave=bins_per_octave,
|
|
274
|
+
n_octaves=n_octaves,
|
|
275
|
+
fmin=fmin,
|
|
276
|
+
)
|
|
277
|
+
return t(y, backend=backend)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def chroma_cens(
|
|
281
|
+
y: Any,
|
|
282
|
+
*,
|
|
283
|
+
sr: float,
|
|
284
|
+
hop_length: int = 512,
|
|
285
|
+
n_chroma: int = 12,
|
|
286
|
+
bins_per_octave: int = 36,
|
|
287
|
+
n_octaves: int = 7,
|
|
288
|
+
fmin: float | None = None,
|
|
289
|
+
win: int = 41,
|
|
290
|
+
backend: str = "auto",
|
|
291
|
+
) -> Any:
|
|
292
|
+
t = _plan(
|
|
293
|
+
ChromaCENS,
|
|
294
|
+
sr=sr,
|
|
295
|
+
hop_length=hop_length,
|
|
296
|
+
n_chroma=n_chroma,
|
|
297
|
+
bins_per_octave=bins_per_octave,
|
|
298
|
+
n_octaves=n_octaves,
|
|
299
|
+
fmin=fmin,
|
|
300
|
+
win=win,
|
|
301
|
+
)
|
|
302
|
+
return t(y, backend=backend)
|
|
303
|
+
|
|
304
|
+
|
|
252
305
|
def spectral_moments(
|
|
253
306
|
y: Any,
|
|
254
307
|
*,
|
|
@@ -276,6 +329,68 @@ def spectral_moments(
|
|
|
276
329
|
return t(y, backend=backend)
|
|
277
330
|
|
|
278
331
|
|
|
332
|
+
def band_energy(
|
|
333
|
+
y: Any,
|
|
334
|
+
*,
|
|
335
|
+
sr: float,
|
|
336
|
+
n_fft: int = 2048,
|
|
337
|
+
hop_length: int | None = None,
|
|
338
|
+
win_length: int | None = None,
|
|
339
|
+
window: str = "hann",
|
|
340
|
+
center: bool = True,
|
|
341
|
+
pad_mode: str = "constant",
|
|
342
|
+
n_bands: int = 6,
|
|
343
|
+
fmin: float = 0.0,
|
|
344
|
+
backend: str = "auto",
|
|
345
|
+
) -> Any:
|
|
346
|
+
t = _plan(
|
|
347
|
+
BandEnergy,
|
|
348
|
+
sr=sr,
|
|
349
|
+
n_fft=n_fft,
|
|
350
|
+
hop_length=hop_length,
|
|
351
|
+
win_length=win_length,
|
|
352
|
+
window=window,
|
|
353
|
+
center=center,
|
|
354
|
+
pad_mode=pad_mode,
|
|
355
|
+
n_bands=n_bands,
|
|
356
|
+
fmin=fmin,
|
|
357
|
+
)
|
|
358
|
+
return t(y, backend=backend)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def spectral_contrast(
|
|
362
|
+
y: Any,
|
|
363
|
+
*,
|
|
364
|
+
sr: float,
|
|
365
|
+
n_fft: int = 2048,
|
|
366
|
+
hop_length: int | None = None,
|
|
367
|
+
win_length: int | None = None,
|
|
368
|
+
window: str = "hann",
|
|
369
|
+
center: bool = True,
|
|
370
|
+
pad_mode: str = "constant",
|
|
371
|
+
n_bands: int = 6,
|
|
372
|
+
fmin: float = 200.0,
|
|
373
|
+
quantile: float = 0.02,
|
|
374
|
+
linear: bool = False,
|
|
375
|
+
backend: str = "auto",
|
|
376
|
+
) -> Any:
|
|
377
|
+
t = _plan(
|
|
378
|
+
SpectralContrast,
|
|
379
|
+
sr=sr,
|
|
380
|
+
n_fft=n_fft,
|
|
381
|
+
hop_length=hop_length,
|
|
382
|
+
win_length=win_length,
|
|
383
|
+
window=window,
|
|
384
|
+
center=center,
|
|
385
|
+
pad_mode=pad_mode,
|
|
386
|
+
n_bands=n_bands,
|
|
387
|
+
fmin=fmin,
|
|
388
|
+
quantile=quantile,
|
|
389
|
+
linear=linear,
|
|
390
|
+
)
|
|
391
|
+
return t(y, backend=backend)
|
|
392
|
+
|
|
393
|
+
|
|
279
394
|
def _one_moment(feature: str, y: Any, **kw: Any) -> Any:
|
|
280
395
|
return spectral_moments(y, features=(feature,), **kw)
|
|
281
396
|
|
|
@@ -305,6 +420,26 @@ def spectral_flatness(y: Any, **kw: Any) -> Any:
|
|
|
305
420
|
return _one_moment("flatness", y, **kw)
|
|
306
421
|
|
|
307
422
|
|
|
423
|
+
def spectral_slope(y: Any, **kw: Any) -> Any:
|
|
424
|
+
"""Regression slope of |spectrum| on frequency, shape (..., 1, n_frames)."""
|
|
425
|
+
return _one_moment("slope", y, **kw)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def spectral_decrease(y: Any, **kw: Any) -> Any:
|
|
429
|
+
"""MPEG-7 spectral decrease per frame, shape (..., 1, n_frames)."""
|
|
430
|
+
return _one_moment("decrease", y, **kw)
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def spectral_crest(y: Any, **kw: Any) -> Any:
|
|
434
|
+
"""Spectral crest (peak-to-mean magnitude) per frame, shape (..., 1, n_frames)."""
|
|
435
|
+
return _one_moment("crest", y, **kw)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def spectral_entropy(y: Any, **kw: Any) -> Any:
|
|
439
|
+
"""Normalized spectral entropy in [0, 1] per frame, shape (..., 1, n_frames)."""
|
|
440
|
+
return _one_moment("entropy", y, **kw)
|
|
441
|
+
|
|
442
|
+
|
|
308
443
|
def rms(
|
|
309
444
|
y: Any,
|
|
310
445
|
*,
|
|
@@ -421,3 +556,40 @@ def tempogram(
|
|
|
421
556
|
mel_norm=mel_norm,
|
|
422
557
|
)
|
|
423
558
|
return t(y, backend=backend)
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
def fourier_tempogram(
|
|
562
|
+
y: Any,
|
|
563
|
+
*,
|
|
564
|
+
sr: float,
|
|
565
|
+
win_length: int = 256,
|
|
566
|
+
n_fft: int = 2048,
|
|
567
|
+
hop_length: int | None = None,
|
|
568
|
+
stft_win_length: int | None = None,
|
|
569
|
+
window: str = "hann",
|
|
570
|
+
center: bool = True,
|
|
571
|
+
pad_mode: str = "constant",
|
|
572
|
+
n_mels: int = 128,
|
|
573
|
+
fmin: float = 0.0,
|
|
574
|
+
fmax: float | None = None,
|
|
575
|
+
mel_scale: str = "slaney",
|
|
576
|
+
mel_norm: str | None = "slaney",
|
|
577
|
+
backend: str = "auto",
|
|
578
|
+
) -> Any:
|
|
579
|
+
t = _plan(
|
|
580
|
+
FourierTempogram,
|
|
581
|
+
sr=sr,
|
|
582
|
+
win_length=win_length,
|
|
583
|
+
n_fft=n_fft,
|
|
584
|
+
hop_length=hop_length,
|
|
585
|
+
stft_win_length=stft_win_length,
|
|
586
|
+
window=window,
|
|
587
|
+
center=center,
|
|
588
|
+
pad_mode=pad_mode,
|
|
589
|
+
n_mels=n_mels,
|
|
590
|
+
fmin=fmin,
|
|
591
|
+
fmax=fmax,
|
|
592
|
+
mel_scale=mel_scale,
|
|
593
|
+
mel_norm=mel_norm,
|
|
594
|
+
)
|
|
595
|
+
return t(y, backend=backend)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Post-processing over feature matrices: delta (derivative) features and
|
|
2
|
+
temporal pooling. These are fixed linear / reduction ops along the time axis,
|
|
3
|
+
so they apply to any (..., t) feature and stay differentiable: numpy in ->
|
|
4
|
+
numpy out, torch in -> torch out with grad."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from functools import lru_cache
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from auvux.dsp import _dispatch
|
|
14
|
+
|
|
15
|
+
_POOL_STATS = ("mean", "std", "var", "max", "min", "median")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache(maxsize=32)
|
|
19
|
+
def _savgol_matrix(n: int, width: int, order: int) -> np.ndarray:
|
|
20
|
+
"""The T x T Savitzky-Golay derivative operator (deriv == order, mode
|
|
21
|
+
'interp'), so savgol == data @ M.T. Built with numpy alone."""
|
|
22
|
+
half = width // 2
|
|
23
|
+
falling = float(np.prod([order - j for j in range(order)])) if order > 0 else 1.0
|
|
24
|
+
M = np.zeros((n, n), dtype=np.float64)
|
|
25
|
+
|
|
26
|
+
# Interior: one shared kernel, the order-th derivative at the window center.
|
|
27
|
+
xc = np.arange(width) - half
|
|
28
|
+
Ac = np.vander(xc, order + 1, increasing=True)
|
|
29
|
+
kernel = falling * np.linalg.pinv(Ac)[order] # length width
|
|
30
|
+
for i in range(half, n - half):
|
|
31
|
+
M[i, i - half : i + half + 1] = kernel
|
|
32
|
+
|
|
33
|
+
# Edges: fit one polynomial to the first/last `width` points and read its
|
|
34
|
+
# order-th derivative at each edge position (scipy's 'interp' convention).
|
|
35
|
+
xe = np.arange(width)
|
|
36
|
+
pinv_e = np.linalg.pinv(np.vander(xe, order + 1, increasing=True)) # (order+1, width)
|
|
37
|
+
# powers[m] = m!/(m-order)!, the constant from differentiating x^m `order` times.
|
|
38
|
+
powers = np.array([float(np.prod([m - j for j in range(order)])) for m in range(order + 1)])
|
|
39
|
+
|
|
40
|
+
def edge_row(x: float) -> np.ndarray:
|
|
41
|
+
row = np.zeros(width, dtype=np.float64)
|
|
42
|
+
for m in range(order, len(powers)):
|
|
43
|
+
row += pinv_e[m] * powers[m] * (x ** (m - order))
|
|
44
|
+
return row
|
|
45
|
+
|
|
46
|
+
for i in range(half): # left edge: window is the first `width` points, x = i
|
|
47
|
+
M[i, :width] = edge_row(float(i))
|
|
48
|
+
for i in range(n - half, n): # right edge: window is the last `width` points
|
|
49
|
+
M[i, n - width :] = edge_row(float(i - (n - width)))
|
|
50
|
+
return M.astype(np.float32)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def delta(data: Any, *, width: int = 9, order: int = 1, axis: int = -1) -> Any:
|
|
54
|
+
"""Savitzky-Golay derivative features (librosa.feature.delta, mode
|
|
55
|
+
'interp'): a fixed linear filter of the requested order along `axis`.
|
|
56
|
+
order=1 is the delta, order=2 the delta-delta."""
|
|
57
|
+
if width < 3 or width % 2 == 0:
|
|
58
|
+
raise ValueError(f"width must be an odd integer >= 3, got {width}")
|
|
59
|
+
if order < 1 or order >= width:
|
|
60
|
+
raise ValueError(f"order must satisfy 1 <= order < width, got {order}")
|
|
61
|
+
n = data.shape[axis]
|
|
62
|
+
if width > n:
|
|
63
|
+
raise ValueError(f"width ({width}) exceeds the {n} frames along axis {axis}")
|
|
64
|
+
M = _savgol_matrix(int(n), int(width), int(order)) # (n, n), out = data @ M.T
|
|
65
|
+
if _dispatch.array_module(data) == "torch":
|
|
66
|
+
import torch
|
|
67
|
+
|
|
68
|
+
Mt = torch.from_numpy(M).to(device=data.device, dtype=data.dtype)
|
|
69
|
+
moved = data.movedim(axis, -1)
|
|
70
|
+
return torch.matmul(moved, Mt.transpose(0, 1)).movedim(-1, axis)
|
|
71
|
+
moved = np.moveaxis(data, axis, -1)
|
|
72
|
+
out = moved @ M.T
|
|
73
|
+
return np.moveaxis(out, -1, axis).astype(data.dtype, copy=False)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def pool(data: Any, *, stat: str = "mean", axis: int = -1) -> Any:
|
|
77
|
+
"""Reduce a feature matrix along the time `axis` to a single statistic
|
|
78
|
+
(mean / std / var / max / min / median). Differentiable for torch inputs
|
|
79
|
+
(max/min/median use the standard subgradient). std/var are population
|
|
80
|
+
(unbiased=False); median over an even-length axis follows each backend's
|
|
81
|
+
own convention (numpy averages the two central values, torch takes the
|
|
82
|
+
lower)."""
|
|
83
|
+
if stat not in _POOL_STATS:
|
|
84
|
+
raise ValueError(f"stat must be one of {_POOL_STATS}; got {stat!r}")
|
|
85
|
+
if _dispatch.array_module(data) == "torch":
|
|
86
|
+
import torch
|
|
87
|
+
|
|
88
|
+
if stat == "mean":
|
|
89
|
+
return data.mean(dim=axis)
|
|
90
|
+
if stat == "std":
|
|
91
|
+
return data.std(dim=axis, unbiased=False)
|
|
92
|
+
if stat == "var":
|
|
93
|
+
return data.var(dim=axis, unbiased=False)
|
|
94
|
+
if stat == "max":
|
|
95
|
+
return data.amax(dim=axis)
|
|
96
|
+
if stat == "min":
|
|
97
|
+
return data.amin(dim=axis)
|
|
98
|
+
return torch.median(data, dim=axis).values
|
|
99
|
+
if stat == "mean":
|
|
100
|
+
return np.mean(data, axis=axis)
|
|
101
|
+
if stat == "std":
|
|
102
|
+
return np.std(data, axis=axis)
|
|
103
|
+
if stat == "var":
|
|
104
|
+
return np.var(data, axis=axis)
|
|
105
|
+
if stat == "max":
|
|
106
|
+
return np.max(data, axis=axis)
|
|
107
|
+
if stat == "min":
|
|
108
|
+
return np.min(data, axis=axis)
|
|
109
|
+
return np.median(data, axis=axis)
|