dpdfnet 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dpdfnet/__init__.py +19 -0
- dpdfnet/__main__.py +5 -0
- dpdfnet/api.py +151 -0
- dpdfnet/audio.py +103 -0
- dpdfnet/cli.py +268 -0
- dpdfnet/models.py +515 -0
- dpdfnet/onnx_backend.py +111 -0
- dpdfnet-0.2.0.dist-info/METADATA +144 -0
- dpdfnet-0.2.0.dist-info/RECORD +13 -0
- dpdfnet-0.2.0.dist-info/WHEEL +5 -0
- dpdfnet-0.2.0.dist-info/entry_points.txt +2 -0
- dpdfnet-0.2.0.dist-info/licenses/LICENSE +201 -0
- dpdfnet-0.2.0.dist-info/top_level.txt +1 -0
dpdfnet/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"enhance",
|
|
5
|
+
"enhance_file",
|
|
6
|
+
"available_models",
|
|
7
|
+
"download",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from .api import available_models, download, enhance, enhance_file
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def __getattr__(name: str):
|
|
15
|
+
if name in {"enhance", "enhance_file", "available_models", "download"}:
|
|
16
|
+
from . import api
|
|
17
|
+
|
|
18
|
+
return getattr(api, name)
|
|
19
|
+
raise AttributeError(f"module 'dpdfnet' has no attribute '{name}'")
|
dpdfnet/__main__.py
ADDED
dpdfnet/api.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from .models import (
|
|
8
|
+
DEFAULT_MODEL,
|
|
9
|
+
available_model_entries,
|
|
10
|
+
download_model,
|
|
11
|
+
download_models,
|
|
12
|
+
resolve_model,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def available_models(
|
|
17
|
+
) -> List[Dict[str, Any]]:
|
|
18
|
+
return available_model_entries()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def download(
|
|
22
|
+
model: Optional[str] = None,
|
|
23
|
+
*,
|
|
24
|
+
force: bool = False,
|
|
25
|
+
quiet: bool = False,
|
|
26
|
+
verbose: bool = False,
|
|
27
|
+
) -> Union[Path, Dict[str, Path]]:
|
|
28
|
+
if quiet and verbose:
|
|
29
|
+
raise ValueError("quiet=True and verbose=True are mutually exclusive.")
|
|
30
|
+
|
|
31
|
+
notifier = (lambda _message: None) if quiet else None
|
|
32
|
+
|
|
33
|
+
if model is None:
|
|
34
|
+
resolved_all = download_models(
|
|
35
|
+
models=None,
|
|
36
|
+
force=force,
|
|
37
|
+
verbose=verbose,
|
|
38
|
+
notifier=notifier,
|
|
39
|
+
)
|
|
40
|
+
return {item.info.name: item.onnx_path.parent for item in resolved_all}
|
|
41
|
+
|
|
42
|
+
resolved = download_model(
|
|
43
|
+
model=model,
|
|
44
|
+
force=force,
|
|
45
|
+
verbose=verbose,
|
|
46
|
+
notifier=notifier,
|
|
47
|
+
)
|
|
48
|
+
return resolved.onnx_path.parent
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def enhance(
|
|
52
|
+
audio: np.ndarray,
|
|
53
|
+
sample_rate: int,
|
|
54
|
+
*,
|
|
55
|
+
model: str = DEFAULT_MODEL,
|
|
56
|
+
onnx_path: Optional[Union[str, Path]] = None,
|
|
57
|
+
state_path: Optional[Union[str, Path]] = None,
|
|
58
|
+
verbose: bool = False,
|
|
59
|
+
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
60
|
+
) -> np.ndarray:
|
|
61
|
+
from .audio import (
|
|
62
|
+
ensure_sample_rate,
|
|
63
|
+
fit_length,
|
|
64
|
+
make_stft_config,
|
|
65
|
+
postprocess_spec,
|
|
66
|
+
preprocess_waveform,
|
|
67
|
+
to_mono,
|
|
68
|
+
)
|
|
69
|
+
from .onnx_backend import build_runtime_model, infer_win_len
|
|
70
|
+
|
|
71
|
+
waveform = to_mono(np.asarray(audio, dtype=np.float32))
|
|
72
|
+
sr_in = int(sample_rate)
|
|
73
|
+
|
|
74
|
+
resolved = resolve_model(
|
|
75
|
+
model=model,
|
|
76
|
+
onnx_path=onnx_path,
|
|
77
|
+
state_path=state_path,
|
|
78
|
+
auto_download=True,
|
|
79
|
+
verbose=verbose,
|
|
80
|
+
)
|
|
81
|
+
runtime = build_runtime_model(resolved.onnx_path, resolved.state_path)
|
|
82
|
+
|
|
83
|
+
waveform_model_sr = ensure_sample_rate(waveform, sr_in, resolved.info.sample_rate)
|
|
84
|
+
win_len = infer_win_len(runtime.session, resolved.info.sample_rate)
|
|
85
|
+
cfg = make_stft_config(win_len)
|
|
86
|
+
|
|
87
|
+
# Keep alignment behavior from the original scripts.
|
|
88
|
+
waveform_padded = np.pad(waveform_model_sr, (0, cfg.win_len), mode="constant")
|
|
89
|
+
spec_r = preprocess_waveform(waveform_padded, cfg)
|
|
90
|
+
|
|
91
|
+
state = runtime.init_state.copy()
|
|
92
|
+
frames: list[np.ndarray] = []
|
|
93
|
+
total_frames = int(spec_r.shape[1])
|
|
94
|
+
if progress_callback is not None:
|
|
95
|
+
progress_callback(0, total_frames)
|
|
96
|
+
for t in range(total_frames):
|
|
97
|
+
spec_t = np.ascontiguousarray(spec_r[:, t : t + 1, :, :], dtype=np.float32)
|
|
98
|
+
spec_e_t, state = runtime.session.run(
|
|
99
|
+
[runtime.out_spec_name, runtime.out_state_name],
|
|
100
|
+
{runtime.in_spec_name: spec_t, runtime.in_state_name: state},
|
|
101
|
+
)
|
|
102
|
+
frames.append(np.ascontiguousarray(spec_e_t, dtype=np.float32))
|
|
103
|
+
if progress_callback is not None:
|
|
104
|
+
progress_callback(t + 1, total_frames)
|
|
105
|
+
|
|
106
|
+
if not frames:
|
|
107
|
+
return waveform.copy()
|
|
108
|
+
|
|
109
|
+
spec_e = np.concatenate(frames, axis=1)
|
|
110
|
+
enhanced_model_sr = postprocess_spec(spec_e, cfg)
|
|
111
|
+
enhanced = ensure_sample_rate(enhanced_model_sr, resolved.info.sample_rate, sr_in)
|
|
112
|
+
return fit_length(enhanced, waveform.shape[0]).astype(np.float32, copy=False)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def enhance_file(
|
|
116
|
+
input_path: Union[str, Path],
|
|
117
|
+
output_path: Optional[Union[str, Path]] = None,
|
|
118
|
+
*,
|
|
119
|
+
model: str = DEFAULT_MODEL,
|
|
120
|
+
onnx_path: Optional[Union[str, Path]] = None,
|
|
121
|
+
state_path: Optional[Union[str, Path]] = None,
|
|
122
|
+
verbose: bool = False,
|
|
123
|
+
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
124
|
+
) -> Path:
|
|
125
|
+
import soundfile as sf
|
|
126
|
+
|
|
127
|
+
from .audio import pcm16_safe
|
|
128
|
+
|
|
129
|
+
in_path = Path(input_path).expanduser().resolve()
|
|
130
|
+
if not in_path.is_file():
|
|
131
|
+
raise FileNotFoundError(f"Input file not found: {in_path}")
|
|
132
|
+
|
|
133
|
+
audio, sr = sf.read(str(in_path), always_2d=False)
|
|
134
|
+
enhanced = enhance(
|
|
135
|
+
audio=audio,
|
|
136
|
+
sample_rate=int(sr),
|
|
137
|
+
model=model,
|
|
138
|
+
onnx_path=onnx_path,
|
|
139
|
+
state_path=state_path,
|
|
140
|
+
verbose=verbose,
|
|
141
|
+
progress_callback=progress_callback,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if output_path is None:
|
|
145
|
+
out_path = in_path.with_name(f"{in_path.stem}_enhanced.wav")
|
|
146
|
+
else:
|
|
147
|
+
out_path = Path(output_path).expanduser().resolve()
|
|
148
|
+
|
|
149
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
150
|
+
sf.write(str(out_path), pcm16_safe(enhanced), int(sr), subtype="PCM_16")
|
|
151
|
+
return out_path
|
dpdfnet/audio.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import librosa
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def to_mono(audio: np.ndarray) -> np.ndarray:
|
|
10
|
+
x = np.asarray(audio, dtype=np.float32)
|
|
11
|
+
if x.ndim == 1:
|
|
12
|
+
return x
|
|
13
|
+
if x.ndim != 2:
|
|
14
|
+
raise ValueError(f"Expected mono/stereo audio, got shape {x.shape}")
|
|
15
|
+
return np.mean(x, axis=1, dtype=np.float32)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def ensure_sample_rate(audio: np.ndarray, sample_rate: int, target_sample_rate: int) -> np.ndarray:
|
|
19
|
+
if sample_rate == target_sample_rate:
|
|
20
|
+
return np.asarray(audio, dtype=np.float32)
|
|
21
|
+
return librosa.resample(
|
|
22
|
+
np.asarray(audio, dtype=np.float32),
|
|
23
|
+
orig_sr=sample_rate,
|
|
24
|
+
target_sr=target_sample_rate,
|
|
25
|
+
).astype(np.float32, copy=False)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def fit_length(audio: np.ndarray, target_len: int) -> np.ndarray:
|
|
29
|
+
x = np.asarray(audio, dtype=np.float32).reshape(-1)
|
|
30
|
+
if x.shape[0] == target_len:
|
|
31
|
+
return x
|
|
32
|
+
if x.shape[0] > target_len:
|
|
33
|
+
return x[:target_len]
|
|
34
|
+
out = np.zeros(target_len, dtype=np.float32)
|
|
35
|
+
out[: x.shape[0]] = x
|
|
36
|
+
return out
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def pcm16_safe(audio: np.ndarray) -> np.ndarray:
|
|
40
|
+
x = np.clip(np.asarray(audio, dtype=np.float32), -1.0, 1.0)
|
|
41
|
+
return (x * 32767.0).astype(np.int16)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def vorbis_window(window_len: int) -> np.ndarray:
|
|
45
|
+
window_size_h = window_len / 2
|
|
46
|
+
indices = np.arange(window_len)
|
|
47
|
+
s = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h)
|
|
48
|
+
return np.sin(0.5 * np.pi * s * s).astype(np.float32)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_wnorm(window_len: int, frame_size: int) -> float:
|
|
52
|
+
return 1.0 / (window_len**2 / (2 * frame_size))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass(frozen=True)
|
|
56
|
+
class StftConfig:
|
|
57
|
+
win_len: int
|
|
58
|
+
hop_size: int
|
|
59
|
+
window: np.ndarray
|
|
60
|
+
wnorm: float
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def make_stft_config(win_len: int) -> StftConfig:
|
|
64
|
+
hop_size = win_len // 2
|
|
65
|
+
window = vorbis_window(win_len)
|
|
66
|
+
wnorm = get_wnorm(win_len, hop_size)
|
|
67
|
+
return StftConfig(win_len=win_len, hop_size=hop_size, window=window, wnorm=wnorm)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def preprocess_waveform(waveform: np.ndarray, cfg: StftConfig) -> np.ndarray:
|
|
71
|
+
x = np.asarray(waveform, dtype=np.float32).reshape(-1)
|
|
72
|
+
spec = librosa.stft(
|
|
73
|
+
y=x,
|
|
74
|
+
n_fft=cfg.win_len,
|
|
75
|
+
hop_length=cfg.hop_size,
|
|
76
|
+
win_length=cfg.win_len,
|
|
77
|
+
window=cfg.window,
|
|
78
|
+
center=True,
|
|
79
|
+
pad_mode="reflect",
|
|
80
|
+
)
|
|
81
|
+
spec = (spec.T * cfg.wnorm).astype(np.complex64, copy=False)
|
|
82
|
+
spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False)
|
|
83
|
+
return spec_ri[None, ...]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def postprocess_spec(spec_e: np.ndarray, cfg: StftConfig) -> np.ndarray:
|
|
87
|
+
spec_c = np.asarray(spec_e[0], dtype=np.float32)
|
|
88
|
+
spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False)
|
|
89
|
+
|
|
90
|
+
waveform_e = librosa.istft(
|
|
91
|
+
spec,
|
|
92
|
+
hop_length=cfg.hop_size,
|
|
93
|
+
win_length=cfg.win_len,
|
|
94
|
+
window=cfg.window,
|
|
95
|
+
center=True,
|
|
96
|
+
length=None,
|
|
97
|
+
).astype(np.float32, copy=False)
|
|
98
|
+
|
|
99
|
+
waveform_e = waveform_e / cfg.wnorm
|
|
100
|
+
return np.concatenate(
|
|
101
|
+
[waveform_e[cfg.win_len * 2 :], np.zeros(cfg.win_len * 2, dtype=np.float32)],
|
|
102
|
+
axis=0,
|
|
103
|
+
)
|
dpdfnet/cli.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
from importlib import metadata
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import sys
|
|
7
|
+
from typing import Callable, List, Optional
|
|
8
|
+
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
from .models import DEFAULT_MODEL, get_cache_model_dir, supported_models
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _build_frame_progress_callback(
|
|
14
|
+
bar: tqdm,
|
|
15
|
+
) -> Callable[[int, int], None]:
|
|
16
|
+
last_done = 0
|
|
17
|
+
|
|
18
|
+
def _callback(done: int, total: int) -> None:
|
|
19
|
+
nonlocal last_done
|
|
20
|
+
if bar.total != total:
|
|
21
|
+
bar.total = total
|
|
22
|
+
bar.refresh()
|
|
23
|
+
delta = max(0, done - last_done)
|
|
24
|
+
if delta:
|
|
25
|
+
bar.update(delta)
|
|
26
|
+
last_done = done
|
|
27
|
+
|
|
28
|
+
return _callback
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _version_string() -> str:
|
|
32
|
+
try:
|
|
33
|
+
return f"dpdfnet {metadata.version('dpdfnet')}"
|
|
34
|
+
except metadata.PackageNotFoundError:
|
|
35
|
+
return "dpdfnet (local)"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _add_model_resolution_args(parser: argparse.ArgumentParser) -> None:
|
|
39
|
+
parser.add_argument(
|
|
40
|
+
"--model",
|
|
41
|
+
default=DEFAULT_MODEL,
|
|
42
|
+
choices=supported_models(),
|
|
43
|
+
help="Model name to run.",
|
|
44
|
+
)
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
"-v",
|
|
47
|
+
"--verbose",
|
|
48
|
+
action="store_true",
|
|
49
|
+
help="Enable verbose model-resolution/download logs.",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _build_parser() -> argparse.ArgumentParser:
|
|
54
|
+
parser = argparse.ArgumentParser(
|
|
55
|
+
prog="dpdfnet",
|
|
56
|
+
description="DPDFNet CPU-only ONNX speech enhancement toolkit.",
|
|
57
|
+
)
|
|
58
|
+
parser.add_argument(
|
|
59
|
+
"--version",
|
|
60
|
+
action="version",
|
|
61
|
+
version=_version_string(),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
65
|
+
|
|
66
|
+
p_models = subparsers.add_parser(
|
|
67
|
+
"models",
|
|
68
|
+
help="List supported models and local availability.",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
p_enhance = subparsers.add_parser(
|
|
72
|
+
"enhance",
|
|
73
|
+
help="Enhance a single wav file.",
|
|
74
|
+
)
|
|
75
|
+
p_enhance.add_argument("input", type=Path, help="Input wav file path.")
|
|
76
|
+
p_enhance.add_argument("output", type=Path, help="Output wav file path.")
|
|
77
|
+
_add_model_resolution_args(p_enhance)
|
|
78
|
+
|
|
79
|
+
p_enhance_dir = subparsers.add_parser(
|
|
80
|
+
"enhance-dir",
|
|
81
|
+
help="Enhance all .wav files from one directory (non-recursive).",
|
|
82
|
+
)
|
|
83
|
+
p_enhance_dir.add_argument("input_dir", type=Path, help="Input directory.")
|
|
84
|
+
p_enhance_dir.add_argument("output_dir", type=Path, help="Output directory.")
|
|
85
|
+
_add_model_resolution_args(p_enhance_dir)
|
|
86
|
+
|
|
87
|
+
p_download = subparsers.add_parser(
|
|
88
|
+
"download",
|
|
89
|
+
help="Download all models by default, or a single model if provided.",
|
|
90
|
+
)
|
|
91
|
+
p_download.add_argument(
|
|
92
|
+
"model",
|
|
93
|
+
nargs="?",
|
|
94
|
+
choices=supported_models(),
|
|
95
|
+
default=None,
|
|
96
|
+
help="Optional model name to download. If omitted, all models are fetched.",
|
|
97
|
+
)
|
|
98
|
+
p_download.add_argument(
|
|
99
|
+
"--model",
|
|
100
|
+
dest="model_flag",
|
|
101
|
+
choices=supported_models(),
|
|
102
|
+
default=None,
|
|
103
|
+
help=argparse.SUPPRESS,
|
|
104
|
+
)
|
|
105
|
+
p_download.add_argument(
|
|
106
|
+
"--force",
|
|
107
|
+
"--refresh",
|
|
108
|
+
action="store_true",
|
|
109
|
+
help="Force re-download even if files are already cached.",
|
|
110
|
+
)
|
|
111
|
+
p_download_verbosity = p_download.add_mutually_exclusive_group()
|
|
112
|
+
p_download_verbosity.add_argument(
|
|
113
|
+
"-q",
|
|
114
|
+
"--quiet",
|
|
115
|
+
action="store_true",
|
|
116
|
+
help="Suppress download progress messages.",
|
|
117
|
+
)
|
|
118
|
+
p_download_verbosity.add_argument(
|
|
119
|
+
"-v",
|
|
120
|
+
"--verbose",
|
|
121
|
+
action="store_true",
|
|
122
|
+
help="Enable verbose download logs.",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return parser
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _print_model_table() -> int:
|
|
129
|
+
from .api import available_models
|
|
130
|
+
|
|
131
|
+
rows = available_models()
|
|
132
|
+
print(f"cache_dir={get_cache_model_dir().resolve()}")
|
|
133
|
+
for row in rows:
|
|
134
|
+
print(
|
|
135
|
+
f"{row['name']}: sr={row['sample_rate']}Hz, "
|
|
136
|
+
f"ready={row['ready']}, "
|
|
137
|
+
f"onnx_found={row['onnx_found']}, state_found={row['state_found']}, "
|
|
138
|
+
f"cached={row['cached']}"
|
|
139
|
+
)
|
|
140
|
+
return 0
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _run_enhance(args: argparse.Namespace) -> int:
|
|
144
|
+
from .api import enhance_file
|
|
145
|
+
|
|
146
|
+
with tqdm(
|
|
147
|
+
total=0,
|
|
148
|
+
unit="frame",
|
|
149
|
+
desc="Enhancing",
|
|
150
|
+
dynamic_ncols=True,
|
|
151
|
+
file=sys.stderr,
|
|
152
|
+
) as progress:
|
|
153
|
+
enhance_file(
|
|
154
|
+
input_path=args.input,
|
|
155
|
+
output_path=args.output,
|
|
156
|
+
model=args.model,
|
|
157
|
+
verbose=args.verbose,
|
|
158
|
+
progress_callback=_build_frame_progress_callback(progress),
|
|
159
|
+
)
|
|
160
|
+
print(f"Wrote enhanced audio: {Path(args.output).expanduser().resolve()}")
|
|
161
|
+
return 0
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _run_enhance_dir(args: argparse.Namespace) -> int:
|
|
165
|
+
from .api import enhance_file
|
|
166
|
+
|
|
167
|
+
input_dir = Path(args.input_dir).expanduser().resolve()
|
|
168
|
+
output_dir = Path(args.output_dir).expanduser().resolve()
|
|
169
|
+
if not input_dir.is_dir():
|
|
170
|
+
raise FileNotFoundError(f"Input directory not found: {input_dir}")
|
|
171
|
+
|
|
172
|
+
wav_files = sorted([p for p in input_dir.iterdir() if p.is_file() and p.suffix.lower() == ".wav"])
|
|
173
|
+
if not wav_files:
|
|
174
|
+
raise FileNotFoundError(f"No .wav files found in {input_dir}")
|
|
175
|
+
|
|
176
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
177
|
+
with tqdm(
|
|
178
|
+
total=len(wav_files),
|
|
179
|
+
unit="file",
|
|
180
|
+
desc="Files",
|
|
181
|
+
dynamic_ncols=True,
|
|
182
|
+
file=sys.stderr,
|
|
183
|
+
) as files_progress:
|
|
184
|
+
with tqdm(
|
|
185
|
+
total=0,
|
|
186
|
+
unit="frame",
|
|
187
|
+
desc="Frames",
|
|
188
|
+
dynamic_ncols=True,
|
|
189
|
+
file=sys.stderr,
|
|
190
|
+
) as frames_progress:
|
|
191
|
+
for wav_path in wav_files:
|
|
192
|
+
out_path = output_dir / f"{wav_path.stem}_enhanced.wav"
|
|
193
|
+
last_done = 0
|
|
194
|
+
|
|
195
|
+
def _callback(done: int, total: int) -> None:
|
|
196
|
+
nonlocal last_done
|
|
197
|
+
if done == 0:
|
|
198
|
+
frames_progress.total = (frames_progress.total or 0) + total
|
|
199
|
+
frames_progress.refresh()
|
|
200
|
+
last_done = 0
|
|
201
|
+
return
|
|
202
|
+
delta = max(0, done - last_done)
|
|
203
|
+
if delta:
|
|
204
|
+
frames_progress.update(delta)
|
|
205
|
+
last_done = done
|
|
206
|
+
|
|
207
|
+
enhance_file(
|
|
208
|
+
input_path=wav_path,
|
|
209
|
+
output_path=out_path,
|
|
210
|
+
model=args.model,
|
|
211
|
+
verbose=args.verbose,
|
|
212
|
+
progress_callback=_callback,
|
|
213
|
+
)
|
|
214
|
+
files_progress.update(1)
|
|
215
|
+
files_progress.set_postfix_str(wav_path.name)
|
|
216
|
+
return 0
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _run_download(args: argparse.Namespace) -> int:
|
|
220
|
+
from .api import download
|
|
221
|
+
|
|
222
|
+
if args.model is not None and args.model_flag is not None and args.model != args.model_flag:
|
|
223
|
+
raise ValueError("Conflicting model names provided in positional argument and --model.")
|
|
224
|
+
|
|
225
|
+
model = args.model if args.model is not None else args.model_flag
|
|
226
|
+
destination = download(
|
|
227
|
+
model=model,
|
|
228
|
+
force=args.force,
|
|
229
|
+
quiet=args.quiet,
|
|
230
|
+
verbose=args.verbose,
|
|
231
|
+
)
|
|
232
|
+
if isinstance(destination, dict):
|
|
233
|
+
print("Downloaded models:")
|
|
234
|
+
for model_name, model_path in destination.items():
|
|
235
|
+
print(f"- {model_name}: {model_path}")
|
|
236
|
+
else:
|
|
237
|
+
model_name = model if model is not None else "<unknown>"
|
|
238
|
+
print(f"Downloaded '{model_name}' to: {destination}")
|
|
239
|
+
return 0
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def main(argv: Optional[List[str]] = None) -> int:
|
|
243
|
+
parser = _build_parser()
|
|
244
|
+
args = parser.parse_args(argv)
|
|
245
|
+
|
|
246
|
+
if args.command is None:
|
|
247
|
+
parser.print_help()
|
|
248
|
+
return 0
|
|
249
|
+
|
|
250
|
+
try:
|
|
251
|
+
if args.command == "models":
|
|
252
|
+
return _print_model_table()
|
|
253
|
+
if args.command == "enhance":
|
|
254
|
+
return _run_enhance(args)
|
|
255
|
+
if args.command == "enhance-dir":
|
|
256
|
+
return _run_enhance_dir(args)
|
|
257
|
+
if args.command == "download":
|
|
258
|
+
return _run_download(args)
|
|
259
|
+
except Exception as exc:
|
|
260
|
+
print(f"Error: {exc}", file=sys.stderr)
|
|
261
|
+
return 2
|
|
262
|
+
|
|
263
|
+
parser.print_help()
|
|
264
|
+
return 2
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
if __name__ == "__main__":
|
|
268
|
+
raise SystemExit(main())
|