pymss-core 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymss_core/__init__.py +25 -0
- pymss_core/checkpoint.py +127 -0
- pymss_core/config.py +135 -0
- pymss_core/modules/__init__.py +0 -0
- pymss_core/modules/_dsp.py +84 -0
- pymss_core/modules/apollo_mlx.py +242 -0
- pymss_core/modules/bandit/__init__.py +0 -0
- pymss_core/modules/bandit/bandsplit.py +166 -0
- pymss_core/modules/bandit/core/__init__.py +9 -0
- pymss_core/modules/bandit/core/model/__init__.py +9 -0
- pymss_core/modules/bandit/core/model/_spectral.py +95 -0
- pymss_core/modules/bandit/core/model/bsrnn/__init__.py +17 -0
- pymss_core/modules/bandit/core/model/bsrnn/bandsplit.py +32 -0
- pymss_core/modules/bandit/core/model/bsrnn/core.py +171 -0
- pymss_core/modules/bandit/core/model/bsrnn/maskestim.py +20 -0
- pymss_core/modules/bandit/core/model/bsrnn/tfmodel.py +12 -0
- pymss_core/modules/bandit/core/model/bsrnn/utils.py +355 -0
- pymss_core/modules/bandit/core/model/bsrnn/wrapper.py +235 -0
- pymss_core/modules/bandit/maskestim.py +312 -0
- pymss_core/modules/bandit/tfmodel.py +194 -0
- pymss_core/modules/bandit_mlx.py +415 -0
- pymss_core/modules/bandit_v2/__init__.py +0 -0
- pymss_core/modules/bandit_v2/bandit.py +326 -0
- pymss_core/modules/bandit_v2/bandsplit.py +13 -0
- pymss_core/modules/bandit_v2/maskestim.py +119 -0
- pymss_core/modules/bandit_v2/tfmodel.py +18 -0
- pymss_core/modules/bandit_v2/utils.py +33 -0
- pymss_core/modules/bs_roformer/__init__.py +5 -0
- pymss_core/modules/bs_roformer/attend.py +29 -0
- pymss_core/modules/bs_roformer/bands.py +683 -0
- pymss_core/modules/bs_roformer/bs_roformer.py +111 -0
- pymss_core/modules/bs_roformer/bs_roformer_hyperace.py +43 -0
- pymss_core/modules/bs_roformer/common.py +383 -0
- pymss_core/modules/bs_roformer/hyperace_segm.py +331 -0
- pymss_core/modules/bs_roformer/mel_band_roformer.py +168 -0
- pymss_core/modules/bs_roformer/mlx_attention.py +408 -0
- pymss_core/modules/bs_roformer/mlx_roformer.py +667 -0
- pymss_core/modules/bs_roformer/transformer.py +355 -0
- pymss_core/modules/demucs4ht.py +479 -0
- pymss_core/modules/demucs_local.py +597 -0
- pymss_core/modules/demucs_mlx.py +605 -0
- pymss_core/modules/legacy_demucs.py +1552 -0
- pymss_core/modules/look2hear/__init__.py +3 -0
- pymss_core/modules/look2hear/apollo.py +549 -0
- pymss_core/modules/mdx23c_mlx.py +303 -0
- pymss_core/modules/mdx23c_tfc_tdf_v3.py +175 -0
- pymss_core/modules/mlx_utils.py +24 -0
- pymss_core/modules/scnet/__init__.py +3 -0
- pymss_core/modules/scnet/scnet.py +266 -0
- pymss_core/modules/scnet/separation.py +64 -0
- pymss_core/modules/scnet_mlx.py +411 -0
- pymss_core/modules/spectrogram.py +87 -0
- pymss_core/modules/vocal_remover/__init__.py +12 -0
- pymss_core/modules/vocal_remover/uvr_lib_v5/__init__.py +1 -0
- pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/__init__.py +1 -0
- pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/layers.py +135 -0
- pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/layers_new.py +101 -0
- pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/model_param_init.py +19 -0
- pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/nets.py +155 -0
- pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/nets_new.py +118 -0
- pymss_core/resources/vr_modelparams/1band_sr16000_hl512.json +19 -0
- pymss_core/resources/vr_modelparams/1band_sr32000_hl512.json +19 -0
- pymss_core/resources/vr_modelparams/1band_sr33075_hl384.json +19 -0
- pymss_core/resources/vr_modelparams/1band_sr44100_hl1024.json +19 -0
- pymss_core/resources/vr_modelparams/1band_sr44100_hl256.json +19 -0
- pymss_core/resources/vr_modelparams/1band_sr44100_hl512.json +19 -0
- pymss_core/resources/vr_modelparams/1band_sr44100_hl512_cut.json +19 -0
- pymss_core/resources/vr_modelparams/1band_sr44100_hl512_nf1024.json +19 -0
- pymss_core/resources/vr_modelparams/2band_32000.json +30 -0
- pymss_core/resources/vr_modelparams/2band_44100_lofi.json +30 -0
- pymss_core/resources/vr_modelparams/2band_48000.json +30 -0
- pymss_core/resources/vr_modelparams/3band_44100.json +42 -0
- pymss_core/resources/vr_modelparams/3band_44100_mid.json +43 -0
- pymss_core/resources/vr_modelparams/3band_44100_msb2.json +43 -0
- pymss_core/resources/vr_modelparams/4band_44100.json +54 -0
- pymss_core/resources/vr_modelparams/4band_44100_mid.json +55 -0
- pymss_core/resources/vr_modelparams/4band_44100_msb.json +55 -0
- pymss_core/resources/vr_modelparams/4band_44100_msb2.json +55 -0
- pymss_core/resources/vr_modelparams/4band_44100_reverse.json +55 -0
- pymss_core/resources/vr_modelparams/4band_44100_sw.json +55 -0
- pymss_core/resources/vr_modelparams/4band_v2.json +54 -0
- pymss_core/resources/vr_modelparams/4band_v2_sn.json +55 -0
- pymss_core/resources/vr_modelparams/4band_v3.json +54 -0
- pymss_core/resources/vr_modelparams/4band_v3_sn.json +55 -0
- pymss_core/resources/vr_modelparams/4band_v4_ms_fullband.json +58 -0
- pymss_core/utils.py +53 -0
- pymss_core-0.1.0.dist-info/METADATA +113 -0
- pymss_core-0.1.0.dist-info/RECORD +91 -0
- pymss_core-0.1.0.dist-info/WHEEL +5 -0
- pymss_core-0.1.0.dist-info/licenses/LICENSE +21 -0
- pymss_core-0.1.0.dist-info/top_level.txt +1 -0
pymss_core/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Core model, configuration, and checkpoint API for music source separation.
|
|
2
|
+
|
|
3
|
+
`pymss_core` contains the shared pieces used by higher-level packages:
|
|
4
|
+
configuration loading, model construction, model definitions, and
|
|
5
|
+
checkpoint/state-dict helpers. It intentionally does not provide file audio
|
|
6
|
+
I/O, inference DSP pipelines, chunked demixing, model catalog downloads, CLI,
|
|
7
|
+
HTTP server, or WebUI functionality.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .checkpoint import load_checkpoint, load_model_weights, load_state_dict, unwrap_state_dict
|
|
11
|
+
from .config import AttrDict, ConfigLoader, load_config, to_attrdict, to_plain
|
|
12
|
+
from .utils import get_model_from_config
|
|
13
|
+
|
|
14
|
+
__all__ = (
|
|
15
|
+
"AttrDict",
|
|
16
|
+
"ConfigLoader",
|
|
17
|
+
"get_model_from_config",
|
|
18
|
+
"load_checkpoint",
|
|
19
|
+
"load_config",
|
|
20
|
+
"load_model_weights",
|
|
21
|
+
"load_state_dict",
|
|
22
|
+
"to_attrdict",
|
|
23
|
+
"to_plain",
|
|
24
|
+
"unwrap_state_dict",
|
|
25
|
+
)
|
pymss_core/checkpoint.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Checkpoint helpers shared by inference and training frontends."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from types import ModuleType
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
STATE_DICT_KEYS = ("state", "state_dict", "model_state_dict")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def unwrap_state_dict(checkpoint: Any) -> Any:
|
|
16
|
+
"""Return the model state dict from common MSS checkpoint containers."""
|
|
17
|
+
if isinstance(checkpoint, dict):
|
|
18
|
+
for key in STATE_DICT_KEYS:
|
|
19
|
+
if key in checkpoint:
|
|
20
|
+
return checkpoint[key]
|
|
21
|
+
return checkpoint
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _install_demucs_pickle_stubs() -> dict[str, ModuleType | None]:
|
|
25
|
+
import sys
|
|
26
|
+
import types
|
|
27
|
+
|
|
28
|
+
module_names = ("demucs", "demucs.demucs", "demucs.hdemucs", "demucs.htdemucs")
|
|
29
|
+
previous = {name: sys.modules.get(name) for name in module_names}
|
|
30
|
+
package = sys.modules.setdefault("demucs", types.ModuleType("demucs"))
|
|
31
|
+
package.__path__ = []
|
|
32
|
+
for module_name, class_names in {
|
|
33
|
+
"demucs": ("Demucs",),
|
|
34
|
+
"hdemucs": ("HDemucs", "HTDemucs"),
|
|
35
|
+
"htdemucs": ("HTDemucs",),
|
|
36
|
+
}.items():
|
|
37
|
+
full_name = f"demucs.{module_name}"
|
|
38
|
+
module = sys.modules.setdefault(full_name, types.ModuleType(full_name))
|
|
39
|
+
setattr(package, module_name, module)
|
|
40
|
+
for class_name in class_names:
|
|
41
|
+
if not hasattr(module, class_name):
|
|
42
|
+
setattr(module, class_name, type(class_name, (), {"__module__": full_name}))
|
|
43
|
+
return previous
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _restore_modules(previous: dict[str, ModuleType | None]) -> None:
|
|
47
|
+
import sys
|
|
48
|
+
|
|
49
|
+
for name, module in previous.items():
|
|
50
|
+
if module is None:
|
|
51
|
+
sys.modules.pop(name, None)
|
|
52
|
+
else:
|
|
53
|
+
sys.modules[name] = module
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _torch_load(path: str | Path, *, map_location="cpu", weights_only: bool | None = None, mmap: bool = True) -> Any:
|
|
57
|
+
kwargs: dict[str, Any] = {"map_location": map_location}
|
|
58
|
+
if weights_only is not None:
|
|
59
|
+
kwargs["weights_only"] = weights_only
|
|
60
|
+
if mmap:
|
|
61
|
+
kwargs["mmap"] = True
|
|
62
|
+
try:
|
|
63
|
+
return torch.load(path, **kwargs)
|
|
64
|
+
except TypeError:
|
|
65
|
+
kwargs.pop("mmap", None)
|
|
66
|
+
try:
|
|
67
|
+
return torch.load(path, **kwargs)
|
|
68
|
+
except TypeError:
|
|
69
|
+
kwargs.pop("weights_only", None)
|
|
70
|
+
return torch.load(path, **kwargs)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def load_checkpoint(
|
|
74
|
+
path: str | Path,
|
|
75
|
+
*,
|
|
76
|
+
model_type: str | None = None,
|
|
77
|
+
map_location: str | torch.device = "cpu",
|
|
78
|
+
weights_only: bool | None = None,
|
|
79
|
+
mmap: bool = True,
|
|
80
|
+
) -> Any:
|
|
81
|
+
"""Load a checkpoint package with compatibility for common MSS formats."""
|
|
82
|
+
model_type = (model_type or "").lower()
|
|
83
|
+
if model_type in {"htdemucs", "demucs", "legacy_demucs", "legacy_tasnet"}:
|
|
84
|
+
previous = _install_demucs_pickle_stubs()
|
|
85
|
+
try:
|
|
86
|
+
return _torch_load(path, map_location=map_location, weights_only=False, mmap=mmap)
|
|
87
|
+
finally:
|
|
88
|
+
_restore_modules(previous)
|
|
89
|
+
if model_type == "apollo":
|
|
90
|
+
weights_only = False if weights_only is None else weights_only
|
|
91
|
+
return _torch_load(path, map_location=map_location, weights_only=weights_only, mmap=mmap)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def load_state_dict(
|
|
95
|
+
path: str | Path,
|
|
96
|
+
*,
|
|
97
|
+
model_type: str | None = None,
|
|
98
|
+
map_location: str | torch.device = "cpu",
|
|
99
|
+
weights_only: bool | None = None,
|
|
100
|
+
mmap: bool = True,
|
|
101
|
+
) -> Any:
|
|
102
|
+
"""Load and unwrap the model state dict from a checkpoint file."""
|
|
103
|
+
return unwrap_state_dict(
|
|
104
|
+
load_checkpoint(
|
|
105
|
+
path,
|
|
106
|
+
model_type=model_type,
|
|
107
|
+
map_location=map_location,
|
|
108
|
+
weights_only=weights_only,
|
|
109
|
+
mmap=mmap,
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def load_model_weights(
|
|
115
|
+
model: torch.nn.Module,
|
|
116
|
+
checkpoint_or_path: Any,
|
|
117
|
+
*,
|
|
118
|
+
model_type: str | None = None,
|
|
119
|
+
strict: bool = True,
|
|
120
|
+
map_location: str | torch.device = "cpu",
|
|
121
|
+
) -> Any:
|
|
122
|
+
"""Load weights from a checkpoint package or file into a model."""
|
|
123
|
+
if isinstance(checkpoint_or_path, (str, Path)):
|
|
124
|
+
state_dict = load_state_dict(checkpoint_or_path, model_type=model_type, map_location=map_location)
|
|
125
|
+
else:
|
|
126
|
+
state_dict = unwrap_state_dict(checkpoint_or_path)
|
|
127
|
+
return model.load_state_dict(state_dict, strict=strict)
|
pymss_core/config.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
import yaml
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConfigLoader(yaml.FullLoader):
|
|
7
|
+
"""YAML loader used by pymss-core model configuration files."""
|
|
8
|
+
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
ConfigLoader.add_implicit_resolver(
|
|
13
|
+
"tag:yaml.org,2002:float",
|
|
14
|
+
re.compile(
|
|
15
|
+
r"""^[-+]?(
|
|
16
|
+
([0-9][0-9_]*)?\.[0-9_]+([eE][-+]?[0-9]+)?
|
|
17
|
+
|[0-9][0-9_]*[eE][-+]?[0-9]+
|
|
18
|
+
|\.(inf|Inf|INF)
|
|
19
|
+
|\.(nan|NaN|NAN)
|
|
20
|
+
)$""",
|
|
21
|
+
re.X,
|
|
22
|
+
),
|
|
23
|
+
list("-+0123456789."),
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AttrDict(dict):
|
|
28
|
+
"""Dictionary that recursively exposes keys as attributes.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
data (Mapping | None, optional): Data value. Defaults to None.
|
|
32
|
+
**kwargs: Additional keyword arguments.
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
>>> cfg = AttrDict({"audio": {"chunk_size": 485100}})
|
|
36
|
+
>>> cfg.audio.chunk_size
|
|
37
|
+
485100"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, data=None, **kwargs):
|
|
40
|
+
"""Initialize the instance.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
data (Mapping | None, optional): Data value. Defaults to None.
|
|
44
|
+
**kwargs: Additional keyword arguments.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
None: This method completes for its side effects."""
|
|
48
|
+
super().__init__()
|
|
49
|
+
for key, value in dict(data or {}, **kwargs).items():
|
|
50
|
+
self[key] = value
|
|
51
|
+
|
|
52
|
+
def __getattr__(self, key):
|
|
53
|
+
"""Return a missing attribute from the underlying mapping.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
key (str): Key value.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Any: Computed result."""
|
|
60
|
+
try:
|
|
61
|
+
return self[key]
|
|
62
|
+
except KeyError as exc:
|
|
63
|
+
raise AttributeError(key) from exc
|
|
64
|
+
|
|
65
|
+
def __setattr__(self, key, value):
|
|
66
|
+
"""Store an attribute assignment in the underlying mapping.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
key (str): Key value.
|
|
70
|
+
value (Any): Value value.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
None: This method completes for its side effects."""
|
|
74
|
+
self[key] = value
|
|
75
|
+
|
|
76
|
+
def __setitem__(self, key, value):
|
|
77
|
+
"""Store an item after recursively converting nested dictionaries.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
key (str): Key value.
|
|
81
|
+
value (Any): Value value.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
None: This method completes for its side effects."""
|
|
85
|
+
super().__setitem__(key, to_attrdict(value))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def to_attrdict(value):
|
|
89
|
+
"""Recursively convert dictionaries to AttrDict objects.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
value (Any): Value value.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Any: Converted value with nested dictionaries wrapped as AttrDict."""
|
|
96
|
+
if isinstance(value, dict):
|
|
97
|
+
return value if isinstance(value, AttrDict) else AttrDict(value)
|
|
98
|
+
if isinstance(value, list):
|
|
99
|
+
return [to_attrdict(item) for item in value]
|
|
100
|
+
if isinstance(value, tuple):
|
|
101
|
+
return tuple(to_attrdict(item) for item in value)
|
|
102
|
+
return value
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def to_plain(value):
|
|
106
|
+
"""Recursively convert AttrDict objects back to plain Python containers.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
value (Any): Value value.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Any: Converted value using plain dictionaries, lists, and tuples."""
|
|
113
|
+
if isinstance(value, AttrDict):
|
|
114
|
+
return {key: to_plain(item) for key, item in value.items()}
|
|
115
|
+
if isinstance(value, list):
|
|
116
|
+
return [to_plain(item) for item in value]
|
|
117
|
+
if isinstance(value, tuple):
|
|
118
|
+
return tuple(to_plain(item) for item in value)
|
|
119
|
+
return value
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def load_config(path):
|
|
123
|
+
"""Load a YAML model configuration file.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
path (str | os.PathLike): File system path.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
AttrDict: Parsed configuration with attribute access.
|
|
130
|
+
|
|
131
|
+
Example:
|
|
132
|
+
>>> config = load_config("config.yaml")
|
|
133
|
+
>>> config.inference.batch_size"""
|
|
134
|
+
with open(path, encoding="utf-8") as f:
|
|
135
|
+
return to_attrdict(yaml.load(f, Loader=ConfigLoader))
|
|
File without changes
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Small DSP helpers needed by model definitions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def hz_to_midi(hz):
|
|
9
|
+
"""Convert frequencies in Hz to MIDI note numbers."""
|
|
10
|
+
hz = np.asarray(hz)
|
|
11
|
+
return 69.0 + 12.0 * np.log2(hz / 440.0)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def midi_to_hz(midi):
|
|
15
|
+
"""Convert MIDI note numbers to frequencies in Hz."""
|
|
16
|
+
midi = np.asarray(midi)
|
|
17
|
+
return 440.0 * np.power(2.0, (midi - 69.0) / 12.0)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _hz_to_mel(frequencies, *, htk=False):
|
|
21
|
+
frequencies = np.asarray(frequencies, dtype=np.float64)
|
|
22
|
+
if htk:
|
|
23
|
+
return 2595.0 * np.log10(1.0 + frequencies / 700.0)
|
|
24
|
+
|
|
25
|
+
f_sp = 200.0 / 3
|
|
26
|
+
mels = frequencies / f_sp
|
|
27
|
+
min_log_hz = 1000.0
|
|
28
|
+
min_log_mel = min_log_hz / f_sp
|
|
29
|
+
logstep = np.log(6.4) / 27.0
|
|
30
|
+
log_t = frequencies >= min_log_hz
|
|
31
|
+
mels = np.array(mels, copy=True)
|
|
32
|
+
mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep
|
|
33
|
+
return mels
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _mel_to_hz(mels, *, htk=False):
|
|
37
|
+
mels = np.asarray(mels, dtype=np.float64)
|
|
38
|
+
if htk:
|
|
39
|
+
return 700.0 * (np.power(10.0, mels / 2595.0) - 1.0)
|
|
40
|
+
|
|
41
|
+
f_sp = 200.0 / 3
|
|
42
|
+
freqs = f_sp * mels
|
|
43
|
+
min_log_hz = 1000.0
|
|
44
|
+
min_log_mel = min_log_hz / f_sp
|
|
45
|
+
logstep = np.log(6.4) / 27.0
|
|
46
|
+
log_t = mels >= min_log_mel
|
|
47
|
+
freqs = np.array(freqs, copy=True)
|
|
48
|
+
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
|
49
|
+
return freqs
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def mel_frequencies(n_mels, *, fmin=0.0, fmax=11025.0, htk=False):
|
|
53
|
+
"""Return center frequencies on the mel scale, including endpoints."""
|
|
54
|
+
min_mel = _hz_to_mel(fmin, htk=htk)
|
|
55
|
+
max_mel = _hz_to_mel(fmax, htk=htk)
|
|
56
|
+
return _mel_to_hz(np.linspace(min_mel, max_mel, int(n_mels)), htk=htk)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def fft_frequencies(*, sr, n_fft):
|
|
60
|
+
"""Return FFT bin center frequencies."""
|
|
61
|
+
return np.linspace(0.0, float(sr) / 2.0, int(1 + n_fft // 2), endpoint=True)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def mel_filterbank(sr, n_fft, n_mels=128, fmin=0.0, fmax=None, htk=False, norm="slaney", dtype=np.float32):
|
|
65
|
+
"""Create a triangular mel filterbank for model initialization."""
|
|
66
|
+
if fmax is None:
|
|
67
|
+
fmax = float(sr) / 2.0
|
|
68
|
+
|
|
69
|
+
mel_f = mel_frequencies(int(n_mels) + 2, fmin=fmin, fmax=fmax, htk=htk)
|
|
70
|
+
fft_f = fft_frequencies(sr=sr, n_fft=n_fft)
|
|
71
|
+
|
|
72
|
+
fdiff = np.diff(mel_f)
|
|
73
|
+
ramps = np.subtract.outer(mel_f, fft_f)
|
|
74
|
+
lower = -ramps[:-2] / fdiff[:-1, np.newaxis]
|
|
75
|
+
upper = ramps[2:] / fdiff[1:, np.newaxis]
|
|
76
|
+
weights = np.maximum(0.0, np.minimum(lower, upper))
|
|
77
|
+
|
|
78
|
+
if norm == "slaney":
|
|
79
|
+
enorm = 2.0 / (mel_f[2 : int(n_mels) + 2] - mel_f[: int(n_mels)])
|
|
80
|
+
weights *= enorm[:, np.newaxis]
|
|
81
|
+
elif norm is not None:
|
|
82
|
+
raise ValueError(f"Unsupported mel filterbank norm: {norm!r}")
|
|
83
|
+
|
|
84
|
+
return weights.astype(dtype, copy=False)
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from .bs_roformer.mlx_attention import _mlx_dtype, _torch_to_mlx_array, mlx_to_torch_mps
|
|
4
|
+
from .look2hear.apollo import BSNet, ConvActNorm1d, ICB, RMSNorm
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def torch_to_mlx_input(tensor, dtype):
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
|
|
10
|
+
return mx.array(tensor.detach().to(dtype=dtype).cpu().numpy())
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _mlx_param(module, name, tensor, dtype):
|
|
14
|
+
cache = getattr(module, "_pymss_mlx_full_param_cache", None)
|
|
15
|
+
if cache is None:
|
|
16
|
+
cache = {}
|
|
17
|
+
module._pymss_mlx_full_param_cache = cache
|
|
18
|
+
key = (name, tensor.data_ptr(), tensor._version, tuple(tensor.shape), dtype)
|
|
19
|
+
cached = cache.get(name)
|
|
20
|
+
if cached is not None and cached[0] == key:
|
|
21
|
+
return cached[1]
|
|
22
|
+
value = _torch_to_mlx_array(tensor, dtype)
|
|
23
|
+
cache[name] = (key, value)
|
|
24
|
+
return value
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _reflect_pad_last(x, pad):
|
|
28
|
+
import mlx.core as mx
|
|
29
|
+
|
|
30
|
+
if pad <= 0:
|
|
31
|
+
return x
|
|
32
|
+
if x.shape[-1] <= pad:
|
|
33
|
+
raise ValueError("reflect padding requires input length greater than padding")
|
|
34
|
+
return mx.concatenate((x[..., 1 : pad + 1][..., ::-1], x, x[..., -pad - 1 : -1][..., ::-1]), axis=-1)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _stft(module, raw_audio, dtype):
|
|
38
|
+
import mlx.core as mx
|
|
39
|
+
|
|
40
|
+
batch_channels, length = raw_audio.shape
|
|
41
|
+
n_fft = module.win
|
|
42
|
+
hop = module.stride
|
|
43
|
+
x = _reflect_pad_last(raw_audio.astype(dtype), n_fft // 2)
|
|
44
|
+
frames = 1 + (x.shape[-1] - n_fft) // hop
|
|
45
|
+
framed = mx.as_strided(x, shape=(batch_channels, frames, n_fft), strides=(x.shape[-1], hop, 1))
|
|
46
|
+
window = _torch_to_mlx_array(module.window, torch.float32).astype(dtype)
|
|
47
|
+
spec = mx.fft.rfft(framed * window, n=n_fft, axis=-1)
|
|
48
|
+
return mx.moveaxis(spec, -1, -2), {"length": length, "n_fft": n_fft, "hop": hop, "window": window, "dtype": dtype}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _istft(spec, context):
|
|
52
|
+
import mlx.core as mx
|
|
53
|
+
|
|
54
|
+
n_fft = context["n_fft"]
|
|
55
|
+
hop = context["hop"]
|
|
56
|
+
frames = mx.fft.irfft(mx.moveaxis(spec, -2, -1), n=n_fft, axis=-1).astype(context["dtype"]) * context["window"]
|
|
57
|
+
frame_count = frames.shape[1]
|
|
58
|
+
full_length = n_fft + hop * (frame_count - 1)
|
|
59
|
+
positions = mx.arange(n_fft)[None, :] + hop * mx.arange(frame_count)[:, None]
|
|
60
|
+
audio = mx.zeros((frames.shape[0], full_length), dtype=context["dtype"]).at[:, positions].add(frames)
|
|
61
|
+
denom_frames = mx.broadcast_to(mx.square(context["window"])[None, :], (frame_count, n_fft))
|
|
62
|
+
denom = mx.zeros((full_length,), dtype=context["dtype"]).at[positions].add(denom_frames)
|
|
63
|
+
audio = audio / mx.maximum(denom[None, :], mx.array(1e-11, dtype=context["dtype"]))
|
|
64
|
+
pad = n_fft // 2
|
|
65
|
+
return audio[..., pad : pad + context["length"]]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _conv1d_ncl(conv, x, dtype):
|
|
69
|
+
import mlx.core as mx
|
|
70
|
+
|
|
71
|
+
weight = _mlx_param(conv, "weight", conv.weight, dtype).transpose(0, 2, 1)
|
|
72
|
+
y = mx.conv1d(
|
|
73
|
+
x.transpose(0, 2, 1),
|
|
74
|
+
weight,
|
|
75
|
+
stride=conv.stride[0],
|
|
76
|
+
padding=conv.padding[0],
|
|
77
|
+
dilation=conv.dilation[0],
|
|
78
|
+
groups=conv.groups,
|
|
79
|
+
)
|
|
80
|
+
if conv.bias is not None:
|
|
81
|
+
y = y + _mlx_param(conv, "bias", conv.bias, dtype)
|
|
82
|
+
return y.transpose(0, 2, 1)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _rms_norm(module, x, dtype):
|
|
86
|
+
import mlx.core as mx
|
|
87
|
+
|
|
88
|
+
batch, channels, frames = x.shape
|
|
89
|
+
groups = int(module.groups)
|
|
90
|
+
y = x.astype(mx.float32).reshape(batch, groups, channels // groups, frames)
|
|
91
|
+
y = y * mx.rsqrt(mx.mean(mx.square(y), axis=2, keepdims=True) + module.eps)
|
|
92
|
+
y = y.reshape(batch, channels, frames).astype(x.dtype)
|
|
93
|
+
return y * _mlx_param(module, "weight", module.weight, dtype).reshape(1, -1, 1)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _silu(x):
|
|
97
|
+
import mlx.core as mx
|
|
98
|
+
|
|
99
|
+
return x * mx.sigmoid(x)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _glu_channel(x):
|
|
103
|
+
import mlx.core as mx
|
|
104
|
+
|
|
105
|
+
a, b = mx.split(x, 2, axis=1)
|
|
106
|
+
return a * mx.sigmoid(b)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _module_forward(module, x, dtype):
|
|
110
|
+
if isinstance(module, torch.nn.Sequential):
|
|
111
|
+
for child in module:
|
|
112
|
+
x = _module_forward(child, x, dtype)
|
|
113
|
+
return x
|
|
114
|
+
if isinstance(module, torch.nn.Conv1d):
|
|
115
|
+
return _conv1d_ncl(module, x, dtype)
|
|
116
|
+
if isinstance(module, RMSNorm):
|
|
117
|
+
return _rms_norm(module, x, dtype)
|
|
118
|
+
if isinstance(module, torch.nn.SiLU):
|
|
119
|
+
return _silu(x)
|
|
120
|
+
if isinstance(module, torch.nn.GLU):
|
|
121
|
+
return _glu_channel(x)
|
|
122
|
+
if isinstance(module, ConvActNorm1d):
|
|
123
|
+
return _conv_act_norm(module, x, dtype)
|
|
124
|
+
if isinstance(module, ICB):
|
|
125
|
+
return _module_forward(module.blocks, x, dtype)
|
|
126
|
+
if isinstance(module, BSNet):
|
|
127
|
+
return _bsnet(module, x, dtype)
|
|
128
|
+
raise TypeError(f"unsupported Apollo layer for MLX full backend: {type(module).__name__}")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _conv_act_norm(module, x, dtype):
|
|
132
|
+
y = _conv1d_ncl(module.conv[0], x, dtype)
|
|
133
|
+
y = _rms_norm(module.conv[1], y, dtype)
|
|
134
|
+
y = _conv1d_ncl(module.conv[2], y, dtype)
|
|
135
|
+
y = _silu(y)
|
|
136
|
+
y = _conv1d_ncl(module.conv[4], y, dtype)
|
|
137
|
+
if module.causal:
|
|
138
|
+
y = y[..., : -module.kernel + 1]
|
|
139
|
+
return x + y
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _apply_rope(module, x, dtype):
|
|
143
|
+
import mlx.core as mx
|
|
144
|
+
|
|
145
|
+
seq_len = x.shape[-2]
|
|
146
|
+
cos = _torch_to_mlx_array(module.cos_freq[:seq_len], dtype).reshape(1, 1, seq_len, -1)
|
|
147
|
+
sin = _torch_to_mlx_array(module.sin_freq[:seq_len], dtype).reshape(1, 1, seq_len, -1)
|
|
148
|
+
even, odd = x[..., 0::2], x[..., 1::2]
|
|
149
|
+
cos_e = cos[..., 0::2]
|
|
150
|
+
sin_e = sin[..., 0::2]
|
|
151
|
+
out = mx.zeros_like(x)
|
|
152
|
+
out = out.at[..., 0::2].add(even * cos_e - odd * sin_e)
|
|
153
|
+
out = out.at[..., 1::2].add(odd * cos_e + even * sin_e)
|
|
154
|
+
return out
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _roformer(module, x, dtype):
|
|
158
|
+
import mlx.core as mx
|
|
159
|
+
|
|
160
|
+
batch, _, frames = x.shape
|
|
161
|
+
x_norm = _rms_norm(module.input_norm, x, dtype)
|
|
162
|
+
qkv = _conv1d_ncl(module.weight, x_norm, dtype)
|
|
163
|
+
qkv = qkv.reshape(batch, module.num_head, module.hidden_size * 3, frames).transpose(0, 1, 3, 2)
|
|
164
|
+
q, k, v = mx.split(qkv, 3, axis=-1)
|
|
165
|
+
q = _apply_rope(module, q, dtype)
|
|
166
|
+
k = _apply_rope(module, k, dtype)
|
|
167
|
+
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=module.hidden_size**-0.5, mask=None)
|
|
168
|
+
out = attn.transpose(0, 1, 3, 2).reshape(batch, -1, frames)
|
|
169
|
+
out = _conv1d_ncl(module.output, out, dtype) + x
|
|
170
|
+
|
|
171
|
+
hidden = _rms_norm(module.MLP[0], out, dtype)
|
|
172
|
+
hidden = _conv1d_ncl(module.MLP[1], hidden, dtype)
|
|
173
|
+
hidden = _silu(hidden)
|
|
174
|
+
gate, z = mx.split(hidden, 2, axis=1)
|
|
175
|
+
return out + _conv1d_ncl(module.MLP_output, _silu(gate) * z, dtype)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _bsnet(module, x, dtype):
|
|
179
|
+
batch, bands, channels, frames = x.shape
|
|
180
|
+
band = x.transpose(0, 3, 2, 1).reshape(batch * frames, channels, bands)
|
|
181
|
+
band = _roformer(module.band_net, band, dtype)
|
|
182
|
+
band = band.reshape(batch, frames, channels, bands).transpose(0, 3, 2, 1)
|
|
183
|
+
seq = _module_forward(module.seq_net, band.reshape(batch * bands, channels, frames), dtype)
|
|
184
|
+
return seq.reshape(batch, bands, channels, frames)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _feature_extractor(module, raw_audio, dtype):
|
|
188
|
+
import mlx.core as mx
|
|
189
|
+
|
|
190
|
+
mx_dtype = _mlx_dtype(dtype)
|
|
191
|
+
batch, channels, samples = raw_audio.shape
|
|
192
|
+
spec, _ = _stft(module, raw_audio.reshape(batch * channels, samples), mx_dtype)
|
|
193
|
+
features = []
|
|
194
|
+
powers = []
|
|
195
|
+
band_index = 0
|
|
196
|
+
for width, bn in zip(module.band_width, module.BN):
|
|
197
|
+
sub = spec[:, band_index : band_index + width]
|
|
198
|
+
power = mx.sqrt(mx.sum(mx.square(sub.real) + mx.square(sub.imag), axis=1, keepdims=True) + module.eps)
|
|
199
|
+
norm = sub / power
|
|
200
|
+
inp = mx.concatenate((norm.real, norm.imag, mx.log(power)), axis=1)
|
|
201
|
+
features.append(_module_forward(bn, inp.astype(mx_dtype), dtype))
|
|
202
|
+
powers.append(power)
|
|
203
|
+
band_index += width
|
|
204
|
+
return mx.stack(features, axis=1), spec
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _estimate_spec(module, feature, batch_channels, dtype):
|
|
208
|
+
import mlx.core as mx
|
|
209
|
+
|
|
210
|
+
specs = []
|
|
211
|
+
for band_feature, output, width in zip(mx.split(feature, feature.shape[1], axis=1), module.output, module.band_width):
|
|
212
|
+
band_feature = band_feature[:, 0]
|
|
213
|
+
ri = _module_forward(output, band_feature, dtype).reshape(batch_channels, 2, width, -1)
|
|
214
|
+
specs.append(ri[:, 0] + (1j * ri[:, 1]))
|
|
215
|
+
return mx.concatenate(specs, axis=1)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def mlx_forward_apollo_mx(module, raw_audio, dtype=torch.float16):
|
|
219
|
+
if dtype not in (torch.float16, torch.float32):
|
|
220
|
+
raise TypeError("MLX full Apollo supports torch.float16 or torch.float32 compute dtype")
|
|
221
|
+
mx_dtype = _mlx_dtype(dtype)
|
|
222
|
+
raw_audio = raw_audio.astype(mx_dtype)
|
|
223
|
+
batch, channels, samples = raw_audio.shape
|
|
224
|
+
feature, _ = _feature_extractor(module, raw_audio, dtype)
|
|
225
|
+
for block in module.net:
|
|
226
|
+
feature = _bsnet(block, feature, dtype)
|
|
227
|
+
est_spec = _estimate_spec(module, feature, batch * channels, dtype)
|
|
228
|
+
return _istft(
|
|
229
|
+
est_spec,
|
|
230
|
+
{
|
|
231
|
+
"length": samples,
|
|
232
|
+
"n_fft": module.win,
|
|
233
|
+
"hop": module.stride,
|
|
234
|
+
"window": _torch_to_mlx_array(module.window, torch.float32).astype(mx_dtype),
|
|
235
|
+
"dtype": mx_dtype,
|
|
236
|
+
},
|
|
237
|
+
).reshape(batch, channels, -1)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def mlx_forward_apollo(module, raw_audio, dtype=torch.float16):
|
|
241
|
+
x_mx = torch_to_mlx_input(raw_audio, dtype=dtype)
|
|
242
|
+
return mlx_to_torch_mps(mlx_forward_apollo_mx(module, x_mx, dtype), raw_audio)
|
|
File without changes
|