pymss-core 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. pymss_core-0.1.0/LICENSE +21 -0
  2. pymss_core-0.1.0/PKG-INFO +113 -0
  3. pymss_core-0.1.0/README.md +81 -0
  4. pymss_core-0.1.0/pymss_core/__init__.py +25 -0
  5. pymss_core-0.1.0/pymss_core/checkpoint.py +127 -0
  6. pymss_core-0.1.0/pymss_core/config.py +135 -0
  7. pymss_core-0.1.0/pymss_core/modules/__init__.py +0 -0
  8. pymss_core-0.1.0/pymss_core/modules/_dsp.py +84 -0
  9. pymss_core-0.1.0/pymss_core/modules/apollo_mlx.py +242 -0
  10. pymss_core-0.1.0/pymss_core/modules/bandit/__init__.py +0 -0
  11. pymss_core-0.1.0/pymss_core/modules/bandit/bandsplit.py +166 -0
  12. pymss_core-0.1.0/pymss_core/modules/bandit/core/__init__.py +9 -0
  13. pymss_core-0.1.0/pymss_core/modules/bandit/core/model/__init__.py +9 -0
  14. pymss_core-0.1.0/pymss_core/modules/bandit/core/model/_spectral.py +95 -0
  15. pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/__init__.py +17 -0
  16. pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/bandsplit.py +32 -0
  17. pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/core.py +171 -0
  18. pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/maskestim.py +20 -0
  19. pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/tfmodel.py +12 -0
  20. pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/utils.py +355 -0
  21. pymss_core-0.1.0/pymss_core/modules/bandit/core/model/bsrnn/wrapper.py +235 -0
  22. pymss_core-0.1.0/pymss_core/modules/bandit/maskestim.py +312 -0
  23. pymss_core-0.1.0/pymss_core/modules/bandit/tfmodel.py +194 -0
  24. pymss_core-0.1.0/pymss_core/modules/bandit_mlx.py +415 -0
  25. pymss_core-0.1.0/pymss_core/modules/bandit_v2/__init__.py +0 -0
  26. pymss_core-0.1.0/pymss_core/modules/bandit_v2/bandit.py +326 -0
  27. pymss_core-0.1.0/pymss_core/modules/bandit_v2/bandsplit.py +13 -0
  28. pymss_core-0.1.0/pymss_core/modules/bandit_v2/maskestim.py +119 -0
  29. pymss_core-0.1.0/pymss_core/modules/bandit_v2/tfmodel.py +18 -0
  30. pymss_core-0.1.0/pymss_core/modules/bandit_v2/utils.py +33 -0
  31. pymss_core-0.1.0/pymss_core/modules/bs_roformer/__init__.py +5 -0
  32. pymss_core-0.1.0/pymss_core/modules/bs_roformer/attend.py +29 -0
  33. pymss_core-0.1.0/pymss_core/modules/bs_roformer/bands.py +683 -0
  34. pymss_core-0.1.0/pymss_core/modules/bs_roformer/bs_roformer.py +111 -0
  35. pymss_core-0.1.0/pymss_core/modules/bs_roformer/bs_roformer_hyperace.py +43 -0
  36. pymss_core-0.1.0/pymss_core/modules/bs_roformer/common.py +383 -0
  37. pymss_core-0.1.0/pymss_core/modules/bs_roformer/hyperace_segm.py +331 -0
  38. pymss_core-0.1.0/pymss_core/modules/bs_roformer/mel_band_roformer.py +168 -0
  39. pymss_core-0.1.0/pymss_core/modules/bs_roformer/mlx_attention.py +408 -0
  40. pymss_core-0.1.0/pymss_core/modules/bs_roformer/mlx_roformer.py +667 -0
  41. pymss_core-0.1.0/pymss_core/modules/bs_roformer/transformer.py +355 -0
  42. pymss_core-0.1.0/pymss_core/modules/demucs4ht.py +479 -0
  43. pymss_core-0.1.0/pymss_core/modules/demucs_local.py +597 -0
  44. pymss_core-0.1.0/pymss_core/modules/demucs_mlx.py +605 -0
  45. pymss_core-0.1.0/pymss_core/modules/legacy_demucs.py +1552 -0
  46. pymss_core-0.1.0/pymss_core/modules/look2hear/__init__.py +3 -0
  47. pymss_core-0.1.0/pymss_core/modules/look2hear/apollo.py +549 -0
  48. pymss_core-0.1.0/pymss_core/modules/mdx23c_mlx.py +303 -0
  49. pymss_core-0.1.0/pymss_core/modules/mdx23c_tfc_tdf_v3.py +175 -0
  50. pymss_core-0.1.0/pymss_core/modules/mlx_utils.py +24 -0
  51. pymss_core-0.1.0/pymss_core/modules/scnet/__init__.py +3 -0
  52. pymss_core-0.1.0/pymss_core/modules/scnet/scnet.py +266 -0
  53. pymss_core-0.1.0/pymss_core/modules/scnet/separation.py +64 -0
  54. pymss_core-0.1.0/pymss_core/modules/scnet_mlx.py +411 -0
  55. pymss_core-0.1.0/pymss_core/modules/spectrogram.py +87 -0
  56. pymss_core-0.1.0/pymss_core/modules/vocal_remover/__init__.py +12 -0
  57. pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/__init__.py +1 -0
  58. pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/__init__.py +1 -0
  59. pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/layers.py +135 -0
  60. pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/layers_new.py +101 -0
  61. pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/model_param_init.py +19 -0
  62. pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/nets.py +155 -0
  63. pymss_core-0.1.0/pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/nets_new.py +118 -0
  64. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr16000_hl512.json +19 -0
  65. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr32000_hl512.json +19 -0
  66. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr33075_hl384.json +19 -0
  67. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr44100_hl1024.json +19 -0
  68. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr44100_hl256.json +19 -0
  69. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr44100_hl512.json +19 -0
  70. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr44100_hl512_cut.json +19 -0
  71. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/1band_sr44100_hl512_nf1024.json +19 -0
  72. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/2band_32000.json +30 -0
  73. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/2band_44100_lofi.json +30 -0
  74. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/2band_48000.json +30 -0
  75. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/3band_44100.json +42 -0
  76. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/3band_44100_mid.json +43 -0
  77. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/3band_44100_msb2.json +43 -0
  78. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100.json +54 -0
  79. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100_mid.json +55 -0
  80. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100_msb.json +55 -0
  81. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100_msb2.json +55 -0
  82. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100_reverse.json +55 -0
  83. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_44100_sw.json +55 -0
  84. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_v2.json +54 -0
  85. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_v2_sn.json +55 -0
  86. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_v3.json +54 -0
  87. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_v3_sn.json +55 -0
  88. pymss_core-0.1.0/pymss_core/resources/vr_modelparams/4band_v4_ms_fullband.json +58 -0
  89. pymss_core-0.1.0/pymss_core/utils.py +53 -0
  90. pymss_core-0.1.0/pymss_core.egg-info/PKG-INFO +113 -0
  91. pymss_core-0.1.0/pymss_core.egg-info/SOURCES.txt +95 -0
  92. pymss_core-0.1.0/pymss_core.egg-info/dependency_links.txt +1 -0
  93. pymss_core-0.1.0/pymss_core.egg-info/requires.txt +8 -0
  94. pymss_core-0.1.0/pymss_core.egg-info/top_level.txt +1 -0
  95. pymss_core-0.1.0/pyproject.toml +87 -0
  96. pymss_core-0.1.0/setup.cfg +4 -0
  97. pymss_core-0.1.0/tests/test_core_api.py +57 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 KitsuneX07
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,113 @@
1
+ Metadata-Version: 2.4
2
+ Name: pymss-core
3
+ Version: 0.1.0
4
+ Summary: Core model, configuration, and checkpoint package for music source separation.
5
+ Author-email: KitsuneX07 <ghast1085654218@163.com>
6
+ Maintainer-email: baicai1145 <3423714059@qq.com>
7
+ License-Expression: MIT
8
+ Project-URL: Bug Tracker, https://github.com/pymss-project/pymss-core/issues
9
+ Project-URL: Source Code, https://github.com/pymss-project/pymss-core
10
+ Project-URL: Documentation, https://github.com/pymss-project/pymss-core/blob/main/README.md
11
+ Keywords: music source separation,audio separation,music processing,machine learning,audio
12
+ Classifier: Development Status :: 3 - Alpha
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Topic :: Multimedia :: Sound/Audio
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Classifier: Operating System :: OS Independent
23
+ Requires-Python: >=3.10
24
+ Description-Content-Type: text/markdown
25
+ License-File: LICENSE
26
+ Requires-Dist: numpy>=1.26
27
+ Requires-Dist: pyyaml>=6.0.1
28
+ Requires-Dist: torch>=2.7.1
29
+ Provides-Extra: mlx
30
+ Requires-Dist: mlx>=0.31.0; (sys_platform == "darwin" and platform_machine == "arm64") and extra == "mlx"
31
+ Dynamic: license-file
32
+
33
+ # pymss-core
34
+
35
+ [中文文档](README_CN.md)
36
+
37
+ Core model, configuration, and checkpoint package for music source separation.
38
+
39
+ `pymss-core` is the shared low-level package for higher-level projects such as
40
+ `pymss` inference and `pymsst` training. It contains model definitions,
41
+ configuration loading, and checkpoint compatibility helpers. It intentionally
42
+ does not include inference DSP pipelines, chunked demixing, audio file I/O,
43
+ model downloads, catalog management, CLI, HTTP server, WebUI, datasets, losses,
44
+ or training loops.
45
+
46
+ ## Install
47
+
48
+ ```bash
49
+ pip install pymss-core
50
+ ```
51
+
52
+ For local development:
53
+
54
+ ```bash
55
+ uv sync --dev
56
+ ```
57
+
58
+ Optional MLX backend on Apple Silicon:
59
+
60
+ ```bash
61
+ pip install "pymss-core[mlx]"
62
+ ```
63
+
64
+ ## Public API
65
+
66
+ ```python
67
+ from pymss_core import (
68
+ get_model_from_config,
69
+ load_config,
70
+ load_model_weights,
71
+ )
72
+
73
+ model, config = get_model_from_config("bs_roformer", "config.yaml")
74
+ load_model_weights(model, "model.ckpt", model_type="bs_roformer", strict=True)
75
+
76
+ model.eval()
77
+ ```
78
+
79
+ ## Package Boundary
80
+
81
+ Included:
82
+
83
+ - YAML config loading with `AttrDict`
84
+ - PyTorch model definitions under `pymss_core.modules`
85
+ - Optional MLX backend implementations for supported model forward paths
86
+ - Model factory: `get_model_from_config(model_type, config_path)`
87
+ - Checkpoint helpers for common MSS checkpoint containers
88
+ - Small model-internal DSP math needed to construct model structures
89
+ - VR network structures and VR model parameter JSON files
90
+
91
+ Excluded:
92
+
93
+ - Audio file decoding/encoding
94
+ - Resampling, preprocessing, and full inference DSP pipelines
95
+ - Tensor-level chunked demixing runtime
96
+ - Model catalog, aliases, downloads, and cache management
97
+ - CLI, server, WebUI, and endpoint schemas
98
+ - Dataset, augmentation, loss, metrics, and trainer code
99
+ - Any default dependency on MLX, Librosa, tqdm, Lightning, FastAPI, Uvicorn,
100
+ PyAV, WandB, or training extras
101
+
102
+ ## Repository Roles
103
+
104
+ ```text
105
+ pymss-core
106
+ shared model/config/checkpoint layer
107
+
108
+ pymss
109
+ user-facing inference package built on pymss-core, with audio I/O and demix
110
+
111
+ pymsst
112
+ training package built on pymss-core, with training data/loss/runtime code
113
+ ```
@@ -0,0 +1,81 @@
1
+ # pymss-core
2
+
3
+ [中文文档](README_CN.md)
4
+
5
+ Core model, configuration, and checkpoint package for music source separation.
6
+
7
+ `pymss-core` is the shared low-level package for higher-level projects such as
8
+ `pymss` inference and `pymsst` training. It contains model definitions,
9
+ configuration loading, and checkpoint compatibility helpers. It intentionally
10
+ does not include inference DSP pipelines, chunked demixing, audio file I/O,
11
+ model downloads, catalog management, CLI, HTTP server, WebUI, datasets, losses,
12
+ or training loops.
13
+
14
+ ## Install
15
+
16
+ ```bash
17
+ pip install pymss-core
18
+ ```
19
+
20
+ For local development:
21
+
22
+ ```bash
23
+ uv sync --dev
24
+ ```
25
+
26
+ Optional MLX backend on Apple Silicon:
27
+
28
+ ```bash
29
+ pip install "pymss-core[mlx]"
30
+ ```
31
+
32
+ ## Public API
33
+
34
+ ```python
35
+ from pymss_core import (
36
+ get_model_from_config,
37
+ load_config,
38
+ load_model_weights,
39
+ )
40
+
41
+ model, config = get_model_from_config("bs_roformer", "config.yaml")
42
+ load_model_weights(model, "model.ckpt", model_type="bs_roformer", strict=True)
43
+
44
+ model.eval()
45
+ ```
46
+
47
+ ## Package Boundary
48
+
49
+ Included:
50
+
51
+ - YAML config loading with `AttrDict`
52
+ - PyTorch model definitions under `pymss_core.modules`
53
+ - Optional MLX backend implementations for supported model forward paths
54
+ - Model factory: `get_model_from_config(model_type, config_path)`
55
+ - Checkpoint helpers for common MSS checkpoint containers
56
+ - Small model-internal DSP math needed to construct model structures
57
+ - VR network structures and VR model parameter JSON files
58
+
59
+ Excluded:
60
+
61
+ - Audio file decoding/encoding
62
+ - Resampling, preprocessing, and full inference DSP pipelines
63
+ - Tensor-level chunked demixing runtime
64
+ - Model catalog, aliases, downloads, and cache management
65
+ - CLI, server, WebUI, and endpoint schemas
66
+ - Dataset, augmentation, loss, metrics, and trainer code
67
+ - Any default dependency on MLX, Librosa, tqdm, Lightning, FastAPI, Uvicorn,
68
+ PyAV, WandB, or training extras
69
+
70
+ ## Repository Roles
71
+
72
+ ```text
73
+ pymss-core
74
+ shared model/config/checkpoint layer
75
+
76
+ pymss
77
+ user-facing inference package built on pymss-core, with audio I/O and demix
78
+
79
+ pymsst
80
+ training package built on pymss-core, with training data/loss/runtime code
81
+ ```
@@ -0,0 +1,25 @@
1
+ """Core model, configuration, and checkpoint API for music source separation.
2
+
3
+ `pymss_core` contains the shared pieces used by higher-level packages:
4
+ configuration loading, model construction, model definitions, and
5
+ checkpoint/state-dict helpers. It intentionally does not provide file audio
6
+ I/O, inference DSP pipelines, chunked demixing, model catalog downloads, CLI,
7
+ HTTP server, or WebUI functionality.
8
+ """
9
+
10
+ from .checkpoint import load_checkpoint, load_model_weights, load_state_dict, unwrap_state_dict
11
+ from .config import AttrDict, ConfigLoader, load_config, to_attrdict, to_plain
12
+ from .utils import get_model_from_config
13
+
14
+ __all__ = (
15
+ "AttrDict",
16
+ "ConfigLoader",
17
+ "get_model_from_config",
18
+ "load_checkpoint",
19
+ "load_config",
20
+ "load_model_weights",
21
+ "load_state_dict",
22
+ "to_attrdict",
23
+ "to_plain",
24
+ "unwrap_state_dict",
25
+ )
@@ -0,0 +1,127 @@
1
+ """Checkpoint helpers shared by inference and training frontends."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+ from typing import Any
8
+
9
+ import torch
10
+
11
+
12
+ STATE_DICT_KEYS = ("state", "state_dict", "model_state_dict")
13
+
14
+
15
+ def unwrap_state_dict(checkpoint: Any) -> Any:
16
+ """Return the model state dict from common MSS checkpoint containers."""
17
+ if isinstance(checkpoint, dict):
18
+ for key in STATE_DICT_KEYS:
19
+ if key in checkpoint:
20
+ return checkpoint[key]
21
+ return checkpoint
22
+
23
+
24
+ def _install_demucs_pickle_stubs() -> dict[str, ModuleType | None]:
25
+ import sys
26
+ import types
27
+
28
+ module_names = ("demucs", "demucs.demucs", "demucs.hdemucs", "demucs.htdemucs")
29
+ previous = {name: sys.modules.get(name) for name in module_names}
30
+ package = sys.modules.setdefault("demucs", types.ModuleType("demucs"))
31
+ package.__path__ = []
32
+ for module_name, class_names in {
33
+ "demucs": ("Demucs",),
34
+ "hdemucs": ("HDemucs", "HTDemucs"),
35
+ "htdemucs": ("HTDemucs",),
36
+ }.items():
37
+ full_name = f"demucs.{module_name}"
38
+ module = sys.modules.setdefault(full_name, types.ModuleType(full_name))
39
+ setattr(package, module_name, module)
40
+ for class_name in class_names:
41
+ if not hasattr(module, class_name):
42
+ setattr(module, class_name, type(class_name, (), {"__module__": full_name}))
43
+ return previous
44
+
45
+
46
+ def _restore_modules(previous: dict[str, ModuleType | None]) -> None:
47
+ import sys
48
+
49
+ for name, module in previous.items():
50
+ if module is None:
51
+ sys.modules.pop(name, None)
52
+ else:
53
+ sys.modules[name] = module
54
+
55
+
56
+ def _torch_load(path: str | Path, *, map_location="cpu", weights_only: bool | None = None, mmap: bool = True) -> Any:
57
+ kwargs: dict[str, Any] = {"map_location": map_location}
58
+ if weights_only is not None:
59
+ kwargs["weights_only"] = weights_only
60
+ if mmap:
61
+ kwargs["mmap"] = True
62
+ try:
63
+ return torch.load(path, **kwargs)
64
+ except TypeError:
65
+ kwargs.pop("mmap", None)
66
+ try:
67
+ return torch.load(path, **kwargs)
68
+ except TypeError:
69
+ kwargs.pop("weights_only", None)
70
+ return torch.load(path, **kwargs)
71
+
72
+
73
+ def load_checkpoint(
74
+ path: str | Path,
75
+ *,
76
+ model_type: str | None = None,
77
+ map_location: str | torch.device = "cpu",
78
+ weights_only: bool | None = None,
79
+ mmap: bool = True,
80
+ ) -> Any:
81
+ """Load a checkpoint package with compatibility for common MSS formats."""
82
+ model_type = (model_type or "").lower()
83
+ if model_type in {"htdemucs", "demucs", "legacy_demucs", "legacy_tasnet"}:
84
+ previous = _install_demucs_pickle_stubs()
85
+ try:
86
+ return _torch_load(path, map_location=map_location, weights_only=False, mmap=mmap)
87
+ finally:
88
+ _restore_modules(previous)
89
+ if model_type == "apollo":
90
+ weights_only = False if weights_only is None else weights_only
91
+ return _torch_load(path, map_location=map_location, weights_only=weights_only, mmap=mmap)
92
+
93
+
94
+ def load_state_dict(
95
+ path: str | Path,
96
+ *,
97
+ model_type: str | None = None,
98
+ map_location: str | torch.device = "cpu",
99
+ weights_only: bool | None = None,
100
+ mmap: bool = True,
101
+ ) -> Any:
102
+ """Load and unwrap the model state dict from a checkpoint file."""
103
+ return unwrap_state_dict(
104
+ load_checkpoint(
105
+ path,
106
+ model_type=model_type,
107
+ map_location=map_location,
108
+ weights_only=weights_only,
109
+ mmap=mmap,
110
+ )
111
+ )
112
+
113
+
114
+ def load_model_weights(
115
+ model: torch.nn.Module,
116
+ checkpoint_or_path: Any,
117
+ *,
118
+ model_type: str | None = None,
119
+ strict: bool = True,
120
+ map_location: str | torch.device = "cpu",
121
+ ) -> Any:
122
+ """Load weights from a checkpoint package or file into a model."""
123
+ if isinstance(checkpoint_or_path, (str, Path)):
124
+ state_dict = load_state_dict(checkpoint_or_path, model_type=model_type, map_location=map_location)
125
+ else:
126
+ state_dict = unwrap_state_dict(checkpoint_or_path)
127
+ return model.load_state_dict(state_dict, strict=strict)
@@ -0,0 +1,135 @@
1
+ import re
2
+
3
+ import yaml
4
+
5
+
6
+ class ConfigLoader(yaml.FullLoader):
7
+ """YAML loader used by pymss-core model configuration files."""
8
+
9
+ pass
10
+
11
+
12
+ ConfigLoader.add_implicit_resolver(
13
+ "tag:yaml.org,2002:float",
14
+ re.compile(
15
+ r"""^[-+]?(
16
+ ([0-9][0-9_]*)?\.[0-9_]+([eE][-+]?[0-9]+)?
17
+ |[0-9][0-9_]*[eE][-+]?[0-9]+
18
+ |\.(inf|Inf|INF)
19
+ |\.(nan|NaN|NAN)
20
+ )$""",
21
+ re.X,
22
+ ),
23
+ list("-+0123456789."),
24
+ )
25
+
26
+
27
+ class AttrDict(dict):
28
+ """Dictionary that recursively exposes keys as attributes.
29
+
30
+ Args:
31
+ data (Mapping | None, optional): Data value. Defaults to None.
32
+ **kwargs: Additional keyword arguments.
33
+
34
+ Example:
35
+ >>> cfg = AttrDict({"audio": {"chunk_size": 485100}})
36
+ >>> cfg.audio.chunk_size
37
+ 485100"""
38
+
39
+ def __init__(self, data=None, **kwargs):
40
+ """Initialize the instance.
41
+
42
+ Args:
43
+ data (Mapping | None, optional): Data value. Defaults to None.
44
+ **kwargs: Additional keyword arguments.
45
+
46
+ Returns:
47
+ None: This method completes for its side effects."""
48
+ super().__init__()
49
+ for key, value in dict(data or {}, **kwargs).items():
50
+ self[key] = value
51
+
52
+ def __getattr__(self, key):
53
+ """Return a missing attribute from the underlying mapping.
54
+
55
+ Args:
56
+ key (str): Key value.
57
+
58
+ Returns:
59
+ Any: Computed result."""
60
+ try:
61
+ return self[key]
62
+ except KeyError as exc:
63
+ raise AttributeError(key) from exc
64
+
65
+ def __setattr__(self, key, value):
66
+ """Store an attribute assignment in the underlying mapping.
67
+
68
+ Args:
69
+ key (str): Key value.
70
+ value (Any): Value value.
71
+
72
+ Returns:
73
+ None: This method completes for its side effects."""
74
+ self[key] = value
75
+
76
+ def __setitem__(self, key, value):
77
+ """Store an item after recursively converting nested dictionaries.
78
+
79
+ Args:
80
+ key (str): Key value.
81
+ value (Any): Value value.
82
+
83
+ Returns:
84
+ None: This method completes for its side effects."""
85
+ super().__setitem__(key, to_attrdict(value))
86
+
87
+
88
+ def to_attrdict(value):
89
+ """Recursively convert dictionaries to AttrDict objects.
90
+
91
+ Args:
92
+ value (Any): Value value.
93
+
94
+ Returns:
95
+ Any: Converted value with nested dictionaries wrapped as AttrDict."""
96
+ if isinstance(value, dict):
97
+ return value if isinstance(value, AttrDict) else AttrDict(value)
98
+ if isinstance(value, list):
99
+ return [to_attrdict(item) for item in value]
100
+ if isinstance(value, tuple):
101
+ return tuple(to_attrdict(item) for item in value)
102
+ return value
103
+
104
+
105
+ def to_plain(value):
106
+ """Recursively convert AttrDict objects back to plain Python containers.
107
+
108
+ Args:
109
+ value (Any): Value value.
110
+
111
+ Returns:
112
+ Any: Converted value using plain dictionaries, lists, and tuples."""
113
+ if isinstance(value, AttrDict):
114
+ return {key: to_plain(item) for key, item in value.items()}
115
+ if isinstance(value, list):
116
+ return [to_plain(item) for item in value]
117
+ if isinstance(value, tuple):
118
+ return tuple(to_plain(item) for item in value)
119
+ return value
120
+
121
+
122
+ def load_config(path):
123
+ """Load a YAML model configuration file.
124
+
125
+ Args:
126
+ path (str | os.PathLike): File system path.
127
+
128
+ Returns:
129
+ AttrDict: Parsed configuration with attribute access.
130
+
131
+ Example:
132
+ >>> config = load_config("config.yaml")
133
+ >>> config.inference.batch_size"""
134
+ with open(path, encoding="utf-8") as f:
135
+ return to_attrdict(yaml.load(f, Loader=ConfigLoader))
File without changes
@@ -0,0 +1,84 @@
1
+ """Small DSP helpers needed by model definitions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+
8
+ def hz_to_midi(hz):
9
+ """Convert frequencies in Hz to MIDI note numbers."""
10
+ hz = np.asarray(hz)
11
+ return 69.0 + 12.0 * np.log2(hz / 440.0)
12
+
13
+
14
+ def midi_to_hz(midi):
15
+ """Convert MIDI note numbers to frequencies in Hz."""
16
+ midi = np.asarray(midi)
17
+ return 440.0 * np.power(2.0, (midi - 69.0) / 12.0)
18
+
19
+
20
+ def _hz_to_mel(frequencies, *, htk=False):
21
+ frequencies = np.asarray(frequencies, dtype=np.float64)
22
+ if htk:
23
+ return 2595.0 * np.log10(1.0 + frequencies / 700.0)
24
+
25
+ f_sp = 200.0 / 3
26
+ mels = frequencies / f_sp
27
+ min_log_hz = 1000.0
28
+ min_log_mel = min_log_hz / f_sp
29
+ logstep = np.log(6.4) / 27.0
30
+ log_t = frequencies >= min_log_hz
31
+ mels = np.array(mels, copy=True)
32
+ mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep
33
+ return mels
34
+
35
+
36
+ def _mel_to_hz(mels, *, htk=False):
37
+ mels = np.asarray(mels, dtype=np.float64)
38
+ if htk:
39
+ return 700.0 * (np.power(10.0, mels / 2595.0) - 1.0)
40
+
41
+ f_sp = 200.0 / 3
42
+ freqs = f_sp * mels
43
+ min_log_hz = 1000.0
44
+ min_log_mel = min_log_hz / f_sp
45
+ logstep = np.log(6.4) / 27.0
46
+ log_t = mels >= min_log_mel
47
+ freqs = np.array(freqs, copy=True)
48
+ freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
49
+ return freqs
50
+
51
+
52
+ def mel_frequencies(n_mels, *, fmin=0.0, fmax=11025.0, htk=False):
53
+ """Return center frequencies on the mel scale, including endpoints."""
54
+ min_mel = _hz_to_mel(fmin, htk=htk)
55
+ max_mel = _hz_to_mel(fmax, htk=htk)
56
+ return _mel_to_hz(np.linspace(min_mel, max_mel, int(n_mels)), htk=htk)
57
+
58
+
59
+ def fft_frequencies(*, sr, n_fft):
60
+ """Return FFT bin center frequencies."""
61
+ return np.linspace(0.0, float(sr) / 2.0, int(1 + n_fft // 2), endpoint=True)
62
+
63
+
64
+ def mel_filterbank(sr, n_fft, n_mels=128, fmin=0.0, fmax=None, htk=False, norm="slaney", dtype=np.float32):
65
+ """Create a triangular mel filterbank for model initialization."""
66
+ if fmax is None:
67
+ fmax = float(sr) / 2.0
68
+
69
+ mel_f = mel_frequencies(int(n_mels) + 2, fmin=fmin, fmax=fmax, htk=htk)
70
+ fft_f = fft_frequencies(sr=sr, n_fft=n_fft)
71
+
72
+ fdiff = np.diff(mel_f)
73
+ ramps = np.subtract.outer(mel_f, fft_f)
74
+ lower = -ramps[:-2] / fdiff[:-1, np.newaxis]
75
+ upper = ramps[2:] / fdiff[1:, np.newaxis]
76
+ weights = np.maximum(0.0, np.minimum(lower, upper))
77
+
78
+ if norm == "slaney":
79
+ enorm = 2.0 / (mel_f[2 : int(n_mels) + 2] - mel_f[: int(n_mels)])
80
+ weights *= enorm[:, np.newaxis]
81
+ elif norm is not None:
82
+ raise ValueError(f"Unsupported mel filterbank norm: {norm!r}")
83
+
84
+ return weights.astype(dtype, copy=False)