dpdfnet 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dpdfnet/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ __all__ = [
4
+ "enhance",
5
+ "enhance_file",
6
+ "available_models",
7
+ "download",
8
+ ]
9
+
10
+ if TYPE_CHECKING:
11
+ from .api import available_models, download, enhance, enhance_file
12
+
13
+
14
+ def __getattr__(name: str):
15
+ if name in {"enhance", "enhance_file", "available_models", "download"}:
16
+ from . import api
17
+
18
+ return getattr(api, name)
19
+ raise AttributeError(f"module 'dpdfnet' has no attribute '{name}'")
dpdfnet/__main__.py ADDED
@@ -0,0 +1,5 @@
1
+ from .cli import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ raise SystemExit(main())
dpdfnet/api.py ADDED
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
+
6
+ import numpy as np
7
+ from .models import (
8
+ DEFAULT_MODEL,
9
+ available_model_entries,
10
+ download_model,
11
+ download_models,
12
+ resolve_model,
13
+ )
14
+
15
+
16
+ def available_models(
17
+ ) -> List[Dict[str, Any]]:
18
+ return available_model_entries()
19
+
20
+
21
+ def download(
22
+ model: Optional[str] = None,
23
+ *,
24
+ force: bool = False,
25
+ quiet: bool = False,
26
+ verbose: bool = False,
27
+ ) -> Union[Path, Dict[str, Path]]:
28
+ if quiet and verbose:
29
+ raise ValueError("quiet=True and verbose=True are mutually exclusive.")
30
+
31
+ notifier = (lambda _message: None) if quiet else None
32
+
33
+ if model is None:
34
+ resolved_all = download_models(
35
+ models=None,
36
+ force=force,
37
+ verbose=verbose,
38
+ notifier=notifier,
39
+ )
40
+ return {item.info.name: item.onnx_path.parent for item in resolved_all}
41
+
42
+ resolved = download_model(
43
+ model=model,
44
+ force=force,
45
+ verbose=verbose,
46
+ notifier=notifier,
47
+ )
48
+ return resolved.onnx_path.parent
49
+
50
+
51
+ def enhance(
52
+ audio: np.ndarray,
53
+ sample_rate: int,
54
+ *,
55
+ model: str = DEFAULT_MODEL,
56
+ onnx_path: Optional[Union[str, Path]] = None,
57
+ state_path: Optional[Union[str, Path]] = None,
58
+ verbose: bool = False,
59
+ progress_callback: Optional[Callable[[int, int], None]] = None,
60
+ ) -> np.ndarray:
61
+ from .audio import (
62
+ ensure_sample_rate,
63
+ fit_length,
64
+ make_stft_config,
65
+ postprocess_spec,
66
+ preprocess_waveform,
67
+ to_mono,
68
+ )
69
+ from .onnx_backend import build_runtime_model, infer_win_len
70
+
71
+ waveform = to_mono(np.asarray(audio, dtype=np.float32))
72
+ sr_in = int(sample_rate)
73
+
74
+ resolved = resolve_model(
75
+ model=model,
76
+ onnx_path=onnx_path,
77
+ state_path=state_path,
78
+ auto_download=True,
79
+ verbose=verbose,
80
+ )
81
+ runtime = build_runtime_model(resolved.onnx_path, resolved.state_path)
82
+
83
+ waveform_model_sr = ensure_sample_rate(waveform, sr_in, resolved.info.sample_rate)
84
+ win_len = infer_win_len(runtime.session, resolved.info.sample_rate)
85
+ cfg = make_stft_config(win_len)
86
+
87
+ # Keep alignment behavior from the original scripts.
88
+ waveform_padded = np.pad(waveform_model_sr, (0, cfg.win_len), mode="constant")
89
+ spec_r = preprocess_waveform(waveform_padded, cfg)
90
+
91
+ state = runtime.init_state.copy()
92
+ frames: list[np.ndarray] = []
93
+ total_frames = int(spec_r.shape[1])
94
+ if progress_callback is not None:
95
+ progress_callback(0, total_frames)
96
+ for t in range(total_frames):
97
+ spec_t = np.ascontiguousarray(spec_r[:, t : t + 1, :, :], dtype=np.float32)
98
+ spec_e_t, state = runtime.session.run(
99
+ [runtime.out_spec_name, runtime.out_state_name],
100
+ {runtime.in_spec_name: spec_t, runtime.in_state_name: state},
101
+ )
102
+ frames.append(np.ascontiguousarray(spec_e_t, dtype=np.float32))
103
+ if progress_callback is not None:
104
+ progress_callback(t + 1, total_frames)
105
+
106
+ if not frames:
107
+ return waveform.copy()
108
+
109
+ spec_e = np.concatenate(frames, axis=1)
110
+ enhanced_model_sr = postprocess_spec(spec_e, cfg)
111
+ enhanced = ensure_sample_rate(enhanced_model_sr, resolved.info.sample_rate, sr_in)
112
+ return fit_length(enhanced, waveform.shape[0]).astype(np.float32, copy=False)
113
+
114
+
115
+ def enhance_file(
116
+ input_path: Union[str, Path],
117
+ output_path: Optional[Union[str, Path]] = None,
118
+ *,
119
+ model: str = DEFAULT_MODEL,
120
+ onnx_path: Optional[Union[str, Path]] = None,
121
+ state_path: Optional[Union[str, Path]] = None,
122
+ verbose: bool = False,
123
+ progress_callback: Optional[Callable[[int, int], None]] = None,
124
+ ) -> Path:
125
+ import soundfile as sf
126
+
127
+ from .audio import pcm16_safe
128
+
129
+ in_path = Path(input_path).expanduser().resolve()
130
+ if not in_path.is_file():
131
+ raise FileNotFoundError(f"Input file not found: {in_path}")
132
+
133
+ audio, sr = sf.read(str(in_path), always_2d=False)
134
+ enhanced = enhance(
135
+ audio=audio,
136
+ sample_rate=int(sr),
137
+ model=model,
138
+ onnx_path=onnx_path,
139
+ state_path=state_path,
140
+ verbose=verbose,
141
+ progress_callback=progress_callback,
142
+ )
143
+
144
+ if output_path is None:
145
+ out_path = in_path.with_name(f"{in_path.stem}_enhanced.wav")
146
+ else:
147
+ out_path = Path(output_path).expanduser().resolve()
148
+
149
+ out_path.parent.mkdir(parents=True, exist_ok=True)
150
+ sf.write(str(out_path), pcm16_safe(enhanced), int(sr), subtype="PCM_16")
151
+ return out_path
dpdfnet/audio.py ADDED
@@ -0,0 +1,103 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import librosa
6
+ import numpy as np
7
+
8
+
9
+ def to_mono(audio: np.ndarray) -> np.ndarray:
10
+ x = np.asarray(audio, dtype=np.float32)
11
+ if x.ndim == 1:
12
+ return x
13
+ if x.ndim != 2:
14
+ raise ValueError(f"Expected mono/stereo audio, got shape {x.shape}")
15
+ return np.mean(x, axis=1, dtype=np.float32)
16
+
17
+
18
+ def ensure_sample_rate(audio: np.ndarray, sample_rate: int, target_sample_rate: int) -> np.ndarray:
19
+ if sample_rate == target_sample_rate:
20
+ return np.asarray(audio, dtype=np.float32)
21
+ return librosa.resample(
22
+ np.asarray(audio, dtype=np.float32),
23
+ orig_sr=sample_rate,
24
+ target_sr=target_sample_rate,
25
+ ).astype(np.float32, copy=False)
26
+
27
+
28
+ def fit_length(audio: np.ndarray, target_len: int) -> np.ndarray:
29
+ x = np.asarray(audio, dtype=np.float32).reshape(-1)
30
+ if x.shape[0] == target_len:
31
+ return x
32
+ if x.shape[0] > target_len:
33
+ return x[:target_len]
34
+ out = np.zeros(target_len, dtype=np.float32)
35
+ out[: x.shape[0]] = x
36
+ return out
37
+
38
+
39
+ def pcm16_safe(audio: np.ndarray) -> np.ndarray:
40
+ x = np.clip(np.asarray(audio, dtype=np.float32), -1.0, 1.0)
41
+ return (x * 32767.0).astype(np.int16)
42
+
43
+
44
+ def vorbis_window(window_len: int) -> np.ndarray:
45
+ window_size_h = window_len / 2
46
+ indices = np.arange(window_len)
47
+ s = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h)
48
+ return np.sin(0.5 * np.pi * s * s).astype(np.float32)
49
+
50
+
51
+ def get_wnorm(window_len: int, frame_size: int) -> float:
52
+ return 1.0 / (window_len**2 / (2 * frame_size))
53
+
54
+
55
+ @dataclass(frozen=True)
56
+ class StftConfig:
57
+ win_len: int
58
+ hop_size: int
59
+ window: np.ndarray
60
+ wnorm: float
61
+
62
+
63
+ def make_stft_config(win_len: int) -> StftConfig:
64
+ hop_size = win_len // 2
65
+ window = vorbis_window(win_len)
66
+ wnorm = get_wnorm(win_len, hop_size)
67
+ return StftConfig(win_len=win_len, hop_size=hop_size, window=window, wnorm=wnorm)
68
+
69
+
70
+ def preprocess_waveform(waveform: np.ndarray, cfg: StftConfig) -> np.ndarray:
71
+ x = np.asarray(waveform, dtype=np.float32).reshape(-1)
72
+ spec = librosa.stft(
73
+ y=x,
74
+ n_fft=cfg.win_len,
75
+ hop_length=cfg.hop_size,
76
+ win_length=cfg.win_len,
77
+ window=cfg.window,
78
+ center=True,
79
+ pad_mode="reflect",
80
+ )
81
+ spec = (spec.T * cfg.wnorm).astype(np.complex64, copy=False)
82
+ spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False)
83
+ return spec_ri[None, ...]
84
+
85
+
86
+ def postprocess_spec(spec_e: np.ndarray, cfg: StftConfig) -> np.ndarray:
87
+ spec_c = np.asarray(spec_e[0], dtype=np.float32)
88
+ spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False)
89
+
90
+ waveform_e = librosa.istft(
91
+ spec,
92
+ hop_length=cfg.hop_size,
93
+ win_length=cfg.win_len,
94
+ window=cfg.window,
95
+ center=True,
96
+ length=None,
97
+ ).astype(np.float32, copy=False)
98
+
99
+ waveform_e = waveform_e / cfg.wnorm
100
+ return np.concatenate(
101
+ [waveform_e[cfg.win_len * 2 :], np.zeros(cfg.win_len * 2, dtype=np.float32)],
102
+ axis=0,
103
+ )
dpdfnet/cli.py ADDED
@@ -0,0 +1,268 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from importlib import metadata
5
+ from pathlib import Path
6
+ import sys
7
+ from typing import Callable, List, Optional
8
+
9
+ from tqdm import tqdm
10
+ from .models import DEFAULT_MODEL, get_cache_model_dir, supported_models
11
+
12
+
13
+ def _build_frame_progress_callback(
14
+ bar: tqdm,
15
+ ) -> Callable[[int, int], None]:
16
+ last_done = 0
17
+
18
+ def _callback(done: int, total: int) -> None:
19
+ nonlocal last_done
20
+ if bar.total != total:
21
+ bar.total = total
22
+ bar.refresh()
23
+ delta = max(0, done - last_done)
24
+ if delta:
25
+ bar.update(delta)
26
+ last_done = done
27
+
28
+ return _callback
29
+
30
+
31
+ def _version_string() -> str:
32
+ try:
33
+ return f"dpdfnet {metadata.version('dpdfnet')}"
34
+ except metadata.PackageNotFoundError:
35
+ return "dpdfnet (local)"
36
+
37
+
38
+ def _add_model_resolution_args(parser: argparse.ArgumentParser) -> None:
39
+ parser.add_argument(
40
+ "--model",
41
+ default=DEFAULT_MODEL,
42
+ choices=supported_models(),
43
+ help="Model name to run.",
44
+ )
45
+ parser.add_argument(
46
+ "-v",
47
+ "--verbose",
48
+ action="store_true",
49
+ help="Enable verbose model-resolution/download logs.",
50
+ )
51
+
52
+
53
+ def _build_parser() -> argparse.ArgumentParser:
54
+ parser = argparse.ArgumentParser(
55
+ prog="dpdfnet",
56
+ description="DPDFNet CPU-only ONNX speech enhancement toolkit.",
57
+ )
58
+ parser.add_argument(
59
+ "--version",
60
+ action="version",
61
+ version=_version_string(),
62
+ )
63
+
64
+ subparsers = parser.add_subparsers(dest="command")
65
+
66
+ p_models = subparsers.add_parser(
67
+ "models",
68
+ help="List supported models and local availability.",
69
+ )
70
+
71
+ p_enhance = subparsers.add_parser(
72
+ "enhance",
73
+ help="Enhance a single wav file.",
74
+ )
75
+ p_enhance.add_argument("input", type=Path, help="Input wav file path.")
76
+ p_enhance.add_argument("output", type=Path, help="Output wav file path.")
77
+ _add_model_resolution_args(p_enhance)
78
+
79
+ p_enhance_dir = subparsers.add_parser(
80
+ "enhance-dir",
81
+ help="Enhance all .wav files from one directory (non-recursive).",
82
+ )
83
+ p_enhance_dir.add_argument("input_dir", type=Path, help="Input directory.")
84
+ p_enhance_dir.add_argument("output_dir", type=Path, help="Output directory.")
85
+ _add_model_resolution_args(p_enhance_dir)
86
+
87
+ p_download = subparsers.add_parser(
88
+ "download",
89
+ help="Download all models by default, or a single model if provided.",
90
+ )
91
+ p_download.add_argument(
92
+ "model",
93
+ nargs="?",
94
+ choices=supported_models(),
95
+ default=None,
96
+ help="Optional model name to download. If omitted, all models are fetched.",
97
+ )
98
+ p_download.add_argument(
99
+ "--model",
100
+ dest="model_flag",
101
+ choices=supported_models(),
102
+ default=None,
103
+ help=argparse.SUPPRESS,
104
+ )
105
+ p_download.add_argument(
106
+ "--force",
107
+ "--refresh",
108
+ action="store_true",
109
+ help="Force re-download even if files are already cached.",
110
+ )
111
+ p_download_verbosity = p_download.add_mutually_exclusive_group()
112
+ p_download_verbosity.add_argument(
113
+ "-q",
114
+ "--quiet",
115
+ action="store_true",
116
+ help="Suppress download progress messages.",
117
+ )
118
+ p_download_verbosity.add_argument(
119
+ "-v",
120
+ "--verbose",
121
+ action="store_true",
122
+ help="Enable verbose download logs.",
123
+ )
124
+
125
+ return parser
126
+
127
+
128
+ def _print_model_table() -> int:
129
+ from .api import available_models
130
+
131
+ rows = available_models()
132
+ print(f"cache_dir={get_cache_model_dir().resolve()}")
133
+ for row in rows:
134
+ print(
135
+ f"{row['name']}: sr={row['sample_rate']}Hz, "
136
+ f"ready={row['ready']}, "
137
+ f"onnx_found={row['onnx_found']}, state_found={row['state_found']}, "
138
+ f"cached={row['cached']}"
139
+ )
140
+ return 0
141
+
142
+
143
+ def _run_enhance(args: argparse.Namespace) -> int:
144
+ from .api import enhance_file
145
+
146
+ with tqdm(
147
+ total=0,
148
+ unit="frame",
149
+ desc="Enhancing",
150
+ dynamic_ncols=True,
151
+ file=sys.stderr,
152
+ ) as progress:
153
+ enhance_file(
154
+ input_path=args.input,
155
+ output_path=args.output,
156
+ model=args.model,
157
+ verbose=args.verbose,
158
+ progress_callback=_build_frame_progress_callback(progress),
159
+ )
160
+ print(f"Wrote enhanced audio: {Path(args.output).expanduser().resolve()}")
161
+ return 0
162
+
163
+
164
+ def _run_enhance_dir(args: argparse.Namespace) -> int:
165
+ from .api import enhance_file
166
+
167
+ input_dir = Path(args.input_dir).expanduser().resolve()
168
+ output_dir = Path(args.output_dir).expanduser().resolve()
169
+ if not input_dir.is_dir():
170
+ raise FileNotFoundError(f"Input directory not found: {input_dir}")
171
+
172
+ wav_files = sorted([p for p in input_dir.iterdir() if p.is_file() and p.suffix.lower() == ".wav"])
173
+ if not wav_files:
174
+ raise FileNotFoundError(f"No .wav files found in {input_dir}")
175
+
176
+ output_dir.mkdir(parents=True, exist_ok=True)
177
+ with tqdm(
178
+ total=len(wav_files),
179
+ unit="file",
180
+ desc="Files",
181
+ dynamic_ncols=True,
182
+ file=sys.stderr,
183
+ ) as files_progress:
184
+ with tqdm(
185
+ total=0,
186
+ unit="frame",
187
+ desc="Frames",
188
+ dynamic_ncols=True,
189
+ file=sys.stderr,
190
+ ) as frames_progress:
191
+ for wav_path in wav_files:
192
+ out_path = output_dir / f"{wav_path.stem}_enhanced.wav"
193
+ last_done = 0
194
+
195
+ def _callback(done: int, total: int) -> None:
196
+ nonlocal last_done
197
+ if done == 0:
198
+ frames_progress.total = (frames_progress.total or 0) + total
199
+ frames_progress.refresh()
200
+ last_done = 0
201
+ return
202
+ delta = max(0, done - last_done)
203
+ if delta:
204
+ frames_progress.update(delta)
205
+ last_done = done
206
+
207
+ enhance_file(
208
+ input_path=wav_path,
209
+ output_path=out_path,
210
+ model=args.model,
211
+ verbose=args.verbose,
212
+ progress_callback=_callback,
213
+ )
214
+ files_progress.update(1)
215
+ files_progress.set_postfix_str(wav_path.name)
216
+ return 0
217
+
218
+
219
+ def _run_download(args: argparse.Namespace) -> int:
220
+ from .api import download
221
+
222
+ if args.model is not None and args.model_flag is not None and args.model != args.model_flag:
223
+ raise ValueError("Conflicting model names provided in positional argument and --model.")
224
+
225
+ model = args.model if args.model is not None else args.model_flag
226
+ destination = download(
227
+ model=model,
228
+ force=args.force,
229
+ quiet=args.quiet,
230
+ verbose=args.verbose,
231
+ )
232
+ if isinstance(destination, dict):
233
+ print("Downloaded models:")
234
+ for model_name, model_path in destination.items():
235
+ print(f"- {model_name}: {model_path}")
236
+ else:
237
+ model_name = model if model is not None else "<unknown>"
238
+ print(f"Downloaded '{model_name}' to: {destination}")
239
+ return 0
240
+
241
+
242
+ def main(argv: Optional[List[str]] = None) -> int:
243
+ parser = _build_parser()
244
+ args = parser.parse_args(argv)
245
+
246
+ if args.command is None:
247
+ parser.print_help()
248
+ return 0
249
+
250
+ try:
251
+ if args.command == "models":
252
+ return _print_model_table()
253
+ if args.command == "enhance":
254
+ return _run_enhance(args)
255
+ if args.command == "enhance-dir":
256
+ return _run_enhance_dir(args)
257
+ if args.command == "download":
258
+ return _run_download(args)
259
+ except Exception as exc:
260
+ print(f"Error: {exc}", file=sys.stderr)
261
+ return 2
262
+
263
+ parser.print_help()
264
+ return 2
265
+
266
+
267
+ if __name__ == "__main__":
268
+ raise SystemExit(main())