pymss-core 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. pymss_core/__init__.py +25 -0
  2. pymss_core/checkpoint.py +127 -0
  3. pymss_core/config.py +135 -0
  4. pymss_core/modules/__init__.py +0 -0
  5. pymss_core/modules/_dsp.py +84 -0
  6. pymss_core/modules/apollo_mlx.py +242 -0
  7. pymss_core/modules/bandit/__init__.py +0 -0
  8. pymss_core/modules/bandit/bandsplit.py +166 -0
  9. pymss_core/modules/bandit/core/__init__.py +9 -0
  10. pymss_core/modules/bandit/core/model/__init__.py +9 -0
  11. pymss_core/modules/bandit/core/model/_spectral.py +95 -0
  12. pymss_core/modules/bandit/core/model/bsrnn/__init__.py +17 -0
  13. pymss_core/modules/bandit/core/model/bsrnn/bandsplit.py +32 -0
  14. pymss_core/modules/bandit/core/model/bsrnn/core.py +171 -0
  15. pymss_core/modules/bandit/core/model/bsrnn/maskestim.py +20 -0
  16. pymss_core/modules/bandit/core/model/bsrnn/tfmodel.py +12 -0
  17. pymss_core/modules/bandit/core/model/bsrnn/utils.py +355 -0
  18. pymss_core/modules/bandit/core/model/bsrnn/wrapper.py +235 -0
  19. pymss_core/modules/bandit/maskestim.py +312 -0
  20. pymss_core/modules/bandit/tfmodel.py +194 -0
  21. pymss_core/modules/bandit_mlx.py +415 -0
  22. pymss_core/modules/bandit_v2/__init__.py +0 -0
  23. pymss_core/modules/bandit_v2/bandit.py +326 -0
  24. pymss_core/modules/bandit_v2/bandsplit.py +13 -0
  25. pymss_core/modules/bandit_v2/maskestim.py +119 -0
  26. pymss_core/modules/bandit_v2/tfmodel.py +18 -0
  27. pymss_core/modules/bandit_v2/utils.py +33 -0
  28. pymss_core/modules/bs_roformer/__init__.py +5 -0
  29. pymss_core/modules/bs_roformer/attend.py +29 -0
  30. pymss_core/modules/bs_roformer/bands.py +683 -0
  31. pymss_core/modules/bs_roformer/bs_roformer.py +111 -0
  32. pymss_core/modules/bs_roformer/bs_roformer_hyperace.py +43 -0
  33. pymss_core/modules/bs_roformer/common.py +383 -0
  34. pymss_core/modules/bs_roformer/hyperace_segm.py +331 -0
  35. pymss_core/modules/bs_roformer/mel_band_roformer.py +168 -0
  36. pymss_core/modules/bs_roformer/mlx_attention.py +408 -0
  37. pymss_core/modules/bs_roformer/mlx_roformer.py +667 -0
  38. pymss_core/modules/bs_roformer/transformer.py +355 -0
  39. pymss_core/modules/demucs4ht.py +479 -0
  40. pymss_core/modules/demucs_local.py +597 -0
  41. pymss_core/modules/demucs_mlx.py +605 -0
  42. pymss_core/modules/legacy_demucs.py +1552 -0
  43. pymss_core/modules/look2hear/__init__.py +3 -0
  44. pymss_core/modules/look2hear/apollo.py +549 -0
  45. pymss_core/modules/mdx23c_mlx.py +303 -0
  46. pymss_core/modules/mdx23c_tfc_tdf_v3.py +175 -0
  47. pymss_core/modules/mlx_utils.py +24 -0
  48. pymss_core/modules/scnet/__init__.py +3 -0
  49. pymss_core/modules/scnet/scnet.py +266 -0
  50. pymss_core/modules/scnet/separation.py +64 -0
  51. pymss_core/modules/scnet_mlx.py +411 -0
  52. pymss_core/modules/spectrogram.py +87 -0
  53. pymss_core/modules/vocal_remover/__init__.py +12 -0
  54. pymss_core/modules/vocal_remover/uvr_lib_v5/__init__.py +1 -0
  55. pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/__init__.py +1 -0
  56. pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/layers.py +135 -0
  57. pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/layers_new.py +101 -0
  58. pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/model_param_init.py +19 -0
  59. pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/nets.py +155 -0
  60. pymss_core/modules/vocal_remover/uvr_lib_v5/vr_network/nets_new.py +118 -0
  61. pymss_core/resources/vr_modelparams/1band_sr16000_hl512.json +19 -0
  62. pymss_core/resources/vr_modelparams/1band_sr32000_hl512.json +19 -0
  63. pymss_core/resources/vr_modelparams/1band_sr33075_hl384.json +19 -0
  64. pymss_core/resources/vr_modelparams/1band_sr44100_hl1024.json +19 -0
  65. pymss_core/resources/vr_modelparams/1band_sr44100_hl256.json +19 -0
  66. pymss_core/resources/vr_modelparams/1band_sr44100_hl512.json +19 -0
  67. pymss_core/resources/vr_modelparams/1band_sr44100_hl512_cut.json +19 -0
  68. pymss_core/resources/vr_modelparams/1band_sr44100_hl512_nf1024.json +19 -0
  69. pymss_core/resources/vr_modelparams/2band_32000.json +30 -0
  70. pymss_core/resources/vr_modelparams/2band_44100_lofi.json +30 -0
  71. pymss_core/resources/vr_modelparams/2band_48000.json +30 -0
  72. pymss_core/resources/vr_modelparams/3band_44100.json +42 -0
  73. pymss_core/resources/vr_modelparams/3band_44100_mid.json +43 -0
  74. pymss_core/resources/vr_modelparams/3band_44100_msb2.json +43 -0
  75. pymss_core/resources/vr_modelparams/4band_44100.json +54 -0
  76. pymss_core/resources/vr_modelparams/4band_44100_mid.json +55 -0
  77. pymss_core/resources/vr_modelparams/4band_44100_msb.json +55 -0
  78. pymss_core/resources/vr_modelparams/4band_44100_msb2.json +55 -0
  79. pymss_core/resources/vr_modelparams/4band_44100_reverse.json +55 -0
  80. pymss_core/resources/vr_modelparams/4band_44100_sw.json +55 -0
  81. pymss_core/resources/vr_modelparams/4band_v2.json +54 -0
  82. pymss_core/resources/vr_modelparams/4band_v2_sn.json +55 -0
  83. pymss_core/resources/vr_modelparams/4band_v3.json +54 -0
  84. pymss_core/resources/vr_modelparams/4band_v3_sn.json +55 -0
  85. pymss_core/resources/vr_modelparams/4band_v4_ms_fullband.json +58 -0
  86. pymss_core/utils.py +53 -0
  87. pymss_core-0.1.0.dist-info/METADATA +113 -0
  88. pymss_core-0.1.0.dist-info/RECORD +91 -0
  89. pymss_core-0.1.0.dist-info/WHEEL +5 -0
  90. pymss_core-0.1.0.dist-info/licenses/LICENSE +21 -0
  91. pymss_core-0.1.0.dist-info/top_level.txt +1 -0
pymss_core/__init__.py ADDED
@@ -0,0 +1,25 @@
1
+ """Core model, configuration, and checkpoint API for music source separation.
2
+
3
+ `pymss_core` contains the shared pieces used by higher-level packages:
4
+ configuration loading, model construction, model definitions, and
5
+ checkpoint/state-dict helpers. It intentionally does not provide file audio
6
+ I/O, inference DSP pipelines, chunked demixing, model catalog downloads, CLI,
7
+ HTTP server, or WebUI functionality.
8
+ """
9
+
10
+ from .checkpoint import load_checkpoint, load_model_weights, load_state_dict, unwrap_state_dict
11
+ from .config import AttrDict, ConfigLoader, load_config, to_attrdict, to_plain
12
+ from .utils import get_model_from_config
13
+
14
+ __all__ = (
15
+ "AttrDict",
16
+ "ConfigLoader",
17
+ "get_model_from_config",
18
+ "load_checkpoint",
19
+ "load_config",
20
+ "load_model_weights",
21
+ "load_state_dict",
22
+ "to_attrdict",
23
+ "to_plain",
24
+ "unwrap_state_dict",
25
+ )
@@ -0,0 +1,127 @@
1
+ """Checkpoint helpers shared by inference and training frontends."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+ from typing import Any
8
+
9
+ import torch
10
+
11
+
12
+ STATE_DICT_KEYS = ("state", "state_dict", "model_state_dict")
13
+
14
+
15
+ def unwrap_state_dict(checkpoint: Any) -> Any:
16
+ """Return the model state dict from common MSS checkpoint containers."""
17
+ if isinstance(checkpoint, dict):
18
+ for key in STATE_DICT_KEYS:
19
+ if key in checkpoint:
20
+ return checkpoint[key]
21
+ return checkpoint
22
+
23
+
24
+ def _install_demucs_pickle_stubs() -> dict[str, ModuleType | None]:
25
+ import sys
26
+ import types
27
+
28
+ module_names = ("demucs", "demucs.demucs", "demucs.hdemucs", "demucs.htdemucs")
29
+ previous = {name: sys.modules.get(name) for name in module_names}
30
+ package = sys.modules.setdefault("demucs", types.ModuleType("demucs"))
31
+ package.__path__ = []
32
+ for module_name, class_names in {
33
+ "demucs": ("Demucs",),
34
+ "hdemucs": ("HDemucs", "HTDemucs"),
35
+ "htdemucs": ("HTDemucs",),
36
+ }.items():
37
+ full_name = f"demucs.{module_name}"
38
+ module = sys.modules.setdefault(full_name, types.ModuleType(full_name))
39
+ setattr(package, module_name, module)
40
+ for class_name in class_names:
41
+ if not hasattr(module, class_name):
42
+ setattr(module, class_name, type(class_name, (), {"__module__": full_name}))
43
+ return previous
44
+
45
+
46
+ def _restore_modules(previous: dict[str, ModuleType | None]) -> None:
47
+ import sys
48
+
49
+ for name, module in previous.items():
50
+ if module is None:
51
+ sys.modules.pop(name, None)
52
+ else:
53
+ sys.modules[name] = module
54
+
55
+
56
+ def _torch_load(path: str | Path, *, map_location="cpu", weights_only: bool | None = None, mmap: bool = True) -> Any:
57
+ kwargs: dict[str, Any] = {"map_location": map_location}
58
+ if weights_only is not None:
59
+ kwargs["weights_only"] = weights_only
60
+ if mmap:
61
+ kwargs["mmap"] = True
62
+ try:
63
+ return torch.load(path, **kwargs)
64
+ except TypeError:
65
+ kwargs.pop("mmap", None)
66
+ try:
67
+ return torch.load(path, **kwargs)
68
+ except TypeError:
69
+ kwargs.pop("weights_only", None)
70
+ return torch.load(path, **kwargs)
71
+
72
+
73
+ def load_checkpoint(
74
+ path: str | Path,
75
+ *,
76
+ model_type: str | None = None,
77
+ map_location: str | torch.device = "cpu",
78
+ weights_only: bool | None = None,
79
+ mmap: bool = True,
80
+ ) -> Any:
81
+ """Load a checkpoint package with compatibility for common MSS formats."""
82
+ model_type = (model_type or "").lower()
83
+ if model_type in {"htdemucs", "demucs", "legacy_demucs", "legacy_tasnet"}:
84
+ previous = _install_demucs_pickle_stubs()
85
+ try:
86
+ return _torch_load(path, map_location=map_location, weights_only=False, mmap=mmap)
87
+ finally:
88
+ _restore_modules(previous)
89
+ if model_type == "apollo":
90
+ weights_only = False if weights_only is None else weights_only
91
+ return _torch_load(path, map_location=map_location, weights_only=weights_only, mmap=mmap)
92
+
93
+
94
+ def load_state_dict(
95
+ path: str | Path,
96
+ *,
97
+ model_type: str | None = None,
98
+ map_location: str | torch.device = "cpu",
99
+ weights_only: bool | None = None,
100
+ mmap: bool = True,
101
+ ) -> Any:
102
+ """Load and unwrap the model state dict from a checkpoint file."""
103
+ return unwrap_state_dict(
104
+ load_checkpoint(
105
+ path,
106
+ model_type=model_type,
107
+ map_location=map_location,
108
+ weights_only=weights_only,
109
+ mmap=mmap,
110
+ )
111
+ )
112
+
113
+
114
+ def load_model_weights(
115
+ model: torch.nn.Module,
116
+ checkpoint_or_path: Any,
117
+ *,
118
+ model_type: str | None = None,
119
+ strict: bool = True,
120
+ map_location: str | torch.device = "cpu",
121
+ ) -> Any:
122
+ """Load weights from a checkpoint package or file into a model."""
123
+ if isinstance(checkpoint_or_path, (str, Path)):
124
+ state_dict = load_state_dict(checkpoint_or_path, model_type=model_type, map_location=map_location)
125
+ else:
126
+ state_dict = unwrap_state_dict(checkpoint_or_path)
127
+ return model.load_state_dict(state_dict, strict=strict)
pymss_core/config.py ADDED
@@ -0,0 +1,135 @@
1
+ import re
2
+
3
+ import yaml
4
+
5
+
6
+ class ConfigLoader(yaml.FullLoader):
7
+ """YAML loader used by pymss-core model configuration files."""
8
+
9
+ pass
10
+
11
+
12
+ ConfigLoader.add_implicit_resolver(
13
+ "tag:yaml.org,2002:float",
14
+ re.compile(
15
+ r"""^[-+]?(
16
+ ([0-9][0-9_]*)?\.[0-9_]+([eE][-+]?[0-9]+)?
17
+ |[0-9][0-9_]*[eE][-+]?[0-9]+
18
+ |\.(inf|Inf|INF)
19
+ |\.(nan|NaN|NAN)
20
+ )$""",
21
+ re.X,
22
+ ),
23
+ list("-+0123456789."),
24
+ )
25
+
26
+
27
+ class AttrDict(dict):
28
+ """Dictionary that recursively exposes keys as attributes.
29
+
30
+ Args:
31
+ data (Mapping | None, optional): Data value. Defaults to None.
32
+ **kwargs: Additional keyword arguments.
33
+
34
+ Example:
35
+ >>> cfg = AttrDict({"audio": {"chunk_size": 485100}})
36
+ >>> cfg.audio.chunk_size
37
+ 485100"""
38
+
39
+ def __init__(self, data=None, **kwargs):
40
+ """Initialize the instance.
41
+
42
+ Args:
43
+ data (Mapping | None, optional): Data value. Defaults to None.
44
+ **kwargs: Additional keyword arguments.
45
+
46
+ Returns:
47
+ None: This method completes for its side effects."""
48
+ super().__init__()
49
+ for key, value in dict(data or {}, **kwargs).items():
50
+ self[key] = value
51
+
52
+ def __getattr__(self, key):
53
+ """Return a missing attribute from the underlying mapping.
54
+
55
+ Args:
56
+ key (str): Key value.
57
+
58
+ Returns:
59
+ Any: Computed result."""
60
+ try:
61
+ return self[key]
62
+ except KeyError as exc:
63
+ raise AttributeError(key) from exc
64
+
65
+ def __setattr__(self, key, value):
66
+ """Store an attribute assignment in the underlying mapping.
67
+
68
+ Args:
69
+ key (str): Key value.
70
+ value (Any): Value value.
71
+
72
+ Returns:
73
+ None: This method completes for its side effects."""
74
+ self[key] = value
75
+
76
+ def __setitem__(self, key, value):
77
+ """Store an item after recursively converting nested dictionaries.
78
+
79
+ Args:
80
+ key (str): Key value.
81
+ value (Any): Value value.
82
+
83
+ Returns:
84
+ None: This method completes for its side effects."""
85
+ super().__setitem__(key, to_attrdict(value))
86
+
87
+
88
+ def to_attrdict(value):
89
+ """Recursively convert dictionaries to AttrDict objects.
90
+
91
+ Args:
92
+ value (Any): Value value.
93
+
94
+ Returns:
95
+ Any: Converted value with nested dictionaries wrapped as AttrDict."""
96
+ if isinstance(value, dict):
97
+ return value if isinstance(value, AttrDict) else AttrDict(value)
98
+ if isinstance(value, list):
99
+ return [to_attrdict(item) for item in value]
100
+ if isinstance(value, tuple):
101
+ return tuple(to_attrdict(item) for item in value)
102
+ return value
103
+
104
+
105
+ def to_plain(value):
106
+ """Recursively convert AttrDict objects back to plain Python containers.
107
+
108
+ Args:
109
+ value (Any): Value value.
110
+
111
+ Returns:
112
+ Any: Converted value using plain dictionaries, lists, and tuples."""
113
+ if isinstance(value, AttrDict):
114
+ return {key: to_plain(item) for key, item in value.items()}
115
+ if isinstance(value, list):
116
+ return [to_plain(item) for item in value]
117
+ if isinstance(value, tuple):
118
+ return tuple(to_plain(item) for item in value)
119
+ return value
120
+
121
+
122
+ def load_config(path):
123
+ """Load a YAML model configuration file.
124
+
125
+ Args:
126
+ path (str | os.PathLike): File system path.
127
+
128
+ Returns:
129
+ AttrDict: Parsed configuration with attribute access.
130
+
131
+ Example:
132
+ >>> config = load_config("config.yaml")
133
+ >>> config.inference.batch_size"""
134
+ with open(path, encoding="utf-8") as f:
135
+ return to_attrdict(yaml.load(f, Loader=ConfigLoader))
File without changes
@@ -0,0 +1,84 @@
1
+ """Small DSP helpers needed by model definitions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+
8
+ def hz_to_midi(hz):
9
+ """Convert frequencies in Hz to MIDI note numbers."""
10
+ hz = np.asarray(hz)
11
+ return 69.0 + 12.0 * np.log2(hz / 440.0)
12
+
13
+
14
+ def midi_to_hz(midi):
15
+ """Convert MIDI note numbers to frequencies in Hz."""
16
+ midi = np.asarray(midi)
17
+ return 440.0 * np.power(2.0, (midi - 69.0) / 12.0)
18
+
19
+
20
+ def _hz_to_mel(frequencies, *, htk=False):
21
+ frequencies = np.asarray(frequencies, dtype=np.float64)
22
+ if htk:
23
+ return 2595.0 * np.log10(1.0 + frequencies / 700.0)
24
+
25
+ f_sp = 200.0 / 3
26
+ mels = frequencies / f_sp
27
+ min_log_hz = 1000.0
28
+ min_log_mel = min_log_hz / f_sp
29
+ logstep = np.log(6.4) / 27.0
30
+ log_t = frequencies >= min_log_hz
31
+ mels = np.array(mels, copy=True)
32
+ mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep
33
+ return mels
34
+
35
+
36
+ def _mel_to_hz(mels, *, htk=False):
37
+ mels = np.asarray(mels, dtype=np.float64)
38
+ if htk:
39
+ return 700.0 * (np.power(10.0, mels / 2595.0) - 1.0)
40
+
41
+ f_sp = 200.0 / 3
42
+ freqs = f_sp * mels
43
+ min_log_hz = 1000.0
44
+ min_log_mel = min_log_hz / f_sp
45
+ logstep = np.log(6.4) / 27.0
46
+ log_t = mels >= min_log_mel
47
+ freqs = np.array(freqs, copy=True)
48
+ freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
49
+ return freqs
50
+
51
+
52
+ def mel_frequencies(n_mels, *, fmin=0.0, fmax=11025.0, htk=False):
53
+ """Return center frequencies on the mel scale, including endpoints."""
54
+ min_mel = _hz_to_mel(fmin, htk=htk)
55
+ max_mel = _hz_to_mel(fmax, htk=htk)
56
+ return _mel_to_hz(np.linspace(min_mel, max_mel, int(n_mels)), htk=htk)
57
+
58
+
59
+ def fft_frequencies(*, sr, n_fft):
60
+ """Return FFT bin center frequencies."""
61
+ return np.linspace(0.0, float(sr) / 2.0, int(1 + n_fft // 2), endpoint=True)
62
+
63
+
64
+ def mel_filterbank(sr, n_fft, n_mels=128, fmin=0.0, fmax=None, htk=False, norm="slaney", dtype=np.float32):
65
+ """Create a triangular mel filterbank for model initialization."""
66
+ if fmax is None:
67
+ fmax = float(sr) / 2.0
68
+
69
+ mel_f = mel_frequencies(int(n_mels) + 2, fmin=fmin, fmax=fmax, htk=htk)
70
+ fft_f = fft_frequencies(sr=sr, n_fft=n_fft)
71
+
72
+ fdiff = np.diff(mel_f)
73
+ ramps = np.subtract.outer(mel_f, fft_f)
74
+ lower = -ramps[:-2] / fdiff[:-1, np.newaxis]
75
+ upper = ramps[2:] / fdiff[1:, np.newaxis]
76
+ weights = np.maximum(0.0, np.minimum(lower, upper))
77
+
78
+ if norm == "slaney":
79
+ enorm = 2.0 / (mel_f[2 : int(n_mels) + 2] - mel_f[: int(n_mels)])
80
+ weights *= enorm[:, np.newaxis]
81
+ elif norm is not None:
82
+ raise ValueError(f"Unsupported mel filterbank norm: {norm!r}")
83
+
84
+ return weights.astype(dtype, copy=False)
@@ -0,0 +1,242 @@
1
+ import torch
2
+
3
+ from .bs_roformer.mlx_attention import _mlx_dtype, _torch_to_mlx_array, mlx_to_torch_mps
4
+ from .look2hear.apollo import BSNet, ConvActNorm1d, ICB, RMSNorm
5
+
6
+
7
+ def torch_to_mlx_input(tensor, dtype):
8
+ import mlx.core as mx
9
+
10
+ return mx.array(tensor.detach().to(dtype=dtype).cpu().numpy())
11
+
12
+
13
+ def _mlx_param(module, name, tensor, dtype):
14
+ cache = getattr(module, "_pymss_mlx_full_param_cache", None)
15
+ if cache is None:
16
+ cache = {}
17
+ module._pymss_mlx_full_param_cache = cache
18
+ key = (name, tensor.data_ptr(), tensor._version, tuple(tensor.shape), dtype)
19
+ cached = cache.get(name)
20
+ if cached is not None and cached[0] == key:
21
+ return cached[1]
22
+ value = _torch_to_mlx_array(tensor, dtype)
23
+ cache[name] = (key, value)
24
+ return value
25
+
26
+
27
+ def _reflect_pad_last(x, pad):
28
+ import mlx.core as mx
29
+
30
+ if pad <= 0:
31
+ return x
32
+ if x.shape[-1] <= pad:
33
+ raise ValueError("reflect padding requires input length greater than padding")
34
+ return mx.concatenate((x[..., 1 : pad + 1][..., ::-1], x, x[..., -pad - 1 : -1][..., ::-1]), axis=-1)
35
+
36
+
37
+ def _stft(module, raw_audio, dtype):
38
+ import mlx.core as mx
39
+
40
+ batch_channels, length = raw_audio.shape
41
+ n_fft = module.win
42
+ hop = module.stride
43
+ x = _reflect_pad_last(raw_audio.astype(dtype), n_fft // 2)
44
+ frames = 1 + (x.shape[-1] - n_fft) // hop
45
+ framed = mx.as_strided(x, shape=(batch_channels, frames, n_fft), strides=(x.shape[-1], hop, 1))
46
+ window = _torch_to_mlx_array(module.window, torch.float32).astype(dtype)
47
+ spec = mx.fft.rfft(framed * window, n=n_fft, axis=-1)
48
+ return mx.moveaxis(spec, -1, -2), {"length": length, "n_fft": n_fft, "hop": hop, "window": window, "dtype": dtype}
49
+
50
+
51
+ def _istft(spec, context):
52
+ import mlx.core as mx
53
+
54
+ n_fft = context["n_fft"]
55
+ hop = context["hop"]
56
+ frames = mx.fft.irfft(mx.moveaxis(spec, -2, -1), n=n_fft, axis=-1).astype(context["dtype"]) * context["window"]
57
+ frame_count = frames.shape[1]
58
+ full_length = n_fft + hop * (frame_count - 1)
59
+ positions = mx.arange(n_fft)[None, :] + hop * mx.arange(frame_count)[:, None]
60
+ audio = mx.zeros((frames.shape[0], full_length), dtype=context["dtype"]).at[:, positions].add(frames)
61
+ denom_frames = mx.broadcast_to(mx.square(context["window"])[None, :], (frame_count, n_fft))
62
+ denom = mx.zeros((full_length,), dtype=context["dtype"]).at[positions].add(denom_frames)
63
+ audio = audio / mx.maximum(denom[None, :], mx.array(1e-11, dtype=context["dtype"]))
64
+ pad = n_fft // 2
65
+ return audio[..., pad : pad + context["length"]]
66
+
67
+
68
+ def _conv1d_ncl(conv, x, dtype):
69
+ import mlx.core as mx
70
+
71
+ weight = _mlx_param(conv, "weight", conv.weight, dtype).transpose(0, 2, 1)
72
+ y = mx.conv1d(
73
+ x.transpose(0, 2, 1),
74
+ weight,
75
+ stride=conv.stride[0],
76
+ padding=conv.padding[0],
77
+ dilation=conv.dilation[0],
78
+ groups=conv.groups,
79
+ )
80
+ if conv.bias is not None:
81
+ y = y + _mlx_param(conv, "bias", conv.bias, dtype)
82
+ return y.transpose(0, 2, 1)
83
+
84
+
85
+ def _rms_norm(module, x, dtype):
86
+ import mlx.core as mx
87
+
88
+ batch, channels, frames = x.shape
89
+ groups = int(module.groups)
90
+ y = x.astype(mx.float32).reshape(batch, groups, channels // groups, frames)
91
+ y = y * mx.rsqrt(mx.mean(mx.square(y), axis=2, keepdims=True) + module.eps)
92
+ y = y.reshape(batch, channels, frames).astype(x.dtype)
93
+ return y * _mlx_param(module, "weight", module.weight, dtype).reshape(1, -1, 1)
94
+
95
+
96
+ def _silu(x):
97
+ import mlx.core as mx
98
+
99
+ return x * mx.sigmoid(x)
100
+
101
+
102
+ def _glu_channel(x):
103
+ import mlx.core as mx
104
+
105
+ a, b = mx.split(x, 2, axis=1)
106
+ return a * mx.sigmoid(b)
107
+
108
+
109
+ def _module_forward(module, x, dtype):
110
+ if isinstance(module, torch.nn.Sequential):
111
+ for child in module:
112
+ x = _module_forward(child, x, dtype)
113
+ return x
114
+ if isinstance(module, torch.nn.Conv1d):
115
+ return _conv1d_ncl(module, x, dtype)
116
+ if isinstance(module, RMSNorm):
117
+ return _rms_norm(module, x, dtype)
118
+ if isinstance(module, torch.nn.SiLU):
119
+ return _silu(x)
120
+ if isinstance(module, torch.nn.GLU):
121
+ return _glu_channel(x)
122
+ if isinstance(module, ConvActNorm1d):
123
+ return _conv_act_norm(module, x, dtype)
124
+ if isinstance(module, ICB):
125
+ return _module_forward(module.blocks, x, dtype)
126
+ if isinstance(module, BSNet):
127
+ return _bsnet(module, x, dtype)
128
+ raise TypeError(f"unsupported Apollo layer for MLX full backend: {type(module).__name__}")
129
+
130
+
131
+ def _conv_act_norm(module, x, dtype):
132
+ y = _conv1d_ncl(module.conv[0], x, dtype)
133
+ y = _rms_norm(module.conv[1], y, dtype)
134
+ y = _conv1d_ncl(module.conv[2], y, dtype)
135
+ y = _silu(y)
136
+ y = _conv1d_ncl(module.conv[4], y, dtype)
137
+ if module.causal:
138
+ y = y[..., : -module.kernel + 1]
139
+ return x + y
140
+
141
+
142
+ def _apply_rope(module, x, dtype):
143
+ import mlx.core as mx
144
+
145
+ seq_len = x.shape[-2]
146
+ cos = _torch_to_mlx_array(module.cos_freq[:seq_len], dtype).reshape(1, 1, seq_len, -1)
147
+ sin = _torch_to_mlx_array(module.sin_freq[:seq_len], dtype).reshape(1, 1, seq_len, -1)
148
+ even, odd = x[..., 0::2], x[..., 1::2]
149
+ cos_e = cos[..., 0::2]
150
+ sin_e = sin[..., 0::2]
151
+ out = mx.zeros_like(x)
152
+ out = out.at[..., 0::2].add(even * cos_e - odd * sin_e)
153
+ out = out.at[..., 1::2].add(odd * cos_e + even * sin_e)
154
+ return out
155
+
156
+
157
+ def _roformer(module, x, dtype):
158
+ import mlx.core as mx
159
+
160
+ batch, _, frames = x.shape
161
+ x_norm = _rms_norm(module.input_norm, x, dtype)
162
+ qkv = _conv1d_ncl(module.weight, x_norm, dtype)
163
+ qkv = qkv.reshape(batch, module.num_head, module.hidden_size * 3, frames).transpose(0, 1, 3, 2)
164
+ q, k, v = mx.split(qkv, 3, axis=-1)
165
+ q = _apply_rope(module, q, dtype)
166
+ k = _apply_rope(module, k, dtype)
167
+ attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=module.hidden_size**-0.5, mask=None)
168
+ out = attn.transpose(0, 1, 3, 2).reshape(batch, -1, frames)
169
+ out = _conv1d_ncl(module.output, out, dtype) + x
170
+
171
+ hidden = _rms_norm(module.MLP[0], out, dtype)
172
+ hidden = _conv1d_ncl(module.MLP[1], hidden, dtype)
173
+ hidden = _silu(hidden)
174
+ gate, z = mx.split(hidden, 2, axis=1)
175
+ return out + _conv1d_ncl(module.MLP_output, _silu(gate) * z, dtype)
176
+
177
+
178
+ def _bsnet(module, x, dtype):
179
+ batch, bands, channels, frames = x.shape
180
+ band = x.transpose(0, 3, 2, 1).reshape(batch * frames, channels, bands)
181
+ band = _roformer(module.band_net, band, dtype)
182
+ band = band.reshape(batch, frames, channels, bands).transpose(0, 3, 2, 1)
183
+ seq = _module_forward(module.seq_net, band.reshape(batch * bands, channels, frames), dtype)
184
+ return seq.reshape(batch, bands, channels, frames)
185
+
186
+
187
+ def _feature_extractor(module, raw_audio, dtype):
188
+ import mlx.core as mx
189
+
190
+ mx_dtype = _mlx_dtype(dtype)
191
+ batch, channels, samples = raw_audio.shape
192
+ spec, _ = _stft(module, raw_audio.reshape(batch * channels, samples), mx_dtype)
193
+ features = []
194
+ powers = []
195
+ band_index = 0
196
+ for width, bn in zip(module.band_width, module.BN):
197
+ sub = spec[:, band_index : band_index + width]
198
+ power = mx.sqrt(mx.sum(mx.square(sub.real) + mx.square(sub.imag), axis=1, keepdims=True) + module.eps)
199
+ norm = sub / power
200
+ inp = mx.concatenate((norm.real, norm.imag, mx.log(power)), axis=1)
201
+ features.append(_module_forward(bn, inp.astype(mx_dtype), dtype))
202
+ powers.append(power)
203
+ band_index += width
204
+ return mx.stack(features, axis=1), spec
205
+
206
+
207
+ def _estimate_spec(module, feature, batch_channels, dtype):
208
+ import mlx.core as mx
209
+
210
+ specs = []
211
+ for band_feature, output, width in zip(mx.split(feature, feature.shape[1], axis=1), module.output, module.band_width):
212
+ band_feature = band_feature[:, 0]
213
+ ri = _module_forward(output, band_feature, dtype).reshape(batch_channels, 2, width, -1)
214
+ specs.append(ri[:, 0] + (1j * ri[:, 1]))
215
+ return mx.concatenate(specs, axis=1)
216
+
217
+
218
+ def mlx_forward_apollo_mx(module, raw_audio, dtype=torch.float16):
219
+ if dtype not in (torch.float16, torch.float32):
220
+ raise TypeError("MLX full Apollo supports torch.float16 or torch.float32 compute dtype")
221
+ mx_dtype = _mlx_dtype(dtype)
222
+ raw_audio = raw_audio.astype(mx_dtype)
223
+ batch, channels, samples = raw_audio.shape
224
+ feature, _ = _feature_extractor(module, raw_audio, dtype)
225
+ for block in module.net:
226
+ feature = _bsnet(block, feature, dtype)
227
+ est_spec = _estimate_spec(module, feature, batch * channels, dtype)
228
+ return _istft(
229
+ est_spec,
230
+ {
231
+ "length": samples,
232
+ "n_fft": module.win,
233
+ "hop": module.stride,
234
+ "window": _torch_to_mlx_array(module.window, torch.float32).astype(mx_dtype),
235
+ "dtype": mx_dtype,
236
+ },
237
+ ).reshape(batch, channels, -1)
238
+
239
+
240
+ def mlx_forward_apollo(module, raw_audio, dtype=torch.float16):
241
+ x_mx = torch_to_mlx_input(raw_audio, dtype=dtype)
242
+ return mlx_to_torch_mps(mlx_forward_apollo_mx(module, x_mx, dtype), raw_audio)
File without changes