pymss-core 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymss_core-0.1.0/LICENSE +21 -0
- pymss_core-0.1.0/PKG-INFO +113 -0
- pymss_core-0.1.0/README.md +81 -0
- pymss_core-0.1.0/pymss_core/__init__.py +25 -0
- pymss_core-0.1.0/pymss_core/checkpoint.py +127 -0
- pymss_core-0.1.0/pymss_core/config.py +135 -0
- pymss_core-0.1.0/pymss_core/modules/__init__.py +0 -0
- pymss_core-0.1.0/pymss_core/modules/_dsp.py +84 -0
- pymss_core-0.1.0/pymss_core/modules/apollo_mlx.py +242 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/__init__.py +0 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/bandsplit.py +166 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/core/__init__.py +9 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/core/model/__init__.py +9 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/core/model/_spectral.py +95 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/__init__.py +17 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/bandsplit.py +32 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/core.py +171 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/maskestim.py +20 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/tfmodel.py +12 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/utils.py +355 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/wrapper.py +235 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/maskestim.py +312 -0
- pymss_core-0.1.0/pymss_core/modules/bandit/tfmodel.py +194 -0
- pymss_core-0.1.0/pymss_core/modules/bandit_mlx.py +415 -0
- pymss_core-0.1.0/pymss_core/modules/bandit_v2/__init__.py +0 -0
- pymss_core-0.1.0/pymss_core/modules/bandit_v2/bandit.py +326 -0
- pymss_core-0.1.0/pymss_core/modules/bandit_v2/bandsplit.py +13 -0
- pymss_core-0.1.0/pymss_core/modules/bandit_v2/maskestim.py +119 -0
- pymss_core-0.1.0/pymss_core/modules/bandit_v2/tfmodel.py +18 -0
- pymss_core-0.1.0/pymss_core/modules/bandit_v2/utils.py +33 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/__init__.py +5 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/attend.py +29 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/bands.py +683 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/bs_roformer.py +111 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/bs_roformer_hyperace.py +43 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/common.py +383 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/hyperace_segm.py +331 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/mel_band_roformer.py +168 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/mlx_attention.py +408 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/mlx_roformer.py +667 -0
- pymss_core-0.1.0/pymss_core/modules/bs_roformer/transformer.py +355 -0
- pymss_core-0.1.0/pymss_core/modules/demucs4ht.py +479 -0
- pymss_core-0.1.0/pymss_core/modules/demucs_local.py +597 -0
- pymss_core-0.1.0/pymss_core/modules/demucs_mlx.py +605 -0
- pymss_core-0.1.0/pymss_core/modules/legacy_demucs.py +1552 -0
- pymss_core-0.1.0/pymss_core/modules/look2hear/__init__.py +3 -0
- pymss_core-0.1.0/pymss_core/modules/look2hear/apollo.py +549 -0
- pymss_core-0.1.0/pymss_core/modules/mdx23c_mlx.py +303 -0
- pymss_core-0.1.0/pymss_core/modules/mdx23c_tfc_tdf_v3.py +175 -0
- pymss_core-0.1.0/pymss_core/modules/mlx_utils.py +24 -0
- pymss_core-0.1.0/pymss_core/modules/scnet/__init__.py +3 -0
- pymss_core-0.1.0/pymss_core/modules/scnet/scnet.py +266 -0
- pymss_core-0.1.0/pymss_core/modules/scnet/separation.py +64 -0
- pymss_core-0.1.0/pymss_core/modules/scnet_mlx.py +411 -0
- pymss_core-0.1.0/pymss_core/modules/spectrogram.py +87 -0
- pymss_core-0.1.0/pymss_core/modules/vocal_remover/__init__.py +12 -0
- pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/__init__.py +1 -0
- pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/__init__.py +1 -0
- pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/layers.py +135 -0
- pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/layers_new.py +101 -0
- pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/model_param_init.py +19 -0
- pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/nets.py +155 -0
- pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/nets_new.py +118 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr16000_hl512.json +19 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr32000_hl512.json +19 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr33075_hl384.json +19 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr44100_hl1024.json +19 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr44100_hl256.json +19 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr44100_hl512.json +19 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr44100_hl512_cut.json +19 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr44100_hl512_nf1024.json +19 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/2band_32000.json +30 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/2band_44100_lofi.json +30 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/2band_48000.json +30 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/3band_44100.json +42 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/3band_44100_mid.json +43 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/3band_44100_msb2.json +43 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100.json +54 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100_mid.json +55 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100_msb.json +55 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100_msb2.json +55 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100_reverse.json +55 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100_sw.json +55 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_v2.json +54 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_v2_sn.json +55 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_v3.json +54 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_v3_sn.json +55 -0
- pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_v4_ms_fullband.json +58 -0
- pymss_core-0.1.0/pymss_core/utils.py +53 -0
- pymss_core-0.1.0/pymss_core.egg-info/PKG-INFO +113 -0
- pymss_core-0.1.0/pymss_core.egg-info/SOURCES.txt +95 -0
- pymss_core-0.1.0/pymss_core.egg-info/dependency_links.txt +1 -0
- pymss_core-0.1.0/pymss_core.egg-info/requires.txt +8 -0
- pymss_core-0.1.0/pymss_core.egg-info/top_level.txt +1 -0
- pymss_core-0.1.0/pyproject.toml +87 -0
- pymss_core-0.1.0/setup.cfg +4 -0
- pymss_core-0.1.0/tests/test_core_api.py +57 -0
pymss_core-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 KitsuneX07
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pymss-core
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Core model, configuration, and checkpoint package for music source separation.
|
|
5
|
+
Author-email: KitsuneX07 <ghast1085654218@163.com>
|
|
6
|
+
Maintainer-email: baicai1145 <3423714059@qq.com>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Project-URL: Bug Tracker, https://github.com/pymss-project/pymss-core/issues
|
|
9
|
+
Project-URL: Source Code, https://github.com/pymss-project/pymss-core
|
|
10
|
+
Project-URL: Documentation, https://github.com/pymss-project/pymss-core/blob/main/README.md
|
|
11
|
+
Keywords: music source separation,audio separation,music processing,machine learning,audio
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Topic :: Multimedia :: Sound/Audio
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Classifier: Operating System :: OS Independent
|
|
23
|
+
Requires-Python: >=3.10
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
License-File: LICENSE
|
|
26
|
+
Requires-Dist: numpy>=1.26
|
|
27
|
+
Requires-Dist: pyyaml>=6.0.1
|
|
28
|
+
Requires-Dist: torch>=2.7.1
|
|
29
|
+
Provides-Extra: mlx
|
|
30
|
+
Requires-Dist: mlx>=0.31.0; (sys_platform == "darwin" and platform_machine == "arm64") and extra == "mlx"
|
|
31
|
+
Dynamic: license-file
|
|
32
|
+
|
|
33
|
+
# pymss-core
|
|
34
|
+
|
|
35
|
+
[中文文档](README_CN.md)
|
|
36
|
+
|
|
37
|
+
Core model, configuration, and checkpoint package for music source separation.
|
|
38
|
+
|
|
39
|
+
`pymss-core` is the shared low-level package for higher-level projects such as
|
|
40
|
+
`pymss` inference and `pymsst` training. It contains model definitions,
|
|
41
|
+
configuration loading, and checkpoint compatibility helpers. It intentionally
|
|
42
|
+
does not include inference DSP pipelines, chunked demixing, audio file I/O,
|
|
43
|
+
model downloads, catalog management, CLI, HTTP server, WebUI, datasets, losses,
|
|
44
|
+
or training loops.
|
|
45
|
+
|
|
46
|
+
## Install
|
|
47
|
+
|
|
48
|
+
```bash
|
|
49
|
+
pip install pymss-core
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
For local development:
|
|
53
|
+
|
|
54
|
+
```bash
|
|
55
|
+
uv sync --dev
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Optional MLX backend on Apple Silicon:
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
pip install "pymss-core[mlx]"
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
## Public API
|
|
65
|
+
|
|
66
|
+
```python
|
|
67
|
+
from pymss_core import (
|
|
68
|
+
get_model_from_config,
|
|
69
|
+
load_config,
|
|
70
|
+
load_model_weights,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
model, config = get_model_from_config("bs_roformer", "config.yaml")
|
|
74
|
+
load_model_weights(model, "model.ckpt", model_type="bs_roformer", strict=True)
|
|
75
|
+
|
|
76
|
+
model.eval()
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
## Package Boundary
|
|
80
|
+
|
|
81
|
+
Included:
|
|
82
|
+
|
|
83
|
+
- YAML config loading with `AttrDict`
|
|
84
|
+
- PyTorch model definitions under `pymss_core.modules`
|
|
85
|
+
- Optional MLX backend implementations for supported model forward paths
|
|
86
|
+
- Model factory: `get_model_from_config(model_type, config_path)`
|
|
87
|
+
- Checkpoint helpers for common MSS checkpoint containers
|
|
88
|
+
- Small model-internal DSP math needed to construct model structures
|
|
89
|
+
- VR network structures and VR model parameter JSON files
|
|
90
|
+
|
|
91
|
+
Excluded:
|
|
92
|
+
|
|
93
|
+
- Audio file decoding/encoding
|
|
94
|
+
- Resampling, preprocessing, and full inference DSP pipelines
|
|
95
|
+
- Tensor-level chunked demixing runtime
|
|
96
|
+
- Model catalog, aliases, downloads, and cache management
|
|
97
|
+
- CLI, server, WebUI, and endpoint schemas
|
|
98
|
+
- Dataset, augmentation, loss, metrics, and trainer code
|
|
99
|
+
- Any default dependency on MLX, Librosa, tqdm, Lightning, FastAPI, Uvicorn,
|
|
100
|
+
PyAV, WandB, or training extras
|
|
101
|
+
|
|
102
|
+
## Repository Roles
|
|
103
|
+
|
|
104
|
+
```text
|
|
105
|
+
pymss-core
|
|
106
|
+
shared model/config/checkpoint layer
|
|
107
|
+
|
|
108
|
+
pymss
|
|
109
|
+
user-facing inference package built on pymss-core, with audio I/O and demix
|
|
110
|
+
|
|
111
|
+
pymsst
|
|
112
|
+
training package built on pymss-core, with training data/loss/runtime code
|
|
113
|
+
```
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# pymss-core
|
|
2
|
+
|
|
3
|
+
[中文文档](README_CN.md)
|
|
4
|
+
|
|
5
|
+
Core model, configuration, and checkpoint package for music source separation.
|
|
6
|
+
|
|
7
|
+
`pymss-core` is the shared low-level package for higher-level projects such as
|
|
8
|
+
`pymss` inference and `pymsst` training. It contains model definitions,
|
|
9
|
+
configuration loading, and checkpoint compatibility helpers. It intentionally
|
|
10
|
+
does not include inference DSP pipelines, chunked demixing, audio file I/O,
|
|
11
|
+
model downloads, catalog management, CLI, HTTP server, WebUI, datasets, losses,
|
|
12
|
+
or training loops.
|
|
13
|
+
|
|
14
|
+
## Install
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
pip install pymss-core
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
For local development:
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
uv sync --dev
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
Optional MLX backend on Apple Silicon:
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
pip install "pymss-core[mlx]"
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## Public API
|
|
33
|
+
|
|
34
|
+
```python
|
|
35
|
+
from pymss_core import (
|
|
36
|
+
get_model_from_config,
|
|
37
|
+
load_config,
|
|
38
|
+
load_model_weights,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
model, config = get_model_from_config("bs_roformer", "config.yaml")
|
|
42
|
+
load_model_weights(model, "model.ckpt", model_type="bs_roformer", strict=True)
|
|
43
|
+
|
|
44
|
+
model.eval()
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
## Package Boundary
|
|
48
|
+
|
|
49
|
+
Included:
|
|
50
|
+
|
|
51
|
+
- YAML config loading with `AttrDict`
|
|
52
|
+
- PyTorch model definitions under `pymss_core.modules`
|
|
53
|
+
- Optional MLX backend implementations for supported model forward paths
|
|
54
|
+
- Model factory: `get_model_from_config(model_type, config_path)`
|
|
55
|
+
- Checkpoint helpers for common MSS checkpoint containers
|
|
56
|
+
- Small model-internal DSP math needed to construct model structures
|
|
57
|
+
- VR network structures and VR model parameter JSON files
|
|
58
|
+
|
|
59
|
+
Excluded:
|
|
60
|
+
|
|
61
|
+
- Audio file decoding/encoding
|
|
62
|
+
- Resampling, preprocessing, and full inference DSP pipelines
|
|
63
|
+
- Tensor-level chunked demixing runtime
|
|
64
|
+
- Model catalog, aliases, downloads, and cache management
|
|
65
|
+
- CLI, server, WebUI, and endpoint schemas
|
|
66
|
+
- Dataset, augmentation, loss, metrics, and trainer code
|
|
67
|
+
- Any default dependency on MLX, Librosa, tqdm, Lightning, FastAPI, Uvicorn,
|
|
68
|
+
PyAV, WandB, or training extras
|
|
69
|
+
|
|
70
|
+
## Repository Roles
|
|
71
|
+
|
|
72
|
+
```text
|
|
73
|
+
pymss-core
|
|
74
|
+
shared model/config/checkpoint layer
|
|
75
|
+
|
|
76
|
+
pymss
|
|
77
|
+
user-facing inference package built on pymss-core, with audio I/O and demix
|
|
78
|
+
|
|
79
|
+
pymsst
|
|
80
|
+
training package built on pymss-core, with training data/loss/runtime code
|
|
81
|
+
```
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Core model, configuration, and checkpoint API for music source separation.
|
|
2
|
+
|
|
3
|
+
`pymss_core` contains the shared pieces used by higher-level packages:
|
|
4
|
+
configuration loading, model construction, model definitions, and
|
|
5
|
+
checkpoint/state-dict helpers. It intentionally does not provide file audio
|
|
6
|
+
I/O, inference DSP pipelines, chunked demixing, model catalog downloads, CLI,
|
|
7
|
+
HTTP server, or WebUI functionality.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .checkpoint import load_checkpoint, load_model_weights, load_state_dict, unwrap_state_dict
|
|
11
|
+
from .config import AttrDict, ConfigLoader, load_config, to_attrdict, to_plain
|
|
12
|
+
from .utils import get_model_from_config
|
|
13
|
+
|
|
14
|
+
__all__ = (
|
|
15
|
+
"AttrDict",
|
|
16
|
+
"ConfigLoader",
|
|
17
|
+
"get_model_from_config",
|
|
18
|
+
"load_checkpoint",
|
|
19
|
+
"load_config",
|
|
20
|
+
"load_model_weights",
|
|
21
|
+
"load_state_dict",
|
|
22
|
+
"to_attrdict",
|
|
23
|
+
"to_plain",
|
|
24
|
+
"unwrap_state_dict",
|
|
25
|
+
)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Checkpoint helpers shared by inference and training frontends."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from types import ModuleType
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
STATE_DICT_KEYS = ("state", "state_dict", "model_state_dict")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def unwrap_state_dict(checkpoint: Any) -> Any:
|
|
16
|
+
"""Return the model state dict from common MSS checkpoint containers."""
|
|
17
|
+
if isinstance(checkpoint, dict):
|
|
18
|
+
for key in STATE_DICT_KEYS:
|
|
19
|
+
if key in checkpoint:
|
|
20
|
+
return checkpoint[key]
|
|
21
|
+
return checkpoint
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _install_demucs_pickle_stubs() -> dict[str, ModuleType | None]:
|
|
25
|
+
import sys
|
|
26
|
+
import types
|
|
27
|
+
|
|
28
|
+
module_names = ("demucs", "demucs.demucs", "demucs.hdemucs", "demucs.htdemucs")
|
|
29
|
+
previous = {name: sys.modules.get(name) for name in module_names}
|
|
30
|
+
package = sys.modules.setdefault("demucs", types.ModuleType("demucs"))
|
|
31
|
+
package.__path__ = []
|
|
32
|
+
for module_name, class_names in {
|
|
33
|
+
"demucs": ("Demucs",),
|
|
34
|
+
"hdemucs": ("HDemucs", "HTDemucs"),
|
|
35
|
+
"htdemucs": ("HTDemucs",),
|
|
36
|
+
}.items():
|
|
37
|
+
full_name = f"demucs.{module_name}"
|
|
38
|
+
module = sys.modules.setdefault(full_name, types.ModuleType(full_name))
|
|
39
|
+
setattr(package, module_name, module)
|
|
40
|
+
for class_name in class_names:
|
|
41
|
+
if not hasattr(module, class_name):
|
|
42
|
+
setattr(module, class_name, type(class_name, (), {"__module__": full_name}))
|
|
43
|
+
return previous
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _restore_modules(previous: dict[str, ModuleType | None]) -> None:
|
|
47
|
+
import sys
|
|
48
|
+
|
|
49
|
+
for name, module in previous.items():
|
|
50
|
+
if module is None:
|
|
51
|
+
sys.modules.pop(name, None)
|
|
52
|
+
else:
|
|
53
|
+
sys.modules[name] = module
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _torch_load(path: str | Path, *, map_location="cpu", weights_only: bool | None = None, mmap: bool = True) -> Any:
|
|
57
|
+
kwargs: dict[str, Any] = {"map_location": map_location}
|
|
58
|
+
if weights_only is not None:
|
|
59
|
+
kwargs["weights_only"] = weights_only
|
|
60
|
+
if mmap:
|
|
61
|
+
kwargs["mmap"] = True
|
|
62
|
+
try:
|
|
63
|
+
return torch.load(path, **kwargs)
|
|
64
|
+
except TypeError:
|
|
65
|
+
kwargs.pop("mmap", None)
|
|
66
|
+
try:
|
|
67
|
+
return torch.load(path, **kwargs)
|
|
68
|
+
except TypeError:
|
|
69
|
+
kwargs.pop("weights_only", None)
|
|
70
|
+
return torch.load(path, **kwargs)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def load_checkpoint(
|
|
74
|
+
path: str | Path,
|
|
75
|
+
*,
|
|
76
|
+
model_type: str | None = None,
|
|
77
|
+
map_location: str | torch.device = "cpu",
|
|
78
|
+
weights_only: bool | None = None,
|
|
79
|
+
mmap: bool = True,
|
|
80
|
+
) -> Any:
|
|
81
|
+
"""Load a checkpoint package with compatibility for common MSS formats."""
|
|
82
|
+
model_type = (model_type or "").lower()
|
|
83
|
+
if model_type in {"htdemucs", "demucs", "legacy_demucs", "legacy_tasnet"}:
|
|
84
|
+
previous = _install_demucs_pickle_stubs()
|
|
85
|
+
try:
|
|
86
|
+
return _torch_load(path, map_location=map_location, weights_only=False, mmap=mmap)
|
|
87
|
+
finally:
|
|
88
|
+
_restore_modules(previous)
|
|
89
|
+
if model_type == "apollo":
|
|
90
|
+
weights_only = False if weights_only is None else weights_only
|
|
91
|
+
return _torch_load(path, map_location=map_location, weights_only=weights_only, mmap=mmap)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def load_state_dict(
|
|
95
|
+
path: str | Path,
|
|
96
|
+
*,
|
|
97
|
+
model_type: str | None = None,
|
|
98
|
+
map_location: str | torch.device = "cpu",
|
|
99
|
+
weights_only: bool | None = None,
|
|
100
|
+
mmap: bool = True,
|
|
101
|
+
) -> Any:
|
|
102
|
+
"""Load and unwrap the model state dict from a checkpoint file."""
|
|
103
|
+
return unwrap_state_dict(
|
|
104
|
+
load_checkpoint(
|
|
105
|
+
path,
|
|
106
|
+
model_type=model_type,
|
|
107
|
+
map_location=map_location,
|
|
108
|
+
weights_only=weights_only,
|
|
109
|
+
mmap=mmap,
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def load_model_weights(
|
|
115
|
+
model: torch.nn.Module,
|
|
116
|
+
checkpoint_or_path: Any,
|
|
117
|
+
*,
|
|
118
|
+
model_type: str | None = None,
|
|
119
|
+
strict: bool = True,
|
|
120
|
+
map_location: str | torch.device = "cpu",
|
|
121
|
+
) -> Any:
|
|
122
|
+
"""Load weights from a checkpoint package or file into a model."""
|
|
123
|
+
if isinstance(checkpoint_or_path, (str, Path)):
|
|
124
|
+
state_dict = load_state_dict(checkpoint_or_path, model_type=model_type, map_location=map_location)
|
|
125
|
+
else:
|
|
126
|
+
state_dict = unwrap_state_dict(checkpoint_or_path)
|
|
127
|
+
return model.load_state_dict(state_dict, strict=strict)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
import yaml
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConfigLoader(yaml.FullLoader):
|
|
7
|
+
"""YAML loader used by pymss-core model configuration files."""
|
|
8
|
+
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
ConfigLoader.add_implicit_resolver(
|
|
13
|
+
"tag:yaml.org,2002:float",
|
|
14
|
+
re.compile(
|
|
15
|
+
r"""^[-+]?(
|
|
16
|
+
([0-9][0-9_]*)?\.[0-9_]+([eE][-+]?[0-9]+)?
|
|
17
|
+
|[0-9][0-9_]*[eE][-+]?[0-9]+
|
|
18
|
+
|\.(inf|Inf|INF)
|
|
19
|
+
|\.(nan|NaN|NAN)
|
|
20
|
+
)$""",
|
|
21
|
+
re.X,
|
|
22
|
+
),
|
|
23
|
+
list("-+0123456789."),
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AttrDict(dict):
|
|
28
|
+
"""Dictionary that recursively exposes keys as attributes.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
data (Mapping | None, optional): Data value. Defaults to None.
|
|
32
|
+
**kwargs: Additional keyword arguments.
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
>>> cfg = AttrDict({"audio": {"chunk_size": 485100}})
|
|
36
|
+
>>> cfg.audio.chunk_size
|
|
37
|
+
485100"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, data=None, **kwargs):
|
|
40
|
+
"""Initialize the instance.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
data (Mapping | None, optional): Data value. Defaults to None.
|
|
44
|
+
**kwargs: Additional keyword arguments.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
None: This method completes for its side effects."""
|
|
48
|
+
super().__init__()
|
|
49
|
+
for key, value in dict(data or {}, **kwargs).items():
|
|
50
|
+
self[key] = value
|
|
51
|
+
|
|
52
|
+
def __getattr__(self, key):
|
|
53
|
+
"""Return a missing attribute from the underlying mapping.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
key (str): Key value.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Any: Computed result."""
|
|
60
|
+
try:
|
|
61
|
+
return self[key]
|
|
62
|
+
except KeyError as exc:
|
|
63
|
+
raise AttributeError(key) from exc
|
|
64
|
+
|
|
65
|
+
def __setattr__(self, key, value):
|
|
66
|
+
"""Store an attribute assignment in the underlying mapping.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
key (str): Key value.
|
|
70
|
+
value (Any): Value value.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
None: This method completes for its side effects."""
|
|
74
|
+
self[key] = value
|
|
75
|
+
|
|
76
|
+
def __setitem__(self, key, value):
|
|
77
|
+
"""Store an item after recursively converting nested dictionaries.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
key (str): Key value.
|
|
81
|
+
value (Any): Value value.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
None: This method completes for its side effects."""
|
|
85
|
+
super().__setitem__(key, to_attrdict(value))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def to_attrdict(value):
|
|
89
|
+
"""Recursively convert dictionaries to AttrDict objects.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
value (Any): Value value.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Any: Converted value with nested dictionaries wrapped as AttrDict."""
|
|
96
|
+
if isinstance(value, dict):
|
|
97
|
+
return value if isinstance(value, AttrDict) else AttrDict(value)
|
|
98
|
+
if isinstance(value, list):
|
|
99
|
+
return [to_attrdict(item) for item in value]
|
|
100
|
+
if isinstance(value, tuple):
|
|
101
|
+
return tuple(to_attrdict(item) for item in value)
|
|
102
|
+
return value
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def to_plain(value):
|
|
106
|
+
"""Recursively convert AttrDict objects back to plain Python containers.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
value (Any): Value value.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Any: Converted value using plain dictionaries, lists, and tuples."""
|
|
113
|
+
if isinstance(value, AttrDict):
|
|
114
|
+
return {key: to_plain(item) for key, item in value.items()}
|
|
115
|
+
if isinstance(value, list):
|
|
116
|
+
return [to_plain(item) for item in value]
|
|
117
|
+
if isinstance(value, tuple):
|
|
118
|
+
return tuple(to_plain(item) for item in value)
|
|
119
|
+
return value
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def load_config(path):
|
|
123
|
+
"""Load a YAML model configuration file.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
path (str | os.PathLike): File system path.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
AttrDict: Parsed configuration with attribute access.
|
|
130
|
+
|
|
131
|
+
Example:
|
|
132
|
+
>>> config = load_config("config.yaml")
|
|
133
|
+
>>> config.inference.batch_size"""
|
|
134
|
+
with open(path, encoding="utf-8") as f:
|
|
135
|
+
return to_attrdict(yaml.load(f, Loader=ConfigLoader))
|
|
File without changes
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Small DSP helpers needed by model definitions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def hz_to_midi(hz):
|
|
9
|
+
"""Convert frequencies in Hz to MIDI note numbers."""
|
|
10
|
+
hz = np.asarray(hz)
|
|
11
|
+
return 69.0 + 12.0 * np.log2(hz / 440.0)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def midi_to_hz(midi):
|
|
15
|
+
"""Convert MIDI note numbers to frequencies in Hz."""
|
|
16
|
+
midi = np.asarray(midi)
|
|
17
|
+
return 440.0 * np.power(2.0, (midi - 69.0) / 12.0)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _hz_to_mel(frequencies, *, htk=False):
|
|
21
|
+
frequencies = np.asarray(frequencies, dtype=np.float64)
|
|
22
|
+
if htk:
|
|
23
|
+
return 2595.0 * np.log10(1.0 + frequencies / 700.0)
|
|
24
|
+
|
|
25
|
+
f_sp = 200.0 / 3
|
|
26
|
+
mels = frequencies / f_sp
|
|
27
|
+
min_log_hz = 1000.0
|
|
28
|
+
min_log_mel = min_log_hz / f_sp
|
|
29
|
+
logstep = np.log(6.4) / 27.0
|
|
30
|
+
log_t = frequencies >= min_log_hz
|
|
31
|
+
mels = np.array(mels, copy=True)
|
|
32
|
+
mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep
|
|
33
|
+
return mels
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _mel_to_hz(mels, *, htk=False):
|
|
37
|
+
mels = np.asarray(mels, dtype=np.float64)
|
|
38
|
+
if htk:
|
|
39
|
+
return 700.0 * (np.power(10.0, mels / 2595.0) - 1.0)
|
|
40
|
+
|
|
41
|
+
f_sp = 200.0 / 3
|
|
42
|
+
freqs = f_sp * mels
|
|
43
|
+
min_log_hz = 1000.0
|
|
44
|
+
min_log_mel = min_log_hz / f_sp
|
|
45
|
+
logstep = np.log(6.4) / 27.0
|
|
46
|
+
log_t = mels >= min_log_mel
|
|
47
|
+
freqs = np.array(freqs, copy=True)
|
|
48
|
+
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
|
49
|
+
return freqs
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def mel_frequencies(n_mels, *, fmin=0.0, fmax=11025.0, htk=False):
|
|
53
|
+
"""Return center frequencies on the mel scale, including endpoints."""
|
|
54
|
+
min_mel = _hz_to_mel(fmin, htk=htk)
|
|
55
|
+
max_mel = _hz_to_mel(fmax, htk=htk)
|
|
56
|
+
return _mel_to_hz(np.linspace(min_mel, max_mel, int(n_mels)), htk=htk)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def fft_frequencies(*, sr, n_fft):
|
|
60
|
+
"""Return FFT bin center frequencies."""
|
|
61
|
+
return np.linspace(0.0, float(sr) / 2.0, int(1 + n_fft // 2), endpoint=True)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def mel_filterbank(sr, n_fft, n_mels=128, fmin=0.0, fmax=None, htk=False, norm="slaney", dtype=np.float32):
|
|
65
|
+
"""Create a triangular mel filterbank for model initialization."""
|
|
66
|
+
if fmax is None:
|
|
67
|
+
fmax = float(sr) / 2.0
|
|
68
|
+
|
|
69
|
+
mel_f = mel_frequencies(int(n_mels) + 2, fmin=fmin, fmax=fmax, htk=htk)
|
|
70
|
+
fft_f = fft_frequencies(sr=sr, n_fft=n_fft)
|
|
71
|
+
|
|
72
|
+
fdiff = np.diff(mel_f)
|
|
73
|
+
ramps = np.subtract.outer(mel_f, fft_f)
|
|
74
|
+
lower = -ramps[:-2] / fdiff[:-1, np.newaxis]
|
|
75
|
+
upper = ramps[2:] / fdiff[1:, np.newaxis]
|
|
76
|
+
weights = np.maximum(0.0, np.minimum(lower, upper))
|
|
77
|
+
|
|
78
|
+
if norm == "slaney":
|
|
79
|
+
enorm = 2.0 / (mel_f[2 : int(n_mels) + 2] - mel_f[: int(n_mels)])
|
|
80
|
+
weights *= enorm[:, np.newaxis]
|
|
81
|
+
elif norm is not None:
|
|
82
|
+
raise ValueError(f"Unsupported mel filterbank norm: {norm!r}")
|
|
83
|
+
|
|
84
|
+
return weights.astype(dtype, copy=False)
|