mlx-spectro 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_spectro-0.2.0/.github/workflows/ci.yml +23 -0
- mlx_spectro-0.2.0/.github/workflows/release-pypi.yml +94 -0
- mlx_spectro-0.2.0/.github/workflows/release-testpypi.yml +98 -0
- mlx_spectro-0.2.0/.gitignore +12 -0
- mlx_spectro-0.2.0/LICENSE +21 -0
- mlx_spectro-0.2.0/PKG-INFO +190 -0
- mlx_spectro-0.2.0/README.md +154 -0
- mlx_spectro-0.2.0/pyproject.toml +47 -0
- mlx_spectro-0.2.0/scripts/benchmark.py +586 -0
- mlx_spectro-0.2.0/src/mlx_spectro/__init__.py +29 -0
- mlx_spectro-0.2.0/src/mlx_spectro/py.typed +0 -0
- mlx_spectro-0.2.0/src/mlx_spectro/spectral_ops.py +2857 -0
- mlx_spectro-0.2.0/tests/__init__.py +0 -0
- mlx_spectro-0.2.0/tests/test_spectral_ops.py +765 -0
- mlx_spectro-0.2.0/uv.lock +841 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [main]
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
test:
|
|
11
|
+
runs-on: macos-latest
|
|
12
|
+
strategy:
|
|
13
|
+
matrix:
|
|
14
|
+
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
|
15
|
+
steps:
|
|
16
|
+
- uses: actions/checkout@v4
|
|
17
|
+
- uses: actions/setup-python@v5
|
|
18
|
+
with:
|
|
19
|
+
python-version: ${{ matrix.python-version }}
|
|
20
|
+
- name: Install
|
|
21
|
+
run: pip install -e ".[dev]"
|
|
22
|
+
- name: Test
|
|
23
|
+
run: pytest -v
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
name: Release (PyPI)
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
workflow_dispatch:
|
|
5
|
+
inputs:
|
|
6
|
+
version:
|
|
7
|
+
description: "Version to publish (for example: 0.2.0)"
|
|
8
|
+
required: true
|
|
9
|
+
type: string
|
|
10
|
+
|
|
11
|
+
jobs:
|
|
12
|
+
build-artifacts:
|
|
13
|
+
name: Build release artifacts
|
|
14
|
+
runs-on: ubuntu-24.04
|
|
15
|
+
steps:
|
|
16
|
+
- name: Checkout
|
|
17
|
+
uses: actions/checkout@v4
|
|
18
|
+
|
|
19
|
+
- name: Setup Python
|
|
20
|
+
uses: actions/setup-python@v5
|
|
21
|
+
with:
|
|
22
|
+
python-version: "3.13"
|
|
23
|
+
|
|
24
|
+
- name: Install packaging tools
|
|
25
|
+
run: python -m pip install --upgrade pip build twine
|
|
26
|
+
|
|
27
|
+
- name: Build sdist and wheel
|
|
28
|
+
run: python -m build
|
|
29
|
+
|
|
30
|
+
- name: Validate metadata
|
|
31
|
+
run: python -m twine check dist/*
|
|
32
|
+
|
|
33
|
+
- name: Upload dist artifacts
|
|
34
|
+
uses: actions/upload-artifact@v4
|
|
35
|
+
with:
|
|
36
|
+
name: dist-artifacts
|
|
37
|
+
path: dist/*
|
|
38
|
+
|
|
39
|
+
publish-pypi:
|
|
40
|
+
name: Publish to PyPI
|
|
41
|
+
runs-on: ubuntu-24.04
|
|
42
|
+
needs: [build-artifacts]
|
|
43
|
+
environment: pypi
|
|
44
|
+
permissions:
|
|
45
|
+
id-token: write
|
|
46
|
+
steps:
|
|
47
|
+
- name: Download dist artifacts
|
|
48
|
+
uses: actions/download-artifact@v4
|
|
49
|
+
with:
|
|
50
|
+
name: dist-artifacts
|
|
51
|
+
path: dist
|
|
52
|
+
|
|
53
|
+
- name: Publish package distributions to PyPI
|
|
54
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
55
|
+
with:
|
|
56
|
+
attestations: true
|
|
57
|
+
|
|
58
|
+
smoke-macos:
|
|
59
|
+
name: Smoke install from PyPI (macOS)
|
|
60
|
+
runs-on: macos-14
|
|
61
|
+
needs: [publish-pypi]
|
|
62
|
+
env:
|
|
63
|
+
PACKAGE_VERSION: ${{ github.event.inputs.version }}
|
|
64
|
+
steps:
|
|
65
|
+
- name: Setup Python
|
|
66
|
+
uses: actions/setup-python@v5
|
|
67
|
+
with:
|
|
68
|
+
python-version: "3.13"
|
|
69
|
+
|
|
70
|
+
- name: Install from PyPI
|
|
71
|
+
run: |
|
|
72
|
+
python -m pip install --upgrade pip
|
|
73
|
+
for i in {1..12}; do
|
|
74
|
+
python -m pip install "mlx-spectro==${PACKAGE_VERSION}" && break
|
|
75
|
+
if [ "$i" -eq 12 ]; then
|
|
76
|
+
exit 1
|
|
77
|
+
fi
|
|
78
|
+
sleep 10
|
|
79
|
+
done
|
|
80
|
+
|
|
81
|
+
- name: Smoke test
|
|
82
|
+
run: |
|
|
83
|
+
python -c "
|
|
84
|
+
import mlx.core as mx
|
|
85
|
+
from mlx_spectro import SpectralTransform, make_window, resolve_fft_params, get_transform_mlx
|
|
86
|
+
t = SpectralTransform(1024, 256)
|
|
87
|
+
x = mx.zeros((1, 4096), dtype=mx.float32)
|
|
88
|
+
z = t.stft(x, output_layout='bnf')
|
|
89
|
+
y = t.istft(z, length=4096, input_layout='bnf')
|
|
90
|
+
mx.eval(z, y)
|
|
91
|
+
assert z.shape == (1, 13, 513), f'unexpected stft shape: {z.shape}'
|
|
92
|
+
assert y.shape == (1, 4096), f'unexpected istft shape: {y.shape}'
|
|
93
|
+
print(f'mlx-spectro {PACKAGE_VERSION} smoke test passed')
|
|
94
|
+
" 2>&1 | head -20
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
name: Release RC (TestPyPI)
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
workflow_dispatch:
|
|
5
|
+
inputs:
|
|
6
|
+
version:
|
|
7
|
+
description: "Version to publish and verify (for example: 0.2.0rc1)"
|
|
8
|
+
required: true
|
|
9
|
+
type: string
|
|
10
|
+
|
|
11
|
+
jobs:
|
|
12
|
+
build-artifacts:
|
|
13
|
+
name: Build RC artifacts
|
|
14
|
+
runs-on: ubuntu-24.04
|
|
15
|
+
steps:
|
|
16
|
+
- name: Checkout
|
|
17
|
+
uses: actions/checkout@v4
|
|
18
|
+
|
|
19
|
+
- name: Setup Python
|
|
20
|
+
uses: actions/setup-python@v5
|
|
21
|
+
with:
|
|
22
|
+
python-version: "3.13"
|
|
23
|
+
|
|
24
|
+
- name: Install packaging tools
|
|
25
|
+
run: python -m pip install --upgrade pip build twine
|
|
26
|
+
|
|
27
|
+
- name: Build sdist and wheel
|
|
28
|
+
run: python -m build
|
|
29
|
+
|
|
30
|
+
- name: Validate metadata
|
|
31
|
+
run: python -m twine check dist/*
|
|
32
|
+
|
|
33
|
+
- name: Upload dist artifacts
|
|
34
|
+
uses: actions/upload-artifact@v4
|
|
35
|
+
with:
|
|
36
|
+
name: dist-artifacts
|
|
37
|
+
path: dist/*
|
|
38
|
+
|
|
39
|
+
publish-testpypi:
|
|
40
|
+
name: Publish to TestPyPI
|
|
41
|
+
runs-on: ubuntu-24.04
|
|
42
|
+
needs: [build-artifacts]
|
|
43
|
+
environment: testpypi
|
|
44
|
+
permissions:
|
|
45
|
+
id-token: write
|
|
46
|
+
steps:
|
|
47
|
+
- name: Download dist artifacts
|
|
48
|
+
uses: actions/download-artifact@v4
|
|
49
|
+
with:
|
|
50
|
+
name: dist-artifacts
|
|
51
|
+
path: dist
|
|
52
|
+
|
|
53
|
+
- name: Publish package distributions to TestPyPI
|
|
54
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
55
|
+
with:
|
|
56
|
+
repository-url: https://test.pypi.org/legacy/
|
|
57
|
+
attestations: false
|
|
58
|
+
|
|
59
|
+
smoke-macos:
|
|
60
|
+
name: Smoke install from TestPyPI (macOS)
|
|
61
|
+
runs-on: macos-14
|
|
62
|
+
needs: [publish-testpypi]
|
|
63
|
+
env:
|
|
64
|
+
PACKAGE_VERSION: ${{ github.event.inputs.version }}
|
|
65
|
+
steps:
|
|
66
|
+
- name: Setup Python
|
|
67
|
+
uses: actions/setup-python@v5
|
|
68
|
+
with:
|
|
69
|
+
python-version: "3.13"
|
|
70
|
+
|
|
71
|
+
- name: Install from TestPyPI
|
|
72
|
+
run: |
|
|
73
|
+
python -m pip install --upgrade pip
|
|
74
|
+
for i in {1..12}; do
|
|
75
|
+
python -m pip install \
|
|
76
|
+
--index-url https://test.pypi.org/simple/ \
|
|
77
|
+
--extra-index-url https://pypi.org/simple \
|
|
78
|
+
"mlx-spectro==${PACKAGE_VERSION}" && break
|
|
79
|
+
if [ "$i" -eq 12 ]; then
|
|
80
|
+
exit 1
|
|
81
|
+
fi
|
|
82
|
+
sleep 10
|
|
83
|
+
done
|
|
84
|
+
|
|
85
|
+
- name: Smoke test
|
|
86
|
+
run: |
|
|
87
|
+
python -c "
|
|
88
|
+
import mlx.core as mx
|
|
89
|
+
from mlx_spectro import SpectralTransform, make_window, resolve_fft_params, get_transform_mlx
|
|
90
|
+
t = SpectralTransform(1024, 256)
|
|
91
|
+
x = mx.zeros((1, 4096), dtype=mx.float32)
|
|
92
|
+
z = t.stft(x, output_layout='bnf')
|
|
93
|
+
y = t.istft(z, length=4096, input_layout='bnf')
|
|
94
|
+
mx.eval(z, y)
|
|
95
|
+
assert z.shape == (1, 13, 513), f'unexpected stft shape: {z.shape}'
|
|
96
|
+
assert y.shape == (1, 4096), f'unexpected istft shape: {y.shape}'
|
|
97
|
+
print(f'mlx-spectro {PACKAGE_VERSION} smoke test passed')
|
|
98
|
+
" 2>&1 | head -20
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 ssmall256
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mlx-spectro
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: High-performance STFT/iSTFT for Apple MLX with fused Metal kernels
|
|
5
|
+
Project-URL: Homepage, https://github.com/ssmall256/mlx-spectro
|
|
6
|
+
Project-URL: Repository, https://github.com/ssmall256/mlx-spectro
|
|
7
|
+
Project-URL: Issues, https://github.com/ssmall256/mlx-spectro/issues
|
|
8
|
+
Author-email: Sam <ssmall256@users.noreply.github.com>
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: apple-silicon,audio,dsp,istft,mlx,spectral,stft
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Operating System :: MacOS
|
|
17
|
+
Classifier: Programming Language :: Python :: 3
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
22
|
+
Classifier: Topic :: Multimedia :: Sound/Audio :: Analysis
|
|
23
|
+
Classifier: Topic :: Scientific/Engineering
|
|
24
|
+
Requires-Python: >=3.10
|
|
25
|
+
Requires-Dist: mlx>=0.5.0
|
|
26
|
+
Requires-Dist: numpy
|
|
27
|
+
Provides-Extra: benchmark
|
|
28
|
+
Requires-Dist: mlx-stft; extra == 'benchmark'
|
|
29
|
+
Requires-Dist: torch>=2.0; extra == 'benchmark'
|
|
30
|
+
Provides-Extra: dev
|
|
31
|
+
Requires-Dist: build; extra == 'dev'
|
|
32
|
+
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
33
|
+
Provides-Extra: torch
|
|
34
|
+
Requires-Dist: torch>=2.0; extra == 'torch'
|
|
35
|
+
Description-Content-Type: text/markdown
|
|
36
|
+
|
|
37
|
+
# mlx-spectro
|
|
38
|
+
|
|
39
|
+
High-performance STFT/iSTFT for [Apple MLX](https://github.com/ml-explore/mlx) with fused Metal kernels.
|
|
40
|
+
|
|
41
|
+
- Fused overlap-add with autotuned Metal kernels
|
|
42
|
+
- PyTorch-compatible STFT/iSTFT semantics
|
|
43
|
+
- Cached transforms for zero-overhead repeated calls
|
|
44
|
+
- Optional torch fallback for strict numerical parity
|
|
45
|
+
|
|
46
|
+
## Install
|
|
47
|
+
|
|
48
|
+
```bash
|
|
49
|
+
pip install mlx-spectro
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
With optional torch fallback support:
|
|
53
|
+
|
|
54
|
+
```bash
|
|
55
|
+
pip install mlx-spectro[torch]
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## Quick Start
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
import mlx.core as mx
|
|
62
|
+
from mlx_spectro import SpectralTransform
|
|
63
|
+
|
|
64
|
+
# Create a transform
|
|
65
|
+
transform = SpectralTransform(
|
|
66
|
+
n_fft=2048,
|
|
67
|
+
hop_length=512,
|
|
68
|
+
window_fn="hann",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Forward STFT
|
|
72
|
+
audio = mx.random.normal((1, 44100))
|
|
73
|
+
spec = transform.stft(audio, output_layout="bnf")
|
|
74
|
+
|
|
75
|
+
# Inverse STFT
|
|
76
|
+
reconstructed = transform.istft(spec, length=44100, input_layout="bnf")
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
## API
|
|
80
|
+
|
|
81
|
+
### `SpectralTransform`
|
|
82
|
+
|
|
83
|
+
Main class for STFT/iSTFT operations.
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
SpectralTransform(
|
|
87
|
+
n_fft: int,
|
|
88
|
+
hop_length: int,
|
|
89
|
+
win_length: int | None = None,
|
|
90
|
+
window_fn: str = "hann", # "hann", "hamming", "rect"
|
|
91
|
+
window: mx.array | None = None, # custom window array
|
|
92
|
+
periodic: bool = True,
|
|
93
|
+
center: bool = True,
|
|
94
|
+
normalized: bool = False,
|
|
95
|
+
istft_backend_policy: str | None = None, # "auto", "mlx_fft", "metal", "torch_fallback"
|
|
96
|
+
)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
**Methods:**
|
|
100
|
+
- `stft(x, output_layout="bfn")` — Forward STFT. Input: `[T]` or `[B, T]`.
|
|
101
|
+
- `istft(z, length=None, ...)` — Inverse STFT. Returns `[B, T]`.
|
|
102
|
+
- `compiled_pair(length, layout="bnf", warmup_batch=None)` — Return compiled `(stft_fn, istft_fn)` for steady-state loops (10–20% faster).
|
|
103
|
+
- `warmup(batch=1, length=4096)` — Force kernel compilation.
|
|
104
|
+
|
|
105
|
+
### `get_transform_mlx(**kwargs)`
|
|
106
|
+
|
|
107
|
+
Factory that returns cached `SpectralTransform` instances for repeated use.
|
|
108
|
+
|
|
109
|
+
### `make_window(window, window_fn, win_length, n_fft, periodic)`
|
|
110
|
+
|
|
111
|
+
Create or validate a 1D analysis window.
|
|
112
|
+
|
|
113
|
+
### `resolve_fft_params(n_fft, hop_length, win_length, pad)`
|
|
114
|
+
|
|
115
|
+
Resolve effective FFT parameters with PyTorch-compatible defaults.
|
|
116
|
+
|
|
117
|
+
## Benchmarks
|
|
118
|
+
|
|
119
|
+
Apple M4 Max, macOS 26.3, MLX 0.30.6, PyTorch 2.10.0, 20 iterations (5 warmup).
|
|
120
|
+
|
|
121
|
+
### STFT Forward
|
|
122
|
+
|
|
123
|
+
| Config | mlx-spectro | torch MPS | mlx-stft | vs torch | vs mlx-stft |
|
|
124
|
+
|---|---|---|---|---|---|
|
|
125
|
+
| B=1 T=16k nfft=512 | 0.16 ms | 0.21 ms | 0.31 ms | 1.4x | 1.9x |
|
|
126
|
+
| B=4 T=160k nfft=1024 | 0.37 ms | 0.78 ms | 1.09 ms | **2.1x** | **3.0x** |
|
|
127
|
+
| B=8 T=160k nfft=1024 | 0.28 ms | 0.68 ms | 1.53 ms | **2.5x** | **5.6x** |
|
|
128
|
+
| B=4 T=1.3M nfft=1024 | 0.79 ms | 1.71 ms | 5.03 ms | **2.2x** | **6.3x** |
|
|
129
|
+
| B=8 T=480k nfft=1024 | 0.58 ms | 1.30 ms | 3.73 ms | **2.2x** | **6.4x** |
|
|
130
|
+
|
|
131
|
+
### iSTFT Forward
|
|
132
|
+
|
|
133
|
+
| Config | mlx-spectro | torch MPS | mlx-stft | vs torch | vs mlx-stft |
|
|
134
|
+
|---|---|---|---|---|---|
|
|
135
|
+
| B=1 T=16k nfft=512 | 0.17 ms | 0.49 ms | 0.25 ms | 3.0x | 1.5x |
|
|
136
|
+
| B=4 T=160k nfft=1024 | 0.21 ms | 0.99 ms | 0.98 ms | **4.8x** | **4.7x** |
|
|
137
|
+
| B=8 T=160k nfft=1024 | 0.29 ms | 1.58 ms | 1.62 ms | **5.4x** | **5.6x** |
|
|
138
|
+
| B=4 T=1.3M nfft=1024 | 0.77 ms | 5.74 ms | 6.68 ms | **7.5x** | **8.7x** |
|
|
139
|
+
| B=8 T=480k nfft=1024 | 0.60 ms | 4.10 ms | 4.55 ms | **6.8x** | **7.6x** |
|
|
140
|
+
|
|
141
|
+
### Differentiable STFT + iSTFT (Forward + Backward)
|
|
142
|
+
|
|
143
|
+
| Config | mlx-spectro | torch MPS | vs torch |
|
|
144
|
+
|---|---|---|---|
|
|
145
|
+
| B=1 T=16k nfft=512 | 0.32 ms | 0.97 ms | **3.0x** |
|
|
146
|
+
| B=4 T=160k nfft=1024 | 0.61 ms | 2.28 ms | **3.7x** |
|
|
147
|
+
| B=8 T=160k nfft=1024 | 1.05 ms | 4.33 ms | **4.1x** |
|
|
148
|
+
| B=4 T=1.3M nfft=1024 | 4.30 ms | 17.44 ms | **4.1x** |
|
|
149
|
+
| B=8 T=480k nfft=1024 | 3.01 ms | 12.53 ms | **4.2x** |
|
|
150
|
+
|
|
151
|
+
### Roundtrip Accuracy (STFT → iSTFT max abs error)
|
|
152
|
+
|
|
153
|
+
| Config | mlx-spectro | torch MPS |
|
|
154
|
+
|---|---|---|
|
|
155
|
+
| B=1 T=16k nfft=512 | 1.67e-06 | 2.38e-06 |
|
|
156
|
+
| B=4 T=160k nfft=2048 | 2.86e-06 | 5.25e-06 |
|
|
157
|
+
| B=8 T=480k nfft=1024 | 3.81e-06 | 4.77e-06 |
|
|
158
|
+
|
|
159
|
+
### Compiled Mode
|
|
160
|
+
|
|
161
|
+
For tight inference loops with fixed input shapes, `compiled_pair` eliminates
|
|
162
|
+
per-call Python dispatch overhead (10–20% faster for small workloads):
|
|
163
|
+
|
|
164
|
+
```python
|
|
165
|
+
t = SpectralTransform(n_fft=1024, hop_length=256, window_fn="hann")
|
|
166
|
+
stft, istft = t.compiled_pair(length=44100, warmup_batch=2)
|
|
167
|
+
|
|
168
|
+
for chunk in audio_stream:
|
|
169
|
+
z = stft(chunk)
|
|
170
|
+
z = process(z)
|
|
171
|
+
y = istft(z)
|
|
172
|
+
mx.eval(y)
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
Use the eager `t.stft()` / `t.istft()` methods when input shapes vary.
|
|
176
|
+
|
|
177
|
+
## Environment Variables
|
|
178
|
+
|
|
179
|
+
| Variable | Default | Description |
|
|
180
|
+
|---|---|---|
|
|
181
|
+
| `SPEC_MLX_AUTOTUNE` | `1` | Enable Metal kernel autotuning |
|
|
182
|
+
| `SPEC_MLX_TGX` | — | Force threadgroup size (e.g. `256` or `kernel:256`) |
|
|
183
|
+
| `SPEC_MLX_AUTOTUNE_PERSIST` | `1` | Persist autotune results to disk |
|
|
184
|
+
| `SPEC_MLX_AUTOTUNE_CACHE_PATH` | — | Override autotune cache file path |
|
|
185
|
+
| `MLX_OLA_FUSE_NORM` | `1` | Enable fused OLA+normalization kernel |
|
|
186
|
+
| `SPEC_MLX_CACHE_STATS` | `0` | Enable cache debug counters |
|
|
187
|
+
|
|
188
|
+
## License
|
|
189
|
+
|
|
190
|
+
MIT
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# mlx-spectro
|
|
2
|
+
|
|
3
|
+
High-performance STFT/iSTFT for [Apple MLX](https://github.com/ml-explore/mlx) with fused Metal kernels.
|
|
4
|
+
|
|
5
|
+
- Fused overlap-add with autotuned Metal kernels
|
|
6
|
+
- PyTorch-compatible STFT/iSTFT semantics
|
|
7
|
+
- Cached transforms for zero-overhead repeated calls
|
|
8
|
+
- Optional torch fallback for strict numerical parity
|
|
9
|
+
|
|
10
|
+
## Install
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
pip install mlx-spectro
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
With optional torch fallback support:
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
pip install mlx-spectro[torch]
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## Quick Start
|
|
23
|
+
|
|
24
|
+
```python
|
|
25
|
+
import mlx.core as mx
|
|
26
|
+
from mlx_spectro import SpectralTransform
|
|
27
|
+
|
|
28
|
+
# Create a transform
|
|
29
|
+
transform = SpectralTransform(
|
|
30
|
+
n_fft=2048,
|
|
31
|
+
hop_length=512,
|
|
32
|
+
window_fn="hann",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Forward STFT
|
|
36
|
+
audio = mx.random.normal((1, 44100))
|
|
37
|
+
spec = transform.stft(audio, output_layout="bnf")
|
|
38
|
+
|
|
39
|
+
# Inverse STFT
|
|
40
|
+
reconstructed = transform.istft(spec, length=44100, input_layout="bnf")
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
## API
|
|
44
|
+
|
|
45
|
+
### `SpectralTransform`
|
|
46
|
+
|
|
47
|
+
Main class for STFT/iSTFT operations.
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
SpectralTransform(
|
|
51
|
+
n_fft: int,
|
|
52
|
+
hop_length: int,
|
|
53
|
+
win_length: int | None = None,
|
|
54
|
+
window_fn: str = "hann", # "hann", "hamming", "rect"
|
|
55
|
+
window: mx.array | None = None, # custom window array
|
|
56
|
+
periodic: bool = True,
|
|
57
|
+
center: bool = True,
|
|
58
|
+
normalized: bool = False,
|
|
59
|
+
istft_backend_policy: str | None = None, # "auto", "mlx_fft", "metal", "torch_fallback"
|
|
60
|
+
)
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
**Methods:**
|
|
64
|
+
- `stft(x, output_layout="bfn")` — Forward STFT. Input: `[T]` or `[B, T]`.
|
|
65
|
+
- `istft(z, length=None, ...)` — Inverse STFT. Returns `[B, T]`.
|
|
66
|
+
- `compiled_pair(length, layout="bnf", warmup_batch=None)` — Return compiled `(stft_fn, istft_fn)` for steady-state loops (10–20% faster).
|
|
67
|
+
- `warmup(batch=1, length=4096)` — Force kernel compilation.
|
|
68
|
+
|
|
69
|
+
### `get_transform_mlx(**kwargs)`
|
|
70
|
+
|
|
71
|
+
Factory that returns cached `SpectralTransform` instances for repeated use.
|
|
72
|
+
|
|
73
|
+
### `make_window(window, window_fn, win_length, n_fft, periodic)`
|
|
74
|
+
|
|
75
|
+
Create or validate a 1D analysis window.
|
|
76
|
+
|
|
77
|
+
### `resolve_fft_params(n_fft, hop_length, win_length, pad)`
|
|
78
|
+
|
|
79
|
+
Resolve effective FFT parameters with PyTorch-compatible defaults.
|
|
80
|
+
|
|
81
|
+
## Benchmarks
|
|
82
|
+
|
|
83
|
+
Apple M4 Max, macOS 26.3, MLX 0.30.6, PyTorch 2.10.0, 20 iterations (5 warmup).
|
|
84
|
+
|
|
85
|
+
### STFT Forward
|
|
86
|
+
|
|
87
|
+
| Config | mlx-spectro | torch MPS | mlx-stft | vs torch | vs mlx-stft |
|
|
88
|
+
|---|---|---|---|---|---|
|
|
89
|
+
| B=1 T=16k nfft=512 | 0.16 ms | 0.21 ms | 0.31 ms | 1.4x | 1.9x |
|
|
90
|
+
| B=4 T=160k nfft=1024 | 0.37 ms | 0.78 ms | 1.09 ms | **2.1x** | **3.0x** |
|
|
91
|
+
| B=8 T=160k nfft=1024 | 0.28 ms | 0.68 ms | 1.53 ms | **2.5x** | **5.6x** |
|
|
92
|
+
| B=4 T=1.3M nfft=1024 | 0.79 ms | 1.71 ms | 5.03 ms | **2.2x** | **6.3x** |
|
|
93
|
+
| B=8 T=480k nfft=1024 | 0.58 ms | 1.30 ms | 3.73 ms | **2.2x** | **6.4x** |
|
|
94
|
+
|
|
95
|
+
### iSTFT Forward
|
|
96
|
+
|
|
97
|
+
| Config | mlx-spectro | torch MPS | mlx-stft | vs torch | vs mlx-stft |
|
|
98
|
+
|---|---|---|---|---|---|
|
|
99
|
+
| B=1 T=16k nfft=512 | 0.17 ms | 0.49 ms | 0.25 ms | 3.0x | 1.5x |
|
|
100
|
+
| B=4 T=160k nfft=1024 | 0.21 ms | 0.99 ms | 0.98 ms | **4.8x** | **4.7x** |
|
|
101
|
+
| B=8 T=160k nfft=1024 | 0.29 ms | 1.58 ms | 1.62 ms | **5.4x** | **5.6x** |
|
|
102
|
+
| B=4 T=1.3M nfft=1024 | 0.77 ms | 5.74 ms | 6.68 ms | **7.5x** | **8.7x** |
|
|
103
|
+
| B=8 T=480k nfft=1024 | 0.60 ms | 4.10 ms | 4.55 ms | **6.8x** | **7.6x** |
|
|
104
|
+
|
|
105
|
+
### Differentiable STFT + iSTFT (Forward + Backward)
|
|
106
|
+
|
|
107
|
+
| Config | mlx-spectro | torch MPS | vs torch |
|
|
108
|
+
|---|---|---|---|
|
|
109
|
+
| B=1 T=16k nfft=512 | 0.32 ms | 0.97 ms | **3.0x** |
|
|
110
|
+
| B=4 T=160k nfft=1024 | 0.61 ms | 2.28 ms | **3.7x** |
|
|
111
|
+
| B=8 T=160k nfft=1024 | 1.05 ms | 4.33 ms | **4.1x** |
|
|
112
|
+
| B=4 T=1.3M nfft=1024 | 4.30 ms | 17.44 ms | **4.1x** |
|
|
113
|
+
| B=8 T=480k nfft=1024 | 3.01 ms | 12.53 ms | **4.2x** |
|
|
114
|
+
|
|
115
|
+
### Roundtrip Accuracy (STFT → iSTFT max abs error)
|
|
116
|
+
|
|
117
|
+
| Config | mlx-spectro | torch MPS |
|
|
118
|
+
|---|---|---|
|
|
119
|
+
| B=1 T=16k nfft=512 | 1.67e-06 | 2.38e-06 |
|
|
120
|
+
| B=4 T=160k nfft=2048 | 2.86e-06 | 5.25e-06 |
|
|
121
|
+
| B=8 T=480k nfft=1024 | 3.81e-06 | 4.77e-06 |
|
|
122
|
+
|
|
123
|
+
### Compiled Mode
|
|
124
|
+
|
|
125
|
+
For tight inference loops with fixed input shapes, `compiled_pair` eliminates
|
|
126
|
+
per-call Python dispatch overhead (10–20% faster for small workloads):
|
|
127
|
+
|
|
128
|
+
```python
|
|
129
|
+
t = SpectralTransform(n_fft=1024, hop_length=256, window_fn="hann")
|
|
130
|
+
stft, istft = t.compiled_pair(length=44100, warmup_batch=2)
|
|
131
|
+
|
|
132
|
+
for chunk in audio_stream:
|
|
133
|
+
z = stft(chunk)
|
|
134
|
+
z = process(z)
|
|
135
|
+
y = istft(z)
|
|
136
|
+
mx.eval(y)
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
Use the eager `t.stft()` / `t.istft()` methods when input shapes vary.
|
|
140
|
+
|
|
141
|
+
## Environment Variables
|
|
142
|
+
|
|
143
|
+
| Variable | Default | Description |
|
|
144
|
+
|---|---|---|
|
|
145
|
+
| `SPEC_MLX_AUTOTUNE` | `1` | Enable Metal kernel autotuning |
|
|
146
|
+
| `SPEC_MLX_TGX` | — | Force threadgroup size (e.g. `256` or `kernel:256`) |
|
|
147
|
+
| `SPEC_MLX_AUTOTUNE_PERSIST` | `1` | Persist autotune results to disk |
|
|
148
|
+
| `SPEC_MLX_AUTOTUNE_CACHE_PATH` | — | Override autotune cache file path |
|
|
149
|
+
| `MLX_OLA_FUSE_NORM` | `1` | Enable fused OLA+normalization kernel |
|
|
150
|
+
| `SPEC_MLX_CACHE_STATS` | `0` | Enable cache debug counters |
|
|
151
|
+
|
|
152
|
+
## License
|
|
153
|
+
|
|
154
|
+
MIT
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "mlx-spectro"
|
|
7
|
+
version = "0.2.0"
|
|
8
|
+
description = "High-performance STFT/iSTFT for Apple MLX with fused Metal kernels"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = "MIT"
|
|
11
|
+
requires-python = ">=3.10"
|
|
12
|
+
authors = [{ name = "Sam", email = "ssmall256@users.noreply.github.com" }]
|
|
13
|
+
keywords = ["mlx", "stft", "istft", "spectral", "audio", "dsp", "apple-silicon"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 4 - Beta",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"Intended Audience :: Science/Research",
|
|
18
|
+
"License :: OSI Approved :: MIT License",
|
|
19
|
+
"Operating System :: MacOS",
|
|
20
|
+
"Programming Language :: Python :: 3",
|
|
21
|
+
"Programming Language :: Python :: 3.10",
|
|
22
|
+
"Programming Language :: Python :: 3.11",
|
|
23
|
+
"Programming Language :: Python :: 3.12",
|
|
24
|
+
"Programming Language :: Python :: 3.13",
|
|
25
|
+
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
|
26
|
+
"Topic :: Scientific/Engineering",
|
|
27
|
+
]
|
|
28
|
+
dependencies = [
|
|
29
|
+
"mlx>=0.5.0",
|
|
30
|
+
"numpy",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
[project.optional-dependencies]
|
|
34
|
+
torch = ["torch>=2.0"]
|
|
35
|
+
dev = ["pytest>=7.0", "build"]
|
|
36
|
+
benchmark = ["torch>=2.0", "mlx-stft"]
|
|
37
|
+
|
|
38
|
+
[project.urls]
|
|
39
|
+
Homepage = "https://github.com/ssmall256/mlx-spectro"
|
|
40
|
+
Repository = "https://github.com/ssmall256/mlx-spectro"
|
|
41
|
+
Issues = "https://github.com/ssmall256/mlx-spectro/issues"
|
|
42
|
+
|
|
43
|
+
[tool.hatch.build.targets.wheel]
|
|
44
|
+
packages = ["src/mlx_spectro"]
|
|
45
|
+
|
|
46
|
+
[tool.pytest.ini_options]
|
|
47
|
+
testpaths = ["tests"]
|